Commit 349552e2 authored by carlushuang's avatar carlushuang
Browse files

Merge remote-tracking branch 'origin/develop' into stream-k-initial-impl

parents daa98dae 87f2bbcf
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp" #include "common.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp" #include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp"
using InDataType = int8_t; using InDataType = int8_t;
using WeiDataType = int8_t; using WeiDataType = int8_t;
......
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_splitk_contraction_multiple_d_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/device/impl/device_splitk_contraction_multiple_d_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp" #include "ck/library/utility/check_err.hpp"
......
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_splitk_contraction_multiple_d_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/device/impl/device_splitk_contraction_multiple_d_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp" #include "ck/library/utility/check_err.hpp"
......
...@@ -27,6 +27,21 @@ ...@@ -27,6 +27,21 @@
#define CK_WAVELET_MIN_BLOCK_PER_CU 2 #define CK_WAVELET_MIN_BLOCK_PER_CU 2
#endif #endif
// kernel attribute: amdgpu_waves_per_eu()
#ifdef CK_USE_WAVES_PER_EU
// for 1-wave kernels, control arguments of amdgpu_waves_per_eu() attribute
#ifndef CK_MIN_WAVES_PER_EU
#define CK_MIN_WAVES_PER_EU 0
#endif
#ifndef CK_MAX_WAVES_PER_EU
#define CK_MAX_WAVES_PER_EU 0
#endif
#else
#define CK_USE_WAVES_PER_EU 0
#endif
// buffer resource // buffer resource
#ifndef __HIP_DEVICE_COMPILE__ // for host code #ifndef __HIP_DEVICE_COMPILE__ // for host code
#define CK_BUFFER_RESOURCE_3RD_DWORD -1 #define CK_BUFFER_RESOURCE_3RD_DWORD -1
...@@ -148,6 +163,10 @@ ...@@ -148,6 +163,10 @@
#define CK_EXPERIMENTAL_INTER_WAVE_INSTANCES 1 #define CK_EXPERIMENTAL_INTER_WAVE_INSTANCES 1
// experimental feature: add instances using pipeline v2 // experimental feature: add instances using pipeline v2
#define CK_EXPERIMENTAL_PIPELINE_V2_INSTANCES 1 #define CK_EXPERIMENTAL_PIPELINE_V2_INSTANCES 1
// experimental feature: optimize pipeline v2 by IGLP strategy (value=ID of strategy)
#ifndef CK_EXPERIMENTAL_PIPELINE_V2_IGLP_OPT
#define CK_EXPERIMENTAL_PIPELINE_V2_IGLP_OPT 0
#endif
// hack: have underlying assumption that need to be satsified, otherwise it's a bug // hack: have underlying assumption that need to be satsified, otherwise it's a bug
// hack for forcing register to keep idx_diff_low_const in SGPR. idx_diff_low_const must be // hack for forcing register to keep idx_diff_low_const in SGPR. idx_diff_low_const must be
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
#pragma once #pragma once
#include "ck/tensor_description/cluster_descriptor.hpp" #include "ck/tensor_description/cluster_descriptor.hpp"
#include "ck/utility/reduction_common.hpp" #include "ck/utility/get_shift.hpp"
namespace ck { namespace ck {
...@@ -35,10 +35,11 @@ struct BlockwiseWelford ...@@ -35,10 +35,11 @@ struct BlockwiseWelford
static constexpr auto thread_cluster_desc = static constexpr auto thread_cluster_desc =
make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{}); make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{});
template <typename CountDataType>
__device__ static inline void __device__ static inline void
Merge(T& mean_a, T& var_a, int& count_a, T mean_b, T var_b, int count_b) Merge(T& mean_a, T& var_a, CountDataType& count_a, T mean_b, T var_b, CountDataType count_b)
{ {
int count = count_a + count_b; CountDataType count = count_a + count_b;
T count_b_over_count = count == 0 ? type_convert<T>(0) : type_convert<T>(count_b) / count; T count_b_over_count = count == 0 ? type_convert<T>(0) : type_convert<T>(count_b) / count;
T delta = mean_b - mean_a; T delta = mean_b - mean_a;
mean_a += delta * count_b_over_count; mean_a += delta * count_b_over_count;
...@@ -46,11 +47,12 @@ struct BlockwiseWelford ...@@ -46,11 +47,12 @@ struct BlockwiseWelford
count_a = count; count_a = count;
} }
__device__ static void Run(T& mean_value, T& var_value, int& count) template <typename CountDataType>
__device__ static void Run(T& mean_value, T& var_value, CountDataType& count)
{ {
__shared__ T mean_block_buf[BlockSize]; __shared__ T mean_block_buf[BlockSize];
__shared__ T var_block_buf[BlockSize]; __shared__ T var_block_buf[BlockSize];
__shared__ int count_block_buf[BlockSize]; __shared__ CountDataType count_block_buf[BlockSize];
constexpr auto cluster_len_shift = get_shift<BufferLength_K>(); constexpr auto cluster_len_shift = get_shift<BufferLength_K>();
...@@ -76,13 +78,13 @@ struct BlockwiseWelford ...@@ -76,13 +78,13 @@ struct BlockwiseWelford
index_t offset2 = block_buf_desc_m_k.CalculateOffset(thread_cluster_idx + index_t offset2 = block_buf_desc_m_k.CalculateOffset(thread_cluster_idx +
make_tuple(0, indOffset)); make_tuple(0, indOffset));
T mean1 = mean_block_buf[offset1]; T mean1 = mean_block_buf[offset1];
T var1 = var_block_buf[offset1]; T var1 = var_block_buf[offset1];
int count1 = count_block_buf[offset1]; CountDataType count1 = count_block_buf[offset1];
T mean2 = mean_block_buf[offset2]; T mean2 = mean_block_buf[offset2];
T var2 = var_block_buf[offset2]; T var2 = var_block_buf[offset2];
int count2 = count_block_buf[offset2]; CountDataType count2 = count_block_buf[offset2];
Merge(mean1, var1, count1, mean2, var2, count2); Merge(mean1, var1, count1, mean2, var2, count2);
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
#pragma once #pragma once
#include "ck/tensor_description/cluster_descriptor.hpp" #include "ck/tensor_description/cluster_descriptor.hpp"
#include "ck/utility/reduction_common.hpp" #include "ck/utility/get_shift.hpp"
#include "ck/utility/reduction_functions_accumulate.hpp" #include "ck/utility/reduction_functions_accumulate.hpp"
namespace ck { namespace ck {
......
...@@ -10,12 +10,14 @@ ...@@ -10,12 +10,14 @@
#include "ck/tensor_operation/gpu/device/device_batchnorm_forward.hpp" #include "ck/tensor_operation/gpu/device/device_batchnorm_forward.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp" #include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp"
#include "ck/tensor_operation/gpu/device/welford_helper.hpp" #include "ck/tensor_operation/gpu/device/welford_helper.hpp"
#include "ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_batchnorm_forward.hpp"
#include "ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_first_half.hpp" #include "ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_first_half.hpp"
#include "ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_second_half_batchnorm_forward_final.hpp" #include "ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_second_half_batchnorm_forward_final_obsolete.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batchnorm_forward_blockwise_welford.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_batchnorm_forward_blockwise_welford.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/hip_check_error.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
...@@ -114,8 +116,8 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType, ...@@ -114,8 +116,8 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType,
static auto MakeMeanVarCountOutputMG2dDescriptor(int invariantLength, int blkGroupSize) static auto MakeMeanVarCountOutputMG2dDescriptor(int invariantLength, int blkGroupSize)
{ {
const auto grid_desc_m_g = const auto grid_desc_m_g = make_naive_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(invariantLength, blkGroupSize)); make_tuple(invariantLength, blkGroupSize), make_tuple(1, invariantLength));
const auto mPad = const auto mPad =
math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength; math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
...@@ -132,9 +134,9 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType, ...@@ -132,9 +134,9 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType,
static auto MakeMeanVarCountInputMK2dDescriptor(int invariantLength, int blkGroupSize) static auto MakeMeanVarCountInputMK2dDescriptor(int invariantLength, int blkGroupSize)
{ {
const auto reduceLength = blkGroupSize; const auto reduceLength = blkGroupSize;
const auto grid_desc_m_k = const auto grid_desc_m_k = make_naive_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(invariantLength, reduceLength)); make_tuple(invariantLength, reduceLength), make_tuple(1, invariantLength));
const auto mPad = const auto mPad =
math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength; math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
...@@ -244,8 +246,8 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType, ...@@ -244,8 +246,8 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType,
int testBlkGroupSize = (reduce_length_ + (K_BlockTileSize * iterations) - 1) / int testBlkGroupSize = (reduce_length_ + (K_BlockTileSize * iterations) - 1) /
(K_BlockTileSize * iterations); (K_BlockTileSize * iterations);
// we want the blkGroupSize be not more than 128 // we want the blkGroupSize be not more than 16
if(testBlkGroupSize <= 128) if(testBlkGroupSize <= 16)
break; break;
iterations++; iterations++;
...@@ -319,6 +321,8 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType, ...@@ -319,6 +321,8 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType,
void* workspace_mean_; void* workspace_mean_;
void* workspace_variance_; void* workspace_variance_;
void* workspace_count_; void* workspace_count_;
void* control_;
}; };
size_t GetWorkSpaceSize(const BaseArgument* pArg) const override size_t GetWorkSpaceSize(const BaseArgument* pArg) const override
...@@ -340,6 +344,11 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType, ...@@ -340,6 +344,11 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType,
// workspace for welford intermediate count // workspace for welford intermediate count
workspace_size += workspace_size +=
pArg_->invariant_length_ * pArg_->blkGroupSize_ * sizeof(int32_t) + 64; pArg_->invariant_length_ * pArg_->blkGroupSize_ * sizeof(int32_t) + 64;
// workspace for barrier objects, each barrier object consists of two integers
// TODO: allocate barrier object memory globally to reuse it by other operators
workspace_size += (pArg_->invariant_length_ + M_BlockTileSize - 1) / M_BlockTileSize *
sizeof(int) * 2;
} }
return (workspace_size); return (workspace_size);
...@@ -353,7 +362,6 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType, ...@@ -353,7 +362,6 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType,
if(UseMultiblockInK && pArg_->blkGroupSize_ > 1) if(UseMultiblockInK && pArg_->blkGroupSize_ > 1)
{ {
// setup buffer used for intermediate welford mean // setup buffer used for intermediate welford mean
pArg_->workspace_mean_ = static_cast<char*>(pArg_->p_workspace_); pArg_->workspace_mean_ = static_cast<char*>(pArg_->p_workspace_);
...@@ -374,6 +382,18 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType, ...@@ -374,6 +382,18 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType,
// setup buffer used for intermediate welfor count // setup buffer used for intermediate welfor count
pArg_->workspace_count_ = pArg_->workspace_count_ =
reinterpret_cast<char*>(pArg_->workspace_variance_) + variance_space_sz; reinterpret_cast<char*>(pArg_->workspace_variance_) + variance_space_sz;
index_t count_space_sz =
pArg_->invariant_length_ * pArg_->blkGroupSize_ * sizeof(int32_t);
count_space_sz = math::integer_least_multiple(count_space_sz, 64);
pArg_->control_ = reinterpret_cast<char*>(pArg_->workspace_count_) + count_space_sz;
index_t control_space_sz = (pArg_->invariant_length_ + M_BlockTileSize - 1) /
M_BlockTileSize * sizeof(int) * 2;
hip_check_error(hipMemset(pArg_->control_, 0, control_space_sz));
}; };
}; };
...@@ -402,6 +422,32 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType, ...@@ -402,6 +422,32 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType,
using MeanVarCountGridDesc_M_G = decltype(mean_var_count_grid_desc_m_g); using MeanVarCountGridDesc_M_G = decltype(mean_var_count_grid_desc_m_g);
using MeanVarCountGridDesc_M_K = decltype(mean_var_count_grid_desc_m_k); using MeanVarCountGridDesc_M_K = decltype(mean_var_count_grid_desc_m_k);
using GridwiseMultiblockBatchNormForward_ =
GridwiseMultiblockBatchNormForward<XDataType,
YDataType,
AccDataType,
ScaleDataType,
BiasDataType,
MeanVarDataType,
YElementwiseOp,
XYGridDesc_M_K,
MeanVarCountGridDesc_M_G,
MeanVarCountGridDesc_M_K,
ScaleBiasMeanVarGridDesc_M,
ScaleBiasMeanVarGridDesc_M,
GetReduceCountPerThreadFunctor,
BlockSize,
MThreadClusterSize,
KThreadClusterSize,
MThreadSliceSize,
KThreadSliceSize,
XSrcYDstVectorDim,
XSrcVectorSize,
YDstVectorSize,
ScaleSrcVectorSize,
BiasSrcVectorSize,
MeanVarSrcDstVectorSize>;
using GridwiseMultiblockWelfordFirstHalf_ = using GridwiseMultiblockWelfordFirstHalf_ =
GridwiseMultiblockWelfordFirstHalf<XDataType, GridwiseMultiblockWelfordFirstHalf<XDataType,
AccDataType, AccDataType,
...@@ -441,78 +487,136 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType, ...@@ -441,78 +487,136 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType,
BiasSrcVectorSize, BiasSrcVectorSize,
MeanVarSrcDstVectorSize>; MeanVarSrcDstVectorSize>;
index_t numMeanVarCountBlockTileIteration = // It is found that:
(arg.blkGroupSize_ + KThreadClusterSize - 1) / KThreadClusterSize; // 1) gfx1030 does not support the GLC enabled vector load/store, so using the
// two-kernel method for gfx1030
const auto kern_multiblock_welford_first_half = // 2) Profiler on gfx908 could hang even though it works when running examples
kernel_multiblock_welford_first_half<GridwiseMultiblockWelfordFirstHalf_, // 3) Single-kernel method works on gfx1100, but the performance it not better
XDataType, // than two-kernel method (due to more warps participating the barrier)
MeanVarDataType, if(ck::get_device_name() == "gfx90a")
XYGridDesc_M_K, {
MeanVarCountGridDesc_M_G, const auto kern_multiblock_batchnorm_fwd_ =
GetReduceCountPerThreadFunctor>; kernel_multiblock_batchnorm_forward<GridwiseMultiblockBatchNormForward_,
XDataType,
const auto kern_welford_second_half_batchnorm_forward_final = YDataType,
kernel_welford_second_half_batchnorm_forward_final< AccDataType,
GridwiseWelfordSecondHalfBatchNormForwardFinal_, ScaleDataType,
XDataType, BiasDataType,
YDataType, MeanVarDataType,
AccDataType, YElementwiseOp,
ScaleDataType, XYGridDesc_M_K,
BiasDataType, MeanVarCountGridDesc_M_G,
MeanVarDataType, MeanVarCountGridDesc_M_K,
YElementwiseOp, ScaleBiasMeanVarGridDesc_M,
XYGridDesc_M_K, ScaleBiasMeanVarGridDesc_M,
MeanVarCountGridDesc_M_K, GetReduceCountPerThreadFunctor>;
ScaleBiasMeanVarGridDesc_M,
ScaleBiasMeanVarGridDesc_M>; avg_time += launch_and_time_kernel(
stream_config,
avg_time += kern_multiblock_batchnorm_fwd_,
launch_and_time_kernel(stream_config, dim3(arg.gridSize_),
kern_multiblock_welford_first_half, dim3(BlockSize),
dim3(arg.gridSize_), 0,
dim3(BlockSize), arg.x_grid_desc_m_k_,
0, arg.y_grid_desc_m_k_,
arg.x_grid_desc_m_k_, mean_var_count_grid_desc_m_g, // for writing to mean/variance/count
mean_var_count_grid_desc_m_g, // workspace by multiple workgroups
get_reduce_count_per_thread, mean_var_count_grid_desc_m_k, // for reading from mean/variance/count
arg.numBlockTileIteration_, // workspace by each workgroup
arg.p_x_, arg.scale_grid_desc_m_,
static_cast<MeanVarDataType*>(arg.workspace_mean_), arg.bias_grid_desc_m_,
static_cast<MeanVarDataType*>(arg.workspace_variance_), arg.mean_var_grid_desc_m_,
static_cast<int32_t*>(arg.workspace_count_)); get_reduce_count_per_thread,
arg.numBlockTileIteration_,
avg_time += arg.epsilon_,
launch_and_time_kernel(stream_config, arg.p_x_,
kern_welford_second_half_batchnorm_forward_final, static_cast<MeanVarDataType*>(arg.workspace_mean_),
dim3(arg.gridSize_), static_cast<MeanVarDataType*>(arg.workspace_variance_),
dim3(BlockSize), static_cast<int32_t*>(arg.workspace_count_),
0, static_cast<int*>(arg.control_),
arg.x_grid_desc_m_k_, arg.p_scale_,
arg.y_grid_desc_m_k_, arg.p_bias_,
mean_var_count_grid_desc_m_k, arg.y_elementwise_op_,
arg.scale_grid_desc_m_, arg.p_y_,
arg.bias_grid_desc_m_, arg.updateMovingAverage_, // true or false
arg.mean_var_grid_desc_m_, arg.averageFactor_,
arg.blkGroupSize_, arg.resultRunningMean_,
arg.numBlockTileIteration_, arg.resultRunningVariance_,
numMeanVarCountBlockTileIteration, arg.saveMeanInvVariance_, // true or false
arg.epsilon_, arg.resultSaveMean_,
static_cast<MeanVarDataType*>(arg.workspace_mean_), arg.resultSaveInvVariance_);
static_cast<MeanVarDataType*>(arg.workspace_variance_), }
static_cast<int32_t*>(arg.workspace_count_), else
arg.p_x_, {
arg.p_scale_, const auto kern_multiblock_welford_first_half =
arg.p_bias_, kernel_multiblock_welford_first_half<GridwiseMultiblockWelfordFirstHalf_,
arg.y_elementwise_op_, XDataType,
arg.p_y_, MeanVarDataType,
arg.updateMovingAverage_, XYGridDesc_M_K,
arg.averageFactor_, MeanVarCountGridDesc_M_G,
arg.resultRunningMean_, GetReduceCountPerThreadFunctor>;
arg.resultRunningVariance_,
arg.saveMeanInvVariance_, const auto kern_welford_second_half_batchnorm_forward_final =
arg.resultSaveMean_, kernel_welford_second_half_batchnorm_forward_final<
arg.resultSaveInvVariance_); GridwiseWelfordSecondHalfBatchNormForwardFinal_,
XDataType,
YDataType,
AccDataType,
ScaleDataType,
BiasDataType,
MeanVarDataType,
YElementwiseOp,
XYGridDesc_M_K,
MeanVarCountGridDesc_M_K,
ScaleBiasMeanVarGridDesc_M,
ScaleBiasMeanVarGridDesc_M>;
avg_time += launch_and_time_kernel(
stream_config,
kern_multiblock_welford_first_half,
dim3(arg.gridSize_),
dim3(BlockSize),
0,
arg.x_grid_desc_m_k_,
mean_var_count_grid_desc_m_g,
get_reduce_count_per_thread,
arg.numBlockTileIteration_,
arg.p_x_,
static_cast<MeanVarDataType*>(arg.workspace_mean_),
static_cast<MeanVarDataType*>(arg.workspace_variance_),
static_cast<int32_t*>(arg.workspace_count_));
avg_time += launch_and_time_kernel(
stream_config,
kern_welford_second_half_batchnorm_forward_final,
dim3(arg.gridSize_),
dim3(BlockSize),
0,
arg.x_grid_desc_m_k_,
arg.y_grid_desc_m_k_,
mean_var_count_grid_desc_m_k,
arg.scale_grid_desc_m_,
arg.bias_grid_desc_m_,
arg.mean_var_grid_desc_m_,
arg.blkGroupSize_,
arg.numBlockTileIteration_,
arg.epsilon_,
static_cast<MeanVarDataType*>(arg.workspace_mean_),
static_cast<MeanVarDataType*>(arg.workspace_variance_),
static_cast<int32_t*>(arg.workspace_count_),
arg.p_x_,
arg.p_scale_,
arg.p_bias_,
arg.y_elementwise_op_,
arg.p_y_,
arg.updateMovingAverage_,
arg.averageFactor_,
arg.resultRunningMean_,
arg.resultRunningVariance_,
arg.saveMeanInvVariance_,
arg.resultSaveMean_,
arg.resultSaveInvVariance_);
};
} }
else else
{ {
......
...@@ -161,7 +161,7 @@ struct GridwiseMultiblockWelfordFirstHalf ...@@ -161,7 +161,7 @@ struct GridwiseMultiblockWelfordFirstHalf
PassThroughOp, PassThroughOp,
ThreadBufferLengths_M_1, ThreadBufferLengths_M_1,
Sequence<0, 1>, Sequence<0, 1>,
1, 0,
1, 1,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
1, 1,
...@@ -180,7 +180,7 @@ struct GridwiseMultiblockWelfordFirstHalf ...@@ -180,7 +180,7 @@ struct GridwiseMultiblockWelfordFirstHalf
PassThroughOp, PassThroughOp,
ThreadBufferLengths_M_1, ThreadBufferLengths_M_1,
Sequence<0, 1>, Sequence<0, 1>,
1, 0,
1, 1,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
1, 1,
......
...@@ -33,7 +33,6 @@ __global__ void kernel_welford_second_half_batchnorm_forward_final( ...@@ -33,7 +33,6 @@ __global__ void kernel_welford_second_half_batchnorm_forward_final(
const MeanVarGridDesc_M mean_var_grid_desc_m, const MeanVarGridDesc_M mean_var_grid_desc_m,
index_t blkgroup_size, index_t blkgroup_size,
index_t num_xy_k_block_tile_iteration, index_t num_xy_k_block_tile_iteration,
index_t num_mean_var_count_k_block_tile_iteration,
AccDataType epsilon, AccDataType epsilon,
const MeanVarDataType* const __restrict__ p_in_welford_mean, const MeanVarDataType* const __restrict__ p_in_welford_mean,
const MeanVarDataType* const __restrict__ p_in_welford_variance, const MeanVarDataType* const __restrict__ p_in_welford_variance,
...@@ -59,7 +58,6 @@ __global__ void kernel_welford_second_half_batchnorm_forward_final( ...@@ -59,7 +58,6 @@ __global__ void kernel_welford_second_half_batchnorm_forward_final(
mean_var_grid_desc_m, mean_var_grid_desc_m,
blkgroup_size, blkgroup_size,
num_xy_k_block_tile_iteration, num_xy_k_block_tile_iteration,
num_mean_var_count_k_block_tile_iteration,
epsilon, epsilon,
p_in_welford_mean, p_in_welford_mean,
p_in_welford_variance, p_in_welford_variance,
...@@ -152,7 +150,6 @@ struct GridwiseWelfordSecondHalfBatchNormForwardFinal ...@@ -152,7 +150,6 @@ struct GridwiseWelfordSecondHalfBatchNormForwardFinal
const MeanVarGridDesc_M& mean_var_grid_desc_m, const MeanVarGridDesc_M& mean_var_grid_desc_m,
index_t blkgroup_size, index_t blkgroup_size,
index_t num_xy_k_block_tile_iteration, index_t num_xy_k_block_tile_iteration,
index_t num_mean_var_count_k_block_tile_iteration,
AccDataType epsilon, AccDataType epsilon,
const MeanVarDataType* const __restrict__ p_in_welford_mean, const MeanVarDataType* const __restrict__ p_in_welford_mean,
const MeanVarDataType* const __restrict__ p_in_welford_variance, const MeanVarDataType* const __restrict__ p_in_welford_variance,
...@@ -223,7 +220,7 @@ struct GridwiseWelfordSecondHalfBatchNormForwardFinal ...@@ -223,7 +220,7 @@ struct GridwiseWelfordSecondHalfBatchNormForwardFinal
decltype(thread_buffer_desc_m_1), decltype(thread_buffer_desc_m_1),
ThreadBufferLengths_M_1, ThreadBufferLengths_M_1,
Sequence<0, 1>, Sequence<0, 1>,
1, 0,
1, 1,
1, 1,
true>( true>(
...@@ -239,7 +236,7 @@ struct GridwiseWelfordSecondHalfBatchNormForwardFinal ...@@ -239,7 +236,7 @@ struct GridwiseWelfordSecondHalfBatchNormForwardFinal
decltype(thread_buffer_desc_m_1), decltype(thread_buffer_desc_m_1),
ThreadBufferLengths_M_1, ThreadBufferLengths_M_1,
Sequence<0, 1>, Sequence<0, 1>,
1, 0,
1, 1,
1, 1,
true>( true>(
...@@ -257,9 +254,6 @@ struct GridwiseWelfordSecondHalfBatchNormForwardFinal ...@@ -257,9 +254,6 @@ struct GridwiseWelfordSecondHalfBatchNormForwardFinal
const auto welford_count_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto welford_count_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_welford_count, mean_var_count_grid_desc_m_k.GetElementSpaceSize()); p_in_welford_count, mean_var_count_grid_desc_m_k.GetElementSpaceSize());
constexpr auto mean_var_count_thread_copy_step_m_k =
make_multi_index(0, KThreadClusterSize * 1);
// Step 1: do final welford reduction to get mean and variance // Step 1: do final welford reduction to get mean and variance
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
...@@ -268,8 +262,11 @@ struct GridwiseWelfordSecondHalfBatchNormForwardFinal ...@@ -268,8 +262,11 @@ struct GridwiseWelfordSecondHalfBatchNormForwardFinal
welford_count_thread_buf(I) = 0; welford_count_thread_buf(I) = 0;
}); });
for(index_t reducedTiles = 0; reducedTiles < num_mean_var_count_k_block_tile_iteration; constexpr auto mean_var_count_thread_copy_step_m_k =
++reducedTiles) make_multi_index(0, KThreadClusterSize);
int32_t reducedSize = 0;
while(reducedSize < blkgroup_size)
{ {
threadwise_mean_var_load_m_k.Run(mean_var_count_grid_desc_m_k, threadwise_mean_var_load_m_k.Run(mean_var_count_grid_desc_m_k,
welford_mean_global_val_buf, welford_mean_global_val_buf,
...@@ -296,6 +293,8 @@ struct GridwiseWelfordSecondHalfBatchNormForwardFinal ...@@ -296,6 +293,8 @@ struct GridwiseWelfordSecondHalfBatchNormForwardFinal
welford_var_thread_buf, welford_var_thread_buf,
welford_count_thread_buf); welford_count_thread_buf);
reducedSize += KThreadClusterSize;
threadwise_mean_var_load_m_k.MoveSrcSliceWindow(mean_var_count_grid_desc_m_k, threadwise_mean_var_load_m_k.MoveSrcSliceWindow(mean_var_count_grid_desc_m_k,
mean_var_count_thread_copy_step_m_k); mean_var_count_thread_copy_step_m_k);
threadwise_count_load_m_k.MoveSrcSliceWindow(mean_var_count_grid_desc_m_k, threadwise_count_load_m_k.MoveSrcSliceWindow(mean_var_count_grid_desc_m_k,
......
...@@ -3,6 +3,8 @@ ...@@ -3,6 +3,8 @@
#pragma once #pragma once
#include <iostream>
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v2.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v2.hpp"
......
...@@ -79,6 +79,10 @@ struct GridwiseGemmPipeline_v2 ...@@ -79,6 +79,10 @@ struct GridwiseGemmPipeline_v2
do do
{ {
#if CK_EXPERIMENTAL_PIPELINE_V2_IGLP_OPT
__builtin_amdgcn_iglp_opt(CK_EXPERIMENTAL_PIPELINE_V2_IGLP_OPT);
#endif
block_sync_lds(); block_sync_lds();
// GEMM i // GEMM i
......
...@@ -27,6 +27,9 @@ template <typename GridwiseGemm, ...@@ -27,6 +27,9 @@ template <typename GridwiseGemm,
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
#if CK_USE_WAVES_PER_EU
__attribute__((amdgpu_waves_per_eu(CK_MIN_WAVES_PER_EU, CK_MAX_WAVES_PER_EU)))
#endif #endif
kernel_gemm_xdlops_v2r3(const FloatAB* __restrict__ p_a_grid, kernel_gemm_xdlops_v2r3(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_b_grid,
...@@ -60,6 +63,9 @@ template <typename GridwiseGemm, bool HasMainKBlockLoop> ...@@ -60,6 +63,9 @@ template <typename GridwiseGemm, bool HasMainKBlockLoop>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
#if CK_USE_WAVES_PER_EU
__attribute__((amdgpu_waves_per_eu(CK_MIN_WAVES_PER_EU, CK_MAX_WAVES_PER_EU)))
#endif #endif
kernel_gemm_xdlops_v2r3(const typename GridwiseGemm::Argument karg) kernel_gemm_xdlops_v2r3(const typename GridwiseGemm::Argument karg)
{ {
......
...@@ -29,7 +29,9 @@ enum struct MfmaInstr ...@@ -29,7 +29,9 @@ enum struct MfmaInstr
mfma_i32_16x16x16i8, mfma_i32_16x16x16i8,
mfma_i32_32x32x16i8, mfma_i32_32x32x16i8,
mfma_i32_16x16x32i8, mfma_i32_16x16x32i8,
mfma_f64_16x16x4f64 mfma_f64_16x16x4f64,
mfma_f32_32x32x16f8f8,
mfma_f32_16x16x32f8f8
}; };
template <MfmaInstr instr> template <MfmaInstr instr>
...@@ -454,6 +456,50 @@ struct mfma_type<MfmaInstr::mfma_f64_16x16x4f64> ...@@ -454,6 +456,50 @@ struct mfma_type<MfmaInstr::mfma_f64_16x16x4f64>
} }
}; };
template <>
struct mfma_type<MfmaInstr::mfma_f32_32x32x16f8f8>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_per_blk = 4;
static constexpr index_t num_regs_per_blk = 16;
static constexpr index_t num_threads_per_blk = 32;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = 2;
static constexpr index_t num_output_blks = 1;
static constexpr index_t m_per_blk = 32;
static constexpr index_t n_per_blk = 32;
static constexpr index_t k_per_blk = 8;
static constexpr bool is_k_reduction = true;
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{
intrin_mfma_f32_32x32x16f8f8<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
}
};
template <>
struct mfma_type<MfmaInstr::mfma_f32_16x16x32f8f8>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_per_blk = 1;
static constexpr index_t num_regs_per_blk = 4;
static constexpr index_t num_threads_per_blk = 16;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = 4;
static constexpr index_t num_output_blks = 1;
static constexpr index_t m_per_blk = 16;
static constexpr index_t n_per_blk = 16;
static constexpr index_t k_per_blk = 8;
static constexpr bool is_k_reduction = true;
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{
intrin_mfma_f32_16x16x32f8f8<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
}
};
template <typename base_type, index_t MPerXdlops, index_t NPerXdlops> template <typename base_type, index_t MPerXdlops, index_t NPerXdlops>
struct MfmaSelector struct MfmaSelector
{ {
...@@ -594,6 +640,18 @@ struct MfmaSelector ...@@ -594,6 +640,18 @@ struct MfmaSelector
} }
#endif #endif
template <>
static constexpr auto GetMfma<f8_t, 32, 32>()
{
return MfmaInstr::mfma_f32_32x32x16f8f8;
}
template <>
static constexpr auto GetMfma<f8_t, 16, 16>()
{
return MfmaInstr::mfma_f32_16x16x32f8f8;
}
static constexpr auto selected_mfma = mfma_type<GetMfma<base_type, MPerXdlops, NPerXdlops>()>{}; static constexpr auto selected_mfma = mfma_type<GetMfma<base_type, MPerXdlops, NPerXdlops>()>{};
__host__ __device__ constexpr MfmaSelector() __host__ __device__ constexpr MfmaSelector()
...@@ -794,7 +852,7 @@ struct XdlopsGemm ...@@ -794,7 +852,7 @@ struct XdlopsGemm
{ {
static_assert(is_same<base_type, double>::value || is_same<base_type, float>::value || static_assert(is_same<base_type, double>::value || is_same<base_type, float>::value ||
is_same<base_type, half_t>::value || is_same<base_type, bhalf_t>::value || is_same<base_type, half_t>::value || is_same<base_type, bhalf_t>::value ||
is_same<base_type, int8_t>::value, is_same<base_type, int8_t>::value || is_same<base_type, f8_t>::value,
"base base_type must be double, float, half, bfloat16, and int8_t!"); "base base_type must be double, float, half, bfloat16, and int8_t!");
static_for<0, KPack / mfma_instr.k_per_blk, 1>{}([&](auto k) { static_for<0, KPack / mfma_instr.k_per_blk, 1>{}([&](auto k) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment