Commit b924e330 authored by Jun Liu's avatar Jun Liu
Browse files

Merge branch 'amd-develop' into amd-master

parents 72c9f129 9c0811f3
......@@ -15,10 +15,12 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp"
#include "ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp"
#include "ck/tensor_operation/operator_transform/transform_conv_ngchw_to_nhwgc.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
......@@ -292,6 +294,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{};
static constexpr auto I5 = Number<5>{};
using ConvToGemmFwdTransformer = TransformConvFwdToGemm<NDimSpatial,
ConvForwardSpecialization,
......@@ -302,13 +306,32 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
static constexpr auto matrix_padder =
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
static constexpr index_t ClusterLengthNPerBlock =
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(3);
static constexpr auto conv_ngchw_to_nhwgc_transformer =
TransformConvNGCHWToNHWGC<ALayout,
BLayout,
ELayout,
NDimSpatial,
MPerBlock / ClusterLengthNPerBlock,
NPerBlock / ClusterLengthNPerBlock>{};
template <typename ALay>
static auto
MakeAGridDescriptor_AK0_M_AK1(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
{
namespace ctc = tensor_layout::convolution;
using Layout = std::conditional_t<
is_NGCHW_GKYXC_NGKHW<ALayout, BLayout, ELayout>(),
ctc::NHWGC,
std::conditional_t<is_NGCDHW_GKZYXC_NGKDHW<ALayout, BLayout, ELayout>(),
ctc::NDHWGC,
ALay>>;
const auto in_gemmmraw_gemmkraw_desc =
conv_to_gemm_transformer.template MakeADescriptor_M_K<ALay>();
conv_to_gemm_transformer.template MakeADescriptor_M_K<Layout>();
const auto in_gemmm_gemmk_desc =
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc);
......@@ -351,8 +374,16 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
static auto MakeEGridDescriptor_M_N(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
{
namespace ctc = tensor_layout::convolution;
using Layout = std::conditional_t<
is_NGCHW_GKYXC_NGKHW<ALayout, BLayout, ELayout>(),
ctc::NHWGK,
std::conditional_t<is_NGCDHW_GKZYXC_NGKDHW<ALayout, BLayout, ELayout>(),
ctc::NDHWGK,
ELay>>;
const auto out_gemmmraw_gemmnraw_desc =
conv_to_gemm_transformer.template MakeCDescriptor_M_N<ELay>();
conv_to_gemm_transformer.template MakeCDescriptor_M_N<Layout>();
const auto out_gemmm_gemmn_desc =
matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc);
......@@ -385,6 +416,53 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
// Use appropriate gridwise gemm
using GridwiseGemm = GridwiseGemm_xdl_cshuffle_v3<GridwiseGemmV3TemplateParams>;
using Block2TileMapElementwise = BlockToCTileMap_M00_N0_M01Adapt<NPerBlock, NPerBlock>;
using NGCHWTransposeDescType =
remove_cvref_t<decltype(conv_ngchw_to_nhwgc_transformer
.template MakeNGCHWTransposeDesc<NDimSpatial>({}, {}))>;
using NHWGCTransposeDescType =
remove_cvref_t<decltype(conv_ngchw_to_nhwgc_transformer
.template MakeNHWGCTransposeDesc<NDimSpatial>({}, {}))>;
static constexpr index_t ElementwiseBlocksize = ClusterLengthNPerBlock * ClusterLengthNPerBlock;
using GridwiseElementwiseInputTranspose =
GridwiseElementwise<Tuple<NGCHWTransposeDescType>,
Tuple<NHWGCTransposeDescType>,
Tuple<const ADataType*>,
Tuple<ADataType*>,
Block2TileMapElementwise,
element_wise::PassThrough,
ElementwiseBlocksize,
NPerBlock,
NPerBlock,
NPerBlock / ClusterLengthNPerBlock,
NPerBlock / ClusterLengthNPerBlock,
Sequence<1, 0>,
Sequence<CDEBlockTransferScalarPerVector_NPerBlock>,
Sequence<CDEBlockTransferScalarPerVector_NPerBlock>,
I1,
I0>;
using GridwiseElementwiseOutputTranspose =
GridwiseElementwise<Tuple<NHWGCTransposeDescType>,
Tuple<NGCHWTransposeDescType>,
Tuple<const EDataType*>,
Tuple<EDataType*>,
Block2TileMapElementwise,
element_wise::PassThrough,
ElementwiseBlocksize,
NPerBlock,
NPerBlock,
NPerBlock / ClusterLengthNPerBlock,
NPerBlock / ClusterLengthNPerBlock,
Sequence<1, 0>,
Sequence<CDEBlockTransferScalarPerVector_NPerBlock>,
Sequence<CDEBlockTransferScalarPerVector_NPerBlock>,
I0,
I1>;
static auto
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const EGridDesc_M_N& e_grid_desc_m_n)
{
......@@ -428,17 +506,29 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
: p_a_grid_{},
p_b_grid_{},
p_e_grid_{static_cast<EDataType*>(p_e)},
num_group_{a_g_n_c_wis_lengths[0]},
conv_to_gemm_transformer_{a_g_n_c_wis_lengths,
a_g_n_c_wis_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
e_g_n_k_wos_lengths,
e_g_n_k_wos_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads},
a_g_n_c_wis_lengths_{a_g_n_c_wis_lengths},
a_g_n_c_wis_strides_{conv_ngchw_to_nhwgc_transformer.TransposeStrides(
a_g_n_c_wis_lengths, a_g_n_c_wis_strides)},
b_g_k_c_xs_lengths_{b_g_k_c_xs_lengths},
b_g_k_c_xs_strides_{b_g_k_c_xs_strides},
e_g_n_k_wos_lengths_{e_g_n_k_wos_lengths},
e_g_n_k_wos_strides_{conv_ngchw_to_nhwgc_transformer.TransposeStrides(
e_g_n_k_wos_lengths, e_g_n_k_wos_strides)},
conv_filter_strides_{conv_filter_strides},
conv_filter_dilations_{conv_filter_dilations},
input_left_pads_{input_left_pads},
input_right_pads_{input_right_pads},
num_group_{a_g_n_c_wis_lengths_[0]},
conv_to_gemm_transformer_{a_g_n_c_wis_lengths_,
a_g_n_c_wis_strides_,
b_g_k_c_xs_lengths_,
b_g_k_c_xs_strides_,
e_g_n_k_wos_lengths_,
e_g_n_k_wos_strides_,
conv_filter_strides_,
conv_filter_dilations_,
input_left_pads_,
input_right_pads_},
conv_N_per_block_{conv_to_gemm_transformer_.N_},
a_grid_desc_ak0_m_ak1_{
MakeAGridDescriptor_AK0_M_AK1<ALayout>(conv_to_gemm_transformer_)},
......@@ -451,32 +541,70 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
compute_ptr_offset_of_n_{},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
cde_element_op_{cde_element_op},
a_g_n_c_wis_lengths_{a_g_n_c_wis_lengths},
a_g_n_c_wis_strides_{a_g_n_c_wis_strides},
b_g_k_c_xs_lengths_{b_g_k_c_xs_lengths},
b_g_k_c_xs_strides_{b_g_k_c_xs_strides},
e_g_n_k_wos_lengths_{e_g_n_k_wos_lengths},
e_g_n_k_wos_strides_{e_g_n_k_wos_strides},
conv_filter_strides_{conv_filter_strides},
conv_filter_dilations_{conv_filter_dilations},
input_left_pads_{input_left_pads},
input_right_pads_{input_right_pads}
cde_element_op_{cde_element_op}
{
// A/B/E Batch/N Stride
compute_ptr_offset_of_groups_.BatchStrideA_ = a_g_n_c_wis_strides[0];
compute_ptr_offset_of_groups_.BatchStrideB_ = b_g_k_c_xs_strides[0];
compute_ptr_offset_of_n_.BatchStrideA_ = a_g_n_c_wis_strides[1] * conv_N_per_block_;
compute_ptr_offset_of_groups_.BatchStrideA_ = a_g_n_c_wis_strides_[0];
compute_ptr_offset_of_groups_.BatchStrideB_ = b_g_k_c_xs_strides_[0];
compute_ptr_offset_of_n_.BatchStrideA_ = a_g_n_c_wis_strides_[1] * conv_N_per_block_;
// p_as and p_bs are pointers
p_a_grid_ = static_cast<const ADataType*>(p_as);
p_b_grid_ = static_cast<const BDataType*>(p_bs);
compute_ptr_offset_of_groups_.BatchStrideE_ = e_g_n_k_wos_strides[0];
compute_ptr_offset_of_n_.BatchStrideE_ = e_g_n_k_wos_strides[1] * conv_N_per_block_;
compute_ptr_offset_of_groups_.BatchStrideE_ = e_g_n_k_wos_strides_[0];
compute_ptr_offset_of_n_.BatchStrideE_ = e_g_n_k_wos_strides_[1] * conv_N_per_block_;
e_grid_desc_mblock_mperblock_nblock_nperblock_ =
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(e_grid_desc_m_n_);
if constexpr(is_NGCHW_GKYXC_NGKHW<ALayout, BLayout, ELayout>() ||
is_NGCDHW_GKZYXC_NGKDHW<ALayout, BLayout, ELayout>())
{
// Use not modified base strides
a_in_transpose_desc_ =
conv_ngchw_to_nhwgc_transformer.template MakeNGCHWTransposeDesc<NDimSpatial>(
a_g_n_c_wis_lengths, a_g_n_c_wis_strides);
a_out_transpose_desc_ =
conv_ngchw_to_nhwgc_transformer.template MakeNHWGCTransposeDesc<NDimSpatial>(
a_g_n_c_wis_lengths, a_g_n_c_wis_strides);
e_in_transpose_desc_ =
conv_ngchw_to_nhwgc_transformer.template MakeNHWGCTransposeDesc<NDimSpatial>(
e_g_n_k_wos_lengths, e_g_n_k_wos_strides);
e_out_transpose_desc_ =
conv_ngchw_to_nhwgc_transformer.template MakeNGCHWTransposeDesc<NDimSpatial>(
e_g_n_k_wos_lengths, e_g_n_k_wos_strides);
elementwise_block_2_ctile_map_transpose_a_ = Block2TileMapElementwise{
a_in_transpose_desc_.GetLength(I0), a_in_transpose_desc_.GetLength(I1)};
elementwise_block_2_ctile_map_transpose_e_ = Block2TileMapElementwise{
e_in_transpose_desc_.GetLength(I0), e_in_transpose_desc_.GetLength(I1)};
}
}
std::size_t GetWorkspaceATensorSizeBytes() const
{
return sizeof(ADataType) * a_in_transpose_desc_.GetElementSpaceSize();
}
std::size_t GetWorkspaceETensorSizeBytes() const
{
return sizeof(EDataType) * e_out_transpose_desc_.GetElementSpaceSize();
}
std::size_t GetWorkspaceSizeBytes() const
{
// Transpose require workspace for A and B
if constexpr(is_NGCHW_GKYXC_NGKHW<ALayout, BLayout, ELayout>() ||
is_NGCDHW_GKZYXC_NGKDHW<ALayout, BLayout, ELayout>())
{
return GetWorkspaceATensorSizeBytes() + GetWorkspaceETensorSizeBytes();
}
else
{
return 0;
}
}
void Print() const
......@@ -492,6 +620,18 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
const BDataType* p_b_grid_;
EDataType* p_e_grid_;
// for checking IsSupportedArgument()
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_lengths_;
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_strides_;
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_lengths_;
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_strides_;
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_lengths_;
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_strides_;
std::array<index_t, NDimSpatial> conv_filter_strides_;
std::array<index_t, NDimSpatial> conv_filter_dilations_;
std::array<index_t, NDimSpatial> input_left_pads_;
std::array<index_t, NDimSpatial> input_right_pads_;
// tensor descriptors for problem definiton
index_t num_group_;
......@@ -514,17 +654,12 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
BElementwiseOperation b_element_op_;
CDEElementwiseOperation cde_element_op_;
// for checking IsSupportedArgument()
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_lengths_;
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_strides_;
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_lengths_;
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_strides_;
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_lengths_;
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_strides_;
std::array<index_t, NDimSpatial> conv_filter_strides_;
std::array<index_t, NDimSpatial> conv_filter_dilations_;
std::array<index_t, NDimSpatial> input_left_pads_;
std::array<index_t, NDimSpatial> input_right_pads_;
// block-to-e-tile map
Block2TileMapElementwise elementwise_block_2_ctile_map_transpose_a_,
elementwise_block_2_ctile_map_transpose_e_;
NGCHWTransposeDescType a_in_transpose_desc_, e_out_transpose_desc_;
NHWGCTransposeDescType a_out_transpose_desc_, e_in_transpose_desc_;
};
// Invoker
......@@ -532,7 +667,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
{
using Argument = DeviceOp::Argument;
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
float RunGemm(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
if(stream_config.log_level_ > 0)
{
......@@ -561,8 +696,19 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
index_t K_split = (GemmK + KPerBlock - 1) / KPerBlock * KPerBlock;
const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
const ADataType* p_a_grid = arg.p_a_grid_;
EDataType* p_e_grid = arg.p_e_grid_;
if constexpr(is_NGCHW_GKYXC_NGKHW<ALayout, BLayout, ELayout>() ||
is_NGCDHW_GKZYXC_NGKDHW<ALayout, BLayout, ELayout>())
{
p_a_grid = type_convert<const ADataType*>(arg.p_workspace_);
p_e_grid = type_convert<EDataType*>(arg.p_workspace_) +
arg.GetWorkspaceATensorSizeBytes() / sizeof(EDataType);
}
typename GridwiseGemm::Argument gemm_arg{
arg.p_a_grid_, arg.p_b_grid_, arg.p_e_grid_, GemmM, GemmN, GemmK, I0, I0, I0, I1};
p_a_grid, arg.p_b_grid_, p_e_grid, GemmM, GemmN, GemmK, I0, I0, I0, I1};
const auto Run = [&](const auto& kernel) {
if(stream_config.flush_cache)
......@@ -857,6 +1003,79 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
return ave_time;
}
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
float avg_time = 0.f;
if constexpr(is_NGCHW_GKYXC_NGKHW<ALayout, BLayout, ELayout>() ||
is_NGCDHW_GKZYXC_NGKDHW<ALayout, BLayout, ELayout>())
{
const index_t grid_size =
arg.elementwise_block_2_ctile_map_transpose_a_.CalculateGridSize(
arg.a_in_transpose_desc_);
ADataType* p_a_out_grid = type_convert<ADataType*>(arg.p_workspace_);
auto kernel_transpose = kernel_elementwise<GridwiseElementwiseInputTranspose,
ck::Tuple<NGCHWTransposeDescType>,
ck::Tuple<NHWGCTransposeDescType>,
ck::Tuple<const ADataType*>,
ck::Tuple<ADataType*>,
Block2TileMapElementwise,
element_wise::PassThrough>;
avg_time += launch_and_time_kernel(stream_config,
kernel_transpose,
dim3(grid_size),
dim3(ElementwiseBlocksize),
0,
make_tuple(arg.a_in_transpose_desc_),
make_tuple(arg.a_out_transpose_desc_),
make_tuple(arg.p_a_grid_),
make_tuple(p_a_out_grid),
arg.elementwise_block_2_ctile_map_transpose_a_,
element_wise::PassThrough{});
}
avg_time += RunGemm(arg, stream_config);
if constexpr(is_NGCHW_GKYXC_NGKHW<ALayout, BLayout, ELayout>() ||
is_NGCDHW_GKZYXC_NGKDHW<ALayout, BLayout, ELayout>())
{
const index_t grid_size =
arg.elementwise_block_2_ctile_map_transpose_e_.CalculateGridSize(
arg.e_in_transpose_desc_);
const EDataType* p_e_out_grid =
type_convert<EDataType*>(arg.p_workspace_) +
arg.GetWorkspaceATensorSizeBytes() / sizeof(EDataType);
EDataType* p_e_in_grid = arg.p_e_grid_;
auto kernel_transpose = kernel_elementwise<GridwiseElementwiseOutputTranspose,
ck::Tuple<NHWGCTransposeDescType>,
ck::Tuple<NGCHWTransposeDescType>,
ck::Tuple<const EDataType*>,
ck::Tuple<EDataType*>,
Block2TileMapElementwise,
element_wise::PassThrough>;
avg_time += launch_and_time_kernel(stream_config,
kernel_transpose,
dim3(grid_size),
dim3(ElementwiseBlocksize),
0,
make_tuple(arg.e_in_transpose_desc_),
make_tuple(arg.e_out_transpose_desc_),
make_tuple(p_e_out_grid),
make_tuple(p_e_in_grid),
arg.elementwise_block_2_ctile_map_transpose_e_,
element_wise::PassThrough{});
}
return avg_time;
}
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
......@@ -868,6 +1087,10 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
{
namespace ctc = tensor_layout::convolution;
const index_t G = arg.b_g_k_c_xs_lengths_[I0];
const index_t K = arg.b_g_k_c_xs_lengths_[I1];
const index_t C = arg.b_g_k_c_xs_lengths_[I2];
// check device
if(get_device_name() == "gfx908")
{
......@@ -924,10 +1147,9 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
is_same_v<ALayout, ctc::G_NDHW_C> || is_same_v<ALayout, ctc::GNWC> ||
is_same_v<ALayout, ctc::GNHWC> || is_same_v<ALayout, ctc::GNDHWC> ||
is_same_v<ALayout, ctc::NWGC> || is_same_v<ALayout, ctc::NHWGC> ||
is_same_v<ALayout, ctc::NDHWGC>)
is_same_v<ALayout, ctc::NDHWGC> || is_same_v<ALayout, ctc::NGCW> ||
is_same_v<ALayout, ctc::NGCHW> || is_same_v<ALayout, ctc::NGCDHW>)
{
const index_t C = arg.a_g_n_c_wis_lengths_[2];
if(!(ABlockTransferSrcVectorDim == 2 && C % ABlockTransferSrcScalarPerVector == 0))
{
return false;
......@@ -947,8 +1169,6 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
is_same_v<BLayout, ctc::KZYXGC>)
{
const index_t C = arg.b_g_k_c_xs_lengths_[2];
if(!(BBlockTransferSrcVectorDim == 2 && C % BBlockTransferSrcScalarPerVector == 0))
{
return false;
......@@ -959,15 +1179,43 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
return false;
}
if constexpr(is_NGCHW_GKYXC_NGKHW<ALayout, BLayout, ELayout>() ||
is_NGCDHW_GKZYXC_NGKDHW<ALayout, BLayout, ELayout>())
{
if((G * C) % CDEBlockTransferScalarPerVector_NPerBlock != 0)
{
return false;
}
if((G * K) % CDEBlockTransferScalarPerVector_NPerBlock != 0)
{
return false;
}
const index_t input_spatial_acum = ck::accumulate_n<index_t>(
arg.a_g_n_c_wis_lengths_.begin() + I3, NDimSpatial, 1, std::multiplies<>());
const index_t output_spatial_acum = ck::accumulate_n<index_t>(
arg.e_g_n_k_wos_lengths_.begin() + I3, NDimSpatial, 1, std::multiplies<>());
if(input_spatial_acum % CDEBlockTransferScalarPerVector_NPerBlock != 0)
{
return false;
}
if(output_spatial_acum % CDEBlockTransferScalarPerVector_NPerBlock != 0)
{
return false;
}
}
// check vector access of E
if constexpr(is_same_v<ELayout, ctc::G_NW_K> || is_same_v<ELayout, ctc::G_NHW_K> ||
is_same_v<ELayout, ctc::G_NDHW_K> || is_same_v<ELayout, ctc::GNWK> ||
is_same_v<ELayout, ctc::GNHWK> || is_same_v<ELayout, ctc::GNDHWK> ||
is_same_v<ELayout, ctc::NWGK> || is_same_v<ELayout, ctc::NHWGK> ||
is_same_v<ELayout, ctc::NDHWGK>)
is_same_v<ELayout, ctc::NDHWGK> || is_same_v<ELayout, ctc::NGKW> ||
is_same_v<ELayout, ctc::NGKHW> || is_same_v<ELayout, ctc::NGKDHW>)
{
const index_t K = arg.e_g_n_k_wos_lengths_[2];
if(!(K % CDEBlockTransferScalarPerVector_NPerBlock == 0))
{
return false;
......@@ -1279,6 +1527,34 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
return str.str();
}
size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override
{
auto arg = dynamic_cast<const Argument*>(p_arg);
if(arg)
{
return arg->GetWorkspaceSizeBytes();
}
else
throw std::runtime_error(
"The argument pointer is not an object of "
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle::Argument structure!");
}
void SetWorkSpacePointer(BaseArgument* p_arg,
void* p_workspace,
const StreamConfig& = StreamConfig{}) const override
{
auto p_arg_ = dynamic_cast<Argument*>(p_arg);
if(p_arg_)
{
p_arg_->p_workspace_ = p_workspace;
}
else
throw std::runtime_error(
"The argument pointer is not an object of "
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle::Argument structure!");
}
};
} // namespace device
......
......@@ -26,6 +26,15 @@ constexpr bool is_GNWC_GKXC_GNWK()
is_same_v<WeiLayout, tensor_layout::convolution::GKXC> &&
is_same_v<OutLayout, tensor_layout::convolution::GNWK>;
}
template <typename InLayout, typename WeiLayout, typename OutLayout>
constexpr bool is_NGCW_GKXC_NGKW()
{
return is_same_v<InLayout, tensor_layout::convolution::NGCW> &&
is_same_v<WeiLayout, tensor_layout::convolution::GKXC> &&
is_same_v<OutLayout, tensor_layout::convolution::NGKW>;
}
// 2d
template <typename InLayout, typename WeiLayout, typename OutLayout>
constexpr bool is_NHWGC_GKYXC_NHWGK()
......@@ -91,6 +100,14 @@ constexpr bool is_GNSpatialC_GKSpatial_GNSpatialK()
is_GNDHWC_GKZYXC_GNDHWK<InLayout, WeiLayout, OutLayout>();
}
template <typename InLayout, typename WeiLayout, typename OutLayout>
constexpr bool is_NGCSpatial_GKSpatial_NGKSpatial()
{
return is_NGCW_GKXC_NGKW<InLayout, WeiLayout, OutLayout>() ||
is_NGCHW_GKYXC_NGKHW<InLayout, WeiLayout, OutLayout>() ||
is_NGCDHW_GKZYXC_NGKDHW<InLayout, WeiLayout, OutLayout>();
}
template <index_t NumATensor = 1, index_t NumBTensor = 1, index_t NumDTensor = 0, typename = void>
struct ComputePtrOffsetOfStridedBatch
{
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/tensor_operation/gpu/device/impl/device_pool3d_fwd_ndhwc_ndhwc.hpp"
#include <iostream>
#include <sstream>
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp"
#include "ck/tensor_operation/gpu/device/device_pool_fwd.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_2d_reduction_threadwise.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace ck {
namespace tensor_operation {
......@@ -16,95 +27,359 @@ template <typename InDataType,
ck::ReduceTensorOp ReduceOpId,
bool OutputIndex,
ck::index_t BlockSize,
ck::index_t ReduceMThreadClusterSize,
ck::index_t ReduceKThreadClusterSize,
ck::index_t ReduceMThreadSliceSize,
ck::index_t ReduceKThreadSliceSize,
ck::index_t MThreadClusterSize,
ck::index_t KThreadClusterSize,
ck::index_t MThreadSliceSize,
ck::index_t KThreadSliceSize,
ck::index_t InSrcOutDstVectorSize>
struct DevicePool2dFwd_NHWC_NHWC : public DevicePool3dFwd_NDHWC_NDHWC<InDataType,
OutDataType,
IndexDataType,
ComputeDataType,
ReduceOpId,
OutputIndex,
BlockSize,
ReduceMThreadClusterSize,
ReduceKThreadClusterSize,
ReduceMThreadSliceSize,
ReduceKThreadSliceSize,
InSrcOutDstVectorSize>
struct DevicePool2dFwd_NHWC_NHWC : public DevicePoolFwd<4,
2,
InDataType,
OutDataType,
IndexDataType,
tensor_layout::convolution::NHWC,
tensor_layout::convolution::NHWC,
ReduceOpId,
OutputIndex>
{
using DevicePool3D = DevicePool3dFwd_NDHWC_NDHWC<InDataType,
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr index_t InOutRank = 4;
static constexpr index_t WindowRank = 2;
using ReduceOperation = typename reduce_binary_operator<ReduceOpId>::opType;
using InElementwiseOperation =
typename reduce_unary_operator<ReduceOpId, true, true>::InElementwiseOperation;
using AccElementwiseOperation =
typename reduce_unary_operator<ReduceOpId, true, true>::AccElementwiseOperation;
static constexpr ck::index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
static constexpr ck::index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
static auto MakeABGridDescriptor_A_M_K_B_M(std::vector<ck::index_t> input_nchw_lengths,
std::vector<ck::index_t> output_nchw_lengths,
std::vector<ck::index_t> input_nchw_stride,
std::vector<ck::index_t> output_nchw_stride,
std::vector<ck::index_t> window_spatial_yx_lengths,
std::vector<ck::index_t> window_yx_strides,
std::vector<ck::index_t> window_yx_dilations,
std::vector<ck::index_t> input_left_hw_pads,
std::vector<ck::index_t> input_right_hw_pads)
{
const index_t N = input_nchw_lengths[0];
const index_t C = input_nchw_lengths[1];
const index_t Hi = input_nchw_lengths[2];
const index_t Wi = input_nchw_lengths[3];
const index_t Ho = output_nchw_lengths[2];
const index_t Wo = output_nchw_lengths[3];
const index_t Y = window_spatial_yx_lengths[0];
const index_t X = window_spatial_yx_lengths[1];
const index_t WindowStrideH = window_yx_strides[0];
const index_t WindowStrideW = window_yx_strides[1];
const index_t WindowDilationH = window_yx_dilations[0];
const index_t WindowDilationW = window_yx_dilations[1];
const index_t InLeftPadH = input_left_hw_pads[0];
const index_t InLeftPadW = input_left_hw_pads[1];
const index_t InRightPadH = input_right_hw_pads[0];
const index_t InRightPadW = input_right_hw_pads[1];
const index_t MRaw = N * Ho * Wo * C;
const index_t MPad = math::integer_least_multiple(MRaw, M_BlockTileSize) - MRaw;
const index_t KRaw = Y * X;
const index_t KPad = math::integer_least_multiple(KRaw, K_BlockTileSize) - KRaw;
// A[ReduceM, ReduceK]
const index_t Ni_stride = input_nchw_stride[0];
const index_t Ci_stride = input_nchw_stride[1];
const index_t Hi_stride = input_nchw_stride[2];
const index_t Wi_stride = input_nchw_stride[3];
const auto in_grid_desc_n_hi_wi_c = make_naive_tensor_descriptor(
make_tuple(N, Hi, Wi, C), make_tuple(Ni_stride, Hi_stride, Wi_stride, Ci_stride));
const auto in_grid_desc_n_hip_wip_c = transform_tensor_descriptor(
in_grid_desc_n_hi_wi_c,
make_tuple(make_pass_through_transform(N),
make_pad_transform(Hi, InLeftPadH, InRightPadH),
make_pad_transform(Wi, InLeftPadW, InRightPadW),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto in_grid_desc_n_y_ho_x_wo_c = transform_tensor_descriptor(
in_grid_desc_n_hip_wip_c,
make_tuple(
make_pass_through_transform(N),
make_embed_transform(make_tuple(Y, Ho), make_tuple(WindowDilationH, WindowStrideH)),
make_embed_transform(make_tuple(X, Wo), make_tuple(WindowDilationW, WindowStrideW)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
const auto in_grid_desc_reducemraw_reducekraw =
transform_tensor_descriptor(in_grid_desc_n_y_ho_x_wo_c,
make_tuple(make_merge_transform(make_tuple(N, Ho, Wo, C)),
make_merge_transform(make_tuple(Y, X))),
make_tuple(Sequence<0, 2, 4, 5>{}, Sequence<1, 3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto in_grid_desc_reducem_reducek = transform_tensor_descriptor(
in_grid_desc_reducemraw_reducekraw,
make_tuple(make_right_pad_transform(MRaw, MPad), make_right_pad_transform(KRaw, KPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
// B[ReduceM]
const index_t No_stride = output_nchw_stride[0];
const index_t Co_stride = output_nchw_stride[1];
const index_t Ho_stride = output_nchw_stride[2];
const index_t Wo_stride = output_nchw_stride[3];
const auto out_grid_desc_n_ho_wo_c = make_naive_tensor_descriptor(
make_tuple(N, Hi, Wi, C), make_tuple(No_stride, Ho_stride, Wo_stride, Co_stride));
const auto out_grid_desc_reducemraw =
transform_tensor_descriptor(out_grid_desc_n_ho_wo_c,
make_tuple(make_merge_transform(make_tuple(N, Ho, Wo, C))),
make_tuple(Sequence<0, 1, 2, 3>{}),
make_tuple(Sequence<0>{}));
const auto out_grid_desc_reducem =
transform_tensor_descriptor(out_grid_desc_reducemraw,
make_tuple(make_right_pad_transform(MRaw, MPad)),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0>{}));
return make_tuple(in_grid_desc_reducem_reducek, out_grid_desc_reducem);
}
using ABGridDescs =
decltype(MakeABGridDescriptor_A_M_K_B_M({}, {}, {}, {}, {}, {}, {}, {}, {}));
using AGridDesc_M_K = remove_cvref_t<decltype(ABGridDescs{}[I0])>;
using BGridDesc_M = remove_cvref_t<decltype(ABGridDescs{}[I1])>;
struct Argument : public BaseArgument
{
Argument(const InDataType* p_in_dev,
OutDataType* p_out_dev,
IndexDataType* p_out_indices_dev,
std::vector<ck::index_t>& input_nchw_lengths,
std::vector<ck::index_t>& output_nchw_lengths,
std::vector<ck::index_t>& input_nchw_stride,
std::vector<ck::index_t>& output_nchw_stride,
std::vector<ck::index_t>&, // indices_nchw_stride
std::vector<ck::index_t>& window_spatial_yx_lengths,
std::vector<ck::index_t>& window_yx_strides,
std::vector<ck::index_t>& window_yx_dilations,
std::vector<ck::index_t>& input_left_hw_pads,
std::vector<ck::index_t>& input_right_hw_pads)
: p_in_dev_{p_in_dev},
p_out_dev_{p_out_dev},
p_out_indices_dev_{p_out_indices_dev},
a_grid_desc_m_k_{},
b_grid_desc_m_{},
input_nchw_lengths_{input_nchw_lengths},
output_nchw_lengths_{output_nchw_lengths},
input_nchw_stride_{input_nchw_stride},
output_nchw_stride_{output_nchw_stride}
{
const auto descs = MakeABGridDescriptor_A_M_K_B_M(input_nchw_lengths,
output_nchw_lengths,
input_nchw_stride,
output_nchw_stride,
window_spatial_yx_lengths,
window_yx_strides,
window_yx_dilations,
input_left_hw_pads,
input_right_hw_pads);
a_grid_desc_m_k_ = descs[I0];
b_grid_desc_m_ = descs[I1];
int32_t reduceLength = window_spatial_yx_lengths[0] * window_spatial_yx_lengths[1];
std::tie(in_element_op_, acc_element_op_) =
reduce_unary_operator<ReduceOpId, true, true>::GetElementwiseOperator(reduceLength);
}
const InDataType* p_in_dev_;
OutDataType* p_out_dev_;
IndexDataType* p_out_indices_dev_;
AGridDesc_M_K a_grid_desc_m_k_;
BGridDesc_M b_grid_desc_m_;
InElementwiseOperation in_element_op_;
AccElementwiseOperation acc_element_op_;
// for checking vector load/store
std::vector<ck::index_t> input_nchw_lengths_;
std::vector<ck::index_t> output_nchw_lengths_;
std::vector<ck::index_t> input_nchw_stride_;
std::vector<ck::index_t> output_nchw_stride_;
};
struct Invoker : public BaseInvoker
{
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
// for NHWC, the dim C is the fastest dimension, and is not reduced.
// Hence, it is in M dimension for reduction kernel.
static constexpr index_t InSrcOutDstVectorDim = 0; // 0: M, 1: K
using gridwise_reduce =
GridwiseReduction_mk_to_m_threadwise<InDataType,
OutDataType,
IndexDataType,
ComputeDataType,
ReduceOpId,
OutputIndex,
IndexDataType,
AGridDesc_M_K,
BGridDesc_M,
ReduceOperation,
InElementwiseOperation,
AccElementwiseOperation,
InMemoryDataOperationEnum::Set,
false, // propagate_nan
BlockSize,
ReduceMThreadClusterSize,
ReduceKThreadClusterSize,
ReduceMThreadSliceSize,
ReduceKThreadSliceSize,
MThreadSliceSize,
KThreadSliceSize,
InSrcOutDstVectorDim,
InSrcOutDstVectorSize,
InSrcOutDstVectorSize>;
std::unique_ptr<BaseArgument>
const auto kernel =
kernel_reduce_threadwise<gridwise_reduce,
OutputIndex,
true, // pooling need to return global index
false, // don't have index input
InDataType,
OutDataType,
ComputeDataType,
IndexDataType,
AGridDesc_M_K,
BGridDesc_M,
InElementwiseOperation,
AccElementwiseOperation>;
ck::index_t M = arg.a_grid_desc_m_k_.GetLength(I0);
const index_t grid_size = (M / M_BlockTileSize);
return launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.a_grid_desc_m_k_,
arg.b_grid_desc_m_,
arg.in_element_op_,
arg.acc_element_op_,
float(1),
arg.p_in_dev_,
nullptr,
float(0),
arg.p_out_dev_,
arg.p_out_indices_dev_);
}
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
const Argument* pArg = dynamic_cast<const Argument*>(p_arg);
// C should be fastest dimension
if(pArg->input_nchw_stride_[1] != 1)
return false;
for(int i = 0; i < InOutRank; ++i)
{
if(pArg->input_nchw_stride_[i] == 1 &&
pArg->input_nchw_lengths_[i] % InSrcOutDstVectorSize != 0)
return false;
if(pArg->output_nchw_stride_[i] == 1 &&
pArg->output_nchw_lengths_[i] % InSrcOutDstVectorSize != 0)
return false;
}
return true;
}
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_in_dev,
void* p_out_dev,
void* p_out_indices_dev,
std::vector<ck::index_t> input_lengths,
std::vector<ck::index_t> window_lengths,
std::vector<ck::index_t> output_lengths,
std::vector<ck::index_t> input_stride,
std::vector<ck::index_t> output_stride,
std::vector<ck::index_t> indices_stride,
std::vector<ck::index_t> window_strides,
std::vector<ck::index_t> window_dilations,
std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads,
std::vector<ck::index_t> input_nchw_lengths,
std::vector<ck::index_t> window_yx_lengths,
std::vector<ck::index_t> output_nchw_lengths,
std::vector<ck::index_t> input_nchw_stride,
std::vector<ck::index_t> output_nchw_stride,
std::vector<ck::index_t> indices_nchw_stride,
std::vector<ck::index_t> window_yx_strides,
std::vector<ck::index_t> window_yx_dilations,
std::vector<ck::index_t> input_left_hw_pads,
std::vector<ck::index_t> input_right_hw_pads,
std::vector<ck::index_t> pooling_dims) override
{
static constexpr index_t InOutRank = 4;
static constexpr index_t WindowRank = 2;
if(input_lengths.size() != InOutRank || window_lengths.size() != WindowRank ||
input_lengths.size() != InOutRank || window_strides.size() != WindowRank ||
window_dilations.size() != WindowRank || input_left_pads.size() != WindowRank ||
input_right_pads.size() != WindowRank)
if(input_nchw_lengths.size() != InOutRank || window_yx_lengths.size() != WindowRank ||
input_nchw_lengths.size() != InOutRank || window_yx_strides.size() != WindowRank ||
window_yx_dilations.size() != WindowRank || input_left_hw_pads.size() != WindowRank ||
input_right_hw_pads.size() != WindowRank)
throw std::runtime_error("dimension is incorrect");
if(pooling_dims != std::vector<ck::index_t>{2, 3})
throw std::runtime_error("pooling_dims only support {2, 3} in pool2d so far");
// NCHW to NCDHW
input_lengths.insert(input_lengths.begin() + 2, 1);
output_lengths.insert(output_lengths.begin() + 2, 1);
input_stride.insert(input_stride.begin() + 2, 0);
output_stride.insert(output_stride.begin() + 2, 0);
indices_stride.insert(indices_stride.begin() + 2, 0);
// YX to ZYX
window_lengths.insert(window_lengths.begin(), 1);
window_strides.insert(window_strides.begin(), 0);
window_dilations.insert(window_dilations.begin(), 0);
input_left_pads.insert(input_left_pads.begin(), 0);
input_right_pads.insert(input_right_pads.begin(), 0);
pooling_dims = {2, 3, 4};
return DevicePool3D::MakeArgumentPointer(p_in_dev,
p_out_dev,
p_out_indices_dev,
input_lengths,
window_lengths,
output_lengths,
input_stride,
output_stride,
indices_stride,
window_strides,
window_dilations,
input_left_pads,
input_right_pads,
pooling_dims);
if(output_nchw_stride != indices_nchw_stride)
throw std::runtime_error(
"output_nchw_stride need to be equal to indices_nchw_stride for now");
return std::make_unique<Argument>(static_cast<const InDataType*>(p_in_dev),
static_cast<OutDataType*>(p_out_dev),
static_cast<IndexDataType*>(p_out_indices_dev),
input_nchw_lengths,
output_nchw_lengths,
input_nchw_stride,
output_nchw_stride,
indices_nchw_stride,
window_yx_lengths,
window_yx_strides,
window_yx_dilations,
input_left_hw_pads,
input_right_hw_pads);
}
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>(Invoker{});
}
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "DevicePool2dFwd_NHWC_NHWC<" << BlockSize << ",";
str << "M_C" << MThreadClusterSize << "_S" << MThreadSliceSize << ",";
str << "K_C" << KThreadClusterSize << "_S" << KThreadSliceSize << ",";
str <<"InSrcOutDstVectorSize_" << InSrcOutDstVectorSize << ">";
// clang-format on
return str.str();
}
};
......
......@@ -355,12 +355,39 @@ struct UnaryDivide
__host__ __device__ void operator()(T& y, const T& x) const
{
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, int32_t>::value,
is_same<T, int32_t>::value || is_same<T, int8_t>::value,
"Data type is not supported by this operation!");
y = x / type_convert<T>(divider_);
};
template <>
__host__ __device__ void operator()<half_t>(half_t& y, const half_t& x) const
{
float x_ = type_convert<float>(x);
float divider_f_ = type_convert<float>(divider_);
y = type_convert<half_t>(x_ / divider_f_);
};
template <>
__host__ __device__ void operator()<bhalf_t>(bhalf_t& y, const bhalf_t& x) const
{
float x_ = type_convert<float>(x);
float divider_f_ = type_convert<float>(divider_);
y = type_convert<bhalf_t>(x_ / divider_f_);
};
template <>
__host__ __device__ void operator()<f8_t>(f8_t& y, const f8_t& x) const
{
float x_ = type_convert<float>(x);
float divider_f_ = type_convert<float>(divider_);
y = type_convert<f8_t>(x_ / divider_f_);
};
int32_t divider_ = 1;
};
......
......@@ -221,7 +221,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{}));
}
__device__ static auto MakeAGridDescriptor_AK0_M_AK1(
__host__ __device__ static auto MakeAGridDescriptor_AK0_M_AK1(
index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA, index_t AK0)
{
const auto a_grid_desc_mraw_kraw = [&]() {
......@@ -303,7 +303,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
}
}
__device__ static auto MakeBGridDescriptor_BK0_N_BK1(
__host__ __device__ static auto MakeBGridDescriptor_BK0_N_BK1(
index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
{
const auto b_grid_desc_nraw_kraw = [&]() {
......@@ -576,12 +576,12 @@ struct GridwiseGemm_xdl_cshuffle_v3
}
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
{
a_k_split_offset = blockIdx.z * karg.KRead * karg.M;
a_k_split_offset = blockIdx.z * karg.KRead * karg.StrideA;
}
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
{
b_k_split_offset = blockIdx.z * karg.KRead * karg.N;
b_k_split_offset = blockIdx.z * karg.KRead * karg.StrideB;
}
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
{
......
......@@ -255,7 +255,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{}));
}
__device__ static auto MakeAGridDescriptor_AK0_M_AK1(
__host__ __device__ static auto MakeAGridDescriptor_AK0_M_AK1(
index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA, index_t AK0)
{
const auto a_grid_desc_mraw_kraw = [&]() {
......@@ -337,7 +337,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
}
}
__device__ static auto MakeBGridDescriptor_BK0_N_BK1(
__host__ __device__ static auto MakeBGridDescriptor_BK0_N_BK1(
index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
{
const auto b_grid_desc_nraw_kraw = [&]() {
......@@ -647,12 +647,12 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
}
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
{
a_k_split_offset = blockIdx.z * karg.KRead * karg.M;
a_k_split_offset = blockIdx.z * karg.KRead * karg.StrideA;
}
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
{
b_k_split_offset = blockIdx.z * karg.KRead * karg.N;
b_k_split_offset = blockIdx.z * karg.KRead * karg.StrideB;
}
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
{
......
......@@ -315,7 +315,7 @@ struct ThreadwiseTensorSliceTransfer_v5r1
forward_sweep_(I0) = true;
static_for<1, nDim, 1>{}([&](auto i) {
index_t tmp = ordered_dst_access_idx[I0];
index_t tmp = 0;
static_for<0, i, 1>{}([&](auto j) {
tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_idx[j];
......
......@@ -35,10 +35,16 @@ struct smfmac<SmfmacInstr::smfmac_f32_16x16x32f16>
static constexpr index_t k_per_blk = 8;
static constexpr bool is_k_reduction = true;
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, const int32_t& idx, FloatC& reg_c) const
template <index_t MPerXdlops,
index_t NPerXdlops,
index_t idx_part,
class FloatA,
class FloatB,
class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, const index_t& idx, FloatC& reg_c) const
{
intrin_smfmac_f32_16x16x32f16<MPerXdlops, NPerXdlops>::Run(a, b, idx, reg_c);
intrin_smfmac_f32_16x16x32f16<MPerXdlops, NPerXdlops>::Run<FloatC, idx_part>(
a, b, idx, reg_c);
}
};
......@@ -57,10 +63,16 @@ struct smfmac<SmfmacInstr::smfmac_f32_32x32x16f16>
static constexpr index_t k_per_blk = 16;
static constexpr bool is_k_reduction = true;
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, const int32_t& idx, FloatC& reg_c) const
template <index_t MPerXdlops,
index_t NPerXdlops,
index_t idx_part,
class FloatA,
class FloatB,
class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, const index_t& idx, FloatC& reg_c) const
{
intrin_smfmac_f32_32x32x16f16<MPerXdlops, NPerXdlops>::Run(a, b, idx, reg_c);
intrin_smfmac_f32_32x32x16f16<MPerXdlops, NPerXdlops>::Run<FloatC, idx_part>(
a, b, idx, reg_c);
}
};
......@@ -79,10 +91,16 @@ struct smfmac<SmfmacInstr::smfmac_f32_16x16x32bf16>
static constexpr index_t k_per_blk = 8;
static constexpr bool is_k_reduction = true;
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, const int32_t& idx, FloatC& reg_c) const
template <index_t MPerXdlops,
index_t NPerXdlops,
index_t idx_part,
class FloatA,
class FloatB,
class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, const index_t& idx, FloatC& reg_c) const
{
intrin_smfmac_f32_16x16x32bf16<MPerXdlops, NPerXdlops>::Run(a, b, idx, reg_c);
intrin_smfmac_f32_16x16x32bf16<MPerXdlops, NPerXdlops>::Run<FloatC, idx_part>(
a, b, idx, reg_c);
}
};
......@@ -101,10 +119,16 @@ struct smfmac<SmfmacInstr::smfmac_f32_32x32x16bf16>
static constexpr index_t k_per_blk = 16;
static constexpr bool is_k_reduction = true;
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, const int32_t& idx, FloatC& reg_c) const
template <index_t MPerXdlops,
index_t NPerXdlops,
index_t idx_part,
class FloatA,
class FloatB,
class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, const index_t& idx, FloatC& reg_c) const
{
intrin_smfmac_f32_32x32x16bf16<MPerXdlops, NPerXdlops>::Run(a, b, idx, reg_c);
intrin_smfmac_f32_32x32x16bf16<MPerXdlops, NPerXdlops>::Run<FloatC, idx_part>(
a, b, idx, reg_c);
}
};
......@@ -305,8 +329,8 @@ struct SparseXdlopsGemm
"base base_type must be half or bfloat16!");
static_for<0, KPack / smfmac_instr.k_per_blk, 1>{}([&](auto k) {
smfmac_instr.template run<MPerXdlops, NPerXdlops>(
p_a_wave[k], p_b_wave[k], idx[k], p_c_thread);
smfmac_instr.template run<MPerXdlops, NPerXdlops, k % 4>(
p_a_wave[k], p_b_wave[k], idx[k / 4], p_c_thread);
});
}
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
namespace ck {
namespace tensor_operation {
template <typename ALayout,
typename BLayout,
typename ELayout,
index_t NDimSpatial,
index_t MPerThread,
index_t NPerThread>
struct TransformConvNGCHWToNHWGC
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{};
static constexpr auto I5 = Number<5>{};
template <ck::index_t NDim, typename ck::enable_if<NDim == 1, bool>::type = false>
static auto MakeNGCHWTransposeDesc(std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_lengths,
std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_strides)
{
const index_t& G = g_n_c_wis_lengths[I0];
const index_t& N = g_n_c_wis_lengths[I1];
const index_t& C = g_n_c_wis_lengths[I2];
const index_t& Wi = g_n_c_wis_lengths[I3];
const index_t& GStride = g_n_c_wis_strides[I0];
const index_t& NStride = g_n_c_wis_strides[I1];
const index_t& CStride = g_n_c_wis_strides[I2];
const index_t& WiStride = g_n_c_wis_strides[I3];
const auto desc = make_naive_tensor_descriptor(
make_tuple(N, G, C, Wi), make_tuple(NStride, GStride, CStride, WiStride));
const auto merged_desc =
transform_tensor_descriptor(desc,
make_tuple(make_merge_transform(make_tuple(N, G, C)),
make_merge_transform(make_tuple(Wi))),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return device::PadTensorDescriptor(
merged_desc, make_tuple(MPerThread, NPerThread), Sequence<true, true>{});
}
template <ck::index_t NDim, typename ck::enable_if<NDim == 1, bool>::type = false>
static auto MakeNHWGCTransposeDesc(std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_lengths,
std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_strides)
{
const index_t& G = g_n_c_wis_lengths[I0];
const index_t& N = g_n_c_wis_lengths[I1];
const index_t& C = g_n_c_wis_lengths[I2];
const index_t& Wi = g_n_c_wis_lengths[I3];
const index_t& NStride = g_n_c_wis_strides[I1];
const index_t WiStride = G * C;
const index_t GStride = C;
const index_t CStride = 1;
const auto desc = make_naive_tensor_descriptor(
make_tuple(N, G, C, Wi), make_tuple(NStride, GStride, CStride, WiStride));
const auto merged_desc =
transform_tensor_descriptor(desc,
make_tuple(make_merge_transform(make_tuple(N, G, C)),
make_merge_transform(make_tuple(Wi))),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return device::PadTensorDescriptor(
merged_desc, make_tuple(MPerThread, NPerThread), Sequence<true, true>{});
}
template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false>
static auto MakeNGCHWTransposeDesc(std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_lengths,
std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_strides)
{
const index_t& G = g_n_c_wis_lengths[I0];
const index_t& N = g_n_c_wis_lengths[I1];
const index_t& C = g_n_c_wis_lengths[I2];
const index_t& Hi = g_n_c_wis_lengths[I3];
const index_t& Wi = g_n_c_wis_lengths[I4];
const index_t& GStride = g_n_c_wis_strides[I0];
const index_t& NStride = g_n_c_wis_strides[I1];
const index_t& CStride = g_n_c_wis_strides[I2];
const index_t& HiStride = g_n_c_wis_strides[I3];
const index_t& WiStride = g_n_c_wis_strides[I4];
const auto desc = make_naive_tensor_descriptor(
make_tuple(N, G, C, Hi, Wi), make_tuple(NStride, GStride, CStride, HiStride, WiStride));
const auto merged_desc =
transform_tensor_descriptor(desc,
make_tuple(make_merge_transform(make_tuple(N, G, C)),
make_merge_transform(make_tuple(Hi, Wi))),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return device::PadTensorDescriptor(
merged_desc, make_tuple(MPerThread, NPerThread), Sequence<true, true>{});
}
template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false>
static auto MakeNHWGCTransposeDesc(std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_lengths,
std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_strides)
{
const index_t& G = g_n_c_wis_lengths[I0];
const index_t& N = g_n_c_wis_lengths[I1];
const index_t& C = g_n_c_wis_lengths[I2];
const index_t& Hi = g_n_c_wis_lengths[I3];
const index_t& Wi = g_n_c_wis_lengths[I4];
const index_t& NStride = g_n_c_wis_strides[I1];
const index_t HiStride = Wi * G * C;
const index_t WiStride = G * C;
const index_t GStride = C;
const index_t CStride = 1;
const auto desc = make_naive_tensor_descriptor(
make_tuple(N, G, C, Hi, Wi), make_tuple(NStride, GStride, CStride, HiStride, WiStride));
const auto merged_desc =
transform_tensor_descriptor(desc,
make_tuple(make_merge_transform(make_tuple(N, G, C)),
make_merge_transform(make_tuple(Hi, Wi))),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return device::PadTensorDescriptor(
merged_desc, make_tuple(MPerThread, NPerThread), Sequence<true, true>{});
}
template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false>
static auto MakeNGCHWTransposeDesc(std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_lengths,
std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_strides)
{
const index_t& G = g_n_c_wis_lengths[I0];
const index_t& N = g_n_c_wis_lengths[I1];
const index_t& C = g_n_c_wis_lengths[I2];
const index_t& Di = g_n_c_wis_lengths[I3];
const index_t& Hi = g_n_c_wis_lengths[I4];
const index_t& Wi = g_n_c_wis_lengths[I5];
const index_t& GStride = g_n_c_wis_strides[I0];
const index_t& NStride = g_n_c_wis_strides[I1];
const index_t& CStride = g_n_c_wis_strides[I2];
const index_t& DiStride = g_n_c_wis_strides[I3];
const index_t& HiStride = g_n_c_wis_strides[I4];
const index_t& WiStride = g_n_c_wis_strides[I5];
const auto desc = make_naive_tensor_descriptor(
make_tuple(N, G, C, Di, Hi, Wi),
make_tuple(NStride, GStride, CStride, DiStride, HiStride, WiStride));
const auto merged_desc =
transform_tensor_descriptor(desc,
make_tuple(make_merge_transform(make_tuple(N, G, C)),
make_merge_transform(make_tuple(Di, Hi, Wi))),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return device::PadTensorDescriptor(
merged_desc, make_tuple(MPerThread, NPerThread), Sequence<true, true>{});
}
template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false>
static auto MakeNHWGCTransposeDesc(std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_lengths,
std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_strides)
{
const index_t& G = g_n_c_wis_lengths[I0];
const index_t& N = g_n_c_wis_lengths[I1];
const index_t& C = g_n_c_wis_lengths[I2];
const index_t& Di = g_n_c_wis_lengths[I3];
const index_t& Hi = g_n_c_wis_lengths[I4];
const index_t& Wi = g_n_c_wis_lengths[I5];
const index_t& NStride = g_n_c_wis_strides[I1];
const index_t DiStride = Hi * Wi * G * C;
const index_t HiStride = Wi * G * C;
const index_t WiStride = G * C;
const index_t GStride = C;
const index_t CStride = 1;
const auto desc = make_naive_tensor_descriptor(
make_tuple(N, G, C, Di, Hi, Wi),
make_tuple(NStride, GStride, CStride, DiStride, HiStride, WiStride));
const auto merged_desc =
transform_tensor_descriptor(desc,
make_tuple(make_merge_transform(make_tuple(N, G, C)),
make_merge_transform(make_tuple(Di, Hi, Wi))),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return device::PadTensorDescriptor(
merged_desc, make_tuple(MPerThread, NPerThread), Sequence<true, true>{});
}
static auto TransposeStrides(const std::array<index_t, NDimSpatial + 3>& g_n_c_wis_lengths,
const std::array<index_t, NDimSpatial + 3>& g_n_c_wis_strides)
{
if constexpr(device::is_NGCHW_GKYXC_NGKHW<ALayout, BLayout, ELayout>() ||
device::is_NGCDHW_GKZYXC_NGKDHW<ALayout, BLayout, ELayout>())
{
std::array<index_t, NDimSpatial + 3> g_n_c_wis_strides_transposed;
const auto G = g_n_c_wis_lengths[I0];
const auto C = g_n_c_wis_lengths[I2];
g_n_c_wis_strides_transposed[I0] = C;
g_n_c_wis_strides_transposed[I1] = g_n_c_wis_strides[I1];
g_n_c_wis_strides_transposed[I2] = I1;
if constexpr(NDimSpatial == 2)
{
g_n_c_wis_strides_transposed[I3] = g_n_c_wis_lengths[I4] * G * C;
g_n_c_wis_strides_transposed[I4] = G * C;
}
else if constexpr(NDimSpatial == 3)
{
g_n_c_wis_strides_transposed[I3] =
g_n_c_wis_lengths[I4] * g_n_c_wis_lengths[I5] * G * C;
g_n_c_wis_strides_transposed[I4] = g_n_c_wis_lengths[I5] * G * C;
g_n_c_wis_strides_transposed[I5] = G * C;
}
return g_n_c_wis_strides_transposed;
}
else
{
// transpose not needed
return g_n_c_wis_strides;
}
}
};
} // namespace tensor_operation
} // namespace ck
......@@ -9,16 +9,18 @@ namespace ck {
template <index_t MPerWave, index_t NPerWave>
struct intrin_smfmac_f32_16x16x32f16;
// for every smfmac instruction if CBSZ[1:0]=0, ABID[1:0] selects one of four 8-bit sets of sparse
// indices from reg_idx
template <>
struct intrin_smfmac_f32_16x16x32f16<16, 16>
{
template <class FloatC>
template <class FloatC, index_t abid = 0>
__device__ static void
Run(const half4_t& reg_a, const half8_t& reg_b, const int32_t& reg_idx, FloatC& reg_c)
Run(const half4_t& reg_a, const half8_t& reg_b, const index_t& reg_idx, FloatC& reg_c)
{
#if defined(__gfx94__)
reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_smfmac_f32_16x16x32_f16(
reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], reg_idx, 0, 0);
reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], reg_idx, 0, abid);
#else
ignore = reg_a;
ignore = reg_b;
......@@ -34,13 +36,13 @@ struct intrin_smfmac_f32_16x16x32bf16;
template <>
struct intrin_smfmac_f32_16x16x32bf16<16, 16>
{
template <class FloatC>
template <class FloatC, index_t abid = 0>
__device__ static void
Run(const bhalf4_t& reg_a, const bhalf8_t& reg_b, const int32_t& reg_idx, FloatC& reg_c)
Run(const bhalf4_t& reg_a, const bhalf8_t& reg_b, const index_t& reg_idx, FloatC& reg_c)
{
#if defined(__gfx94__)
reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_smfmac_f32_16x16x32_bf16(
reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], reg_idx, 0, 0);
reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], reg_idx, 0, abid);
#else
ignore = reg_a;
ignore = reg_b;
......@@ -56,13 +58,13 @@ struct intrin_smfmac_f32_32x32x16f16;
template <>
struct intrin_smfmac_f32_32x32x16f16<32, 32>
{
template <class FloatC>
template <class FloatC, index_t abid = 0>
__device__ static void
Run(const half4_t& reg_a, const half8_t& reg_b, const int32_t& reg_idx, FloatC& reg_c)
Run(const half4_t& reg_a, const half8_t& reg_b, const index_t& reg_idx, FloatC& reg_c)
{
#if defined(__gfx94__)
reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_smfmac_f32_32x32x16_f16(
reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], reg_idx, 0, 0);
reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], reg_idx, 0, abid);
#else
ignore = reg_a;
ignore = reg_b;
......@@ -78,13 +80,13 @@ struct intrin_smfmac_f32_32x32x16bf16;
template <>
struct intrin_smfmac_f32_32x32x16bf16<32, 32>
{
template <class FloatC>
template <class FloatC, index_t abid = 0>
__device__ static void
Run(const bhalf4_t& reg_a, const bhalf8_t& reg_b, const int32_t& reg_idx, FloatC& reg_c)
Run(const bhalf4_t& reg_a, const bhalf8_t& reg_b, const index_t& reg_idx, FloatC& reg_c)
{
#if defined(__gfx94__)
reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_smfmac_f32_32x32x16_bf16(
reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], reg_idx, 0, 0);
reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], reg_idx, 0, abid);
#else
ignore = reg_a;
ignore = reg_b;
......
......@@ -52,12 +52,28 @@ struct Add
__host__ __device__ inline constexpr void operator()(T& a, T b) const
{
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, int32_t>::value || is_same<T, half_t>::value,
is_same<T, int32_t>::value || is_same<T, int8_t>::value,
"The data type is not supported by the Add accumulator!");
a = a + b;
}
__host__ __device__ inline constexpr void operator()(f8_t& a, f8_t b) const
{
float a_ = type_convert<float>(a);
float b_ = type_convert<float>(b);
a = type_convert<f8_t>(a_ + b_);
}
__host__ __device__ inline constexpr void operator()(half_t& a, half_t b) const
{
float a_ = type_convert<float>(a);
float b_ = type_convert<float>(b);
a = type_convert<half_t>(a_ + b_);
}
__host__ __device__ inline constexpr void operator()(bhalf_t& a, bhalf_t b) const
{
float a_ = type_convert<float>(a);
......@@ -112,12 +128,28 @@ struct Mul
__host__ __device__ inline constexpr void operator()(T& a, T b) const
{
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, int32_t>::value || is_same<T, half_t>::value,
is_same<T, int32_t>::value || is_same<T, int8_t>::value,
"The data type is not supported by the Mul accumulator!");
a = a * b;
}
__host__ __device__ inline constexpr void operator()(f8_t& a, f8_t b) const
{
float a_ = type_convert<float>(a);
float b_ = type_convert<float>(b);
a = type_convert<f8_t>(a_ * b_);
}
__host__ __device__ inline constexpr void operator()(half_t& a, half_t b) const
{
float a_ = type_convert<float>(a);
float b_ = type_convert<float>(b);
a = type_convert<half_t>(a_ * b_);
}
__host__ __device__ inline constexpr void operator()(bhalf_t& a, bhalf_t b) const
{
float a_ = type_convert<float>(a);
......@@ -137,6 +169,16 @@ struct Max
float val = NumericLimits<float>::Lowest();
return type_convert<bhalf_t>(val);
}
if constexpr(is_same_v<T, f8_t>)
{
float val = NumericLimits<float>::Lowest();
return type_convert<f8_t>(val);
}
if constexpr(is_same_v<T, half_t>)
{
float val = NumericLimits<float>::Lowest();
return type_convert<half_t>(val);
}
else
{
return NumericLimits<T>::Lowest();
......@@ -154,8 +196,7 @@ struct Max
__host__ __device__ inline constexpr void operator()(T& a, T b) const
{
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, half_t>::value || is_same<T, int32_t>::value ||
is_same<T, int8_t>::value,
is_same<T, int32_t>::value || is_same<T, int8_t>::value,
"The data type is not supported by the Max accumulator!");
if(a < b)
......@@ -171,12 +212,29 @@ struct Max
a = b;
}
__host__ __device__ inline constexpr void operator()(half_t& a, half_t b) const
{
float a_ = type_convert<float>(a);
float b_ = type_convert<float>(b);
if(a_ < b_)
a = b;
}
__host__ __device__ inline constexpr void operator()(f8_t& a, f8_t b) const
{
float a_ = type_convert<float>(a);
float b_ = type_convert<float>(b);
if(a_ < b_)
a = b;
}
template <typename T>
__host__ __device__ inline constexpr void operator()(T& a, T b, bool& changed) const
{
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, half_t>::value || is_same<T, int32_t>::value ||
is_same<T, int8_t>::value,
is_same<T, int32_t>::value || is_same<T, int8_t>::value,
"The data type is not supported by the Max accumulator!");
if(a < b)
......@@ -197,6 +255,30 @@ struct Max
changed = true;
}
}
__host__ __device__ inline constexpr void operator()(half_t& a, half_t b, bool& changed) const
{
float a_ = type_convert<float>(a);
float b_ = type_convert<float>(b);
if(a_ < b_)
{
a = b;
changed = true;
}
}
__host__ __device__ inline constexpr void operator()(f8_t& a, f8_t b, bool& changed) const
{
float a_ = type_convert<float>(a);
float b_ = type_convert<float>(b);
if(a_ < b_)
{
a = b;
changed = true;
}
}
};
struct Min
......@@ -209,6 +291,16 @@ struct Min
float val = NumericLimits<float>::Max();
return type_convert<bhalf_t>(val);
}
else if constexpr(is_same_v<T, half_t>)
{
float val = NumericLimits<float>::Max();
return type_convert<half_t>(val);
}
else if constexpr(is_same_v<T, f8_t>)
{
float val = NumericLimits<float>::Max();
return type_convert<f8_t>(val);
}
else
{
return NumericLimits<T>::Max();
......@@ -227,8 +319,7 @@ struct Min
__host__ __device__ inline constexpr void operator()(T& a, T b) const
{
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, half_t>::value || is_same<T, int32_t>::value ||
is_same<T, int8_t>::value,
is_same<T, int32_t>::value || is_same<T, int8_t>::value,
"The data type is not supported by the Min accumulator!");
if(a > b)
......@@ -244,6 +335,24 @@ struct Min
a = b;
}
__host__ __device__ inline constexpr void operator()(half_t& a, half_t b) const
{
float a_ = type_convert<float>(a);
float b_ = type_convert<float>(b);
if(a_ > b_)
a = b;
}
__host__ __device__ inline constexpr void operator()(f8_t& a, f8_t b) const
{
float a_ = type_convert<float>(a);
float b_ = type_convert<float>(b);
if(a_ > b_)
a = b;
}
template <typename T>
__host__ __device__ inline constexpr void operator()(T& a, T b, bool& changed) const
{
......@@ -270,6 +379,30 @@ struct Min
changed = true;
}
}
__host__ __device__ inline constexpr void operator()(half_t& a, half_t b, bool& changed) const
{
float a_ = type_convert<float>(a);
float b_ = type_convert<float>(b);
if(a_ > b_)
{
a = b;
changed = true;
}
}
__host__ __device__ inline constexpr void operator()(f8_t& a, f8_t b, bool& changed) const
{
float a_ = type_convert<float>(a);
float b_ = type_convert<float>(b);
if(a_ > b_)
{
a = b;
changed = true;
}
}
};
struct AMax
......@@ -299,6 +432,15 @@ struct AMax
a = b;
}
__host__ __device__ inline constexpr void operator()(f8_t& a, f8_t b) const
{
float a_ = type_convert<float>(a);
float b_ = type_convert<float>(b);
if(a_ < b_)
a = b;
}
template <typename T>
__host__ __device__ inline constexpr void operator()(T& a, T b, bool& changed) const
{
......@@ -313,6 +455,18 @@ struct AMax
changed = true;
}
}
__host__ __device__ inline constexpr void operator()(f8_t& a, f8_t b, bool& changed) const
{
float a_ = type_convert<float>(a);
float b_ = type_convert<float>(b);
if(a_ < b_)
{
a = b;
changed = true;
}
}
};
template <typename T>
......@@ -352,7 +506,8 @@ struct InMemoryDataOperationSupportedOnDataType<InMemoryDataOperationEnum::Set,
static constexpr bool value =
is_same<DataType, float>::value || is_same<DataType, double>::value ||
is_same<DataType, half_t>::value || is_same<DataType, bhalf_t>::value ||
is_same<DataType, int8_t>::value || is_same<DataType, int32_t>::value;
is_same<DataType, int8_t>::value || is_same<DataType, int32_t>::value ||
is_same<DataType, f8_t>::value;
};
template <typename DataType>
......
......@@ -5,6 +5,7 @@
#include "ck_tile/core.hpp"
#include "ck_tile/host/host_tensor.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include <thread>
namespace ck_tile {
......@@ -13,6 +14,9 @@ template <typename ADataType,
typename BDataType,
typename AccDataType,
typename CDataType,
typename LayoutA,
typename LayoutB,
typename LayoutC,
typename AElementOp = ck_tile::identity,
typename BElementOp = ck_tile::identity,
typename ACCElementOp = ck_tile::identity>
......@@ -24,7 +28,12 @@ CK_TILE_HOST void reference_gemm(const HostTensor<ADataType>& a_m_k,
const ACCElementOp& acc_element_op = {})
{
const int N = b_n_k.mDesc.get_lengths()[0];
const int K = b_n_k.mDesc.get_lengths()[1];
const int K = (std::is_same_v<LayoutA, tensor_layout::gemm::RowMajor>)
? a_m_k.mDesc.get_lengths()[1]
: a_m_k.mDesc.get_lengths()[0];
const int M = (std::is_same_v<LayoutA, tensor_layout::gemm::RowMajor>)
? a_m_k.mDesc.get_lengths()[0]
: a_m_k.mDesc.get_lengths()[1];
auto f = [&](auto m) {
for(int n = 0; n < N; ++n)
......@@ -33,7 +42,9 @@ CK_TILE_HOST void reference_gemm(const HostTensor<ADataType>& a_m_k,
for(int k = 0; k < K; ++k)
{
ADataType v_a = a_element_op(a_m_k(m, k));
ADataType v_a = (std::is_same_v<LayoutA, tensor_layout::gemm::RowMajor>)
? a_element_op(a_m_k(m, k))
: a_element_op(a_m_k(k, m));
BDataType v_b = b_element_op(b_n_k(n, k));
v_acc += ck_tile::type_convert<AccDataType>(v_a) *
......@@ -44,7 +55,123 @@ CK_TILE_HOST void reference_gemm(const HostTensor<ADataType>& a_m_k,
}
};
make_ParallelTensorFunctor(f,
c_m_n.mDesc.get_lengths()[0])(std::thread::hardware_concurrency());
make_ParallelTensorFunctor(f, M)(std::thread::hardware_concurrency());
}
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType>
__global__ void naive_gemm_kernel(ADataType* A,
BDataType* B,
CDataType* C,
ck_tile::index_t M,
ck_tile::index_t N,
ck_tile::index_t K,
ck_tile::index_t strideA,
ck_tile::index_t strideB,
ck_tile::index_t strideC)
{
int idx = blockIdx.x * blockDim.x + threadIdx.x;
int row = idx / N; // Compute row index
int col = idx % N; // Compute column index
if(row < M && col < N)
{
AccDataType acc = 0.0;
for(int k = 0; k < K; ++k)
{
acc += static_cast<AccDataType>(A[row * strideA + k]) *
static_cast<AccDataType>(B[col * strideB + k]);
}
C[row * strideC + col] = acc; // Store as AccDataType
}
}
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType>
void reference_gemm_gpu(DeviceMem& a_device,
DeviceMem& b_device,
DeviceMem& c_device,
index_t M,
index_t N,
index_t K,
index_t stride_a,
index_t stride_b,
index_t stride_c)
{
ADataType* d_A;
BDataType* d_B;
CDataType* d_C;
hipError_t errA = hipMalloc(&d_A, M * K * sizeof(ADataType));
hipError_t errB = hipMalloc(&d_B, N * K * sizeof(BDataType));
hipError_t errC = hipMalloc(&d_C, M * N * sizeof(CDataType));
if(errA != hipSuccess)
{
std::cerr << "Error allocating device memory for A: " << hipGetErrorString(errA)
<< std::endl;
return; // Early exit on error
}
if(errB != hipSuccess)
{
std::cerr << "Error allocating device memory for B: " << hipGetErrorString(errB)
<< std::endl;
return; // Early exit on error
}
if(errC != hipSuccess)
{
std::cerr << "Error allocating device memory for C: " << hipGetErrorString(errC)
<< std::endl;
return; // Early exit on error
}
errA = hipMemcpy(
d_A, a_device.GetDeviceBuffer(), M * K * sizeof(ADataType), hipMemcpyHostToDevice);
if(errA != hipSuccess)
{
std::cerr << "Error copying A to device: " << hipGetErrorString(errA) << std::endl;
}
errB = hipMemcpy(
d_B, b_device.GetDeviceBuffer(), N * K * sizeof(BDataType), hipMemcpyHostToDevice);
if(errB != hipSuccess)
{
std::cerr << "Error copying B to device: " << hipGetErrorString(errB) << std::endl;
}
int totalElements = M * N;
int numThreadsPerBlock = 256; // Common choice for threads per block
int numBlocks = (totalElements + numThreadsPerBlock - 1) / numThreadsPerBlock;
naive_gemm_kernel<ADataType, BDataType, AccDataType, CDataType>
<<<numBlocks, numThreadsPerBlock>>>(d_A, d_B, d_C, M, N, K, stride_a, stride_b, stride_c);
errC = hipMemcpy(
c_device.GetDeviceBuffer(), d_C, M * N * sizeof(CDataType), hipMemcpyDeviceToHost);
if(errC != hipSuccess)
{
std::cerr << "Error copying C to device: " << hipGetErrorString(errC) << std::endl;
}
errA = hipFree(d_A);
if(errA != hipSuccess)
{
std::cerr << "Error free the A memory: " << hipGetErrorString(errA) << std::endl;
}
errB = hipFree(d_B);
if(errB != hipSuccess)
{
std::cerr << "Error free the B memory: " << hipGetErrorString(errB) << std::endl;
}
errC = hipFree(d_C);
if(errC != hipSuccess)
{
std::cerr << "Error free the C memory: " << hipGetErrorString(errC) << std::endl;
}
return;
}
} // namespace ck_tile
......@@ -25,14 +25,15 @@ struct BlockFmhaBwdPipelineDefaultPolicy
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm()
{
using BlockGemmProblem =
BlockGemmPipelineProblem<typename Problem::QDataType,
typename Problem::KDataType,
typename Problem::AccDataType,
Problem::kBlockSize,
TileGemmShape<Problem::BlockFmhaShape::kM0,
Problem::BlockFmhaShape::kN0,
Problem::BlockFmhaShape::kK0>>;
using BlockGemmProblem = BlockGemmPipelineProblem<
typename Problem::QDataType,
typename Problem::KDataType,
typename Problem::AccDataType,
TileGemmShape<sequence<Problem::BlockFmhaShape::kM0,
Problem::BlockFmhaShape::kN0,
Problem::BlockFmhaShape::kK0>,
typename Problem::BlockFmhaShape::Gemm0BlockWarps,
typename Problem::BlockFmhaShape::Gemm0WarpTile>>;
using WarpGemm = WarpGemmMfmaDispatcher<
typename Problem::QDataType,
......@@ -57,14 +58,15 @@ struct BlockFmhaBwdPipelineDefaultPolicy
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetPTOGradTBlockGemm()
{
using BlockGemmProblem =
BlockGemmPipelineProblem<typename Problem::GemmDataType,
typename Problem::OGradDataType,
typename Problem::AccDataType,
Problem::kBlockSize,
TileGemmShape<Problem::BlockFmhaShape::kN0,
Problem::BlockFmhaShape::kVHeaddim,
Problem::BlockFmhaShape::kK1>>;
using BlockGemmProblem = BlockGemmPipelineProblem<
typename Problem::GemmDataType,
typename Problem::OGradDataType,
typename Problem::AccDataType,
TileGemmShape<sequence<Problem::BlockFmhaShape::kN0,
Problem::BlockFmhaShape::kVHeaddim,
Problem::BlockFmhaShape::kK1>,
typename Problem::BlockFmhaShape::Gemm1BlockWarps,
typename Problem::BlockFmhaShape::Gemm1WarpTile>>;
using WarpGemm =
WarpGemmMfmaDispatcher<typename Problem::GemmDataType,
......@@ -88,14 +90,15 @@ struct BlockFmhaBwdPipelineDefaultPolicy
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetOGradVBlockGemm()
{
using BlockGemmProblem =
BlockGemmPipelineProblem<typename Problem::OGradDataType,
typename Problem::VDataType,
typename Problem::AccDataType,
Problem::kBlockSize,
TileGemmShape<Problem::BlockFmhaShape::kM0,
Problem::BlockFmhaShape::kN0,
Problem::BlockFmhaShape::kK2>>;
using BlockGemmProblem = BlockGemmPipelineProblem<
typename Problem::OGradDataType,
typename Problem::VDataType,
typename Problem::AccDataType,
TileGemmShape<sequence<Problem::BlockFmhaShape::kM0,
Problem::BlockFmhaShape::kN0,
Problem::BlockFmhaShape::kK2>,
typename Problem::BlockFmhaShape::Gemm2BlockWarps,
typename Problem::BlockFmhaShape::Gemm2WarpTile>>;
using WarpGemm = WarpGemmMfmaDispatcher<
typename Problem::OGradDataType,
......@@ -120,14 +123,15 @@ struct BlockFmhaBwdPipelineDefaultPolicy
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSGradTQTBlockGemm()
{
using BlockGemmProblem =
BlockGemmPipelineProblem<typename Problem::GemmDataType,
typename Problem::QDataType,
typename Problem::AccDataType,
Problem::kBlockSize,
TileGemmShape<Problem::BlockFmhaShape::kN0,
Problem::BlockFmhaShape::kQKHeaddim,
Problem::BlockFmhaShape::kK3>>;
using BlockGemmProblem = BlockGemmPipelineProblem<
typename Problem::GemmDataType,
typename Problem::QDataType,
typename Problem::AccDataType,
TileGemmShape<sequence<Problem::BlockFmhaShape::kN0,
Problem::BlockFmhaShape::kQKHeaddim,
Problem::BlockFmhaShape::kK3>,
typename Problem::BlockFmhaShape::Gemm3BlockWarps,
typename Problem::BlockFmhaShape::Gemm3WarpTile>>;
using WarpGemm =
WarpGemmMfmaDispatcher<typename Problem::GemmDataType,
......@@ -151,14 +155,15 @@ struct BlockFmhaBwdPipelineDefaultPolicy
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSGradKTBlockGemm()
{
using BlockGemmProblem =
BlockGemmPipelineProblem<typename Problem::GemmDataType,
typename Problem::KDataType,
typename Problem::AccDataType,
Problem::kBlockSize,
TileGemmShape<Problem::BlockFmhaShape::kM0,
Problem::BlockFmhaShape::kQKHeaddim,
Problem::BlockFmhaShape::kK4>>;
using BlockGemmProblem = BlockGemmPipelineProblem<
typename Problem::GemmDataType,
typename Problem::KDataType,
typename Problem::AccDataType,
TileGemmShape<sequence<Problem::BlockFmhaShape::kM0,
Problem::BlockFmhaShape::kQKHeaddim,
Problem::BlockFmhaShape::kK4>,
typename Problem::BlockFmhaShape::Gemm4BlockWarps,
typename Problem::BlockFmhaShape::Gemm4WarpTile>>;
using WarpGemm =
WarpGemmMfmaDispatcher<typename Problem::GemmDataType,
......
......@@ -215,8 +215,8 @@ struct BlockFmhaPipelineQRKSVS
const auto num_total_loop = integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0);
// check early exit if masked and no work to do.
if constexpr(FmhaMask::IsMasking)
// check early exit if no work to do
if constexpr(FmhaMask::IsMasking || kPadSeqLenK)
{
if(num_total_loop <= 0)
{
......
......@@ -268,7 +268,7 @@ struct BlockFmhaPipelineQRKSVSAsync
const auto num_total_loop = integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0);
// check early exit
// check early exit if no work to do
if constexpr(FmhaMask::IsMasking || kPadSeqLenK)
{
if(num_total_loop <= 0)
......
......@@ -75,14 +75,15 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true>
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm()
{
using BlockGemmProblem =
BlockGemmPipelineProblem<typename Problem::QDataType,
typename Problem::KDataType,
typename Problem::SaccDataType,
Problem::kBlockSize,
TileGemmShape<Problem::BlockFmhaShape::kM0,
Problem::BlockFmhaShape::kN0,
Problem::BlockFmhaShape::kK0>>;
using BlockGemmProblem = BlockGemmPipelineProblem<
typename Problem::QDataType,
typename Problem::KDataType,
typename Problem::SaccDataType,
TileGemmShape<sequence<Problem::BlockFmhaShape::kM0,
Problem::BlockFmhaShape::kN0,
Problem::BlockFmhaShape::kK0>,
typename Problem::BlockFmhaShape::Gemm0BlockWarps,
typename Problem::BlockFmhaShape::Gemm0WarpTile>>;
constexpr auto warp_gemm = []() {
if constexpr(std::is_same_v<typename Problem::QDataType, half_t> &&
......@@ -198,14 +199,15 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ false>
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm()
{
using BlockGemmProblem =
BlockGemmPipelineProblem<typename Problem::QDataType,
typename Problem::KDataType,
typename Problem::SaccDataType,
Problem::kBlockSize,
TileGemmShape<Problem::BlockFmhaShape::kM0,
Problem::BlockFmhaShape::kN0,
Problem::BlockFmhaShape::kK0>>;
using BlockGemmProblem = BlockGemmPipelineProblem<
typename Problem::QDataType,
typename Problem::KDataType,
typename Problem::SaccDataType,
TileGemmShape<sequence<Problem::BlockFmhaShape::kM0,
Problem::BlockFmhaShape::kN0,
Problem::BlockFmhaShape::kK0>,
typename Problem::BlockFmhaShape::Gemm0BlockWarps,
typename Problem::BlockFmhaShape::Gemm0WarpTile>>;
constexpr auto warp_gemm = []() {
if constexpr(std::is_same_v<typename Problem::QDataType, half_t> &&
......@@ -952,14 +954,15 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetKVBlockGemm()
{
using BlockGemmProblem =
BlockGemmPipelineProblem<typename Problem::PDataType,
typename Problem::VDataType,
typename Problem::OaccDataType,
Problem::kBlockSize,
TileGemmShape<Problem::BlockFmhaShape::kM0,
Problem::BlockFmhaShape::kN1,
Problem::BlockFmhaShape::kK1>>;
using BlockGemmProblem = BlockGemmPipelineProblem<
typename Problem::PDataType,
typename Problem::VDataType,
typename Problem::OaccDataType,
TileGemmShape<sequence<Problem::BlockFmhaShape::kM0,
Problem::BlockFmhaShape::kN1,
Problem::BlockFmhaShape::kK1>,
typename Problem::BlockFmhaShape::Gemm1BlockWarps,
typename Problem::BlockFmhaShape::Gemm1WarpTile>>;
auto warp_gemm = [&]() {
if constexpr(std::is_same_v<typename Problem::KDataType, fp8_t> &&
......
......@@ -21,6 +21,8 @@
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_problem.hpp"
#include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp"
#include "ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp"
#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1.hpp"
#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2.hpp"
......
......@@ -4,7 +4,8 @@
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_problem.hpp"
namespace ck_tile {
......@@ -27,9 +28,9 @@ struct BlockGemmARegBGmemCRegV1
static constexpr index_t kBlockSize = Problem::kBlockSize;
// use BlockGemmARegBSmemCRegV1 as the underlying block-GEMM implementation
using BlockGemmARegBSmemCRegImpl = BlockGemmARegBSmemCRegV1<
using BlockGemmARegBGmemCRegImpl = BlockGemmARegBGmemCRegV1<
BlockGemmProblem<ADataType, BDataType, CDataType, kBlockSize, BlockGemmShape>,
BlockGemmARegBSmemCRegV1DefaultPolicy>;
BlockGemmARegBGmemCRegV1DefaultPolicy>;
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetStaticLdsSize()
{
......@@ -82,7 +83,7 @@ struct BlockGemmARegBGmemCRegV1
block_sync_lds();
// block GEMM
BlockGemmARegBSmemCRegImpl{}(c_block_tensor, a_block_tensor, b_block_smem_window);
BlockGemmARegBGmemCRegImpl{}(c_block_tensor, a_block_tensor, b_block_smem_window);
}
// C = A * B
......@@ -128,7 +129,7 @@ struct BlockGemmARegBGmemCRegV1
block_sync_lds();
// block GEMM
return BlockGemmARegBSmemCRegImpl{}(a_block_tensor, b_block_smem_window);
return BlockGemmARegBGmemCRegImpl{}(a_block_tensor, b_block_smem_window);
}
};
......
......@@ -49,6 +49,10 @@ struct BlockGemmASmemBSmemCRegV1DefaultPolicy
{
return make_tuple(WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution{}, 4, 1);
}
else
{
static_assert(false, "Unsupported data type configuration for GEMM warp execution.");
}
}
};
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include <iostream>
#include <string>
namespace ck_tile {
template <typename TilePartitioner_,
typename GemmPipeline_,
typename EpiloguePipeline_,
typename LayoutA_,
typename LayoutB_,
typename LayoutC_>
struct GemmKernel
{
using TilePartitioner = remove_cvref_t<TilePartitioner_>;
using GemmPipeline = remove_cvref_t<GemmPipeline_>;
using EpiloguePipeline = remove_cvref_t<EpiloguePipeline_>;
using LayoutA = remove_cvref_t<LayoutA_>;
using LayoutB = remove_cvref_t<LayoutB_>;
using LayoutC = remove_cvref_t<LayoutC_>;
static constexpr index_t KernelBlockSize = GemmPipeline::kBlockSize;
using ADataType = remove_cvref_t<typename GemmPipeline::ADataType>;
using BDataType = remove_cvref_t<typename GemmPipeline::BDataType>;
using CAccDataType = remove_cvref_t<typename GemmPipeline::CDataType>;
using CODataType = remove_cvref_t<typename EpiloguePipeline::ODataType>;
__host__ static constexpr auto GridSize(index_t M_size, index_t N_size, index_t Batch_size)
{
return TilePartitioner::GridSize(M_size, N_size, Batch_size);
}
__host__ static constexpr auto BlockSize() { return dim3(KernelBlockSize); }
struct GemmCommonKargs
{
const void* a_ptr;
const void* b_ptr;
void* c_ptr;
float epsilon;
ck_tile::index_t M;
ck_tile::index_t N;
ck_tile::index_t K;
ck_tile::index_t stride_A;
ck_tile::index_t stride_B;
ck_tile::index_t stride_C;
};
CK_TILE_HOST static constexpr GemmCommonKargs MakeKargs(const void* a_ptr,
const void* b_ptr,
void* c_ptr,
float epsilon,
ck_tile::index_t M,
ck_tile::index_t N,
ck_tile::index_t K,
ck_tile::index_t stride_A,
ck_tile::index_t stride_B,
ck_tile::index_t stride_C)
{
return GemmCommonKargs{a_ptr, b_ptr, c_ptr, epsilon, M, N, K, stride_A, stride_B, stride_C};
}
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{
return ck_tile::max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
}
CK_TILE_DEVICE void operator()(GemmCommonKargs kargs) const
{
const auto [i_m, i_n] = TilePartitioner{}();
// options
const ADataType* a_start = static_cast<const ADataType*>(kargs.a_ptr);
const BDataType* b_start = static_cast<const BDataType*>(kargs.b_ptr);
// Convert pointers to tensor views
auto a_tensor_view = [&]() {
if constexpr(std::is_same_v<LayoutA, tensor_layout::gemm::ColumnMajor>)
{
return make_naive_tensor_view<address_space_enum::global>(
a_start,
make_tuple(kargs.M, kargs.K),
make_tuple(1, kargs.stride_A),
number<GemmPipeline::AlignmentA>{},
number<1>{});
}
else
{
return make_naive_tensor_view<address_space_enum::global>(
a_start,
make_tuple(kargs.M, kargs.K),
make_tuple(kargs.stride_A, 1),
number<GemmPipeline::AlignmentA>{},
number<1>{});
}
}();
auto b_tensor_view = [&]() {
if constexpr(std::is_same_v<LayoutB, tensor_layout::gemm::RowMajor>)
{
return make_naive_tensor_view<address_space_enum::global>(
b_start,
make_tuple(kargs.N, kargs.K),
make_tuple(1, kargs.stride_B),
number<GemmPipeline::AlignmentB>{},
number<1>{});
}
else
{ // Default NK layout
return make_naive_tensor_view<address_space_enum::global>(
b_start,
make_tuple(kargs.N, kargs.K),
make_tuple(kargs.stride_B, 1),
number<GemmPipeline::AlignmentB>{},
number<1>{});
}
}();
auto a_pad_view = pad_tensor_view(
a_tensor_view,
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kK>{}),
sequence < 0,
GemmPipeline::kPadA ? 1 : 0 > {});
auto ABlockWindow = make_tile_window(
a_pad_view,
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kK>{}),
{i_m, 0});
auto b_pad_view = pad_tensor_view(
b_tensor_view,
make_tuple(number<TilePartitioner::kN>{}, number<TilePartitioner::kK>{}),
sequence < 0,
GemmPipeline::kPadB ? 1 : 0 > {});
auto BBlockWindow = make_tile_window(
b_pad_view,
make_tuple(number<TilePartitioner::kN>{}, number<TilePartitioner::kK>{}),
{i_n, 0});
// allocate LDS
__shared__ char smem_ptr[GetSmemSize()];
const index_t num_loop = (kargs.K + TilePartitioner::kK - 1) / TilePartitioner::kK;
auto acc = GemmPipeline{}(ABlockWindow, BBlockWindow, num_loop, smem_ptr);
CODataType* c_start = static_cast<CODataType*>(kargs.c_ptr);
auto c_tensor_view = [&]() {
if constexpr(std::is_same_v<LayoutC, tensor_layout::gemm::ColumnMajor>)
{
return make_naive_tensor_view<address_space_enum::global>(
c_start,
make_tuple(kargs.M, kargs.N),
make_tuple(1, kargs.stride_C),
number<GemmPipeline::AlignmentC>{},
number<1>{});
}
else
{
return make_naive_tensor_view<address_space_enum::global>(
c_start,
make_tuple(kargs.M, kargs.N),
make_tuple(kargs.stride_C, 1),
number<GemmPipeline::AlignmentC>{},
number<1>{});
}
}();
auto c_pad_view = pad_tensor_view(
c_tensor_view,
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kN>{}),
sequence < 0,
GemmPipeline::kPadC ? 1 : 0 > {});
auto CBlockWindow_pad = make_tile_window(
c_pad_view,
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kN>{}),
{i_m, i_n});
EpiloguePipeline{}(CBlockWindow_pad, acc);
}
};
} // namespace ck_tile
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment