Commit d042e931 authored by danyao12's avatar danyao12
Browse files

bwd CShuffleBlockTransferScalarPerVector_NPerBlock&ZDataType

parent f3e61c0a
......@@ -77,6 +77,9 @@ static constexpr ck::index_t NumDimM = 1;
static constexpr ck::index_t NumDimN = 1;
static constexpr ck::index_t NumDimK = 1;
static constexpr ck::index_t NumDimO = 1;
// When OutputDataType == F32, CShuffleBlockTransferScalarPerVector_NPerBlock = 4
// When OutputDataType == F16/BF16, CShuffleBlockTransferScalarPerVector_NPerBlock = 8
static constexpr ck::index_t CShuffleBlockTransferScalarPerVector_NPerBlock = 4;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKOPadding;
#if USING_MASK
......@@ -163,7 +166,7 @@ using DeviceGemmInstance =
1, // CShuffleMXdlPerWavePerShuffle
1, // CShuffleNXdlPerWavePerShuffle
S<1, 64, 1, 4>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
4, // CShuffleBlockTransferScalarPerVector_NPerBlock
CShuffleBlockTransferScalarPerVector_NPerBlock, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec>; // MaskingSpecialization
#elif(DIM <= 64)
using DeviceGemmInstance =
......@@ -232,7 +235,7 @@ using DeviceGemmInstance =
1, // CShuffleMXdlPerWavePerShuffle
2, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
4, // CShuffleBlockTransferScalarPerVector_NPerBlock
CShuffleBlockTransferScalarPerVector_NPerBlock, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec>; // MaskingSpecialization
// using DeviceGemmInstance =
......@@ -301,8 +304,8 @@ using DeviceGemmInstance =
// 1, // CShuffleMXdlPerWavePerShuffle
// 2, // CShuffleNXdlPerWavePerShuffle
// S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
// 4, // CShuffleBlockTransferScalarPerVector_NPerBlock
// MaskingSpec>; // MaskingSpecialization
// CShuffleBlockTransferScalarPerVector_NPerBlock,
// MaskingSpec>;
#elif(DIM <= 128)
using DeviceGemmInstance =
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2<
......@@ -370,7 +373,7 @@ using DeviceGemmInstance =
1, // CShuffleMXdlPerWavePerShuffle
4, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
4, // CShuffleBlockTransferScalarPerVector_NPerBlock
CShuffleBlockTransferScalarPerVector_NPerBlock, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec>; // MaskingSpecialization
#endif
......
......@@ -76,6 +76,9 @@ static constexpr ck::index_t NumDimM = 1;
static constexpr ck::index_t NumDimN = 1;
static constexpr ck::index_t NumDimK = 1;
static constexpr ck::index_t NumDimO = 1;
// When OutputDataType == F32, CShuffleBlockTransferScalarPerVector_NPerBlock = 4
// When OutputDataType == F16/BF16, CShuffleBlockTransferScalarPerVector_NPerBlock = 8
static constexpr ck::index_t CShuffleBlockTransferScalarPerVector_NPerBlock = 4;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKOPadding;
#if USING_MASK
......@@ -162,7 +165,7 @@ using DeviceGemmInstance =
1, // CShuffleMXdlPerWavePerShuffle
1, // CShuffleNXdlPerWavePerShuffle
S<1, 64, 1, 4>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
4, // CShuffleBlockTransferScalarPerVector_NPerBlock
CShuffleBlockTransferScalarPerVector_NPerBlock, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec>; // MaskingSpecialization
#elif(DIM <= 64)
using DeviceGemmInstance =
......@@ -231,7 +234,7 @@ using DeviceGemmInstance =
1, // CShuffleMXdlPerWavePerShuffle
2, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
4, // CShuffleBlockTransferScalarPerVector_NPerBlock
CShuffleBlockTransferScalarPerVector_NPerBlock, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec>; // MaskingSpecialization
// using DeviceGemmInstance =
......@@ -300,8 +303,8 @@ using DeviceGemmInstance =
// 1, // CShuffleMXdlPerWavePerShuffle
// 2, // CShuffleNXdlPerWavePerShuffle
// S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
// 4, // CShuffleBlockTransferScalarPerVector_NPerBlock
// MaskingSpec>; // MaskingSpecialization
// CShuffleBlockTransferScalarPerVector_NPerBlock,
// MaskingSpec>;
#elif(DIM <= 128)
using DeviceGemmInstance =
ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2<
......@@ -369,7 +372,7 @@ using DeviceGemmInstance =
1, // CShuffleMXdlPerWavePerShuffle
4, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
4, // CShuffleBlockTransferScalarPerVector_NPerBlock
CShuffleBlockTransferScalarPerVector_NPerBlock, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec>; // MaskingSpecialization
#endif
......
......@@ -601,6 +601,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
using GridwiseGemm = GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1<
InputDataType, // TODO: distinguish A/B datatype
OutputDataType,
ZDataType,
GemmDataType,
GemmAccDataType,
CShuffleDataType,
......
......@@ -600,6 +600,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
using GridwiseGemm = GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2<
InputDataType, // TODO: distinguish A/B datatype
OutputDataType,
ZDataType,
GemmDataType,
GemmAccDataType,
CShuffleDataType,
......
......@@ -95,7 +95,7 @@ __global__ void
const index_t global_thread_id = get_thread_global_1d_id();
ck::philox ph(seed, global_thread_id, offset);
unsigned short* z_matrix_ptr =
auto z_matrix_ptr =
(arg_ptr[group_id].p_z_grid_ == nullptr ? nullptr
: arg_ptr[group_id].p_z_grid_ + z_batch_offset);
......@@ -537,6 +537,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1
using GridwiseGemm = GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1<
InputDataType, // TODO: distinguish A/B datatype
OutputDataType,
ZDataType,
GemmDataType,
GemmAccDataType,
CShuffleDataType,
......
......@@ -95,7 +95,7 @@ __global__ void
const index_t global_thread_id = get_thread_global_1d_id();
ck::philox ph(seed, global_thread_id, offset);
unsigned short* z_matrix_ptr =
auto z_matrix_ptr =
(arg_ptr[group_id].p_z_grid_ == nullptr ? nullptr
: arg_ptr[group_id].p_z_grid_ + z_batch_offset);
......@@ -530,6 +530,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
using GridwiseGemm = GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2<
InputDataType, // TODO: distinguish A/B datatype
OutputDataType,
ZDataType,
GemmDataType,
GemmAccDataType,
CShuffleDataType,
......
......@@ -22,6 +22,7 @@ namespace ck {
template <typename InputDataType,
typename OutputDataType,
typename ZDataType,
typename GemmDataType,
typename FloatGemmAcc,
typename FloatCShuffle,
......@@ -1237,7 +1238,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
typename YGradGridDesc_O0_M_O1>
__device__ static void Run(const InputDataType* __restrict__ p_q_grid,
const InputDataType* __restrict__ p_k_grid,
unsigned short* __restrict__ p_z_grid,
ZDataType* __restrict__ p_z_grid,
const InputDataType* __restrict__ p_v_grid,
const InputDataType* __restrict__ p_y_grid,
const FloatLSE* __restrict__ p_lse_grid,
......@@ -1553,7 +1554,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
auto z_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3<
ushort,
ushort,
ZDataType,
decltype(z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5),
decltype(z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5),
tensor_operation::element_wise::PassThrough,
......
......@@ -22,6 +22,7 @@ namespace ck {
template <typename InputDataType,
typename OutputDataType,
typename ZDataType,
typename GemmDataType,
typename FloatGemmAcc,
typename FloatCShuffle,
......@@ -1147,7 +1148,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
typename YGradGridDesc_M0_O_M1>
__device__ static void Run(const InputDataType* __restrict__ p_q_grid,
const InputDataType* __restrict__ p_k_grid,
unsigned short* __restrict__ p_z_grid,
ZDataType* __restrict__ p_z_grid,
const InputDataType* __restrict__ p_v_grid,
const InputDataType* __restrict__ p_y_grid,
const FloatLSE* __restrict__ p_lse_grid,
......@@ -1485,7 +1486,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
auto z_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3<
ushort,
ushort,
ZDataType,
decltype(z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5),
decltype(z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5),
tensor_operation::element_wise::PassThrough,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment