Commit c5fd087e authored by aska-0096's avatar aska-0096
Browse files

Attn, skip b lds

parent 6e28a8ac
......@@ -180,27 +180,57 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
static auto MakeB0GridDescriptor(const std::vector<index_t>& b0_gs_ls_ks_lengths_vec,
const std::vector<index_t>& b0_gs_ls_ks_strides_vec)
{
return Transform::MakeB0GridDescriptor_BK0_N_BK1(
Transform::MakeB0GridDescriptor_N_K(b0_gs_ls_ks_lengths_vec, b0_gs_ls_ks_strides_vec),
Number<K1>{});
if constexpr(B0EnableLds)
{
return Transform::MakeB0GridDescriptor_BK0_N_BK1(
Transform::MakeB0GridDescriptor_N_K(b0_gs_ls_ks_lengths_vec,
b0_gs_ls_ks_strides_vec),
Number<K1>{});
}
else
{
return Transform::MakeB0GridDescriptor_BKWmma_LBlockRepeat_LWaves_BKRow_LPerWmma_BK1(
Transform::MakeB0GridDescriptor_N_K(b0_gs_ls_ks_lengths_vec,
b0_gs_ls_ks_strides_vec),
Number<WmmaK>{},
Number<LRepeat>{},
Number<LWaves>{},
Number<LPerWmma>{},
Number<K1>{});
}
}
static auto MakeB1GridDescriptor_BL0_N_BL1(const std::vector<index_t>& b1_gs_ns_ls_lengths_vec,
const std::vector<index_t>& b1_gs_ns_ls_strides_vec)
static auto MakeB1GridDescriptor(const std::vector<index_t>& b1_gs_ns_ls_lengths_vec,
const std::vector<index_t>& b1_gs_ns_ls_strides_vec)
{
return Transform::MakeB1GridDescriptor_BK0_N_BK1(
Transform::MakeB1GridDescriptor_N_K(b1_gs_ns_ls_lengths_vec, b1_gs_ns_ls_strides_vec),
Number<L1>{});
if constexpr(B1EnableLds)
{
return Transform::MakeB1GridDescriptor_BK0_N_BK1(
Transform::MakeB1GridDescriptor_N_K(b1_gs_ns_ls_lengths_vec,
b1_gs_ns_ls_strides_vec),
Number<L1>{});
}
else
{
return Transform::MakeB1GridDescriptor_BLWmma_NBlockRepeat_NWaves_BLRow_NPerWmma_BL1(
Transform::MakeB1GridDescriptor_N_K(b1_gs_ns_ls_lengths_vec,
b1_gs_ns_ls_strides_vec),
Number<WmmaK>{},
Number<NRepeat>{},
Number<NWaves>{},
Number<NPerWmma>{},
Number<L1>{});
}
}
using AGridDesc = decltype(MakeAGridDescriptor({}, {}));
using B0GridDesc_BK0_L_BK1 = decltype(MakeB0GridDescriptor({}, {}));
using B1GridDesc_BL0_N_BL1 = decltype(MakeB1GridDescriptor_BL0_N_BL1({}, {}));
using CGridDesc_M_N = decltype(Transform::MakeCGridDescriptor_M_N({}, {}));
using AGridDesc_G_M_K = decltype(Transform::MakeAGridDescriptor_G_M_K({}, {}));
using B0GridDesc_G_L_K = decltype(Transform::MakeB0GridDescriptor_G_N_K({}, {}));
using B1GridDesc_G_N_L = decltype(Transform::MakeB1GridDescriptor_G_N_K({}, {}));
using CGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {}));
using AGridDesc = decltype(MakeAGridDescriptor({}, {}));
using B0GridDesc = decltype(MakeB0GridDescriptor({}, {}));
using B1GridDesc = decltype(MakeB1GridDescriptor({}, {}));
using CGridDesc_M_N = decltype(Transform::MakeCGridDescriptor_M_N({}, {}));
using AGridDesc_G_M_K = decltype(Transform::MakeAGridDescriptor_G_M_K({}, {}));
using B0GridDesc_G_L_K = decltype(Transform::MakeB0GridDescriptor_G_N_K({}, {}));
using B1GridDesc_G_N_L = decltype(Transform::MakeB1GridDescriptor_G_N_K({}, {}));
using CGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {}));
constexpr static auto make_MaskOutPredicate()
{
......@@ -274,8 +304,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
InMemoryDataOperationEnum::Set,
// InMemory Data Descriptor
AGridDesc,
B0GridDesc_BK0_L_BK1,
B1GridDesc_BL0_N_BL1,
B0GridDesc,
B1GridDesc,
CGridDesc_M_N,
// Tiling Family
MPerBlock,
......@@ -364,10 +394,10 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
p_b1_grid_{p_b1_grid},
p_c_grid_{p_c_grid},
a_grid_desc{DeviceOp::MakeAGridDescriptor(a_gs_ms_ks_lengths, a_gs_ms_ks_strides)},
b0_grid_desc_bk0_l_bk1_{
b0_grid_desc{
DeviceOp::MakeB0GridDescriptor(b0_gs_ls_ks_lengths, b0_gs_ls_ks_strides)},
b1_grid_desc_bl0_n_bl1_{DeviceOp::MakeB1GridDescriptor_BL0_N_BL1(
b1_gs_ns_ls_lengths, b1_gs_ns_ls_strides)},
b1_grid_desc{
DeviceOp::MakeB1GridDescriptor(b1_gs_ns_ls_lengths, b1_gs_ns_ls_strides)},
c_grid_desc_m_n_{
Transform::MakeCGridDescriptor_M_N(c_gs_ms_ns_lengths, c_gs_ms_ns_strides)},
a_grid_desc_g_m_k_{
......@@ -410,11 +440,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
ignore = acc1_biases_gs_ms_ns_lengths;
ignore = acc1_biases_gs_ms_ns_strides;
if(GridwiseOp::CheckValidity(a_grid_desc,
b0_grid_desc_bk0_l_bk1_,
b1_grid_desc_bl0_n_bl1_,
c_grid_desc_m_n_,
block_2_ctile_map_))
if(GridwiseOp::CheckValidity(
a_grid_desc, b0_grid_desc, b1_grid_desc, c_grid_desc_m_n_, block_2_ctile_map_))
{
c_grid_desc_mblock_mperblock_nblock_nperblock_ =
GridwiseOp::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
......@@ -430,8 +457,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
// Tensor Descriptors
AGridDesc a_grid_desc;
B0GridDesc_BK0_L_BK1 b0_grid_desc_bk0_l_bk1_;
B1GridDesc_BL0_N_BL1 b1_grid_desc_bl0_n_bl1_;
B0GridDesc b0_grid_desc;
B1GridDesc b1_grid_desc;
CGridDesc_M_N c_grid_desc_m_n_;
AGridDesc_G_M_K a_grid_desc_g_m_k_;
......@@ -498,8 +525,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
B1DataType,
CDataType,
DeviceOp::AGridDesc,
DeviceOp::B0GridDesc_BK0_L_BK1,
DeviceOp::B1GridDesc_BL0_N_BL1,
DeviceOp::B0GridDesc,
DeviceOp::B1GridDesc,
typename GridwiseOp::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
AElementwiseOperation,
B0ElementwiseOperation,
......@@ -521,8 +548,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
arg.p_b1_grid_,
arg.p_c_grid_,
arg.a_grid_desc,
arg.b0_grid_desc_bk0_l_bk1_,
arg.b1_grid_desc_bl0_n_bl1_,
arg.b0_grid_desc,
arg.b1_grid_desc,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.a_element_op_,
arg.b0_element_op_,
......@@ -582,8 +609,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
}
if(!GridwiseOp::CheckValidity(arg.a_grid_desc,
arg.b0_grid_desc_bk0_l_bk1_,
arg.b1_grid_desc_bl0_n_bl1_,
arg.b0_grid_desc,
arg.b1_grid_desc,
arg.c_grid_desc_m_n_,
arg.block_2_ctile_map_))
{
......
......@@ -18,14 +18,14 @@
namespace ck {
template <typename GridwiseGemm,
typename FloatA,
typename FloatB0,
typename FloatB1,
typename FloatC,
template <typename GridwiseOp,
typename ADataType,
typename B0DataType,
typename B1DataType,
typename CDataType,
typename AGridDesc,
typename B0GridDesc_BK0_L_BK1,
typename B1GridDesc_BL0_N_BL1,
typename B0GridDesc,
typename B1GridDesc,
typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename AElementwiseOperation,
typename B0ElementwiseOperation,
......@@ -41,13 +41,13 @@ __global__ void
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_batched_gemm_softmax_gemm_wmma_cshuffle(
const FloatA* __restrict__ p_a_grid,
const FloatB0* __restrict__ p_b0_grid,
const FloatB1* __restrict__ p_b1_grid,
FloatC* __restrict__ p_c_grid,
const ADataType* __restrict__ p_a_grid,
const B0DataType* __restrict__ p_b0_grid,
const B1DataType* __restrict__ p_b1_grid,
CDataType* __restrict__ p_c_grid,
const AGridDesc a_grid_desc,
const B0GridDesc_BK0_L_BK1 b0_grid_desc_bk0_l_bk1,
const B1GridDesc_BL0_N_BL1 b1_grid_desc_l0_n_l1,
const B0GridDesc b0_grid_desc,
const B1GridDesc b1_grid_desc,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock,
const AElementwiseOperation a_element_op,
......@@ -61,7 +61,7 @@ __global__ void
const Block2CTileMap block_2_ctile_map)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
__shared__ char p_shared[GridwiseOp::GetSharedMemoryNumberOfByte()];
const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
......@@ -76,30 +76,30 @@ __global__ void
const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetCBasePtr(g_idx)));
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid + a_batch_offset,
p_b0_grid + b0_batch_offset,
p_b1_grid + b1_batch_offset,
p_c_grid + c_batch_offset,
p_shared,
a_grid_desc,
b0_grid_desc_bk0_l_bk1,
b1_grid_desc_l0_n_l1,
c_grid_desc_mblock_mperblock_nblock_nperblock,
a_element_op,
b0_element_op,
acc_element_op,
b1_element_op,
c_element_op,
c0_matrix_mask,
block_2_ctile_map);
GridwiseOp::template Run<HasMainKBlockLoop>(p_a_grid + a_batch_offset,
p_b0_grid + b0_batch_offset,
p_b1_grid + b1_batch_offset,
p_c_grid + c_batch_offset,
p_shared,
a_grid_desc,
b0_grid_desc,
b1_grid_desc,
c_grid_desc_mblock_mperblock_nblock_nperblock,
a_element_op,
b0_element_op,
acc_element_op,
b1_element_op,
c_element_op,
c0_matrix_mask,
block_2_ctile_map);
#else
ignore = p_a_grid;
ignore = p_b0_grid;
ignore = p_b1_grid;
ignore = p_c_grid;
ignore = a_grid_desc;
ignore = b0_grid_desc_bk0_l_bk1;
ignore = b1_grid_desc_l0_n_l1;
ignore = b0_grid_desc;
ignore = b1_grid_desc;
ignore = c_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = a_element_op;
ignore = b0_element_op;
......@@ -115,13 +115,13 @@ __global__ void
// Gemm0: A [M x K] x B0 [K x L] = Acc [M x L]
// Gemm1: Acc [M x L] x B1 [L x N] = C [M x N]
template <typename FloatA,
typename FloatB0,
typename FloatAcc0,
typename FloatB1,
typename FloatAcc1,
typename FloatCShuffle,
typename FloatC,
template <typename ADataType,
typename B0DataType,
typename Acc0DataType,
typename B1DataType,
typename Acc1DataType,
typename CShuffleDataType,
typename CDataType,
typename AElementwiseOperation,
typename B0ElementwiseOperation,
typename AccElementwiseOperation,
......@@ -129,8 +129,8 @@ template <typename FloatA,
typename CElementwiseOperation,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
typename AGridDesc,
typename B0GridDesc_BK0_L_BK1,
typename B1GridDesc_BL0_N_BL1,
typename B0GridDesc,
typename B1GridDesc,
typename CGridDesc_M_N,
index_t MPerBlock,
index_t LPerBlock,
......@@ -163,7 +163,7 @@ template <typename FloatA,
index_t B0BlockTransferDstScalarPerVector_K1,
bool B0ThreadTransferSrcResetCoordinateAfterRun,
bool B0EnableLds,
bool B0BlockLdsExtraN,
bool B0BlockLdsExtraL,
typename B1BlockTransferThreadClusterLengths_L0_N_L1,
typename B1BlockTransferThreadClusterArrangeOrder,
typename B1BlockTransferSrcAccessOrder,
......@@ -204,8 +204,10 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
static constexpr auto BL1 = Number<L1Value>{};
static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma);
static constexpr auto LWaves = LPerBlock / (LRepeat * LPerWmma);
static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma);
static constexpr auto WmmaK = 16;
static constexpr auto WmmaL = 16;
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
......@@ -250,6 +252,73 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
return a_block_desc;
}
__host__ __device__ static constexpr auto MakeB0BlockDescriptor()
{
constexpr auto b0_block_desc = [&]() {
if constexpr(B0EnableLds)
{
// K0->L->BK1 Per Block
constexpr auto K0PerBlock = KPerBlock / BK1;
constexpr auto max_lds_align = BK1;
if constexpr(B0BlockLdsExtraL)
{
return make_naive_tensor_descriptor(
make_tuple(Number<K0PerBlock>{}, Number<LPerBlock>{}, BK1),
make_tuple(Number<LPerBlock + 1>{} * BK1, BK1, I1));
}
else
{
return make_naive_tensor_descriptor_aligned(
make_tuple(Number<K0PerBlock>{}, Number<LPerBlock>{}, BK1), max_lds_align);
}
}
else
{
constexpr auto KWmmaPerblock = KPerBlock / WmmaK;
// KWmma->NRepeat->NWave->NRow->NPerWmma->BK1 Per Thread
return make_naive_tensor_descriptor(
make_tuple(Number<KWmmaPerblock>{}, Number<LRepeat>{}, I1, I1, I1, BK1),
make_tuple(Number<LRepeat>{} * BK1, BK1, BK1, BK1, BK1, I1));
}
}();
return b0_block_desc;
}
__host__ __device__ static constexpr auto MakeB1BlockDescriptor()
{
constexpr auto b1_block_desc = [&]() {
if constexpr(B1EnableLds)
{
// L0->N->BL1 Per Block
constexpr auto max_lds_align = BL1;
if constexpr(B1BlockLdsExtraN)
{
return make_naive_tensor_descriptor(
make_tuple(Number<L0PerBlock>{}, Number<NPerBlock>{}, BL1),
make_tuple(Number<NPerBlock + 1>{} * BL1, BL1, I1));
}
else
{
return make_naive_tensor_descriptor_aligned(
make_tuple(Number<L0PerBlock>{}, Number<NPerBlock>{}, BL1), max_lds_align);
}
}
else
{
constexpr auto LWmmaPerblock = LPerBlock / WmmaL;
// LWmma->NRepeat->NWave->NRow->LPerWmma->BL1 Per Thread
return make_naive_tensor_descriptor(
make_tuple(Number<LWmmaPerblock>{}, Number<NRepeat>{}, I1, I1, I1, BL1),
make_tuple(Number<NRepeat>{} * BL1, BL1, BL1, BL1, BL1, I1));
}
}();
return b1_block_desc;
}
__host__ __device__ static constexpr auto MakeABlockSliceCopyStep()
{
constexpr auto a_block_copy_step = [&]() {
......@@ -270,6 +339,44 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
return a_block_copy_step;
}
__host__ __device__ static constexpr auto MakeB0BlockSliceCopyStep()
{
constexpr auto b0_block_copy_step = [&]() {
if constexpr(B0EnableLds)
{
constexpr auto K0PerBlock = KPerBlock / BK1;
return make_multi_index(K0PerBlock, 0, 0);
}
else
{
constexpr auto KWmmaPerBlock = KPerBlock / WmmaK;
return make_multi_index(KWmmaPerBlock, 0, 0, 0, 0, 0);
}
}();
return b0_block_copy_step;
}
__host__ __device__ static constexpr auto MakeB1BlockSliceCopyStep()
{
constexpr auto b1_block_copy_step = [&]() {
if constexpr(B1EnableLds)
{
return make_multi_index(L0PerBlock, 0, 0);
}
else
{
constexpr auto LWmmaPerBlock = LTilePerBlock / WmmaL;
return make_multi_index(LWmmaPerBlock, 0, 0, 0, 0, 0);
}
}();
return b1_block_copy_step;
}
// Describe how data read from (LDS/VGPR) buffer
template <typename ABlockDesc_>
__host__ __device__ static constexpr auto MakeAWaveDescriptor(const ABlockDesc_&)
......@@ -323,26 +430,61 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
return a_wave_desc;
}
template <typename B0BlockDesc_BK0_L_BK1>
__host__ __device__ static constexpr auto
MakeB0BlockDescriptor_K0_L0_L1_L2_K1(const B0BlockDesc_BK0_L_BK1&)
template <typename B0BlockDesc_>
__host__ __device__ static constexpr auto MakeB0WaveDescriptor(const B0BlockDesc_&)
{
constexpr index_t B_K0 = B0BlockDesc_BK0_L_BK1{}.GetLength(I0);
constexpr index_t B_K1 = B0BlockDesc_BK0_L_BK1{}.GetLength(I2);
constexpr index_t LWaves = LPerBlock / (LRepeat * LPerWmma);
return transform_tensor_descriptor(
B0BlockDesc_BK0_L_BK1{},
make_tuple(make_pass_through_transform(Number<B_K0>{}),
make_unmerge_transform(
make_tuple(Number<LRepeat>{}, Number<LWaves>{}, Number<LPerWmma>{})),
make_pass_through_transform(Number<B_K1>{})),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{}));
constexpr auto b0_wave_desc = [&]() {
if constexpr(B0EnableLds)
{
// BK0_L_BK1 -> BK0_LRepeat_Lwaves_LPerWmma_BK1
constexpr auto B_K0 = B0BlockDesc_{}.GetLength(I0);
constexpr auto B_K1 = B0BlockDesc_{}.GetLength(I2);
return transform_tensor_descriptor(
B0BlockDesc_{},
make_tuple(make_pass_through_transform(Number<B_K0>{}),
make_unmerge_transform(make_tuple(
Number<LRepeat>{}, Number<LWaves>{}, Number<LPerWmma>{})),
make_pass_through_transform(Number<B_K1>{})),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{}));
}
else
{
// KWmma_LRepeat_LWave_KRow_LPerWmma_K1 -> K0_LRepeat_Lwaves_LPerWmma_K1
constexpr auto KWmma = B0BlockDesc_{}.GetLength(I0);
constexpr auto B_K1 = B0BlockDesc_{}.GetLength(I5);
// Workaround, Freeze transform
return transform_tensor_descriptor(
B0BlockDesc_{},
make_tuple(make_freeze_transform(I0),
make_pass_through_transform(Number<KWmma>{}),
make_pass_through_transform(Number<LRepeat>{}),
make_pass_through_transform(I1),
make_pass_through_transform(I1),
make_pass_through_transform(Number<B_K1>{})),
make_tuple(Sequence<3>{},
Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<4>{},
Sequence<5>{}),
make_tuple(Sequence<>{},
Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{}));
}
}();
return b0_wave_desc;
}
template <typename A1BlockDesc_AL0_M_AL1>
__host__ __device__ static constexpr auto
MakeA1BlockDescriptor_L0_M0_M1_M2_L1(const A1BlockDesc_AL0_M_AL1&)
MakeA1WaveDescriptor_L0_M0_M1_M2_L1(const A1BlockDesc_AL0_M_AL1&)
{
constexpr index_t A_L0 = A1BlockDesc_AL0_M_AL1{}.GetLength(I0);
constexpr index_t A_L1 = A1BlockDesc_AL0_M_AL1{}.GetLength(I2);
......@@ -356,37 +498,56 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{}));
}
template <typename B1BlockDesc_BL0_N_BL1>
__host__ __device__ static constexpr auto
MakeB1BlockDescriptor_L0_N0_N1_N2_L1(const B1BlockDesc_BL0_N_BL1&)
template <typename B1BlockDesc_>
__host__ __device__ static constexpr auto MakeB1WaveDescriptor(const B1BlockDesc_&)
{
constexpr index_t B_K0 = B1BlockDesc_BL0_N_BL1{}.GetLength(I0);
constexpr index_t B_K1 = B1BlockDesc_BL0_N_BL1{}.GetLength(I2);
return transform_tensor_descriptor(
B1BlockDesc_BL0_N_BL1{},
make_tuple(make_pass_through_transform(Number<B_K0>{}),
make_unmerge_transform(
make_tuple(Number<NRepeat>{}, Number<NWaves>{}, Number<NPerWmma>{})),
make_pass_through_transform(Number<B_K1>{})),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{}));
}
constexpr auto b1_wave_desc = [&]() {
if constexpr(B1EnableLds)
{
// BL0_N_BL1 -> BL0_NRepeat_Nwaves_NPerWmma_BL1
constexpr auto B_L0 = B1BlockDesc_{}.GetLength(I0);
constexpr auto B_L1 = B1BlockDesc_{}.GetLength(I2);
return transform_tensor_descriptor(
B1BlockDesc_{},
make_tuple(make_pass_through_transform(Number<B_L0>{}),
make_unmerge_transform(make_tuple(
Number<NRepeat>{}, Number<NWaves>{}, Number<NPerWmma>{})),
make_pass_through_transform(Number<B_L1>{})),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{}));
}
else
{
// LWmma_NRepeat_NWave_LRow_NPerWmma_L1 -> L0_NRepeat_Nwaves_NPerWmma_L1
constexpr auto LWmma = B1BlockDesc_{}.GetLength(I0);
constexpr auto B_L1 = B1BlockDesc_{}.GetLength(I5);
__host__ __device__ static constexpr auto GetB0BlockDescriptor_BK0PerBlock_LPerBlock_BK1()
{
// B matrix in LDS memory, dst of blockwise copy
return make_naive_tensor_descriptor(
make_tuple(BK0, Number<LPerBlock>{}, BK1),
make_tuple(Number<LPerBlock + B0BlockLdsExtraN>{} * BK1, BK1, I1));
}
// Workaround, Freeze transform
return transform_tensor_descriptor(
B1BlockDesc_{},
make_tuple(make_freeze_transform(I0),
make_pass_through_transform(Number<LWmma>{}),
make_pass_through_transform(Number<NRepeat>{}),
make_pass_through_transform(I1),
make_pass_through_transform(I1),
make_pass_through_transform(Number<B_L1>{})),
make_tuple(Sequence<3>{},
Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<4>{},
Sequence<5>{}),
make_tuple(Sequence<>{},
Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{}));
}
}();
__host__ __device__ static constexpr auto GetB1BlockDescriptor_BL0PerBlock_NPerBlock_BL1()
{
// B1 matrix in LDS memory, dst of blockwise copy
return make_naive_tensor_descriptor(
make_tuple(BL0, Number<NPerBlock>{}, BL1),
make_tuple(Number<NPerBlock + B1BlockLdsExtraN>{} * BL1, BL1, I1));
return b1_wave_desc;
}
__host__ __device__ static constexpr auto
......@@ -410,31 +571,30 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
{
// LDS allocation for A and B: be careful of alignment
const index_t gemm0_bytes_end =
(SharedMemTrait::a_block_space_size_aligned * sizeof(FloatA) +
SharedMemTrait::b0_block_space_size_aligned * sizeof(FloatB0));
(SharedMemTrait::a_block_space_size_aligned * sizeof(ADataType) +
SharedMemTrait::b0_block_space_size_aligned * sizeof(B0DataType));
const index_t gemm1_bytes_end =
(SharedMemTrait::b1_block_space_offset +
SharedMemTrait::b1_block_space_size_aligned * sizeof(FloatB1));
SharedMemTrait::b1_block_space_size_aligned * sizeof(B1DataType));
const index_t softmax_bytes_end =
SharedMemTrait::reduction_space_offset +
SharedMemTrait::reduction_space_size_aligned * sizeof(FloatAcc0);
SharedMemTrait::reduction_space_size_aligned * sizeof(Acc0DataType);
const index_t c_block_bytes_end =
SharedMemTrait::c_block_space_size * sizeof(FloatCShuffle);
SharedMemTrait::c_block_space_size * sizeof(CShuffleDataType);
return math::max(gemm0_bytes_end, gemm1_bytes_end, softmax_bytes_end, c_block_bytes_end);
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
template <typename Block2CTileMap>
__host__ __device__ static constexpr bool
CheckValidity(const AGridDesc& a_grid_desc,
const B0GridDesc_BK0_L_BK1& b0_grid_desc_bk0_l_bk1,
const B1GridDesc_BL0_N_BL1& b1_grid_desc_l0_n_l1,
const CGridDesc_M_N& c_grid_desc_m_n,
const Block2CTileMap& block_2_ctile_map)
__host__ __device__ static constexpr bool CheckValidity(const AGridDesc& a_grid_desc,
const B0GridDesc& b0_grid_desc,
const B1GridDesc& b1_grid_desc,
const CGridDesc_M_N& c_grid_desc_m_n,
const Block2CTileMap& block_2_ctile_map)
{
static_assert((MPerBlock % (MPerWmma * MRepeat) == 0) &&
(LPerBlock % (LPerWmma * LRepeat)) == 0,
......@@ -455,10 +615,40 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
}
};
const auto GetB0ProblemsizeLK = [&]() {
if constexpr(B0EnableLds)
{
return make_tuple(b0_grid_desc.GetLength(I1),
b0_grid_desc.GetLength(I0) * b0_grid_desc.GetLength(I2));
}
else
{
return make_tuple(b0_grid_desc.GetLength(I1) * b0_grid_desc.GetLength(I2) *
b0_grid_desc.GetLength(I4),
b0_grid_desc.GetLength(I0) * b0_grid_desc.GetLength(I3) *
b0_grid_desc.GetLength(I5));
}
};
const auto GetB1ProblemsizeNL = [&]() {
if constexpr(B1EnableLds)
{
return make_tuple(b1_grid_desc.GetLength(I1),
b1_grid_desc.GetLength(I0) * b1_grid_desc.GetLength(I2));
}
else
{
return make_tuple(b1_grid_desc.GetLength(I1) * b1_grid_desc.GetLength(I2) *
b1_grid_desc.GetLength(I4),
b1_grid_desc.GetLength(I0) * b1_grid_desc.GetLength(I3) *
b1_grid_desc.GetLength(I5));
}
};
const auto M = GetAProblemsizeMK()[I0];
const auto L = b0_grid_desc_bk0_l_bk1.GetLength(I1);
const auto L = GetB0ProblemsizeLK()(I0);
const auto K = GetAProblemsizeMK()[I1];
const auto N = b1_grid_desc_l0_n_l1.GetLength(I1);
const auto N = GetB1ProblemsizeNL()(I0);
if(!(M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1)))
{
......@@ -567,17 +757,13 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
max_lds_align)
: 0;
static constexpr auto b0_block_space_size_aligned =
B0EnableLds
? math::integer_least_multiple(
GetB0BlockDescriptor_BK0PerBlock_LPerBlock_BK1().GetElementSpaceSize(),
max_lds_align)
: 0;
B0EnableLds ? math::integer_least_multiple(
MakeB0BlockDescriptor().GetElementSpaceSize(), max_lds_align)
: 0;
static constexpr auto b1_block_space_size_aligned =
B1EnableLds
? math::integer_least_multiple(
GetB1BlockDescriptor_BL0PerBlock_NPerBlock_BL1().GetElementSpaceSize(),
max_lds_align)
: 0;
B1EnableLds ? math::integer_least_multiple(
MakeB1BlockDescriptor().GetElementSpaceSize(), max_lds_align)
: 0;
static constexpr auto a_block_space_offset = 0;
static constexpr auto b0_block_space_offset = a_block_space_size_aligned;
......@@ -599,14 +785,14 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
template <bool HasMainKBlockLoop,
typename C0MatrixMask,
typename Block2CTileMap = DefaultBlock2CTileMap>
__device__ static void Run(const FloatA* __restrict__ p_a_grid,
const FloatB0* __restrict__ p_b0_grid,
const FloatB1* __restrict__ p_b1_grid,
FloatC* __restrict__ p_c_grid,
__device__ static void Run(const ADataType* __restrict__ p_a_grid,
const B0DataType* __restrict__ p_b0_grid,
const B1DataType* __restrict__ p_b1_grid,
CDataType* __restrict__ p_c_grid,
void* __restrict__ p_shared,
const AGridDesc& a_grid_desc,
const B0GridDesc_BK0_L_BK1& b0_grid_desc_k0_l_k1,
const B1GridDesc_BL0_N_BL1& b1_grid_desc_l0_n_l1,
const B0GridDesc& b0_grid_desc,
const B1GridDesc& b1_grid_desc,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
c_grid_desc_mblock_mperblock_nblock_nperblock,
const AElementwiseOperation& a_element_op,
......@@ -623,9 +809,9 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc.GetElementSpaceSize());
const auto b0_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b0_grid, b0_grid_desc_k0_l_k1.GetElementSpaceSize());
p_b0_grid, b0_grid_desc.GetElementSpaceSize());
const auto b1_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b1_grid, b1_grid_desc_l0_n_l1.GetElementSpaceSize());
p_b1_grid, b1_grid_desc.GetElementSpaceSize());
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
......@@ -648,17 +834,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
/*******************************************************************************/
// BlockLevel, A/B Matrix ThreadMapping in LDS, As Destinaion of BlockWise_Copy
const auto K = [&](){
if constexpr(AEnableLds){
return a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I2);
}
else{
return a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I3) * a_grid_desc.GetLength(I5);
}
}();
constexpr auto a_block_desc = MakeABlockDescriptor();
constexpr auto b0_block_desc_k0perblock_lperblock_k1 = GetB0BlockDescriptor_BK0PerBlock_LPerBlock_BK1();
constexpr auto b0_block_desc = MakeB0BlockDescriptor();
auto a_block_trait = [&](){
// A matrix blockwise copy
......@@ -666,7 +843,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
{
constexpr auto AK0PerBlock = KPerBlock/ AK1;
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatA*>(p_shared) + SharedMemTrait::a_block_space_offset,
static_cast<ADataType*>(p_shared) + SharedMemTrait::a_block_space_offset,
SharedMemTrait::a_block_space_size_aligned);
auto a_blockwise_copy =
......@@ -677,8 +854,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
/* typename BlockSliceLengths, */ Sequence<AK0PerBlock, MPerBlock, AK1>,
/* typename ThreadClusterLengths, */ ABlockTransferThreadClusterLengths_K0_M_K1,
/* typename ThreadClusterArrangeOrder, */ ABlockTransferThreadClusterArrangeOrder,
/* typename SrcData, */ FloatA,
/* typename DstData, */ FloatA,
/* typename SrcData, */ ADataType,
/* typename DstData, */ ADataType,
/* typename SrcDesc, */ decltype(a_grid_desc),
/* typename DstDesc, */ decltype(a_block_desc),
/* typename SrcDimAccessOrder, */ ABlockTransferSrcAccessOrder,
......@@ -705,13 +882,13 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
// Thread-wise copy
// KPerBlock/WmmaK -> MRepeat -> MWaves -> WmmaK/K1 -> MPerWmma -> K1
constexpr auto KWmmaPerBlock = KPerBlock / WmmaK;
auto a_block_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatA>(
auto a_block_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ADataType>(
a_block_desc.GetElementSpaceSize());
// Limitation: NumDim of Src and Dst descriptor should be identical
auto a_blockwise_copy =
ThreadwiseTensorSliceTransfer_v2<FloatA,
FloatA,
ThreadwiseTensorSliceTransfer_v2<ADataType,
ADataType,
decltype(a_grid_desc),
decltype(a_block_desc),
Sequence<Number<KWmmaPerBlock>{},
......@@ -736,20 +913,26 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
return make_tuple(a_block_buf, a_blockwise_copy);
}
};
auto b0_block_trait = [&](){
if constexpr(B0EnableLds)
{
auto b0_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<B0DataType*>(p_shared) + SharedMemTrait::b0_block_space_offset,
SharedMemTrait::b0_block_space_size_aligned);
// B matrix blockwise copy
auto b0_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
auto b0_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
B0ElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<BK0, LPerBlock, BK1>,
B0BlockTransferThreadClusterLengths_K0_L_K1,
B0BlockTransferThreadClusterArrangeOrder,
FloatB0,
FloatB0,
decltype(b0_grid_desc_k0_l_k1),
decltype(b0_block_desc_k0perblock_lperblock_k1),
B0DataType,
B0DataType,
decltype(b0_grid_desc),
decltype(b0_block_desc),
B0BlockTransferSrcAccessOrder,
Sequence<0, 1, 2>,
B0BlockTransferSrcVectorDim,
......@@ -760,15 +943,57 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
1,
B0ThreadTransferSrcResetCoordinateAfterRun,
true>(
b0_grid_desc_k0_l_k1,
b0_grid_desc,
make_multi_index(0, 0, 0),
b0_element_op,
b0_block_desc_k0perblock_lperblock_k1,
b0_block_desc,
make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{});
return make_tuple(b0_block_buf, b0_blockwise_copy);
}
else
{
// Thread-wise copy
// KPerBlock/WmmaK -> LRepeat -> LWaves -> KRow -> LPerWmma -> K1
constexpr auto KWmmaPerBlock = KPerBlock / WmmaK;
auto b0_block_buf = make_static_buffer<AddressSpaceEnum::Vgpr, B0DataType>(
b0_block_desc.GetElementSpaceSize());
// Limitation: NumDim of Src and Dst descriptor should be identical
auto b0_blockwise_copy =
ThreadwiseTensorSliceTransfer_v2<B0DataType,
B0DataType,
decltype(b0_grid_desc),
decltype(b0_block_desc),
Sequence<Number<KWmmaPerBlock>{},
Number<LRepeat>{},
I1,
I1,
I1,
Number<K1Value>{}>,
Sequence<0, 1, 2, 3, 4, 5>,
5,
B0BlockTransferSrcScalarPerVector,
B0ThreadTransferSrcResetCoordinateAfterRun,
true>(
b0_grid_desc,
make_multi_index(0,
0/(LWaves * LPerWmma),
get_thread_local_1d_id() / 32,
(get_thread_local_1d_id() % 32 )/ 16,
get_thread_local_1d_id() % 16,
0));
return make_tuple(b0_block_buf, b0_blockwise_copy);
}
};
auto a_block_buf = a_block_trait()[I0];
auto a_blockwise_copy = a_block_trait()[I1];
auto b0_block_buf = b0_block_trait()[I0];
auto b0_blockwise_copy = b0_block_trait()[I1];
/*******************************************************************************/
// Gemm0
......@@ -776,11 +1001,11 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
auto blockwise_gemm0 = BlockwiseGemmWMMA<
BlockSize,
FloatA,
FloatB0,
FloatAcc0,
ADataType,
B0DataType,
Acc0DataType,
decltype(MakeAWaveDescriptor(a_block_desc)),
decltype(MakeB0BlockDescriptor_K0_L0_L1_L2_K1(b0_block_desc_k0perblock_lperblock_k1)),
decltype(MakeB0WaveDescriptor(b0_block_desc)),
MPerBlock,
LPerBlock,
KPerBlock,
......@@ -816,16 +1041,10 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
make_tuple(Sequence<3, 4, 5>{}, Sequence<0, 1, 2>{}, Sequence<6>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
/*******************************************************************************/
// LDS allocation for A and B: be careful of alignment
auto b0_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatB0*>(p_shared) + SharedMemTrait::b0_block_space_offset,
SharedMemTrait::b0_block_space_size_aligned);
/*******************************************************************************/
// Shift Per SUB_K
constexpr auto a_block_slice_copy_step = MakeABlockSliceCopyStep();
constexpr auto b0_block_slice_copy_step = make_multi_index(BK0, 0, 0);
constexpr auto b0_block_slice_copy_step = MakeB0BlockSliceCopyStep();
const auto a_block_reset_copy_step = [&](){
if constexpr(AEnableLds){
......@@ -836,14 +1055,30 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
}
}();
const auto b0_block_reset_copy_step = make_multi_index(-b0_grid_desc_k0_l_k1.GetLength(I0), LPerBlock, 0);
const auto b0_block_reset_copy_step = [&](){
if constexpr(B0EnableLds){
return make_multi_index(-b0_grid_desc.GetLength(I0), LPerBlock, 0);
}
else{
return make_multi_index(-b0_grid_desc.GetLength(I0), LRepeat, 0, 0, 0, 0);
}
}();
const auto K = [&](){
if constexpr(AEnableLds){
return a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I2);
}
else{
return a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I3) * a_grid_desc.GetLength(I5);
}
}();
const index_t KBlockMainLoop = __builtin_amdgcn_readfirstlane(K / KPerBlock);
/*******************************************************************************/
// softmax
/*******************************************************************************/
auto workspace_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatAcc0*>(p_shared) + SharedMemTrait::reduction_space_offset,
static_cast<Acc0DataType*>(p_shared) + SharedMemTrait::reduction_space_offset,
SharedMemTrait::reduction_space_size_aligned);
// get acc0 7D thread cluster
constexpr auto thread_cluster_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs =
......@@ -879,7 +1114,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
make_tuple(mrepeat * mwave * mthreadpersubgroup, lrepeat * lwave * lsubgroup * laccvgprs));
auto blockwise_softmax = BlockwiseSoftmax<BlockSize,
FloatAcc0,
Acc0DataType,
decltype(threadid_to_l_n_thread_cluster_adaptor),
decltype(thread_cluster_desc_m_l),
decltype(thread_slice_desc_m_l)>{};
......@@ -889,15 +1124,11 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
SoftmaxBuf running_sum, running_sum_new, running_max, running_max_new;
running_sum = 0;
running_sum_new = 0;
running_max = NumericLimits<FloatAcc0>::Lowest();
running_max_new = NumericLimits<FloatAcc0>::Lowest();
running_max = NumericLimits<Acc0DataType>::Lowest();
running_max_new = NumericLimits<Acc0DataType>::Lowest();
/*******************************************************************************/
// set up Gemm1
/*******************************************************************************/
// B1 matrix in LDS memory, dst of blockwise copy
constexpr auto b1_block_desc_l0perblock_nperblock_l1 = GetB1BlockDescriptor_BL0PerBlock_NPerBlock_BL1();
constexpr auto b1_block_slice_copy_step = make_multi_index(BL0, 0, 0);
// Acc0 thread buffer -> A1 thread buffer -> blockwise gemm
// A1 matrix in VGPR
constexpr auto A1ThreadSlice_L0PerBlock_MPerBlock_L1 = make_tuple(
......@@ -915,8 +1146,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
// A1 matrix blockwise copy
auto a1_blockwise_copy = ThreadwiseTensorSliceTransfer_StaticToStatic<
FloatAcc0,
FloatA,
Acc0DataType,
ADataType,
decltype(acc0_thread_desc_l0perblock_mperblock_l1),
decltype(a1_thread_desc_l0perblock_mperblock_l1),
tensor_operation::element_wise::PassThrough,
......@@ -925,8 +1156,19 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
2,
laccvgprs>{tensor_operation::element_wise::PassThrough{}};
// B1 matrix blockwise copy
auto b1_blockwise_copy =
auto a1_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ADataType>(
a1_thread_desc_l0perblock_mperblock_l1.GetElementSpaceSize());
constexpr auto b1_block_desc = MakeB1BlockDescriptor();
auto b1_block_trait = [&](){
if constexpr(B1EnableLds)
{
auto b1_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<B1DataType*>(p_shared) + SharedMemTrait::b1_block_space_offset,
SharedMemTrait::b1_block_space_size_aligned);
auto b1_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1< ThisThreadBlock,
/* typename SrcElementwiseOperation, */ B1ElementwiseOperation,
/* typename DstElementwiseOperation, */ tensor_operation::element_wise::PassThrough,
......@@ -934,10 +1176,10 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
/* typename BlockSliceLengths, */ Sequence<BL0, NPerBlock, BL1>,
/* typename ThreadClusterLengths, */ B1BlockTransferThreadClusterLengths_L0_N_L1,
/* typename ThreadClusterArrangeOrder, */ B1BlockTransferThreadClusterArrangeOrder,
/* typename SrcData, */ FloatB1,
/* typename DstData, */ FloatB1,
/* typename SrcDesc, */ decltype(b1_grid_desc_l0_n_l1),
/* typename DstDesc, */ decltype(b1_block_desc_l0perblock_nperblock_l1),
/* typename SrcData, */ B1DataType,
/* typename DstData, */ B1DataType,
/* typename SrcDesc, */ decltype(b1_grid_desc),
/* typename DstDesc, */ decltype(b1_block_desc),
/* typename SrcDimAccessOrder, */ B1BlockTransferSrcAccessOrder,
/* typename DstDimAccessOrder, */ Sequence<1, 0, 2>,
/* index_t SrcVectorDim, */ B1BlockTransferSrcVectorDim,
......@@ -949,26 +1191,64 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
/* bool ThreadTransferSrcResetCoordinateAfterRun, */ B1ThreadTransferSrcResetCoordinateAfterRun,
/* bool ThreadTransferDstResetCoordinateAfterRun, */ true, // DstResetCoord
NumGemmKPrefetchStage>(
b1_grid_desc_l0_n_l1,
b1_grid_desc,
make_multi_index(0, n_block_data_idx_on_grid, 0),
b1_element_op,
b1_block_desc_l0perblock_nperblock_l1,
b1_block_desc,
make_multi_index(0, 0, 0),
tensor_operation::element_wise::PassThrough{});
auto a1_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatA>(
a1_thread_desc_l0perblock_mperblock_l1.GetElementSpaceSize());
auto b1_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatB1*>(p_shared)+ SharedMemTrait::b1_block_space_offset,
SharedMemTrait::b1_block_space_size_aligned);
return make_tuple(b1_block_buf, b1_blockwise_copy);
}
else
{
// Thread-wise copy
// KPerBlock/WmmaK -> NRepeat -> NWaves -> WmmaK/K1 -> NPerWmma -> K1
constexpr auto LWmmaPerBlock = LTilePerBlock / WmmaL;
auto b1_block_buf = make_static_buffer<AddressSpaceEnum::Vgpr, B1DataType>(
b1_block_desc.GetElementSpaceSize());
// Limitation: NumDim of Src and Dst descriptor should be identical
auto b1_blockwise_copy =
ThreadwiseTensorSliceTransfer_v2<B1DataType,
B1DataType,
decltype(b1_grid_desc),
decltype(b1_block_desc),
Sequence<Number<LWmmaPerBlock>{},
Number<NRepeat>{},
I1,
I1,
I1,
Number<L1Value>{}>,
Sequence<0, 1, 2, 3, 4, 5>,
5,
B1BlockTransferSrcScalarPerVector,
B1ThreadTransferSrcResetCoordinateAfterRun,
true>(
b1_grid_desc,
make_multi_index(0,
n_block_data_idx_on_grid/(NWaves * NPerWmma),
get_thread_local_1d_id() / 32,
(get_thread_local_1d_id() % 32 )/ 16,
get_thread_local_1d_id() % 16,
0));
return make_tuple(b1_block_buf, b1_blockwise_copy);
}
};
auto b1_block_buf = b1_block_trait()[I0];
auto b1_blockwise_copy = b1_block_trait()[I1];
constexpr auto b1_block_slice_copy_step = MakeB1BlockSliceCopyStep();
auto blockwise_gemm1 =
BlockwiseGemmWMMA<BlockSize,
FloatA,
FloatB1,
FloatAcc1,
decltype(MakeA1BlockDescriptor_L0_M0_M1_M2_L1(a1_thread_desc_l0perblock_mperblock_l1)),
decltype(MakeB1BlockDescriptor_L0_N0_N1_N2_L1(b1_block_desc_l0perblock_nperblock_l1)),
ADataType,
B1DataType,
Acc1DataType,
decltype(MakeA1WaveDescriptor_L0_M0_M1_M2_L1(a1_thread_desc_l0perblock_mperblock_l1)),
decltype(MakeB1WaveDescriptor(b1_block_desc)),
MPerBlock,
NPerBlock,
LTilePerBlock,
......@@ -983,11 +1263,20 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
auto acc1_thread_buf = blockwise_gemm1.GetCThreadBuffer();
const index_t num_gemm1_l_block_outer_loop = b0_grid_desc_k0_l_k1.GetLength(I1) / LPerBlock;
const auto L = [&](){
if constexpr(B0EnableLds){
return b0_grid_desc.GetLength(I1);
}
else{
return b0_grid_desc.GetLength(I1) * b0_grid_desc.GetLength(I2) * b0_grid_desc.GetLength(I4);
}
}();
const index_t num_gemm1_l_block_outer_loop = L / LPerBlock;
constexpr index_t num_gemm1_l_block_inner_loop = LPerBlock / LTilePerBlock;
// Initialize C
StaticBuffer<AddressSpaceEnum::Vgpr, FloatAcc1, acc1_thread_buf.Size(), true> c_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, Acc1DataType, acc1_thread_buf.Size(), true> c_thread_buf;
c_thread_buf.Clear();
/*******************************************************************************/
......@@ -1014,8 +1303,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
a_grid_buf,
a_block_buf,
a_block_slice_copy_step,
b0_grid_desc_k0_l_k1,
b0_block_desc_k0perblock_lperblock_k1,
b0_grid_desc,
b0_block_desc,
b0_blockwise_copy,
b0_grid_buf,
b0_block_buf,
......@@ -1106,20 +1395,20 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
acc1_thread_buf.Clear();
// preload data into LDS
b1_blockwise_copy.RunRead(b1_grid_desc_l0_n_l1, b1_grid_buf);
b1_blockwise_copy.RunRead(b1_grid_desc, b1_grid_buf);
b1_blockwise_copy.MoveSrcSliceWindow(b1_grid_desc_l0_n_l1,
b1_blockwise_copy.MoveSrcSliceWindow(b1_grid_desc,
b1_block_slice_copy_step);
block_sync_lds(); // wait for reduction LDS read
b1_blockwise_copy.RunWrite(b1_block_desc_l0perblock_nperblock_l1, b1_block_buf);
b1_blockwise_copy.RunWrite(b1_block_desc, b1_block_buf);
// main body
if constexpr(num_gemm1_l_block_inner_loop > 1)
{
static_for<0, num_gemm1_l_block_inner_loop - 1, 1>{}([&](auto i) {
// Data cast from FloatAcc0 to FloatA happen here
// Data cast from Acc0DataType to ADataType happen here
a1_blockwise_copy.Run(acc0_thread_desc_l0perblock_mperblock_l1,
make_tuple(Number<i * A1ThreadSliceL0PerBlock>{}, I0, I0),
acc0_thread_buf,
......@@ -1127,7 +1416,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
make_tuple(I0, I0, I0),
a1_thread_buf);
b1_blockwise_copy.RunRead(b1_grid_desc_l0_n_l1, b1_grid_buf);
b1_blockwise_copy.RunRead(b1_grid_desc, b1_grid_buf);
block_sync_lds();
......@@ -1135,10 +1424,10 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
block_sync_lds();
b1_blockwise_copy.MoveSrcSliceWindow(b1_grid_desc_l0_n_l1,
b1_blockwise_copy.MoveSrcSliceWindow(b1_grid_desc,
b1_block_slice_copy_step);
b1_blockwise_copy.RunWrite(b1_block_desc_l0perblock_nperblock_l1, b1_block_buf);
b1_blockwise_copy.RunWrite(b1_block_desc, b1_block_buf);
});
}
// tail
......@@ -1177,9 +1466,9 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
static_for<0, c_thread_buf_slice_m, 1>{}([&](auto iM) {
static_for<0, c_thread_buf_slice_n, 1>{}([&](auto iN) {
auto I = Number<c_thread_slice_desc_m_n.CalculateOffset(make_tuple(iM, iN))>{};
FloatAcc1 acc1 = acc1_thread_buf[I]; // P*V
FloatAcc1 c = c_thread_buf[I]; // O
FloatAcc1 c_new =
Acc1DataType acc1 = acc1_thread_buf[I]; // P*V
Acc1DataType c = c_thread_buf[I]; // O
Acc1DataType c_new =
(running_sum[iM] * math::exp(running_max[iM] - running_max_new[iM]) * c +
math::exp(max[iM] - running_max_new[iM]) * acc1) /
running_sum_new[iM];
......@@ -1190,7 +1479,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc,
a_block_reset_copy_step); // rewind K
b0_blockwise_copy.MoveSrcSliceWindow(b0_grid_desc_k0_l_k1,
b0_blockwise_copy.MoveSrcSliceWindow(b0_grid_desc,
b0_block_reset_copy_step); // rewind K and step N
// update before next j iteration
......@@ -1220,7 +1509,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat();
auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatCShuffle*>(p_shared),
static_cast<CShuffleDataType*>(p_shared),
c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat.GetElementSpaceSize());
constexpr auto c_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs = transform_tensor_descriptor(
......@@ -1268,8 +1557,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
// shuffle: threadwise copy C from VGPR to LDS
auto c_thread_copy_vgpr_to_lds =
ThreadwiseTensorSliceTransfer_v1r3<FloatAcc1,
FloatCShuffle,
ThreadwiseTensorSliceTransfer_v1r3<Acc1DataType,
CShuffleDataType,
decltype(c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs),
decltype(c_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs),
ck::tensor_operation::element_wise::PassThrough,
......@@ -1307,8 +1596,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
CShuffleNRepeatPerShuffle * NWave * NPerWmma>, // BlockSliceLengths,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
FloatCShuffle, // typename SrcData,
FloatC, // typename DstData,
CShuffleDataType, // typename SrcData,
CDataType, // typename DstData,
decltype(c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat),
decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
......
......@@ -719,7 +719,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
// Thread-wise copy
// KPerBlock/WmmaK -> NRepeat -> NWaves -> WmmaK/K1 -> NPerWmma -> K1
constexpr auto KWmmaPerBlock = KPerBlock / WmmaK;
auto b_block_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ADataType>(
auto b_block_buf = make_static_buffer<AddressSpaceEnum::Vgpr, BDataType>(
b_block_desc.GetElementSpaceSize());
// Limitation: NumDim of Src and Dst descriptor should be identical
......
......@@ -247,6 +247,34 @@ struct TransformBatchedContractionContractionToBatchedGemmGemm
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
template <typename BGridDesc_L_K,
typename WmmaK,
typename LRepeat,
typename LWaves,
typename LPerWmma,
typename BK1>
__host__ __device__ static constexpr auto
MakeB0GridDescriptor_BKWmma_LBlockRepeat_LWaves_BKRow_LPerWmma_BK1(
const BGridDesc_L_K& b_grid_desc_l_k,
const WmmaK&,
const LRepeat&,
const LWaves&,
const LPerWmma&,
const BK1&)
{
const auto L0 = b_grid_desc_l_k.GetLength(I0) / NPerBlock;
const auto K = b_grid_desc_l_k.GetLength(I1);
const auto BKWmma = K / WmmaK{};
constexpr auto BKRow = WmmaK{} / BK1{};
return transform_tensor_descriptor(
b_grid_desc_l_k,
make_tuple(make_unmerge_transform(make_tuple(BKWmma, BKRow, BK1{})),
make_unmerge_transform(make_tuple(L0 * LRepeat{}, LWaves{}, LPerWmma{}))),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 3, 5>{}, Sequence<1, 2, 4>{}));
}
//
// B1
//
......@@ -288,6 +316,34 @@ struct TransformBatchedContractionContractionToBatchedGemmGemm
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
template <typename BGridDesc_N_L,
typename WmmaL,
typename NRepeat,
typename NWaves,
typename NPerWmma,
typename BL1>
__host__ __device__ static constexpr auto
MakeB1GridDescriptor_BLWmma_NBlockRepeat_NWaves_BLRow_NPerWmma_BL1(
const BGridDesc_N_L& b_grid_desc_n_l,
const WmmaL&,
const NRepeat&,
const NWaves&,
const NPerWmma&,
const BL1&)
{
const auto N0 = b_grid_desc_n_l.GetLength(I0) / OPerBlock;
const auto L = b_grid_desc_n_l.GetLength(I1);
const auto BLWmma = L / WmmaL{};
constexpr auto BLRow = WmmaL{} / BL1{};
return transform_tensor_descriptor(
b_grid_desc_n_l,
make_tuple(make_unmerge_transform(make_tuple(BLWmma, BLRow, BL1{})),
make_unmerge_transform(make_tuple(N0 * NRepeat{}, NWaves{}, NPerWmma{}))),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 3, 5>{}, Sequence<1, 2, 4>{}));
}
//
// C
//
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment