Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
35f1a04e
"git@developer.sourcefind.cn:modelzoo/resnet50_tensorflow.git" did not exist on "47545935c167c28f66896768481e7d25a870eaff"
Commit
35f1a04e
authored
Dec 14, 2022
by
Rosty Geyyer
Browse files
Add padding device_gemm_fastgelu_xdl_c_shuffle instances
parent
6a617c09
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
108 additions
and
4 deletions
+108
-4
library/src/tensor_operation_instance/gpu/gemm_fastgelu/device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instance.cpp
..._fastgelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instance.cpp
+27
-1
library/src/tensor_operation_instance/gpu/gemm_fastgelu/device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instance.cpp
..._fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instance.cpp
+27
-1
library/src/tensor_operation_instance/gpu/gemm_fastgelu/device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instance.cpp
..._fastgelu_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instance.cpp
+27
-1
library/src/tensor_operation_instance/gpu/gemm_fastgelu/device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp
..._fastgelu_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp
+27
-1
No files found.
library/src/tensor_operation_instance/gpu/gemm_fastgelu/device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instance.cpp
View file @
35f1a04e
...
@@ -15,7 +15,8 @@ namespace instance {
...
@@ -15,7 +15,8 @@ namespace instance {
template
<
ck
::
index_t
...
Is
>
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
S
=
ck
::
Sequence
<
Is
...
>
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
static
constexpr
auto
GemmMNKPadding
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
;
// e = elementwise((a * b))
// e = elementwise((a * b))
// outout: e[m, n]
// outout: e[m, n]
...
@@ -86,6 +87,28 @@ using device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances = std::t
...
@@ -86,6 +87,28 @@ using device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances = std::t
// clang-format on
// clang-format on
>
;
>
;
// irregular tile size
using
device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_irregular_tile_instances
=
std
::
tuple
<
// clang-format off
//##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline|
//##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | |
//##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | |
//##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// pipeline v1, 1 wave
DeviceGemmMultipleD_Xdl_CShuffle
<
Col
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F32
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
FastGelu
,
GemmMNKPadding
,
1
,
64
,
16
,
16
,
32
,
8
,
8
,
16
,
16
,
1
,
1
,
S
<
4
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
1
,
S
<
4
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
1
,
LoopScheduler
::
Default
,
PipelineVersion
::
v1
>
#if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES
// pipeline v1, 2 waves
,
DeviceGemmMultipleD_Xdl_CShuffle
<
Col
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F32
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
FastGelu
,
GemmMNKPadding
,
1
,
64
,
16
,
16
,
32
,
8
,
8
,
16
,
16
,
1
,
1
,
S
<
4
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
1
,
S
<
4
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
1
,
LoopScheduler
::
Interwave
,
PipelineVersion
::
v1
>
#endif
#if CK_EXPERIMENTAL_PIPELINE_V2_INSTANCES
// pipeline v2, 1 wave
,
DeviceGemmMultipleD_Xdl_CShuffle
<
Col
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F32
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
FastGelu
,
GemmMNKPadding
,
1
,
64
,
16
,
16
,
32
,
8
,
8
,
16
,
16
,
1
,
1
,
S
<
4
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
1
,
S
<
4
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
1
,
LoopScheduler
::
Default
,
PipelineVersion
::
v2
>
#endif
// clang-format on
>
;
void
add_device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances
(
void
add_device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
<
Col
,
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
<
Col
,
Row
,
Row
,
...
@@ -101,6 +124,9 @@ void add_device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances(
...
@@ -101,6 +124,9 @@ void add_device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances(
{
{
add_device_operation_instances
(
add_device_operation_instances
(
instances
,
device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances
{});
instances
,
device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances
{});
add_device_operation_instances
(
instances
,
device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_irregular_tile_instances
{});
}
}
}
// namespace instance
}
// namespace instance
...
...
library/src/tensor_operation_instance/gpu/gemm_fastgelu/device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instance.cpp
View file @
35f1a04e
...
@@ -15,7 +15,8 @@ namespace instance {
...
@@ -15,7 +15,8 @@ namespace instance {
template
<
ck
::
index_t
...
Is
>
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
S
=
ck
::
Sequence
<
Is
...
>
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
static
constexpr
auto
GemmMNKPadding
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
;
// e = elementwise((a * b))
// e = elementwise((a * b))
// outout: e[m, n]
// outout: e[m, n]
...
@@ -86,6 +87,28 @@ using device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances = std::t
...
@@ -86,6 +87,28 @@ using device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances = std::t
// clang-format on
// clang-format on
>
;
>
;
// irregular tile size
using
device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_irregular_tile_instances
=
std
::
tuple
<
// clang-format off
//##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline|
//##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | |
//##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | |
//##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// pipeline v1, 1 wave
DeviceGemmMultipleD_Xdl_CShuffle
<
Col
,
Col
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F32
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
FastGelu
,
GemmMNKPadding
,
1
,
64
,
16
,
16
,
32
,
8
,
8
,
16
,
16
,
1
,
1
,
S
<
4
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
1
,
LoopScheduler
::
Default
,
PipelineVersion
::
v1
>
#if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES
// pipeline v1, 2 waves
,
DeviceGemmMultipleD_Xdl_CShuffle
<
Col
,
Col
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F32
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
FastGelu
,
GemmMNKPadding
,
1
,
64
,
16
,
16
,
32
,
8
,
8
,
16
,
16
,
1
,
1
,
S
<
4
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
1
,
LoopScheduler
::
Interwave
,
PipelineVersion
::
v1
>
#endif
#if CK_EXPERIMENTAL_PIPELINE_V2_INSTANCES
// pipeline v2, 1 wave
,
DeviceGemmMultipleD_Xdl_CShuffle
<
Col
,
Col
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F32
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
FastGelu
,
GemmMNKPadding
,
1
,
64
,
16
,
16
,
32
,
8
,
8
,
16
,
16
,
1
,
1
,
S
<
4
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
1
,
LoopScheduler
::
Default
,
PipelineVersion
::
v2
>
#endif
// clang-format on
>
;
void
add_device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances
(
void
add_device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
<
Col
,
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
<
Col
,
Col
,
Col
,
...
@@ -101,6 +124,9 @@ void add_device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances(
...
@@ -101,6 +124,9 @@ void add_device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances(
{
{
add_device_operation_instances
(
add_device_operation_instances
(
instances
,
device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances
{});
instances
,
device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances
{});
add_device_operation_instances
(
instances
,
device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_irregular_tile_instances
{});
}
}
}
// namespace instance
}
// namespace instance
...
...
library/src/tensor_operation_instance/gpu/gemm_fastgelu/device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instance.cpp
View file @
35f1a04e
...
@@ -15,7 +15,8 @@ namespace instance {
...
@@ -15,7 +15,8 @@ namespace instance {
template
<
ck
::
index_t
...
Is
>
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
S
=
ck
::
Sequence
<
Is
...
>
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
static
constexpr
auto
GemmMNKPadding
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
;
// e = elementwise((a * b))
// e = elementwise((a * b))
// outout: e[m, n]
// outout: e[m, n]
...
@@ -86,6 +87,28 @@ using device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances = std::t
...
@@ -86,6 +87,28 @@ using device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances = std::t
// clang-format on
// clang-format on
>
;
>
;
// irregular tile size
using
device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_irregular_tile_instances
=
std
::
tuple
<
// clang-format off
//##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline|
//##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | |
//##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | |
//##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// pipeline v1, 1 wave
DeviceGemmMultipleD_Xdl_CShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F32
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
FastGelu
,
GemmMNKPadding
,
1
,
64
,
16
,
16
,
32
,
8
,
8
,
16
,
16
,
1
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
8
,
1
,
S
<
4
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
1
,
LoopScheduler
::
Default
,
PipelineVersion
::
v1
>
#if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES
// pipeline v1, 2 waves
,
DeviceGemmMultipleD_Xdl_CShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F32
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
FastGelu
,
GemmMNKPadding
,
1
,
64
,
16
,
16
,
32
,
8
,
8
,
16
,
16
,
1
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
8
,
1
,
S
<
4
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
1
,
LoopScheduler
::
Interwave
,
PipelineVersion
::
v1
>
#endif
#if CK_EXPERIMENTAL_PIPELINE_V2_INSTANCES
// pipeline v2, 1 wave
,
DeviceGemmMultipleD_Xdl_CShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F32
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
FastGelu
,
GemmMNKPadding
,
1
,
64
,
16
,
16
,
32
,
8
,
8
,
16
,
16
,
1
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
8
,
1
,
S
<
4
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
1
,
LoopScheduler
::
Default
,
PipelineVersion
::
v2
>
#endif
// clang-format on
>
;
void
add_device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances
(
void
add_device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
<
Row
,
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
<
Row
,
Row
,
Row
,
...
@@ -101,6 +124,9 @@ void add_device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances(
...
@@ -101,6 +124,9 @@ void add_device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances(
{
{
add_device_operation_instances
(
add_device_operation_instances
(
instances
,
device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances
{});
instances
,
device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances
{});
add_device_operation_instances
(
instances
,
device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_irregular_tile_instances
{});
}
}
}
// namespace instance
}
// namespace instance
...
...
library/src/tensor_operation_instance/gpu/gemm_fastgelu/device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp
View file @
35f1a04e
...
@@ -15,7 +15,8 @@ namespace instance {
...
@@ -15,7 +15,8 @@ namespace instance {
template
<
ck
::
index_t
...
Is
>
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
S
=
ck
::
Sequence
<
Is
...
>
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
static
constexpr
auto
GemmMNKPadding
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
;
// e = elementwise((a * b))
// e = elementwise((a * b))
// outout: e[m, n]
// outout: e[m, n]
...
@@ -77,6 +78,28 @@ using device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances = std::t
...
@@ -77,6 +78,28 @@ using device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances = std::t
// clang-format on
// clang-format on
>
;
>
;
// irregular tile size
using
device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_irregular_tile_instances
=
std
::
tuple
<
// clang-format off
//##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline|
//##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | |
//##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | |
//##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// pipeline v1, 1 wave
DeviceGemmMultipleD_Xdl_CShuffle
<
Row
,
Col
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F32
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
FastGelu
,
GemmMNKPadding
,
1
,
64
,
16
,
16
,
32
,
8
,
8
,
16
,
16
,
1
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
8
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
1
,
LoopScheduler
::
Default
,
PipelineVersion
::
v1
>
#if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES
// pipeline v1, 2 waves
,
DeviceGemmMultipleD_Xdl_CShuffle
<
Row
,
Col
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F32
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
FastGelu
,
GemmMNKPadding
,
1
,
64
,
16
,
16
,
32
,
8
,
8
,
16
,
16
,
1
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
8
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
1
,
LoopScheduler
::
Interwave
,
PipelineVersion
::
v1
>
#endif
#if CK_EXPERIMENTAL_PIPELINE_V2_INSTANCES
// pipeline v2, 1 wave
,
DeviceGemmMultipleD_Xdl_CShuffle
<
Row
,
Col
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F32
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
FastGelu
,
GemmMNKPadding
,
1
,
64
,
16
,
16
,
32
,
8
,
8
,
16
,
16
,
1
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
8
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
1
,
LoopScheduler
::
Default
,
PipelineVersion
::
v2
>
#endif
// clang-format on
>
;
void
add_device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances
(
void
add_device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
<
Row
,
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
<
Row
,
Col
,
Col
,
...
@@ -92,6 +115,9 @@ void add_device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances(
...
@@ -92,6 +115,9 @@ void add_device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances(
{
{
add_device_operation_instances
(
add_device_operation_instances
(
instances
,
device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances
{});
instances
,
device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances
{});
add_device_operation_instances
(
instances
,
device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_irregular_tile_instances
{});
}
}
}
// namespace instance
}
// namespace instance
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment