Commit a188073b authored by ltqin's avatar ltqin
Browse files

fix bug for group Block2CTileMap

parent 0b472e28
......@@ -24,7 +24,7 @@ Kernel outputs:
*/
#define USING_MASK 0
#define DIM 32 // DIM should be a multiple of 8.
#define DIM 128 // DIM should be a multiple of 8.
#include <iostream>
#include <numeric>
......
......@@ -709,6 +709,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
BlockSize,
BlockSize,
DKPerBlock>;
using DBlock2CTileMap =
OffsettedBlockToCTileMap<typename GridwiseYDotYGrad::DefaultBlock2CTileMap>;
struct GroupKernelArg
{
......@@ -751,7 +753,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
// D parameter
DDataType* p_d_grid_;
DGridDesc_M d_grid_desc_m_;
typename GridwiseYDotYGrad::DefaultBlock2CTileMap d_block_2_ctile_map_;
DBlock2CTileMap d_block_2_ctile_map_;
typename GridwiseYDotYGrad::YGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
d_y_grid_desc_mblock_mperblock_nblock_nperblock_;
index_t d_num_blocks_per_batch_;
......@@ -930,16 +932,15 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
const auto p_d_grid = static_cast<DDataType*>(p_Ds[i]);
const auto d_grid_desc_m =
DeviceOp::MakeLSEGridDescriptor_M(problem_desc.d_gs_ms_lengths[NumDimG]);
const auto d_block_2_ctile_map =
GridwiseYDotYGrad::MakeDefaultBlock2CTileMap(y_grid_desc_m_o);
index_t d_block_start = d_grid_size_;
const auto d_block_2_ctile_map = DBlock2CTileMap(y_grid_desc_m_o, d_block_start);
const auto d_y_grid_desc_mblock_mperblock_nblock_nperblock =
GridwiseYDotYGrad::MakeYGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
y_grid_desc_m_o);
index_t d_num_blocks_per_batch =
d_block_2_ctile_map.CalculateGridSize(y_grid_desc_m_o);
index_t d_block_start = d_grid_size_;
index_t d_block_end = d_block_start + d_num_blocks_per_batch * batch_count;
d_grid_size_ = d_block_end;
......
......@@ -717,6 +717,9 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
BlockSize,
BlockSize,
DKPerBlock>;
using DBlock2CTileMap =
OffsettedBlockToCTileMap<typename GridwiseYDotYGrad::DefaultBlock2CTileMap>;
struct GroupKernelArg
{
// pointers
......@@ -758,7 +761,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
// D parameter
DDataType* p_d_grid_;
DGridDesc_M d_grid_desc_m_;
typename GridwiseYDotYGrad::DefaultBlock2CTileMap d_block_2_ctile_map_;
DBlock2CTileMap d_block_2_ctile_map_;
typename GridwiseYDotYGrad::YGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
d_y_grid_desc_mblock_mperblock_nblock_nperblock_;
index_t d_num_blocks_per_batch_;
......@@ -933,15 +936,15 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
const auto d_grid_desc_m =
DeviceOp::MakeLSEGridDescriptor_M(problem_desc.d_gs_ms_lengths[NumDimG]);
const auto d_block_2_ctile_map =
GridwiseYDotYGrad::MakeDefaultBlock2CTileMap(y_grid_desc_m_o);
index_t d_block_start = d_grid_size_;
const auto d_block_2_ctile_map = DBlock2CTileMap(y_grid_desc_m_o, d_block_start);
const auto d_y_grid_desc_mblock_mperblock_nblock_nperblock =
GridwiseYDotYGrad::MakeYGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
y_grid_desc_m_o);
index_t d_num_blocks_per_batch =
d_block_2_ctile_map.CalculateGridSize(y_grid_desc_m_o);
index_t d_block_start = d_grid_size_;
index_t d_block_end = d_block_start + d_num_blocks_per_batch * batch_count;
d_grid_size_ = d_block_end;
......
......@@ -101,8 +101,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
__host__ __device__ static constexpr auto
MakeDefaultBlock2CTileMap(const YGridDesc_M_N& y_grid_desc_m_n)
{
return BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, YGridDesc_M_N>(
y_grid_desc_m_n);
// should rewrite BlockToCTileMap_M00_N0_M01Adapt
return BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, 1024, YGridDesc_M_N>(y_grid_desc_m_n);
}
using YGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
......@@ -133,13 +133,14 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
};
using YDotYGrad_M_N = YDotYGrad_M_N_<BlockSize, MPerBlock, NPerBlock>;
template <typename Block2CTileMap>
__device__ static void Run(const InputDataType* __restrict__ p_y_grid,
const InputDataType* __restrict__ p_ygrad_grid,
FloatD* __restrict__ p_d_grid,
const YGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
y_grid_desc_mblock_mperblock_nblock_nperblock,
const DGridDesc_M& d_grid_desc_m,
const DefaultBlock2CTileMap& block_2_ctile_map)
const Block2CTileMap& block_2_ctile_map)
{
const auto y_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_y_grid, y_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment