Commit 33fad9ba authored by ltqin's avatar ltqin
Browse files

regular code

parent 513abed6
......@@ -776,7 +776,6 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
ORSDataType,
YGridDesc_M_O,
ORSGridDesc_M,
NumGemmKPrefetchStage,
BlockSize,
MPerBlock,
NPerBlock,
......@@ -790,14 +789,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
NPerXDL,
MXdlPerWave,
NXdlPerWave,
Gemm1NXdlPerWave,
Gemm2NXdlPerWave,
ABlockLdsExtraM,
BBlockLdsExtraN,
BBlockLdsExtraN,
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
LoopSched,
Deterministic>;
// Argument
struct Argument : public BaseArgument
......
......@@ -792,7 +792,6 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
ORSDataType,
YGridDesc_M_O,
ORSGridDesc_M,
NumGemmKPrefetchStage,
BlockSize,
MPerBlock,
NPerBlock,
......@@ -806,14 +805,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
NPerXDL,
MXdlPerWave,
NXdlPerWave,
Gemm1NXdlPerWave,
Gemm2NXdlPerWave,
ABlockLdsExtraM,
BBlockLdsExtraN,
BBlockLdsExtraN,
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
LoopSched,
Deterministic>;
// Argument
struct Argument : public BaseArgument
......
......@@ -26,7 +26,6 @@ template <typename InputDataType,
typename FloatORS,
typename CGridDesc_M_N,
typename ORSGridDesc_M,
index_t NumGemmKPrefetchStage,
index_t BlockSize,
index_t MPerBlock,
index_t NPerBlock,
......@@ -40,21 +39,11 @@ template <typename InputDataType,
index_t NPerXdl,
index_t MXdlPerWave,
index_t NXdlPerWave,
index_t Gemm1NXdlPerWave,
index_t Gemm2NXdlPerWave,
index_t ABlockLdsExtraM,
index_t BBlockLdsExtraN,
index_t B1BlockLdsExtraN,
index_t CShuffleMXdlPerWavePerShuffle,
index_t CShuffleNXdlPerWavePerShuffle,
LoopScheduler LoopSched,
bool Deterministic,
PipelineVersion PipelineVer = PipelineVersion::v1>
bool Deterministic>
struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
{
static_assert(LoopSched == LoopScheduler::Default,
"Non-default loop scheduler is currently not supported");
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
......@@ -74,18 +63,12 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
static constexpr auto AK1 = Number<AK1Value>{};
static constexpr auto BK1 = Number<BK1Value>{};
static constexpr auto Gemm0MWaves = MPerBlock / (MPerXdl * MXdlPerWave);
static constexpr auto Gemm0NWaves = NPerBlock / (NPerXdl * NXdlPerWave);
// Gemm1
static constexpr auto B1K0 = Number<Gemm1KPerBlock / B1K1Value>{};
static constexpr auto B1K1 = Number<B1K1Value>{};
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
using GridwiseGemmPipe = remove_cvref_t<decltype(
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage>())>;
__host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
{
// A matrix in LDS memory, dst of blockwise copy
......@@ -102,57 +85,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
make_tuple(Number<NPerBlock + BBlockLdsExtraN>{} * BK1, BK1, I1));
}
template <typename AccThreadDesc_M0_N0_M1_N1_M2_N2_N3_N4>
__host__ __device__ static constexpr auto GetA1SrcThreadDescriptor_AK0PerBlock_MPerBlock_AK1(
const AccThreadDesc_M0_N0_M1_N1_M2_N2_N3_N4& acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4)
{
// acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 to a_src_thread_desc_k0_m_k1
// n0_n1_n2_n3 -> k0
// m0_m1_m2 -> m
// n4 -> k1
// NOTE: had to use merge_v3 or will spit out compilation errors
const auto m0 = acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I0);
const auto n0 = acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I1);
const auto m1 = acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I2);
const auto n1 = acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I3);
const auto m2 = acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I4);
const auto n2 = acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I5);
const auto n3 = acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I6);
const auto n4 = acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I7);
return transform_tensor_descriptor(
acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4,
make_tuple(make_merge_transform_v3_division_mod(make_tuple(n0, n1, n2, n3)),
make_merge_transform_v3_division_mod(make_tuple(m0, m1, m2)),
make_pass_through_transform(n4)),
make_tuple(Sequence<1, 3, 5, 6>{}, Sequence<0, 2, 4>{}, Sequence<7>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
}
__host__ __device__ static constexpr auto GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1()
{
// B1 matrix in LDS memory, dst of blockwise copy
return make_naive_tensor_descriptor(
make_tuple(B1K0, Number<Gemm1NPerBlock>{}, B1K1),
make_tuple(Number<Gemm1NPerBlock + B1BlockLdsExtraN>{} * B1K1, B1K1, I1));
}
__host__ __device__ static constexpr auto
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock()
{
constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
constexpr index_t NWave = Gemm1NPerBlock / (Gemm1NXdlPerWave * NPerXdl);
constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
make_naive_tensor_descriptor_packed(
make_tuple(I1,
Number<CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl>{},
I1,
Number<CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>{}));
return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock;
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
template <typename Block2CTileMap>
__host__ __device__ static constexpr bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n,
......@@ -171,13 +103,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
return true;
}
__host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
{
const index_t num_loop = K / KPerBlock;
return GridwiseGemmPipe::CalculateHasMainLoop(num_loop);
}
__host__ __device__ static constexpr auto
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc_M_N& c_grid_desc_m_n)
{
......@@ -335,7 +260,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
auto ors_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_ors_grid, ors_grid_desc_m.GetElementSpaceSize());
ignore = ors_grid_buf;
// divide block work by [M, O]
const auto block_work_idx =
block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
......@@ -362,13 +286,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
constexpr auto thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 =
s_blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
constexpr auto m0 = thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I0);
// constexpr auto n0 = thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I1);
constexpr auto m1 = thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I2);
// constexpr auto n1 = thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I3);
constexpr auto m2 = thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I4);
// constexpr auto n2 = thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I5);
// constexpr auto n3 = thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I6);
// constexpr auto n4 = thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I7);
constexpr auto ors_thread_desc_mblock_mrepeat_mwave_mperxdl =
make_naive_tensor_descriptor_packed(make_tuple(I1, m0, m1, m2));
// m0, n0 are m/n repeat per wave
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment