Commit 31706d42 authored by danyao12's avatar danyao12
Browse files

modify bias LDS addrs in bwd kernels

parent 70d700b3
......@@ -1191,56 +1191,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
KPack>;
};
struct SharedMemTrait
{
// // LDS allocation for A and B: be careful of alignment
static constexpr auto q_block_desc_k0_m_k1 =
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
static constexpr auto k_block_desc_k0_n_k1 =
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
static constexpr auto ygrad_block_desc_k0_m_k1 =
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
static constexpr auto p_slash_sgrad_block_desc_k0_m_k1 =
GetA2BlockDescriptor_K0_M_K1<Gemm2Params>();
static constexpr auto max_lds_align = Number<16 / sizeof(GemmDataType)>{};
static constexpr auto q_block_space_size_aligned =
math::integer_least_multiple(q_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align);
static constexpr auto k_block_space_size_aligned =
math::integer_least_multiple(k_block_desc_k0_n_k1.GetElementSpaceSize(), max_lds_align);
static constexpr auto ygrad_block_space_size_aligned = math::integer_least_multiple(
ygrad_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align);
static constexpr auto p_slash_sgrad_block_space_size_aligned = math::integer_least_multiple(
p_slash_sgrad_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align);
static constexpr auto k_block_space_offset = 0;
static constexpr auto ygrad_block_space_offset = k_block_space_size_aligned.value;
static constexpr auto q_block_space_offset =
k_block_space_size_aligned.value + ygrad_block_space_size_aligned.value;
static constexpr auto p_slash_sgrad_block_space_offset =
k_block_space_size_aligned.value + ygrad_block_space_size_aligned.value +
q_block_space_size_aligned.value;
// LDS allocation for C shuffle in LDS
static constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
static constexpr auto c_block_space_size =
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
};
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{
const index_t p_slash_sgrad_bytes_end =
(SharedMemTrait::p_slash_sgrad_block_space_offset +
SharedMemTrait::p_slash_sgrad_block_space_size_aligned) *
sizeof(GemmDataType);
const index_t c_block_bytes_end =
SharedMemTrait::c_block_space_size * sizeof(FloatCShuffle);
return math::max(p_slash_sgrad_bytes_end, c_block_bytes_end);
}
// D0
static constexpr auto D0M2 = Number<4>{};
static constexpr auto D0M1 = Number<MPerXdl>{} / D0M2;
......@@ -1273,12 +1223,16 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
template <typename DataType>
struct TypeTransform
{
using Type = DataType;
using Type = DataType;
static constexpr index_t Size0 = sizeof(DataType);
static constexpr index_t Size = sizeof(DataType);
};
template <>
struct TypeTransform<void>
{
using Type = ck::half_t;
using Type = ck::half_t;
static constexpr index_t Size0 = 0;
static constexpr index_t Size = sizeof(ck::half_t);
};
static constexpr index_t NThreadClusterLengths = MPerXdl;
static_assert(MPerXdl <= KPerBlock);
......@@ -1354,6 +1308,66 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
2>;
};
struct SharedMemTrait
{
// // LDS allocation for A and B: be careful of alignment
static constexpr auto q_block_desc_k0_m_k1 =
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
static constexpr auto k_block_desc_k0_n_k1 =
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
static constexpr auto ygrad_block_desc_k0_m_k1 =
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
static constexpr auto p_slash_sgrad_block_desc_k0_m_k1 =
GetA2BlockDescriptor_K0_M_K1<Gemm2Params>();
static constexpr auto max_lds_align = Number<16 / sizeof(GemmDataType)>{};
static constexpr auto q_block_space_size_aligned =
math::integer_least_multiple(q_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align);
static constexpr auto k_block_space_size_aligned =
math::integer_least_multiple(k_block_desc_k0_n_k1.GetElementSpaceSize(), max_lds_align);
static constexpr auto ygrad_block_space_size_aligned = math::integer_least_multiple(
ygrad_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align);
static constexpr auto p_slash_sgrad_block_space_size_aligned = math::integer_least_multiple(
p_slash_sgrad_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align);
static constexpr auto k_block_space_offset = 0;
static constexpr auto ygrad_block_space_offset = k_block_space_size_aligned.value;
static constexpr auto q_block_space_offset =
k_block_space_size_aligned.value + ygrad_block_space_size_aligned.value;
static constexpr auto p_slash_sgrad_block_space_offset =
k_block_space_size_aligned.value + ygrad_block_space_size_aligned.value +
q_block_space_size_aligned.value;
static constexpr auto d0_block_space_size_aligned = math::integer_least_multiple(
D0Loader::d0_block_write_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize(), max_lds_align);
static constexpr auto d0_block_space_offset =
(k_block_space_size_aligned.value + ygrad_block_space_size_aligned.value +
q_block_space_size_aligned.value) *
sizeof(GemmDataType) / D0Loader::template TypeTransform<D0DataType>::Size;
// LDS allocation for C shuffle in LDS
static constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
static constexpr auto c_block_space_size =
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
};
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{
const index_t p_slash_sgrad_bytes_end =
(SharedMemTrait::p_slash_sgrad_block_space_offset +
SharedMemTrait::p_slash_sgrad_block_space_size_aligned) *
sizeof(GemmDataType);
const index_t d0_bytes_end =
(SharedMemTrait::d0_block_space_offset + SharedMemTrait::d0_block_space_size_aligned) *
D0Loader::template TypeTransform<D0DataType>::Size0;
const index_t c_block_bytes_end =
SharedMemTrait::c_block_space_size * sizeof(FloatCShuffle);
return math::max(p_slash_sgrad_bytes_end, d0_bytes_end, c_block_bytes_end);
}
template <bool HasMainKBlockLoop,
bool IsDropout,
typename Block2CTileMap,
......@@ -1987,8 +2001,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
p_d0_grid, d0_grid_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize());
auto d0_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<D0DataType*>(p_shared) +
SharedMemTrait::p_slash_sgrad_block_space_offset,
static_cast<D0DataType*>(p_shared) + SharedMemTrait::d0_block_space_offset,
D0Loader::d0_block_write_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize());
auto d0_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, D0DataType>(
......@@ -2023,10 +2036,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
});
});
// load k
gemm_tile_k_blockwise_copy.RunWrite(GemmBlockwiseCopy::k_block_desc_k0_n_k1,
k_block_buf);
d0_block_copy_global_to_lds.MoveSrcSliceWindow(
d0_grid_desc_m0_n0_m1_m2_n1_m3, make_multi_index(-1, 0, -D0M0.value, 0, 0, 0));
}
......
......@@ -1276,65 +1276,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
}
};
struct SharedMemTrait
{
// LDS allocation for K
static constexpr auto k_block_desc_k0_n_k1 = GetKBlockDescriptor_K0PerBlock_NPerBlock_K1();
// LDS allocation for A and B: be careful of alignment
static constexpr auto a_block_desc_ak0_m_ak1 =
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
static constexpr auto b1_block_desc_bk0_n_bk1 =
GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1();
static constexpr auto a2_block_desc_k0_m_k1 = GetA2BlockDescriptor_K0_M_K1<Gemm2Params>();
static constexpr auto max_lds_align = Number<16 / sizeof(GemmDataType)>{};
static constexpr auto k_block_space_size_aligned =
math::integer_least_multiple(k_block_desc_k0_n_k1.GetElementSpaceSize(), max_lds_align);
static constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
static constexpr auto b1_block_space_size_aligned = math::integer_least_multiple(
b1_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
static constexpr auto a2_block_space_size_aligned = math::integer_least_multiple(
a2_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align);
static constexpr auto k_block_space_offset = 0;
static constexpr auto a_block_space_offset = k_block_space_size_aligned.value;
static constexpr auto b1_block_space_offset = k_block_space_size_aligned.value;
static constexpr auto a2_block_space_offset = k_block_space_size_aligned.value;
// LDS allocation for reduction
static constexpr index_t reduction_space_size_aligned =
math::integer_least_multiple(BlockSize, max_lds_align);
// LDS allocation for C shuffle in LDS
static constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
static constexpr auto c_block_space_size =
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
};
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{
const index_t gemm0_bytes_end = (SharedMemTrait::k_block_space_size_aligned +
SharedMemTrait::a_block_space_size_aligned) *
sizeof(GemmDataType);
const index_t gemm1_bytes_end = (SharedMemTrait::k_block_space_size_aligned +
SharedMemTrait::b1_block_space_size_aligned) *
sizeof(GemmDataType);
const index_t gemm2_bytes_end = (SharedMemTrait::k_block_space_size_aligned +
SharedMemTrait::a2_block_space_size_aligned) *
sizeof(GemmDataType);
const index_t gemm3_bytes_end = (SharedMemTrait::k_block_space_size_aligned +
SharedMemTrait::a_block_space_size_aligned) *
sizeof(GemmDataType);
const index_t c_block_bytes_end =
SharedMemTrait::c_block_space_size * sizeof(FloatCShuffle);
return math::max(
gemm0_bytes_end, gemm1_bytes_end, gemm2_bytes_end, gemm3_bytes_end, c_block_bytes_end);
}
// D0
static constexpr auto D0M2 = Number<4>{};
static constexpr auto D0M1 = Number<MPerXdl>{} / D0M2;
......@@ -1367,12 +1308,16 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
template <typename DataType>
struct TypeTransform
{
using Type = DataType;
using Type = DataType;
static constexpr index_t Size0 = sizeof(DataType);
static constexpr index_t Size = sizeof(DataType);
};
template <>
struct TypeTransform<void>
{
using Type = ck::half_t;
using Type = ck::half_t;
static constexpr index_t Size0 = 0;
static constexpr index_t Size = sizeof(ck::half_t);
};
static constexpr index_t NThreadClusterLengths = 32;
static_assert(NPerXdl == 32);
......@@ -1448,6 +1393,78 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
2>;
};
struct SharedMemTrait
{
// LDS allocation for K
static constexpr auto k_block_desc_k0_n_k1 = GetKBlockDescriptor_K0PerBlock_NPerBlock_K1();
// LDS allocation for A and B: be careful of alignment
static constexpr auto a_block_desc_ak0_m_ak1 =
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
static constexpr auto b1_block_desc_bk0_n_bk1 =
GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1();
static constexpr auto a2_block_desc_k0_m_k1 = GetA2BlockDescriptor_K0_M_K1<Gemm2Params>();
static constexpr auto max_lds_align = Number<16 / sizeof(GemmDataType)>{};
static constexpr auto k_block_space_size_aligned =
math::integer_least_multiple(k_block_desc_k0_n_k1.GetElementSpaceSize(), max_lds_align);
static constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
static constexpr auto b1_block_space_size_aligned = math::integer_least_multiple(
b1_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
static constexpr auto a2_block_space_size_aligned = math::integer_least_multiple(
a2_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align);
static constexpr auto k_block_space_offset = 0;
static constexpr auto a_block_space_offset = k_block_space_size_aligned.value;
static constexpr auto b1_block_space_offset = k_block_space_size_aligned.value;
static constexpr auto a2_block_space_offset = k_block_space_size_aligned.value;
// LDS allocation for reduction
static constexpr index_t reduction_space_size_aligned =
math::integer_least_multiple(BlockSize, max_lds_align);
static constexpr auto d0_block_space_size_aligned = math::integer_least_multiple(
D0Loader::d0_block_write_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize(), max_lds_align);
static constexpr auto d0_block_space_offset =
k_block_space_size_aligned.value * sizeof(GemmDataType) /
D0Loader::template TypeTransform<D0DataType>::Size;
// LDS allocation for C shuffle in LDS
static constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
static constexpr auto c_block_space_size =
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
};
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{
const index_t gemm0_bytes_end = (SharedMemTrait::k_block_space_size_aligned +
SharedMemTrait::a_block_space_size_aligned) *
sizeof(GemmDataType);
const index_t gemm1_bytes_end = (SharedMemTrait::k_block_space_size_aligned +
SharedMemTrait::b1_block_space_size_aligned) *
sizeof(GemmDataType);
const index_t gemm2_bytes_end = (SharedMemTrait::k_block_space_size_aligned +
SharedMemTrait::a2_block_space_size_aligned) *
sizeof(GemmDataType);
const index_t gemm3_bytes_end = (SharedMemTrait::k_block_space_size_aligned +
SharedMemTrait::a_block_space_size_aligned) *
sizeof(GemmDataType);
const index_t d0_bytes_end =
(SharedMemTrait::d0_block_space_offset + SharedMemTrait::d0_block_space_size_aligned) *
D0Loader::template TypeTransform<D0DataType>::Size0;
const index_t c_block_bytes_end =
SharedMemTrait::c_block_space_size * sizeof(FloatCShuffle);
return math::max(gemm0_bytes_end,
gemm1_bytes_end,
gemm2_bytes_end,
gemm3_bytes_end,
d0_bytes_end,
c_block_bytes_end);
}
template <bool HasMainKBlockLoop,
bool IsDropout,
typename Block2CTileMap,
......@@ -2137,7 +2154,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
p_d0_grid, d0_grid_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize());
auto d0_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<D0DataType*>(p_shared) + SharedMemTrait::a_block_space_offset,
static_cast<D0DataType*>(p_shared) + SharedMemTrait::d0_block_space_offset,
D0Loader::d0_block_write_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize());
auto d0_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, D0DataType>(
......
......@@ -1259,68 +1259,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
};
using YDotYGrad_M_O = YDotYGrad_M_O_<BlockSize, MPerBlock, Gemm1NPerBlock>;
struct SharedMemTrait
{
// // LDS allocation for A and B: be careful of alignment
static constexpr auto q_block_desc_k0_m_k1 =
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
static constexpr auto k_block_desc_k0_n_k1 =
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
static constexpr auto ygrad_block_desc_k0_m_k1 =
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
static constexpr auto p_slash_sgrad_block_desc_k0_m_k1 =
GetA2BlockDescriptor_K0_M_K1<Gemm2Params>();
static constexpr auto max_lds_align = Number<16 / sizeof(GemmDataType)>{};
static constexpr auto q_block_space_size_aligned =
math::integer_least_multiple(q_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align);
static constexpr auto k_block_space_size_aligned =
math::integer_least_multiple(k_block_desc_k0_n_k1.GetElementSpaceSize(), max_lds_align);
static constexpr auto ygrad_block_space_size_aligned = math::integer_least_multiple(
ygrad_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align);
static constexpr auto p_slash_sgrad_block_space_size_aligned = math::integer_least_multiple(
p_slash_sgrad_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align);
static constexpr auto k_block_space_offset = 0;
static constexpr auto ygrad_block_space_offset = k_block_space_size_aligned.value;
static constexpr auto q_block_space_offset =
k_block_space_size_aligned.value + ygrad_block_space_size_aligned.value;
static constexpr auto p_slash_sgrad_block_space_offset =
k_block_space_size_aligned.value + ygrad_block_space_size_aligned.value +
q_block_space_size_aligned.value;
// LDS allocation for reduction
static constexpr index_t reduction_space_size_aligned =
math::integer_least_multiple(BlockSize, max_lds_align);
static constexpr auto reduction_space_offset =
(k_block_space_size_aligned.value + ygrad_block_space_size_aligned.value +
q_block_space_size_aligned.value) *
sizeof(GemmDataType) / sizeof(FloatGemmAcc);
// LDS allocation for C shuffle in LDS
static constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
static constexpr auto c_block_space_size =
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
};
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{
const index_t p_slash_sgrad_bytes_end =
(SharedMemTrait::p_slash_sgrad_block_space_offset +
SharedMemTrait::p_slash_sgrad_block_space_size_aligned) *
sizeof(GemmDataType);
const index_t softmax_bytes_end = (SharedMemTrait::reduction_space_offset +
SharedMemTrait::reduction_space_size_aligned) *
sizeof(FloatGemmAcc);
const index_t c_block_bytes_end =
SharedMemTrait::c_block_space_size * sizeof(FloatCShuffle);
return math::max(p_slash_sgrad_bytes_end, softmax_bytes_end, c_block_bytes_end);
}
// D0
static constexpr auto D0M2 = Number<4>{};
static constexpr auto D0M1 = Number<MPerXdl>{} / D0M2;
......@@ -1353,12 +1291,16 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
template <typename DataType>
struct TypeTransform
{
using Type = DataType;
using Type = DataType;
static constexpr index_t Size0 = sizeof(DataType);
static constexpr index_t Size = sizeof(DataType);
};
template <>
struct TypeTransform<void>
{
using Type = ck::half_t;
using Type = ck::half_t;
static constexpr index_t Size0 = 0;
static constexpr index_t Size = sizeof(ck::half_t);
};
static constexpr index_t NThreadClusterLengths = MPerXdl;
static_assert(MPerXdl <= KPerBlock);
......@@ -1434,6 +1376,79 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
2>;
};
struct SharedMemTrait
{
// // LDS allocation for A and B: be careful of alignment
static constexpr auto q_block_desc_k0_m_k1 =
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
static constexpr auto k_block_desc_k0_n_k1 =
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
static constexpr auto ygrad_block_desc_k0_m_k1 =
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
static constexpr auto p_slash_sgrad_block_desc_k0_m_k1 =
GetA2BlockDescriptor_K0_M_K1<Gemm2Params>();
static constexpr auto max_lds_align = Number<16 / sizeof(GemmDataType)>{};
static constexpr auto q_block_space_size_aligned =
math::integer_least_multiple(q_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align);
static constexpr auto k_block_space_size_aligned =
math::integer_least_multiple(k_block_desc_k0_n_k1.GetElementSpaceSize(), max_lds_align);
static constexpr auto ygrad_block_space_size_aligned = math::integer_least_multiple(
ygrad_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align);
static constexpr auto p_slash_sgrad_block_space_size_aligned = math::integer_least_multiple(
p_slash_sgrad_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align);
static constexpr auto k_block_space_offset = 0;
static constexpr auto ygrad_block_space_offset = k_block_space_size_aligned.value;
static constexpr auto q_block_space_offset =
k_block_space_size_aligned.value + ygrad_block_space_size_aligned.value;
static constexpr auto p_slash_sgrad_block_space_offset =
k_block_space_size_aligned.value + ygrad_block_space_size_aligned.value +
q_block_space_size_aligned.value;
// LDS allocation for reduction
static constexpr index_t reduction_space_size_aligned =
math::integer_least_multiple(BlockSize, max_lds_align);
static constexpr auto reduction_space_offset =
(k_block_space_size_aligned.value + ygrad_block_space_size_aligned.value +
q_block_space_size_aligned.value) *
sizeof(GemmDataType) / sizeof(FloatGemmAcc);
static constexpr auto d0_block_space_size_aligned = math::integer_least_multiple(
D0Loader::d0_block_write_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize(), max_lds_align);
static constexpr auto d0_block_space_offset =
(k_block_space_size_aligned.value + ygrad_block_space_size_aligned.value +
q_block_space_size_aligned.value) *
sizeof(GemmDataType) / D0Loader::template TypeTransform<D0DataType>::Size;
// LDS allocation for C shuffle in LDS
static constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
static constexpr auto c_block_space_size =
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
};
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{
const index_t p_slash_sgrad_bytes_end =
(SharedMemTrait::p_slash_sgrad_block_space_offset +
SharedMemTrait::p_slash_sgrad_block_space_size_aligned) *
sizeof(GemmDataType);
const index_t softmax_bytes_end = (SharedMemTrait::reduction_space_offset +
SharedMemTrait::reduction_space_size_aligned) *
sizeof(FloatGemmAcc);
const index_t d0_bytes_end =
(SharedMemTrait::d0_block_space_offset + SharedMemTrait::d0_block_space_size_aligned) *
D0Loader::template TypeTransform<D0DataType>::Size0;
const index_t c_block_bytes_end =
SharedMemTrait::c_block_space_size * sizeof(FloatCShuffle);
return math::max(
p_slash_sgrad_bytes_end, softmax_bytes_end, d0_bytes_end, c_block_bytes_end);
}
template <bool HasMainKBlockLoop,
bool IsDropout,
typename Block2CTileMap,
......@@ -2186,8 +2201,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
p_d0_grid, d0_grid_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize());
auto d0_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<D0DataType*>(p_shared) +
SharedMemTrait::p_slash_sgrad_block_space_offset,
static_cast<D0DataType*>(p_shared) + SharedMemTrait::d0_block_space_offset,
D0Loader::d0_block_write_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize());
auto d0_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, D0DataType>(
......@@ -2222,10 +2236,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
});
});
// load k
gemm_tile_k_blockwise_copy.RunWrite(GemmBlockwiseCopy::k_block_desc_k0_n_k1,
k_block_buf);
d0_block_copy_global_to_lds.MoveSrcSliceWindow(
d0_grid_desc_m0_n0_m1_m2_n1_m3, make_multi_index(-1, 0, -D0M0.value, 0, 0, 0));
}
......
......@@ -1321,79 +1321,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
}
};
struct SharedMemTrait
{
// LDS allocation for K
static constexpr auto k_block_desc_k0_n_k1 = GetKBlockDescriptor_K0PerBlock_NPerBlock_K1();
// LDS allocation for A and B: be careful of alignment
static constexpr auto a_block_desc_ak0_m_ak1 =
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
static constexpr auto b1_block_desc_bk0_n_bk1 =
GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1();
static constexpr auto a2_block_desc_k0_m_k1 = GetA2BlockDescriptor_K0_M_K1<Gemm2Params>();
static constexpr auto max_lds_align = Number<16 / sizeof(GemmDataType)>{};
static constexpr auto k_block_space_size_aligned =
math::integer_least_multiple(k_block_desc_k0_n_k1.GetElementSpaceSize(), max_lds_align);
static constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
static constexpr auto b1_block_space_size_aligned = math::integer_least_multiple(
b1_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
static constexpr auto a2_block_space_size_aligned = math::integer_least_multiple(
a2_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align);
static constexpr auto k_block_space_offset = 0;
static constexpr auto a_block_space_offset = k_block_space_size_aligned.value;
static constexpr auto b1_block_space_offset = k_block_space_size_aligned.value;
static constexpr auto a2_block_space_offset = k_block_space_size_aligned.value;
// LDS allocation for reduction
static constexpr index_t reduction_space_size_aligned =
math::integer_least_multiple(BlockSize, max_lds_align);
static constexpr auto reduction_space_offset =
(math::max(a_block_space_size_aligned.value,
b1_block_space_size_aligned.value,
a2_block_space_size_aligned.value) +
k_block_space_size_aligned.value) *
sizeof(GemmDataType) / sizeof(FloatGemmAcc);
// LDS allocation for C shuffle in LDS
static constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
static constexpr auto c_block_space_size =
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
};
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{
const index_t gemm0_bytes_end = (SharedMemTrait::k_block_space_size_aligned +
SharedMemTrait::a_block_space_size_aligned) *
sizeof(GemmDataType);
const index_t gemm1_bytes_end = (SharedMemTrait::k_block_space_size_aligned +
SharedMemTrait::b1_block_space_size_aligned) *
sizeof(GemmDataType);
const index_t gemm2_bytes_end = (SharedMemTrait::k_block_space_size_aligned +
SharedMemTrait::a2_block_space_size_aligned) *
sizeof(GemmDataType);
const index_t gemm3_bytes_end = (SharedMemTrait::k_block_space_size_aligned +
SharedMemTrait::a_block_space_size_aligned) *
sizeof(GemmDataType);
const index_t softmax_bytes_end = (SharedMemTrait::reduction_space_offset +
SharedMemTrait::reduction_space_size_aligned) *
sizeof(FloatGemmAcc);
const index_t c_block_bytes_end =
SharedMemTrait::c_block_space_size * sizeof(FloatCShuffle);
return math::max(gemm0_bytes_end,
gemm1_bytes_end,
gemm2_bytes_end,
gemm3_bytes_end,
softmax_bytes_end,
c_block_bytes_end);
}
// D0
static constexpr auto D0M2 = Number<4>{};
static constexpr auto D0M1 = Number<MPerXdl>{} / D0M2;
......@@ -1426,12 +1353,16 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
template <typename DataType>
struct TypeTransform
{
using Type = DataType;
using Type = DataType;
static constexpr index_t Size0 = sizeof(DataType);
static constexpr index_t Size = sizeof(DataType);
};
template <>
struct TypeTransform<void>
{
using Type = ck::half_t;
using Type = ck::half_t;
static constexpr index_t Size0 = 0;
static constexpr index_t Size = sizeof(ck::half_t);
};
static constexpr index_t NThreadClusterLengths = 32;
static_assert(NPerXdl == 32);
......@@ -1507,6 +1438,89 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
2>;
};
struct SharedMemTrait
{
// LDS allocation for K
static constexpr auto k_block_desc_k0_n_k1 = GetKBlockDescriptor_K0PerBlock_NPerBlock_K1();
// LDS allocation for A and B: be careful of alignment
static constexpr auto a_block_desc_ak0_m_ak1 =
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
static constexpr auto b1_block_desc_bk0_n_bk1 =
GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1();
static constexpr auto a2_block_desc_k0_m_k1 = GetA2BlockDescriptor_K0_M_K1<Gemm2Params>();
static constexpr auto max_lds_align = Number<16 / sizeof(GemmDataType)>{};
static constexpr auto k_block_space_size_aligned =
math::integer_least_multiple(k_block_desc_k0_n_k1.GetElementSpaceSize(), max_lds_align);
static constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
static constexpr auto b1_block_space_size_aligned = math::integer_least_multiple(
b1_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
static constexpr auto a2_block_space_size_aligned = math::integer_least_multiple(
a2_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align);
static constexpr auto k_block_space_offset = 0;
static constexpr auto a_block_space_offset = k_block_space_size_aligned.value;
static constexpr auto b1_block_space_offset = k_block_space_size_aligned.value;
static constexpr auto a2_block_space_offset = k_block_space_size_aligned.value;
// LDS allocation for reduction
static constexpr index_t reduction_space_size_aligned =
math::integer_least_multiple(BlockSize, max_lds_align);
static constexpr auto reduction_space_offset =
(math::max(a_block_space_size_aligned.value,
b1_block_space_size_aligned.value,
a2_block_space_size_aligned.value) +
k_block_space_size_aligned.value) *
sizeof(GemmDataType) / sizeof(FloatGemmAcc);
static constexpr auto d0_block_space_size_aligned = math::integer_least_multiple(
D0Loader::d0_block_write_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize(), max_lds_align);
static constexpr auto d0_block_space_offset = k_block_space_size_aligned.value *
sizeof(GemmDataType) /
D0Loader::template TypeTransform<D0DataType>::Size;
// LDS allocation for C shuffle in LDS
static constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
static constexpr auto c_block_space_size =
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
};
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{
const index_t gemm0_bytes_end = (SharedMemTrait::k_block_space_size_aligned +
SharedMemTrait::a_block_space_size_aligned) *
sizeof(GemmDataType);
const index_t gemm1_bytes_end = (SharedMemTrait::k_block_space_size_aligned +
SharedMemTrait::b1_block_space_size_aligned) *
sizeof(GemmDataType);
const index_t gemm2_bytes_end = (SharedMemTrait::k_block_space_size_aligned +
SharedMemTrait::a2_block_space_size_aligned) *
sizeof(GemmDataType);
const index_t gemm3_bytes_end = (SharedMemTrait::k_block_space_size_aligned +
SharedMemTrait::a_block_space_size_aligned) *
sizeof(GemmDataType);
const index_t softmax_bytes_end = (SharedMemTrait::reduction_space_offset +
SharedMemTrait::reduction_space_size_aligned) *
sizeof(FloatGemmAcc);
const index_t d0_bytes_end =
(SharedMemTrait::d0_block_space_offset + SharedMemTrait::d0_block_space_size_aligned) *
D0Loader::template TypeTransform<D0DataType>::Size0;
const index_t c_block_bytes_end =
SharedMemTrait::c_block_space_size * sizeof(FloatCShuffle);
return math::max(gemm0_bytes_end,
gemm1_bytes_end,
gemm2_bytes_end,
gemm3_bytes_end,
softmax_bytes_end,
d0_bytes_end,
c_block_bytes_end);
}
template <bool HasMainKBlockLoop,
bool IsDropout,
typename Block2CTileMap,
......@@ -2292,7 +2306,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
p_d0_grid, d0_grid_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize());
auto d0_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<D0DataType*>(p_shared) + SharedMemTrait::a_block_space_offset,
static_cast<D0DataType*>(p_shared) + SharedMemTrait::d0_block_space_offset,
D0Loader::d0_block_write_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize());
auto d0_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, D0DataType>(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment