Unverified Commit 9f8ab221 authored by zjing14's avatar zjing14 Committed by GitHub
Browse files

Merge branch 'develop' into add_int8_wmma_example_instance

parents 755ace59 b4fc4d0b
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <index_t NumDTensor>
struct ComputePtrOffsetOfStridedBatch
{
ComputePtrOffsetOfStridedBatch() = default;
ComputePtrOffsetOfStridedBatch(index_t BatchStrideA,
index_t BatchStrideB,
Array<ck::index_t, NumDTensor> BatchStrideDs,
index_t BatchStrideE)
: BatchStrideA_(BatchStrideA),
BatchStrideB_(BatchStrideB),
BatchStrideDs_(BatchStrideDs),
BatchStrideE_(BatchStrideE)
{
}
__host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideA_);
}
__host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideB_);
}
__host__ __device__ constexpr auto GetDsPtrOffset(index_t g_idx) const
{
Array<long_index_t, NumDTensor> ds_offset;
static_for<0, NumDTensor, 1>{}(
[&](auto i) { ds_offset(i) = g_idx * static_cast<long_index_t>(BatchStrideDs_[i]); });
return ds_offset;
}
[[maybe_unused]] __host__ __device__ constexpr long_index_t GetEPtrOffset(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideE_);
}
// alias for kernels without multiple D
[[maybe_unused]] __host__ __device__ constexpr long_index_t GetCPtrOffset(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideE_);
}
index_t BatchStrideA_;
index_t BatchStrideB_;
Array<ck::index_t, NumDTensor> BatchStrideDs_;
index_t BatchStrideE_;
index_t& BatchStrideC_ = BatchStrideE_; // alias for kernels without multiple D
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -5,64 +5,41 @@ ...@@ -5,64 +5,41 @@
#include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/device_image_to_column.hpp" #include "ck/tensor_operation/gpu/device/device_conv_tensor_rearrange.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_image_to_column.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_tensor_rearrange.hpp"
#include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/kernel_launch.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" #include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp" #include "ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp"
#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" #include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" #include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/conv_tensor_rearrange_op.hpp"
#include "ck/host_utility/io.hpp" #include "ck/host_utility/io.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
template <typename InputGridDesc,
typename InputDataType,
typename OutputGridDesc,
typename OutputDataType,
typename Block2ETileMap,
typename GridwiseImageToColumnKernel>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_image_to_column(const InputGridDesc in_grid_desc,
const InputDataType* __restrict__ p_in_global,
const OutputGridDesc out_grid_desc,
OutputDataType* __restrict__ p_out_global,
const Block2ETileMap block_2_tile_map)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \
defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx1030__) || defined(__gfx1100__) || \
defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx941__) || defined(__gfx942__))
GridwiseImageToColumnKernel::Run(
in_grid_desc, p_in_global, out_grid_desc, p_out_global, block_2_tile_map);
#else
ignore = in_grid_desc;
ignore = p_in_global;
ignore = out_grid_desc;
ignore = p_out_global;
ignore = block_2_tile_map;
#endif
}
// Image to column for input layout NDHWC: // Image to column for input layout NDHWC:
// input : input image [N, Di, Hi, Wi, C], // input : input image [N, Di, Hi, Wi, C]
// output : output image [N * Do * Ho * Wo, Z * Y * X * C] // output : gemm form [N * Do * Ho * Wo, Z * Y * X * C]
template <index_t NDimSpatial, template <index_t NDimSpatial,
typename InputLayout, typename ImageLayout,
typename InputDataType, typename InputDataType,
typename OutputDataType, typename OutputDataType,
index_t BlockSize, index_t BlockSize,
index_t MPerBlock, index_t MPerBlock,
index_t KPerBlock, index_t KPerBlock,
typename ThreadClusterLengths, typename ThreadClusterLengths,
index_t ScalarPerVector> index_t ScalarPerVector,
typename std::enable_if<NDimSpatial >= 1 && NDimSpatial <= 3, bool>::type = false>
struct DeviceImageToColumnImpl struct DeviceImageToColumnImpl
: public DeviceImageToColumn<NDimSpatial, InputLayout, InputDataType, OutputDataType> : public DeviceConvTensorRearrange<NDimSpatial,
ImageLayout,
InputDataType,
OutputDataType,
conv_tensor_rearrange_op::ImageToColumn>
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
...@@ -83,7 +60,7 @@ struct DeviceImageToColumnImpl ...@@ -83,7 +60,7 @@ struct DeviceImageToColumnImpl
const std::array<index_t, NDimSpatial>& input_spatial_lengths, const std::array<index_t, NDimSpatial>& input_spatial_lengths,
const std::array<index_t, NDimSpatial>& filter_spatial_lengths, const std::array<index_t, NDimSpatial>& filter_spatial_lengths,
const std::array<index_t, NDimSpatial>& output_spatial_lengths, const std::array<index_t, NDimSpatial>& output_spatial_lengths,
const std::array<index_t, NDimSpatial + 3>& input_g_n_c_wis_strides, const std::array<index_t, NDimSpatial + 3>& image_g_n_c_wis_strides,
const std::array<index_t, NDimSpatial>& conv_filter_strides, const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations, const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads, const std::array<index_t, NDimSpatial>& input_left_pads,
...@@ -110,9 +87,9 @@ struct DeviceImageToColumnImpl ...@@ -110,9 +87,9 @@ struct DeviceImageToColumnImpl
c_g_n_k_wos_lengths[I1] = N; c_g_n_k_wos_lengths[I1] = N;
const auto in_gemmmraw_gemmkraw_desc = const auto in_gemmmraw_gemmkraw_desc =
conv_to_gemm_transformer.template MakeADescriptor_M_K<InputLayout>( conv_to_gemm_transformer.template MakeADescriptor_M_K<ImageLayout>(
a_g_n_c_wis_lengths, a_g_n_c_wis_lengths,
input_g_n_c_wis_strides, image_g_n_c_wis_strides,
b_g_k_c_xs_lengths, b_g_k_c_xs_lengths,
{}, // not needed for A Descriptor {}, // not needed for A Descriptor
c_g_n_k_wos_lengths, c_g_n_k_wos_lengths,
...@@ -132,7 +109,7 @@ struct DeviceImageToColumnImpl ...@@ -132,7 +109,7 @@ struct DeviceImageToColumnImpl
const ck::index_t C, const ck::index_t C,
const std::array<index_t, NDimSpatial>& filter_spatial_lengths, const std::array<index_t, NDimSpatial>& filter_spatial_lengths,
const std::array<index_t, NDimSpatial>& output_spatial_lengths, const std::array<index_t, NDimSpatial>& output_spatial_lengths,
const std::array<index_t, 2>& output_m_k_strides) const std::array<index_t, 2>& gemm_m_k_strides)
{ {
const index_t NDoHoWo = const index_t NDoHoWo =
N * ck::accumulate_n<index_t>( N * ck::accumulate_n<index_t>(
...@@ -141,7 +118,7 @@ struct DeviceImageToColumnImpl ...@@ -141,7 +118,7 @@ struct DeviceImageToColumnImpl
C * ck::accumulate_n<index_t>( C * ck::accumulate_n<index_t>(
filter_spatial_lengths.begin(), NDimSpatial, 1, std::multiplies<>()); filter_spatial_lengths.begin(), NDimSpatial, 1, std::multiplies<>());
const auto desc_mraw_kraw = make_naive_tensor_descriptor( const auto desc_mraw_kraw = make_naive_tensor_descriptor(
make_tuple(NDoHoWo, CZYX), make_tuple(output_m_k_strides[I0], output_m_k_strides[I1])); make_tuple(NDoHoWo, CZYX), make_tuple(gemm_m_k_strides[I0], gemm_m_k_strides[I1]));
const auto desc_m_k = matrix_padder.PadADescriptor_M_K(desc_mraw_kraw); const auto desc_m_k = matrix_padder.PadADescriptor_M_K(desc_mraw_kraw);
return desc_m_k; return desc_m_k;
...@@ -155,28 +132,29 @@ struct DeviceImageToColumnImpl ...@@ -155,28 +132,29 @@ struct DeviceImageToColumnImpl
decltype(BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, KPerBlock, OutputGridDesc>( decltype(BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, KPerBlock, OutputGridDesc>(
OutputGridDesc{}))>; OutputGridDesc{}))>;
using GridwiseImageToColumnKernel = GridwiseImageToColumn<InputGridDesc, using GridwiseTensorRearrangeKernel = GridwiseTensorRearrange<InputGridDesc,
InputDataType, InputDataType,
OutputGridDesc, OutputGridDesc,
OutputDataType, OutputDataType,
BlockSize, BlockSize,
MPerBlock, MPerBlock,
KPerBlock, KPerBlock,
ThreadClusterLengths, ThreadClusterLengths,
ScalarPerVector, ScalarPerVector,
Block2ETileMap>; InMemoryDataOperationEnum::Set,
Block2ETileMap>;
struct Argument : public BaseArgument struct Argument : public BaseArgument
{ {
Argument(const void* p_in, // input image Argument(const void* p_in, // input image
void* p_out, // output image void* p_out, // gemm form
const ck::index_t N, const ck::index_t N,
const ck::index_t C, const ck::index_t C,
const std::array<index_t, NDimSpatial>& input_spatial_lengths, const std::array<index_t, NDimSpatial>& input_spatial_lengths,
const std::array<index_t, NDimSpatial>& filter_spatial_lengths, const std::array<index_t, NDimSpatial>& filter_spatial_lengths,
const std::array<index_t, NDimSpatial>& output_spatial_lengths, const std::array<index_t, NDimSpatial>& output_spatial_lengths,
const std::array<index_t, NDimSpatial + 3>& input_g_n_c_wis_strides, const std::array<index_t, NDimSpatial + 3>& image_g_n_c_wis_strides,
const std::array<index_t, 2>& output_m_k_strides, const std::array<index_t, 2>& gemm_m_k_strides,
const std::array<index_t, NDimSpatial>& conv_filter_strides, const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations, const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads, const std::array<index_t, NDimSpatial>& input_left_pads,
...@@ -185,7 +163,7 @@ struct DeviceImageToColumnImpl ...@@ -185,7 +163,7 @@ struct DeviceImageToColumnImpl
X_(filter_spatial_lengths[NDimSpatial - I1]), X_(filter_spatial_lengths[NDimSpatial - I1]),
p_in_{static_cast<const InputDataType*>(p_in)}, p_in_{static_cast<const InputDataType*>(p_in)},
p_out_{static_cast<OutputDataType*>(p_out)}, p_out_{static_cast<OutputDataType*>(p_out)},
input_g_n_c_wis_strides_{input_g_n_c_wis_strides}, image_g_n_c_wis_strides_{image_g_n_c_wis_strides},
conv_filter_strides_{conv_filter_strides}, conv_filter_strides_{conv_filter_strides},
conv_filter_dilations_{conv_filter_dilations}, conv_filter_dilations_{conv_filter_dilations},
input_left_pads_{input_left_pads}, input_left_pads_{input_left_pads},
...@@ -197,7 +175,7 @@ struct DeviceImageToColumnImpl ...@@ -197,7 +175,7 @@ struct DeviceImageToColumnImpl
input_spatial_lengths, input_spatial_lengths,
filter_spatial_lengths, filter_spatial_lengths,
output_spatial_lengths, output_spatial_lengths,
input_g_n_c_wis_strides, image_g_n_c_wis_strides,
conv_filter_strides, conv_filter_strides,
conv_filter_dilations, conv_filter_dilations,
...@@ -205,7 +183,7 @@ struct DeviceImageToColumnImpl ...@@ -205,7 +183,7 @@ struct DeviceImageToColumnImpl
input_right_pads); input_right_pads);
out_grid_desc_m_k_ = MakeOutDescriptor_M_K( out_grid_desc_m_k_ = MakeOutDescriptor_M_K(
N, C, filter_spatial_lengths, output_spatial_lengths, output_m_k_strides); N, C, filter_spatial_lengths, output_spatial_lengths, gemm_m_k_strides);
} }
void Print() const void Print() const
...@@ -220,7 +198,7 @@ struct DeviceImageToColumnImpl ...@@ -220,7 +198,7 @@ struct DeviceImageToColumnImpl
const InputDataType* p_in_; const InputDataType* p_in_;
OutputDataType* p_out_; OutputDataType* p_out_;
const std::array<index_t, NDimSpatial + 3>& input_g_n_c_wis_strides_; const std::array<index_t, NDimSpatial + 3>& image_g_n_c_wis_strides_;
const std::array<index_t, NDimSpatial>& conv_filter_strides_; const std::array<index_t, NDimSpatial>& conv_filter_strides_;
const std::array<index_t, NDimSpatial>& conv_filter_dilations_; const std::array<index_t, NDimSpatial>& conv_filter_dilations_;
const std::array<index_t, NDimSpatial>& input_left_pads_; const std::array<index_t, NDimSpatial>& input_left_pads_;
...@@ -243,12 +221,12 @@ struct DeviceImageToColumnImpl ...@@ -243,12 +221,12 @@ struct DeviceImageToColumnImpl
BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, KPerBlock, OutputGridDesc>( BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, KPerBlock, OutputGridDesc>(
arg.out_grid_desc_m_k_); arg.out_grid_desc_m_k_);
const index_t grid_size = block_2_tile_map.CalculateGridSize(arg.out_grid_desc_m_k_); const index_t grid_size = block_2_tile_map.CalculateGridSize(arg.out_grid_desc_m_k_);
const auto kernel = kernel_image_to_column<InputGridDesc, const auto kernel = kernel_tensor_rearrange<InputGridDesc,
InputDataType, InputDataType,
OutputGridDesc, OutputGridDesc,
OutputDataType, OutputDataType,
Block2ETileMap, Block2ETileMap,
GridwiseImageToColumnKernel>; GridwiseTensorRearrangeKernel>;
float elapsed_time = launch_and_time_kernel(stream_config, float elapsed_time = launch_and_time_kernel(stream_config,
kernel, kernel,
...@@ -273,12 +251,8 @@ struct DeviceImageToColumnImpl ...@@ -273,12 +251,8 @@ struct DeviceImageToColumnImpl
bool IsSupportedArgument(const Argument& arg) bool IsSupportedArgument(const Argument& arg)
{ {
using namespace tensor_layout::convolution; using namespace tensor_layout::convolution;
if(!(std::is_same_v<InputLayout, GNWC> || std::is_same_v<InputLayout, GNHWC> || if constexpr(!(std::is_same_v<ImageLayout, GNWC> || std::is_same_v<ImageLayout, GNHWC> ||
std::is_same_v<InputLayout, GNDHWC>)) std::is_same_v<ImageLayout, GNDHWC>))
{
return false;
}
if(!(NDimSpatial >= 1 && NDimSpatial <= 3))
{ {
return false; return false;
} }
...@@ -287,8 +261,8 @@ struct DeviceImageToColumnImpl ...@@ -287,8 +261,8 @@ struct DeviceImageToColumnImpl
const auto w_pad_right = arg.input_right_pads_[NDimSpatial - I1]; const auto w_pad_right = arg.input_right_pads_[NDimSpatial - I1];
const auto dilation_x = arg.conv_filter_dilations_[NDimSpatial - I1]; const auto dilation_x = arg.conv_filter_dilations_[NDimSpatial - I1];
const auto stride_x = arg.conv_filter_strides_[NDimSpatial - I1]; const auto stride_x = arg.conv_filter_strides_[NDimSpatial - I1];
bool is_w_packed = arg.input_g_n_c_wis_strides_[NDimSpatial + I2] == arg.C_; bool is_w_packed = arg.image_g_n_c_wis_strides_[NDimSpatial + I2] == arg.C_;
bool is_c_packed = arg.input_g_n_c_wis_strides_[I2] == 1; bool is_c_packed = arg.image_g_n_c_wis_strides_[I2] == 1;
// check vector acces with c not packed // check vector acces with c not packed
if(!is_c_packed && ScalarPerVector != 1) if(!is_c_packed && ScalarPerVector != 1)
...@@ -310,8 +284,8 @@ struct DeviceImageToColumnImpl ...@@ -310,8 +284,8 @@ struct DeviceImageToColumnImpl
if(dilation_x > 1 && arg.C_ % ScalarPerVector != 0) if(dilation_x > 1 && arg.C_ % ScalarPerVector != 0)
return false; return false;
return GridwiseImageToColumnKernel::CheckValidity(arg.in_grid_desc_m_k_, return GridwiseTensorRearrangeKernel::CheckValidity(arg.in_grid_desc_m_k_,
arg.out_grid_desc_m_k_); arg.out_grid_desc_m_k_);
} }
bool IsSupportedArgument(const BaseArgument* p_arg) override bool IsSupportedArgument(const BaseArgument* p_arg) override
...@@ -320,14 +294,14 @@ struct DeviceImageToColumnImpl ...@@ -320,14 +294,14 @@ struct DeviceImageToColumnImpl
} }
static auto MakeArgument(const void* p_in, // input image static auto MakeArgument(const void* p_in, // input image
void* p_out, // output image void* p_out, // gemm form
const ck::index_t N, const ck::index_t N,
const ck::index_t C, const ck::index_t C,
const std::array<index_t, NDimSpatial>& input_spatial_lengths, const std::array<index_t, NDimSpatial>& input_spatial_lengths,
const std::array<index_t, NDimSpatial>& filter_spatial_lengths, const std::array<index_t, NDimSpatial>& filter_spatial_lengths,
const std::array<index_t, NDimSpatial>& output_spatial_lengths, const std::array<index_t, NDimSpatial>& output_spatial_lengths,
const std::array<index_t, NDimSpatial + 3>& input_g_n_c_wis_strides, const std::array<index_t, NDimSpatial + 3>& image_g_n_c_wis_strides,
const std::array<index_t, 2>& output_m_k_strides, const std::array<index_t, 2>& gemm_m_k_strides,
const std::array<index_t, NDimSpatial>& conv_filter_strides, const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations, const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads, const std::array<index_t, NDimSpatial>& input_left_pads,
...@@ -340,8 +314,8 @@ struct DeviceImageToColumnImpl ...@@ -340,8 +314,8 @@ struct DeviceImageToColumnImpl
input_spatial_lengths, input_spatial_lengths,
filter_spatial_lengths, filter_spatial_lengths,
output_spatial_lengths, output_spatial_lengths,
input_g_n_c_wis_strides, image_g_n_c_wis_strides,
output_m_k_strides, gemm_m_k_strides,
conv_filter_strides, conv_filter_strides,
conv_filter_dilations, conv_filter_dilations,
input_left_pads, input_left_pads,
...@@ -352,14 +326,14 @@ struct DeviceImageToColumnImpl ...@@ -352,14 +326,14 @@ struct DeviceImageToColumnImpl
std::unique_ptr<BaseArgument> std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_in, // input image MakeArgumentPointer(const void* p_in, // input image
void* p_out, // output image void* p_out, // gemm form
const ck::index_t N, const ck::index_t N,
const ck::index_t C, const ck::index_t C,
const std::array<index_t, NDimSpatial>& input_spatial_lengths, const std::array<index_t, NDimSpatial>& input_spatial_lengths,
const std::array<index_t, NDimSpatial>& filter_spatial_lengths, const std::array<index_t, NDimSpatial>& filter_spatial_lengths,
const std::array<index_t, NDimSpatial>& output_spatial_lengths, const std::array<index_t, NDimSpatial>& output_spatial_lengths,
const std::array<index_t, NDimSpatial + 3>& input_g_n_c_wis_strides, const std::array<index_t, NDimSpatial + 3>& image_g_n_c_wis_strides,
const std::array<index_t, 2>& output_m_k_strides, const std::array<index_t, 2>& gemm_m_k_strides,
const std::array<index_t, NDimSpatial>& conv_filter_strides, const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations, const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads, const std::array<index_t, NDimSpatial>& input_left_pads,
...@@ -372,8 +346,8 @@ struct DeviceImageToColumnImpl ...@@ -372,8 +346,8 @@ struct DeviceImageToColumnImpl
input_spatial_lengths, input_spatial_lengths,
filter_spatial_lengths, filter_spatial_lengths,
output_spatial_lengths, output_spatial_lengths,
input_g_n_c_wis_strides, image_g_n_c_wis_strides,
output_m_k_strides, gemm_m_k_strides,
conv_filter_strides, conv_filter_strides,
conv_filter_dilations, conv_filter_dilations,
input_left_pads, input_left_pads,
......
...@@ -28,6 +28,7 @@ template <typename XDataType, ...@@ -28,6 +28,7 @@ template <typename XDataType,
typename BetaDataType, typename BetaDataType,
typename ComputeDataType, typename ComputeDataType,
typename YDataType, typename YDataType,
typename SaveMeanInvStdDataType,
typename YElementwiseOperation, typename YElementwiseOperation,
index_t Rank, index_t Rank,
index_t NumReduceDim, index_t NumReduceDim,
...@@ -43,12 +44,13 @@ template <typename XDataType, ...@@ -43,12 +44,13 @@ template <typename XDataType,
index_t BetaSrcVectorDim, index_t BetaSrcVectorDim,
index_t BetaSrcVectorSize, index_t BetaSrcVectorSize,
index_t YDstVectorSize, index_t YDstVectorSize,
index_t SaveMeanInvStdDstVectorSize,
bool UseWelford = true> bool UseWelford = true>
struct DeviceNormalizationImpl : public DeviceNormalization<XDataType, struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
GammaDataType, GammaDataType,
BetaDataType, BetaDataType,
ComputeDataType,
YDataType, YDataType,
SaveMeanInvStdDataType,
YElementwiseOperation, YElementwiseOperation,
Rank, Rank,
NumReduceDim> NumReduceDim>
...@@ -64,18 +66,24 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType, ...@@ -64,18 +66,24 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
(BetaSrcVectorDim == 1 && KThreadSliceSize % BetaSrcVectorSize == 0)), (BetaSrcVectorDim == 1 && KThreadSliceSize % BetaSrcVectorSize == 0)),
"Invalid thread slice sizes and/or beta vector sizes configuration, please check!"); "Invalid thread slice sizes and/or beta vector sizes configuration, please check!");
static_assert(MThreadSliceSize % SaveMeanInvStdDstVectorSize == 0,
"Invalid thread slice sizes and/or save mean and inverse std vector sizes "
"configuration, please check!");
using PassThrough = tensor_operation::element_wise::PassThrough; using PassThrough = tensor_operation::element_wise::PassThrough;
static constexpr index_t NumInvariantDim = Rank - NumReduceDim;
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
static constexpr bool reduceAllDim = (NumInvariantDim == 0);
static_assert(!reduceAllDim); // TODO
static auto MakeSrc2dDescriptor(const std::vector<index_t>& inLengths, static auto MakeSrc2dDescriptor(const std::vector<index_t>& inLengths,
const std::vector<index_t>& inStrides, const std::vector<index_t>& inStrides,
int numBlockTileIteration) int numBlockTileIteration)
{ {
constexpr index_t NumInvariantDim = Rank - NumReduceDim;
static constexpr index_t numSrcDim = Rank; static constexpr index_t numSrcDim = Rank;
static constexpr bool reduceAllDim = (NumInvariantDim == 0);
const auto tupleSrcLengths = make_tuple_from_array(inLengths, Number<numSrcDim>{}); const auto tupleSrcLengths = make_tuple_from_array(inLengths, Number<numSrcDim>{});
const auto tupleSrcStrides = make_tuple_from_array(inStrides, Number<numSrcDim>{}); const auto tupleSrcStrides = make_tuple_from_array(inStrides, Number<numSrcDim>{});
...@@ -133,7 +141,37 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType, ...@@ -133,7 +141,37 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
return (in_grid_desc_m_k_padded); return (in_grid_desc_m_k_padded);
}; };
static auto MakeSaveMeanInvStdDescriptor_M(const std::vector<index_t>& lengths,
const std::vector<index_t>& strides)
{
using InvariantDims = typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type;
const auto tupleSrcLengths = make_tuple_from_array_and_index_seq(lengths, InvariantDims{});
const auto tupleSrcStrides = make_tuple_from_array_and_index_seq(strides, InvariantDims{});
const auto desc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides);
const auto grid_desc_m =
transform_tensor_descriptor(desc,
make_tuple(make_merge_transform(tupleSrcLengths)),
make_tuple(InvariantDims{}),
make_tuple(Sequence<0>{}));
const auto invariantLength = grid_desc_m.GetLength(Number<0>{});
const auto pad_M =
math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
auto grid_desc_m_padded = transform_tensor_descriptor(
grid_desc_m,
make_tuple(make_right_pad_transform(invariantLength, pad_M)),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0>{}));
return grid_desc_m_padded;
}
using GridDesc_M_K = decltype(MakeSrc2dDescriptor({1}, {1}, 1)); using GridDesc_M_K = decltype(MakeSrc2dDescriptor({1}, {1}, 1));
using GridDesc_M = decltype(MakeSaveMeanInvStdDescriptor_M({1}, {1}));
struct Argument : public BaseArgument struct Argument : public BaseArgument
{ {
...@@ -142,17 +180,23 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType, ...@@ -142,17 +180,23 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
const std::vector<index_t> gammaStrides, const std::vector<index_t> gammaStrides,
const std::vector<index_t> betaStrides, const std::vector<index_t> betaStrides,
const std::vector<index_t> yStrides, const std::vector<index_t> yStrides,
const std::vector<index_t> saveMeanStrides,
const std::vector<index_t> saveInvStdStrides,
const std::vector<index_t> reduceDims, const std::vector<index_t> reduceDims,
YElementwiseOperation y_elementwise_op, YElementwiseOperation y_elementwise_op,
double epsilon, double epsilon,
const XDataType* p_x, const XDataType* p_x,
const GammaDataType* p_gamma, const GammaDataType* p_gamma,
const BetaDataType* p_beta, const BetaDataType* p_beta,
YDataType* p_y) YDataType* p_y,
SaveMeanInvStdDataType* p_saveMean,
SaveMeanInvStdDataType* p_saveInvStd)
: p_x_(p_x), : p_x_(p_x),
p_gamma_(p_gamma), p_gamma_(p_gamma),
p_beta_(p_beta), p_beta_(p_beta),
p_y_(p_y), p_y_(p_y),
p_saveMean_(p_saveMean),
p_saveInvStd_(p_saveInvStd),
y_elementwise_op_(y_elementwise_op) y_elementwise_op_(y_elementwise_op)
{ {
epsilon_ = static_cast<ComputeDataType>(epsilon); epsilon_ = static_cast<ComputeDataType>(epsilon);
...@@ -162,16 +206,14 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType, ...@@ -162,16 +206,14 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
yStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(yStrides, reduceDims); yStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(yStrides, reduceDims);
gammaStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(gammaStrides, reduceDims); gammaStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(gammaStrides, reduceDims);
betaStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(betaStrides, reduceDims); betaStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(betaStrides, reduceDims);
saveMeanStrides_ = saveMeanStrides;
saveInvStdStrides_ = saveInvStdStrides;
long_index_t invariant_length; std::tie(MRaw_, KRaw_) = get_2d_lengths<Rank, NumReduceDim>(Lengths_);
long_index_t reduce_length;
std::tie(invariant_length, reduce_length) =
get_2d_lengths<Rank, NumReduceDim>(Lengths_);
numBlockTileIteration_ = math::integer_divide_ceil(reduce_length, K_BlockTileSize); numBlockTileIteration_ = math::integer_divide_ceil(KRaw_, K_BlockTileSize);
gridSize_ = math::integer_divide_ceil(invariant_length, M_BlockTileSize); gridSize_ = math::integer_divide_ceil(MRaw_, M_BlockTileSize);
x_grid_desc_m_k_ = MakeSrc2dDescriptor(Lengths_, xStrides_, numBlockTileIteration_); x_grid_desc_m_k_ = MakeSrc2dDescriptor(Lengths_, xStrides_, numBlockTileIteration_);
gamma_grid_desc_m_k_ = gamma_grid_desc_m_k_ =
...@@ -179,9 +221,16 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType, ...@@ -179,9 +221,16 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
beta_grid_desc_m_k_ = beta_grid_desc_m_k_ =
MakeSrc2dDescriptor(Lengths_, betaStrides_, numBlockTileIteration_); MakeSrc2dDescriptor(Lengths_, betaStrides_, numBlockTileIteration_);
y_grid_desc_m_k_ = MakeSrc2dDescriptor(Lengths_, yStrides_, numBlockTileIteration_); y_grid_desc_m_k_ = MakeSrc2dDescriptor(Lengths_, yStrides_, numBlockTileIteration_);
save_mean_grid_desc_m_ = MakeSaveMeanInvStdDescriptor_M(Lengths_, saveMeanStrides);
save_inv_std_grid_desc_m_ = MakeSaveMeanInvStdDescriptor_M(Lengths_, saveInvStdStrides);
isSweeponce_ = isSweeponce_ =
x_grid_desc_m_k_.GetLength(Number<1>{}) <= KThreadClusterSize * KThreadSliceSize; x_grid_desc_m_k_.GetLength(Number<1>{}) <= KThreadClusterSize * KThreadSliceSize;
if constexpr(NumInvariantDim == 0)
invariant_lowest_length_ = 1;
else
invariant_lowest_length_ = Lengths_[NumInvariantDim - 1];
} }
ComputeDataType epsilon_; ComputeDataType epsilon_;
...@@ -190,12 +239,16 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType, ...@@ -190,12 +239,16 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
const GammaDataType* p_gamma_; const GammaDataType* p_gamma_;
const BetaDataType* p_beta_; const BetaDataType* p_beta_;
YDataType* p_y_; YDataType* p_y_;
SaveMeanInvStdDataType* p_saveMean_;
SaveMeanInvStdDataType* p_saveInvStd_;
std::vector<index_t> Lengths_; std::vector<index_t> Lengths_;
std::vector<index_t> xStrides_; std::vector<index_t> xStrides_;
std::vector<index_t> gammaStrides_; std::vector<index_t> gammaStrides_;
std::vector<index_t> betaStrides_; std::vector<index_t> betaStrides_;
std::vector<index_t> yStrides_; std::vector<index_t> yStrides_;
std::vector<index_t> saveMeanStrides_;
std::vector<index_t> saveInvStdStrides_;
YElementwiseOperation y_elementwise_op_; YElementwiseOperation y_elementwise_op_;
...@@ -206,7 +259,14 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType, ...@@ -206,7 +259,14 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
GridDesc_M_K gamma_grid_desc_m_k_; GridDesc_M_K gamma_grid_desc_m_k_;
GridDesc_M_K beta_grid_desc_m_k_; GridDesc_M_K beta_grid_desc_m_k_;
GridDesc_M_K y_grid_desc_m_k_; GridDesc_M_K y_grid_desc_m_k_;
GridDesc_M save_mean_grid_desc_m_;
GridDesc_M save_inv_std_grid_desc_m_;
bool isSweeponce_; bool isSweeponce_;
index_t MRaw_; // invarient length
index_t KRaw_; // reduce length
index_t invariant_lowest_length_;
}; };
struct Invoker : public BaseInvoker struct Invoker : public BaseInvoker
...@@ -217,9 +277,11 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType, ...@@ -217,9 +277,11 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
GammaDataType, GammaDataType,
BetaDataType, BetaDataType,
YDataType, YDataType,
SaveMeanInvStdDataType,
ComputeDataType, ComputeDataType,
YElementwiseOperation, YElementwiseOperation,
GridDesc_M_K, GridDesc_M_K,
GridDesc_M,
BlockSize, BlockSize,
MThreadClusterSize, MThreadClusterSize,
KThreadClusterSize, KThreadClusterSize,
...@@ -233,6 +295,7 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType, ...@@ -233,6 +295,7 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
BetaSrcVectorSize, BetaSrcVectorSize,
XYSrcVectorDim, XYSrcVectorDim,
YDstVectorSize, YDstVectorSize,
SaveMeanInvStdDstVectorSize,
UseWelford>(arg.isSweeponce_); UseWelford>(arg.isSweeponce_);
float avg_time = 0; float avg_time = 0;
...@@ -245,12 +308,16 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType, ...@@ -245,12 +308,16 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
arg.gamma_grid_desc_m_k_, arg.gamma_grid_desc_m_k_,
arg.beta_grid_desc_m_k_, arg.beta_grid_desc_m_k_,
arg.y_grid_desc_m_k_, arg.y_grid_desc_m_k_,
arg.save_mean_grid_desc_m_,
arg.save_inv_std_grid_desc_m_,
arg.numBlockTileIteration_, arg.numBlockTileIteration_,
arg.epsilon_, arg.epsilon_,
arg.p_x_, arg.p_x_,
arg.p_gamma_, arg.p_gamma_,
arg.p_beta_, arg.p_beta_,
arg.p_y_, arg.p_y_,
arg.p_saveMean_,
arg.p_saveInvStd_,
arg.y_elementwise_op_); arg.y_elementwise_op_);
return (avg_time); return (avg_time);
...@@ -267,8 +334,6 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType, ...@@ -267,8 +334,6 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
{ {
const Argument* p_arg_ = dynamic_cast<const Argument*>(p_arg); const Argument* p_arg_ = dynamic_cast<const Argument*>(p_arg);
constexpr index_t NumInvariantDim = Rank - NumReduceDim;
if constexpr(XYSrcVectorDim == 0) if constexpr(XYSrcVectorDim == 0)
{ {
if constexpr(NumInvariantDim == 0) if constexpr(NumInvariantDim == 0)
...@@ -277,13 +342,15 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType, ...@@ -277,13 +342,15 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
} }
else else
{ {
printf("!!!! %d\n", p_arg_->invariant_lowest_length_);
if(p_arg_->xStrides_[NumInvariantDim - 1] != 1) if(p_arg_->xStrides_[NumInvariantDim - 1] != 1)
return false; return false;
if(p_arg_->invariant_lowest_length % XSrcVectorSize != 0) if(p_arg_->invariant_lowest_length_ % XSrcVectorSize != 0)
return false; return false;
if(p_arg_->invariant_lowest_length % YDstVectorSize != 0) if(p_arg_->invariant_lowest_length_ % YDstVectorSize != 0)
return false; return false;
}; };
} }
...@@ -325,7 +392,7 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType, ...@@ -325,7 +392,7 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
if(p_arg_->betaStrides_[NumInvariantDim - 1] != 1) if(p_arg_->betaStrides_[NumInvariantDim - 1] != 1)
return (false); return (false);
if(p_arg_->invariant_lowest_length % BetaSrcVectorSize != 0) if(p_arg_->invariant_lowest_length_ % BetaSrcVectorSize != 0)
return (false); return (false);
} }
else // if fastest dim is reduced else // if fastest dim is reduced
...@@ -337,6 +404,9 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType, ...@@ -337,6 +404,9 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
return (false); return (false);
} }
if(p_arg_->invariant_lowest_length_ % SaveMeanInvStdDstVectorSize != 0)
return false;
return true; return true;
}; };
...@@ -346,6 +416,8 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType, ...@@ -346,6 +416,8 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
const std::vector<index_t> gammaStrides, const std::vector<index_t> gammaStrides,
const std::vector<index_t> betaStrides, const std::vector<index_t> betaStrides,
const std::vector<index_t> yStrides, const std::vector<index_t> yStrides,
const std::vector<index_t> saveMeanStrides,
const std::vector<index_t> saveInvStdStrides,
const std::vector<index_t> reduceDims, const std::vector<index_t> reduceDims,
double epsilon, double epsilon,
const void* p_x, const void* p_x,
...@@ -353,27 +425,30 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType, ...@@ -353,27 +425,30 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
const void* p_beta, const void* p_beta,
void* p_y, void* p_y,
void* p_saveMean, void* p_saveMean,
void* p_saveInvVar, void* p_saveInvStd,
YElementwiseOperation y_elementwise_op) override YElementwiseOperation y_elementwise_op) override
{ {
// TODO if(lengths.size() != Rank || xStrides.size() != Rank || gammaStrides.size() != Rank ||
// Optional cache of the intermediate results (mean and InvVariance) during the betaStrides.size() != Rank || yStrides.size() != Rank ||
// forward pass could speedup in the backward saveMeanStrides.size() != NumInvariantDim || saveInvStdStrides.size() != NumInvariantDim)
ignore = p_saveMean; throw std::runtime_error("dimension is incorrect");
ignore = p_saveInvVar;
return std::make_unique<Argument>(lengths, return std::make_unique<Argument>(lengths,
xStrides, xStrides,
gammaStrides, gammaStrides,
betaStrides, betaStrides,
yStrides, yStrides,
saveMeanStrides,
saveInvStdStrides,
reduceDims, reduceDims,
y_elementwise_op, y_elementwise_op,
epsilon, epsilon,
static_cast<const XDataType*>(p_x), static_cast<const XDataType*>(p_x),
static_cast<const GammaDataType*>(p_gamma), static_cast<const GammaDataType*>(p_gamma),
static_cast<const BetaDataType*>(p_beta), static_cast<const BetaDataType*>(p_beta),
static_cast<YDataType*>(p_y)); static_cast<YDataType*>(p_y),
static_cast<SaveMeanInvStdDataType*>(p_saveMean),
static_cast<SaveMeanInvStdDataType*>(p_saveInvStd));
}; };
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
......
...@@ -19,7 +19,7 @@ ...@@ -19,7 +19,7 @@
namespace ck { namespace ck {
template <typename GridwiseWelford, template <typename GridwiseWelford,
typename XDataType, typename XDataType,
typename MeanVarDataType, typename WorkspaceMeanVarDataType,
typename ComputeDataType, typename ComputeDataType,
typename XGridDesc_M_K, typename XGridDesc_M_K,
typename MeanVarGridDesc_M_KBlock> typename MeanVarGridDesc_M_KBlock>
...@@ -28,8 +28,8 @@ kernel_normalizationSplitK1st(const XGridDesc_M_K x_grid_desc_m_k, ...@@ -28,8 +28,8 @@ kernel_normalizationSplitK1st(const XGridDesc_M_K x_grid_desc_m_k,
const MeanVarGridDesc_M_KBlock mean_var_grid_desc_m_kblock, const MeanVarGridDesc_M_KBlock mean_var_grid_desc_m_kblock,
index_t num_k_block_tile_iteration, index_t num_k_block_tile_iteration,
const XDataType* const __restrict__ p_x_global, const XDataType* const __restrict__ p_x_global,
MeanVarDataType* const __restrict__ p_welford_mean, WorkspaceMeanVarDataType* const __restrict__ p_welford_mean,
MeanVarDataType* const __restrict__ p_welford_variance, WorkspaceMeanVarDataType* const __restrict__ p_welford_variance,
int32_t* const __restrict__ p_welford_count) int32_t* const __restrict__ p_welford_count)
{ {
GridwiseWelford::Run(x_grid_desc_m_k, GridwiseWelford::Run(x_grid_desc_m_k,
...@@ -42,16 +42,18 @@ kernel_normalizationSplitK1st(const XGridDesc_M_K x_grid_desc_m_k, ...@@ -42,16 +42,18 @@ kernel_normalizationSplitK1st(const XGridDesc_M_K x_grid_desc_m_k,
}; };
template <typename GridwiseWelfordNormalization, template <typename GridwiseWelfordNormalization,
typename MeanVarDataType, typename WorkspaceMeanVarDataType,
typename XDataType, typename XDataType,
typename GammaDataType, typename GammaDataType,
typename BetaDataType, typename BetaDataType,
typename YDataType, typename YDataType,
typename SaveMeanInvStdDataType,
typename ComputeDataType, typename ComputeDataType,
typename YElementwiseOperation, typename YElementwiseOperation,
typename MeanVarGridDesc_M_KBlock, typename MeanVarGridDesc_M_KBlock,
typename CountGridDesc_M_KBlock, typename CountGridDesc_M_KBlock,
typename XYGammaBetaGridDesc_M_K> typename XYGammaBetaGridDesc_M_K,
typename SaveMeanInvStdGridDesc_M>
__global__ void __global__ void
kernel_normalizationSplitK2nd(const MeanVarGridDesc_M_KBlock mean_var_grid_desc_m_kblock, kernel_normalizationSplitK2nd(const MeanVarGridDesc_M_KBlock mean_var_grid_desc_m_kblock,
const CountGridDesc_M_KBlock count_grid_desc_m_kblock, const CountGridDesc_M_KBlock count_grid_desc_m_kblock,
...@@ -59,17 +61,21 @@ kernel_normalizationSplitK2nd(const MeanVarGridDesc_M_KBlock mean_var_grid_desc_ ...@@ -59,17 +61,21 @@ kernel_normalizationSplitK2nd(const MeanVarGridDesc_M_KBlock mean_var_grid_desc_
const XYGammaBetaGridDesc_M_K gamma_grid_desc_m_k, const XYGammaBetaGridDesc_M_K gamma_grid_desc_m_k,
const XYGammaBetaGridDesc_M_K beta_grid_desc_m_k, const XYGammaBetaGridDesc_M_K beta_grid_desc_m_k,
const XYGammaBetaGridDesc_M_K y_grid_desc_m_k, const XYGammaBetaGridDesc_M_K y_grid_desc_m_k,
const SaveMeanInvStdGridDesc_M save_mean_grid_desc_m,
const SaveMeanInvStdGridDesc_M save_inv_std_grid_desc_m,
index_t num_k_mean_var_count_iteration, index_t num_k_mean_var_count_iteration,
index_t num_k_block_tile_iteration, index_t num_k_block_tile_iteration,
index_t k_grid_size, index_t k_grid_size,
ComputeDataType epsilon, ComputeDataType epsilon,
const MeanVarDataType* const p_mean_global, const WorkspaceMeanVarDataType* const p_mean_global,
const MeanVarDataType* const p_variance_global, const WorkspaceMeanVarDataType* const p_variance_global,
const int32_t* const p_welford_count_global, const int32_t* const p_welford_count_global,
const XDataType* const __restrict__ p_x_global, const XDataType* const __restrict__ p_x_global,
const GammaDataType* const __restrict__ p_gamma_global, const GammaDataType* const __restrict__ p_gamma_global,
const BetaDataType* const __restrict__ p_beta_global, const BetaDataType* const __restrict__ p_beta_global,
YDataType* const __restrict__ p_y_global, YDataType* const __restrict__ p_y_global,
SaveMeanInvStdDataType* const __restrict__ p_save_mean_global,
SaveMeanInvStdDataType* const __restrict__ p_save_inv_std_global,
const YElementwiseOperation y_elementwise_op) const YElementwiseOperation y_elementwise_op)
{ {
GridwiseWelfordNormalization::Run(mean_var_grid_desc_m_kblock, GridwiseWelfordNormalization::Run(mean_var_grid_desc_m_kblock,
...@@ -78,6 +84,8 @@ kernel_normalizationSplitK2nd(const MeanVarGridDesc_M_KBlock mean_var_grid_desc_ ...@@ -78,6 +84,8 @@ kernel_normalizationSplitK2nd(const MeanVarGridDesc_M_KBlock mean_var_grid_desc_
gamma_grid_desc_m_k, gamma_grid_desc_m_k,
beta_grid_desc_m_k, beta_grid_desc_m_k,
y_grid_desc_m_k, y_grid_desc_m_k,
save_mean_grid_desc_m,
save_inv_std_grid_desc_m,
num_k_mean_var_count_iteration, num_k_mean_var_count_iteration,
num_k_block_tile_iteration, num_k_block_tile_iteration,
k_grid_size, k_grid_size,
...@@ -89,6 +97,8 @@ kernel_normalizationSplitK2nd(const MeanVarGridDesc_M_KBlock mean_var_grid_desc_ ...@@ -89,6 +97,8 @@ kernel_normalizationSplitK2nd(const MeanVarGridDesc_M_KBlock mean_var_grid_desc_
p_gamma_global, p_gamma_global,
p_beta_global, p_beta_global,
p_y_global, p_y_global,
p_save_mean_global,
p_save_inv_std_global,
y_elementwise_op); y_elementwise_op);
}; };
} // namespace ck } // namespace ck
...@@ -107,6 +117,7 @@ template <typename XDataType, ...@@ -107,6 +117,7 @@ template <typename XDataType,
typename BetaDataType, typename BetaDataType,
typename ComputeDataType, typename ComputeDataType,
typename YDataType, typename YDataType,
typename SaveMeanInvStdDataType,
typename YElementwiseOperation, typename YElementwiseOperation,
index_t Rank, index_t Rank,
index_t NumReduceDim, index_t NumReduceDim,
...@@ -121,17 +132,18 @@ template <typename XDataType, ...@@ -121,17 +132,18 @@ template <typename XDataType,
index_t GammaSrcVectorSize, index_t GammaSrcVectorSize,
index_t BetaSrcVectorDim, index_t BetaSrcVectorDim,
index_t BetaSrcVectorSize, index_t BetaSrcVectorSize,
index_t YDstVectorSize> index_t YDstVectorSize,
index_t SaveMeanInvStdDstVectorSize>
struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType, struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
GammaDataType, GammaDataType,
BetaDataType, BetaDataType,
ComputeDataType,
YDataType, YDataType,
SaveMeanInvStdDataType,
YElementwiseOperation, YElementwiseOperation,
Rank, Rank,
NumReduceDim> NumReduceDim>
{ {
using MeanVarDataType = ComputeDataType; using WorkspaceMeanVarDataType = SaveMeanInvStdDataType;
static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize); static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize);
static_assert( static_assert(
...@@ -144,22 +156,28 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType, ...@@ -144,22 +156,28 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
(BetaSrcVectorDim == 1 && KThreadSliceSize % BetaSrcVectorSize == 0)), (BetaSrcVectorDim == 1 && KThreadSliceSize % BetaSrcVectorSize == 0)),
"Invalid thread slice sizes and/or beta vector sizes configuration, please check!"); "Invalid thread slice sizes and/or beta vector sizes configuration, please check!");
static_assert(MThreadSliceSize % SaveMeanInvStdDstVectorSize == 0,
"Invalid thread slice sizes and/or save mean and inverse std vector sizes "
"configuration, please check!");
using PassThrough = tensor_operation::element_wise::PassThrough; using PassThrough = tensor_operation::element_wise::PassThrough;
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
static constexpr index_t NumInvariantDim = Rank - NumReduceDim;
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
static constexpr bool reduceAllDim = (NumInvariantDim == 0);
static_assert(!reduceAllDim); // TODO
static auto MakeSrc2dDescriptor(const std::vector<index_t>& inLengths, static auto MakeSrc2dDescriptor(const std::vector<index_t>& inLengths,
const std::vector<index_t>& inStrides, const std::vector<index_t>& inStrides,
int kBlockSize, int kBlockSize,
int numBlockTileIteration) int numBlockTileIteration)
{ {
constexpr index_t NumInvariantDim = Rank - NumReduceDim;
static constexpr index_t numSrcDim = Rank; static constexpr index_t numSrcDim = Rank;
static constexpr bool reduceAllDim = (NumInvariantDim == 0);
const auto tupleSrcLengths = make_tuple_from_array(inLengths, Number<numSrcDim>{}); const auto tupleSrcLengths = make_tuple_from_array(inLengths, Number<numSrcDim>{});
const auto tupleSrcStrides = make_tuple_from_array(inStrides, Number<numSrcDim>{}); const auto tupleSrcStrides = make_tuple_from_array(inStrides, Number<numSrcDim>{});
...@@ -219,7 +237,7 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType, ...@@ -219,7 +237,7 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
}; };
template <typename DoPads, index_t MPerTile, index_t KPerTile> template <typename DoPads, index_t MPerTile, index_t KPerTile>
static auto MakeMeanVarDescriptor_M_K(index_t M, index_t K) static auto MakeWorkspaceMeanVarDescriptor_M_K(index_t M, index_t K)
{ {
const auto grid_desc_m_k = const auto grid_desc_m_k =
make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(K, I1)); make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(K, I1));
...@@ -227,26 +245,57 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType, ...@@ -227,26 +245,57 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
} }
template <typename DoPads, index_t MPerTile, index_t KPerTile> template <typename DoPads, index_t MPerTile, index_t KPerTile>
static auto MakeCountDescriptor_M_K(index_t M, index_t K) static auto MakeWorkspaceCountDescriptor_M_K(index_t M, index_t K)
{ {
const auto grid_desc_m_k = const auto grid_desc_m_k =
make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I0, I1)); make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I0, I1));
return PadTensorDescriptor(grid_desc_m_k, make_tuple(MPerTile, KPerTile), DoPads{}); return PadTensorDescriptor(grid_desc_m_k, make_tuple(MPerTile, KPerTile), DoPads{});
} }
static auto MakeSaveMeanInvStdDescriptor_M(const std::vector<index_t>& lengths,
const std::vector<index_t>& strides)
{
using InvariantDims = typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type;
const auto tupleSrcLengths = make_tuple_from_array_and_index_seq(lengths, InvariantDims{});
const auto tupleSrcStrides = make_tuple_from_array_and_index_seq(strides, InvariantDims{});
const auto desc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides);
const auto grid_desc_m =
transform_tensor_descriptor(desc,
make_tuple(make_merge_transform(tupleSrcLengths)),
make_tuple(InvariantDims{}),
make_tuple(Sequence<0>{}));
const auto invariantLength = grid_desc_m.GetLength(Number<0>{});
const auto pad_M =
math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
auto grid_desc_m_padded = transform_tensor_descriptor(
grid_desc_m,
make_tuple(make_right_pad_transform(invariantLength, pad_M)),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0>{}));
return grid_desc_m_padded;
}
using SrcGridDesc_M_K = decltype(MakeSrc2dDescriptor({1}, {1}, 1, 1)); using SrcGridDesc_M_K = decltype(MakeSrc2dDescriptor({1}, {1}, 1, 1));
using Kernel1MeanVarGridDesc_M_KBlock = using Kernel1MeanVarGridDesc_M_KBlock =
decltype(MakeMeanVarDescriptor_M_K<Sequence<true, false>, 1, 1>(1, 1)); decltype(MakeWorkspaceMeanVarDescriptor_M_K<Sequence<true, false>, 1, 1>(1, 1));
using Kernel2MeanVarGridDesc_M_KBlock = using Kernel2MeanVarGridDesc_M_KBlock =
decltype(MakeMeanVarDescriptor_M_K<Sequence<true, true>, 1, 1>(1, 1)); decltype(MakeWorkspaceMeanVarDescriptor_M_K<Sequence<true, true>, 1, 1>(1, 1));
using Kernel2CountGridDesc_M_KBlock = using Kernel2CountGridDesc_M_KBlock =
decltype(MakeCountDescriptor_M_K<Sequence<true, true>, 1, 1>(1, 1)); decltype(MakeWorkspaceCountDescriptor_M_K<Sequence<true, true>, 1, 1>(1, 1));
using SaveMeanInvStdGridDesc_M = decltype(MakeSaveMeanInvStdDescriptor_M({1}, {1}));
using GridwiseWelford = GridwiseNormalizationSplitK1st<XDataType, using GridwiseWelford = GridwiseNormalizationSplitK1st<XDataType,
ComputeDataType, ComputeDataType,
MeanVarDataType, WorkspaceMeanVarDataType,
SrcGridDesc_M_K, SrcGridDesc_M_K,
Kernel1MeanVarGridDesc_M_KBlock, Kernel1MeanVarGridDesc_M_KBlock,
BlockSize, BlockSize,
...@@ -258,16 +307,18 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType, ...@@ -258,16 +307,18 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
XSrcVectorSize>; XSrcVectorSize>;
using GridwiseWelfordNormalization = using GridwiseWelfordNormalization =
GridwiseNormalizationSplitK2nd<MeanVarDataType, GridwiseNormalizationSplitK2nd<WorkspaceMeanVarDataType,
XDataType, XDataType,
GammaDataType, GammaDataType,
BetaDataType, BetaDataType,
YDataType, YDataType,
SaveMeanInvStdDataType,
ComputeDataType, ComputeDataType,
YElementwiseOperation, YElementwiseOperation,
Kernel2MeanVarGridDesc_M_KBlock, Kernel2MeanVarGridDesc_M_KBlock,
Kernel2CountGridDesc_M_KBlock, Kernel2CountGridDesc_M_KBlock,
SrcGridDesc_M_K, SrcGridDesc_M_K,
SaveMeanInvStdGridDesc_M,
BlockSize, BlockSize,
MThreadClusterSize, MThreadClusterSize,
KThreadClusterSize, KThreadClusterSize,
...@@ -280,7 +331,8 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType, ...@@ -280,7 +331,8 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
BetaSrcVectorDim, BetaSrcVectorDim,
BetaSrcVectorSize, BetaSrcVectorSize,
XYVectorDim, XYVectorDim,
YDstVectorSize>; YDstVectorSize,
SaveMeanInvStdDstVectorSize>;
struct Argument : public BaseArgument struct Argument : public BaseArgument
{ {
...@@ -289,17 +341,23 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType, ...@@ -289,17 +341,23 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
const std::vector<index_t> gammaStrides, const std::vector<index_t> gammaStrides,
const std::vector<index_t> betaStrides, const std::vector<index_t> betaStrides,
const std::vector<index_t> yStrides, const std::vector<index_t> yStrides,
const std::vector<index_t> saveMeanStrides,
const std::vector<index_t> saveInvStdStrides,
const std::vector<index_t> reduceDims, const std::vector<index_t> reduceDims,
YElementwiseOperation y_elementwise_op, YElementwiseOperation y_elementwise_op,
double epsilon, double epsilon,
const XDataType* p_x, const XDataType* p_x,
const GammaDataType* p_gamma, const GammaDataType* p_gamma,
const BetaDataType* p_beta, const BetaDataType* p_beta,
YDataType* p_y) YDataType* p_y,
SaveMeanInvStdDataType* p_saveMean,
SaveMeanInvStdDataType* p_saveInvStd)
: p_x_(p_x), : p_x_(p_x),
p_gamma_(p_gamma), p_gamma_(p_gamma),
p_beta_(p_beta), p_beta_(p_beta),
p_y_(p_y), p_y_(p_y),
p_saveMean_(p_saveMean),
p_saveInvStd_(p_saveInvStd),
p_workspace_mean_{nullptr}, p_workspace_mean_{nullptr},
p_workspace_var_{nullptr}, p_workspace_var_{nullptr},
p_workspace_count_{nullptr}, p_workspace_count_{nullptr},
...@@ -312,6 +370,8 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType, ...@@ -312,6 +370,8 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
yStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(yStrides, reduceDims); yStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(yStrides, reduceDims);
gammaStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(gammaStrides, reduceDims); gammaStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(gammaStrides, reduceDims);
betaStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(betaStrides, reduceDims); betaStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(betaStrides, reduceDims);
saveMeanStrides_ = saveMeanStrides;
saveInvStdStrides_ = saveInvStdStrides;
std::tie(MRaw_, KRaw_) = get_2d_lengths<Rank, NumReduceDim>(Lengths_); std::tie(MRaw_, KRaw_) = get_2d_lengths<Rank, NumReduceDim>(Lengths_);
...@@ -346,20 +406,28 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType, ...@@ -346,20 +406,28 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
y_grid_desc_m_k_ = y_grid_desc_m_k_ =
MakeSrc2dDescriptor(Lengths_, yStrides_, kGridSize_, numBlockTileIteration_); MakeSrc2dDescriptor(Lengths_, yStrides_, kGridSize_, numBlockTileIteration_);
save_mean_grid_desc_m_ = MakeSaveMeanInvStdDescriptor_M(Lengths_, saveMeanStrides);
save_inv_std_grid_desc_m_ = MakeSaveMeanInvStdDescriptor_M(Lengths_, saveInvStdStrides);
// We don't need to pad in K dimension for Welford1. Set KPerTile 1. // We don't need to pad in K dimension for Welford1. Set KPerTile 1.
kernel1_mean_var_grid_desc_m_kblock_ = kernel1_mean_var_grid_desc_m_kblock_ =
MakeMeanVarDescriptor_M_K<Sequence<true, false>, M_BlockTileSize, 1>(MRaw_, MakeWorkspaceMeanVarDescriptor_M_K<Sequence<true, false>, M_BlockTileSize, 1>(
kGridSize_); MRaw_, kGridSize_);
kernel2_mean_var_grid_desc_m_kblock_ = kernel2_mean_var_grid_desc_m_kblock_ =
MakeMeanVarDescriptor_M_K<Sequence<true, true>, MakeWorkspaceMeanVarDescriptor_M_K<Sequence<true, true>,
M_BlockTileSize, M_BlockTileSize,
K_MeanVarCountBlockTileSize>(MRaw_, kGridSize_); K_MeanVarCountBlockTileSize>(MRaw_, kGridSize_);
kernel2_count_grid_desc_m_kblock_ = kernel2_count_grid_desc_m_kblock_ =
MakeCountDescriptor_M_K<Sequence<true, true>, MakeWorkspaceCountDescriptor_M_K<Sequence<true, true>,
M_BlockTileSize, M_BlockTileSize,
K_MeanVarCountBlockTileSize>(MRaw_, kGridSize_); K_MeanVarCountBlockTileSize>(MRaw_, kGridSize_);
if constexpr(NumInvariantDim == 0)
invariant_lowest_length_ = 1;
else
invariant_lowest_length_ = Lengths_[NumInvariantDim - 1];
} }
ComputeDataType epsilon_; ComputeDataType epsilon_;
...@@ -368,6 +436,8 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType, ...@@ -368,6 +436,8 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
const GammaDataType* p_gamma_; const GammaDataType* p_gamma_;
const BetaDataType* p_beta_; const BetaDataType* p_beta_;
YDataType* p_y_; YDataType* p_y_;
SaveMeanInvStdDataType* p_saveMean_;
SaveMeanInvStdDataType* p_saveInvStd_;
void* p_workspace_mean_; void* p_workspace_mean_;
void* p_workspace_var_; void* p_workspace_var_;
void* p_workspace_count_; void* p_workspace_count_;
...@@ -377,6 +447,8 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType, ...@@ -377,6 +447,8 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
std::vector<index_t> gammaStrides_; std::vector<index_t> gammaStrides_;
std::vector<index_t> betaStrides_; std::vector<index_t> betaStrides_;
std::vector<index_t> yStrides_; std::vector<index_t> yStrides_;
std::vector<index_t> saveMeanStrides_;
std::vector<index_t> saveInvStdStrides_;
YElementwiseOperation y_elementwise_op_; YElementwiseOperation y_elementwise_op_;
...@@ -389,6 +461,8 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType, ...@@ -389,6 +461,8 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
SrcGridDesc_M_K gamma_grid_desc_m_k_; SrcGridDesc_M_K gamma_grid_desc_m_k_;
SrcGridDesc_M_K beta_grid_desc_m_k_; SrcGridDesc_M_K beta_grid_desc_m_k_;
SrcGridDesc_M_K y_grid_desc_m_k_; SrcGridDesc_M_K y_grid_desc_m_k_;
SaveMeanInvStdGridDesc_M save_mean_grid_desc_m_;
SaveMeanInvStdGridDesc_M save_inv_std_grid_desc_m_;
Kernel1MeanVarGridDesc_M_KBlock kernel1_mean_var_grid_desc_m_kblock_; Kernel1MeanVarGridDesc_M_KBlock kernel1_mean_var_grid_desc_m_kblock_;
Kernel2MeanVarGridDesc_M_KBlock kernel2_mean_var_grid_desc_m_kblock_; Kernel2MeanVarGridDesc_M_KBlock kernel2_mean_var_grid_desc_m_kblock_;
...@@ -396,6 +470,8 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType, ...@@ -396,6 +470,8 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
index_t MRaw_; // invarient length index_t MRaw_; // invarient length
index_t KRaw_; // reduce length index_t KRaw_; // reduce length
index_t invariant_lowest_length_;
}; };
struct Invoker : public BaseInvoker struct Invoker : public BaseInvoker
...@@ -408,60 +484,68 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType, ...@@ -408,60 +484,68 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
auto kernel1 = kernel_normalizationSplitK1st<GridwiseWelford, auto kernel1 = kernel_normalizationSplitK1st<GridwiseWelford,
XDataType, XDataType,
MeanVarDataType, WorkspaceMeanVarDataType,
ComputeDataType, ComputeDataType,
SrcGridDesc_M_K, SrcGridDesc_M_K,
Kernel1MeanVarGridDesc_M_KBlock>; Kernel1MeanVarGridDesc_M_KBlock>;
auto kernel2 = kernel_normalizationSplitK2nd<GridwiseWelfordNormalization, auto kernel2 = kernel_normalizationSplitK2nd<GridwiseWelfordNormalization,
MeanVarDataType, WorkspaceMeanVarDataType,
XDataType, XDataType,
GammaDataType, GammaDataType,
BetaDataType, BetaDataType,
YDataType, YDataType,
SaveMeanInvStdDataType,
ComputeDataType, ComputeDataType,
YElementwiseOperation, YElementwiseOperation,
Kernel2MeanVarGridDesc_M_KBlock, Kernel2MeanVarGridDesc_M_KBlock,
Kernel2CountGridDesc_M_KBlock, Kernel2CountGridDesc_M_KBlock,
SrcGridDesc_M_K>; SrcGridDesc_M_K,
SaveMeanInvStdGridDesc_M>;
float avg_time = 0; float avg_time = 0;
avg_time += launch_and_time_kernel(stream_config, avg_time += launch_and_time_kernel(
kernel1, stream_config,
dim3(arg.gridSize_), kernel1,
dim3(BlockSize), dim3(arg.gridSize_),
0, dim3(BlockSize),
arg.x_grid_desc_m_k_, 0,
arg.kernel1_mean_var_grid_desc_m_kblock_, arg.x_grid_desc_m_k_,
arg.numBlockTileIteration_, arg.kernel1_mean_var_grid_desc_m_kblock_,
arg.p_x_, arg.numBlockTileIteration_,
static_cast<MeanVarDataType*>(arg.p_workspace_mean_), arg.p_x_,
static_cast<MeanVarDataType*>(arg.p_workspace_var_), static_cast<WorkspaceMeanVarDataType*>(arg.p_workspace_mean_),
static_cast<int32_t*>(arg.p_workspace_count_)); static_cast<WorkspaceMeanVarDataType*>(arg.p_workspace_var_),
static_cast<int32_t*>(arg.p_workspace_count_));
avg_time += launch_and_time_kernel(stream_config,
kernel2, avg_time += launch_and_time_kernel(
dim3(arg.gridSize_), stream_config,
dim3(BlockSize), kernel2,
0, dim3(arg.gridSize_),
arg.kernel2_mean_var_grid_desc_m_kblock_, dim3(BlockSize),
arg.kernel2_count_grid_desc_m_kblock_, 0,
arg.x_grid_desc_m_k_, arg.kernel2_mean_var_grid_desc_m_kblock_,
arg.gamma_grid_desc_m_k_, arg.kernel2_count_grid_desc_m_kblock_,
arg.beta_grid_desc_m_k_, arg.x_grid_desc_m_k_,
arg.y_grid_desc_m_k_, arg.gamma_grid_desc_m_k_,
arg.numMeanVarCountIteration_, arg.beta_grid_desc_m_k_,
arg.numBlockTileIteration_, arg.y_grid_desc_m_k_,
arg.kGridSize_, arg.save_mean_grid_desc_m_,
arg.epsilon_, arg.save_inv_std_grid_desc_m_,
static_cast<MeanVarDataType*>(arg.p_workspace_mean_), arg.numMeanVarCountIteration_,
static_cast<MeanVarDataType*>(arg.p_workspace_var_), arg.numBlockTileIteration_,
static_cast<int32_t*>(arg.p_workspace_count_), arg.kGridSize_,
arg.p_x_, arg.epsilon_,
arg.p_gamma_, static_cast<const WorkspaceMeanVarDataType*>(arg.p_workspace_mean_),
arg.p_beta_, static_cast<const WorkspaceMeanVarDataType*>(arg.p_workspace_var_),
arg.p_y_, static_cast<const int32_t*>(arg.p_workspace_count_),
arg.y_elementwise_op_); arg.p_x_,
arg.p_gamma_,
arg.p_beta_,
arg.p_y_,
arg.p_saveMean_,
arg.p_saveInvStd_,
arg.y_elementwise_op_);
return avg_time; return avg_time;
}; };
...@@ -482,10 +566,10 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType, ...@@ -482,10 +566,10 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
int welford_size = pArg_->MRaw_ * pArg_->kGridSize_; int welford_size = pArg_->MRaw_ * pArg_->kGridSize_;
// workspace for welford intermediate mean // workspace for welford intermediate mean
workspace_size += welford_size * sizeof(MeanVarDataType) + 64; workspace_size += welford_size * sizeof(WorkspaceMeanVarDataType) + 64;
// workspace for welford intermediate variance // workspace for welford intermediate variance
workspace_size += welford_size * sizeof(MeanVarDataType) + 64; workspace_size += welford_size * sizeof(WorkspaceMeanVarDataType) + 64;
// workspace for welford intermediate count // workspace for welford intermediate count
workspace_size += pArg_->kGridSize_ * sizeof(int32_t) + 64; workspace_size += pArg_->kGridSize_ * sizeof(int32_t) + 64;
...@@ -504,13 +588,13 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType, ...@@ -504,13 +588,13 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
// setup buffer used for intermediate welford mean // setup buffer used for intermediate welford mean
pArg_->p_workspace_mean_ = static_cast<char*>(pArg_->p_workspace_); pArg_->p_workspace_mean_ = static_cast<char*>(pArg_->p_workspace_);
index_t mean_space_sz = welford_size * sizeof(MeanVarDataType); index_t mean_space_sz = welford_size * sizeof(WorkspaceMeanVarDataType);
mean_space_sz = math::integer_least_multiple(mean_space_sz, 64); mean_space_sz = math::integer_least_multiple(mean_space_sz, 64);
// setup buffer used for intermediate welford varirance // setup buffer used for intermediate welford varirance
pArg_->p_workspace_var_ = reinterpret_cast<char*>(pArg_->p_workspace_mean_) + mean_space_sz; pArg_->p_workspace_var_ = reinterpret_cast<char*>(pArg_->p_workspace_mean_) + mean_space_sz;
index_t variance_space_sz = welford_size * sizeof(MeanVarDataType); index_t variance_space_sz = welford_size * sizeof(WorkspaceMeanVarDataType);
variance_space_sz = math::integer_least_multiple(variance_space_sz, 64); variance_space_sz = math::integer_least_multiple(variance_space_sz, 64);
// setup buffer used for intermediate welford count // setup buffer used for intermediate welford count
...@@ -522,8 +606,6 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType, ...@@ -522,8 +606,6 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
{ {
const Argument* p_arg_ = dynamic_cast<const Argument*>(p_arg); const Argument* p_arg_ = dynamic_cast<const Argument*>(p_arg);
constexpr index_t NumInvariantDim = Rank - NumReduceDim;
if constexpr(XYVectorDim == 0) if constexpr(XYVectorDim == 0)
{ {
if constexpr(NumInvariantDim == 0) if constexpr(NumInvariantDim == 0)
...@@ -535,10 +617,10 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType, ...@@ -535,10 +617,10 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
if(p_arg_->xStrides_[NumInvariantDim - 1] != 1) if(p_arg_->xStrides_[NumInvariantDim - 1] != 1)
return false; return false;
if(p_arg_->invariant_lowest_length % XSrcVectorSize != 0) if(p_arg_->invariant_lowest_length_ % XSrcVectorSize != 0)
return false; return false;
if(p_arg_->invariant_lowest_length % YDstVectorSize != 0) if(p_arg_->invariant_lowest_length_ % YDstVectorSize != 0)
return false; return false;
}; };
} }
...@@ -578,7 +660,7 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType, ...@@ -578,7 +660,7 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
if(p_arg_->betaStrides_[NumInvariantDim - 1] != 1) if(p_arg_->betaStrides_[NumInvariantDim - 1] != 1)
return false; return false;
if(p_arg_->invariant_lowest_length % BetaSrcVectorSize != 0) if(p_arg_->invariant_lowest_length_ % BetaSrcVectorSize != 0)
return false; return false;
} }
else // if fastest dim is reduced else // if fastest dim is reduced
...@@ -593,6 +675,9 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType, ...@@ -593,6 +675,9 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
if(p_arg_->kGridSize_ <= 1) if(p_arg_->kGridSize_ <= 1)
return false; return false;
if(p_arg_->invariant_lowest_length_ % SaveMeanInvStdDstVectorSize != 0)
return false;
return true; return true;
}; };
...@@ -602,6 +687,8 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType, ...@@ -602,6 +687,8 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
const std::vector<index_t> gammaStrides, const std::vector<index_t> gammaStrides,
const std::vector<index_t> betaStrides, const std::vector<index_t> betaStrides,
const std::vector<index_t> yStrides, const std::vector<index_t> yStrides,
const std::vector<index_t> saveMeanStrides,
const std::vector<index_t> saveInvStdStrides,
const std::vector<index_t> reduceDims, const std::vector<index_t> reduceDims,
double epsilon, double epsilon,
const void* p_x, const void* p_x,
...@@ -609,27 +696,30 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType, ...@@ -609,27 +696,30 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
const void* p_beta, const void* p_beta,
void* p_y, void* p_y,
void* p_saveMean, void* p_saveMean,
void* p_saveInvVar, void* p_saveInvStd,
YElementwiseOperation y_elementwise_op) override YElementwiseOperation y_elementwise_op) override
{ {
// TODO if(lengths.size() != Rank || xStrides.size() != Rank || gammaStrides.size() != Rank ||
// Optional cache of the intermediate results (mean and InvVariance) during the betaStrides.size() != Rank || yStrides.size() != Rank ||
// forward pass could speedup in the backward saveMeanStrides.size() != NumInvariantDim || saveInvStdStrides.size() != NumInvariantDim)
ignore = p_saveMean; throw std::runtime_error("dimension is incorrect");
ignore = p_saveInvVar;
return std::make_unique<Argument>(lengths, return std::make_unique<Argument>(lengths,
xStrides, xStrides,
gammaStrides, gammaStrides,
betaStrides, betaStrides,
yStrides, yStrides,
saveMeanStrides,
saveInvStdStrides,
reduceDims, reduceDims,
y_elementwise_op, y_elementwise_op,
epsilon, epsilon,
static_cast<const XDataType*>(p_x), static_cast<const XDataType*>(p_x),
static_cast<const GammaDataType*>(p_gamma), static_cast<const GammaDataType*>(p_gamma),
static_cast<const BetaDataType*>(p_beta), static_cast<const BetaDataType*>(p_beta),
static_cast<YDataType*>(p_y)); static_cast<YDataType*>(p_y),
static_cast<SaveMeanInvStdDataType*>(p_saveMean),
static_cast<SaveMeanInvStdDataType*>(p_saveInvStd));
}; };
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
......
...@@ -113,7 +113,6 @@ struct PassThrough ...@@ -113,7 +113,6 @@ struct PassThrough
} }
#endif #endif
#if defined CK_ENABLE_FP8
template <> template <>
__host__ __device__ void operator()<f8_t, f8_t>(f8_t& y, const f8_t& x) const __host__ __device__ void operator()<f8_t, f8_t>(f8_t& y, const f8_t& x) const
{ {
...@@ -143,7 +142,36 @@ struct PassThrough ...@@ -143,7 +142,36 @@ struct PassThrough
{ {
y = type_convert<f8_t>(x); y = type_convert<f8_t>(x);
} }
#endif
template <>
__host__ __device__ void operator()<bf8_t, bf8_t>(bf8_t& y, const bf8_t& x) const
{
y = x;
}
template <>
__host__ __device__ void operator()<float, bf8_t>(float& y, const bf8_t& x) const
{
y = type_convert<float>(x);
}
template <>
__host__ __device__ void operator()<bf8_t, float>(bf8_t& y, const float& x) const
{
y = type_convert<bf8_t>(x);
}
template <>
__host__ __device__ void operator()<half_t, bf8_t>(half_t& y, const bf8_t& x) const
{
y = type_convert<half_t>(x);
}
template <>
__host__ __device__ void operator()<bf8_t, half_t>(bf8_t& y, const half_t& x) const
{
y = ck::type_convert<bf8_t>(x);
}
}; };
struct UnaryConvert struct UnaryConvert
...@@ -172,7 +200,6 @@ struct ConvertBF16RTN ...@@ -172,7 +200,6 @@ struct ConvertBF16RTN
} }
}; };
#if defined CK_ENABLE_FP8
struct ConvertF8SR struct ConvertF8SR
{ {
// convert to fp8 using stochastic rounding (SR) // convert to fp8 using stochastic rounding (SR)
...@@ -189,7 +216,6 @@ struct ConvertF8SR ...@@ -189,7 +216,6 @@ struct ConvertF8SR
y = f8_convert_sr<Y>(x); y = f8_convert_sr<Y>(x);
} }
}; };
#endif
struct Scale struct Scale
{ {
...@@ -416,10 +442,11 @@ struct Sigmoid ...@@ -416,10 +442,11 @@ struct Sigmoid
__host__ __device__ void operator()(T& y, const T& x) const __host__ __device__ void operator()(T& y, const T& x) const
{ {
static_assert(is_same<T, float>::value || is_same<T, double>::value || static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, ck::half_t>::value, is_same<T, ck::half_t>::value || is_same<T, int8_t>::value ||
is_same<T, int32_t>::value,
"Data type is not supported by this operation!"); "Data type is not supported by this operation!");
constexpr T one = type_convert<T>(1);
y = 1 / (ck::type_convert<T>(1) + exp(-x)); y = one / (one + ck::math::exp(-x));
}; };
}; };
...@@ -429,7 +456,8 @@ struct TanH ...@@ -429,7 +456,8 @@ struct TanH
__host__ __device__ void operator()(T& y, const T& x) const __host__ __device__ void operator()(T& y, const T& x) const
{ {
static_assert(is_same<T, float>::value || is_same<T, double>::value || static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, ck::half_t>::value, is_same<T, ck::half_t>::value || is_same<T, int8_t>::value ||
is_same<T, int32_t>::value,
"Data type is not supported by this operation!"); "Data type is not supported by this operation!");
y = ck::math::tanh(x); y = ck::math::tanh(x);
...@@ -455,7 +483,101 @@ struct Swish ...@@ -455,7 +483,101 @@ struct Swish
y = type_convert<Y>(x / (1.f + ck::math::exp(bx))); y = type_convert<Y>(x / (1.f + ck::math::exp(bx)));
}; };
float beta_ = 1.0f; const float beta_;
};
struct SoftRelu
{
SoftRelu(float alpha = 1.f) : alpha_(alpha){};
template <typename T>
__host__ __device__ void operator()(T& y, const T& x) const
{
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, half_t>::value || is_same<T, int32_t>::value ||
is_same<T, int8_t>::value,
"Data type is not supported by this operation!");
T casted_alpha = type_convert<T>(alpha_);
constexpr T one = type_convert<T>(1);
y = ck::math::log(one + ck::math::exp(x * casted_alpha)) / casted_alpha;
}
const float alpha_;
};
struct Power
{
Power(float alpha = 0.f, float beta = 1.f, float gamma = 2.f)
: alpha_(alpha), beta_(beta), gamma_(gamma){};
template <typename T>
__host__ __device__ void operator()(T& y, const T& x) const
{
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, half_t>::value || is_same<T, int32_t>::value ||
is_same<T, int8_t>::value,
"Data type is not supported by this operation!");
T casted_alpha = type_convert<T>(alpha_);
T casted_beta = type_convert<T>(beta_);
T casted_gamma = type_convert<T>(gamma_);
T shifted_scaled_x = casted_alpha + casted_beta * x;
y = ck::math::pow(shifted_scaled_x, casted_gamma);
}
const float alpha_;
const float beta_;
const float gamma_;
};
struct ClippedRelu
{
ClippedRelu(float alpha = 0.f, float beta = 1.f) : alpha_(alpha), beta_(beta){};
template <typename T>
__host__ __device__ void operator()(T& y, const T& x) const
{
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, half_t>::value || is_same<T, int32_t>::value ||
is_same<T, int8_t>::value,
"Data type is not supported by this operation!");
T casted_alpha = type_convert<T>(alpha_);
T casted_beta = type_convert<T>(beta_);
y = ck::math::min(casted_beta, ck::math::max(casted_alpha, x));
}
const float alpha_;
const float beta_;
};
struct LeakyRelu
{
LeakyRelu(float alpha = 0.01f) : alpha_(alpha){};
template <typename T>
__host__ __device__ void operator()(T& y, const T& x) const
{
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, half_t>::value || is_same<T, int32_t>::value ||
is_same<T, int8_t>::value,
"Data type is not supported by this operation!");
T casted_alpha = type_convert<T>(alpha_);
y = x >= 0 ? x : x * casted_alpha;
}
const float alpha_;
};
struct Elu
{
Elu(float alpha = 1.f) : alpha_(alpha){};
template <typename T>
__host__ __device__ void operator()(T& y, const T& x) const
{
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, half_t>::value || is_same<T, int32_t>::value ||
is_same<T, int8_t>::value,
"Data type is not supported by this operation!");
T casted_alpha = type_convert<T>(alpha_);
y = x > 0 ? x : casted_alpha * ck::math::expm1(x);
}
const float alpha_;
}; };
} // namespace element_wise } // namespace element_wise
......
...@@ -522,6 +522,7 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle ...@@ -522,6 +522,7 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
BlockSize, BlockSize,
ABDataType, ABDataType,
ABDataType,
AccDataType, AccDataType,
decltype(a_block_desc_ak0_m_ak1), decltype(a_block_desc_ak0_m_ak1),
decltype(b_block_desc_bk0_n_bk1), decltype(b_block_desc_bk0_n_bk1),
......
...@@ -628,7 +628,8 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle ...@@ -628,7 +628,8 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
Gemm1KPack, Gemm1KPack,
false, // TransposeC false, // TransposeC
Gemm1KPack, // AMmaKStride Gemm1KPack, // AMmaKStride
Gemm1KPack * XdlopsGemm<FloatAB, MPerXdl, NPerXdl, Gemm1KPack, false>{}.K0PerXdlops>{ Gemm1KPack *
XdlopsGemm<FloatAB, MPerXdl, NPerXdl, Gemm1KPack, FloatAB, false>{}.K0PerXdlops>{
// BMmaKStride // BMmaKStride
make_tuple(0, 0, 0, 0)}; // A_origin make_tuple(0, 0, 0, 0)}; // A_origin
......
...@@ -880,7 +880,12 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle ...@@ -880,7 +880,12 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
Gemm1KPack, Gemm1KPack,
false, // TransposeC false, // TransposeC
Gemm1KPack, // AMmaKStride Gemm1KPack, // AMmaKStride
Gemm1KPack * XdlopsGemm<A0B0B1DataType, Gemm0MPerXdl, Gemm0NPerXdl, Gemm1KPack, false>{} Gemm1KPack * XdlopsGemm<A0B0B1DataType,
Gemm0MPerXdl,
Gemm0NPerXdl,
Gemm1KPack,
A0B0B1DataType,
false>{}
.K0PerXdlops>{ // BMmaKStride .K0PerXdlops>{ // BMmaKStride
make_tuple(0, 0, 0, 0)}; // A_origin make_tuple(0, 0, 0, 0)}; // A_origin
......
...@@ -794,7 +794,8 @@ struct GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle ...@@ -794,7 +794,8 @@ struct GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle
Gemm1KPack, Gemm1KPack,
true, // TransposeC true, // TransposeC
Gemm1KPack, // AMmaKStride Gemm1KPack, // AMmaKStride
Gemm1KPack * XdlopsGemm<FloatAB, MPerXdl, NPerXdl, Gemm1KPack, false>{}.K0PerXdlops>{ Gemm1KPack *
XdlopsGemm<FloatAB, MPerXdl, NPerXdl, Gemm1KPack, FloatAB, false>{}.K0PerXdlops>{
// BMmaKStride // BMmaKStride
make_tuple(0, 0, 0, 0)}; // A_origin make_tuple(0, 0, 0, 0)}; // A_origin
......
...@@ -649,7 +649,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -649,7 +649,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
Gemm1KPack, Gemm1KPack,
true, // TransposeC true, // TransposeC
Gemm1KPack, // AMmaKStride Gemm1KPack, // AMmaKStride
Gemm1KPack * XdlopsGemm<FloatAB, MPerXdl, NPerXdl, Gemm1KPack, false>{}.K0PerXdlops>{ Gemm1KPack *
XdlopsGemm<FloatAB, MPerXdl, NPerXdl, Gemm1KPack, FloatAB, false>{}.K0PerXdlops>{
// BMmaKStride // BMmaKStride
make_tuple(0, 0, 0, 0)}; // A_origin make_tuple(0, 0, 0, 0)}; // A_origin
......
...@@ -504,6 +504,7 @@ struct GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -504,6 +504,7 @@ struct GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
BlockSize, BlockSize,
FloatAB, FloatAB,
FloatAB,
FloatGemmAcc, FloatGemmAcc,
decltype(a_block_desc_ak0_m_ak1), decltype(a_block_desc_ak0_m_ak1),
decltype(b_block_desc_bk0_n_bk1), decltype(b_block_desc_bk0_n_bk1),
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/multi_index_transform_helper.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r2.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
namespace ck {
// GEMM:
// input : A0[M, K], A1[M, K]
// input : B0[N, K], B1[N, K]
// input : D0[M, N], D1[M, N], ...
// output : E[M, N]
// C = a_op(A) * b_op(B)
// E = cde_op(C, D0, D1, ...)
// Assume:
// D0, D1, ... and E have the same layout
template <typename AsDataType,
typename BsDataType,
typename ComputeDataType_,
typename AccDataType,
typename CShuffleDataType,
typename DsDataType,
typename EDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
index_t NumGemmKPrefetchStage,
index_t BlockSize,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t AK1Value,
index_t BK1Value,
index_t MPerXdl,
index_t NPerXdl,
index_t MXdlPerWave,
index_t NXdlPerWave,
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_AK1,
bool AThreadTransferSrcResetCoordinateAfterRun,
index_t ABlockLdsExtraM,
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_BK1,
bool BThreadTransferSrcResetCoordinateAfterRun,
index_t BBlockLdsExtraN,
index_t CShuffleMXdlPerWavePerShuffle,
index_t CShuffleNXdlPerWavePerShuffle,
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CDEShuffleBlockTransferScalarPerVector_NPerBlock,
LoopScheduler LoopSched,
PipelineVersion PipelineVer = PipelineVersion::v1>
struct GridwiseGemmMultipleABD_xdl_cshuffle
{
static constexpr index_t NumATensor = AsDataType::Size();
static constexpr index_t NumBTensor = BsDataType::Size();
static constexpr index_t NumDTensor = DsDataType::Size();
using GemmSpecialization = ck::tensor_operation::device::GemmSpecialization;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{};
static constexpr auto I5 = Number<5>{};
static constexpr auto I6 = Number<6>{};
static constexpr auto I7 = Number<7>{};
// K1 should be Number<...>
static constexpr auto AK1 = Number<AK1Value>{};
static constexpr auto BK1 = Number<BK1Value>{};
static constexpr auto AK0PerBlock = Number<KPerBlock / AK1Value>{};
static constexpr auto BK0PerBlock = Number<KPerBlock / BK1Value>{};
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
using GridwiseGemmPipe = remove_cvref_t<
decltype(GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>())>;
#if CK_WORKAROUND_DENORM_FIX
using ComputeDataType =
conditional_t<is_same_v<ComputeDataType_, ck::half_t>, ck::bhalf_t, ComputeDataType_>;
#else
using ComputeDataType = ComputeDataType_;
#endif
__host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
{
// A matrix in LDS memory, dst of blockwise copy
return make_naive_tensor_descriptor(
make_tuple(AK0PerBlock, Number<MPerBlock>{}, AK1),
make_tuple(Number<MPerBlock + ABlockLdsExtraM>{} * AK1, AK1, I1));
}
__host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
{
// B matrix in LDS memory, dst of blockwise copy
return make_naive_tensor_descriptor(
make_tuple(BK0PerBlock, Number<NPerBlock>{}, BK1),
make_tuple(Number<NPerBlock + BBlockLdsExtraN>{} * BK1, BK1, I1));
}
__host__ __device__ static constexpr auto
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock()
{
constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
make_naive_tensor_descriptor_packed(
make_tuple(I1,
Number<CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl>{},
I1,
Number<CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>{}));
return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock;
}
static constexpr auto MakeAsGridPointer()
{
return generate_tuple(
[&](auto i) {
using ADataType = remove_cvref_t<tuple_element_t<i.value, AsDataType>>;
return static_cast<const ADataType*>(nullptr);
},
Number<NumATensor>{});
}
static constexpr auto MakeBsGridPointer()
{
return generate_tuple(
[&](auto i) {
using BDataType = remove_cvref_t<tuple_element_t<i.value, BsDataType>>;
return static_cast<const BDataType*>(nullptr);
},
Number<NumBTensor>{});
}
// ck::Tuple<const D0DataType*, const D1DataType*, ...>
static constexpr auto MakeDsGridPointer()
{
return generate_tuple(
[&](auto i) {
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
return static_cast<const DDataType*>(nullptr);
},
Number<NumDTensor>{});
}
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
// lds max alignment
constexpr auto max_lds_align = math::lcm(AK1, BK1);
constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
constexpr auto b_block_space_size_aligned = math::integer_least_multiple(
b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
// LDS allocation for C shuffle in LDS
constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
constexpr auto c_block_size =
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
return math::max((a_block_space_size_aligned + b_block_space_size_aligned) *
sizeof(ComputeDataType),
c_block_size * sizeof(CShuffleDataType));
}
// A desc for source in blockwise copy
template <typename AGridDesc_M_K>
__host__ __device__ static constexpr auto
MakeAGridDescriptor_AK0_M_AK1(const AGridDesc_M_K& a_grid_desc_m_k)
{
const auto M = a_grid_desc_m_k.GetLength(I0);
const auto K = a_grid_desc_m_k.GetLength(I1);
const auto AK0 = K / AK1;
return transform_tensor_descriptor(a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
make_pass_through_transform(M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
template <typename AsGridDesc_M_K>
__host__ __device__ static constexpr auto
MakeAsGridDescriptor_AK0_M_AK1(const AsGridDesc_M_K& as_grid_desc_m_k)
{
return generate_tuple(
[&](auto i) { return MakeAGridDescriptor_AK0_M_AK1(as_grid_desc_m_k[i]); },
Number<NumATensor>{});
}
// B desc for source in blockwise copy
template <typename BGridDesc_N_K>
__host__ __device__ static constexpr auto
MakeBGridDescriptor_BK0_N_BK1(const BGridDesc_N_K& b_grid_desc_n_k)
{
const auto N = b_grid_desc_n_k.GetLength(I0);
const auto K = b_grid_desc_n_k.GetLength(I1);
const auto BK0 = K / BK1;
return transform_tensor_descriptor(b_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
make_pass_through_transform(N)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
template <typename BsGridDesc_N_K>
__host__ __device__ static constexpr auto
MakeBsGridDescriptor_BK0_N_BK1(const BsGridDesc_N_K& bs_grid_desc_n_k)
{
return generate_tuple(
[&](auto i) { return MakeBGridDescriptor_BK0_N_BK1(bs_grid_desc_n_k[i]); },
Number<NumBTensor>{});
}
// E desc for destination in blockwise copy
template <typename EGridDesc_M_N>
__host__ __device__ static constexpr auto
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const EGridDesc_M_N& e_grid_desc_m_n)
{
const auto M = e_grid_desc_m_n.GetLength(I0);
const auto N = e_grid_desc_m_n.GetLength(I1);
const auto MBlock = M / MPerBlock;
const auto NBlock = N / NPerBlock;
const auto e_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
e_grid_desc_m_n,
make_tuple(make_unmerge_transform(make_tuple(MBlock, Number<MPerBlock>{})),
make_unmerge_transform(make_tuple(NBlock, Number<NPerBlock>{}))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}));
return e_grid_desc_mblock_mperblock_nblock_nperblock;
}
// Ds desc for source in blockwise copy
template <typename DsGridDesc_M_N>
__host__ __device__ static constexpr auto
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DsGridDesc_M_N& ds_grid_desc_m_n)
{
return generate_tuple(
[&](auto i) {
return MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(ds_grid_desc_m_n[i]);
},
Number<NumDTensor>{});
}
// return block_id to E matrix tile idx (m0, n0) mapping
template <typename EGridDesc_M_N>
__host__ __device__ static constexpr auto
MakeBlock2ETileMap(const EGridDesc_M_N& e_grid_desc_m_n)
{
return BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, EGridDesc_M_N>(
e_grid_desc_m_n);
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
template <typename AsGridDesc_M_K,
typename BsGridDesc_N_K,
typename DsGridDesc_M_N,
typename EGridDesc_M_N,
typename Block2ETileMap>
__host__ __device__ static constexpr bool CheckValidity(const AsGridDesc_M_K& as_grid_desc_m_k,
const BsGridDesc_N_K& bs_grid_desc_n_k,
const DsGridDesc_M_N& ds_grid_desc_m_n,
const EGridDesc_M_N& e_grid_desc_m_n,
const Block2ETileMap& block_2_etile_map)
{
static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
(NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
"Invalid tuning param!");
static_assert(KPerBlock % AK1Value == 0 && KPerBlock % BK1Value == 0,
"KPerBlock must be divisible by AK1Value and BK1Value!");
const auto M = as_grid_desc_m_k[I0].GetLength(I0);
const auto N = bs_grid_desc_n_k[I0].GetLength(I0);
const auto AK = as_grid_desc_m_k[I0].GetLength(I1);
const auto BK = bs_grid_desc_n_k[I0].GetLength(I1);
// check consistency of desc
if(!(M == e_grid_desc_m_n.GetLength(I0) && N == e_grid_desc_m_n.GetLength(I1) && AK == BK))
{
return false;
}
constexpr long_index_t TwoGB = (long_index_t{1} << 31);
bool valid = true;
static_for<0, NumATensor, 1>{}([&](auto i) {
using ADataType = remove_cvref_t<tuple_element_t<i.value, AsDataType>>;
valid =
valid && (as_grid_desc_m_k[i].GetElementSpaceSize() * sizeof(ADataType) <= TwoGB);
valid = valid && (M == as_grid_desc_m_k[i].GetLength(I0) &&
AK == as_grid_desc_m_k[i].GetLength(I1));
});
static_for<0, NumBTensor, 1>{}([&](auto i) {
using BDataType = remove_cvref_t<tuple_element_t<i.value, BsDataType>>;
valid =
valid && (bs_grid_desc_n_k[i].GetElementSpaceSize() * sizeof(BDataType) <= TwoGB);
valid = valid && (N == bs_grid_desc_n_k[i].GetLength(I0) &&
BK == bs_grid_desc_n_k[i].GetLength(I1));
});
static_for<0, NumDTensor, 1>{}([&](auto i) {
valid = valid && (M == ds_grid_desc_m_n[i].GetLength(I0) &&
N == ds_grid_desc_m_n[i].GetLength(I1));
});
if(!valid)
{
return false;
}
// check tile size
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && AK % KPerBlock == 0))
{
return false;
}
// check gridwise gemm pipeline
const auto num_k_loop = AK / KPerBlock;
if(!GridwiseGemmPipe::IsSupported(num_k_loop))
{
return false;
}
// check block-to-E-tile
if(!block_2_etile_map.CheckValidity(e_grid_desc_m_n))
{
return false;
}
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
// check tensor size: cannot be larger than 2GB each
if(!(e_grid_desc_m_n.GetElementSpaceSize() * sizeof(EDataType) <= TwoGB))
{
return false;
}
return true;
}
__host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
{
const index_t num_loop = K / KPerBlock;
return GridwiseGemmPipe::CalculateHasMainLoop(num_loop);
}
using AsGridPointer = decltype(MakeAsGridPointer());
using BsGridPointer = decltype(MakeBsGridPointer());
using DsGridPointer = decltype(MakeDsGridPointer());
template <typename ALayout, GemmSpecialization GemmSpec>
__host__ __device__ static auto
MakeAGridDescriptor_M_K(index_t MRaw, index_t KRaw, index_t StrideA)
{
constexpr auto matrix_padder =
ck::tensor_operation::device::MatrixPadder<GemmSpec, index_t, index_t, index_t>{
MPerBlock, NPerBlock, KPerBlock};
const auto a_grid_desc_mraw_kraw = [&]() {
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
{
return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
make_tuple(StrideA, I1));
}
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
{
return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
make_tuple(I1, StrideA));
}
}();
return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
}
template <typename AsLayout, GemmSpecialization GemmSpec>
__host__ __device__ static auto
MakeAsGridDescriptor_M_K(const std::array<index_t, NumATensor>& MRaws,
const std::array<index_t, NumATensor>& KRaws,
const std::array<index_t, NumATensor>& AsStride)
{
return generate_tuple(
[&](auto i) {
using ALayout = remove_cvref_t<tuple_element_t<i.value, AsLayout>>;
return MakeAGridDescriptor_M_K<ALayout, GemmSpec>(MRaws[i], KRaws[i], AsStride[i]);
},
Number<NumATensor>{});
}
template <typename BLayout, GemmSpecialization GemmSpec>
__host__ __device__ static auto
MakeBGridDescriptor_N_K(index_t KRaw, index_t NRaw, index_t StrideB)
{
constexpr auto matrix_padder =
ck::tensor_operation::device::MatrixPadder<GemmSpec, index_t, index_t, index_t>{
MPerBlock, NPerBlock, KPerBlock};
const auto b_grid_desc_nraw_kraw = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
make_tuple(I1, StrideB));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
make_tuple(StrideB, I1));
}
}();
return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
}
template <typename BsLayout, GemmSpecialization GemmSpec>
__host__ __device__ static auto
MakeBsGridDescriptor_N_K(const std::array<index_t, NumBTensor>& KRaws,
const std::array<index_t, NumBTensor>& NRaws,
const std::array<index_t, NumBTensor>& BsStride)
{
return generate_tuple(
[&](auto i) {
using BLayout = remove_cvref_t<tuple_element_t<i.value, BsLayout>>;
return MakeBGridDescriptor_N_K<BLayout, GemmSpec>(KRaws[i], NRaws[i], BsStride[i]);
},
Number<NumBTensor>{});
}
template <typename ELayout, GemmSpecialization GemmSpec>
__host__ __device__ static auto
MakeEGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideE)
{
constexpr auto matrix_padder =
ck::tensor_operation::device::MatrixPadder<GemmSpec, index_t, index_t, index_t>{
MPerBlock, NPerBlock, KPerBlock};
const auto e_grid_desc_mraw_nraw = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, ELayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
make_tuple(StrideE, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, ELayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
make_tuple(I1, StrideE));
}
}();
return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw);
}
template <typename DsLayout, GemmSpecialization GemmSpec>
__host__ __device__ static auto
MakeDsGridDescriptor_M_N(const std::array<index_t, NumDTensor>& MRaws,
const std::array<index_t, NumDTensor>& NRaws,
const std::array<index_t, NumDTensor>& DsStride)
{
return generate_tuple(
[&](auto i) {
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
return MakeEGridDescriptor_M_N<DLayout, GemmSpec>(MRaws[i], NRaws[i], DsStride[i]);
},
Number<NumDTensor>{});
}
__device__ __host__ static constexpr auto GetMPerBlock() { return MPerBlock; }
template <bool HasMainKBlockLoop,
typename AsGridDesc_AK0_M_AK1,
typename BsGridDesc_BK0_N_BK1,
typename DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
typename EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
typename Block2ETileMap>
__device__ static void Run(AsGridPointer p_as_grid,
BsGridPointer p_bs_grid,
DsGridPointer p_ds_grid,
EDataType* __restrict__ p_e_grid,
void* __restrict__ p_shared,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CDEElementwiseOperation& cde_element_op,
const AsGridDesc_AK0_M_AK1 as_grid_desc_ak0_m_ak1,
const BsGridDesc_BK0_N_BK1 bs_grid_desc_bk0_n_bk1,
const DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
ds_grid_desc_mblock_mperblock_nblock_nperblock,
const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
e_grid_desc_mblock_mperblock_nblock_nperblock,
const Block2ETileMap& block_2_etile_map)
{
const auto as_grid_buf = generate_tuple(
[&](auto i) {
return make_dynamic_buffer<AddressSpaceEnum::Global>(
p_as_grid[i], as_grid_desc_ak0_m_ak1[i].GetElementSpaceSize());
},
Number<NumATensor>{});
const auto bs_grid_buf = generate_tuple(
[&](auto i) {
return make_dynamic_buffer<AddressSpaceEnum::Global>(
p_bs_grid[i], bs_grid_desc_bk0_n_bk1[i].GetElementSpaceSize());
},
Number<NumBTensor>{});
const auto ds_grid_buf = generate_tuple(
[&](auto i) {
return make_dynamic_buffer<AddressSpaceEnum::Global>(
p_ds_grid[i],
ds_grid_desc_mblock_mperblock_nblock_nperblock[i].GetElementSpaceSize());
},
Number<NumDTensor>{});
auto e_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_e_grid, e_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
// divide block work by [M, N]
const auto block_work_idx =
block_2_etile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
if(!block_2_etile_map.ValidCTileIndex(
block_work_idx,
make_tuple(e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
{
return;
}
// HACK: this force m/n_block_data_idx_on_grid into SGPR
const index_t m_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock);
const index_t n_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock);
// lds max alignment
constexpr auto max_lds_align = math::lcm(AK1, BK1);
// A matrix in LDS memory, dst of blockwise copy
constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
// B matrix in LDS memory, dst of blockwise copy
constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
const auto idx_as_block_begin =
generate_tuple([&](auto) { return make_multi_index(0, m_block_data_idx_on_grid, 0); },
Number<NumATensor>{});
auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_v7r2<
ThisThreadBlock,
AsDataType,
Tuple<ComputeDataType>,
decltype(as_grid_desc_ak0_m_ak1),
decltype(tie(a_block_desc_ak0_m_ak1)),
AElementwiseOperation,
Sequence<static_cast<index_t>(InMemoryDataOperationEnum::Set)>,
Sequence<AK0PerBlock, MPerBlock, AK1>,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
Sequence<1, 0, 2>,
ABlockTransferSrcVectorDim,
2,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
uniform_sequence_gen_t<NumATensor, false>,
Sequence<true>>{as_grid_desc_ak0_m_ak1,
idx_as_block_begin,
tie(a_block_desc_ak0_m_ak1),
make_tuple(make_multi_index(0, 0, 0)),
a_element_op};
const auto idx_bs_block_begin =
generate_tuple([&](auto) { return make_multi_index(0, n_block_data_idx_on_grid, 0); },
Number<NumBTensor>{});
auto b_blockwise_copy = ThreadGroupTensorSliceTransfer_v7r2<
ThisThreadBlock,
BsDataType,
Tuple<ComputeDataType>,
decltype(bs_grid_desc_bk0_n_bk1),
decltype(tie(b_block_desc_bk0_n_bk1)),
BElementwiseOperation,
Sequence<static_cast<index_t>(InMemoryDataOperationEnum::Set)>,
Sequence<BK0PerBlock, NPerBlock, BK1>,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
Sequence<1, 0, 2>,
BBlockTransferSrcVectorDim,
2,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
uniform_sequence_gen_t<NumBTensor, false>,
Sequence<true>>{bs_grid_desc_bk0_n_bk1,
idx_bs_block_begin,
tie(b_block_desc_bk0_n_bk1),
make_tuple(make_multi_index(0, 0, 0)),
b_element_op};
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[K0PerBlock, MPerBlock] is in LDS
// b_mtx[K0PerBlock, NPerBlock] is in LDS
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
// sanity check
constexpr index_t KPack =
math::max(math::lcm(AK1, BK1),
MfmaSelector<ComputeDataType, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
BlockSize,
ComputeDataType, // ComputeDataType for A
ComputeDataType, // ComputeDataType for B
AccDataType,
decltype(a_block_desc_ak0_m_ak1),
decltype(b_block_desc_bk0_n_bk1),
MPerXdl,
NPerXdl,
MXdlPerWave,
NXdlPerWave,
KPack,
LoopSched>();
auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<ComputeDataType*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<ComputeDataType*>(p_shared) + a_block_space_size_aligned,
b_block_desc_bk0_n_bk1.GetElementSpaceSize());
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1, 0, 0);
const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
(as_grid_desc_ak0_m_ak1[I0].GetLength(I0) * as_grid_desc_ak0_m_ak1[I0].GetLength(I2)) /
KPerBlock);
// gridwise GEMM pipeline
const auto gridwise_gemm_pipeline =
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>();
gridwise_gemm_pipeline.template Run<HasMainKBlockLoop>(as_grid_desc_ak0_m_ak1,
a_block_desc_ak0_m_ak1,
a_blockwise_copy,
as_grid_buf,
a_block_buf,
a_block_slice_copy_step,
bs_grid_desc_bk0_n_bk1,
b_block_desc_bk0_n_bk1,
b_blockwise_copy,
bs_grid_buf,
b_block_buf,
b_block_slice_copy_step,
blockwise_gemm,
c_thread_buf,
num_k_block_main_loop);
// shuffle C and write out
{
static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
"wrong!");
constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
// TODO: hacky, fix it!
constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
// TODO: hacky, fix it!
// c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<CShuffleDataType*>(p_shared),
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
make_tuple(
make_freeze_transform(I0),
make_unmerge_transform(make_tuple(
Number<CShuffleMXdlPerWavePerShuffle>{}, // M0 (MXdlPerWave) per shuffle
M1, // M1 = MWave
M2, // M2 * M3 * M4 = MPerXdl
M3,
M4)),
make_freeze_transform(I0),
make_unmerge_transform(make_tuple(
Number<CShuffleNXdlPerWavePerShuffle>{}, // N0 (NXdlPerWave) per shuffle
N1, // N1 = NWave
N2))), // N2 = NPerXdl
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(
Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{}));
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const auto c_thread_mtx_on_block =
blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
make_tuple(Sequence<0, 1, 2, 3, 4>{}),
make_tuple(Sequence<0>{}));
const auto m_thread_data_on_block_idx =
m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
make_multi_index(m_thread_data_on_block));
const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(N0, N1, N2))),
make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{}));
const auto n_thread_data_on_block_idx =
n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
make_multi_index(n_thread_data_on_block));
// shuffle: threadwise copy C from VGPR to LDS
auto c_thread_copy_vgpr_to_lds =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
CShuffleDataType,
decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
ck::tensor_operation::element_wise::PassThrough,
Sequence<CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
I1,
I1,
M2,
I1,
M4,
I1>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
7,
1,
InMemoryDataOperationEnum::Set,
1,
true>{
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
make_multi_index(0,
0,
m_thread_data_on_block_idx[I1],
n_thread_data_on_block_idx[I1],
m_thread_data_on_block_idx[I2],
m_thread_data_on_block_idx[I3],
m_thread_data_on_block_idx[I4],
n_thread_data_on_block_idx[I2]),
ck::tensor_operation::element_wise::PassThrough{}};
// tuple of reference to C/Ds tensor descriptors
const auto c_ds_desc_refs = concat_tuple_of_reference(
tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
generate_tie(
[&](auto i) -> const auto& // return type should be reference
{ return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
Number<NumDTensor>{}));
// tuple of reference to C/Ds tensor descriptors
const auto c_ds_buf_refs = concat_tuple_of_reference(
tie(c_shuffle_block_buf),
generate_tie(
[&](auto i) -> const auto& // return type should be reference
{ return ds_grid_buf[i]; },
Number<NumDTensor>{}));
// tuple of starting index of C/Ds blockwise copy
const auto idx_c_ds_block_begin = container_concat(
make_tuple(make_multi_index(0, 0, 0, 0)),
generate_tuple(
[&](auto) {
return make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0);
},
Number<NumDTensor>{}));
// blockwise copy C/D/E between LDS and global
auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r2<
ThisThreadBlock,
decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})),
Tuple<EDataType>,
decltype(c_ds_desc_refs),
decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
CDEElementwiseOperation,
Sequence<static_cast<index_t>(EGlobalMemoryDataOperation)>, // FIXME: make Sequence
// support arbitray type
Sequence<1,
CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1,
CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder,
Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder,
3, // index_t SrcVectorDim,
3, // index_t DstVectorDim,
CDEShuffleBlockTransferScalarPerVector_NPerBlock,
CDEShuffleBlockTransferScalarPerVector_NPerBlock,
sequence_merge_t<
Sequence<true>,
uniform_sequence_gen_t<NumDTensor,
false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags
Sequence<false>> // ThreadTransferDstResetCoordinateAfterRunFlags
{c_ds_desc_refs,
idx_c_ds_block_begin,
tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
make_tuple(make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0)),
cde_element_op};
// space filling curve for threadwise C in VGPR before shuffle
constexpr auto sfc_c_vgpr =
SpaceFillingCurve<Sequence<MXdlPerWave, NXdlPerWave, 1, 1, M2, 1, M4, 1>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
Sequence<CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
1,
1,
M2,
1,
M4,
1>>{};
// space filling curve for shuffled blockwise C/D/E
constexpr auto sfc_cde_block =
SpaceFillingCurve<Sequence<1, MPerBlock, 1, NPerBlock>,
Sequence<0, 2, 1, 3>,
Sequence<1,
CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1,
CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!");
static_for<0, num_access, 1>{}([&](auto access_id) {
// make sure it's safe to write to LDS
block_sync_lds();
// each thread write its data from VGPR to LDS
c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
c_thread_buf,
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
c_shuffle_block_buf);
// make sure it's safe to read from LDS
block_sync_lds();
// each block copy its data from LDS to global
cde_block_copy_lds_and_global.Run(
c_ds_desc_refs,
c_ds_buf_refs,
tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
tie(e_grid_buf));
if constexpr(access_id < num_access - 1)
{
constexpr auto cde_lds_and_global_step =
sfc_cde_block.GetForwardStep(access_id);
// move on Ds
static_for<0, NumDTensor, 1>{}([&](auto i) {
cde_block_copy_lds_and_global.MoveSrcSliceWindow(
c_ds_desc_refs, i + I1, cde_lds_and_global_step);
});
// move on E
cde_block_copy_lds_and_global.MoveDstSliceWindow(
tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
I0,
cde_lds_and_global_step);
}
});
}
}
template <bool HasMainKBlockLoop,
GemmSpecialization GemmSpec,
typename AsLayout,
typename BsLayout,
typename DsLayout,
typename ELayout,
typename Block2ETileMap>
__device__ static void Run(AsGridPointer p_as_grid,
BsGridPointer p_bs_grid,
DsGridPointer p_ds_grid,
void* __restrict__ p_e_grid_,
void* __restrict__ p_shared,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CDEElementwiseOperation& cde_element_op,
const index_t M,
const index_t N,
const index_t K,
const std::array<index_t, NumATensor> StrideAs,
const std::array<index_t, NumBTensor> StrideBs,
const std::array<index_t, NumDTensor> StrideDs,
const index_t StrideE,
const Block2ETileMap& block_2_etile_map)
{
using AsGridDesc_M_K =
remove_cvref_t<decltype(MakeAsGridDescriptor_M_K<AsLayout, GemmSpec>({}, {}, {}))>;
using BsGridDesc_N_K =
remove_cvref_t<decltype(MakeBsGridDescriptor_N_K<BsLayout, GemmSpec>({}, {}, {}))>;
using DsGridDesc_M_N =
remove_cvref_t<decltype(MakeDsGridDescriptor_M_N<DsLayout, GemmSpec>({}, {}, {}))>;
const auto p_e_grid = reinterpret_cast<EDataType*>(p_e_grid_);
AsGridDesc_M_K as_grid_desc_m_k;
BsGridDesc_N_K bs_grid_desc_n_k;
DsGridDesc_M_N ds_grid_desc_m_n;
static_for<0, NumATensor, 1>{}([&](auto j) {
using ALayout = remove_cvref_t<tuple_element_t<j.value, AsLayout>>;
as_grid_desc_m_k(j) = MakeAGridDescriptor_M_K<ALayout, GemmSpec>(M, K, StrideAs[j]);
});
static_for<0, NumBTensor, 1>{}([&](auto j) {
using BLayout = remove_cvref_t<tuple_element_t<j.value, BsLayout>>;
bs_grid_desc_n_k(j) = MakeBGridDescriptor_N_K<BLayout, GemmSpec>(N, K, StrideBs[j]);
});
static_for<0, NumDTensor, 1>{}([&](auto j) {
using DLayout = remove_cvref_t<tuple_element_t<j.value, DsLayout>>;
ds_grid_desc_m_n(j) = MakeEGridDescriptor_M_N<DLayout, GemmSpec>(M, N, StrideDs[j]);
});
const auto e_grid_desc_m_n = MakeEGridDescriptor_M_N<ELayout, GemmSpec>(M, N, StrideE);
// tensor descriptors for block/thread-wise copy
const auto as_grid_desc_ak0_m_ak1 = MakeAsGridDescriptor_AK0_M_AK1(as_grid_desc_m_k);
const auto bs_grid_desc_bk0_n_bk1 = MakeBsGridDescriptor_BK0_N_BK1(bs_grid_desc_n_k);
const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(ds_grid_desc_m_n);
const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(e_grid_desc_m_n);
Run<HasMainKBlockLoop>(p_as_grid,
p_bs_grid,
p_ds_grid,
p_e_grid,
p_shared,
a_element_op,
b_element_op,
cde_element_op,
as_grid_desc_ak0_m_ak1,
bs_grid_desc_bk0_n_bk1,
ds_grid_desc_mblock_mperblock_nblock_nperblock,
e_grid_desc_mblock_mperblock_nblock_nperblock,
block_2_etile_map);
}
};
} // namespace ck
...@@ -470,6 +470,7 @@ struct GridwiseGemmMultipleDMultipleR_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -470,6 +470,7 @@ struct GridwiseGemmMultipleDMultipleR_k0mk1_k0nk1_mn_xdl_cshuffle_v1
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
BlockSize, BlockSize,
FloatAB, FloatAB,
FloatAB,
FloatGemmAcc, FloatGemmAcc,
decltype(a_block_desc_ak0_m_ak1), decltype(a_block_desc_ak0_m_ak1),
decltype(b_block_desc_bk0_n_bk1), decltype(b_block_desc_bk0_n_bk1),
......
...@@ -36,7 +36,7 @@ __global__ void ...@@ -36,7 +36,7 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_grouped_conv_fwd_multiple_d_wmma_cshuffle( kernel_grouped_conv_multiple_d_wmma_cshuffle(
const ADataType* __restrict__ p_a_grid, const ADataType* __restrict__ p_a_grid,
const BDataType* __restrict__ p_b_grid, const BDataType* __restrict__ p_b_grid,
DsPointer p_ds_grid, DsPointer p_ds_grid,
...@@ -452,11 +452,11 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle ...@@ -452,11 +452,11 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
} }
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
// CheckValidity for kernels without multi D
template <typename Block2CTileMap> template <typename Block2CTileMap>
__host__ __device__ static constexpr bool __host__ __device__ static constexpr bool
CheckValidity(const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1, CheckValidity(const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1,
const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1, const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1,
const DsGridDesc_M_N& ds_grid_desc_m_n,
const EGridDesc_M_N& e_grid_desc_m_n, const EGridDesc_M_N& e_grid_desc_m_n,
const Block2CTileMap& block_2_ctile_map) const Block2CTileMap& block_2_ctile_map)
{ {
...@@ -471,18 +471,6 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle ...@@ -471,18 +471,6 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
const auto N = b_grid_desc_k0_n_k1.GetLength(I1); const auto N = b_grid_desc_k0_n_k1.GetLength(I1);
const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0); const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0);
bool valid = true;
static_for<0, NumDTensor, 1>{}([&](auto i) {
valid = valid && (M == ds_grid_desc_m_n[i].GetLength(I0) &&
N == ds_grid_desc_m_n[i].GetLength(I1));
});
if(!valid)
{
return false;
}
if(!(M == e_grid_desc_m_n.GetLength(I0) && N == e_grid_desc_m_n.GetLength(I1) && if(!(M == e_grid_desc_m_n.GetLength(I0) && N == e_grid_desc_m_n.GetLength(I1) &&
K0 == b_grid_desc_k0_n_k1.GetLength(I0) && K1 == a_grid_desc_k0_m_k1.GetLength(I2) && K0 == b_grid_desc_k0_n_k1.GetLength(I0) && K1 == a_grid_desc_k0_m_k1.GetLength(I2) &&
K1 == b_grid_desc_k0_n_k1.GetLength(I2))) K1 == b_grid_desc_k0_n_k1.GetLength(I2)))
...@@ -517,6 +505,31 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle ...@@ -517,6 +505,31 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
return true; return true;
} }
template <typename Block2CTileMap>
__host__ __device__ static constexpr bool
CheckValidity(const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1,
const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1,
const DsGridDesc_M_N& ds_grid_desc_m_n,
const EGridDesc_M_N& e_grid_desc_m_n,
const Block2CTileMap& block_2_ctile_map)
{
const auto M = a_grid_desc_k0_m_k1.GetLength(I1);
const auto N = b_grid_desc_k0_n_k1.GetLength(I1);
bool valid = true;
static_for<0, NumDTensor, 1>{}([&](auto i) {
valid = valid && (M == ds_grid_desc_m_n[i].GetLength(I0) &&
N == ds_grid_desc_m_n[i].GetLength(I1));
});
if(!valid)
{
return false;
}
return CheckValidity(
a_grid_desc_k0_m_k1, b_grid_desc_k0_n_k1, e_grid_desc_m_n, block_2_ctile_map);
}
__host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
{ {
const index_t num_loop = K / (K0PerBlock * K1); const index_t num_loop = K / (K0PerBlock * K1);
......
...@@ -31,7 +31,7 @@ namespace ck { ...@@ -31,7 +31,7 @@ namespace ck {
// D0, D1, ... and E have the same layout // D0, D1, ... and E have the same layout
template <typename ADataType, template <typename ADataType,
typename BDataType, typename BDataType,
typename ComputeDataType_, typename AComputeDataType_,
typename AccDataType, typename AccDataType,
typename CShuffleDataType, typename CShuffleDataType,
typename DsDataType, typename DsDataType,
...@@ -72,7 +72,8 @@ template <typename ADataType, ...@@ -72,7 +72,8 @@ template <typename ADataType,
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CDEShuffleBlockTransferScalarPerVector_NPerBlock, index_t CDEShuffleBlockTransferScalarPerVector_NPerBlock,
LoopScheduler LoopSched, LoopScheduler LoopSched,
PipelineVersion PipelineVer = PipelineVersion::v1> PipelineVersion PipelineVer = PipelineVersion::v1,
typename BComputeDataType = AComputeDataType_>
struct GridwiseGemmMultipleD_xdl_cshuffle struct GridwiseGemmMultipleD_xdl_cshuffle
{ {
static constexpr index_t NumDTensor = DsDataType::Size(); static constexpr index_t NumDTensor = DsDataType::Size();
...@@ -100,10 +101,10 @@ struct GridwiseGemmMultipleD_xdl_cshuffle ...@@ -100,10 +101,10 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
decltype(GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>())>; decltype(GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>())>;
#if CK_WORKAROUND_DENORM_FIX #if CK_WORKAROUND_DENORM_FIX
using ComputeDataType = using AComputeDataType =
conditional_t<is_same_v<ComputeDataType_, ck::half_t>, ck::bhalf_t, ComputeDataType_>; conditional_t<is_same_v<AComputeDataType_, ck::half_t>, ck::bhalf_t, AComputeDataType_>;
#else #else
using ComputeDataType = ComputeDataType_; using AComputeDataType = AComputeDataType_;
#endif #endif
__host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
...@@ -172,8 +173,8 @@ struct GridwiseGemmMultipleD_xdl_cshuffle ...@@ -172,8 +173,8 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
constexpr auto c_block_size = constexpr auto c_block_size =
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
return math::max((a_block_space_size_aligned + b_block_space_size_aligned) * return math::max(a_block_space_size_aligned * sizeof(AComputeDataType) +
sizeof(ComputeDataType), b_block_space_size_aligned * sizeof(BComputeDataType),
c_block_size * sizeof(CShuffleDataType)); c_block_size * sizeof(CShuffleDataType));
} }
...@@ -502,7 +503,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle ...@@ -502,7 +503,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder, ABlockTransferThreadClusterArrangeOrder,
ADataType, ADataType,
ComputeDataType, AComputeDataType,
decltype(a_grid_desc_ak0_m_ak1), decltype(a_grid_desc_ak0_m_ak1),
decltype(a_block_desc_ak0_m_ak1), decltype(a_block_desc_ak0_m_ak1),
ABlockTransferSrcAccessOrder, ABlockTransferSrcAccessOrder,
...@@ -533,7 +534,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle ...@@ -533,7 +534,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder, BBlockTransferThreadClusterArrangeOrder,
BDataType, BDataType,
ComputeDataType, BComputeDataType,
decltype(b_grid_desc_bk0_n_bk1), decltype(b_grid_desc_bk0_n_bk1),
decltype(b_block_desc_bk0_n_bk1), decltype(b_block_desc_bk0_n_bk1),
BBlockTransferSrcAccessOrder, BBlockTransferSrcAccessOrder,
...@@ -561,13 +562,15 @@ struct GridwiseGemmMultipleD_xdl_cshuffle ...@@ -561,13 +562,15 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register // register
// sanity check // sanity check
constexpr index_t KPack = constexpr index_t KPack = math::max(
math::max(math::lcm(AK1, BK1), math::lcm(AK1, BK1),
MfmaSelector<ComputeDataType, MPerXdl, NPerXdl>::selected_mfma.k_per_blk); MfmaSelector<AComputeDataType, MPerXdl, NPerXdl, BComputeDataType>::selected_mfma
.k_per_blk);
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
BlockSize, BlockSize,
ComputeDataType, AComputeDataType,
BComputeDataType,
AccDataType, AccDataType,
decltype(a_block_desc_ak0_m_ak1), decltype(a_block_desc_ak0_m_ak1),
decltype(b_block_desc_bk0_n_bk1), decltype(b_block_desc_bk0_n_bk1),
...@@ -585,10 +588,10 @@ struct GridwiseGemmMultipleD_xdl_cshuffle ...@@ -585,10 +588,10 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<ComputeDataType*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); static_cast<AComputeDataType*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<ComputeDataType*>(p_shared) + a_block_space_size_aligned, static_cast<BComputeDataType*>(p_shared) + a_block_space_size_aligned,
b_block_desc_bk0_n_bk1.GetElementSpaceSize()); b_block_desc_bk0_n_bk1.GetElementSpaceSize());
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0); constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0);
......
...@@ -602,6 +602,7 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle ...@@ -602,6 +602,7 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
BlockSize, BlockSize,
ComputeType, ComputeType,
ComputeType,
AccDataType, AccDataType,
decltype(a_block_desc_ak0_m_ak1), decltype(a_block_desc_ak0_m_ak1),
decltype(b_block_desc_bk0_n_bk1), decltype(b_block_desc_bk0_n_bk1),
......
...@@ -9,13 +9,13 @@ namespace ck { ...@@ -9,13 +9,13 @@ namespace ck {
struct GridwiseGemmPipeline_v2 struct GridwiseGemmPipeline_v2
{ {
__host__ __device__ static constexpr bool IsSupported(index_t num_loop) __host__ __device__ static constexpr bool IsSupported(const index_t num_loop)
{ {
// TODO: improve applicability // TODO: improve applicability
return num_loop % 2 == 0; return num_loop % 2 == 0;
} }
__host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop) __host__ __device__ static constexpr bool CalculateHasMainLoop(const index_t num_loop)
{ {
return (num_loop / 2) > 1; return (num_loop / 2) > 1;
} }
......
...@@ -457,6 +457,7 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -457,6 +457,7 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
BlockSize, BlockSize,
FloatAB, FloatAB,
FloatAB,
FloatGemmAcc, FloatGemmAcc,
decltype(a_block_desc_ak0_m_ak1), decltype(a_block_desc_ak0_m_ak1),
decltype(b_block_desc_bk0_n_bk1), decltype(b_block_desc_bk0_n_bk1),
......
...@@ -588,6 +588,7 @@ struct GridwiseGemmSplitKMultipleD_xdl_cshuffle ...@@ -588,6 +588,7 @@ struct GridwiseGemmSplitKMultipleD_xdl_cshuffle
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
BlockSize, BlockSize,
ABDataType, ABDataType,
ABDataType,
AccDataType, AccDataType,
decltype(a_block_desc_ak0_m_ak1), decltype(a_block_desc_ak0_m_ak1),
decltype(b_block_desc_bk0_n_bk1), decltype(b_block_desc_bk0_n_bk1),
...@@ -1012,6 +1013,7 @@ struct GridwiseGemmSplitKMultipleD_xdl_cshuffle ...@@ -1012,6 +1013,7 @@ struct GridwiseGemmSplitKMultipleD_xdl_cshuffle
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
BlockSize, BlockSize,
ABDataType, ABDataType,
ABDataType,
AccDataType, AccDataType,
decltype(a_block_desc_ak0_m_ak1), decltype(a_block_desc_ak0_m_ak1),
decltype(b_block_desc_bk0_n_bk1), decltype(b_block_desc_bk0_n_bk1),
......
...@@ -108,7 +108,8 @@ template <typename ALayout, ...@@ -108,7 +108,8 @@ template <typename ALayout,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock, index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopScheduler LoopSched, LoopScheduler LoopSched,
PipelineVersion PipelineVer = PipelineVersion::v1, PipelineVersion PipelineVer = PipelineVersion::v1,
typename ComputeType = FloatC> typename ComputeTypeA = FloatC,
typename ComputeTypeB = ComputeTypeA>
struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
...@@ -547,8 +548,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -547,8 +548,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
constexpr auto c_block_size = constexpr auto c_block_size =
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
return math::max((a_block_space_size_aligned * sizeof(ComputeType) + return math::max((a_block_space_size_aligned * sizeof(ComputeTypeA) +
b_block_space_size_aligned * sizeof(ComputeType)), b_block_space_size_aligned * sizeof(ComputeTypeB)),
c_block_size * sizeof(FloatCShuffle)); c_block_size * sizeof(FloatCShuffle));
} }
...@@ -750,7 +751,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -750,7 +751,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder, ABlockTransferThreadClusterArrangeOrder,
FloatA, FloatA,
ComputeType, ComputeTypeA,
decltype(a_grid_desc_ak0_m_ak1), decltype(a_grid_desc_ak0_m_ak1),
decltype(a_block_desc_ak0_m_ak1), decltype(a_block_desc_ak0_m_ak1),
ABlockTransferSrcAccessOrder, ABlockTransferSrcAccessOrder,
...@@ -781,7 +782,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -781,7 +782,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder, BBlockTransferThreadClusterArrangeOrder,
FloatB, FloatB,
ComputeType, ComputeTypeB,
decltype(b_grid_desc_bk0_n_bk1), decltype(b_grid_desc_bk0_n_bk1),
decltype(b_block_desc_bk0_n_bk1), decltype(b_block_desc_bk0_n_bk1),
BBlockTransferSrcAccessOrder, BBlockTransferSrcAccessOrder,
...@@ -809,13 +810,14 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -809,13 +810,14 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register // register
// sanity check // sanity check
constexpr index_t KPack = constexpr index_t KPack = math::max(
math::max(math::lcm(AK1Number, BK1Number), math::lcm(AK1Number, BK1Number),
MfmaSelector<ComputeType, MPerXdl, NPerXdl>::selected_mfma.k_per_blk); MfmaSelector<ComputeTypeA, MPerXdl, NPerXdl, ComputeTypeB>::selected_mfma.k_per_blk);
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
BlockSize, BlockSize,
ComputeType, ComputeTypeA,
ComputeTypeB,
FloatGemmAcc, FloatGemmAcc,
decltype(a_block_desc_ak0_m_ak1), decltype(a_block_desc_ak0_m_ak1),
decltype(b_block_desc_bk0_n_bk1), decltype(b_block_desc_bk0_n_bk1),
...@@ -833,10 +835,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -833,10 +835,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<ComputeType*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); static_cast<ComputeTypeA*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<ComputeType*>(p_shared) + a_block_space_size_aligned, static_cast<ComputeTypeB*>(p_shared) + a_block_space_size_aligned,
b_block_desc_bk0_n_bk1.GetElementSpaceSize()); b_block_desc_bk0_n_bk1.GetElementSpaceSize());
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0); constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment