Commit be38f68d authored by ltqin's avatar ltqin
Browse files

add padding code for M

parent a188073b
......@@ -102,8 +102,8 @@ static constexpr bool Deterministic = false;
// If 32 < DIM <= 64 , ues prototype1 2nd template.
// If 64 < DIM <= 128, ues prototype2 2nd template.
#if(DIM <= 32)
using DeviceGemmInstance =
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Qloop_Phased_Xdl_CShuffle_V1<
using DeviceGemmInstance = ck::tensor_operation::device::
DeviceBatchedMultiheadAttentionBackward_Qloop_Phased_Xdl_CShuffle_V1<
NumDimG,
NumDimM,
NumDimN,
......@@ -172,8 +172,8 @@ using DeviceGemmInstance =
MaskingSpec, // MaskingSpecialization
Deterministic>;
#elif(DIM <= 64)
using DeviceGemmInstance =
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Qloop_Phased_Xdl_CShuffle_V1<
using DeviceGemmInstance = ck::tensor_operation::device::
DeviceBatchedMultiheadAttentionBackward_Qloop_Phased_Xdl_CShuffle_V1<
NumDimG,
NumDimM,
NumDimN,
......
......@@ -25,7 +25,7 @@ Kernel outputs:
#define PRINT_HOST 0
#define USING_MASK 0
#define DIM 32 // DIM should be a multiple of 8.
#define DIM 128 // DIM should be a multiple of 8.
#include <iostream>
#include <numeric>
......
......@@ -24,7 +24,7 @@ Kernel outputs:
*/
#define USING_MASK 0
#define DIM 128 // DIM should be a multiple of 8.
#define DIM 32 // DIM should be a multiple of 8.
#include <iostream>
#include <numeric>
......
......@@ -337,6 +337,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
static constexpr index_t NumAcc0Bias = Acc0BiasDataType::Size();
static constexpr index_t NumAcc1Bias = Acc1BiasDataType::Size();
static constexpr index_t DMPerBlock = BlockSize;
// TODO: implement bias combination
static_assert(NumAcc0Bias == 0 && NumAcc0Bias == 0, "Bias addition is unimplemented");
......@@ -371,6 +372,15 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
B1Spec,
CSpec>;
using DTransform = TransformBatchedContractionContractionToBatchedGemmGemm<
Sequence<NumDimG, NumDimM, NumDimN, NumDimK, NumDimO>,
Sequence<DMPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock>,
GemmSpecialization::MNKOPadding,
ASpec,
BSpec,
B1Spec,
CSpec>;
/*
Descriptors for inputs:
......@@ -596,6 +606,19 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
}
}
static auto MakeDGridDescriptor_M(index_t MRaw)
{
const auto d_grid_desc_mraw = make_naive_tensor_descriptor_packed(make_tuple(MRaw));
const auto M = math::integer_divide_ceil(MRaw, DMPerBlock) * DMPerBlock;
const auto MPad = M - MRaw;
return transform_tensor_descriptor(d_grid_desc_mraw,
make_tuple(make_right_pad_transform(MRaw, MPad)),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0>{}));
}
using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1({}, {}));
using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1({}, {}));
using B1GridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1({}, {}));
......@@ -606,7 +629,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
using B1GridDesc_G_N_K = decltype(Transform::MakeB1GridDescriptor_G_N_K({}, {}));
using CGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {}));
using ZGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {}));
using DGridDesc_M = decltype(MakeLSEGridDescriptor_M(1));
using DYGridDesc_M_O = decltype(DTransform::MakeCGridDescriptor_M_N({}, {}));
using DGridDesc_M = decltype(MakeDGridDescriptor_M(1));
using KGridDesc_N_K = decltype(Transform::MakeB0GridDescriptor_N_K({}, {}));
using YGradGridDesc_O0_M_O1 = decltype(MakeYGradGridDescriptor_O0_M_O1({}, {}));
......@@ -705,7 +729,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
B1GridDesc_BK0_N_BK1,
YGridDesc_M_O,
LSEGridDesc_M,
DGridDesc_M,
LSEGridDesc_M,
NumGemmKPrefetchStage,
BlockSize,
MPerBlock,
......@@ -754,7 +778,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
YGridDesc_M_O,
DGridDesc_M,
BlockSize,
BlockSize,
DMPerBlock,
DKPerBlock>;
// Argument
struct Argument : public BaseArgument
......@@ -818,8 +842,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
b1_gs_gemm1ns_gemm1ks_lengths, b1_gs_gemm1ns_gemm1ks_strides)},
y_grid_desc_m_o_{Transform::MakeCGridDescriptor_M_N(c_gs_ms_gemm1ns_lengths,
c_gs_ms_gemm1ns_strides)},
d_y_grid_desc_m_o_{DTransform::MakeCGridDescriptor_M_N(c_gs_ms_gemm1ns_lengths,
c_gs_ms_gemm1ns_strides)},
lse_grid_desc_m_{DeviceOp::MakeLSEGridDescriptor_M(lse_gs_ms_lengths[NumDimG])},
d_grid_desc_m_{DeviceOp::MakeLSEGridDescriptor_M(d_gs_ms_lengths[NumDimG])},
d_grid_desc_m_{DeviceOp::MakeDGridDescriptor_M(d_gs_ms_lengths[NumDimG])},
k_grid_desc_n_k_{
Transform::MakeB0GridDescriptor_N_K(b_gs_ns_ks_lengths, b_gs_ns_ks_strides)},
ygrad_grid_desc_o0_m_o1_{DeviceOp::MakeYGradGridDescriptor_O0_M_O1(
......@@ -836,7 +862,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
z_grid_desc_g_m_n_{
Transform::MakeCGridDescriptor_G_M_N(z_gs_ms_ns_lengths, z_gs_ms_ns_strides)},
block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(k_grid_desc_n_k_)},
d_block_2_ctile_map_{GridwiseYDotYGrad::MakeDefaultBlock2CTileMap(y_grid_desc_m_o_)},
d_block_2_ctile_map_{
GridwiseYDotYGrad::MakeDefaultBlock2CTileMap(d_y_grid_desc_m_o_)},
d_y_grid_desc_mblock_mperblock_oblock_operblock_{},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
......@@ -881,7 +908,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3(z_grid_desc_m_n_);
d_y_grid_desc_mblock_mperblock_oblock_operblock_ =
GridwiseYDotYGrad::MakeYGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
y_grid_desc_m_o_);
d_y_grid_desc_m_o_);
// Print();
m_raw_padded_ = GridwiseGemm::GetPaddedSize(raw_lengths_mz_nz_kz_gemm1nz_[0]);
......@@ -932,6 +959,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
ZGridDesc_M_N z_grid_desc_m_n_;
B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1_;
YGridDesc_M_O y_grid_desc_m_o_;
DYGridDesc_M_O d_y_grid_desc_m_o_;
LSEGridDesc_M lse_grid_desc_m_;
DGridDesc_M d_grid_desc_m_;
KGridDesc_N_K k_grid_desc_n_k_;
......@@ -998,15 +1026,17 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
const index_t grid_size =
(Deterministic
? 1
: arg.d_block_2_ctile_map_.CalculateGridSize(arg.y_grid_desc_m_o_)) *
: arg.d_block_2_ctile_map_.CalculateGridSize(arg.d_y_grid_desc_m_o_)) *
arg.batch_count_;
std::cout << "grid_size: " << grid_size
<< "grid_size / arg.batch_count_: " << grid_size / arg.batch_count_
<< " arg.batch_count_: " << arg.batch_count_ << std::endl;
std::cout << "MPerBlock: " << MPerBlock << " Gemm1NPerBlock: " << Gemm1NPerBlock
<< std::endl;
std::cout << "arg.y_grid_desc_m_o_: {" << arg.y_grid_desc_m_o_.GetLength(I0) << ","
<< arg.y_grid_desc_m_o_.GetLength(I1) << "}" << std::endl;
std::cout << "arg.d_y_grid_desc_m_o_: {" << arg.d_y_grid_desc_m_o_.GetLength(I0)
<< "," << arg.d_y_grid_desc_m_o_.GetLength(I1) << "}" << std::endl;
std::cout << "arg.d_grid_desc_m_: {" << arg.d_grid_desc_m_.GetLength(I0) << "}"
<< std::endl;
auto launch_kernel = [&]() {
const auto kernel = kernel_batched_multihead_attention_backward_ydotygrad_v1<
......@@ -1062,7 +1092,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3,
DeviceOp::B1GridDesc_BK0_N_BK1,
DeviceOp::LSEGridDesc_M,
DeviceOp::DGridDesc_M,
DeviceOp::LSEGridDesc_M,
DeviceOp::YGradGridDesc_O0_M_O1,
typename GridwiseGemm::DefaultBlock2CTileMap,
ComputeBasePtrOfStridedBatch,
......@@ -1096,7 +1126,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
arg.c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_,
arg.b1_grid_desc_bk0_n_bk1_,
arg.lse_grid_desc_m_,
arg.d_grid_desc_m_,
arg.lse_grid_desc_m_,
arg.ygrad_grid_desc_o0_m_o1_,
arg.block_2_ctile_map_,
arg.batch_count_,
......@@ -1138,7 +1168,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
}
// TODO: Check if tensor specialization & strides mismatch
if(!GridwiseYDotYGrad::CheckValidity(arg.y_grid_desc_m_o_, arg.d_block_2_ctile_map_))
if(!GridwiseYDotYGrad::CheckValidity(arg.d_y_grid_desc_m_o_, arg.d_block_2_ctile_map_))
{
return false;
}
......
......@@ -342,6 +342,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
static constexpr index_t NumAcc0Bias = Acc0BiasDataType::Size();
static constexpr index_t NumAcc1Bias = Acc1BiasDataType::Size();
static constexpr index_t DMPerBlock = BlockSize;
// TODO: implement bias combination
static_assert(NumAcc0Bias == 0 && NumAcc0Bias == 0, "Bias addition is unimplemented");
......@@ -377,6 +378,15 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
B1Spec,
CSpec>;
using DTransform = TransformBatchedContractionContractionToBatchedGemmGemm<
Sequence<NumDimG, NumDimM, NumDimN, NumDimK, NumDimO>,
Sequence<DMPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock>,
GemmSpecialization::MNKOPadding,
ASpec,
BSpec,
B1Spec,
CSpec>;
/*
Descriptors for inputs:
......@@ -602,6 +612,19 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
}
}
static auto MakeDGridDescriptor_M(index_t MRaw)
{
const auto d_grid_desc_mraw = make_naive_tensor_descriptor_packed(make_tuple(MRaw));
const auto M = math::integer_divide_ceil(MRaw, BlockSize) * BlockSize;
const auto MPad = M - MRaw;
return transform_tensor_descriptor(d_grid_desc_mraw,
make_tuple(make_right_pad_transform(MRaw, MPad)),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0>{}));
}
using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1({}, {}));
using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1({}, {}));
using B1GridDesc_BK0_N_BK1 = decltype(MakeB1GridDescriptor_BK0_N_BK1({}, {}));
......@@ -612,7 +635,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
using B1GridDesc_G_N_K = decltype(Transform::MakeB1GridDescriptor_G_N_K({}, {}));
using CGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {}));
using ZGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {}));
using DGridDesc_M = decltype(MakeLSEGridDescriptor_M(1));
using DYGridDesc_M_O = decltype(DTransform::MakeCGridDescriptor_M_N({}, {}));
using DGridDesc_M = decltype(MakeDGridDescriptor_M(1));
using KGridDesc_N_K = decltype(Transform::MakeB0GridDescriptor_N_K({}, {}));
using YGradGridDesc_M0_O_M1 = decltype(MakeYGradGridDescriptor_M0_O_M1(YGridDesc_M_O{}));
......@@ -711,7 +735,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
B1GridDesc_BK0_N_BK1,
YGridDesc_M_O,
LSEGridDesc_M,
DGridDesc_M,
LSEGridDesc_M,
NumGemmKPrefetchStage,
BlockSize,
MPerBlock,
......@@ -768,7 +792,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
YGridDesc_M_O,
DGridDesc_M,
BlockSize,
BlockSize,
DMPerBlock,
DKPerBlock>;
// Argument
struct Argument : public BaseArgument
......@@ -832,8 +856,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
b1_gs_gemm1ns_gemm1ks_lengths, b1_gs_gemm1ns_gemm1ks_strides)},
y_grid_desc_m_o_{Transform::MakeCGridDescriptor_M_N(c_gs_ms_gemm1ns_lengths,
c_gs_ms_gemm1ns_strides)},
d_y_grid_desc_m_o_{DTransform::MakeCGridDescriptor_M_N(c_gs_ms_gemm1ns_lengths,
c_gs_ms_gemm1ns_strides)},
lse_grid_desc_m_{DeviceOp::MakeLSEGridDescriptor_M(lse_gs_ms_lengths[NumDimG])},
d_grid_desc_m_{DeviceOp::MakeLSEGridDescriptor_M(d_gs_ms_lengths[NumDimG])},
d_grid_desc_m_{DeviceOp::MakeDGridDescriptor_M(d_gs_ms_lengths[NumDimG])},
k_grid_desc_n_k_{
Transform::MakeB0GridDescriptor_N_K(b_gs_ns_ks_lengths, b_gs_ns_ks_strides)},
ygrad_grid_desc_m0_o_m1_{DeviceOp::MakeYGradGridDescriptor_M0_O_M1(y_grid_desc_m_o_)},
......@@ -849,7 +875,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
z_grid_desc_g_m_n_{
Transform::MakeCGridDescriptor_G_M_N(z_gs_ms_ns_lengths, z_gs_ms_ns_strides)},
block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(k_grid_desc_n_k_)},
d_block_2_ctile_map_{GridwiseYDotYGrad::MakeDefaultBlock2CTileMap(y_grid_desc_m_o_)},
d_block_2_ctile_map_{
GridwiseYDotYGrad::MakeDefaultBlock2CTileMap(d_y_grid_desc_m_o_)},
d_y_grid_desc_mblock_mperblock_oblock_operblock_{},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
......@@ -894,7 +921,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3(z_grid_desc_m_n_);
d_y_grid_desc_mblock_mperblock_oblock_operblock_ =
GridwiseYDotYGrad::MakeYGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
y_grid_desc_m_o_);
d_y_grid_desc_m_o_);
// Print();
m_raw_padded_ = GridwiseGemm::GetPaddedSize(raw_lengths_mz_nz_kz_gemm1nz_[0]);
......@@ -945,6 +972,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
ZGridDesc_M_N z_grid_desc_m_n_;
B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1_;
YGridDesc_M_O y_grid_desc_m_o_;
DYGridDesc_M_O d_y_grid_desc_m_o_;
LSEGridDesc_M lse_grid_desc_m_;
DGridDesc_M d_grid_desc_m_;
KGridDesc_N_K k_grid_desc_n_k_;
......@@ -1011,15 +1039,17 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
const index_t grid_size =
(Deterministic
? 1
: arg.d_block_2_ctile_map_.CalculateGridSize(arg.y_grid_desc_m_o_)) *
: arg.d_block_2_ctile_map_.CalculateGridSize(arg.d_y_grid_desc_m_o_)) *
arg.batch_count_;
std::cout << "grid_size: " << grid_size
<< "grid_size / arg.batch_count_: " << grid_size / arg.batch_count_
<< " arg.batch_count_: " << arg.batch_count_ << std::endl;
std::cout << "MPerBlock: " << MPerBlock << " Gemm1NPerBlock: " << Gemm1NPerBlock
<< std::endl;
std::cout << "arg.y_grid_desc_m_o_: {" << arg.y_grid_desc_m_o_.GetLength(I0) << ","
<< arg.y_grid_desc_m_o_.GetLength(I1) << "}" << std::endl;
std::cout << "arg.d_y_grid_desc_m_o_: {" << arg.d_y_grid_desc_m_o_.GetLength(I0)
<< "," << arg.d_y_grid_desc_m_o_.GetLength(I1) << "}" << std::endl;
std::cout << "arg.d_grid_desc_m_: {" << arg.d_grid_desc_m_.GetLength(I0) << "}"
<< std::endl;
auto launch_kernel = [&]() {
const auto kernel = kernel_batched_multihead_attention_backward_ydotygrad_v2<
......@@ -1079,7 +1109,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3,
DeviceOp::B1GridDesc_BK0_N_BK1,
DeviceOp::LSEGridDesc_M,
DeviceOp::DGridDesc_M,
DeviceOp::LSEGridDesc_M,
DeviceOp::YGradGridDesc_M0_O_M1,
typename GridwiseGemm::DefaultBlock2CTileMap,
ComputeBasePtrOfStridedBatch,
......@@ -1113,7 +1143,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
arg.c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_,
arg.b1_grid_desc_bk0_n_bk1_,
arg.lse_grid_desc_m_,
arg.d_grid_desc_m_,
arg.lse_grid_desc_m_,
arg.ygrad_grid_desc_m0_o_m1_,
arg.block_2_ctile_map_,
arg.batch_count_,
......@@ -1165,7 +1195,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
}
// TODO: Check if tensor specialization & strides mismatch
if(!GridwiseYDotYGrad::CheckValidity(arg.y_grid_desc_m_o_, arg.d_block_2_ctile_map_))
if(!GridwiseYDotYGrad::CheckValidity(arg.d_y_grid_desc_m_o_, arg.d_block_2_ctile_map_))
{
return false;
}
......
......@@ -315,6 +315,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
static constexpr index_t NumAcc0Bias = Acc0BiasDataType::Size();
static constexpr index_t NumAcc1Bias = Acc1BiasDataType::Size();
static constexpr index_t DMPerBlock = BlockSize;
// TODO: implement bias combination
static_assert(NumAcc0Bias == 0 && NumAcc0Bias == 0, "Bias addition is unimplemented");
......@@ -378,6 +379,15 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
B1Spec,
CSpec>;
using DTransform = TransformBatchedContractionContractionToBatchedGemmGemm<
Sequence<NumDimG, NumDimM, NumDimN, NumDimK, NumDimO>,
Sequence<DMPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock>,
GemmSpecialization::MNKOPadding,
ASpec,
BSpec,
B1Spec,
CSpec>;
/*
Descriptors for inputs:
......@@ -547,6 +557,19 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
}
}
static auto MakeDGridDescriptor_M(index_t MRaw)
{
const auto d_grid_desc_mraw = make_naive_tensor_descriptor_packed(make_tuple(MRaw));
const auto M = math::integer_divide_ceil(MRaw, DMPerBlock) * DMPerBlock;
const auto MPad = M - MRaw;
return transform_tensor_descriptor(d_grid_desc_mraw,
make_tuple(make_right_pad_transform(MRaw, MPad)),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0>{}));
}
using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1({}, {}));
using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1({}, {}));
using B1GridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1({}, {}));
......@@ -562,7 +585,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
using YGradGridDesc_O0_M_O1 = decltype(MakeYGradGridDescriptor_O0_M_O1({}, {}));
using ZGridDesc_M_N = decltype(MakeZGridDescriptor_M_N({}, {}));
using DGridDesc_M = decltype(MakeLSEGridDescriptor_M(1));
using DYGridDesc_M_O = decltype(DTransform::MakeCGridDescriptor_M_N({}, {}));
using DGridDesc_M = decltype(MakeDGridDescriptor_M(1));
constexpr static auto make_MaskOutPredicate()
{
......@@ -656,7 +680,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
B1GridDesc_BK0_N_BK1,
YGridDesc_M_O,
LSEGridDesc_M,
DGridDesc_M,
LSEGridDesc_M,
NumGemmKPrefetchStage,
BlockSize,
MPerBlock,
......@@ -707,7 +731,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
YGridDesc_M_O,
DGridDesc_M,
BlockSize,
BlockSize,
DMPerBlock,
DKPerBlock>;
using DBlock2CTileMap =
OffsettedBlockToCTileMap<typename GridwiseYDotYGrad::DefaultBlock2CTileMap>;
......@@ -752,6 +776,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
// D parameter
DDataType* p_d_grid_;
DYGridDesc_M_O d_y_grid_desc_m_o_;
DGridDesc_M d_grid_desc_m_;
DBlock2CTileMap d_block_2_ctile_map_;
typename GridwiseYDotYGrad::YGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
......@@ -931,15 +956,18 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
// D parameters
const auto p_d_grid = static_cast<DDataType*>(p_Ds[i]);
const auto d_grid_desc_m =
DeviceOp::MakeLSEGridDescriptor_M(problem_desc.d_gs_ms_lengths[NumDimG]);
DeviceOp::MakeDGridDescriptor_M(problem_desc.d_gs_ms_lengths[NumDimG]);
const auto d_y_grid_desc_m_o = DTransform::MakeCGridDescriptor_M_N(
problem_desc.c_gs_ms_gemm1ns_lengths, problem_desc.c_gs_ms_gemm1ns_strides);
index_t d_block_start = d_grid_size_;
const auto d_block_2_ctile_map = DBlock2CTileMap(y_grid_desc_m_o, d_block_start);
const auto d_block_2_ctile_map = DBlock2CTileMap(d_y_grid_desc_m_o, d_block_start);
const auto d_y_grid_desc_mblock_mperblock_nblock_nperblock =
GridwiseYDotYGrad::MakeYGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
y_grid_desc_m_o);
d_y_grid_desc_m_o);
index_t d_num_blocks_per_batch =
d_block_2_ctile_map.CalculateGridSize(y_grid_desc_m_o);
d_block_2_ctile_map.CalculateGridSize(d_y_grid_desc_m_o);
index_t d_block_end = d_block_start + d_num_blocks_per_batch * batch_count;
d_grid_size_ = d_block_end;
......@@ -973,6 +1001,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
raw_m_padded,
raw_n_padded,
p_d_grid,
d_y_grid_desc_m_o,
d_grid_desc_m,
d_block_2_ctile_map,
d_y_grid_desc_mblock_mperblock_nblock_nperblock,
......@@ -1151,7 +1180,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
// TODO: Check if tensor specialization & strides mismatch
const auto& kernel_arg = arg.group_kernel_args_[i];
const auto& device_arg = arg.group_device_args_[i];
if(!GridwiseYDotYGrad::CheckValidity(kernel_arg.y_grid_desc_m_o_,
if(!GridwiseYDotYGrad::CheckValidity(kernel_arg.d_y_grid_desc_m_o_,
kernel_arg.d_block_2_ctile_map_))
{
return false;
......
......@@ -322,6 +322,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
static constexpr index_t NumAcc0Bias = Acc0BiasDataType::Size();
static constexpr index_t NumAcc1Bias = Acc1BiasDataType::Size();
static constexpr index_t DMPerBlock = BlockSize;
// TODO: implement bias combination
static_assert(NumAcc0Bias == 0 && NumAcc0Bias == 0, "Bias addition is unimplemented");
......@@ -385,6 +386,15 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
B1Spec,
CSpec>;
using DTransform = TransformBatchedContractionContractionToBatchedGemmGemm<
Sequence<NumDimG, NumDimM, NumDimN, NumDimK, NumDimO>,
Sequence<DMPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock>,
GemmSpecialization::MNKOPadding,
ASpec,
BSpec,
B1Spec,
CSpec>;
/*
Descriptors for inputs:
......@@ -547,6 +557,19 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
}
}
static auto MakeDGridDescriptor_M(index_t MRaw)
{
const auto d_grid_desc_mraw = make_naive_tensor_descriptor_packed(make_tuple(MRaw));
const auto M = math::integer_divide_ceil(MRaw, DMPerBlock) * DMPerBlock;
const auto MPad = M - MRaw;
return transform_tensor_descriptor(d_grid_desc_mraw,
make_tuple(make_right_pad_transform(MRaw, MPad)),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0>{}));
}
using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1({}, {}));
using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1({}, {}));
using B1GridDesc_BK0_N_BK1 = decltype(MakeB1GridDescriptor_BK0_N_BK1({}, {}));
......@@ -562,7 +585,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
using YGradGridDesc_M0_O_M1 = decltype(MakeYGradGridDescriptor_M0_O_M1(YGridDesc_M_O{}));
using ZGridDesc_M_N = decltype(MakeZGridDescriptor_M_N({}, {}));
using DGridDesc_M = decltype(MakeLSEGridDescriptor_M(1));
using DYGridDesc_M_O = decltype(DTransform::MakeCGridDescriptor_M_N({}, {}));
using DGridDesc_M = decltype(MakeDGridDescriptor_M(1));
constexpr static auto make_MaskOutPredicate()
{
......@@ -656,7 +680,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
B1GridDesc_BK0_N_BK1,
YGridDesc_M_O,
LSEGridDesc_M,
DGridDesc_M,
LSEGridDesc_M,
NumGemmKPrefetchStage,
BlockSize,
MPerBlock,
......@@ -715,7 +739,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
YGridDesc_M_O,
DGridDesc_M,
BlockSize,
BlockSize,
DMPerBlock,
DKPerBlock>;
using DBlock2CTileMap =
OffsettedBlockToCTileMap<typename GridwiseYDotYGrad::DefaultBlock2CTileMap>;
......@@ -760,6 +784,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
// D parameter
DDataType* p_d_grid_;
DYGridDesc_M_O d_y_grid_desc_m_o_;
DGridDesc_M d_grid_desc_m_;
DBlock2CTileMap d_block_2_ctile_map_;
typename GridwiseYDotYGrad::YGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
......@@ -934,16 +959,18 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
// D parameters
const auto p_d_grid = static_cast<DDataType*>(p_Ds[i]);
const auto d_grid_desc_m =
DeviceOp::MakeLSEGridDescriptor_M(problem_desc.d_gs_ms_lengths[NumDimG]);
DeviceOp::MakeDGridDescriptor_M(problem_desc.d_gs_ms_lengths[NumDimG]);
const auto d_y_grid_desc_m_o = DTransform::MakeCGridDescriptor_M_N(
problem_desc.c_gs_ms_gemm1ns_lengths, problem_desc.c_gs_ms_gemm1ns_strides);
index_t d_block_start = d_grid_size_;
const auto d_block_2_ctile_map = DBlock2CTileMap(y_grid_desc_m_o, d_block_start);
const auto d_block_2_ctile_map = DBlock2CTileMap(d_y_grid_desc_m_o, d_block_start);
const auto d_y_grid_desc_mblock_mperblock_nblock_nperblock =
GridwiseYDotYGrad::MakeYGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
y_grid_desc_m_o);
d_y_grid_desc_m_o);
index_t d_num_blocks_per_batch =
d_block_2_ctile_map.CalculateGridSize(y_grid_desc_m_o);
d_block_2_ctile_map.CalculateGridSize(d_y_grid_desc_m_o);
index_t d_block_end = d_block_start + d_num_blocks_per_batch * batch_count;
d_grid_size_ = d_block_end;
......@@ -977,6 +1004,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
raw_m_padded,
raw_n_padded,
p_d_grid,
d_y_grid_desc_m_o,
d_grid_desc_m,
d_block_2_ctile_map,
d_y_grid_desc_mblock_mperblock_nblock_nperblock,
......@@ -1153,7 +1181,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
// TODO: Check if tensor specialization & strides mismatch
const auto& kernel_arg = arg.group_kernel_args_[i];
const auto& device_arg = arg.group_device_args_[i];
if(!GridwiseYDotYGrad::CheckValidity(kernel_arg.y_grid_desc_m_o_,
if(!GridwiseYDotYGrad::CheckValidity(kernel_arg.d_y_grid_desc_m_o_,
kernel_arg.d_block_2_ctile_map_))
{
return false;
......
......@@ -45,21 +45,22 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
{
return false;
}
// const auto M = y_grid_desc_m_n.GetLength(I0);
const auto M = y_grid_desc_m_n.GetLength(I0);
const auto N = y_grid_desc_m_n.GetLength(I1);
if(N < NPerBlock)
{
return false;
}
// std::cout << "m: " << M <<" n: " << N << std::endl;
// if(M < MPerBlock)
// {
// return false;
// }
// if(M % MPerBlock != 0)
// {
// return false;
// }
if(M < MPerBlock)
{
return false;
}
if(M % MPerBlock != 0)
{
return false;
}
return true;
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment