Commit 463e2aa1 authored by aska-0096's avatar aska-0096
Browse files

Merge branch 'develop' of https://github.com/ROCmSoftwarePlatform/composable_kernel into wmma_op

parents 6e106c19 236bd148
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
#include "ck/library/utility/host_tensor.hpp" #include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/host_common_util.hpp" #include "ck/library/utility/host_common_util.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batchnorm_forward_nhwc_c.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_batchnorm_forward.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batchnorm_forward_impl.hpp" #include "ck/tensor_operation/gpu/device/impl/device_batchnorm_forward_impl.hpp"
#include "ck/library/utility/host_common_util.hpp" #include "ck/library/utility/host_common_util.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
...@@ -142,6 +142,8 @@ bool bnorm_fwd_nhwc_test(bool do_verification, ...@@ -142,6 +142,8 @@ bool bnorm_fwd_nhwc_test(bool do_verification,
constexpr int Rank = 4; constexpr int Rank = 4;
constexpr int NumReduceDim = 3; constexpr int NumReduceDim = 3;
// when using lengths[] to create a tensor, lengths[0] is the length of highest dimension
// eg. N of NHWC, so lengths[3] is the dimension C length of NHWC
const std::vector<size_t> scaleBiasMeanVarLengths = {inOutLengths[3]}; const std::vector<size_t> scaleBiasMeanVarLengths = {inOutLengths[3]};
// input data of the batchnorm forward algorithm // input data of the batchnorm forward algorithm
...@@ -300,7 +302,7 @@ bool bnorm_fwd_nhwc_test(bool do_verification, ...@@ -300,7 +302,7 @@ bool bnorm_fwd_nhwc_test(bool do_verification,
i_inOutLengths, i_inOutLengths,
i_inOutStrides, i_inOutStrides,
i_inOutStrides, i_inOutStrides,
{0, 1, 2}, {0, 1, 2}, // indicates physical indices of reduce dimensions in lengths[] and strides[]
i_scaleBiasMeanVarLengths, i_scaleBiasMeanVarLengths,
i_scaleBiasMeanVarStrides, i_scaleBiasMeanVarStrides,
i_scaleBiasMeanVarStrides, i_scaleBiasMeanVarStrides,
...@@ -366,13 +368,15 @@ bool bnorm_fwd_nhwc_test(bool do_verification, ...@@ -366,13 +368,15 @@ bool bnorm_fwd_nhwc_test(bool do_verification,
{ {
using ReferenceBatchNormFwdInstance = using ReferenceBatchNormFwdInstance =
ck::tensor_operation::host::ReferenceBatchNormFwd_Input_N_H_W_C_Output_C<InOutDataType, ck::tensor_operation::host::ReferenceBatchNormFwd<InOutDataType,
InOutDataType, InOutDataType,
AccDataType, AccDataType,
AccDataType, AccDataType,
AccDataType, AccDataType,
AccDataType, AccDataType,
PassThroughOp>; PassThroughOp,
Rank,
NumReduceDim>;
auto batchNormFwd_ref = ReferenceBatchNormFwdInstance{}; auto batchNormFwd_ref = ReferenceBatchNormFwdInstance{};
...@@ -380,7 +384,7 @@ bool bnorm_fwd_nhwc_test(bool do_verification, ...@@ -380,7 +384,7 @@ bool bnorm_fwd_nhwc_test(bool do_verification,
i_inOutLengths, i_inOutLengths,
i_inOutStrides, i_inOutStrides,
i_inOutStrides, i_inOutStrides,
{0, 1, 2}, {0, 1, 2}, // indicates physical indices of reduce dimensions in lengths[] and strides[]
i_scaleBiasMeanVarLengths, i_scaleBiasMeanVarLengths,
i_scaleBiasMeanVarStrides, i_scaleBiasMeanVarStrides,
i_scaleBiasMeanVarStrides, i_scaleBiasMeanVarStrides,
......
...@@ -163,6 +163,13 @@ ...@@ -163,6 +163,13 @@
// tuning parameter // tuning parameter
#define CK_WORKAROUND_SWDEV_325164 0 #define CK_WORKAROUND_SWDEV_325164 0
// workaround: a BF16 attention kernel for gfx908 is likely affected by a compiler issue
#ifdef __gfx908__
#define CK_WORKAROUND_SWDEV_XXXXXX_BF16_ATTEN_FWD_GFX908_ISSUE 1
#else // __gfx90a__, ...
#define CK_WORKAROUND_SWDEV_XXXXXX_BF16_ATTEN_FWD_GFX908_ISSUE 0
#endif // __gfx908__
namespace ck { namespace ck {
enum struct InMemoryDataOperationEnum enum struct InMemoryDataOperationEnum
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <array>
#include <memory>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <index_t Rank, index_t NumBatchNormReduceDim, typename DyElementwiseOp>
struct DeviceBatchNormBwd : public BaseOperator
{
static constexpr index_t NumInvariantDim = Rank - NumBatchNormReduceDim;
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const std::array<index_t, Rank> xyLengths,
const std::array<index_t, Rank> xStrides,
const std::array<index_t, Rank> dyStrides,
const std::array<index_t, Rank> dxStrides,
const std::array<int, NumBatchNormReduceDim> reduceDims,
const std::array<ck::index_t, NumInvariantDim> bnScaleBiasMeanVarLengths,
const std::array<ck::index_t, NumInvariantDim> bnScaleStrides,
const std::array<ck::index_t, NumInvariantDim> bnBiasStrides,
const std::array<ck::index_t, NumInvariantDim> bnMeanVarStrides,
const void* p_x,
const void* p_dy,
const void* p_scale,
const void* p_savedMean,
const void* p_savedInvVar,
double epsilon,
const DyElementwiseOp dy_elementwise_op,
void* p_dx,
void* p_dscale,
void* p_dbias) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
template <index_t Rank, index_t NumBatchNormReduceDim, typename DyElementwiseOp>
using DeviceBatchNormBwdPtr =
std::unique_ptr<DeviceBatchNormBwd<Rank, NumBatchNormReduceDim, DyElementwiseOp>>;
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -13,7 +13,15 @@ namespace ck { ...@@ -13,7 +13,15 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
template <index_t Rank, index_t NumBatchNormReduceDim, typename YElementwiseOp> template <typename XDataType,
typename YDataType,
typename AccDataType,
typename ScaleDataType,
typename BiasDataType,
typename MeanVarDataType,
typename YElementwiseOp,
index_t Rank,
index_t NumBatchNormReduceDim>
struct DeviceBatchNormFwd : public BaseOperator struct DeviceBatchNormFwd : public BaseOperator
{ {
virtual std::unique_ptr<BaseArgument> MakeArgumentPointer( virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(
...@@ -40,9 +48,24 @@ struct DeviceBatchNormFwd : public BaseOperator ...@@ -40,9 +48,24 @@ struct DeviceBatchNormFwd : public BaseOperator
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0; virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
}; };
template <index_t Rank, index_t NumBatchNormReduceDim, typename YElementwiseOp> template <typename XDataType,
using DeviceBatchNormFwdPtr = typename YDataType,
std::unique_ptr<DeviceBatchNormFwd<Rank, NumBatchNormReduceDim, YElementwiseOp>>; typename AccDataType,
typename ScaleDataType,
typename BiasDataType,
typename MeanVarDataType,
typename YElementwiseOp,
index_t Rank,
index_t NumBatchNormReduceDim>
using DeviceBatchNormFwdPtr = std::unique_ptr<DeviceBatchNormFwd<XDataType,
YDataType,
AccDataType,
ScaleDataType,
BiasDataType,
MeanVarDataType,
YElementwiseOp,
Rank,
NumBatchNormReduceDim>>;
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
......
...@@ -13,13 +13,22 @@ namespace ck { ...@@ -13,13 +13,22 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
template <index_t Rank, index_t NumBatchNormReduceDim> template <typename XDataType,
typename YDataType,
typename AccDataType,
typename ScaleDataType,
typename BiasDataType,
typename MeanVarDataType,
typename YElementwiseOp,
index_t Rank,
index_t NumBatchNormReduceDim>
struct DeviceBatchNormInfer : public BaseOperator struct DeviceBatchNormInfer : public BaseOperator
{ {
virtual std::unique_ptr<BaseArgument> MakeArgumentPointer( virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(
const std::array<index_t, Rank> xyLengths, const std::array<index_t, Rank> xyLengths,
const std::array<index_t, Rank> xStrides, const std::array<index_t, Rank> xStrides,
const std::array<index_t, Rank> yStrides, const std::array<index_t, Rank> yStrides,
const std::array<int, NumBatchNormReduceDim> reduceDims,
const std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleBiasMeanVarLengths, const std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleBiasMeanVarLengths,
const std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleStrides, const std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleStrides,
const std::array<index_t, Rank - NumBatchNormReduceDim> bnBiasStrides, const std::array<index_t, Rank - NumBatchNormReduceDim> bnBiasStrides,
...@@ -28,6 +37,7 @@ struct DeviceBatchNormInfer : public BaseOperator ...@@ -28,6 +37,7 @@ struct DeviceBatchNormInfer : public BaseOperator
const void* bnScale, const void* bnScale,
const void* bnBias, const void* bnBias,
double epsilon, double epsilon,
const YElementwiseOp y_elementwise_op,
const void* estimatedMean, const void* estimatedMean,
const void* estimatedInvVariance, const void* estimatedInvVariance,
void* p_y) = 0; void* p_y) = 0;
...@@ -35,8 +45,24 @@ struct DeviceBatchNormInfer : public BaseOperator ...@@ -35,8 +45,24 @@ struct DeviceBatchNormInfer : public BaseOperator
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0; virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
}; };
template <index_t Rank, index_t NumBatchNormReduceDim> template <typename XDataType,
using DeviceBatchNormInferPtr = std::unique_ptr<DeviceBatchNormInfer<Rank, NumBatchNormReduceDim>>; typename YDataType,
typename AccDataType,
typename ScaleDataType,
typename BiasDataType,
typename MeanVarDataType,
typename YElementwiseOp,
index_t Rank,
index_t NumBatchNormReduceDim>
using DeviceBatchNormInferPtr = std::unique_ptr<DeviceBatchNormInfer<XDataType,
YDataType,
AccDataType,
ScaleDataType,
BiasDataType,
MeanVarDataType,
YElementwiseOp,
Rank,
NumBatchNormReduceDim>>;
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
......
...@@ -700,7 +700,7 @@ struct DeviceBatchedGemmMultiD_Xdl : public DeviceBatchedGemmMultiD<ALayout, ...@@ -700,7 +700,7 @@ struct DeviceBatchedGemmMultiD_Xdl : public DeviceBatchedGemmMultiD<ALayout,
<< BlockSize << ", " << BlockSize << ", "
<< MPerBlock << ", " << MPerBlock << ", "
<< NPerBlock << ", " << NPerBlock << ", "
<< KPerBlock << KPerBlock << ", "
<< AK1 << ", " << AK1 << ", "
<< BK1 << ", " << BK1 << ", "
<< getGemmSpecializationString(GemmSpec) << getGemmSpecializationString(GemmSpec)
......
...@@ -42,8 +42,15 @@ template <typename XDataType, ...@@ -42,8 +42,15 @@ template <typename XDataType,
index_t ScaleSrcVectorSize, index_t ScaleSrcVectorSize,
index_t BiasSrcVectorSize, index_t BiasSrcVectorSize,
index_t MeanVarSrcDstVectorSize> index_t MeanVarSrcDstVectorSize>
struct DeviceBatchNormFwdImpl struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType,
: public DeviceBatchNormFwd<Rank, NumBatchNormReduceDim, YElementwiseOp> YDataType,
AccDataType,
ScaleDataType,
BiasDataType,
MeanVarDataType,
YElementwiseOp,
Rank,
NumBatchNormReduceDim>
{ {
static_assert(Rank <= 6, "Bigger Rank size is not supported!"); static_assert(Rank <= 6, "Bigger Rank size is not supported!");
static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize, static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize,
......
...@@ -194,21 +194,36 @@ struct Relu ...@@ -194,21 +194,36 @@ struct Relu
} }
}; };
// https://paperswithcode.com/method/gelu // Y = FastGelu(X)
// y = 0.5*x*(1+tanh(sqrt(2/pi)*(x+0.044715*x^3)))
struct FastGelu struct FastGelu
{ {
template <typename Y, typename X> // Fast GeLU
__host__ __device__ void operator()(Y& y, const X& x) const; // https://paperswithcode.com/method/gelu
// y = 0.5*x*(1+tanh(sqrt(2/pi)*(x+0.044715*x^3)))
template <> __host__ __device__ static constexpr float GetFastGeLU(float x)
__host__ __device__ void operator()<float, float>(float& y, const float& x) const
{ {
const float u = float(2) * x * (float(0.035677) * x * x + float(0.797885)); const float u = 2.f * x * (0.035677f * x * x + 0.797885f);
const float emu = exp(-u); const float emu = exp(-u);
const float cdf = float(0.5) + float(0.5) * (float(2) / (float(1) + emu) - float(1)); const float cdf = 0.5f + 0.5f * (2.f / (1.f + emu) - 1.f);
return x * cdf;
}
template <typename T>
static inline constexpr bool is_valid_param_type_v =
std::is_same_v<T, float> || std::is_same_v<T, half_t> || std::is_same_v<T, bhalf_t> ||
std::is_same_v<T, int32_t> || std::is_same_v<T, int8_t>
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
|| std::is_same_v<T, ck::int4_t>
#endif
;
template <typename Y, typename X>
__host__ __device__ void operator()(Y& y, const X& x) const
{
static_assert(is_valid_param_type_v<Y> && is_valid_param_type_v<X>);
y = x * cdf; const float tmp_y = GetFastGeLU(type_convert<float>(x));
y = type_convert<Y>(tmp_y);
} }
}; };
......
...@@ -93,6 +93,9 @@ struct GridwiseMultiblockWelfordFirstHalf ...@@ -93,6 +93,9 @@ struct GridwiseMultiblockWelfordFirstHalf
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
// clang-format off
// First half of the Multiblock Welford method to calculate mean and variance, used by both batchnorm-forward and batchnorm-backward.
// clang-format on
__device__ static void Run(const XGridDesc_M_K& x_grid_desc_m_k, __device__ static void Run(const XGridDesc_M_K& x_grid_desc_m_k,
const MeanVarCountGridDesc_M_G& mean_var_count_grid_desc_m_g, const MeanVarCountGridDesc_M_G& mean_var_count_grid_desc_m_g,
const GetReduceCountPerThreadFunctor& get_reduce_count_per_thread, const GetReduceCountPerThreadFunctor& get_reduce_count_per_thread,
......
...@@ -529,6 +529,7 @@ struct GridwiseWelfordSecondHalfBatchNormForwardFinal ...@@ -529,6 +529,7 @@ struct GridwiseWelfordSecondHalfBatchNormForwardFinal
auto result_inv_var_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto result_inv_var_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
resultSaveInvVariance, mean_var_grid_desc_m.GetElementSpaceSize()); resultSaveInvVariance, mean_var_grid_desc_m.GetElementSpaceSize());
// calculate inv-variance as 1/sqrt(epsilon+variance)
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
welford_var_thread_buf(I) = welford_var_thread_buf(I) =
type_convert<AccDataType>(1.0f) / sqrt(epsilon + welford_var_thread_buf[I]); type_convert<AccDataType>(1.0f) / sqrt(epsilon + welford_var_thread_buf[I]);
......
...@@ -874,6 +874,14 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -874,6 +874,14 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
} }
} // end gemm1 } // end gemm1
// workaround compiler issue; see ck/ck.hpp
if constexpr(CK_WORKAROUND_SWDEV_XXXXXX_BF16_ATTEN_FWD_GFX908_ISSUE == 1 &&
is_same_v<FloatAB, bhalf_t> && MPerBlock == 256 && NPerBlock == 128 &&
Gemm1NPerBlock == 128)
{
__builtin_amdgcn_sched_barrier(0);
}
constexpr auto c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 = constexpr auto c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 =
gemm1_blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(); gemm1_blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
constexpr auto cm0 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I0); constexpr auto cm0 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I0);
......
...@@ -441,6 +441,7 @@ struct GridwiseBatchNormForwardWithBlockwiseWelford ...@@ -441,6 +441,7 @@ struct GridwiseBatchNormForwardWithBlockwiseWelford
auto result_inv_var_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto result_inv_var_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
resultSaveInvVariance, mean_var_grid_desc_m.GetElementSpaceSize()); resultSaveInvVariance, mean_var_grid_desc_m.GetElementSpaceSize());
// calculate inv-variance as 1/sqrt(epsilon+variance), stored in place of variance
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
var_thread_buf(I) = var_thread_buf(I) =
type_convert<AccDataType>(1.0f) / sqrt(epsilon + var_thread_buf[I]); type_convert<AccDataType>(1.0f) / sqrt(epsilon + var_thread_buf[I]);
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/data_type.hpp"
#include "ck/utility/math.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_welford.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_welford.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace ck {
template <typename GridwiseMultiblockWelfordFirstHalf_,
typename XDataType,
typename MeanVarDataType,
typename XGridDesc_M_K,
typename MeanVarCountGridDesc_M_G,
typename GetReduceCountPerThreadFunctor>
__global__ void kernel_multiblock_welford_first_half(
const XGridDesc_M_K x_grid_desc_m_k,
const MeanVarCountGridDesc_M_G mean_var_count_grid_desc_m_g,
const GetReduceCountPerThreadFunctor get_reduce_count_per_thread,
index_t num_k_block_tile_iteration,
const XDataType* const __restrict__ p_x,
MeanVarDataType* const p_welford_mean,
MeanVarDataType* const p_welford_variance,
int32_t* const p_welford_count)
{
GridwiseMultiblockWelfordFirstHalf_::Run(x_grid_desc_m_k,
mean_var_count_grid_desc_m_g,
get_reduce_count_per_thread,
num_k_block_tile_iteration,
p_x,
p_welford_mean,
p_welford_variance,
p_welford_count);
};
template <typename XDataType,
typename AccDataType,
typename MeanVarDataType,
typename XGridDesc_M_K,
typename MeanVarCountGridDesc_M_G,
typename GetReduceCountPerThreadFunctor,
index_t BlockSize,
index_t MThreadClusterSize,
index_t KThreadClusterSize,
index_t MThreadSliceSize,
index_t KThreadSliceSize,
index_t XSrcCountSrcVectorDim,
index_t XSrcCountSrcVectorSize>
struct GridwiseMultiblockWelfordFirstHalf
{
static_assert((XSrcCountSrcVectorDim == 0 && MThreadSliceSize % XSrcCountSrcVectorSize == 0) ||
(XSrcCountSrcVectorDim == 1 &&
KThreadSliceSize % XSrcCountSrcVectorSize == 0),
"Invalid thread slice sizes and/or vector sizes configuration, please check!");
static constexpr bool reorder_thread_cluster = (XSrcCountSrcVectorDim == 0);
using ThreadClusterLengths_M_K = Sequence<MThreadClusterSize, KThreadClusterSize>;
using ThreadBufferDimAccessOrder =
typename conditional<reorder_thread_cluster, Sequence<1, 0>, Sequence<0, 1>>::type;
using ThreadClusterArrangeOrder =
typename conditional<reorder_thread_cluster, Sequence<1, 0>, Sequence<0, 1>>::type;
static constexpr auto thread_cluster_desc =
make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{});
using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{})));
using ThreadReduceDstDesc_M =
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{})));
using ThreadwiseWelford =
ThreadwiseWelford<AccDataType, ThreadReduceSrcDesc_M_K, ThreadReduceDstDesc_M>;
using BlockwiseWelford = BlockwiseWelford<AccDataType,
BlockSize,
ThreadClusterLengths_M_K,
ThreadClusterArrangeOrder,
false>;
using PassThroughOp = tensor_operation::element_wise::PassThrough;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
__device__ static void Run(const XGridDesc_M_K& x_grid_desc_m_k,
const MeanVarCountGridDesc_M_G& mean_var_count_grid_desc_m_g,
const GetReduceCountPerThreadFunctor& get_reduce_count_per_thread,
index_t num_k_block_tile_iteration,
const XDataType* const __restrict__ p_x,
MeanVarDataType* const p_welford_mean,
MeanVarDataType* const p_welford_variance,
int32_t* const p_welford_count)
{
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
x_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true>
welford_mean_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true>
welford_var_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, int32_t, MThreadSliceSize, true>
welford_count_thread_buf;
const index_t blkgroup_size = mean_var_count_grid_desc_m_g.GetLength(I1);
const index_t thread_local_id = get_thread_local_1d_id();
const index_t block_global_id = get_block_1d_id();
const index_t blkgroup_id = block_global_id / blkgroup_size;
const index_t block_local_id = block_global_id % blkgroup_size;
const auto thread_cluster_idx =
thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));
const auto thread_m_cluster_id = thread_cluster_idx[I0];
const auto thread_k_cluster_id = thread_cluster_idx[I1];
using ThreadBufferLengths_M_K = Sequence<MThreadSliceSize, KThreadSliceSize>;
using ThreadBufferLengths_M_1 = Sequence<MThreadSliceSize, 1>;
constexpr auto thread_buffer_desc_m_k = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{}));
constexpr auto thread_buffer_desc_m_1 = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<1>{}));
const index_t reduceSizePerBlock = K_BlockTileSize * num_k_block_tile_iteration;
auto threadwise_x_load = ThreadwiseTensorSliceTransfer_v2<XDataType,
AccDataType,
XGridDesc_M_K,
decltype(thread_buffer_desc_m_k),
ThreadBufferLengths_M_K,
ThreadBufferDimAccessOrder,
XSrcCountSrcVectorDim,
XSrcCountSrcVectorSize,
1,
true>(
x_grid_desc_m_k,
make_multi_index(blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize,
block_local_id * reduceSizePerBlock +
thread_k_cluster_id * KThreadSliceSize));
auto threadwise_welford_mean_var_store =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
MeanVarDataType,
decltype(thread_buffer_desc_m_1),
MeanVarCountGridDesc_M_G,
PassThroughOp,
ThreadBufferLengths_M_1,
Sequence<0, 1>,
1,
1,
InMemoryDataOperationEnum::Set,
1,
true>(
mean_var_count_grid_desc_m_g,
make_multi_index(blkgroup_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
block_local_id),
PassThroughOp{});
auto threadwise_welford_count_store =
ThreadwiseTensorSliceTransfer_v1r3<int32_t,
int32_t,
decltype(thread_buffer_desc_m_1),
MeanVarCountGridDesc_M_G,
PassThroughOp,
ThreadBufferLengths_M_1,
Sequence<0, 1>,
1,
1,
InMemoryDataOperationEnum::Set,
1,
true>(
mean_var_count_grid_desc_m_g,
make_multi_index(blkgroup_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
block_local_id),
PassThroughOp{});
constexpr auto thread_copy_fwd_step_m_k = make_multi_index(0, K_BlockTileSize);
const auto x_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_x, x_grid_desc_m_k.GetElementSpaceSize());
auto welford_mean_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_welford_mean, mean_var_count_grid_desc_m_g.GetElementSpaceSize());
auto welford_var_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_welford_variance, mean_var_count_grid_desc_m_g.GetElementSpaceSize());
auto welford_count_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_welford_count, mean_var_count_grid_desc_m_g.GetElementSpaceSize());
auto threadwise_welford = ThreadwiseWelford();
threadwise_welford.max_count_ =
get_reduce_count_per_thread(block_local_id, thread_k_cluster_id);
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
welford_mean_thread_buf(I) = type_convert<AccDataType>(0.0f);
welford_var_thread_buf(I) = type_convert<AccDataType>(0.0f);
});
for(index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles)
{
threadwise_x_load.Run(x_grid_desc_m_k,
x_global_val_buf,
thread_buffer_desc_m_k,
make_tuple(I0, I0),
x_thread_buf);
threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_fwd_step_m_k);
threadwise_welford.Run(x_thread_buf, welford_mean_thread_buf, welford_var_thread_buf);
}
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
if constexpr(I > 0)
block_sync_lds();
welford_count_thread_buf(I) = threadwise_welford.cur_count_;
BlockwiseWelford::Run(
welford_mean_thread_buf(I), welford_var_thread_buf(I), welford_count_thread_buf(I));
});
if(thread_k_cluster_id == 0)
{
threadwise_welford_mean_var_store.Run(thread_buffer_desc_m_1,
make_tuple(I0, I0),
welford_mean_thread_buf,
mean_var_count_grid_desc_m_g,
welford_mean_global_val_buf);
threadwise_welford_mean_var_store.Run(thread_buffer_desc_m_1,
make_tuple(I0, I0),
welford_var_thread_buf,
mean_var_count_grid_desc_m_g,
welford_var_global_val_buf);
threadwise_welford_count_store.Run(thread_buffer_desc_m_1,
make_tuple(I0, I0),
welford_count_thread_buf,
mean_var_count_grid_desc_m_g,
welford_count_global_val_buf);
};
}
};
} // namespace ck
...@@ -254,7 +254,7 @@ struct intrin_mfma_f32_16x16x8bf16<16, 16> ...@@ -254,7 +254,7 @@ struct intrin_mfma_f32_16x16x8bf16<16, 16>
template <class FloatC> template <class FloatC>
__device__ static void Run(const bhalf2_t& reg_a, const bhalf2_t& reg_b, FloatC& reg_c) __device__ static void Run(const bhalf2_t& reg_a, const bhalf2_t& reg_b, FloatC& reg_c)
{ {
reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x4bf16( reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x8bf16(
reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 0, 0, 0); reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 0, 0, 0);
} }
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment