Unverified Commit c26c154e authored by rocking's avatar rocking Committed by GitHub
Browse files

Merge branch 'develop' into avgpool_bwd

parents 0ab4fa0f 1ee99dca
add_example_executable(example_batchnorm_forward_training batchnorm_forward_training_nhwc.cpp) add_example_executable(example_batchnorm_forward_training batchnorm_forward_training_nhwc.cpp)
add_example_executable(example_batchnorm_forward_training_obsolete batchnorm_forward_training_nhwc_obsolete.cpp)
add_example_executable(example_batchnorm_forward_inferring batchnorm_forward_inferring_nhwc.cpp) add_example_executable(example_batchnorm_forward_inferring batchnorm_forward_inferring_nhwc.cpp)
add_example_executable(example_batchnorm_backward batchnorm_backward_nhwc.cpp) add_example_executable(example_batchnorm_backward batchnorm_backward_nhwc.cpp)
...@@ -414,7 +414,7 @@ bool bnorm_fwd_nhwc_test(bool do_verification, ...@@ -414,7 +414,7 @@ bool bnorm_fwd_nhwc_test(bool do_verification,
(void)invoker_ptr_ref->Run(argument_ptr_ref.get()); (void)invoker_ptr_ref->Run(argument_ptr_ref.get());
y_dev.FromDevice(y.mData.data()); y_dev.FromDevice(y.mData.data());
pass = pass && ck::utils::check_err(y, y_ref); pass = pass && ck::utils::check_err(y, y_ref, "Incorrect normalized output values");
if(updateMovingAverage) if(updateMovingAverage)
{ {
...@@ -424,8 +424,12 @@ bool bnorm_fwd_nhwc_test(bool do_verification, ...@@ -424,8 +424,12 @@ bool bnorm_fwd_nhwc_test(bool do_verification,
resultRunningMean_dev.FromDevice(resultRunningMean.mData.data()); resultRunningMean_dev.FromDevice(resultRunningMean.mData.data());
resultRunningVariance_dev.FromDevice(resultRunningVariance.mData.data()); resultRunningVariance_dev.FromDevice(resultRunningVariance.mData.data());
pass = pass && ck::utils::check_err(resultRunningMean, resultRunningMean_ref); pass = pass && ck::utils::check_err(resultRunningMean,
pass = pass && ck::utils::check_err(resultRunningVariance, resultRunningVariance_ref); resultRunningMean_ref,
"Incorrect running mean values");
pass = pass && ck::utils::check_err(resultRunningVariance,
resultRunningVariance_ref,
"Incorrect running variance values");
}; };
if(saveMeanAndInvVariance) if(saveMeanAndInvVariance)
...@@ -438,8 +442,11 @@ bool bnorm_fwd_nhwc_test(bool do_verification, ...@@ -438,8 +442,11 @@ bool bnorm_fwd_nhwc_test(bool do_verification,
resultSaveMean_dev.FromDevice(resultSaveMean.mData.data()); resultSaveMean_dev.FromDevice(resultSaveMean.mData.data());
resultSaveInvVariance_dev.FromDevice(resultSaveInvVariance.mData.data()); resultSaveInvVariance_dev.FromDevice(resultSaveInvVariance.mData.data());
pass = pass && ck::utils::check_err(resultSaveMean, resultSaveMean_ref); pass = pass && ck::utils::check_err(
pass = pass && ck::utils::check_err(resultSaveInvVariance, resultSaveInvVariance_ref); resultSaveMean, resultSaveMean_ref, "Incorrect saved mean values");
pass = pass && ck::utils::check_err(resultSaveInvVariance,
resultSaveInvVariance_ref,
"Incorrect saved invvariance values");
}; };
}; };
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp" #include "common.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp" #include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp"
using InDataType = int8_t; using InDataType = int8_t;
using WeiDataType = int8_t; using WeiDataType = int8_t;
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp" #include "common.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp" #include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp"
using InDataType = int8_t; using InDataType = int8_t;
using WeiDataType = int8_t; using WeiDataType = int8_t;
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp" #include "common.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp" #include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp"
using InDataType = int8_t; using InDataType = int8_t;
using WeiDataType = int8_t; using WeiDataType = int8_t;
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp" #include "common.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp" #include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp"
using InDataType = int8_t; using InDataType = int8_t;
using WeiDataType = int8_t; using WeiDataType = int8_t;
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp" #include "common.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp" #include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp"
using InDataType = int8_t; using InDataType = int8_t;
using WeiDataType = int8_t; using WeiDataType = int8_t;
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp" #include "common.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp" #include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp"
using InDataType = int8_t; using InDataType = int8_t;
using WeiDataType = int8_t; using WeiDataType = int8_t;
......
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_splitk_contraction_multiple_d_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/device/impl/device_splitk_contraction_multiple_d_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp" #include "ck/library/utility/check_err.hpp"
......
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_splitk_contraction_multiple_d_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/device/impl/device_splitk_contraction_multiple_d_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp" #include "ck/library/utility/check_err.hpp"
......
...@@ -27,6 +27,21 @@ ...@@ -27,6 +27,21 @@
#define CK_WAVELET_MIN_BLOCK_PER_CU 2 #define CK_WAVELET_MIN_BLOCK_PER_CU 2
#endif #endif
// kernel attribute: amdgpu_waves_per_eu()
#ifdef CK_USE_WAVES_PER_EU
// for 1-wave kernels, control arguments of amdgpu_waves_per_eu() attribute
#ifndef CK_MIN_WAVES_PER_EU
#define CK_MIN_WAVES_PER_EU 0
#endif
#ifndef CK_MAX_WAVES_PER_EU
#define CK_MAX_WAVES_PER_EU 0
#endif
#else
#define CK_USE_WAVES_PER_EU 0
#endif
// buffer resource // buffer resource
#ifndef __HIP_DEVICE_COMPILE__ // for host code #ifndef __HIP_DEVICE_COMPILE__ // for host code
#define CK_BUFFER_RESOURCE_3RD_DWORD -1 #define CK_BUFFER_RESOURCE_3RD_DWORD -1
...@@ -148,6 +163,10 @@ ...@@ -148,6 +163,10 @@
#define CK_EXPERIMENTAL_INTER_WAVE_INSTANCES 1 #define CK_EXPERIMENTAL_INTER_WAVE_INSTANCES 1
// experimental feature: add instances using pipeline v2 // experimental feature: add instances using pipeline v2
#define CK_EXPERIMENTAL_PIPELINE_V2_INSTANCES 1 #define CK_EXPERIMENTAL_PIPELINE_V2_INSTANCES 1
// experimental feature: optimize pipeline v2 by IGLP strategy (value=ID of strategy)
#ifndef CK_EXPERIMENTAL_PIPELINE_V2_IGLP_OPT
#define CK_EXPERIMENTAL_PIPELINE_V2_IGLP_OPT 0
#endif
// hack: have underlying assumption that need to be satsified, otherwise it's a bug // hack: have underlying assumption that need to be satsified, otherwise it's a bug
// hack for forcing register to keep idx_diff_low_const in SGPR. idx_diff_low_const must be // hack for forcing register to keep idx_diff_low_const in SGPR. idx_diff_low_const must be
...@@ -173,6 +192,10 @@ ...@@ -173,6 +192,10 @@
// workaround: compiler issue on gfx908 // workaround: compiler issue on gfx908
#define CK_WORKAROUND_SWDEV_388832 1 #define CK_WORKAROUND_SWDEV_388832 1
// workaround: Grouped Conv2d_bwd_data fails for already implemented instance
#define CK_WORKAROUND_SWDEV_3318619 0
// flag to enable (1) or disable (0) the debugging output in some kernels // flag to enable (1) or disable (0) the debugging output in some kernels
#define DEBUG_LOG 0 #define DEBUG_LOG 0
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
#pragma once #pragma once
#include "ck/tensor_description/cluster_descriptor.hpp" #include "ck/tensor_description/cluster_descriptor.hpp"
#include "ck/utility/reduction_common.hpp" #include "ck/utility/get_shift.hpp"
namespace ck { namespace ck {
...@@ -35,10 +35,11 @@ struct BlockwiseWelford ...@@ -35,10 +35,11 @@ struct BlockwiseWelford
static constexpr auto thread_cluster_desc = static constexpr auto thread_cluster_desc =
make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{}); make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{});
template <typename CountDataType>
__device__ static inline void __device__ static inline void
Merge(T& mean_a, T& var_a, int& count_a, T mean_b, T var_b, int count_b) Merge(T& mean_a, T& var_a, CountDataType& count_a, T mean_b, T var_b, CountDataType count_b)
{ {
int count = count_a + count_b; CountDataType count = count_a + count_b;
T count_b_over_count = count == 0 ? type_convert<T>(0) : type_convert<T>(count_b) / count; T count_b_over_count = count == 0 ? type_convert<T>(0) : type_convert<T>(count_b) / count;
T delta = mean_b - mean_a; T delta = mean_b - mean_a;
mean_a += delta * count_b_over_count; mean_a += delta * count_b_over_count;
...@@ -46,11 +47,12 @@ struct BlockwiseWelford ...@@ -46,11 +47,12 @@ struct BlockwiseWelford
count_a = count; count_a = count;
} }
__device__ static void Run(T& mean_value, T& var_value, int& count) template <typename CountDataType>
__device__ static void Run(T& mean_value, T& var_value, CountDataType& count)
{ {
__shared__ T mean_block_buf[BlockSize]; __shared__ T mean_block_buf[BlockSize];
__shared__ T var_block_buf[BlockSize]; __shared__ T var_block_buf[BlockSize];
__shared__ int count_block_buf[BlockSize]; __shared__ CountDataType count_block_buf[BlockSize];
constexpr auto cluster_len_shift = get_shift<BufferLength_K>(); constexpr auto cluster_len_shift = get_shift<BufferLength_K>();
...@@ -76,13 +78,13 @@ struct BlockwiseWelford ...@@ -76,13 +78,13 @@ struct BlockwiseWelford
index_t offset2 = block_buf_desc_m_k.CalculateOffset(thread_cluster_idx + index_t offset2 = block_buf_desc_m_k.CalculateOffset(thread_cluster_idx +
make_tuple(0, indOffset)); make_tuple(0, indOffset));
T mean1 = mean_block_buf[offset1]; T mean1 = mean_block_buf[offset1];
T var1 = var_block_buf[offset1]; T var1 = var_block_buf[offset1];
int count1 = count_block_buf[offset1]; CountDataType count1 = count_block_buf[offset1];
T mean2 = mean_block_buf[offset2]; T mean2 = mean_block_buf[offset2];
T var2 = var_block_buf[offset2]; T var2 = var_block_buf[offset2];
int count2 = count_block_buf[offset2]; CountDataType count2 = count_block_buf[offset2];
Merge(mean1, var1, count1, mean2, var2, count2); Merge(mean1, var1, count1, mean2, var2, count2);
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
#pragma once #pragma once
#include "ck/tensor_description/cluster_descriptor.hpp" #include "ck/tensor_description/cluster_descriptor.hpp"
#include "ck/utility/reduction_common.hpp" #include "ck/utility/get_shift.hpp"
#include "ck/utility/reduction_functions_accumulate.hpp" #include "ck/utility/reduction_functions_accumulate.hpp"
namespace ck { namespace ck {
......
...@@ -27,17 +27,19 @@ struct DeviceGroupedConvBwdWeight : public BaseOperator ...@@ -27,17 +27,19 @@ struct DeviceGroupedConvBwdWeight : public BaseOperator
MakeArgumentPointer(const void* p_in, MakeArgumentPointer(const void* p_in,
void* p_wei, void* p_wei,
const void* p_out, const void* p_out,
ck::index_t G, const ck::index_t G,
ck::index_t N, const ck::index_t N,
ck::index_t K, const ck::index_t K,
ck::index_t C, const ck::index_t C,
std::array<ck::index_t, NDimSpatial> input_spatial_lengths, const std::array<ck::index_t, NDimSpatial>& input_spatial_lengths,
std::array<ck::index_t, NDimSpatial> filter_spatial_lengths, const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths,
std::array<ck::index_t, NDimSpatial> output_spatial_lengths, const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths,
std::array<ck::index_t, NDimSpatial> conv_filter_strides, const std::array<ck::index_t, NDimSpatial + 3>& input_strides,
std::array<ck::index_t, NDimSpatial> conv_filter_dilations, const std::array<ck::index_t, NDimSpatial + 3>& output_strides,
std::array<ck::index_t, NDimSpatial> input_left_pads, const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
std::array<ck::index_t, NDimSpatial> input_right_pads, const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
const std::array<ck::index_t, NDimSpatial>& input_left_pads,
const std::array<ck::index_t, NDimSpatial>& input_right_pads,
InElementwiseOperation in_element_op, InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op, WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op, OutElementwiseOperation out_element_op,
......
...@@ -10,12 +10,14 @@ ...@@ -10,12 +10,14 @@
#include "ck/tensor_operation/gpu/device/device_batchnorm_forward.hpp" #include "ck/tensor_operation/gpu/device/device_batchnorm_forward.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp" #include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp"
#include "ck/tensor_operation/gpu/device/welford_helper.hpp" #include "ck/tensor_operation/gpu/device/welford_helper.hpp"
#include "ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_batchnorm_forward.hpp"
#include "ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_first_half.hpp" #include "ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_first_half.hpp"
#include "ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_second_half_batchnorm_forward_final.hpp" #include "ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_second_half_batchnorm_forward_final_obsolete.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batchnorm_forward_blockwise_welford.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_batchnorm_forward_blockwise_welford.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/hip_check_error.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
...@@ -114,8 +116,8 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType, ...@@ -114,8 +116,8 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType,
static auto MakeMeanVarCountOutputMG2dDescriptor(int invariantLength, int blkGroupSize) static auto MakeMeanVarCountOutputMG2dDescriptor(int invariantLength, int blkGroupSize)
{ {
const auto grid_desc_m_g = const auto grid_desc_m_g = make_naive_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(invariantLength, blkGroupSize)); make_tuple(invariantLength, blkGroupSize), make_tuple(1, invariantLength));
const auto mPad = const auto mPad =
math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength; math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
...@@ -132,9 +134,9 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType, ...@@ -132,9 +134,9 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType,
static auto MakeMeanVarCountInputMK2dDescriptor(int invariantLength, int blkGroupSize) static auto MakeMeanVarCountInputMK2dDescriptor(int invariantLength, int blkGroupSize)
{ {
const auto reduceLength = blkGroupSize; const auto reduceLength = blkGroupSize;
const auto grid_desc_m_k = const auto grid_desc_m_k = make_naive_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(invariantLength, reduceLength)); make_tuple(invariantLength, reduceLength), make_tuple(1, invariantLength));
const auto mPad = const auto mPad =
math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength; math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
...@@ -244,8 +246,8 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType, ...@@ -244,8 +246,8 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType,
int testBlkGroupSize = (reduce_length_ + (K_BlockTileSize * iterations) - 1) / int testBlkGroupSize = (reduce_length_ + (K_BlockTileSize * iterations) - 1) /
(K_BlockTileSize * iterations); (K_BlockTileSize * iterations);
// we want the blkGroupSize be not more than 128 // we want the blkGroupSize be not more than 16
if(testBlkGroupSize <= 128) if(testBlkGroupSize <= 16)
break; break;
iterations++; iterations++;
...@@ -319,6 +321,8 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType, ...@@ -319,6 +321,8 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType,
void* workspace_mean_; void* workspace_mean_;
void* workspace_variance_; void* workspace_variance_;
void* workspace_count_; void* workspace_count_;
void* control_;
}; };
size_t GetWorkSpaceSize(const BaseArgument* pArg) const override size_t GetWorkSpaceSize(const BaseArgument* pArg) const override
...@@ -340,6 +344,11 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType, ...@@ -340,6 +344,11 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType,
// workspace for welford intermediate count // workspace for welford intermediate count
workspace_size += workspace_size +=
pArg_->invariant_length_ * pArg_->blkGroupSize_ * sizeof(int32_t) + 64; pArg_->invariant_length_ * pArg_->blkGroupSize_ * sizeof(int32_t) + 64;
// workspace for barrier objects, each barrier object consists of two integers
// TODO: allocate barrier object memory globally to reuse it by other operators
workspace_size += (pArg_->invariant_length_ + M_BlockTileSize - 1) / M_BlockTileSize *
sizeof(int) * 2;
} }
return (workspace_size); return (workspace_size);
...@@ -353,7 +362,6 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType, ...@@ -353,7 +362,6 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType,
if(UseMultiblockInK && pArg_->blkGroupSize_ > 1) if(UseMultiblockInK && pArg_->blkGroupSize_ > 1)
{ {
// setup buffer used for intermediate welford mean // setup buffer used for intermediate welford mean
pArg_->workspace_mean_ = static_cast<char*>(pArg_->p_workspace_); pArg_->workspace_mean_ = static_cast<char*>(pArg_->p_workspace_);
...@@ -374,6 +382,18 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType, ...@@ -374,6 +382,18 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType,
// setup buffer used for intermediate welfor count // setup buffer used for intermediate welfor count
pArg_->workspace_count_ = pArg_->workspace_count_ =
reinterpret_cast<char*>(pArg_->workspace_variance_) + variance_space_sz; reinterpret_cast<char*>(pArg_->workspace_variance_) + variance_space_sz;
index_t count_space_sz =
pArg_->invariant_length_ * pArg_->blkGroupSize_ * sizeof(int32_t);
count_space_sz = math::integer_least_multiple(count_space_sz, 64);
pArg_->control_ = reinterpret_cast<char*>(pArg_->workspace_count_) + count_space_sz;
index_t control_space_sz = (pArg_->invariant_length_ + M_BlockTileSize - 1) /
M_BlockTileSize * sizeof(int) * 2;
hip_check_error(hipMemset(pArg_->control_, 0, control_space_sz));
}; };
}; };
...@@ -402,6 +422,32 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType, ...@@ -402,6 +422,32 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType,
using MeanVarCountGridDesc_M_G = decltype(mean_var_count_grid_desc_m_g); using MeanVarCountGridDesc_M_G = decltype(mean_var_count_grid_desc_m_g);
using MeanVarCountGridDesc_M_K = decltype(mean_var_count_grid_desc_m_k); using MeanVarCountGridDesc_M_K = decltype(mean_var_count_grid_desc_m_k);
using GridwiseMultiblockBatchNormForward_ =
GridwiseMultiblockBatchNormForward<XDataType,
YDataType,
AccDataType,
ScaleDataType,
BiasDataType,
MeanVarDataType,
YElementwiseOp,
XYGridDesc_M_K,
MeanVarCountGridDesc_M_G,
MeanVarCountGridDesc_M_K,
ScaleBiasMeanVarGridDesc_M,
ScaleBiasMeanVarGridDesc_M,
GetReduceCountPerThreadFunctor,
BlockSize,
MThreadClusterSize,
KThreadClusterSize,
MThreadSliceSize,
KThreadSliceSize,
XSrcYDstVectorDim,
XSrcVectorSize,
YDstVectorSize,
ScaleSrcVectorSize,
BiasSrcVectorSize,
MeanVarSrcDstVectorSize>;
using GridwiseMultiblockWelfordFirstHalf_ = using GridwiseMultiblockWelfordFirstHalf_ =
GridwiseMultiblockWelfordFirstHalf<XDataType, GridwiseMultiblockWelfordFirstHalf<XDataType,
AccDataType, AccDataType,
...@@ -441,78 +487,136 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType, ...@@ -441,78 +487,136 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType,
BiasSrcVectorSize, BiasSrcVectorSize,
MeanVarSrcDstVectorSize>; MeanVarSrcDstVectorSize>;
index_t numMeanVarCountBlockTileIteration = // It is found that:
(arg.blkGroupSize_ + KThreadClusterSize - 1) / KThreadClusterSize; // 1) gfx1030 does not support the GLC enabled vector load/store, so using the
// two-kernel method for gfx1030
const auto kern_multiblock_welford_first_half = // 2) Profiler on gfx908 could hang even though it works when running examples
kernel_multiblock_welford_first_half<GridwiseMultiblockWelfordFirstHalf_, // 3) Single-kernel method works on gfx1100, but the performance it not better
XDataType, // than two-kernel method (due to more warps participating the barrier)
MeanVarDataType, if(ck::get_device_name() == "gfx90a")
XYGridDesc_M_K, {
MeanVarCountGridDesc_M_G, const auto kern_multiblock_batchnorm_fwd_ =
GetReduceCountPerThreadFunctor>; kernel_multiblock_batchnorm_forward<GridwiseMultiblockBatchNormForward_,
XDataType,
const auto kern_welford_second_half_batchnorm_forward_final = YDataType,
kernel_welford_second_half_batchnorm_forward_final< AccDataType,
GridwiseWelfordSecondHalfBatchNormForwardFinal_, ScaleDataType,
XDataType, BiasDataType,
YDataType, MeanVarDataType,
AccDataType, YElementwiseOp,
ScaleDataType, XYGridDesc_M_K,
BiasDataType, MeanVarCountGridDesc_M_G,
MeanVarDataType, MeanVarCountGridDesc_M_K,
YElementwiseOp, ScaleBiasMeanVarGridDesc_M,
XYGridDesc_M_K, ScaleBiasMeanVarGridDesc_M,
MeanVarCountGridDesc_M_K, GetReduceCountPerThreadFunctor>;
ScaleBiasMeanVarGridDesc_M,
ScaleBiasMeanVarGridDesc_M>; avg_time += launch_and_time_kernel(
stream_config,
avg_time += kern_multiblock_batchnorm_fwd_,
launch_and_time_kernel(stream_config, dim3(arg.gridSize_),
kern_multiblock_welford_first_half, dim3(BlockSize),
dim3(arg.gridSize_), 0,
dim3(BlockSize), arg.x_grid_desc_m_k_,
0, arg.y_grid_desc_m_k_,
arg.x_grid_desc_m_k_, mean_var_count_grid_desc_m_g, // for writing to mean/variance/count
mean_var_count_grid_desc_m_g, // workspace by multiple workgroups
get_reduce_count_per_thread, mean_var_count_grid_desc_m_k, // for reading from mean/variance/count
arg.numBlockTileIteration_, // workspace by each workgroup
arg.p_x_, arg.scale_grid_desc_m_,
static_cast<MeanVarDataType*>(arg.workspace_mean_), arg.bias_grid_desc_m_,
static_cast<MeanVarDataType*>(arg.workspace_variance_), arg.mean_var_grid_desc_m_,
static_cast<int32_t*>(arg.workspace_count_)); get_reduce_count_per_thread,
arg.numBlockTileIteration_,
avg_time += arg.epsilon_,
launch_and_time_kernel(stream_config, arg.p_x_,
kern_welford_second_half_batchnorm_forward_final, static_cast<MeanVarDataType*>(arg.workspace_mean_),
dim3(arg.gridSize_), static_cast<MeanVarDataType*>(arg.workspace_variance_),
dim3(BlockSize), static_cast<int32_t*>(arg.workspace_count_),
0, static_cast<int*>(arg.control_),
arg.x_grid_desc_m_k_, arg.p_scale_,
arg.y_grid_desc_m_k_, arg.p_bias_,
mean_var_count_grid_desc_m_k, arg.y_elementwise_op_,
arg.scale_grid_desc_m_, arg.p_y_,
arg.bias_grid_desc_m_, arg.updateMovingAverage_, // true or false
arg.mean_var_grid_desc_m_, arg.averageFactor_,
arg.blkGroupSize_, arg.resultRunningMean_,
arg.numBlockTileIteration_, arg.resultRunningVariance_,
numMeanVarCountBlockTileIteration, arg.saveMeanInvVariance_, // true or false
arg.epsilon_, arg.resultSaveMean_,
static_cast<MeanVarDataType*>(arg.workspace_mean_), arg.resultSaveInvVariance_);
static_cast<MeanVarDataType*>(arg.workspace_variance_), }
static_cast<int32_t*>(arg.workspace_count_), else
arg.p_x_, {
arg.p_scale_, const auto kern_multiblock_welford_first_half =
arg.p_bias_, kernel_multiblock_welford_first_half<GridwiseMultiblockWelfordFirstHalf_,
arg.y_elementwise_op_, XDataType,
arg.p_y_, MeanVarDataType,
arg.updateMovingAverage_, XYGridDesc_M_K,
arg.averageFactor_, MeanVarCountGridDesc_M_G,
arg.resultRunningMean_, GetReduceCountPerThreadFunctor>;
arg.resultRunningVariance_,
arg.saveMeanInvVariance_, const auto kern_welford_second_half_batchnorm_forward_final =
arg.resultSaveMean_, kernel_welford_second_half_batchnorm_forward_final<
arg.resultSaveInvVariance_); GridwiseWelfordSecondHalfBatchNormForwardFinal_,
XDataType,
YDataType,
AccDataType,
ScaleDataType,
BiasDataType,
MeanVarDataType,
YElementwiseOp,
XYGridDesc_M_K,
MeanVarCountGridDesc_M_K,
ScaleBiasMeanVarGridDesc_M,
ScaleBiasMeanVarGridDesc_M>;
avg_time += launch_and_time_kernel(
stream_config,
kern_multiblock_welford_first_half,
dim3(arg.gridSize_),
dim3(BlockSize),
0,
arg.x_grid_desc_m_k_,
mean_var_count_grid_desc_m_g,
get_reduce_count_per_thread,
arg.numBlockTileIteration_,
arg.p_x_,
static_cast<MeanVarDataType*>(arg.workspace_mean_),
static_cast<MeanVarDataType*>(arg.workspace_variance_),
static_cast<int32_t*>(arg.workspace_count_));
avg_time += launch_and_time_kernel(
stream_config,
kern_welford_second_half_batchnorm_forward_final,
dim3(arg.gridSize_),
dim3(BlockSize),
0,
arg.x_grid_desc_m_k_,
arg.y_grid_desc_m_k_,
mean_var_count_grid_desc_m_k,
arg.scale_grid_desc_m_,
arg.bias_grid_desc_m_,
arg.mean_var_grid_desc_m_,
arg.blkGroupSize_,
arg.numBlockTileIteration_,
arg.epsilon_,
static_cast<MeanVarDataType*>(arg.workspace_mean_),
static_cast<MeanVarDataType*>(arg.workspace_variance_),
static_cast<int32_t*>(arg.workspace_count_),
arg.p_x_,
arg.p_scale_,
arg.p_bias_,
arg.y_elementwise_op_,
arg.p_y_,
arg.updateMovingAverage_,
arg.averageFactor_,
arg.resultRunningMean_,
arg.resultRunningVariance_,
arg.saveMeanInvVariance_,
arg.resultSaveMean_,
arg.resultSaveInvVariance_);
};
} }
else else
{ {
......
...@@ -459,7 +459,6 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 ...@@ -459,7 +459,6 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
p_ds_grid_{}, p_ds_grid_{},
p_e_grid_{static_cast<EDataType*>(p_e)}, p_e_grid_{static_cast<EDataType*>(p_e)},
num_group_{a_g_n_k_wos_lengths[0]}, num_group_{a_g_n_k_wos_lengths[0]},
num_gemm_{},
a_element_op_{a_element_op}, a_element_op_{a_element_op},
b_element_op_{b_element_op}, b_element_op_{b_element_op},
cde_element_op_{cde_element_op}, cde_element_op_{cde_element_op},
...@@ -508,9 +507,6 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 ...@@ -508,9 +507,6 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
const auto YTilde = ConvStrideH / GcdStrideDilationH; const auto YTilde = ConvStrideH / GcdStrideDilationH;
const auto XTilde = ConvStrideW / GcdStrideDilationW; const auto XTilde = ConvStrideW / GcdStrideDilationW;
// number of GEMM
num_gemm_ = YTilde * XTilde;
for(index_t i_ytilde = 0; i_ytilde < YTilde; ++i_ytilde) for(index_t i_ytilde = 0; i_ytilde < YTilde; ++i_ytilde)
{ {
for(index_t i_xtilde = 0; i_xtilde < XTilde; ++i_xtilde) for(index_t i_xtilde = 0; i_xtilde < XTilde; ++i_xtilde)
...@@ -626,7 +622,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 ...@@ -626,7 +622,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
void Print() const void Print() const
{ {
for(index_t i = 0; i < num_gemm_; i++) for(std::size_t i = 0; i < a_grid_desc_ak0_m_ak1_container_.size(); i++)
{ {
std::cout << "a_grid_desc_ak0_m_ak1_container_" std::cout << "a_grid_desc_ak0_m_ak1_container_"
<< a_grid_desc_ak0_m_ak1_container_[i] << std::endl; << a_grid_desc_ak0_m_ak1_container_[i] << std::endl;
...@@ -654,7 +650,6 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 ...@@ -654,7 +650,6 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
// tensor descriptor for problem definition // tensor descriptor for problem definition
index_t num_group_; index_t num_group_;
index_t num_gemm_;
std::vector<AGridDesc_M_K> a_grid_desc_m_k_container_; std::vector<AGridDesc_M_K> a_grid_desc_m_k_container_;
std::vector<BGridDesc_N_K> b_grid_desc_n_k_container_; std::vector<BGridDesc_N_K> b_grid_desc_n_k_container_;
std::vector<DsGridDesc_M_N> ds_grid_desc_m_n_container_; std::vector<DsGridDesc_M_N> ds_grid_desc_m_n_container_;
...@@ -708,7 +703,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 ...@@ -708,7 +703,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
float ave_time = 0; float ave_time = 0;
for(index_t i = 0; i < arg.num_gemm_; i++) for(std::size_t i = 0; i < arg.a_grid_desc_ak0_m_ak1_container_.size(); i++)
{ {
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_container_[i], if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_container_[i],
arg.b_grid_desc_n_k_container_[i], arg.b_grid_desc_n_k_container_[i],
...@@ -807,7 +802,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 ...@@ -807,7 +802,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
} }
// vector load for A matrix from global memory to LDS // vector load for A matrix from global memory to LDS
if constexpr(is_same_v<ALayout, tensor_layout::convolution::GNHWK>) if constexpr(is_same_v<ALayout, tensor_layout::convolution::GNHWK> ||
is_same_v<ALayout, tensor_layout::convolution::NHWGK>)
{ {
if(!(ABlockTransferSrcVectorDim == 2 && ConvK % ABlockTransferSrcScalarPerVector == 0)) if(!(ABlockTransferSrcVectorDim == 2 && ConvK % ABlockTransferSrcScalarPerVector == 0))
{ {
...@@ -862,7 +858,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 ...@@ -862,7 +858,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
} }
// vector store for E // vector store for E
if constexpr(is_same_v<ELayout, tensor_layout::convolution::GNHWC>) if constexpr(is_same_v<ELayout, tensor_layout::convolution::GNHWC> ||
is_same_v<ELayout, tensor_layout::convolution::NHWGC>)
{ {
// vector store C matrix into global memory // vector store C matrix into global memory
if(!(ConvC % CDEBlockTransferScalarPerVector_NPerBlock == 0)) if(!(ConvC % CDEBlockTransferScalarPerVector_NPerBlock == 0))
......
...@@ -195,17 +195,17 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl ...@@ -195,17 +195,17 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
template <ck::index_t NDim, typename ck::enable_if<NDim == 1, bool>::type = false> template <ck::index_t NDim, typename ck::enable_if<NDim == 1, bool>::type = false>
static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(
ck::index_t N, const ck::index_t N,
ck::index_t K, const ck::index_t K,
ck::index_t C, const ck::index_t C,
std::array<ck::index_t, NDimSpatial> input_spatial_lengths, const std::array<ck::index_t, NDimSpatial>& input_spatial_lengths,
std::array<ck::index_t, NDimSpatial> filter_spatial_lengths, const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths,
std::array<ck::index_t, NDimSpatial> output_spatial_lengths, const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths,
std::array<ck::index_t, NDimSpatial> conv_filter_strides, const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
std::array<ck::index_t, NDimSpatial> conv_filter_dilations, const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
std::array<ck::index_t, NDimSpatial> input_left_pads, const std::array<ck::index_t, NDimSpatial>& input_left_pads,
std::array<ck::index_t, NDimSpatial> input_right_pads, const std::array<ck::index_t, NDimSpatial>& input_right_pads,
ck::index_t batch_k) const ck::index_t batch_k)
{ {
using namespace ck; using namespace ck;
...@@ -347,17 +347,17 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl ...@@ -347,17 +347,17 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
} // function end } // function end
template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false> template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false>
static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(
ck::index_t N, const ck::index_t N,
ck::index_t K, const ck::index_t K,
ck::index_t C, const ck::index_t C,
std::array<ck::index_t, NDimSpatial> input_spatial_lengths, const std::array<ck::index_t, NDimSpatial>& input_spatial_lengths,
std::array<ck::index_t, NDimSpatial> filter_spatial_lengths, const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths,
std::array<ck::index_t, NDimSpatial> output_spatial_lengths, const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths,
std::array<ck::index_t, NDimSpatial> conv_filter_strides, const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
std::array<ck::index_t, NDimSpatial> conv_filter_dilations, const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
std::array<ck::index_t, NDimSpatial> input_left_pads, const std::array<ck::index_t, NDimSpatial>& input_left_pads,
std::array<ck::index_t, NDimSpatial> input_right_pads, const std::array<ck::index_t, NDimSpatial>& input_right_pads,
ck::index_t batch_k) const ck::index_t batch_k)
{ {
using namespace ck; using namespace ck;
...@@ -515,17 +515,17 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl ...@@ -515,17 +515,17 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false> template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false>
static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(
ck::index_t N, const ck::index_t N,
ck::index_t K, const ck::index_t K,
ck::index_t C, const ck::index_t C,
std::array<ck::index_t, NDimSpatial> input_spatial_lengths, const std::array<ck::index_t, NDimSpatial>& input_spatial_lengths,
std::array<ck::index_t, NDimSpatial> filter_spatial_lengths, const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths,
std::array<ck::index_t, NDimSpatial> output_spatial_lengths, const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths,
std::array<ck::index_t, NDimSpatial> conv_filter_strides, const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
std::array<ck::index_t, NDimSpatial> conv_filter_dilations, const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
std::array<ck::index_t, NDimSpatial> input_left_pads, const std::array<ck::index_t, NDimSpatial>& input_left_pads,
std::array<ck::index_t, NDimSpatial> input_right_pads, const std::array<ck::index_t, NDimSpatial>& input_right_pads,
ck::index_t batch_k) const ck::index_t batch_k)
{ {
using namespace ck; using namespace ck;
...@@ -784,17 +784,19 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl ...@@ -784,17 +784,19 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
Argument(const InDataType* p_in_grid, Argument(const InDataType* p_in_grid,
WeiDataType* p_wei_grid, WeiDataType* p_wei_grid,
const OutDataType* p_out_grid, const OutDataType* p_out_grid,
ck::index_t G, const ck::index_t G,
ck::index_t N, const ck::index_t N,
ck::index_t K, const ck::index_t K,
ck::index_t C, const ck::index_t C,
std::array<ck::index_t, NDimSpatial> input_spatial_lengths, const std::array<ck::index_t, NDimSpatial>& input_spatial_lengths,
std::array<ck::index_t, NDimSpatial> filter_spatial_lengths, const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths,
std::array<ck::index_t, NDimSpatial> output_spatial_lengths, const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths,
std::array<ck::index_t, NDimSpatial> conv_filter_strides, const std::array<ck::index_t, NDimSpatial + 3>& /*input_strides*/,
std::array<ck::index_t, NDimSpatial> conv_filter_dilations, const std::array<ck::index_t, NDimSpatial + 3>& /*output_strides*/,
std::array<ck::index_t, NDimSpatial> input_left_pads, const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
std::array<ck::index_t, NDimSpatial> input_right_pads, const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
const std::array<ck::index_t, NDimSpatial>& input_left_pads,
const std::array<ck::index_t, NDimSpatial>& input_right_pads,
InElementwiseOperation in_element_op, InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op, WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op, OutElementwiseOperation out_element_op,
...@@ -897,18 +899,18 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl ...@@ -897,18 +899,18 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
InElementwiseOperation c_element_op_; InElementwiseOperation c_element_op_;
// for checking IsSupportedArgument() // for checking IsSupportedArgument()
index_t Conv_G_; const index_t Conv_G_;
index_t Conv_N_; const index_t Conv_N_;
index_t Conv_K_; const index_t Conv_K_;
index_t Conv_C_; const index_t Conv_C_;
std::array<ck::index_t, NDimSpatial> input_spatial_lengths_; const std::array<ck::index_t, NDimSpatial>& input_spatial_lengths_;
std::array<ck::index_t, NDimSpatial> filter_spatial_lengths_; const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths_;
std::array<ck::index_t, NDimSpatial> output_spatial_lengths_; const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths_;
std::array<ck::index_t, NDimSpatial> conv_filter_strides_; const std::array<ck::index_t, NDimSpatial>& conv_filter_strides_;
std::array<ck::index_t, NDimSpatial> conv_filter_dilations_; const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations_;
std::array<ck::index_t, NDimSpatial> input_left_pads_; const std::array<ck::index_t, NDimSpatial>& input_left_pads_;
std::array<ck::index_t, NDimSpatial> input_right_pads_; const std::array<ck::index_t, NDimSpatial>& input_right_pads_;
index_t k_batch_; index_t k_batch_;
}; };
...@@ -1111,17 +1113,19 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl ...@@ -1111,17 +1113,19 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
static auto MakeArgument(const InDataType* p_in_grid, static auto MakeArgument(const InDataType* p_in_grid,
WeiDataType* p_wei_grid, WeiDataType* p_wei_grid,
const OutDataType* p_out_grid, const OutDataType* p_out_grid,
ck::index_t G, const ck::index_t G,
ck::index_t N, const ck::index_t N,
ck::index_t K, const ck::index_t K,
ck::index_t C, const ck::index_t C,
std::array<ck::index_t, NDimSpatial> input_spatial_lengths, const std::array<ck::index_t, NDimSpatial>& input_spatial_lengths,
std::array<ck::index_t, NDimSpatial> filter_spatial_lengths, const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths,
std::array<ck::index_t, NDimSpatial> output_spatial_lengths, const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths,
std::array<ck::index_t, NDimSpatial> conv_filter_strides, const std::array<ck::index_t, NDimSpatial + 3>& input_strides,
std::array<ck::index_t, NDimSpatial> conv_filter_dilations, const std::array<ck::index_t, NDimSpatial + 3>& output_strides,
std::array<ck::index_t, NDimSpatial> input_left_pads, const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
std::array<ck::index_t, NDimSpatial> input_right_pads, const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
const std::array<ck::index_t, NDimSpatial>& input_left_pads,
const std::array<ck::index_t, NDimSpatial>& input_right_pads,
InElementwiseOperation in_element_op, InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op, WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op, OutElementwiseOperation out_element_op,
...@@ -1137,6 +1141,8 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl ...@@ -1137,6 +1141,8 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
input_spatial_lengths, input_spatial_lengths,
filter_spatial_lengths, filter_spatial_lengths,
output_spatial_lengths, output_spatial_lengths,
input_strides,
output_strides,
conv_filter_strides, conv_filter_strides,
conv_filter_dilations, conv_filter_dilations,
input_left_pads, input_left_pads,
...@@ -1153,17 +1159,19 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl ...@@ -1153,17 +1159,19 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
MakeArgumentPointer(const void* p_in_grid, MakeArgumentPointer(const void* p_in_grid,
void* p_wei_grid, void* p_wei_grid,
const void* p_out_grid, const void* p_out_grid,
ck::index_t G, const ck::index_t G,
ck::index_t N, const ck::index_t N,
ck::index_t K, const ck::index_t K,
ck::index_t C, const ck::index_t C,
std::array<ck::index_t, NDimSpatial> input_spatial_lengths, const std::array<ck::index_t, NDimSpatial>& input_spatial_lengths,
std::array<ck::index_t, NDimSpatial> filter_spatial_lengths, const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths,
std::array<ck::index_t, NDimSpatial> output_spatial_lengths, const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths,
std::array<ck::index_t, NDimSpatial> conv_filter_strides, const std::array<ck::index_t, NDimSpatial + 3>& input_strides,
std::array<ck::index_t, NDimSpatial> conv_filter_dilations, const std::array<ck::index_t, NDimSpatial + 3>& output_strides,
std::array<ck::index_t, NDimSpatial> input_left_pads, const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
std::array<ck::index_t, NDimSpatial> input_right_pads, const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
const std::array<ck::index_t, NDimSpatial>& input_left_pads,
const std::array<ck::index_t, NDimSpatial>& input_right_pads,
InElementwiseOperation in_element_op, InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op, WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op, OutElementwiseOperation out_element_op,
...@@ -1179,6 +1187,8 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl ...@@ -1179,6 +1187,8 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
input_spatial_lengths, input_spatial_lengths,
filter_spatial_lengths, filter_spatial_lengths,
output_spatial_lengths, output_spatial_lengths,
input_strides,
output_strides,
conv_filter_strides, conv_filter_strides,
conv_filter_dilations, conv_filter_dilations,
input_left_pads, input_left_pads,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment