Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
f2cc8405
"...composable_kernel_rocm.git" did not exist on "11e4082dd8459f3a0e69f7d164ef64eb7ebfa7fa"
Commit
f2cc8405
authored
Jun 09, 2022
by
Jing Zhang
Browse files
fixed comments
parent
35b07efb
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
95 additions
and
93 deletions
+95
-93
example/15_grouped_gemm/grouped_gemm_transpose_xdl_fp16.cpp
example/15_grouped_gemm/grouped_gemm_transpose_xdl_fp16.cpp
+51
-55
example/15_grouped_gemm/grouped_gemm_xdl_fp16.cpp
example/15_grouped_gemm/grouped_gemm_xdl_fp16.cpp
+5
-5
include/ck/tensor_operation/gpu/device/device_grouped_gemm.hpp
...de/ck/tensor_operation/gpu/device/device_grouped_gemm.hpp
+6
-6
include/ck/tensor_operation/gpu/device/device_grouped_gemm_transpose_xdl.hpp
...peration/gpu/device/device_grouped_gemm_transpose_xdl.hpp
+21
-15
include/ck/tensor_operation/gpu/device/device_grouped_gemm_xdl.hpp
...k/tensor_operation/gpu/device/device_grouped_gemm_xdl.hpp
+6
-6
library/include/ck/library/reference_tensor_operation/cpu/reference_gemm_transpose.hpp
...ference_tensor_operation/cpu/reference_gemm_transpose.hpp
+2
-2
test/grouped_gemm/grouped_gemm_fp16.cpp
test/grouped_gemm/grouped_gemm_fp16.cpp
+4
-4
No files found.
example/15_grouped_gemm/grouped_gemm_transpose_xdl_fp16.cpp
View file @
f2cc8405
...
...
@@ -78,7 +78,7 @@ int main(int argc, char* argv[])
exit
(
0
);
}
int
group_count
=
4
;
int
group_count
=
rand
()
%
16
+
1
;
// GEMM shape
std
::
vector
<
ck
::
tensor_operation
::
device
::
GemmTransposeDesc
>
gemm_descs
;
...
...
@@ -89,66 +89,62 @@ int main(int argc, char* argv[])
for
(
int
i
=
0
;
i
<
group_count
;
i
++
)
{
int
B
=
1
6
;
int
S
=
64
;
int
NumHead
=
1
6
;
int
HeadDim
=
64
;
const
int
M0
=
rand
()
%
4
+
1
;
const
int
M1
=
256
;
const
int
N0
=
rand
()
%
4
+
1
;
const
int
N1
=
256
;
int
M0
=
B
;
int
M1
=
S
;
int
N0
=
NumHead
;
int
N1
=
HeadDim
;
const
int
M
=
M0
*
N1
;
const
int
N
=
N0
*
N1
;
int
M
=
M0
*
N1
;
int
N
=
N0
*
N1
;
int
K
=
NumHead
*
HeadDim
;
const
int
K
=
128
*
(
rand
()
%
4
+
1
);
int
S
trideA
=
K
;
int
S
trideB
=
K
;
const
int
s
tride
_
A
=
K
;
const
int
s
tride
_
B
=
K
;
if
(
i
%
2
==
0
)
{
int
S
trideM0
=
S
*
NumHead
*
HeadDim
;
int
S
trideM1
=
1
;
int
S
trideN0
=
S
*
HeadDim
;
int
S
trideN1
=
S
;
// output layout [M0, N0, M1, N1]
const
int
s
tride
_
M0
=
N1
*
M1
*
N0
;
const
int
s
tride
_
M1
=
N
1
;
const
int
s
tride
_
N0
=
N1
*
M1
;
const
int
s
tride
_
N1
=
1
;
gemm_descs
.
push_back
({
M
,
N
,
K
,
S
trideA
,
S
trideB
,
s
tride
_
A
,
s
tride
_
B
,
M0
,
M1
,
N0
,
N1
,
S
trideM0
,
S
trideM1
,
S
trideN0
,
S
trideN1
});
s
tride
_
M0
,
s
tride
_
M1
,
s
tride
_
N0
,
s
tride
_
N1
});
}
else
{
int
S
trideM0
=
S
*
N
umHead
*
HeadDim
;
int
S
trideM1
=
HeadDim
;
int
S
trideN0
=
S
*
HeadDim
;
int
S
trideN1
=
1
;
// output layout [M0, N0, N1, M1]
int
s
tride
_
M0
=
N1
*
N
1
*
N0
;
int
s
tride
_
M1
=
1
;
int
s
tride
_
N0
=
M1
*
N1
;
int
s
tride
_
N1
=
M
1
;
gemm_descs
.
push_back
({
M
,
N
,
K
,
S
trideA
,
S
trideB
,
s
tride
_
A
,
s
tride
_
B
,
M0
,
M1
,
N0
,
N1
,
S
trideM0
,
S
trideM1
,
S
trideN0
,
S
trideN1
});
s
tride
_
M0
,
s
tride
_
M1
,
s
tride
_
N0
,
s
tride
_
N1
});
}
}
...
...
@@ -202,33 +198,33 @@ int main(int argc, char* argv[])
for
(
std
::
size_t
i
=
0
;
i
<
gemm_descs
.
size
();
i
++
)
{
a_tensors
.
push_back
(
Tensor
<
ADataType
>
(
f_host_tensor_descriptor
(
gemm_descs
[
i
].
M
,
gemm_descs
[
i
].
K
,
gemm_descs
[
i
].
S
tride
A
,
ALayout
{})));
gemm_descs
[
i
].
M
_
,
gemm_descs
[
i
].
K
_
,
gemm_descs
[
i
].
s
tride
_A_
,
ALayout
{})));
b_tensors
.
push_back
(
Tensor
<
BDataType
>
(
f_host_tensor_descriptor
(
gemm_descs
[
i
].
K
,
gemm_descs
[
i
].
N
,
gemm_descs
[
i
].
S
tride
B
,
BLayout
{})));
gemm_descs
[
i
].
K
_
,
gemm_descs
[
i
].
N
_
,
gemm_descs
[
i
].
s
tride
_B_
,
BLayout
{})));
c_host_tensors
.
push_back
(
Tensor
<
CDataType
>
(
f_host_c_tensor_descriptor
(
gemm_descs
[
i
].
M0
,
gemm_descs
[
i
].
M1
,
gemm_descs
[
i
].
N0
,
gemm_descs
[
i
].
N1
,
gemm_descs
[
i
].
S
trideM0
,
gemm_descs
[
i
].
S
trideM1
,
gemm_descs
[
i
].
S
trideN0
,
gemm_descs
[
i
].
S
trideN1
)));
Tensor
<
CDataType
>
(
f_host_c_tensor_descriptor
(
gemm_descs
[
i
].
M0
_
,
gemm_descs
[
i
].
M1
_
,
gemm_descs
[
i
].
N0
_
,
gemm_descs
[
i
].
N1
_
,
gemm_descs
[
i
].
s
tride
_
M0
_
,
gemm_descs
[
i
].
s
tride
_
M1
_
,
gemm_descs
[
i
].
s
tride
_
N0
_
,
gemm_descs
[
i
].
s
tride
_
N1
_
)));
c_device_tensors
.
push_back
(
Tensor
<
CDataType
>
(
f_host_c_tensor_descriptor
(
gemm_descs
[
i
].
M0
,
gemm_descs
[
i
].
M1
,
gemm_descs
[
i
].
N0
,
gemm_descs
[
i
].
N1
,
gemm_descs
[
i
].
S
trideM0
,
gemm_descs
[
i
].
S
trideM1
,
gemm_descs
[
i
].
S
trideN0
,
gemm_descs
[
i
].
S
trideN1
)));
Tensor
<
CDataType
>
(
f_host_c_tensor_descriptor
(
gemm_descs
[
i
].
M0
_
,
gemm_descs
[
i
].
M1
_
,
gemm_descs
[
i
].
N0
_
,
gemm_descs
[
i
].
N1
_
,
gemm_descs
[
i
].
s
tride
_
M0
_
,
gemm_descs
[
i
].
s
tride
_
M1
_
,
gemm_descs
[
i
].
s
tride
_
N0
_
,
gemm_descs
[
i
].
s
tride
_
N1
_
)));
std
::
cout
<<
"gemm["
<<
i
<<
"] a_m_k: "
<<
a_tensors
[
i
].
mDesc
<<
" b_k_n: "
<<
b_tensors
[
i
].
mDesc
<<
" c_m_n: "
<<
c_device_tensors
[
i
].
mDesc
<<
std
::
endl
;
flop
+=
std
::
size_t
(
2
)
*
gemm_descs
[
i
].
M
*
gemm_descs
[
i
].
K
*
gemm_descs
[
i
].
N
;
flop
+=
std
::
size_t
(
2
)
*
gemm_descs
[
i
].
M
_
*
gemm_descs
[
i
].
K
_
*
gemm_descs
[
i
].
N
_
;
num_btype
+=
sizeof
(
ADataType
)
*
a_tensors
[
i
].
mDesc
.
GetElementSize
()
+
sizeof
(
BDataType
)
*
b_tensors
[
i
].
mDesc
.
GetElementSize
()
+
sizeof
(
CDataType
)
*
c_device_tensors
[
i
].
mDesc
.
GetElementSize
();
...
...
example/15_grouped_gemm/grouped_gemm_xdl_fp16.cpp
View file @
f2cc8405
...
...
@@ -133,19 +133,19 @@ int main(int argc, char* argv[])
for
(
std
::
size_t
i
=
0
;
i
<
gemm_descs
.
size
();
i
++
)
{
a_tensors
.
push_back
(
Tensor
<
ADataType
>
(
f_host_tensor_descriptor
(
gemm_descs
[
i
].
M
,
gemm_descs
[
i
].
K
,
gemm_descs
[
i
].
S
tride
A
,
ALayout
{})));
gemm_descs
[
i
].
M
_
,
gemm_descs
[
i
].
K
_
,
gemm_descs
[
i
].
s
tride
_A_
,
ALayout
{})));
b_tensors
.
push_back
(
Tensor
<
BDataType
>
(
f_host_tensor_descriptor
(
gemm_descs
[
i
].
K
,
gemm_descs
[
i
].
N
,
gemm_descs
[
i
].
S
tride
B
,
BLayout
{})));
gemm_descs
[
i
].
K
_
,
gemm_descs
[
i
].
N
_
,
gemm_descs
[
i
].
s
tride
_B_
,
BLayout
{})));
c_host_tensors
.
push_back
(
Tensor
<
CDataType
>
(
f_host_tensor_descriptor
(
gemm_descs
[
i
].
M
,
gemm_descs
[
i
].
N
,
gemm_descs
[
i
].
S
tride
C
,
CLayout
{})));
gemm_descs
[
i
].
M
_
,
gemm_descs
[
i
].
N
_
,
gemm_descs
[
i
].
s
tride
_C_
,
CLayout
{})));
c_device_tensors
.
push_back
(
Tensor
<
CDataType
>
(
f_host_tensor_descriptor
(
gemm_descs
[
i
].
M
,
gemm_descs
[
i
].
N
,
gemm_descs
[
i
].
S
tride
C
,
CLayout
{})));
gemm_descs
[
i
].
M
_
,
gemm_descs
[
i
].
N
_
,
gemm_descs
[
i
].
s
tride
_C_
,
CLayout
{})));
std
::
cout
<<
"gemm["
<<
i
<<
"] a_m_k: "
<<
a_tensors
[
i
].
mDesc
<<
" b_k_n: "
<<
b_tensors
[
i
].
mDesc
<<
" c_m_n: "
<<
c_device_tensors
[
i
].
mDesc
<<
std
::
endl
;
flop
+=
std
::
size_t
(
2
)
*
gemm_descs
[
i
].
M
*
gemm_descs
[
i
].
K
*
gemm_descs
[
i
].
N
;
flop
+=
std
::
size_t
(
2
)
*
gemm_descs
[
i
].
M
_
*
gemm_descs
[
i
].
K
_
*
gemm_descs
[
i
].
N
_
;
num_btype
+=
sizeof
(
ADataType
)
*
a_tensors
[
i
].
mDesc
.
GetElementSize
()
+
sizeof
(
BDataType
)
*
b_tensors
[
i
].
mDesc
.
GetElementSize
()
+
sizeof
(
CDataType
)
*
c_device_tensors
[
i
].
mDesc
.
GetElementSize
();
...
...
include/ck/tensor_operation/gpu/device/device_grouped_gemm.hpp
View file @
f2cc8405
...
...
@@ -10,17 +10,17 @@ namespace device {
struct
GemmDesc
{
ck
::
index_t
M
,
N
,
K
;
ck
::
index_t
S
tride
A
,
S
tride
B
,
S
tride
C
;
ck
::
index_t
M
_
,
N
_
,
K
_
;
ck
::
index_t
s
tride
_A_
,
s
tride
_B_
,
s
tride
_C_
;
};
struct
GemmTransposeDesc
{
ck
::
index_t
M
,
N
,
K
;
ck
::
index_t
S
tride
A
,
S
tride
B
;
ck
::
index_t
M
_
,
N
_
,
K
_
;
ck
::
index_t
s
tride
_A_
,
s
tride
_B_
;
ck
::
index_t
M0
,
M1
,
N0
,
N1
;
ck
::
index_t
S
trideM0
,
S
trideM1
,
S
trideN0
,
S
trideN1
;
ck
::
index_t
M0
_
,
M1
_
,
N0
_
,
N1
_
;
ck
::
index_t
s
tride
_
M0
_
,
s
tride
_
M1
_
,
s
tride
_
N0
_
,
s
tride
_
N1
_
;
};
template
<
typename
AElementwiseOperation
,
...
...
include/ck/tensor_operation/gpu/device/device_grouped_gemm_transpose_xdl.hpp
View file @
f2cc8405
#ifndef DEVICE_GROUPED_GEMM_XDL_HPP
#define DEVICE_GROUPED_GEMM_XDL_HPP
#ifndef DEVICE_GROUPED_GEMM_
TRANSPOSE_
XDL_HPP
#define DEVICE_GROUPED_GEMM_
TRANSPOSE_
XDL_HPP
#include <iostream>
#include <sstream>
...
...
@@ -389,12 +389,18 @@ struct DeviceGroupedGemmTransposeXdl : public DeviceGroupedGemmTranspose<AElemen
for
(
std
::
size_t
i
=
0
;
i
<
gemm_transpose_desc
.
size
();
i
++
)
{
const
index_t
M
=
gemm_transpose_desc
[
i
].
M
;
const
index_t
N
=
gemm_transpose_desc
[
i
].
N
;
const
index_t
K
=
gemm_transpose_desc
[
i
].
K
;
const
index_t
M
=
gemm_transpose_desc
[
i
].
M
_
;
const
index_t
N
=
gemm_transpose_desc
[
i
].
N
_
;
const
index_t
K
=
gemm_transpose_desc
[
i
].
K
_
;
const
index_t
StrideA
=
gemm_transpose_desc
[
i
].
StrideA
;
const
index_t
StrideB
=
gemm_transpose_desc
[
i
].
StrideB
;
const
index_t
StrideA
=
gemm_transpose_desc
[
i
].
stride_A_
;
const
index_t
StrideB
=
gemm_transpose_desc
[
i
].
stride_B_
;
if
(
!
(
M
==
gemm_transpose_desc
[
i
].
M0_
*
gemm_transpose_desc
[
i
].
M1_
&&
N
==
gemm_transpose_desc
[
i
].
N0_
*
gemm_transpose_desc
[
i
].
N1_
))
{
throw
std
::
runtime_error
(
"wrong! M != M0 * M1 or N != N0 * N1"
);
}
const
auto
a_grid_desc_k0_m_k1_
=
DeviceGroupedGemmTransposeXdl
::
MakeAGridDescriptor_K0_M_K1
(
M
,
K
,
StrideA
);
...
...
@@ -402,14 +408,14 @@ struct DeviceGroupedGemmTransposeXdl : public DeviceGroupedGemmTranspose<AElemen
DeviceGroupedGemmTransposeXdl
::
MakeBGridDescriptor_K0_N_K1
(
K
,
N
,
StrideB
);
const
auto
c_grid_desc_m_n_
=
DeviceGroupedGemmTransposeXdl
::
MakeCGridDescriptor_M_N
(
gemm_transpose_desc
[
i
].
M0
,
gemm_transpose_desc
[
i
].
M1
,
gemm_transpose_desc
[
i
].
N0
,
gemm_transpose_desc
[
i
].
N1
,
gemm_transpose_desc
[
i
].
S
trideM0
,
gemm_transpose_desc
[
i
].
S
trideM1
,
gemm_transpose_desc
[
i
].
S
trideN0
,
gemm_transpose_desc
[
i
].
S
trideN1
);
gemm_transpose_desc
[
i
].
M0
_
,
gemm_transpose_desc
[
i
].
M1
_
,
gemm_transpose_desc
[
i
].
N0
_
,
gemm_transpose_desc
[
i
].
N1
_
,
gemm_transpose_desc
[
i
].
s
tride
_
M0
_
,
gemm_transpose_desc
[
i
].
s
tride
_
M1
_
,
gemm_transpose_desc
[
i
].
s
tride
_
N0
_
,
gemm_transpose_desc
[
i
].
s
tride
_
N1
_
);
const
index_t
grid_size_grp
=
GroupedGemmBlock2CTileMap
(
c_grid_desc_m_n_
,
M01
,
N01
,
0
)
...
...
include/ck/tensor_operation/gpu/device/device_grouped_gemm_xdl.hpp
View file @
f2cc8405
...
...
@@ -377,13 +377,13 @@ struct DeviceGroupedGemmXdl
for
(
std
::
size_t
i
=
0
;
i
<
gemm_descs
.
size
();
i
++
)
{
const
index_t
M
=
gemm_descs
[
i
].
M
;
const
index_t
N
=
gemm_descs
[
i
].
N
;
const
index_t
K
=
gemm_descs
[
i
].
K
;
const
index_t
M
=
gemm_descs
[
i
].
M
_
;
const
index_t
N
=
gemm_descs
[
i
].
N
_
;
const
index_t
K
=
gemm_descs
[
i
].
K
_
;
const
index_t
StrideA
=
gemm_descs
[
i
].
S
tride
A
;
const
index_t
StrideB
=
gemm_descs
[
i
].
S
tride
B
;
const
index_t
StrideC
=
gemm_descs
[
i
].
S
tride
C
;
const
index_t
StrideA
=
gemm_descs
[
i
].
s
tride
_A_
;
const
index_t
StrideB
=
gemm_descs
[
i
].
s
tride
_B_
;
const
index_t
StrideC
=
gemm_descs
[
i
].
s
tride
_C_
;
const
auto
a_grid_desc_k0_m_k1_
=
DeviceGroupedGemmXdl
::
MakeAGridDescriptor_K0_M_K1
(
M
,
K
,
StrideA
);
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_gemm_transpose.hpp
View file @
f2cc8405
...
...
@@ -63,8 +63,8 @@ struct ReferenceGemmTranspose : public device::BaseOperator
float
v_a
;
float
v_b
;
arg
.
a_element_op_
(
v_a
,
static_cas
t
<
const
float
>
(
arg
.
a_m_k_
(
m
,
k
)));
arg
.
b_element_op_
(
v_b
,
static_cas
t
<
const
float
>
(
arg
.
b_k_n_
(
k
,
n
)));
arg
.
a_element_op_
(
v_a
,
ck
::
type_conver
t
<
const
float
>
(
arg
.
a_m_k_
(
m
,
k
)));
arg
.
b_element_op_
(
v_b
,
ck
::
type_conver
t
<
const
float
>
(
arg
.
b_k_n_
(
k
,
n
)));
v_acc
+=
v_a
*
v_b
;
}
...
...
test/grouped_gemm/grouped_gemm_fp16.cpp
View file @
f2cc8405
...
...
@@ -107,13 +107,13 @@ bool TestGroupedGemm(DeviceGroupedGemmPtr_& groupedGemmPtr)
for
(
std
::
size_t
i
=
0
;
i
<
gemm_descs
.
size
();
i
++
)
{
a_tensors
.
emplace_back
(
Tensor
<
ADataType
>
(
f_host_tensor_descriptor
(
gemm_descs
[
i
].
M
,
gemm_descs
[
i
].
K
,
gemm_descs
[
i
].
S
tride
A
,
ALayout
{})));
gemm_descs
[
i
].
M
_
,
gemm_descs
[
i
].
K
_
,
gemm_descs
[
i
].
s
tride
_A_
,
ALayout
{})));
b_tensors
.
emplace_back
(
Tensor
<
BDataType
>
(
f_host_tensor_descriptor
(
gemm_descs
[
i
].
K
,
gemm_descs
[
i
].
N
,
gemm_descs
[
i
].
S
tride
B
,
BLayout
{})));
gemm_descs
[
i
].
K
_
,
gemm_descs
[
i
].
N
_
,
gemm_descs
[
i
].
s
tride
_B_
,
BLayout
{})));
c_host_tensors
.
emplace_back
(
Tensor
<
CDataType
>
(
f_host_tensor_descriptor
(
gemm_descs
[
i
].
M
,
gemm_descs
[
i
].
N
,
gemm_descs
[
i
].
S
tride
C
,
CLayout
{})));
gemm_descs
[
i
].
M
_
,
gemm_descs
[
i
].
N
_
,
gemm_descs
[
i
].
s
tride
_C_
,
CLayout
{})));
c_device_tensors
.
emplace_back
(
Tensor
<
CDataType
>
(
f_host_tensor_descriptor
(
gemm_descs
[
i
].
M
,
gemm_descs
[
i
].
N
,
gemm_descs
[
i
].
S
tride
C
,
CLayout
{})));
gemm_descs
[
i
].
M
_
,
gemm_descs
[
i
].
N
_
,
gemm_descs
[
i
].
s
tride
_C_
,
CLayout
{})));
a_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_2
<
ADataType
>
{
-
5
,
5
});
b_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_2
<
BDataType
>
{
-
5
,
5
});
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment