Commit 6c971dc8 authored by letaoqin's avatar letaoqin
Browse files

Merge branch 'mha-train-develop' into mha-train-develop-fix-issupport

parents b76c8e62 f27f9158
...@@ -71,11 +71,10 @@ static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecializatio ...@@ -71,11 +71,10 @@ static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecializatio
static constexpr auto MaskingSpec = static constexpr auto MaskingSpec =
ck::tensor_operation::device::MaskingSpecialization::MaskDisabled; ck::tensor_operation::device::MaskingSpecialization::MaskDisabled;
static constexpr auto TensorSpecA = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecA = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecB0 = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecB0 = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecB1 = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecB1 = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr bool Deterministic = false;
#if(DIM <= 32) #if(DIM <= 32)
using DeviceGemmInstance = using DeviceGemmInstance =
...@@ -149,8 +148,7 @@ using DeviceGemmInstance = ...@@ -149,8 +148,7 @@ using DeviceGemmInstance =
S<1, 64, 1, 4>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock S<1, 64, 1, 4>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock 8, // CShuffleBlockTransferScalarPerVector_NPerBlock
4, 4,
MaskingSpec, // MaskingSpecialization MaskingSpec>; // MaskingSpecialization
Deterministic>;
#elif(DIM <= 64) #elif(DIM <= 64)
using DeviceGemmInstance = using DeviceGemmInstance =
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2< ck::tensor_operation::device::DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2<
...@@ -223,8 +221,7 @@ using DeviceGemmInstance = ...@@ -223,8 +221,7 @@ using DeviceGemmInstance =
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock 8, // CShuffleBlockTransferScalarPerVector_NPerBlock
4, 4,
MaskingSpec, // MaskingSpecialization MaskingSpec>; // MaskingSpecialization
Deterministic>;
#elif(DIM <= 128) #elif(DIM <= 128)
using DeviceGemmInstance = using DeviceGemmInstance =
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2< ck::tensor_operation::device::DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2<
...@@ -297,8 +294,7 @@ using DeviceGemmInstance = ...@@ -297,8 +294,7 @@ using DeviceGemmInstance =
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock 8, // CShuffleBlockTransferScalarPerVector_NPerBlock
4, 4,
MaskingSpec, // MaskingSpecialization MaskingSpec>; // MaskingSpecialization
Deterministic>;
#endif #endif
// Ref Gemm0: DataType in, AccDataType out // Ref Gemm0: DataType in, AccDataType out
......
...@@ -71,11 +71,10 @@ static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecializatio ...@@ -71,11 +71,10 @@ static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecializatio
static constexpr auto MaskingSpec = static constexpr auto MaskingSpec =
ck::tensor_operation::device::MaskingSpecialization::MaskDisabled; ck::tensor_operation::device::MaskingSpecialization::MaskDisabled;
static constexpr auto TensorSpecA = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecA = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecB0 = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecB0 = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecB1 = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecB1 = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr bool Deterministic = true;
#if(DIM <= 32) #if(DIM <= 32)
using DeviceGemmInstance = using DeviceGemmInstance =
...@@ -149,8 +148,7 @@ using DeviceGemmInstance = ...@@ -149,8 +148,7 @@ using DeviceGemmInstance =
S<1, 64, 1, 4>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock S<1, 64, 1, 4>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock 8, // CShuffleBlockTransferScalarPerVector_NPerBlock
1, 1,
MaskingSpec, // MaskingSpecialization MaskingSpec>; // MaskingSpecialization
Deterministic>;
#elif(DIM <= 64) #elif(DIM <= 64)
using DeviceGemmInstance = using DeviceGemmInstance =
ck::tensor_operation::device::DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2< ck::tensor_operation::device::DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2<
...@@ -223,8 +221,7 @@ using DeviceGemmInstance = ...@@ -223,8 +221,7 @@ using DeviceGemmInstance =
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock 8, // CShuffleBlockTransferScalarPerVector_NPerBlock
1, 1,
MaskingSpec, // MaskingSpecialization MaskingSpec>; // MaskingSpecialization
Deterministic>;
#elif(DIM <= 128) #elif(DIM <= 128)
using DeviceGemmInstance = using DeviceGemmInstance =
ck::tensor_operation::device::DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2< ck::tensor_operation::device::DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2<
...@@ -297,8 +294,7 @@ using DeviceGemmInstance = ...@@ -297,8 +294,7 @@ using DeviceGemmInstance =
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock 8, // CShuffleBlockTransferScalarPerVector_NPerBlock
1, 1,
MaskingSpec, // MaskingSpecialization MaskingSpec>; // MaskingSpecialization
Deterministic>;
#endif #endif
// Ref Gemm0: DataType in, AccDataType out // Ref Gemm0: DataType in, AccDataType out
......
...@@ -71,11 +71,10 @@ static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecializatio ...@@ -71,11 +71,10 @@ static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecializatio
static constexpr auto MaskingSpec = static constexpr auto MaskingSpec =
ck::tensor_operation::device::MaskingSpecialization::MaskDisabled; ck::tensor_operation::device::MaskingSpecialization::MaskDisabled;
static constexpr auto TensorSpecA = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecA = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecB0 = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecB0 = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecB1 = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecB1 = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr bool Deterministic = false;
#if(DIM <= 32) #if(DIM <= 32)
using DeviceGemmInstance = using DeviceGemmInstance =
...@@ -149,8 +148,7 @@ using DeviceGemmInstance = ...@@ -149,8 +148,7 @@ using DeviceGemmInstance =
S<1, 64, 1, 4>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock S<1, 64, 1, 4>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock 8, // CShuffleBlockTransferScalarPerVector_NPerBlock
4, 4,
MaskingSpec, // MaskingSpecialization MaskingSpec>; // MaskingSpecialization
Deterministic>;
#elif(DIM <= 64) #elif(DIM <= 64)
using DeviceGemmInstance = using DeviceGemmInstance =
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2< ck::tensor_operation::device::DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2<
...@@ -223,8 +221,7 @@ using DeviceGemmInstance = ...@@ -223,8 +221,7 @@ using DeviceGemmInstance =
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock 8, // CShuffleBlockTransferScalarPerVector_NPerBlock
4, 4,
MaskingSpec, // MaskingSpecialization MaskingSpec>; // MaskingSpecialization
Deterministic>;
#elif(DIM <= 128) #elif(DIM <= 128)
using DeviceGemmInstance = using DeviceGemmInstance =
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2< ck::tensor_operation::device::DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2<
...@@ -297,8 +294,7 @@ using DeviceGemmInstance = ...@@ -297,8 +294,7 @@ using DeviceGemmInstance =
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock 8, // CShuffleBlockTransferScalarPerVector_NPerBlock
4, 4,
MaskingSpec, // MaskingSpecialization MaskingSpec>; // MaskingSpecialization
Deterministic>;
#endif #endif
// Ref Gemm0: DataType in, AccDataType out // Ref Gemm0: DataType in, AccDataType out
......
...@@ -71,11 +71,10 @@ static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecializatio ...@@ -71,11 +71,10 @@ static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecializatio
static constexpr auto MaskingSpec = static constexpr auto MaskingSpec =
ck::tensor_operation::device::MaskingSpecialization::MaskDisabled; ck::tensor_operation::device::MaskingSpecialization::MaskDisabled;
static constexpr auto TensorSpecA = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecA = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecB0 = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecB0 = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecB1 = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecB1 = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr bool Deterministic = false;
#if(DIM <= 32) #if(DIM <= 32)
using DeviceGemmInstance = using DeviceGemmInstance =
...@@ -149,8 +148,7 @@ using DeviceGemmInstance = ...@@ -149,8 +148,7 @@ using DeviceGemmInstance =
S<1, 64, 1, 4>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock S<1, 64, 1, 4>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock 8, // CShuffleBlockTransferScalarPerVector_NPerBlock
4, 4,
MaskingSpec, // MaskingSpecialization MaskingSpec>; // MaskingSpecialization
Deterministic>;
#elif(DIM <= 64) #elif(DIM <= 64)
using DeviceGemmInstance = using DeviceGemmInstance =
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2< ck::tensor_operation::device::DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2<
...@@ -223,8 +221,7 @@ using DeviceGemmInstance = ...@@ -223,8 +221,7 @@ using DeviceGemmInstance =
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock 8, // CShuffleBlockTransferScalarPerVector_NPerBlock
4, 4,
MaskingSpec, // MaskingSpecialization MaskingSpec>; // MaskingSpecialization
Deterministic>;
#elif(DIM <= 128) #elif(DIM <= 128)
using DeviceGemmInstance = using DeviceGemmInstance =
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2< ck::tensor_operation::device::DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2<
...@@ -297,8 +294,7 @@ using DeviceGemmInstance = ...@@ -297,8 +294,7 @@ using DeviceGemmInstance =
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock 8, // CShuffleBlockTransferScalarPerVector_NPerBlock
4, 4,
MaskingSpec, // MaskingSpecialization MaskingSpec>; // MaskingSpecialization
Deterministic>;
#endif #endif
using DeviceDropoutInstance = ck::tensor_operation::device::DeviceBatchedDropout<NumDimG, using DeviceDropoutInstance = ck::tensor_operation::device::DeviceBatchedDropout<NumDimG,
......
...@@ -71,11 +71,10 @@ static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecializatio ...@@ -71,11 +71,10 @@ static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecializatio
static constexpr auto MaskingSpec = static constexpr auto MaskingSpec =
ck::tensor_operation::device::MaskingSpecialization::MaskDisabled; ck::tensor_operation::device::MaskingSpecialization::MaskDisabled;
static constexpr auto TensorSpecA = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecA = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecB0 = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecB0 = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecB1 = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecB1 = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr bool Deterministic = false;
#if(DIM <= 32) #if(DIM <= 32)
using DeviceGemmInstance = using DeviceGemmInstance =
...@@ -149,8 +148,7 @@ using DeviceGemmInstance = ...@@ -149,8 +148,7 @@ using DeviceGemmInstance =
S<1, 64, 1, 4>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock S<1, 64, 1, 4>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock 8, // CShuffleBlockTransferScalarPerVector_NPerBlock
1, 1,
MaskingSpec, // MaskingSpecialization MaskingSpec>; // MaskingSpecialization
Deterministic>;
#elif(DIM <= 64) #elif(DIM <= 64)
using DeviceGemmInstance = using DeviceGemmInstance =
ck::tensor_operation::device::DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2< ck::tensor_operation::device::DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2<
...@@ -223,8 +221,7 @@ using DeviceGemmInstance = ...@@ -223,8 +221,7 @@ using DeviceGemmInstance =
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock 8, // CShuffleBlockTransferScalarPerVector_NPerBlock
1, 1,
MaskingSpec, // MaskingSpecialization MaskingSpec>; // MaskingSpecialization
Deterministic>;
#elif(DIM <= 128) #elif(DIM <= 128)
using DeviceGemmInstance = using DeviceGemmInstance =
ck::tensor_operation::device::DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2< ck::tensor_operation::device::DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2<
...@@ -297,8 +294,7 @@ using DeviceGemmInstance = ...@@ -297,8 +294,7 @@ using DeviceGemmInstance =
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock 8, // CShuffleBlockTransferScalarPerVector_NPerBlock
1, 1,
MaskingSpec, // MaskingSpecialization MaskingSpec>; // MaskingSpecialization
Deterministic>;
#endif #endif
// Ref Gemm0: DataType in, AccDataType out // Ref Gemm0: DataType in, AccDataType out
......
...@@ -1424,8 +1424,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -1424,8 +1424,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
void* p_qgrad_grid, void* p_qgrad_grid,
void* p_kgrad_grid, void* p_kgrad_grid,
void* p_vgrad_grid, void* p_vgrad_grid,
const D0DataType* p_acc0_bias, const void* p_acc0_bias,
const D1DataType* p_acc1_bias, const void* p_acc1_bias,
void* p_d0grad_grid, void* p_d0grad_grid,
void* p_d1grad_grid, void* p_d1grad_grid,
const std::vector<index_t>& a_gs_ms_ks_lengths, const std::vector<index_t>& a_gs_ms_ks_lengths,
......
...@@ -1281,10 +1281,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -1281,10 +1281,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
void* p_qgrad_grid, void* p_qgrad_grid,
void* p_kgrad_grid, void* p_kgrad_grid,
void* p_vgrad_grid, void* p_vgrad_grid,
const D0DataType* p_acc0_bias, const void* p_acc0_bias,
const D1DataType* p_acc1_bias, const void* p_acc1_bias,
D0DataType* p_d0grad_grid, void* p_d0grad_grid,
D1DataType* p_d1grad_grid, void* p_d1grad_grid,
const std::vector<index_t>& a_gs_ms_ks_lengths, const std::vector<index_t>& a_gs_ms_ks_lengths,
const std::vector<index_t>& a_gs_ms_ks_strides, const std::vector<index_t>& a_gs_ms_ks_strides,
const std::vector<index_t>& b_gs_ns_ks_lengths, const std::vector<index_t>& b_gs_ns_ks_lengths,
...@@ -1323,8 +1323,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -1323,8 +1323,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
static_cast<OutputDataType*>(p_vgrad_grid), static_cast<OutputDataType*>(p_vgrad_grid),
static_cast<const D0DataType*>(p_acc0_bias), // cast in struct Argument static_cast<const D0DataType*>(p_acc0_bias), // cast in struct Argument
static_cast<const D1DataType*>(p_acc1_bias), // cast in struct Argument static_cast<const D1DataType*>(p_acc1_bias), // cast in struct Argument
static_cast<const D0DataType*>(p_d0grad_grid), static_cast<D0DataType*>(p_d0grad_grid),
static_cast<const D1DataType*>(p_d1grad_grid), static_cast<D1DataType*>(p_d1grad_grid),
a_gs_ms_ks_lengths, a_gs_ms_ks_lengths,
a_gs_ms_ks_strides, a_gs_ms_ks_strides,
b_gs_ns_ks_lengths, b_gs_ns_ks_lengths,
......
...@@ -47,8 +47,7 @@ template <typename GridwiseGemm, ...@@ -47,8 +47,7 @@ template <typename GridwiseGemm,
typename C0MatrixMask, typename C0MatrixMask,
bool HasMainKBlockLoop, bool HasMainKBlockLoop,
bool IsDropout, bool IsDropout,
bool IsLseStoring, bool IsLseStoring>
bool Deterministic>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
...@@ -78,7 +77,6 @@ __global__ void ...@@ -78,7 +77,6 @@ __global__ void
const LSEGridDescriptor_M lse_grid_desc_m, const LSEGridDescriptor_M lse_grid_desc_m,
const Block2CTileMap block_2_ctile_map, const Block2CTileMap block_2_ctile_map,
const index_t batch_count, const index_t batch_count,
const index_t mblock,
const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch, const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch,
const C0MatrixMask c0_matrix_mask, const C0MatrixMask c0_matrix_mask,
const uint8_t p_dropout_in_uint8_t, const uint8_t p_dropout_in_uint8_t,
...@@ -122,73 +120,34 @@ __global__ void ...@@ -122,73 +120,34 @@ __global__ void
const index_t z_random_matrix_offset = g_idx * raw_m_padded * raw_n_padded; const index_t z_random_matrix_offset = g_idx * raw_m_padded * raw_n_padded;
if constexpr(Deterministic) GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout, IsLseStoring>(
{ p_a_grid + a_batch_offset,
for(index_t i = 0; i < mblock; i++) p_b_grid + b_batch_offset,
{ tmp_p_d0_grid,
GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout, IsLseStoring>( p_b1_grid + b1_batch_offset,
p_a_grid + a_batch_offset, p_c_grid + c_batch_offset,
p_b_grid + b_batch_offset, p_z_grid == nullptr ? nullptr : p_z_grid + z_batch_offset,
tmp_p_d0_grid, p_lse_grid == nullptr ? nullptr : p_lse_grid + lse_batch_offset,
p_b1_grid + b1_batch_offset, p_shared,
p_c_grid + c_batch_offset, a_element_op,
p_z_grid == nullptr ? nullptr : p_z_grid + z_batch_offset, b_element_op,
p_lse_grid == nullptr ? nullptr : p_lse_grid + lse_batch_offset, acc_element_op,
p_shared, b1_element_op,
a_element_op, c_element_op,
b_element_op, a_grid_desc_ak0_m_ak1,
acc_element_op, b_grid_desc_bk0_n_bk1,
b1_element_op, d0_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
c_element_op, b1_grid_desc_bk0_n_bk1,
a_grid_desc_ak0_m_ak1, c_grid_desc_mblock_mperblock_nblock_nperblock,
b_grid_desc_bk0_n_bk1, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
d0_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, lse_grid_desc_m,
b1_grid_desc_bk0_n_bk1, block_2_ctile_map,
c_grid_desc_mblock_mperblock_nblock_nperblock, c0_matrix_mask,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, p_dropout_in_uint8_t,
lse_grid_desc_m, p_dropout_rescale,
block_2_ctile_map, ph,
c0_matrix_mask, z_random_matrix_offset,
p_dropout_in_uint8_t, raw_n_padded);
p_dropout_rescale,
ph,
z_random_matrix_offset,
raw_n_padded,
i);
}
}
else
{
GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout, IsLseStoring>(
p_a_grid + a_batch_offset,
p_b_grid + b_batch_offset,
tmp_p_d0_grid,
p_b1_grid + b1_batch_offset,
p_c_grid + c_batch_offset,
p_z_grid == nullptr ? nullptr : p_z_grid + z_batch_offset,
p_lse_grid == nullptr ? nullptr : p_lse_grid + lse_batch_offset,
p_shared,
a_element_op,
b_element_op,
acc_element_op,
b1_element_op,
c_element_op,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
d0_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
b1_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
lse_grid_desc_m,
block_2_ctile_map,
c0_matrix_mask,
p_dropout_in_uint8_t,
p_dropout_rescale,
ph,
z_random_matrix_offset,
raw_n_padded,
0);
}
#else #else
ignore = p_a_grid; ignore = p_a_grid;
ignore = p_b_grid; ignore = p_b_grid;
...@@ -211,7 +170,6 @@ __global__ void ...@@ -211,7 +170,6 @@ __global__ void
ignore = lse_grid_desc_m; ignore = lse_grid_desc_m;
ignore = block_2_ctile_map; ignore = block_2_ctile_map;
ignore = batch_count; ignore = batch_count;
ignore = mblock;
ignore = compute_base_ptr_of_batch; ignore = compute_base_ptr_of_batch;
ignore = c0_matrix_mask; ignore = c0_matrix_mask;
ignore = p_dropout_in_uint8_t; ignore = p_dropout_in_uint8_t;
...@@ -296,7 +254,6 @@ template <index_t NumDimG, ...@@ -296,7 +254,6 @@ template <index_t NumDimG,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock, index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
index_t Acc1BiasTransferSrcScalarPerVector, index_t Acc1BiasTransferSrcScalarPerVector,
MaskingSpecialization MaskingSpec, MaskingSpecialization MaskingSpec,
bool Deterministic,
LoopScheduler LoopSched = LoopScheduler::Default> LoopScheduler LoopSched = LoopScheduler::Default>
struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2 struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
: public DeviceBatchedMultiheadAttentionForward<NumDimG, : public DeviceBatchedMultiheadAttentionForward<NumDimG,
...@@ -576,8 +533,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -576,8 +533,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
Acc1BiasTransferSrcScalarPerVector, Acc1BiasTransferSrcScalarPerVector,
LoopSched, LoopSched,
Transform::matrix_padder.PadN, Transform::matrix_padder.PadN,
MaskingSpec != MaskingSpecialization::MaskDisabled, MaskingSpec != MaskingSpecialization::MaskDisabled>;
Deterministic>;
// Argument // Argument
// FIXME: constness // FIXME: constness
...@@ -833,9 +789,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -833,9 +789,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
} }
const index_t grid_size = const index_t grid_size =
(Deterministic ? 1 arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_) * arg.batch_count_;
: arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_)) *
arg.batch_count_;
// Gemm0_K // Gemm0_K
const auto K = const auto K =
...@@ -843,73 +797,71 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -843,73 +797,71 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
float ave_time = 0; float ave_time = 0;
auto launch_kernel = auto launch_kernel = [&](auto has_main_k_block_loop_,
[&](auto has_main_k_block_loop_, auto is_dropout_, auto is_lse_storing_) { auto is_dropout_,
const auto kernel = kernel_batched_multiheadattention_forward_xdl_cshuffle_v2< auto is_lse_storing_) {
GridwiseGemm, const auto kernel = kernel_batched_multiheadattention_forward_xdl_cshuffle_v2<
ADataType, // TODO: distiguish A/B datatype GridwiseGemm,
D0DataType, ADataType, // TODO: distiguish A/B datatype
CDataType, D0DataType,
ZDataType, CDataType,
LSEDataType, ZDataType,
GemmAccDataType, LSEDataType,
AElementwiseOperation, GemmAccDataType,
BElementwiseOperation, AElementwiseOperation,
AccElementwiseOperation, BElementwiseOperation,
B1ElementwiseOperation, AccElementwiseOperation,
CElementwiseOperation, B1ElementwiseOperation,
DeviceOp::AGridDesc_AK0_M_AK1, CElementwiseOperation,
DeviceOp::BGridDesc_BK0_N_BK1, DeviceOp::AGridDesc_AK0_M_AK1,
typename GridwiseGemm::D0GridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5, DeviceOp::BGridDesc_BK0_N_BK1,
DeviceOp::B1GridDesc_BK0_N_BK1, typename GridwiseGemm::D0GridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5,
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, DeviceOp::B1GridDesc_BK0_N_BK1,
typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5, typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
DeviceOp::LSEGridDesc_M, typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5,
typename GridwiseGemm::DefaultBlock2CTileMap, DeviceOp::LSEGridDesc_M,
ComputeBasePtrOfStridedBatch, typename GridwiseGemm::DefaultBlock2CTileMap,
C0MatrixMask, ComputeBasePtrOfStridedBatch,
has_main_k_block_loop_, C0MatrixMask,
is_dropout_, has_main_k_block_loop_,
is_lse_storing_, is_dropout_,
Deterministic>; is_lse_storing_>;
return launch_and_time_kernel( return launch_and_time_kernel(stream_config,
stream_config, kernel,
kernel, dim3(grid_size),
dim3(grid_size), dim3(BlockSize),
dim3(BlockSize), 0,
0, arg.p_a_grid_,
arg.p_a_grid_, arg.p_b_grid_,
arg.p_b_grid_, arg.p_d0_grid_,
arg.p_d0_grid_, arg.p_b1_grid_,
arg.p_b1_grid_, arg.p_c_grid_,
arg.p_c_grid_, arg.p_z_grid_,
arg.p_z_grid_, arg.p_lse_grid_,
arg.p_lse_grid_, arg.a_element_op_,
arg.a_element_op_, arg.b_element_op_,
arg.b_element_op_, arg.acc_element_op_,
arg.acc_element_op_, arg.b1_element_op_,
arg.b1_element_op_, arg.c_element_op_,
arg.c_element_op_, arg.a_grid_desc_ak0_m_ak1_,
arg.a_grid_desc_ak0_m_ak1_, arg.b_grid_desc_bk0_n_bk1_,
arg.b_grid_desc_bk0_n_bk1_, arg.d0_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_,
arg.d0_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_, arg.b1_grid_desc_bk0_n_bk1_,
arg.b1_grid_desc_bk0_n_bk1_, arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, arg.z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_,
arg.z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_, arg.lse_grid_desc_m_,
arg.lse_grid_desc_m_, arg.block_2_ctile_map_,
arg.block_2_ctile_map_, arg.batch_count_,
arg.batch_count_, arg.compute_base_ptr_of_batch_,
arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_), arg.c0_matrix_mask_,
arg.compute_base_ptr_of_batch_, arg.p_dropout_in_uint8_t_,
arg.c0_matrix_mask_, arg.p_dropout_rescale_,
arg.p_dropout_in_uint8_t_, arg.seed_,
arg.p_dropout_rescale_, arg.offset_,
arg.seed_, arg.m_raw_padded_,
arg.offset_, arg.n_raw_padded_);
arg.m_raw_padded_, };
arg.n_raw_padded_);
};
// Gemm1_K is split into Gemm1_K0/K1 where K1 is known at compile time, so we only need // Gemm1_K is split into Gemm1_K0/K1 where K1 is known at compile time, so we only need
// to concern Gemm0's loop // to concern Gemm0's loop
......
...@@ -1027,7 +1027,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -1027,7 +1027,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
z_grid_desc_g_m_n, z_grid_desc_g_m_n,
b1_grid_desc_g_n_k, b1_grid_desc_g_n_k,
c_grid_desc_g_m_n, c_grid_desc_g_m_n,
type_convert<index_t>(lse_grid_desc_m.GetElementSpaceSize())); type_convert<index_t>(problem_desc.lse_gs_ms_strides[NumDimG - 1]));
// C0 mask // C0 mask
const auto c0_matrix_mask = const auto c0_matrix_mask =
......
...@@ -1098,7 +1098,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -1098,7 +1098,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
z_grid_desc_g_m_n, z_grid_desc_g_m_n,
b1_grid_desc_g_n_k, b1_grid_desc_g_n_k,
c_grid_desc_g_m_n, c_grid_desc_g_m_n,
type_convert<index_t>(lse_grid_desc_m.GetElementSpaceSize())); type_convert<index_t>(problem_desc.lse_gs_ms_strides[NumDimG - 1]));
// C0 mask // C0 mask
const auto c0_matrix_mask = const auto c0_matrix_mask =
......
...@@ -918,7 +918,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -918,7 +918,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
z_grid_desc_g_m_n, z_grid_desc_g_m_n,
b1_grid_desc_g_n_k, b1_grid_desc_g_n_k,
c_grid_desc_g_m_n, c_grid_desc_g_m_n,
type_convert<index_t>(lse_grid_desc_m.GetElementSpaceSize())); type_convert<index_t>(problem_desc.lse_gs_ms_strides[NumDimG - 1]));
// C0 mask // C0 mask
const auto c0_matrix_mask = const auto c0_matrix_mask =
......
...@@ -448,19 +448,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -448,19 +448,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
make_tuple(Sequence<0, 2>{}, Sequence<1>{})); make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
} }
//
// dP = dY * V^T
//
// YGrad in Gemm A position
static auto MakeYGradGridDescriptor_O0_M_O1(const std::vector<index_t>& y_gs_ms_os_lengths,
const std::vector<index_t>& y_gs_ms_os_strides)
{
return Transform::MakeAGridDescriptor_AK0_M_AK1(
Transform::MakeAGridDescriptor_M_K(y_gs_ms_os_lengths, y_gs_ms_os_strides),
Number<Y_O1>{});
}
// V in Gemm B position // V in Gemm B position
static auto MakeVGridDescriptor_O0_N_O1(const std::vector<index_t>& v_gs_os_ns_lengths, static auto MakeVGridDescriptor_O0_N_O1(const std::vector<index_t>& v_gs_os_ns_lengths,
const std::vector<index_t>& v_gs_os_ns_strides) const std::vector<index_t>& v_gs_os_ns_strides)
...@@ -988,7 +975,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -988,7 +975,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
z_grid_desc_g_m_n, z_grid_desc_g_m_n,
b1_grid_desc_g_n_k, b1_grid_desc_g_n_k,
c_grid_desc_g_m_n, c_grid_desc_g_m_n,
type_convert<index_t>(lse_grid_desc_m.GetElementSpaceSize())); type_convert<index_t>(problem_desc.lse_gs_ms_strides[NumDimG - 1]));
// C0 mask // C0 mask
const auto c0_matrix_mask = const auto c0_matrix_mask =
......
...@@ -694,7 +694,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V1 ...@@ -694,7 +694,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V1
b1_grid_desc_g_n_k, b1_grid_desc_g_n_k,
c_grid_desc_g_m_n, c_grid_desc_g_m_n,
z_grid_desc_g_m_n, z_grid_desc_g_m_n,
type_convert<index_t>(lse_grid_desc_m.GetElementSpaceSize())); type_convert<index_t>(lse_gs_ms_strides[NumDimG - 1]));
// C0 mask // C0 mask
const auto c0_matrix_mask = const auto c0_matrix_mask =
......
...@@ -35,8 +35,7 @@ template <typename GridwiseGemm, ...@@ -35,8 +35,7 @@ template <typename GridwiseGemm,
typename CElementwiseOperation, typename CElementwiseOperation,
bool HasMainKBlockLoop, bool HasMainKBlockLoop,
bool IsDropout, bool IsDropout,
bool IsLseStoring, bool IsLseStoring>
bool Deterministic>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
...@@ -87,7 +86,7 @@ __global__ void ...@@ -87,7 +86,7 @@ __global__ void
// per-group batch offset // per-group batch offset
const index_t num_blocks_per_batch = arg_ptr[group_id].num_blocks_per_batch_; const index_t num_blocks_per_batch = arg_ptr[group_id].num_blocks_per_batch_;
const index_t g_idx = __builtin_amdgcn_readfirstlane( const index_t g_idx = __builtin_amdgcn_readfirstlane(
(block_id - arg_ptr[group_id].block_start_) / (Deterministic ? 1 : num_blocks_per_batch)); (block_id - arg_ptr[group_id].block_start_) / num_blocks_per_batch);
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(arg_ptr[group_id].compute_base_ptr_of_batch_.GetABasePtr(g_idx))); static_cast<long_index_t>(arg_ptr[group_id].compute_base_ptr_of_batch_.GetABasePtr(g_idx)));
...@@ -113,84 +112,38 @@ __global__ void ...@@ -113,84 +112,38 @@ __global__ void
tmp_p_d0_grid = arg_ptr[group_id].p_d0_grid_ + d0_batch_offset; tmp_p_d0_grid = arg_ptr[group_id].p_d0_grid_ + d0_batch_offset;
} }
if constexpr(Deterministic) GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout, IsLseStoring>(
{ arg_ptr[group_id].p_a_grid_ + a_batch_offset,
for(index_t i = 0; i < num_blocks_per_batch; i++) arg_ptr[group_id].p_b_grid_ + b_batch_offset,
{ tmp_p_d0_grid,
GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout, IsLseStoring>( arg_ptr[group_id].p_b1_grid_ + b1_batch_offset,
arg_ptr[group_id].p_a_grid_ + a_batch_offset, arg_ptr[group_id].p_c_grid_ + c_batch_offset,
arg_ptr[group_id].p_b_grid_ + b_batch_offset, arg_ptr[group_id].p_z_grid_ == nullptr ? nullptr
tmp_p_d0_grid, : arg_ptr[group_id].p_z_grid_ + z_batch_offset,
arg_ptr[group_id].p_b1_grid_ + b1_batch_offset, arg_ptr[group_id].p_lse_grid_ == nullptr ? nullptr
arg_ptr[group_id].p_c_grid_ + c_batch_offset, : arg_ptr[group_id].p_lse_grid_ + lse_batch_offset,
arg_ptr[group_id].p_z_grid_ == nullptr // arg_ptr[group_id].p_lse_grid_ + lse_batch_offset,
? nullptr p_shared,
: arg_ptr[group_id].p_z_grid_ + z_batch_offset, a_element_op,
arg_ptr[group_id].p_lse_grid_ == nullptr b_element_op,
? nullptr acc_element_op,
: arg_ptr[group_id].p_lse_grid_ + lse_batch_offset, b1_element_op,
// arg_ptr[group_id].p_lse_grid_ + lse_batch_offset, c_element_op,
p_shared, arg_ptr[group_id].a_grid_desc_ak0_m_ak1_,
a_element_op, arg_ptr[group_id].b_grid_desc_bk0_n_bk1_,
b_element_op, arg_ptr[group_id].d0_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_,
acc_element_op, arg_ptr[group_id].b1_grid_desc_bk0_n_bk1_,
b1_element_op, arg_ptr[group_id].c_grid_desc_mblock_mperblock_nblock_nperblock_,
c_element_op, arg_ptr[group_id].z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_,
arg_ptr[group_id].a_grid_desc_ak0_m_ak1_, arg_ptr[group_id].lse_grid_desc_m_,
arg_ptr[group_id].b_grid_desc_bk0_n_bk1_, arg_ptr[group_id].block_2_ctile_map_,
arg_ptr[group_id].d0_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_, arg_ptr[group_id].c0_matrix_mask_,
arg_ptr[group_id].b1_grid_desc_bk0_n_bk1_, p_dropout_in_uint8_t,
arg_ptr[group_id].c_grid_desc_mblock_mperblock_nblock_nperblock_, p_dropout_rescale,
arg_ptr[group_id].z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_, ph,
arg_ptr[group_id].lse_grid_desc_m_, arg_ptr[group_id].z_random_matrix_offset_ +
arg_ptr[group_id].block_2_ctile_map_, g_idx * arg_ptr[group_id].raw_m_padded_ * arg_ptr[group_id].raw_n_padded_,
arg_ptr[group_id].c0_matrix_mask_, arg_ptr[group_id].raw_n_padded_);
p_dropout_in_uint8_t,
p_dropout_rescale,
ph,
arg_ptr[group_id].z_random_matrix_offset_ +
g_idx * arg_ptr[group_id].raw_m_padded_ * arg_ptr[group_id].raw_n_padded_,
arg_ptr[group_id].raw_n_padded_,
i);
}
}
else
{
GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout, IsLseStoring>(
arg_ptr[group_id].p_a_grid_ + a_batch_offset,
arg_ptr[group_id].p_b_grid_ + b_batch_offset,
tmp_p_d0_grid,
arg_ptr[group_id].p_b1_grid_ + b1_batch_offset,
arg_ptr[group_id].p_c_grid_ + c_batch_offset,
arg_ptr[group_id].p_z_grid_ == nullptr ? nullptr
: arg_ptr[group_id].p_z_grid_ + z_batch_offset,
arg_ptr[group_id].p_lse_grid_ == nullptr
? nullptr
: arg_ptr[group_id].p_lse_grid_ + lse_batch_offset,
// arg_ptr[group_id].p_lse_grid_ + lse_batch_offset,
p_shared,
a_element_op,
b_element_op,
acc_element_op,
b1_element_op,
c_element_op,
arg_ptr[group_id].a_grid_desc_ak0_m_ak1_,
arg_ptr[group_id].b_grid_desc_bk0_n_bk1_,
arg_ptr[group_id].d0_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_,
arg_ptr[group_id].b1_grid_desc_bk0_n_bk1_,
arg_ptr[group_id].c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg_ptr[group_id].z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_,
arg_ptr[group_id].lse_grid_desc_m_,
arg_ptr[group_id].block_2_ctile_map_,
arg_ptr[group_id].c0_matrix_mask_,
p_dropout_in_uint8_t,
p_dropout_rescale,
ph,
arg_ptr[group_id].z_random_matrix_offset_ +
g_idx * arg_ptr[group_id].raw_m_padded_ * arg_ptr[group_id].raw_n_padded_,
arg_ptr[group_id].raw_n_padded_,
0);
}
#else #else
ignore = group_kernel_args; ignore = group_kernel_args;
ignore = group_count; ignore = group_count;
...@@ -279,7 +232,6 @@ template <index_t NumDimG, ...@@ -279,7 +232,6 @@ template <index_t NumDimG,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock, index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
index_t Acc1BiasTransferSrcScalarPerVector, index_t Acc1BiasTransferSrcScalarPerVector,
MaskingSpecialization MaskingSpec, MaskingSpecialization MaskingSpec,
bool Deterministic,
LoopScheduler LoopSched = LoopScheduler::Default> LoopScheduler LoopSched = LoopScheduler::Default>
struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2 struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
: public DeviceGroupedMultiheadAttentionForward<NumDimG, : public DeviceGroupedMultiheadAttentionForward<NumDimG,
...@@ -597,8 +549,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -597,8 +549,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
Acc1BiasTransferSrcScalarPerVector, Acc1BiasTransferSrcScalarPerVector,
LoopSched, LoopSched,
Transform::matrix_padder.PadN, Transform::matrix_padder.PadN,
MaskingSpec != MaskingSpecialization::MaskDisabled, MaskingSpec != MaskingSpecialization::MaskDisabled>;
Deterministic>;
using Block2CTileMap = OffsettedBlockToCTileMap<typename GridwiseGemm::DefaultBlock2CTileMap>; using Block2CTileMap = OffsettedBlockToCTileMap<typename GridwiseGemm::DefaultBlock2CTileMap>;
...@@ -783,8 +734,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -783,8 +734,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
const auto block_2_ctile_map = Block2CTileMap(c_grid_desc_m_n, BlockStart); const auto block_2_ctile_map = Block2CTileMap(c_grid_desc_m_n, BlockStart);
const index_t batch_count = c_grid_desc_g_m_n.GetLength(I0); const index_t batch_count = c_grid_desc_g_m_n.GetLength(I0);
const index_t grid_size_grp = const index_t grid_size_grp =
(Deterministic ? 1 : block_2_ctile_map.CalculateGridSize(c_grid_desc_m_n)) * block_2_ctile_map.CalculateGridSize(c_grid_desc_m_n) * batch_count;
batch_count;
const index_t BlockEnd = grid_size_ + grid_size_grp; const index_t BlockEnd = grid_size_ + grid_size_grp;
// batch stride // batch stride
...@@ -795,7 +745,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -795,7 +745,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
b1_grid_desc_g_n_k, b1_grid_desc_g_n_k,
c_grid_desc_g_m_n, c_grid_desc_g_m_n,
z_grid_desc_g_m_n, z_grid_desc_g_m_n,
type_convert<index_t>(lse_grid_desc_m.GetElementSpaceSize())); type_convert<index_t>(problem_desc.lse_gs_ms_strides[NumDimG - 1]));
// C0 mask // C0 mask
const auto c0_matrix_mask = const auto c0_matrix_mask =
...@@ -958,8 +908,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -958,8 +908,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
CElementwiseOperation, CElementwiseOperation,
has_main_k_block_loop_, has_main_k_block_loop_,
use_dropout_, use_dropout_,
is_lse_storing_, is_lse_storing_>;
Deterministic>;
return launch_and_time_kernel( return launch_and_time_kernel(
stream_config, stream_config,
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
#include <iostream> #include <iostream>
#include <sstream> #include <sstream>
#include <cstring>
#include "ck/utility/common_header.hpp" #include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor.hpp"
...@@ -687,12 +688,34 @@ struct DeviceGroupedMultiheadAttentionInfer_Xdl_CShuffle ...@@ -687,12 +688,34 @@ struct DeviceGroupedMultiheadAttentionInfer_Xdl_CShuffle
some_has_main_k_block_loop |= y; some_has_main_k_block_loop |= y;
} }
hipGetErrorString( hipStreamCaptureStatus status = hipStreamCaptureStatusNone;
hipMemcpyWithStream(arg.p_workspace_,
arg.group_kernel_args_.data(), HIP_CHECK_ERROR(hipStreamIsCapturing(stream_config.stream_id_, &status));
arg.group_kernel_args_.size() * sizeof(GroupKernelArg),
hipMemcpyHostToDevice, if(status == hipStreamCaptureStatusActive)
stream_config.stream_id_)); {
size_t copy_size = arg.group_kernel_args_.size() * sizeof(GroupKernelArg);
// ToDO: when to release this memory buffer?
char* persistent_ptr = new char[copy_size];
(void)std::memcpy(persistent_ptr, arg.group_kernel_args_.data(), copy_size);
HIP_CHECK_ERROR(hipMemcpyAsync(arg.p_workspace_,
persistent_ptr,
copy_size,
hipMemcpyHostToDevice,
stream_config.stream_id_));
}
else
{
HIP_CHECK_ERROR(
hipMemcpyAsync(arg.p_workspace_,
arg.group_kernel_args_.data(),
arg.group_kernel_args_.size() * sizeof(GroupKernelArg),
hipMemcpyHostToDevice,
stream_config.stream_id_));
}
float ave_time = 0; float ave_time = 0;
......
...@@ -320,18 +320,27 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -320,18 +320,27 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
// dV_NOM / dK_NKM Gemm (Gemm2 crr) // dV_NOM / dK_NKM Gemm (Gemm2 crr)
if(O != K) if(O != K)
{ {
std::cerr << "O = " << O << " K = " << K << std::endl;
std::cerr << "SizeK must be equal to SizeO (equal attention head size)" << '\n'; std::cerr << "SizeK must be equal to SizeO (equal attention head size)" << '\n';
return false; return false;
} }
if(!(M == y_grid_desc_m_o.GetLength(I0) && O == y_grid_desc_m_o.GetLength(I1))) if(!(M == y_grid_desc_m_o.GetLength(I0) && O == y_grid_desc_m_o.GetLength(I1)))
{ {
std::cerr << "M = " << M << " O = " << O
<< " y_grid_desc_m_o = " << y_grid_desc_m_o.GetLength(I0) << " , "
<< y_grid_desc_m_o.GetLength(I1) << std::endl;
std::cerr << "Un-matched sizes!" << std::endl;
return false; return false;
} }
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0 && if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0 &&
O % Gemm1NPerBlock == 0)) O % Gemm1NPerBlock == 0))
{ {
std::cerr << "M = " << M << " N = " << N << " O = " << O << std::endl;
std::cerr << "MPerBlock = " << MPerBlock << " NPerBlock = " << NPerBlock
<< " KPerBlock = " << KPerBlock << std::endl;
std::cerr << "Un-aligned sizes!" << std::endl;
return false; return false;
} }
......
...@@ -94,7 +94,6 @@ template <typename FloatAB, ...@@ -94,7 +94,6 @@ template <typename FloatAB,
LoopScheduler LoopSched, LoopScheduler LoopSched,
bool PadN, bool PadN,
bool MaskOutUpperTriangle, bool MaskOutUpperTriangle,
bool Deterministic,
PipelineVersion PipelineVer = PipelineVersion::v1> PipelineVersion PipelineVer = PipelineVersion::v1>
struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2 struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
{ {
...@@ -531,8 +530,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -531,8 +530,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
FloatGemmAcc p_dropout_rescale, FloatGemmAcc p_dropout_rescale,
ck::philox& ph, ck::philox& ph,
const index_t z_random_matrix_offset, const index_t z_random_matrix_offset,
const index_t raw_n_padded, const index_t raw_n_padded)
const index_t block_idx_m)
{ {
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
...@@ -557,7 +555,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -557,7 +555,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
return; return;
} }
const index_t block_work_idx_m = Deterministic ? block_idx_m : block_work_idx[I0]; const index_t block_work_idx_m = block_work_idx[I0];
// HACK: this force m/gemm1_n_block_data_idx_on_grid into SGPR // HACK: this force m/gemm1_n_block_data_idx_on_grid into SGPR
const index_t m_block_data_idx_on_grid = const index_t m_block_data_idx_on_grid =
...@@ -1145,11 +1143,6 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -1145,11 +1143,6 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
0), 0),
tensor_operation::element_wise::PassThrough{}}; tensor_operation::element_wise::PassThrough{}};
if constexpr(Deterministic)
{
block_sync_lds();
}
do do
{ {
auto n_block_data_idx_on_grid = auto n_block_data_idx_on_grid =
......
...@@ -31,6 +31,51 @@ inline __host__ __device__ constexpr float type_convert<float, bhalf_t>(bhalf_t ...@@ -31,6 +31,51 @@ inline __host__ __device__ constexpr float type_convert<float, bhalf_t>(bhalf_t
return u.fp32; return u.fp32;
} }
#ifdef USE_RTN_BF16_CONVERT
// Convert fp32 to bf16 with RTN if higher precision is needed
template <>
inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, float>(float x)
{
union
{
float fp32;
uint32_t int32;
} u = {x};
// When the exponent bits are not all 1s, then the value is zero, normal,
// or subnormal. We round the bfloat16 mantissa up by adding 0x7FFF, plus
// 1 if the least significant bit of the bfloat16 mantissa is 1 (odd).
// This causes the bfloat16's mantissa to be incremented by 1 if the 16
// least significant bits of the float mantissa are greater than 0x8000,
// or if they are equal to 0x8000 and the least significant bit of the
// bfloat16 mantissa is 1 (odd). This causes it to be rounded to even when
// the lower 16 bits are exactly 0x8000. If the bfloat16 mantissa already
// has the value 0x7f, then incrementing it causes it to become 0x00 and
// the exponent is incremented by one, which is the next higher FP value
// to the unrounded bfloat16 value. When the bfloat16 value is subnormal
// with an exponent of 0x00 and a mantissa of 0x7f, it may be rounded up
// to a normal value with an exponent of 0x01 and a mantissa of 0x00.
// When the bfloat16 value has an exponent of 0xFE and a mantissa of 0x7F,
// incrementing it causes it to become an exponent of 0xFF and a mantissa
// of 0x00, which is Inf, the next higher value to the unrounded value.
bool flag0 = ~u.int32 & 0x7f800000;
// When all of the exponent bits are 1, the value is Inf or NaN.
// Inf is indicated by a zero mantissa. NaN is indicated by any nonzero
// mantissa bit. Quiet NaN is indicated by the most significant mantissa
// bit being 1. Signaling NaN is indicated by the most significant
// mantissa bit being 0 but some other bit(s) being 1. If any of the
// lower 16 bits of the mantissa are 1, we set the least significant bit
// of the bfloat16 mantissa, in order to preserve signaling NaN in case
// the bfloat16's mantissa bits are all 0.
bool flag1 = !flag0 && (u.int32 & 0xffff);
u.int32 += flag0 ? 0x7fff + ((u.int32 >> 16) & 1) : 0; // Round to nearest, round to even
u.int32 |= flag1 ? 0x10000 : 0x0; // Preserve signaling NaN
return uint16_t(u.int32 >> 16);
}
#else
// convert fp32 to bfp16 // convert fp32 to bfp16
template <> template <>
inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, float>(float x) inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, float>(float x)
...@@ -43,6 +88,7 @@ inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, float>(float ...@@ -43,6 +88,7 @@ inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, float>(float
return uint16_t(u.int32 >> 16); return uint16_t(u.int32 >> 16);
} }
#endif
// convert bfp16 to fp16 via fp32 // convert bfp16 to fp16 via fp32
template <> template <>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment