Commit 29448ffd authored by Harisankar Sadasivan's avatar Harisankar Sadasivan
Browse files

merge from develop and revisison for pr#881

parents 9223a5e2 8f84a012
...@@ -38,16 +38,9 @@ struct DeviceSoftmaxImpl : public DeviceSoftmax<InDataType, ...@@ -38,16 +38,9 @@ struct DeviceSoftmaxImpl : public DeviceSoftmax<InDataType,
OutDataType, OutDataType,
InElementwiseOp, InElementwiseOp,
AccElementwiseOp, AccElementwiseOp,
Rank> Rank,
NumReduceDim>
{ {
static constexpr index_t kRank = Rank;
static constexpr index_t kNumReduceDim = NumReduceDim;
static constexpr index_t kNumInvariantDim = Rank - NumReduceDim;
virtual index_t GetRank() const override { return kRank; }
virtual index_t GetNumReduceDim() const override { return kNumReduceDim; }
static constexpr index_t NumInvariantDim = Rank - NumReduceDim; static constexpr index_t NumInvariantDim = Rank - NumReduceDim;
static constexpr index_t NumSrcDim = Rank; static constexpr index_t NumSrcDim = Rank;
...@@ -287,13 +280,13 @@ struct DeviceSoftmaxImpl : public DeviceSoftmax<InDataType, ...@@ -287,13 +280,13 @@ struct DeviceSoftmaxImpl : public DeviceSoftmax<InDataType,
{ {
if constexpr(InSrcVectorDim == 0) if constexpr(InSrcVectorDim == 0)
{ {
if constexpr(kNumInvariantDim == 0) if constexpr(NumInvariantDim == 0)
{ {
return false; return false;
} }
else else
{ {
if(arg.inStrides_[kNumInvariantDim - 1] != 1 && InSrcVectorSize != 1) if(arg.inStrides_[NumInvariantDim - 1] != 1 && InSrcVectorSize != 1)
{ {
return false; return false;
} }
...@@ -316,7 +309,7 @@ struct DeviceSoftmaxImpl : public DeviceSoftmax<InDataType, ...@@ -316,7 +309,7 @@ struct DeviceSoftmaxImpl : public DeviceSoftmax<InDataType,
} }
// To improve // To improve
if(kNumInvariantDim > 0 && arg.invariant_lowest_length_ % OutDstVectorSize != 0) if(NumInvariantDim > 0 && arg.invariant_lowest_length_ % OutDstVectorSize != 0)
{ {
return false; return false;
} }
......
...@@ -617,10 +617,12 @@ struct DeviceSplitKContractionMultipleD_Xdl_CShuffle ...@@ -617,10 +617,12 @@ struct DeviceSplitKContractionMultipleD_Xdl_CShuffle
CDEBlockTransferScalarPerVector_NPerBlock, CDEBlockTransferScalarPerVector_NPerBlock,
LoopSched>; LoopSched>;
using AGridDesc_AKB_AK0_M_AK1 = remove_cvref_t<decltype( using AGridDesc_AKB_AK0_M_AK1 =
GridwiseGemm::MakeDefaultAGridDescriptor_AKB_AK0_M_AK1(AGridDesc_M_K{}, 1))>; remove_cvref_t<decltype(GridwiseGemm::MakeDefaultAGridDescriptor_AKB_AK0_M_AK1(
using BGridDesc_BKB_BK0_N_BK1 = remove_cvref_t<decltype( AGridDesc_M_K{}, 1))>;
GridwiseGemm::MakeDefaultBGridDescriptor_BKB_BK0_N_BK1(BGridDesc_N_K{}, 1))>; using BGridDesc_BKB_BK0_N_BK1 =
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultBGridDescriptor_BKB_BK0_N_BK1(
BGridDesc_N_K{}, 1))>;
using Block2ETileMap = typename GridwiseGemm::DefaultBlock2ETileMap; using Block2ETileMap = typename GridwiseGemm::DefaultBlock2ETileMap;
...@@ -886,11 +888,12 @@ struct DeviceSplitKContractionMultipleD_Xdl_CShuffle ...@@ -886,11 +888,12 @@ struct DeviceSplitKContractionMultipleD_Xdl_CShuffle
typename GridwiseGemmAtomicAdd::DefaultBlock2ETileMap, typename GridwiseGemmAtomicAdd::DefaultBlock2ETileMap,
has_main_loop>; has_main_loop>;
hipGetErrorString(hipMemset( hipGetErrorString(hipMemsetAsync(
arg.p_e_grid_, arg.p_e_grid_,
0, 0,
arg.e_grid_desc_mblock_mperblock_nblock_nperblock_.GetElementSpaceSize() * arg.e_grid_desc_mblock_mperblock_nblock_nperblock_.GetElementSpaceSize() *
sizeof(EDataType))); sizeof(EDataType),
stream_config.stream_id_));
return launch_and_time_kernel(stream_config, return launch_and_time_kernel(stream_config,
kernel, kernel,
...@@ -939,9 +942,7 @@ struct DeviceSplitKContractionMultipleD_Xdl_CShuffle ...@@ -939,9 +942,7 @@ struct DeviceSplitKContractionMultipleD_Xdl_CShuffle
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" || if(!ck::is_xdl_supported())
ck::get_device_name() == "gfx940" || ck::get_device_name() == "gfx941" ||
ck::get_device_name() == "gfx942"))
{ {
return false; return false;
} }
......
File mode changed from 100644 to 100755
File mode changed from 100644 to 100755
File mode changed from 100644 to 100755
File mode changed from 100644 to 100755
...@@ -36,6 +36,13 @@ struct Add ...@@ -36,6 +36,13 @@ struct Add
y = x0 + type_convert<half_t>(x1); y = x0 + type_convert<half_t>(x1);
}; };
template <>
__host__ __device__ constexpr void
operator()<half_t>(half_t& y, const float& x0, const float& x1) const
{
y = type_convert<half_t>(x0 + x1);
};
template <> template <>
__host__ __device__ constexpr void __host__ __device__ constexpr void
operator()<half_t>(half_t& y, const float& x0, const half_t& x1) const operator()<half_t>(half_t& y, const float& x0, const half_t& x1) const
...@@ -179,6 +186,13 @@ struct Bilinear ...@@ -179,6 +186,13 @@ struct Bilinear
y = type_convert<half_t>(alpha_ * x0 + beta_ * ck::type_convert<float>(x1)); y = type_convert<half_t>(alpha_ * x0 + beta_ * ck::type_convert<float>(x1));
}; };
template <>
__host__ __device__ constexpr void operator()<std::int8_t, std::int32_t, std::int8_t>(
std::int8_t& y, const std::int32_t& x0, const std::int8_t& x1) const
{
y = type_convert<std::int8_t>(x0 + ck::type_convert<std::int32_t>(x1));
};
float alpha_; float alpha_;
float beta_; float beta_;
}; };
......
...@@ -195,6 +195,51 @@ struct AddMultiply ...@@ -195,6 +195,51 @@ struct AddMultiply
} }
}; };
// C = A * B
// E = C x D0 + D1
struct MultiplyAdd
{
template <typename E, typename C, typename D0, typename D1>
__host__ __device__ void operator()(E& e, const C& c, const D0& d0, const D1& d1) const;
template <>
__host__ __device__ void operator()<half_t, half_t, half_t, half_t>(half_t& e,
const half_t& c,
const half_t& d0,
const half_t& d1) const
{
const half_t y = (c * d0) + d1;
e = y;
}
template <>
__host__ __device__ void operator()<half_t, float, half_t, half_t>(half_t& e,
const float& c,
const half_t& d0,
const half_t& d1) const
{
const half_t y = type_convert<half_t>(c) * d0 + d1;
e = y;
}
template <>
__host__ __device__ void operator()<float, float, half_t, half_t>(float& e,
const float& c,
const half_t& d0,
const half_t& d1) const
{
const float y = c * d0 + d1;
e = y;
}
template <>
__host__ __device__ void operator()<half_t, float, float, float>(half_t& e,
const float& c,
const float& d0,
const float& d1) const
{
const float y = c * d0 + d1;
e = y;
}
};
// E = FastGelu(C + D0 + D1) // E = FastGelu(C + D0 + D1)
struct AddAddFastGelu struct AddAddFastGelu
{ {
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#include "ck/utility/data_type.hpp" #include "ck/utility/data_type.hpp"
#include "ck/utility/math.hpp" #include "ck/utility/math.hpp"
#include "ck/utility/math_v2.hpp" #include "ck/utility/math_v2.hpp"
#include "ck/utility/type_convert.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
...@@ -38,6 +39,12 @@ struct PassThrough ...@@ -38,6 +39,12 @@ struct PassThrough
y = x; y = x;
} }
template <>
__host__ __device__ void operator()<half_t, float>(half_t& y, const float& x) const
{
y = type_convert<half_t>(x);
}
template <> template <>
__host__ __device__ void operator()<bhalf_t, bhalf_t>(bhalf_t& y, const bhalf_t& x) const __host__ __device__ void operator()<bhalf_t, bhalf_t>(bhalf_t& y, const bhalf_t& x) const
{ {
...@@ -81,6 +88,36 @@ struct PassThrough ...@@ -81,6 +88,36 @@ struct PassThrough
y = x; y = x;
} }
#endif #endif
template <>
__host__ __device__ void operator()<f8_t, f8_t>(f8_t& y, const f8_t& x) const
{
y = x;
}
template <>
__host__ __device__ void operator()<float, f8_t>(float& y, const f8_t& x) const
{
y = type_convert<float>(x);
}
template <>
__host__ __device__ void operator()<f8_t, float>(f8_t& y, const float& x) const
{
y = type_convert<f8_t>(x);
}
template <>
__host__ __device__ void operator()<half_t, f8_t>(half_t& y, const f8_t& x) const
{
y = type_convert<half_t>(x);
}
template <>
__host__ __device__ void operator()<f8_t, half_t>(f8_t& y, const half_t& x) const
{
y = type_convert<f8_t>(x);
}
}; };
struct UnaryConvert struct UnaryConvert
...@@ -109,6 +146,23 @@ struct ConvertBF16RTN ...@@ -109,6 +146,23 @@ struct ConvertBF16RTN
} }
}; };
struct ConvertF8SR
{
// convert to fp8 using stochastic rounding (SR)
template <typename Y, typename X>
__host__ __device__ void operator()(Y& y, const X& x) const
{
// check Y datatype
static_assert(is_same<Y, f8_t>::value, "Data type is not supported by this operation!");
// check X datatype
static_assert(is_same<X, float>::value || is_same<X, half_t>::value,
"Data type is not supported by this operation!");
y = f8_convert_sr<Y>(x);
}
};
struct Scale struct Scale
{ {
__host__ __device__ Scale(float scale) : scale_(scale) {} __host__ __device__ Scale(float scale) : scale_(scale) {}
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/data_type.hpp"
#include "ck/utility/math_v2.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_welford.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_welford.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/utility/workgroup_synchronization.hpp"
namespace ck {
template <typename GridwiseMultiblockBatchNormForward_,
typename XDataType,
typename YDataType,
typename AccDataType,
typename ScaleDataType,
typename BiasDataType,
typename MeanVarDataType,
typename YElementwiseOp,
typename XYGridDesc_M_K,
typename MeanVarCountGridDesc_M_G,
typename MeanVarCountGridDesc_M_K,
typename ScaleBiasGridDesc_M,
typename MeanVarGridDesc_M,
typename GetReduceCountPerThreadFunctor>
__global__ void kernel_multiblock_batchnorm_forward(
const XYGridDesc_M_K x_grid_desc_m_k,
const XYGridDesc_M_K y_grid_desc_m_k,
const MeanVarCountGridDesc_M_G mean_var_count_grid_desc_m_g,
const MeanVarCountGridDesc_M_K mean_var_count_grid_desc_m_k,
const ScaleBiasGridDesc_M scale_grid_desc_m,
const ScaleBiasGridDesc_M bias_grid_desc_m,
const MeanVarGridDesc_M mean_var_grid_desc_m,
const GetReduceCountPerThreadFunctor get_reduce_count_per_thread,
index_t num_k_block_tile_iteration,
AccDataType epsilon,
const XDataType* const __restrict__ p_x,
MeanVarDataType* const __restrict__ p_welford_mean,
MeanVarDataType* const __restrict__ p_welford_variance,
int32_t* const __restrict__ p_welford_count,
int32_t* const __restrict__ p_control,
const ScaleDataType* const __restrict__ p_scale,
const BiasDataType* const __restrict__ p_bias,
const YElementwiseOp y_elementwise_op,
YDataType* const __restrict__ p_y,
bool updateMovingAverage,
AccDataType averageFactor,
MeanVarDataType* const __restrict__ resultRunningMean,
MeanVarDataType* const __restrict__ resultRunningVariance,
bool saveMeanInvVariance,
MeanVarDataType* const __restrict__ resultSaveMean,
MeanVarDataType* const __restrict__ resultSaveInvVariance)
{
GridwiseMultiblockBatchNormForward_::Run(x_grid_desc_m_k,
y_grid_desc_m_k,
mean_var_count_grid_desc_m_g,
mean_var_count_grid_desc_m_k,
scale_grid_desc_m,
bias_grid_desc_m,
mean_var_grid_desc_m,
get_reduce_count_per_thread,
num_k_block_tile_iteration,
epsilon,
p_x,
p_welford_mean,
p_welford_variance,
p_welford_count,
p_control,
p_scale,
p_bias,
y_elementwise_op,
p_y,
updateMovingAverage,
averageFactor,
resultRunningMean,
resultRunningVariance,
saveMeanInvVariance,
resultSaveMean,
resultSaveInvVariance);
};
template <typename XDataType,
typename YDataType,
typename AccDataType,
typename ScaleDataType,
typename BiasDataType,
typename MeanVarDataType,
typename YElementwiseOp,
typename XYGridDesc_M_K,
typename MeanVarCountGridDesc_M_G,
typename MeanVarCountGridDesc_M_K,
typename ScaleBiasGridDesc_M,
typename MeanVarGridDesc_M,
typename GetReduceCountPerThreadFunctor,
index_t BlockSize,
index_t MThreadClusterSize,
index_t KThreadClusterSize,
index_t MThreadSliceSize,
index_t KThreadSliceSize,
index_t XSrcYDstVectorDim,
index_t XSrcVectorSize,
index_t YDstVectorSize,
index_t ScaleSrcVectorSize,
index_t BiasSrcVectorSize,
index_t MeanVarSrcDstVectorSize>
struct GridwiseMultiblockBatchNormForward
{
static_assert((XSrcYDstVectorDim == 0 && MThreadSliceSize % XSrcVectorSize == 0) ||
(XSrcYDstVectorDim == 1 && KThreadSliceSize % XSrcVectorSize == 0),
"Invalid thread slice sizes and/or vector sizes configuration, please check!");
static_assert((XSrcYDstVectorDim == 0 && MThreadSliceSize % YDstVectorSize == 0) ||
(XSrcYDstVectorDim == 1 && KThreadSliceSize % YDstVectorSize == 0),
"Invalid thread slice sizes and/or vector sizes configuration, please check!");
static constexpr bool reorder_thread_cluster = (XSrcYDstVectorDim == 0);
using ThreadClusterLengths_M_K = Sequence<MThreadClusterSize, KThreadClusterSize>;
using ThreadBufferDimAccessOrder =
typename conditional<reorder_thread_cluster, Sequence<1, 0>, Sequence<0, 1>>::type;
using ThreadClusterArrangeOrder =
typename conditional<reorder_thread_cluster, Sequence<1, 0>, Sequence<0, 1>>::type;
static constexpr auto thread_cluster_desc =
make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{});
using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{})));
using ThreadReduceDstDesc_M =
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{})));
using ThreadReduceSrcDesc_M_1 = decltype(make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<1>{})));
using ThreadwiseWelford1 =
ThreadwiseWelford<AccDataType, ThreadReduceSrcDesc_M_K, ThreadReduceDstDesc_M>;
using ThreadwiseWelford2 =
ThreadwiseWelfordMerge<AccDataType, ThreadReduceSrcDesc_M_1, ThreadReduceDstDesc_M>;
using BlockwiseWelford1 = BlockwiseWelford<AccDataType,
BlockSize,
ThreadClusterLengths_M_K,
ThreadClusterArrangeOrder,
false>;
using BlockwiseWelford2 = BlockwiseWelford<AccDataType,
BlockSize,
ThreadClusterLengths_M_K,
ThreadClusterArrangeOrder,
true>;
using PassThroughOp = tensor_operation::element_wise::PassThrough;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
__device__ static void Run(const XYGridDesc_M_K& x_grid_desc_m_k,
const XYGridDesc_M_K& y_grid_desc_m_k,
const MeanVarCountGridDesc_M_G& mean_var_count_grid_desc_m_g,
const MeanVarCountGridDesc_M_K& mean_var_count_grid_desc_m_k,
const ScaleBiasGridDesc_M& scale_grid_desc_m,
const ScaleBiasGridDesc_M& bias_grid_desc_m,
const MeanVarGridDesc_M& mean_var_grid_desc_m,
const GetReduceCountPerThreadFunctor& get_reduce_count_per_thread,
index_t num_k_block_tile_iteration,
AccDataType epsilon,
const XDataType* const __restrict__ p_x,
MeanVarDataType* const __restrict__ p_welford_mean,
MeanVarDataType* const __restrict__ p_welford_variance,
int32_t* const __restrict__ p_welford_count,
int32_t* const __restrict__ p_control,
const ScaleDataType* const __restrict__ p_scale,
const BiasDataType* const __restrict__ p_bias,
const YElementwiseOp y_elementwise_op,
YDataType* const __restrict__ p_y,
bool updateMovingAverage,
AccDataType averageFactor,
MeanVarDataType* const __restrict__ resultRunningMean,
MeanVarDataType* const __restrict__ resultRunningVariance,
bool saveMeanInvVariance,
MeanVarDataType* const __restrict__ resultSaveMean,
MeanVarDataType* const __restrict__ resultSaveInvVariance)
{
using ck::math::sqrt;
const index_t blkgroup_size = mean_var_count_grid_desc_m_g.GetLength(I1);
const index_t thread_local_id = get_thread_local_1d_id();
const index_t block_global_id = get_block_1d_id();
const index_t blkgroup_id = block_global_id / blkgroup_size;
const index_t block_local_id = block_global_id % blkgroup_size;
if(block_local_id == 0)
gms_init(BlockSize / warpSize * blkgroup_size, &p_control[blkgroup_id * 2]);
const auto thread_cluster_idx =
thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));
const auto thread_m_cluster_id = thread_cluster_idx[I0];
const auto thread_k_cluster_id = thread_cluster_idx[I1];
using ThreadBufferLengths_M_K = Sequence<MThreadSliceSize, KThreadSliceSize>;
using ThreadBufferLengths_M = Sequence<MThreadSliceSize>;
using ThreadBufferLengths_M_1 = Sequence<MThreadSliceSize, 1>;
constexpr auto thread_buffer_desc_m_k = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{}));
constexpr auto thread_buffer_desc_m =
make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{}));
constexpr auto thread_buffer_desc_m_1 = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<1>{}));
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
x_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> mean_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> var_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, int32_t, MThreadSliceSize, true> count_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true>
tmp_mean_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true>
tmp_var_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, int32_t, MThreadSliceSize, true> tmp_count_thread_buf;
const index_t reduceSizePerBlock = K_BlockTileSize * num_k_block_tile_iteration;
auto threadwise_x_load = ThreadwiseTensorSliceTransfer_v2<XDataType,
AccDataType,
XYGridDesc_M_K,
decltype(thread_buffer_desc_m_k),
ThreadBufferLengths_M_K,
ThreadBufferDimAccessOrder,
XSrcYDstVectorDim,
XSrcVectorSize,
1,
true>(
x_grid_desc_m_k,
make_multi_index(blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize,
block_local_id * reduceSizePerBlock +
thread_k_cluster_id * KThreadSliceSize));
constexpr auto xy_copy_fwd_step_m_k = make_multi_index(0, K_BlockTileSize);
const auto x_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_x, x_grid_desc_m_k.GetElementSpaceSize());
// Step 1: each workgroup does local welford reduction
auto threadwise_welford_1 = ThreadwiseWelford1();
threadwise_welford_1.max_count_ =
get_reduce_count_per_thread(block_local_id, thread_k_cluster_id);
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
mean_thread_buf(I) = type_convert<AccDataType>(0.0f);
var_thread_buf(I) = type_convert<AccDataType>(0.0f);
});
for(index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles)
{
threadwise_x_load.Run(x_grid_desc_m_k,
x_global_val_buf,
thread_buffer_desc_m_k,
make_tuple(I0, I0),
x_thread_buf);
threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, xy_copy_fwd_step_m_k);
threadwise_welford_1.Run(x_thread_buf, mean_thread_buf, var_thread_buf);
}
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
if constexpr(I > 0)
block_sync_lds();
count_thread_buf(I) = threadwise_welford_1.cur_count_;
BlockwiseWelford1::Run(mean_thread_buf(I), var_thread_buf(I), count_thread_buf(I));
});
// Step 2: each workgroup writes its local welford result to workspace memory
auto mean_global_val_buf =
make_dynamic_buffer<AddressSpaceEnum::Global, AmdBufferCoherenceEnum::GLC>(
p_welford_mean, mean_var_count_grid_desc_m_g.GetElementSpaceSize());
auto var_global_val_buf =
make_dynamic_buffer<AddressSpaceEnum::Global, AmdBufferCoherenceEnum::GLC>(
p_welford_variance, mean_var_count_grid_desc_m_g.GetElementSpaceSize());
auto count_global_val_buf =
make_dynamic_buffer<AddressSpaceEnum::Global, AmdBufferCoherenceEnum::GLC>(
p_welford_count, mean_var_count_grid_desc_m_g.GetElementSpaceSize());
auto threadwise_mean_var_store_m_g =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
MeanVarDataType,
decltype(thread_buffer_desc_m_1),
MeanVarCountGridDesc_M_G,
PassThroughOp,
ThreadBufferLengths_M_1,
Sequence<0, 1>,
0,
1,
InMemoryDataOperationEnum::Set,
1,
true>(
mean_var_count_grid_desc_m_g,
make_multi_index(blkgroup_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
block_local_id),
PassThroughOp{});
auto threadwise_count_store_m_g =
ThreadwiseTensorSliceTransfer_v1r3<int32_t,
int32_t,
decltype(thread_buffer_desc_m_1),
MeanVarCountGridDesc_M_G,
PassThroughOp,
ThreadBufferLengths_M_1,
Sequence<0, 1>,
0,
1,
InMemoryDataOperationEnum::Set,
1,
true>(
mean_var_count_grid_desc_m_g,
make_multi_index(blkgroup_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
block_local_id),
PassThroughOp{});
if(thread_k_cluster_id == 0)
{
threadwise_mean_var_store_m_g.Run(thread_buffer_desc_m_1,
make_tuple(I0, I0),
mean_thread_buf,
mean_var_count_grid_desc_m_g,
mean_global_val_buf);
threadwise_mean_var_store_m_g.Run(thread_buffer_desc_m_1,
make_tuple(I0, I0),
var_thread_buf,
mean_var_count_grid_desc_m_g,
var_global_val_buf);
threadwise_count_store_m_g.Run(thread_buffer_desc_m_1,
make_tuple(I0, I0),
count_thread_buf,
mean_var_count_grid_desc_m_g,
count_global_val_buf);
};
gms_barrier(&p_control[blkgroup_id * 2]);
if(block_local_id == 0)
gms_reset(&p_control[blkgroup_id * 2]);
// Step 3: each workgroup reads welford results from workspace memory and does final welford
// reduction
auto threadwise_mean_var_load_m_k =
ThreadwiseTensorSliceTransfer_v2<MeanVarDataType,
AccDataType,
MeanVarCountGridDesc_M_K,
decltype(thread_buffer_desc_m_1),
ThreadBufferLengths_M_1,
Sequence<0, 1>,
0,
1,
1,
true>(
mean_var_count_grid_desc_m_k,
make_multi_index(blkgroup_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * 1));
auto threadwise_count_load_m_k =
ThreadwiseTensorSliceTransfer_v2<int32_t,
int32_t,
MeanVarCountGridDesc_M_K,
decltype(thread_buffer_desc_m_1),
ThreadBufferLengths_M_1,
Sequence<0, 1>,
0,
1,
1,
true>(
mean_var_count_grid_desc_m_k,
make_multi_index(blkgroup_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * 1));
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
mean_thread_buf(I) = type_convert<AccDataType>(0.0f);
var_thread_buf(I) = type_convert<AccDataType>(0.0f);
count_thread_buf(I) = 0;
});
constexpr auto mean_var_count_read_fwd_step_m_k = make_multi_index(0, KThreadClusterSize);
int32_t reducedSize = 0;
while(reducedSize < blkgroup_size)
{
threadwise_mean_var_load_m_k.Run(mean_var_count_grid_desc_m_k,
mean_global_val_buf,
thread_buffer_desc_m_1,
make_tuple(I0, I0),
tmp_mean_thread_buf);
threadwise_mean_var_load_m_k.Run(mean_var_count_grid_desc_m_k,
var_global_val_buf,
thread_buffer_desc_m_1,
make_tuple(I0, I0),
tmp_var_thread_buf);
threadwise_count_load_m_k.Run(mean_var_count_grid_desc_m_k,
count_global_val_buf,
thread_buffer_desc_m_1,
make_tuple(I0, I0),
tmp_count_thread_buf);
ThreadwiseWelford2::Run(tmp_mean_thread_buf,
tmp_var_thread_buf,
tmp_count_thread_buf,
mean_thread_buf,
var_thread_buf,
count_thread_buf);
reducedSize += KThreadClusterSize;
threadwise_mean_var_load_m_k.MoveSrcSliceWindow(mean_var_count_grid_desc_m_k,
mean_var_count_read_fwd_step_m_k);
threadwise_count_load_m_k.MoveSrcSliceWindow(mean_var_count_grid_desc_m_k,
mean_var_count_read_fwd_step_m_k);
};
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
if constexpr(I > 0)
block_sync_lds();
BlockwiseWelford2::Run(mean_thread_buf(I), var_thread_buf(I), count_thread_buf(I));
});
// Step 4: do normalization using the mean/variance
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> scale_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> bias_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
y_thread_buf;
auto threadwise_y_store =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
YDataType,
decltype(thread_buffer_desc_m_k),
XYGridDesc_M_K,
YElementwiseOp,
ThreadBufferLengths_M_K,
ThreadBufferDimAccessOrder,
XSrcYDstVectorDim,
YDstVectorSize,
InMemoryDataOperationEnum::Set,
1,
true>(
y_grid_desc_m_k,
make_multi_index(
blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize,
block_local_id * reduceSizePerBlock + thread_k_cluster_id * KThreadSliceSize),
y_elementwise_op);
auto threadwise_scale_load =
ThreadwiseTensorSliceTransfer_v2<ScaleDataType,
AccDataType,
ScaleBiasGridDesc_M,
decltype(thread_buffer_desc_m),
ThreadBufferLengths_M,
Sequence<0>,
0,
ScaleSrcVectorSize,
1,
true>(
scale_grid_desc_m,
make_multi_index(blkgroup_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize));
auto threadwise_bias_load = ThreadwiseTensorSliceTransfer_v2<BiasDataType,
AccDataType,
ScaleBiasGridDesc_M,
decltype(thread_buffer_desc_m),
ThreadBufferLengths_M,
Sequence<0>,
0,
BiasSrcVectorSize,
1,
true>(
bias_grid_desc_m,
make_multi_index(blkgroup_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize));
const auto scale_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_scale, scale_grid_desc_m.GetElementSpaceSize());
const auto bias_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_bias, bias_grid_desc_m.GetElementSpaceSize());
auto y_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_y, y_grid_desc_m_k.GetElementSpaceSize());
threadwise_scale_load.Run(scale_grid_desc_m,
scale_global_val_buf,
thread_buffer_desc_m,
make_tuple(I0),
scale_thread_buf);
threadwise_bias_load.Run(bias_grid_desc_m,
bias_global_val_buf,
thread_buffer_desc_m,
make_tuple(I0),
bias_thread_buf);
threadwise_x_load.SetSrcSliceOrigin(
x_grid_desc_m_k,
make_multi_index(blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize,
block_local_id * reduceSizePerBlock +
thread_k_cluster_id * KThreadSliceSize));
for(index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles)
{
threadwise_x_load.Run(x_grid_desc_m_k,
x_global_val_buf,
thread_buffer_desc_m_k,
make_tuple(I0, I0),
x_thread_buf);
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
AccDataType multiplier =
scale_thread_buf[Number<iM>{}] / sqrt(var_thread_buf[iM] + epsilon);
AccDataType fused_mean_bias =
bias_thread_buf[Number<iM>{}] - mean_thread_buf[iM] * multiplier;
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset =
thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK));
// normalize
y_thread_buf(Number<offset>{}) =
x_thread_buf[Number<offset>{}] * multiplier + fused_mean_bias;
});
});
threadwise_y_store.Run(thread_buffer_desc_m_k,
make_tuple(I0, I0),
y_thread_buf,
y_grid_desc_m_k,
y_global_val_buf);
threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, xy_copy_fwd_step_m_k);
threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k, xy_copy_fwd_step_m_k);
}
// Step 5: update the moving average of mean and variance (optional)
if(updateMovingAverage && block_local_id == 0 && thread_k_cluster_id == 0)
{
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true>
running_mean_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true>
running_var_thread_buf;
auto running_mean_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
resultRunningMean, mean_var_grid_desc_m.GetElementSpaceSize());
auto running_var_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
resultRunningVariance, mean_var_grid_desc_m.GetElementSpaceSize());
auto threadwise_mean_var_load =
ThreadwiseTensorSliceTransfer_v2<MeanVarDataType,
AccDataType,
MeanVarGridDesc_M,
decltype(thread_buffer_desc_m),
ThreadBufferLengths_M,
Sequence<0>,
0,
MeanVarSrcDstVectorSize,
1,
true>(
mean_var_grid_desc_m,
make_multi_index(blkgroup_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize));
threadwise_mean_var_load.Run(mean_var_grid_desc_m,
running_mean_global_buf,
thread_buffer_desc_m,
make_tuple(I0),
running_mean_thread_buf);
threadwise_mean_var_load.Run(mean_var_grid_desc_m,
running_var_global_buf,
thread_buffer_desc_m,
make_tuple(I0),
running_var_thread_buf);
AccDataType oneMinusAverageFactor = type_convert<AccDataType>(1.0) - averageFactor;
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
running_mean_thread_buf(I) = running_mean_thread_buf[I] * oneMinusAverageFactor +
mean_thread_buf[I] * averageFactor;
running_var_thread_buf(I) = running_var_thread_buf[I] * oneMinusAverageFactor +
var_thread_buf[I] * averageFactor;
});
auto threadwise_mean_var_store =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
MeanVarDataType,
decltype(thread_buffer_desc_m),
MeanVarGridDesc_M,
PassThroughOp,
ThreadBufferLengths_M,
Sequence<0>,
0,
MeanVarSrcDstVectorSize,
InMemoryDataOperationEnum::Set,
1,
true>(
mean_var_grid_desc_m,
make_multi_index(blkgroup_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize),
PassThroughOp{});
threadwise_mean_var_store.Run(thread_buffer_desc_m,
make_tuple(I0),
running_mean_thread_buf,
mean_var_grid_desc_m,
running_mean_global_buf);
threadwise_mean_var_store.Run(thread_buffer_desc_m,
make_tuple(I0),
running_var_thread_buf,
mean_var_grid_desc_m,
running_var_global_buf);
};
// Step 6: save mean and inv-variance (optional)
if(saveMeanInvVariance && block_local_id == 0 && thread_k_cluster_id == 0)
{
auto result_mean_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
resultSaveMean, mean_var_grid_desc_m.GetElementSpaceSize());
auto result_inv_var_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
resultSaveInvVariance, mean_var_grid_desc_m.GetElementSpaceSize());
// calculate inv-variance as 1/sqrt(epsilon+variance), stored in place of variance
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
var_thread_buf(I) =
type_convert<AccDataType>(1.0f) / sqrt(epsilon + var_thread_buf[I]);
});
auto threadwise_mean_inv_var_store =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
MeanVarDataType,
decltype(thread_buffer_desc_m),
MeanVarGridDesc_M,
PassThroughOp,
ThreadBufferLengths_M,
Sequence<0>,
0,
MeanVarSrcDstVectorSize,
InMemoryDataOperationEnum::Set,
1,
true>(
mean_var_grid_desc_m,
make_multi_index(blkgroup_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize),
PassThroughOp{});
threadwise_mean_inv_var_store.Run(thread_buffer_desc_m,
make_tuple(I0),
mean_thread_buf,
mean_var_grid_desc_m,
result_mean_global_buf);
threadwise_mean_inv_var_store.Run(thread_buffer_desc_m,
make_tuple(I0),
var_thread_buf,
mean_var_grid_desc_m,
result_inv_var_global_buf);
};
}
}; // namespace ck
} // namespace ck
...@@ -118,8 +118,8 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal ...@@ -118,8 +118,8 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal
static constexpr auto thread_cluster_desc = static constexpr auto thread_cluster_desc =
make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{}); make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{});
using ThreadReduceSrcDesc_M_1 = decltype( using ThreadReduceSrcDesc_M_1 = decltype(make_naive_tensor_descriptor_packed(
make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{}, Number<1>{}))); make_tuple(Number<MThreadSliceSize>{}, Number<1>{})));
using ThreadReduceDstDesc_M = using ThreadReduceDstDesc_M =
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{}))); decltype(make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{})));
......
...@@ -161,7 +161,7 @@ struct GridwiseMultiblockWelfordFirstHalf ...@@ -161,7 +161,7 @@ struct GridwiseMultiblockWelfordFirstHalf
PassThroughOp, PassThroughOp,
ThreadBufferLengths_M_1, ThreadBufferLengths_M_1,
Sequence<0, 1>, Sequence<0, 1>,
1, 0,
1, 1,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
1, 1,
...@@ -180,7 +180,7 @@ struct GridwiseMultiblockWelfordFirstHalf ...@@ -180,7 +180,7 @@ struct GridwiseMultiblockWelfordFirstHalf
PassThroughOp, PassThroughOp,
ThreadBufferLengths_M_1, ThreadBufferLengths_M_1,
Sequence<0, 1>, Sequence<0, 1>,
1, 0,
1, 1,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
1, 1,
......
...@@ -33,7 +33,6 @@ __global__ void kernel_welford_second_half_batchnorm_forward_final( ...@@ -33,7 +33,6 @@ __global__ void kernel_welford_second_half_batchnorm_forward_final(
const MeanVarGridDesc_M mean_var_grid_desc_m, const MeanVarGridDesc_M mean_var_grid_desc_m,
index_t blkgroup_size, index_t blkgroup_size,
index_t num_xy_k_block_tile_iteration, index_t num_xy_k_block_tile_iteration,
index_t num_mean_var_count_k_block_tile_iteration,
AccDataType epsilon, AccDataType epsilon,
const MeanVarDataType* const __restrict__ p_in_welford_mean, const MeanVarDataType* const __restrict__ p_in_welford_mean,
const MeanVarDataType* const __restrict__ p_in_welford_variance, const MeanVarDataType* const __restrict__ p_in_welford_variance,
...@@ -59,7 +58,6 @@ __global__ void kernel_welford_second_half_batchnorm_forward_final( ...@@ -59,7 +58,6 @@ __global__ void kernel_welford_second_half_batchnorm_forward_final(
mean_var_grid_desc_m, mean_var_grid_desc_m,
blkgroup_size, blkgroup_size,
num_xy_k_block_tile_iteration, num_xy_k_block_tile_iteration,
num_mean_var_count_k_block_tile_iteration,
epsilon, epsilon,
p_in_welford_mean, p_in_welford_mean,
p_in_welford_variance, p_in_welford_variance,
...@@ -123,8 +121,8 @@ struct GridwiseWelfordSecondHalfBatchNormForwardFinal ...@@ -123,8 +121,8 @@ struct GridwiseWelfordSecondHalfBatchNormForwardFinal
static constexpr auto thread_cluster_desc = static constexpr auto thread_cluster_desc =
make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{}); make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{});
using ThreadReduceSrcDesc_M_1 = decltype( using ThreadReduceSrcDesc_M_1 = decltype(make_naive_tensor_descriptor_packed(
make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{}, Number<1>{}))); make_tuple(Number<MThreadSliceSize>{}, Number<1>{})));
using ThreadReduceDstDesc_M = using ThreadReduceDstDesc_M =
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{}))); decltype(make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{})));
...@@ -152,7 +150,6 @@ struct GridwiseWelfordSecondHalfBatchNormForwardFinal ...@@ -152,7 +150,6 @@ struct GridwiseWelfordSecondHalfBatchNormForwardFinal
const MeanVarGridDesc_M& mean_var_grid_desc_m, const MeanVarGridDesc_M& mean_var_grid_desc_m,
index_t blkgroup_size, index_t blkgroup_size,
index_t num_xy_k_block_tile_iteration, index_t num_xy_k_block_tile_iteration,
index_t num_mean_var_count_k_block_tile_iteration,
AccDataType epsilon, AccDataType epsilon,
const MeanVarDataType* const __restrict__ p_in_welford_mean, const MeanVarDataType* const __restrict__ p_in_welford_mean,
const MeanVarDataType* const __restrict__ p_in_welford_variance, const MeanVarDataType* const __restrict__ p_in_welford_variance,
...@@ -223,7 +220,7 @@ struct GridwiseWelfordSecondHalfBatchNormForwardFinal ...@@ -223,7 +220,7 @@ struct GridwiseWelfordSecondHalfBatchNormForwardFinal
decltype(thread_buffer_desc_m_1), decltype(thread_buffer_desc_m_1),
ThreadBufferLengths_M_1, ThreadBufferLengths_M_1,
Sequence<0, 1>, Sequence<0, 1>,
1, 0,
1, 1,
1, 1,
true>( true>(
...@@ -239,7 +236,7 @@ struct GridwiseWelfordSecondHalfBatchNormForwardFinal ...@@ -239,7 +236,7 @@ struct GridwiseWelfordSecondHalfBatchNormForwardFinal
decltype(thread_buffer_desc_m_1), decltype(thread_buffer_desc_m_1),
ThreadBufferLengths_M_1, ThreadBufferLengths_M_1,
Sequence<0, 1>, Sequence<0, 1>,
1, 0,
1, 1,
1, 1,
true>( true>(
...@@ -257,9 +254,6 @@ struct GridwiseWelfordSecondHalfBatchNormForwardFinal ...@@ -257,9 +254,6 @@ struct GridwiseWelfordSecondHalfBatchNormForwardFinal
const auto welford_count_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto welford_count_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_welford_count, mean_var_count_grid_desc_m_k.GetElementSpaceSize()); p_in_welford_count, mean_var_count_grid_desc_m_k.GetElementSpaceSize());
constexpr auto mean_var_count_thread_copy_step_m_k =
make_multi_index(0, KThreadClusterSize * 1);
// Step 1: do final welford reduction to get mean and variance // Step 1: do final welford reduction to get mean and variance
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
...@@ -268,8 +262,11 @@ struct GridwiseWelfordSecondHalfBatchNormForwardFinal ...@@ -268,8 +262,11 @@ struct GridwiseWelfordSecondHalfBatchNormForwardFinal
welford_count_thread_buf(I) = 0; welford_count_thread_buf(I) = 0;
}); });
for(index_t reducedTiles = 0; reducedTiles < num_mean_var_count_k_block_tile_iteration; constexpr auto mean_var_count_thread_copy_step_m_k =
++reducedTiles) make_multi_index(0, KThreadClusterSize);
int32_t reducedSize = 0;
while(reducedSize < blkgroup_size)
{ {
threadwise_mean_var_load_m_k.Run(mean_var_count_grid_desc_m_k, threadwise_mean_var_load_m_k.Run(mean_var_count_grid_desc_m_k,
welford_mean_global_val_buf, welford_mean_global_val_buf,
...@@ -296,6 +293,8 @@ struct GridwiseWelfordSecondHalfBatchNormForwardFinal ...@@ -296,6 +293,8 @@ struct GridwiseWelfordSecondHalfBatchNormForwardFinal
welford_var_thread_buf, welford_var_thread_buf,
welford_count_thread_buf); welford_count_thread_buf);
reducedSize += KThreadClusterSize;
threadwise_mean_var_load_m_k.MoveSrcSliceWindow(mean_var_count_grid_desc_m_k, threadwise_mean_var_load_m_k.MoveSrcSliceWindow(mean_var_count_grid_desc_m_k,
mean_var_count_thread_copy_step_m_k); mean_var_count_thread_copy_step_m_k);
threadwise_count_load_m_k.MoveSrcSliceWindow(mean_var_count_grid_desc_m_k, threadwise_count_load_m_k.MoveSrcSliceWindow(mean_var_count_grid_desc_m_k,
......
...@@ -115,8 +115,8 @@ struct GridwiseWelfordSecondHalfReduceFirstHalf ...@@ -115,8 +115,8 @@ struct GridwiseWelfordSecondHalfReduceFirstHalf
using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed( using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{}))); make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{})));
using ThreadReduceSrcDesc_M_1 = decltype( using ThreadReduceSrcDesc_M_1 = decltype(make_naive_tensor_descriptor_packed(
make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{}, Number<1>{}))); make_tuple(Number<MThreadSliceSize>{}, Number<1>{})));
using ThreadReduceDstDesc_M = using ThreadReduceDstDesc_M =
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{}))); decltype(make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{})));
......
...@@ -7,6 +7,8 @@ ...@@ -7,6 +7,8 @@
#include "ck/utility/number.hpp" #include "ck/utility/number.hpp"
#include "ck/tensor_description/tensor_adaptor.hpp" #include "ck/tensor_description/tensor_adaptor.hpp"
#include "ck/tensor_description/multi_index_transform_helper.hpp" #include "ck/tensor_description/multi_index_transform_helper.hpp"
#include <limits>
#include <stdlib.h>
namespace ck { namespace ck {
...@@ -585,7 +587,8 @@ struct OffsettedBlockToCTileMap ...@@ -585,7 +587,8 @@ struct OffsettedBlockToCTileMap
{ {
using underlying_type = UnderlyingBlockToCTileMap; using underlying_type = UnderlyingBlockToCTileMap;
OffsettedBlockToCTileMap(UnderlyingBlockToCTileMap block_to_ctile_map, index_t block_start) __host__ __device__ OffsettedBlockToCTileMap(UnderlyingBlockToCTileMap block_to_ctile_map,
index_t block_start)
{ {
block_to_ctile_map_ = block_to_ctile_map; block_to_ctile_map_ = block_to_ctile_map;
block_start_ = block_start; block_start_ = block_start;
...@@ -679,4 +682,406 @@ struct BlockToCTileMap_3DGrid_KSplit ...@@ -679,4 +682,406 @@ struct BlockToCTileMap_3DGrid_KSplit
} }
}; };
enum StreamKReductionStrategy
{
Atomic = 0, // sk block use atomic to do reduction
Reduction, // let some workgroup responsible for doing the reduction operation
};
template <uint32_t MPerBlock_,
uint32_t NPerBlock_,
uint32_t KPerBlock_,
StreamKReductionStrategy ReductionStrategy_ = StreamKReductionStrategy::Atomic,
uint32_t TileSwizzleSubM_ = 8>
struct BlockToCTileMap_GemmStreamK
{
static constexpr uint32_t min_k_iters_per_sk_block = 2;
static constexpr uint32_t MPerBlock = MPerBlock_;
static constexpr uint32_t NPerBlock = NPerBlock_;
static constexpr uint32_t KPerBlock = KPerBlock_;
static constexpr StreamKReductionStrategy ReductionStrategy = ReductionStrategy_;
static constexpr uint32_t tile_swizzle_sub_m = TileSwizzleSubM_;
//--------------------------------------
// pass to device
uint32_t sk_num_blocks;
uint32_t sk_num_big_blocks;
uint32_t dp_start_block_idx;
uint32_t reduction_start_block_idx;
uint32_t k_iters_per_big_block;
MDiv2 n_tiles;
MDiv k_iters_per_tile;
MDiv eqav_tiles_big; // for reduction
MDiv eqav_tiles_little; // for reduction
// MDiv tile_swizzle_sub_m_rem;
//--------------------------------------
// prefer construct on host
BlockToCTileMap_GemmStreamK(uint32_t m,
uint32_t n,
uint32_t k,
uint32_t num_cu,
uint32_t occupancy,
uint32_t sk_blocks = 0xffffffff)
{
uint32_t num_tiles =
math::integer_divide_ceil(m, MPerBlock) * math::integer_divide_ceil(n, NPerBlock);
k_iters_per_tile = MDiv(math::integer_divide_ceil(k, KPerBlock));
// one cu can hold one wg at one time, from the whole chip's point of view
// if number of wg is same as num_cu, we call it 1 dispatch
// if number of wg is 2x num_cu, we call it 2 dispatches.
// one dispatch can deliver wg same as num_cu (full dispatch), or less than num_cu (partial
// dispatch)
//
uint32_t full_dispatches = num_tiles / num_cu;
uint32_t full_dispatch_tiles = full_dispatches * num_cu;
uint32_t partial_dispatche_tiles = num_tiles - full_dispatch_tiles;
uint32_t sk_occupancy = occupancy;
uint32_t dp_tiles = full_dispatch_tiles;
uint32_t sk_tiles = partial_dispatche_tiles;
if(full_dispatches < occupancy)
{
// in this case, we allocate all blocks as sk blocks
// sk_occupancy = occupancy - full_dispatches;
sk_occupancy = 1; // TODO: single occ seems better
dp_tiles = full_dispatch_tiles;
sk_tiles = partial_dispatche_tiles;
}
else if((occupancy > 1) && (full_dispatches % occupancy == occupancy - 1))
{
// e.g. occupancy = 2, full_dispatches = 3, 5, 7 ...
// occupancy = 3, full_dispatches = 5, 8, 11 ...
// occupancy = 4, full_dispatches = 7, 11 ...
sk_occupancy = 1; // left 1 slot for sk occupancy
dp_tiles = full_dispatch_tiles;
sk_tiles = partial_dispatche_tiles;
}
else
{
// others, we reduce 1 dispatch from dp, together with partial dispatch,
// to construct sk dispatch
sk_occupancy = occupancy - ((full_dispatches - 1) % occupancy);
dp_tiles = full_dispatch_tiles - num_cu;
sk_tiles = partial_dispatche_tiles + num_cu;
}
// uint32_t dp_iters_per_block = k_iters_per_tile.get();
uint32_t sk_total_iters = k_iters_per_tile.get() * sk_tiles;
uint32_t dp_num_blocks = 0;
{
uint32_t min_sk_tiles = (sk_tiles >= num_cu) ? num_cu : (sk_tiles + 1);
uint32_t max_sk_tiles =
(sk_tiles >= num_cu) ? num_cu * sk_occupancy
: math::min(num_cu, sk_total_iters / min_k_iters_per_sk_block);
// if use dp for sk-block, how many iters do we need
uint32_t dp_for_sk_iters = k_iters_per_tile.get();
uint32_t best_sk_score =
std::numeric_limits<int>::max(); // we need to find the smallest sk iters
for(uint32_t tentative_sk_blocks = min_sk_tiles; tentative_sk_blocks < max_sk_tiles;
tentative_sk_blocks++)
{
uint32_t tentative_sk_iters_per_block =
(sk_total_iters + tentative_sk_blocks - 1) / tentative_sk_blocks;
uint32_t tentative_sk_iters = tentative_sk_iters_per_block;
uint32_t sk_blocks_per_tile = (tentative_sk_blocks + sk_tiles - 1) / sk_tiles;
// TODO: carefully adjust this parameter
// the more sk_blocks_per_tile, the worse the overhead
uint32_t cross_sk_blocks_overhead = sk_blocks_per_tile;
if(tentative_sk_blocks % sk_tiles != 0)
{
// penalty for uneven divide
cross_sk_blocks_overhead +=
sk_blocks_per_tile * tentative_sk_iters_per_block / 50;
}
uint32_t tentative_sk_score = tentative_sk_iters + cross_sk_blocks_overhead;
if(tentative_sk_score < best_sk_score)
{
best_sk_score = tentative_sk_score;
sk_num_blocks = tentative_sk_blocks;
}
}
if(best_sk_score >= dp_for_sk_iters)
{
sk_num_blocks = 0;
}
// give a chance to control num of sk blocks
sk_num_blocks = sk_blocks != 0xffffffff ? sk_blocks : sk_num_blocks;
if(sk_num_blocks == 0)
{
sk_num_big_blocks = 0;
k_iters_per_big_block = 0;
dp_num_blocks = num_tiles; // all tile to be dp block
dp_start_block_idx = 0;
sk_total_iters = 0; // clear this tiles
}
else
{
// k_iters_per_sk_block is the floor of avg each ck block loop over tiles.
// we need to decide how many iters for each sk block
// let m = k_iters_per_sk_block
// some of the sk block (little) will cover m iters, some (big) will cover m+1
// we have
// 1) l + b = sk_blocks
// 2) l * m + b * (m + 1) = sk_total_iters
// => (l + b) * m + b = sk_total_iters
// => sk_blocks * m + b = sk_total_iters
// => b = sk_total_iters - m * sk_blocks
// NOTE: big could be zero
uint32_t k_iters_per_sk_block = sk_total_iters / sk_num_blocks;
sk_num_big_blocks = sk_total_iters - k_iters_per_sk_block * sk_num_blocks;
k_iters_per_big_block = k_iters_per_sk_block + 1;
dp_num_blocks = dp_tiles;
dp_start_block_idx = (sk_num_blocks + num_cu - 1) / num_cu * num_cu;
}
}
n_tiles = MDiv2(math::integer_divide_ceil(n, NPerBlock));
reduction_start_block_idx = dp_start_block_idx + dp_num_blocks;
if constexpr(ReductionStrategy == StreamKReductionStrategy::Reduction)
{
uint32_t upper_big = math::lcm(k_iters_per_big_block, k_iters_per_tile.get());
uint32_t upper_little = math::lcm(k_iters_per_big_block - 1, k_iters_per_tile.get());
eqav_tiles_big = MDiv(upper_big / k_iters_per_tile.get());
eqav_tiles_little = MDiv(upper_little / k_iters_per_tile.get());
}
#if 0
printf("cu:%d, occupancy:%d, grids:%d, num_tiles:%d, dp_tiles:%d, sk_num_big_blocks:%d, "
"sk_num_blocks:%d, "
"sk_total_iters:%d, dp_start_block_idx:%d, dp_iters_per_block:%d, dp_num_blocks:%d, "
"k_iters_per_tile:%d, k_iters_per_big_block:%d, reduction_start_block_idx:%u, "
"sk_tiles:%u, workspace(acc float):%u\n",
num_cu,
occupancy,
get_grid_dims().x,
num_tiles,
dp_tiles,
sk_num_big_blocks,
sk_num_blocks,
sk_total_iters,
dp_start_block_idx,
dp_iters_per_block,
dp_num_blocks,
k_iters_per_tile.get(),
k_iters_per_big_block,
reduction_start_block_idx,
get_sk_tiles(),
get_workspace_size(sizeof(float)));
#endif
}
__host__ __device__ uint32_t get_sk_total_iters() const
{
uint32_t sk_total_iters = sk_num_big_blocks * k_iters_per_big_block +
(sk_num_blocks - sk_num_big_blocks) * (k_iters_per_big_block - 1);
return sk_total_iters;
}
__host__ __device__ uint32_t get_sk_tiles() const
{
// tiles for sk
uint32_t sk_total_iters = get_sk_total_iters();
return k_iters_per_tile.div(sk_total_iters);
}
__host__ __device__ dim3 get_grid_dims() const
{
if constexpr(ReductionStrategy == StreamKReductionStrategy::Reduction)
{
return dim3(reduction_start_block_idx + get_sk_tiles(), 1, 1);
}
else
return dim3(reduction_start_block_idx, 1, 1);
}
__device__ uint32_t get_block_idx() const
{
// TODO: swizzle block index for better locality
return __builtin_amdgcn_readfirstlane(blockIdx.x);
}
__device__ void
get_block_itr(uint32_t block_idx, uint32_t& iter_start, uint32_t& iter_end) const
{
if(block_idx < sk_num_big_blocks)
{
iter_start = block_idx * k_iters_per_big_block;
iter_end = iter_start + k_iters_per_big_block;
}
else if(block_idx < sk_num_blocks)
{
iter_start = (sk_num_big_blocks * k_iters_per_big_block) +
(block_idx - sk_num_big_blocks) * (k_iters_per_big_block - 1);
iter_end = iter_start + (k_iters_per_big_block - 1);
}
else if(block_idx >= dp_start_block_idx)
{
uint32_t sk_total_iters = get_sk_total_iters();
uint32_t dp_iters_per_block = k_iters_per_tile.get();
iter_start = sk_total_iters + (block_idx - dp_start_block_idx) * dp_iters_per_block;
iter_end = iter_start + dp_iters_per_block;
}
}
__device__ uint32_t get_current_iter_length(uint32_t iter_start,
uint32_t iter_end,
uint32_t total_iter_length) const
{
uint32_t iter_length_mod, iter_length_quo /*unused*/;
k_iters_per_tile.divmod(iter_end, iter_length_quo, iter_length_mod);
uint32_t current_iter_length = math::min(
iter_length_mod == 0 ? (iter_end - iter_start) : iter_length_mod, total_iter_length);
return current_iter_length;
}
__device__ uint32_t get_tile_idx(uint32_t iter) const { return k_iters_per_tile.div(iter); }
__device__ void
get_tile_idx_with_offset(uint32_t iter, uint32_t& tile_idx, uint32_t& iter_offset) const
{
k_iters_per_tile.divmod(iter, tile_idx, iter_offset);
}
__device__ auto tile_to_spatial(uint32_t tile_idx, uint32_t m, uint32_t n) const
{
uint32_t m_tile_idx, n_tile_idx;
uint32_t n_tiles_value = math::integer_divide_ceil(n, NPerBlock);
n_tiles.divmod(tile_idx, n_tiles_value, m_tile_idx, n_tile_idx);
// swizzle tile
uint32_t m_tiles = math::integer_divide_ceil(m, MPerBlock);
uint32_t tile_swizzle_sub_m_rem = m_tiles % tile_swizzle_sub_m;
const auto sub_m_adapt = (m_tile_idx < (m_tiles - tile_swizzle_sub_m_rem))
? tile_swizzle_sub_m
: tile_swizzle_sub_m_rem;
uint32_t m_tile_idx_sub0, m_tile_idx_sub1;
m_tile_idx_sub0 = m_tile_idx / tile_swizzle_sub_m;
m_tile_idx_sub1 = m_tile_idx % tile_swizzle_sub_m;
uint32_t tile_idx_local = n_tile_idx + m_tile_idx_sub1 * n_tiles_value;
uint32_t m_tile_idx_with_adapt, n_tile_idx_with_adapt;
n_tile_idx_with_adapt = tile_idx_local / sub_m_adapt;
m_tile_idx_with_adapt = tile_idx_local % sub_m_adapt;
return make_tuple(m_tile_idx_with_adapt + m_tile_idx_sub0 * tile_swizzle_sub_m,
n_tile_idx_with_adapt);
}
__host__ __device__ uint32_t get_workspace_size_for_acc(uint32_t acc_element_bytes) const
{
static constexpr uint32_t alignment = 128;
uint32_t acc_buffer_bytes =
MPerBlock * NPerBlock * get_total_acc_buffers() * acc_element_bytes;
return (acc_buffer_bytes + alignment - 1) / alignment * alignment;
}
__host__ __device__ uint32_t get_workspace_size_for_semaphore() const
{
return get_sk_tiles() * sizeof(uint32_t);
}
__host__ __device__ uint32_t get_workspace_size(uint32_t acc_element_bytes) const
{
return get_workspace_size_for_acc(acc_element_bytes) + get_workspace_size_for_semaphore();
}
__host__ __device__ uint32_t get_tile_intersections(uint32_t tiles_,
const MDiv& eqav_tiles_) const
{
uint32_t tile_idx_ = tiles_ == 0 ? 0 : (tiles_ - 1);
uint32_t max_eqav_tiles_ = eqav_tiles_.get() - 1;
uint32_t quo_, rem_;
eqav_tiles_.divmod(tile_idx_, quo_, rem_);
return quo_ * max_eqav_tiles_ + rem_;
}
__host__ __device__ uint32_t get_tiles_cover_sk_block(uint32_t num_sk_blocks_,
uint32_t iters_per_sk_block_) const
{
return k_iters_per_tile.div(num_sk_blocks_ * iters_per_sk_block_ + k_iters_per_tile.get() -
1);
}
__host__ __device__ uint32_t get_total_acc_buffers() const
{
uint32_t tiles_cover_big_blocks =
get_tiles_cover_sk_block(sk_num_big_blocks, k_iters_per_big_block);
uint32_t tiles_cover_little_blocks =
get_tiles_cover_sk_block(sk_num_blocks - sk_num_big_blocks, k_iters_per_big_block - 1);
uint32_t total_intersec_big =
get_tile_intersections(tiles_cover_big_blocks, eqav_tiles_big);
uint32_t total_intersec_little =
get_tile_intersections(tiles_cover_little_blocks, eqav_tiles_little);
return sk_num_blocks + total_intersec_big + total_intersec_little;
}
__device__ uint32_t get_acc_buffer_offset_from_tile(uint32_t tile_idx_) const
{
// TODO: from big to little
uint32_t tiles_cover_big_blocks =
get_tiles_cover_sk_block(sk_num_big_blocks, k_iters_per_big_block);
if(tile_idx_ < tiles_cover_big_blocks)
{
uint32_t touched_sk_blocks =
(tile_idx_ * k_iters_per_tile.get() + k_iters_per_big_block - 1) /
k_iters_per_big_block;
uint32_t current_intersec = get_tile_intersections(tile_idx_, eqav_tiles_big);
return touched_sk_blocks + current_intersec;
}
else
{
uint32_t iters_per_little_sk_block = k_iters_per_big_block - 1;
uint32_t tile_idx_little_reverse = get_sk_tiles() - tile_idx_;
uint32_t touched_sk_blocks =
(tile_idx_little_reverse * k_iters_per_tile.get() + iters_per_little_sk_block - 1) /
iters_per_little_sk_block;
uint32_t current_intersec =
get_tile_intersections(tile_idx_little_reverse, eqav_tiles_little);
return get_total_acc_buffers() - (touched_sk_blocks + current_intersec);
}
}
__device__ uint32_t get_acc_buffer_offset_from_block(uint32_t block_idx_) const
{
uint32_t iters_per_big_sk_block = k_iters_per_big_block;
uint32_t iters_per_little_sk_block = k_iters_per_big_block - 1;
if(block_idx_ < sk_num_big_blocks)
{
uint32_t touched_tiles = k_iters_per_tile.div(block_idx_ * iters_per_big_sk_block +
k_iters_per_tile.get() - 1);
uint32_t current_intersec = get_tile_intersections(touched_tiles, eqav_tiles_big);
return block_idx_ + current_intersec;
}
else
{
uint32_t block_idx_little_reverse = sk_num_blocks - block_idx_;
uint32_t touched_tiles = k_iters_per_tile.div(
block_idx_little_reverse * iters_per_little_sk_block + k_iters_per_tile.get() - 1);
uint32_t current_intersec = get_tile_intersections(touched_tiles, eqav_tiles_little);
return get_total_acc_buffers() - (block_idx_little_reverse + current_intersec);
}
}
};
} // namespace ck } // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment