Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
b89a88b5
Commit
b89a88b5
authored
Sep 19, 2022
by
Adam Osewski
Browse files
Merge branch 'develop' into wavelet_model
parents
41d5fca7
43c898f6
Changes
261
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
794 additions
and
117 deletions
+794
-117
library/include/ck/library/utility/host_tensor.hpp
library/include/ck/library/utility/host_tensor.hpp
+12
-7
library/include/ck/library/utility/host_tensor_generator.hpp
library/include/ck/library/utility/host_tensor_generator.hpp
+18
-0
library/src/tensor_operation_instance/gpu/CMakeLists.txt
library/src/tensor_operation_instance/gpu/CMakeLists.txt
+3
-0
library/src/tensor_operation_instance/gpu/batched_gemm_add_relu_gemm_add/CMakeLists.txt
...nstance/gpu/batched_gemm_add_relu_gemm_add/CMakeLists.txt
+4
-0
library/src/tensor_operation_instance/gpu/batched_gemm_add_relu_gemm_add/device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp
...xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp
+80
-0
library/src/tensor_operation_instance/gpu/batched_gemm_add_relu_gemm_add/device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance.cpp
...xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance.cpp
+81
-0
library/src/tensor_operation_instance/gpu/batched_gemm_gemm/CMakeLists.txt
...r_operation_instance/gpu/batched_gemm_gemm/CMakeLists.txt
+1
-0
library/src/tensor_operation_instance/gpu/batched_gemm_gemm/device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp
...xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp
+16
-3
library/src/tensor_operation_instance/gpu/batched_gemm_gemm/device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance.cpp
...xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance.cpp
+80
-0
library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm/device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp
...xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp
+15
-2
library/src/tensor_operation_instance/gpu/elementwise/device_normalize_instance.cpp
...on_instance/gpu/elementwise/device_normalize_instance.cpp
+8
-10
library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instance.cpp
.../device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instance.cpp
+13
-13
library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instance.cpp
.../device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instance.cpp
+13
-13
library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instance.cpp
.../device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instance.cpp
+14
-19
library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instance.cpp
.../device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instance.cpp
+18
-18
library/src/tensor_operation_instance/gpu/normalization/device_softmax_f16_f16_instance.cpp
...nce/gpu/normalization/device_softmax_f16_f16_instance.cpp
+25
-17
library/src/tensor_operation_instance/gpu/normalization/device_softmax_f32_f32_instance.cpp
...nce/gpu/normalization/device_softmax_f32_f32_instance.cpp
+23
-15
profiler/CMakeLists.txt
profiler/CMakeLists.txt
+4
-0
profiler/include/profile_batched_gemm_add_relu_gemm_add_impl.hpp
...r/include/profile_batched_gemm_add_relu_gemm_add_impl.hpp
+360
-0
profiler/include/profile_batched_gemm_gemm_impl.hpp
profiler/include/profile_batched_gemm_gemm_impl.hpp
+6
-0
No files found.
library/include/ck/library/utility/host_tensor.hpp
View file @
b89a88b5
...
...
@@ -254,23 +254,28 @@ struct Tensor
Tensor
(
const
HostTensorDescriptor
&
desc
)
:
mDesc
(
desc
),
mData
(
mDesc
.
GetElementSpaceSize
())
{}
template
<
typename
OutT
>
Tensor
<
OutT
>
CopyAsType
()
Tensor
<
OutT
>
CopyAsType
()
const
{
Tensor
<
OutT
>
ret
(
mDesc
);
for
(
size_t
i
=
0
;
i
<
mData
.
size
();
i
++
)
{
ret
.
mData
[
i
]
=
static_cas
t
<
OutT
>
(
mData
[
i
]);
ret
.
mData
[
i
]
=
ck
::
type_conver
t
<
OutT
>
(
mData
[
i
]);
}
return
ret
;
}
Tensor
(
const
Tensor
&
other
)
:
mDesc
(
other
.
mDesc
),
mData
(
other
.
mData
)
{}
Tensor
()
=
delete
;
Tensor
(
const
Tensor
&
)
=
default
;
Tensor
(
Tensor
&&
)
=
default
;
Tensor
&
operator
=
(
const
Tensor
&
other
)
~
Tensor
()
=
default
;
Tensor
&
operator
=
(
const
Tensor
&
)
=
default
;
Tensor
&
operator
=
(
Tensor
&&
)
=
default
;
template
<
typename
FromT
>
explicit
Tensor
(
const
Tensor
<
FromT
>&
other
)
:
Tensor
(
other
.
template
CopyAsType
<
T
>())
{
mDesc
=
other
.
mDesc
;
mData
=
other
.
mData
;
return
*
this
;
}
const
std
::
vector
<
std
::
size_t
>&
GetLengths
()
const
{
return
mDesc
.
GetLengths
();
}
...
...
library/include/ck/library/utility/host_tensor_generator.hpp
View file @
b89a88b5
...
...
@@ -5,6 +5,7 @@
#include <cmath>
#include <numeric>
#include <random>
#include "ck/ck.hpp"
...
...
@@ -126,6 +127,23 @@ struct GeneratorTensor_3<ck::bhalf_t>
}
};
template
<
typename
T
>
struct
GeneratorTensor_4
{
std
::
default_random_engine
generator
;
std
::
normal_distribution
<
float
>
distribution
;
GeneratorTensor_4
(
float
mean
,
float
stddev
)
:
generator
(
1
),
distribution
(
mean
,
stddev
){};
template
<
typename
...
Is
>
T
operator
()(
Is
...)
{
float
tmp
=
distribution
(
generator
);
return
ck
::
type_convert
<
T
>
(
tmp
);
}
};
struct
GeneratorTensor_Checkboard
{
template
<
typename
...
Ts
>
...
...
library/src/tensor_operation_instance/gpu/CMakeLists.txt
View file @
b89a88b5
...
...
@@ -16,6 +16,7 @@ add_subdirectory(batched_gemm)
add_subdirectory
(
batched_gemm_reduce
)
add_subdirectory
(
batched_gemm_gemm
)
add_subdirectory
(
batched_gemm_softmax_gemm
)
add_subdirectory
(
batched_gemm_add_relu_gemm_add
)
add_subdirectory
(
grouped_gemm
)
add_subdirectory
(
contraction_scale
)
add_subdirectory
(
contraction_bilinear
)
...
...
@@ -42,6 +43,7 @@ add_library(device_operations STATIC
$<TARGET_OBJECTS:device_gemm_add_add_fastgelu_instance>
$<TARGET_OBJECTS:device_gemm_bias_add_reduce_instance>
$<TARGET_OBJECTS:device_batched_gemm_instance>
$<TARGET_OBJECTS:device_batched_gemm_add_relu_gemm_add_instance>
$<TARGET_OBJECTS:device_batched_gemm_reduce_instance>
$<TARGET_OBJECTS:device_grouped_gemm_instance>
$<TARGET_OBJECTS:device_contraction_scale_instance>
...
...
@@ -78,6 +80,7 @@ target_include_directories(device_operations PUBLIC
$<INSTALL_INTERFACE:
${
CMAKE_INSTALL_INCLUDEDIR
}
/ck/tensor>
$<INSTALL_INTERFACE:
${
CMAKE_INSTALL_INCLUDEDIR
}
/ck/problem_transform>
$<INSTALL_INTERFACE:
${
CMAKE_INSTALL_INCLUDEDIR
}
/ck/tensor_operation/gpu/device>
$<INSTALL_INTERFACE:
${
CMAKE_INSTALL_INCLUDEDIR
}
/ck/tensor_operation/gpu/device/impl>
$<INSTALL_INTERFACE:
${
CMAKE_INSTALL_INCLUDEDIR
}
/ck/tensor_operation/gpu/grid>
$<INSTALL_INTERFACE:
${
CMAKE_INSTALL_INCLUDEDIR
}
/ck/tensor_operation/gpu/block>
$<INSTALL_INTERFACE:
${
CMAKE_INSTALL_INCLUDEDIR
}
/ck/tensor_operation/gpu/warp>
...
...
library/src/tensor_operation_instance/gpu/batched_gemm_add_relu_gemm_add/CMakeLists.txt
0 → 100644
View file @
b89a88b5
add_instance_library
(
device_batched_gemm_add_relu_gemm_add_instance
device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp
device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance.cpp
)
library/src/tensor_operation_instance/gpu/batched_gemm_add_relu_gemm_add/device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp
0 → 100644
View file @
b89a88b5
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
CDE0ElementOp
=
ck
::
tensor_operation
::
element_wise
::
AddRelu
;
using
CDE1ElementOp
=
ck
::
tensor_operation
::
element_wise
::
Add
;
// c[g, m, n] = a[g, m, k] * b[g, n, k]
using
device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances
=
std
::
tuple
<
// clang-format off
//##################################################| A0Layout| B0Layout| D0Layout| B1Layout| D1sLayout| E1Layout| A0Data| B0Data| Acc0DataType| D0DataType| B1Data| Acc1CData| CShuffle| D1sData| E1Data| A0| B0| CDE0| B1| CDE1| PadGemm0M| PadGemm0N| PadGemm0K| PadGemm1N| PadGemm1K|NumGemm0K| Block| Gemm0| Gemm0| Gemm0| Gemm1| Gemm1|A0K1|B0K1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1|A0BlockTransfer|A0BlockTransfer|A0BlockTransfer|A0BlockTransfer|A0BlockTransfer|A0BlockTransfer|A0BlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| C1Shuffle| C1Shuffle| CDE1BlockTransferClusterLengths| CDE1BlockTransfer|
//##################################################| | | | | | | Type| Type| Type| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| | | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//##################################################| | | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per|Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//##################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | |
// no padding
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
<
Row
,
Col
,
ck
::
Tuple
<
Row
>
,
Row
,
ck
::
Tuple
<
Row
>
,
Row
,
F16
,
F16
,
F32
,
ck
::
Tuple
<
F16
>
,
F16
,
F32
,
F32
,
ck
::
Tuple
<
F16
>
,
F16
,
PassThrough
,
PassThrough
,
CDE0ElementOp
,
PassThrough
,
CDE1ElementOp
,
false
,
false
,
false
,
false
,
false
,
1
,
256
,
128
,
128
,
64
,
64
,
32
,
8
,
8
,
2
,
32
,
32
,
1
,
4
,
2
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
false
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
false
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
<
Row
,
Col
,
ck
::
Tuple
<
Row
>
,
Row
,
ck
::
Tuple
<
Row
>
,
Row
,
F16
,
F16
,
F32
,
ck
::
Tuple
<
F16
>
,
F16
,
F32
,
F32
,
ck
::
Tuple
<
F16
>
,
F16
,
PassThrough
,
PassThrough
,
CDE0ElementOp
,
PassThrough
,
CDE1ElementOp
,
false
,
false
,
false
,
false
,
false
,
1
,
256
,
128
,
128
,
32
,
64
,
32
,
8
,
8
,
2
,
32
,
32
,
1
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
<
Row
,
Col
,
ck
::
Tuple
<
Row
>
,
Row
,
ck
::
Tuple
<
Row
>
,
Row
,
F16
,
F16
,
F32
,
ck
::
Tuple
<
F16
>
,
F16
,
F32
,
F32
,
ck
::
Tuple
<
F16
>
,
F16
,
PassThrough
,
PassThrough
,
CDE0ElementOp
,
PassThrough
,
CDE1ElementOp
,
false
,
false
,
false
,
false
,
false
,
1
,
256
,
128
,
128
,
64
,
128
,
32
,
8
,
8
,
2
,
32
,
32
,
1
,
4
,
4
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
false
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
false
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
<
Row
,
Col
,
ck
::
Tuple
<
Row
>
,
Row
,
ck
::
Tuple
<
Row
>
,
Row
,
F16
,
F16
,
F32
,
ck
::
Tuple
<
F16
>
,
F16
,
F32
,
F32
,
ck
::
Tuple
<
F16
>
,
F16
,
PassThrough
,
PassThrough
,
CDE0ElementOp
,
PassThrough
,
CDE1ElementOp
,
false
,
false
,
false
,
false
,
false
,
1
,
256
,
128
,
128
,
32
,
128
,
32
,
8
,
8
,
2
,
32
,
32
,
1
,
4
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
<
Row
,
Col
,
ck
::
Tuple
<
Row
>
,
Row
,
ck
::
Tuple
<
Row
>
,
Row
,
F16
,
F16
,
F32
,
ck
::
Tuple
<
F16
>
,
F16
,
F32
,
F32
,
ck
::
Tuple
<
F16
>
,
F16
,
PassThrough
,
PassThrough
,
CDE0ElementOp
,
PassThrough
,
CDE1ElementOp
,
false
,
false
,
false
,
false
,
false
,
1
,
256
,
64
,
256
,
32
,
128
,
32
,
8
,
8
,
2
,
16
,
16
,
1
,
16
,
8
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
8
,
S
<
1
,
16
,
1
,
16
>
,
8
>
,
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
<
Row
,
Col
,
ck
::
Tuple
<
Row
>
,
Row
,
ck
::
Tuple
<
Row
>
,
Row
,
F16
,
F16
,
F32
,
ck
::
Tuple
<
F16
>
,
F16
,
F32
,
F32
,
ck
::
Tuple
<
F16
>
,
F16
,
PassThrough
,
PassThrough
,
CDE0ElementOp
,
PassThrough
,
CDE1ElementOp
,
false
,
false
,
false
,
false
,
false
,
1
,
256
,
64
,
256
,
32
,
64
,
32
,
8
,
8
,
2
,
16
,
16
,
1
,
16
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
4
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
<
Row
,
Col
,
ck
::
Tuple
<
Row
>
,
Row
,
ck
::
Tuple
<
Row
>
,
Row
,
F16
,
F16
,
F32
,
ck
::
Tuple
<
F16
>
,
F16
,
F32
,
F32
,
ck
::
Tuple
<
F16
>
,
F16
,
PassThrough
,
PassThrough
,
CDE0ElementOp
,
PassThrough
,
CDE1ElementOp
,
false
,
false
,
false
,
false
,
false
,
1
,
256
,
64
,
256
,
64
,
128
,
32
,
8
,
8
,
2
,
16
,
16
,
1
,
16
,
8
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
8
,
S
<
1
,
16
,
1
,
16
>
,
8
>
,
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
<
Row
,
Col
,
ck
::
Tuple
<
Row
>
,
Row
,
ck
::
Tuple
<
Row
>
,
Row
,
F16
,
F16
,
F32
,
ck
::
Tuple
<
F16
>
,
F16
,
F32
,
F32
,
ck
::
Tuple
<
F16
>
,
F16
,
PassThrough
,
PassThrough
,
CDE0ElementOp
,
PassThrough
,
CDE1ElementOp
,
false
,
false
,
false
,
false
,
false
,
1
,
256
,
64
,
256
,
64
,
64
,
32
,
8
,
8
,
2
,
16
,
16
,
1
,
16
,
4
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
4
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
// Padded fallback kernel
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
<
Row
,
Col
,
ck
::
Tuple
<
Row
>
,
Row
,
ck
::
Tuple
<
Row
>
,
Row
,
F16
,
F16
,
F32
,
ck
::
Tuple
<
F16
>
,
F16
,
F32
,
F32
,
ck
::
Tuple
<
F16
>
,
F16
,
PassThrough
,
PassThrough
,
CDE0ElementOp
,
PassThrough
,
CDE1ElementOp
,
true
,
true
,
true
,
true
,
true
,
1
,
256
,
128
,
128
,
64
,
128
,
32
,
8
,
8
,
2
,
32
,
32
,
1
,
4
,
4
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
false
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
false
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
<
Row
,
Col
,
ck
::
Tuple
<
Row
>
,
Row
,
ck
::
Tuple
<
Row
>
,
Row
,
F16
,
F16
,
F32
,
ck
::
Tuple
<
F16
>
,
F16
,
F32
,
F32
,
ck
::
Tuple
<
F16
>
,
F16
,
PassThrough
,
PassThrough
,
CDE0ElementOp
,
PassThrough
,
CDE1ElementOp
,
true
,
true
,
true
,
true
,
true
,
1
,
256
,
128
,
64
,
32
,
128
,
32
,
8
,
8
,
2
,
32
,
32
,
1
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
>
// clang-format on
>
;
void
add_device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedGemmMultipleDGemmMultipleD
<
Row
,
Col
,
ck
::
Tuple
<
Row
>
,
Row
,
ck
::
Tuple
<
Row
>
,
Row
,
F16
,
F16
,
ck
::
Tuple
<
F16
>
,
F16
,
ck
::
Tuple
<
F16
>
,
F16
,
PassThrough
,
PassThrough
,
CDE0ElementOp
,
PassThrough
,
CDE1ElementOp
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/batched_gemm_add_relu_gemm_add/device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance.cpp
0 → 100644
View file @
b89a88b5
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
CDE0ElementOp
=
ck
::
tensor_operation
::
element_wise
::
AddRelu
;
using
CDE1ElementOp
=
ck
::
tensor_operation
::
element_wise
::
Add
;
// c[g, m, n] = a[g, m, k] * b[g, n, k]
using
device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instances
=
std
::
tuple
<
// clang-format off
//##################################################| A0Layout| B0Layout| D0Layout| B1Layout| D1sLayout| E1Layout| A0Data| B0Data| Acc0DataType| D0DataType| B1Data| Acc1CData| CShuffle| D1sData| E1Data| A0| B0| CDE0| B1| CDE1| PadGemm0M| PadGemm0N| PadGemm0K| PadGemm1N| PadGemm1K| NumGemm0K| Block| Gemm0| Gemm0| Gemm0| Gemm1| Gemm1| A0K1| B0K1|B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| A0BlockTransfer|A0BlockTransfer|A0BlockTransfer|A0BlockTransfer|A0BlockTransfer|A0BlockTransfer|A0BlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| C1Shuffle| C1Shuffle| CDE1BlockTransferClusterLengths| CDE1BlockTransfer|
//##################################################| | | | | | | Type| Type| Type| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| | | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//##################################################| | | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//##################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | |
// no padding
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
<
Row
,
Col
,
ck
::
Tuple
<
Row
>
,
Col
,
ck
::
Tuple
<
Row
>
,
Row
,
F16
,
F16
,
F32
,
ck
::
Tuple
<
F16
>
,
F16
,
F32
,
F32
,
ck
::
Tuple
<
F16
>
,
F16
,
PassThrough
,
PassThrough
,
CDE0ElementOp
,
PassThrough
,
CDE1ElementOp
,
false
,
false
,
false
,
false
,
false
,
1
,
256
,
256
,
128
,
32
,
128
,
32
,
8
,
8
,
4
,
32
,
32
,
2
,
4
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
true
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
<
Row
,
Col
,
ck
::
Tuple
<
Row
>
,
Col
,
ck
::
Tuple
<
Row
>
,
Row
,
F16
,
F16
,
F32
,
ck
::
Tuple
<
F16
>
,
F16
,
F32
,
F32
,
ck
::
Tuple
<
F16
>
,
F16
,
PassThrough
,
PassThrough
,
CDE0ElementOp
,
PassThrough
,
CDE1ElementOp
,
false
,
false
,
false
,
false
,
false
,
1
,
256
,
128
,
128
,
64
,
64
,
32
,
8
,
8
,
4
,
32
,
32
,
1
,
4
,
2
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
false
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
<
Row
,
Col
,
ck
::
Tuple
<
Row
>
,
Col
,
ck
::
Tuple
<
Row
>
,
Row
,
F16
,
F16
,
F32
,
ck
::
Tuple
<
F16
>
,
F16
,
F32
,
F32
,
ck
::
Tuple
<
F16
>
,
F16
,
PassThrough
,
PassThrough
,
CDE0ElementOp
,
PassThrough
,
CDE1ElementOp
,
false
,
false
,
false
,
false
,
false
,
1
,
256
,
128
,
128
,
32
,
64
,
32
,
8
,
8
,
4
,
32
,
32
,
1
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
true
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
<
Row
,
Col
,
ck
::
Tuple
<
Row
>
,
Col
,
ck
::
Tuple
<
Row
>
,
Row
,
F16
,
F16
,
F32
,
ck
::
Tuple
<
F16
>
,
F16
,
F32
,
F32
,
ck
::
Tuple
<
F16
>
,
F16
,
PassThrough
,
PassThrough
,
CDE0ElementOp
,
PassThrough
,
CDE1ElementOp
,
false
,
false
,
false
,
false
,
false
,
1
,
256
,
128
,
128
,
64
,
128
,
32
,
8
,
8
,
4
,
32
,
32
,
1
,
4
,
4
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
false
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
<
Row
,
Col
,
ck
::
Tuple
<
Row
>
,
Col
,
ck
::
Tuple
<
Row
>
,
Row
,
F16
,
F16
,
F32
,
ck
::
Tuple
<
F16
>
,
F16
,
F32
,
F32
,
ck
::
Tuple
<
F16
>
,
F16
,
PassThrough
,
PassThrough
,
CDE0ElementOp
,
PassThrough
,
CDE1ElementOp
,
false
,
false
,
false
,
false
,
false
,
1
,
256
,
128
,
128
,
32
,
128
,
32
,
8
,
8
,
4
,
32
,
32
,
1
,
4
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
true
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
<
Row
,
Col
,
ck
::
Tuple
<
Row
>
,
Col
,
ck
::
Tuple
<
Row
>
,
Row
,
F16
,
F16
,
F32
,
ck
::
Tuple
<
F16
>
,
F16
,
F32
,
F32
,
ck
::
Tuple
<
F16
>
,
F16
,
PassThrough
,
PassThrough
,
CDE0ElementOp
,
PassThrough
,
CDE1ElementOp
,
false
,
false
,
false
,
false
,
false
,
1
,
256
,
64
,
256
,
32
,
128
,
32
,
8
,
8
,
4
,
16
,
16
,
1
,
16
,
8
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
true
,
1
,
8
,
S
<
1
,
16
,
1
,
16
>
,
8
>
,
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
<
Row
,
Col
,
ck
::
Tuple
<
Row
>
,
Col
,
ck
::
Tuple
<
Row
>
,
Row
,
F16
,
F16
,
F32
,
ck
::
Tuple
<
F16
>
,
F16
,
F32
,
F32
,
ck
::
Tuple
<
F16
>
,
F16
,
PassThrough
,
PassThrough
,
CDE0ElementOp
,
PassThrough
,
CDE1ElementOp
,
false
,
false
,
false
,
false
,
false
,
1
,
256
,
64
,
256
,
32
,
64
,
32
,
8
,
8
,
4
,
16
,
16
,
1
,
16
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
true
,
1
,
4
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
<
Row
,
Col
,
ck
::
Tuple
<
Row
>
,
Col
,
ck
::
Tuple
<
Row
>
,
Row
,
F16
,
F16
,
F32
,
ck
::
Tuple
<
F16
>
,
F16
,
F32
,
F32
,
ck
::
Tuple
<
F16
>
,
F16
,
PassThrough
,
PassThrough
,
CDE0ElementOp
,
PassThrough
,
CDE1ElementOp
,
false
,
false
,
false
,
false
,
false
,
1
,
256
,
64
,
256
,
64
,
128
,
32
,
8
,
8
,
4
,
16
,
16
,
1
,
16
,
8
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
true
,
1
,
8
,
S
<
1
,
16
,
1
,
16
>
,
8
>
,
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
<
Row
,
Col
,
ck
::
Tuple
<
Row
>
,
Col
,
ck
::
Tuple
<
Row
>
,
Row
,
F16
,
F16
,
F32
,
ck
::
Tuple
<
F16
>
,
F16
,
F32
,
F32
,
ck
::
Tuple
<
F16
>
,
F16
,
PassThrough
,
PassThrough
,
CDE0ElementOp
,
PassThrough
,
CDE1ElementOp
,
false
,
false
,
false
,
false
,
false
,
1
,
256
,
64
,
256
,
64
,
64
,
32
,
8
,
8
,
4
,
16
,
16
,
1
,
16
,
4
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
true
,
1
,
4
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
// Padded fallback kernel
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
<
Row
,
Col
,
ck
::
Tuple
<
Row
>
,
Col
,
ck
::
Tuple
<
Row
>
,
Row
,
F16
,
F16
,
F32
,
ck
::
Tuple
<
F16
>
,
F16
,
F32
,
F32
,
ck
::
Tuple
<
F16
>
,
F16
,
PassThrough
,
PassThrough
,
CDE0ElementOp
,
PassThrough
,
CDE1ElementOp
,
true
,
true
,
true
,
true
,
true
,
1
,
256
,
128
,
128
,
64
,
128
,
32
,
8
,
8
,
4
,
32
,
32
,
1
,
4
,
4
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
false
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
false
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
<
Row
,
Col
,
ck
::
Tuple
<
Row
>
,
Col
,
ck
::
Tuple
<
Row
>
,
Row
,
F16
,
F16
,
F32
,
ck
::
Tuple
<
F16
>
,
F16
,
F32
,
F32
,
ck
::
Tuple
<
F16
>
,
F16
,
PassThrough
,
PassThrough
,
CDE0ElementOp
,
PassThrough
,
CDE1ElementOp
,
true
,
true
,
true
,
true
,
true
,
1
,
256
,
128
,
64
,
32
,
128
,
32
,
8
,
8
,
4
,
32
,
32
,
1
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
true
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
>
// clang-format on
>
;
void
add_device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedGemmMultipleDGemmMultipleD
<
Row
,
Col
,
ck
::
Tuple
<
Row
>
,
Col
,
ck
::
Tuple
<
Row
>
,
Row
,
F16
,
F16
,
ck
::
Tuple
<
F16
>
,
F16
,
ck
::
Tuple
<
F16
>
,
F16
,
PassThrough
,
PassThrough
,
CDE0ElementOp
,
PassThrough
,
CDE1ElementOp
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instances
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/batched_gemm_gemm/CMakeLists.txt
View file @
b89a88b5
add_instance_library
(
device_batched_gemm_gemm_instance
device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp
device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance.cpp
)
library/src/tensor_operation_instance/gpu/batched_gemm_gemm/device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp
View file @
b89a88b5
...
...
@@ -26,6 +26,7 @@ using S = ck::Sequence<Is...>;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
static
constexpr
auto
GemmPadded
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKOPadding
;
// c[g, m, n] = a[g, m, k] * b[g, n, k]
using
device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances
=
std
::
tuple
<
...
...
@@ -34,10 +35,22 @@ using device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_inst
//################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//################################| | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceBatchedGemmGemm_Xdl_CShuffle
<
Row
,
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
256
,
128
,
32
,
128
,
32
,
8
,
8
,
2
,
32
,
32
,
2
,
4
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceBatchedGemmGemm_Xdl_CShuffle
<
Row
,
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
128
,
128
,
32
,
128
,
32
,
8
,
8
,
2
,
32
,
32
,
1
,
4
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
//################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | |
// DeviceBatchedGemmGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 64, 32, 8, 8, 2, 32, 32, 2, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8>, // failed validation on MI100
// DeviceBatchedGemmGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 128, 32, 8, 8, 2, 32, 32, 2, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8>, // failed validation on MI100
// DeviceBatchedGemmGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 64, 32, 8, 8, 2, 32, 32, 1, 8, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8>, // failed validation on MI100
// DeviceBatchedGemmGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 128, 32, 8, 8, 2, 32, 32, 1, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8>, // failed validation on MI100
DeviceBatchedGemmGemm_Xdl_CShuffle
<
Row
,
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
128
,
128
,
64
,
64
,
32
,
8
,
8
,
2
,
32
,
32
,
1
,
4
,
2
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
false
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
false
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceBatchedGemmGemm_Xdl_CShuffle
<
Row
,
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
128
,
128
,
32
,
64
,
32
,
8
,
8
,
2
,
32
,
32
,
1
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceBatchedGemmGemm_Xdl_CShuffle
<
Row
,
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
128
,
64
,
32
,
128
,
32
,
8
,
8
,
2
,
32
,
32
,
1
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
>
DeviceBatchedGemmGemm_Xdl_CShuffle
<
Row
,
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
128
,
128
,
64
,
128
,
32
,
8
,
8
,
2
,
32
,
32
,
1
,
4
,
4
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
false
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
false
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceBatchedGemmGemm_Xdl_CShuffle
<
Row
,
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
128
,
128
,
32
,
128
,
32
,
8
,
8
,
2
,
32
,
32
,
1
,
4
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceBatchedGemmGemm_Xdl_CShuffle
<
Row
,
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
64
,
256
,
32
,
128
,
32
,
8
,
8
,
2
,
16
,
16
,
1
,
16
,
8
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
8
,
S
<
1
,
16
,
1
,
16
>
,
8
>
,
DeviceBatchedGemmGemm_Xdl_CShuffle
<
Row
,
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
64
,
256
,
32
,
64
,
32
,
8
,
8
,
2
,
16
,
16
,
1
,
16
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
4
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceBatchedGemmGemm_Xdl_CShuffle
<
Row
,
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
64
,
256
,
64
,
128
,
32
,
8
,
8
,
2
,
16
,
16
,
1
,
16
,
8
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
8
,
S
<
1
,
16
,
1
,
16
>
,
8
>
,
DeviceBatchedGemmGemm_Xdl_CShuffle
<
Row
,
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
64
,
256
,
64
,
64
,
32
,
8
,
8
,
2
,
16
,
16
,
1
,
16
,
4
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
4
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
// Padded fallback kernel
DeviceBatchedGemmGemm_Xdl_CShuffle
<
Row
,
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmPadded
,
1
,
256
,
128
,
128
,
64
,
128
,
32
,
8
,
8
,
2
,
32
,
32
,
1
,
4
,
4
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
false
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
false
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceBatchedGemmGemm_Xdl_CShuffle
<
Row
,
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmPadded
,
1
,
256
,
128
,
64
,
32
,
128
,
32
,
8
,
8
,
2
,
32
,
32
,
1
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
>
// clang-format on
>
;
...
...
library/src/tensor_operation_instance/gpu/batched_gemm_gemm/device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance.cpp
0 → 100644
View file @
b89a88b5
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_gemm_xdl_cshuffle.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
static
constexpr
auto
GemmPadded
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKOPadding
;
// c[g, m, n] = a[g, m, k] * b[g, n, k]
using
device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instances
=
std
::
tuple
<
// clang-format off
//################################| ALayout| B0Layout| B1Layout| CLayout| AData| B0Data| B1Data| CData| AccData| CShuffle| A| B0| Acc0| B1| C| GEMM| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//################################| | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | |
//################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | |
// DeviceBatchedGemmGemm_Xdl_CShuffle< Row, Col, Col, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 64, 32, 8, 8, 4, 32, 32, 2, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 2, S<1, 32, 1, 8>, 8>, // TODO: to enable; can trigger compiler crash in mainline #9110 but not in #10738
DeviceBatchedGemmGemm_Xdl_CShuffle
<
Row
,
Col
,
Col
,
Row
,
F16
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
256
,
128
,
32
,
128
,
32
,
8
,
8
,
4
,
32
,
32
,
2
,
4
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
true
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
// DeviceBatchedGemmGemm_Xdl_CShuffle< Row, Col, Col, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 64, 32, 8, 8, 4, 32, 32, 1, 8, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 2, S<1, 32, 1, 8>, 8>, // TODO: to enable; can cause validation error on MI100
// DeviceBatchedGemmGemm_Xdl_CShuffle< Row, Col, Col, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 128, 32, 8, 8, 4, 32, 32, 1, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 2, S<1, 32, 1, 8>, 8>, // TODO: to enable; can cause validation error on MI100
DeviceBatchedGemmGemm_Xdl_CShuffle
<
Row
,
Col
,
Col
,
Row
,
F16
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
128
,
128
,
64
,
64
,
32
,
8
,
8
,
4
,
32
,
32
,
1
,
4
,
2
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
false
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceBatchedGemmGemm_Xdl_CShuffle
<
Row
,
Col
,
Col
,
Row
,
F16
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
128
,
128
,
32
,
64
,
32
,
8
,
8
,
4
,
32
,
32
,
1
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
true
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceBatchedGemmGemm_Xdl_CShuffle
<
Row
,
Col
,
Col
,
Row
,
F16
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
128
,
128
,
64
,
128
,
32
,
8
,
8
,
4
,
32
,
32
,
1
,
4
,
4
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
false
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceBatchedGemmGemm_Xdl_CShuffle
<
Row
,
Col
,
Col
,
Row
,
F16
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
128
,
128
,
32
,
128
,
32
,
8
,
8
,
4
,
32
,
32
,
1
,
4
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
true
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceBatchedGemmGemm_Xdl_CShuffle
<
Row
,
Col
,
Col
,
Row
,
F16
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
64
,
256
,
32
,
128
,
32
,
8
,
8
,
4
,
16
,
16
,
1
,
16
,
8
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
true
,
1
,
8
,
S
<
1
,
16
,
1
,
16
>
,
8
>
,
DeviceBatchedGemmGemm_Xdl_CShuffle
<
Row
,
Col
,
Col
,
Row
,
F16
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
64
,
256
,
32
,
64
,
32
,
8
,
8
,
4
,
16
,
16
,
1
,
16
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
true
,
1
,
4
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceBatchedGemmGemm_Xdl_CShuffle
<
Row
,
Col
,
Col
,
Row
,
F16
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
64
,
256
,
64
,
128
,
32
,
8
,
8
,
4
,
16
,
16
,
1
,
16
,
8
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
true
,
1
,
8
,
S
<
1
,
16
,
1
,
16
>
,
8
>
,
DeviceBatchedGemmGemm_Xdl_CShuffle
<
Row
,
Col
,
Col
,
Row
,
F16
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
64
,
256
,
64
,
64
,
32
,
8
,
8
,
4
,
16
,
16
,
1
,
16
,
4
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
true
,
1
,
4
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
// Padded fallback kernel
DeviceBatchedGemmGemm_Xdl_CShuffle
<
Row
,
Col
,
Col
,
Row
,
F16
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmPadded
,
1
,
256
,
128
,
128
,
64
,
128
,
32
,
8
,
8
,
4
,
32
,
32
,
1
,
4
,
4
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
false
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
false
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceBatchedGemmGemm_Xdl_CShuffle
<
Row
,
Col
,
Col
,
Row
,
F16
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmPadded
,
1
,
256
,
128
,
64
,
32
,
128
,
32
,
8
,
8
,
4
,
32
,
32
,
1
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
true
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
>
// clang-format on
>
;
void
add_device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedGemmGemm
<
Row
,
Col
,
Col
,
Row
,
F16
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instances
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm/device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp
View file @
b89a88b5
...
...
@@ -26,6 +26,8 @@ using S = ck::Sequence<Is...>;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
static
constexpr
auto
GemmPadded
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNOPadding
;
// Padding K is currently flawed
// c[g, m, n] = a[g, m, k] * b[g, n, k]
using
device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances
=
...
...
@@ -35,10 +37,21 @@ using device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_
//#######################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//#######################################| | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//#######################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
<
Row
,
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
256
,
128
,
32
,
64
,
32
,
8
,
8
,
2
,
32
,
32
,
2
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
<
Row
,
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
256
,
128
,
32
,
128
,
32
,
8
,
8
,
2
,
32
,
32
,
2
,
4
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
<
Row
,
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
128
,
128
,
32
,
128
,
32
,
8
,
8
,
2
,
32
,
32
,
1
,
4
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
<
Row
,
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
128
,
256
,
32
,
64
,
32
,
8
,
8
,
2
,
32
,
32
,
1
,
8
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
<
Row
,
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
128
,
256
,
32
,
128
,
32
,
8
,
8
,
2
,
32
,
32
,
1
,
8
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
<
Row
,
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
128
,
128
,
64
,
64
,
32
,
8
,
8
,
2
,
32
,
32
,
1
,
4
,
2
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
false
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
false
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
<
Row
,
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
128
,
128
,
32
,
64
,
32
,
8
,
8
,
2
,
32
,
32
,
1
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
<
Row
,
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
128
,
64
,
32
,
128
,
32
,
8
,
8
,
2
,
32
,
32
,
1
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
>
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
<
Row
,
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
128
,
128
,
64
,
128
,
32
,
8
,
8
,
2
,
32
,
32
,
1
,
4
,
4
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
false
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
false
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
<
Row
,
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
128
,
128
,
32
,
128
,
32
,
8
,
8
,
2
,
32
,
32
,
1
,
4
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
<
Row
,
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
64
,
256
,
32
,
128
,
32
,
8
,
8
,
2
,
16
,
16
,
1
,
16
,
8
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
8
,
S
<
1
,
16
,
1
,
16
>
,
8
>
,
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
<
Row
,
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
64
,
256
,
32
,
64
,
32
,
8
,
8
,
2
,
16
,
16
,
1
,
16
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
4
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
<
Row
,
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
64
,
256
,
64
,
128
,
32
,
8
,
8
,
2
,
16
,
16
,
1
,
16
,
8
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
8
,
S
<
1
,
16
,
1
,
16
>
,
8
>
,
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
<
Row
,
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
64
,
256
,
64
,
64
,
32
,
8
,
8
,
2
,
16
,
16
,
1
,
16
,
4
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
4
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
// Padded fallback kernel
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
<
Row
,
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmPadded
,
1
,
256
,
128
,
128
,
64
,
128
,
32
,
8
,
8
,
2
,
32
,
32
,
1
,
4
,
4
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
false
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
false
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
<
Row
,
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmPadded
,
1
,
256
,
128
,
64
,
32
,
128
,
32
,
8
,
8
,
2
,
32
,
32
,
1
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
>
// clang-format on
>
;
...
...
library/src/tensor_operation_instance/gpu/elementwise/device_normalize_instance.cpp
View file @
b89a88b5
...
...
@@ -6,7 +6,7 @@
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_
5ary_
elementwise.hpp"
#include "ck/tensor_operation/gpu/device/device_elementwise.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace
ck
{
...
...
@@ -27,19 +27,17 @@ using outputType = F16;
using
Normalize
=
ck
::
tensor_operation
::
element_wise
::
Normalize
;
using
device_normalize_from_mean_squaremean_f16_f32_f32_f16_f16_instances
=
std
::
tuple
<
// clang-format off
//###################|in | mean| square_mean| gamma| beta| out| ComputeDataType| functor| NDim| MPerThread| in, mean, square_mean, gamma, beta, out ScalarPerVector|
//###################|in | mean| square_mean| gamma| beta| out| ComputeDataType| functor| NDim| MPerThread| in, mean, square_mean, gamma, beta, out ScalarPerVector|
//###################|in | mean| square_mean| gamma| beta| out| ComputeDataType| functor| NDim| MPerThread| in, mean, square_mean, gamma, beta, out ScalarPerVector|
//###################|in | mean| square_mean| gamma| beta| out| ComputeDataType| functor| NDim| MPerThread| in, mean, square_mean, gamma, beta, out ScalarPerVector|
Device5AryElementwise
<
F16
,
F32
,
F32
,
F16
,
F16
,
F16
,
F32
,
Normalize
,
2
,
8
,
8
,
1
,
1
,
8
,
8
,
8
>
,
Device5AryElementwise
<
F16
,
F32
,
F32
,
F16
,
F16
,
F16
,
F32
,
Normalize
,
2
,
4
,
4
,
1
,
1
,
4
,
4
,
4
>
,
Device5AryElementwise
<
F16
,
F32
,
F32
,
F16
,
F16
,
F16
,
F32
,
Normalize
,
2
,
2
,
2
,
1
,
1
,
2
,
2
,
2
>
,
Device5AryElementwise
<
F16
,
F32
,
F32
,
F16
,
F16
,
F16
,
F32
,
Normalize
,
2
,
1
,
1
,
1
,
1
,
1
,
1
,
1
>
//###################|<in, mean, square_mean, gamma, beta>| <out>| functor| NDim| MPerThread| <in, mean, square_mean, gamma, beta ScalarPerVector>| <out ScalarPerVector>|
DeviceElementwise
<
Tuple
<
F16
,
F32
,
F32
,
F16
,
F16
>
,
Tuple
<
F16
>
,
Normalize
,
2
,
8
,
Sequence
<
8
,
1
,
1
,
8
,
8
>
,
Sequence
<
8
>
>
,
DeviceElementwise
<
Tuple
<
F16
,
F32
,
F32
,
F16
,
F16
>
,
Tuple
<
F16
>
,
Normalize
,
2
,
4
,
Sequence
<
4
,
1
,
1
,
4
,
4
>
,
Sequence
<
4
>
>
,
DeviceElementwise
<
Tuple
<
F16
,
F32
,
F32
,
F16
,
F16
>
,
Tuple
<
F16
>
,
Normalize
,
2
,
2
,
Sequence
<
2
,
1
,
1
,
2
,
2
>
,
Sequence
<
2
>
>
,
DeviceElementwise
<
Tuple
<
F16
,
F32
,
F32
,
F16
,
F16
>
,
Tuple
<
F16
>
,
Normalize
,
2
,
1
,
Sequence
<
1
,
1
,
1
,
1
,
1
>
,
Sequence
<
1
>
>
// clang-format on
>
;
void
add_device_normalize_from_mean_squaremean_f16_f32_f32_f16_f16_instances
(
std
::
vector
<
DeviceElementwisePtr
<
5
,
1
,
2
,
Normalize
>>&
instances
)
std
::
vector
<
DeviceElementwiseBasePtr
<
Tuple
<
F16
,
F32
,
F32
,
F16
,
F16
>
,
Tuple
<
F16
>
,
Normalize
,
2
>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_normalize_from_mean_squaremean_f16_f32_f32_f16_f16_instances
{});
...
...
library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instance.cpp
View file @
b89a88b5
...
...
@@ -6,7 +6,7 @@
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_xdl_splitk.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_xdl_splitk
_c_shuffle
.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
...
...
@@ -31,18 +31,18 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa
// Compilation parameters for a[k, m] * b[k, n] = c[m, n]
using
device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances
=
std
::
tuple
<
// clang-format off
//#################|
AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds|
CThreadTransfer| CThread
Transfer|
//#################|
Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN|
SrcDstVectorDim| DstScala
r|
//#################|
| | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| |
| PerVector
|
//#################|
| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmXdlSplitK
<
F32
,
F32
,
F32
,
F32
,
Col
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
256
,
128
,
4
,
4
,
32
,
32
,
4
,
2
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
4
,
4
,
true
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
4
,
true
,
7
,
1
>
,
DeviceGemmXdlSplitK
<
F32
,
F32
,
F32
,
F32
,
Col
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
128
,
256
,
4
,
4
,
32
,
32
,
2
,
4
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
4
,
true
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
4
,
4
,
true
,
7
,
1
>
,
DeviceGemmXdlSplitK
<
F32
,
F32
,
F32
,
F32
,
Col
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
128
,
128
,
128
,
4
,
4
,
32
,
32
,
4
,
2
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
4
,
4
,
true
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
4
,
4
,
true
,
7
,
1
>
,
DeviceGemmXdlSplitK
<
F32
,
F32
,
F32
,
F32
,
Col
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
128
,
128
,
4
,
4
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
4
,
true
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
4
,
true
,
7
,
1
>
,
DeviceGemmXdlSplitK
<
F32
,
F32
,
F32
,
F32
,
Col
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
128
,
128
,
64
,
4
,
4
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
4
,
4
,
true
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
4
,
true
,
7
,
1
>
,
DeviceGemmXdlSplitK
<
F32
,
F32
,
F32
,
F32
,
Col
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
128
,
64
,
128
,
4
,
4
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
4
,
true
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
4
,
4
,
true
,
7
,
1
>
,
DeviceGemmXdlSplitK
<
F32
,
F32
,
F32
,
F32
,
Col
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
128
,
64
,
4
,
4
,
32
,
32
,
2
,
1
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
4
,
true
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
1
,
4
,
true
,
7
,
1
>
,
DeviceGemmXdlSplitK
<
F32
,
F32
,
F32
,
F32
,
Col
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
64
,
128
,
4
,
4
,
32
,
32
,
1
,
2
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
1
,
4
,
true
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
4
,
true
,
7
,
1
>
//#################
########
|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds|
CShuffle| CShuffle| CBlockTransferClusterLengths| CBlock
Transfer|
//#################
########
| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN|
MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVecto
r|
//#################
########
| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| |
PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl
|
//#################
########
| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
| |
| |
DeviceGemmXdlSplitK
CShuffle
<
F32
,
F32
,
F32
,
F32
,
Col
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
256
,
128
,
4
,
4
,
32
,
32
,
4
,
2
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
4
,
4
,
true
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
4
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
4
>
,
DeviceGemmXdlSplitK
CShuffle
<
F32
,
F32
,
F32
,
F32
,
Col
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
128
,
256
,
4
,
4
,
32
,
32
,
2
,
4
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
4
,
true
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
4
,
4
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
4
>
,
DeviceGemmXdlSplitK
CShuffle
<
F32
,
F32
,
F32
,
F32
,
Col
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
128
,
128
,
128
,
4
,
4
,
32
,
32
,
4
,
2
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
4
,
4
,
true
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
4
,
4
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
4
>
,
DeviceGemmXdlSplitK
CShuffle
<
F32
,
F32
,
F32
,
F32
,
Col
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
128
,
128
,
4
,
4
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
4
,
true
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
4
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
4
>
,
DeviceGemmXdlSplitK
CShuffle
<
F32
,
F32
,
F32
,
F32
,
Col
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
128
,
128
,
64
,
4
,
4
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
4
,
4
,
true
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
4
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
4
>
,
DeviceGemmXdlSplitK
CShuffle
<
F32
,
F32
,
F32
,
F32
,
Col
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
128
,
64
,
128
,
4
,
4
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
4
,
true
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
4
,
4
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
4
>
,
DeviceGemmXdlSplitK
CShuffle
<
F32
,
F32
,
F32
,
F32
,
Col
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
128
,
64
,
4
,
4
,
32
,
32
,
2
,
1
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
4
,
true
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
1
,
4
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
4
>
,
DeviceGemmXdlSplitK
CShuffle
<
F32
,
F32
,
F32
,
F32
,
Col
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
64
,
128
,
4
,
4
,
32
,
32
,
1
,
2
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
1
,
4
,
true
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
4
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
4
>
// clang-format on
>
;
...
...
library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instance.cpp
View file @
b89a88b5
...
...
@@ -6,7 +6,7 @@
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_xdl_splitk.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_xdl_splitk
_c_shuffle
.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
...
...
@@ -31,18 +31,18 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa
// Compilation parameters for a[k, m] * b[n, k] = c[m, n]
using
device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instances
=
std
::
tuple
<
// clang-format off
//#################|
AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds|
CThreadTransfer| CThread
Transfer|
//#################|
Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN|
SrcDstVectorDim| DstScala
r|
//#################|
| | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| |
| PerVector
|
//#################|
| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmXdlSplitK
<
F32
,
F32
,
F32
,
F32
,
Col
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
256
,
128
,
4
,
4
,
32
,
32
,
4
,
2
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
4
,
4
,
true
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
4
,
4
,
true
,
7
,
1
>
,
DeviceGemmXdlSplitK
<
F32
,
F32
,
F32
,
F32
,
Col
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
128
,
256
,
4
,
4
,
32
,
32
,
2
,
4
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
4
,
true
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
4
,
4
,
true
,
7
,
1
>
,
DeviceGemmXdlSplitK
<
F32
,
F32
,
F32
,
F32
,
Col
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
128
,
128
,
128
,
4
,
4
,
32
,
32
,
4
,
2
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
4
,
4
,
true
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
4
,
4
,
true
,
7
,
1
>
,
DeviceGemmXdlSplitK
<
F32
,
F32
,
F32
,
F32
,
Col
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
128
,
128
,
4
,
4
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
4
,
true
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
4
,
4
,
true
,
7
,
1
>
,
DeviceGemmXdlSplitK
<
F32
,
F32
,
F32
,
F32
,
Col
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
128
,
128
,
64
,
4
,
4
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
4
,
4
,
true
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
4
,
4
,
true
,
7
,
1
>
,
DeviceGemmXdlSplitK
<
F32
,
F32
,
F32
,
F32
,
Col
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
128
,
64
,
128
,
4
,
4
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
4
,
true
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
4
,
4
,
true
,
7
,
1
>
,
DeviceGemmXdlSplitK
<
F32
,
F32
,
F32
,
F32
,
Col
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
128
,
64
,
4
,
4
,
32
,
32
,
2
,
1
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
4
,
true
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
4
,
4
,
true
,
7
,
1
>
,
DeviceGemmXdlSplitK
<
F32
,
F32
,
F32
,
F32
,
Col
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
64
,
128
,
4
,
4
,
32
,
32
,
1
,
2
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
1
,
4
,
true
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
4
,
4
,
true
,
7
,
1
>
//#################
########
|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds|
CShuffle| CShuffle| CBlockTransferClusterLengths| CBlock
Transfer|
//#################
########
| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN|
MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVecto
r|
//#################
########
| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| |
PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl
|
//#################
########
| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
| |
| |
DeviceGemmXdlSplitK
CShuffle
<
F32
,
F32
,
F32
,
F32
,
Col
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
256
,
128
,
4
,
4
,
32
,
32
,
4
,
2
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
4
,
4
,
true
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
4
,
4
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
4
>
,
DeviceGemmXdlSplitK
CShuffle
<
F32
,
F32
,
F32
,
F32
,
Col
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
128
,
256
,
4
,
4
,
32
,
32
,
2
,
4
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
4
,
true
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
4
,
4
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
4
>
,
DeviceGemmXdlSplitK
CShuffle
<
F32
,
F32
,
F32
,
F32
,
Col
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
128
,
128
,
128
,
4
,
4
,
32
,
32
,
4
,
2
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
4
,
4
,
true
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
4
,
4
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
4
>
,
DeviceGemmXdlSplitK
CShuffle
<
F32
,
F32
,
F32
,
F32
,
Col
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
128
,
128
,
4
,
4
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
4
,
true
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
4
,
4
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
4
>
,
DeviceGemmXdlSplitK
CShuffle
<
F32
,
F32
,
F32
,
F32
,
Col
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
128
,
128
,
64
,
4
,
4
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
4
,
4
,
true
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
4
,
4
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
4
>
,
DeviceGemmXdlSplitK
CShuffle
<
F32
,
F32
,
F32
,
F32
,
Col
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
128
,
64
,
128
,
4
,
4
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
4
,
true
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
4
,
4
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
4
>
,
DeviceGemmXdlSplitK
CShuffle
<
F32
,
F32
,
F32
,
F32
,
Col
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
128
,
64
,
4
,
4
,
32
,
32
,
2
,
1
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
4
,
true
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
4
,
4
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
4
>
,
DeviceGemmXdlSplitK
CShuffle
<
F32
,
F32
,
F32
,
F32
,
Col
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
64
,
128
,
4
,
4
,
32
,
32
,
1
,
2
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
1
,
4
,
true
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
4
,
4
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
4
>
// clang-format on
>
;
...
...
library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instance.cpp
View file @
b89a88b5
...
...
@@ -6,7 +6,7 @@
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_xdl_splitk.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_xdl_splitk
_c_shuffle
.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
...
...
@@ -26,28 +26,23 @@ using S = ck::Sequence<Is...>;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
static
constexpr
auto
Gemm
MNPadding
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNPadding
;
static
constexpr
auto
Gemm
Default
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
// Compilation parameters for a[m, k] * b[k, n] = c[m, n]
using
device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instances
=
std
::
tuple
<
// clang-format off
//###################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM|Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
//###################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
//###################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
//###################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmXdlSplitK
<
F32
,
F32
,
F32
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
256
,
96
,
128
,
4
,
8
,
16
,
16
,
3
,
4
,
S
<
1
,
4
,
32
,
2
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
4
,
4
,
true
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
4
,
true
,
7
,
1
>
,
DeviceGemmXdlSplitK
<
F32
,
F32
,
F32
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
256
,
256
,
128
,
4
,
4
,
32
,
32
,
4
,
2
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
4
,
4
,
true
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
4
,
true
,
7
,
1
>
,
DeviceGemmXdlSplitK
<
F32
,
F32
,
F32
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
256
,
128
,
256
,
4
,
4
,
32
,
32
,
2
,
4
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
4
,
4
,
true
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
4
,
4
,
true
,
7
,
1
>
,
DeviceGemmXdlSplitK
<
F32
,
F32
,
F32
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
128
,
128
,
128
,
4
,
4
,
32
,
32
,
4
,
2
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
4
,
4
,
true
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
4
,
4
,
true
,
7
,
1
>
,
DeviceGemmXdlSplitK
<
F32
,
F32
,
F32
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
256
,
128
,
128
,
4
,
4
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
4
,
4
,
true
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
4
,
true
,
7
,
1
>
,
DeviceGemmXdlSplitK
<
F32
,
F32
,
F32
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
128
,
128
,
64
,
4
,
4
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
4
,
4
,
true
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
4
,
true
,
7
,
1
>
,
DeviceGemmXdlSplitK
<
F32
,
F32
,
F32
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
128
,
64
,
128
,
4
,
4
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
4
,
4
,
true
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
4
,
4
,
true
,
7
,
1
>
,
DeviceGemmXdlSplitK
<
F32
,
F32
,
F32
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
256
,
128
,
64
,
4
,
4
,
32
,
32
,
2
,
1
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
4
,
4
,
true
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
1
,
4
,
true
,
7
,
1
>
,
DeviceGemmXdlSplitK
<
F32
,
F32
,
F32
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
256
,
64
,
128
,
4
,
4
,
32
,
32
,
1
,
2
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
4
,
4
,
true
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
4
,
true
,
7
,
1
>
,
DeviceGemmXdlSplitK
<
F32
,
F32
,
F32
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
256
,
32
,
256
,
4
,
4
,
32
,
32
,
1
,
2
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
4
,
4
,
true
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
4
,
4
,
true
,
7
,
1
>
,
DeviceGemmXdlSplitK
<
F32
,
F32
,
F32
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
128
,
32
,
128
,
4
,
4
,
32
,
32
,
1
,
2
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
4
,
4
,
true
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
4
,
4
,
true
,
7
,
1
>
,
DeviceGemmXdlSplitK
<
F32
,
F32
,
F32
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
256
,
16
,
256
,
4
,
4
,
16
,
16
,
1
,
4
,
S
<
1
,
4
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
4
,
4
,
true
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
4
,
4
,
true
,
7
,
1
>
,
DeviceGemmXdlSplitK
<
F32
,
F32
,
F32
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
128
,
16
,
128
,
4
,
4
,
16
,
16
,
1
,
4
,
S
<
1
,
4
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
4
,
4
,
true
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
4
,
4
,
true
,
7
,
1
>
//#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector|
//#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl|
//#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmXdlSplitKCShuffle
<
F32
,
F32
,
F32
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
256
,
128
,
4
,
4
,
32
,
32
,
4
,
2
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
4
,
4
,
true
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
4
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
4
>
,
DeviceGemmXdlSplitKCShuffle
<
F32
,
F32
,
F32
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
128
,
256
,
4
,
4
,
32
,
32
,
2
,
4
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
4
,
4
,
true
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
4
,
4
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
4
>
,
DeviceGemmXdlSplitKCShuffle
<
F32
,
F32
,
F32
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
128
,
128
,
128
,
4
,
4
,
32
,
32
,
4
,
2
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
4
,
4
,
true
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
4
,
4
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
4
>
,
DeviceGemmXdlSplitKCShuffle
<
F32
,
F32
,
F32
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
128
,
128
,
4
,
4
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
4
,
4
,
true
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
4
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
4
>
,
DeviceGemmXdlSplitKCShuffle
<
F32
,
F32
,
F32
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
128
,
128
,
64
,
4
,
4
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
4
,
4
,
true
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
4
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
4
>
,
DeviceGemmXdlSplitKCShuffle
<
F32
,
F32
,
F32
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
128
,
64
,
128
,
4
,
4
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
4
,
4
,
true
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
4
,
4
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
4
>
,
DeviceGemmXdlSplitKCShuffle
<
F32
,
F32
,
F32
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
128
,
64
,
4
,
4
,
32
,
32
,
2
,
1
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
4
,
4
,
true
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
1
,
4
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
4
>
,
DeviceGemmXdlSplitKCShuffle
<
F32
,
F32
,
F32
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
64
,
128
,
4
,
4
,
32
,
32
,
1
,
2
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
4
,
4
,
true
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
4
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
4
>
// clang-format on
>
;
...
...
library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instance.cpp
View file @
b89a88b5
...
...
@@ -6,7 +6,7 @@
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_xdl_splitk.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_xdl_splitk
_c_shuffle
.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
...
...
@@ -31,23 +31,23 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa
// Compilation parameters for a[m, k] * b[n, k] = c[m, n]
using
device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances
=
std
::
tuple
<
// clang-format off
//#################|
AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds|
CThreadTransfer| CThread
Transfer|
//#################|
Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN|
SrcDstVectorDim| DstScala
r|
//#################|
| | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| |
| PerVector
|
//#################|
| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmXdlSplitK
<
F32
,
F32
,
F32
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
256
,
128
,
4
,
4
,
32
,
32
,
4
,
2
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
4
,
4
,
true
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
4
,
4
,
true
,
7
,
1
>
,
DeviceGemmXdlSplitK
<
F32
,
F32
,
F32
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
128
,
256
,
4
,
4
,
32
,
32
,
2
,
4
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
4
,
4
,
true
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
4
,
4
,
true
,
7
,
1
>
,
DeviceGemmXdlSplitK
<
F32
,
F32
,
F32
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
128
,
128
,
128
,
4
,
4
,
32
,
32
,
4
,
2
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
4
,
4
,
true
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
4
,
4
,
true
,
7
,
1
>
,
DeviceGemmXdlSplitK
<
F32
,
F32
,
F32
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
128
,
128
,
4
,
4
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
4
,
4
,
true
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
4
,
4
,
true
,
7
,
1
>
,
DeviceGemmXdlSplitK
<
F32
,
F32
,
F32
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
128
,
128
,
64
,
4
,
4
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
4
,
4
,
true
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
4
,
4
,
true
,
7
,
1
>
,
DeviceGemmXdlSplitK
<
F32
,
F32
,
F32
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
128
,
64
,
128
,
4
,
4
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
4
,
4
,
true
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
4
,
4
,
true
,
7
,
1
>
,
DeviceGemmXdlSplitK
<
F32
,
F32
,
F32
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
64
,
64
,
64
,
4
,
4
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
4
,
4
,
true
,
S
<
1
,
4
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
4
,
4
,
true
,
7
,
1
>
,
DeviceGemmXdlSplitK
<
F32
,
F32
,
F32
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
128
,
64
,
4
,
4
,
32
,
32
,
2
,
1
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
4
,
4
,
true
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
4
,
4
,
true
,
7
,
1
>
,
DeviceGemmXdlSplitK
<
F32
,
F32
,
F32
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
64
,
128
,
4
,
4
,
32
,
32
,
1
,
2
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
4
,
4
,
true
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
4
,
4
,
true
,
7
,
1
>
,
DeviceGemmXdlSplitK
<
F32
,
F32
,
F32
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
128
,
128
,
32
,
4
,
4
,
32
,
32
,
2
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
4
,
4
,
true
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
4
,
4
,
true
,
7
,
1
>
,
DeviceGemmXdlSplitK
<
F32
,
F32
,
F32
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
128
,
32
,
128
,
4
,
4
,
32
,
32
,
1
,
2
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
4
,
4
,
true
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
4
,
4
,
true
,
7
,
1
>
,
DeviceGemmXdlSplitK
<
F32
,
F32
,
F32
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
64
,
64
,
32
,
4
,
4
,
32
,
32
,
2
,
1
,
S
<
1
,
4
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
4
,
4
,
true
,
S
<
1
,
4
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
4
,
4
,
true
,
7
,
1
>
,
DeviceGemmXdlSplitK
<
F32
,
F32
,
F32
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
64
,
32
,
64
,
4
,
4
,
32
,
32
,
1
,
2
,
S
<
1
,
4
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
4
,
4
,
true
,
S
<
1
,
4
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
4
,
4
,
true
,
7
,
1
>
//#################
########
|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds|
CShuffle| CShuffle| CBlockTransferClusterLengths| CBlock
Transfer|
//#################
########
| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN|
MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVecto
r|
//#################
########
| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| |
PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl
|
//#################
########
| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
| |
| |
DeviceGemmXdlSplitK
CShuffle
<
F32
,
F32
,
F32
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
256
,
128
,
4
,
4
,
32
,
32
,
4
,
2
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
4
,
4
,
true
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
3
,
4
,
4
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
4
>
,
DeviceGemmXdlSplitK
CShuffle
<
F32
,
F32
,
F32
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
128
,
256
,
4
,
4
,
32
,
32
,
2
,
4
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
4
,
4
,
true
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
3
,
4
,
4
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
4
>
,
DeviceGemmXdlSplitK
CShuffle
<
F32
,
F32
,
F32
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
128
,
128
,
128
,
4
,
4
,
32
,
32
,
4
,
2
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
4
,
4
,
true
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
3
,
4
,
4
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
4
>
,
DeviceGemmXdlSplitK
CShuffle
<
F32
,
F32
,
F32
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
128
,
128
,
4
,
4
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
4
,
4
,
true
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
3
,
4
,
4
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
4
>
,
DeviceGemmXdlSplitK
CShuffle
<
F32
,
F32
,
F32
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
128
,
128
,
64
,
4
,
4
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
4
,
4
,
true
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
3
,
4
,
4
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
4
>
,
DeviceGemmXdlSplitK
CShuffle
<
F32
,
F32
,
F32
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
128
,
64
,
128
,
4
,
4
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
4
,
4
,
true
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
3
,
4
,
4
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
4
>
,
DeviceGemmXdlSplitK
CShuffle
<
F32
,
F32
,
F32
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
64
,
64
,
64
,
4
,
4
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
4
,
4
,
true
,
S
<
1
,
4
,
16
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
3
,
4
,
4
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
4
>
,
DeviceGemmXdlSplitK
CShuffle
<
F32
,
F32
,
F32
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
128
,
64
,
4
,
4
,
32
,
32
,
2
,
1
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
4
,
4
,
true
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
3
,
4
,
4
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
4
>
,
DeviceGemmXdlSplitK
CShuffle
<
F32
,
F32
,
F32
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
64
,
128
,
4
,
4
,
32
,
32
,
1
,
2
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
4
,
4
,
true
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
3
,
4
,
4
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
4
>
,
DeviceGemmXdlSplitK
CShuffle
<
F32
,
F32
,
F32
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
128
,
128
,
32
,
4
,
4
,
32
,
32
,
2
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
4
,
4
,
true
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
3
,
4
,
4
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
4
>
,
DeviceGemmXdlSplitK
CShuffle
<
F32
,
F32
,
F32
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
128
,
32
,
128
,
4
,
4
,
32
,
32
,
1
,
2
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
4
,
4
,
true
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
3
,
4
,
4
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
4
>
,
DeviceGemmXdlSplitK
CShuffle
<
F32
,
F32
,
F32
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
64
,
64
,
32
,
4
,
4
,
32
,
32
,
2
,
1
,
S
<
1
,
4
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
4
,
4
,
true
,
S
<
1
,
4
,
16
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
3
,
4
,
4
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
4
>
,
DeviceGemmXdlSplitK
CShuffle
<
F32
,
F32
,
F32
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
64
,
32
,
64
,
4
,
4
,
32
,
32
,
1
,
2
,
S
<
1
,
4
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
4
,
4
,
true
,
S
<
1
,
4
,
16
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
3
,
4
,
4
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
4
>
// clang-format on
>
;
...
...
library/src/tensor_operation_instance/gpu/normalization/device_softmax_f16_f16_instance.cpp
View file @
b89a88b5
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/device_softmax.hpp"
#include "ck/utility/data_type.hpp"
#include <tuple>
#include <vector>
#include "ck/ck.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_softmax_impl.hpp"
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
#include "ck/utility/data_type.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
namespace
{
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
using
Pass
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
}
// namespace
template
<
index_t
Rank
,
index_t
Reduce
>
using
device_softmax_f16_f16_instances
=
std
::
tuple
<
// clang-format off
//
InDataType, AccDataType, OutDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, InSrcVectorDim, InSrcVectorSize, OutDstVectorSize>
DeviceSoftmax
<
F16
,
F32
,
F16
,
Rank
,
Reduce
,
256
,
8
,
32
,
1
,
8
,
1
,
1
,
1
>
,
// fallback kernel
DeviceSoftmax
<
F16
,
F32
,
F16
,
Rank
,
Reduce
,
256
,
8
,
32
,
1
,
8
,
1
,
8
,
8
>
,
DeviceSoftmax
<
F16
,
F32
,
F16
,
Rank
,
Reduce
,
256
,
4
,
64
,
1
,
8
,
1
,
8
,
8
>
,
DeviceSoftmax
<
F16
,
F32
,
F16
,
Rank
,
Reduce
,
256
,
2
,
128
,
1
,
8
,
1
,
8
,
8
>
,
DeviceSoftmax
<
F16
,
F32
,
F16
,
Rank
,
Reduce
,
256
,
2
,
128
,
1
,
16
,
1
,
8
,
8
>
,
DeviceSoftmax
<
F16
,
F32
,
F16
,
Rank
,
Reduce
,
256
,
2
,
128
,
1
,
32
,
1
,
8
,
8
>
,
DeviceSoftmax
<
F16
,
F32
,
F16
,
Rank
,
Reduce
,
256
,
1
,
256
,
1
,
8
,
1
,
8
,
8
>
,
DeviceSoftmax
<
F16
,
F32
,
F16
,
Rank
,
Reduce
,
256
,
1
,
256
,
1
,
16
,
1
,
8
,
8
>
,
DeviceSoftmax
<
F16
,
F32
,
F16
,
Rank
,
Reduce
,
256
,
1
,
256
,
1
,
32
,
1
,
8
,
8
>
//
InDataType, AccDataType, OutDataType,
InElementwiseOp, AccElementwiseOp,
Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, InSrcVectorDim, InSrcVectorSize, OutDstVectorSize>
DeviceSoftmax
Impl
<
F16
,
F32
,
F16
,
Pass
,
Pass
,
Rank
,
Reduce
,
256
,
8
,
32
,
1
,
8
,
1
,
1
,
1
>
,
// fallback kernel
DeviceSoftmax
Impl
<
F16
,
F32
,
F16
,
Pass
,
Pass
,
Rank
,
Reduce
,
256
,
8
,
32
,
1
,
8
,
1
,
8
,
8
>
,
DeviceSoftmax
Impl
<
F16
,
F32
,
F16
,
Pass
,
Pass
,
Rank
,
Reduce
,
256
,
4
,
64
,
1
,
8
,
1
,
8
,
8
>
,
DeviceSoftmax
Impl
<
F16
,
F32
,
F16
,
Pass
,
Pass
,
Rank
,
Reduce
,
256
,
2
,
128
,
1
,
8
,
1
,
8
,
8
>
,
DeviceSoftmax
Impl
<
F16
,
F32
,
F16
,
Pass
,
Pass
,
Rank
,
Reduce
,
256
,
2
,
128
,
1
,
16
,
1
,
8
,
8
>
,
DeviceSoftmax
Impl
<
F16
,
F32
,
F16
,
Pass
,
Pass
,
Rank
,
Reduce
,
256
,
2
,
128
,
1
,
32
,
1
,
8
,
8
>
,
DeviceSoftmax
Impl
<
F16
,
F32
,
F16
,
Pass
,
Pass
,
Rank
,
Reduce
,
256
,
1
,
256
,
1
,
8
,
1
,
8
,
8
>
,
DeviceSoftmax
Impl
<
F16
,
F32
,
F16
,
Pass
,
Pass
,
Rank
,
Reduce
,
256
,
1
,
256
,
1
,
16
,
1
,
8
,
8
>
,
DeviceSoftmax
Impl
<
F16
,
F32
,
F16
,
Pass
,
Pass
,
Rank
,
Reduce
,
256
,
1
,
256
,
1
,
32
,
1
,
8
,
8
>
// clang-format on
>
;
void
add_device_softmax_f16_f16_rank3_instances
(
std
::
vector
<
DeviceNormalizationPtr
>&
instances
)
void
add_device_softmax_f16_f16_rank3_instances
(
std
::
vector
<
DeviceSoftmaxPtr
<
F16
,
F32
,
F16
,
Pass
,
Pass
,
3
>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_softmax_f16_f16_instances
<
3
,
1
>
{});
add_device_operation_instances
(
instances
,
device_softmax_f16_f16_instances
<
3
,
2
>
{});
}
void
add_device_softmax_f16_f16_rank4_instances
(
std
::
vector
<
DeviceNormalizationPtr
>&
instances
)
void
add_device_softmax_f16_f16_rank4_instances
(
std
::
vector
<
DeviceSoftmaxPtr
<
F16
,
F32
,
F16
,
Pass
,
Pass
,
4
>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_softmax_f16_f16_instances
<
4
,
1
>
{});
add_device_operation_instances
(
instances
,
device_softmax_f16_f16_instances
<
4
,
2
>
{});
...
...
library/src/tensor_operation_instance/gpu/normalization/device_softmax_f32_f32_instance.cpp
View file @
b89a88b5
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <tuple>
#include <vector>
#include "ck/ck.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/tensor_operation/gpu/device/device_softmax.hpp"
#include "ck/
utility/data_type
.hpp"
#include "ck/tensor_operation/gpu/device/
impl/
device_softmax
_impl
.hpp"
#include "ck/
tensor_operation/gpu/element/unary_element_wise_operation
.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
using
F32
=
float
;
namespace
{
using
F32
=
float
;
using
Pass
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
}
// namespace
template
<
index_t
Rank
,
index_t
Reduce
>
using
device_softmax_f32_f32_instances
=
std
::
tuple
<
// clang-format off
//
InDataType, AccDataType, OutDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, InSrcVectorDim, InSrcVectorSize, OutDstVectorSize>
DeviceSoftmax
<
F32
,
F32
,
F32
,
Rank
,
Reduce
,
256
,
8
,
32
,
1
,
8
,
1
,
1
,
1
>
,
// fallback kernel
DeviceSoftmax
<
F32
,
F32
,
F32
,
Rank
,
Reduce
,
256
,
8
,
32
,
1
,
8
,
1
,
4
,
4
>
,
DeviceSoftmax
<
F32
,
F32
,
F32
,
Rank
,
Reduce
,
256
,
4
,
64
,
1
,
8
,
1
,
4
,
4
>
,
DeviceSoftmax
<
F32
,
F32
,
F32
,
Rank
,
Reduce
,
256
,
2
,
128
,
1
,
8
,
1
,
4
,
4
>
,
DeviceSoftmax
<
F32
,
F32
,
F32
,
Rank
,
Reduce
,
256
,
2
,
128
,
1
,
16
,
1
,
4
,
4
>
,
DeviceSoftmax
<
F32
,
F32
,
F32
,
Rank
,
Reduce
,
256
,
2
,
128
,
1
,
32
,
1
,
4
,
4
>
,
DeviceSoftmax
<
F32
,
F32
,
F32
,
Rank
,
Reduce
,
256
,
1
,
256
,
1
,
8
,
1
,
4
,
4
>
,
DeviceSoftmax
<
F32
,
F32
,
F32
,
Rank
,
Reduce
,
256
,
1
,
256
,
1
,
16
,
1
,
4
,
4
>
,
DeviceSoftmax
<
F32
,
F32
,
F32
,
Rank
,
Reduce
,
256
,
1
,
256
,
1
,
32
,
1
,
4
,
4
>
//
InDataType, AccDataType, OutDataType,
InElementwiseOp, AccElementwiseOp,
Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, InSrcVectorDim, InSrcVectorSize, OutDstVectorSize>
DeviceSoftmax
Impl
<
F32
,
F32
,
F32
,
Pass
,
Pass
,
Rank
,
Reduce
,
256
,
8
,
32
,
1
,
8
,
1
,
1
,
1
>
,
// fallback kernel
DeviceSoftmax
Impl
<
F32
,
F32
,
F32
,
Pass
,
Pass
,
Rank
,
Reduce
,
256
,
8
,
32
,
1
,
8
,
1
,
4
,
4
>
,
DeviceSoftmax
Impl
<
F32
,
F32
,
F32
,
Pass
,
Pass
,
Rank
,
Reduce
,
256
,
4
,
64
,
1
,
8
,
1
,
4
,
4
>
,
DeviceSoftmax
Impl
<
F32
,
F32
,
F32
,
Pass
,
Pass
,
Rank
,
Reduce
,
256
,
2
,
128
,
1
,
8
,
1
,
4
,
4
>
,
DeviceSoftmax
Impl
<
F32
,
F32
,
F32
,
Pass
,
Pass
,
Rank
,
Reduce
,
256
,
2
,
128
,
1
,
16
,
1
,
4
,
4
>
,
DeviceSoftmax
Impl
<
F32
,
F32
,
F32
,
Pass
,
Pass
,
Rank
,
Reduce
,
256
,
2
,
128
,
1
,
32
,
1
,
4
,
4
>
,
DeviceSoftmax
Impl
<
F32
,
F32
,
F32
,
Pass
,
Pass
,
Rank
,
Reduce
,
256
,
1
,
256
,
1
,
8
,
1
,
4
,
4
>
,
DeviceSoftmax
Impl
<
F32
,
F32
,
F32
,
Pass
,
Pass
,
Rank
,
Reduce
,
256
,
1
,
256
,
1
,
16
,
1
,
4
,
4
>
,
DeviceSoftmax
Impl
<
F32
,
F32
,
F32
,
Pass
,
Pass
,
Rank
,
Reduce
,
256
,
1
,
256
,
1
,
32
,
1
,
4
,
4
>
// clang-format on
>
;
void
add_device_softmax_f32_f32_rank3_instances
(
std
::
vector
<
DeviceNormalizationPtr
>&
instances
)
void
add_device_softmax_f32_f32_rank3_instances
(
std
::
vector
<
DeviceSoftmaxPtr
<
F32
,
F32
,
F32
,
Pass
,
Pass
,
3
>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_softmax_f32_f32_instances
<
3
,
1
>
{});
add_device_operation_instances
(
instances
,
device_softmax_f32_f32_instances
<
3
,
2
>
{});
}
void
add_device_softmax_f32_f32_rank4_instances
(
std
::
vector
<
DeviceNormalizationPtr
>&
instances
)
void
add_device_softmax_f32_f32_rank4_instances
(
std
::
vector
<
DeviceSoftmaxPtr
<
F32
,
F32
,
F32
,
Pass
,
Pass
,
4
>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_softmax_f32_f32_instances
<
4
,
1
>
{});
add_device_operation_instances
(
instances
,
device_softmax_f32_f32_instances
<
4
,
2
>
{});
...
...
profiler/CMakeLists.txt
View file @
b89a88b5
...
...
@@ -12,6 +12,8 @@ set(PROFILER_SOURCE
src/profile_gemm_add_add_fastgelu.cpp
src/profile_gemm_reduce.cpp
src/profile_batched_gemm.cpp
src/profile_batched_gemm_gemm.cpp
src/profile_batched_gemm_add_relu_gemm_add.cpp
src/profile_batched_gemm_reduce.cpp
src/profile_grouped_gemm.cpp
src/profile_conv_fwd.cpp
...
...
@@ -35,6 +37,8 @@ target_link_libraries(ckProfiler PRIVATE device_gemm_add_add_fastgelu_instance)
target_link_libraries
(
ckProfiler PRIVATE device_gemm_reduce_instance
)
target_link_libraries
(
ckProfiler PRIVATE device_gemm_bias_add_reduce_instance
)
target_link_libraries
(
ckProfiler PRIVATE device_batched_gemm_instance
)
target_link_libraries
(
ckProfiler PRIVATE device_batched_gemm_gemm_instance
)
target_link_libraries
(
ckProfiler PRIVATE device_batched_gemm_add_relu_gemm_add_instance
)
target_link_libraries
(
ckProfiler PRIVATE device_batched_gemm_reduce_instance
)
target_link_libraries
(
ckProfiler PRIVATE device_grouped_gemm_instance
)
target_link_libraries
(
ckProfiler PRIVATE device_conv2d_fwd_instance
)
...
...
profiler/include/profile_batched_gemm_add_relu_gemm_add_impl.hpp
0 → 100644
View file @
b89a88b5
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <memory>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/gpu/batched_gemm_add_relu_gemm_add.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
namespace
ck
{
namespace
profiler
{
template
<
typename
A0Layout
,
typename
B0Layout
,
typename
D0sLayout
,
typename
B1Layout
,
typename
D1sLayout
,
typename
E1Layout
,
typename
A0DataType
,
typename
B0DataType
,
typename
D0sDataType
,
typename
B1DataType
,
typename
D1sDataType
,
typename
E1DataType
>
bool
profile_batched_gemm_add_relu_gemm_add_impl
(
bool
do_verification
,
int
init_method
,
bool
do_log
,
bool
time_kernel
,
int
M
,
int
N
,
int
K
,
int
O
,
int
BatchCount
=
1
,
int
StrideA0
=
-
1
,
int
StrideB0
=
-
1
,
int
StrideD0
=
-
1
,
int
StrideB1
=
-
1
,
int
StrideD1
=
-
1
,
int
StrideE1
=
-
1
,
int
BatchStrideA0
=
-
1
,
int
BatchStrideB0
=
-
1
,
int
BatchStrideD0
=
-
1
,
int
BatchStrideB1
=
-
1
,
int
BatchStrideD1
=
-
1
,
int
BatchStrideE1
=
-
1
)
{
using
Row
=
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
tensor_layout
::
gemm
::
ColumnMajor
;
using
PassThrough
=
tensor_operation
::
element_wise
::
PassThrough
;
using
A0ElementOp
=
PassThrough
;
using
B0ElementOp
=
PassThrough
;
using
CDE0ElementOp
=
ck
::
tensor_operation
::
element_wise
::
AddRelu
;
using
B1ElementOp
=
PassThrough
;
using
CDE1ElementOp
=
ck
::
tensor_operation
::
element_wise
::
Add
;
using
D0DataType
=
remove_cvref_t
<
tuple_element_t
<
0
,
D0sDataType
>>
;
using
D0Layout
=
remove_cvref_t
<
tuple_element_t
<
0
,
D0sLayout
>>
;
using
D1DataType
=
remove_cvref_t
<
tuple_element_t
<
0
,
D1sDataType
>>
;
using
D1Layout
=
remove_cvref_t
<
tuple_element_t
<
0
,
D1sLayout
>>
;
// for reference
using
RefAcc0DataType
=
float
;
using
RefAcc1DataType
=
float
;
bool
pass
=
true
;
const
int
DefaultStrideA0
=
ck
::
is_same_v
<
A0Layout
,
Row
>
?
K
:
M
;
const
int
DefaultStrideB0
=
ck
::
is_same_v
<
B0Layout
,
Row
>
?
N
:
K
;
const
int
DefaultStrideD0
=
ck
::
is_same_v
<
D0Layout
,
Row
>
?
N
:
M
;
const
int
DefaultStrideB1
=
ck
::
is_same_v
<
B1Layout
,
Row
>
?
O
:
N
;
const
int
DefaultStrideD1
=
ck
::
is_same_v
<
D1Layout
,
Row
>
?
O
:
M
;
const
int
DefaultStrideE1
=
ck
::
is_same_v
<
E1Layout
,
Row
>
?
O
:
M
;
StrideA0
=
(
StrideA0
<
0
)
?
DefaultStrideA0
:
StrideA0
;
StrideB0
=
(
StrideB0
<
0
)
?
DefaultStrideB0
:
StrideB0
;
StrideD0
=
(
StrideD0
<
0
)
?
DefaultStrideD0
:
StrideD0
;
StrideB1
=
(
StrideB1
<
0
)
?
DefaultStrideB1
:
StrideB1
;
StrideD1
=
(
StrideD1
<
0
)
?
DefaultStrideD1
:
StrideD1
;
StrideE1
=
(
StrideE1
<
0
)
?
DefaultStrideE1
:
StrideE1
;
const
int
DefaultBatchStrideA0
=
(
ck
::
is_same_v
<
A0Layout
,
Col
>
?
K
:
M
)
*
StrideA0
;
const
int
DefaultBatchStrideB0
=
(
ck
::
is_same_v
<
B0Layout
,
Col
>
?
N
:
K
)
*
StrideB0
;
const
int
DefaultBatchStrideD0
=
(
ck
::
is_same_v
<
D0Layout
,
Col
>
?
N
:
M
)
*
StrideD0
;
const
int
DefaultBatchStrideB1
=
(
ck
::
is_same_v
<
B1Layout
,
Col
>
?
O
:
N
)
*
StrideB1
;
const
int
DefaultBatchStrideD1
=
(
ck
::
is_same_v
<
D1Layout
,
Col
>
?
O
:
M
)
*
StrideD1
;
const
int
DefaultBatchStrideE1
=
(
ck
::
is_same_v
<
E1Layout
,
Col
>
?
O
:
M
)
*
StrideE1
;
BatchStrideA0
=
BatchStrideA0
<
0
?
DefaultBatchStrideA0
:
BatchStrideA0
;
BatchStrideB0
=
BatchStrideB0
<
0
?
DefaultBatchStrideB0
:
BatchStrideB0
;
BatchStrideD0
=
BatchStrideD0
<
0
?
DefaultBatchStrideD0
:
BatchStrideD0
;
BatchStrideB1
=
BatchStrideB1
<
0
?
DefaultBatchStrideB1
:
BatchStrideB1
;
BatchStrideD1
=
BatchStrideD1
<
0
?
DefaultBatchStrideD1
:
BatchStrideD1
;
BatchStrideE1
=
BatchStrideE1
<
0
?
DefaultBatchStrideE1
:
BatchStrideE1
;
auto
f_host_tensor_descriptor
=
[](
std
::
size_t
batch_count
,
std
::
size_t
row
,
std
::
size_t
col
,
std
::
size_t
stride
,
std
::
size_t
batch_stride
,
auto
layout
)
{
if
(
std
::
is_same
<
decltype
(
layout
),
Row
>::
value
)
{
return
HostTensorDescriptor
(
std
::
vector
<
std
::
size_t
>
({
batch_count
,
row
,
col
}),
std
::
vector
<
std
::
size_t
>
({
batch_stride
,
stride
,
1
}));
}
else
{
return
HostTensorDescriptor
(
std
::
vector
<
std
::
size_t
>
({
batch_count
,
row
,
col
}),
std
::
vector
<
std
::
size_t
>
({
batch_stride
,
1
,
stride
}));
}
};
// E_m_o = A_m_k * B0_k_n * B1_n_o
Tensor
<
A0DataType
>
a0_g_m_k
(
f_host_tensor_descriptor
(
BatchCount
,
M
,
K
,
StrideA0
,
BatchStrideA0
,
A0Layout
{}));
Tensor
<
B0DataType
>
b0_g_k_n
(
f_host_tensor_descriptor
(
BatchCount
,
K
,
N
,
StrideB0
,
BatchStrideB0
,
B0Layout
{}));
Tensor
<
D0DataType
>
d0_g_m_n
(
f_host_tensor_descriptor
(
BatchCount
,
M
,
N
,
StrideD0
,
BatchStrideD0
,
D0Layout
{}));
Tensor
<
B1DataType
>
b1_g_n_o
(
f_host_tensor_descriptor
(
BatchCount
,
N
,
O
,
StrideB1
,
BatchStrideB1
,
B1Layout
{}));
Tensor
<
D1DataType
>
d1_g_m_o
(
f_host_tensor_descriptor
(
BatchCount
,
M
,
O
,
StrideD1
,
BatchStrideD1
,
D1Layout
{}));
Tensor
<
E1DataType
>
e1_g_m_o_host_result
(
f_host_tensor_descriptor
(
BatchCount
,
M
,
O
,
StrideE1
,
BatchStrideE1
,
E1Layout
{}));
Tensor
<
E1DataType
>
e1_g_m_o_device_result
(
f_host_tensor_descriptor
(
BatchCount
,
M
,
O
,
StrideE1
,
BatchStrideE1
,
E1Layout
{}));
// Host verification: Output of Gemm0 is input A of Gemm1
Tensor
<
RefAcc0DataType
>
c0_g_m_n
(
f_host_tensor_descriptor
(
BatchCount
,
M
,
N
,
N
,
M
*
N
,
Row
{}));
Tensor
<
RefAcc0DataType
>
e0_g_m_n
(
f_host_tensor_descriptor
(
BatchCount
,
M
,
N
,
N
,
M
*
N
,
Row
{}));
Tensor
<
RefAcc1DataType
>
c1_g_m_o
(
f_host_tensor_descriptor
(
BatchCount
,
M
,
O
,
O
,
M
*
O
,
Row
{}));
std
::
cout
<<
"a0_g_m_k: "
<<
a0_g_m_k
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"b0_g_k_n: "
<<
b0_g_k_n
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"d0_g_m_n: "
<<
d0_g_m_n
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"b1_g_n_o: "
<<
b1_g_n_o
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"d1_g_m_o: "
<<
d1_g_m_o
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"e1_g_m_o: "
<<
e1_g_m_o_host_result
.
mDesc
<<
std
::
endl
;
switch
(
init_method
)
{
case
0
:
break
;
case
1
:
a0_g_m_k
.
GenerateTensorValue
(
GeneratorTensor_2
<
A0DataType
>
{
-
2
,
3
});
b0_g_k_n
.
GenerateTensorValue
(
GeneratorTensor_2
<
B0DataType
>
{
-
2
,
3
});
d0_g_m_n
.
GenerateTensorValue
(
GeneratorTensor_2
<
D0DataType
>
{
-
2
,
3
});
b1_g_n_o
.
GenerateTensorValue
(
GeneratorTensor_2
<
B1DataType
>
{
-
2
,
3
});
d1_g_m_o
.
GenerateTensorValue
(
GeneratorTensor_2
<
D1DataType
>
{
-
2
,
3
});
break
;
default:
a0_g_m_k
.
GenerateTensorValue
(
GeneratorTensor_3
<
A0DataType
>
{
0.0
,
1.0
});
b0_g_k_n
.
GenerateTensorValue
(
GeneratorTensor_3
<
B0DataType
>
{
-
0.5
,
0.5
});
d0_g_m_n
.
GenerateTensorValue
(
GeneratorTensor_3
<
D0DataType
>
{
0.0
,
1.0
});
b1_g_n_o
.
GenerateTensorValue
(
GeneratorTensor_3
<
B1DataType
>
{
-
0.5
,
0.5
});
d1_g_m_o
.
GenerateTensorValue
(
GeneratorTensor_3
<
D1DataType
>
{
0.0
,
1.0
});
}
DeviceMem
a0_g_m_k_device_buf
(
sizeof
(
A0DataType
)
*
a0_g_m_k
.
mDesc
.
GetElementSize
());
DeviceMem
b0_g_k_n_device_buf
(
sizeof
(
B0DataType
)
*
b0_g_k_n
.
mDesc
.
GetElementSize
());
DeviceMem
d0_g_m_n_device_buf
(
sizeof
(
D0DataType
)
*
d0_g_m_n
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
b1_g_n_o_device_buf
(
sizeof
(
B1DataType
)
*
b1_g_n_o
.
mDesc
.
GetElementSize
());
DeviceMem
d1_g_m_o_device_buf
(
sizeof
(
D1DataType
)
*
d1_g_m_o
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
e1_g_m_o_device_buf
(
sizeof
(
E1DataType
)
*
e1_g_m_o_device_result
.
mDesc
.
GetElementSize
());
a0_g_m_k_device_buf
.
ToDevice
(
a0_g_m_k
.
mData
.
data
());
b0_g_k_n_device_buf
.
ToDevice
(
b0_g_k_n
.
mData
.
data
());
d0_g_m_n_device_buf
.
ToDevice
(
d0_g_m_n
.
mData
.
data
());
b1_g_n_o_device_buf
.
ToDevice
(
b1_g_n_o
.
mData
.
data
());
d1_g_m_o_device_buf
.
ToDevice
(
d1_g_m_o
.
mData
.
data
());
auto
a0_element_op
=
A0ElementOp
{};
auto
b0_element_op
=
B0ElementOp
{};
auto
cde0_element_op
=
CDE0ElementOp
{};
auto
b1_element_op
=
B1ElementOp
{};
auto
cde1_element_op
=
CDE1ElementOp
{};
using
DeviceOp
=
tensor_operation
::
device
::
DeviceBatchedGemmMultipleDGemmMultipleD
<
A0Layout
,
B0Layout
,
D0sLayout
,
B1Layout
,
D1sLayout
,
E1Layout
,
A0DataType
,
B0DataType
,
D0sDataType
,
B1DataType
,
D1sDataType
,
E1DataType
,
A0ElementOp
,
B0ElementOp
,
CDE0ElementOp
,
B1ElementOp
,
CDE1ElementOp
>
;
// get device op instances
const
auto
op_ptrs
=
tensor_operation
::
device
::
instance
::
DeviceOperationInstanceFactory
<
DeviceOp
>::
GetInstances
();
std
::
cout
<<
"found "
<<
op_ptrs
.
size
()
<<
" instances"
<<
std
::
endl
;
if
(
do_verification
)
{
// Ref Gemm0
using
ReferenceGemm0Instance
=
tensor_operation
::
host
::
ReferenceBatchedGemm
<
A0DataType
,
B0DataType
,
RefAcc0DataType
,
RefAcc0DataType
,
A0ElementOp
,
B0ElementOp
,
PassThrough
>
;
// Ref Gemm1
using
ReferenceGemm1Instance
=
tensor_operation
::
host
::
ReferenceBatchedGemm
<
RefAcc0DataType
,
B1DataType
,
RefAcc1DataType
,
RefAcc1DataType
,
PassThrough
,
B1ElementOp
,
PassThrough
>
;
auto
ref_gemm0
=
ReferenceGemm0Instance
{};
auto
ref_gemm0_invoker
=
ref_gemm0
.
MakeInvoker
();
auto
ref_gemm0_argument
=
ref_gemm0
.
MakeArgument
(
a0_g_m_k
,
b0_g_k_n
,
c0_g_m_n
,
a0_element_op
,
b0_element_op
,
PassThrough
{});
ref_gemm0_invoker
.
Run
(
ref_gemm0_argument
);
// cde0_elementwise
e0_g_m_n
.
ForEach
(
[
&
](
auto
&
,
auto
idx
)
{
cde0_element_op
(
e0_g_m_n
(
idx
),
c0_g_m_n
(
idx
),
d0_g_m_n
(
idx
));
});
auto
ref_gemm1
=
ReferenceGemm1Instance
{};
auto
ref_gemm1_invoker
=
ref_gemm1
.
MakeInvoker
();
auto
ref_gemm1_argument
=
ref_gemm1
.
MakeArgument
(
e0_g_m_n
,
b1_g_n_o
,
c1_g_m_o
,
PassThrough
{},
b1_element_op
,
PassThrough
{});
ref_gemm1_invoker
.
Run
(
ref_gemm1_argument
);
// cde1_elementwise
e1_g_m_o_host_result
.
ForEach
([
&
](
auto
&
,
auto
idx
)
{
cde1_element_op
(
e1_g_m_o_host_result
(
idx
),
c1_g_m_o
(
idx
),
d1_g_m_o
(
idx
));
});
}
std
::
string
best_op_name
;
float
best_ave_time
=
0
;
float
best_tflops
=
0
;
float
best_gb_per_sec
=
0
;
// profile device op instances
for
(
auto
&
op_ptr
:
op_ptrs
)
{
auto
argument_ptr
=
op_ptr
->
MakeArgumentPointer
(
static_cast
<
A0DataType
*>
(
a0_g_m_k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
B0DataType
*>
(
b0_g_k_n_device_buf
.
GetDeviceBuffer
()),
std
::
array
<
const
void
*
,
1
>
{
d0_g_m_n_device_buf
.
GetDeviceBuffer
()},
static_cast
<
B1DataType
*>
(
b1_g_n_o_device_buf
.
GetDeviceBuffer
()),
std
::
array
<
const
void
*
,
1
>
{
d1_g_m_o_device_buf
.
GetDeviceBuffer
()},
static_cast
<
E1DataType
*>
(
e1_g_m_o_device_buf
.
GetDeviceBuffer
()),
M
,
N
,
K
,
O
,
BatchCount
,
StrideA0
,
StrideB0
,
std
::
array
<
ck
::
index_t
,
1
>
{
StrideD0
},
StrideB1
,
std
::
array
<
ck
::
index_t
,
1
>
{
StrideD1
},
StrideE1
,
BatchStrideA0
,
BatchStrideB0
,
std
::
array
<
ck
::
index_t
,
1
>
{
BatchStrideD0
},
BatchStrideB1
,
std
::
array
<
ck
::
index_t
,
1
>
{
BatchStrideD1
},
BatchStrideE1
,
a0_element_op
,
b0_element_op
,
cde0_element_op
,
b1_element_op
,
cde1_element_op
);
auto
invoker_ptr
=
op_ptr
->
MakeInvokerPointer
();
if
(
op_ptr
->
IsSupportedArgument
(
argument_ptr
.
get
()))
{
std
::
string
op_name
=
op_ptr
->
GetTypeString
();
float
ave_time
=
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
time_kernel
});
std
::
size_t
flop
=
(
size_t
(
M
)
*
N
*
K
*
2
+
size_t
(
M
)
*
N
*
O
*
2
)
*
BatchCount
;
std
::
size_t
num_btype
=
(
sizeof
(
A0DataType
)
*
M
*
K
+
sizeof
(
B0DataType
)
*
K
*
N
+
sizeof
(
D0DataType
)
*
N
+
sizeof
(
B1DataType
)
*
N
*
O
+
sizeof
(
E1DataType
)
*
M
*
O
+
sizeof
(
D1DataType
)
*
O
)
*
BatchCount
;
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
ave_time
;
float
gb_per_sec
=
num_btype
/
1.E6
/
ave_time
;
std
::
cout
<<
"Perf: "
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s, "
<<
op_name
<<
std
::
endl
;
if
(
tflops
>
best_tflops
)
{
best_op_name
=
op_name
;
best_tflops
=
tflops
;
best_ave_time
=
ave_time
;
best_gb_per_sec
=
gb_per_sec
;
}
if
(
do_verification
)
{
e1_g_m_o_device_buf
.
FromDevice
(
e1_g_m_o_device_result
.
mData
.
data
());
pass
=
pass
&
ck
::
utils
::
check_err
(
e1_g_m_o_device_result
.
mData
,
e1_g_m_o_host_result
.
mData
);
if
(
do_log
)
{
LogRangeAsType
<
float
>
(
std
::
cout
<<
"e1_g_m_o_host_result : "
,
e1_g_m_o_host_result
.
mData
,
","
)
<<
std
::
endl
;
LogRangeAsType
<
float
>
(
std
::
cout
<<
"e1_g_m_o_device_result : "
,
e1_g_m_o_device_result
.
mData
,
","
)
<<
std
::
endl
;
}
}
}
else
{
std
::
cout
<<
op_ptr
->
GetTypeString
()
<<
" does not support this problem"
<<
std
::
endl
;
}
}
std
::
cout
<<
"Best Perf: "
<<
best_ave_time
<<
" ms, "
<<
best_tflops
<<
" TFlops, "
<<
best_gb_per_sec
<<
" GB/s, "
<<
best_op_name
<<
std
::
endl
;
return
pass
;
}
}
// namespace profiler
}
// namespace ck
profiler/include/profile_batched_gemm_gemm_impl.hpp
View file @
b89a88b5
...
...
@@ -195,6 +195,12 @@ bool profile_batched_gemm_gemm_impl(bool do_verification,
std
::
cout
<<
"found "
<<
op_ptrs
.
size
()
<<
" instances"
<<
std
::
endl
;
// early fail when no instances are found
if
(
op_ptrs
.
size
()
==
0
)
{
return
false
;
}
if
(
do_verification
)
{
auto
ref_gemm0
=
ReferenceGemm0Instance
{};
...
...
Prev
1
…
8
9
10
11
12
13
14
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment