Commit 497c30e0 authored by Alan Turner's avatar Alan Turner
Browse files

Merge branch 'develop' of...

Merge branch 'develop' of https://github.com/ROCmSoftwarePlatform/composable_kernel into migx-jit-lib
parents 17acbbf4 8b9cbba8
#pragma once
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_gemm.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_dl_multiple_d.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <typename GridwiseGemm,
typename GemmDesc,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
bool HasMainKBlockLoop,
bool HasDoubleTailKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_grouped_gemm_multiple_d_dl(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const,
const index_t group_count,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CDEElementwiseOperation cde_element_op)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \
defined(__gfx90a__) || defined(__gfx1030__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
const index_t block_id = get_block_1d_id();
const auto gemm_desc_ptr =
reinterpret_cast<const GemmDesc*>(cast_pointer_to_generic_address_space(gemm_descs_const));
index_t left = 0;
index_t right = group_count;
index_t group_id = index_t((left + right) / 2);
while((!(block_id >= gemm_desc_ptr[group_id].BlockStart_ &&
block_id < gemm_desc_ptr[group_id].BlockEnd_)) &&
left <= right)
{
if(block_id < gemm_desc_ptr[group_id].BlockStart_)
{
right = group_id;
}
else
{
left = group_id;
}
group_id = index_t((left + right) / 2);
}
GridwiseGemm::Run(gemm_desc_ptr[group_id].a_ptr_,
gemm_desc_ptr[group_id].b_ptr_,
gemm_desc_ptr[group_id].ds_ptr_,
gemm_desc_ptr[group_id].e_ptr_,
p_shared,
a_element_op,
b_element_op,
cde_element_op,
gemm_desc_ptr[group_id].a_grid_desc_k0_m0_m1_k1_,
gemm_desc_ptr[group_id].b_grid_desc_k0_n0_n1_k1_,
gemm_desc_ptr[group_id].ds_grid_desc_m0_m10_m11_n0_n10_n11_,
gemm_desc_ptr[group_id].e_grid_desc_m0_m10_m11_n0_n10_n11_,
gemm_desc_ptr[group_id].block_2_etile_map_,
integral_constant<bool, HasMainKBlockLoop>{},
integral_constant<bool, HasDoubleTailKBlockLoop>{});
#else
ignore = gemm_descs_const;
ignore = group_count;
ignore = a_element_op;
ignore = b_element_op;
ignore = cde_element_op;
#endif
}
template <typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
typename ADataType,
typename BDataType,
typename AccDataType,
typename DsDataType,
typename EDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
GemmSpecialization GemmSpec,
index_t BlockSize,
index_t MPerBlock,
index_t NPerBlock,
index_t K0PerBlock,
index_t K1,
index_t M1PerThread,
index_t N1PerThread,
index_t KPerThread,
typename M1N1ThreadClusterM1Xs,
typename M1N1ThreadClusterN1Xs,
typename ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
typename ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
typename ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1,
typename ABlockTransferSrcVectorTensorContiguousDimOrder,
typename ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1,
typename BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
typename BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
typename BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1,
typename BBlockTransferSrcVectorTensorContiguousDimOrder,
typename BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector,
enable_if_t<
is_same_v<AElementwiseOperation, ck::tensor_operation::element_wise::PassThrough> &&
is_same_v<BElementwiseOperation, ck::tensor_operation::element_wise::PassThrough>,
bool> = false>
struct DeviceGroupedGemmMultipleD_Dl : public DeviceGroupedGemm<ALayout,
BLayout,
DsLayout,
ELayout,
ADataType,
BDataType,
DsDataType,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation>
{
using DeviceOp = DeviceGroupedGemmMultipleD_Dl;
static constexpr index_t NumDTensor = DsDataType::Size();
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{};
static constexpr auto I5 = Number<5>{};
static constexpr auto K1Number = Number<K1>{};
static auto MakeAGridDescriptor_K0_M_K1(index_t M, index_t K, index_t StrideA)
{
assert(K % K1 == 0);
const index_t K0 = K / K1;
const auto a_grid_desc_m_k = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, ALayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA));
}
}();
if constexpr(GemmSpec == GemmSpecialization::MNPadding)
{
const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
return transform_tensor_descriptor(
a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
make_right_pad_transform(M, PadM)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
else
{
return transform_tensor_descriptor(
a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
make_pass_through_transform(M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
}
static auto MakeBGridDescriptor_K0_N_K1(index_t K, index_t N, index_t StrideB)
{
assert(K % K1 == 0);
const index_t K0 = K / K1;
const auto b_grid_desc_k_n = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB));
}
}();
if constexpr(GemmSpec == GemmSpecialization::MNPadding)
{
const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
return transform_tensor_descriptor(
b_grid_desc_k_n,
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
make_right_pad_transform(N, PadN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
else
{
return transform_tensor_descriptor(
b_grid_desc_k_n,
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
make_pass_through_transform(N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
}
template <typename ELay>
static auto MakeEGridDescriptor_M_N(index_t M, index_t N, index_t StrideE)
{
const auto c_grid_desc_m_n = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, ELay>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideE, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, ELay>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideE));
}
}();
if constexpr(GemmSpec == GemmSpecialization::MNPadding)
{
const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
return transform_tensor_descriptor(
c_grid_desc_m_n,
make_tuple(make_right_pad_transform(M, PadM), make_right_pad_transform(N, PadN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else
{
return transform_tensor_descriptor(
c_grid_desc_m_n,
make_tuple(make_pass_through_transform(M), make_pass_through_transform(N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
}
static auto MakeDsGridDescriptor_M_N(const std::array<index_t, NumDTensor>& MRaws,
const std::array<index_t, NumDTensor>& NRaws,
const std::array<index_t, NumDTensor>& DsStride)
{
return generate_tuple(
[&](auto i) {
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
return DeviceOp::MakeEGridDescriptor_M_N<DLayout>(MRaws[i], NRaws[i], DsStride[i]);
},
Number<NumDTensor>{});
}
using AGridDesc_K0_M_K1 = decltype(MakeAGridDescriptor_K0_M_K1(1, 1, 1));
using BGridDesc_K0_N_K1 = decltype(MakeBGridDescriptor_K0_N_K1(1, 1, 1));
using DsGridDesc_M_N = decltype(MakeDsGridDescriptor_M_N({}, {}, {}));
using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N<ELayout>(1, 1, 1));
// GridwiseGemm
using GridwiseGemm =
GridwiseGemmDlMultipleD_km_kn_mn<BlockSize,
ADataType,
AccDataType,
DsDataType,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation,
InMemoryDataOperationEnum::Set,
AGridDesc_K0_M_K1,
BGridDesc_K0_N_K1,
EGridDesc_M_N,
MPerBlock,
NPerBlock,
K0PerBlock,
K1,
M1PerThread,
N1PerThread,
KPerThread,
M1N1ThreadClusterM1Xs,
M1N1ThreadClusterN1Xs,
ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1,
ABlockTransferSrcVectorTensorContiguousDimOrder,
ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1,
BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1,
BBlockTransferSrcVectorTensorContiguousDimOrder,
BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector>;
using AGridDesc_K0_M0_M1_K1 =
decltype(GridwiseGemm::MakeAGridDescriptor_K0_M0_M1_K1(AGridDesc_K0_M_K1{}));
using BGridDesc_K0_N0_N1_K1 =
decltype(GridwiseGemm::MakeBGridDescriptor_K0_N0_N1_K1(BGridDesc_K0_N_K1{}));
using DsGridDesc_M0_M10_M11_N0_N10_N11 =
decltype(GridwiseGemm::MakeDsGridDescriptor_M0_M10_M11_N0_N10_N11(DsGridDesc_M_N{}));
using EGridDesc_M0_M10_M11_N0_N10_N11 =
decltype(GridwiseGemm::MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(EGridDesc_M_N{}));
struct GroupedGemmBlock2ETileMap
{
using Block2ETileMap =
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultBlock2CTileMap(EGridDesc_M_N{}))>;
GroupedGemmBlock2ETileMap()
{
block_2_etile_map_ = GridwiseGemm::MakeDefaultBlock2CTileMap(EGridDesc_M_N{});
BlockStart_ = -1;
}
GroupedGemmBlock2ETileMap(const EGridDesc_M_N& e_grid_desc_m_n, ck::index_t BlockStart)
{
block_2_etile_map_ = GridwiseGemm::MakeDefaultBlock2CTileMap(e_grid_desc_m_n);
BlockStart_ = BlockStart;
}
template <typename TopIdx>
__host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const
{
return block_2_etile_map_.CalculateBottomIndex(
make_multi_index(idx_top[I0] - BlockStart_));
}
// it's actually E-Tile
template <typename CTileIdx, typename CTileDim>
__host__ __device__ bool ValidCTileIndex(const CTileIdx& c_tile_idx,
const CTileDim& c_tile_dim) const
{
return block_2_etile_map_.ValidCTileIndex(c_tile_idx, c_tile_dim);
}
__host__ bool CheckValidity(const EGridDesc_M_N& e_grid_desc_m_n) const
{
return block_2_etile_map_.CheckValidity(e_grid_desc_m_n);
}
Block2ETileMap block_2_etile_map_;
ck::index_t BlockStart_;
};
struct GemmKernelArg
{
// pointers
const ADataType* a_ptr_;
const BDataType* b_ptr_;
typename GridwiseGemm::DsGridPointer ds_ptr_;
EDataType* e_ptr_;
// tensor descriptors for problem definiton
AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_;
BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_;
DsGridDesc_M_N ds_grid_desc_m_n_;
EGridDesc_M_N e_grid_desc_m_n_;
// tensor descriptors for block/thread-wise copy
AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1_;
BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1_;
DsGridDesc_M0_M10_M11_N0_N10_N11 ds_grid_desc_m0_m10_m11_n0_n10_n11_;
EGridDesc_M0_M10_M11_N0_N10_N11 e_grid_desc_m0_m10_m11_n0_n10_n11_;
// block-to-e-tile map
GroupedGemmBlock2ETileMap block_2_etile_map_;
ck::index_t BlockStart_, BlockEnd_;
};
// Argument
struct Argument : public BaseArgument
{
Argument(std::vector<const void*>& p_As,
std::vector<const void*>& p_Bs,
std::vector<std::array<const void*, NumDTensor>>& p_Ds,
std::vector<void*>& p_Es,
std::vector<GemmDesc>& gemm_descs,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op)
: a_element_op_{a_element_op},
b_element_op_{b_element_op},
cde_element_op_{cde_element_op}
{
grid_size_ = 0;
group_count_ = ck::type_convert<ck::index_t>(gemm_descs.size());
if(!(group_count_ == ck::type_convert<ck::index_t>(p_As.size()) &&
group_count_ == ck::type_convert<ck::index_t>(p_Bs.size()) &&
group_count_ == ck::type_convert<ck::index_t>(p_Es.size())))
{
throw std::runtime_error("wrong! group_count_ != p_As/b/c.size");
}
gemm_desc_kernel_arg_.reserve(group_count_);
skipped_group_count_ = 0;
for(std::size_t i = 0; i < gemm_descs.size(); i++)
{
const index_t M = gemm_descs[i].M_;
const index_t N = gemm_descs[i].N_;
const index_t K = gemm_descs[i].K_;
a_mtx_mraw_kraw_.emplace_back(M, K);
b_mtx_nraw_kraw_.emplace_back(N, K);
if(M == 0)
{
skipped_group_count_++;
continue;
}
const index_t StrideA = gemm_descs[i].stride_A_;
const index_t StrideB = gemm_descs[i].stride_B_;
const index_t StrideE = gemm_descs[i].stride_C_;
typename GridwiseGemm::DsGridPointer p_ds_grid{};
DsGridDesc_M_N ds_grid_desc_m_n;
static_for<0, NumDTensor, 1>{}([&](auto j) {
using DLayout = remove_cvref_t<tuple_element_t<j.value, DsLayout>>;
using DDataType = remove_cvref_t<tuple_element_t<j.value, DsDataType>>;
p_ds_grid(j) = static_cast<const DDataType*>(p_Ds[i][j]);
ds_grid_desc_m_n(j) = DeviceOp::MakeEGridDescriptor_M_N<DLayout>(
M, N, gemm_descs[i].stride_Ds_[j]);
});
// tensor descriptors for problem definiton
const auto a_grid_desc_k0_m_k1 =
DeviceOp::MakeAGridDescriptor_K0_M_K1(M, K, StrideA);
const auto b_grid_desc_k0_n_k1 =
DeviceOp::MakeBGridDescriptor_K0_N_K1(K, N, StrideB);
const auto e_grid_desc_m_n =
DeviceOp::MakeEGridDescriptor_M_N<ELayout>(M, N, StrideE);
if(GridwiseGemm::CheckValidity(
a_grid_desc_k0_m_k1, b_grid_desc_k0_n_k1, e_grid_desc_m_n))
{
const index_t grid_size_grp =
GroupedGemmBlock2ETileMap(e_grid_desc_m_n, 0)
.block_2_etile_map_.CalculateGridSize(e_grid_desc_m_n);
const index_t BlockStart = grid_size_;
const index_t BlockEnd = grid_size_ + grid_size_grp;
grid_size_ += grid_size_grp;
// block-to-e-tile map
const auto block_2_etile_map =
GroupedGemmBlock2ETileMap(e_grid_desc_m_n, BlockStart);
// tensor descriptors for block/thread-wise copy
const auto a_grid_desc_k0_m0_m1_k1 =
GridwiseGemm::MakeAGridDescriptor_K0_M0_M1_K1(a_grid_desc_k0_m_k1);
const auto b_grid_desc_k0_n0_n1_k1 =
GridwiseGemm::MakeBGridDescriptor_K0_N0_N1_K1(b_grid_desc_k0_n_k1);
const auto ds_grid_desc_m0_m10_m11_n0_n10_n11 =
GridwiseGemm::MakeDsGridDescriptor_M0_M10_M11_N0_N10_N11(ds_grid_desc_m_n);
const auto e_grid_desc_m0_m10_m11_n0_n10_n11 =
GridwiseGemm::MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(e_grid_desc_m_n);
gemm_desc_kernel_arg_.push_back(
GemmKernelArg{static_cast<const ADataType*>(p_As[i]),
static_cast<const BDataType*>(p_Bs[i]),
p_ds_grid,
static_cast<EDataType*>(p_Es[i]),
a_grid_desc_k0_m_k1,
b_grid_desc_k0_n_k1,
ds_grid_desc_m_n,
e_grid_desc_m_n,
a_grid_desc_k0_m0_m1_k1,
b_grid_desc_k0_n0_n1_k1,
ds_grid_desc_m0_m10_m11_n0_n10_n11,
e_grid_desc_m0_m10_m11_n0_n10_n11,
block_2_etile_map,
BlockStart,
BlockEnd});
}
}
}
// private:
index_t group_count_;
index_t skipped_group_count_;
// TODO: A,B element op is unused since gridwise_gemm_dl_v1r3 does NOT support prologue
// for the time being.
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CDEElementwiseOperation cde_element_op_;
std::vector<GemmKernelArg> gemm_desc_kernel_arg_;
std::vector<Tuple<index_t, index_t>> a_mtx_mraw_kraw_;
std::vector<Tuple<index_t, index_t>> b_mtx_nraw_kraw_;
index_t grid_size_;
};
// Invoker
struct Invoker : public BaseInvoker
{
using Argument = DeviceOp::Argument;
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
auto K0 = arg.gemm_desc_kernel_arg_[0].a_grid_desc_k0_m_k1_.GetLength(I0);
bool all_has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K0);
bool all_has_double_tail_k_block_loop =
GridwiseGemm::CalculateHasDoubleTailKBlockLoop(K0);
for(std::size_t i = 0; i < arg.gemm_desc_kernel_arg_.size(); i++)
{
#if DEBUG_LOG
std::cout << "group: " << i << " arg.a_grid_desc_k0_m_k1_{"
<< arg.gemm_desc_kernel_arg_[i].a_grid_desc_k0_m_k1_.GetLength(I0) << ", "
<< arg.gemm_desc_kernel_arg_[i].a_grid_desc_k0_m_k1_.GetLength(I1) << ", "
<< arg.gemm_desc_kernel_arg_[i].a_grid_desc_k0_m_k1_.GetLength(I2) << "}"
<< std::endl;
std::cout << ", arg.b_grid_desc_k0_n_k1_{"
<< arg.gemm_desc_kernel_arg_[i].b_grid_desc_k0_n_k1_.GetLength(I0) << ", "
<< arg.gemm_desc_kernel_arg_[i].b_grid_desc_k0_n_k1_.GetLength(I1) << ", "
<< arg.gemm_desc_kernel_arg_[i].b_grid_desc_k0_n_k1_.GetLength(I2) << "}"
<< std::endl;
std::cout << ", arg.e_grid_desc_m_n_{ "
<< arg.gemm_desc_kernel_arg_[i].e_grid_desc_m_n_.GetLength(I0) << ", "
<< arg.gemm_desc_kernel_arg_[i].e_grid_desc_m_n_.GetLength(I1) << "}"
<< std::endl;
#endif
if(!GridwiseGemm::CheckValidity(arg.gemm_desc_kernel_arg_[i].a_grid_desc_k0_m_k1_,
arg.gemm_desc_kernel_arg_[i].b_grid_desc_k0_n_k1_,
arg.gemm_desc_kernel_arg_[i].e_grid_desc_m_n_))
{
throw std::runtime_error(
"wrong! GridwiseGemmDlMultipleD_km_kn_mn has invalid setting");
}
K0 = arg.gemm_desc_kernel_arg_[i].a_grid_desc_k0_m0_m1_k1_.GetLength(I0);
bool not_all_has_main_k_block_loop_same =
all_has_main_k_block_loop xor GridwiseGemm::CalculateHasMainKBlockLoop(K0);
bool not_all_has_double_tail_k_block_loop_same =
all_has_double_tail_k_block_loop xor
GridwiseGemm::CalculateHasDoubleTailKBlockLoop(K0);
if(not_all_has_main_k_block_loop_same or not_all_has_double_tail_k_block_loop_same)
{
std::ostringstream err;
err << "Not all gemms have same value for [main|double_tail]_k_block_loop! in "
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__;
throw std::runtime_error(err.str());
}
}
hipGetErrorString(hipMemcpy(arg.p_workspace_,
arg.gemm_desc_kernel_arg_.data(),
arg.gemm_desc_kernel_arg_.size() * sizeof(GemmKernelArg),
hipMemcpyHostToDevice));
auto launch_kernel = [&](auto has_main_k_block_loop,
auto has_double_tail_k_block_loop) {
constexpr bool has_main_loop = has_main_k_block_loop.value;
constexpr bool has_double_loop = has_double_tail_k_block_loop.value;
const auto kernel = kernel_grouped_gemm_multiple_d_dl<GridwiseGemm,
GemmKernelArg,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation,
has_main_loop,
has_double_loop>;
return launch_and_time_kernel(
stream_config,
kernel,
dim3(arg.grid_size_),
dim3(BlockSize),
0,
cast_pointer_to_constant_address_space(arg.p_workspace_),
arg.gemm_desc_kernel_arg_.size(),
arg.a_element_op_,
arg.b_element_op_,
arg.cde_element_op_);
};
if(all_has_main_k_block_loop && all_has_double_tail_k_block_loop)
{
return launch_kernel(integral_constant<bool, true>{},
integral_constant<bool, true>{});
}
else if(all_has_main_k_block_loop && !all_has_double_tail_k_block_loop)
{
return launch_kernel(integral_constant<bool, true>{},
integral_constant<bool, false>{});
}
else if(!all_has_main_k_block_loop && all_has_double_tail_k_block_loop)
{
return launch_kernel(integral_constant<bool, false>{},
integral_constant<bool, true>{});
}
else
{
return launch_kernel(integral_constant<bool, false>{},
integral_constant<bool, false>{});
}
}
// polymorphic
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
static bool IsSupportedArgument(const Argument& arg)
{
if((ck::type_convert<ck::index_t>(arg.gemm_desc_kernel_arg_.size()) +
arg.skipped_group_count_) != arg.group_count_)
{
return false;
}
const std::string device_name = ck::get_device_name();
// TODO add newer Navi arch
if(device_name != "gfx906" and device_name != "gfx908" and device_name != "gfx90a" and
device_name != "gfx1030")
{
return false;
}
for(std::size_t i = 0; i < arg.gemm_desc_kernel_arg_.size(); i++)
{
if(!GridwiseGemm::CheckValidity(arg.gemm_desc_kernel_arg_[i].a_grid_desc_k0_m_k1_,
arg.gemm_desc_kernel_arg_[i].b_grid_desc_k0_n_k1_,
arg.gemm_desc_kernel_arg_[i].e_grid_desc_m_n_))
{
return false;
}
}
return true;
}
// polymorphic
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
static auto MakeArgument(std::vector<const void*>& p_As,
std::vector<const void*>& p_Bs,
std::vector<std::array<const void*, NumDTensor>>& p_Ds,
std::vector<void*>& p_Es,
std::vector<GemmDesc> gemm_descs,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op)
{
return Argument{
p_As, p_Bs, p_Ds, p_Es, gemm_descs, a_element_op, b_element_op, cde_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
// polymorphic
std::unique_ptr<BaseArgument>
MakeArgumentPointer(std::vector<const void*>& p_As,
std::vector<const void*>& p_Bs,
std::vector<std::array<const void*, NumDTensor>>& p_Ds,
std::vector<void*>& p_Es,
std::vector<GemmDesc>& gemm_descs,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op) override
{
return std::make_unique<Argument>(
p_As, p_Bs, p_Ds, p_Es, gemm_descs, a_element_op, b_element_op, cde_element_op);
}
// polymorphic
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>(Invoker{});
}
// polymorphic
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "DeviceGroupedGemmMultipleD_Dl"
<< "<"
<< BlockSize << ", "
<< MPerBlock << ", "
<< NPerBlock << ", "
<< K0PerBlock << ", "
<< K1 << ", "
<< M1PerThread << ", "
<< N1PerThread << ", "
<< KPerThread
<< getGemmSpecializationString(GemmSpec)
<< ">";
// clang-format on
return str.str();
}
size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override
{
return dynamic_cast<const Argument*>(p_arg)->group_count_ * sizeof(GemmKernelArg);
}
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -381,6 +381,9 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
const index_t N = gemm_descs[i].N_;
const index_t K = gemm_descs[i].K_;
a_mtx_mraw_kraw_.emplace_back(M, K);
b_mtx_nraw_kraw_.emplace_back(N, K);
if(M == 0)
{
skipped_group_count_++;
......@@ -485,6 +488,8 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
CDEElementwiseOperation c_element_op_;
std::vector<GemmBiasTransKernelArg> gemm_desc_kernel_arg_;
std::vector<Tuple<index_t, index_t>> a_mtx_mraw_kraw_;
std::vector<Tuple<index_t, index_t>> b_mtx_nraw_kraw_;
index_t grid_size_;
};
......@@ -599,7 +604,28 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
return false;
}
return true;
bool supported = true;
// If we use padding we do not support vector loads for dimensions not divisible by vector
// load size.
if constexpr(GemmSpec != GemmSpecialization::Default)
{
// [A|B]BlockTransferSrcVectorDim value define dimension in the block {K0,M,K1} layout,
// thus we have to adapt it to the {M,K} or {N,K} layout.
const auto a_raw_vector_dim = ABlockTransferSrcVectorDim != 1 ? 1 : 0;
const auto b_raw_vector_dim = BBlockTransferSrcVectorDim != 1 ? 1 : 0;
for(index_t i = 0; i < arg.group_count_; ++i)
{
const auto a_vector_dim = arg.a_mtx_mraw_kraw_[i].At(Number<a_raw_vector_dim>{});
const auto b_vector_dim = arg.b_mtx_nraw_kraw_[i].At(Number<b_raw_vector_dim>{});
supported = supported & (a_vector_dim % ABlockTransferSrcScalarPerVector == 0);
supported = supported & (b_vector_dim % BBlockTransferSrcScalarPerVector == 0);
}
}
return supported;
}
// polymorphic
......@@ -661,7 +687,12 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
<< MPerXDL << ", "
<< NPerXDL << ", "
<< MXdlPerWave << ", "
<< NXdlPerWave
<< NXdlPerWave << ", "
<< ABlockTransferSrcScalarPerVector << ", "
<< BBlockTransferSrcScalarPerVector << ", "
<< CShuffleMXdlPerWavePerShuffle << ", "
<< CShuffleNXdlPerWavePerShuffle << ", "
<< getGemmSpecializationString(GemmSpec)
<< ">";
// clang-format on
......
......@@ -4,6 +4,7 @@
#pragma once
#include "ck/utility/data_type.hpp"
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
namespace ck {
namespace tensor_operation {
......@@ -280,43 +281,42 @@ struct AddHardswish
};
};
// C = A * B
// E = FastGelu(C + D)
struct AddFastGelu
{
// Fast GeLU
// https://paperswithcode.com/method/gelu
// y = 0.5*x*(1+tanh(sqrt(2/pi)*(x+0.044715*x^3)))
__host__ __device__ static constexpr float GetFastGeLU(float x)
{
const float u = 2.f * x * (0.035677f * x * x + 0.797885f);
const float emu = exp(-u);
const float cdf = 0.5f + 0.5f * (2.f / (1.f + emu) - 1.f);
return x * cdf;
}
template <typename T>
static inline constexpr bool is_valid_param_type_v =
std::is_same_v<T, float> || std::is_same_v<T, half_t> || std::is_same_v<T, bhalf_t> ||
std::is_same_v<T, int32_t> || std::is_same_v<T, int8_t>;
template <typename E, typename C, typename D>
__host__ __device__ constexpr void operator()(E& e, const C& c, const D& d) const
__host__ __device__ constexpr void operator()(E& e, const C& c, const D& d) const;
template <>
__host__ __device__ constexpr void
operator()<float, float, float>(float& e, const float& c, const float& d) const
{
static_assert(is_valid_param_type_v<E> && is_valid_param_type_v<C> &&
is_valid_param_type_v<D>);
const float x = c + d;
FastGelu{}.template operator()<float, float>(e, x);
}
const float y = GetFastGeLU(type_convert<float>(c) + type_convert<float>(d));
template <>
__host__ __device__ constexpr void
operator()<half_t, half_t, half_t>(half_t& e, const half_t& c, const half_t& d) const
{
const half_t x = c + d;
e = type_convert<E>(y);
ck::tensor_operation::element_wise::FastGelu{}.template operator()<half_t, half_t>(e, x);
}
template <typename D>
__host__ __device__ constexpr void operator()(float& e, const float& c, const D& d) const
template <>
__host__ __device__ constexpr void
operator()<half_t, float, half_t>(half_t& e, const float& c, const half_t& d) const
{
static_assert(is_valid_param_type_v<D>);
const float x0_f = c + d;
float x1_f = 0;
ck::tensor_operation::element_wise::FastGelu{}.template operator()<float, float>(x1_f,
x0_f);
e = GetFastGeLU(c + type_convert<float>(d));
e = type_convert<half_t>(x1_f);
}
};
......
......@@ -16,7 +16,7 @@ namespace element_wise {
// Need to ensure compiler will fail if there is no matching candidate, instead of compiler
// siliently do implicit type conversion
//
// Method 1:
// Example:
//
// struct ExampleElementwiseOp
// {
......@@ -30,19 +30,6 @@ namespace element_wise {
// {
// }
// };
//
// Method 2:
//
// template <typename Y, typename X>
// struct ExampleElementwiseOp;
//
// template <>
// struct ExampleElementwiseOp<float, ck::bhalf_t>
// {
// __host__ __device__ void operator()(float& y, ck::bhalf_t& x) const
// {
// }
// };
struct AddReluAdd
{
......@@ -208,41 +195,74 @@ struct AddMultiply
}
};
// C = A * B
// E = FastGelu(C + D0 + D1)
struct AddAddFastGelu
{
// Fast GeLU
// https://paperswithcode.com/method/gelu
// y = 0.5*x*(1+tanh(sqrt(2/pi)*(x+0.044715*x^3)))
__host__ __device__ static constexpr float GetFastGeLU(float x)
template <typename E, typename C, typename D0, typename D1>
__host__ __device__ constexpr void
operator()(E& e, const C& c, const D0& d0, const D1& d1) const;
template <>
__host__ __device__ constexpr void operator()<float, float, float, float>(float& e,
const float& c,
const float& d0,
const float& d1) const
{
const float u = 2.f * x * (0.035677f * x * x + 0.797885f);
const float emu = exp(-u);
const float cdf = 0.5f + 0.5f * (2.f / (1.f + emu) - 1.f);
return x * cdf;
const float x = c + d0 + d1;
FastGelu{}.template operator()<float, float>(e, x);
}
template <typename T>
static inline constexpr bool is_valid_param_type_v =
std::is_same_v<T, float> || std::is_same_v<T, half_t> || std::is_same_v<T, bhalf_t> ||
std::is_same_v<T, int32_t> || std::is_same_v<T, int8_t>
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
|| std::is_same_v<T, ck::int4_t>
#endif
;
template <>
__host__ __device__ constexpr void operator()<half_t, half_t, half_t, half_t>(
half_t& e, const half_t& c, const half_t& d0, const half_t& d1) const
{
const half_t x = c + d0 + d1;
template <typename E, typename C, typename D0, typename D1>
__host__ __device__ constexpr void
operator()(E& e, const C& c, const D0& d0, const D1& d1) const
ck::tensor_operation::element_wise::FastGelu{}.template operator()<half_t, half_t>(e, x);
}
template <>
__host__ __device__ constexpr void operator()<half_t, float, half_t, half_t>(
half_t& e, const float& c, const half_t& d0, const half_t& d1) const
{
static_assert(is_valid_param_type_v<E> && is_valid_param_type_v<C> &&
is_valid_param_type_v<D0> && is_valid_param_type_v<D1>);
const float x0_f = c + d0 + d1;
const float y =
GetFastGeLU(type_convert<float>(c) + type_convert<float>(d0) + type_convert<float>(d1));
float x1_f = 0;
e = type_convert<E>(y);
ck::tensor_operation::element_wise::FastGelu{}.template operator()<float, float>(x1_f,
x0_f);
e = type_convert<half_t>(x1_f);
}
template <>
__host__ __device__ constexpr void operator()<bhalf_t, float, bhalf_t, bhalf_t>(
bhalf_t& e, const float& c, const bhalf_t& d0, const bhalf_t& d1) const
{
const float x0_f = c + type_convert<float>(d0) + type_convert<float>(d1);
float x1_f = 0;
ck::tensor_operation::element_wise::FastGelu{}.template operator()<float, float>(x1_f,
x0_f);
e = type_convert<bhalf_t>(x1_f);
}
template <>
__host__ __device__ constexpr void operator()<int8_t, int32_t, int8_t, int8_t>(
int8_t& e, const int32_t& c, const int8_t& d0, const int8_t& d1) const
{
const float x0_f =
type_convert<float>(c) + type_convert<float>(d0) + type_convert<float>(d1);
float x1_f = 0;
ck::tensor_operation::element_wise::FastGelu{}.template operator()<float, float>(x1_f,
x0_f);
e = type_convert<int8_t>(x1_f);
}
};
......
#pragma once
#include "ck/utility/data_type.hpp"
// #include "ck/utility/get_id.hpp"
namespace ck {
namespace tensor_operation {
namespace element_wise {
// Y = Sy * Qy
// W = Sw * Qw
// X = Sx * Qx
// B = Sb * Qb = Sw * Sx * Qb
// Where X, W, Y are float32, Qx, Qw, Qy are int8
// Sx, Sw, Sy are scale of x, w, y (float32), which is calculated from quantization range
// Qb is int32, scale of B is Sw * Sx for convenient
// Y = W @ X, where @ is convolution or matrix multiplication
// Sy * Qy = Sw * Qw @ Sx * Qx
// Qy = [(Sw*Sx)/Sy] * Qw @ Qx
// For Activation function which is piecewise linear function, such as relu, leaky relu ...etc
// Activation(Sy * Qy) = Sy * Activation(Qy)
template <typename Activation>
struct Activation_Mul_Clamp
{
// Convolution + Activation (piecewise linear function)
// If an activation is piecewise linear function, then Activation(Sy * Qy) = Sy * Activation(Qy)
// Z = Activation(Y) = Activation(W @ X)
// Sz * Qz = Activation(Sy * Qy)
// Qz = Sy / Sz * Activation(Qy) = (Sw * Sx / Sz) * Activation(Qw @ Qx)
// requantScale_ = Sw * Sx / Sz
Activation_Mul_Clamp(float requantScale, Activation activationOp)
: requantScale_(requantScale), activationOp_(activationOp)
{
......@@ -17,26 +38,66 @@ struct Activation_Mul_Clamp
__host__ __device__ constexpr void operator()(int8_t& y, const int32_t& x) const
{
float x_fp32 = ck::type_convert<float>(x);
activationOp_(x_fp32, x_fp32);
float y_fp32 = math::clamp(requantScale_ * x_fp32, -128.f, 127.f);
float y_fp32 = ck::type_convert<float>(x);
activationOp_(y_fp32, y_fp32);
y_fp32 = math::clamp(requantScale_ * y_fp32, -128.f, 127.f);
y = ck::type_convert<int8_t>(y_fp32);
}
__host__ __device__ constexpr void operator()(float& y, const int32_t& x) const
__device__ constexpr void operator()(int32_t& y, const int32_t& x) const
{
// We might type_convert to int8 after lambda in someplace
float x_fp32 = ck::type_convert<float>(x);
activationOp_(x_fp32, x_fp32);
y = math::clamp(requantScale_ * x_fp32, -128.f, 127.f);
// CAUSION - We might type_convert to int8 in threadwise copy
// eg. GridwiseGemmDlMultipleD_km_kn_mn
float y_fp32 = ck::type_convert<float>(x);
activationOp_(y_fp32, y_fp32);
y_fp32 = math::clamp(requantScale_ * y_fp32, -128.f, 127.f);
y = ck::type_convert<int32_t>(y_fp32);
}
__host__ constexpr void operator()(float& y, const float& x) const
{
// CAUSION - We might float in & float out in reference code
activationOp_(y, x);
y = math::clamp(requantScale_ * y, -128.f, 127.f);
}
float requantScale_;
Activation activationOp_;
};
// For Activation function which is non piecewise linear function, such as TanH, Sigmoid ...etc
// If an activation is not piecewise linear function
// then Activation(Sy * Qy) != Sy * Activation(Qy)
template <typename Activation>
struct Mul_Activation_Mul_Clamp
{
// Convolution + Activation (non piecewise linear function)
// Z = Activation(Y) = Activation(W @ X)
// Sz * Qz = Activation(Sy * Qy)
// Qz = S1 * Activation[Sacc * (Qw @ Qx)]
// Where S1 = 1 / Sz, Sacc = Sw * Sx
Mul_Activation_Mul_Clamp(float scale_z_inv, float scaleAcc, Activation activationOp)
: scale_z_inv_(scale_z_inv), scaleAcc_(scaleAcc), activationOp_(activationOp)
{
}
__host__ __device__ constexpr void operator()(int8_t& y, const int32_t& x) const
{
float y_fp32 = ck::type_convert<float>(x);
y_fp32 = scaleAcc_ * y_fp32;
activationOp_(y_fp32, y_fp32);
y_fp32 = math::clamp(scale_z_inv_ * y_fp32, -128.f, 127.f);
y = ck::type_convert<int8_t>(y_fp32);
}
float scale_z_inv_;
float scaleAcc_;
Activation activationOp_;
};
// Conv Perchannel quantization + Activation function which is piecewise linear function, such as
// relu, leaky relu ...etc
// Activation(Sy * Qy) = Sy * Activation(Qy)
template <typename Activation>
struct Activation_Mul2_Clamp
{
......@@ -51,13 +112,35 @@ struct Activation_Mul2_Clamp
y = ck::type_convert<int8_t>(y_fp32);
}
__device__ constexpr void
operator()(int32_t& y, const int32_t& x, const float& requantScale) const
{
// CAUSION - We might type_convert to int8 in threadwise copy
// eg. GridwiseGemmDlMultipleD_km_kn_mn
float y_fp32 = ck::type_convert<float>(x);
activationOp_(y_fp32, y_fp32);
y_fp32 = math::clamp(requantScale * y_fp32, -128.f, 127.f);
y = ck::type_convert<int32_t>(y_fp32);
}
Activation activationOp_;
};
// For Activation function which is piecewise linear function, such as relu, leaky relu ...etc
// Activation(Sy * Qy) = Sy * Activation(Qy)
template <typename Activation>
struct Add_Activation_Mul_Clamp
{
// Convolution + bias
// Let Bias = B = Sw * Sx * Qb
// Where Qb is int32
// Y = W @ X + B
// Sy * Qy = Sw * Qw @ Sx * Qx + Sw * Sx * Qb
// Qy = [(Sw*Sx)/Sy] * (Qw @ Qx + Qb)
// For activation, Z = Activaiton(Y)
// Sz * Qz = Activation(Sy * Qy)
// Qz = Sy / Sz * Activation(Qy) = [(Sw*Sx)/Sz] * Activation(Qw @ Qx + Qb)
Add_Activation_Mul_Clamp(float requantScale, Activation activationOp)
: requantScale_(requantScale), activationOp_(activationOp)
{
......@@ -72,6 +155,17 @@ struct Add_Activation_Mul_Clamp
y = ck::type_convert<int8_t>(y_fp32);
}
__host__ __device__ constexpr void
operator()(int32_t& y, const int32_t& x, const int32_t& bias) const
{
// CAUSION - We might type_convert to int8 in threadwise copy
// eg. GridwiseGemmDlMultipleD_km_kn_mn
float y_fp32 = ck::type_convert<float>(x + bias);
activationOp_(y_fp32, y_fp32);
y_fp32 = math::clamp(requantScale_ * y_fp32, -128.f, 127.f);
y = ck::type_convert<int32_t>(y_fp32);
}
float requantScale_;
Activation activationOp_;
};
......@@ -92,15 +186,33 @@ struct Add_Activation_Mul2_Clamp
y = ck::type_convert<int8_t>(y_fp32);
}
__host__ __device__ constexpr void
operator()(int32_t& y, const int32_t& x, const int32_t& bias, const float& requantScale) const
{
// CAUSION - We might type_convert to int8 in threadwise copy
// eg. GridwiseGemmDlMultipleD_km_kn_mn
float y_fp32 = ck::type_convert<float>(x + bias);
activationOp_(y_fp32, y_fp32);
y_fp32 = math::clamp(requantScale * y_fp32, -128.f, 127.f);
y = ck::type_convert<int32_t>(y_fp32);
}
Activation activationOp_;
};
// For Activation function which is non piecewise linear function, such as TanH, Sigmoid ...etc
// If an activation is not piecewise linear function
// then Activation(Sy * Qy) != Sy * Activation(Qy)
template <typename Activation>
struct Add_Mul_Activation_Mul_Clamp
{
Add_Mul_Activation_Mul_Clamp(float requantScale1, float requantScale2, Activation activationOp)
: requantScale1_(requantScale1), requantScale2_(requantScale2), activationOp_(activationOp)
// Convolution + Activation (non piecewise linear function)
// Z = Activation(Y) = Activation(W @ X + B)
// Sz * Qz = Activation(Sy * Qy)
// Qz = S1 * Activation[Sacc * (Qw @ Qx + Qb)]
// Where S1 = 1 / Sz, Sacc = Sw * Sx
Add_Mul_Activation_Mul_Clamp(float scale_z_inv, float scaleAcc, Activation activationOp)
: scale_z_inv_(scale_z_inv), scaleAcc_(scaleAcc), activationOp_(activationOp)
{
}
......@@ -108,14 +220,64 @@ struct Add_Mul_Activation_Mul_Clamp
operator()(int8_t& y, const int32_t& x, const int32_t& bias) const
{
float y_fp32 = ck::type_convert<float>(x + bias);
y_fp32 = requantScale1_ * y_fp32;
y_fp32 = scaleAcc_ * y_fp32;
activationOp_(y_fp32, y_fp32);
y_fp32 = math::clamp(scale_z_inv_ * y_fp32, -128.f, 127.f);
y = ck::type_convert<int8_t>(y_fp32);
}
__host__ __device__ constexpr void
operator()(int32_t& y, const int32_t& x, const int32_t& bias) const
{
// CAUSION - We might type_convert to int8 in threadwise copy
// eg. GridwiseGemmDlMultipleD_km_kn_mn
float y_fp32 = ck::type_convert<float>(x + bias);
y_fp32 = scaleAcc_ * y_fp32;
activationOp_(y_fp32, y_fp32);
y_fp32 = math::clamp(scale_z_inv_ * y_fp32, -128.f, 127.f);
y = ck::type_convert<int32_t>(y_fp32);
}
float scale_z_inv_;
float scaleAcc_;
Activation activationOp_;
};
// Conv Perchannel quantization + Activation function which is non piecewise linear function,
// such as TanH, Sigmoid ...etc
// If an activation is not piecewise linear function
// then Activation(Sy *Qy) != Sy * Activation(Qy)
template <typename Activation>
struct Add_Mul2_Activation_Mul_Clamp
{
Add_Mul2_Activation_Mul_Clamp(float scale_z_inv, Activation activationOp)
: scale_z_inv_(scale_z_inv), activationOp_(activationOp)
{
}
__host__ __device__ constexpr void
operator()(int8_t& y, const int32_t& x, const int32_t& bias, const float& scaleAcc) const
{
float y_fp32 = ck::type_convert<float>(x + bias);
y_fp32 = scaleAcc * y_fp32;
activationOp_(y_fp32, y_fp32);
y_fp32 = math::clamp(requantScale2_ * y_fp32, -128.f, 127.f);
y_fp32 = math::clamp(scale_z_inv_ * y_fp32, -128.f, 127.f);
y = ck::type_convert<int8_t>(y_fp32);
}
float requantScale1_;
float requantScale2_;
__host__ __device__ constexpr void
operator()(int32_t& y, const int32_t& x, const int32_t& bias, const float& scaleAcc) const
{
// CAUSION - We might type_convert to int8 in threadwise copy
// eg. GridwiseGemmDlMultipleD_km_kn_mn
float y_fp32 = ck::type_convert<float>(x + bias);
y_fp32 = scaleAcc * y_fp32;
activationOp_(y_fp32, y_fp32);
y_fp32 = math::clamp(scale_z_inv_ * y_fp32, -128.f, 127.f);
y = ck::type_convert<int32_t>(y_fp32);
}
float scale_z_inv_;
Activation activationOp_;
};
......
......@@ -11,6 +11,10 @@ namespace ck {
namespace tensor_operation {
namespace element_wise {
#if CK_WORKAROUND_SWDEV_383542
extern "C" __device__ float __ocml_native_recip_f32(float);
#endif
struct PassThrough
{
template <typename Y, typename X>
......@@ -200,36 +204,83 @@ struct Relu
}
};
// Y = FastGelu(X)
// Fast GeLU
// https://paperswithcode.com/method/gelu
// y = 0.5*x*(1+tanh(sqrt(2/pi)*(x+0.044715*x^3)))
// host code use higher accuracy "exp" and "div"
// gpu code use lower accuracy "__expf" and "rcp" function
struct FastGelu
{
// Fast GeLU
// https://paperswithcode.com/method/gelu
// y = 0.5*x*(1+tanh(sqrt(2/pi)*(x+0.044715*x^3)))
__host__ __device__ static constexpr float GetFastGeLU(float x)
template <typename Y, typename X>
__host__ void operator()(Y& y, const X& x) const;
template <typename Y, typename X>
__device__ void operator()(Y& y, const X& x) const;
template <>
__host__ void operator()<float, float>(float& y, const float& x) const
{
const float u = 2.f * x * (0.035677f * x * x + 0.797885f);
const float emu = exp(-u);
const float cdf = 0.5f + 0.5f * (2.f / (1.f + emu) - 1.f);
return x * cdf;
y = x * cdf;
}
template <typename T>
static inline constexpr bool is_valid_param_type_v =
std::is_same_v<T, float> || std::is_same_v<T, half_t> || std::is_same_v<T, bhalf_t> ||
std::is_same_v<T, int32_t> || std::is_same_v<T, int8_t>
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
|| std::is_same_v<T, ck::int4_t>
// device code, use lower precision "__expf" and "rcp"
template <>
__device__ void operator()<float, float>(float& y, const float& x) const
{
const float u = 2.f * x * (0.035677f * x * x + 0.797885f);
const float emu = __expf(-u);
#if !CK_WORKAROUND_SWDEV_383542
const float cdf = 0.5f + 0.5f * (2.f * __frcp_rn(1.f + emu) - 1.f);
#else
const float cdf = 0.5f + 0.5f * (2.f * __ocml_native_recip_f32(1.f + emu) - 1.f);
#endif
;
template <typename Y, typename X>
__host__ __device__ void operator()(Y& y, const X& x) const
y = x * cdf;
}
template <>
__host__ void operator()<half_t, half_t>(half_t& y, const half_t& x) const
{
float y_f;
this->operator()<float, float>(y_f, type_convert<float>(x));
y = type_convert<half_t>(y_f);
}
template <>
__device__ void operator()<half_t, half_t>(half_t& y, const half_t& x) const
{
static_assert(is_valid_param_type_v<Y> && is_valid_param_type_v<X>);
float y_f;
const float tmp_y = GetFastGeLU(type_convert<float>(x));
y = type_convert<Y>(tmp_y);
this->operator()<float, float>(y_f, type_convert<float>(x));
y = type_convert<half_t>(y_f);
}
template <>
__host__ void operator()<half_t, float>(half_t& y, const float& x) const
{
float y_f;
this->operator()<float, float>(y_f, x);
y = type_convert<half_t>(y_f);
}
template <>
__device__ void operator()<half_t, float>(half_t& y, const float& x) const
{
float y_f;
this->operator()<float, float>(y_f, x);
y = type_convert<half_t>(y_f);
}
};
......@@ -265,8 +316,36 @@ struct Sigmoid
y = 1 / (ck::type_convert<T>(1) + exp(-x));
};
};
int32_t divider_ = 1;
struct TanH
{
template <typename T>
__host__ __device__ void operator()(T& y, const T& x) const
{
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, ck::half_t>::value,
"Data type is not supported by this operation!");
y = ck::math::tanh(x);
};
};
struct Swish
{
Swish(float beta = 1.0f) : beta_(beta) {}
template <typename T>
__host__ __device__ void operator()(T& y, const T& x) const
{
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, ck::half_t>::value,
"Data type is not supported by this operation!");
y = x / (ck::type_convert<T>(1) + ck::math::exp(-beta_ * x));
};
float beta_ = 1.0f;
};
} // namespace element_wise
......
......@@ -18,6 +18,10 @@
namespace ck {
/**
* @brief Gridwise gemm + softmax + gemm fusion
*
*/
template <typename FloatAB,
typename FloatGemmAcc,
typename FloatCShuffle,
......
......@@ -185,8 +185,10 @@ struct GridwiseGemmDlMultipleD_km_kn_mn
return b_grid_desc_k0_n0_n1_k1;
}
// E desc for destination in blockwise copy
template <typename CGridDesc_M_N_>
__host__ __device__ static constexpr auto
MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(const CGridDesc_M_N& c_grid_desc_m_n)
MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(const CGridDesc_M_N_& c_grid_desc_m_n)
{
const auto M = c_grid_desc_m_n.GetLength(I0);
const auto N = c_grid_desc_m_n.GetLength(I1);
......@@ -238,19 +240,19 @@ struct GridwiseGemmDlMultipleD_km_kn_mn
using BGridDesc_K0_N0_N1_K1 = decltype(MakeBGridDescriptor_K0_N0_N1_K1(BGridDesc_K0_N_K1{}));
using CGridDesc_M0_M10_M11_N0_N10_N11 =
decltype(MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(CGridDesc_M_N{}));
using Block2CTileMap = decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}));
using DsGridPointer = decltype(MakeDsGridPointer());
template <typename DsGridDesc_M0_M10_M11_N0_N10_N11,
bool HasMainKBlockLoop,
bool HasDoubleTailKBlockLoop>
bool HasDoubleTailKBlockLoop,
typename Block2CTileMap>
__device__ static void
Run(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
DsGridPointer p_ds_grid,
FloatC* __restrict__ p_c_grid,
FloatAB* __restrict__ p_shared_block,
void* __restrict__ p_shared_block,
const AElementwiseOperation&,
const BElementwiseOperation&,
const CDEElementwiseOperation& cde_element_op,
......@@ -399,8 +401,9 @@ struct GridwiseGemmDlMultipleD_km_kn_mn
constexpr auto b_block_aligned_space_size = math::integer_least_multiple(
b_block_desc_k0_n0_n1_k1.GetElementSpaceSize(), max_lds_align);
FloatAB* p_a_block_double = p_shared_block;
FloatAB* p_b_block_double = p_shared_block + 2 * a_block_aligned_space_size;
FloatAB* p_a_block_double = static_cast<FloatAB*>(p_shared_block);
FloatAB* p_b_block_double =
static_cast<FloatAB*>(p_shared_block) + 2 * a_block_aligned_space_size;
// register allocation for output
auto c_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAcc>(
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -574,4 +574,546 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
}
};
template <index_t BlockSize,
typename FloatAB,
typename FloatAcc,
typename FloatC,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
typename AGridDesc_B_K0_M_K1,
typename BGridDesc_B_K0_N_K1,
typename CGridDesc_M_N,
index_t MPerBlock,
index_t NPerBlock,
index_t K0PerBlock,
index_t K1Value,
index_t M1PerThreadM111,
index_t N1PerThreadN111,
index_t KPerThread,
typename M11N11ThreadClusterM110Xs,
typename M11N11ThreadClusterN110Xs,
typename ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
typename ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
typename ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1,
typename ABlockTransferSrcVectorTensorContiguousDimOrder,
typename ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1,
typename BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
typename BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
typename BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1,
typename BBlockTransferSrcVectorTensorContiguousDimOrder,
typename BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector>
struct GridwiseGemmDl_bkm_bkn_mn_v1r3
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
// K1 should be Number<...>
static constexpr auto K1 = Number<K1Value>{};
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{
// TODO: change this. I think it needs multi-dimensional alignment
constexpr auto max_lds_align = K1;
// TODO: check alignment
// A matrix in LDS memory, dst of blockwise copy
constexpr auto a_block_desc_b_k0_m_k1 = make_naive_tensor_descriptor_aligned(
make_tuple(Number<1>{}, Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
// TODO: check alignment
// B matrix in LDS memory, dst of blockwise copy
constexpr auto b_block_desc_b_k0_n_k1 = make_naive_tensor_descriptor_aligned(
make_tuple(Number<1>{}, Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
// TODO: check alignment
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_aligned_space_size = math::integer_least_multiple(
a_block_desc_b_k0_m_k1.GetElementSpaceSize(), max_lds_align);
constexpr auto b_block_aligned_space_size = math::integer_least_multiple(
b_block_desc_b_k0_n_k1.GetElementSpaceSize(), max_lds_align);
return 2 * (a_block_aligned_space_size + b_block_aligned_space_size) * sizeof(FloatAB);
}
__host__ __device__ static constexpr bool
CheckValidity(const AGridDesc_B_K0_M_K1& a_grid_desc_b_k0_m_k1,
const BGridDesc_B_K0_N_K1& b_grid_desc_b_k0_n_k1,
const CGridDesc_M_N& c_grid_desc_m_n)
{
const auto M = a_grid_desc_b_k0_m_k1.GetLength(I2);
const auto N = b_grid_desc_b_k0_n_k1.GetLength(I2);
const auto K0 = a_grid_desc_b_k0_m_k1.GetLength(I1);
const auto KBatch = a_grid_desc_b_k0_m_k1.GetLength(I0);
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return (M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1) &&
K0 == b_grid_desc_b_k0_n_k1.GetLength(I1) &&
K1 == a_grid_desc_b_k0_m_k1.GetLength(I3) &&
K1 == b_grid_desc_b_k0_n_k1.GetLength(I3)) &&
KBatch == b_grid_desc_b_k0_n_k1.GetLength(I0) &&
(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0);
}
__host__ __device__ static constexpr index_t CalculateGridSize(index_t M, index_t N)
{
const index_t grid_size = (M / MPerBlock) * (N / NPerBlock);
return grid_size;
}
__host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K0)
{
const bool has_main_k_block_loop = (K0 + K0PerBlock) / (2 * K0PerBlock) > 1;
return has_main_k_block_loop;
}
__host__ __device__ static constexpr bool CalculateHasDoubleTailKBlockLoop(index_t K0)
{
const bool has_double_tail_k_block_loop = (K0 / K0PerBlock) % 2 == 0;
return has_double_tail_k_block_loop;
}
__host__ __device__ static constexpr auto
MakeAGridDescriptor_B_K0_M0_M1_K1(const AGridDesc_B_K0_M_K1& a_grid_desc_b_k0_m_k1)
{
const auto KBatch = a_grid_desc_b_k0_m_k1.GetLength(I0);
const auto K0 = a_grid_desc_b_k0_m_k1.GetLength(I1);
const auto M = a_grid_desc_b_k0_m_k1.GetLength(I2);
const auto M1 = Number<MPerBlock>{};
const auto M0 = M / M1;
const auto a_grid_desc_b_k0_m0_m1_k1 = transform_tensor_descriptor(
a_grid_desc_b_k0_m_k1,
make_tuple(make_pass_through_transform(KBatch),
make_pass_through_transform(K0),
make_unmerge_transform(make_tuple(M0, M1)),
make_pass_through_transform(K1)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}));
return a_grid_desc_b_k0_m0_m1_k1;
}
__host__ __device__ static constexpr auto
MakeBGridDescriptor_B_K0_N0_N1_K1(const BGridDesc_B_K0_N_K1& b_grid_desc_b_k0_n_k1)
{
const auto KBatch = b_grid_desc_b_k0_n_k1.GetLength(I0);
const auto K0 = b_grid_desc_b_k0_n_k1.GetLength(I1);
const auto N = b_grid_desc_b_k0_n_k1.GetLength(I2);
const auto N1 = Number<NPerBlock>{};
const auto N0 = N / N1;
const auto b_grid_desc_b_k0_n0_n1_k1 = transform_tensor_descriptor(
b_grid_desc_b_k0_n_k1,
make_tuple(make_pass_through_transform(KBatch),
make_pass_through_transform(K0),
make_unmerge_transform(make_tuple(N0, N1)),
make_pass_through_transform(K1)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}));
return b_grid_desc_b_k0_n0_n1_k1;
}
__host__ __device__ static constexpr auto
MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(const CGridDesc_M_N& c_grid_desc_m_n)
{
const auto M = c_grid_desc_m_n.GetLength(I0);
const auto N = c_grid_desc_m_n.GetLength(I1);
constexpr auto M1 = Number<MPerBlock>{};
constexpr auto N1 = Number<NPerBlock>{};
const auto M0 = M / M1;
const auto N0 = N / N1;
constexpr auto M11 =
Number<container_reduce(M11N11ThreadClusterM110Xs{}, math::multiplies{}, I1) *
M1PerThreadM111>{};
constexpr auto N11 =
Number<container_reduce(M11N11ThreadClusterN110Xs{}, math::multiplies{}, I1) *
N1PerThreadN111>{};
constexpr auto M10 = M1 / M11;
constexpr auto N10 = N1 / N11;
const auto c_grid_desc_m0_m10_m11_n0_n10_n11 = transform_tensor_descriptor(
c_grid_desc_m_n,
make_tuple(make_unmerge_transform(make_tuple(M0, M10, M11)),
make_unmerge_transform(make_tuple(N0, N10, N11))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{}));
return c_grid_desc_m0_m10_m11_n0_n10_n11;
}
// return block_id to C matrix tile idx (m0, n0) mapping
__host__ __device__ static constexpr auto MakeCBlockClusterAdaptor(
const CGridDesc_M_N& c_m_n_grid_desc, index_t M01, index_t N01, index_t KBatch)
{
return BlockToCTileMap_KSplit_M00_N00_M01_N01<MPerBlock, NPerBlock, CGridDesc_M_N>(
c_m_n_grid_desc, M01, N01, KBatch);
}
using AGridDesc_B_K0_M0_M1_K1 =
decltype(MakeAGridDescriptor_B_K0_M0_M1_K1(AGridDesc_B_K0_M_K1{}));
using BGridDesc_B_K0_N0_N1_K1 =
decltype(MakeBGridDescriptor_B_K0_N0_N1_K1(BGridDesc_B_K0_N_K1{}));
using CGridDesc_M0_M10_M11_N0_N10_N11 =
decltype(MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(CGridDesc_M_N{}));
using CBlockClusterAdaptor = decltype(MakeCBlockClusterAdaptor(CGridDesc_M_N{}, 1, 1, 1));
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
__device__ static void
Run(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
FloatAB* __restrict__ p_shared_block,
const AGridDesc_B_K0_M0_M1_K1& a_grid_desc_b_k0_m0_m1_k1,
const BGridDesc_B_K0_N0_N1_K1& b_grid_desc_b_k0_n0_n1_k1,
const CGridDesc_M0_M10_M11_N0_N10_N11& c_grid_desc_m0_m10_m11_n0_n10_n11,
const CBlockClusterAdaptor& c_block_cluster_adaptor,
integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>)
{
const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_b_k0_m0_m1_k1.GetElementSpaceSize());
const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid, b_grid_desc_b_k0_n0_n1_k1.GetElementSpaceSize());
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_grid, c_grid_desc_m0_m10_m11_n0_n10_n11.GetElementSpaceSize());
// divide block work by [M, N]
const auto block_work_idx =
c_block_cluster_adaptor.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
const index_t k_batch_id = block_work_idx[I0];
if(!c_block_cluster_adaptor.ValidCTileIndex(
make_tuple(block_work_idx[I1], block_work_idx[I2]),
make_tuple(c_grid_desc_m0_m10_m11_n0_n10_n11.GetLength(I0),
c_grid_desc_m0_m10_m11_n0_n10_n11.GetLength(I3))))
{
return;
}
// HACK: this force m/n_block_data_idx_on_grid into SGPR
const index_t m_block_data_idx_on_grid = __builtin_amdgcn_readfirstlane(block_work_idx[I1]);
const index_t n_block_data_idx_on_grid = __builtin_amdgcn_readfirstlane(block_work_idx[I2]);
// TODO: change this. I think it needs multi-dimensional alignment
constexpr auto max_lds_align = K1;
// TODO: check alignment
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto a_block_desc_b_k0_m0_m1_k1 = make_naive_tensor_descriptor_aligned(
make_tuple(I1, Number<K0PerBlock>{}, I1, Number<MPerBlock>{}, K1), max_lds_align);
// TODO: check alignment
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto b_block_desc_b_k0_n0_n1_k1 = make_naive_tensor_descriptor_aligned(
make_tuple(I1, Number<K0PerBlock>{}, I1, Number<NPerBlock>{}, K1), max_lds_align);
// TODO: check alignment
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto a_block_desc_k0_m0_m1_k1 = make_naive_tensor_descriptor_aligned(
make_tuple(Number<K0PerBlock>{}, I1, Number<MPerBlock>{}, K1), max_lds_align);
// TODO: check alignment
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto b_block_desc_k0_n0_n1_k1 = make_naive_tensor_descriptor_aligned(
make_tuple(Number<K0PerBlock>{}, I1, Number<NPerBlock>{}, K1), max_lds_align);
// TODO: check alignment
// A matrix in LDS memory, for blockwise GEMM
constexpr auto a_k0_m_k1_block_desc = make_naive_tensor_descriptor_aligned(
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
// TODO: check alignment
// B matrix in LDS memory, for blockwise GEMM
constexpr auto b_k0_n_k1_block_desc = make_naive_tensor_descriptor_aligned(
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
static_assert(a_block_desc_k0_m0_m1_k1.GetElementSpaceSize() ==
a_k0_m_k1_block_desc.GetElementSpaceSize() &&
b_block_desc_k0_n0_n1_k1.GetElementSpaceSize() ==
b_k0_n_k1_block_desc.GetElementSpaceSize() &&
"wrong!");
// A matrix blockwise copy
auto a_blockwise_copy = BlockwiseTensorSliceTransfer_v5r1<
BlockSize,
InMemoryDataOperationEnum::Set,
Sequence<1, K0PerBlock, 1, MPerBlock, K1.value>,
ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
ABlockTransferThreadClusterArrangeOrder,
FloatAB,
FloatAB,
remove_reference_t<decltype(a_grid_desc_b_k0_m0_m1_k1)>,
decltype(a_block_desc_b_k0_m0_m1_k1),
ABlockTransferSrcAccessOrder,
Sequence<0, 1, 2, 3, 4>,
ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1, // SrcVectorTensorLengths
ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1, // DstVectorTensorLengths
ABlockTransferSrcVectorTensorContiguousDimOrder, // SrcVectorTensorContiguousDimOrder
Sequence<0, 1, 2, 3, 4>, // DstVectorTensorContiguousDimOrder
false,
true>(a_grid_desc_b_k0_m0_m1_k1,
make_multi_index(k_batch_id, 0, m_block_data_idx_on_grid, 0, 0),
a_block_desc_b_k0_m0_m1_k1,
make_multi_index(0, 0, 0, 0, 0));
// B matrix blockwise copy
auto b_blockwise_copy = BlockwiseTensorSliceTransfer_v5r1<
BlockSize,
InMemoryDataOperationEnum::Set,
Sequence<1, K0PerBlock, 1, NPerBlock, K1.value>,
BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
BBlockTransferThreadClusterArrangeOrder,
FloatAB,
FloatAB,
remove_reference_t<decltype(b_grid_desc_b_k0_n0_n1_k1)>,
decltype(b_block_desc_b_k0_n0_n1_k1),
BBlockTransferSrcAccessOrder,
Sequence<0, 1, 2, 3, 4>,
BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1, // SrcVectorTensorLengths
BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1, // DstVectorTensorLengths
BBlockTransferSrcVectorTensorContiguousDimOrder, // SrcVectorTensorContiguousDimOrder
Sequence<0, 1, 2, 3, 4>, // DstVectorTensorContiguousDimOrder
false,
true>(b_grid_desc_b_k0_n0_n1_k1,
make_multi_index(k_batch_id, 0, n_block_data_idx_on_grid, 0, 0),
b_block_desc_b_k0_n0_n1_k1,
make_multi_index(0, 0, 0, 0, 0));
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[K0PerBlock, MPerBlock] is in LDS
// b_mtx[KPerBlocl, NPerBlock] is in LDS
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
const auto blockwise_gemm =
BlockwiseGemmDl_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2<
BlockSize,
FloatAB,
FloatAB,
FloatAcc,
decltype(a_k0_m_k1_block_desc),
decltype(b_k0_n_k1_block_desc),
M1PerThreadM111,
N1PerThreadN111,
KPerThread,
M11N11ThreadClusterM110Xs,
M11N11ThreadClusterN110Xs,
M1PerThreadM111,
N1PerThreadN111>{};
constexpr auto c_m10_m11_n10_n11_thread_tensor_lengths =
decltype(blockwise_gemm)::GetCThreadTensorLengths_BM0_BM1_BN0_BN1();
constexpr auto c_thread_desc_m10_m11_n10_n11 = make_naive_tensor_descriptor_packed(
sequence_to_tuple_of_number(c_m10_m11_n10_n11_thread_tensor_lengths));
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_aligned_space_size = math::integer_least_multiple(
a_block_desc_k0_m0_m1_k1.GetElementSpaceSize(), max_lds_align);
constexpr auto b_block_aligned_space_size = math::integer_least_multiple(
b_block_desc_k0_n0_n1_k1.GetElementSpaceSize(), max_lds_align);
FloatAB* p_a_block_double = p_shared_block;
FloatAB* p_b_block_double = p_shared_block + 2 * a_block_aligned_space_size;
// register allocation for output
auto c_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAcc>(
c_thread_desc_m10_m11_n10_n11.GetElementSpaceSize());
// Initialize C
c_thread_buf.Clear();
constexpr auto a_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0, 0);
auto a_block_even_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
p_a_block_double, a_block_desc_k0_m0_m1_k1.GetElementSpaceSize());
auto b_block_even_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
p_b_block_double, b_block_desc_k0_n0_n1_k1.GetElementSpaceSize());
auto a_block_odd_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
p_a_block_double + a_block_aligned_space_size,
a_block_desc_k0_m0_m1_k1.GetElementSpaceSize());
auto b_block_odd_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
p_b_block_double + b_block_aligned_space_size,
b_block_desc_k0_n0_n1_k1.GetElementSpaceSize());
// LDS double buffer: preload data into LDS
{
a_blockwise_copy.RunRead(a_grid_desc_b_k0_m0_m1_k1, a_global_buf);
b_blockwise_copy.RunRead(b_grid_desc_b_k0_n0_n1_k1, b_global_buf);
a_blockwise_copy.RunWrite(a_block_desc_b_k0_m0_m1_k1, a_block_even_buf);
b_blockwise_copy.RunWrite(b_block_desc_b_k0_n0_n1_k1, b_block_even_buf);
}
if constexpr(HasMainKBlockLoop)
{
const auto K0 = a_grid_desc_b_k0_m0_m1_k1.GetLength(I1);
index_t k_block_data_begin = 0;
// LDS double buffer: main body
// use Do-While loop instead of For loop to simplify control flow
do
{
// even iteration
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_b_k0_m0_m1_k1,
a_block_slice_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_b_k0_n0_n1_k1,
b_block_slice_copy_step);
// LDS doubel buffer: load next data from device mem
a_blockwise_copy.RunRead(a_grid_desc_b_k0_m0_m1_k1, a_global_buf);
b_blockwise_copy.RunRead(b_grid_desc_b_k0_n0_n1_k1, b_global_buf);
block_sync_lds();
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(c_thread_desc_m10_m11_n10_n11,
a_block_even_buf,
b_block_even_buf,
c_thread_buf);
// LDS double buffer: store next data to LDS
a_blockwise_copy.RunWrite(a_block_desc_b_k0_m0_m1_k1, a_block_odd_buf);
b_blockwise_copy.RunWrite(b_block_desc_b_k0_n0_n1_k1, b_block_odd_buf);
// odd iteration
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_b_k0_m0_m1_k1,
a_block_slice_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_b_k0_n0_n1_k1,
b_block_slice_copy_step);
// LDS doubel buffer: load next data from device mem
a_blockwise_copy.RunRead(a_grid_desc_b_k0_m0_m1_k1, a_global_buf);
b_blockwise_copy.RunRead(b_grid_desc_b_k0_n0_n1_k1, b_global_buf);
block_sync_lds();
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(
c_thread_desc_m10_m11_n10_n11, a_block_odd_buf, b_block_odd_buf, c_thread_buf);
// LDS double buffer: store next data to LDS
a_blockwise_copy.RunWrite(a_block_desc_b_k0_m0_m1_k1, a_block_even_buf);
b_blockwise_copy.RunWrite(b_block_desc_b_k0_n0_n1_k1, b_block_even_buf);
k_block_data_begin += 2 * K0PerBlock;
} while(k_block_data_begin < K0 - 2 * K0PerBlock);
}
// LDS double buffer: tail
if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left
{
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_b_k0_m0_m1_k1, a_block_slice_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_b_k0_n0_n1_k1, b_block_slice_copy_step);
block_sync_lds();
// LDS double buffer: load last data from device mem
a_blockwise_copy.RunRead(a_grid_desc_b_k0_m0_m1_k1, a_global_buf);
b_blockwise_copy.RunRead(b_grid_desc_b_k0_n0_n1_k1, b_global_buf);
// LDS double buffer: GEMM on 2nd-last data
blockwise_gemm.Run(
c_thread_desc_m10_m11_n10_n11, a_block_even_buf, b_block_even_buf, c_thread_buf);
// LDS double buffer: store last data to LDS
a_blockwise_copy.RunWrite(a_block_desc_b_k0_m0_m1_k1, a_block_odd_buf);
b_blockwise_copy.RunWrite(b_block_desc_b_k0_n0_n1_k1, b_block_odd_buf);
block_sync_lds();
// LDS double buffer: GEMM on last data
blockwise_gemm.Run(
c_thread_desc_m10_m11_n10_n11, a_block_odd_buf, b_block_odd_buf, c_thread_buf);
}
else // if has 1 iteration left
{
__syncthreads();
// LDS double buffer: GEMM on last data
blockwise_gemm.Run(
c_thread_desc_m10_m11_n10_n11, a_block_even_buf, b_block_even_buf, c_thread_buf);
}
// output: register to global memory
{
constexpr auto c_thread_desc_m0_m10_m11_n0_n10_n11 =
make_naive_tensor_descriptor_packed(
make_tuple(I1,
Number<c_m10_m11_n10_n11_thread_tensor_lengths[I0]>{},
Number<c_m10_m11_n10_n11_thread_tensor_lengths[I1]>{},
I1,
Number<c_m10_m11_n10_n11_thread_tensor_lengths[I2]>{},
Number<c_m10_m11_n10_n11_thread_tensor_lengths[I3]>{}));
const auto c_m10_m11_n10_n11_thread_origin_idx_on_block =
blockwise_gemm.CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1(
get_thread_local_1d_id());
ThreadwiseTensorSliceTransfer_v1r3<
FloatAcc,
FloatC,
decltype(c_thread_desc_m0_m10_m11_n0_n10_n11),
decltype(c_grid_desc_m0_m10_m11_n0_n10_n11),
ck::tensor_operation::element_wise::PassThrough,
Sequence<1,
c_m10_m11_n10_n11_thread_tensor_lengths[I0],
c_m10_m11_n10_n11_thread_tensor_lengths[I1],
1,
c_m10_m11_n10_n11_thread_tensor_lengths[I2],
c_m10_m11_n10_n11_thread_tensor_lengths[I3]>,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector,
CGlobalMemoryDataOperation,
1,
true>{c_grid_desc_m0_m10_m11_n0_n10_n11,
make_multi_index(m_block_data_idx_on_grid,
c_m10_m11_n10_n11_thread_origin_idx_on_block[I0],
c_m10_m11_n10_n11_thread_origin_idx_on_block[I1],
n_block_data_idx_on_grid,
c_m10_m11_n10_n11_thread_origin_idx_on_block[I2],
c_m10_m11_n10_n11_thread_origin_idx_on_block[I3]),
ck::tensor_operation::element_wise::PassThrough{}}
.Run(c_thread_desc_m0_m10_m11_n0_n10_n11,
make_tuple(I0, I0, I0, I0, I0, I0),
c_thread_buf,
c_grid_desc_m0_m10_m11_n0_n10_n11,
c_grid_buf);
}
}
};
} // namespace ck
......@@ -54,7 +54,8 @@ __global__ void
const Block2CTileMap block_2_ctile_map,
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__) || defined(__gfx1101__) || \
defined(__gfx1102__))
// offset base pointer for each work-group
const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
......@@ -147,7 +148,8 @@ __global__ void
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
const Block2CTileMap block_2_etile_map)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__) || defined(__gfx1101__) || \
defined(__gfx1102__))
// printf("entry kernel launch");
__shared__ char p_shared[GridwiseOp::GetSharedMemoryNumberOfByte()];
......@@ -242,7 +244,8 @@ __global__ void
const CDEElementwiseOperation cde_element_op,
const Block2CTileMap block_2_ctile_map)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__) || defined(__gfx1101__) || \
defined(__gfx1102__))
__shared__ char p_shared[GridwiseOp::GetSharedMemoryNumberOfByte()];
GridwiseOp::template Run<HasMainKBlockLoop>(p_a_grid,
......@@ -271,7 +274,7 @@ __global__ void
ignore = b_element_op;
ignore = cde_element_op;
ignore = block_2_ctile_map;
#endif // end of if (defined(__gfx1100__))
#endif // end of if (defined(__gfx1100__ ))
}
template < // DataType Family
......@@ -428,6 +431,9 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
constexpr auto b_block_desc_k0perblock_nperblock_k1 =
GetBBlockDescriptor_K0PerBlock_NPerBlock_K1();
constexpr auto cshuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat =
GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat();
constexpr auto max_lds_align = K1;
constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
......@@ -436,8 +442,13 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
constexpr auto b_block_space_size_aligned = math::integer_least_multiple(
b_block_desc_k0perblock_nperblock_k1.GetElementSpaceSize(), max_lds_align);
return (a_block_space_size_aligned * sizeof(ADataType) +
b_block_space_size_aligned * sizeof(BDataType));
constexpr auto c_block_space_size_aligned = math::integer_least_multiple(
cshuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat.GetElementSpaceSize(),
max_lds_align);
return math::max((a_block_space_size_aligned * sizeof(ADataType) +
b_block_space_size_aligned * sizeof(BDataType)),
c_block_space_size_aligned * sizeof(CShuffleDataType));
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
......@@ -494,6 +505,15 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
}
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
constexpr long_index_t TwoGB = (long_index_t{1} << 31);
if(!(a_grid_desc_k0_m_k1.GetElementSpaceSize() * sizeof(ADataType) <= TwoGB &&
b_grid_desc_k0_n_k1.GetElementSpaceSize() * sizeof(BDataType) <= TwoGB &&
e_grid_desc_m_n.GetElementSpaceSize() * sizeof(EDataType) <= TwoGB))
{
return false;
}
return true;
}
......@@ -673,7 +693,7 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
constexpr auto KPack = math::integer_least_multiple(K1, WmmaK);
auto blockwise_gemm =
BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO<BlockSize,
BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle<BlockSize,
ADataType,
BDataType,
AccDataType,
......@@ -716,7 +736,6 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
c_thread_buf,
K0BlockMainLoop);
/*******************************************************************************/
//printf("safe 1");
// write out to C, implement shuffle
{
constexpr auto c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs =
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -92,6 +92,17 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
using GridwiseGemmPipe = remove_cvref_t<decltype(
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>())>;
// denorm test fix, required to work around fp16 mfma issue
// we convert fp16->fp32->bf16 and execute bf16 mfma instruction
// when mfma if fixed, remove this section and update
// ABDataTypeAdjusted -> ABDataType throughout this file
#if CK_WORKAROUND_DENORM_FIX && defined(__gfx90a__)
using ABDataTypeAdjusted =
conditional_t<is_same_v<ABDataType, ck::half_t>, ck::bhalf_t, ABDataType>;
#else
using ABDataTypeAdjusted = ABDataType;
#endif
__host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
{
// A matrix in LDS memory, dst of blockwise copy
......@@ -397,7 +408,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ABDataType,
ABDataType,
ABDataTypeAdjusted,
decltype(a_grid_desc_ak0_m_ak1),
decltype(a_block_desc_ak0_m_ak1),
ABlockTransferSrcAccessOrder,
......@@ -428,7 +439,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
ABDataType,
ABDataType,
ABDataTypeAdjusted,
decltype(b_grid_desc_bk0_n_bk1),
decltype(b_block_desc_bk0_n_bk1),
BBlockTransferSrcAccessOrder,
......@@ -458,11 +469,11 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
// sanity check
constexpr index_t KPack =
math::max(math::lcm(AK1, BK1),
MfmaSelector<ABDataType, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
MfmaSelector<ABDataTypeAdjusted, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
BlockSize,
ABDataType,
ABDataTypeAdjusted,
AccDataType,
decltype(a_block_desc_ak0_m_ak1),
decltype(b_block_desc_bk0_n_bk1),
......@@ -480,10 +491,11 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<ABDataType*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
static_cast<ABDataTypeAdjusted*>(p_shared),
a_block_desc_ak0_m_ak1.GetElementSpaceSize());
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<ABDataType*>(p_shared) + a_block_space_size_aligned,
static_cast<ABDataTypeAdjusted*>(p_shared) + a_block_space_size_aligned,
b_block_desc_bk0_n_bk1.GetElementSpaceSize());
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0);
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......
......@@ -49,7 +49,8 @@ __global__ void
const CElementwiseOperation c_element_op,
const Block2CTileMap block_2_ctile_map)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__) || defined(__gfx1101__) || \
defined(__gfx1102__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
......@@ -263,6 +264,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
}
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
constexpr long_index_t TwoGB = (long_index_t{1} << 31);
if(!(a_grid_desc_k0_m_k1.GetElementSpaceSize() * sizeof(FloatA) <= TwoGB &&
b_grid_desc_k0_n_k1.GetElementSpaceSize() * sizeof(FloatB) <= TwoGB))
{
return false;
}
return true;
}
......@@ -414,7 +422,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
constexpr auto KPack = math::integer_least_multiple(K1, WmmaK);
auto blockwise_gemm =
BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO<BlockSize,
BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle<BlockSize,
FloatA,
FloatB,
FloatAcc,
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -166,15 +166,12 @@ __global__ void
const CBlockClusterAdaptor c_block_cluster_adaptor)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
constexpr index_t shared_block_size =
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
__shared__ FloatAB p_shared_block[shared_block_size];
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
p_b_grid,
p_c_grid,
p_shared_block,
p_shared,
a_b_k0_m_k1_grid_desc,
b_b_k0_n_k1_grid_desc,
c_grid_desc_mblock_mperblock_nblock_nperblock,
......@@ -264,6 +261,16 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
using GridwiseGemmPipe = remove_cvref_t<decltype(
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage>())>;
// denorm test fix, required to work around fp16 mfma issue
// we convert fp16->fp32->bf16 and execute bf16 mfma instruction
// when mfma if fixed, remove this section and update
// FloatABAdjusted -> FloatAB throughout this file
#if CK_WORKAROUND_DENORM_FIX && defined(__gfx90a__)
using FloatABAdjusted = conditional_t<is_same_v<FloatAB, ck::half_t>, ck::bhalf_t, FloatAB>;
#else
using FloatABAdjusted = FloatAB;
#endif
// M0/M1/M1Padding
static constexpr auto M1PerBlock = Number<ABlockLdsM1PerBlock>{};
static constexpr auto M0PerBlock = Number<ABlockLdsM0PerBlock>{};
......@@ -605,7 +612,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
__device__ static void Run(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
FloatAB* __restrict__ p_shared_block,
void* __restrict__ p_shared,
const AGridDesc_B_K0_M_K1& a_b_k0_m_k1_grid_desc,
const BGridDesc_B_K0_N_K1& b_b_k0_n_k1_grid_desc,
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
......@@ -666,7 +673,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
ABlockTransferThreadClusterLengths_K0_M_K1,
ABlockTransferThreadClusterArrangeOrder,
FloatAB,
FloatAB,
FloatABAdjusted,
decltype(a_b_k0_m_k1_grid_desc),
decltype(a_b_k0_m_k1_block_desc),
ABlockTransferSrcAccessOrder,
......@@ -696,7 +703,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
BBlockTransferThreadClusterLengths_K0_N_K1,
BBlockTransferThreadClusterArrangeOrder,
FloatAB,
FloatAB,
FloatABAdjusted,
decltype(b_b_k0_n_k1_grid_desc),
decltype(b_b_k0_n_k1_block_desc),
BBlockTransferSrcAccessOrder,
......@@ -725,11 +732,11 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
// sanity check
constexpr index_t KPack =
math::max(K1, MfmaSelector<FloatAB, MPerXDL, NPerXDL>::selected_mfma.k_per_blk);
math::max(K1, MfmaSelector<FloatABAdjusted, MPerXDL, NPerXDL>::selected_mfma.k_per_blk);
auto blockwise_gemm =
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
FloatAB,
FloatABAdjusted,
FloatAcc,
decltype(a_k0_m_k1_block_desc),
decltype(b_k0_n_k1_block_desc),
......@@ -745,16 +752,15 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
constexpr auto a_block_space_size =
math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align);
FloatAB* p_a_block = p_shared_block;
FloatAB* p_b_block = p_shared_block + a_block_space_size;
constexpr auto a_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0);
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
p_a_block, a_k0_m_k1_block_desc.GetElementSpaceSize());
static_cast<FloatABAdjusted*>(p_shared), a_k0_m_k1_block_desc.GetElementSpaceSize());
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
p_b_block, b_k0_n_k1_block_desc.GetElementSpaceSize());
static_cast<FloatABAdjusted*>(p_shared) + a_block_space_size,
b_k0_n_k1_block_desc.GetElementSpaceSize());
// gridwise GEMM pipeline
const index_t K0BlockMainLoop = __builtin_amdgcn_readfirstlane(K0 / K0PerBlock);
......@@ -798,8 +804,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
constexpr auto c_block_desc_mblock_mperblock_nblock_nperblock =
GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
void* p_shared = static_cast<void*>(p_shared_block);
auto c_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatC*>(p_shared),
c_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -131,6 +131,16 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
using GridwiseGemmPipe = remove_cvref_t<decltype(
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>())>;
// denorm test fix, required to work around fp16 mfma issue
// we convert fp16->fp32->bf16 and execute bf16 mfma instruction
// when mfma if fixed, remove this section and update
// FloatABAdjusted -> FloatAB throughout this file
#if CK_WORKAROUND_DENORM_FIX && defined(__gfx90a__)
using FloatABAdjusted = conditional_t<is_same_v<FloatAB, ck::half_t>, ck::bhalf_t, FloatAB>;
#else
using FloatABAdjusted = FloatAB;
#endif
__host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1()
{
constexpr auto max_lds_align = K1;
......@@ -281,7 +291,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
using BlockwiseGemm =
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
FloatAB,
FloatABAdjusted,
FloatAcc,
decltype(a_block_desc_k0_m_k1),
decltype(b_block_desc_k0_n_k1),
......@@ -367,7 +377,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
ABlockTransferThreadClusterLengths_K0_M_K1,
ABlockTransferThreadClusterArrangeOrder,
FloatAB,
FloatAB,
FloatABAdjusted,
decltype(a_grid_desc_k0_m_k1),
decltype(a_block_desc_k0_m_k1),
ABlockTransferSrcAccessOrder,
......@@ -398,7 +408,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
BBlockTransferThreadClusterLengths_K0_N_K1,
BBlockTransferThreadClusterArrangeOrder,
FloatAB,
FloatAB,
FloatABAdjusted,
decltype(b_grid_desc_k0_n_k1),
decltype(b_block_desc_k0_n_k1),
BBlockTransferSrcAccessOrder,
......@@ -428,7 +438,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
// sanity check
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
BlockSize,
FloatAB,
FloatABAdjusted,
FloatAcc,
decltype(a_block_desc_k0_m_k1),
decltype(b_block_desc_k0_n_k1),
......@@ -446,10 +456,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
math::integer_least_multiple(a_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align);
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatAB*>(p_shared), a_block_desc_k0_m_k1.GetElementSpaceSize());
static_cast<FloatABAdjusted*>(p_shared), a_block_desc_k0_m_k1.GetElementSpaceSize());
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatAB*>(p_shared) + a_block_space_size_aligned,
static_cast<FloatABAdjusted*>(p_shared) + a_block_space_size_aligned,
b_block_desc_k0_n_k1.GetElementSpaceSize());
constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0);
......
......@@ -18,60 +18,23 @@
namespace ck {
template <typename GridwiseGemm,
typename FloatAB,
typename FloatC,
typename AGridDesc_B_K0_M_K1,
typename BGridDesc_B_K0_N_K1,
typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
typename CBlockClusterAdaptor,
bool HasMainKBlockLoop>
bool HasMainKBlockLoop,
InMemoryDataOperationEnum CGlobalMemoryDataOperation>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_gemm_xdlops_v2r4r2(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
const AGridDesc_B_K0_M_K1 a_b_k0_m_k1_grid_desc,
const BGridDesc_B_K0_N_K1 b_b_k0_n_k1_grid_desc,
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CElementwiseOperation c_element_op,
const CBlockClusterAdaptor c_block_cluster_adaptor)
kernel_gemm_xdlops_v2r4r2_simplified(typename GridwiseGemm::Argument karg)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
constexpr index_t shared_block_size =
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
constexpr index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte();
__shared__ FloatAB p_shared_block[shared_block_size];
__shared__ uint8_t p_shared[shared_size];
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
p_b_grid,
p_c_grid,
static_cast<void*>(p_shared_block),
a_b_k0_m_k1_grid_desc,
b_b_k0_n_k1_grid_desc,
c_grid_desc_mblock_mperblock_nblock_nperblock,
a_element_op,
b_element_op,
c_element_op,
c_block_cluster_adaptor);
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation>(
karg, static_cast<void*>(p_shared));
#else
ignore = p_a_grid;
ignore = p_b_grid;
ignore = p_c_grid;
ignore = a_b_k0_m_k1_grid_desc;
ignore = b_b_k0_n_k1_grid_desc;
ignore = c_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = a_element_op;
ignore = b_element_op;
ignore = c_element_op;
ignore = c_block_cluster_adaptor;
ignore = karg;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
......@@ -79,13 +42,13 @@ template <index_t BlockSize,
typename FloatAB,
typename FloatAcc,
typename FloatC,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
typename AGridDesc_B_K0_M_K1,
typename BGridDesc_B_K0_N_K1,
typename CMNGridDesc,
typename ALayout,
typename BLayout,
typename CLayout,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
tensor_operation::device::GemmSpecialization GemmSpec,
index_t MPerBlock,
index_t NPerBlock,
index_t K0PerBlock,
......@@ -127,9 +90,237 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
// K1 should be Number<...>
static constexpr auto K1 = Number<K1Value>{};
static constexpr auto M01 = 1;
static constexpr auto N01 = 1;
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
struct Argument : public ck::tensor_operation::device::BaseArgument
{
const FloatAB* p_a_grid;
const FloatAB* p_b_grid;
FloatC* p_c_grid;
index_t M;
index_t N;
index_t K;
index_t StrideA;
index_t StrideB;
index_t StrideC;
index_t MPadded;
index_t NPadded;
index_t KPadded;
index_t K0;
index_t k_batch;
Argument(const FloatAB* p_a_grid_,
const FloatAB* p_b_grid_,
FloatC* p_c_grid_,
index_t M_,
index_t N_,
index_t K_,
index_t StrideA_,
index_t StrideB_,
index_t StrideC_,
index_t MPadded_,
index_t NPadded_,
index_t KPadded_,
index_t K0_,
index_t k_batch_)
: p_a_grid(p_a_grid_),
p_b_grid(p_b_grid_),
p_c_grid(p_c_grid_),
M(M_),
N(N_),
K(K_),
StrideA(StrideA_),
StrideB(StrideB_),
StrideC(StrideC_),
MPadded(MPadded_),
NPadded(NPadded_),
KPadded(KPadded_),
K0(K0_),
k_batch(k_batch_)
{
}
void Print() const
{
std::cout << "arg {"
<< "M:" << M << ", "
<< "N:" << N << ", "
<< "K:" << K << ", "
<< "SA:" << StrideA << ", "
<< "SB:" << StrideB << ", "
<< "SC:" << StrideC << ", "
<< "MP:" << MPadded << ", "
<< "NP:" << NPadded << ", "
<< "KP:" << KPadded << ", "
<< "K0:" << K0 << ", "
<< "KB:" << k_batch << "}" << std::endl;
}
};
__host__ __device__ static auto CalculateGridSize(const Argument& karg)
{
return std::make_tuple(math::integer_divide_ceil(karg.N, NPerBlock),
math::integer_divide_ceil(karg.M, MPerBlock),
karg.k_batch);
}
// prefer this to be called on host
__host__ __device__ static auto CalculateMPadded(index_t M)
{
return (M + MPerBlock - 1) / MPerBlock * MPerBlock;
}
__host__ __device__ static auto CalculateNPadded(index_t N)
{
return (N + NPerBlock - 1) / NPerBlock * NPerBlock;
}
__host__ __device__ static auto CalculateK0(index_t K, index_t K_Batch = 1)
{
// k_batch * k0 * k0_per_block * k1
auto K_t = K_Batch * K0PerBlock * K1;
return (K + K_t - 1) / K_t * K0PerBlock;
}
__host__ __device__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1)
{
auto K0 = CalculateK0(K, K_Batch);
return K_Batch * K0 * K1;
}
__host__ __device__ static auto MakeAGridDescriptor_KBatch_K0_M_K1(index_t M,
index_t MPad,
index_t K,
index_t StrideA,
index_t KBatch,
index_t K0,
index_t KPad)
{
const auto a_grid_desc_m_k = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, ALayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA));
}
}();
const auto a_grid_desc_m_kpad = transform_tensor_descriptor(
a_grid_desc_m_k,
make_tuple(make_pass_through_transform(M), make_right_pad_transform(K, KPad - K)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)
{
// const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
return transform_tensor_descriptor(
a_grid_desc_m_kpad,
make_tuple(make_unmerge_transform(make_tuple(KBatch, K0, K1)),
make_right_pad_transform(M, MPad - M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
}
else
{
return transform_tensor_descriptor(
a_grid_desc_m_kpad,
make_tuple(make_unmerge_transform(make_tuple(KBatch, K0, K1)),
make_pass_through_transform(M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
}
}
__host__ __device__ static auto MakeBGridDescriptor_KBatch_K0_N_K1(index_t K,
index_t NPad,
index_t N,
index_t StrideB,
index_t KBatch,
index_t K0,
index_t KPad)
{
const auto b_grid_desc_k_n = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB));
}
}();
const auto b_grid_desc_kpad_n = transform_tensor_descriptor(
b_grid_desc_k_n,
make_tuple(make_right_pad_transform(K, KPad - K), make_pass_through_transform(N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)
{
// const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
return transform_tensor_descriptor(
b_grid_desc_kpad_n,
make_tuple(make_unmerge_transform(make_tuple(KBatch, K0, K1)),
make_right_pad_transform(N, NPad - N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
}
else
{
return transform_tensor_descriptor(
b_grid_desc_kpad_n,
make_tuple(make_unmerge_transform(make_tuple(KBatch, K0, K1)),
make_pass_through_transform(N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
}
}
__host__ __device__ static auto
MakeCGridDescriptor_M_N(index_t M, index_t N, index_t MPad, index_t NPad, index_t StrideC)
{
const auto c_grid_desc_m_n = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, CLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, CLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC));
}
}();
if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding)
{
return transform_tensor_descriptor(c_grid_desc_m_n,
make_tuple(make_right_pad_transform(M, MPad - M),
make_right_pad_transform(N, NPad - N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else
{
return transform_tensor_descriptor(
c_grid_desc_m_n,
make_tuple(make_pass_through_transform(M), make_pass_through_transform(N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
}
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{
constexpr auto max_lds_align = K1;
......@@ -178,45 +369,68 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
c_block_size * sizeof(FloatC));
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
template <typename Block2CTileMap>
__host__ __device__ static constexpr bool
CheckValidity(const AGridDesc_B_K0_M_K1& a_b_k0_m_k1_grid_desc,
const BGridDesc_B_K0_N_K1& b_b_k0_n_k1_grid_desc,
const CMNGridDesc& c_m_n_grid_desc,
const Block2CTileMap& block_2_ctile_map)
__host__ __device__ static constexpr bool CheckValidity(const Argument& karg)
{
static_assert(is_known_at_compile_time<remove_cv_t<decltype(K1)>>::value,
"wrong! K1 need to be known at compile-time");
static_assert((MPerBlock % (MPerXDL * MRepeat) == 0) &&
(NPerBlock % (NRepeat * NPerXDL)) == 0,
"Invalid tuning param!");
const auto M = a_b_k0_m_k1_grid_desc.GetLength(I2);
const auto N = b_b_k0_n_k1_grid_desc.GetLength(I2);
const auto K0 = a_b_k0_m_k1_grid_desc.GetLength(I1);
const auto KBatch = a_b_k0_m_k1_grid_desc.GetLength(I0);
if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding))
{
if(!(karg.M % MPerBlock == 0))
return false;
}
if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding))
{
if(!(karg.N % NPerBlock == 0))
return false;
}
if(!(M == c_m_n_grid_desc.GetLength(I0) && N == c_m_n_grid_desc.GetLength(I1) &&
K0 == b_b_k0_n_k1_grid_desc.GetLength(I1) &&
K1 == a_b_k0_m_k1_grid_desc.GetLength(I3) &&
K1 == b_b_k0_n_k1_grid_desc.GetLength(I3) &&
KBatch == b_b_k0_n_k1_grid_desc.GetLength(I0)))
if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
{
if(karg.K % ABlockTransferSrcScalarPerVector != 0)
return false;
}
else
{
if(karg.M % ABlockTransferSrcScalarPerVector != 0)
return false;
}
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0))
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
{
if(karg.N % BBlockTransferSrcScalarPerVector != 0)
return false;
}
else
{
if(karg.K % BBlockTransferSrcScalarPerVector != 0)
return false;
}
if(!block_2_ctile_map.CheckValidity(c_m_n_grid_desc))
if constexpr(is_same<tensor_layout::gemm::RowMajor, CLayout>::value)
{
if(karg.N % CBlockTransferScalarPerVector_NWaveNPerXDL != 0)
return false;
}
else
{
if(karg.M % CBlockTransferScalarPerVector_NWaveNPerXDL != 0)
return false;
}
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return true;
}
__host__ __device__ static auto GetKPad(index_t K, index_t KBatch)
{
const index_t K0 = math::integer_divide_ceil(K, K1 * K0PerBlock * KBatch) * K0PerBlock;
const index_t KPad = KBatch * K0 * K1;
return KPad;
}
__host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0)
{
const bool has_main_k0_block_loop = K0 > K0PerBlock;
......@@ -224,8 +438,9 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
return has_main_k0_block_loop;
}
template <typename CGridDesc>
__host__ __device__ static constexpr auto
MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(const CMNGridDesc& c_m_n_grid_desc)
MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc& c_m_n_grid_desc)
{
const auto M = c_m_n_grid_desc.GetLength(I0);
const auto N = c_m_n_grid_desc.GetLength(I1);
......@@ -242,10 +457,11 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
}
// return block_id to C matrix tile idx (m0, n0) mapping
template <typename CGridDesc>
__host__ __device__ static constexpr auto MakeCBlockClusterAdaptor(
const CMNGridDesc& c_m_n_grid_desc, index_t /* M01 */, index_t /* N01 */, index_t KBatch)
const CGridDesc& c_m_n_grid_desc, index_t /* M01 */, index_t /* N01 */, index_t KBatch)
{
return BlockToCTileMap_KSplit_M00_N0_M01Adapt<MPerBlock, NPerBlock, CMNGridDesc>(
return BlockToCTileMap_KSplit_M00_N0_M01Adapt<MPerBlock, NPerBlock, CGridDesc>(
c_m_n_grid_desc, 8, KBatch);
}
......@@ -262,24 +478,25 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
Number<CShuffleNRepeatPerShuffle * NWave * NPerXDL>{}));
}
using CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
decltype(MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(CMNGridDesc{}));
using CBlockClusterAdaptor = decltype(MakeCBlockClusterAdaptor(CMNGridDesc{}, 1, 1, 1));
template <bool HasMainKBlockLoop>
__device__ static void Run(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
void* __restrict__ p_shared_block,
const AGridDesc_B_K0_M_K1& a_b_k0_m_k1_grid_desc,
const BGridDesc_B_K0_N_K1& b_b_k0_n_k1_grid_desc,
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
c_grid_desc_mblock_mperblock_nblock_nperblock,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CElementwiseOperation& c_element_op,
const CBlockClusterAdaptor& c_block_cluster_adaptor)
template <bool HasMainKBlockLoop, InMemoryDataOperationEnum CGlobalMemoryDataOperation>
__device__ static void Run(const Argument& karg, void* __restrict__ p_shared_block)
{
const FloatAB* p_a_grid = karg.p_a_grid;
const FloatAB* p_b_grid = karg.p_b_grid;
FloatC* p_c_grid = karg.p_c_grid;
const auto a_b_k0_m_k1_grid_desc = MakeAGridDescriptor_KBatch_K0_M_K1(
karg.M, karg.MPadded, karg.K, karg.StrideA, karg.k_batch, karg.K0, karg.KPadded);
const auto b_b_k0_n_k1_grid_desc = MakeBGridDescriptor_KBatch_K0_N_K1(
karg.K, karg.NPadded, karg.N, karg.StrideB, karg.k_batch, karg.K0, karg.KPadded);
const auto c_grid_desc_m_n =
MakeCGridDescriptor_M_N(karg.M, karg.N, karg.MPadded, karg.NPadded, karg.StrideC);
const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n);
const AElementwiseOperation a_element_op = AElementwiseOperation{};
const BElementwiseOperation b_element_op = BElementwiseOperation{};
const CElementwiseOperation c_element_op = CElementwiseOperation{};
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_b_k0_m_k1_grid_desc.GetElementSpaceSize());
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
......@@ -289,26 +506,16 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
const auto K0 = a_b_k0_m_k1_grid_desc.GetLength(I1);
// divide block work by [M, N]
const auto block_work_idx =
c_block_cluster_adaptor.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
if(!c_block_cluster_adaptor.ValidCTileIndex(
make_tuple(block_work_idx[I1], block_work_idx[I2]),
make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
{
return;
}
const index_t k_batch_id = block_work_idx[I0];
const index_t block_m_id = __builtin_amdgcn_readfirstlane(blockIdx.y);
const index_t block_n_id = __builtin_amdgcn_readfirstlane(blockIdx.x);
const index_t k_batch_id = __builtin_amdgcn_readfirstlane(blockIdx.z);
// HACK: this force m/n_block_data_idx_on_grid into SGPR
const index_t m_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I1] * MPerBlock);
__builtin_amdgcn_readfirstlane(block_m_id * MPerBlock);
const index_t n_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I2] * NPerBlock);
__builtin_amdgcn_readfirstlane(block_n_id * NPerBlock);
// lds max alignment
constexpr auto max_lds_align = K1;
......@@ -444,7 +651,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
// sanity check
auto blockwise_gemm =
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
FloatAB,
......@@ -647,7 +853,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
{c_block_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(0, 0, 0, 0),
c_grid_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(block_work_idx[I1], 0, block_work_idx[I2], 0),
make_multi_index(block_m_id, 0, block_n_id, 0),
c_element_op};
constexpr auto mxdlperwave_forward_step =
......@@ -716,6 +922,48 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
});
}
}
template <typename Layout>
struct LStr
{
static std::string Get() { return ""; }
};
template <>
struct LStr<ck::tensor_layout::gemm::RowMajor>
{
static std::string Get() { return "R"; }
};
template <>
struct LStr<ck::tensor_layout::gemm::ColumnMajor>
{
static std::string Get() { return "C"; }
};
static std::string GetTypeString()
{
auto str = std::stringstream();
// clang-format off
str << "GemmXdlSplitKCShuffle_"
<< getGemmSpecializationString(GemmSpec) << "_"
<< std::string(ALayout::name)[0]
<< std::string(BLayout::name)[0]
<< std::string(CLayout::name)[0]
<< "_"
<< "B" << BlockSize << "_"
<< "Vec" << ABlockTransferSrcScalarPerVector << "x"
<< BBlockTransferSrcScalarPerVector << "x"
<< CBlockTransferScalarPerVector_NWaveNPerXDL << "_"
<< MPerBlock << "x"
<< NPerBlock << "x"
<< K0PerBlock << "x"
<< K1 ;
// clang-format on
return str.str();
}
};
} // namespace ck
......@@ -1201,7 +1201,12 @@ struct ThreadwiseTensorSliceTransfer_v4
SrcCoord src_ref_coord_;
};
// Do NOT involve any tensor coordinates with StaticBuffer
/**
* @brief Threadwise data transfer
*
* Do NOT involve any tensor coordinates with StaticBuffer
*
*/
template <typename SrcData,
typename DstData,
typename SrcDesc,
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......
......@@ -1030,7 +1030,7 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave,
constexpr index_t vector_size = scalar_type<vector_t>::vector_size;
#if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
uint32_t src_addr_shift = src_thread_element_valid ? 0 : 0x7fffffff;
uint32_t src_addr_shift = src_thread_element_valid ? 0 : 0x80000000;
return amd_buffer_load_impl<scalar_t, vector_size>(
src_wave_buffer_resource, src_addr_shift + src_thread_addr_offset, 0);
......@@ -1091,7 +1091,7 @@ __device__ void amd_buffer_store(const typename vector_type_maker<T, N>::type::t
constexpr index_t vector_size = scalar_type<vector_t>::vector_size;
#if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK
uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x7fffffff;
uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x80000000;
amd_buffer_store_impl<scalar_t, vector_size>(
src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0);
......@@ -1126,7 +1126,7 @@ amd_buffer_atomic_add(const typename vector_type_maker<T, N>::type::type src_thr
constexpr index_t vector_size = scalar_type<vector_t>::vector_size;
#if CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_ADD_OOB_CHECK_OFFSET_TRICK
uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x7fffffff;
uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x80000000;
amd_buffer_atomic_add_impl<scalar_t, vector_size>(
src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0);
......@@ -1161,7 +1161,7 @@ amd_buffer_atomic_max(const typename vector_type_maker<T, N>::type::type src_thr
constexpr index_t vector_size = scalar_type<vector_t>::vector_size;
#if CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_MAX_OOB_CHECK_OFFSET_TRICK
uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x7fffffff;
uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x80000000;
amd_buffer_atomic_max_impl<scalar_t, vector_size>(
src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0);
......
......@@ -358,7 +358,13 @@ __device__ void amd_assembly_outer_product_1x4(int8x16_t a,
// Ranged input operand
__device__ void amd_assembly_wmma_f32_16x16x16_f16_w32(half16_t a, half16_t b, float8_t& c)
{
#if defined(__gfx11__)
asm volatile("v_wmma_f32_16x16x16_f16 %0, %1, %2, %0" : "=v"(c) : "v"(a), "v"(b), "0"(c));
#else
ignore = a;
ignore = b;
ignore = c;
#endif
}
} // namespace ck
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment