Commit 454e9445 authored by letaoqin's avatar letaoqin
Browse files

remove-fwd-deterministic

parent b23b3d71
...@@ -75,7 +75,6 @@ static constexpr auto TensorSpecA = ck::tensor_operation::device::TensorSpecia ...@@ -75,7 +75,6 @@ static constexpr auto TensorSpecA = ck::tensor_operation::device::TensorSpecia
static constexpr auto TensorSpecB0 = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecB0 = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecB1 = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecB1 = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr bool Deterministic = false;
#if(DIM <= 32) #if(DIM <= 32)
using DeviceGemmInstance = using DeviceGemmInstance =
...@@ -149,8 +148,7 @@ using DeviceGemmInstance = ...@@ -149,8 +148,7 @@ using DeviceGemmInstance =
S<1, 64, 1, 4>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock S<1, 64, 1, 4>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock 8, // CShuffleBlockTransferScalarPerVector_NPerBlock
4, 4,
MaskingSpec, // MaskingSpecialization MaskingSpec>; // MaskingSpecialization
Deterministic>;
#elif(DIM <= 64) #elif(DIM <= 64)
using DeviceGemmInstance = using DeviceGemmInstance =
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2< ck::tensor_operation::device::DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2<
...@@ -223,8 +221,7 @@ using DeviceGemmInstance = ...@@ -223,8 +221,7 @@ using DeviceGemmInstance =
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock 8, // CShuffleBlockTransferScalarPerVector_NPerBlock
4, 4,
MaskingSpec, // MaskingSpecialization MaskingSpec>; // MaskingSpecialization
Deterministic>;
#elif(DIM <= 128) #elif(DIM <= 128)
using DeviceGemmInstance = using DeviceGemmInstance =
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2< ck::tensor_operation::device::DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2<
...@@ -297,8 +294,7 @@ using DeviceGemmInstance = ...@@ -297,8 +294,7 @@ using DeviceGemmInstance =
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock 8, // CShuffleBlockTransferScalarPerVector_NPerBlock
4, 4,
MaskingSpec, // MaskingSpecialization MaskingSpec>; // MaskingSpecialization
Deterministic>;
#endif #endif
// Ref Gemm0: DataType in, AccDataType out // Ref Gemm0: DataType in, AccDataType out
......
...@@ -75,7 +75,6 @@ static constexpr auto TensorSpecA = ck::tensor_operation::device::TensorSpecia ...@@ -75,7 +75,6 @@ static constexpr auto TensorSpecA = ck::tensor_operation::device::TensorSpecia
static constexpr auto TensorSpecB0 = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecB0 = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecB1 = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecB1 = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr bool Deterministic = true;
#if(DIM <= 32) #if(DIM <= 32)
using DeviceGemmInstance = using DeviceGemmInstance =
...@@ -149,8 +148,7 @@ using DeviceGemmInstance = ...@@ -149,8 +148,7 @@ using DeviceGemmInstance =
S<1, 64, 1, 4>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock S<1, 64, 1, 4>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock 8, // CShuffleBlockTransferScalarPerVector_NPerBlock
1, 1,
MaskingSpec, // MaskingSpecialization MaskingSpec>; // MaskingSpecialization
Deterministic>;
#elif(DIM <= 64) #elif(DIM <= 64)
using DeviceGemmInstance = using DeviceGemmInstance =
ck::tensor_operation::device::DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2< ck::tensor_operation::device::DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2<
...@@ -223,8 +221,7 @@ using DeviceGemmInstance = ...@@ -223,8 +221,7 @@ using DeviceGemmInstance =
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock 8, // CShuffleBlockTransferScalarPerVector_NPerBlock
1, 1,
MaskingSpec, // MaskingSpecialization MaskingSpec>; // MaskingSpecialization
Deterministic>;
#elif(DIM <= 128) #elif(DIM <= 128)
using DeviceGemmInstance = using DeviceGemmInstance =
ck::tensor_operation::device::DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2< ck::tensor_operation::device::DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2<
...@@ -297,8 +294,7 @@ using DeviceGemmInstance = ...@@ -297,8 +294,7 @@ using DeviceGemmInstance =
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock 8, // CShuffleBlockTransferScalarPerVector_NPerBlock
1, 1,
MaskingSpec, // MaskingSpecialization MaskingSpec>; // MaskingSpecialization
Deterministic>;
#endif #endif
// Ref Gemm0: DataType in, AccDataType out // Ref Gemm0: DataType in, AccDataType out
......
...@@ -75,7 +75,6 @@ static constexpr auto TensorSpecA = ck::tensor_operation::device::TensorSpecia ...@@ -75,7 +75,6 @@ static constexpr auto TensorSpecA = ck::tensor_operation::device::TensorSpecia
static constexpr auto TensorSpecB0 = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecB0 = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecB1 = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecB1 = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr bool Deterministic = false;
#if(DIM <= 32) #if(DIM <= 32)
using DeviceGemmInstance = using DeviceGemmInstance =
...@@ -149,8 +148,7 @@ using DeviceGemmInstance = ...@@ -149,8 +148,7 @@ using DeviceGemmInstance =
S<1, 64, 1, 4>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock S<1, 64, 1, 4>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock 8, // CShuffleBlockTransferScalarPerVector_NPerBlock
4, 4,
MaskingSpec, // MaskingSpecialization MaskingSpec>; // MaskingSpecialization
Deterministic>;
#elif(DIM <= 64) #elif(DIM <= 64)
using DeviceGemmInstance = using DeviceGemmInstance =
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2< ck::tensor_operation::device::DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2<
...@@ -223,8 +221,7 @@ using DeviceGemmInstance = ...@@ -223,8 +221,7 @@ using DeviceGemmInstance =
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock 8, // CShuffleBlockTransferScalarPerVector_NPerBlock
4, 4,
MaskingSpec, // MaskingSpecialization MaskingSpec>; // MaskingSpecialization
Deterministic>;
#elif(DIM <= 128) #elif(DIM <= 128)
using DeviceGemmInstance = using DeviceGemmInstance =
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2< ck::tensor_operation::device::DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2<
...@@ -297,8 +294,7 @@ using DeviceGemmInstance = ...@@ -297,8 +294,7 @@ using DeviceGemmInstance =
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock 8, // CShuffleBlockTransferScalarPerVector_NPerBlock
4, 4,
MaskingSpec, // MaskingSpecialization MaskingSpec>; // MaskingSpecialization
Deterministic>;
#endif #endif
// Ref Gemm0: DataType in, AccDataType out // Ref Gemm0: DataType in, AccDataType out
......
...@@ -75,7 +75,6 @@ static constexpr auto TensorSpecA = ck::tensor_operation::device::TensorSpecia ...@@ -75,7 +75,6 @@ static constexpr auto TensorSpecA = ck::tensor_operation::device::TensorSpecia
static constexpr auto TensorSpecB0 = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecB0 = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecB1 = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecB1 = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr bool Deterministic = false;
#if(DIM <= 32) #if(DIM <= 32)
using DeviceGemmInstance = using DeviceGemmInstance =
...@@ -149,8 +148,7 @@ using DeviceGemmInstance = ...@@ -149,8 +148,7 @@ using DeviceGemmInstance =
S<1, 64, 1, 4>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock S<1, 64, 1, 4>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock 8, // CShuffleBlockTransferScalarPerVector_NPerBlock
4, 4,
MaskingSpec, // MaskingSpecialization MaskingSpec>; // MaskingSpecialization
Deterministic>;
#elif(DIM <= 64) #elif(DIM <= 64)
using DeviceGemmInstance = using DeviceGemmInstance =
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2< ck::tensor_operation::device::DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2<
...@@ -223,8 +221,7 @@ using DeviceGemmInstance = ...@@ -223,8 +221,7 @@ using DeviceGemmInstance =
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock 8, // CShuffleBlockTransferScalarPerVector_NPerBlock
4, 4,
MaskingSpec, // MaskingSpecialization MaskingSpec>; // MaskingSpecialization
Deterministic>;
#elif(DIM <= 128) #elif(DIM <= 128)
using DeviceGemmInstance = using DeviceGemmInstance =
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2< ck::tensor_operation::device::DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2<
...@@ -297,8 +294,7 @@ using DeviceGemmInstance = ...@@ -297,8 +294,7 @@ using DeviceGemmInstance =
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock 8, // CShuffleBlockTransferScalarPerVector_NPerBlock
4, 4,
MaskingSpec, // MaskingSpecialization MaskingSpec>; // MaskingSpecialization
Deterministic>;
#endif #endif
using DeviceDropoutInstance = ck::tensor_operation::device::DeviceBatchedDropout<NumDimG, using DeviceDropoutInstance = ck::tensor_operation::device::DeviceBatchedDropout<NumDimG,
......
...@@ -75,7 +75,6 @@ static constexpr auto TensorSpecA = ck::tensor_operation::device::TensorSpecia ...@@ -75,7 +75,6 @@ static constexpr auto TensorSpecA = ck::tensor_operation::device::TensorSpecia
static constexpr auto TensorSpecB0 = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecB0 = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecB1 = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecB1 = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr bool Deterministic = false;
#if(DIM <= 32) #if(DIM <= 32)
using DeviceGemmInstance = using DeviceGemmInstance =
...@@ -149,8 +148,7 @@ using DeviceGemmInstance = ...@@ -149,8 +148,7 @@ using DeviceGemmInstance =
S<1, 64, 1, 4>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock S<1, 64, 1, 4>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock 8, // CShuffleBlockTransferScalarPerVector_NPerBlock
1, 1,
MaskingSpec, // MaskingSpecialization MaskingSpec>; // MaskingSpecialization
Deterministic>;
#elif(DIM <= 64) #elif(DIM <= 64)
using DeviceGemmInstance = using DeviceGemmInstance =
ck::tensor_operation::device::DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2< ck::tensor_operation::device::DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2<
...@@ -223,8 +221,7 @@ using DeviceGemmInstance = ...@@ -223,8 +221,7 @@ using DeviceGemmInstance =
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock 8, // CShuffleBlockTransferScalarPerVector_NPerBlock
1, 1,
MaskingSpec, // MaskingSpecialization MaskingSpec>; // MaskingSpecialization
Deterministic>;
#elif(DIM <= 128) #elif(DIM <= 128)
using DeviceGemmInstance = using DeviceGemmInstance =
ck::tensor_operation::device::DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2< ck::tensor_operation::device::DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2<
...@@ -297,8 +294,7 @@ using DeviceGemmInstance = ...@@ -297,8 +294,7 @@ using DeviceGemmInstance =
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock 8, // CShuffleBlockTransferScalarPerVector_NPerBlock
1, 1,
MaskingSpec, // MaskingSpecialization MaskingSpec>; // MaskingSpecialization
Deterministic>;
#endif #endif
// Ref Gemm0: DataType in, AccDataType out // Ref Gemm0: DataType in, AccDataType out
......
...@@ -47,8 +47,7 @@ template <typename GridwiseGemm, ...@@ -47,8 +47,7 @@ template <typename GridwiseGemm,
typename C0MatrixMask, typename C0MatrixMask,
bool HasMainKBlockLoop, bool HasMainKBlockLoop,
bool IsDropout, bool IsDropout,
bool IsLseStoring, bool IsLseStoring>
bool Deterministic>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
...@@ -78,7 +77,6 @@ __global__ void ...@@ -78,7 +77,6 @@ __global__ void
const LSEGridDescriptor_M lse_grid_desc_m, const LSEGridDescriptor_M lse_grid_desc_m,
const Block2CTileMap block_2_ctile_map, const Block2CTileMap block_2_ctile_map,
const index_t batch_count, const index_t batch_count,
const index_t mblock,
const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch, const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch,
const C0MatrixMask c0_matrix_mask, const C0MatrixMask c0_matrix_mask,
const uint8_t p_dropout_in_uint8_t, const uint8_t p_dropout_in_uint8_t,
...@@ -122,43 +120,6 @@ __global__ void ...@@ -122,43 +120,6 @@ __global__ void
const index_t z_random_matrix_offset = g_idx * raw_m_padded * raw_n_padded; const index_t z_random_matrix_offset = g_idx * raw_m_padded * raw_n_padded;
if constexpr(Deterministic)
{
for(index_t i = 0; i < mblock; i++)
{
GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout, IsLseStoring>(
p_a_grid + a_batch_offset,
p_b_grid + b_batch_offset,
tmp_p_d0_grid,
p_b1_grid + b1_batch_offset,
p_c_grid + c_batch_offset,
p_z_grid == nullptr ? nullptr : p_z_grid + z_batch_offset,
p_lse_grid == nullptr ? nullptr : p_lse_grid + lse_batch_offset,
p_shared,
a_element_op,
b_element_op,
acc_element_op,
b1_element_op,
c_element_op,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
d0_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
b1_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
lse_grid_desc_m,
block_2_ctile_map,
c0_matrix_mask,
p_dropout_in_uint8_t,
p_dropout_rescale,
ph,
z_random_matrix_offset,
raw_n_padded,
i);
}
}
else
{
GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout, IsLseStoring>( GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout, IsLseStoring>(
p_a_grid + a_batch_offset, p_a_grid + a_batch_offset,
p_b_grid + b_batch_offset, p_b_grid + b_batch_offset,
...@@ -186,9 +147,7 @@ __global__ void ...@@ -186,9 +147,7 @@ __global__ void
p_dropout_rescale, p_dropout_rescale,
ph, ph,
z_random_matrix_offset, z_random_matrix_offset,
raw_n_padded, raw_n_padded);
0);
}
#else #else
ignore = p_a_grid; ignore = p_a_grid;
ignore = p_b_grid; ignore = p_b_grid;
...@@ -211,7 +170,6 @@ __global__ void ...@@ -211,7 +170,6 @@ __global__ void
ignore = lse_grid_desc_m; ignore = lse_grid_desc_m;
ignore = block_2_ctile_map; ignore = block_2_ctile_map;
ignore = batch_count; ignore = batch_count;
ignore = mblock;
ignore = compute_base_ptr_of_batch; ignore = compute_base_ptr_of_batch;
ignore = c0_matrix_mask; ignore = c0_matrix_mask;
ignore = p_dropout_in_uint8_t; ignore = p_dropout_in_uint8_t;
...@@ -296,7 +254,6 @@ template <index_t NumDimG, ...@@ -296,7 +254,6 @@ template <index_t NumDimG,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock, index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
index_t Acc1BiasTransferSrcScalarPerVector, index_t Acc1BiasTransferSrcScalarPerVector,
MaskingSpecialization MaskingSpec, MaskingSpecialization MaskingSpec,
bool Deterministic,
LoopScheduler LoopSched = LoopScheduler::Default> LoopScheduler LoopSched = LoopScheduler::Default>
struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2 struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
: public DeviceBatchedMultiheadAttentionForward<NumDimG, : public DeviceBatchedMultiheadAttentionForward<NumDimG,
...@@ -576,8 +533,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -576,8 +533,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
Acc1BiasTransferSrcScalarPerVector, Acc1BiasTransferSrcScalarPerVector,
LoopSched, LoopSched,
Transform::matrix_padder.PadN, Transform::matrix_padder.PadN,
MaskingSpec != MaskingSpecialization::MaskDisabled, MaskingSpec != MaskingSpecialization::MaskDisabled>;
Deterministic>;
// Argument // Argument
// FIXME: constness // FIXME: constness
...@@ -833,9 +789,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -833,9 +789,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
} }
const index_t grid_size = const index_t grid_size =
(Deterministic ? 1 arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_) * arg.batch_count_;
: arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_)) *
arg.batch_count_;
// Gemm0_K // Gemm0_K
const auto K = const auto K =
...@@ -843,8 +797,9 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -843,8 +797,9 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
float ave_time = 0; float ave_time = 0;
auto launch_kernel = auto launch_kernel = [&](auto has_main_k_block_loop_,
[&](auto has_main_k_block_loop_, auto is_dropout_, auto is_lse_storing_) { auto is_dropout_,
auto is_lse_storing_) {
const auto kernel = kernel_batched_multiheadattention_forward_xdl_cshuffle_v2< const auto kernel = kernel_batched_multiheadattention_forward_xdl_cshuffle_v2<
GridwiseGemm, GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype ADataType, // TODO: distiguish A/B datatype
...@@ -870,11 +825,9 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -870,11 +825,9 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
C0MatrixMask, C0MatrixMask,
has_main_k_block_loop_, has_main_k_block_loop_,
is_dropout_, is_dropout_,
is_lse_storing_, is_lse_storing_>;
Deterministic>;
return launch_and_time_kernel( return launch_and_time_kernel(stream_config,
stream_config,
kernel, kernel,
dim3(grid_size), dim3(grid_size),
dim3(BlockSize), dim3(BlockSize),
...@@ -900,7 +853,6 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -900,7 +853,6 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
arg.lse_grid_desc_m_, arg.lse_grid_desc_m_,
arg.block_2_ctile_map_, arg.block_2_ctile_map_,
arg.batch_count_, arg.batch_count_,
arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_),
arg.compute_base_ptr_of_batch_, arg.compute_base_ptr_of_batch_,
arg.c0_matrix_mask_, arg.c0_matrix_mask_,
arg.p_dropout_in_uint8_t_, arg.p_dropout_in_uint8_t_,
......
...@@ -35,8 +35,7 @@ template <typename GridwiseGemm, ...@@ -35,8 +35,7 @@ template <typename GridwiseGemm,
typename CElementwiseOperation, typename CElementwiseOperation,
bool HasMainKBlockLoop, bool HasMainKBlockLoop,
bool IsDropout, bool IsDropout,
bool IsLseStoring, bool IsLseStoring>
bool Deterministic>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
...@@ -87,7 +86,7 @@ __global__ void ...@@ -87,7 +86,7 @@ __global__ void
// per-group batch offset // per-group batch offset
const index_t num_blocks_per_batch = arg_ptr[group_id].num_blocks_per_batch_; const index_t num_blocks_per_batch = arg_ptr[group_id].num_blocks_per_batch_;
const index_t g_idx = __builtin_amdgcn_readfirstlane( const index_t g_idx = __builtin_amdgcn_readfirstlane(
(block_id - arg_ptr[group_id].block_start_) / (Deterministic ? 1 : num_blocks_per_batch)); (block_id - arg_ptr[group_id].block_start_) / num_blocks_per_batch);
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(arg_ptr[group_id].compute_base_ptr_of_batch_.GetABasePtr(g_idx))); static_cast<long_index_t>(arg_ptr[group_id].compute_base_ptr_of_batch_.GetABasePtr(g_idx)));
...@@ -113,49 +112,6 @@ __global__ void ...@@ -113,49 +112,6 @@ __global__ void
tmp_p_d0_grid = arg_ptr[group_id].p_d0_grid_ + d0_batch_offset; tmp_p_d0_grid = arg_ptr[group_id].p_d0_grid_ + d0_batch_offset;
} }
if constexpr(Deterministic)
{
for(index_t i = 0; i < num_blocks_per_batch; i++)
{
GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout, IsLseStoring>(
arg_ptr[group_id].p_a_grid_ + a_batch_offset,
arg_ptr[group_id].p_b_grid_ + b_batch_offset,
tmp_p_d0_grid,
arg_ptr[group_id].p_b1_grid_ + b1_batch_offset,
arg_ptr[group_id].p_c_grid_ + c_batch_offset,
arg_ptr[group_id].p_z_grid_ == nullptr
? nullptr
: arg_ptr[group_id].p_z_grid_ + z_batch_offset,
arg_ptr[group_id].p_lse_grid_ == nullptr
? nullptr
: arg_ptr[group_id].p_lse_grid_ + lse_batch_offset,
// arg_ptr[group_id].p_lse_grid_ + lse_batch_offset,
p_shared,
a_element_op,
b_element_op,
acc_element_op,
b1_element_op,
c_element_op,
arg_ptr[group_id].a_grid_desc_ak0_m_ak1_,
arg_ptr[group_id].b_grid_desc_bk0_n_bk1_,
arg_ptr[group_id].d0_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_,
arg_ptr[group_id].b1_grid_desc_bk0_n_bk1_,
arg_ptr[group_id].c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg_ptr[group_id].z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_,
arg_ptr[group_id].lse_grid_desc_m_,
arg_ptr[group_id].block_2_ctile_map_,
arg_ptr[group_id].c0_matrix_mask_,
p_dropout_in_uint8_t,
p_dropout_rescale,
ph,
arg_ptr[group_id].z_random_matrix_offset_ +
g_idx * arg_ptr[group_id].raw_m_padded_ * arg_ptr[group_id].raw_n_padded_,
arg_ptr[group_id].raw_n_padded_,
i);
}
}
else
{
GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout, IsLseStoring>( GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout, IsLseStoring>(
arg_ptr[group_id].p_a_grid_ + a_batch_offset, arg_ptr[group_id].p_a_grid_ + a_batch_offset,
arg_ptr[group_id].p_b_grid_ + b_batch_offset, arg_ptr[group_id].p_b_grid_ + b_batch_offset,
...@@ -164,8 +120,7 @@ __global__ void ...@@ -164,8 +120,7 @@ __global__ void
arg_ptr[group_id].p_c_grid_ + c_batch_offset, arg_ptr[group_id].p_c_grid_ + c_batch_offset,
arg_ptr[group_id].p_z_grid_ == nullptr ? nullptr arg_ptr[group_id].p_z_grid_ == nullptr ? nullptr
: arg_ptr[group_id].p_z_grid_ + z_batch_offset, : arg_ptr[group_id].p_z_grid_ + z_batch_offset,
arg_ptr[group_id].p_lse_grid_ == nullptr arg_ptr[group_id].p_lse_grid_ == nullptr ? nullptr
? nullptr
: arg_ptr[group_id].p_lse_grid_ + lse_batch_offset, : arg_ptr[group_id].p_lse_grid_ + lse_batch_offset,
// arg_ptr[group_id].p_lse_grid_ + lse_batch_offset, // arg_ptr[group_id].p_lse_grid_ + lse_batch_offset,
p_shared, p_shared,
...@@ -188,9 +143,7 @@ __global__ void ...@@ -188,9 +143,7 @@ __global__ void
ph, ph,
arg_ptr[group_id].z_random_matrix_offset_ + arg_ptr[group_id].z_random_matrix_offset_ +
g_idx * arg_ptr[group_id].raw_m_padded_ * arg_ptr[group_id].raw_n_padded_, g_idx * arg_ptr[group_id].raw_m_padded_ * arg_ptr[group_id].raw_n_padded_,
arg_ptr[group_id].raw_n_padded_, arg_ptr[group_id].raw_n_padded_);
0);
}
#else #else
ignore = group_kernel_args; ignore = group_kernel_args;
ignore = group_count; ignore = group_count;
...@@ -279,7 +232,6 @@ template <index_t NumDimG, ...@@ -279,7 +232,6 @@ template <index_t NumDimG,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock, index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
index_t Acc1BiasTransferSrcScalarPerVector, index_t Acc1BiasTransferSrcScalarPerVector,
MaskingSpecialization MaskingSpec, MaskingSpecialization MaskingSpec,
bool Deterministic,
LoopScheduler LoopSched = LoopScheduler::Default> LoopScheduler LoopSched = LoopScheduler::Default>
struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2 struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
: public DeviceGroupedMultiheadAttentionForward<NumDimG, : public DeviceGroupedMultiheadAttentionForward<NumDimG,
...@@ -597,8 +549,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -597,8 +549,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
Acc1BiasTransferSrcScalarPerVector, Acc1BiasTransferSrcScalarPerVector,
LoopSched, LoopSched,
Transform::matrix_padder.PadN, Transform::matrix_padder.PadN,
MaskingSpec != MaskingSpecialization::MaskDisabled, MaskingSpec != MaskingSpecialization::MaskDisabled>;
Deterministic>;
using Block2CTileMap = OffsettedBlockToCTileMap<typename GridwiseGemm::DefaultBlock2CTileMap>; using Block2CTileMap = OffsettedBlockToCTileMap<typename GridwiseGemm::DefaultBlock2CTileMap>;
...@@ -783,8 +734,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -783,8 +734,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
const auto block_2_ctile_map = Block2CTileMap(c_grid_desc_m_n, BlockStart); const auto block_2_ctile_map = Block2CTileMap(c_grid_desc_m_n, BlockStart);
const index_t batch_count = c_grid_desc_g_m_n.GetLength(I0); const index_t batch_count = c_grid_desc_g_m_n.GetLength(I0);
const index_t grid_size_grp = const index_t grid_size_grp =
(Deterministic ? 1 : block_2_ctile_map.CalculateGridSize(c_grid_desc_m_n)) * block_2_ctile_map.CalculateGridSize(c_grid_desc_m_n) * batch_count;
batch_count;
const index_t BlockEnd = grid_size_ + grid_size_grp; const index_t BlockEnd = grid_size_ + grid_size_grp;
// batch stride // batch stride
...@@ -958,8 +908,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -958,8 +908,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
CElementwiseOperation, CElementwiseOperation,
has_main_k_block_loop_, has_main_k_block_loop_,
use_dropout_, use_dropout_,
is_lse_storing_, is_lse_storing_>;
Deterministic>;
return launch_and_time_kernel( return launch_and_time_kernel(
stream_config, stream_config,
......
...@@ -94,7 +94,6 @@ template <typename FloatAB, ...@@ -94,7 +94,6 @@ template <typename FloatAB,
LoopScheduler LoopSched, LoopScheduler LoopSched,
bool PadN, bool PadN,
bool MaskOutUpperTriangle, bool MaskOutUpperTriangle,
bool Deterministic,
PipelineVersion PipelineVer = PipelineVersion::v1> PipelineVersion PipelineVer = PipelineVersion::v1>
struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2 struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
{ {
...@@ -531,8 +530,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -531,8 +530,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
FloatGemmAcc p_dropout_rescale, FloatGemmAcc p_dropout_rescale,
ck::philox& ph, ck::philox& ph,
const index_t z_random_matrix_offset, const index_t z_random_matrix_offset,
const index_t raw_n_padded, const index_t raw_n_padded)
const index_t block_idx_m)
{ {
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
...@@ -557,7 +555,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -557,7 +555,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
return; return;
} }
const index_t block_work_idx_m = Deterministic ? block_idx_m : block_work_idx[I0]; const index_t block_work_idx_m = block_work_idx[I0];
// HACK: this force m/gemm1_n_block_data_idx_on_grid into SGPR // HACK: this force m/gemm1_n_block_data_idx_on_grid into SGPR
const index_t m_block_data_idx_on_grid = const index_t m_block_data_idx_on_grid =
...@@ -1145,11 +1143,6 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -1145,11 +1143,6 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
0), 0),
tensor_operation::element_wise::PassThrough{}}; tensor_operation::element_wise::PassThrough{}};
if constexpr(Deterministic)
{
block_sync_lds();
}
do do
{ {
auto n_block_data_idx_on_grid = auto n_block_data_idx_on_grid =
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment