Unverified Commit 4b70d68e authored by Chao Liu's avatar Chao Liu Committed by GitHub
Browse files

Merge branch 'develop' into add_fp16_wmma_conv_instance

parents 212b9299 f82bd593
...@@ -27,17 +27,19 @@ struct DeviceGroupedConvBwdWeight : public BaseOperator ...@@ -27,17 +27,19 @@ struct DeviceGroupedConvBwdWeight : public BaseOperator
MakeArgumentPointer(const void* p_in, MakeArgumentPointer(const void* p_in,
void* p_wei, void* p_wei,
const void* p_out, const void* p_out,
ck::index_t G, const ck::index_t G,
ck::index_t N, const ck::index_t N,
ck::index_t K, const ck::index_t K,
ck::index_t C, const ck::index_t C,
std::array<ck::index_t, NDimSpatial> input_spatial_lengths, const std::array<ck::index_t, NDimSpatial>& input_spatial_lengths,
std::array<ck::index_t, NDimSpatial> filter_spatial_lengths, const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths,
std::array<ck::index_t, NDimSpatial> output_spatial_lengths, const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths,
std::array<ck::index_t, NDimSpatial> conv_filter_strides, const std::array<ck::index_t, NDimSpatial + 3>& input_strides,
std::array<ck::index_t, NDimSpatial> conv_filter_dilations, const std::array<ck::index_t, NDimSpatial + 3>& output_strides,
std::array<ck::index_t, NDimSpatial> input_left_pads, const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
std::array<ck::index_t, NDimSpatial> input_right_pads, const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
const std::array<ck::index_t, NDimSpatial>& input_left_pads,
const std::array<ck::index_t, NDimSpatial>& input_right_pads,
InElementwiseOperation in_element_op, InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op, WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op, OutElementwiseOperation out_element_op,
......
...@@ -10,12 +10,14 @@ ...@@ -10,12 +10,14 @@
#include "ck/tensor_operation/gpu/device/device_batchnorm_forward.hpp" #include "ck/tensor_operation/gpu/device/device_batchnorm_forward.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp" #include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp"
#include "ck/tensor_operation/gpu/device/welford_helper.hpp" #include "ck/tensor_operation/gpu/device/welford_helper.hpp"
#include "ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_batchnorm_forward.hpp"
#include "ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_first_half.hpp" #include "ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_first_half.hpp"
#include "ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_second_half_batchnorm_forward_final.hpp" #include "ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_second_half_batchnorm_forward_final_obsolete.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batchnorm_forward_blockwise_welford.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_batchnorm_forward_blockwise_welford.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/hip_check_error.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
...@@ -114,8 +116,8 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType, ...@@ -114,8 +116,8 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType,
static auto MakeMeanVarCountOutputMG2dDescriptor(int invariantLength, int blkGroupSize) static auto MakeMeanVarCountOutputMG2dDescriptor(int invariantLength, int blkGroupSize)
{ {
const auto grid_desc_m_g = const auto grid_desc_m_g = make_naive_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(invariantLength, blkGroupSize)); make_tuple(invariantLength, blkGroupSize), make_tuple(1, invariantLength));
const auto mPad = const auto mPad =
math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength; math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
...@@ -132,9 +134,9 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType, ...@@ -132,9 +134,9 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType,
static auto MakeMeanVarCountInputMK2dDescriptor(int invariantLength, int blkGroupSize) static auto MakeMeanVarCountInputMK2dDescriptor(int invariantLength, int blkGroupSize)
{ {
const auto reduceLength = blkGroupSize; const auto reduceLength = blkGroupSize;
const auto grid_desc_m_k = const auto grid_desc_m_k = make_naive_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(invariantLength, reduceLength)); make_tuple(invariantLength, reduceLength), make_tuple(1, invariantLength));
const auto mPad = const auto mPad =
math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength; math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
...@@ -244,8 +246,8 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType, ...@@ -244,8 +246,8 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType,
int testBlkGroupSize = (reduce_length_ + (K_BlockTileSize * iterations) - 1) / int testBlkGroupSize = (reduce_length_ + (K_BlockTileSize * iterations) - 1) /
(K_BlockTileSize * iterations); (K_BlockTileSize * iterations);
// we want the blkGroupSize be not more than 128 // we want the blkGroupSize be not more than 16
if(testBlkGroupSize <= 128) if(testBlkGroupSize <= 16)
break; break;
iterations++; iterations++;
...@@ -319,6 +321,8 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType, ...@@ -319,6 +321,8 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType,
void* workspace_mean_; void* workspace_mean_;
void* workspace_variance_; void* workspace_variance_;
void* workspace_count_; void* workspace_count_;
void* control_;
}; };
size_t GetWorkSpaceSize(const BaseArgument* pArg) const override size_t GetWorkSpaceSize(const BaseArgument* pArg) const override
...@@ -340,6 +344,11 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType, ...@@ -340,6 +344,11 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType,
// workspace for welford intermediate count // workspace for welford intermediate count
workspace_size += workspace_size +=
pArg_->invariant_length_ * pArg_->blkGroupSize_ * sizeof(int32_t) + 64; pArg_->invariant_length_ * pArg_->blkGroupSize_ * sizeof(int32_t) + 64;
// workspace for barrier objects, each barrier object consists of two integers
// TODO: allocate barrier object memory globally to reuse it by other operators
workspace_size += (pArg_->invariant_length_ + M_BlockTileSize - 1) / M_BlockTileSize *
sizeof(int) * 2;
} }
return (workspace_size); return (workspace_size);
...@@ -353,7 +362,6 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType, ...@@ -353,7 +362,6 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType,
if(UseMultiblockInK && pArg_->blkGroupSize_ > 1) if(UseMultiblockInK && pArg_->blkGroupSize_ > 1)
{ {
// setup buffer used for intermediate welford mean // setup buffer used for intermediate welford mean
pArg_->workspace_mean_ = static_cast<char*>(pArg_->p_workspace_); pArg_->workspace_mean_ = static_cast<char*>(pArg_->p_workspace_);
...@@ -374,6 +382,18 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType, ...@@ -374,6 +382,18 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType,
// setup buffer used for intermediate welfor count // setup buffer used for intermediate welfor count
pArg_->workspace_count_ = pArg_->workspace_count_ =
reinterpret_cast<char*>(pArg_->workspace_variance_) + variance_space_sz; reinterpret_cast<char*>(pArg_->workspace_variance_) + variance_space_sz;
index_t count_space_sz =
pArg_->invariant_length_ * pArg_->blkGroupSize_ * sizeof(int32_t);
count_space_sz = math::integer_least_multiple(count_space_sz, 64);
pArg_->control_ = reinterpret_cast<char*>(pArg_->workspace_count_) + count_space_sz;
index_t control_space_sz = (pArg_->invariant_length_ + M_BlockTileSize - 1) /
M_BlockTileSize * sizeof(int) * 2;
hip_check_error(hipMemset(pArg_->control_, 0, control_space_sz));
}; };
}; };
...@@ -402,6 +422,32 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType, ...@@ -402,6 +422,32 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType,
using MeanVarCountGridDesc_M_G = decltype(mean_var_count_grid_desc_m_g); using MeanVarCountGridDesc_M_G = decltype(mean_var_count_grid_desc_m_g);
using MeanVarCountGridDesc_M_K = decltype(mean_var_count_grid_desc_m_k); using MeanVarCountGridDesc_M_K = decltype(mean_var_count_grid_desc_m_k);
using GridwiseMultiblockBatchNormForward_ =
GridwiseMultiblockBatchNormForward<XDataType,
YDataType,
AccDataType,
ScaleDataType,
BiasDataType,
MeanVarDataType,
YElementwiseOp,
XYGridDesc_M_K,
MeanVarCountGridDesc_M_G,
MeanVarCountGridDesc_M_K,
ScaleBiasMeanVarGridDesc_M,
ScaleBiasMeanVarGridDesc_M,
GetReduceCountPerThreadFunctor,
BlockSize,
MThreadClusterSize,
KThreadClusterSize,
MThreadSliceSize,
KThreadSliceSize,
XSrcYDstVectorDim,
XSrcVectorSize,
YDstVectorSize,
ScaleSrcVectorSize,
BiasSrcVectorSize,
MeanVarSrcDstVectorSize>;
using GridwiseMultiblockWelfordFirstHalf_ = using GridwiseMultiblockWelfordFirstHalf_ =
GridwiseMultiblockWelfordFirstHalf<XDataType, GridwiseMultiblockWelfordFirstHalf<XDataType,
AccDataType, AccDataType,
...@@ -441,78 +487,136 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType, ...@@ -441,78 +487,136 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType,
BiasSrcVectorSize, BiasSrcVectorSize,
MeanVarSrcDstVectorSize>; MeanVarSrcDstVectorSize>;
index_t numMeanVarCountBlockTileIteration = // It is found that:
(arg.blkGroupSize_ + KThreadClusterSize - 1) / KThreadClusterSize; // 1) gfx1030 does not support the GLC enabled vector load/store, so using the
// two-kernel method for gfx1030
const auto kern_multiblock_welford_first_half = // 2) Profiler on gfx908 could hang even though it works when running examples
kernel_multiblock_welford_first_half<GridwiseMultiblockWelfordFirstHalf_, // 3) Single-kernel method works on gfx1100, but the performance it not better
XDataType, // than two-kernel method (due to more warps participating the barrier)
MeanVarDataType, if(ck::get_device_name() == "gfx90a")
XYGridDesc_M_K, {
MeanVarCountGridDesc_M_G, const auto kern_multiblock_batchnorm_fwd_ =
GetReduceCountPerThreadFunctor>; kernel_multiblock_batchnorm_forward<GridwiseMultiblockBatchNormForward_,
XDataType,
const auto kern_welford_second_half_batchnorm_forward_final = YDataType,
kernel_welford_second_half_batchnorm_forward_final< AccDataType,
GridwiseWelfordSecondHalfBatchNormForwardFinal_, ScaleDataType,
XDataType, BiasDataType,
YDataType, MeanVarDataType,
AccDataType, YElementwiseOp,
ScaleDataType, XYGridDesc_M_K,
BiasDataType, MeanVarCountGridDesc_M_G,
MeanVarDataType, MeanVarCountGridDesc_M_K,
YElementwiseOp, ScaleBiasMeanVarGridDesc_M,
XYGridDesc_M_K, ScaleBiasMeanVarGridDesc_M,
MeanVarCountGridDesc_M_K, GetReduceCountPerThreadFunctor>;
ScaleBiasMeanVarGridDesc_M,
ScaleBiasMeanVarGridDesc_M>; avg_time += launch_and_time_kernel(
stream_config,
avg_time += kern_multiblock_batchnorm_fwd_,
launch_and_time_kernel(stream_config, dim3(arg.gridSize_),
kern_multiblock_welford_first_half, dim3(BlockSize),
dim3(arg.gridSize_), 0,
dim3(BlockSize), arg.x_grid_desc_m_k_,
0, arg.y_grid_desc_m_k_,
arg.x_grid_desc_m_k_, mean_var_count_grid_desc_m_g, // for writing to mean/variance/count
mean_var_count_grid_desc_m_g, // workspace by multiple workgroups
get_reduce_count_per_thread, mean_var_count_grid_desc_m_k, // for reading from mean/variance/count
arg.numBlockTileIteration_, // workspace by each workgroup
arg.p_x_, arg.scale_grid_desc_m_,
static_cast<MeanVarDataType*>(arg.workspace_mean_), arg.bias_grid_desc_m_,
static_cast<MeanVarDataType*>(arg.workspace_variance_), arg.mean_var_grid_desc_m_,
static_cast<int32_t*>(arg.workspace_count_)); get_reduce_count_per_thread,
arg.numBlockTileIteration_,
avg_time += arg.epsilon_,
launch_and_time_kernel(stream_config, arg.p_x_,
kern_welford_second_half_batchnorm_forward_final, static_cast<MeanVarDataType*>(arg.workspace_mean_),
dim3(arg.gridSize_), static_cast<MeanVarDataType*>(arg.workspace_variance_),
dim3(BlockSize), static_cast<int32_t*>(arg.workspace_count_),
0, static_cast<int*>(arg.control_),
arg.x_grid_desc_m_k_, arg.p_scale_,
arg.y_grid_desc_m_k_, arg.p_bias_,
mean_var_count_grid_desc_m_k, arg.y_elementwise_op_,
arg.scale_grid_desc_m_, arg.p_y_,
arg.bias_grid_desc_m_, arg.updateMovingAverage_, // true or false
arg.mean_var_grid_desc_m_, arg.averageFactor_,
arg.blkGroupSize_, arg.resultRunningMean_,
arg.numBlockTileIteration_, arg.resultRunningVariance_,
numMeanVarCountBlockTileIteration, arg.saveMeanInvVariance_, // true or false
arg.epsilon_, arg.resultSaveMean_,
static_cast<MeanVarDataType*>(arg.workspace_mean_), arg.resultSaveInvVariance_);
static_cast<MeanVarDataType*>(arg.workspace_variance_), }
static_cast<int32_t*>(arg.workspace_count_), else
arg.p_x_, {
arg.p_scale_, const auto kern_multiblock_welford_first_half =
arg.p_bias_, kernel_multiblock_welford_first_half<GridwiseMultiblockWelfordFirstHalf_,
arg.y_elementwise_op_, XDataType,
arg.p_y_, MeanVarDataType,
arg.updateMovingAverage_, XYGridDesc_M_K,
arg.averageFactor_, MeanVarCountGridDesc_M_G,
arg.resultRunningMean_, GetReduceCountPerThreadFunctor>;
arg.resultRunningVariance_,
arg.saveMeanInvVariance_, const auto kern_welford_second_half_batchnorm_forward_final =
arg.resultSaveMean_, kernel_welford_second_half_batchnorm_forward_final<
arg.resultSaveInvVariance_); GridwiseWelfordSecondHalfBatchNormForwardFinal_,
XDataType,
YDataType,
AccDataType,
ScaleDataType,
BiasDataType,
MeanVarDataType,
YElementwiseOp,
XYGridDesc_M_K,
MeanVarCountGridDesc_M_K,
ScaleBiasMeanVarGridDesc_M,
ScaleBiasMeanVarGridDesc_M>;
avg_time += launch_and_time_kernel(
stream_config,
kern_multiblock_welford_first_half,
dim3(arg.gridSize_),
dim3(BlockSize),
0,
arg.x_grid_desc_m_k_,
mean_var_count_grid_desc_m_g,
get_reduce_count_per_thread,
arg.numBlockTileIteration_,
arg.p_x_,
static_cast<MeanVarDataType*>(arg.workspace_mean_),
static_cast<MeanVarDataType*>(arg.workspace_variance_),
static_cast<int32_t*>(arg.workspace_count_));
avg_time += launch_and_time_kernel(
stream_config,
kern_welford_second_half_batchnorm_forward_final,
dim3(arg.gridSize_),
dim3(BlockSize),
0,
arg.x_grid_desc_m_k_,
arg.y_grid_desc_m_k_,
mean_var_count_grid_desc_m_k,
arg.scale_grid_desc_m_,
arg.bias_grid_desc_m_,
arg.mean_var_grid_desc_m_,
arg.blkGroupSize_,
arg.numBlockTileIteration_,
arg.epsilon_,
static_cast<MeanVarDataType*>(arg.workspace_mean_),
static_cast<MeanVarDataType*>(arg.workspace_variance_),
static_cast<int32_t*>(arg.workspace_count_),
arg.p_x_,
arg.p_scale_,
arg.p_bias_,
arg.y_elementwise_op_,
arg.p_y_,
arg.updateMovingAverage_,
arg.averageFactor_,
arg.resultRunningMean_,
arg.resultRunningVariance_,
arg.saveMeanInvVariance_,
arg.resultSaveMean_,
arg.resultSaveInvVariance_);
};
} }
else else
{ {
......
...@@ -195,17 +195,17 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl ...@@ -195,17 +195,17 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
template <ck::index_t NDim, typename ck::enable_if<NDim == 1, bool>::type = false> template <ck::index_t NDim, typename ck::enable_if<NDim == 1, bool>::type = false>
static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(
ck::index_t N, const ck::index_t N,
ck::index_t K, const ck::index_t K,
ck::index_t C, const ck::index_t C,
std::array<ck::index_t, NDimSpatial> input_spatial_lengths, const std::array<ck::index_t, NDimSpatial>& input_spatial_lengths,
std::array<ck::index_t, NDimSpatial> filter_spatial_lengths, const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths,
std::array<ck::index_t, NDimSpatial> output_spatial_lengths, const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths,
std::array<ck::index_t, NDimSpatial> conv_filter_strides, const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
std::array<ck::index_t, NDimSpatial> conv_filter_dilations, const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
std::array<ck::index_t, NDimSpatial> input_left_pads, const std::array<ck::index_t, NDimSpatial>& input_left_pads,
std::array<ck::index_t, NDimSpatial> input_right_pads, const std::array<ck::index_t, NDimSpatial>& input_right_pads,
ck::index_t batch_k) const ck::index_t batch_k)
{ {
using namespace ck; using namespace ck;
...@@ -347,17 +347,17 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl ...@@ -347,17 +347,17 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
} // function end } // function end
template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false> template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false>
static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(
ck::index_t N, const ck::index_t N,
ck::index_t K, const ck::index_t K,
ck::index_t C, const ck::index_t C,
std::array<ck::index_t, NDimSpatial> input_spatial_lengths, const std::array<ck::index_t, NDimSpatial>& input_spatial_lengths,
std::array<ck::index_t, NDimSpatial> filter_spatial_lengths, const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths,
std::array<ck::index_t, NDimSpatial> output_spatial_lengths, const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths,
std::array<ck::index_t, NDimSpatial> conv_filter_strides, const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
std::array<ck::index_t, NDimSpatial> conv_filter_dilations, const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
std::array<ck::index_t, NDimSpatial> input_left_pads, const std::array<ck::index_t, NDimSpatial>& input_left_pads,
std::array<ck::index_t, NDimSpatial> input_right_pads, const std::array<ck::index_t, NDimSpatial>& input_right_pads,
ck::index_t batch_k) const ck::index_t batch_k)
{ {
using namespace ck; using namespace ck;
...@@ -515,17 +515,17 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl ...@@ -515,17 +515,17 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false> template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false>
static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(
ck::index_t N, const ck::index_t N,
ck::index_t K, const ck::index_t K,
ck::index_t C, const ck::index_t C,
std::array<ck::index_t, NDimSpatial> input_spatial_lengths, const std::array<ck::index_t, NDimSpatial>& input_spatial_lengths,
std::array<ck::index_t, NDimSpatial> filter_spatial_lengths, const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths,
std::array<ck::index_t, NDimSpatial> output_spatial_lengths, const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths,
std::array<ck::index_t, NDimSpatial> conv_filter_strides, const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
std::array<ck::index_t, NDimSpatial> conv_filter_dilations, const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
std::array<ck::index_t, NDimSpatial> input_left_pads, const std::array<ck::index_t, NDimSpatial>& input_left_pads,
std::array<ck::index_t, NDimSpatial> input_right_pads, const std::array<ck::index_t, NDimSpatial>& input_right_pads,
ck::index_t batch_k) const ck::index_t batch_k)
{ {
using namespace ck; using namespace ck;
...@@ -784,17 +784,19 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl ...@@ -784,17 +784,19 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
Argument(const InDataType* p_in_grid, Argument(const InDataType* p_in_grid,
WeiDataType* p_wei_grid, WeiDataType* p_wei_grid,
const OutDataType* p_out_grid, const OutDataType* p_out_grid,
ck::index_t G, const ck::index_t G,
ck::index_t N, const ck::index_t N,
ck::index_t K, const ck::index_t K,
ck::index_t C, const ck::index_t C,
std::array<ck::index_t, NDimSpatial> input_spatial_lengths, const std::array<ck::index_t, NDimSpatial>& input_spatial_lengths,
std::array<ck::index_t, NDimSpatial> filter_spatial_lengths, const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths,
std::array<ck::index_t, NDimSpatial> output_spatial_lengths, const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths,
std::array<ck::index_t, NDimSpatial> conv_filter_strides, const std::array<ck::index_t, NDimSpatial + 3>& /*input_strides*/,
std::array<ck::index_t, NDimSpatial> conv_filter_dilations, const std::array<ck::index_t, NDimSpatial + 3>& /*output_strides*/,
std::array<ck::index_t, NDimSpatial> input_left_pads, const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
std::array<ck::index_t, NDimSpatial> input_right_pads, const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
const std::array<ck::index_t, NDimSpatial>& input_left_pads,
const std::array<ck::index_t, NDimSpatial>& input_right_pads,
InElementwiseOperation in_element_op, InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op, WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op, OutElementwiseOperation out_element_op,
...@@ -897,18 +899,18 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl ...@@ -897,18 +899,18 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
InElementwiseOperation c_element_op_; InElementwiseOperation c_element_op_;
// for checking IsSupportedArgument() // for checking IsSupportedArgument()
index_t Conv_G_; const index_t Conv_G_;
index_t Conv_N_; const index_t Conv_N_;
index_t Conv_K_; const index_t Conv_K_;
index_t Conv_C_; const index_t Conv_C_;
std::array<ck::index_t, NDimSpatial> input_spatial_lengths_; const std::array<ck::index_t, NDimSpatial>& input_spatial_lengths_;
std::array<ck::index_t, NDimSpatial> filter_spatial_lengths_; const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths_;
std::array<ck::index_t, NDimSpatial> output_spatial_lengths_; const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths_;
std::array<ck::index_t, NDimSpatial> conv_filter_strides_; const std::array<ck::index_t, NDimSpatial>& conv_filter_strides_;
std::array<ck::index_t, NDimSpatial> conv_filter_dilations_; const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations_;
std::array<ck::index_t, NDimSpatial> input_left_pads_; const std::array<ck::index_t, NDimSpatial>& input_left_pads_;
std::array<ck::index_t, NDimSpatial> input_right_pads_; const std::array<ck::index_t, NDimSpatial>& input_right_pads_;
index_t k_batch_; index_t k_batch_;
}; };
...@@ -1111,17 +1113,19 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl ...@@ -1111,17 +1113,19 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
static auto MakeArgument(const InDataType* p_in_grid, static auto MakeArgument(const InDataType* p_in_grid,
WeiDataType* p_wei_grid, WeiDataType* p_wei_grid,
const OutDataType* p_out_grid, const OutDataType* p_out_grid,
ck::index_t G, const ck::index_t G,
ck::index_t N, const ck::index_t N,
ck::index_t K, const ck::index_t K,
ck::index_t C, const ck::index_t C,
std::array<ck::index_t, NDimSpatial> input_spatial_lengths, const std::array<ck::index_t, NDimSpatial>& input_spatial_lengths,
std::array<ck::index_t, NDimSpatial> filter_spatial_lengths, const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths,
std::array<ck::index_t, NDimSpatial> output_spatial_lengths, const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths,
std::array<ck::index_t, NDimSpatial> conv_filter_strides, const std::array<ck::index_t, NDimSpatial + 3>& input_strides,
std::array<ck::index_t, NDimSpatial> conv_filter_dilations, const std::array<ck::index_t, NDimSpatial + 3>& output_strides,
std::array<ck::index_t, NDimSpatial> input_left_pads, const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
std::array<ck::index_t, NDimSpatial> input_right_pads, const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
const std::array<ck::index_t, NDimSpatial>& input_left_pads,
const std::array<ck::index_t, NDimSpatial>& input_right_pads,
InElementwiseOperation in_element_op, InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op, WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op, OutElementwiseOperation out_element_op,
...@@ -1137,6 +1141,8 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl ...@@ -1137,6 +1141,8 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
input_spatial_lengths, input_spatial_lengths,
filter_spatial_lengths, filter_spatial_lengths,
output_spatial_lengths, output_spatial_lengths,
input_strides,
output_strides,
conv_filter_strides, conv_filter_strides,
conv_filter_dilations, conv_filter_dilations,
input_left_pads, input_left_pads,
...@@ -1153,17 +1159,19 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl ...@@ -1153,17 +1159,19 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
MakeArgumentPointer(const void* p_in_grid, MakeArgumentPointer(const void* p_in_grid,
void* p_wei_grid, void* p_wei_grid,
const void* p_out_grid, const void* p_out_grid,
ck::index_t G, const ck::index_t G,
ck::index_t N, const ck::index_t N,
ck::index_t K, const ck::index_t K,
ck::index_t C, const ck::index_t C,
std::array<ck::index_t, NDimSpatial> input_spatial_lengths, const std::array<ck::index_t, NDimSpatial>& input_spatial_lengths,
std::array<ck::index_t, NDimSpatial> filter_spatial_lengths, const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths,
std::array<ck::index_t, NDimSpatial> output_spatial_lengths, const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths,
std::array<ck::index_t, NDimSpatial> conv_filter_strides, const std::array<ck::index_t, NDimSpatial + 3>& input_strides,
std::array<ck::index_t, NDimSpatial> conv_filter_dilations, const std::array<ck::index_t, NDimSpatial + 3>& output_strides,
std::array<ck::index_t, NDimSpatial> input_left_pads, const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
std::array<ck::index_t, NDimSpatial> input_right_pads, const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
const std::array<ck::index_t, NDimSpatial>& input_left_pads,
const std::array<ck::index_t, NDimSpatial>& input_right_pads,
InElementwiseOperation in_element_op, InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op, WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op, OutElementwiseOperation out_element_op,
...@@ -1179,6 +1187,8 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl ...@@ -1179,6 +1187,8 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
input_spatial_lengths, input_spatial_lengths,
filter_spatial_lengths, filter_spatial_lengths,
output_spatial_lengths, output_spatial_lengths,
input_strides,
output_strides,
conv_filter_strides, conv_filter_strides,
conv_filter_dilations, conv_filter_dilations,
input_left_pads, input_left_pads,
......
...@@ -161,7 +161,7 @@ struct GridwiseMultiblockWelfordFirstHalf ...@@ -161,7 +161,7 @@ struct GridwiseMultiblockWelfordFirstHalf
PassThroughOp, PassThroughOp,
ThreadBufferLengths_M_1, ThreadBufferLengths_M_1,
Sequence<0, 1>, Sequence<0, 1>,
1, 0,
1, 1,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
1, 1,
...@@ -180,7 +180,7 @@ struct GridwiseMultiblockWelfordFirstHalf ...@@ -180,7 +180,7 @@ struct GridwiseMultiblockWelfordFirstHalf
PassThroughOp, PassThroughOp,
ThreadBufferLengths_M_1, ThreadBufferLengths_M_1,
Sequence<0, 1>, Sequence<0, 1>,
1, 0,
1, 1,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
1, 1,
......
...@@ -33,7 +33,6 @@ __global__ void kernel_welford_second_half_batchnorm_forward_final( ...@@ -33,7 +33,6 @@ __global__ void kernel_welford_second_half_batchnorm_forward_final(
const MeanVarGridDesc_M mean_var_grid_desc_m, const MeanVarGridDesc_M mean_var_grid_desc_m,
index_t blkgroup_size, index_t blkgroup_size,
index_t num_xy_k_block_tile_iteration, index_t num_xy_k_block_tile_iteration,
index_t num_mean_var_count_k_block_tile_iteration,
AccDataType epsilon, AccDataType epsilon,
const MeanVarDataType* const __restrict__ p_in_welford_mean, const MeanVarDataType* const __restrict__ p_in_welford_mean,
const MeanVarDataType* const __restrict__ p_in_welford_variance, const MeanVarDataType* const __restrict__ p_in_welford_variance,
...@@ -59,7 +58,6 @@ __global__ void kernel_welford_second_half_batchnorm_forward_final( ...@@ -59,7 +58,6 @@ __global__ void kernel_welford_second_half_batchnorm_forward_final(
mean_var_grid_desc_m, mean_var_grid_desc_m,
blkgroup_size, blkgroup_size,
num_xy_k_block_tile_iteration, num_xy_k_block_tile_iteration,
num_mean_var_count_k_block_tile_iteration,
epsilon, epsilon,
p_in_welford_mean, p_in_welford_mean,
p_in_welford_variance, p_in_welford_variance,
...@@ -152,7 +150,6 @@ struct GridwiseWelfordSecondHalfBatchNormForwardFinal ...@@ -152,7 +150,6 @@ struct GridwiseWelfordSecondHalfBatchNormForwardFinal
const MeanVarGridDesc_M& mean_var_grid_desc_m, const MeanVarGridDesc_M& mean_var_grid_desc_m,
index_t blkgroup_size, index_t blkgroup_size,
index_t num_xy_k_block_tile_iteration, index_t num_xy_k_block_tile_iteration,
index_t num_mean_var_count_k_block_tile_iteration,
AccDataType epsilon, AccDataType epsilon,
const MeanVarDataType* const __restrict__ p_in_welford_mean, const MeanVarDataType* const __restrict__ p_in_welford_mean,
const MeanVarDataType* const __restrict__ p_in_welford_variance, const MeanVarDataType* const __restrict__ p_in_welford_variance,
...@@ -223,7 +220,7 @@ struct GridwiseWelfordSecondHalfBatchNormForwardFinal ...@@ -223,7 +220,7 @@ struct GridwiseWelfordSecondHalfBatchNormForwardFinal
decltype(thread_buffer_desc_m_1), decltype(thread_buffer_desc_m_1),
ThreadBufferLengths_M_1, ThreadBufferLengths_M_1,
Sequence<0, 1>, Sequence<0, 1>,
1, 0,
1, 1,
1, 1,
true>( true>(
...@@ -239,7 +236,7 @@ struct GridwiseWelfordSecondHalfBatchNormForwardFinal ...@@ -239,7 +236,7 @@ struct GridwiseWelfordSecondHalfBatchNormForwardFinal
decltype(thread_buffer_desc_m_1), decltype(thread_buffer_desc_m_1),
ThreadBufferLengths_M_1, ThreadBufferLengths_M_1,
Sequence<0, 1>, Sequence<0, 1>,
1, 0,
1, 1,
1, 1,
true>( true>(
...@@ -257,9 +254,6 @@ struct GridwiseWelfordSecondHalfBatchNormForwardFinal ...@@ -257,9 +254,6 @@ struct GridwiseWelfordSecondHalfBatchNormForwardFinal
const auto welford_count_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto welford_count_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_welford_count, mean_var_count_grid_desc_m_k.GetElementSpaceSize()); p_in_welford_count, mean_var_count_grid_desc_m_k.GetElementSpaceSize());
constexpr auto mean_var_count_thread_copy_step_m_k =
make_multi_index(0, KThreadClusterSize * 1);
// Step 1: do final welford reduction to get mean and variance // Step 1: do final welford reduction to get mean and variance
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
...@@ -268,8 +262,11 @@ struct GridwiseWelfordSecondHalfBatchNormForwardFinal ...@@ -268,8 +262,11 @@ struct GridwiseWelfordSecondHalfBatchNormForwardFinal
welford_count_thread_buf(I) = 0; welford_count_thread_buf(I) = 0;
}); });
for(index_t reducedTiles = 0; reducedTiles < num_mean_var_count_k_block_tile_iteration; constexpr auto mean_var_count_thread_copy_step_m_k =
++reducedTiles) make_multi_index(0, KThreadClusterSize);
int32_t reducedSize = 0;
while(reducedSize < blkgroup_size)
{ {
threadwise_mean_var_load_m_k.Run(mean_var_count_grid_desc_m_k, threadwise_mean_var_load_m_k.Run(mean_var_count_grid_desc_m_k,
welford_mean_global_val_buf, welford_mean_global_val_buf,
...@@ -296,6 +293,8 @@ struct GridwiseWelfordSecondHalfBatchNormForwardFinal ...@@ -296,6 +293,8 @@ struct GridwiseWelfordSecondHalfBatchNormForwardFinal
welford_var_thread_buf, welford_var_thread_buf,
welford_count_thread_buf); welford_count_thread_buf);
reducedSize += KThreadClusterSize;
threadwise_mean_var_load_m_k.MoveSrcSliceWindow(mean_var_count_grid_desc_m_k, threadwise_mean_var_load_m_k.MoveSrcSliceWindow(mean_var_count_grid_desc_m_k,
mean_var_count_thread_copy_step_m_k); mean_var_count_thread_copy_step_m_k);
threadwise_count_load_m_k.MoveSrcSliceWindow(mean_var_count_grid_desc_m_k, threadwise_count_load_m_k.MoveSrcSliceWindow(mean_var_count_grid_desc_m_k,
......
...@@ -3,6 +3,8 @@ ...@@ -3,6 +3,8 @@
#pragma once #pragma once
#include <iostream>
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v2.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v2.hpp"
......
...@@ -79,6 +79,10 @@ struct GridwiseGemmPipeline_v2 ...@@ -79,6 +79,10 @@ struct GridwiseGemmPipeline_v2
do do
{ {
#if CK_EXPERIMENTAL_PIPELINE_V2_IGLP_OPT
__builtin_amdgcn_iglp_opt(CK_EXPERIMENTAL_PIPELINE_V2_IGLP_OPT);
#endif
block_sync_lds(); block_sync_lds();
// GEMM i // GEMM i
......
...@@ -27,6 +27,9 @@ template <typename GridwiseGemm, ...@@ -27,6 +27,9 @@ template <typename GridwiseGemm,
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
#if CK_USE_WAVES_PER_EU
__attribute__((amdgpu_waves_per_eu(CK_MIN_WAVES_PER_EU, CK_MAX_WAVES_PER_EU)))
#endif #endif
kernel_gemm_xdlops_v2r3(const FloatAB* __restrict__ p_a_grid, kernel_gemm_xdlops_v2r3(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_b_grid,
...@@ -60,6 +63,9 @@ template <typename GridwiseGemm, bool HasMainKBlockLoop> ...@@ -60,6 +63,9 @@ template <typename GridwiseGemm, bool HasMainKBlockLoop>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
#if CK_USE_WAVES_PER_EU
__attribute__((amdgpu_waves_per_eu(CK_MIN_WAVES_PER_EU, CK_MAX_WAVES_PER_EU)))
#endif #endif
kernel_gemm_xdlops_v2r3(const typename GridwiseGemm::Argument karg) kernel_gemm_xdlops_v2r3(const typename GridwiseGemm::Argument karg)
{ {
......
...@@ -29,7 +29,9 @@ enum struct MfmaInstr ...@@ -29,7 +29,9 @@ enum struct MfmaInstr
mfma_i32_16x16x16i8, mfma_i32_16x16x16i8,
mfma_i32_32x32x16i8, mfma_i32_32x32x16i8,
mfma_i32_16x16x32i8, mfma_i32_16x16x32i8,
mfma_f64_16x16x4f64 mfma_f64_16x16x4f64,
mfma_f32_32x32x16f8f8,
mfma_f32_16x16x32f8f8
}; };
template <MfmaInstr instr> template <MfmaInstr instr>
...@@ -454,6 +456,50 @@ struct mfma_type<MfmaInstr::mfma_f64_16x16x4f64> ...@@ -454,6 +456,50 @@ struct mfma_type<MfmaInstr::mfma_f64_16x16x4f64>
} }
}; };
template <>
struct mfma_type<MfmaInstr::mfma_f32_32x32x16f8f8>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_per_blk = 4;
static constexpr index_t num_regs_per_blk = 16;
static constexpr index_t num_threads_per_blk = 32;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = 2;
static constexpr index_t num_output_blks = 1;
static constexpr index_t m_per_blk = 32;
static constexpr index_t n_per_blk = 32;
static constexpr index_t k_per_blk = 8;
static constexpr bool is_k_reduction = true;
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{
intrin_mfma_f32_32x32x16f8f8<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
}
};
template <>
struct mfma_type<MfmaInstr::mfma_f32_16x16x32f8f8>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_per_blk = 1;
static constexpr index_t num_regs_per_blk = 4;
static constexpr index_t num_threads_per_blk = 16;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = 4;
static constexpr index_t num_output_blks = 1;
static constexpr index_t m_per_blk = 16;
static constexpr index_t n_per_blk = 16;
static constexpr index_t k_per_blk = 8;
static constexpr bool is_k_reduction = true;
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{
intrin_mfma_f32_16x16x32f8f8<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
}
};
template <typename base_type, index_t MPerXdlops, index_t NPerXdlops> template <typename base_type, index_t MPerXdlops, index_t NPerXdlops>
struct MfmaSelector struct MfmaSelector
{ {
...@@ -594,6 +640,18 @@ struct MfmaSelector ...@@ -594,6 +640,18 @@ struct MfmaSelector
} }
#endif #endif
template <>
static constexpr auto GetMfma<f8_t, 32, 32>()
{
return MfmaInstr::mfma_f32_32x32x16f8f8;
}
template <>
static constexpr auto GetMfma<f8_t, 16, 16>()
{
return MfmaInstr::mfma_f32_16x16x32f8f8;
}
static constexpr auto selected_mfma = mfma_type<GetMfma<base_type, MPerXdlops, NPerXdlops>()>{}; static constexpr auto selected_mfma = mfma_type<GetMfma<base_type, MPerXdlops, NPerXdlops>()>{};
__host__ __device__ constexpr MfmaSelector() __host__ __device__ constexpr MfmaSelector()
...@@ -794,7 +852,7 @@ struct XdlopsGemm ...@@ -794,7 +852,7 @@ struct XdlopsGemm
{ {
static_assert(is_same<base_type, double>::value || is_same<base_type, float>::value || static_assert(is_same<base_type, double>::value || is_same<base_type, float>::value ||
is_same<base_type, half_t>::value || is_same<base_type, bhalf_t>::value || is_same<base_type, half_t>::value || is_same<base_type, bhalf_t>::value ||
is_same<base_type, int8_t>::value, is_same<base_type, int8_t>::value || is_same<base_type, f8_t>::value,
"base base_type must be double, float, half, bfloat16, and int8_t!"); "base base_type must be double, float, half, bfloat16, and int8_t!");
static_for<0, KPack / mfma_instr.k_per_blk, 1>{}([&](auto k) { static_for<0, KPack / mfma_instr.k_per_blk, 1>{}([&](auto k) {
......
...@@ -1114,13 +1114,30 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave, ...@@ -1114,13 +1114,30 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave,
#if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK #if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
uint32_t src_addr_shift = src_thread_element_valid ? 0 : 0x80000000; uint32_t src_addr_shift = src_thread_element_valid ? 0 : 0x80000000;
return amd_buffer_load_impl<scalar_t, vector_size, coherence>( if constexpr(is_same<scalar_t, f8_t>::value)
src_wave_buffer_resource, src_addr_shift + src_thread_addr_offset, 0); {
auto tmp = amd_buffer_load_impl<int8_t, vector_size, coherence>(
src_wave_buffer_resource, src_addr_shift + src_thread_addr_offset, 0);
return bit_cast<vector_t>(tmp);
}
else
{
return amd_buffer_load_impl<scalar_t, vector_size, coherence>(
src_wave_buffer_resource, src_addr_shift + src_thread_addr_offset, 0);
}
#else #else
vector_t tmp = amd_buffer_load_impl<scalar_t, vector_size, coherence>( if constexpr(is_same<scalar_t, f8_t>::value)
src_wave_buffer_resource, src_thread_addr_offset, 0); {
auto tmp = amd_buffer_load_impl<int8_t, vector_size, coherence>(
return src_thread_element_valid ? tmp : vector_t(0); src_wave_buffer_resource, src_thread_addr_offset, 0);
return src_thread_element_valid ? bit_cast<vector_t>(tmp) : vector_t(0);
}
else
{
vector_t tmp = amd_buffer_load_impl<scalar_t, vector_size, coherence>(
src_wave_buffer_resource, src_thread_addr_offset, 0);
return src_thread_element_valid ? tmp : vector_t(0);
}
#endif #endif
} }
...@@ -1179,13 +1196,33 @@ __device__ void amd_buffer_store(const typename vector_type_maker<T, N>::type::t ...@@ -1179,13 +1196,33 @@ __device__ void amd_buffer_store(const typename vector_type_maker<T, N>::type::t
#if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK #if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK
uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x80000000; uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x80000000;
amd_buffer_store_impl<scalar_t, vector_size, coherence>( if constexpr(is_same<scalar_t, f8_t>::value)
src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0); {
auto tmp =
bit_cast<typename vector_type_maker<int8_t, vector_size>::type::type>(src_thread_data);
amd_buffer_store_impl<int8_t, vector_size, coherence>(
tmp, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0);
}
else
{
amd_buffer_store_impl<scalar_t, vector_size, coherence>(
src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0);
}
#else #else
if(dst_thread_element_valid) if(dst_thread_element_valid)
{ {
amd_buffer_store_impl<scalar_t, vector_size, coherence>( if constexpr(is_same<scalar_t, f8_t>::value)
src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0); {
auto tmp = bit_cast<typename vector_type_maker<int8_t, vector_size>::type::type>(
src_thread_data);
amd_buffer_store_impl<int8_t, vector_size, coherence>(
tmp, dst_wave_buffer_resource, dst_thread_addr_offset, 0);
}
else
{
amd_buffer_store_impl<scalar_t, vector_size, coherence>(
src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0);
}
} }
#endif #endif
} }
......
...@@ -354,5 +354,68 @@ struct intrin_mfma_f64_16x16x4f64<16, 16> ...@@ -354,5 +354,68 @@ struct intrin_mfma_f64_16x16x4f64<16, 16>
#endif #endif
} }
}; };
template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_32x32x16f8f8;
template <>
struct intrin_mfma_f32_32x32x16f8f8<32, 32>
{
template <class FloatC>
__device__ static void Run(const f8x8_t& reg_a, const f8x8_t& reg_b, FloatC& reg_c)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
reg_c.template AsType<float16_t>()(Number<0>{}) =
__builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8(
bit_cast<long>(reg_a),
bit_cast<long>(reg_b),
reg_c.template AsType<float16_t>()[Number<0>{}],
0,
0,
0);
#else
vector_type<f8_t, 8> reg_a_v(reg_a);
vector_type<f8_t, 8> reg_b_v(reg_b);
static_for<0, 8, 1>{}([&](auto k) {
float reg_a_f32 = type_convert<float>(reg_a_v.template AsType<f8_t>()[Number<k>{}]);
float reg_b_f32 = type_convert<float>(reg_b_v.template AsType<f8_t>()[Number<k>{}]);
intrin_mfma_f32_32x32x2f32<32, 32>::Run(reg_a_f32, reg_b_f32, reg_c);
});
#endif
}
};
template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_16x16x32f8f8;
template <>
struct intrin_mfma_f32_16x16x32f8f8<16, 16>
{
template <class FloatC>
__device__ static void Run(const f8x8_t& reg_a, const f8x8_t& reg_b, FloatC& reg_c)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x32_fp8_fp8(
bit_cast<long>(reg_a),
bit_cast<long>(reg_b),
reg_c.template AsType<float4_t>()[Number<0>{}],
0,
0,
0);
#else
vector_type<f8_t, 8> reg_a_v(reg_a);
vector_type<f8_t, 8> reg_b_v(reg_b);
static_for<0, 8, 1>{}([&](auto k) {
float reg_a_f32 = type_convert<float>(reg_a_v.template AsType<f8_t>()[Number<k>{}]);
float reg_b_f32 = type_convert<float>(reg_b_v.template AsType<f8_t>()[Number<k>{}]);
intrin_mfma_f32_16x16x4f32<16, 16>::Run(reg_a_f32, reg_b_f32, reg_c);
});
#endif
}
};
} // namespace ck } // namespace ck
#endif #endif
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
namespace ck {
template <index_t N>
static constexpr __device__ index_t get_shift()
{
return (get_shift<N / 2>() + 1);
};
template <>
constexpr __device__ index_t get_shift<1>()
{
return (0);
}
} // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment