Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
99ebfeba
Commit
99ebfeba
authored
Apr 18, 2023
by
danyao12
Browse files
correct deterministic mode
parent
84a81ae2
Changes
13
Show whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
538 additions
and
401 deletions
+538
-401
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_forward.cpp
...cale_softmax_gemm/batched_multihead_attention_forward.cpp
+11
-7
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_train.cpp
..._scale_softmax_gemm/batched_multihead_attention_train.cpp
+19
-11
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_forward.cpp
...cale_softmax_gemm/grouped_multihead_attention_forward.cpp
+11
-7
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_train.cpp
..._scale_softmax_gemm/grouped_multihead_attention_train.cpp
+19
-11
include/ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle_v1.hpp
..._batched_multihead_attention_backward_xdl_cshuffle_v1.hpp
+76
-71
include/ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle_v2.hpp
..._batched_multihead_attention_backward_xdl_cshuffle_v2.hpp
+76
-71
include/ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_forward_xdl_cshuffle.hpp
...vice_batched_multihead_attention_forward_xdl_cshuffle.hpp
+129
-86
include/ck/tensor_operation/gpu/device/impl/device_grouped_multihead_attention_backward_xdl_cshuffle_v1.hpp
..._grouped_multihead_attention_backward_xdl_cshuffle_v1.hpp
+37
-36
include/ck/tensor_operation/gpu/device/impl/device_grouped_multihead_attention_backward_xdl_cshuffle_v2.hpp
..._grouped_multihead_attention_backward_xdl_cshuffle_v2.hpp
+37
-36
include/ck/tensor_operation/gpu/device/impl/device_grouped_multihead_attention_forward_xdl_cshuffle.hpp
...vice_grouped_multihead_attention_forward_xdl_cshuffle.hpp
+78
-32
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_pt1.hpp
...batched_multihead_attention_backward_xdl_cshuffle_pt1.hpp
+18
-14
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_pt2.hpp
...batched_multihead_attention_backward_xdl_cshuffle_pt2.hpp
+10
-6
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_forward_xdl_cshuffle.hpp
...wise_batched_multihead_attention_forward_xdl_cshuffle.hpp
+17
-13
No files found.
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_forward.cpp
View file @
99ebfeba
...
@@ -75,6 +75,7 @@ static constexpr auto TensorSpecA = ck::tensor_operation::device::TensorSpecial
...
@@ -75,6 +75,7 @@ static constexpr auto TensorSpecA = ck::tensor_operation::device::TensorSpecial
static
constexpr
auto
TensorSpecB0
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecB0
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecB1
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecB1
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecC
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecC
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
bool
Deterministic
=
true
;
#if(DIM <= 32)
#if(DIM <= 32)
using
DeviceGemmInstance
=
using
DeviceGemmInstance
=
...
@@ -145,7 +146,8 @@ using DeviceGemmInstance =
...
@@ -145,7 +146,8 @@ using DeviceGemmInstance =
1
,
// CShuffleNXdlPerWavePerShuffle
1
,
// CShuffleNXdlPerWavePerShuffle
S
<
1
,
64
,
1
,
4
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
S
<
1
,
64
,
1
,
4
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec
>
;
// MaskingSpecialization
MaskingSpec
,
// MaskingSpecialization
Deterministic
>
;
#elif(DIM <= 64)
#elif(DIM <= 64)
using
DeviceGemmInstance
=
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
<
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
<
...
@@ -215,7 +217,8 @@ using DeviceGemmInstance =
...
@@ -215,7 +217,8 @@ using DeviceGemmInstance =
2
,
// CShuffleNXdlPerWavePerShuffle
2
,
// CShuffleNXdlPerWavePerShuffle
S
<
1
,
32
,
1
,
8
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
S
<
1
,
32
,
1
,
8
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec
>
;
// MaskingSpecialization
MaskingSpec
,
// MaskingSpecialization
Deterministic
>
;
#elif(DIM <= 128)
#elif(DIM <= 128)
using
DeviceGemmInstance
=
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
<
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
<
...
@@ -285,7 +288,8 @@ using DeviceGemmInstance =
...
@@ -285,7 +288,8 @@ using DeviceGemmInstance =
2
,
// CShuffleNXdlPerWavePerShuffle
2
,
// CShuffleNXdlPerWavePerShuffle
S
<
1
,
32
,
1
,
8
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
S
<
1
,
32
,
1
,
8
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec
>
;
// MaskingSpecialization
MaskingSpec
,
// MaskingSpecialization
Deterministic
>
;
#endif
#endif
// Ref Gemm0: DataType in, AccDataType out
// Ref Gemm0: DataType in, AccDataType out
...
...
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_train.cpp
View file @
99ebfeba
...
@@ -104,6 +104,7 @@ static constexpr auto TensorSpecQ = ck::tensor_operation::device::TensorSpeciali
...
@@ -104,6 +104,7 @@ static constexpr auto TensorSpecQ = ck::tensor_operation::device::TensorSpeciali
static
constexpr
auto
TensorSpecK
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecK
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecV
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecV
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecY
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecY
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
bool
Deterministic
=
true
;
// DIM should be a multiple of 8.
// DIM should be a multiple of 8.
// If DIM <= 32 , ues prototype1 1st template.
// If DIM <= 32 , ues prototype1 1st template.
...
@@ -178,7 +179,8 @@ using DeviceGemmInstanceFWD =
...
@@ -178,7 +179,8 @@ using DeviceGemmInstanceFWD =
1
,
// CShuffleNXdlPerWavePerShuffle
1
,
// CShuffleNXdlPerWavePerShuffle
S
<
1
,
64
,
1
,
4
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
S
<
1
,
64
,
1
,
4
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec
>
;
// MaskingSpecialization
MaskingSpec
,
// MaskingSpecialization
Deterministic
>
;
using
DeviceGemmInstanceBWD
=
using
DeviceGemmInstanceBWD
=
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
<
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
<
...
@@ -247,7 +249,8 @@ using DeviceGemmInstanceBWD =
...
@@ -247,7 +249,8 @@ using DeviceGemmInstanceBWD =
1
,
// CShuffleNXdlPerWavePerShuffle
1
,
// CShuffleNXdlPerWavePerShuffle
S
<
1
,
64
,
1
,
4
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
S
<
1
,
64
,
1
,
4
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
CShuffleBlockTransferScalarPerVector_NPerBlock
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
CShuffleBlockTransferScalarPerVector_NPerBlock
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec
>
;
// MaskingSpecialization
MaskingSpec
,
// MaskingSpecialization
Deterministic
>
;
#elif(DIM <= 64)
#elif(DIM <= 64)
using
DeviceGemmInstanceFWD
=
using
DeviceGemmInstanceFWD
=
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
<
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
<
...
@@ -317,7 +320,8 @@ using DeviceGemmInstanceFWD =
...
@@ -317,7 +320,8 @@ using DeviceGemmInstanceFWD =
2
,
// CShuffleNXdlPerWavePerShuffle
2
,
// CShuffleNXdlPerWavePerShuffle
S
<
1
,
32
,
1
,
8
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
S
<
1
,
32
,
1
,
8
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec
>
;
// MaskingSpecialization
MaskingSpec
,
// MaskingSpecialization
Deterministic
>
;
using
DeviceGemmInstanceBWD
=
using
DeviceGemmInstanceBWD
=
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
<
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
<
...
@@ -386,7 +390,8 @@ using DeviceGemmInstanceBWD =
...
@@ -386,7 +390,8 @@ using DeviceGemmInstanceBWD =
2
,
// CShuffleNXdlPerWavePerShuffle
2
,
// CShuffleNXdlPerWavePerShuffle
S
<
1
,
32
,
1
,
8
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
S
<
1
,
32
,
1
,
8
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
CShuffleBlockTransferScalarPerVector_NPerBlock
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
CShuffleBlockTransferScalarPerVector_NPerBlock
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec
>
;
// MaskingSpecialization
MaskingSpec
,
// MaskingSpecialization
Deterministic
>
;
// using DeviceGemmInstanceBWD =
// using DeviceGemmInstanceBWD =
// ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2<
// ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2<
...
@@ -455,7 +460,8 @@ using DeviceGemmInstanceBWD =
...
@@ -455,7 +460,8 @@ using DeviceGemmInstanceBWD =
// 2, // CShuffleNXdlPerWavePerShuffle
// 2, // CShuffleNXdlPerWavePerShuffle
// S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
// S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
// CShuffleBlockTransferScalarPerVector_NPerBlock,
// CShuffleBlockTransferScalarPerVector_NPerBlock,
// MaskingSpec>;
// MaskingSpec,
// Deterministic>;
#elif(DIM <= 128)
#elif(DIM <= 128)
using
DeviceGemmInstanceFWD
=
using
DeviceGemmInstanceFWD
=
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
<
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
<
...
@@ -525,7 +531,8 @@ using DeviceGemmInstanceFWD =
...
@@ -525,7 +531,8 @@ using DeviceGemmInstanceFWD =
2
,
// CShuffleNXdlPerWavePerShuffle
2
,
// CShuffleNXdlPerWavePerShuffle
S
<
1
,
32
,
1
,
8
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
S
<
1
,
32
,
1
,
8
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec
>
;
// MaskingSpecialization
MaskingSpec
,
// MaskingSpecialization
Deterministic
>
;
using
DeviceGemmInstanceBWD
=
using
DeviceGemmInstanceBWD
=
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
<
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
<
...
@@ -594,7 +601,8 @@ using DeviceGemmInstanceBWD =
...
@@ -594,7 +601,8 @@ using DeviceGemmInstanceBWD =
4
,
// CShuffleNXdlPerWavePerShuffle
4
,
// CShuffleNXdlPerWavePerShuffle
S
<
1
,
32
,
1
,
8
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
S
<
1
,
32
,
1
,
8
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
CShuffleBlockTransferScalarPerVector_NPerBlock
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
CShuffleBlockTransferScalarPerVector_NPerBlock
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec
>
;
// MaskingSpecialization
MaskingSpec
,
// MaskingSpecialization
Deterministic
>
;
#endif
#endif
// Ref Gemm0: S = alpha * Q * K^T
// Ref Gemm0: S = alpha * Q * K^T
...
...
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_forward.cpp
View file @
99ebfeba
...
@@ -75,6 +75,7 @@ static constexpr auto TensorSpecA = ck::tensor_operation::device::TensorSpecial
...
@@ -75,6 +75,7 @@ static constexpr auto TensorSpecA = ck::tensor_operation::device::TensorSpecial
static
constexpr
auto
TensorSpecB0
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecB0
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecB1
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecB1
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecC
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecC
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
bool
Deterministic
=
true
;
#if(DIM <= 32)
#if(DIM <= 32)
using
DeviceGemmInstance
=
using
DeviceGemmInstance
=
...
@@ -145,7 +146,8 @@ using DeviceGemmInstance =
...
@@ -145,7 +146,8 @@ using DeviceGemmInstance =
1
,
// CShuffleNXdlPerWavePerShuffle
1
,
// CShuffleNXdlPerWavePerShuffle
S
<
1
,
64
,
1
,
4
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
S
<
1
,
64
,
1
,
4
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec
>
;
// MaskingSpecialization
MaskingSpec
,
// MaskingSpecialization
Deterministic
>
;
#elif(DIM <= 64)
#elif(DIM <= 64)
using
DeviceGemmInstance
=
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
<
ck
::
tensor_operation
::
device
::
DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
<
...
@@ -215,7 +217,8 @@ using DeviceGemmInstance =
...
@@ -215,7 +217,8 @@ using DeviceGemmInstance =
2
,
// CShuffleNXdlPerWavePerShuffle
2
,
// CShuffleNXdlPerWavePerShuffle
S
<
1
,
32
,
1
,
8
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
S
<
1
,
32
,
1
,
8
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec
>
;
// MaskingSpecialization
MaskingSpec
,
// MaskingSpecialization
Deterministic
>
;
#elif(DIM <= 128)
#elif(DIM <= 128)
using
DeviceGemmInstance
=
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
<
ck
::
tensor_operation
::
device
::
DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
<
...
@@ -285,7 +288,8 @@ using DeviceGemmInstance =
...
@@ -285,7 +288,8 @@ using DeviceGemmInstance =
2
,
// CShuffleNXdlPerWavePerShuffle
2
,
// CShuffleNXdlPerWavePerShuffle
S
<
1
,
32
,
1
,
8
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
S
<
1
,
32
,
1
,
8
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec
>
;
// MaskingSpecialization
MaskingSpec
,
// MaskingSpecialization
Deterministic
>
;
#endif
#endif
// Ref Gemm0: DataType in, AccDataType out
// Ref Gemm0: DataType in, AccDataType out
...
...
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_train.cpp
View file @
99ebfeba
...
@@ -103,6 +103,7 @@ static constexpr auto TensorSpecQ = ck::tensor_operation::device::TensorSpeciali
...
@@ -103,6 +103,7 @@ static constexpr auto TensorSpecQ = ck::tensor_operation::device::TensorSpeciali
static
constexpr
auto
TensorSpecK
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecK
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecV
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecV
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecY
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecY
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
bool
Deterministic
=
true
;
// DIM should be a multiple of 8.
// DIM should be a multiple of 8.
// If DIM <= 32 , ues prototype1 1st template.
// If DIM <= 32 , ues prototype1 1st template.
...
@@ -177,7 +178,8 @@ using DeviceGemmInstanceFWD =
...
@@ -177,7 +178,8 @@ using DeviceGemmInstanceFWD =
1
,
// CShuffleNXdlPerWavePerShuffle
1
,
// CShuffleNXdlPerWavePerShuffle
S
<
1
,
64
,
1
,
4
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
S
<
1
,
64
,
1
,
4
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec
>
;
// MaskingSpecialization
MaskingSpec
,
// MaskingSpecialization
Deterministic
>
;
using
DeviceGemmInstanceBWD
=
using
DeviceGemmInstanceBWD
=
ck
::
tensor_operation
::
device
::
DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1
<
ck
::
tensor_operation
::
device
::
DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1
<
...
@@ -246,7 +248,8 @@ using DeviceGemmInstanceBWD =
...
@@ -246,7 +248,8 @@ using DeviceGemmInstanceBWD =
1
,
// CShuffleNXdlPerWavePerShuffle
1
,
// CShuffleNXdlPerWavePerShuffle
S
<
1
,
64
,
1
,
4
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
S
<
1
,
64
,
1
,
4
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
CShuffleBlockTransferScalarPerVector_NPerBlock
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
CShuffleBlockTransferScalarPerVector_NPerBlock
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec
>
;
// MaskingSpecialization
MaskingSpec
,
// MaskingSpecialization
Deterministic
>
;
#elif(DIM <= 64)
#elif(DIM <= 64)
using
DeviceGemmInstanceFWD
=
using
DeviceGemmInstanceFWD
=
ck
::
tensor_operation
::
device
::
DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
<
ck
::
tensor_operation
::
device
::
DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
<
...
@@ -316,7 +319,8 @@ using DeviceGemmInstanceFWD =
...
@@ -316,7 +319,8 @@ using DeviceGemmInstanceFWD =
2
,
// CShuffleNXdlPerWavePerShuffle
2
,
// CShuffleNXdlPerWavePerShuffle
S
<
1
,
32
,
1
,
8
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
S
<
1
,
32
,
1
,
8
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec
>
;
// MaskingSpecialization
MaskingSpec
,
// MaskingSpecialization
Deterministic
>
;
using
DeviceGemmInstanceBWD
=
using
DeviceGemmInstanceBWD
=
ck
::
tensor_operation
::
device
::
DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1
<
ck
::
tensor_operation
::
device
::
DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1
<
...
@@ -385,7 +389,8 @@ using DeviceGemmInstanceBWD =
...
@@ -385,7 +389,8 @@ using DeviceGemmInstanceBWD =
2
,
// CShuffleNXdlPerWavePerShuffle
2
,
// CShuffleNXdlPerWavePerShuffle
S
<
1
,
32
,
1
,
8
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
S
<
1
,
32
,
1
,
8
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
CShuffleBlockTransferScalarPerVector_NPerBlock
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
CShuffleBlockTransferScalarPerVector_NPerBlock
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec
>
;
// MaskingSpecialization
MaskingSpec
,
// MaskingSpecialization
Deterministic
>
;
// using DeviceGemmInstanceBWD =
// using DeviceGemmInstanceBWD =
// ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2<
// ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2<
...
@@ -454,7 +459,8 @@ using DeviceGemmInstanceBWD =
...
@@ -454,7 +459,8 @@ using DeviceGemmInstanceBWD =
// 2, // CShuffleNXdlPerWavePerShuffle
// 2, // CShuffleNXdlPerWavePerShuffle
// S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
// S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
// CShuffleBlockTransferScalarPerVector_NPerBlock,
// CShuffleBlockTransferScalarPerVector_NPerBlock,
// MaskingSpec>;
// MaskingSpec,
// Deterministic>;
#elif(DIM <= 128)
#elif(DIM <= 128)
using
DeviceGemmInstanceFWD
=
using
DeviceGemmInstanceFWD
=
ck
::
tensor_operation
::
device
::
DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
<
ck
::
tensor_operation
::
device
::
DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
<
...
@@ -524,7 +530,8 @@ using DeviceGemmInstanceFWD =
...
@@ -524,7 +530,8 @@ using DeviceGemmInstanceFWD =
2
,
// CShuffleNXdlPerWavePerShuffle
2
,
// CShuffleNXdlPerWavePerShuffle
S
<
1
,
32
,
1
,
8
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
S
<
1
,
32
,
1
,
8
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec
>
;
// MaskingSpecialization
MaskingSpec
,
// MaskingSpecialization
Deterministic
>
;
using
DeviceGemmInstanceBWD
=
using
DeviceGemmInstanceBWD
=
ck
::
tensor_operation
::
device
::
DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
<
ck
::
tensor_operation
::
device
::
DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
<
...
@@ -593,7 +600,8 @@ using DeviceGemmInstanceBWD =
...
@@ -593,7 +600,8 @@ using DeviceGemmInstanceBWD =
4
,
// CShuffleNXdlPerWavePerShuffle
4
,
// CShuffleNXdlPerWavePerShuffle
S
<
1
,
32
,
1
,
8
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
S
<
1
,
32
,
1
,
8
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
CShuffleBlockTransferScalarPerVector_NPerBlock
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
CShuffleBlockTransferScalarPerVector_NPerBlock
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec
>
;
// MaskingSpecialization
MaskingSpec
,
// MaskingSpecialization
Deterministic
>
;
#endif
#endif
// Ref Gemm0: S = alpha * Q * K^T
// Ref Gemm0: S = alpha * Q * K^T
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle_v1.hpp
View file @
99ebfeba
...
@@ -82,6 +82,7 @@ __global__ void
...
@@ -82,6 +82,7 @@ __global__ void
const
YGradGridDesc_O0_M_O1
ygrad_grid_desc_o0_m_o1
,
const
YGradGridDesc_O0_M_O1
ygrad_grid_desc_o0_m_o1
,
const
Block2CTileMap
block_2_ctile_map
,
const
Block2CTileMap
block_2_ctile_map
,
const
index_t
batch_count
,
const
index_t
batch_count
,
const
index_t
mblock
,
const
ComputeBasePtrOfStridedBatch
compute_base_ptr_of_batch
,
const
ComputeBasePtrOfStridedBatch
compute_base_ptr_of_batch
,
const
C0MatrixMask
c0_matrix_mask
,
const
C0MatrixMask
c0_matrix_mask
,
const
float
p_drop
,
const
float
p_drop
,
...
@@ -115,9 +116,7 @@ __global__ void
...
@@ -115,9 +116,7 @@ __global__ void
if
constexpr
(
Deterministic
)
if
constexpr
(
Deterministic
)
{
{
for
(
index_t
i
=
0
;
i
<
num_blocks_per_batch
;
i
++
)
for
(
index_t
i
=
0
;
i
<
mblock
;
i
++
)
{
if
(
get_block_1d_id
()
%
num_blocks_per_batch
==
i
)
{
{
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
+
a_batch_offset
,
p_a_grid
+
a_batch_offset
,
...
@@ -147,8 +146,8 @@ __global__ void
...
@@ -147,8 +146,8 @@ __global__ void
block_2_ctile_map
,
block_2_ctile_map
,
c0_matrix_mask
,
c0_matrix_mask
,
p_drop
,
p_drop
,
ph
);
ph
,
}
i
);
}
}
}
}
else
else
...
@@ -180,7 +179,8 @@ __global__ void
...
@@ -180,7 +179,8 @@ __global__ void
block_2_ctile_map
,
block_2_ctile_map
,
c0_matrix_mask
,
c0_matrix_mask
,
p_drop
,
p_drop
,
ph
);
ph
,
0
);
}
}
#else
#else
ignore
=
p_a_grid
;
ignore
=
p_a_grid
;
...
@@ -707,7 +707,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
...
@@ -707,7 +707,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
CShuffleBlockTransferScalarPerVector_NPerBlock
,
CShuffleBlockTransferScalarPerVector_NPerBlock
,
LoopSched
,
LoopSched
,
Transform
::
matrix_padder
.
PadN
,
Transform
::
matrix_padder
.
PadN
,
MaskingSpec
==
MaskingSpecialization
::
MaskOutUpperTriangle
>
;
MaskingSpec
==
MaskingSpecialization
::
MaskOutUpperTriangle
,
Deterministic
>
;
// Argument
// Argument
struct
Argument
:
public
BaseArgument
struct
Argument
:
public
BaseArgument
...
@@ -941,7 +942,9 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
...
@@ -941,7 +942,9 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
}
}
const
index_t
grid_size
=
const
index_t
grid_size
=
arg
.
block_2_ctile_map_
.
CalculateGridSize
(
arg
.
y_grid_desc_m_o_
)
*
arg
.
batch_count_
;
(
Deterministic
?
1
:
arg
.
block_2_ctile_map_
.
CalculateGridSize
(
arg
.
y_grid_desc_m_o_
))
*
arg
.
batch_count_
;
float
ave_time
=
0
;
float
ave_time
=
0
;
...
@@ -971,7 +974,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
...
@@ -971,7 +974,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
has_main_k_block_loop_
,
has_main_k_block_loop_
,
Deterministic
>
;
Deterministic
>
;
return
launch_and_time_kernel
(
stream_config
,
return
launch_and_time_kernel
(
stream_config
,
kernel
,
kernel
,
dim3
(
grid_size
),
dim3
(
grid_size
),
dim3
(
BlockSize
),
dim3
(
BlockSize
),
...
@@ -1001,6 +1005,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
...
@@ -1001,6 +1005,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
arg
.
ygrad_grid_desc_o0_m_o1_
,
arg
.
ygrad_grid_desc_o0_m_o1_
,
arg
.
block_2_ctile_map_
,
arg
.
block_2_ctile_map_
,
arg
.
batch_count_
,
arg
.
batch_count_
,
arg
.
block_2_ctile_map_
.
CalculateGridSize
(
arg
.
y_grid_desc_m_o_
),
arg
.
compute_base_ptr_of_batch_
,
arg
.
compute_base_ptr_of_batch_
,
arg
.
c0_matrix_mask_
,
arg
.
c0_matrix_mask_
,
arg
.
p_drop_
,
arg
.
p_drop_
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle_v2.hpp
View file @
99ebfeba
...
@@ -81,6 +81,7 @@ __global__ void
...
@@ -81,6 +81,7 @@ __global__ void
const
YGradGridDesc_M0_O_M1
ygrad_grid_desc_m0_o_m1
,
const
YGradGridDesc_M0_O_M1
ygrad_grid_desc_m0_o_m1
,
const
Block2CTileMap
block_2_ctile_map
,
const
Block2CTileMap
block_2_ctile_map
,
const
index_t
batch_count
,
const
index_t
batch_count
,
const
index_t
mblock
,
const
ComputeBasePtrOfStridedBatch
compute_base_ptr_of_batch
,
const
ComputeBasePtrOfStridedBatch
compute_base_ptr_of_batch
,
const
C0MatrixMask
c0_matrix_mask
,
const
C0MatrixMask
c0_matrix_mask
,
const
float
p_drop
,
const
float
p_drop
,
...
@@ -114,9 +115,7 @@ __global__ void
...
@@ -114,9 +115,7 @@ __global__ void
if
constexpr
(
Deterministic
)
if
constexpr
(
Deterministic
)
{
{
for
(
index_t
i
=
0
;
i
<
num_blocks_per_batch
;
i
++
)
for
(
index_t
i
=
0
;
i
<
mblock
;
i
++
)
{
if
(
get_block_1d_id
()
%
num_blocks_per_batch
==
i
)
{
{
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
+
a_batch_offset
,
p_a_grid
+
a_batch_offset
,
...
@@ -146,8 +145,8 @@ __global__ void
...
@@ -146,8 +145,8 @@ __global__ void
block_2_ctile_map
,
block_2_ctile_map
,
c0_matrix_mask
,
c0_matrix_mask
,
p_drop
,
p_drop
,
ph
);
ph
,
}
i
);
}
}
}
}
else
else
...
@@ -179,7 +178,8 @@ __global__ void
...
@@ -179,7 +178,8 @@ __global__ void
block_2_ctile_map
,
block_2_ctile_map
,
c0_matrix_mask
,
c0_matrix_mask
,
p_drop
,
p_drop
,
ph
);
ph
,
0
);
}
}
#else
#else
ignore
=
p_a_grid
;
ignore
=
p_a_grid
;
...
@@ -706,7 +706,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -706,7 +706,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
CShuffleBlockTransferScalarPerVector_NPerBlock
,
CShuffleBlockTransferScalarPerVector_NPerBlock
,
LoopSched
,
LoopSched
,
Transform
::
matrix_padder
.
PadN
,
Transform
::
matrix_padder
.
PadN
,
MaskingSpec
==
MaskingSpecialization
::
MaskOutUpperTriangle
>
;
MaskingSpec
==
MaskingSpecialization
::
MaskOutUpperTriangle
,
Deterministic
>
;
// Argument
// Argument
struct
Argument
:
public
BaseArgument
struct
Argument
:
public
BaseArgument
...
@@ -939,7 +940,9 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -939,7 +940,9 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
}
}
const
index_t
grid_size
=
const
index_t
grid_size
=
arg
.
block_2_ctile_map_
.
CalculateGridSize
(
arg
.
y_grid_desc_m_o_
)
*
arg
.
batch_count_
;
(
Deterministic
?
1
:
arg
.
block_2_ctile_map_
.
CalculateGridSize
(
arg
.
y_grid_desc_m_o_
))
*
arg
.
batch_count_
;
// Gemm0_K
// Gemm0_K
const
auto
K
=
const
auto
K
=
...
@@ -973,7 +976,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -973,7 +976,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
has_main_k_block_loop_
,
has_main_k_block_loop_
,
Deterministic
>
;
Deterministic
>
;
return
launch_and_time_kernel
(
stream_config
,
return
launch_and_time_kernel
(
stream_config
,
kernel
,
kernel
,
dim3
(
grid_size
),
dim3
(
grid_size
),
dim3
(
BlockSize
),
dim3
(
BlockSize
),
...
@@ -1003,6 +1007,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -1003,6 +1007,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
arg
.
ygrad_grid_desc_m0_o_m1_
,
arg
.
ygrad_grid_desc_m0_o_m1_
,
arg
.
block_2_ctile_map_
,
arg
.
block_2_ctile_map_
,
arg
.
batch_count_
,
arg
.
batch_count_
,
arg
.
block_2_ctile_map_
.
CalculateGridSize
(
arg
.
y_grid_desc_m_o_
),
arg
.
compute_base_ptr_of_batch_
,
arg
.
compute_base_ptr_of_batch_
,
arg
.
c0_matrix_mask_
,
arg
.
c0_matrix_mask_
,
arg
.
p_drop_
,
arg
.
p_drop_
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_forward_xdl_cshuffle.hpp
View file @
99ebfeba
...
@@ -45,7 +45,8 @@ template <typename GridwiseGemm,
...
@@ -45,7 +45,8 @@ template <typename GridwiseGemm,
typename
C0MatrixMask
,
typename
C0MatrixMask
,
bool
HasMainKBlockLoop
,
bool
HasMainKBlockLoop
,
bool
IsDropout
,
bool
IsDropout
,
bool
IsLseStoring
>
bool
IsLseStoring
,
bool
Deterministic
>
__global__
void
__global__
void
#if CK_USE_LAUNCH_BOUNDS
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
...
@@ -72,6 +73,7 @@ __global__ void
...
@@ -72,6 +73,7 @@ __global__ void
const
LSEGridDescriptor_M
lse_grid_desc_m
,
const
LSEGridDescriptor_M
lse_grid_desc_m
,
const
Block2CTileMap
block_2_ctile_map
,
const
Block2CTileMap
block_2_ctile_map
,
const
index_t
batch_count
,
const
index_t
batch_count
,
const
index_t
mblock
,
const
ComputeBasePtrOfStridedBatch
compute_base_ptr_of_batch
,
const
ComputeBasePtrOfStridedBatch
compute_base_ptr_of_batch
,
const
C0MatrixMask
c0_matrix_mask
,
const
C0MatrixMask
c0_matrix_mask
,
const
ushort
p_dropout_in_16bits
,
const
ushort
p_dropout_in_16bits
,
...
@@ -101,6 +103,39 @@ __global__ void
...
@@ -101,6 +103,39 @@ __global__ void
const
index_t
global_thread_id
=
get_thread_global_1d_id
();
const
index_t
global_thread_id
=
get_thread_global_1d_id
();
ck
::
philox
ph
(
seed
,
global_thread_id
,
offset
);
ck
::
philox
ph
(
seed
,
global_thread_id
,
offset
);
if
constexpr
(
Deterministic
)
{
for
(
index_t
i
=
0
;
i
<
mblock
;
i
++
)
{
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
,
IsDropout
,
IsLseStoring
>(
p_a_grid
+
a_batch_offset
,
p_b_grid
+
b_batch_offset
,
p_b1_grid
+
b1_batch_offset
,
p_c_grid
+
c_batch_offset
,
nullptr
?
nullptr
:
p_z_grid
+
z_batch_offset
,
nullptr
?
nullptr
:
p_lse_grid
+
lse_batch_offset
,
p_shared
,
a_element_op
,
b_element_op
,
acc_element_op
,
b1_element_op
,
c_element_op
,
a_grid_desc_ak0_m_ak1
,
b_grid_desc_bk0_n_bk1
,
b1_grid_desc_bk0_n_bk1
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
lse_grid_desc_m
,
block_2_ctile_map
,
c0_matrix_mask
,
p_dropout_in_16bits
,
p_dropout_rescale
,
ph
,
i
);
}
}
else
{
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
,
IsDropout
,
IsLseStoring
>(
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
,
IsDropout
,
IsLseStoring
>(
p_a_grid
+
a_batch_offset
,
p_a_grid
+
a_batch_offset
,
p_b_grid
+
b_batch_offset
,
p_b_grid
+
b_batch_offset
,
...
@@ -124,7 +159,9 @@ __global__ void
...
@@ -124,7 +159,9 @@ __global__ void
c0_matrix_mask
,
c0_matrix_mask
,
p_dropout_in_16bits
,
p_dropout_in_16bits
,
p_dropout_rescale
,
p_dropout_rescale
,
ph
);
ph
,
0
);
}
#else
#else
ignore
=
p_a_grid
;
ignore
=
p_a_grid
;
ignore
=
p_b_grid
;
ignore
=
p_b_grid
;
...
@@ -216,6 +253,7 @@ template <index_t NumDimG,
...
@@ -216,6 +253,7 @@ template <index_t NumDimG,
typename
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
,
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
,
MaskingSpecialization
MaskingSpec
,
MaskingSpecialization
MaskingSpec
,
bool
Deterministic
,
LoopScheduler
LoopSched
=
LoopScheduler
::
Default
>
LoopScheduler
LoopSched
=
LoopScheduler
::
Default
>
struct
DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
struct
DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
:
public
DeviceBatchedMultiheadAttentionForward
<
NumDimG
,
:
public
DeviceBatchedMultiheadAttentionForward
<
NumDimG
,
...
@@ -476,7 +514,8 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -476,7 +514,8 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
CShuffleBlockTransferScalarPerVector_NPerBlock
,
CShuffleBlockTransferScalarPerVector_NPerBlock
,
LoopSched
,
LoopSched
,
Transform
::
matrix_padder
.
PadN
,
Transform
::
matrix_padder
.
PadN
,
MaskingSpec
==
MaskingSpecialization
::
MaskOutUpperTriangle
>
;
MaskingSpec
==
MaskingSpecialization
::
MaskOutUpperTriangle
,
Deterministic
>
;
// Argument
// Argument
// FIXME: constness
// FIXME: constness
...
@@ -695,7 +734,9 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -695,7 +734,9 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
}
}
const
index_t
grid_size
=
const
index_t
grid_size
=
arg
.
block_2_ctile_map_
.
CalculateGridSize
(
arg
.
c_grid_desc_m_n_
)
*
arg
.
batch_count_
;
(
Deterministic
?
1
:
arg
.
block_2_ctile_map_
.
CalculateGridSize
(
arg
.
c_grid_desc_m_n_
))
*
arg
.
batch_count_
;
// Gemm0_K
// Gemm0_K
const
auto
K
=
const
auto
K
=
...
@@ -703,9 +744,8 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -703,9 +744,8 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
float
ave_time
=
0
;
float
ave_time
=
0
;
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop_
,
auto
launch_kernel
=
auto
is_dropout_
,
[
&
](
auto
has_main_k_block_loop_
,
auto
is_dropout_
,
auto
is_lse_storing_
)
{
auto
is_lse_storing_
)
{
const
auto
kernel
=
kernel_batched_multiheadattention_forward_xdl_cshuffle
<
const
auto
kernel
=
kernel_batched_multiheadattention_forward_xdl_cshuffle
<
GridwiseGemm
,
GridwiseGemm
,
ADataType
,
// TODO: distiguish A/B datatype
ADataType
,
// TODO: distiguish A/B datatype
...
@@ -729,9 +769,11 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -729,9 +769,11 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
C0MatrixMask
,
C0MatrixMask
,
has_main_k_block_loop_
,
has_main_k_block_loop_
,
is_dropout_
,
is_dropout_
,
is_lse_storing_
>
;
is_lse_storing_
,
Deterministic
>
;
return
launch_and_time_kernel
(
stream_config
,
return
launch_and_time_kernel
(
stream_config
,
kernel
,
kernel
,
dim3
(
grid_size
),
dim3
(
grid_size
),
dim3
(
BlockSize
),
dim3
(
BlockSize
),
...
@@ -755,6 +797,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -755,6 +797,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
arg
.
lse_grid_desc_m_
,
arg
.
lse_grid_desc_m_
,
arg
.
block_2_ctile_map_
,
arg
.
block_2_ctile_map_
,
arg
.
batch_count_
,
arg
.
batch_count_
,
arg
.
block_2_ctile_map_
.
CalculateGridSize
(
arg
.
c_grid_desc_m_n_
),
arg
.
compute_base_ptr_of_batch_
,
arg
.
compute_base_ptr_of_batch_
,
arg
.
c0_matrix_mask_
,
arg
.
c0_matrix_mask_
,
arg
.
p_dropout_in_16bits_
,
arg
.
p_dropout_in_16bits_
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_multihead_attention_backward_xdl_cshuffle_v1.hpp
View file @
99ebfeba
...
@@ -79,7 +79,7 @@ __global__ void
...
@@ -79,7 +79,7 @@ __global__ void
// per-group batch offset
// per-group batch offset
const
index_t
num_blocks_per_batch
=
arg_ptr
[
group_id
].
num_blocks_per_batch_
;
const
index_t
num_blocks_per_batch
=
arg_ptr
[
group_id
].
num_blocks_per_batch_
;
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
(
block_id
-
arg_ptr
[
group_id
].
block_start_
)
/
num_blocks_per_batch
);
(
block_id
-
arg_ptr
[
group_id
].
block_start_
)
/
(
Deterministic
?
1
:
num_blocks_per_batch
)
)
;
const
long_index_t
a_batch_offset
=
__builtin_amdgcn_readfirstlane
(
const
long_index_t
a_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
arg_ptr
[
group_id
].
compute_base_ptr_of_batch_
.
GetABasePtr
(
g_idx
)));
static_cast
<
long_index_t
>
(
arg_ptr
[
group_id
].
compute_base_ptr_of_batch_
.
GetABasePtr
(
g_idx
)));
...
@@ -103,8 +103,6 @@ __global__ void
...
@@ -103,8 +103,6 @@ __global__ void
if
constexpr
(
Deterministic
)
if
constexpr
(
Deterministic
)
{
{
for
(
index_t
i
=
0
;
i
<
num_blocks_per_batch
;
i
++
)
for
(
index_t
i
=
0
;
i
<
num_blocks_per_batch
;
i
++
)
{
if
(((
block_id
-
arg_ptr
[
group_id
].
block_start_
)
%
num_blocks_per_batch
)
==
i
)
{
{
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
arg_ptr
[
group_id
].
p_a_grid_
+
a_batch_offset
,
arg_ptr
[
group_id
].
p_a_grid_
+
a_batch_offset
,
...
@@ -134,8 +132,8 @@ __global__ void
...
@@ -134,8 +132,8 @@ __global__ void
arg_ptr
[
group_id
].
block_2_ctile_map_
,
arg_ptr
[
group_id
].
block_2_ctile_map_
,
arg_ptr
[
group_id
].
c0_matrix_mask_
,
arg_ptr
[
group_id
].
c0_matrix_mask_
,
p_dropout
,
p_dropout
,
ph
);
ph
,
}
i
);
}
}
}
}
else
else
...
@@ -168,7 +166,8 @@ __global__ void
...
@@ -168,7 +166,8 @@ __global__ void
arg_ptr
[
group_id
].
block_2_ctile_map_
,
arg_ptr
[
group_id
].
block_2_ctile_map_
,
arg_ptr
[
group_id
].
c0_matrix_mask_
,
arg_ptr
[
group_id
].
c0_matrix_mask_
,
p_dropout
,
p_dropout
,
ph
);
ph
,
0
);
}
}
#else
#else
ignore
=
group_kernel_args
;
ignore
=
group_kernel_args
;
...
@@ -643,7 +642,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1
...
@@ -643,7 +642,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1
CShuffleBlockTransferScalarPerVector_NPerBlock
,
CShuffleBlockTransferScalarPerVector_NPerBlock
,
LoopSched
,
LoopSched
,
Transform
::
matrix_padder
.
PadN
,
Transform
::
matrix_padder
.
PadN
,
MaskingSpec
==
MaskingSpecialization
::
MaskOutUpperTriangle
>
;
MaskingSpec
==
MaskingSpecialization
::
MaskOutUpperTriangle
,
Deterministic
>
;
using
Block2CTileMap
=
OffsettedBlockToCTileMap
<
typename
GridwiseGemm
::
DefaultBlock2CTileMap
>
;
using
Block2CTileMap
=
OffsettedBlockToCTileMap
<
typename
GridwiseGemm
::
DefaultBlock2CTileMap
>
;
...
@@ -825,7 +825,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1
...
@@ -825,7 +825,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1
const
index_t
batch_count
=
c_grid_desc_g_m_n
.
GetLength
(
I0
);
const
index_t
batch_count
=
c_grid_desc_g_m_n
.
GetLength
(
I0
);
const
index_t
grid_size_grp
=
const
index_t
grid_size_grp
=
block_2_ctile_map
.
CalculateGridSize
(
y_grid_desc_m_o
)
*
batch_count
;
(
Deterministic
?
1
:
block_2_ctile_map
.
CalculateGridSize
(
y_grid_desc_m_o
))
*
batch_count
;
const
index_t
BlockEnd
=
grid_size_
+
grid_size_grp
;
const
index_t
BlockEnd
=
grid_size_
+
grid_size_grp
;
// batch stride
// batch stride
const
auto
compute_base_ptr_of_batch
=
ComputeBasePtrOfStridedBatch
(
const
auto
compute_base_ptr_of_batch
=
ComputeBasePtrOfStridedBatch
(
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_multihead_attention_backward_xdl_cshuffle_v2.hpp
View file @
99ebfeba
...
@@ -79,7 +79,7 @@ __global__ void
...
@@ -79,7 +79,7 @@ __global__ void
// per-group batch offset
// per-group batch offset
const
index_t
num_blocks_per_batch
=
arg_ptr
[
group_id
].
num_blocks_per_batch_
;
const
index_t
num_blocks_per_batch
=
arg_ptr
[
group_id
].
num_blocks_per_batch_
;
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
(
block_id
-
arg_ptr
[
group_id
].
block_start_
)
/
num_blocks_per_batch
);
(
block_id
-
arg_ptr
[
group_id
].
block_start_
)
/
(
Deterministic
?
1
:
num_blocks_per_batch
)
)
;
const
long_index_t
a_batch_offset
=
__builtin_amdgcn_readfirstlane
(
const
long_index_t
a_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
arg_ptr
[
group_id
].
compute_base_ptr_of_batch_
.
GetABasePtr
(
g_idx
)));
static_cast
<
long_index_t
>
(
arg_ptr
[
group_id
].
compute_base_ptr_of_batch_
.
GetABasePtr
(
g_idx
)));
...
@@ -103,8 +103,6 @@ __global__ void
...
@@ -103,8 +103,6 @@ __global__ void
if
constexpr
(
Deterministic
)
if
constexpr
(
Deterministic
)
{
{
for
(
index_t
i
=
0
;
i
<
num_blocks_per_batch
;
i
++
)
for
(
index_t
i
=
0
;
i
<
num_blocks_per_batch
;
i
++
)
{
if
(((
block_id
-
arg_ptr
[
group_id
].
block_start_
)
%
num_blocks_per_batch
)
==
i
)
{
{
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
arg_ptr
[
group_id
].
p_a_grid_
+
a_batch_offset
,
arg_ptr
[
group_id
].
p_a_grid_
+
a_batch_offset
,
...
@@ -134,8 +132,8 @@ __global__ void
...
@@ -134,8 +132,8 @@ __global__ void
arg_ptr
[
group_id
].
block_2_ctile_map_
,
arg_ptr
[
group_id
].
block_2_ctile_map_
,
arg_ptr
[
group_id
].
c0_matrix_mask_
,
arg_ptr
[
group_id
].
c0_matrix_mask_
,
p_dropout
,
p_dropout
,
ph
);
ph
,
}
i
);
}
}
}
}
else
else
...
@@ -168,7 +166,8 @@ __global__ void
...
@@ -168,7 +166,8 @@ __global__ void
arg_ptr
[
group_id
].
block_2_ctile_map_
,
arg_ptr
[
group_id
].
block_2_ctile_map_
,
arg_ptr
[
group_id
].
c0_matrix_mask_
,
arg_ptr
[
group_id
].
c0_matrix_mask_
,
p_dropout
,
p_dropout
,
ph
);
ph
,
0
);
}
}
#else
#else
ignore
=
group_kernel_args
;
ignore
=
group_kernel_args
;
...
@@ -636,7 +635,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -636,7 +635,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
CShuffleBlockTransferScalarPerVector_NPerBlock
,
CShuffleBlockTransferScalarPerVector_NPerBlock
,
LoopSched
,
LoopSched
,
Transform
::
matrix_padder
.
PadN
,
Transform
::
matrix_padder
.
PadN
,
MaskingSpec
==
MaskingSpecialization
::
MaskOutUpperTriangle
>
;
MaskingSpec
==
MaskingSpecialization
::
MaskOutUpperTriangle
,
Deterministic
>
;
using
Block2CTileMap
=
OffsettedBlockToCTileMap
<
typename
GridwiseGemm
::
DefaultBlock2CTileMap
>
;
using
Block2CTileMap
=
OffsettedBlockToCTileMap
<
typename
GridwiseGemm
::
DefaultBlock2CTileMap
>
;
...
@@ -818,7 +818,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -818,7 +818,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
const
index_t
batch_count
=
c_grid_desc_g_m_n
.
GetLength
(
I0
);
const
index_t
batch_count
=
c_grid_desc_g_m_n
.
GetLength
(
I0
);
const
index_t
grid_size_grp
=
const
index_t
grid_size_grp
=
block_2_ctile_map
.
CalculateGridSize
(
y_grid_desc_m_o
)
*
batch_count
;
(
Deterministic
?
1
:
block_2_ctile_map
.
CalculateGridSize
(
y_grid_desc_m_o
))
*
batch_count
;
const
index_t
BlockEnd
=
grid_size_
+
grid_size_grp
;
const
index_t
BlockEnd
=
grid_size_
+
grid_size_grp
;
// batch stride
// batch stride
const
auto
compute_base_ptr_of_batch
=
ComputeBasePtrOfStridedBatch
(
const
auto
compute_base_ptr_of_batch
=
ComputeBasePtrOfStridedBatch
(
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_multihead_attention_forward_xdl_cshuffle.hpp
View file @
99ebfeba
...
@@ -33,7 +33,8 @@ template <typename GridwiseGemm,
...
@@ -33,7 +33,8 @@ template <typename GridwiseGemm,
typename
CElementwiseOperation
,
typename
CElementwiseOperation
,
bool
HasMainKBlockLoop
,
bool
HasMainKBlockLoop
,
bool
IsDropout
,
bool
IsDropout
,
bool
IsLseStoring
>
bool
IsLseStoring
,
bool
Deterministic
>
__global__
void
__global__
void
#if CK_USE_LAUNCH_BOUNDS
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
...
@@ -83,7 +84,7 @@ __global__ void
...
@@ -83,7 +84,7 @@ __global__ void
// per-group batch offset
// per-group batch offset
const
index_t
num_blocks_per_batch
=
arg_ptr
[
group_id
].
num_blocks_per_batch_
;
const
index_t
num_blocks_per_batch
=
arg_ptr
[
group_id
].
num_blocks_per_batch_
;
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
(
block_id
-
arg_ptr
[
group_id
].
block_start_
)
/
num_blocks_per_batch
);
(
block_id
-
arg_ptr
[
group_id
].
block_start_
)
/
(
Deterministic
?
1
:
num_blocks_per_batch
)
)
;
const
long_index_t
a_batch_offset
=
__builtin_amdgcn_readfirstlane
(
const
long_index_t
a_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
arg_ptr
[
group_id
].
compute_base_ptr_of_batch_
.
GetABasePtr
(
g_idx
)));
static_cast
<
long_index_t
>
(
arg_ptr
[
group_id
].
compute_base_ptr_of_batch_
.
GetABasePtr
(
g_idx
)));
...
@@ -98,6 +99,44 @@ __global__ void
...
@@ -98,6 +99,44 @@ __global__ void
const
long_index_t
lse_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
const
long_index_t
lse_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
arg_ptr
[
group_id
].
compute_base_ptr_of_batch_
.
GetLSEBasePtr
(
g_idx
)));
arg_ptr
[
group_id
].
compute_base_ptr_of_batch_
.
GetLSEBasePtr
(
g_idx
)));
if
constexpr
(
Deterministic
)
{
for
(
index_t
i
=
0
;
i
<
num_blocks_per_batch
;
i
++
)
{
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
,
IsDropout
,
IsLseStoring
>(
arg_ptr
[
group_id
].
p_a_grid_
+
a_batch_offset
,
arg_ptr
[
group_id
].
p_b_grid_
+
b_batch_offset
,
arg_ptr
[
group_id
].
p_b1_grid_
+
b1_batch_offset
,
arg_ptr
[
group_id
].
p_c_grid_
+
c_batch_offset
,
arg_ptr
[
group_id
].
p_z_grid_
==
nullptr
?
nullptr
:
arg_ptr
[
group_id
].
p_z_grid_
+
z_batch_offset
,
arg_ptr
[
group_id
].
p_lse_grid_
==
nullptr
?
nullptr
:
arg_ptr
[
group_id
].
p_lse_grid_
+
lse_batch_offset
,
// arg_ptr[group_id].p_lse_grid_ + lse_batch_offset,
p_shared
,
a_element_op
,
b_element_op
,
acc_element_op
,
b1_element_op
,
c_element_op
,
arg_ptr
[
group_id
].
a_grid_desc_ak0_m_ak1_
,
arg_ptr
[
group_id
].
b_grid_desc_bk0_n_bk1_
,
arg_ptr
[
group_id
].
b1_grid_desc_bk0_n_bk1_
,
arg_ptr
[
group_id
].
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg_ptr
[
group_id
].
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_
,
arg_ptr
[
group_id
].
lse_grid_desc_m_
,
arg_ptr
[
group_id
].
block_2_ctile_map_
,
arg_ptr
[
group_id
].
c0_matrix_mask_
,
p_dropout_in_16bits
,
p_dropout_rescale
,
ph
,
i
);
}
}
else
{
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
,
IsDropout
,
IsLseStoring
>(
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
,
IsDropout
,
IsLseStoring
>(
arg_ptr
[
group_id
].
p_a_grid_
+
a_batch_offset
,
arg_ptr
[
group_id
].
p_a_grid_
+
a_batch_offset
,
arg_ptr
[
group_id
].
p_b_grid_
+
b_batch_offset
,
arg_ptr
[
group_id
].
p_b_grid_
+
b_batch_offset
,
...
@@ -105,7 +144,8 @@ __global__ void
...
@@ -105,7 +144,8 @@ __global__ void
arg_ptr
[
group_id
].
p_c_grid_
+
c_batch_offset
,
arg_ptr
[
group_id
].
p_c_grid_
+
c_batch_offset
,
arg_ptr
[
group_id
].
p_z_grid_
==
nullptr
?
nullptr
arg_ptr
[
group_id
].
p_z_grid_
==
nullptr
?
nullptr
:
arg_ptr
[
group_id
].
p_z_grid_
+
z_batch_offset
,
:
arg_ptr
[
group_id
].
p_z_grid_
+
z_batch_offset
,
arg_ptr
[
group_id
].
p_lse_grid_
==
nullptr
?
nullptr
arg_ptr
[
group_id
].
p_lse_grid_
==
nullptr
?
nullptr
:
arg_ptr
[
group_id
].
p_lse_grid_
+
lse_batch_offset
,
:
arg_ptr
[
group_id
].
p_lse_grid_
+
lse_batch_offset
,
// arg_ptr[group_id].p_lse_grid_ + lse_batch_offset,
// arg_ptr[group_id].p_lse_grid_ + lse_batch_offset,
p_shared
,
p_shared
,
...
@@ -124,7 +164,9 @@ __global__ void
...
@@ -124,7 +164,9 @@ __global__ void
arg_ptr
[
group_id
].
c0_matrix_mask_
,
arg_ptr
[
group_id
].
c0_matrix_mask_
,
p_dropout_in_16bits
,
p_dropout_in_16bits
,
p_dropout_rescale
,
p_dropout_rescale
,
ph
);
ph
,
0
);
}
#else
#else
ignore
=
group_kernel_args
;
ignore
=
group_kernel_args
;
ignore
=
group_count
;
ignore
=
group_count
;
...
@@ -206,6 +248,7 @@ template <index_t NumDimG,
...
@@ -206,6 +248,7 @@ template <index_t NumDimG,
typename
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
,
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
,
MaskingSpecialization
MaskingSpec
,
MaskingSpecialization
MaskingSpec
,
bool
Deterministic
,
LoopScheduler
LoopSched
=
LoopScheduler
::
Default
>
LoopScheduler
LoopSched
=
LoopScheduler
::
Default
>
struct
DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
struct
DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
:
public
DeviceGroupedMultiheadAttentionForward
<
NumDimG
,
:
public
DeviceGroupedMultiheadAttentionForward
<
NumDimG
,
...
@@ -487,7 +530,8 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -487,7 +530,8 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
CShuffleBlockTransferScalarPerVector_NPerBlock
,
CShuffleBlockTransferScalarPerVector_NPerBlock
,
LoopSched
,
LoopSched
,
Transform
::
matrix_padder
.
PadN
,
Transform
::
matrix_padder
.
PadN
,
MaskingSpec
==
MaskingSpecialization
::
MaskOutUpperTriangle
>
;
MaskingSpec
==
MaskingSpecialization
::
MaskOutUpperTriangle
,
Deterministic
>
;
using
Block2CTileMap
=
OffsettedBlockToCTileMap
<
typename
GridwiseGemm
::
DefaultBlock2CTileMap
>
;
using
Block2CTileMap
=
OffsettedBlockToCTileMap
<
typename
GridwiseGemm
::
DefaultBlock2CTileMap
>
;
...
@@ -638,7 +682,8 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -638,7 +682,8 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
const
auto
block_2_ctile_map
=
Block2CTileMap
(
c_grid_desc_m_n
,
BlockStart
);
const
auto
block_2_ctile_map
=
Block2CTileMap
(
c_grid_desc_m_n
,
BlockStart
);
const
index_t
batch_count
=
c_grid_desc_g_m_n
.
GetLength
(
I0
);
const
index_t
batch_count
=
c_grid_desc_g_m_n
.
GetLength
(
I0
);
const
index_t
grid_size_grp
=
const
index_t
grid_size_grp
=
block_2_ctile_map
.
CalculateGridSize
(
c_grid_desc_m_n
)
*
batch_count
;
(
Deterministic
?
1
:
block_2_ctile_map
.
CalculateGridSize
(
c_grid_desc_m_n
))
*
batch_count
;
const
index_t
BlockEnd
=
grid_size_
+
grid_size_grp
;
const
index_t
BlockEnd
=
grid_size_
+
grid_size_grp
;
// batch stride
// batch stride
...
@@ -778,7 +823,8 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -778,7 +823,8 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
CElementwiseOperation
,
CElementwiseOperation
,
has_main_k_block_loop_
,
has_main_k_block_loop_
,
is_dropout_
,
is_dropout_
,
is_lse_storing_
>
;
is_lse_storing_
,
Deterministic
>
;
return
launch_and_time_kernel
(
return
launch_and_time_kernel
(
stream_config
,
stream_config
,
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_pt1.hpp
View file @
99ebfeba
...
@@ -86,6 +86,7 @@ template <typename InputDataType,
...
@@ -86,6 +86,7 @@ template <typename InputDataType,
LoopScheduler
LoopSched
,
LoopScheduler
LoopSched
,
bool
PadN
,
bool
PadN
,
bool
MaskOutUpperTriangle
,
bool
MaskOutUpperTriangle
,
bool
Deterministic
,
PipelineVersion
PipelineVer
=
PipelineVersion
::
v1
>
PipelineVersion
PipelineVer
=
PipelineVersion
::
v1
>
struct
GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
struct
GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
{
{
...
@@ -1265,7 +1266,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
...
@@ -1265,7 +1266,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
const
Block2CTileMap
&
block_2_ctile_map
,
const
Block2CTileMap
&
block_2_ctile_map
,
const
C0MatrixMask
&
c0_matrix_mask
,
const
C0MatrixMask
&
c0_matrix_mask
,
const
float
p_drop
,
const
float
p_drop
,
ck
::
philox
&
ph
)
ck
::
philox
&
ph
,
const
index_t
block_idx_m
)
{
{
const
FloatGemmAcc
p_dropout
=
type_convert
<
FloatGemmAcc
>
(
1.0
f
-
p_drop
);
const
FloatGemmAcc
p_dropout
=
type_convert
<
FloatGemmAcc
>
(
1.0
f
-
p_drop
);
const
FloatGemmAcc
rp_dropout
=
type_convert
<
FloatGemmAcc
>
(
1.0
f
/
p_dropout
);
const
FloatGemmAcc
rp_dropout
=
type_convert
<
FloatGemmAcc
>
(
1.0
f
/
p_dropout
);
...
@@ -1305,9 +1307,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
...
@@ -1305,9 +1307,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
return
;
return
;
}
}
const
index_t
block_work_idx_m
=
Deterministic
?
block_idx_m
:
block_work_idx
[
I0
];
// HACK: this force m/o_block_data_idx_on_grid into SGPR
// HACK: this force m/o_block_data_idx_on_grid into SGPR
const
index_t
m_block_data_idx_on_grid
=
const
index_t
m_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I0
]
*
MPerBlock
);
__builtin_amdgcn_readfirstlane
(
block_work_idx
_m
*
MPerBlock
);
// const index_t o_block_data_idx_on_grid =
// const index_t o_block_data_idx_on_grid =
// __builtin_amdgcn_readfirstlane(block_work_idx[I1] * Gemm1NPerBlock);
// __builtin_amdgcn_readfirstlane(block_work_idx[I1] * Gemm1NPerBlock);
...
@@ -1512,7 +1516,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
...
@@ -1512,7 +1516,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
1
,
1
,
false
>
{
false
>
{
lse_grid_desc_mblock_mrepeat_mwave_mperxdl
,
lse_grid_desc_mblock_mrepeat_mwave_mperxdl
,
make_multi_index
(
block_work_idx
[
I0
],
// mblock
make_multi_index
(
block_work_idx
_m
,
// mblock
acc0_thread_origin
[
I0
],
// mrepeat
acc0_thread_origin
[
I0
],
// mrepeat
acc0_thread_origin
[
I2
],
// mwave
acc0_thread_origin
[
I2
],
// mwave
acc0_thread_origin
[
I4
])};
// mperxdl
acc0_thread_origin
[
I4
])};
// mperxdl
...
@@ -1574,7 +1578,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
...
@@ -1574,7 +1578,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
InMemoryDataOperationEnum
::
Set
,
InMemoryDataOperationEnum
::
Set
,
1
,
// DstScalarStrideInVector
1
,
// DstScalarStrideInVector
true
>
{
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
true
>
{
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
make_multi_index
(
block_work_idx
[
I0
]
,
// MBlockId
make_multi_index
(
block_work_idx
_m
,
// MBlockId
0
,
// NBlockId
0
,
// NBlockId
0
,
// mrepeat
0
,
// mrepeat
0
,
// nrepeat
0
,
// nrepeat
...
@@ -1720,7 +1724,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
...
@@ -1720,7 +1724,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
ygrad_thread_cluster_idx
*
ygrad_thread_desc_m_o
.
GetLengths
();
ygrad_thread_cluster_idx
*
ygrad_thread_desc_m_o
.
GetLengths
();
const
auto
y_thread_data_on_grid_idx
=
const
auto
y_thread_data_on_grid_idx
=
make_multi_index
(
make_multi_index
(
block_work_idx
[
I0
]
,
I0
,
I0
/* all WGs start from o_block_idx = 0 */
,
I0
)
+
block_work_idx
_m
,
I0
,
I0
/* all WGs start from o_block_idx = 0 */
,
I0
)
+
y_thread_data_on_block_idx
;
y_thread_data_on_block_idx
;
// performs for y
// performs for y
...
@@ -2320,7 +2324,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
...
@@ -2320,7 +2324,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
{
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
,
{
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
,
make_multi_index
(
0
,
0
,
0
,
0
),
make_multi_index
(
0
,
0
,
0
,
0
),
qgrad_grid_desc_mblock_mperblock_kblock_kperblock
,
qgrad_grid_desc_mblock_mperblock_kblock_kperblock
,
make_multi_index
(
block_work_idx
[
I0
]
,
0
,
block_work_idx
[
I1
],
0
),
make_multi_index
(
block_work_idx
_m
,
0
,
block_work_idx
[
I1
],
0
),
c_element_op
};
c_element_op
};
// space filling curve for threadwise C in VGPR
// space filling curve for threadwise C in VGPR
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_pt2.hpp
View file @
99ebfeba
...
@@ -86,6 +86,7 @@ template <typename InputDataType,
...
@@ -86,6 +86,7 @@ template <typename InputDataType,
LoopScheduler
LoopSched
,
LoopScheduler
LoopSched
,
bool
PadN
,
bool
PadN
,
bool
MaskOutUpperTriangle
,
bool
MaskOutUpperTriangle
,
bool
Deterministic
,
PipelineVersion
PipelineVer
=
PipelineVersion
::
v1
>
PipelineVersion
PipelineVer
=
PipelineVersion
::
v1
>
struct
GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
struct
GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
{
{
...
@@ -1175,7 +1176,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -1175,7 +1176,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
const
Block2CTileMap
&
block_2_ctile_map
,
const
Block2CTileMap
&
block_2_ctile_map
,
const
C0MatrixMask
&
c0_matrix_mask
,
const
C0MatrixMask
&
c0_matrix_mask
,
const
float
p_drop
,
const
float
p_drop
,
ck
::
philox
&
ph
)
ck
::
philox
&
ph
,
const
index_t
block_idx_m
)
{
{
const
FloatGemmAcc
p_dropout
=
type_convert
<
FloatGemmAcc
>
(
1.0
f
-
p_drop
);
const
FloatGemmAcc
p_dropout
=
type_convert
<
FloatGemmAcc
>
(
1.0
f
-
p_drop
);
const
FloatGemmAcc
rp_dropout
=
type_convert
<
FloatGemmAcc
>
(
1.0
f
/
p_dropout
);
const
FloatGemmAcc
rp_dropout
=
type_convert
<
FloatGemmAcc
>
(
1.0
f
/
p_dropout
);
...
@@ -1215,9 +1217,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -1215,9 +1217,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
return
;
return
;
}
}
const
index_t
block_work_idx_m
=
Deterministic
?
block_idx_m
:
block_work_idx
[
I0
];
// HACK: this force m/o_block_data_idx_on_grid into SGPR
// HACK: this force m/o_block_data_idx_on_grid into SGPR
const
index_t
m_block_data_idx_on_grid
=
const
index_t
m_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I0
]
*
MPerBlock
);
__builtin_amdgcn_readfirstlane
(
block_work_idx
_m
*
MPerBlock
);
const
index_t
o_block_data_idx_on_grid
=
const
index_t
o_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I1
]
*
Gemm1NPerBlock
);
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I1
]
*
Gemm1NPerBlock
);
...
@@ -1444,7 +1448,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -1444,7 +1448,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
1
,
1
,
false
>
{
false
>
{
lse_grid_desc_mblock_mrepeat_mwave_mperxdl
,
lse_grid_desc_mblock_mrepeat_mwave_mperxdl
,
make_multi_index
(
block_work_idx
[
I0
]
,
// mblock
make_multi_index
(
block_work_idx
_m
,
// mblock
acc0_thread_origin
[
I0
],
// mrepeat
acc0_thread_origin
[
I0
],
// mrepeat
acc0_thread_origin
[
I2
],
// mwave
acc0_thread_origin
[
I2
],
// mwave
acc0_thread_origin
[
I4
])};
// mperxdl
acc0_thread_origin
[
I4
])};
// mperxdl
...
@@ -1506,7 +1510,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -1506,7 +1510,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
InMemoryDataOperationEnum
::
Set
,
InMemoryDataOperationEnum
::
Set
,
1
,
// DstScalarStrideInVector
1
,
// DstScalarStrideInVector
true
>
{
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
true
>
{
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
make_multi_index
(
block_work_idx
[
I0
]
,
// MBlockId
make_multi_index
(
block_work_idx
_m
,
// MBlockId
0
,
// NBlockId
0
,
// NBlockId
0
,
// mrepeat
0
,
// mrepeat
0
,
// nrepeat
0
,
// nrepeat
...
@@ -1643,7 +1647,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -1643,7 +1647,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
y_thread_cluster_idx
*
y_thread_desc_m0_m1_o0_o1
.
GetLengths
();
y_thread_cluster_idx
*
y_thread_desc_m0_m1_o0_o1
.
GetLengths
();
const
auto
y_thread_data_on_grid_idx
=
const
auto
y_thread_data_on_grid_idx
=
make_multi_index
(
make_multi_index
(
block_work_idx
[
I0
]
,
I0
,
I0
/* all WGs start from o_block_idx = 0 */
,
I0
)
+
block_work_idx
_m
,
I0
,
I0
/* all WGs start from o_block_idx = 0 */
,
I0
)
+
y_thread_data_on_block_idx
;
y_thread_data_on_block_idx
;
// performs double duty for both y and ygrad
// performs double duty for both y and ygrad
...
@@ -2270,7 +2274,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -2270,7 +2274,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
{
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
,
{
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
,
make_multi_index
(
0
,
0
,
0
,
0
),
make_multi_index
(
0
,
0
,
0
,
0
),
qgrad_grid_desc_mblock_mperblock_kblock_kperblock
,
qgrad_grid_desc_mblock_mperblock_kblock_kperblock
,
make_multi_index
(
block_work_idx
[
I0
]
,
0
,
block_work_idx
[
I1
],
0
),
make_multi_index
(
block_work_idx
_m
,
0
,
block_work_idx
[
I1
],
0
),
c_element_op
};
c_element_op
};
// space filling curve for threadwise C in VGPR
// space filling curve for threadwise C in VGPR
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_forward_xdl_cshuffle.hpp
View file @
99ebfeba
...
@@ -85,6 +85,7 @@ template <typename FloatAB,
...
@@ -85,6 +85,7 @@ template <typename FloatAB,
LoopScheduler
LoopSched
,
LoopScheduler
LoopSched
,
bool
PadN
,
bool
PadN
,
bool
MaskOutUpperTriangle
,
bool
MaskOutUpperTriangle
,
bool
Deterministic
,
PipelineVersion
PipelineVer
=
PipelineVersion
::
v1
>
PipelineVersion
PipelineVer
=
PipelineVersion
::
v1
>
struct
GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
struct
GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
{
{
...
@@ -445,7 +446,8 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -445,7 +446,8 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
const
C0MatrixMask
&
c0_matrix_mask
,
const
C0MatrixMask
&
c0_matrix_mask
,
const
ushort
p_dropout_in_16bits
,
const
ushort
p_dropout_in_16bits
,
FloatGemmAcc
p_dropout_rescale
,
FloatGemmAcc
p_dropout_rescale
,
ck
::
philox
ph
)
ck
::
philox
&
ph
,
const
index_t
block_idx_m
)
{
{
const
auto
a_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
const
auto
a_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_a_grid
,
a_grid_desc_ak0_m_ak1
.
GetElementSpaceSize
());
p_a_grid
,
a_grid_desc_ak0_m_ak1
.
GetElementSpaceSize
());
...
@@ -470,9 +472,11 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -470,9 +472,11 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
return
;
return
;
}
}
const
index_t
block_work_idx_m
=
Deterministic
?
block_idx_m
:
block_work_idx
[
I0
];
// HACK: this force m/gemm1_n_block_data_idx_on_grid into SGPR
// HACK: this force m/gemm1_n_block_data_idx_on_grid into SGPR
const
index_t
m_block_data_idx_on_grid
=
const
index_t
m_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I0
]
*
MPerBlock
);
__builtin_amdgcn_readfirstlane
(
block_work_idx
_m
*
MPerBlock
);
const
index_t
gemm1_n_block_data_idx_on_grid
=
const
index_t
gemm1_n_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I1
]
*
Gemm1NPerBlock
);
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I1
]
*
Gemm1NPerBlock
);
...
@@ -835,7 +839,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -835,7 +839,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
InMemoryDataOperationEnum
::
Set
,
InMemoryDataOperationEnum
::
Set
,
1
,
1
,
false
>
{
lse_grid_desc_mblock_mrepeat_mwave_mperxdl
,
false
>
{
lse_grid_desc_mblock_mrepeat_mwave_mperxdl
,
make_multi_index
(
block_work_idx
[
I0
],
// mblock
make_multi_index
(
block_work_idx
_m
,
// mblock
0
,
// mrepeat
0
,
// mrepeat
acc0_thread_origin
[
I2
],
// mwave
acc0_thread_origin
[
I2
],
// mwave
acc0_thread_origin
[
I4
]),
// mperxdl
acc0_thread_origin
[
I4
]),
// mperxdl
...
@@ -897,7 +901,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -897,7 +901,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
InMemoryDataOperationEnum
::
Set
,
InMemoryDataOperationEnum
::
Set
,
1
,
// DstScalarStrideInVector
1
,
// DstScalarStrideInVector
true
>
{
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
true
>
{
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
make_multi_index
(
block_work_idx
[
I0
]
,
// MBlockId
make_multi_index
(
block_work_idx
_m
,
// MBlockId
0
,
// NBlockId
0
,
// NBlockId
0
,
// mrepeat
0
,
// mrepeat
0
,
// nrepeat
0
,
// nrepeat
...
@@ -1319,7 +1323,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -1319,7 +1323,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
{
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
,
{
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
,
make_multi_index
(
0
,
0
,
0
,
0
),
make_multi_index
(
0
,
0
,
0
,
0
),
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
make_multi_index
(
block_work_idx
[
I0
]
,
0
,
block_work_idx
[
I1
],
0
),
make_multi_index
(
block_work_idx
_m
,
0
,
block_work_idx
[
I1
],
0
),
c_element_op
};
c_element_op
};
// space filling curve for threadwise C in VGPR
// space filling curve for threadwise C in VGPR
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment