Commit 6985af40 authored by wangshaojie6's avatar wangshaojie6
Browse files

init code

parent 63914743
add_example_executable(example_gemm_gemm_xdl_fp16 gemm_gemm_xdl_fp16.cpp)
\ No newline at end of file
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_reduce_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/host_tensor/device_memory.hpp"
#include "ck/library/host_tensor/host_tensor.hpp"
#include "ck/library/host_tensor/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
#include "ck/library/utility/check_err.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using QDataType = F16;
using KDataType = F16;
using PDataType = F16;
using VDataType = F16;
using RDataType = F16;
using GemmAccDataType = F32;
using QLayout = Row;
using KLayout = Col;
using PLayout = Row;
using VLayout = Row;
using RLayout = Row;
using QElementOp = ck::tensor_operation::element_wise::PassThrough;
using KElementOp = ck::tensor_operation::element_wise::PassThrough;
using PElementOp = ck::tensor_operation::element_wise::PassThrough;
using VElementOp = ck::tensor_operation::element_wise::PassThrough;
using RElementOp = ck::tensor_operation::element_wise::PassThrough;
//static constexpr auto GemmSpecialization =
// ck::tensor_operation::device::GemmSpecialization::Default;
using ReferenceGemmInstanceQKP = ck::tensor_operation::host::ReferenceBatchedGemm<QDataType,
KDataType,
PDataType,
QElementOp,
KElementOp,
PElementOp>;
using ReferenceGemmInstancePVR = ck::tensor_operation::host::ReferenceBatchedGemm<PDataType,
VDataType,
RDataType,
PElementOp,
VElementOp,
RElementOp>;
int main(int argc, char* argv[])
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = false;
// GEMM shape
ck::index_t N_ = 1024;
ck::index_t d_ = 64;
#if 0
ck::index_t M_QKP = N_;
ck::index_t N_QKP = N_;
ck::index_t K_QKP = d_;
ck::index_t M_PVR = N_;
ck::index_t N_PVR = d_;
ck::index_t K_PVR = N_;
ck::index_t StrideQ = d_;
ck::index_t StrideK = d_;
ck::index_t StrideP = N_;
ck::index_t StrideV = d_;
ck::index_t StrideR = d_;
#endif
ck::index_t BatchCount = 8 * 12;
if(argc == 1)
{
// do nothing
}
else if(argc == 4)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
}
else if(argc == 7)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
N_ = std::stoi(argv[4]);
d_ = std::stoi(argv[5]);
BatchCount = std::stoi(argv[6]);
#if 0
M_QKP = N_;
N_QKP = N_;
K_QKP = d_;
M_PVR = N_;
N_PVR = d_;
K_PVR = N_;
StrideQ = d_;
StrideK = d_;
StrideP = N_;
StrideV = d_;
StrideR = d_;
#endif
}
else
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: run kernel # of times (>1)\n");
printf("arg4 to 6: S (256x), d(128x), BatchCount(32x)\n");
exit(0);
}
auto f_host_tensor_descriptor = [](std::size_t batch_count,
std::size_t row,
std::size_t col,
std::size_t stride,
auto layout) {
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
{
return HostTensorDescriptor(std::vector<std::size_t>({batch_count, row, col}),
std::vector<std::size_t>({row * stride, stride, 1}));
}
else
{
return HostTensorDescriptor(std::vector<std::size_t>({batch_count, row, col}),
std::vector<std::size_t>({col * stride, 1, stride}));
}
};
Tensor<QDataType> q_g_n_d(f_host_tensor_descriptor(BatchCount, N_, d_, d_, QLayout{}));
Tensor<KDataType> k_g_d_n(f_host_tensor_descriptor(BatchCount, d_, N_, d_, KLayout{}));
Tensor<PDataType> p_g_n_n(f_host_tensor_descriptor(BatchCount, N_, N_, N_, PLayout{}));
Tensor<VDataType> v_g_n_d(f_host_tensor_descriptor(BatchCount, N_, d_, d_, VLayout{}));
Tensor<RDataType> r_g_n_d_host_result(f_host_tensor_descriptor(BatchCount, N_, d_, d_, RLayout{}));
Tensor<RDataType> r_g_n_d_device_result(f_host_tensor_descriptor(BatchCount, N_, d_, d_, RLayout{}));
std::cout << "q_g_n_d: " << q_g_n_d.mDesc << std::endl;
std::cout << "k_g_d_n: " << k_g_d_n.mDesc << std::endl;
std::cout << "p_g_n_n: " << p_g_n_n.mDesc << std::endl;
std::cout << "v_g_n_d: " << v_g_n_d.mDesc << std::endl;
std::cout << "r_g_n_d: " << r_g_n_d_host_result.mDesc << std::endl;
std::cout << "time kernel: " << time_kernel << std::endl;
switch (init_method)
{
case 0:
break;
case 1:
q_g_n_d.GenerateTensorValue(GeneratorTensor_2<QDataType>{-5, 5});
k_g_d_n.GenerateTensorValue(GeneratorTensor_2<KDataType>{-5, 5});
v_g_n_d.GenerateTensorValue(GeneratorTensor_2<VDataType>{-5, 5});
break;
default:
q_g_n_d.GenerateTensorValue(GeneratorTensor_3<QDataType>{0.0, 1.0});
k_g_d_n.GenerateTensorValue(GeneratorTensor_3<KDataType>{-0.5, 0.5});
v_g_n_d.GenerateTensorValue(GeneratorTensor_3<VDataType>{-0.5, 0.5});
break;
}
auto q_element_op = QElementOp{};
auto k_element_op = KElementOp{};
auto v_element_op = VElementOp{};
auto p_element_op = PElementOp{};
auto r_element_op = RElementOp{};
DeviceMem q_device_buf(sizeof(QDataType) * q_g_n_d.mDesc.GetElementSpace());
DeviceMem k_device_buf(sizeof(KDataType) * k_g_d_n.mDesc.GetElementSpace());
DeviceMem v_device_buf(sizeof(VDataType) * v_g_n_d.mDesc.GetElementSpace());
DeviceMem r_device_buf(sizeof(RDataType) *
r_g_n_d_device_result.mDesc.GetElementSpace());
q_device_buf.ToDevice(q_g_n_d.mData.data());
k_device_buf.ToDevice(k_g_d_n.mData.data());
v_device_buf.ToDevice(v_g_n_d.mData.data());
// bool pass = true;
if(do_verification)
{
auto ref_batched_gemmQKP = ReferenceGemmInstanceQKP{};
auto ref_invokerQKP = ref_batched_gemmQKP.MakeInvoker();
auto ref_argumentQKP = ref_batched_gemmQKP.MakeArgument(
q_g_n_d, k_g_d_n, p_g_n_n, q_element_op, k_element_op, p_element_op);
auto ref_batched_gemmPVR = ReferenceGemmInstancePVR{};
auto ref_invokerPVR = ref_batched_gemmPVR.MakeInvoker();
auto ref_argumentPVR = ref_batched_gemmPVR.MakeArgument(
p_g_n_n, v_g_n_d, r_g_n_d_host_result, p_element_op, v_element_op, r_element_op);
ref_invokerQKP.Run(ref_argumentQKP);
ref_invokerPVR.Run(ref_argumentPVR);
}
}
...@@ -45,3 +45,4 @@ add_subdirectory(23_softmax) ...@@ -45,3 +45,4 @@ add_subdirectory(23_softmax)
add_subdirectory(24_batched_gemm_c_permute) add_subdirectory(24_batched_gemm_c_permute)
add_subdirectory(25_gemm_bias_c_permute) add_subdirectory(25_gemm_bias_c_permute)
add_subdirectory(26_contraction) add_subdirectory(26_contraction)
add_subdirectory(27_gemm_gemm)
#pragma once
#include <iostream>
#include <vector>
#include "device_base.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
struct BatchedGemmGemmCShuffleDesc
{
ck::index_t G0_, G1_, M_, N_;
ck::index_t stride_G0_, stride_G1_, stride_M_, stride_N_;
};
template <typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
struct DeviceBatchedGemmGemmCShuffle : public BaseOperator
{
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b,
void* p_c,
index_t M,
index_t N,
index_t K,
index_t stride_A,
index_t stride_B,
BatchedGemmCPermuteDesc batched_gemm_c_permute_desc,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
ck::index_t BatchCount) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
template <typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
using DeviceBatchedGemmCPermutePtr = std::unique_ptr<
DeviceBatchedGemmCPermute<AElementwiseOperation, BElementwiseOperation, CElementwiseOperation>>;
} // namespace device
} // namespace tensor_operation
} // namespace ck
#pragma once
#include <iostream>
#include <sstream>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/device_utility/device_prop.hpp"
#include "ck/device_utility/kernel_launch.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_gemm_c_shuffle.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_gemm_xdl_skip_lds.hpp"
template <typename GridwiseGemm,
typename FloatAB,
typename FloatC,
typename QGridDesc_K0_M_K1,
typename KGridDesc_K0_N_K1,
typename VGridDesc_K0_N_K1,
typename RGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
typename QElementwiseOperation,
typename KElementwiseOperation,
typename VElementwiseOperation,
typename PElementwiseOperation,
typename RElementwiseOperation,
typename ComputePtrOffsetOfBatch,
typename Block2CTileMap,
bool HasMainKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_batched_gemm_gemm_c_shuffle_xdl(const FloatAB* __restrict__ p_q_grid,
const FloatAB* __restrict__ p_k_grid,
const FloatAB* __restrict__ p_v_grid,
FloatC* __restrict__ p_o_grid,
const index_t batch_count,
const QGridDesc_K0_M_K1 q_grid_desc_k0_m_k1,
const KGridDesc_K0_N_K1 k_grid_desc_k0_n_k1,
const VGridDesc_K0_N_K1 v_grid_desc_k0_n_k1,
const RGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
r_grid_desc_mblock_mperblock_nblock_nperblock,
const QElementwiseOperation q_element_op,
const KElementwiseOperation k_element_op,
const VElementwiseOperation v_element_op,
const PElementwiseOperation p_element_op,
const RElementwiseOperation r_element_op,
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
const Block2CTileMap block_2_ctile_map)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
const long_index_t q_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetQPtrOffset(g_idx)));
const long_index_t k_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetKPtrOffset(g_idx)));
const long_index_t v_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetVPtrOffset(g_idx)));
const long_index_t o_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetRPtrOffset(g_idx)));
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainKBlockLoop>(
p_q_grid + a_batch_offset,
p_k_grid + k_batch_offset,
p_v_grid + v_batch_offset,
ck::Tuple<>{},
p_o_grid + o_batch_offset,
p_shared,
q_element_op,
k_element_op,
v_element_op,
p_element_op,
r_element_op,
q_grid_desc_k0_m_k1,
k_grid_desc_k0_n_k1,
v_grid_desc_k0_n_k1,
ck::StaticallyIndexedArray<
typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
0>{},
r_grid_desc_mblock_mperblock_nblock_nperblock,
block_2_ctile_map);
#else
ignore = p_q_grid;
ignore = p_v_grid;
ignore = p_k_grid;
ignore = p_o_grid;
ignore = batch_count;
ignore = q_grid_desc_k0_m_k1;
ignore = k_grid_desc_k0_m_k1
ignore = v_grid_desc_k0_n_k1;
ignore = r_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = q_element_op;
ignore = k_element_op;
ignore = v_element_op;
ignore = p_element_op;
ignore = r_element_op;
ignore = compute_ptr_offset_of_batch;
ignore = block_2_ctile_map;
#endif
}
template <typename QLayout,
typename KLayout,
typename VLayout,
typename RLayout,
typename KDataType,
typename QDataType,
typename VDataType,
typename ODataType,
typename AccDataType,
typename KElementwiseOperation,
typename QElementwiseOperation,
typename VElementwiseOperation,
typename PElementwiseOperation,
typename OlementwiseOperation,
GemmSpecialization GemmSpec,
ck::index_t NumPrefetch,
ck::index_t BlockSize,
ck::index_t MPerBlock,
ck::index_t NPerBlock,
ck::index_t KPerBlock,
ck::index_t QK1,
ck::index_t KK1,
ck::index_t VK1,
ck::index_t QKMPerXDL,
ck::index_t QKNPerXDL,
ck::index_t QKMXdlPerWave,
ck::index_t QKNXdlPerWave,
ck::index_t PVMPerXDL,
ck::index_t PVNPerXDL,
ck::index_t PVMXdlPerWave,
ck::index_t PVNXdlPerWave,
typename QBlockTransferThreadClusterLengths_K0_M_K1,
typename QBlockTransferThreadClusterArrangeOrder,
typename QBlockTransferSrcAccessOrder,
ck::index_t QBlockTransferSrcVectorDim,
ck::index_t QBlockTransferSrcScalarPerVector,
ck::index_t QBlockTransferDstScalarPerVector_K1,
ck::index_t QBlockLdsAddExtraM,
typename KBlockTransferThreadClusterLengths_K0_N_K1,
typename KBlockTransferThreadClusterArrangeOrder,
typename KBlockTransferSrcAccessOrder,
ck::index_t KBlockTransferSrcVectorDim,
ck::index_t KBlockTransferSrcScalarPerVector,
ck::index_t KBlockTransferDstScalarPerVector_K1,
ck::index_t KBlockLdsAddExtraN,
typename VBlockTransferThreadClusterLengths_K0_N_K1,
typename VBlockTransferThreadClusterArrangeOrder,
typename VBlockTransferSrcAccessOrder,
ck::index_t VBlockTransferSrcVectorDim,
ck::index_t VBlockTransferSrcScalarPerVector,
ck::index_t VBlockTransferDstScalarPerVector_K1,
ck::index_t VBlockLdsAddExtraN,
index_t CShuffleMXdlPerWavePerShuffle,
index_t CShuffleNXdlPerWavePerShuffle,
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CDEBlockTransferScalarPerVector_NPerBlock,
LoopScheduler LoopSched = make_default_loop_scheduler()>
struct DeviceBatchedGemmGemmCShuffleXdl : public DeviceBatchedGemmGemmCShuffle<AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation>
{
using DeviceOp = DeviceBatchedGemmGemmCShuffleXdl;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static auto MakeQGridDescriptor_QK0_M_QK1(index_t M, index_t K, index_t StrideQ)
{
// not pad M or K
assert(K % QK1 == 0);
const auto QK0 = K / QK1;
const auto q_grid_desc_k0_m_k1 = [&](){
return make_naive_tensor_descriptor(make_tuple(QK0, M, QK1),
make_tuple(M * QK1, QK1, I1));
}
return q_grid_desc_qk0_m_qk1;
}
static auto MakeKGridDescriptor_KK0_N_KK1(index_t N, index_t K, index_t StrideK)
{
// not pad M or K
assert(K % KK1 == 0);
const auto KK0 = K / KK1;
const auto k_grid_desc_kk0_n_kk1 = make_naive_tensor_descriptor(make_tuple(KK0, N, KK1),
make_tuple(KK1 * N, KK1, I1));
return k_grid_desc_kk0_n_kk1;
}
static auto MakeVGridDescriptor_VK0_N_VK1(index_t N, index_t K, index_t StrideV)
{
// not pad M or K
assert(K % VK1 == 0);
const auto VK0 = K / VK1;
const auto v_grid_desc_vk0_n_vk1 = make_naive_tensor_descriptor(make_tuple(VK0, N, VK1),
make_tuple(VK1 * N, VK1, I1));
return v_grid_desc_vk0_n_vk1;
}
static auto
MakeOGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t stride_M, index_t stride_N)
{
const auto o_grid_desc_mraw_nraw = [&]() {
return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
make_tuple(stride_M, stride_N));
}();
return o_grid_desc_mraw_nraw;
}
using QGridDesc_K0_M_K1 = decltype(MakeQGridDescriptor_QK0_M_QK1(1, 1, 1));
using KGridDesc_K0_N_K1 = decltype(MakeKGridDescriptor_KK0_N_KK1(1, 1, 1));
using VGridDesc_K0_N_K1 = decltype(MakeVGridDescriptor_VK0_N_VK1(1, 1, 1));
using OGridDesc_M_N = decltype(MakeOGridDescriptor_M_N(1, 1, 1, 1));
// GridwiseGemm
using GridwiseGemm = GridwiseGemmGemmXdlopsSkipLdsV1<
BlockSize,
KDataType, // TODO: distinguish A/B datatype
AccDataType,
ODataType,
InMemoryDataOperationEnum::Set,
QGridDesc_K0_M_K1,
KGridDesc_K0_N_K1,
VGridDesc_K0_N_K1,
OGridDesc_M_N,
QElementwiseOperation,
KElementwiseOperation,
VElementwiseOperation,
PElementwiseOperation
OElementwiseOperation,
QKMPerBlock,
QKNPerBlock,
QKMPerXDL,
QKNPerXDL,
PVMPerBlock,
PVNPerBlock,
PVMPerXDL,
PVNPerXDL,
KPerBlock,
QK1,
KK1,
VK1,
QKMXdlPerWave,
QKNXdlPerWave,
PVMXdlPerWave,
PVNXdlPerWave,
KBlockTransferThreadClusterLengths_K0_N_K1,
KBlockTransferThreadClusterArrangeOrder,
KBlockTransferSrcAccessOrder,
KBlockTransferSrcVectorDim,
KBlockTransferSrcScalarPerVector,
KBlockTransferDstScalarPerVector_K1,
false, // AThreadTransferSrcResetCoordinateAfterRun,
KBlockLdsAddExtraM,
VBlockTransferThreadClusterLengths_K0_N_K1,
VBlockTransferThreadClusterArrangeOrder,
VBlockTransferSrcAccessOrder,
VBlockTransferSrcVectorDim,
VBlockTransferSrcScalarPerVector,
VBlockTransferDstScalarPerVector_K1,
false, // AThreadTransferSrcResetCoordinateAfterRun,
VBlockLdsAddExtraM,
QBlockTransferSrcScalarPerVector,
false, // BThreadTransferSrcResetCoordinateAfterRun,
Sequence<0, 2, 4, 5, 6, 1, 3, 7>, // CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector>;
// Argument
struct Argument : public BaseArgument
{
Argument(const QDataType* p_q_grid,
const KDataType* p_k_grid,
const VDataType* p_v_grid,
ODataType* p_o_grid,
index_t QKM,
index_t QKN,
index_t PVM,
index_t PVN,
index_t K,
index_t StrideA,
index_t StrideB,
index_t StrideC,
index_t M01,
index_t N01,
QElementwiseOperation q_element_op,
KElementwiseOperation k_element_op,
VElementwiseOperation v_element_op,
PElementwiseOperation p_element_op,
OElementwiseOperation o_element_op)
: p_q_grid_{p_q_grid},
p_k_grid_{p_k_grid},
p_v_grid_{p_v_grid},
p_o_grid_{p_o_grid}
q_grid_desc_k0_m_k1_{},
k_grid_desc_k0_n_k1_{},
o_grid_desc_m_n_{},
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_{},
block_2_ctile_map_{},
M01_{M01},
N01_{N01},
q_element_op_{q_element_op},
k_element_op_{k_element_op},
p_element_op_{p_element_op},
v_element_op_{v_element_op},
o_element_op_{o_element_op},
{
q_grid_desc_k0_m_k1_ =
DeviceOp::MakeQGridDescriptor_QK0_M_QK1(QKM, K, StrideA);
k_grid_desc_k0_n_k1_ =
DeviceOp::MakeKGridDescriptor_KK0_N_KK1(QKN, K, StrideB);
v_grid_desc_k0_n_k1_ =
DeviceOp::MakeVGridDescriptor_VK0_N_VK1(PVN, K, StrideB);
o_grid_desc_m_n_ = DeviceOp::MakeOGridDescriptor_M_N(PVM, PVN, StrideC);
if(GridwiseGemm::CheckValidity(
q_grid_desc_k0_m_k1_, k_grid_desc_k0_n_k1_, o_grid_desc_m_n_, M01_, N01_))
{
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_ =
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(o_grid_desc_m_n_);
block_2_ctile_map_ =
GridwiseGemm::MakeDefaultBlock2CTileMap(o_grid_desc_m_n_, M01, N01);
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3_ =
GridwiseGemm::MakeBGridDescriptor_K0_K1_K2_N0_N1_N2_N3_K3(k_grid_desc_k0_n_k1_);
}
}
// private:
const QDataType* p_q_grid_;
const KDataType* p_k_grid_;
const VDataType* p_v_grid_;
ODataType* p_o_grid_;
QGridDesc_K0_M_K1 q_grid_desc_k0_m_k1_;
KGridDesc_K0_N_K1 k_grid_desc_k0_n_k1_;
VGridDesc_K0_N_K1 v_grid_desc_k0_n_k1_;
OGridDesc_M_N o_grid_desc_m_n_;
typename GridwiseGemm::BGridDesc_K0_K1_K2_N0_N1_N2_N3_K3
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3_;
typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_;
typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_;
index_t M01_;
index_t N01_;
QElementwiseOperation q_element_op_;
KElementwiseOperation k_element_op_;
PElementwiseOperation p_element_op_;
VElementwiseOperation v_element_op_;
OElementwiseOperation o_element_op_;
};
// Invoker
struct Invoker : public BaseInvoker
{
using Argument = DeviceOp::Argument;
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
{
std::cout << "arg.q_grid_desc_k0_m_k1_{" << arg.q_grid_desc_k0_m_k1_.GetLength(I0)
<< ", " << arg.q_grid_desc_k0_m_k1_.GetLength(I1) << ", "
<< arg.q_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl;
std::cout << "arg.k_grid_desc_k0_n_k1_{" << arg.k_grid_desc_k0_n_k1_.GetLength(I0)
<< ", " << arg.k_grid_desc_k0_n_k1_.GetLength(I1) << ", "
<< arg.k_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl;
std::cout << "arg.o_grid_desc_m_n_{ " << arg.o_grid_desc_m_n_.GetLength(I0) << ", "
<< arg.o_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
}
if(!GridwiseGemm::CheckValidity(arg.q_grid_desc_k0_m_k1_,
arg.k_grid_desc_k0_n_k1_,
arg.o_grid_desc_m_n_,
arg.M01_,
arg.N01_))
{
throw std::runtime_error(
"wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting");
}
const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.o_grid_desc_m_n_);
const auto K0 = arg.q_grid_desc_k0_m_k1_.GetLength(I0);
const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0);
float ave_time = 0;
if(has_main_k0_block_loop)
{
const auto kernel = kernel_gemm_xdlops_skip_b_lds_v1<
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
CDataType,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<typename GridwiseGemm::BGridDesc_K0_K1_K2_N0_N1_N2_N3_K3>,
remove_reference_t<typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>,
true>;
ave_time = launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.q_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3_,
arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.block_2_ctile_map_);
}
else
{
const auto kernel = kernel_gemm_xdlops_skip_b_lds_v1<
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
CDataType,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<typename GridwiseGemm::BGridDesc_K0_K1_K2_N0_N1_N2_N3_K3>,
remove_reference_t<typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>,
false>;
ave_time = launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.q_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3_,
arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.block_2_ctile_map_);
}
return ave_time;
}
// polymorphic
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
static bool IsSupportedArgument(const Argument& arg)
{
return GridwiseGemm::CheckValidity(arg.q_grid_desc_k0_m_k1_,
arg.k_grid_desc_k0_n_k1_,
arg.o_grid_desc_m_n_,
arg.M01_,
arg.N01_);
}
// polymorphic
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
static auto MakeArgument(const ADataType* p_a,
const BDataType* p_b,
CDataType* p_c,
index_t M,
index_t N,
index_t K,
index_t StrideA,
index_t StrideB,
index_t StrideC,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
{
return Argument{p_a,
p_b,
p_c,
M,
N,
K,
StrideA,
StrideB,
StrideC,
1,
1,
a_element_op,
b_element_op,
c_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
// polymorphic
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
const void* p_b,
void* p_c,
index_t M,
index_t N,
index_t K,
index_t StrideA,
index_t StrideB,
index_t StrideC,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
index_t /* KBatch */ = 1) override
{
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b),
static_cast<CDataType*>(p_c),
M,
N,
K,
StrideA,
StrideB,
StrideC,
1,
1,
a_element_op,
b_element_op,
c_element_op);
}
// polymorphic
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>(Invoker{});
}
// polymorphic
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "DeviceOp"
<< "<"
<< BlockSize << ", "
<< MPerBlock << ", "
<< NPerBlock << ", "
<< K0PerBlock << ", "
<< K1 << ", "
<< MPerXDL << ", "
<< NPerXDL << ", "
<< MXdlPerWave << ", "
<< NXdlPerWave
<< ">";
// clang-format on
return str.str();
}
};
\ No newline at end of file
#pragma once
#include "common_header.hpp"
#include "multi_index_transform_helper.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "blockwise_gemm_xdlops.hpp"
#include "blockwise_gemm_xdlops_skip_b_lds.hpp"
#include "thread_group_tensor_slice_transfer_v4r1.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
#include "gridwise_gemm_pipeline_v1.hpp"
namespace ck {
template <index_t BlockSize,
typename FloatAB,
typename FloatAcc,
typename FloatC,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
typename AGridDesc_K0_M_K1,
typename B0GridDesc_K0_N_K1,
typename B1GridDesc_K0_N_K1,
typename CGridDesc_M_N,
typename AElementwiseOperation,
typename B0ElementwiseOperation,
typename C0ElementwiseOperation,
typename B1ElementwiseOperation,
typename C1ElementwiseOperation,
index_t M0PerBlock,
index_t N0PerBlock,
index_t M0PerXDL,
index_t N0PerXDL,
index_t M1PerBlock,
index_t N1PerBlock,
index_t M1PerXDL,
index_t N1PerXDL,
index_t KPerBlock,
index_t AK1,
index_t B0K1,
index_t B1K1,
index_t M0XdlPerWave,
index_t N0XdlPerWave,
index_t M1XdlPerWave,
index_t N1XdlPerWave,
typename ABlockTransferThreadClusterLengths_K0_M_K1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_K1,
bool AThreadTransferSrcResetCoordinateAfterRun,
bool ABlockLdsExtraM,
typename B1BlockTransferThreadClusterLengths_K0_M_K1,
typename B1BlockTransferThreadClusterArrangeOrder,
typename B1BlockTransferSrcAccessOrder,
index_t B1BlockTransferSrcVectorDim,
index_t B1BlockTransferSrcScalarPerVector,
index_t B1BlockTransferDstScalarPerVector_K1,
bool B1ThreadTransferSrcResetCoordinateAfterRun,
bool B1BlockLdsExtraM,
index_t B0BlockTransferSrcScalarPerVector,
bool B0ThreadTransferSrcResetCoordinateAfterRun,
typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector>
struct GridwiseGemmGemmXdlopsSkipLdsV1
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{};
static constexpr auto I5 = Number<5>{};
static constexpr auto I6 = Number<6>{};
static constexpr auto I7 = Number<7>{};
static constexpr auto BaseMultK0 = 2;
static constexpr auto MultiK0 = BaseMultK0 * 1;
// K1 should be Number<...>
static constexpr auto K1 = Number<AK1>{};
static constexpr index_t WaveSize = 64;
static constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXDL);
static constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXDL);
static constexpr auto xdlops_gemm = XdlopsGemm<FloatAB, MPerXDL, NPerXDL, K1>{};
static constexpr index_t K0PerThread = K0PerBlock / xdlops_gemm.K0PerXdlops;
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
__host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1()
{
constexpr auto max_lds_align = K1;
// A matrix in LDS memory, dst of blockwise copy
constexpr auto a_block_desc_k0_m_k1 = [&]() {
if constexpr(ABlockLdsExtraM)
{
return make_naive_tensor_descriptor(
make_tuple(Number<K0PerBlock * MultiK0>{}, Number<MPerBlock>{}, K1),
make_tuple(Number<MPerBlock + 1>{} * K1, K1, I1));
}
else
{
return make_naive_tensor_descriptor_aligned(
make_tuple(Number<K0PerBlock * MultiK0>{}, Number<MPerBlock>{}, K1),
max_lds_align);
}
}();
return a_block_desc_k0_m_k1;
}
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_desc_k0_m_k1 = GetABlockDescriptor_K0PerBlock_MPerBlock_K1();
constexpr auto max_lds_align = K1;
constexpr auto a_block_space_size_aligned =
math::integer_least_multiple(a_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align);
return (a_block_space_size_aligned) * sizeof(FloatAB);
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
__host__ __device__ static constexpr bool
CheckValidity(const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1,
const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1,
const CGridDesc_M_N& c_grid_desc_m_n,
index_t M01,
index_t N01)
{
static_assert(is_known_at_compile_time<remove_cv_t<decltype(K1)>>::value,
"wrong! K1 need to be known at compile-time");
static_assert((MPerBlock % (MPerXDL * MXdlPerWave) == 0) &&
(NPerBlock % (NXdlPerWave * NPerXDL)) == 0,
"Invalid tuning param!");
const auto M = a_grid_desc_k0_m_k1.GetLength(I1);
const auto N = b_grid_desc_k0_n_k1.GetLength(I1);
const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0);
if(!(M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1) &&
K0 == b_grid_desc_k0_n_k1.GetLength(I0) && K1 == a_grid_desc_k0_m_k1.GetLength(I2) &&
K1 == b_grid_desc_k0_n_k1.GetLength(I2)))
return false;
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0))
return false;
// 2-stage prefetch currently only support even number of K0 loop
// TODO: add support for odd number of K0 loop
if(!((K0 / K0PerBlock) % MultiK0 == 0))
{
return false;
}
// check M01, N01
constexpr auto M1 = Number<MPerBlock>{};
constexpr auto N1 = Number<NPerBlock>{};
const auto M0 = M / M1;
const auto N0 = N / N1;
if(!(M0 % M01 == 0 && N0 % N01 == 0))
return false;
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return true;
}
__host__ __device__ static constexpr index_t
CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n)
{
const auto M = c_grid_desc_m_n.GetLength(I0);
const auto N = c_grid_desc_m_n.GetLength(I1);
const index_t grid_size = (M / MPerBlock) * (N / NPerBlock);
return grid_size;
}
// TODO move this function into GEMM-pipeline class
__host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0)
{
const bool has_main_k0_block_loop = (K0 / (MultiK0 * K0PerBlock)) > 1;
return has_main_k0_block_loop;
}
__host__ __device__ static constexpr auto
MakeBGridDescriptor_K0_K1_K2_N0_N1_N2_N3_K3(const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1)
{
const auto K0 = b_grid_desc_k0_n_k1.GetLength(I0);
const auto N = b_grid_desc_k0_n_k1.GetLength(I1);
const auto b_griddesc_k0_nblockid_nrepeat_waves_nperxdlops_k1 = transform_tensor_descriptor(
b_grid_desc_k0_n_k1,
make_tuple(make_unmerge_transform(
make_tuple(K0 / K0PerBlock, xdlops_gemm.K0PerXdlops, K0PerThread)),
make_unmerge_transform(make_tuple(
N / (NXdlPerWave * NWaves * NPerXDL), NXdlPerWave, NWaves, NPerXDL)),
make_pass_through_transform(K1)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5, 6>{}, Sequence<7>{}));
return b_griddesc_k0_nblockid_nrepeat_waves_nperxdlops_k1;
}
__device__ static auto GetWaveIdx()
{
const index_t thread_id = get_thread_local_1d_id();
constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(MWaves, NWaves, WaveSize))),
make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{}));
return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id));
}
__device__ static auto GetWaveKNIdx(const index_t thread_id)
{
constexpr auto wave_threadid_to_nk_idx_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(xdlops_gemm.K0PerXdlops, NPerXDL))),
make_tuple(Sequence<0, 1>{}),
make_tuple(Sequence<0>{}));
return wave_threadid_to_nk_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id));
}
__host__ __device__ static constexpr auto
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_M_N& c_grid_desc_m_n)
{
constexpr auto max_lds_align = K1;
// A matrix in LDS memory, dst of blockwise copy
constexpr auto a_block_desc_k0_m_k1 = [&]() {
if constexpr(ABlockLdsExtraM)
{
return make_naive_tensor_descriptor(
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1),
make_tuple(Number<MPerBlock + 1>{} * K1, K1, I1));
}
else
{
return make_naive_tensor_descriptor_aligned(
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
}
}();
// B matrix threadwise copy
constexpr auto b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3 =
make_naive_tensor_descriptor_packed(make_tuple(I1,
I1,
Number<K0PerThread>{}, // K0PerThread
I1, // NBlockId
Number<NXdlPerWave>{}, // repeat
I1, // waves
I1, // NPerXdlops
Number<K1>{}));
using BlockwiseGemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1r1<
BlockSize,
FloatAB,
FloatAcc,
decltype(a_block_desc_k0_m_k1),
decltype(b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3),
MPerBlock,
NPerBlock,
K0PerBlock,
MPerXDL,
NPerXDL,
MXdlPerWave,
NXdlPerWave,
K1>;
return BlockwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n);
}
// return block_id to C matrix tile idx (m0, n0) mapping
__host__ __device__ static constexpr auto
MakeDefaultBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n, index_t M01, index_t N01)
{
const auto M = c_grid_desc_m_n.GetLength(I0);
const auto N = c_grid_desc_m_n.GetLength(I1);
constexpr auto M1 = Number<MPerBlock>{};
constexpr auto N1 = Number<NPerBlock>{};
const auto M0 = M / M1;
const auto N0 = N / N1;
const auto M00 = M0 / M01;
const auto N00 = N0 / N01;
const auto m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(M00, M01)),
make_unmerge_transform(make_tuple(N00, N01))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1, 3>{}));
const auto cblockid_to_m00_m01_n00_n01_block_cluster_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(M00, N00, M01, N01))),
make_tuple(Sequence<0, 1, 2, 3>{}),
make_tuple(Sequence<0>{}));
const auto cblockid_to_m0_n0_block_cluster_adaptor =
chain_tensor_adaptors(m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor,
cblockid_to_m00_m01_n00_n01_block_cluster_adaptor);
return cblockid_to_m0_n0_block_cluster_adaptor;
}
using CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 =
decltype(MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(CGridDesc_M_N{}));
using DefaultBlock2CTileMap = decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1));
using BGridDesc_K0_K1_K2_N0_N1_N2_N3_K3 =
decltype(MakeBGridDescriptor_K0_K1_K2_N0_N1_N2_N3_K3(BGridDesc_K0_N_K1{}));
template <bool HasMainK0BlockLoop, typename Block2CTileMap = DefaultBlock2CTileMap>
__device__ static void
Run(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
void* __restrict__ p_shared,
const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1,
const BGridDesc_K0_K1_K2_N0_N1_N2_N3_K3 b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
const CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2& c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CElementwiseOperation& c_element_op,
const Block2CTileMap& block_2_ctile_map)
{
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_k0_m_k1.GetElementSpaceSize());
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid, b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3.GetElementSpaceSize());
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_grid, c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetElementSpaceSize());
const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0);
// divide block work by [M, N]
const auto block_work_idx =
block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
// HACK: this force m/n_block_data_idx_on_grid into SGPR
const index_t m_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock);
const index_t n_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock);
// A matrix in LDS memory, dst of blockwise copy
constexpr auto a_block_desc_k0_m_k1 = GetABlockDescriptor_K0PerBlock_MPerBlock_K1();
// A matrix blockwise copy
auto a_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
AElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<K0PerBlock * MultiK0, MPerBlock, K1>,
ABlockTransferThreadClusterLengths_K0_M_K1,
ABlockTransferThreadClusterArrangeOrder,
FloatAB,
FloatAB,
decltype(a_grid_desc_k0_m_k1),
decltype(a_block_desc_k0_m_k1),
ABlockTransferSrcAccessOrder,
Sequence<1, 0, 2>,
ABlockTransferSrcVectorDim,
2,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_K1,
1,
1,
AThreadTransferSrcResetCoordinateAfterRun,
true,
1>(
a_grid_desc_k0_m_k1,
make_multi_index(0, m_block_data_idx_on_grid, 0),
a_element_op,
a_block_desc_k0_m_k1,
make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{});
ignore = b_element_op;
// B matrix threadwise copy
constexpr auto b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3 =
make_naive_tensor_descriptor_packed(make_tuple(I1,
I1,
Number<K0PerThread>{}, // K0PerThread
I1, // NBlockId
Number<NXdlPerWave>{}, // repeat
I1, // waves
I1, // NPerXdlops
Number<K1>{}));
StaticBuffer<AddressSpaceEnum::Vgpr,
FloatAB,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3.GetElementSpaceSize(),
true>
b_thread_1st_buf, b_thread_2nd_buf, b_thread_3rd_buf, b_thread_4th_buf;
const auto wave_id = GetWaveIdx();
const auto wave_k_n_id = GetWaveKNIdx(wave_id[I2]);
#if 0
const index_t block_id = get_block_1d_id();
const index_t thread_id = get_thread_local_1d_id();
printf("block id: %d m blockid: %d n block id: %d ,thread id: %d, wave id :{%d %d %d} "
"kn id: {%d %d}\n",
block_id,
block_work_idx[I0],
block_work_idx[I1],
thread_id,
wave_id[I0],
wave_id[I1],
wave_id[I2],
wave_k_n_id[I0],
wave_k_n_id[I1]);
printf("mfma thread k per xdlops: %d K0PerThread: %d HasMainK0BlockLoop: %d K0: %d \t",
xdlops_gemm.K0PerXdlops, K0PerThread, HasMainK0BlockLoop, b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3.GetLength(I0));
#endif
auto b_threadwise_copy =
ThreadwiseTensorSliceTransfer_v2<FloatAB,
FloatAB,
decltype(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3),
decltype(b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3),
Sequence<I1,
I1,
Number<K0PerThread>{},
I1,
Number<NXdlPerWave>{},
I1,
I1,
Number<K1>{}>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
7,
BBlockTransferSrcScalarPerVector,
BThreadTransferSrcResetCoordinateAfterRun,
true>(
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
make_multi_index(
0, wave_k_n_id[I0], 0, block_work_idx[I1], 0, wave_id[I1], wave_k_n_id[I1], 0));
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[K0PerBlock, MPerBlock] is in LDS
// b_mtx[K0PerBlock, NPerBlock] is in LDS
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
// sanity check
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1r1<
BlockSize,
FloatAB,
FloatAcc,
decltype(a_block_desc_k0_m_k1),
decltype(b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3),
MPerBlock,
NPerBlock,
K0PerBlock,
MPerXDL,
NPerXDL,
MXdlPerWave,
NXdlPerWave,
K1>{};
auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
// LDS allocation for A
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatAB*>(p_shared), a_block_desc_k0_m_k1.GetElementSpaceSize());
// gridwise GEMM pipeline
constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock * MultiK0, 0, 0);
constexpr auto b_thread_slice_copy_step = make_multi_index(1, 0, 0, 0, 0, 0, 0, 0);
// preload data to regiester and LDS
{
// Read
a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf);
b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_grid_buf,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_1st_buf);
// Move
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m_k1, a_block_slice_copy_step);
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_thread_slice_copy_step);
// Initialize C
c_thread_buf.Clear();
// a data write to lds
a_blockwise_copy.RunWrite(a_block_desc_k0_m_k1, a_block_buf);
// load 2nd a matrix data
b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_grid_buf,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_2nd_buf);
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_thread_slice_copy_step);
// main body
if constexpr(HasMainK0BlockLoop)
{
index_t K0BlockMainLoop =
__builtin_amdgcn_readfirstlane(K0 / (MultiK0 * K0PerBlock));
index_t i = 0;
do
{
a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf);
blockwise_gemm.ResetABlockStartWindow();
block_sync_lds();
static_for<0, MultiK0, BaseMultK0>{}([&](auto) {
// 1st
b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_grid_buf,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_3rd_buf);
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_thread_slice_copy_step);
blockwise_gemm.Run(a_block_buf, b_thread_1st_buf, c_thread_buf);
blockwise_gemm.MoveABlockSliceWindow();
s_nop();
// 2nd
b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_grid_buf,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_4th_buf);
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_thread_slice_copy_step);
blockwise_gemm.Run(a_block_buf, b_thread_2nd_buf, c_thread_buf);
blockwise_gemm.MoveABlockSliceWindow();
s_nop();
// 3rd
b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_grid_buf,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_1st_buf);
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_thread_slice_copy_step);
blockwise_gemm.Run(a_block_buf, b_thread_3rd_buf, c_thread_buf);
blockwise_gemm.MoveABlockSliceWindow();
s_nop();
// 4th
b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_grid_buf,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_2nd_buf);
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_thread_slice_copy_step);
blockwise_gemm.Run(a_block_buf, b_thread_4th_buf, c_thread_buf);
blockwise_gemm.MoveABlockSliceWindow();
});
block_sync_lds();
a_blockwise_copy.RunWrite(a_block_desc_k0_m_k1, a_block_buf);
// move a and b window
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m_k1,
a_block_slice_copy_step);
i += 1;
} while(i < (K0BlockMainLoop - 1));
}
// tail
{
block_sync_lds();
blockwise_gemm.ResetABlockStartWindow();
static_for<0, MultiK0, BaseMultK0>{}([&](auto i) {
// 1st
b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_grid_buf,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_3rd_buf);
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_thread_slice_copy_step);
blockwise_gemm.Run(a_block_buf, b_thread_1st_buf, c_thread_buf);
blockwise_gemm.MoveABlockSliceWindow();
// 2nd
b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_grid_buf,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_4th_buf);
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_thread_slice_copy_step);
blockwise_gemm.Run(a_block_buf, b_thread_2nd_buf, c_thread_buf);
blockwise_gemm.MoveABlockSliceWindow();
// 3rd
if constexpr(i < MultiK0 - BaseMultK0)
{
b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_grid_buf,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_1st_buf);
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_thread_slice_copy_step);
}
blockwise_gemm.Run(a_block_buf, b_thread_3rd_buf, c_thread_buf);
blockwise_gemm.MoveABlockSliceWindow();
// 4th
if constexpr(i < MultiK0 - BaseMultK0)
{
b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_grid_buf,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_2nd_buf);
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_thread_slice_copy_step);
}
blockwise_gemm.Run(a_block_buf, b_thread_4th_buf, c_thread_buf);
blockwise_gemm.MoveABlockSliceWindow();
});
}
}
// output: register to global memory
{
constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I0);
constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I1);
constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I2);
constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I3);
constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I4);
constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I5);
constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I6);
constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I7);
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const auto c_thread_mtx_on_block =
blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
const index_t m_thread_data_on_grid =
m_block_data_idx_on_grid + c_thread_mtx_on_block[I0];
const index_t n_thread_data_on_grid =
n_block_data_idx_on_grid + c_thread_mtx_on_block[I1];
const auto m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
make_tuple(Sequence<0, 1, 2, 3, 4>{}),
make_tuple(Sequence<0>{}));
const auto m_thread_data_on_grid_idx =
m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
make_multi_index(m_thread_data_on_grid));
const auto n_thread_data_on_grid_to_n0_n1_n2_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(N0, N1, N2))),
make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{}));
const auto n_thread_data_on_grid_idx =
n_thread_data_on_grid_to_n0_n1_n2_adaptor.CalculateBottomIndex(
make_multi_index(n_thread_data_on_grid));
auto c_thread_copy =
ThreadwiseTensorSliceTransfer_v1r3<FloatAcc,
FloatC,
decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
decltype(c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2),
CElementwiseOperation,
Sequence<M0, N0, I1, I1, M2, I1, M4, I1>,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector,
CGlobalMemoryDataOperation,
1,
true>{
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
make_multi_index(m_thread_data_on_grid_idx[I0],
n_thread_data_on_grid_idx[I0],
m_thread_data_on_grid_idx[I1],
n_thread_data_on_grid_idx[I1],
m_thread_data_on_grid_idx[I2],
m_thread_data_on_grid_idx[I3],
m_thread_data_on_grid_idx[I4],
n_thread_data_on_grid_idx[I2]),
c_element_op};
c_thread_copy.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
c_thread_buf,
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
c_grid_buf);
}
}
};
} // namespace ck
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment