Commit 31706d42 authored by danyao12's avatar danyao12
Browse files

modify bias LDS addrs in bwd kernels

parent 70d700b3
...@@ -1191,56 +1191,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -1191,56 +1191,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
KPack>; KPack>;
}; };
struct SharedMemTrait
{
// // LDS allocation for A and B: be careful of alignment
static constexpr auto q_block_desc_k0_m_k1 =
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
static constexpr auto k_block_desc_k0_n_k1 =
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
static constexpr auto ygrad_block_desc_k0_m_k1 =
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
static constexpr auto p_slash_sgrad_block_desc_k0_m_k1 =
GetA2BlockDescriptor_K0_M_K1<Gemm2Params>();
static constexpr auto max_lds_align = Number<16 / sizeof(GemmDataType)>{};
static constexpr auto q_block_space_size_aligned =
math::integer_least_multiple(q_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align);
static constexpr auto k_block_space_size_aligned =
math::integer_least_multiple(k_block_desc_k0_n_k1.GetElementSpaceSize(), max_lds_align);
static constexpr auto ygrad_block_space_size_aligned = math::integer_least_multiple(
ygrad_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align);
static constexpr auto p_slash_sgrad_block_space_size_aligned = math::integer_least_multiple(
p_slash_sgrad_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align);
static constexpr auto k_block_space_offset = 0;
static constexpr auto ygrad_block_space_offset = k_block_space_size_aligned.value;
static constexpr auto q_block_space_offset =
k_block_space_size_aligned.value + ygrad_block_space_size_aligned.value;
static constexpr auto p_slash_sgrad_block_space_offset =
k_block_space_size_aligned.value + ygrad_block_space_size_aligned.value +
q_block_space_size_aligned.value;
// LDS allocation for C shuffle in LDS
static constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
static constexpr auto c_block_space_size =
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
};
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{
const index_t p_slash_sgrad_bytes_end =
(SharedMemTrait::p_slash_sgrad_block_space_offset +
SharedMemTrait::p_slash_sgrad_block_space_size_aligned) *
sizeof(GemmDataType);
const index_t c_block_bytes_end =
SharedMemTrait::c_block_space_size * sizeof(FloatCShuffle);
return math::max(p_slash_sgrad_bytes_end, c_block_bytes_end);
}
// D0 // D0
static constexpr auto D0M2 = Number<4>{}; static constexpr auto D0M2 = Number<4>{};
static constexpr auto D0M1 = Number<MPerXdl>{} / D0M2; static constexpr auto D0M1 = Number<MPerXdl>{} / D0M2;
...@@ -1274,11 +1224,15 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -1274,11 +1224,15 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
struct TypeTransform struct TypeTransform
{ {
using Type = DataType; using Type = DataType;
static constexpr index_t Size0 = sizeof(DataType);
static constexpr index_t Size = sizeof(DataType);
}; };
template <> template <>
struct TypeTransform<void> struct TypeTransform<void>
{ {
using Type = ck::half_t; using Type = ck::half_t;
static constexpr index_t Size0 = 0;
static constexpr index_t Size = sizeof(ck::half_t);
}; };
static constexpr index_t NThreadClusterLengths = MPerXdl; static constexpr index_t NThreadClusterLengths = MPerXdl;
static_assert(MPerXdl <= KPerBlock); static_assert(MPerXdl <= KPerBlock);
...@@ -1354,6 +1308,66 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -1354,6 +1308,66 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
2>; 2>;
}; };
struct SharedMemTrait
{
// // LDS allocation for A and B: be careful of alignment
static constexpr auto q_block_desc_k0_m_k1 =
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
static constexpr auto k_block_desc_k0_n_k1 =
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
static constexpr auto ygrad_block_desc_k0_m_k1 =
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
static constexpr auto p_slash_sgrad_block_desc_k0_m_k1 =
GetA2BlockDescriptor_K0_M_K1<Gemm2Params>();
static constexpr auto max_lds_align = Number<16 / sizeof(GemmDataType)>{};
static constexpr auto q_block_space_size_aligned =
math::integer_least_multiple(q_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align);
static constexpr auto k_block_space_size_aligned =
math::integer_least_multiple(k_block_desc_k0_n_k1.GetElementSpaceSize(), max_lds_align);
static constexpr auto ygrad_block_space_size_aligned = math::integer_least_multiple(
ygrad_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align);
static constexpr auto p_slash_sgrad_block_space_size_aligned = math::integer_least_multiple(
p_slash_sgrad_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align);
static constexpr auto k_block_space_offset = 0;
static constexpr auto ygrad_block_space_offset = k_block_space_size_aligned.value;
static constexpr auto q_block_space_offset =
k_block_space_size_aligned.value + ygrad_block_space_size_aligned.value;
static constexpr auto p_slash_sgrad_block_space_offset =
k_block_space_size_aligned.value + ygrad_block_space_size_aligned.value +
q_block_space_size_aligned.value;
static constexpr auto d0_block_space_size_aligned = math::integer_least_multiple(
D0Loader::d0_block_write_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize(), max_lds_align);
static constexpr auto d0_block_space_offset =
(k_block_space_size_aligned.value + ygrad_block_space_size_aligned.value +
q_block_space_size_aligned.value) *
sizeof(GemmDataType) / D0Loader::template TypeTransform<D0DataType>::Size;
// LDS allocation for C shuffle in LDS
static constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
static constexpr auto c_block_space_size =
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
};
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{
const index_t p_slash_sgrad_bytes_end =
(SharedMemTrait::p_slash_sgrad_block_space_offset +
SharedMemTrait::p_slash_sgrad_block_space_size_aligned) *
sizeof(GemmDataType);
const index_t d0_bytes_end =
(SharedMemTrait::d0_block_space_offset + SharedMemTrait::d0_block_space_size_aligned) *
D0Loader::template TypeTransform<D0DataType>::Size0;
const index_t c_block_bytes_end =
SharedMemTrait::c_block_space_size * sizeof(FloatCShuffle);
return math::max(p_slash_sgrad_bytes_end, d0_bytes_end, c_block_bytes_end);
}
template <bool HasMainKBlockLoop, template <bool HasMainKBlockLoop,
bool IsDropout, bool IsDropout,
typename Block2CTileMap, typename Block2CTileMap,
...@@ -1987,8 +2001,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -1987,8 +2001,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
p_d0_grid, d0_grid_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize()); p_d0_grid, d0_grid_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize());
auto d0_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto d0_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<D0DataType*>(p_shared) + static_cast<D0DataType*>(p_shared) + SharedMemTrait::d0_block_space_offset,
SharedMemTrait::p_slash_sgrad_block_space_offset,
D0Loader::d0_block_write_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize()); D0Loader::d0_block_write_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize());
auto d0_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, D0DataType>( auto d0_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, D0DataType>(
...@@ -2023,10 +2036,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -2023,10 +2036,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
}); });
}); });
// load k
gemm_tile_k_blockwise_copy.RunWrite(GemmBlockwiseCopy::k_block_desc_k0_n_k1,
k_block_buf);
d0_block_copy_global_to_lds.MoveSrcSliceWindow( d0_block_copy_global_to_lds.MoveSrcSliceWindow(
d0_grid_desc_m0_n0_m1_m2_n1_m3, make_multi_index(-1, 0, -D0M0.value, 0, 0, 0)); d0_grid_desc_m0_n0_m1_m2_n1_m3, make_multi_index(-1, 0, -D0M0.value, 0, 0, 0));
} }
......
...@@ -1276,65 +1276,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -1276,65 +1276,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
} }
}; };
struct SharedMemTrait
{
// LDS allocation for K
static constexpr auto k_block_desc_k0_n_k1 = GetKBlockDescriptor_K0PerBlock_NPerBlock_K1();
// LDS allocation for A and B: be careful of alignment
static constexpr auto a_block_desc_ak0_m_ak1 =
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
static constexpr auto b1_block_desc_bk0_n_bk1 =
GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1();
static constexpr auto a2_block_desc_k0_m_k1 = GetA2BlockDescriptor_K0_M_K1<Gemm2Params>();
static constexpr auto max_lds_align = Number<16 / sizeof(GemmDataType)>{};
static constexpr auto k_block_space_size_aligned =
math::integer_least_multiple(k_block_desc_k0_n_k1.GetElementSpaceSize(), max_lds_align);
static constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
static constexpr auto b1_block_space_size_aligned = math::integer_least_multiple(
b1_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
static constexpr auto a2_block_space_size_aligned = math::integer_least_multiple(
a2_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align);
static constexpr auto k_block_space_offset = 0;
static constexpr auto a_block_space_offset = k_block_space_size_aligned.value;
static constexpr auto b1_block_space_offset = k_block_space_size_aligned.value;
static constexpr auto a2_block_space_offset = k_block_space_size_aligned.value;
// LDS allocation for reduction
static constexpr index_t reduction_space_size_aligned =
math::integer_least_multiple(BlockSize, max_lds_align);
// LDS allocation for C shuffle in LDS
static constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
static constexpr auto c_block_space_size =
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
};
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{
const index_t gemm0_bytes_end = (SharedMemTrait::k_block_space_size_aligned +
SharedMemTrait::a_block_space_size_aligned) *
sizeof(GemmDataType);
const index_t gemm1_bytes_end = (SharedMemTrait::k_block_space_size_aligned +
SharedMemTrait::b1_block_space_size_aligned) *
sizeof(GemmDataType);
const index_t gemm2_bytes_end = (SharedMemTrait::k_block_space_size_aligned +
SharedMemTrait::a2_block_space_size_aligned) *
sizeof(GemmDataType);
const index_t gemm3_bytes_end = (SharedMemTrait::k_block_space_size_aligned +
SharedMemTrait::a_block_space_size_aligned) *
sizeof(GemmDataType);
const index_t c_block_bytes_end =
SharedMemTrait::c_block_space_size * sizeof(FloatCShuffle);
return math::max(
gemm0_bytes_end, gemm1_bytes_end, gemm2_bytes_end, gemm3_bytes_end, c_block_bytes_end);
}
// D0 // D0
static constexpr auto D0M2 = Number<4>{}; static constexpr auto D0M2 = Number<4>{};
static constexpr auto D0M1 = Number<MPerXdl>{} / D0M2; static constexpr auto D0M1 = Number<MPerXdl>{} / D0M2;
...@@ -1368,11 +1309,15 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -1368,11 +1309,15 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
struct TypeTransform struct TypeTransform
{ {
using Type = DataType; using Type = DataType;
static constexpr index_t Size0 = sizeof(DataType);
static constexpr index_t Size = sizeof(DataType);
}; };
template <> template <>
struct TypeTransform<void> struct TypeTransform<void>
{ {
using Type = ck::half_t; using Type = ck::half_t;
static constexpr index_t Size0 = 0;
static constexpr index_t Size = sizeof(ck::half_t);
}; };
static constexpr index_t NThreadClusterLengths = 32; static constexpr index_t NThreadClusterLengths = 32;
static_assert(NPerXdl == 32); static_assert(NPerXdl == 32);
...@@ -1448,6 +1393,78 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -1448,6 +1393,78 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
2>; 2>;
}; };
struct SharedMemTrait
{
// LDS allocation for K
static constexpr auto k_block_desc_k0_n_k1 = GetKBlockDescriptor_K0PerBlock_NPerBlock_K1();
// LDS allocation for A and B: be careful of alignment
static constexpr auto a_block_desc_ak0_m_ak1 =
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
static constexpr auto b1_block_desc_bk0_n_bk1 =
GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1();
static constexpr auto a2_block_desc_k0_m_k1 = GetA2BlockDescriptor_K0_M_K1<Gemm2Params>();
static constexpr auto max_lds_align = Number<16 / sizeof(GemmDataType)>{};
static constexpr auto k_block_space_size_aligned =
math::integer_least_multiple(k_block_desc_k0_n_k1.GetElementSpaceSize(), max_lds_align);
static constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
static constexpr auto b1_block_space_size_aligned = math::integer_least_multiple(
b1_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
static constexpr auto a2_block_space_size_aligned = math::integer_least_multiple(
a2_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align);
static constexpr auto k_block_space_offset = 0;
static constexpr auto a_block_space_offset = k_block_space_size_aligned.value;
static constexpr auto b1_block_space_offset = k_block_space_size_aligned.value;
static constexpr auto a2_block_space_offset = k_block_space_size_aligned.value;
// LDS allocation for reduction
static constexpr index_t reduction_space_size_aligned =
math::integer_least_multiple(BlockSize, max_lds_align);
static constexpr auto d0_block_space_size_aligned = math::integer_least_multiple(
D0Loader::d0_block_write_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize(), max_lds_align);
static constexpr auto d0_block_space_offset =
k_block_space_size_aligned.value * sizeof(GemmDataType) /
D0Loader::template TypeTransform<D0DataType>::Size;
// LDS allocation for C shuffle in LDS
static constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
static constexpr auto c_block_space_size =
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
};
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{
const index_t gemm0_bytes_end = (SharedMemTrait::k_block_space_size_aligned +
SharedMemTrait::a_block_space_size_aligned) *
sizeof(GemmDataType);
const index_t gemm1_bytes_end = (SharedMemTrait::k_block_space_size_aligned +
SharedMemTrait::b1_block_space_size_aligned) *
sizeof(GemmDataType);
const index_t gemm2_bytes_end = (SharedMemTrait::k_block_space_size_aligned +
SharedMemTrait::a2_block_space_size_aligned) *
sizeof(GemmDataType);
const index_t gemm3_bytes_end = (SharedMemTrait::k_block_space_size_aligned +
SharedMemTrait::a_block_space_size_aligned) *
sizeof(GemmDataType);
const index_t d0_bytes_end =
(SharedMemTrait::d0_block_space_offset + SharedMemTrait::d0_block_space_size_aligned) *
D0Loader::template TypeTransform<D0DataType>::Size0;
const index_t c_block_bytes_end =
SharedMemTrait::c_block_space_size * sizeof(FloatCShuffle);
return math::max(gemm0_bytes_end,
gemm1_bytes_end,
gemm2_bytes_end,
gemm3_bytes_end,
d0_bytes_end,
c_block_bytes_end);
}
template <bool HasMainKBlockLoop, template <bool HasMainKBlockLoop,
bool IsDropout, bool IsDropout,
typename Block2CTileMap, typename Block2CTileMap,
...@@ -2137,7 +2154,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -2137,7 +2154,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
p_d0_grid, d0_grid_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize()); p_d0_grid, d0_grid_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize());
auto d0_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto d0_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<D0DataType*>(p_shared) + SharedMemTrait::a_block_space_offset, static_cast<D0DataType*>(p_shared) + SharedMemTrait::d0_block_space_offset,
D0Loader::d0_block_write_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize()); D0Loader::d0_block_write_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize());
auto d0_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, D0DataType>( auto d0_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, D0DataType>(
......
...@@ -1259,68 +1259,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -1259,68 +1259,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
}; };
using YDotYGrad_M_O = YDotYGrad_M_O_<BlockSize, MPerBlock, Gemm1NPerBlock>; using YDotYGrad_M_O = YDotYGrad_M_O_<BlockSize, MPerBlock, Gemm1NPerBlock>;
struct SharedMemTrait
{
// // LDS allocation for A and B: be careful of alignment
static constexpr auto q_block_desc_k0_m_k1 =
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
static constexpr auto k_block_desc_k0_n_k1 =
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
static constexpr auto ygrad_block_desc_k0_m_k1 =
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
static constexpr auto p_slash_sgrad_block_desc_k0_m_k1 =
GetA2BlockDescriptor_K0_M_K1<Gemm2Params>();
static constexpr auto max_lds_align = Number<16 / sizeof(GemmDataType)>{};
static constexpr auto q_block_space_size_aligned =
math::integer_least_multiple(q_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align);
static constexpr auto k_block_space_size_aligned =
math::integer_least_multiple(k_block_desc_k0_n_k1.GetElementSpaceSize(), max_lds_align);
static constexpr auto ygrad_block_space_size_aligned = math::integer_least_multiple(
ygrad_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align);
static constexpr auto p_slash_sgrad_block_space_size_aligned = math::integer_least_multiple(
p_slash_sgrad_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align);
static constexpr auto k_block_space_offset = 0;
static constexpr auto ygrad_block_space_offset = k_block_space_size_aligned.value;
static constexpr auto q_block_space_offset =
k_block_space_size_aligned.value + ygrad_block_space_size_aligned.value;
static constexpr auto p_slash_sgrad_block_space_offset =
k_block_space_size_aligned.value + ygrad_block_space_size_aligned.value +
q_block_space_size_aligned.value;
// LDS allocation for reduction
static constexpr index_t reduction_space_size_aligned =
math::integer_least_multiple(BlockSize, max_lds_align);
static constexpr auto reduction_space_offset =
(k_block_space_size_aligned.value + ygrad_block_space_size_aligned.value +
q_block_space_size_aligned.value) *
sizeof(GemmDataType) / sizeof(FloatGemmAcc);
// LDS allocation for C shuffle in LDS
static constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
static constexpr auto c_block_space_size =
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
};
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{
const index_t p_slash_sgrad_bytes_end =
(SharedMemTrait::p_slash_sgrad_block_space_offset +
SharedMemTrait::p_slash_sgrad_block_space_size_aligned) *
sizeof(GemmDataType);
const index_t softmax_bytes_end = (SharedMemTrait::reduction_space_offset +
SharedMemTrait::reduction_space_size_aligned) *
sizeof(FloatGemmAcc);
const index_t c_block_bytes_end =
SharedMemTrait::c_block_space_size * sizeof(FloatCShuffle);
return math::max(p_slash_sgrad_bytes_end, softmax_bytes_end, c_block_bytes_end);
}
// D0 // D0
static constexpr auto D0M2 = Number<4>{}; static constexpr auto D0M2 = Number<4>{};
static constexpr auto D0M1 = Number<MPerXdl>{} / D0M2; static constexpr auto D0M1 = Number<MPerXdl>{} / D0M2;
...@@ -1354,11 +1292,15 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -1354,11 +1292,15 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
struct TypeTransform struct TypeTransform
{ {
using Type = DataType; using Type = DataType;
static constexpr index_t Size0 = sizeof(DataType);
static constexpr index_t Size = sizeof(DataType);
}; };
template <> template <>
struct TypeTransform<void> struct TypeTransform<void>
{ {
using Type = ck::half_t; using Type = ck::half_t;
static constexpr index_t Size0 = 0;
static constexpr index_t Size = sizeof(ck::half_t);
}; };
static constexpr index_t NThreadClusterLengths = MPerXdl; static constexpr index_t NThreadClusterLengths = MPerXdl;
static_assert(MPerXdl <= KPerBlock); static_assert(MPerXdl <= KPerBlock);
...@@ -1434,6 +1376,79 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -1434,6 +1376,79 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
2>; 2>;
}; };
struct SharedMemTrait
{
// // LDS allocation for A and B: be careful of alignment
static constexpr auto q_block_desc_k0_m_k1 =
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
static constexpr auto k_block_desc_k0_n_k1 =
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
static constexpr auto ygrad_block_desc_k0_m_k1 =
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
static constexpr auto p_slash_sgrad_block_desc_k0_m_k1 =
GetA2BlockDescriptor_K0_M_K1<Gemm2Params>();
static constexpr auto max_lds_align = Number<16 / sizeof(GemmDataType)>{};
static constexpr auto q_block_space_size_aligned =
math::integer_least_multiple(q_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align);
static constexpr auto k_block_space_size_aligned =
math::integer_least_multiple(k_block_desc_k0_n_k1.GetElementSpaceSize(), max_lds_align);
static constexpr auto ygrad_block_space_size_aligned = math::integer_least_multiple(
ygrad_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align);
static constexpr auto p_slash_sgrad_block_space_size_aligned = math::integer_least_multiple(
p_slash_sgrad_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align);
static constexpr auto k_block_space_offset = 0;
static constexpr auto ygrad_block_space_offset = k_block_space_size_aligned.value;
static constexpr auto q_block_space_offset =
k_block_space_size_aligned.value + ygrad_block_space_size_aligned.value;
static constexpr auto p_slash_sgrad_block_space_offset =
k_block_space_size_aligned.value + ygrad_block_space_size_aligned.value +
q_block_space_size_aligned.value;
// LDS allocation for reduction
static constexpr index_t reduction_space_size_aligned =
math::integer_least_multiple(BlockSize, max_lds_align);
static constexpr auto reduction_space_offset =
(k_block_space_size_aligned.value + ygrad_block_space_size_aligned.value +
q_block_space_size_aligned.value) *
sizeof(GemmDataType) / sizeof(FloatGemmAcc);
static constexpr auto d0_block_space_size_aligned = math::integer_least_multiple(
D0Loader::d0_block_write_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize(), max_lds_align);
static constexpr auto d0_block_space_offset =
(k_block_space_size_aligned.value + ygrad_block_space_size_aligned.value +
q_block_space_size_aligned.value) *
sizeof(GemmDataType) / D0Loader::template TypeTransform<D0DataType>::Size;
// LDS allocation for C shuffle in LDS
static constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
static constexpr auto c_block_space_size =
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
};
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{
const index_t p_slash_sgrad_bytes_end =
(SharedMemTrait::p_slash_sgrad_block_space_offset +
SharedMemTrait::p_slash_sgrad_block_space_size_aligned) *
sizeof(GemmDataType);
const index_t softmax_bytes_end = (SharedMemTrait::reduction_space_offset +
SharedMemTrait::reduction_space_size_aligned) *
sizeof(FloatGemmAcc);
const index_t d0_bytes_end =
(SharedMemTrait::d0_block_space_offset + SharedMemTrait::d0_block_space_size_aligned) *
D0Loader::template TypeTransform<D0DataType>::Size0;
const index_t c_block_bytes_end =
SharedMemTrait::c_block_space_size * sizeof(FloatCShuffle);
return math::max(
p_slash_sgrad_bytes_end, softmax_bytes_end, d0_bytes_end, c_block_bytes_end);
}
template <bool HasMainKBlockLoop, template <bool HasMainKBlockLoop,
bool IsDropout, bool IsDropout,
typename Block2CTileMap, typename Block2CTileMap,
...@@ -2186,8 +2201,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -2186,8 +2201,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
p_d0_grid, d0_grid_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize()); p_d0_grid, d0_grid_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize());
auto d0_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto d0_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<D0DataType*>(p_shared) + static_cast<D0DataType*>(p_shared) + SharedMemTrait::d0_block_space_offset,
SharedMemTrait::p_slash_sgrad_block_space_offset,
D0Loader::d0_block_write_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize()); D0Loader::d0_block_write_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize());
auto d0_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, D0DataType>( auto d0_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, D0DataType>(
...@@ -2222,10 +2236,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -2222,10 +2236,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
}); });
}); });
// load k
gemm_tile_k_blockwise_copy.RunWrite(GemmBlockwiseCopy::k_block_desc_k0_n_k1,
k_block_buf);
d0_block_copy_global_to_lds.MoveSrcSliceWindow( d0_block_copy_global_to_lds.MoveSrcSliceWindow(
d0_grid_desc_m0_n0_m1_m2_n1_m3, make_multi_index(-1, 0, -D0M0.value, 0, 0, 0)); d0_grid_desc_m0_n0_m1_m2_n1_m3, make_multi_index(-1, 0, -D0M0.value, 0, 0, 0));
} }
......
...@@ -1321,79 +1321,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1321,79 +1321,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
} }
}; };
struct SharedMemTrait
{
// LDS allocation for K
static constexpr auto k_block_desc_k0_n_k1 = GetKBlockDescriptor_K0PerBlock_NPerBlock_K1();
// LDS allocation for A and B: be careful of alignment
static constexpr auto a_block_desc_ak0_m_ak1 =
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
static constexpr auto b1_block_desc_bk0_n_bk1 =
GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1();
static constexpr auto a2_block_desc_k0_m_k1 = GetA2BlockDescriptor_K0_M_K1<Gemm2Params>();
static constexpr auto max_lds_align = Number<16 / sizeof(GemmDataType)>{};
static constexpr auto k_block_space_size_aligned =
math::integer_least_multiple(k_block_desc_k0_n_k1.GetElementSpaceSize(), max_lds_align);
static constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
static constexpr auto b1_block_space_size_aligned = math::integer_least_multiple(
b1_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
static constexpr auto a2_block_space_size_aligned = math::integer_least_multiple(
a2_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align);
static constexpr auto k_block_space_offset = 0;
static constexpr auto a_block_space_offset = k_block_space_size_aligned.value;
static constexpr auto b1_block_space_offset = k_block_space_size_aligned.value;
static constexpr auto a2_block_space_offset = k_block_space_size_aligned.value;
// LDS allocation for reduction
static constexpr index_t reduction_space_size_aligned =
math::integer_least_multiple(BlockSize, max_lds_align);
static constexpr auto reduction_space_offset =
(math::max(a_block_space_size_aligned.value,
b1_block_space_size_aligned.value,
a2_block_space_size_aligned.value) +
k_block_space_size_aligned.value) *
sizeof(GemmDataType) / sizeof(FloatGemmAcc);
// LDS allocation for C shuffle in LDS
static constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
static constexpr auto c_block_space_size =
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
};
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{
const index_t gemm0_bytes_end = (SharedMemTrait::k_block_space_size_aligned +
SharedMemTrait::a_block_space_size_aligned) *
sizeof(GemmDataType);
const index_t gemm1_bytes_end = (SharedMemTrait::k_block_space_size_aligned +
SharedMemTrait::b1_block_space_size_aligned) *
sizeof(GemmDataType);
const index_t gemm2_bytes_end = (SharedMemTrait::k_block_space_size_aligned +
SharedMemTrait::a2_block_space_size_aligned) *
sizeof(GemmDataType);
const index_t gemm3_bytes_end = (SharedMemTrait::k_block_space_size_aligned +
SharedMemTrait::a_block_space_size_aligned) *
sizeof(GemmDataType);
const index_t softmax_bytes_end = (SharedMemTrait::reduction_space_offset +
SharedMemTrait::reduction_space_size_aligned) *
sizeof(FloatGemmAcc);
const index_t c_block_bytes_end =
SharedMemTrait::c_block_space_size * sizeof(FloatCShuffle);
return math::max(gemm0_bytes_end,
gemm1_bytes_end,
gemm2_bytes_end,
gemm3_bytes_end,
softmax_bytes_end,
c_block_bytes_end);
}
// D0 // D0
static constexpr auto D0M2 = Number<4>{}; static constexpr auto D0M2 = Number<4>{};
static constexpr auto D0M1 = Number<MPerXdl>{} / D0M2; static constexpr auto D0M1 = Number<MPerXdl>{} / D0M2;
...@@ -1427,11 +1354,15 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1427,11 +1354,15 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
struct TypeTransform struct TypeTransform
{ {
using Type = DataType; using Type = DataType;
static constexpr index_t Size0 = sizeof(DataType);
static constexpr index_t Size = sizeof(DataType);
}; };
template <> template <>
struct TypeTransform<void> struct TypeTransform<void>
{ {
using Type = ck::half_t; using Type = ck::half_t;
static constexpr index_t Size0 = 0;
static constexpr index_t Size = sizeof(ck::half_t);
}; };
static constexpr index_t NThreadClusterLengths = 32; static constexpr index_t NThreadClusterLengths = 32;
static_assert(NPerXdl == 32); static_assert(NPerXdl == 32);
...@@ -1507,6 +1438,89 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1507,6 +1438,89 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
2>; 2>;
}; };
struct SharedMemTrait
{
// LDS allocation for K
static constexpr auto k_block_desc_k0_n_k1 = GetKBlockDescriptor_K0PerBlock_NPerBlock_K1();
// LDS allocation for A and B: be careful of alignment
static constexpr auto a_block_desc_ak0_m_ak1 =
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
static constexpr auto b1_block_desc_bk0_n_bk1 =
GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1();
static constexpr auto a2_block_desc_k0_m_k1 = GetA2BlockDescriptor_K0_M_K1<Gemm2Params>();
static constexpr auto max_lds_align = Number<16 / sizeof(GemmDataType)>{};
static constexpr auto k_block_space_size_aligned =
math::integer_least_multiple(k_block_desc_k0_n_k1.GetElementSpaceSize(), max_lds_align);
static constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
static constexpr auto b1_block_space_size_aligned = math::integer_least_multiple(
b1_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
static constexpr auto a2_block_space_size_aligned = math::integer_least_multiple(
a2_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align);
static constexpr auto k_block_space_offset = 0;
static constexpr auto a_block_space_offset = k_block_space_size_aligned.value;
static constexpr auto b1_block_space_offset = k_block_space_size_aligned.value;
static constexpr auto a2_block_space_offset = k_block_space_size_aligned.value;
// LDS allocation for reduction
static constexpr index_t reduction_space_size_aligned =
math::integer_least_multiple(BlockSize, max_lds_align);
static constexpr auto reduction_space_offset =
(math::max(a_block_space_size_aligned.value,
b1_block_space_size_aligned.value,
a2_block_space_size_aligned.value) +
k_block_space_size_aligned.value) *
sizeof(GemmDataType) / sizeof(FloatGemmAcc);
static constexpr auto d0_block_space_size_aligned = math::integer_least_multiple(
D0Loader::d0_block_write_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize(), max_lds_align);
static constexpr auto d0_block_space_offset = k_block_space_size_aligned.value *
sizeof(GemmDataType) /
D0Loader::template TypeTransform<D0DataType>::Size;
// LDS allocation for C shuffle in LDS
static constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
static constexpr auto c_block_space_size =
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
};
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{
const index_t gemm0_bytes_end = (SharedMemTrait::k_block_space_size_aligned +
SharedMemTrait::a_block_space_size_aligned) *
sizeof(GemmDataType);
const index_t gemm1_bytes_end = (SharedMemTrait::k_block_space_size_aligned +
SharedMemTrait::b1_block_space_size_aligned) *
sizeof(GemmDataType);
const index_t gemm2_bytes_end = (SharedMemTrait::k_block_space_size_aligned +
SharedMemTrait::a2_block_space_size_aligned) *
sizeof(GemmDataType);
const index_t gemm3_bytes_end = (SharedMemTrait::k_block_space_size_aligned +
SharedMemTrait::a_block_space_size_aligned) *
sizeof(GemmDataType);
const index_t softmax_bytes_end = (SharedMemTrait::reduction_space_offset +
SharedMemTrait::reduction_space_size_aligned) *
sizeof(FloatGemmAcc);
const index_t d0_bytes_end =
(SharedMemTrait::d0_block_space_offset + SharedMemTrait::d0_block_space_size_aligned) *
D0Loader::template TypeTransform<D0DataType>::Size0;
const index_t c_block_bytes_end =
SharedMemTrait::c_block_space_size * sizeof(FloatCShuffle);
return math::max(gemm0_bytes_end,
gemm1_bytes_end,
gemm2_bytes_end,
gemm3_bytes_end,
softmax_bytes_end,
d0_bytes_end,
c_block_bytes_end);
}
template <bool HasMainKBlockLoop, template <bool HasMainKBlockLoop,
bool IsDropout, bool IsDropout,
typename Block2CTileMap, typename Block2CTileMap,
...@@ -2292,7 +2306,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -2292,7 +2306,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
p_d0_grid, d0_grid_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize()); p_d0_grid, d0_grid_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize());
auto d0_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto d0_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<D0DataType*>(p_shared) + SharedMemTrait::a_block_space_offset, static_cast<D0DataType*>(p_shared) + SharedMemTrait::d0_block_space_offset,
D0Loader::d0_block_write_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize()); D0Loader::d0_block_write_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize());
auto d0_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, D0DataType>( auto d0_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, D0DataType>(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment