Commit 05ee41c3 authored by Rosty Geyyer's avatar Rosty Geyyer
Browse files

Merge branch 'develop' into lwpck-471

parents 37116c98 ad541ad6
......@@ -529,6 +529,7 @@ struct GridwiseWelfordSecondHalfBatchNormForwardFinal
auto result_inv_var_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
resultSaveInvVariance, mean_var_grid_desc_m.GetElementSpaceSize());
// calculate inv-variance as 1/sqrt(epsilon+variance)
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
welford_var_thread_buf(I) =
type_convert<AccDataType>(1.0f) / sqrt(epsilon + welford_var_thread_buf[I]);
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/data_type.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_welford.hpp"
#include "ck/tensor_operation/gpu/block/reduction_functions_blockwise.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_welford.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace ck {
template <typename GridwiseWelfordSecondHalfReduceFirstHalf_,
typename XDataType,
typename DyDataType,
typename AccDataType,
typename ScaleDataType,
typename DscaleDbiasDataType,
typename MeanVarDataType,
typename DyElementwiseOp,
typename XYGridDesc_M_K,
typename MeanVarGridDesc_M,
typename MeanVarCountGridDesc_M_K,
typename DscaleDbiasGridDesc_M_G>
__global__ void kernel_welford_second_half_reduce_first_half(
const XYGridDesc_M_K x_grid_desc_m_k,
const XYGridDesc_M_K dy_grid_desc_m_k,
const MeanVarGridDesc_M mean_var_grid_desc_m,
const MeanVarCountGridDesc_M_K mean_var_count_grid_desc_m_k,
const DscaleDbiasGridDesc_M_G dscale_dbias_grid_desc_m_g,
index_t blkgroup_size,
index_t num_xy_k_block_tile_iteration,
index_t num_mean_var_count_k_block_tile_iteration,
AccDataType epsilon,
bool haveSavedMeanInvVar,
const MeanVarDataType* const __restrict__ p_savedMean,
const MeanVarDataType* const __restrict__ p_savedInvVar,
const MeanVarDataType* const __restrict__ p_in_welford_mean,
const MeanVarDataType* const __restrict__ p_in_welford_variance,
const int32_t* const __restrict__ p_in_welford_count,
const DyElementwiseOp dy_elementwise_op,
MeanVarDataType* const __restrict__ p_out_welford_mean,
MeanVarDataType* const __restrict__ p_out_welford_inv_variance,
const XDataType* const __restrict__ p_x,
const DyDataType* const __restrict__ p_dy,
DscaleDbiasDataType* const __restrict__ p_reduce_dscale,
DscaleDbiasDataType* const __restrict__ p_reduce_dbias)
{
GridwiseWelfordSecondHalfReduceFirstHalf_::Run(x_grid_desc_m_k,
dy_grid_desc_m_k,
mean_var_grid_desc_m,
mean_var_count_grid_desc_m_k,
dscale_dbias_grid_desc_m_g,
blkgroup_size,
num_xy_k_block_tile_iteration,
num_mean_var_count_k_block_tile_iteration,
epsilon,
haveSavedMeanInvVar,
p_savedMean,
p_savedInvVar,
p_in_welford_mean,
p_in_welford_variance,
p_in_welford_count,
dy_elementwise_op,
p_out_welford_mean,
p_out_welford_inv_variance,
p_x,
p_dy,
p_reduce_dscale,
p_reduce_dbias);
};
template <typename XDataType,
typename DyDataType,
typename AccDataType,
typename ScaleDataType,
typename DscaleDbiasDataType,
typename MeanVarDataType,
typename DyElementwiseOp,
typename XYGridDesc_M_K,
typename MeanVarGridDesc_M,
typename MeanVarCountGridDesc_M_K,
typename DscaleDbiasGridDesc_M_G,
index_t BlockSize,
index_t MThreadClusterSize,
index_t KThreadClusterSize,
index_t MThreadSliceSize,
index_t KThreadSliceSize,
index_t XDyVectorDim,
index_t XSrcVectorSize,
index_t DySrcVectorSize,
index_t MeanVarSrcVectorSize>
struct GridwiseWelfordSecondHalfReduceFirstHalf
{
static_assert((XDyVectorDim == 0 && MThreadSliceSize % XSrcVectorSize == 0 &&
MThreadSliceSize % DySrcVectorSize == 0) ||
(XDyVectorDim == 1 && KThreadSliceSize % XSrcVectorSize == 0 &&
KThreadSliceSize % DySrcVectorSize == 0),
"Invalid thread slice sizes and/or vector sizes configuration, please check!");
static constexpr bool reorder_thread_cluster = (XDyVectorDim == 0);
using ThreadClusterLengths_M_K = Sequence<MThreadClusterSize, KThreadClusterSize>;
using ThreadBufferDimAccessOrder =
typename conditional<reorder_thread_cluster, Sequence<1, 0>, Sequence<0, 1>>::type;
using ThreadClusterArrangeOrder =
typename conditional<reorder_thread_cluster, Sequence<1, 0>, Sequence<0, 1>>::type;
static constexpr auto thread_cluster_desc =
make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{});
using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{})));
using ThreadReduceSrcDesc_M_1 = decltype(
make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{}, Number<1>{})));
using ThreadReduceDstDesc_M =
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{})));
using ThreadwiseWelford =
ThreadwiseWelfordMerge<AccDataType, ThreadReduceSrcDesc_M_1, ThreadReduceDstDesc_M>;
using BlockwiseWelford = BlockwiseWelford<AccDataType,
BlockSize,
ThreadClusterLengths_M_K,
ThreadClusterArrangeOrder>;
using BlockwiseReduce = PartitionedBlockwiseReduction<AccDataType,
BlockSize,
ThreadClusterLengths_M_K,
ThreadClusterArrangeOrder,
ck::reduce::Add,
false>;
using ThreadwiseReduce = ThreadwiseReduction<AccDataType,
ThreadReduceSrcDesc_M_K,
ThreadReduceDstDesc_M,
ck::reduce::Add,
false>;
using PassThroughOp = tensor_operation::element_wise::PassThrough;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
// clang-format off
// Two of the steps of Multiblock BatchNorm Backward
// Step 1: Second half of Welford method to calculate mean and variance, as well as getting inv-variance = 1/sqrt(epsilon+variance)
// Step 2: First half of Reduction: dbias = sum(dy), dscale = sum(dy * (x-mean) * inv-variance)
// clang-format on
__device__ static void Run(const XYGridDesc_M_K& x_grid_desc_m_k,
const XYGridDesc_M_K& dy_grid_desc_m_k,
const MeanVarGridDesc_M& mean_var_grid_desc_m,
const MeanVarCountGridDesc_M_K& mean_var_count_grid_desc_m_k,
const DscaleDbiasGridDesc_M_G& dscale_dbias_grid_desc_m_g,
index_t blkgroup_size,
index_t num_xy_k_block_tile_iteration,
index_t num_mean_var_count_k_block_tile_iteration,
AccDataType epsilon,
bool haveSavedMeanInvVar,
const MeanVarDataType* const __restrict__ p_savedMean,
const MeanVarDataType* const __restrict__ p_savedInvVar,
const MeanVarDataType* const __restrict__ p_in_welford_mean,
const MeanVarDataType* const __restrict__ p_in_welford_variance,
const int32_t* const __restrict__ p_in_welford_count,
const DyElementwiseOp dy_elementwise_op,
MeanVarDataType* const __restrict__ p_out_welford_mean,
MeanVarDataType* const __restrict__ p_out_welford_inv_variance,
const XDataType* const __restrict__ p_x,
const DyDataType* const __restrict__ p_dy,
DscaleDbiasDataType* const __restrict__ p_reduce_dscale,
DscaleDbiasDataType* const __restrict__ p_reduce_dbias)
{
__shared__ AccDataType p_reduce_work_buffer[BlockSize];
auto reduce_work_buf =
make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_buffer, BlockSize);
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * 1, true>
in_welford_mean_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * 1, true>
in_welford_var_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, int32_t, MThreadSliceSize * 1, true>
in_welford_count_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true>
welford_mean_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true>
welford_var_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, int32_t, MThreadSliceSize, true>
welford_count_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true>& mean_thread_buf =
welford_mean_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true>&
inv_var_thread_buf = welford_var_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
x_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
dy_thread_buf;
// buffer of values of dy * (x-mean) * inv-variance, used as input of Blockwise reduction
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
tmp1_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true>
reduce_dscale_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true>
reduce_dbias_thread_buf;
const index_t thread_local_id = get_thread_local_1d_id();
const index_t block_global_id = get_block_1d_id();
const index_t blkgroup_id = block_global_id / blkgroup_size;
const index_t block_local_id = block_global_id % blkgroup_size;
const auto thread_cluster_idx =
thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));
const auto thread_m_cluster_id = thread_cluster_idx[I0];
const auto thread_k_cluster_id = thread_cluster_idx[I1];
using ThreadBufferLengths_M_K = Sequence<MThreadSliceSize, KThreadSliceSize>;
using ThreadBufferLengths_M = Sequence<MThreadSliceSize>;
using ThreadBufferLengths_M_1 = Sequence<MThreadSliceSize, 1>;
constexpr auto thread_buffer_desc_m_k = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{}));
constexpr auto thread_buffer_desc_m =
make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{}));
constexpr auto thread_buffer_desc_m_1 = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<1>{}));
// clang-format off
// Step 1: load existing mean and inv-variance, or do final welford reduction on mean and variance as well as get inv-variance = 1/sqrt(epsilon+variance)
// clang-format on
if(haveSavedMeanInvVar)
{
const auto mean_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_savedMean, mean_var_grid_desc_m.GetElementSpaceSize());
const auto inv_var_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_savedInvVar, mean_var_grid_desc_m.GetElementSpaceSize());
auto threadwise_mean_inv_var_load =
ThreadwiseTensorSliceTransfer_v2<MeanVarDataType,
AccDataType,
MeanVarGridDesc_M,
decltype(thread_buffer_desc_m),
ThreadBufferLengths_M,
Sequence<0>,
0,
MeanVarSrcVectorSize,
1,
true>(
mean_var_grid_desc_m,
make_multi_index(blkgroup_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize));
threadwise_mean_inv_var_load.Run(mean_var_grid_desc_m,
mean_global_buf,
thread_buffer_desc_m,
make_tuple(I0),
mean_thread_buf);
threadwise_mean_inv_var_load.Run(mean_var_grid_desc_m,
inv_var_global_buf,
thread_buffer_desc_m,
make_tuple(I0),
inv_var_thread_buf);
}
else
{
const auto welford_mean_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_welford_mean, mean_var_count_grid_desc_m_k.GetElementSpaceSize());
const auto welford_var_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_welford_variance, mean_var_count_grid_desc_m_k.GetElementSpaceSize());
const auto welford_count_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_welford_count, mean_var_count_grid_desc_m_k.GetElementSpaceSize());
auto threadwise_mean_var_load_m_k =
ThreadwiseTensorSliceTransfer_v2<AccDataType,
AccDataType,
MeanVarCountGridDesc_M_K,
decltype(thread_buffer_desc_m_1),
ThreadBufferLengths_M_1,
Sequence<0, 1>,
1,
1,
1,
true>(
mean_var_count_grid_desc_m_k,
make_multi_index(blkgroup_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * 1));
auto threadwise_count_load_m_k =
ThreadwiseTensorSliceTransfer_v2<int32_t,
int32_t,
MeanVarCountGridDesc_M_K,
decltype(thread_buffer_desc_m_1),
ThreadBufferLengths_M_1,
Sequence<0, 1>,
1,
1,
1,
true>(
mean_var_count_grid_desc_m_k,
make_multi_index(blkgroup_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * 1));
constexpr auto mean_var_count_thread_copy_step_m_k =
make_multi_index(0, KThreadClusterSize * 1);
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
welford_mean_thread_buf(I) = type_convert<AccDataType>(0.0f);
welford_var_thread_buf(I) = type_convert<AccDataType>(0.0f);
welford_count_thread_buf(I) = 0;
});
for(index_t reducedTiles = 0; reducedTiles < num_mean_var_count_k_block_tile_iteration;
++reducedTiles)
{
threadwise_mean_var_load_m_k.Run(mean_var_count_grid_desc_m_k,
welford_mean_global_buf,
thread_buffer_desc_m_1,
make_tuple(I0, I0),
in_welford_mean_thread_buf);
threadwise_mean_var_load_m_k.Run(mean_var_count_grid_desc_m_k,
welford_var_global_buf,
thread_buffer_desc_m_1,
make_tuple(I0, I0),
in_welford_var_thread_buf);
threadwise_count_load_m_k.Run(mean_var_count_grid_desc_m_k,
welford_count_global_buf,
thread_buffer_desc_m_1,
make_tuple(I0, I0),
in_welford_count_thread_buf);
ThreadwiseWelford::Run(in_welford_mean_thread_buf,
in_welford_var_thread_buf,
in_welford_count_thread_buf,
welford_mean_thread_buf,
welford_var_thread_buf,
welford_count_thread_buf);
threadwise_mean_var_load_m_k.MoveSrcSliceWindow(
mean_var_count_grid_desc_m_k, mean_var_count_thread_copy_step_m_k);
threadwise_count_load_m_k.MoveSrcSliceWindow(mean_var_count_grid_desc_m_k,
mean_var_count_thread_copy_step_m_k);
}
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
if constexpr(I > 0)
block_sync_lds();
BlockwiseWelford::Run(welford_mean_thread_buf(I),
welford_var_thread_buf(I),
welford_count_thread_buf(I));
});
// calculate inv-variance as 1/sqrt(epsilon+variance), stored in place of variance
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
welford_var_thread_buf(I) =
type_convert<AccDataType>(1.0) / sqrt(welford_var_thread_buf[I] + epsilon);
});
if(block_local_id == 0 && thread_k_cluster_id == 0)
{
auto threadwise_mean_inv_var_store =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
MeanVarDataType,
decltype(thread_buffer_desc_m),
MeanVarGridDesc_M,
PassThroughOp,
ThreadBufferLengths_M,
Sequence<0>,
0,
1,
InMemoryDataOperationEnum::Set,
1,
true>(
mean_var_grid_desc_m,
make_multi_index(blkgroup_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize),
PassThroughOp{});
auto mean_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_out_welford_mean, mean_var_grid_desc_m.GetElementSpaceSize());
auto inv_var_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_out_welford_inv_variance, mean_var_grid_desc_m.GetElementSpaceSize());
threadwise_mean_inv_var_store.Run(thread_buffer_desc_m,
make_tuple(I0),
mean_thread_buf,
mean_var_grid_desc_m,
mean_global_buf);
threadwise_mean_inv_var_store.Run(thread_buffer_desc_m,
make_tuple(I0),
inv_var_thread_buf,
mean_var_grid_desc_m,
inv_var_global_buf);
};
};
const index_t workSizePerBlock = K_BlockTileSize * num_xy_k_block_tile_iteration;
auto threadwise_x_load = ThreadwiseTensorSliceTransfer_v2<XDataType,
AccDataType,
XYGridDesc_M_K,
decltype(thread_buffer_desc_m_k),
ThreadBufferLengths_M_K,
ThreadBufferDimAccessOrder,
XDyVectorDim,
XSrcVectorSize,
1,
true>(
x_grid_desc_m_k,
make_multi_index(blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize,
workSizePerBlock * block_local_id +
thread_k_cluster_id * KThreadSliceSize));
auto threadwise_dy_load = ThreadwiseTensorSliceTransfer_v2<DyDataType,
AccDataType,
XYGridDesc_M_K,
decltype(thread_buffer_desc_m_k),
ThreadBufferLengths_M_K,
ThreadBufferDimAccessOrder,
XDyVectorDim,
DySrcVectorSize,
1,
true>(
dy_grid_desc_m_k,
make_multi_index(blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize,
workSizePerBlock * block_local_id +
thread_k_cluster_id * KThreadSliceSize));
const auto x_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_x, x_grid_desc_m_k.GetElementSpaceSize());
const auto dy_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_dy, dy_grid_desc_m_k.GetElementSpaceSize());
constexpr auto xy_thread_copy_step_m_k = make_multi_index(0, K_BlockTileSize);
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
reduce_dscale_thread_buf(I) = type_convert<AccDataType>(0);
reduce_dbias_thread_buf(I) = type_convert<AccDataType>(0);
});
// clang-format off
// Step 2: first-half of reduction: dbias = sum(dy), dscale = sum(dy * (x-mean) * inv-variance)
// clang-format on
for(index_t reducedTiles = 0; reducedTiles < num_xy_k_block_tile_iteration; ++reducedTiles)
{
threadwise_x_load.Run(x_grid_desc_m_k,
x_global_buf,
thread_buffer_desc_m_k,
make_tuple(I0, I0),
x_thread_buf);
threadwise_dy_load.Run(dy_grid_desc_m_k,
dy_global_buf,
thread_buffer_desc_m_k,
make_tuple(I0, I0),
dy_thread_buf);
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset =
thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK));
dy_elementwise_op(dy_thread_buf(Number<offset>{}),
dy_thread_buf[Number<offset>{}]);
AccDataType norm_x = (x_thread_buf[Number<offset>{}] - mean_thread_buf[iM]) *
inv_var_thread_buf[iM];
tmp1_thread_buf(Number<offset>{}) = norm_x * dy_thread_buf[Number<offset>{}];
});
});
ThreadwiseReduce::Reduce(tmp1_thread_buf, reduce_dscale_thread_buf);
ThreadwiseReduce::Reduce(dy_thread_buf, reduce_dbias_thread_buf);
threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, xy_thread_copy_step_m_k);
threadwise_dy_load.MoveSrcSliceWindow(dy_grid_desc_m_k, xy_thread_copy_step_m_k);
};
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
if constexpr(I > 0)
block_sync_lds();
BlockwiseReduce::Reduce(reduce_work_buf, reduce_dscale_thread_buf(I));
block_sync_lds();
BlockwiseReduce::Reduce(reduce_work_buf, reduce_dbias_thread_buf(I));
});
auto threadwise_dscale_dbias_store =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
DscaleDbiasDataType,
decltype(thread_buffer_desc_m_1),
DscaleDbiasGridDesc_M_G,
PassThroughOp,
ThreadBufferLengths_M_1,
Sequence<0, 1>,
1,
1,
InMemoryDataOperationEnum::Set,
1,
true>(
dscale_dbias_grid_desc_m_g,
make_multi_index(blkgroup_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
block_local_id),
PassThroughOp{});
auto reduce_dscale_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_reduce_dscale, dscale_dbias_grid_desc_m_g.GetElementSpaceSize());
auto reduce_dbias_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_reduce_dbias, dscale_dbias_grid_desc_m_g.GetElementSpaceSize());
if(thread_k_cluster_id == 0)
{
threadwise_dscale_dbias_store.Run(thread_buffer_desc_m_1,
make_tuple(I0, I0),
reduce_dscale_thread_buf,
dscale_dbias_grid_desc_m_g,
reduce_dscale_global_buf);
threadwise_dscale_dbias_store.Run(thread_buffer_desc_m_1,
make_tuple(I0, I0),
reduce_dbias_thread_buf,
dscale_dbias_grid_desc_m_g,
reduce_dbias_global_buf);
};
};
};
} // namespace ck
......@@ -364,14 +364,16 @@ struct BlockToCTileMap_KSplit_M00_N00_M01_N01
index_t M01 = 1,
index_t N01 = 1,
index_t KSplit = 1)
: M01_(M01),
: c_grid_desc_m_n_(c_grid_desc_m_n),
M01_(M01),
N01_(N01),
KSplit_(KSplit),
underlying_map_(GetBlockToCTileMap(c_grid_desc_m_n, M01, N01, KSplit))
{
}
__host__ constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const
__host__ __device__ constexpr index_t
CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const
{
const auto M0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I0), MPerBlock);
const auto N0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I1), NPerBlock);
......@@ -387,7 +389,10 @@ struct BlockToCTileMap_KSplit_M00_N00_M01_N01
template <typename TopIdx>
__host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const
{
return underlying_map_.CalculateBottomIndex(idx_top);
static_assert(TopIdx::Size() == 1);
return underlying_map_.CalculateBottomIndex(
make_multi_index(idx_top[I0] % CalculateGridSize()));
}
template <typename CTileIdx, typename CTileDim>
......@@ -418,6 +423,11 @@ struct BlockToCTileMap_KSplit_M00_N00_M01_N01
}
private:
__device__ constexpr index_t CalculateGridSize() const
{
return CalculateGridSize(c_grid_desc_m_n_);
}
__host__ static constexpr auto GetBlockToCTileMap(const CGridDesc_M_N& c_grid_desc_m_n,
index_t M01,
index_t N01,
......@@ -450,6 +460,7 @@ struct BlockToCTileMap_KSplit_M00_N00_M01_N01
return c_blockid_to_ksplit_m0_n0_block_cluster_adaptor;
}
CGridDesc_M_N c_grid_desc_m_n_;
index_t M01_, N01_, KSplit_;
using UnderlyingMap = decltype(GetBlockToCTileMap(CGridDesc_M_N{}, 1, 1, 1));
UnderlyingMap underlying_map_;
......
......@@ -8,7 +8,7 @@
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp"
......@@ -74,7 +74,8 @@ template <typename FloatAB,
index_t CShuffleNXdlPerWavePerShuffle,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopScheduler LoopSched>
LoopScheduler LoopSched,
PipelineVersion PipelineVer = PipelineVersion::v1>
struct GridwiseBatchedGemmGemm_Xdl_CShuffle
{
static_assert(LoopSched == LoopScheduler::Default,
......@@ -101,7 +102,8 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
using GridwiseGemmPipe = GridwiseGemmPipeline_v1<NumGemmKPrefetchStage>;
using GridwiseGemmPipe = remove_cvref_t<decltype(
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage>())>;
template <typename ABlockDesc_AK0_M_AK1>
__host__ __device__ static constexpr auto
......@@ -486,8 +488,9 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
// gridwise GEMM pipeline
// Only supports LoopScheduler::Default
const auto gridwise_gemm_pipeline =
GridwiseGemmPipeline_v1_Selector<NumGemmKPrefetchStage, LoopScheduler::Default>();
const auto gridwise_gemm_pipeline = GridwiseGemmPipeline_Selector<PipelineVer,
NumGemmKPrefetchStage,
LoopScheduler::Default>();
const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
(a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
......
......@@ -8,7 +8,7 @@
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp"
......@@ -77,7 +77,8 @@ template <typename FloatAB,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopScheduler LoopSched,
bool PadN,
bool MaskOutUpperTriangle>
bool MaskOutUpperTriangle,
PipelineVersion PipelineVer = PipelineVersion::v1>
struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
{
static_assert(LoopSched == LoopScheduler::Default,
......@@ -108,7 +109,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
using GridwiseGemmPipe = GridwiseGemmPipeline_v1<NumGemmKPrefetchStage>;
using GridwiseGemmPipe = remove_cvref_t<decltype(
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage>())>;
template <typename ABlockDesc_AK0_M_AK1>
__host__ __device__ static constexpr auto
......@@ -503,8 +505,9 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
// gridwise GEMM pipeline
// Only supports LoopScheduler::Default
const auto gridwise_gemm_pipeline =
GridwiseGemmPipeline_v1_Selector<NumGemmKPrefetchStage, LoopScheduler::Default>();
const auto gridwise_gemm_pipeline = GridwiseGemmPipeline_Selector<PipelineVer,
NumGemmKPrefetchStage,
LoopScheduler::Default>();
const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
(a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
......@@ -691,12 +694,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
FloatGemmAcc,
decltype(threadid_to_m_n_thread_cluster_adaptor),
decltype(thread_cluster_desc_m_n),
decltype(thread_slice_desc_m_n)
#if CK_WORKAROUND_SWDEV_XXXXXX_ATTN_KERNEL_CLANG_CANNOT_SCAVENGE_REGISTER
,
true
#endif
>{};
decltype(thread_slice_desc_m_n)>{};
const index_t num_gemm1_k_block_outer_loop =
b_grid_desc_bk0_n_bk1.GetLength(I1) / NPerBlock;
......@@ -876,6 +874,14 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
}
} // end gemm1
// workaround compiler issue; see ck/ck.hpp
if constexpr(CK_WORKAROUND_SWDEV_XXXXXX_BF16_ATTEN_FWD_GFX908_ISSUE == 1 &&
is_same_v<FloatAB, bhalf_t> && MPerBlock == 256 && NPerBlock == 128 &&
Gemm1NPerBlock == 128)
{
__builtin_amdgcn_sched_barrier(0);
}
constexpr auto c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 =
gemm1_blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
constexpr auto cm0 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I0);
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/data_type.hpp"
#include "ck/utility/math_v2.hpp"
#include "ck/utility/reduction_operator.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_welford.hpp"
#include "ck/tensor_operation/gpu/block/reduction_functions_blockwise.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_welford.hpp"
#include "ck/tensor_operation/gpu/thread/reduction_functions_threadwise.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace ck {
template <typename GridwiseBatchrNormBackwardWithBlockwiseWelford_,
typename XDataType,
typename DyDataType,
typename DxDataType,
typename AccDataType,
typename ScaleDataType,
typename DscaleDbiasDataType,
typename MeanVarDataType,
typename DyElementwiseOp,
typename XYGridDesc_M_K,
typename ScaleBiasGridDesc_M,
typename MeanVarGridDesc_M,
typename GetReduceCountPerThreadFunctor>
__global__ void kernel_batchnorm_backward_with_blockwise_welford(
const XYGridDesc_M_K x_grid_desc_m_k,
const XYGridDesc_M_K dy_grid_desc_m_k,
const XYGridDesc_M_K dx_grid_desc_m_k,
const ScaleBiasGridDesc_M scale_grid_desc_m,
const ScaleBiasGridDesc_M dscale_dbias_grid_desc_m,
const MeanVarGridDesc_M mean_var_grid_desc_m,
const GetReduceCountPerThreadFunctor get_reduce_count_per_thread,
long_index_t reduce_size,
index_t num_k_block_tile_iteration,
AccDataType epsilon,
const XDataType* const __restrict__ p_x,
const DyDataType* const __restrict__ p_dy,
const ScaleDataType* const __restrict__ p_scale,
bool haveSavedMeanInvVar,
const MeanVarDataType* const __restrict__ p_savedMean,
const MeanVarDataType* const __restrict__ p_savedInvVar,
const DyElementwiseOp dy_elementwise_op,
DxDataType* const __restrict__ p_dx,
DscaleDbiasDataType* const __restrict__ p_dscale,
DscaleDbiasDataType* const __restrict__ p_dbias)
{
GridwiseBatchrNormBackwardWithBlockwiseWelford_::Run(x_grid_desc_m_k,
dy_grid_desc_m_k,
dx_grid_desc_m_k,
scale_grid_desc_m,
dscale_dbias_grid_desc_m,
mean_var_grid_desc_m,
get_reduce_count_per_thread,
reduce_size,
num_k_block_tile_iteration,
epsilon,
p_x,
p_dy,
p_scale,
haveSavedMeanInvVar,
p_savedMean,
p_savedInvVar,
dy_elementwise_op,
p_dx,
p_dscale,
p_dbias);
};
template <typename XDataType,
typename DyDataType,
typename DxDataType,
typename AccDataType,
typename ScaleDataType,
typename DscaleDbiasDataType,
typename MeanVarDataType,
typename DyElementwiseOp,
typename XYGridDesc_M_K,
typename ScaleBiasGridDesc_M,
typename MeanVarGridDesc_M,
typename GetReduceCountPerThreadFunctor,
index_t BlockSize,
index_t MThreadClusterSize,
index_t KThreadClusterSize,
index_t MThreadSliceSize,
index_t KThreadSliceSize,
index_t XDyDxVectorDim,
index_t XSrcVectorSize,
index_t DySrcVectorSize,
index_t DxDstVectorSize,
index_t ScaleSrcVectorSize,
index_t DscaleDbiasDstVectorSize,
index_t MeanVarSrcVectorSize>
struct GridwiseBatchNormBackwardWithBlockwiseWelford
{
static_assert((XDyDxVectorDim == 0 && MThreadSliceSize % XSrcVectorSize == 0 &&
MThreadSliceSize % DySrcVectorSize == 0 &&
MThreadSliceSize % DxDstVectorSize == 0) ||
(XDyDxVectorDim == 1 && KThreadSliceSize % XSrcVectorSize == 0 &&
KThreadSliceSize % DySrcVectorSize == 0 &&
KThreadSliceSize % DxDstVectorSize == 0),
"Invalid thread slice sizes and/or vector sizes configuration, please check!");
static constexpr bool reorder_thread_cluster = (XDyDxVectorDim == 0);
using ThreadClusterLengths_M_K = Sequence<MThreadClusterSize, KThreadClusterSize>;
using ThreadBufferDimAccessOrder =
typename conditional<reorder_thread_cluster, Sequence<1, 0>, Sequence<0, 1>>::type;
using ThreadClusterArrangeOrder =
typename conditional<reorder_thread_cluster, Sequence<1, 0>, Sequence<0, 1>>::type;
static constexpr auto thread_cluster_desc =
make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{});
using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{})));
using ThreadReduceDstDesc_M =
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{})));
using ThreadwiseWelford =
ThreadwiseWelford<AccDataType, ThreadReduceSrcDesc_M_K, ThreadReduceDstDesc_M>;
using BlockwiseWelford = BlockwiseWelford<AccDataType,
BlockSize,
ThreadClusterLengths_M_K,
ThreadClusterArrangeOrder>;
using BlockwiseReduce = PartitionedBlockwiseReduction<AccDataType,
BlockSize,
ThreadClusterLengths_M_K,
ThreadClusterArrangeOrder,
ck::reduce::Add,
false>;
using ThreadwiseReduce = ThreadwiseReduction<AccDataType,
ThreadReduceSrcDesc_M_K,
ThreadReduceDstDesc_M,
ck::reduce::Add,
false>;
using PassThroughOp = tensor_operation::element_wise::PassThrough;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
// clang-format off
// Blockwise BatchNorm Backward
// Input: x, dy, scale, savedMean and savedInvVar (optional), reduce_size
// Output: dx, dscale, dbias
// Step 1: calculating mean and inv-variance using welford method (if savedMean/savedInvVar not available), where inv-variance = 1/sqrt(epsilon+variance)
// Step 2: reduction: dbias = sum(dy), dscale = sum(dy *(x-mean) * inv-variance)
// Step 3: calculating dx = 1/reduce_size * inv-variance * scale * (reduce_size * dy - dbias - dscale * (x - mean) * inv-variance)) elementwise-ly
// clang-format on
__device__ static void Run(const XYGridDesc_M_K x_grid_desc_m_k,
const XYGridDesc_M_K dy_grid_desc_m_k,
const XYGridDesc_M_K dx_grid_desc_m_k,
const ScaleBiasGridDesc_M scale_grid_desc_m,
const ScaleBiasGridDesc_M dscale_dbias_grid_desc_m,
const MeanVarGridDesc_M mean_var_grid_desc_m,
const GetReduceCountPerThreadFunctor get_reduce_count_per_thread,
long_index_t reduce_size,
index_t num_k_block_tile_iteration,
AccDataType epsilon,
const XDataType* const __restrict__ p_x,
const DyDataType* const __restrict__ p_dy,
const ScaleDataType* const __restrict__ p_scale,
bool haveSavedMeanInvVar,
const MeanVarDataType* const __restrict__ p_savedMean,
const MeanVarDataType* const __restrict__ p_savedInvVar,
const DyElementwiseOp dy_elementwise_op,
DxDataType* const __restrict__ p_dx,
DscaleDbiasDataType* const __restrict__ p_dscale,
DscaleDbiasDataType* const __restrict__ p_dbias)
{
using ck::math::sqrt;
__shared__ AccDataType p_reduce_work_buffer[BlockSize];
auto reduce_work_buf =
make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_buffer, BlockSize);
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
x_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
dy_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
dx_thread_buf;
// buffer of values of dy * (x-mean) * invVariance, used as input of Blockwise reduction
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
tmp1_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> scale_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> mean_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> var_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true>&
inv_var_thread_buf = var_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> dscale_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> dbias_thread_buf;
const index_t thread_local_id = get_thread_local_1d_id();
const index_t block_global_id = get_block_1d_id();
const auto thread_cluster_idx =
thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));
const auto thread_m_cluster_id = thread_cluster_idx[I0];
const auto thread_k_cluster_id = thread_cluster_idx[I1];
using ThreadBufferLengths_M_K = Sequence<MThreadSliceSize, KThreadSliceSize>;
using ThreadBufferLengths_M = Sequence<MThreadSliceSize>;
constexpr auto thread_buffer_desc_m_k = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{}));
constexpr auto thread_buffer_desc_m =
make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{}));
auto threadwise_x_load = ThreadwiseTensorSliceTransfer_v2<XDataType,
AccDataType,
XYGridDesc_M_K,
decltype(thread_buffer_desc_m_k),
ThreadBufferLengths_M_K,
ThreadBufferDimAccessOrder,
XDyDxVectorDim,
XSrcVectorSize,
1,
true>(
x_grid_desc_m_k,
make_multi_index(block_global_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * KThreadSliceSize));
auto threadwise_dy_load = ThreadwiseTensorSliceTransfer_v2<DyDataType,
AccDataType,
XYGridDesc_M_K,
decltype(thread_buffer_desc_m_k),
ThreadBufferLengths_M_K,
ThreadBufferDimAccessOrder,
XDyDxVectorDim,
XSrcVectorSize,
1,
true>(
dy_grid_desc_m_k,
make_multi_index(block_global_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * KThreadSliceSize));
auto threadwise_dx_store =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
DxDataType,
decltype(thread_buffer_desc_m_k),
XYGridDesc_M_K,
PassThroughOp,
ThreadBufferLengths_M_K,
ThreadBufferDimAccessOrder,
XDyDxVectorDim,
DxDstVectorSize,
InMemoryDataOperationEnum::Set,
1,
true>(
dx_grid_desc_m_k,
make_multi_index(block_global_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * KThreadSliceSize),
PassThroughOp{});
auto threadwise_scale_load =
ThreadwiseTensorSliceTransfer_v2<ScaleDataType,
AccDataType,
ScaleBiasGridDesc_M,
decltype(thread_buffer_desc_m),
ThreadBufferLengths_M,
Sequence<0>,
0,
ScaleSrcVectorSize,
1,
true>(
scale_grid_desc_m,
make_multi_index(block_global_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize));
auto threadwise_dscale_dbias_store =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
DscaleDbiasDataType,
decltype(thread_buffer_desc_m),
ScaleBiasGridDesc_M,
PassThroughOp,
ThreadBufferLengths_M,
Sequence<0>,
0,
DscaleDbiasDstVectorSize,
InMemoryDataOperationEnum::Set,
1,
true>(
dscale_dbias_grid_desc_m,
make_multi_index(block_global_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize),
PassThroughOp{});
constexpr auto thread_copy_fwd_step_m_k = make_multi_index(0, K_BlockTileSize);
constexpr auto thread_copy_bwd_step_m_k = make_multi_index(0, -K_BlockTileSize);
const auto x_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_x, x_grid_desc_m_k.GetElementSpaceSize());
const auto dy_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_dy, dy_grid_desc_m_k.GetElementSpaceSize());
auto dx_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_dx, dx_grid_desc_m_k.GetElementSpaceSize());
const auto scale_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_scale, scale_grid_desc_m.GetElementSpaceSize());
auto dscale_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_dscale, dscale_dbias_grid_desc_m.GetElementSpaceSize());
auto dbias_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_dbias, dscale_dbias_grid_desc_m.GetElementSpaceSize());
// clang-format off
// Step 1: calculating mean and inv-variance using welford method (if savedMean/savedInvVar not available), where inv-variance = 1/sqrt(epsilon+variance)
// clang-format on
if(haveSavedMeanInvVar)
{
const auto mean_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_savedMean, mean_var_grid_desc_m.GetElementSpaceSize());
const auto inv_var_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_savedInvVar, mean_var_grid_desc_m.GetElementSpaceSize());
auto threadwise_mean_inv_var_load =
ThreadwiseTensorSliceTransfer_v2<MeanVarDataType,
AccDataType,
MeanVarGridDesc_M,
decltype(thread_buffer_desc_m),
ThreadBufferLengths_M,
Sequence<0>,
0,
MeanVarSrcVectorSize,
1,
true>(
mean_var_grid_desc_m,
make_multi_index(block_global_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize));
threadwise_mean_inv_var_load.Run(mean_var_grid_desc_m,
mean_global_buf,
thread_buffer_desc_m,
make_tuple(I0),
mean_thread_buf);
threadwise_mean_inv_var_load.Run(mean_var_grid_desc_m,
inv_var_global_buf,
thread_buffer_desc_m,
make_tuple(I0),
inv_var_thread_buf);
}
else
{
auto threadwise_welford = ThreadwiseWelford();
threadwise_welford.max_count_ = get_reduce_count_per_thread(thread_k_cluster_id);
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
mean_thread_buf(I) = type_convert<AccDataType>(0.0f);
var_thread_buf(I) = type_convert<AccDataType>(0.0f);
});
for(index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles)
{
threadwise_x_load.Run(x_grid_desc_m_k,
x_global_buf,
thread_buffer_desc_m_k,
make_tuple(I0, I0),
x_thread_buf);
threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_fwd_step_m_k);
threadwise_welford.Run(x_thread_buf, mean_thread_buf, var_thread_buf);
}
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
if constexpr(I > 0)
block_sync_lds();
int count = threadwise_welford.cur_count_;
BlockwiseWelford::Run(mean_thread_buf(I), var_thread_buf(I), count);
});
// calculate inv-variance as 1/sqrt(epsilon+variance)
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
inv_var_thread_buf(I) =
type_convert<AccDataType>(1.0) / sqrt(var_thread_buf[I] + epsilon);
});
threadwise_x_load.SetSrcSliceOrigin(
x_grid_desc_m_k,
make_multi_index(block_global_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * KThreadSliceSize));
};
// clang-format off
// Step 2: reduction: dbias = sum(dy), dscale = sum(dy *(x-mean) * inv-variance)
// clang-format on
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
dscale_thread_buf(I) = type_convert<AccDataType>(0);
dbias_thread_buf(I) = type_convert<AccDataType>(0);
});
for(index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles)
{
threadwise_x_load.Run(x_grid_desc_m_k,
x_global_buf,
thread_buffer_desc_m_k,
make_tuple(I0, I0),
x_thread_buf);
threadwise_dy_load.Run(dx_grid_desc_m_k,
dy_global_buf,
thread_buffer_desc_m_k,
make_tuple(I0, I0),
dy_thread_buf);
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset =
thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK));
dy_elementwise_op(dy_thread_buf(Number<offset>{}),
dy_thread_buf[Number<offset>{}]);
AccDataType norm_x = (x_thread_buf[Number<offset>{}] - mean_thread_buf[iM]) *
inv_var_thread_buf[iM];
tmp1_thread_buf(Number<offset>{}) = norm_x * dy_thread_buf[Number<offset>{}];
});
});
ThreadwiseReduce::Reduce(tmp1_thread_buf, dscale_thread_buf);
ThreadwiseReduce::Reduce(dy_thread_buf, dbias_thread_buf);
threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_fwd_step_m_k);
threadwise_dy_load.MoveSrcSliceWindow(dy_grid_desc_m_k, thread_copy_fwd_step_m_k);
};
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
if constexpr(I > 0)
block_sync_lds();
BlockwiseReduce::Reduce(reduce_work_buf, dscale_thread_buf(I));
block_sync_lds();
BlockwiseReduce::Reduce(reduce_work_buf, dbias_thread_buf(I));
});
if(thread_k_cluster_id == 0)
{
threadwise_dscale_dbias_store.Run(thread_buffer_desc_m,
make_tuple(I0),
dscale_thread_buf,
dscale_dbias_grid_desc_m,
dscale_global_buf);
threadwise_dscale_dbias_store.Run(thread_buffer_desc_m,
make_tuple(I0),
dbias_thread_buf,
dscale_dbias_grid_desc_m,
dbias_global_buf);
};
// clang-format off
// Step 3: calculating dx = 1/reduce_size * inv-variance * scale * (reduce_size * dy - dbias - dscale * (x - mean) * inv-variance)) elementwise-ly
// clang-format on
threadwise_scale_load.Run(scale_grid_desc_m,
scale_global_buf,
thread_buffer_desc_m,
make_tuple(I0),
scale_thread_buf);
auto thread_copy_tail_m_k = (num_k_block_tile_iteration - 1) * thread_copy_fwd_step_m_k;
threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_bwd_step_m_k);
threadwise_dy_load.MoveSrcSliceWindow(dy_grid_desc_m_k, thread_copy_bwd_step_m_k);
threadwise_dx_store.MoveDstSliceWindow(dx_grid_desc_m_k, thread_copy_tail_m_k);
AccDataType inv_reduce_size =
type_convert<AccDataType>(1.0) / type_convert<AccDataType>(reduce_size);
for(index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles)
{
threadwise_x_load.Run(x_grid_desc_m_k,
x_global_buf,
thread_buffer_desc_m_k,
make_tuple(I0, I0),
x_thread_buf);
threadwise_dy_load.Run(dy_grid_desc_m_k,
dy_global_buf,
thread_buffer_desc_m_k,
make_tuple(I0, I0),
dy_thread_buf);
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
AccDataType multiplier =
inv_reduce_size * inv_var_thread_buf[iM] * scale_thread_buf[iM];
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset =
thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK));
dy_elementwise_op(dy_thread_buf(Number<offset>{}),
dy_thread_buf[Number<offset>{}]);
AccDataType norm_x = (x_thread_buf[Number<offset>{}] - mean_thread_buf[iM]) *
inv_var_thread_buf[iM];
AccDataType tmpVal = norm_x * dscale_thread_buf[iM];
dx_thread_buf(Number<offset>{}) =
multiplier *
(type_convert<AccDataType>(reduce_size) * dy_thread_buf[Number<offset>{}] -
dbias_thread_buf[iM] - tmpVal);
});
});
threadwise_dx_store.Run(thread_buffer_desc_m_k,
make_tuple(I0, I0),
dx_thread_buf,
dx_grid_desc_m_k,
dx_global_buf);
threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_bwd_step_m_k);
threadwise_dy_load.MoveSrcSliceWindow(dy_grid_desc_m_k, thread_copy_bwd_step_m_k);
threadwise_dx_store.MoveDstSliceWindow(dx_grid_desc_m_k, thread_copy_bwd_step_m_k);
}
}
};
} // namespace ck
......@@ -441,6 +441,7 @@ struct GridwiseBatchNormForwardWithBlockwiseWelford
auto result_inv_var_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
resultSaveInvVariance, mean_var_grid_desc_m.GetElementSpaceSize());
// calculate inv-variance as 1/sqrt(epsilon+variance), stored in place of variance
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
var_thread_buf(I) =
type_convert<AccDataType>(1.0f) / sqrt(epsilon + var_thread_buf[I]);
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/data_type.hpp"
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_welford.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_welford.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace ck {
// X = Elementwise(input1, input2, input3, ...)
// Y = Normalization(X, beta, gamma)
template <typename InDataTypePointerTuple,
typename XDataType,
typename GammaDataType,
typename BetaDataType,
typename YDataType,
typename AccDataType,
typename XElementwiseOperation,
typename YElementwiseOperation,
typename InGrid2dDescTuple,
typename GridDesc_M_K,
index_t BlockSize,
index_t MThreadClusterSize,
index_t KThreadClusterSize,
index_t MThreadSliceSize,
index_t KThreadSliceSize,
index_t XSrcVectorDim,
index_t XSrcVectorSize,
index_t GammaSrcVectorDim,
index_t GammaSrcVectorSize,
index_t BetaSrcVectorDim,
index_t BetaSrcVectorSize,
index_t YDstVectorDim,
index_t YDstVectorSize,
bool SweepOnce>
struct GridwiseElementwiseLayernormWelfordVariance_mk_to_mk
{
static_assert((XSrcVectorDim == 0 && MThreadSliceSize % XSrcVectorSize == 0) ||
(XSrcVectorDim == 1 && KThreadSliceSize % XSrcVectorSize == 0),
"Invalid thread slice sizes and/or vector sizes configuration, please check!");
static_assert((YDstVectorDim == 0 && MThreadSliceSize % YDstVectorSize == 0) ||
(YDstVectorDim == 1 && KThreadSliceSize % YDstVectorSize == 0),
"Invalid thread slice sizes and/or vector sizes configuration, please check!");
static constexpr index_t NumInput = InDataTypePointerTuple::Size();
static constexpr bool reorder_thread_cluster = (XSrcVectorDim == 0);
using ThreadClusterLengths_M_K = Sequence<MThreadClusterSize, KThreadClusterSize>;
using ThreadBufferDimAccessOrder =
typename conditional<reorder_thread_cluster, Sequence<1, 0>, Sequence<0, 1>>::type;
using ThreadClusterArrangeOrder =
typename conditional<reorder_thread_cluster, Sequence<1, 0>, Sequence<0, 1>>::type;
static constexpr auto thread_cluster_desc =
make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{});
using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<XSrcVectorSize>{})));
using ThreadReduceDstDesc_M =
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{})));
using ThreadwiseWelford =
ThreadwiseWelford<AccDataType, ThreadReduceSrcDesc_M_K, ThreadReduceDstDesc_M>;
using BlockwiseWelford = BlockwiseWelford<AccDataType,
BlockSize,
ThreadClusterLengths_M_K,
ThreadClusterArrangeOrder>;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
static constexpr index_t K_BlockTileStepSize = KThreadClusterSize * XSrcVectorSize;
static constexpr auto XThreadBufferNumber = Number<KThreadSliceSize / XSrcVectorSize>{};
static constexpr auto GammaThreadBufferNumber = Number<KThreadSliceSize / GammaSrcVectorSize>{};
static constexpr auto BetaThreadBufferNumber = Number<KThreadSliceSize / BetaSrcVectorSize>{};
static constexpr auto YThreadBufferNumber = Number<KThreadSliceSize / YDstVectorSize>{};
__device__ static int GetKPerThread(const GridDesc_M_K& x_grid_desc_m_k,
int thread_k_cluster_id)
{
int kPerBlock = x_grid_desc_m_k.GetTransforms()[I2].GetUpperLengths()[I0];
int kPerThread =
kPerBlock < K_BlockTileSize ? 0 : KThreadSliceSize * (kPerBlock / K_BlockTileSize);
int kPerBlockTail = kPerBlock - kPerThread * KThreadClusterSize;
if(kPerBlockTail > 0)
{
static_for<0, XThreadBufferNumber, 1>{}([&](auto i) {
int thread_max_len =
(thread_k_cluster_id + 1) * XSrcVectorSize + K_BlockTileStepSize * i;
int delta = thread_max_len - kPerBlockTail;
delta = math::clamp(thread_max_len - kPerBlockTail, 0, XSrcVectorSize);
kPerThread += XSrcVectorSize - delta;
});
}
return kPerThread;
}
__device__ static void Run(const InGrid2dDescTuple in_grid_2d_desc_tuple,
const GridDesc_M_K& x_grid_desc_m_k,
const GridDesc_M_K& gamma_grid_desc_m_k,
const GridDesc_M_K& beta_grid_desc_m_k,
const GridDesc_M_K& y_grid_desc_m_k,
index_t num_k_block_tile_iteration,
AccDataType epsilon,
const InDataTypePointerTuple p_in_global_tuple,
XDataType* const __restrict__ p_x_lds,
const GammaDataType* const __restrict__ p_gamma_global,
const BetaDataType* const __restrict__ p_beta_global,
YDataType* const __restrict__ p_y_global,
const XElementwiseOperation x_elementwise_op,
const YElementwiseOperation y_elementwise_op)
{
if constexpr(SweepOnce)
{
num_k_block_tile_iteration = 1;
}
const index_t thread_local_id = get_thread_local_1d_id();
const index_t block_global_id = get_block_1d_id();
const index_t grid_size = get_grid_size();
auto in_global_buf_tuple = generate_tuple(
[&](auto I) {
static_assert(in_grid_2d_desc_tuple[I].GetNumOfDimension() ==
2); // matrix dimension
return make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_global_tuple[I], in_grid_2d_desc_tuple[I].GetElementSpaceSize());
},
Number<NumInput>{});
auto y_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_y_global, y_grid_desc_m_k.GetElementSpaceSize());
auto x_lds_val_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
p_x_lds, x_grid_desc_m_k.GetElementSpaceSize() / grid_size);
auto in_thread_buf_tuple = generate_tuple(
[&](auto) {
return generate_tuple(
[&](auto) {
return StaticBuffer<AddressSpaceEnum::Vgpr,
AccDataType,
MThreadSliceSize * XSrcVectorSize,
true>{};
},
Number<NumInput>{});
},
Number<XThreadBufferNumber>{});
auto x_thread_buf = generate_tuple(
[&](auto) {
return StaticBuffer<AddressSpaceEnum::Vgpr,
AccDataType,
MThreadSliceSize * XSrcVectorSize,
true>{};
},
Number<XThreadBufferNumber>{});
auto gamma_thread_buf = generate_tuple(
[&](auto) {
return StaticBuffer<AddressSpaceEnum::Vgpr,
AccDataType,
MThreadSliceSize * GammaSrcVectorSize,
true>{};
},
Number<GammaThreadBufferNumber>{});
auto beta_thread_buf = generate_tuple(
[&](auto) {
return StaticBuffer<AddressSpaceEnum::Vgpr,
AccDataType,
MThreadSliceSize * BetaSrcVectorSize,
true>{};
},
Number<BetaThreadBufferNumber>{});
auto y_thread_buf = generate_tuple(
[&](auto) {
return StaticBuffer<AddressSpaceEnum::Vgpr,
AccDataType,
MThreadSliceSize * YDstVectorSize,
true>{};
},
Number<YThreadBufferNumber>{});
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> mean_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> var_thread_buf;
const auto thread_cluster_idx =
thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));
const auto thread_m_cluster_id = thread_cluster_idx[I0];
const auto thread_k_cluster_id = thread_cluster_idx[I1];
using ThreadBufferLengths_M_K = Sequence<MThreadSliceSize, XSrcVectorSize>;
constexpr auto thread_buffer_desc_m_k = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<XSrcVectorSize>{}));
auto in_global_load_tuple = generate_tuple(
[&](auto I) {
using DataTypePointer = remove_cvref_t<decltype(InDataTypePointerTuple{}[I])>;
using DataType = remove_cv_t<remove_pointer_t<DataTypePointer>>;
return ThreadwiseTensorSliceTransfer_v2<DataType,
AccDataType,
decltype(in_grid_2d_desc_tuple[I]),
decltype(thread_buffer_desc_m_k),
ThreadBufferLengths_M_K,
ThreadBufferDimAccessOrder,
XSrcVectorDim,
XSrcVectorSize,
1,
false>{
in_grid_2d_desc_tuple[I],
make_multi_index(block_global_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * XSrcVectorSize)};
},
Number<NumInput>{});
auto threadwise_x_load = ThreadwiseTensorSliceTransfer_v2<XDataType,
AccDataType,
GridDesc_M_K,
decltype(thread_buffer_desc_m_k),
ThreadBufferLengths_M_K,
ThreadBufferDimAccessOrder,
XSrcVectorDim,
XSrcVectorSize,
1,
true>(
x_grid_desc_m_k,
make_multi_index(thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * XSrcVectorSize));
auto threadwise_gamma_load =
ThreadwiseTensorSliceTransfer_v2<GammaDataType,
AccDataType,
GridDesc_M_K,
decltype(thread_buffer_desc_m_k),
ThreadBufferLengths_M_K,
ThreadBufferDimAccessOrder,
GammaSrcVectorDim,
GammaSrcVectorSize,
1,
true>(
gamma_grid_desc_m_k,
make_multi_index(block_global_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * GammaSrcVectorSize));
auto threadwise_beta_load =
ThreadwiseTensorSliceTransfer_v2<BetaDataType,
AccDataType,
GridDesc_M_K,
decltype(thread_buffer_desc_m_k),
ThreadBufferLengths_M_K,
ThreadBufferDimAccessOrder,
BetaSrcVectorDim,
BetaSrcVectorSize,
1,
true>(
beta_grid_desc_m_k,
make_multi_index(block_global_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * BetaSrcVectorSize));
using PassThrough = tensor_operation::element_wise::PassThrough;
PassThrough pass_through_op;
auto threadwise_x_store =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
XDataType,
decltype(thread_buffer_desc_m_k),
GridDesc_M_K,
PassThrough,
ThreadBufferLengths_M_K,
ThreadBufferDimAccessOrder,
XSrcVectorDim,
XSrcVectorSize,
InMemoryDataOperationEnum::Set,
1,
true>(
x_grid_desc_m_k,
make_multi_index(thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * XSrcVectorSize),
pass_through_op);
auto threadwise_y_store =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
YDataType,
decltype(thread_buffer_desc_m_k),
GridDesc_M_K,
YElementwiseOperation,
ThreadBufferLengths_M_K,
ThreadBufferDimAccessOrder,
YDstVectorDim,
YDstVectorSize,
InMemoryDataOperationEnum::Set,
1,
true>(
y_grid_desc_m_k,
make_multi_index(block_global_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * YDstVectorSize),
y_elementwise_op);
// Copy x from Cache
// one pass: fwd, second pass: bwd
constexpr auto thread_copy_fwd_step_m_k = make_multi_index(0, K_BlockTileStepSize);
constexpr auto thread_copy_bwd_step_m_k =
make_multi_index(0, SweepOnce ? 0 : -K_BlockTileSize);
const auto gamma_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_gamma_global, gamma_grid_desc_m_k.GetElementSpaceSize());
const auto beta_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_beta_global, beta_grid_desc_m_k.GetElementSpaceSize());
auto threadwise_welford = ThreadwiseWelford();
threadwise_welford.max_count_ = GetKPerThread(x_grid_desc_m_k, thread_k_cluster_id);
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
mean_thread_buf(I) = type_convert<AccDataType>(0.0f);
var_thread_buf(I) = type_convert<AccDataType>(0.0f);
});
for(index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles)
{
static_for<0, XThreadBufferNumber, 1>{}([&](auto iK0) {
static_for<0, NumInput, 1>{}([&](auto I) { // input load loop
in_global_load_tuple(I).Run(in_grid_2d_desc_tuple[I],
in_global_buf_tuple[I],
thread_buffer_desc_m_k,
make_tuple(I0, I0),
in_thread_buf_tuple(iK0)(I));
in_global_load_tuple(I).MoveSrcSliceWindow(in_grid_2d_desc_tuple[I],
thread_copy_fwd_step_m_k);
});
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { // input add loop
static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) {
constexpr auto offset_m_k =
thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1));
// get reference to in data
const auto in_data_refs = generate_tie(
// return type should be lvalue
[&](auto I) -> const auto& {
return in_thread_buf_tuple(iK0)(I)(Number<offset_m_k>{});
},
Number<NumInput>{});
// get reference to dst data
auto out_data_refs = generate_tie(
// return type should be lvalue
[&](auto) -> auto& { return x_thread_buf(iK0)(Number<offset_m_k>{}); },
I1);
unpack2(x_elementwise_op, out_data_refs, in_data_refs);
});
});
threadwise_welford.Run(x_thread_buf[iK0], mean_thread_buf, var_thread_buf);
if constexpr(!SweepOnce)
{
threadwise_x_store.Run(thread_buffer_desc_m_k,
make_tuple(I0, I0),
x_thread_buf(iK0),
x_grid_desc_m_k,
x_lds_val_buf);
threadwise_x_store.MoveDstSliceWindow(x_grid_desc_m_k,
thread_copy_fwd_step_m_k);
}
});
}
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
if constexpr(I > 0)
block_sync_lds();
int count = threadwise_welford.cur_count_;
BlockwiseWelford::Run(mean_thread_buf(I), var_thread_buf(I), count);
});
auto thread_copy_tail_m_k =
(num_k_block_tile_iteration - 1) * XThreadBufferNumber * thread_copy_fwd_step_m_k;
if constexpr(!SweepOnce)
threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_tail_m_k);
threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k, thread_copy_tail_m_k);
threadwise_beta_load.MoveSrcSliceWindow(beta_grid_desc_m_k, thread_copy_tail_m_k);
threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k, thread_copy_tail_m_k);
for(index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles)
{
if constexpr(!SweepOnce)
{
static_for<0, XThreadBufferNumber, 1>{}([&](auto i) {
threadwise_x_load.Run(x_grid_desc_m_k,
x_lds_val_buf,
thread_buffer_desc_m_k,
make_tuple(I0, I0),
x_thread_buf(i));
threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_fwd_step_m_k);
});
}
static_for<0, GammaThreadBufferNumber, 1>{}([&](auto i) {
threadwise_gamma_load.Run(gamma_grid_desc_m_k,
gamma_global_val_buf,
thread_buffer_desc_m_k,
make_tuple(I0, I0),
gamma_thread_buf(i));
threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k,
thread_copy_fwd_step_m_k);
});
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
auto divisor = 1 / __builtin_amdgcn_sqrtf(var_thread_buf(iM) + epsilon);
static_for<0, XThreadBufferNumber, 1>{}([&](auto iK0) {
static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) {
constexpr auto offset_m_k =
thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1));
// normalize
y_thread_buf(iK0)(Number<offset_m_k>{}) =
(x_thread_buf(iK0)(Number<offset_m_k>{}) - mean_thread_buf(iM)) *
divisor;
// gamma
y_thread_buf(iK0)(Number<offset_m_k>{}) =
y_thread_buf(iK0)(Number<offset_m_k>{}) *
gamma_thread_buf(iK0)(Number<offset_m_k>{});
});
});
});
static_for<0, BetaThreadBufferNumber, 1>{}([&](auto i) {
threadwise_beta_load.Run(beta_grid_desc_m_k,
beta_global_val_buf,
thread_buffer_desc_m_k,
make_tuple(I0, I0),
beta_thread_buf(i));
threadwise_beta_load.MoveSrcSliceWindow(beta_grid_desc_m_k,
thread_copy_fwd_step_m_k);
});
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
static_for<0, XThreadBufferNumber, 1>{}([&](auto iK0) {
static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) {
constexpr auto offset_m_k =
thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1));
// beta
y_thread_buf(iK0)(Number<offset_m_k>{}) =
y_thread_buf(iK0)(Number<offset_m_k>{}) +
beta_thread_buf(iK0)(Number<offset_m_k>{});
});
});
});
static_for<0, YThreadBufferNumber, 1>{}([&](auto i) {
threadwise_y_store.Run(thread_buffer_desc_m_k,
make_tuple(I0, I0),
y_thread_buf(i),
y_grid_desc_m_k,
y_global_val_buf);
threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k, thread_copy_fwd_step_m_k);
});
if constexpr(!SweepOnce)
threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, 2 * thread_copy_bwd_step_m_k);
threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k,
2 * thread_copy_bwd_step_m_k);
threadwise_beta_load.MoveSrcSliceWindow(beta_grid_desc_m_k,
2 * thread_copy_bwd_step_m_k);
threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k, 2 * thread_copy_bwd_step_m_k);
}
}
};
} // namespace ck
......@@ -8,7 +8,7 @@
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp"
......@@ -169,7 +169,8 @@ template <typename FloatAB,
typename CReduceThreadClusterLengths_MPerBlock_NPerBlock,
index_t CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock,
index_t CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock,
LoopScheduler LoopSched>
LoopScheduler LoopSched,
PipelineVersion PipelineVer = PipelineVersion::v1>
struct GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
{
static constexpr auto I0 = Number<0>{};
......@@ -189,7 +190,8 @@ struct GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
using GridwiseGemmPipe = GridwiseGemmPipeline_v1<NumGemmKPrefetchStage>;
using GridwiseGemmPipe = remove_cvref_t<decltype(
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage>())>;
__host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
{
......@@ -526,7 +528,7 @@ struct GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
// gridwise GEMM pipeline
const auto gridwise_gemm_pipeline =
GridwiseGemmPipeline_v1_Selector<NumGemmKPrefetchStage, LoopSched>();
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>();
const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
(a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
......
......@@ -8,7 +8,7 @@
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp"
......@@ -68,7 +68,8 @@ template <typename FloatAB,
typename CDRThreadTransferClusterLengths_MPerBlock_NPerBlock,
index_t CDEReduceThreadTransferScalarPerVector_NPerBlock,
index_t RThreadTransferDstScalarPerVector_MPerBlock,
LoopScheduler LoopSched>
LoopScheduler LoopSched,
PipelineVersion PipelineVer = PipelineVersion::v1>
struct GridwiseGemmMultipleDMultipleR_k0mk1_k0nk1_mn_xdl_cshuffle_v1
{
static constexpr index_t NumDTensor = DsDataType::Size();
......@@ -91,7 +92,8 @@ struct GridwiseGemmMultipleDMultipleR_k0mk1_k0nk1_mn_xdl_cshuffle_v1
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
using GridwiseGemmPipe = GridwiseGemmPipeline_v1<NumGemmKPrefetchStage>;
using GridwiseGemmPipe = remove_cvref_t<decltype(
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage>())>;
__host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
{
......@@ -495,7 +497,7 @@ struct GridwiseGemmMultipleDMultipleR_k0mk1_k0nk1_mn_xdl_cshuffle_v1
// gridwise GEMM pipeline
const auto gridwise_gemm_pipeline =
GridwiseGemmPipeline_v1_Selector<NumGemmKPrefetchStage, LoopSched>();
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>();
const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
(a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
......
......@@ -8,7 +8,7 @@
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7.hpp"
......@@ -66,7 +66,8 @@ template <typename ABDataType, // FIXME: don't assume A/B have same datatype
index_t CShuffleNXdlPerWavePerShuffle,
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CDEShuffleBlockTransferScalarPerVector_NPerBlock,
LoopScheduler LoopSched>
LoopScheduler LoopSched,
PipelineVersion PipelineVer = PipelineVersion::v1>
struct GridwiseGemmMultipleD_xdl_cshuffle
{
static constexpr index_t NumDTensor = DsDataType::Size();
......@@ -88,7 +89,8 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
using GridwiseGemmPipe = GridwiseGemmPipeline_v1<NumGemmKPrefetchStage>;
using GridwiseGemmPipe = remove_cvref_t<decltype(
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>())>;
__host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
{
......@@ -489,7 +491,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
// gridwise GEMM pipeline
const auto gridwise_gemm_pipeline =
GridwiseGemmPipeline_v1_Selector<NumGemmKPrefetchStage, LoopSched>();
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>();
const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
(a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v2.hpp"
namespace ck {
enum struct PipelineVersion
{
v1,
v2,
};
template <PipelineVersion PipelineVer,
index_t NumPrefetch = 1,
LoopScheduler LoopSched = LoopScheduler::Default>
constexpr auto GridwiseGemmPipeline_Selector()
{
if constexpr(PipelineVer == PipelineVersion::v1)
{
if constexpr(LoopSched == LoopScheduler::Default)
{
return GridwiseGemmPipeline_v1<NumPrefetch>{};
}
else if constexpr(LoopSched == LoopScheduler::Interwave)
{
return GridwiseGemmPipelineInterwave_v1<NumPrefetch>{};
}
}
else if constexpr(PipelineVer == PipelineVersion::v2)
{
return GridwiseGemmPipeline_v2{};
}
else
{
std::cerr << "GridwiseGemmPipeline configuration is not available" << std::endl;
}
}
} // namespace ck
......@@ -352,6 +352,7 @@ struct GridwiseGemmPipelineInterwave_v1<2> : public GridwiseGemmPipeline_v1<2>
{
};
// TODO: deprecate as GridwiseGemmPipeline_Selector covers the functionality
template <index_t NumPrefetch, LoopScheduler LoopSched>
constexpr auto GridwiseGemmPipeline_v1_Selector()
{
......
......@@ -8,7 +8,7 @@
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp"
......@@ -142,7 +142,8 @@ template <typename FloatAB,
typename CReduceThreadClusterLengths_MPerBlock_NPerBlock,
index_t CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock,
index_t CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock,
LoopScheduler LoopSched>
LoopScheduler LoopSched,
PipelineVersion PipelineVer = PipelineVersion::v1>
struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
{
static constexpr auto I0 = Number<0>{};
......@@ -162,7 +163,8 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
using GridwiseGemmPipe = GridwiseGemmPipeline_v1<NumGemmKPrefetchStage>;
using GridwiseGemmPipe = remove_cvref_t<decltype(
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage>())>;
__host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
{
......@@ -481,7 +483,7 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
// gridwise GEMM pipeline
const auto gridwise_gemm_pipeline =
GridwiseGemmPipeline_v1_Selector<NumGemmKPrefetchStage, LoopSched>();
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>();
const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
(a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
......
......@@ -8,8 +8,7 @@
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v2.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp"
......@@ -115,7 +114,8 @@ template <typename FloatAB,
index_t CShuffleNXdlPerWavePerShuffle,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopScheduler LoopSched>
LoopScheduler LoopSched,
PipelineVersion PipelineVer = PipelineVersion::v1>
struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
{
static constexpr auto I0 = Number<0>{};
......@@ -136,13 +136,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
// FIXME: pass GridwiseGemmPipe as a template arguement into GridwiseGemm
using GridwiseGemmPipe =
#if 1
remove_cvref_t<decltype(
GridwiseGemmPipeline_v1_Selector<NumGemmKPrefetchStage, LoopSched>())>;
#else
GridwiseGemmPipeline_v2;
#endif
using GridwiseGemmPipe = remove_cvref_t<decltype(
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>())>;
__host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
{
......
......@@ -8,7 +8,7 @@
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp"
......@@ -151,7 +151,8 @@ template <typename FloatAB,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
typename CReduceThreadClusterLengths_MPerBlock_NPerBlock,
index_t CReduceThreadCopySrcDstScalarPerVector_NPerBlock,
LoopScheduler LoopSched>
LoopScheduler LoopSched,
PipelineVersion PipelineVer = PipelineVersion::v1>
struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
{
static constexpr auto I0 = Number<0>{};
......@@ -171,7 +172,8 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
using GridwiseGemmPipe = GridwiseGemmPipeline_v1<NumGemmKPrefetchStage>;
using GridwiseGemmPipe = remove_cvref_t<decltype(
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage>())>;
__host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
{
......@@ -519,7 +521,7 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
// gridwise GEMM pipeline
const auto gridwise_gemm_pipeline =
GridwiseGemmPipeline_v1_Selector<NumGemmKPrefetchStage, LoopSched>();
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>();
const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
(a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
......
......@@ -8,7 +8,7 @@
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp"
......@@ -243,7 +243,8 @@ template <index_t BlockSize,
typename CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
bool ABlockLdsExtraM1Wrw = false,
bool BBlockLdsExtraN1Wrw = false,
index_t NumGemmKPrefetchStage = 1>
index_t NumGemmKPrefetchStage = 1,
PipelineVersion PipelineVer = PipelineVersion::v1>
struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
{
static constexpr auto I0 = Number<0>{};
......@@ -258,8 +259,10 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
// K1 should be Number<...>
static constexpr auto K1 = Number<K1Value>{};
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
using GridwiseGemmPipe = GridwiseGemmPipeline_v1<NumGemmKPrefetchStage>;
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
using GridwiseGemmPipe = remove_cvref_t<decltype(
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage>())>;
// M0/M1/M1Padding
static constexpr auto M1PerBlock = Number<ABlockLdsM1PerBlock>{};
......
......@@ -8,7 +8,7 @@
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
......@@ -109,7 +109,9 @@ template <index_t BlockSize,
typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector,
index_t NumGemmKPrefetchStage = 1>
index_t NumGemmKPrefetchStage = 1,
LoopScheduler LoopSched = make_default_loop_scheduler(),
PipelineVersion PipelineVer = PipelineVersion::v1>
struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
{
static constexpr auto I0 = Number<0>{};
......@@ -126,7 +128,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
using GridwiseGemmPipe = GridwiseGemmPipeline_v1<NumGemmKPrefetchStage>;
using GridwiseGemmPipe = remove_cvref_t<decltype(
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>())>;
__host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1()
{
......@@ -423,18 +426,18 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
// sanity check
auto blockwise_gemm =
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
FloatAB,
FloatAcc,
decltype(a_block_desc_k0_m_k1),
decltype(b_block_desc_k0_n_k1),
MPerXDL,
NPerXDL,
MXdlPerWave,
NXdlPerWave,
K1>{};
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
BlockSize,
FloatAB,
FloatAcc,
decltype(a_block_desc_k0_m_k1),
decltype(b_block_desc_k0_n_k1),
MPerXDL,
NPerXDL,
MXdlPerWave,
NXdlPerWave,
K1,
LoopSched>();
auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
......
......@@ -9,7 +9,7 @@
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp"
......@@ -117,7 +117,8 @@ template <
index_t CShuffleNXdlPerWavePerShuffle,
typename CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl,
index_t CBlockTransferScalarPerVector_NWaveNPerXdl,
index_t NumGemmKPrefetchStage = 1>
index_t NumGemmKPrefetchStage = 1,
PipelineVersion PipelineVer = PipelineVersion::v1>
struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
{
static constexpr auto I0 = Number<0>{};
......@@ -137,7 +138,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
using GridwiseGemmPipe = GridwiseGemmPipeline_v1<NumGemmKPrefetchStage>;
using GridwiseGemmPipe = remove_cvref_t<decltype(
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage>())>;
__host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
{
......
......@@ -8,7 +8,7 @@
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r2.hpp"
......@@ -123,7 +123,8 @@ template <
index_t CShuffleNXdlPerWavePerShuffle,
typename CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl,
index_t CBlockTransferScalarPerVector_NWaveNPerXdl,
index_t NumGemmKPrefetchStage = 1>
index_t NumGemmKPrefetchStage = 1,
PipelineVersion PipelineVer = PipelineVersion::v1>
struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2
{
static constexpr auto I0 = Number<0>{};
......@@ -140,7 +141,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
using GridwiseGemmPipe = GridwiseGemmPipeline_v1<NumGemmKPrefetchStage>;
using GridwiseGemmPipe = remove_cvref_t<decltype(
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage>())>;
__host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1()
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment