Commit 84a81ae2 authored by danyao12's avatar danyao12
Browse files

add deterministic mode for FA unittest

parent e576c081
...@@ -91,10 +91,11 @@ static constexpr auto MaskingSpec = ...@@ -91,10 +91,11 @@ static constexpr auto MaskingSpec =
ck::tensor_operation::device::MaskingSpecialization::MaskDisabled; ck::tensor_operation::device::MaskingSpecialization::MaskDisabled;
#endif #endif
static constexpr auto TensorSpecQ = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecQ = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecK = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecK = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecV = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecV = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecY = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecY = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr bool Deterministic = true;
// DIM should be a multiple of 8. // DIM should be a multiple of 8.
// If DIM <= 32 , ues prototype1 1st template. // If DIM <= 32 , ues prototype1 1st template.
...@@ -168,7 +169,8 @@ using DeviceGemmInstance = ...@@ -168,7 +169,8 @@ using DeviceGemmInstance =
1, // CShuffleNXdlPerWavePerShuffle 1, // CShuffleNXdlPerWavePerShuffle
S<1, 64, 1, 4>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock S<1, 64, 1, 4>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
CShuffleBlockTransferScalarPerVector_NPerBlock, // CShuffleBlockTransferScalarPerVector_NPerBlock CShuffleBlockTransferScalarPerVector_NPerBlock, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec>; // MaskingSpecialization MaskingSpec, // MaskingSpecialization
Deterministic>;
#elif(DIM <= 64) #elif(DIM <= 64)
using DeviceGemmInstance = using DeviceGemmInstance =
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1< ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1<
...@@ -237,7 +239,8 @@ using DeviceGemmInstance = ...@@ -237,7 +239,8 @@ using DeviceGemmInstance =
2, // CShuffleNXdlPerWavePerShuffle 2, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
CShuffleBlockTransferScalarPerVector_NPerBlock, // CShuffleBlockTransferScalarPerVector_NPerBlock CShuffleBlockTransferScalarPerVector_NPerBlock, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec>; // MaskingSpecialization MaskingSpec, // MaskingSpecialization
Deterministic>;
// using DeviceGemmInstance = // using DeviceGemmInstance =
// ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2< // ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2<
...@@ -306,7 +309,8 @@ using DeviceGemmInstance = ...@@ -306,7 +309,8 @@ using DeviceGemmInstance =
// 2, // CShuffleNXdlPerWavePerShuffle // 2, // CShuffleNXdlPerWavePerShuffle
// S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock // S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
// CShuffleBlockTransferScalarPerVector_NPerBlock, // CShuffleBlockTransferScalarPerVector_NPerBlock,
// MaskingSpec>; // MaskingSpec,
// Deterministic>;
#elif(DIM <= 128) #elif(DIM <= 128)
using DeviceGemmInstance = using DeviceGemmInstance =
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2< ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2<
...@@ -375,7 +379,8 @@ using DeviceGemmInstance = ...@@ -375,7 +379,8 @@ using DeviceGemmInstance =
4, // CShuffleNXdlPerWavePerShuffle 4, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
CShuffleBlockTransferScalarPerVector_NPerBlock, // CShuffleBlockTransferScalarPerVector_NPerBlock CShuffleBlockTransferScalarPerVector_NPerBlock, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec>; // MaskingSpecialization MaskingSpec, // MaskingSpecialization
Deterministic>;
#endif #endif
// Ref Gemm0: S = alpha * Q * K^T // Ref Gemm0: S = alpha * Q * K^T
......
...@@ -90,10 +90,11 @@ static constexpr auto MaskingSpec = ...@@ -90,10 +90,11 @@ static constexpr auto MaskingSpec =
ck::tensor_operation::device::MaskingSpecialization::MaskDisabled; ck::tensor_operation::device::MaskingSpecialization::MaskDisabled;
#endif #endif
static constexpr auto TensorSpecQ = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecQ = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecK = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecK = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecV = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecV = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecY = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecY = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr bool Deterministic = true;
// DIM should be a multiple of 8. // DIM should be a multiple of 8.
// If DIM <= 32 , ues prototype1 1st template. // If DIM <= 32 , ues prototype1 1st template.
...@@ -167,7 +168,8 @@ using DeviceGemmInstance = ...@@ -167,7 +168,8 @@ using DeviceGemmInstance =
1, // CShuffleNXdlPerWavePerShuffle 1, // CShuffleNXdlPerWavePerShuffle
S<1, 64, 1, 4>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock S<1, 64, 1, 4>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
CShuffleBlockTransferScalarPerVector_NPerBlock, // CShuffleBlockTransferScalarPerVector_NPerBlock CShuffleBlockTransferScalarPerVector_NPerBlock, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec>; // MaskingSpecialization MaskingSpec, // MaskingSpecialization
Deterministic>;
#elif(DIM <= 64) #elif(DIM <= 64)
using DeviceGemmInstance = using DeviceGemmInstance =
ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1< ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1<
...@@ -236,7 +238,8 @@ using DeviceGemmInstance = ...@@ -236,7 +238,8 @@ using DeviceGemmInstance =
2, // CShuffleNXdlPerWavePerShuffle 2, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
CShuffleBlockTransferScalarPerVector_NPerBlock, // CShuffleBlockTransferScalarPerVector_NPerBlock CShuffleBlockTransferScalarPerVector_NPerBlock, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec>; // MaskingSpecialization MaskingSpec, // MaskingSpecialization
Deterministic>;
// using DeviceGemmInstance = // using DeviceGemmInstance =
// ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2< // ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2<
...@@ -305,7 +308,8 @@ using DeviceGemmInstance = ...@@ -305,7 +308,8 @@ using DeviceGemmInstance =
// 2, // CShuffleNXdlPerWavePerShuffle // 2, // CShuffleNXdlPerWavePerShuffle
// S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock // S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
// CShuffleBlockTransferScalarPerVector_NPerBlock, // CShuffleBlockTransferScalarPerVector_NPerBlock,
// MaskingSpec>; // MaskingSpec,
// Deterministic>;
#elif(DIM <= 128) #elif(DIM <= 128)
using DeviceGemmInstance = using DeviceGemmInstance =
ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2< ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2<
...@@ -374,7 +378,8 @@ using DeviceGemmInstance = ...@@ -374,7 +378,8 @@ using DeviceGemmInstance =
4, // CShuffleNXdlPerWavePerShuffle 4, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
CShuffleBlockTransferScalarPerVector_NPerBlock, // CShuffleBlockTransferScalarPerVector_NPerBlock CShuffleBlockTransferScalarPerVector_NPerBlock, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec>; // MaskingSpecialization MaskingSpec, // MaskingSpecialization
Deterministic>;
#endif #endif
// Ref Gemm0: S = alpha * Q * K^T // Ref Gemm0: S = alpha * Q * K^T
......
...@@ -48,7 +48,8 @@ template <typename GridwiseGemm, ...@@ -48,7 +48,8 @@ template <typename GridwiseGemm,
typename Block2CTileMap, typename Block2CTileMap,
typename ComputeBasePtrOfStridedBatch, typename ComputeBasePtrOfStridedBatch,
typename C0MatrixMask, typename C0MatrixMask,
bool HasMainKBlockLoop> bool HasMainKBlockLoop,
bool Deterministic>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, /*CK_MIN_BLOCK_PER_CU*/ 1) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, /*CK_MIN_BLOCK_PER_CU*/ 1)
...@@ -112,34 +113,75 @@ __global__ void ...@@ -112,34 +113,75 @@ __global__ void
ck::philox ph(seed, global_thread_id, offset); ck::philox ph(seed, global_thread_id, offset);
ZDataType* z_matrix_ptr = (p_z_grid == nullptr ? nullptr : p_z_grid + z_batch_offset); ZDataType* z_matrix_ptr = (p_z_grid == nullptr ? nullptr : p_z_grid + z_batch_offset);
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid + a_batch_offset, if constexpr(Deterministic)
p_b_grid + b_batch_offset, {
z_matrix_ptr, for(index_t i = 0; i < num_blocks_per_batch; i++)
p_b1_grid + b1_batch_offset, {
p_c_grid + c_batch_offset, if(get_block_1d_id() % num_blocks_per_batch == i)
p_lse_grid + lse_batch_offset, {
p_ygrad_grid + c_batch_offset, GridwiseGemm::template Run<HasMainKBlockLoop>(
p_qgrad_grid + a_batch_offset, p_a_grid + a_batch_offset,
p_kgrad_grid + b_batch_offset, p_b_grid + b_batch_offset,
p_vgrad_grid + b1_batch_offset, z_matrix_ptr,
p_shared, p_b1_grid + b1_batch_offset,
a_element_op, p_c_grid + c_batch_offset,
b_element_op, p_lse_grid + lse_batch_offset,
acc_element_op, p_ygrad_grid + c_batch_offset,
b1_element_op, p_qgrad_grid + a_batch_offset,
c_element_op, p_kgrad_grid + b_batch_offset,
a_grid_desc_ak0_m_ak1, p_vgrad_grid + b1_batch_offset,
b_grid_desc_bk0_n_bk1, p_shared,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, a_element_op,
b1_grid_desc_bk0_n_bk1, b_element_op,
c_grid_desc_mblock_mperblock_nblock_nperblock, acc_element_op,
lse_grid_desc_m, b1_element_op,
vgrad_grid_desc_n_o, c_element_op,
ygrad_grid_desc_o0_m_o1, a_grid_desc_ak0_m_ak1,
block_2_ctile_map, b_grid_desc_bk0_n_bk1,
c0_matrix_mask, c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
p_drop, b1_grid_desc_bk0_n_bk1,
ph); c_grid_desc_mblock_mperblock_nblock_nperblock,
lse_grid_desc_m,
vgrad_grid_desc_n_o,
ygrad_grid_desc_o0_m_o1,
block_2_ctile_map,
c0_matrix_mask,
p_drop,
ph);
}
}
}
else
{
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid + a_batch_offset,
p_b_grid + b_batch_offset,
z_matrix_ptr,
p_b1_grid + b1_batch_offset,
p_c_grid + c_batch_offset,
p_lse_grid + lse_batch_offset,
p_ygrad_grid + c_batch_offset,
p_qgrad_grid + a_batch_offset,
p_kgrad_grid + b_batch_offset,
p_vgrad_grid + b1_batch_offset,
p_shared,
a_element_op,
b_element_op,
acc_element_op,
b1_element_op,
c_element_op,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
b1_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock,
lse_grid_desc_m,
vgrad_grid_desc_n_o,
ygrad_grid_desc_o0_m_o1,
block_2_ctile_map,
c0_matrix_mask,
p_drop,
ph);
}
#else #else
ignore = p_a_grid; ignore = p_a_grid;
ignore = p_b_grid; ignore = p_b_grid;
...@@ -233,6 +275,7 @@ template <index_t NumDimG, ...@@ -233,6 +275,7 @@ template <index_t NumDimG,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock, index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
MaskingSpecialization MaskingSpec, MaskingSpecialization MaskingSpec,
bool Deterministic,
LoopScheduler LoopSched = LoopScheduler::Default> LoopScheduler LoopSched = LoopScheduler::Default>
struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1 struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
: public BaseOperator // TODO inherit atten bwd op once API stablizes : public BaseOperator // TODO inherit atten bwd op once API stablizes
...@@ -925,7 +968,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -925,7 +968,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
typename GridwiseGemm::DefaultBlock2CTileMap, typename GridwiseGemm::DefaultBlock2CTileMap,
ComputeBasePtrOfStridedBatch, ComputeBasePtrOfStridedBatch,
C0MatrixMask, C0MatrixMask,
has_main_k_block_loop_>; has_main_k_block_loop_,
Deterministic>;
return launch_and_time_kernel(stream_config, return launch_and_time_kernel(stream_config,
kernel, kernel,
......
...@@ -47,7 +47,8 @@ template <typename GridwiseGemm, ...@@ -47,7 +47,8 @@ template <typename GridwiseGemm,
typename Block2CTileMap, typename Block2CTileMap,
typename ComputeBasePtrOfStridedBatch, typename ComputeBasePtrOfStridedBatch,
typename C0MatrixMask, typename C0MatrixMask,
bool HasMainKBlockLoop> bool HasMainKBlockLoop,
bool Deterministic>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, /*CK_MIN_BLOCK_PER_CU*/ 1) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, /*CK_MIN_BLOCK_PER_CU*/ 1)
...@@ -111,34 +112,75 @@ __global__ void ...@@ -111,34 +112,75 @@ __global__ void
ck::philox ph(seed, global_thread_id, offset); ck::philox ph(seed, global_thread_id, offset);
ZDataType* z_matrix_ptr = (p_z_grid == nullptr ? nullptr : p_z_grid + z_batch_offset); ZDataType* z_matrix_ptr = (p_z_grid == nullptr ? nullptr : p_z_grid + z_batch_offset);
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid + a_batch_offset, if constexpr(Deterministic)
p_b_grid + b_batch_offset, {
z_matrix_ptr, for(index_t i = 0; i < num_blocks_per_batch; i++)
p_b1_grid + b1_batch_offset, {
p_c_grid + c_batch_offset, if(get_block_1d_id() % num_blocks_per_batch == i)
p_lse_grid + lse_batch_offset, {
p_ygrad_grid + c_batch_offset, GridwiseGemm::template Run<HasMainKBlockLoop>(
p_qgrad_grid + a_batch_offset, p_a_grid + a_batch_offset,
p_kgrad_grid + b_batch_offset, p_b_grid + b_batch_offset,
p_vgrad_grid + b1_batch_offset, z_matrix_ptr,
p_shared, p_b1_grid + b1_batch_offset,
a_element_op, p_c_grid + c_batch_offset,
b_element_op, p_lse_grid + lse_batch_offset,
acc_element_op, p_ygrad_grid + c_batch_offset,
b1_element_op, p_qgrad_grid + a_batch_offset,
c_element_op, p_kgrad_grid + b_batch_offset,
a_grid_desc_ak0_m_ak1, p_vgrad_grid + b1_batch_offset,
b_grid_desc_bk0_n_bk1, p_shared,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, a_element_op,
b1_grid_desc_bk0_n_bk1, b_element_op,
c_grid_desc_mblock_mperblock_nblock_nperblock, acc_element_op,
lse_grid_desc_m, b1_element_op,
vgrad_grid_desc_n_o, c_element_op,
ygrad_grid_desc_m0_o_m1, a_grid_desc_ak0_m_ak1,
block_2_ctile_map, b_grid_desc_bk0_n_bk1,
c0_matrix_mask, c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
p_drop, b1_grid_desc_bk0_n_bk1,
ph); c_grid_desc_mblock_mperblock_nblock_nperblock,
lse_grid_desc_m,
vgrad_grid_desc_n_o,
ygrad_grid_desc_m0_o_m1,
block_2_ctile_map,
c0_matrix_mask,
p_drop,
ph);
}
}
}
else
{
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid + a_batch_offset,
p_b_grid + b_batch_offset,
z_matrix_ptr,
p_b1_grid + b1_batch_offset,
p_c_grid + c_batch_offset,
p_lse_grid + lse_batch_offset,
p_ygrad_grid + c_batch_offset,
p_qgrad_grid + a_batch_offset,
p_kgrad_grid + b_batch_offset,
p_vgrad_grid + b1_batch_offset,
p_shared,
a_element_op,
b_element_op,
acc_element_op,
b1_element_op,
c_element_op,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
b1_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock,
lse_grid_desc_m,
vgrad_grid_desc_n_o,
ygrad_grid_desc_m0_o_m1,
block_2_ctile_map,
c0_matrix_mask,
p_drop,
ph);
}
#else #else
ignore = p_a_grid; ignore = p_a_grid;
ignore = p_b_grid; ignore = p_b_grid;
...@@ -232,6 +274,7 @@ template <index_t NumDimG, ...@@ -232,6 +274,7 @@ template <index_t NumDimG,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock, index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
MaskingSpecialization MaskingSpec, MaskingSpecialization MaskingSpec,
bool Deterministic,
LoopScheduler LoopSched = LoopScheduler::Default> LoopScheduler LoopSched = LoopScheduler::Default>
struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
: public BaseOperator // TODO inherit atten bwd op once API stablizes : public BaseOperator // TODO inherit atten bwd op once API stablizes
...@@ -927,7 +970,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -927,7 +970,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
typename GridwiseGemm::DefaultBlock2CTileMap, typename GridwiseGemm::DefaultBlock2CTileMap,
ComputeBasePtrOfStridedBatch, ComputeBasePtrOfStridedBatch,
C0MatrixMask, C0MatrixMask,
has_main_k_block_loop_>; has_main_k_block_loop_,
Deterministic>;
return launch_and_time_kernel(stream_config, return launch_and_time_kernel(stream_config,
kernel, kernel,
......
...@@ -34,7 +34,8 @@ template <typename GridwiseGemm, ...@@ -34,7 +34,8 @@ template <typename GridwiseGemm,
typename AccElementwiseOperation, typename AccElementwiseOperation,
typename B1ElementwiseOperation, typename B1ElementwiseOperation,
typename CElementwiseOperation, typename CElementwiseOperation,
bool HasMainKBlockLoop> bool HasMainKBlockLoop,
bool Deterministic>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, /*CK_MIN_BLOCK_PER_CU*/ 1) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, /*CK_MIN_BLOCK_PER_CU*/ 1)
...@@ -99,35 +100,76 @@ __global__ void ...@@ -99,35 +100,76 @@ __global__ void
(arg_ptr[group_id].p_z_grid_ == nullptr ? nullptr (arg_ptr[group_id].p_z_grid_ == nullptr ? nullptr
: arg_ptr[group_id].p_z_grid_ + z_batch_offset); : arg_ptr[group_id].p_z_grid_ + z_batch_offset);
GridwiseGemm::template Run<HasMainKBlockLoop>( if constexpr(Deterministic)
arg_ptr[group_id].p_a_grid_ + a_batch_offset, {
arg_ptr[group_id].p_b_grid_ + b_batch_offset, for(index_t i = 0; i < num_blocks_per_batch; i++)
z_matrix_ptr, {
arg_ptr[group_id].p_b1_grid_ + b1_batch_offset, if(((block_id - arg_ptr[group_id].block_start_) % num_blocks_per_batch) == i)
arg_ptr[group_id].p_c_grid_ + c_batch_offset, {
arg_ptr[group_id].p_lse_grid_ + lse_batch_offset, GridwiseGemm::template Run<HasMainKBlockLoop>(
arg_ptr[group_id].p_ygrad_grid_ + c_batch_offset, arg_ptr[group_id].p_a_grid_ + a_batch_offset,
arg_ptr[group_id].p_qgrad_grid_ + a_batch_offset, arg_ptr[group_id].p_b_grid_ + b_batch_offset,
arg_ptr[group_id].p_kgrad_grid_ + b_batch_offset, z_matrix_ptr,
arg_ptr[group_id].p_vgrad_grid_ + b1_batch_offset, arg_ptr[group_id].p_b1_grid_ + b1_batch_offset,
p_shared, arg_ptr[group_id].p_c_grid_ + c_batch_offset,
a_element_op, arg_ptr[group_id].p_lse_grid_ + lse_batch_offset,
b_element_op, arg_ptr[group_id].p_ygrad_grid_ + c_batch_offset,
acc_element_op, arg_ptr[group_id].p_qgrad_grid_ + a_batch_offset,
b1_element_op, arg_ptr[group_id].p_kgrad_grid_ + b_batch_offset,
c_element_op, arg_ptr[group_id].p_vgrad_grid_ + b1_batch_offset,
arg_ptr[group_id].a_grid_desc_ak0_m_ak1_, p_shared,
arg_ptr[group_id].b_grid_desc_bk0_n_bk1_, a_element_op,
arg_ptr[group_id].c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_, b_element_op,
arg_ptr[group_id].b1_grid_desc_bk0_n_bk1_, acc_element_op,
arg_ptr[group_id].y_grid_desc_mblock_mperblock_oblock_operblock_, b1_element_op,
arg_ptr[group_id].lse_grid_desc_m_, c_element_op,
arg_ptr[group_id].vgrad_grid_desc_n_o_, arg_ptr[group_id].a_grid_desc_ak0_m_ak1_,
arg_ptr[group_id].ygrad_grid_desc_o0_m_o1_, arg_ptr[group_id].b_grid_desc_bk0_n_bk1_,
arg_ptr[group_id].block_2_ctile_map_, arg_ptr[group_id].c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_,
arg_ptr[group_id].c0_matrix_mask_, arg_ptr[group_id].b1_grid_desc_bk0_n_bk1_,
p_dropout, arg_ptr[group_id].y_grid_desc_mblock_mperblock_oblock_operblock_,
ph); arg_ptr[group_id].lse_grid_desc_m_,
arg_ptr[group_id].vgrad_grid_desc_n_o_,
arg_ptr[group_id].ygrad_grid_desc_o0_m_o1_,
arg_ptr[group_id].block_2_ctile_map_,
arg_ptr[group_id].c0_matrix_mask_,
p_dropout,
ph);
}
}
}
else
{
GridwiseGemm::template Run<HasMainKBlockLoop>(
arg_ptr[group_id].p_a_grid_ + a_batch_offset,
arg_ptr[group_id].p_b_grid_ + b_batch_offset,
z_matrix_ptr,
arg_ptr[group_id].p_b1_grid_ + b1_batch_offset,
arg_ptr[group_id].p_c_grid_ + c_batch_offset,
arg_ptr[group_id].p_lse_grid_ + lse_batch_offset,
arg_ptr[group_id].p_ygrad_grid_ + c_batch_offset,
arg_ptr[group_id].p_qgrad_grid_ + a_batch_offset,
arg_ptr[group_id].p_kgrad_grid_ + b_batch_offset,
arg_ptr[group_id].p_vgrad_grid_ + b1_batch_offset,
p_shared,
a_element_op,
b_element_op,
acc_element_op,
b1_element_op,
c_element_op,
arg_ptr[group_id].a_grid_desc_ak0_m_ak1_,
arg_ptr[group_id].b_grid_desc_bk0_n_bk1_,
arg_ptr[group_id].c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_,
arg_ptr[group_id].b1_grid_desc_bk0_n_bk1_,
arg_ptr[group_id].y_grid_desc_mblock_mperblock_oblock_operblock_,
arg_ptr[group_id].lse_grid_desc_m_,
arg_ptr[group_id].vgrad_grid_desc_n_o_,
arg_ptr[group_id].ygrad_grid_desc_o0_m_o1_,
arg_ptr[group_id].block_2_ctile_map_,
arg_ptr[group_id].c0_matrix_mask_,
p_dropout,
ph);
}
#else #else
ignore = group_kernel_args; ignore = group_kernel_args;
ignore = group_count; ignore = group_count;
...@@ -211,6 +253,7 @@ template <index_t NumDimG, ...@@ -211,6 +253,7 @@ template <index_t NumDimG,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock, index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
MaskingSpecialization MaskingSpec, MaskingSpecialization MaskingSpec,
bool Deterministic,
LoopScheduler LoopSched = LoopScheduler::Default> LoopScheduler LoopSched = LoopScheduler::Default>
struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1 struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1
: public BaseOperator // TODO inherit atten bwd op once API stablizes : public BaseOperator // TODO inherit atten bwd op once API stablizes
...@@ -920,7 +963,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -920,7 +963,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1
AccElementwiseOperation, AccElementwiseOperation,
B1ElementwiseOperation, B1ElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
has_main_k_block_loop_>; has_main_k_block_loop_,
Deterministic>;
return launch_and_time_kernel( return launch_and_time_kernel(
stream_config, stream_config,
......
...@@ -34,7 +34,8 @@ template <typename GridwiseGemm, ...@@ -34,7 +34,8 @@ template <typename GridwiseGemm,
typename AccElementwiseOperation, typename AccElementwiseOperation,
typename B1ElementwiseOperation, typename B1ElementwiseOperation,
typename CElementwiseOperation, typename CElementwiseOperation,
bool HasMainKBlockLoop> bool HasMainKBlockLoop,
bool Deterministic>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, /*CK_MIN_BLOCK_PER_CU*/ 1) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, /*CK_MIN_BLOCK_PER_CU*/ 1)
...@@ -99,35 +100,76 @@ __global__ void ...@@ -99,35 +100,76 @@ __global__ void
(arg_ptr[group_id].p_z_grid_ == nullptr ? nullptr (arg_ptr[group_id].p_z_grid_ == nullptr ? nullptr
: arg_ptr[group_id].p_z_grid_ + z_batch_offset); : arg_ptr[group_id].p_z_grid_ + z_batch_offset);
GridwiseGemm::template Run<HasMainKBlockLoop>( if constexpr(Deterministic)
arg_ptr[group_id].p_a_grid_ + a_batch_offset, {
arg_ptr[group_id].p_b_grid_ + b_batch_offset, for(index_t i = 0; i < num_blocks_per_batch; i++)
z_matrix_ptr, {
arg_ptr[group_id].p_b1_grid_ + b1_batch_offset, if(((block_id - arg_ptr[group_id].block_start_) % num_blocks_per_batch) == i)
arg_ptr[group_id].p_c_grid_ + c_batch_offset, {
arg_ptr[group_id].p_lse_grid_ + lse_batch_offset, GridwiseGemm::template Run<HasMainKBlockLoop>(
arg_ptr[group_id].p_ygrad_grid_ + c_batch_offset, arg_ptr[group_id].p_a_grid_ + a_batch_offset,
arg_ptr[group_id].p_qgrad_grid_ + a_batch_offset, arg_ptr[group_id].p_b_grid_ + b_batch_offset,
arg_ptr[group_id].p_kgrad_grid_ + b_batch_offset, z_matrix_ptr,
arg_ptr[group_id].p_vgrad_grid_ + b1_batch_offset, arg_ptr[group_id].p_b1_grid_ + b1_batch_offset,
p_shared, arg_ptr[group_id].p_c_grid_ + c_batch_offset,
a_element_op, arg_ptr[group_id].p_lse_grid_ + lse_batch_offset,
b_element_op, arg_ptr[group_id].p_ygrad_grid_ + c_batch_offset,
acc_element_op, arg_ptr[group_id].p_qgrad_grid_ + a_batch_offset,
b1_element_op, arg_ptr[group_id].p_kgrad_grid_ + b_batch_offset,
c_element_op, arg_ptr[group_id].p_vgrad_grid_ + b1_batch_offset,
arg_ptr[group_id].a_grid_desc_ak0_m_ak1_, p_shared,
arg_ptr[group_id].b_grid_desc_bk0_n_bk1_, a_element_op,
arg_ptr[group_id].c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_, b_element_op,
arg_ptr[group_id].b1_grid_desc_bk0_n_bk1_, acc_element_op,
arg_ptr[group_id].y_grid_desc_mblock_mperblock_oblock_operblock_, b1_element_op,
arg_ptr[group_id].lse_grid_desc_m_, c_element_op,
arg_ptr[group_id].vgrad_grid_desc_n_o_, arg_ptr[group_id].a_grid_desc_ak0_m_ak1_,
arg_ptr[group_id].ygrad_grid_desc_m0_o_m1_, arg_ptr[group_id].b_grid_desc_bk0_n_bk1_,
arg_ptr[group_id].block_2_ctile_map_, arg_ptr[group_id].c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_,
arg_ptr[group_id].c0_matrix_mask_, arg_ptr[group_id].b1_grid_desc_bk0_n_bk1_,
p_dropout, arg_ptr[group_id].y_grid_desc_mblock_mperblock_oblock_operblock_,
ph); arg_ptr[group_id].lse_grid_desc_m_,
arg_ptr[group_id].vgrad_grid_desc_n_o_,
arg_ptr[group_id].ygrad_grid_desc_m0_o_m1_,
arg_ptr[group_id].block_2_ctile_map_,
arg_ptr[group_id].c0_matrix_mask_,
p_dropout,
ph);
}
}
}
else
{
GridwiseGemm::template Run<HasMainKBlockLoop>(
arg_ptr[group_id].p_a_grid_ + a_batch_offset,
arg_ptr[group_id].p_b_grid_ + b_batch_offset,
z_matrix_ptr,
arg_ptr[group_id].p_b1_grid_ + b1_batch_offset,
arg_ptr[group_id].p_c_grid_ + c_batch_offset,
arg_ptr[group_id].p_lse_grid_ + lse_batch_offset,
arg_ptr[group_id].p_ygrad_grid_ + c_batch_offset,
arg_ptr[group_id].p_qgrad_grid_ + a_batch_offset,
arg_ptr[group_id].p_kgrad_grid_ + b_batch_offset,
arg_ptr[group_id].p_vgrad_grid_ + b1_batch_offset,
p_shared,
a_element_op,
b_element_op,
acc_element_op,
b1_element_op,
c_element_op,
arg_ptr[group_id].a_grid_desc_ak0_m_ak1_,
arg_ptr[group_id].b_grid_desc_bk0_n_bk1_,
arg_ptr[group_id].c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_,
arg_ptr[group_id].b1_grid_desc_bk0_n_bk1_,
arg_ptr[group_id].y_grid_desc_mblock_mperblock_oblock_operblock_,
arg_ptr[group_id].lse_grid_desc_m_,
arg_ptr[group_id].vgrad_grid_desc_n_o_,
arg_ptr[group_id].ygrad_grid_desc_m0_o_m1_,
arg_ptr[group_id].block_2_ctile_map_,
arg_ptr[group_id].c0_matrix_mask_,
p_dropout,
ph);
}
#else #else
ignore = group_kernel_args; ignore = group_kernel_args;
ignore = group_count; ignore = group_count;
...@@ -211,6 +253,7 @@ template <index_t NumDimG, ...@@ -211,6 +253,7 @@ template <index_t NumDimG,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock, index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
MaskingSpecialization MaskingSpec, MaskingSpecialization MaskingSpec,
bool Deterministic,
LoopScheduler LoopSched = LoopScheduler::Default> LoopScheduler LoopSched = LoopScheduler::Default>
struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2 struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
: public BaseOperator // TODO inherit atten bwd op once API stablizes : public BaseOperator // TODO inherit atten bwd op once API stablizes
...@@ -912,7 +955,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -912,7 +955,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
AccElementwiseOperation, AccElementwiseOperation,
B1ElementwiseOperation, B1ElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
has_main_k_block_loop_>; has_main_k_block_loop_,
Deterministic>;
return launch_and_time_kernel( return launch_and_time_kernel(
stream_config, stream_config,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment