Commit 463e2aa1 authored by aska-0096's avatar aska-0096
Browse files

Merge branch 'develop' of https://github.com/ROCmSoftwarePlatform/composable_kernel into wmma_op

parents 6e106c19 236bd148
......@@ -15,7 +15,7 @@
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/host_common_util.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batchnorm_forward_nhwc_c.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batchnorm_forward.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batchnorm_forward_impl.hpp"
#include "ck/library/utility/host_common_util.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
......@@ -142,6 +142,8 @@ bool bnorm_fwd_nhwc_test(bool do_verification,
constexpr int Rank = 4;
constexpr int NumReduceDim = 3;
// when using lengths[] to create a tensor, lengths[0] is the length of highest dimension
// eg. N of NHWC, so lengths[3] is the dimension C length of NHWC
const std::vector<size_t> scaleBiasMeanVarLengths = {inOutLengths[3]};
// input data of the batchnorm forward algorithm
......@@ -300,7 +302,7 @@ bool bnorm_fwd_nhwc_test(bool do_verification,
i_inOutLengths,
i_inOutStrides,
i_inOutStrides,
{0, 1, 2},
{0, 1, 2}, // indicates physical indices of reduce dimensions in lengths[] and strides[]
i_scaleBiasMeanVarLengths,
i_scaleBiasMeanVarStrides,
i_scaleBiasMeanVarStrides,
......@@ -366,13 +368,15 @@ bool bnorm_fwd_nhwc_test(bool do_verification,
{
using ReferenceBatchNormFwdInstance =
ck::tensor_operation::host::ReferenceBatchNormFwd_Input_N_H_W_C_Output_C<InOutDataType,
InOutDataType,
AccDataType,
AccDataType,
AccDataType,
AccDataType,
PassThroughOp>;
ck::tensor_operation::host::ReferenceBatchNormFwd<InOutDataType,
InOutDataType,
AccDataType,
AccDataType,
AccDataType,
AccDataType,
PassThroughOp,
Rank,
NumReduceDim>;
auto batchNormFwd_ref = ReferenceBatchNormFwdInstance{};
......@@ -380,7 +384,7 @@ bool bnorm_fwd_nhwc_test(bool do_verification,
i_inOutLengths,
i_inOutStrides,
i_inOutStrides,
{0, 1, 2},
{0, 1, 2}, // indicates physical indices of reduce dimensions in lengths[] and strides[]
i_scaleBiasMeanVarLengths,
i_scaleBiasMeanVarStrides,
i_scaleBiasMeanVarStrides,
......
......@@ -163,6 +163,13 @@
// tuning parameter
#define CK_WORKAROUND_SWDEV_325164 0
// workaround: a BF16 attention kernel for gfx908 is likely affected by a compiler issue
#ifdef __gfx908__
#define CK_WORKAROUND_SWDEV_XXXXXX_BF16_ATTEN_FWD_GFX908_ISSUE 1
#else // __gfx90a__, ...
#define CK_WORKAROUND_SWDEV_XXXXXX_BF16_ATTEN_FWD_GFX908_ISSUE 0
#endif // __gfx908__
namespace ck {
enum struct InMemoryDataOperationEnum
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <array>
#include <memory>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <index_t Rank, index_t NumBatchNormReduceDim, typename DyElementwiseOp>
struct DeviceBatchNormBwd : public BaseOperator
{
static constexpr index_t NumInvariantDim = Rank - NumBatchNormReduceDim;
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const std::array<index_t, Rank> xyLengths,
const std::array<index_t, Rank> xStrides,
const std::array<index_t, Rank> dyStrides,
const std::array<index_t, Rank> dxStrides,
const std::array<int, NumBatchNormReduceDim> reduceDims,
const std::array<ck::index_t, NumInvariantDim> bnScaleBiasMeanVarLengths,
const std::array<ck::index_t, NumInvariantDim> bnScaleStrides,
const std::array<ck::index_t, NumInvariantDim> bnBiasStrides,
const std::array<ck::index_t, NumInvariantDim> bnMeanVarStrides,
const void* p_x,
const void* p_dy,
const void* p_scale,
const void* p_savedMean,
const void* p_savedInvVar,
double epsilon,
const DyElementwiseOp dy_elementwise_op,
void* p_dx,
void* p_dscale,
void* p_dbias) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
template <index_t Rank, index_t NumBatchNormReduceDim, typename DyElementwiseOp>
using DeviceBatchNormBwdPtr =
std::unique_ptr<DeviceBatchNormBwd<Rank, NumBatchNormReduceDim, DyElementwiseOp>>;
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -13,7 +13,15 @@ namespace ck {
namespace tensor_operation {
namespace device {
template <index_t Rank, index_t NumBatchNormReduceDim, typename YElementwiseOp>
template <typename XDataType,
typename YDataType,
typename AccDataType,
typename ScaleDataType,
typename BiasDataType,
typename MeanVarDataType,
typename YElementwiseOp,
index_t Rank,
index_t NumBatchNormReduceDim>
struct DeviceBatchNormFwd : public BaseOperator
{
virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(
......@@ -40,9 +48,24 @@ struct DeviceBatchNormFwd : public BaseOperator
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
template <index_t Rank, index_t NumBatchNormReduceDim, typename YElementwiseOp>
using DeviceBatchNormFwdPtr =
std::unique_ptr<DeviceBatchNormFwd<Rank, NumBatchNormReduceDim, YElementwiseOp>>;
template <typename XDataType,
typename YDataType,
typename AccDataType,
typename ScaleDataType,
typename BiasDataType,
typename MeanVarDataType,
typename YElementwiseOp,
index_t Rank,
index_t NumBatchNormReduceDim>
using DeviceBatchNormFwdPtr = std::unique_ptr<DeviceBatchNormFwd<XDataType,
YDataType,
AccDataType,
ScaleDataType,
BiasDataType,
MeanVarDataType,
YElementwiseOp,
Rank,
NumBatchNormReduceDim>>;
} // namespace device
} // namespace tensor_operation
......
......@@ -13,13 +13,22 @@ namespace ck {
namespace tensor_operation {
namespace device {
template <index_t Rank, index_t NumBatchNormReduceDim>
template <typename XDataType,
typename YDataType,
typename AccDataType,
typename ScaleDataType,
typename BiasDataType,
typename MeanVarDataType,
typename YElementwiseOp,
index_t Rank,
index_t NumBatchNormReduceDim>
struct DeviceBatchNormInfer : public BaseOperator
{
virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(
const std::array<index_t, Rank> xyLengths,
const std::array<index_t, Rank> xStrides,
const std::array<index_t, Rank> yStrides,
const std::array<int, NumBatchNormReduceDim> reduceDims,
const std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleBiasMeanVarLengths,
const std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleStrides,
const std::array<index_t, Rank - NumBatchNormReduceDim> bnBiasStrides,
......@@ -28,6 +37,7 @@ struct DeviceBatchNormInfer : public BaseOperator
const void* bnScale,
const void* bnBias,
double epsilon,
const YElementwiseOp y_elementwise_op,
const void* estimatedMean,
const void* estimatedInvVariance,
void* p_y) = 0;
......@@ -35,8 +45,24 @@ struct DeviceBatchNormInfer : public BaseOperator
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
template <index_t Rank, index_t NumBatchNormReduceDim>
using DeviceBatchNormInferPtr = std::unique_ptr<DeviceBatchNormInfer<Rank, NumBatchNormReduceDim>>;
template <typename XDataType,
typename YDataType,
typename AccDataType,
typename ScaleDataType,
typename BiasDataType,
typename MeanVarDataType,
typename YElementwiseOp,
index_t Rank,
index_t NumBatchNormReduceDim>
using DeviceBatchNormInferPtr = std::unique_ptr<DeviceBatchNormInfer<XDataType,
YDataType,
AccDataType,
ScaleDataType,
BiasDataType,
MeanVarDataType,
YElementwiseOp,
Rank,
NumBatchNormReduceDim>>;
} // namespace device
} // namespace tensor_operation
......
......@@ -700,7 +700,7 @@ struct DeviceBatchedGemmMultiD_Xdl : public DeviceBatchedGemmMultiD<ALayout,
<< BlockSize << ", "
<< MPerBlock << ", "
<< NPerBlock << ", "
<< KPerBlock
<< KPerBlock << ", "
<< AK1 << ", "
<< BK1 << ", "
<< getGemmSpecializationString(GemmSpec)
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/utility/reduction_operator.hpp"
#include "ck/tensor_operation/gpu/device/device_batchnorm_backward.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batchnorm_backward_blockwise_welford.hpp"
#include "ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_first_half.hpp"
#include "ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_second_half_multiblock_reduce_first_half.hpp"
#include "ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_reduce_second_half_batchnorm_backward_final.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/welford_helper.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <typename XDataType,
typename DxDataType,
typename DyDataType,
typename AccDataType,
typename ScaleDataType,
typename BiasDataType,
typename MeanVarDataType,
typename DyElementwiseOp,
index_t Rank,
index_t NumBatchNormReduceDim,
bool UseMultiblockInK,
index_t BlockSize,
index_t MThreadClusterSize,
index_t KThreadClusterSize,
index_t MThreadSliceSize,
index_t KThreadSliceSize,
index_t XDyDxVectorDim,
index_t XSrcVectorSize,
index_t DySrcVectorSize,
index_t DxDstVectorSize,
index_t ScaleSrcDstVectorSize,
index_t BiasDstVectorSize,
index_t MeanVarSrcVectorSize>
struct DeviceBatchNormBwdImpl
: public DeviceBatchNormBwd<Rank, NumBatchNormReduceDim, DyElementwiseOp>
{
static_assert(Rank <= 6, "Bigger Rank size is not supported!");
static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize,
"Invalid thread cluster size assignments!");
static_assert((XDyDxVectorDim == 0 && MThreadSliceSize % XSrcVectorSize == 0 &&
MThreadSliceSize % DySrcVectorSize == 0 &&
MThreadSliceSize % DxDstVectorSize == 0) ||
(XDyDxVectorDim == 1 && KThreadSliceSize % XSrcVectorSize == 0 &&
KThreadSliceSize % DySrcVectorSize == 0 &&
KThreadSliceSize % DxDstVectorSize == 0),
"Invalid thread slice sizes and/or vector sizes configuration, please check!");
static constexpr index_t NumInvariantDim = Rank - NumBatchNormReduceDim;
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
static auto MakeXY2dDescriptor(const std::array<index_t, Rank>& xyLengths,
const std::array<index_t, Rank>& xyStrides,
int blkGroupSize,
int numBlockTileIteration)
{
const auto tupleXYLengths =
generate_tuple([&](auto I) { return xyLengths[I]; }, Number<Rank>{});
const auto tupleXYStrides =
generate_tuple([&](auto I) { return xyStrides[I]; }, Number<Rank>{});
const auto raw_grid_desc = make_naive_tensor_descriptor(tupleXYLengths, tupleXYStrides);
const auto grid_desc_m_k = [&]() {
using InvariantDims = typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type;
using ReduceDims = typename arithmetic_sequence_gen<NumInvariantDim, Rank, 1>::type;
const auto reduceDimLengths =
generate_tuple([&](auto I) { return xyLengths[NumInvariantDim + I]; },
Number<NumBatchNormReduceDim>{});
const auto invariantDimLengths =
generate_tuple([&](auto I) { return xyLengths[I]; }, Number<NumInvariantDim>{});
return transform_tensor_descriptor(raw_grid_desc,
make_tuple(make_merge_transform(invariantDimLengths),
make_merge_transform(reduceDimLengths)),
make_tuple(InvariantDims{}, ReduceDims{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}();
const auto invariantLength = grid_desc_m_k.GetLength(Number<0>{});
const auto reduceLength = grid_desc_m_k.GetLength(Number<1>{});
const int workSizePerBlock = K_BlockTileSize * numBlockTileIteration;
const auto mPad =
math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
const auto kPad = workSizePerBlock * blkGroupSize - reduceLength;
auto grid_desc_m_k_padded =
transform_tensor_descriptor(grid_desc_m_k,
make_tuple(make_right_pad_transform(invariantLength, mPad),
make_right_pad_transform(reduceLength, kPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return (grid_desc_m_k_padded);
};
static auto MakeMultiblockFirstReduceOutputMG2dDescriptor(int invariantLength, int blkGroupSize)
{
const auto grid_desc_m_g =
make_naive_tensor_descriptor_packed(make_tuple(invariantLength, blkGroupSize));
const auto mPad =
math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
auto grid_desc_m_g_padded =
transform_tensor_descriptor(grid_desc_m_g,
make_tuple(make_right_pad_transform(invariantLength, mPad),
make_pass_through_transform(blkGroupSize)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return (grid_desc_m_g_padded);
};
static auto MakeMultiblockFinalReduceInputMK2dDescriptor(int invariantLength, int blkGroupSize)
{
const auto reduceLength = blkGroupSize;
const auto grid_desc_m_k =
make_naive_tensor_descriptor_packed(make_tuple(invariantLength, reduceLength));
const auto mPad =
math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
const auto kPad =
math::integer_least_multiple(reduceLength, KThreadClusterSize) - reduceLength;
auto grid_desc_m_k_padded =
transform_tensor_descriptor(grid_desc_m_k,
make_tuple(make_right_pad_transform(invariantLength, mPad),
make_right_pad_transform(reduceLength, kPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return (grid_desc_m_k_padded);
};
static auto
MakeScaleBiasMeanVar1dDescriptor(const std::array<index_t, NumInvariantDim>& lengths,
const std::array<index_t, NumInvariantDim>& strides)
{
const auto tupleLengths =
generate_tuple([&](auto I) { return lengths[I]; }, Number<NumInvariantDim>{});
const auto tupleStrides =
generate_tuple([&](auto I) { return strides[I]; }, Number<NumInvariantDim>{});
auto raw_grid_desc = make_naive_tensor_descriptor(tupleLengths, tupleStrides);
auto grid_desc_m = transform_tensor_descriptor(
raw_grid_desc,
make_tuple(make_merge_transform(tupleLengths)),
make_tuple(typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type{}),
make_tuple(Sequence<0>{}));
const auto invariantLength = grid_desc_m.GetLength(Number<0>{});
const auto mPad =
math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
auto grid_desc_m_padded =
transform_tensor_descriptor(grid_desc_m,
make_tuple(make_right_pad_transform(invariantLength, mPad)),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0>{}));
return (grid_desc_m_padded);
};
using XYGridDesc_M_K = decltype(MakeXY2dDescriptor({1}, {1}, 1, 1));
using ScaleBiasGridDesc_M = decltype(MakeScaleBiasMeanVar1dDescriptor({1}, {1}));
using MeanVarGridDesc_M = ScaleBiasGridDesc_M;
struct Argument : public BaseArgument
{
Argument(const std::array<index_t, Rank> xyLengths,
const std::array<index_t, Rank> xStrides,
const std::array<index_t, Rank> dyStrides,
const std::array<index_t, Rank> dxStrides,
const std::array<int, NumBatchNormReduceDim> reduceDims,
const std::array<ck::index_t, NumInvariantDim> bnScaleBiasMeanVarLengths,
const std::array<ck::index_t, NumInvariantDim> bnScaleStrides,
const std::array<ck::index_t, NumInvariantDim> bnBiasStrides,
const std::array<ck::index_t, NumInvariantDim> bnMeanVarStrides,
const XDataType* p_x,
const DyDataType* p_dy,
const ScaleDataType* p_scale,
const MeanVarDataType* p_savedMean,
const MeanVarDataType* p_savedInvVar,
const DyElementwiseOp dy_elementwise_op,
double epsilon,
DxDataType* p_dx,
ScaleDataType* p_dscale,
BiasDataType* p_dbias)
: bnScaleBiasMeanVarLengths_(bnScaleBiasMeanVarLengths),
bnScaleStrides_(bnScaleStrides),
bnBiasStrides_(bnBiasStrides),
bnMeanVarStrides_(bnMeanVarStrides),
p_x_(p_x),
p_dy_(p_dy),
p_scale_(p_scale),
p_savedMean_(p_savedMean),
p_savedInvVar_(p_savedInvVar),
dy_elementwise_op_(dy_elementwise_op),
p_dx_(p_dx),
p_dscale_(p_dscale),
p_dbias_(p_dbias)
{
xyLengths_ =
shuffle_tensor_dimensions<Rank, NumBatchNormReduceDim>(xyLengths, reduceDims);
xStrides_ =
shuffle_tensor_dimensions<Rank, NumBatchNormReduceDim>(xStrides, reduceDims);
dyStrides_ =
shuffle_tensor_dimensions<Rank, NumBatchNormReduceDim>(dyStrides, reduceDims);
dxStrides_ =
shuffle_tensor_dimensions<Rank, NumBatchNormReduceDim>(dxStrides, reduceDims);
std::tie(invariant_length, reduce_length) =
get_2d_lengths<Rank, NumBatchNormReduceDim>(xyLengths_);
epsilon_ = type_convert<AccDataType>(epsilon);
haveSavedMeanInvVar_ = (p_savedMean_ != nullptr && p_savedInvVar_ != nullptr);
if(UseMultiblockInK)
{
int iterations = 1;
while(true)
{
int testBlkGroupSize = (reduce_length + (K_BlockTileSize * iterations) - 1) /
(K_BlockTileSize * iterations);
// we want the blkGroupSize be not more than 128
if(testBlkGroupSize <= 128)
break;
iterations++;
};
blkGroupSize = (reduce_length + (K_BlockTileSize * iterations) - 1) /
(K_BlockTileSize * iterations);
numBlockTileIteration = iterations;
}
else
{
blkGroupSize = 1;
numBlockTileIteration = (reduce_length + K_BlockTileSize - 1) / K_BlockTileSize;
};
gridSize = (invariant_length + M_BlockTileSize - 1) / M_BlockTileSize * blkGroupSize;
x_grid_desc_m_k =
MakeXY2dDescriptor(xyLengths_, xStrides_, blkGroupSize, numBlockTileIteration);
dy_grid_desc_m_k =
MakeXY2dDescriptor(xyLengths_, dyStrides_, blkGroupSize, numBlockTileIteration);
dx_grid_desc_m_k =
MakeXY2dDescriptor(xyLengths_, dxStrides_, blkGroupSize, numBlockTileIteration);
scale_grid_desc_m =
MakeScaleBiasMeanVar1dDescriptor(bnScaleBiasMeanVarLengths, bnScaleStrides);
bias_grid_desc_m =
MakeScaleBiasMeanVar1dDescriptor(bnScaleBiasMeanVarLengths, bnBiasStrides);
mean_var_grid_desc_m =
MakeScaleBiasMeanVar1dDescriptor(bnScaleBiasMeanVarLengths, bnMeanVarStrides);
}
AccDataType epsilon_;
bool haveSavedMeanInvVar_;
std::array<index_t, Rank> xyLengths_;
std::array<index_t, Rank> xStrides_;
std::array<index_t, Rank> dyStrides_;
std::array<index_t, Rank> dxStrides_;
std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleBiasMeanVarLengths_;
std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleStrides_;
std::array<index_t, Rank - NumBatchNormReduceDim> bnBiasStrides_;
std::array<index_t, Rank - NumBatchNormReduceDim> bnMeanVarStrides_;
const XDataType* p_x_;
const DyDataType* p_dy_;
const ScaleDataType* p_scale_;
const MeanVarDataType* p_savedMean_;
const MeanVarDataType* p_savedInvVar_;
const DyElementwiseOp dy_elementwise_op_;
DxDataType* p_dx_;
ScaleDataType* p_dscale_;
BiasDataType* p_dbias_;
long_index_t invariant_length;
long_index_t reduce_length;
int blkGroupSize;
int numBlockTileIteration;
size_t gridSize;
XYGridDesc_M_K x_grid_desc_m_k;
XYGridDesc_M_K dy_grid_desc_m_k;
XYGridDesc_M_K dx_grid_desc_m_k;
ScaleBiasGridDesc_M scale_grid_desc_m;
ScaleBiasGridDesc_M bias_grid_desc_m;
MeanVarGridDesc_M mean_var_grid_desc_m;
void* workspace_mean;
void* workspace_variance;
void* workspace_count;
void* workspace_savedMean;
void* workspace_savedInvVar;
void* workspace_reduce_dscale;
void* workspace_reduce_dbias;
};
size_t GetWorkSpaceSize(const BaseArgument* pArg) const override
{
const Argument* pArg_ = dynamic_cast<const Argument*>(pArg);
size_t workspace_size = 0;
if(UseMultiblockInK && pArg_->blkGroupSize > 1)
{
// workspace for the partial reduced result for dscale
workspace_size +=
pArg_->invariant_length * pArg_->blkGroupSize * sizeof(ScaleDataType) + 64;
// workspace for the partial reduced result for dbias
workspace_size +=
pArg_->invariant_length * pArg_->blkGroupSize * sizeof(BiasDataType) + 64;
if(!pArg_->haveSavedMeanInvVar_)
{
// workspace for welford intermediate mean
workspace_size +=
pArg_->invariant_length * pArg_->blkGroupSize * sizeof(MeanVarDataType) + 64;
// workspace for welford intermediate variance
workspace_size +=
pArg_->invariant_length * pArg_->blkGroupSize * sizeof(MeanVarDataType) + 64;
// workspace for welford intermediate count
workspace_size +=
pArg_->invariant_length * pArg_->blkGroupSize * sizeof(int32_t) + 64;
// workspace for welford result mean
workspace_size += pArg_->invariant_length * sizeof(MeanVarDataType) + 64;
// workspace for welford result inv_variance
workspace_size += pArg_->invariant_length * sizeof(MeanVarDataType) + 64;
};
}
return (workspace_size);
};
void SetWorkSpacePointer(BaseArgument* pArg, void* p_workspace) const override
{
Argument* pArg_ = dynamic_cast<Argument*>(pArg);
pArg_->p_workspace_ = p_workspace;
index_t space_sz;
// setup buffer for the partial reduced result for dscale
pArg_->workspace_reduce_dscale = pArg_->p_workspace_;
space_sz = pArg_->invariant_length * pArg_->blkGroupSize * sizeof(ScaleDataType);
space_sz = math::integer_least_multiple(space_sz, 64);
// setup buffer for the partial reduced result for dbias
pArg_->workspace_reduce_dbias =
reinterpret_cast<char*>(pArg_->workspace_reduce_dscale) + space_sz;
if(UseMultiblockInK && pArg_->blkGroupSize > 1)
{
space_sz = pArg_->invariant_length * pArg_->blkGroupSize * sizeof(BiasDataType);
space_sz = math::integer_least_multiple(space_sz, 64);
// setup buffer for welford intermediate mean
pArg_->workspace_mean =
reinterpret_cast<char*>(pArg_->workspace_reduce_dbias) + space_sz;
space_sz = pArg_->invariant_length * pArg_->blkGroupSize * sizeof(MeanVarDataType);
space_sz = math::integer_least_multiple(space_sz, 64);
// setup buffer for welford intermediate varirance
pArg_->workspace_variance = reinterpret_cast<char*>(pArg_->workspace_mean) + space_sz;
space_sz = pArg_->invariant_length * pArg_->blkGroupSize * sizeof(MeanVarDataType);
space_sz = math::integer_least_multiple(space_sz, 64);
// setup buffer for welford intermediate count
pArg_->workspace_count = reinterpret_cast<char*>(pArg_->workspace_variance) + space_sz;
space_sz = pArg_->invariant_length * pArg_->blkGroupSize * sizeof(int32_t);
space_sz = math::integer_least_multiple(space_sz, 64);
// setup buffer for welford result mean
pArg_->workspace_savedMean = reinterpret_cast<char*>(pArg_->workspace_count) + space_sz;
space_sz = pArg_->invariant_length * sizeof(MeanVarDataType);
space_sz = math::integer_least_multiple(space_sz, 64);
// setup buffer for welford result inv_variance
pArg_->workspace_savedInvVar =
reinterpret_cast<char*>(pArg_->workspace_savedMean) + space_sz;
};
};
struct Invoker : public BaseInvoker
{
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
float avg_time = 0;
const auto mean_var_count_grid_desc_m_g =
DeviceBatchNormBwdImpl::MakeMultiblockFirstReduceOutputMG2dDescriptor(
arg.invariant_length, arg.blkGroupSize);
const auto dscale_dbias_grid_desc_m_g =
DeviceBatchNormBwdImpl::MakeMultiblockFirstReduceOutputMG2dDescriptor(
arg.invariant_length, arg.blkGroupSize);
const auto mean_var_count_grid_desc_m_k =
DeviceBatchNormBwdImpl::MakeMultiblockFinalReduceInputMK2dDescriptor(
arg.invariant_length, arg.blkGroupSize);
const auto dscale_dbias_grid_desc_m_k =
DeviceBatchNormBwdImpl::MakeMultiblockFinalReduceInputMK2dDescriptor(
arg.invariant_length, arg.blkGroupSize);
using MeanVarCountGridDesc_M_G = decltype(mean_var_count_grid_desc_m_g);
using MeanVarCountGridDesc_M_K = decltype(mean_var_count_grid_desc_m_k);
using DscaleDbiasGridDesc_M_G = decltype(dscale_dbias_grid_desc_m_g);
using DscaleDbiasGridDesc_M_K = decltype(dscale_dbias_grid_desc_m_k);
using GridwiseWelfordSecondHalfReduceFirstHalf_ =
GridwiseWelfordSecondHalfReduceFirstHalf<XDataType,
DyDataType,
AccDataType,
ScaleDataType,
BiasDataType,
MeanVarDataType,
DyElementwiseOp,
XYGridDesc_M_K,
MeanVarGridDesc_M,
MeanVarCountGridDesc_M_K,
DscaleDbiasGridDesc_M_G,
BlockSize,
MThreadClusterSize,
KThreadClusterSize,
MThreadSliceSize,
KThreadSliceSize,
XDyDxVectorDim,
XSrcVectorSize,
DySrcVectorSize,
MeanVarSrcVectorSize>;
using GridwiseReduceSecondHalfBatchNormBwdFinal_ =
GridwiseReduceSecondHalfBatchNormBackwardFinal<XDataType,
DyDataType,
DxDataType,
AccDataType,
ScaleDataType,
BiasDataType,
MeanVarDataType,
DyElementwiseOp,
XYGridDesc_M_K,
DscaleDbiasGridDesc_M_K,
MeanVarGridDesc_M,
ScaleBiasGridDesc_M,
BlockSize,
MThreadClusterSize,
KThreadClusterSize,
MThreadSliceSize,
KThreadSliceSize,
XDyDxVectorDim,
XSrcVectorSize,
DySrcVectorSize,
DxDstVectorSize,
ScaleSrcDstVectorSize,
BiasDstVectorSize,
MeanVarSrcVectorSize>;
if(UseMultiblockInK && arg.blkGroupSize > 1)
{
using GetReduceCountPerThreadFunctor =
GetReduceCountPerThreadForMultiblockWelford<K_BlockTileSize, KThreadSliceSize>;
GetReduceCountPerThreadFunctor get_reduce_count_per_thread(
arg.blkGroupSize, arg.numBlockTileIteration, arg.reduce_length);
if(!arg.haveSavedMeanInvVar_)
{
using GridwiseMultiblockWelfordFirstHalf_ =
GridwiseMultiblockWelfordFirstHalf<XDataType,
AccDataType,
MeanVarDataType,
XYGridDesc_M_K,
MeanVarCountGridDesc_M_G,
GetReduceCountPerThreadFunctor,
BlockSize,
MThreadClusterSize,
KThreadClusterSize,
MThreadSliceSize,
KThreadSliceSize,
XDyDxVectorDim,
XSrcVectorSize>;
const auto kern_multiblock_welford_first_half =
kernel_multiblock_welford_first_half<GridwiseMultiblockWelfordFirstHalf_,
XDataType,
MeanVarDataType,
XYGridDesc_M_K,
MeanVarCountGridDesc_M_G,
GetReduceCountPerThreadFunctor>;
avg_time += launch_and_time_kernel(
stream_config,
kern_multiblock_welford_first_half,
dim3(arg.gridSize),
dim3(BlockSize),
0,
arg.x_grid_desc_m_k,
mean_var_count_grid_desc_m_g,
get_reduce_count_per_thread,
arg.numBlockTileIteration,
arg.p_x_,
static_cast<MeanVarDataType*>(arg.workspace_mean),
static_cast<MeanVarDataType*>(arg.workspace_variance),
static_cast<int32_t*>(arg.workspace_count));
};
const auto kern_welford_second_half_reduce_first_half =
kernel_welford_second_half_reduce_first_half<
GridwiseWelfordSecondHalfReduceFirstHalf_,
XDataType,
DyDataType,
AccDataType,
ScaleDataType,
BiasDataType,
MeanVarDataType,
DyElementwiseOp,
XYGridDesc_M_K,
MeanVarGridDesc_M,
MeanVarCountGridDesc_M_K,
DscaleDbiasGridDesc_M_G>;
const auto kern_reduce_second_half_batchnorm_backward_final =
kernel_reduce_second_half_batchnorm_backward_final<
GridwiseReduceSecondHalfBatchNormBwdFinal_,
XDataType,
DyDataType,
DxDataType,
ScaleDataType,
BiasDataType,
MeanVarDataType,
DyElementwiseOp,
XYGridDesc_M_K,
DscaleDbiasGridDesc_M_K,
MeanVarGridDesc_M,
ScaleBiasGridDesc_M>;
index_t numDscaleDbiasBlockTileIteration =
(arg.blkGroupSize + KThreadClusterSize - 1) / KThreadClusterSize;
avg_time += launch_and_time_kernel(
stream_config,
kern_welford_second_half_reduce_first_half,
dim3(arg.gridSize),
dim3(BlockSize),
0,
arg.x_grid_desc_m_k,
arg.dy_grid_desc_m_k,
arg.mean_var_grid_desc_m,
mean_var_count_grid_desc_m_k,
dscale_dbias_grid_desc_m_g,
arg.blkGroupSize,
arg.numBlockTileIteration,
numDscaleDbiasBlockTileIteration,
arg.epsilon_,
arg.haveSavedMeanInvVar_,
arg.haveSavedMeanInvVar_ ? arg.p_savedMean_ : nullptr,
arg.haveSavedMeanInvVar_ ? arg.p_savedInvVar_ : nullptr,
arg.haveSavedMeanInvVar_
? nullptr
: static_cast<const MeanVarDataType*>(arg.workspace_mean),
arg.haveSavedMeanInvVar_
? nullptr
: static_cast<const MeanVarDataType*>(arg.workspace_variance),
arg.haveSavedMeanInvVar_ ? nullptr
: static_cast<const int32_t*>(arg.workspace_count),
arg.dy_elementwise_op_,
arg.haveSavedMeanInvVar_
? nullptr
: static_cast<MeanVarDataType*>(arg.workspace_savedMean),
arg.haveSavedMeanInvVar_
? nullptr
: static_cast<MeanVarDataType*>(arg.workspace_savedInvVar),
arg.p_x_,
arg.p_dy_,
static_cast<ScaleDataType*>(arg.workspace_reduce_dscale),
static_cast<BiasDataType*>(arg.workspace_reduce_dbias));
avg_time += launch_and_time_kernel(
stream_config,
kern_reduce_second_half_batchnorm_backward_final,
dim3(arg.gridSize),
dim3(BlockSize),
0,
arg.x_grid_desc_m_k,
arg.dy_grid_desc_m_k,
arg.dx_grid_desc_m_k,
dscale_dbias_grid_desc_m_k,
arg.mean_var_grid_desc_m,
arg.scale_grid_desc_m,
arg.bias_grid_desc_m,
arg.blkGroupSize,
arg.reduce_length,
arg.numBlockTileIteration,
numDscaleDbiasBlockTileIteration,
static_cast<const ScaleDataType*>(arg.workspace_reduce_dscale),
static_cast<const BiasDataType*>(arg.workspace_reduce_dbias),
arg.haveSavedMeanInvVar_
? arg.p_savedMean_
: static_cast<const MeanVarDataType*>(arg.workspace_savedMean),
arg.haveSavedMeanInvVar_
? arg.p_savedInvVar_
: static_cast<const MeanVarDataType*>(arg.workspace_savedInvVar),
arg.p_x_,
arg.p_dy_,
arg.p_scale_,
arg.dy_elementwise_op_,
arg.p_dx_,
arg.p_dscale_,
arg.p_dbias_);
}
else
{
using GetReduceCountPerThreadFunctor =
GetReduceCountPerThreadForBlockwiseWelford<K_BlockTileSize, KThreadSliceSize>;
GetReduceCountPerThreadFunctor get_reduce_count_per_thread(
arg.numBlockTileIteration, arg.reduce_length);
using GridwiseBatchNormBackwardWithBlockwiseWelford_ =
GridwiseBatchNormBackwardWithBlockwiseWelford<XDataType,
DyDataType,
DxDataType,
AccDataType,
ScaleDataType,
BiasDataType,
MeanVarDataType,
DyElementwiseOp,
XYGridDesc_M_K,
ScaleBiasGridDesc_M,
MeanVarGridDesc_M,
GetReduceCountPerThreadFunctor,
BlockSize,
MThreadClusterSize,
KThreadClusterSize,
MThreadSliceSize,
KThreadSliceSize,
XDyDxVectorDim,
XSrcVectorSize,
DySrcVectorSize,
DxDstVectorSize,
ScaleSrcDstVectorSize,
BiasDstVectorSize,
MeanVarSrcVectorSize>;
const auto kern_batchnorm_bwd = kernel_batchnorm_backward_with_blockwise_welford<
GridwiseBatchNormBackwardWithBlockwiseWelford_,
XDataType,
DyDataType,
DxDataType,
AccDataType,
ScaleDataType,
BiasDataType,
MeanVarDataType,
DyElementwiseOp,
XYGridDesc_M_K,
ScaleBiasGridDesc_M,
MeanVarGridDesc_M,
GetReduceCountPerThreadFunctor>;
avg_time += launch_and_time_kernel(stream_config,
kern_batchnorm_bwd,
dim3(arg.gridSize),
dim3(BlockSize),
0,
arg.x_grid_desc_m_k,
arg.dy_grid_desc_m_k,
arg.dx_grid_desc_m_k,
arg.scale_grid_desc_m,
arg.bias_grid_desc_m,
arg.mean_var_grid_desc_m,
get_reduce_count_per_thread,
arg.reduce_length,
arg.numBlockTileIteration,
arg.epsilon_,
arg.p_x_,
arg.p_dy_,
arg.p_scale_,
arg.haveSavedMeanInvVar_,
arg.p_savedMean_,
arg.p_savedInvVar_,
arg.dy_elementwise_op_,
arg.p_dx_,
arg.p_dscale_,
arg.p_dbias_);
};
return (avg_time);
};
float Run(const BaseArgument* pArg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(pArg), stream_config);
};
};
bool IsSupportedArgument(const BaseArgument* pArg) override
{
const Argument* pArg_ = dynamic_cast<const Argument*>(pArg);
if constexpr(XDyDxVectorDim == 0)
{
if(pArg_->xStrides_[NumInvariantDim - 1] != 1 ||
pArg_->dyStrides_[NumInvariantDim - 1] != 1 ||
pArg_->dxStrides_[NumInvariantDim - 1] != 1)
return false;
if(pArg_->xyLengths_[NumInvariantDim - 1] % XSrcVectorSize != 0 ||
pArg_->xyLengths_[NumInvariantDim - 1] % DySrcVectorSize != 0 ||
pArg_->xyLengths_[NumInvariantDim - 1] % DxDstVectorSize != 0)
return false;
}
else
{
if(pArg_->xStrides_[Rank - 1] != 1 || pArg_->dyStrides_[Rank - 1] != 1 ||
pArg_->dxStrides_[Rank - 1] != 1)
return false;
if(pArg_->xyLengths_[Rank - 1] % XSrcVectorSize != 0 ||
pArg_->xyLengths_[Rank - 1] % DySrcVectorSize != 0 ||
pArg_->xyLengths_[Rank - 1] % DxDstVectorSize != 0)
return false;
};
if(pArg_->bnScaleStrides_[NumInvariantDim - 1] != 1 && ScaleSrcDstVectorSize != 1)
return false;
if(pArg_->bnBiasStrides_[NumInvariantDim - 1] != 1 && BiasDstVectorSize != 1)
return false;
if(pArg_->bnScaleBiasMeanVarLengths_[NumInvariantDim - 1] % ScaleSrcDstVectorSize != 0)
return false;
if(pArg_->bnScaleBiasMeanVarLengths_[NumInvariantDim - 1] % BiasDstVectorSize != 0)
return false;
if(pArg_->haveSavedMeanInvVar_)
{
if(pArg_->bnMeanVarStrides_[NumInvariantDim - 1] != 1 && MeanVarSrcVectorSize != 1)
return false;
if(pArg_->bnScaleBiasMeanVarLengths_[NumInvariantDim - 1] % MeanVarSrcVectorSize != 0)
return false;
};
bool is_valid = true;
static_for<0, NumInvariantDim, 1>{}([&](auto I) {
if(pArg_->xyLengths_[I] != pArg_->bnScaleBiasMeanVarLengths_[I])
is_valid = false;
});
if(!is_valid)
return false;
return true;
};
std::unique_ptr<BaseArgument>
MakeArgumentPointer(const std::array<index_t, Rank> xyLengths,
const std::array<index_t, Rank> xStrides,
const std::array<index_t, Rank> dyStrides,
const std::array<index_t, Rank> dxStrides,
const std::array<int, NumBatchNormReduceDim> reduceDims,
const std::array<ck::index_t, NumInvariantDim> bnScaleBiasMeanVarLengths,
const std::array<ck::index_t, NumInvariantDim> bnScaleStrides,
const std::array<ck::index_t, NumInvariantDim> bnBiasStrides,
const std::array<ck::index_t, NumInvariantDim> bnMeanVarStrides,
const void* p_x,
const void* p_dy,
const void* p_scale,
const void* p_savedMean,
const void* p_savedInvVar,
double epsilon,
const DyElementwiseOp dy_elementwise_op,
void* p_dx,
void* p_dscale,
void* p_dbias) override
{
return std::make_unique<Argument>(xyLengths,
xStrides,
dyStrides,
dxStrides,
reduceDims,
bnScaleBiasMeanVarLengths,
bnScaleStrides,
bnBiasStrides,
bnMeanVarStrides,
static_cast<const XDataType*>(p_x),
static_cast<const DyDataType*>(p_dy),
static_cast<const ScaleDataType*>(p_scale),
static_cast<const MeanVarDataType*>(p_savedMean),
static_cast<const MeanVarDataType*>(p_savedInvVar),
dy_elementwise_op,
epsilon,
static_cast<DxDataType*>(p_dx),
static_cast<ScaleDataType*>(p_dscale),
static_cast<BiasDataType*>(p_dbias));
};
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>();
};
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "DeviceBatchNormBwdImpl<" << BlockSize << ",";
str << "M_C" << MThreadClusterSize << "_S" << MThreadSliceSize << ",";
str << "K_C" << KThreadClusterSize << "_S" << KThreadSliceSize << ",";
str << "XDyDxVectorDim_" << XDyDxVectorDim << ",";
str << "VectorSize_X" << XSrcVectorSize << "_scale_" << ScaleSrcDstVectorSize << "_bias_" << BiasDstVectorSize << "_mean_var_" << MeanVarSrcVectorSize << "_Dx_" << DxDstVectorSize << ">";
// clang-format on
return str.str();
}
}; // namespace device
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -42,8 +42,15 @@ template <typename XDataType,
index_t ScaleSrcVectorSize,
index_t BiasSrcVectorSize,
index_t MeanVarSrcDstVectorSize>
struct DeviceBatchNormFwdImpl
: public DeviceBatchNormFwd<Rank, NumBatchNormReduceDim, YElementwiseOp>
struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType,
YDataType,
AccDataType,
ScaleDataType,
BiasDataType,
MeanVarDataType,
YElementwiseOp,
Rank,
NumBatchNormReduceDim>
{
static_assert(Rank <= 6, "Bigger Rank size is not supported!");
static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize,
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_conv_bwd_data.hpp"
#include "ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
// out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C]
template <ck::index_t NDimSpatial,
typename InDataType,
typename WeiDataType,
typename OutDataType,
typename AccDataType,
typename InElementwiseOperation,
typename WeiElementwiseOperation,
typename OutElementwiseOperation,
ConvolutionBackwardDataSpecialization ConvBackwardDataSpecialization,
ck::index_t BlockSize,
ck::index_t MPerBlock,
ck::index_t NPerBlock,
ck::index_t K0PerBlock,
ck::index_t K1,
index_t M1PerThread,
index_t N1PerThread,
index_t KPerThread,
typename M1N1ThreadClusterM1Xs,
typename M1N1ThreadClusterN1Xs,
typename ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
typename ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
typename ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1,
typename ABlockTransferSrcVectorTensorContiguousDimOrder,
typename ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1,
typename BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
typename BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
typename BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1,
typename BBlockTransferSrcVectorTensorContiguousDimOrder,
typename BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector>
struct DeviceConvNdBwdDataNwcKxcNwk_Dl
: public DeviceConvBwdData<
NDimSpatial,
ck::tuple_element_t<NDimSpatial - 1,
ck::Tuple<ck::tensor_layout::convolution::NWC,
ck::tensor_layout::convolution::NHWC,
ck::tensor_layout::convolution::NDHWC>>,
ck::tuple_element_t<NDimSpatial - 1,
ck::Tuple<ck::tensor_layout::convolution::KXC,
ck::tensor_layout::convolution::KYXC,
ck::tensor_layout::convolution::KZYXC>>,
ck::tuple_element_t<NDimSpatial - 1,
ck::Tuple<ck::tensor_layout::convolution::NWK,
ck::tensor_layout::convolution::NHWK,
ck::tensor_layout::convolution::NDHWK>>,
InDataType,
WeiDataType,
OutDataType,
InElementwiseOperation,
WeiElementwiseOperation,
OutElementwiseOperation>
{
using DeviceOp = DeviceConvNdBwdDataNwcKxcNwk_Dl;
using ADataType = OutDataType;
using BDataType = WeiDataType;
using CDataType = InDataType;
// TODO make A/B datatype different
using ABDataType = InDataType;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{};
static constexpr auto I5 = Number<5>{};
static constexpr auto I6 = Number<6>{};
static constexpr auto I7 = Number<7>{};
template <ck::index_t NDim, typename ck::enable_if<NDim == 1, bool>::type = false>
static auto
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(ck::index_t N,
ck::index_t K,
ck::index_t C,
std::vector<ck::index_t> input_spatial_lengths,
std::vector<ck::index_t> filter_spatial_lengths,
std::vector<ck::index_t> output_spatial_lengths,
std::vector<ck::index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads,
std::vector<ck::index_t> tildes)
{
using namespace ck;
index_t i_xtilde = tildes[0];
const index_t Wi = input_spatial_lengths[0];
const index_t Wo = output_spatial_lengths[0];
const index_t X = filter_spatial_lengths[0];
const index_t InLeftPadW = input_left_pads[0];
const index_t InRightPadW = input_right_pads[0];
const index_t ConvStrideW = conv_filter_strides[0];
const index_t ConvDilationW = conv_filter_dilations[0];
const auto K0 = K / K1;
const auto in_n_wi_c_grid_desc = make_naive_tensor_descriptor_packed(make_tuple(N, Wi, C));
if constexpr(ConvBackwardDataSpecialization ==
ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0)
{
// A: output tensor
const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(N * Wo, K)),
make_tuple(make_pass_through_transform(N * Wo),
make_unmerge_transform(make_tuple(K0, K1))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<1>{}, Sequence<0, 2>{}));
// B: weight tensor
const auto wei_gemmk0_gemmn_gemmk1_grid_desc =
transform_tensor_descriptor(make_naive_tensor_descriptor_packed(make_tuple(K, C)),
make_tuple(make_unmerge_transform(make_tuple(K0, K1)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// C: input tensor
const auto in_n_x_wo_c_grid_desc = transform_tensor_descriptor(
in_n_wi_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(I1, Wo), make_tuple(I1, ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor(
in_n_x_wo_c_grid_desc,
make_tuple(make_freeze_transform(I0),
make_merge_transform(make_tuple(N, Wo)),
make_pass_through_transform(C)),
make_tuple(Sequence<1>{}, Sequence<0, 2>{}, Sequence<3>{}),
make_tuple(Sequence<>{}, Sequence<0>{}, Sequence<1>{}));
return make_tuple(out_gemmk0_gemmm_gemmk1_grid_desc,
wei_gemmk0_gemmn_gemmk1_grid_desc,
in_gemmm_gemmn_grid_desc);
}
else
{
const auto out_n_wo_k_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N, Wo, K));
const auto wei_k_x_c_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(K, X, C));
const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
const auto XTilde = ConvStrideW / GcdStrideDilationW;
const auto XDot = math::integer_divide_ceil(X, XTilde);
const auto WTilde =
Wo + math::integer_divide_ceil(ConvDilationW * (X - I1), ConvStrideW);
// only work on HTilde and WTilde that contribute to non-padding area of input tensor
const auto IWTildeSliceBegin = math::integer_divide_floor(
math::max(I0, InLeftPadW - ConvDilationW * (XTilde - I1)), ConvStrideW);
const auto IWTildeSliceEnd = math::min(
WTilde, math::integer_divide_ceil(InLeftPadW + Wi - I1, ConvStrideW) + I1);
const auto WTildeSlice = IWTildeSliceEnd - IWTildeSliceBegin;
// GemmK is different for each GEMM
const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde);
// A: output tensor
const auto out_n_wop_k_grid_desc = transform_tensor_descriptor(
out_n_wo_k_grid_desc,
make_tuple(make_pass_through_transform(N),
make_pad_transform(Wo, I0, I0),
make_pass_through_transform(K)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
const auto out_n_xdot_wtilde_k_grid_desc = transform_tensor_descriptor(
out_n_wop_k_grid_desc,
make_tuple(
make_pass_through_transform(N),
make_embed_transform(make_tuple(XDot, WTilde),
make_tuple(-ConvDilationW / GcdStrideDilationW, I1)),
make_pass_through_transform(K)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
const auto out_n_xdotslice_wtildeslice_k0_k1_grid_desc = transform_tensor_descriptor(
out_n_xdot_wtilde_k_grid_desc,
make_tuple(make_pass_through_transform(N),
make_slice_transform(XDot, I0, XDotSlice),
make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice),
make_unmerge_transform(make_tuple(K0, K1))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3, 4>{}));
const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
out_n_xdotslice_wtildeslice_k0_k1_grid_desc,
make_tuple(make_merge_transform(make_tuple(XDotSlice, K0)),
make_merge_transform(make_tuple(N, WTildeSlice)),
make_pass_through_transform(K1)),
make_tuple(Sequence<1, 3>{}, Sequence<0, 2>{}, Sequence<4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
// B weight tensor
const auto wei_k_xdot_xtilde_c_grid_desc = transform_tensor_descriptor(
wei_k_x_c_grid_desc,
make_tuple(make_pass_through_transform(K),
make_embed_transform(make_tuple(XDot, XTilde),
make_tuple(ConvStrideW / GcdStrideDilationW, I1)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
const auto wei_k0_k1_xdotslice_c_grid_desc = transform_tensor_descriptor(
wei_k_xdot_xtilde_c_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(K0, K1)),
make_slice_transform(XDot, I0, XDotSlice),
make_freeze_transform(i_xtilde),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<>{}, Sequence<3>{}));
const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
wei_k0_k1_xdotslice_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(XDotSlice, K0)),
make_pass_through_transform(C),
make_pass_through_transform(K1)),
make_tuple(Sequence<2, 0>{}, Sequence<3>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
// C: input tensor
const auto in_n_wip_c_grid_desc = transform_tensor_descriptor(
in_n_wi_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_pad_transform(Wi, InLeftPadW, InRightPadW),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
const auto in_n_xtilde_wtilde_c_grid_desc = transform_tensor_descriptor(
in_n_wip_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(XTilde, WTilde),
make_tuple(ConvDilationW, ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
const auto in_n_wtildeslice_c_grid_desc = transform_tensor_descriptor(
in_n_xtilde_wtilde_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_freeze_transform(i_xtilde),
make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<>{}, Sequence<1>{}, Sequence<2>{}));
const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor(
in_n_wtildeslice_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(N, WTildeSlice)),
make_pass_through_transform(C)),
make_tuple(Sequence<0, 1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return make_tuple(out_gemmk0_gemmm_gemmk1_grid_desc,
wei_gemmk0_gemmn_gemmk1_grid_desc,
in_gemmm_gemmn_grid_desc);
}
} // function end
template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false>
static auto
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(ck::index_t N,
ck::index_t K,
ck::index_t C,
std::vector<ck::index_t> input_spatial_lengths,
std::vector<ck::index_t> filter_spatial_lengths,
std::vector<ck::index_t> output_spatial_lengths,
std::vector<ck::index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads,
std::vector<ck::index_t> tildes)
{
using namespace ck;
index_t i_ytilde = tildes[0];
index_t i_xtilde = tildes[1];
const index_t Hi = input_spatial_lengths[0];
const index_t Wi = input_spatial_lengths[1];
const index_t Ho = output_spatial_lengths[0];
const index_t Wo = output_spatial_lengths[1];
const index_t Y = filter_spatial_lengths[0];
const index_t X = filter_spatial_lengths[1];
const index_t InLeftPadH = input_left_pads[0];
const index_t InLeftPadW = input_left_pads[1];
const index_t InRightPadH = input_right_pads[0];
const index_t InRightPadW = input_right_pads[1];
const index_t ConvStrideH = conv_filter_strides[0];
const index_t ConvStrideW = conv_filter_strides[1];
const index_t ConvDilationH = conv_filter_dilations[0];
const index_t ConvDilationW = conv_filter_dilations[1];
const auto K0 = K / K1;
const auto out_n_ho_wo_k_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N, Ho, Wo, K));
const auto wei_k_y_x_c_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(K, Y, X, C));
const auto in_n_hi_wi_c_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C));
if constexpr(ConvBackwardDataSpecialization ==
ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0)
{
// A: output tensor
const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)),
make_tuple(make_pass_through_transform(N * Ho * Wo),
make_unmerge_transform(make_tuple(K0, K1))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<1>{}, Sequence<0, 2>{}));
// B: weight tensor
const auto wei_gemmk0_gemmn_gemmk1_grid_desc =
transform_tensor_descriptor(make_naive_tensor_descriptor_packed(make_tuple(K, C)),
make_tuple(make_unmerge_transform(make_tuple(K0, K1)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// C: input tensor
const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
in_n_hi_wi_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(I1, Ho), make_tuple(I1, ConvStrideH)),
make_embed_transform(make_tuple(I1, Wo), make_tuple(I1, ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor(
in_n_y_ho_x_wo_c_grid_desc,
make_tuple(make_freeze_transform(I0),
make_freeze_transform(I0),
make_merge_transform(make_tuple(N, Ho, Wo)),
make_pass_through_transform(C)),
make_tuple(Sequence<1>{}, Sequence<3>{}, Sequence<0, 2, 4>{}, Sequence<5>{}),
make_tuple(Sequence<>{}, Sequence<>{}, Sequence<0>{}, Sequence<1>{}));
return make_tuple(out_gemmk0_gemmm_gemmk1_grid_desc,
wei_gemmk0_gemmn_gemmk1_grid_desc,
in_gemmm_gemmn_grid_desc);
}
else
{
const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
const auto YTilde = ConvStrideH / GcdStrideDilationH;
const auto XTilde = ConvStrideW / GcdStrideDilationW;
const auto YDot = math::integer_divide_ceil(Y, YTilde);
const auto XDot = math::integer_divide_ceil(X, XTilde);
const auto HTilde =
Ho + math::integer_divide_ceil(ConvDilationH * (Y - I1), ConvStrideH);
const auto WTilde =
Wo + math::integer_divide_ceil(ConvDilationW * (X - I1), ConvStrideW);
// only work on HTilde and WTilde that contribute to non-padding area of input tensor
const auto IHTildeSliceBegin = math::integer_divide_floor(
math::max(I0, InLeftPadH - ConvDilationH * (YTilde - I1)), ConvStrideH);
const auto IWTildeSliceBegin = math::integer_divide_floor(
math::max(I0, InLeftPadW - ConvDilationW * (XTilde - I1)), ConvStrideW);
const auto IHTildeSliceEnd = math::min(
HTilde, math::integer_divide_ceil(InLeftPadH + Hi - I1, ConvStrideH) + I1);
const auto IWTildeSliceEnd = math::min(
WTilde, math::integer_divide_ceil(InLeftPadW + Wi - I1, ConvStrideW) + I1);
const auto HTildeSlice = IHTildeSliceEnd - IHTildeSliceBegin;
const auto WTildeSlice = IWTildeSliceEnd - IWTildeSliceBegin;
// GemmK is different for each GEMM
const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde);
const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde);
// A: output tensor
const auto out_n_hop_wop_k_grid_desc = transform_tensor_descriptor(
out_n_ho_wo_k_grid_desc,
make_tuple(make_pass_through_transform(N),
make_pad_transform(Ho, I0, I0),
make_pad_transform(Wo, I0, I0),
make_pass_through_transform(K)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto out_n_ydot_htilde_xdot_wtilde_k_grid_desc = transform_tensor_descriptor(
out_n_hop_wop_k_grid_desc,
make_tuple(
make_pass_through_transform(N),
make_embed_transform(make_tuple(YDot, HTilde),
make_tuple(-ConvDilationH / GcdStrideDilationH, I1)),
make_embed_transform(make_tuple(XDot, WTilde),
make_tuple(-ConvDilationW / GcdStrideDilationW, I1)),
make_pass_through_transform(K)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
const auto out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc =
transform_tensor_descriptor(
out_n_ydot_htilde_xdot_wtilde_k_grid_desc,
make_tuple(make_pass_through_transform(N),
make_slice_transform(YDot, I0, YDotSlice),
make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice),
make_slice_transform(XDot, I0, XDotSlice),
make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice),
make_unmerge_transform(make_tuple(K0, K1))),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5>{}),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5, 6>{}));
const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc,
make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K0)),
make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice)),
make_pass_through_transform(K1)),
make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}, Sequence<6>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
// B weight tensor
const auto wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc = transform_tensor_descriptor(
wei_k_y_x_c_grid_desc,
make_tuple(make_pass_through_transform(K),
make_embed_transform(make_tuple(YDot, YTilde),
make_tuple(ConvStrideH / GcdStrideDilationH, I1)),
make_embed_transform(make_tuple(XDot, XTilde),
make_tuple(ConvStrideW / GcdStrideDilationW, I1)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
const auto wei_k0_k1_ydotslice_xdotslice_c_grid_desc =
transform_tensor_descriptor(wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(K0, K1)),
make_slice_transform(YDot, I0, YDotSlice),
make_slice_transform(XDot, I0, XDotSlice),
make_freeze_transform(i_ytilde),
make_freeze_transform(i_xtilde),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<3>{},
Sequence<2>{},
Sequence<4>{},
Sequence<5>{}),
make_tuple(Sequence<0, 1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<>{},
Sequence<>{},
Sequence<4>{}));
const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
wei_k0_k1_ydotslice_xdotslice_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K0)),
make_pass_through_transform(C),
make_pass_through_transform(K1)),
make_tuple(Sequence<2, 3, 0>{}, Sequence<4>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
// C: input tensor
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
in_n_hi_wi_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_pad_transform(Hi, InLeftPadH, InRightPadH),
make_pad_transform(Wi, InLeftPadW, InRightPadW),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc = transform_tensor_descriptor(
in_n_hip_wip_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(YTilde, HTilde),
make_tuple(ConvDilationH, ConvStrideH)),
make_embed_transform(make_tuple(XTilde, WTilde),
make_tuple(ConvDilationW, ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
const auto in_n_htildeslice_wtildeslice_c_grid_desc = transform_tensor_descriptor(
in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_freeze_transform(i_ytilde),
make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice),
make_freeze_transform(i_xtilde),
make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5>{}),
make_tuple(Sequence<0>{},
Sequence<>{},
Sequence<1>{},
Sequence<>{},
Sequence<2>{},
Sequence<3>{}));
const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor(
in_n_htildeslice_wtildeslice_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice)),
make_pass_through_transform(C)),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return make_tuple(out_gemmk0_gemmm_gemmk1_grid_desc,
wei_gemmk0_gemmn_gemmk1_grid_desc,
in_gemmm_gemmn_grid_desc);
}
} // function end
template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false>
static auto
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(ck::index_t N,
ck::index_t K,
ck::index_t C,
std::vector<ck::index_t> input_spatial_lengths,
std::vector<ck::index_t> filter_spatial_lengths,
std::vector<ck::index_t> output_spatial_lengths,
std::vector<ck::index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads,
std::vector<ck::index_t> tildes)
{
using namespace ck;
const index_t i_ztilde = tildes[0];
const index_t i_ytilde = tildes[1];
const index_t i_xtilde = tildes[2];
const index_t Di = input_spatial_lengths[0];
const index_t Hi = input_spatial_lengths[1];
const index_t Wi = input_spatial_lengths[2];
const index_t Do = output_spatial_lengths[0];
const index_t Ho = output_spatial_lengths[1];
const index_t Wo = output_spatial_lengths[2];
const index_t Z = filter_spatial_lengths[0];
const index_t Y = filter_spatial_lengths[1];
const index_t X = filter_spatial_lengths[2];
const index_t InLeftPadD = input_left_pads[0];
const index_t InLeftPadH = input_left_pads[1];
const index_t InLeftPadW = input_left_pads[2];
const index_t InRightPadD = input_right_pads[0];
const index_t InRightPadH = input_right_pads[1];
const index_t InRightPadW = input_right_pads[2];
const index_t ConvStrideD = conv_filter_strides[0];
const index_t ConvStrideH = conv_filter_strides[1];
const index_t ConvStrideW = conv_filter_strides[2];
const index_t ConvDilationD = conv_filter_dilations[0];
const index_t ConvDilationH = conv_filter_dilations[1];
const index_t ConvDilationW = conv_filter_dilations[2];
const auto K0 = K / K1;
const auto out_n_do_ho_wo_k_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N, Do, Ho, Wo, K));
const auto wei_k_z_y_x_c_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(K, Z, Y, X, C));
const auto in_n_di_hi_wi_c_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N, Di, Hi, Wi, C));
if constexpr(ConvBackwardDataSpecialization ==
ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0)
{
// A: output tensor
const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(N * Do * Ho * Wo, K)),
make_tuple(make_pass_through_transform(N * Do * Ho * Wo),
make_unmerge_transform(make_tuple(K0, K1))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<1>{}, Sequence<0, 2>{}));
// B: weight tensor
const auto wei_gemmk0_gemmn_gemmk1_grid_desc =
transform_tensor_descriptor(make_naive_tensor_descriptor_packed(make_tuple(K, C)),
make_tuple(make_unmerge_transform(make_tuple(K0, K1)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// C: input tensor
const auto in_n_z_do_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
in_n_di_hi_wi_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(I1, Do), make_tuple(I1, ConvStrideD)),
make_embed_transform(make_tuple(I1, Ho), make_tuple(I1, ConvStrideH)),
make_embed_transform(make_tuple(I1, Wo), make_tuple(I1, ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(Sequence<0>{},
Sequence<1, 2>{},
Sequence<3, 4>{},
Sequence<5, 6>{},
Sequence<7>{}));
const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor(
in_n_z_do_y_ho_x_wo_c_grid_desc,
make_tuple(make_freeze_transform(I0),
make_freeze_transform(I0),
make_freeze_transform(I0),
make_merge_transform(make_tuple(N, Do, Ho, Wo)),
make_pass_through_transform(C)),
make_tuple(Sequence<1>{},
Sequence<3>{},
Sequence<5>{},
Sequence<0, 2, 4, 6>{},
Sequence<7>{}),
make_tuple(Sequence<>{}, Sequence<>{}, Sequence<>{}, Sequence<0>{}, Sequence<1>{}));
return make_tuple(out_gemmk0_gemmm_gemmk1_grid_desc,
wei_gemmk0_gemmn_gemmk1_grid_desc,
in_gemmm_gemmn_grid_desc);
}
else
{
const auto GcdStrideDilationD = math::gcd(ConvStrideD, ConvDilationD);
const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
const auto ZTilde = ConvStrideD / GcdStrideDilationD;
const auto YTilde = ConvStrideH / GcdStrideDilationH;
const auto XTilde = ConvStrideW / GcdStrideDilationW;
const auto ZDot = math::integer_divide_ceil(Z, ZTilde);
const auto YDot = math::integer_divide_ceil(Y, YTilde);
const auto XDot = math::integer_divide_ceil(X, XTilde);
const auto DTilde =
Do + math::integer_divide_ceil(ConvDilationD * (Z - I1), ConvStrideD);
const auto HTilde =
Ho + math::integer_divide_ceil(ConvDilationH * (Y - I1), ConvStrideH);
const auto WTilde =
Wo + math::integer_divide_ceil(ConvDilationW * (X - I1), ConvStrideW);
// only work on HTilde and WTilde that contribute to non-padding area of input tensor
const auto IDTildeSliceBegin = math::integer_divide_floor(
math::max(I0, InLeftPadD - ConvDilationD * (ZTilde - I1)), ConvStrideD);
const auto IHTildeSliceBegin = math::integer_divide_floor(
math::max(I0, InLeftPadH - ConvDilationH * (YTilde - I1)), ConvStrideH);
const auto IWTildeSliceBegin = math::integer_divide_floor(
math::max(I0, InLeftPadW - ConvDilationW * (XTilde - I1)), ConvStrideW);
const auto IDTildeSliceEnd = math::min(
DTilde, math::integer_divide_ceil(InLeftPadD + Di - I1, ConvStrideD) + I1);
const auto IHTildeSliceEnd = math::min(
HTilde, math::integer_divide_ceil(InLeftPadH + Hi - I1, ConvStrideH) + I1);
const auto IWTildeSliceEnd = math::min(
WTilde, math::integer_divide_ceil(InLeftPadW + Wi - I1, ConvStrideW) + I1);
const auto DTildeSlice = IDTildeSliceEnd - IDTildeSliceBegin;
const auto HTildeSlice = IHTildeSliceEnd - IHTildeSliceBegin;
const auto WTildeSlice = IWTildeSliceEnd - IWTildeSliceBegin;
// GemmK is different for each GEMM
const auto ZDotSlice = math::integer_divide_ceil(Z - i_ztilde, ZTilde);
const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde);
const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde);
// A: output tensor
const auto out_n_dop_hop_wop_k_grid_desc = transform_tensor_descriptor(
out_n_do_ho_wo_k_grid_desc,
make_tuple(make_pass_through_transform(N),
make_pad_transform(Do, I0, I0),
make_pad_transform(Ho, I0, I0),
make_pad_transform(Wo, I0, I0),
make_pass_through_transform(K)),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
const auto out_n_zdot_dtilde_ydot_htilde_xdot_wtilde_k_grid_desc =
transform_tensor_descriptor(
out_n_dop_hop_wop_k_grid_desc,
make_tuple(
make_pass_through_transform(N),
make_embed_transform(make_tuple(ZDot, DTilde),
make_tuple(-ConvDilationD / GcdStrideDilationD, I1)),
make_embed_transform(make_tuple(YDot, HTilde),
make_tuple(-ConvDilationH / GcdStrideDilationH, I1)),
make_embed_transform(make_tuple(XDot, WTilde),
make_tuple(-ConvDilationW / GcdStrideDilationW, I1)),
make_pass_through_transform(K)),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(Sequence<0>{},
Sequence<1, 2>{},
Sequence<3, 4>{},
Sequence<5, 6>{},
Sequence<7>{}));
const auto
out_n_zdotslice_dtildeslice_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc =
transform_tensor_descriptor(
out_n_zdot_dtilde_ydot_htilde_xdot_wtilde_k_grid_desc,
make_tuple(make_pass_through_transform(N),
make_slice_transform(ZDot, I0, ZDotSlice),
make_slice_transform(DTilde, IDTildeSliceBegin, DTildeSlice),
make_slice_transform(YDot, I0, YDotSlice),
make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice),
make_slice_transform(XDot, I0, XDotSlice),
make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice),
make_unmerge_transform(make_tuple(K0, K1))),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5>{},
Sequence<6>{},
Sequence<7>{}),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5>{},
Sequence<6>{},
Sequence<7, 8>{}));
const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
out_n_zdotslice_dtildeslice_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc,
make_tuple(
make_merge_transform(make_tuple(ZDotSlice, YDotSlice, XDotSlice, K0)),
make_merge_transform(make_tuple(N, DTildeSlice, HTildeSlice, WTildeSlice)),
make_pass_through_transform(K1)),
make_tuple(Sequence<1, 3, 5, 7>{}, Sequence<0, 2, 4, 6>{}, Sequence<8>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
// B weight tensor
const auto wei_k_zdot_ztilde_ydot_ytilde_xdot_xtilde_c_grid_desc =
transform_tensor_descriptor(
wei_k_z_y_x_c_grid_desc,
make_tuple(
make_pass_through_transform(K),
make_embed_transform(make_tuple(ZDot, ZTilde),
make_tuple(ConvStrideD / GcdStrideDilationD, I1)),
make_embed_transform(make_tuple(YDot, YTilde),
make_tuple(ConvStrideH / GcdStrideDilationH, I1)),
make_embed_transform(make_tuple(XDot, XTilde),
make_tuple(ConvStrideW / GcdStrideDilationW, I1)),
make_pass_through_transform(C)),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(Sequence<0>{},
Sequence<1, 2>{},
Sequence<3, 4>{},
Sequence<5, 6>{},
Sequence<7>{}));
const auto wei_k0_k1_zdotslice_ydotslice_xdotslice_c_grid_desc =
transform_tensor_descriptor(wei_k_zdot_ztilde_ydot_ytilde_xdot_xtilde_c_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(K0, K1)),
make_slice_transform(ZDot, I0, ZDotSlice),
make_slice_transform(YDot, I0, YDotSlice),
make_slice_transform(XDot, I0, XDotSlice),
make_freeze_transform(i_ztilde),
make_freeze_transform(i_ytilde),
make_freeze_transform(i_xtilde),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<3>{},
Sequence<5>{},
Sequence<2>{},
Sequence<4>{},
Sequence<6>{},
Sequence<7>{}),
make_tuple(Sequence<0, 1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<>{},
Sequence<>{},
Sequence<>{},
Sequence<5>{}));
const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
wei_k0_k1_zdotslice_ydotslice_xdotslice_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(ZDotSlice, YDotSlice, XDotSlice, K0)),
make_pass_through_transform(C),
make_pass_through_transform(K1)),
make_tuple(Sequence<2, 3, 4, 0>{}, Sequence<5>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
// C: input tensor
const auto in_n_dip_hip_wip_c_grid_desc = transform_tensor_descriptor(
in_n_di_hi_wi_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_pad_transform(Di, InLeftPadD, InRightPadD),
make_pad_transform(Hi, InLeftPadH, InRightPadH),
make_pad_transform(Wi, InLeftPadW, InRightPadW),
make_pass_through_transform(C)),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
const auto in_n_ztilde_dtilde_ytilde_htilde_xtilde_wtilde_c_grid_desc =
transform_tensor_descriptor(
in_n_dip_hip_wip_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(ZTilde, DTilde),
make_tuple(ConvDilationD, ConvStrideD)),
make_embed_transform(make_tuple(YTilde, HTilde),
make_tuple(ConvDilationH, ConvStrideH)),
make_embed_transform(make_tuple(XTilde, WTilde),
make_tuple(ConvDilationW, ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(Sequence<0>{},
Sequence<1, 2>{},
Sequence<3, 4>{},
Sequence<5, 6>{},
Sequence<7>{}));
const auto in_n_dtildeslice_htildeslice_wtildeslice_c_grid_desc =
transform_tensor_descriptor(
in_n_ztilde_dtilde_ytilde_htilde_xtilde_wtilde_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_freeze_transform(i_ztilde),
make_slice_transform(DTilde, IDTildeSliceBegin, DTildeSlice),
make_freeze_transform(i_ytilde),
make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice),
make_freeze_transform(i_xtilde),
make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5>{},
Sequence<6>{},
Sequence<7>{}),
make_tuple(Sequence<0>{},
Sequence<>{},
Sequence<1>{},
Sequence<>{},
Sequence<2>{},
Sequence<>{},
Sequence<3>{},
Sequence<4>{}));
const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor(
in_n_dtildeslice_htildeslice_wtildeslice_c_grid_desc,
make_tuple(
make_merge_transform(make_tuple(N, DTildeSlice, HTildeSlice, WTildeSlice)),
make_pass_through_transform(C)),
make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return make_tuple(out_gemmk0_gemmm_gemmk1_grid_desc,
wei_gemmk0_gemmn_gemmk1_grid_desc,
in_gemmm_gemmn_grid_desc);
}
} // function end
template <ck::index_t NDim, typename ck::enable_if<NDim == 1, bool>::type = false>
static auto GetABCGridDesc()
{
return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<1>(
1, 1, 1, {1}, {1}, {1}, {1}, {1}, {1}, {1}, {0});
}
template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false>
static auto GetABCGridDesc()
{
return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<2>(
1, 1, 1, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {0, 0});
}
template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false>
static auto GetABCGridDesc()
{
return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<3>(1,
1,
1,
{1, 1, 1},
{1, 1, 1},
{1, 1, 1},
{1, 1, 1},
{1, 1, 1},
{1, 1, 1},
{1, 1, 1},
{0, 0, 0});
}
using ABCGridDescs = decltype(GetABCGridDesc<NDimSpatial>());
using AGridDesc_K0_M_K1 = remove_cvref_t<decltype(ABCGridDescs{}[I0])>;
using BGridDesc_K0_N_K1 = remove_cvref_t<decltype(ABCGridDescs{}[I1])>;
using CGridDesc_M_N = remove_cvref_t<decltype(ABCGridDescs{}[I2])>;
// GridwiseGemm
using GridwiseGemm =
GridwiseGemmDl_km_kn_mn_v1r3<BlockSize,
ADataType,
AccDataType,
CDataType,
InMemoryDataOperationEnum::Set,
AGridDesc_K0_M_K1,
BGridDesc_K0_N_K1,
CGridDesc_M_N,
MPerBlock,
NPerBlock,
K0PerBlock,
K1,
M1PerThread,
N1PerThread,
KPerThread,
M1N1ThreadClusterM1Xs,
M1N1ThreadClusterN1Xs,
ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1,
ABlockTransferSrcVectorTensorContiguousDimOrder,
ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1,
BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1,
BBlockTransferSrcVectorTensorContiguousDimOrder,
BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector>;
using AGridDesc_K0_M0_M1_K1 =
decltype(GridwiseGemm::MakeAGridDescriptor_K0_M0_M1_K1(AGridDesc_K0_M_K1{}));
using BGridDesc_K0_N0_N1_K1 =
decltype(GridwiseGemm::MakeBGridDescriptor_K0_N0_N1_K1(BGridDesc_K0_N_K1{}));
using CGridDesc_M0_M10_M11_N0_N10_N11 =
decltype(GridwiseGemm::MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(CGridDesc_M_N{}));
using DefaultBlock2CTileMap =
decltype(GridwiseGemm::MakeDefaultBlock2CTileMap(CGridDesc_M_N{}));
// Argument
struct Argument : public BaseArgument
{
Argument(InDataType* p_in_grid,
const WeiDataType* p_wei_grid,
const OutDataType* p_out_grid,
ck::index_t N,
ck::index_t K,
ck::index_t C,
std::vector<ck::index_t> input_spatial_lengths,
std::vector<ck::index_t> filter_spatial_lengths,
std::vector<ck::index_t> output_spatial_lengths,
std::vector<ck::index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads,
InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op)
: p_a_grid_{p_out_grid},
p_b_grid_{p_wei_grid},
p_c_grid_{p_in_grid},
a_element_op_{out_element_op},
b_element_op_{wei_element_op},
c_element_op_{in_element_op},
Conv_N_{N},
Conv_K_{K},
Conv_C_{C},
input_spatial_lengths_{input_spatial_lengths},
filter_spatial_lengths_{filter_spatial_lengths},
output_spatial_lengths_{output_spatial_lengths},
conv_filter_strides_{conv_filter_strides},
conv_filter_dilations_{conv_filter_dilations},
input_left_pads_{input_left_pads},
input_right_pads_{input_right_pads}
{
CreateABCDesc<NDimSpatial>();
}
template <ck::index_t NDim, typename ck::enable_if<NDim == 1, bool>::type = false>
void CreateABCDesc()
{
const index_t ConvStrideW = conv_filter_strides_[0];
const index_t ConvDilationW = conv_filter_dilations_[0];
const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
const auto XTilde = ConvStrideW / GcdStrideDilationW;
const index_t X = filter_spatial_lengths_[0];
for(index_t i_xtilde = 0; i_xtilde < XTilde; ++i_xtilde)
{
// check slice is valid
const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde);
if(XDotSlice <= 0)
{
continue;
}
const auto descs =
DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<NDimSpatial>(
Conv_N_,
Conv_K_,
Conv_C_,
input_spatial_lengths_,
filter_spatial_lengths_,
output_spatial_lengths_,
conv_filter_strides_,
conv_filter_dilations_,
input_left_pads_,
input_right_pads_,
{i_xtilde});
a_grid_desc_k0_m_k1_container_.push_back(descs[I0]);
b_grid_desc_k0_n_k1_container_.push_back(descs[I1]);
c_grid_desc_m_n_container_.push_back(descs[I2]);
if(GridwiseGemm::CheckValidity(descs[I0], descs[I1], descs[I2]))
{
a_grid_desc_k0_m0_m1_k1_container_.push_back(
GridwiseGemm::MakeAGridDescriptor_K0_M0_M1_K1(descs[I0]));
b_grid_desc_k0_n0_n1_k1_container_.push_back(
GridwiseGemm::MakeBGridDescriptor_K0_N0_N1_K1(descs[I1]));
c_grid_desc_m0_m10_m11_n0_n10_n11_container_.push_back(
GridwiseGemm::MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(descs[I2]));
block_2_ctile_map_container_.push_back(
GridwiseGemm::MakeDefaultBlock2CTileMap(descs[I2]));
}
}
}
template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false>
void CreateABCDesc()
{
const index_t ConvStrideH = conv_filter_strides_[0];
const index_t ConvStrideW = conv_filter_strides_[1];
const index_t ConvDilationH = conv_filter_dilations_[0];
const index_t ConvDilationW = conv_filter_dilations_[1];
const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
const auto YTilde = ConvStrideH / GcdStrideDilationH;
const auto XTilde = ConvStrideW / GcdStrideDilationW;
const index_t Y = filter_spatial_lengths_[0];
const index_t X = filter_spatial_lengths_[1];
for(index_t i_ytilde = 0; i_ytilde < YTilde; ++i_ytilde)
{
for(index_t i_xtilde = 0; i_xtilde < XTilde; ++i_xtilde)
{
// check slice is valid
const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde);
const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde);
if(YDotSlice * XDotSlice <= 0)
{
continue;
}
const auto descs =
DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<NDimSpatial>(
Conv_N_,
Conv_K_,
Conv_C_,
input_spatial_lengths_,
filter_spatial_lengths_,
output_spatial_lengths_,
conv_filter_strides_,
conv_filter_dilations_,
input_left_pads_,
input_right_pads_,
{i_ytilde, i_xtilde});
a_grid_desc_k0_m_k1_container_.push_back(descs[I0]);
b_grid_desc_k0_n_k1_container_.push_back(descs[I1]);
c_grid_desc_m_n_container_.push_back(descs[I2]);
if(GridwiseGemm::CheckValidity(descs[I0], descs[I1], descs[I2]))
{
a_grid_desc_k0_m0_m1_k1_container_.push_back(
GridwiseGemm::MakeAGridDescriptor_K0_M0_M1_K1(descs[I0]));
b_grid_desc_k0_n0_n1_k1_container_.push_back(
GridwiseGemm::MakeBGridDescriptor_K0_N0_N1_K1(descs[I1]));
c_grid_desc_m0_m10_m11_n0_n10_n11_container_.push_back(
GridwiseGemm::MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(descs[I2]));
block_2_ctile_map_container_.push_back(
GridwiseGemm::MakeDefaultBlock2CTileMap(descs[I2]));
}
}
}
}
template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false>
void CreateABCDesc()
{
const index_t ConvStrideD = conv_filter_strides_[0];
const index_t ConvStrideH = conv_filter_strides_[1];
const index_t ConvStrideW = conv_filter_strides_[2];
const index_t ConvDilationD = conv_filter_dilations_[0];
const index_t ConvDilationH = conv_filter_dilations_[1];
const index_t ConvDilationW = conv_filter_dilations_[2];
const auto GcdStrideDilationD = math::gcd(ConvStrideD, ConvDilationD);
const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
const auto ZTilde = ConvStrideD / GcdStrideDilationD;
const auto YTilde = ConvStrideH / GcdStrideDilationH;
const auto XTilde = ConvStrideW / GcdStrideDilationW;
const index_t Z = filter_spatial_lengths_[0];
const index_t Y = filter_spatial_lengths_[1];
const index_t X = filter_spatial_lengths_[2];
for(index_t i_ztilde = 0; i_ztilde < ZTilde; ++i_ztilde)
{
for(index_t i_ytilde = 0; i_ytilde < YTilde; ++i_ytilde)
{
for(index_t i_xtilde = 0; i_xtilde < XTilde; ++i_xtilde)
{
// check slice is valid
const auto ZDotSlice = math::integer_divide_ceil(Z - i_ztilde, ZTilde);
const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde);
const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde);
if(ZDotSlice * YDotSlice * XDotSlice <= 0)
{
continue;
}
const auto descs =
DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<NDimSpatial>(
Conv_N_,
Conv_K_,
Conv_C_,
input_spatial_lengths_,
filter_spatial_lengths_,
output_spatial_lengths_,
conv_filter_strides_,
conv_filter_dilations_,
input_left_pads_,
input_right_pads_,
{i_ztilde, i_ytilde, i_xtilde});
a_grid_desc_k0_m_k1_container_.push_back(descs[I0]);
b_grid_desc_k0_n_k1_container_.push_back(descs[I1]);
c_grid_desc_m_n_container_.push_back(descs[I2]);
if(GridwiseGemm::CheckValidity(descs[I0], descs[I1], descs[I2]))
{
a_grid_desc_k0_m0_m1_k1_container_.push_back(
GridwiseGemm::MakeAGridDescriptor_K0_M0_M1_K1(descs[I0]));
b_grid_desc_k0_n0_n1_k1_container_.push_back(
GridwiseGemm::MakeBGridDescriptor_K0_N0_N1_K1(descs[I1]));
c_grid_desc_m0_m10_m11_n0_n10_n11_container_.push_back(
GridwiseGemm::MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(descs[I2]));
block_2_ctile_map_container_.push_back(
GridwiseGemm::MakeDefaultBlock2CTileMap(descs[I2]));
}
}
}
}
}
const ADataType* p_a_grid_;
const BDataType* p_b_grid_;
CDataType* p_c_grid_;
std::vector<AGridDesc_K0_M_K1> a_grid_desc_k0_m_k1_container_;
std::vector<BGridDesc_K0_N_K1> b_grid_desc_k0_n_k1_container_;
std::vector<CGridDesc_M_N> c_grid_desc_m_n_container_;
std::vector<AGridDesc_K0_M0_M1_K1> a_grid_desc_k0_m0_m1_k1_container_;
std::vector<BGridDesc_K0_N0_N1_K1> b_grid_desc_k0_n0_n1_k1_container_;
std::vector<CGridDesc_M0_M10_M11_N0_N10_N11> c_grid_desc_m0_m10_m11_n0_n10_n11_container_;
std::vector<DefaultBlock2CTileMap> block_2_ctile_map_container_;
// element-wise op
OutElementwiseOperation a_element_op_;
WeiElementwiseOperation b_element_op_;
InElementwiseOperation c_element_op_;
// for checking IsSupportedArgument()
index_t Conv_N_;
index_t Conv_K_;
index_t Conv_C_;
std::vector<ck::index_t> input_spatial_lengths_;
std::vector<ck::index_t> filter_spatial_lengths_;
std::vector<ck::index_t> output_spatial_lengths_;
std::vector<ck::index_t> conv_filter_strides_;
std::vector<ck::index_t> conv_filter_dilations_;
std::vector<ck::index_t> input_left_pads_;
std::vector<ck::index_t> input_right_pads_;
};
// Invoker
struct Invoker : public BaseInvoker
{
using Argument = DeviceOp::Argument;
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
float ave_time = 0;
for(size_t i = 0; i < arg.a_grid_desc_k0_m_k1_container_.size(); i++)
{
{
std::cout << "arg.a_grid_desc_k0_m_k1_container_{"
<< arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I0) << ", "
<< arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I1) << ", "
<< arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I2) << "}"
<< std::endl;
std::cout << "arg.b_grid_desc_k0_n_k1_container_{"
<< arg.b_grid_desc_k0_n_k1_container_[i].GetLength(I0) << ", "
<< arg.b_grid_desc_k0_n_k1_container_[i].GetLength(I1) << ", "
<< arg.b_grid_desc_k0_n_k1_container_[i].GetLength(I2) << "}"
<< std::endl;
std::cout << "arg.c_grid_desc_m_n_container_{ "
<< arg.c_grid_desc_m_n_container_[i].GetLength(I0) << ", "
<< arg.c_grid_desc_m_n_container_[i].GetLength(I1) << "}"
<< std::endl;
std::cout << "arg.c_grid_desc_m0_m10_m11_n0_n10_n11_container_( "
<< arg.c_grid_desc_m0_m10_m11_n0_n10_n11_container_[i].GetLength(I0)
<< ", "
<< arg.c_grid_desc_m0_m10_m11_n0_n10_n11_container_[i].GetLength(I1)
<< ", "
<< arg.c_grid_desc_m0_m10_m11_n0_n10_n11_container_[i].GetLength(I2)
<< ", "
<< arg.c_grid_desc_m0_m10_m11_n0_n10_n11_container_[i].GetLength(I3)
<< ", "
<< arg.c_grid_desc_m0_m10_m11_n0_n10_n11_container_[i].GetLength(I4)
<< ", "
<< arg.c_grid_desc_m0_m10_m11_n0_n10_n11_container_[i].GetLength(I5)
<< " ) " << std::endl;
}
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_container_[i],
arg.b_grid_desc_k0_n_k1_container_[i],
arg.c_grid_desc_m_n_container_[i]))
{
throw std::runtime_error(
"wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v3r1 has invalid setting");
}
const index_t grid_size = arg.block_2_ctile_map_container_[i].CalculateGridSize(
arg.c_grid_desc_m_n_container_[i]);
auto launch_kernel = [&](auto has_main_k_block_loop,
auto has_double_tail_k_block_loop) {
constexpr bool has_main_loop = has_main_k_block_loop.value;
constexpr bool has_double_loop = has_double_tail_k_block_loop;
const auto kernel = kernel_gemm_dl_v1r3<
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
CDataType,
remove_reference_t<DeviceOp::AGridDesc_K0_M0_M1_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N0_N1_K1>,
remove_reference_t<DeviceOp::CGridDesc_M0_M10_M11_N0_N10_N11>,
remove_reference_t<DeviceOp::DefaultBlock2CTileMap>,
has_main_loop,
has_double_loop>;
ave_time +=
launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.a_grid_desc_k0_m0_m1_k1_container_[i],
arg.b_grid_desc_k0_n0_n1_k1_container_[i],
arg.c_grid_desc_m0_m10_m11_n0_n10_n11_container_[i],
arg.block_2_ctile_map_container_[i]);
};
const auto K0 = arg.a_grid_desc_k0_m0_m1_k1_container_[i].GetLength(I0);
const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K0);
const bool has_double_tail_k_block_loop =
GridwiseGemm::CalculateHasDoubleTailKBlockLoop(K0);
if(has_main_k_block_loop && has_double_tail_k_block_loop)
{
launch_kernel(integral_constant<bool, true>{}, integral_constant<bool, true>{});
}
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
{
launch_kernel(integral_constant<bool, true>{},
integral_constant<bool, false>{});
}
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
{
launch_kernel(integral_constant<bool, false>{},
integral_constant<bool, true>{});
}
else
{
launch_kernel(integral_constant<bool, false>{},
integral_constant<bool, false>{});
}
}
return ave_time;
}
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
static bool IsSupportedArgument(const Argument& arg)
{
// check device
if(!(ck::get_device_name() == "gfx906" || ck::get_device_name() == "gfx1030"))
{
return false;
}
if constexpr(ConvBackwardDataSpecialization ==
ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0)
{
// check if it's 1x1, stride=1 pad = 0 conv
for(int i = 0; i < NDimSpatial; i++)
{
if(!(arg.filter_spatial_lengths_[i] == 1 && arg.conv_filter_strides_[i] == 1 &&
arg.input_left_pads_[i] == 0 && arg.input_right_pads_[i] == 0))
{
return false;
}
}
}
// matrix A
{
auto srcVectorLengths = ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1{};
if(srcVectorLengths[I1] != 1 || srcVectorLengths[I2] != 1)
{
return false;
}
if(K1 % srcVectorLengths[I3] != 0 || K0PerBlock % srcVectorLengths[I0] != 0)
{
return false;
}
const index_t K = arg.Conv_K_;
if(K % (srcVectorLengths[I0] * srcVectorLengths[I3]) != 0)
{
return false;
}
}
// matrix B
{
auto srcLoadLenghts = BBlockTransferThreadSliceLengths_K0_N0_N1_K1{};
auto srcVectorLengths = BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1{};
if(srcVectorLengths[I0] != 1 || srcVectorLengths[I3] != 1)
{
return false;
}
if(srcLoadLenghts[I1] % srcVectorLengths[I1] != 0 ||
srcLoadLenghts[I2] % srcVectorLengths[I2] != 0)
{
return false;
}
const index_t C = arg.Conv_K_;
if(C % (srcVectorLengths[I1] * srcVectorLengths[I2]) != 0)
{
return false;
}
}
// vector store C matrix into global memory
if(!(arg.Conv_C_ % CThreadTransferDstScalarPerVector == 0))
{
std::cout << "Not surpport,because: arg.Conv_C_ % CThreadTransferDstScalarPerVector = "
<< arg.Conv_C_ % CThreadTransferDstScalarPerVector << std::endl;
return false;
}
// Gridwise GEMM size
for(std::size_t i = 0; i < arg.a_grid_desc_k0_m_k1_container_.size(); i++)
{
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_container_[i],
arg.b_grid_desc_k0_n_k1_container_[i],
arg.c_grid_desc_m_n_container_[i]))
{
return false;
}
}
return true;
}
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
static auto MakeArgument(InDataType* p_in_grid,
const WeiDataType* p_wei_grid,
const OutDataType* p_out_grid,
ck::index_t N,
ck::index_t K,
ck::index_t C,
std::vector<ck::index_t> input_spatial_lengths,
std::vector<ck::index_t> filter_spatial_lengths,
std::vector<ck::index_t> output_spatial_lengths,
std::vector<ck::index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads,
InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op)
{
return Argument{p_in_grid,
p_wei_grid,
p_out_grid,
N,
K,
C,
input_spatial_lengths,
filter_spatial_lengths,
output_spatial_lengths,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
in_element_op,
wei_element_op,
out_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
std::unique_ptr<BaseArgument>
MakeArgumentPointer(void* p_in_grid,
const void* p_wei_grid,
const void* p_out_grid,
ck::index_t N,
ck::index_t K,
ck::index_t C,
std::vector<ck::index_t> input_spatial_lengths,
std::vector<ck::index_t> filter_spatial_lengths,
std::vector<ck::index_t> output_spatial_lengths,
std::vector<ck::index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads,
InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op) override
{
return std::make_unique<Argument>(static_cast<InDataType*>(p_in_grid),
static_cast<const WeiDataType*>(p_wei_grid),
static_cast<const OutDataType*>(p_out_grid),
N,
K,
C,
input_spatial_lengths,
filter_spatial_lengths,
output_spatial_lengths,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
in_element_op,
wei_element_op,
out_element_op);
}
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>(Invoker{});
}
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "DeviceConvNdBwdDataNwcKxcNwk_Dl"
<< "<"
<< BlockSize << ", "
<< MPerBlock << ", "
<< NPerBlock << ", "
<< K0PerBlock
<< ">";
if constexpr(ConvBackwardDataSpecialization ==
ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0){
str<< " Filter1x1Stride1Pad0";
}
return str.str();
}
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -194,21 +194,36 @@ struct Relu
}
};
// https://paperswithcode.com/method/gelu
// y = 0.5*x*(1+tanh(sqrt(2/pi)*(x+0.044715*x^3)))
// Y = FastGelu(X)
struct FastGelu
{
template <typename Y, typename X>
__host__ __device__ void operator()(Y& y, const X& x) const;
template <>
__host__ __device__ void operator()<float, float>(float& y, const float& x) const
// Fast GeLU
// https://paperswithcode.com/method/gelu
// y = 0.5*x*(1+tanh(sqrt(2/pi)*(x+0.044715*x^3)))
__host__ __device__ static constexpr float GetFastGeLU(float x)
{
const float u = float(2) * x * (float(0.035677) * x * x + float(0.797885));
const float u = 2.f * x * (0.035677f * x * x + 0.797885f);
const float emu = exp(-u);
const float cdf = float(0.5) + float(0.5) * (float(2) / (float(1) + emu) - float(1));
const float cdf = 0.5f + 0.5f * (2.f / (1.f + emu) - 1.f);
return x * cdf;
}
template <typename T>
static inline constexpr bool is_valid_param_type_v =
std::is_same_v<T, float> || std::is_same_v<T, half_t> || std::is_same_v<T, bhalf_t> ||
std::is_same_v<T, int32_t> || std::is_same_v<T, int8_t>
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
|| std::is_same_v<T, ck::int4_t>
#endif
;
template <typename Y, typename X>
__host__ __device__ void operator()(Y& y, const X& x) const
{
static_assert(is_valid_param_type_v<Y> && is_valid_param_type_v<X>);
y = x * cdf;
const float tmp_y = GetFastGeLU(type_convert<float>(x));
y = type_convert<Y>(tmp_y);
}
};
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/data_type.hpp"
#include "ck/tensor_operation/gpu/block/reduction_functions_blockwise.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_welford.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace ck {
template <typename GridwiseReduceSecondHalfBatchNormBackwardFinal_,
typename XDataType,
typename DyDataType,
typename DxDataType,
typename ScaleDataType,
typename BiasDataType,
typename MeanVarDataType,
typename DyElementwiseOp,
typename XYGridDesc_M_K,
typename DscaleDbiasGridDesc_M_K,
typename MeanVarGridDesc_M,
typename ScaleBiasGridDesc_M>
__global__ void kernel_reduce_second_half_batchnorm_backward_final(
const XYGridDesc_M_K x_grid_desc_m_k,
const XYGridDesc_M_K dy_grid_desc_m_k,
const XYGridDesc_M_K dx_grid_desc_m_k,
const DscaleDbiasGridDesc_M_K dscale_dbias_grid_desc_m_k,
const MeanVarGridDesc_M mean_var_grid_desc_m,
const ScaleBiasGridDesc_M scale_grid_desc_m,
const ScaleBiasGridDesc_M bias_grid_desc_m,
index_t blkgroup_size,
long_index_t reduce_size,
index_t num_xy_k_block_tile_iteration,
index_t num_dscale_dbias_k_block_tile_iteration,
const ScaleDataType* const __restrict__ p_reduce_dscale,
const BiasDataType* const __restrict__ p_reduce_dbias,
const MeanVarDataType* const __restrict__ p_mean,
const MeanVarDataType* const __restrict__ p_inv_var,
const XDataType* const __restrict__ p_x,
const DyDataType* const __restrict__ p_dy,
const ScaleDataType* const __restrict__ p_scale,
const DyElementwiseOp dy_elementwise_op,
DxDataType* const __restrict__ p_dx,
ScaleDataType* const __restrict__ p_dscale,
BiasDataType* const __restrict__ p_dbias)
{
GridwiseReduceSecondHalfBatchNormBackwardFinal_::Run(x_grid_desc_m_k,
dy_grid_desc_m_k,
dx_grid_desc_m_k,
dscale_dbias_grid_desc_m_k,
mean_var_grid_desc_m,
scale_grid_desc_m,
bias_grid_desc_m,
blkgroup_size,
reduce_size,
num_xy_k_block_tile_iteration,
num_dscale_dbias_k_block_tile_iteration,
p_reduce_dscale,
p_reduce_dbias,
p_mean,
p_inv_var,
p_x,
p_dy,
p_scale,
dy_elementwise_op,
p_dx,
p_dscale,
p_dbias);
};
template <typename XDataType,
typename DyDataType,
typename DxDataType,
typename AccDataType,
typename ScaleDataType,
typename BiasDataType,
typename MeanVarDataType,
typename DyElementwiseOp,
typename XYGridDesc_M_K,
typename DscaleDbiasGridDesc_M_K,
typename MeanVarGridDesc_M,
typename ScaleBiasGridDesc_M,
index_t BlockSize,
index_t MThreadClusterSize,
index_t KThreadClusterSize,
index_t MThreadSliceSize,
index_t KThreadSliceSize,
index_t XDyDxVectorDim,
index_t XSrcVectorSize,
index_t DySrcVectorSize,
index_t DxDstVectorSize,
index_t ScaleSrcDstVectorSize,
index_t BiasDstVectorSize,
index_t MeanVarSrcVectorSize>
struct GridwiseReduceSecondHalfBatchNormBackwardFinal
{
static_assert((XDyDxVectorDim == 0 && MThreadSliceSize % XSrcVectorSize == 0 &&
MThreadSliceSize % DySrcVectorSize == 0 &&
MThreadSliceSize % DxDstVectorSize == 0) ||
(XDyDxVectorDim == 1 && KThreadSliceSize % XSrcVectorSize == 0 &&
KThreadSliceSize % DySrcVectorSize == 0 &&
KThreadSliceSize % DxDstVectorSize == 0),
"Invalid thread slice sizes and/or vector sizes configuration, please check!");
static constexpr bool reorder_thread_cluster = (XDyDxVectorDim == 0);
using ThreadClusterLengths_M_K = Sequence<MThreadClusterSize, KThreadClusterSize>;
using ThreadBufferDimAccessOrder =
typename conditional<reorder_thread_cluster, Sequence<1, 0>, Sequence<0, 1>>::type;
using ThreadClusterArrangeOrder =
typename conditional<reorder_thread_cluster, Sequence<1, 0>, Sequence<0, 1>>::type;
static constexpr auto thread_cluster_desc =
make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{});
using ThreadReduceSrcDesc_M_1 = decltype(
make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{}, Number<1>{})));
using ThreadReduceDstDesc_M =
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{})));
using BlockwiseReduce = PartitionedBlockwiseReduction<AccDataType,
BlockSize,
ThreadClusterLengths_M_K,
ThreadClusterArrangeOrder,
ck::reduce::Add,
false>;
using ThreadwiseReduce = ThreadwiseReduction<AccDataType,
ThreadReduceSrcDesc_M_1,
ThreadReduceDstDesc_M,
ck::reduce::Add,
false>;
using PassThroughOp = tensor_operation::element_wise::PassThrough;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
// clang-format off
// Two of the steps of Multiblock BatchNorm Backward
// Step 1: Second half of Reduction: dbias = sum(dy), dscale = sum(dy * (x-mean) * inv-variance)
// Step 2: calculating dx = 1/reduce_size * inv-variance * scale * (reduce_size * dy - dbias - dscale * (x - mean) * inv-variance)) elementwise-ly
// clang-format on
__device__ static void Run(const XYGridDesc_M_K& x_grid_desc_m_k,
const XYGridDesc_M_K& dy_grid_desc_m_k,
const XYGridDesc_M_K& dx_grid_desc_m_k,
const DscaleDbiasGridDesc_M_K& dscale_dbias_grid_desc_m_k,
const MeanVarGridDesc_M& mean_var_grid_desc_m,
const ScaleBiasGridDesc_M& scale_grid_desc_m,
const ScaleBiasGridDesc_M& bias_grid_desc_m,
index_t blkgroup_size,
long_index_t reduce_size,
index_t num_xy_k_block_tile_iteration,
index_t num_dscale_dbias_k_block_tile_iteration,
const ScaleDataType* const __restrict__ p_reduce_dscale,
const BiasDataType* const __restrict__ p_reduce_dbias,
const MeanVarDataType* const __restrict__ p_mean,
const MeanVarDataType* const __restrict__ p_inv_var,
const XDataType* const __restrict__ p_x,
const DyDataType* const __restrict__ p_dy,
const ScaleDataType* const __restrict__ p_scale,
const DyElementwiseOp dy_elementwise_op,
DxDataType* const __restrict__ p_dx,
ScaleDataType* const __restrict__ p_dscale,
BiasDataType* const __restrict__ p_dbias)
{
__shared__ AccDataType p_reduce_work_buffer[BlockSize];
auto reduce_work_buf =
make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_buffer, BlockSize);
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * 1, true>
reduce_dscale_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * 1, true>
reduce_dbias_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> dscale_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> dbias_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
x_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
dy_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
dx_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> mean_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true>
inv_var_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> scale_thread_buf;
const index_t thread_local_id = get_thread_local_1d_id();
const index_t block_global_id = get_block_1d_id();
const index_t blkgroup_id = block_global_id / blkgroup_size;
const index_t block_local_id = block_global_id % blkgroup_size;
const auto thread_cluster_idx =
thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));
const auto thread_m_cluster_id = thread_cluster_idx[I0];
const auto thread_k_cluster_id = thread_cluster_idx[I1];
using ThreadBufferLengths_M_K = Sequence<MThreadSliceSize, KThreadSliceSize>;
using ThreadBufferLengths_M = Sequence<MThreadSliceSize>;
using ThreadBufferLengths_M_1 = Sequence<MThreadSliceSize, 1>;
constexpr auto thread_buffer_desc_m_k = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{}));
constexpr auto thread_buffer_desc_m =
make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{}));
constexpr auto thread_buffer_desc_m_1 = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<1>{}));
// clang-format off
// Step 1: do final reduction of dbias = sum(dy), dscale = sum(dy * (x-mean) * inv-variance)
// clang-format on
auto threadwise_dscale_load_m_k =
ThreadwiseTensorSliceTransfer_v2<ScaleDataType,
AccDataType,
DscaleDbiasGridDesc_M_K,
decltype(thread_buffer_desc_m_1),
ThreadBufferLengths_M_1,
Sequence<0, 1>,
1,
1,
1,
true>(
dscale_dbias_grid_desc_m_k,
make_multi_index(blkgroup_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * 1));
auto threadwise_dbias_load_m_k =
ThreadwiseTensorSliceTransfer_v2<BiasDataType,
AccDataType,
DscaleDbiasGridDesc_M_K,
decltype(thread_buffer_desc_m_1),
ThreadBufferLengths_M_1,
Sequence<0, 1>,
1,
1,
1,
true>(
dscale_dbias_grid_desc_m_k,
make_multi_index(blkgroup_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * 1));
auto threadwise_dscale_store_m =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
ScaleDataType,
decltype(thread_buffer_desc_m),
ScaleBiasGridDesc_M,
PassThroughOp,
ThreadBufferLengths_M,
Sequence<0>,
0,
ScaleSrcDstVectorSize,
InMemoryDataOperationEnum::Set,
1,
true>(
scale_grid_desc_m,
make_multi_index(blkgroup_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize),
PassThroughOp{});
auto threadwise_dbias_store_m =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
BiasDataType,
decltype(thread_buffer_desc_m),
ScaleBiasGridDesc_M,
PassThroughOp,
ThreadBufferLengths_M,
Sequence<0>,
0,
BiasDstVectorSize,
InMemoryDataOperationEnum::Set,
1,
true>(
bias_grid_desc_m,
make_multi_index(blkgroup_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize),
PassThroughOp{});
const auto reduce_dscale_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_reduce_dscale, dscale_dbias_grid_desc_m_k.GetElementSpaceSize());
const auto reduce_dbias_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_reduce_dbias, dscale_dbias_grid_desc_m_k.GetElementSpaceSize());
auto dscale_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_dscale, scale_grid_desc_m.GetElementSpaceSize());
auto dbias_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_dbias, bias_grid_desc_m.GetElementSpaceSize());
constexpr auto dscale_dbias_thread_copy_step_m_k =
make_multi_index(0, KThreadClusterSize * 1);
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
dscale_thread_buf(I) = type_convert<AccDataType>(0.0f);
dbias_thread_buf(I) = type_convert<AccDataType>(0.0f);
});
for(index_t reducedTiles = 0; reducedTiles < num_dscale_dbias_k_block_tile_iteration;
++reducedTiles)
{
threadwise_dscale_load_m_k.Run(dscale_dbias_grid_desc_m_k,
reduce_dscale_global_buf,
thread_buffer_desc_m_1,
make_tuple(I0, I0),
reduce_dscale_thread_buf);
threadwise_dbias_load_m_k.Run(dscale_dbias_grid_desc_m_k,
reduce_dbias_global_buf,
thread_buffer_desc_m_1,
make_tuple(I0, I0),
reduce_dbias_thread_buf);
ThreadwiseReduce::Reduce(reduce_dscale_thread_buf, dscale_thread_buf);
ThreadwiseReduce::Reduce(reduce_dbias_thread_buf, dbias_thread_buf);
threadwise_dscale_load_m_k.MoveSrcSliceWindow(dscale_dbias_grid_desc_m_k,
dscale_dbias_thread_copy_step_m_k);
threadwise_dbias_load_m_k.MoveSrcSliceWindow(dscale_dbias_grid_desc_m_k,
dscale_dbias_thread_copy_step_m_k);
}
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
if constexpr(I > 0)
block_sync_lds();
BlockwiseReduce::Reduce(reduce_work_buf, dscale_thread_buf(I));
block_sync_lds();
BlockwiseReduce::Reduce(reduce_work_buf, dbias_thread_buf(I));
});
threadwise_dscale_store_m.Run(thread_buffer_desc_m,
make_tuple(I0),
dscale_thread_buf,
scale_grid_desc_m,
dscale_global_buf);
threadwise_dbias_store_m.Run(thread_buffer_desc_m,
make_tuple(I0),
dbias_thread_buf,
bias_grid_desc_m,
dbias_global_buf);
// clang-format off
// Step 2: calculate dx = 1/N * inv-variance * scale * (N * dy - dbias - dscale * (x - mean) * inv-variance)
// clang-format on
const index_t workSizePerBlock = K_BlockTileSize * num_xy_k_block_tile_iteration;
auto threadwise_x_load = ThreadwiseTensorSliceTransfer_v2<XDataType,
AccDataType,
XYGridDesc_M_K,
decltype(thread_buffer_desc_m_k),
ThreadBufferLengths_M_K,
ThreadBufferDimAccessOrder,
XDyDxVectorDim,
XSrcVectorSize,
1,
true>(
x_grid_desc_m_k,
make_multi_index(blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize,
workSizePerBlock * block_local_id +
thread_k_cluster_id * KThreadSliceSize));
auto threadwise_dy_load = ThreadwiseTensorSliceTransfer_v2<DyDataType,
AccDataType,
XYGridDesc_M_K,
decltype(thread_buffer_desc_m_k),
ThreadBufferLengths_M_K,
ThreadBufferDimAccessOrder,
XDyDxVectorDim,
DySrcVectorSize,
1,
true>(
dy_grid_desc_m_k,
make_multi_index(blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize,
workSizePerBlock * block_local_id +
thread_k_cluster_id * KThreadSliceSize));
auto threadwise_dx_store =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
DxDataType,
decltype(thread_buffer_desc_m_k),
XYGridDesc_M_K,
PassThroughOp,
ThreadBufferLengths_M_K,
ThreadBufferDimAccessOrder,
XDyDxVectorDim,
DxDstVectorSize,
InMemoryDataOperationEnum::Set,
1,
true>(
dx_grid_desc_m_k,
make_multi_index(
blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize,
workSizePerBlock * block_local_id + thread_k_cluster_id * KThreadSliceSize),
PassThroughOp{});
auto threadwise_scale_load =
ThreadwiseTensorSliceTransfer_v2<ScaleDataType,
AccDataType,
ScaleBiasGridDesc_M,
decltype(thread_buffer_desc_m),
ThreadBufferLengths_M,
Sequence<0>,
0,
ScaleSrcDstVectorSize,
1,
true>(
scale_grid_desc_m,
make_multi_index(blkgroup_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize));
auto threadwise_mean_var_load =
ThreadwiseTensorSliceTransfer_v2<MeanVarDataType,
AccDataType,
MeanVarGridDesc_M,
decltype(thread_buffer_desc_m),
ThreadBufferLengths_M,
Sequence<0>,
0,
MeanVarSrcVectorSize,
1,
true>(
mean_var_grid_desc_m,
make_multi_index(blkgroup_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize));
const auto x_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_x, x_grid_desc_m_k.GetElementSpaceSize());
const auto dy_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_dy, dy_grid_desc_m_k.GetElementSpaceSize());
auto dx_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_dx, dx_grid_desc_m_k.GetElementSpaceSize());
const auto scale_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_scale, scale_grid_desc_m.GetElementSpaceSize());
const auto mean_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_mean, mean_var_grid_desc_m.GetElementSpaceSize());
const auto inv_var_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_inv_var, mean_var_grid_desc_m.GetElementSpaceSize());
threadwise_scale_load.Run(scale_grid_desc_m,
scale_global_buf,
thread_buffer_desc_m,
make_tuple(I0),
scale_thread_buf);
threadwise_mean_var_load.Run(mean_var_grid_desc_m,
mean_global_buf,
thread_buffer_desc_m,
make_tuple(I0),
mean_thread_buf);
threadwise_mean_var_load.Run(mean_var_grid_desc_m,
inv_var_global_buf,
thread_buffer_desc_m,
make_tuple(I0),
inv_var_thread_buf);
constexpr auto xy_thread_copy_step_m_k = make_multi_index(0, K_BlockTileSize);
AccDataType inv_reduce_size =
type_convert<AccDataType>(1.0) / type_convert<AccDataType>(reduce_size);
for(index_t reducedTiles = 0; reducedTiles < num_xy_k_block_tile_iteration; ++reducedTiles)
{
threadwise_x_load.Run(x_grid_desc_m_k,
x_global_buf,
thread_buffer_desc_m_k,
make_tuple(I0, I0),
x_thread_buf);
threadwise_dy_load.Run(dy_grid_desc_m_k,
dy_global_buf,
thread_buffer_desc_m_k,
make_tuple(I0, I0),
dy_thread_buf);
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
AccDataType multiplier =
inv_reduce_size * inv_var_thread_buf[iM] * scale_thread_buf[iM];
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset =
thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK));
dy_elementwise_op(dy_thread_buf(Number<offset>{}),
dy_thread_buf[Number<offset>{}]);
AccDataType norm_x = (x_thread_buf[Number<offset>{}] - mean_thread_buf[iM]) *
inv_var_thread_buf[iM];
AccDataType tmpVal = norm_x * dscale_thread_buf[iM];
dx_thread_buf(Number<offset>{}) =
multiplier *
(type_convert<AccDataType>(reduce_size) * dy_thread_buf[Number<offset>{}] -
dbias_thread_buf[iM] - tmpVal);
});
});
threadwise_dx_store.Run(thread_buffer_desc_m_k,
make_tuple(I0, I0),
dx_thread_buf,
dx_grid_desc_m_k,
dx_global_buf);
threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, xy_thread_copy_step_m_k);
threadwise_dy_load.MoveSrcSliceWindow(dy_grid_desc_m_k, xy_thread_copy_step_m_k);
threadwise_dx_store.MoveDstSliceWindow(dx_grid_desc_m_k, xy_thread_copy_step_m_k);
}
};
};
} // namespace ck
......@@ -93,6 +93,9 @@ struct GridwiseMultiblockWelfordFirstHalf
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
// clang-format off
// First half of the Multiblock Welford method to calculate mean and variance, used by both batchnorm-forward and batchnorm-backward.
// clang-format on
__device__ static void Run(const XGridDesc_M_K& x_grid_desc_m_k,
const MeanVarCountGridDesc_M_G& mean_var_count_grid_desc_m_g,
const GetReduceCountPerThreadFunctor& get_reduce_count_per_thread,
......
......@@ -529,6 +529,7 @@ struct GridwiseWelfordSecondHalfBatchNormForwardFinal
auto result_inv_var_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
resultSaveInvVariance, mean_var_grid_desc_m.GetElementSpaceSize());
// calculate inv-variance as 1/sqrt(epsilon+variance)
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
welford_var_thread_buf(I) =
type_convert<AccDataType>(1.0f) / sqrt(epsilon + welford_var_thread_buf[I]);
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/data_type.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_welford.hpp"
#include "ck/tensor_operation/gpu/block/reduction_functions_blockwise.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_welford.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace ck {
template <typename GridwiseWelfordSecondHalfReduceFirstHalf_,
typename XDataType,
typename DyDataType,
typename AccDataType,
typename ScaleDataType,
typename BiasDataType,
typename MeanVarDataType,
typename DyElementwiseOp,
typename XYGridDesc_M_K,
typename MeanVarGridDesc_M,
typename MeanVarCountGridDesc_M_K,
typename DscaleDbiasGridDesc_M_G>
__global__ void kernel_welford_second_half_reduce_first_half(
const XYGridDesc_M_K x_grid_desc_m_k,
const XYGridDesc_M_K dy_grid_desc_m_k,
const MeanVarGridDesc_M mean_var_grid_desc_m,
const MeanVarCountGridDesc_M_K mean_var_count_grid_desc_m_k,
const DscaleDbiasGridDesc_M_G dscale_dbias_grid_desc_m_g,
index_t blkgroup_size,
index_t num_xy_k_block_tile_iteration,
index_t num_mean_var_count_k_block_tile_iteration,
AccDataType epsilon,
bool haveSavedMeanInvVar,
const MeanVarDataType* const __restrict__ p_savedMean,
const MeanVarDataType* const __restrict__ p_savedInvVar,
const MeanVarDataType* const __restrict__ p_in_welford_mean,
const MeanVarDataType* const __restrict__ p_in_welford_variance,
const int32_t* const __restrict__ p_in_welford_count,
const DyElementwiseOp dy_elementwise_op,
MeanVarDataType* const __restrict__ p_out_welford_mean,
MeanVarDataType* const __restrict__ p_out_welford_inv_variance,
const XDataType* const __restrict__ p_x,
const DyDataType* const __restrict__ p_dy,
ScaleDataType* const __restrict__ p_reduce_dscale,
BiasDataType* const __restrict__ p_reduce_dbias)
{
GridwiseWelfordSecondHalfReduceFirstHalf_::Run(x_grid_desc_m_k,
dy_grid_desc_m_k,
mean_var_grid_desc_m,
mean_var_count_grid_desc_m_k,
dscale_dbias_grid_desc_m_g,
blkgroup_size,
num_xy_k_block_tile_iteration,
num_mean_var_count_k_block_tile_iteration,
epsilon,
haveSavedMeanInvVar,
p_savedMean,
p_savedInvVar,
p_in_welford_mean,
p_in_welford_variance,
p_in_welford_count,
dy_elementwise_op,
p_out_welford_mean,
p_out_welford_inv_variance,
p_x,
p_dy,
p_reduce_dscale,
p_reduce_dbias);
};
template <typename XDataType,
typename DyDataType,
typename AccDataType,
typename ScaleDataType,
typename BiasDataType,
typename MeanVarDataType,
typename DyElementwiseOp,
typename XYGridDesc_M_K,
typename MeanVarGridDesc_M,
typename MeanVarCountGridDesc_M_K,
typename DscaleDbiasGridDesc_M_G,
index_t BlockSize,
index_t MThreadClusterSize,
index_t KThreadClusterSize,
index_t MThreadSliceSize,
index_t KThreadSliceSize,
index_t XDyVectorDim,
index_t XSrcVectorSize,
index_t DySrcVectorSize,
index_t MeanVarSrcVectorSize>
struct GridwiseWelfordSecondHalfReduceFirstHalf
{
static_assert((XDyVectorDim == 0 && MThreadSliceSize % XSrcVectorSize == 0 &&
MThreadSliceSize % DySrcVectorSize == 0) ||
(XDyVectorDim == 1 && KThreadSliceSize % XSrcVectorSize == 0 &&
KThreadSliceSize % DySrcVectorSize == 0),
"Invalid thread slice sizes and/or vector sizes configuration, please check!");
static constexpr bool reorder_thread_cluster = (XDyVectorDim == 0);
using ThreadClusterLengths_M_K = Sequence<MThreadClusterSize, KThreadClusterSize>;
using ThreadBufferDimAccessOrder =
typename conditional<reorder_thread_cluster, Sequence<1, 0>, Sequence<0, 1>>::type;
using ThreadClusterArrangeOrder =
typename conditional<reorder_thread_cluster, Sequence<1, 0>, Sequence<0, 1>>::type;
static constexpr auto thread_cluster_desc =
make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{});
using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{})));
using ThreadReduceSrcDesc_M_1 = decltype(
make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{}, Number<1>{})));
using ThreadReduceDstDesc_M =
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{})));
using ThreadwiseWelford =
ThreadwiseWelfordMerge<AccDataType, ThreadReduceSrcDesc_M_1, ThreadReduceDstDesc_M>;
using BlockwiseWelford = BlockwiseWelford<AccDataType,
BlockSize,
ThreadClusterLengths_M_K,
ThreadClusterArrangeOrder>;
using BlockwiseReduce = PartitionedBlockwiseReduction<AccDataType,
BlockSize,
ThreadClusterLengths_M_K,
ThreadClusterArrangeOrder,
ck::reduce::Add,
false>;
using ThreadwiseReduce = ThreadwiseReduction<AccDataType,
ThreadReduceSrcDesc_M_K,
ThreadReduceDstDesc_M,
ck::reduce::Add,
false>;
using PassThroughOp = tensor_operation::element_wise::PassThrough;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
// clang-format off
// Two of the steps of Multiblock BatchNorm Backward
// Step 1: Second half of Welford method to calculate mean and variance, as well as getting inv-variance = 1/sqrt(epsilon+variance)
// Step 2: First half of Reduction: dbias = sum(dy), dscale = sum(dy * (x-mean) * inv-variance)
// clang-format on
__device__ static void Run(const XYGridDesc_M_K& x_grid_desc_m_k,
const XYGridDesc_M_K& dy_grid_desc_m_k,
const MeanVarGridDesc_M& mean_var_grid_desc_m,
const MeanVarCountGridDesc_M_K& mean_var_count_grid_desc_m_k,
const DscaleDbiasGridDesc_M_G& dscale_dbias_grid_desc_m_g,
index_t blkgroup_size,
index_t num_xy_k_block_tile_iteration,
index_t num_mean_var_count_k_block_tile_iteration,
AccDataType epsilon,
bool haveSavedMeanInvVar,
const MeanVarDataType* const __restrict__ p_savedMean,
const MeanVarDataType* const __restrict__ p_savedInvVar,
const MeanVarDataType* const __restrict__ p_in_welford_mean,
const MeanVarDataType* const __restrict__ p_in_welford_variance,
const int32_t* const __restrict__ p_in_welford_count,
const DyElementwiseOp dy_elementwise_op,
MeanVarDataType* const __restrict__ p_out_welford_mean,
MeanVarDataType* const __restrict__ p_out_welford_inv_variance,
const XDataType* const __restrict__ p_x,
const DyDataType* const __restrict__ p_dy,
ScaleDataType* const __restrict__ p_reduce_dscale,
BiasDataType* const __restrict__ p_reduce_dbias)
{
__shared__ AccDataType p_reduce_work_buffer[BlockSize];
auto reduce_work_buf =
make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_buffer, BlockSize);
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * 1, true>
in_welford_mean_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * 1, true>
in_welford_var_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, int32_t, MThreadSliceSize * 1, true>
in_welford_count_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true>
welford_mean_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true>
welford_var_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, int32_t, MThreadSliceSize, true>
welford_count_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true>& mean_thread_buf =
welford_mean_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true>&
inv_var_thread_buf = welford_var_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
x_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
dy_thread_buf;
// buffer of values of dy * (x-mean) * inv-variance, used as input of Blockwise reduction
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
tmp1_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true>
reduce_dscale_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true>
reduce_dbias_thread_buf;
const index_t thread_local_id = get_thread_local_1d_id();
const index_t block_global_id = get_block_1d_id();
const index_t blkgroup_id = block_global_id / blkgroup_size;
const index_t block_local_id = block_global_id % blkgroup_size;
const auto thread_cluster_idx =
thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));
const auto thread_m_cluster_id = thread_cluster_idx[I0];
const auto thread_k_cluster_id = thread_cluster_idx[I1];
using ThreadBufferLengths_M_K = Sequence<MThreadSliceSize, KThreadSliceSize>;
using ThreadBufferLengths_M = Sequence<MThreadSliceSize>;
using ThreadBufferLengths_M_1 = Sequence<MThreadSliceSize, 1>;
constexpr auto thread_buffer_desc_m_k = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{}));
constexpr auto thread_buffer_desc_m =
make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{}));
constexpr auto thread_buffer_desc_m_1 = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<1>{}));
// clang-format off
// Step 1: load existing mean and inv-variance, or do final welford reduction on mean and variance as well as get inv-variance = 1/sqrt(epsilon+variance)
// clang-format on
if(haveSavedMeanInvVar)
{
const auto mean_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_savedMean, mean_var_grid_desc_m.GetElementSpaceSize());
const auto inv_var_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_savedInvVar, mean_var_grid_desc_m.GetElementSpaceSize());
auto threadwise_mean_inv_var_load =
ThreadwiseTensorSliceTransfer_v2<MeanVarDataType,
AccDataType,
MeanVarGridDesc_M,
decltype(thread_buffer_desc_m),
ThreadBufferLengths_M,
Sequence<0>,
0,
MeanVarSrcVectorSize,
1,
true>(
mean_var_grid_desc_m,
make_multi_index(blkgroup_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize));
threadwise_mean_inv_var_load.Run(mean_var_grid_desc_m,
mean_global_buf,
thread_buffer_desc_m,
make_tuple(I0),
mean_thread_buf);
threadwise_mean_inv_var_load.Run(mean_var_grid_desc_m,
inv_var_global_buf,
thread_buffer_desc_m,
make_tuple(I0),
inv_var_thread_buf);
}
else
{
const auto welford_mean_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_welford_mean, mean_var_count_grid_desc_m_k.GetElementSpaceSize());
const auto welford_var_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_welford_variance, mean_var_count_grid_desc_m_k.GetElementSpaceSize());
const auto welford_count_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_welford_count, mean_var_count_grid_desc_m_k.GetElementSpaceSize());
auto threadwise_mean_var_load_m_k =
ThreadwiseTensorSliceTransfer_v2<AccDataType,
AccDataType,
MeanVarCountGridDesc_M_K,
decltype(thread_buffer_desc_m_1),
ThreadBufferLengths_M_1,
Sequence<0, 1>,
1,
1,
1,
true>(
mean_var_count_grid_desc_m_k,
make_multi_index(blkgroup_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * 1));
auto threadwise_count_load_m_k =
ThreadwiseTensorSliceTransfer_v2<int32_t,
int32_t,
MeanVarCountGridDesc_M_K,
decltype(thread_buffer_desc_m_1),
ThreadBufferLengths_M_1,
Sequence<0, 1>,
1,
1,
1,
true>(
mean_var_count_grid_desc_m_k,
make_multi_index(blkgroup_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * 1));
constexpr auto mean_var_count_thread_copy_step_m_k =
make_multi_index(0, KThreadClusterSize * 1);
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
welford_mean_thread_buf(I) = type_convert<AccDataType>(0.0f);
welford_var_thread_buf(I) = type_convert<AccDataType>(0.0f);
welford_count_thread_buf(I) = 0;
});
for(index_t reducedTiles = 0; reducedTiles < num_mean_var_count_k_block_tile_iteration;
++reducedTiles)
{
threadwise_mean_var_load_m_k.Run(mean_var_count_grid_desc_m_k,
welford_mean_global_buf,
thread_buffer_desc_m_1,
make_tuple(I0, I0),
in_welford_mean_thread_buf);
threadwise_mean_var_load_m_k.Run(mean_var_count_grid_desc_m_k,
welford_var_global_buf,
thread_buffer_desc_m_1,
make_tuple(I0, I0),
in_welford_var_thread_buf);
threadwise_count_load_m_k.Run(mean_var_count_grid_desc_m_k,
welford_count_global_buf,
thread_buffer_desc_m_1,
make_tuple(I0, I0),
in_welford_count_thread_buf);
ThreadwiseWelford::Run(in_welford_mean_thread_buf,
in_welford_var_thread_buf,
in_welford_count_thread_buf,
welford_mean_thread_buf,
welford_var_thread_buf,
welford_count_thread_buf);
threadwise_mean_var_load_m_k.MoveSrcSliceWindow(
mean_var_count_grid_desc_m_k, mean_var_count_thread_copy_step_m_k);
threadwise_count_load_m_k.MoveSrcSliceWindow(mean_var_count_grid_desc_m_k,
mean_var_count_thread_copy_step_m_k);
}
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
if constexpr(I > 0)
block_sync_lds();
BlockwiseWelford::Run(welford_mean_thread_buf(I),
welford_var_thread_buf(I),
welford_count_thread_buf(I));
});
// calculate inv-variance as 1/sqrt(epsilon+variance), stored in place of variance
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
welford_var_thread_buf(I) =
type_convert<AccDataType>(1.0) / sqrt(welford_var_thread_buf[I] + epsilon);
});
if(block_local_id == 0 && thread_k_cluster_id == 0)
{
auto threadwise_mean_inv_var_store =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
MeanVarDataType,
decltype(thread_buffer_desc_m),
MeanVarGridDesc_M,
PassThroughOp,
ThreadBufferLengths_M,
Sequence<0>,
0,
1,
InMemoryDataOperationEnum::Set,
1,
true>(
mean_var_grid_desc_m,
make_multi_index(blkgroup_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize),
PassThroughOp{});
auto mean_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_out_welford_mean, mean_var_grid_desc_m.GetElementSpaceSize());
auto inv_var_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_out_welford_inv_variance, mean_var_grid_desc_m.GetElementSpaceSize());
threadwise_mean_inv_var_store.Run(thread_buffer_desc_m,
make_tuple(I0),
mean_thread_buf,
mean_var_grid_desc_m,
mean_global_buf);
threadwise_mean_inv_var_store.Run(thread_buffer_desc_m,
make_tuple(I0),
inv_var_thread_buf,
mean_var_grid_desc_m,
inv_var_global_buf);
};
};
const index_t workSizePerBlock = K_BlockTileSize * num_xy_k_block_tile_iteration;
auto threadwise_x_load = ThreadwiseTensorSliceTransfer_v2<XDataType,
AccDataType,
XYGridDesc_M_K,
decltype(thread_buffer_desc_m_k),
ThreadBufferLengths_M_K,
ThreadBufferDimAccessOrder,
XDyVectorDim,
XSrcVectorSize,
1,
true>(
x_grid_desc_m_k,
make_multi_index(blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize,
workSizePerBlock * block_local_id +
thread_k_cluster_id * KThreadSliceSize));
auto threadwise_dy_load = ThreadwiseTensorSliceTransfer_v2<DyDataType,
AccDataType,
XYGridDesc_M_K,
decltype(thread_buffer_desc_m_k),
ThreadBufferLengths_M_K,
ThreadBufferDimAccessOrder,
XDyVectorDim,
DySrcVectorSize,
1,
true>(
dy_grid_desc_m_k,
make_multi_index(blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize,
workSizePerBlock * block_local_id +
thread_k_cluster_id * KThreadSliceSize));
const auto x_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_x, x_grid_desc_m_k.GetElementSpaceSize());
const auto dy_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_dy, dy_grid_desc_m_k.GetElementSpaceSize());
constexpr auto xy_thread_copy_step_m_k = make_multi_index(0, K_BlockTileSize);
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
reduce_dscale_thread_buf(I) = type_convert<AccDataType>(0);
reduce_dbias_thread_buf(I) = type_convert<AccDataType>(0);
});
// clang-format off
// Step 2: first-half of reduction: dbias = sum(dy), dscale = sum(dy * (x-mean) * inv-variance)
// clang-format on
for(index_t reducedTiles = 0; reducedTiles < num_xy_k_block_tile_iteration; ++reducedTiles)
{
threadwise_x_load.Run(x_grid_desc_m_k,
x_global_buf,
thread_buffer_desc_m_k,
make_tuple(I0, I0),
x_thread_buf);
threadwise_dy_load.Run(dy_grid_desc_m_k,
dy_global_buf,
thread_buffer_desc_m_k,
make_tuple(I0, I0),
dy_thread_buf);
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset =
thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK));
dy_elementwise_op(dy_thread_buf(Number<offset>{}),
dy_thread_buf[Number<offset>{}]);
AccDataType norm_x = (x_thread_buf[Number<offset>{}] - mean_thread_buf[iM]) *
inv_var_thread_buf[iM];
tmp1_thread_buf(Number<offset>{}) = norm_x * dy_thread_buf[Number<offset>{}];
});
});
ThreadwiseReduce::Reduce(tmp1_thread_buf, reduce_dscale_thread_buf);
ThreadwiseReduce::Reduce(dy_thread_buf, reduce_dbias_thread_buf);
threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, xy_thread_copy_step_m_k);
threadwise_dy_load.MoveSrcSliceWindow(dy_grid_desc_m_k, xy_thread_copy_step_m_k);
};
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
if constexpr(I > 0)
block_sync_lds();
BlockwiseReduce::Reduce(reduce_work_buf, reduce_dscale_thread_buf(I));
block_sync_lds();
BlockwiseReduce::Reduce(reduce_work_buf, reduce_dbias_thread_buf(I));
});
auto threadwise_dscale_store =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
ScaleDataType,
decltype(thread_buffer_desc_m_1),
DscaleDbiasGridDesc_M_G,
PassThroughOp,
ThreadBufferLengths_M_1,
Sequence<0, 1>,
1,
1,
InMemoryDataOperationEnum::Set,
1,
true>(
dscale_dbias_grid_desc_m_g,
make_multi_index(blkgroup_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
block_local_id),
PassThroughOp{});
auto threadwise_dbias_store =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
BiasDataType,
decltype(thread_buffer_desc_m_1),
DscaleDbiasGridDesc_M_G,
PassThroughOp,
ThreadBufferLengths_M_1,
Sequence<0, 1>,
1,
1,
InMemoryDataOperationEnum::Set,
1,
true>(
dscale_dbias_grid_desc_m_g,
make_multi_index(blkgroup_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
block_local_id),
PassThroughOp{});
auto reduce_dscale_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_reduce_dscale, dscale_dbias_grid_desc_m_g.GetElementSpaceSize());
auto reduce_dbias_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_reduce_dbias, dscale_dbias_grid_desc_m_g.GetElementSpaceSize());
if(thread_k_cluster_id == 0)
{
threadwise_dscale_store.Run(thread_buffer_desc_m_1,
make_tuple(I0, I0),
reduce_dscale_thread_buf,
dscale_dbias_grid_desc_m_g,
reduce_dscale_global_buf);
threadwise_dbias_store.Run(thread_buffer_desc_m_1,
make_tuple(I0, I0),
reduce_dbias_thread_buf,
dscale_dbias_grid_desc_m_g,
reduce_dbias_global_buf);
};
};
};
} // namespace ck
......@@ -874,6 +874,14 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
}
} // end gemm1
// workaround compiler issue; see ck/ck.hpp
if constexpr(CK_WORKAROUND_SWDEV_XXXXXX_BF16_ATTEN_FWD_GFX908_ISSUE == 1 &&
is_same_v<FloatAB, bhalf_t> && MPerBlock == 256 && NPerBlock == 128 &&
Gemm1NPerBlock == 128)
{
__builtin_amdgcn_sched_barrier(0);
}
constexpr auto c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 =
gemm1_blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
constexpr auto cm0 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I0);
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/data_type.hpp"
#include "ck/utility/math_v2.hpp"
#include "ck/utility/reduction_operator.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_welford.hpp"
#include "ck/tensor_operation/gpu/block/reduction_functions_blockwise.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_welford.hpp"
#include "ck/tensor_operation/gpu/thread/reduction_functions_threadwise.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace ck {
template <typename GridwiseBatchrNormBackwardWithBlockwiseWelford_,
typename XDataType,
typename DyDataType,
typename DxDataType,
typename AccDataType,
typename ScaleDataType,
typename BiasDataType,
typename MeanVarDataType,
typename DyElementwiseOp,
typename XYGridDesc_M_K,
typename ScaleBiasGridDesc_M,
typename MeanVarGridDesc_M,
typename GetReduceCountPerThreadFunctor>
__global__ void kernel_batchnorm_backward_with_blockwise_welford(
const XYGridDesc_M_K x_grid_desc_m_k,
const XYGridDesc_M_K dy_grid_desc_m_k,
const XYGridDesc_M_K dx_grid_desc_m_k,
const ScaleBiasGridDesc_M scale_grid_desc_m,
const ScaleBiasGridDesc_M bias_grid_desc_m,
const MeanVarGridDesc_M mean_var_grid_desc_m,
const GetReduceCountPerThreadFunctor get_reduce_count_per_thread,
long_index_t reduce_size,
index_t num_k_block_tile_iteration,
AccDataType epsilon,
const XDataType* const __restrict__ p_x,
const DyDataType* const __restrict__ p_dy,
const ScaleDataType* const __restrict__ p_scale,
bool haveSavedMeanInvVar,
const MeanVarDataType* const __restrict__ p_savedMean,
const MeanVarDataType* const __restrict__ p_savedInvVar,
const DyElementwiseOp dy_elementwise_op,
DxDataType* const __restrict__ p_dx,
ScaleDataType* const __restrict__ p_dscale,
BiasDataType* const __restrict__ p_dbias)
{
GridwiseBatchrNormBackwardWithBlockwiseWelford_::Run(x_grid_desc_m_k,
dy_grid_desc_m_k,
dx_grid_desc_m_k,
scale_grid_desc_m,
bias_grid_desc_m,
mean_var_grid_desc_m,
get_reduce_count_per_thread,
reduce_size,
num_k_block_tile_iteration,
epsilon,
p_x,
p_dy,
p_scale,
haveSavedMeanInvVar,
p_savedMean,
p_savedInvVar,
dy_elementwise_op,
p_dx,
p_dscale,
p_dbias);
};
template <typename XDataType,
typename DyDataType,
typename DxDataType,
typename AccDataType,
typename ScaleDataType,
typename BiasDataType,
typename MeanVarDataType,
typename DyElementwiseOp,
typename XYGridDesc_M_K,
typename ScaleBiasGridDesc_M,
typename MeanVarGridDesc_M,
typename GetReduceCountPerThreadFunctor,
index_t BlockSize,
index_t MThreadClusterSize,
index_t KThreadClusterSize,
index_t MThreadSliceSize,
index_t KThreadSliceSize,
index_t XDyDxVectorDim,
index_t XSrcVectorSize,
index_t DySrcVectorSize,
index_t DxDstVectorSize,
index_t ScaleSrcDstVectorSize,
index_t BiasDstVectorSize,
index_t MeanVarSrcVectorSize>
struct GridwiseBatchNormBackwardWithBlockwiseWelford
{
static_assert((XDyDxVectorDim == 0 && MThreadSliceSize % XSrcVectorSize == 0 &&
MThreadSliceSize % DySrcVectorSize == 0 &&
MThreadSliceSize % DxDstVectorSize == 0) ||
(XDyDxVectorDim == 1 && KThreadSliceSize % XSrcVectorSize == 0 &&
KThreadSliceSize % DySrcVectorSize == 0 &&
KThreadSliceSize % DxDstVectorSize == 0),
"Invalid thread slice sizes and/or vector sizes configuration, please check!");
static constexpr bool reorder_thread_cluster = (XDyDxVectorDim == 0);
using ThreadClusterLengths_M_K = Sequence<MThreadClusterSize, KThreadClusterSize>;
using ThreadBufferDimAccessOrder =
typename conditional<reorder_thread_cluster, Sequence<1, 0>, Sequence<0, 1>>::type;
using ThreadClusterArrangeOrder =
typename conditional<reorder_thread_cluster, Sequence<1, 0>, Sequence<0, 1>>::type;
static constexpr auto thread_cluster_desc =
make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{});
using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{})));
using ThreadReduceDstDesc_M =
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{})));
using ThreadwiseWelford =
ThreadwiseWelford<AccDataType, ThreadReduceSrcDesc_M_K, ThreadReduceDstDesc_M>;
using BlockwiseWelford = BlockwiseWelford<AccDataType,
BlockSize,
ThreadClusterLengths_M_K,
ThreadClusterArrangeOrder>;
using BlockwiseReduce = PartitionedBlockwiseReduction<AccDataType,
BlockSize,
ThreadClusterLengths_M_K,
ThreadClusterArrangeOrder,
ck::reduce::Add,
false>;
using ThreadwiseReduce = ThreadwiseReduction<AccDataType,
ThreadReduceSrcDesc_M_K,
ThreadReduceDstDesc_M,
ck::reduce::Add,
false>;
using PassThroughOp = tensor_operation::element_wise::PassThrough;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
// clang-format off
// Blockwise BatchNorm Backward
// Input: x, dy, scale, savedMean and savedInvVar (optional), reduce_size
// Output: dx, dscale, dbias
// Step 1: calculating mean and inv-variance using welford method (if savedMean/savedInvVar not available), where inv-variance = 1/sqrt(epsilon+variance)
// Step 2: reduction: dbias = sum(dy), dscale = sum(dy *(x-mean) * inv-variance)
// Step 3: calculating dx = 1/reduce_size * inv-variance * scale * (reduce_size * dy - dbias - dscale * (x - mean) * inv-variance)) elementwise-ly
// clang-format on
__device__ static void Run(const XYGridDesc_M_K x_grid_desc_m_k,
const XYGridDesc_M_K dy_grid_desc_m_k,
const XYGridDesc_M_K dx_grid_desc_m_k,
const ScaleBiasGridDesc_M scale_grid_desc_m,
const ScaleBiasGridDesc_M bias_grid_desc_m,
const MeanVarGridDesc_M mean_var_grid_desc_m,
const GetReduceCountPerThreadFunctor get_reduce_count_per_thread,
long_index_t reduce_size,
index_t num_k_block_tile_iteration,
AccDataType epsilon,
const XDataType* const __restrict__ p_x,
const DyDataType* const __restrict__ p_dy,
const ScaleDataType* const __restrict__ p_scale,
bool haveSavedMeanInvVar,
const MeanVarDataType* const __restrict__ p_savedMean,
const MeanVarDataType* const __restrict__ p_savedInvVar,
const DyElementwiseOp dy_elementwise_op,
DxDataType* const __restrict__ p_dx,
ScaleDataType* const __restrict__ p_dscale,
BiasDataType* const __restrict__ p_dbias)
{
using ck::math::sqrt;
__shared__ AccDataType p_reduce_work_buffer[BlockSize];
auto reduce_work_buf =
make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_buffer, BlockSize);
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
x_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
dy_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
dx_thread_buf;
// buffer of values of dy * (x-mean) * invVariance, used as input of Blockwise reduction
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
tmp1_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> scale_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> mean_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> var_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true>&
inv_var_thread_buf = var_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> dscale_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> dbias_thread_buf;
const index_t thread_local_id = get_thread_local_1d_id();
const index_t block_global_id = get_block_1d_id();
const auto thread_cluster_idx =
thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));
const auto thread_m_cluster_id = thread_cluster_idx[I0];
const auto thread_k_cluster_id = thread_cluster_idx[I1];
using ThreadBufferLengths_M_K = Sequence<MThreadSliceSize, KThreadSliceSize>;
using ThreadBufferLengths_M = Sequence<MThreadSliceSize>;
constexpr auto thread_buffer_desc_m_k = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{}));
constexpr auto thread_buffer_desc_m =
make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{}));
auto threadwise_x_load = ThreadwiseTensorSliceTransfer_v2<XDataType,
AccDataType,
XYGridDesc_M_K,
decltype(thread_buffer_desc_m_k),
ThreadBufferLengths_M_K,
ThreadBufferDimAccessOrder,
XDyDxVectorDim,
XSrcVectorSize,
1,
true>(
x_grid_desc_m_k,
make_multi_index(block_global_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * KThreadSliceSize));
auto threadwise_dy_load = ThreadwiseTensorSliceTransfer_v2<DyDataType,
AccDataType,
XYGridDesc_M_K,
decltype(thread_buffer_desc_m_k),
ThreadBufferLengths_M_K,
ThreadBufferDimAccessOrder,
XDyDxVectorDim,
XSrcVectorSize,
1,
true>(
x_grid_desc_m_k,
make_multi_index(block_global_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * KThreadSliceSize));
auto threadwise_dx_store =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
DxDataType,
decltype(thread_buffer_desc_m_k),
XYGridDesc_M_K,
PassThroughOp,
ThreadBufferLengths_M_K,
ThreadBufferDimAccessOrder,
XDyDxVectorDim,
DxDstVectorSize,
InMemoryDataOperationEnum::Set,
1,
true>(
dy_grid_desc_m_k,
make_multi_index(block_global_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * KThreadSliceSize),
PassThroughOp{});
auto threadwise_scale_load =
ThreadwiseTensorSliceTransfer_v2<ScaleDataType,
AccDataType,
ScaleBiasGridDesc_M,
decltype(thread_buffer_desc_m),
ThreadBufferLengths_M,
Sequence<0>,
0,
ScaleSrcDstVectorSize,
1,
true>(
scale_grid_desc_m,
make_multi_index(block_global_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize));
auto threadwise_dscale_store =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
ScaleDataType,
decltype(thread_buffer_desc_m),
ScaleBiasGridDesc_M,
PassThroughOp,
ThreadBufferLengths_M,
Sequence<0>,
0,
ScaleSrcDstVectorSize,
InMemoryDataOperationEnum::Set,
1,
true>(
scale_grid_desc_m,
make_multi_index(block_global_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize),
PassThroughOp{});
auto threadwise_dbias_store =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
BiasDataType,
decltype(thread_buffer_desc_m),
ScaleBiasGridDesc_M,
PassThroughOp,
ThreadBufferLengths_M,
Sequence<0>,
0,
BiasDstVectorSize,
InMemoryDataOperationEnum::Set,
1,
true>(
bias_grid_desc_m,
make_multi_index(block_global_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize),
PassThroughOp{});
constexpr auto thread_copy_fwd_step_m_k = make_multi_index(0, K_BlockTileSize);
constexpr auto thread_copy_bwd_step_m_k = make_multi_index(0, -K_BlockTileSize);
const auto x_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_x, x_grid_desc_m_k.GetElementSpaceSize());
const auto dy_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_dy, dy_grid_desc_m_k.GetElementSpaceSize());
auto dx_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_dx, dx_grid_desc_m_k.GetElementSpaceSize());
const auto scale_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_scale, scale_grid_desc_m.GetElementSpaceSize());
auto dscale_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_dscale, scale_grid_desc_m.GetElementSpaceSize());
auto dbias_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_dbias, bias_grid_desc_m.GetElementSpaceSize());
// clang-format off
// Step 1: calculating mean and inv-variance using welford method (if savedMean/savedInvVar not available), where inv-variance = 1/sqrt(epsilon+variance)
// clang-format on
if(haveSavedMeanInvVar)
{
const auto mean_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_savedMean, mean_var_grid_desc_m.GetElementSpaceSize());
const auto inv_var_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_savedInvVar, mean_var_grid_desc_m.GetElementSpaceSize());
auto threadwise_mean_inv_var_load =
ThreadwiseTensorSliceTransfer_v2<MeanVarDataType,
AccDataType,
MeanVarGridDesc_M,
decltype(thread_buffer_desc_m),
ThreadBufferLengths_M,
Sequence<0>,
0,
MeanVarSrcVectorSize,
1,
true>(
mean_var_grid_desc_m,
make_multi_index(block_global_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize));
threadwise_mean_inv_var_load.Run(mean_var_grid_desc_m,
mean_global_buf,
thread_buffer_desc_m,
make_tuple(I0),
mean_thread_buf);
threadwise_mean_inv_var_load.Run(mean_var_grid_desc_m,
inv_var_global_buf,
thread_buffer_desc_m,
make_tuple(I0),
inv_var_thread_buf);
}
else
{
auto threadwise_welford = ThreadwiseWelford();
threadwise_welford.max_count_ = get_reduce_count_per_thread(thread_k_cluster_id);
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
mean_thread_buf(I) = type_convert<AccDataType>(0.0f);
var_thread_buf(I) = type_convert<AccDataType>(0.0f);
});
for(index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles)
{
threadwise_x_load.Run(x_grid_desc_m_k,
x_global_buf,
thread_buffer_desc_m_k,
make_tuple(I0, I0),
x_thread_buf);
threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_fwd_step_m_k);
threadwise_welford.Run(x_thread_buf, mean_thread_buf, var_thread_buf);
}
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
if constexpr(I > 0)
block_sync_lds();
int count = threadwise_welford.cur_count_;
BlockwiseWelford::Run(mean_thread_buf(I), var_thread_buf(I), count);
});
// calculate inv-variance as 1/sqrt(epsilon+variance)
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
inv_var_thread_buf(I) =
type_convert<AccDataType>(1.0) / sqrt(var_thread_buf[I] + epsilon);
});
threadwise_x_load.SetSrcSliceOrigin(
x_grid_desc_m_k,
make_multi_index(block_global_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * KThreadSliceSize));
};
// clang-format off
// Step 2: reduction: dbias = sum(dy), dscale = sum(dy *(x-mean) * inv-variance)
// clang-format on
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
dscale_thread_buf(I) = type_convert<AccDataType>(0);
dbias_thread_buf(I) = type_convert<AccDataType>(0);
});
for(index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles)
{
threadwise_x_load.Run(x_grid_desc_m_k,
x_global_buf,
thread_buffer_desc_m_k,
make_tuple(I0, I0),
x_thread_buf);
threadwise_dy_load.Run(dx_grid_desc_m_k,
dy_global_buf,
thread_buffer_desc_m_k,
make_tuple(I0, I0),
dy_thread_buf);
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset =
thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK));
dy_elementwise_op(dy_thread_buf(Number<offset>{}),
dy_thread_buf[Number<offset>{}]);
AccDataType norm_x = (x_thread_buf[Number<offset>{}] - mean_thread_buf[iM]) *
inv_var_thread_buf[iM];
tmp1_thread_buf(Number<offset>{}) = norm_x * dy_thread_buf[Number<offset>{}];
});
});
ThreadwiseReduce::Reduce(tmp1_thread_buf, dscale_thread_buf);
ThreadwiseReduce::Reduce(dy_thread_buf, dbias_thread_buf);
threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_fwd_step_m_k);
threadwise_dy_load.MoveSrcSliceWindow(dy_grid_desc_m_k, thread_copy_fwd_step_m_k);
};
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
if constexpr(I > 0)
block_sync_lds();
BlockwiseReduce::Reduce(reduce_work_buf, dscale_thread_buf(I));
block_sync_lds();
BlockwiseReduce::Reduce(reduce_work_buf, dbias_thread_buf(I));
});
if(thread_k_cluster_id == 0)
{
threadwise_dscale_store.Run(thread_buffer_desc_m,
make_tuple(I0),
dscale_thread_buf,
scale_grid_desc_m,
dscale_global_buf);
threadwise_dbias_store.Run(thread_buffer_desc_m,
make_tuple(I0),
dbias_thread_buf,
bias_grid_desc_m,
dbias_global_buf);
};
// clang-format off
// Step 3: calculating dx = 1/reduce_size * inv-variance * scale * (reduce_size * dy - dbias - dscale * (x - mean) * inv-variance)) elementwise-ly
// clang-format on
threadwise_scale_load.Run(scale_grid_desc_m,
scale_global_buf,
thread_buffer_desc_m,
make_tuple(I0),
scale_thread_buf);
auto thread_copy_tail_m_k = (num_k_block_tile_iteration - 1) * thread_copy_fwd_step_m_k;
threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_bwd_step_m_k);
threadwise_dy_load.MoveSrcSliceWindow(dy_grid_desc_m_k, thread_copy_bwd_step_m_k);
threadwise_dx_store.MoveDstSliceWindow(dx_grid_desc_m_k, thread_copy_tail_m_k);
AccDataType inv_reduce_size =
type_convert<AccDataType>(1.0) / type_convert<AccDataType>(reduce_size);
for(index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles)
{
threadwise_x_load.Run(x_grid_desc_m_k,
x_global_buf,
thread_buffer_desc_m_k,
make_tuple(I0, I0),
x_thread_buf);
threadwise_dy_load.Run(dy_grid_desc_m_k,
dy_global_buf,
thread_buffer_desc_m_k,
make_tuple(I0, I0),
dy_thread_buf);
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
AccDataType multiplier =
inv_reduce_size * inv_var_thread_buf[iM] * scale_thread_buf[iM];
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset =
thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK));
dy_elementwise_op(dy_thread_buf(Number<offset>{}),
dy_thread_buf[Number<offset>{}]);
AccDataType norm_x = (x_thread_buf[Number<offset>{}] - mean_thread_buf[iM]) *
inv_var_thread_buf[iM];
AccDataType tmpVal = norm_x * dscale_thread_buf[iM];
dx_thread_buf(Number<offset>{}) =
multiplier *
(type_convert<AccDataType>(reduce_size) * dy_thread_buf[Number<offset>{}] -
dbias_thread_buf[iM] - tmpVal);
});
});
threadwise_dx_store.Run(thread_buffer_desc_m_k,
make_tuple(I0, I0),
dx_thread_buf,
dx_grid_desc_m_k,
dx_global_buf);
threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_bwd_step_m_k);
threadwise_dy_load.MoveSrcSliceWindow(dy_grid_desc_m_k, thread_copy_bwd_step_m_k);
threadwise_dx_store.MoveDstSliceWindow(dx_grid_desc_m_k, thread_copy_bwd_step_m_k);
}
}
};
} // namespace ck
......@@ -441,6 +441,7 @@ struct GridwiseBatchNormForwardWithBlockwiseWelford
auto result_inv_var_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
resultSaveInvVariance, mean_var_grid_desc_m.GetElementSpaceSize());
// calculate inv-variance as 1/sqrt(epsilon+variance), stored in place of variance
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
var_thread_buf(I) =
type_convert<AccDataType>(1.0f) / sqrt(epsilon + var_thread_buf[I]);
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/data_type.hpp"
#include "ck/utility/math.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_welford.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_welford.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace ck {
template <typename GridwiseMultiblockWelfordFirstHalf_,
typename XDataType,
typename MeanVarDataType,
typename XGridDesc_M_K,
typename MeanVarCountGridDesc_M_G,
typename GetReduceCountPerThreadFunctor>
__global__ void kernel_multiblock_welford_first_half(
const XGridDesc_M_K x_grid_desc_m_k,
const MeanVarCountGridDesc_M_G mean_var_count_grid_desc_m_g,
const GetReduceCountPerThreadFunctor get_reduce_count_per_thread,
index_t num_k_block_tile_iteration,
const XDataType* const __restrict__ p_x,
MeanVarDataType* const p_welford_mean,
MeanVarDataType* const p_welford_variance,
int32_t* const p_welford_count)
{
GridwiseMultiblockWelfordFirstHalf_::Run(x_grid_desc_m_k,
mean_var_count_grid_desc_m_g,
get_reduce_count_per_thread,
num_k_block_tile_iteration,
p_x,
p_welford_mean,
p_welford_variance,
p_welford_count);
};
template <typename XDataType,
typename AccDataType,
typename MeanVarDataType,
typename XGridDesc_M_K,
typename MeanVarCountGridDesc_M_G,
typename GetReduceCountPerThreadFunctor,
index_t BlockSize,
index_t MThreadClusterSize,
index_t KThreadClusterSize,
index_t MThreadSliceSize,
index_t KThreadSliceSize,
index_t XSrcCountSrcVectorDim,
index_t XSrcCountSrcVectorSize>
struct GridwiseMultiblockWelfordFirstHalf
{
static_assert((XSrcCountSrcVectorDim == 0 && MThreadSliceSize % XSrcCountSrcVectorSize == 0) ||
(XSrcCountSrcVectorDim == 1 &&
KThreadSliceSize % XSrcCountSrcVectorSize == 0),
"Invalid thread slice sizes and/or vector sizes configuration, please check!");
static constexpr bool reorder_thread_cluster = (XSrcCountSrcVectorDim == 0);
using ThreadClusterLengths_M_K = Sequence<MThreadClusterSize, KThreadClusterSize>;
using ThreadBufferDimAccessOrder =
typename conditional<reorder_thread_cluster, Sequence<1, 0>, Sequence<0, 1>>::type;
using ThreadClusterArrangeOrder =
typename conditional<reorder_thread_cluster, Sequence<1, 0>, Sequence<0, 1>>::type;
static constexpr auto thread_cluster_desc =
make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{});
using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{})));
using ThreadReduceDstDesc_M =
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{})));
using ThreadwiseWelford =
ThreadwiseWelford<AccDataType, ThreadReduceSrcDesc_M_K, ThreadReduceDstDesc_M>;
using BlockwiseWelford = BlockwiseWelford<AccDataType,
BlockSize,
ThreadClusterLengths_M_K,
ThreadClusterArrangeOrder,
false>;
using PassThroughOp = tensor_operation::element_wise::PassThrough;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
__device__ static void Run(const XGridDesc_M_K& x_grid_desc_m_k,
const MeanVarCountGridDesc_M_G& mean_var_count_grid_desc_m_g,
const GetReduceCountPerThreadFunctor& get_reduce_count_per_thread,
index_t num_k_block_tile_iteration,
const XDataType* const __restrict__ p_x,
MeanVarDataType* const p_welford_mean,
MeanVarDataType* const p_welford_variance,
int32_t* const p_welford_count)
{
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
x_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true>
welford_mean_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true>
welford_var_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, int32_t, MThreadSliceSize, true>
welford_count_thread_buf;
const index_t blkgroup_size = mean_var_count_grid_desc_m_g.GetLength(I1);
const index_t thread_local_id = get_thread_local_1d_id();
const index_t block_global_id = get_block_1d_id();
const index_t blkgroup_id = block_global_id / blkgroup_size;
const index_t block_local_id = block_global_id % blkgroup_size;
const auto thread_cluster_idx =
thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));
const auto thread_m_cluster_id = thread_cluster_idx[I0];
const auto thread_k_cluster_id = thread_cluster_idx[I1];
using ThreadBufferLengths_M_K = Sequence<MThreadSliceSize, KThreadSliceSize>;
using ThreadBufferLengths_M_1 = Sequence<MThreadSliceSize, 1>;
constexpr auto thread_buffer_desc_m_k = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{}));
constexpr auto thread_buffer_desc_m_1 = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<1>{}));
const index_t reduceSizePerBlock = K_BlockTileSize * num_k_block_tile_iteration;
auto threadwise_x_load = ThreadwiseTensorSliceTransfer_v2<XDataType,
AccDataType,
XGridDesc_M_K,
decltype(thread_buffer_desc_m_k),
ThreadBufferLengths_M_K,
ThreadBufferDimAccessOrder,
XSrcCountSrcVectorDim,
XSrcCountSrcVectorSize,
1,
true>(
x_grid_desc_m_k,
make_multi_index(blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize,
block_local_id * reduceSizePerBlock +
thread_k_cluster_id * KThreadSliceSize));
auto threadwise_welford_mean_var_store =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
MeanVarDataType,
decltype(thread_buffer_desc_m_1),
MeanVarCountGridDesc_M_G,
PassThroughOp,
ThreadBufferLengths_M_1,
Sequence<0, 1>,
1,
1,
InMemoryDataOperationEnum::Set,
1,
true>(
mean_var_count_grid_desc_m_g,
make_multi_index(blkgroup_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
block_local_id),
PassThroughOp{});
auto threadwise_welford_count_store =
ThreadwiseTensorSliceTransfer_v1r3<int32_t,
int32_t,
decltype(thread_buffer_desc_m_1),
MeanVarCountGridDesc_M_G,
PassThroughOp,
ThreadBufferLengths_M_1,
Sequence<0, 1>,
1,
1,
InMemoryDataOperationEnum::Set,
1,
true>(
mean_var_count_grid_desc_m_g,
make_multi_index(blkgroup_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
block_local_id),
PassThroughOp{});
constexpr auto thread_copy_fwd_step_m_k = make_multi_index(0, K_BlockTileSize);
const auto x_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_x, x_grid_desc_m_k.GetElementSpaceSize());
auto welford_mean_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_welford_mean, mean_var_count_grid_desc_m_g.GetElementSpaceSize());
auto welford_var_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_welford_variance, mean_var_count_grid_desc_m_g.GetElementSpaceSize());
auto welford_count_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_welford_count, mean_var_count_grid_desc_m_g.GetElementSpaceSize());
auto threadwise_welford = ThreadwiseWelford();
threadwise_welford.max_count_ =
get_reduce_count_per_thread(block_local_id, thread_k_cluster_id);
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
welford_mean_thread_buf(I) = type_convert<AccDataType>(0.0f);
welford_var_thread_buf(I) = type_convert<AccDataType>(0.0f);
});
for(index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles)
{
threadwise_x_load.Run(x_grid_desc_m_k,
x_global_val_buf,
thread_buffer_desc_m_k,
make_tuple(I0, I0),
x_thread_buf);
threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_fwd_step_m_k);
threadwise_welford.Run(x_thread_buf, welford_mean_thread_buf, welford_var_thread_buf);
}
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
if constexpr(I > 0)
block_sync_lds();
welford_count_thread_buf(I) = threadwise_welford.cur_count_;
BlockwiseWelford::Run(
welford_mean_thread_buf(I), welford_var_thread_buf(I), welford_count_thread_buf(I));
});
if(thread_k_cluster_id == 0)
{
threadwise_welford_mean_var_store.Run(thread_buffer_desc_m_1,
make_tuple(I0, I0),
welford_mean_thread_buf,
mean_var_count_grid_desc_m_g,
welford_mean_global_val_buf);
threadwise_welford_mean_var_store.Run(thread_buffer_desc_m_1,
make_tuple(I0, I0),
welford_var_thread_buf,
mean_var_count_grid_desc_m_g,
welford_var_global_val_buf);
threadwise_welford_count_store.Run(thread_buffer_desc_m_1,
make_tuple(I0, I0),
welford_count_thread_buf,
mean_var_count_grid_desc_m_g,
welford_count_global_val_buf);
};
}
};
} // namespace ck
......@@ -254,7 +254,7 @@ struct intrin_mfma_f32_16x16x8bf16<16, 16>
template <class FloatC>
__device__ static void Run(const bhalf2_t& reg_a, const bhalf2_t& reg_b, FloatC& reg_c)
{
reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x4bf16(
reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x8bf16(
reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 0, 0, 0);
}
};
......
......@@ -4,50 +4,61 @@
#pragma once
#include <iostream>
#include <vector>
#include <array>
#include <sstream>
#include <algorithm>
#include "ck/tensor_operation/gpu/device/device_batchnorm_infer.hpp"
#include "ck/tensor_operation/gpu/device/device_batchnorm_backward.hpp"
namespace ck {
namespace tensor_operation {
namespace host {
template <typename XDataType,
typename YDataType,
typename DyDataType,
typename DxDataType,
typename AccDataType,
typename ScaleDataType,
typename BiasDataType,
typename MeanVarDataType>
struct ReferenceBatchNormInfer_Input_N_H_W_C_Output_C : public device::DeviceBatchNormInfer<4, 3>
typename MeanVarDataType,
typename DyElementwiseOp>
struct ReferenceBatchNormBwd_Input_N_H_W_C_Output_C
: public device::DeviceBatchNormBwd<4, 3, DyElementwiseOp>
{
struct Argument : public device::BaseArgument
{
Argument(const std::array<index_t, 4> xyLengths,
const std::array<index_t, 4> xStrides,
const std::array<index_t, 4> yStrides,
const std::array<index_t, 1> bnScaleBiasMeanVarLengths,
const std::array<index_t, 1> bnScaleStrides,
const std::array<index_t, 1> bnBiasStrides,
const std::array<index_t, 1> bnMeanVarStrides,
const std::array<index_t, 4> dyStrides,
const std::array<index_t, 4> dxStrides,
const std::array<int, 3> reduceDims,
const std::array<ck::index_t, 1> bnScaleBiasMeanVarLengths,
const std::array<ck::index_t, 1> bnScaleStrides,
const std::array<ck::index_t, 1> bnBiasStrides,
const std::array<ck::index_t, 1> bnMeanVarStrides,
const XDataType* p_x,
const ScaleDataType* bnScale,
const BiasDataType* bnBias,
const DyDataType* p_dy,
const ScaleDataType* p_scale,
const MeanVarDataType* p_savedMean,
const MeanVarDataType* p_savedInvVar,
double epsilon,
const MeanVarDataType* estimatedMean,
const MeanVarDataType* estimatedVariance,
YDataType* p_y)
const DyElementwiseOp dy_elementwise_op,
DxDataType* p_dx,
ScaleDataType* p_dscale,
BiasDataType* p_dbias)
: p_x_(p_x),
bnScale_(bnScale),
bnBias_(bnBias),
p_dy_(p_dy),
p_scale_(p_scale),
p_savedMean_(p_savedMean),
p_savedInvVar_(p_savedInvVar),
epsilon_(epsilon),
estimatedMean_(estimatedMean),
estimatedVariance_(estimatedVariance),
p_y_(p_y)
dy_elementwise_op_(dy_elementwise_op),
p_dx_(p_dx),
p_dscale_(p_dscale),
p_dbias_(p_dbias)
{
ignore = xStrides;
ignore = yStrides;
ignore = dyStrides;
ignore = dxStrides;
ignore = bnScaleStrides;
ignore = bnBiasStrides;
ignore = bnMeanVarStrides;
......@@ -56,22 +67,31 @@ struct ReferenceBatchNormInfer_Input_N_H_W_C_Output_C : public device::DeviceBat
bnScaleBiasMeanVarLengths[0] != xyLengths[3])
throw std::runtime_error("Invalid tensor dimensions!");
if(reduceDims[0] != 0 || reduceDims[1] != 1 || reduceDims[2] != 2)
throw std::runtime_error("Invalid reduce dimensions!");
n_ = xyLengths[0];
h_ = xyLengths[1];
w_ = xyLengths[2];
c_ = xyLengths[3];
haveSavedMeanInvVar_ = (p_savedMean != nullptr && p_savedInvVar != nullptr);
}
const XDataType* p_x_;
const ScaleDataType* bnScale_;
const BiasDataType* bnBias_;
const DyDataType* p_dy_;
const ScaleDataType* p_scale_;
const MeanVarDataType* p_savedMean_;
const MeanVarDataType* p_savedInvVar_;
double epsilon_;
const DyElementwiseOp dy_elementwise_op_;
const MeanVarDataType* estimatedMean_;
const MeanVarDataType* estimatedVariance_;
DxDataType* p_dx_;
ScaleDataType* p_dscale_;
BiasDataType* p_dbias_;
YDataType* p_y_;
bool haveSavedMeanInvVar_;
index_t n_, h_, w_, c_;
};
......@@ -81,15 +101,60 @@ struct ReferenceBatchNormInfer_Input_N_H_W_C_Output_C : public device::DeviceBat
float Run(const Argument& arg)
{
auto thread_reduce_func = [&](auto iC) {
index_t offset_C = iC;
AccDataType mean = arg.estimatedMean_[offset_C];
AccDataType variance = arg.estimatedVariance_[offset_C];
AccDataType reduceSize = type_convert<AccDataType>(arg.n_) *
type_convert<AccDataType>(arg.h_) *
type_convert<AccDataType>(arg.w_);
index_t offset_C = iC;
AccDataType mean;
AccDataType invVar;
if(arg.haveSavedMeanInvVar_)
{
mean = arg.p_savedMean_[offset_C];
invVar = arg.p_savedInvVar_[offset_C];
}
else
{
AccDataType meansquare;
meansquare = type_convert<AccDataType>(0.0f);
mean = type_convert<AccDataType>(0.0f);
// compute mean, meanquare, variance, inv-variance
for(index_t iN = 0; iN < arg.n_; iN++)
{
index_t offset_N = iN * arg.h_ * arg.w_ * arg.c_;
for(index_t iH = 0; iH < arg.h_; iH++)
{
index_t offset_H = iH * arg.w_ * arg.c_;
for(index_t iW = 0; iW < arg.w_; iW++)
{
index_t offset_W = iW * arg.c_;
auto offset = offset_N + offset_H + offset_W + offset_C;
AccDataType x = type_convert<AccDataType>(arg.p_x_[offset]);
AccDataType invVariance =
type_convert<AccDataType>(1.0f) /
std::sqrt(type_convert<AccDataType>(arg.epsilon_) + variance);
mean += x;
meansquare += x * x;
};
}
};
// Normalization
mean = mean / reduceSize;
meansquare = meansquare / reduceSize;
AccDataType variance = meansquare - mean * mean;
invVar = type_convert<AccDataType>(1.0f) /
std::sqrt(type_convert<AccDataType>(arg.epsilon_) + variance);
};
AccDataType dbias = type_convert<AccDataType>(0.0f); // Sum on NHW of dy
AccDataType dscale = type_convert<AccDataType>(0.0f); // Sum on NHW of dy * norm_x
// 1) calculate dy * (x - mean) * inv-variance
// 2) calculate sum(dy) on NHW dimensions
// 3) calculate sum(dy * norm_x) on NHW dimensions
for(index_t iN = 0; iN < arg.n_; iN++)
{
index_t offset_N = iN * arg.h_ * arg.w_ * arg.c_;
......@@ -104,10 +169,50 @@ struct ReferenceBatchNormInfer_Input_N_H_W_C_Output_C : public device::DeviceBat
AccDataType x = type_convert<AccDataType>(arg.p_x_[offset]);
AccDataType norm_x =
arg.bnScale_[iC] * (x - mean) * invVariance + arg.bnBias_[iC];
AccDataType norm_x = (x - mean) * invVar;
AccDataType dy = type_convert<AccDataType>(arg.p_dy_[offset]);
arg.dy_elementwise_op_(dy, dy);
dbias += dy;
dscale += norm_x * dy;
};
}
};
arg.p_dscale_[offset_C] = type_convert<ScaleDataType>(dscale);
arg.p_dbias_[offset_C] = type_convert<BiasDataType>(dbias);
AccDataType scale = type_convert<AccDataType>(arg.p_scale_[offset_C]);
AccDataType multiplier =
type_convert<AccDataType>(1.0f) / reduceSize * invVar * scale;
// 1) calculate tmp = dscale * (x - mean) * inv-variance
// 2) calculate dx = 1/nhw * inv-variance * scale * (nhw * dy - dbias - tmp)
for(index_t iN = 0; iN < arg.n_; iN++)
{
index_t offset_N = iN * arg.h_ * arg.w_ * arg.c_;
for(index_t iH = 0; iH < arg.h_; iH++)
{
index_t offset_H = iH * arg.w_ * arg.c_;
for(index_t iW = 0; iW < arg.w_; iW++)
{
index_t offset_W = iW * arg.c_;
auto offset = offset_N + offset_H + offset_W + offset_C;
AccDataType x = type_convert<AccDataType>(arg.p_x_[offset]);
AccDataType norm_x = (x - mean) * invVar;
AccDataType dy = type_convert<AccDataType>(arg.p_dy_[offset]);
arg.dy_elementwise_op_(dy, dy);
AccDataType tmpVal = norm_x * dscale;
AccDataType dx = multiplier * (reduceSize * dy - dbias - tmpVal);
arg.p_y_[offset] = type_convert<YDataType>(norm_x);
arg.p_dx_[offset] = type_convert<XDataType>(dx);
};
}
};
......@@ -153,33 +258,43 @@ struct ReferenceBatchNormInfer_Input_N_H_W_C_Output_C : public device::DeviceBat
std::unique_ptr<device::BaseArgument>
MakeArgumentPointer(const std::array<index_t, 4> xyLengths,
const std::array<index_t, 4> xStrides,
const std::array<index_t, 4> yStrides,
const std::array<index_t, 1> bnScaleBiasMeanVarLengths,
const std::array<index_t, 1> bnScaleStrides,
const std::array<index_t, 1> bnBiasStrides,
const std::array<index_t, 1> bnMeanVarStrides,
const std::array<index_t, 4> dyStrides,
const std::array<index_t, 4> dxStrides,
const std::array<int, 3> reduceDims,
const std::array<ck::index_t, 1> bnScaleBiasMeanVarLengths,
const std::array<ck::index_t, 1> bnScaleStrides,
const std::array<ck::index_t, 1> bnBiasStrides,
const std::array<ck::index_t, 1> bnMeanVarStrides,
const void* p_x,
const void* bnScale,
const void* bnBias,
const void* p_dy,
const void* p_scale,
const void* p_savedMean,
const void* p_savedInvVar,
double epsilon,
const void* estimatedMean,
const void* estimatedVariance,
void* p_y) override
const DyElementwiseOp dy_elementwise_op,
void* p_dx,
void* p_dscale,
void* p_dbias) override
{
return std::make_unique<Argument>(xyLengths,
xStrides,
yStrides,
dyStrides,
dxStrides,
reduceDims,
bnScaleBiasMeanVarLengths,
bnScaleStrides,
bnBiasStrides,
bnMeanVarStrides,
static_cast<const XDataType*>(p_x),
static_cast<const ScaleDataType*>(bnScale),
static_cast<const BiasDataType*>(bnBias),
static_cast<const DyDataType*>(p_dy),
static_cast<const ScaleDataType*>(p_scale),
static_cast<const MeanVarDataType*>(p_savedMean),
static_cast<const MeanVarDataType*>(p_savedInvVar),
epsilon,
static_cast<const MeanVarDataType*>(estimatedMean),
static_cast<const MeanVarDataType*>(estimatedVariance),
static_cast<YDataType*>(p_y));
dy_elementwise_op,
static_cast<DxDataType*>(p_dx),
static_cast<ScaleDataType*>(p_dscale),
static_cast<BiasDataType*>(p_dbias));
};
std::unique_ptr<device::BaseInvoker> MakeInvokerPointer() override
......@@ -192,7 +307,7 @@ struct ReferenceBatchNormInfer_Input_N_H_W_C_Output_C : public device::DeviceBat
auto str = std::stringstream();
// clang-format off
str << "Reference_BatchNorm_Forward_NHWC_C<" << std::endl;
str << "Reference_BatchNorm_Backward_NHWC_C<" << std::endl;
// clang-format on
return str.str();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment