Commit d042e931 authored by danyao12's avatar danyao12
Browse files

bwd CShuffleBlockTransferScalarPerVector_NPerBlock&ZDataType

parent f3e61c0a
...@@ -77,6 +77,9 @@ static constexpr ck::index_t NumDimM = 1; ...@@ -77,6 +77,9 @@ static constexpr ck::index_t NumDimM = 1;
static constexpr ck::index_t NumDimN = 1; static constexpr ck::index_t NumDimN = 1;
static constexpr ck::index_t NumDimK = 1; static constexpr ck::index_t NumDimK = 1;
static constexpr ck::index_t NumDimO = 1; static constexpr ck::index_t NumDimO = 1;
// When OutputDataType == F32, CShuffleBlockTransferScalarPerVector_NPerBlock = 4
// When OutputDataType == F16/BF16, CShuffleBlockTransferScalarPerVector_NPerBlock = 8
static constexpr ck::index_t CShuffleBlockTransferScalarPerVector_NPerBlock = 4;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKOPadding; static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKOPadding;
#if USING_MASK #if USING_MASK
...@@ -163,8 +166,8 @@ using DeviceGemmInstance = ...@@ -163,8 +166,8 @@ using DeviceGemmInstance =
1, // CShuffleMXdlPerWavePerShuffle 1, // CShuffleMXdlPerWavePerShuffle
1, // CShuffleNXdlPerWavePerShuffle 1, // CShuffleNXdlPerWavePerShuffle
S<1, 64, 1, 4>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock S<1, 64, 1, 4>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
4, // CShuffleBlockTransferScalarPerVector_NPerBlock CShuffleBlockTransferScalarPerVector_NPerBlock, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec>; // MaskingSpecialization MaskingSpec>; // MaskingSpecialization
#elif(DIM <= 64) #elif(DIM <= 64)
using DeviceGemmInstance = using DeviceGemmInstance =
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1< ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1<
...@@ -232,8 +235,8 @@ using DeviceGemmInstance = ...@@ -232,8 +235,8 @@ using DeviceGemmInstance =
1, // CShuffleMXdlPerWavePerShuffle 1, // CShuffleMXdlPerWavePerShuffle
2, // CShuffleNXdlPerWavePerShuffle 2, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
4, // CShuffleBlockTransferScalarPerVector_NPerBlock CShuffleBlockTransferScalarPerVector_NPerBlock, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec>; // MaskingSpecialization MaskingSpec>; // MaskingSpecialization
// using DeviceGemmInstance = // using DeviceGemmInstance =
// ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2< // ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2<
...@@ -301,8 +304,8 @@ using DeviceGemmInstance = ...@@ -301,8 +304,8 @@ using DeviceGemmInstance =
// 1, // CShuffleMXdlPerWavePerShuffle // 1, // CShuffleMXdlPerWavePerShuffle
// 2, // CShuffleNXdlPerWavePerShuffle // 2, // CShuffleNXdlPerWavePerShuffle
// S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock // S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
// 4, // CShuffleBlockTransferScalarPerVector_NPerBlock // CShuffleBlockTransferScalarPerVector_NPerBlock,
// MaskingSpec>; // MaskingSpecialization // MaskingSpec>;
#elif(DIM <= 128) #elif(DIM <= 128)
using DeviceGemmInstance = using DeviceGemmInstance =
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2< ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2<
...@@ -370,8 +373,8 @@ using DeviceGemmInstance = ...@@ -370,8 +373,8 @@ using DeviceGemmInstance =
1, // CShuffleMXdlPerWavePerShuffle 1, // CShuffleMXdlPerWavePerShuffle
4, // CShuffleNXdlPerWavePerShuffle 4, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
4, // CShuffleBlockTransferScalarPerVector_NPerBlock CShuffleBlockTransferScalarPerVector_NPerBlock, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec>; // MaskingSpecialization MaskingSpec>; // MaskingSpecialization
#endif #endif
// Ref Gemm0: S = alpha * Q * K^T // Ref Gemm0: S = alpha * Q * K^T
......
...@@ -76,6 +76,9 @@ static constexpr ck::index_t NumDimM = 1; ...@@ -76,6 +76,9 @@ static constexpr ck::index_t NumDimM = 1;
static constexpr ck::index_t NumDimN = 1; static constexpr ck::index_t NumDimN = 1;
static constexpr ck::index_t NumDimK = 1; static constexpr ck::index_t NumDimK = 1;
static constexpr ck::index_t NumDimO = 1; static constexpr ck::index_t NumDimO = 1;
// When OutputDataType == F32, CShuffleBlockTransferScalarPerVector_NPerBlock = 4
// When OutputDataType == F16/BF16, CShuffleBlockTransferScalarPerVector_NPerBlock = 8
static constexpr ck::index_t CShuffleBlockTransferScalarPerVector_NPerBlock = 4;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKOPadding; static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKOPadding;
#if USING_MASK #if USING_MASK
...@@ -162,8 +165,8 @@ using DeviceGemmInstance = ...@@ -162,8 +165,8 @@ using DeviceGemmInstance =
1, // CShuffleMXdlPerWavePerShuffle 1, // CShuffleMXdlPerWavePerShuffle
1, // CShuffleNXdlPerWavePerShuffle 1, // CShuffleNXdlPerWavePerShuffle
S<1, 64, 1, 4>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock S<1, 64, 1, 4>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
4, // CShuffleBlockTransferScalarPerVector_NPerBlock CShuffleBlockTransferScalarPerVector_NPerBlock, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec>; // MaskingSpecialization MaskingSpec>; // MaskingSpecialization
#elif(DIM <= 64) #elif(DIM <= 64)
using DeviceGemmInstance = using DeviceGemmInstance =
ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1< ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1<
...@@ -231,8 +234,8 @@ using DeviceGemmInstance = ...@@ -231,8 +234,8 @@ using DeviceGemmInstance =
1, // CShuffleMXdlPerWavePerShuffle 1, // CShuffleMXdlPerWavePerShuffle
2, // CShuffleNXdlPerWavePerShuffle 2, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
4, // CShuffleBlockTransferScalarPerVector_NPerBlock CShuffleBlockTransferScalarPerVector_NPerBlock, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec>; // MaskingSpecialization MaskingSpec>; // MaskingSpecialization
// using DeviceGemmInstance = // using DeviceGemmInstance =
// ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2< // ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2<
...@@ -300,8 +303,8 @@ using DeviceGemmInstance = ...@@ -300,8 +303,8 @@ using DeviceGemmInstance =
// 1, // CShuffleMXdlPerWavePerShuffle // 1, // CShuffleMXdlPerWavePerShuffle
// 2, // CShuffleNXdlPerWavePerShuffle // 2, // CShuffleNXdlPerWavePerShuffle
// S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock // S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
// 4, // CShuffleBlockTransferScalarPerVector_NPerBlock // CShuffleBlockTransferScalarPerVector_NPerBlock,
// MaskingSpec>; // MaskingSpecialization // MaskingSpec>;
#elif(DIM <= 128) #elif(DIM <= 128)
using DeviceGemmInstance = using DeviceGemmInstance =
ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2< ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2<
...@@ -369,8 +372,8 @@ using DeviceGemmInstance = ...@@ -369,8 +372,8 @@ using DeviceGemmInstance =
1, // CShuffleMXdlPerWavePerShuffle 1, // CShuffleMXdlPerWavePerShuffle
4, // CShuffleNXdlPerWavePerShuffle 4, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
4, // CShuffleBlockTransferScalarPerVector_NPerBlock CShuffleBlockTransferScalarPerVector_NPerBlock, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec>; // MaskingSpecialization MaskingSpec>; // MaskingSpecialization
#endif #endif
// Ref Gemm0: S = alpha * Q * K^T // Ref Gemm0: S = alpha * Q * K^T
......
...@@ -601,6 +601,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -601,6 +601,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
using GridwiseGemm = GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1< using GridwiseGemm = GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1<
InputDataType, // TODO: distinguish A/B datatype InputDataType, // TODO: distinguish A/B datatype
OutputDataType, OutputDataType,
ZDataType,
GemmDataType, GemmDataType,
GemmAccDataType, GemmAccDataType,
CShuffleDataType, CShuffleDataType,
......
...@@ -600,6 +600,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -600,6 +600,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
using GridwiseGemm = GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2< using GridwiseGemm = GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2<
InputDataType, // TODO: distinguish A/B datatype InputDataType, // TODO: distinguish A/B datatype
OutputDataType, OutputDataType,
ZDataType,
GemmDataType, GemmDataType,
GemmAccDataType, GemmAccDataType,
CShuffleDataType, CShuffleDataType,
......
...@@ -95,7 +95,7 @@ __global__ void ...@@ -95,7 +95,7 @@ __global__ void
const index_t global_thread_id = get_thread_global_1d_id(); const index_t global_thread_id = get_thread_global_1d_id();
ck::philox ph(seed, global_thread_id, offset); ck::philox ph(seed, global_thread_id, offset);
unsigned short* z_matrix_ptr = auto z_matrix_ptr =
(arg_ptr[group_id].p_z_grid_ == nullptr ? nullptr (arg_ptr[group_id].p_z_grid_ == nullptr ? nullptr
: arg_ptr[group_id].p_z_grid_ + z_batch_offset); : arg_ptr[group_id].p_z_grid_ + z_batch_offset);
...@@ -537,6 +537,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -537,6 +537,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1
using GridwiseGemm = GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1< using GridwiseGemm = GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1<
InputDataType, // TODO: distinguish A/B datatype InputDataType, // TODO: distinguish A/B datatype
OutputDataType, OutputDataType,
ZDataType,
GemmDataType, GemmDataType,
GemmAccDataType, GemmAccDataType,
CShuffleDataType, CShuffleDataType,
......
...@@ -95,7 +95,7 @@ __global__ void ...@@ -95,7 +95,7 @@ __global__ void
const index_t global_thread_id = get_thread_global_1d_id(); const index_t global_thread_id = get_thread_global_1d_id();
ck::philox ph(seed, global_thread_id, offset); ck::philox ph(seed, global_thread_id, offset);
unsigned short* z_matrix_ptr = auto z_matrix_ptr =
(arg_ptr[group_id].p_z_grid_ == nullptr ? nullptr (arg_ptr[group_id].p_z_grid_ == nullptr ? nullptr
: arg_ptr[group_id].p_z_grid_ + z_batch_offset); : arg_ptr[group_id].p_z_grid_ + z_batch_offset);
...@@ -530,6 +530,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -530,6 +530,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
using GridwiseGemm = GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2< using GridwiseGemm = GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2<
InputDataType, // TODO: distinguish A/B datatype InputDataType, // TODO: distinguish A/B datatype
OutputDataType, OutputDataType,
ZDataType,
GemmDataType, GemmDataType,
GemmAccDataType, GemmAccDataType,
CShuffleDataType, CShuffleDataType,
......
...@@ -22,6 +22,7 @@ namespace ck { ...@@ -22,6 +22,7 @@ namespace ck {
template <typename InputDataType, template <typename InputDataType,
typename OutputDataType, typename OutputDataType,
typename ZDataType,
typename GemmDataType, typename GemmDataType,
typename FloatGemmAcc, typename FloatGemmAcc,
typename FloatCShuffle, typename FloatCShuffle,
...@@ -1237,7 +1238,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -1237,7 +1238,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
typename YGradGridDesc_O0_M_O1> typename YGradGridDesc_O0_M_O1>
__device__ static void Run(const InputDataType* __restrict__ p_q_grid, __device__ static void Run(const InputDataType* __restrict__ p_q_grid,
const InputDataType* __restrict__ p_k_grid, const InputDataType* __restrict__ p_k_grid,
unsigned short* __restrict__ p_z_grid, ZDataType* __restrict__ p_z_grid,
const InputDataType* __restrict__ p_v_grid, const InputDataType* __restrict__ p_v_grid,
const InputDataType* __restrict__ p_y_grid, const InputDataType* __restrict__ p_y_grid,
const FloatLSE* __restrict__ p_lse_grid, const FloatLSE* __restrict__ p_lse_grid,
...@@ -1553,7 +1554,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -1553,7 +1554,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
auto z_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3< auto z_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3<
ushort, ushort,
ushort, ZDataType,
decltype(z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5), decltype(z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5),
decltype(z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5), decltype(z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5),
tensor_operation::element_wise::PassThrough, tensor_operation::element_wise::PassThrough,
......
...@@ -22,6 +22,7 @@ namespace ck { ...@@ -22,6 +22,7 @@ namespace ck {
template <typename InputDataType, template <typename InputDataType,
typename OutputDataType, typename OutputDataType,
typename ZDataType,
typename GemmDataType, typename GemmDataType,
typename FloatGemmAcc, typename FloatGemmAcc,
typename FloatCShuffle, typename FloatCShuffle,
...@@ -1147,7 +1148,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -1147,7 +1148,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
typename YGradGridDesc_M0_O_M1> typename YGradGridDesc_M0_O_M1>
__device__ static void Run(const InputDataType* __restrict__ p_q_grid, __device__ static void Run(const InputDataType* __restrict__ p_q_grid,
const InputDataType* __restrict__ p_k_grid, const InputDataType* __restrict__ p_k_grid,
unsigned short* __restrict__ p_z_grid, ZDataType* __restrict__ p_z_grid,
const InputDataType* __restrict__ p_v_grid, const InputDataType* __restrict__ p_v_grid,
const InputDataType* __restrict__ p_y_grid, const InputDataType* __restrict__ p_y_grid,
const FloatLSE* __restrict__ p_lse_grid, const FloatLSE* __restrict__ p_lse_grid,
...@@ -1485,7 +1486,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -1485,7 +1486,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
auto z_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3< auto z_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3<
ushort, ushort,
ushort, ZDataType,
decltype(z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5), decltype(z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5),
decltype(z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5), decltype(z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5),
tensor_operation::element_wise::PassThrough, tensor_operation::element_wise::PassThrough,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment