Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
957d5dee
"...composable_kernel.git" did not exist on "e0d8806ca1cd8611d387f1fb441d3c9e92174ec5"
Commit
957d5dee
authored
Oct 11, 2023
by
danyao12
Browse files
resolve conflict
parents
10836d41
3f4eae1d
Changes
18
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
18 changed files
with
226 additions
and
356 deletions
+226
-356
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_forward_v2.cpp
...e_softmax_gemm/batched_multihead_attention_forward_v2.cpp
+7
-11
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_train_v2.cpp
...ale_softmax_gemm/batched_multihead_attention_train_v2.cpp
+15
-15
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_forward_v2.cpp
...e_softmax_gemm/grouped_multihead_attention_forward_v2.cpp
+7
-11
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_train_v2.cpp
...ale_softmax_gemm/grouped_multihead_attention_train_v2.cpp
+15
-15
example/52_flash_atten_bias/batched_multihead_attention_bias_forward_v2.cpp
...tten_bias/batched_multihead_attention_bias_forward_v2.cpp
+7
-11
example/52_flash_atten_bias/batched_multihead_attention_bias_forward_v2_zcheck.cpp
...as/batched_multihead_attention_bias_forward_v2_zcheck.cpp
+7
-11
example/52_flash_atten_bias/grouped_multihead_attention_bias_forward_v2.cpp
...tten_bias/grouped_multihead_attention_bias_forward_v2.cpp
+7
-11
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_light_v1.hpp
...pl/device_batched_mha_bwd_xdl_cshuffle_qloop_light_v1.hpp
+4
-4
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_v1.hpp
...ice/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_v1.hpp
+6
-6
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_fwd_xdl_cshuffle_v2.hpp
...pu/device/impl/device_batched_mha_fwd_xdl_cshuffle_v2.hpp
+97
-145
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_light_v1.hpp
...pl/device_grouped_mha_bwd_xdl_cshuffle_qloop_light_v1.hpp
+1
-1
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_light_v2.hpp
...pl/device_grouped_mha_bwd_xdl_cshuffle_qloop_light_v2.hpp
+1
-1
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_v1.hpp
...ice/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_v1.hpp
+1
-1
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_v2.hpp
...ice/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_v2.hpp
+1
-14
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_fwd_xdl_cshuffle_v1.hpp
...pu/device/impl/device_grouped_mha_fwd_xdl_cshuffle_v1.hpp
+1
-1
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_fwd_xdl_cshuffle_v2.hpp
...pu/device/impl/device_grouped_mha_fwd_xdl_cshuffle_v2.hpp
+38
-89
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v2.hpp
...id/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v2.hpp
+9
-0
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_fwd_xdl_cshuffle_v2.hpp
...ion/gpu/grid/gridwise_batched_mha_fwd_xdl_cshuffle_v2.hpp
+2
-9
No files found.
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_forward_v2.cpp
View file @
957d5dee
...
@@ -71,11 +71,10 @@ static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecializatio
...
@@ -71,11 +71,10 @@ static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecializatio
static
constexpr
auto
MaskingSpec
=
static
constexpr
auto
MaskingSpec
=
ck
::
tensor_operation
::
device
::
MaskingSpecialization
::
MaskDisabled
;
ck
::
tensor_operation
::
device
::
MaskingSpecialization
::
MaskDisabled
;
static
constexpr
auto
TensorSpecA
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecA
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecB0
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecB0
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecB1
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecB1
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecC
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecC
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
bool
Deterministic
=
false
;
#if(DIM <= 32)
#if(DIM <= 32)
using
DeviceGemmInstance
=
using
DeviceGemmInstance
=
...
@@ -149,8 +148,7 @@ using DeviceGemmInstance =
...
@@ -149,8 +148,7 @@ using DeviceGemmInstance =
S
<
1
,
64
,
1
,
4
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
S
<
1
,
64
,
1
,
4
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
4
,
4
,
MaskingSpec
,
// MaskingSpecialization
MaskingSpec
>
;
// MaskingSpecialization
Deterministic
>
;
#elif(DIM <= 64)
#elif(DIM <= 64)
using
DeviceGemmInstance
=
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
<
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
<
...
@@ -223,8 +221,7 @@ using DeviceGemmInstance =
...
@@ -223,8 +221,7 @@ using DeviceGemmInstance =
S
<
1
,
32
,
1
,
8
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
S
<
1
,
32
,
1
,
8
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
4
,
4
,
MaskingSpec
,
// MaskingSpecialization
MaskingSpec
>
;
// MaskingSpecialization
Deterministic
>
;
#elif(DIM <= 128)
#elif(DIM <= 128)
using
DeviceGemmInstance
=
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
<
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
<
...
@@ -297,8 +294,7 @@ using DeviceGemmInstance =
...
@@ -297,8 +294,7 @@ using DeviceGemmInstance =
S
<
1
,
32
,
1
,
8
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
S
<
1
,
32
,
1
,
8
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
4
,
4
,
MaskingSpec
,
// MaskingSpecialization
MaskingSpec
>
;
// MaskingSpecialization
Deterministic
>
;
#endif
#endif
// Ref Gemm0: DataType in, AccDataType out
// Ref Gemm0: DataType in, AccDataType out
...
...
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_train_v2.cpp
View file @
957d5dee
This diff is collapsed.
Click to expand it.
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_forward_v2.cpp
View file @
957d5dee
...
@@ -71,11 +71,10 @@ static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecializatio
...
@@ -71,11 +71,10 @@ static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecializatio
static
constexpr
auto
MaskingSpec
=
static
constexpr
auto
MaskingSpec
=
ck
::
tensor_operation
::
device
::
MaskingSpecialization
::
MaskDisabled
;
ck
::
tensor_operation
::
device
::
MaskingSpecialization
::
MaskDisabled
;
static
constexpr
auto
TensorSpecA
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecA
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecB0
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecB0
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecB1
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecB1
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecC
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecC
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
bool
Deterministic
=
false
;
#if(DIM <= 32)
#if(DIM <= 32)
using
DeviceGemmInstance
=
using
DeviceGemmInstance
=
...
@@ -149,8 +148,7 @@ using DeviceGemmInstance =
...
@@ -149,8 +148,7 @@ using DeviceGemmInstance =
S
<
1
,
64
,
1
,
4
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
S
<
1
,
64
,
1
,
4
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
1
,
1
,
MaskingSpec
,
// MaskingSpecialization
MaskingSpec
>
;
// MaskingSpecialization
Deterministic
>
;
#elif(DIM <= 64)
#elif(DIM <= 64)
using
DeviceGemmInstance
=
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
<
ck
::
tensor_operation
::
device
::
DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
<
...
@@ -223,8 +221,7 @@ using DeviceGemmInstance =
...
@@ -223,8 +221,7 @@ using DeviceGemmInstance =
S
<
1
,
32
,
1
,
8
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
S
<
1
,
32
,
1
,
8
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
1
,
1
,
MaskingSpec
,
// MaskingSpecialization
MaskingSpec
>
;
// MaskingSpecialization
Deterministic
>
;
#elif(DIM <= 128)
#elif(DIM <= 128)
using
DeviceGemmInstance
=
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
<
ck
::
tensor_operation
::
device
::
DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
<
...
@@ -297,8 +294,7 @@ using DeviceGemmInstance =
...
@@ -297,8 +294,7 @@ using DeviceGemmInstance =
S
<
1
,
32
,
1
,
8
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
S
<
1
,
32
,
1
,
8
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
1
,
1
,
MaskingSpec
,
// MaskingSpecialization
MaskingSpec
>
;
// MaskingSpecialization
Deterministic
>
;
#endif
#endif
// Ref Gemm0: DataType in, AccDataType out
// Ref Gemm0: DataType in, AccDataType out
...
...
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_train_v2.cpp
View file @
957d5dee
This diff is collapsed.
Click to expand it.
example/52_flash_atten_bias/batched_multihead_attention_bias_forward_v2.cpp
View file @
957d5dee
...
@@ -71,11 +71,10 @@ static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecializatio
...
@@ -71,11 +71,10 @@ static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecializatio
static
constexpr
auto
MaskingSpec
=
static
constexpr
auto
MaskingSpec
=
ck
::
tensor_operation
::
device
::
MaskingSpecialization
::
MaskDisabled
;
ck
::
tensor_operation
::
device
::
MaskingSpecialization
::
MaskDisabled
;
static
constexpr
auto
TensorSpecA
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecA
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecB0
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecB0
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecB1
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecB1
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecC
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecC
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
bool
Deterministic
=
false
;
#if(DIM <= 32)
#if(DIM <= 32)
using
DeviceGemmInstance
=
using
DeviceGemmInstance
=
...
@@ -149,8 +148,7 @@ using DeviceGemmInstance =
...
@@ -149,8 +148,7 @@ using DeviceGemmInstance =
S
<
1
,
64
,
1
,
4
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
S
<
1
,
64
,
1
,
4
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
4
,
4
,
MaskingSpec
,
// MaskingSpecialization
MaskingSpec
>
;
// MaskingSpecialization
Deterministic
>
;
#elif(DIM <= 64)
#elif(DIM <= 64)
using
DeviceGemmInstance
=
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
<
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
<
...
@@ -223,8 +221,7 @@ using DeviceGemmInstance =
...
@@ -223,8 +221,7 @@ using DeviceGemmInstance =
S
<
1
,
32
,
1
,
8
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
S
<
1
,
32
,
1
,
8
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
4
,
4
,
MaskingSpec
,
// MaskingSpecialization
MaskingSpec
>
;
// MaskingSpecialization
Deterministic
>
;
#elif(DIM <= 128)
#elif(DIM <= 128)
using
DeviceGemmInstance
=
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
<
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
<
...
@@ -297,8 +294,7 @@ using DeviceGemmInstance =
...
@@ -297,8 +294,7 @@ using DeviceGemmInstance =
S
<
1
,
32
,
1
,
8
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
S
<
1
,
32
,
1
,
8
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
4
,
4
,
MaskingSpec
,
// MaskingSpecialization
MaskingSpec
>
;
// MaskingSpecialization
Deterministic
>
;
#endif
#endif
// Ref Gemm0: DataType in, AccDataType out
// Ref Gemm0: DataType in, AccDataType out
...
...
example/52_flash_atten_bias/batched_multihead_attention_bias_forward_v2_zcheck.cpp
View file @
957d5dee
...
@@ -71,11 +71,10 @@ static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecializatio
...
@@ -71,11 +71,10 @@ static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecializatio
static
constexpr
auto
MaskingSpec
=
static
constexpr
auto
MaskingSpec
=
ck
::
tensor_operation
::
device
::
MaskingSpecialization
::
MaskDisabled
;
ck
::
tensor_operation
::
device
::
MaskingSpecialization
::
MaskDisabled
;
static
constexpr
auto
TensorSpecA
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecA
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecB0
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecB0
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecB1
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecB1
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecC
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecC
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
bool
Deterministic
=
false
;
#if(DIM <= 32)
#if(DIM <= 32)
using
DeviceGemmInstance
=
using
DeviceGemmInstance
=
...
@@ -149,8 +148,7 @@ using DeviceGemmInstance =
...
@@ -149,8 +148,7 @@ using DeviceGemmInstance =
S
<
1
,
64
,
1
,
4
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
S
<
1
,
64
,
1
,
4
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
4
,
4
,
MaskingSpec
,
// MaskingSpecialization
MaskingSpec
>
;
// MaskingSpecialization
Deterministic
>
;
#elif(DIM <= 64)
#elif(DIM <= 64)
using
DeviceGemmInstance
=
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
<
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
<
...
@@ -223,8 +221,7 @@ using DeviceGemmInstance =
...
@@ -223,8 +221,7 @@ using DeviceGemmInstance =
S
<
1
,
32
,
1
,
8
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
S
<
1
,
32
,
1
,
8
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
4
,
4
,
MaskingSpec
,
// MaskingSpecialization
MaskingSpec
>
;
// MaskingSpecialization
Deterministic
>
;
#elif(DIM <= 128)
#elif(DIM <= 128)
using
DeviceGemmInstance
=
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
<
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
<
...
@@ -297,8 +294,7 @@ using DeviceGemmInstance =
...
@@ -297,8 +294,7 @@ using DeviceGemmInstance =
S
<
1
,
32
,
1
,
8
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
S
<
1
,
32
,
1
,
8
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
4
,
4
,
MaskingSpec
,
// MaskingSpecialization
MaskingSpec
>
;
// MaskingSpecialization
Deterministic
>
;
#endif
#endif
using
DeviceDropoutInstance
=
ck
::
tensor_operation
::
device
::
DeviceBatchedDropout
<
NumDimG
,
using
DeviceDropoutInstance
=
ck
::
tensor_operation
::
device
::
DeviceBatchedDropout
<
NumDimG
,
...
...
example/52_flash_atten_bias/grouped_multihead_attention_bias_forward_v2.cpp
View file @
957d5dee
...
@@ -71,11 +71,10 @@ static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecializatio
...
@@ -71,11 +71,10 @@ static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecializatio
static
constexpr
auto
MaskingSpec
=
static
constexpr
auto
MaskingSpec
=
ck
::
tensor_operation
::
device
::
MaskingSpecialization
::
MaskDisabled
;
ck
::
tensor_operation
::
device
::
MaskingSpecialization
::
MaskDisabled
;
static
constexpr
auto
TensorSpecA
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecA
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecB0
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecB0
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecB1
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecB1
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecC
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecC
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
bool
Deterministic
=
false
;
#if(DIM <= 32)
#if(DIM <= 32)
using
DeviceGemmInstance
=
using
DeviceGemmInstance
=
...
@@ -149,8 +148,7 @@ using DeviceGemmInstance =
...
@@ -149,8 +148,7 @@ using DeviceGemmInstance =
S
<
1
,
64
,
1
,
4
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
S
<
1
,
64
,
1
,
4
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
1
,
1
,
MaskingSpec
,
// MaskingSpecialization
MaskingSpec
>
;
// MaskingSpecialization
Deterministic
>
;
#elif(DIM <= 64)
#elif(DIM <= 64)
using
DeviceGemmInstance
=
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
<
ck
::
tensor_operation
::
device
::
DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
<
...
@@ -223,8 +221,7 @@ using DeviceGemmInstance =
...
@@ -223,8 +221,7 @@ using DeviceGemmInstance =
S
<
1
,
32
,
1
,
8
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
S
<
1
,
32
,
1
,
8
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
1
,
1
,
MaskingSpec
,
// MaskingSpecialization
MaskingSpec
>
;
// MaskingSpecialization
Deterministic
>
;
#elif(DIM <= 128)
#elif(DIM <= 128)
using
DeviceGemmInstance
=
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
<
ck
::
tensor_operation
::
device
::
DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
<
...
@@ -297,8 +294,7 @@ using DeviceGemmInstance =
...
@@ -297,8 +294,7 @@ using DeviceGemmInstance =
S
<
1
,
32
,
1
,
8
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
S
<
1
,
32
,
1
,
8
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
1
,
1
,
MaskingSpec
,
// MaskingSpecialization
MaskingSpec
>
;
// MaskingSpecialization
Deterministic
>
;
#endif
#endif
// Ref Gemm0: DataType in, AccDataType out
// Ref Gemm0: DataType in, AccDataType out
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_light_v1.hpp
View file @
957d5dee
...
@@ -1488,10 +1488,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
...
@@ -1488,10 +1488,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
void
*
p_qgrad_grid
,
void
*
p_qgrad_grid
,
void
*
p_kgrad_grid
,
void
*
p_kgrad_grid
,
void
*
p_vgrad_grid
,
void
*
p_vgrad_grid
,
const
D0DataType
*
p_acc0_bias
,
const
void
*
p_acc0_bias
,
const
D1DataType
*
p_acc1_bias
,
const
void
*
p_acc1_bias
,
D0DataType
*
p_d0grad_grid
,
void
*
p_d0grad_grid
,
D1DataType
*
p_d1grad_grid
,
void
*
p_d1grad_grid
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_lengths
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_lengths
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_v1.hpp
View file @
957d5dee
...
@@ -1344,10 +1344,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
...
@@ -1344,10 +1344,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
void
*
p_qgrad_grid
,
void
*
p_qgrad_grid
,
void
*
p_kgrad_grid
,
void
*
p_kgrad_grid
,
void
*
p_vgrad_grid
,
void
*
p_vgrad_grid
,
const
D0DataType
*
p_acc0_bias
,
const
void
*
p_acc0_bias
,
const
D1DataType
*
p_acc1_bias
,
const
void
*
p_acc1_bias
,
D0DataType
*
p_d0grad_grid
,
void
*
p_d0grad_grid
,
D1DataType
*
p_d1grad_grid
,
void
*
p_d1grad_grid
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_lengths
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_lengths
,
...
@@ -1390,8 +1390,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
...
@@ -1390,8 +1390,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
static_cast
<
OutputDataType
*>
(
p_vgrad_grid
),
static_cast
<
OutputDataType
*>
(
p_vgrad_grid
),
static_cast
<
const
D0DataType
*>
(
p_acc0_bias
),
// cast in struct Argument
static_cast
<
const
D0DataType
*>
(
p_acc0_bias
),
// cast in struct Argument
static_cast
<
const
D1DataType
*>
(
p_acc1_bias
),
// cast in struct Argument
static_cast
<
const
D1DataType
*>
(
p_acc1_bias
),
// cast in struct Argument
static_cast
<
const
D0DataType
*>
(
p_d0grad_grid
),
static_cast
<
D0DataType
*>
(
p_d0grad_grid
),
static_cast
<
const
D1DataType
*>
(
p_d1grad_grid
),
static_cast
<
D1DataType
*>
(
p_d1grad_grid
),
a_gs_ms_ks_lengths
,
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
,
a_gs_ms_ks_strides
,
b_gs_ns_ks_lengths
,
b_gs_ns_ks_lengths
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_fwd_xdl_cshuffle_v2.hpp
View file @
957d5dee
...
@@ -47,8 +47,7 @@ template <typename GridwiseGemm,
...
@@ -47,8 +47,7 @@ template <typename GridwiseGemm,
typename
C0MatrixMask
,
typename
C0MatrixMask
,
bool
HasMainKBlockLoop
,
bool
HasMainKBlockLoop
,
bool
IsDropout
,
bool
IsDropout
,
bool
IsLseStoring
,
bool
IsLseStoring
>
bool
Deterministic
>
__global__
void
__global__
void
#if CK_USE_LAUNCH_BOUNDS
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
...
@@ -79,7 +78,6 @@ __global__ void
...
@@ -79,7 +78,6 @@ __global__ void
const
Block2CTileMap
block_2_ctile_map
,
const
Block2CTileMap
block_2_ctile_map
,
const
index_t
batch_count
,
const
index_t
batch_count
,
const
index_t
h_ratio
,
const
index_t
h_ratio
,
const
index_t
mblock
,
const
ComputeBasePtrOfStridedBatch
compute_base_ptr_of_batch
,
const
ComputeBasePtrOfStridedBatch
compute_base_ptr_of_batch
,
const
C0MatrixMask
c0_matrix_mask
,
const
C0MatrixMask
c0_matrix_mask
,
const
uint8_t
p_dropout_in_uint8_t
,
const
uint8_t
p_dropout_in_uint8_t
,
...
@@ -124,73 +122,34 @@ __global__ void
...
@@ -124,73 +122,34 @@ __global__ void
const
index_t
z_random_matrix_offset
=
g_idx
*
raw_m_padded
*
raw_n_padded
;
const
index_t
z_random_matrix_offset
=
g_idx
*
raw_m_padded
*
raw_n_padded
;
if
constexpr
(
Deterministic
)
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
,
IsDropout
,
IsLseStoring
>(
{
p_a_grid
+
a_batch_offset
,
for
(
index_t
i
=
0
;
i
<
mblock
;
i
++
)
p_b_grid
+
b_batch_offset
,
{
tmp_p_d0_grid
,
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
,
IsDropout
,
IsLseStoring
>(
p_b1_grid
+
b1_batch_offset
,
p_a_grid
+
a_batch_offset
,
p_c_grid
+
c_batch_offset
,
p_b_grid
+
b_batch_offset
,
p_z_grid
==
nullptr
?
nullptr
:
p_z_grid
+
z_batch_offset
,
tmp_p_d0_grid
,
p_lse_grid
==
nullptr
?
nullptr
:
p_lse_grid
+
lse_batch_offset
,
p_b1_grid
+
b1_batch_offset
,
p_shared
,
p_c_grid
+
c_batch_offset
,
a_element_op
,
p_z_grid
==
nullptr
?
nullptr
:
p_z_grid
+
z_batch_offset
,
b_element_op
,
p_lse_grid
==
nullptr
?
nullptr
:
p_lse_grid
+
lse_batch_offset
,
acc_element_op
,
p_shared
,
b1_element_op
,
a_element_op
,
c_element_op
,
b_element_op
,
a_grid_desc_ak0_m_ak1
,
acc_element_op
,
b_grid_desc_bk0_n_bk1
,
b1_element_op
,
d0_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
c_element_op
,
b1_grid_desc_bk0_n_bk1
,
a_grid_desc_ak0_m_ak1
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
b_grid_desc_bk0_n_bk1
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
d0_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
lse_grid_desc_m
,
b1_grid_desc_bk0_n_bk1
,
block_2_ctile_map
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c0_matrix_mask
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
p_dropout_in_uint8_t
,
lse_grid_desc_m
,
p_dropout_rescale
,
block_2_ctile_map
,
ph
,
c0_matrix_mask
,
z_random_matrix_offset
,
p_dropout_in_uint8_t
,
raw_n_padded
);
p_dropout_rescale
,
ph
,
z_random_matrix_offset
,
raw_n_padded
,
i
);
}
}
else
{
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
,
IsDropout
,
IsLseStoring
>(
p_a_grid
+
a_batch_offset
,
p_b_grid
+
b_batch_offset
,
tmp_p_d0_grid
,
p_b1_grid
+
b1_batch_offset
,
p_c_grid
+
c_batch_offset
,
p_z_grid
==
nullptr
?
nullptr
:
p_z_grid
+
z_batch_offset
,
p_lse_grid
==
nullptr
?
nullptr
:
p_lse_grid
+
lse_batch_offset
,
p_shared
,
a_element_op
,
b_element_op
,
acc_element_op
,
b1_element_op
,
c_element_op
,
a_grid_desc_ak0_m_ak1
,
b_grid_desc_bk0_n_bk1
,
d0_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
b1_grid_desc_bk0_n_bk1
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
lse_grid_desc_m
,
block_2_ctile_map
,
c0_matrix_mask
,
p_dropout_in_uint8_t
,
p_dropout_rescale
,
ph
,
z_random_matrix_offset
,
raw_n_padded
,
0
);
}
#else
#else
ignore
=
p_a_grid
;
ignore
=
p_a_grid
;
ignore
=
p_b_grid
;
ignore
=
p_b_grid
;
...
@@ -214,7 +173,6 @@ __global__ void
...
@@ -214,7 +173,6 @@ __global__ void
ignore
=
block_2_ctile_map
;
ignore
=
block_2_ctile_map
;
ignore
=
batch_count
;
ignore
=
batch_count
;
ignore
=
h_ratio
;
ignore
=
h_ratio
;
ignore
=
mblock
;
ignore
=
compute_base_ptr_of_batch
;
ignore
=
compute_base_ptr_of_batch
;
ignore
=
c0_matrix_mask
;
ignore
=
c0_matrix_mask
;
ignore
=
p_dropout_in_uint8_t
;
ignore
=
p_dropout_in_uint8_t
;
...
@@ -299,7 +257,6 @@ template <index_t NumDimG,
...
@@ -299,7 +257,6 @@ template <index_t NumDimG,
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
,
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
,
index_t
Acc1BiasTransferSrcScalarPerVector
,
index_t
Acc1BiasTransferSrcScalarPerVector
,
MaskingSpecialization
MaskingSpec
,
MaskingSpecialization
MaskingSpec
,
bool
Deterministic
,
LoopScheduler
LoopSched
=
LoopScheduler
::
Default
>
LoopScheduler
LoopSched
=
LoopScheduler
::
Default
>
struct
DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
struct
DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
:
public
DeviceBatchedMultiheadAttentionForward
<
NumDimG
,
:
public
DeviceBatchedMultiheadAttentionForward
<
NumDimG
,
...
@@ -579,8 +536,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
...
@@ -579,8 +536,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
Acc1BiasTransferSrcScalarPerVector
,
Acc1BiasTransferSrcScalarPerVector
,
LoopSched
,
LoopSched
,
Transform
::
matrix_padder
.
PadN
,
Transform
::
matrix_padder
.
PadN
,
MaskingSpec
!=
MaskingSpecialization
::
MaskDisabled
,
MaskingSpec
!=
MaskingSpecialization
::
MaskDisabled
>
;
Deterministic
>
;
// Argument
// Argument
// FIXME: constness
// FIXME: constness
...
@@ -836,9 +792,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
...
@@ -836,9 +792,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
}
}
const
index_t
grid_size
=
const
index_t
grid_size
=
(
Deterministic
?
1
arg
.
block_2_ctile_map_
.
CalculateGridSize
(
arg
.
c_grid_desc_m_n_
)
*
arg
.
batch_count_
;
:
arg
.
block_2_ctile_map_
.
CalculateGridSize
(
arg
.
c_grid_desc_m_n_
))
*
arg
.
batch_count_
;
// Gemm0_K
// Gemm0_K
const
auto
K
=
const
auto
K
=
...
@@ -846,74 +800,72 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
...
@@ -846,74 +800,72 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
float
ave_time
=
0
;
float
ave_time
=
0
;
auto
launch_kernel
=
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop_
,
[
&
](
auto
has_main_k_block_loop_
,
auto
is_dropout_
,
auto
is_lse_storing_
)
{
auto
is_dropout_
,
const
auto
kernel
=
kernel_batched_multiheadattention_forward_xdl_cshuffle_v2
<
auto
is_lse_storing_
)
{
GridwiseGemm
,
const
auto
kernel
=
kernel_batched_multiheadattention_forward_xdl_cshuffle_v2
<
ADataType
,
// TODO: distiguish A/B datatype
GridwiseGemm
,
D0DataType
,
ADataType
,
// TODO: distiguish A/B datatype
CDataType
,
D0DataType
,
ZDataType
,
CDataType
,
LSEDataType
,
ZDataType
,
GemmAccDataType
,
LSEDataType
,
AElementwiseOperation
,
GemmAccDataType
,
BElementwiseOperation
,
AElementwiseOperation
,
AccElementwiseOperation
,
BElementwiseOperation
,
B1ElementwiseOperation
,
AccElementwiseOperation
,
CElementwiseOperation
,
B1ElementwiseOperation
,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
CElementwiseOperation
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
typename
GridwiseGemm
::
D0GridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
DeviceOp
::
B1GridDesc_BK0_N_BK1
,
typename
GridwiseGemm
::
D0GridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
,
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
DeviceOp
::
B1GridDesc_BK0_N_BK1
,
typename
GridwiseGemm
::
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
,
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
DeviceOp
::
LSEGridDesc_M
,
typename
GridwiseGemm
::
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
,
typename
GridwiseGemm
::
DefaultBlock2CTileMap
,
DeviceOp
::
LSEGridDesc_M
,
ComputeBasePtrOfStridedBatch
,
typename
GridwiseGemm
::
DefaultBlock2CTileMap
,
C0MatrixMask
,
ComputeBasePtrOfStridedBatch
,
has_main_k_block_loop_
,
C0MatrixMask
,
is_dropout_
,
has_main_k_block_loop_
,
is_lse_storing_
,
is_dropout_
,
Deterministic
>
;
is_lse_storing_
>
;
return
launch_and_time_kernel
(
return
launch_and_time_kernel
(
stream_config
,
stream_config
,
kernel
,
kernel
,
dim3
(
grid_size
),
dim3
(
grid_size
),
dim3
(
BlockSize
),
dim3
(
BlockSize
),
0
,
0
,
arg
.
p_a_grid_
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_b_grid_
,
arg
.
p_d0_grid_
,
arg
.
p_d0_grid_
,
arg
.
p_b1_grid_
,
arg
.
p_b1_grid_
,
arg
.
p_c_grid_
,
arg
.
p_c_grid_
,
arg
.
p_z_grid_
,
arg
.
p_z_grid_
,
arg
.
p_lse_grid_
,
arg
.
p_lse_grid_
,
arg
.
a_element_op_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
b_element_op_
,
arg
.
acc_element_op_
,
arg
.
acc_element_op_
,
arg
.
b1_element_op_
,
arg
.
b1_element_op_
,
arg
.
c_element_op_
,
arg
.
c_element_op_
,
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
d0_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_
,
arg
.
d0_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_
,
arg
.
b1_grid_desc_bk0_n_bk1_
,
arg
.
b1_grid_desc_bk0_n_bk1_
,
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_
,
arg
.
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_
,
arg
.
lse_grid_desc_m_
,
arg
.
lse_grid_desc_m_
,
arg
.
block_2_ctile_map_
,
arg
.
block_2_ctile_map_
,
arg
.
batch_count_
,
arg
.
batch_count_
,
arg
.
h_ratio_
,
arg
.
h_ratio_
,
arg
.
compute_base_ptr_of_batch_
,
arg
.
block_2_ctile_map_
.
CalculateGridSize
(
arg
.
c_grid_desc_m_n_
),
arg
.
c0_matrix_mask_
,
arg
.
compute_base_ptr_of_batch_
,
arg
.
p_dropout_in_uint8_t_
,
arg
.
c0_matrix_mask_
,
arg
.
p_dropout_rescale_
,
arg
.
p_dropout_in_uint8_t_
,
arg
.
seed_
,
arg
.
p_dropout_rescale_
,
arg
.
offset_
,
arg
.
seed_
,
arg
.
m_raw_padded_
,
arg
.
offset_
,
arg
.
n_raw_padded_
);
arg
.
m_raw_padded_
,
};
arg
.
n_raw_padded_
);
};
// Gemm1_K is split into Gemm1_K0/K1 where K1 is known at compile time, so we only need
// Gemm1_K is split into Gemm1_K0/K1 where K1 is known at compile time, so we only need
// to concern Gemm0's loop
// to concern Gemm0's loop
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_light_v1.hpp
View file @
957d5dee
...
@@ -1079,7 +1079,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
...
@@ -1079,7 +1079,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
c_grid_desc_g_m_n
,
c_grid_desc_g_m_n
,
bgrad_grid_desc_g_n_k
,
bgrad_grid_desc_g_n_k
,
b1grad_grid_desc_g_n_k
,
b1grad_grid_desc_g_n_k
,
type_convert
<
index_t
>
(
lse_grid_desc_m
.
GetElementSpaceSize
()
));
type_convert
<
index_t
>
(
problem_desc
.
lse_gs_ms_strides
[
NumDimG
-
1
]
));
// C0 mask
// C0 mask
const
auto
c0_matrix_mask
=
const
auto
c0_matrix_mask
=
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_light_v2.hpp
View file @
957d5dee
...
@@ -1148,7 +1148,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
...
@@ -1148,7 +1148,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
c_grid_desc_g_m_n
,
c_grid_desc_g_m_n
,
bgrad_grid_desc_g_n_k
,
bgrad_grid_desc_g_n_k
,
b1grad_grid_desc_g_n_k
,
b1grad_grid_desc_g_n_k
,
type_convert
<
index_t
>
(
lse_grid_desc_m
.
GetElementSpaceSize
()
));
type_convert
<
index_t
>
(
problem_desc
.
lse_gs_ms_strides
[
NumDimG
-
1
]
));
// C0 mask
// C0 mask
const
auto
c0_matrix_mask
=
const
auto
c0_matrix_mask
=
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_v1.hpp
View file @
957d5dee
...
@@ -969,7 +969,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
...
@@ -969,7 +969,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
c_grid_desc_g_m_n
,
c_grid_desc_g_m_n
,
bgrad_grid_desc_g_n_k
,
bgrad_grid_desc_g_n_k
,
b1grad_grid_desc_g_n_k
,
b1grad_grid_desc_g_n_k
,
type_convert
<
index_t
>
(
lse_grid_desc_m
.
GetElementSpaceSize
()
));
type_convert
<
index_t
>
(
problem_desc
.
lse_gs_ms_strides
[
NumDimG
-
1
]
));
// C0 mask
// C0 mask
const
auto
c0_matrix_mask
=
const
auto
c0_matrix_mask
=
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_v2.hpp
View file @
957d5dee
...
@@ -467,19 +467,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
...
@@ -467,19 +467,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
}
//
// dP = dY * V^T
//
// YGrad in Gemm A position
static
auto
MakeYGradGridDescriptor_O0_M_O1
(
const
std
::
vector
<
index_t
>&
y_gs_ms_os_lengths
,
const
std
::
vector
<
index_t
>&
y_gs_ms_os_strides
)
{
return
Transform
::
MakeAGridDescriptor_AK0_M_AK1
(
Transform
::
MakeAGridDescriptor_M_K
(
y_gs_ms_os_lengths
,
y_gs_ms_os_strides
),
Number
<
Y_O1
>
{});
}
// V in Gemm B position
// V in Gemm B position
static
auto
MakeVGridDescriptor_O0_N_O1
(
const
std
::
vector
<
index_t
>&
v_gs_os_ns_lengths
,
static
auto
MakeVGridDescriptor_O0_N_O1
(
const
std
::
vector
<
index_t
>&
v_gs_os_ns_lengths
,
const
std
::
vector
<
index_t
>&
v_gs_os_ns_strides
)
const
std
::
vector
<
index_t
>&
v_gs_os_ns_strides
)
...
@@ -1039,7 +1026,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
...
@@ -1039,7 +1026,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
c_grid_desc_g_m_n
,
c_grid_desc_g_m_n
,
bgrad_grid_desc_g_n_k
,
bgrad_grid_desc_g_n_k
,
b1grad_grid_desc_g_n_k
,
b1grad_grid_desc_g_n_k
,
type_convert
<
index_t
>
(
lse_grid_desc_m
.
GetElementSpaceSize
()
));
type_convert
<
index_t
>
(
problem_desc
.
lse_gs_ms_strides
[
NumDimG
-
1
]
));
// C0 mask
// C0 mask
const
auto
c0_matrix_mask
=
const
auto
c0_matrix_mask
=
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_fwd_xdl_cshuffle_v1.hpp
View file @
957d5dee
...
@@ -694,7 +694,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V1
...
@@ -694,7 +694,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V1
b1_grid_desc_g_n_k
,
b1_grid_desc_g_n_k
,
c_grid_desc_g_m_n
,
c_grid_desc_g_m_n
,
z_grid_desc_g_m_n
,
z_grid_desc_g_m_n
,
type_convert
<
index_t
>
(
lse_g
rid_desc_m
.
GetElementSpaceSize
()
));
type_convert
<
index_t
>
(
lse_g
s_ms_strides
[
NumDimG
-
1
]
));
// C0 mask
// C0 mask
const
auto
c0_matrix_mask
=
const
auto
c0_matrix_mask
=
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_fwd_xdl_cshuffle_v2.hpp
View file @
957d5dee
...
@@ -35,8 +35,7 @@ template <typename GridwiseGemm,
...
@@ -35,8 +35,7 @@ template <typename GridwiseGemm,
typename
CElementwiseOperation
,
typename
CElementwiseOperation
,
bool
HasMainKBlockLoop
,
bool
HasMainKBlockLoop
,
bool
IsDropout
,
bool
IsDropout
,
bool
IsLseStoring
,
bool
IsLseStoring
>
bool
Deterministic
>
__global__
void
__global__
void
#if CK_USE_LAUNCH_BOUNDS
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
...
@@ -88,7 +87,7 @@ __global__ void
...
@@ -88,7 +87,7 @@ __global__ void
// per-group batch offset
// per-group batch offset
const
index_t
num_blocks_per_batch
=
arg_ptr
[
group_id
].
num_blocks_per_batch_
;
const
index_t
num_blocks_per_batch
=
arg_ptr
[
group_id
].
num_blocks_per_batch_
;
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
(
block_id
-
arg_ptr
[
group_id
].
block_start_
)
/
(
Deterministic
?
1
:
num_blocks_per_batch
)
)
;
(
block_id
-
arg_ptr
[
group_id
].
block_start_
)
/
num_blocks_per_batch
);
const
index_t
gkv_idx
=
__builtin_amdgcn_readfirstlane
(
g_idx
/
h_ratio
);
const
index_t
gkv_idx
=
__builtin_amdgcn_readfirstlane
(
g_idx
/
h_ratio
);
const
long_index_t
a_batch_offset
=
__builtin_amdgcn_readfirstlane
(
const
long_index_t
a_batch_offset
=
__builtin_amdgcn_readfirstlane
(
...
@@ -115,84 +114,38 @@ __global__ void
...
@@ -115,84 +114,38 @@ __global__ void
tmp_p_d0_grid
=
arg_ptr
[
group_id
].
p_d0_grid_
+
d0_batch_offset
;
tmp_p_d0_grid
=
arg_ptr
[
group_id
].
p_d0_grid_
+
d0_batch_offset
;
}
}
if
constexpr
(
Deterministic
)
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
,
IsDropout
,
IsLseStoring
>(
{
arg_ptr
[
group_id
].
p_a_grid_
+
a_batch_offset
,
for
(
index_t
i
=
0
;
i
<
num_blocks_per_batch
;
i
++
)
arg_ptr
[
group_id
].
p_b_grid_
+
b_batch_offset
,
{
tmp_p_d0_grid
,
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
,
IsDropout
,
IsLseStoring
>(
arg_ptr
[
group_id
].
p_b1_grid_
+
b1_batch_offset
,
arg_ptr
[
group_id
].
p_a_grid_
+
a_batch_offset
,
arg_ptr
[
group_id
].
p_c_grid_
+
c_batch_offset
,
arg_ptr
[
group_id
].
p_b_grid_
+
b_batch_offset
,
arg_ptr
[
group_id
].
p_z_grid_
==
nullptr
?
nullptr
tmp_p_d0_grid
,
:
arg_ptr
[
group_id
].
p_z_grid_
+
z_batch_offset
,
arg_ptr
[
group_id
].
p_b1_grid_
+
b1_batch_offset
,
arg_ptr
[
group_id
].
p_lse_grid_
==
nullptr
?
nullptr
arg_ptr
[
group_id
].
p_c_grid_
+
c_batch_offset
,
:
arg_ptr
[
group_id
].
p_lse_grid_
+
lse_batch_offset
,
arg_ptr
[
group_id
].
p_z_grid_
==
nullptr
// arg_ptr[group_id].p_lse_grid_ + lse_batch_offset,
?
nullptr
p_shared
,
:
arg_ptr
[
group_id
].
p_z_grid_
+
z_batch_offset
,
a_element_op
,
arg_ptr
[
group_id
].
p_lse_grid_
==
nullptr
b_element_op
,
?
nullptr
acc_element_op
,
:
arg_ptr
[
group_id
].
p_lse_grid_
+
lse_batch_offset
,
b1_element_op
,
// arg_ptr[group_id].p_lse_grid_ + lse_batch_offset,
c_element_op
,
p_shared
,
arg_ptr
[
group_id
].
a_grid_desc_ak0_m_ak1_
,
a_element_op
,
arg_ptr
[
group_id
].
b_grid_desc_bk0_n_bk1_
,
b_element_op
,
arg_ptr
[
group_id
].
d0_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_
,
acc_element_op
,
arg_ptr
[
group_id
].
b1_grid_desc_bk0_n_bk1_
,
b1_element_op
,
arg_ptr
[
group_id
].
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
c_element_op
,
arg_ptr
[
group_id
].
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_
,
arg_ptr
[
group_id
].
a_grid_desc_ak0_m_ak1_
,
arg_ptr
[
group_id
].
lse_grid_desc_m_
,
arg_ptr
[
group_id
].
b_grid_desc_bk0_n_bk1_
,
arg_ptr
[
group_id
].
block_2_ctile_map_
,
arg_ptr
[
group_id
].
d0_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_
,
arg_ptr
[
group_id
].
c0_matrix_mask_
,
arg_ptr
[
group_id
].
b1_grid_desc_bk0_n_bk1_
,
p_dropout_in_uint8_t
,
arg_ptr
[
group_id
].
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
p_dropout_rescale
,
arg_ptr
[
group_id
].
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_
,
ph
,
arg_ptr
[
group_id
].
lse_grid_desc_m_
,
arg_ptr
[
group_id
].
z_random_matrix_offset_
+
arg_ptr
[
group_id
].
block_2_ctile_map_
,
g_idx
*
arg_ptr
[
group_id
].
raw_m_padded_
*
arg_ptr
[
group_id
].
raw_n_padded_
,
arg_ptr
[
group_id
].
c0_matrix_mask_
,
arg_ptr
[
group_id
].
raw_n_padded_
);
p_dropout_in_uint8_t
,
p_dropout_rescale
,
ph
,
arg_ptr
[
group_id
].
z_random_matrix_offset_
+
g_idx
*
arg_ptr
[
group_id
].
raw_m_padded_
*
arg_ptr
[
group_id
].
raw_n_padded_
,
arg_ptr
[
group_id
].
raw_n_padded_
,
i
);
}
}
else
{
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
,
IsDropout
,
IsLseStoring
>(
arg_ptr
[
group_id
].
p_a_grid_
+
a_batch_offset
,
arg_ptr
[
group_id
].
p_b_grid_
+
b_batch_offset
,
tmp_p_d0_grid
,
arg_ptr
[
group_id
].
p_b1_grid_
+
b1_batch_offset
,
arg_ptr
[
group_id
].
p_c_grid_
+
c_batch_offset
,
arg_ptr
[
group_id
].
p_z_grid_
==
nullptr
?
nullptr
:
arg_ptr
[
group_id
].
p_z_grid_
+
z_batch_offset
,
arg_ptr
[
group_id
].
p_lse_grid_
==
nullptr
?
nullptr
:
arg_ptr
[
group_id
].
p_lse_grid_
+
lse_batch_offset
,
// arg_ptr[group_id].p_lse_grid_ + lse_batch_offset,
p_shared
,
a_element_op
,
b_element_op
,
acc_element_op
,
b1_element_op
,
c_element_op
,
arg_ptr
[
group_id
].
a_grid_desc_ak0_m_ak1_
,
arg_ptr
[
group_id
].
b_grid_desc_bk0_n_bk1_
,
arg_ptr
[
group_id
].
d0_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_
,
arg_ptr
[
group_id
].
b1_grid_desc_bk0_n_bk1_
,
arg_ptr
[
group_id
].
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg_ptr
[
group_id
].
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_
,
arg_ptr
[
group_id
].
lse_grid_desc_m_
,
arg_ptr
[
group_id
].
block_2_ctile_map_
,
arg_ptr
[
group_id
].
c0_matrix_mask_
,
p_dropout_in_uint8_t
,
p_dropout_rescale
,
ph
,
arg_ptr
[
group_id
].
z_random_matrix_offset_
+
g_idx
*
arg_ptr
[
group_id
].
raw_m_padded_
*
arg_ptr
[
group_id
].
raw_n_padded_
,
arg_ptr
[
group_id
].
raw_n_padded_
,
0
);
}
#else
#else
ignore
=
group_kernel_args
;
ignore
=
group_kernel_args
;
ignore
=
group_count
;
ignore
=
group_count
;
...
@@ -282,7 +235,6 @@ template <index_t NumDimG,
...
@@ -282,7 +235,6 @@ template <index_t NumDimG,
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
,
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
,
index_t
Acc1BiasTransferSrcScalarPerVector
,
index_t
Acc1BiasTransferSrcScalarPerVector
,
MaskingSpecialization
MaskingSpec
,
MaskingSpecialization
MaskingSpec
,
bool
Deterministic
,
LoopScheduler
LoopSched
=
LoopScheduler
::
Default
>
LoopScheduler
LoopSched
=
LoopScheduler
::
Default
>
struct
DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
struct
DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
:
public
DeviceGroupedMultiheadAttentionForward
<
NumDimG
,
:
public
DeviceGroupedMultiheadAttentionForward
<
NumDimG
,
...
@@ -598,8 +550,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
...
@@ -598,8 +550,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
Acc1BiasTransferSrcScalarPerVector
,
Acc1BiasTransferSrcScalarPerVector
,
LoopSched
,
LoopSched
,
Transform
::
matrix_padder
.
PadN
,
Transform
::
matrix_padder
.
PadN
,
MaskingSpec
!=
MaskingSpecialization
::
MaskDisabled
,
MaskingSpec
!=
MaskingSpecialization
::
MaskDisabled
>
;
Deterministic
>
;
using
Block2CTileMap
=
OffsettedBlockToCTileMap
<
typename
GridwiseGemm
::
DefaultBlock2CTileMap
>
;
using
Block2CTileMap
=
OffsettedBlockToCTileMap
<
typename
GridwiseGemm
::
DefaultBlock2CTileMap
>
;
...
@@ -789,8 +740,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
...
@@ -789,8 +740,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
const
auto
block_2_ctile_map
=
Block2CTileMap
(
c_grid_desc_m_n
,
BlockStart
);
const
auto
block_2_ctile_map
=
Block2CTileMap
(
c_grid_desc_m_n
,
BlockStart
);
const
index_t
batch_count
=
c_grid_desc_g_m_n
.
GetLength
(
I0
);
const
index_t
batch_count
=
c_grid_desc_g_m_n
.
GetLength
(
I0
);
const
index_t
grid_size_grp
=
const
index_t
grid_size_grp
=
(
Deterministic
?
1
:
block_2_ctile_map
.
CalculateGridSize
(
c_grid_desc_m_n
))
*
block_2_ctile_map
.
CalculateGridSize
(
c_grid_desc_m_n
)
*
batch_count
;
batch_count
;
const
index_t
BlockEnd
=
grid_size_
+
grid_size_grp
;
const
index_t
BlockEnd
=
grid_size_
+
grid_size_grp
;
// batch stride
// batch stride
...
@@ -801,7 +751,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
...
@@ -801,7 +751,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
b1_grid_desc_g_n_k
,
b1_grid_desc_g_n_k
,
c_grid_desc_g_m_n
,
c_grid_desc_g_m_n
,
z_grid_desc_g_m_n
,
z_grid_desc_g_m_n
,
type_convert
<
index_t
>
(
lse_grid_desc_m
.
GetElementSpaceSize
()
));
type_convert
<
index_t
>
(
problem_desc
.
lse_gs_ms_strides
[
NumDimG
-
1
]
));
// C0 mask
// C0 mask
const
auto
c0_matrix_mask
=
const
auto
c0_matrix_mask
=
...
@@ -967,8 +917,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
...
@@ -967,8 +917,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
CElementwiseOperation
,
CElementwiseOperation
,
has_main_k_block_loop_
,
has_main_k_block_loop_
,
use_dropout_
,
use_dropout_
,
is_lse_storing_
,
is_lse_storing_
>
;
Deterministic
>
;
return
launch_and_time_kernel
(
return
launch_and_time_kernel
(
stream_config
,
stream_config
,
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v2.hpp
View file @
957d5dee
...
@@ -320,18 +320,27 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
...
@@ -320,18 +320,27 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
// dV_NOM / dK_NKM Gemm (Gemm2 crr)
// dV_NOM / dK_NKM Gemm (Gemm2 crr)
if
(
O
!=
K
)
if
(
O
!=
K
)
{
{
std
::
cerr
<<
"O = "
<<
O
<<
" K = "
<<
K
<<
std
::
endl
;
std
::
cerr
<<
"SizeK must be equal to SizeO (equal attention head size)"
<<
'\n'
;
std
::
cerr
<<
"SizeK must be equal to SizeO (equal attention head size)"
<<
'\n'
;
return
false
;
return
false
;
}
}
if
(
!
(
M
==
y_grid_desc_m_o
.
GetLength
(
I0
)
&&
O
==
y_grid_desc_m_o
.
GetLength
(
I1
)))
if
(
!
(
M
==
y_grid_desc_m_o
.
GetLength
(
I0
)
&&
O
==
y_grid_desc_m_o
.
GetLength
(
I1
)))
{
{
std
::
cerr
<<
"M = "
<<
M
<<
" O = "
<<
O
<<
" y_grid_desc_m_o = "
<<
y_grid_desc_m_o
.
GetLength
(
I0
)
<<
" , "
<<
y_grid_desc_m_o
.
GetLength
(
I1
)
<<
std
::
endl
;
std
::
cerr
<<
"Un-matched sizes!"
<<
std
::
endl
;
return
false
;
return
false
;
}
}
if
(
!
(
M
%
MPerBlock
==
0
&&
N
%
NPerBlock
==
0
&&
K
%
KPerBlock
==
0
&&
if
(
!
(
M
%
MPerBlock
==
0
&&
N
%
NPerBlock
==
0
&&
K
%
KPerBlock
==
0
&&
O
%
Gemm1NPerBlock
==
0
))
O
%
Gemm1NPerBlock
==
0
))
{
{
std
::
cerr
<<
"M = "
<<
M
<<
" N = "
<<
N
<<
" O = "
<<
O
<<
std
::
endl
;
std
::
cerr
<<
"MPerBlock = "
<<
MPerBlock
<<
" NPerBlock = "
<<
NPerBlock
<<
" KPerBlock = "
<<
KPerBlock
<<
std
::
endl
;
std
::
cerr
<<
"Un-aligned sizes!"
<<
std
::
endl
;
return
false
;
return
false
;
}
}
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_fwd_xdl_cshuffle_v2.hpp
View file @
957d5dee
...
@@ -94,7 +94,6 @@ template <typename FloatAB,
...
@@ -94,7 +94,6 @@ template <typename FloatAB,
LoopScheduler
LoopSched
,
LoopScheduler
LoopSched
,
bool
PadN
,
bool
PadN
,
bool
MaskOutUpperTriangle
,
bool
MaskOutUpperTriangle
,
bool
Deterministic
,
PipelineVersion
PipelineVer
=
PipelineVersion
::
v1
>
PipelineVersion
PipelineVer
=
PipelineVersion
::
v1
>
struct
GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
struct
GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
{
{
...
@@ -531,8 +530,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
...
@@ -531,8 +530,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
FloatGemmAcc
p_dropout_rescale
,
FloatGemmAcc
p_dropout_rescale
,
ck
::
philox
&
ph
,
ck
::
philox
&
ph
,
const
index_t
z_random_matrix_offset
,
const
index_t
z_random_matrix_offset
,
const
index_t
raw_n_padded
,
const
index_t
raw_n_padded
)
const
index_t
block_idx_m
)
{
{
const
auto
a_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
const
auto
a_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_a_grid
,
a_grid_desc_ak0_m_ak1
.
GetElementSpaceSize
());
p_a_grid
,
a_grid_desc_ak0_m_ak1
.
GetElementSpaceSize
());
...
@@ -557,7 +555,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
...
@@ -557,7 +555,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
return
;
return
;
}
}
const
index_t
block_work_idx_m
=
Deterministic
?
block_idx_m
:
block_work_idx
[
I0
];
const
index_t
block_work_idx_m
=
block_work_idx
[
I0
];
// HACK: this force m/gemm1_n_block_data_idx_on_grid into SGPR
// HACK: this force m/gemm1_n_block_data_idx_on_grid into SGPR
const
index_t
m_block_data_idx_on_grid
=
const
index_t
m_block_data_idx_on_grid
=
...
@@ -1145,11 +1143,6 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
...
@@ -1145,11 +1143,6 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
0
),
0
),
tensor_operation
::
element_wise
::
PassThrough
{}};
tensor_operation
::
element_wise
::
PassThrough
{}};
if
constexpr
(
Deterministic
)
{
block_sync_lds
();
}
do
do
{
{
auto
n_block_data_idx_on_grid
=
auto
n_block_data_idx_on_grid
=
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment