Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
6dbced07
Commit
6dbced07
authored
Sep 26, 2023
by
letaoqin
Browse files
change mha infer class and file name
parent
63ea1d70
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
230 additions
and
227 deletions
+230
-227
example/52_flash_atten_bias/CMakeLists.txt
example/52_flash_atten_bias/CMakeLists.txt
+3
-3
example/52_flash_atten_bias/batched_gemm_multihead_attention_bias_infer.cpp
...tten_bias/batched_gemm_multihead_attention_bias_infer.cpp
+68
-67
example/52_flash_atten_bias/batched_gemm_multihead_attention_infer.cpp
...ash_atten_bias/batched_gemm_multihead_attention_infer.cpp
+68
-67
example/52_flash_atten_bias/grouped_mutihead_attention_bias_infer.cpp
...lash_atten_bias/grouped_mutihead_attention_bias_infer.cpp
+68
-67
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_infer_xdl_cshuffle.hpp
...gpu/device/impl/device_batched_mha_infer_xdl_cshuffle.hpp
+7
-7
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_infer_xdl_cshuffle.hpp
...gpu/device/impl/device_grouped_mha_infer_xdl_cshuffle.hpp
+15
-15
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_infer_xdl_cshuffle.hpp
...tion/gpu/grid/gridwise_batched_mha_infer_xdl_cshuffle.hpp
+1
-1
No files found.
example/52_flash_atten_bias/CMakeLists.txt
View file @
6dbced07
add_example_executable
(
example_batched_multihead_attention_
forward
batched_gemm_multihead_attention_
forward
.cpp
)
add_example_executable
(
example_batched_multihead_attention_
infer
batched_gemm_multihead_attention_
infer
.cpp
)
add_example_executable
(
example_batched_multihead_attention_bias_
forward
batched_gemm_multihead_attention_bias_
forward
.cpp
)
add_example_executable
(
example_batched_multihead_attention_bias_
infer
batched_gemm_multihead_attention_bias_
infer
.cpp
)
add_example_executable
(
example_grouped_multihead_attention_bias_
forward
grouped_mutihead_attention_bias_
forward
.cpp
)
add_example_executable
(
example_grouped_multihead_attention_bias_
infer
grouped_mutihead_attention_bias_
infer
.cpp
)
add_example_executable
(
example_batched_multihead_attention_bias_forward_v2 batched_multihead_attention_bias_forward_v2.cpp
)
add_example_executable
(
example_batched_multihead_attention_bias_forward_v2 batched_multihead_attention_bias_forward_v2.cpp
)
add_example_executable
(
example_grouped_multihead_attention_bias_forward_v2 grouped_multihead_attention_bias_forward_v2.cpp
)
add_example_executable
(
example_grouped_multihead_attention_bias_forward_v2 grouped_multihead_attention_bias_forward_v2.cpp
)
...
...
example/52_flash_atten_bias/batched_gemm_multihead_attention_bias_
forward
.cpp
→
example/52_flash_atten_bias/batched_gemm_multihead_attention_bias_
infer
.cpp
View file @
6dbced07
...
@@ -18,7 +18,7 @@ Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g
...
@@ -18,7 +18,7 @@ Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g
#include "ck/ck.hpp"
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_mha_
fwd
_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_mha_
infer
_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/check_err.hpp"
...
@@ -67,72 +67,73 @@ static constexpr auto TensorSpecB0 = ck::tensor_operation::device::TensorSpecial
...
@@ -67,72 +67,73 @@ static constexpr auto TensorSpecB0 = ck::tensor_operation::device::TensorSpecial
static
constexpr
auto
TensorSpecB1
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecB1
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecC
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecC
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionForward_Xdl
<
using
DeviceGemmInstance
=
NumDimG
,
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionInfer_Xdl_CShuffle
<
NumDimM
,
NumDimG
,
NumDimN
,
NumDimM
,
NumDimK
,
NumDimN
,
NumDimO
,
NumDimK
,
ADataType
,
NumDimO
,
B0DataType
,
ADataType
,
B1DataType
,
B0DataType
,
CDataType
,
B1DataType
,
Acc0BiasDataType
,
CDataType
,
Acc1BiasDataType
,
Acc0BiasDataType
,
AccDataType
,
Acc1BiasDataType
,
CShuffleDataType
,
AccDataType
,
AElementOp
,
CShuffleDataType
,
B0ElementOp
,
AElementOp
,
Acc0ElementOp
,
B0ElementOp
,
B1ElementOp
,
Acc0ElementOp
,
CElementOp
,
B1ElementOp
,
GemmSpec
,
CElementOp
,
TensorSpecA
,
GemmSpec
,
TensorSpecB0
,
TensorSpecA
,
TensorSpecB1
,
TensorSpecB0
,
TensorSpecC
,
TensorSpecB1
,
1
,
TensorSpecC
,
256
,
1
,
128
,
// MPerBlock
256
,
128
,
// NPerBlock
128
,
// MPerBlock
32
,
// KPerBlock
128
,
// NPerBlock
DIM
,
// Gemm1NPerBlock
32
,
// KPerBlock
32
,
// Gemm1KPerBlock
DIM
,
// Gemm1NPerBlock
8
,
// AK1
32
,
// Gemm1KPerBlock
8
,
// BK1
8
,
// AK1
2
,
// B1K1
8
,
// BK1
32
,
// MPerXDL
2
,
// B1K1
32
,
// NPerXDL
32
,
// MPerXDL
1
,
// MXdlPerWave
32
,
// NPerXDL
4
,
// NXdlPerWave
1
,
// MXdlPerWave
DIM
/
32
,
// Gemm1NXdlPerWave
4
,
// NXdlPerWave
S
<
4
,
64
,
1
>
,
// ABlockTransfer
DIM
/
32
,
// Gemm1NXdlPerWave
S
<
1
,
0
,
2
>
,
S
<
4
,
64
,
1
>
,
// ABlockTransfer
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
S
<
1
,
0
,
2
>
,
8
,
2
,
8
,
8
,
true
,
8
,
S
<
4
,
64
,
1
>
,
// BBlockTransfer
true
,
S
<
1
,
0
,
2
>
,
S
<
4
,
64
,
1
>
,
// BBlockTransfer
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
S
<
1
,
0
,
2
>
,
8
,
2
,
8
,
8
,
true
,
8
,
4
,
true
,
S
<
16
,
16
,
1
>
,
// B1BlockTransfer
4
,
S
<
0
,
2
,
1
>
,
S
<
16
,
16
,
1
>
,
// B1BlockTransfer
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
S
<
0
,
2
,
1
>
,
4
,
1
,
2
,
4
,
false
,
2
,
1
,
// CShuffleMXdlPerWavePerShuffle
false
,
2
,
// CShuffleNXdlPerWavePerShuffle
1
,
// CShuffleMXdlPerWavePerShuffle
S
<
1
,
32
,
1
,
8
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
2
,
// CShuffleNXdlPerWavePerShuffle
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
S
<
1
,
32
,
1
,
8
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
MaskingSpec
>
;
// MaskingSpecialization
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec
>
;
// MaskingSpecialization
// Ref Gemm0: fp16 in, fp32 out
// Ref Gemm0: fp16 in, fp32 out
using
ReferenceGemm0Instance
=
ck
::
tensor_operation
::
host
::
ReferenceBatchedGemm
<
ADataType
,
using
ReferenceGemm0Instance
=
ck
::
tensor_operation
::
host
::
ReferenceBatchedGemm
<
ADataType
,
...
...
example/52_flash_atten_bias/batched_gemm_multihead_attention_
forward
.cpp
→
example/52_flash_atten_bias/batched_gemm_multihead_attention_
infer
.cpp
View file @
6dbced07
...
@@ -18,7 +18,7 @@ Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g
...
@@ -18,7 +18,7 @@ Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g
#include "ck/ck.hpp"
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_mha_
fwd
_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_mha_
infer
_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/check_err.hpp"
...
@@ -67,72 +67,73 @@ static constexpr auto TensorSpecB0 = ck::tensor_operation::device::TensorSpecial
...
@@ -67,72 +67,73 @@ static constexpr auto TensorSpecB0 = ck::tensor_operation::device::TensorSpecial
static
constexpr
auto
TensorSpecB1
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecB1
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecC
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecC
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionForward_Xdl
<
using
DeviceGemmInstance
=
NumDimG
,
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionInfer_Xdl_CShuffle
<
NumDimM
,
NumDimG
,
NumDimN
,
NumDimM
,
NumDimK
,
NumDimN
,
NumDimO
,
NumDimK
,
ADataType
,
NumDimO
,
B0DataType
,
ADataType
,
B1DataType
,
B0DataType
,
CDataType
,
B1DataType
,
Acc0BiasDataType
,
CDataType
,
Acc1BiasDataType
,
Acc0BiasDataType
,
AccDataType
,
Acc1BiasDataType
,
CShuffleDataType
,
AccDataType
,
AElementOp
,
CShuffleDataType
,
B0ElementOp
,
AElementOp
,
Acc0ElementOp
,
B0ElementOp
,
B1ElementOp
,
Acc0ElementOp
,
CElementOp
,
B1ElementOp
,
GemmSpec
,
CElementOp
,
TensorSpecA
,
GemmSpec
,
TensorSpecB0
,
TensorSpecA
,
TensorSpecB1
,
TensorSpecB0
,
TensorSpecC
,
TensorSpecB1
,
1
,
TensorSpecC
,
256
,
1
,
128
,
// MPerBlock
256
,
128
,
// NPerBlock
128
,
// MPerBlock
32
,
// KPerBlock
128
,
// NPerBlock
DIM
,
// Gemm1NPerBlock
32
,
// KPerBlock
32
,
// Gemm1KPerBlock
DIM
,
// Gemm1NPerBlock
8
,
// AK1
32
,
// Gemm1KPerBlock
8
,
// BK1
8
,
// AK1
2
,
// B1K1
8
,
// BK1
32
,
// MPerXDL
2
,
// B1K1
32
,
// NPerXDL
32
,
// MPerXDL
1
,
// MXdlPerWave
32
,
// NPerXDL
4
,
// NXdlPerWave
1
,
// MXdlPerWave
DIM
/
32
,
// Gemm1NXdlPerWave
4
,
// NXdlPerWave
S
<
4
,
64
,
1
>
,
// ABlockTransfer
DIM
/
32
,
// Gemm1NXdlPerWave
S
<
1
,
0
,
2
>
,
S
<
4
,
64
,
1
>
,
// ABlockTransfer
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
S
<
1
,
0
,
2
>
,
8
,
2
,
8
,
8
,
true
,
8
,
S
<
4
,
64
,
1
>
,
// BBlockTransfer
true
,
S
<
1
,
0
,
2
>
,
S
<
4
,
64
,
1
>
,
// BBlockTransfer
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
S
<
1
,
0
,
2
>
,
8
,
2
,
8
,
8
,
true
,
8
,
4
,
true
,
S
<
16
,
16
,
1
>
,
// B1BlockTransfer
4
,
S
<
0
,
2
,
1
>
,
S
<
16
,
16
,
1
>
,
// B1BlockTransfer
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
S
<
0
,
2
,
1
>
,
4
,
1
,
2
,
4
,
false
,
2
,
1
,
// CShuffleMXdlPerWavePerShuffle
false
,
2
,
// CShuffleNXdlPerWavePerShuffle
1
,
// CShuffleMXdlPerWavePerShuffle
S
<
1
,
32
,
1
,
8
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
2
,
// CShuffleNXdlPerWavePerShuffle
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
S
<
1
,
32
,
1
,
8
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
MaskingSpec
>
;
// MaskingSpecialization
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec
>
;
// MaskingSpecialization
// Ref Gemm0: fp16 in, fp32 out
// Ref Gemm0: fp16 in, fp32 out
using
ReferenceGemm0Instance
=
ck
::
tensor_operation
::
host
::
ReferenceBatchedGemm
<
ADataType
,
using
ReferenceGemm0Instance
=
ck
::
tensor_operation
::
host
::
ReferenceBatchedGemm
<
ADataType
,
...
...
example/52_flash_atten_bias/grouped_mutihead_attention_bias_
forward
.cpp
→
example/52_flash_atten_bias/grouped_mutihead_attention_bias_
infer
.cpp
View file @
6dbced07
...
@@ -17,7 +17,7 @@ Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g
...
@@ -17,7 +17,7 @@ Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g
#include "ck/ck.hpp"
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_mha_
fwd
_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_mha_
infer
_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/check_err.hpp"
...
@@ -66,72 +66,73 @@ static constexpr auto TensorSpecB0 = ck::tensor_operation::device::TensorSpecial
...
@@ -66,72 +66,73 @@ static constexpr auto TensorSpecB0 = ck::tensor_operation::device::TensorSpecial
static
constexpr
auto
TensorSpecB1
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecB1
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecC
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecC
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceGroupedMultiheadAttentionForward_Xdl
<
using
DeviceGemmInstance
=
NumDimG
,
ck
::
tensor_operation
::
device
::
DeviceGroupedMultiheadAttentionInfer_Xdl_CShuffle
<
NumDimM
,
NumDimG
,
NumDimN
,
NumDimM
,
NumDimK
,
NumDimN
,
NumDimO
,
NumDimK
,
ADataType
,
NumDimO
,
B0DataType
,
ADataType
,
B1DataType
,
B0DataType
,
CDataType
,
B1DataType
,
Acc0BiasDataType
,
CDataType
,
Acc1BiasDataType
,
Acc0BiasDataType
,
AccDataType
,
Acc1BiasDataType
,
CShuffleDataType
,
AccDataType
,
AElementOp
,
CShuffleDataType
,
B0ElementOp
,
AElementOp
,
Acc0ElementOp
,
B0ElementOp
,
B1ElementOp
,
Acc0ElementOp
,
CElementOp
,
B1ElementOp
,
GemmSpec
,
CElementOp
,
TensorSpecA
,
GemmSpec
,
TensorSpecB0
,
TensorSpecA
,
TensorSpecB1
,
TensorSpecB0
,
TensorSpecC
,
TensorSpecB1
,
1
,
TensorSpecC
,
256
,
1
,
128
,
// MPerBlock
256
,
128
,
// NPerBlock
128
,
// MPerBlock
32
,
// KPerBlock
128
,
// NPerBlock
64
,
// Gemm1NPerBlock
32
,
// KPerBlock
32
,
// Gemm1KPerBlock
64
,
// Gemm1NPerBlock
8
,
// AK1
32
,
// Gemm1KPerBlock
8
,
// BK1
8
,
// AK1
2
,
// B1K1
8
,
// BK1
32
,
// MPerXDL
2
,
// B1K1
32
,
// NPerXDL
32
,
// MPerXDL
1
,
// MXdlPerWave
32
,
// NPerXDL
4
,
// NXdlPerWave
1
,
// MXdlPerWave
2
,
// Gemm1NXdlPerWave
4
,
// NXdlPerWave
S
<
4
,
64
,
1
>
,
// ABlockTransfer
2
,
// Gemm1NXdlPerWave
S
<
1
,
0
,
2
>
,
S
<
4
,
64
,
1
>
,
// ABlockTransfer
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
S
<
1
,
0
,
2
>
,
8
,
2
,
8
,
8
,
true
,
8
,
S
<
4
,
64
,
1
>
,
// BBlockTransfer
true
,
S
<
1
,
0
,
2
>
,
S
<
4
,
64
,
1
>
,
// BBlockTransfer
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
S
<
1
,
0
,
2
>
,
8
,
2
,
8
,
8
,
true
,
8
,
4
,
true
,
S
<
16
,
16
,
1
>
,
// B1BlockTransfer
4
,
S
<
0
,
2
,
1
>
,
S
<
16
,
16
,
1
>
,
// B1BlockTransfer
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
S
<
0
,
2
,
1
>
,
4
,
1
,
2
,
4
,
false
,
2
,
1
,
// CShuffleMXdlPerWavePerShuffle
false
,
2
,
// CShuffleNXdlPerWavePerShuffle
1
,
// CShuffleMXdlPerWavePerShuffle
S
<
1
,
32
,
1
,
8
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
2
,
// CShuffleNXdlPerWavePerShuffle
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
S
<
1
,
32
,
1
,
8
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
MaskingSpec
>
;
// MaskingSpecialization
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec
>
;
// MaskingSpecialization
// Ref Gemm0: fp16 in, fp32 out
// Ref Gemm0: fp16 in, fp32 out
using
ReferenceGemm0Instance
=
ck
::
tensor_operation
::
host
::
ReferenceBatchedGemm
<
ADataType
,
using
ReferenceGemm0Instance
=
ck
::
tensor_operation
::
host
::
ReferenceBatchedGemm
<
ADataType
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_
fwd
_xdl_cshuffle.hpp
→
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_
infer
_xdl_cshuffle.hpp
View file @
6dbced07
...
@@ -13,7 +13,7 @@
...
@@ -13,7 +13,7 @@
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batched_mha_
fwd
_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batched_mha_
infer
_xdl_cshuffle.hpp"
#include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp"
#include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/kernel_launch.hpp"
...
@@ -44,7 +44,7 @@ __global__ void
...
@@ -44,7 +44,7 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
#endif
kernel_batched_multiple_head_flash_attention_
forward
(
kernel_batched_multiple_head_flash_attention_
infer
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
const
D0DataType
*
p_d0_grid
,
const
D0DataType
*
p_d0_grid
,
...
@@ -205,7 +205,7 @@ template <index_t NumDimG,
...
@@ -205,7 +205,7 @@ template <index_t NumDimG,
MaskingSpecialization
MaskingSpec
,
MaskingSpecialization
MaskingSpec
,
int
D0sTransferSrcScalarPerVector
=
4
,
int
D0sTransferSrcScalarPerVector
=
4
,
LoopScheduler
LoopSched
=
LoopScheduler
::
Default
>
LoopScheduler
LoopSched
=
LoopScheduler
::
Default
>
struct
DeviceBatchedMultiheadAttention
Forward_Xdl
struct
DeviceBatchedMultiheadAttention
Infer_Xdl_CShuffle
:
public
DeviceBatchedMultiheadAttentionInfer
<
NumDimG
,
:
public
DeviceBatchedMultiheadAttentionInfer
<
NumDimG
,
NumDimM
,
NumDimM
,
NumDimN
,
NumDimN
,
...
@@ -243,7 +243,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl
...
@@ -243,7 +243,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl
static constexpr index_t NumDimGemm1K = NumDimN;
static constexpr index_t NumDimGemm1K = NumDimN;
#endif
#endif
using
DeviceOp
=
DeviceBatchedMultiheadAttention
Forward_Xdl
;
using
DeviceOp
=
DeviceBatchedMultiheadAttention
Infer_Xdl_CShuffle
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
...
@@ -376,7 +376,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl
...
@@ -376,7 +376,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl
D0GridDesc_G_M_N
d0_grid_desc_g_m_n_
;
D0GridDesc_G_M_N
d0_grid_desc_g_m_n_
;
};
};
using
GridwiseGemm
=
GridwiseMultiHeadFlashAttention
Forward
_Xdl_CShuffle
<
using
GridwiseGemm
=
GridwiseMultiHeadFlashAttention
Infer
_Xdl_CShuffle
<
ADataType
,
// TODO: distinguish A/B datatype
ADataType
,
// TODO: distinguish A/B datatype
D0DataType
,
D0DataType
,
GemmAccDataType
,
GemmAccDataType
,
...
@@ -641,7 +641,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl
...
@@ -641,7 +641,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl
float
ave_time
=
0
;
float
ave_time
=
0
;
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop_
)
{
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop_
)
{
const
auto
kernel
=
kernel_batched_multiple_head_flash_attention_
forward
<
const
auto
kernel
=
kernel_batched_multiple_head_flash_attention_
infer
<
GridwiseGemm
,
GridwiseGemm
,
ADataType
,
// TODO: distiguish A/B datatype
ADataType
,
// TODO: distiguish A/B datatype
D0DataType
,
D0DataType
,
...
@@ -925,7 +925,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl
...
@@ -925,7 +925,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl
auto
str
=
std
::
stringstream
();
auto
str
=
std
::
stringstream
();
// clang-format off
// clang-format off
str
<<
"DeviceBatchedMultiheadAttention
Forward_Xdl
"
str
<<
"DeviceBatchedMultiheadAttention
Infer_Xdl_CShuffle
"
<<
"<"
<<
"<"
<<
BlockSize
<<
", "
<<
BlockSize
<<
", "
<<
MPerBlock
<<
", "
<<
MPerBlock
<<
", "
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_
fwd
_xdl_cshuffle.hpp
→
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_
infer
_xdl_cshuffle.hpp
View file @
6dbced07
...
@@ -13,7 +13,7 @@
...
@@ -13,7 +13,7 @@
#include "ck/tensor_operation/gpu/device/device_grouped_mha_infer.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_mha_infer.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batched_mha_
fwd
_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batched_mha_
infer
_xdl_cshuffle.hpp"
#include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp"
#include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/kernel_launch.hpp"
...
@@ -35,7 +35,7 @@ __global__ void
...
@@ -35,7 +35,7 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
#endif
kernel_grouped_
gemm_softmax_gemm_xdl_cshuffle_v1
(
kernel_grouped_
multiple_head_flash_attention_infer
(
const
void
CK_CONSTANT_ADDRESS_SPACE
*
group_kernel_args
,
const
void
CK_CONSTANT_ADDRESS_SPACE
*
group_kernel_args
,
const
index_t
group_count
,
const
index_t
group_count
,
const
AElementwiseOperation
a_element_op
,
const
AElementwiseOperation
a_element_op
,
...
@@ -194,7 +194,7 @@ template <index_t NumDimG,
...
@@ -194,7 +194,7 @@ template <index_t NumDimG,
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
,
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
,
MaskingSpecialization
MaskingSpec
,
MaskingSpecialization
MaskingSpec
,
LoopScheduler
LoopSched
=
LoopScheduler
::
Default
>
LoopScheduler
LoopSched
=
LoopScheduler
::
Default
>
struct
DeviceGroupedMultiheadAttention
Forward_Xdl
struct
DeviceGroupedMultiheadAttention
Infer_Xdl_CShuffle
:
public
DeviceGroupedMultiheadAttentionInfer
<
NumDimG
,
:
public
DeviceGroupedMultiheadAttentionInfer
<
NumDimG
,
NumDimM
,
NumDimM
,
NumDimN
,
NumDimN
,
...
@@ -230,7 +230,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl
...
@@ -230,7 +230,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl
static constexpr index_t NumDimGemm1K = NumDimN;
static constexpr index_t NumDimGemm1K = NumDimN;
#endif
#endif
using
DeviceOp
=
DeviceGroupedMultiheadAttention
Forward_Xdl
;
using
DeviceOp
=
DeviceGroupedMultiheadAttention
Infer_Xdl_CShuffle
;
using
ProblemDesc
=
typename
DeviceGroupedMultiheadAttentionInfer
<
NumDimG
,
using
ProblemDesc
=
typename
DeviceGroupedMultiheadAttentionInfer
<
NumDimG
,
NumDimM
,
NumDimM
,
NumDimN
,
NumDimN
,
...
@@ -382,7 +382,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl
...
@@ -382,7 +382,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl
};
};
// GridwiseGemm
// GridwiseGemm
using
GridwiseGemm
=
GridwiseMultiHeadFlashAttention
Forward
_Xdl_CShuffle
<
using
GridwiseGemm
=
GridwiseMultiHeadFlashAttention
Infer
_Xdl_CShuffle
<
ADataType
,
// TODO: distinguish A/B datatype
ADataType
,
// TODO: distinguish A/B datatype
Acc0BiasDataType
,
Acc0BiasDataType
,
GemmAccDataType
,
GemmAccDataType
,
...
@@ -698,15 +698,15 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl
...
@@ -698,15 +698,15 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop_
)
{
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop_
)
{
const
auto
kernel
=
const
auto
kernel
=
kernel_grouped_
gemm_softmax_gemm_xdl_cshuffle_v1
<
GridwiseGemm
,
kernel_grouped_
multiple_head_flash_attention_infer
<
GridwiseGemm
,
D0DataType
,
D0DataType
,
GroupKernelArg
,
GroupKernelArg
,
AElementwiseOperation
,
AElementwiseOperation
,
BElementwiseOperation
,
BElementwiseOperation
,
AccElementwiseOperation
,
AccElementwiseOperation
,
B1ElementwiseOperation
,
B1ElementwiseOperation
,
CElementwiseOperation
,
CElementwiseOperation
,
has_main_k_block_loop_
>
;
has_main_k_block_loop_
>
;
return
launch_and_time_kernel
(
return
launch_and_time_kernel
(
stream_config
,
stream_config
,
...
@@ -944,7 +944,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl
...
@@ -944,7 +944,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl
auto
str
=
std
::
stringstream
();
auto
str
=
std
::
stringstream
();
// clang-format off
// clang-format off
str
<<
"DeviceGroupedMultiheadAttention
Forward_Xdl
"
str
<<
"DeviceGroupedMultiheadAttention
Infer_Xdl_CShuffle
"
<<
"<"
<<
"<"
<<
BlockSize
<<
", "
<<
BlockSize
<<
", "
<<
MPerBlock
<<
", "
<<
MPerBlock
<<
", "
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_
fwd
_xdl_cshuffle.hpp
→
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_
infer
_xdl_cshuffle.hpp
View file @
6dbced07
...
@@ -86,7 +86,7 @@ template <typename FloatAB,
...
@@ -86,7 +86,7 @@ template <typename FloatAB,
bool
PadN
,
bool
PadN
,
bool
MaskOutUpperTriangle
,
bool
MaskOutUpperTriangle
,
PipelineVersion
PipelineVer
=
PipelineVersion
::
v1
>
PipelineVersion
PipelineVer
=
PipelineVersion
::
v1
>
struct
GridwiseMultiHeadFlashAttention
Forward
_Xdl_CShuffle
struct
GridwiseMultiHeadFlashAttention
Infer
_Xdl_CShuffle
{
{
static_assert
(
D0BlockTransferSrcScalarPerVector
==
1
||
static_assert
(
D0BlockTransferSrcScalarPerVector
==
1
||
D0BlockTransferSrcScalarPerVector
==
2
||
D0BlockTransferSrcScalarPerVector
==
2
||
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment