Commit 796b544e authored by ltqin's avatar ltqin
Browse files

remove share memory

parent dcfe312b
......@@ -51,7 +51,6 @@ __global__ void
const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
......@@ -67,7 +66,6 @@ __global__ void
GridwiseGemm::Run(p_y_grid + c_batch_offset,
p_ygrad_grid + c_batch_offset,
p_d_grid + d_batch_offset,
p_shared,
c_grid_desc_mblock_mperblock_nblock_nperblock,
d_grid_desc_m,
block_2_ctile_map);
......@@ -759,8 +757,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
// GridwiseYDotYGrad
using GridwiseYDotYGrad =
GridwiseBatchedMultiheadAttentionBackward_YDotYGrad<InputDataType, // TODO: distinguish A/B
// datatype
DDataType,
DDataType, // datatype
YGridDesc_M_O,
DGridDesc_M,
BlockSize,
......
......@@ -50,7 +50,6 @@ __global__ void
const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
......@@ -66,7 +65,6 @@ __global__ void
GridwiseGemm::Run(p_y_grid + c_batch_offset,
p_ygrad_grid + c_batch_offset,
p_d_grid + d_batch_offset,
p_shared,
c_grid_desc_mblock_mperblock_nblock_nperblock,
d_grid_desc_m,
block_2_ctile_map);
......@@ -773,8 +771,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
// GridwiseYDotYGrad
using GridwiseYDotYGrad =
GridwiseBatchedMultiheadAttentionBackward_YDotYGrad<InputDataType, // TODO: distinguish A/B
// datatype
DDataType,
DDataType, // datatype
YGridDesc_M_O,
DGridDesc_M,
BlockSize,
......
......@@ -37,7 +37,6 @@ __global__ void
const void CK_CONSTANT_ADDRESS_SPACE* group_kernel_args, const index_t group_count)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
const index_t block_id = get_block_1d_id();
const auto arg_ptr = reinterpret_cast<const GroupKernelArg*>(
cast_pointer_to_generic_address_space(group_kernel_args));
......@@ -73,7 +72,6 @@ __global__ void
GridwiseGemm::Run(arg_ptr[group_id].p_c_grid_ + c_batch_offset,
arg_ptr[group_id].p_ygrad_grid_ + c_batch_offset,
arg_ptr[group_id].p_d_grid_ + d_batch_offset,
static_cast<void*>(p_shared),
arg_ptr[group_id].d_y_grid_desc_mblock_mperblock_nblock_nperblock_,
arg_ptr[group_id].d_grid_desc_m_,
arg_ptr[group_id].d_block_2_ctile_map_);
......
......@@ -37,7 +37,6 @@ __global__ void
const void CK_CONSTANT_ADDRESS_SPACE* group_kernel_args, const index_t group_count)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
const index_t block_id = get_block_1d_id();
const auto arg_ptr = reinterpret_cast<const GroupKernelArg*>(
cast_pointer_to_generic_address_space(group_kernel_args));
......@@ -73,7 +72,6 @@ __global__ void
GridwiseGemm::Run(arg_ptr[group_id].p_c_grid_ + c_batch_offset,
arg_ptr[group_id].p_ygrad_grid_ + c_batch_offset,
arg_ptr[group_id].p_d_grid_ + d_batch_offset,
static_cast<void*>(p_shared),
arg_ptr[group_id].d_y_grid_desc_mblock_mperblock_nblock_nperblock_,
arg_ptr[group_id].d_grid_desc_m_,
arg_ptr[group_id].d_block_2_ctile_map_);
......
......@@ -136,15 +136,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
};
using YDotYGrad_M_O = YDotYGrad_M_O_<BlockSize, MPerBlock, NPerBlock>;
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{
return MPerBlock * sizeof(FloatD);
}
__device__ static void Run(const InputDataType* __restrict__ p_y_grid,
const InputDataType* __restrict__ p_ygrad_grid,
FloatD* __restrict__ p_d_grid,
void* __restrict__ p_shared,
const YGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
y_grid_desc_mblock_mperblock_nblock_nperblock,
const DGridDesc_M& d_grid_desc_m,
......@@ -213,12 +207,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
auto y_thread_buf = typename YDotYGrad_M_O::SrcBufType{};
auto ygrad_thread_buf = typename YDotYGrad_M_O::SrcBufType{};
auto y_dot_ygrad_thread_accum_buf = typename YDotYGrad_M_O::DstBufType{};
auto y_dot_ygrad_block_accum_buf =
make_dynamic_buffer<AddressSpaceEnum::Lds>(static_cast<FloatD*>(p_shared), MPerBlock);
// clear accum buffers
y_dot_ygrad_thread_accum_buf.Clear();
y_dot_ygrad_block_accum_buf.Clear();
index_t oblock_idx = 0;
do
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment