Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
823c8801
Commit
823c8801
authored
Jun 13, 2023
by
aska-0096
Browse files
Merge branch 'e2e_kernellib' of
https://github.com/aska-0096/navi3x_ck
into e2e_kernellib
parents
e305e41e
efee4541
Changes
13
Show whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
2761 additions
and
434 deletions
+2761
-434
example/32_batched_gemm_scale_softmax_gemm/CMakeLists.txt
example/32_batched_gemm_scale_softmax_gemm/CMakeLists.txt
+2
-0
example/32_batched_gemm_scale_softmax_gemm/batched_gemm_scale_softmax_gemm_permute_wmma_fp16.cpp
...emm/batched_gemm_scale_softmax_gemm_permute_wmma_fp16.cpp
+194
-71
example/32_batched_gemm_scale_softmax_gemm/cross_attention_forward_wmma_fp16.cpp
..._scale_softmax_gemm/cross_attention_forward_wmma_fp16.cpp
+332
-0
example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm_permute.inc
...tmax_gemm/run_batched_gemm_scale_softmax_gemm_permute.inc
+169
-127
example/32_batched_gemm_scale_softmax_gemm/run_cross_attention.inc
...2_batched_gemm_scale_softmax_gemm/run_cross_attention.inc
+344
-0
example/32_batched_gemm_scale_softmax_gemm/run_self_attention.inc
...32_batched_gemm_scale_softmax_gemm/run_self_attention.inc
+343
-0
example/32_batched_gemm_scale_softmax_gemm/self_attention_forward_wmma_fp16.cpp
...m_scale_softmax_gemm/self_attention_forward_wmma_fp16.cpp
+288
-0
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp
...evice_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp
+976
-96
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp
...impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp
+2
-2
include/ck/tensor_operation/gpu/device/masking_specialization.hpp
...ck/tensor_operation/gpu/device/masking_specialization.hpp
+4
-1
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp
...grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp
+11
-106
include/ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp
...tion/operator_transform/transform_contraction_to_gemm.hpp
+44
-31
script/unet_mha.sh
script/unet_mha.sh
+52
-0
No files found.
example/32_batched_gemm_scale_softmax_gemm/CMakeLists.txt
View file @
823c8801
...
@@ -8,6 +8,8 @@ add_example_executable(example_grouped_gemm_lower_triangle_scale_softmax_gemm_pe
...
@@ -8,6 +8,8 @@ add_example_executable(example_grouped_gemm_lower_triangle_scale_softmax_gemm_pe
if
(
GPU_TARGETS MATCHES
"gfx1100"
OR GPU_TARGETS MATCHES
"gfx1101"
OR GPU_TARGETS MATCHES
"gfx1102"
)
if
(
GPU_TARGETS MATCHES
"gfx1100"
OR GPU_TARGETS MATCHES
"gfx1101"
OR GPU_TARGETS MATCHES
"gfx1102"
)
add_example_executable
(
example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_wmma_fp16 batched_gemm_lower_triangle_scale_softmax_gemm_permute_wmma_fp16.cpp
)
add_example_executable
(
example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_wmma_fp16 batched_gemm_lower_triangle_scale_softmax_gemm_permute_wmma_fp16.cpp
)
add_example_executable
(
example_batched_gemm_scale_softmax_gemm_permute_wmma_fp16 batched_gemm_scale_softmax_gemm_permute_wmma_fp16.cpp
)
add_example_executable
(
example_batched_gemm_scale_softmax_gemm_permute_wmma_fp16 batched_gemm_scale_softmax_gemm_permute_wmma_fp16.cpp
)
add_example_executable
(
example_self_attention_forward_wmma_fp16 self_attention_forward_wmma_fp16.cpp
)
add_example_executable
(
example_cross_attention_forward_wmma_fp16 cross_attention_forward_wmma_fp16.cpp
)
endif
()
endif
()
add_custom_target
(
example_gemm_scale_softmax_gemm
)
add_custom_target
(
example_gemm_scale_softmax_gemm
)
...
...
example/32_batched_gemm_scale_softmax_gemm/batched_gemm_scale_softmax_gemm_permute_wmma_fp16.cpp
View file @
823c8801
...
@@ -67,77 +67,200 @@ static constexpr auto TensorSpecB0 = ck::tensor_operation::device::TensorSpecial
...
@@ -67,77 +67,200 @@ static constexpr auto TensorSpecB0 = ck::tensor_operation::device::TensorSpecial
static
constexpr
auto
TensorSpecB1
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecB1
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecC
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecC
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
using
DeviceGemmInstance
=
// clang-format off
// #define CK_MHA_USE_WAVE_1
// #define CK_MHA_USE_WAVE_2
#define CK_MHA_USE_WAVE_4
#define CK_MHA_USE_WAVE_8
using
DeviceMHAFactory
=
std
::
tuple
<
#ifdef CK_MHA_USE_WAVE_1
// 1 wave, mrepeat = 1, nrepeat = 2, k/o repeat = 1~5
ck
::
tensor_operation
::
device
::
DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
<
ck
::
tensor_operation
::
device
::
DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
<
NumDimG
,
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
NumDimM
,
ADataType
,
B0DataType
,
B1DataType
,
CDataType
,
Acc0BiasDataType
,
Acc0DataType
,
Acc1BiasDataType
,
Acc1DataType
,
CShuffleDataType
,
NumDimN
,
AElementOp
,
B0ElementOp
,
Acc0ElementOp
,
B1ElementOp
,
CElementOp
,
NumDimK
,
GemmSpec
,
TensorSpecA
,
TensorSpecB0
,
TensorSpecB1
,
TensorSpecC
,
1
,
NumDimO
,
32
,
ADataType
,
// Gemm 0
B0DataType
,
16
,
128
,
64
,
8
,
8
,
B1DataType
,
// Gemm 1
CDataType
,
64
,
64
,
8
,
Acc0BiasDataType
,
16
,
16
,
16
,
Acc0DataType
,
// Per repeat = wave_m = wave_num, wave_n = 1
Acc1BiasDataType
,
1
,
8
,
4
,
Acc1DataType
,
// ABlockTransfer MK -> K0 M K1
CShuffleDataType
,
S
<
2
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
AElementOp
,
// B0BlockTransfer LK -> K0 L K1
B0ElementOp
,
S
<
2
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
Acc0ElementOp
,
// B1BlockTransfer NL -> L0 N L1
B1ElementOp
,
S
<
2
,
2
,
8
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
8
,
1
,
false
,
CElementOp
,
// CShuffleBlockTransfer MN
GemmSpec
,
1
,
1
,
S
<
1
,
16
,
1
,
2
>
,
8
,
TensorSpecA
,
MaskingSpec
>
,
TensorSpecB0
,
ck
::
tensor_operation
::
device
::
DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
<
TensorSpecB1
,
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
TensorSpecC
,
ADataType
,
B0DataType
,
B1DataType
,
CDataType
,
Acc0BiasDataType
,
Acc0DataType
,
Acc1BiasDataType
,
Acc1DataType
,
CShuffleDataType
,
1
,
AElementOp
,
B0ElementOp
,
Acc0ElementOp
,
B1ElementOp
,
CElementOp
,
GemmSpec
,
TensorSpecA
,
TensorSpecB0
,
TensorSpecB1
,
TensorSpecC
,
1
,
32
,
// Gemm 0
16
,
64
,
64
,
8
,
8
,
// Gemm 1
64
,
64
,
8
,
16
,
16
,
16
,
// Per repeat = wave_m = wave_num, wave_n = 1
1
,
4
,
4
,
// ABlockTransfer MK -> K0 M K1
S
<
2
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
// B0BlockTransfer LK -> K0 L K1
S
<
2
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
// B1BlockTransfer NL -> L0 N L1
S
<
2
,
2
,
8
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
8
,
1
,
false
,
// CShuffleBlockTransfer MN
1
,
1
,
S
<
1
,
16
,
1
,
2
>
,
8
,
MaskingSpec
>
,
#endif
#ifdef CK_MHA_USE_WAVE_2
ck
::
tensor_operation
::
device
::
DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
ADataType
,
B0DataType
,
B1DataType
,
CDataType
,
Acc0BiasDataType
,
Acc0DataType
,
Acc1BiasDataType
,
Acc1DataType
,
CShuffleDataType
,
AElementOp
,
B0ElementOp
,
Acc0ElementOp
,
B1ElementOp
,
CElementOp
,
GemmSpec
,
TensorSpecA
,
TensorSpecB0
,
TensorSpecB1
,
TensorSpecC
,
1
,
64
,
// Gemm 0
32
,
128
,
64
,
8
,
8
,
// Gemm 1
64
,
64
,
8
,
16
,
16
,
16
,
// Per repeat = wave_m = wave_num, wave_n = 1
1
,
8
,
4
,
// ABlockTransfer MK -> K0 M K1
S
<
2
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
// B0BlockTransfer LK -> K0 L K1
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
// B1BlockTransfer NL -> L0 N L1
S
<
2
,
4
,
8
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
1
,
false
,
// CShuffleBlockTransfer MN
1
,
1
,
S
<
1
,
32
,
1
,
2
>
,
8
,
MaskingSpec
>
,
ck
::
tensor_operation
::
device
::
DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
ADataType
,
B0DataType
,
B1DataType
,
CDataType
,
Acc0BiasDataType
,
Acc0DataType
,
Acc1BiasDataType
,
Acc1DataType
,
CShuffleDataType
,
AElementOp
,
B0ElementOp
,
Acc0ElementOp
,
B1ElementOp
,
CElementOp
,
GemmSpec
,
TensorSpecA
,
TensorSpecB0
,
TensorSpecB1
,
TensorSpecC
,
1
,
64
,
// Gemm 0
32
,
64
,
64
,
8
,
8
,
// Gemm 1
64
,
64
,
8
,
16
,
16
,
16
,
// Per repeat = wave_m = wave_num, wave_n = 1
1
,
4
,
4
,
// ABlockTransfer MK -> K0 M K1
S
<
2
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
// B0BlockTransfer LK -> K0 L K1
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
// B1BlockTransfer NL -> L0 N L1
S
<
2
,
4
,
8
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
1
,
false
,
// CShuffleBlockTransfer MN
1
,
1
,
S
<
1
,
32
,
1
,
2
>
,
8
,
MaskingSpec
>
,
#endif
#ifdef CK_MHA_USE_WAVE_4
ck
::
tensor_operation
::
device
::
DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
ADataType
,
B0DataType
,
B1DataType
,
CDataType
,
Acc0BiasDataType
,
Acc0DataType
,
Acc1BiasDataType
,
Acc1DataType
,
CShuffleDataType
,
AElementOp
,
B0ElementOp
,
Acc0ElementOp
,
B1ElementOp
,
CElementOp
,
GemmSpec
,
TensorSpecA
,
TensorSpecB0
,
TensorSpecB1
,
TensorSpecC
,
1
,
128
,
// Gemm 0
64
,
128
,
64
,
8
,
8
,
// Gemm 1
64
,
64
,
8
,
16
,
16
,
16
,
// Per repeat = wave_m = wave_num, wave_n = 1
1
,
8
,
4
,
// ABlockTransfer MK -> K0 M K1
S
<
2
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
// B0BlockTransfer LK -> K0 L K1
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
// B1BlockTransfer NL -> L0 N L1
S
<
2
,
8
,
8
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
1
,
false
,
// CShuffleBlockTransfer MN
1
,
1
,
S
<
1
,
64
,
1
,
2
>
,
8
,
MaskingSpec
>
,
ck
::
tensor_operation
::
device
::
DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
ADataType
,
B0DataType
,
B1DataType
,
CDataType
,
Acc0BiasDataType
,
Acc0DataType
,
Acc1BiasDataType
,
Acc1DataType
,
CShuffleDataType
,
AElementOp
,
B0ElementOp
,
Acc0ElementOp
,
B1ElementOp
,
CElementOp
,
GemmSpec
,
TensorSpecA
,
TensorSpecB0
,
TensorSpecB1
,
TensorSpecC
,
1
,
128
,
// Gemm 0
64
,
64
,
64
,
8
,
8
,
// Gemm 1
64
,
64
,
8
,
16
,
16
,
16
,
// Per repeat = wave_m = wave_num, wave_n = 1
1
,
4
,
4
,
// ABlockTransfer MK -> K0 M K1
S
<
2
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
// B0BlockTransfer LK -> K0 L K1
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
// B1BlockTransfer NL -> L0 N L1
S
<
2
,
8
,
8
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
1
,
false
,
// CShuffleBlockTransfer MN
1
,
1
,
S
<
1
,
64
,
1
,
2
>
,
8
,
MaskingSpec
>
,
#endif
#ifdef CK_MHA_USE_WAVE_8
ck
::
tensor_operation
::
device
::
DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
ADataType
,
B0DataType
,
B1DataType
,
CDataType
,
Acc0BiasDataType
,
Acc0DataType
,
Acc1BiasDataType
,
Acc1DataType
,
CShuffleDataType
,
AElementOp
,
B0ElementOp
,
Acc0ElementOp
,
B1ElementOp
,
CElementOp
,
GemmSpec
,
TensorSpecA
,
TensorSpecB0
,
TensorSpecB1
,
TensorSpecC
,
1
,
256
,
256
,
// Gemm 0
// Gemm 0
128
,
// MPerBlock
128
,
128
,
64
,
8
,
8
,
64
,
// LPerBlock
64
,
// KPerBlock
8
,
// K1
// Gemm 1
// Gemm 1
64
,
// NPerBlock
64
,
64
,
8
,
64
,
// LTilePerBlock
16
,
16
,
16
,
8
,
// L1
16
,
// MPerWMMA
16
,
// LPerWMMA
16
,
// NPerWMMA
// Per repeat = wave_m = wave_num, wave_n = 1
// Per repeat = wave_m = wave_num, wave_n = 1
1
,
// MRepeat
1
,
8
,
4
,
4
,
// LRepeat
// ABlockTransfer MK -> K0 M K1
4
,
// NRepeat
S
<
2
,
128
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
// ABlockTransfer MK -> K0 M K1
// B0BlockTransfer LK -> K0 L K1
S
<
1
,
0
,
2
>
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
1
,
0
,
2
>
,
// B1BlockTransfer NL -> L0 N L1
2
,
S
<
2
,
16
,
8
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
1
,
false
,
8
,
// CShuffleBlockTransfer MN
8
,
1
,
1
,
S
<
1
,
128
,
1
,
2
>
,
8
,
true
,
MaskingSpec
>
,
S
<
4
,
64
,
1
>
,
// B0BlockTransfer LK -> K0 L K1
ck
::
tensor_operation
::
device
::
DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
<
S
<
1
,
0
,
2
>
,
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
S
<
1
,
0
,
2
>
,
ADataType
,
B0DataType
,
B1DataType
,
CDataType
,
Acc0BiasDataType
,
Acc0DataType
,
Acc1BiasDataType
,
Acc1DataType
,
CShuffleDataType
,
2
,
AElementOp
,
B0ElementOp
,
Acc0ElementOp
,
B1ElementOp
,
CElementOp
,
8
,
GemmSpec
,
TensorSpecA
,
TensorSpecB0
,
TensorSpecB1
,
TensorSpecC
,
1
,
8
,
256
,
true
,
// Gemm 0
S
<
4
,
8
,
8
>
,
// B1BlockTransfer NL -> L0 N L1
128
,
128
,
64
,
8
,
8
,
S
<
0
,
2
,
1
>
,
// Gemm 1
S
<
0
,
2
,
1
>
,
64
,
64
,
8
,
1
,
16
,
16
,
16
,
8
,
// Per repeat = wave_m = wave_num, wave_n = 1
1
,
1
,
8
,
4
,
false
,
// ABlockTransfer MK -> K0 M K1
1
,
// CShuffleMWmmaPerWavePerShuffle
S
<
2
,
128
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
2
,
// CShuffleNWmmaPerWavePerShuffle
// B0BlockTransfer LK -> K0 L K1
S
<
1
,
64
,
1
,
4
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
// B1BlockTransfer NL -> L0 N L1
MaskingSpec
>
;
// MaskingSpecialization
S
<
2
,
16
,
8
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
1
,
false
,
// CShuffleBlockTransfer MN
1
,
1
,
S
<
1
,
128
,
1
,
2
>
,
8
,
MaskingSpec
>
#endif
>
;
// clang-format on
// Ref Gemm0: fp16 in, fp32 out
// Ref Gemm0: fp16 in, fp32 out
using
ReferenceGemm0Instance
=
ck
::
tensor_operation
::
host
::
ReferenceBatchedGemm
<
ADataType
,
using
ReferenceGemm0Instance
=
ck
::
tensor_operation
::
host
::
ReferenceBatchedGemm
<
ADataType
,
B0DataType
,
B0DataType
,
...
...
example/32_batched_gemm_scale_softmax_gemm/cross_attention_forward_wmma_fp16.cpp
0 → 100644
View file @
823c8801
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
/*
Gemm + Softmax + Gemm fused operation. Computes C_g_m_n = Softmax(A_g_m_k * B0_g_k_l) * B1_g_l_n
|-----------------|
Gemm0
|-------------------------------------|
Gemm1
*/
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp"
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
ADataType
=
F16
;
using
B0DataType
=
F16
;
using
B1DataType
=
F16
;
using
Acc0DataType
=
F32
;
using
Acc1DataType
=
F32
;
using
CShuffleDataType
=
F32
;
using
CDataType
=
F16
;
using
Acc0BiasDataType
=
ck
::
Tuple
<>
;
using
Acc1BiasDataType
=
ck
::
Tuple
<>
;
static
constexpr
ck
::
index_t
NumDimG
=
2
;
static
constexpr
ck
::
index_t
NumDimM
=
1
;
static
constexpr
ck
::
index_t
NumDimN
=
1
;
static
constexpr
ck
::
index_t
NumDimK
=
1
;
static
constexpr
ck
::
index_t
NumDimO
=
1
;
using
AElementOp
=
PassThrough
;
using
B0ElementOp
=
PassThrough
;
using
Acc0ElementOp
=
ck
::
tensor_operation
::
element_wise
::
Scale
;
using
B1ElementOp
=
PassThrough
;
using
CElementOp
=
PassThrough
;
static
constexpr
auto
GemmSpec
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKOPadding
;
static
constexpr
auto
MaskingSpec
=
ck
::
tensor_operation
::
device
::
MaskingSpecialization
::
MaskDisabled
;
static
constexpr
auto
TensorSpecA
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecB0
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecB1
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecC
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
// clang-format off
#define CK_MHA_USE_WAVE_1
#define CK_MHA_USE_WAVE_2
#define CK_MHA_USE_WAVE_4
#define CK_MHA_USE_WAVE_8
using
DeviceMHAFactory
=
std
::
tuple
<
#ifdef CK_MHA_USE_WAVE_1
// 1 wave, mrepeat = 1, nrepeat = 2, k/o repeat = 1~5
ck
::
tensor_operation
::
device
::
DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
ADataType
,
B0DataType
,
B1DataType
,
CDataType
,
Acc0BiasDataType
,
Acc0DataType
,
Acc1BiasDataType
,
Acc1DataType
,
CShuffleDataType
,
AElementOp
,
B0ElementOp
,
Acc0ElementOp
,
B1ElementOp
,
CElementOp
,
GemmSpec
,
TensorSpecA
,
TensorSpecB0
,
TensorSpecB1
,
TensorSpecC
,
1
,
32
,
// Gemm 0
16
,
32
,
160
,
8
,
8
,
// Gemm 1
80
,
32
,
8
,
16
,
16
,
16
,
// Per repeat = wave_m = wave_num, wave_n = 1
1
,
2
,
5
,
// ABlockTransfer MK -> K0 M K1
S
<
2
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
// B0BlockTransfer LK -> K0 L K1
S
<
2
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
// B1BlockTransfer NL -> L0 N L1
S
<
2
,
2
,
8
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
8
,
1
,
false
,
// CShuffleBlockTransfer MN
1
,
1
,
S
<
1
,
16
,
1
,
2
>
,
8
,
MaskingSpec
>
,
ck
::
tensor_operation
::
device
::
DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
ADataType
,
B0DataType
,
B1DataType
,
CDataType
,
Acc0BiasDataType
,
Acc0DataType
,
Acc1BiasDataType
,
Acc1DataType
,
CShuffleDataType
,
AElementOp
,
B0ElementOp
,
Acc0ElementOp
,
B1ElementOp
,
CElementOp
,
GemmSpec
,
TensorSpecA
,
TensorSpecB0
,
TensorSpecB1
,
TensorSpecC
,
1
,
32
,
// Gemm 0
16
,
64
,
80
,
8
,
8
,
// Gemm 1
80
,
64
,
8
,
16
,
16
,
16
,
// Per repeat = wave_m = wave_num, wave_n = 1
1
,
4
,
5
,
// ABlockTransfer MK -> K0 M K1
S
<
2
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
// B0BlockTransfer LK -> K0 L K1
S
<
2
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
// B1BlockTransfer NL -> L0 N L1
S
<
2
,
2
,
8
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
8
,
1
,
false
,
// CShuffleBlockTransfer MN
1
,
1
,
S
<
1
,
16
,
1
,
2
>
,
8
,
MaskingSpec
>
,
ck
::
tensor_operation
::
device
::
DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
ADataType
,
B0DataType
,
B1DataType
,
CDataType
,
Acc0BiasDataType
,
Acc0DataType
,
Acc1BiasDataType
,
Acc1DataType
,
CShuffleDataType
,
AElementOp
,
B0ElementOp
,
Acc0ElementOp
,
B1ElementOp
,
CElementOp
,
GemmSpec
,
TensorSpecA
,
TensorSpecB0
,
TensorSpecB1
,
TensorSpecC
,
1
,
32
,
// Gemm 0
16
,
64
,
48
,
8
,
8
,
// Gemm 1
48
,
64
,
8
,
16
,
16
,
16
,
// Per repeat = wave_m = wave_num, wave_n = 1
1
,
4
,
3
,
// ABlockTransfer MK -> K0 M K1
S
<
2
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
// B0BlockTransfer LK -> K0 L K1
S
<
2
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
// B1BlockTransfer NL -> L0 N L1
S
<
2
,
2
,
8
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
8
,
1
,
false
,
// CShuffleBlockTransfer MN
1
,
1
,
S
<
1
,
16
,
1
,
2
>
,
8
,
MaskingSpec
>
,
#endif
#ifdef CK_MHA_USE_WAVE_2
ck
::
tensor_operation
::
device
::
DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
ADataType
,
B0DataType
,
B1DataType
,
CDataType
,
Acc0BiasDataType
,
Acc0DataType
,
Acc1BiasDataType
,
Acc1DataType
,
CShuffleDataType
,
AElementOp
,
B0ElementOp
,
Acc0ElementOp
,
B1ElementOp
,
CElementOp
,
GemmSpec
,
TensorSpecA
,
TensorSpecB0
,
TensorSpecB1
,
TensorSpecC
,
1
,
64
,
// Gemm 0
32
,
64
,
48
,
8
,
8
,
// Gemm 1
48
,
64
,
8
,
16
,
16
,
16
,
// Per repeat = wave_m = wave_num, wave_n = 1
1
,
4
,
3
,
// ABlockTransfer MK -> K0 M K1
S
<
2
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
// B0BlockTransfer LK -> K0 L K1
S
<
2
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
// B1BlockTransfer NL -> L0 N L1
S
<
2
,
4
,
8
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
1
,
false
,
// CShuffleBlockTransfer MN
1
,
1
,
S
<
1
,
32
,
1
,
2
>
,
8
,
MaskingSpec
>
,
ck
::
tensor_operation
::
device
::
DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
ADataType
,
B0DataType
,
B1DataType
,
CDataType
,
Acc0BiasDataType
,
Acc0DataType
,
Acc1BiasDataType
,
Acc1DataType
,
CShuffleDataType
,
AElementOp
,
B0ElementOp
,
Acc0ElementOp
,
B1ElementOp
,
CElementOp
,
GemmSpec
,
TensorSpecA
,
TensorSpecB0
,
TensorSpecB1
,
TensorSpecC
,
1
,
64
,
// Gemm 0
32
,
64
,
80
,
8
,
8
,
// Gemm 1
80
,
64
,
8
,
16
,
16
,
16
,
// Per repeat = wave_m = wave_num, wave_n = 1
1
,
4
,
5
,
// ABlockTransfer MK -> K0 M K1
S
<
2
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
// B0BlockTransfer LK -> K0 L K1
S
<
2
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
// B1BlockTransfer NL -> L0 N L1
S
<
2
,
4
,
8
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
1
,
false
,
// CShuffleBlockTransfer MN
1
,
1
,
S
<
1
,
32
,
1
,
2
>
,
8
,
MaskingSpec
>
,
ck
::
tensor_operation
::
device
::
DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
ADataType
,
B0DataType
,
B1DataType
,
CDataType
,
Acc0BiasDataType
,
Acc0DataType
,
Acc1BiasDataType
,
Acc1DataType
,
CShuffleDataType
,
AElementOp
,
B0ElementOp
,
Acc0ElementOp
,
B1ElementOp
,
CElementOp
,
GemmSpec
,
TensorSpecA
,
TensorSpecB0
,
TensorSpecB1
,
TensorSpecC
,
1
,
64
,
// Gemm 0
32
,
32
,
160
,
8
,
8
,
// Gemm 1
80
,
32
,
8
,
16
,
16
,
16
,
// Per repeat = wave_m = wave_num, wave_n = 1
1
,
2
,
5
,
// ABlockTransfer MK -> K0 M K1
S
<
2
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
// B0BlockTransfer LK -> K0 L K1
S
<
2
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
// B1BlockTransfer NL -> L0 N L1
S
<
2
,
4
,
8
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
1
,
false
,
// CShuffleBlockTransfer MN
1
,
1
,
S
<
1
,
32
,
1
,
2
>
,
8
,
MaskingSpec
>
,
#endif
#ifdef CK_MHA_USE_WAVE_4
ck
::
tensor_operation
::
device
::
DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
ADataType
,
B0DataType
,
B1DataType
,
CDataType
,
Acc0BiasDataType
,
Acc0DataType
,
Acc1BiasDataType
,
Acc1DataType
,
CShuffleDataType
,
AElementOp
,
B0ElementOp
,
Acc0ElementOp
,
B1ElementOp
,
CElementOp
,
GemmSpec
,
TensorSpecA
,
TensorSpecB0
,
TensorSpecB1
,
TensorSpecC
,
1
,
128
,
// Gemm 0
64
,
128
,
80
,
8
,
8
,
// Gemm 1
80
,
64
,
8
,
16
,
16
,
16
,
// Per repeat = wave_m = wave_num, wave_n = 1
1
,
8
,
5
,
// ABlockTransfer MK -> K0 M K1
S
<
2
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
// B0BlockTransfer LK -> K0 L K1
S
<
2
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
// B1BlockTransfer NL -> L0 N L1
S
<
2
,
8
,
8
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
1
,
false
,
// CShuffleBlockTransfer MN
1
,
1
,
S
<
1
,
64
,
1
,
2
>
,
8
,
MaskingSpec
>
,
ck
::
tensor_operation
::
device
::
DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
ADataType
,
B0DataType
,
B1DataType
,
CDataType
,
Acc0BiasDataType
,
Acc0DataType
,
Acc1BiasDataType
,
Acc1DataType
,
CShuffleDataType
,
AElementOp
,
B0ElementOp
,
Acc0ElementOp
,
B1ElementOp
,
CElementOp
,
GemmSpec
,
TensorSpecA
,
TensorSpecB0
,
TensorSpecB1
,
TensorSpecC
,
1
,
128
,
// Gemm 0
64
,
192
,
48
,
8
,
8
,
// Gemm 1
48
,
64
,
8
,
16
,
16
,
16
,
// Per repeat = wave_m = wave_num, wave_n = 1
1
,
12
,
3
,
// ABlockTransfer MK -> K0 M K1
S
<
2
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
// B0BlockTransfer LK -> K0 L K1
S
<
2
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
// B1BlockTransfer NL -> L0 N L1
S
<
2
,
8
,
8
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
1
,
false
,
// CShuffleBlockTransfer MN
1
,
1
,
S
<
1
,
64
,
1
,
2
>
,
8
,
MaskingSpec
>
,
ck
::
tensor_operation
::
device
::
DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
ADataType
,
B0DataType
,
B1DataType
,
CDataType
,
Acc0BiasDataType
,
Acc0DataType
,
Acc1BiasDataType
,
Acc1DataType
,
CShuffleDataType
,
AElementOp
,
B0ElementOp
,
Acc0ElementOp
,
B1ElementOp
,
CElementOp
,
GemmSpec
,
TensorSpecA
,
TensorSpecB0
,
TensorSpecB1
,
TensorSpecC
,
1
,
128
,
// Gemm 0
64
,
64
,
48
,
8
,
8
,
// Gemm 1
48
,
64
,
8
,
16
,
16
,
16
,
// Per repeat = wave_m = wave_num, wave_n = 1
1
,
4
,
3
,
// ABlockTransfer MK -> K0 M K1
S
<
2
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
// B0BlockTransfer LK -> K0 L K1
S
<
2
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
// B1BlockTransfer NL -> L0 N L1
S
<
2
,
8
,
8
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
1
,
false
,
// CShuffleBlockTransfer MN
1
,
1
,
S
<
1
,
64
,
1
,
2
>
,
8
,
MaskingSpec
>
,
#endif
#ifdef CK_MHA_USE_WAVE_8
ck
::
tensor_operation
::
device
::
DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
ADataType
,
B0DataType
,
B1DataType
,
CDataType
,
Acc0BiasDataType
,
Acc0DataType
,
Acc1BiasDataType
,
Acc1DataType
,
CShuffleDataType
,
AElementOp
,
B0ElementOp
,
Acc0ElementOp
,
B1ElementOp
,
CElementOp
,
GemmSpec
,
TensorSpecA
,
TensorSpecB0
,
TensorSpecB1
,
TensorSpecC
,
1
,
256
,
// Gemm 0
128
,
192
,
48
,
8
,
4
,
// Gemm 1
48
,
64
,
8
,
16
,
16
,
16
,
// Per repeat = wave_m = wave_num, wave_n = 1
1
,
12
,
3
,
// ABlockTransfer MK -> K0 M K1
S
<
2
,
128
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
// B0BlockTransfer LK -> K0 L K1
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
true
,
// B1BlockTransfer NL -> L0 N L1
S
<
2
,
16
,
8
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
1
,
false
,
// CShuffleBlockTransfer MN
1
,
1
,
S
<
1
,
128
,
1
,
2
>
,
8
,
MaskingSpec
>
#endif
>
;
// clang-format on
// Ref Gemm0: fp16 in, fp32 out
using
ReferenceGemm0Instance
=
ck
::
tensor_operation
::
host
::
ReferenceBatchedGemm
<
ADataType
,
B0DataType
,
Acc0DataType
,
Acc1DataType
,
AElementOp
,
B0ElementOp
,
Acc0ElementOp
>
;
// Ref Softmax: fp32 in, fp16 out
using
ReferenceSoftmaxInstance
=
ck
::
tensor_operation
::
host
::
ReferenceSoftmax
<
Acc0DataType
,
ADataType
,
Acc0DataType
>
;
// Ref Gemm1: fp16 in, fp16 out
using
ReferenceGemm1Instance
=
ck
::
tensor_operation
::
host
::
ReferenceBatchedGemm
<
ADataType
,
B1DataType
,
CDataType
,
Acc1DataType
,
AElementOp
,
B1ElementOp
,
CElementOp
>
;
#include "run_cross_attention.inc"
int
main
(
int
argc
,
char
*
argv
[])
{
return
run
(
argc
,
argv
);
}
example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm_permute.inc
View file @
823c8801
...
@@ -127,16 +127,31 @@ int run(int argc, char* argv[])
...
@@ -127,16 +127,31 @@ int run(int argc, char* argv[])
b0_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
B0DataType
>
{
-
2
,
2
});
b0_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
B0DataType
>
{
-
2
,
2
});
b1_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_2
<
B1DataType
>
{
-
2
,
2
});
b1_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_2
<
B1DataType
>
{
-
2
,
2
});
break
;
break
;
case
6
:
// Rand: a b0 ; unit:
b1 pass
case
6
:
// Rand: a b0 ; unit:
B1
a_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
ADataType
>
{
-
2
,
2
});
a_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
ADataType
>
{
-
2
,
2
});
b0_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
B0DataType
>
{
-
2
,
2
});
b0_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
B0DataType
>
{
-
2
,
2
});
b1_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
B1DataType
>
{});
b1_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
B1DataType
>
{});
break
;
break
;
case
7
:
// Rand: a b1 ; unit: b0
pass
case
7
:
// Rand: a b1 ; unit: b0
a_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
ADataType
>
{
-
2
,
2
});
a_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
ADataType
>
{
-
2
,
2
});
b0_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
B0DataType
>
{});
b0_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
B0DataType
>
{});
b1_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_2
<
B1DataType
>
{
-
2
,
2
});
b1_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_2
<
B1DataType
>
{
-
2
,
2
});
break
;
break
;
case
8
:
// Rand: a ; unit: b0 b1
a_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
ADataType
>
{
-
2
,
2
});
b0_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
B0DataType
>
{});
b1_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
B1DataType
>
{});
break
;
case
9
:
// Rand: b0 ; unit: a b1
a_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
ADataType
>
{});
b0_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
B0DataType
>
{
-
2
,
2
});
b1_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
B1DataType
>
{});
break
;
case
10
:
// Rand: b1 ; unit: a b0
a_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
ADataType
>
{});
b0_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
B0DataType
>
{});
b1_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_2
<
B1DataType
>
{
-
2
,
2
});
break
;
default
:
default
:
a_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
2
>
{});
a_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
2
>
{});
b0_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
B0DataType
>
{});
b0_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
B0DataType
>
{});
...
@@ -160,39 +175,37 @@ int run(int argc, char* argv[])
...
@@ -160,39 +175,37 @@ int run(int argc, char* argv[])
auto
c_element_op
=
CElementOp
{};
auto
c_element_op
=
CElementOp
{};
// do GEMM
// do GEMM
float
best_perf
=
.0
;
float
best_time
=
.0
;
int
not_pass
=
0
;
std
::
string
best_kernel
=
""
;
printf
(
"Verification: %s
\n
"
,
do_verification
?
"ON"
:
"OFF"
);
// TODO ANT: replace array with vector?
// TODO ANT: replace array with vector?
auto
gemm
=
DeviceGemmInstance
{};
ck
::
static_for
<
0
,
std
::
tuple_size_v
<
DeviceMHAFactory
>
,
1
>
{}([
&
](
auto
i
)
->
void
{
const
auto
device_conv_mha_instance
=
std
::
get
<
i
>
(
DeviceMHAFactory
{});
using
DeviceMHAInstance
=
ck
::
remove_cvref_t
<
decltype
(
device_conv_mha_instance
)
>
;
auto
gemm
=
DeviceMHAInstance
{};
auto
invoker
=
gemm
.
MakeInvoker
();
auto
invoker
=
gemm
.
MakeInvoker
();
auto
argument
=
gemm
.
MakeArgument
(
auto
argument
=
gemm
.
MakeArgument
(
static_cast
<
ADataType
*>
(
a_device_buf
.
GetDeviceBuffer
()),
static_cast
<
ADataType
*>
(
a_device_buf
.
GetDeviceBuffer
()),
static_cast
<
B0DataType
*>
(
b0_device_buf
.
GetDeviceBuffer
()),
static_cast
<
B0DataType
*>
(
b0_device_buf
.
GetDeviceBuffer
()),
static_cast
<
B1DataType
*>
(
b1_device_buf
.
GetDeviceBuffer
()),
static_cast
<
B1DataType
*>
(
b1_device_buf
.
GetDeviceBuffer
()),
static_cast
<
CDataType
*>
(
c_device_buf
.
GetDeviceBuffer
()),
static_cast
<
CDataType
*>
(
c_device_buf
.
GetDeviceBuffer
()),
{},
// std::array<void*, 1> p_acc0_biases;
M
,
{},
// std::array<void*, 1> p_acc1_biases;
N
,
a_gs_ms_ks_lengths
,
K
,
a_gs_ms_ks_strides
,
O
,
b0_gs_ns_ks_lengths
,
G0
,
b0_gs_ns_ks_strides
,
G1
,
b1_gs_os_ns_lengths
,
alpha
,
b1_gs_os_ns_strides
,
input_permute
,
c_gs_ms_os_lengths
,
output_permute
);
c_gs_ms_os_strides
,
{},
// std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_lengths},
{},
// std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_strides},
{},
// std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_lengths},
{},
// std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_strides},
a_element_op
,
b0_element_op
,
acc0_element_op
,
b1_element_op
,
c_element_op
);
if
(
!
gemm
.
IsSupportedArgument
(
argument
))
if
(
!
gemm
.
IsSupportedArgument
(
argument
))
{
{
std
::
cout
<<
gemm
.
GetTypeString
()
<<
" does not support this problem"
<<
std
::
endl
;
std
::
cout
<<
gemm
.
GetTypeString
()
<<
" does not support this problem"
<<
std
::
endl
;
return
0
;
//
return 0;
}
}
ck
::
index_t
BatchCount
=
G0
*
G1
;
ck
::
index_t
BatchCount
=
G0
*
G1
;
...
@@ -208,9 +221,14 @@ int run(int argc, char* argv[])
...
@@ -208,9 +221,14 @@ int run(int argc, char* argv[])
float
gb_per_sec
=
num_btype
/
1.E6
/
ave_time
;
float
gb_per_sec
=
num_btype
/
1.E6
/
ave_time
;
std
::
cout
<<
"Perf: "
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s, "
std
::
cout
<<
"Perf: "
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
gemm
.
GetTypeString
()
<<
std
::
endl
;
<<
" GB/s, "
<<
gemm
.
GetTypeString
()
<<
std
::
endl
;
if
(
tflops
>
best_perf
)
{
best_perf
=
tflops
;
best_time
=
ave_time
*
1000
;
best_kernel
=
gemm
.
GetTypeString
();
}
if
(
do_verification
)
if
(
do_verification
)
{
{
c_device_buf
.
FromDevice
(
c_gs_ms_os_device_result
.
mData
.
data
());
c_device_buf
.
FromDevice
(
c_gs_ms_os_device_result
.
mData
.
data
());
...
@@ -242,7 +260,7 @@ int run(int argc, char* argv[])
...
@@ -242,7 +260,7 @@ int run(int argc, char* argv[])
ref_gemm0_invoker
.
Run
(
ref_gemm0_argument
);
ref_gemm0_invoker
.
Run
(
ref_gemm0_argument
);
// masking
// masking
const
auto
mask
=
Device
Gemm
Instance
::
C0MatrixMask
(
N
);
const
auto
mask
=
typename
Device
MHA
Instance
::
C0MatrixMask
(
N
);
acc0_g_m_n
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
acc0_g_m_n
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
if
(
mask
.
IsMaskedElement
(
idx
[
1
],
idx
[
2
]))
if
(
mask
.
IsMaskedElement
(
idx
[
1
],
idx
[
2
]))
self
(
idx
)
=
-
ck
::
NumericLimits
<
float
>::
Infinity
();
self
(
idx
)
=
-
ck
::
NumericLimits
<
float
>::
Infinity
();
...
@@ -258,8 +276,12 @@ int run(int argc, char* argv[])
...
@@ -258,8 +276,12 @@ int run(int argc, char* argv[])
// gemm1
// gemm1
auto
ref_gemm1
=
ReferenceGemm1Instance
{};
auto
ref_gemm1
=
ReferenceGemm1Instance
{};
auto
ref_gemm1_invoker
=
ref_gemm1
.
MakeInvoker
();
auto
ref_gemm1_invoker
=
ref_gemm1
.
MakeInvoker
();
auto
ref_gemm1_argument
=
ref_gemm1
.
MakeArgument
(
auto
ref_gemm1_argument
=
ref_gemm1
.
MakeArgument
(
a1_g_m_n
,
a1_g_m_n
,
b1_g_n_o
,
c_g_m_o_host_result
,
PassThrough
{},
b1_element_op
,
c_element_op
);
b1_g_n_o
,
c_g_m_o_host_result
,
PassThrough
{},
b1_element_op
,
c_element_op
);
ref_gemm1_invoker
.
Run
(
ref_gemm1_argument
);
ref_gemm1_invoker
.
Run
(
ref_gemm1_argument
);
...
@@ -285,14 +307,34 @@ int run(int argc, char* argv[])
...
@@ -285,14 +307,34 @@ int run(int argc, char* argv[])
atol
=
1
e
-
2
;
atol
=
1
e
-
2
;
}
}
return
ck
::
utils
::
check_err
(
c_gs_ms_os_device_result
.
mData
,
bool
this_run_verification
=
ck
::
utils
::
check_err
(
c_gs_ms_os_device_result
.
mData
,
c_gs_ms_os_host_result
.
mData
,
c_gs_ms_os_host_result
.
mData
,
"Error: Incorrect results!"
,
"Error: Incorrect results!"
,
rtol
,
rtol
,
atol
)
atol
)
;
?
0
printf
(
"Verification: %s, Pass: %s
\n
"
,
:
1
;
do_verification
?
"ON"
:
"OFF"
,
}
this_run_verification
?
"YES"
:
"NO"
);
return
0
;
if
(
!
this_run_verification
)
{
not_pass
=
1
;
printf
(
"%d th MHA instance verification Failed
\n
"
,
i
.
value
);
}
}
});
std
::
cout
<<
"---------------------------------------------------------------------------------"
"-----------"
<<
std
::
endl
;
std
::
cout
<<
"Problem Size: BatchCount: "
<<
G0
<<
", HeadNum: "
<<
G1
<<
", M: "
<<
M
<<
", N: "
<<
N
<<
", K: "
<<
K
<<
", O: "
<<
O
<<
std
::
endl
;
std
::
cout
<<
"---------------------------------------------------------------------------------"
"-----------"
<<
std
::
endl
;
std
::
cout
<<
"Best kernel: "
<<
best_kernel
<<
" , "
<<
best_perf
<<
" TFlops , "
<<
best_time
<<
" us"
<<
std
::
endl
;
std
::
cout
<<
"---------------------------------------------------------------------------------"
"-----------"
<<
std
::
endl
;
return
not_pass
;
}
}
example/32_batched_gemm_scale_softmax_gemm/run_cross_attention.inc
0 → 100644
View file @
823c8801
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
int
run
(
int
argc
,
char
*
argv
[])
{
bool
do_verification
=
true
;
int
init_method
=
1
;
bool
time_kernel
=
false
;
// GEMM shape for A/B0/B1/C
// C_g_m_o = A_g_m_k * B0_g_k_n * B1_g_n_o
ck
::
index_t
M
=
256
;
ck
::
index_t
N
=
64
;
ck
::
index_t
K
=
80
;
ck
::
index_t
O
=
80
;
// Output shape C[G0, M, G1, O]. Batch dim, outer dim, inner dim must match GEMM shape
// C_g0_g1_m_o = reshape(C_g_m_o, [g0, g1, m, o])
// C_g0_m_g1_o = permute(C_g0_g1_m_o, [0, 2, 1, 3])
ck
::
index_t
G0
=
2
;
ck
::
index_t
G1
=
8
;
float
alpha
=
1
;
bool
input_permute
=
false
;
bool
output_permute
=
true
;
if
(
argc
==
1
)
{
// use default case
}
else
if
(
argc
==
4
)
{
do_verification
=
std
::
stoi
(
argv
[
1
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
}
else
if
(
argc
==
13
)
{
do_verification
=
std
::
stoi
(
argv
[
1
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
M
=
std
::
stoi
(
argv
[
4
]);
N
=
std
::
stoi
(
argv
[
5
]);
K
=
std
::
stoi
(
argv
[
6
]);
O
=
std
::
stoi
(
argv
[
7
]);
G0
=
std
::
stoi
(
argv
[
8
]);
G1
=
std
::
stoi
(
argv
[
9
]);
alpha
=
std
::
stof
(
argv
[
10
]);
input_permute
=
std
::
stoi
(
argv
[
11
]);
output_permute
=
std
::
stoi
(
argv
[
12
]);
}
else
{
printf
(
"arg1: verification (0=no, 1=yes)
\n
"
);
printf
(
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)
\n
"
);
printf
(
"arg3: time kernel (0=no, 1=yes)
\n
"
);
printf
(
"arg4 to 11: M, N, K, O, G0, G1
\n
"
);
printf
(
"arg10: scale (alpha)
\n
"
);
printf
(
"arg11 to 12: input / output permute
\n
"
);
exit
(
0
);
}
std
::
vector
<
ck
::
index_t
>
a_gs_ms_ks_lengths
{
G0
,
G1
,
M
,
K
};
std
::
vector
<
ck
::
index_t
>
a_gs_ms_ks_strides
=
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
K
,
K
,
G1
*
K
,
1
}
// A layout [G0, M, G1, K]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
K
,
M
*
K
,
K
,
1
};
// A layout [G0, G1, M, K]
std
::
vector
<
ck
::
index_t
>
b0_gs_ns_ks_lengths
{
G0
,
G1
,
N
,
K
};
std
::
vector
<
ck
::
index_t
>
b0_gs_ns_ks_strides
=
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G1
*
K
,
K
,
G1
*
K
,
1
}
// B0 layout [G0, N, G1, K]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
N
*
K
,
N
*
K
,
K
,
1
};
// B0 layout [G0, G1, N, K]
std
::
vector
<
ck
::
index_t
>
b1_gs_os_ns_lengths
{
G0
,
G1
,
O
,
N
};
std
::
vector
<
ck
::
index_t
>
b1_gs_os_ns_strides
=
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G1
*
O
,
O
,
1
,
G1
*
O
}
// B1 layout [G0, N, G1, O]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
N
*
O
,
N
*
O
,
1
,
O
};
// B1 layout [G0, G1, N, O]
std
::
vector
<
ck
::
index_t
>
c_gs_ms_os_lengths
{
G0
,
G1
,
M
,
O
};
std
::
vector
<
ck
::
index_t
>
c_gs_ms_os_strides
=
output_permute
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
O
,
O
,
G1
*
O
,
1
}
// C layout [G0, M, G1, O]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
O
,
M
*
O
,
O
,
1
};
// C layout [G0, G1, M, O]
Tensor
<
ADataType
>
a_gs_ms_ks
(
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
);
Tensor
<
B0DataType
>
b0_gs_ns_ks
(
b0_gs_ns_ks_lengths
,
b0_gs_ns_ks_strides
);
Tensor
<
B1DataType
>
b1_gs_os_ns
(
b1_gs_os_ns_lengths
,
b1_gs_os_ns_strides
);
Tensor
<
CDataType
>
c_gs_ms_os_host_result
(
c_gs_ms_os_lengths
,
c_gs_ms_os_strides
);
Tensor
<
CDataType
>
c_gs_ms_os_device_result
(
c_gs_ms_os_lengths
,
c_gs_ms_os_strides
);
std
::
cout
<<
"a_gs_ms_ks: "
<<
a_gs_ms_ks
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"b0_gs_ns_ks: "
<<
b0_gs_ns_ks
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"b1_gs_os_ns: "
<<
b1_gs_os_ns
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"c_gs_ms_os: "
<<
c_gs_ms_os_host_result
.
mDesc
<<
std
::
endl
;
switch
(
init_method
)
{
case
0
:
break
;
case
1
:
a_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
ADataType
>
{
-
2
,
2
});
b0_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
B0DataType
>
{
-
2
,
2
});
b1_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_2
<
B1DataType
>
{
-
2
,
2
});
break
;
case
2
:
a_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_3
<
ADataType
>
{
0.0
,
1.0
});
b0_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_3
<
B0DataType
>
{
0.0
,
1.0
});
b1_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_3
<
B1DataType
>
{
-
0.5
,
0.5
});
break
;
case
3
:
a_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
ADataType
>
{
-
2
,
2
});
b0_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
B0DataType
>
{});
b1_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
B1DataType
>
{});
break
;
case
4
:
// A, B0, B1 1
a_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
ADataType
>
{});
b0_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
B0DataType
>
{});
b1_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
B1DataType
>
{});
break
;
case
5
:
// Rand: b1 b0; unit: a
a_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
ADataType
>
{});
b0_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
B0DataType
>
{
-
2
,
2
});
b1_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_2
<
B1DataType
>
{
-
2
,
2
});
break
;
case
6
:
// Rand: a b0 ; unit: B1
a_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
ADataType
>
{
-
2
,
2
});
b0_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
B0DataType
>
{
-
2
,
2
});
b1_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
B1DataType
>
{});
break
;
case
7
:
// Rand: a b1 ; unit: b0
a_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
ADataType
>
{
-
2
,
2
});
b0_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
B0DataType
>
{});
b1_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_2
<
B1DataType
>
{
-
2
,
2
});
break
;
case
8
:
// Rand: a ; unit: b0 b1
a_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
ADataType
>
{
-
2
,
2
});
b0_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
B0DataType
>
{});
b1_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
B1DataType
>
{});
break
;
case
9
:
// Rand: b0 ; unit: a b1
a_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
ADataType
>
{});
b0_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
B0DataType
>
{
-
2
,
2
});
b1_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
B1DataType
>
{});
break
;
case
10
:
// Rand: b1 ; unit: a b0
a_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
ADataType
>
{});
b0_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
B0DataType
>
{});
b1_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_2
<
B1DataType
>
{
-
2
,
2
});
break
;
default
:
a_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
2
>
{});
b0_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
B0DataType
>
{});
b1_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
B1DataType
>
{});
}
std
::
vector
<
ck
::
index_t
>
kv_gs_ns_ks_lengths
{
G0
,
G1
,
N
,
2
,
K
};
std
::
vector
<
ck
::
index_t
>
kv_gs_ns_ks_strides
=
std
::
vector
<
ck
::
index_t
>
{
N
*
G1
*
2
*
K
,
2
*
K
,
G1
*
2
*
K
,
K
,
1
};
// kv layout [G0, M, G1, 2, K]
Tensor
<
ADataType
>
kv_gs_ns_ks
(
kv_gs_ns_ks_lengths
,
kv_gs_ns_ks_strides
);
// merge kv into a packed pointer send to device
b0_gs_ns_ks
.
ForEach
(
[
&
](
auto
&
self
,
auto
idx
)
{
kv_gs_ns_ks
(
idx
[
0
],
idx
[
1
],
idx
[
2
],
0
,
idx
[
3
])
=
self
(
idx
);
});
b1_gs_os_ns
.
ForEach
(
[
&
](
auto
&
self
,
auto
idx
)
{
kv_gs_ns_ks
(
idx
[
0
],
idx
[
1
],
idx
[
3
],
1
,
idx
[
2
])
=
self
(
idx
);
});
DeviceMem
q_device_buf
(
sizeof
(
ADataType
)
*
a_gs_ms_ks
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
kv_device_buf
(
sizeof
(
B0DataType
)
*
b0_gs_ns_ks
.
mDesc
.
GetElementSpaceSize
()
+
sizeof
(
B1DataType
)
*
b1_gs_os_ns
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
c_device_buf
(
sizeof
(
CDataType
)
*
c_gs_ms_os_device_result
.
mDesc
.
GetElementSpaceSize
());
q_device_buf
.
ToDevice
(
a_gs_ms_ks
.
mData
.
data
());
kv_device_buf
.
ToDevice
(
kv_gs_ns_ks
.
mData
.
data
());
auto
a_element_op
=
AElementOp
{};
auto
b0_element_op
=
B0ElementOp
{};
auto
acc0_element_op
=
Acc0ElementOp
{
alpha
};
auto
b1_element_op
=
B1ElementOp
{};
auto
c_element_op
=
CElementOp
{};
// do GEMM
float
best_perf
=
.0
;
float
best_time
=
.0
;
int
not_pass
=
0
;
std
::
string
best_kernel
=
""
;
printf
(
"Verification: %s
\n
"
,
do_verification
?
"ON"
:
"OFF"
);
// TODO ANT: replace array with vector?
ck
::
static_for
<
0
,
std
::
tuple_size_v
<
DeviceMHAFactory
>
,
1
>
{}([
&
](
auto
i
)
->
void
{
const
auto
device_conv_mha_instance
=
std
::
get
<
i
>
(
DeviceMHAFactory
{});
using
DeviceMHAInstance
=
ck
::
remove_cvref_t
<
decltype
(
device_conv_mha_instance
)
>
;
auto
gemm
=
DeviceMHAInstance
{};
auto
invoker
=
gemm
.
MakeCrossAttnInvoker
();
auto
argument
=
gemm
.
MakeCrossAttnArgument
(
static_cast
<
ADataType
*>
(
q_device_buf
.
GetDeviceBuffer
()),
static_cast
<
B0DataType
*>
(
kv_device_buf
.
GetDeviceBuffer
()),
static_cast
<
CDataType
*>
(
c_device_buf
.
GetDeviceBuffer
()),
G0
,
M
,
N
,
G1
,
K
,
alpha
);
// if(!gemm.IsSupportedArgument(argument))
// {
// std::cout << gemm.GetTypeString() << " does not support this problem" << std::endl;
// return 0;
// }
ck
::
index_t
BatchCount
=
G0
*
G1
;
float
ave_time
=
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
time_kernel
});
std
::
size_t
flop
=
(
size_t
(
M
)
*
N
*
K
*
2
+
size_t
(
M
)
*
N
*
O
*
2
)
*
BatchCount
;
std
::
size_t
num_btype
=
(
sizeof
(
ADataType
)
*
M
*
K
+
sizeof
(
B0DataType
)
*
K
*
N
+
sizeof
(
B1DataType
)
*
N
*
O
+
sizeof
(
CDataType
)
*
M
*
O
)
*
BatchCount
;
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
ave_time
;
float
gb_per_sec
=
num_btype
/
1.E6
/
ave_time
;
std
::
cout
<<
"Perf: "
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s, "
<<
gemm
.
GetTypeString
()
<<
std
::
endl
;
if
(
tflops
>
best_perf
)
{
best_perf
=
tflops
;
best_time
=
ave_time
*
1000
;
best_kernel
=
gemm
.
GetTypeString
();
}
if
(
do_verification
)
{
c_device_buf
.
FromDevice
(
c_gs_ms_os_device_result
.
mData
.
data
());
Tensor
<
ADataType
>
a_g_m_k
({
BatchCount
,
M
,
K
});
Tensor
<
B0DataType
>
b0_g_k_n
({
BatchCount
,
K
,
N
});
Tensor
<
B1DataType
>
b1_g_n_o
({
BatchCount
,
N
,
O
});
Tensor
<
Acc0DataType
>
acc0_g_m_n
({
BatchCount
,
M
,
N
});
// scratch object after gemm0
Tensor
<
ADataType
>
a1_g_m_n
({
BatchCount
,
M
,
N
});
// scratch object after softmax
Tensor
<
CDataType
>
c_g_m_o_host_result
({
BatchCount
,
M
,
O
});
// scratch object after gemm1
// permute
a_gs_ms_ks
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
a_g_m_k
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
b0_gs_ns_ks
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
b0_g_k_n
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
3
],
idx
[
2
])
=
self
(
idx
);
});
b1_gs_os_ns
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
b1_g_n_o
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
3
],
idx
[
2
])
=
self
(
idx
);
});
// gemm 0
auto
ref_gemm0
=
ReferenceGemm0Instance
{};
auto
ref_gemm0_invoker
=
ref_gemm0
.
MakeInvoker
();
auto
ref_gemm0_argument
=
ref_gemm0
.
MakeArgument
(
a_g_m_k
,
b0_g_k_n
,
acc0_g_m_n
,
a_element_op
,
b0_element_op
,
acc0_element_op
);
ref_gemm0_invoker
.
Run
(
ref_gemm0_argument
);
// masking
const
auto
mask
=
typename
DeviceMHAInstance
::
C0MatrixMask
(
N
);
acc0_g_m_n
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
if
(
mask
.
IsMaskedElement
(
idx
[
1
],
idx
[
2
]))
self
(
idx
)
=
-
ck
::
NumericLimits
<
float
>::
Infinity
();
});
// softmax
auto
ref_softmax
=
ReferenceSoftmaxInstance
{};
auto
ref_softmax_invoker
=
ref_softmax
.
MakeInvoker
();
auto
ref_softmax_argument
=
ref_softmax
.
MakeArgument
(
acc0_g_m_n
,
a1_g_m_n
,
1
,
0
,
{
2
});
ref_softmax_invoker
.
Run
(
ref_softmax_argument
);
// gemm1
auto
ref_gemm1
=
ReferenceGemm1Instance
{};
auto
ref_gemm1_invoker
=
ref_gemm1
.
MakeInvoker
();
auto
ref_gemm1_argument
=
ref_gemm1
.
MakeArgument
(
a1_g_m_n
,
b1_g_n_o
,
c_g_m_o_host_result
,
PassThrough
{},
b1_element_op
,
c_element_op
);
ref_gemm1_invoker
.
Run
(
ref_gemm1_argument
);
// permute
c_gs_ms_os_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g1
=
idx
[
1
];
const
size_t
g
=
g0
*
G1
+
g1
;
self
(
idx
)
=
c_g_m_o_host_result
(
g
,
idx
[
2
],
idx
[
3
]);
});
// default absolute error and relative error is 0.001
double
rtol
=
1
e
-
3
;
double
atol
=
1
e
-
3
;
// when BF16 is taken, set absolute error and relative error to 0.01
if
(
std
::
is_same_v
<
ADataType
,
ck
::
bhalf_t
>
&&
std
::
is_same_v
<
B0DataType
,
ck
::
bhalf_t
>
&&
std
::
is_same_v
<
B1DataType
,
ck
::
bhalf_t
>
&&
std
::
is_same_v
<
CDataType
,
ck
::
bhalf_t
>
)
{
rtol
=
1
e
-
2
;
atol
=
1
e
-
2
;
}
bool
this_run_verification
=
ck
::
utils
::
check_err
(
c_gs_ms_os_device_result
.
mData
,
c_gs_ms_os_host_result
.
mData
,
"Error: Incorrect results!"
,
rtol
,
atol
);
printf
(
"Verification: %s, Pass: %s
\n
"
,
do_verification
?
"ON"
:
"OFF"
,
this_run_verification
?
"YES"
:
"NO"
);
if
(
!
this_run_verification
)
{
not_pass
=
1
;
printf
(
"%d th MHA instance verification Failed
\n
"
,
i
.
value
);
}
}
});
std
::
cout
<<
"---------------------------------------------------------------------------------"
"-----------"
<<
std
::
endl
;
std
::
cout
<<
"Problem Size: BatchCount: "
<<
G0
<<
", HeadNum: "
<<
G1
<<
", M: "
<<
M
<<
", N: "
<<
N
<<
", K: "
<<
K
<<
", O: "
<<
O
<<
std
::
endl
;
std
::
cout
<<
"---------------------------------------------------------------------------------"
"-----------"
<<
std
::
endl
;
std
::
cout
<<
"Best kernel: "
<<
best_kernel
<<
" , "
<<
best_perf
<<
" TFlops , "
<<
best_time
<<
" us"
<<
std
::
endl
;
std
::
cout
<<
"---------------------------------------------------------------------------------"
"-----------"
<<
std
::
endl
;
return
not_pass
;
}
example/32_batched_gemm_scale_softmax_gemm/run_self_attention.inc
0 → 100644
View file @
823c8801
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
int
run
(
int
argc
,
char
*
argv
[])
{
bool
do_verification
=
true
;
int
init_method
=
1
;
bool
time_kernel
=
false
;
// GEMM shape for A/B0/B1/C
// C_g_m_o = A_g_m_k * B0_g_k_n * B1_g_n_o
ck
::
index_t
M
=
256
;
ck
::
index_t
N
=
256
;
ck
::
index_t
K
=
80
;
ck
::
index_t
O
=
80
;
// Output shape C[G0, M, G1, O]. Batch dim, outer dim, inner dim must match GEMM shape
// C_g0_g1_m_o = reshape(C_g_m_o, [g0, g1, m, o])
// C_g0_m_g1_o = permute(C_g0_g1_m_o, [0, 2, 1, 3])
ck
::
index_t
G0
=
2
;
ck
::
index_t
G1
=
8
;
float
alpha
=
1
;
bool
input_permute
=
false
;
bool
output_permute
=
true
;
if
(
argc
==
1
)
{
// use default case
}
else
if
(
argc
==
4
)
{
do_verification
=
std
::
stoi
(
argv
[
1
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
}
else
if
(
argc
==
13
)
{
do_verification
=
std
::
stoi
(
argv
[
1
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
M
=
std
::
stoi
(
argv
[
4
]);
N
=
std
::
stoi
(
argv
[
5
]);
K
=
std
::
stoi
(
argv
[
6
]);
O
=
std
::
stoi
(
argv
[
7
]);
G0
=
std
::
stoi
(
argv
[
8
]);
G1
=
std
::
stoi
(
argv
[
9
]);
alpha
=
std
::
stof
(
argv
[
10
]);
input_permute
=
std
::
stoi
(
argv
[
11
]);
output_permute
=
std
::
stoi
(
argv
[
12
]);
}
else
{
printf
(
"arg1: verification (0=no, 1=yes)
\n
"
);
printf
(
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)
\n
"
);
printf
(
"arg3: time kernel (0=no, 1=yes)
\n
"
);
printf
(
"arg4 to 11: M, N, K, O, G0, G1
\n
"
);
printf
(
"arg10: scale (alpha)
\n
"
);
printf
(
"arg11 to 12: input / output permute
\n
"
);
exit
(
0
);
}
std
::
vector
<
ck
::
index_t
>
a_gs_ms_ks_lengths
{
G0
,
G1
,
M
,
K
};
std
::
vector
<
ck
::
index_t
>
a_gs_ms_ks_strides
=
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
K
,
K
,
G1
*
K
,
1
}
// A layout [G0, M, G1, K]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
K
,
M
*
K
,
K
,
1
};
// A layout [G0, G1, M, K]
std
::
vector
<
ck
::
index_t
>
b0_gs_ns_ks_lengths
{
G0
,
G1
,
N
,
K
};
std
::
vector
<
ck
::
index_t
>
b0_gs_ns_ks_strides
=
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G1
*
K
,
K
,
G1
*
K
,
1
}
// B0 layout [G0, N, G1, K]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
N
*
K
,
N
*
K
,
K
,
1
};
// B0 layout [G0, G1, N, K]
std
::
vector
<
ck
::
index_t
>
b1_gs_os_ns_lengths
{
G0
,
G1
,
O
,
N
};
std
::
vector
<
ck
::
index_t
>
b1_gs_os_ns_strides
=
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G1
*
O
,
O
,
1
,
G1
*
O
}
// B1 layout [G0, N, G1, O]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
N
*
O
,
N
*
O
,
1
,
O
};
// B1 layout [G0, G1, N, O]
std
::
vector
<
ck
::
index_t
>
c_gs_ms_os_lengths
{
G0
,
G1
,
M
,
O
};
std
::
vector
<
ck
::
index_t
>
c_gs_ms_os_strides
=
output_permute
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
O
,
O
,
G1
*
O
,
1
}
// C layout [G0, M, G1, O]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
O
,
M
*
O
,
O
,
1
};
// C layout [G0, G1, M, O]
Tensor
<
ADataType
>
a_gs_ms_ks
(
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
);
Tensor
<
B0DataType
>
b0_gs_ns_ks
(
b0_gs_ns_ks_lengths
,
b0_gs_ns_ks_strides
);
Tensor
<
B1DataType
>
b1_gs_os_ns
(
b1_gs_os_ns_lengths
,
b1_gs_os_ns_strides
);
Tensor
<
CDataType
>
c_gs_ms_os_host_result
(
c_gs_ms_os_lengths
,
c_gs_ms_os_strides
);
Tensor
<
CDataType
>
c_gs_ms_os_device_result
(
c_gs_ms_os_lengths
,
c_gs_ms_os_strides
);
std
::
cout
<<
"a_gs_ms_ks: "
<<
a_gs_ms_ks
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"b0_gs_ns_ks: "
<<
b0_gs_ns_ks
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"b1_gs_os_ns: "
<<
b1_gs_os_ns
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"c_gs_ms_os: "
<<
c_gs_ms_os_host_result
.
mDesc
<<
std
::
endl
;
switch
(
init_method
)
{
case
0
:
break
;
case
1
:
a_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
ADataType
>
{
-
2
,
2
});
b0_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
B0DataType
>
{
-
2
,
2
});
b1_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_2
<
B1DataType
>
{
-
2
,
2
});
break
;
case
2
:
a_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_3
<
ADataType
>
{
0.0
,
1.0
});
b0_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_3
<
B0DataType
>
{
0.0
,
1.0
});
b1_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_3
<
B1DataType
>
{
-
0.5
,
0.5
});
break
;
case
3
:
a_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
ADataType
>
{
-
2
,
2
});
b0_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
B0DataType
>
{});
b1_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
B1DataType
>
{});
break
;
case
4
:
// A, B0, B1 1
a_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
ADataType
>
{});
b0_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
B0DataType
>
{});
b1_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
B1DataType
>
{});
break
;
case
5
:
// Rand: b1 b0; unit: a
a_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
ADataType
>
{});
b0_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
B0DataType
>
{
-
2
,
2
});
b1_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_2
<
B1DataType
>
{
-
2
,
2
});
break
;
case
6
:
// Rand: a b0 ; unit: B1
a_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
ADataType
>
{
-
2
,
2
});
b0_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
B0DataType
>
{
-
2
,
2
});
b1_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
B1DataType
>
{});
break
;
case
7
:
// Rand: a b1 ; unit: b0
a_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
ADataType
>
{
-
2
,
2
});
b0_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
B0DataType
>
{});
b1_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_2
<
B1DataType
>
{
-
2
,
2
});
break
;
case
8
:
// Rand: a ; unit: b0 b1
a_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
ADataType
>
{
-
2
,
2
});
b0_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
B0DataType
>
{});
b1_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
B1DataType
>
{});
break
;
case
9
:
// Rand: b0 ; unit: a b1
a_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
ADataType
>
{});
b0_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
B0DataType
>
{
-
2
,
2
});
b1_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
B1DataType
>
{});
break
;
case
10
:
// Rand: b1 ; unit: a b0
a_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
ADataType
>
{});
b0_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
B0DataType
>
{});
b1_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_2
<
B1DataType
>
{
-
2
,
2
});
break
;
default
:
a_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
2
>
{});
b0_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
B0DataType
>
{});
b1_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
B1DataType
>
{});
}
std
::
vector
<
ck
::
index_t
>
qkv_gs_ms_ks_lengths
{
G0
,
G1
,
M
,
3
,
K
};
std
::
vector
<
ck
::
index_t
>
qkv_gs_ms_ks_strides
=
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
3
*
K
,
3
*
K
,
G1
*
3
*
K
,
K
,
1
};
// qkv layout [G0, M, G1, 3, K]
Tensor
<
ADataType
>
qkv_gs_ms_ks
(
qkv_gs_ms_ks_lengths
,
qkv_gs_ms_ks_strides
);
// merge qkv into a packed pointer send to device
a_gs_ms_ks
.
ForEach
(
[
&
](
auto
&
self
,
auto
idx
)
{
qkv_gs_ms_ks
(
idx
[
0
],
idx
[
1
],
idx
[
2
],
0
,
idx
[
3
])
=
self
(
idx
);
});
b0_gs_ns_ks
.
ForEach
(
[
&
](
auto
&
self
,
auto
idx
)
{
qkv_gs_ms_ks
(
idx
[
0
],
idx
[
1
],
idx
[
2
],
1
,
idx
[
3
])
=
self
(
idx
);
});
b1_gs_os_ns
.
ForEach
(
[
&
](
auto
&
self
,
auto
idx
)
{
qkv_gs_ms_ks
(
idx
[
0
],
idx
[
1
],
idx
[
3
],
2
,
idx
[
2
])
=
self
(
idx
);
});
DeviceMem
qkv_device_buf
(
sizeof
(
ADataType
)
*
a_gs_ms_ks
.
mDesc
.
GetElementSpaceSize
()
+
sizeof
(
B0DataType
)
*
b0_gs_ns_ks
.
mDesc
.
GetElementSpaceSize
()
+
sizeof
(
B1DataType
)
*
b1_gs_os_ns
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
c_device_buf
(
sizeof
(
CDataType
)
*
c_gs_ms_os_device_result
.
mDesc
.
GetElementSpaceSize
());
qkv_device_buf
.
ToDevice
(
qkv_gs_ms_ks
.
mData
.
data
());
auto
a_element_op
=
AElementOp
{};
auto
b0_element_op
=
B0ElementOp
{};
auto
acc0_element_op
=
Acc0ElementOp
{
alpha
};
auto
b1_element_op
=
B1ElementOp
{};
auto
c_element_op
=
CElementOp
{};
// do GEMM
float
best_perf
=
.0
;
float
best_time
=
.0
;
int
not_pass
=
0
;
std
::
string
best_kernel
=
""
;
printf
(
"Verification: %s
\n
"
,
do_verification
?
"ON"
:
"OFF"
);
// TODO ANT: replace array with vector?
ck
::
static_for
<
0
,
std
::
tuple_size_v
<
DeviceMHAFactory
>
,
1
>
{}([
&
](
auto
i
)
->
void
{
const
auto
device_conv_mha_instance
=
std
::
get
<
i
>
(
DeviceMHAFactory
{});
using
DeviceMHAInstance
=
ck
::
remove_cvref_t
<
decltype
(
device_conv_mha_instance
)
>
;
auto
gemm
=
DeviceMHAInstance
{};
auto
invoker
=
gemm
.
MakeSelfAttnInvoker
();
auto
argument
=
gemm
.
MakeSelfAttnArgument
(
static_cast
<
ADataType
*>
(
qkv_device_buf
.
GetDeviceBuffer
()),
static_cast
<
CDataType
*>
(
c_device_buf
.
GetDeviceBuffer
()),
G0
,
M
,
G1
,
K
,
alpha
);
// if(!gemm.IsSupportedArgument(argument))
// {
// std::cout << gemm.GetTypeString() << " does not support this problem" << std::endl;
// return 0;
// }
ck
::
index_t
BatchCount
=
G0
*
G1
;
float
ave_time
=
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
time_kernel
});
std
::
size_t
flop
=
(
size_t
(
M
)
*
N
*
K
*
2
+
size_t
(
M
)
*
N
*
O
*
2
)
*
BatchCount
;
std
::
size_t
num_btype
=
(
sizeof
(
ADataType
)
*
M
*
K
+
sizeof
(
B0DataType
)
*
K
*
N
+
sizeof
(
B1DataType
)
*
N
*
O
+
sizeof
(
CDataType
)
*
M
*
O
)
*
BatchCount
;
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
ave_time
;
float
gb_per_sec
=
num_btype
/
1.E6
/
ave_time
;
std
::
cout
<<
"Perf: "
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s, "
<<
gemm
.
GetTypeString
()
<<
std
::
endl
;
if
(
tflops
>
best_perf
)
{
best_perf
=
tflops
;
best_time
=
ave_time
*
1000
;
best_kernel
=
gemm
.
GetTypeString
();
}
if
(
do_verification
)
{
c_device_buf
.
FromDevice
(
c_gs_ms_os_device_result
.
mData
.
data
());
Tensor
<
ADataType
>
a_g_m_k
({
BatchCount
,
M
,
K
});
Tensor
<
B0DataType
>
b0_g_k_n
({
BatchCount
,
K
,
N
});
Tensor
<
B1DataType
>
b1_g_n_o
({
BatchCount
,
N
,
O
});
Tensor
<
Acc0DataType
>
acc0_g_m_n
({
BatchCount
,
M
,
N
});
// scratch object after gemm0
Tensor
<
ADataType
>
a1_g_m_n
({
BatchCount
,
M
,
N
});
// scratch object after softmax
Tensor
<
CDataType
>
c_g_m_o_host_result
({
BatchCount
,
M
,
O
});
// scratch object after gemm1
// permute
a_gs_ms_ks
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
a_g_m_k
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
b0_gs_ns_ks
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
b0_g_k_n
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
3
],
idx
[
2
])
=
self
(
idx
);
});
b1_gs_os_ns
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
b1_g_n_o
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
3
],
idx
[
2
])
=
self
(
idx
);
});
// gemm 0
auto
ref_gemm0
=
ReferenceGemm0Instance
{};
auto
ref_gemm0_invoker
=
ref_gemm0
.
MakeInvoker
();
auto
ref_gemm0_argument
=
ref_gemm0
.
MakeArgument
(
a_g_m_k
,
b0_g_k_n
,
acc0_g_m_n
,
a_element_op
,
b0_element_op
,
acc0_element_op
);
ref_gemm0_invoker
.
Run
(
ref_gemm0_argument
);
// masking
const
auto
mask
=
typename
DeviceMHAInstance
::
C0MatrixMask
(
N
);
acc0_g_m_n
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
if
(
mask
.
IsMaskedElement
(
idx
[
1
],
idx
[
2
]))
self
(
idx
)
=
-
ck
::
NumericLimits
<
float
>::
Infinity
();
});
// softmax
auto
ref_softmax
=
ReferenceSoftmaxInstance
{};
auto
ref_softmax_invoker
=
ref_softmax
.
MakeInvoker
();
auto
ref_softmax_argument
=
ref_softmax
.
MakeArgument
(
acc0_g_m_n
,
a1_g_m_n
,
1
,
0
,
{
2
});
ref_softmax_invoker
.
Run
(
ref_softmax_argument
);
// gemm1
auto
ref_gemm1
=
ReferenceGemm1Instance
{};
auto
ref_gemm1_invoker
=
ref_gemm1
.
MakeInvoker
();
auto
ref_gemm1_argument
=
ref_gemm1
.
MakeArgument
(
a1_g_m_n
,
b1_g_n_o
,
c_g_m_o_host_result
,
PassThrough
{},
b1_element_op
,
c_element_op
);
ref_gemm1_invoker
.
Run
(
ref_gemm1_argument
);
// permute
c_gs_ms_os_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g1
=
idx
[
1
];
const
size_t
g
=
g0
*
G1
+
g1
;
self
(
idx
)
=
c_g_m_o_host_result
(
g
,
idx
[
2
],
idx
[
3
]);
});
// default absolute error and relative error is 0.001
double
rtol
=
1
e
-
3
;
double
atol
=
1
e
-
3
;
// when BF16 is taken, set absolute error and relative error to 0.01
if
(
std
::
is_same_v
<
ADataType
,
ck
::
bhalf_t
>
&&
std
::
is_same_v
<
B0DataType
,
ck
::
bhalf_t
>
&&
std
::
is_same_v
<
B1DataType
,
ck
::
bhalf_t
>
&&
std
::
is_same_v
<
CDataType
,
ck
::
bhalf_t
>
)
{
rtol
=
1
e
-
2
;
atol
=
1
e
-
2
;
}
bool
this_run_verification
=
ck
::
utils
::
check_err
(
c_gs_ms_os_device_result
.
mData
,
c_gs_ms_os_host_result
.
mData
,
"Error: Incorrect results!"
,
rtol
,
atol
);
printf
(
"Verification: %s, Pass: %s
\n
"
,
do_verification
?
"ON"
:
"OFF"
,
this_run_verification
?
"YES"
:
"NO"
);
if
(
!
this_run_verification
)
{
not_pass
=
1
;
printf
(
"%d th MHA instance verification Failed
\n
"
,
i
.
value
);
}
}
});
std
::
cout
<<
"---------------------------------------------------------------------------------"
"-----------"
<<
std
::
endl
;
std
::
cout
<<
"Problem Size: BatchCount: "
<<
G0
<<
", HeadNum: "
<<
G1
<<
", M: "
<<
M
<<
", N: "
<<
N
<<
", K: "
<<
K
<<
", O: "
<<
O
<<
std
::
endl
;
std
::
cout
<<
"---------------------------------------------------------------------------------"
"-----------"
<<
std
::
endl
;
std
::
cout
<<
"Best kernel: "
<<
best_kernel
<<
" , "
<<
best_perf
<<
" TFlops , "
<<
best_time
<<
" us"
<<
std
::
endl
;
std
::
cout
<<
"---------------------------------------------------------------------------------"
"-----------"
<<
std
::
endl
;
return
not_pass
;
}
example/32_batched_gemm_scale_softmax_gemm/self_attention_forward_wmma_fp16.cpp
0 → 100644
View file @
823c8801
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
/*
Gemm + Softmax + Gemm fused operation. Computes C_g_m_n = Softmax(A_g_m_k * B0_g_k_l) * B1_g_l_n
|-----------------|
Gemm0
|-------------------------------------|
Gemm1
*/
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp"
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
ADataType
=
F16
;
using
B0DataType
=
F16
;
using
B1DataType
=
F16
;
using
Acc0DataType
=
F32
;
using
Acc1DataType
=
F32
;
using
CShuffleDataType
=
F32
;
using
CDataType
=
F16
;
using
Acc0BiasDataType
=
ck
::
Tuple
<>
;
using
Acc1BiasDataType
=
ck
::
Tuple
<>
;
static
constexpr
ck
::
index_t
NumDimG
=
2
;
static
constexpr
ck
::
index_t
NumDimM
=
1
;
static
constexpr
ck
::
index_t
NumDimN
=
1
;
static
constexpr
ck
::
index_t
NumDimK
=
1
;
static
constexpr
ck
::
index_t
NumDimO
=
1
;
using
AElementOp
=
PassThrough
;
using
B0ElementOp
=
PassThrough
;
using
Acc0ElementOp
=
ck
::
tensor_operation
::
element_wise
::
Scale
;
using
B1ElementOp
=
PassThrough
;
using
CElementOp
=
PassThrough
;
static
constexpr
auto
GemmSpec
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKOPadding
;
static
constexpr
auto
MaskingSpec
=
ck
::
tensor_operation
::
device
::
MaskingSpecialization
::
MaskDisabled
;
static
constexpr
auto
TensorSpecA
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecB0
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecB1
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecC
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
// clang-format off
#define CK_MHA_USE_WAVE_1
#define CK_MHA_USE_WAVE_2
#define CK_MHA_USE_WAVE_4
#define CK_MHA_USE_WAVE_8
using
DeviceMHAFactory
=
std
::
tuple
<
#ifdef CK_MHA_USE_WAVE_1
// 1 wave, mrepeat = 1, nrepeat = 2, k/o repeat = 1~5
ck
::
tensor_operation
::
device
::
DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
ADataType
,
B0DataType
,
B1DataType
,
CDataType
,
Acc0BiasDataType
,
Acc0DataType
,
Acc1BiasDataType
,
Acc1DataType
,
CShuffleDataType
,
AElementOp
,
B0ElementOp
,
Acc0ElementOp
,
B1ElementOp
,
CElementOp
,
GemmSpec
,
TensorSpecA
,
TensorSpecB0
,
TensorSpecB1
,
TensorSpecC
,
1
,
32
,
// Gemm 0
16
,
128
,
64
,
8
,
8
,
// Gemm 1
64
,
64
,
8
,
16
,
16
,
16
,
// Per repeat = wave_m = wave_num, wave_n = 1
1
,
8
,
4
,
// ABlockTransfer MK -> K0 M K1
S
<
2
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
// B0BlockTransfer LK -> K0 L K1
S
<
2
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
// B1BlockTransfer NL -> L0 N L1
S
<
2
,
2
,
8
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
8
,
1
,
false
,
// CShuffleBlockTransfer MN
1
,
1
,
S
<
1
,
16
,
1
,
2
>
,
8
,
MaskingSpec
>
,
ck
::
tensor_operation
::
device
::
DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
ADataType
,
B0DataType
,
B1DataType
,
CDataType
,
Acc0BiasDataType
,
Acc0DataType
,
Acc1BiasDataType
,
Acc1DataType
,
CShuffleDataType
,
AElementOp
,
B0ElementOp
,
Acc0ElementOp
,
B1ElementOp
,
CElementOp
,
GemmSpec
,
TensorSpecA
,
TensorSpecB0
,
TensorSpecB1
,
TensorSpecC
,
1
,
32
,
// Gemm 0
16
,
64
,
64
,
8
,
8
,
// Gemm 1
64
,
64
,
8
,
16
,
16
,
16
,
// Per repeat = wave_m = wave_num, wave_n = 1
1
,
4
,
4
,
// ABlockTransfer MK -> K0 M K1
S
<
2
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
// B0BlockTransfer LK -> K0 L K1
S
<
2
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
// B1BlockTransfer NL -> L0 N L1
S
<
2
,
2
,
8
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
8
,
1
,
false
,
// CShuffleBlockTransfer MN
1
,
1
,
S
<
1
,
16
,
1
,
2
>
,
8
,
MaskingSpec
>
,
#endif
#ifdef CK_MHA_USE_WAVE_2
ck
::
tensor_operation
::
device
::
DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
ADataType
,
B0DataType
,
B1DataType
,
CDataType
,
Acc0BiasDataType
,
Acc0DataType
,
Acc1BiasDataType
,
Acc1DataType
,
CShuffleDataType
,
AElementOp
,
B0ElementOp
,
Acc0ElementOp
,
B1ElementOp
,
CElementOp
,
GemmSpec
,
TensorSpecA
,
TensorSpecB0
,
TensorSpecB1
,
TensorSpecC
,
1
,
64
,
// Gemm 0
32
,
128
,
64
,
8
,
8
,
// Gemm 1
64
,
64
,
8
,
16
,
16
,
16
,
// Per repeat = wave_m = wave_num, wave_n = 1
1
,
8
,
4
,
// ABlockTransfer MK -> K0 M K1
S
<
2
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
// B0BlockTransfer LK -> K0 L K1
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
// B1BlockTransfer NL -> L0 N L1
S
<
2
,
4
,
8
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
1
,
false
,
// CShuffleBlockTransfer MN
1
,
1
,
S
<
1
,
32
,
1
,
2
>
,
8
,
MaskingSpec
>
,
ck
::
tensor_operation
::
device
::
DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
ADataType
,
B0DataType
,
B1DataType
,
CDataType
,
Acc0BiasDataType
,
Acc0DataType
,
Acc1BiasDataType
,
Acc1DataType
,
CShuffleDataType
,
AElementOp
,
B0ElementOp
,
Acc0ElementOp
,
B1ElementOp
,
CElementOp
,
GemmSpec
,
TensorSpecA
,
TensorSpecB0
,
TensorSpecB1
,
TensorSpecC
,
1
,
64
,
// Gemm 0
32
,
64
,
64
,
8
,
8
,
// Gemm 1
64
,
64
,
8
,
16
,
16
,
16
,
// Per repeat = wave_m = wave_num, wave_n = 1
1
,
4
,
4
,
// ABlockTransfer MK -> K0 M K1
S
<
2
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
// B0BlockTransfer LK -> K0 L K1
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
// B1BlockTransfer NL -> L0 N L1
S
<
2
,
4
,
8
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
1
,
false
,
// CShuffleBlockTransfer MN
1
,
1
,
S
<
1
,
32
,
1
,
2
>
,
8
,
MaskingSpec
>
,
#endif
#ifdef CK_MHA_USE_WAVE_4
ck
::
tensor_operation
::
device
::
DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
ADataType
,
B0DataType
,
B1DataType
,
CDataType
,
Acc0BiasDataType
,
Acc0DataType
,
Acc1BiasDataType
,
Acc1DataType
,
CShuffleDataType
,
AElementOp
,
B0ElementOp
,
Acc0ElementOp
,
B1ElementOp
,
CElementOp
,
GemmSpec
,
TensorSpecA
,
TensorSpecB0
,
TensorSpecB1
,
TensorSpecC
,
1
,
128
,
// Gemm 0
64
,
128
,
64
,
8
,
8
,
// Gemm 1
64
,
64
,
8
,
16
,
16
,
16
,
// Per repeat = wave_m = wave_num, wave_n = 1
1
,
8
,
4
,
// ABlockTransfer MK -> K0 M K1
S
<
2
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
// B0BlockTransfer LK -> K0 L K1
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
// B1BlockTransfer NL -> L0 N L1
S
<
2
,
8
,
8
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
1
,
false
,
// CShuffleBlockTransfer MN
1
,
1
,
S
<
1
,
64
,
1
,
2
>
,
8
,
MaskingSpec
>
,
ck
::
tensor_operation
::
device
::
DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
ADataType
,
B0DataType
,
B1DataType
,
CDataType
,
Acc0BiasDataType
,
Acc0DataType
,
Acc1BiasDataType
,
Acc1DataType
,
CShuffleDataType
,
AElementOp
,
B0ElementOp
,
Acc0ElementOp
,
B1ElementOp
,
CElementOp
,
GemmSpec
,
TensorSpecA
,
TensorSpecB0
,
TensorSpecB1
,
TensorSpecC
,
1
,
128
,
// Gemm 0
64
,
64
,
64
,
8
,
8
,
// Gemm 1
64
,
64
,
8
,
16
,
16
,
16
,
// Per repeat = wave_m = wave_num, wave_n = 1
1
,
4
,
4
,
// ABlockTransfer MK -> K0 M K1
S
<
2
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
// B0BlockTransfer LK -> K0 L K1
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
// B1BlockTransfer NL -> L0 N L1
S
<
2
,
8
,
8
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
1
,
false
,
// CShuffleBlockTransfer MN
1
,
1
,
S
<
1
,
64
,
1
,
2
>
,
8
,
MaskingSpec
>
,
#endif
#ifdef CK_MHA_USE_WAVE_8
ck
::
tensor_operation
::
device
::
DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
ADataType
,
B0DataType
,
B1DataType
,
CDataType
,
Acc0BiasDataType
,
Acc0DataType
,
Acc1BiasDataType
,
Acc1DataType
,
CShuffleDataType
,
AElementOp
,
B0ElementOp
,
Acc0ElementOp
,
B1ElementOp
,
CElementOp
,
GemmSpec
,
TensorSpecA
,
TensorSpecB0
,
TensorSpecB1
,
TensorSpecC
,
1
,
256
,
// Gemm 0
128
,
128
,
64
,
8
,
8
,
// Gemm 1
64
,
64
,
8
,
16
,
16
,
16
,
// Per repeat = wave_m = wave_num, wave_n = 1
1
,
8
,
4
,
// ABlockTransfer MK -> K0 M K1
S
<
2
,
128
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
// B0BlockTransfer LK -> K0 L K1
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
// B1BlockTransfer NL -> L0 N L1
S
<
2
,
16
,
8
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
1
,
false
,
// CShuffleBlockTransfer MN
1
,
1
,
S
<
1
,
128
,
1
,
2
>
,
8
,
MaskingSpec
>
,
ck
::
tensor_operation
::
device
::
DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
ADataType
,
B0DataType
,
B1DataType
,
CDataType
,
Acc0BiasDataType
,
Acc0DataType
,
Acc1BiasDataType
,
Acc1DataType
,
CShuffleDataType
,
AElementOp
,
B0ElementOp
,
Acc0ElementOp
,
B1ElementOp
,
CElementOp
,
GemmSpec
,
TensorSpecA
,
TensorSpecB0
,
TensorSpecB1
,
TensorSpecC
,
1
,
256
,
// Gemm 0
128
,
128
,
64
,
8
,
8
,
// Gemm 1
64
,
64
,
8
,
16
,
16
,
16
,
// Per repeat = wave_m = wave_num, wave_n = 1
1
,
8
,
4
,
// ABlockTransfer MK -> K0 M K1
S
<
2
,
128
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
// B0BlockTransfer LK -> K0 L K1
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
// B1BlockTransfer NL -> L0 N L1
S
<
2
,
16
,
8
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
1
,
false
,
// CShuffleBlockTransfer MN
1
,
1
,
S
<
1
,
128
,
1
,
2
>
,
8
,
MaskingSpec
>
#endif
>
;
// clang-format on
// Ref Gemm0: fp16 in, fp32 out
using
ReferenceGemm0Instance
=
ck
::
tensor_operation
::
host
::
ReferenceBatchedGemm
<
ADataType
,
B0DataType
,
Acc0DataType
,
Acc1DataType
,
AElementOp
,
B0ElementOp
,
Acc0ElementOp
>
;
// Ref Softmax: fp32 in, fp16 out
using
ReferenceSoftmaxInstance
=
ck
::
tensor_operation
::
host
::
ReferenceSoftmax
<
Acc0DataType
,
ADataType
,
Acc0DataType
>
;
// Ref Gemm1: fp16 in, fp16 out
using
ReferenceGemm1Instance
=
ck
::
tensor_operation
::
host
::
ReferenceBatchedGemm
<
ADataType
,
B1DataType
,
CDataType
,
Acc1DataType
,
AElementOp
,
B1ElementOp
,
CElementOp
>
;
#include "run_self_attention.inc"
int
main
(
int
argc
,
char
*
argv
[])
{
return
run
(
argc
,
argv
);
}
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp
View file @
823c8801
...
@@ -5,7 +5,11 @@
...
@@ -5,7 +5,11 @@
#include <iostream>
#include <iostream>
#include <sstream>
#include <sstream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/utility/common_header.hpp"
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
...
@@ -22,6 +26,417 @@ namespace ck {
...
@@ -22,6 +26,417 @@ namespace ck {
namespace
tensor_operation
{
namespace
tensor_operation
{
namespace
device
{
namespace
device
{
template
<
typename
DeviceOp
,
typename
GridwiseOp
,
typename
ADataType
,
typename
B0DataType
,
typename
B1DataType
,
typename
CDataType
,
typename
AElementwiseOperation
,
typename
B0ElementwiseOperation
,
typename
AccElementwiseOperation
,
typename
B1ElementwiseOperation
,
typename
CElementwiseOperation
,
bool
HasMainKBlockLoop
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
kernel_batched_gemm_softmax_gemm_wmma_cshuffle
(
const
ADataType
*
__restrict__
p_a_grid
,
const
B0DataType
*
__restrict__
p_b0_grid
,
const
B1DataType
*
__restrict__
p_b1_grid
,
CDataType
*
__restrict__
p_c_grid
,
index_t
M
,
index_t
N
,
index_t
K
,
index_t
O
,
index_t
G0
,
index_t
G1
,
float
alpha
,
bool
input_permute
,
bool
output_permute
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__) || defined(__gfx1101__) || \
defined(__gfx1102__))
// clang-format off
// ***************************************************
// Make Tensor Descriptors
constexpr
index_t
array_size
=
4
;
std
::
array
<
ck
::
index_t
,
array_size
>
a_gs_ms_ks_lengths
{
G0
,
G1
,
M
,
K
};
std
::
array
<
ck
::
index_t
,
array_size
>
a_gs_ms_ks_strides
=
input_permute
?
std
::
array
<
ck
::
index_t
,
array_size
>
{
M
*
G1
*
K
,
K
,
G1
*
K
,
1
}
// A layout [G0, M, G1, K]
:
std
::
array
<
ck
::
index_t
,
array_size
>
{
G1
*
M
*
K
,
M
*
K
,
K
,
1
};
// A layout [G0, G1, M, K]
std
::
array
<
ck
::
index_t
,
array_size
>
b0_gs_ns_ks_lengths
{
G0
,
G1
,
N
,
K
};
std
::
array
<
ck
::
index_t
,
array_size
>
b0_gs_ns_ks_strides
=
input_permute
?
std
::
array
<
ck
::
index_t
,
array_size
>
{
N
*
G1
*
K
,
K
,
G1
*
K
,
1
}
// B0 layout [G0, N, G1, K]
:
std
::
array
<
ck
::
index_t
,
array_size
>
{
G1
*
N
*
K
,
N
*
K
,
K
,
1
};
// B0 layout [G0, G1, N, K]
std
::
array
<
ck
::
index_t
,
array_size
>
b1_gs_os_ns_lengths
{
G0
,
G1
,
O
,
N
};
std
::
array
<
ck
::
index_t
,
array_size
>
b1_gs_os_ns_strides
=
input_permute
?
std
::
array
<
ck
::
index_t
,
array_size
>
{
N
*
G1
*
O
,
O
,
1
,
G1
*
O
}
// B1 layout [G0, N, G1, O]
:
std
::
array
<
ck
::
index_t
,
array_size
>
{
G1
*
N
*
O
,
N
*
O
,
1
,
O
};
// B1 layout [G0, G1, N, O]
std
::
array
<
ck
::
index_t
,
array_size
>
c_gs_ms_os_lengths
{
G0
,
G1
,
M
,
O
};
std
::
array
<
ck
::
index_t
,
array_size
>
c_gs_ms_os_strides
=
output_permute
?
std
::
array
<
ck
::
index_t
,
array_size
>
{
M
*
G1
*
O
,
O
,
G1
*
O
,
1
}
// C layout [G0, M, G1, O]
:
std
::
array
<
ck
::
index_t
,
array_size
>
{
G1
*
M
*
O
,
M
*
O
,
O
,
1
};
// C layout [G0, G1, M, O]
const
auto
a_element_op
=
AElementwiseOperation
{};
const
auto
b0_element_op
=
B0ElementwiseOperation
{};
const
auto
acc0_element_op
=
AccElementwiseOperation
{
alpha
};
const
auto
b1_element_op
=
B1ElementwiseOperation
{};
const
auto
c_element_op
=
CElementwiseOperation
{};
// fail to reuse DeviceOp::MakeArgument() because of the __device__ function required.
const
auto
a_grid_desc
=
DeviceOp
::
MakeAGridDescriptor
(
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
);
const
auto
b0_grid_desc
=
DeviceOp
::
MakeB0GridDescriptor
(
b0_gs_ns_ks_lengths
,
b0_gs_ns_ks_strides
);
const
auto
b1_grid_desc
=
DeviceOp
::
MakeB1GridDescriptor
(
b1_gs_os_ns_lengths
,
b1_gs_os_ns_strides
);
const
auto
c_grid_desc_m_n
=
DeviceOp
::
Transform
::
MakeCGridDescriptor_M_N
(
c_gs_ms_os_lengths
,
c_gs_ms_os_strides
);
const
auto
c_grid_desc_mblock_mperblock_nblock_nperblock
=
GridwiseOp
::
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
c_grid_desc_m_n
);
const
auto
block_2_ctile_map
=
GridwiseOp
::
MakeDefaultBlock2CTileMap
(
c_grid_desc_m_n
,
1
,
1
);
const
auto
a_grid_desc_g_m_k
=
DeviceOp
::
Transform
::
MakeAGridDescriptor_G_M_K
(
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
);
const
auto
b0_grid_desc_g_l_k
=
DeviceOp
::
Transform
::
MakeB0GridDescriptor_G_N_K
(
b0_gs_ns_ks_lengths
,
b0_gs_ns_ks_strides
);
const
auto
b1_grid_desc_g_n_l
=
DeviceOp
::
Transform
::
MakeB1GridDescriptor_G_N_K
(
b1_gs_os_ns_lengths
,
b1_gs_os_ns_strides
);
const
auto
c_grid_desc_g_m_n
=
DeviceOp
::
Transform
::
MakeCGridDescriptor_G_M_N
(
c_gs_ms_os_lengths
,
c_gs_ms_os_strides
);
const
auto
compute_base_ptr_of_batch
=
typename
DeviceOp
::
ComputeBasePtrOfStridedBatch
{
a_grid_desc_g_m_k
,
b0_grid_desc_g_l_k
,
b1_grid_desc_g_n_l
,
c_grid_desc_g_m_n
};
index_t
batch_count
=
c_grid_desc_g_m_n
.
GetLength
(
Number
<
0
>
{});
const
auto
c0_matrix_mask
=
typename
DeviceOp
::
C0MatrixMask
{
b0_grid_desc_g_l_k
.
GetLength
(
Number
<
1
>
{})};
// clang-format on
__shared__
char
p_shared
[
GridwiseOp
::
GetSharedMemoryNumberOfByte
()];
const
index_t
num_blocks_per_batch
=
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
get_block_1d_id
()
/
num_blocks_per_batch
);
const
long_index_t
a_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetABasePtr
(
g_idx
)));
const
long_index_t
b0_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetB0BasePtr
(
g_idx
)));
const
long_index_t
b1_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetB1BasePtr
(
g_idx
)));
const
long_index_t
c_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetCBasePtr
(
g_idx
)));
GridwiseOp
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
+
a_batch_offset
,
p_b0_grid
+
b0_batch_offset
,
p_b1_grid
+
b1_batch_offset
,
p_c_grid
+
c_batch_offset
,
p_shared
,
a_grid_desc
,
b0_grid_desc
,
b1_grid_desc
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
a_element_op
,
b0_element_op
,
acc0_element_op
,
b1_element_op
,
c_element_op
,
c0_matrix_mask
,
block_2_ctile_map
);
#else
ignore
=
p_a_grid
;
ignore
=
p_b0_grid
;
ignore
=
p_b1_grid
;
ignore
=
p_c_grid
;
ignore
=
M
;
ignore
=
N
;
ignore
=
K
;
ignore
=
O
;
ignore
=
G0
;
ignore
=
G1
;
ignore
=
input_permute
;
ignore
=
output_permute
;
#endif // end of if (defined(__gfx1100__))
}
// Self-Attention
template
<
typename
DeviceOp
,
typename
GridwiseOp
,
typename
QKVDataType
,
typename
ODataType
,
typename
AElementwiseOperation
,
typename
B0ElementwiseOperation
,
typename
AccElementwiseOperation
,
typename
B1ElementwiseOperation
,
typename
CElementwiseOperation
,
bool
HasMainKBlockLoop
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
kernel_wmma_self_attention_forward
(
const
QKVDataType
*
__restrict__
p_qkv_grid
,
ODataType
*
__restrict__
p_out_grid
,
index_t
batch_size
,
index_t
sequence_length
,
index_t
head_count
,
index_t
head_size
,
float
alpha
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__) || defined(__gfx1101__) || \
defined(__gfx1102__))
// clang-format off
// ***************************************************
// Make Tensor Descriptors
// o Self-attention(packed QKV): [batchSize, sequenceLength, headCount, 3, headSize]
constexpr
index_t
array_size
=
4
;
std
::
array
<
ck
::
index_t
,
array_size
>
qk_gs_ms_ks_lengths
{
batch_size
,
head_count
,
sequence_length
,
head_size
};
std
::
array
<
ck
::
index_t
,
array_size
>
qk_gs_ms_ks_strides
{
sequence_length
*
head_count
*
3
*
head_size
,
3
*
head_size
,
head_count
*
3
*
head_size
,
1
};
std
::
array
<
ck
::
index_t
,
array_size
>
v_gs_os_ns_lengths
{
batch_size
,
head_count
,
head_size
,
sequence_length
};
std
::
array
<
ck
::
index_t
,
array_size
>
v_gs_os_ns_strides
{
sequence_length
*
head_count
*
3
*
head_size
,
3
*
head_size
,
1
,
head_count
*
3
*
head_size
};
std
::
array
<
ck
::
index_t
,
array_size
>
c_gs_ms_os_lengths
{
batch_size
,
head_count
,
sequence_length
,
head_size
};
std
::
array
<
ck
::
index_t
,
array_size
>
c_gs_ms_os_strides
{
sequence_length
*
head_count
*
head_size
,
head_size
,
head_count
*
head_size
,
1
};
const
auto
a_element_op
=
AElementwiseOperation
{};
const
auto
b0_element_op
=
B0ElementwiseOperation
{};
const
auto
acc0_element_op
=
AccElementwiseOperation
{
alpha
};
const
auto
b1_element_op
=
B1ElementwiseOperation
{};
const
auto
c_element_op
=
CElementwiseOperation
{};
const
auto
a_grid_desc
=
DeviceOp
::
MakeAGridDescriptor
(
qk_gs_ms_ks_lengths
,
qk_gs_ms_ks_strides
);
const
auto
b0_grid_desc
=
DeviceOp
::
MakeB0GridDescriptor
(
qk_gs_ms_ks_lengths
,
qk_gs_ms_ks_strides
);
const
auto
b1_grid_desc
=
DeviceOp
::
MakeB1GridDescriptor
(
v_gs_os_ns_lengths
,
v_gs_os_ns_strides
);
const
auto
c_grid_desc_m_n
=
DeviceOp
::
Transform
::
MakeCGridDescriptor_M_N
(
c_gs_ms_os_lengths
,
c_gs_ms_os_strides
);
const
auto
c_grid_desc_mblock_mperblock_nblock_nperblock
=
GridwiseOp
::
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
c_grid_desc_m_n
);
const
auto
block_2_ctile_map
=
GridwiseOp
::
MakeDefaultBlock2CTileMap
(
c_grid_desc_m_n
,
1
,
1
);
const
auto
a_grid_desc_g_m_k
=
DeviceOp
::
Transform
::
MakeAGridDescriptor_G_M_K
(
qk_gs_ms_ks_lengths
,
qk_gs_ms_ks_strides
);
const
auto
b0_grid_desc_g_l_k
=
DeviceOp
::
Transform
::
MakeB0GridDescriptor_G_N_K
(
qk_gs_ms_ks_lengths
,
qk_gs_ms_ks_strides
);
const
auto
b1_grid_desc_g_n_l
=
DeviceOp
::
Transform
::
MakeB1GridDescriptor_G_N_K
(
v_gs_os_ns_lengths
,
v_gs_os_ns_strides
);
const
auto
c_grid_desc_g_m_n
=
DeviceOp
::
Transform
::
MakeCGridDescriptor_G_M_N
(
c_gs_ms_os_lengths
,
c_gs_ms_os_strides
);
const
auto
compute_base_ptr_of_batch
=
typename
DeviceOp
::
ComputeBasePtrOfStridedBatch
{
a_grid_desc_g_m_k
,
b0_grid_desc_g_l_k
,
b1_grid_desc_g_n_l
,
c_grid_desc_g_m_n
};
index_t
batch_count
=
c_grid_desc_g_m_n
.
GetLength
(
Number
<
0
>
{});
const
auto
c0_matrix_mask
=
typename
DeviceOp
::
C0MatrixMask
{
b0_grid_desc_g_l_k
.
GetLength
(
Number
<
1
>
{})};
// clang-format on
__shared__
char
p_shared
[
GridwiseOp
::
GetSharedMemoryNumberOfByte
()];
const
index_t
num_blocks_per_batch
=
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
get_block_1d_id
()
/
num_blocks_per_batch
);
const
long_index_t
a_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetABasePtr
(
g_idx
)));
const
long_index_t
b0_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetB0BasePtr
(
g_idx
)));
const
long_index_t
b1_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetB1BasePtr
(
g_idx
)));
const
long_index_t
c_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetCBasePtr
(
g_idx
)));
const
index_t
qkv_gap
=
__builtin_amdgcn_readfirstlane
(
head_size
);
#ifdef CK_SELF_ATTN_DEBUG
if
(
get_thread_global_1d_id
()
==
0
)
{
printf
(
"batch_size: %d
\n
"
,
batch_size
);
printf
(
"sequence_length: %d
\n
"
,
sequence_length
);
printf
(
"head_count: %d
\n
"
,
head_count
);
printf
(
"head_size: %d
\n
"
,
head_size
);
printf
(
"qkv_gap: %d
\n
"
,
qkv_gap
);
printf
(
"get_grid_size(): %d
\n
"
,
get_grid_size
());
printf
(
"batch_count: %d
\n
"
,
batch_count
);
printf
(
"blockid: %d
\n
"
,
get_block_1d_id
());
printf
(
"num_blocks_per_batch: %d
\n
"
,
num_blocks_per_batch
);
printf
(
"g_idx: %d
\n
"
,
g_idx
);
printf
(
"a_batch_offset: %ld
\n
"
,
a_batch_offset
);
printf
(
"b0_batch_offset: %ld
\n
"
,
b0_batch_offset
);
printf
(
"b1_batch_offset: %ld
\n
"
,
b1_batch_offset
);
}
#endif
GridwiseOp
::
template
Run
<
HasMainKBlockLoop
>(
p_qkv_grid
+
0
*
qkv_gap
+
a_batch_offset
,
p_qkv_grid
+
1
*
qkv_gap
+
b0_batch_offset
,
p_qkv_grid
+
2
*
qkv_gap
+
b1_batch_offset
,
p_out_grid
+
c_batch_offset
,
p_shared
,
a_grid_desc
,
b0_grid_desc
,
b1_grid_desc
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
a_element_op
,
b0_element_op
,
acc0_element_op
,
b1_element_op
,
c_element_op
,
c0_matrix_mask
,
block_2_ctile_map
);
#else
ignore
=
p_qkv_grid
;
ignore
=
p_out_grid
;
ignore
=
batch_size
;
ignore
=
sequence_length
;
ignore
=
head_count
;
ignore
=
head_size
;
ignore
=
alpha
;
#endif // end of if (defined(__gfx1100__))
}
// Cross-Attention
// Self-Attention
template
<
typename
DeviceOp
,
typename
GridwiseOp
,
typename
QDataType
,
typename
KVDataType
,
typename
ODataType
,
typename
AElementwiseOperation
,
typename
B0ElementwiseOperation
,
typename
AccElementwiseOperation
,
typename
B1ElementwiseOperation
,
typename
CElementwiseOperation
,
bool
HasMainKBlockLoop
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
kernel_wmma_cross_attention_forward
(
const
QDataType
*
__restrict__
p_q_grid
,
const
KVDataType
*
__restrict__
p_kv_grid
,
ODataType
*
__restrict__
p_out_grid
,
index_t
batch_size
,
index_t
q_sequence_length
,
index_t
kv_sequence_length
,
index_t
head_count
,
index_t
head_size
,
float
alpha
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__) || defined(__gfx1101__) || \
defined(__gfx1102__))
// clang-format off
// ***************************************************
// Make Tensor Descriptors
// o Self-attention(packed QKV): [batchSize, sequenceLength, headCount, 3, headSize]
constexpr
index_t
array_size
=
4
;
std
::
array
<
ck
::
index_t
,
array_size
>
q_gs_ms_ks_lengths
{
batch_size
,
head_count
,
q_sequence_length
,
head_size
};
std
::
array
<
ck
::
index_t
,
array_size
>
q_gs_ms_ks_strides
{
q_sequence_length
*
head_count
*
head_size
,
head_size
,
head_count
*
head_size
,
1
};
std
::
array
<
ck
::
index_t
,
array_size
>
k_gs_ms_ks_lengths
{
batch_size
,
head_count
,
kv_sequence_length
,
head_size
};
std
::
array
<
ck
::
index_t
,
array_size
>
k_gs_ms_ks_strides
{
kv_sequence_length
*
head_count
*
2
*
head_size
,
2
*
head_size
,
head_count
*
2
*
head_size
,
1
};
std
::
array
<
ck
::
index_t
,
array_size
>
v_gs_os_ns_lengths
{
batch_size
,
head_count
,
head_size
,
kv_sequence_length
};
std
::
array
<
ck
::
index_t
,
array_size
>
v_gs_os_ns_strides
{
kv_sequence_length
*
head_count
*
2
*
head_size
,
2
*
head_size
,
1
,
head_count
*
2
*
head_size
};
std
::
array
<
ck
::
index_t
,
array_size
>
c_gs_ms_os_lengths
{
batch_size
,
head_count
,
q_sequence_length
,
head_size
};
std
::
array
<
ck
::
index_t
,
array_size
>
c_gs_ms_os_strides
{
q_sequence_length
*
head_count
*
head_size
,
head_size
,
head_count
*
head_size
,
1
};
const
auto
a_element_op
=
AElementwiseOperation
{};
const
auto
b0_element_op
=
B0ElementwiseOperation
{};
const
auto
acc0_element_op
=
AccElementwiseOperation
{
alpha
};
const
auto
b1_element_op
=
B1ElementwiseOperation
{};
const
auto
c_element_op
=
CElementwiseOperation
{};
const
auto
a_grid_desc
=
DeviceOp
::
MakeAGridDescriptor
(
q_gs_ms_ks_lengths
,
q_gs_ms_ks_strides
);
const
auto
b0_grid_desc
=
DeviceOp
::
MakeB0GridDescriptor
(
k_gs_ms_ks_lengths
,
k_gs_ms_ks_strides
);
const
auto
b1_grid_desc
=
DeviceOp
::
MakeB1GridDescriptor
(
v_gs_os_ns_lengths
,
v_gs_os_ns_strides
);
const
auto
c_grid_desc_m_n
=
DeviceOp
::
Transform
::
MakeCGridDescriptor_M_N
(
c_gs_ms_os_lengths
,
c_gs_ms_os_strides
);
const
auto
c_grid_desc_mblock_mperblock_nblock_nperblock
=
GridwiseOp
::
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
c_grid_desc_m_n
);
const
auto
block_2_ctile_map
=
GridwiseOp
::
MakeDefaultBlock2CTileMap
(
c_grid_desc_m_n
,
1
,
1
);
const
auto
a_grid_desc_g_m_k
=
DeviceOp
::
Transform
::
MakeAGridDescriptor_G_M_K
(
q_gs_ms_ks_lengths
,
q_gs_ms_ks_strides
);
const
auto
b0_grid_desc_g_l_k
=
DeviceOp
::
Transform
::
MakeB0GridDescriptor_G_N_K
(
k_gs_ms_ks_lengths
,
k_gs_ms_ks_strides
);
const
auto
b1_grid_desc_g_n_l
=
DeviceOp
::
Transform
::
MakeB1GridDescriptor_G_N_K
(
v_gs_os_ns_lengths
,
v_gs_os_ns_strides
);
const
auto
c_grid_desc_g_m_n
=
DeviceOp
::
Transform
::
MakeCGridDescriptor_G_M_N
(
c_gs_ms_os_lengths
,
c_gs_ms_os_strides
);
const
auto
compute_base_ptr_of_batch
=
typename
DeviceOp
::
ComputeBasePtrOfStridedBatch
{
a_grid_desc_g_m_k
,
b0_grid_desc_g_l_k
,
b1_grid_desc_g_n_l
,
c_grid_desc_g_m_n
};
index_t
batch_count
=
c_grid_desc_g_m_n
.
GetLength
(
Number
<
0
>
{});
const
auto
c0_matrix_mask
=
typename
DeviceOp
::
C0MatrixMask
{
b0_grid_desc_g_l_k
.
GetLength
(
Number
<
1
>
{})};
// clang-format on
__shared__
char
p_shared
[
GridwiseOp
::
GetSharedMemoryNumberOfByte
()];
const
index_t
num_blocks_per_batch
=
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
get_block_1d_id
()
/
num_blocks_per_batch
);
const
long_index_t
a_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetABasePtr
(
g_idx
)));
const
long_index_t
b0_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetB0BasePtr
(
g_idx
)));
const
long_index_t
b1_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetB1BasePtr
(
g_idx
)));
const
long_index_t
c_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetCBasePtr
(
g_idx
)));
const
index_t
kv_gap
=
__builtin_amdgcn_readfirstlane
(
head_size
);
#ifdef CK_SELF_ATTN_DEBUG
if
(
get_thread_global_1d_id
()
==
0
)
{
printf
(
"batch_size: %d
\n
"
,
batch_size
);
printf
(
"q_sequence_length: %d
\n
"
,
q_sequence_length
);
printf
(
"k_sequence_length: %d
\n
"
,
kv_sequence_length
);
printf
(
"head_count: %d
\n
"
,
head_count
);
printf
(
"head_size: %d
\n
"
,
head_size
);
printf
(
"kv_gap: %d
\n
"
,
kv_gap
);
printf
(
"get_grid_size(): %d
\n
"
,
get_grid_size
());
printf
(
"batch_count: %d
\n
"
,
batch_count
);
printf
(
"blockid: %d
\n
"
,
get_block_1d_id
());
printf
(
"num_blocks_per_batch: %d
\n
"
,
num_blocks_per_batch
);
printf
(
"g_idx: %d
\n
"
,
g_idx
);
printf
(
"a_batch_offset: %ld
\n
"
,
a_batch_offset
);
printf
(
"b0_batch_offset: %ld
\n
"
,
b0_batch_offset
);
printf
(
"b1_batch_offset: %ld
\n
"
,
b1_batch_offset
);
}
#endif
GridwiseOp
::
template
Run
<
HasMainKBlockLoop
>(
p_q_grid
+
a_batch_offset
,
p_kv_grid
+
0
*
kv_gap
+
b0_batch_offset
,
p_kv_grid
+
1
*
kv_gap
+
b1_batch_offset
,
p_out_grid
+
c_batch_offset
,
p_shared
,
a_grid_desc
,
b0_grid_desc
,
b1_grid_desc
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
a_element_op
,
b0_element_op
,
acc0_element_op
,
b1_element_op
,
c_element_op
,
c0_matrix_mask
,
block_2_ctile_map
);
#else
ignore
=
p_q_grid
;
ignore
=
p_kv_grid
;
ignore
=
p_out_grid
;
ignore
=
batch_size
;
ignore
=
q_sequence_length
;
ignore
=
kv_sequence_length
;
ignore
=
head_count
;
ignore
=
head_size
;
ignore
=
alpha
;
#endif // end of if (defined(__gfx1100__))
}
// Computes C = A * B0 * B1
// Computes C = A * B0 * B1
// MN = MK * KL * LN
// MN = MK * KL * LN
// ^^^^^^ (Acc0)
// ^^^^^^ (Acc0)
...
@@ -55,7 +470,8 @@ template <index_t NumDimG,
...
@@ -55,7 +470,8 @@ template <index_t NumDimG,
ck
::
index_t
MPerBlock
,
ck
::
index_t
MPerBlock
,
ck
::
index_t
LPerBlock
,
ck
::
index_t
LPerBlock
,
ck
::
index_t
KPerBlock
,
ck
::
index_t
KPerBlock
,
ck
::
index_t
K1
,
ck
::
index_t
AK1
,
ck
::
index_t
BK1
,
ck
::
index_t
NPerBlock
,
ck
::
index_t
NPerBlock
,
ck
::
index_t
LTilePerBlock
,
ck
::
index_t
LTilePerBlock
,
ck
::
index_t
L1
,
ck
::
index_t
L1
,
...
@@ -165,14 +581,15 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
...
@@ -165,14 +581,15 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
B1Spec
,
B1Spec
,
CSpec
>
;
CSpec
>
;
static
auto
MakeAGridDescriptor
(
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths_vec
,
__host__
__device__
static
auto
MakeAGridDescriptor
(
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides_vec
)
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
a_gs_ms_ks_lengths_vec
,
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
a_gs_ms_ks_strides_vec
)
{
{
if
constexpr
(
AEnableLds
)
if
constexpr
(
AEnableLds
)
{
{
return
Transform
::
MakeAGridDescriptor_AK0_M_AK1
(
return
Transform
::
MakeAGridDescriptor_AK0_M_AK1
(
Transform
::
MakeAGridDescriptor_M_K
(
a_gs_ms_ks_lengths_vec
,
a_gs_ms_ks_strides_vec
),
Transform
::
MakeAGridDescriptor_M_K
(
a_gs_ms_ks_lengths_vec
,
a_gs_ms_ks_strides_vec
),
Number
<
K1
>
{});
Number
<
A
K1
>
{});
}
}
else
else
{
{
...
@@ -184,19 +601,20 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
...
@@ -184,19 +601,20 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
Number
<
MRepeat
>
{},
Number
<
MRepeat
>
{},
Number
<
MWaves
>
{},
Number
<
MWaves
>
{},
Number
<
MPerWmma
>
{},
Number
<
MPerWmma
>
{},
Number
<
K1
>
{});
Number
<
A
K1
>
{});
}
}
}
}
static
auto
MakeB0GridDescriptor
(
const
std
::
vector
<
index_t
>&
b0_gs_ls_ks_lengths_vec
,
__host__
__device__
static
auto
MakeB0GridDescriptor
(
const
std
::
vector
<
index_t
>&
b0_gs_ls_ks_strides_vec
)
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
b0_gs_ls_ks_lengths_vec
,
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
b0_gs_ls_ks_strides_vec
)
{
{
if
constexpr
(
B0EnableLds
)
if
constexpr
(
B0EnableLds
)
{
{
return
Transform
::
MakeB0GridDescriptor_BK0_N_BK1
(
return
Transform
::
MakeB0GridDescriptor_BK0_N_BK1
(
Transform
::
MakeB0GridDescriptor_N_K
(
b0_gs_ls_ks_lengths_vec
,
Transform
::
MakeB0GridDescriptor_N_K
(
b0_gs_ls_ks_lengths_vec
,
b0_gs_ls_ks_strides_vec
),
b0_gs_ls_ks_strides_vec
),
Number
<
K1
>
{});
Number
<
B
K1
>
{});
}
}
else
else
{
{
...
@@ -208,12 +626,13 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
...
@@ -208,12 +626,13 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
Number
<
LRepeat
>
{},
Number
<
LRepeat
>
{},
Number
<
LWaves
>
{},
Number
<
LWaves
>
{},
Number
<
LPerWmma
>
{},
Number
<
LPerWmma
>
{},
Number
<
K1
>
{});
Number
<
B
K1
>
{});
}
}
}
}
static
auto
MakeB1GridDescriptor
(
const
std
::
vector
<
index_t
>&
b1_gs_ns_ls_lengths_vec
,
__host__
__device__
static
auto
MakeB1GridDescriptor
(
const
std
::
vector
<
index_t
>&
b1_gs_ns_ls_strides_vec
)
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
b1_gs_ns_ls_lengths_vec
,
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
b1_gs_ns_ls_strides_vec
)
{
{
if
constexpr
(
B1EnableLds
)
if
constexpr
(
B1EnableLds
)
{
{
...
@@ -245,7 +664,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
...
@@ -245,7 +664,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
using
B1GridDesc_G_N_L
=
decltype
(
Transform
::
MakeB1GridDescriptor_G_N_K
({},
{}));
using
B1GridDesc_G_N_L
=
decltype
(
Transform
::
MakeB1GridDescriptor_G_N_K
({},
{}));
using
CGridDesc_G_M_N
=
decltype
(
Transform
::
MakeCGridDescriptor_G_M_N
({},
{}));
using
CGridDesc_G_M_N
=
decltype
(
Transform
::
MakeCGridDescriptor_G_M_N
({},
{}));
constexpr
static
auto
make_MaskOutPredicate
()
__host__
__device__
constexpr
static
auto
make_MaskOutPredicate
()
{
{
if
constexpr
(
MaskingSpec
==
MaskingSpecialization
::
MaskDisabled
)
if
constexpr
(
MaskingSpec
==
MaskingSpecialization
::
MaskDisabled
)
{
{
...
@@ -260,7 +679,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
...
@@ -260,7 +679,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
struct
ComputeBasePtrOfStridedBatch
struct
ComputeBasePtrOfStridedBatch
{
{
ComputeBasePtrOfStridedBatch
(
const
AGridDesc_G_M_K
&
a_grid_desc_g_m_k
,
__host__
__device__
ComputeBasePtrOfStridedBatch
(
const
AGridDesc_G_M_K
&
a_grid_desc_g_m_k
,
const
B0GridDesc_G_L_K
&
b0_grid_desc_g_l_k
,
const
B0GridDesc_G_L_K
&
b0_grid_desc_g_l_k
,
const
B1GridDesc_G_N_L
&
b1_grid_desc_g_n_l
,
const
B1GridDesc_G_N_L
&
b1_grid_desc_g_n_l
,
const
CGridDesc_G_M_N
&
c_grid_desc_g_m_n
)
const
CGridDesc_G_M_N
&
c_grid_desc_g_m_n
)
...
@@ -324,7 +743,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
...
@@ -324,7 +743,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
MPerBlock
,
MPerBlock
,
LPerBlock
,
LPerBlock
,
KPerBlock
,
KPerBlock
,
K1
,
AK1
,
BK1
,
NPerBlock
,
NPerBlock
,
LTilePerBlock
,
LTilePerBlock
,
L1
,
L1
,
...
@@ -373,6 +793,323 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
...
@@ -373,6 +793,323 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
LoopSched
,
LoopSched
,
PipelineVer
>
;
PipelineVer
>
;
struct
RawArg
:
public
BaseArgument
{
RawArg
(
const
ADataType
*
p_a_grid
,
const
B0DataType
*
p_b0_grid
,
const
B1DataType
*
p_b1_grid
,
CDataType
*
p_c_grid
,
index_t
M
,
index_t
N
,
index_t
K
,
index_t
O
,
index_t
G0
,
index_t
G1
,
float
alpha
,
bool
input_permute
,
bool
output_permute
)
:
p_a_grid_
{
p_a_grid
},
p_b0_grid_
{
p_b0_grid
},
p_b1_grid_
{
p_b1_grid
},
p_c_grid_
{
p_c_grid
},
M_
{
M
},
N_
{
N
},
K_
{
K
},
O_
{
O
},
G0_
{
G0
},
G1_
{
G1
},
alpha_
{
alpha
},
input_permute_
{
input_permute
},
output_permute_
{
output_permute
}
{
}
// Pointers
const
ADataType
*
p_a_grid_
;
const
B0DataType
*
p_b0_grid_
;
const
B1DataType
*
p_b1_grid_
;
CDataType
*
p_c_grid_
;
// Raw Problem Size
index_t
M_
;
index_t
N_
;
index_t
K_
;
index_t
O_
;
index_t
G0_
;
index_t
G1_
;
float
alpha_
;
bool
input_permute_
;
bool
output_permute_
;
};
static
auto
MakeArgument
(
const
ADataType
*
p_a
,
const
B0DataType
*
p_b0
,
const
B1DataType
*
p_b1
,
CDataType
*
p_c
,
index_t
M
,
index_t
N
,
index_t
K
,
index_t
O
,
index_t
G0
,
index_t
G1
,
float
alpha
,
bool
input_permute
,
bool
output_permute
)
{
return
RawArg
{
p_a
,
p_b0
,
p_b1
,
p_c
,
M
,
N
,
K
,
O
,
G0
,
G1
,
alpha
,
input_permute
,
output_permute
};
}
static
bool
IsSupportedArgument
(
const
RawArg
&
arg
)
{
if
(
ck
::
get_device_name
()
==
"gfx1100"
||
ck
::
get_device_name
()
==
"gfx1101"
||
ck
::
get_device_name
()
==
"gfx1102"
)
{
if
constexpr
(
!
(
is_same_v
<
Acc0DataType
,
float
>
||
is_same_v
<
Acc0DataType
,
int32_t
>
))
{
printf
(
"DeviceOp: Acc0 Type err"
);
return
false
;
}
if
constexpr
(
!
(
is_same_v
<
Acc1DataType
,
float
>
||
is_same_v
<
Acc1DataType
,
int32_t
>
))
{
printf
(
"DeviceOp: Acc1 Type err"
);
return
false
;
}
}
else
{
printf
(
"DeviceOp: Arch err"
);
return
false
;
}
constexpr
index_t
array_size
=
4
;
ck
::
index_t
G0
=
arg
.
G0_
;
ck
::
index_t
G1
=
arg
.
G1_
;
ck
::
index_t
M
=
arg
.
M_
;
ck
::
index_t
N
=
arg
.
N_
;
ck
::
index_t
K
=
arg
.
K_
;
ck
::
index_t
O
=
arg
.
O_
;
bool
input_permute
=
arg
.
input_permute_
;
bool
output_permute
=
arg
.
output_permute_
;
std
::
array
<
ck
::
index_t
,
array_size
>
a_gs_ms_ks_lengths
{
G0
,
G1
,
M
,
K
};
std
::
array
<
ck
::
index_t
,
array_size
>
a_gs_ms_ks_strides
=
input_permute
?
std
::
array
<
ck
::
index_t
,
array_size
>
{
M
*
G1
*
K
,
K
,
G1
*
K
,
1
}
// A layout [G0, M, G1, K]
:
std
::
array
<
ck
::
index_t
,
array_size
>
{
G1
*
M
*
K
,
M
*
K
,
K
,
1
};
// A layout [G0, G1, M, K]
std
::
array
<
ck
::
index_t
,
array_size
>
b0_gs_ns_ks_lengths
{
G0
,
G1
,
N
,
K
};
std
::
array
<
ck
::
index_t
,
array_size
>
b0_gs_ns_ks_strides
=
input_permute
?
std
::
array
<
ck
::
index_t
,
array_size
>
{
N
*
G1
*
K
,
K
,
G1
*
K
,
1
}
// B0 layout [G0, N, G1, K]
:
std
::
array
<
ck
::
index_t
,
array_size
>
{
G1
*
N
*
K
,
N
*
K
,
K
,
1
};
// B0 layout [G0, G1, N, K]
std
::
array
<
ck
::
index_t
,
array_size
>
b1_gs_os_ns_lengths
{
G0
,
G1
,
O
,
N
};
std
::
array
<
ck
::
index_t
,
array_size
>
b1_gs_os_ns_strides
=
input_permute
?
std
::
array
<
ck
::
index_t
,
array_size
>
{
N
*
G1
*
O
,
O
,
1
,
G1
*
O
}
// B1 layout [G0, N, G1, O]
:
std
::
array
<
ck
::
index_t
,
array_size
>
{
G1
*
N
*
O
,
N
*
O
,
1
,
O
};
// B1 layout [G0, G1, N, O]
std
::
array
<
ck
::
index_t
,
array_size
>
c_gs_ms_os_lengths
{
G0
,
G1
,
M
,
O
};
std
::
array
<
ck
::
index_t
,
array_size
>
c_gs_ms_os_strides
=
output_permute
?
std
::
array
<
ck
::
index_t
,
array_size
>
{
M
*
G1
*
O
,
O
,
G1
*
O
,
1
}
// C layout [G0, M, G1, O]
:
std
::
array
<
ck
::
index_t
,
array_size
>
{
G1
*
M
*
O
,
M
*
O
,
O
,
1
};
// C layout [G0, G1, M, O]
const
auto
a_grid_desc
=
DeviceOp
::
MakeAGridDescriptor
(
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
);
const
auto
b0_grid_desc
=
DeviceOp
::
MakeB0GridDescriptor
(
b0_gs_ns_ks_lengths
,
b0_gs_ns_ks_strides
);
const
auto
b1_grid_desc
=
DeviceOp
::
MakeB1GridDescriptor
(
b1_gs_os_ns_lengths
,
b1_gs_os_ns_strides
);
const
auto
c_grid_desc_m_n
=
DeviceOp
::
Transform
::
MakeCGridDescriptor_M_N
(
c_gs_ms_os_lengths
,
c_gs_ms_os_strides
);
const
auto
block_2_ctile_map
=
GridwiseOp
::
MakeDefaultBlock2CTileMap
(
c_grid_desc_m_n
,
1
,
1
);
const
auto
c_grid_desc_g_m_n
=
DeviceOp
::
Transform
::
MakeCGridDescriptor_G_M_N
(
c_gs_ms_os_lengths
,
c_gs_ms_os_strides
);
index_t
batch_count
=
c_grid_desc_g_m_n
.
GetLength
(
Number
<
0
>
{});
if
(
!
GridwiseOp
::
CheckValidity
(
a_grid_desc
,
b0_grid_desc
,
b1_grid_desc
,
c_grid_desc_m_n
,
block_2_ctile_map
))
{
return
false
;
}
// Check if C permute dimension matches GEMM + GEMM shape
const
index_t
c_g
=
c_grid_desc_g_m_n
.
GetLength
(
I0
);
// unpadded
if
(
!
(
c_g
==
batch_count
))
{
printf
(
"DeviceOp: BatchCount err"
);
return
false
;
}
// Note: we need raw lengths since threadwise copy can not handle vector load when part of
// vector is out of bounds
// Note: need lowest dim in Ms/Ns/Ks/Os, not merged M/N/K/O
const
auto
MzRaw
=
M
;
const
auto
LzRaw
=
N
;
const
auto
KzRaw
=
K
;
const
auto
NzRaw
=
O
;
// Check scalar per vector requirement
const
auto
a_extent_lowest
=
ABlockTransferSrcVectorDim
==
2
?
KzRaw
:
MzRaw
;
const
auto
b0_extent_lowest
=
B0BlockTransferSrcVectorDim
==
2
?
KzRaw
:
LzRaw
;
const
auto
b1_extent_lowest
=
B1BlockTransferSrcVectorDim
==
2
?
LzRaw
:
NzRaw
;
const
auto
c_extent_lowest
=
NzRaw
;
if
(
!
(
a_extent_lowest
%
ABlockTransferSrcScalarPerVector
==
0
&&
b0_extent_lowest
%
B0BlockTransferSrcScalarPerVector
==
0
&&
b1_extent_lowest
%
B1BlockTransferSrcScalarPerVector
==
0
&&
c_extent_lowest
%
CShuffleBlockTransferScalarPerVector_NPerBlock
==
0
))
{
printf
(
"DeviceOp: Data Transfer Vector scalar err"
);
return
false
;
}
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>
a_mz_kz_strides_
{
a_gs_ms_ks_strides
[
NumDimG
+
NumDimM
-
1
],
a_gs_ms_ks_strides
[
NumDimG
+
NumDimM
+
NumDimK
-
1
]};
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>
b0_lz_kz_strides_
{
b0_gs_ns_ks_strides
[
NumDimG
+
NumDimL
-
1
],
b0_gs_ns_ks_strides
[
NumDimG
+
NumDimL
+
NumDimK
-
1
]};
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>
b1_nz_lz_strides_
{
b1_gs_os_ns_strides
[
NumDimG
+
NumDimN
-
1
],
b1_gs_os_ns_strides
[
NumDimG
+
NumDimN
+
NumDimL
-
1
]};
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>
c_mz_nz_strides_
{
c_gs_ms_os_strides
[
NumDimG
+
NumDimM
-
1
],
c_gs_ms_os_strides
[
NumDimG
+
NumDimM
+
NumDimN
-
1
]};
// Check vector load/store requirement
const
auto
a_stride_lowest
=
ABlockTransferSrcVectorDim
==
2
?
a_mz_kz_strides_
[
1
]
:
a_mz_kz_strides_
[
0
];
const
auto
b0_stride_lowest
=
B0BlockTransferSrcVectorDim
==
2
?
b0_lz_kz_strides_
[
1
]
:
b0_lz_kz_strides_
[
0
];
const
auto
b1_stride_lowest
=
B1BlockTransferSrcVectorDim
==
2
?
b1_nz_lz_strides_
[
1
]
:
b1_nz_lz_strides_
[
0
];
const
auto
c_stride_lowest
=
c_mz_nz_strides_
[
1
];
if
(
!
(
a_stride_lowest
==
1
||
b0_stride_lowest
==
1
||
b1_stride_lowest
==
1
||
c_stride_lowest
==
1
))
{
printf
(
"DeviceOp: Data Vectorize transfer err"
);
return
false
;
}
return
true
;
}
// polymorphic
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
{
return
IsSupportedArgument
(
*
dynamic_cast
<
const
RawArg
*>
(
p_arg
));
}
struct
SelfAttnArg
:
public
BaseArgument
{
SelfAttnArg
(
const
ADataType
*
p_qkv_grid
,
CDataType
*
p_out_grid
,
index_t
batch_size
,
index_t
sequence_length
,
index_t
head_count
,
index_t
head_size
,
float
alpha
)
:
p_qkv_grid_
{
p_qkv_grid
},
p_out_grid_
{
p_out_grid
},
batch_size_
{
batch_size
},
sequence_length_
{
sequence_length
},
head_count_
{
head_count
},
head_size_
{
head_size
},
alpha_
{
alpha
}
{
}
// Pointers
const
ADataType
*
p_qkv_grid_
;
CDataType
*
p_out_grid_
;
// Raw Problem Size
index_t
batch_size_
;
index_t
sequence_length_
;
index_t
head_count_
;
index_t
head_size_
;
float
alpha_
;
};
static
auto
MakeSelfAttnArgument
(
const
ADataType
*
p_qkv_grid
,
CDataType
*
p_out_grid
,
index_t
batch_size
,
index_t
sequence_length
,
index_t
head_count
,
index_t
head_size
,
float
alpha
)
{
return
SelfAttnArg
{
p_qkv_grid
,
p_out_grid
,
batch_size
,
sequence_length
,
head_count
,
head_size
,
alpha
};
}
struct
CrossAttnArg
:
public
BaseArgument
{
CrossAttnArg
(
const
ADataType
*
p_q_grid
,
const
B0DataType
*
p_kv_grid
,
CDataType
*
p_out_grid
,
index_t
batch_size
,
index_t
q_sequence_length
,
index_t
kv_sequence_length
,
index_t
head_count
,
index_t
head_size
,
float
alpha
)
:
p_q_grid_
{
p_q_grid
},
p_kv_grid_
{
p_kv_grid
},
p_out_grid_
{
p_out_grid
},
batch_size_
{
batch_size
},
q_sequence_length_
{
q_sequence_length
},
kv_sequence_length_
{
kv_sequence_length
},
head_count_
{
head_count
},
head_size_
{
head_size
},
alpha_
{
alpha
}
{
}
// Pointers
const
ADataType
*
p_q_grid_
;
const
B0DataType
*
p_kv_grid_
;
CDataType
*
p_out_grid_
;
// Raw Problem Size
index_t
batch_size_
;
index_t
q_sequence_length_
;
index_t
kv_sequence_length_
;
index_t
head_count_
;
index_t
head_size_
;
float
alpha_
;
};
static
auto
MakeCrossAttnArgument
(
const
ADataType
*
p_q_grid
,
const
B0DataType
*
p_kv_grid
,
CDataType
*
p_out_grid
,
index_t
batch_size
,
index_t
q_sequence_length
,
index_t
kv_sequence_length
,
index_t
head_count
,
index_t
head_size
,
float
alpha
)
{
return
CrossAttnArg
{
p_q_grid
,
p_kv_grid
,
p_out_grid
,
batch_size
,
q_sequence_length
,
kv_sequence_length
,
head_count
,
head_size
,
alpha
};
}
// Argument
// Argument
struct
Argument
:
public
BaseArgument
struct
Argument
:
public
BaseArgument
{
{
...
@@ -383,14 +1120,14 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
...
@@ -383,14 +1120,14 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
CDataType
*
p_c_grid
,
CDataType
*
p_c_grid
,
const
std
::
array
<
void
*
,
NumAcc0Bias
>
p_acc0_biases
,
const
std
::
array
<
void
*
,
NumAcc0Bias
>
p_acc0_biases
,
const
std
::
array
<
void
*
,
NumAcc1Bias
>
p_acc1_biases
,
const
std
::
array
<
void
*
,
NumAcc1Bias
>
p_acc1_biases
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths
,
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
a_gs_ms_ks_lengths
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides
,
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
a_gs_ms_ks_strides
,
const
std
::
vector
<
index_t
>&
b0_gs_ls_ks_lengths
,
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
b0_gs_ls_ks_lengths
,
const
std
::
vector
<
index_t
>&
b0_gs_ls_ks_strides
,
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
b0_gs_ls_ks_strides
,
const
std
::
vector
<
index_t
>&
b1_gs_ns_ls_lengths
,
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
b1_gs_ns_ls_lengths
,
const
std
::
vector
<
index_t
>&
b1_gs_ns_ls_strides
,
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
b1_gs_ns_ls_strides
,
const
std
::
vector
<
index_t
>&
c_gs_ms_ns_lengths
,
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
c_gs_ms_ns_lengths
,
const
std
::
vector
<
index_t
>&
c_gs_ms_ns_strides
,
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
c_gs_ms_ns_strides
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc0Bias
>
acc0_biases_gs_ms_ls_lengths
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc0Bias
>
acc0_biases_gs_ms_ls_lengths
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc0Bias
>
acc0_biases_gs_ms_ls_strides
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc0Bias
>
acc0_biases_gs_ms_ls_strides
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc1Bias
>
acc1_biases_gs_ms_ns_lengths
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc1Bias
>
acc1_biases_gs_ms_ns_lengths
,
...
@@ -497,11 +1234,11 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
...
@@ -497,11 +1234,11 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
// Strides for the last M/N/K dimensions of A/B0/B1/C
// Strides for the last M/N/K dimensions of A/B0/B1/C
// for sanity check of vector load/store
// for sanity check of vector load/store
std
::
vector
<
index_t
>
raw_lengths_mz_lz_kz_nz_
;
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>
raw_lengths_mz_lz_kz_nz_
;
std
::
vector
<
index_t
>
a_mz_kz_strides_
;
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>
a_mz_kz_strides_
;
std
::
vector
<
index_t
>
b0_lz_kz_strides_
;
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>
b0_lz_kz_strides_
;
std
::
vector
<
index_t
>
b1_nz_lz_strides_
;
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>
b1_nz_lz_strides_
;
std
::
vector
<
index_t
>
c_mz_nz_strides_
;
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>
c_mz_nz_strides_
;
index_t
batch_count_
;
index_t
batch_count_
;
// Batch Offset
// Batch Offset
...
@@ -509,46 +1246,151 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
...
@@ -509,46 +1246,151 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
};
};
// Invoker
// Invoker
struct
Invoker
:
public
BaseInvoker
struct
SelfAttn
Invoker
:
public
BaseInvoker
{
{
using
Argument
=
DeviceOp
::
Argument
;
using
Argument
=
DeviceOp
::
SelfAttnArg
;
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
{
const
index_t
grid_size
=
const
auto
M0
=
math
::
integer_divide_ceil
(
arg
.
sequence_length_
,
MPerBlock
);
arg
.
block_2_ctile_map_
.
CalculateGridSize
(
arg
.
c_grid_desc_m_n_
)
*
arg
.
batch_count_
;
const
auto
N0
=
math
::
integer_divide_ceil
(
arg
.
head_size_
,
NPerBlock
)
;
const
auto
K
=
[
&
]()
{
const
index_t
grid_size
=
arg
.
batch_size_
*
arg
.
head_count_
*
M0
*
N0
;
if
constexpr
(
AEnableLds
)
const
auto
K
=
arg
.
head_size_
;
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop
)
{
const
auto
kernel
=
kernel_wmma_self_attention_forward
<
DeviceOp
,
GridwiseOp
,
ADataType
,
CDataType
,
AElementwiseOperation
,
B0ElementwiseOperation
,
AccElementwiseOperation
,
B1ElementwiseOperation
,
CElementwiseOperation
,
has_main_k_block_loop
>
;
return
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
arg
.
p_qkv_grid_
,
arg
.
p_out_grid_
,
arg
.
batch_size_
,
arg
.
sequence_length_
,
arg
.
head_count_
,
arg
.
head_size_
,
arg
.
alpha_
);
};
if
(
GridwiseOp
::
CalculateHasMainKBlockLoop
(
K
))
{
{
return
arg
.
a_grid_desc
.
GetLength
(
I0
)
*
arg
.
a_grid_desc
.
GetLength
(
I2
);
return
launch_kernel
(
integral_constant
<
bool
,
true
>
{}
);
}
}
else
else
{
{
return
arg
.
a_grid_desc
.
GetLength
(
I0
)
*
arg
.
a_grid_desc
.
GetLength
(
I3
)
*
return
launch_kernel
(
integral_constant
<
bool
,
false
>
{});
arg
.
a_grid_desc
.
GetLength
(
I4
)
*
arg
.
a_grid_desc
.
GetLength
(
I6
);
}
}
}
}();
// polymorphic
float
Run
(
const
BaseArgument
*
p_arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
override
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
stream_config
);
}
};
static
auto
MakeSelfAttnInvoker
()
{
return
SelfAttnInvoker
{};
}
// Invoker
struct
CrossAttnInvoker
:
public
BaseInvoker
{
using
Argument
=
DeviceOp
::
CrossAttnArg
;
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
const
auto
M0
=
math
::
integer_divide_ceil
(
arg
.
q_sequence_length_
,
MPerBlock
);
const
auto
N0
=
math
::
integer_divide_ceil
(
arg
.
head_size_
,
NPerBlock
);
const
index_t
grid_size
=
arg
.
batch_size_
*
arg
.
head_count_
*
M0
*
N0
;
const
auto
K
=
arg
.
head_size_
;
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop
)
{
const
auto
kernel
=
kernel_wmma_cross_attention_forward
<
DeviceOp
,
GridwiseOp
,
ADataType
,
B0DataType
,
CDataType
,
AElementwiseOperation
,
B0ElementwiseOperation
,
AccElementwiseOperation
,
B1ElementwiseOperation
,
CElementwiseOperation
,
has_main_k_block_loop
>
;
return
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
arg
.
p_q_grid_
,
arg
.
p_kv_grid_
,
arg
.
p_out_grid_
,
arg
.
batch_size_
,
arg
.
q_sequence_length_
,
arg
.
kv_sequence_length_
,
arg
.
head_count_
,
arg
.
head_size_
,
arg
.
alpha_
);
};
if
(
GridwiseOp
::
CalculateHasMainKBlockLoop
(
K
))
{
return
launch_kernel
(
integral_constant
<
bool
,
true
>
{});
}
else
{
return
launch_kernel
(
integral_constant
<
bool
,
false
>
{});
}
}
// polymorphic
float
Run
(
const
BaseArgument
*
p_arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
override
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
stream_config
);
}
};
static
auto
MakeCrossAttnInvoker
()
{
return
CrossAttnInvoker
{};
}
struct
Invoker
:
public
BaseInvoker
{
using
Argument
=
DeviceOp
::
RawArg
;
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
const
auto
M0
=
math
::
integer_divide_ceil
(
arg
.
M_
,
MPerBlock
);
const
auto
N0
=
math
::
integer_divide_ceil
(
arg
.
O_
,
NPerBlock
);
const
index_t
grid_size
=
arg
.
G0_
*
arg
.
G1_
*
M0
*
N0
;
const
auto
K
=
arg
.
K_
;
// printf("HasKBlockLoop: %d\n", GridwiseOp::CalculateHasMainKBlockLoop(K));
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop
)
{
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop
)
{
const
auto
kernel
=
kernel_batched_gemm_softmax_gemm_wmma_cshuffle
<
const
auto
kernel
=
kernel_batched_gemm_softmax_gemm_wmma_cshuffle
<
DeviceOp
,
GridwiseOp
,
GridwiseOp
,
ADataType
,
ADataType
,
B0DataType
,
B0DataType
,
B1DataType
,
B1DataType
,
CDataType
,
CDataType
,
DeviceOp
::
AGridDesc
,
DeviceOp
::
B0GridDesc
,
DeviceOp
::
B1GridDesc
,
typename
GridwiseOp
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
AElementwiseOperation
,
AElementwiseOperation
,
B0ElementwiseOperation
,
B0ElementwiseOperation
,
AccElementwiseOperation
,
AccElementwiseOperation
,
B1ElementwiseOperation
,
B1ElementwiseOperation
,
CElementwiseOperation
,
CElementwiseOperation
,
ComputeBasePtrOfStridedBatch
,
C0MatrixMask
,
typename
GridwiseOp
::
DefaultBlock2CTileMap
,
has_main_k_block_loop
>
;
has_main_k_block_loop
>
;
return
launch_and_time_kernel
(
stream_config
,
return
launch_and_time_kernel
(
stream_config
,
...
@@ -560,19 +1402,15 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
...
@@ -560,19 +1402,15 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
arg
.
p_b0_grid_
,
arg
.
p_b0_grid_
,
arg
.
p_b1_grid_
,
arg
.
p_b1_grid_
,
arg
.
p_c_grid_
,
arg
.
p_c_grid_
,
arg
.
a_grid_desc
,
arg
.
M_
,
arg
.
b0_grid_desc
,
arg
.
N_
,
arg
.
b1_grid_desc
,
arg
.
K_
,
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
O_
,
arg
.
a_element_op_
,
arg
.
G0_
,
arg
.
b0_element_op_
,
arg
.
G1_
,
arg
.
acc_element_op_
,
arg
.
alpha_
,
arg
.
b1_element_op_
,
arg
.
input_permute_
,
arg
.
c_element_op_
,
arg
.
output_permute_
);
arg
.
batch_count_
,
arg
.
compute_ptr_offset_of_batch_
,
arg
.
c0_matrix_mask_
,
arg
.
block_2_ctile_map_
);
};
};
if
(
GridwiseOp
::
CalculateHasMainKBlockLoop
(
K
))
if
(
GridwiseOp
::
CalculateHasMainKBlockLoop
(
K
))
...
@@ -598,7 +1436,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
...
@@ -598,7 +1436,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
// TODO: properly implement this check
// TODO: properly implement this check
return
true
;
return
true
;
}
}
#if 0
static bool IsSupportedArgument(const Argument& arg)
static bool IsSupportedArgument(const Argument& arg)
{
{
if(ck::get_device_name() == "gfx1100" || ck::get_device_name() == "gfx1101" ||
if(ck::get_device_name() == "gfx1100" || ck::get_device_name() == "gfx1101" ||
...
@@ -695,14 +1533,14 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
...
@@ -695,14 +1533,14 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
CDataType* p_c,
CDataType* p_c,
const std::array<void*, NumAcc0Bias> p_acc0_biases,
const std::array<void*, NumAcc0Bias> p_acc0_biases,
const std::array<void*, NumAcc1Bias> p_acc1_biases,
const std::array<void*, NumAcc1Bias> p_acc1_biases,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths
,
const std::
array
<index_t
, NumDimG + NumDimM + NumDimN
>& a_gs_ms_ks_lengths,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides
,
const std::
array
<index_t
, NumDimG + NumDimM + NumDimN
>& a_gs_ms_ks_strides,
const
std
::
vector
<
index_t
>&
b0_gs_ls_ks_lengths
,
const std::
array
<index_t
, NumDimG + NumDimM + NumDimN
>& b0_gs_ls_ks_lengths,
const
std
::
vector
<
index_t
>&
b0_gs_ls_ks_strides
,
const std::
array
<index_t
, NumDimG + NumDimM + NumDimN
>& b0_gs_ls_ks_strides,
const
std
::
vector
<
index_t
>&
b1_gs_ns_ls_lengths
,
const std::
array
<index_t
, NumDimG + NumDimM + NumDimN
>& b1_gs_ns_ls_lengths,
const
std
::
vector
<
index_t
>&
b1_gs_ns_ls_strides
,
const std::
array
<index_t
, NumDimG + NumDimM + NumDimN
>& b1_gs_ns_ls_strides,
const
std
::
vector
<
index_t
>&
c_gs_ms_ns_lengths
,
const std::
array
<index_t
, NumDimG + NumDimM + NumDimN
>& c_gs_ms_ns_lengths,
const
std
::
vector
<
index_t
>&
c_gs_ms_ns_strides
,
const std::
array
<index_t
, NumDimG + NumDimM + NumDimN
>& c_gs_ms_ns_strides,
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ls_lengths,
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ls_lengths,
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ls_strides,
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ls_strides,
const std::array<std::vector<ck::index_t>, NumAcc1Bias> acc1_biases_gs_ms_ns_lengths,
const std::array<std::vector<ck::index_t>, NumAcc1Bias> acc1_biases_gs_ms_ns_lengths,
...
@@ -739,6 +1577,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
...
@@ -739,6 +1577,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
b1_element_op,
b1_element_op,
c_element_op};
c_element_op};
}
}
#endif
// polymorphic
// polymorphic
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
...
@@ -766,20 +1605,60 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
...
@@ -766,20 +1605,60 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
B1ElementwiseOperation
b1_element_op
,
B1ElementwiseOperation
b1_element_op
,
CElementwiseOperation
c_element_op
)
override
CElementwiseOperation
c_element_op
)
override
{
{
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>
a_lengths
;
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>
a_strides
;
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>
b0_lengths
;
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>
b0_strides
;
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>
b1_lengths
;
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>
b1_strides
;
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>
c_lengths
;
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>
c_strides
;
std
::
transform
(
a_gs_ms_ks_lengths
.
begin
(),
a_gs_ms_ks_lengths
.
end
(),
a_lengths
.
begin
(),
[](
index_t
i
)
{
return
i
;
});
std
::
transform
(
a_gs_ms_ks_strides
.
begin
(),
a_gs_ms_ks_strides
.
end
(),
a_strides
.
begin
(),
[](
index_t
i
)
{
return
i
;
});
std
::
transform
(
b0_gs_ls_ks_lengths
.
begin
(),
b0_gs_ls_ks_lengths
.
end
(),
b0_lengths
.
begin
(),
[](
index_t
i
)
{
return
i
;
});
std
::
transform
(
b0_gs_ls_ks_strides
.
begin
(),
b0_gs_ls_ks_strides
.
end
(),
b0_strides
.
begin
(),
[](
index_t
i
)
{
return
i
;
});
std
::
transform
(
b1_gs_ns_ls_lengths
.
begin
(),
b1_gs_ns_ls_lengths
.
end
(),
b1_lengths
.
begin
(),
[](
index_t
i
)
{
return
i
;
});
std
::
transform
(
b1_gs_ns_ls_strides
.
begin
(),
b1_gs_ns_ls_strides
.
end
(),
b1_strides
.
begin
(),
[](
index_t
i
)
{
return
i
;
});
std
::
transform
(
c_gs_ms_ns_lengths
.
begin
(),
c_gs_ms_ns_lengths
.
end
(),
c_lengths
.
begin
(),
[](
index_t
i
)
{
return
i
;
});
std
::
transform
(
c_gs_ms_ns_strides
.
begin
(),
c_gs_ms_ns_strides
.
end
(),
c_strides
.
begin
(),
[](
index_t
i
)
{
return
i
;
});
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
ADataType
*>
(
p_a
),
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
ADataType
*>
(
p_a
),
static_cast
<
const
B0DataType
*>
(
p_b0
),
static_cast
<
const
B0DataType
*>
(
p_b0
),
static_cast
<
const
B1DataType
*>
(
p_b1
),
static_cast
<
const
B1DataType
*>
(
p_b1
),
static_cast
<
CDataType
*>
(
p_c
),
static_cast
<
CDataType
*>
(
p_c
),
p_acc0_biases
,
p_acc0_biases
,
p_acc1_biases
,
p_acc1_biases
,
a_
gs_ms_ks_
lengths
,
a_lengths
,
a_
gs_ms_ks_
strides
,
a_strides
,
b0_
gs_ls_ks_
lengths
,
b0_lengths
,
b0_
gs_ls_ks_
strides
,
b0_strides
,
b1_
gs_ns_ls_
lengths
,
b1_lengths
,
b1_
gs_ns_ls_
strides
,
b1_strides
,
c_
gs_ms_ns_
lengths
,
c_lengths
,
c_
gs_ms_ns_
strides
,
c_strides
,
acc0_biases_gs_ms_ls_lengths
,
acc0_biases_gs_ms_ls_lengths
,
acc0_biases_gs_ms_ls_strides
,
acc0_biases_gs_ms_ls_strides
,
acc1_biases_gs_ms_ns_lengths
,
acc1_biases_gs_ms_ns_lengths
,
...
@@ -819,11 +1698,12 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
...
@@ -819,11 +1698,12 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
<<
MPerBlock
<<
", "
<<
MPerBlock
<<
", "
<<
LPerBlock
<<
", "
<<
LPerBlock
<<
", "
<<
KPerBlock
<<
", "
<<
KPerBlock
<<
", "
<<
K1
<<
", "
<<
AK1
<<
", "
<<
BK1
<<
", "
<<
MPerBlock
<<
", "
<<
MPerBlock
<<
", "
<<
NPerBlock
<<
", "
<<
NPerBlock
<<
", "
<<
LTilePerBlock
<<
", "
<<
LTilePerBlock
<<
", "
<<
L1
<<
L1
<<
", "
<<
getGemmSpecializationString
(
GemmSpec
)
<<
", "
<<
getGemmSpecializationString
(
GemmSpec
)
<<
", "
<<
"ASpec"
<<
getTensorSpecializationString
(
ASpec
)
<<
", "
<<
"ASpec"
<<
getTensorSpecializationString
(
ASpec
)
<<
", "
<<
"B0Spec"
<<
getTensorSpecializationString
(
B0Spec
)
<<
", "
<<
"B0Spec"
<<
getTensorSpecializationString
(
B0Spec
)
<<
", "
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp
View file @
823c8801
...
@@ -650,7 +650,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
...
@@ -650,7 +650,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
// check if it's 1x1, stride=1 conv
// check if it's 1x1, stride=1 conv
for
(
index_t
i
=
0
;
i
<
NDimSpatial
;
++
i
)
for
(
index_t
i
=
0
;
i
<
NDimSpatial
;
++
i
)
{
{
const
index_t
X
=
arg
.
b_g_k_c_xs_lengths_
[
i
+
2
];
const
index_t
X
=
arg
.
b_g_k_c_xs_lengths_
[
i
+
3
];
const
index_t
ConvStride
=
arg
.
conv_filter_strides_
[
i
];
const
index_t
ConvStride
=
arg
.
conv_filter_strides_
[
i
];
const
index_t
LeftPad
=
arg
.
input_left_pads_
[
i
];
const
index_t
LeftPad
=
arg
.
input_left_pads_
[
i
];
const
index_t
RightPad
=
arg
.
input_right_pads_
[
i
];
const
index_t
RightPad
=
arg
.
input_right_pads_
[
i
];
...
@@ -667,7 +667,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
...
@@ -667,7 +667,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
// check if it's 1x1 conv
// check if it's 1x1 conv
for
(
index_t
i
=
0
;
i
<
NDimSpatial
;
++
i
)
for
(
index_t
i
=
0
;
i
<
NDimSpatial
;
++
i
)
{
{
const
index_t
X
=
arg
.
b_g_k_c_xs_lengths_
[
i
+
2
];
const
index_t
X
=
arg
.
b_g_k_c_xs_lengths_
[
i
+
3
];
const
index_t
LeftPad
=
arg
.
input_left_pads_
[
i
];
const
index_t
LeftPad
=
arg
.
input_left_pads_
[
i
];
const
index_t
RightPad
=
arg
.
input_right_pads_
[
i
];
const
index_t
RightPad
=
arg
.
input_right_pads_
[
i
];
...
...
include/ck/tensor_operation/gpu/device/masking_specialization.hpp
View file @
823c8801
...
@@ -53,7 +53,10 @@ struct MaskOutUpperTrianglePredicate
...
@@ -53,7 +53,10 @@ struct MaskOutUpperTrianglePredicate
template
<
typename
MaskOutPredicate
>
template
<
typename
MaskOutPredicate
>
struct
C0MatrixMask_impl
struct
C0MatrixMask_impl
{
{
C0MatrixMask_impl
(
index_t
NRaw
)
:
NRaw_
(
NRaw
),
predicate_
(
MaskOutPredicate
{})
{}
__host__
__device__
C0MatrixMask_impl
(
index_t
NRaw
)
:
NRaw_
(
NRaw
),
predicate_
(
MaskOutPredicate
{})
{
}
__host__
__device__
constexpr
bool
IsNOutOfBound
(
/*index_t m, */
index_t
n
)
const
__host__
__device__
constexpr
bool
IsNOutOfBound
(
/*index_t m, */
index_t
n
)
const
{
{
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp
View file @
823c8801
...
@@ -18,102 +18,6 @@
...
@@ -18,102 +18,6 @@
namespace
ck
{
namespace
ck
{
template
<
typename
GridwiseOp
,
typename
ADataType
,
typename
B0DataType
,
typename
B1DataType
,
typename
CDataType
,
typename
AGridDesc
,
typename
B0GridDesc
,
typename
B1GridDesc
,
typename
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
AElementwiseOperation
,
typename
B0ElementwiseOperation
,
typename
AccElementwiseOperation
,
typename
B1ElementwiseOperation
,
typename
CElementwiseOperation
,
typename
ComputeBasePtrOfStridedBatch
,
typename
C0MatrixMask
,
typename
Block2CTileMap
,
bool
HasMainKBlockLoop
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
kernel_batched_gemm_softmax_gemm_wmma_cshuffle
(
const
ADataType
*
__restrict__
p_a_grid
,
const
B0DataType
*
__restrict__
p_b0_grid
,
const
B1DataType
*
__restrict__
p_b1_grid
,
CDataType
*
__restrict__
p_c_grid
,
const
AGridDesc
a_grid_desc
,
const
B0GridDesc
b0_grid_desc
,
const
B1GridDesc
b1_grid_desc
,
const
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock
,
const
AElementwiseOperation
a_element_op
,
const
B0ElementwiseOperation
b0_element_op
,
const
AccElementwiseOperation
acc_element_op
,
const
B1ElementwiseOperation
b1_element_op
,
const
CElementwiseOperation
c_element_op
,
const
index_t
batch_count
,
const
ComputeBasePtrOfStridedBatch
compute_base_ptr_of_batch
,
const
C0MatrixMask
c0_matrix_mask
,
const
Block2CTileMap
block_2_ctile_map
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__) || defined(__gfx1101__) || \
defined(__gfx1102__))
__shared__
char
p_shared
[
GridwiseOp
::
GetSharedMemoryNumberOfByte
()];
const
index_t
num_blocks_per_batch
=
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
get_block_1d_id
()
/
num_blocks_per_batch
);
const
long_index_t
a_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetABasePtr
(
g_idx
)));
const
long_index_t
b0_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetB0BasePtr
(
g_idx
)));
const
long_index_t
b1_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetB1BasePtr
(
g_idx
)));
const
long_index_t
c_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetCBasePtr
(
g_idx
)));
GridwiseOp
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
+
a_batch_offset
,
p_b0_grid
+
b0_batch_offset
,
p_b1_grid
+
b1_batch_offset
,
p_c_grid
+
c_batch_offset
,
p_shared
,
a_grid_desc
,
b0_grid_desc
,
b1_grid_desc
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
a_element_op
,
b0_element_op
,
acc_element_op
,
b1_element_op
,
c_element_op
,
c0_matrix_mask
,
block_2_ctile_map
);
#else
ignore
=
p_a_grid
;
ignore
=
p_b0_grid
;
ignore
=
p_b1_grid
;
ignore
=
p_c_grid
;
ignore
=
a_grid_desc
;
ignore
=
b0_grid_desc
;
ignore
=
b1_grid_desc
;
ignore
=
c_grid_desc_mblock_mperblock_nblock_nperblock
;
ignore
=
a_element_op
;
ignore
=
b0_element_op
;
ignore
=
acc_element_op
;
ignore
=
b1_element_op
;
ignore
=
c_element_op
;
ignore
=
batch_count
;
ignore
=
compute_base_ptr_of_batch
;
ignore
=
c0_matrix_mask
;
ignore
=
block_2_ctile_map
;
#endif // end of if (defined(__gfx1100__))
}
// Gemm0: A [M x K] x B0 [K x L] = Acc [M x L]
// Gemm0: A [M x K] x B0 [K x L] = Acc [M x L]
// Gemm1: Acc [M x L] x B1 [L x N] = C [M x N]
// Gemm1: Acc [M x L] x B1 [L x N] = C [M x N]
template
<
typename
ADataType
,
template
<
typename
ADataType
,
...
@@ -136,7 +40,8 @@ template <typename ADataType,
...
@@ -136,7 +40,8 @@ template <typename ADataType,
index_t
MPerBlock
,
index_t
MPerBlock
,
index_t
LPerBlock
,
index_t
LPerBlock
,
index_t
KPerBlock
,
index_t
KPerBlock
,
index_t
K1Value
,
index_t
AK1Value
,
index_t
BK1Value
,
index_t
NPerBlock
,
index_t
NPerBlock
,
index_t
LTilePerBlock
,
index_t
LTilePerBlock
,
index_t
L1Value
,
index_t
L1Value
,
...
@@ -194,9 +99,9 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
...
@@ -194,9 +99,9 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
static
constexpr
auto
I6
=
Number
<
6
>
{};
static
constexpr
auto
I6
=
Number
<
6
>
{};
static
constexpr
auto
I7
=
Number
<
7
>
{};
static
constexpr
auto
I7
=
Number
<
7
>
{};
static
constexpr
auto
AK1
=
Number
<
K1Value
>
{};
static
constexpr
auto
AK1
=
Number
<
A
K1Value
>
{};
static
constexpr
auto
BK0
=
Number
<
KPerBlock
/
K1Value
>
{};
static
constexpr
auto
BK0
=
Number
<
KPerBlock
/
B
K1Value
>
{};
static
constexpr
auto
BK1
=
Number
<
K1Value
>
{};
static
constexpr
auto
BK1
=
Number
<
B
K1Value
>
{};
static
constexpr
auto
L0PerBlock
=
LTilePerBlock
/
L1Value
;
static
constexpr
auto
L0PerBlock
=
LTilePerBlock
/
L1Value
;
static
constexpr
auto
AL0
=
Number
<
L0PerBlock
/
2
>
{};
static
constexpr
auto
AL0
=
Number
<
L0PerBlock
/
2
>
{};
...
@@ -714,7 +619,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
...
@@ -714,7 +619,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
__host__
__device__
static
constexpr
bool
CalculateHasMainKBlockLoop
(
index_t
K
)
__host__
__device__
static
constexpr
bool
CalculateHasMainKBlockLoop
(
index_t
K
)
{
{
const
index_t
num_loop
=
K
/
KPerBlock
;
const
index_t
num_loop
=
math
::
integer_divide_ceil
(
K
,
KPerBlock
)
;
return
GridwiseGemmPipe
::
CalculateHasMainLoop
(
num_loop
);
return
GridwiseGemmPipe
::
CalculateHasMainLoop
(
num_loop
);
}
}
...
@@ -887,7 +792,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
...
@@ -887,7 +792,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
// Thread-wise copy
// Thread-wise copy
// KPerBlock/WmmaK -> MRepeat -> MWaves -> WmmaK/K1 -> MPerWmma -> K1
// KPerBlock/WmmaK -> MRepeat -> MWaves -> WmmaK/K1 -> MPerWmma -> K1
constexpr
auto
KWmmaPerBlock
=
KPerBlock
/
WmmaK
;
constexpr
auto
KWmmaPerBlock
=
KPerBlock
/
WmmaK
;
constexpr
auto
K0PerWmma
=
WmmaK
/
2
/
K1Value
;
constexpr
auto
K0PerWmma
=
WmmaK
/
2
/
A
K1Value
;
auto
a_block_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
ADataType
>
(
auto
a_block_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
ADataType
>
(
a_block_desc
.
GetElementSpaceSize
());
a_block_desc
.
GetElementSpaceSize
());
...
@@ -903,7 +808,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
...
@@ -903,7 +808,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
Number
<
K0PerWmma
>
{},
Number
<
K0PerWmma
>
{},
I1
,
I1
,
I1
,
I1
,
Number
<
K1Value
>
{}
>
,
Number
<
A
K1Value
>
{}
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
>
,
6
,
6
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferSrcScalarPerVector
,
...
@@ -966,7 +871,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
...
@@ -966,7 +871,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
// Thread-wise copy
// Thread-wise copy
// KPerBlock/WmmaK -> LRepeat -> LWaves -> KRow -> LPerWmma -> K1
// KPerBlock/WmmaK -> LRepeat -> LWaves -> KRow -> LPerWmma -> K1
constexpr
auto
KWmmaPerBlock
=
KPerBlock
/
WmmaK
;
constexpr
auto
KWmmaPerBlock
=
KPerBlock
/
WmmaK
;
constexpr
auto
K0PerWmma
=
WmmaK
/
2
/
K1Value
;
constexpr
auto
K0PerWmma
=
WmmaK
/
2
/
B
K1Value
;
auto
b0_block_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
B0DataType
>
(
auto
b0_block_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
B0DataType
>
(
b0_block_desc
.
GetElementSpaceSize
());
b0_block_desc
.
GetElementSpaceSize
());
...
@@ -982,7 +887,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
...
@@ -982,7 +887,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
Number
<
K0PerWmma
>
{},
Number
<
K0PerWmma
>
{},
I1
,
I1
,
I1
,
I1
,
Number
<
K1Value
>
{}
>
,
Number
<
A
K1Value
>
{}
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
>
,
6
,
6
,
B0BlockTransferSrcScalarPerVector
,
B0BlockTransferSrcScalarPerVector
,
...
@@ -1009,7 +914,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
...
@@ -1009,7 +914,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
/*******************************************************************************/
/*******************************************************************************/
// Gemm0
// Gemm0
constexpr
auto
KPack
=
math
::
integer_least_multiple
(
K1Value
,
WmmaK
);
constexpr
auto
KPack
=
math
::
integer_least_multiple
(
math
::
integer_least_multiple
(
AK1Value
,
BK1Value
)
,
WmmaK
);
auto
blockwise_gemm0
=
BlockwiseGemmWMMA
<
auto
blockwise_gemm0
=
BlockwiseGemmWMMA
<
BlockSize
,
BlockSize
,
...
...
include/ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp
View file @
823c8801
...
@@ -16,14 +16,15 @@ template <index_t NumDimG,
...
@@ -16,14 +16,15 @@ template <index_t NumDimG,
index_t
NumDimM
,
index_t
NumDimM
,
index_t
NumDimN
,
index_t
NumDimN
,
device
::
TensorSpecialization
TensorSpec
>
device
::
TensorSpecialization
TensorSpec
>
static
auto
MakeGridDescriptorPair
(
const
std
::
vector
<
index_t
>&
gs_ms_ns_lengths_vec
,
__host__
__device__
static
auto
const
std
::
vector
<
index_t
>&
gs_ms_ns_strides_vec
)
MakeGridDescriptorPair
(
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
gs_ms_ns_lengths_vec
,
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
gs_ms_ns_strides_vec
)
{
{
if
(
!
(
gs_ms_ns_lengths_vec
.
size
()
==
NumDimG
+
NumDimM
+
NumDimN
&&
//
if(!(gs_ms_ns_lengths_vec.size() == NumDimG + NumDimM + NumDimN &&
gs_ms_ns_strides_vec
.
size
()
==
NumDimG
+
NumDimM
+
NumDimN
))
//
gs_ms_ns_strides_vec.size() == NumDimG + NumDimM + NumDimN))
{
//
{
throw
std
::
runtime_error
(
"wrong! dimension must match input lengths"
);
//
throw std::runtime_error("wrong! dimension must match input lengths");
}
//
}
const
auto
to_tuple
=
[
&
](
auto
&
vec
,
auto
start
,
auto
end
)
{
const
auto
to_tuple
=
[
&
](
auto
&
vec
,
auto
start
,
auto
end
)
{
return
generate_tuple
([
&
](
auto
i
)
{
return
vec
[
start
+
i
];
},
Number
<
end
-
start
>
{});
return
generate_tuple
([
&
](
auto
i
)
{
return
vec
[
start
+
i
];
},
Number
<
end
-
start
>
{});
...
@@ -143,21 +144,24 @@ struct TransformBatchedContractionContractionToBatchedGemmGemm
...
@@ -143,21 +144,24 @@ struct TransformBatchedContractionContractionToBatchedGemmGemm
//
//
// A
// A
//
//
static
auto
MakeAGridDescriptorPair
(
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths_vec
,
__host__
__device__
static
auto
MakeAGridDescriptorPair
(
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides_vec
)
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
a_gs_ms_ks_lengths_vec
,
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
a_gs_ms_ks_strides_vec
)
{
{
return
MakeGridDescriptorPair
<
NumDimG
,
NumDimM
,
NumDimK
,
ASpec
>
(
a_gs_ms_ks_lengths_vec
,
return
MakeGridDescriptorPair
<
NumDimG
,
NumDimM
,
NumDimK
,
ASpec
>
(
a_gs_ms_ks_lengths_vec
,
a_gs_ms_ks_strides_vec
);
a_gs_ms_ks_strides_vec
);
}
}
// TODO: rename to G_MRaw_KRaw
// TODO: rename to G_MRaw_KRaw
static
auto
MakeAGridDescriptor_G_M_K
(
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths_vec
,
__host__
__device__
static
auto
MakeAGridDescriptor_G_M_K
(
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides_vec
)
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
a_gs_ms_ks_lengths_vec
,
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
a_gs_ms_ks_strides_vec
)
{
{
return
MakeAGridDescriptorPair
(
a_gs_ms_ks_lengths_vec
,
a_gs_ms_ks_strides_vec
).
first
;
return
MakeAGridDescriptorPair
(
a_gs_ms_ks_lengths_vec
,
a_gs_ms_ks_strides_vec
).
first
;
}
}
static
auto
MakeAGridDescriptor_M_K
(
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths_vec
,
__host__
__device__
static
auto
MakeAGridDescriptor_M_K
(
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides_vec
)
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
a_gs_ms_ks_lengths_vec
,
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
a_gs_ms_ks_strides_vec
)
{
{
return
matrix_padder
.
PadADescriptor_M_K
(
return
matrix_padder
.
PadADescriptor_M_K
(
MakeAGridDescriptorPair
(
a_gs_ms_ks_lengths_vec
,
a_gs_ms_ks_strides_vec
).
second
);
MakeAGridDescriptorPair
(
a_gs_ms_ks_lengths_vec
,
a_gs_ms_ks_strides_vec
).
second
);
...
@@ -212,21 +216,24 @@ struct TransformBatchedContractionContractionToBatchedGemmGemm
...
@@ -212,21 +216,24 @@ struct TransformBatchedContractionContractionToBatchedGemmGemm
//
//
// B (alias of B0)
// B (alias of B0)
//
//
static
auto
MakeB0GridDescriptorPair
(
const
std
::
vector
<
index_t
>&
b0_gs_ns_ks_lengths_vec
,
__host__
__device__
static
auto
MakeB0GridDescriptorPair
(
const
std
::
vector
<
index_t
>&
b0_gs_ns_ks_strides_vec
)
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
b0_gs_ns_ks_lengths_vec
,
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
b0_gs_ns_ks_strides_vec
)
{
{
return
MakeGridDescriptorPair
<
NumDimG
,
NumDimN
,
NumDimK
,
B0Spec
>
(
b0_gs_ns_ks_lengths_vec
,
return
MakeGridDescriptorPair
<
NumDimG
,
NumDimN
,
NumDimK
,
B0Spec
>
(
b0_gs_ns_ks_lengths_vec
,
b0_gs_ns_ks_strides_vec
);
b0_gs_ns_ks_strides_vec
);
}
}
// TODO: rename to G_MRaw_NRaw
// TODO: rename to G_MRaw_NRaw
static
auto
MakeB0GridDescriptor_G_N_K
(
const
std
::
vector
<
index_t
>&
b0_gs_ns_ks_lengths_vec
,
__host__
__device__
static
auto
MakeB0GridDescriptor_G_N_K
(
const
std
::
vector
<
index_t
>&
b0_gs_ns_ks_strides_vec
)
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
b0_gs_ns_ks_lengths_vec
,
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
b0_gs_ns_ks_strides_vec
)
{
{
return
MakeB0GridDescriptorPair
(
b0_gs_ns_ks_lengths_vec
,
b0_gs_ns_ks_strides_vec
).
first
;
return
MakeB0GridDescriptorPair
(
b0_gs_ns_ks_lengths_vec
,
b0_gs_ns_ks_strides_vec
).
first
;
}
}
static
auto
MakeB0GridDescriptor_N_K
(
const
std
::
vector
<
index_t
>&
b0_gs_ns_ks_lengths_vec
,
__host__
__device__
static
auto
MakeB0GridDescriptor_N_K
(
const
std
::
vector
<
index_t
>&
b0_gs_ns_ks_strides_vec
)
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
b0_gs_ns_ks_lengths_vec
,
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
b0_gs_ns_ks_strides_vec
)
{
{
// alias of matrix_padder.PadB0Descriptor_N_K
// alias of matrix_padder.PadB0Descriptor_N_K
return
matrix_padder
.
PadBDescriptor_N_K
(
return
matrix_padder
.
PadBDescriptor_N_K
(
...
@@ -282,21 +289,24 @@ struct TransformBatchedContractionContractionToBatchedGemmGemm
...
@@ -282,21 +289,24 @@ struct TransformBatchedContractionContractionToBatchedGemmGemm
//
//
// B1
// B1
//
//
static
auto
MakeB1GridDescriptorPair
(
const
std
::
vector
<
index_t
>&
b1_gs_os_ns_lengths_vec
,
__host__
__device__
static
auto
MakeB1GridDescriptorPair
(
const
std
::
vector
<
index_t
>&
b1_gs_os_ns_strides_vec
)
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
b1_gs_os_ns_lengths_vec
,
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
b1_gs_os_ns_strides_vec
)
{
{
return
MakeGridDescriptorPair
<
NumDimG
,
NumDimO
,
NumDimN
,
B1Spec
>
(
b1_gs_os_ns_lengths_vec
,
return
MakeGridDescriptorPair
<
NumDimG
,
NumDimO
,
NumDimN
,
B1Spec
>
(
b1_gs_os_ns_lengths_vec
,
b1_gs_os_ns_strides_vec
);
b1_gs_os_ns_strides_vec
);
}
}
// TODO: rename to G_NRaw_KRaw
// TODO: rename to G_NRaw_KRaw
static
auto
MakeB1GridDescriptor_G_N_K
(
const
std
::
vector
<
index_t
>&
b1_gs_os_ns_lengths_vec
,
__host__
__device__
static
auto
MakeB1GridDescriptor_G_N_K
(
const
std
::
vector
<
index_t
>&
b1_gs_os_ns_strides_vec
)
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
b1_gs_os_ns_lengths_vec
,
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
b1_gs_os_ns_strides_vec
)
{
{
return
MakeB1GridDescriptorPair
(
b1_gs_os_ns_lengths_vec
,
b1_gs_os_ns_strides_vec
).
first
;
return
MakeB1GridDescriptorPair
(
b1_gs_os_ns_lengths_vec
,
b1_gs_os_ns_strides_vec
).
first
;
}
}
static
auto
MakeB1GridDescriptor_N_K
(
const
std
::
vector
<
index_t
>&
b1_gs_os_ns_lengths_vec
,
__host__
__device__
static
auto
MakeB1GridDescriptor_N_K
(
const
std
::
vector
<
index_t
>&
b1_gs_os_ns_strides_vec
)
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
b1_gs_os_ns_lengths_vec
,
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
b1_gs_os_ns_strides_vec
)
{
{
// alias of matrix_padder.PadB1Descriptor_O_N
// alias of matrix_padder.PadB1Descriptor_O_N
return
matrix_padder
.
PadB1Descriptor_N_K
(
return
matrix_padder
.
PadB1Descriptor_N_K
(
...
@@ -353,21 +363,24 @@ struct TransformBatchedContractionContractionToBatchedGemmGemm
...
@@ -353,21 +363,24 @@ struct TransformBatchedContractionContractionToBatchedGemmGemm
//
//
// C
// C
//
//
static
auto
MakeCGridDescriptorPair
(
const
std
::
vector
<
index_t
>&
c_gs_ms_os_lengths_vec
,
__host__
__device__
static
auto
MakeCGridDescriptorPair
(
const
std
::
vector
<
index_t
>&
c_gs_ms_os_strides_vec
)
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
c_gs_ms_os_lengths_vec
,
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
c_gs_ms_os_strides_vec
)
{
{
return
MakeGridDescriptorPair
<
NumDimG
,
NumDimM
,
NumDimO
,
CSpec
>
(
c_gs_ms_os_lengths_vec
,
return
MakeGridDescriptorPair
<
NumDimG
,
NumDimM
,
NumDimO
,
CSpec
>
(
c_gs_ms_os_lengths_vec
,
c_gs_ms_os_strides_vec
);
c_gs_ms_os_strides_vec
);
}
}
// TODO: rename to G_MRaw_NRaw
// TODO: rename to G_MRaw_NRaw
static
auto
MakeCGridDescriptor_G_M_N
(
const
std
::
vector
<
index_t
>&
c_gs_ms_os_lengths_vec
,
__host__
__device__
static
auto
MakeCGridDescriptor_G_M_N
(
const
std
::
vector
<
index_t
>&
c_gs_ms_os_strides_vec
)
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
c_gs_ms_os_lengths_vec
,
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
c_gs_ms_os_strides_vec
)
{
{
return
MakeCGridDescriptorPair
(
c_gs_ms_os_lengths_vec
,
c_gs_ms_os_strides_vec
).
first
;
return
MakeCGridDescriptorPair
(
c_gs_ms_os_lengths_vec
,
c_gs_ms_os_strides_vec
).
first
;
}
}
static
auto
MakeCGridDescriptor_M_N
(
const
std
::
vector
<
index_t
>&
c_gs_ms_os_lengths_vec
,
__host__
__device__
static
auto
MakeCGridDescriptor_M_N
(
const
std
::
vector
<
index_t
>&
c_gs_ms_os_strides_vec
)
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
c_gs_ms_os_lengths_vec
,
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
c_gs_ms_os_strides_vec
)
{
{
return
matrix_padder
.
PadCDescriptor_M_N
(
return
matrix_padder
.
PadCDescriptor_M_N
(
MakeCGridDescriptorPair
(
c_gs_ms_os_lengths_vec
,
c_gs_ms_os_strides_vec
).
second
);
MakeCGridDescriptorPair
(
c_gs_ms_os_lengths_vec
,
c_gs_ms_os_strides_vec
).
second
);
...
...
script/unet_mha.sh
0 → 100644
View file @
823c8801
#!/bin/bash
while
getopts
e: flag
do
case
"
${
flag
}
"
in
e
)
executable
=
${
OPTARG
}
;;
esac
done
echo
"CK-NAVI31 Performance Test: MHA for AITemplate"
VERIFICATION
=
0
INITIALIZE
=
1
TIMING
=
1
ALL_TEST_CASE
=
0
SELF_ATTENTION
=
1
CROSS_ATTENTION
=
0
CAUSAL_MASK
=
0
# self attention with causal mask
if
[
$ALL_TEST_CASE
-eq
1
]
||
{
[
$SELF_ATTENTION
-eq
1
]
&&
[
$CAUSAL_MASK
-eq
1
]
;
}
;
then
echo
"Test launched: self attention with causal mask"
./bin/example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_wmma_fp16
$VERIFICATION
1
$TIMING
4096 4096 40 40 2 8 0.158113881945610 1 1
./bin/example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_wmma_fp16
$VERIFICATION
1
$TIMING
1024 1024 80 80 2 8 0.111803397536277 1 1
./bin/example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_wmma_fp16
$VERIFICATION
1
$TIMING
256 256 160 160 2 8 0.079056940972805 1 1
./bin/example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_wmma_fp16
$VERIFICATION
1
$TIMING
64 64 160 160 2 8 0.079056940972805 1 1
fi
# cross attention with causal mask
if
[
$ALL_TEST_CASE
-eq
1
]
||
{
[
$CROSS_ATTENTION
-eq
1
]
&&
[
$CAUSAL_MASK
-eq
1
]
;
}
;
then
echo
"Test launched: cross attention with causal mask"
./bin/example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_wmma_fp16
$VERIFICATION
1
$TIMING
4096 64 40 40 2 8 0.158113881945610 1 1
./bin/example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_wmma_fp16
$VERIFICATION
1
$TIMING
1024 64 80 80 2 8 0.111803397536277 1 1
./bin/example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_wmma_fp16
$VERIFICATION
1
$TIMING
256 64 160 160 2 8 0.079056940972805 1 1
./bin/example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_wmma_fp16
$VERIFICATION
1
$TIMING
64 64 160 160 2 8 0.079056940972805 1 1
fi
# self attention without causal mask
if
[
$ALL_TEST_CASE
-eq
1
]
||
{
[
$SELF_ATTENTION
-eq
1
]
&&
[
$CAUSAL_MASK
-eq
0
]
;
}
;
then
echo
"Test launched: self attention without causal mask"
$executable
$VERIFICATION
$INITIALIZE
$TIMING
4096 4096 64 64 2 5 0.125 1 1
$executable
$VERIFICATION
$INITIALIZE
$TIMING
1024 1024 64 64 2 10 0.125 1 1
$executable
$VERIFICATION
$INITIALIZE
$TIMING
256 256 64 64 2 20 0.125 1 1
$executable
$VERIFICATION
$INITIALIZE
$TIMING
64 64 64 64 2 20 0.125 1 1
fi
# cross attention without causal mask
if
[
$ALL_TEST_CASE
-eq
1
]
||
{
[
$CROSS_ATTENTION
-eq
1
]
&&
[
$CAUSAL_MASK
-eq
0
]
;
}
;
then
echo
"Test launched: cross attention without causal mask"
$executable
$VERIFICATION
1
$TIMING
4096 64 40 40 2 8 0.158113881945610 1 1
$executable
$VERIFICATION
1
$TIMING
1024 64 80 80 2 8 0.111803397536277 1 1
$executable
$VERIFICATION
1
$TIMING
256 64 160 160 2 8 0.079056940972805 1 1
$executable
$VERIFICATION
1
$TIMING
64 64 160 160 2 8 0.079056940972805 1 1
fi
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment