Unverified Commit 8c4897d1 authored by Rostyslav Geyyer's avatar Rostyslav Geyyer Committed by GitHub
Browse files

Merge branch 'develop' into lwpck-756

parents 9ba9ebec 9e86ebd6
...@@ -195,6 +195,51 @@ struct AddMultiply ...@@ -195,6 +195,51 @@ struct AddMultiply
} }
}; };
// C = A * B
// E = C x D0 + D1
struct MultiplyAdd
{
template <typename E, typename C, typename D0, typename D1>
__host__ __device__ void operator()(E& e, const C& c, const D0& d0, const D1& d1) const;
template <>
__host__ __device__ void operator()<half_t, half_t, half_t, half_t>(half_t& e,
const half_t& c,
const half_t& d0,
const half_t& d1) const
{
const half_t y = (c * d0) + d1;
e = y;
}
template <>
__host__ __device__ void operator()<half_t, float, half_t, half_t>(half_t& e,
const float& c,
const half_t& d0,
const half_t& d1) const
{
const half_t y = type_convert<half_t>(c) * d0 + d1;
e = y;
}
template <>
__host__ __device__ void operator()<float, float, half_t, half_t>(float& e,
const float& c,
const half_t& d0,
const half_t& d1) const
{
const float y = c * d0 + d1;
e = y;
}
template <>
__host__ __device__ void operator()<half_t, float, float, float>(half_t& e,
const float& c,
const float& d0,
const float& d1) const
{
const float y = c * d0 + d1;
e = y;
}
};
// E = FastGelu(C + D0 + D1) // E = FastGelu(C + D0 + D1)
struct AddAddFastGelu struct AddAddFastGelu
{ {
......
...@@ -39,6 +39,12 @@ struct PassThrough ...@@ -39,6 +39,12 @@ struct PassThrough
y = x; y = x;
} }
template <>
__host__ __device__ void operator()<half_t, float>(half_t& y, const float& x) const
{
y = type_convert<half_t>(x);
}
template <> template <>
__host__ __device__ void operator()<bhalf_t, bhalf_t>(bhalf_t& y, const bhalf_t& x) const __host__ __device__ void operator()<bhalf_t, bhalf_t>(bhalf_t& y, const bhalf_t& x) const
{ {
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/data_type.hpp"
#include "ck/utility/math_v2.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_welford.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_welford.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/utility/workgroup_synchronization.hpp"
namespace ck {
template <typename GridwiseMultiblockBatchNormForward_,
typename XDataType,
typename YDataType,
typename AccDataType,
typename ScaleDataType,
typename BiasDataType,
typename MeanVarDataType,
typename YElementwiseOp,
typename XYGridDesc_M_K,
typename MeanVarCountGridDesc_M_G,
typename MeanVarCountGridDesc_M_K,
typename ScaleBiasGridDesc_M,
typename MeanVarGridDesc_M,
typename GetReduceCountPerThreadFunctor>
__global__ void kernel_multiblock_batchnorm_forward(
const XYGridDesc_M_K x_grid_desc_m_k,
const XYGridDesc_M_K y_grid_desc_m_k,
const MeanVarCountGridDesc_M_G mean_var_count_grid_desc_m_g,
const MeanVarCountGridDesc_M_K mean_var_count_grid_desc_m_k,
const ScaleBiasGridDesc_M scale_grid_desc_m,
const ScaleBiasGridDesc_M bias_grid_desc_m,
const MeanVarGridDesc_M mean_var_grid_desc_m,
const GetReduceCountPerThreadFunctor get_reduce_count_per_thread,
index_t num_k_block_tile_iteration,
AccDataType epsilon,
const XDataType* const __restrict__ p_x,
MeanVarDataType* const __restrict__ p_welford_mean,
MeanVarDataType* const __restrict__ p_welford_variance,
int32_t* const __restrict__ p_welford_count,
int32_t* const __restrict__ p_control,
const ScaleDataType* const __restrict__ p_scale,
const BiasDataType* const __restrict__ p_bias,
const YElementwiseOp y_elementwise_op,
YDataType* const __restrict__ p_y,
bool updateMovingAverage,
AccDataType averageFactor,
MeanVarDataType* const __restrict__ resultRunningMean,
MeanVarDataType* const __restrict__ resultRunningVariance,
bool saveMeanInvVariance,
MeanVarDataType* const __restrict__ resultSaveMean,
MeanVarDataType* const __restrict__ resultSaveInvVariance)
{
GridwiseMultiblockBatchNormForward_::Run(x_grid_desc_m_k,
y_grid_desc_m_k,
mean_var_count_grid_desc_m_g,
mean_var_count_grid_desc_m_k,
scale_grid_desc_m,
bias_grid_desc_m,
mean_var_grid_desc_m,
get_reduce_count_per_thread,
num_k_block_tile_iteration,
epsilon,
p_x,
p_welford_mean,
p_welford_variance,
p_welford_count,
p_control,
p_scale,
p_bias,
y_elementwise_op,
p_y,
updateMovingAverage,
averageFactor,
resultRunningMean,
resultRunningVariance,
saveMeanInvVariance,
resultSaveMean,
resultSaveInvVariance);
};
template <typename XDataType,
typename YDataType,
typename AccDataType,
typename ScaleDataType,
typename BiasDataType,
typename MeanVarDataType,
typename YElementwiseOp,
typename XYGridDesc_M_K,
typename MeanVarCountGridDesc_M_G,
typename MeanVarCountGridDesc_M_K,
typename ScaleBiasGridDesc_M,
typename MeanVarGridDesc_M,
typename GetReduceCountPerThreadFunctor,
index_t BlockSize,
index_t MThreadClusterSize,
index_t KThreadClusterSize,
index_t MThreadSliceSize,
index_t KThreadSliceSize,
index_t XSrcYDstVectorDim,
index_t XSrcVectorSize,
index_t YDstVectorSize,
index_t ScaleSrcVectorSize,
index_t BiasSrcVectorSize,
index_t MeanVarSrcDstVectorSize>
struct GridwiseMultiblockBatchNormForward
{
static_assert((XSrcYDstVectorDim == 0 && MThreadSliceSize % XSrcVectorSize == 0) ||
(XSrcYDstVectorDim == 1 && KThreadSliceSize % XSrcVectorSize == 0),
"Invalid thread slice sizes and/or vector sizes configuration, please check!");
static_assert((XSrcYDstVectorDim == 0 && MThreadSliceSize % YDstVectorSize == 0) ||
(XSrcYDstVectorDim == 1 && KThreadSliceSize % YDstVectorSize == 0),
"Invalid thread slice sizes and/or vector sizes configuration, please check!");
static constexpr bool reorder_thread_cluster = (XSrcYDstVectorDim == 0);
using ThreadClusterLengths_M_K = Sequence<MThreadClusterSize, KThreadClusterSize>;
using ThreadBufferDimAccessOrder =
typename conditional<reorder_thread_cluster, Sequence<1, 0>, Sequence<0, 1>>::type;
using ThreadClusterArrangeOrder =
typename conditional<reorder_thread_cluster, Sequence<1, 0>, Sequence<0, 1>>::type;
static constexpr auto thread_cluster_desc =
make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{});
using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{})));
using ThreadReduceDstDesc_M =
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{})));
using ThreadReduceSrcDesc_M_1 = decltype(make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<1>{})));
using ThreadwiseWelford1 =
ThreadwiseWelford<AccDataType, ThreadReduceSrcDesc_M_K, ThreadReduceDstDesc_M>;
using ThreadwiseWelford2 =
ThreadwiseWelfordMerge<AccDataType, ThreadReduceSrcDesc_M_1, ThreadReduceDstDesc_M>;
using BlockwiseWelford1 = BlockwiseWelford<AccDataType,
BlockSize,
ThreadClusterLengths_M_K,
ThreadClusterArrangeOrder,
false>;
using BlockwiseWelford2 = BlockwiseWelford<AccDataType,
BlockSize,
ThreadClusterLengths_M_K,
ThreadClusterArrangeOrder,
true>;
using PassThroughOp = tensor_operation::element_wise::PassThrough;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
__device__ static void Run(const XYGridDesc_M_K& x_grid_desc_m_k,
const XYGridDesc_M_K& y_grid_desc_m_k,
const MeanVarCountGridDesc_M_G& mean_var_count_grid_desc_m_g,
const MeanVarCountGridDesc_M_K& mean_var_count_grid_desc_m_k,
const ScaleBiasGridDesc_M& scale_grid_desc_m,
const ScaleBiasGridDesc_M& bias_grid_desc_m,
const MeanVarGridDesc_M& mean_var_grid_desc_m,
const GetReduceCountPerThreadFunctor& get_reduce_count_per_thread,
index_t num_k_block_tile_iteration,
AccDataType epsilon,
const XDataType* const __restrict__ p_x,
MeanVarDataType* const __restrict__ p_welford_mean,
MeanVarDataType* const __restrict__ p_welford_variance,
int32_t* const __restrict__ p_welford_count,
int32_t* const __restrict__ p_control,
const ScaleDataType* const __restrict__ p_scale,
const BiasDataType* const __restrict__ p_bias,
const YElementwiseOp y_elementwise_op,
YDataType* const __restrict__ p_y,
bool updateMovingAverage,
AccDataType averageFactor,
MeanVarDataType* const __restrict__ resultRunningMean,
MeanVarDataType* const __restrict__ resultRunningVariance,
bool saveMeanInvVariance,
MeanVarDataType* const __restrict__ resultSaveMean,
MeanVarDataType* const __restrict__ resultSaveInvVariance)
{
using ck::math::sqrt;
const index_t blkgroup_size = mean_var_count_grid_desc_m_g.GetLength(I1);
const index_t thread_local_id = get_thread_local_1d_id();
const index_t block_global_id = get_block_1d_id();
const index_t blkgroup_id = block_global_id / blkgroup_size;
const index_t block_local_id = block_global_id % blkgroup_size;
if(block_local_id == 0)
gms_init(BlockSize / warpSize * blkgroup_size, &p_control[blkgroup_id * 2]);
const auto thread_cluster_idx =
thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));
const auto thread_m_cluster_id = thread_cluster_idx[I0];
const auto thread_k_cluster_id = thread_cluster_idx[I1];
using ThreadBufferLengths_M_K = Sequence<MThreadSliceSize, KThreadSliceSize>;
using ThreadBufferLengths_M = Sequence<MThreadSliceSize>;
using ThreadBufferLengths_M_1 = Sequence<MThreadSliceSize, 1>;
constexpr auto thread_buffer_desc_m_k = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{}));
constexpr auto thread_buffer_desc_m =
make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{}));
constexpr auto thread_buffer_desc_m_1 = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<1>{}));
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
x_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> mean_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> var_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, int32_t, MThreadSliceSize, true> count_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true>
tmp_mean_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true>
tmp_var_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, int32_t, MThreadSliceSize, true> tmp_count_thread_buf;
const index_t reduceSizePerBlock = K_BlockTileSize * num_k_block_tile_iteration;
auto threadwise_x_load = ThreadwiseTensorSliceTransfer_v2<XDataType,
AccDataType,
XYGridDesc_M_K,
decltype(thread_buffer_desc_m_k),
ThreadBufferLengths_M_K,
ThreadBufferDimAccessOrder,
XSrcYDstVectorDim,
XSrcVectorSize,
1,
true>(
x_grid_desc_m_k,
make_multi_index(blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize,
block_local_id * reduceSizePerBlock +
thread_k_cluster_id * KThreadSliceSize));
constexpr auto xy_copy_fwd_step_m_k = make_multi_index(0, K_BlockTileSize);
const auto x_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_x, x_grid_desc_m_k.GetElementSpaceSize());
// Step 1: each workgroup does local welford reduction
auto threadwise_welford_1 = ThreadwiseWelford1();
threadwise_welford_1.max_count_ =
get_reduce_count_per_thread(block_local_id, thread_k_cluster_id);
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
mean_thread_buf(I) = type_convert<AccDataType>(0.0f);
var_thread_buf(I) = type_convert<AccDataType>(0.0f);
});
for(index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles)
{
threadwise_x_load.Run(x_grid_desc_m_k,
x_global_val_buf,
thread_buffer_desc_m_k,
make_tuple(I0, I0),
x_thread_buf);
threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, xy_copy_fwd_step_m_k);
threadwise_welford_1.Run(x_thread_buf, mean_thread_buf, var_thread_buf);
}
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
if constexpr(I > 0)
block_sync_lds();
count_thread_buf(I) = threadwise_welford_1.cur_count_;
BlockwiseWelford1::Run(mean_thread_buf(I), var_thread_buf(I), count_thread_buf(I));
});
// Step 2: each workgroup writes its local welford result to workspace memory
auto mean_global_val_buf =
make_dynamic_buffer<AddressSpaceEnum::Global, AmdBufferCoherenceEnum::GLC>(
p_welford_mean, mean_var_count_grid_desc_m_g.GetElementSpaceSize());
auto var_global_val_buf =
make_dynamic_buffer<AddressSpaceEnum::Global, AmdBufferCoherenceEnum::GLC>(
p_welford_variance, mean_var_count_grid_desc_m_g.GetElementSpaceSize());
auto count_global_val_buf =
make_dynamic_buffer<AddressSpaceEnum::Global, AmdBufferCoherenceEnum::GLC>(
p_welford_count, mean_var_count_grid_desc_m_g.GetElementSpaceSize());
auto threadwise_mean_var_store_m_g =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
MeanVarDataType,
decltype(thread_buffer_desc_m_1),
MeanVarCountGridDesc_M_G,
PassThroughOp,
ThreadBufferLengths_M_1,
Sequence<0, 1>,
0,
1,
InMemoryDataOperationEnum::Set,
1,
true>(
mean_var_count_grid_desc_m_g,
make_multi_index(blkgroup_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
block_local_id),
PassThroughOp{});
auto threadwise_count_store_m_g =
ThreadwiseTensorSliceTransfer_v1r3<int32_t,
int32_t,
decltype(thread_buffer_desc_m_1),
MeanVarCountGridDesc_M_G,
PassThroughOp,
ThreadBufferLengths_M_1,
Sequence<0, 1>,
0,
1,
InMemoryDataOperationEnum::Set,
1,
true>(
mean_var_count_grid_desc_m_g,
make_multi_index(blkgroup_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
block_local_id),
PassThroughOp{});
if(thread_k_cluster_id == 0)
{
threadwise_mean_var_store_m_g.Run(thread_buffer_desc_m_1,
make_tuple(I0, I0),
mean_thread_buf,
mean_var_count_grid_desc_m_g,
mean_global_val_buf);
threadwise_mean_var_store_m_g.Run(thread_buffer_desc_m_1,
make_tuple(I0, I0),
var_thread_buf,
mean_var_count_grid_desc_m_g,
var_global_val_buf);
threadwise_count_store_m_g.Run(thread_buffer_desc_m_1,
make_tuple(I0, I0),
count_thread_buf,
mean_var_count_grid_desc_m_g,
count_global_val_buf);
};
gms_barrier(&p_control[blkgroup_id * 2]);
if(block_local_id == 0)
gms_reset(&p_control[blkgroup_id * 2]);
// Step 3: each workgroup reads welford results from workspace memory and does final welford
// reduction
auto threadwise_mean_var_load_m_k =
ThreadwiseTensorSliceTransfer_v2<MeanVarDataType,
AccDataType,
MeanVarCountGridDesc_M_K,
decltype(thread_buffer_desc_m_1),
ThreadBufferLengths_M_1,
Sequence<0, 1>,
0,
1,
1,
true>(
mean_var_count_grid_desc_m_k,
make_multi_index(blkgroup_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * 1));
auto threadwise_count_load_m_k =
ThreadwiseTensorSliceTransfer_v2<int32_t,
int32_t,
MeanVarCountGridDesc_M_K,
decltype(thread_buffer_desc_m_1),
ThreadBufferLengths_M_1,
Sequence<0, 1>,
0,
1,
1,
true>(
mean_var_count_grid_desc_m_k,
make_multi_index(blkgroup_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * 1));
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
mean_thread_buf(I) = type_convert<AccDataType>(0.0f);
var_thread_buf(I) = type_convert<AccDataType>(0.0f);
count_thread_buf(I) = 0;
});
constexpr auto mean_var_count_read_fwd_step_m_k = make_multi_index(0, KThreadClusterSize);
int32_t reducedSize = 0;
while(reducedSize < blkgroup_size)
{
threadwise_mean_var_load_m_k.Run(mean_var_count_grid_desc_m_k,
mean_global_val_buf,
thread_buffer_desc_m_1,
make_tuple(I0, I0),
tmp_mean_thread_buf);
threadwise_mean_var_load_m_k.Run(mean_var_count_grid_desc_m_k,
var_global_val_buf,
thread_buffer_desc_m_1,
make_tuple(I0, I0),
tmp_var_thread_buf);
threadwise_count_load_m_k.Run(mean_var_count_grid_desc_m_k,
count_global_val_buf,
thread_buffer_desc_m_1,
make_tuple(I0, I0),
tmp_count_thread_buf);
ThreadwiseWelford2::Run(tmp_mean_thread_buf,
tmp_var_thread_buf,
tmp_count_thread_buf,
mean_thread_buf,
var_thread_buf,
count_thread_buf);
reducedSize += KThreadClusterSize;
threadwise_mean_var_load_m_k.MoveSrcSliceWindow(mean_var_count_grid_desc_m_k,
mean_var_count_read_fwd_step_m_k);
threadwise_count_load_m_k.MoveSrcSliceWindow(mean_var_count_grid_desc_m_k,
mean_var_count_read_fwd_step_m_k);
};
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
if constexpr(I > 0)
block_sync_lds();
BlockwiseWelford2::Run(mean_thread_buf(I), var_thread_buf(I), count_thread_buf(I));
});
// Step 4: do normalization using the mean/variance
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> scale_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> bias_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
y_thread_buf;
auto threadwise_y_store =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
YDataType,
decltype(thread_buffer_desc_m_k),
XYGridDesc_M_K,
YElementwiseOp,
ThreadBufferLengths_M_K,
ThreadBufferDimAccessOrder,
XSrcYDstVectorDim,
YDstVectorSize,
InMemoryDataOperationEnum::Set,
1,
true>(
y_grid_desc_m_k,
make_multi_index(
blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize,
block_local_id * reduceSizePerBlock + thread_k_cluster_id * KThreadSliceSize),
y_elementwise_op);
auto threadwise_scale_load =
ThreadwiseTensorSliceTransfer_v2<ScaleDataType,
AccDataType,
ScaleBiasGridDesc_M,
decltype(thread_buffer_desc_m),
ThreadBufferLengths_M,
Sequence<0>,
0,
ScaleSrcVectorSize,
1,
true>(
scale_grid_desc_m,
make_multi_index(blkgroup_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize));
auto threadwise_bias_load = ThreadwiseTensorSliceTransfer_v2<BiasDataType,
AccDataType,
ScaleBiasGridDesc_M,
decltype(thread_buffer_desc_m),
ThreadBufferLengths_M,
Sequence<0>,
0,
BiasSrcVectorSize,
1,
true>(
bias_grid_desc_m,
make_multi_index(blkgroup_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize));
const auto scale_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_scale, scale_grid_desc_m.GetElementSpaceSize());
const auto bias_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_bias, bias_grid_desc_m.GetElementSpaceSize());
auto y_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_y, y_grid_desc_m_k.GetElementSpaceSize());
threadwise_scale_load.Run(scale_grid_desc_m,
scale_global_val_buf,
thread_buffer_desc_m,
make_tuple(I0),
scale_thread_buf);
threadwise_bias_load.Run(bias_grid_desc_m,
bias_global_val_buf,
thread_buffer_desc_m,
make_tuple(I0),
bias_thread_buf);
threadwise_x_load.SetSrcSliceOrigin(
x_grid_desc_m_k,
make_multi_index(blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize,
block_local_id * reduceSizePerBlock +
thread_k_cluster_id * KThreadSliceSize));
for(index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles)
{
threadwise_x_load.Run(x_grid_desc_m_k,
x_global_val_buf,
thread_buffer_desc_m_k,
make_tuple(I0, I0),
x_thread_buf);
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
AccDataType multiplier =
scale_thread_buf[Number<iM>{}] / sqrt(var_thread_buf[iM] + epsilon);
AccDataType fused_mean_bias =
bias_thread_buf[Number<iM>{}] - mean_thread_buf[iM] * multiplier;
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset =
thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK));
// normalize
y_thread_buf(Number<offset>{}) =
x_thread_buf[Number<offset>{}] * multiplier + fused_mean_bias;
});
});
threadwise_y_store.Run(thread_buffer_desc_m_k,
make_tuple(I0, I0),
y_thread_buf,
y_grid_desc_m_k,
y_global_val_buf);
threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, xy_copy_fwd_step_m_k);
threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k, xy_copy_fwd_step_m_k);
}
// Step 5: update the moving average of mean and variance (optional)
if(updateMovingAverage && block_local_id == 0 && thread_k_cluster_id == 0)
{
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true>
running_mean_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true>
running_var_thread_buf;
auto running_mean_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
resultRunningMean, mean_var_grid_desc_m.GetElementSpaceSize());
auto running_var_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
resultRunningVariance, mean_var_grid_desc_m.GetElementSpaceSize());
auto threadwise_mean_var_load =
ThreadwiseTensorSliceTransfer_v2<MeanVarDataType,
AccDataType,
MeanVarGridDesc_M,
decltype(thread_buffer_desc_m),
ThreadBufferLengths_M,
Sequence<0>,
0,
MeanVarSrcDstVectorSize,
1,
true>(
mean_var_grid_desc_m,
make_multi_index(blkgroup_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize));
threadwise_mean_var_load.Run(mean_var_grid_desc_m,
running_mean_global_buf,
thread_buffer_desc_m,
make_tuple(I0),
running_mean_thread_buf);
threadwise_mean_var_load.Run(mean_var_grid_desc_m,
running_var_global_buf,
thread_buffer_desc_m,
make_tuple(I0),
running_var_thread_buf);
AccDataType oneMinusAverageFactor = type_convert<AccDataType>(1.0) - averageFactor;
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
running_mean_thread_buf(I) = running_mean_thread_buf[I] * oneMinusAverageFactor +
mean_thread_buf[I] * averageFactor;
running_var_thread_buf(I) = running_var_thread_buf[I] * oneMinusAverageFactor +
var_thread_buf[I] * averageFactor;
});
auto threadwise_mean_var_store =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
MeanVarDataType,
decltype(thread_buffer_desc_m),
MeanVarGridDesc_M,
PassThroughOp,
ThreadBufferLengths_M,
Sequence<0>,
0,
MeanVarSrcDstVectorSize,
InMemoryDataOperationEnum::Set,
1,
true>(
mean_var_grid_desc_m,
make_multi_index(blkgroup_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize),
PassThroughOp{});
threadwise_mean_var_store.Run(thread_buffer_desc_m,
make_tuple(I0),
running_mean_thread_buf,
mean_var_grid_desc_m,
running_mean_global_buf);
threadwise_mean_var_store.Run(thread_buffer_desc_m,
make_tuple(I0),
running_var_thread_buf,
mean_var_grid_desc_m,
running_var_global_buf);
};
// Step 6: save mean and inv-variance (optional)
if(saveMeanInvVariance && block_local_id == 0 && thread_k_cluster_id == 0)
{
auto result_mean_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
resultSaveMean, mean_var_grid_desc_m.GetElementSpaceSize());
auto result_inv_var_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
resultSaveInvVariance, mean_var_grid_desc_m.GetElementSpaceSize());
// calculate inv-variance as 1/sqrt(epsilon+variance), stored in place of variance
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
var_thread_buf(I) =
type_convert<AccDataType>(1.0f) / sqrt(epsilon + var_thread_buf[I]);
});
auto threadwise_mean_inv_var_store =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
MeanVarDataType,
decltype(thread_buffer_desc_m),
MeanVarGridDesc_M,
PassThroughOp,
ThreadBufferLengths_M,
Sequence<0>,
0,
MeanVarSrcDstVectorSize,
InMemoryDataOperationEnum::Set,
1,
true>(
mean_var_grid_desc_m,
make_multi_index(blkgroup_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize),
PassThroughOp{});
threadwise_mean_inv_var_store.Run(thread_buffer_desc_m,
make_tuple(I0),
mean_thread_buf,
mean_var_grid_desc_m,
result_mean_global_buf);
threadwise_mean_inv_var_store.Run(thread_buffer_desc_m,
make_tuple(I0),
var_thread_buf,
mean_var_grid_desc_m,
result_inv_var_global_buf);
};
}
}; // namespace ck
} // namespace ck
...@@ -118,8 +118,8 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal ...@@ -118,8 +118,8 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal
static constexpr auto thread_cluster_desc = static constexpr auto thread_cluster_desc =
make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{}); make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{});
using ThreadReduceSrcDesc_M_1 = decltype( using ThreadReduceSrcDesc_M_1 = decltype(make_naive_tensor_descriptor_packed(
make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{}, Number<1>{}))); make_tuple(Number<MThreadSliceSize>{}, Number<1>{})));
using ThreadReduceDstDesc_M = using ThreadReduceDstDesc_M =
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{}))); decltype(make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{})));
......
...@@ -161,7 +161,7 @@ struct GridwiseMultiblockWelfordFirstHalf ...@@ -161,7 +161,7 @@ struct GridwiseMultiblockWelfordFirstHalf
PassThroughOp, PassThroughOp,
ThreadBufferLengths_M_1, ThreadBufferLengths_M_1,
Sequence<0, 1>, Sequence<0, 1>,
1, 0,
1, 1,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
1, 1,
...@@ -180,7 +180,7 @@ struct GridwiseMultiblockWelfordFirstHalf ...@@ -180,7 +180,7 @@ struct GridwiseMultiblockWelfordFirstHalf
PassThroughOp, PassThroughOp,
ThreadBufferLengths_M_1, ThreadBufferLengths_M_1,
Sequence<0, 1>, Sequence<0, 1>,
1, 0,
1, 1,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
1, 1,
......
...@@ -33,7 +33,6 @@ __global__ void kernel_welford_second_half_batchnorm_forward_final( ...@@ -33,7 +33,6 @@ __global__ void kernel_welford_second_half_batchnorm_forward_final(
const MeanVarGridDesc_M mean_var_grid_desc_m, const MeanVarGridDesc_M mean_var_grid_desc_m,
index_t blkgroup_size, index_t blkgroup_size,
index_t num_xy_k_block_tile_iteration, index_t num_xy_k_block_tile_iteration,
index_t num_mean_var_count_k_block_tile_iteration,
AccDataType epsilon, AccDataType epsilon,
const MeanVarDataType* const __restrict__ p_in_welford_mean, const MeanVarDataType* const __restrict__ p_in_welford_mean,
const MeanVarDataType* const __restrict__ p_in_welford_variance, const MeanVarDataType* const __restrict__ p_in_welford_variance,
...@@ -59,7 +58,6 @@ __global__ void kernel_welford_second_half_batchnorm_forward_final( ...@@ -59,7 +58,6 @@ __global__ void kernel_welford_second_half_batchnorm_forward_final(
mean_var_grid_desc_m, mean_var_grid_desc_m,
blkgroup_size, blkgroup_size,
num_xy_k_block_tile_iteration, num_xy_k_block_tile_iteration,
num_mean_var_count_k_block_tile_iteration,
epsilon, epsilon,
p_in_welford_mean, p_in_welford_mean,
p_in_welford_variance, p_in_welford_variance,
...@@ -123,8 +121,8 @@ struct GridwiseWelfordSecondHalfBatchNormForwardFinal ...@@ -123,8 +121,8 @@ struct GridwiseWelfordSecondHalfBatchNormForwardFinal
static constexpr auto thread_cluster_desc = static constexpr auto thread_cluster_desc =
make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{}); make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{});
using ThreadReduceSrcDesc_M_1 = decltype( using ThreadReduceSrcDesc_M_1 = decltype(make_naive_tensor_descriptor_packed(
make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{}, Number<1>{}))); make_tuple(Number<MThreadSliceSize>{}, Number<1>{})));
using ThreadReduceDstDesc_M = using ThreadReduceDstDesc_M =
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{}))); decltype(make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{})));
...@@ -152,7 +150,6 @@ struct GridwiseWelfordSecondHalfBatchNormForwardFinal ...@@ -152,7 +150,6 @@ struct GridwiseWelfordSecondHalfBatchNormForwardFinal
const MeanVarGridDesc_M& mean_var_grid_desc_m, const MeanVarGridDesc_M& mean_var_grid_desc_m,
index_t blkgroup_size, index_t blkgroup_size,
index_t num_xy_k_block_tile_iteration, index_t num_xy_k_block_tile_iteration,
index_t num_mean_var_count_k_block_tile_iteration,
AccDataType epsilon, AccDataType epsilon,
const MeanVarDataType* const __restrict__ p_in_welford_mean, const MeanVarDataType* const __restrict__ p_in_welford_mean,
const MeanVarDataType* const __restrict__ p_in_welford_variance, const MeanVarDataType* const __restrict__ p_in_welford_variance,
...@@ -223,7 +220,7 @@ struct GridwiseWelfordSecondHalfBatchNormForwardFinal ...@@ -223,7 +220,7 @@ struct GridwiseWelfordSecondHalfBatchNormForwardFinal
decltype(thread_buffer_desc_m_1), decltype(thread_buffer_desc_m_1),
ThreadBufferLengths_M_1, ThreadBufferLengths_M_1,
Sequence<0, 1>, Sequence<0, 1>,
1, 0,
1, 1,
1, 1,
true>( true>(
...@@ -239,7 +236,7 @@ struct GridwiseWelfordSecondHalfBatchNormForwardFinal ...@@ -239,7 +236,7 @@ struct GridwiseWelfordSecondHalfBatchNormForwardFinal
decltype(thread_buffer_desc_m_1), decltype(thread_buffer_desc_m_1),
ThreadBufferLengths_M_1, ThreadBufferLengths_M_1,
Sequence<0, 1>, Sequence<0, 1>,
1, 0,
1, 1,
1, 1,
true>( true>(
...@@ -257,9 +254,6 @@ struct GridwiseWelfordSecondHalfBatchNormForwardFinal ...@@ -257,9 +254,6 @@ struct GridwiseWelfordSecondHalfBatchNormForwardFinal
const auto welford_count_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto welford_count_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_welford_count, mean_var_count_grid_desc_m_k.GetElementSpaceSize()); p_in_welford_count, mean_var_count_grid_desc_m_k.GetElementSpaceSize());
constexpr auto mean_var_count_thread_copy_step_m_k =
make_multi_index(0, KThreadClusterSize * 1);
// Step 1: do final welford reduction to get mean and variance // Step 1: do final welford reduction to get mean and variance
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
...@@ -268,8 +262,11 @@ struct GridwiseWelfordSecondHalfBatchNormForwardFinal ...@@ -268,8 +262,11 @@ struct GridwiseWelfordSecondHalfBatchNormForwardFinal
welford_count_thread_buf(I) = 0; welford_count_thread_buf(I) = 0;
}); });
for(index_t reducedTiles = 0; reducedTiles < num_mean_var_count_k_block_tile_iteration; constexpr auto mean_var_count_thread_copy_step_m_k =
++reducedTiles) make_multi_index(0, KThreadClusterSize);
int32_t reducedSize = 0;
while(reducedSize < blkgroup_size)
{ {
threadwise_mean_var_load_m_k.Run(mean_var_count_grid_desc_m_k, threadwise_mean_var_load_m_k.Run(mean_var_count_grid_desc_m_k,
welford_mean_global_val_buf, welford_mean_global_val_buf,
...@@ -296,6 +293,8 @@ struct GridwiseWelfordSecondHalfBatchNormForwardFinal ...@@ -296,6 +293,8 @@ struct GridwiseWelfordSecondHalfBatchNormForwardFinal
welford_var_thread_buf, welford_var_thread_buf,
welford_count_thread_buf); welford_count_thread_buf);
reducedSize += KThreadClusterSize;
threadwise_mean_var_load_m_k.MoveSrcSliceWindow(mean_var_count_grid_desc_m_k, threadwise_mean_var_load_m_k.MoveSrcSliceWindow(mean_var_count_grid_desc_m_k,
mean_var_count_thread_copy_step_m_k); mean_var_count_thread_copy_step_m_k);
threadwise_count_load_m_k.MoveSrcSliceWindow(mean_var_count_grid_desc_m_k, threadwise_count_load_m_k.MoveSrcSliceWindow(mean_var_count_grid_desc_m_k,
......
...@@ -115,8 +115,8 @@ struct GridwiseWelfordSecondHalfReduceFirstHalf ...@@ -115,8 +115,8 @@ struct GridwiseWelfordSecondHalfReduceFirstHalf
using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed( using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{}))); make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{})));
using ThreadReduceSrcDesc_M_1 = decltype( using ThreadReduceSrcDesc_M_1 = decltype(make_naive_tensor_descriptor_packed(
make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{}, Number<1>{}))); make_tuple(Number<MThreadSliceSize>{}, Number<1>{})));
using ThreadReduceDstDesc_M = using ThreadReduceDstDesc_M =
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{}))); decltype(make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{})));
......
...@@ -7,6 +7,8 @@ ...@@ -7,6 +7,8 @@
#include "ck/utility/number.hpp" #include "ck/utility/number.hpp"
#include "ck/tensor_description/tensor_adaptor.hpp" #include "ck/tensor_description/tensor_adaptor.hpp"
#include "ck/tensor_description/multi_index_transform_helper.hpp" #include "ck/tensor_description/multi_index_transform_helper.hpp"
#include <limits>
#include <stdlib.h>
namespace ck { namespace ck {
...@@ -669,4 +671,406 @@ struct BlockToCTileMap_3DGrid_KSplit ...@@ -669,4 +671,406 @@ struct BlockToCTileMap_3DGrid_KSplit
} }
}; };
enum StreamKReductionStrategy
{
Atomic = 0, // sk block use atomic to do reduction
Reduction, // let some workgroup responsible for doing the reduction operation
};
template <uint32_t MPerBlock_,
uint32_t NPerBlock_,
uint32_t KPerBlock_,
StreamKReductionStrategy ReductionStrategy_ = StreamKReductionStrategy::Atomic,
uint32_t TileSwizzleSubM_ = 8>
struct BlockToCTileMap_GemmStreamK
{
static constexpr uint32_t min_k_iters_per_sk_block = 2;
static constexpr uint32_t MPerBlock = MPerBlock_;
static constexpr uint32_t NPerBlock = NPerBlock_;
static constexpr uint32_t KPerBlock = KPerBlock_;
static constexpr StreamKReductionStrategy ReductionStrategy = ReductionStrategy_;
static constexpr uint32_t tile_swizzle_sub_m = TileSwizzleSubM_;
//--------------------------------------
// pass to device
uint32_t sk_num_blocks;
uint32_t sk_num_big_blocks;
uint32_t dp_start_block_idx;
uint32_t reduction_start_block_idx;
uint32_t k_iters_per_big_block;
MDiv2 n_tiles;
MDiv k_iters_per_tile;
MDiv eqav_tiles_big; // for reduction
MDiv eqav_tiles_little; // for reduction
// MDiv tile_swizzle_sub_m_rem;
//--------------------------------------
// prefer construct on host
BlockToCTileMap_GemmStreamK(uint32_t m,
uint32_t n,
uint32_t k,
uint32_t num_cu,
uint32_t occupancy,
uint32_t sk_blocks = 0xffffffff)
{
uint32_t num_tiles =
math::integer_divide_ceil(m, MPerBlock) * math::integer_divide_ceil(n, NPerBlock);
k_iters_per_tile = MDiv(math::integer_divide_ceil(k, KPerBlock));
// one cu can hold one wg at one time, from the whole chip's point of view
// if number of wg is same as num_cu, we call it 1 dispatch
// if number of wg is 2x num_cu, we call it 2 dispatches.
// one dispatch can deliver wg same as num_cu (full dispatch), or less than num_cu (partial
// dispatch)
//
uint32_t full_dispatches = num_tiles / num_cu;
uint32_t full_dispatch_tiles = full_dispatches * num_cu;
uint32_t partial_dispatche_tiles = num_tiles - full_dispatch_tiles;
uint32_t sk_occupancy = occupancy;
uint32_t dp_tiles = full_dispatch_tiles;
uint32_t sk_tiles = partial_dispatche_tiles;
if(full_dispatches < occupancy)
{
// in this case, we allocate all blocks as sk blocks
// sk_occupancy = occupancy - full_dispatches;
sk_occupancy = 1; // TODO: single occ seems better
dp_tiles = full_dispatch_tiles;
sk_tiles = partial_dispatche_tiles;
}
else if((occupancy > 1) && (full_dispatches % occupancy == occupancy - 1))
{
// e.g. occupancy = 2, full_dispatches = 3, 5, 7 ...
// occupancy = 3, full_dispatches = 5, 8, 11 ...
// occupancy = 4, full_dispatches = 7, 11 ...
sk_occupancy = 1; // left 1 slot for sk occupancy
dp_tiles = full_dispatch_tiles;
sk_tiles = partial_dispatche_tiles;
}
else
{
// others, we reduce 1 dispatch from dp, together with partial dispatch,
// to construct sk dispatch
sk_occupancy = occupancy - ((full_dispatches - 1) % occupancy);
dp_tiles = full_dispatch_tiles - num_cu;
sk_tiles = partial_dispatche_tiles + num_cu;
}
// uint32_t dp_iters_per_block = k_iters_per_tile.get();
uint32_t sk_total_iters = k_iters_per_tile.get() * sk_tiles;
uint32_t dp_num_blocks = 0;
{
uint32_t min_sk_tiles = (sk_tiles >= num_cu) ? num_cu : (sk_tiles + 1);
uint32_t max_sk_tiles =
(sk_tiles >= num_cu) ? num_cu * sk_occupancy
: math::min(num_cu, sk_total_iters / min_k_iters_per_sk_block);
// if use dp for sk-block, how many iters do we need
uint32_t dp_for_sk_iters = k_iters_per_tile.get();
uint32_t best_sk_score =
std::numeric_limits<int>::max(); // we need to find the smallest sk iters
for(uint32_t tentative_sk_blocks = min_sk_tiles; tentative_sk_blocks < max_sk_tiles;
tentative_sk_blocks++)
{
uint32_t tentative_sk_iters_per_block =
(sk_total_iters + tentative_sk_blocks - 1) / tentative_sk_blocks;
uint32_t tentative_sk_iters = tentative_sk_iters_per_block;
uint32_t sk_blocks_per_tile = (tentative_sk_blocks + sk_tiles - 1) / sk_tiles;
// TODO: carefully adjust this parameter
// the more sk_blocks_per_tile, the worse the overhead
uint32_t cross_sk_blocks_overhead = sk_blocks_per_tile;
if(tentative_sk_blocks % sk_tiles != 0)
{
// penalty for uneven divide
cross_sk_blocks_overhead +=
sk_blocks_per_tile * tentative_sk_iters_per_block / 50;
}
uint32_t tentative_sk_score = tentative_sk_iters + cross_sk_blocks_overhead;
if(tentative_sk_score < best_sk_score)
{
best_sk_score = tentative_sk_score;
sk_num_blocks = tentative_sk_blocks;
}
}
if(best_sk_score >= dp_for_sk_iters)
{
sk_num_blocks = 0;
}
// give a chance to control num of sk blocks
sk_num_blocks = sk_blocks != 0xffffffff ? sk_blocks : sk_num_blocks;
if(sk_num_blocks == 0)
{
sk_num_big_blocks = 0;
k_iters_per_big_block = 0;
dp_num_blocks = num_tiles; // all tile to be dp block
dp_start_block_idx = 0;
sk_total_iters = 0; // clear this tiles
}
else
{
// k_iters_per_sk_block is the floor of avg each ck block loop over tiles.
// we need to decide how many iters for each sk block
// let m = k_iters_per_sk_block
// some of the sk block (little) will cover m iters, some (big) will cover m+1
// we have
// 1) l + b = sk_blocks
// 2) l * m + b * (m + 1) = sk_total_iters
// => (l + b) * m + b = sk_total_iters
// => sk_blocks * m + b = sk_total_iters
// => b = sk_total_iters - m * sk_blocks
// NOTE: big could be zero
uint32_t k_iters_per_sk_block = sk_total_iters / sk_num_blocks;
sk_num_big_blocks = sk_total_iters - k_iters_per_sk_block * sk_num_blocks;
k_iters_per_big_block = k_iters_per_sk_block + 1;
dp_num_blocks = dp_tiles;
dp_start_block_idx = (sk_num_blocks + num_cu - 1) / num_cu * num_cu;
}
}
n_tiles = MDiv2(math::integer_divide_ceil(n, NPerBlock));
reduction_start_block_idx = dp_start_block_idx + dp_num_blocks;
if constexpr(ReductionStrategy == StreamKReductionStrategy::Reduction)
{
uint32_t upper_big = math::lcm(k_iters_per_big_block, k_iters_per_tile.get());
uint32_t upper_little = math::lcm(k_iters_per_big_block - 1, k_iters_per_tile.get());
eqav_tiles_big = MDiv(upper_big / k_iters_per_tile.get());
eqav_tiles_little = MDiv(upper_little / k_iters_per_tile.get());
}
#if 0
printf("cu:%d, occupancy:%d, grids:%d, num_tiles:%d, dp_tiles:%d, sk_num_big_blocks:%d, "
"sk_num_blocks:%d, "
"sk_total_iters:%d, dp_start_block_idx:%d, dp_iters_per_block:%d, dp_num_blocks:%d, "
"k_iters_per_tile:%d, k_iters_per_big_block:%d, reduction_start_block_idx:%u, "
"sk_tiles:%u, workspace(acc float):%u\n",
num_cu,
occupancy,
get_grid_dims().x,
num_tiles,
dp_tiles,
sk_num_big_blocks,
sk_num_blocks,
sk_total_iters,
dp_start_block_idx,
dp_iters_per_block,
dp_num_blocks,
k_iters_per_tile.get(),
k_iters_per_big_block,
reduction_start_block_idx,
get_sk_tiles(),
get_workspace_size(sizeof(float)));
#endif
}
__host__ __device__ uint32_t get_sk_total_iters() const
{
uint32_t sk_total_iters = sk_num_big_blocks * k_iters_per_big_block +
(sk_num_blocks - sk_num_big_blocks) * (k_iters_per_big_block - 1);
return sk_total_iters;
}
__host__ __device__ uint32_t get_sk_tiles() const
{
// tiles for sk
uint32_t sk_total_iters = get_sk_total_iters();
return k_iters_per_tile.div(sk_total_iters);
}
__host__ __device__ dim3 get_grid_dims() const
{
if constexpr(ReductionStrategy == StreamKReductionStrategy::Reduction)
{
return dim3(reduction_start_block_idx + get_sk_tiles(), 1, 1);
}
else
return dim3(reduction_start_block_idx, 1, 1);
}
__device__ uint32_t get_block_idx() const
{
// TODO: swizzle block index for better locality
return __builtin_amdgcn_readfirstlane(blockIdx.x);
}
__device__ void
get_block_itr(uint32_t block_idx, uint32_t& iter_start, uint32_t& iter_end) const
{
if(block_idx < sk_num_big_blocks)
{
iter_start = block_idx * k_iters_per_big_block;
iter_end = iter_start + k_iters_per_big_block;
}
else if(block_idx < sk_num_blocks)
{
iter_start = (sk_num_big_blocks * k_iters_per_big_block) +
(block_idx - sk_num_big_blocks) * (k_iters_per_big_block - 1);
iter_end = iter_start + (k_iters_per_big_block - 1);
}
else if(block_idx >= dp_start_block_idx)
{
uint32_t sk_total_iters = get_sk_total_iters();
uint32_t dp_iters_per_block = k_iters_per_tile.get();
iter_start = sk_total_iters + (block_idx - dp_start_block_idx) * dp_iters_per_block;
iter_end = iter_start + dp_iters_per_block;
}
}
__device__ uint32_t get_current_iter_length(uint32_t iter_start,
uint32_t iter_end,
uint32_t total_iter_length) const
{
uint32_t iter_length_mod, iter_length_quo /*unused*/;
k_iters_per_tile.divmod(iter_end, iter_length_quo, iter_length_mod);
uint32_t current_iter_length = math::min(
iter_length_mod == 0 ? (iter_end - iter_start) : iter_length_mod, total_iter_length);
return current_iter_length;
}
__device__ uint32_t get_tile_idx(uint32_t iter) const { return k_iters_per_tile.div(iter); }
__device__ void
get_tile_idx_with_offset(uint32_t iter, uint32_t& tile_idx, uint32_t& iter_offset) const
{
k_iters_per_tile.divmod(iter, tile_idx, iter_offset);
}
__device__ auto tile_to_spatial(uint32_t tile_idx, uint32_t m, uint32_t n) const
{
uint32_t m_tile_idx, n_tile_idx;
uint32_t n_tiles_value = math::integer_divide_ceil(n, NPerBlock);
n_tiles.divmod(tile_idx, n_tiles_value, m_tile_idx, n_tile_idx);
// swizzle tile
uint32_t m_tiles = math::integer_divide_ceil(m, MPerBlock);
uint32_t tile_swizzle_sub_m_rem = m_tiles % tile_swizzle_sub_m;
const auto sub_m_adapt = (m_tile_idx < (m_tiles - tile_swizzle_sub_m_rem))
? tile_swizzle_sub_m
: tile_swizzle_sub_m_rem;
uint32_t m_tile_idx_sub0, m_tile_idx_sub1;
m_tile_idx_sub0 = m_tile_idx / tile_swizzle_sub_m;
m_tile_idx_sub1 = m_tile_idx % tile_swizzle_sub_m;
uint32_t tile_idx_local = n_tile_idx + m_tile_idx_sub1 * n_tiles_value;
uint32_t m_tile_idx_with_adapt, n_tile_idx_with_adapt;
n_tile_idx_with_adapt = tile_idx_local / sub_m_adapt;
m_tile_idx_with_adapt = tile_idx_local % sub_m_adapt;
return make_tuple(m_tile_idx_with_adapt + m_tile_idx_sub0 * tile_swizzle_sub_m,
n_tile_idx_with_adapt);
}
__host__ __device__ uint32_t get_workspace_size_for_acc(uint32_t acc_element_bytes) const
{
static constexpr uint32_t alignment = 128;
uint32_t acc_buffer_bytes =
MPerBlock * NPerBlock * get_total_acc_buffers() * acc_element_bytes;
return (acc_buffer_bytes + alignment - 1) / alignment * alignment;
}
__host__ __device__ uint32_t get_workspace_size_for_semaphore() const
{
return get_sk_tiles() * sizeof(uint32_t);
}
__host__ __device__ uint32_t get_workspace_size(uint32_t acc_element_bytes) const
{
return get_workspace_size_for_acc(acc_element_bytes) + get_workspace_size_for_semaphore();
}
__host__ __device__ uint32_t get_tile_intersections(uint32_t tiles_,
const MDiv& eqav_tiles_) const
{
uint32_t tile_idx_ = tiles_ == 0 ? 0 : (tiles_ - 1);
uint32_t max_eqav_tiles_ = eqav_tiles_.get() - 1;
uint32_t quo_, rem_;
eqav_tiles_.divmod(tile_idx_, quo_, rem_);
return quo_ * max_eqav_tiles_ + rem_;
}
__host__ __device__ uint32_t get_tiles_cover_sk_block(uint32_t num_sk_blocks_,
uint32_t iters_per_sk_block_) const
{
return k_iters_per_tile.div(num_sk_blocks_ * iters_per_sk_block_ + k_iters_per_tile.get() -
1);
}
__host__ __device__ uint32_t get_total_acc_buffers() const
{
uint32_t tiles_cover_big_blocks =
get_tiles_cover_sk_block(sk_num_big_blocks, k_iters_per_big_block);
uint32_t tiles_cover_little_blocks =
get_tiles_cover_sk_block(sk_num_blocks - sk_num_big_blocks, k_iters_per_big_block - 1);
uint32_t total_intersec_big =
get_tile_intersections(tiles_cover_big_blocks, eqav_tiles_big);
uint32_t total_intersec_little =
get_tile_intersections(tiles_cover_little_blocks, eqav_tiles_little);
return sk_num_blocks + total_intersec_big + total_intersec_little;
}
__device__ uint32_t get_acc_buffer_offset_from_tile(uint32_t tile_idx_) const
{
// TODO: from big to little
uint32_t tiles_cover_big_blocks =
get_tiles_cover_sk_block(sk_num_big_blocks, k_iters_per_big_block);
if(tile_idx_ < tiles_cover_big_blocks)
{
uint32_t touched_sk_blocks =
(tile_idx_ * k_iters_per_tile.get() + k_iters_per_big_block - 1) /
k_iters_per_big_block;
uint32_t current_intersec = get_tile_intersections(tile_idx_, eqav_tiles_big);
return touched_sk_blocks + current_intersec;
}
else
{
uint32_t iters_per_little_sk_block = k_iters_per_big_block - 1;
uint32_t tile_idx_little_reverse = get_sk_tiles() - tile_idx_;
uint32_t touched_sk_blocks =
(tile_idx_little_reverse * k_iters_per_tile.get() + iters_per_little_sk_block - 1) /
iters_per_little_sk_block;
uint32_t current_intersec =
get_tile_intersections(tile_idx_little_reverse, eqav_tiles_little);
return get_total_acc_buffers() - (touched_sk_blocks + current_intersec);
}
}
__device__ uint32_t get_acc_buffer_offset_from_block(uint32_t block_idx_) const
{
uint32_t iters_per_big_sk_block = k_iters_per_big_block;
uint32_t iters_per_little_sk_block = k_iters_per_big_block - 1;
if(block_idx_ < sk_num_big_blocks)
{
uint32_t touched_tiles = k_iters_per_tile.div(block_idx_ * iters_per_big_sk_block +
k_iters_per_tile.get() - 1);
uint32_t current_intersec = get_tile_intersections(touched_tiles, eqav_tiles_big);
return block_idx_ + current_intersec;
}
else
{
uint32_t block_idx_little_reverse = sk_num_blocks - block_idx_;
uint32_t touched_tiles = k_iters_per_tile.div(
block_idx_little_reverse * iters_per_little_sk_block + k_iters_per_tile.get() - 1);
uint32_t current_intersec = get_tile_intersections(touched_tiles, eqav_tiles_little);
return get_total_acc_buffers() - (block_idx_little_reverse + current_intersec);
}
}
};
} // namespace ck } // namespace ck
...@@ -101,8 +101,8 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle ...@@ -101,8 +101,8 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
using ThisThreadBlock = ThisThreadBlock<BlockSize>; using ThisThreadBlock = ThisThreadBlock<BlockSize>;
using GridwiseGemmPipe = remove_cvref_t<decltype( using GridwiseGemmPipe = remove_cvref_t<
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>())>; decltype(GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>())>;
__host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
{ {
...@@ -346,14 +346,18 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle ...@@ -346,14 +346,18 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
remove_cvref_t<decltype(MakeDefaultAGridDescriptor_AK0_M_AK1(AGridDesc_M_K{}))>; remove_cvref_t<decltype(MakeDefaultAGridDescriptor_AK0_M_AK1(AGridDesc_M_K{}))>;
using DefaultBGridDesc_BK0_N_BK1 = using DefaultBGridDesc_BK0_N_BK1 =
remove_cvref_t<decltype(MakeDefaultBGridDescriptor_BK0_N_BK1(BGridDesc_N_K{}))>; remove_cvref_t<decltype(MakeDefaultBGridDescriptor_BK0_N_BK1(BGridDesc_N_K{}))>;
using EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype( using EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock =
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{}))>; remove_cvref_t<decltype(MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
using MeanVarGridDescriptor_MBlock_MPerBlock_NBlock = remove_cvref_t<decltype( EGridDesc_M_N{}))>;
MakeMeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock(MeanVarGridDesc_M_NBlock{}))>; using MeanVarGridDescriptor_MBlock_MPerBlock_NBlock =
using CountGridDescriptor_MBlock_MPerBlock_NBlock = remove_cvref_t<decltype( remove_cvref_t<decltype(MakeMeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock(
MakeMeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock(CountGridDesc_M_NBlock{}))>; MeanVarGridDesc_M_NBlock{}))>;
using DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype( using CountGridDescriptor_MBlock_MPerBlock_NBlock =
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(DsGridDesc_M_N{}))>; remove_cvref_t<decltype(MakeMeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock(
CountGridDesc_M_NBlock{}))>;
using DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock =
remove_cvref_t<decltype(MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
DsGridDesc_M_N{}))>;
using DefaultBlock2ETileMap = using DefaultBlock2ETileMap =
remove_cvref_t<decltype(MakeDefaultBlock2ETileMap(EGridDesc_M_N{}))>; remove_cvref_t<decltype(MakeDefaultBlock2ETileMap(EGridDesc_M_N{}))>;
......
...@@ -102,8 +102,8 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle ...@@ -102,8 +102,8 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
using ThisThreadBlock = ThisThreadBlock<BlockSize>; using ThisThreadBlock = ThisThreadBlock<BlockSize>;
using GridwiseGemmPipe = remove_cvref_t<decltype( using GridwiseGemmPipe = remove_cvref_t<
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage>())>; decltype(GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage>())>;
template <typename ABlockDesc_AK0_M_AK1> template <typename ABlockDesc_AK0_M_AK1>
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
...@@ -286,8 +286,9 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle ...@@ -286,8 +286,9 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
c_grid_desc_m_n); c_grid_desc_m_n);
} }
using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype( using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock =
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{}))>; remove_cvref_t<decltype(MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
CGridDesc_M_N{}))>;
using DefaultBlock2CTileMap = using DefaultBlock2CTileMap =
remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}))>; remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}))>;
......
...@@ -67,6 +67,8 @@ template <typename A0B0B1DataType, // FIXME: don't assume A0/B0/B1 have same dat ...@@ -67,6 +67,8 @@ template <typename A0B0B1DataType, // FIXME: don't assume A0/B0/B1 have same dat
index_t B0BlockTransferDstScalarPerVector_BK1, index_t B0BlockTransferDstScalarPerVector_BK1,
bool B0ThreadTransferSrcResetCoordinateAfterRun, // ignored bool B0ThreadTransferSrcResetCoordinateAfterRun, // ignored
index_t B0BlockLdsExtraN, index_t B0BlockLdsExtraN,
index_t CDE0BlockTransferSrcVectorDim,
index_t CDE0BlockTransferSrcScalarPerVector,
typename B1BlockTransferThreadClusterLengths_BK0_N_BK1, typename B1BlockTransferThreadClusterLengths_BK0_N_BK1,
typename B1BlockTransferThreadClusterArrangeOrder, typename B1BlockTransferThreadClusterArrangeOrder,
typename B1BlockTransferSrcAccessOrder, typename B1BlockTransferSrcAccessOrder,
...@@ -444,14 +446,17 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle ...@@ -444,14 +446,17 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
e1_grid_desc_m_n); e1_grid_desc_m_n);
} }
using E1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype( using E1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock =
MakeE1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(E1GridDesc_M_N{}))>; remove_cvref_t<decltype(MakeE1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
E1GridDesc_M_N{}))>;
using D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5 = remove_cvref_t<decltype( using D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5 =
MakeD0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(D0sGridDesc_M_N{}))>; remove_cvref_t<decltype(MakeD0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(
D0sGridDesc_M_N{}))>;
using D1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype( using D1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock =
MakeD1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(D1sGridDesc_M_N{}))>; remove_cvref_t<decltype(MakeD1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
D1sGridDesc_M_N{}))>;
using DefaultBlock2E1TileMap = using DefaultBlock2E1TileMap =
remove_cvref_t<decltype(MakeDefaultBlock2E1TileMap(E1GridDesc_M_N{}))>; remove_cvref_t<decltype(MakeDefaultBlock2E1TileMap(E1GridDesc_M_N{}))>;
...@@ -710,13 +715,13 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle ...@@ -710,13 +715,13 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
constexpr auto d0_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5 = constexpr auto d0_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5 =
make_naive_tensor_descriptor_packed(make_tuple(I1, // MBlockId make_naive_tensor_descriptor_packed(make_tuple(I1, // MBlockId
I1, // NBlockID I1, // NBlockID
I1, // MRepeat m0, // MRepeat
I1, // NRepeat n0, // NRepeat
I1, // MWaveId m1, // MWaveId
I1, // NWaveId n1, // NWaveId
I1, // MPerXdl m2, // MPerXdl
I1, // NGroupNum n2, // NGroupNum
I1, // NInputNum n3, // NInputNum
n4)); // registerNum n4)); // registerNum
auto d0s_thread_buf = generate_tuple( auto d0s_thread_buf = generate_tuple(
...@@ -732,8 +737,9 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle ...@@ -732,8 +737,9 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
const auto wave_id = GetGemm0WaveIdx(); const auto wave_id = GetGemm0WaveIdx();
const auto wave_m_n_id = GetGemm0WaveMNIdx(wave_id[I2]); // I2: 0~63 const auto wave_m_n_id = GetGemm0WaveMNIdx(wave_id[I2]); // I2: 0~63
constexpr auto acc0_thread_desc = make_naive_tensor_descriptor_packed( static_assert(CDE0BlockTransferSrcScalarPerVector <= n4,
make_tuple(Number<Gemm0MXdlPerWave>{}, Number<Gemm0NXdlPerWave>{}, n2, n4)); "vector load must be not greater than n4");
static_assert(n4 % CDE0BlockTransferSrcScalarPerVector == 0);
auto d0s_threadwise_copy = generate_tuple( auto d0s_threadwise_copy = generate_tuple(
[&](auto i) { [&](auto i) {
...@@ -742,10 +748,19 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle ...@@ -742,10 +748,19 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
A0B0B1DataType, A0B0B1DataType,
decltype(d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i]), decltype(d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i]),
decltype(d0_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5), decltype(d0_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5),
Sequence<I1, I1, I1, I1, I1, I1, I1, I1, I1, n4>, Sequence<I1, // MBlockId
I1, // NBlockID
m0, // MRepeat
n0, // NRepeat
m1, // MWaveId
n1, // NWaveId
m2, // MPerXdl
n2, // NGroupNum
n3, // NInputNum
n4>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>, Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>,
9, 9, // CDE0BlockTransferSrcVectorDim
n4, CDE0BlockTransferSrcScalarPerVector,
1, 1,
false>(d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i], false>(d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i],
make_multi_index(block_work_idx[I0], // MBlockId make_multi_index(block_work_idx[I0], // MBlockId
...@@ -898,66 +913,42 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle ...@@ -898,66 +913,42 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
blockwise_gemm0, blockwise_gemm0,
acc0_thread_buf, acc0_thread_buf,
num_k_block_main_loop); num_k_block_main_loop);
// bias+gelu // multiple d
if constexpr(NumD0Tensor)
{ {
static_for<0, Gemm0MXdlPerWave, 1>{}([&](auto mr) { static_for<0, NumD0Tensor, 1>{}([&](auto i) {
static_for<0, Gemm0NXdlPerWave, 1>{}([&](auto nr) { d0s_threadwise_copy(i).Run(d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i],
static_for<0, n2, 1>{}([&](auto groupid) { d0s_grid_buf[i],
static_for<0, NumD0Tensor, 1>{}([&](auto i) { d0_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
d0s_threadwise_copy(i).Run( make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i], d0s_thread_buf(i));
d0s_grid_buf[i], });
d0_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, static_for<0, m0 * n0 * n2 * n4, 1>{}([&](auto i) {
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0), // get reference to src data
d0s_thread_buf(i)); const auto src_data_refs = generate_tie(
}); // return type should be lvalue
[&](auto iSrc) -> const auto& { return d0s_thread_buf[iSrc][i]; },
static_for<0, n4, 1>{}([&](auto i) { Number<NumD0Tensor>{});
constexpr index_t c_offset = acc0_thread_desc.CalculateOffset(
make_tuple(mr, nr, groupid, i)); // get reference to dst data
auto dst_data_refs = generate_tie(
// get reference to src data // return type should be lvalue
const auto src_data_refs = generate_tie( [&](auto) -> auto& { return acc0_thread_buf(i); },
// return type should be lvalue Number<2>{});
[&](auto iSrc) -> const auto& {
return d0s_thread_buf[iSrc][i]; unpack2(cde0_element_op, dst_data_refs, src_data_refs);
},
Number<NumD0Tensor>{});
// get reference to dst data
auto dst_data_refs = generate_tie(
// return type should be lvalue
[&](auto) -> auto& {
return acc0_thread_buf(Number<c_offset>{});
},
Number<2>{});
unpack2(cde0_element_op, dst_data_refs, src_data_refs);
});
static_for<0, NumD0Tensor, 1>{}([&](auto i) {
d0s_threadwise_copy(i).MoveSrcSliceWindow(
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i],
make_multi_index(0, 0, 0, 0, 0, 0, 0, 1, 0, 0));
});
});
static_for<0, NumD0Tensor, 1>{}([&](auto i) {
d0s_threadwise_copy(i).MoveSrcSliceWindow(
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i],
make_multi_index(0, 0, 0, 1, 0, 0, 0, -n2.value, 0, 0));
});
});
static_for<0, NumD0Tensor, 1>{}([&](auto i) {
d0s_threadwise_copy(i).MoveSrcSliceWindow(
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i],
make_multi_index(0, 0, 1, -Gemm0NXdlPerWave, 0, 0, 0, 0, 0, 0));
});
}); });
static_for<0, NumD0Tensor, 1>{}([&](auto i) { static_for<0, NumD0Tensor, 1>{}([&](auto i) {
d0s_threadwise_copy(i).MoveSrcSliceWindow( d0s_threadwise_copy(i).MoveSrcSliceWindow(
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i], d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i],
make_multi_index(0, 1, -Gemm0MXdlPerWave, 0, 0, 0, 0, 0, 0, 0)); make_multi_index(0, 1, 0, 0, 0, 0, 0, 0, 0, 0));
}); });
} }
else
{
static_for<0, acc0_thread_buf.Size(), 1>{}(
[&](auto i) { cde0_element_op(acc_thread_buf(i), acc0_thread_buf[i]); });
}
// gemm1 // gemm1
{ {
// TODO: explore using dynamic buffer for a1 thread buffer // TODO: explore using dynamic buffer for a1 thread buffer
......
...@@ -114,8 +114,8 @@ struct GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle ...@@ -114,8 +114,8 @@ struct GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle
using ThisThreadBlock = ThisThreadBlock<BlockSize>; using ThisThreadBlock = ThisThreadBlock<BlockSize>;
using GridwiseGemmPipe = remove_cvref_t<decltype( using GridwiseGemmPipe = remove_cvref_t<
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage>())>; decltype(GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage>())>;
template <typename ABlockDesc_AK0_M_AK1> template <typename ABlockDesc_AK0_M_AK1>
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
...@@ -368,12 +368,14 @@ struct GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle ...@@ -368,12 +368,14 @@ struct GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle
Number<NumD0Tensor>{}); Number<NumD0Tensor>{});
} }
using D0sGridPointer = decltype(MakeD0sGridPointer()); using D0sGridPointer = decltype(MakeD0sGridPointer());
using D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5 = remove_cvref_t<decltype( using D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5 =
MakeD0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(D0sGridDesc_M_N{}))>; remove_cvref_t<decltype(MakeD0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(
D0sGridDesc_M_N{}))>;
using C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype( using C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock =
MakeC1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(C1GridDesc_M_N{}))>; remove_cvref_t<decltype(MakeC1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
C1GridDesc_M_N{}))>;
using DefaultBlock2CTileMap = using DefaultBlock2CTileMap =
remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(C1GridDesc_M_N{}))>; remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(C1GridDesc_M_N{}))>;
......
...@@ -113,8 +113,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -113,8 +113,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
using ThisThreadBlock = ThisThreadBlock<BlockSize>; using ThisThreadBlock = ThisThreadBlock<BlockSize>;
using GridwiseGemmPipe = remove_cvref_t<decltype( using GridwiseGemmPipe = remove_cvref_t<
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage>())>; decltype(GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage>())>;
template <typename ABlockDesc_AK0_M_AK1> template <typename ABlockDesc_AK0_M_AK1>
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
...@@ -300,8 +300,9 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -300,8 +300,9 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
c_grid_desc_m_n); c_grid_desc_m_n);
} }
using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype( using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock =
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{}))>; remove_cvref_t<decltype(MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
CGridDesc_M_N{}))>;
using DefaultBlock2CTileMap = using DefaultBlock2CTileMap =
remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}))>; remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}))>;
......
...@@ -191,8 +191,8 @@ struct GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -191,8 +191,8 @@ struct GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
using ThisThreadBlock = ThisThreadBlock<BlockSize>; using ThisThreadBlock = ThisThreadBlock<BlockSize>;
using GridwiseGemmPipe = remove_cvref_t<decltype( using GridwiseGemmPipe = remove_cvref_t<
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage>())>; decltype(GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage>())>;
__host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
{ {
...@@ -346,14 +346,17 @@ struct GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -346,14 +346,17 @@ struct GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
c_grid_desc_m_n); c_grid_desc_m_n);
} }
using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype( using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock =
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{}))>; remove_cvref_t<decltype(MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
CGridDesc_M_N{}))>;
using C0GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype( using C0GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock =
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(C0GridDesc_M_N{}))>; remove_cvref_t<decltype(MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
C0GridDesc_M_N{}))>;
using C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype( using C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock =
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(C1GridDesc_M_N{}))>; remove_cvref_t<decltype(MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
C1GridDesc_M_N{}))>;
using ReduceGridDescriptor_MBlock_MPerBlock = using ReduceGridDescriptor_MBlock_MPerBlock =
remove_cvref_t<decltype(MakeReduceGridDescriptor_MBlock_MPerBlock(ReduceGridDesc_M{}))>; remove_cvref_t<decltype(MakeReduceGridDescriptor_MBlock_MPerBlock(ReduceGridDesc_M{}))>;
......
...@@ -7,9 +7,11 @@ ...@@ -7,9 +7,11 @@
#include "ck/tensor_description/multi_index_transform_helper.hpp" #include "ck/tensor_description/multi_index_transform_helper.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/gemm_dl_algorithm.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" #include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_dl_v2r3.hpp" #include "ck/tensor_operation/gpu/block/blockwise_gemm_dl_v2r3.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_dl_dpp8.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_tensor_slice_transfer_v5r1.hpp" #include "ck/tensor_operation/gpu/block/blockwise_tensor_slice_transfer_v5r1.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" #include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_set.hpp" #include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_set.hpp"
...@@ -17,6 +19,8 @@ ...@@ -17,6 +19,8 @@
namespace ck { namespace ck {
using GemmDlAlgorithm = tensor_operation::device::GemmDlAlgorithm;
template <typename GridwiseGemm, template <typename GridwiseGemm,
typename FloatAB, typename FloatAB,
typename FloatC, typename FloatC,
...@@ -25,7 +29,8 @@ template <typename GridwiseGemm, ...@@ -25,7 +29,8 @@ template <typename GridwiseGemm,
typename CGridDesc_M0_M10_M11_N0_N10_N11, typename CGridDesc_M0_M10_M11_N0_N10_N11,
typename Block2CTileMap, typename Block2CTileMap,
bool HasMainKBlockLoop, bool HasMainKBlockLoop,
bool HasDoubleTailKBlockLoop> bool HasDoubleTailKBlockLoop,
GemmDlAlgorithm GemmDlAlg = GemmDlAlgorithm::Default>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
...@@ -38,6 +43,13 @@ __global__ void ...@@ -38,6 +43,13 @@ __global__ void
const CGridDesc_M0_M10_M11_N0_N10_N11 c_grid_desc_m0_m10_m11_n0_n10_n11, const CGridDesc_M0_M10_M11_N0_N10_N11 c_grid_desc_m0_m10_m11_n0_n10_n11,
const Block2CTileMap block_2_ctile_map) const Block2CTileMap block_2_ctile_map)
{ {
// DPP8 is currently only supported on gfx1030
#if !defined(__gfx1030__)
if(GemmDlAlg == GemmDlAlgorithm::Dpp8)
{
return;
}
#endif
constexpr index_t shared_block_size = constexpr index_t shared_block_size =
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
...@@ -88,7 +100,8 @@ template <index_t BlockSize, ...@@ -88,7 +100,8 @@ template <index_t BlockSize,
typename BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1, typename BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
typename CThreadTransferSrcDstAccessOrder, typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim, index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector> index_t CThreadTransferDstScalarPerVector,
GemmDlAlgorithm GemmDlAlg = GemmDlAlgorithm::Default>
struct GridwiseGemmDl_km_kn_mn_v1r3 struct GridwiseGemmDl_km_kn_mn_v1r3
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
...@@ -244,6 +257,45 @@ struct GridwiseGemmDl_km_kn_mn_v1r3 ...@@ -244,6 +257,45 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
c_grid_desc_m_n); c_grid_desc_m_n);
} }
template <typename ABlockDesc_BK0_BM_BK1, typename BBlockDesc_BK0_BN_BK1>
__host__ __device__ static constexpr auto GetBlockwiseGemm()
{
if constexpr(GemmDlAlg == GemmDlAlgorithm::Dpp8)
{
return BlockwiseGemmDlDpp8_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_loop_BM0_BN0<
BlockSize,
FloatAB,
FloatAB,
FloatAcc,
ABlockDesc_BK0_BM_BK1,
BBlockDesc_BK0_BN_BK1,
M1PerThreadM111,
N1PerThreadN111,
KPerThread,
M11N11ThreadClusterM110Xs,
M11N11ThreadClusterN110Xs,
M1PerThreadM111,
N1PerThreadN111>{};
}
else
{
return BlockwiseGemmDl_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2<
BlockSize,
FloatAB,
FloatAB,
FloatAcc,
ABlockDesc_BK0_BM_BK1,
BBlockDesc_BK0_BN_BK1,
M1PerThreadM111,
N1PerThreadN111,
KPerThread,
M11N11ThreadClusterM110Xs,
M11N11ThreadClusterN110Xs,
M1PerThreadM111,
N1PerThreadN111>{};
}
}
using AGridDesc_K0_M0_M1_K1 = decltype(MakeAGridDescriptor_K0_M0_M1_K1(AGridDesc_K0_M_K1{})); using AGridDesc_K0_M0_M1_K1 = decltype(MakeAGridDescriptor_K0_M0_M1_K1(AGridDesc_K0_M_K1{}));
using BGridDesc_K0_N0_N1_K1 = decltype(MakeBGridDescriptor_K0_N0_N1_K1(BGridDesc_K0_N_K1{})); using BGridDesc_K0_N0_N1_K1 = decltype(MakeBGridDescriptor_K0_N0_N1_K1(BGridDesc_K0_N_K1{}));
using CGridDesc_M0_M10_M11_N0_N10_N11 = using CGridDesc_M0_M10_M11_N0_N10_N11 =
...@@ -274,7 +326,7 @@ struct GridwiseGemmDl_km_kn_mn_v1r3 ...@@ -274,7 +326,7 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
const auto c_m0_n0_block_cluster_idx = const auto c_m0_n0_block_cluster_idx =
block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
// HACK: this force index data into SGPR // HACK: this forces index data into SGPR
const index_t im0 = __builtin_amdgcn_readfirstlane(c_m0_n0_block_cluster_idx[I0]); const index_t im0 = __builtin_amdgcn_readfirstlane(c_m0_n0_block_cluster_idx[I0]);
const index_t in0 = __builtin_amdgcn_readfirstlane(c_m0_n0_block_cluster_idx[I1]); const index_t in0 = __builtin_amdgcn_readfirstlane(c_m0_n0_block_cluster_idx[I1]);
...@@ -372,20 +424,7 @@ struct GridwiseGemmDl_km_kn_mn_v1r3 ...@@ -372,20 +424,7 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register // register
const auto blockwise_gemm = const auto blockwise_gemm =
BlockwiseGemmDl_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2< GetBlockwiseGemm<decltype(a_k0_m_k1_block_desc), decltype(b_k0_n_k1_block_desc)>();
BlockSize,
FloatAB,
FloatAB,
FloatAcc,
decltype(a_k0_m_k1_block_desc),
decltype(b_k0_n_k1_block_desc),
M1PerThreadM111,
N1PerThreadN111,
KPerThread,
M11N11ThreadClusterM110Xs,
M11N11ThreadClusterN110Xs,
M1PerThreadM111,
N1PerThreadN111>{};
constexpr auto c_m10_m11_n10_n11_thread_tensor_lengths = constexpr auto c_m10_m11_n10_n11_thread_tensor_lengths =
decltype(blockwise_gemm)::GetCThreadTensorLengths_BM0_BM1_BN0_BN1(); decltype(blockwise_gemm)::GetCThreadTensorLengths_BM0_BM1_BN0_BN1();
...@@ -472,7 +511,7 @@ struct GridwiseGemmDl_km_kn_mn_v1r3 ...@@ -472,7 +511,7 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n0_n1_k1, b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n0_n1_k1,
b_block_slice_copy_step); b_block_slice_copy_step);
// LDS doubel buffer: load next data from device mem // LDS double buffer: load next data from device mem
a_blockwise_copy.RunRead(a_grid_desc_k0_m0_m1_k1, a_global_buf); a_blockwise_copy.RunRead(a_grid_desc_k0_m0_m1_k1, a_global_buf);
b_blockwise_copy.RunRead(b_grid_desc_k0_n0_n1_k1, b_global_buf); b_blockwise_copy.RunRead(b_grid_desc_k0_n0_n1_k1, b_global_buf);
...@@ -992,7 +1031,7 @@ struct GridwiseGemmDl_bkm_bkn_mn_v1r3 ...@@ -992,7 +1031,7 @@ struct GridwiseGemmDl_bkm_bkn_mn_v1r3
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_b_k0_n0_n1_k1, b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_b_k0_n0_n1_k1,
b_block_slice_copy_step); b_block_slice_copy_step);
// LDS doubel buffer: load next data from device mem // LDS double buffer: load next data from device mem
a_blockwise_copy.RunRead(a_grid_desc_b_k0_m0_m1_k1, a_global_buf); a_blockwise_copy.RunRead(a_grid_desc_b_k0_m0_m1_k1, a_global_buf);
b_blockwise_copy.RunRead(b_grid_desc_b_k0_n0_n1_k1, b_global_buf); b_blockwise_copy.RunRead(b_grid_desc_b_k0_n0_n1_k1, b_global_buf);
......
...@@ -92,8 +92,8 @@ struct GridwiseGemmMultipleDMultipleR_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -92,8 +92,8 @@ struct GridwiseGemmMultipleDMultipleR_k0mk1_k0nk1_mn_xdl_cshuffle_v1
using ThisThreadBlock = ThisThreadBlock<BlockSize>; using ThisThreadBlock = ThisThreadBlock<BlockSize>;
using GridwiseGemmPipe = remove_cvref_t<decltype( using GridwiseGemmPipe = remove_cvref_t<
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage>())>; decltype(GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage>())>;
__host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
{ {
...@@ -300,8 +300,9 @@ struct GridwiseGemmMultipleDMultipleR_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -300,8 +300,9 @@ struct GridwiseGemmMultipleDMultipleR_k0mk1_k0nk1_mn_xdl_cshuffle_v1
remove_cvref_t<decltype(MakeDefaultAGridDescriptor_AK0_M_AK1(AGridDesc_M_K{}))>; remove_cvref_t<decltype(MakeDefaultAGridDescriptor_AK0_M_AK1(AGridDesc_M_K{}))>;
using DefaultBGridDesc_BK0_N_BK1 = using DefaultBGridDesc_BK0_N_BK1 =
remove_cvref_t<decltype(MakeDefaultBGridDescriptor_BK0_N_BK1(BGridDesc_N_K{}))>; remove_cvref_t<decltype(MakeDefaultBGridDescriptor_BK0_N_BK1(BGridDesc_N_K{}))>;
using EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype( using EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock =
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{}))>; remove_cvref_t<decltype(MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
EGridDesc_M_N{}))>;
// Support 2 dimension in the future. Not only M // Support 2 dimension in the future. Not only M
using RGridDescriptor_MBlock_MPerBlock = using RGridDescriptor_MBlock_MPerBlock =
......
...@@ -346,8 +346,8 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle ...@@ -346,8 +346,8 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
using ThisThreadBlock = ThisThreadBlock<BlockSize>; using ThisThreadBlock = ThisThreadBlock<BlockSize>;
using GridwiseGemmPipe = remove_cvref_t<decltype( using GridwiseGemmPipe = remove_cvref_t<
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>())>; decltype(GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>())>;
__host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1() __host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1()
{ {
...@@ -565,10 +565,12 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle ...@@ -565,10 +565,12 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
e_grid_desc_m_n); e_grid_desc_m_n);
} }
using DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype( using DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock =
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(DsGridDesc_M_N{}))>; remove_cvref_t<decltype(MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
using EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype( DsGridDesc_M_N{}))>;
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{}))>; using EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock =
remove_cvref_t<decltype(MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
EGridDesc_M_N{}))>;
using DefaultBlock2CTileMap = using DefaultBlock2CTileMap =
remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(EGridDesc_M_N{}, 1, 1))>; remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(EGridDesc_M_N{}, 1, 1))>;
using DsGridPointer = decltype(MakeDsGridPointer()); using DsGridPointer = decltype(MakeDsGridPointer());
......
...@@ -26,7 +26,9 @@ namespace ck { ...@@ -26,7 +26,9 @@ namespace ck {
// E = cde_op(C, D0, D1, ...) // E = cde_op(C, D0, D1, ...)
// Assume: // Assume:
// D0, D1, ... and E have the same layout // D0, D1, ... and E have the same layout
template <typename ABDataType, // FIXME: don't assume A/B have same datatype template <typename ADataType,
typename BDataType,
typename ComputeDataType_,
typename AccDataType, typename AccDataType,
typename CShuffleDataType, typename CShuffleDataType,
typename DsDataType, typename DsDataType,
...@@ -89,18 +91,14 @@ struct GridwiseGemmMultipleD_xdl_cshuffle ...@@ -89,18 +91,14 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
using ThisThreadBlock = ThisThreadBlock<BlockSize>; using ThisThreadBlock = ThisThreadBlock<BlockSize>;
using GridwiseGemmPipe = remove_cvref_t<decltype( using GridwiseGemmPipe = remove_cvref_t<
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>())>; decltype(GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>())>;
// denorm test fix, required to work around fp16 mfma issue
// we convert fp16->fp32->bf16 and execute bf16 mfma instruction
// when mfma if fixed, remove this section and update
// ABDataTypeAdjusted -> ABDataType throughout this file
#if CK_WORKAROUND_DENORM_FIX #if CK_WORKAROUND_DENORM_FIX
using ABDataTypeAdjusted = using ComputeDataType =
conditional_t<is_same_v<ABDataType, ck::half_t>, ck::bhalf_t, ABDataType>; conditional_t<is_same_v<ComputeDataType_, ck::half_t>, ck::bhalf_t, ComputeDataType_>;
#else #else
using ABDataTypeAdjusted = ABDataType; using ComputeDataType = ComputeDataType_;
#endif #endif
__host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
...@@ -170,7 +168,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle ...@@ -170,7 +168,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
return math::max((a_block_space_size_aligned + b_block_space_size_aligned) * return math::max((a_block_space_size_aligned + b_block_space_size_aligned) *
sizeof(ABDataType), sizeof(ComputeDataType),
c_block_size * sizeof(CShuffleDataType)); c_block_size * sizeof(CShuffleDataType));
} }
...@@ -266,12 +264,13 @@ struct GridwiseGemmMultipleD_xdl_cshuffle ...@@ -266,12 +264,13 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
(NPerBlock % (NXdlPerWave * NPerXdl)) == 0, (NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
"Invalid tuning param!"); "Invalid tuning param!");
const auto M = a_grid_desc_m_k.GetLength(I0); const auto M = a_grid_desc_m_k.GetLength(I0);
const auto N = b_grid_desc_n_k.GetLength(I0); const auto N = b_grid_desc_n_k.GetLength(I0);
const auto K = a_grid_desc_m_k.GetLength(I1); const auto AK = a_grid_desc_m_k.GetLength(I1);
const auto BK = b_grid_desc_n_k.GetLength(I1);
// check consistency of desc // check consistency of desc
if(!(M == e_grid_desc_m_n.GetLength(I0) && N == e_grid_desc_m_n.GetLength(I1))) if(!(M == e_grid_desc_m_n.GetLength(I0) && N == e_grid_desc_m_n.GetLength(I1) && AK == BK))
{ {
return false; return false;
} }
...@@ -289,13 +288,13 @@ struct GridwiseGemmMultipleD_xdl_cshuffle ...@@ -289,13 +288,13 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
} }
// check tile size // check tile size
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0)) if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && AK % KPerBlock == 0))
{ {
return false; return false;
} }
// check gridwise gemm pipeline // check gridwise gemm pipeline
const auto num_k_loop = K / KPerBlock; const auto num_k_loop = AK / KPerBlock;
if(!GridwiseGemmPipe::IsSupported(num_k_loop)) if(!GridwiseGemmPipe::IsSupported(num_k_loop))
{ {
...@@ -312,8 +311,8 @@ struct GridwiseGemmMultipleD_xdl_cshuffle ...@@ -312,8 +311,8 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
// check tensor size: cannot be larger than 2GB each // check tensor size: cannot be larger than 2GB each
constexpr long_index_t TwoGB = (long_index_t{1} << 31); constexpr long_index_t TwoGB = (long_index_t{1} << 31);
if(!(a_grid_desc_m_k.GetElementSpaceSize() * sizeof(ABDataType) <= TwoGB && if(!(a_grid_desc_m_k.GetElementSpaceSize() * sizeof(ADataType) <= TwoGB &&
b_grid_desc_n_k.GetElementSpaceSize() * sizeof(ABDataType) <= TwoGB && b_grid_desc_n_k.GetElementSpaceSize() * sizeof(BDataType) <= TwoGB &&
e_grid_desc_m_n.GetElementSpaceSize() * sizeof(EDataType) <= TwoGB)) e_grid_desc_m_n.GetElementSpaceSize() * sizeof(EDataType) <= TwoGB))
{ {
return false; return false;
...@@ -337,8 +336,8 @@ struct GridwiseGemmMultipleD_xdl_cshuffle ...@@ -337,8 +336,8 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
typename DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, typename DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
typename EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, typename EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
typename Block2ETileMap> typename Block2ETileMap>
__device__ static void Run(const ABDataType* __restrict__ p_a_grid, __device__ static void Run(const ADataType* __restrict__ p_a_grid,
const ABDataType* __restrict__ p_b_grid, const BDataType* __restrict__ p_b_grid,
DsGridPointer p_ds_grid, DsGridPointer p_ds_grid,
EDataType* __restrict__ p_e_grid, EDataType* __restrict__ p_e_grid,
void* __restrict__ p_shared, void* __restrict__ p_shared,
...@@ -407,8 +406,8 @@ struct GridwiseGemmMultipleD_xdl_cshuffle ...@@ -407,8 +406,8 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
Sequence<AK0PerBlock, MPerBlock, AK1>, Sequence<AK0PerBlock, MPerBlock, AK1>,
ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder, ABlockTransferThreadClusterArrangeOrder,
ABDataType, ADataType,
ABDataTypeAdjusted, ComputeDataType,
decltype(a_grid_desc_ak0_m_ak1), decltype(a_grid_desc_ak0_m_ak1),
decltype(a_block_desc_ak0_m_ak1), decltype(a_block_desc_ak0_m_ak1),
ABlockTransferSrcAccessOrder, ABlockTransferSrcAccessOrder,
...@@ -438,8 +437,8 @@ struct GridwiseGemmMultipleD_xdl_cshuffle ...@@ -438,8 +437,8 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
Sequence<BK0PerBlock, NPerBlock, BK1>, Sequence<BK0PerBlock, NPerBlock, BK1>,
BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder, BBlockTransferThreadClusterArrangeOrder,
ABDataType, BDataType,
ABDataTypeAdjusted, ComputeDataType,
decltype(b_grid_desc_bk0_n_bk1), decltype(b_grid_desc_bk0_n_bk1),
decltype(b_block_desc_bk0_n_bk1), decltype(b_block_desc_bk0_n_bk1),
BBlockTransferSrcAccessOrder, BBlockTransferSrcAccessOrder,
...@@ -469,11 +468,11 @@ struct GridwiseGemmMultipleD_xdl_cshuffle ...@@ -469,11 +468,11 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
// sanity check // sanity check
constexpr index_t KPack = constexpr index_t KPack =
math::max(math::lcm(AK1, BK1), math::max(math::lcm(AK1, BK1),
MfmaSelector<ABDataTypeAdjusted, MPerXdl, NPerXdl>::selected_mfma.k_per_blk); MfmaSelector<ComputeDataType, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
BlockSize, BlockSize,
ABDataTypeAdjusted, ComputeDataType,
AccDataType, AccDataType,
decltype(a_block_desc_ak0_m_ak1), decltype(a_block_desc_ak0_m_ak1),
decltype(b_block_desc_bk0_n_bk1), decltype(b_block_desc_bk0_n_bk1),
...@@ -491,11 +490,10 @@ struct GridwiseGemmMultipleD_xdl_cshuffle ...@@ -491,11 +490,10 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<ABDataTypeAdjusted*>(p_shared), static_cast<ComputeDataType*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
a_block_desc_ak0_m_ak1.GetElementSpaceSize());
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<ABDataTypeAdjusted*>(p_shared) + a_block_space_size_aligned, static_cast<ComputeDataType*>(p_shared) + a_block_space_size_aligned,
b_block_desc_bk0_n_bk1.GetElementSpaceSize()); b_block_desc_bk0_n_bk1.GetElementSpaceSize());
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0); constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0);
......
...@@ -3,6 +3,8 @@ ...@@ -3,6 +3,8 @@
#pragma once #pragma once
#include <iostream>
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v2.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v2.hpp"
......
...@@ -79,6 +79,10 @@ struct GridwiseGemmPipeline_v2 ...@@ -79,6 +79,10 @@ struct GridwiseGemmPipeline_v2
do do
{ {
#if CK_EXPERIMENTAL_PIPELINE_V2_IGLP_OPT
__builtin_amdgcn_iglp_opt(CK_EXPERIMENTAL_PIPELINE_V2_IGLP_OPT);
#endif
block_sync_lds(); block_sync_lds();
// GEMM i // GEMM i
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment