Unverified Commit 9f8ab221 authored by zjing14's avatar zjing14 Committed by GitHub
Browse files

Merge branch 'develop' into add_int8_wmma_example_instance

parents 755ace59 b4fc4d0b
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <index_t NumDTensor>
struct ComputePtrOffsetOfStridedBatch
{
ComputePtrOffsetOfStridedBatch() = default;
ComputePtrOffsetOfStridedBatch(index_t BatchStrideA,
index_t BatchStrideB,
Array<ck::index_t, NumDTensor> BatchStrideDs,
index_t BatchStrideE)
: BatchStrideA_(BatchStrideA),
BatchStrideB_(BatchStrideB),
BatchStrideDs_(BatchStrideDs),
BatchStrideE_(BatchStrideE)
{
}
__host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideA_);
}
__host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideB_);
}
__host__ __device__ constexpr auto GetDsPtrOffset(index_t g_idx) const
{
Array<long_index_t, NumDTensor> ds_offset;
static_for<0, NumDTensor, 1>{}(
[&](auto i) { ds_offset(i) = g_idx * static_cast<long_index_t>(BatchStrideDs_[i]); });
return ds_offset;
}
[[maybe_unused]] __host__ __device__ constexpr long_index_t GetEPtrOffset(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideE_);
}
// alias for kernels without multiple D
[[maybe_unused]] __host__ __device__ constexpr long_index_t GetCPtrOffset(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideE_);
}
index_t BatchStrideA_;
index_t BatchStrideB_;
Array<ck::index_t, NumDTensor> BatchStrideDs_;
index_t BatchStrideE_;
index_t& BatchStrideC_ = BatchStrideE_; // alias for kernels without multiple D
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -5,64 +5,41 @@
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/device_image_to_column.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_image_to_column.hpp"
#include "ck/tensor_operation/gpu/device/device_conv_tensor_rearrange.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_tensor_rearrange.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp"
#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/conv_tensor_rearrange_op.hpp"
#include "ck/host_utility/io.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <typename InputGridDesc,
typename InputDataType,
typename OutputGridDesc,
typename OutputDataType,
typename Block2ETileMap,
typename GridwiseImageToColumnKernel>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_image_to_column(const InputGridDesc in_grid_desc,
const InputDataType* __restrict__ p_in_global,
const OutputGridDesc out_grid_desc,
OutputDataType* __restrict__ p_out_global,
const Block2ETileMap block_2_tile_map)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \
defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx1030__) || defined(__gfx1100__) || \
defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx941__) || defined(__gfx942__))
GridwiseImageToColumnKernel::Run(
in_grid_desc, p_in_global, out_grid_desc, p_out_global, block_2_tile_map);
#else
ignore = in_grid_desc;
ignore = p_in_global;
ignore = out_grid_desc;
ignore = p_out_global;
ignore = block_2_tile_map;
#endif
}
// Image to column for input layout NDHWC:
// input : input image [N, Di, Hi, Wi, C],
// output : output image [N * Do * Ho * Wo, Z * Y * X * C]
// input : input image [N, Di, Hi, Wi, C]
// output : gemm form [N * Do * Ho * Wo, Z * Y * X * C]
template <index_t NDimSpatial,
typename InputLayout,
typename ImageLayout,
typename InputDataType,
typename OutputDataType,
index_t BlockSize,
index_t MPerBlock,
index_t KPerBlock,
typename ThreadClusterLengths,
index_t ScalarPerVector>
index_t ScalarPerVector,
typename std::enable_if<NDimSpatial >= 1 && NDimSpatial <= 3, bool>::type = false>
struct DeviceImageToColumnImpl
: public DeviceImageToColumn<NDimSpatial, InputLayout, InputDataType, OutputDataType>
: public DeviceConvTensorRearrange<NDimSpatial,
ImageLayout,
InputDataType,
OutputDataType,
conv_tensor_rearrange_op::ImageToColumn>
{
static constexpr auto I0 = Number<0>{};
......@@ -83,7 +60,7 @@ struct DeviceImageToColumnImpl
const std::array<index_t, NDimSpatial>& input_spatial_lengths,
const std::array<index_t, NDimSpatial>& filter_spatial_lengths,
const std::array<index_t, NDimSpatial>& output_spatial_lengths,
const std::array<index_t, NDimSpatial + 3>& input_g_n_c_wis_strides,
const std::array<index_t, NDimSpatial + 3>& image_g_n_c_wis_strides,
const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads,
......@@ -110,9 +87,9 @@ struct DeviceImageToColumnImpl
c_g_n_k_wos_lengths[I1] = N;
const auto in_gemmmraw_gemmkraw_desc =
conv_to_gemm_transformer.template MakeADescriptor_M_K<InputLayout>(
conv_to_gemm_transformer.template MakeADescriptor_M_K<ImageLayout>(
a_g_n_c_wis_lengths,
input_g_n_c_wis_strides,
image_g_n_c_wis_strides,
b_g_k_c_xs_lengths,
{}, // not needed for A Descriptor
c_g_n_k_wos_lengths,
......@@ -132,7 +109,7 @@ struct DeviceImageToColumnImpl
const ck::index_t C,
const std::array<index_t, NDimSpatial>& filter_spatial_lengths,
const std::array<index_t, NDimSpatial>& output_spatial_lengths,
const std::array<index_t, 2>& output_m_k_strides)
const std::array<index_t, 2>& gemm_m_k_strides)
{
const index_t NDoHoWo =
N * ck::accumulate_n<index_t>(
......@@ -141,7 +118,7 @@ struct DeviceImageToColumnImpl
C * ck::accumulate_n<index_t>(
filter_spatial_lengths.begin(), NDimSpatial, 1, std::multiplies<>());
const auto desc_mraw_kraw = make_naive_tensor_descriptor(
make_tuple(NDoHoWo, CZYX), make_tuple(output_m_k_strides[I0], output_m_k_strides[I1]));
make_tuple(NDoHoWo, CZYX), make_tuple(gemm_m_k_strides[I0], gemm_m_k_strides[I1]));
const auto desc_m_k = matrix_padder.PadADescriptor_M_K(desc_mraw_kraw);
return desc_m_k;
......@@ -155,28 +132,29 @@ struct DeviceImageToColumnImpl
decltype(BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, KPerBlock, OutputGridDesc>(
OutputGridDesc{}))>;
using GridwiseImageToColumnKernel = GridwiseImageToColumn<InputGridDesc,
InputDataType,
OutputGridDesc,
OutputDataType,
BlockSize,
MPerBlock,
KPerBlock,
ThreadClusterLengths,
ScalarPerVector,
Block2ETileMap>;
using GridwiseTensorRearrangeKernel = GridwiseTensorRearrange<InputGridDesc,
InputDataType,
OutputGridDesc,
OutputDataType,
BlockSize,
MPerBlock,
KPerBlock,
ThreadClusterLengths,
ScalarPerVector,
InMemoryDataOperationEnum::Set,
Block2ETileMap>;
struct Argument : public BaseArgument
{
Argument(const void* p_in, // input image
void* p_out, // output image
void* p_out, // gemm form
const ck::index_t N,
const ck::index_t C,
const std::array<index_t, NDimSpatial>& input_spatial_lengths,
const std::array<index_t, NDimSpatial>& filter_spatial_lengths,
const std::array<index_t, NDimSpatial>& output_spatial_lengths,
const std::array<index_t, NDimSpatial + 3>& input_g_n_c_wis_strides,
const std::array<index_t, 2>& output_m_k_strides,
const std::array<index_t, NDimSpatial + 3>& image_g_n_c_wis_strides,
const std::array<index_t, 2>& gemm_m_k_strides,
const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads,
......@@ -185,7 +163,7 @@ struct DeviceImageToColumnImpl
X_(filter_spatial_lengths[NDimSpatial - I1]),
p_in_{static_cast<const InputDataType*>(p_in)},
p_out_{static_cast<OutputDataType*>(p_out)},
input_g_n_c_wis_strides_{input_g_n_c_wis_strides},
image_g_n_c_wis_strides_{image_g_n_c_wis_strides},
conv_filter_strides_{conv_filter_strides},
conv_filter_dilations_{conv_filter_dilations},
input_left_pads_{input_left_pads},
......@@ -197,7 +175,7 @@ struct DeviceImageToColumnImpl
input_spatial_lengths,
filter_spatial_lengths,
output_spatial_lengths,
input_g_n_c_wis_strides,
image_g_n_c_wis_strides,
conv_filter_strides,
conv_filter_dilations,
......@@ -205,7 +183,7 @@ struct DeviceImageToColumnImpl
input_right_pads);
out_grid_desc_m_k_ = MakeOutDescriptor_M_K(
N, C, filter_spatial_lengths, output_spatial_lengths, output_m_k_strides);
N, C, filter_spatial_lengths, output_spatial_lengths, gemm_m_k_strides);
}
void Print() const
......@@ -220,7 +198,7 @@ struct DeviceImageToColumnImpl
const InputDataType* p_in_;
OutputDataType* p_out_;
const std::array<index_t, NDimSpatial + 3>& input_g_n_c_wis_strides_;
const std::array<index_t, NDimSpatial + 3>& image_g_n_c_wis_strides_;
const std::array<index_t, NDimSpatial>& conv_filter_strides_;
const std::array<index_t, NDimSpatial>& conv_filter_dilations_;
const std::array<index_t, NDimSpatial>& input_left_pads_;
......@@ -243,12 +221,12 @@ struct DeviceImageToColumnImpl
BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, KPerBlock, OutputGridDesc>(
arg.out_grid_desc_m_k_);
const index_t grid_size = block_2_tile_map.CalculateGridSize(arg.out_grid_desc_m_k_);
const auto kernel = kernel_image_to_column<InputGridDesc,
InputDataType,
OutputGridDesc,
OutputDataType,
Block2ETileMap,
GridwiseImageToColumnKernel>;
const auto kernel = kernel_tensor_rearrange<InputGridDesc,
InputDataType,
OutputGridDesc,
OutputDataType,
Block2ETileMap,
GridwiseTensorRearrangeKernel>;
float elapsed_time = launch_and_time_kernel(stream_config,
kernel,
......@@ -273,12 +251,8 @@ struct DeviceImageToColumnImpl
bool IsSupportedArgument(const Argument& arg)
{
using namespace tensor_layout::convolution;
if(!(std::is_same_v<InputLayout, GNWC> || std::is_same_v<InputLayout, GNHWC> ||
std::is_same_v<InputLayout, GNDHWC>))
{
return false;
}
if(!(NDimSpatial >= 1 && NDimSpatial <= 3))
if constexpr(!(std::is_same_v<ImageLayout, GNWC> || std::is_same_v<ImageLayout, GNHWC> ||
std::is_same_v<ImageLayout, GNDHWC>))
{
return false;
}
......@@ -287,8 +261,8 @@ struct DeviceImageToColumnImpl
const auto w_pad_right = arg.input_right_pads_[NDimSpatial - I1];
const auto dilation_x = arg.conv_filter_dilations_[NDimSpatial - I1];
const auto stride_x = arg.conv_filter_strides_[NDimSpatial - I1];
bool is_w_packed = arg.input_g_n_c_wis_strides_[NDimSpatial + I2] == arg.C_;
bool is_c_packed = arg.input_g_n_c_wis_strides_[I2] == 1;
bool is_w_packed = arg.image_g_n_c_wis_strides_[NDimSpatial + I2] == arg.C_;
bool is_c_packed = arg.image_g_n_c_wis_strides_[I2] == 1;
// check vector acces with c not packed
if(!is_c_packed && ScalarPerVector != 1)
......@@ -310,8 +284,8 @@ struct DeviceImageToColumnImpl
if(dilation_x > 1 && arg.C_ % ScalarPerVector != 0)
return false;
return GridwiseImageToColumnKernel::CheckValidity(arg.in_grid_desc_m_k_,
arg.out_grid_desc_m_k_);
return GridwiseTensorRearrangeKernel::CheckValidity(arg.in_grid_desc_m_k_,
arg.out_grid_desc_m_k_);
}
bool IsSupportedArgument(const BaseArgument* p_arg) override
......@@ -320,14 +294,14 @@ struct DeviceImageToColumnImpl
}
static auto MakeArgument(const void* p_in, // input image
void* p_out, // output image
void* p_out, // gemm form
const ck::index_t N,
const ck::index_t C,
const std::array<index_t, NDimSpatial>& input_spatial_lengths,
const std::array<index_t, NDimSpatial>& filter_spatial_lengths,
const std::array<index_t, NDimSpatial>& output_spatial_lengths,
const std::array<index_t, NDimSpatial + 3>& input_g_n_c_wis_strides,
const std::array<index_t, 2>& output_m_k_strides,
const std::array<index_t, NDimSpatial + 3>& image_g_n_c_wis_strides,
const std::array<index_t, 2>& gemm_m_k_strides,
const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads,
......@@ -340,8 +314,8 @@ struct DeviceImageToColumnImpl
input_spatial_lengths,
filter_spatial_lengths,
output_spatial_lengths,
input_g_n_c_wis_strides,
output_m_k_strides,
image_g_n_c_wis_strides,
gemm_m_k_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
......@@ -352,14 +326,14 @@ struct DeviceImageToColumnImpl
std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_in, // input image
void* p_out, // output image
void* p_out, // gemm form
const ck::index_t N,
const ck::index_t C,
const std::array<index_t, NDimSpatial>& input_spatial_lengths,
const std::array<index_t, NDimSpatial>& filter_spatial_lengths,
const std::array<index_t, NDimSpatial>& output_spatial_lengths,
const std::array<index_t, NDimSpatial + 3>& input_g_n_c_wis_strides,
const std::array<index_t, 2>& output_m_k_strides,
const std::array<index_t, NDimSpatial + 3>& image_g_n_c_wis_strides,
const std::array<index_t, 2>& gemm_m_k_strides,
const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads,
......@@ -372,8 +346,8 @@ struct DeviceImageToColumnImpl
input_spatial_lengths,
filter_spatial_lengths,
output_spatial_lengths,
input_g_n_c_wis_strides,
output_m_k_strides,
image_g_n_c_wis_strides,
gemm_m_k_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
......
......@@ -28,6 +28,7 @@ template <typename XDataType,
typename BetaDataType,
typename ComputeDataType,
typename YDataType,
typename SaveMeanInvStdDataType,
typename YElementwiseOperation,
index_t Rank,
index_t NumReduceDim,
......@@ -43,12 +44,13 @@ template <typename XDataType,
index_t BetaSrcVectorDim,
index_t BetaSrcVectorSize,
index_t YDstVectorSize,
index_t SaveMeanInvStdDstVectorSize,
bool UseWelford = true>
struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
GammaDataType,
BetaDataType,
ComputeDataType,
YDataType,
SaveMeanInvStdDataType,
YElementwiseOperation,
Rank,
NumReduceDim>
......@@ -64,18 +66,24 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
(BetaSrcVectorDim == 1 && KThreadSliceSize % BetaSrcVectorSize == 0)),
"Invalid thread slice sizes and/or beta vector sizes configuration, please check!");
static_assert(MThreadSliceSize % SaveMeanInvStdDstVectorSize == 0,
"Invalid thread slice sizes and/or save mean and inverse std vector sizes "
"configuration, please check!");
using PassThrough = tensor_operation::element_wise::PassThrough;
static constexpr index_t NumInvariantDim = Rank - NumReduceDim;
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
static constexpr bool reduceAllDim = (NumInvariantDim == 0);
static_assert(!reduceAllDim); // TODO
static auto MakeSrc2dDescriptor(const std::vector<index_t>& inLengths,
const std::vector<index_t>& inStrides,
int numBlockTileIteration)
{
constexpr index_t NumInvariantDim = Rank - NumReduceDim;
static constexpr index_t numSrcDim = Rank;
static constexpr bool reduceAllDim = (NumInvariantDim == 0);
const auto tupleSrcLengths = make_tuple_from_array(inLengths, Number<numSrcDim>{});
const auto tupleSrcStrides = make_tuple_from_array(inStrides, Number<numSrcDim>{});
......@@ -133,7 +141,37 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
return (in_grid_desc_m_k_padded);
};
static auto MakeSaveMeanInvStdDescriptor_M(const std::vector<index_t>& lengths,
const std::vector<index_t>& strides)
{
using InvariantDims = typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type;
const auto tupleSrcLengths = make_tuple_from_array_and_index_seq(lengths, InvariantDims{});
const auto tupleSrcStrides = make_tuple_from_array_and_index_seq(strides, InvariantDims{});
const auto desc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides);
const auto grid_desc_m =
transform_tensor_descriptor(desc,
make_tuple(make_merge_transform(tupleSrcLengths)),
make_tuple(InvariantDims{}),
make_tuple(Sequence<0>{}));
const auto invariantLength = grid_desc_m.GetLength(Number<0>{});
const auto pad_M =
math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
auto grid_desc_m_padded = transform_tensor_descriptor(
grid_desc_m,
make_tuple(make_right_pad_transform(invariantLength, pad_M)),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0>{}));
return grid_desc_m_padded;
}
using GridDesc_M_K = decltype(MakeSrc2dDescriptor({1}, {1}, 1));
using GridDesc_M = decltype(MakeSaveMeanInvStdDescriptor_M({1}, {1}));
struct Argument : public BaseArgument
{
......@@ -142,17 +180,23 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
const std::vector<index_t> gammaStrides,
const std::vector<index_t> betaStrides,
const std::vector<index_t> yStrides,
const std::vector<index_t> saveMeanStrides,
const std::vector<index_t> saveInvStdStrides,
const std::vector<index_t> reduceDims,
YElementwiseOperation y_elementwise_op,
double epsilon,
const XDataType* p_x,
const GammaDataType* p_gamma,
const BetaDataType* p_beta,
YDataType* p_y)
YDataType* p_y,
SaveMeanInvStdDataType* p_saveMean,
SaveMeanInvStdDataType* p_saveInvStd)
: p_x_(p_x),
p_gamma_(p_gamma),
p_beta_(p_beta),
p_y_(p_y),
p_saveMean_(p_saveMean),
p_saveInvStd_(p_saveInvStd),
y_elementwise_op_(y_elementwise_op)
{
epsilon_ = static_cast<ComputeDataType>(epsilon);
......@@ -162,16 +206,14 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
yStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(yStrides, reduceDims);
gammaStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(gammaStrides, reduceDims);
betaStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(betaStrides, reduceDims);
saveMeanStrides_ = saveMeanStrides;
saveInvStdStrides_ = saveInvStdStrides;
long_index_t invariant_length;
long_index_t reduce_length;
std::tie(invariant_length, reduce_length) =
get_2d_lengths<Rank, NumReduceDim>(Lengths_);
std::tie(MRaw_, KRaw_) = get_2d_lengths<Rank, NumReduceDim>(Lengths_);
numBlockTileIteration_ = math::integer_divide_ceil(reduce_length, K_BlockTileSize);
numBlockTileIteration_ = math::integer_divide_ceil(KRaw_, K_BlockTileSize);
gridSize_ = math::integer_divide_ceil(invariant_length, M_BlockTileSize);
gridSize_ = math::integer_divide_ceil(MRaw_, M_BlockTileSize);
x_grid_desc_m_k_ = MakeSrc2dDescriptor(Lengths_, xStrides_, numBlockTileIteration_);
gamma_grid_desc_m_k_ =
......@@ -179,9 +221,16 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
beta_grid_desc_m_k_ =
MakeSrc2dDescriptor(Lengths_, betaStrides_, numBlockTileIteration_);
y_grid_desc_m_k_ = MakeSrc2dDescriptor(Lengths_, yStrides_, numBlockTileIteration_);
save_mean_grid_desc_m_ = MakeSaveMeanInvStdDescriptor_M(Lengths_, saveMeanStrides);
save_inv_std_grid_desc_m_ = MakeSaveMeanInvStdDescriptor_M(Lengths_, saveInvStdStrides);
isSweeponce_ =
x_grid_desc_m_k_.GetLength(Number<1>{}) <= KThreadClusterSize * KThreadSliceSize;
if constexpr(NumInvariantDim == 0)
invariant_lowest_length_ = 1;
else
invariant_lowest_length_ = Lengths_[NumInvariantDim - 1];
}
ComputeDataType epsilon_;
......@@ -190,12 +239,16 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
const GammaDataType* p_gamma_;
const BetaDataType* p_beta_;
YDataType* p_y_;
SaveMeanInvStdDataType* p_saveMean_;
SaveMeanInvStdDataType* p_saveInvStd_;
std::vector<index_t> Lengths_;
std::vector<index_t> xStrides_;
std::vector<index_t> gammaStrides_;
std::vector<index_t> betaStrides_;
std::vector<index_t> yStrides_;
std::vector<index_t> saveMeanStrides_;
std::vector<index_t> saveInvStdStrides_;
YElementwiseOperation y_elementwise_op_;
......@@ -206,7 +259,14 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
GridDesc_M_K gamma_grid_desc_m_k_;
GridDesc_M_K beta_grid_desc_m_k_;
GridDesc_M_K y_grid_desc_m_k_;
GridDesc_M save_mean_grid_desc_m_;
GridDesc_M save_inv_std_grid_desc_m_;
bool isSweeponce_;
index_t MRaw_; // invarient length
index_t KRaw_; // reduce length
index_t invariant_lowest_length_;
};
struct Invoker : public BaseInvoker
......@@ -217,9 +277,11 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
GammaDataType,
BetaDataType,
YDataType,
SaveMeanInvStdDataType,
ComputeDataType,
YElementwiseOperation,
GridDesc_M_K,
GridDesc_M,
BlockSize,
MThreadClusterSize,
KThreadClusterSize,
......@@ -233,6 +295,7 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
BetaSrcVectorSize,
XYSrcVectorDim,
YDstVectorSize,
SaveMeanInvStdDstVectorSize,
UseWelford>(arg.isSweeponce_);
float avg_time = 0;
......@@ -245,12 +308,16 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
arg.gamma_grid_desc_m_k_,
arg.beta_grid_desc_m_k_,
arg.y_grid_desc_m_k_,
arg.save_mean_grid_desc_m_,
arg.save_inv_std_grid_desc_m_,
arg.numBlockTileIteration_,
arg.epsilon_,
arg.p_x_,
arg.p_gamma_,
arg.p_beta_,
arg.p_y_,
arg.p_saveMean_,
arg.p_saveInvStd_,
arg.y_elementwise_op_);
return (avg_time);
......@@ -267,8 +334,6 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
{
const Argument* p_arg_ = dynamic_cast<const Argument*>(p_arg);
constexpr index_t NumInvariantDim = Rank - NumReduceDim;
if constexpr(XYSrcVectorDim == 0)
{
if constexpr(NumInvariantDim == 0)
......@@ -277,13 +342,15 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
}
else
{
printf("!!!! %d\n", p_arg_->invariant_lowest_length_);
if(p_arg_->xStrides_[NumInvariantDim - 1] != 1)
return false;
if(p_arg_->invariant_lowest_length % XSrcVectorSize != 0)
if(p_arg_->invariant_lowest_length_ % XSrcVectorSize != 0)
return false;
if(p_arg_->invariant_lowest_length % YDstVectorSize != 0)
if(p_arg_->invariant_lowest_length_ % YDstVectorSize != 0)
return false;
};
}
......@@ -325,7 +392,7 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
if(p_arg_->betaStrides_[NumInvariantDim - 1] != 1)
return (false);
if(p_arg_->invariant_lowest_length % BetaSrcVectorSize != 0)
if(p_arg_->invariant_lowest_length_ % BetaSrcVectorSize != 0)
return (false);
}
else // if fastest dim is reduced
......@@ -337,6 +404,9 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
return (false);
}
if(p_arg_->invariant_lowest_length_ % SaveMeanInvStdDstVectorSize != 0)
return false;
return true;
};
......@@ -346,6 +416,8 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
const std::vector<index_t> gammaStrides,
const std::vector<index_t> betaStrides,
const std::vector<index_t> yStrides,
const std::vector<index_t> saveMeanStrides,
const std::vector<index_t> saveInvStdStrides,
const std::vector<index_t> reduceDims,
double epsilon,
const void* p_x,
......@@ -353,27 +425,30 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
const void* p_beta,
void* p_y,
void* p_saveMean,
void* p_saveInvVar,
void* p_saveInvStd,
YElementwiseOperation y_elementwise_op) override
{
// TODO
// Optional cache of the intermediate results (mean and InvVariance) during the
// forward pass could speedup in the backward
ignore = p_saveMean;
ignore = p_saveInvVar;
if(lengths.size() != Rank || xStrides.size() != Rank || gammaStrides.size() != Rank ||
betaStrides.size() != Rank || yStrides.size() != Rank ||
saveMeanStrides.size() != NumInvariantDim || saveInvStdStrides.size() != NumInvariantDim)
throw std::runtime_error("dimension is incorrect");
return std::make_unique<Argument>(lengths,
xStrides,
gammaStrides,
betaStrides,
yStrides,
saveMeanStrides,
saveInvStdStrides,
reduceDims,
y_elementwise_op,
epsilon,
static_cast<const XDataType*>(p_x),
static_cast<const GammaDataType*>(p_gamma),
static_cast<const BetaDataType*>(p_beta),
static_cast<YDataType*>(p_y));
static_cast<YDataType*>(p_y),
static_cast<SaveMeanInvStdDataType*>(p_saveMean),
static_cast<SaveMeanInvStdDataType*>(p_saveInvStd));
};
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
......
......@@ -19,7 +19,7 @@
namespace ck {
template <typename GridwiseWelford,
typename XDataType,
typename MeanVarDataType,
typename WorkspaceMeanVarDataType,
typename ComputeDataType,
typename XGridDesc_M_K,
typename MeanVarGridDesc_M_KBlock>
......@@ -28,8 +28,8 @@ kernel_normalizationSplitK1st(const XGridDesc_M_K x_grid_desc_m_k,
const MeanVarGridDesc_M_KBlock mean_var_grid_desc_m_kblock,
index_t num_k_block_tile_iteration,
const XDataType* const __restrict__ p_x_global,
MeanVarDataType* const __restrict__ p_welford_mean,
MeanVarDataType* const __restrict__ p_welford_variance,
WorkspaceMeanVarDataType* const __restrict__ p_welford_mean,
WorkspaceMeanVarDataType* const __restrict__ p_welford_variance,
int32_t* const __restrict__ p_welford_count)
{
GridwiseWelford::Run(x_grid_desc_m_k,
......@@ -42,16 +42,18 @@ kernel_normalizationSplitK1st(const XGridDesc_M_K x_grid_desc_m_k,
};
template <typename GridwiseWelfordNormalization,
typename MeanVarDataType,
typename WorkspaceMeanVarDataType,
typename XDataType,
typename GammaDataType,
typename BetaDataType,
typename YDataType,
typename SaveMeanInvStdDataType,
typename ComputeDataType,
typename YElementwiseOperation,
typename MeanVarGridDesc_M_KBlock,
typename CountGridDesc_M_KBlock,
typename XYGammaBetaGridDesc_M_K>
typename XYGammaBetaGridDesc_M_K,
typename SaveMeanInvStdGridDesc_M>
__global__ void
kernel_normalizationSplitK2nd(const MeanVarGridDesc_M_KBlock mean_var_grid_desc_m_kblock,
const CountGridDesc_M_KBlock count_grid_desc_m_kblock,
......@@ -59,17 +61,21 @@ kernel_normalizationSplitK2nd(const MeanVarGridDesc_M_KBlock mean_var_grid_desc_
const XYGammaBetaGridDesc_M_K gamma_grid_desc_m_k,
const XYGammaBetaGridDesc_M_K beta_grid_desc_m_k,
const XYGammaBetaGridDesc_M_K y_grid_desc_m_k,
const SaveMeanInvStdGridDesc_M save_mean_grid_desc_m,
const SaveMeanInvStdGridDesc_M save_inv_std_grid_desc_m,
index_t num_k_mean_var_count_iteration,
index_t num_k_block_tile_iteration,
index_t k_grid_size,
ComputeDataType epsilon,
const MeanVarDataType* const p_mean_global,
const MeanVarDataType* const p_variance_global,
const WorkspaceMeanVarDataType* const p_mean_global,
const WorkspaceMeanVarDataType* const p_variance_global,
const int32_t* const p_welford_count_global,
const XDataType* const __restrict__ p_x_global,
const GammaDataType* const __restrict__ p_gamma_global,
const BetaDataType* const __restrict__ p_beta_global,
YDataType* const __restrict__ p_y_global,
SaveMeanInvStdDataType* const __restrict__ p_save_mean_global,
SaveMeanInvStdDataType* const __restrict__ p_save_inv_std_global,
const YElementwiseOperation y_elementwise_op)
{
GridwiseWelfordNormalization::Run(mean_var_grid_desc_m_kblock,
......@@ -78,6 +84,8 @@ kernel_normalizationSplitK2nd(const MeanVarGridDesc_M_KBlock mean_var_grid_desc_
gamma_grid_desc_m_k,
beta_grid_desc_m_k,
y_grid_desc_m_k,
save_mean_grid_desc_m,
save_inv_std_grid_desc_m,
num_k_mean_var_count_iteration,
num_k_block_tile_iteration,
k_grid_size,
......@@ -89,6 +97,8 @@ kernel_normalizationSplitK2nd(const MeanVarGridDesc_M_KBlock mean_var_grid_desc_
p_gamma_global,
p_beta_global,
p_y_global,
p_save_mean_global,
p_save_inv_std_global,
y_elementwise_op);
};
} // namespace ck
......@@ -107,6 +117,7 @@ template <typename XDataType,
typename BetaDataType,
typename ComputeDataType,
typename YDataType,
typename SaveMeanInvStdDataType,
typename YElementwiseOperation,
index_t Rank,
index_t NumReduceDim,
......@@ -121,17 +132,18 @@ template <typename XDataType,
index_t GammaSrcVectorSize,
index_t BetaSrcVectorDim,
index_t BetaSrcVectorSize,
index_t YDstVectorSize>
index_t YDstVectorSize,
index_t SaveMeanInvStdDstVectorSize>
struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
GammaDataType,
BetaDataType,
ComputeDataType,
YDataType,
SaveMeanInvStdDataType,
YElementwiseOperation,
Rank,
NumReduceDim>
{
using MeanVarDataType = ComputeDataType;
using WorkspaceMeanVarDataType = SaveMeanInvStdDataType;
static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize);
static_assert(
......@@ -144,22 +156,28 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
(BetaSrcVectorDim == 1 && KThreadSliceSize % BetaSrcVectorSize == 0)),
"Invalid thread slice sizes and/or beta vector sizes configuration, please check!");
static_assert(MThreadSliceSize % SaveMeanInvStdDstVectorSize == 0,
"Invalid thread slice sizes and/or save mean and inverse std vector sizes "
"configuration, please check!");
using PassThrough = tensor_operation::element_wise::PassThrough;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr index_t NumInvariantDim = Rank - NumReduceDim;
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
static constexpr bool reduceAllDim = (NumInvariantDim == 0);
static_assert(!reduceAllDim); // TODO
static auto MakeSrc2dDescriptor(const std::vector<index_t>& inLengths,
const std::vector<index_t>& inStrides,
int kBlockSize,
int numBlockTileIteration)
{
constexpr index_t NumInvariantDim = Rank - NumReduceDim;
static constexpr index_t numSrcDim = Rank;
static constexpr bool reduceAllDim = (NumInvariantDim == 0);
const auto tupleSrcLengths = make_tuple_from_array(inLengths, Number<numSrcDim>{});
const auto tupleSrcStrides = make_tuple_from_array(inStrides, Number<numSrcDim>{});
......@@ -219,7 +237,7 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
};
template <typename DoPads, index_t MPerTile, index_t KPerTile>
static auto MakeMeanVarDescriptor_M_K(index_t M, index_t K)
static auto MakeWorkspaceMeanVarDescriptor_M_K(index_t M, index_t K)
{
const auto grid_desc_m_k =
make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(K, I1));
......@@ -227,26 +245,57 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
}
template <typename DoPads, index_t MPerTile, index_t KPerTile>
static auto MakeCountDescriptor_M_K(index_t M, index_t K)
static auto MakeWorkspaceCountDescriptor_M_K(index_t M, index_t K)
{
const auto grid_desc_m_k =
make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I0, I1));
return PadTensorDescriptor(grid_desc_m_k, make_tuple(MPerTile, KPerTile), DoPads{});
}
static auto MakeSaveMeanInvStdDescriptor_M(const std::vector<index_t>& lengths,
const std::vector<index_t>& strides)
{
using InvariantDims = typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type;
const auto tupleSrcLengths = make_tuple_from_array_and_index_seq(lengths, InvariantDims{});
const auto tupleSrcStrides = make_tuple_from_array_and_index_seq(strides, InvariantDims{});
const auto desc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides);
const auto grid_desc_m =
transform_tensor_descriptor(desc,
make_tuple(make_merge_transform(tupleSrcLengths)),
make_tuple(InvariantDims{}),
make_tuple(Sequence<0>{}));
const auto invariantLength = grid_desc_m.GetLength(Number<0>{});
const auto pad_M =
math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
auto grid_desc_m_padded = transform_tensor_descriptor(
grid_desc_m,
make_tuple(make_right_pad_transform(invariantLength, pad_M)),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0>{}));
return grid_desc_m_padded;
}
using SrcGridDesc_M_K = decltype(MakeSrc2dDescriptor({1}, {1}, 1, 1));
using Kernel1MeanVarGridDesc_M_KBlock =
decltype(MakeMeanVarDescriptor_M_K<Sequence<true, false>, 1, 1>(1, 1));
decltype(MakeWorkspaceMeanVarDescriptor_M_K<Sequence<true, false>, 1, 1>(1, 1));
using Kernel2MeanVarGridDesc_M_KBlock =
decltype(MakeMeanVarDescriptor_M_K<Sequence<true, true>, 1, 1>(1, 1));
decltype(MakeWorkspaceMeanVarDescriptor_M_K<Sequence<true, true>, 1, 1>(1, 1));
using Kernel2CountGridDesc_M_KBlock =
decltype(MakeCountDescriptor_M_K<Sequence<true, true>, 1, 1>(1, 1));
decltype(MakeWorkspaceCountDescriptor_M_K<Sequence<true, true>, 1, 1>(1, 1));
using SaveMeanInvStdGridDesc_M = decltype(MakeSaveMeanInvStdDescriptor_M({1}, {1}));
using GridwiseWelford = GridwiseNormalizationSplitK1st<XDataType,
ComputeDataType,
MeanVarDataType,
WorkspaceMeanVarDataType,
SrcGridDesc_M_K,
Kernel1MeanVarGridDesc_M_KBlock,
BlockSize,
......@@ -258,16 +307,18 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
XSrcVectorSize>;
using GridwiseWelfordNormalization =
GridwiseNormalizationSplitK2nd<MeanVarDataType,
GridwiseNormalizationSplitK2nd<WorkspaceMeanVarDataType,
XDataType,
GammaDataType,
BetaDataType,
YDataType,
SaveMeanInvStdDataType,
ComputeDataType,
YElementwiseOperation,
Kernel2MeanVarGridDesc_M_KBlock,
Kernel2CountGridDesc_M_KBlock,
SrcGridDesc_M_K,
SaveMeanInvStdGridDesc_M,
BlockSize,
MThreadClusterSize,
KThreadClusterSize,
......@@ -280,7 +331,8 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
BetaSrcVectorDim,
BetaSrcVectorSize,
XYVectorDim,
YDstVectorSize>;
YDstVectorSize,
SaveMeanInvStdDstVectorSize>;
struct Argument : public BaseArgument
{
......@@ -289,17 +341,23 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
const std::vector<index_t> gammaStrides,
const std::vector<index_t> betaStrides,
const std::vector<index_t> yStrides,
const std::vector<index_t> saveMeanStrides,
const std::vector<index_t> saveInvStdStrides,
const std::vector<index_t> reduceDims,
YElementwiseOperation y_elementwise_op,
double epsilon,
const XDataType* p_x,
const GammaDataType* p_gamma,
const BetaDataType* p_beta,
YDataType* p_y)
YDataType* p_y,
SaveMeanInvStdDataType* p_saveMean,
SaveMeanInvStdDataType* p_saveInvStd)
: p_x_(p_x),
p_gamma_(p_gamma),
p_beta_(p_beta),
p_y_(p_y),
p_saveMean_(p_saveMean),
p_saveInvStd_(p_saveInvStd),
p_workspace_mean_{nullptr},
p_workspace_var_{nullptr},
p_workspace_count_{nullptr},
......@@ -312,6 +370,8 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
yStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(yStrides, reduceDims);
gammaStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(gammaStrides, reduceDims);
betaStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(betaStrides, reduceDims);
saveMeanStrides_ = saveMeanStrides;
saveInvStdStrides_ = saveInvStdStrides;
std::tie(MRaw_, KRaw_) = get_2d_lengths<Rank, NumReduceDim>(Lengths_);
......@@ -346,20 +406,28 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
y_grid_desc_m_k_ =
MakeSrc2dDescriptor(Lengths_, yStrides_, kGridSize_, numBlockTileIteration_);
save_mean_grid_desc_m_ = MakeSaveMeanInvStdDescriptor_M(Lengths_, saveMeanStrides);
save_inv_std_grid_desc_m_ = MakeSaveMeanInvStdDescriptor_M(Lengths_, saveInvStdStrides);
// We don't need to pad in K dimension for Welford1. Set KPerTile 1.
kernel1_mean_var_grid_desc_m_kblock_ =
MakeMeanVarDescriptor_M_K<Sequence<true, false>, M_BlockTileSize, 1>(MRaw_,
kGridSize_);
MakeWorkspaceMeanVarDescriptor_M_K<Sequence<true, false>, M_BlockTileSize, 1>(
MRaw_, kGridSize_);
kernel2_mean_var_grid_desc_m_kblock_ =
MakeMeanVarDescriptor_M_K<Sequence<true, true>,
M_BlockTileSize,
K_MeanVarCountBlockTileSize>(MRaw_, kGridSize_);
MakeWorkspaceMeanVarDescriptor_M_K<Sequence<true, true>,
M_BlockTileSize,
K_MeanVarCountBlockTileSize>(MRaw_, kGridSize_);
kernel2_count_grid_desc_m_kblock_ =
MakeCountDescriptor_M_K<Sequence<true, true>,
M_BlockTileSize,
K_MeanVarCountBlockTileSize>(MRaw_, kGridSize_);
MakeWorkspaceCountDescriptor_M_K<Sequence<true, true>,
M_BlockTileSize,
K_MeanVarCountBlockTileSize>(MRaw_, kGridSize_);
if constexpr(NumInvariantDim == 0)
invariant_lowest_length_ = 1;
else
invariant_lowest_length_ = Lengths_[NumInvariantDim - 1];
}
ComputeDataType epsilon_;
......@@ -368,6 +436,8 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
const GammaDataType* p_gamma_;
const BetaDataType* p_beta_;
YDataType* p_y_;
SaveMeanInvStdDataType* p_saveMean_;
SaveMeanInvStdDataType* p_saveInvStd_;
void* p_workspace_mean_;
void* p_workspace_var_;
void* p_workspace_count_;
......@@ -377,6 +447,8 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
std::vector<index_t> gammaStrides_;
std::vector<index_t> betaStrides_;
std::vector<index_t> yStrides_;
std::vector<index_t> saveMeanStrides_;
std::vector<index_t> saveInvStdStrides_;
YElementwiseOperation y_elementwise_op_;
......@@ -389,6 +461,8 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
SrcGridDesc_M_K gamma_grid_desc_m_k_;
SrcGridDesc_M_K beta_grid_desc_m_k_;
SrcGridDesc_M_K y_grid_desc_m_k_;
SaveMeanInvStdGridDesc_M save_mean_grid_desc_m_;
SaveMeanInvStdGridDesc_M save_inv_std_grid_desc_m_;
Kernel1MeanVarGridDesc_M_KBlock kernel1_mean_var_grid_desc_m_kblock_;
Kernel2MeanVarGridDesc_M_KBlock kernel2_mean_var_grid_desc_m_kblock_;
......@@ -396,6 +470,8 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
index_t MRaw_; // invarient length
index_t KRaw_; // reduce length
index_t invariant_lowest_length_;
};
struct Invoker : public BaseInvoker
......@@ -408,60 +484,68 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
auto kernel1 = kernel_normalizationSplitK1st<GridwiseWelford,
XDataType,
MeanVarDataType,
WorkspaceMeanVarDataType,
ComputeDataType,
SrcGridDesc_M_K,
Kernel1MeanVarGridDesc_M_KBlock>;
auto kernel2 = kernel_normalizationSplitK2nd<GridwiseWelfordNormalization,
MeanVarDataType,
WorkspaceMeanVarDataType,
XDataType,
GammaDataType,
BetaDataType,
YDataType,
SaveMeanInvStdDataType,
ComputeDataType,
YElementwiseOperation,
Kernel2MeanVarGridDesc_M_KBlock,
Kernel2CountGridDesc_M_KBlock,
SrcGridDesc_M_K>;
SrcGridDesc_M_K,
SaveMeanInvStdGridDesc_M>;
float avg_time = 0;
avg_time += launch_and_time_kernel(stream_config,
kernel1,
dim3(arg.gridSize_),
dim3(BlockSize),
0,
arg.x_grid_desc_m_k_,
arg.kernel1_mean_var_grid_desc_m_kblock_,
arg.numBlockTileIteration_,
arg.p_x_,
static_cast<MeanVarDataType*>(arg.p_workspace_mean_),
static_cast<MeanVarDataType*>(arg.p_workspace_var_),
static_cast<int32_t*>(arg.p_workspace_count_));
avg_time += launch_and_time_kernel(stream_config,
kernel2,
dim3(arg.gridSize_),
dim3(BlockSize),
0,
arg.kernel2_mean_var_grid_desc_m_kblock_,
arg.kernel2_count_grid_desc_m_kblock_,
arg.x_grid_desc_m_k_,
arg.gamma_grid_desc_m_k_,
arg.beta_grid_desc_m_k_,
arg.y_grid_desc_m_k_,
arg.numMeanVarCountIteration_,
arg.numBlockTileIteration_,
arg.kGridSize_,
arg.epsilon_,
static_cast<MeanVarDataType*>(arg.p_workspace_mean_),
static_cast<MeanVarDataType*>(arg.p_workspace_var_),
static_cast<int32_t*>(arg.p_workspace_count_),
arg.p_x_,
arg.p_gamma_,
arg.p_beta_,
arg.p_y_,
arg.y_elementwise_op_);
avg_time += launch_and_time_kernel(
stream_config,
kernel1,
dim3(arg.gridSize_),
dim3(BlockSize),
0,
arg.x_grid_desc_m_k_,
arg.kernel1_mean_var_grid_desc_m_kblock_,
arg.numBlockTileIteration_,
arg.p_x_,
static_cast<WorkspaceMeanVarDataType*>(arg.p_workspace_mean_),
static_cast<WorkspaceMeanVarDataType*>(arg.p_workspace_var_),
static_cast<int32_t*>(arg.p_workspace_count_));
avg_time += launch_and_time_kernel(
stream_config,
kernel2,
dim3(arg.gridSize_),
dim3(BlockSize),
0,
arg.kernel2_mean_var_grid_desc_m_kblock_,
arg.kernel2_count_grid_desc_m_kblock_,
arg.x_grid_desc_m_k_,
arg.gamma_grid_desc_m_k_,
arg.beta_grid_desc_m_k_,
arg.y_grid_desc_m_k_,
arg.save_mean_grid_desc_m_,
arg.save_inv_std_grid_desc_m_,
arg.numMeanVarCountIteration_,
arg.numBlockTileIteration_,
arg.kGridSize_,
arg.epsilon_,
static_cast<const WorkspaceMeanVarDataType*>(arg.p_workspace_mean_),
static_cast<const WorkspaceMeanVarDataType*>(arg.p_workspace_var_),
static_cast<const int32_t*>(arg.p_workspace_count_),
arg.p_x_,
arg.p_gamma_,
arg.p_beta_,
arg.p_y_,
arg.p_saveMean_,
arg.p_saveInvStd_,
arg.y_elementwise_op_);
return avg_time;
};
......@@ -482,10 +566,10 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
int welford_size = pArg_->MRaw_ * pArg_->kGridSize_;
// workspace for welford intermediate mean
workspace_size += welford_size * sizeof(MeanVarDataType) + 64;
workspace_size += welford_size * sizeof(WorkspaceMeanVarDataType) + 64;
// workspace for welford intermediate variance
workspace_size += welford_size * sizeof(MeanVarDataType) + 64;
workspace_size += welford_size * sizeof(WorkspaceMeanVarDataType) + 64;
// workspace for welford intermediate count
workspace_size += pArg_->kGridSize_ * sizeof(int32_t) + 64;
......@@ -504,13 +588,13 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
// setup buffer used for intermediate welford mean
pArg_->p_workspace_mean_ = static_cast<char*>(pArg_->p_workspace_);
index_t mean_space_sz = welford_size * sizeof(MeanVarDataType);
index_t mean_space_sz = welford_size * sizeof(WorkspaceMeanVarDataType);
mean_space_sz = math::integer_least_multiple(mean_space_sz, 64);
// setup buffer used for intermediate welford varirance
pArg_->p_workspace_var_ = reinterpret_cast<char*>(pArg_->p_workspace_mean_) + mean_space_sz;
index_t variance_space_sz = welford_size * sizeof(MeanVarDataType);
index_t variance_space_sz = welford_size * sizeof(WorkspaceMeanVarDataType);
variance_space_sz = math::integer_least_multiple(variance_space_sz, 64);
// setup buffer used for intermediate welford count
......@@ -522,8 +606,6 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
{
const Argument* p_arg_ = dynamic_cast<const Argument*>(p_arg);
constexpr index_t NumInvariantDim = Rank - NumReduceDim;
if constexpr(XYVectorDim == 0)
{
if constexpr(NumInvariantDim == 0)
......@@ -535,10 +617,10 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
if(p_arg_->xStrides_[NumInvariantDim - 1] != 1)
return false;
if(p_arg_->invariant_lowest_length % XSrcVectorSize != 0)
if(p_arg_->invariant_lowest_length_ % XSrcVectorSize != 0)
return false;
if(p_arg_->invariant_lowest_length % YDstVectorSize != 0)
if(p_arg_->invariant_lowest_length_ % YDstVectorSize != 0)
return false;
};
}
......@@ -578,7 +660,7 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
if(p_arg_->betaStrides_[NumInvariantDim - 1] != 1)
return false;
if(p_arg_->invariant_lowest_length % BetaSrcVectorSize != 0)
if(p_arg_->invariant_lowest_length_ % BetaSrcVectorSize != 0)
return false;
}
else // if fastest dim is reduced
......@@ -593,6 +675,9 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
if(p_arg_->kGridSize_ <= 1)
return false;
if(p_arg_->invariant_lowest_length_ % SaveMeanInvStdDstVectorSize != 0)
return false;
return true;
};
......@@ -602,6 +687,8 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
const std::vector<index_t> gammaStrides,
const std::vector<index_t> betaStrides,
const std::vector<index_t> yStrides,
const std::vector<index_t> saveMeanStrides,
const std::vector<index_t> saveInvStdStrides,
const std::vector<index_t> reduceDims,
double epsilon,
const void* p_x,
......@@ -609,27 +696,30 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
const void* p_beta,
void* p_y,
void* p_saveMean,
void* p_saveInvVar,
void* p_saveInvStd,
YElementwiseOperation y_elementwise_op) override
{
// TODO
// Optional cache of the intermediate results (mean and InvVariance) during the
// forward pass could speedup in the backward
ignore = p_saveMean;
ignore = p_saveInvVar;
if(lengths.size() != Rank || xStrides.size() != Rank || gammaStrides.size() != Rank ||
betaStrides.size() != Rank || yStrides.size() != Rank ||
saveMeanStrides.size() != NumInvariantDim || saveInvStdStrides.size() != NumInvariantDim)
throw std::runtime_error("dimension is incorrect");
return std::make_unique<Argument>(lengths,
xStrides,
gammaStrides,
betaStrides,
yStrides,
saveMeanStrides,
saveInvStdStrides,
reduceDims,
y_elementwise_op,
epsilon,
static_cast<const XDataType*>(p_x),
static_cast<const GammaDataType*>(p_gamma),
static_cast<const BetaDataType*>(p_beta),
static_cast<YDataType*>(p_y));
static_cast<YDataType*>(p_y),
static_cast<SaveMeanInvStdDataType*>(p_saveMean),
static_cast<SaveMeanInvStdDataType*>(p_saveInvStd));
};
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
......
......@@ -113,7 +113,6 @@ struct PassThrough
}
#endif
#if defined CK_ENABLE_FP8
template <>
__host__ __device__ void operator()<f8_t, f8_t>(f8_t& y, const f8_t& x) const
{
......@@ -143,7 +142,36 @@ struct PassThrough
{
y = type_convert<f8_t>(x);
}
#endif
template <>
__host__ __device__ void operator()<bf8_t, bf8_t>(bf8_t& y, const bf8_t& x) const
{
y = x;
}
template <>
__host__ __device__ void operator()<float, bf8_t>(float& y, const bf8_t& x) const
{
y = type_convert<float>(x);
}
template <>
__host__ __device__ void operator()<bf8_t, float>(bf8_t& y, const float& x) const
{
y = type_convert<bf8_t>(x);
}
template <>
__host__ __device__ void operator()<half_t, bf8_t>(half_t& y, const bf8_t& x) const
{
y = type_convert<half_t>(x);
}
template <>
__host__ __device__ void operator()<bf8_t, half_t>(bf8_t& y, const half_t& x) const
{
y = ck::type_convert<bf8_t>(x);
}
};
struct UnaryConvert
......@@ -172,7 +200,6 @@ struct ConvertBF16RTN
}
};
#if defined CK_ENABLE_FP8
struct ConvertF8SR
{
// convert to fp8 using stochastic rounding (SR)
......@@ -189,7 +216,6 @@ struct ConvertF8SR
y = f8_convert_sr<Y>(x);
}
};
#endif
struct Scale
{
......@@ -416,10 +442,11 @@ struct Sigmoid
__host__ __device__ void operator()(T& y, const T& x) const
{
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, ck::half_t>::value,
is_same<T, ck::half_t>::value || is_same<T, int8_t>::value ||
is_same<T, int32_t>::value,
"Data type is not supported by this operation!");
y = 1 / (ck::type_convert<T>(1) + exp(-x));
constexpr T one = type_convert<T>(1);
y = one / (one + ck::math::exp(-x));
};
};
......@@ -429,7 +456,8 @@ struct TanH
__host__ __device__ void operator()(T& y, const T& x) const
{
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, ck::half_t>::value,
is_same<T, ck::half_t>::value || is_same<T, int8_t>::value ||
is_same<T, int32_t>::value,
"Data type is not supported by this operation!");
y = ck::math::tanh(x);
......@@ -455,7 +483,101 @@ struct Swish
y = type_convert<Y>(x / (1.f + ck::math::exp(bx)));
};
float beta_ = 1.0f;
const float beta_;
};
struct SoftRelu
{
SoftRelu(float alpha = 1.f) : alpha_(alpha){};
template <typename T>
__host__ __device__ void operator()(T& y, const T& x) const
{
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, half_t>::value || is_same<T, int32_t>::value ||
is_same<T, int8_t>::value,
"Data type is not supported by this operation!");
T casted_alpha = type_convert<T>(alpha_);
constexpr T one = type_convert<T>(1);
y = ck::math::log(one + ck::math::exp(x * casted_alpha)) / casted_alpha;
}
const float alpha_;
};
struct Power
{
Power(float alpha = 0.f, float beta = 1.f, float gamma = 2.f)
: alpha_(alpha), beta_(beta), gamma_(gamma){};
template <typename T>
__host__ __device__ void operator()(T& y, const T& x) const
{
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, half_t>::value || is_same<T, int32_t>::value ||
is_same<T, int8_t>::value,
"Data type is not supported by this operation!");
T casted_alpha = type_convert<T>(alpha_);
T casted_beta = type_convert<T>(beta_);
T casted_gamma = type_convert<T>(gamma_);
T shifted_scaled_x = casted_alpha + casted_beta * x;
y = ck::math::pow(shifted_scaled_x, casted_gamma);
}
const float alpha_;
const float beta_;
const float gamma_;
};
struct ClippedRelu
{
ClippedRelu(float alpha = 0.f, float beta = 1.f) : alpha_(alpha), beta_(beta){};
template <typename T>
__host__ __device__ void operator()(T& y, const T& x) const
{
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, half_t>::value || is_same<T, int32_t>::value ||
is_same<T, int8_t>::value,
"Data type is not supported by this operation!");
T casted_alpha = type_convert<T>(alpha_);
T casted_beta = type_convert<T>(beta_);
y = ck::math::min(casted_beta, ck::math::max(casted_alpha, x));
}
const float alpha_;
const float beta_;
};
struct LeakyRelu
{
LeakyRelu(float alpha = 0.01f) : alpha_(alpha){};
template <typename T>
__host__ __device__ void operator()(T& y, const T& x) const
{
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, half_t>::value || is_same<T, int32_t>::value ||
is_same<T, int8_t>::value,
"Data type is not supported by this operation!");
T casted_alpha = type_convert<T>(alpha_);
y = x >= 0 ? x : x * casted_alpha;
}
const float alpha_;
};
struct Elu
{
Elu(float alpha = 1.f) : alpha_(alpha){};
template <typename T>
__host__ __device__ void operator()(T& y, const T& x) const
{
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, half_t>::value || is_same<T, int32_t>::value ||
is_same<T, int8_t>::value,
"Data type is not supported by this operation!");
T casted_alpha = type_convert<T>(alpha_);
y = x > 0 ? x : casted_alpha * ck::math::expm1(x);
}
const float alpha_;
};
} // namespace element_wise
......
......@@ -522,6 +522,7 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
BlockSize,
ABDataType,
ABDataType,
AccDataType,
decltype(a_block_desc_ak0_m_ak1),
decltype(b_block_desc_bk0_n_bk1),
......
......@@ -628,7 +628,8 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
Gemm1KPack,
false, // TransposeC
Gemm1KPack, // AMmaKStride
Gemm1KPack * XdlopsGemm<FloatAB, MPerXdl, NPerXdl, Gemm1KPack, false>{}.K0PerXdlops>{
Gemm1KPack *
XdlopsGemm<FloatAB, MPerXdl, NPerXdl, Gemm1KPack, FloatAB, false>{}.K0PerXdlops>{
// BMmaKStride
make_tuple(0, 0, 0, 0)}; // A_origin
......
......@@ -880,7 +880,12 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
Gemm1KPack,
false, // TransposeC
Gemm1KPack, // AMmaKStride
Gemm1KPack * XdlopsGemm<A0B0B1DataType, Gemm0MPerXdl, Gemm0NPerXdl, Gemm1KPack, false>{}
Gemm1KPack * XdlopsGemm<A0B0B1DataType,
Gemm0MPerXdl,
Gemm0NPerXdl,
Gemm1KPack,
A0B0B1DataType,
false>{}
.K0PerXdlops>{ // BMmaKStride
make_tuple(0, 0, 0, 0)}; // A_origin
......
......@@ -794,7 +794,8 @@ struct GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle
Gemm1KPack,
true, // TransposeC
Gemm1KPack, // AMmaKStride
Gemm1KPack * XdlopsGemm<FloatAB, MPerXdl, NPerXdl, Gemm1KPack, false>{}.K0PerXdlops>{
Gemm1KPack *
XdlopsGemm<FloatAB, MPerXdl, NPerXdl, Gemm1KPack, FloatAB, false>{}.K0PerXdlops>{
// BMmaKStride
make_tuple(0, 0, 0, 0)}; // A_origin
......
......@@ -649,7 +649,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
Gemm1KPack,
true, // TransposeC
Gemm1KPack, // AMmaKStride
Gemm1KPack * XdlopsGemm<FloatAB, MPerXdl, NPerXdl, Gemm1KPack, false>{}.K0PerXdlops>{
Gemm1KPack *
XdlopsGemm<FloatAB, MPerXdl, NPerXdl, Gemm1KPack, FloatAB, false>{}.K0PerXdlops>{
// BMmaKStride
make_tuple(0, 0, 0, 0)}; // A_origin
......
......@@ -504,6 +504,7 @@ struct GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
BlockSize,
FloatAB,
FloatAB,
FloatGemmAcc,
decltype(a_block_desc_ak0_m_ak1),
decltype(b_block_desc_bk0_n_bk1),
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/multi_index_transform_helper.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r2.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
namespace ck {
// GEMM:
// input : A0[M, K], A1[M, K]
// input : B0[N, K], B1[N, K]
// input : D0[M, N], D1[M, N], ...
// output : E[M, N]
// C = a_op(A) * b_op(B)
// E = cde_op(C, D0, D1, ...)
// Assume:
// D0, D1, ... and E have the same layout
template <typename AsDataType,
typename BsDataType,
typename ComputeDataType_,
typename AccDataType,
typename CShuffleDataType,
typename DsDataType,
typename EDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
index_t NumGemmKPrefetchStage,
index_t BlockSize,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t AK1Value,
index_t BK1Value,
index_t MPerXdl,
index_t NPerXdl,
index_t MXdlPerWave,
index_t NXdlPerWave,
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_AK1,
bool AThreadTransferSrcResetCoordinateAfterRun,
index_t ABlockLdsExtraM,
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_BK1,
bool BThreadTransferSrcResetCoordinateAfterRun,
index_t BBlockLdsExtraN,
index_t CShuffleMXdlPerWavePerShuffle,
index_t CShuffleNXdlPerWavePerShuffle,
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CDEShuffleBlockTransferScalarPerVector_NPerBlock,
LoopScheduler LoopSched,
PipelineVersion PipelineVer = PipelineVersion::v1>
struct GridwiseGemmMultipleABD_xdl_cshuffle
{
static constexpr index_t NumATensor = AsDataType::Size();
static constexpr index_t NumBTensor = BsDataType::Size();
static constexpr index_t NumDTensor = DsDataType::Size();
using GemmSpecialization = ck::tensor_operation::device::GemmSpecialization;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{};
static constexpr auto I5 = Number<5>{};
static constexpr auto I6 = Number<6>{};
static constexpr auto I7 = Number<7>{};
// K1 should be Number<...>
static constexpr auto AK1 = Number<AK1Value>{};
static constexpr auto BK1 = Number<BK1Value>{};
static constexpr auto AK0PerBlock = Number<KPerBlock / AK1Value>{};
static constexpr auto BK0PerBlock = Number<KPerBlock / BK1Value>{};
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
using GridwiseGemmPipe = remove_cvref_t<
decltype(GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>())>;
#if CK_WORKAROUND_DENORM_FIX
using ComputeDataType =
conditional_t<is_same_v<ComputeDataType_, ck::half_t>, ck::bhalf_t, ComputeDataType_>;
#else
using ComputeDataType = ComputeDataType_;
#endif
__host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
{
// A matrix in LDS memory, dst of blockwise copy
return make_naive_tensor_descriptor(
make_tuple(AK0PerBlock, Number<MPerBlock>{}, AK1),
make_tuple(Number<MPerBlock + ABlockLdsExtraM>{} * AK1, AK1, I1));
}
__host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
{
// B matrix in LDS memory, dst of blockwise copy
return make_naive_tensor_descriptor(
make_tuple(BK0PerBlock, Number<NPerBlock>{}, BK1),
make_tuple(Number<NPerBlock + BBlockLdsExtraN>{} * BK1, BK1, I1));
}
__host__ __device__ static constexpr auto
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock()
{
constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
make_naive_tensor_descriptor_packed(
make_tuple(I1,
Number<CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl>{},
I1,
Number<CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>{}));
return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock;
}
static constexpr auto MakeAsGridPointer()
{
return generate_tuple(
[&](auto i) {
using ADataType = remove_cvref_t<tuple_element_t<i.value, AsDataType>>;
return static_cast<const ADataType*>(nullptr);
},
Number<NumATensor>{});
}
static constexpr auto MakeBsGridPointer()
{
return generate_tuple(
[&](auto i) {
using BDataType = remove_cvref_t<tuple_element_t<i.value, BsDataType>>;
return static_cast<const BDataType*>(nullptr);
},
Number<NumBTensor>{});
}
// ck::Tuple<const D0DataType*, const D1DataType*, ...>
static constexpr auto MakeDsGridPointer()
{
return generate_tuple(
[&](auto i) {
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
return static_cast<const DDataType*>(nullptr);
},
Number<NumDTensor>{});
}
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
// lds max alignment
constexpr auto max_lds_align = math::lcm(AK1, BK1);
constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
constexpr auto b_block_space_size_aligned = math::integer_least_multiple(
b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
// LDS allocation for C shuffle in LDS
constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
constexpr auto c_block_size =
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
return math::max((a_block_space_size_aligned + b_block_space_size_aligned) *
sizeof(ComputeDataType),
c_block_size * sizeof(CShuffleDataType));
}
// A desc for source in blockwise copy
template <typename AGridDesc_M_K>
__host__ __device__ static constexpr auto
MakeAGridDescriptor_AK0_M_AK1(const AGridDesc_M_K& a_grid_desc_m_k)
{
const auto M = a_grid_desc_m_k.GetLength(I0);
const auto K = a_grid_desc_m_k.GetLength(I1);
const auto AK0 = K / AK1;
return transform_tensor_descriptor(a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
make_pass_through_transform(M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
template <typename AsGridDesc_M_K>
__host__ __device__ static constexpr auto
MakeAsGridDescriptor_AK0_M_AK1(const AsGridDesc_M_K& as_grid_desc_m_k)
{
return generate_tuple(
[&](auto i) { return MakeAGridDescriptor_AK0_M_AK1(as_grid_desc_m_k[i]); },
Number<NumATensor>{});
}
// B desc for source in blockwise copy
template <typename BGridDesc_N_K>
__host__ __device__ static constexpr auto
MakeBGridDescriptor_BK0_N_BK1(const BGridDesc_N_K& b_grid_desc_n_k)
{
const auto N = b_grid_desc_n_k.GetLength(I0);
const auto K = b_grid_desc_n_k.GetLength(I1);
const auto BK0 = K / BK1;
return transform_tensor_descriptor(b_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
make_pass_through_transform(N)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
template <typename BsGridDesc_N_K>
__host__ __device__ static constexpr auto
MakeBsGridDescriptor_BK0_N_BK1(const BsGridDesc_N_K& bs_grid_desc_n_k)
{
return generate_tuple(
[&](auto i) { return MakeBGridDescriptor_BK0_N_BK1(bs_grid_desc_n_k[i]); },
Number<NumBTensor>{});
}
// E desc for destination in blockwise copy
template <typename EGridDesc_M_N>
__host__ __device__ static constexpr auto
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const EGridDesc_M_N& e_grid_desc_m_n)
{
const auto M = e_grid_desc_m_n.GetLength(I0);
const auto N = e_grid_desc_m_n.GetLength(I1);
const auto MBlock = M / MPerBlock;
const auto NBlock = N / NPerBlock;
const auto e_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
e_grid_desc_m_n,
make_tuple(make_unmerge_transform(make_tuple(MBlock, Number<MPerBlock>{})),
make_unmerge_transform(make_tuple(NBlock, Number<NPerBlock>{}))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}));
return e_grid_desc_mblock_mperblock_nblock_nperblock;
}
// Ds desc for source in blockwise copy
template <typename DsGridDesc_M_N>
__host__ __device__ static constexpr auto
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DsGridDesc_M_N& ds_grid_desc_m_n)
{
return generate_tuple(
[&](auto i) {
return MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(ds_grid_desc_m_n[i]);
},
Number<NumDTensor>{});
}
// return block_id to E matrix tile idx (m0, n0) mapping
template <typename EGridDesc_M_N>
__host__ __device__ static constexpr auto
MakeBlock2ETileMap(const EGridDesc_M_N& e_grid_desc_m_n)
{
return BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, EGridDesc_M_N>(
e_grid_desc_m_n);
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
template <typename AsGridDesc_M_K,
typename BsGridDesc_N_K,
typename DsGridDesc_M_N,
typename EGridDesc_M_N,
typename Block2ETileMap>
__host__ __device__ static constexpr bool CheckValidity(const AsGridDesc_M_K& as_grid_desc_m_k,
const BsGridDesc_N_K& bs_grid_desc_n_k,
const DsGridDesc_M_N& ds_grid_desc_m_n,
const EGridDesc_M_N& e_grid_desc_m_n,
const Block2ETileMap& block_2_etile_map)
{
static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
(NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
"Invalid tuning param!");
static_assert(KPerBlock % AK1Value == 0 && KPerBlock % BK1Value == 0,
"KPerBlock must be divisible by AK1Value and BK1Value!");
const auto M = as_grid_desc_m_k[I0].GetLength(I0);
const auto N = bs_grid_desc_n_k[I0].GetLength(I0);
const auto AK = as_grid_desc_m_k[I0].GetLength(I1);
const auto BK = bs_grid_desc_n_k[I0].GetLength(I1);
// check consistency of desc
if(!(M == e_grid_desc_m_n.GetLength(I0) && N == e_grid_desc_m_n.GetLength(I1) && AK == BK))
{
return false;
}
constexpr long_index_t TwoGB = (long_index_t{1} << 31);
bool valid = true;
static_for<0, NumATensor, 1>{}([&](auto i) {
using ADataType = remove_cvref_t<tuple_element_t<i.value, AsDataType>>;
valid =
valid && (as_grid_desc_m_k[i].GetElementSpaceSize() * sizeof(ADataType) <= TwoGB);
valid = valid && (M == as_grid_desc_m_k[i].GetLength(I0) &&
AK == as_grid_desc_m_k[i].GetLength(I1));
});
static_for<0, NumBTensor, 1>{}([&](auto i) {
using BDataType = remove_cvref_t<tuple_element_t<i.value, BsDataType>>;
valid =
valid && (bs_grid_desc_n_k[i].GetElementSpaceSize() * sizeof(BDataType) <= TwoGB);
valid = valid && (N == bs_grid_desc_n_k[i].GetLength(I0) &&
BK == bs_grid_desc_n_k[i].GetLength(I1));
});
static_for<0, NumDTensor, 1>{}([&](auto i) {
valid = valid && (M == ds_grid_desc_m_n[i].GetLength(I0) &&
N == ds_grid_desc_m_n[i].GetLength(I1));
});
if(!valid)
{
return false;
}
// check tile size
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && AK % KPerBlock == 0))
{
return false;
}
// check gridwise gemm pipeline
const auto num_k_loop = AK / KPerBlock;
if(!GridwiseGemmPipe::IsSupported(num_k_loop))
{
return false;
}
// check block-to-E-tile
if(!block_2_etile_map.CheckValidity(e_grid_desc_m_n))
{
return false;
}
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
// check tensor size: cannot be larger than 2GB each
if(!(e_grid_desc_m_n.GetElementSpaceSize() * sizeof(EDataType) <= TwoGB))
{
return false;
}
return true;
}
__host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
{
const index_t num_loop = K / KPerBlock;
return GridwiseGemmPipe::CalculateHasMainLoop(num_loop);
}
using AsGridPointer = decltype(MakeAsGridPointer());
using BsGridPointer = decltype(MakeBsGridPointer());
using DsGridPointer = decltype(MakeDsGridPointer());
template <typename ALayout, GemmSpecialization GemmSpec>
__host__ __device__ static auto
MakeAGridDescriptor_M_K(index_t MRaw, index_t KRaw, index_t StrideA)
{
constexpr auto matrix_padder =
ck::tensor_operation::device::MatrixPadder<GemmSpec, index_t, index_t, index_t>{
MPerBlock, NPerBlock, KPerBlock};
const auto a_grid_desc_mraw_kraw = [&]() {
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
{
return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
make_tuple(StrideA, I1));
}
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
{
return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
make_tuple(I1, StrideA));
}
}();
return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
}
template <typename AsLayout, GemmSpecialization GemmSpec>
__host__ __device__ static auto
MakeAsGridDescriptor_M_K(const std::array<index_t, NumATensor>& MRaws,
const std::array<index_t, NumATensor>& KRaws,
const std::array<index_t, NumATensor>& AsStride)
{
return generate_tuple(
[&](auto i) {
using ALayout = remove_cvref_t<tuple_element_t<i.value, AsLayout>>;
return MakeAGridDescriptor_M_K<ALayout, GemmSpec>(MRaws[i], KRaws[i], AsStride[i]);
},
Number<NumATensor>{});
}
template <typename BLayout, GemmSpecialization GemmSpec>
__host__ __device__ static auto
MakeBGridDescriptor_N_K(index_t KRaw, index_t NRaw, index_t StrideB)
{
constexpr auto matrix_padder =
ck::tensor_operation::device::MatrixPadder<GemmSpec, index_t, index_t, index_t>{
MPerBlock, NPerBlock, KPerBlock};
const auto b_grid_desc_nraw_kraw = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
make_tuple(I1, StrideB));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
make_tuple(StrideB, I1));
}
}();
return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
}
template <typename BsLayout, GemmSpecialization GemmSpec>
__host__ __device__ static auto
MakeBsGridDescriptor_N_K(const std::array<index_t, NumBTensor>& KRaws,
const std::array<index_t, NumBTensor>& NRaws,
const std::array<index_t, NumBTensor>& BsStride)
{
return generate_tuple(
[&](auto i) {
using BLayout = remove_cvref_t<tuple_element_t<i.value, BsLayout>>;
return MakeBGridDescriptor_N_K<BLayout, GemmSpec>(KRaws[i], NRaws[i], BsStride[i]);
},
Number<NumBTensor>{});
}
template <typename ELayout, GemmSpecialization GemmSpec>
__host__ __device__ static auto
MakeEGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideE)
{
constexpr auto matrix_padder =
ck::tensor_operation::device::MatrixPadder<GemmSpec, index_t, index_t, index_t>{
MPerBlock, NPerBlock, KPerBlock};
const auto e_grid_desc_mraw_nraw = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, ELayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
make_tuple(StrideE, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, ELayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
make_tuple(I1, StrideE));
}
}();
return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw);
}
template <typename DsLayout, GemmSpecialization GemmSpec>
__host__ __device__ static auto
MakeDsGridDescriptor_M_N(const std::array<index_t, NumDTensor>& MRaws,
const std::array<index_t, NumDTensor>& NRaws,
const std::array<index_t, NumDTensor>& DsStride)
{
return generate_tuple(
[&](auto i) {
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
return MakeEGridDescriptor_M_N<DLayout, GemmSpec>(MRaws[i], NRaws[i], DsStride[i]);
},
Number<NumDTensor>{});
}
__device__ __host__ static constexpr auto GetMPerBlock() { return MPerBlock; }
template <bool HasMainKBlockLoop,
typename AsGridDesc_AK0_M_AK1,
typename BsGridDesc_BK0_N_BK1,
typename DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
typename EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
typename Block2ETileMap>
__device__ static void Run(AsGridPointer p_as_grid,
BsGridPointer p_bs_grid,
DsGridPointer p_ds_grid,
EDataType* __restrict__ p_e_grid,
void* __restrict__ p_shared,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CDEElementwiseOperation& cde_element_op,
const AsGridDesc_AK0_M_AK1 as_grid_desc_ak0_m_ak1,
const BsGridDesc_BK0_N_BK1 bs_grid_desc_bk0_n_bk1,
const DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
ds_grid_desc_mblock_mperblock_nblock_nperblock,
const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
e_grid_desc_mblock_mperblock_nblock_nperblock,
const Block2ETileMap& block_2_etile_map)
{
const auto as_grid_buf = generate_tuple(
[&](auto i) {
return make_dynamic_buffer<AddressSpaceEnum::Global>(
p_as_grid[i], as_grid_desc_ak0_m_ak1[i].GetElementSpaceSize());
},
Number<NumATensor>{});
const auto bs_grid_buf = generate_tuple(
[&](auto i) {
return make_dynamic_buffer<AddressSpaceEnum::Global>(
p_bs_grid[i], bs_grid_desc_bk0_n_bk1[i].GetElementSpaceSize());
},
Number<NumBTensor>{});
const auto ds_grid_buf = generate_tuple(
[&](auto i) {
return make_dynamic_buffer<AddressSpaceEnum::Global>(
p_ds_grid[i],
ds_grid_desc_mblock_mperblock_nblock_nperblock[i].GetElementSpaceSize());
},
Number<NumDTensor>{});
auto e_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_e_grid, e_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
// divide block work by [M, N]
const auto block_work_idx =
block_2_etile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
if(!block_2_etile_map.ValidCTileIndex(
block_work_idx,
make_tuple(e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
{
return;
}
// HACK: this force m/n_block_data_idx_on_grid into SGPR
const index_t m_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock);
const index_t n_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock);
// lds max alignment
constexpr auto max_lds_align = math::lcm(AK1, BK1);
// A matrix in LDS memory, dst of blockwise copy
constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
// B matrix in LDS memory, dst of blockwise copy
constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
const auto idx_as_block_begin =
generate_tuple([&](auto) { return make_multi_index(0, m_block_data_idx_on_grid, 0); },
Number<NumATensor>{});
auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_v7r2<
ThisThreadBlock,
AsDataType,
Tuple<ComputeDataType>,
decltype(as_grid_desc_ak0_m_ak1),
decltype(tie(a_block_desc_ak0_m_ak1)),
AElementwiseOperation,
Sequence<static_cast<index_t>(InMemoryDataOperationEnum::Set)>,
Sequence<AK0PerBlock, MPerBlock, AK1>,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
Sequence<1, 0, 2>,
ABlockTransferSrcVectorDim,
2,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
uniform_sequence_gen_t<NumATensor, false>,
Sequence<true>>{as_grid_desc_ak0_m_ak1,
idx_as_block_begin,
tie(a_block_desc_ak0_m_ak1),
make_tuple(make_multi_index(0, 0, 0)),
a_element_op};
const auto idx_bs_block_begin =
generate_tuple([&](auto) { return make_multi_index(0, n_block_data_idx_on_grid, 0); },
Number<NumBTensor>{});
auto b_blockwise_copy = ThreadGroupTensorSliceTransfer_v7r2<
ThisThreadBlock,
BsDataType,
Tuple<ComputeDataType>,
decltype(bs_grid_desc_bk0_n_bk1),
decltype(tie(b_block_desc_bk0_n_bk1)),
BElementwiseOperation,
Sequence<static_cast<index_t>(InMemoryDataOperationEnum::Set)>,
Sequence<BK0PerBlock, NPerBlock, BK1>,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
Sequence<1, 0, 2>,
BBlockTransferSrcVectorDim,
2,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
uniform_sequence_gen_t<NumBTensor, false>,
Sequence<true>>{bs_grid_desc_bk0_n_bk1,
idx_bs_block_begin,
tie(b_block_desc_bk0_n_bk1),
make_tuple(make_multi_index(0, 0, 0)),
b_element_op};
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[K0PerBlock, MPerBlock] is in LDS
// b_mtx[K0PerBlock, NPerBlock] is in LDS
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
// sanity check
constexpr index_t KPack =
math::max(math::lcm(AK1, BK1),
MfmaSelector<ComputeDataType, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
BlockSize,
ComputeDataType, // ComputeDataType for A
ComputeDataType, // ComputeDataType for B
AccDataType,
decltype(a_block_desc_ak0_m_ak1),
decltype(b_block_desc_bk0_n_bk1),
MPerXdl,
NPerXdl,
MXdlPerWave,
NXdlPerWave,
KPack,
LoopSched>();
auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<ComputeDataType*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<ComputeDataType*>(p_shared) + a_block_space_size_aligned,
b_block_desc_bk0_n_bk1.GetElementSpaceSize());
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1, 0, 0);
const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
(as_grid_desc_ak0_m_ak1[I0].GetLength(I0) * as_grid_desc_ak0_m_ak1[I0].GetLength(I2)) /
KPerBlock);
// gridwise GEMM pipeline
const auto gridwise_gemm_pipeline =
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>();
gridwise_gemm_pipeline.template Run<HasMainKBlockLoop>(as_grid_desc_ak0_m_ak1,
a_block_desc_ak0_m_ak1,
a_blockwise_copy,
as_grid_buf,
a_block_buf,
a_block_slice_copy_step,
bs_grid_desc_bk0_n_bk1,
b_block_desc_bk0_n_bk1,
b_blockwise_copy,
bs_grid_buf,
b_block_buf,
b_block_slice_copy_step,
blockwise_gemm,
c_thread_buf,
num_k_block_main_loop);
// shuffle C and write out
{
static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
"wrong!");
constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
// TODO: hacky, fix it!
constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
// TODO: hacky, fix it!
// c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<CShuffleDataType*>(p_shared),
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
make_tuple(
make_freeze_transform(I0),
make_unmerge_transform(make_tuple(
Number<CShuffleMXdlPerWavePerShuffle>{}, // M0 (MXdlPerWave) per shuffle
M1, // M1 = MWave
M2, // M2 * M3 * M4 = MPerXdl
M3,
M4)),
make_freeze_transform(I0),
make_unmerge_transform(make_tuple(
Number<CShuffleNXdlPerWavePerShuffle>{}, // N0 (NXdlPerWave) per shuffle
N1, // N1 = NWave
N2))), // N2 = NPerXdl
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(
Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{}));
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const auto c_thread_mtx_on_block =
blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
make_tuple(Sequence<0, 1, 2, 3, 4>{}),
make_tuple(Sequence<0>{}));
const auto m_thread_data_on_block_idx =
m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
make_multi_index(m_thread_data_on_block));
const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(N0, N1, N2))),
make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{}));
const auto n_thread_data_on_block_idx =
n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
make_multi_index(n_thread_data_on_block));
// shuffle: threadwise copy C from VGPR to LDS
auto c_thread_copy_vgpr_to_lds =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
CShuffleDataType,
decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
ck::tensor_operation::element_wise::PassThrough,
Sequence<CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
I1,
I1,
M2,
I1,
M4,
I1>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
7,
1,
InMemoryDataOperationEnum::Set,
1,
true>{
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
make_multi_index(0,
0,
m_thread_data_on_block_idx[I1],
n_thread_data_on_block_idx[I1],
m_thread_data_on_block_idx[I2],
m_thread_data_on_block_idx[I3],
m_thread_data_on_block_idx[I4],
n_thread_data_on_block_idx[I2]),
ck::tensor_operation::element_wise::PassThrough{}};
// tuple of reference to C/Ds tensor descriptors
const auto c_ds_desc_refs = concat_tuple_of_reference(
tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
generate_tie(
[&](auto i) -> const auto& // return type should be reference
{ return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
Number<NumDTensor>{}));
// tuple of reference to C/Ds tensor descriptors
const auto c_ds_buf_refs = concat_tuple_of_reference(
tie(c_shuffle_block_buf),
generate_tie(
[&](auto i) -> const auto& // return type should be reference
{ return ds_grid_buf[i]; },
Number<NumDTensor>{}));
// tuple of starting index of C/Ds blockwise copy
const auto idx_c_ds_block_begin = container_concat(
make_tuple(make_multi_index(0, 0, 0, 0)),
generate_tuple(
[&](auto) {
return make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0);
},
Number<NumDTensor>{}));
// blockwise copy C/D/E between LDS and global
auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r2<
ThisThreadBlock,
decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})),
Tuple<EDataType>,
decltype(c_ds_desc_refs),
decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
CDEElementwiseOperation,
Sequence<static_cast<index_t>(EGlobalMemoryDataOperation)>, // FIXME: make Sequence
// support arbitray type
Sequence<1,
CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1,
CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder,
Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder,
3, // index_t SrcVectorDim,
3, // index_t DstVectorDim,
CDEShuffleBlockTransferScalarPerVector_NPerBlock,
CDEShuffleBlockTransferScalarPerVector_NPerBlock,
sequence_merge_t<
Sequence<true>,
uniform_sequence_gen_t<NumDTensor,
false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags
Sequence<false>> // ThreadTransferDstResetCoordinateAfterRunFlags
{c_ds_desc_refs,
idx_c_ds_block_begin,
tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
make_tuple(make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0)),
cde_element_op};
// space filling curve for threadwise C in VGPR before shuffle
constexpr auto sfc_c_vgpr =
SpaceFillingCurve<Sequence<MXdlPerWave, NXdlPerWave, 1, 1, M2, 1, M4, 1>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
Sequence<CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
1,
1,
M2,
1,
M4,
1>>{};
// space filling curve for shuffled blockwise C/D/E
constexpr auto sfc_cde_block =
SpaceFillingCurve<Sequence<1, MPerBlock, 1, NPerBlock>,
Sequence<0, 2, 1, 3>,
Sequence<1,
CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1,
CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!");
static_for<0, num_access, 1>{}([&](auto access_id) {
// make sure it's safe to write to LDS
block_sync_lds();
// each thread write its data from VGPR to LDS
c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
c_thread_buf,
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
c_shuffle_block_buf);
// make sure it's safe to read from LDS
block_sync_lds();
// each block copy its data from LDS to global
cde_block_copy_lds_and_global.Run(
c_ds_desc_refs,
c_ds_buf_refs,
tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
tie(e_grid_buf));
if constexpr(access_id < num_access - 1)
{
constexpr auto cde_lds_and_global_step =
sfc_cde_block.GetForwardStep(access_id);
// move on Ds
static_for<0, NumDTensor, 1>{}([&](auto i) {
cde_block_copy_lds_and_global.MoveSrcSliceWindow(
c_ds_desc_refs, i + I1, cde_lds_and_global_step);
});
// move on E
cde_block_copy_lds_and_global.MoveDstSliceWindow(
tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
I0,
cde_lds_and_global_step);
}
});
}
}
template <bool HasMainKBlockLoop,
GemmSpecialization GemmSpec,
typename AsLayout,
typename BsLayout,
typename DsLayout,
typename ELayout,
typename Block2ETileMap>
__device__ static void Run(AsGridPointer p_as_grid,
BsGridPointer p_bs_grid,
DsGridPointer p_ds_grid,
void* __restrict__ p_e_grid_,
void* __restrict__ p_shared,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CDEElementwiseOperation& cde_element_op,
const index_t M,
const index_t N,
const index_t K,
const std::array<index_t, NumATensor> StrideAs,
const std::array<index_t, NumBTensor> StrideBs,
const std::array<index_t, NumDTensor> StrideDs,
const index_t StrideE,
const Block2ETileMap& block_2_etile_map)
{
using AsGridDesc_M_K =
remove_cvref_t<decltype(MakeAsGridDescriptor_M_K<AsLayout, GemmSpec>({}, {}, {}))>;
using BsGridDesc_N_K =
remove_cvref_t<decltype(MakeBsGridDescriptor_N_K<BsLayout, GemmSpec>({}, {}, {}))>;
using DsGridDesc_M_N =
remove_cvref_t<decltype(MakeDsGridDescriptor_M_N<DsLayout, GemmSpec>({}, {}, {}))>;
const auto p_e_grid = reinterpret_cast<EDataType*>(p_e_grid_);
AsGridDesc_M_K as_grid_desc_m_k;
BsGridDesc_N_K bs_grid_desc_n_k;
DsGridDesc_M_N ds_grid_desc_m_n;
static_for<0, NumATensor, 1>{}([&](auto j) {
using ALayout = remove_cvref_t<tuple_element_t<j.value, AsLayout>>;
as_grid_desc_m_k(j) = MakeAGridDescriptor_M_K<ALayout, GemmSpec>(M, K, StrideAs[j]);
});
static_for<0, NumBTensor, 1>{}([&](auto j) {
using BLayout = remove_cvref_t<tuple_element_t<j.value, BsLayout>>;
bs_grid_desc_n_k(j) = MakeBGridDescriptor_N_K<BLayout, GemmSpec>(N, K, StrideBs[j]);
});
static_for<0, NumDTensor, 1>{}([&](auto j) {
using DLayout = remove_cvref_t<tuple_element_t<j.value, DsLayout>>;
ds_grid_desc_m_n(j) = MakeEGridDescriptor_M_N<DLayout, GemmSpec>(M, N, StrideDs[j]);
});
const auto e_grid_desc_m_n = MakeEGridDescriptor_M_N<ELayout, GemmSpec>(M, N, StrideE);
// tensor descriptors for block/thread-wise copy
const auto as_grid_desc_ak0_m_ak1 = MakeAsGridDescriptor_AK0_M_AK1(as_grid_desc_m_k);
const auto bs_grid_desc_bk0_n_bk1 = MakeBsGridDescriptor_BK0_N_BK1(bs_grid_desc_n_k);
const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(ds_grid_desc_m_n);
const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(e_grid_desc_m_n);
Run<HasMainKBlockLoop>(p_as_grid,
p_bs_grid,
p_ds_grid,
p_e_grid,
p_shared,
a_element_op,
b_element_op,
cde_element_op,
as_grid_desc_ak0_m_ak1,
bs_grid_desc_bk0_n_bk1,
ds_grid_desc_mblock_mperblock_nblock_nperblock,
e_grid_desc_mblock_mperblock_nblock_nperblock,
block_2_etile_map);
}
};
} // namespace ck
......@@ -470,6 +470,7 @@ struct GridwiseGemmMultipleDMultipleR_k0mk1_k0nk1_mn_xdl_cshuffle_v1
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
BlockSize,
FloatAB,
FloatAB,
FloatGemmAcc,
decltype(a_block_desc_ak0_m_ak1),
decltype(b_block_desc_bk0_n_bk1),
......
......@@ -36,7 +36,7 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_grouped_conv_fwd_multiple_d_wmma_cshuffle(
kernel_grouped_conv_multiple_d_wmma_cshuffle(
const ADataType* __restrict__ p_a_grid,
const BDataType* __restrict__ p_b_grid,
DsPointer p_ds_grid,
......@@ -452,11 +452,11 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
// CheckValidity for kernels without multi D
template <typename Block2CTileMap>
__host__ __device__ static constexpr bool
CheckValidity(const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1,
const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1,
const DsGridDesc_M_N& ds_grid_desc_m_n,
const EGridDesc_M_N& e_grid_desc_m_n,
const Block2CTileMap& block_2_ctile_map)
{
......@@ -471,18 +471,6 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
const auto N = b_grid_desc_k0_n_k1.GetLength(I1);
const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0);
bool valid = true;
static_for<0, NumDTensor, 1>{}([&](auto i) {
valid = valid && (M == ds_grid_desc_m_n[i].GetLength(I0) &&
N == ds_grid_desc_m_n[i].GetLength(I1));
});
if(!valid)
{
return false;
}
if(!(M == e_grid_desc_m_n.GetLength(I0) && N == e_grid_desc_m_n.GetLength(I1) &&
K0 == b_grid_desc_k0_n_k1.GetLength(I0) && K1 == a_grid_desc_k0_m_k1.GetLength(I2) &&
K1 == b_grid_desc_k0_n_k1.GetLength(I2)))
......@@ -517,6 +505,31 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
return true;
}
template <typename Block2CTileMap>
__host__ __device__ static constexpr bool
CheckValidity(const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1,
const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1,
const DsGridDesc_M_N& ds_grid_desc_m_n,
const EGridDesc_M_N& e_grid_desc_m_n,
const Block2CTileMap& block_2_ctile_map)
{
const auto M = a_grid_desc_k0_m_k1.GetLength(I1);
const auto N = b_grid_desc_k0_n_k1.GetLength(I1);
bool valid = true;
static_for<0, NumDTensor, 1>{}([&](auto i) {
valid = valid && (M == ds_grid_desc_m_n[i].GetLength(I0) &&
N == ds_grid_desc_m_n[i].GetLength(I1));
});
if(!valid)
{
return false;
}
return CheckValidity(
a_grid_desc_k0_m_k1, b_grid_desc_k0_n_k1, e_grid_desc_m_n, block_2_ctile_map);
}
__host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
{
const index_t num_loop = K / (K0PerBlock * K1);
......
......@@ -31,7 +31,7 @@ namespace ck {
// D0, D1, ... and E have the same layout
template <typename ADataType,
typename BDataType,
typename ComputeDataType_,
typename AComputeDataType_,
typename AccDataType,
typename CShuffleDataType,
typename DsDataType,
......@@ -72,7 +72,8 @@ template <typename ADataType,
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CDEShuffleBlockTransferScalarPerVector_NPerBlock,
LoopScheduler LoopSched,
PipelineVersion PipelineVer = PipelineVersion::v1>
PipelineVersion PipelineVer = PipelineVersion::v1,
typename BComputeDataType = AComputeDataType_>
struct GridwiseGemmMultipleD_xdl_cshuffle
{
static constexpr index_t NumDTensor = DsDataType::Size();
......@@ -100,10 +101,10 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
decltype(GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>())>;
#if CK_WORKAROUND_DENORM_FIX
using ComputeDataType =
conditional_t<is_same_v<ComputeDataType_, ck::half_t>, ck::bhalf_t, ComputeDataType_>;
using AComputeDataType =
conditional_t<is_same_v<AComputeDataType_, ck::half_t>, ck::bhalf_t, AComputeDataType_>;
#else
using ComputeDataType = ComputeDataType_;
using AComputeDataType = AComputeDataType_;
#endif
__host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
......@@ -172,8 +173,8 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
constexpr auto c_block_size =
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
return math::max((a_block_space_size_aligned + b_block_space_size_aligned) *
sizeof(ComputeDataType),
return math::max(a_block_space_size_aligned * sizeof(AComputeDataType) +
b_block_space_size_aligned * sizeof(BComputeDataType),
c_block_size * sizeof(CShuffleDataType));
}
......@@ -502,7 +503,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ADataType,
ComputeDataType,
AComputeDataType,
decltype(a_grid_desc_ak0_m_ak1),
decltype(a_block_desc_ak0_m_ak1),
ABlockTransferSrcAccessOrder,
......@@ -533,7 +534,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
BDataType,
ComputeDataType,
BComputeDataType,
decltype(b_grid_desc_bk0_n_bk1),
decltype(b_block_desc_bk0_n_bk1),
BBlockTransferSrcAccessOrder,
......@@ -561,13 +562,15 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
// sanity check
constexpr index_t KPack =
math::max(math::lcm(AK1, BK1),
MfmaSelector<ComputeDataType, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
constexpr index_t KPack = math::max(
math::lcm(AK1, BK1),
MfmaSelector<AComputeDataType, MPerXdl, NPerXdl, BComputeDataType>::selected_mfma
.k_per_blk);
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
BlockSize,
ComputeDataType,
AComputeDataType,
BComputeDataType,
AccDataType,
decltype(a_block_desc_ak0_m_ak1),
decltype(b_block_desc_bk0_n_bk1),
......@@ -585,10 +588,10 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<ComputeDataType*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
static_cast<AComputeDataType*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<ComputeDataType*>(p_shared) + a_block_space_size_aligned,
static_cast<BComputeDataType*>(p_shared) + a_block_space_size_aligned,
b_block_desc_bk0_n_bk1.GetElementSpaceSize());
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0);
......
......@@ -602,6 +602,7 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
BlockSize,
ComputeType,
ComputeType,
AccDataType,
decltype(a_block_desc_ak0_m_ak1),
decltype(b_block_desc_bk0_n_bk1),
......
......@@ -9,13 +9,13 @@ namespace ck {
struct GridwiseGemmPipeline_v2
{
__host__ __device__ static constexpr bool IsSupported(index_t num_loop)
__host__ __device__ static constexpr bool IsSupported(const index_t num_loop)
{
// TODO: improve applicability
return num_loop % 2 == 0;
}
__host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop)
__host__ __device__ static constexpr bool CalculateHasMainLoop(const index_t num_loop)
{
return (num_loop / 2) > 1;
}
......
......@@ -457,6 +457,7 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
BlockSize,
FloatAB,
FloatAB,
FloatGemmAcc,
decltype(a_block_desc_ak0_m_ak1),
decltype(b_block_desc_bk0_n_bk1),
......
......@@ -588,6 +588,7 @@ struct GridwiseGemmSplitKMultipleD_xdl_cshuffle
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
BlockSize,
ABDataType,
ABDataType,
AccDataType,
decltype(a_block_desc_ak0_m_ak1),
decltype(b_block_desc_bk0_n_bk1),
......@@ -1012,6 +1013,7 @@ struct GridwiseGemmSplitKMultipleD_xdl_cshuffle
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
BlockSize,
ABDataType,
ABDataType,
AccDataType,
decltype(a_block_desc_ak0_m_ak1),
decltype(b_block_desc_bk0_n_bk1),
......
......@@ -108,7 +108,8 @@ template <typename ALayout,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopScheduler LoopSched,
PipelineVersion PipelineVer = PipelineVersion::v1,
typename ComputeType = FloatC>
typename ComputeTypeA = FloatC,
typename ComputeTypeB = ComputeTypeA>
struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
{
static constexpr auto I0 = Number<0>{};
......@@ -547,8 +548,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
constexpr auto c_block_size =
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
return math::max((a_block_space_size_aligned * sizeof(ComputeType) +
b_block_space_size_aligned * sizeof(ComputeType)),
return math::max((a_block_space_size_aligned * sizeof(ComputeTypeA) +
b_block_space_size_aligned * sizeof(ComputeTypeB)),
c_block_size * sizeof(FloatCShuffle));
}
......@@ -750,7 +751,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
FloatA,
ComputeType,
ComputeTypeA,
decltype(a_grid_desc_ak0_m_ak1),
decltype(a_block_desc_ak0_m_ak1),
ABlockTransferSrcAccessOrder,
......@@ -781,7 +782,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
FloatB,
ComputeType,
ComputeTypeB,
decltype(b_grid_desc_bk0_n_bk1),
decltype(b_block_desc_bk0_n_bk1),
BBlockTransferSrcAccessOrder,
......@@ -809,13 +810,14 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
// sanity check
constexpr index_t KPack =
math::max(math::lcm(AK1Number, BK1Number),
MfmaSelector<ComputeType, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
constexpr index_t KPack = math::max(
math::lcm(AK1Number, BK1Number),
MfmaSelector<ComputeTypeA, MPerXdl, NPerXdl, ComputeTypeB>::selected_mfma.k_per_blk);
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
BlockSize,
ComputeType,
ComputeTypeA,
ComputeTypeB,
FloatGemmAcc,
decltype(a_block_desc_ak0_m_ak1),
decltype(b_block_desc_bk0_n_bk1),
......@@ -833,10 +835,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<ComputeType*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
static_cast<ComputeTypeA*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<ComputeType*>(p_shared) + a_block_space_size_aligned,
static_cast<ComputeTypeB*>(p_shared) + a_block_space_size_aligned,
b_block_desc_bk0_n_bk1.GetElementSpaceSize());
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment