"src/include/ConstantTensorDescriptor.hpp" did not exist on "17f3d2d4bccebcc3a70606a916f93dc90e5eaa3a"
Commit 99ebfeba authored by danyao12's avatar danyao12
Browse files

correct deterministic mode

parent 84a81ae2
......@@ -75,6 +75,7 @@ static constexpr auto TensorSpecA = ck::tensor_operation::device::TensorSpecial
static constexpr auto TensorSpecB0 = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecB1 = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr bool Deterministic = true;
#if(DIM <= 32)
using DeviceGemmInstance =
......@@ -145,7 +146,8 @@ using DeviceGemmInstance =
1, // CShuffleNXdlPerWavePerShuffle
S<1, 64, 1, 4>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec>; // MaskingSpecialization
MaskingSpec, // MaskingSpecialization
Deterministic>;
#elif(DIM <= 64)
using DeviceGemmInstance =
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle<
......@@ -215,7 +217,8 @@ using DeviceGemmInstance =
2, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec>; // MaskingSpecialization
MaskingSpec, // MaskingSpecialization
Deterministic>;
#elif(DIM <= 128)
using DeviceGemmInstance =
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle<
......@@ -285,7 +288,8 @@ using DeviceGemmInstance =
2, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec>; // MaskingSpecialization
MaskingSpec, // MaskingSpecialization
Deterministic>;
#endif
// Ref Gemm0: DataType in, AccDataType out
......
......@@ -104,6 +104,7 @@ static constexpr auto TensorSpecQ = ck::tensor_operation::device::TensorSpeciali
static constexpr auto TensorSpecK = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecV = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecY = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr bool Deterministic = true;
// DIM should be a multiple of 8.
// If DIM <= 32 , ues prototype1 1st template.
......@@ -178,7 +179,8 @@ using DeviceGemmInstanceFWD =
1, // CShuffleNXdlPerWavePerShuffle
S<1, 64, 1, 4>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec>; // MaskingSpecialization
MaskingSpec, // MaskingSpecialization
Deterministic>;
using DeviceGemmInstanceBWD =
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1<
......@@ -247,7 +249,8 @@ using DeviceGemmInstanceBWD =
1, // CShuffleNXdlPerWavePerShuffle
S<1, 64, 1, 4>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
CShuffleBlockTransferScalarPerVector_NPerBlock, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec>; // MaskingSpecialization
MaskingSpec, // MaskingSpecialization
Deterministic>;
#elif(DIM <= 64)
using DeviceGemmInstanceFWD =
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle<
......@@ -317,7 +320,8 @@ using DeviceGemmInstanceFWD =
2, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec>; // MaskingSpecialization
MaskingSpec, // MaskingSpecialization
Deterministic>;
using DeviceGemmInstanceBWD =
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1<
......@@ -386,7 +390,8 @@ using DeviceGemmInstanceBWD =
2, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
CShuffleBlockTransferScalarPerVector_NPerBlock, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec>; // MaskingSpecialization
MaskingSpec, // MaskingSpecialization
Deterministic>;
// using DeviceGemmInstanceBWD =
// ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2<
......@@ -455,7 +460,8 @@ using DeviceGemmInstanceBWD =
// 2, // CShuffleNXdlPerWavePerShuffle
// S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
// CShuffleBlockTransferScalarPerVector_NPerBlock,
// MaskingSpec>;
// MaskingSpec,
// Deterministic>;
#elif(DIM <= 128)
using DeviceGemmInstanceFWD =
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle<
......@@ -525,7 +531,8 @@ using DeviceGemmInstanceFWD =
2, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec>; // MaskingSpecialization
MaskingSpec, // MaskingSpecialization
Deterministic>;
using DeviceGemmInstanceBWD =
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2<
......@@ -594,7 +601,8 @@ using DeviceGemmInstanceBWD =
4, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
CShuffleBlockTransferScalarPerVector_NPerBlock, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec>; // MaskingSpecialization
MaskingSpec, // MaskingSpecialization
Deterministic>;
#endif
// Ref Gemm0: S = alpha * Q * K^T
......
......@@ -75,6 +75,7 @@ static constexpr auto TensorSpecA = ck::tensor_operation::device::TensorSpecial
static constexpr auto TensorSpecB0 = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecB1 = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr bool Deterministic = true;
#if(DIM <= 32)
using DeviceGemmInstance =
......@@ -145,7 +146,8 @@ using DeviceGemmInstance =
1, // CShuffleNXdlPerWavePerShuffle
S<1, 64, 1, 4>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec>; // MaskingSpecialization
MaskingSpec, // MaskingSpecialization
Deterministic>;
#elif(DIM <= 64)
using DeviceGemmInstance =
ck::tensor_operation::device::DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle<
......@@ -215,7 +217,8 @@ using DeviceGemmInstance =
2, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec>; // MaskingSpecialization
MaskingSpec, // MaskingSpecialization
Deterministic>;
#elif(DIM <= 128)
using DeviceGemmInstance =
ck::tensor_operation::device::DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle<
......@@ -285,7 +288,8 @@ using DeviceGemmInstance =
2, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec>; // MaskingSpecialization
MaskingSpec, // MaskingSpecialization
Deterministic>;
#endif
// Ref Gemm0: DataType in, AccDataType out
......
......@@ -103,6 +103,7 @@ static constexpr auto TensorSpecQ = ck::tensor_operation::device::TensorSpeciali
static constexpr auto TensorSpecK = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecV = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecY = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr bool Deterministic = true;
// DIM should be a multiple of 8.
// If DIM <= 32 , ues prototype1 1st template.
......@@ -177,7 +178,8 @@ using DeviceGemmInstanceFWD =
1, // CShuffleNXdlPerWavePerShuffle
S<1, 64, 1, 4>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec>; // MaskingSpecialization
MaskingSpec, // MaskingSpecialization
Deterministic>;
using DeviceGemmInstanceBWD =
ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1<
......@@ -246,7 +248,8 @@ using DeviceGemmInstanceBWD =
1, // CShuffleNXdlPerWavePerShuffle
S<1, 64, 1, 4>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
CShuffleBlockTransferScalarPerVector_NPerBlock, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec>; // MaskingSpecialization
MaskingSpec, // MaskingSpecialization
Deterministic>;
#elif(DIM <= 64)
using DeviceGemmInstanceFWD =
ck::tensor_operation::device::DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle<
......@@ -316,7 +319,8 @@ using DeviceGemmInstanceFWD =
2, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec>; // MaskingSpecialization
MaskingSpec, // MaskingSpecialization
Deterministic>;
using DeviceGemmInstanceBWD =
ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1<
......@@ -385,7 +389,8 @@ using DeviceGemmInstanceBWD =
2, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
CShuffleBlockTransferScalarPerVector_NPerBlock, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec>; // MaskingSpecialization
MaskingSpec, // MaskingSpecialization
Deterministic>;
// using DeviceGemmInstanceBWD =
// ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2<
......@@ -454,7 +459,8 @@ using DeviceGemmInstanceBWD =
// 2, // CShuffleNXdlPerWavePerShuffle
// S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
// CShuffleBlockTransferScalarPerVector_NPerBlock,
// MaskingSpec>;
// MaskingSpec,
// Deterministic>;
#elif(DIM <= 128)
using DeviceGemmInstanceFWD =
ck::tensor_operation::device::DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle<
......@@ -524,7 +530,8 @@ using DeviceGemmInstanceFWD =
2, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec>; // MaskingSpecialization
MaskingSpec, // MaskingSpecialization
Deterministic>;
using DeviceGemmInstanceBWD =
ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2<
......@@ -593,7 +600,8 @@ using DeviceGemmInstanceBWD =
4, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
CShuffleBlockTransferScalarPerVector_NPerBlock, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec>; // MaskingSpecialization
MaskingSpec, // MaskingSpecialization
Deterministic>;
#endif
// Ref Gemm0: S = alpha * Q * K^T
......
......@@ -82,6 +82,7 @@ __global__ void
const YGradGridDesc_O0_M_O1 ygrad_grid_desc_o0_m_o1,
const Block2CTileMap block_2_ctile_map,
const index_t batch_count,
const index_t mblock,
const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch,
const C0MatrixMask c0_matrix_mask,
const float p_drop,
......@@ -115,9 +116,7 @@ __global__ void
if constexpr(Deterministic)
{
for(index_t i = 0; i < num_blocks_per_batch; i++)
{
if(get_block_1d_id() % num_blocks_per_batch == i)
for(index_t i = 0; i < mblock; i++)
{
GridwiseGemm::template Run<HasMainKBlockLoop>(
p_a_grid + a_batch_offset,
......@@ -147,8 +146,8 @@ __global__ void
block_2_ctile_map,
c0_matrix_mask,
p_drop,
ph);
}
ph,
i);
}
}
else
......@@ -180,7 +179,8 @@ __global__ void
block_2_ctile_map,
c0_matrix_mask,
p_drop,
ph);
ph,
0);
}
#else
ignore = p_a_grid;
......@@ -707,7 +707,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopSched,
Transform::matrix_padder.PadN,
MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle>;
MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle,
Deterministic>;
// Argument
struct Argument : public BaseArgument
......@@ -941,7 +942,9 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
}
const index_t grid_size =
arg.block_2_ctile_map_.CalculateGridSize(arg.y_grid_desc_m_o_) * arg.batch_count_;
(Deterministic ? 1
: arg.block_2_ctile_map_.CalculateGridSize(arg.y_grid_desc_m_o_)) *
arg.batch_count_;
float ave_time = 0;
......@@ -971,7 +974,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
has_main_k_block_loop_,
Deterministic>;
return launch_and_time_kernel(stream_config,
return launch_and_time_kernel(
stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
......@@ -1001,6 +1005,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
arg.ygrad_grid_desc_o0_m_o1_,
arg.block_2_ctile_map_,
arg.batch_count_,
arg.block_2_ctile_map_.CalculateGridSize(arg.y_grid_desc_m_o_),
arg.compute_base_ptr_of_batch_,
arg.c0_matrix_mask_,
arg.p_drop_,
......
......@@ -81,6 +81,7 @@ __global__ void
const YGradGridDesc_M0_O_M1 ygrad_grid_desc_m0_o_m1,
const Block2CTileMap block_2_ctile_map,
const index_t batch_count,
const index_t mblock,
const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch,
const C0MatrixMask c0_matrix_mask,
const float p_drop,
......@@ -114,9 +115,7 @@ __global__ void
if constexpr(Deterministic)
{
for(index_t i = 0; i < num_blocks_per_batch; i++)
{
if(get_block_1d_id() % num_blocks_per_batch == i)
for(index_t i = 0; i < mblock; i++)
{
GridwiseGemm::template Run<HasMainKBlockLoop>(
p_a_grid + a_batch_offset,
......@@ -146,8 +145,8 @@ __global__ void
block_2_ctile_map,
c0_matrix_mask,
p_drop,
ph);
}
ph,
i);
}
}
else
......@@ -179,7 +178,8 @@ __global__ void
block_2_ctile_map,
c0_matrix_mask,
p_drop,
ph);
ph,
0);
}
#else
ignore = p_a_grid;
......@@ -706,7 +706,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopSched,
Transform::matrix_padder.PadN,
MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle>;
MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle,
Deterministic>;
// Argument
struct Argument : public BaseArgument
......@@ -939,7 +940,9 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
}
const index_t grid_size =
arg.block_2_ctile_map_.CalculateGridSize(arg.y_grid_desc_m_o_) * arg.batch_count_;
(Deterministic ? 1
: arg.block_2_ctile_map_.CalculateGridSize(arg.y_grid_desc_m_o_)) *
arg.batch_count_;
// Gemm0_K
const auto K =
......@@ -973,7 +976,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
has_main_k_block_loop_,
Deterministic>;
return launch_and_time_kernel(stream_config,
return launch_and_time_kernel(
stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
......@@ -1003,6 +1007,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
arg.ygrad_grid_desc_m0_o_m1_,
arg.block_2_ctile_map_,
arg.batch_count_,
arg.block_2_ctile_map_.CalculateGridSize(arg.y_grid_desc_m_o_),
arg.compute_base_ptr_of_batch_,
arg.c0_matrix_mask_,
arg.p_drop_,
......
......@@ -45,7 +45,8 @@ template <typename GridwiseGemm,
typename C0MatrixMask,
bool HasMainKBlockLoop,
bool IsDropout,
bool IsLseStoring>
bool IsLseStoring,
bool Deterministic>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
......@@ -72,6 +73,7 @@ __global__ void
const LSEGridDescriptor_M lse_grid_desc_m,
const Block2CTileMap block_2_ctile_map,
const index_t batch_count,
const index_t mblock,
const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch,
const C0MatrixMask c0_matrix_mask,
const ushort p_dropout_in_16bits,
......@@ -101,6 +103,39 @@ __global__ void
const index_t global_thread_id = get_thread_global_1d_id();
ck::philox ph(seed, global_thread_id, offset);
if constexpr(Deterministic)
{
for(index_t i = 0; i < mblock; i++)
{
GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout, IsLseStoring>(
p_a_grid + a_batch_offset,
p_b_grid + b_batch_offset,
p_b1_grid + b1_batch_offset,
p_c_grid + c_batch_offset,
nullptr ? nullptr : p_z_grid + z_batch_offset,
nullptr ? nullptr : p_lse_grid + lse_batch_offset,
p_shared,
a_element_op,
b_element_op,
acc_element_op,
b1_element_op,
c_element_op,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
b1_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
lse_grid_desc_m,
block_2_ctile_map,
c0_matrix_mask,
p_dropout_in_16bits,
p_dropout_rescale,
ph,
i);
}
}
else
{
GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout, IsLseStoring>(
p_a_grid + a_batch_offset,
p_b_grid + b_batch_offset,
......@@ -124,7 +159,9 @@ __global__ void
c0_matrix_mask,
p_dropout_in_16bits,
p_dropout_rescale,
ph);
ph,
0);
}
#else
ignore = p_a_grid;
ignore = p_b_grid;
......@@ -216,6 +253,7 @@ template <index_t NumDimG,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
MaskingSpecialization MaskingSpec,
bool Deterministic,
LoopScheduler LoopSched = LoopScheduler::Default>
struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
: public DeviceBatchedMultiheadAttentionForward<NumDimG,
......@@ -476,7 +514,8 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopSched,
Transform::matrix_padder.PadN,
MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle>;
MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle,
Deterministic>;
// Argument
// FIXME: constness
......@@ -695,7 +734,9 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
}
const index_t grid_size =
arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_) * arg.batch_count_;
(Deterministic ? 1
: arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_)) *
arg.batch_count_;
// Gemm0_K
const auto K =
......@@ -703,9 +744,8 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
float ave_time = 0;
auto launch_kernel = [&](auto has_main_k_block_loop_,
auto is_dropout_,
auto is_lse_storing_) {
auto launch_kernel =
[&](auto has_main_k_block_loop_, auto is_dropout_, auto is_lse_storing_) {
const auto kernel = kernel_batched_multiheadattention_forward_xdl_cshuffle<
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
......@@ -729,9 +769,11 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
C0MatrixMask,
has_main_k_block_loop_,
is_dropout_,
is_lse_storing_>;
is_lse_storing_,
Deterministic>;
return launch_and_time_kernel(stream_config,
return launch_and_time_kernel(
stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
......@@ -755,6 +797,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
arg.lse_grid_desc_m_,
arg.block_2_ctile_map_,
arg.batch_count_,
arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_),
arg.compute_base_ptr_of_batch_,
arg.c0_matrix_mask_,
arg.p_dropout_in_16bits_,
......
......@@ -79,7 +79,7 @@ __global__ void
// per-group batch offset
const index_t num_blocks_per_batch = arg_ptr[group_id].num_blocks_per_batch_;
const index_t g_idx = __builtin_amdgcn_readfirstlane(
(block_id - arg_ptr[group_id].block_start_) / num_blocks_per_batch);
(block_id - arg_ptr[group_id].block_start_) / (Deterministic ? 1 : num_blocks_per_batch));
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(arg_ptr[group_id].compute_base_ptr_of_batch_.GetABasePtr(g_idx)));
......@@ -103,8 +103,6 @@ __global__ void
if constexpr(Deterministic)
{
for(index_t i = 0; i < num_blocks_per_batch; i++)
{
if(((block_id - arg_ptr[group_id].block_start_) % num_blocks_per_batch) == i)
{
GridwiseGemm::template Run<HasMainKBlockLoop>(
arg_ptr[group_id].p_a_grid_ + a_batch_offset,
......@@ -134,8 +132,8 @@ __global__ void
arg_ptr[group_id].block_2_ctile_map_,
arg_ptr[group_id].c0_matrix_mask_,
p_dropout,
ph);
}
ph,
i);
}
}
else
......@@ -168,7 +166,8 @@ __global__ void
arg_ptr[group_id].block_2_ctile_map_,
arg_ptr[group_id].c0_matrix_mask_,
p_dropout,
ph);
ph,
0);
}
#else
ignore = group_kernel_args;
......@@ -643,7 +642,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1
CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopSched,
Transform::matrix_padder.PadN,
MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle>;
MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle,
Deterministic>;
using Block2CTileMap = OffsettedBlockToCTileMap<typename GridwiseGemm::DefaultBlock2CTileMap>;
......@@ -825,7 +825,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1
const index_t batch_count = c_grid_desc_g_m_n.GetLength(I0);
const index_t grid_size_grp =
block_2_ctile_map.CalculateGridSize(y_grid_desc_m_o) * batch_count;
(Deterministic ? 1 : block_2_ctile_map.CalculateGridSize(y_grid_desc_m_o)) *
batch_count;
const index_t BlockEnd = grid_size_ + grid_size_grp;
// batch stride
const auto compute_base_ptr_of_batch = ComputeBasePtrOfStridedBatch(
......
......@@ -79,7 +79,7 @@ __global__ void
// per-group batch offset
const index_t num_blocks_per_batch = arg_ptr[group_id].num_blocks_per_batch_;
const index_t g_idx = __builtin_amdgcn_readfirstlane(
(block_id - arg_ptr[group_id].block_start_) / num_blocks_per_batch);
(block_id - arg_ptr[group_id].block_start_) / (Deterministic ? 1 : num_blocks_per_batch));
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(arg_ptr[group_id].compute_base_ptr_of_batch_.GetABasePtr(g_idx)));
......@@ -103,8 +103,6 @@ __global__ void
if constexpr(Deterministic)
{
for(index_t i = 0; i < num_blocks_per_batch; i++)
{
if(((block_id - arg_ptr[group_id].block_start_) % num_blocks_per_batch) == i)
{
GridwiseGemm::template Run<HasMainKBlockLoop>(
arg_ptr[group_id].p_a_grid_ + a_batch_offset,
......@@ -134,8 +132,8 @@ __global__ void
arg_ptr[group_id].block_2_ctile_map_,
arg_ptr[group_id].c0_matrix_mask_,
p_dropout,
ph);
}
ph,
i);
}
}
else
......@@ -168,7 +166,8 @@ __global__ void
arg_ptr[group_id].block_2_ctile_map_,
arg_ptr[group_id].c0_matrix_mask_,
p_dropout,
ph);
ph,
0);
}
#else
ignore = group_kernel_args;
......@@ -636,7 +635,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopSched,
Transform::matrix_padder.PadN,
MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle>;
MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle,
Deterministic>;
using Block2CTileMap = OffsettedBlockToCTileMap<typename GridwiseGemm::DefaultBlock2CTileMap>;
......@@ -818,7 +818,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
const index_t batch_count = c_grid_desc_g_m_n.GetLength(I0);
const index_t grid_size_grp =
block_2_ctile_map.CalculateGridSize(y_grid_desc_m_o) * batch_count;
(Deterministic ? 1 : block_2_ctile_map.CalculateGridSize(y_grid_desc_m_o)) *
batch_count;
const index_t BlockEnd = grid_size_ + grid_size_grp;
// batch stride
const auto compute_base_ptr_of_batch = ComputeBasePtrOfStridedBatch(
......
......@@ -33,7 +33,8 @@ template <typename GridwiseGemm,
typename CElementwiseOperation,
bool HasMainKBlockLoop,
bool IsDropout,
bool IsLseStoring>
bool IsLseStoring,
bool Deterministic>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
......@@ -83,7 +84,7 @@ __global__ void
// per-group batch offset
const index_t num_blocks_per_batch = arg_ptr[group_id].num_blocks_per_batch_;
const index_t g_idx = __builtin_amdgcn_readfirstlane(
(block_id - arg_ptr[group_id].block_start_) / num_blocks_per_batch);
(block_id - arg_ptr[group_id].block_start_) / (Deterministic ? 1 : num_blocks_per_batch));
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(arg_ptr[group_id].compute_base_ptr_of_batch_.GetABasePtr(g_idx)));
......@@ -98,6 +99,44 @@ __global__ void
const long_index_t lse_batch_offset = __builtin_amdgcn_readfirstlane(static_cast<long_index_t>(
arg_ptr[group_id].compute_base_ptr_of_batch_.GetLSEBasePtr(g_idx)));
if constexpr(Deterministic)
{
for(index_t i = 0; i < num_blocks_per_batch; i++)
{
GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout, IsLseStoring>(
arg_ptr[group_id].p_a_grid_ + a_batch_offset,
arg_ptr[group_id].p_b_grid_ + b_batch_offset,
arg_ptr[group_id].p_b1_grid_ + b1_batch_offset,
arg_ptr[group_id].p_c_grid_ + c_batch_offset,
arg_ptr[group_id].p_z_grid_ == nullptr
? nullptr
: arg_ptr[group_id].p_z_grid_ + z_batch_offset,
arg_ptr[group_id].p_lse_grid_ == nullptr
? nullptr
: arg_ptr[group_id].p_lse_grid_ + lse_batch_offset,
// arg_ptr[group_id].p_lse_grid_ + lse_batch_offset,
p_shared,
a_element_op,
b_element_op,
acc_element_op,
b1_element_op,
c_element_op,
arg_ptr[group_id].a_grid_desc_ak0_m_ak1_,
arg_ptr[group_id].b_grid_desc_bk0_n_bk1_,
arg_ptr[group_id].b1_grid_desc_bk0_n_bk1_,
arg_ptr[group_id].c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg_ptr[group_id].z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_,
arg_ptr[group_id].lse_grid_desc_m_,
arg_ptr[group_id].block_2_ctile_map_,
arg_ptr[group_id].c0_matrix_mask_,
p_dropout_in_16bits,
p_dropout_rescale,
ph,
i);
}
}
else
{
GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout, IsLseStoring>(
arg_ptr[group_id].p_a_grid_ + a_batch_offset,
arg_ptr[group_id].p_b_grid_ + b_batch_offset,
......@@ -105,7 +144,8 @@ __global__ void
arg_ptr[group_id].p_c_grid_ + c_batch_offset,
arg_ptr[group_id].p_z_grid_ == nullptr ? nullptr
: arg_ptr[group_id].p_z_grid_ + z_batch_offset,
arg_ptr[group_id].p_lse_grid_ == nullptr ? nullptr
arg_ptr[group_id].p_lse_grid_ == nullptr
? nullptr
: arg_ptr[group_id].p_lse_grid_ + lse_batch_offset,
// arg_ptr[group_id].p_lse_grid_ + lse_batch_offset,
p_shared,
......@@ -124,7 +164,9 @@ __global__ void
arg_ptr[group_id].c0_matrix_mask_,
p_dropout_in_16bits,
p_dropout_rescale,
ph);
ph,
0);
}
#else
ignore = group_kernel_args;
ignore = group_count;
......@@ -206,6 +248,7 @@ template <index_t NumDimG,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
MaskingSpecialization MaskingSpec,
bool Deterministic,
LoopScheduler LoopSched = LoopScheduler::Default>
struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
: public DeviceGroupedMultiheadAttentionForward<NumDimG,
......@@ -487,7 +530,8 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopSched,
Transform::matrix_padder.PadN,
MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle>;
MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle,
Deterministic>;
using Block2CTileMap = OffsettedBlockToCTileMap<typename GridwiseGemm::DefaultBlock2CTileMap>;
......@@ -638,7 +682,8 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
const auto block_2_ctile_map = Block2CTileMap(c_grid_desc_m_n, BlockStart);
const index_t batch_count = c_grid_desc_g_m_n.GetLength(I0);
const index_t grid_size_grp =
block_2_ctile_map.CalculateGridSize(c_grid_desc_m_n) * batch_count;
(Deterministic ? 1 : block_2_ctile_map.CalculateGridSize(c_grid_desc_m_n)) *
batch_count;
const index_t BlockEnd = grid_size_ + grid_size_grp;
// batch stride
......@@ -778,7 +823,8 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
CElementwiseOperation,
has_main_k_block_loop_,
is_dropout_,
is_lse_storing_>;
is_lse_storing_,
Deterministic>;
return launch_and_time_kernel(
stream_config,
......
......@@ -86,6 +86,7 @@ template <typename InputDataType,
LoopScheduler LoopSched,
bool PadN,
bool MaskOutUpperTriangle,
bool Deterministic,
PipelineVersion PipelineVer = PipelineVersion::v1>
struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
{
......@@ -1265,7 +1266,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
const Block2CTileMap& block_2_ctile_map,
const C0MatrixMask& c0_matrix_mask,
const float p_drop,
ck::philox& ph)
ck::philox& ph,
const index_t block_idx_m)
{
const FloatGemmAcc p_dropout = type_convert<FloatGemmAcc>(1.0f - p_drop);
const FloatGemmAcc rp_dropout = type_convert<FloatGemmAcc>(1.0f / p_dropout);
......@@ -1305,9 +1307,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
return;
}
const index_t block_work_idx_m = Deterministic ? block_idx_m : block_work_idx[I0];
// HACK: this force m/o_block_data_idx_on_grid into SGPR
const index_t m_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock);
__builtin_amdgcn_readfirstlane(block_work_idx_m * MPerBlock);
// const index_t o_block_data_idx_on_grid =
// __builtin_amdgcn_readfirstlane(block_work_idx[I1] * Gemm1NPerBlock);
......@@ -1512,7 +1516,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
1,
false>{
lse_grid_desc_mblock_mrepeat_mwave_mperxdl,
make_multi_index(block_work_idx[I0], // mblock
make_multi_index(block_work_idx_m, // mblock
acc0_thread_origin[I0], // mrepeat
acc0_thread_origin[I2], // mwave
acc0_thread_origin[I4])}; // mperxdl
......@@ -1574,7 +1578,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
InMemoryDataOperationEnum::Set,
1, // DstScalarStrideInVector
true>{z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_multi_index(block_work_idx[I0], // MBlockId
make_multi_index(block_work_idx_m, // MBlockId
0, // NBlockId
0, // mrepeat
0, // nrepeat
......@@ -1720,7 +1724,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
ygrad_thread_cluster_idx * ygrad_thread_desc_m_o.GetLengths();
const auto y_thread_data_on_grid_idx =
make_multi_index(
block_work_idx[I0], I0, I0 /* all WGs start from o_block_idx = 0 */, I0) +
block_work_idx_m, I0, I0 /* all WGs start from o_block_idx = 0 */, I0) +
y_thread_data_on_block_idx;
// performs for y
......@@ -2320,7 +2324,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
{c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(0, 0, 0, 0),
qgrad_grid_desc_mblock_mperblock_kblock_kperblock,
make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0),
make_multi_index(block_work_idx_m, 0, block_work_idx[I1], 0),
c_element_op};
// space filling curve for threadwise C in VGPR
......
......@@ -86,6 +86,7 @@ template <typename InputDataType,
LoopScheduler LoopSched,
bool PadN,
bool MaskOutUpperTriangle,
bool Deterministic,
PipelineVersion PipelineVer = PipelineVersion::v1>
struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
{
......@@ -1175,7 +1176,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
const Block2CTileMap& block_2_ctile_map,
const C0MatrixMask& c0_matrix_mask,
const float p_drop,
ck::philox& ph)
ck::philox& ph,
const index_t block_idx_m)
{
const FloatGemmAcc p_dropout = type_convert<FloatGemmAcc>(1.0f - p_drop);
const FloatGemmAcc rp_dropout = type_convert<FloatGemmAcc>(1.0f / p_dropout);
......@@ -1215,9 +1217,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
return;
}
const index_t block_work_idx_m = Deterministic ? block_idx_m : block_work_idx[I0];
// HACK: this force m/o_block_data_idx_on_grid into SGPR
const index_t m_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock);
__builtin_amdgcn_readfirstlane(block_work_idx_m * MPerBlock);
const index_t o_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I1] * Gemm1NPerBlock);
......@@ -1444,7 +1448,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
1,
false>{
lse_grid_desc_mblock_mrepeat_mwave_mperxdl,
make_multi_index(block_work_idx[I0], // mblock
make_multi_index(block_work_idx_m, // mblock
acc0_thread_origin[I0], // mrepeat
acc0_thread_origin[I2], // mwave
acc0_thread_origin[I4])}; // mperxdl
......@@ -1506,7 +1510,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
InMemoryDataOperationEnum::Set,
1, // DstScalarStrideInVector
true>{z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_multi_index(block_work_idx[I0], // MBlockId
make_multi_index(block_work_idx_m, // MBlockId
0, // NBlockId
0, // mrepeat
0, // nrepeat
......@@ -1643,7 +1647,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
y_thread_cluster_idx * y_thread_desc_m0_m1_o0_o1.GetLengths();
const auto y_thread_data_on_grid_idx =
make_multi_index(
block_work_idx[I0], I0, I0 /* all WGs start from o_block_idx = 0 */, I0) +
block_work_idx_m, I0, I0 /* all WGs start from o_block_idx = 0 */, I0) +
y_thread_data_on_block_idx;
// performs double duty for both y and ygrad
......@@ -2270,7 +2274,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
{c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(0, 0, 0, 0),
qgrad_grid_desc_mblock_mperblock_kblock_kperblock,
make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0),
make_multi_index(block_work_idx_m, 0, block_work_idx[I1], 0),
c_element_op};
// space filling curve for threadwise C in VGPR
......
......@@ -85,6 +85,7 @@ template <typename FloatAB,
LoopScheduler LoopSched,
bool PadN,
bool MaskOutUpperTriangle,
bool Deterministic,
PipelineVersion PipelineVer = PipelineVersion::v1>
struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
{
......@@ -445,7 +446,8 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
const C0MatrixMask& c0_matrix_mask,
const ushort p_dropout_in_16bits,
FloatGemmAcc p_dropout_rescale,
ck::philox ph)
ck::philox& ph,
const index_t block_idx_m)
{
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
......@@ -470,9 +472,11 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
return;
}
const index_t block_work_idx_m = Deterministic ? block_idx_m : block_work_idx[I0];
// HACK: this force m/gemm1_n_block_data_idx_on_grid into SGPR
const index_t m_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock);
__builtin_amdgcn_readfirstlane(block_work_idx_m * MPerBlock);
const index_t gemm1_n_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I1] * Gemm1NPerBlock);
......@@ -835,7 +839,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
InMemoryDataOperationEnum::Set,
1,
false>{lse_grid_desc_mblock_mrepeat_mwave_mperxdl,
make_multi_index(block_work_idx[I0], // mblock
make_multi_index(block_work_idx_m, // mblock
0, // mrepeat
acc0_thread_origin[I2], // mwave
acc0_thread_origin[I4]), // mperxdl
......@@ -897,7 +901,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
InMemoryDataOperationEnum::Set,
1, // DstScalarStrideInVector
true>{z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_multi_index(block_work_idx[I0], // MBlockId
make_multi_index(block_work_idx_m, // MBlockId
0, // NBlockId
0, // mrepeat
0, // nrepeat
......@@ -1319,7 +1323,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
{c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(0, 0, 0, 0),
c_grid_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0),
make_multi_index(block_work_idx_m, 0, block_work_idx[I1], 0),
c_element_op};
// space filling curve for threadwise C in VGPR
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment