Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
yangql
composable_kernel-1
Commits
fe6ce55c
Unverified
Commit
fe6ce55c
authored
Mar 28, 2022
by
zjing14
Committed by
GitHub
Mar 28, 2022
Browse files
Grouped gemm test fix (#150)
* fixed test: return res; rand gemm shapes * fixed return
parent
313bbea5
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
13 additions
and
6 deletions
+13
-6
test/grouped_gemm/grouped_gemm_fp16.cpp
test/grouped_gemm/grouped_gemm_fp16.cpp
+13
-6
No files found.
test/grouped_gemm/grouped_gemm_fp16.cpp
View file @
fe6ce55c
...
...
@@ -66,7 +66,7 @@ static bool check_err(const Tensor<T>& ref, const Tensor<T>& result)
bool
TestGroupedGemm
(
DeviceGroupedGemmPtr_
&
groupedGemmPtr
)
{
int
group_count
=
4
;
int
group_count
=
rand
()
%
10
+
1
;
// GEMM shape
std
::
vector
<
ck
::
tensor_operation
::
device
::
GemmShape
>
gemm_shapes
;
...
...
@@ -77,9 +77,9 @@ bool TestGroupedGemm(DeviceGroupedGemmPtr_& groupedGemmPtr)
for
(
int
i
=
0
;
i
<
group_count
;
i
++
)
{
int
M
=
256
+
256
*
i
;
int
N
=
128
+
128
*
i
;
int
K
=
128
+
64
*
i
;
int
M
=
256
+
256
*
(
rand
()
%
10
)
;
int
N
=
256
+
256
*
(
rand
()
%
10
)
;
int
K
=
128
+
128
*
(
rand
()
%
10
)
;
int
AStride
=
std
::
is_same
<
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>::
value
?
K
:
M
;
int
BStride
=
std
::
is_same
<
ck
::
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>::
value
?
N
:
K
;
...
...
@@ -132,8 +132,8 @@ bool TestGroupedGemm(DeviceGroupedGemmPtr_& groupedGemmPtr)
c_device_tensors
.
emplace_back
(
Tensor
<
CDataType
>
(
f_host_tensor_descriptor
(
gemm_shapes
[
i
].
M
,
gemm_shapes
[
i
].
N
,
gemm_shapes
[
i
].
StrideC
,
CLayout
{})));
a_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_
3
<
ADataType
>
{
0.0
,
1.0
});
b_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_
3
<
BDataType
>
{
-
0.
5
,
0.
5
});
a_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_
2
<
ADataType
>
{
-
5
,
5
});
b_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_
2
<
BDataType
>
{
-
5
,
5
});
}
for
(
int
i
=
0
;
i
<
gemm_shapes
.
size
();
i
++
)
...
...
@@ -181,6 +181,11 @@ bool TestGroupedGemm(DeviceGroupedGemmPtr_& groupedGemmPtr)
b_element_op
,
c_element_op
);
if
(
!
groupedGemmPtr
->
IsSupportedArgument
(
argument_ptr
.
get
()))
{
return
false
;
}
ref_invoker
.
Run
(
ref_argument
);
bool
res
=
check_err
(
c_device_tensors
[
i
],
c_host_tensors
[
i
]);
...
...
@@ -210,4 +215,6 @@ int main()
}
std
::
cout
<<
"TestGroupedGemm ..... "
<<
(
res
?
"SUCCESS"
:
"FAILURE"
)
<<
std
::
endl
;
return
res
?
0
:
1
;
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment