Commit 796b544e authored by ltqin's avatar ltqin
Browse files

remove share memory

parent dcfe312b
...@@ -51,7 +51,6 @@ __global__ void ...@@ -51,7 +51,6 @@ __global__ void
const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch) const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
const index_t num_blocks_per_batch = const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
...@@ -67,7 +66,6 @@ __global__ void ...@@ -67,7 +66,6 @@ __global__ void
GridwiseGemm::Run(p_y_grid + c_batch_offset, GridwiseGemm::Run(p_y_grid + c_batch_offset,
p_ygrad_grid + c_batch_offset, p_ygrad_grid + c_batch_offset,
p_d_grid + d_batch_offset, p_d_grid + d_batch_offset,
p_shared,
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
d_grid_desc_m, d_grid_desc_m,
block_2_ctile_map); block_2_ctile_map);
...@@ -759,8 +757,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1 ...@@ -759,8 +757,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
// GridwiseYDotYGrad // GridwiseYDotYGrad
using GridwiseYDotYGrad = using GridwiseYDotYGrad =
GridwiseBatchedMultiheadAttentionBackward_YDotYGrad<InputDataType, // TODO: distinguish A/B GridwiseBatchedMultiheadAttentionBackward_YDotYGrad<InputDataType, // TODO: distinguish A/B
// datatype DDataType, // datatype
DDataType,
YGridDesc_M_O, YGridDesc_M_O,
DGridDesc_M, DGridDesc_M,
BlockSize, BlockSize,
......
...@@ -50,7 +50,6 @@ __global__ void ...@@ -50,7 +50,6 @@ __global__ void
const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch) const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
const index_t num_blocks_per_batch = const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
...@@ -66,7 +65,6 @@ __global__ void ...@@ -66,7 +65,6 @@ __global__ void
GridwiseGemm::Run(p_y_grid + c_batch_offset, GridwiseGemm::Run(p_y_grid + c_batch_offset,
p_ygrad_grid + c_batch_offset, p_ygrad_grid + c_batch_offset,
p_d_grid + d_batch_offset, p_d_grid + d_batch_offset,
p_shared,
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
d_grid_desc_m, d_grid_desc_m,
block_2_ctile_map); block_2_ctile_map);
...@@ -773,8 +771,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2 ...@@ -773,8 +771,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
// GridwiseYDotYGrad // GridwiseYDotYGrad
using GridwiseYDotYGrad = using GridwiseYDotYGrad =
GridwiseBatchedMultiheadAttentionBackward_YDotYGrad<InputDataType, // TODO: distinguish A/B GridwiseBatchedMultiheadAttentionBackward_YDotYGrad<InputDataType, // TODO: distinguish A/B
// datatype DDataType, // datatype
DDataType,
YGridDesc_M_O, YGridDesc_M_O,
DGridDesc_M, DGridDesc_M,
BlockSize, BlockSize,
......
...@@ -37,7 +37,6 @@ __global__ void ...@@ -37,7 +37,6 @@ __global__ void
const void CK_CONSTANT_ADDRESS_SPACE* group_kernel_args, const index_t group_count) const void CK_CONSTANT_ADDRESS_SPACE* group_kernel_args, const index_t group_count)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
const index_t block_id = get_block_1d_id(); const index_t block_id = get_block_1d_id();
const auto arg_ptr = reinterpret_cast<const GroupKernelArg*>( const auto arg_ptr = reinterpret_cast<const GroupKernelArg*>(
cast_pointer_to_generic_address_space(group_kernel_args)); cast_pointer_to_generic_address_space(group_kernel_args));
...@@ -73,7 +72,6 @@ __global__ void ...@@ -73,7 +72,6 @@ __global__ void
GridwiseGemm::Run(arg_ptr[group_id].p_c_grid_ + c_batch_offset, GridwiseGemm::Run(arg_ptr[group_id].p_c_grid_ + c_batch_offset,
arg_ptr[group_id].p_ygrad_grid_ + c_batch_offset, arg_ptr[group_id].p_ygrad_grid_ + c_batch_offset,
arg_ptr[group_id].p_d_grid_ + d_batch_offset, arg_ptr[group_id].p_d_grid_ + d_batch_offset,
static_cast<void*>(p_shared),
arg_ptr[group_id].d_y_grid_desc_mblock_mperblock_nblock_nperblock_, arg_ptr[group_id].d_y_grid_desc_mblock_mperblock_nblock_nperblock_,
arg_ptr[group_id].d_grid_desc_m_, arg_ptr[group_id].d_grid_desc_m_,
arg_ptr[group_id].d_block_2_ctile_map_); arg_ptr[group_id].d_block_2_ctile_map_);
......
...@@ -37,7 +37,6 @@ __global__ void ...@@ -37,7 +37,6 @@ __global__ void
const void CK_CONSTANT_ADDRESS_SPACE* group_kernel_args, const index_t group_count) const void CK_CONSTANT_ADDRESS_SPACE* group_kernel_args, const index_t group_count)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
const index_t block_id = get_block_1d_id(); const index_t block_id = get_block_1d_id();
const auto arg_ptr = reinterpret_cast<const GroupKernelArg*>( const auto arg_ptr = reinterpret_cast<const GroupKernelArg*>(
cast_pointer_to_generic_address_space(group_kernel_args)); cast_pointer_to_generic_address_space(group_kernel_args));
...@@ -73,7 +72,6 @@ __global__ void ...@@ -73,7 +72,6 @@ __global__ void
GridwiseGemm::Run(arg_ptr[group_id].p_c_grid_ + c_batch_offset, GridwiseGemm::Run(arg_ptr[group_id].p_c_grid_ + c_batch_offset,
arg_ptr[group_id].p_ygrad_grid_ + c_batch_offset, arg_ptr[group_id].p_ygrad_grid_ + c_batch_offset,
arg_ptr[group_id].p_d_grid_ + d_batch_offset, arg_ptr[group_id].p_d_grid_ + d_batch_offset,
static_cast<void*>(p_shared),
arg_ptr[group_id].d_y_grid_desc_mblock_mperblock_nblock_nperblock_, arg_ptr[group_id].d_y_grid_desc_mblock_mperblock_nblock_nperblock_,
arg_ptr[group_id].d_grid_desc_m_, arg_ptr[group_id].d_grid_desc_m_,
arg_ptr[group_id].d_block_2_ctile_map_); arg_ptr[group_id].d_block_2_ctile_map_);
......
...@@ -136,15 +136,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad ...@@ -136,15 +136,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
}; };
using YDotYGrad_M_O = YDotYGrad_M_O_<BlockSize, MPerBlock, NPerBlock>; using YDotYGrad_M_O = YDotYGrad_M_O_<BlockSize, MPerBlock, NPerBlock>;
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{
return MPerBlock * sizeof(FloatD);
}
__device__ static void Run(const InputDataType* __restrict__ p_y_grid, __device__ static void Run(const InputDataType* __restrict__ p_y_grid,
const InputDataType* __restrict__ p_ygrad_grid, const InputDataType* __restrict__ p_ygrad_grid,
FloatD* __restrict__ p_d_grid, FloatD* __restrict__ p_d_grid,
void* __restrict__ p_shared,
const YGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock& const YGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
y_grid_desc_mblock_mperblock_nblock_nperblock, y_grid_desc_mblock_mperblock_nblock_nperblock,
const DGridDesc_M& d_grid_desc_m, const DGridDesc_M& d_grid_desc_m,
...@@ -213,12 +207,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad ...@@ -213,12 +207,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
auto y_thread_buf = typename YDotYGrad_M_O::SrcBufType{}; auto y_thread_buf = typename YDotYGrad_M_O::SrcBufType{};
auto ygrad_thread_buf = typename YDotYGrad_M_O::SrcBufType{}; auto ygrad_thread_buf = typename YDotYGrad_M_O::SrcBufType{};
auto y_dot_ygrad_thread_accum_buf = typename YDotYGrad_M_O::DstBufType{}; auto y_dot_ygrad_thread_accum_buf = typename YDotYGrad_M_O::DstBufType{};
auto y_dot_ygrad_block_accum_buf =
make_dynamic_buffer<AddressSpaceEnum::Lds>(static_cast<FloatD*>(p_shared), MPerBlock);
// clear accum buffers // clear accum buffers
y_dot_ygrad_thread_accum_buf.Clear(); y_dot_ygrad_thread_accum_buf.Clear();
y_dot_ygrad_block_accum_buf.Clear();
index_t oblock_idx = 0; index_t oblock_idx = 0;
do do
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment