Commit a188073b authored by ltqin's avatar ltqin
Browse files

fix bug for group Block2CTileMap

parent 0b472e28
...@@ -24,7 +24,7 @@ Kernel outputs: ...@@ -24,7 +24,7 @@ Kernel outputs:
*/ */
#define USING_MASK 0 #define USING_MASK 0
#define DIM 32 // DIM should be a multiple of 8. #define DIM 128 // DIM should be a multiple of 8.
#include <iostream> #include <iostream>
#include <numeric> #include <numeric>
......
...@@ -709,6 +709,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1 ...@@ -709,6 +709,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
BlockSize, BlockSize,
BlockSize, BlockSize,
DKPerBlock>; DKPerBlock>;
using DBlock2CTileMap =
OffsettedBlockToCTileMap<typename GridwiseYDotYGrad::DefaultBlock2CTileMap>;
struct GroupKernelArg struct GroupKernelArg
{ {
...@@ -751,7 +753,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1 ...@@ -751,7 +753,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
// D parameter // D parameter
DDataType* p_d_grid_; DDataType* p_d_grid_;
DGridDesc_M d_grid_desc_m_; DGridDesc_M d_grid_desc_m_;
typename GridwiseYDotYGrad::DefaultBlock2CTileMap d_block_2_ctile_map_; DBlock2CTileMap d_block_2_ctile_map_;
typename GridwiseYDotYGrad::YGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock typename GridwiseYDotYGrad::YGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
d_y_grid_desc_mblock_mperblock_nblock_nperblock_; d_y_grid_desc_mblock_mperblock_nblock_nperblock_;
index_t d_num_blocks_per_batch_; index_t d_num_blocks_per_batch_;
...@@ -930,16 +932,15 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1 ...@@ -930,16 +932,15 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
const auto p_d_grid = static_cast<DDataType*>(p_Ds[i]); const auto p_d_grid = static_cast<DDataType*>(p_Ds[i]);
const auto d_grid_desc_m = const auto d_grid_desc_m =
DeviceOp::MakeLSEGridDescriptor_M(problem_desc.d_gs_ms_lengths[NumDimG]); DeviceOp::MakeLSEGridDescriptor_M(problem_desc.d_gs_ms_lengths[NumDimG]);
index_t d_block_start = d_grid_size_;
const auto d_block_2_ctile_map = const auto d_block_2_ctile_map = DBlock2CTileMap(y_grid_desc_m_o, d_block_start);
GridwiseYDotYGrad::MakeDefaultBlock2CTileMap(y_grid_desc_m_o);
const auto d_y_grid_desc_mblock_mperblock_nblock_nperblock = const auto d_y_grid_desc_mblock_mperblock_nblock_nperblock =
GridwiseYDotYGrad::MakeYGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( GridwiseYDotYGrad::MakeYGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
y_grid_desc_m_o); y_grid_desc_m_o);
index_t d_num_blocks_per_batch = index_t d_num_blocks_per_batch =
d_block_2_ctile_map.CalculateGridSize(y_grid_desc_m_o); d_block_2_ctile_map.CalculateGridSize(y_grid_desc_m_o);
index_t d_block_start = d_grid_size_;
index_t d_block_end = d_block_start + d_num_blocks_per_batch * batch_count; index_t d_block_end = d_block_start + d_num_blocks_per_batch * batch_count;
d_grid_size_ = d_block_end; d_grid_size_ = d_block_end;
......
...@@ -717,6 +717,9 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2 ...@@ -717,6 +717,9 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
BlockSize, BlockSize,
BlockSize, BlockSize,
DKPerBlock>; DKPerBlock>;
using DBlock2CTileMap =
OffsettedBlockToCTileMap<typename GridwiseYDotYGrad::DefaultBlock2CTileMap>;
struct GroupKernelArg struct GroupKernelArg
{ {
// pointers // pointers
...@@ -758,7 +761,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2 ...@@ -758,7 +761,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
// D parameter // D parameter
DDataType* p_d_grid_; DDataType* p_d_grid_;
DGridDesc_M d_grid_desc_m_; DGridDesc_M d_grid_desc_m_;
typename GridwiseYDotYGrad::DefaultBlock2CTileMap d_block_2_ctile_map_; DBlock2CTileMap d_block_2_ctile_map_;
typename GridwiseYDotYGrad::YGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock typename GridwiseYDotYGrad::YGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
d_y_grid_desc_mblock_mperblock_nblock_nperblock_; d_y_grid_desc_mblock_mperblock_nblock_nperblock_;
index_t d_num_blocks_per_batch_; index_t d_num_blocks_per_batch_;
...@@ -933,15 +936,15 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2 ...@@ -933,15 +936,15 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
const auto d_grid_desc_m = const auto d_grid_desc_m =
DeviceOp::MakeLSEGridDescriptor_M(problem_desc.d_gs_ms_lengths[NumDimG]); DeviceOp::MakeLSEGridDescriptor_M(problem_desc.d_gs_ms_lengths[NumDimG]);
const auto d_block_2_ctile_map = index_t d_block_start = d_grid_size_;
GridwiseYDotYGrad::MakeDefaultBlock2CTileMap(y_grid_desc_m_o); const auto d_block_2_ctile_map = DBlock2CTileMap(y_grid_desc_m_o, d_block_start);
const auto d_y_grid_desc_mblock_mperblock_nblock_nperblock = const auto d_y_grid_desc_mblock_mperblock_nblock_nperblock =
GridwiseYDotYGrad::MakeYGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( GridwiseYDotYGrad::MakeYGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
y_grid_desc_m_o); y_grid_desc_m_o);
index_t d_num_blocks_per_batch = index_t d_num_blocks_per_batch =
d_block_2_ctile_map.CalculateGridSize(y_grid_desc_m_o); d_block_2_ctile_map.CalculateGridSize(y_grid_desc_m_o);
index_t d_block_start = d_grid_size_;
index_t d_block_end = d_block_start + d_num_blocks_per_batch * batch_count; index_t d_block_end = d_block_start + d_num_blocks_per_batch * batch_count;
d_grid_size_ = d_block_end; d_grid_size_ = d_block_end;
......
...@@ -101,8 +101,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad ...@@ -101,8 +101,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeDefaultBlock2CTileMap(const YGridDesc_M_N& y_grid_desc_m_n) MakeDefaultBlock2CTileMap(const YGridDesc_M_N& y_grid_desc_m_n)
{ {
return BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, YGridDesc_M_N>( // should rewrite BlockToCTileMap_M00_N0_M01Adapt
y_grid_desc_m_n); return BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, 1024, YGridDesc_M_N>(y_grid_desc_m_n);
} }
using YGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype( using YGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
...@@ -133,13 +133,14 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad ...@@ -133,13 +133,14 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
}; };
using YDotYGrad_M_N = YDotYGrad_M_N_<BlockSize, MPerBlock, NPerBlock>; using YDotYGrad_M_N = YDotYGrad_M_N_<BlockSize, MPerBlock, NPerBlock>;
template <typename Block2CTileMap>
__device__ static void Run(const InputDataType* __restrict__ p_y_grid, __device__ static void Run(const InputDataType* __restrict__ p_y_grid,
const InputDataType* __restrict__ p_ygrad_grid, const InputDataType* __restrict__ p_ygrad_grid,
FloatD* __restrict__ p_d_grid, FloatD* __restrict__ p_d_grid,
const YGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock& const YGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
y_grid_desc_mblock_mperblock_nblock_nperblock, y_grid_desc_mblock_mperblock_nblock_nperblock,
const DGridDesc_M& d_grid_desc_m, const DGridDesc_M& d_grid_desc_m,
const DefaultBlock2CTileMap& block_2_ctile_map) const Block2CTileMap& block_2_ctile_map)
{ {
const auto y_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto y_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_y_grid, y_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); p_y_grid, y_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment