Commit a0a469e4 authored by aska-0096's avatar aska-0096
Browse files

save progress

parent 3ddd3578
...@@ -129,7 +129,11 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle ...@@ -129,7 +129,11 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle
return make_tuple(c_thread_m, c_thread_n); return make_tuple(c_thread_m, c_thread_n);
} }
__host__ __device__ BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle() using Tuple5 = decltype(CalculateAThreadOriginDataIndex());
__host__ __device__ BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle(
Tuple4 a_origin = CalculateAThreadOriginDataIndex(),
Tuple4 b_origin = CalculateBThreadOriginDataIndex())
: a_thread_copy_(a_origin), b_thread_copy_(b_origin)
{ {
static_assert(AK0MK1BlockDesc::IsKnownAtCompileTime() && static_assert(AK0MK1BlockDesc::IsKnownAtCompileTime() &&
BK0NK1BlockDesc::IsKnownAtCompileTime(), BK0NK1BlockDesc::IsKnownAtCompileTime(),
...@@ -299,8 +303,8 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle ...@@ -299,8 +303,8 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle
B_K1, B_K1,
B_K1>; B_K1>;
AThreadCopy a_thread_copy_{CalculateAThreadOriginDataIndex()}; AThreadCopy a_thread_copy_;
BThreadCopy b_thread_copy_{CalculateBThreadOriginDataIndex()}; BThreadCopy b_thread_copy_;
}; };
// block wise level pipe designed for inline asm // block wise level pipe designed for inline asm
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" #include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle_v1.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp"
#include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp" #include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp"
#include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/kernel_launch.hpp"
......
...@@ -18,72 +18,134 @@ ...@@ -18,72 +18,134 @@
namespace ck { namespace ck {
template <typename FloatAB, template <typename GridwiseGemm,
typename FloatGemmAcc, typename FloatA,
typename FloatCShuffle, typename FloatB,
typename FloatC, typename FloatC,
typename AGridDesc_K0_M_K1,
typename BGridDesc_K0_N_K1,
typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation,
typename Block2CTileMap,
bool HasMainKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_gemm_wmma(
const FloatA* __restrict__ p_a_grid,
const FloatB* __restrict__ p_b0_grid,
FloatC* __restrict__ p_c_grid,
const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1,
const BGridDesc_K0_N_K1 b0_grid_desc_k0_l_k1,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock,
// const
// CGridDescriptor_MBlockxRepeat_MWave_MSubGroup_MAccVgprs_NBlockxRepeat_NWave_NThreadPerSubGroup
// c_grid_desc_mblockxrepeat_mwave_msubgroup_maccvgprs_nblockxrepeat_nwave_nthreadpersubgroup,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CElementwiseOperation c_element_op,
const Block2CTileMap block_2_ctile_map)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
p_b0_grid,
p_c_grid,
p_shared,
a_grid_desc_k0_m_k1,
b0_grid_desc_k0_l_k1,
c_grid_desc_mblock_mperblock_nblock_nperblock,
a_element_op,
b_element_op,
c_element_op,
block_2_ctile_map);
#else
ignore = p_a_grid;
ignore = p_b0_grid;
ignore = p_c_grid;
ignore = a_grid_desc_k0_m_k1;
ignore = b0_grid_desc_k0_l_k1;
ignore = c_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = a_element_op;
ignore = b_element_op;
ignore = c_element_op;
ignore = block_2_ctile_map;
#endif // end of if (defined(__gfx1100__))
}
// Gemm0: A [M x K] x B0 [K x L] = Acc [M x L]
// Gemm1: Acc [M x L] x B1 [L x N] = C [M x N]
template <index_t BlockSize,
typename FloatA,
typename FloatB0,
typename FloatB1,
typename FloatAcc,
typename FloatCShuffle,
typename FloatC,
typename AElementwiseOperation,
typename B0ElementwiseOperation,
typename AccElementwiseOperation, typename AccElementwiseOperation,
typename B1ElementwiseOperation, typename B1ElementwiseOperation,
typename CElementwiseOperation, typename CElementwiseOperation,
InMemoryDataOperationEnum CGlobalMemoryDataOperation, InMemoryDataOperationEnum CGlobalMemoryDataOperation,
typename AGridDesc_AK0_M_AK1, typename AGridDesc_K0_M_K1,
typename BGridDesc_BK0_N_BK1, typename B0GridDesc_K0_L_K1,
typename B1GridDesc_BK0_N_BK1, typename B1GridDesc_L0_N_L1,
typename CGridDesc_M_N, typename CGridDesc_M_N,
index_t NumGemmKPrefetchStage, index_t Gemm0MPerBlock,
index_t BlockSize, index_t Gemm0LPerBlock,
index_t MPerBlock, index_t Gemm0K0PerBlock,
index_t NPerBlock, index_t Gemm0K1Value,
index_t KPerBlock, index_t Gemm0MPerWmma,
index_t Gemm0LPerWmma,
index_t Gemm0MRepeat,
index_t Gemm0LRepeat,
index_t Gemm1MPerBlock,
index_t Gemm1NPerBlock, index_t Gemm1NPerBlock,
index_t Gemm1KPerBlock, index_t Gemm1L0PerBlock,
index_t AK1Value, index_t Gemm1L1Value,
index_t BK1Value, index_t Gemm1MPerWmma,
index_t B1K1Value, index_t Gemm1NPerWmma,
index_t MPerXdl, index_t Gemm1MRepeat,
index_t NPerXdl, index_t Gemm1NRepeat,
index_t MXdlPerWave, typename ABlockTransferThreadClusterLengths_K0_M_K1,
index_t NXdlPerWave,
index_t Gemm1NXdlPerWave,
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
typename ABlockTransferThreadClusterArrangeOrder, typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder, typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim, index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector, index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_AK1, index_t ABlockTransferDstScalarPerVector_K1,
bool AThreadTransferSrcResetCoordinateAfterRun, // ignored bool AThreadTransferSrcResetCoordinateAfterRun,
index_t ABlockLdsExtraM, bool ABlockLdsExtraM,
typename BBlockTransferThreadClusterLengths_BK0_N_BK1, typename B0BlockTransferThreadClusterLengths_K0_L_K1,
typename BBlockTransferThreadClusterArrangeOrder, typename B0BlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder, typename B0BlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim, index_t B0BlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector, index_t B0BlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_BK1, index_t B0BlockTransferDstScalarPerVector_K1,
bool BThreadTransferSrcResetCoordinateAfterRun, // ignored bool B0ThreadTransferSrcResetCoordinateAfterRun,
index_t BBlockLdsExtraN, bool B0BlockLdsExtraN,
typename B1BlockTransferThreadClusterLengths_BK0_N_BK1, typename B1BlockTransferThreadClusterLengths_L0_N_L1,
typename B1BlockTransferThreadClusterArrangeOrder, typename B1BlockTransferThreadClusterArrangeOrder,
typename B1BlockTransferSrcAccessOrder, typename B1BlockTransferSrcAccessOrder,
index_t B1BlockTransferSrcVectorDim, index_t B1BlockTransferSrcVectorDim,
index_t B1BlockTransferSrcScalarPerVector, index_t B1BlockTransferSrcScalarPerVector,
index_t B1BlockTransferDstScalarPerVector_BK1, index_t B1BlockTransferDstScalarPerVector_L1,
bool B1ThreadTransferSrcResetCoordinateAfterRun, bool B1ThreadTransferSrcResetCoordinateAfterRun,
index_t B1BlockLdsExtraN, bool B1BlockLdsExtraN,
index_t CShuffleMXdlPerWavePerShuffle, index_t CShuffleMRepeatPerShuffle,
index_t CShuffleNXdlPerWavePerShuffle, index_t CShuffleNRepeatPerShuffle,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock, index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopScheduler LoopSched, index_t NumGemmKPrefetchStage = 1,
bool PadN, LoopScheduler LoopSched = make_default_loop_scheduler(),
bool MaskOutUpperTriangle, PipelineVersion PipelineVer = PipelineVersion::v1>
PipelineVersion PipelineVer = PipelineVersion::v1> struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
{ {
static_assert(LoopSched == LoopScheduler::Default,
"Non-default loop scheduler is currently not supported");
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{}; static constexpr auto I2 = Number<2>{};
...@@ -94,161 +156,127 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -94,161 +156,127 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
static constexpr auto I7 = Number<7>{}; static constexpr auto I7 = Number<7>{};
// K1 should be Number<...> // K1 should be Number<...>
// Gemm0 static constexpr auto K1 = Number<Gemm0K1Value>{};
static constexpr auto AK0 = Number<KPerBlock / AK1Value>{}; static constexpr auto N1 = Number<Gemm1N1Value>{};
static constexpr auto BK0 = Number<KPerBlock / BK1Value>{};
static constexpr auto AK1 = Number<AK1Value>{};
static constexpr auto BK1 = Number<BK1Value>{};
static constexpr auto Gemm0MWaves = MPerBlock / (MPerXdl * MXdlPerWave);
static constexpr auto Gemm0NWaves = NPerBlock / (NPerXdl * NXdlPerWave);
// Gemm1
static constexpr auto B1K0 = Number<Gemm1KPerBlock / B1K1Value>{};
static constexpr auto B1K1 = Number<B1K1Value>{};
using ThisThreadBlock = ThisThreadBlock<BlockSize>; using ThisThreadBlock = ThisThreadBlock<BlockSize>;
using GridwiseGemmPipe = remove_cvref_t<decltype( using GridwiseGemmPipe = remove_cvref_t<decltype(
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage>())>; GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>())>;
template <typename ABlockDesc_AK0_M_AK1> __host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1()
__host__ __device__ static constexpr auto
MakeGemm0AMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1&)
{ {
constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl); constexpr auto max_lds_align = K1;
return MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K<MXdlPerWave, MWaves, MPerXdl>( // A matrix in LDS memory, dst of blockwise copy
ABlockDesc_AK0_M_AK1{}); constexpr auto a_block_desc_k0perblock_mperblock_k1 = [&]() {
} if constexpr(ABlockLdsExtraM)
{
template <typename BBlockDesc_BK0_N_BK1> return make_naive_tensor_descriptor(
__host__ __device__ static constexpr auto make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1),
MakeGemm0BMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1&) make_tuple(Number<MPerBlock + 1>{} * K1, K1, I1));
{ }
constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl); else
{
return MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K<NXdlPerWave, NWaves, NPerXdl>( return make_naive_tensor_descriptor_aligned(
BBlockDesc_BK0_N_BK1{}); make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
} }
}();
template <typename ABlockDesc_AK0_M_AK1>
__host__ __device__ static constexpr auto
MakeGemm1AMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1&)
{
return MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K<MXdlPerWave, 1, 1>(ABlockDesc_AK0_M_AK1{});
}
template <typename BBlockDesc_BK0_N_BK1> return a_block_desc_k0perblock_mperblock_k1;
__host__ __device__ static constexpr auto
MakeGemm1BMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1&)
{
constexpr index_t Gemm1NWaves = Gemm1NPerBlock / (Gemm1NXdlPerWave * NPerXdl);
return MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K<Gemm1NXdlPerWave, Gemm1NWaves, NPerXdl>(
BBlockDesc_BK0_N_BK1{});
} }
__host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() __host__ __device__ static constexpr auto GetBBlockDescriptor_K0PerBlock_NPerBlock_K1()
{ {
// A matrix in LDS memory, dst of blockwise copy constexpr auto max_lds_align = K1;
return make_naive_tensor_descriptor(
make_tuple(AK0, Number<MPerBlock>{}, AK1),
make_tuple(Number<MPerBlock + ABlockLdsExtraM>{} * AK1, AK1, I1));
}
__host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
{
// B matrix in LDS memory, dst of blockwise copy // B matrix in LDS memory, dst of blockwise copy
return make_naive_tensor_descriptor( constexpr auto b_block_desc_k0perblock_nperblock_k1 = [&]() {
make_tuple(BK0, Number<NPerBlock>{}, BK1), if constexpr(BBlockLdsExtraN)
make_tuple(Number<NPerBlock + BBlockLdsExtraN>{} * BK1, BK1, I1)); {
} return make_naive_tensor_descriptor(
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1),
make_tuple(Number<NPerBlock + 1>{} * K1, K1, I1));
}
else
{
return make_naive_tensor_descriptor_aligned(
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
}
}();
__host__ __device__ static constexpr auto GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1() return b_block_desc_k0perblock_nperblock_k1;
{
// B1 matrix in LDS memory, dst of blockwise copy
return make_naive_tensor_descriptor(
make_tuple(B1K0, Number<Gemm1NPerBlock>{}, B1K1),
make_tuple(Number<Gemm1NPerBlock + B1BlockLdsExtraN>{} * B1K1, B1K1, I1));
} }
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() // *Caution Here repeat is shuffle repeat
GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat()
{ {
constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); constexpr index_t MWave = MPerBlock / (MRepeat * MPerWmma);
constexpr index_t NWave = Gemm1NPerBlock / (Gemm1NXdlPerWave * NPerXdl); constexpr index_t NWave = NPerBlock / (NRepeat * NPerWmma);
constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat =
make_naive_tensor_descriptor_packed( make_naive_tensor_descriptor_packed(
make_tuple(I1, make_tuple(I1,
Number<CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl>{}, Number<CShuffleMRepeatPerShuffle * MWave * MPerWmma>{},
I1, I1,
Number<CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>{})); Number<CShuffleNRepeatPerShuffle * NWave * NPerWmma>{}));
return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock; return c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat;
} }
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{ {
const index_t gemm0_bytes_end = (SharedMemTrait::a_block_space_size_aligned + // LDS allocation for A and B: be careful of alignment
SharedMemTrait::b_block_space_size_aligned) * constexpr auto a_block_desc_k0perblock_mperblock_k1 =
sizeof(FloatAB); GetABlockDescriptor_K0PerBlock_MPerBlock_K1();
const index_t gemm1_bytes_end =
(SharedMemTrait::b1_block_space_offset + SharedMemTrait::b1_block_space_size_aligned) * constexpr auto b_block_desc_k0perblock_nperblock_k1 =
sizeof(FloatAB); GetBBlockDescriptor_K0PerBlock_NPerBlock_K1();
const index_t softmax_bytes_end = (SharedMemTrait::reduction_space_offset +
SharedMemTrait::reduction_space_size_aligned) * constexpr auto max_lds_align = K1;
sizeof(FloatGemmAcc);
const index_t c_block_bytes_end = constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
SharedMemTrait::c_block_space_size * sizeof(FloatCShuffle); a_block_desc_k0perblock_mperblock_k1.GetElementSpaceSize(), max_lds_align);
return math::max(gemm0_bytes_end, gemm1_bytes_end, softmax_bytes_end, c_block_bytes_end); constexpr auto b_block_space_size_aligned = math::integer_least_multiple(
b_block_desc_k0perblock_nperblock_k1.GetElementSpaceSize(), max_lds_align);
return (a_block_space_size_aligned * sizeof(FloatA) +
b_block_space_size_aligned * sizeof(FloatB));
} }
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
template <typename Block2CTileMap> template <typename Block2CTileMap>
__host__ __device__ static constexpr bool __host__ __device__ static constexpr bool
CheckValidity(const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1, CheckValidity(const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1,
const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1, const BGridDesc_K0_N_K1& b0_grid_desc_k0_l_k1,
const B1GridDesc_BK0_N_BK1& b1_grid_desc_bk0_n_bk1,
const CGridDesc_M_N& c_grid_desc_m_n, const CGridDesc_M_N& c_grid_desc_m_n,
const Block2CTileMap& block_2_ctile_map) const Block2CTileMap& block_2_ctile_map)
{ {
static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) && static_assert(is_known_at_compile_time<remove_cv_t<decltype(K1)>>::value,
(NPerBlock % (NXdlPerWave * NPerXdl)) == 0, "wrong! K1 need to be known at compile-time");
static_assert((MPerBlock % (MPerWmma * MRepeat) == 0) &&
(NPerBlock % (NRepeat * NPerWmma)) == 0,
"Invalid tuning param!"); "Invalid tuning param!");
const auto M = a_grid_desc_ak0_m_ak1.GetLength(I1); const auto M = a_grid_desc_k0_m_k1.GetLength(I1);
const auto N = b_grid_desc_bk0_n_bk1.GetLength(I1); const auto N = b0_grid_desc_k0_l_k1.GetLength(I1);
const auto K = a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2); const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0);
const auto Gemm1N = b1_grid_desc_bk0_n_bk1.GetLength(I1);
if(!(M == c_grid_desc_m_n.GetLength(I0) && Gemm1N == c_grid_desc_m_n.GetLength(I1))) if(!(M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1) &&
{ K0 == b0_grid_desc_k0_l_k1.GetLength(I0) && K1 == a_grid_desc_k0_m_k1.GetLength(I2) &&
K1 == b0_grid_desc_k0_l_k1.GetLength(I2)))
return false; return false;
}
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0 && if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0))
Gemm1N % Gemm1NPerBlock == 0))
{
return false; return false;
}
// check gemm0 gridwise gemm pipeline // check gridwise gemm pipeline
const auto num_gemm0_k_loop = K / KPerBlock; const auto num_k_loop = K0 / K0PerBlock;
if(!GridwiseGemmPipe::IsSupported(num_gemm0_k_loop))
{
return false;
}
// check gemm1 gridwise gemm pipeline if(!GridwiseGemmPipe::IsSupported(num_k_loop))
if(!(NPerBlock % Gemm1KPerBlock == 0))
{
return false;
}
const auto num_gemm1_k_inner_loop = NPerBlock / Gemm1KPerBlock;
if(!GridwiseGemmPipe::IsSupported(num_gemm1_k_inner_loop))
{ {
return false; return false;
} }
...@@ -264,7 +292,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -264,7 +292,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
__host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
{ {
const index_t num_loop = K / KPerBlock; const index_t num_loop = K / (K0PerBlock * K1);
return GridwiseGemmPipe::CalculateHasMainLoop(num_loop); return GridwiseGemmPipe::CalculateHasMainLoop(num_loop);
} }
...@@ -276,12 +304,12 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -276,12 +304,12 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
const auto N = c_grid_desc_m_n.GetLength(I1); const auto N = c_grid_desc_m_n.GetLength(I1);
const auto MBlock = M / MPerBlock; const auto MBlock = M / MPerBlock;
const auto NBlock = N / Gemm1NPerBlock; const auto NBlock = N / NPerBlock;
const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
c_grid_desc_m_n, c_grid_desc_m_n,
make_tuple(make_unmerge_transform(make_tuple(MBlock, Number<MPerBlock>{})), make_tuple(make_unmerge_transform(make_tuple(MBlock, Number<MPerBlock>{})),
make_unmerge_transform(make_tuple(NBlock, Number<Gemm1NPerBlock>{}))), make_unmerge_transform(make_tuple(NBlock, Number<NPerBlock>{}))),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}));
...@@ -289,284 +317,225 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -289,284 +317,225 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
} }
// return block_id to C matrix tile idx (m0, n0) mapping // return block_id to C matrix tile idx (m0, n0) mapping
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto MakeDefaultBlock2CTileMap(
MakeDefaultBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n) const CGridDesc_M_N& c_grid_desc_m_n, index_t /* M01 */, index_t /* N01 */)
{ {
return BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, Gemm1NPerBlock, CGridDesc_M_N>( return BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, CGridDesc_M_N>(
c_grid_desc_m_n); c_grid_desc_m_n);
} }
using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype( using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{}))>; MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{}))>;
using DefaultBlock2CTileMap = using DefaultBlock2CTileMap =
remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}))>; remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1))>;
struct SharedMemTrait template <bool HasMainKBlockLoop, typename C0MatrixMask, typename Block2CTileMap = DefaultBlock2CTileMap>
{ __device__ static void Run(const FloatA* __restrict__ p_a_grid,
// LDS allocation for A and B: be careful of alignment const FloatB0* __restrict__ p_b0_grid,
static constexpr auto a_block_desc_ak0_m_ak1 = const FloatB1* __restrict__ p_b1_grid,
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
static constexpr auto b_block_desc_bk0_n_bk1 =
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
static constexpr auto b1_block_desc_bk0_n_bk1 =
GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1();
static constexpr auto max_lds_align = math::lcm(math::lcm(AK1, BK1), B1K1);
static constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
static constexpr auto b_block_space_size_aligned = math::integer_least_multiple(
b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
static constexpr auto b1_block_space_size_aligned = math::integer_least_multiple(
b1_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
static constexpr auto a_block_space_offset = 0;
static constexpr auto b_block_space_offset = a_block_space_size_aligned.value;
static constexpr auto b1_block_space_offset = 0;
// LDS allocation for reduction
static constexpr index_t reduction_space_size_aligned =
math::integer_least_multiple(BlockSize, max_lds_align);
static constexpr auto reduction_space_offset = 0;
// LDS allocation for C shuffle in LDS
static constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
static constexpr auto c_block_space_size =
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
};
template <bool HasMainKBlockLoop, typename Block2CTileMap, typename C0MatrixMask>
__device__ static void Run(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
const FloatAB* __restrict__ p_b1_grid,
FloatC* __restrict__ p_c_grid, FloatC* __restrict__ p_c_grid,
void* __restrict__ p_shared, void* __restrict__ p_shared,
const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1,
const B0GridDesc_K0_L_K1& b0_grid_desc_k0_l_k1,
const B1GridDesc_L0_N_L1& b1_grid_desc_l0_n_l1,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
c_grid_desc_mblock_mperblock_nblock_nperblock,
const AElementwiseOperation& a_element_op, const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op, const B0ElementwiseOperation& b0_element_op,
const AccElementwiseOperation& acc_element_op, const AccElementwiseOperation& acc_element_op,
const B1ElementwiseOperation& b1_element_op, const B1ElementwiseOperation& b1_element_op,
const CElementwiseOperation& c_element_op, const CElementwiseOperation& c_element_op,
const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1, const C0MatrixMask& c0_matrix_mask,
const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1, const Block2CTileMap& block_2_ctile_map)
const B1GridDesc_BK0_N_BK1& b1_grid_desc_bk0_n_bk1,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
c_grid_desc_mblock_mperblock_nblock_nperblock,
const Block2CTileMap& block_2_ctile_map,
const C0MatrixMask& c0_matrix_mask)
{ {
// clang-format off
/*******************************************************************************/
// Memory buffer zone.
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); p_a_grid, a_grid_desc_k0_m_k1.GetElementSpaceSize());
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto b0_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); p_b0_grid, b0_grid_desc_k0_l_k1.GetElementSpaceSize());
const auto b1_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto b1_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b1_grid, b1_grid_desc_bk0_n_bk1.GetElementSpaceSize()); p_b1_grid, b1_grid_desc_l0_n_l1.GetElementSpaceSize());
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
// divide block work by [M, N] /*******************************************************************************/
const auto block_work_idx = // BlockIdx.x -> [BlockId.m, BlockId.n]
block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); const auto block_work_idx = block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
if(!block_2_ctile_map.ValidCTileIndex( if(!block_2_ctile_map.ValidCTileIndex(
block_work_idx, block_work_idx,
make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
{ { return; }
return;
}
// HACK: this force m/gemm1_n_block_data_idx_on_grid into SGPR // Store BlockId into SGPR
const index_t m_block_data_idx_on_grid = const index_t m_block_data_idx_on_grid = __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock);
__builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock); const index_t n_block_data_idx_on_grid = __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock);
const index_t gemm1_n_block_data_idx_on_grid = /*******************************************************************************/
__builtin_amdgcn_readfirstlane(block_work_idx[I1] * Gemm1NPerBlock); // set up Gemm0
/*******************************************************************************/
// A matrix in LDS memory, dst of blockwise copy /*******************************************************************************/
constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); // BlockLevel, A/B Matrix ThreadMapping in LDS, As Destinaion of BlockWise_Copy
const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0);
// B matrix in LDS memory, dst of blockwise copy constexpr auto max_lds_align = K1;
constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); constexpr auto a_block_desc_k0perblock_mperblock_k1 = GetABlockDescriptor_K0PerBlock_MPerBlock_K1();
constexpr auto b_block_desc_k0perblock_nperblock_k1 = GetBBlockDescriptor_K0PerBlock_NPerBlock_K1();
//
// set up Gemm0
//
// A matrix blockwise copy // A matrix blockwise copy
auto a_blockwise_copy = auto a_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock, ThreadGroupTensorSliceTransfer_v4r1< ThisThreadBlock,
AElementwiseOperation, /* typename SrcElementwiseOperation, */ AElementwiseOperation,
tensor_operation::element_wise::PassThrough, /* typename DstElementwiseOperation, */ ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set, /* InMemoryDataOperationEnum DstInMemOp, */ InMemoryDataOperationEnum::Set,
Sequence<AK0, MPerBlock, AK1>, /* typename BlockSliceLengths, */ Sequence<K0PerBlock, MPerBlock, K1>,
ABlockTransferThreadClusterLengths_AK0_M_AK1, /* typename ThreadClusterLengths, */ ABlockTransferThreadClusterLengths_K0_M_K1,
ABlockTransferThreadClusterArrangeOrder, /* typename ThreadClusterArrangeOrder, */ ABlockTransferThreadClusterArrangeOrder,
FloatAB, /* typename SrcData, */ FloatA,
FloatAB, /* typename DstData, */ FloatA,
decltype(a_grid_desc_ak0_m_ak1), /* typename SrcDesc, */ decltype(a_grid_desc_k0_m_k1),
decltype(a_block_desc_ak0_m_ak1), /* typename DstDesc, */ decltype(a_block_desc_k0perblock_mperblock_k1),
ABlockTransferSrcAccessOrder, /* typename SrcDimAccessOrder, */ ABlockTransferSrcAccessOrder,
Sequence<1, 0, 2>, /* typename DstDimAccessOrder, */ Sequence<0, 1, 2>,
ABlockTransferSrcVectorDim, /* index_t SrcVectorDim, */ ABlockTransferSrcVectorDim,
2, /* index_t DstVectorDim, */ 2,
ABlockTransferSrcScalarPerVector, /* index_t SrcScalarPerVector, */ ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1, /* index_t DstScalarPerVector, */ ABlockTransferDstScalarPerVector_K1,
1, /* index_t SrcScalarStrideInVector, */ 1,
1, /* index_t DstScalarStrideInVector, */ 1,
true, // SrcResetCoord /* bool ThreadTransferSrcResetCoordinateAfterRun, */ AThreadTransferSrcResetCoordinateAfterRun,
true, // DstResetCoord /* bool ThreadTransferDstResetCoordinateAfterRun, */ true>(
NumGemmKPrefetchStage>( a_grid_desc_k0_m_k1,
a_grid_desc_ak0_m_ak1,
make_multi_index(0, m_block_data_idx_on_grid, 0), make_multi_index(0, m_block_data_idx_on_grid, 0),
a_element_op, a_element_op,
a_block_desc_ak0_m_ak1, a_block_desc_k0perblock_mperblock_k1,
make_multi_index(0, 0, 0), make_multi_index(0, 0, 0),
tensor_operation::element_wise::PassThrough{}); ck::tensor_operation::element_wise::PassThrough{});
// B matrix blockwise copy // B matrix blockwise copy
auto b_blockwise_copy = auto b_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock, ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
BElementwiseOperation, BElementwiseOperation,
tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
Sequence<BK0, NPerBlock, BK1>, Sequence<K0PerBlock, NPerBlock, K1>,
BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterLengths_K0_N_K1,
BBlockTransferThreadClusterArrangeOrder, BBlockTransferThreadClusterArrangeOrder,
FloatAB, FloatB,
FloatAB, FloatB,
decltype(b_grid_desc_bk0_n_bk1), decltype(b0_grid_desc_k0_l_k1),
decltype(b_block_desc_bk0_n_bk1), decltype(b_block_desc_k0perblock_nperblock_k1),
BBlockTransferSrcAccessOrder, BBlockTransferSrcAccessOrder,
Sequence<1, 0, 2>, Sequence<0, 1, 2>,
BBlockTransferSrcVectorDim, BBlockTransferSrcVectorDim,
2, 2,
BBlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1, BBlockTransferDstScalarPerVector_K1,
1, 1,
1, 1,
true, // SrcResetCoord BThreadTransferSrcResetCoordinateAfterRun,
true, // DstResetCoord true>(
NumGemmKPrefetchStage>( b0_grid_desc_k0_l_k1,
b_grid_desc_bk0_n_bk1, make_multi_index(0, n_block_data_idx_on_grid, 0),
make_multi_index(0, 0, 0), // will loop over GemmN dimension
b_element_op, b_element_op,
b_block_desc_bk0_n_bk1, b_block_desc_k0perblock_nperblock_k1,
make_multi_index(0, 0, 0), make_multi_index(0, 0, 0),
tensor_operation::element_wise::PassThrough{}); ck::tensor_operation::element_wise::PassThrough{});
// Fused Gemm+Gemm pipeline /*******************************************************************************/
// for n in N0: // Gemm0
// for k in K0: constexpr auto WmmaK = 16;
// acc[m][n] += A[m][k] * B0[k][n] constexpr auto KPack = math::integer_least_multiple(K1, WmmaK);
// acc1[m][o] += acc[m][n] * B1[n][o]
auto blockwise_gemm0 =
// sanity check BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO<BlockSize,
constexpr index_t KPack = math::max( FloatA,
math::lcm(AK1, BK1), MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma.k_per_blk); FloatB,
FloatAcc,
auto blockwise_gemm = BlockwiseGemmXdlops_v2< decltype(a_block_desc_k0perblock_mperblock_k1),
BlockSize, decltype(b_block_desc_k0perblock_nperblock_k1),
FloatAB, MPerWmma,
FloatGemmAcc, NPerWmma,
decltype(a_block_desc_ak0_m_ak1), MRepeat,
decltype(b_block_desc_bk0_n_bk1), NRepeat,
decltype(MakeGemm0AMmaTileDescriptor_M0_M1_M2_K(a_block_desc_ak0_m_ak1)), KPack>{};
decltype(MakeGemm0BMmaTileDescriptor_N0_N1_N2_K(b_block_desc_bk0_n_bk1)),
MPerBlock, // Prepare Register for A*B0 matrix
NPerBlock,
KPerBlock,
MPerXdl,
NPerXdl,
MXdlPerWave,
NXdlPerWave,
KPack,
true>{}; // TransposeC
auto acc_thread_buf = blockwise_gemm.GetCThreadBuffer(); auto acc_thread_buf = blockwise_gemm.GetCThreadBuffer();
/*******************************************************************************/
constexpr auto a_block_space_size_aligned = math::integer_least_multiple(a_block_desc_k0perblock_mperblock_k1.GetElementSpaceSize(), max_lds_align);
// LDS allocation for A and B: be careful of alignment // LDS allocation for A and B: be careful of alignment
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(static_cast<FloatA*>(p_shared), a_block_desc_k0perblock_mperblock_k1.GetElementSpaceSize());
static_cast<FloatAB*>(p_shared) + SharedMemTrait::a_block_space_offset, auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(static_cast<FloatB*>(p_shared) + a_block_space_size_aligned, b_block_desc_k0perblock_nperblock_k1.GetElementSpaceSize());
a_block_desc_ak0_m_ak1.GetElementSpaceSize());
// Shift Per SUB_K
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0);
static_cast<FloatAB*>(p_shared) + SharedMemTrait::b_block_space_offset, constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0);
b_block_desc_bk0_n_bk1.GetElementSpaceSize()); const auto a_block_reset_copy_step = make_multi_index(-a_grid_desc_k0_m_k1.GetLength(I0), 0, 0);
const auto b_block_reset_copy_step = make_multi_index(-b0_grid_desc_k0_l_k1.GetLength(I0), LPerBlock, 0);
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0); /*******************************************************************************/
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1, 0, 0); // softmax
const auto a_block_reset_copy_step = /*******************************************************************************/
make_multi_index(-a_grid_desc_ak0_m_ak1.GetLength(I0), 0, 0); auto workspace_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(static_cast<FloatAcc*>(p_shared), math::integer_least_multiple(BlockSize, max_lds_align));
const auto b_block_reset_copy_step = // get acc0 8D thread cluster
make_multi_index(-b_grid_desc_bk0_n_bk1.GetLength(I0), NPerBlock, 0); constexpr auto thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4 =
blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4().GetLengths() /
// gridwise GEMM pipeline blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4().GetLengths();
// Only supports LoopScheduler::Default constexpr auto tm0 = thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4.At(I0);
const auto gridwise_gemm_pipeline = GridwiseGemmPipeline_Selector<PipelineVer, constexpr auto tn0 = thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4.At(I1);
NumGemmKPrefetchStage, constexpr auto tm1 = thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4.At(I2);
LoopScheduler::Default>(); constexpr auto tn1 = thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4.At(I3);
constexpr auto tm2 = thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4.At(I4);
const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( constexpr auto tn2 = thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4.At(I5);
(a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / constexpr auto tn3 = thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4.At(I6);
KPerBlock); constexpr auto tn4 = thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4.At(I7);
//
// set up Gemm1
//
// Acc matrix threadwise copy: AccVGPR to VGPR and downcast to XDL input data type
constexpr auto acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 =
blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
constexpr auto m0 = acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I0);
constexpr auto n0 = acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I1);
constexpr auto m1 = acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I2);
constexpr auto n1 = acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I3);
constexpr auto m2 = acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I4);
constexpr auto n2 = acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I5);
constexpr auto n3 = acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I6);
constexpr auto n4 = acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I7);
constexpr auto b1_block_slice_copy_step = make_multi_index(Gemm1KPerBlock / B1K1, 0, 0);
// acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 to acc_thread_desc_k0_m_k1
// n0_n1_n2_n3 -> k0
// m0_m1_m2 -> m
// n4 -> k1
// NOTE: had to use merge_v3 or will spit out compilation errors
constexpr auto acc_thread_desc_k0_m_k1 = transform_tensor_descriptor(
acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4,
make_tuple(make_merge_transform_v3_division_mod(make_tuple(n0, n1, n2, n3)),
make_merge_transform_v3_division_mod(make_tuple(m0, m1, m2)),
make_pass_through_transform(n4)),
make_tuple(Sequence<1, 3, 5, 6>{}, Sequence<0, 2, 4>{}, Sequence<7>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
// A1 matrix in AccVGPR
// N2 num_groups_per_blk, N3 num_input_blks, N4 group_size
constexpr auto AccN3 =
blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4().GetLength(I6);
constexpr auto A1ThreadSlice_K0_M_K1 =
make_tuple(Number<Gemm1KPerBlock / n4 / AccN3>{}, Number<m0 * m1 * m2>{}, Number<n4>{});
constexpr auto A1ThreadSliceK0 = A1ThreadSlice_K0_M_K1[I0];
constexpr auto A1ThreadSliceM = A1ThreadSlice_K0_M_K1[I1];
constexpr auto A1ThreadSliceK1 = A1ThreadSlice_K0_M_K1[I2];
constexpr auto a1_thread_desc_k0_m_k1 = make_naive_tensor_descriptor(
A1ThreadSlice_K0_M_K1,
make_tuple(A1ThreadSliceM * A1ThreadSliceK1, A1ThreadSliceK1, I1));
// get acc0 thread map
constexpr auto m0_n_m1_to_m_n_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(tm0 * tm1, tm2)),
make_pass_through_transform(I1)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
constexpr auto threadid_to_m0_n_m1_adaptor = make_single_stage_tensor_adaptor(
make_tuple(
make_merge_transform(make_tuple(tm0 * tm1, tn0 * tn1 * tn2 * tn3 * tn4, tm2))),
make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{}));
const auto threadid_to_m_n_thread_cluster_adaptor =
chain_tensor_adaptors(m0_n_m1_to_m_n_adaptor, threadid_to_m0_n_m1_adaptor);
// get acc0 2D thread cluster & 2D thread slice
constexpr auto thread_cluster_desc_m_n = make_naive_tensor_descriptor_packed(
make_tuple(tm0 * tm1 * tm2, tn0 * tn1 * tn2 * tn3 * tn4));
constexpr auto thread_slice_desc_m_n =
make_naive_tensor_descriptor_packed(make_tuple(m0 * m1 * m2, n0 * n1 * n2 * n3 * n4));
auto blockwise_softmax = BlockwiseSoftmax<BlockSize,
FloatGemmAcc,
decltype(threadid_to_m_n_thread_cluster_adaptor),
decltype(thread_cluster_desc_m_n),
decltype(thread_slice_desc_m_n)>{};
// Initialize running sum and max of exponentiating row vectors
using SoftmaxBuf = typename decltype(blockwise_softmax)::BufferType;
SoftmaxBuf running_sum, running_sum_new, running_max, running_max_new;
running_sum = 0;
running_sum_new = 0;
running_max = NumericLimits<FloatGemmAcc>::Lowest();
running_max_new = NumericLimits<FloatGemmAcc>::Lowest();
/*******************************************************************************/
// set up Gemm1
/*******************************************************************************/
// B1 matrix in LDS memory, dst of blockwise copy // B1 matrix in LDS memory, dst of blockwise copy
constexpr auto b1_block_desc_bk0_n_bk1 = GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1(); constexpr auto b1_block_desc_l0perblock_nperblock_l1 = GetB1BlockDescriptor_L0PerBlock_NPerBlock_L1();
// A1 matrix blockwise copy // A1 matrix blockwise copy
auto a1_blockwise_copy = ThreadwiseTensorSliceTransfer_StaticToStatic< auto a1_blockwise_copy = ThreadwiseTensorSliceTransfer_StaticToStatic<
FloatGemmAcc, FloatAcc,
FloatAB, FloatA,
decltype(acc_thread_desc_k0_m_k1), decltype(acc_thread_desc_k0_m_k1),
decltype(a1_thread_desc_k0_m_k1), decltype(a1_thread_desc_k0_m_k1),
tensor_operation::element_wise::PassThrough, tensor_operation::element_wise::PassThrough,
...@@ -574,7 +543,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -574,7 +543,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
Sequence<1, 0, 2>, Sequence<1, 0, 2>,
2, 2,
n4>{tensor_operation::element_wise::PassThrough{}}; n4>{tensor_operation::element_wise::PassThrough{}};
// B1 matrix blockwise copy // B1 matrix blockwise copy
auto b1_blockwise_copy = auto b1_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock, ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
...@@ -605,142 +574,56 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -605,142 +574,56 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
b1_block_desc_bk0_n_bk1, b1_block_desc_bk0_n_bk1,
make_multi_index(0, 0, 0), make_multi_index(0, 0, 0),
tensor_operation::element_wise::PassThrough{}); tensor_operation::element_wise::PassThrough{});
auto a1_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAB>( auto a1_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatA>(a1_thread_desc_k0_m_k1.GetElementSpaceSize());
a1_thread_desc_k0_m_k1.GetElementSpaceSize()); auto b1_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(static_cast<FloatB*>(p_shared), b_block_desc_k0perblock_nperblock_k1.GetElementSpaceSize());
// reuse LDS space for gemm0's b_block_buf auto blockwise_gemm1 =
auto b1_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle<BlockSize,
static_cast<FloatAB*>(p_shared) + SharedMemTrait::b1_block_space_offset, FloatA,
b1_block_desc_bk0_n_bk1.GetElementSpaceSize()); FloatB,
FloatAcc,
// selected_mfma.group_size or B1K1 <= Gemm1KPack <= selected_mfma.group_size decltype(a1_thread_desc_k0perblock_mperblock_k1),
// selected_mfma.k_per_blk <= Gemm1KPack decltype(b1_block_desc_k0perblock_nperblock_k1),
// MPerWmma,
// Following similar rationale behind Gemm0KPack, let Gemm1KPack be the lowest common NPerWmma,
// multiples of A1K1 (predetermined by selected_mfma.group_size) and B1K1. But in this case MRepeat,
// Gemm1KPack can't be higher than A1K1 itself because A1 matrix is distributed in VGPRs NRepeat,
// with 'group_size' amount of contiguous elements. Having Gemm1KPack greater than A1K1 will KPack>{make_tuple(0, 0, 0, 0, 0)};
// cause mismatch in summation index for example c[0:7] = a1[[0:3, 8:11]] * b1[0:7].
// therefore we may just as well assign Gemm1KPack = group_size auto acc1_thread_buf = blockwise_gemm1.GetCThreadBuffer();
constexpr index_t Gemm1KPack =
MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma.group_size; const index_t num_gemm1_k_block_outer_loop = b_grid_desc_bk0_n_bk1.GetLength(I1) / NPerBlock;
auto gemm1_blockwise_gemm = BlockwiseGemmXdlops_v2<
BlockSize,
FloatAB,
FloatGemmAcc,
decltype(a1_thread_desc_k0_m_k1),
decltype(b1_block_desc_bk0_n_bk1),
decltype(MakeGemm1AMmaTileDescriptor_M0_M1_M2_K(a1_thread_desc_k0_m_k1)),
decltype(MakeGemm1BMmaTileDescriptor_N0_N1_N2_K(b1_block_desc_bk0_n_bk1)),
MPerBlock,
Gemm1NPerBlock,
Gemm1KPerBlock,
MPerXdl,
NPerXdl,
MXdlPerWave,
Gemm1NXdlPerWave,
Gemm1KPack,
true, // TransposeC
Gemm1KPack, // AMmaKStride
Gemm1KPack * XdlopsGemm<FloatAB, MPerXdl, NPerXdl, Gemm1KPack, false>{}.K0PerXdlops>{
// BMmaKStride
make_tuple(0, 0, 0, 0)}; // A_origin
auto acc1_thread_buf = gemm1_blockwise_gemm.GetCThreadBuffer();
//
// Blockwise softmax
//
auto workspace_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatGemmAcc*>(p_shared) + SharedMemTrait::reduction_space_offset,
SharedMemTrait::reduction_space_size_aligned);
// get acc0 8D thread cluster
constexpr auto thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4 =
blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4().GetLengths() /
blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4().GetLengths();
constexpr auto tm0 = thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4.At(I0);
constexpr auto tn0 = thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4.At(I1);
constexpr auto tm1 = thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4.At(I2);
constexpr auto tn1 = thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4.At(I3);
constexpr auto tm2 = thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4.At(I4);
constexpr auto tn2 = thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4.At(I5);
constexpr auto tn3 = thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4.At(I6);
constexpr auto tn4 = thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4.At(I7);
// get acc0 thread map
constexpr auto m0_n_m1_to_m_n_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(tm0 * tm1, tm2)),
make_pass_through_transform(I1)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
constexpr auto threadid_to_m0_n_m1_adaptor = make_single_stage_tensor_adaptor(
make_tuple(
make_merge_transform(make_tuple(tm0 * tm1, tn0 * tn1 * tn2 * tn3 * tn4, tm2))),
make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{}));
const auto threadid_to_m_n_thread_cluster_adaptor =
chain_tensor_adaptors(m0_n_m1_to_m_n_adaptor, threadid_to_m0_n_m1_adaptor);
// get acc0 2D thread cluster & 2D thread slice
constexpr auto thread_cluster_desc_m_n = make_naive_tensor_descriptor_packed(
make_tuple(tm0 * tm1 * tm2, tn0 * tn1 * tn2 * tn3 * tn4));
constexpr auto thread_slice_desc_m_n =
make_naive_tensor_descriptor_packed(make_tuple(m0 * m1 * m2, n0 * n1 * n2 * n3 * n4));
auto blockwise_softmax = BlockwiseSoftmax<BlockSize,
FloatGemmAcc,
decltype(threadid_to_m_n_thread_cluster_adaptor),
decltype(thread_cluster_desc_m_n),
decltype(thread_slice_desc_m_n)>{};
const index_t num_gemm1_k_block_outer_loop =
b_grid_desc_bk0_n_bk1.GetLength(I1) / NPerBlock;
constexpr index_t num_gemm1_k_block_inner_loop = NPerBlock / Gemm1KPerBlock; constexpr index_t num_gemm1_k_block_inner_loop = NPerBlock / Gemm1KPerBlock;
// Initialize C // Initialize C
StaticBuffer<AddressSpaceEnum::Vgpr, FloatGemmAcc, acc1_thread_buf.Size(), true> StaticBuffer<AddressSpaceEnum::Vgpr, FloatAcc, acc1_thread_buf.Size(), true> c_thread_buf;
c_thread_buf;
c_thread_buf.Clear(); c_thread_buf.Clear();
// Initialize running sum and max of exponentiating row vectors /*******************************************************************************/
using SoftmaxBuf = typename decltype(blockwise_softmax)::BufferType; // Flash Attention
SoftmaxBuf running_sum, running_sum_new, running_max, running_max_new; // Dao, Tri, et al. "Flashattention: Fast and memory-efficient exact attention with io-awareness." arXiv preprint arXiv:2205.14135 (2022).
running_sum = 0;
running_sum_new = 0;
running_max = NumericLimits<FloatGemmAcc>::Lowest();
running_max_new = NumericLimits<FloatGemmAcc>::Lowest();
// gemm1 K loop
index_t gemm1_k_block_outer_index = 0; index_t gemm1_k_block_outer_index = 0;
do // Outer loop, along GEMM_L
{ // Inner loop, along GEMM_K
auto n_block_data_idx_on_grid = do{
__builtin_amdgcn_readfirstlane(gemm1_k_block_outer_index * NPerBlock); // gemm0 start, A-B swaped
if(c0_matrix_mask.IsTileSkippable( const index_t K0BlockMainLoop = __builtin_amdgcn_readfirstlane(K0 / K0PerBlock);
m_block_data_idx_on_grid, n_block_data_idx_on_grid, MPerBlock, NPerBlock)) GridwiseGemmPipe::template Run<HasMainKBlockLoop>(a_grid_desc_k0_m_k1,
{ a_block_desc_k0perblock_mperblock_k1,
continue; a_blockwise_copy,
} a_grid_buf,
// gemm0 a_block_buf,
gridwise_gemm_pipeline.template Run<HasMainKBlockLoop>(a_grid_desc_ak0_m_ak1, a_block_slice_copy_step,
a_block_desc_ak0_m_ak1, b0_grid_desc_k0_l_k1,
a_blockwise_copy, b_block_desc_k0perblock_nperblock_k1,
a_grid_buf, b_blockwise_copy,
a_block_buf, b0_grid_buf,
a_block_slice_copy_step, b_block_buf,
b_grid_desc_bk0_n_bk1, b_block_slice_copy_step,
b_block_desc_bk0_n_bk1, blockwise_gemm,
b_blockwise_copy, acc_thread_buf,
b_grid_buf, K0BlockMainLoop);
b_block_buf,
b_block_slice_copy_step,
blockwise_gemm,
acc_thread_buf,
num_k_block_main_loop);
// do MNK padding or upper triangular masking // do MNK padding or upper triangular masking
if constexpr(MaskOutUpperTriangle || PadN) if constexpr(MaskOutUpperTriangle || PadN)
{ {
...@@ -797,13 +680,15 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -797,13 +680,15 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
}); });
} }
else else
{ { static_for<0, acc_thread_buf.Size(), 1>{}(
static_for<0, acc_thread_buf.Size(), 1>{}(
[&](auto i) { acc_element_op(acc_thread_buf(i), acc_thread_buf[i]); }); [&](auto i) { acc_element_op(acc_thread_buf(i), acc_thread_buf[i]); });
} }
block_sync_lds(); // wait for lds read in gemm0 blockwise gemm
block_sync_lds();
// gemm0 end
// Tiled softmax start
// softmax // softmax
SoftmaxBuf& max = blockwise_softmax.max_value_buf; SoftmaxBuf& max = blockwise_softmax.max_value_buf;
SoftmaxBuf& sum = blockwise_softmax.sum_value_buf; SoftmaxBuf& sum = blockwise_softmax.sum_value_buf;
...@@ -814,7 +699,13 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -814,7 +699,13 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
running_max_new = mathext::max(max, running_max); running_max_new = mathext::max(max, running_max);
running_sum_new = mathext::exp(running_max - running_max_new) * running_sum + running_sum_new = mathext::exp(running_max - running_max_new) * running_sum +
mathext::exp(max - running_max_new) * sum; mathext::exp(max - running_max_new) * sum;
// Intra-Row data permutation, make swizzled A input for WMMA
__builtin_amdgcn_permlane16(0xeca86420, 0xfdb97531);
// Low/high row move data to low/high half of thread buffer
/* thread copy*/
// Inter-Row data permutation, fullfill data duplication requirement
__builtin_amdgcn_permlanex16(0x76543210, 0xfedcba98);
// gemm1 // gemm1
{ {
// TODO: explore using dynamic buffer for a1 thread buffer // TODO: explore using dynamic buffer for a1 thread buffer
...@@ -841,6 +732,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -841,6 +732,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
if constexpr(num_gemm1_k_block_inner_loop > 1) if constexpr(num_gemm1_k_block_inner_loop > 1)
{ {
static_for<0, num_gemm1_k_block_inner_loop - 1, 1>{}([&](auto i) { static_for<0, num_gemm1_k_block_inner_loop - 1, 1>{}([&](auto i) {
// Data cast from FloatAcc to FloatA happen here
a1_blockwise_copy.Run(acc_thread_desc_k0_m_k1, a1_blockwise_copy.Run(acc_thread_desc_k0_m_k1,
make_tuple(Number<i * A1ThreadSliceK0>{}, I0, I0), make_tuple(Number<i * A1ThreadSliceK0>{}, I0, I0),
acc_thread_buf, acc_thread_buf,
...@@ -879,14 +771,6 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -879,14 +771,6 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
} }
} // end gemm1 } // end gemm1
// workaround compiler issue; see ck/ck.hpp
if constexpr(CK_WORKAROUND_SWDEV_XXXXXX_BF16_ATTEN_FWD_GFX908_ISSUE == 1 &&
is_same_v<FloatAB, bhalf_t> && MPerBlock == 256 && NPerBlock == 128 &&
Gemm1NPerBlock == 128)
{
__builtin_amdgcn_sched_barrier(0);
}
constexpr auto c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 = constexpr auto c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 =
gemm1_blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(); gemm1_blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
constexpr auto cm0 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I0); constexpr auto cm0 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I0);
...@@ -910,8 +794,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -910,8 +794,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
FloatGemmAcc c_new = FloatGemmAcc c_new =
(running_sum[iM] * math::exp(running_max[iM] - running_max_new[iM]) * c + (running_sum[iM] * math::exp(running_max[iM] - running_max_new[iM]) * c +
math::exp(max[iM] - running_max_new[iM]) * acc1) / math::exp(max[iM] - running_max_new[iM]) * acc1) /
running_sum_new[iM]; // Formula by Dao et al., running_sum_new[iM];
// https://arxiv.org/pdf/2205.14135v2.pdf section 3.1
c_thread_buf(I) = c_new; // O_new c_thread_buf(I) = c_new; // O_new
}); });
...@@ -927,120 +810,102 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -927,120 +810,102 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
running_sum = running_sum_new; running_sum = running_sum_new;
block_sync_lds(); // wait for gemm1 LDS read block_sync_lds(); // wait for gemm1 LDS read
} while(++gemm1_k_block_outer_index < num_gemm1_k_block_outer_loop); // end j loop }while(++gemm1_k_block_outer_index < num_gemm1_k_block_outer_loop);
/*******************************************************************************/
// shuffle C and write out // write out to C, implement shuffle
{ {
static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && constexpr auto c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs =
Gemm1NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, blockwise_gemm.GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs();
"wrong!");
constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
constexpr index_t NWave = Gemm1NPerBlock / (Gemm1NXdlPerWave * NPerXdl);
// TODO: hacky, fix it! // This API Provide All dimension (size) you need
constexpr auto c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 = constexpr auto c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp =
gemm1_blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(); blockwise_gemm.GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs();
// TODO: hacky, fix it! constexpr auto MWave = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp.GetLength(I1);
// c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp is only used to get lengths constexpr auto MSubGroup = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp.GetLength(I2);
constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp = constexpr auto NWave = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp.GetLength(I4);
gemm1_blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(); constexpr auto NThreadPerSubGroup = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp.GetLength(I5);
constexpr auto MAccVgprs = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp.GetLength(I6);
constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I0); // LDS descriptor, shuffle and write out in MRepeat x NRepeat times
constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I1); constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat =
constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I2); GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat();
constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I3);
constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I4);
constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I5);
constexpr auto N3 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I6);
constexpr auto N4 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I7);
constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatCShuffle*>(p_shared), static_cast<FloatCShuffle*>(p_shared),
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat.GetElementSpaceSize());
constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4 = transform_tensor_descriptor( constexpr auto c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs = transform_tensor_descriptor(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat,
make_tuple( make_tuple(
make_freeze_transform(I0), make_freeze_transform(I0),
make_unmerge_transform(make_tuple( make_unmerge_transform(make_tuple(
Number<CShuffleMXdlPerWavePerShuffle>{}, // M0 (MXdlPerWave) per shuffle Number<CShuffleMRepeatPerShuffle>{}, // MRepeat per shuffle repeat
M1, // M1 = MWave MWave, // MWave
M2)), // M2 = MPerXdl MSubGroup, // MSubGroup * MAccVgprs = MPerWmma
MAccVgprs)),
make_freeze_transform(I0), make_freeze_transform(I0),
make_unmerge_transform(make_tuple( make_unmerge_transform(make_tuple(
Number<CShuffleNXdlPerWavePerShuffle>{}, // N0 (NXdlPerWave) per shuffle Number<CShuffleNRepeatPerShuffle>{}, // NRepeat per shuffle repeat
N1, // N1 = NWave NWave, // NWave
N2, // N2 * N3 * N4 = NPerXdl NThreadPerSubGroup))), // NThreadPerSubGroup = NPerWmma
N3,
N4))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple( make_tuple(Sequence<>{}, Sequence<0, 1, 2, 6>{}, Sequence<>{}, Sequence<3, 4, 5>{}));
Sequence<>{}, Sequence<0, 2, 4>{}, Sequence<>{}, Sequence<1, 3, 5, 6, 7>{}));
// calculate origin of thread output tensor on global memory // calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index // blockwise GEMM c matrix starting index
const auto c_thread_mtx_on_block = const auto c_thread_mtx_on_block = blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0);
gemm1_blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
const auto m_thread_data_on_block_to_m0_m1_m2_adaptor = const auto m_thread_data_on_block_to_mrepeat_mwave_msubgroup_maccvgprs_adaptor =
make_single_stage_tensor_adaptor( make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(M0, M1, M2))), make_tuple(make_merge_transform(make_tuple(MRepeat, MWave, MSubGroup, MAccVgprs))),
make_tuple(Sequence<0, 1, 2>{}), make_tuple(Sequence<0, 1, 2, 3>{}),
make_tuple(Sequence<0>{})); make_tuple(Sequence<0>{}));
const auto m_thread_data_on_block_idx = const auto n_thread_data_on_block_to_nrepeat_nwave_nthreadpersubgroup_adaptor =
m_thread_data_on_block_to_m0_m1_m2_adaptor.CalculateBottomIndex( make_single_stage_tensor_adaptor(
make_multi_index(m_thread_data_on_block)); make_tuple(make_merge_transform(make_tuple(NRepeat, NWave, NThreadPerSubGroup))),
make_tuple(Sequence<0, 1, 2>{}),
const auto n_thread_data_on_block_to_n0_n1_n2_n3_n4_adaptor = make_tuple(Sequence<0>{}));
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(N0, N1, N2, N3, N4))), const auto m_thread_data_on_block_idx = m_thread_data_on_block_to_mrepeat_mwave_msubgroup_maccvgprs_adaptor.CalculateBottomIndex(
make_tuple(Sequence<0, 1, 2, 3, 4>{}), make_multi_index(m_thread_data_on_block));
make_tuple(Sequence<0>{}));
const auto n_thread_data_on_block_idx = n_thread_data_on_block_to_nrepeat_nwave_nthreadpersubgroup_adaptor.CalculateBottomIndex(
const auto n_thread_data_on_block_idx = make_multi_index(n_thread_data_on_block));
n_thread_data_on_block_to_n0_n1_n2_n3_n4_adaptor.CalculateBottomIndex(
make_multi_index(n_thread_data_on_block));
// shuffle: threadwise copy C from VGPR to LDS // shuffle: threadwise copy C from VGPR to LDS
auto c_thread_copy_vgpr_to_lds = auto c_thread_copy_vgpr_to_lds =
ThreadwiseTensorSliceTransfer_v1r3<FloatGemmAcc, ThreadwiseTensorSliceTransfer_v1r3<FloatAcc,
FloatCShuffle, FloatCShuffle,
decltype(c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4), decltype(c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs),
decltype(c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4), decltype(c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs),
tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
Sequence<CShuffleMXdlPerWavePerShuffle, Sequence<CShuffleMRepeatPerShuffle,
CShuffleNXdlPerWavePerShuffle,
I1, I1,
I1, I1,
CShuffleNRepeatPerShuffle,
I1, I1,
N2,
I1, I1,
N4>, MAccVgprs>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7>, Sequence<0, 1, 2, 3, 4, 5, 6>,
7, 6,
1, 1, // vector write pixel
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
1, 1,
true>{ true>{
c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4, c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs,
make_multi_index(0, make_multi_index(0,
0,
m_thread_data_on_block_idx[I1], m_thread_data_on_block_idx[I1],
n_thread_data_on_block_idx[I1],
m_thread_data_on_block_idx[I2], m_thread_data_on_block_idx[I2],
0,
n_thread_data_on_block_idx[I1],
n_thread_data_on_block_idx[I2], n_thread_data_on_block_idx[I2],
n_thread_data_on_block_idx[I3], m_thread_data_on_block_idx[I3]),
n_thread_data_on_block_idx[I4]), ck::tensor_operation::element_wise::PassThrough{}};
tensor_operation::element_wise::PassThrough{}};
// shuffle: blockwise copy C from LDS to global // shuffle: blockwise copy C from LDS to global
auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1<
...@@ -1048,47 +913,47 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -1048,47 +913,47 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
CElementwiseOperation, // ElementwiseOperation, CElementwiseOperation, // ElementwiseOperation,
CGlobalMemoryDataOperation, // DstInMemOp, CGlobalMemoryDataOperation, // DstInMemOp,
Sequence<1, Sequence<1,
CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, CShuffleMRepeatPerShuffle * MWave * MPerWmma,
1, 1,
CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, CShuffleNRepeatPerShuffle * NWave * NPerWmma>, // BlockSliceLengths,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
FloatCShuffle, // typename SrcData, FloatCShuffle, // typename SrcData,
FloatC, // typename DstData, FloatC, // typename DstData,
decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), decltype(c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat),
decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
Sequence<0, 1, 2, 3>, // typename DimAccessOrder, Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
3, // index_t VectorDim, 3, // index_t VectorDim,
CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector, CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector,
true, // bool ThreadTransferSrcResetCoordinateAfterRun, true, // bool ThreadTransferSrcResetCoordinateAfterRun,
false> // bool ThreadTransferDstResetCoordinateAfterRun> false> // bool ThreadTransferDstResetCoordinateAfterRun>
{c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, {c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat,
make_multi_index(0, 0, 0, 0), make_multi_index(0, 0, 0, 0),
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0), make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0),
c_element_op}; c_element_op};
// space filling curve for local reg & global memory
// space filling curve for threadwise C in VGPR // space filling curve for threadwise C in VGPR
constexpr auto sfc_c_vgpr = constexpr auto sfc_c_vgpr =
SpaceFillingCurve<Sequence<MXdlPerWave, Gemm1NXdlPerWave, 1, 1, 1, N2, 1, N4>, SpaceFillingCurve<Sequence<MRepeat, 1, 1, NRepeat, 1, 1, MAccVgprs>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7>, Sequence<0, 1, 2, 3, 4, 5, 6>,
Sequence<CShuffleMXdlPerWavePerShuffle, Sequence<CShuffleMRepeatPerShuffle,
CShuffleNXdlPerWavePerShuffle,
1, 1,
1, 1,
CShuffleNRepeatPerShuffle,
1, 1,
N2,
1, 1,
N4>>{}; MAccVgprs>>{};
// space filling curve for shuffled blockwise C in global mem // space filling curve for shuffled blockwise C in global mem
constexpr auto sfc_c_global = constexpr auto sfc_c_global =
SpaceFillingCurve<Sequence<1, MPerBlock, 1, Gemm1NPerBlock>, SpaceFillingCurve<Sequence<1, MPerBlock, 1, NPerBlock>,
Sequence<0, 2, 1, 3>, Sequence<0, 2, 1, 3>,
Sequence<1, Sequence<1,
CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, CShuffleMRepeatPerShuffle * MWave * MPerWmma,
1, 1,
CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; CShuffleNRepeatPerShuffle * NWave * NPerWmma>>{};
constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
...@@ -1099,10 +964,10 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -1099,10 +964,10 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
block_sync_lds(); block_sync_lds();
// each thread write its data from VGPR to LDS // each thread write its data from VGPR to LDS
c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4, c_thread_copy_vgpr_to_lds.Run(c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs,
sfc_c_vgpr.GetIndexTupleOfNumber(access_id), sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
c_thread_buf, c_thread_buf,
c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4, c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs,
c_shuffle_block_buf); c_shuffle_block_buf);
// make sure it's safe to read from LDS // make sure it's safe to read from LDS
...@@ -1110,7 +975,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -1110,7 +975,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
// each block copy its data from LDS to global // each block copy its data from LDS to global
c_shuffle_block_copy_lds_to_global.Run( c_shuffle_block_copy_lds_to_global.Run(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat,
c_shuffle_block_buf, c_shuffle_block_buf,
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
c_grid_buf); c_grid_buf);
...@@ -1118,13 +983,13 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -1118,13 +983,13 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
if constexpr(access_id < num_access - 1) if constexpr(access_id < num_access - 1)
{ {
constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id);
// move on C // move on C
c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow(
c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
} }
}); });
} }
// clang-format on
} }
}; };
......
...@@ -414,7 +414,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma ...@@ -414,7 +414,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
constexpr auto KPack = math::integer_least_multiple(K1, WmmaK); constexpr auto KPack = math::integer_least_multiple(K1, WmmaK);
auto blockwise_gemm = auto blockwise_gemm =
BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO<BlockSize, BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle<BlockSize,
FloatA, FloatA,
FloatB, FloatB,
FloatAcc, FloatAcc,
......
...@@ -23,11 +23,11 @@ struct intrin_wmma_f32_16x16x16_f16_w32<16, 16> ...@@ -23,11 +23,11 @@ struct intrin_wmma_f32_16x16x16_f16_w32<16, 16>
{ {
// * Inline assembly need to elimate the duplicated data load, compiler won't help you // * Inline assembly need to elimate the duplicated data load, compiler won't help you
// delete them. // delete them.
amd_assembly_wmma_f32_16x16x16_f16_w32( // amd_assembly_wmma_f32_16x16x16_f16_w32(
reg_a, reg_b, reg_c.template AsType<float8_t>()(Number<0>{})); // reg_a, reg_b, reg_c.template AsType<float8_t>()(Number<0>{}));
// reg_c.template AsType<float8_t>()(Number<0>{}) = reg_c.template AsType<float8_t>()(Number<0>{}) =
// __builtin_amdgcn_wmma_f32_16x16x16_f16_w32( reg_a, reg_b, reg_c.template __builtin_amdgcn_wmma_f32_16x16x16_f16_w32( reg_a, reg_b, reg_c.template
// AsType<float8_t>()[Number<0>{}]); AsType<float8_t>()[Number<0>{}]);
} }
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment