Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
f1cdecfb
Commit
f1cdecfb
authored
May 31, 2022
by
Jing Zhang
Browse files
fix
parent
426abafe
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
120 additions
and
67 deletions
+120
-67
example/15_grouped_gemm/grouped_gemm_transpose_xdl_fp16.cpp
example/15_grouped_gemm/grouped_gemm_transpose_xdl_fp16.cpp
+38
-9
include/ck/tensor_operation/gpu/device/device_grouped_gemm.hpp
...de/ck/tensor_operation/gpu/device/device_grouped_gemm.hpp
+17
-16
include/ck/tensor_operation/gpu/device/device_grouped_gemm_transpose_xdl.hpp
...peration/gpu/device/device_grouped_gemm_transpose_xdl.hpp
+65
-42
No files found.
example/15_grouped_gemm/grouped_gemm_transpose_xdl_fp16.cpp
View file @
f1cdecfb
...
@@ -78,7 +78,7 @@ int main(int argc, char* argv[])
...
@@ -78,7 +78,7 @@ int main(int argc, char* argv[])
exit
(
0
);
exit
(
0
);
}
}
int
group_count
=
rand
()
%
16
+
1
;
int
group_count
=
4
;
// GEMM shape
// GEMM shape
std
::
vector
<
ck
::
tensor_operation
::
device
::
GemmTransposeDesc
>
gemm_descs
;
std
::
vector
<
ck
::
tensor_operation
::
device
::
GemmTransposeDesc
>
gemm_descs
;
...
@@ -89,11 +89,13 @@ int main(int argc, char* argv[])
...
@@ -89,11 +89,13 @@ int main(int argc, char* argv[])
for
(
int
i
=
0
;
i
<
group_count
;
i
++
)
for
(
int
i
=
0
;
i
<
group_count
;
i
++
)
{
{
int
M
=
1024
;
int
B
=
16
;
int
N
=
1024
;
int
S
=
64
;
int
K
=
1024
;
int
nH
=
16
;
int
hD
=
64
;
gemm_descs
.
push_back
({
M
,
N
,
K
,
K
,
K
,
N
});
gemm_descs
.
push_back
(
{
B
*
S
,
nH
*
hD
,
nH
*
hD
,
nH
*
hD
,
nH
*
hD
,
B
,
S
,
nH
,
hD
,
S
*
nH
*
hD
,
S
*
hD
,
hD
,
1
});
}
}
auto
f_host_tensor_descriptor
=
auto
f_host_tensor_descriptor
=
...
@@ -110,6 +112,19 @@ int main(int argc, char* argv[])
...
@@ -110,6 +112,19 @@ int main(int argc, char* argv[])
}
}
};
};
auto
f_host_c_tensor_descriptor
=
[](
std
::
size_t
M0
,
std
::
size_t
M1
,
std
::
size_t
N0
,
std
::
size_t
N1
,
std
::
size_t
StrideM0
,
std
::
size_t
StrideM1
,
std
::
size_t
StrideN0
,
std
::
size_t
StrideN1
)
{
return
HostTensorDescriptor
(
std
::
vector
<
std
::
size_t
>
({
M0
,
M1
,
N0
,
N1
}),
std
::
vector
<
std
::
size_t
>
({
StrideM0
,
StrideM1
,
StrideN0
,
StrideN1
}));
};
std
::
vector
<
Tensor
<
ADataType
>>
a_tensors
;
std
::
vector
<
Tensor
<
ADataType
>>
a_tensors
;
std
::
vector
<
Tensor
<
BDataType
>>
b_tensors
;
std
::
vector
<
Tensor
<
BDataType
>>
b_tensors
;
std
::
vector
<
Tensor
<
CDataType
>>
c_host_tensors
;
std
::
vector
<
Tensor
<
CDataType
>>
c_host_tensors
;
...
@@ -136,10 +151,24 @@ int main(int argc, char* argv[])
...
@@ -136,10 +151,24 @@ int main(int argc, char* argv[])
gemm_descs
[
i
].
M
,
gemm_descs
[
i
].
K
,
gemm_descs
[
i
].
StrideA
,
ALayout
{})));
gemm_descs
[
i
].
M
,
gemm_descs
[
i
].
K
,
gemm_descs
[
i
].
StrideA
,
ALayout
{})));
b_tensors
.
push_back
(
Tensor
<
BDataType
>
(
f_host_tensor_descriptor
(
b_tensors
.
push_back
(
Tensor
<
BDataType
>
(
f_host_tensor_descriptor
(
gemm_descs
[
i
].
K
,
gemm_descs
[
i
].
N
,
gemm_descs
[
i
].
StrideB
,
BLayout
{})));
gemm_descs
[
i
].
K
,
gemm_descs
[
i
].
N
,
gemm_descs
[
i
].
StrideB
,
BLayout
{})));
c_host_tensors
.
push_back
(
Tensor
<
CDataType
>
(
f_host_tensor_descriptor
(
c_host_tensors
.
push_back
(
gemm_descs
[
i
].
M
,
gemm_descs
[
i
].
N
,
gemm_descs
[
i
].
StrideC
,
CLayout
{})));
Tensor
<
CDataType
>
(
f_host_c_tensor_descriptor
(
gemm_descs
[
i
].
M0
,
c_device_tensors
.
push_back
(
Tensor
<
CDataType
>
(
f_host_tensor_descriptor
(
gemm_descs
[
i
].
M1
,
gemm_descs
[
i
].
M
,
gemm_descs
[
i
].
N
,
gemm_descs
[
i
].
StrideC
,
CLayout
{})));
gemm_descs
[
i
].
N0
,
gemm_descs
[
i
].
N1
,
gemm_descs
[
i
].
StrideM0
,
gemm_descs
[
i
].
StrideM1
,
gemm_descs
[
i
].
StrideN0
,
gemm_descs
[
i
].
StrideN1
)));
c_device_tensors
.
push_back
(
Tensor
<
CDataType
>
(
f_host_c_tensor_descriptor
(
gemm_descs
[
i
].
M0
,
gemm_descs
[
i
].
M1
,
gemm_descs
[
i
].
N0
,
gemm_descs
[
i
].
N1
,
gemm_descs
[
i
].
StrideM0
,
gemm_descs
[
i
].
StrideM1
,
gemm_descs
[
i
].
StrideN0
,
gemm_descs
[
i
].
StrideN1
)));
std
::
cout
<<
"gemm["
<<
i
<<
"] a_m_k: "
<<
a_tensors
[
i
].
mDesc
std
::
cout
<<
"gemm["
<<
i
<<
"] a_m_k: "
<<
a_tensors
[
i
].
mDesc
<<
" b_k_n: "
<<
b_tensors
[
i
].
mDesc
<<
" c_m_n: "
<<
c_device_tensors
[
i
].
mDesc
<<
" b_k_n: "
<<
b_tensors
[
i
].
mDesc
<<
" c_m_n: "
<<
c_device_tensors
[
i
].
mDesc
...
...
include/ck/tensor_operation/gpu/device/device_grouped_gemm.hpp
View file @
f1cdecfb
...
@@ -16,11 +16,11 @@ struct GemmDesc
...
@@ -16,11 +16,11 @@ struct GemmDesc
struct
GemmTransposeDesc
struct
GemmTransposeDesc
{
{
ck
::
index_t
M
,
N
,
K
;
;
ck
::
index_t
M
,
N
,
K
;
ck
::
index_t
StrideA
,
StrideB
,
StrideC
;
ck
::
index_t
StrideA
,
StrideB
;
ck
::
index_t
B
,
S
,
NumHead
,
HeadDim
;
ck
::
index_t
M0
,
M1
,
N0
,
N1
;
std
::
vector
<
ck
::
index_t
>
transpose
;
ck
::
index_t
StrideM0
,
StrideM1
,
StrideN0
,
StrideN1
;
};
};
template
<
typename
AElementwiseOperation
,
template
<
typename
AElementwiseOperation
,
...
@@ -51,14 +51,15 @@ template <typename AElementwiseOperation,
...
@@ -51,14 +51,15 @@ template <typename AElementwiseOperation,
typename
CElementwiseOperation
>
typename
CElementwiseOperation
>
struct
DeviceGroupedGemmTranspose
:
public
BaseOperator
struct
DeviceGroupedGemmTranspose
:
public
BaseOperator
{
{
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
std
::
vector
<
const
void
*>&
p_a
,
virtual
std
::
unique_ptr
<
BaseArgument
>
std
::
vector
<
const
void
*>&
p_b
,
MakeArgumentPointer
(
std
::
vector
<
const
void
*>&
p_a
,
std
::
vector
<
void
*>&
p_c
,
std
::
vector
<
const
void
*>&
p_b
,
std
::
vector
<
GemmTransposeDesc
>&
gemm_transpose_desc
,
std
::
vector
<
void
*>&
p_c
,
AElementwiseOperation
a_element_op
,
std
::
vector
<
GemmTransposeDesc
>&
gemm_transpose_desc
,
BElementwiseOperation
b_element_op
,
AElementwiseOperation
a_element_op
,
CElementwiseOperation
c_element_op
,
BElementwiseOperation
b_element_op
,
ck
::
index_t
KBatch
=
1
)
=
0
;
CElementwiseOperation
c_element_op
,
ck
::
index_t
KBatch
=
1
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
};
};
...
@@ -66,10 +67,10 @@ struct DeviceGroupedGemmTranspose : public BaseOperator
...
@@ -66,10 +67,10 @@ struct DeviceGroupedGemmTranspose : public BaseOperator
template
<
typename
AElementwiseOperation
,
template
<
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
>
typename
CElementwiseOperation
>
using
DeviceGroupedGemmTransposePtr
=
std
::
unique_ptr
<
using
DeviceGroupedGemmTransposePtr
=
DeviceGroupedGemmTranspose
<
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
>>
;
std
::
unique_ptr
<
DeviceGroupedGemmTranspose
<
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
>>
;
}
// namespace device
}
// namespace device
}
// namespace tensor_operation
}
// namespace tensor_operation
...
...
include/ck/tensor_operation/gpu/device/device_grouped_gemm_transpose_xdl.hpp
View file @
f1cdecfb
...
@@ -29,11 +29,12 @@ __global__ void
...
@@ -29,11 +29,12 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
#endif
kernel_grouped_gemm_transpose_xdlops_v2r3
(
const
void
CK_CONSTANT_ADDRESS_SPACE
*
gemm_descs_const
,
kernel_grouped_gemm_transpose_xdlops_v2r3
(
const
index_t
group_count
,
const
void
CK_CONSTANT_ADDRESS_SPACE
*
gemm_descs_const
,
const
AElementwiseOperation
a_element_op
,
const
index_t
group_count
,
const
BElementwiseOperation
b_element_op
,
const
AElementwiseOperation
a_element_op
,
const
CElementwiseOperation
c_element_op
)
const
BElementwiseOperation
b_element_op
,
const
CElementwiseOperation
c_element_op
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
...
@@ -111,8 +112,9 @@ template <typename ADataType,
...
@@ -111,8 +112,9 @@ template <typename ADataType,
ck
::
index_t
CThreadTransferDstScalarPerVector
,
ck
::
index_t
CThreadTransferDstScalarPerVector
,
ck
::
index_t
NumPrefetch
=
1
,
ck
::
index_t
NumPrefetch
=
1
,
ck
::
index_t
MaxGroupCount
=
10
>
ck
::
index_t
MaxGroupCount
=
10
>
struct
DeviceGroupedGemmTransposeXdl
struct
DeviceGroupedGemmTransposeXdl
:
public
DeviceGroupedGemmTranspose
<
AElementwiseOperation
,
:
public
DeviceGroupedGemmTranspose
<
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
>
BElementwiseOperation
,
CElementwiseOperation
>
{
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
...
@@ -198,19 +200,29 @@ struct DeviceGroupedGemmTransposeXdl
...
@@ -198,19 +200,29 @@ struct DeviceGroupedGemmTransposeXdl
}
}
}
}
static
auto
MakeCGridDescriptor_M_N
(
index_t
M
,
index_t
N
,
index_t
StrideC
)
static
auto
MakeCGridDescriptor_M_N
(
index_t
M0
,
index_t
M1
,
index_t
N0
,
index_t
N1
,
index_t
StrideM0
,
index_t
StrideM1
,
index_t
StrideN0
,
index_t
StrideN1
)
{
{
const
auto
c_grid_desc_m_n
=
[
&
]()
{
const
auto
c_grid_desc_m_n
=
[
&
]()
{
if
const
expr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
CLayout
>::
value
)
const
auto
c_grid_desc_m0_m1_n0_n1
=
make_naive_tensor_descriptor
(
{
make_tuple
(
M0
,
M1
,
N0
,
N1
),
make_tuple
(
StrideM0
,
StrideM1
,
StrideN0
,
StrideN1
));
return
make_naive_tensor_descriptor
(
make_tuple
(
M
,
N
),
make_tuple
(
StrideC
,
I1
));
}
return
transform_tensor_descriptor
(
c_grid_desc_m0_m1_n0_n1
,
else
if
constexpr
(
is_same
<
te
nsor
_layout
::
gemm
::
ColumnMajor
,
CLayout
>::
value
)
make_tuple
(
make_merge_tra
ns
f
or
m
(
make_tuple
(
M0
,
M1
)),
{
make_merge_transform
(
make_tuple
(
N0
,
N1
))),
return
make_naive_tensor_descriptor
(
make_tuple
(
M
,
N
),
make_tuple
(
I1
,
StrideC
));
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
,
3
>
{}),
}
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}();
}();
const
index_t
M
=
M0
*
M1
;
const
index_t
N
=
N0
*
N1
;
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MNPadding
)
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MNPadding
)
{
{
const
auto
PadM
=
(
MPerBlock
-
M
%
MPerBlock
)
%
MPerBlock
;
const
auto
PadM
=
(
MPerBlock
-
M
%
MPerBlock
)
%
MPerBlock
;
...
@@ -235,7 +247,7 @@ struct DeviceGroupedGemmTransposeXdl
...
@@ -235,7 +247,7 @@ struct DeviceGroupedGemmTransposeXdl
using
AGridDesc_K0_M_K1
=
decltype
(
MakeAGridDescriptor_K0_M_K1
(
1
,
1
,
1
));
using
AGridDesc_K0_M_K1
=
decltype
(
MakeAGridDescriptor_K0_M_K1
(
1
,
1
,
1
));
using
BGridDesc_K0_N_K1
=
decltype
(
MakeBGridDescriptor_K0_N_K1
(
1
,
1
,
1
));
using
BGridDesc_K0_N_K1
=
decltype
(
MakeBGridDescriptor_K0_N_K1
(
1
,
1
,
1
));
using
CGridDesc_M_N
=
decltype
(
MakeCGridDescriptor_M_N
(
1
,
1
,
1
));
using
CGridDesc_M_N
=
decltype
(
MakeCGridDescriptor_M_N
(
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
));
// GridwiseGemm
// GridwiseGemm
using
GridwiseGemm
=
GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
<
using
GridwiseGemm
=
GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
<
...
@@ -384,14 +396,21 @@ struct DeviceGroupedGemmTransposeXdl
...
@@ -384,14 +396,21 @@ struct DeviceGroupedGemmTransposeXdl
const
index_t
StrideA
=
gemm_transpose_desc
[
i
].
StrideA
;
const
index_t
StrideA
=
gemm_transpose_desc
[
i
].
StrideA
;
const
index_t
StrideB
=
gemm_transpose_desc
[
i
].
StrideB
;
const
index_t
StrideB
=
gemm_transpose_desc
[
i
].
StrideB
;
const
index_t
StrideC
=
gemm_transpose_desc
[
i
].
StrideC
;
const
auto
a_grid_desc_k0_m_k1_
=
const
auto
a_grid_desc_k0_m_k1_
=
DeviceGroupedGemmTransposeXdl
::
MakeAGridDescriptor_K0_M_K1
(
M
,
K
,
StrideA
);
DeviceGroupedGemmTransposeXdl
::
MakeAGridDescriptor_K0_M_K1
(
M
,
K
,
StrideA
);
const
auto
b_grid_desc_k0_n_k1_
=
const
auto
b_grid_desc_k0_n_k1_
=
DeviceGroupedGemmTransposeXdl
::
MakeBGridDescriptor_K0_N_K1
(
K
,
N
,
StrideB
);
DeviceGroupedGemmTransposeXdl
::
MakeBGridDescriptor_K0_N_K1
(
K
,
N
,
StrideB
);
const
auto
c_grid_desc_m_n_
=
const
auto
c_grid_desc_m_n_
=
DeviceGroupedGemmTransposeXdl
::
MakeCGridDescriptor_M_N
(
M
,
N
,
StrideC
);
DeviceGroupedGemmTransposeXdl
::
MakeCGridDescriptor_M_N
(
gemm_transpose_desc
[
i
].
M0
,
gemm_transpose_desc
[
i
].
M1
,
gemm_transpose_desc
[
i
].
N0
,
gemm_transpose_desc
[
i
].
N1
,
gemm_transpose_desc
[
i
].
StrideM0
,
gemm_transpose_desc
[
i
].
StrideM1
,
gemm_transpose_desc
[
i
].
StrideN0
,
gemm_transpose_desc
[
i
].
StrideN1
);
const
index_t
grid_size_grp
=
const
index_t
grid_size_grp
=
typename
GroupedGemmBlock2CTileMap
::
UnderlyingBlock2CTileMap
(
typename
GroupedGemmBlock2CTileMap
::
UnderlyingBlock2CTileMap
(
...
@@ -501,13 +520,14 @@ struct DeviceGroupedGemmTransposeXdl
...
@@ -501,13 +520,14 @@ struct DeviceGroupedGemmTransposeXdl
{
{
const
auto
kernel
=
const
auto
kernel
=
kernel_grouped_gemm_transpose_xdlops_v2r3
<
GridwiseGemm
,
kernel_grouped_gemm_transpose_xdlops_v2r3
<
GridwiseGemm
,
ADataType
,
// TODO: distiguish A/B datatype
ADataType
,
// TODO: distiguish A/B
CDataType
,
// datatype
GemmDescKernelArg
,
CDataType
,
AElementwiseOperation
,
GemmDescKernelArg
,
BElementwiseOperation
,
AElementwiseOperation
,
CElementwiseOperation
,
BElementwiseOperation
,
true
>
;
CElementwiseOperation
,
true
>
;
ave_time
=
launch_and_time_kernel
(
ave_time
=
launch_and_time_kernel
(
stream_config
,
stream_config
,
...
@@ -525,13 +545,14 @@ struct DeviceGroupedGemmTransposeXdl
...
@@ -525,13 +545,14 @@ struct DeviceGroupedGemmTransposeXdl
{
{
const
auto
kernel
=
const
auto
kernel
=
kernel_grouped_gemm_transpose_xdlops_v2r3
<
GridwiseGemm
,
kernel_grouped_gemm_transpose_xdlops_v2r3
<
GridwiseGemm
,
ADataType
,
// TODO: distiguish A/B datatype
ADataType
,
// TODO: distiguish A/B
CDataType
,
// datatype
GemmDescKernelArg
,
CDataType
,
AElementwiseOperation
,
GemmDescKernelArg
,
BElementwiseOperation
,
AElementwiseOperation
,
CElementwiseOperation
,
BElementwiseOperation
,
false
>
;
CElementwiseOperation
,
false
>
;
ave_time
=
launch_and_time_kernel
(
ave_time
=
launch_and_time_kernel
(
stream_config
,
stream_config
,
...
@@ -585,20 +606,22 @@ struct DeviceGroupedGemmTransposeXdl
...
@@ -585,20 +606,22 @@ struct DeviceGroupedGemmTransposeXdl
BElementwiseOperation
b_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
)
CElementwiseOperation
c_element_op
)
{
{
return
Argument
{
p_a
,
p_b
,
p_c
,
gemm_transpose_desc
,
1
,
1
,
a_element_op
,
b_element_op
,
c_element_op
};
return
Argument
{
p_a
,
p_b
,
p_c
,
gemm_transpose_desc
,
1
,
1
,
a_element_op
,
b_element_op
,
c_element_op
};
}
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
// polymorphic
// polymorphic
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
std
::
vector
<
const
void
*>&
p_a
,
std
::
unique_ptr
<
BaseArgument
>
std
::
vector
<
const
void
*>&
p_b
,
MakeArgumentPointer
(
std
::
vector
<
const
void
*>&
p_a
,
std
::
vector
<
void
*>&
p_c
,
std
::
vector
<
const
void
*>&
p_b
,
std
::
vector
<
GemmTransposeDesc
>&
gemm_transpose_desc
,
std
::
vector
<
void
*>&
p_c
,
AElementwiseOperation
a_element_op
,
std
::
vector
<
GemmTransposeDesc
>&
gemm_transpose_desc
,
BElementwiseOperation
b_element_op
,
AElementwiseOperation
a_element_op
,
CElementwiseOperation
c_element_op
,
BElementwiseOperation
b_element_op
,
index_t
/* KBatch */
=
1
)
override
CElementwiseOperation
c_element_op
,
index_t
/* KBatch */
=
1
)
override
{
{
return
std
::
make_unique
<
Argument
>
(
return
std
::
make_unique
<
Argument
>
(
p_a
,
p_b
,
p_c
,
gemm_transpose_desc
,
1
,
1
,
a_element_op
,
b_element_op
,
c_element_op
);
p_a
,
p_b
,
p_c
,
gemm_transpose_desc
,
1
,
1
,
a_element_op
,
b_element_op
,
c_element_op
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment