Commit 15713b20 authored by danyao12's avatar danyao12
Browse files

rename functions

parent 8defa341
...@@ -881,7 +881,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1 ...@@ -881,7 +881,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_ = c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_ =
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3(z_grid_desc_m_n_); GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3(z_grid_desc_m_n_);
d_y_grid_desc_mblock_mperblock_oblock_operblock_ = d_y_grid_desc_mblock_mperblock_oblock_operblock_ =
GridwiseYDotYGrad::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( GridwiseYDotYGrad::MakeYGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
y_grid_desc_m_o_); y_grid_desc_m_o_);
// Print(); // Print();
......
...@@ -894,7 +894,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2 ...@@ -894,7 +894,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_ = c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_ =
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3(z_grid_desc_m_n_); GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3(z_grid_desc_m_n_);
d_y_grid_desc_mblock_mperblock_oblock_operblock_ = d_y_grid_desc_mblock_mperblock_oblock_operblock_ =
GridwiseYDotYGrad::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( GridwiseYDotYGrad::MakeYGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
y_grid_desc_m_o_); y_grid_desc_m_o_);
// Print(); // Print();
......
...@@ -949,7 +949,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1 ...@@ -949,7 +949,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
const auto d_block_2_ctile_map = const auto d_block_2_ctile_map =
GridwiseYDotYGrad::MakeDefaultBlock2CTileMap(y_grid_desc_m_o); GridwiseYDotYGrad::MakeDefaultBlock2CTileMap(y_grid_desc_m_o);
const auto d_y_grid_desc_mblock_mperblock_nblock_nperblock = const auto d_y_grid_desc_mblock_mperblock_nblock_nperblock =
GridwiseYDotYGrad::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( GridwiseYDotYGrad::MakeYGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
y_grid_desc_m_o); y_grid_desc_m_o);
index_t d_num_blocks_per_batch = index_t d_num_blocks_per_batch =
......
...@@ -951,7 +951,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2 ...@@ -951,7 +951,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
const auto d_block_2_ctile_map = const auto d_block_2_ctile_map =
GridwiseYDotYGrad::MakeDefaultBlock2CTileMap(y_grid_desc_m_o); GridwiseYDotYGrad::MakeDefaultBlock2CTileMap(y_grid_desc_m_o);
const auto d_y_grid_desc_mblock_mperblock_nblock_nperblock = const auto d_y_grid_desc_mblock_mperblock_nblock_nperblock =
GridwiseYDotYGrad::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( GridwiseYDotYGrad::MakeYGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
y_grid_desc_m_o); y_grid_desc_m_o);
index_t d_num_blocks_per_batch = index_t d_num_blocks_per_batch =
......
...@@ -22,7 +22,7 @@ namespace ck { ...@@ -22,7 +22,7 @@ namespace ck {
template <typename InputDataType, template <typename InputDataType,
typename FloatD, typename FloatD,
typename CGridDesc_M_N, typename YGridDesc_M_N,
typename DGridDesc_M, typename DGridDesc_M,
index_t BlockSize, index_t BlockSize,
index_t MPerBlock, index_t MPerBlock,
...@@ -32,23 +32,21 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad ...@@ -32,23 +32,21 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{}; static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{};
static constexpr auto WaveSize = 64; static constexpr auto WaveSize = 64;
static_assert(BlockSize == MPerBlock, "BlockSize must be same with MPerBlock"); static_assert(BlockSize == MPerBlock, "BlockSize must be same with MPerBlock");
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
template <typename Block2CTileMap> template <typename Block2CTileMap>
__host__ __device__ static constexpr bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n, __host__ __device__ static constexpr bool CheckValidity(const YGridDesc_M_N& y_grid_desc_m_n,
const Block2CTileMap& block_2_ctile_map) const Block2CTileMap& block_2_ctile_map)
{ {
if(!block_2_ctile_map.CheckValidity(c_grid_desc_m_n)) if(!block_2_ctile_map.CheckValidity(y_grid_desc_m_n))
{ {
return false; return false;
} }
// const auto M = c_grid_desc_m_n.GetLength(I0); // const auto M = y_grid_desc_m_n.GetLength(I0);
const auto N = c_grid_desc_m_n.GetLength(I1); const auto N = y_grid_desc_m_n.GetLength(I1);
if(N < NPerBlock) if(N < NPerBlock)
{ {
return false; return false;
...@@ -62,21 +60,20 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad ...@@ -62,21 +60,20 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
// { // {
// return false; // return false;
// } // }
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return true; return true;
} }
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc_M_N& c_grid_desc_m_n) MakeYGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const YGridDesc_M_N& y_grid_desc_m_n)
{ {
const auto M = c_grid_desc_m_n.GetLength(I0); const auto M = y_grid_desc_m_n.GetLength(I0);
const auto N = c_grid_desc_m_n.GetLength(I1); const auto N = y_grid_desc_m_n.GetLength(I1);
const auto MBlock = M / MPerBlock; const auto MBlock = M / MPerBlock;
const auto NBlock = N / NPerBlock; const auto NBlock = N / NPerBlock;
const auto y_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( const auto y_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
c_grid_desc_m_n, y_grid_desc_m_n,
make_tuple(make_unmerge_transform(make_tuple(MBlock, Number<MPerBlock>{})), make_tuple(make_unmerge_transform(make_tuple(MBlock, Number<MPerBlock>{})),
make_unmerge_transform(make_tuple(NBlock, Number<NPerBlock>{}))), make_unmerge_transform(make_tuple(NBlock, Number<NPerBlock>{}))),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
...@@ -86,7 +83,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad ...@@ -86,7 +83,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
} }
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeORSGridDescriptor_MBlock_MPerBlock(const DGridDesc_M& d_grid_desc_m) MakeDGridDescriptor_MBlock_MPerBlock(const DGridDesc_M& d_grid_desc_m)
{ {
const index_t M = d_grid_desc_m.GetLength(I0); const index_t M = d_grid_desc_m.GetLength(I0);
const index_t MBlock = M / MPerBlock; const index_t MBlock = M / MPerBlock;
...@@ -100,19 +97,19 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad ...@@ -100,19 +97,19 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
return d_grid_desc_mblock_mperblock; return d_grid_desc_mblock_mperblock;
} }
// return block_id to C matrix tile idx (m0, n0) mapping // return block_id to Y matrix tile idx (m0, n0) mapping
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeDefaultBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n) MakeDefaultBlock2CTileMap(const YGridDesc_M_N& y_grid_desc_m_n)
{ {
return BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, CGridDesc_M_N>( return BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, YGridDesc_M_N>(
c_grid_desc_m_n); y_grid_desc_m_n);
} }
using YGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype( using YGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{}))>; MakeYGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(YGridDesc_M_N{}))>;
using DefaultBlock2CTileMap = using DefaultBlock2CTileMap =
remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}))>; remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(YGridDesc_M_N{}))>;
template <index_t BlockSize_, index_t BlockSliceLength_M_, index_t BlockSliceLength_O_> template <index_t BlockSize_, index_t BlockSliceLength_M_, index_t BlockSliceLength_O_>
struct YDotYGrad_M_N_ struct YDotYGrad_M_N_
...@@ -240,7 +237,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad ...@@ -240,7 +237,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
oblock_idx++; oblock_idx++;
} while(oblock_idx < y_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)); } while(oblock_idx < y_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2));
auto d_grid_desc_mblock_mperblock = MakeORSGridDescriptor_MBlock_MPerBlock(d_grid_desc_m); auto d_grid_desc_mblock_mperblock = MakeDGridDescriptor_MBlock_MPerBlock(d_grid_desc_m);
auto d_thread_copy_vgpr_to_global = auto d_thread_copy_vgpr_to_global =
ThreadwiseTensorSliceTransfer_v1r3<FloatD, ThreadwiseTensorSliceTransfer_v1r3<FloatD,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment