Commit cf9ef868 authored by ltqin's avatar ltqin
Browse files

remove useless code

parent 118742b6
......@@ -34,8 +34,7 @@ template <typename GridwiseGemm,
typename YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock,
typename ORSGridDescriptor_M,
typename Block2CTileMap,
typename ComputeBasePtrOfStridedBatch,
bool Deterministic>
typename ComputeBasePtrOfStridedBatch>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, /*CK_MIN_BLOCK_PER_CU*/ 1)
......@@ -49,7 +48,6 @@ __global__ void
const ORSGridDescriptor_M ors_grid_desc_m,
const Block2CTileMap block_2_ctile_map,
const index_t batch_count,
const index_t nblock,
const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
......@@ -65,32 +63,15 @@ __global__ void
const long_index_t ors_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetLSEBasePtr(g_idx)));
if constexpr(Deterministic)
{
for(index_t i = 0; i < nblock; i++)
{
GridwiseGemm::template Run(p_y_grid + c_batch_offset,
p_ygrad_grid + c_batch_offset,
p_ors_grid + ors_batch_offset,
p_shared,
c_grid_desc_mblock_mperblock_nblock_nperblock,
ors_grid_desc_m,
block_2_ctile_map,
i);
}
}
else
{
// GridwiseGemm::test();
GridwiseGemm::Run(p_y_grid + c_batch_offset,
p_ygrad_grid + c_batch_offset,
p_ors_grid + ors_batch_offset,
p_shared,
c_grid_desc_mblock_mperblock_nblock_nperblock,
ors_grid_desc_m,
block_2_ctile_map,
0);
}
// GridwiseGemm::test();
GridwiseGemm::Run(p_y_grid + c_batch_offset,
p_ygrad_grid + c_batch_offset,
p_ors_grid + ors_batch_offset,
p_shared,
c_grid_desc_mblock_mperblock_nblock_nperblock,
ors_grid_desc_m,
block_2_ctile_map);
#else
ignore = p_y_grid;
ignore = c_grid_desc_mblock_mperblock_nblock_nperblock;
......@@ -771,27 +752,12 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
using GridwiseYDotYGrad =
GridwiseBatchedMultiheadAttentionBackward_YDotYGrad<InputDataType, // TODO: distinguish A/B
// datatype
GemmDataType,
GemmAccDataType,
ORSDataType,
YGridDesc_M_O,
ORSGridDesc_M,
BlockSize,
128,
128,
KPerBlock,
Gemm1NPerBlock,
Gemm1KPerBlock,
AK1,
BK1,
B1K1,
32,
32,
1,
4,
ABlockLdsExtraM,
BBlockLdsExtraN,
Deterministic>;
32>;
// Argument
struct Argument : public BaseArgument
{
......@@ -1061,8 +1027,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock,
DeviceOp::ORSGridDesc_M,
typename GridwiseYDotYGrad::DefaultBlock2CTileMap,
ComputeBasePtrOfStridedBatch,
Deterministic>;
ComputeBasePtrOfStridedBatch>;
return launch_and_time_kernel(
stream_config,
......@@ -1077,7 +1042,6 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
arg.ors_grid_desc_m_,
arg.ors_block_2_ctile_map_,
arg.batch_count_,
arg.ors_block_2_ctile_map_.CalculateGridSize(arg.y_grid_desc_m_o_),
arg.compute_base_ptr_of_batch_);
};
......
......@@ -33,8 +33,7 @@ template <typename GridwiseGemm,
typename YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock,
typename ORSGridDescriptor_M,
typename Block2CTileMap,
typename ComputeBasePtrOfStridedBatch,
bool Deterministic>
typename ComputeBasePtrOfStridedBatch>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, /*CK_MIN_BLOCK_PER_CU*/ 1)
......@@ -48,7 +47,6 @@ __global__ void
const ORSGridDescriptor_M ors_grid_desc_m,
const Block2CTileMap block_2_ctile_map,
const index_t batch_count,
const index_t nblock,
const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
......@@ -64,32 +62,15 @@ __global__ void
const long_index_t ors_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetLSEBasePtr(g_idx)));
if constexpr(Deterministic)
{
for(index_t i = 0; i < nblock; i++)
{
GridwiseGemm::template Run(p_y_grid + c_batch_offset,
p_ygrad_grid + c_batch_offset,
p_ors_grid + ors_batch_offset,
p_shared,
c_grid_desc_mblock_mperblock_nblock_nperblock,
ors_grid_desc_m,
block_2_ctile_map,
i);
}
}
else
{
// GridwiseGemm::test();
GridwiseGemm::Run(p_y_grid + c_batch_offset,
p_ygrad_grid + c_batch_offset,
p_ors_grid + ors_batch_offset,
p_shared,
c_grid_desc_mblock_mperblock_nblock_nperblock,
ors_grid_desc_m,
block_2_ctile_map,
0);
}
// GridwiseGemm::test();
GridwiseGemm::Run(p_y_grid + c_batch_offset,
p_ygrad_grid + c_batch_offset,
p_ors_grid + ors_batch_offset,
p_shared,
c_grid_desc_mblock_mperblock_nblock_nperblock,
ors_grid_desc_m,
block_2_ctile_map);
#else
ignore = p_y_grid;
ignore = c_grid_desc_mblock_mperblock_nblock_nperblock;
......@@ -787,27 +768,12 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
using GridwiseYDotYGrad =
GridwiseBatchedMultiheadAttentionBackward_YDotYGrad<InputDataType, // TODO: distinguish A/B
// datatype
GemmDataType,
GemmAccDataType,
ORSDataType,
YGridDesc_M_O,
ORSGridDesc_M,
BlockSize,
256,
128,
KPerBlock,
32,
Gemm1KPerBlock,
AK1,
BK1,
B1K1,
64,
64,
1,
4,
ABlockLdsExtraM,
BBlockLdsExtraN,
Deterministic>;
64>;
// Argument
struct Argument : public BaseArgument
{
......@@ -1076,8 +1042,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock,
DeviceOp::ORSGridDesc_M,
typename GridwiseYDotYGrad::DefaultBlock2CTileMap,
ComputeBasePtrOfStridedBatch,
Deterministic>;
ComputeBasePtrOfStridedBatch>;
return launch_and_time_kernel(
stream_config,
......@@ -1092,7 +1057,6 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
arg.ors_grid_desc_m_,
arg.ors_block_2_ctile_map_,
arg.batch_count_,
arg.ors_block_2_ctile_map_.CalculateGridSize(arg.y_grid_desc_m_o_),
arg.compute_base_ptr_of_batch_);
};
......
......@@ -21,27 +21,12 @@
namespace ck {
template <typename InputDataType,
typename GemmDataType,
typename FloatGemmAcc,
typename FloatORS,
typename CGridDesc_M_N,
typename ORSGridDesc_M,
index_t BlockSize,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t Gemm1NPerBlock,
index_t Gemm1KPerBlock,
index_t AK1Value,
index_t BK1Value,
index_t B1K1Value,
index_t MPerXdl,
index_t NPerXdl,
index_t MXdlPerWave,
index_t NXdlPerWave,
index_t ABlockLdsExtraM,
index_t BBlockLdsExtraN,
bool Deterministic>
index_t NPerBlock>
struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
{
static constexpr auto I0 = Number<0>{};
......@@ -56,44 +41,14 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
static constexpr auto I9 = Number<9>{};
static constexpr auto WaveSize = 64;
// K1 should be Number<...>
// Gemm0
static constexpr auto AK0 = Number<KPerBlock / AK1Value>{};
static constexpr auto BK0 = Number<KPerBlock / BK1Value>{};
static constexpr auto AK1 = Number<AK1Value>{};
static constexpr auto BK1 = Number<BK1Value>{};
// Gemm1
static constexpr auto B1K0 = Number<Gemm1KPerBlock / B1K1Value>{};
static constexpr auto B1K1 = Number<B1K1Value>{};
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
__host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
{
// A matrix in LDS memory, dst of blockwise copy
return make_naive_tensor_descriptor(
make_tuple(AK0, Number<MPerBlock>{}, AK1),
make_tuple(Number<MPerBlock + ABlockLdsExtraM>{} * AK1, AK1, I1));
}
__host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
{
// B matrix in LDS memory, dst of blockwise copy
return make_naive_tensor_descriptor(
make_tuple(BK0, Number<NPerBlock>{}, BK1),
make_tuple(Number<NPerBlock + BBlockLdsExtraN>{} * BK1, BK1, I1));
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
template <typename Block2CTileMap>
__host__ __device__ static constexpr bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n,
const Block2CTileMap& block_2_ctile_map)
{
static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
(NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
"Invalid tuning param!");
if(!block_2_ctile_map.CheckValidity(c_grid_desc_m_n))
{
return false;
......@@ -110,12 +65,12 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
const auto N = c_grid_desc_m_n.GetLength(I1);
const auto MBlock = M / MPerBlock;
const auto NBlock = N / Gemm1NPerBlock;
const auto NBlock = N / NPerBlock;
const auto y_grid_desc_mblock_mperblock_oblock_operblock = transform_tensor_descriptor(
c_grid_desc_m_n,
make_tuple(make_unmerge_transform(make_tuple(MBlock, Number<MPerBlock>{})),
make_unmerge_transform(make_tuple(NBlock, Number<Gemm1NPerBlock>{}))),
make_unmerge_transform(make_tuple(NBlock, Number<NPerBlock>{}))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}));
......@@ -141,7 +96,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
__host__ __device__ static constexpr auto
MakeDefaultBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n)
{
return BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, Gemm1NPerBlock, CGridDesc_M_N>(
return BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, CGridDesc_M_N>(
c_grid_desc_m_n);
}
......@@ -151,64 +106,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
using DefaultBlock2CTileMap =
remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}))>;
// S / dP Gemm (type 1 rcr)
struct Gemm0
{
// A matrix in LDS memory, dst of blockwise copy
static constexpr auto a_block_desc_ak0_m_ak1 =
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
// B matrix in LDS memory, dst of blockwise copy
static constexpr auto b_block_desc_bk0_n_bk1 =
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
template <typename ABlockDesc_AK0_M_AK1>
__host__ __device__ static constexpr auto
MakeGemm0AMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1&)
{
constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
return MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K<MXdlPerWave, MWaves, MPerXdl>(
ABlockDesc_AK0_M_AK1{});
}
template <typename BBlockDesc_BK0_N_BK1>
__host__ __device__ static constexpr auto
MakeGemm0BMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1&)
{
constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl);
return MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K<NXdlPerWave, NWaves, NPerXdl>(
BBlockDesc_BK0_N_BK1{});
}
static constexpr index_t KPack =
math::max(math::lcm(AK1, BK1),
MfmaSelector<GemmDataType, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
// Blockwise gemm with transposed XDL output
using BlockwiseGemm = BlockwiseGemmXdlops_v2<
BlockSize,
GemmDataType,
FloatGemmAcc,
decltype(a_block_desc_ak0_m_ak1),
decltype(b_block_desc_bk0_n_bk1),
decltype(MakeGemm0AMmaTileDescriptor_M0_M1_M2_K(a_block_desc_ak0_m_ak1)),
decltype(MakeGemm0BMmaTileDescriptor_N0_N1_N2_K(b_block_desc_bk0_n_bk1)),
MPerBlock,
NPerBlock,
KPerBlock,
MPerXdl,
NPerXdl,
MXdlPerWave,
NXdlPerWave,
KPack,
true>; // TransposeC
static constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0);
static constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1, 0, 0);
};
template <index_t BlockSize_, index_t BlockSliceLength_M_, index_t BlockSliceLength_O_>
struct YDotYGrad_M_O_
{
......@@ -223,18 +120,18 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
static_assert(ThreadClusterLength_M * ThreadSliceLength_M == BlockSliceLength_M_, "");
using SrcBufType = StaticBuffer<AddressSpaceEnum::Vgpr,
FloatGemmAcc,
FloatORS,
ThreadSliceLength_M * ThreadSliceLength_O,
true>;
using DstBufType =
StaticBuffer<AddressSpaceEnum::Vgpr, FloatGemmAcc, ThreadSliceLength_M, true>;
StaticBuffer<AddressSpaceEnum::Vgpr, FloatORS, ThreadSliceLength_M, true>;
};
using YDotYGrad_M_O = YDotYGrad_M_O_<BlockSize, MPerBlock, Gemm1NPerBlock>;
using YDotYGrad_M_O = YDotYGrad_M_O_<BlockSize, MPerBlock, NPerBlock>;
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{
return MPerBlock * sizeof(FloatGemmAcc);
return MPerBlock * sizeof(FloatORS);
}
__device__ static void test() {}
......@@ -246,8 +143,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
const YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock&
y_grid_desc_mblock_mperblock_oblock_operblock,
const ORSGridDesc_M& ors_grid_desc_m,
const Block2CTileMap& block_2_ctile_map,
const index_t block_idx_m)
const Block2CTileMap& block_2_ctile_map)
{
const auto y_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_y_grid, y_grid_desc_mblock_mperblock_oblock_operblock.GetElementSpaceSize());
......@@ -269,7 +165,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
return;
}
const index_t block_work_idx_m = Deterministic ? block_idx_m : block_work_idx[I0];
const index_t block_work_idx_m = block_work_idx[I0];
constexpr auto ors_thread_desc_mblock_mrepeat_mwave_mperxdl =
make_naive_tensor_descriptor_packed(make_tuple(I1, I1));
......@@ -306,7 +202,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
// performs double duty for both y and ygrad
auto yygrad_threadwise_copy = ThreadwiseTensorSliceTransfer_v2<
InputDataType,
FloatGemmAcc,
FloatORS,
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock,
decltype(y_thread_desc_m0_m1_o0_o1),
decltype(y_thread_desc_m0_m1_o0_o1.GetLengths()),
......@@ -321,13 +217,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
auto y_thread_buf = typename YDotYGrad_M_O::SrcBufType{};
auto ygrad_thread_buf = typename YDotYGrad_M_O::SrcBufType{};
auto y_dot_ygrad_thread_accum_buf = typename YDotYGrad_M_O::DstBufType{};
auto y_dot_ygrad_block_accum_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatGemmAcc*>(p_shared), MPerBlock);
if constexpr(Deterministic)
{
block_sync_lds();
}
auto y_dot_ygrad_block_accum_buf =
make_dynamic_buffer<AddressSpaceEnum::Lds>(static_cast<FloatORS*>(p_shared), MPerBlock);
// clear accum buffers
y_dot_ygrad_thread_accum_buf.Clear();
......@@ -366,7 +257,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
MakeORSGridDescriptor_MBlock_MRepeat_NWave_MPerXdl(ors_grid_desc_m);
auto ors_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3<
FloatGemmAcc,
FloatORS,
FloatORS,
decltype(ors_thread_desc_mblock_mrepeat_mwave_mperxdl),
decltype(ors_grid_desc_mblock_mrepeat_mwave_mperxdl),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment