Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
84a81ae2
Commit
84a81ae2
authored
Apr 17, 2023
by
danyao12
Browse files
add deterministic mode for FA unittest
parent
e576c081
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
324 additions
and
138 deletions
+324
-138
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward.cpp
...ale_softmax_gemm/batched_multihead_attention_backward.cpp
+13
-8
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_backward.cpp
...ale_softmax_gemm/grouped_multihead_attention_backward.cpp
+13
-8
include/ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle_v1.hpp
..._batched_multihead_attention_backward_xdl_cshuffle_v1.hpp
+74
-30
include/ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle_v2.hpp
..._batched_multihead_attention_backward_xdl_cshuffle_v2.hpp
+74
-30
include/ck/tensor_operation/gpu/device/impl/device_grouped_multihead_attention_backward_xdl_cshuffle_v1.hpp
..._grouped_multihead_attention_backward_xdl_cshuffle_v1.hpp
+75
-31
include/ck/tensor_operation/gpu/device/impl/device_grouped_multihead_attention_backward_xdl_cshuffle_v2.hpp
..._grouped_multihead_attention_backward_xdl_cshuffle_v2.hpp
+75
-31
No files found.
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward.cpp
View file @
84a81ae2
...
@@ -95,6 +95,7 @@ static constexpr auto TensorSpecQ = ck::tensor_operation::device::TensorSpeciali
...
@@ -95,6 +95,7 @@ static constexpr auto TensorSpecQ = ck::tensor_operation::device::TensorSpeciali
static
constexpr
auto
TensorSpecK
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecK
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecV
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecV
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecY
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecY
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
bool
Deterministic
=
true
;
// DIM should be a multiple of 8.
// DIM should be a multiple of 8.
// If DIM <= 32 , ues prototype1 1st template.
// If DIM <= 32 , ues prototype1 1st template.
...
@@ -168,7 +169,8 @@ using DeviceGemmInstance =
...
@@ -168,7 +169,8 @@ using DeviceGemmInstance =
1
,
// CShuffleNXdlPerWavePerShuffle
1
,
// CShuffleNXdlPerWavePerShuffle
S
<
1
,
64
,
1
,
4
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
S
<
1
,
64
,
1
,
4
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
CShuffleBlockTransferScalarPerVector_NPerBlock
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
CShuffleBlockTransferScalarPerVector_NPerBlock
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec
>
;
// MaskingSpecialization
MaskingSpec
,
// MaskingSpecialization
Deterministic
>
;
#elif(DIM <= 64)
#elif(DIM <= 64)
using
DeviceGemmInstance
=
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
<
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
<
...
@@ -237,7 +239,8 @@ using DeviceGemmInstance =
...
@@ -237,7 +239,8 @@ using DeviceGemmInstance =
2
,
// CShuffleNXdlPerWavePerShuffle
2
,
// CShuffleNXdlPerWavePerShuffle
S
<
1
,
32
,
1
,
8
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
S
<
1
,
32
,
1
,
8
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
CShuffleBlockTransferScalarPerVector_NPerBlock
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
CShuffleBlockTransferScalarPerVector_NPerBlock
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec
>
;
// MaskingSpecialization
MaskingSpec
,
// MaskingSpecialization
Deterministic
>
;
// using DeviceGemmInstance =
// using DeviceGemmInstance =
// ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2<
// ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2<
...
@@ -306,7 +309,8 @@ using DeviceGemmInstance =
...
@@ -306,7 +309,8 @@ using DeviceGemmInstance =
// 2, // CShuffleNXdlPerWavePerShuffle
// 2, // CShuffleNXdlPerWavePerShuffle
// S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
// S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
// CShuffleBlockTransferScalarPerVector_NPerBlock,
// CShuffleBlockTransferScalarPerVector_NPerBlock,
// MaskingSpec>;
// MaskingSpec,
// Deterministic>;
#elif(DIM <= 128)
#elif(DIM <= 128)
using
DeviceGemmInstance
=
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
<
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
<
...
@@ -375,7 +379,8 @@ using DeviceGemmInstance =
...
@@ -375,7 +379,8 @@ using DeviceGemmInstance =
4
,
// CShuffleNXdlPerWavePerShuffle
4
,
// CShuffleNXdlPerWavePerShuffle
S
<
1
,
32
,
1
,
8
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
S
<
1
,
32
,
1
,
8
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
CShuffleBlockTransferScalarPerVector_NPerBlock
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
CShuffleBlockTransferScalarPerVector_NPerBlock
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec
>
;
// MaskingSpecialization
MaskingSpec
,
// MaskingSpecialization
Deterministic
>
;
#endif
#endif
// Ref Gemm0: S = alpha * Q * K^T
// Ref Gemm0: S = alpha * Q * K^T
...
...
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_backward.cpp
View file @
84a81ae2
...
@@ -94,6 +94,7 @@ static constexpr auto TensorSpecQ = ck::tensor_operation::device::TensorSpeciali
...
@@ -94,6 +94,7 @@ static constexpr auto TensorSpecQ = ck::tensor_operation::device::TensorSpeciali
static
constexpr
auto
TensorSpecK
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecK
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecV
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecV
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecY
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecY
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
bool
Deterministic
=
true
;
// DIM should be a multiple of 8.
// DIM should be a multiple of 8.
// If DIM <= 32 , ues prototype1 1st template.
// If DIM <= 32 , ues prototype1 1st template.
...
@@ -167,7 +168,8 @@ using DeviceGemmInstance =
...
@@ -167,7 +168,8 @@ using DeviceGemmInstance =
1
,
// CShuffleNXdlPerWavePerShuffle
1
,
// CShuffleNXdlPerWavePerShuffle
S
<
1
,
64
,
1
,
4
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
S
<
1
,
64
,
1
,
4
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
CShuffleBlockTransferScalarPerVector_NPerBlock
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
CShuffleBlockTransferScalarPerVector_NPerBlock
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec
>
;
// MaskingSpecialization
MaskingSpec
,
// MaskingSpecialization
Deterministic
>
;
#elif(DIM <= 64)
#elif(DIM <= 64)
using
DeviceGemmInstance
=
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1
<
ck
::
tensor_operation
::
device
::
DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1
<
...
@@ -236,7 +238,8 @@ using DeviceGemmInstance =
...
@@ -236,7 +238,8 @@ using DeviceGemmInstance =
2
,
// CShuffleNXdlPerWavePerShuffle
2
,
// CShuffleNXdlPerWavePerShuffle
S
<
1
,
32
,
1
,
8
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
S
<
1
,
32
,
1
,
8
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
CShuffleBlockTransferScalarPerVector_NPerBlock
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
CShuffleBlockTransferScalarPerVector_NPerBlock
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec
>
;
// MaskingSpecialization
MaskingSpec
,
// MaskingSpecialization
Deterministic
>
;
// using DeviceGemmInstance =
// using DeviceGemmInstance =
// ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2<
// ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2<
...
@@ -305,7 +308,8 @@ using DeviceGemmInstance =
...
@@ -305,7 +308,8 @@ using DeviceGemmInstance =
// 2, // CShuffleNXdlPerWavePerShuffle
// 2, // CShuffleNXdlPerWavePerShuffle
// S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
// S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
// CShuffleBlockTransferScalarPerVector_NPerBlock,
// CShuffleBlockTransferScalarPerVector_NPerBlock,
// MaskingSpec>;
// MaskingSpec,
// Deterministic>;
#elif(DIM <= 128)
#elif(DIM <= 128)
using
DeviceGemmInstance
=
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
<
ck
::
tensor_operation
::
device
::
DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
<
...
@@ -374,7 +378,8 @@ using DeviceGemmInstance =
...
@@ -374,7 +378,8 @@ using DeviceGemmInstance =
4
,
// CShuffleNXdlPerWavePerShuffle
4
,
// CShuffleNXdlPerWavePerShuffle
S
<
1
,
32
,
1
,
8
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
S
<
1
,
32
,
1
,
8
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
CShuffleBlockTransferScalarPerVector_NPerBlock
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
CShuffleBlockTransferScalarPerVector_NPerBlock
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec
>
;
// MaskingSpecialization
MaskingSpec
,
// MaskingSpecialization
Deterministic
>
;
#endif
#endif
// Ref Gemm0: S = alpha * Q * K^T
// Ref Gemm0: S = alpha * Q * K^T
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle_v1.hpp
View file @
84a81ae2
...
@@ -48,7 +48,8 @@ template <typename GridwiseGemm,
...
@@ -48,7 +48,8 @@ template <typename GridwiseGemm,
typename
Block2CTileMap
,
typename
Block2CTileMap
,
typename
ComputeBasePtrOfStridedBatch
,
typename
ComputeBasePtrOfStridedBatch
,
typename
C0MatrixMask
,
typename
C0MatrixMask
,
bool
HasMainKBlockLoop
>
bool
HasMainKBlockLoop
,
bool
Deterministic
>
__global__
void
__global__
void
#if CK_USE_LAUNCH_BOUNDS
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
/*CK_MIN_BLOCK_PER_CU*/
1
)
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
/*CK_MIN_BLOCK_PER_CU*/
1
)
...
@@ -112,6 +113,46 @@ __global__ void
...
@@ -112,6 +113,46 @@ __global__ void
ck
::
philox
ph
(
seed
,
global_thread_id
,
offset
);
ck
::
philox
ph
(
seed
,
global_thread_id
,
offset
);
ZDataType
*
z_matrix_ptr
=
(
p_z_grid
==
nullptr
?
nullptr
:
p_z_grid
+
z_batch_offset
);
ZDataType
*
z_matrix_ptr
=
(
p_z_grid
==
nullptr
?
nullptr
:
p_z_grid
+
z_batch_offset
);
if
constexpr
(
Deterministic
)
{
for
(
index_t
i
=
0
;
i
<
num_blocks_per_batch
;
i
++
)
{
if
(
get_block_1d_id
()
%
num_blocks_per_batch
==
i
)
{
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
+
a_batch_offset
,
p_b_grid
+
b_batch_offset
,
z_matrix_ptr
,
p_b1_grid
+
b1_batch_offset
,
p_c_grid
+
c_batch_offset
,
p_lse_grid
+
lse_batch_offset
,
p_ygrad_grid
+
c_batch_offset
,
p_qgrad_grid
+
a_batch_offset
,
p_kgrad_grid
+
b_batch_offset
,
p_vgrad_grid
+
b1_batch_offset
,
p_shared
,
a_element_op
,
b_element_op
,
acc_element_op
,
b1_element_op
,
c_element_op
,
a_grid_desc_ak0_m_ak1
,
b_grid_desc_bk0_n_bk1
,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
b1_grid_desc_bk0_n_bk1
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
lse_grid_desc_m
,
vgrad_grid_desc_n_o
,
ygrad_grid_desc_o0_m_o1
,
block_2_ctile_map
,
c0_matrix_mask
,
p_drop
,
ph
);
}
}
}
else
{
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
+
a_batch_offset
,
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
+
a_batch_offset
,
p_b_grid
+
b_batch_offset
,
p_b_grid
+
b_batch_offset
,
z_matrix_ptr
,
z_matrix_ptr
,
...
@@ -140,6 +181,7 @@ __global__ void
...
@@ -140,6 +181,7 @@ __global__ void
c0_matrix_mask
,
c0_matrix_mask
,
p_drop
,
p_drop
,
ph
);
ph
);
}
#else
#else
ignore
=
p_a_grid
;
ignore
=
p_a_grid
;
ignore
=
p_b_grid
;
ignore
=
p_b_grid
;
...
@@ -233,6 +275,7 @@ template <index_t NumDimG,
...
@@ -233,6 +275,7 @@ template <index_t NumDimG,
typename
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
,
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
,
MaskingSpecialization
MaskingSpec
,
MaskingSpecialization
MaskingSpec
,
bool
Deterministic
,
LoopScheduler
LoopSched
=
LoopScheduler
::
Default
>
LoopScheduler
LoopSched
=
LoopScheduler
::
Default
>
struct
DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
struct
DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
:
public
BaseOperator
// TODO inherit atten bwd op once API stablizes
:
public
BaseOperator
// TODO inherit atten bwd op once API stablizes
...
@@ -925,7 +968,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
...
@@ -925,7 +968,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
typename
GridwiseGemm
::
DefaultBlock2CTileMap
,
typename
GridwiseGemm
::
DefaultBlock2CTileMap
,
ComputeBasePtrOfStridedBatch
,
ComputeBasePtrOfStridedBatch
,
C0MatrixMask
,
C0MatrixMask
,
has_main_k_block_loop_
>
;
has_main_k_block_loop_
,
Deterministic
>
;
return
launch_and_time_kernel
(
stream_config
,
return
launch_and_time_kernel
(
stream_config
,
kernel
,
kernel
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle_v2.hpp
View file @
84a81ae2
...
@@ -47,7 +47,8 @@ template <typename GridwiseGemm,
...
@@ -47,7 +47,8 @@ template <typename GridwiseGemm,
typename
Block2CTileMap
,
typename
Block2CTileMap
,
typename
ComputeBasePtrOfStridedBatch
,
typename
ComputeBasePtrOfStridedBatch
,
typename
C0MatrixMask
,
typename
C0MatrixMask
,
bool
HasMainKBlockLoop
>
bool
HasMainKBlockLoop
,
bool
Deterministic
>
__global__
void
__global__
void
#if CK_USE_LAUNCH_BOUNDS
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
/*CK_MIN_BLOCK_PER_CU*/
1
)
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
/*CK_MIN_BLOCK_PER_CU*/
1
)
...
@@ -111,6 +112,46 @@ __global__ void
...
@@ -111,6 +112,46 @@ __global__ void
ck
::
philox
ph
(
seed
,
global_thread_id
,
offset
);
ck
::
philox
ph
(
seed
,
global_thread_id
,
offset
);
ZDataType
*
z_matrix_ptr
=
(
p_z_grid
==
nullptr
?
nullptr
:
p_z_grid
+
z_batch_offset
);
ZDataType
*
z_matrix_ptr
=
(
p_z_grid
==
nullptr
?
nullptr
:
p_z_grid
+
z_batch_offset
);
if
constexpr
(
Deterministic
)
{
for
(
index_t
i
=
0
;
i
<
num_blocks_per_batch
;
i
++
)
{
if
(
get_block_1d_id
()
%
num_blocks_per_batch
==
i
)
{
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
+
a_batch_offset
,
p_b_grid
+
b_batch_offset
,
z_matrix_ptr
,
p_b1_grid
+
b1_batch_offset
,
p_c_grid
+
c_batch_offset
,
p_lse_grid
+
lse_batch_offset
,
p_ygrad_grid
+
c_batch_offset
,
p_qgrad_grid
+
a_batch_offset
,
p_kgrad_grid
+
b_batch_offset
,
p_vgrad_grid
+
b1_batch_offset
,
p_shared
,
a_element_op
,
b_element_op
,
acc_element_op
,
b1_element_op
,
c_element_op
,
a_grid_desc_ak0_m_ak1
,
b_grid_desc_bk0_n_bk1
,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
b1_grid_desc_bk0_n_bk1
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
lse_grid_desc_m
,
vgrad_grid_desc_n_o
,
ygrad_grid_desc_m0_o_m1
,
block_2_ctile_map
,
c0_matrix_mask
,
p_drop
,
ph
);
}
}
}
else
{
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
+
a_batch_offset
,
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
+
a_batch_offset
,
p_b_grid
+
b_batch_offset
,
p_b_grid
+
b_batch_offset
,
z_matrix_ptr
,
z_matrix_ptr
,
...
@@ -139,6 +180,7 @@ __global__ void
...
@@ -139,6 +180,7 @@ __global__ void
c0_matrix_mask
,
c0_matrix_mask
,
p_drop
,
p_drop
,
ph
);
ph
);
}
#else
#else
ignore
=
p_a_grid
;
ignore
=
p_a_grid
;
ignore
=
p_b_grid
;
ignore
=
p_b_grid
;
...
@@ -232,6 +274,7 @@ template <index_t NumDimG,
...
@@ -232,6 +274,7 @@ template <index_t NumDimG,
typename
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
,
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
,
MaskingSpecialization
MaskingSpec
,
MaskingSpecialization
MaskingSpec
,
bool
Deterministic
,
LoopScheduler
LoopSched
=
LoopScheduler
::
Default
>
LoopScheduler
LoopSched
=
LoopScheduler
::
Default
>
struct
DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
struct
DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
:
public
BaseOperator
// TODO inherit atten bwd op once API stablizes
:
public
BaseOperator
// TODO inherit atten bwd op once API stablizes
...
@@ -927,7 +970,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -927,7 +970,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
typename
GridwiseGemm
::
DefaultBlock2CTileMap
,
typename
GridwiseGemm
::
DefaultBlock2CTileMap
,
ComputeBasePtrOfStridedBatch
,
ComputeBasePtrOfStridedBatch
,
C0MatrixMask
,
C0MatrixMask
,
has_main_k_block_loop_
>
;
has_main_k_block_loop_
,
Deterministic
>
;
return
launch_and_time_kernel
(
stream_config
,
return
launch_and_time_kernel
(
stream_config
,
kernel
,
kernel
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_multihead_attention_backward_xdl_cshuffle_v1.hpp
View file @
84a81ae2
...
@@ -34,7 +34,8 @@ template <typename GridwiseGemm,
...
@@ -34,7 +34,8 @@ template <typename GridwiseGemm,
typename
AccElementwiseOperation
,
typename
AccElementwiseOperation
,
typename
B1ElementwiseOperation
,
typename
B1ElementwiseOperation
,
typename
CElementwiseOperation
,
typename
CElementwiseOperation
,
bool
HasMainKBlockLoop
>
bool
HasMainKBlockLoop
,
bool
Deterministic
>
__global__
void
__global__
void
#if CK_USE_LAUNCH_BOUNDS
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
/*CK_MIN_BLOCK_PER_CU*/
1
)
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
/*CK_MIN_BLOCK_PER_CU*/
1
)
...
@@ -99,6 +100,46 @@ __global__ void
...
@@ -99,6 +100,46 @@ __global__ void
(
arg_ptr
[
group_id
].
p_z_grid_
==
nullptr
?
nullptr
(
arg_ptr
[
group_id
].
p_z_grid_
==
nullptr
?
nullptr
:
arg_ptr
[
group_id
].
p_z_grid_
+
z_batch_offset
);
:
arg_ptr
[
group_id
].
p_z_grid_
+
z_batch_offset
);
if
constexpr
(
Deterministic
)
{
for
(
index_t
i
=
0
;
i
<
num_blocks_per_batch
;
i
++
)
{
if
(((
block_id
-
arg_ptr
[
group_id
].
block_start_
)
%
num_blocks_per_batch
)
==
i
)
{
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
arg_ptr
[
group_id
].
p_a_grid_
+
a_batch_offset
,
arg_ptr
[
group_id
].
p_b_grid_
+
b_batch_offset
,
z_matrix_ptr
,
arg_ptr
[
group_id
].
p_b1_grid_
+
b1_batch_offset
,
arg_ptr
[
group_id
].
p_c_grid_
+
c_batch_offset
,
arg_ptr
[
group_id
].
p_lse_grid_
+
lse_batch_offset
,
arg_ptr
[
group_id
].
p_ygrad_grid_
+
c_batch_offset
,
arg_ptr
[
group_id
].
p_qgrad_grid_
+
a_batch_offset
,
arg_ptr
[
group_id
].
p_kgrad_grid_
+
b_batch_offset
,
arg_ptr
[
group_id
].
p_vgrad_grid_
+
b1_batch_offset
,
p_shared
,
a_element_op
,
b_element_op
,
acc_element_op
,
b1_element_op
,
c_element_op
,
arg_ptr
[
group_id
].
a_grid_desc_ak0_m_ak1_
,
arg_ptr
[
group_id
].
b_grid_desc_bk0_n_bk1_
,
arg_ptr
[
group_id
].
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_
,
arg_ptr
[
group_id
].
b1_grid_desc_bk0_n_bk1_
,
arg_ptr
[
group_id
].
y_grid_desc_mblock_mperblock_oblock_operblock_
,
arg_ptr
[
group_id
].
lse_grid_desc_m_
,
arg_ptr
[
group_id
].
vgrad_grid_desc_n_o_
,
arg_ptr
[
group_id
].
ygrad_grid_desc_o0_m_o1_
,
arg_ptr
[
group_id
].
block_2_ctile_map_
,
arg_ptr
[
group_id
].
c0_matrix_mask_
,
p_dropout
,
ph
);
}
}
}
else
{
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
arg_ptr
[
group_id
].
p_a_grid_
+
a_batch_offset
,
arg_ptr
[
group_id
].
p_a_grid_
+
a_batch_offset
,
arg_ptr
[
group_id
].
p_b_grid_
+
b_batch_offset
,
arg_ptr
[
group_id
].
p_b_grid_
+
b_batch_offset
,
...
@@ -128,6 +169,7 @@ __global__ void
...
@@ -128,6 +169,7 @@ __global__ void
arg_ptr
[
group_id
].
c0_matrix_mask_
,
arg_ptr
[
group_id
].
c0_matrix_mask_
,
p_dropout
,
p_dropout
,
ph
);
ph
);
}
#else
#else
ignore
=
group_kernel_args
;
ignore
=
group_kernel_args
;
ignore
=
group_count
;
ignore
=
group_count
;
...
@@ -211,6 +253,7 @@ template <index_t NumDimG,
...
@@ -211,6 +253,7 @@ template <index_t NumDimG,
typename
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
,
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
,
MaskingSpecialization
MaskingSpec
,
MaskingSpecialization
MaskingSpec
,
bool
Deterministic
,
LoopScheduler
LoopSched
=
LoopScheduler
::
Default
>
LoopScheduler
LoopSched
=
LoopScheduler
::
Default
>
struct
DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1
struct
DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1
:
public
BaseOperator
// TODO inherit atten bwd op once API stablizes
:
public
BaseOperator
// TODO inherit atten bwd op once API stablizes
...
@@ -920,7 +963,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1
...
@@ -920,7 +963,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1
AccElementwiseOperation
,
AccElementwiseOperation
,
B1ElementwiseOperation
,
B1ElementwiseOperation
,
CElementwiseOperation
,
CElementwiseOperation
,
has_main_k_block_loop_
>
;
has_main_k_block_loop_
,
Deterministic
>
;
return
launch_and_time_kernel
(
return
launch_and_time_kernel
(
stream_config
,
stream_config
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_multihead_attention_backward_xdl_cshuffle_v2.hpp
View file @
84a81ae2
...
@@ -34,7 +34,8 @@ template <typename GridwiseGemm,
...
@@ -34,7 +34,8 @@ template <typename GridwiseGemm,
typename
AccElementwiseOperation
,
typename
AccElementwiseOperation
,
typename
B1ElementwiseOperation
,
typename
B1ElementwiseOperation
,
typename
CElementwiseOperation
,
typename
CElementwiseOperation
,
bool
HasMainKBlockLoop
>
bool
HasMainKBlockLoop
,
bool
Deterministic
>
__global__
void
__global__
void
#if CK_USE_LAUNCH_BOUNDS
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
/*CK_MIN_BLOCK_PER_CU*/
1
)
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
/*CK_MIN_BLOCK_PER_CU*/
1
)
...
@@ -99,6 +100,46 @@ __global__ void
...
@@ -99,6 +100,46 @@ __global__ void
(
arg_ptr
[
group_id
].
p_z_grid_
==
nullptr
?
nullptr
(
arg_ptr
[
group_id
].
p_z_grid_
==
nullptr
?
nullptr
:
arg_ptr
[
group_id
].
p_z_grid_
+
z_batch_offset
);
:
arg_ptr
[
group_id
].
p_z_grid_
+
z_batch_offset
);
if
constexpr
(
Deterministic
)
{
for
(
index_t
i
=
0
;
i
<
num_blocks_per_batch
;
i
++
)
{
if
(((
block_id
-
arg_ptr
[
group_id
].
block_start_
)
%
num_blocks_per_batch
)
==
i
)
{
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
arg_ptr
[
group_id
].
p_a_grid_
+
a_batch_offset
,
arg_ptr
[
group_id
].
p_b_grid_
+
b_batch_offset
,
z_matrix_ptr
,
arg_ptr
[
group_id
].
p_b1_grid_
+
b1_batch_offset
,
arg_ptr
[
group_id
].
p_c_grid_
+
c_batch_offset
,
arg_ptr
[
group_id
].
p_lse_grid_
+
lse_batch_offset
,
arg_ptr
[
group_id
].
p_ygrad_grid_
+
c_batch_offset
,
arg_ptr
[
group_id
].
p_qgrad_grid_
+
a_batch_offset
,
arg_ptr
[
group_id
].
p_kgrad_grid_
+
b_batch_offset
,
arg_ptr
[
group_id
].
p_vgrad_grid_
+
b1_batch_offset
,
p_shared
,
a_element_op
,
b_element_op
,
acc_element_op
,
b1_element_op
,
c_element_op
,
arg_ptr
[
group_id
].
a_grid_desc_ak0_m_ak1_
,
arg_ptr
[
group_id
].
b_grid_desc_bk0_n_bk1_
,
arg_ptr
[
group_id
].
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_
,
arg_ptr
[
group_id
].
b1_grid_desc_bk0_n_bk1_
,
arg_ptr
[
group_id
].
y_grid_desc_mblock_mperblock_oblock_operblock_
,
arg_ptr
[
group_id
].
lse_grid_desc_m_
,
arg_ptr
[
group_id
].
vgrad_grid_desc_n_o_
,
arg_ptr
[
group_id
].
ygrad_grid_desc_m0_o_m1_
,
arg_ptr
[
group_id
].
block_2_ctile_map_
,
arg_ptr
[
group_id
].
c0_matrix_mask_
,
p_dropout
,
ph
);
}
}
}
else
{
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
arg_ptr
[
group_id
].
p_a_grid_
+
a_batch_offset
,
arg_ptr
[
group_id
].
p_a_grid_
+
a_batch_offset
,
arg_ptr
[
group_id
].
p_b_grid_
+
b_batch_offset
,
arg_ptr
[
group_id
].
p_b_grid_
+
b_batch_offset
,
...
@@ -128,6 +169,7 @@ __global__ void
...
@@ -128,6 +169,7 @@ __global__ void
arg_ptr
[
group_id
].
c0_matrix_mask_
,
arg_ptr
[
group_id
].
c0_matrix_mask_
,
p_dropout
,
p_dropout
,
ph
);
ph
);
}
#else
#else
ignore
=
group_kernel_args
;
ignore
=
group_kernel_args
;
ignore
=
group_count
;
ignore
=
group_count
;
...
@@ -211,6 +253,7 @@ template <index_t NumDimG,
...
@@ -211,6 +253,7 @@ template <index_t NumDimG,
typename
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
,
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
,
MaskingSpecialization
MaskingSpec
,
MaskingSpecialization
MaskingSpec
,
bool
Deterministic
,
LoopScheduler
LoopSched
=
LoopScheduler
::
Default
>
LoopScheduler
LoopSched
=
LoopScheduler
::
Default
>
struct
DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
struct
DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
:
public
BaseOperator
// TODO inherit atten bwd op once API stablizes
:
public
BaseOperator
// TODO inherit atten bwd op once API stablizes
...
@@ -912,7 +955,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -912,7 +955,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
AccElementwiseOperation
,
AccElementwiseOperation
,
B1ElementwiseOperation
,
B1ElementwiseOperation
,
CElementwiseOperation
,
CElementwiseOperation
,
has_main_k_block_loop_
>
;
has_main_k_block_loop_
,
Deterministic
>
;
return
launch_and_time_kernel
(
return
launch_and_time_kernel
(
stream_config
,
stream_config
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment