Commit 1abaedd9 authored by Alan Turner's avatar Alan Turner
Browse files

Merge remote-tracking branch 'origin/develop' into gpu-invoker

parents bd2b3dd7 cb3fac4d
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
namespace ck {
namespace tensor_operation {
namespace device {
enum struct MaskingSpecialization
{
MaskDisabled,
MaskOutUpperTriangle
};
inline std::string getMaskingSpecializationString(const MaskingSpecialization& s)
{
switch(s)
{
case MaskingSpecialization::MaskDisabled: return "MaskDisabled";
case MaskingSpecialization::MaskOutUpperTriangle: return "MaskOutUpperTriangle";
default: return "Unrecognized specialization!";
}
}
struct MaskDisabledPredicate
{
__host__ __device__ constexpr bool operator()(index_t /*m*/, index_t /*n*/) const
{
return false;
};
__host__ __device__ constexpr bool
IsTileSkippable(index_t /*m*/, index_t /*n*/, index_t /*m_tile*/, index_t /*n_tile*/) const
{
return false;
}
};
struct MaskOutUpperTrianglePredicate
{
__host__ __device__ constexpr bool operator()(index_t m, index_t n) const { return n > m; }
__host__ __device__ constexpr bool
IsTileSkippable(index_t m, index_t n, index_t m_tile, index_t /*n_tile*/) const
{
return operator()(m + m_tile - 1, n);
}
};
// to track the points which need to be set to -inf on C0
// Note: no need to reset M padding value, because they will not be stored out.
template <typename MaskOutPredicate>
struct C0MatrixMask_impl
{
C0MatrixMask_impl(index_t NRaw) : NRaw_(NRaw), predicate_(MaskOutPredicate{}) {}
__host__ __device__ constexpr bool IsNOutOfBound(/*index_t m, */ index_t n) const
{
return n >= NRaw_;
}
__host__ __device__ constexpr bool IsMaskedElement(index_t m, index_t n) const
{
return predicate_(m, n) || IsNOutOfBound(n);
}
__host__ __device__ constexpr bool
IsTileSkippable(index_t m, index_t n, index_t m_tile, index_t n_tile) const
{
return predicate_.IsTileSkippable(m, n, m_tile, n_tile);
}
private:
// index_t MRaw_;
index_t NRaw_;
MaskOutPredicate predicate_;
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
namespace ck {
namespace tensor_operation {
namespace device {
template <index_t K_BlockTileSize, index_t KThreadSliceSize>
struct GetReduceCountPerThreadForBlockwiseWelford
{
GetReduceCountPerThreadForBlockwiseWelford(index_t numBlockTileIteration,
long_index_t reduce_length)
: numBlockTileIteration_{numBlockTileIteration}
{
count_in_last_tile_ = reduce_length % K_BlockTileSize;
};
__device__ index_t operator()(index_t thread_k_cluster_id) const
{
if(count_in_last_tile_ == 0)
return (KThreadSliceSize * numBlockTileIteration_);
else
{
index_t num_complete_slice = count_in_last_tile_ / KThreadSliceSize;
index_t count_in_last_slice = count_in_last_tile_ % KThreadSliceSize;
if(thread_k_cluster_id < num_complete_slice)
return (KThreadSliceSize * numBlockTileIteration_);
else if(thread_k_cluster_id == num_complete_slice)
return (KThreadSliceSize * (numBlockTileIteration_ - 1) + count_in_last_slice);
else
return (KThreadSliceSize * (numBlockTileIteration_ - 1));
};
};
index_t numBlockTileIteration_;
index_t count_in_last_tile_;
};
template <index_t K_BlockTileSize, index_t KThreadSliceSize>
struct GetReduceCountPerThreadForMultiblockWelford
{
GetReduceCountPerThreadForMultiblockWelford(index_t blkGroupSize,
index_t numBlockTileIteration,
long_index_t reduce_length)
: blkGroupSize_(blkGroupSize), numBlockTileIteration_{numBlockTileIteration}
{
last_block_reduce_length_ =
reduce_length - K_BlockTileSize * numBlockTileIteration_ * (blkGroupSize_ - 1);
numBlockTileIterationByLastBlock_ =
(last_block_reduce_length_ + K_BlockTileSize - 1) / K_BlockTileSize;
};
__device__ index_t operator()(index_t block_local_id, index_t thread_k_cluster_id) const
{
if(last_block_reduce_length_ == K_BlockTileSize * numBlockTileIteration_ ||
block_local_id < blkGroupSize_ - 1)
return (KThreadSliceSize * numBlockTileIteration_);
index_t count_in_last_tile = last_block_reduce_length_ % K_BlockTileSize;
if(count_in_last_tile == 0)
return (KThreadSliceSize * numBlockTileIterationByLastBlock_);
else
{
index_t num_complete_slice = count_in_last_tile / KThreadSliceSize;
if(thread_k_cluster_id < num_complete_slice)
return (KThreadSliceSize * numBlockTileIterationByLastBlock_);
else if(thread_k_cluster_id == num_complete_slice)
return (KThreadSliceSize * (numBlockTileIterationByLastBlock_ - 1) +
count_in_last_tile);
else
return (KThreadSliceSize * (numBlockTileIterationByLastBlock_ - 1));
};
};
index_t blkGroupSize_;
index_t numBlockTileIteration_;
index_t last_block_reduce_length_;
index_t numBlockTileIterationByLastBlock_;
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -49,6 +49,14 @@ struct Add ...@@ -49,6 +49,14 @@ struct Add
y = x0 + x1; y = x0 + x1;
}; };
template <>
__host__ __device__ constexpr void
operator()<float>(float& y, const float& x0, const bhalf_t& x1) const
{
const float x1_tmp = ck::type_convert<float>(x1);
y = x0 + x1_tmp;
}
template <> template <>
__host__ __device__ constexpr void __host__ __device__ constexpr void
operator()<bhalf_t>(bhalf_t& y, const bhalf_t& x0, const bhalf_t& x1) const operator()<bhalf_t>(bhalf_t& y, const bhalf_t& x0, const bhalf_t& x1) const
...@@ -67,6 +75,30 @@ struct Add ...@@ -67,6 +75,30 @@ struct Add
}; };
}; };
struct ScaleAdd
{
__host__ __device__ ScaleAdd(float scale) : scale_(scale) {}
template <typename Y, typename X0, typename X1>
__host__ __device__ constexpr void operator()(Y& y, const X0& x0, const X1& x1) const;
template <>
__host__ __device__ void
operator()<float, float, half_t>(float& y, const float& x0, const half_t& x1) const
{
y = scale_ * x0 + ck::type_convert<float>(x1);
};
template <>
__host__ __device__ void
operator()<float, float, bhalf_t>(float& y, const float& x0, const bhalf_t& x1) const
{
y = scale_ * x0 + ck::type_convert<float>(x1);
};
float scale_;
};
struct Subtract struct Subtract
{ {
template <typename T> template <typename T>
...@@ -118,6 +150,13 @@ struct Bilinear ...@@ -118,6 +150,13 @@ struct Bilinear
template <typename Y, typename X0, typename X1> template <typename Y, typename X0, typename X1>
__host__ __device__ constexpr void operator()(Y&, const X0&, const X1&) const; __host__ __device__ constexpr void operator()(Y&, const X0&, const X1&) const;
template <>
__host__ __device__ constexpr void
operator()<double, double, double>(double& y, const double& x0, const double& x1) const
{
y = alpha_ * x0 + beta_ * x1;
};
template <> template <>
__host__ __device__ constexpr void __host__ __device__ constexpr void
operator()<float, float, float>(float& y, const float& x0, const float& x1) const operator()<float, float, float>(float& y, const float& x0, const float& x1) const
...@@ -187,6 +226,22 @@ struct AddRelu ...@@ -187,6 +226,22 @@ struct AddRelu
const float a = x0 + type_convert<float>(x1); const float a = x0 + type_convert<float>(x1);
y = a > 0.0f ? a : 0.0f; y = a > 0.0f ? a : 0.0f;
}; };
template <>
__host__ __device__ constexpr void
operator()<int, int, int8_t>(int& y, const int& x0, const int8_t& x1) const
{
const int8_t a = x0 + x1;
y = a > 0 ? a : 0;
};
template <>
__host__ __device__ constexpr void
operator()<int8_t, int8_t, int8_t>(int8_t& y, const int8_t& x0, const int8_t& x1) const
{
const int8_t a = x0 + x1;
y = a > 0 ? a : 0;
};
}; };
struct AddHardswish struct AddHardswish
......
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
#include "ck/utility/math_v2.hpp" #include "ck/utility/math_v2.hpp"
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/quantization_operation.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
...@@ -171,6 +172,42 @@ struct AddAdd ...@@ -171,6 +172,42 @@ struct AddAdd
} }
}; };
// C = A * B
// E = (C + D0) x D1
struct AddMultiply
{
template <typename E, typename C, typename D0, typename D1>
__host__ __device__ void operator()(E& e, const C& c, const D0& d0, const D1& d1) const;
template <>
__host__ __device__ void operator()<half_t, half_t, half_t, half_t>(half_t& e,
const half_t& c,
const half_t& d0,
const half_t& d1) const
{
const half_t y = (c + d0) * d1;
e = y;
}
template <>
__host__ __device__ void operator()<half_t, float, half_t, half_t>(half_t& e,
const float& c,
const half_t& d0,
const half_t& d1) const
{
const half_t y = (type_convert<half_t>(c) + d0) * d1;
e = y;
}
template <>
__host__ __device__ void operator()<float, float, half_t, half_t>(float& e,
const float& c,
const half_t& d0,
const half_t& d1) const
{
const float y = (c + d0) * d1;
e = y;
}
};
// C = A * B // C = A * B
// E = FastGelu(C + D0 + D1) // E = FastGelu(C + D0 + D1)
struct AddAddFastGelu struct AddAddFastGelu
...@@ -277,6 +314,40 @@ struct Normalize ...@@ -277,6 +314,40 @@ struct Normalize
double epsilon_; double epsilon_;
}; };
// used by BatchNorm inference
// y = gamma * (x-mean) / sqrt(epsilon+variance) + beta
// The data type of mean and variance is used as AccDataType
struct NormalizeInInfer
{
NormalizeInInfer(double epsilon = 1e-4) : epsilon_(epsilon) {}
template <typename T1, typename T2, typename T3, typename T4>
__host__ __device__ constexpr void operator()(T1& y,
const T1& x,
const T2& mean,
const T2& variance,
const T3& gamma,
const T4& beta) const
{
static_assert(std::is_same<T2, float>::value || std::is_same<T2, double>::value,
"Data type is not supported by this operation!");
using ck::type_convert;
using ck::math::sqrt;
T2 tmp_x, tmp_y;
tmp_x = type_convert<T2>(x);
tmp_y = ((tmp_x - mean) / sqrt(variance + type_convert<T2>(epsilon_))) *
type_convert<T2>(gamma) +
type_convert<T2>(beta);
y = type_convert<T1>(tmp_y);
};
double epsilon_;
};
template <typename Y, typename X> template <typename Y, typename X>
struct UnaryTypeConvert; struct UnaryTypeConvert;
......
#pragma once
#include "ck/utility/data_type.hpp"
namespace ck {
namespace tensor_operation {
namespace element_wise {
// For Activation function which is piecewise linear function, such as relu, leaky relu ...etc
template <typename Activation>
struct Activation_Mul_Clamp
{
Activation_Mul_Clamp(float requantScale, Activation activationOp)
: requantScale_(requantScale), activationOp_(activationOp)
{
}
__host__ __device__ constexpr void operator()(int8_t& y, const int32_t& x) const
{
float x_fp32 = ck::type_convert<float>(x);
activationOp_(x_fp32, x_fp32);
float y_fp32 = math::clamp(requantScale_ * x_fp32, -128.f, 127.f);
y = ck::type_convert<int8_t>(y_fp32);
}
__host__ __device__ constexpr void operator()(float& y, const int32_t& x) const
{
// We might type_convert to int8 after lambda in someplace
float x_fp32 = ck::type_convert<float>(x);
activationOp_(x_fp32, x_fp32);
y = math::clamp(requantScale_ * x_fp32, -128.f, 127.f);
}
float requantScale_;
Activation activationOp_;
};
// Conv Perchannel quantization + Activation function which is piecewise linear function, such as
// relu, leaky relu ...etc
template <typename Activation>
struct Activation_Mul2_Clamp
{
Activation_Mul2_Clamp(Activation activationOp) : activationOp_(activationOp) {}
__host__ __device__ constexpr void
operator()(int8_t& y, const int32_t& x, const float& requantScale) const
{
float y_fp32 = ck::type_convert<float>(x);
activationOp_(y_fp32, y_fp32);
y_fp32 = math::clamp(requantScale * y_fp32, -128.f, 127.f);
y = ck::type_convert<int8_t>(y_fp32);
}
Activation activationOp_;
};
// For Activation function which is piecewise linear function, such as relu, leaky relu ...etc
template <typename Activation>
struct Add_Activation_Mul_Clamp
{
Add_Activation_Mul_Clamp(float requantScale, Activation activationOp)
: requantScale_(requantScale), activationOp_(activationOp)
{
}
__host__ __device__ constexpr void
operator()(int8_t& y, const int32_t& x, const int32_t& bias) const
{
float y_fp32 = ck::type_convert<float>(x + bias);
activationOp_(y_fp32, y_fp32);
y_fp32 = math::clamp(requantScale_ * y_fp32, -128.f, 127.f);
y = ck::type_convert<int8_t>(y_fp32);
}
float requantScale_;
Activation activationOp_;
};
// Conv Perchannel quantization + Activation function which is piecewise linear function, such as
// relu, leaky relu ...etc
template <typename Activation>
struct Add_Activation_Mul2_Clamp
{
Add_Activation_Mul2_Clamp(Activation activationOp) : activationOp_(activationOp) {}
__host__ __device__ constexpr void
operator()(int8_t& y, const int32_t& x, const int32_t& bias, const float& requantScale) const
{
float y_fp32 = ck::type_convert<float>(x + bias);
activationOp_(y_fp32, y_fp32);
y_fp32 = math::clamp(requantScale * y_fp32, -128.f, 127.f);
y = ck::type_convert<int8_t>(y_fp32);
}
Activation activationOp_;
};
// For Activation function which is non piecewise linear function, such as TanH, Sigmoid ...etc
template <typename Activation>
struct Add_Mul_Activation_Mul_Clamp
{
Add_Mul_Activation_Mul_Clamp(float requantScale1, float requantScale2, Activation activationOp)
: requantScale1_(requantScale1), requantScale2_(requantScale2), activationOp_(activationOp)
{
}
__host__ __device__ constexpr void
operator()(int8_t& y, const int32_t& x, const int32_t& bias) const
{
float y_fp32 = ck::type_convert<float>(x + bias);
y_fp32 = requantScale1_ * y_fp32;
activationOp_(y_fp32, y_fp32);
y_fp32 = math::clamp(requantScale2_ * y_fp32, -128.f, 127.f);
y = ck::type_convert<int8_t>(y_fp32);
}
float requantScale1_;
float requantScale2_;
Activation activationOp_;
};
} // namespace element_wise
} // namespace tensor_operation
} // namespace ck
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
#pragma once #pragma once
#include "ck/utility/data_type.hpp" #include "ck/utility/data_type.hpp"
#include "ck/utility/math.hpp"
#include "ck/utility/math_v2.hpp" #include "ck/utility/math_v2.hpp"
namespace ck { namespace ck {
...@@ -94,6 +95,12 @@ struct Scale ...@@ -94,6 +95,12 @@ struct Scale
y = scale_ * x; y = scale_ * x;
}; };
template <>
__host__ __device__ void operator()<double, double>(double& y, const double& x) const
{
y = scale_ * x;
};
float scale_; float scale_;
}; };
...@@ -193,21 +200,36 @@ struct Relu ...@@ -193,21 +200,36 @@ struct Relu
} }
}; };
// https://paperswithcode.com/method/gelu // Y = FastGelu(X)
// y = 0.5*x*(1+tanh(sqrt(2/pi)*(x+0.044715*x^3)))
struct FastGelu struct FastGelu
{ {
template <typename Y, typename X> // Fast GeLU
__host__ __device__ void operator()(Y& y, const X& x) const; // https://paperswithcode.com/method/gelu
// y = 0.5*x*(1+tanh(sqrt(2/pi)*(x+0.044715*x^3)))
template <> __host__ __device__ static constexpr float GetFastGeLU(float x)
__host__ __device__ void operator()<float, float>(float& y, const float& x) const
{ {
const float u = float(2) * x * (float(0.035677) * x * x + float(0.797885)); const float u = 2.f * x * (0.035677f * x * x + 0.797885f);
const float emu = exp(-u); const float emu = exp(-u);
const float cdf = float(0.5) + float(0.5) * (float(2) / (float(1) + emu) - float(1)); const float cdf = 0.5f + 0.5f * (2.f / (1.f + emu) - 1.f);
return x * cdf;
}
template <typename T>
static inline constexpr bool is_valid_param_type_v =
std::is_same_v<T, float> || std::is_same_v<T, half_t> || std::is_same_v<T, bhalf_t> ||
std::is_same_v<T, int32_t> || std::is_same_v<T, int8_t>
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
|| std::is_same_v<T, ck::int4_t>
#endif
;
template <typename Y, typename X>
__host__ __device__ void operator()(Y& y, const X& x) const
{
static_assert(is_valid_param_type_v<Y> && is_valid_param_type_v<X>);
y = x * cdf; const float tmp_y = GetFastGeLU(type_convert<float>(x));
y = type_convert<Y>(tmp_y);
} }
}; };
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/data_type.hpp"
#include "ck/tensor_operation/gpu/block/reduction_functions_blockwise.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_welford.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace ck {
template <typename GridwiseReduceSecondHalfBatchNormBackwardFinal_,
typename XDataType,
typename DyDataType,
typename DxDataType,
typename ScaleDataType,
typename DscaleDbiasDataType,
typename MeanVarDataType,
typename DyElementwiseOp,
typename XYGridDesc_M_K,
typename DscaleDbiasGridDesc_M_K,
typename MeanVarGridDesc_M,
typename ScaleBiasGridDesc_M>
__global__ void kernel_reduce_second_half_batchnorm_backward_final(
const XYGridDesc_M_K x_grid_desc_m_k,
const XYGridDesc_M_K dy_grid_desc_m_k,
const XYGridDesc_M_K dx_grid_desc_m_k,
const DscaleDbiasGridDesc_M_K dscale_dbias_grid_desc_m_k,
const MeanVarGridDesc_M mean_var_grid_desc_m,
const ScaleBiasGridDesc_M scale_grid_desc_m,
const ScaleBiasGridDesc_M bias_grid_desc_m,
index_t blkgroup_size,
long_index_t reduce_size,
index_t num_xy_k_block_tile_iteration,
index_t num_dscale_dbias_k_block_tile_iteration,
const DscaleDbiasDataType* const __restrict__ p_reduce_dscale,
const DscaleDbiasDataType* const __restrict__ p_reduce_dbias,
const MeanVarDataType* const __restrict__ p_mean,
const MeanVarDataType* const __restrict__ p_inv_var,
const XDataType* const __restrict__ p_x,
const DyDataType* const __restrict__ p_dy,
const ScaleDataType* const __restrict__ p_scale,
const DyElementwiseOp dy_elementwise_op,
DxDataType* const __restrict__ p_dx,
DscaleDbiasDataType* const __restrict__ p_dscale,
DscaleDbiasDataType* const __restrict__ p_dbias)
{
GridwiseReduceSecondHalfBatchNormBackwardFinal_::Run(x_grid_desc_m_k,
dy_grid_desc_m_k,
dx_grid_desc_m_k,
dscale_dbias_grid_desc_m_k,
mean_var_grid_desc_m,
scale_grid_desc_m,
bias_grid_desc_m,
blkgroup_size,
reduce_size,
num_xy_k_block_tile_iteration,
num_dscale_dbias_k_block_tile_iteration,
p_reduce_dscale,
p_reduce_dbias,
p_mean,
p_inv_var,
p_x,
p_dy,
p_scale,
dy_elementwise_op,
p_dx,
p_dscale,
p_dbias);
};
template <typename XDataType,
typename DyDataType,
typename DxDataType,
typename AccDataType,
typename ScaleDataType,
typename DscaleDbiasDataType,
typename MeanVarDataType,
typename DyElementwiseOp,
typename XYGridDesc_M_K,
typename DscaleDbiasGridDesc_M_K,
typename MeanVarGridDesc_M,
typename ScaleBiasGridDesc_M,
index_t BlockSize,
index_t MThreadClusterSize,
index_t KThreadClusterSize,
index_t MThreadSliceSize,
index_t KThreadSliceSize,
index_t XDyDxVectorDim,
index_t XSrcVectorSize,
index_t DySrcVectorSize,
index_t DxDstVectorSize,
index_t ScaleSrcVectorSize,
index_t DscaleDbiasDstVectorSize,
index_t MeanVarSrcVectorSize>
struct GridwiseReduceSecondHalfBatchNormBackwardFinal
{
static_assert((XDyDxVectorDim == 0 && MThreadSliceSize % XSrcVectorSize == 0 &&
MThreadSliceSize % DySrcVectorSize == 0 &&
MThreadSliceSize % DxDstVectorSize == 0) ||
(XDyDxVectorDim == 1 && KThreadSliceSize % XSrcVectorSize == 0 &&
KThreadSliceSize % DySrcVectorSize == 0 &&
KThreadSliceSize % DxDstVectorSize == 0),
"Invalid thread slice sizes and/or vector sizes configuration, please check!");
static constexpr bool reorder_thread_cluster = (XDyDxVectorDim == 0);
using ThreadClusterLengths_M_K = Sequence<MThreadClusterSize, KThreadClusterSize>;
using ThreadBufferDimAccessOrder =
typename conditional<reorder_thread_cluster, Sequence<1, 0>, Sequence<0, 1>>::type;
using ThreadClusterArrangeOrder =
typename conditional<reorder_thread_cluster, Sequence<1, 0>, Sequence<0, 1>>::type;
static constexpr auto thread_cluster_desc =
make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{});
using ThreadReduceSrcDesc_M_1 = decltype(
make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{}, Number<1>{})));
using ThreadReduceDstDesc_M =
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{})));
using BlockwiseReduce = PartitionedBlockwiseReduction<AccDataType,
BlockSize,
ThreadClusterLengths_M_K,
ThreadClusterArrangeOrder,
ck::reduce::Add,
false>;
using ThreadwiseReduce = ThreadwiseReduction<AccDataType,
ThreadReduceSrcDesc_M_1,
ThreadReduceDstDesc_M,
ck::reduce::Add,
false>;
using PassThroughOp = tensor_operation::element_wise::PassThrough;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
// clang-format off
// Two of the steps of Multiblock BatchNorm Backward
// Step 1: Second half of Reduction: dbias = sum(dy), dscale = sum(dy * (x-mean) * inv-variance)
// Step 2: calculating dx = 1/reduce_size * inv-variance * scale * (reduce_size * dy - dbias - dscale * (x - mean) * inv-variance)) elementwise-ly
// clang-format on
__device__ static void Run(const XYGridDesc_M_K& x_grid_desc_m_k,
const XYGridDesc_M_K& dy_grid_desc_m_k,
const XYGridDesc_M_K& dx_grid_desc_m_k,
const DscaleDbiasGridDesc_M_K& dscale_dbias_grid_desc_m_k,
const MeanVarGridDesc_M& mean_var_grid_desc_m,
const ScaleBiasGridDesc_M& scale_grid_desc_m,
const ScaleBiasGridDesc_M& dscale_dbias_grid_desc_m,
index_t blkgroup_size,
long_index_t reduce_size,
index_t num_xy_k_block_tile_iteration,
index_t num_dscale_dbias_k_block_tile_iteration,
const DscaleDbiasDataType* const __restrict__ p_reduce_dscale,
const DscaleDbiasDataType* const __restrict__ p_reduce_dbias,
const MeanVarDataType* const __restrict__ p_mean,
const MeanVarDataType* const __restrict__ p_inv_var,
const XDataType* const __restrict__ p_x,
const DyDataType* const __restrict__ p_dy,
const ScaleDataType* const __restrict__ p_scale,
const DyElementwiseOp dy_elementwise_op,
DxDataType* const __restrict__ p_dx,
DscaleDbiasDataType* const __restrict__ p_dscale,
DscaleDbiasDataType* const __restrict__ p_dbias)
{
__shared__ AccDataType p_reduce_work_buffer[BlockSize];
auto reduce_work_buf =
make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_buffer, BlockSize);
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * 1, true>
reduce_dscale_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * 1, true>
reduce_dbias_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> dscale_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> dbias_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
x_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
dy_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
dx_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> mean_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true>
inv_var_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> scale_thread_buf;
const index_t thread_local_id = get_thread_local_1d_id();
const index_t block_global_id = get_block_1d_id();
const index_t blkgroup_id = block_global_id / blkgroup_size;
const index_t block_local_id = block_global_id % blkgroup_size;
const auto thread_cluster_idx =
thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));
const auto thread_m_cluster_id = thread_cluster_idx[I0];
const auto thread_k_cluster_id = thread_cluster_idx[I1];
using ThreadBufferLengths_M_K = Sequence<MThreadSliceSize, KThreadSliceSize>;
using ThreadBufferLengths_M = Sequence<MThreadSliceSize>;
using ThreadBufferLengths_M_1 = Sequence<MThreadSliceSize, 1>;
constexpr auto thread_buffer_desc_m_k = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{}));
constexpr auto thread_buffer_desc_m =
make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{}));
constexpr auto thread_buffer_desc_m_1 = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<1>{}));
// clang-format off
// Step 1: do final reduction of dbias = sum(dy), dscale = sum(dy * (x-mean) * inv-variance)
// clang-format on
auto threadwise_dscale_dbias_load_m_k =
ThreadwiseTensorSliceTransfer_v2<DscaleDbiasDataType,
AccDataType,
DscaleDbiasGridDesc_M_K,
decltype(thread_buffer_desc_m_1),
ThreadBufferLengths_M_1,
Sequence<0, 1>,
1,
1,
1,
true>(
dscale_dbias_grid_desc_m_k,
make_multi_index(blkgroup_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * 1));
auto threadwise_dscale_dbias_store_m =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
DscaleDbiasDataType,
decltype(thread_buffer_desc_m),
ScaleBiasGridDesc_M,
PassThroughOp,
ThreadBufferLengths_M,
Sequence<0>,
0,
DscaleDbiasDstVectorSize,
InMemoryDataOperationEnum::Set,
1,
true>(
dscale_dbias_grid_desc_m,
make_multi_index(blkgroup_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize),
PassThroughOp{});
const auto reduce_dscale_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_reduce_dscale, dscale_dbias_grid_desc_m_k.GetElementSpaceSize());
const auto reduce_dbias_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_reduce_dbias, dscale_dbias_grid_desc_m_k.GetElementSpaceSize());
auto dscale_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_dscale, dscale_dbias_grid_desc_m.GetElementSpaceSize());
auto dbias_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_dbias, dscale_dbias_grid_desc_m.GetElementSpaceSize());
constexpr auto dscale_dbias_thread_copy_step_m_k =
make_multi_index(0, KThreadClusterSize * 1);
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
dscale_thread_buf(I) = type_convert<AccDataType>(0.0f);
dbias_thread_buf(I) = type_convert<AccDataType>(0.0f);
});
for(index_t reducedTiles = 0; reducedTiles < num_dscale_dbias_k_block_tile_iteration;
++reducedTiles)
{
threadwise_dscale_dbias_load_m_k.Run(dscale_dbias_grid_desc_m_k,
reduce_dscale_global_buf,
thread_buffer_desc_m_1,
make_tuple(I0, I0),
reduce_dscale_thread_buf);
threadwise_dscale_dbias_load_m_k.Run(dscale_dbias_grid_desc_m_k,
reduce_dbias_global_buf,
thread_buffer_desc_m_1,
make_tuple(I0, I0),
reduce_dbias_thread_buf);
ThreadwiseReduce::Reduce(reduce_dscale_thread_buf, dscale_thread_buf);
ThreadwiseReduce::Reduce(reduce_dbias_thread_buf, dbias_thread_buf);
threadwise_dscale_dbias_load_m_k.MoveSrcSliceWindow(dscale_dbias_grid_desc_m_k,
dscale_dbias_thread_copy_step_m_k);
}
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
if constexpr(I > 0)
block_sync_lds();
BlockwiseReduce::Reduce(reduce_work_buf, dscale_thread_buf(I));
block_sync_lds();
BlockwiseReduce::Reduce(reduce_work_buf, dbias_thread_buf(I));
});
threadwise_dscale_dbias_store_m.Run(thread_buffer_desc_m,
make_tuple(I0),
dscale_thread_buf,
dscale_dbias_grid_desc_m,
dscale_global_buf);
threadwise_dscale_dbias_store_m.Run(thread_buffer_desc_m,
make_tuple(I0),
dbias_thread_buf,
dscale_dbias_grid_desc_m,
dbias_global_buf);
// clang-format off
// Step 2: calculate dx = 1/N * inv-variance * scale * (N * dy - dbias - dscale * (x - mean) * inv-variance)
// clang-format on
const index_t workSizePerBlock = K_BlockTileSize * num_xy_k_block_tile_iteration;
auto threadwise_x_load = ThreadwiseTensorSliceTransfer_v2<XDataType,
AccDataType,
XYGridDesc_M_K,
decltype(thread_buffer_desc_m_k),
ThreadBufferLengths_M_K,
ThreadBufferDimAccessOrder,
XDyDxVectorDim,
XSrcVectorSize,
1,
true>(
x_grid_desc_m_k,
make_multi_index(blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize,
workSizePerBlock * block_local_id +
thread_k_cluster_id * KThreadSliceSize));
auto threadwise_dy_load = ThreadwiseTensorSliceTransfer_v2<DyDataType,
AccDataType,
XYGridDesc_M_K,
decltype(thread_buffer_desc_m_k),
ThreadBufferLengths_M_K,
ThreadBufferDimAccessOrder,
XDyDxVectorDim,
DySrcVectorSize,
1,
true>(
dy_grid_desc_m_k,
make_multi_index(blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize,
workSizePerBlock * block_local_id +
thread_k_cluster_id * KThreadSliceSize));
auto threadwise_dx_store =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
DxDataType,
decltype(thread_buffer_desc_m_k),
XYGridDesc_M_K,
PassThroughOp,
ThreadBufferLengths_M_K,
ThreadBufferDimAccessOrder,
XDyDxVectorDim,
DxDstVectorSize,
InMemoryDataOperationEnum::Set,
1,
true>(
dx_grid_desc_m_k,
make_multi_index(
blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize,
workSizePerBlock * block_local_id + thread_k_cluster_id * KThreadSliceSize),
PassThroughOp{});
auto threadwise_scale_load =
ThreadwiseTensorSliceTransfer_v2<ScaleDataType,
AccDataType,
ScaleBiasGridDesc_M,
decltype(thread_buffer_desc_m),
ThreadBufferLengths_M,
Sequence<0>,
0,
ScaleSrcVectorSize,
1,
true>(
scale_grid_desc_m,
make_multi_index(blkgroup_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize));
auto threadwise_mean_var_load =
ThreadwiseTensorSliceTransfer_v2<MeanVarDataType,
AccDataType,
MeanVarGridDesc_M,
decltype(thread_buffer_desc_m),
ThreadBufferLengths_M,
Sequence<0>,
0,
MeanVarSrcVectorSize,
1,
true>(
mean_var_grid_desc_m,
make_multi_index(blkgroup_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize));
const auto x_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_x, x_grid_desc_m_k.GetElementSpaceSize());
const auto dy_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_dy, dy_grid_desc_m_k.GetElementSpaceSize());
auto dx_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_dx, dx_grid_desc_m_k.GetElementSpaceSize());
const auto scale_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_scale, scale_grid_desc_m.GetElementSpaceSize());
const auto mean_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_mean, mean_var_grid_desc_m.GetElementSpaceSize());
const auto inv_var_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_inv_var, mean_var_grid_desc_m.GetElementSpaceSize());
threadwise_scale_load.Run(scale_grid_desc_m,
scale_global_buf,
thread_buffer_desc_m,
make_tuple(I0),
scale_thread_buf);
threadwise_mean_var_load.Run(mean_var_grid_desc_m,
mean_global_buf,
thread_buffer_desc_m,
make_tuple(I0),
mean_thread_buf);
threadwise_mean_var_load.Run(mean_var_grid_desc_m,
inv_var_global_buf,
thread_buffer_desc_m,
make_tuple(I0),
inv_var_thread_buf);
constexpr auto xy_thread_copy_step_m_k = make_multi_index(0, K_BlockTileSize);
AccDataType inv_reduce_size =
type_convert<AccDataType>(1.0) / type_convert<AccDataType>(reduce_size);
for(index_t reducedTiles = 0; reducedTiles < num_xy_k_block_tile_iteration; ++reducedTiles)
{
threadwise_x_load.Run(x_grid_desc_m_k,
x_global_buf,
thread_buffer_desc_m_k,
make_tuple(I0, I0),
x_thread_buf);
threadwise_dy_load.Run(dy_grid_desc_m_k,
dy_global_buf,
thread_buffer_desc_m_k,
make_tuple(I0, I0),
dy_thread_buf);
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
AccDataType multiplier =
inv_reduce_size * inv_var_thread_buf[iM] * scale_thread_buf[iM];
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset =
thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK));
dy_elementwise_op(dy_thread_buf(Number<offset>{}),
dy_thread_buf[Number<offset>{}]);
AccDataType norm_x = (x_thread_buf[Number<offset>{}] - mean_thread_buf[iM]) *
inv_var_thread_buf[iM];
AccDataType tmpVal = norm_x * dscale_thread_buf[iM];
dx_thread_buf(Number<offset>{}) =
multiplier *
(type_convert<AccDataType>(reduce_size) * dy_thread_buf[Number<offset>{}] -
dbias_thread_buf[iM] - tmpVal);
});
});
threadwise_dx_store.Run(thread_buffer_desc_m_k,
make_tuple(I0, I0),
dx_thread_buf,
dx_grid_desc_m_k,
dx_global_buf);
threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, xy_thread_copy_step_m_k);
threadwise_dy_load.MoveSrcSliceWindow(dy_grid_desc_m_k, xy_thread_copy_step_m_k);
threadwise_dx_store.MoveDstSliceWindow(dx_grid_desc_m_k, xy_thread_copy_step_m_k);
}
};
};
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/data_type.hpp"
#include "ck/utility/math.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_welford.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_welford.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace ck {
template <typename GridwiseMultiblockWelfordFirstHalf_,
typename XDataType,
typename MeanVarDataType,
typename XGridDesc_M_K,
typename MeanVarCountGridDesc_M_G,
typename GetReduceCountPerThreadFunctor>
__global__ void kernel_multiblock_welford_first_half(
const XGridDesc_M_K x_grid_desc_m_k,
const MeanVarCountGridDesc_M_G mean_var_count_grid_desc_m_g,
const GetReduceCountPerThreadFunctor get_reduce_count_per_thread,
index_t num_k_block_tile_iteration,
const XDataType* const __restrict__ p_x,
MeanVarDataType* const p_welford_mean,
MeanVarDataType* const p_welford_variance,
int32_t* const p_welford_count)
{
GridwiseMultiblockWelfordFirstHalf_::Run(x_grid_desc_m_k,
mean_var_count_grid_desc_m_g,
get_reduce_count_per_thread,
num_k_block_tile_iteration,
p_x,
p_welford_mean,
p_welford_variance,
p_welford_count);
};
template <typename XDataType,
typename AccDataType,
typename MeanVarDataType,
typename XGridDesc_M_K,
typename MeanVarCountGridDesc_M_G,
typename GetReduceCountPerThreadFunctor,
index_t BlockSize,
index_t MThreadClusterSize,
index_t KThreadClusterSize,
index_t MThreadSliceSize,
index_t KThreadSliceSize,
index_t XSrcCountSrcVectorDim,
index_t XSrcCountSrcVectorSize>
struct GridwiseMultiblockWelfordFirstHalf
{
static_assert((XSrcCountSrcVectorDim == 0 && MThreadSliceSize % XSrcCountSrcVectorSize == 0) ||
(XSrcCountSrcVectorDim == 1 &&
KThreadSliceSize % XSrcCountSrcVectorSize == 0),
"Invalid thread slice sizes and/or vector sizes configuration, please check!");
static constexpr bool reorder_thread_cluster = (XSrcCountSrcVectorDim == 0);
using ThreadClusterLengths_M_K = Sequence<MThreadClusterSize, KThreadClusterSize>;
using ThreadBufferDimAccessOrder =
typename conditional<reorder_thread_cluster, Sequence<1, 0>, Sequence<0, 1>>::type;
using ThreadClusterArrangeOrder =
typename conditional<reorder_thread_cluster, Sequence<1, 0>, Sequence<0, 1>>::type;
static constexpr auto thread_cluster_desc =
make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{});
using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{})));
using ThreadReduceDstDesc_M =
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{})));
using ThreadwiseWelford =
ThreadwiseWelford<AccDataType, ThreadReduceSrcDesc_M_K, ThreadReduceDstDesc_M>;
using BlockwiseWelford = BlockwiseWelford<AccDataType,
BlockSize,
ThreadClusterLengths_M_K,
ThreadClusterArrangeOrder,
false>;
using PassThroughOp = tensor_operation::element_wise::PassThrough;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
// clang-format off
// First half of the Multiblock Welford method to calculate mean and variance, used by both batchnorm-forward and batchnorm-backward.
// clang-format on
__device__ static void Run(const XGridDesc_M_K& x_grid_desc_m_k,
const MeanVarCountGridDesc_M_G& mean_var_count_grid_desc_m_g,
const GetReduceCountPerThreadFunctor& get_reduce_count_per_thread,
index_t num_k_block_tile_iteration,
const XDataType* const __restrict__ p_x,
MeanVarDataType* const p_welford_mean,
MeanVarDataType* const p_welford_variance,
int32_t* const p_welford_count)
{
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
x_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true>
welford_mean_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true>
welford_var_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, int32_t, MThreadSliceSize, true>
welford_count_thread_buf;
const index_t blkgroup_size = mean_var_count_grid_desc_m_g.GetLength(I1);
const index_t thread_local_id = get_thread_local_1d_id();
const index_t block_global_id = get_block_1d_id();
const index_t blkgroup_id = block_global_id / blkgroup_size;
const index_t block_local_id = block_global_id % blkgroup_size;
const auto thread_cluster_idx =
thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));
const auto thread_m_cluster_id = thread_cluster_idx[I0];
const auto thread_k_cluster_id = thread_cluster_idx[I1];
using ThreadBufferLengths_M_K = Sequence<MThreadSliceSize, KThreadSliceSize>;
using ThreadBufferLengths_M_1 = Sequence<MThreadSliceSize, 1>;
constexpr auto thread_buffer_desc_m_k = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{}));
constexpr auto thread_buffer_desc_m_1 = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<1>{}));
const index_t reduceSizePerBlock = K_BlockTileSize * num_k_block_tile_iteration;
auto threadwise_x_load = ThreadwiseTensorSliceTransfer_v2<XDataType,
AccDataType,
XGridDesc_M_K,
decltype(thread_buffer_desc_m_k),
ThreadBufferLengths_M_K,
ThreadBufferDimAccessOrder,
XSrcCountSrcVectorDim,
XSrcCountSrcVectorSize,
1,
true>(
x_grid_desc_m_k,
make_multi_index(blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize,
block_local_id * reduceSizePerBlock +
thread_k_cluster_id * KThreadSliceSize));
auto threadwise_welford_mean_var_store =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
MeanVarDataType,
decltype(thread_buffer_desc_m_1),
MeanVarCountGridDesc_M_G,
PassThroughOp,
ThreadBufferLengths_M_1,
Sequence<0, 1>,
1,
1,
InMemoryDataOperationEnum::Set,
1,
true>(
mean_var_count_grid_desc_m_g,
make_multi_index(blkgroup_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
block_local_id),
PassThroughOp{});
auto threadwise_welford_count_store =
ThreadwiseTensorSliceTransfer_v1r3<int32_t,
int32_t,
decltype(thread_buffer_desc_m_1),
MeanVarCountGridDesc_M_G,
PassThroughOp,
ThreadBufferLengths_M_1,
Sequence<0, 1>,
1,
1,
InMemoryDataOperationEnum::Set,
1,
true>(
mean_var_count_grid_desc_m_g,
make_multi_index(blkgroup_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
block_local_id),
PassThroughOp{});
constexpr auto thread_copy_fwd_step_m_k = make_multi_index(0, K_BlockTileSize);
const auto x_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_x, x_grid_desc_m_k.GetElementSpaceSize());
auto welford_mean_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_welford_mean, mean_var_count_grid_desc_m_g.GetElementSpaceSize());
auto welford_var_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_welford_variance, mean_var_count_grid_desc_m_g.GetElementSpaceSize());
auto welford_count_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_welford_count, mean_var_count_grid_desc_m_g.GetElementSpaceSize());
auto threadwise_welford = ThreadwiseWelford();
threadwise_welford.max_count_ =
get_reduce_count_per_thread(block_local_id, thread_k_cluster_id);
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
welford_mean_thread_buf(I) = type_convert<AccDataType>(0.0f);
welford_var_thread_buf(I) = type_convert<AccDataType>(0.0f);
});
for(index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles)
{
threadwise_x_load.Run(x_grid_desc_m_k,
x_global_val_buf,
thread_buffer_desc_m_k,
make_tuple(I0, I0),
x_thread_buf);
threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_fwd_step_m_k);
threadwise_welford.Run(x_thread_buf, welford_mean_thread_buf, welford_var_thread_buf);
}
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
if constexpr(I > 0)
block_sync_lds();
welford_count_thread_buf(I) = threadwise_welford.cur_count_;
BlockwiseWelford::Run(
welford_mean_thread_buf(I), welford_var_thread_buf(I), welford_count_thread_buf(I));
});
if(thread_k_cluster_id == 0)
{
threadwise_welford_mean_var_store.Run(thread_buffer_desc_m_1,
make_tuple(I0, I0),
welford_mean_thread_buf,
mean_var_count_grid_desc_m_g,
welford_mean_global_val_buf);
threadwise_welford_mean_var_store.Run(thread_buffer_desc_m_1,
make_tuple(I0, I0),
welford_var_thread_buf,
mean_var_count_grid_desc_m_g,
welford_var_global_val_buf);
threadwise_welford_count_store.Run(thread_buffer_desc_m_1,
make_tuple(I0, I0),
welford_count_thread_buf,
mean_var_count_grid_desc_m_g,
welford_count_global_val_buf);
};
}
};
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/data_type.hpp"
#include "ck/utility/math_v2.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_welford.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_welford.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace ck {
template <typename GridwiseWelfordSecondHalfBatchNormForwardFinal_,
typename XDataType,
typename YDataType,
typename AccDataType,
typename ScaleDataType,
typename BiasDataType,
typename MeanVarDataType,
typename YElementwiseOp,
typename XYGridDesc_M_K,
typename MeanVarCountGridDesc_M_K,
typename ScaleBiasGridDesc_M,
typename MeanVarGridDesc_M>
__global__ void kernel_welford_second_half_batchnorm_forward_final(
const XYGridDesc_M_K x_grid_desc_m_k,
const XYGridDesc_M_K y_grid_desc_m_k,
const MeanVarCountGridDesc_M_K mean_var_count_grid_desc_m_k,
const ScaleBiasGridDesc_M scale_grid_desc_m,
const ScaleBiasGridDesc_M bias_grid_desc_m,
const MeanVarGridDesc_M mean_var_grid_desc_m,
index_t blkgroup_size,
index_t num_xy_k_block_tile_iteration,
index_t num_mean_var_count_k_block_tile_iteration,
AccDataType epsilon,
const MeanVarDataType* const __restrict__ p_in_welford_mean,
const MeanVarDataType* const __restrict__ p_in_welford_variance,
const int32_t* const __restrict__ p_in_welford_count,
const XDataType* const __restrict__ p_x,
const ScaleDataType* const __restrict__ p_scale,
const BiasDataType* const __restrict__ p_bias,
const YElementwiseOp y_elementwise_op,
YDataType* const __restrict__ p_y,
bool updateMovingAverage,
AccDataType averageFactor,
MeanVarDataType* const __restrict__ resultRunningMean,
MeanVarDataType* const __restrict__ resultRunningVariance,
bool saveMeanInvVariance,
MeanVarDataType* const __restrict__ resultSaveMean,
MeanVarDataType* const __restrict__ resultSaveInvVariance)
{
GridwiseWelfordSecondHalfBatchNormForwardFinal_::Run(x_grid_desc_m_k,
y_grid_desc_m_k,
mean_var_count_grid_desc_m_k,
scale_grid_desc_m,
bias_grid_desc_m,
mean_var_grid_desc_m,
blkgroup_size,
num_xy_k_block_tile_iteration,
num_mean_var_count_k_block_tile_iteration,
epsilon,
p_in_welford_mean,
p_in_welford_variance,
p_in_welford_count,
p_x,
p_scale,
p_bias,
y_elementwise_op,
p_y,
updateMovingAverage,
averageFactor,
resultRunningMean,
resultRunningVariance,
saveMeanInvVariance,
resultSaveMean,
resultSaveInvVariance);
};
template <typename XDataType,
typename YDataType,
typename AccDataType,
typename ScaleDataType,
typename BiasDataType,
typename MeanVarDataType,
typename YElementwiseOp,
typename XYGridDesc_M_K,
typename MeanVarCountGridDesc_M_K,
typename ScaleBiasGridDesc_M,
typename MeanVarGridDesc_M,
index_t BlockSize,
index_t MThreadClusterSize,
index_t KThreadClusterSize,
index_t MThreadSliceSize,
index_t KThreadSliceSize,
index_t XSrcYDstVectorDim,
index_t XSrcVectorSize,
index_t YDstVectorSize,
index_t ScaleSrcVectorSize,
index_t BiasSrcVectorSize,
index_t MeanVarSrcDstVectorSize>
struct GridwiseWelfordSecondHalfBatchNormForwardFinal
{
static_assert((XSrcYDstVectorDim == 0 && MThreadSliceSize % XSrcVectorSize == 0) ||
(XSrcYDstVectorDim == 1 && KThreadSliceSize % XSrcVectorSize == 0),
"Invalid thread slice sizes and/or vector sizes configuration, please check!");
static_assert((XSrcYDstVectorDim == 0 && MThreadSliceSize % YDstVectorSize == 0) ||
(XSrcYDstVectorDim == 1 && KThreadSliceSize % YDstVectorSize == 0),
"Invalid thread slice sizes and/or vector sizes configuration, please check!");
static constexpr bool reorder_thread_cluster = (XSrcYDstVectorDim == 0);
using ThreadClusterLengths_M_K = Sequence<MThreadClusterSize, KThreadClusterSize>;
using ThreadBufferDimAccessOrder =
typename conditional<reorder_thread_cluster, Sequence<1, 0>, Sequence<0, 1>>::type;
using ThreadClusterArrangeOrder =
typename conditional<reorder_thread_cluster, Sequence<1, 0>, Sequence<0, 1>>::type;
static constexpr auto thread_cluster_desc =
make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{});
using ThreadReduceSrcDesc_M_1 = decltype(
make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{}, Number<1>{})));
using ThreadReduceDstDesc_M =
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{})));
using ThreadwiseWelford =
ThreadwiseWelfordMerge<AccDataType, ThreadReduceSrcDesc_M_1, ThreadReduceDstDesc_M>;
using BlockwiseWelford = BlockwiseWelford<AccDataType,
BlockSize,
ThreadClusterLengths_M_K,
ThreadClusterArrangeOrder>;
using PassThroughOp = tensor_operation::element_wise::PassThrough;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
__device__ static void Run(const XYGridDesc_M_K& x_grid_desc_m_k,
const XYGridDesc_M_K& y_grid_desc_m_k,
const MeanVarCountGridDesc_M_K& mean_var_count_grid_desc_m_k,
const ScaleBiasGridDesc_M& scale_grid_desc_m,
const ScaleBiasGridDesc_M& bias_grid_desc_m,
const MeanVarGridDesc_M& mean_var_grid_desc_m,
index_t blkgroup_size,
index_t num_xy_k_block_tile_iteration,
index_t num_mean_var_count_k_block_tile_iteration,
AccDataType epsilon,
const MeanVarDataType* const __restrict__ p_in_welford_mean,
const MeanVarDataType* const __restrict__ p_in_welford_variance,
const int32_t* const __restrict__ p_in_welford_count,
const XDataType* const __restrict__ p_x,
const ScaleDataType* const __restrict__ p_scale,
const BiasDataType* const __restrict__ p_bias,
const YElementwiseOp y_elementwise_op,
YDataType* const __restrict__ p_y,
bool updateMovingAverage,
AccDataType averageFactor,
MeanVarDataType* const __restrict__ resultRunningMean,
MeanVarDataType* const __restrict__ resultRunningVariance,
bool saveMeanInvVariance,
MeanVarDataType* const __restrict__ resultSaveMean,
MeanVarDataType* const __restrict__ resultSaveInvVariance)
{
using ck::math::sqrt;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * 1, true>
in_welford_mean_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * 1, true>
in_welford_var_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, int32_t, MThreadSliceSize * 1, true>
in_welford_count_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true>
welford_mean_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true>
welford_var_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, int32_t, MThreadSliceSize, true>
welford_count_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
x_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
y_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> scale_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> bias_thread_buf;
const index_t thread_local_id = get_thread_local_1d_id();
const index_t block_global_id = get_block_1d_id();
const index_t blkgroup_id = block_global_id / blkgroup_size;
const index_t block_local_id = block_global_id % blkgroup_size;
const auto thread_cluster_idx =
thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));
const auto thread_m_cluster_id = thread_cluster_idx[I0];
const auto thread_k_cluster_id = thread_cluster_idx[I1];
using ThreadBufferLengths_M_K = Sequence<MThreadSliceSize, KThreadSliceSize>;
using ThreadBufferLengths_M = Sequence<MThreadSliceSize>;
using ThreadBufferLengths_M_1 = Sequence<MThreadSliceSize, 1>;
constexpr auto thread_buffer_desc_m_k = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{}));
constexpr auto thread_buffer_desc_m =
make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{}));
constexpr auto thread_buffer_desc_m_1 = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<1>{}));
auto threadwise_mean_var_load_m_k =
ThreadwiseTensorSliceTransfer_v2<MeanVarDataType,
AccDataType,
MeanVarCountGridDesc_M_K,
decltype(thread_buffer_desc_m_1),
ThreadBufferLengths_M_1,
Sequence<0, 1>,
1,
1,
1,
true>(
mean_var_count_grid_desc_m_k,
make_multi_index(blkgroup_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * 1));
auto threadwise_count_load_m_k =
ThreadwiseTensorSliceTransfer_v2<int32_t,
int32_t,
MeanVarCountGridDesc_M_K,
decltype(thread_buffer_desc_m_1),
ThreadBufferLengths_M_1,
Sequence<0, 1>,
1,
1,
1,
true>(
mean_var_count_grid_desc_m_k,
make_multi_index(blkgroup_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * 1));
const auto welford_mean_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_welford_mean, mean_var_count_grid_desc_m_k.GetElementSpaceSize());
const auto welford_var_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_welford_variance, mean_var_count_grid_desc_m_k.GetElementSpaceSize());
const auto welford_count_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_welford_count, mean_var_count_grid_desc_m_k.GetElementSpaceSize());
constexpr auto mean_var_count_thread_copy_step_m_k =
make_multi_index(0, KThreadClusterSize * 1);
// Step 1: do final welford reduction to get mean and variance
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
welford_mean_thread_buf(I) = type_convert<AccDataType>(0.0f);
welford_var_thread_buf(I) = type_convert<AccDataType>(0.0f);
welford_count_thread_buf(I) = 0;
});
for(index_t reducedTiles = 0; reducedTiles < num_mean_var_count_k_block_tile_iteration;
++reducedTiles)
{
threadwise_mean_var_load_m_k.Run(mean_var_count_grid_desc_m_k,
welford_mean_global_val_buf,
thread_buffer_desc_m_1,
make_tuple(I0, I0),
in_welford_mean_thread_buf);
threadwise_mean_var_load_m_k.Run(mean_var_count_grid_desc_m_k,
welford_var_global_val_buf,
thread_buffer_desc_m_1,
make_tuple(I0, I0),
in_welford_var_thread_buf);
threadwise_count_load_m_k.Run(mean_var_count_grid_desc_m_k,
welford_count_global_val_buf,
thread_buffer_desc_m_1,
make_tuple(I0, I0),
in_welford_count_thread_buf);
ThreadwiseWelford::Run(in_welford_mean_thread_buf,
in_welford_var_thread_buf,
in_welford_count_thread_buf,
welford_mean_thread_buf,
welford_var_thread_buf,
welford_count_thread_buf);
threadwise_mean_var_load_m_k.MoveSrcSliceWindow(mean_var_count_grid_desc_m_k,
mean_var_count_thread_copy_step_m_k);
threadwise_count_load_m_k.MoveSrcSliceWindow(mean_var_count_grid_desc_m_k,
mean_var_count_thread_copy_step_m_k);
}
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
if constexpr(I > 0)
block_sync_lds();
BlockwiseWelford::Run(
welford_mean_thread_buf(I), welford_var_thread_buf(I), welford_count_thread_buf(I));
});
// Step 2: do normalization and output y
const index_t workSizePerBlock = K_BlockTileSize * num_xy_k_block_tile_iteration;
auto threadwise_x_load = ThreadwiseTensorSliceTransfer_v2<XDataType,
AccDataType,
XYGridDesc_M_K,
decltype(thread_buffer_desc_m_k),
ThreadBufferLengths_M_K,
ThreadBufferDimAccessOrder,
XSrcYDstVectorDim,
XSrcVectorSize,
1,
true>(
x_grid_desc_m_k,
make_multi_index(blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize,
workSizePerBlock * block_local_id +
thread_k_cluster_id * KThreadSliceSize));
auto threadwise_y_store =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
YDataType,
decltype(thread_buffer_desc_m_k),
XYGridDesc_M_K,
YElementwiseOp,
ThreadBufferLengths_M_K,
ThreadBufferDimAccessOrder,
XSrcYDstVectorDim,
YDstVectorSize,
InMemoryDataOperationEnum::Set,
1,
true>(
y_grid_desc_m_k,
make_multi_index(
blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize,
workSizePerBlock * block_local_id + thread_k_cluster_id * KThreadSliceSize),
y_elementwise_op);
auto threadwise_scale_load =
ThreadwiseTensorSliceTransfer_v2<ScaleDataType,
AccDataType,
ScaleBiasGridDesc_M,
decltype(thread_buffer_desc_m),
ThreadBufferLengths_M,
Sequence<0>,
0,
ScaleSrcVectorSize,
1,
true>(
scale_grid_desc_m,
make_multi_index(blkgroup_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize));
auto threadwise_bias_load = ThreadwiseTensorSliceTransfer_v2<BiasDataType,
AccDataType,
ScaleBiasGridDesc_M,
decltype(thread_buffer_desc_m),
ThreadBufferLengths_M,
Sequence<0>,
0,
BiasSrcVectorSize,
1,
true>(
bias_grid_desc_m,
make_multi_index(blkgroup_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize));
const auto x_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_x, x_grid_desc_m_k.GetElementSpaceSize());
const auto scale_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_scale, scale_grid_desc_m.GetElementSpaceSize());
const auto bias_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_bias, bias_grid_desc_m.GetElementSpaceSize());
auto y_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_y, y_grid_desc_m_k.GetElementSpaceSize());
threadwise_scale_load.Run(scale_grid_desc_m,
scale_global_val_buf,
thread_buffer_desc_m,
make_tuple(I0),
scale_thread_buf);
threadwise_bias_load.Run(bias_grid_desc_m,
bias_global_val_buf,
thread_buffer_desc_m,
make_tuple(I0),
bias_thread_buf);
constexpr auto xy_thread_copy_step_m_k = make_multi_index(0, K_BlockTileSize);
for(index_t workTiles = 0; workTiles < num_xy_k_block_tile_iteration; ++workTiles)
{
threadwise_x_load.Run(x_grid_desc_m_k,
x_global_val_buf,
thread_buffer_desc_m_k,
make_tuple(I0, I0),
x_thread_buf);
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
AccDataType multiplier =
scale_thread_buf[iM] / sqrt(welford_var_thread_buf[iM] + epsilon);
AccDataType fused_mean_bias =
bias_thread_buf[iM] - welford_mean_thread_buf[iM] * multiplier;
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset =
thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK));
y_thread_buf(Number<offset>{}) =
x_thread_buf[Number<offset>{}] * multiplier + fused_mean_bias;
});
});
threadwise_y_store.Run(thread_buffer_desc_m_k,
make_tuple(I0, I0),
y_thread_buf,
y_grid_desc_m_k,
y_global_val_buf);
threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, xy_thread_copy_step_m_k);
threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k, xy_thread_copy_step_m_k);
}
// Step 3: update the moving average of mean and variance (optional)
if(updateMovingAverage && block_local_id == 0 && thread_k_cluster_id == 0)
{
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true>
running_mean_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true>
running_var_thread_buf;
auto running_mean_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
resultRunningMean, mean_var_grid_desc_m.GetElementSpaceSize());
auto running_var_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
resultRunningVariance, mean_var_grid_desc_m.GetElementSpaceSize());
auto threadwise_mean_var_load_m =
ThreadwiseTensorSliceTransfer_v2<MeanVarDataType,
AccDataType,
MeanVarGridDesc_M,
decltype(thread_buffer_desc_m),
ThreadBufferLengths_M,
Sequence<0>,
0,
MeanVarSrcDstVectorSize,
1,
true>(
mean_var_grid_desc_m,
make_multi_index(blkgroup_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize));
threadwise_mean_var_load_m.Run(mean_var_grid_desc_m,
running_mean_global_buf,
thread_buffer_desc_m,
make_tuple(I0),
running_mean_thread_buf);
threadwise_mean_var_load_m.Run(mean_var_grid_desc_m,
running_var_global_buf,
thread_buffer_desc_m,
make_tuple(I0),
running_var_thread_buf);
AccDataType oneMinusAverageFactor = type_convert<AccDataType>(1.0) - averageFactor;
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
running_mean_thread_buf(I) = running_mean_thread_buf[I] * oneMinusAverageFactor +
welford_mean_thread_buf[I] * averageFactor;
running_var_thread_buf(I) = running_var_thread_buf[I] * oneMinusAverageFactor +
welford_var_thread_buf[I] * averageFactor;
});
auto threadwise_mean_var_store =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
MeanVarDataType,
decltype(thread_buffer_desc_m),
MeanVarGridDesc_M,
PassThroughOp,
ThreadBufferLengths_M,
Sequence<0>,
0,
MeanVarSrcDstVectorSize,
InMemoryDataOperationEnum::Set,
1,
true>(
mean_var_grid_desc_m,
make_multi_index(blkgroup_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize),
PassThroughOp{});
threadwise_mean_var_store.Run(thread_buffer_desc_m,
make_tuple(I0),
running_mean_thread_buf,
mean_var_grid_desc_m,
running_mean_global_buf);
threadwise_mean_var_store.Run(thread_buffer_desc_m,
make_tuple(I0),
running_var_thread_buf,
mean_var_grid_desc_m,
running_var_global_buf);
};
// Step 4: save mean and inv-variance (optional)
if(saveMeanInvVariance && block_local_id == 0 && thread_k_cluster_id == 0)
{
auto result_mean_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
resultSaveMean, mean_var_grid_desc_m.GetElementSpaceSize());
auto result_inv_var_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
resultSaveInvVariance, mean_var_grid_desc_m.GetElementSpaceSize());
// calculate inv-variance as 1/sqrt(epsilon+variance)
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
welford_var_thread_buf(I) =
type_convert<AccDataType>(1.0f) / sqrt(epsilon + welford_var_thread_buf[I]);
});
auto threadwise_mean_inv_var_store =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
MeanVarDataType,
decltype(thread_buffer_desc_m),
MeanVarGridDesc_M,
PassThroughOp,
ThreadBufferLengths_M,
Sequence<0>,
0,
MeanVarSrcDstVectorSize,
InMemoryDataOperationEnum::Set,
1,
true>(
mean_var_grid_desc_m,
make_multi_index(blkgroup_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize),
PassThroughOp{});
threadwise_mean_inv_var_store.Run(thread_buffer_desc_m,
make_tuple(I0),
welford_mean_thread_buf,
mean_var_grid_desc_m,
result_mean_global_buf);
threadwise_mean_inv_var_store.Run(thread_buffer_desc_m,
make_tuple(I0),
welford_var_thread_buf,
mean_var_grid_desc_m,
result_inv_var_global_buf);
};
}
};
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/data_type.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_welford.hpp"
#include "ck/tensor_operation/gpu/block/reduction_functions_blockwise.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_welford.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace ck {
template <typename GridwiseWelfordSecondHalfReduceFirstHalf_,
typename XDataType,
typename DyDataType,
typename AccDataType,
typename ScaleDataType,
typename DscaleDbiasDataType,
typename MeanVarDataType,
typename DyElementwiseOp,
typename XYGridDesc_M_K,
typename MeanVarGridDesc_M,
typename MeanVarCountGridDesc_M_K,
typename DscaleDbiasGridDesc_M_G>
__global__ void kernel_welford_second_half_reduce_first_half(
const XYGridDesc_M_K x_grid_desc_m_k,
const XYGridDesc_M_K dy_grid_desc_m_k,
const MeanVarGridDesc_M mean_var_grid_desc_m,
const MeanVarCountGridDesc_M_K mean_var_count_grid_desc_m_k,
const DscaleDbiasGridDesc_M_G dscale_dbias_grid_desc_m_g,
index_t blkgroup_size,
index_t num_xy_k_block_tile_iteration,
index_t num_mean_var_count_k_block_tile_iteration,
AccDataType epsilon,
bool haveSavedMeanInvVar,
const MeanVarDataType* const __restrict__ p_savedMean,
const MeanVarDataType* const __restrict__ p_savedInvVar,
const MeanVarDataType* const __restrict__ p_in_welford_mean,
const MeanVarDataType* const __restrict__ p_in_welford_variance,
const int32_t* const __restrict__ p_in_welford_count,
const DyElementwiseOp dy_elementwise_op,
MeanVarDataType* const __restrict__ p_out_welford_mean,
MeanVarDataType* const __restrict__ p_out_welford_inv_variance,
const XDataType* const __restrict__ p_x,
const DyDataType* const __restrict__ p_dy,
DscaleDbiasDataType* const __restrict__ p_reduce_dscale,
DscaleDbiasDataType* const __restrict__ p_reduce_dbias)
{
GridwiseWelfordSecondHalfReduceFirstHalf_::Run(x_grid_desc_m_k,
dy_grid_desc_m_k,
mean_var_grid_desc_m,
mean_var_count_grid_desc_m_k,
dscale_dbias_grid_desc_m_g,
blkgroup_size,
num_xy_k_block_tile_iteration,
num_mean_var_count_k_block_tile_iteration,
epsilon,
haveSavedMeanInvVar,
p_savedMean,
p_savedInvVar,
p_in_welford_mean,
p_in_welford_variance,
p_in_welford_count,
dy_elementwise_op,
p_out_welford_mean,
p_out_welford_inv_variance,
p_x,
p_dy,
p_reduce_dscale,
p_reduce_dbias);
};
template <typename XDataType,
typename DyDataType,
typename AccDataType,
typename ScaleDataType,
typename DscaleDbiasDataType,
typename MeanVarDataType,
typename DyElementwiseOp,
typename XYGridDesc_M_K,
typename MeanVarGridDesc_M,
typename MeanVarCountGridDesc_M_K,
typename DscaleDbiasGridDesc_M_G,
index_t BlockSize,
index_t MThreadClusterSize,
index_t KThreadClusterSize,
index_t MThreadSliceSize,
index_t KThreadSliceSize,
index_t XDyVectorDim,
index_t XSrcVectorSize,
index_t DySrcVectorSize,
index_t MeanVarSrcVectorSize>
struct GridwiseWelfordSecondHalfReduceFirstHalf
{
static_assert((XDyVectorDim == 0 && MThreadSliceSize % XSrcVectorSize == 0 &&
MThreadSliceSize % DySrcVectorSize == 0) ||
(XDyVectorDim == 1 && KThreadSliceSize % XSrcVectorSize == 0 &&
KThreadSliceSize % DySrcVectorSize == 0),
"Invalid thread slice sizes and/or vector sizes configuration, please check!");
static constexpr bool reorder_thread_cluster = (XDyVectorDim == 0);
using ThreadClusterLengths_M_K = Sequence<MThreadClusterSize, KThreadClusterSize>;
using ThreadBufferDimAccessOrder =
typename conditional<reorder_thread_cluster, Sequence<1, 0>, Sequence<0, 1>>::type;
using ThreadClusterArrangeOrder =
typename conditional<reorder_thread_cluster, Sequence<1, 0>, Sequence<0, 1>>::type;
static constexpr auto thread_cluster_desc =
make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{});
using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{})));
using ThreadReduceSrcDesc_M_1 = decltype(
make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{}, Number<1>{})));
using ThreadReduceDstDesc_M =
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{})));
using ThreadwiseWelford =
ThreadwiseWelfordMerge<AccDataType, ThreadReduceSrcDesc_M_1, ThreadReduceDstDesc_M>;
using BlockwiseWelford = BlockwiseWelford<AccDataType,
BlockSize,
ThreadClusterLengths_M_K,
ThreadClusterArrangeOrder>;
using BlockwiseReduce = PartitionedBlockwiseReduction<AccDataType,
BlockSize,
ThreadClusterLengths_M_K,
ThreadClusterArrangeOrder,
ck::reduce::Add,
false>;
using ThreadwiseReduce = ThreadwiseReduction<AccDataType,
ThreadReduceSrcDesc_M_K,
ThreadReduceDstDesc_M,
ck::reduce::Add,
false>;
using PassThroughOp = tensor_operation::element_wise::PassThrough;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
// clang-format off
// Two of the steps of Multiblock BatchNorm Backward
// Step 1: Second half of Welford method to calculate mean and variance, as well as getting inv-variance = 1/sqrt(epsilon+variance)
// Step 2: First half of Reduction: dbias = sum(dy), dscale = sum(dy * (x-mean) * inv-variance)
// clang-format on
__device__ static void Run(const XYGridDesc_M_K& x_grid_desc_m_k,
const XYGridDesc_M_K& dy_grid_desc_m_k,
const MeanVarGridDesc_M& mean_var_grid_desc_m,
const MeanVarCountGridDesc_M_K& mean_var_count_grid_desc_m_k,
const DscaleDbiasGridDesc_M_G& dscale_dbias_grid_desc_m_g,
index_t blkgroup_size,
index_t num_xy_k_block_tile_iteration,
index_t num_mean_var_count_k_block_tile_iteration,
AccDataType epsilon,
bool haveSavedMeanInvVar,
const MeanVarDataType* const __restrict__ p_savedMean,
const MeanVarDataType* const __restrict__ p_savedInvVar,
const MeanVarDataType* const __restrict__ p_in_welford_mean,
const MeanVarDataType* const __restrict__ p_in_welford_variance,
const int32_t* const __restrict__ p_in_welford_count,
const DyElementwiseOp dy_elementwise_op,
MeanVarDataType* const __restrict__ p_out_welford_mean,
MeanVarDataType* const __restrict__ p_out_welford_inv_variance,
const XDataType* const __restrict__ p_x,
const DyDataType* const __restrict__ p_dy,
DscaleDbiasDataType* const __restrict__ p_reduce_dscale,
DscaleDbiasDataType* const __restrict__ p_reduce_dbias)
{
__shared__ AccDataType p_reduce_work_buffer[BlockSize];
auto reduce_work_buf =
make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_buffer, BlockSize);
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * 1, true>
in_welford_mean_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * 1, true>
in_welford_var_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, int32_t, MThreadSliceSize * 1, true>
in_welford_count_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true>
welford_mean_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true>
welford_var_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, int32_t, MThreadSliceSize, true>
welford_count_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true>& mean_thread_buf =
welford_mean_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true>&
inv_var_thread_buf = welford_var_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
x_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
dy_thread_buf;
// buffer of values of dy * (x-mean) * inv-variance, used as input of Blockwise reduction
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
tmp1_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true>
reduce_dscale_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true>
reduce_dbias_thread_buf;
const index_t thread_local_id = get_thread_local_1d_id();
const index_t block_global_id = get_block_1d_id();
const index_t blkgroup_id = block_global_id / blkgroup_size;
const index_t block_local_id = block_global_id % blkgroup_size;
const auto thread_cluster_idx =
thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));
const auto thread_m_cluster_id = thread_cluster_idx[I0];
const auto thread_k_cluster_id = thread_cluster_idx[I1];
using ThreadBufferLengths_M_K = Sequence<MThreadSliceSize, KThreadSliceSize>;
using ThreadBufferLengths_M = Sequence<MThreadSliceSize>;
using ThreadBufferLengths_M_1 = Sequence<MThreadSliceSize, 1>;
constexpr auto thread_buffer_desc_m_k = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{}));
constexpr auto thread_buffer_desc_m =
make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{}));
constexpr auto thread_buffer_desc_m_1 = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<1>{}));
// clang-format off
// Step 1: load existing mean and inv-variance, or do final welford reduction on mean and variance as well as get inv-variance = 1/sqrt(epsilon+variance)
// clang-format on
if(haveSavedMeanInvVar)
{
const auto mean_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_savedMean, mean_var_grid_desc_m.GetElementSpaceSize());
const auto inv_var_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_savedInvVar, mean_var_grid_desc_m.GetElementSpaceSize());
auto threadwise_mean_inv_var_load =
ThreadwiseTensorSliceTransfer_v2<MeanVarDataType,
AccDataType,
MeanVarGridDesc_M,
decltype(thread_buffer_desc_m),
ThreadBufferLengths_M,
Sequence<0>,
0,
MeanVarSrcVectorSize,
1,
true>(
mean_var_grid_desc_m,
make_multi_index(blkgroup_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize));
threadwise_mean_inv_var_load.Run(mean_var_grid_desc_m,
mean_global_buf,
thread_buffer_desc_m,
make_tuple(I0),
mean_thread_buf);
threadwise_mean_inv_var_load.Run(mean_var_grid_desc_m,
inv_var_global_buf,
thread_buffer_desc_m,
make_tuple(I0),
inv_var_thread_buf);
}
else
{
const auto welford_mean_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_welford_mean, mean_var_count_grid_desc_m_k.GetElementSpaceSize());
const auto welford_var_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_welford_variance, mean_var_count_grid_desc_m_k.GetElementSpaceSize());
const auto welford_count_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_welford_count, mean_var_count_grid_desc_m_k.GetElementSpaceSize());
auto threadwise_mean_var_load_m_k =
ThreadwiseTensorSliceTransfer_v2<AccDataType,
AccDataType,
MeanVarCountGridDesc_M_K,
decltype(thread_buffer_desc_m_1),
ThreadBufferLengths_M_1,
Sequence<0, 1>,
1,
1,
1,
true>(
mean_var_count_grid_desc_m_k,
make_multi_index(blkgroup_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * 1));
auto threadwise_count_load_m_k =
ThreadwiseTensorSliceTransfer_v2<int32_t,
int32_t,
MeanVarCountGridDesc_M_K,
decltype(thread_buffer_desc_m_1),
ThreadBufferLengths_M_1,
Sequence<0, 1>,
1,
1,
1,
true>(
mean_var_count_grid_desc_m_k,
make_multi_index(blkgroup_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * 1));
constexpr auto mean_var_count_thread_copy_step_m_k =
make_multi_index(0, KThreadClusterSize * 1);
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
welford_mean_thread_buf(I) = type_convert<AccDataType>(0.0f);
welford_var_thread_buf(I) = type_convert<AccDataType>(0.0f);
welford_count_thread_buf(I) = 0;
});
for(index_t reducedTiles = 0; reducedTiles < num_mean_var_count_k_block_tile_iteration;
++reducedTiles)
{
threadwise_mean_var_load_m_k.Run(mean_var_count_grid_desc_m_k,
welford_mean_global_buf,
thread_buffer_desc_m_1,
make_tuple(I0, I0),
in_welford_mean_thread_buf);
threadwise_mean_var_load_m_k.Run(mean_var_count_grid_desc_m_k,
welford_var_global_buf,
thread_buffer_desc_m_1,
make_tuple(I0, I0),
in_welford_var_thread_buf);
threadwise_count_load_m_k.Run(mean_var_count_grid_desc_m_k,
welford_count_global_buf,
thread_buffer_desc_m_1,
make_tuple(I0, I0),
in_welford_count_thread_buf);
ThreadwiseWelford::Run(in_welford_mean_thread_buf,
in_welford_var_thread_buf,
in_welford_count_thread_buf,
welford_mean_thread_buf,
welford_var_thread_buf,
welford_count_thread_buf);
threadwise_mean_var_load_m_k.MoveSrcSliceWindow(
mean_var_count_grid_desc_m_k, mean_var_count_thread_copy_step_m_k);
threadwise_count_load_m_k.MoveSrcSliceWindow(mean_var_count_grid_desc_m_k,
mean_var_count_thread_copy_step_m_k);
}
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
if constexpr(I > 0)
block_sync_lds();
BlockwiseWelford::Run(welford_mean_thread_buf(I),
welford_var_thread_buf(I),
welford_count_thread_buf(I));
});
// calculate inv-variance as 1/sqrt(epsilon+variance), stored in place of variance
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
welford_var_thread_buf(I) =
type_convert<AccDataType>(1.0) / sqrt(welford_var_thread_buf[I] + epsilon);
});
if(block_local_id == 0 && thread_k_cluster_id == 0)
{
auto threadwise_mean_inv_var_store =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
MeanVarDataType,
decltype(thread_buffer_desc_m),
MeanVarGridDesc_M,
PassThroughOp,
ThreadBufferLengths_M,
Sequence<0>,
0,
1,
InMemoryDataOperationEnum::Set,
1,
true>(
mean_var_grid_desc_m,
make_multi_index(blkgroup_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize),
PassThroughOp{});
auto mean_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_out_welford_mean, mean_var_grid_desc_m.GetElementSpaceSize());
auto inv_var_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_out_welford_inv_variance, mean_var_grid_desc_m.GetElementSpaceSize());
threadwise_mean_inv_var_store.Run(thread_buffer_desc_m,
make_tuple(I0),
mean_thread_buf,
mean_var_grid_desc_m,
mean_global_buf);
threadwise_mean_inv_var_store.Run(thread_buffer_desc_m,
make_tuple(I0),
inv_var_thread_buf,
mean_var_grid_desc_m,
inv_var_global_buf);
};
};
const index_t workSizePerBlock = K_BlockTileSize * num_xy_k_block_tile_iteration;
auto threadwise_x_load = ThreadwiseTensorSliceTransfer_v2<XDataType,
AccDataType,
XYGridDesc_M_K,
decltype(thread_buffer_desc_m_k),
ThreadBufferLengths_M_K,
ThreadBufferDimAccessOrder,
XDyVectorDim,
XSrcVectorSize,
1,
true>(
x_grid_desc_m_k,
make_multi_index(blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize,
workSizePerBlock * block_local_id +
thread_k_cluster_id * KThreadSliceSize));
auto threadwise_dy_load = ThreadwiseTensorSliceTransfer_v2<DyDataType,
AccDataType,
XYGridDesc_M_K,
decltype(thread_buffer_desc_m_k),
ThreadBufferLengths_M_K,
ThreadBufferDimAccessOrder,
XDyVectorDim,
DySrcVectorSize,
1,
true>(
dy_grid_desc_m_k,
make_multi_index(blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize,
workSizePerBlock * block_local_id +
thread_k_cluster_id * KThreadSliceSize));
const auto x_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_x, x_grid_desc_m_k.GetElementSpaceSize());
const auto dy_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_dy, dy_grid_desc_m_k.GetElementSpaceSize());
constexpr auto xy_thread_copy_step_m_k = make_multi_index(0, K_BlockTileSize);
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
reduce_dscale_thread_buf(I) = type_convert<AccDataType>(0);
reduce_dbias_thread_buf(I) = type_convert<AccDataType>(0);
});
// clang-format off
// Step 2: first-half of reduction: dbias = sum(dy), dscale = sum(dy * (x-mean) * inv-variance)
// clang-format on
for(index_t reducedTiles = 0; reducedTiles < num_xy_k_block_tile_iteration; ++reducedTiles)
{
threadwise_x_load.Run(x_grid_desc_m_k,
x_global_buf,
thread_buffer_desc_m_k,
make_tuple(I0, I0),
x_thread_buf);
threadwise_dy_load.Run(dy_grid_desc_m_k,
dy_global_buf,
thread_buffer_desc_m_k,
make_tuple(I0, I0),
dy_thread_buf);
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset =
thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK));
dy_elementwise_op(dy_thread_buf(Number<offset>{}),
dy_thread_buf[Number<offset>{}]);
AccDataType norm_x = (x_thread_buf[Number<offset>{}] - mean_thread_buf[iM]) *
inv_var_thread_buf[iM];
tmp1_thread_buf(Number<offset>{}) = norm_x * dy_thread_buf[Number<offset>{}];
});
});
ThreadwiseReduce::Reduce(tmp1_thread_buf, reduce_dscale_thread_buf);
ThreadwiseReduce::Reduce(dy_thread_buf, reduce_dbias_thread_buf);
threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, xy_thread_copy_step_m_k);
threadwise_dy_load.MoveSrcSliceWindow(dy_grid_desc_m_k, xy_thread_copy_step_m_k);
};
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
if constexpr(I > 0)
block_sync_lds();
BlockwiseReduce::Reduce(reduce_work_buf, reduce_dscale_thread_buf(I));
block_sync_lds();
BlockwiseReduce::Reduce(reduce_work_buf, reduce_dbias_thread_buf(I));
});
auto threadwise_dscale_dbias_store =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
DscaleDbiasDataType,
decltype(thread_buffer_desc_m_1),
DscaleDbiasGridDesc_M_G,
PassThroughOp,
ThreadBufferLengths_M_1,
Sequence<0, 1>,
1,
1,
InMemoryDataOperationEnum::Set,
1,
true>(
dscale_dbias_grid_desc_m_g,
make_multi_index(blkgroup_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
block_local_id),
PassThroughOp{});
auto reduce_dscale_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_reduce_dscale, dscale_dbias_grid_desc_m_g.GetElementSpaceSize());
auto reduce_dbias_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_reduce_dbias, dscale_dbias_grid_desc_m_g.GetElementSpaceSize());
if(thread_k_cluster_id == 0)
{
threadwise_dscale_dbias_store.Run(thread_buffer_desc_m_1,
make_tuple(I0, I0),
reduce_dscale_thread_buf,
dscale_dbias_grid_desc_m_g,
reduce_dscale_global_buf);
threadwise_dscale_dbias_store.Run(thread_buffer_desc_m_1,
make_tuple(I0, I0),
reduce_dbias_thread_buf,
dscale_dbias_grid_desc_m_g,
reduce_dbias_global_buf);
};
};
};
} // namespace ck
...@@ -154,6 +154,50 @@ struct BlockToCTileMap_M00_N0_M01Adapt ...@@ -154,6 +154,50 @@ struct BlockToCTileMap_M00_N0_M01Adapt
index_t idx_M01 = idx_M0 % M01_; index_t idx_M01 = idx_M0 % M01_;
index_t idx_N0_M01_local = idx_N0 + idx_M01 * N0; index_t idx_N0_M01_local = idx_N0 + idx_M01 * N0;
/**
* idxN0
*
* |< mtx N >|
*
* NPerBlock NPerBlock NPerBlock NPerBlock
* N_0 N_1 N_2 N_3
* - |-----------|-----------|-----------|-----|-----|-
* ^ | - - 0 |/----> 2 | | | |
* | | | / | | | | | M_0 MPerBlock
* | M | /| | | | | |
* |-0---|---/-|-----|-----|-----------|-----|-----|-
* | 1 | / | | | blockid | | |
* idxM0 | | | / | V | 5 | | | M_1 MPerBlock
* | - V 1 | - 3 | | | |
* |-----------|-----------|-----------|-----|-----|-
* mtx M | | | | | |
* | | | | | | M_2 MPerBlock
* | | | | | |
* |-----------|-----------|-----------|-----|-----|-
* | | | | | |
* | | | | | | M_3 MPerBlock
* | | | | | |
* |-----------|-----------|-----------|-----|-----|-
* V | | | | | |
* - |-----------|-----------|-----------|-----|-----|- M_4 MPerBlock
* | | | | | |
* |-----------|-----------|-----------|-----|-----|-
* Example:
* assume:
* M0 = 5
* N0 = 4
* block_1d_id = 5
* M01 = 2
*
* idx_N0 = 1
* idx_M0 = 1
* M01_adapt = 2
* idx_M00 = 0
* idx_M01 = 1
* idx_N0_M01_local = 5
* output {1, 2}
*/
return make_tuple(idx_N0_M01_local % M01_adapt + idx_M00 * M01_, return make_tuple(idx_N0_M01_local % M01_adapt + idx_M00 * M01_,
idx_N0_M01_local / M01_adapt); idx_N0_M01_local / M01_adapt);
} }
...@@ -364,14 +408,16 @@ struct BlockToCTileMap_KSplit_M00_N00_M01_N01 ...@@ -364,14 +408,16 @@ struct BlockToCTileMap_KSplit_M00_N00_M01_N01
index_t M01 = 1, index_t M01 = 1,
index_t N01 = 1, index_t N01 = 1,
index_t KSplit = 1) index_t KSplit = 1)
: M01_(M01), : c_grid_desc_m_n_(c_grid_desc_m_n),
M01_(M01),
N01_(N01), N01_(N01),
KSplit_(KSplit), KSplit_(KSplit),
underlying_map_(GetBlockToCTileMap(c_grid_desc_m_n, M01, N01, KSplit)) underlying_map_(GetBlockToCTileMap(c_grid_desc_m_n, M01, N01, KSplit))
{ {
} }
__host__ constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const __host__ __device__ constexpr index_t
CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const
{ {
const auto M0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I0), MPerBlock); const auto M0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I0), MPerBlock);
const auto N0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I1), NPerBlock); const auto N0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I1), NPerBlock);
...@@ -387,7 +433,10 @@ struct BlockToCTileMap_KSplit_M00_N00_M01_N01 ...@@ -387,7 +433,10 @@ struct BlockToCTileMap_KSplit_M00_N00_M01_N01
template <typename TopIdx> template <typename TopIdx>
__host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const
{ {
return underlying_map_.CalculateBottomIndex(idx_top); static_assert(TopIdx::Size() == 1);
return underlying_map_.CalculateBottomIndex(
make_multi_index(idx_top[I0] % CalculateGridSize()));
} }
template <typename CTileIdx, typename CTileDim> template <typename CTileIdx, typename CTileDim>
...@@ -418,6 +467,11 @@ struct BlockToCTileMap_KSplit_M00_N00_M01_N01 ...@@ -418,6 +467,11 @@ struct BlockToCTileMap_KSplit_M00_N00_M01_N01
} }
private: private:
__device__ constexpr index_t CalculateGridSize() const
{
return CalculateGridSize(c_grid_desc_m_n_);
}
__host__ static constexpr auto GetBlockToCTileMap(const CGridDesc_M_N& c_grid_desc_m_n, __host__ static constexpr auto GetBlockToCTileMap(const CGridDesc_M_N& c_grid_desc_m_n,
index_t M01, index_t M01,
index_t N01, index_t N01,
...@@ -450,6 +504,7 @@ struct BlockToCTileMap_KSplit_M00_N00_M01_N01 ...@@ -450,6 +504,7 @@ struct BlockToCTileMap_KSplit_M00_N00_M01_N01
return c_blockid_to_ksplit_m0_n0_block_cluster_adaptor; return c_blockid_to_ksplit_m0_n0_block_cluster_adaptor;
} }
CGridDesc_M_N c_grid_desc_m_n_;
index_t M01_, N01_, KSplit_; index_t M01_, N01_, KSplit_;
using UnderlyingMap = decltype(GetBlockToCTileMap(CGridDesc_M_N{}, 1, 1, 1)); using UnderlyingMap = decltype(GetBlockToCTileMap(CGridDesc_M_N{}, 1, 1, 1));
UnderlyingMap underlying_map_; UnderlyingMap underlying_map_;
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/multi_index_transform_helper.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_welford.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_welford.hpp"
namespace ck {
// GEMM:
// input : A[M, K]
// input : B[N, K]
// input : D0[M, N], D1[M, N], ...
// output : E[M, N]
// output : F[M, N0], where N0 is number of blocks along N dimension
// output : G[M, N0], where N0 is number of blocks along N dimension
// C = a_op(A) * b_op(B)
// E = cde_op(C, D0, D1, ...)
// F, G = welford(E)
// Assume:
// D0, D1, ... and E have the same layout
// Calculate mean & variance along N dimension for E
template <typename ABDataType,
typename AccDataType,
typename CShuffleDataType,
typename DsDataType,
typename EMeanVarDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
typename AGridDesc_M_K,
typename BGridDesc_N_K,
typename DsGridDesc_M_N,
typename EGridDesc_M_N,
typename MeanVarGridDesc_M_NBlock,
typename CountGridDesc_M_NBlock,
index_t NumGemmKPrefetchStage,
index_t BlockSize,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t AK1Value,
index_t BK1Value,
index_t MPerXdl,
index_t NPerXdl,
index_t MXdlPerWave,
index_t NXdlPerWave,
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_AK1,
bool AThreadTransferSrcResetCoordinateAfterRun,
index_t ABlockLdsExtraM,
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_BK1,
bool BThreadTransferSrcResetCoordinateAfterRun,
index_t BBlockLdsExtraN,
index_t CShuffleMXdlPerWavePerShuffle,
index_t CShuffleNXdlPerWavePerShuffle,
typename PostShuffleThreadClusterSize_M_N,
index_t PostShuffleScalarPerVector,
LoopScheduler LoopSched,
PipelineVersion PipelineVer = PipelineVersion::v1>
struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
{
static constexpr index_t NumDTensor = DsDataType::Size();
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{};
static constexpr auto I5 = Number<5>{};
static constexpr auto I6 = Number<6>{};
static constexpr auto I7 = Number<7>{};
// K1 should be Number<...>
static constexpr auto AK1 = Number<AK1Value>{};
static constexpr auto BK1 = Number<BK1Value>{};
static constexpr auto AK0PerBlock = Number<KPerBlock / AK1Value>{};
static constexpr auto BK0PerBlock = Number<KPerBlock / BK1Value>{};
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
using GridwiseGemmPipe = remove_cvref_t<decltype(
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>())>;
__host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
{
// A matrix in LDS memory, dst of blockwise copy
return make_naive_tensor_descriptor(
make_tuple(AK0PerBlock, Number<MPerBlock>{}, AK1),
make_tuple(Number<MPerBlock + ABlockLdsExtraM>{} * AK1, AK1, I1));
}
__host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
{
// B matrix in LDS memory, dst of blockwise copy
return make_naive_tensor_descriptor(
make_tuple(BK0PerBlock, Number<NPerBlock>{}, BK1),
make_tuple(Number<NPerBlock + BBlockLdsExtraN>{} * BK1, BK1, I1));
}
__host__ __device__ static constexpr auto
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock()
{
constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
make_naive_tensor_descriptor_packed(
make_tuple(I1,
Number<CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl>{},
I1,
Number<CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>{}));
return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock;
}
// ck::Tuple<const D0DataType*, const D1DataType*, ...>
static constexpr auto MakeDsGridPointer()
{
return generate_tuple(
[&](auto i) {
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
return static_cast<const DDataType*>(nullptr);
},
Number<NumDTensor>{});
}
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
// lds max alignment
constexpr auto max_lds_align = math::lcm(AK1, BK1);
constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
constexpr auto b_block_space_size_aligned = math::integer_least_multiple(
b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
// LDS allocation for C shuffle in LDS
constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
constexpr auto c_block_size =
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
return math::max((a_block_space_size_aligned + b_block_space_size_aligned) *
sizeof(ABDataType),
c_block_size * sizeof(CShuffleDataType));
}
// A desc for source in blockwise copy
__host__ __device__ static constexpr auto
MakeDefaultAGridDescriptor_AK0_M_AK1(const AGridDesc_M_K& a_grid_desc_m_k)
{
const auto M = a_grid_desc_m_k.GetLength(I0);
const auto K = a_grid_desc_m_k.GetLength(I1);
const auto AK0 = K / AK1;
return transform_tensor_descriptor(a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
make_pass_through_transform(M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
// B desc for source in blockwise copy
__host__ __device__ static constexpr auto
MakeDefaultBGridDescriptor_BK0_N_BK1(const BGridDesc_N_K& b_grid_desc_n_k)
{
const auto N = b_grid_desc_n_k.GetLength(I0);
const auto K = b_grid_desc_n_k.GetLength(I1);
const auto BK0 = K / BK1;
return transform_tensor_descriptor(b_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
make_pass_through_transform(N)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
// E desc for destination in blockwise copy
template <typename EGridDescriptor_M_N>
__host__ __device__ static constexpr auto MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
const EGridDescriptor_M_N& e_grid_desc_m_n)
{
const auto M = e_grid_desc_m_n.GetLength(I0);
const auto N = e_grid_desc_m_n.GetLength(I1);
const auto MBlock = M / MPerBlock;
const auto NBlock = N / NPerBlock;
const auto e_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
e_grid_desc_m_n,
make_tuple(make_unmerge_transform(make_tuple(MBlock, Number<MPerBlock>{})),
make_unmerge_transform(make_tuple(NBlock, Number<NPerBlock>{}))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}));
return e_grid_desc_mblock_mperblock_nblock_nperblock;
}
// Ds desc for source in blockwise copy
template <typename DsGridDescriptor_M_N>
__host__ __device__ static constexpr auto
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
const DsGridDescriptor_M_N& ds_grid_desc_m_n)
{
return generate_tuple(
[&](auto i) {
return MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(ds_grid_desc_m_n[i]);
},
Number<NumDTensor>{});
}
template <typename GridDescriptor_M_N>
__host__ __device__ static constexpr auto
MakeMeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock(const GridDescriptor_M_N& grid_desc_m_n)
{
const auto M = grid_desc_m_n.GetLength(I0);
const auto NBlock = grid_desc_m_n.GetLength(I1);
const auto MBlock = M / MPerBlock;
const auto grid_desc_mblock_mperblock_nblock = transform_tensor_descriptor(
grid_desc_m_n,
make_tuple(make_unmerge_transform(make_tuple(MBlock, Number<MPerBlock>{})),
make_pass_through_transform(NBlock)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2>{}));
return grid_desc_mblock_mperblock_nblock;
}
// return block_id to E matrix tile idx (m0, n0) mapping
__host__ __device__ static constexpr auto
MakeDefaultBlock2ETileMap(const EGridDesc_M_N& e_grid_desc_m_n)
{
return BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, EGridDesc_M_N>(
e_grid_desc_m_n);
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
template <typename Block2ETileMap>
__host__ __device__ static constexpr bool CheckValidity(const AGridDesc_M_K& a_grid_desc_m_k,
const BGridDesc_N_K& b_grid_desc_n_k,
const DsGridDesc_M_N& ds_grid_desc_m_n,
const EGridDesc_M_N& e_grid_desc_m_n,
const Block2ETileMap& block_2_etile_map)
{
static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
(NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
"Invalid tuning param!");
const auto M = a_grid_desc_m_k.GetLength(I0);
const auto N = b_grid_desc_n_k.GetLength(I0);
const auto K = a_grid_desc_m_k.GetLength(I1);
// check consistency of desc
if(!(M == e_grid_desc_m_n.GetLength(I0) && N == e_grid_desc_m_n.GetLength(I1)))
{
return false;
}
bool valid = true;
static_for<0, NumDTensor, 1>{}([&](auto i) {
valid = valid && (M == ds_grid_desc_m_n[i].GetLength(I0) &&
N == ds_grid_desc_m_n[i].GetLength(I1));
});
if(!valid)
{
return false;
}
// check tile size
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0))
{
return false;
}
// check gridwise gemm pipeline
const auto num_k_loop = K / KPerBlock;
if(!GridwiseGemmPipe::IsSupported(num_k_loop))
{
return false;
}
// check block-to-E-tile
if(!block_2_etile_map.CheckValidity(e_grid_desc_m_n))
{
return false;
}
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
// check tensor size: cannot be larger than 2GB each
constexpr long_index_t TwoGB = (long_index_t{1} << 31);
if(!(a_grid_desc_m_k.GetElementSpaceSize() * sizeof(ABDataType) <= TwoGB &&
b_grid_desc_n_k.GetElementSpaceSize() * sizeof(ABDataType) <= TwoGB &&
e_grid_desc_m_n.GetElementSpaceSize() * sizeof(EMeanVarDataType) <= TwoGB))
{
return false;
}
return true;
}
__host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
{
const index_t num_loop = K / KPerBlock;
return GridwiseGemmPipe::CalculateHasMainLoop(num_loop);
}
using DefaultAGridDesc_AK0_M_AK1 =
remove_cvref_t<decltype(MakeDefaultAGridDescriptor_AK0_M_AK1(AGridDesc_M_K{}))>;
using DefaultBGridDesc_BK0_N_BK1 =
remove_cvref_t<decltype(MakeDefaultBGridDescriptor_BK0_N_BK1(BGridDesc_N_K{}))>;
using EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{}))>;
using MeanVarGridDescriptor_MBlock_MPerBlock_NBlock = remove_cvref_t<decltype(
MakeMeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock(MeanVarGridDesc_M_NBlock{}))>;
using CountGridDescriptor_MBlock_MPerBlock_NBlock = remove_cvref_t<decltype(
MakeMeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock(CountGridDesc_M_NBlock{}))>;
using DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(DsGridDesc_M_N{}))>;
using DefaultBlock2ETileMap =
remove_cvref_t<decltype(MakeDefaultBlock2ETileMap(EGridDesc_M_N{}))>;
using DsGridPointer = decltype(MakeDsGridPointer());
template <bool HasMainKBlockLoop,
typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1,
typename Block2ETileMap>
__device__ static void
Run(const ABDataType* __restrict__ p_a_grid,
const ABDataType* __restrict__ p_b_grid,
DsGridPointer p_ds_grid,
EMeanVarDataType* __restrict__ p_e_grid,
EMeanVarDataType* __restrict__ p_welford_mean_grid,
EMeanVarDataType* __restrict__ p_welford_var_grid,
int32_t* __restrict__ p_welford_count,
void* __restrict__ p_shared,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CDEElementwiseOperation& cde_element_op,
const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
ds_grid_desc_mblock_mperblock_nblock_nperblock,
const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
e_grid_desc_mblock_mperblock_nblock_nperblock,
const MeanVarGridDescriptor_MBlock_MPerBlock_NBlock&
mean_var_grid_desc_mblock_mperblock_nblock,
const CountGridDescriptor_MBlock_MPerBlock_NBlock& count_grid_desc_mblock_mperblock_nblock,
const Block2ETileMap& block_2_etile_map,
index_t NRaw)
{
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
const auto ds_grid_buf = generate_tuple(
[&](auto i) {
return make_dynamic_buffer<AddressSpaceEnum::Global>(
p_ds_grid[i],
ds_grid_desc_mblock_mperblock_nblock_nperblock[i].GetElementSpaceSize());
},
Number<NumDTensor>{});
auto e_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_e_grid, e_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
auto mean_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_welford_mean_grid, mean_var_grid_desc_mblock_mperblock_nblock.GetElementSpaceSize());
auto var_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_welford_var_grid, mean_var_grid_desc_mblock_mperblock_nblock.GetElementSpaceSize());
auto welford_count_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_welford_count, count_grid_desc_mblock_mperblock_nblock.GetElementSpaceSize());
// divide block work by [M, N]
const auto block_work_idx =
block_2_etile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
if(!block_2_etile_map.ValidCTileIndex(
block_work_idx,
make_tuple(e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
{
return;
}
// HACK: this force m/n_block_data_idx_on_grid into SGPR
const index_t m_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock);
const index_t n_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock);
// lds max alignment
constexpr auto max_lds_align = math::lcm(AK1, BK1);
// A matrix in LDS memory, dst of blockwise copy
constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
// B matrix in LDS memory, dst of blockwise copy
constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
// A matrix blockwise copy
auto a_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
AElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<AK0PerBlock, MPerBlock, AK1>,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ABDataType,
ABDataType,
decltype(a_grid_desc_ak0_m_ak1),
decltype(a_block_desc_ak0_m_ak1),
ABlockTransferSrcAccessOrder,
Sequence<1, 0, 2>,
ABlockTransferSrcVectorDim,
2,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
1,
1,
AThreadTransferSrcResetCoordinateAfterRun,
true,
NumGemmKPrefetchStage>(
a_grid_desc_ak0_m_ak1,
make_multi_index(0, m_block_data_idx_on_grid, 0),
a_element_op,
a_block_desc_ak0_m_ak1,
make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{});
// B matrix blockwise copy
auto b_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
BElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<BK0PerBlock, NPerBlock, BK1>,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
ABDataType,
ABDataType,
decltype(b_grid_desc_bk0_n_bk1),
decltype(b_block_desc_bk0_n_bk1),
BBlockTransferSrcAccessOrder,
Sequence<1, 0, 2>,
BBlockTransferSrcVectorDim,
2,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
1,
1,
BThreadTransferSrcResetCoordinateAfterRun,
true,
NumGemmKPrefetchStage>(
b_grid_desc_bk0_n_bk1,
make_multi_index(0, n_block_data_idx_on_grid, 0),
b_element_op,
b_block_desc_bk0_n_bk1,
make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{});
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[K0PerBlock, MPerBlock] is in LDS
// b_mtx[K0PerBlock, NPerBlock] is in LDS
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
// sanity check
constexpr index_t KPack =
math::max(math::lcm(AK1, BK1),
MfmaSelector<ABDataType, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
BlockSize,
ABDataType,
AccDataType,
decltype(a_block_desc_ak0_m_ak1),
decltype(b_block_desc_bk0_n_bk1),
MPerXdl,
NPerXdl,
MXdlPerWave,
NXdlPerWave,
KPack,
LoopSched>();
auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<ABDataType*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<ABDataType*>(p_shared) + a_block_space_size_aligned,
b_block_desc_bk0_n_bk1.GetElementSpaceSize());
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1, 0, 0);
// gridwise GEMM pipeline
const auto gridwise_gemm_pipeline =
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>();
const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
(a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
KPerBlock);
gridwise_gemm_pipeline.template Run<HasMainKBlockLoop>(a_grid_desc_ak0_m_ak1,
a_block_desc_ak0_m_ak1,
a_blockwise_copy,
a_grid_buf,
a_block_buf,
a_block_slice_copy_step,
b_grid_desc_bk0_n_bk1,
b_block_desc_bk0_n_bk1,
b_blockwise_copy,
b_grid_buf,
b_block_buf,
b_block_slice_copy_step,
blockwise_gemm,
c_thread_buf,
num_k_block_main_loop);
// shuffle C, Welford and write out
{
static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
"wrong!");
constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
// TODO: hacky, fix it!
constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
// TODO: hacky, fix it!
// c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<CShuffleDataType*>(p_shared),
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
make_tuple(
make_freeze_transform(I0),
make_unmerge_transform(make_tuple(
Number<CShuffleMXdlPerWavePerShuffle>{}, // M0 (MXdlPerWave) per shuffle
M1, // M1 = MWave
M2, // M2 * M3 * M4 = MPerXdl
M3,
M4)),
make_freeze_transform(I0),
make_unmerge_transform(make_tuple(
Number<CShuffleNXdlPerWavePerShuffle>{}, // N0 (NXdlPerWave) per shuffle
N1, // N1 = NWave
N2))), // N2 = NPerXdl
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(
Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{}));
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const auto c_thread_mtx_on_block =
blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
make_tuple(Sequence<0, 1, 2, 3, 4>{}),
make_tuple(Sequence<0>{}));
const auto m_thread_data_on_block_idx =
m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
make_multi_index(m_thread_data_on_block));
const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(N0, N1, N2))),
make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{}));
const auto n_thread_data_on_block_idx =
n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
make_multi_index(n_thread_data_on_block));
// shuffle: threadwise copy C from VGPR to LDS
auto c_thread_copy_vgpr_to_lds =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
CShuffleDataType,
decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
ck::tensor_operation::element_wise::PassThrough,
Sequence<CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
I1,
I1,
M2,
I1,
M4,
I1>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
7,
1,
InMemoryDataOperationEnum::Set,
1,
true>{
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
make_multi_index(0,
0,
m_thread_data_on_block_idx[I1],
n_thread_data_on_block_idx[I1],
m_thread_data_on_block_idx[I2],
m_thread_data_on_block_idx[I3],
m_thread_data_on_block_idx[I4],
n_thread_data_on_block_idx[I2]),
ck::tensor_operation::element_wise::PassThrough{}};
// space filling curve for threadwise C in VGPR
constexpr auto sfc_c_vgpr =
SpaceFillingCurve<Sequence<MXdlPerWave, NXdlPerWave, 1, 1, M2, 1, M4, 1>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
Sequence<CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
1,
1,
M2,
1,
M4,
1>,
false>{};
// space filling curve for shuffled blockwise C in global mem
constexpr auto sfc_der_global =
SpaceFillingCurve<Sequence<1, MPerBlock, 1, NPerBlock>,
Sequence<0, 2, 1, 3>,
Sequence<1,
CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1,
CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>,
false>{};
// LDS c_shuffle_block_desc_mperblock_nperblock
constexpr auto c_shuffle_block_desc_mperblock_nperblock = transform_tensor_descriptor(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
make_tuple(
make_freeze_transform(I0),
make_pass_through_transform(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetLength(I1)),
make_freeze_transform(I0),
make_pass_through_transform(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetLength(I3))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<>{}, Sequence<0>{}, Sequence<>{}, Sequence<1>{}));
static_assert(PostShuffleThreadClusterSize_M_N::At(I0) *
PostShuffleThreadClusterSize_M_N::At(I1) ==
BlockSize,
"wrong!");
static_assert((CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl) %
PostShuffleThreadClusterSize_M_N::At(I0) ==
0 &&
(CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl) %
PostShuffleThreadClusterSize_M_N::At(I1) ==
0,
"wrong!");
constexpr index_t PostShuffleThreadSliceSize_M =
(CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl) /
PostShuffleThreadClusterSize_M_N::At(I0);
constexpr index_t PostShuffleThreadSliceSize_N =
(CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl) /
PostShuffleThreadClusterSize_M_N::At(I1);
constexpr auto PostShuffleThreadSliceSize_M_N =
Sequence<PostShuffleThreadSliceSize_M, PostShuffleThreadSliceSize_N>{};
// VGPR post_shuffle_thread_desc_m_n
constexpr auto post_shuffle_thread_desc_m_n = make_naive_tensor_descriptor_packed(
make_tuple(Number<PostShuffleThreadSliceSize_M>{},
Number<PostShuffleThreadSliceSize_N>{}));
auto e_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, AccDataType>(
post_shuffle_thread_desc_m_n.GetElementSpaceSize());
// To apply D0, D1, ... and Welford.
// threadwise copy from LDS to VGPR
constexpr auto post_shuffle_thread_cluster_desc =
make_cluster_descriptor(PostShuffleThreadClusterSize_M_N{}, Sequence<0, 1>{});
const auto post_shuffle_thread_cluster_idx =
post_shuffle_thread_cluster_desc.CalculateBottomIndex(
make_multi_index(get_thread_local_1d_id()));
const auto post_shuffle_thread_data_idx_begin =
post_shuffle_thread_cluster_idx * PostShuffleThreadSliceSize_M_N;
// To apply D0, D1, ... and Welford.
// Copy c shuffle from LDS back to VGPR
auto post_shuffle_thread_copy_lds_to_vgpr =
ThreadwiseTensorSliceTransfer_v2<CShuffleDataType,
AccDataType,
decltype(c_shuffle_block_desc_mperblock_nperblock),
decltype(post_shuffle_thread_desc_m_n),
decltype(PostShuffleThreadSliceSize_M_N),
Sequence<0, 1>,
1,
PostShuffleScalarPerVector,
1,
true>{c_shuffle_block_desc_mperblock_nperblock,
post_shuffle_thread_data_idx_begin};
// D0, D1, ..., Dn
constexpr auto post_shuffle_thread_desc_I1_mperblock_I1_nperblock =
make_naive_tensor_descriptor_packed(
make_tuple(I1,
Number<PostShuffleThreadSliceSize_M>{},
I1,
Number<PostShuffleThreadSliceSize_N>{}));
// FIXME: Decrease usage of VGPR
// Apply pointwise lambda function from multi-source (Global and LDS) into VGPR
auto ds_thread_buf = generate_tuple(
[&](auto) {
return make_static_buffer<AddressSpaceEnum::Vgpr, CShuffleDataType>(
post_shuffle_thread_desc_I1_mperblock_I1_nperblock.GetElementSpaceSize());
},
Number<NumDTensor>{});
// Copy D0, D1, ..., Dn from global to VGPR
auto ds_thread_copy_global_to_vgpr = generate_tuple(
[&](auto I) {
using DDataType = remove_cvref_t<tuple_element_t<I.value, DsDataType>>;
return ThreadwiseTensorSliceTransfer_v2<
DDataType,
AccDataType,
decltype(ds_grid_desc_mblock_mperblock_nblock_nperblock[I]),
decltype(post_shuffle_thread_desc_I1_mperblock_I1_nperblock),
Sequence<I1,
PostShuffleThreadSliceSize_M,
I1,
PostShuffleThreadSliceSize_N>,
Sequence<0, 1, 2, 3>,
3,
PostShuffleScalarPerVector,
1,
true>(
ds_grid_desc_mblock_mperblock_nblock_nperblock[I],
make_multi_index(
I0,
m_block_data_idx_on_grid + post_shuffle_thread_data_idx_begin[I0],
I0,
n_block_data_idx_on_grid + post_shuffle_thread_data_idx_begin[I1]));
},
Number<NumDTensor>{});
auto e_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3<
AccDataType,
EMeanVarDataType,
decltype(post_shuffle_thread_desc_I1_mperblock_I1_nperblock),
decltype(e_grid_desc_mblock_mperblock_nblock_nperblock),
tensor_operation::element_wise::PassThrough,
Sequence<I1,
PostShuffleThreadSliceSize_M,
I1,
PostShuffleThreadSliceSize_N>, // SliceLengths
Sequence<0, 1, 2, 3>, // DimAccessOrder
3, // DstVectorDim
PostShuffleScalarPerVector,
InMemoryDataOperationEnum::Set,
1,
true>{
e_grid_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(I0,
m_block_data_idx_on_grid + post_shuffle_thread_data_idx_begin[I0],
I0,
n_block_data_idx_on_grid + post_shuffle_thread_data_idx_begin[I1]),
tensor_operation::element_wise::PassThrough{}};
// Welford
constexpr auto thread_welford_src_desc_m_k = make_naive_tensor_descriptor_packed(
make_tuple(Number<PostShuffleThreadSliceSize_M>{},
Number<PostShuffleThreadSliceSize_N>{}));
constexpr auto thread_welford_dst_desc_m = make_naive_tensor_descriptor_packed(
make_tuple(Number<PostShuffleThreadSliceSize_M>{}));
using ThreadwiseWelford = ThreadwiseWelford<AccDataType,
decltype(thread_welford_src_desc_m_k),
decltype(thread_welford_dst_desc_m)>;
using BlockwiseWelford = BlockwiseWelford<AccDataType,
BlockSize,
PostShuffleThreadClusterSize_M_N,
Sequence<0, 1>,
false>;
constexpr int num_shuffleM =
MPerBlock / (CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl);
constexpr int num_shuffleN =
NPerBlock / (CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl);
using mean_var_vgpr_type =
decltype(make_static_buffer<AddressSpaceEnum::Vgpr, AccDataType>(
thread_welford_dst_desc_m.GetElementSpaceSize()));
using welford_count_vgpr_type =
decltype(make_static_buffer<AddressSpaceEnum::Vgpr, int32_t>(
thread_welford_dst_desc_m.GetElementSpaceSize()));
Array<ThreadwiseWelford, num_shuffleM> threadwise_welfords;
Array<mean_var_vgpr_type, num_shuffleM> mean_thread_bufs;
Array<mean_var_vgpr_type, num_shuffleM> var_thread_bufs;
Array<welford_count_vgpr_type, num_shuffleM> welford_count_thread_bufs;
int max_count = PostShuffleThreadSliceSize_N * num_shuffleN;
const auto nblock = mean_var_grid_desc_mblock_mperblock_nblock.GetLength(I2);
// tail block
if(block_work_idx[I1] % nblock == nblock - 1)
{
constexpr index_t NPerShuffleBlock =
CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl;
int NPerBlockTail = NRaw - NPerBlock * (nblock - 1);
int thread_max_len =
PostShuffleThreadSliceSize_N * (post_shuffle_thread_cluster_idx[I1] + 1);
int shuffle_step = 0;
while(thread_max_len <= NPerBlockTail && shuffle_step < num_shuffleN)
{
++shuffle_step;
thread_max_len += NPerShuffleBlock;
}
int delta = 0;
if(thread_max_len - NPerBlockTail > PostShuffleThreadSliceSize_N)
delta = 0;
else if(NPerBlockTail > thread_max_len)
delta = PostShuffleThreadSliceSize_N;
else
delta = PostShuffleThreadSliceSize_N - thread_max_len + NPerBlockTail;
max_count = shuffle_step * PostShuffleThreadSliceSize_N + delta;
}
static_for<0, num_shuffleM, 1>{}([&](auto i) {
threadwise_welfords(i).max_count_ = max_count;
mean_thread_bufs(i) = make_static_buffer<AddressSpaceEnum::Vgpr, AccDataType>(
thread_welford_dst_desc_m.GetElementSpaceSize());
var_thread_bufs(i) = make_static_buffer<AddressSpaceEnum::Vgpr, AccDataType>(
thread_welford_dst_desc_m.GetElementSpaceSize());
welford_count_thread_bufs(i) = make_static_buffer<AddressSpaceEnum::Vgpr, int32_t>(
thread_welford_dst_desc_m.GetElementSpaceSize());
static_for<0, PostShuffleThreadSliceSize_M, 1>{}([&](auto j) {
mean_thread_bufs(i)(j) = type_convert<AccDataType>(0.0f);
var_thread_bufs(i)(j) = type_convert<AccDataType>(0.0f);
welford_count_thread_bufs(i)(j) = 0;
});
});
constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
static_assert(num_access == sfc_der_global.GetNumOfAccess(), "wrong!");
int shuffleM_index = __builtin_amdgcn_readfirstlane(0);
static_for<0, num_access, 1>{}([&](auto access_id) {
// make sure it's safe to read from LDS
block_sync_lds();
// each thread shuffle data from VGPR to LDS
c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
c_thread_buf,
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
c_shuffle_block_buf);
// make sure it's safe to write to LDS
block_sync_lds();
// Get shuffle data from LDS to VGPR
post_shuffle_thread_copy_lds_to_vgpr.Run(c_shuffle_block_desc_mperblock_nperblock,
c_shuffle_block_buf,
post_shuffle_thread_desc_m_n,
make_tuple(I0, I0),
e_thread_buf);
// Global read D0, D1, ...
static_for<0, NumDTensor, 1>{}([&](auto Id) {
auto& d_thread_copy_global_to_vgpr = ds_thread_copy_global_to_vgpr(Id);
d_thread_copy_global_to_vgpr.Run(
ds_grid_desc_mblock_mperblock_nblock_nperblock[Id],
ds_grid_buf[Id],
post_shuffle_thread_desc_I1_mperblock_I1_nperblock,
make_tuple(I0, I0, I0, I0),
ds_thread_buf(Id));
if constexpr(access_id < num_access - 1)
{
// move on D0, D1, ...
constexpr auto de_global_step = sfc_der_global.GetForwardStep(access_id);
d_thread_copy_global_to_vgpr.MoveSrcSliceWindow(
ds_grid_desc_mblock_mperblock_nblock_nperblock[Id], de_global_step);
}
});
// cde_element_op(e, c, d0, d1, ...);
static_for<0, post_shuffle_thread_desc_m_n.GetElementSize(), 1>{}([&](auto i) {
const auto c_ds_src_data_refs = concat_tuple_of_reference(
tie(e_thread_buf[i]),
generate_tie(
[&](auto Id) -> const auto& { return ds_thread_buf[Id][i]; },
Number<NumDTensor>{}));
auto e_dst_data_refs = tie(e_thread_buf(i));
unpack2(cde_element_op, e_dst_data_refs, c_ds_src_data_refs);
});
// Global write E
e_thread_copy_vgpr_to_global.Run(post_shuffle_thread_desc_I1_mperblock_I1_nperblock,
make_tuple(I0, I0, I0, I0),
e_thread_buf,
e_grid_desc_mblock_mperblock_nblock_nperblock,
e_grid_buf);
if constexpr(access_id < num_access - 1)
{
// move on E
constexpr auto de_global_step = sfc_der_global.GetForwardStep(access_id);
e_thread_copy_vgpr_to_global.MoveDstSliceWindow(
e_grid_desc_mblock_mperblock_nblock_nperblock, de_global_step);
}
// Threadwise welford
auto& threadwise_welford = threadwise_welfords(shuffleM_index);
auto& mean_thread_buf = mean_thread_bufs(shuffleM_index);
auto& var_thread_buf = var_thread_bufs(shuffleM_index);
threadwise_welford.Run(e_thread_buf, mean_thread_buf, var_thread_buf);
if constexpr(access_id < num_access - 1)
{
constexpr auto de_global_step = sfc_der_global.GetForwardStep(access_id);
constexpr int shuffleMInc =
de_global_step[I1] /
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetLength(I1);
shuffleM_index = __builtin_amdgcn_readfirstlane(shuffleM_index + shuffleMInc);
}
}); // copy c, d, e + welford
// Blockwise welford and write out
static_for<0, num_shuffleM, 1>{}([&](auto i) {
auto& mean_thread_buf = mean_thread_bufs(i);
auto& var_thread_buf = var_thread_bufs(i);
auto& count_thread_buf = welford_count_thread_bufs(i);
static_for<0, PostShuffleThreadSliceSize_M, 1>{}([&](auto j) {
block_sync_lds();
count_thread_buf(j) = threadwise_welfords(i).cur_count_;
BlockwiseWelford::Run(
mean_thread_buf(j), var_thread_buf(j), count_thread_buf(j));
});
if(post_shuffle_thread_cluster_idx[I1] == 0)
{
constexpr auto thread_welford_desc_I_m_I = make_naive_tensor_descriptor_packed(
make_tuple(I1, Number<PostShuffleThreadSliceSize_M>{}, I1));
constexpr int shuffleMPerBlock =
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetLength(I1);
auto mean_var_count_thread_copy_index = make_multi_index(
block_work_idx[I0], // mblock
shuffleMPerBlock * i + post_shuffle_thread_data_idx_begin[I0], // mperblock
block_work_idx[I1]); // nblock
auto mean_var_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3<
AccDataType,
EMeanVarDataType,
decltype(thread_welford_desc_I_m_I),
decltype(mean_var_grid_desc_mblock_mperblock_nblock),
tensor_operation::element_wise::PassThrough,
Sequence<1, PostShuffleThreadSliceSize_M, 1>,
Sequence<0, 1, 2>,
1,
1,
InMemoryDataOperationEnum::Set,
1,
true>{mean_var_grid_desc_mblock_mperblock_nblock,
mean_var_count_thread_copy_index,
tensor_operation::element_wise::PassThrough{}};
mean_var_thread_copy_vgpr_to_global.Run(
thread_welford_desc_I_m_I,
make_tuple(I0, I0, I0),
mean_thread_buf,
mean_var_grid_desc_mblock_mperblock_nblock,
mean_grid_buf); // write mean
mean_var_thread_copy_vgpr_to_global.Run(
thread_welford_desc_I_m_I,
make_tuple(I0, I0, I0),
var_thread_buf,
mean_var_grid_desc_mblock_mperblock_nblock,
var_grid_buf); // write variance
// Stride of count is [0, 1]. Only the first row in count[0, 0:nblock] need
// to be written.
if(i == 0 && block_work_idx[I0] == 0 &&
post_shuffle_thread_cluster_idx[I0] == 0)
{
auto count_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3<
int32_t,
int32_t,
decltype(thread_welford_desc_I_m_I),
decltype(count_grid_desc_mblock_mperblock_nblock),
tensor_operation::element_wise::PassThrough,
Sequence<1, PostShuffleThreadSliceSize_M, 1>,
Sequence<0, 1, 2>,
1,
1,
InMemoryDataOperationEnum::Set,
1,
false>{count_grid_desc_mblock_mperblock_nblock,
mean_var_count_thread_copy_index,
tensor_operation::element_wise::PassThrough{}};
count_thread_copy_vgpr_to_global.Run(
thread_welford_desc_I_m_I,
make_tuple(I0, I0, I0),
count_thread_buf,
count_grid_desc_mblock_mperblock_nblock,
welford_count_grid_buf); // write count
}
}
});
} // shuffle C + Ds + welford + write out
} // run
};
} // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment