"...composable_kernel.git" did not exist on "e0d8806ca1cd8611d387f1fb441d3c9e92174ec5"
Commit 957d5dee authored by danyao12's avatar danyao12
Browse files

resolve conflict

parents 10836d41 3f4eae1d
...@@ -71,11 +71,10 @@ static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecializatio ...@@ -71,11 +71,10 @@ static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecializatio
static constexpr auto MaskingSpec = static constexpr auto MaskingSpec =
ck::tensor_operation::device::MaskingSpecialization::MaskDisabled; ck::tensor_operation::device::MaskingSpecialization::MaskDisabled;
static constexpr auto TensorSpecA = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecA = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecB0 = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecB0 = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecB1 = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecB1 = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr bool Deterministic = false;
#if(DIM <= 32) #if(DIM <= 32)
using DeviceGemmInstance = using DeviceGemmInstance =
...@@ -149,8 +148,7 @@ using DeviceGemmInstance = ...@@ -149,8 +148,7 @@ using DeviceGemmInstance =
S<1, 64, 1, 4>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock S<1, 64, 1, 4>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock 8, // CShuffleBlockTransferScalarPerVector_NPerBlock
4, 4,
MaskingSpec, // MaskingSpecialization MaskingSpec>; // MaskingSpecialization
Deterministic>;
#elif(DIM <= 64) #elif(DIM <= 64)
using DeviceGemmInstance = using DeviceGemmInstance =
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2< ck::tensor_operation::device::DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2<
...@@ -223,8 +221,7 @@ using DeviceGemmInstance = ...@@ -223,8 +221,7 @@ using DeviceGemmInstance =
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock 8, // CShuffleBlockTransferScalarPerVector_NPerBlock
4, 4,
MaskingSpec, // MaskingSpecialization MaskingSpec>; // MaskingSpecialization
Deterministic>;
#elif(DIM <= 128) #elif(DIM <= 128)
using DeviceGemmInstance = using DeviceGemmInstance =
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2< ck::tensor_operation::device::DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2<
...@@ -297,8 +294,7 @@ using DeviceGemmInstance = ...@@ -297,8 +294,7 @@ using DeviceGemmInstance =
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock 8, // CShuffleBlockTransferScalarPerVector_NPerBlock
4, 4,
MaskingSpec, // MaskingSpecialization MaskingSpec>; // MaskingSpecialization
Deterministic>;
#endif #endif
// Ref Gemm0: DataType in, AccDataType out // Ref Gemm0: DataType in, AccDataType out
......
...@@ -71,11 +71,10 @@ static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecializatio ...@@ -71,11 +71,10 @@ static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecializatio
static constexpr auto MaskingSpec = static constexpr auto MaskingSpec =
ck::tensor_operation::device::MaskingSpecialization::MaskDisabled; ck::tensor_operation::device::MaskingSpecialization::MaskDisabled;
static constexpr auto TensorSpecA = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecA = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecB0 = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecB0 = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecB1 = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecB1 = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr bool Deterministic = false;
#if(DIM <= 32) #if(DIM <= 32)
using DeviceGemmInstance = using DeviceGemmInstance =
...@@ -149,8 +148,7 @@ using DeviceGemmInstance = ...@@ -149,8 +148,7 @@ using DeviceGemmInstance =
S<1, 64, 1, 4>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock S<1, 64, 1, 4>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock 8, // CShuffleBlockTransferScalarPerVector_NPerBlock
1, 1,
MaskingSpec, // MaskingSpecialization MaskingSpec>; // MaskingSpecialization
Deterministic>;
#elif(DIM <= 64) #elif(DIM <= 64)
using DeviceGemmInstance = using DeviceGemmInstance =
ck::tensor_operation::device::DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2< ck::tensor_operation::device::DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2<
...@@ -223,8 +221,7 @@ using DeviceGemmInstance = ...@@ -223,8 +221,7 @@ using DeviceGemmInstance =
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock 8, // CShuffleBlockTransferScalarPerVector_NPerBlock
1, 1,
MaskingSpec, // MaskingSpecialization MaskingSpec>; // MaskingSpecialization
Deterministic>;
#elif(DIM <= 128) #elif(DIM <= 128)
using DeviceGemmInstance = using DeviceGemmInstance =
ck::tensor_operation::device::DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2< ck::tensor_operation::device::DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2<
...@@ -297,8 +294,7 @@ using DeviceGemmInstance = ...@@ -297,8 +294,7 @@ using DeviceGemmInstance =
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock 8, // CShuffleBlockTransferScalarPerVector_NPerBlock
1, 1,
MaskingSpec, // MaskingSpecialization MaskingSpec>; // MaskingSpecialization
Deterministic>;
#endif #endif
// Ref Gemm0: DataType in, AccDataType out // Ref Gemm0: DataType in, AccDataType out
......
...@@ -71,11 +71,10 @@ static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecializatio ...@@ -71,11 +71,10 @@ static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecializatio
static constexpr auto MaskingSpec = static constexpr auto MaskingSpec =
ck::tensor_operation::device::MaskingSpecialization::MaskDisabled; ck::tensor_operation::device::MaskingSpecialization::MaskDisabled;
static constexpr auto TensorSpecA = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecA = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecB0 = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecB0 = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecB1 = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecB1 = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr bool Deterministic = false;
#if(DIM <= 32) #if(DIM <= 32)
using DeviceGemmInstance = using DeviceGemmInstance =
...@@ -149,8 +148,7 @@ using DeviceGemmInstance = ...@@ -149,8 +148,7 @@ using DeviceGemmInstance =
S<1, 64, 1, 4>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock S<1, 64, 1, 4>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock 8, // CShuffleBlockTransferScalarPerVector_NPerBlock
4, 4,
MaskingSpec, // MaskingSpecialization MaskingSpec>; // MaskingSpecialization
Deterministic>;
#elif(DIM <= 64) #elif(DIM <= 64)
using DeviceGemmInstance = using DeviceGemmInstance =
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2< ck::tensor_operation::device::DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2<
...@@ -223,8 +221,7 @@ using DeviceGemmInstance = ...@@ -223,8 +221,7 @@ using DeviceGemmInstance =
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock 8, // CShuffleBlockTransferScalarPerVector_NPerBlock
4, 4,
MaskingSpec, // MaskingSpecialization MaskingSpec>; // MaskingSpecialization
Deterministic>;
#elif(DIM <= 128) #elif(DIM <= 128)
using DeviceGemmInstance = using DeviceGemmInstance =
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2< ck::tensor_operation::device::DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2<
...@@ -297,8 +294,7 @@ using DeviceGemmInstance = ...@@ -297,8 +294,7 @@ using DeviceGemmInstance =
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock 8, // CShuffleBlockTransferScalarPerVector_NPerBlock
4, 4,
MaskingSpec, // MaskingSpecialization MaskingSpec>; // MaskingSpecialization
Deterministic>;
#endif #endif
// Ref Gemm0: DataType in, AccDataType out // Ref Gemm0: DataType in, AccDataType out
......
...@@ -71,11 +71,10 @@ static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecializatio ...@@ -71,11 +71,10 @@ static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecializatio
static constexpr auto MaskingSpec = static constexpr auto MaskingSpec =
ck::tensor_operation::device::MaskingSpecialization::MaskDisabled; ck::tensor_operation::device::MaskingSpecialization::MaskDisabled;
static constexpr auto TensorSpecA = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecA = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecB0 = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecB0 = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecB1 = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecB1 = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr bool Deterministic = false;
#if(DIM <= 32) #if(DIM <= 32)
using DeviceGemmInstance = using DeviceGemmInstance =
...@@ -149,8 +148,7 @@ using DeviceGemmInstance = ...@@ -149,8 +148,7 @@ using DeviceGemmInstance =
S<1, 64, 1, 4>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock S<1, 64, 1, 4>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock 8, // CShuffleBlockTransferScalarPerVector_NPerBlock
4, 4,
MaskingSpec, // MaskingSpecialization MaskingSpec>; // MaskingSpecialization
Deterministic>;
#elif(DIM <= 64) #elif(DIM <= 64)
using DeviceGemmInstance = using DeviceGemmInstance =
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2< ck::tensor_operation::device::DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2<
...@@ -223,8 +221,7 @@ using DeviceGemmInstance = ...@@ -223,8 +221,7 @@ using DeviceGemmInstance =
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock 8, // CShuffleBlockTransferScalarPerVector_NPerBlock
4, 4,
MaskingSpec, // MaskingSpecialization MaskingSpec>; // MaskingSpecialization
Deterministic>;
#elif(DIM <= 128) #elif(DIM <= 128)
using DeviceGemmInstance = using DeviceGemmInstance =
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2< ck::tensor_operation::device::DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2<
...@@ -297,8 +294,7 @@ using DeviceGemmInstance = ...@@ -297,8 +294,7 @@ using DeviceGemmInstance =
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock 8, // CShuffleBlockTransferScalarPerVector_NPerBlock
4, 4,
MaskingSpec, // MaskingSpecialization MaskingSpec>; // MaskingSpecialization
Deterministic>;
#endif #endif
using DeviceDropoutInstance = ck::tensor_operation::device::DeviceBatchedDropout<NumDimG, using DeviceDropoutInstance = ck::tensor_operation::device::DeviceBatchedDropout<NumDimG,
......
...@@ -71,11 +71,10 @@ static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecializatio ...@@ -71,11 +71,10 @@ static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecializatio
static constexpr auto MaskingSpec = static constexpr auto MaskingSpec =
ck::tensor_operation::device::MaskingSpecialization::MaskDisabled; ck::tensor_operation::device::MaskingSpecialization::MaskDisabled;
static constexpr auto TensorSpecA = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecA = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecB0 = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecB0 = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecB1 = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecB1 = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr bool Deterministic = false;
#if(DIM <= 32) #if(DIM <= 32)
using DeviceGemmInstance = using DeviceGemmInstance =
...@@ -149,8 +148,7 @@ using DeviceGemmInstance = ...@@ -149,8 +148,7 @@ using DeviceGemmInstance =
S<1, 64, 1, 4>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock S<1, 64, 1, 4>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock 8, // CShuffleBlockTransferScalarPerVector_NPerBlock
1, 1,
MaskingSpec, // MaskingSpecialization MaskingSpec>; // MaskingSpecialization
Deterministic>;
#elif(DIM <= 64) #elif(DIM <= 64)
using DeviceGemmInstance = using DeviceGemmInstance =
ck::tensor_operation::device::DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2< ck::tensor_operation::device::DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2<
...@@ -223,8 +221,7 @@ using DeviceGemmInstance = ...@@ -223,8 +221,7 @@ using DeviceGemmInstance =
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock 8, // CShuffleBlockTransferScalarPerVector_NPerBlock
1, 1,
MaskingSpec, // MaskingSpecialization MaskingSpec>; // MaskingSpecialization
Deterministic>;
#elif(DIM <= 128) #elif(DIM <= 128)
using DeviceGemmInstance = using DeviceGemmInstance =
ck::tensor_operation::device::DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2< ck::tensor_operation::device::DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2<
...@@ -297,8 +294,7 @@ using DeviceGemmInstance = ...@@ -297,8 +294,7 @@ using DeviceGemmInstance =
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock 8, // CShuffleBlockTransferScalarPerVector_NPerBlock
1, 1,
MaskingSpec, // MaskingSpecialization MaskingSpec>; // MaskingSpecialization
Deterministic>;
#endif #endif
// Ref Gemm0: DataType in, AccDataType out // Ref Gemm0: DataType in, AccDataType out
......
...@@ -1488,10 +1488,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -1488,10 +1488,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
void* p_qgrad_grid, void* p_qgrad_grid,
void* p_kgrad_grid, void* p_kgrad_grid,
void* p_vgrad_grid, void* p_vgrad_grid,
const D0DataType* p_acc0_bias, const void* p_acc0_bias,
const D1DataType* p_acc1_bias, const void* p_acc1_bias,
D0DataType* p_d0grad_grid, void* p_d0grad_grid,
D1DataType* p_d1grad_grid, void* p_d1grad_grid,
const std::vector<index_t>& a_gs_ms_ks_lengths, const std::vector<index_t>& a_gs_ms_ks_lengths,
const std::vector<index_t>& a_gs_ms_ks_strides, const std::vector<index_t>& a_gs_ms_ks_strides,
const std::vector<index_t>& b_gs_ns_ks_lengths, const std::vector<index_t>& b_gs_ns_ks_lengths,
......
...@@ -1344,10 +1344,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -1344,10 +1344,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
void* p_qgrad_grid, void* p_qgrad_grid,
void* p_kgrad_grid, void* p_kgrad_grid,
void* p_vgrad_grid, void* p_vgrad_grid,
const D0DataType* p_acc0_bias, const void* p_acc0_bias,
const D1DataType* p_acc1_bias, const void* p_acc1_bias,
D0DataType* p_d0grad_grid, void* p_d0grad_grid,
D1DataType* p_d1grad_grid, void* p_d1grad_grid,
const std::vector<index_t>& a_gs_ms_ks_lengths, const std::vector<index_t>& a_gs_ms_ks_lengths,
const std::vector<index_t>& a_gs_ms_ks_strides, const std::vector<index_t>& a_gs_ms_ks_strides,
const std::vector<index_t>& b_gs_ns_ks_lengths, const std::vector<index_t>& b_gs_ns_ks_lengths,
...@@ -1390,8 +1390,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -1390,8 +1390,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
static_cast<OutputDataType*>(p_vgrad_grid), static_cast<OutputDataType*>(p_vgrad_grid),
static_cast<const D0DataType*>(p_acc0_bias), // cast in struct Argument static_cast<const D0DataType*>(p_acc0_bias), // cast in struct Argument
static_cast<const D1DataType*>(p_acc1_bias), // cast in struct Argument static_cast<const D1DataType*>(p_acc1_bias), // cast in struct Argument
static_cast<const D0DataType*>(p_d0grad_grid), static_cast<D0DataType*>(p_d0grad_grid),
static_cast<const D1DataType*>(p_d1grad_grid), static_cast<D1DataType*>(p_d1grad_grid),
a_gs_ms_ks_lengths, a_gs_ms_ks_lengths,
a_gs_ms_ks_strides, a_gs_ms_ks_strides,
b_gs_ns_ks_lengths, b_gs_ns_ks_lengths,
......
...@@ -47,8 +47,7 @@ template <typename GridwiseGemm, ...@@ -47,8 +47,7 @@ template <typename GridwiseGemm,
typename C0MatrixMask, typename C0MatrixMask,
bool HasMainKBlockLoop, bool HasMainKBlockLoop,
bool IsDropout, bool IsDropout,
bool IsLseStoring, bool IsLseStoring>
bool Deterministic>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
...@@ -79,7 +78,6 @@ __global__ void ...@@ -79,7 +78,6 @@ __global__ void
const Block2CTileMap block_2_ctile_map, const Block2CTileMap block_2_ctile_map,
const index_t batch_count, const index_t batch_count,
const index_t h_ratio, const index_t h_ratio,
const index_t mblock,
const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch, const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch,
const C0MatrixMask c0_matrix_mask, const C0MatrixMask c0_matrix_mask,
const uint8_t p_dropout_in_uint8_t, const uint8_t p_dropout_in_uint8_t,
...@@ -124,73 +122,34 @@ __global__ void ...@@ -124,73 +122,34 @@ __global__ void
const index_t z_random_matrix_offset = g_idx * raw_m_padded * raw_n_padded; const index_t z_random_matrix_offset = g_idx * raw_m_padded * raw_n_padded;
if constexpr(Deterministic) GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout, IsLseStoring>(
{ p_a_grid + a_batch_offset,
for(index_t i = 0; i < mblock; i++) p_b_grid + b_batch_offset,
{ tmp_p_d0_grid,
GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout, IsLseStoring>( p_b1_grid + b1_batch_offset,
p_a_grid + a_batch_offset, p_c_grid + c_batch_offset,
p_b_grid + b_batch_offset, p_z_grid == nullptr ? nullptr : p_z_grid + z_batch_offset,
tmp_p_d0_grid, p_lse_grid == nullptr ? nullptr : p_lse_grid + lse_batch_offset,
p_b1_grid + b1_batch_offset, p_shared,
p_c_grid + c_batch_offset, a_element_op,
p_z_grid == nullptr ? nullptr : p_z_grid + z_batch_offset, b_element_op,
p_lse_grid == nullptr ? nullptr : p_lse_grid + lse_batch_offset, acc_element_op,
p_shared, b1_element_op,
a_element_op, c_element_op,
b_element_op, a_grid_desc_ak0_m_ak1,
acc_element_op, b_grid_desc_bk0_n_bk1,
b1_element_op, d0_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
c_element_op, b1_grid_desc_bk0_n_bk1,
a_grid_desc_ak0_m_ak1, c_grid_desc_mblock_mperblock_nblock_nperblock,
b_grid_desc_bk0_n_bk1, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
d0_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, lse_grid_desc_m,
b1_grid_desc_bk0_n_bk1, block_2_ctile_map,
c_grid_desc_mblock_mperblock_nblock_nperblock, c0_matrix_mask,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, p_dropout_in_uint8_t,
lse_grid_desc_m, p_dropout_rescale,
block_2_ctile_map, ph,
c0_matrix_mask, z_random_matrix_offset,
p_dropout_in_uint8_t, raw_n_padded);
p_dropout_rescale,
ph,
z_random_matrix_offset,
raw_n_padded,
i);
}
}
else
{
GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout, IsLseStoring>(
p_a_grid + a_batch_offset,
p_b_grid + b_batch_offset,
tmp_p_d0_grid,
p_b1_grid + b1_batch_offset,
p_c_grid + c_batch_offset,
p_z_grid == nullptr ? nullptr : p_z_grid + z_batch_offset,
p_lse_grid == nullptr ? nullptr : p_lse_grid + lse_batch_offset,
p_shared,
a_element_op,
b_element_op,
acc_element_op,
b1_element_op,
c_element_op,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
d0_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
b1_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
lse_grid_desc_m,
block_2_ctile_map,
c0_matrix_mask,
p_dropout_in_uint8_t,
p_dropout_rescale,
ph,
z_random_matrix_offset,
raw_n_padded,
0);
}
#else #else
ignore = p_a_grid; ignore = p_a_grid;
ignore = p_b_grid; ignore = p_b_grid;
...@@ -214,7 +173,6 @@ __global__ void ...@@ -214,7 +173,6 @@ __global__ void
ignore = block_2_ctile_map; ignore = block_2_ctile_map;
ignore = batch_count; ignore = batch_count;
ignore = h_ratio; ignore = h_ratio;
ignore = mblock;
ignore = compute_base_ptr_of_batch; ignore = compute_base_ptr_of_batch;
ignore = c0_matrix_mask; ignore = c0_matrix_mask;
ignore = p_dropout_in_uint8_t; ignore = p_dropout_in_uint8_t;
...@@ -299,7 +257,6 @@ template <index_t NumDimG, ...@@ -299,7 +257,6 @@ template <index_t NumDimG,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock, index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
index_t Acc1BiasTransferSrcScalarPerVector, index_t Acc1BiasTransferSrcScalarPerVector,
MaskingSpecialization MaskingSpec, MaskingSpecialization MaskingSpec,
bool Deterministic,
LoopScheduler LoopSched = LoopScheduler::Default> LoopScheduler LoopSched = LoopScheduler::Default>
struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2 struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
: public DeviceBatchedMultiheadAttentionForward<NumDimG, : public DeviceBatchedMultiheadAttentionForward<NumDimG,
...@@ -579,8 +536,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -579,8 +536,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
Acc1BiasTransferSrcScalarPerVector, Acc1BiasTransferSrcScalarPerVector,
LoopSched, LoopSched,
Transform::matrix_padder.PadN, Transform::matrix_padder.PadN,
MaskingSpec != MaskingSpecialization::MaskDisabled, MaskingSpec != MaskingSpecialization::MaskDisabled>;
Deterministic>;
// Argument // Argument
// FIXME: constness // FIXME: constness
...@@ -836,9 +792,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -836,9 +792,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
} }
const index_t grid_size = const index_t grid_size =
(Deterministic ? 1 arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_) * arg.batch_count_;
: arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_)) *
arg.batch_count_;
// Gemm0_K // Gemm0_K
const auto K = const auto K =
...@@ -846,74 +800,72 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -846,74 +800,72 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
float ave_time = 0; float ave_time = 0;
auto launch_kernel = auto launch_kernel = [&](auto has_main_k_block_loop_,
[&](auto has_main_k_block_loop_, auto is_dropout_, auto is_lse_storing_) { auto is_dropout_,
const auto kernel = kernel_batched_multiheadattention_forward_xdl_cshuffle_v2< auto is_lse_storing_) {
GridwiseGemm, const auto kernel = kernel_batched_multiheadattention_forward_xdl_cshuffle_v2<
ADataType, // TODO: distiguish A/B datatype GridwiseGemm,
D0DataType, ADataType, // TODO: distiguish A/B datatype
CDataType, D0DataType,
ZDataType, CDataType,
LSEDataType, ZDataType,
GemmAccDataType, LSEDataType,
AElementwiseOperation, GemmAccDataType,
BElementwiseOperation, AElementwiseOperation,
AccElementwiseOperation, BElementwiseOperation,
B1ElementwiseOperation, AccElementwiseOperation,
CElementwiseOperation, B1ElementwiseOperation,
DeviceOp::AGridDesc_AK0_M_AK1, CElementwiseOperation,
DeviceOp::BGridDesc_BK0_N_BK1, DeviceOp::AGridDesc_AK0_M_AK1,
typename GridwiseGemm::D0GridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5, DeviceOp::BGridDesc_BK0_N_BK1,
DeviceOp::B1GridDesc_BK0_N_BK1, typename GridwiseGemm::D0GridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5,
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, DeviceOp::B1GridDesc_BK0_N_BK1,
typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5, typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
DeviceOp::LSEGridDesc_M, typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5,
typename GridwiseGemm::DefaultBlock2CTileMap, DeviceOp::LSEGridDesc_M,
ComputeBasePtrOfStridedBatch, typename GridwiseGemm::DefaultBlock2CTileMap,
C0MatrixMask, ComputeBasePtrOfStridedBatch,
has_main_k_block_loop_, C0MatrixMask,
is_dropout_, has_main_k_block_loop_,
is_lse_storing_, is_dropout_,
Deterministic>; is_lse_storing_>;
return launch_and_time_kernel( return launch_and_time_kernel(stream_config,
stream_config, kernel,
kernel, dim3(grid_size),
dim3(grid_size), dim3(BlockSize),
dim3(BlockSize), 0,
0, arg.p_a_grid_,
arg.p_a_grid_, arg.p_b_grid_,
arg.p_b_grid_, arg.p_d0_grid_,
arg.p_d0_grid_, arg.p_b1_grid_,
arg.p_b1_grid_, arg.p_c_grid_,
arg.p_c_grid_, arg.p_z_grid_,
arg.p_z_grid_, arg.p_lse_grid_,
arg.p_lse_grid_, arg.a_element_op_,
arg.a_element_op_, arg.b_element_op_,
arg.b_element_op_, arg.acc_element_op_,
arg.acc_element_op_, arg.b1_element_op_,
arg.b1_element_op_, arg.c_element_op_,
arg.c_element_op_, arg.a_grid_desc_ak0_m_ak1_,
arg.a_grid_desc_ak0_m_ak1_, arg.b_grid_desc_bk0_n_bk1_,
arg.b_grid_desc_bk0_n_bk1_, arg.d0_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_,
arg.d0_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_, arg.b1_grid_desc_bk0_n_bk1_,
arg.b1_grid_desc_bk0_n_bk1_, arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, arg.z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_,
arg.z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_, arg.lse_grid_desc_m_,
arg.lse_grid_desc_m_, arg.block_2_ctile_map_,
arg.block_2_ctile_map_, arg.batch_count_,
arg.batch_count_, arg.h_ratio_,
arg.h_ratio_, arg.compute_base_ptr_of_batch_,
arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_), arg.c0_matrix_mask_,
arg.compute_base_ptr_of_batch_, arg.p_dropout_in_uint8_t_,
arg.c0_matrix_mask_, arg.p_dropout_rescale_,
arg.p_dropout_in_uint8_t_, arg.seed_,
arg.p_dropout_rescale_, arg.offset_,
arg.seed_, arg.m_raw_padded_,
arg.offset_, arg.n_raw_padded_);
arg.m_raw_padded_, };
arg.n_raw_padded_);
};
// Gemm1_K is split into Gemm1_K0/K1 where K1 is known at compile time, so we only need // Gemm1_K is split into Gemm1_K0/K1 where K1 is known at compile time, so we only need
// to concern Gemm0's loop // to concern Gemm0's loop
......
...@@ -1079,7 +1079,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -1079,7 +1079,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
c_grid_desc_g_m_n, c_grid_desc_g_m_n,
bgrad_grid_desc_g_n_k, bgrad_grid_desc_g_n_k,
b1grad_grid_desc_g_n_k, b1grad_grid_desc_g_n_k,
type_convert<index_t>(lse_grid_desc_m.GetElementSpaceSize())); type_convert<index_t>(problem_desc.lse_gs_ms_strides[NumDimG - 1]));
// C0 mask // C0 mask
const auto c0_matrix_mask = const auto c0_matrix_mask =
......
...@@ -1148,7 +1148,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -1148,7 +1148,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
c_grid_desc_g_m_n, c_grid_desc_g_m_n,
bgrad_grid_desc_g_n_k, bgrad_grid_desc_g_n_k,
b1grad_grid_desc_g_n_k, b1grad_grid_desc_g_n_k,
type_convert<index_t>(lse_grid_desc_m.GetElementSpaceSize())); type_convert<index_t>(problem_desc.lse_gs_ms_strides[NumDimG - 1]));
// C0 mask // C0 mask
const auto c0_matrix_mask = const auto c0_matrix_mask =
......
...@@ -969,7 +969,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -969,7 +969,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
c_grid_desc_g_m_n, c_grid_desc_g_m_n,
bgrad_grid_desc_g_n_k, bgrad_grid_desc_g_n_k,
b1grad_grid_desc_g_n_k, b1grad_grid_desc_g_n_k,
type_convert<index_t>(lse_grid_desc_m.GetElementSpaceSize())); type_convert<index_t>(problem_desc.lse_gs_ms_strides[NumDimG - 1]));
// C0 mask // C0 mask
const auto c0_matrix_mask = const auto c0_matrix_mask =
......
...@@ -467,19 +467,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -467,19 +467,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
make_tuple(Sequence<0, 2>{}, Sequence<1>{})); make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
} }
//
// dP = dY * V^T
//
// YGrad in Gemm A position
static auto MakeYGradGridDescriptor_O0_M_O1(const std::vector<index_t>& y_gs_ms_os_lengths,
const std::vector<index_t>& y_gs_ms_os_strides)
{
return Transform::MakeAGridDescriptor_AK0_M_AK1(
Transform::MakeAGridDescriptor_M_K(y_gs_ms_os_lengths, y_gs_ms_os_strides),
Number<Y_O1>{});
}
// V in Gemm B position // V in Gemm B position
static auto MakeVGridDescriptor_O0_N_O1(const std::vector<index_t>& v_gs_os_ns_lengths, static auto MakeVGridDescriptor_O0_N_O1(const std::vector<index_t>& v_gs_os_ns_lengths,
const std::vector<index_t>& v_gs_os_ns_strides) const std::vector<index_t>& v_gs_os_ns_strides)
...@@ -1039,7 +1026,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1039,7 +1026,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
c_grid_desc_g_m_n, c_grid_desc_g_m_n,
bgrad_grid_desc_g_n_k, bgrad_grid_desc_g_n_k,
b1grad_grid_desc_g_n_k, b1grad_grid_desc_g_n_k,
type_convert<index_t>(lse_grid_desc_m.GetElementSpaceSize())); type_convert<index_t>(problem_desc.lse_gs_ms_strides[NumDimG - 1]));
// C0 mask // C0 mask
const auto c0_matrix_mask = const auto c0_matrix_mask =
......
...@@ -694,7 +694,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V1 ...@@ -694,7 +694,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V1
b1_grid_desc_g_n_k, b1_grid_desc_g_n_k,
c_grid_desc_g_m_n, c_grid_desc_g_m_n,
z_grid_desc_g_m_n, z_grid_desc_g_m_n,
type_convert<index_t>(lse_grid_desc_m.GetElementSpaceSize())); type_convert<index_t>(lse_gs_ms_strides[NumDimG - 1]));
// C0 mask // C0 mask
const auto c0_matrix_mask = const auto c0_matrix_mask =
......
...@@ -35,8 +35,7 @@ template <typename GridwiseGemm, ...@@ -35,8 +35,7 @@ template <typename GridwiseGemm,
typename CElementwiseOperation, typename CElementwiseOperation,
bool HasMainKBlockLoop, bool HasMainKBlockLoop,
bool IsDropout, bool IsDropout,
bool IsLseStoring, bool IsLseStoring>
bool Deterministic>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
...@@ -88,7 +87,7 @@ __global__ void ...@@ -88,7 +87,7 @@ __global__ void
// per-group batch offset // per-group batch offset
const index_t num_blocks_per_batch = arg_ptr[group_id].num_blocks_per_batch_; const index_t num_blocks_per_batch = arg_ptr[group_id].num_blocks_per_batch_;
const index_t g_idx = __builtin_amdgcn_readfirstlane( const index_t g_idx = __builtin_amdgcn_readfirstlane(
(block_id - arg_ptr[group_id].block_start_) / (Deterministic ? 1 : num_blocks_per_batch)); (block_id - arg_ptr[group_id].block_start_) / num_blocks_per_batch);
const index_t gkv_idx = __builtin_amdgcn_readfirstlane(g_idx / h_ratio); const index_t gkv_idx = __builtin_amdgcn_readfirstlane(g_idx / h_ratio);
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
...@@ -115,84 +114,38 @@ __global__ void ...@@ -115,84 +114,38 @@ __global__ void
tmp_p_d0_grid = arg_ptr[group_id].p_d0_grid_ + d0_batch_offset; tmp_p_d0_grid = arg_ptr[group_id].p_d0_grid_ + d0_batch_offset;
} }
if constexpr(Deterministic) GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout, IsLseStoring>(
{ arg_ptr[group_id].p_a_grid_ + a_batch_offset,
for(index_t i = 0; i < num_blocks_per_batch; i++) arg_ptr[group_id].p_b_grid_ + b_batch_offset,
{ tmp_p_d0_grid,
GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout, IsLseStoring>( arg_ptr[group_id].p_b1_grid_ + b1_batch_offset,
arg_ptr[group_id].p_a_grid_ + a_batch_offset, arg_ptr[group_id].p_c_grid_ + c_batch_offset,
arg_ptr[group_id].p_b_grid_ + b_batch_offset, arg_ptr[group_id].p_z_grid_ == nullptr ? nullptr
tmp_p_d0_grid, : arg_ptr[group_id].p_z_grid_ + z_batch_offset,
arg_ptr[group_id].p_b1_grid_ + b1_batch_offset, arg_ptr[group_id].p_lse_grid_ == nullptr ? nullptr
arg_ptr[group_id].p_c_grid_ + c_batch_offset, : arg_ptr[group_id].p_lse_grid_ + lse_batch_offset,
arg_ptr[group_id].p_z_grid_ == nullptr // arg_ptr[group_id].p_lse_grid_ + lse_batch_offset,
? nullptr p_shared,
: arg_ptr[group_id].p_z_grid_ + z_batch_offset, a_element_op,
arg_ptr[group_id].p_lse_grid_ == nullptr b_element_op,
? nullptr acc_element_op,
: arg_ptr[group_id].p_lse_grid_ + lse_batch_offset, b1_element_op,
// arg_ptr[group_id].p_lse_grid_ + lse_batch_offset, c_element_op,
p_shared, arg_ptr[group_id].a_grid_desc_ak0_m_ak1_,
a_element_op, arg_ptr[group_id].b_grid_desc_bk0_n_bk1_,
b_element_op, arg_ptr[group_id].d0_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_,
acc_element_op, arg_ptr[group_id].b1_grid_desc_bk0_n_bk1_,
b1_element_op, arg_ptr[group_id].c_grid_desc_mblock_mperblock_nblock_nperblock_,
c_element_op, arg_ptr[group_id].z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_,
arg_ptr[group_id].a_grid_desc_ak0_m_ak1_, arg_ptr[group_id].lse_grid_desc_m_,
arg_ptr[group_id].b_grid_desc_bk0_n_bk1_, arg_ptr[group_id].block_2_ctile_map_,
arg_ptr[group_id].d0_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_, arg_ptr[group_id].c0_matrix_mask_,
arg_ptr[group_id].b1_grid_desc_bk0_n_bk1_, p_dropout_in_uint8_t,
arg_ptr[group_id].c_grid_desc_mblock_mperblock_nblock_nperblock_, p_dropout_rescale,
arg_ptr[group_id].z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_, ph,
arg_ptr[group_id].lse_grid_desc_m_, arg_ptr[group_id].z_random_matrix_offset_ +
arg_ptr[group_id].block_2_ctile_map_, g_idx * arg_ptr[group_id].raw_m_padded_ * arg_ptr[group_id].raw_n_padded_,
arg_ptr[group_id].c0_matrix_mask_, arg_ptr[group_id].raw_n_padded_);
p_dropout_in_uint8_t,
p_dropout_rescale,
ph,
arg_ptr[group_id].z_random_matrix_offset_ +
g_idx * arg_ptr[group_id].raw_m_padded_ * arg_ptr[group_id].raw_n_padded_,
arg_ptr[group_id].raw_n_padded_,
i);
}
}
else
{
GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout, IsLseStoring>(
arg_ptr[group_id].p_a_grid_ + a_batch_offset,
arg_ptr[group_id].p_b_grid_ + b_batch_offset,
tmp_p_d0_grid,
arg_ptr[group_id].p_b1_grid_ + b1_batch_offset,
arg_ptr[group_id].p_c_grid_ + c_batch_offset,
arg_ptr[group_id].p_z_grid_ == nullptr ? nullptr
: arg_ptr[group_id].p_z_grid_ + z_batch_offset,
arg_ptr[group_id].p_lse_grid_ == nullptr
? nullptr
: arg_ptr[group_id].p_lse_grid_ + lse_batch_offset,
// arg_ptr[group_id].p_lse_grid_ + lse_batch_offset,
p_shared,
a_element_op,
b_element_op,
acc_element_op,
b1_element_op,
c_element_op,
arg_ptr[group_id].a_grid_desc_ak0_m_ak1_,
arg_ptr[group_id].b_grid_desc_bk0_n_bk1_,
arg_ptr[group_id].d0_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_,
arg_ptr[group_id].b1_grid_desc_bk0_n_bk1_,
arg_ptr[group_id].c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg_ptr[group_id].z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_,
arg_ptr[group_id].lse_grid_desc_m_,
arg_ptr[group_id].block_2_ctile_map_,
arg_ptr[group_id].c0_matrix_mask_,
p_dropout_in_uint8_t,
p_dropout_rescale,
ph,
arg_ptr[group_id].z_random_matrix_offset_ +
g_idx * arg_ptr[group_id].raw_m_padded_ * arg_ptr[group_id].raw_n_padded_,
arg_ptr[group_id].raw_n_padded_,
0);
}
#else #else
ignore = group_kernel_args; ignore = group_kernel_args;
ignore = group_count; ignore = group_count;
...@@ -282,7 +235,6 @@ template <index_t NumDimG, ...@@ -282,7 +235,6 @@ template <index_t NumDimG,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock, index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
index_t Acc1BiasTransferSrcScalarPerVector, index_t Acc1BiasTransferSrcScalarPerVector,
MaskingSpecialization MaskingSpec, MaskingSpecialization MaskingSpec,
bool Deterministic,
LoopScheduler LoopSched = LoopScheduler::Default> LoopScheduler LoopSched = LoopScheduler::Default>
struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2 struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
: public DeviceGroupedMultiheadAttentionForward<NumDimG, : public DeviceGroupedMultiheadAttentionForward<NumDimG,
...@@ -598,8 +550,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -598,8 +550,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
Acc1BiasTransferSrcScalarPerVector, Acc1BiasTransferSrcScalarPerVector,
LoopSched, LoopSched,
Transform::matrix_padder.PadN, Transform::matrix_padder.PadN,
MaskingSpec != MaskingSpecialization::MaskDisabled, MaskingSpec != MaskingSpecialization::MaskDisabled>;
Deterministic>;
using Block2CTileMap = OffsettedBlockToCTileMap<typename GridwiseGemm::DefaultBlock2CTileMap>; using Block2CTileMap = OffsettedBlockToCTileMap<typename GridwiseGemm::DefaultBlock2CTileMap>;
...@@ -789,8 +740,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -789,8 +740,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
const auto block_2_ctile_map = Block2CTileMap(c_grid_desc_m_n, BlockStart); const auto block_2_ctile_map = Block2CTileMap(c_grid_desc_m_n, BlockStart);
const index_t batch_count = c_grid_desc_g_m_n.GetLength(I0); const index_t batch_count = c_grid_desc_g_m_n.GetLength(I0);
const index_t grid_size_grp = const index_t grid_size_grp =
(Deterministic ? 1 : block_2_ctile_map.CalculateGridSize(c_grid_desc_m_n)) * block_2_ctile_map.CalculateGridSize(c_grid_desc_m_n) * batch_count;
batch_count;
const index_t BlockEnd = grid_size_ + grid_size_grp; const index_t BlockEnd = grid_size_ + grid_size_grp;
// batch stride // batch stride
...@@ -801,7 +751,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -801,7 +751,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
b1_grid_desc_g_n_k, b1_grid_desc_g_n_k,
c_grid_desc_g_m_n, c_grid_desc_g_m_n,
z_grid_desc_g_m_n, z_grid_desc_g_m_n,
type_convert<index_t>(lse_grid_desc_m.GetElementSpaceSize())); type_convert<index_t>(problem_desc.lse_gs_ms_strides[NumDimG - 1]));
// C0 mask // C0 mask
const auto c0_matrix_mask = const auto c0_matrix_mask =
...@@ -967,8 +917,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -967,8 +917,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
CElementwiseOperation, CElementwiseOperation,
has_main_k_block_loop_, has_main_k_block_loop_,
use_dropout_, use_dropout_,
is_lse_storing_, is_lse_storing_>;
Deterministic>;
return launch_and_time_kernel( return launch_and_time_kernel(
stream_config, stream_config,
......
...@@ -320,18 +320,27 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -320,18 +320,27 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
// dV_NOM / dK_NKM Gemm (Gemm2 crr) // dV_NOM / dK_NKM Gemm (Gemm2 crr)
if(O != K) if(O != K)
{ {
std::cerr << "O = " << O << " K = " << K << std::endl;
std::cerr << "SizeK must be equal to SizeO (equal attention head size)" << '\n'; std::cerr << "SizeK must be equal to SizeO (equal attention head size)" << '\n';
return false; return false;
} }
if(!(M == y_grid_desc_m_o.GetLength(I0) && O == y_grid_desc_m_o.GetLength(I1))) if(!(M == y_grid_desc_m_o.GetLength(I0) && O == y_grid_desc_m_o.GetLength(I1)))
{ {
std::cerr << "M = " << M << " O = " << O
<< " y_grid_desc_m_o = " << y_grid_desc_m_o.GetLength(I0) << " , "
<< y_grid_desc_m_o.GetLength(I1) << std::endl;
std::cerr << "Un-matched sizes!" << std::endl;
return false; return false;
} }
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0 && if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0 &&
O % Gemm1NPerBlock == 0)) O % Gemm1NPerBlock == 0))
{ {
std::cerr << "M = " << M << " N = " << N << " O = " << O << std::endl;
std::cerr << "MPerBlock = " << MPerBlock << " NPerBlock = " << NPerBlock
<< " KPerBlock = " << KPerBlock << std::endl;
std::cerr << "Un-aligned sizes!" << std::endl;
return false; return false;
} }
......
...@@ -94,7 +94,6 @@ template <typename FloatAB, ...@@ -94,7 +94,6 @@ template <typename FloatAB,
LoopScheduler LoopSched, LoopScheduler LoopSched,
bool PadN, bool PadN,
bool MaskOutUpperTriangle, bool MaskOutUpperTriangle,
bool Deterministic,
PipelineVersion PipelineVer = PipelineVersion::v1> PipelineVersion PipelineVer = PipelineVersion::v1>
struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2 struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
{ {
...@@ -531,8 +530,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -531,8 +530,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
FloatGemmAcc p_dropout_rescale, FloatGemmAcc p_dropout_rescale,
ck::philox& ph, ck::philox& ph,
const index_t z_random_matrix_offset, const index_t z_random_matrix_offset,
const index_t raw_n_padded, const index_t raw_n_padded)
const index_t block_idx_m)
{ {
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
...@@ -557,7 +555,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -557,7 +555,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
return; return;
} }
const index_t block_work_idx_m = Deterministic ? block_idx_m : block_work_idx[I0]; const index_t block_work_idx_m = block_work_idx[I0];
// HACK: this force m/gemm1_n_block_data_idx_on_grid into SGPR // HACK: this force m/gemm1_n_block_data_idx_on_grid into SGPR
const index_t m_block_data_idx_on_grid = const index_t m_block_data_idx_on_grid =
...@@ -1145,11 +1143,6 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -1145,11 +1143,6 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
0), 0),
tensor_operation::element_wise::PassThrough{}}; tensor_operation::element_wise::PassThrough{}};
if constexpr(Deterministic)
{
block_sync_lds();
}
do do
{ {
auto n_block_data_idx_on_grid = auto n_block_data_idx_on_grid =
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment