Commit cf9ef868 authored by ltqin's avatar ltqin
Browse files

remove useless code

parent 118742b6
...@@ -34,8 +34,7 @@ template <typename GridwiseGemm, ...@@ -34,8 +34,7 @@ template <typename GridwiseGemm,
typename YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock, typename YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock,
typename ORSGridDescriptor_M, typename ORSGridDescriptor_M,
typename Block2CTileMap, typename Block2CTileMap,
typename ComputeBasePtrOfStridedBatch, typename ComputeBasePtrOfStridedBatch>
bool Deterministic>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, /*CK_MIN_BLOCK_PER_CU*/ 1) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, /*CK_MIN_BLOCK_PER_CU*/ 1)
...@@ -49,7 +48,6 @@ __global__ void ...@@ -49,7 +48,6 @@ __global__ void
const ORSGridDescriptor_M ors_grid_desc_m, const ORSGridDescriptor_M ors_grid_desc_m,
const Block2CTileMap block_2_ctile_map, const Block2CTileMap block_2_ctile_map,
const index_t batch_count, const index_t batch_count,
const index_t nblock,
const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch) const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
...@@ -65,32 +63,15 @@ __global__ void ...@@ -65,32 +63,15 @@ __global__ void
const long_index_t ors_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t ors_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetLSEBasePtr(g_idx))); static_cast<long_index_t>(compute_base_ptr_of_batch.GetLSEBasePtr(g_idx)));
if constexpr(Deterministic) // GridwiseGemm::test();
{ GridwiseGemm::Run(p_y_grid + c_batch_offset,
for(index_t i = 0; i < nblock; i++) p_ygrad_grid + c_batch_offset,
{ p_ors_grid + ors_batch_offset,
GridwiseGemm::template Run(p_y_grid + c_batch_offset, p_shared,
p_ygrad_grid + c_batch_offset, c_grid_desc_mblock_mperblock_nblock_nperblock,
p_ors_grid + ors_batch_offset, ors_grid_desc_m,
p_shared, block_2_ctile_map);
c_grid_desc_mblock_mperblock_nblock_nperblock,
ors_grid_desc_m,
block_2_ctile_map,
i);
}
}
else
{
// GridwiseGemm::test();
GridwiseGemm::Run(p_y_grid + c_batch_offset,
p_ygrad_grid + c_batch_offset,
p_ors_grid + ors_batch_offset,
p_shared,
c_grid_desc_mblock_mperblock_nblock_nperblock,
ors_grid_desc_m,
block_2_ctile_map,
0);
}
#else #else
ignore = p_y_grid; ignore = p_y_grid;
ignore = c_grid_desc_mblock_mperblock_nblock_nperblock; ignore = c_grid_desc_mblock_mperblock_nblock_nperblock;
...@@ -771,27 +752,12 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -771,27 +752,12 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
using GridwiseYDotYGrad = using GridwiseYDotYGrad =
GridwiseBatchedMultiheadAttentionBackward_YDotYGrad<InputDataType, // TODO: distinguish A/B GridwiseBatchedMultiheadAttentionBackward_YDotYGrad<InputDataType, // TODO: distinguish A/B
// datatype // datatype
GemmDataType,
GemmAccDataType,
ORSDataType, ORSDataType,
YGridDesc_M_O, YGridDesc_M_O,
ORSGridDesc_M, ORSGridDesc_M,
BlockSize, BlockSize,
128, 128,
128, 32>;
KPerBlock,
Gemm1NPerBlock,
Gemm1KPerBlock,
AK1,
BK1,
B1K1,
32,
32,
1,
4,
ABlockLdsExtraM,
BBlockLdsExtraN,
Deterministic>;
// Argument // Argument
struct Argument : public BaseArgument struct Argument : public BaseArgument
{ {
...@@ -1061,8 +1027,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -1061,8 +1027,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock, YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock,
DeviceOp::ORSGridDesc_M, DeviceOp::ORSGridDesc_M,
typename GridwiseYDotYGrad::DefaultBlock2CTileMap, typename GridwiseYDotYGrad::DefaultBlock2CTileMap,
ComputeBasePtrOfStridedBatch, ComputeBasePtrOfStridedBatch>;
Deterministic>;
return launch_and_time_kernel( return launch_and_time_kernel(
stream_config, stream_config,
...@@ -1077,7 +1042,6 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -1077,7 +1042,6 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
arg.ors_grid_desc_m_, arg.ors_grid_desc_m_,
arg.ors_block_2_ctile_map_, arg.ors_block_2_ctile_map_,
arg.batch_count_, arg.batch_count_,
arg.ors_block_2_ctile_map_.CalculateGridSize(arg.y_grid_desc_m_o_),
arg.compute_base_ptr_of_batch_); arg.compute_base_ptr_of_batch_);
}; };
......
...@@ -33,8 +33,7 @@ template <typename GridwiseGemm, ...@@ -33,8 +33,7 @@ template <typename GridwiseGemm,
typename YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock, typename YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock,
typename ORSGridDescriptor_M, typename ORSGridDescriptor_M,
typename Block2CTileMap, typename Block2CTileMap,
typename ComputeBasePtrOfStridedBatch, typename ComputeBasePtrOfStridedBatch>
bool Deterministic>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, /*CK_MIN_BLOCK_PER_CU*/ 1) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, /*CK_MIN_BLOCK_PER_CU*/ 1)
...@@ -48,7 +47,6 @@ __global__ void ...@@ -48,7 +47,6 @@ __global__ void
const ORSGridDescriptor_M ors_grid_desc_m, const ORSGridDescriptor_M ors_grid_desc_m,
const Block2CTileMap block_2_ctile_map, const Block2CTileMap block_2_ctile_map,
const index_t batch_count, const index_t batch_count,
const index_t nblock,
const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch) const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
...@@ -64,32 +62,15 @@ __global__ void ...@@ -64,32 +62,15 @@ __global__ void
const long_index_t ors_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t ors_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetLSEBasePtr(g_idx))); static_cast<long_index_t>(compute_base_ptr_of_batch.GetLSEBasePtr(g_idx)));
if constexpr(Deterministic) // GridwiseGemm::test();
{ GridwiseGemm::Run(p_y_grid + c_batch_offset,
for(index_t i = 0; i < nblock; i++) p_ygrad_grid + c_batch_offset,
{ p_ors_grid + ors_batch_offset,
GridwiseGemm::template Run(p_y_grid + c_batch_offset, p_shared,
p_ygrad_grid + c_batch_offset, c_grid_desc_mblock_mperblock_nblock_nperblock,
p_ors_grid + ors_batch_offset, ors_grid_desc_m,
p_shared, block_2_ctile_map);
c_grid_desc_mblock_mperblock_nblock_nperblock,
ors_grid_desc_m,
block_2_ctile_map,
i);
}
}
else
{
// GridwiseGemm::test();
GridwiseGemm::Run(p_y_grid + c_batch_offset,
p_ygrad_grid + c_batch_offset,
p_ors_grid + ors_batch_offset,
p_shared,
c_grid_desc_mblock_mperblock_nblock_nperblock,
ors_grid_desc_m,
block_2_ctile_map,
0);
}
#else #else
ignore = p_y_grid; ignore = p_y_grid;
ignore = c_grid_desc_mblock_mperblock_nblock_nperblock; ignore = c_grid_desc_mblock_mperblock_nblock_nperblock;
...@@ -787,27 +768,12 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -787,27 +768,12 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
using GridwiseYDotYGrad = using GridwiseYDotYGrad =
GridwiseBatchedMultiheadAttentionBackward_YDotYGrad<InputDataType, // TODO: distinguish A/B GridwiseBatchedMultiheadAttentionBackward_YDotYGrad<InputDataType, // TODO: distinguish A/B
// datatype // datatype
GemmDataType,
GemmAccDataType,
ORSDataType, ORSDataType,
YGridDesc_M_O, YGridDesc_M_O,
ORSGridDesc_M, ORSGridDesc_M,
BlockSize, BlockSize,
256, 256,
128, 64>;
KPerBlock,
32,
Gemm1KPerBlock,
AK1,
BK1,
B1K1,
64,
64,
1,
4,
ABlockLdsExtraM,
BBlockLdsExtraN,
Deterministic>;
// Argument // Argument
struct Argument : public BaseArgument struct Argument : public BaseArgument
{ {
...@@ -1076,8 +1042,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -1076,8 +1042,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock, YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock,
DeviceOp::ORSGridDesc_M, DeviceOp::ORSGridDesc_M,
typename GridwiseYDotYGrad::DefaultBlock2CTileMap, typename GridwiseYDotYGrad::DefaultBlock2CTileMap,
ComputeBasePtrOfStridedBatch, ComputeBasePtrOfStridedBatch>;
Deterministic>;
return launch_and_time_kernel( return launch_and_time_kernel(
stream_config, stream_config,
...@@ -1092,7 +1057,6 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -1092,7 +1057,6 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
arg.ors_grid_desc_m_, arg.ors_grid_desc_m_,
arg.ors_block_2_ctile_map_, arg.ors_block_2_ctile_map_,
arg.batch_count_, arg.batch_count_,
arg.ors_block_2_ctile_map_.CalculateGridSize(arg.y_grid_desc_m_o_),
arg.compute_base_ptr_of_batch_); arg.compute_base_ptr_of_batch_);
}; };
......
...@@ -21,27 +21,12 @@ ...@@ -21,27 +21,12 @@
namespace ck { namespace ck {
template <typename InputDataType, template <typename InputDataType,
typename GemmDataType,
typename FloatGemmAcc,
typename FloatORS, typename FloatORS,
typename CGridDesc_M_N, typename CGridDesc_M_N,
typename ORSGridDesc_M, typename ORSGridDesc_M,
index_t BlockSize, index_t BlockSize,
index_t MPerBlock, index_t MPerBlock,
index_t NPerBlock, index_t NPerBlock>
index_t KPerBlock,
index_t Gemm1NPerBlock,
index_t Gemm1KPerBlock,
index_t AK1Value,
index_t BK1Value,
index_t B1K1Value,
index_t MPerXdl,
index_t NPerXdl,
index_t MXdlPerWave,
index_t NXdlPerWave,
index_t ABlockLdsExtraM,
index_t BBlockLdsExtraN,
bool Deterministic>
struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
...@@ -56,44 +41,14 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad ...@@ -56,44 +41,14 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
static constexpr auto I9 = Number<9>{}; static constexpr auto I9 = Number<9>{};
static constexpr auto WaveSize = 64; static constexpr auto WaveSize = 64;
// K1 should be Number<...>
// Gemm0
static constexpr auto AK0 = Number<KPerBlock / AK1Value>{};
static constexpr auto BK0 = Number<KPerBlock / BK1Value>{};
static constexpr auto AK1 = Number<AK1Value>{};
static constexpr auto BK1 = Number<BK1Value>{};
// Gemm1
static constexpr auto B1K0 = Number<Gemm1KPerBlock / B1K1Value>{};
static constexpr auto B1K1 = Number<B1K1Value>{};
using ThisThreadBlock = ThisThreadBlock<BlockSize>; using ThisThreadBlock = ThisThreadBlock<BlockSize>;
__host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
{
// A matrix in LDS memory, dst of blockwise copy
return make_naive_tensor_descriptor(
make_tuple(AK0, Number<MPerBlock>{}, AK1),
make_tuple(Number<MPerBlock + ABlockLdsExtraM>{} * AK1, AK1, I1));
}
__host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
{
// B matrix in LDS memory, dst of blockwise copy
return make_naive_tensor_descriptor(
make_tuple(BK0, Number<NPerBlock>{}, BK1),
make_tuple(Number<NPerBlock + BBlockLdsExtraN>{} * BK1, BK1, I1));
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
template <typename Block2CTileMap> template <typename Block2CTileMap>
__host__ __device__ static constexpr bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n, __host__ __device__ static constexpr bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n,
const Block2CTileMap& block_2_ctile_map) const Block2CTileMap& block_2_ctile_map)
{ {
static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
(NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
"Invalid tuning param!");
if(!block_2_ctile_map.CheckValidity(c_grid_desc_m_n)) if(!block_2_ctile_map.CheckValidity(c_grid_desc_m_n))
{ {
return false; return false;
...@@ -110,12 +65,12 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad ...@@ -110,12 +65,12 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
const auto N = c_grid_desc_m_n.GetLength(I1); const auto N = c_grid_desc_m_n.GetLength(I1);
const auto MBlock = M / MPerBlock; const auto MBlock = M / MPerBlock;
const auto NBlock = N / Gemm1NPerBlock; const auto NBlock = N / NPerBlock;
const auto y_grid_desc_mblock_mperblock_oblock_operblock = transform_tensor_descriptor( const auto y_grid_desc_mblock_mperblock_oblock_operblock = transform_tensor_descriptor(
c_grid_desc_m_n, c_grid_desc_m_n,
make_tuple(make_unmerge_transform(make_tuple(MBlock, Number<MPerBlock>{})), make_tuple(make_unmerge_transform(make_tuple(MBlock, Number<MPerBlock>{})),
make_unmerge_transform(make_tuple(NBlock, Number<Gemm1NPerBlock>{}))), make_unmerge_transform(make_tuple(NBlock, Number<NPerBlock>{}))),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}));
...@@ -141,7 +96,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad ...@@ -141,7 +96,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeDefaultBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n) MakeDefaultBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n)
{ {
return BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, Gemm1NPerBlock, CGridDesc_M_N>( return BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, CGridDesc_M_N>(
c_grid_desc_m_n); c_grid_desc_m_n);
} }
...@@ -151,64 +106,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad ...@@ -151,64 +106,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
using DefaultBlock2CTileMap = using DefaultBlock2CTileMap =
remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}))>; remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}))>;
// S / dP Gemm (type 1 rcr)
struct Gemm0
{
// A matrix in LDS memory, dst of blockwise copy
static constexpr auto a_block_desc_ak0_m_ak1 =
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
// B matrix in LDS memory, dst of blockwise copy
static constexpr auto b_block_desc_bk0_n_bk1 =
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
template <typename ABlockDesc_AK0_M_AK1>
__host__ __device__ static constexpr auto
MakeGemm0AMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1&)
{
constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
return MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K<MXdlPerWave, MWaves, MPerXdl>(
ABlockDesc_AK0_M_AK1{});
}
template <typename BBlockDesc_BK0_N_BK1>
__host__ __device__ static constexpr auto
MakeGemm0BMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1&)
{
constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl);
return MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K<NXdlPerWave, NWaves, NPerXdl>(
BBlockDesc_BK0_N_BK1{});
}
static constexpr index_t KPack =
math::max(math::lcm(AK1, BK1),
MfmaSelector<GemmDataType, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
// Blockwise gemm with transposed XDL output
using BlockwiseGemm = BlockwiseGemmXdlops_v2<
BlockSize,
GemmDataType,
FloatGemmAcc,
decltype(a_block_desc_ak0_m_ak1),
decltype(b_block_desc_bk0_n_bk1),
decltype(MakeGemm0AMmaTileDescriptor_M0_M1_M2_K(a_block_desc_ak0_m_ak1)),
decltype(MakeGemm0BMmaTileDescriptor_N0_N1_N2_K(b_block_desc_bk0_n_bk1)),
MPerBlock,
NPerBlock,
KPerBlock,
MPerXdl,
NPerXdl,
MXdlPerWave,
NXdlPerWave,
KPack,
true>; // TransposeC
static constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0);
static constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1, 0, 0);
};
template <index_t BlockSize_, index_t BlockSliceLength_M_, index_t BlockSliceLength_O_> template <index_t BlockSize_, index_t BlockSliceLength_M_, index_t BlockSliceLength_O_>
struct YDotYGrad_M_O_ struct YDotYGrad_M_O_
{ {
...@@ -223,18 +120,18 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad ...@@ -223,18 +120,18 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
static_assert(ThreadClusterLength_M * ThreadSliceLength_M == BlockSliceLength_M_, ""); static_assert(ThreadClusterLength_M * ThreadSliceLength_M == BlockSliceLength_M_, "");
using SrcBufType = StaticBuffer<AddressSpaceEnum::Vgpr, using SrcBufType = StaticBuffer<AddressSpaceEnum::Vgpr,
FloatGemmAcc, FloatORS,
ThreadSliceLength_M * ThreadSliceLength_O, ThreadSliceLength_M * ThreadSliceLength_O,
true>; true>;
using DstBufType = using DstBufType =
StaticBuffer<AddressSpaceEnum::Vgpr, FloatGemmAcc, ThreadSliceLength_M, true>; StaticBuffer<AddressSpaceEnum::Vgpr, FloatORS, ThreadSliceLength_M, true>;
}; };
using YDotYGrad_M_O = YDotYGrad_M_O_<BlockSize, MPerBlock, Gemm1NPerBlock>; using YDotYGrad_M_O = YDotYGrad_M_O_<BlockSize, MPerBlock, NPerBlock>;
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{ {
return MPerBlock * sizeof(FloatGemmAcc); return MPerBlock * sizeof(FloatORS);
} }
__device__ static void test() {} __device__ static void test() {}
...@@ -246,8 +143,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad ...@@ -246,8 +143,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
const YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock& const YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock&
y_grid_desc_mblock_mperblock_oblock_operblock, y_grid_desc_mblock_mperblock_oblock_operblock,
const ORSGridDesc_M& ors_grid_desc_m, const ORSGridDesc_M& ors_grid_desc_m,
const Block2CTileMap& block_2_ctile_map, const Block2CTileMap& block_2_ctile_map)
const index_t block_idx_m)
{ {
const auto y_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto y_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_y_grid, y_grid_desc_mblock_mperblock_oblock_operblock.GetElementSpaceSize()); p_y_grid, y_grid_desc_mblock_mperblock_oblock_operblock.GetElementSpaceSize());
...@@ -269,7 +165,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad ...@@ -269,7 +165,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
return; return;
} }
const index_t block_work_idx_m = Deterministic ? block_idx_m : block_work_idx[I0]; const index_t block_work_idx_m = block_work_idx[I0];
constexpr auto ors_thread_desc_mblock_mrepeat_mwave_mperxdl = constexpr auto ors_thread_desc_mblock_mrepeat_mwave_mperxdl =
make_naive_tensor_descriptor_packed(make_tuple(I1, I1)); make_naive_tensor_descriptor_packed(make_tuple(I1, I1));
...@@ -306,7 +202,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad ...@@ -306,7 +202,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
// performs double duty for both y and ygrad // performs double duty for both y and ygrad
auto yygrad_threadwise_copy = ThreadwiseTensorSliceTransfer_v2< auto yygrad_threadwise_copy = ThreadwiseTensorSliceTransfer_v2<
InputDataType, InputDataType,
FloatGemmAcc, FloatORS,
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock, YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock,
decltype(y_thread_desc_m0_m1_o0_o1), decltype(y_thread_desc_m0_m1_o0_o1),
decltype(y_thread_desc_m0_m1_o0_o1.GetLengths()), decltype(y_thread_desc_m0_m1_o0_o1.GetLengths()),
...@@ -321,13 +217,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad ...@@ -321,13 +217,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
auto y_thread_buf = typename YDotYGrad_M_O::SrcBufType{}; auto y_thread_buf = typename YDotYGrad_M_O::SrcBufType{};
auto ygrad_thread_buf = typename YDotYGrad_M_O::SrcBufType{}; auto ygrad_thread_buf = typename YDotYGrad_M_O::SrcBufType{};
auto y_dot_ygrad_thread_accum_buf = typename YDotYGrad_M_O::DstBufType{}; auto y_dot_ygrad_thread_accum_buf = typename YDotYGrad_M_O::DstBufType{};
auto y_dot_ygrad_block_accum_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto y_dot_ygrad_block_accum_buf =
static_cast<FloatGemmAcc*>(p_shared), MPerBlock); make_dynamic_buffer<AddressSpaceEnum::Lds>(static_cast<FloatORS*>(p_shared), MPerBlock);
if constexpr(Deterministic)
{
block_sync_lds();
}
// clear accum buffers // clear accum buffers
y_dot_ygrad_thread_accum_buf.Clear(); y_dot_ygrad_thread_accum_buf.Clear();
...@@ -366,7 +257,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad ...@@ -366,7 +257,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
MakeORSGridDescriptor_MBlock_MRepeat_NWave_MPerXdl(ors_grid_desc_m); MakeORSGridDescriptor_MBlock_MRepeat_NWave_MPerXdl(ors_grid_desc_m);
auto ors_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3< auto ors_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3<
FloatGemmAcc, FloatORS,
FloatORS, FloatORS,
decltype(ors_thread_desc_mblock_mrepeat_mwave_mperxdl), decltype(ors_thread_desc_mblock_mrepeat_mwave_mperxdl),
decltype(ors_grid_desc_mblock_mrepeat_mwave_mperxdl), decltype(ors_grid_desc_mblock_mrepeat_mwave_mperxdl),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment