"git@developer.sourcefind.cn:zhaoyu6/sglang.git" did not exist on "8f5b9910c122deea78b6d6b25169bb44d2e21619"
Commit a045e0be authored by aska-0096's avatar aska-0096
Browse files

Example branch provide to compiler team

parent fbc576b5
......@@ -240,7 +240,6 @@ include_directories(BEFORE
SET(BUILD_DEV ON CACHE BOOL "BUILD_DEV")
if(BUILD_DEV)
add_compile_options(-Werror)
add_compile_options(-Weverything)
endif()
message("CMAKE_CXX_FLAGS: ${CMAKE_CXX_FLAGS}")
......
......@@ -38,7 +38,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle
256, // BlockSize
128, // MPerBlock
128, // NPerBlock
64, // KPerBlock
32, // KPerBlock
8, // K1
16, // MPerWmma
16, // NPerWmma
......
......@@ -5,9 +5,6 @@ add_example_executable(example_batched_gemm_scale_softmax_gemm_permute_xdl_bf16
add_example_executable(example_grouped_gemm_scale_softmax_gemm_permute_xdl_fp16 grouped_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp)
add_example_executable(example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16 batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp)
add_example_executable(example_grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16 grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp)
if(GPU_TARGETS MATCHES "gfx1100")
add_example_executable(example_batched_gemm_scale_softmax_gemm_permute_wmma_fp16 batched_gemm_scale_softmax_gemm_permute_wmma_fp16.cpp)
endif()
add_custom_target(example_gemm_scale_softmax_gemm)
add_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_xdl_fp16)
......@@ -17,8 +14,3 @@ add_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_scale_soft
add_dependencies(example_gemm_scale_softmax_gemm example_grouped_gemm_scale_softmax_gemm_permute_xdl_fp16)
add_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16)
add_dependencies(example_gemm_scale_softmax_gemm example_grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16)
if(GPU_TARGETS MATCHES "gfx1100")
add_custom_target(example_gemm_scale_softmax_gemm_wmma)
add_dependencies(example_gemm_scale_softmax_gemm_wmma example_batched_gemm_scale_softmax_gemm_permute_wmma_fp16)
endif()
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
/*
Gemm + Softmax + Gemm fused operation. Computes C_g_m_n = Softmax(A_g_m_k * B0_g_k_l) * B1_g_l_n
|-----------------|
Gemm0
|-------------------------------------|
Gemm1
*/
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using F16 = ck::half_t;
using F32 = float;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ADataType = F16;
using B0DataType = F16;
using B1DataType = F16;
using Acc0DataType = F32;
using Acc1DataType = F32;
using CShuffleDataType = F32;
using CDataType = F16;
using Acc0BiasDataType = ck::Tuple<>;
using Acc1BiasDataType = ck::Tuple<>;
static constexpr ck::index_t NumDimG = 2;
static constexpr ck::index_t NumDimM = 1;
static constexpr ck::index_t NumDimN = 1;
static constexpr ck::index_t NumDimK = 1;
static constexpr ck::index_t NumDimO = 1;
using AElementOp = PassThrough;
using B0ElementOp = PassThrough;
using Acc0ElementOp = ck::tensor_operation::element_wise::Scale;
using B1ElementOp = PassThrough;
using CElementOp = PassThrough;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKOPadding;
static constexpr auto MaskingSpec =
ck::tensor_operation::device::MaskingSpecialization::MaskDisabled;
static constexpr auto TensorSpecA = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecB0 = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecB1 = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecialization::Default;
using DeviceGemmInstance =
ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle<
NumDimG,
NumDimM,
NumDimN,
NumDimK,
NumDimO,
ADataType,
B0DataType,
B1DataType,
Acc0BiasDataType,
Acc0DataType,
Acc1BiasDataType,
Acc1DataType,
CShuffleDataType,
CDataType,
AElementOp,
B0ElementOp,
Acc0ElementOp,
B1ElementOp,
CElementOp,
GemmSpec,
TensorSpecA,
TensorSpecB0,
TensorSpecB1,
TensorSpecC,
256,
128, // MPerBlock
128, // LPerBlock
4, // K0PerBlock
8, // K1
64, // NPerBlock
4, // L0PerBlock
8, // L1
16, // MPerWMMA
16, // LPerWMMA
16, // NPerWMMA
// Per repeat = wave_m = wave_num, wave_n = 1
1, // MRepeat
8, // LRepeat
4, // NRepeat
S<4, 64, 1>, // ABlockTransfer MK -> K0 M K1
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
S<4, 64, 1>, // B0BlockTransfer LK -> K0 L K1
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
S<4, 8, 8>, // B1BlockTransfer LN -> L0 N L1
S<0, 2, 1>,
S<0, 2, 1>,
1,
8,
1,
false,
1, // CShuffleMWmmaPerWavePerShuffle
2, // CShuffleNWmmaPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
4, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec>; // MaskingSpecialization
// Ref Gemm0: fp16 in, fp32 out
using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm<ADataType,
B0DataType,
Acc0DataType,
Acc1DataType,
AElementOp,
B0ElementOp,
Acc0ElementOp>;
// Ref Softmax: fp32 in, fp16 out
using ReferenceSoftmaxInstance =
ck::tensor_operation::host::ReferenceSoftmax<Acc0DataType, ADataType, Acc0DataType>;
// Ref Gemm1: fp16 in, fp16 out
using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm<ADataType,
B1DataType,
CDataType,
Acc1DataType,
AElementOp,
B1ElementOp,
CElementOp>;
#include "run_batched_gemm_scale_softmax_gemm_permute.inc"
int main(int argc, char* argv[]) { return run(argc, argv); }
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp"
#include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
// Computes C = A * B0 * B1
// MN = MK * KL * LN
// ^^^^^^ (Acc0)
// ^^^^^^^^^^^ (Acc1)
template <index_t NumDimG,
index_t NumDimM,
index_t NumDimL,
index_t NumDimK,
index_t NumDimN,
typename ADataType,
typename B0DataType,
typename B1DataType,
typename Acc0BiasDataType,
typename Acc0DataType,
typename Acc1BiasDataType,
typename Acc1DataType,
typename CShuffleDataType,
typename CDataType,
typename AElementwiseOperation,
typename B0ElementwiseOperation,
typename AccElementwiseOperation,
typename B1ElementwiseOperation,
typename CElementwiseOperation,
GemmSpecialization GemmSpec,
TensorSpecialization ASpec,
TensorSpecialization B0Spec,
TensorSpecialization B1Spec,
TensorSpecialization CSpec,
ck::index_t BlockSize,
ck::index_t MPerBlock,
ck::index_t LPerBlock,
ck::index_t K0PerBlock, // K0 * K1 = Gemm0 GEMM_K Dim
ck::index_t K1, //
ck::index_t NPerBlock,
ck::index_t L0PerBlock,
ck::index_t L1,
ck::index_t MPerWMMA,
ck::index_t LPerWMMA,
ck::index_t NPerWMMA,
ck::index_t MRepeat,
ck::index_t LRepeat,
ck::index_t NRepeat,
typename ABlockTransferThreadClusterLengths_K0_M_K1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
ck::index_t ABlockTransferSrcVectorDim,
ck::index_t ABlockTransferSrcScalarPerVector,
ck::index_t ABlockTransferDstScalarPerVector_K1,
bool ABlockLdsAddExtraM,
typename B0BlockTransferThreadClusterLengths_K0_L_K1,
typename B0BlockTransferThreadClusterArrangeOrder,
typename B0BlockTransferSrcAccessOrder,
ck::index_t B0BlockTransferSrcVectorDim,
ck::index_t B0BlockTransferSrcScalarPerVector,
ck::index_t B0BlockTransferDstScalarPerVector_K1,
bool B0BlockLdsAddExtraL,
typename B1BlockTransferThreadClusterLengths_L0_N_L1,
typename B1BlockTransferThreadClusterArrangeOrder,
typename B1BlockTransferSrcAccessOrder,
ck::index_t B1BlockTransferSrcVectorDim,
ck::index_t B1BlockTransferSrcScalarPerVector,
ck::index_t B1BlockTransferDstScalarPerVector_L1,
bool B1BlockLdsAddExtraN,
index_t CShuffleMRepeatPerShuffle,
index_t CShuffleNRepeatPerShuffle,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
MaskingSpecialization MaskingSpec,
ck::index_t NumPrefetch = 1,
ck::LoopScheduler LoopSched = make_default_loop_scheduler(),
ck::PipelineVersion PipelineVer = ck::PipelineVersion::v1>
struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
: public DeviceBatchedGemmSoftmaxGemmPermute<NumDimG,
NumDimM,
NumDimL,
NumDimK,
NumDimN,
ADataType,
B0DataType,
B1DataType,
CDataType,
Acc0BiasDataType,
Acc1BiasDataType,
AElementwiseOperation,
B0ElementwiseOperation,
AccElementwiseOperation,
B1ElementwiseOperation,
CElementwiseOperation,
MaskingSpec>
{
static_assert(NumDimG > 0 && NumDimM > 0 && NumDimL > 0 && NumDimK > 0 && NumDimN > 0,
"Number of dimension must be greater than 0");
static constexpr index_t NumAcc0Bias = Acc0BiasDataType::Size();
static constexpr index_t NumAcc1Bias = Acc1BiasDataType::Size();
// TODO ANT: implement bias combination
static_assert(NumAcc0Bias == 0 && NumAcc0Bias == 0, "Bias addition is unimplemented");
static constexpr index_t NumDimGemm0M = NumDimM;
static constexpr index_t NumDimGemm0N = NumDimL;
static constexpr index_t NumDimGemm0K = NumDimK;
static constexpr index_t NumDimGemm1M = NumDimM;
static constexpr index_t NumDimGemm1N = NumDimN;
static constexpr index_t NumDimGemm1K = NumDimL;
static constexpr index_t KPerBlock = K0PerBlock * K1;
using DeviceOp = DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
using Transform = TransformBatchedContractionContractionToBatchedGemmGemm<
Sequence<NumDimG, NumDimM, NumDimL, NumDimK, NumDimN>,
Sequence<MPerBlock, LPerBlock, KPerBlock, NPerBlock>,
GemmSpec,
ASpec,
B0Spec,
B1Spec,
CSpec>;
static auto MakeAGridDescriptor_AK0_M_AK1(const std::vector<index_t>& a_gs_ms_ks_lengths_vec,
const std::vector<index_t>& a_gs_ms_ks_strides_vec)
{
return Transform::MakeAGridDescriptor_AK0_M_AK1(
Transform::MakeAGridDescriptor_M_K(a_gs_ms_ks_lengths_vec, a_gs_ms_ks_strides_vec),
Number<K1>{});
}
static auto MakeB0GridDescriptor_BK0_L_BK1(const std::vector<index_t>& b0_gs_ls_ks_lengths_vec,
const std::vector<index_t>& b0_gs_ls_ks_strides_vec)
{
return Transform::MakeB0GridDescriptor_BK0_N_BK1(
Transform::MakeB0GridDescriptor_N_K(b0_gs_ls_ks_lengths_vec, b0_gs_ls_ks_strides_vec),
Number<K1>{});
}
static auto MakeB1GridDescriptor_BL0_N_BL1(const std::vector<index_t>& b1_gs_ns_ls_lengths_vec,
const std::vector<index_t>& b1_gs_ns_ls_strides_vec)
{
return Transform::MakeB1GridDescriptor_BK0_N_BK1(
Transform::MakeB1GridDescriptor_N_K(b1_gs_ns_ls_lengths_vec, b1_gs_ns_ls_strides_vec),
Number<L1>{});
}
using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1({}, {}));
using B0GridDesc_BK0_L_BK1 = decltype(MakeB0GridDescriptor_BK0_L_BK1({}, {}));
using B1GridDesc_BL0_N_BL1 = decltype(MakeB1GridDescriptor_BL0_N_BL1({}, {}));
using CGridDesc_M_N = decltype(Transform::MakeCGridDescriptor_M_N({}, {}));
using AGridDesc_G_M_K = decltype(Transform::MakeAGridDescriptor_G_M_K({}, {}));
using B0GridDesc_G_L_K = decltype(Transform::MakeB0GridDescriptor_G_N_K({}, {}));
using B1GridDesc_G_N_L = decltype(Transform::MakeB1GridDescriptor_G_N_K({}, {}));
using CGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {}));
constexpr static auto make_MaskOutPredicate()
{
if constexpr(MaskingSpec == MaskingSpecialization::MaskDisabled)
{
return MaskDisabledPredicate{};
}
else if constexpr(MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle)
{
return MaskOutUpperTrianglePredicate{};
}
}
using C0MatrixMask = C0MatrixMask_impl<decltype(make_MaskOutPredicate())>;
struct ComputeBasePtrOfStridedBatch
{
ComputeBasePtrOfStridedBatch(const AGridDesc_G_M_K& a_grid_desc_g_m_k,
const B0GridDesc_G_L_K& b0_grid_desc_g_l_k,
const B1GridDesc_G_N_L& b1_grid_desc_g_n_l,
const CGridDesc_G_M_N& c_grid_desc_g_m_n)
: a_grid_desc_g_m_k_(a_grid_desc_g_m_k),
b0_grid_desc_g_l_k_(b0_grid_desc_g_l_k),
b1_grid_desc_g_n_l_(b1_grid_desc_g_n_l),
c_grid_desc_g_m_n_(c_grid_desc_g_m_n)
{
}
__host__ __device__ constexpr long_index_t GetABasePtr(index_t g_idx) const
{
return a_grid_desc_g_m_k_.CalculateOffset(make_multi_index(g_idx, 0, 0));
}
__host__ __device__ constexpr long_index_t GetB0BasePtr(index_t g_idx) const
{
return b0_grid_desc_g_l_k_.CalculateOffset(make_multi_index(g_idx, 0, 0));
}
__host__ __device__ constexpr long_index_t GetB1BasePtr(index_t g_idx) const
{
return b1_grid_desc_g_n_l_.CalculateOffset(make_multi_index(g_idx, 0, 0));
}
__host__ __device__ constexpr long_index_t GetCBasePtr(index_t g_idx) const
{
return c_grid_desc_g_m_n_.CalculateOffset(make_multi_index(g_idx, 0, 0));
}
private:
AGridDesc_G_M_K a_grid_desc_g_m_k_;
B0GridDesc_G_L_K b0_grid_desc_g_l_k_;
B1GridDesc_G_N_L b1_grid_desc_g_n_l_;
CGridDesc_G_M_N c_grid_desc_g_m_n_;
};
// GridwiseOp
using GridwiseOp = GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle<
// DataType Family
ADataType,
B0DataType,
Acc0DataType,
B1DataType,
Acc1DataType,
CShuffleDataType,
CDataType,
// ElementwiseOp Family
AElementwiseOperation,
B0ElementwiseOperation,
AccElementwiseOperation,
B1ElementwiseOperation,
CElementwiseOperation,
InMemoryDataOperationEnum::Set,
// InMemory Data Descriptor
AGridDesc_AK0_M_AK1,
B0GridDesc_BK0_L_BK1,
B1GridDesc_BL0_N_BL1,
CGridDesc_M_N,
// Tiling Family
MPerBlock,
LPerBlock,
K0PerBlock, // K0 * K1 = Gemm0 GEMM_K Dim
K1, //
NPerBlock,
L0PerBlock,
L1,
MPerWMMA,
LPerWMMA,
NPerWMMA,
MRepeat,
LRepeat,
NRepeat,
// ThreadCluster Family
BlockSize,
ABlockTransferThreadClusterLengths_K0_M_K1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_K1,
true,
ABlockLdsAddExtraM,
B0BlockTransferThreadClusterLengths_K0_L_K1,
B0BlockTransferThreadClusterArrangeOrder,
B0BlockTransferSrcAccessOrder,
B0BlockTransferSrcVectorDim,
B0BlockTransferSrcScalarPerVector,
B0BlockTransferDstScalarPerVector_K1,
true,
B0BlockLdsAddExtraL,
B1BlockTransferThreadClusterLengths_L0_N_L1,
B1BlockTransferThreadClusterArrangeOrder,
B1BlockTransferSrcAccessOrder,
B1BlockTransferSrcVectorDim,
B1BlockTransferSrcScalarPerVector,
B1BlockTransferDstScalarPerVector_L1,
false,
B1BlockLdsAddExtraN,
CShuffleMRepeatPerShuffle,
CShuffleNRepeatPerShuffle,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CShuffleBlockTransferScalarPerVector_NPerBlock,
Transform::matrix_padder.PadN,
MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle,
NumPrefetch,
LoopSched,
PipelineVer>;
// Argument
struct Argument : public BaseArgument
{
Argument(
const ADataType* p_a_grid,
const B0DataType* p_b0_grid,
const B1DataType* p_b1_grid,
CDataType* p_c_grid,
const std::array<void*, NumAcc0Bias> p_acc0_biases,
const std::array<void*, NumAcc1Bias> p_acc1_biases,
const std::vector<index_t>& a_gs_ms_ks_lengths,
const std::vector<index_t>& a_gs_ms_ks_strides,
const std::vector<index_t>& b0_gs_ls_ks_lengths,
const std::vector<index_t>& b0_gs_ls_ks_strides,
const std::vector<index_t>& b1_gs_ns_ls_lengths,
const std::vector<index_t>& b1_gs_ns_ls_strides,
const std::vector<index_t>& c_gs_ms_ns_lengths,
const std::vector<index_t>& c_gs_ms_ns_strides,
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ls_lengths,
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ls_strides,
const std::array<std::vector<ck::index_t>, NumAcc1Bias> acc1_biases_gs_ms_ns_lengths,
const std::array<std::vector<ck::index_t>, NumAcc1Bias> acc1_biases_gs_ms_ns_strides,
const index_t M01,
const index_t N01,
AElementwiseOperation a_element_op,
B0ElementwiseOperation b0_element_op,
AccElementwiseOperation acc_element_op,
B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op)
: p_a_grid_{p_a_grid},
p_b0_grid_{p_b0_grid},
p_b1_grid_{p_b1_grid},
p_c_grid_{p_c_grid},
a_grid_desc_ak0_m_ak1_{
DeviceOp::MakeAGridDescriptor_AK0_M_AK1(a_gs_ms_ks_lengths, a_gs_ms_ks_strides)},
b0_grid_desc_bk0_l_bk1_{DeviceOp::MakeB0GridDescriptor_BK0_L_BK1(
b0_gs_ls_ks_lengths, b0_gs_ls_ks_strides)},
b1_grid_desc_bl0_n_bl1_{DeviceOp::MakeB1GridDescriptor_BL0_N_BL1(
b1_gs_ns_ls_lengths, b1_gs_ns_ls_strides)},
c_grid_desc_m_n_{
Transform::MakeCGridDescriptor_M_N(c_gs_ms_ns_lengths, c_gs_ms_ns_strides)},
a_grid_desc_g_m_k_{
Transform::MakeAGridDescriptor_G_M_K(a_gs_ms_ks_lengths, a_gs_ms_ks_strides)},
b0_grid_desc_g_l_k_{
Transform::MakeB0GridDescriptor_G_N_K(b0_gs_ls_ks_lengths, b0_gs_ls_ks_strides)},
b1_grid_desc_g_n_l_{
Transform::MakeB1GridDescriptor_G_N_K(b1_gs_ns_ls_lengths, b1_gs_ns_ls_strides)},
c_grid_desc_g_m_n_{
Transform::MakeCGridDescriptor_G_M_N(c_gs_ms_ns_lengths, c_gs_ms_ns_strides)},
c_grid_desc_mblock_mperblock_nblock_nperblock_{},
block_2_ctile_map_{GridwiseOp::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01)},
a_element_op_{a_element_op},
b0_element_op_{b0_element_op},
acc_element_op_{acc_element_op},
b1_element_op_{b1_element_op},
c_element_op_{c_element_op},
c0_matrix_mask_{b0_grid_desc_g_l_k_.GetLength(I1)},
raw_lengths_mz_lz_kz_nz_{a_gs_ms_ks_lengths[NumDimG + NumDimM - 1],
b0_gs_ls_ks_lengths[NumDimG + NumDimL - 1],
b0_gs_ls_ks_lengths[NumDimG + NumDimL + NumDimK - 1],
b1_gs_ns_ls_lengths[NumDimG + NumDimN - 1]},
a_mz_kz_strides_{a_gs_ms_ks_strides[NumDimG + NumDimM - 1],
a_gs_ms_ks_strides[NumDimG + NumDimM + NumDimK - 1]},
b0_lz_kz_strides_{b0_gs_ls_ks_strides[NumDimG + NumDimL - 1],
b0_gs_ls_ks_strides[NumDimG + NumDimL + NumDimK - 1]},
b1_nz_lz_strides_{b1_gs_ns_ls_strides[NumDimG + NumDimN - 1],
b1_gs_ns_ls_strides[NumDimG + NumDimN + NumDimL - 1]},
c_mz_nz_strides_{c_gs_ms_ns_strides[NumDimG + NumDimM - 1],
c_gs_ms_ns_strides[NumDimG + NumDimM + NumDimN - 1]},
batch_count_{c_grid_desc_g_m_n_.GetLength(I0)},
compute_ptr_offset_of_batch_{
a_grid_desc_g_m_k_, b0_grid_desc_g_l_k_, b1_grid_desc_g_n_l_, c_grid_desc_g_m_n_}
{
// TODO ANT: implement bias addition
ignore = p_acc0_biases;
ignore = p_acc1_biases;
ignore = acc0_biases_gs_ms_ls_lengths;
ignore = acc0_biases_gs_ms_ls_strides;
ignore = acc1_biases_gs_ms_ns_lengths;
ignore = acc1_biases_gs_ms_ns_strides;
if(GridwiseOp::CheckValidity(a_grid_desc_ak0_m_ak1_,
b0_grid_desc_bk0_l_bk1_,
b1_grid_desc_bl0_n_bl1_,
c_grid_desc_m_n_,
block_2_ctile_map_))
{
c_grid_desc_mblock_mperblock_nblock_nperblock_ =
GridwiseOp::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
c_grid_desc_m_n_);
}
}
// Pointers
const ADataType* p_a_grid_;
const B0DataType* p_b0_grid_;
const B1DataType* p_b1_grid_;
CDataType* p_c_grid_;
// Tensor Descriptors
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
B0GridDesc_BK0_L_BK1 b0_grid_desc_bk0_l_bk1_;
B1GridDesc_BL0_N_BL1 b1_grid_desc_bl0_n_bl1_;
CGridDesc_M_N c_grid_desc_m_n_;
AGridDesc_G_M_K a_grid_desc_g_m_k_;
B0GridDesc_G_L_K b0_grid_desc_g_l_k_;
B1GridDesc_G_N_L b1_grid_desc_g_n_l_;
CGridDesc_G_M_N c_grid_desc_g_m_n_;
typename GridwiseOp::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock_;
// Block to Tile mapping
typename GridwiseOp::DefaultBlock2CTileMap block_2_ctile_map_;
// ElementwiseOp
AElementwiseOperation a_element_op_;
B0ElementwiseOperation b0_element_op_;
AccElementwiseOperation acc_element_op_;
B1ElementwiseOperation b1_element_op_;
CElementwiseOperation c_element_op_;
// check C0 masking and padding
C0MatrixMask c0_matrix_mask_;
// Strides for the last M/N/K dimensions of A/B0/B1/C
// for sanity check of vector load/store
std::vector<index_t> raw_lengths_mz_lz_kz_nz_;
std::vector<index_t> a_mz_kz_strides_;
std::vector<index_t> b0_lz_kz_strides_;
std::vector<index_t> b1_nz_lz_strides_;
std::vector<index_t> c_mz_nz_strides_;
index_t batch_count_;
// Batch Offset
ComputeBasePtrOfStridedBatch compute_ptr_offset_of_batch_;
};
// Invoker
struct Invoker : public BaseInvoker
{
using Argument = DeviceOp::Argument;
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
const index_t grid_size =
arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_) * arg.batch_count_;
const auto K =
arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2);
auto launch_kernel = [&](auto has_main_k_block_loop) {
const auto kernel = kernel_batched_gemm_softmax_gemm_wmma_cshuffle<
GridwiseOp,
ADataType,
B0DataType,
B1DataType,
CDataType,
DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::B0GridDesc_BK0_L_BK1,
DeviceOp::B1GridDesc_BL0_N_BL1,
typename GridwiseOp::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
AElementwiseOperation,
B0ElementwiseOperation,
AccElementwiseOperation,
B1ElementwiseOperation,
CElementwiseOperation,
ComputeBasePtrOfStridedBatch,
C0MatrixMask,
typename GridwiseOp::DefaultBlock2CTileMap,
has_main_k_block_loop>;
return launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b0_grid_,
arg.p_b1_grid_,
arg.p_c_grid_,
arg.a_grid_desc_ak0_m_ak1_,
arg.b0_grid_desc_bk0_l_bk1_,
arg.b1_grid_desc_bl0_n_bl1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.a_element_op_,
arg.b0_element_op_,
arg.acc_element_op_,
arg.b1_element_op_,
arg.c_element_op_,
arg.batch_count_,
arg.compute_ptr_offset_of_batch_,
arg.c0_matrix_mask_,
arg.block_2_ctile_map_);
};
if(GridwiseOp::CalculateHasMainKBlockLoop(K))
{
return launch_kernel(integral_constant<bool, true>{});
}
else
{
return launch_kernel(integral_constant<bool, false>{});
}
}
// polymorphic
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
static bool IsSupportedArgument(const Argument& arg)
{
if(ck::get_device_name() == "gfx1100")
{
if constexpr(!(is_same_v<Acc0DataType, float> || is_same_v<Acc0DataType, int32_t>))
{
return false;
}
if constexpr(!(is_same_v<Acc1DataType, float> || is_same_v<Acc1DataType, int32_t>))
{
return false;
}
}
else
{
return false;
}
if(!GridwiseOp::CheckValidity(arg.a_grid_desc_ak0_m_ak1_,
arg.b0_grid_desc_bk0_l_bk1_,
arg.b1_grid_desc_bl0_n_bl1_,
arg.c_grid_desc_m_n_,
arg.block_2_ctile_map_))
{
return false;
}
// Check if C permute dimension matches GEMM + GEMM shape
const index_t c_g = arg.c_grid_desc_g_m_n_.GetLength(I0); // unpadded
const index_t c_m = arg.c_grid_desc_m_n_.GetLength(I0);
const index_t c_n = arg.c_grid_desc_m_n_.GetLength(I1);
const index_t a_m = arg.a_grid_desc_ak0_m_ak1_.GetLength(I1);
const index_t b1_n = arg.b1_grid_desc_bl0_n_bl1_.GetLength(I1);
if(!(c_g == arg.batch_count_ && c_m == a_m && c_n == b1_n))
{
return false;
}
// Note: we need raw lengths since threadwise copy can not handle vector load when part of
// vector is out of bounds
// Note: need lowest dim in Ms/Ns/Ks/Os, not merged M/N/K/O
const auto MzRaw = arg.raw_lengths_mz_lz_kz_nz_[0];
const auto LzRaw = arg.raw_lengths_mz_lz_kz_nz_[1];
const auto KzRaw = arg.raw_lengths_mz_lz_kz_nz_[2];
const auto NzRaw = arg.raw_lengths_mz_lz_kz_nz_[3];
// Check scalar per vector requirement
const auto a_extent_lowest = ABlockTransferSrcVectorDim == 2 ? KzRaw : MzRaw;
const auto b0_extent_lowest = B0BlockTransferSrcVectorDim == 2 ? KzRaw : LzRaw;
const auto b1_extent_lowest = B1BlockTransferSrcVectorDim == 2 ? LzRaw : NzRaw;
const auto c_extent_lowest = NzRaw;
if(!(a_extent_lowest % ABlockTransferSrcScalarPerVector == 0 &&
b0_extent_lowest % B0BlockTransferSrcScalarPerVector == 0 &&
b1_extent_lowest % B1BlockTransferSrcScalarPerVector == 0 &&
c_extent_lowest % CShuffleBlockTransferScalarPerVector_NPerBlock == 0))
{
return false;
}
// Check vector load/store requirement
const auto a_stride_lowest =
ABlockTransferSrcVectorDim == 2 ? arg.a_mz_kz_strides_[1] : arg.a_mz_kz_strides_[0];
const auto b0_stride_lowest =
B0BlockTransferSrcVectorDim == 2 ? arg.b0_lz_kz_strides_[1] : arg.b0_lz_kz_strides_[0];
const auto b1_stride_lowest =
B1BlockTransferSrcVectorDim == 2 ? arg.b1_nz_lz_strides_[1] : arg.b1_nz_lz_strides_[0];
const auto c_stride_lowest = arg.c_mz_nz_strides_[1];
if(!(a_stride_lowest == 1 || b0_stride_lowest == 1 || b1_stride_lowest == 1 ||
c_stride_lowest == 1))
{
return false;
}
return true;
}
// polymorphic
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
static auto MakeArgument(
const ADataType* p_a,
const B0DataType* p_b0,
const B1DataType* p_b1,
CDataType* p_c,
const std::array<void*, NumAcc0Bias> p_acc0_biases,
const std::array<void*, NumAcc1Bias> p_acc1_biases,
const std::vector<index_t>& a_gs_ms_ks_lengths,
const std::vector<index_t>& a_gs_ms_ks_strides,
const std::vector<index_t>& b0_gs_ls_ks_lengths,
const std::vector<index_t>& b0_gs_ls_ks_strides,
const std::vector<index_t>& b1_gs_ns_ls_lengths,
const std::vector<index_t>& b1_gs_ns_ls_strides,
const std::vector<index_t>& c_gs_ms_ns_lengths,
const std::vector<index_t>& c_gs_ms_ns_strides,
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ls_lengths,
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ls_strides,
const std::array<std::vector<ck::index_t>, NumAcc1Bias> acc1_biases_gs_ms_ns_lengths,
const std::array<std::vector<ck::index_t>, NumAcc1Bias> acc1_biases_gs_ms_ns_strides,
AElementwiseOperation a_element_op,
B0ElementwiseOperation b0_element_op,
AccElementwiseOperation acc_element_op,
B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op)
{
return Argument{p_a,
p_b0,
p_b1,
p_c,
p_acc0_biases,
p_acc1_biases,
a_gs_ms_ks_lengths,
a_gs_ms_ks_strides,
b0_gs_ls_ks_lengths,
b0_gs_ls_ks_strides,
b1_gs_ns_ls_lengths,
b1_gs_ns_ls_strides,
c_gs_ms_ns_lengths,
c_gs_ms_ns_strides,
acc0_biases_gs_ms_ls_lengths,
acc0_biases_gs_ms_ls_strides,
acc1_biases_gs_ms_ns_lengths,
acc1_biases_gs_ms_ns_strides,
1,
1,
a_element_op,
b0_element_op,
acc_element_op,
b1_element_op,
c_element_op};
}
// polymorphic
std::unique_ptr<BaseArgument> MakeArgumentPointer(
const void* p_a,
const void* p_b0,
const void* p_b1,
void* p_c,
const std::array<void*, NumAcc0Bias> p_acc0_biases,
const std::array<void*, NumAcc1Bias> p_acc1_biases,
const std::vector<index_t>& a_gs_ms_ks_lengths,
const std::vector<index_t>& a_gs_ms_ks_strides,
const std::vector<index_t>& b0_gs_ls_ks_lengths,
const std::vector<index_t>& b0_gs_ls_ks_strides,
const std::vector<index_t>& b1_gs_ns_ls_lengths,
const std::vector<index_t>& b1_gs_ns_ls_strides,
const std::vector<index_t>& c_gs_ms_ns_lengths,
const std::vector<index_t>& c_gs_ms_ns_strides,
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ls_lengths,
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ls_strides,
const std::array<std::vector<ck::index_t>, NumAcc1Bias> acc1_biases_gs_ms_ns_lengths,
const std::array<std::vector<ck::index_t>, NumAcc1Bias> acc1_biases_gs_ms_ns_strides,
AElementwiseOperation a_element_op,
B0ElementwiseOperation b0_element_op,
AccElementwiseOperation acc_element_op,
B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op) override
{
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const B0DataType*>(p_b0),
static_cast<const B1DataType*>(p_b1),
static_cast<CDataType*>(p_c),
p_acc0_biases,
p_acc1_biases,
a_gs_ms_ks_lengths,
a_gs_ms_ks_strides,
b0_gs_ls_ks_lengths,
b0_gs_ls_ks_strides,
b1_gs_ns_ls_lengths,
b1_gs_ns_ls_strides,
c_gs_ms_ns_lengths,
c_gs_ms_ns_strides,
acc0_biases_gs_ms_ls_lengths,
acc0_biases_gs_ms_ls_strides,
acc1_biases_gs_ms_ns_lengths,
acc1_biases_gs_ms_ns_strides,
1,
1,
a_element_op,
b0_element_op,
acc_element_op,
b1_element_op,
c_element_op);
}
static auto MakeInvoker() { return Invoker{}; }
// polymorphic
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>(Invoker{});
}
// polymorphic
std::string GetTypeString() const override
{
auto str = std::stringstream();
std::map<LoopScheduler, std::string> LoopSchedToString{
{LoopScheduler::Default, "Default"}, {LoopScheduler::Interwave, "Interwave"}};
std::map<PipelineVersion, std::string> PipelineVersionToString{{PipelineVersion::v1, "v1"},
{PipelineVersion::v2, "v2"}};
// clang-format off
str << "DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle"
<< "<"
<< BlockSize << ", "
<< MPerBlock << ", "
<< LPerBlock << ", "
<< K0PerBlock << ", "
<< K1 << ", "
<< MPerBlock << ", "
<< NPerBlock << ", "
<< L0PerBlock << ", "
<< L1
<< getGemmSpecializationString(GemmSpec) << ", "
<< "ASpec" << getTensorSpecializationString(ASpec) << ", "
<< "B0Spec" << getTensorSpecializationString(B0Spec) << ", "
<< "B1Spec" << getTensorSpecializationString(B1Spec) << ", "
<< "CSpec" << getTensorSpecializationString(CSpec) << ", "
<< getMaskingSpecializationString(MaskingSpec)
<< ">"
<< " NumPrefetch: "
<< NumPrefetch << ", "
<< "LoopScheduler: "
<< LoopSchedToString[LoopSched] << ", "
<< "PipelineVersion: "
<< PipelineVersionToString[PipelineVer];
// clang-format on
return str.str();
}
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/multi_index_transform_helper.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_softmax.hpp"
namespace ck {
template <typename GridwiseGemm,
typename FloatA,
typename FloatB0,
typename FloatB1,
typename FloatC,
typename AGridDesc_AK0_M_AK1,
typename B0GridDesc_BK0_L_BK1,
typename B1GridDesc_BL0_N_BL1,
typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename AElementwiseOperation,
typename B0ElementwiseOperation,
typename AccElementwiseOperation,
typename B1ElementwiseOperation,
typename CElementwiseOperation,
typename ComputeBasePtrOfStridedBatch,
typename C0MatrixMask,
typename Block2CTileMap,
bool HasMainKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_batched_gemm_softmax_gemm_wmma_cshuffle(
const FloatA* __restrict__ p_a_grid,
const FloatB0* __restrict__ p_b0_grid,
const FloatB1* __restrict__ p_b1_grid,
FloatC* __restrict__ p_c_grid,
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
const B0GridDesc_BK0_L_BK1 b0_grid_desc_bk0_l_bk1,
const B1GridDesc_BL0_N_BL1 b1_grid_desc_l0_n_l1,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock,
const AElementwiseOperation a_element_op,
const B0ElementwiseOperation b0_element_op,
const AccElementwiseOperation acc_element_op,
const B1ElementwiseOperation b1_element_op,
const CElementwiseOperation c_element_op,
const index_t batch_count,
const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch,
const C0MatrixMask c0_matrix_mask,
const Block2CTileMap block_2_ctile_map)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetABasePtr(g_idx)));
const long_index_t b0_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetB0BasePtr(g_idx)));
const long_index_t b1_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetB1BasePtr(g_idx)));
const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetCBasePtr(g_idx)));
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid + a_batch_offset,
p_b0_grid + b0_batch_offset,
p_b1_grid + b1_batch_offset,
p_c_grid + c_batch_offset,
p_shared,
a_grid_desc_ak0_m_ak1,
b0_grid_desc_bk0_l_bk1,
b1_grid_desc_l0_n_l1,
c_grid_desc_mblock_mperblock_nblock_nperblock,
a_element_op,
b0_element_op,
acc_element_op,
b1_element_op,
c_element_op,
c0_matrix_mask,
block_2_ctile_map);
#else
ignore = p_a_grid;
ignore = p_b0_grid;
ignore = p_b1_grid;
ignore = p_c_grid;
ignore = a_grid_desc_ak0_m_ak1;
ignore = b0_grid_desc_bk0_l_bk1;
ignore = b1_grid_desc_l0_n_l1;
ignore = c_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = a_element_op;
ignore = b0_element_op;
ignore = acc_element_op;
ignore = b1_element_op;
ignore = c_element_op;
ignore = batch_count;
ignore = compute_base_ptr_of_batch;
ignore = c0_matrix_mask;
ignore = block_2_ctile_map;
#endif // end of if (defined(__gfx1100__))
}
// Gemm0: A [M x K] x B0 [K x L] = Acc [M x L]
// Gemm1: Acc [M x L] x B1 [L x N] = C [M x N]
template <typename FloatA,
typename FloatB0,
typename FloatAcc0,
typename FloatB1,
typename FloatAcc1,
typename FloatCShuffle,
typename FloatC,
typename AElementwiseOperation,
typename B0ElementwiseOperation,
typename AccElementwiseOperation,
typename B1ElementwiseOperation,
typename CElementwiseOperation,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
typename AGridDesc_AK0_M_AK1,
typename B0GridDesc_BK0_L_BK1,
typename B1GridDesc_BL0_N_BL1,
typename CGridDesc_M_N,
index_t MPerBlock,
index_t LPerBlock,
index_t K0PerBlock, // K0 * K1Value = Gemm0 GEMM_K Dim
index_t K1Value,
index_t NPerBlock,
index_t L0PerBlock,
index_t L1Value,
index_t MPerWmma,
index_t LPerWmma,
index_t NPerWmma,
index_t MRepeat,
index_t LRepeat,
index_t NRepeat,
index_t BlockSize,
typename ABlockTransferThreadClusterLengths_K0_M_K1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_K1,
bool AThreadTransferSrcResetCoordinateAfterRun,
bool ABlockLdsExtraM,
typename B0BlockTransferThreadClusterLengths_K0_L_K1,
typename B0BlockTransferThreadClusterArrangeOrder,
typename B0BlockTransferSrcAccessOrder,
index_t B0BlockTransferSrcVectorDim,
index_t B0BlockTransferSrcScalarPerVector,
index_t B0BlockTransferDstScalarPerVector_K1,
bool B0ThreadTransferSrcResetCoordinateAfterRun,
bool B0BlockLdsExtraN,
typename B1BlockTransferThreadClusterLengths_L0_N_L1,
typename B1BlockTransferThreadClusterArrangeOrder,
typename B1BlockTransferSrcAccessOrder,
index_t B1BlockTransferSrcVectorDim,
index_t B1BlockTransferSrcScalarPerVector,
index_t B1BlockTransferDstScalarPerVector_L1,
bool B1ThreadTransferSrcResetCoordinateAfterRun,
bool B1BlockLdsExtraN,
index_t CShuffleMRepeatPerShuffle,
index_t CShuffleNRepeatPerShuffle,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
bool PadN,
bool MaskOutUpperTriangle,
index_t NumGemmKPrefetchStage = 1,
LoopScheduler LoopSched = make_default_loop_scheduler(),
PipelineVersion PipelineVer = PipelineVersion::v1>
struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{};
static constexpr auto I5 = Number<5>{};
static constexpr auto I6 = Number<6>{};
static constexpr auto I7 = Number<7>{};
static constexpr auto AK0 = Number<K0PerBlock>{};
static constexpr auto AK1 = Number<K1Value>{};
static constexpr auto BK0 = Number<K0PerBlock>{};
static constexpr auto BK1 = Number<K1Value>{};
static constexpr auto AL0 = Number<L0PerBlock / 2>{};
static constexpr auto AL1 = Number<L1Value>{};
static constexpr auto BL0 = Number<L0PerBlock>{};
static constexpr auto BL1 = Number<L1Value>{};
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
using GridwiseGemmPipe = remove_cvref_t<decltype(
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>())>;
template <typename A0BlockDesc_AK0_M_AK1>
__host__ __device__ static constexpr auto
MakeA0BlockDescriptor_K0_M0_M1_M2_K1(const A0BlockDesc_AK0_M_AK1&)
{
constexpr index_t A_K0 = A0BlockDesc_AK0_M_AK1{}.GetLength(I0);
constexpr index_t A_K1 = A0BlockDesc_AK0_M_AK1{}.GetLength(I2);
constexpr index_t MWaves = MPerBlock / (MRepeat * MPerWmma);
return transform_tensor_descriptor(
A0BlockDesc_AK0_M_AK1{},
make_tuple(make_pass_through_transform(Number<A_K0>{}),
make_unmerge_transform(
make_tuple(Number<MRepeat>{}, Number<MWaves>{}, Number<MPerWmma>{})),
make_pass_through_transform(Number<A_K1>{})),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{}));
}
template <typename B0BlockDesc_BK0_L_BK1>
__host__ __device__ static constexpr auto
MakeB0BlockDescriptor_K0_L0_L1_L2_K1(const B0BlockDesc_BK0_L_BK1&)
{
constexpr index_t B_K0 = B0BlockDesc_BK0_L_BK1{}.GetLength(I0);
constexpr index_t B_K1 = B0BlockDesc_BK0_L_BK1{}.GetLength(I2);
constexpr index_t LWaves = LPerBlock / (LRepeat * LPerWmma);
return transform_tensor_descriptor(
B0BlockDesc_BK0_L_BK1{},
make_tuple(make_pass_through_transform(Number<B_K0>{}),
make_unmerge_transform(
make_tuple(Number<LRepeat>{}, Number<LWaves>{}, Number<LPerWmma>{})),
make_pass_through_transform(Number<B_K1>{})),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{}));
}
template <typename A1BlockDesc_AL0_M_AL1>
__host__ __device__ static constexpr auto
MakeA1BlockDescriptor_L0_M0_M1_M2_L1(const A1BlockDesc_AL0_M_AL1&)
{
constexpr index_t A_L0 = A1BlockDesc_AL0_M_AL1{}.GetLength(I0);
constexpr index_t A_L1 = A1BlockDesc_AL0_M_AL1{}.GetLength(I2);
return transform_tensor_descriptor(
A1BlockDesc_AL0_M_AL1{},
make_tuple(make_pass_through_transform(Number<A_L0>{}),
make_unmerge_transform(make_tuple(Number<MRepeat>{}, I1, I1)),
make_pass_through_transform(Number<A_L1>{})),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{}));
}
template <typename B1BlockDesc_BL0_N_BL1>
__host__ __device__ static constexpr auto
MakeB1BlockDescriptor_L0_N0_N1_N2_L1(const B1BlockDesc_BL0_N_BL1&)
{
constexpr index_t B_K0 = B1BlockDesc_BL0_N_BL1{}.GetLength(I0);
constexpr index_t B_K1 = B1BlockDesc_BL0_N_BL1{}.GetLength(I2);
constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWmma);
return transform_tensor_descriptor(
B1BlockDesc_BL0_N_BL1{},
make_tuple(make_pass_through_transform(Number<B_K0>{}),
make_unmerge_transform(
make_tuple(Number<NRepeat>{}, Number<NWaves>{}, Number<NPerWmma>{})),
make_pass_through_transform(Number<B_K1>{})),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{}));
}
__host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
{
// A matrix in LDS memory, dst of blockwise copy
return make_naive_tensor_descriptor(
make_tuple(AK0, Number<MPerBlock>{}, AK1),
make_tuple(Number<MPerBlock + ABlockLdsExtraM>{} * AK1, AK1, I1));
}
__host__ __device__ static constexpr auto GetB0BlockDescriptor_BK0PerBlock_LPerBlock_BK1()
{
// B matrix in LDS memory, dst of blockwise copy
return make_naive_tensor_descriptor(
make_tuple(BK0, Number<LPerBlock>{}, BK1),
make_tuple(Number<LPerBlock + B0BlockLdsExtraN>{} * BK1, BK1, I1));
}
__host__ __device__ static constexpr auto GetB1BlockDescriptor_BL0PerBlock_NPerBlock_BL1()
{
// B1 matrix in LDS memory, dst of blockwise copy
return make_naive_tensor_descriptor(
make_tuple(BL0, Number<NPerBlock>{}, BL1),
make_tuple(Number<NPerBlock + B1BlockLdsExtraN>{} * BL1, BL1, I1));
}
__host__ __device__ static constexpr auto
// *Caution Here repeat is shuffle repeat
GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat()
{
constexpr index_t MWave = MPerBlock / (MRepeat * MPerWmma);
constexpr index_t NWave = NPerBlock / (NRepeat * NPerWmma);
constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat =
make_naive_tensor_descriptor_packed(
make_tuple(I1,
Number<CShuffleMRepeatPerShuffle * MWave * MPerWmma>{},
I1,
Number<CShuffleNRepeatPerShuffle * NWave * NPerWmma>{}));
return c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat;
}
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{
// LDS allocation for A and B: be careful of alignment
const index_t gemm0_bytes_end =
(SharedMemTrait::a_block_space_size_aligned * sizeof(FloatA) +
SharedMemTrait::b0_block_space_size_aligned * sizeof(FloatB0));
const index_t gemm1_bytes_end =
(SharedMemTrait::b1_block_space_offset + SharedMemTrait::b1_block_space_size_aligned) *
sizeof(FloatB1);
const index_t softmax_bytes_end = (SharedMemTrait::reduction_space_offset +
SharedMemTrait::reduction_space_size_aligned) *
sizeof(FloatAcc0);
const index_t c_block_bytes_end =
SharedMemTrait::c_block_space_size * sizeof(FloatCShuffle);
return math::max(gemm0_bytes_end, gemm1_bytes_end, softmax_bytes_end, c_block_bytes_end);
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
template <typename Block2CTileMap>
__host__ __device__ static constexpr bool
CheckValidity(const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
const B0GridDesc_BK0_L_BK1& b0_grid_desc_bk0_l_bk1,
const B1GridDesc_BL0_N_BL1& b1_grid_desc_l0_n_l1,
const CGridDesc_M_N& c_grid_desc_m_n,
const Block2CTileMap& block_2_ctile_map)
{
static_assert((MPerBlock % (MPerWmma * MRepeat) == 0) &&
(LPerBlock % (LPerWmma * LRepeat)) == 0,
"Invalid tuning param!");
const auto M = a_grid_desc_ak0_m_ak1.GetLength(I1);
const auto L = b0_grid_desc_bk0_l_bk1.GetLength(I1);
const auto K = a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2);
const auto N = b1_grid_desc_l0_n_l1.GetLength(I1);
const auto KPerBlock = K0PerBlock * K1Value;
if(!(M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1)))
{
return false;
}
if(!(M % MPerBlock == 0 && L % LPerBlock == 0 && K % KPerBlock == 0 && N % NPerBlock == 0))
{
return false;
}
// check gemm0 gridwise gemm pipeline
const auto num_gemm0_k_loop = K / KPerBlock;
if(!GridwiseGemmPipe::IsSupported(num_gemm0_k_loop))
{
return false;
}
// check gemm1 gridwise gemm pipeline
if(!(LPerBlock % (L0PerBlock * L1Value) == 0))
{
return false;
}
const auto num_gemm1_k_inner_loop = LPerBlock / (L0PerBlock * L1Value);
if(!GridwiseGemmPipe::IsSupported(num_gemm1_k_inner_loop))
{
return false;
}
if(!block_2_ctile_map.CheckValidity(c_grid_desc_m_n))
{
return false;
}
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return true;
}
__host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
{
const index_t num_loop = K / (K0PerBlock * K1Value);
return GridwiseGemmPipe::CalculateHasMainLoop(num_loop);
}
__host__ __device__ static constexpr auto
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc_M_N& c_grid_desc_m_n)
{
const auto M = c_grid_desc_m_n.GetLength(I0);
const auto N = c_grid_desc_m_n.GetLength(I1);
const auto MBlock = M / MPerBlock;
const auto NBlock = N / NPerBlock;
const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
c_grid_desc_m_n,
make_tuple(make_unmerge_transform(make_tuple(MBlock, Number<MPerBlock>{})),
make_unmerge_transform(make_tuple(NBlock, Number<NPerBlock>{}))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}));
return c_grid_desc_mblock_mperblock_nblock_nperblock;
}
// return block_id to C matrix tile idx (m0, n0) mapping
__host__ __device__ static constexpr auto MakeDefaultBlock2CTileMap(
const CGridDesc_M_N& c_grid_desc_m_n, index_t /* M01 */, index_t /* N01 */)
{
return BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, CGridDesc_M_N>(
c_grid_desc_m_n);
}
using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{}))>;
using DefaultBlock2CTileMap =
remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1))>;
struct SharedMemTrait
{
// LDS allocation for A and B: be careful of alignment
static constexpr auto a_block_desc_ak0_m_ak1 =
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
static constexpr auto b0_block_desc_bk0_l_bk1 =
GetB0BlockDescriptor_BK0PerBlock_LPerBlock_BK1();
static constexpr auto b1_block_desc_bl0_n_bl1 =
GetB1BlockDescriptor_BL0PerBlock_NPerBlock_BL1();
static constexpr auto max_lds_align = math::lcm(math::lcm(AK1, BK1), BL1);
static constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
static constexpr auto b0_block_space_size_aligned = math::integer_least_multiple(
b0_block_desc_bk0_l_bk1.GetElementSpaceSize(), max_lds_align);
static constexpr auto b1_block_space_size_aligned = math::integer_least_multiple(
b1_block_desc_bl0_n_bl1.GetElementSpaceSize(), max_lds_align);
static constexpr auto a_block_space_offset = 0;
static constexpr auto b0_block_space_offset = a_block_space_size_aligned.value;
static constexpr auto b1_block_space_offset = 0;
// LDS allocation for reduction
static constexpr index_t reduction_space_size_aligned =
math::integer_least_multiple(BlockSize, max_lds_align);
static constexpr auto reduction_space_offset = 0;
// LDS allocation for C shuffle in LDS
static constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat =
GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat();
static constexpr auto c_block_space_size =
c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat
.GetElementSpaceSize();
};
template <bool HasMainKBlockLoop,
typename C0MatrixMask,
typename Block2CTileMap = DefaultBlock2CTileMap>
__device__ static void Run(const FloatA* __restrict__ p_a_grid,
const FloatB0* __restrict__ p_b0_grid,
const FloatB1* __restrict__ p_b1_grid,
FloatC* __restrict__ p_c_grid,
void* __restrict__ p_shared,
const AGridDesc_AK0_M_AK1& a_grid_desc_k0_m_k1,
const B0GridDesc_BK0_L_BK1& b0_grid_desc_k0_l_k1,
const B1GridDesc_BL0_N_BL1& b1_grid_desc_l0_n_l1,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
c_grid_desc_mblock_mperblock_nblock_nperblock,
const AElementwiseOperation& a_element_op,
const B0ElementwiseOperation& b0_element_op,
const AccElementwiseOperation& acc_element_op,
const B1ElementwiseOperation& b1_element_op,
const CElementwiseOperation& c_element_op,
const C0MatrixMask& c0_matrix_mask,
const Block2CTileMap& block_2_ctile_map)
{
// clang-format off
/*******************************************************************************/
// Memory buffer zone.
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_k0_m_k1.GetElementSpaceSize());
const auto b0_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b0_grid, b0_grid_desc_k0_l_k1.GetElementSpaceSize());
const auto b1_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b1_grid, b1_grid_desc_l0_n_l1.GetElementSpaceSize());
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
/*******************************************************************************/
// BlockIdx.x -> [BlockId.m, BlockId.n]
const auto block_work_idx = block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
if(!block_2_ctile_map.ValidCTileIndex(
block_work_idx,
make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
{ return; }
// Store BlockId into SGPR
const index_t m_block_data_idx_on_grid = __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock);
const index_t n_block_data_idx_on_grid = __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock);
/*******************************************************************************/
// set up Gemm0
/*******************************************************************************/
/*******************************************************************************/
// BlockLevel, A/B Matrix ThreadMapping in LDS, As Destinaion of BlockWise_Copy
const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0);
constexpr auto a_block_desc_k0perblock_mperblock_k1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
constexpr auto b0_block_desc_k0perblock_lperblock_k1 = GetB0BlockDescriptor_BK0PerBlock_LPerBlock_BK1();
// A matrix blockwise copy
auto a_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1< ThisThreadBlock,
/* typename SrcElementwiseOperation, */ AElementwiseOperation,
/* typename DstElementwiseOperation, */ ck::tensor_operation::element_wise::PassThrough,
/* InMemoryDataOperationEnum DstInMemOp, */ InMemoryDataOperationEnum::Set,
/* typename BlockSliceLengths, */ Sequence<AK0, MPerBlock, AK1>,
/* typename ThreadClusterLengths, */ ABlockTransferThreadClusterLengths_K0_M_K1,
/* typename ThreadClusterArrangeOrder, */ ABlockTransferThreadClusterArrangeOrder,
/* typename SrcData, */ FloatA,
/* typename DstData, */ FloatA,
/* typename SrcDesc, */ decltype(a_grid_desc_k0_m_k1),
/* typename DstDesc, */ decltype(a_block_desc_k0perblock_mperblock_k1),
/* typename SrcDimAccessOrder, */ ABlockTransferSrcAccessOrder,
/* typename DstDimAccessOrder, */ Sequence<0, 1, 2>,
/* index_t SrcVectorDim, */ ABlockTransferSrcVectorDim,
/* index_t DstVectorDim, */ 2,
/* index_t SrcScalarPerVector, */ ABlockTransferSrcScalarPerVector,
/* index_t DstScalarPerVector, */ ABlockTransferDstScalarPerVector_K1,
/* index_t SrcScalarStrideInVector, */ 1,
/* index_t DstScalarStrideInVector, */ 1,
/* bool ThreadTransferSrcResetCoordinateAfterRun, */ AThreadTransferSrcResetCoordinateAfterRun,
/* bool ThreadTransferDstResetCoordinateAfterRun, */ true>(
a_grid_desc_k0_m_k1,
make_multi_index(0, m_block_data_idx_on_grid, 0),
a_element_op,
a_block_desc_k0perblock_mperblock_k1,
make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{});
// B matrix blockwise copy
auto b0_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
B0ElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<BK0, LPerBlock, BK1>,
B0BlockTransferThreadClusterLengths_K0_L_K1,
B0BlockTransferThreadClusterArrangeOrder,
FloatB0,
FloatB0,
decltype(b0_grid_desc_k0_l_k1),
decltype(b0_block_desc_k0perblock_lperblock_k1),
B0BlockTransferSrcAccessOrder,
Sequence<0, 1, 2>,
B0BlockTransferSrcVectorDim,
2,
B0BlockTransferSrcScalarPerVector,
B0BlockTransferDstScalarPerVector_K1,
1,
1,
B0ThreadTransferSrcResetCoordinateAfterRun,
true>(
b0_grid_desc_k0_l_k1,
make_multi_index(0, 0, 0),
b0_element_op,
b0_block_desc_k0perblock_lperblock_k1,
make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{});
/*******************************************************************************/
// Gemm0
constexpr auto WmmaK = 16;
constexpr auto KPack = math::integer_least_multiple(K1Value, WmmaK);
auto blockwise_gemm0 = BlockwiseGemmWMMA<
BlockSize,
FloatA,
FloatB0,
FloatAcc0,
decltype(MakeA0BlockDescriptor_K0_M0_M1_M2_K1(a_block_desc_k0perblock_mperblock_k1)),
decltype(MakeB0BlockDescriptor_K0_L0_L1_L2_K1(b0_block_desc_k0perblock_lperblock_k1)),
MPerBlock,
LPerBlock,
K0PerBlock * K1Value,
MPerWmma,
LPerWmma,
MRepeat,
LRepeat,
KPack,
true>{}; // C' = B' x A'
// Prepare Register for A*B0 matrix
auto acc0_thread_buf = blockwise_gemm0.GetCThreadBuffer();
constexpr auto acc0_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs =
blockwise_gemm0.GetCThreadDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs();
constexpr auto mrepeat = acc0_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I0);
constexpr auto mwave = acc0_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I1);
constexpr auto mthreadpersubgroup = acc0_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I2);
constexpr auto lrepeat = acc0_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I3);
constexpr auto lwave = acc0_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I4);
constexpr auto lsubgroup = acc0_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I5);
constexpr auto laccvgprs = acc0_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I6);
constexpr auto acc0_thread_desc_l0perblock_mperblock_l1 = transform_tensor_descriptor(
acc0_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs,
make_tuple(make_merge_transform_v3_division_mod(make_tuple(lrepeat, lwave, lsubgroup)),
make_merge_transform_v3_division_mod(make_tuple(mrepeat, mwave, mthreadpersubgroup)),
make_pass_through_transform(laccvgprs)),
make_tuple(Sequence<3, 4, 5>{}, Sequence<0, 1, 2>{}, Sequence<6>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
/*******************************************************************************/
// LDS allocation for A and B: be careful of alignment
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(static_cast<FloatA*>(p_shared) + SharedMemTrait::a_block_space_offset,
a_block_desc_k0perblock_mperblock_k1.GetElementSpaceSize());
auto b0_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(static_cast<FloatB0*>(p_shared) + SharedMemTrait::b0_block_space_offset,
b0_block_desc_k0perblock_lperblock_k1.GetElementSpaceSize());
// Shift Per SUB_K
constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0);
constexpr auto b0_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0);
const auto a_block_reset_copy_step = make_multi_index(-a_grid_desc_k0_m_k1.GetLength(I0), 0, 0);
const auto b0_block_reset_copy_step = make_multi_index(-b0_grid_desc_k0_l_k1.GetLength(I0), LPerBlock, 0);
const index_t K0BlockMainLoop = __builtin_amdgcn_readfirstlane(K0 / K0PerBlock);
/*******************************************************************************/
// softmax
/*******************************************************************************/
auto workspace_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatAcc0*>(p_shared) + SharedMemTrait::reduction_space_offset,
SharedMemTrait::reduction_space_size_aligned);
// get acc0 7D thread cluster
constexpr auto thread_cluster_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs =
blockwise_gemm0.GetCBlockDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs().GetLengths() /
blockwise_gemm0.GetCThreadDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs().GetLengths();
constexpr auto t_mrepeat = thread_cluster_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.At(I0);
constexpr auto t_mwave = thread_cluster_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.At(I1);
constexpr auto t_mthreadpersubgroup = thread_cluster_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.At(I2);
constexpr auto t_lrepeat = thread_cluster_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.At(I3);
constexpr auto t_lwave = thread_cluster_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.At(I4);
constexpr auto t_lsubgroup = thread_cluster_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.At(I5);
constexpr auto t_laccvgprs = thread_cluster_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.At(I6);
// get acc0 thread map
constexpr auto m0_l_m1_to_m_l_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(t_mrepeat * t_mwave, t_mthreadpersubgroup)),
make_pass_through_transform(I1)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
constexpr auto threadid_to_m0_l_m1_adaptor = make_single_stage_tensor_adaptor(
make_tuple(
make_merge_transform(
make_tuple(t_mrepeat * t_mwave, t_lrepeat * t_lwave * t_lsubgroup * t_laccvgprs, t_mthreadpersubgroup))),
make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{}));
const auto threadid_to_l_n_thread_cluster_adaptor =
chain_tensor_adaptors(m0_l_m1_to_m_l_adaptor, threadid_to_m0_l_m1_adaptor);
// get acc0 2D thread cluster & 2D thread slice
constexpr auto thread_cluster_desc_m_l = make_naive_tensor_descriptor_packed(
make_tuple(t_mrepeat * t_mwave * t_mthreadpersubgroup, t_lrepeat * t_lwave * t_lsubgroup * t_laccvgprs));
constexpr auto thread_slice_desc_m_l = make_naive_tensor_descriptor_packed(
make_tuple(mrepeat * mwave * mthreadpersubgroup, lrepeat * lwave * lsubgroup * laccvgprs));
auto blockwise_softmax = BlockwiseSoftmax<BlockSize,
FloatAcc0,
decltype(threadid_to_l_n_thread_cluster_adaptor),
decltype(thread_cluster_desc_m_l),
decltype(thread_slice_desc_m_l)>{};
// Initialize running sum and max of exponentiating row vectors
using SoftmaxBuf = typename decltype(blockwise_softmax)::BufferType;
SoftmaxBuf running_sum, running_sum_new, running_max, running_max_new;
running_sum = 0;
running_sum_new = 0;
running_max = NumericLimits<FloatAcc0>::Lowest();
running_max_new = NumericLimits<FloatAcc0>::Lowest();
/*******************************************************************************/
// set up Gemm1
/*******************************************************************************/
// B1 matrix in LDS memory, dst of blockwise copy
constexpr auto b1_block_desc_l0perblock_nperblock_l1 = GetB1BlockDescriptor_BL0PerBlock_NPerBlock_BL1();
constexpr auto b1_block_slice_copy_step = make_multi_index(BL0, 0, 0);
// A1 matrix in VGPR
constexpr auto A1ThreadSlice_L0PerBlock_MPerBlock_L1 = make_tuple(
Number<AL0 * AL1 / laccvgprs>{},
Number<mrepeat * mwave * mthreadpersubgroup>{},
Number<laccvgprs>{}); // Data duplicated dimension
constexpr auto A1ThreadSliceL0PerBlock = A1ThreadSlice_L0PerBlock_MPerBlock_L1[I0];
constexpr auto A1ThreadSliceMPerBlock = A1ThreadSlice_L0PerBlock_MPerBlock_L1[I1];
constexpr auto A1ThreadSliceL1 = A1ThreadSlice_L0PerBlock_MPerBlock_L1[I2];
// A1 has duplicated data
constexpr auto A1ThreadDuplicatedDim = I2 * A1ThreadSliceL1;
constexpr auto a1_thread_desc_l0perblock_mperblock_l1 = make_naive_tensor_descriptor(
make_tuple(A1ThreadSliceL0PerBlock, A1ThreadSliceMPerBlock, A1ThreadDuplicatedDim),
make_tuple(A1ThreadSliceMPerBlock * A1ThreadDuplicatedDim, A1ThreadDuplicatedDim, I1));
// A1 matrix blockwise copy
auto a1_blockwise_copy = ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow<
FloatAcc0,
FloatA,
decltype(acc0_thread_desc_l0perblock_mperblock_l1),
decltype(a1_thread_desc_l0perblock_mperblock_l1),
tensor_operation::element_wise::PassThrough,
Sequence<A1ThreadSliceL0PerBlock, A1ThreadSliceMPerBlock, A1ThreadSliceL1>,
Sequence<0, 1, 2>,
2,
laccvgprs,
// dst Rowlane
// 0x76543210 0xfedcba98
// src Rowlane
0x76543210, 0xfedcba98,
false>{tensor_operation::element_wise::PassThrough{}};
// B1 matrix blockwise copy
auto b1_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1< ThisThreadBlock,
/* typename SrcElementwiseOperation, */ B1ElementwiseOperation,
/* typename DstElementwiseOperation, */ tensor_operation::element_wise::PassThrough,
/* InMemoryDataOperationEnum DstInMemOp, */ InMemoryDataOperationEnum::Set,
/* typename BlockSliceLengths, */ Sequence<BL0, NPerBlock, BL1>,
/* typename ThreadClusterLengths, */ B1BlockTransferThreadClusterLengths_L0_N_L1,
/* typename ThreadClusterArrangeOrder, */ B1BlockTransferThreadClusterArrangeOrder,
/* typename SrcData, */ FloatB1,
/* typename DstData, */ FloatB1,
/* typename SrcDesc, */ decltype(b1_grid_desc_l0_n_l1),
/* typename DstDesc, */ decltype(b1_block_desc_l0perblock_nperblock_l1),
/* typename SrcDimAccessOrder, */ B1BlockTransferSrcAccessOrder,
/* typename DstDimAccessOrder, */ Sequence<1, 0, 2>,
/* index_t SrcVectorDim, */ B1BlockTransferSrcVectorDim,
/* index_t DstVectorDim, */ 2,
/* index_t SrcScalarPerVector, */ B1BlockTransferSrcScalarPerVector,
/* index_t DstScalarPerVector, */ B1BlockTransferDstScalarPerVector_L1,
/* index_t SrcScalarStrideInVector, */ 1,
/* index_t DstScalarStrideInVector, */ 1,
/* bool ThreadTransferSrcResetCoordinateAfterRun, */ B1ThreadTransferSrcResetCoordinateAfterRun,
/* bool ThreadTransferDstResetCoordinateAfterRun, */ true, // DstResetCoord
NumGemmKPrefetchStage>(
b1_grid_desc_l0_n_l1,
make_multi_index(0, n_block_data_idx_on_grid, 0),
b1_element_op,
b1_block_desc_l0perblock_nperblock_l1,
make_multi_index(0, 0, 0),
tensor_operation::element_wise::PassThrough{});
auto a1_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatA>(
a1_thread_desc_l0perblock_mperblock_l1.GetElementSpaceSize());
auto b1_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatB1*>(p_shared)+ SharedMemTrait::b1_block_space_offset,
b1_block_desc_l0perblock_nperblock_l1.GetElementSpaceSize());
auto blockwise_gemm1 =
BlockwiseGemmWMMA<BlockSize,
FloatA,
FloatB1,
FloatAcc1,
decltype(MakeA1BlockDescriptor_L0_M0_M1_M2_L1(a1_thread_desc_l0perblock_mperblock_l1)),
decltype(MakeB1BlockDescriptor_L0_N0_N1_N2_L1(b1_block_desc_l0perblock_nperblock_l1)),
MPerBlock,
NPerBlock,
BL0 * BL1,
MPerWmma,
NPerWmma,
MRepeat,
NRepeat,
KPack,
true>{make_tuple(0, 0, 0, 0, 0)};
auto acc1_thread_buf = blockwise_gemm1.GetCThreadBuffer();
const index_t num_gemm1_l_block_outer_loop = b0_grid_desc_k0_l_k1.GetLength(I1) / LPerBlock;
constexpr index_t num_gemm1_l_block_inner_loop = LPerBlock / (BL0 * BL1);
// Initialize C
StaticBuffer<AddressSpaceEnum::Vgpr, FloatAcc1, acc1_thread_buf.Size(), true> c_thread_buf;
c_thread_buf.Clear();
/*******************************************************************************/
// Flash Attention
// Dao, Tri, et al. "Flashattention: Fast and memory-efficient exact attention with io-awareness." arXiv preprint arXiv:2205.14135 (2022).
index_t gemm1_l_block_outer_index = 0;
// Outer loop, along GEMM_L
// Inner loop, along GEMM_K
do{
auto l_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(gemm1_l_block_outer_index * LPerBlock);
if(c0_matrix_mask.IsTileSkippable(
m_block_data_idx_on_grid, l_block_data_idx_on_grid, MPerBlock, LPerBlock))
{
continue;
}
// gemm0 start, A-B swaped
GridwiseGemmPipe::template Run<HasMainKBlockLoop>(a_grid_desc_k0_m_k1,
a_block_desc_k0perblock_mperblock_k1,
a_blockwise_copy,
a_grid_buf,
a_block_buf,
a_block_slice_copy_step,
b0_grid_desc_k0_l_k1,
b0_block_desc_k0perblock_lperblock_k1,
b0_blockwise_copy,
b0_grid_buf,
b0_block_buf,
b0_block_slice_copy_step,
blockwise_gemm0,
acc0_thread_buf,
K0BlockMainLoop);
// do MNK padding or upper triangular masking
if constexpr(MaskOutUpperTriangle || PadN)
{
// 7d thread_desc in thread scope
constexpr auto c_thread_lengths =
blockwise_gemm0.GetCThreadDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs().GetLengths();
// 7d block_desc in block scope
constexpr auto c_block_lengths =
blockwise_gemm0.GetCBlockDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs().GetLengths();
constexpr auto MREPEAT = c_block_lengths[I0];
constexpr auto MWAVE = c_block_lengths[I1];
constexpr auto MTHREADSubGroup = c_block_lengths[I2];
constexpr auto LREPEAT = c_block_lengths[I3];
constexpr auto LWAVE = c_block_lengths[I4];
constexpr auto LSUBGROUP = c_block_lengths[I5];
constexpr auto LACCVGPRS = c_block_lengths[I6];
// works like multi-dimension static_for (static_ford), but provides both the linear
// index as well as n-d index
using Acc0TileIterator = SpaceFillingCurve<
decltype(c_thread_lengths),
typename arithmetic_sequence_gen<0, c_thread_lengths.Size(), 1>::type,
typename uniform_sequence_gen<c_thread_lengths.Size(), 1>::type,
false>; // SnakeCurved
auto acc0_thread_origin = blockwise_gemm0.CalculateCThreadOriginDataIndex7D(
Number<0>{}, Number<0>{});
constexpr auto block_idx_to_m_l_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(MREPEAT, MWAVE, MTHREADSubGroup)),
make_unmerge_transform(make_tuple(LREPEAT, LWAVE, LSUBGROUP, LACCVGPRS))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5, 6>{}));
static_for<0, Acc0TileIterator::GetNumOfAccess(), 1>{}([&](auto i) {
auto acc0_thread_idx = Acc0TileIterator::GetIndex(i) + acc0_thread_origin;
auto m_local = block_idx_to_m_l_adaptor.CalculateBottomIndex(acc0_thread_idx)[I0];
auto l_local = block_idx_to_m_l_adaptor.CalculateBottomIndex(acc0_thread_idx)[I1];
auto m_global = m_local + m_block_data_idx_on_grid;
auto l_global = l_local + l_block_data_idx_on_grid;
if(c0_matrix_mask.IsMaskedElement(m_global, l_global))
{
acc0_thread_buf(i) = -ck::NumericLimits<float>::Infinity();
}
else
{
acc_element_op(acc0_thread_buf(i), acc0_thread_buf[i]);
}
});
}
else
{ static_for<0, acc0_thread_buf.Size(), 1>{}(
[&](auto i) { acc_element_op(acc0_thread_buf(i), acc0_thread_buf[i]); });
}
block_sync_lds();
// gemm0 end
// gemm0 incorrect
// Tiled softmax start
// softmax
SoftmaxBuf& max = blockwise_softmax.max_value_buf;
SoftmaxBuf& sum = blockwise_softmax.sum_value_buf;
blockwise_softmax.Run(acc0_thread_buf, workspace_buf);
// TODO: may convert to log domain
running_max_new = mathext::max(max, running_max);
running_sum_new = mathext::exp(running_max - running_max_new) * running_sum +
mathext::exp(max - running_max_new) * sum;
// gemm1
{
// TODO: explore using dynamic buffer for a1 thread buffer
// For a1_blockwise_copy, the goal is to satisfy pipeline requirements RunRead(),
// RunWrite(), and MoveSliceWindow(). But it is impossible to implement given that
// the A1 source buffer is static buffer holding the output of first GEMM and
// requires constexpr offset by design. Therefore, we pass tensor coordinate offset
// explicitly in Run() below.
// Initialize acc1
acc1_thread_buf.Clear();
// preload data into LDS
b1_blockwise_copy.RunRead(b1_grid_desc_l0_n_l1, b1_grid_buf);
b1_blockwise_copy.MoveSrcSliceWindow(b1_grid_desc_l0_n_l1,
b1_block_slice_copy_step);
block_sync_lds(); // wait for reduction LDS read
b1_blockwise_copy.RunWrite(b1_block_desc_l0perblock_nperblock_l1, b1_block_buf);
// main body
if constexpr(num_gemm1_l_block_inner_loop > 1)
{
static_for<0, num_gemm1_l_block_inner_loop - 1, 1>{}([&](auto i) {
// Data cast from FloatAcc0 to FloatA happen here
a1_blockwise_copy.Run(acc0_thread_desc_l0perblock_mperblock_l1,
make_tuple(Number<i * A1ThreadSliceL0PerBlock>{}, I0, I0),
acc0_thread_buf,
a1_thread_desc_l0perblock_mperblock_l1,
make_tuple(I0, I0, I0),
a1_thread_buf);
b1_blockwise_copy.RunRead(b1_grid_desc_l0_n_l1, b1_grid_buf);
block_sync_lds();
blockwise_gemm1.Run(a1_thread_buf, b1_block_buf, acc1_thread_buf);
block_sync_lds();
b1_blockwise_copy.MoveSrcSliceWindow(b1_grid_desc_l0_n_l1,
b1_block_slice_copy_step);
b1_blockwise_copy.RunWrite(b1_block_desc_l0perblock_nperblock_l1, b1_block_buf);
});
}
// tail
{
a1_blockwise_copy.Run(
acc0_thread_desc_l0perblock_mperblock_l1,
make_tuple(
Number<(num_gemm1_l_block_inner_loop - 1) * A1ThreadSliceL0PerBlock>{}, I0, I0),
acc0_thread_buf,
a1_thread_desc_l0perblock_mperblock_l1,
make_tuple(I0, I0, I0),
a1_thread_buf);
block_sync_lds();
blockwise_gemm1.Run(a1_thread_buf, b1_block_buf, acc1_thread_buf);
}
} // end gemm1
constexpr auto c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs =
blockwise_gemm1.GetCThreadDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs();
constexpr auto c_mrepeat = c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I0);
constexpr auto c_mwave = c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I1);
constexpr auto c_mthreadpersubgroup = c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I2);
constexpr auto c_nrepeat = c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I3);
constexpr auto c_nwave = c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I4);
constexpr auto c_nsubgroup = c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I5);
constexpr auto c_naccvgprs = c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I6);
constexpr auto c_thread_slice_desc_m_n = make_naive_tensor_descriptor_packed(
make_tuple(c_mrepeat * c_mwave * c_mthreadpersubgroup,
c_nrepeat * c_nwave * c_nsubgroup * c_naccvgprs));
constexpr auto c_thread_buf_slice_m = c_thread_slice_desc_m_n.GetLength(I0);
constexpr auto c_thread_buf_slice_n = c_thread_slice_desc_m_n.GetLength(I1);
static_for<0, c_thread_buf_slice_m, 1>{}([&](auto iM) {
static_for<0, c_thread_buf_slice_n, 1>{}([&](auto iN) {
auto I = Number<c_thread_slice_desc_m_n.CalculateOffset(make_tuple(iM, iN))>{};
FloatAcc1 acc1 = acc1_thread_buf[I]; // P*V
FloatAcc1 c = c_thread_buf[I]; // O
FloatAcc1 c_new =
(running_sum[iM] * math::exp(running_max[iM] - running_max_new[iM]) * c +
math::exp(max[iM] - running_max_new[iM]) * acc1) /
running_sum_new[iM];
c_thread_buf(I) = c_new; // O_new
});
});
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m_k1,
a_block_reset_copy_step); // rewind K
b0_blockwise_copy.MoveSrcSliceWindow(b0_grid_desc_k0_l_k1,
b0_block_reset_copy_step); // rewind K and step N
// update before next j iteration
running_max = running_max_new;
running_sum = running_sum_new;
block_sync_lds(); // wait for gemm1 LDS read
}while(++gemm1_l_block_outer_index < num_gemm1_l_block_outer_loop);
/*******************************************************************************/
// write out to C, implement shuffle
{
constexpr auto c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs =
blockwise_gemm1.GetCThreadDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs();
// This API Provide All dimension (size) you need
constexpr auto c_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs_tmp =
blockwise_gemm1.GetCBlockDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs();
constexpr auto MWave = c_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs_tmp.GetLength(I1);
constexpr auto MThreadPerSubGroup = c_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs_tmp.GetLength(I2);
constexpr auto NWave = c_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs_tmp.GetLength(I4);
constexpr auto NSubGroup = c_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs_tmp.GetLength(I5);
constexpr auto NAccVgprs = c_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs_tmp.GetLength(I6);
// LDS descriptor, shuffle and write out in MRepeat x NRepeat times
constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat =
GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat();
auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatCShuffle*>(p_shared),
c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat.GetElementSpaceSize());
constexpr auto c_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs = transform_tensor_descriptor(
c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat,
make_tuple(
make_freeze_transform(I0),
make_unmerge_transform(make_tuple(
Number<CShuffleMRepeatPerShuffle>{}, // MRepeat per shuffle repeat
MWave, // MWave
MThreadPerSubGroup // MThreadPerSubGroup = MPerWmma
)),
make_freeze_transform(I0),
make_unmerge_transform(make_tuple(
Number<CShuffleNRepeatPerShuffle>{}, // NRepeat per shuffle repeat
NWave, // NWave
NSubGroup,
NAccVgprs))), // NSubGroup * NAccVgprs = NPerWmma
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<>{}, Sequence<0, 1, 2>{}, Sequence<>{}, Sequence<3, 4, 5, 6>{}));
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const auto c_thread_mtx_on_block = blockwise_gemm1.CalculateCThreadOriginDataIndex(I0, I0);
const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
const auto m_thread_data_on_block_to_mrepeat_mwave_mthreadpersubgroup_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(MRepeat, MWave, MThreadPerSubGroup))),
make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{}));
const auto n_thread_data_on_block_to_nrepeat_nwave_nsubgroup_naccvgprs_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(NRepeat, NWave, NSubGroup, NAccVgprs))),
make_tuple(Sequence<0, 1, 2, 3>{}),
make_tuple(Sequence<0>{}));
const auto m_thread_data_on_block_idx = m_thread_data_on_block_to_mrepeat_mwave_mthreadpersubgroup_adaptor.CalculateBottomIndex(
make_multi_index(m_thread_data_on_block));
const auto n_thread_data_on_block_idx = n_thread_data_on_block_to_nrepeat_nwave_nsubgroup_naccvgprs_adaptor.CalculateBottomIndex(
make_multi_index(n_thread_data_on_block));
// shuffle: threadwise copy C from VGPR to LDS
auto c_thread_copy_vgpr_to_lds =
ThreadwiseTensorSliceTransfer_v1r3<FloatAcc1,
FloatCShuffle,
decltype(c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs),
decltype(c_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs),
ck::tensor_operation::element_wise::PassThrough,
Sequence<CShuffleMRepeatPerShuffle,
I1,
I1,
CShuffleNRepeatPerShuffle,
I1,
I1,
NAccVgprs>,
Sequence<0, 1, 2, 3, 4, 5, 6>,
6,
8, // vector write pixel
InMemoryDataOperationEnum::Set,
1,
true>{
c_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs,
make_multi_index(0,
m_thread_data_on_block_idx[I1],
m_thread_data_on_block_idx[I2],
0,
n_thread_data_on_block_idx[I1],
n_thread_data_on_block_idx[I2],
n_thread_data_on_block_idx[I3]),
ck::tensor_operation::element_wise::PassThrough{}};
// shuffle: blockwise copy C from LDS to global
auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1<
ThisThreadBlock, // ThreadGroup
CElementwiseOperation, // ElementwiseOperation,
CGlobalMemoryDataOperation, // DstInMemOp,
Sequence<1,
CShuffleMRepeatPerShuffle * MWave * MPerWmma,
1,
CShuffleNRepeatPerShuffle * NWave * NPerWmma>, // BlockSliceLengths,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
FloatCShuffle, // typename SrcData,
FloatC, // typename DstData,
decltype(c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat),
decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
3, // index_t VectorDim,
CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector,
true, // bool ThreadTransferSrcResetCoordinateAfterRun,
false> // bool ThreadTransferDstResetCoordinateAfterRun>
{c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat,
make_multi_index(0, 0, 0, 0),
c_grid_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0),
c_element_op};
// space filling curve for local reg & global memory
// space filling curve for threadwise C in VGPR
constexpr auto sfc_c_vgpr =
SpaceFillingCurve<Sequence<MRepeat, 1, 1, NRepeat, 1, 1, NAccVgprs>,
Sequence<0, 1, 2, 3, 4, 5, 6>,
Sequence<CShuffleMRepeatPerShuffle,
1,
1,
CShuffleNRepeatPerShuffle,
1,
1,
NAccVgprs>>{};
// space filling curve for shuffled blockwise C in global mem
constexpr auto sfc_c_global =
SpaceFillingCurve<Sequence<1, MPerBlock, 1, NPerBlock>,
Sequence<0, 2, 1, 3>,
Sequence<1,
CShuffleMRepeatPerShuffle * MWave * MPerWmma,
1,
CShuffleNRepeatPerShuffle * NWave * NPerWmma>>{};
constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!");
static_for<0, num_access, 1>{}([&](auto access_id) {
// make sure it's safe to write to LDS
block_sync_lds();
// each thread write its data from VGPR to LDS
c_thread_copy_vgpr_to_lds.Run(c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs,
sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
c_thread_buf,
c_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs,
c_shuffle_block_buf);
// make sure it's safe to read from LDS
block_sync_lds();
// each block copy its data from LDS to global
c_shuffle_block_copy_lds_to_global.Run(
c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat,
c_shuffle_block_buf,
c_grid_desc_mblock_mperblock_nblock_nperblock,
c_grid_buf);
if constexpr(access_id < num_access - 1)
{
constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id);
// move on C
c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow(
c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
}
});
}
// clang-format on
}
};
} // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment