Unverified Commit 5aa3c344 authored by rocking5566's avatar rocking5566 Committed by GitHub
Browse files

Merge branch 'develop' into gemm_layernorm_welford

parents 7fefc966 9d8f834a
......@@ -23,11 +23,10 @@ template <typename GridwiseReduction,
typename YDataType,
typename AccDataType,
typename AccElementwiseOperation,
typename GridDesc_M_K,
typename GridDesc_K>
typename GridDesc_M_K>
__global__ void kernel_layernorm(const GridDesc_M_K x_grid_desc_m_k,
const GridDesc_K gamma_grid_desc_k,
const GridDesc_K beta_grid_desc_k,
const GridDesc_M_K gamma_grid_desc_m_k,
const GridDesc_M_K beta_grid_desc_m_k,
const GridDesc_M_K y_grid_desc_m_k,
index_t num_k_block_tile_iteration,
AccDataType epsilon,
......@@ -38,8 +37,8 @@ __global__ void kernel_layernorm(const GridDesc_M_K x_grid_desc_m_k,
const AccElementwiseOperation acc_elementwise_op)
{
GridwiseReduction::Run(x_grid_desc_m_k,
gamma_grid_desc_k,
beta_grid_desc_k,
gamma_grid_desc_m_k,
beta_grid_desc_m_k,
y_grid_desc_m_k,
num_k_block_tile_iteration,
epsilon,
......@@ -71,7 +70,9 @@ template <typename XDataType,
index_t KThreadSliceSize,
index_t XYSrcVectorDim,
index_t XSrcVectorSize,
index_t GammaSrcVectorDim,
index_t GammaSrcVectorSize,
index_t BetaSrcVectorDim,
index_t BetaSrcVectorSize,
index_t YDstVectorSize>
struct DeviceLayernormImpl : public DeviceLayernorm<XDataType,
......@@ -84,11 +85,13 @@ struct DeviceLayernormImpl : public DeviceLayernorm<XDataType,
NumReduceDim>
{
static_assert(
(KThreadSliceSize % GammaSrcVectorSize == 0),
((GammaSrcVectorDim == 0 && MThreadSliceSize % GammaSrcVectorSize == 0) ||
(GammaSrcVectorDim == 1 && KThreadSliceSize % GammaSrcVectorSize == 0)),
"Invalid thread slice sizes and/or gamma vector sizes configuration, please check!");
static_assert(
(KThreadSliceSize % BetaSrcVectorSize == 0),
((BetaSrcVectorDim == 0 && MThreadSliceSize % BetaSrcVectorSize == 0) ||
(BetaSrcVectorDim == 1 && KThreadSliceSize % BetaSrcVectorSize == 0)),
"Invalid thread slice sizes and/or beta vector sizes configuration, please check!");
using PassThrough = tensor_operation::element_wise::PassThrough;
......@@ -162,38 +165,7 @@ struct DeviceLayernormImpl : public DeviceLayernorm<XDataType,
return (in_grid_desc_m_k_padded);
};
static auto MakeAffine1dDescriptor(const std::vector<index_t>& Lengths,
const std::vector<index_t>& Strides,
int blkGroupSize,
int numBlockTileIteration)
{
const auto tupleLengths = make_tuple_from_array(Lengths, Number<NumReduceDim>{});
const auto tupleStrides = make_tuple_from_array(Strides, Number<NumReduceDim>{});
auto desc = make_naive_tensor_descriptor(tupleLengths, tupleStrides);
auto grid_desc_k = transform_tensor_descriptor(
desc,
make_tuple(make_merge_transform(tupleLengths)),
make_tuple(typename arithmetic_sequence_gen<0, NumReduceDim, 1>::type{}),
make_tuple(Sequence<0>{}));
const auto reduceTotalLength = grid_desc_k.GetLength(Number<0>{});
const int reduceSizePerBlock = K_BlockTileSize * numBlockTileIteration;
const auto Pad_K = reduceSizePerBlock * blkGroupSize - reduceTotalLength;
auto grid_desc_k_padded = transform_tensor_descriptor(
grid_desc_k,
make_tuple(make_right_pad_transform(reduceTotalLength, Pad_K)),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0>{}));
return (grid_desc_k_padded);
};
using GridDesc_M_K = decltype(MakeSrc2dDescriptor({1}, {1}, 1, 1));
using GridDesc_K = decltype(MakeAffine1dDescriptor({1}, {1}, 1, 1));
using GridwiseReduceLayernormGeneric =
GridwiseLayernormWelfordVariance_mk_to_mk<XDataType,
......@@ -203,7 +175,6 @@ struct DeviceLayernormImpl : public DeviceLayernorm<XDataType,
AccDataType,
AccElementwiseOperation,
GridDesc_M_K,
GridDesc_K,
BlockSize,
MThreadClusterSize,
KThreadClusterSize,
......@@ -211,12 +182,13 @@ struct DeviceLayernormImpl : public DeviceLayernorm<XDataType,
KThreadSliceSize,
XYSrcVectorDim,
XSrcVectorSize,
GammaSrcVectorDim,
GammaSrcVectorSize,
BetaSrcVectorDim,
BetaSrcVectorSize,
XYSrcVectorDim,
YDstVectorSize,
false>;
using GridwiseReduceLayernormSweepOnce =
GridwiseLayernormWelfordVariance_mk_to_mk<XDataType,
GammaDataType,
......@@ -225,7 +197,6 @@ struct DeviceLayernormImpl : public DeviceLayernorm<XDataType,
AccDataType,
AccElementwiseOperation,
GridDesc_M_K,
GridDesc_K,
BlockSize,
MThreadClusterSize,
KThreadClusterSize,
......@@ -233,7 +204,9 @@ struct DeviceLayernormImpl : public DeviceLayernorm<XDataType,
KThreadSliceSize,
XYSrcVectorDim,
XSrcVectorSize,
GammaSrcVectorDim,
GammaSrcVectorSize,
BetaSrcVectorDim,
BetaSrcVectorSize,
XYSrcVectorDim,
YDstVectorSize,
......@@ -258,13 +231,13 @@ struct DeviceLayernormImpl : public DeviceLayernorm<XDataType,
p_gamma_(p_gamma),
p_beta_(p_beta),
p_y_(p_y),
gammaStrides_(gammaStrides),
betaStrides_(betaStrides),
acc_elementwise_op_(acc_elementwise_op)
{
Lengths_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(lengths, reduceDims);
xStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(xStrides, reduceDims);
yStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(yStrides, reduceDims);
Lengths_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(lengths, reduceDims);
xStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(xStrides, reduceDims);
yStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(yStrides, reduceDims);
gammaStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(gammaStrides, reduceDims);
betaStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(betaStrides, reduceDims);
long_index_t invariant_total_length;
long_index_t reduce_total_length;
......@@ -278,12 +251,17 @@ struct DeviceLayernormImpl : public DeviceLayernorm<XDataType,
gridSize_ = math::integer_least_multiple(invariant_total_length, M_BlockTileSize) /
M_BlockTileSize * blkGroupSize_;
reduceLengths_.resize(NumReduceDim);
for(int i = 0; i < NumReduceDim; ++i)
{
reduceLengths_[i] = lengths[reduceDims[i]];
}
x_grid_desc_m_k_ =
MakeSrc2dDescriptor(Lengths_, xStrides_, blkGroupSize_, numBlockTileIteration_);
gamma_grid_desc_m_k_ =
MakeSrc2dDescriptor(Lengths_, gammaStrides_, blkGroupSize_, numBlockTileIteration_);
beta_grid_desc_m_k_ =
MakeSrc2dDescriptor(Lengths_, betaStrides_, blkGroupSize_, numBlockTileIteration_);
y_grid_desc_m_k_ =
MakeSrc2dDescriptor(Lengths_, yStrides_, blkGroupSize_, numBlockTileIteration_);
isSweeponce_ =
x_grid_desc_m_k_.GetLength(Number<1>{}) <= KThreadClusterSize * KThreadSliceSize;
}
AccDataType epsilon_;
......@@ -295,7 +273,6 @@ struct DeviceLayernormImpl : public DeviceLayernorm<XDataType,
std::vector<index_t> Lengths_;
std::vector<index_t> xStrides_;
std::vector<index_t> reduceLengths_;
std::vector<index_t> gammaStrides_;
std::vector<index_t> betaStrides_;
std::vector<index_t> yStrides_;
......@@ -305,46 +282,35 @@ struct DeviceLayernormImpl : public DeviceLayernorm<XDataType,
int blkGroupSize_;
int numBlockTileIteration_;
size_t gridSize_;
GridDesc_M_K x_grid_desc_m_k_;
GridDesc_M_K gamma_grid_desc_m_k_;
GridDesc_M_K beta_grid_desc_m_k_;
GridDesc_M_K y_grid_desc_m_k_;
bool isSweeponce_;
};
struct Invoker : public BaseInvoker
{
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
const auto x_grid_desc_m_k = MakeSrc2dDescriptor(
arg.Lengths_, arg.xStrides_, arg.blkGroupSize_, arg.numBlockTileIteration_);
const auto gamma_grid_desc_k = MakeAffine1dDescriptor(arg.reduceLengths_,
arg.gammaStrides_,
arg.blkGroupSize_,
arg.numBlockTileIteration_);
const auto beta_grid_desc_k = MakeAffine1dDescriptor(arg.reduceLengths_,
arg.betaStrides_,
arg.blkGroupSize_,
arg.numBlockTileIteration_);
const auto y_grid_desc_m_k = MakeSrc2dDescriptor(
arg.Lengths_, arg.yStrides_, arg.blkGroupSize_, arg.numBlockTileIteration_);
bool sweep_once =
x_grid_desc_m_k.GetLength(Number<1>{}) <= KThreadClusterSize * KThreadSliceSize;
const auto kernel_main = sweep_once ? kernel_layernorm<GridwiseReduceLayernormSweepOnce,
XDataType,
GammaDataType,
BetaDataType,
YDataType,
AccDataType,
AccElementwiseOperation,
GridDesc_M_K,
GridDesc_K>
: kernel_layernorm<GridwiseReduceLayernormGeneric,
XDataType,
GammaDataType,
BetaDataType,
YDataType,
AccDataType,
AccElementwiseOperation,
GridDesc_M_K,
GridDesc_K>;
const auto kernel_main = arg.isSweeponce_
? kernel_layernorm<GridwiseReduceLayernormSweepOnce,
XDataType,
GammaDataType,
BetaDataType,
YDataType,
AccDataType,
AccElementwiseOperation,
GridDesc_M_K>
: kernel_layernorm<GridwiseReduceLayernormGeneric,
XDataType,
GammaDataType,
BetaDataType,
YDataType,
AccDataType,
AccElementwiseOperation,
GridDesc_M_K>;
float avg_time = 0;
avg_time += launch_and_time_kernel(stream_config,
......@@ -352,10 +318,10 @@ struct DeviceLayernormImpl : public DeviceLayernorm<XDataType,
dim3(arg.gridSize_),
dim3(BlockSize),
0,
x_grid_desc_m_k,
gamma_grid_desc_k,
beta_grid_desc_k,
y_grid_desc_m_k,
arg.x_grid_desc_m_k_,
arg.gamma_grid_desc_m_k_,
arg.beta_grid_desc_m_k_,
arg.y_grid_desc_m_k_,
arg.numBlockTileIteration_,
arg.epsilon_,
arg.p_x_,
......@@ -409,26 +375,41 @@ struct DeviceLayernormImpl : public DeviceLayernorm<XDataType,
return false;
}
if(p_arg_->gammaStrides_.size() != NumReduceDim ||
p_arg_->betaStrides_.size() != NumReduceDim)
return false;
// if fastest dim is not reduced
if constexpr(GammaSrcVectorDim == 0)
{
if(p_arg_->gammaStrides_[NumInvariantDim - 1] != 1)
return (false);
auto IsScalarPerVectorValid = [](bool isLastDimensionCoalesced, int scalarPerVector) {
bool ret = true;
if(p_arg_->Lengths_[Rank - 1] % GammaSrcVectorSize != 0)
return (false);
}
else // if fastest dim is reduced
{
if(p_arg_->gammaStrides_[Rank - 1] != 1)
return (false);
if(!isLastDimensionCoalesced)
ret = scalarPerVector == 1;
else
ret = KThreadSliceSize % scalarPerVector == 0;
if(p_arg_->Lengths_[Rank - 1] % GammaSrcVectorSize != 0)
return (false);
}
return ret;
};
// if fastest dim is not reduced
if constexpr(BetaSrcVectorDim == 0)
{
if(p_arg_->betaStrides_[NumInvariantDim - 1] != 1)
return (false);
if(!IsScalarPerVectorValid(p_arg_->gammaStrides_.back() == 1, GammaSrcVectorSize))
return false;
if(p_arg_->invariant_lowest_length % BetaSrcVectorSize != 0)
return (false);
}
else // if fastest dim is reduced
{
if(p_arg_->betaStrides_[Rank - 1] != 1)
return (false);
if(!IsScalarPerVectorValid(p_arg_->betaStrides_.back() == 1, BetaSrcVectorSize))
return false;
if(p_arg_->Lengths_[Rank - 1] % BetaSrcVectorSize != 0)
return (false);
}
return true;
};
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <array>
#include <cmath>
#include <memory>
#include <type_traits>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <index_t NumDim, typename InDataType, typename OutDataType, typename ElementwiseOperation>
struct DevicePermute : BaseOperator
{
using Lengths = std::array<index_t, NumDim>;
using Strides = Lengths;
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const Lengths& in_lengths,
const Strides& in_strides,
const Lengths& out_lengths,
const Strides& out_strides,
const void* in_dev_buffer,
void* out_dev_buffer,
ElementwiseOperation elementwise_op) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_data_multiple_d.hpp"
#include "ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp"
#include "ck/tensor_operation/operator_transform/transform_conv_bwd_data_to_gemm_v1.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/io.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace {
template <index_t NumDTensor>
struct ComputePtrOffsetOfStridedBatch
{
ComputePtrOffsetOfStridedBatch() = default;
ComputePtrOffsetOfStridedBatch(index_t BatchStrideA,
index_t BatchStrideB,
Array<ck::index_t, NumDTensor> BatchStrideDs,
index_t BatchStrideE)
: BatchStrideA_(BatchStrideA),
BatchStrideB_(BatchStrideB),
BatchStrideDs_(BatchStrideDs),
BatchStrideE_(BatchStrideE)
{
}
__host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideA_);
}
__host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideB_);
}
__host__ __device__ constexpr auto GetDsPtrOffset(index_t g_idx) const
{
Array<long_index_t, NumDTensor> ds_offset;
static_for<0, NumDTensor, 1>{}(
[&](auto i) { ds_offset(i) = g_idx * static_cast<long_index_t>(BatchStrideDs_[i]); });
return ds_offset;
}
__host__ __device__ constexpr long_index_t GetEPtrOffset(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideE_);
}
index_t BatchStrideA_;
index_t BatchStrideB_;
Array<ck::index_t, NumDTensor> BatchStrideDs_;
index_t BatchStrideE_;
};
/*
* \brief Wrapper function of GridwiseGemm::Run to realize BatchedGEMM.
*
* \tparam ComputePtrOffsetOfBatch Class that computes the base pointer offsets of A, B, C matrix
* given the batch. For example, ComputePtrOffsetOfStridedBatch() computes the offsets of evenly
* strided batched, but we can easily extend to other layouts. The returned offset can be either \p
* index_t or \p long_index_t. If it returns \p long_index_t, we are not subject to the 2GB
* limitations.
*
* \tparam Block2ETileMap Block2ETileMap::CalculateBottomIndex() takes in id of a workgroup and
* returns the 2D index of the tile that it computes. \see
* GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3::Run().
*
* \note Using \p ComputePtrOffsetOfBatch gives us the flexibility that 2 workgroups can compute 2
* tiles from different matrices. Keep in mind that these 2 matrices can share the same grid
* descriptor (like in BatchedGEMM), or use their own grid descriptors (in GroupedGemm). \link
* device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp kernel_gemm_xdlops_v2r3_for_conv3d \endlink for \link
* DeviceConv3d \endlink uses the same concept, but currently does NOT encapsulate the computing of
* pointer offset into \p ComputePtrOffsetOfStridedBatch.
*
* \note \p Block2ETileMap allows customized mapping between a workgroup and the C-tile it computes.
* Together with \p ComputePtrOffsetOfBatch, we can reuse GridwiseGemm (and GridwiseGemm fusion ) to
* realize BatchedGemm and GroupedGemm (and the corresponding GEMM fusion).
*
*/
template <typename GridwiseGemm,
typename ABDataType,
typename DsPointer,
typename EDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1,
typename DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
typename Block2ETileMap,
typename ComputePtrOffsetOfBatch,
bool HasMainKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_grouped_conv_bwd_data_multiple_d_xdl_cshuffle(
const ABDataType* __restrict__ p_a_grid,
const ABDataType* __restrict__ p_b_grid,
DsPointer p_ds_grid,
EDataType* __restrict__ p_e_grid,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CDEElementwiseOperation cde_element_op,
const index_t batch_count,
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock,
const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock_,
const Block2ETileMap block_2_ctile_map,
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
// offset base pointer for each work-group
const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)));
const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)));
const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)));
const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx);
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
DsPointer p_ds_grid_grp;
static constexpr index_t NumDTensor =
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock::Size();
static_for<0, NumDTensor, 1>{}(
[&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_batch_offset[i]; });
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid + a_batch_offset,
p_b_grid + b_batch_offset,
p_ds_grid_grp,
p_e_grid + e_batch_offset,
p_shared,
a_element_op,
b_element_op,
cde_element_op,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
ds_grid_desc_mblock_mperblock_nblock_nperblock,
e_grid_desc_mblock_mperblock_nblock_nperblock_,
block_2_ctile_map);
#else
ignore = p_a_grid;
ignore = p_b_grid;
ignore = p_ds_grid;
ignore = p_e_grid;
ignore = batch_count;
ignore = a_grid_desc_ak0_m_ak1;
ignore = b_grid_desc_bk0_n_bk1;
ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = e_grid_desc_mblock_mperblock_nblock_nperblock_;
ignore = a_element_op;
ignore = b_element_op;
ignore = cde_element_op;
ignore = compute_ptr_offset_of_batch;
ignore = block_2_ctile_map;
#endif
}
} // namespace
// Conv backward data multiple D:
// input : output image A: [G, N, K, Ho, Wo]
// input : weight B: [G, K, C, Y, X],
// input : D0, D1, ... : [G, N, K, Ho, Wo]
// output : input image E: [G, N, C, Hi, Wi]
// C = a_op(A) * b_op(B)
// E = cde_op(C, D0, D1, ...)
template <index_t NDimSpatial,
typename ALayout, // output image
typename BLayout, // weight
typename DsLayout, // bias
typename ELayout, // input image
typename ADataType, // output image
typename BDataType, // weight
typename AccDataType,
typename CShuffleDataType,
typename DsDataType, // bias
typename EDataType, // input image
typename AElementwiseOp, // output image
typename BElementwiseOp, // weight
typename CDEElementwiseOp, // C, bias, and input image
ConvolutionBackwardDataSpecialization ConvBackwardDataSpecialization,
bool DoPadGemmM,
bool DoPadGemmN,
index_t NumGemmKPrefetchStage,
index_t BlockSize,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t AK1,
index_t BK1,
index_t MPerXDL,
index_t NPerXDL,
index_t MXdlPerWave,
index_t NXdlPerWave,
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_AK1,
index_t ABlockLdsExtraM,
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_BK1,
index_t BBlockLdsExtraN,
index_t CShuffleMXdlPerWavePerShuffle,
index_t CShuffleNXdlPerWavePerShuffle,
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CDEBlockTransferScalarPerVector_NPerBlock,
LoopScheduler LoopSched = make_default_loop_scheduler()>
struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
: public DeviceGroupedConvBwdDataMultipleD<NDimSpatial,
ALayout, // output image
BLayout, // weight
DsLayout, // bias
ELayout, // input image
ADataType, // output image
BDataType, // weight
DsDataType, // bias
EDataType, // input image
AElementwiseOp,
BElementwiseOp,
CDEElementwiseOp>
{
// FIXME
static_assert(NDimSpatial == 2, "wrong! only implemented for 2D now");
using DeviceOp = DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1;
static constexpr index_t NumDTensor = DsDataType::Size();
// TODO make A/B datatype different
using ABDataType = ADataType;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto transform_conv_to_gemm =
TransformConvBwdDataToGemm_v1<NDimSpatial,
ConvBackwardDataSpecialization,
AK1,
BK1,
MPerBlock,
NPerBlock,
DoPadGemmM,
DoPadGemmN>{};
static auto GetDummyABDsEGridDescriptor()
{
const std::array<index_t, NDimSpatial + 3> dummy_tensor_lengths = {1};
const std::array<index_t, NDimSpatial + 3> dummy_tensor_strides = {1};
const std::array<index_t, NDimSpatial> dummy_spatial_lengths = {1};
const auto a_grid_desc_ak0_m_ak1 =
transform_conv_to_gemm.template MakeADescriptor_AK0_M_AK1<ALayout>(
dummy_tensor_lengths,
dummy_tensor_strides,
dummy_tensor_lengths,
dummy_tensor_strides,
dummy_tensor_lengths,
dummy_tensor_strides,
dummy_spatial_lengths,
dummy_spatial_lengths,
dummy_spatial_lengths,
dummy_spatial_lengths,
dummy_spatial_lengths);
const auto b_grid_desc_bk0_n_bk1 =
transform_conv_to_gemm.template MakeBDescriptor_BK0_N_BK1<BLayout>(
dummy_tensor_lengths,
dummy_tensor_strides,
dummy_tensor_lengths,
dummy_tensor_strides,
dummy_tensor_lengths,
dummy_tensor_strides,
dummy_spatial_lengths,
dummy_spatial_lengths,
dummy_spatial_lengths,
dummy_spatial_lengths,
dummy_spatial_lengths);
const auto ds_grid_desc_m_n = generate_tuple(
[&](auto i) {
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
return transform_conv_to_gemm.template MakeCDescriptor_M_N<DLayout>(
dummy_tensor_lengths,
dummy_tensor_strides,
dummy_tensor_lengths,
dummy_tensor_strides,
dummy_tensor_lengths,
dummy_tensor_strides,
dummy_spatial_lengths,
dummy_spatial_lengths,
dummy_spatial_lengths,
dummy_spatial_lengths,
dummy_spatial_lengths);
},
Number<NumDTensor>{});
const auto e_grid_desc_m_n =
transform_conv_to_gemm.template MakeCDescriptor_M_N<ELayout>(dummy_tensor_lengths,
dummy_tensor_strides,
dummy_tensor_lengths,
dummy_tensor_strides,
dummy_tensor_lengths,
dummy_tensor_strides,
dummy_spatial_lengths,
dummy_spatial_lengths,
dummy_spatial_lengths,
dummy_spatial_lengths,
dummy_spatial_lengths);
return make_tuple(
a_grid_desc_ak0_m_ak1, b_grid_desc_bk0_n_bk1, ds_grid_desc_m_n, e_grid_desc_m_n);
}
// GridwiseGemm
using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle<
ABDataType, // TODO: distinguish A/B datatype
AccDataType,
CShuffleDataType,
DsDataType,
EDataType,
AElementwiseOp,
BElementwiseOp,
CDEElementwiseOp,
InMemoryDataOperationEnum::Set,
NumGemmKPrefetchStage,
BlockSize,
MPerBlock,
NPerBlock,
KPerBlock,
AK1,
BK1,
MPerXDL,
NPerXDL,
MXdlPerWave,
NXdlPerWave,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
false,
ABlockLdsExtraM,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
false,
BBlockLdsExtraN,
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CDEBlockTransferScalarPerVector_NPerBlock,
LoopSched>;
template <typename Desc_K0_M_K1>
static auto transform_k0_m_k1_to_m_k(const Desc_K0_M_K1& desc_k0_m_k1)
{
const auto grid_desc_m_k = transform_tensor_descriptor(
desc_k0_m_k1,
make_tuple(make_pass_through_transform(desc_k0_m_k1.GetLength(I1)),
make_merge_transform(
make_tuple(desc_k0_m_k1.GetLength(I0), desc_k0_m_k1.GetLength(I2)))),
make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return grid_desc_m_k;
}
// desc
using ABDsEGridDesc = decltype(GetDummyABDsEGridDescriptor());
using AGridDesc_AK0_M_AK1 = remove_cvref_t<tuple_element_t<0, ABDsEGridDesc>>;
using BGridDesc_BK0_N_BK1 = remove_cvref_t<tuple_element_t<1, ABDsEGridDesc>>;
using DsGridDesc_M_N = remove_cvref_t<tuple_element_t<2, ABDsEGridDesc>>;
using EGridDesc_M_N = remove_cvref_t<tuple_element_t<3, ABDsEGridDesc>>;
using AGridDesc_M_K = decltype(transform_k0_m_k1_to_m_k(AGridDesc_AK0_M_AK1{}));
using BGridDesc_N_K = decltype(transform_k0_m_k1_to_m_k(BGridDesc_BK0_N_BK1{}));
using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = decltype(
GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(DsGridDesc_M_N{}));
using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = decltype(
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{}));
// block-to-e-tile map
using Block2ETileMap =
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultBlock2ETileMap(EGridDesc_M_N{}))>;
// Argument
struct Argument : public BaseArgument
{
Argument(const void* p_a, // output image
const void* p_b, // weight
const std::array<const void*, NumDTensor>& p_ds, // bias
void* p_e, // input image
const std::array<index_t, NDimSpatial + 3>& a_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& a_g_n_k_wos_strides,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>&
ds_g_n_c_wis_lengths,
const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>&
ds_g_n_c_wis_strides,
const std::array<index_t, NDimSpatial + 3>& e_g_n_c_wis_lengths,
const std::array<index_t, NDimSpatial + 3>& e_g_n_c_wis_strides,
const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads,
const std::array<index_t, NDimSpatial>& input_right_pads,
const AElementwiseOp& a_element_op,
const BElementwiseOp& b_element_op,
const CDEElementwiseOp& cde_element_op)
: p_a_grid_{static_cast<const ADataType*>(p_a)},
p_b_grid_{static_cast<const BDataType*>(p_b)},
p_ds_grid_{},
p_e_grid_{static_cast<EDataType*>(p_e)},
num_group_{a_g_n_k_wos_lengths[0]},
num_gemm_{},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
cde_element_op_{cde_element_op},
a_g_n_k_wos_lengths_{a_g_n_k_wos_lengths},
a_g_n_k_wos_strides_{a_g_n_k_wos_strides},
b_g_k_c_xs_lengths_{b_g_k_c_xs_lengths},
b_g_k_c_xs_strides_{b_g_k_c_xs_strides},
ds_g_n_c_wis_lengths_{ds_g_n_c_wis_lengths},
ds_g_n_c_wis_strides_{ds_g_n_c_wis_strides},
e_g_n_c_wis_lengths_{e_g_n_c_wis_lengths},
e_g_n_c_wis_strides_{e_g_n_c_wis_strides},
conv_filter_strides_{conv_filter_strides},
conv_filter_dilations_{conv_filter_dilations},
input_left_pads_{input_left_pads},
input_right_pads_{input_right_pads}
{
// populate Ds pointer
static_for<0, NumDTensor, 1>{}([&](auto i) {
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
p_ds_grid_(i) = static_cast<const DDataType*>(p_ds[i]);
});
// A/B/Ds/E Batch Stride
compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_k_wos_strides[0];
compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_k_c_xs_strides[0];
compute_ptr_offset_of_batch_.BatchStrideE_ = e_g_n_c_wis_strides[0];
static_for<0, NumDTensor, 1>{}([&](auto i) {
compute_ptr_offset_of_batch_.BatchStrideDs_(i) = ds_g_n_c_wis_strides[i][0];
});
// problem definition
const index_t Y = b_g_k_c_xs_lengths[3];
const index_t X = b_g_k_c_xs_lengths[4];
const index_t ConvStrideH = conv_filter_strides_[0];
const index_t ConvStrideW = conv_filter_strides_[1];
const index_t ConvDilationH = conv_filter_dilations_[0];
const index_t ConvDilationW = conv_filter_dilations_[1];
const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
const auto YTilde = ConvStrideH / GcdStrideDilationH;
const auto XTilde = ConvStrideW / GcdStrideDilationW;
// number of GEMM
num_gemm_ = YTilde * XTilde;
for(index_t i_ytilde = 0; i_ytilde < YTilde; ++i_ytilde)
{
for(index_t i_xtilde = 0; i_xtilde < XTilde; ++i_xtilde)
{
// check slice is valid
const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde);
const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde);
if(YDotSlice * XDotSlice <= 0)
{
continue;
}
const auto a_grid_desc_ak0_m_ak1 =
transform_conv_to_gemm.template MakeADescriptor_AK0_M_AK1<ALayout>(
a_g_n_k_wos_lengths,
a_g_n_k_wos_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
e_g_n_c_wis_lengths,
e_g_n_c_wis_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
{i_ytilde, i_xtilde});
const auto b_grid_desc_bk0_n_bk1 =
transform_conv_to_gemm.template MakeBDescriptor_BK0_N_BK1<BLayout>(
a_g_n_k_wos_lengths,
a_g_n_k_wos_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
e_g_n_c_wis_lengths,
e_g_n_c_wis_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
{i_ytilde, i_xtilde});
DsGridDesc_M_N ds_grid_desc_m_n;
// populate Ds desc
static_for<0, NumDTensor, 1>{}([&](auto i) {
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
ds_grid_desc_m_n(i) =
transform_conv_to_gemm.template MakeCDescriptor_M_N<DLayout>(
a_g_n_k_wos_lengths,
a_g_n_k_wos_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
ds_g_n_c_wis_lengths[i],
ds_g_n_c_wis_strides[i],
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
{i_ytilde, i_xtilde});
});
const auto e_grid_desc_m_n =
transform_conv_to_gemm.template MakeCDescriptor_M_N<ELayout>(
a_g_n_k_wos_lengths,
a_g_n_k_wos_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
e_g_n_c_wis_lengths,
e_g_n_c_wis_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
{i_ytilde, i_xtilde});
// desc for problem definition
const auto a_grid_desc_m_k = transform_k0_m_k1_to_m_k(a_grid_desc_ak0_m_ak1);
const auto b_grid_desc_n_k = transform_k0_m_k1_to_m_k(b_grid_desc_bk0_n_bk1);
a_grid_desc_m_k_container_.push_back(a_grid_desc_m_k);
b_grid_desc_n_k_container_.push_back(b_grid_desc_n_k);
ds_grid_desc_m_n_container_.push_back(ds_grid_desc_m_n);
e_grid_desc_m_n_container_.push_back(e_grid_desc_m_n);
// desc for blockwise copy
a_grid_desc_ak0_m_ak1_container_.push_back(a_grid_desc_ak0_m_ak1);
b_grid_desc_bk0_n_bk1_container_.push_back(b_grid_desc_bk0_n_bk1);
// block-to-e-tile-map
auto block_2_etile_map =
GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n);
block_2_etile_map_container_.push_back(block_2_etile_map);
if(GridwiseGemm::CheckValidity(a_grid_desc_m_k,
b_grid_desc_n_k,
ds_grid_desc_m_n,
e_grid_desc_m_n,
block_2_etile_map))
{
ds_grid_desc_mblock_mperblock_nblock_nperblock_container_.push_back(
GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
ds_grid_desc_m_n));
e_grid_desc_mblock_mperblock_nblock_nperblock_container_.push_back(
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
e_grid_desc_m_n));
}
}
}
}
void Print() const
{
for(index_t i = 0; i < num_gemm_; i++)
{
std::cout << "a_grid_desc_ak0_m_ak1_container_"
<< a_grid_desc_ak0_m_ak1_container_[i] << std::endl;
std::cout << "b_grid_desc_bk0_n_bk1_container_"
<< b_grid_desc_bk0_n_bk1_container_[i] << std::endl;
static_for<0, NumDTensor, 1>{}([&](auto j) {
std::cout << "ds_grid_desc_mblock_mperblock_nblock_nperblock_container_"
<< ds_grid_desc_mblock_mperblock_nblock_nperblock_container_[i][j]
<< std::endl;
});
std::cout << "e_grid_desc_mblock_mperblock_nblock_nperblock_container_"
<< e_grid_desc_mblock_mperblock_nblock_nperblock_container_[i]
<< std::endl;
}
}
// pointers
const ADataType* p_a_grid_;
const BDataType* p_b_grid_;
typename GridwiseGemm::DsGridPointer p_ds_grid_;
EDataType* p_e_grid_;
// tensor descriptor for problem definition
index_t num_group_;
index_t num_gemm_;
std::vector<AGridDesc_M_K> a_grid_desc_m_k_container_;
std::vector<BGridDesc_N_K> b_grid_desc_n_k_container_;
std::vector<DsGridDesc_M_N> ds_grid_desc_m_n_container_;
std::vector<EGridDesc_M_N> e_grid_desc_m_n_container_;
// tensor descriptor for block-wise copy
std::vector<AGridDesc_AK0_M_AK1> a_grid_desc_ak0_m_ak1_container_;
std::vector<BGridDesc_BK0_N_BK1> b_grid_desc_bk0_n_bk1_container_;
std::vector<DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>
ds_grid_desc_mblock_mperblock_nblock_nperblock_container_;
std::vector<EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>
e_grid_desc_mblock_mperblock_nblock_nperblock_container_;
// block-to-e-tile map
std::vector<Block2ETileMap> block_2_etile_map_container_;
// for computing batch offset
ComputePtrOffsetOfStridedBatch<NumDTensor> compute_ptr_offset_of_batch_;
// element-wise op
AElementwiseOp a_element_op_;
BElementwiseOp b_element_op_;
CDEElementwiseOp cde_element_op_;
// for checking IsSupportedArgument()
std::array<index_t, NDimSpatial + 3> a_g_n_k_wos_lengths_;
std::array<index_t, NDimSpatial + 3> a_g_n_k_wos_strides_;
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_lengths_;
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_strides_;
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_c_wis_lengths_;
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_c_wis_strides_;
std::array<index_t, NDimSpatial + 3> e_g_n_c_wis_lengths_;
std::array<index_t, NDimSpatial + 3> e_g_n_c_wis_strides_;
std::array<index_t, NDimSpatial> conv_filter_strides_;
std::array<index_t, NDimSpatial> conv_filter_dilations_;
std::array<index_t, NDimSpatial> input_left_pads_;
std::array<index_t, NDimSpatial> input_right_pads_;
};
// Invoker
struct Invoker : public BaseInvoker
{
using Argument = DeviceOp::Argument;
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
if(stream_config.log_level_ > 0)
{
arg.Print();
}
float ave_time = 0;
for(index_t i = 0; i < arg.num_gemm_; i++)
{
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_container_[i],
arg.b_grid_desc_n_k_container_[i],
arg.ds_grid_desc_m_n_container_[i],
arg.e_grid_desc_m_n_container_[i],
arg.block_2_etile_map_container_[i]))
{
throw std::runtime_error("wrong! device_op has invalid setting");
}
const index_t grid_size = arg.block_2_etile_map_container_[i].CalculateGridSize(
arg.e_grid_desc_m_n_container_[i]) *
arg.num_group_;
const auto GemmK = arg.a_grid_desc_m_k_container_[i].GetLength(I1);
auto launch_kernel = [&](auto has_main_k_block_loop) {
constexpr bool has_main_loop = has_main_k_block_loop.value;
const auto kernel = kernel_grouped_conv_bwd_data_multiple_d_xdl_cshuffle<
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
typename GridwiseGemm::DsGridPointer,
EDataType,
AElementwiseOp,
BElementwiseOp,
CDEElementwiseOp,
DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1,
DeviceOp::DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
Block2ETileMap,
ComputePtrOffsetOfStridedBatch<NumDTensor>,
has_main_loop>;
return launch_and_time_kernel(
stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_ds_grid_,
arg.p_e_grid_,
arg.a_element_op_,
arg.b_element_op_,
arg.cde_element_op_,
arg.a_g_n_k_wos_lengths_[0], // Group count
arg.a_grid_desc_ak0_m_ak1_container_[i],
arg.b_grid_desc_bk0_n_bk1_container_[i],
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_container_[i],
arg.e_grid_desc_mblock_mperblock_nblock_nperblock_container_[i],
arg.block_2_etile_map_container_[i],
arg.compute_ptr_offset_of_batch_);
};
if(GridwiseGemm::CalculateHasMainKBlockLoop(GemmK))
{
ave_time += launch_kernel(integral_constant<bool, true>{});
}
else
{
ave_time += launch_kernel(integral_constant<bool, false>{});
}
}
return ave_time;
}
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
static bool IsSupportedArgument(const Argument& arg)
{
const index_t ConvK = arg.b_g_k_c_xs_lengths_[1];
const index_t ConvC = arg.b_g_k_c_xs_lengths_[2];
// Specifialization
if constexpr(ConvBackwardDataSpecialization ==
ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0)
{
// check if it's 1x1, stride=1 pad = 0 conv
for(int i = 0; i < NDimSpatial; i++)
{
if(!(arg.b_g_k_c_xs_lengths_[3 + i] == 1 && arg.conv_filter_strides_[i] == 1 &&
arg.input_left_pads_[i] == 0 && arg.input_right_pads_[i] == 0))
{
return false;
}
}
}
// vector load for A matrix from global memory to LDS
if constexpr(is_same_v<ALayout, tensor_layout::convolution::GNHWK>)
{
if(!(ABlockTransferSrcVectorDim == 2 && ConvK % ABlockTransferSrcScalarPerVector == 0))
{
return false;
}
}
else
{
return false;
}
// vector load for B matrix from global memory to LDS
if constexpr(is_same_v<BLayout, tensor_layout::convolution::GKYXC>)
{
if(!(BBlockTransferSrcVectorDim == 1 && ConvC % BBlockTransferSrcScalarPerVector == 0))
{
return false;
}
}
else
{
return false;
}
// vector store for Ds
bool ds_valid = true;
static_for<0, NumDTensor, 1>{}([&](auto i) {
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
if constexpr(is_same_v<DLayout, tensor_layout::convolution::GNHWC> ||
is_same_v<DLayout, tensor_layout::convolution::NHWGC> ||
is_same_v<DLayout, tensor_layout::convolution::G_NHW_C> ||
is_same_v<DLayout, tensor_layout::convolution::GC> ||
is_same_v<DLayout, tensor_layout::convolution::G_C>)
{
// vector load D matrix from global memory
if(!(ConvC % CDEBlockTransferScalarPerVector_NPerBlock == 0))
{
ds_valid = false;
}
}
else
{
ds_valid = false;
}
});
if(!ds_valid)
{
return false;
}
// vector store for E
if constexpr(is_same_v<ELayout, tensor_layout::convolution::GNHWC>)
{
// vector store C matrix into global memory
if(!(ConvC % CDEBlockTransferScalarPerVector_NPerBlock == 0))
{
return false;
}
}
else
{
return false;
}
// Gridwise GEMM size
for(std::size_t i = 0; i < arg.a_grid_desc_ak0_m_ak1_container_.size(); i++)
{
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_container_[i],
arg.b_grid_desc_n_k_container_[i],
arg.ds_grid_desc_m_n_container_[i],
arg.e_grid_desc_m_n_container_[i],
arg.block_2_etile_map_container_[i]))
{
return false;
}
}
return true;
}
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
static auto
MakeArgument(const void* p_a, // output image
const void* p_b, // weight
const std::array<const void*, NumDTensor>& p_ds, // bias
void* p_e, // input image
const std::array<index_t, NDimSpatial + 3>& a_g_n_k_wos_lengths, // output image
const std::array<index_t, NDimSpatial + 3>& a_g_n_k_wos_strides, // output image
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths, // weight
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides, // weight
const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>&
ds_g_n_c_wis_lengths, // bias
const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>&
ds_g_n_c_wis_strides, // bias
const std::array<index_t, NDimSpatial + 3>& e_g_n_c_wis_lengths, // input image
const std::array<index_t, NDimSpatial + 3>& e_g_n_c_wis_strides, // input image
const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads,
const std::array<index_t, NDimSpatial>& input_right_pads,
const AElementwiseOp& a_element_op,
const BElementwiseOp& b_element_op,
const CDEElementwiseOp& cde_element_op)
{
return Argument{p_a,
p_b,
p_ds,
p_e,
a_g_n_k_wos_lengths,
a_g_n_k_wos_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
ds_g_n_c_wis_lengths,
ds_g_n_c_wis_strides,
e_g_n_c_wis_lengths,
e_g_n_c_wis_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
a_element_op,
b_element_op,
cde_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
std::unique_ptr<BaseArgument> MakeArgumentPointer(
const void* p_a, // output image
const void* p_b, // weight
const std::array<const void*, NumDTensor>& p_ds, // bias
void* p_e, // input image
const std::array<index_t, NDimSpatial + 3>& a_g_n_k_wos_lengths, // output image
const std::array<index_t, NDimSpatial + 3>& a_g_n_k_wos_strides, // output image
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths, // weight
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides, // weight
const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>&
ds_g_n_c_wis_lengths, // bias
const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>&
ds_g_n_c_wis_strides, // bias
const std::array<index_t, NDimSpatial + 3>& e_g_n_c_wis_lengths, // input image
const std::array<index_t, NDimSpatial + 3>& e_g_n_c_wis_strides, // input image
const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads,
const std::array<index_t, NDimSpatial>& input_right_pads,
const AElementwiseOp& a_element_op,
const BElementwiseOp& b_element_op,
const CDEElementwiseOp& cde_element_op) override
{
return std::make_unique<Argument>(p_a,
p_b,
p_ds,
p_e,
a_g_n_k_wos_lengths,
a_g_n_k_wos_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
ds_g_n_c_wis_lengths,
ds_g_n_c_wis_strides,
e_g_n_c_wis_lengths,
e_g_n_c_wis_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
a_element_op,
b_element_op,
cde_element_op);
}
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>(Invoker{});
}
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1"
<< "<"
<< BlockSize << ", "
<< MPerBlock << ", "
<< NPerBlock << ", "
<< KPerBlock << ", "
<< AK1 << ", "
<< BK1 << ", "
<< getConvBackwardDataSpecializationString(ConvBackwardDataSpecialization)
<< ">";
return str.str();
}
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <array>
#include <memory>
#include <utility>
#include "ck/utility/math.hpp"
#include "ck/utility/sequence.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/tensor_operation/gpu/device/device_permute.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_permute.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
// Swap last 2 dimensions
// input shape: [d[0], d[1], d[2], ..., d[NumDim-3], d[NumDim-2], d[NumDim-1]]
// ^^^^^^^^^^^
// output shape: [d[0], d[1], d[2], ..., d[NumDim-3], d[NumDim-1], d[NumDim-2]]
// ^^^^^^^^^^^
template <index_t NumDim,
typename InDataType,
typename OutDataType,
typename ElementwiseOperation,
index_t BlockSize,
index_t NPerBlock,
index_t HPerBlock,
index_t WPerBlock,
index_t InBlockLdsExtraW,
typename InBlockTransferThreadClusterLengths,
typename InBlockTransferThreadClusterArrangeOrder,
index_t SrcVectorDim,
index_t DstVectorDim,
index_t SrcScalarPerVector,
index_t DstScalarPerVector>
struct DevicePermuteImpl : DevicePermute<NumDim, InDataType, OutDataType, ElementwiseOperation>
{
using BaseType = DevicePermute<NumDim, InDataType, OutDataType, ElementwiseOperation>;
using typename BaseType::Lengths;
using typename BaseType::Strides;
static_assert(3 <= NumDim, "Only accept at least 3D dimension tensor");
static_assert((NumDim - 2) <= SrcVectorDim && SrcVectorDim < NumDim);
static_assert((NumDim - 2) <= DstVectorDim && DstVectorDim < NumDim);
static_assert(SrcVectorDim != DstVectorDim);
template <index_t N = NumDim>
static auto ConvertArrayToTuple(const std::array<index_t, NumDim>& array)
{
static_assert(1 <= N && N <= NumDim);
return generate_tuple([&](auto I) { return array[I]; }, Number<N>{});
}
static auto MakeDescriptor_N_H_W(const Lengths& lengths, const Strides& stride)
{
// create nd descriptor, shape: [d[0], d[1], d[2], ..., d[NumDim-3], d[NumDim-2],
// d[NumDim-1]]
const auto desc =
make_naive_tensor_descriptor(ConvertArrayToTuple(lengths), ConvertArrayToTuple(stride));
// merge nd to 3d descriptor, shape: [(d[0] * d[1] * d[2] * ... * d[NumDim-3]), d[NumDim-2],
// d[NumDim-1]]
// => [N, H, W]
const index_t H = *std::next(rbegin(lengths));
const index_t W = *rbegin(lengths);
const auto desc_n_h_w = transform_tensor_descriptor(
desc,
make_tuple(make_merge_transform(ConvertArrayToTuple<NumDim - 2>(lengths)),
make_pass_through_transform(H),
make_pass_through_transform(W)),
make_tuple(generate_sequence_v2([&](auto I) { return I; }, Number<NumDim - 2>{}),
Sequence<NumDim - 2>{},
Sequence<NumDim - 1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
return PadTensorDescriptor(
desc_n_h_w, make_tuple(NPerBlock, HPerBlock, WPerBlock), Sequence<true, true, true>{});
}
using InGridDesc = decltype(MakeDescriptor_N_H_W({1, 1}, {1, 1}));
using OutGridDesc = InGridDesc;
using GridwisePermute = GridwisePermute<
InGridDesc,
OutGridDesc,
InDataType,
OutDataType,
ElementwiseOperation,
BlockSize,
NPerBlock,
HPerBlock,
WPerBlock,
InBlockLdsExtraW,
InBlockTransferThreadClusterLengths,
InBlockTransferThreadClusterArrangeOrder,
SrcVectorDim - (NumDim - 3), // calculate new SrcVectorDim for the merged descriptor
DstVectorDim - (NumDim - 3), // calculate new DstVectorDim for the merged descriptor
SrcScalarPerVector,
DstScalarPerVector>;
using Block2TileMap = typename GridwisePermute::DefaultBlock2TileMap;
struct Argument : public BaseArgument
{
Argument(const Lengths& in_lengths,
const Strides& in_strides,
const Lengths& out_lengths,
const Strides& out_strides,
const void* in_dev_buffer,
void* out_dev_buffer,
ElementwiseOperation elementwise_op)
: in_dev_buffer_(static_cast<const InDataType*>(in_dev_buffer)),
out_dev_buffer_(static_cast<OutDataType*>(out_dev_buffer)),
in_grid_desc_(MakeDescriptor_N_H_W(in_lengths, in_strides)),
out_grid_desc_(MakeDescriptor_N_H_W(out_lengths, out_strides)),
in_lengths_(in_lengths),
in_strides_(in_strides),
out_lengths_(out_lengths),
out_strides_(out_strides),
elementwise_op_(elementwise_op),
block_2_tile_map_(GridwisePermute::MakeDefaultBlock2TileMap(in_grid_desc_))
{
}
const InDataType* in_dev_buffer_;
OutDataType* out_dev_buffer_;
InGridDesc in_grid_desc_;
OutGridDesc out_grid_desc_;
Lengths in_lengths_;
Strides in_strides_;
Lengths out_lengths_;
Strides out_strides_;
ElementwiseOperation elementwise_op_;
Block2TileMap block_2_tile_map_;
};
struct Invoker : BaseInvoker
{
static float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
const index_t grid_size = arg.block_2_tile_map_.CalculateGridSize(arg.in_grid_desc_);
const auto kernel = kernel_nd_permute<GridwisePermute,
InGridDesc,
OutGridDesc,
InDataType,
OutDataType,
ElementwiseOperation,
Block2TileMap>;
float elapsed_time = launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.in_grid_desc_,
arg.out_grid_desc_,
arg.in_dev_buffer_,
arg.out_dev_buffer_,
arg.elementwise_op_,
arg.block_2_tile_map_);
return elapsed_time;
}
float Run(const BaseArgument* arg,
const StreamConfig& stream_config = StreamConfig{}) override final
{
const auto* const argument = dynamic_cast<const Argument*>(arg);
if(!argument)
{
return NAN;
}
return Run(*argument, stream_config);
}
};
static bool IsSupportedArgument(const Argument& arg)
{
constexpr auto GetPaddedLength = [](index_t length, index_t tile_length) {
return math::integer_divide_ceil(length, tile_length) * tile_length;
};
constexpr auto IsScalarPerVectorValid =
[](index_t length, index_t stride, index_t scalar_per_vector) {
if(stride == 1 && length % scalar_per_vector == 0)
{
return true;
}
else if(stride != 1 && scalar_per_vector == 1)
{
return true;
}
return false;
};
return IsScalarPerVectorValid(arg.in_lengths_[SrcVectorDim],
arg.in_strides_[SrcVectorDim],
SrcScalarPerVector) &&
IsScalarPerVectorValid(
GetPaddedLength(arg.in_lengths_[SrcVectorDim],
(SrcVectorDim == NumDim - 2 ? HPerBlock : WPerBlock)),
arg.in_strides_[SrcVectorDim],
SrcScalarPerVector) &&
IsScalarPerVectorValid(arg.out_lengths_[DstVectorDim],
arg.out_strides_[DstVectorDim],
DstScalarPerVector) &&
IsScalarPerVectorValid(
GetPaddedLength(arg.out_lengths_[DstVectorDim],
(DstVectorDim == NumDim - 2 ? HPerBlock : WPerBlock)),
arg.in_strides_[DstVectorDim],
DstScalarPerVector) &&
GridwisePermute::CheckValidity(arg.in_grid_desc_, arg.out_grid_desc_);
};
// override methods inherited from 'BaseOperator'
bool IsSupportedArgument(const BaseArgument* arg) override final
{
const auto* const argument = dynamic_cast<const Argument*>(arg);
if(!argument)
{
return false;
}
return IsSupportedArgument(*argument);
}
// override methods inherited from 'DevicePermute'
std::unique_ptr<BaseArgument>
MakeArgumentPointer(const Lengths& in_lengths,
const Strides& in_strides,
const Lengths& out_lengths,
const Strides& out_strides,
const void* in_dev_buffer,
void* out_dev_buffer,
ElementwiseOperation elementwise_op) override final
{
return std::make_unique<Argument>(in_lengths,
in_strides,
out_lengths,
out_strides,
in_dev_buffer,
out_dev_buffer,
elementwise_op);
}
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override final
{
return std::make_unique<Invoker>();
};
// other constructor methods
template <typename... Args>
static std::enable_if_t<std::is_constructible_v<Argument, Args...>, Argument>
MakeArgument(Args&&... args) noexcept(std::is_nothrow_constructible_v<Argument, Args...>)
{
return Argument{std::forward<Args>(args)...};
}
static std::enable_if_t<std::is_default_constructible_v<Invoker>, Invoker>
MakeInvoker() noexcept(std::is_nothrow_default_constructible_v<Invoker>)
{
return Invoker{};
}
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -218,6 +218,165 @@ struct GemmPadder_v2
KPerTileType KPerTile_;
};
// M/N/KPerTileType could be index_t or Number<>
template <bool PadM,
bool PadN,
bool PadK,
typename MPerTileType,
typename NPerTileType,
typename KPerTileType>
struct MatrixPadder_v2
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
template <typename ADesc_MRaw_KRaw>
__host__ __device__ constexpr auto
PadADescriptor_M_K(const ADesc_MRaw_KRaw& a_desc_mraw_kraw) const
{
const auto MRaw = a_desc_mraw_kraw.GetLength(I0);
const auto KRaw = a_desc_mraw_kraw.GetLength(I1);
const auto M = math::integer_divide_ceil(MRaw, MPerTile_) * MPerTile_;
const auto K = math::integer_divide_ceil(KRaw, KPerTile_) * KPerTile_;
const auto MPad = M - MRaw;
const auto KPad = K - KRaw;
if constexpr(PadM && PadK)
{
// pad both M and K
return transform_tensor_descriptor(a_desc_mraw_kraw,
make_tuple(make_right_pad_transform(MRaw, MPad),
make_right_pad_transform(KRaw, KPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else if constexpr(PadM && (!PadK))
{
// pad M, but not K
return transform_tensor_descriptor(
a_desc_mraw_kraw,
make_tuple(make_right_pad_transform(MRaw, MPad), make_pass_through_transform(KRaw)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else if constexpr((!PadM) && PadK)
{
// pad K, but not M
return transform_tensor_descriptor(
a_desc_mraw_kraw,
make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(KRaw, KPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else
{
// not pad M or K
return a_desc_mraw_kraw;
}
}
template <typename BDesc_NRaw_KRaw>
__host__ __device__ constexpr auto
PadBDescriptor_N_K(const BDesc_NRaw_KRaw& b_desc_nraw_kraw) const
{
const auto NRaw = b_desc_nraw_kraw.GetLength(I0);
const auto KRaw = b_desc_nraw_kraw.GetLength(I1);
const auto N = math::integer_divide_ceil(NRaw, NPerTile_) * NPerTile_;
const auto K = math::integer_divide_ceil(KRaw, KPerTile_) * KPerTile_;
const auto NPad = N - NRaw;
const auto KPad = K - KRaw;
if constexpr(PadN && PadK)
{
// pad both N and K
return transform_tensor_descriptor(b_desc_nraw_kraw,
make_tuple(make_right_pad_transform(NRaw, NPad),
make_right_pad_transform(KRaw, KPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else if constexpr(PadN && (!PadK))
{
// pad N, but not K
return transform_tensor_descriptor(
b_desc_nraw_kraw,
make_tuple(make_right_pad_transform(NRaw, NPad), make_pass_through_transform(KRaw)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else if constexpr((!PadN) && PadK)
{
// pad K, but not N
return transform_tensor_descriptor(
b_desc_nraw_kraw,
make_tuple(make_pass_through_transform(NRaw), make_right_pad_transform(KRaw, KPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else
{
// not pad N or K
return b_desc_nraw_kraw;
}
}
template <typename CDesc_MRaw_NRaw>
__host__ __device__ constexpr auto
PadCDescriptor_M_N(const CDesc_MRaw_NRaw& c_desc_mraw_nraw) const
{
const auto MRaw = c_desc_mraw_nraw.GetLength(I0);
const auto NRaw = c_desc_mraw_nraw.GetLength(I1);
const auto M = math::integer_divide_ceil(MRaw, MPerTile_) * MPerTile_;
const auto N = math::integer_divide_ceil(NRaw, NPerTile_) * NPerTile_;
const auto MPad = M - MRaw;
const auto NPad = N - NRaw;
if constexpr(PadM && PadN)
{
// pad M and N
return transform_tensor_descriptor(c_desc_mraw_nraw,
make_tuple(make_right_pad_transform(MRaw, MPad),
make_right_pad_transform(NRaw, NPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else if constexpr(PadM && (!PadN))
{
// pad M, but not N
return transform_tensor_descriptor(
c_desc_mraw_nraw,
make_tuple(make_right_pad_transform(MRaw, MPad), make_pass_through_transform(NRaw)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else if constexpr((!PadM) && PadN)
{
// pad N, but not M
return transform_tensor_descriptor(
c_desc_mraw_nraw,
make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(NRaw, NPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else
{
// not pad M or N
return c_desc_mraw_nraw;
}
}
MPerTileType MPerTile_;
NPerTileType NPerTile_;
KPerTileType KPerTile_;
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -92,6 +92,12 @@ struct GNDHWC : public BaseTensorLayout
static constexpr const char* name = "GNDHWC";
};
// for input bias
struct GC : public BaseTensorLayout
{
static constexpr const char* name = "GC";
};
// input tensor
// packed NWGC/NHWGC/NDHWGC
struct NWGC : public BaseTensorLayout
......@@ -126,6 +132,12 @@ struct G_NDHW_C : public BaseTensorLayout
static constexpr const char* name = "G_NDHW_C";
};
// for input bias
struct G_C : public BaseTensorLayout
{
static constexpr const char* name = "G_C";
};
// weight tensor
// packed KCX/KCYX/KCZYX
struct KCX : public BaseTensorLayout
......@@ -296,6 +308,12 @@ struct GNDHWK : public BaseTensorLayout
static constexpr const char* name = "GNDHWK";
};
// for output bias
struct GK : public BaseTensorLayout
{
static constexpr const char* name = "GK";
};
// output tensor
// packed NWGK/NHWGK/NDHWGK
struct NWGK : public BaseTensorLayout
......@@ -330,6 +348,12 @@ struct G_NDHW_K : public BaseTensorLayout
static constexpr const char* name = "G_NDHW_K";
};
// for output bias
struct G_K : public BaseTensorLayout
{
static constexpr const char* name = "G_K";
};
// K-reduced output tensor (packed)
struct GNW : public BaseTensorLayout
{
......
......@@ -28,6 +28,13 @@ struct Add
y = x0 + x1;
};
template <>
__host__ __device__ constexpr void
operator()<float>(float& y, const float& x0, const half_t& x1) const
{
y = x0 + type_convert<half_t>(x1);
};
template <>
__host__ __device__ constexpr void
operator()<half_t>(half_t& y, const float& x0, const half_t& x1) const
......@@ -172,6 +179,14 @@ struct AddRelu
const float a = x0 + x1;
y = a > type_convert<half_t>(0.0f) ? a : type_convert<half_t>(0.0f);
};
template <>
__host__ __device__ constexpr void
operator()<float, float, half_t>(float& y, const float& x0, const half_t& x1) const
{
const float a = x0 + type_convert<float>(x1);
y = a > 0.0f ? a : 0.0f;
};
};
struct AddHardswish
......@@ -210,6 +225,46 @@ struct AddHardswish
};
};
// C = A * B
// E = FastGelu(C + D)
struct AddFastGelu
{
// Fast GeLU
// https://paperswithcode.com/method/gelu
// y = 0.5*x*(1+tanh(sqrt(2/pi)*(x+0.044715*x^3)))
__host__ __device__ static constexpr float GetFastGeLU(float x)
{
const float u = 2.f * x * (0.035677f * x * x + 0.797885f);
const float emu = exp(-u);
const float cdf = 0.5f + 0.5f * (2.f / (1.f + emu) - 1.f);
return x * cdf;
}
template <typename T>
static inline constexpr bool is_valid_param_type_v =
std::is_same_v<T, float> || std::is_same_v<T, half_t> || std::is_same_v<T, bhalf_t> ||
std::is_same_v<T, int32_t> || std::is_same_v<T, int8_t>;
template <typename E, typename C, typename D>
__host__ __device__ constexpr void operator()(E& e, const C& c, const D& d) const
{
static_assert(is_valid_param_type_v<E> && is_valid_param_type_v<C> &&
is_valid_param_type_v<D>);
const float y = GetFastGeLU(type_convert<float>(c) + type_convert<float>(d));
e = type_convert<E>(y);
}
template <typename D>
__host__ __device__ constexpr void operator()(float& e, const float& c, const D& d) const
{
static_assert(is_valid_param_type_v<D>);
e = GetFastGeLU(c + type_convert<float>(d));
}
};
} // namespace element_wise
} // namespace tensor_operation
} // namespace ck
......@@ -211,6 +211,42 @@ struct FastGelu
}
};
// https://paperswithcode.com/method/gelu
// y = 0.5*x*(1+erf(x/sqrt(2)))
struct Gelu
{
template <typename Y, typename X>
__host__ __device__ void operator()(Y& y, const X& x) const;
template <>
__host__ __device__ void operator()<float, float>(float& y, const float& x) const
{
y = 0.5f * x * (1.f + erf(float(0.70710678118f * x)));
}
template <>
__host__ __device__ void operator()<ck::half_t, ck::half_t>(ck::half_t& y,
const ck::half_t& x) const
{
y = ck::half_t(0.5) * x * (ck::half_t(1) + ck::half_t(erf(float(0.70710678118f * x))));
}
};
struct Sigmoid
{
template <typename T>
__host__ __device__ void operator()(T& y, const T& x) const
{
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, ck::half_t>::value,
"Data type is not supported by this operation!");
y = 1 / (ck::type_convert<T>(1) + exp(-x));
};
int32_t divider_ = 1;
};
} // namespace element_wise
} // namespace tensor_operation
} // namespace ck
......@@ -486,4 +486,48 @@ __host__ __device__ bool DefaultValidCTileIndex(const CTileIdx& c_tile_idx,
return is_valid;
}
// This wrapper class is for grouped gemm where it subtracts blockIdx by a value so that the
// workgroups assigned to a given gemm problem have top index offsetted to range [0,
// grid_size_per_gemm]
template <typename UnderlyingBlockToCTileMap>
struct OffsettedBlockToCTileMap
{
using underlying_type = UnderlyingBlockToCTileMap;
OffsettedBlockToCTileMap(UnderlyingBlockToCTileMap block_to_ctile_map, index_t block_start)
{
block_to_ctile_map_ = block_to_ctile_map;
block_start_ = block_start;
}
template <typename TopIdx>
__host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const
{
return block_to_ctile_map_.CalculateBottomIndex(
make_multi_index(idx_top[Number<0>{}] - block_start_));
}
template <typename CTileIdx, typename CTileDim>
__host__ __device__ bool ValidCTileIndex(const CTileIdx& c_tile_idx,
const CTileDim& c_tile_dim) const
{
return block_to_ctile_map_.ValidCTileIndex(c_tile_idx, c_tile_dim);
}
template <typename CGridDesc_M_N>
__host__ bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const
{
return block_to_ctile_map_.CheckValidity(c_grid_desc_m_n);
}
template <typename CGridDesc_M_N>
__host__ constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const
{
return block_to_ctile_map_.CalculateGridSize(c_grid_desc_m_n);
}
UnderlyingBlockToCTileMap block_to_ctile_map_;
index_t block_start_;
};
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/multi_index_transform_helper.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace ck {
template <typename A0B0B1DataType, // FIXME: don't assume A0/B0/B1 have same datatype
typename Acc0DataType,
typename D0sDataType,
typename Acc1DataType,
typename C1ShuffleDataType,
typename D1sDataType,
typename E1DataType,
typename A0ElementwiseOperation,
typename B0ElementwiseOperation,
typename CDE0ElementwiseOperation,
typename B1ElementwiseOperation,
typename CDE1ElementwiseOperation,
InMemoryDataOperationEnum E1GlobalMemoryDataOperation,
typename A0GridDesc_M_K,
typename B0GridDesc_N_K,
typename D0sGridDesc_M_N,
typename B1GridDesc_N_K,
typename D1sGridDesc_M_N,
typename E1GridDesc_M_N,
index_t NumGemm0KPrefetchStage,
index_t BlockSize,
index_t Gemm0MPerBlock,
index_t Gemm0NPerBlock,
index_t Gemm0KPerBlock,
index_t Gemm1NPerBlock,
index_t Gemm1KPerBlock,
index_t A0K1Value,
index_t B0K1Value,
index_t B1K1Value,
index_t Gemm0MPerXdl,
index_t Gemm0NPerXdl,
index_t Gemm0MXdlPerWave,
index_t Gemm0NXdlPerWave,
index_t Gemm1NXdlPerWave,
typename A0BlockTransferThreadClusterLengths_AK0_M_AK1,
typename A0BlockTransferThreadClusterArrangeOrder,
typename A0BlockTransferSrcAccessOrder,
index_t A0BlockTransferSrcVectorDim,
index_t A0BlockTransferSrcScalarPerVector,
index_t A0BlockTransferDstScalarPerVector_AK1,
bool A0ThreadTransferSrcResetCoordinateAfterRun, // ignored
index_t A0BlockLdsExtraM,
typename B0BlockTransferThreadClusterLengths_BK0_N_BK1,
typename B0BlockTransferThreadClusterArrangeOrder,
typename B0BlockTransferSrcAccessOrder,
index_t B0BlockTransferSrcVectorDim,
index_t B0BlockTransferSrcScalarPerVector,
index_t B0BlockTransferDstScalarPerVector_BK1,
bool B0ThreadTransferSrcResetCoordinateAfterRun, // ignored
index_t B0BlockLdsExtraN,
typename B1BlockTransferThreadClusterLengths_BK0_N_BK1,
typename B1BlockTransferThreadClusterArrangeOrder,
typename B1BlockTransferSrcAccessOrder,
index_t B1BlockTransferSrcVectorDim,
index_t B1BlockTransferSrcScalarPerVector,
index_t B1BlockTransferDstScalarPerVector_BK1,
bool B1ThreadTransferSrcResetCoordinateAfterRun,
index_t B1BlockLdsExtraN,
index_t C1ShuffleGemm0MXdlPerWavePerShuffle,
index_t C1ShuffleGemm0NXdlPerWavePerShuffle,
typename CDE1ShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CDE1ShuffleBlockTransferScalarPerVector_NPerBlock,
LoopScheduler LoopSched>
struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
{
static_assert(LoopSched == LoopScheduler::Default,
"Non-default loop scheduler is currently not supported");
static constexpr index_t NumD0Tensor = D0sDataType::Size();
static constexpr index_t NumD1Tensor = D1sDataType::Size();
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{};
static constexpr auto I5 = Number<5>{};
static constexpr auto I6 = Number<6>{};
static constexpr auto I7 = Number<7>{};
static constexpr auto WaveSize = 64;
// K1 should be Number<...>
// Gemm0
static constexpr auto A0K1 = Number<A0K1Value>{};
static constexpr auto B0K1 = Number<B0K1Value>{};
static constexpr auto A0K0PerBlock = Number<Gemm0KPerBlock / A0K1Value>{};
static constexpr auto B0K0PerBlock = Number<Gemm0KPerBlock / B0K1Value>{};
static constexpr auto Gemm0MWaves = Gemm0MPerBlock / (Gemm0MPerXdl * Gemm0MXdlPerWave);
static constexpr auto Gemm0NWaves = Gemm0NPerBlock / (Gemm0NPerXdl * Gemm0NXdlPerWave);
// Gemm1
static constexpr auto B1K1 = Number<B1K1Value>{};
static constexpr auto B1K0PerBlock = Number<Gemm1KPerBlock / B1K1Value>{};
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
using GridwiseGemmPipe = GridwiseGemmPipeline_v1<NumGemm0KPrefetchStage>;
// ck::Tuple<const D0DataType1*, const D0DataType2*, ...>
static constexpr auto MakeD0sGridPointer()
{
return generate_tuple(
[&](auto i) {
using D0DataType = remove_cvref_t<tuple_element_t<i.value, D0sDataType>>;
return static_cast<const D0DataType*>(nullptr);
},
Number<NumD0Tensor>{});
}
// ck::Tuple<const D1DataType1*, const D1DataType2*, ...>
static constexpr auto MakeD1sGridPointer()
{
return generate_tuple(
[&](auto i) {
using D1DataType = remove_cvref_t<tuple_element_t<i.value, D1sDataType>>;
return static_cast<const D1DataType*>(nullptr);
},
Number<NumD1Tensor>{});
}
__device__ static auto GetGemm0WaveIdx()
{
const index_t thread_id = get_thread_local_1d_id();
constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(Gemm0MWaves, Gemm0NWaves, WaveSize))),
make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{}));
return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id));
}
__device__ static auto GetGemm0WaveMNIdx(const index_t thread_id)
{
constexpr auto wave_threadid_to_mn_idx_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(WaveSize / Gemm0NPerXdl, Gemm0NPerXdl))),
make_tuple(Sequence<0, 1>{}),
make_tuple(Sequence<0>{}));
return wave_threadid_to_mn_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id));
}
template <typename A0BlockDesc_AK0_M_AK1>
__host__ __device__ static constexpr auto
MakeGemm0AMmaTileDescriptor_M0_M1_M2_K(const A0BlockDesc_AK0_M_AK1&)
{
constexpr index_t MWaves = Gemm0MPerBlock / (Gemm0MXdlPerWave * Gemm0MPerXdl);
return MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K<Gemm0MXdlPerWave, MWaves, Gemm0MPerXdl>(
A0BlockDesc_AK0_M_AK1{});
}
template <typename BBlockDesc_BK0_N_BK1>
__host__ __device__ static constexpr auto
MakeGemm0BMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1&)
{
constexpr index_t NWaves = Gemm0NPerBlock / (Gemm0NXdlPerWave * Gemm0NPerXdl);
return MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K<Gemm0NXdlPerWave, NWaves, Gemm0NPerXdl>(
BBlockDesc_BK0_N_BK1{});
}
template <typename A0BlockDesc_AK0_M_AK1>
__host__ __device__ static constexpr auto
MakeGemm1AMmaTileDescriptor_M0_M1_M2_K(const A0BlockDesc_AK0_M_AK1&)
{
return MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K<Gemm0MXdlPerWave, 1, 1>(
A0BlockDesc_AK0_M_AK1{});
}
template <typename BBlockDesc_BK0_N_BK1>
__host__ __device__ static constexpr auto
MakeGemm1BMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1&)
{
constexpr index_t Gemm1NWaves = Gemm1NPerBlock / (Gemm1NXdlPerWave * Gemm0NPerXdl);
return MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K<Gemm1NXdlPerWave, Gemm1NWaves, Gemm0NPerXdl>(
BBlockDesc_BK0_N_BK1{});
}
__host__ __device__ static constexpr auto GetA0BlockDescriptor_AK0PerBlock_MPerBlock_AK1()
{
// A0 matrix in LDS memory, dst of blockwise copy
return make_naive_tensor_descriptor(
make_tuple(A0K0PerBlock, Number<Gemm0MPerBlock>{}, A0K1),
make_tuple(Number<Gemm0MPerBlock + A0BlockLdsExtraM>{} * A0K1, A0K1, I1));
}
__host__ __device__ static constexpr auto GetB0BlockDescriptor_BK0PerBlock_NPerBlock_BK1()
{
// B0 matrix in LDS memory, dst of blockwise copy
return make_naive_tensor_descriptor(
make_tuple(B0K0PerBlock, Number<Gemm0NPerBlock>{}, B0K1),
make_tuple(Number<Gemm0NPerBlock + B0BlockLdsExtraN>{} * B0K1, B0K1, I1));
}
__host__ __device__ static constexpr auto GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1()
{
// B1 matrix in LDS memory, dst of blockwise copy
return make_naive_tensor_descriptor(
make_tuple(B1K0PerBlock, Number<Gemm1NPerBlock>{}, B1K1),
make_tuple(Number<Gemm1NPerBlock + B1BlockLdsExtraN>{} * B1K1, B1K1, I1));
}
__host__ __device__ static constexpr auto
GetC1ShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock()
{
constexpr index_t MWave = Gemm0MPerBlock / (Gemm0MXdlPerWave * Gemm0MPerXdl);
constexpr index_t NWave = Gemm1NPerBlock / (Gemm1NXdlPerWave * Gemm0NPerXdl);
constexpr auto c1_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
make_naive_tensor_descriptor_packed(
make_tuple(I1,
Number<C1ShuffleGemm0MXdlPerWavePerShuffle * MWave * Gemm0MPerXdl>{},
I1,
Number<C1ShuffleGemm0NXdlPerWavePerShuffle * NWave * Gemm0NPerXdl>{}));
return c1_shuffle_block_desc_mblock_mperblock_nblock_nperblock;
}
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{
const index_t gemm0_bytes_end = (SharedMemTrait::a0_block_space_size_aligned +
SharedMemTrait::b0_block_space_size_aligned) *
sizeof(A0B0B1DataType);
const index_t gemm1_bytes_end =
(SharedMemTrait::b1_block_space_offset + SharedMemTrait::b1_block_space_size_aligned) *
sizeof(A0B0B1DataType);
const index_t c1_block_bytes_end =
SharedMemTrait::c1_block_space_size * sizeof(C1ShuffleDataType);
return math::max(gemm0_bytes_end, gemm1_bytes_end, c1_block_bytes_end);
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
template <typename Block2E1TileMap>
__host__ __device__ static constexpr bool
CheckValidity(const A0GridDesc_M_K& a0_grid_desc_m_k,
const B0GridDesc_N_K& b0_grid_desc_n_k,
const B1GridDesc_N_K& b1_grid_desc_n_k,
const E1GridDesc_M_N& e1_grid_desc_m_n,
const Block2E1TileMap& block_2_e1tile_map)
{
static_assert((Gemm0MPerBlock % (Gemm0MPerXdl * Gemm0MXdlPerWave) == 0) &&
(Gemm0NPerBlock % (Gemm0NXdlPerWave * Gemm0NPerXdl)) == 0,
"Invalid tuning param!");
const auto M = a0_grid_desc_m_k.GetLength(I0);
const auto N = b0_grid_desc_n_k.GetLength(I0);
const auto K = a0_grid_desc_m_k.GetLength(I1);
const auto Gemm1N = b1_grid_desc_n_k.GetLength(I0);
if(!(M == e1_grid_desc_m_n.GetLength(I0) && Gemm1N == e1_grid_desc_m_n.GetLength(I1)))
{
return false;
}
if(!(M % Gemm0MPerBlock == 0 && N % Gemm0NPerBlock == 0 && K % Gemm0KPerBlock == 0 &&
Gemm1N % Gemm1NPerBlock == 0))
{
return false;
}
// check gemm0 gridwise gemm pipeline
const auto num_gemm0_k_loop = K / Gemm0KPerBlock;
if(!GridwiseGemmPipe::IsSupported(num_gemm0_k_loop))
{
return false;
}
// check gemm1 gridwise gemm pipeline
if(!(Gemm0NPerBlock % Gemm1KPerBlock == 0))
{
return false;
}
const auto num_gemm1_k_inner_loop = Gemm0NPerBlock / Gemm1KPerBlock;
if(!GridwiseGemmPipe::IsSupported(num_gemm1_k_inner_loop))
{
return false;
}
if(!block_2_e1tile_map.CheckValidity(e1_grid_desc_m_n))
{
return false;
}
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return true;
}
__host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
{
const index_t num_loop = K / Gemm0KPerBlock;
return GridwiseGemmPipe::CalculateHasMainLoop(num_loop);
}
// A0 desc for source in blockwise copy
__host__ __device__ static constexpr auto
MakeDefaultA0GridDescriptor_AK0_M_AK1(const A0GridDesc_M_K& a0_grid_desc_m_k)
{
const auto M = a0_grid_desc_m_k.GetLength(I0);
const auto K = a0_grid_desc_m_k.GetLength(I1);
const auto A0K0 = K / A0K1;
return transform_tensor_descriptor(
a0_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(A0K0, A0K1)),
make_pass_through_transform(M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
// B0 desc for source in blockwise copy
__host__ __device__ static constexpr auto
MakeDefaultB0GridDescriptor_BK0_N_BK1(const B0GridDesc_N_K& b0_grid_desc_n_k)
{
const auto N = b0_grid_desc_n_k.GetLength(I0);
const auto K = b0_grid_desc_n_k.GetLength(I1);
const auto B0K0 = K / B0K1;
return transform_tensor_descriptor(
b0_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(B0K0, B0K1)),
make_pass_through_transform(N)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
// D0 desc for source in blockwise copy
template <typename D0GridDesc_M_N>
__host__ __device__ static constexpr auto
MakeGemm0D0GridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(const D0GridDesc_M_N& d0_grid_desc_m_n)
{
const auto M = d0_grid_desc_m_n.GetLength(I0);
const auto N = d0_grid_desc_m_n.GetLength(I1);
constexpr auto mfma =
MfmaSelector<A0B0B1DataType, Gemm0MPerXdl, Gemm0NPerXdl>::selected_mfma;
constexpr auto N3 = mfma.num_groups_per_blk;
constexpr auto N5 = mfma.group_size;
return transform_tensor_descriptor(
d0_grid_desc_m_n,
make_tuple(make_unmerge_transform(make_tuple(
M / Gemm0MPerBlock, Gemm0MXdlPerWave, Gemm0MWaves, Gemm0MPerXdl)),
make_unmerge_transform(make_tuple(N / Gemm0NPerBlock,
Gemm0NXdlPerWave,
Gemm0NWaves,
N3,
WaveSize / Gemm0NPerXdl,
N5))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2, 4, 6>{}, Sequence<1, 3, 5, 7, 8, 9>{}));
}
// B1 desc for source in blockwise copy
__host__ __device__ static constexpr auto
MakeDefaultB1GridDescriptor_BK0_N_BK1(const B1GridDesc_N_K& b1_grid_desc_n_k)
{
const auto N = b1_grid_desc_n_k.GetLength(I0);
const auto K = b1_grid_desc_n_k.GetLength(I1);
const auto B1K0 = K / B1K1;
return transform_tensor_descriptor(
b1_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(B1K0, B1K1)),
make_pass_through_transform(N)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
// C1 desc for destination in blockwise copy
__host__ __device__ static constexpr auto
MakeE1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const E1GridDesc_M_N& e1_grid_desc_m_n)
{
const auto M = e1_grid_desc_m_n.GetLength(I0);
const auto N = e1_grid_desc_m_n.GetLength(I1);
const auto MBlock = M / Gemm0MPerBlock;
const auto NBlock = N / Gemm1NPerBlock;
const auto e1_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
e1_grid_desc_m_n,
make_tuple(make_unmerge_transform(make_tuple(MBlock, Number<Gemm0MPerBlock>{})),
make_unmerge_transform(make_tuple(NBlock, Number<Gemm1NPerBlock>{}))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}));
return e1_grid_desc_mblock_mperblock_nblock_nperblock;
}
// D0s desc for source in blockwise copy
__host__ __device__ static constexpr auto
MakeD0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(const D0sGridDesc_M_N& ds_grid_desc_m_n)
{
return generate_tuple(
[&](auto i) {
return MakeGemm0D0GridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(ds_grid_desc_m_n[i]);
},
Number<NumD0Tensor>{});
}
// Ds desc for source in blockwise copy
template <typename DsGridDescriptor_M_N>
__host__ __device__ static constexpr auto
MakeD1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
const DsGridDescriptor_M_N& ds_grid_desc_m_n)
{
return generate_tuple(
[&](auto i) {
return MakeE1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(ds_grid_desc_m_n[i]);
},
Number<NumD1Tensor>{});
}
// return block_id to C1 matrix tile idx (m0, n0) mapping
__host__ __device__ static constexpr auto
MakeDefaultBlock2E1TileMap(const E1GridDesc_M_N& e1_grid_desc_m_n)
{
return BlockToCTileMap_M00_N0_M01Adapt<Gemm0MPerBlock, Gemm1NPerBlock, E1GridDesc_M_N>(
e1_grid_desc_m_n);
}
using E1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
MakeE1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(E1GridDesc_M_N{}))>;
using D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5 = remove_cvref_t<decltype(
MakeD0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(D0sGridDesc_M_N{}))>;
using D1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
MakeD1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(D1sGridDesc_M_N{}))>;
using DefaultBlock2E1TileMap =
remove_cvref_t<decltype(MakeDefaultBlock2E1TileMap(E1GridDesc_M_N{}))>;
struct SharedMemTrait
{
// LDS allocation for A0 and B0: be careful of alignment
static constexpr auto a0_block_desc_ak0_m_ak1 =
GetA0BlockDescriptor_AK0PerBlock_MPerBlock_AK1();
static constexpr auto b0_block_desc_bk0_n_bk1 =
GetB0BlockDescriptor_BK0PerBlock_NPerBlock_BK1();
static constexpr auto b1_block_desc_bk0_n_bk1 =
GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1();
static constexpr auto max_lds_align = math::lcm(math::lcm(A0K1, B0K1), B1K1);
static constexpr auto a0_block_space_size_aligned = math::integer_least_multiple(
a0_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
static constexpr auto b0_block_space_size_aligned = math::integer_least_multiple(
b0_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
static constexpr auto b1_block_space_size_aligned = math::integer_least_multiple(
b1_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
static constexpr auto a0_block_space_offset = 0;
static constexpr auto b0_block_space_offset = a0_block_space_size_aligned.value;
static constexpr auto b1_block_space_offset = 0;
// LDS allocation for C1 shuffle in LDS
static constexpr auto c1_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
GetC1ShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
static constexpr auto c1_block_space_size =
c1_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
};
using D0sGridPointer = decltype(MakeD0sGridPointer());
using D1sGridPointer = decltype(MakeD1sGridPointer());
template <bool HasMainKBlockLoop,
typename A0GridDesc_AK0_M_AK1,
typename B0GridDesc_BK0_N_BK1,
typename B1GridDesc_BK0_N_BK1,
typename Block2E1TileMap>
__device__ static void Run(const A0B0B1DataType* __restrict__ p_a0_grid,
const A0B0B1DataType* __restrict__ p_b0_grid,
D0sGridPointer p_d0s_grid,
const A0B0B1DataType* __restrict__ p_b1_grid,
D1sGridPointer p_d1s_grid,
E1DataType* __restrict__ p_e1_grid,
void* __restrict__ p_shared,
const A0ElementwiseOperation& a0_element_op,
const B0ElementwiseOperation& b0_element_op,
const CDE0ElementwiseOperation& cde0_element_op,
const B1ElementwiseOperation& b1_element_op,
const CDE1ElementwiseOperation& cde1_element_op,
const A0GridDesc_AK0_M_AK1& a0_grid_desc_ak0_m_ak1,
const B0GridDesc_BK0_N_BK1& b0_grid_desc_bk0_n_bk1,
const D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5&
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
const B1GridDesc_BK0_N_BK1& b1_grid_desc_bk0_n_bk1,
const D1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
d1s_grid_desc_mblock_mperblock_nblock_nperblock,
const E1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
e1_grid_desc_mblock_mperblock_nblock_nperblock,
const Block2E1TileMap& block_2_e1tile_map)
{
const auto a0_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a0_grid, a0_grid_desc_ak0_m_ak1.GetElementSpaceSize());
const auto b0_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b0_grid, b0_grid_desc_bk0_n_bk1.GetElementSpaceSize());
const auto b1_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b1_grid, b1_grid_desc_bk0_n_bk1.GetElementSpaceSize());
auto e1_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_e1_grid, e1_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
const auto d0s_grid_buf = generate_tuple(
[&](auto i) {
return make_dynamic_buffer<AddressSpaceEnum::Global>(
p_d0s_grid[i],
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i].GetElementSpaceSize());
},
Number<NumD0Tensor>{});
const auto d1s_grid_buf = generate_tuple(
[&](auto i) {
return make_dynamic_buffer<AddressSpaceEnum::Global>(
p_d1s_grid[i],
d1s_grid_desc_mblock_mperblock_nblock_nperblock[i].GetElementSpaceSize());
},
Number<NumD1Tensor>{});
// divide block work by [M, N]
const auto block_work_idx =
block_2_e1tile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
if(!block_2_e1tile_map.ValidCTileIndex(
block_work_idx,
make_tuple(e1_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
e1_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
{
return;
}
// HACK: this force m/n_block_data_idx_on_grid into SGPR
const index_t m_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I0] * Gemm0MPerBlock);
const index_t n_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I1] * Gemm1NPerBlock);
// A0 matrix in LDS memory, dst of blockwise copy
constexpr auto a0_block_desc_ak0_m_ak1 = GetA0BlockDescriptor_AK0PerBlock_MPerBlock_AK1();
// B0 matrix in LDS memory, dst of blockwise copy
constexpr auto b0_block_desc_bk0_n_bk1 = GetB0BlockDescriptor_BK0PerBlock_NPerBlock_BK1();
//
// set up Gemm0
//
// A0 matrix blockwise copy
auto a0_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
A0ElementwiseOperation,
tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<A0K0PerBlock, Gemm0MPerBlock, A0K1>,
A0BlockTransferThreadClusterLengths_AK0_M_AK1,
A0BlockTransferThreadClusterArrangeOrder,
A0B0B1DataType,
A0B0B1DataType,
decltype(a0_grid_desc_ak0_m_ak1),
decltype(a0_block_desc_ak0_m_ak1),
A0BlockTransferSrcAccessOrder,
Sequence<1, 0, 2>,
A0BlockTransferSrcVectorDim,
2,
A0BlockTransferSrcScalarPerVector,
A0BlockTransferDstScalarPerVector_AK1,
1,
1,
true, // SrcResetCoord
true, // DstResetCoord
NumGemm0KPrefetchStage>(
a0_grid_desc_ak0_m_ak1,
make_multi_index(0, m_block_data_idx_on_grid, 0),
a0_element_op,
a0_block_desc_ak0_m_ak1,
make_multi_index(0, 0, 0),
tensor_operation::element_wise::PassThrough{});
// B0 matrix blockwise copy
auto b0_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
B0ElementwiseOperation,
tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<B0K0PerBlock, Gemm0NPerBlock, B0K1>,
B0BlockTransferThreadClusterLengths_BK0_N_BK1,
B0BlockTransferThreadClusterArrangeOrder,
A0B0B1DataType,
A0B0B1DataType,
decltype(b0_grid_desc_bk0_n_bk1),
decltype(b0_block_desc_bk0_n_bk1),
B0BlockTransferSrcAccessOrder,
Sequence<1, 0, 2>,
B0BlockTransferSrcVectorDim,
2,
B0BlockTransferSrcScalarPerVector,
B0BlockTransferDstScalarPerVector_BK1,
1,
1,
true, // SrcResetCoord
true, // DstResetCoord
NumGemm0KPrefetchStage>(
b0_grid_desc_bk0_n_bk1,
make_multi_index(0, 0, 0), // will loop over GemmN dimension
b0_element_op,
b0_block_desc_bk0_n_bk1,
make_multi_index(0, 0, 0),
tensor_operation::element_wise::PassThrough{});
// Fused Gemm+Gemm pipeline
// for n in N0:
// for k in K0:
// acc[m][n] += A[m][k] * B0[k][n]
// acc1[m][o] += acc[m][n] * B1[n][o]
// sanity check
constexpr index_t KPack = math::max(
math::lcm(A0K1, B0K1),
MfmaSelector<A0B0B1DataType, Gemm0MPerXdl, Gemm0NPerXdl>::selected_mfma.k_per_blk);
auto blockwise_gemm0 = BlockwiseGemmXdlops_v2<
BlockSize,
A0B0B1DataType,
Acc0DataType,
decltype(a0_block_desc_ak0_m_ak1),
decltype(b0_block_desc_bk0_n_bk1),
decltype(MakeGemm0AMmaTileDescriptor_M0_M1_M2_K(a0_block_desc_ak0_m_ak1)),
decltype(MakeGemm0BMmaTileDescriptor_N0_N1_N2_K(b0_block_desc_bk0_n_bk1)),
Gemm0MPerBlock,
Gemm0NPerBlock,
Gemm0KPerBlock,
Gemm0MPerXdl,
Gemm0NPerXdl,
Gemm0MXdlPerWave,
Gemm0NXdlPerWave,
KPack,
true>{}; // TransposeC
auto acc0_thread_buf = blockwise_gemm0.GetCThreadBuffer();
// LDS allocation for A0 and B0: be careful of alignment
auto a0_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<A0B0B1DataType*>(p_shared) + SharedMemTrait::a0_block_space_offset,
a0_block_desc_ak0_m_ak1.GetElementSpaceSize());
auto b0_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<A0B0B1DataType*>(p_shared) + SharedMemTrait::b0_block_space_offset,
b0_block_desc_bk0_n_bk1.GetElementSpaceSize());
constexpr auto a0_block_slice_copy_step = make_multi_index(Gemm0KPerBlock / A0K1, 0, 0);
constexpr auto b0_block_slice_copy_step = make_multi_index(Gemm0KPerBlock / B0K1, 0, 0);
const auto a0_block_reset_copy_step =
make_multi_index(-a0_grid_desc_ak0_m_ak1.GetLength(I0), 0, 0);
const auto b0_block_reset_copy_step =
make_multi_index(-b0_grid_desc_bk0_n_bk1.GetLength(I0), Gemm0NPerBlock, 0);
// gridwise GEMM pipeline
// Only supports LoopScheduler::Default
const auto gridwise_gemm0_pipeline =
GridwiseGemmPipeline_v1_Selector<NumGemm0KPrefetchStage, LoopScheduler::Default>();
const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
(a0_grid_desc_ak0_m_ak1.GetLength(I0) * a0_grid_desc_ak0_m_ak1.GetLength(I2)) /
Gemm0KPerBlock);
//
// set up Gemm1
//
// Acc0 matrix threadwise copy: AccVGPR to VGPR and downcast to XDL input data type
constexpr auto acc0_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 =
blockwise_gemm0.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
constexpr auto m0 = acc0_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I0);
constexpr auto n0 = acc0_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I1);
constexpr auto m1 = acc0_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I2);
constexpr auto n1 = acc0_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I3);
constexpr auto m2 = acc0_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I4);
constexpr auto n2 = acc0_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I5);
constexpr auto n3 = acc0_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I6);
constexpr auto n4 = acc0_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I7);
constexpr auto b1_block_slice_copy_step = make_multi_index(Gemm1KPerBlock / B1K1, 0, 0);
// d0 matrix threadwise copy
constexpr auto d0_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5 =
make_naive_tensor_descriptor_packed(make_tuple(I1, // MBlockId
I1, // NBlockID
I1, // MRepeat
I1, // NRepeat
I1, // MWaveId
I1, // NWaveId
I1, // MPerXdl
I1, // NGroupNum
I1, // NInputNum
n4)); // registerNum
auto d0s_thread_buf = generate_tuple(
[&](auto) {
return StaticBuffer<
AddressSpaceEnum::Vgpr,
A0B0B1DataType,
d0_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetElementSpaceSize(),
true>{};
},
Number<NumD0Tensor>{});
const auto wave_id = GetGemm0WaveIdx();
const auto wave_m_n_id = GetGemm0WaveMNIdx(wave_id[I2]); // I2: 0~63
constexpr auto acc0_thread_desc = make_naive_tensor_descriptor_packed(
make_tuple(Number<Gemm0MXdlPerWave>{}, Number<Gemm0NXdlPerWave>{}, n2, n4));
auto d0s_threadwise_copy = generate_tuple(
[&](auto i) {
return ThreadwiseTensorSliceTransfer_v2<
A0B0B1DataType,
A0B0B1DataType,
decltype(d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i]),
decltype(d0_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5),
Sequence<I1, I1, I1, I1, I1, I1, I1, I1, I1, n4>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>,
9,
n4,
1,
false>(d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i],
make_multi_index(block_work_idx[I0], // MBlockId
0, // NBlockId
0, // mrepeat
0, // nrepeat
wave_id[I0], // MWaveId
wave_id[I1], // NWaveId
wave_m_n_id[I1], // MPerXdl
0, // group
wave_m_n_id[I0], // NInputIndex
0)); // register number
},
Number<NumD0Tensor>{});
// acc0_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 to acc0_thread_desc_k0_m_k1
// n0_n1_n2_n3 -> k0
// m0_m1_m2 -> m
// n4 -> k1
// NOTE: had to use merge_v3 or will spit out compilation errors
constexpr auto acc0_thread_desc_k0_m_k1 = transform_tensor_descriptor(
acc0_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4,
make_tuple(make_merge_transform_v3_division_mod(make_tuple(n0, n1, n2, n3)),
make_merge_transform_v3_division_mod(make_tuple(m0, m1, m2)),
make_pass_through_transform(n4)),
make_tuple(Sequence<1, 3, 5, 6>{}, Sequence<0, 2, 4>{}, Sequence<7>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
// A1 matrix in AccVGPR
// N2 num_groups_per_blk, N3 num_input_blks, N4 group_size
constexpr auto Acc0N3 =
blockwise_gemm0.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4().GetLength(I6);
constexpr auto A1ThreadSlice_K0_M_K1 = make_tuple(
Number<Gemm1KPerBlock / n4 / Acc0N3>{}, Number<m0 * m1 * m2>{}, Number<n4>{});
constexpr auto A1ThreadSliceK0 = A1ThreadSlice_K0_M_K1[I0];
constexpr auto A1ThreadSliceM = A1ThreadSlice_K0_M_K1[I1];
constexpr auto A1ThreadSliceK1 = A1ThreadSlice_K0_M_K1[I2];
constexpr auto a1_thread_desc_k0_m_k1 = make_naive_tensor_descriptor(
A1ThreadSlice_K0_M_K1,
make_tuple(A1ThreadSliceM * A1ThreadSliceK1, A1ThreadSliceK1, I1));
// B1 matrix in LDS memory, dst of blockwise copy
constexpr auto b1_block_desc_bk0_n_bk1 = GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1();
// A1 matrix blockwise copy
auto a1_blockwise_copy = ThreadwiseTensorSliceTransfer_StaticToStatic<
Acc0DataType,
A0B0B1DataType,
decltype(acc0_thread_desc_k0_m_k1),
decltype(a1_thread_desc_k0_m_k1),
tensor_operation::element_wise::PassThrough,
Sequence<A1ThreadSliceK0, A1ThreadSliceM, A1ThreadSliceK1>,
Sequence<1, 0, 2>,
2,
n4>{tensor_operation::element_wise::PassThrough{}};
// B1 matrix blockwise copy
auto b1_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
B0ElementwiseOperation,
tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<B1K0PerBlock, Gemm1NPerBlock, B1K1>,
B1BlockTransferThreadClusterLengths_BK0_N_BK1,
B1BlockTransferThreadClusterArrangeOrder,
A0B0B1DataType,
A0B0B1DataType,
decltype(b1_grid_desc_bk0_n_bk1),
decltype(b1_block_desc_bk0_n_bk1),
B1BlockTransferSrcAccessOrder,
Sequence<1, 0, 2>,
B1BlockTransferSrcVectorDim,
2,
B1BlockTransferSrcScalarPerVector,
B1BlockTransferDstScalarPerVector_BK1,
1,
1,
B1ThreadTransferSrcResetCoordinateAfterRun,
true, // DstResetCoord
1>(b1_grid_desc_bk0_n_bk1,
make_multi_index(0, n_block_data_idx_on_grid, 0),
b1_element_op,
b1_block_desc_bk0_n_bk1,
make_multi_index(0, 0, 0),
tensor_operation::element_wise::PassThrough{});
auto a1_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, A0B0B1DataType>(
a1_thread_desc_k0_m_k1.GetElementSpaceSize());
// reuse LDS space for gemm0's b0_block_buf
auto b1_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<A0B0B1DataType*>(p_shared) + SharedMemTrait::b1_block_space_offset,
b1_block_desc_bk0_n_bk1.GetElementSpaceSize());
constexpr index_t Gemm1KPack = math::max(
math::lcm(
MfmaSelector<A0B0B1DataType, Gemm0MPerXdl, Gemm0NPerXdl>::selected_mfma.group_size,
B1K1),
MfmaSelector<A0B0B1DataType, Gemm0MPerXdl, Gemm0NPerXdl>::selected_mfma.k_per_blk);
auto blockwise_gemm1 = BlockwiseGemmXdlops_v2<
BlockSize,
A0B0B1DataType,
Acc1DataType,
decltype(a1_thread_desc_k0_m_k1),
decltype(b1_block_desc_bk0_n_bk1),
decltype(MakeGemm1AMmaTileDescriptor_M0_M1_M2_K(a1_thread_desc_k0_m_k1)),
decltype(MakeGemm1BMmaTileDescriptor_N0_N1_N2_K(b1_block_desc_bk0_n_bk1)),
Gemm0MPerBlock,
Gemm1NPerBlock,
Gemm1KPerBlock,
Gemm0MPerXdl,
Gemm0NPerXdl,
Gemm0MXdlPerWave,
Gemm1NXdlPerWave,
Gemm1KPack,
false, // TransposeC
Gemm1KPack, // AMmaKStride
Gemm1KPack * XdlopsGemm<A0B0B1DataType, Gemm0MPerXdl, Gemm0NPerXdl, Gemm1KPack, false>{}
.K0PerXdlops>{ // BMmaKStride
make_tuple(0, 0, 0, 0)}; // A_origin
auto c1_thread_buf = blockwise_gemm1.GetCThreadBuffer();
const index_t num_gemm1_k_block_outer_loop =
b0_grid_desc_bk0_n_bk1.GetLength(I1) / Gemm0NPerBlock;
constexpr index_t num_gemm1_k_block_inner_loop = Gemm0NPerBlock / Gemm1KPerBlock;
// Initialize C1
c1_thread_buf.Clear();
// gemm1 K loop
index_t gemm1_k_block_outer_index = 0;
do
{
// gemm0
gridwise_gemm0_pipeline.template Run<HasMainKBlockLoop>(a0_grid_desc_ak0_m_ak1,
a0_block_desc_ak0_m_ak1,
a0_blockwise_copy,
a0_grid_buf,
a0_block_buf,
a0_block_slice_copy_step,
b0_grid_desc_bk0_n_bk1,
b0_block_desc_bk0_n_bk1,
b0_blockwise_copy,
b0_grid_buf,
b0_block_buf,
b0_block_slice_copy_step,
blockwise_gemm0,
acc0_thread_buf,
num_k_block_main_loop);
// bias+gelu
{
static_for<0, Gemm0MXdlPerWave, 1>{}([&](auto mr) {
static_for<0, Gemm0NXdlPerWave, 1>{}([&](auto nr) {
static_for<0, n2, 1>{}([&](auto groupid) {
static_for<0, NumD0Tensor, 1>{}([&](auto i) {
d0s_threadwise_copy(i).Run(
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i],
d0s_grid_buf[i],
d0_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
d0s_thread_buf(i));
});
static_for<0, n4, 1>{}([&](auto i) {
constexpr index_t c_offset = acc0_thread_desc.CalculateOffset(
make_tuple(mr, nr, groupid, i));
// get reference to src data
const auto src_data_refs = generate_tie(
// return type should be lvalue
[&](auto iSrc) -> const auto& {
return d0s_thread_buf[iSrc][i];
},
Number<NumD0Tensor>{});
// get reference to dst data
auto dst_data_refs = generate_tie(
// return type should be lvalue
[&](auto) -> auto& {
return acc0_thread_buf(Number<c_offset>{});
},
Number<2>{});
unpack2(cde0_element_op, dst_data_refs, src_data_refs);
});
static_for<0, NumD0Tensor, 1>{}([&](auto i) {
d0s_threadwise_copy(i).MoveSrcSliceWindow(
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i],
make_multi_index(0, 0, 0, 0, 0, 0, 0, 1, 0, 0));
});
});
static_for<0, NumD0Tensor, 1>{}([&](auto i) {
d0s_threadwise_copy(i).MoveSrcSliceWindow(
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i],
make_multi_index(0, 0, 0, 1, 0, 0, 0, -n2.value, 0, 0));
});
});
static_for<0, NumD0Tensor, 1>{}([&](auto i) {
d0s_threadwise_copy(i).MoveSrcSliceWindow(
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i],
make_multi_index(0, 0, 1, -Gemm0NXdlPerWave, 0, 0, 0, 0, 0, 0));
});
});
static_for<0, NumD0Tensor, 1>{}([&](auto i) {
d0s_threadwise_copy(i).MoveSrcSliceWindow(
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i],
make_multi_index(0, 1, -Gemm0MXdlPerWave, 0, 0, 0, 0, 0, 0, 0));
});
}
// gemm1
{
// TODO: explore using dynamic buffer for a1 thread buffer
// For a1_blockwise_copy, the goal is to satisfy pipeline requirements RunRead(),
// RunWrite(), and MoveSliceWindow(). But it is impossible to implement given that
// the A1 source buffer is static buffer holding the output of first GEMM and
// requires constexpr offset by design. Therefore, we pass tensor coordinate offset
// explicitly in Run() below.
// preload data into LDS
b1_blockwise_copy.RunRead(b1_grid_desc_bk0_n_bk1, b1_grid_buf);
b1_blockwise_copy.MoveSrcSliceWindow(b1_grid_desc_bk0_n_bk1,
b1_block_slice_copy_step);
block_sync_lds(); // wait for gemm0 LDS read
b1_blockwise_copy.RunWrite(b1_block_desc_bk0_n_bk1, b1_block_buf);
// main body
if constexpr(num_gemm1_k_block_inner_loop > 1)
{
static_for<0, num_gemm1_k_block_inner_loop - 1, 1>{}([&](auto i) {
a1_blockwise_copy.Run(acc0_thread_desc_k0_m_k1,
make_tuple(Number<i * A1ThreadSliceK0>{}, I0, I0),
acc0_thread_buf,
a1_thread_desc_k0_m_k1,
make_tuple(I0, I0, I0),
a1_thread_buf);
b1_blockwise_copy.RunRead(b1_grid_desc_bk0_n_bk1, b1_grid_buf);
block_sync_lds();
blockwise_gemm1.Run(a1_thread_buf, b1_block_buf, c1_thread_buf);
block_sync_lds();
b1_blockwise_copy.MoveSrcSliceWindow(b1_grid_desc_bk0_n_bk1,
b1_block_slice_copy_step);
b1_blockwise_copy.RunWrite(b1_block_desc_bk0_n_bk1, b1_block_buf);
});
}
// tail
{
a1_blockwise_copy.Run(
acc0_thread_desc_k0_m_k1,
make_tuple(
Number<(num_gemm1_k_block_inner_loop - 1) * A1ThreadSliceK0>{}, I0, I0),
acc0_thread_buf,
a1_thread_desc_k0_m_k1,
make_tuple(I0, I0, I0),
a1_thread_buf);
block_sync_lds();
blockwise_gemm1.Run(a1_thread_buf, b1_block_buf, c1_thread_buf);
}
} // end gemm1
a0_blockwise_copy.MoveSrcSliceWindow(a0_grid_desc_ak0_m_ak1,
a0_block_reset_copy_step); // rewind K
b0_blockwise_copy.MoveSrcSliceWindow(b0_grid_desc_bk0_n_bk1,
b0_block_reset_copy_step); // rewind K and step N
block_sync_lds(); // wait for gemm1 LDS read
} while(++gemm1_k_block_outer_index < num_gemm1_k_block_outer_loop); // end j loop
// shuffle C1 and write out
{
static_assert(Gemm0MXdlPerWave % C1ShuffleGemm0MXdlPerWavePerShuffle == 0 &&
Gemm1NXdlPerWave % C1ShuffleGemm0NXdlPerWavePerShuffle == 0,
"wrong!");
constexpr index_t MWave = Gemm0MPerBlock / (Gemm0MXdlPerWave * Gemm0MPerXdl);
constexpr index_t NWave = Gemm1NPerBlock / (Gemm1NXdlPerWave * Gemm0NPerXdl);
// TODO: hacky, fix it!
constexpr auto c1_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
blockwise_gemm1.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
// TODO: hacky, fix it!
// c1_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
constexpr auto c1_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
blockwise_gemm1.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
constexpr auto M0 = c1_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
constexpr auto N0 = c1_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
constexpr auto M1 = c1_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
constexpr auto N1 = c1_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
constexpr auto M2 = c1_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
constexpr auto M3 = c1_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
constexpr auto M4 = c1_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
constexpr auto N2 = c1_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
constexpr auto c1_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
GetC1ShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
auto c1_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<C1ShuffleDataType*>(p_shared),
c1_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
constexpr auto c1_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
c1_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
make_tuple(
make_freeze_transform(I0),
make_unmerge_transform(make_tuple(
Number<C1ShuffleGemm0MXdlPerWavePerShuffle>{}, // M0 (Gemm0MXdlPerWave) per
// shuffle
M1, // M1 = MWave
M2, // M2 * M3 * M4 = Gemm0MPerXdl
M3,
M4)),
make_freeze_transform(I0),
make_unmerge_transform(make_tuple(
Number<C1ShuffleGemm0NXdlPerWavePerShuffle>{}, // N0 (Gemm0NXdlPerWave) per
// shuffle
N1, // N1 = NWave
N2))), // N2 = Gemm0NPerXdl
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(
Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{}));
// calculate origin of thread output tensor on global memory
// blockwise GEMM C1 matrix starting index
const auto c1_thread_mtx_on_block =
blockwise_gemm1.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
const index_t m_thread_data_on_block = c1_thread_mtx_on_block[I0];
const index_t n_thread_data_on_block = c1_thread_mtx_on_block[I1];
const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
make_tuple(Sequence<0, 1, 2, 3, 4>{}),
make_tuple(Sequence<0>{}));
const auto m_thread_data_on_block_idx =
m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
make_multi_index(m_thread_data_on_block));
const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(N0, N1, N2))),
make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{}));
const auto n_thread_data_on_block_idx =
n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
make_multi_index(n_thread_data_on_block));
// shuffle: threadwise copy C from VGPR to LDS
auto c1_thread_copy_vgpr_to_lds =
ThreadwiseTensorSliceTransfer_v1r3<Acc1DataType,
C1ShuffleDataType,
decltype(c1_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
decltype(c1_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
tensor_operation::element_wise::PassThrough,
Sequence<C1ShuffleGemm0MXdlPerWavePerShuffle,
C1ShuffleGemm0NXdlPerWavePerShuffle,
I1,
I1,
M2,
I1,
M4,
I1>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
7,
1,
InMemoryDataOperationEnum::Set,
1,
true>{
c1_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
make_multi_index(0,
0,
m_thread_data_on_block_idx[I1],
n_thread_data_on_block_idx[I1],
m_thread_data_on_block_idx[I2],
m_thread_data_on_block_idx[I3],
m_thread_data_on_block_idx[I4],
n_thread_data_on_block_idx[I2]),
tensor_operation::element_wise::PassThrough{}};
// tuple of reference to C/Ds tensor descriptors
const auto c1_d1s_desc_refs = concat_tuple_of_reference(
tie(c1_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
generate_tie(
[&](auto i) -> const auto& // return type should be reference
{ return d1s_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
Number<NumD1Tensor>{}));
// tuple of reference to C/Ds tensor descriptors
const auto c1_d1s_buf_refs = concat_tuple_of_reference(
tie(c1_shuffle_block_buf),
generate_tie(
[&](auto i) -> const auto& // return type should be reference
{ return d1s_grid_buf[i]; },
Number<NumD1Tensor>{}));
// tuple of starting index of C/Ds blockwise copy
const auto idx_c1_d1s_block_begin = container_concat(
make_tuple(make_multi_index(0, 0, 0, 0)),
generate_tuple(
[&](auto) {
return make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0);
},
Number<NumD1Tensor>{}));
// shuffle: blockwise copy C from LDS to global
auto cde1_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v7<
ThisThreadBlock,
decltype(container_concat(make_tuple(C1ShuffleDataType{}), D1sDataType{})),
Tuple<E1DataType>,
decltype(c1_d1s_desc_refs),
decltype(tie(e1_grid_desc_mblock_mperblock_nblock_nperblock)),
CDE1ElementwiseOperation,
Sequence<static_cast<index_t>(E1GlobalMemoryDataOperation)>, // FIXME: make Sequence
// support arbitray
// type
Sequence<1,
C1ShuffleGemm0MXdlPerWavePerShuffle * MWave * Gemm0MPerXdl,
1,
C1ShuffleGemm0NXdlPerWavePerShuffle * NWave *
Gemm0NPerXdl>, // BlockSliceLengths,
CDE1ShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
3, // index_t VectorDim,
CDE1ShuffleBlockTransferScalarPerVector_NPerBlock,
sequence_merge_t<
Sequence<true>,
uniform_sequence_gen_t<NumD1Tensor,
false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags
Sequence<false>> // ThreadTransferDstResetCoordinateAfterRunFlags
{c1_d1s_desc_refs,
idx_c1_d1s_block_begin,
tie(e1_grid_desc_mblock_mperblock_nblock_nperblock),
make_tuple(make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0)),
cde1_element_op};
// space filling curve for threadwise C in VGPR
constexpr auto sfc_c1_vgpr =
SpaceFillingCurve<Sequence<Gemm0MXdlPerWave, Gemm1NXdlPerWave, 1, 1, M2, 1, M4, 1>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
Sequence<C1ShuffleGemm0MXdlPerWavePerShuffle,
C1ShuffleGemm0NXdlPerWavePerShuffle,
1,
1,
M2,
1,
M4,
1>>{};
// space filling curve for shuffled blockwise C in global mem
constexpr auto sfc_e1_global = SpaceFillingCurve<
Sequence<1, Gemm0MPerBlock, 1, Gemm1NPerBlock>,
Sequence<0, 2, 1, 3>,
Sequence<1,
C1ShuffleGemm0MXdlPerWavePerShuffle * MWave * Gemm0MPerXdl,
1,
C1ShuffleGemm0NXdlPerWavePerShuffle * NWave * Gemm0NPerXdl>>{};
constexpr index_t num_access = sfc_c1_vgpr.GetNumOfAccess();
static_assert(num_access == sfc_e1_global.GetNumOfAccess(), "wrong!");
static_for<0, num_access, 1>{}([&](auto access_id) {
// make sure it's safe to write to LDS
block_sync_lds();
// each thread write its data from VGPR to LDS
c1_thread_copy_vgpr_to_lds.Run(c1_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
sfc_c1_vgpr.GetIndexTupleOfNumber(access_id),
c1_thread_buf,
c1_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
c1_shuffle_block_buf);
// make sure it's safe to read from LDS
block_sync_lds();
// each block copy its data from LDS to global
cde1_shuffle_block_copy_lds_to_global.Run(
c1_d1s_desc_refs,
c1_d1s_buf_refs,
tie(e1_grid_desc_mblock_mperblock_nblock_nperblock),
tie(e1_grid_buf));
if constexpr(access_id < num_access - 1)
{
constexpr auto e1_global_step = sfc_e1_global.GetForwardStep(access_id);
// move on D1s
static_for<0, NumD1Tensor, 1>{}([&](auto i) {
cde1_shuffle_block_copy_lds_to_global.MoveSrcSliceWindow(
c1_d1s_desc_refs, i + I1, e1_global_step);
});
// move on C
cde1_shuffle_block_copy_lds_to_global.MoveDstSliceWindow(
tie(e1_grid_desc_mblock_mperblock_nblock_nperblock), I0, e1_global_step);
}
});
}
}
};
} // namespace ck
......@@ -76,7 +76,8 @@ template <typename FloatAB,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopScheduler LoopSched,
bool PadN>
bool PadN,
bool MaskOutUpperTriangle>
struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
{
static_assert(LoopSched == LoopScheduler::Default,
......@@ -97,6 +98,10 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
static constexpr auto BK0 = Number<KPerBlock / BK1Value>{};
static constexpr auto AK1 = Number<AK1Value>{};
static constexpr auto BK1 = Number<BK1Value>{};
static constexpr auto Gemm0MWaves = MPerBlock / (MPerXdl * MXdlPerWave);
static constexpr auto Gemm0NWaves = NPerBlock / (NPerXdl * NXdlPerWave);
// Gemm1
static constexpr auto B1K0 = Number<Gemm1KPerBlock / B1K1Value>{};
static constexpr auto B1K1 = Number<B1K1Value>{};
......@@ -361,7 +366,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
}
};
template <bool HasMainKBlockLoop, typename Block2CTileMap>
template <bool HasMainKBlockLoop, typename Block2CTileMap, typename C0MatrixMask>
__device__ static void Run(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
const FloatAB* __restrict__ p_b1_grid,
......@@ -377,22 +382,13 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
const B1GridDesc_BK0_N_BK1& b1_grid_desc_bk0_n_bk1,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
c_grid_desc_mblock_mperblock_nblock_nperblock,
const Block2CTileMap& block_2_ctile_map)
const Block2CTileMap& block_2_ctile_map,
const C0MatrixMask& c0_matrix_mask)
{
const auto a_grid_buf =
conditional_expr<PadN>(make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid,
a_grid_desc_ak0_m_ak1.GetElementSpaceSize(),
NumericLimits<FloatAB>::QuietNaN()),
make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()));
const auto b_grid_buf =
conditional_expr<PadN>(make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid,
b_grid_desc_bk0_n_bk1.GetElementSpaceSize(),
NumericLimits<FloatAB>::QuietNaN()),
make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()));
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
const auto b1_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b1_grid, b1_grid_desc_bk0_n_bk1.GetElementSpaceSize());
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
......@@ -749,10 +745,30 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
running_max = NumericLimits<FloatGemmAcc>::Lowest();
running_max_new = NumericLimits<FloatGemmAcc>::Lowest();
// decoder lower triangular mask
const auto thread_cluster_idx = threadid_to_m_n_thread_cluster_adaptor.CalculateBottomIndex(
make_multi_index(get_thread_local_1d_id()));
const auto thread_m_cluster_id = thread_cluster_idx[I0];
const auto thread_n_cluster_id = thread_cluster_idx[I1];
const index_t MPerRepeat = MPerBlock / MXdlPerWave;
const index_t NPerRepeat = NPerBlock / NXdlPerWave;
const index_t mstart = m_block_data_idx_on_grid + thread_m_cluster_id;
// gemm1 K loop
index_t gemm1_k_block_outer_index = 0;
do
{
if constexpr(MaskOutUpperTriangle)
{
auto gemm0_n_block_idx =
__builtin_amdgcn_readfirstlane(gemm1_k_block_outer_index * NPerBlock);
if(c0_matrix_mask.IsUpperTriangle(m_block_data_idx_on_grid, gemm0_n_block_idx) &&
c0_matrix_mask.IsUpperTriangle(m_block_data_idx_on_grid + MPerBlock - 1,
gemm0_n_block_idx))
{
continue;
}
}
// gemm0
gridwise_gemm_pipeline.template Run<HasMainKBlockLoop>(a_grid_desc_ak0_m_ak1,
a_block_desc_ak0_m_ak1,
......@@ -770,16 +786,63 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
acc_thread_buf,
num_k_block_main_loop);
// Acc0 elementwise Op
#if CK_WORKAROUND_SWDEV_XXXXXX_ATTN_KERNEL_CLANG_CANNOT_SCAVENGE_REGISTER
static_for<0, acc_thread_buf.Size(), 1>{}(
[&](auto i) { acc_element_op(acc_thread_buf(i), acc_thread_buf[i]); });
#else
static_for<0, acc_thread_buf.Size(), 1>{}([&](auto i) {
ElementOpPredicatedResetNaNToMinusInf<PadN>{}.Run(
acc_thread_buf(i), acc_element_op, acc_thread_buf[i]);
});
#endif
// do MNK padding or upper triangular masking
if constexpr(MaskOutUpperTriangle || PadN)
{
const index_t nstart = gemm1_k_block_outer_index * NPerBlock;
static_for<0, m0, 1>{}([&](auto m0_i) {
const index_t m_global = mstart + m0_i * MPerRepeat;
const index_t acc_idx_m0 = m0_i * n0 * n2 * n4;
static_for<0, n0, 1>{}([&](auto n0_i) {
// constexpr auto nrepeat_i = n0_i * NPerRepeat;
// const index_t nstartxdl = nstart + nrepeat_i;
const index_t nstartxdl = nstart + n0_i * NPerRepeat;
const index_t acc_idx_n0 = acc_idx_m0 + n0_i * n2 * n4;
static_for<0, n2, 1>{}([&](auto n2_i) {
const index_t nstartgroup =
nstartxdl + thread_n_cluster_id * n4 + n2_i * AccN3 * n4;
const index_t acc_idx_n2 = acc_idx_n0 + n2_i * n4;
static_for<0, n4, 1>{}([&](auto n4_i) {
const index_t n_global = nstartgroup + n4_i;
const auto acc_offset = Number<acc_idx_n2 + n4_i>{};
if constexpr(MaskOutUpperTriangle)
{
if(c0_matrix_mask.IsMaskedElement(m_global, n_global))
{
acc_thread_buf(acc_offset) =
-ck::NumericLimits<float>::Infinity();
}
else
{
acc_element_op(acc_thread_buf(acc_offset),
acc_thread_buf[acc_offset]);
}
}
else
{
// ignore m_global;
if(c0_matrix_mask.IsNOutOfBound(n_global))
{
acc_thread_buf(acc_offset) =
-ck::NumericLimits<float>::Infinity();
}
else
{
acc_element_op(acc_thread_buf(acc_offset),
acc_thread_buf[acc_offset]);
}
}
});
});
});
});
}
else
{
static_for<0, acc_thread_buf.Size(), 1>{}(
[&](auto i) { acc_element_op(acc_thread_buf(i), acc_thread_buf[i]); });
}
block_sync_lds(); // wait for lds read in gemm0 blockwise gemm
......@@ -881,9 +944,10 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
FloatGemmAcc c_new =
(running_sum[iM] * math::exp(running_max[iM] - running_max_new[iM]) * c +
math::exp(max[iM] - running_max_new[iM]) * acc1) /
running_sum_new[iM]; // O_new
running_sum_new[iM]; // Formula by Dao et al.,
// https://arxiv.org/pdf/2205.14135v2.pdf section 3.1
c_thread_buf(I) = c_new;
c_thread_buf(I) = c_new; // O_new
});
});
......
......@@ -83,6 +83,8 @@ struct GridwiseElementwise_1D
auto in_global_buf_tuple = generate_tuple(
[&](auto I) {
static_assert(in_grid_1d_desc_tuple[I].GetNumOfDimension() == 1);
return make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_global_tuple[I], in_grid_1d_desc_tuple[I].GetElementSpaceSize());
},
......@@ -90,6 +92,8 @@ struct GridwiseElementwise_1D
auto out_global_buf_tuple = generate_tuple(
[&](auto I) {
static_assert(out_grid_1d_desc_tuple[I].GetNumOfDimension() == 1);
return make_dynamic_buffer<AddressSpaceEnum::Global>(
p_out_global_tuple[I], out_grid_1d_desc_tuple[I].GetElementSpaceSize());
},
......
......@@ -35,10 +35,6 @@ template <typename ABDataType, // FIXME: don't assume A/B have same datatype
typename BElementwiseOperation,
typename CDEElementwiseOperation,
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
typename AGridDesc_M_K,
typename BGridDesc_N_K,
typename DsGridDesc_M_N,
typename EGridDesc_M_N,
index_t NumGemmKPrefetchStage,
index_t BlockSize,
index_t MPerBlock,
......@@ -166,6 +162,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
}
// A desc for source in blockwise copy
template <typename AGridDesc_M_K>
__host__ __device__ static constexpr auto
MakeDefaultAGridDescriptor_AK0_M_AK1(const AGridDesc_M_K& a_grid_desc_m_k)
{
......@@ -182,6 +179,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
}
// B desc for source in blockwise copy
template <typename BGridDesc_N_K>
__host__ __device__ static constexpr auto
MakeDefaultBGridDescriptor_BK0_N_BK1(const BGridDesc_N_K& b_grid_desc_n_k)
{
......@@ -198,9 +196,9 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
}
// E desc for destination in blockwise copy
template <typename EGridDescriptor_M_N>
__host__ __device__ static constexpr auto MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
const EGridDescriptor_M_N& e_grid_desc_m_n)
template <typename EGridDesc_M_N>
__host__ __device__ static constexpr auto
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const EGridDesc_M_N& e_grid_desc_m_n)
{
const auto M = e_grid_desc_m_n.GetLength(I0);
const auto N = e_grid_desc_m_n.GetLength(I1);
......@@ -219,10 +217,9 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
}
// Ds desc for source in blockwise copy
template <typename DsGridDescriptor_M_N>
template <typename DsGridDesc_M_N>
__host__ __device__ static constexpr auto
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
const DsGridDescriptor_M_N& ds_grid_desc_m_n)
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DsGridDesc_M_N& ds_grid_desc_m_n)
{
return generate_tuple(
[&](auto i) {
......@@ -232,6 +229,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
}
// return block_id to E matrix tile idx (m0, n0) mapping
template <typename EGridDesc_M_N>
__host__ __device__ static constexpr auto
MakeDefaultBlock2ETileMap(const EGridDesc_M_N& e_grid_desc_m_n)
{
......@@ -240,7 +238,11 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
template <typename Block2ETileMap>
template <typename AGridDesc_M_K,
typename BGridDesc_N_K,
typename DsGridDesc_M_N,
typename EGridDesc_M_N,
typename Block2ETileMap>
__host__ __device__ static constexpr bool CheckValidity(const AGridDesc_M_K& a_grid_desc_m_k,
const BGridDesc_N_K& b_grid_desc_n_k,
const DsGridDesc_M_N& ds_grid_desc_m_n,
......@@ -314,23 +316,13 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
return GridwiseGemmPipe::CalculateHasMainLoop(num_loop);
}
using DefaultAGridDesc_AK0_M_AK1 =
remove_cvref_t<decltype(MakeDefaultAGridDescriptor_AK0_M_AK1(AGridDesc_M_K{}))>;
using DefaultBGridDesc_BK0_N_BK1 =
remove_cvref_t<decltype(MakeDefaultBGridDescriptor_BK0_N_BK1(BGridDesc_N_K{}))>;
using EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{}))>;
using DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(DsGridDesc_M_N{}))>;
using DefaultBlock2ETileMap =
remove_cvref_t<decltype(MakeDefaultBlock2ETileMap(EGridDesc_M_N{}))>;
using DsGridPointer = decltype(MakeDsGridPointer());
template <bool HasMainKBlockLoop,
typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1,
typename DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
typename EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
typename Block2ETileMap>
__device__ static void Run(const ABDataType* __restrict__ p_a_grid,
const ABDataType* __restrict__ p_b_grid,
......@@ -342,9 +334,9 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
const CDEElementwiseOperation& cde_element_op,
const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
const DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
ds_grid_desc_mblock_mperblock_nblock_nperblock,
const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
e_grid_desc_mblock_mperblock_nblock_nperblock,
const Block2ETileMap& block_2_etile_map)
{
......
......@@ -22,7 +22,6 @@ template <typename XDataType,
typename AccDataType,
typename AccElementwiseOperation,
typename GridDesc_M_K,
typename GridDesc_K,
index_t BlockSize,
index_t MThreadClusterSize,
index_t KThreadClusterSize,
......@@ -30,7 +29,9 @@ template <typename XDataType,
index_t KThreadSliceSize,
index_t XSrcVectorDim,
index_t XSrcVectorSize,
index_t GammaSrcVectorDim,
index_t GammaSrcVectorSize,
index_t BetaSrcVectorDim,
index_t BetaSrcVectorSize,
index_t YDstVectorDim,
index_t YDstVectorSize,
......@@ -78,13 +79,14 @@ struct GridwiseLayernormNaiveVariance_mk_to_mk
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
__device__ static void Run(const GridDesc_M_K& x_grid_desc_m_k,
const GridDesc_K& gamma_grid_desc_k,
const GridDesc_K& beta_grid_desc_k,
const GridDesc_M_K& gamma_grid_desc_m_k,
const GridDesc_M_K& beta_grid_desc_m_k,
const GridDesc_M_K& y_grid_desc_m_k,
index_t num_k_block_tile_iteration,
AccDataType epsilon,
......@@ -111,11 +113,14 @@ struct GridwiseLayernormNaiveVariance_mk_to_mk
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
x_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, KThreadSliceSize, true> gamma_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, KThreadSliceSize, true>& beta_thread_buf =
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
gamma_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr,
AccDataType,
MThreadSliceSize * KThreadSliceSize,
true>& beta_thread_buf = gamma_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
y_thread_buf;
......@@ -127,7 +132,7 @@ struct GridwiseLayernormNaiveVariance_mk_to_mk
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> mean_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true>
mean_square_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true>& var_value_buf =
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true>& var_thread_buf =
mean_square_thread_buf;
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
......@@ -145,11 +150,8 @@ struct GridwiseLayernormNaiveVariance_mk_to_mk
const auto thread_k_cluster_id = thread_cluster_idx[I1];
using ThreadBufferLengths_M_K = Sequence<MThreadSliceSize, KThreadSliceSize>;
using ThreadBufferLengths_K = Sequence<KThreadSliceSize>;
constexpr auto thread_buffer_desc_m_k = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{}));
constexpr auto thread_buffer_desc_k =
make_naive_tensor_descriptor_packed(make_tuple(Number<KThreadSliceSize>{}));
auto threadwise_x_load = ThreadwiseTensorSliceTransfer_v2<XDataType,
AccDataType,
......@@ -169,27 +171,34 @@ struct GridwiseLayernormNaiveVariance_mk_to_mk
auto threadwise_gamma_load =
ThreadwiseTensorSliceTransfer_v2<GammaDataType,
AccDataType,
GridDesc_K,
decltype(thread_buffer_desc_k),
ThreadBufferLengths_K,
Sequence<0>,
0,
GridDesc_M_K,
decltype(thread_buffer_desc_m_k),
ThreadBufferLengths_M_K,
ThreadBufferDimAccessOrder,
GammaSrcVectorDim,
GammaSrcVectorSize,
1,
true>(
gamma_grid_desc_k, make_multi_index(thread_k_cluster_id * KThreadSliceSize));
auto threadwise_beta_load = ThreadwiseTensorSliceTransfer_v2<BetaDataType,
AccDataType,
GridDesc_K,
decltype(thread_buffer_desc_k),
ThreadBufferLengths_K,
Sequence<0>,
0,
BetaSrcVectorSize,
1,
true>(
beta_grid_desc_k, make_multi_index(thread_k_cluster_id * KThreadSliceSize));
gamma_grid_desc_m_k,
make_multi_index(block_global_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * KThreadSliceSize));
auto threadwise_beta_load =
ThreadwiseTensorSliceTransfer_v2<BetaDataType,
AccDataType,
GridDesc_M_K,
decltype(thread_buffer_desc_m_k),
ThreadBufferLengths_M_K,
ThreadBufferDimAccessOrder,
BetaSrcVectorDim,
BetaSrcVectorSize,
1,
true>(
beta_grid_desc_m_k,
make_multi_index(block_global_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * KThreadSliceSize));
auto threadwise_y_store =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
......@@ -212,9 +221,6 @@ struct GridwiseLayernormNaiveVariance_mk_to_mk
// Copy x from Cache
// one pass: fwd, second pass: bwd
constexpr auto thread_copy_fwd_step_k = make_multi_index(SweepOnce ? 0 : K_BlockTileSize);
constexpr auto thread_copy_bwd_step_k = make_multi_index(SweepOnce ? 0 : -K_BlockTileSize);
constexpr auto thread_copy_fwd_step_m_k =
make_multi_index(0, SweepOnce ? 0 : K_BlockTileSize);
constexpr auto thread_copy_bwd_step_m_k =
......@@ -224,13 +230,14 @@ struct GridwiseLayernormNaiveVariance_mk_to_mk
p_x_global, x_grid_desc_m_k.GetElementSpaceSize());
const auto gamma_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_gamma_global, gamma_grid_desc_k.GetElementSpaceSize());
p_gamma_global, gamma_grid_desc_m_k.GetElementSpaceSize());
const auto beta_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_beta_global, beta_grid_desc_k.GetElementSpaceSize());
p_beta_global, beta_grid_desc_m_k.GetElementSpaceSize());
// E(x), E[x^2], var(x)
int reduce_length = x_grid_desc_m_k.GetTransforms()[I0].GetUpperLengths()[I1];
// FIXME: Should not hack the transform from deviceOP
int reduce_length = x_grid_desc_m_k.GetTransforms()[I2].GetUpperLengths()[I0];
index_t reducedTiles = 0;
do
......@@ -271,17 +278,16 @@ struct GridwiseLayernormNaiveVariance_mk_to_mk
mean_square_thread_buf(I) = mean_square_thread_buf(I) / reduce_length;
// var(x) = E[x^2] - E[x]^2
var_value_buf(I) =
var_thread_buf(I) =
mean_square_thread_buf(I) - (mean_thread_buf(I) * mean_thread_buf(I));
});
// y = (x - E[x]) / sqrt(var[x] + epsilon)
auto thread_copy_tail_m_k = (num_k_block_tile_iteration - 1) * thread_copy_fwd_step_m_k;
auto thread_copy_tail_k = (num_k_block_tile_iteration - 1) * thread_copy_fwd_step_k;
threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_bwd_step_m_k);
threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_k, thread_copy_tail_k);
threadwise_beta_load.MoveSrcSliceWindow(beta_grid_desc_k, thread_copy_tail_k);
threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k, thread_copy_tail_m_k);
threadwise_beta_load.MoveSrcSliceWindow(beta_grid_desc_m_k, thread_copy_tail_m_k);
threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k, thread_copy_tail_m_k);
reducedTiles = 0;
......@@ -296,10 +302,10 @@ struct GridwiseLayernormNaiveVariance_mk_to_mk
x_thread_buf);
}
threadwise_gamma_load.Run(gamma_grid_desc_k,
threadwise_gamma_load.Run(gamma_grid_desc_m_k,
gamma_global_val_buf,
thread_buffer_desc_k,
make_tuple(I0),
thread_buffer_desc_m_k,
make_tuple(I0, I0),
gamma_thread_buf);
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
......@@ -307,23 +313,21 @@ struct GridwiseLayernormNaiveVariance_mk_to_mk
constexpr auto offset_m_k =
thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK));
constexpr auto offset_k = thread_buffer_desc_k.CalculateOffset(make_tuple(iK));
// normalize
y_thread_buf(Number<offset_m_k>{}) =
(x_thread_buf(Number<offset_m_k>{}) - mean_thread_buf(iM)) /
sqrt(var_value_buf(iM) + epsilon);
sqrt(var_thread_buf(iM) + epsilon);
// gamma
y_thread_buf(Number<offset_m_k>{}) =
y_thread_buf(Number<offset_m_k>{}) * gamma_thread_buf(Number<offset_k>{});
y_thread_buf(Number<offset_m_k>{}) * gamma_thread_buf(Number<offset_m_k>{});
});
});
threadwise_beta_load.Run(beta_grid_desc_k,
threadwise_beta_load.Run(beta_grid_desc_m_k,
beta_global_val_buf,
thread_buffer_desc_k,
make_tuple(I0),
thread_buffer_desc_m_k,
make_tuple(I0, I0),
beta_thread_buf);
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
......@@ -331,11 +335,9 @@ struct GridwiseLayernormNaiveVariance_mk_to_mk
constexpr auto offset_m_k =
thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK));
constexpr auto offset_k = thread_buffer_desc_k.CalculateOffset(make_tuple(iK));
// beta
y_thread_buf(Number<offset_m_k>{}) =
y_thread_buf(Number<offset_m_k>{}) + beta_thread_buf(Number<offset_k>{});
y_thread_buf(Number<offset_m_k>{}) + beta_thread_buf(Number<offset_m_k>{});
});
});
......@@ -346,8 +348,8 @@ struct GridwiseLayernormNaiveVariance_mk_to_mk
y_global_val_buf);
threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_bwd_step_m_k);
threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_k, thread_copy_bwd_step_k);
threadwise_beta_load.MoveSrcSliceWindow(beta_grid_desc_k, thread_copy_bwd_step_k);
threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k, thread_copy_bwd_step_m_k);
threadwise_beta_load.MoveSrcSliceWindow(beta_grid_desc_m_k, thread_copy_bwd_step_m_k);
threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k, thread_copy_bwd_step_m_k);
++reducedTiles;
......
......@@ -19,7 +19,6 @@ template <typename XDataType,
typename AccDataType,
typename AccElementwiseOperation,
typename GridDesc_M_K,
typename GridDesc_K,
index_t BlockSize,
index_t MThreadClusterSize,
index_t KThreadClusterSize,
......@@ -27,7 +26,9 @@ template <typename XDataType,
index_t KThreadSliceSize,
index_t XSrcVectorDim,
index_t XSrcVectorSize,
index_t GammaSrcVectorDim,
index_t GammaSrcVectorSize,
index_t BetaSrcVectorDim,
index_t BetaSrcVectorSize,
index_t YDstVectorDim,
index_t YDstVectorSize,
......@@ -70,6 +71,7 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
......@@ -77,7 +79,8 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
__device__ static int GetKPerThread(const GridDesc_M_K& x_grid_desc_m_k,
int thread_k_cluster_id)
{
int kPerBlock = x_grid_desc_m_k.GetTransforms()[I0].GetUpperLengths()[I1];
// FIXME: Should not hack the transform from deviceOP
int kPerBlock = x_grid_desc_m_k.GetTransforms()[I2].GetUpperLengths()[I0];
int kPerThread =
kPerBlock < K_BlockTileSize ? 0 : KThreadSliceSize * (kPerBlock / K_BlockTileSize);
int kPerBlockTail = kPerBlock - kPerThread * KThreadClusterSize;
......@@ -94,8 +97,8 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
}
__device__ static void Run(const GridDesc_M_K& x_grid_desc_m_k,
const GridDesc_K& gamma_grid_desc_k,
const GridDesc_K& beta_grid_desc_k,
const GridDesc_M_K& gamma_grid_desc_m_k,
const GridDesc_M_K& beta_grid_desc_m_k,
const GridDesc_M_K& y_grid_desc_m_k,
index_t num_k_block_tile_iteration,
AccDataType epsilon,
......@@ -116,11 +119,14 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
x_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, KThreadSliceSize, true> gamma_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, KThreadSliceSize, true>& beta_thread_buf =
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
gamma_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr,
AccDataType,
MThreadSliceSize * KThreadSliceSize,
true>& beta_thread_buf = gamma_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
y_thread_buf;
......@@ -137,11 +143,8 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
const auto thread_k_cluster_id = thread_cluster_idx[I1];
using ThreadBufferLengths_M_K = Sequence<MThreadSliceSize, KThreadSliceSize>;
using ThreadBufferLengths_K = Sequence<KThreadSliceSize>;
constexpr auto thread_buffer_desc_m_k = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{}));
constexpr auto thread_buffer_desc_k =
make_naive_tensor_descriptor_packed(make_tuple(Number<KThreadSliceSize>{}));
auto threadwise_x_load = ThreadwiseTensorSliceTransfer_v2<XDataType,
AccDataType,
......@@ -161,27 +164,34 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
auto threadwise_gamma_load =
ThreadwiseTensorSliceTransfer_v2<GammaDataType,
AccDataType,
GridDesc_K,
decltype(thread_buffer_desc_k),
ThreadBufferLengths_K,
Sequence<0>,
0,
GridDesc_M_K,
decltype(thread_buffer_desc_m_k),
ThreadBufferLengths_M_K,
ThreadBufferDimAccessOrder,
GammaSrcVectorDim,
GammaSrcVectorSize,
1,
true>(
gamma_grid_desc_k, make_multi_index(thread_k_cluster_id * KThreadSliceSize));
auto threadwise_beta_load = ThreadwiseTensorSliceTransfer_v2<BetaDataType,
AccDataType,
GridDesc_K,
decltype(thread_buffer_desc_k),
ThreadBufferLengths_K,
Sequence<0>,
0,
BetaSrcVectorSize,
1,
true>(
beta_grid_desc_k, make_multi_index(thread_k_cluster_id * KThreadSliceSize));
gamma_grid_desc_m_k,
make_multi_index(block_global_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * KThreadSliceSize));
auto threadwise_beta_load =
ThreadwiseTensorSliceTransfer_v2<BetaDataType,
AccDataType,
GridDesc_M_K,
decltype(thread_buffer_desc_m_k),
ThreadBufferLengths_M_K,
ThreadBufferDimAccessOrder,
BetaSrcVectorDim,
BetaSrcVectorSize,
1,
true>(
beta_grid_desc_m_k,
make_multi_index(block_global_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * KThreadSliceSize));
auto threadwise_y_store =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
......@@ -204,9 +214,6 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
// Copy x from Cache
// one pass: fwd, second pass: bwd
constexpr auto thread_copy_fwd_step_k = make_multi_index(SweepOnce ? 0 : K_BlockTileSize);
constexpr auto thread_copy_bwd_step_k = make_multi_index(SweepOnce ? 0 : -K_BlockTileSize);
constexpr auto thread_copy_fwd_step_m_k =
make_multi_index(0, SweepOnce ? 0 : K_BlockTileSize);
constexpr auto thread_copy_bwd_step_m_k =
......@@ -216,10 +223,10 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
p_x_global, x_grid_desc_m_k.GetElementSpaceSize());
const auto gamma_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_gamma_global, gamma_grid_desc_k.GetElementSpaceSize());
p_gamma_global, gamma_grid_desc_m_k.GetElementSpaceSize());
const auto beta_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_beta_global, beta_grid_desc_k.GetElementSpaceSize());
p_beta_global, beta_grid_desc_m_k.GetElementSpaceSize());
auto threadwise_welford = ThreadwiseWelford();
threadwise_welford.max_count_ = GetKPerThread(x_grid_desc_m_k, thread_k_cluster_id);
......@@ -250,11 +257,10 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
});
auto thread_copy_tail_m_k = (num_k_block_tile_iteration - 1) * thread_copy_fwd_step_m_k;
auto thread_copy_tail_k = (num_k_block_tile_iteration - 1) * thread_copy_fwd_step_k;
threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_bwd_step_m_k);
threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_k, thread_copy_tail_k);
threadwise_beta_load.MoveSrcSliceWindow(beta_grid_desc_k, thread_copy_tail_k);
threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k, thread_copy_tail_m_k);
threadwise_beta_load.MoveSrcSliceWindow(beta_grid_desc_m_k, thread_copy_tail_m_k);
threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k, thread_copy_tail_m_k);
for(index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles)
......@@ -268,10 +274,10 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
x_thread_buf);
}
threadwise_gamma_load.Run(gamma_grid_desc_k,
threadwise_gamma_load.Run(gamma_grid_desc_m_k,
gamma_global_val_buf,
thread_buffer_desc_k,
make_tuple(I0),
thread_buffer_desc_m_k,
make_tuple(I0, I0),
gamma_thread_buf);
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
......@@ -279,8 +285,6 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
constexpr auto offset_m_k =
thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK));
constexpr auto offset_k = thread_buffer_desc_k.CalculateOffset(make_tuple(iK));
// normalize
y_thread_buf(Number<offset_m_k>{}) =
(x_thread_buf(Number<offset_m_k>{}) - mean_thread_buf(iM)) /
......@@ -288,14 +292,14 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
// gamma
y_thread_buf(Number<offset_m_k>{}) =
y_thread_buf(Number<offset_m_k>{}) * gamma_thread_buf(Number<offset_k>{});
y_thread_buf(Number<offset_m_k>{}) * gamma_thread_buf(Number<offset_m_k>{});
});
});
threadwise_beta_load.Run(beta_grid_desc_k,
threadwise_beta_load.Run(beta_grid_desc_m_k,
beta_global_val_buf,
thread_buffer_desc_k,
make_tuple(I0),
thread_buffer_desc_m_k,
make_tuple(I0, I0),
beta_thread_buf);
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
......@@ -303,11 +307,9 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
constexpr auto offset_m_k =
thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK));
constexpr auto offset_k = thread_buffer_desc_k.CalculateOffset(make_tuple(iK));
// beta
y_thread_buf(Number<offset_m_k>{}) =
y_thread_buf(Number<offset_m_k>{}) + beta_thread_buf(Number<offset_k>{});
y_thread_buf(Number<offset_m_k>{}) + beta_thread_buf(Number<offset_m_k>{});
});
});
......@@ -318,8 +320,8 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
y_global_val_buf);
threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_bwd_step_m_k);
threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_k, thread_copy_bwd_step_k);
threadwise_beta_load.MoveSrcSliceWindow(beta_grid_desc_k, thread_copy_bwd_step_k);
threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k, thread_copy_bwd_step_m_k);
threadwise_beta_load.MoveSrcSliceWindow(beta_grid_desc_m_k, thread_copy_bwd_step_m_k);
threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k, thread_copy_bwd_step_m_k);
}
}
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <functional>
#include <numeric>
#include <iterator>
#include "ck/tensor_description/cluster_descriptor.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace ck {
template <typename GridwisePermute,
typename InGridDesc,
typename OutGridDesc,
typename InDataType,
typename OutDataType,
typename ElementwiseOperation,
typename Block2TileMap>
__global__ void kernel_nd_permute(const InGridDesc in_grid_desc,
const OutGridDesc out_grid_desc,
const InDataType* p_in_global,
OutDataType* p_out_global,
const ElementwiseOperation elementwise_op,
const Block2TileMap block_2_tile_map)
{
__shared__ char p_shared[GridwisePermute::GetSharedMemoryNumberOfByte()];
GridwisePermute::Run(in_grid_desc,
out_grid_desc,
p_in_global,
p_out_global,
p_shared,
elementwise_op,
block_2_tile_map);
}
template <typename InGridDesc,
typename OutGridDesc,
typename InDataType,
typename OutDataType,
typename ElementwiseOperation,
index_t BlockSize,
index_t NPerBlock,
index_t HPerBlock,
index_t WPerBlock,
index_t InBlockLdsExtraW,
typename InBlockTransferThreadClusterLengths,
typename InBlockTransferThreadClusterArrangeOrder,
index_t SrcVectorDim,
index_t DstVectorDim,
index_t SrcScalarPerVector,
index_t DstScalarPerVector>
struct GridwisePermute
{
static_assert(InGridDesc::GetNumOfDimension() == OutGridDesc::GetNumOfDimension());
static_assert(3 <= InGridDesc::GetNumOfDimension());
static_assert((InGridDesc::GetNumOfDimension() - 2) <= SrcVectorDim &&
SrcVectorDim < InGridDesc::GetNumOfDimension());
static_assert((OutGridDesc::GetNumOfDimension() - 2) <= DstVectorDim &&
DstVectorDim < OutGridDesc::GetNumOfDimension());
static_assert(SrcVectorDim != DstVectorDim);
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
struct Block2TileMap
{
static constexpr index_t NumDim = InGridDesc::GetNumOfDimension();
static_assert(3 <= NumDim);
static constexpr auto I0 = Number<0>{};
Block2TileMap() = delete;
Block2TileMap(const Block2TileMap&) = default;
Block2TileMap(Block2TileMap&&) = delete;
~Block2TileMap() = default;
Block2TileMap& operator=(const Block2TileMap&) = delete;
Block2TileMap& operator=(Block2TileMap&&) = delete;
explicit Block2TileMap(const InGridDesc& desc) : desc_(desc) {}
__host__ constexpr index_t CalculateGridSize(const InGridDesc& desc) const
{
const auto N0 =
math::integer_divide_ceil(desc.GetLength(Number<NumDim - 3>{}), NPerBlock);
const auto H0 =
math::integer_divide_ceil(desc.GetLength(Number<NumDim - 2>{}), HPerBlock);
const auto W0 =
math::integer_divide_ceil(desc.GetLength(Number<NumDim - 1>{}), WPerBlock);
const index_t grid_size = N0 * H0 * W0;
return grid_size;
}
template <typename TopIdx>
__host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const
{
static_assert(TopIdx::Size() == 1);
auto block_1d_id = idx_top[I0];
const auto N0 =
math::integer_divide_ceil(desc_.GetLength(Number<NumDim - 3>{}), NPerBlock);
const auto H0 =
math::integer_divide_ceil(desc_.GetLength(Number<NumDim - 2>{}), HPerBlock);
const auto W0 =
math::integer_divide_ceil(desc_.GetLength(Number<NumDim - 1>{}), WPerBlock);
block_1d_id = block_1d_id % (N0 * H0 * W0);
index_t idx_N0 = block_1d_id / (H0 * W0);
index_t idx_H0 = (block_1d_id % (H0 * W0)) / W0;
index_t idx_W0 = block_1d_id % W0;
return make_tuple(idx_N0, idx_H0, idx_W0);
}
private:
const InGridDesc desc_;
};
using DefaultBlock2TileMap = Block2TileMap;
// use an [NPerBlock, HPerBlock, WPerBlock] tensor as element-copy relay
__host__ __device__ static constexpr auto GetInBlockDesc_NPerBlock_HPerBlock_WPerBlock()
{
return make_naive_tensor_descriptor(
make_tuple(Number<NPerBlock>{}, Number<HPerBlock>{}, Number<WPerBlock>{}),
make_tuple(Number<HPerBlock*(WPerBlock + InBlockLdsExtraW)>{},
Number<WPerBlock + InBlockLdsExtraW>{},
I1));
}
// for N-dimension descriptor, reserve its last 2 dimensions, then merge its leading dimensions
// into single one. finally, form a 3D descriptor: [d(0), d(1), ..., d(N - 2), d(N - 1)] ->
// [(d(0) x d(1) x ...), d(N - 2), d(N - 1)]
template <typename GridDesc>
__host__ __device__ static constexpr auto GetMergedDesc(const GridDesc& desc)
{
constexpr index_t NumDim = GridDesc::GetNumOfDimension();
static_assert(3 <= NumDim);
const auto merged_desc = transform_tensor_descriptor(
desc,
make_tuple(make_merge_transform(generate_tuple(
[&](auto I) { return desc.GetLength(I); }, Number<NumDim - 2>{})),
make_pass_through_transform(desc.GetLength(Number<NumDim - 2>{})),
make_pass_through_transform(desc.GetLength(Number<NumDim - 1>{}))),
make_tuple(generate_sequence_v2([&](auto I) { return I; }, Number<NumDim - 2>{}),
Sequence<NumDim - 2>{},
Sequence<NumDim - 1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
return merged_desc;
}
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{
constexpr auto in_block_desc_nperblock_hperblock_wperblock =
GetInBlockDesc_NPerBlock_HPerBlock_WPerBlock();
return in_block_desc_nperblock_hperblock_wperblock.GetElementSpaceSize() *
sizeof(InDataType);
}
__host__ __device__ static constexpr auto MakeDefaultBlock2TileMap(const InGridDesc& desc)
{
return DefaultBlock2TileMap{desc};
}
__host__ __device__ static constexpr bool CheckValidity(const InGridDesc& in_grid_desc,
const OutGridDesc& out_grid_desc)
{
constexpr index_t NumDim = InGridDesc::GetNumOfDimension();
// check if we only swap last 2 dimensions
bool valid = true;
static_for<0, NumDim - 2, 1>{}([&](auto I) {
if(valid && in_grid_desc.GetLength(I) != out_grid_desc.GetLength(I))
{
valid = false;
}
});
return valid &&
(in_grid_desc.GetLength(Number<NumDim - 1>{}) ==
out_grid_desc.GetLength(Number<NumDim - 2>{})) &&
(in_grid_desc.GetLength(Number<NumDim - 2>{}) ==
out_grid_desc.GetLength(Number<NumDim - 1>{}));
}
template <typename Block2TileMap>
__device__ static void Run(const InGridDesc in_grid_desc,
const OutGridDesc out_grid_desc,
const InDataType* p_in_global,
OutDataType* p_out_global,
void* __restrict__ p_shared,
const ElementwiseOperation elementwise_op,
const Block2TileMap& block_2_tile_map)
{
auto in_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_global, in_grid_desc.GetElementSpaceSize());
auto out_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_out_global, out_grid_desc.GetElementSpaceSize());
// each workgroup handles an [NPerBlock, HPerBlock, WPerBLock] slice-transpose problem
const auto block_work_idx =
block_2_tile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
const index_t n_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I0] * NPerBlock);
const index_t h_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I1] * HPerBlock);
const index_t w_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I2] * WPerBlock);
// create [NPerBlock, HPerBlock, WPerBLock] shaped LDS buffer
constexpr auto in_block_desc_nperblock_hperblock_wperblock =
GetInBlockDesc_NPerBlock_HPerBlock_WPerBlock();
auto in_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<InDataType*>(p_shared),
in_block_desc_nperblock_hperblock_wperblock.GetElementSpaceSize());
using BlockSliceLengths = Sequence<NPerBlock, HPerBlock, WPerBlock>;
using InBlockTransferAccessOrder = Sequence<0, 1, 2>;
constexpr index_t SrcVectorDimAfterMerge =
SrcVectorDim - (InGridDesc::GetNumOfDimension() - 3);
constexpr index_t DstVectorDimAfterMerge = SrcVectorDimAfterMerge;
using ck::tensor_operation::element_wise::PassThrough;
// merge input descriptor into [(in_grid_desc.GetLength(0) x in_grid_desc.GetLength(1) x
// ...), in_grid_desc.GetLength(NumDim - 2), in_grid_desc.GetLength(NumDim - 1)]
const auto in_grid_desc_n_h_w = GetMergedDesc(in_grid_desc);
// a workgroup copies an [NPerBlock, HPerBlock, WPerBlock] slice from global memory to LDS
auto in_global_load = ThreadGroupTensorSliceTransfer_v4r1<
ThisThreadBlock,
ElementwiseOperation,
PassThrough,
InMemoryDataOperationEnum::Set,
BlockSliceLengths,
InBlockTransferThreadClusterLengths,
InBlockTransferThreadClusterArrangeOrder,
InDataType,
InDataType,
decltype(in_grid_desc_n_h_w),
decltype(in_block_desc_nperblock_hperblock_wperblock),
InBlockTransferAccessOrder,
InBlockTransferAccessOrder,
SrcVectorDimAfterMerge,
2,
SrcScalarPerVector,
1,
1,
1,
true,
true>(in_grid_desc_n_h_w,
make_multi_index(
n_block_data_idx_on_grid, h_block_data_idx_on_grid, w_block_data_idx_on_grid),
PassThrough{},
in_block_desc_nperblock_hperblock_wperblock,
make_multi_index(0, 0, 0),
PassThrough{});
// merge output descriptor into [(out_grid_desc.GetLength(0) x out_grid_desc.GetLength(1) x
// ...), out_grid_desc.GetLength(NumDim - 2), out_grid_desc.GetLength(NumDim - 1)]
const auto out_grid_desc_n_w_h = GetMergedDesc(out_grid_desc);
// create transposed view of output tensor
const auto out_grid_desc_n_h_w = transform_tensor_descriptor(
out_grid_desc_n_w_h,
make_tuple(make_pass_through_transform(out_grid_desc_n_w_h.GetLength(I0)),
make_pass_through_transform(out_grid_desc_n_w_h.GetLength(I1)),
make_pass_through_transform(out_grid_desc_n_w_h.GetLength(I2))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<2>{}, Sequence<1>{}));
// a workgroup copies an [NPerBlock, HPerBlock, WPerBlock] slice from LDS to global memory
auto out_global_store = ThreadGroupTensorSliceTransfer_v4r1<
ThisThreadBlock,
ElementwiseOperation,
PassThrough,
InMemoryDataOperationEnum::Set,
BlockSliceLengths,
InBlockTransferThreadClusterLengths,
InBlockTransferThreadClusterArrangeOrder,
InDataType,
OutDataType,
decltype(in_block_desc_nperblock_hperblock_wperblock),
decltype(out_grid_desc_n_h_w),
InBlockTransferAccessOrder,
InBlockTransferAccessOrder,
2,
DstVectorDimAfterMerge,
1,
DstScalarPerVector,
1,
1,
true,
true>(in_block_desc_nperblock_hperblock_wperblock,
make_multi_index(0, 0, 0),
PassThrough{},
out_grid_desc_n_h_w,
make_multi_index(
n_block_data_idx_on_grid, h_block_data_idx_on_grid, w_block_data_idx_on_grid),
elementwise_op);
in_global_load.Run(in_grid_desc_n_h_w,
in_global_buf,
in_block_desc_nperblock_hperblock_wperblock,
in_block_buf,
I0);
out_global_store.Run(in_block_desc_nperblock_hperblock_wperblock,
in_block_buf,
out_grid_desc_n_h_w,
out_global_buf,
I0);
}
};
} // namespace ck
......@@ -6,6 +6,7 @@
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor/static_tensor.hpp"
namespace ck {
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
namespace ck {
namespace tensor_operation {
template <
index_t NDimSpatial,
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization ConvBwdDataSpecialization,
index_t AK1,
index_t BK1,
index_t GemmMPerBlock,
index_t GemmNPerBlock,
bool DoPadGemmM,
bool DoPadGemmN>
struct TransformConvBwdDataToGemm_v1
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
template <typename ALayout,
typename std::enable_if<NDimSpatial == 2 &&
is_same_v<ALayout, tensor_layout::convolution::GNHWK>,
bool>::type = false>
static auto MakeADescriptor_AK0_M_AK1(
const std::array<index_t, NDimSpatial + 3>& out_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& /* out_g_n_k_wos_strides */,
const std::array<index_t, NDimSpatial + 3>& wei_g_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 3>& /* wei_g_k_c_xs_strides */,
const std::array<index_t, NDimSpatial + 3>& in_g_n_c_wis_lengths,
const std::array<index_t, NDimSpatial + 3>& /* in_g_n_c_wis_strides */,
const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads,
const std::array<index_t, NDimSpatial>& /* input_right_pads */,
const std::array<index_t, NDimSpatial>& tildes)
{
index_t i_ytilde = tildes[0];
index_t i_xtilde = tildes[1];
const index_t N = in_g_n_c_wis_lengths[1];
const index_t K = wei_g_k_c_xs_lengths[1];
const index_t Hi = in_g_n_c_wis_lengths[3];
const index_t Wi = in_g_n_c_wis_lengths[4];
const index_t Ho = out_g_n_k_wos_lengths[3];
const index_t Wo = out_g_n_k_wos_lengths[4];
const index_t Y = wei_g_k_c_xs_lengths[3];
const index_t X = wei_g_k_c_xs_lengths[4];
const index_t InLeftPadH = input_left_pads[0];
const index_t InLeftPadW = input_left_pads[1];
const index_t ConvStrideH = conv_filter_strides[0];
const index_t ConvStrideW = conv_filter_strides[1];
const index_t ConvDilationH = conv_filter_dilations[0];
const index_t ConvDilationW = conv_filter_dilations[1];
const index_t AK0 = K / AK1;
// assume packed
const auto out_n_ho_wo_k_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N, Ho, Wo, K));
if constexpr(ConvBwdDataSpecialization ==
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::
Filter1x1Stride1Pad0)
{
// A: output tensor
const auto out_gemmak0_gemmmraw_gemmak1_grid_desc = transform_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)),
make_tuple(make_pass_through_transform(N * Ho * Wo),
make_unmerge_transform(make_tuple(AK0, AK1))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<1>{}, Sequence<0, 2>{}));
const auto out_gemmak0_gemmm_gemmak1_grid_desc =
ck::tensor_operation::device::PadTensorDescriptor(
out_gemmak0_gemmmraw_gemmak1_grid_desc,
make_tuple(AK0, GemmMPerBlock, AK1),
Sequence<false, DoPadGemmM, false>{});
return out_gemmak0_gemmm_gemmak1_grid_desc;
}
else
{
const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
const auto YTilde = ConvStrideH / GcdStrideDilationH;
const auto XTilde = ConvStrideW / GcdStrideDilationW;
const auto YDot = math::integer_divide_ceil(Y, YTilde);
const auto XDot = math::integer_divide_ceil(X, XTilde);
const auto HTilde =
Ho + math::integer_divide_ceil(ConvDilationH * (Y - I1), ConvStrideH);
const auto WTilde =
Wo + math::integer_divide_ceil(ConvDilationW * (X - I1), ConvStrideW);
// only work on HTilde and WTilde that contribute to non-padding area of input tensor
const auto IHTildeSliceBegin = math::integer_divide_floor(
math::max(I0, InLeftPadH - ConvDilationH * (YTilde - I1)), ConvStrideH);
const auto IWTildeSliceBegin = math::integer_divide_floor(
math::max(I0, InLeftPadW - ConvDilationW * (XTilde - I1)), ConvStrideW);
const auto IHTildeSliceEnd = math::min(
HTilde, math::integer_divide_ceil(InLeftPadH + Hi - I1, ConvStrideH) + I1);
const auto IWTildeSliceEnd = math::min(
WTilde, math::integer_divide_ceil(InLeftPadW + Wi - I1, ConvStrideW) + I1);
const auto HTildeSlice = IHTildeSliceEnd - IHTildeSliceBegin;
const auto WTildeSlice = IWTildeSliceEnd - IWTildeSliceBegin;
// GemmK is different for each GEMM
const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde);
const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde);
// A: output tensor
const auto out_n_hop_wop_k_grid_desc = transform_tensor_descriptor(
out_n_ho_wo_k_grid_desc,
make_tuple(make_pass_through_transform(N),
make_pad_transform(Ho, I0, I0),
make_pad_transform(Wo, I0, I0),
make_pass_through_transform(K)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto out_n_ydot_htilde_xdot_wtilde_k_grid_desc = transform_tensor_descriptor(
out_n_hop_wop_k_grid_desc,
make_tuple(
make_pass_through_transform(N),
make_embed_transform(make_tuple(YDot, HTilde),
make_tuple(-ConvDilationH / GcdStrideDilationH, I1)),
make_embed_transform(make_tuple(XDot, WTilde),
make_tuple(-ConvDilationW / GcdStrideDilationW, I1)),
make_pass_through_transform(K)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
const auto out_n_ydotslice_htildeslice_xdotslice_wtildeslice_ak0_ak1_grid_desc =
transform_tensor_descriptor(
out_n_ydot_htilde_xdot_wtilde_k_grid_desc,
make_tuple(make_pass_through_transform(N),
make_slice_transform(YDot, I0, YDotSlice),
make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice),
make_slice_transform(XDot, I0, XDotSlice),
make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice),
make_unmerge_transform(make_tuple(AK0, AK1))),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5>{}),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5, 6>{}));
const auto out_gemmak0_gemmmraw_gemmak1_grid_desc = transform_tensor_descriptor(
out_n_ydotslice_htildeslice_xdotslice_wtildeslice_ak0_ak1_grid_desc,
make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, AK0)),
make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice)),
make_pass_through_transform(AK1)),
make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}, Sequence<6>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
const auto out_gemmak0_gemmm_gemmak1_grid_desc =
ck::tensor_operation::device::PadTensorDescriptor(
out_gemmak0_gemmmraw_gemmak1_grid_desc,
make_tuple(AK0, GemmMPerBlock, AK1),
Sequence<false, DoPadGemmM, false>{});
return out_gemmak0_gemmm_gemmak1_grid_desc;
}
}
template <typename BLayout,
typename std::enable_if<NDimSpatial == 2 &&
is_same_v<BLayout, tensor_layout::convolution::GKYXC>,
bool>::type = false>
static auto MakeBDescriptor_BK0_N_BK1(
const std::array<index_t, NDimSpatial + 3>& out_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& /* out_g_n_k_wos_strides */,
const std::array<index_t, NDimSpatial + 3>& wei_g_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 3>& /* wei_g_k_c_xs_strides */,
const std::array<index_t, NDimSpatial + 3>& in_g_n_c_wis_lengths,
const std::array<index_t, NDimSpatial + 3>& /* in_g_n_c_wis_strides */,
const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& /* input_left_pads */,
const std::array<index_t, NDimSpatial>& /* input_right_pads */,
const std::array<index_t, NDimSpatial>& tildes)
{
index_t i_ytilde = tildes[0];
index_t i_xtilde = tildes[1];
const index_t N = in_g_n_c_wis_lengths[1];
const index_t K = wei_g_k_c_xs_lengths[1];
const index_t C = wei_g_k_c_xs_lengths[2];
const index_t Ho = out_g_n_k_wos_lengths[3];
const index_t Wo = out_g_n_k_wos_lengths[4];
const index_t Y = wei_g_k_c_xs_lengths[3];
const index_t X = wei_g_k_c_xs_lengths[4];
const index_t ConvStrideH = conv_filter_strides[0];
const index_t ConvStrideW = conv_filter_strides[1];
const index_t ConvDilationH = conv_filter_dilations[0];
const index_t ConvDilationW = conv_filter_dilations[1];
const index_t BK0 = K / BK1;
// assume packed
const auto wei_k_y_x_c_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(K, Y, X, C));
if constexpr(ConvBwdDataSpecialization ==
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::
Filter1x1Stride1Pad0)
{
// B: weight tensor
const auto wei_gemmbk0_gemmnraw_gemmbk1_grid_desc =
transform_tensor_descriptor(make_naive_tensor_descriptor_packed(make_tuple(K, C)),
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
make_naive_tensor_descriptor(make_tuple(N * Ho * Wo, C), make_tuple(I0, I1));
const auto wei_gemmbk0_gemmn_gemmbk1_grid_desc =
ck::tensor_operation::device::PadTensorDescriptor(
wei_gemmbk0_gemmnraw_gemmbk1_grid_desc,
make_tuple(BK0, GemmNPerBlock, BK1),
Sequence<false, DoPadGemmN, false>{});
return wei_gemmbk0_gemmn_gemmbk1_grid_desc;
}
else
{
const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
const auto YTilde = ConvStrideH / GcdStrideDilationH;
const auto XTilde = ConvStrideW / GcdStrideDilationW;
const auto YDot = math::integer_divide_ceil(Y, YTilde);
const auto XDot = math::integer_divide_ceil(X, XTilde);
// GemmK is different for each GEMM
const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde);
const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde);
// B weight tensor
const auto wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc = transform_tensor_descriptor(
wei_k_y_x_c_grid_desc,
make_tuple(make_pass_through_transform(K),
make_embed_transform(make_tuple(YDot, YTilde),
make_tuple(ConvStrideH / GcdStrideDilationH, I1)),
make_embed_transform(make_tuple(XDot, XTilde),
make_tuple(ConvStrideW / GcdStrideDilationW, I1)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
const auto wei_bk0_bk1_ydotslice_xdotslice_c_grid_desc =
transform_tensor_descriptor(wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
make_slice_transform(YDot, I0, YDotSlice),
make_slice_transform(XDot, I0, XDotSlice),
make_freeze_transform(i_ytilde),
make_freeze_transform(i_xtilde),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<3>{},
Sequence<2>{},
Sequence<4>{},
Sequence<5>{}),
make_tuple(Sequence<0, 1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<>{},
Sequence<>{},
Sequence<4>{}));
const auto wei_gemmbk0_gemmnraw_gemmbk1_grid_desc = transform_tensor_descriptor(
wei_bk0_bk1_ydotslice_xdotslice_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, BK0)),
make_pass_through_transform(C),
make_pass_through_transform(BK1)),
make_tuple(Sequence<2, 3, 0>{}, Sequence<4>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
const auto wei_gemmbk0_gemmn_gemmbk1_grid_desc =
ck::tensor_operation::device::PadTensorDescriptor(
wei_gemmbk0_gemmnraw_gemmbk1_grid_desc,
make_tuple(
wei_gemmbk0_gemmnraw_gemmbk1_grid_desc.GetLength(I0), GemmNPerBlock, BK1),
Sequence<false, DoPadGemmN, false>{});
return wei_gemmbk0_gemmn_gemmbk1_grid_desc;
}
}
template <typename CLayout,
typename std::enable_if<NDimSpatial == 2 &&
(is_same_v<CLayout, tensor_layout::convolution::GNHWC> ||
is_same_v<CLayout, tensor_layout::convolution::NHWGC> ||
is_same_v<CLayout, tensor_layout::convolution::G_NHW_C>),
bool>::type = false>
static auto
MakeCDescriptor_M_N(const std::array<index_t, NDimSpatial + 3>& out_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& /* out_g_n_k_wos_strides */,
const std::array<index_t, NDimSpatial + 3>& wei_g_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 3>& /* wei_g_k_c_xs_strides */,
const std::array<index_t, NDimSpatial + 3>& in_g_n_c_wis_lengths,
const std::array<index_t, NDimSpatial + 3>& in_g_n_c_wis_strides,
const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads,
const std::array<index_t, NDimSpatial>& input_right_pads,
const std::array<index_t, NDimSpatial>& tildes)
{
index_t i_ytilde = tildes[0];
index_t i_xtilde = tildes[1];
const index_t N = in_g_n_c_wis_lengths[1];
const index_t C = wei_g_k_c_xs_lengths[2];
const index_t Hi = in_g_n_c_wis_lengths[3];
const index_t Wi = in_g_n_c_wis_lengths[4];
const index_t Ho = out_g_n_k_wos_lengths[3];
const index_t Wo = out_g_n_k_wos_lengths[4];
const index_t Y = wei_g_k_c_xs_lengths[3];
const index_t X = wei_g_k_c_xs_lengths[4];
const index_t InLeftPadH = input_left_pads[0];
const index_t InLeftPadW = input_left_pads[1];
const index_t InRightPadH = input_right_pads[0];
const index_t InRightPadW = input_right_pads[1];
const index_t ConvStrideH = conv_filter_strides[0];
const index_t ConvStrideW = conv_filter_strides[1];
const index_t ConvDilationH = conv_filter_dilations[0];
const index_t ConvDilationW = conv_filter_dilations[1];
// assume strided
const auto in_n_hi_wi_c_grid_desc =
make_naive_tensor_descriptor(make_tuple(N, Hi, Wi, C),
make_tuple(in_g_n_c_wis_strides[1],
in_g_n_c_wis_strides[3],
in_g_n_c_wis_strides[4],
in_g_n_c_wis_strides[2]));
if constexpr(ConvBwdDataSpecialization ==
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::
Filter1x1Stride1Pad0)
{
// C: input tensor
const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
in_n_hi_wi_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(I1, Ho), make_tuple(I1, ConvStrideH)),
make_embed_transform(make_tuple(I1, Wo), make_tuple(I1, ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
const auto in_gemmmraw_gemmnraw_grid_desc = transform_tensor_descriptor(
in_n_y_ho_x_wo_c_grid_desc,
make_tuple(make_freeze_transform(I0),
make_freeze_transform(I0),
make_merge_transform(make_tuple(N, Ho, Wo)),
make_pass_through_transform(C)),
make_tuple(Sequence<1>{}, Sequence<3>{}, Sequence<0, 2, 4>{}, Sequence<5>{}),
make_tuple(Sequence<>{}, Sequence<>{}, Sequence<0>{}, Sequence<1>{}));
const auto in_gemmm_gemmn_grid_desc = ck::tensor_operation::device::PadTensorDescriptor(
in_gemmmraw_gemmnraw_grid_desc,
make_tuple(GemmMPerBlock, GemmNPerBlock),
Sequence<DoPadGemmM, DoPadGemmN>{});
return in_gemmm_gemmn_grid_desc;
}
else
{
const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
const auto YTilde = ConvStrideH / GcdStrideDilationH;
const auto XTilde = ConvStrideW / GcdStrideDilationW;
const auto HTilde =
Ho + math::integer_divide_ceil(ConvDilationH * (Y - I1), ConvStrideH);
const auto WTilde =
Wo + math::integer_divide_ceil(ConvDilationW * (X - I1), ConvStrideW);
// only work on HTilde and WTilde that contribute to non-padding area of input tensor
const auto IHTildeSliceBegin = math::integer_divide_floor(
math::max(I0, InLeftPadH - ConvDilationH * (YTilde - I1)), ConvStrideH);
const auto IWTildeSliceBegin = math::integer_divide_floor(
math::max(I0, InLeftPadW - ConvDilationW * (XTilde - I1)), ConvStrideW);
const auto IHTildeSliceEnd = math::min(
HTilde, math::integer_divide_ceil(InLeftPadH + Hi - I1, ConvStrideH) + I1);
const auto IWTildeSliceEnd = math::min(
WTilde, math::integer_divide_ceil(InLeftPadW + Wi - I1, ConvStrideW) + I1);
const auto HTildeSlice = IHTildeSliceEnd - IHTildeSliceBegin;
const auto WTildeSlice = IWTildeSliceEnd - IWTildeSliceBegin;
// C: input tensor
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
in_n_hi_wi_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_pad_transform(Hi, InLeftPadH, InRightPadH),
make_pad_transform(Wi, InLeftPadW, InRightPadW),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc = transform_tensor_descriptor(
in_n_hip_wip_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(YTilde, HTilde),
make_tuple(ConvDilationH, ConvStrideH)),
make_embed_transform(make_tuple(XTilde, WTilde),
make_tuple(ConvDilationW, ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
const auto in_n_htildeslice_wtildeslice_c_grid_desc = transform_tensor_descriptor(
in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_freeze_transform(i_ytilde),
make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice),
make_freeze_transform(i_xtilde),
make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5>{}),
make_tuple(Sequence<0>{},
Sequence<>{},
Sequence<1>{},
Sequence<>{},
Sequence<2>{},
Sequence<3>{}));
const auto in_gemmmraw_gemmnraw_grid_desc = transform_tensor_descriptor(
in_n_htildeslice_wtildeslice_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice)),
make_pass_through_transform(C)),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto in_gemmm_gemmn_grid_desc = ck::tensor_operation::device::PadTensorDescriptor(
in_gemmmraw_gemmnraw_grid_desc,
make_tuple(GemmMPerBlock, GemmNPerBlock),
Sequence<DoPadGemmM, DoPadGemmN>{});
return in_gemmm_gemmn_grid_desc;
}
}
// for input bias
template <typename CLayout,
typename std::enable_if<NDimSpatial == 2 &&
(is_same_v<CLayout, tensor_layout::convolution::GC> ||
is_same_v<CLayout, tensor_layout::convolution::G_C>),
bool>::type = false>
static auto
MakeCDescriptor_M_N(const std::array<index_t, NDimSpatial + 3>& out_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& /* out_g_n_k_wos_strides */,
const std::array<index_t, NDimSpatial + 3>& wei_g_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 3>& /* wei_g_k_c_xs_strides */,
const std::array<index_t, NDimSpatial + 3>& in_g_n_c_wis_lengths,
const std::array<index_t, NDimSpatial + 3>& /* in_g_n_c_wis_strides */,
const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads,
const std::array<index_t, NDimSpatial>& /* input_right_pads */,
const std::array<index_t, NDimSpatial>& /* tildes */)
{
const index_t N = in_g_n_c_wis_lengths[1];
const index_t C = wei_g_k_c_xs_lengths[2];
const index_t Hi = in_g_n_c_wis_lengths[3];
const index_t Wi = in_g_n_c_wis_lengths[4];
const index_t Ho = out_g_n_k_wos_lengths[3];
const index_t Wo = out_g_n_k_wos_lengths[4];
const index_t Y = wei_g_k_c_xs_lengths[3];
const index_t X = wei_g_k_c_xs_lengths[4];
const index_t InLeftPadH = input_left_pads[0];
const index_t InLeftPadW = input_left_pads[1];
const index_t ConvStrideH = conv_filter_strides[0];
const index_t ConvStrideW = conv_filter_strides[1];
const index_t ConvDilationH = conv_filter_dilations[0];
const index_t ConvDilationW = conv_filter_dilations[1];
if constexpr(ConvBwdDataSpecialization ==
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::
Filter1x1Stride1Pad0)
{
const auto in_gemmm_gemmn_grid_desc =
make_naive_tensor_descriptor(make_tuple(N * Ho * Wo, C), make_tuple(I0, I1));
return in_gemmm_gemmn_grid_desc;
}
else
{
const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
const auto YTilde = ConvStrideH / GcdStrideDilationH;
const auto XTilde = ConvStrideW / GcdStrideDilationW;
const auto HTilde =
Ho + math::integer_divide_ceil(ConvDilationH * (Y - I1), ConvStrideH);
const auto WTilde =
Wo + math::integer_divide_ceil(ConvDilationW * (X - I1), ConvStrideW);
// only work on HTilde and WTilde that contribute to non-padding area of input tensor
const auto IHTildeSliceBegin = math::integer_divide_floor(
math::max(I0, InLeftPadH - ConvDilationH * (YTilde - I1)), ConvStrideH);
const auto IWTildeSliceBegin = math::integer_divide_floor(
math::max(I0, InLeftPadW - ConvDilationW * (XTilde - I1)), ConvStrideW);
const auto IHTildeSliceEnd = math::min(
HTilde, math::integer_divide_ceil(InLeftPadH + Hi - I1, ConvStrideH) + I1);
const auto IWTildeSliceEnd = math::min(
WTilde, math::integer_divide_ceil(InLeftPadW + Wi - I1, ConvStrideW) + I1);
const auto HTildeSlice = IHTildeSliceEnd - IHTildeSliceBegin;
const auto WTildeSlice = IWTildeSliceEnd - IWTildeSliceBegin;
// bias tensor
const auto in_gemmmraw_gemmnraw_grid_desc = make_naive_tensor_descriptor(
make_tuple(N * HTildeSlice * WTildeSlice, C), make_tuple(I0, I1));
const auto in_gemmm_gemmn_grid_desc = ck::tensor_operation::device::PadTensorDescriptor(
in_gemmmraw_gemmnraw_grid_desc,
make_tuple(GemmMPerBlock, GemmNPerBlock),
Sequence<DoPadGemmM, DoPadGemmN>{});
return in_gemmm_gemmn_grid_desc;
}
}
};
} // namespace tensor_operation
} // namespace ck
......@@ -16,6 +16,7 @@ namespace tensor_operation {
template <index_t NDimSpatial, device::ConvolutionForwardSpecialization ConvForwardSpecialization>
struct TransformConvFwdToGemm
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
template <typename ALayout,
......@@ -864,6 +865,29 @@ struct TransformConvFwdToGemm
return out_gemmm_gemmn_desc;
}
// for output bias
template <typename CLayout,
typename std::enable_if<is_same_v<CLayout, tensor_layout::convolution::GK> ||
is_same_v<CLayout, tensor_layout::convolution::G_K>,
bool>::type = false>
static auto
MakeCDescriptor_M_N(const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& /* c_g_n_k_wos_strides */)
{
const index_t N = c_g_n_k_wos_lengths[1];
const index_t K = c_g_n_k_wos_lengths[2];
const index_t NHoWo = N * std::accumulate(c_g_n_k_wos_lengths.begin() + 3,
c_g_n_k_wos_lengths.begin() + 3 + NDimSpatial,
index_t{1},
std::multiplies<index_t>());
const auto out_gemmm_gemmn_desc =
make_naive_tensor_descriptor(make_tuple(NHoWo, K), make_tuple(I0, I1));
return out_gemmm_gemmn_desc;
}
};
} // namespace tensor_operation
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_IGNORE_HPP
#define CK_IGNORE_HPP
#pragma once
// https://en.cppreference.com/w/cpp/utility/tuple/ignore
......@@ -21,4 +20,3 @@ struct ignore_t
inline constexpr detail::ignore_t ignore;
} // namespace ck
#endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment