Commit be38f68d authored by ltqin's avatar ltqin
Browse files

add padding code for M

parent a188073b
...@@ -102,8 +102,8 @@ static constexpr bool Deterministic = false; ...@@ -102,8 +102,8 @@ static constexpr bool Deterministic = false;
// If 32 < DIM <= 64 , ues prototype1 2nd template. // If 32 < DIM <= 64 , ues prototype1 2nd template.
// If 64 < DIM <= 128, ues prototype2 2nd template. // If 64 < DIM <= 128, ues prototype2 2nd template.
#if(DIM <= 32) #if(DIM <= 32)
using DeviceGemmInstance = using DeviceGemmInstance = ck::tensor_operation::device::
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Qloop_Phased_Xdl_CShuffle_V1< DeviceBatchedMultiheadAttentionBackward_Qloop_Phased_Xdl_CShuffle_V1<
NumDimG, NumDimG,
NumDimM, NumDimM,
NumDimN, NumDimN,
...@@ -172,8 +172,8 @@ using DeviceGemmInstance = ...@@ -172,8 +172,8 @@ using DeviceGemmInstance =
MaskingSpec, // MaskingSpecialization MaskingSpec, // MaskingSpecialization
Deterministic>; Deterministic>;
#elif(DIM <= 64) #elif(DIM <= 64)
using DeviceGemmInstance = using DeviceGemmInstance = ck::tensor_operation::device::
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Qloop_Phased_Xdl_CShuffle_V1< DeviceBatchedMultiheadAttentionBackward_Qloop_Phased_Xdl_CShuffle_V1<
NumDimG, NumDimG,
NumDimM, NumDimM,
NumDimN, NumDimN,
......
...@@ -25,7 +25,7 @@ Kernel outputs: ...@@ -25,7 +25,7 @@ Kernel outputs:
#define PRINT_HOST 0 #define PRINT_HOST 0
#define USING_MASK 0 #define USING_MASK 0
#define DIM 32 // DIM should be a multiple of 8. #define DIM 128 // DIM should be a multiple of 8.
#include <iostream> #include <iostream>
#include <numeric> #include <numeric>
......
...@@ -24,7 +24,7 @@ Kernel outputs: ...@@ -24,7 +24,7 @@ Kernel outputs:
*/ */
#define USING_MASK 0 #define USING_MASK 0
#define DIM 128 // DIM should be a multiple of 8. #define DIM 32 // DIM should be a multiple of 8.
#include <iostream> #include <iostream>
#include <numeric> #include <numeric>
......
...@@ -337,6 +337,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1 ...@@ -337,6 +337,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
static constexpr index_t NumAcc0Bias = Acc0BiasDataType::Size(); static constexpr index_t NumAcc0Bias = Acc0BiasDataType::Size();
static constexpr index_t NumAcc1Bias = Acc1BiasDataType::Size(); static constexpr index_t NumAcc1Bias = Acc1BiasDataType::Size();
static constexpr index_t DMPerBlock = BlockSize;
// TODO: implement bias combination // TODO: implement bias combination
static_assert(NumAcc0Bias == 0 && NumAcc0Bias == 0, "Bias addition is unimplemented"); static_assert(NumAcc0Bias == 0 && NumAcc0Bias == 0, "Bias addition is unimplemented");
...@@ -371,6 +372,15 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1 ...@@ -371,6 +372,15 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
B1Spec, B1Spec,
CSpec>; CSpec>;
using DTransform = TransformBatchedContractionContractionToBatchedGemmGemm<
Sequence<NumDimG, NumDimM, NumDimN, NumDimK, NumDimO>,
Sequence<DMPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock>,
GemmSpecialization::MNKOPadding,
ASpec,
BSpec,
B1Spec,
CSpec>;
/* /*
Descriptors for inputs: Descriptors for inputs:
...@@ -596,6 +606,19 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1 ...@@ -596,6 +606,19 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
} }
} }
static auto MakeDGridDescriptor_M(index_t MRaw)
{
const auto d_grid_desc_mraw = make_naive_tensor_descriptor_packed(make_tuple(MRaw));
const auto M = math::integer_divide_ceil(MRaw, DMPerBlock) * DMPerBlock;
const auto MPad = M - MRaw;
return transform_tensor_descriptor(d_grid_desc_mraw,
make_tuple(make_right_pad_transform(MRaw, MPad)),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0>{}));
}
using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1({}, {})); using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1({}, {}));
using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1({}, {})); using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1({}, {}));
using B1GridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1({}, {})); using B1GridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1({}, {}));
...@@ -606,7 +629,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1 ...@@ -606,7 +629,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
using B1GridDesc_G_N_K = decltype(Transform::MakeB1GridDescriptor_G_N_K({}, {})); using B1GridDesc_G_N_K = decltype(Transform::MakeB1GridDescriptor_G_N_K({}, {}));
using CGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {})); using CGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {}));
using ZGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {})); using ZGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {}));
using DGridDesc_M = decltype(MakeLSEGridDescriptor_M(1)); using DYGridDesc_M_O = decltype(DTransform::MakeCGridDescriptor_M_N({}, {}));
using DGridDesc_M = decltype(MakeDGridDescriptor_M(1));
using KGridDesc_N_K = decltype(Transform::MakeB0GridDescriptor_N_K({}, {})); using KGridDesc_N_K = decltype(Transform::MakeB0GridDescriptor_N_K({}, {}));
using YGradGridDesc_O0_M_O1 = decltype(MakeYGradGridDescriptor_O0_M_O1({}, {})); using YGradGridDesc_O0_M_O1 = decltype(MakeYGradGridDescriptor_O0_M_O1({}, {}));
...@@ -705,7 +729,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1 ...@@ -705,7 +729,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
B1GridDesc_BK0_N_BK1, B1GridDesc_BK0_N_BK1,
YGridDesc_M_O, YGridDesc_M_O,
LSEGridDesc_M, LSEGridDesc_M,
DGridDesc_M, LSEGridDesc_M,
NumGemmKPrefetchStage, NumGemmKPrefetchStage,
BlockSize, BlockSize,
MPerBlock, MPerBlock,
...@@ -754,7 +778,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1 ...@@ -754,7 +778,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
YGridDesc_M_O, YGridDesc_M_O,
DGridDesc_M, DGridDesc_M,
BlockSize, BlockSize,
BlockSize, DMPerBlock,
DKPerBlock>; DKPerBlock>;
// Argument // Argument
struct Argument : public BaseArgument struct Argument : public BaseArgument
...@@ -818,8 +842,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1 ...@@ -818,8 +842,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
b1_gs_gemm1ns_gemm1ks_lengths, b1_gs_gemm1ns_gemm1ks_strides)}, b1_gs_gemm1ns_gemm1ks_lengths, b1_gs_gemm1ns_gemm1ks_strides)},
y_grid_desc_m_o_{Transform::MakeCGridDescriptor_M_N(c_gs_ms_gemm1ns_lengths, y_grid_desc_m_o_{Transform::MakeCGridDescriptor_M_N(c_gs_ms_gemm1ns_lengths,
c_gs_ms_gemm1ns_strides)}, c_gs_ms_gemm1ns_strides)},
d_y_grid_desc_m_o_{DTransform::MakeCGridDescriptor_M_N(c_gs_ms_gemm1ns_lengths,
c_gs_ms_gemm1ns_strides)},
lse_grid_desc_m_{DeviceOp::MakeLSEGridDescriptor_M(lse_gs_ms_lengths[NumDimG])}, lse_grid_desc_m_{DeviceOp::MakeLSEGridDescriptor_M(lse_gs_ms_lengths[NumDimG])},
d_grid_desc_m_{DeviceOp::MakeLSEGridDescriptor_M(d_gs_ms_lengths[NumDimG])}, d_grid_desc_m_{DeviceOp::MakeDGridDescriptor_M(d_gs_ms_lengths[NumDimG])},
k_grid_desc_n_k_{ k_grid_desc_n_k_{
Transform::MakeB0GridDescriptor_N_K(b_gs_ns_ks_lengths, b_gs_ns_ks_strides)}, Transform::MakeB0GridDescriptor_N_K(b_gs_ns_ks_lengths, b_gs_ns_ks_strides)},
ygrad_grid_desc_o0_m_o1_{DeviceOp::MakeYGradGridDescriptor_O0_M_O1( ygrad_grid_desc_o0_m_o1_{DeviceOp::MakeYGradGridDescriptor_O0_M_O1(
...@@ -836,7 +862,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1 ...@@ -836,7 +862,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
z_grid_desc_g_m_n_{ z_grid_desc_g_m_n_{
Transform::MakeCGridDescriptor_G_M_N(z_gs_ms_ns_lengths, z_gs_ms_ns_strides)}, Transform::MakeCGridDescriptor_G_M_N(z_gs_ms_ns_lengths, z_gs_ms_ns_strides)},
block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(k_grid_desc_n_k_)}, block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(k_grid_desc_n_k_)},
d_block_2_ctile_map_{GridwiseYDotYGrad::MakeDefaultBlock2CTileMap(y_grid_desc_m_o_)}, d_block_2_ctile_map_{
GridwiseYDotYGrad::MakeDefaultBlock2CTileMap(d_y_grid_desc_m_o_)},
d_y_grid_desc_mblock_mperblock_oblock_operblock_{}, d_y_grid_desc_mblock_mperblock_oblock_operblock_{},
a_element_op_{a_element_op}, a_element_op_{a_element_op},
b_element_op_{b_element_op}, b_element_op_{b_element_op},
...@@ -881,7 +908,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1 ...@@ -881,7 +908,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3(z_grid_desc_m_n_); GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3(z_grid_desc_m_n_);
d_y_grid_desc_mblock_mperblock_oblock_operblock_ = d_y_grid_desc_mblock_mperblock_oblock_operblock_ =
GridwiseYDotYGrad::MakeYGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( GridwiseYDotYGrad::MakeYGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
y_grid_desc_m_o_); d_y_grid_desc_m_o_);
// Print(); // Print();
m_raw_padded_ = GridwiseGemm::GetPaddedSize(raw_lengths_mz_nz_kz_gemm1nz_[0]); m_raw_padded_ = GridwiseGemm::GetPaddedSize(raw_lengths_mz_nz_kz_gemm1nz_[0]);
...@@ -932,6 +959,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1 ...@@ -932,6 +959,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
ZGridDesc_M_N z_grid_desc_m_n_; ZGridDesc_M_N z_grid_desc_m_n_;
B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1_; B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1_;
YGridDesc_M_O y_grid_desc_m_o_; YGridDesc_M_O y_grid_desc_m_o_;
DYGridDesc_M_O d_y_grid_desc_m_o_;
LSEGridDesc_M lse_grid_desc_m_; LSEGridDesc_M lse_grid_desc_m_;
DGridDesc_M d_grid_desc_m_; DGridDesc_M d_grid_desc_m_;
KGridDesc_N_K k_grid_desc_n_k_; KGridDesc_N_K k_grid_desc_n_k_;
...@@ -998,15 +1026,17 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1 ...@@ -998,15 +1026,17 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
const index_t grid_size = const index_t grid_size =
(Deterministic (Deterministic
? 1 ? 1
: arg.d_block_2_ctile_map_.CalculateGridSize(arg.y_grid_desc_m_o_)) * : arg.d_block_2_ctile_map_.CalculateGridSize(arg.d_y_grid_desc_m_o_)) *
arg.batch_count_; arg.batch_count_;
std::cout << "grid_size: " << grid_size std::cout << "grid_size: " << grid_size
<< "grid_size / arg.batch_count_: " << grid_size / arg.batch_count_ << "grid_size / arg.batch_count_: " << grid_size / arg.batch_count_
<< " arg.batch_count_: " << arg.batch_count_ << std::endl; << " arg.batch_count_: " << arg.batch_count_ << std::endl;
std::cout << "MPerBlock: " << MPerBlock << " Gemm1NPerBlock: " << Gemm1NPerBlock std::cout << "MPerBlock: " << MPerBlock << " Gemm1NPerBlock: " << Gemm1NPerBlock
<< std::endl; << std::endl;
std::cout << "arg.y_grid_desc_m_o_: {" << arg.y_grid_desc_m_o_.GetLength(I0) << "," std::cout << "arg.d_y_grid_desc_m_o_: {" << arg.d_y_grid_desc_m_o_.GetLength(I0)
<< arg.y_grid_desc_m_o_.GetLength(I1) << "}" << std::endl; << "," << arg.d_y_grid_desc_m_o_.GetLength(I1) << "}" << std::endl;
std::cout << "arg.d_grid_desc_m_: {" << arg.d_grid_desc_m_.GetLength(I0) << "}"
<< std::endl;
auto launch_kernel = [&]() { auto launch_kernel = [&]() {
const auto kernel = kernel_batched_multihead_attention_backward_ydotygrad_v1< const auto kernel = kernel_batched_multihead_attention_backward_ydotygrad_v1<
...@@ -1062,7 +1092,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1 ...@@ -1062,7 +1092,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3, typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3,
DeviceOp::B1GridDesc_BK0_N_BK1, DeviceOp::B1GridDesc_BK0_N_BK1,
DeviceOp::LSEGridDesc_M, DeviceOp::LSEGridDesc_M,
DeviceOp::DGridDesc_M, DeviceOp::LSEGridDesc_M,
DeviceOp::YGradGridDesc_O0_M_O1, DeviceOp::YGradGridDesc_O0_M_O1,
typename GridwiseGemm::DefaultBlock2CTileMap, typename GridwiseGemm::DefaultBlock2CTileMap,
ComputeBasePtrOfStridedBatch, ComputeBasePtrOfStridedBatch,
...@@ -1096,7 +1126,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1 ...@@ -1096,7 +1126,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
arg.c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_, arg.c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_,
arg.b1_grid_desc_bk0_n_bk1_, arg.b1_grid_desc_bk0_n_bk1_,
arg.lse_grid_desc_m_, arg.lse_grid_desc_m_,
arg.d_grid_desc_m_, arg.lse_grid_desc_m_,
arg.ygrad_grid_desc_o0_m_o1_, arg.ygrad_grid_desc_o0_m_o1_,
arg.block_2_ctile_map_, arg.block_2_ctile_map_,
arg.batch_count_, arg.batch_count_,
...@@ -1138,7 +1168,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1 ...@@ -1138,7 +1168,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
} }
// TODO: Check if tensor specialization & strides mismatch // TODO: Check if tensor specialization & strides mismatch
if(!GridwiseYDotYGrad::CheckValidity(arg.y_grid_desc_m_o_, arg.d_block_2_ctile_map_)) if(!GridwiseYDotYGrad::CheckValidity(arg.d_y_grid_desc_m_o_, arg.d_block_2_ctile_map_))
{ {
return false; return false;
} }
......
...@@ -342,6 +342,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2 ...@@ -342,6 +342,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
static constexpr index_t NumAcc0Bias = Acc0BiasDataType::Size(); static constexpr index_t NumAcc0Bias = Acc0BiasDataType::Size();
static constexpr index_t NumAcc1Bias = Acc1BiasDataType::Size(); static constexpr index_t NumAcc1Bias = Acc1BiasDataType::Size();
static constexpr index_t DMPerBlock = BlockSize;
// TODO: implement bias combination // TODO: implement bias combination
static_assert(NumAcc0Bias == 0 && NumAcc0Bias == 0, "Bias addition is unimplemented"); static_assert(NumAcc0Bias == 0 && NumAcc0Bias == 0, "Bias addition is unimplemented");
...@@ -377,6 +378,15 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2 ...@@ -377,6 +378,15 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
B1Spec, B1Spec,
CSpec>; CSpec>;
using DTransform = TransformBatchedContractionContractionToBatchedGemmGemm<
Sequence<NumDimG, NumDimM, NumDimN, NumDimK, NumDimO>,
Sequence<DMPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock>,
GemmSpecialization::MNKOPadding,
ASpec,
BSpec,
B1Spec,
CSpec>;
/* /*
Descriptors for inputs: Descriptors for inputs:
...@@ -602,6 +612,19 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2 ...@@ -602,6 +612,19 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
} }
} }
static auto MakeDGridDescriptor_M(index_t MRaw)
{
const auto d_grid_desc_mraw = make_naive_tensor_descriptor_packed(make_tuple(MRaw));
const auto M = math::integer_divide_ceil(MRaw, BlockSize) * BlockSize;
const auto MPad = M - MRaw;
return transform_tensor_descriptor(d_grid_desc_mraw,
make_tuple(make_right_pad_transform(MRaw, MPad)),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0>{}));
}
using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1({}, {})); using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1({}, {}));
using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1({}, {})); using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1({}, {}));
using B1GridDesc_BK0_N_BK1 = decltype(MakeB1GridDescriptor_BK0_N_BK1({}, {})); using B1GridDesc_BK0_N_BK1 = decltype(MakeB1GridDescriptor_BK0_N_BK1({}, {}));
...@@ -612,7 +635,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2 ...@@ -612,7 +635,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
using B1GridDesc_G_N_K = decltype(Transform::MakeB1GridDescriptor_G_N_K({}, {})); using B1GridDesc_G_N_K = decltype(Transform::MakeB1GridDescriptor_G_N_K({}, {}));
using CGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {})); using CGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {}));
using ZGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {})); using ZGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {}));
using DGridDesc_M = decltype(MakeLSEGridDescriptor_M(1)); using DYGridDesc_M_O = decltype(DTransform::MakeCGridDescriptor_M_N({}, {}));
using DGridDesc_M = decltype(MakeDGridDescriptor_M(1));
using KGridDesc_N_K = decltype(Transform::MakeB0GridDescriptor_N_K({}, {})); using KGridDesc_N_K = decltype(Transform::MakeB0GridDescriptor_N_K({}, {}));
using YGradGridDesc_M0_O_M1 = decltype(MakeYGradGridDescriptor_M0_O_M1(YGridDesc_M_O{})); using YGradGridDesc_M0_O_M1 = decltype(MakeYGradGridDescriptor_M0_O_M1(YGridDesc_M_O{}));
...@@ -711,7 +735,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2 ...@@ -711,7 +735,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
B1GridDesc_BK0_N_BK1, B1GridDesc_BK0_N_BK1,
YGridDesc_M_O, YGridDesc_M_O,
LSEGridDesc_M, LSEGridDesc_M,
DGridDesc_M, LSEGridDesc_M,
NumGemmKPrefetchStage, NumGemmKPrefetchStage,
BlockSize, BlockSize,
MPerBlock, MPerBlock,
...@@ -768,7 +792,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2 ...@@ -768,7 +792,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
YGridDesc_M_O, YGridDesc_M_O,
DGridDesc_M, DGridDesc_M,
BlockSize, BlockSize,
BlockSize, DMPerBlock,
DKPerBlock>; DKPerBlock>;
// Argument // Argument
struct Argument : public BaseArgument struct Argument : public BaseArgument
...@@ -832,8 +856,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2 ...@@ -832,8 +856,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
b1_gs_gemm1ns_gemm1ks_lengths, b1_gs_gemm1ns_gemm1ks_strides)}, b1_gs_gemm1ns_gemm1ks_lengths, b1_gs_gemm1ns_gemm1ks_strides)},
y_grid_desc_m_o_{Transform::MakeCGridDescriptor_M_N(c_gs_ms_gemm1ns_lengths, y_grid_desc_m_o_{Transform::MakeCGridDescriptor_M_N(c_gs_ms_gemm1ns_lengths,
c_gs_ms_gemm1ns_strides)}, c_gs_ms_gemm1ns_strides)},
d_y_grid_desc_m_o_{DTransform::MakeCGridDescriptor_M_N(c_gs_ms_gemm1ns_lengths,
c_gs_ms_gemm1ns_strides)},
lse_grid_desc_m_{DeviceOp::MakeLSEGridDescriptor_M(lse_gs_ms_lengths[NumDimG])}, lse_grid_desc_m_{DeviceOp::MakeLSEGridDescriptor_M(lse_gs_ms_lengths[NumDimG])},
d_grid_desc_m_{DeviceOp::MakeLSEGridDescriptor_M(d_gs_ms_lengths[NumDimG])}, d_grid_desc_m_{DeviceOp::MakeDGridDescriptor_M(d_gs_ms_lengths[NumDimG])},
k_grid_desc_n_k_{ k_grid_desc_n_k_{
Transform::MakeB0GridDescriptor_N_K(b_gs_ns_ks_lengths, b_gs_ns_ks_strides)}, Transform::MakeB0GridDescriptor_N_K(b_gs_ns_ks_lengths, b_gs_ns_ks_strides)},
ygrad_grid_desc_m0_o_m1_{DeviceOp::MakeYGradGridDescriptor_M0_O_M1(y_grid_desc_m_o_)}, ygrad_grid_desc_m0_o_m1_{DeviceOp::MakeYGradGridDescriptor_M0_O_M1(y_grid_desc_m_o_)},
...@@ -849,7 +875,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2 ...@@ -849,7 +875,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
z_grid_desc_g_m_n_{ z_grid_desc_g_m_n_{
Transform::MakeCGridDescriptor_G_M_N(z_gs_ms_ns_lengths, z_gs_ms_ns_strides)}, Transform::MakeCGridDescriptor_G_M_N(z_gs_ms_ns_lengths, z_gs_ms_ns_strides)},
block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(k_grid_desc_n_k_)}, block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(k_grid_desc_n_k_)},
d_block_2_ctile_map_{GridwiseYDotYGrad::MakeDefaultBlock2CTileMap(y_grid_desc_m_o_)}, d_block_2_ctile_map_{
GridwiseYDotYGrad::MakeDefaultBlock2CTileMap(d_y_grid_desc_m_o_)},
d_y_grid_desc_mblock_mperblock_oblock_operblock_{}, d_y_grid_desc_mblock_mperblock_oblock_operblock_{},
a_element_op_{a_element_op}, a_element_op_{a_element_op},
b_element_op_{b_element_op}, b_element_op_{b_element_op},
...@@ -894,7 +921,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2 ...@@ -894,7 +921,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3(z_grid_desc_m_n_); GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3(z_grid_desc_m_n_);
d_y_grid_desc_mblock_mperblock_oblock_operblock_ = d_y_grid_desc_mblock_mperblock_oblock_operblock_ =
GridwiseYDotYGrad::MakeYGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( GridwiseYDotYGrad::MakeYGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
y_grid_desc_m_o_); d_y_grid_desc_m_o_);
// Print(); // Print();
m_raw_padded_ = GridwiseGemm::GetPaddedSize(raw_lengths_mz_nz_kz_gemm1nz_[0]); m_raw_padded_ = GridwiseGemm::GetPaddedSize(raw_lengths_mz_nz_kz_gemm1nz_[0]);
...@@ -945,6 +972,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2 ...@@ -945,6 +972,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
ZGridDesc_M_N z_grid_desc_m_n_; ZGridDesc_M_N z_grid_desc_m_n_;
B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1_; B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1_;
YGridDesc_M_O y_grid_desc_m_o_; YGridDesc_M_O y_grid_desc_m_o_;
DYGridDesc_M_O d_y_grid_desc_m_o_;
LSEGridDesc_M lse_grid_desc_m_; LSEGridDesc_M lse_grid_desc_m_;
DGridDesc_M d_grid_desc_m_; DGridDesc_M d_grid_desc_m_;
KGridDesc_N_K k_grid_desc_n_k_; KGridDesc_N_K k_grid_desc_n_k_;
...@@ -1011,15 +1039,17 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2 ...@@ -1011,15 +1039,17 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
const index_t grid_size = const index_t grid_size =
(Deterministic (Deterministic
? 1 ? 1
: arg.d_block_2_ctile_map_.CalculateGridSize(arg.y_grid_desc_m_o_)) * : arg.d_block_2_ctile_map_.CalculateGridSize(arg.d_y_grid_desc_m_o_)) *
arg.batch_count_; arg.batch_count_;
std::cout << "grid_size: " << grid_size std::cout << "grid_size: " << grid_size
<< "grid_size / arg.batch_count_: " << grid_size / arg.batch_count_ << "grid_size / arg.batch_count_: " << grid_size / arg.batch_count_
<< " arg.batch_count_: " << arg.batch_count_ << std::endl; << " arg.batch_count_: " << arg.batch_count_ << std::endl;
std::cout << "MPerBlock: " << MPerBlock << " Gemm1NPerBlock: " << Gemm1NPerBlock std::cout << "MPerBlock: " << MPerBlock << " Gemm1NPerBlock: " << Gemm1NPerBlock
<< std::endl; << std::endl;
std::cout << "arg.y_grid_desc_m_o_: {" << arg.y_grid_desc_m_o_.GetLength(I0) << "," std::cout << "arg.d_y_grid_desc_m_o_: {" << arg.d_y_grid_desc_m_o_.GetLength(I0)
<< arg.y_grid_desc_m_o_.GetLength(I1) << "}" << std::endl; << "," << arg.d_y_grid_desc_m_o_.GetLength(I1) << "}" << std::endl;
std::cout << "arg.d_grid_desc_m_: {" << arg.d_grid_desc_m_.GetLength(I0) << "}"
<< std::endl;
auto launch_kernel = [&]() { auto launch_kernel = [&]() {
const auto kernel = kernel_batched_multihead_attention_backward_ydotygrad_v2< const auto kernel = kernel_batched_multihead_attention_backward_ydotygrad_v2<
...@@ -1079,7 +1109,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2 ...@@ -1079,7 +1109,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3, typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3,
DeviceOp::B1GridDesc_BK0_N_BK1, DeviceOp::B1GridDesc_BK0_N_BK1,
DeviceOp::LSEGridDesc_M, DeviceOp::LSEGridDesc_M,
DeviceOp::DGridDesc_M, DeviceOp::LSEGridDesc_M,
DeviceOp::YGradGridDesc_M0_O_M1, DeviceOp::YGradGridDesc_M0_O_M1,
typename GridwiseGemm::DefaultBlock2CTileMap, typename GridwiseGemm::DefaultBlock2CTileMap,
ComputeBasePtrOfStridedBatch, ComputeBasePtrOfStridedBatch,
...@@ -1113,7 +1143,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2 ...@@ -1113,7 +1143,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
arg.c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_, arg.c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_,
arg.b1_grid_desc_bk0_n_bk1_, arg.b1_grid_desc_bk0_n_bk1_,
arg.lse_grid_desc_m_, arg.lse_grid_desc_m_,
arg.d_grid_desc_m_, arg.lse_grid_desc_m_,
arg.ygrad_grid_desc_m0_o_m1_, arg.ygrad_grid_desc_m0_o_m1_,
arg.block_2_ctile_map_, arg.block_2_ctile_map_,
arg.batch_count_, arg.batch_count_,
...@@ -1165,7 +1195,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2 ...@@ -1165,7 +1195,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
} }
// TODO: Check if tensor specialization & strides mismatch // TODO: Check if tensor specialization & strides mismatch
if(!GridwiseYDotYGrad::CheckValidity(arg.y_grid_desc_m_o_, arg.d_block_2_ctile_map_)) if(!GridwiseYDotYGrad::CheckValidity(arg.d_y_grid_desc_m_o_, arg.d_block_2_ctile_map_))
{ {
return false; return false;
} }
......
...@@ -315,6 +315,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1 ...@@ -315,6 +315,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
static constexpr index_t NumAcc0Bias = Acc0BiasDataType::Size(); static constexpr index_t NumAcc0Bias = Acc0BiasDataType::Size();
static constexpr index_t NumAcc1Bias = Acc1BiasDataType::Size(); static constexpr index_t NumAcc1Bias = Acc1BiasDataType::Size();
static constexpr index_t DMPerBlock = BlockSize;
// TODO: implement bias combination // TODO: implement bias combination
static_assert(NumAcc0Bias == 0 && NumAcc0Bias == 0, "Bias addition is unimplemented"); static_assert(NumAcc0Bias == 0 && NumAcc0Bias == 0, "Bias addition is unimplemented");
...@@ -378,6 +379,15 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1 ...@@ -378,6 +379,15 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
B1Spec, B1Spec,
CSpec>; CSpec>;
using DTransform = TransformBatchedContractionContractionToBatchedGemmGemm<
Sequence<NumDimG, NumDimM, NumDimN, NumDimK, NumDimO>,
Sequence<DMPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock>,
GemmSpecialization::MNKOPadding,
ASpec,
BSpec,
B1Spec,
CSpec>;
/* /*
Descriptors for inputs: Descriptors for inputs:
...@@ -547,6 +557,19 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1 ...@@ -547,6 +557,19 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
} }
} }
static auto MakeDGridDescriptor_M(index_t MRaw)
{
const auto d_grid_desc_mraw = make_naive_tensor_descriptor_packed(make_tuple(MRaw));
const auto M = math::integer_divide_ceil(MRaw, DMPerBlock) * DMPerBlock;
const auto MPad = M - MRaw;
return transform_tensor_descriptor(d_grid_desc_mraw,
make_tuple(make_right_pad_transform(MRaw, MPad)),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0>{}));
}
using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1({}, {})); using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1({}, {}));
using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1({}, {})); using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1({}, {}));
using B1GridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1({}, {})); using B1GridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1({}, {}));
...@@ -562,7 +585,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1 ...@@ -562,7 +585,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
using YGradGridDesc_O0_M_O1 = decltype(MakeYGradGridDescriptor_O0_M_O1({}, {})); using YGradGridDesc_O0_M_O1 = decltype(MakeYGradGridDescriptor_O0_M_O1({}, {}));
using ZGridDesc_M_N = decltype(MakeZGridDescriptor_M_N({}, {})); using ZGridDesc_M_N = decltype(MakeZGridDescriptor_M_N({}, {}));
using DGridDesc_M = decltype(MakeLSEGridDescriptor_M(1)); using DYGridDesc_M_O = decltype(DTransform::MakeCGridDescriptor_M_N({}, {}));
using DGridDesc_M = decltype(MakeDGridDescriptor_M(1));
constexpr static auto make_MaskOutPredicate() constexpr static auto make_MaskOutPredicate()
{ {
...@@ -656,7 +680,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1 ...@@ -656,7 +680,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
B1GridDesc_BK0_N_BK1, B1GridDesc_BK0_N_BK1,
YGridDesc_M_O, YGridDesc_M_O,
LSEGridDesc_M, LSEGridDesc_M,
DGridDesc_M, LSEGridDesc_M,
NumGemmKPrefetchStage, NumGemmKPrefetchStage,
BlockSize, BlockSize,
MPerBlock, MPerBlock,
...@@ -707,7 +731,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1 ...@@ -707,7 +731,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
YGridDesc_M_O, YGridDesc_M_O,
DGridDesc_M, DGridDesc_M,
BlockSize, BlockSize,
BlockSize, DMPerBlock,
DKPerBlock>; DKPerBlock>;
using DBlock2CTileMap = using DBlock2CTileMap =
OffsettedBlockToCTileMap<typename GridwiseYDotYGrad::DefaultBlock2CTileMap>; OffsettedBlockToCTileMap<typename GridwiseYDotYGrad::DefaultBlock2CTileMap>;
...@@ -752,6 +776,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1 ...@@ -752,6 +776,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
// D parameter // D parameter
DDataType* p_d_grid_; DDataType* p_d_grid_;
DYGridDesc_M_O d_y_grid_desc_m_o_;
DGridDesc_M d_grid_desc_m_; DGridDesc_M d_grid_desc_m_;
DBlock2CTileMap d_block_2_ctile_map_; DBlock2CTileMap d_block_2_ctile_map_;
typename GridwiseYDotYGrad::YGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock typename GridwiseYDotYGrad::YGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
...@@ -931,15 +956,18 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1 ...@@ -931,15 +956,18 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
// D parameters // D parameters
const auto p_d_grid = static_cast<DDataType*>(p_Ds[i]); const auto p_d_grid = static_cast<DDataType*>(p_Ds[i]);
const auto d_grid_desc_m = const auto d_grid_desc_m =
DeviceOp::MakeLSEGridDescriptor_M(problem_desc.d_gs_ms_lengths[NumDimG]); DeviceOp::MakeDGridDescriptor_M(problem_desc.d_gs_ms_lengths[NumDimG]);
const auto d_y_grid_desc_m_o = DTransform::MakeCGridDescriptor_M_N(
problem_desc.c_gs_ms_gemm1ns_lengths, problem_desc.c_gs_ms_gemm1ns_strides);
index_t d_block_start = d_grid_size_; index_t d_block_start = d_grid_size_;
const auto d_block_2_ctile_map = DBlock2CTileMap(y_grid_desc_m_o, d_block_start); const auto d_block_2_ctile_map = DBlock2CTileMap(d_y_grid_desc_m_o, d_block_start);
const auto d_y_grid_desc_mblock_mperblock_nblock_nperblock = const auto d_y_grid_desc_mblock_mperblock_nblock_nperblock =
GridwiseYDotYGrad::MakeYGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( GridwiseYDotYGrad::MakeYGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
y_grid_desc_m_o); d_y_grid_desc_m_o);
index_t d_num_blocks_per_batch = index_t d_num_blocks_per_batch =
d_block_2_ctile_map.CalculateGridSize(y_grid_desc_m_o); d_block_2_ctile_map.CalculateGridSize(d_y_grid_desc_m_o);
index_t d_block_end = d_block_start + d_num_blocks_per_batch * batch_count; index_t d_block_end = d_block_start + d_num_blocks_per_batch * batch_count;
d_grid_size_ = d_block_end; d_grid_size_ = d_block_end;
...@@ -973,6 +1001,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1 ...@@ -973,6 +1001,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
raw_m_padded, raw_m_padded,
raw_n_padded, raw_n_padded,
p_d_grid, p_d_grid,
d_y_grid_desc_m_o,
d_grid_desc_m, d_grid_desc_m,
d_block_2_ctile_map, d_block_2_ctile_map,
d_y_grid_desc_mblock_mperblock_nblock_nperblock, d_y_grid_desc_mblock_mperblock_nblock_nperblock,
...@@ -1151,7 +1180,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1 ...@@ -1151,7 +1180,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
// TODO: Check if tensor specialization & strides mismatch // TODO: Check if tensor specialization & strides mismatch
const auto& kernel_arg = arg.group_kernel_args_[i]; const auto& kernel_arg = arg.group_kernel_args_[i];
const auto& device_arg = arg.group_device_args_[i]; const auto& device_arg = arg.group_device_args_[i];
if(!GridwiseYDotYGrad::CheckValidity(kernel_arg.y_grid_desc_m_o_, if(!GridwiseYDotYGrad::CheckValidity(kernel_arg.d_y_grid_desc_m_o_,
kernel_arg.d_block_2_ctile_map_)) kernel_arg.d_block_2_ctile_map_))
{ {
return false; return false;
......
...@@ -322,6 +322,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2 ...@@ -322,6 +322,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
static constexpr index_t NumAcc0Bias = Acc0BiasDataType::Size(); static constexpr index_t NumAcc0Bias = Acc0BiasDataType::Size();
static constexpr index_t NumAcc1Bias = Acc1BiasDataType::Size(); static constexpr index_t NumAcc1Bias = Acc1BiasDataType::Size();
static constexpr index_t DMPerBlock = BlockSize;
// TODO: implement bias combination // TODO: implement bias combination
static_assert(NumAcc0Bias == 0 && NumAcc0Bias == 0, "Bias addition is unimplemented"); static_assert(NumAcc0Bias == 0 && NumAcc0Bias == 0, "Bias addition is unimplemented");
...@@ -385,6 +386,15 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2 ...@@ -385,6 +386,15 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
B1Spec, B1Spec,
CSpec>; CSpec>;
using DTransform = TransformBatchedContractionContractionToBatchedGemmGemm<
Sequence<NumDimG, NumDimM, NumDimN, NumDimK, NumDimO>,
Sequence<DMPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock>,
GemmSpecialization::MNKOPadding,
ASpec,
BSpec,
B1Spec,
CSpec>;
/* /*
Descriptors for inputs: Descriptors for inputs:
...@@ -547,6 +557,19 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2 ...@@ -547,6 +557,19 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
} }
} }
static auto MakeDGridDescriptor_M(index_t MRaw)
{
const auto d_grid_desc_mraw = make_naive_tensor_descriptor_packed(make_tuple(MRaw));
const auto M = math::integer_divide_ceil(MRaw, DMPerBlock) * DMPerBlock;
const auto MPad = M - MRaw;
return transform_tensor_descriptor(d_grid_desc_mraw,
make_tuple(make_right_pad_transform(MRaw, MPad)),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0>{}));
}
using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1({}, {})); using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1({}, {}));
using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1({}, {})); using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1({}, {}));
using B1GridDesc_BK0_N_BK1 = decltype(MakeB1GridDescriptor_BK0_N_BK1({}, {})); using B1GridDesc_BK0_N_BK1 = decltype(MakeB1GridDescriptor_BK0_N_BK1({}, {}));
...@@ -562,7 +585,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2 ...@@ -562,7 +585,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
using YGradGridDesc_M0_O_M1 = decltype(MakeYGradGridDescriptor_M0_O_M1(YGridDesc_M_O{})); using YGradGridDesc_M0_O_M1 = decltype(MakeYGradGridDescriptor_M0_O_M1(YGridDesc_M_O{}));
using ZGridDesc_M_N = decltype(MakeZGridDescriptor_M_N({}, {})); using ZGridDesc_M_N = decltype(MakeZGridDescriptor_M_N({}, {}));
using DGridDesc_M = decltype(MakeLSEGridDescriptor_M(1)); using DYGridDesc_M_O = decltype(DTransform::MakeCGridDescriptor_M_N({}, {}));
using DGridDesc_M = decltype(MakeDGridDescriptor_M(1));
constexpr static auto make_MaskOutPredicate() constexpr static auto make_MaskOutPredicate()
{ {
...@@ -656,7 +680,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2 ...@@ -656,7 +680,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
B1GridDesc_BK0_N_BK1, B1GridDesc_BK0_N_BK1,
YGridDesc_M_O, YGridDesc_M_O,
LSEGridDesc_M, LSEGridDesc_M,
DGridDesc_M, LSEGridDesc_M,
NumGemmKPrefetchStage, NumGemmKPrefetchStage,
BlockSize, BlockSize,
MPerBlock, MPerBlock,
...@@ -715,7 +739,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2 ...@@ -715,7 +739,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
YGridDesc_M_O, YGridDesc_M_O,
DGridDesc_M, DGridDesc_M,
BlockSize, BlockSize,
BlockSize, DMPerBlock,
DKPerBlock>; DKPerBlock>;
using DBlock2CTileMap = using DBlock2CTileMap =
OffsettedBlockToCTileMap<typename GridwiseYDotYGrad::DefaultBlock2CTileMap>; OffsettedBlockToCTileMap<typename GridwiseYDotYGrad::DefaultBlock2CTileMap>;
...@@ -760,6 +784,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2 ...@@ -760,6 +784,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
// D parameter // D parameter
DDataType* p_d_grid_; DDataType* p_d_grid_;
DYGridDesc_M_O d_y_grid_desc_m_o_;
DGridDesc_M d_grid_desc_m_; DGridDesc_M d_grid_desc_m_;
DBlock2CTileMap d_block_2_ctile_map_; DBlock2CTileMap d_block_2_ctile_map_;
typename GridwiseYDotYGrad::YGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock typename GridwiseYDotYGrad::YGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
...@@ -934,16 +959,18 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2 ...@@ -934,16 +959,18 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
// D parameters // D parameters
const auto p_d_grid = static_cast<DDataType*>(p_Ds[i]); const auto p_d_grid = static_cast<DDataType*>(p_Ds[i]);
const auto d_grid_desc_m = const auto d_grid_desc_m =
DeviceOp::MakeLSEGridDescriptor_M(problem_desc.d_gs_ms_lengths[NumDimG]); DeviceOp::MakeDGridDescriptor_M(problem_desc.d_gs_ms_lengths[NumDimG]);
const auto d_y_grid_desc_m_o = DTransform::MakeCGridDescriptor_M_N(
problem_desc.c_gs_ms_gemm1ns_lengths, problem_desc.c_gs_ms_gemm1ns_strides);
index_t d_block_start = d_grid_size_; index_t d_block_start = d_grid_size_;
const auto d_block_2_ctile_map = DBlock2CTileMap(y_grid_desc_m_o, d_block_start); const auto d_block_2_ctile_map = DBlock2CTileMap(d_y_grid_desc_m_o, d_block_start);
const auto d_y_grid_desc_mblock_mperblock_nblock_nperblock = const auto d_y_grid_desc_mblock_mperblock_nblock_nperblock =
GridwiseYDotYGrad::MakeYGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( GridwiseYDotYGrad::MakeYGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
y_grid_desc_m_o); d_y_grid_desc_m_o);
index_t d_num_blocks_per_batch = index_t d_num_blocks_per_batch =
d_block_2_ctile_map.CalculateGridSize(y_grid_desc_m_o); d_block_2_ctile_map.CalculateGridSize(d_y_grid_desc_m_o);
index_t d_block_end = d_block_start + d_num_blocks_per_batch * batch_count; index_t d_block_end = d_block_start + d_num_blocks_per_batch * batch_count;
d_grid_size_ = d_block_end; d_grid_size_ = d_block_end;
...@@ -977,6 +1004,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2 ...@@ -977,6 +1004,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
raw_m_padded, raw_m_padded,
raw_n_padded, raw_n_padded,
p_d_grid, p_d_grid,
d_y_grid_desc_m_o,
d_grid_desc_m, d_grid_desc_m,
d_block_2_ctile_map, d_block_2_ctile_map,
d_y_grid_desc_mblock_mperblock_nblock_nperblock, d_y_grid_desc_mblock_mperblock_nblock_nperblock,
...@@ -1153,7 +1181,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2 ...@@ -1153,7 +1181,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
// TODO: Check if tensor specialization & strides mismatch // TODO: Check if tensor specialization & strides mismatch
const auto& kernel_arg = arg.group_kernel_args_[i]; const auto& kernel_arg = arg.group_kernel_args_[i];
const auto& device_arg = arg.group_device_args_[i]; const auto& device_arg = arg.group_device_args_[i];
if(!GridwiseYDotYGrad::CheckValidity(kernel_arg.y_grid_desc_m_o_, if(!GridwiseYDotYGrad::CheckValidity(kernel_arg.d_y_grid_desc_m_o_,
kernel_arg.d_block_2_ctile_map_)) kernel_arg.d_block_2_ctile_map_))
{ {
return false; return false;
......
...@@ -45,21 +45,22 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad ...@@ -45,21 +45,22 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
{ {
return false; return false;
} }
// const auto M = y_grid_desc_m_n.GetLength(I0); const auto M = y_grid_desc_m_n.GetLength(I0);
const auto N = y_grid_desc_m_n.GetLength(I1); const auto N = y_grid_desc_m_n.GetLength(I1);
if(N < NPerBlock) if(N < NPerBlock)
{ {
return false; return false;
} }
// std::cout << "m: " << M <<" n: " << N << std::endl;
// if(M < MPerBlock) if(M < MPerBlock)
// { {
// return false; return false;
// } }
// if(M % MPerBlock != 0) if(M % MPerBlock != 0)
// { {
// return false; return false;
// } }
return true; return true;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment