Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
89c46e79
Commit
89c46e79
authored
Jul 11, 2022
by
Jing Zhang
Browse files
fixed test_grouped_gemm
parent
bd0f0686
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
11 additions
and
11 deletions
+11
-11
test/grouped_gemm/grouped_gemm_fp16.cpp
test/grouped_gemm/grouped_gemm_fp16.cpp
+11
-11
No files found.
test/grouped_gemm/grouped_gemm_fp16.cpp
View file @
89c46e79
...
...
@@ -52,11 +52,11 @@ bool TestGroupedGemm(DeviceGroupedGemmPtr_& groupedGemmPtr)
int
group_count
=
rand
()
%
10
+
1
;
// GEMM shape
std
::
vector
<
ck
::
tensor_operation
::
device
::
Gemm
Shape
>
gemm_
shape
s
;
std
::
vector
<
ck
::
tensor_operation
::
device
::
Gemm
Desc
>
gemm_
desc
s
;
std
::
vector
<
const
void
*>
p_a
,
p_b
;
std
::
vector
<
void
*>
p_c
;
gemm_
shape
s
.
reserve
(
group_count
);
gemm_
desc
s
.
reserve
(
group_count
);
for
(
int
i
=
0
;
i
<
group_count
;
i
++
)
{
...
...
@@ -68,7 +68,7 @@ bool TestGroupedGemm(DeviceGroupedGemmPtr_& groupedGemmPtr)
int
BStride
=
std
::
is_same
<
ck
::
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>::
value
?
N
:
K
;
int
CStride
=
std
::
is_same
<
ck
::
tensor_layout
::
gemm
::
RowMajor
,
CLayout
>::
value
?
N
:
M
;
gemm_
shape
s
.
push_back
({
M
,
N
,
K
,
AStride
,
BStride
,
CStride
});
gemm_
desc
s
.
push_back
({
M
,
N
,
K
,
AStride
,
BStride
,
CStride
});
}
auto
f_host_tensor_descriptor
=
...
...
@@ -104,22 +104,22 @@ bool TestGroupedGemm(DeviceGroupedGemmPtr_& groupedGemmPtr)
b_tensors_device
.
reserve
(
group_count
);
c_tensors_device
.
reserve
(
group_count
);
for
(
std
::
size_t
i
=
0
;
i
<
gemm_
shape
s
.
size
();
i
++
)
for
(
std
::
size_t
i
=
0
;
i
<
gemm_
desc
s
.
size
();
i
++
)
{
a_tensors
.
emplace_back
(
Tensor
<
ADataType
>
(
f_host_tensor_descriptor
(
gemm_
shape
s
[
i
].
M
,
gemm_
shape
s
[
i
].
K
,
gemm_
shape
s
[
i
].
S
tride
A
,
ALayout
{})));
gemm_
desc
s
[
i
].
M
_
,
gemm_
desc
s
[
i
].
K
_
,
gemm_
desc
s
[
i
].
s
tride
_A_
,
ALayout
{})));
b_tensors
.
emplace_back
(
Tensor
<
BDataType
>
(
f_host_tensor_descriptor
(
gemm_
shape
s
[
i
].
K
,
gemm_
shape
s
[
i
].
N
,
gemm_
shape
s
[
i
].
S
tride
B
,
BLayout
{})));
gemm_
desc
s
[
i
].
K
_
,
gemm_
desc
s
[
i
].
N
_
,
gemm_
desc
s
[
i
].
s
tride
_B_
,
BLayout
{})));
c_host_tensors
.
emplace_back
(
Tensor
<
CDataType
>
(
f_host_tensor_descriptor
(
gemm_
shape
s
[
i
].
M
,
gemm_
shape
s
[
i
].
N
,
gemm_
shape
s
[
i
].
S
tride
C
,
CLayout
{})));
gemm_
desc
s
[
i
].
M
_
,
gemm_
desc
s
[
i
].
N
_
,
gemm_
desc
s
[
i
].
s
tride
_C_
,
CLayout
{})));
c_device_tensors
.
emplace_back
(
Tensor
<
CDataType
>
(
f_host_tensor_descriptor
(
gemm_
shape
s
[
i
].
M
,
gemm_
shape
s
[
i
].
N
,
gemm_
shape
s
[
i
].
S
tride
C
,
CLayout
{})));
gemm_
desc
s
[
i
].
M
_
,
gemm_
desc
s
[
i
].
N
_
,
gemm_
desc
s
[
i
].
s
tride
_C_
,
CLayout
{})));
a_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_2
<
ADataType
>
{
-
5
,
5
});
b_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_2
<
BDataType
>
{
-
5
,
5
});
}
for
(
std
::
size_t
i
=
0
;
i
<
gemm_
shape
s
.
size
();
i
++
)
for
(
std
::
size_t
i
=
0
;
i
<
gemm_
desc
s
.
size
();
i
++
)
{
a_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
ADataType
)
*
a_tensors
[
i
].
mDesc
.
GetElementSize
()));
...
...
@@ -144,7 +144,7 @@ bool TestGroupedGemm(DeviceGroupedGemmPtr_& groupedGemmPtr)
auto
invoker_ptr
=
groupedGemmPtr
->
MakeInvokerPointer
();
auto
argument_ptr
=
groupedGemmPtr
->
MakeArgumentPointer
(
p_a
,
p_b
,
p_c
,
gemm_
shape
s
,
a_element_op
,
b_element_op
,
c_element_op
);
p_a
,
p_b
,
p_c
,
gemm_
desc
s
,
a_element_op
,
b_element_op
,
c_element_op
);
DeviceMem
gemm_desc_workspace
(
groupedGemmPtr
->
GetWorkSpaceSize
(
argument_ptr
.
get
()));
...
...
@@ -152,7 +152,7 @@ bool TestGroupedGemm(DeviceGroupedGemmPtr_& groupedGemmPtr)
invoker_ptr
->
Run
(
argument_ptr
.
get
());
for
(
std
::
size_t
i
=
0
;
i
<
gemm_
shape
s
.
size
();
i
++
)
for
(
std
::
size_t
i
=
0
;
i
<
gemm_
desc
s
.
size
();
i
++
)
{
c_tensors_device
[
i
]
->
FromDevice
(
c_device_tensors
[
i
].
mData
.
data
());
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment