Unverified Commit c26c154e authored by rocking's avatar rocking Committed by GitHub
Browse files

Merge branch 'develop' into avgpool_bwd

parents 0ab4fa0f 1ee99dca
add_example_executable(example_batchnorm_forward_training batchnorm_forward_training_nhwc.cpp)
add_example_executable(example_batchnorm_forward_training_obsolete batchnorm_forward_training_nhwc_obsolete.cpp)
add_example_executable(example_batchnorm_forward_inferring batchnorm_forward_inferring_nhwc.cpp)
add_example_executable(example_batchnorm_backward batchnorm_backward_nhwc.cpp)
......@@ -414,7 +414,7 @@ bool bnorm_fwd_nhwc_test(bool do_verification,
(void)invoker_ptr_ref->Run(argument_ptr_ref.get());
y_dev.FromDevice(y.mData.data());
pass = pass && ck::utils::check_err(y, y_ref);
pass = pass && ck::utils::check_err(y, y_ref, "Incorrect normalized output values");
if(updateMovingAverage)
{
......@@ -424,8 +424,12 @@ bool bnorm_fwd_nhwc_test(bool do_verification,
resultRunningMean_dev.FromDevice(resultRunningMean.mData.data());
resultRunningVariance_dev.FromDevice(resultRunningVariance.mData.data());
pass = pass && ck::utils::check_err(resultRunningMean, resultRunningMean_ref);
pass = pass && ck::utils::check_err(resultRunningVariance, resultRunningVariance_ref);
pass = pass && ck::utils::check_err(resultRunningMean,
resultRunningMean_ref,
"Incorrect running mean values");
pass = pass && ck::utils::check_err(resultRunningVariance,
resultRunningVariance_ref,
"Incorrect running variance values");
};
if(saveMeanAndInvVariance)
......@@ -438,8 +442,11 @@ bool bnorm_fwd_nhwc_test(bool do_verification,
resultSaveMean_dev.FromDevice(resultSaveMean.mData.data());
resultSaveInvVariance_dev.FromDevice(resultSaveInvVariance.mData.data());
pass = pass && ck::utils::check_err(resultSaveMean, resultSaveMean_ref);
pass = pass && ck::utils::check_err(resultSaveInvVariance, resultSaveInvVariance_ref);
pass = pass && ck::utils::check_err(
resultSaveMean, resultSaveMean_ref, "Incorrect saved mean values");
pass = pass && ck::utils::check_err(resultSaveInvVariance,
resultSaveInvVariance_ref,
"Incorrect saved invvariance values");
};
};
......
......@@ -2,7 +2,7 @@
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp"
using InDataType = int8_t;
using WeiDataType = int8_t;
......
......@@ -2,7 +2,7 @@
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp"
using InDataType = int8_t;
using WeiDataType = int8_t;
......
......@@ -2,7 +2,7 @@
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp"
using InDataType = int8_t;
using WeiDataType = int8_t;
......
......@@ -2,7 +2,7 @@
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp"
using InDataType = int8_t;
using WeiDataType = int8_t;
......
......@@ -2,7 +2,7 @@
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp"
using InDataType = int8_t;
using WeiDataType = int8_t;
......
......@@ -2,7 +2,7 @@
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp"
using InDataType = int8_t;
using WeiDataType = int8_t;
......
......@@ -8,7 +8,7 @@
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_splitk_contraction_multiple_d_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_splitk_contraction_multiple_d_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
......
......@@ -8,7 +8,7 @@
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_splitk_contraction_multiple_d_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_splitk_contraction_multiple_d_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
......
......@@ -27,6 +27,21 @@
#define CK_WAVELET_MIN_BLOCK_PER_CU 2
#endif
// kernel attribute: amdgpu_waves_per_eu()
#ifdef CK_USE_WAVES_PER_EU
// for 1-wave kernels, control arguments of amdgpu_waves_per_eu() attribute
#ifndef CK_MIN_WAVES_PER_EU
#define CK_MIN_WAVES_PER_EU 0
#endif
#ifndef CK_MAX_WAVES_PER_EU
#define CK_MAX_WAVES_PER_EU 0
#endif
#else
#define CK_USE_WAVES_PER_EU 0
#endif
// buffer resource
#ifndef __HIP_DEVICE_COMPILE__ // for host code
#define CK_BUFFER_RESOURCE_3RD_DWORD -1
......@@ -148,6 +163,10 @@
#define CK_EXPERIMENTAL_INTER_WAVE_INSTANCES 1
// experimental feature: add instances using pipeline v2
#define CK_EXPERIMENTAL_PIPELINE_V2_INSTANCES 1
// experimental feature: optimize pipeline v2 by IGLP strategy (value=ID of strategy)
#ifndef CK_EXPERIMENTAL_PIPELINE_V2_IGLP_OPT
#define CK_EXPERIMENTAL_PIPELINE_V2_IGLP_OPT 0
#endif
// hack: have underlying assumption that need to be satsified, otherwise it's a bug
// hack for forcing register to keep idx_diff_low_const in SGPR. idx_diff_low_const must be
......@@ -173,6 +192,10 @@
// workaround: compiler issue on gfx908
#define CK_WORKAROUND_SWDEV_388832 1
// workaround: Grouped Conv2d_bwd_data fails for already implemented instance
#define CK_WORKAROUND_SWDEV_3318619 0
// flag to enable (1) or disable (0) the debugging output in some kernels
#define DEBUG_LOG 0
......
......@@ -4,7 +4,7 @@
#pragma once
#include "ck/tensor_description/cluster_descriptor.hpp"
#include "ck/utility/reduction_common.hpp"
#include "ck/utility/get_shift.hpp"
namespace ck {
......@@ -35,10 +35,11 @@ struct BlockwiseWelford
static constexpr auto thread_cluster_desc =
make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{});
template <typename CountDataType>
__device__ static inline void
Merge(T& mean_a, T& var_a, int& count_a, T mean_b, T var_b, int count_b)
Merge(T& mean_a, T& var_a, CountDataType& count_a, T mean_b, T var_b, CountDataType count_b)
{
int count = count_a + count_b;
CountDataType count = count_a + count_b;
T count_b_over_count = count == 0 ? type_convert<T>(0) : type_convert<T>(count_b) / count;
T delta = mean_b - mean_a;
mean_a += delta * count_b_over_count;
......@@ -46,11 +47,12 @@ struct BlockwiseWelford
count_a = count;
}
__device__ static void Run(T& mean_value, T& var_value, int& count)
template <typename CountDataType>
__device__ static void Run(T& mean_value, T& var_value, CountDataType& count)
{
__shared__ T mean_block_buf[BlockSize];
__shared__ T var_block_buf[BlockSize];
__shared__ int count_block_buf[BlockSize];
__shared__ CountDataType count_block_buf[BlockSize];
constexpr auto cluster_len_shift = get_shift<BufferLength_K>();
......@@ -76,13 +78,13 @@ struct BlockwiseWelford
index_t offset2 = block_buf_desc_m_k.CalculateOffset(thread_cluster_idx +
make_tuple(0, indOffset));
T mean1 = mean_block_buf[offset1];
T var1 = var_block_buf[offset1];
int count1 = count_block_buf[offset1];
T mean1 = mean_block_buf[offset1];
T var1 = var_block_buf[offset1];
CountDataType count1 = count_block_buf[offset1];
T mean2 = mean_block_buf[offset2];
T var2 = var_block_buf[offset2];
int count2 = count_block_buf[offset2];
T mean2 = mean_block_buf[offset2];
T var2 = var_block_buf[offset2];
CountDataType count2 = count_block_buf[offset2];
Merge(mean1, var1, count1, mean2, var2, count2);
......
......@@ -4,7 +4,7 @@
#pragma once
#include "ck/tensor_description/cluster_descriptor.hpp"
#include "ck/utility/reduction_common.hpp"
#include "ck/utility/get_shift.hpp"
#include "ck/utility/reduction_functions_accumulate.hpp"
namespace ck {
......
......@@ -27,17 +27,19 @@ struct DeviceGroupedConvBwdWeight : public BaseOperator
MakeArgumentPointer(const void* p_in,
void* p_wei,
const void* p_out,
ck::index_t G,
ck::index_t N,
ck::index_t K,
ck::index_t C,
std::array<ck::index_t, NDimSpatial> input_spatial_lengths,
std::array<ck::index_t, NDimSpatial> filter_spatial_lengths,
std::array<ck::index_t, NDimSpatial> output_spatial_lengths,
std::array<ck::index_t, NDimSpatial> conv_filter_strides,
std::array<ck::index_t, NDimSpatial> conv_filter_dilations,
std::array<ck::index_t, NDimSpatial> input_left_pads,
std::array<ck::index_t, NDimSpatial> input_right_pads,
const ck::index_t G,
const ck::index_t N,
const ck::index_t K,
const ck::index_t C,
const std::array<ck::index_t, NDimSpatial>& input_spatial_lengths,
const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths,
const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths,
const std::array<ck::index_t, NDimSpatial + 3>& input_strides,
const std::array<ck::index_t, NDimSpatial + 3>& output_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
const std::array<ck::index_t, NDimSpatial>& input_left_pads,
const std::array<ck::index_t, NDimSpatial>& input_right_pads,
InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op,
......
......@@ -10,12 +10,14 @@
#include "ck/tensor_operation/gpu/device/device_batchnorm_forward.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp"
#include "ck/tensor_operation/gpu/device/welford_helper.hpp"
#include "ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_batchnorm_forward.hpp"
#include "ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_first_half.hpp"
#include "ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_second_half_batchnorm_forward_final.hpp"
#include "ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_second_half_batchnorm_forward_final_obsolete.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batchnorm_forward_blockwise_welford.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/hip_check_error.hpp"
namespace ck {
namespace tensor_operation {
......@@ -114,8 +116,8 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType,
static auto MakeMeanVarCountOutputMG2dDescriptor(int invariantLength, int blkGroupSize)
{
const auto grid_desc_m_g =
make_naive_tensor_descriptor_packed(make_tuple(invariantLength, blkGroupSize));
const auto grid_desc_m_g = make_naive_tensor_descriptor(
make_tuple(invariantLength, blkGroupSize), make_tuple(1, invariantLength));
const auto mPad =
math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
......@@ -132,9 +134,9 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType,
static auto MakeMeanVarCountInputMK2dDescriptor(int invariantLength, int blkGroupSize)
{
const auto reduceLength = blkGroupSize;
const auto grid_desc_m_k =
make_naive_tensor_descriptor_packed(make_tuple(invariantLength, reduceLength));
const auto reduceLength = blkGroupSize;
const auto grid_desc_m_k = make_naive_tensor_descriptor(
make_tuple(invariantLength, reduceLength), make_tuple(1, invariantLength));
const auto mPad =
math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
......@@ -244,8 +246,8 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType,
int testBlkGroupSize = (reduce_length_ + (K_BlockTileSize * iterations) - 1) /
(K_BlockTileSize * iterations);
// we want the blkGroupSize be not more than 128
if(testBlkGroupSize <= 128)
// we want the blkGroupSize be not more than 16
if(testBlkGroupSize <= 16)
break;
iterations++;
......@@ -319,6 +321,8 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType,
void* workspace_mean_;
void* workspace_variance_;
void* workspace_count_;
void* control_;
};
size_t GetWorkSpaceSize(const BaseArgument* pArg) const override
......@@ -340,6 +344,11 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType,
// workspace for welford intermediate count
workspace_size +=
pArg_->invariant_length_ * pArg_->blkGroupSize_ * sizeof(int32_t) + 64;
// workspace for barrier objects, each barrier object consists of two integers
// TODO: allocate barrier object memory globally to reuse it by other operators
workspace_size += (pArg_->invariant_length_ + M_BlockTileSize - 1) / M_BlockTileSize *
sizeof(int) * 2;
}
return (workspace_size);
......@@ -353,7 +362,6 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType,
if(UseMultiblockInK && pArg_->blkGroupSize_ > 1)
{
// setup buffer used for intermediate welford mean
pArg_->workspace_mean_ = static_cast<char*>(pArg_->p_workspace_);
......@@ -374,6 +382,18 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType,
// setup buffer used for intermediate welfor count
pArg_->workspace_count_ =
reinterpret_cast<char*>(pArg_->workspace_variance_) + variance_space_sz;
index_t count_space_sz =
pArg_->invariant_length_ * pArg_->blkGroupSize_ * sizeof(int32_t);
count_space_sz = math::integer_least_multiple(count_space_sz, 64);
pArg_->control_ = reinterpret_cast<char*>(pArg_->workspace_count_) + count_space_sz;
index_t control_space_sz = (pArg_->invariant_length_ + M_BlockTileSize - 1) /
M_BlockTileSize * sizeof(int) * 2;
hip_check_error(hipMemset(pArg_->control_, 0, control_space_sz));
};
};
......@@ -402,6 +422,32 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType,
using MeanVarCountGridDesc_M_G = decltype(mean_var_count_grid_desc_m_g);
using MeanVarCountGridDesc_M_K = decltype(mean_var_count_grid_desc_m_k);
using GridwiseMultiblockBatchNormForward_ =
GridwiseMultiblockBatchNormForward<XDataType,
YDataType,
AccDataType,
ScaleDataType,
BiasDataType,
MeanVarDataType,
YElementwiseOp,
XYGridDesc_M_K,
MeanVarCountGridDesc_M_G,
MeanVarCountGridDesc_M_K,
ScaleBiasMeanVarGridDesc_M,
ScaleBiasMeanVarGridDesc_M,
GetReduceCountPerThreadFunctor,
BlockSize,
MThreadClusterSize,
KThreadClusterSize,
MThreadSliceSize,
KThreadSliceSize,
XSrcYDstVectorDim,
XSrcVectorSize,
YDstVectorSize,
ScaleSrcVectorSize,
BiasSrcVectorSize,
MeanVarSrcDstVectorSize>;
using GridwiseMultiblockWelfordFirstHalf_ =
GridwiseMultiblockWelfordFirstHalf<XDataType,
AccDataType,
......@@ -441,78 +487,136 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType,
BiasSrcVectorSize,
MeanVarSrcDstVectorSize>;
index_t numMeanVarCountBlockTileIteration =
(arg.blkGroupSize_ + KThreadClusterSize - 1) / KThreadClusterSize;
const auto kern_multiblock_welford_first_half =
kernel_multiblock_welford_first_half<GridwiseMultiblockWelfordFirstHalf_,
XDataType,
MeanVarDataType,
XYGridDesc_M_K,
MeanVarCountGridDesc_M_G,
GetReduceCountPerThreadFunctor>;
const auto kern_welford_second_half_batchnorm_forward_final =
kernel_welford_second_half_batchnorm_forward_final<
GridwiseWelfordSecondHalfBatchNormForwardFinal_,
XDataType,
YDataType,
AccDataType,
ScaleDataType,
BiasDataType,
MeanVarDataType,
YElementwiseOp,
XYGridDesc_M_K,
MeanVarCountGridDesc_M_K,
ScaleBiasMeanVarGridDesc_M,
ScaleBiasMeanVarGridDesc_M>;
avg_time +=
launch_and_time_kernel(stream_config,
kern_multiblock_welford_first_half,
dim3(arg.gridSize_),
dim3(BlockSize),
0,
arg.x_grid_desc_m_k_,
mean_var_count_grid_desc_m_g,
get_reduce_count_per_thread,
arg.numBlockTileIteration_,
arg.p_x_,
static_cast<MeanVarDataType*>(arg.workspace_mean_),
static_cast<MeanVarDataType*>(arg.workspace_variance_),
static_cast<int32_t*>(arg.workspace_count_));
avg_time +=
launch_and_time_kernel(stream_config,
kern_welford_second_half_batchnorm_forward_final,
dim3(arg.gridSize_),
dim3(BlockSize),
0,
arg.x_grid_desc_m_k_,
arg.y_grid_desc_m_k_,
mean_var_count_grid_desc_m_k,
arg.scale_grid_desc_m_,
arg.bias_grid_desc_m_,
arg.mean_var_grid_desc_m_,
arg.blkGroupSize_,
arg.numBlockTileIteration_,
numMeanVarCountBlockTileIteration,
arg.epsilon_,
static_cast<MeanVarDataType*>(arg.workspace_mean_),
static_cast<MeanVarDataType*>(arg.workspace_variance_),
static_cast<int32_t*>(arg.workspace_count_),
arg.p_x_,
arg.p_scale_,
arg.p_bias_,
arg.y_elementwise_op_,
arg.p_y_,
arg.updateMovingAverage_,
arg.averageFactor_,
arg.resultRunningMean_,
arg.resultRunningVariance_,
arg.saveMeanInvVariance_,
arg.resultSaveMean_,
arg.resultSaveInvVariance_);
// It is found that:
// 1) gfx1030 does not support the GLC enabled vector load/store, so using the
// two-kernel method for gfx1030
// 2) Profiler on gfx908 could hang even though it works when running examples
// 3) Single-kernel method works on gfx1100, but the performance it not better
// than two-kernel method (due to more warps participating the barrier)
if(ck::get_device_name() == "gfx90a")
{
const auto kern_multiblock_batchnorm_fwd_ =
kernel_multiblock_batchnorm_forward<GridwiseMultiblockBatchNormForward_,
XDataType,
YDataType,
AccDataType,
ScaleDataType,
BiasDataType,
MeanVarDataType,
YElementwiseOp,
XYGridDesc_M_K,
MeanVarCountGridDesc_M_G,
MeanVarCountGridDesc_M_K,
ScaleBiasMeanVarGridDesc_M,
ScaleBiasMeanVarGridDesc_M,
GetReduceCountPerThreadFunctor>;
avg_time += launch_and_time_kernel(
stream_config,
kern_multiblock_batchnorm_fwd_,
dim3(arg.gridSize_),
dim3(BlockSize),
0,
arg.x_grid_desc_m_k_,
arg.y_grid_desc_m_k_,
mean_var_count_grid_desc_m_g, // for writing to mean/variance/count
// workspace by multiple workgroups
mean_var_count_grid_desc_m_k, // for reading from mean/variance/count
// workspace by each workgroup
arg.scale_grid_desc_m_,
arg.bias_grid_desc_m_,
arg.mean_var_grid_desc_m_,
get_reduce_count_per_thread,
arg.numBlockTileIteration_,
arg.epsilon_,
arg.p_x_,
static_cast<MeanVarDataType*>(arg.workspace_mean_),
static_cast<MeanVarDataType*>(arg.workspace_variance_),
static_cast<int32_t*>(arg.workspace_count_),
static_cast<int*>(arg.control_),
arg.p_scale_,
arg.p_bias_,
arg.y_elementwise_op_,
arg.p_y_,
arg.updateMovingAverage_, // true or false
arg.averageFactor_,
arg.resultRunningMean_,
arg.resultRunningVariance_,
arg.saveMeanInvVariance_, // true or false
arg.resultSaveMean_,
arg.resultSaveInvVariance_);
}
else
{
const auto kern_multiblock_welford_first_half =
kernel_multiblock_welford_first_half<GridwiseMultiblockWelfordFirstHalf_,
XDataType,
MeanVarDataType,
XYGridDesc_M_K,
MeanVarCountGridDesc_M_G,
GetReduceCountPerThreadFunctor>;
const auto kern_welford_second_half_batchnorm_forward_final =
kernel_welford_second_half_batchnorm_forward_final<
GridwiseWelfordSecondHalfBatchNormForwardFinal_,
XDataType,
YDataType,
AccDataType,
ScaleDataType,
BiasDataType,
MeanVarDataType,
YElementwiseOp,
XYGridDesc_M_K,
MeanVarCountGridDesc_M_K,
ScaleBiasMeanVarGridDesc_M,
ScaleBiasMeanVarGridDesc_M>;
avg_time += launch_and_time_kernel(
stream_config,
kern_multiblock_welford_first_half,
dim3(arg.gridSize_),
dim3(BlockSize),
0,
arg.x_grid_desc_m_k_,
mean_var_count_grid_desc_m_g,
get_reduce_count_per_thread,
arg.numBlockTileIteration_,
arg.p_x_,
static_cast<MeanVarDataType*>(arg.workspace_mean_),
static_cast<MeanVarDataType*>(arg.workspace_variance_),
static_cast<int32_t*>(arg.workspace_count_));
avg_time += launch_and_time_kernel(
stream_config,
kern_welford_second_half_batchnorm_forward_final,
dim3(arg.gridSize_),
dim3(BlockSize),
0,
arg.x_grid_desc_m_k_,
arg.y_grid_desc_m_k_,
mean_var_count_grid_desc_m_k,
arg.scale_grid_desc_m_,
arg.bias_grid_desc_m_,
arg.mean_var_grid_desc_m_,
arg.blkGroupSize_,
arg.numBlockTileIteration_,
arg.epsilon_,
static_cast<MeanVarDataType*>(arg.workspace_mean_),
static_cast<MeanVarDataType*>(arg.workspace_variance_),
static_cast<int32_t*>(arg.workspace_count_),
arg.p_x_,
arg.p_scale_,
arg.p_bias_,
arg.y_elementwise_op_,
arg.p_y_,
arg.updateMovingAverage_,
arg.averageFactor_,
arg.resultRunningMean_,
arg.resultRunningVariance_,
arg.saveMeanInvVariance_,
arg.resultSaveMean_,
arg.resultSaveInvVariance_);
};
}
else
{
......
......@@ -459,7 +459,6 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
p_ds_grid_{},
p_e_grid_{static_cast<EDataType*>(p_e)},
num_group_{a_g_n_k_wos_lengths[0]},
num_gemm_{},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
cde_element_op_{cde_element_op},
......@@ -508,9 +507,6 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
const auto YTilde = ConvStrideH / GcdStrideDilationH;
const auto XTilde = ConvStrideW / GcdStrideDilationW;
// number of GEMM
num_gemm_ = YTilde * XTilde;
for(index_t i_ytilde = 0; i_ytilde < YTilde; ++i_ytilde)
{
for(index_t i_xtilde = 0; i_xtilde < XTilde; ++i_xtilde)
......@@ -626,7 +622,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
void Print() const
{
for(index_t i = 0; i < num_gemm_; i++)
for(std::size_t i = 0; i < a_grid_desc_ak0_m_ak1_container_.size(); i++)
{
std::cout << "a_grid_desc_ak0_m_ak1_container_"
<< a_grid_desc_ak0_m_ak1_container_[i] << std::endl;
......@@ -654,7 +650,6 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
// tensor descriptor for problem definition
index_t num_group_;
index_t num_gemm_;
std::vector<AGridDesc_M_K> a_grid_desc_m_k_container_;
std::vector<BGridDesc_N_K> b_grid_desc_n_k_container_;
std::vector<DsGridDesc_M_N> ds_grid_desc_m_n_container_;
......@@ -708,7 +703,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
float ave_time = 0;
for(index_t i = 0; i < arg.num_gemm_; i++)
for(std::size_t i = 0; i < arg.a_grid_desc_ak0_m_ak1_container_.size(); i++)
{
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_container_[i],
arg.b_grid_desc_n_k_container_[i],
......@@ -807,7 +802,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
}
// vector load for A matrix from global memory to LDS
if constexpr(is_same_v<ALayout, tensor_layout::convolution::GNHWK>)
if constexpr(is_same_v<ALayout, tensor_layout::convolution::GNHWK> ||
is_same_v<ALayout, tensor_layout::convolution::NHWGK>)
{
if(!(ABlockTransferSrcVectorDim == 2 && ConvK % ABlockTransferSrcScalarPerVector == 0))
{
......@@ -862,7 +858,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
}
// vector store for E
if constexpr(is_same_v<ELayout, tensor_layout::convolution::GNHWC>)
if constexpr(is_same_v<ELayout, tensor_layout::convolution::GNHWC> ||
is_same_v<ELayout, tensor_layout::convolution::NHWGC>)
{
// vector store C matrix into global memory
if(!(ConvC % CDEBlockTransferScalarPerVector_NPerBlock == 0))
......
......@@ -195,17 +195,17 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
template <ck::index_t NDim, typename ck::enable_if<NDim == 1, bool>::type = false>
static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(
ck::index_t N,
ck::index_t K,
ck::index_t C,
std::array<ck::index_t, NDimSpatial> input_spatial_lengths,
std::array<ck::index_t, NDimSpatial> filter_spatial_lengths,
std::array<ck::index_t, NDimSpatial> output_spatial_lengths,
std::array<ck::index_t, NDimSpatial> conv_filter_strides,
std::array<ck::index_t, NDimSpatial> conv_filter_dilations,
std::array<ck::index_t, NDimSpatial> input_left_pads,
std::array<ck::index_t, NDimSpatial> input_right_pads,
ck::index_t batch_k)
const ck::index_t N,
const ck::index_t K,
const ck::index_t C,
const std::array<ck::index_t, NDimSpatial>& input_spatial_lengths,
const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths,
const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths,
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
const std::array<ck::index_t, NDimSpatial>& input_left_pads,
const std::array<ck::index_t, NDimSpatial>& input_right_pads,
const ck::index_t batch_k)
{
using namespace ck;
......@@ -347,17 +347,17 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
} // function end
template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false>
static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(
ck::index_t N,
ck::index_t K,
ck::index_t C,
std::array<ck::index_t, NDimSpatial> input_spatial_lengths,
std::array<ck::index_t, NDimSpatial> filter_spatial_lengths,
std::array<ck::index_t, NDimSpatial> output_spatial_lengths,
std::array<ck::index_t, NDimSpatial> conv_filter_strides,
std::array<ck::index_t, NDimSpatial> conv_filter_dilations,
std::array<ck::index_t, NDimSpatial> input_left_pads,
std::array<ck::index_t, NDimSpatial> input_right_pads,
ck::index_t batch_k)
const ck::index_t N,
const ck::index_t K,
const ck::index_t C,
const std::array<ck::index_t, NDimSpatial>& input_spatial_lengths,
const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths,
const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths,
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
const std::array<ck::index_t, NDimSpatial>& input_left_pads,
const std::array<ck::index_t, NDimSpatial>& input_right_pads,
const ck::index_t batch_k)
{
using namespace ck;
......@@ -515,17 +515,17 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false>
static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(
ck::index_t N,
ck::index_t K,
ck::index_t C,
std::array<ck::index_t, NDimSpatial> input_spatial_lengths,
std::array<ck::index_t, NDimSpatial> filter_spatial_lengths,
std::array<ck::index_t, NDimSpatial> output_spatial_lengths,
std::array<ck::index_t, NDimSpatial> conv_filter_strides,
std::array<ck::index_t, NDimSpatial> conv_filter_dilations,
std::array<ck::index_t, NDimSpatial> input_left_pads,
std::array<ck::index_t, NDimSpatial> input_right_pads,
ck::index_t batch_k)
const ck::index_t N,
const ck::index_t K,
const ck::index_t C,
const std::array<ck::index_t, NDimSpatial>& input_spatial_lengths,
const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths,
const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths,
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
const std::array<ck::index_t, NDimSpatial>& input_left_pads,
const std::array<ck::index_t, NDimSpatial>& input_right_pads,
const ck::index_t batch_k)
{
using namespace ck;
......@@ -784,17 +784,19 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
Argument(const InDataType* p_in_grid,
WeiDataType* p_wei_grid,
const OutDataType* p_out_grid,
ck::index_t G,
ck::index_t N,
ck::index_t K,
ck::index_t C,
std::array<ck::index_t, NDimSpatial> input_spatial_lengths,
std::array<ck::index_t, NDimSpatial> filter_spatial_lengths,
std::array<ck::index_t, NDimSpatial> output_spatial_lengths,
std::array<ck::index_t, NDimSpatial> conv_filter_strides,
std::array<ck::index_t, NDimSpatial> conv_filter_dilations,
std::array<ck::index_t, NDimSpatial> input_left_pads,
std::array<ck::index_t, NDimSpatial> input_right_pads,
const ck::index_t G,
const ck::index_t N,
const ck::index_t K,
const ck::index_t C,
const std::array<ck::index_t, NDimSpatial>& input_spatial_lengths,
const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths,
const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths,
const std::array<ck::index_t, NDimSpatial + 3>& /*input_strides*/,
const std::array<ck::index_t, NDimSpatial + 3>& /*output_strides*/,
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
const std::array<ck::index_t, NDimSpatial>& input_left_pads,
const std::array<ck::index_t, NDimSpatial>& input_right_pads,
InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op,
......@@ -897,18 +899,18 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
InElementwiseOperation c_element_op_;
// for checking IsSupportedArgument()
index_t Conv_G_;
index_t Conv_N_;
index_t Conv_K_;
index_t Conv_C_;
std::array<ck::index_t, NDimSpatial> input_spatial_lengths_;
std::array<ck::index_t, NDimSpatial> filter_spatial_lengths_;
std::array<ck::index_t, NDimSpatial> output_spatial_lengths_;
std::array<ck::index_t, NDimSpatial> conv_filter_strides_;
std::array<ck::index_t, NDimSpatial> conv_filter_dilations_;
std::array<ck::index_t, NDimSpatial> input_left_pads_;
std::array<ck::index_t, NDimSpatial> input_right_pads_;
const index_t Conv_G_;
const index_t Conv_N_;
const index_t Conv_K_;
const index_t Conv_C_;
const std::array<ck::index_t, NDimSpatial>& input_spatial_lengths_;
const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths_;
const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths_;
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides_;
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations_;
const std::array<ck::index_t, NDimSpatial>& input_left_pads_;
const std::array<ck::index_t, NDimSpatial>& input_right_pads_;
index_t k_batch_;
};
......@@ -1111,17 +1113,19 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
static auto MakeArgument(const InDataType* p_in_grid,
WeiDataType* p_wei_grid,
const OutDataType* p_out_grid,
ck::index_t G,
ck::index_t N,
ck::index_t K,
ck::index_t C,
std::array<ck::index_t, NDimSpatial> input_spatial_lengths,
std::array<ck::index_t, NDimSpatial> filter_spatial_lengths,
std::array<ck::index_t, NDimSpatial> output_spatial_lengths,
std::array<ck::index_t, NDimSpatial> conv_filter_strides,
std::array<ck::index_t, NDimSpatial> conv_filter_dilations,
std::array<ck::index_t, NDimSpatial> input_left_pads,
std::array<ck::index_t, NDimSpatial> input_right_pads,
const ck::index_t G,
const ck::index_t N,
const ck::index_t K,
const ck::index_t C,
const std::array<ck::index_t, NDimSpatial>& input_spatial_lengths,
const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths,
const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths,
const std::array<ck::index_t, NDimSpatial + 3>& input_strides,
const std::array<ck::index_t, NDimSpatial + 3>& output_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
const std::array<ck::index_t, NDimSpatial>& input_left_pads,
const std::array<ck::index_t, NDimSpatial>& input_right_pads,
InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op,
......@@ -1137,6 +1141,8 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
input_spatial_lengths,
filter_spatial_lengths,
output_spatial_lengths,
input_strides,
output_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
......@@ -1153,17 +1159,19 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
MakeArgumentPointer(const void* p_in_grid,
void* p_wei_grid,
const void* p_out_grid,
ck::index_t G,
ck::index_t N,
ck::index_t K,
ck::index_t C,
std::array<ck::index_t, NDimSpatial> input_spatial_lengths,
std::array<ck::index_t, NDimSpatial> filter_spatial_lengths,
std::array<ck::index_t, NDimSpatial> output_spatial_lengths,
std::array<ck::index_t, NDimSpatial> conv_filter_strides,
std::array<ck::index_t, NDimSpatial> conv_filter_dilations,
std::array<ck::index_t, NDimSpatial> input_left_pads,
std::array<ck::index_t, NDimSpatial> input_right_pads,
const ck::index_t G,
const ck::index_t N,
const ck::index_t K,
const ck::index_t C,
const std::array<ck::index_t, NDimSpatial>& input_spatial_lengths,
const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths,
const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths,
const std::array<ck::index_t, NDimSpatial + 3>& input_strides,
const std::array<ck::index_t, NDimSpatial + 3>& output_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
const std::array<ck::index_t, NDimSpatial>& input_left_pads,
const std::array<ck::index_t, NDimSpatial>& input_right_pads,
InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op,
......@@ -1179,6 +1187,8 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
input_spatial_lengths,
filter_spatial_lengths,
output_spatial_lengths,
input_strides,
output_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment