Unverified Commit 63af525c authored by Qianfeng's avatar Qianfeng Committed by GitHub
Browse files

BatchNorm backward instance/external API/profiler/tests (#519)

* Refine the device batchnorm-backward base API templates and data type assignments

* Remove duplicated kernel file

* Add batchnorm backward instances and external API

* Add batchnorm-backward profiler and tests

* Add client example which uses batchnorm backward external API

* Merge test/batchnorm_fwd and test/batchnorm_bwd into one directory

* Loose the threshold for batchnorm-backward check_err()
parent 236bd148
add_executable(client_batchnorm_fwd_nhwc batchnorm_fwd_nhwc.cpp)
add_executable(client_batchnorm_bwd_nhwc batchnorm_bwd_nhwc.cpp)
target_link_libraries(client_batchnorm_fwd_nhwc PRIVATE composable_kernel::device_operations)
target_link_libraries(client_batchnorm_bwd_nhwc PRIVATE composable_kernel::device_operations)
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <functional>
#include <numeric>
#include <iomanip>
#include <iostream>
#include <vector>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/device_reduce.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/gpu/batchnorm_backward.hpp"
using XDataType = ck::half_t;
using DxDataType = float;
using DyDataType = float;
using AccDataType = float;
using ScaleDataType = ck::half_t;
using DscaleDbiasDataType = float;
using MeanVarDataType = float;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
constexpr int Rank = 4;
constexpr int NumBatchNormReduceDim = 3;
const double epsilon = std::numeric_limits<float>::epsilon();
struct SimpleDeviceMem
{
SimpleDeviceMem() = delete;
SimpleDeviceMem(std::size_t mem_size) : p_mem_{}
{
(void)hipMalloc(static_cast<void**>(&p_mem_), mem_size);
}
void* GetDeviceBuffer() { return p_mem_; }
~SimpleDeviceMem() { (void)hipFree(p_mem_); }
void* p_mem_;
};
int main(int argc, char* argv[])
{
std::array<ck::index_t, Rank> xyLengths{16, 8, 128, 256};
std::array<ck::index_t, Rank> xyStrides{8 * 128 * 256, 128 * 256, 256, 1};
std::array<ck::index_t, Rank - NumBatchNormReduceDim> scaleBiasMeanVarLengths{256};
std::array<ck::index_t, Rank - NumBatchNormReduceDim> scaleBiasMeanVarStrides{1};
std::array<int, NumBatchNormReduceDim> reduceDims{0, 1, 2};
ck::index_t numXYElement =
std::accumulate(xyLengths.begin(), xyLengths.end(), 1, std::multiplies<ck::index_t>());
ck::index_t numScaleBiasMeanVarElement = std::accumulate(scaleBiasMeanVarLengths.begin(),
scaleBiasMeanVarLengths.end(),
1,
std::multiplies<ck::index_t>());
SimpleDeviceMem x(sizeof(XDataType) * numXYElement);
SimpleDeviceMem dy(sizeof(DyDataType) * numXYElement);
SimpleDeviceMem scale(sizeof(ScaleDataType) * numScaleBiasMeanVarElement);
SimpleDeviceMem mean(sizeof(MeanVarDataType) * numScaleBiasMeanVarElement);
SimpleDeviceMem invVariance(sizeof(MeanVarDataType) * numScaleBiasMeanVarElement);
SimpleDeviceMem dx(sizeof(DxDataType) * numXYElement);
SimpleDeviceMem dscale(sizeof(DscaleDbiasDataType) * numScaleBiasMeanVarElement);
SimpleDeviceMem dbias(sizeof(DscaleDbiasDataType) * numScaleBiasMeanVarElement);
using DeviceOp = ck::tensor_operation::device::DeviceBatchNormBwd<XDataType,
DxDataType,
DyDataType,
AccDataType,
ScaleDataType,
DscaleDbiasDataType,
MeanVarDataType,
PassThrough,
Rank,
NumBatchNormReduceDim>;
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
DeviceOp>::GetInstances();
std::cout << "found " << op_ptrs.size() << " instances" << std::endl;
std::string best_op_name;
bool found = false;
int best_op_id = -1;
float best_ave_time = std::numeric_limits<float>::max();
float best_gb_per_sec = 0;
// profile device operation instances
std::cout << "Run all instances and do timing" << std::endl;
for(int i = 0; i < op_ptrs.size(); ++i)
{
auto& op_ptr = op_ptrs[i];
auto argument_ptr = op_ptr->MakeArgumentPointer(xyLengths,
xyStrides,
xyStrides,
xyStrides,
reduceDims,
scaleBiasMeanVarLengths,
scaleBiasMeanVarStrides,
scaleBiasMeanVarStrides,
scaleBiasMeanVarStrides,
x.GetDeviceBuffer(),
dy.GetDeviceBuffer(),
scale.GetDeviceBuffer(),
mean.GetDeviceBuffer(),
invVariance.GetDeviceBuffer(),
epsilon,
PassThrough{},
dx.GetDeviceBuffer(),
dscale.GetDeviceBuffer(),
dbias.GetDeviceBuffer());
auto invoker_ptr = op_ptr->MakeInvokerPointer();
std::string op_name = op_ptr->GetTypeString();
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{
size_t workspace_sz = op_ptr->GetWorkSpaceSize(argument_ptr.get());
SimpleDeviceMem workspace(workspace_sz);
op_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace.GetDeviceBuffer());
float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true});
std::size_t num_bytes =
numXYElement * (sizeof(XDataType) + sizeof(DyDataType) + sizeof(DxDataType)) +
numScaleBiasMeanVarElement *
(sizeof(ScaleDataType) + sizeof(DscaleDbiasDataType) * 2 +
sizeof(MeanVarDataType) * 2);
float gb_per_sec = num_bytes / 1.E6 / ave_time;
std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << gb_per_sec << " GB/s, "
<< op_name << std::endl;
if(ave_time < best_ave_time)
{
found = true;
best_op_id = i;
best_op_name = op_name;
best_ave_time = ave_time;
best_gb_per_sec = gb_per_sec;
}
}
else
{
std::cout << op_name << " does not support this problem" << std::endl;
}
}
if(found)
{
std::cout << "Best Perf: " << best_ave_time << " ms, " << best_gb_per_sec << " GB/s, "
<< best_op_name << std::endl;
// run the best intance
auto& op_ptr = op_ptrs[best_op_id];
std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString()
<< std::endl;
auto argument_ptr = op_ptr->MakeArgumentPointer(xyLengths,
xyStrides,
xyStrides,
xyStrides,
reduceDims,
scaleBiasMeanVarLengths,
scaleBiasMeanVarStrides,
scaleBiasMeanVarStrides,
scaleBiasMeanVarStrides,
x.GetDeviceBuffer(),
dy.GetDeviceBuffer(),
scale.GetDeviceBuffer(),
mean.GetDeviceBuffer(),
invVariance.GetDeviceBuffer(),
epsilon,
PassThrough{},
dx.GetDeviceBuffer(),
dscale.GetDeviceBuffer(),
dbias.GetDeviceBuffer());
auto invoker_ptr = op_ptr->MakeInvokerPointer();
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false});
}
std::cout << "Done" << std::endl;
}
return 0;
}
......@@ -11,7 +11,7 @@
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/host_common_util.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batchnorm_backward_nhwc_c.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batchnorm_backward.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batchnorm_backward_impl.hpp"
static struct option long_options[] = {{"inOutLengths", required_argument, nullptr, 'D'},
......@@ -106,7 +106,7 @@ class BatchNormBwdArg
using namespace ck;
template <typename InOutDataType, typename AccDataType, bool UseMultiblockInK>
template <typename XDataType, typename AccDataType, bool UseMultiblockInK>
bool bnorm_bwd_nhwc_test(bool do_verification,
int init_method,
bool time_kernel,
......@@ -118,13 +118,15 @@ bool bnorm_bwd_nhwc_test(bool do_verification,
constexpr index_t Rank = 4;
constexpr index_t NumReduceDim = 3;
using ScaleDataType = XDataType;
const std::vector<size_t> scaleBiasMeanVarLengths = {inOutLengths[3]};
// input data of the batchnorm backward algorithm
Tensor<InOutDataType> x(inOutLengths);
Tensor<InOutDataType> dy(inOutLengths);
Tensor<XDataType> x(inOutLengths);
Tensor<AccDataType> dy(inOutLengths);
Tensor<AccDataType> bnScale(scaleBiasMeanVarLengths);
Tensor<ScaleDataType> bnScale(scaleBiasMeanVarLengths);
Tensor<AccDataType> savedMean(scaleBiasMeanVarLengths);
Tensor<AccDataType> savedInvVar(scaleBiasMeanVarLengths);
......@@ -132,8 +134,8 @@ bool bnorm_bwd_nhwc_test(bool do_verification,
Tensor<AccDataType> savedVariance(scaleBiasMeanVarLengths);
// output data of the batchnorm backward algorithm
Tensor<InOutDataType> dx_ref(inOutLengths);
Tensor<InOutDataType> dx(inOutLengths);
Tensor<AccDataType> dx_ref(inOutLengths);
Tensor<AccDataType> dx(inOutLengths);
Tensor<AccDataType> dscale(scaleBiasMeanVarLengths);
Tensor<AccDataType> dbias(scaleBiasMeanVarLengths);
......@@ -153,7 +155,7 @@ bool bnorm_bwd_nhwc_test(bool do_verification,
const float noise_stddev = 0.0001f;
// input data in normal distribution
x.GenerateTensorValue(GeneratorTensor_4<InOutDataType>{x_mean, x_stddev}, num_thread);
x.GenerateTensorValue(GeneratorTensor_4<XDataType>{x_mean, x_stddev}, num_thread);
// initialize the savedMean to be values with tiny variation to the mean of the x values
savedMean.GenerateTensorValue(GeneratorTensor_4<AccDataType>{x_mean, noise_stddev},
......@@ -182,7 +184,7 @@ bool bnorm_bwd_nhwc_test(bool do_verification,
const float x_stddev = 1.0f;
// input data in normal distribution
x.GenerateTensorValue(GeneratorTensor_4<InOutDataType>{x_mean, x_stddev}, num_thread);
x.GenerateTensorValue(GeneratorTensor_4<XDataType>{x_mean, x_stddev}, num_thread);
};
if(do_verification)
......@@ -190,34 +192,34 @@ bool bnorm_bwd_nhwc_test(bool do_verification,
switch(init_method)
{
case 0:
dy.GenerateTensorValue(GeneratorTensor_0<InOutDataType>{}, num_thread);
bnScale.GenerateTensorValue(GeneratorTensor_0<InOutDataType>{}, num_thread);
dy.GenerateTensorValue(GeneratorTensor_0<AccDataType>{}, num_thread);
bnScale.GenerateTensorValue(GeneratorTensor_0<ScaleDataType>{}, num_thread);
break;
case 1:
dy.GenerateTensorValue(GeneratorTensor_1<InOutDataType>{1}, num_thread);
bnScale.GenerateTensorValue(GeneratorTensor_1<InOutDataType>{1}, num_thread);
dy.GenerateTensorValue(GeneratorTensor_1<AccDataType>{1}, num_thread);
bnScale.GenerateTensorValue(GeneratorTensor_1<ScaleDataType>{1}, num_thread);
break;
case 2:
dy.GenerateTensorValue(GeneratorTensor_2<InOutDataType>{-5, 5}, num_thread);
bnScale.GenerateTensorValue(GeneratorTensor_2<InOutDataType>{-5, 5}, num_thread);
dy.GenerateTensorValue(GeneratorTensor_2<AccDataType>{-2, 2}, num_thread);
bnScale.GenerateTensorValue(GeneratorTensor_2<ScaleDataType>{-5, 5}, num_thread);
break;
default:
dy.GenerateTensorValue(GeneratorTensor_3<InOutDataType>{-0.2f, 0.2f}, num_thread);
bnScale.GenerateTensorValue(GeneratorTensor_3<InOutDataType>{-0.5f, 0.5f}, num_thread);
dy.GenerateTensorValue(GeneratorTensor_3<AccDataType>{-0.2f, 0.2f}, num_thread);
bnScale.GenerateTensorValue(GeneratorTensor_3<ScaleDataType>{-0.5f, 0.5f}, num_thread);
}
};
// input data of the batchnorm backward algorithm
DeviceMem x_dev(sizeof(InOutDataType) * x.mDesc.GetElementSpaceSize());
DeviceMem dy_dev(sizeof(InOutDataType) * dy.mDesc.GetElementSpaceSize());
DeviceMem x_dev(sizeof(XDataType) * x.mDesc.GetElementSpaceSize());
DeviceMem dy_dev(sizeof(AccDataType) * dy.mDesc.GetElementSpaceSize());
DeviceMem bnScale_dev(sizeof(AccDataType) * bnScale.mDesc.GetElementSpaceSize());
DeviceMem bnScale_dev(sizeof(ScaleDataType) * bnScale.mDesc.GetElementSpaceSize());
DeviceMem savedMean_dev(sizeof(AccDataType) * savedMean.mDesc.GetElementSpaceSize());
DeviceMem savedInvVar_dev(sizeof(AccDataType) * savedInvVar.mDesc.GetElementSpaceSize());
// output data of the batchnorm backward algorithm
DeviceMem dx_dev(sizeof(InOutDataType) * dx.mDesc.GetElementSpaceSize());
DeviceMem dx_dev(sizeof(AccDataType) * dx.mDesc.GetElementSpaceSize());
DeviceMem dscale_dev(sizeof(AccDataType) * dscale.mDesc.GetElementSpaceSize());
DeviceMem dbias_dev(sizeof(AccDataType) * dbias.mDesc.GetElementSpaceSize());
......@@ -249,13 +251,13 @@ bool bnorm_bwd_nhwc_test(bool do_verification,
using PassThroughOp = ck::tensor_operation::element_wise::PassThrough;
using DeviceBatchNormBwdInstance =
ck::tensor_operation::device::DeviceBatchNormBwdImpl<InOutDataType,
InOutDataType,
InOutDataType,
ck::tensor_operation::device::DeviceBatchNormBwdImpl<XDataType,
AccDataType,
AccDataType,
AccDataType,
AccDataType, // ScaleDataType
AccDataType, // BiasDataType
AccDataType, // MeanVarDataType
ScaleDataType, // ScaleDataType
AccDataType, // DscaleDbiasDataType
AccDataType, // MeanVarDataType
PassThroughOp,
Rank,
NumReduceDim,
......@@ -269,8 +271,8 @@ bool bnorm_bwd_nhwc_test(bool do_verification,
1, // XSrcVectorSize
1, // DySrcVectorSize
1, // DxDstVectorSize
1, // ScaleSrcDstVectorSize
1, // BiasDstVectorSize
1, // ScaleSrcVectorSize
1, // DscaleDbiasDstVectorSize
1>; // MeanVarSrcVectorSize
auto batchnorm_bwd = DeviceBatchNormBwdInstance{};
......@@ -324,7 +326,7 @@ bool bnorm_bwd_nhwc_test(bool do_verification,
// inputing of x, dy, scale, outputing of dx, dscale, dbias
num_bytes +=
total_length * sizeof(InOutDataType) * 3 + invariant_length * sizeof(AccDataType) * 3;
total_length * sizeof(XDataType) * 3 + invariant_length * sizeof(AccDataType) * 3;
// outputing of mean, inv-variance
num_bytes += haveSavedMeanInvVar ? invariant_length * sizeof(AccDataType) * 2 : 0;
......@@ -341,14 +343,16 @@ bool bnorm_bwd_nhwc_test(bool do_verification,
if(do_verification)
{
using ReferenceBatchNormBwdInstance =
ck::tensor_operation::host::ReferenceBatchNormBwd_Input_N_H_W_C_Output_C<InOutDataType,
InOutDataType,
InOutDataType,
AccDataType,
AccDataType,
AccDataType,
AccDataType,
PassThroughOp>;
ck::tensor_operation::host::ReferenceBatchNormBwd<XDataType,
AccDataType,
AccDataType,
AccDataType,
ScaleDataType, // ScaleDataType
AccDataType,
AccDataType,
PassThroughOp,
Rank,
NumReduceDim>;
auto batchNormBwd_ref = ReferenceBatchNormBwdInstance{};
......@@ -390,8 +394,8 @@ bool bnorm_bwd_nhwc_test(bool do_verification,
dbias_dev.FromDevice(dbias.data());
// clang-format off
pass = pass && ck::utils::check_err(dbias.mData, dbias_ref.mData, "dBias result:", 1e-5, 1e-5);
pass = pass && ck::utils::check_err(dscale.mData, dscale_ref.mData, "dScale result:", 1e-5, 2e-4);
pass = pass && ck::utils::check_err(dbias.mData, dbias_ref.mData, "dBias result:", 2e-4, 2e-4);
pass = pass && ck::utils::check_err(dscale.mData, dscale_ref.mData, "dScale result:", 2e-4, 2e-4);
pass = pass && ck::utils::check_err(dx.mData, dx_ref.mData, "dx result:");
// clang-format on
};
......
......@@ -13,7 +13,16 @@ namespace ck {
namespace tensor_operation {
namespace device {
template <index_t Rank, index_t NumBatchNormReduceDim, typename DyElementwiseOp>
template <typename XDataType,
typename DxDataType,
typename DyDataType,
typename AccDataType,
typename ScaleDataType,
typename DscaleDbiasDataType,
typename MeanVarDataType,
typename DyElementwiseOp,
index_t Rank,
index_t NumBatchNormReduceDim>
struct DeviceBatchNormBwd : public BaseOperator
{
static constexpr index_t NumInvariantDim = Rank - NumBatchNormReduceDim;
......@@ -26,7 +35,7 @@ struct DeviceBatchNormBwd : public BaseOperator
const std::array<int, NumBatchNormReduceDim> reduceDims,
const std::array<ck::index_t, NumInvariantDim> bnScaleBiasMeanVarLengths,
const std::array<ck::index_t, NumInvariantDim> bnScaleStrides,
const std::array<ck::index_t, NumInvariantDim> bnBiasStrides,
const std::array<ck::index_t, NumInvariantDim> bnDscaleDbiasStrides,
const std::array<ck::index_t, NumInvariantDim> bnMeanVarStrides,
const void* p_x,
const void* p_dy,
......@@ -42,9 +51,26 @@ struct DeviceBatchNormBwd : public BaseOperator
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
template <index_t Rank, index_t NumBatchNormReduceDim, typename DyElementwiseOp>
using DeviceBatchNormBwdPtr =
std::unique_ptr<DeviceBatchNormBwd<Rank, NumBatchNormReduceDim, DyElementwiseOp>>;
template <typename XDataType,
typename DxDataType,
typename DyDataType,
typename AccDataType,
typename ScaleDataType,
typename DscaleDbiasDataType,
typename MeanVarDataType,
typename DyElementwiseOp,
index_t Rank,
index_t NumBatchNormReduceDim>
using DeviceBatchNormBwdPtr = std::unique_ptr<DeviceBatchNormBwd<XDataType,
DxDataType,
DyDataType,
AccDataType,
ScaleDataType,
DscaleDbiasDataType,
MeanVarDataType,
DyElementwiseOp,
Rank,
NumBatchNormReduceDim>>;
} // namespace device
} // namespace tensor_operation
......
......@@ -27,7 +27,7 @@ template <typename XDataType,
typename DyDataType,
typename AccDataType,
typename ScaleDataType,
typename BiasDataType,
typename DscaleDbiasDataType,
typename MeanVarDataType,
typename DyElementwiseOp,
index_t Rank,
......@@ -42,11 +42,19 @@ template <typename XDataType,
index_t XSrcVectorSize,
index_t DySrcVectorSize,
index_t DxDstVectorSize,
index_t ScaleSrcDstVectorSize,
index_t BiasDstVectorSize,
index_t ScaleSrcVectorSize,
index_t DscaleDbiasDstVectorSize,
index_t MeanVarSrcVectorSize>
struct DeviceBatchNormBwdImpl
: public DeviceBatchNormBwd<Rank, NumBatchNormReduceDim, DyElementwiseOp>
struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<XDataType,
DxDataType,
DyDataType,
AccDataType,
ScaleDataType,
DscaleDbiasDataType,
MeanVarDataType,
DyElementwiseOp,
Rank,
NumBatchNormReduceDim>
{
static_assert(Rank <= 6, "Bigger Rank size is not supported!");
static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize,
......@@ -194,7 +202,7 @@ struct DeviceBatchNormBwdImpl
const std::array<int, NumBatchNormReduceDim> reduceDims,
const std::array<ck::index_t, NumInvariantDim> bnScaleBiasMeanVarLengths,
const std::array<ck::index_t, NumInvariantDim> bnScaleStrides,
const std::array<ck::index_t, NumInvariantDim> bnBiasStrides,
const std::array<ck::index_t, NumInvariantDim> bnDscaleDbiasStrides,
const std::array<ck::index_t, NumInvariantDim> bnMeanVarStrides,
const XDataType* p_x,
const DyDataType* p_dy,
......@@ -204,11 +212,11 @@ struct DeviceBatchNormBwdImpl
const DyElementwiseOp dy_elementwise_op,
double epsilon,
DxDataType* p_dx,
ScaleDataType* p_dscale,
BiasDataType* p_dbias)
DscaleDbiasDataType* p_dscale,
DscaleDbiasDataType* p_dbias)
: bnScaleBiasMeanVarLengths_(bnScaleBiasMeanVarLengths),
bnScaleStrides_(bnScaleStrides),
bnBiasStrides_(bnBiasStrides),
bnDscaleDbiasStrides_(bnDscaleDbiasStrides),
bnMeanVarStrides_(bnMeanVarStrides),
p_x_(p_x),
p_dy_(p_dy),
......@@ -272,8 +280,8 @@ struct DeviceBatchNormBwdImpl
MakeXY2dDescriptor(xyLengths_, dxStrides_, blkGroupSize, numBlockTileIteration);
scale_grid_desc_m =
MakeScaleBiasMeanVar1dDescriptor(bnScaleBiasMeanVarLengths, bnScaleStrides);
bias_grid_desc_m =
MakeScaleBiasMeanVar1dDescriptor(bnScaleBiasMeanVarLengths, bnBiasStrides);
dscale_dbias_grid_desc_m =
MakeScaleBiasMeanVar1dDescriptor(bnScaleBiasMeanVarLengths, bnDscaleDbiasStrides);
mean_var_grid_desc_m =
MakeScaleBiasMeanVar1dDescriptor(bnScaleBiasMeanVarLengths, bnMeanVarStrides);
}
......@@ -289,7 +297,7 @@ struct DeviceBatchNormBwdImpl
std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleBiasMeanVarLengths_;
std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleStrides_;
std::array<index_t, Rank - NumBatchNormReduceDim> bnBiasStrides_;
std::array<index_t, Rank - NumBatchNormReduceDim> bnDscaleDbiasStrides_;
std::array<index_t, Rank - NumBatchNormReduceDim> bnMeanVarStrides_;
const XDataType* p_x_;
......@@ -299,8 +307,8 @@ struct DeviceBatchNormBwdImpl
const MeanVarDataType* p_savedInvVar_;
const DyElementwiseOp dy_elementwise_op_;
DxDataType* p_dx_;
ScaleDataType* p_dscale_;
BiasDataType* p_dbias_;
DscaleDbiasDataType* p_dscale_;
DscaleDbiasDataType* p_dbias_;
long_index_t invariant_length;
long_index_t reduce_length;
......@@ -313,7 +321,7 @@ struct DeviceBatchNormBwdImpl
XYGridDesc_M_K dy_grid_desc_m_k;
XYGridDesc_M_K dx_grid_desc_m_k;
ScaleBiasGridDesc_M scale_grid_desc_m;
ScaleBiasGridDesc_M bias_grid_desc_m;
ScaleBiasGridDesc_M dscale_dbias_grid_desc_m;
MeanVarGridDesc_M mean_var_grid_desc_m;
void* workspace_mean;
......@@ -337,11 +345,11 @@ struct DeviceBatchNormBwdImpl
{
// workspace for the partial reduced result for dscale
workspace_size +=
pArg_->invariant_length * pArg_->blkGroupSize * sizeof(ScaleDataType) + 64;
pArg_->invariant_length * pArg_->blkGroupSize * sizeof(DscaleDbiasDataType) + 64;
// workspace for the partial reduced result for dbias
workspace_size +=
pArg_->invariant_length * pArg_->blkGroupSize * sizeof(BiasDataType) + 64;
pArg_->invariant_length * pArg_->blkGroupSize * sizeof(DscaleDbiasDataType) + 64;
if(!pArg_->haveSavedMeanInvVar_)
{
......@@ -379,7 +387,7 @@ struct DeviceBatchNormBwdImpl
// setup buffer for the partial reduced result for dscale
pArg_->workspace_reduce_dscale = pArg_->p_workspace_;
space_sz = pArg_->invariant_length * pArg_->blkGroupSize * sizeof(ScaleDataType);
space_sz = pArg_->invariant_length * pArg_->blkGroupSize * sizeof(DscaleDbiasDataType);
space_sz = math::integer_least_multiple(space_sz, 64);
// setup buffer for the partial reduced result for dbias
......@@ -388,7 +396,7 @@ struct DeviceBatchNormBwdImpl
if(UseMultiblockInK && pArg_->blkGroupSize > 1)
{
space_sz = pArg_->invariant_length * pArg_->blkGroupSize * sizeof(BiasDataType);
space_sz = pArg_->invariant_length * pArg_->blkGroupSize * sizeof(DscaleDbiasDataType);
space_sz = math::integer_least_multiple(space_sz, 64);
// setup buffer for welford intermediate mean
......@@ -454,7 +462,7 @@ struct DeviceBatchNormBwdImpl
DyDataType,
AccDataType,
ScaleDataType,
BiasDataType,
DscaleDbiasDataType,
MeanVarDataType,
DyElementwiseOp,
XYGridDesc_M_K,
......@@ -477,7 +485,7 @@ struct DeviceBatchNormBwdImpl
DxDataType,
AccDataType,
ScaleDataType,
BiasDataType,
DscaleDbiasDataType,
MeanVarDataType,
DyElementwiseOp,
XYGridDesc_M_K,
......@@ -493,8 +501,8 @@ struct DeviceBatchNormBwdImpl
XSrcVectorSize,
DySrcVectorSize,
DxDstVectorSize,
ScaleSrcDstVectorSize,
BiasDstVectorSize,
ScaleSrcVectorSize,
DscaleDbiasDstVectorSize,
MeanVarSrcVectorSize>;
if(UseMultiblockInK && arg.blkGroupSize > 1)
......@@ -553,7 +561,7 @@ struct DeviceBatchNormBwdImpl
DyDataType,
AccDataType,
ScaleDataType,
BiasDataType,
DscaleDbiasDataType,
MeanVarDataType,
DyElementwiseOp,
XYGridDesc_M_K,
......@@ -568,7 +576,7 @@ struct DeviceBatchNormBwdImpl
DyDataType,
DxDataType,
ScaleDataType,
BiasDataType,
DscaleDbiasDataType,
MeanVarDataType,
DyElementwiseOp,
XYGridDesc_M_K,
......@@ -614,8 +622,8 @@ struct DeviceBatchNormBwdImpl
: static_cast<MeanVarDataType*>(arg.workspace_savedInvVar),
arg.p_x_,
arg.p_dy_,
static_cast<ScaleDataType*>(arg.workspace_reduce_dscale),
static_cast<BiasDataType*>(arg.workspace_reduce_dbias));
static_cast<DscaleDbiasDataType*>(arg.workspace_reduce_dscale),
static_cast<DscaleDbiasDataType*>(arg.workspace_reduce_dbias));
avg_time += launch_and_time_kernel(
stream_config,
......@@ -629,13 +637,13 @@ struct DeviceBatchNormBwdImpl
dscale_dbias_grid_desc_m_k,
arg.mean_var_grid_desc_m,
arg.scale_grid_desc_m,
arg.bias_grid_desc_m,
arg.dscale_dbias_grid_desc_m,
arg.blkGroupSize,
arg.reduce_length,
arg.numBlockTileIteration,
numDscaleDbiasBlockTileIteration,
static_cast<const ScaleDataType*>(arg.workspace_reduce_dscale),
static_cast<const BiasDataType*>(arg.workspace_reduce_dbias),
static_cast<const DscaleDbiasDataType*>(arg.workspace_reduce_dscale),
static_cast<const DscaleDbiasDataType*>(arg.workspace_reduce_dbias),
arg.haveSavedMeanInvVar_
? arg.p_savedMean_
: static_cast<const MeanVarDataType*>(arg.workspace_savedMean),
......@@ -664,7 +672,7 @@ struct DeviceBatchNormBwdImpl
DxDataType,
AccDataType,
ScaleDataType,
BiasDataType,
DscaleDbiasDataType,
MeanVarDataType,
DyElementwiseOp,
XYGridDesc_M_K,
......@@ -680,8 +688,8 @@ struct DeviceBatchNormBwdImpl
XSrcVectorSize,
DySrcVectorSize,
DxDstVectorSize,
ScaleSrcDstVectorSize,
BiasDstVectorSize,
ScaleSrcVectorSize,
DscaleDbiasDstVectorSize,
MeanVarSrcVectorSize>;
const auto kern_batchnorm_bwd = kernel_batchnorm_backward_with_blockwise_welford<
......@@ -691,7 +699,7 @@ struct DeviceBatchNormBwdImpl
DxDataType,
AccDataType,
ScaleDataType,
BiasDataType,
DscaleDbiasDataType,
MeanVarDataType,
DyElementwiseOp,
XYGridDesc_M_K,
......@@ -708,7 +716,7 @@ struct DeviceBatchNormBwdImpl
arg.dy_grid_desc_m_k,
arg.dx_grid_desc_m_k,
arg.scale_grid_desc_m,
arg.bias_grid_desc_m,
arg.dscale_dbias_grid_desc_m,
arg.mean_var_grid_desc_m,
get_reduce_count_per_thread,
arg.reduce_length,
......@@ -764,16 +772,16 @@ struct DeviceBatchNormBwdImpl
return false;
};
if(pArg_->bnScaleStrides_[NumInvariantDim - 1] != 1 && ScaleSrcDstVectorSize != 1)
if(pArg_->bnScaleStrides_[NumInvariantDim - 1] != 1 && ScaleSrcVectorSize != 1)
return false;
if(pArg_->bnBiasStrides_[NumInvariantDim - 1] != 1 && BiasDstVectorSize != 1)
if(pArg_->bnDscaleDbiasStrides_[NumInvariantDim - 1] != 1 && DscaleDbiasDstVectorSize != 1)
return false;
if(pArg_->bnScaleBiasMeanVarLengths_[NumInvariantDim - 1] % ScaleSrcDstVectorSize != 0)
if(pArg_->bnScaleBiasMeanVarLengths_[NumInvariantDim - 1] % ScaleSrcVectorSize != 0)
return false;
if(pArg_->bnScaleBiasMeanVarLengths_[NumInvariantDim - 1] % BiasDstVectorSize != 0)
if(pArg_->bnScaleBiasMeanVarLengths_[NumInvariantDim - 1] % DscaleDbiasDstVectorSize != 0)
return false;
if(pArg_->haveSavedMeanInvVar_)
......@@ -806,7 +814,7 @@ struct DeviceBatchNormBwdImpl
const std::array<int, NumBatchNormReduceDim> reduceDims,
const std::array<ck::index_t, NumInvariantDim> bnScaleBiasMeanVarLengths,
const std::array<ck::index_t, NumInvariantDim> bnScaleStrides,
const std::array<ck::index_t, NumInvariantDim> bnBiasStrides,
const std::array<ck::index_t, NumInvariantDim> bnDscaleDbiasStrides,
const std::array<ck::index_t, NumInvariantDim> bnMeanVarStrides,
const void* p_x,
const void* p_dy,
......@@ -826,7 +834,7 @@ struct DeviceBatchNormBwdImpl
reduceDims,
bnScaleBiasMeanVarLengths,
bnScaleStrides,
bnBiasStrides,
bnDscaleDbiasStrides,
bnMeanVarStrides,
static_cast<const XDataType*>(p_x),
static_cast<const DyDataType*>(p_dy),
......@@ -836,8 +844,8 @@ struct DeviceBatchNormBwdImpl
dy_elementwise_op,
epsilon,
static_cast<DxDataType*>(p_dx),
static_cast<ScaleDataType*>(p_dscale),
static_cast<BiasDataType*>(p_dbias));
static_cast<DscaleDbiasDataType*>(p_dscale),
static_cast<DscaleDbiasDataType*>(p_dbias));
};
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
......@@ -854,7 +862,7 @@ struct DeviceBatchNormBwdImpl
str << "M_C" << MThreadClusterSize << "_S" << MThreadSliceSize << ",";
str << "K_C" << KThreadClusterSize << "_S" << KThreadSliceSize << ",";
str << "XDyDxVectorDim_" << XDyDxVectorDim << ",";
str << "VectorSize_X" << XSrcVectorSize << "_scale_" << ScaleSrcDstVectorSize << "_bias_" << BiasDstVectorSize << "_mean_var_" << MeanVarSrcVectorSize << "_Dx_" << DxDstVectorSize << ">";
str << "VectorSize_X" << XSrcVectorSize << "_scale_" << ScaleSrcVectorSize << "_bias_" << DscaleDbiasDstVectorSize << "_mean_var_" << MeanVarSrcVectorSize << "_Dx_" << DxDstVectorSize << ">";
// clang-format on
return str.str();
......
......@@ -16,7 +16,7 @@ template <typename GridwiseReduceSecondHalfBatchNormBackwardFinal_,
typename DyDataType,
typename DxDataType,
typename ScaleDataType,
typename BiasDataType,
typename DscaleDbiasDataType,
typename MeanVarDataType,
typename DyElementwiseOp,
typename XYGridDesc_M_K,
......@@ -35,8 +35,8 @@ __global__ void kernel_reduce_second_half_batchnorm_backward_final(
long_index_t reduce_size,
index_t num_xy_k_block_tile_iteration,
index_t num_dscale_dbias_k_block_tile_iteration,
const ScaleDataType* const __restrict__ p_reduce_dscale,
const BiasDataType* const __restrict__ p_reduce_dbias,
const DscaleDbiasDataType* const __restrict__ p_reduce_dscale,
const DscaleDbiasDataType* const __restrict__ p_reduce_dbias,
const MeanVarDataType* const __restrict__ p_mean,
const MeanVarDataType* const __restrict__ p_inv_var,
const XDataType* const __restrict__ p_x,
......@@ -44,8 +44,8 @@ __global__ void kernel_reduce_second_half_batchnorm_backward_final(
const ScaleDataType* const __restrict__ p_scale,
const DyElementwiseOp dy_elementwise_op,
DxDataType* const __restrict__ p_dx,
ScaleDataType* const __restrict__ p_dscale,
BiasDataType* const __restrict__ p_dbias)
DscaleDbiasDataType* const __restrict__ p_dscale,
DscaleDbiasDataType* const __restrict__ p_dbias)
{
GridwiseReduceSecondHalfBatchNormBackwardFinal_::Run(x_grid_desc_m_k,
dy_grid_desc_m_k,
......@@ -76,7 +76,7 @@ template <typename XDataType,
typename DxDataType,
typename AccDataType,
typename ScaleDataType,
typename BiasDataType,
typename DscaleDbiasDataType,
typename MeanVarDataType,
typename DyElementwiseOp,
typename XYGridDesc_M_K,
......@@ -92,8 +92,8 @@ template <typename XDataType,
index_t XSrcVectorSize,
index_t DySrcVectorSize,
index_t DxDstVectorSize,
index_t ScaleSrcDstVectorSize,
index_t BiasDstVectorSize,
index_t ScaleSrcVectorSize,
index_t DscaleDbiasDstVectorSize,
index_t MeanVarSrcVectorSize>
struct GridwiseReduceSecondHalfBatchNormBackwardFinal
{
......@@ -155,13 +155,13 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal
const DscaleDbiasGridDesc_M_K& dscale_dbias_grid_desc_m_k,
const MeanVarGridDesc_M& mean_var_grid_desc_m,
const ScaleBiasGridDesc_M& scale_grid_desc_m,
const ScaleBiasGridDesc_M& bias_grid_desc_m,
const ScaleBiasGridDesc_M& dscale_dbias_grid_desc_m,
index_t blkgroup_size,
long_index_t reduce_size,
index_t num_xy_k_block_tile_iteration,
index_t num_dscale_dbias_k_block_tile_iteration,
const ScaleDataType* const __restrict__ p_reduce_dscale,
const BiasDataType* const __restrict__ p_reduce_dbias,
const DscaleDbiasDataType* const __restrict__ p_reduce_dscale,
const DscaleDbiasDataType* const __restrict__ p_reduce_dbias,
const MeanVarDataType* const __restrict__ p_mean,
const MeanVarDataType* const __restrict__ p_inv_var,
const XDataType* const __restrict__ p_x,
......@@ -169,8 +169,8 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal
const ScaleDataType* const __restrict__ p_scale,
const DyElementwiseOp dy_elementwise_op,
DxDataType* const __restrict__ p_dx,
ScaleDataType* const __restrict__ p_dscale,
BiasDataType* const __restrict__ p_dbias)
DscaleDbiasDataType* const __restrict__ p_dscale,
DscaleDbiasDataType* const __restrict__ p_dbias)
{
__shared__ AccDataType p_reduce_work_buffer[BlockSize];
......@@ -222,24 +222,8 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal
// Step 1: do final reduction of dbias = sum(dy), dscale = sum(dy * (x-mean) * inv-variance)
// clang-format on
auto threadwise_dscale_load_m_k =
ThreadwiseTensorSliceTransfer_v2<ScaleDataType,
AccDataType,
DscaleDbiasGridDesc_M_K,
decltype(thread_buffer_desc_m_1),
ThreadBufferLengths_M_1,
Sequence<0, 1>,
1,
1,
1,
true>(
dscale_dbias_grid_desc_m_k,
make_multi_index(blkgroup_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * 1));
auto threadwise_dbias_load_m_k =
ThreadwiseTensorSliceTransfer_v2<BiasDataType,
auto threadwise_dscale_dbias_load_m_k =
ThreadwiseTensorSliceTransfer_v2<DscaleDbiasDataType,
AccDataType,
DscaleDbiasGridDesc_M_K,
decltype(thread_buffer_desc_m_1),
......@@ -254,38 +238,20 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal
thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * 1));
auto threadwise_dscale_store_m =
auto threadwise_dscale_dbias_store_m =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
ScaleDataType,
DscaleDbiasDataType,
decltype(thread_buffer_desc_m),
ScaleBiasGridDesc_M,
PassThroughOp,
ThreadBufferLengths_M,
Sequence<0>,
0,
ScaleSrcDstVectorSize,
DscaleDbiasDstVectorSize,
InMemoryDataOperationEnum::Set,
1,
true>(
scale_grid_desc_m,
make_multi_index(blkgroup_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize),
PassThroughOp{});
auto threadwise_dbias_store_m =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
BiasDataType,
decltype(thread_buffer_desc_m),
ScaleBiasGridDesc_M,
PassThroughOp,
ThreadBufferLengths_M,
Sequence<0>,
0,
BiasDstVectorSize,
InMemoryDataOperationEnum::Set,
1,
true>(
bias_grid_desc_m,
dscale_dbias_grid_desc_m,
make_multi_index(blkgroup_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize),
PassThroughOp{});
......@@ -297,10 +263,10 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal
p_reduce_dbias, dscale_dbias_grid_desc_m_k.GetElementSpaceSize());
auto dscale_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_dscale, scale_grid_desc_m.GetElementSpaceSize());
p_dscale, dscale_dbias_grid_desc_m.GetElementSpaceSize());
auto dbias_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_dbias, bias_grid_desc_m.GetElementSpaceSize());
p_dbias, dscale_dbias_grid_desc_m.GetElementSpaceSize());
constexpr auto dscale_dbias_thread_copy_step_m_k =
make_multi_index(0, KThreadClusterSize * 1);
......@@ -313,25 +279,23 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal
for(index_t reducedTiles = 0; reducedTiles < num_dscale_dbias_k_block_tile_iteration;
++reducedTiles)
{
threadwise_dscale_load_m_k.Run(dscale_dbias_grid_desc_m_k,
reduce_dscale_global_buf,
thread_buffer_desc_m_1,
make_tuple(I0, I0),
reduce_dscale_thread_buf);
threadwise_dbias_load_m_k.Run(dscale_dbias_grid_desc_m_k,
reduce_dbias_global_buf,
thread_buffer_desc_m_1,
make_tuple(I0, I0),
reduce_dbias_thread_buf);
threadwise_dscale_dbias_load_m_k.Run(dscale_dbias_grid_desc_m_k,
reduce_dscale_global_buf,
thread_buffer_desc_m_1,
make_tuple(I0, I0),
reduce_dscale_thread_buf);
threadwise_dscale_dbias_load_m_k.Run(dscale_dbias_grid_desc_m_k,
reduce_dbias_global_buf,
thread_buffer_desc_m_1,
make_tuple(I0, I0),
reduce_dbias_thread_buf);
ThreadwiseReduce::Reduce(reduce_dscale_thread_buf, dscale_thread_buf);
ThreadwiseReduce::Reduce(reduce_dbias_thread_buf, dbias_thread_buf);
threadwise_dscale_load_m_k.MoveSrcSliceWindow(dscale_dbias_grid_desc_m_k,
dscale_dbias_thread_copy_step_m_k);
threadwise_dbias_load_m_k.MoveSrcSliceWindow(dscale_dbias_grid_desc_m_k,
dscale_dbias_thread_copy_step_m_k);
threadwise_dscale_dbias_load_m_k.MoveSrcSliceWindow(dscale_dbias_grid_desc_m_k,
dscale_dbias_thread_copy_step_m_k);
}
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
......@@ -343,17 +307,17 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal
BlockwiseReduce::Reduce(reduce_work_buf, dbias_thread_buf(I));
});
threadwise_dscale_store_m.Run(thread_buffer_desc_m,
make_tuple(I0),
dscale_thread_buf,
scale_grid_desc_m,
dscale_global_buf);
threadwise_dscale_dbias_store_m.Run(thread_buffer_desc_m,
make_tuple(I0),
dscale_thread_buf,
dscale_dbias_grid_desc_m,
dscale_global_buf);
threadwise_dbias_store_m.Run(thread_buffer_desc_m,
make_tuple(I0),
dbias_thread_buf,
bias_grid_desc_m,
dbias_global_buf);
threadwise_dscale_dbias_store_m.Run(thread_buffer_desc_m,
make_tuple(I0),
dbias_thread_buf,
dscale_dbias_grid_desc_m,
dbias_global_buf);
// clang-format off
// Step 2: calculate dx = 1/N * inv-variance * scale * (N * dy - dbias - dscale * (x - mean) * inv-variance)
......@@ -418,7 +382,7 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal
ThreadBufferLengths_M,
Sequence<0>,
0,
ScaleSrcDstVectorSize,
ScaleSrcVectorSize,
1,
true>(
scale_grid_desc_m,
......
......@@ -17,7 +17,7 @@ template <typename GridwiseWelfordSecondHalfReduceFirstHalf_,
typename DyDataType,
typename AccDataType,
typename ScaleDataType,
typename BiasDataType,
typename DscaleDbiasDataType,
typename MeanVarDataType,
typename DyElementwiseOp,
typename XYGridDesc_M_K,
......@@ -45,8 +45,8 @@ __global__ void kernel_welford_second_half_reduce_first_half(
MeanVarDataType* const __restrict__ p_out_welford_inv_variance,
const XDataType* const __restrict__ p_x,
const DyDataType* const __restrict__ p_dy,
ScaleDataType* const __restrict__ p_reduce_dscale,
BiasDataType* const __restrict__ p_reduce_dbias)
DscaleDbiasDataType* const __restrict__ p_reduce_dscale,
DscaleDbiasDataType* const __restrict__ p_reduce_dbias)
{
GridwiseWelfordSecondHalfReduceFirstHalf_::Run(x_grid_desc_m_k,
dy_grid_desc_m_k,
......@@ -76,7 +76,7 @@ template <typename XDataType,
typename DyDataType,
typename AccDataType,
typename ScaleDataType,
typename BiasDataType,
typename DscaleDbiasDataType,
typename MeanVarDataType,
typename DyElementwiseOp,
typename XYGridDesc_M_K,
......@@ -174,8 +174,8 @@ struct GridwiseWelfordSecondHalfReduceFirstHalf
MeanVarDataType* const __restrict__ p_out_welford_inv_variance,
const XDataType* const __restrict__ p_x,
const DyDataType* const __restrict__ p_dy,
ScaleDataType* const __restrict__ p_reduce_dscale,
BiasDataType* const __restrict__ p_reduce_dbias)
DscaleDbiasDataType* const __restrict__ p_reduce_dscale,
DscaleDbiasDataType* const __restrict__ p_reduce_dbias)
{
__shared__ AccDataType p_reduce_work_buffer[BlockSize];
......@@ -511,28 +511,9 @@ struct GridwiseWelfordSecondHalfReduceFirstHalf
BlockwiseReduce::Reduce(reduce_work_buf, reduce_dbias_thread_buf(I));
});
auto threadwise_dscale_store =
auto threadwise_dscale_dbias_store =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
ScaleDataType,
decltype(thread_buffer_desc_m_1),
DscaleDbiasGridDesc_M_G,
PassThroughOp,
ThreadBufferLengths_M_1,
Sequence<0, 1>,
1,
1,
InMemoryDataOperationEnum::Set,
1,
true>(
dscale_dbias_grid_desc_m_g,
make_multi_index(blkgroup_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
block_local_id),
PassThroughOp{});
auto threadwise_dbias_store =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
BiasDataType,
DscaleDbiasDataType,
decltype(thread_buffer_desc_m_1),
DscaleDbiasGridDesc_M_G,
PassThroughOp,
......@@ -557,17 +538,17 @@ struct GridwiseWelfordSecondHalfReduceFirstHalf
if(thread_k_cluster_id == 0)
{
threadwise_dscale_store.Run(thread_buffer_desc_m_1,
make_tuple(I0, I0),
reduce_dscale_thread_buf,
dscale_dbias_grid_desc_m_g,
reduce_dscale_global_buf);
threadwise_dbias_store.Run(thread_buffer_desc_m_1,
make_tuple(I0, I0),
reduce_dbias_thread_buf,
dscale_dbias_grid_desc_m_g,
reduce_dbias_global_buf);
threadwise_dscale_dbias_store.Run(thread_buffer_desc_m_1,
make_tuple(I0, I0),
reduce_dscale_thread_buf,
dscale_dbias_grid_desc_m_g,
reduce_dscale_global_buf);
threadwise_dscale_dbias_store.Run(thread_buffer_desc_m_1,
make_tuple(I0, I0),
reduce_dbias_thread_buf,
dscale_dbias_grid_desc_m_g,
reduce_dbias_global_buf);
};
};
};
......
......@@ -21,7 +21,7 @@ template <typename GridwiseBatchrNormBackwardWithBlockwiseWelford_,
typename DxDataType,
typename AccDataType,
typename ScaleDataType,
typename BiasDataType,
typename DscaleDbiasDataType,
typename MeanVarDataType,
typename DyElementwiseOp,
typename XYGridDesc_M_K,
......@@ -33,7 +33,7 @@ __global__ void kernel_batchnorm_backward_with_blockwise_welford(
const XYGridDesc_M_K dy_grid_desc_m_k,
const XYGridDesc_M_K dx_grid_desc_m_k,
const ScaleBiasGridDesc_M scale_grid_desc_m,
const ScaleBiasGridDesc_M bias_grid_desc_m,
const ScaleBiasGridDesc_M dscale_dbias_grid_desc_m,
const MeanVarGridDesc_M mean_var_grid_desc_m,
const GetReduceCountPerThreadFunctor get_reduce_count_per_thread,
long_index_t reduce_size,
......@@ -47,14 +47,14 @@ __global__ void kernel_batchnorm_backward_with_blockwise_welford(
const MeanVarDataType* const __restrict__ p_savedInvVar,
const DyElementwiseOp dy_elementwise_op,
DxDataType* const __restrict__ p_dx,
ScaleDataType* const __restrict__ p_dscale,
BiasDataType* const __restrict__ p_dbias)
DscaleDbiasDataType* const __restrict__ p_dscale,
DscaleDbiasDataType* const __restrict__ p_dbias)
{
GridwiseBatchrNormBackwardWithBlockwiseWelford_::Run(x_grid_desc_m_k,
dy_grid_desc_m_k,
dx_grid_desc_m_k,
scale_grid_desc_m,
bias_grid_desc_m,
dscale_dbias_grid_desc_m,
mean_var_grid_desc_m,
get_reduce_count_per_thread,
reduce_size,
......@@ -77,7 +77,7 @@ template <typename XDataType,
typename DxDataType,
typename AccDataType,
typename ScaleDataType,
typename BiasDataType,
typename DscaleDbiasDataType,
typename MeanVarDataType,
typename DyElementwiseOp,
typename XYGridDesc_M_K,
......@@ -93,8 +93,8 @@ template <typename XDataType,
index_t XSrcVectorSize,
index_t DySrcVectorSize,
index_t DxDstVectorSize,
index_t ScaleSrcDstVectorSize,
index_t BiasDstVectorSize,
index_t ScaleSrcVectorSize,
index_t DscaleDbiasDstVectorSize,
index_t MeanVarSrcVectorSize>
struct GridwiseBatchNormBackwardWithBlockwiseWelford
{
......@@ -165,7 +165,7 @@ struct GridwiseBatchNormBackwardWithBlockwiseWelford
const XYGridDesc_M_K dy_grid_desc_m_k,
const XYGridDesc_M_K dx_grid_desc_m_k,
const ScaleBiasGridDesc_M scale_grid_desc_m,
const ScaleBiasGridDesc_M bias_grid_desc_m,
const ScaleBiasGridDesc_M dscale_dbias_grid_desc_m,
const MeanVarGridDesc_M mean_var_grid_desc_m,
const GetReduceCountPerThreadFunctor get_reduce_count_per_thread,
long_index_t reduce_size,
......@@ -179,8 +179,8 @@ struct GridwiseBatchNormBackwardWithBlockwiseWelford
const MeanVarDataType* const __restrict__ p_savedInvVar,
const DyElementwiseOp dy_elementwise_op,
DxDataType* const __restrict__ p_dx,
ScaleDataType* const __restrict__ p_dscale,
BiasDataType* const __restrict__ p_dbias)
DscaleDbiasDataType* const __restrict__ p_dscale,
DscaleDbiasDataType* const __restrict__ p_dbias)
{
using ck::math::sqrt;
......@@ -253,7 +253,7 @@ struct GridwiseBatchNormBackwardWithBlockwiseWelford
XSrcVectorSize,
1,
true>(
x_grid_desc_m_k,
dy_grid_desc_m_k,
make_multi_index(block_global_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * KThreadSliceSize));
......@@ -271,7 +271,7 @@ struct GridwiseBatchNormBackwardWithBlockwiseWelford
InMemoryDataOperationEnum::Set,
1,
true>(
dy_grid_desc_m_k,
dx_grid_desc_m_k,
make_multi_index(block_global_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * KThreadSliceSize),
......@@ -285,45 +285,27 @@ struct GridwiseBatchNormBackwardWithBlockwiseWelford
ThreadBufferLengths_M,
Sequence<0>,
0,
ScaleSrcDstVectorSize,
ScaleSrcVectorSize,
1,
true>(
scale_grid_desc_m,
make_multi_index(block_global_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize));
auto threadwise_dscale_store =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
ScaleDataType,
decltype(thread_buffer_desc_m),
ScaleBiasGridDesc_M,
PassThroughOp,
ThreadBufferLengths_M,
Sequence<0>,
0,
ScaleSrcDstVectorSize,
InMemoryDataOperationEnum::Set,
1,
true>(
scale_grid_desc_m,
make_multi_index(block_global_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize),
PassThroughOp{});
auto threadwise_dbias_store =
auto threadwise_dscale_dbias_store =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
BiasDataType,
DscaleDbiasDataType,
decltype(thread_buffer_desc_m),
ScaleBiasGridDesc_M,
PassThroughOp,
ThreadBufferLengths_M,
Sequence<0>,
0,
BiasDstVectorSize,
DscaleDbiasDstVectorSize,
InMemoryDataOperationEnum::Set,
1,
true>(
bias_grid_desc_m,
dscale_dbias_grid_desc_m,
make_multi_index(block_global_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize),
PassThroughOp{});
......@@ -344,10 +326,10 @@ struct GridwiseBatchNormBackwardWithBlockwiseWelford
p_scale, scale_grid_desc_m.GetElementSpaceSize());
auto dscale_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_dscale, scale_grid_desc_m.GetElementSpaceSize());
p_dscale, dscale_dbias_grid_desc_m.GetElementSpaceSize());
auto dbias_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_dbias, bias_grid_desc_m.GetElementSpaceSize());
p_dbias, dscale_dbias_grid_desc_m.GetElementSpaceSize());
// clang-format off
// Step 1: calculating mean and inv-variance using welford method (if savedMean/savedInvVar not available), where inv-variance = 1/sqrt(epsilon+variance)
......@@ -487,17 +469,17 @@ struct GridwiseBatchNormBackwardWithBlockwiseWelford
if(thread_k_cluster_id == 0)
{
threadwise_dscale_store.Run(thread_buffer_desc_m,
make_tuple(I0),
dscale_thread_buf,
scale_grid_desc_m,
dscale_global_buf);
threadwise_dbias_store.Run(thread_buffer_desc_m,
make_tuple(I0),
dbias_thread_buf,
bias_grid_desc_m,
dbias_global_buf);
threadwise_dscale_dbias_store.Run(thread_buffer_desc_m,
make_tuple(I0),
dscale_thread_buf,
dscale_dbias_grid_desc_m,
dscale_global_buf);
threadwise_dscale_dbias_store.Run(thread_buffer_desc_m,
make_tuple(I0),
dbias_thread_buf,
dscale_dbias_grid_desc_m,
dbias_global_buf);
};
// clang-format off
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/data_type.hpp"
#include "ck/utility/math.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_welford.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_welford.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace ck {
template <typename GridwiseMultiblockWelfordFirstHalf_,
typename XDataType,
typename MeanVarDataType,
typename XGridDesc_M_K,
typename MeanVarCountGridDesc_M_G,
typename GetReduceCountPerThreadFunctor>
__global__ void kernel_multiblock_welford_first_half(
const XGridDesc_M_K x_grid_desc_m_k,
const MeanVarCountGridDesc_M_G mean_var_count_grid_desc_m_g,
const GetReduceCountPerThreadFunctor get_reduce_count_per_thread,
index_t num_k_block_tile_iteration,
const XDataType* const __restrict__ p_x,
MeanVarDataType* const p_welford_mean,
MeanVarDataType* const p_welford_variance,
int32_t* const p_welford_count)
{
GridwiseMultiblockWelfordFirstHalf_::Run(x_grid_desc_m_k,
mean_var_count_grid_desc_m_g,
get_reduce_count_per_thread,
num_k_block_tile_iteration,
p_x,
p_welford_mean,
p_welford_variance,
p_welford_count);
};
template <typename XDataType,
typename AccDataType,
typename MeanVarDataType,
typename XGridDesc_M_K,
typename MeanVarCountGridDesc_M_G,
typename GetReduceCountPerThreadFunctor,
index_t BlockSize,
index_t MThreadClusterSize,
index_t KThreadClusterSize,
index_t MThreadSliceSize,
index_t KThreadSliceSize,
index_t XSrcCountSrcVectorDim,
index_t XSrcCountSrcVectorSize>
struct GridwiseMultiblockWelfordFirstHalf
{
static_assert((XSrcCountSrcVectorDim == 0 && MThreadSliceSize % XSrcCountSrcVectorSize == 0) ||
(XSrcCountSrcVectorDim == 1 &&
KThreadSliceSize % XSrcCountSrcVectorSize == 0),
"Invalid thread slice sizes and/or vector sizes configuration, please check!");
static constexpr bool reorder_thread_cluster = (XSrcCountSrcVectorDim == 0);
using ThreadClusterLengths_M_K = Sequence<MThreadClusterSize, KThreadClusterSize>;
using ThreadBufferDimAccessOrder =
typename conditional<reorder_thread_cluster, Sequence<1, 0>, Sequence<0, 1>>::type;
using ThreadClusterArrangeOrder =
typename conditional<reorder_thread_cluster, Sequence<1, 0>, Sequence<0, 1>>::type;
static constexpr auto thread_cluster_desc =
make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{});
using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{})));
using ThreadReduceDstDesc_M =
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{})));
using ThreadwiseWelford =
ThreadwiseWelford<AccDataType, ThreadReduceSrcDesc_M_K, ThreadReduceDstDesc_M>;
using BlockwiseWelford = BlockwiseWelford<AccDataType,
BlockSize,
ThreadClusterLengths_M_K,
ThreadClusterArrangeOrder,
false>;
using PassThroughOp = tensor_operation::element_wise::PassThrough;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
__device__ static void Run(const XGridDesc_M_K& x_grid_desc_m_k,
const MeanVarCountGridDesc_M_G& mean_var_count_grid_desc_m_g,
const GetReduceCountPerThreadFunctor& get_reduce_count_per_thread,
index_t num_k_block_tile_iteration,
const XDataType* const __restrict__ p_x,
MeanVarDataType* const p_welford_mean,
MeanVarDataType* const p_welford_variance,
int32_t* const p_welford_count)
{
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
x_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true>
welford_mean_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true>
welford_var_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, int32_t, MThreadSliceSize, true>
welford_count_thread_buf;
const index_t blkgroup_size = mean_var_count_grid_desc_m_g.GetLength(I1);
const index_t thread_local_id = get_thread_local_1d_id();
const index_t block_global_id = get_block_1d_id();
const index_t blkgroup_id = block_global_id / blkgroup_size;
const index_t block_local_id = block_global_id % blkgroup_size;
const auto thread_cluster_idx =
thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));
const auto thread_m_cluster_id = thread_cluster_idx[I0];
const auto thread_k_cluster_id = thread_cluster_idx[I1];
using ThreadBufferLengths_M_K = Sequence<MThreadSliceSize, KThreadSliceSize>;
using ThreadBufferLengths_M_1 = Sequence<MThreadSliceSize, 1>;
constexpr auto thread_buffer_desc_m_k = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{}));
constexpr auto thread_buffer_desc_m_1 = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<1>{}));
const index_t reduceSizePerBlock = K_BlockTileSize * num_k_block_tile_iteration;
auto threadwise_x_load = ThreadwiseTensorSliceTransfer_v2<XDataType,
AccDataType,
XGridDesc_M_K,
decltype(thread_buffer_desc_m_k),
ThreadBufferLengths_M_K,
ThreadBufferDimAccessOrder,
XSrcCountSrcVectorDim,
XSrcCountSrcVectorSize,
1,
true>(
x_grid_desc_m_k,
make_multi_index(blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize,
block_local_id * reduceSizePerBlock +
thread_k_cluster_id * KThreadSliceSize));
auto threadwise_welford_mean_var_store =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
MeanVarDataType,
decltype(thread_buffer_desc_m_1),
MeanVarCountGridDesc_M_G,
PassThroughOp,
ThreadBufferLengths_M_1,
Sequence<0, 1>,
1,
1,
InMemoryDataOperationEnum::Set,
1,
true>(
mean_var_count_grid_desc_m_g,
make_multi_index(blkgroup_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
block_local_id),
PassThroughOp{});
auto threadwise_welford_count_store =
ThreadwiseTensorSliceTransfer_v1r3<int32_t,
int32_t,
decltype(thread_buffer_desc_m_1),
MeanVarCountGridDesc_M_G,
PassThroughOp,
ThreadBufferLengths_M_1,
Sequence<0, 1>,
1,
1,
InMemoryDataOperationEnum::Set,
1,
true>(
mean_var_count_grid_desc_m_g,
make_multi_index(blkgroup_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
block_local_id),
PassThroughOp{});
constexpr auto thread_copy_fwd_step_m_k = make_multi_index(0, K_BlockTileSize);
const auto x_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_x, x_grid_desc_m_k.GetElementSpaceSize());
auto welford_mean_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_welford_mean, mean_var_count_grid_desc_m_g.GetElementSpaceSize());
auto welford_var_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_welford_variance, mean_var_count_grid_desc_m_g.GetElementSpaceSize());
auto welford_count_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_welford_count, mean_var_count_grid_desc_m_g.GetElementSpaceSize());
auto threadwise_welford = ThreadwiseWelford();
threadwise_welford.max_count_ =
get_reduce_count_per_thread(block_local_id, thread_k_cluster_id);
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
welford_mean_thread_buf(I) = type_convert<AccDataType>(0.0f);
welford_var_thread_buf(I) = type_convert<AccDataType>(0.0f);
});
for(index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles)
{
threadwise_x_load.Run(x_grid_desc_m_k,
x_global_val_buf,
thread_buffer_desc_m_k,
make_tuple(I0, I0),
x_thread_buf);
threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_fwd_step_m_k);
threadwise_welford.Run(x_thread_buf, welford_mean_thread_buf, welford_var_thread_buf);
}
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
if constexpr(I > 0)
block_sync_lds();
welford_count_thread_buf(I) = threadwise_welford.cur_count_;
BlockwiseWelford::Run(
welford_mean_thread_buf(I), welford_var_thread_buf(I), welford_count_thread_buf(I));
});
if(thread_k_cluster_id == 0)
{
threadwise_welford_mean_var_store.Run(thread_buffer_desc_m_1,
make_tuple(I0, I0),
welford_mean_thread_buf,
mean_var_count_grid_desc_m_g,
welford_mean_global_val_buf);
threadwise_welford_mean_var_store.Run(thread_buffer_desc_m_1,
make_tuple(I0, I0),
welford_var_thread_buf,
mean_var_count_grid_desc_m_g,
welford_var_global_val_buf);
threadwise_welford_count_store.Run(thread_buffer_desc_m_1,
make_tuple(I0, I0),
welford_count_thread_buf,
mean_var_count_grid_desc_m_g,
welford_count_global_val_buf);
};
}
};
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <array>
#include <algorithm>
#include <thread>
#include "ck/utility/math_v2.hpp"
#include "ck/utility/ignore.hpp"
#include "ck/library/utility/host_common_util.hpp"
#include "ck/tensor_operation/gpu/device/device_batchnorm_backward.hpp"
namespace ck {
namespace tensor_operation {
namespace host {
template <typename XDataType,
typename DxDataType,
typename DyDataType,
typename AccDataType,
typename ScaleDataType,
typename DscaleDbiasDataType,
typename MeanVarDataType,
typename DyElementwiseOp,
index_t Rank,
index_t NumBatchNormReduceDim>
struct ReferenceBatchNormBwd : public device::DeviceBatchNormBwd<XDataType,
DxDataType,
DyDataType,
AccDataType,
ScaleDataType,
DscaleDbiasDataType,
MeanVarDataType,
DyElementwiseOp,
Rank,
NumBatchNormReduceDim>
{
static_assert(Rank <= 6, "Bigger Rank size is not supported!");
static constexpr index_t NumInvariantDim = Rank - NumBatchNormReduceDim;
struct Argument : public device::BaseArgument
{
Argument(const std::array<index_t, Rank> xyLengths,
const std::array<index_t, Rank> xStrides,
const std::array<index_t, Rank> dxStrides,
const std::array<index_t, Rank> dyStrides,
const std::array<int, NumBatchNormReduceDim> reduceDims,
const std::array<index_t, NumInvariantDim> bnScaleBiasMeanVarLengths,
const std::array<index_t, NumInvariantDim> bnScaleStrides,
const std::array<index_t, NumInvariantDim> bnDscaleDbiasStrides,
const std::array<index_t, NumInvariantDim> bnMeanVarStrides,
const XDataType* p_x,
const DyDataType* p_dy,
const ScaleDataType* p_scale,
const MeanVarDataType* p_savedMean,
const MeanVarDataType* p_savedInvVar,
double epsilon,
const DyElementwiseOp dy_elementwise_op,
DxDataType* p_dx,
DscaleDbiasDataType* p_dscale,
DscaleDbiasDataType* p_dbias)
: reduceDims_(reduceDims),
bnScaleBiasMeanVarLengths_(bnScaleBiasMeanVarLengths),
bnScaleStrides_(bnScaleStrides),
bnDscaleDbiasStrides_(bnDscaleDbiasStrides),
bnMeanVarStrides_(bnMeanVarStrides),
p_x_(p_x),
p_dy_(p_dy),
p_scale_(p_scale),
p_savedMean_(p_savedMean),
p_savedInvVar_(p_savedInvVar),
dy_elementwise_op_(dy_elementwise_op),
p_dx_(p_dx),
p_dscale_(p_dscale),
p_dbias_(p_dbias)
{
using ck::host_common::get_index_set;
if(std::any_of(
reduceDims.begin(), reduceDims.end(), [](int d) { return d < 0 || d >= Rank; }))
throw std::runtime_error("Invalid reduce dimensions!");
// get invariant_dims[] and invariant_lengths[]
for(int dim = 0, i = 0; dim < Rank; dim++)
if(std::none_of(
reduceDims.begin(), reduceDims.end(), [&](int d) { return d == dim; }))
{
invariantDims_[i] = dim;
invariant_lengths_[i] = xyLengths[dim];
i++;
};
// get reduce_lengths_[]
for(int j = 0, i = 0; j < NumBatchNormReduceDim; j++)
{
int dim = reduceDims[j];
reduce_lengths_[i++] = xyLengths[dim];
};
for(int i = 0; i < NumInvariantDim; i++)
if(invariant_lengths_[i] != bnScaleBiasMeanVarLengths_[i])
throw std::runtime_error("Invalid lengths parameters!");
for(int j = 0, i = 0; j < NumInvariantDim; j++)
{
int dim = invariantDims_[j];
x_invariant_strides_[i] = xStrides[dim];
dy_invariant_strides_[i] = dyStrides[dim];
dx_invariant_strides_[i] = dxStrides[dim];
i++;
};
for(int j = 0, i = 0; j < NumBatchNormReduceDim; j++)
{
int dim = reduceDims_[j];
x_reduce_strides_[i] = xStrides[dim];
dy_reduce_strides_[i] = dyStrides[dim];
dx_reduce_strides_[i] = dxStrides[dim];
i++;
};
reduceSize_ = std::accumulate(
reduce_lengths_.begin(), reduce_lengths_.end(), 1, std::multiplies<size_t>{});
invariant_index_set_ = get_index_set<NumInvariantDim>(invariant_lengths_);
reduce_index_set_ = get_index_set<NumBatchNormReduceDim>(reduce_lengths_);
epsilon_ = type_convert<AccDataType>(epsilon);
haveSavedMeanInvVar_ = (p_savedMean != nullptr && p_savedInvVar != nullptr);
}
std::array<int, NumBatchNormReduceDim> reduceDims_;
std::array<int, NumInvariantDim> invariantDims_;
std::array<index_t, NumInvariantDim> invariant_lengths_;
std::array<index_t, NumBatchNormReduceDim> reduce_lengths_;
const std::array<index_t, NumInvariantDim> bnScaleBiasMeanVarLengths_;
const std::array<index_t, NumInvariantDim> bnScaleStrides_;
const std::array<index_t, NumInvariantDim> bnDscaleDbiasStrides_;
const std::array<index_t, NumInvariantDim> bnMeanVarStrides_;
std::array<index_t, NumInvariantDim> x_invariant_strides_;
std::array<index_t, NumInvariantDim> dy_invariant_strides_;
std::array<index_t, NumInvariantDim> dx_invariant_strides_;
std::array<index_t, NumBatchNormReduceDim> x_reduce_strides_;
std::array<index_t, NumBatchNormReduceDim> dy_reduce_strides_;
std::array<index_t, NumBatchNormReduceDim> dx_reduce_strides_;
const XDataType* p_x_;
const DyDataType* p_dy_;
const ScaleDataType* p_scale_;
const MeanVarDataType* p_savedMean_;
const MeanVarDataType* p_savedInvVar_;
const DyElementwiseOp dy_elementwise_op_;
DxDataType* p_dx_;
DscaleDbiasDataType* p_dscale_;
DscaleDbiasDataType* p_dbias_;
bool haveSavedMeanInvVar_;
std::vector<std::array<index_t, NumInvariantDim>> invariant_index_set_;
std::vector<std::array<index_t, NumBatchNormReduceDim>> reduce_index_set_;
AccDataType epsilon_;
size_t reduceSize_;
};
struct Invoker : public device::BaseInvoker
{
float Run(const Argument& arg)
{
using ck::host_common::get_offset_from_index;
auto thread_reduce_func = [&](auto invariant_index) {
size_t x_invariant_offset = get_offset_from_index<NumInvariantDim>(
arg.x_invariant_strides_, invariant_index);
size_t dy_invariant_offset = get_offset_from_index<NumInvariantDim>(
arg.dy_invariant_strides_, invariant_index);
size_t dx_invariant_offset = get_offset_from_index<NumInvariantDim>(
arg.dx_invariant_strides_, invariant_index);
AccDataType mean = type_convert<AccDataType>(0.0f);
AccDataType variance = type_convert<AccDataType>(0.0f);
AccDataType invVar;
int32_t curr_count = 0;
if(arg.haveSavedMeanInvVar_)
{
size_t mean_invVar_invariant_offset = get_offset_from_index<NumInvariantDim>(
arg.bnMeanVarStrides_, invariant_index);
mean =
type_convert<AccDataType>(arg.p_savedMean_[mean_invVar_invariant_offset]);
invVar =
type_convert<AccDataType>(arg.p_savedInvVar_[mean_invVar_invariant_offset]);
}
else
{
// compute mean, variance using welford method
for(const auto& reduce_index : arg.reduce_index_set_)
{
size_t x_reduce_offset = get_offset_from_index<NumBatchNormReduceDim>(
arg.x_reduce_strides_, reduce_index);
auto x_offset = x_invariant_offset + x_reduce_offset;
curr_count++;
AccDataType x = type_convert<AccDataType>(arg.p_x_[x_offset]);
AccDataType delta = x - mean;
mean += delta / curr_count;
AccDataType delta2 = x - mean;
variance += delta * delta2;
};
// actual variance
variance = variance / curr_count;
// inv-variance defined as 1/sqrt(epsilon+variance)
invVar =
type_convert<AccDataType>(1.0f) / ck::math::sqrt(arg.epsilon_ + variance);
};
AccDataType dbias =
type_convert<AccDataType>(0.0f); // Sum on reduced dimensions of dy
AccDataType dscale =
type_convert<AccDataType>(0.0f); // Sum on reduced dimensions of dy * norm_x
// 1) calculate dy * (x - mean) * inv-variance
// 2) calculate sum(dy) on reduced dimensions
// 3) calculate sum(dy * norm_x) on reduced dimensions
for(const auto& reduce_index : arg.reduce_index_set_)
{
size_t x_reduce_offset = get_offset_from_index<NumBatchNormReduceDim>(
arg.x_reduce_strides_, reduce_index);
size_t dy_reduce_offset = get_offset_from_index<NumBatchNormReduceDim>(
arg.dy_reduce_strides_, reduce_index);
auto x_offset = x_invariant_offset + x_reduce_offset;
auto dy_offset = dy_invariant_offset + dy_reduce_offset;
AccDataType x = type_convert<AccDataType>(arg.p_x_[x_offset]);
AccDataType norm_x = (x - mean) * invVar;
AccDataType dy = type_convert<AccDataType>(arg.p_dy_[dy_offset]);
arg.dy_elementwise_op_(dy, dy);
dbias += dy;
dscale += norm_x * dy;
};
size_t dscale_offset = get_offset_from_index<NumInvariantDim>(
arg.bnDscaleDbiasStrides_, invariant_index);
size_t dbias_offset = get_offset_from_index<NumInvariantDim>(
arg.bnDscaleDbiasStrides_, invariant_index);
arg.p_dscale_[dscale_offset] = type_convert<DscaleDbiasDataType>(dscale);
arg.p_dbias_[dbias_offset] = type_convert<DscaleDbiasDataType>(dbias);
size_t scale_offset =
get_offset_from_index<NumInvariantDim>(arg.bnScaleStrides_, invariant_index);
AccDataType scale = type_convert<AccDataType>(arg.p_scale_[scale_offset]);
AccDataType multiplier = type_convert<AccDataType>(1.0f) /
type_convert<AccDataType>(arg.reduceSize_) * invVar *
scale;
// 1) calculate tmp = dscale * (x - mean) * inv-variance
// 2) calculate dx = 1/reduceSize * inv-variance * scale * (reduceSize * dy - dbias
// - tmp)
for(const auto& reduce_index : arg.reduce_index_set_)
{
size_t x_reduce_offset = get_offset_from_index<NumBatchNormReduceDim>(
arg.x_reduce_strides_, reduce_index);
size_t dy_reduce_offset = get_offset_from_index<NumBatchNormReduceDim>(
arg.dy_reduce_strides_, reduce_index);
size_t dx_reduce_offset = get_offset_from_index<NumBatchNormReduceDim>(
arg.dx_reduce_strides_, reduce_index);
auto x_offset = x_invariant_offset + x_reduce_offset;
auto dy_offset = dy_invariant_offset + dy_reduce_offset;
auto dx_offset = dx_invariant_offset + dx_reduce_offset;
AccDataType x = type_convert<AccDataType>(arg.p_x_[x_offset]);
AccDataType norm_x = (x - mean) * invVar;
AccDataType dy = type_convert<AccDataType>(arg.p_dy_[dy_offset]);
arg.dy_elementwise_op_(dy, dy);
AccDataType tmpVal = norm_x * dscale;
AccDataType dx = multiplier * (type_convert<AccDataType>(arg.reduceSize_) * dy -
dbias - tmpVal);
arg.p_dx_[dx_offset] = type_convert<DxDataType>(dx);
};
};
std::size_t num_thread = std::thread::hardware_concurrency();
std::size_t work_per_thread =
(arg.invariant_index_set_.size() + num_thread - 1) / num_thread;
std::vector<joinable_thread> threads(num_thread);
for(std::size_t it = 0; it < num_thread; ++it)
{
std::size_t i_begin = it * work_per_thread;
std::size_t i_end = std::min(static_cast<size_t>((it + 1) * work_per_thread),
arg.invariant_index_set_.size());
auto f = [=] {
for(std::size_t i = i_begin; i < i_end; ++i)
{
thread_reduce_func(arg.invariant_index_set_[i]);
}
};
threads[it] = joinable_thread(f);
}
return (0.0f);
};
float Run(const device::BaseArgument* p_arg,
const StreamConfig& /*stream_config*/ = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg));
};
};
bool IsSupportedArgument(const device::BaseArgument* p_arg) override
{
(void)p_arg;
return (true);
};
std::unique_ptr<device::BaseArgument>
MakeArgumentPointer(const std::array<index_t, Rank> xyLengths,
const std::array<index_t, Rank> xStrides,
const std::array<index_t, Rank> dxStrides,
const std::array<index_t, Rank> dyStrides,
const std::array<int, NumBatchNormReduceDim> reduceDims,
const std::array<index_t, NumInvariantDim> bnScaleBiasMeanVarLengths,
const std::array<index_t, NumInvariantDim> bnScaleStrides,
const std::array<index_t, NumInvariantDim> bnDscaleDbiasStrides,
const std::array<index_t, NumInvariantDim> bnMeanVarStrides,
const void* p_x,
const void* p_dy,
const void* p_scale,
const void* p_savedMean,
const void* p_savedInvVar,
double epsilon,
const DyElementwiseOp dy_elementwise_op,
void* p_dx,
void* p_dscale,
void* p_dbias) override
{
return std::make_unique<Argument>(xyLengths,
xStrides,
dxStrides,
dyStrides,
reduceDims,
bnScaleBiasMeanVarLengths,
bnScaleStrides,
bnDscaleDbiasStrides,
bnMeanVarStrides,
static_cast<const XDataType*>(p_x),
static_cast<const DyDataType*>(p_dy),
static_cast<const ScaleDataType*>(p_scale),
static_cast<const MeanVarDataType*>(p_savedMean),
static_cast<const MeanVarDataType*>(p_savedInvVar),
epsilon,
dy_elementwise_op,
static_cast<DxDataType*>(p_dx),
static_cast<DscaleDbiasDataType*>(p_dscale),
static_cast<DscaleDbiasDataType*>(p_dbias));
};
std::unique_ptr<device::BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>();
};
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "Reference_BatchNorm_Backward" << std::endl;
// clang-format on
return str.str();
}
};
} // namespace host
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include <algorithm>
#include "ck/tensor_operation/gpu/device/device_batchnorm_backward.hpp"
namespace ck {
namespace tensor_operation {
namespace host {
template <typename XDataType,
typename DyDataType,
typename DxDataType,
typename AccDataType,
typename ScaleDataType,
typename BiasDataType,
typename MeanVarDataType,
typename DyElementwiseOp>
struct ReferenceBatchNormBwd_Input_N_H_W_C_Output_C
: public device::DeviceBatchNormBwd<4, 3, DyElementwiseOp>
{
struct Argument : public device::BaseArgument
{
Argument(const std::array<index_t, 4> xyLengths,
const std::array<index_t, 4> xStrides,
const std::array<index_t, 4> dyStrides,
const std::array<index_t, 4> dxStrides,
const std::array<int, 3> reduceDims,
const std::array<ck::index_t, 1> bnScaleBiasMeanVarLengths,
const std::array<ck::index_t, 1> bnScaleStrides,
const std::array<ck::index_t, 1> bnBiasStrides,
const std::array<ck::index_t, 1> bnMeanVarStrides,
const XDataType* p_x,
const DyDataType* p_dy,
const ScaleDataType* p_scale,
const MeanVarDataType* p_savedMean,
const MeanVarDataType* p_savedInvVar,
double epsilon,
const DyElementwiseOp dy_elementwise_op,
DxDataType* p_dx,
ScaleDataType* p_dscale,
BiasDataType* p_dbias)
: p_x_(p_x),
p_dy_(p_dy),
p_scale_(p_scale),
p_savedMean_(p_savedMean),
p_savedInvVar_(p_savedInvVar),
epsilon_(epsilon),
dy_elementwise_op_(dy_elementwise_op),
p_dx_(p_dx),
p_dscale_(p_dscale),
p_dbias_(p_dbias)
{
ignore = xStrides;
ignore = dyStrides;
ignore = dxStrides;
ignore = bnScaleStrides;
ignore = bnBiasStrides;
ignore = bnMeanVarStrides;
if(xyLengths.size() != 4 || bnScaleBiasMeanVarLengths.size() != 1 ||
bnScaleBiasMeanVarLengths[0] != xyLengths[3])
throw std::runtime_error("Invalid tensor dimensions!");
if(reduceDims[0] != 0 || reduceDims[1] != 1 || reduceDims[2] != 2)
throw std::runtime_error("Invalid reduce dimensions!");
n_ = xyLengths[0];
h_ = xyLengths[1];
w_ = xyLengths[2];
c_ = xyLengths[3];
haveSavedMeanInvVar_ = (p_savedMean != nullptr && p_savedInvVar != nullptr);
}
const XDataType* p_x_;
const DyDataType* p_dy_;
const ScaleDataType* p_scale_;
const MeanVarDataType* p_savedMean_;
const MeanVarDataType* p_savedInvVar_;
double epsilon_;
const DyElementwiseOp dy_elementwise_op_;
DxDataType* p_dx_;
ScaleDataType* p_dscale_;
BiasDataType* p_dbias_;
bool haveSavedMeanInvVar_;
index_t n_, h_, w_, c_;
};
struct Invoker : public device::BaseInvoker
{
float Run(const Argument& arg)
{
auto thread_reduce_func = [&](auto iC) {
AccDataType reduceSize = type_convert<AccDataType>(arg.n_) *
type_convert<AccDataType>(arg.h_) *
type_convert<AccDataType>(arg.w_);
index_t offset_C = iC;
AccDataType mean;
AccDataType invVar;
if(arg.haveSavedMeanInvVar_)
{
mean = arg.p_savedMean_[offset_C];
invVar = arg.p_savedInvVar_[offset_C];
}
else
{
AccDataType meansquare;
meansquare = type_convert<AccDataType>(0.0f);
mean = type_convert<AccDataType>(0.0f);
// compute mean, meanquare, variance, inv-variance
for(index_t iN = 0; iN < arg.n_; iN++)
{
index_t offset_N = iN * arg.h_ * arg.w_ * arg.c_;
for(index_t iH = 0; iH < arg.h_; iH++)
{
index_t offset_H = iH * arg.w_ * arg.c_;
for(index_t iW = 0; iW < arg.w_; iW++)
{
index_t offset_W = iW * arg.c_;
auto offset = offset_N + offset_H + offset_W + offset_C;
AccDataType x = type_convert<AccDataType>(arg.p_x_[offset]);
mean += x;
meansquare += x * x;
};
}
};
mean = mean / reduceSize;
meansquare = meansquare / reduceSize;
AccDataType variance = meansquare - mean * mean;
invVar = type_convert<AccDataType>(1.0f) /
std::sqrt(type_convert<AccDataType>(arg.epsilon_) + variance);
};
AccDataType dbias = type_convert<AccDataType>(0.0f); // Sum on NHW of dy
AccDataType dscale = type_convert<AccDataType>(0.0f); // Sum on NHW of dy * norm_x
// 1) calculate dy * (x - mean) * inv-variance
// 2) calculate sum(dy) on NHW dimensions
// 3) calculate sum(dy * norm_x) on NHW dimensions
for(index_t iN = 0; iN < arg.n_; iN++)
{
index_t offset_N = iN * arg.h_ * arg.w_ * arg.c_;
for(index_t iH = 0; iH < arg.h_; iH++)
{
index_t offset_H = iH * arg.w_ * arg.c_;
for(index_t iW = 0; iW < arg.w_; iW++)
{
index_t offset_W = iW * arg.c_;
auto offset = offset_N + offset_H + offset_W + offset_C;
AccDataType x = type_convert<AccDataType>(arg.p_x_[offset]);
AccDataType norm_x = (x - mean) * invVar;
AccDataType dy = type_convert<AccDataType>(arg.p_dy_[offset]);
arg.dy_elementwise_op_(dy, dy);
dbias += dy;
dscale += norm_x * dy;
};
}
};
arg.p_dscale_[offset_C] = type_convert<ScaleDataType>(dscale);
arg.p_dbias_[offset_C] = type_convert<BiasDataType>(dbias);
AccDataType scale = type_convert<AccDataType>(arg.p_scale_[offset_C]);
AccDataType multiplier =
type_convert<AccDataType>(1.0f) / reduceSize * invVar * scale;
// 1) calculate tmp = dscale * (x - mean) * inv-variance
// 2) calculate dx = 1/nhw * inv-variance * scale * (nhw * dy - dbias - tmp)
for(index_t iN = 0; iN < arg.n_; iN++)
{
index_t offset_N = iN * arg.h_ * arg.w_ * arg.c_;
for(index_t iH = 0; iH < arg.h_; iH++)
{
index_t offset_H = iH * arg.w_ * arg.c_;
for(index_t iW = 0; iW < arg.w_; iW++)
{
index_t offset_W = iW * arg.c_;
auto offset = offset_N + offset_H + offset_W + offset_C;
AccDataType x = type_convert<AccDataType>(arg.p_x_[offset]);
AccDataType norm_x = (x - mean) * invVar;
AccDataType dy = type_convert<AccDataType>(arg.p_dy_[offset]);
arg.dy_elementwise_op_(dy, dy);
AccDataType tmpVal = norm_x * dscale;
AccDataType dx = multiplier * (reduceSize * dy - dbias - tmpVal);
arg.p_dx_[offset] = type_convert<XDataType>(dx);
};
}
};
};
std::size_t num_thread = std::thread::hardware_concurrency();
std::size_t work_per_thread = (arg.c_ + num_thread - 1) / num_thread;
std::vector<joinable_thread> threads(num_thread);
for(std::size_t it = 0; it < num_thread; ++it)
{
std::size_t ic_begin = it * work_per_thread;
std::size_t ic_end = std::min(static_cast<int>((it + 1) * work_per_thread), arg.c_);
auto f = [=] {
for(std::size_t ic = ic_begin; ic < ic_end; ++ic)
{
thread_reduce_func(ic);
}
};
threads[it] = joinable_thread(f);
}
return (0.0f);
};
float Run(const device::BaseArgument* p_arg,
const StreamConfig& /*stream_config*/ = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg));
};
};
bool IsSupportedArgument(const device::BaseArgument* p_arg) override
{
(void)p_arg;
return (true);
};
std::unique_ptr<device::BaseArgument>
MakeArgumentPointer(const std::array<index_t, 4> xyLengths,
const std::array<index_t, 4> xStrides,
const std::array<index_t, 4> dyStrides,
const std::array<index_t, 4> dxStrides,
const std::array<int, 3> reduceDims,
const std::array<ck::index_t, 1> bnScaleBiasMeanVarLengths,
const std::array<ck::index_t, 1> bnScaleStrides,
const std::array<ck::index_t, 1> bnBiasStrides,
const std::array<ck::index_t, 1> bnMeanVarStrides,
const void* p_x,
const void* p_dy,
const void* p_scale,
const void* p_savedMean,
const void* p_savedInvVar,
double epsilon,
const DyElementwiseOp dy_elementwise_op,
void* p_dx,
void* p_dscale,
void* p_dbias) override
{
return std::make_unique<Argument>(xyLengths,
xStrides,
dyStrides,
dxStrides,
reduceDims,
bnScaleBiasMeanVarLengths,
bnScaleStrides,
bnBiasStrides,
bnMeanVarStrides,
static_cast<const XDataType*>(p_x),
static_cast<const DyDataType*>(p_dy),
static_cast<const ScaleDataType*>(p_scale),
static_cast<const MeanVarDataType*>(p_savedMean),
static_cast<const MeanVarDataType*>(p_savedInvVar),
epsilon,
dy_elementwise_op,
static_cast<DxDataType*>(p_dx),
static_cast<ScaleDataType*>(p_dscale),
static_cast<BiasDataType*>(p_dbias));
};
std::unique_ptr<device::BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>();
};
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "Reference_BatchNorm_Backward_NHWC_C<" << std::endl;
// clang-format on
return str.str();
}
};
} // namespace host
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/device_batchnorm_backward.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// FP16
void add_device_batchnorm_backward_rank_4_3_f16_instances(
std::vector<std::unique_ptr<
DeviceBatchNormBwd<F16, F32, F32, F32, F16, F32, F32, PassThrough, 4, 3>>>&);
// FP32
void add_device_batchnorm_backward_rank_4_3_f32_instances(
std::vector<std::unique_ptr<
DeviceBatchNormBwd<F32, F32, F32, F32, F32, F32, F32, PassThrough, 4, 3>>>&);
// BF16
void add_device_batchnorm_backward_rank_4_3_bf16_instances(
std::vector<std::unique_ptr<
DeviceBatchNormBwd<BF16, F32, F32, F32, BF16, F32, F32, PassThrough, 4, 3>>>&);
// FP64
void add_device_batchnorm_backward_rank_4_3_f64_instances(
std::vector<std::unique_ptr<
DeviceBatchNormBwd<F64, F64, F64, F64, F64, F64, F64, PassThrough, 4, 3>>>&);
template <typename XDataType,
typename DxDataType,
typename DyDataType,
typename AccDataType,
typename ScaleDataType,
typename DscaleDbiasDataType,
typename MeanVarDataType,
typename DyElementwiseOp,
index_t Rank,
index_t NumReduceDim>
struct DeviceOperationInstanceFactory<
ck::tensor_operation::device::DeviceBatchNormBwd<XDataType,
DxDataType,
DyDataType,
AccDataType,
ScaleDataType,
DscaleDbiasDataType,
MeanVarDataType,
DyElementwiseOp,
Rank,
NumReduceDim>>
{
using DeviceOp = DeviceBatchNormBwd<XDataType,
DxDataType,
DyDataType,
AccDataType,
ScaleDataType,
DscaleDbiasDataType,
MeanVarDataType,
DyElementwiseOp,
Rank,
NumReduceDim>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
if constexpr(is_same_v<XDataType, F16> && is_same_v<DxDataType, F32> &&
is_same_v<DyDataType, F32> && is_same_v<AccDataType, F32> &&
is_same_v<ScaleDataType, F16> && is_same_v<DscaleDbiasDataType, F32> &&
is_same_v<MeanVarDataType, F32>)
{
if constexpr(Rank == 4 && NumReduceDim == 3 && is_same_v<DyElementwiseOp, PassThrough>)
{
add_device_batchnorm_backward_rank_4_3_f16_instances(op_ptrs);
}
}
else if constexpr(is_same_v<XDataType, F32> && is_same_v<DxDataType, F32> &&
is_same_v<DyDataType, F32> && is_same_v<AccDataType, F32> &&
is_same_v<ScaleDataType, F32> && is_same_v<DscaleDbiasDataType, F32> &&
is_same_v<MeanVarDataType, F32>)
{
if constexpr(Rank == 4 && NumReduceDim == 3 && is_same_v<DyElementwiseOp, PassThrough>)
{
add_device_batchnorm_backward_rank_4_3_f32_instances(op_ptrs);
}
}
else if constexpr(is_same_v<XDataType, BF16> && is_same_v<DxDataType, F32> &&
is_same_v<DyDataType, F32> && is_same_v<AccDataType, F32> &&
is_same_v<ScaleDataType, BF16> && is_same_v<DscaleDbiasDataType, F32> &&
is_same_v<MeanVarDataType, F32>)
{
if constexpr(Rank == 4 && NumReduceDim == 3 && is_same_v<DyElementwiseOp, PassThrough>)
{
add_device_batchnorm_backward_rank_4_3_bf16_instances(op_ptrs);
}
}
else if constexpr(is_same_v<XDataType, F64> && is_same_v<DxDataType, F64> &&
is_same_v<DyDataType, F64> && is_same_v<AccDataType, F64> &&
is_same_v<ScaleDataType, F64> && is_same_v<DscaleDbiasDataType, F64> &&
is_same_v<MeanVarDataType, F64>)
{
if constexpr(Rank == 4 && NumReduceDim == 3 && is_same_v<DyElementwiseOp, PassThrough>)
{
add_device_batchnorm_backward_rank_4_3_f64_instances(op_ptrs);
}
}
return op_ptrs;
}
};
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -3,4 +3,8 @@ add_instance_library(device_batchnorm_instance
device_batchnorm_forward_f32_instance.cpp
device_batchnorm_forward_bf16_instance.cpp
device_batchnorm_forward_f64_instance.cpp
device_batchnorm_backward_f16_instance.cpp
device_batchnorm_backward_f32_instance.cpp
device_batchnorm_backward_bf16_instance.cpp
device_batchnorm_backward_f64_instance.cpp
)
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batchnorm_backward_impl.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using BF16 = ck::bhalf_t;
using F32 = float;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
// clang-format off
template <index_t Rank, index_t NumReduceDim, typename DyElementwiseOp>
using device_batchnorm_backward_bf16_blockwise_instances =
std::tuple <
// XDataType, DxDataType, DyDataType, AccDataType, ScaleDataType, DscaleDbiasDataType, MeanVarDataType, DyElementwiseOp, Rank, NumReduceDim, UseMultiBlockInK, BLockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XDyDxVectorDim, XSrcVectorSize, DySrcVectorSize, DxDstVectorSize, ScaleSrcVectorSize, DscaleDbiasDstVectorSize, MeanVarSrcVectorSize
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 128, 2, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 128, 2, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 128, 2, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 128, 2, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 128, 2, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 128, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 64, 4, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 64, 4, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 64, 4, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 64, 4, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 64, 4, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 64, 4, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 32, 8, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 32, 8, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 32, 8, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 32, 8, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 32, 8, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 32, 8, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 16, 16, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 16, 16, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 16, 16, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 16, 16, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 16, 16, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 16, 16, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 8, 32, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 8, 32, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 8, 32, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 8, 32, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 8, 32, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 8, 32, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 4, 64, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 4, 64, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 4, 64, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 4, 64, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 4, 64, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 4, 64, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 2, 128, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 2, 128, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 2, 128, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 2, 128, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 2, 128, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 2, 128, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 1, 256, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 1, 256, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 1, 256, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 1, 256, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 1, 256, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 1, 256, 2, 2, 1, 1, 1, 1, 1, 1, 1>
>;
// clang-format on
// clang-format off
template <index_t Rank, index_t NumReduceDim, typename DyElementwiseOp>
using device_batchnorm_backward_bf16_multiblock_instances =
std::tuple <
// XDataType, DxDataType, DyDataType, AccDataType, ScaleDataType, BiasDataType, MeanVarDataType, DyElementwiseOp, Rank, NumReduceDim, UseMultiBlockInK, BLockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XDyDxVectorDim, XSrcVectorSize, DySrcVectorSize, DxDstVectorSize, ScaleSrcDstVectorSize, BiasDstVectorSize, MeanVarSrcVectorSize
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 128, 2, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 128, 2, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 128, 2, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 128, 2, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 128, 2, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 128, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 64, 4, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 64, 4, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 64, 4, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 64, 4, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 64, 4, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 64, 4, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 32, 8, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 32, 8, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 32, 8, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 32, 8, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 32, 8, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 32, 8, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 16, 16, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 16, 16, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 16, 16, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 16, 16, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 16, 16, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 16, 16, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 8, 32, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 8, 32, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 8, 32, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 8, 32, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 8, 32, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 8, 32, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 4, 64, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 4, 64, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 4, 64, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 4, 64, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 4, 64, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 4, 64, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 2, 128, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 2, 128, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 2, 128, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 2, 128, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 2, 128, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 2, 128, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 1, 256, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 1, 256, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 1, 256, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 1, 256, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 1, 256, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 1, 256, 2, 2, 1, 1, 1, 1, 1, 1, 1>
>;
// clang-format on
void add_device_batchnorm_backward_rank_4_3_bf16_instances(
std::vector<std::unique_ptr<
DeviceBatchNormBwd<BF16, F32, F32, F32, BF16, F32, F32, PassThrough, 4, 3>>>& instances)
{
add_device_operation_instances(
instances, device_batchnorm_backward_bf16_blockwise_instances<4, 3, PassThrough>{});
add_device_operation_instances(
instances, device_batchnorm_backward_bf16_multiblock_instances<4, 3, PassThrough>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batchnorm_backward_impl.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using F16 = ck::half_t;
using F32 = float;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
// clang-format off
template <index_t Rank, index_t NumReduceDim, typename DyElementwiseOp>
using device_batchnorm_backward_f16_blockwise_instances =
std::tuple <
// XDataType, DxDataType, DyDataType, AccDataType, ScaleDataType, DscaleDbiasDataType, MeanVarDataType, DyElementwiseOp, Rank, NumReduceDim, UseMultiBlockInK, BLockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XDyDxVectorDim, XSrcVectorSize, DySrcVectorSize, DxDstVectorSize, ScaleSrcVectorSize, DscaleDbiasDstVectorSize, MeanVarSrcVectorSize
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 128, 2, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 128, 2, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 128, 2, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 128, 2, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 128, 2, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 128, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 64, 4, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 64, 4, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 64, 4, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 64, 4, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 64, 4, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 64, 4, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 32, 8, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 32, 8, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 32, 8, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 32, 8, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 32, 8, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 32, 8, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 16, 16, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 16, 16, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 16, 16, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 16, 16, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 16, 16, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 16, 16, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 8, 32, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 8, 32, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 8, 32, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 8, 32, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 8, 32, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 8, 32, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 4, 64, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 4, 64, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 4, 64, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 4, 64, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 4, 64, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 4, 64, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 2, 128, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 2, 128, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 2, 128, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 2, 128, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 2, 128, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 2, 128, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 1, 256, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 1, 256, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 1, 256, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 1, 256, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 1, 256, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 1, 256, 2, 2, 1, 1, 1, 1, 1, 1, 1>
>;
// clang-format on
// clang-format off
template <index_t Rank, index_t NumReduceDim, typename DyElementwiseOp>
using device_batchnorm_backward_f16_multiblock_instances =
std::tuple <
// XDataType, DxDataType, DyDataType, AccDataType, ScaleDataType, BiasDataType, MeanVarDataType, DyElementwiseOp, Rank, NumReduceDim, UseMultiBlockInK, BLockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XDyDxVectorDim, XSrcVectorSize, DySrcVectorSize, DxDstVectorSize, ScaleSrcDstVectorSize, BiasDstVectorSize, MeanVarSrcVectorSize
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 128, 2, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 128, 2, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 128, 2, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 128, 2, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 128, 2, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 128, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 64, 4, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 64, 4, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 64, 4, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 64, 4, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 64, 4, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 64, 4, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 32, 8, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 32, 8, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 32, 8, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 32, 8, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 32, 8, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 32, 8, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 16, 16, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 16, 16, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 16, 16, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 16, 16, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 16, 16, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 16, 16, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 8, 32, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 8, 32, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 8, 32, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 8, 32, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 8, 32, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 8, 32, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 4, 64, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 4, 64, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 4, 64, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 4, 64, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 4, 64, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 4, 64, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 2, 128, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 2, 128, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 2, 128, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 2, 128, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 2, 128, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 2, 128, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 1, 256, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 1, 256, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 1, 256, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 1, 256, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 1, 256, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 1, 256, 2, 2, 1, 1, 1, 1, 1, 1, 1>
>;
// clang-format on
void add_device_batchnorm_backward_rank_4_3_f16_instances(
std::vector<
std::unique_ptr<DeviceBatchNormBwd<F16, F32, F32, F32, F16, F32, F32, PassThrough, 4, 3>>>&
instances)
{
add_device_operation_instances(
instances, device_batchnorm_backward_f16_blockwise_instances<4, 3, PassThrough>{});
add_device_operation_instances(
instances, device_batchnorm_backward_f16_multiblock_instances<4, 3, PassThrough>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batchnorm_backward_impl.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using F32 = float;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
// clang-format off
template <index_t Rank, index_t NumReduceDim, typename DyElementwiseOp>
using device_batchnorm_backward_f32_blockwise_instances = std::tuple<
// XDataType, DxDataType, DyDataType, AccDataType, ScaleDataType, DscaleDbiasDataType, MeanVarDataType, DyElementwiseOp, Rank, NumReduceDim, UseMultiBlockInK, BLockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XDyDxVectorDim, XSrcVectorSize, DySrcVectorSize, DxDstVectorSize, ScaleSrcVectorSize, DscaleDbiasDstVectorSize, MeanVarSrcVectorSize
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 128, 2, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 128, 2, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 128, 2, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 128, 2, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 128, 2, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 128, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 64, 4, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 64, 4, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 64, 4, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 64, 4, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 64, 4, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 64, 4, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 32, 8, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 32, 8, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 32, 8, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 32, 8, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 32, 8, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 32, 8, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 16, 16, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 16, 16, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 16, 16, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 16, 16, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 16, 16, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 16, 16, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 8, 32, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 8, 32, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 8, 32, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 8, 32, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 8, 32, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 8, 32, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 4, 64, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 4, 64, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 4, 64, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 4, 64, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 4, 64, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 4, 64, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 2, 128, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 2, 128, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 2, 128, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 2, 128, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 2, 128, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 2, 128, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 1, 256, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 1, 256, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 1, 256, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 1, 256, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 1, 256, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 1, 256, 2, 2, 1, 1, 1, 1, 1, 1, 1>
>;
// clang-format on
// clang-format off
template <index_t Rank, index_t NumReduceDim, typename DyElementwiseOp>
using device_batchnorm_backward_f32_multiblock_instances =
std::tuple <
// XDataType, DxDataType, DyDataType, AccDataType, ScaleDataType, BiasDataType, MeanVarDataType, DyElementwiseOp, Rank, NumReduceDim, UseMultiBlockInK, BLockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XDyDxVectorDim, XSrcVectorSize, DySrcVectorSize, DxDstVectorSize, ScaleSrcDstVectorSize, BiasDstVectorSize, MeanVarSrcVectorSize
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 128, 2, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 128, 2, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 128, 2, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 128, 2, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 128, 2, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 128, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 64, 4, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 64, 4, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 64, 4, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 64, 4, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 64, 4, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 64, 4, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 32, 8, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 32, 8, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 32, 8, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 32, 8, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 32, 8, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 32, 8, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 16, 16, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 16, 16, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 16, 16, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 16, 16, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 16, 16, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 16, 16, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 8, 32, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 8, 32, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 8, 32, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 8, 32, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 8, 32, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 8, 32, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 4, 64, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 4, 64, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 4, 64, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 4, 64, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 4, 64, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 4, 64, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 2, 128, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 2, 128, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 2, 128, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 2, 128, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 2, 128, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 2, 128, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 1, 256, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 1, 256, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 1, 256, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 1, 256, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 1, 256, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 1, 256, 2, 2, 1, 1, 1, 1, 1, 1, 1>
>;
// clang-format on
void add_device_batchnorm_backward_rank_4_3_f32_instances(
std::vector<
std::unique_ptr<DeviceBatchNormBwd<F32, F32, F32, F32, F32, F32, F32, PassThrough, 4, 3>>>&
instances)
{
add_device_operation_instances(
instances, device_batchnorm_backward_f32_blockwise_instances<4, 3, PassThrough>{});
add_device_operation_instances(
instances, device_batchnorm_backward_f32_multiblock_instances<4, 3, PassThrough>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batchnorm_backward_impl.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using F64 = double;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
// clang-format off
template <index_t Rank, index_t NumReduceDim, typename DyElementwiseOp>
using device_batchnorm_backward_f64_blockwise_instances = std::tuple<
// XDataType, DxDataType, DyDataType, AccDataType, ScaleDataType, DscaleDbiasDataType, MeanVarDataType, DyElementwiseOp, Rank, NumReduceDim, UseMultiBlockInK, BLockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XDyDxVectorDim, XSrcVectorSize, DySrcVectorSize, DxDstVectorSize, ScaleSrcVectorSize, DscaleDbiasDstVectorSize, MeanVarSrcVectorSize
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, false, 256, 128, 2, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, false, 256, 128, 2, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, false, 256, 128, 2, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, false, 256, 128, 2, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, false, 256, 128, 2, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, false, 256, 128, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, false, 256, 64, 4, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, false, 256, 64, 4, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, false, 256, 64, 4, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, false, 256, 64, 4, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, false, 256, 64, 4, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, false, 256, 64, 4, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, false, 256, 32, 8, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, false, 256, 32, 8, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, false, 256, 32, 8, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, false, 256, 32, 8, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, false, 256, 32, 8, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, false, 256, 32, 8, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, false, 256, 16, 16, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, false, 256, 16, 16, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, false, 256, 16, 16, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, false, 256, 16, 16, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, false, 256, 16, 16, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, false, 256, 16, 16, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, false, 256, 8, 32, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, false, 256, 8, 32, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, false, 256, 8, 32, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, false, 256, 8, 32, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, false, 256, 8, 32, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, false, 256, 8, 32, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, false, 256, 4, 64, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, false, 256, 4, 64, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, false, 256, 4, 64, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, false, 256, 4, 64, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, false, 256, 4, 64, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, false, 256, 4, 64, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, false, 256, 2, 128, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, false, 256, 2, 128, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, false, 256, 2, 128, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, false, 256, 2, 128, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, false, 256, 2, 128, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, false, 256, 2, 128, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, false, 256, 1, 256, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, false, 256, 1, 256, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, false, 256, 1, 256, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, false, 256, 1, 256, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, false, 256, 1, 256, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, false, 256, 1, 256, 2, 2, 1, 1, 1, 1, 1, 1, 1>
>;
// clang-format on
// clang-format off
template <index_t Rank, index_t NumReduceDim, typename DyElementwiseOp>
using device_batchnorm_backward_f64_multiblock_instances =
std::tuple <
// XDataType, DxDataType, DyDataType, AccDataType, ScaleDataType, BiasDataType, MeanVarDataType, DyElementwiseOp, Rank, NumReduceDim, UseMultiBlockInK, BLockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XDyDxVectorDim, XSrcVectorSize, DySrcVectorSize, DxDstVectorSize, ScaleSrcDstVectorSize, BiasDstVectorSize, MeanVarSrcVectorSize
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, true, 256, 128, 2, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, true, 256, 128, 2, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, true, 256, 128, 2, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, true, 256, 128, 2, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, true, 256, 128, 2, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, true, 256, 128, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, true, 256, 64, 4, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, true, 256, 64, 4, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, true, 256, 64, 4, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, true, 256, 64, 4, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, true, 256, 64, 4, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, true, 256, 64, 4, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, true, 256, 32, 8, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, true, 256, 32, 8, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, true, 256, 32, 8, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, true, 256, 32, 8, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, true, 256, 32, 8, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, true, 256, 32, 8, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, true, 256, 16, 16, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, true, 256, 16, 16, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, true, 256, 16, 16, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, true, 256, 16, 16, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, true, 256, 16, 16, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, true, 256, 16, 16, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, true, 256, 8, 32, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, true, 256, 8, 32, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, true, 256, 8, 32, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, true, 256, 8, 32, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, true, 256, 8, 32, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, true, 256, 8, 32, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, true, 256, 4, 64, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, true, 256, 4, 64, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, true, 256, 4, 64, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, true, 256, 4, 64, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, true, 256, 4, 64, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, true, 256, 4, 64, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, true, 256, 2, 128, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, true, 256, 2, 128, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, true, 256, 2, 128, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, true, 256, 2, 128, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, true, 256, 2, 128, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, true, 256, 2, 128, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, true, 256, 1, 256, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, true, 256, 1, 256, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, true, 256, 1, 256, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, true, 256, 1, 256, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, true, 256, 1, 256, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, true, 256, 1, 256, 2, 2, 1, 1, 1, 1, 1, 1, 1>
>;
// clang-format on
void add_device_batchnorm_backward_rank_4_3_f64_instances(
std::vector<
std::unique_ptr<DeviceBatchNormBwd<F64, F64, F64, F64, F64, F64, F64, PassThrough, 4, 3>>>&
instances)
{
add_device_operation_instances(
instances, device_batchnorm_backward_f64_blockwise_instances<4, 3, PassThrough>{});
add_device_operation_instances(
instances, device_batchnorm_backward_f64_multiblock_instances<4, 3, PassThrough>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -27,6 +27,7 @@ set(PROFILER_SOURCE
src/profile_layernorm.cpp
src/profile_softmax.cpp
src/profile_batchnorm_fwd.cpp
src/profile_batchnorm_bwd.cpp
)
add_executable(ckProfiler ${PROFILER_SOURCE})
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iomanip>
#include <stdexcept>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/tensor_operation_instance/gpu/batchnorm_backward.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batchnorm_backward.hpp"
namespace ck {
namespace profiler {
template <typename XDataType,
typename DxDataType,
typename DyDataType,
typename AccDataType,
typename ScaleDataType,
typename DscaleDbiasDataType,
typename MeanVarDataType,
index_t Rank,
index_t NumBatchNormReduceDim>
bool profile_batchnorm_backward_impl(bool do_verification,
int init_method,
bool do_dumpout,
bool time_kernel,
const std::vector<size_t> inOutLengths,
const std::vector<int> reduceDims,
bool haveSavedMeanInvVar,
double epsilon)
{
if(inOutLengths.size() != Rank || reduceDims.size() != NumBatchNormReduceDim)
{
throw std::runtime_error("Invalid tensor lengths or number of reduce dimensions!");
};
std::vector<size_t> scaleBiasMeanVarLengths;
// used for calculating the effective transferred bytes by each operation
size_t total_length;
size_t invariant_length = 1;
total_length =
std::accumulate(inOutLengths.begin(), inOutLengths.end(), 1, std::multiplies<size_t>{});
if(std::any_of(reduceDims.begin(), reduceDims.end(), [](int d) { return d < 0 || d >= Rank; }))
throw std::runtime_error("Invalid reduce dimensions!");
for(int dim = 0; dim < Rank; dim++)
{
if(std::none_of(reduceDims.begin(), reduceDims.end(), [&](int d) { return dim == d; }))
{
scaleBiasMeanVarLengths.push_back(inOutLengths[dim]);
invariant_length *= inOutLengths[dim];
};
}
// input data of the batchnorm backward algorithm
Tensor<XDataType> x(inOutLengths);
Tensor<DyDataType> dy(inOutLengths);
Tensor<ScaleDataType> bnScale(scaleBiasMeanVarLengths);
Tensor<MeanVarDataType> savedMean(scaleBiasMeanVarLengths);
Tensor<MeanVarDataType> savedInvVar(scaleBiasMeanVarLengths);
// savedVariance is only used for initializing savedInvVar
Tensor<MeanVarDataType> savedVariance(scaleBiasMeanVarLengths);
// output data of the batchnorm backward algorithm
Tensor<DxDataType> dx_ref(inOutLengths);
Tensor<DxDataType> dx(inOutLengths);
Tensor<DscaleDbiasDataType> dscale(scaleBiasMeanVarLengths);
Tensor<DscaleDbiasDataType> dbias(scaleBiasMeanVarLengths);
Tensor<DscaleDbiasDataType> dscale_ref(scaleBiasMeanVarLengths);
Tensor<DscaleDbiasDataType> dbias_ref(scaleBiasMeanVarLengths);
auto inOutStrides = x.mDesc.GetStrides();
auto scaleBiasMeanVarStrides = bnScale.mDesc.GetStrides();
std::size_t num_thread = std::thread::hardware_concurrency();
if(haveSavedMeanInvVar)
{
const float x_mean = 0.0f;
const float x_stddev = 1.0f;
const float noise_stddev = 0.0001f;
// input data in normal distribution
x.GenerateTensorValue(GeneratorTensor_4<XDataType>{x_mean, x_stddev}, num_thread);
// initialize the savedMean to be values with tiny variation to the mean of the x values
savedMean.GenerateTensorValue(GeneratorTensor_4<MeanVarDataType>{x_mean, noise_stddev},
num_thread);
// initialize the variance to be values with tiny variation to the variance of the x values
savedVariance.GenerateTensorValue(
GeneratorTensor_4<MeanVarDataType>{x_stddev * x_stddev, noise_stddev}, num_thread);
auto it_src = savedVariance.mData.begin();
auto it_dst = savedInvVar.mData.begin();
float tmp_epsilon = std::numeric_limits<float>::epsilon();
while(it_src != savedVariance.mData.end())
{
*it_dst = type_convert<AccDataType>(
1.0f / std::sqrtf(type_convert<float>(*it_src) + tmp_epsilon));
it_src++;
it_dst++;
};
}
else
{
const float x_mean = 0.0f;
const float x_stddev = 1.0f;
// input data in normal distribution
x.GenerateTensorValue(GeneratorTensor_4<XDataType>{x_mean, x_stddev}, num_thread);
};
if(do_verification)
{
switch(init_method)
{
case 0:
dy.GenerateTensorValue(GeneratorTensor_0<DyDataType>{}, num_thread);
bnScale.GenerateTensorValue(GeneratorTensor_0<ScaleDataType>{}, num_thread);
break;
case 1:
dy.GenerateTensorValue(GeneratorTensor_1<DyDataType>{1}, num_thread);
bnScale.GenerateTensorValue(GeneratorTensor_1<ScaleDataType>{1}, num_thread);
break;
case 2:
dy.GenerateTensorValue(GeneratorTensor_2<DyDataType>{-2, 2}, num_thread);
bnScale.GenerateTensorValue(GeneratorTensor_2<ScaleDataType>{-5, 5}, num_thread);
break;
default:
dy.GenerateTensorValue(GeneratorTensor_3<DyDataType>{-0.2f, 0.2f}, num_thread);
bnScale.GenerateTensorValue(GeneratorTensor_3<ScaleDataType>{-0.5f, 0.5f}, num_thread);
}
};
// input data of the batchnorm backward algorithm
DeviceMem x_dev(sizeof(XDataType) * x.mDesc.GetElementSpaceSize());
DeviceMem dy_dev(sizeof(DyDataType) * dy.mDesc.GetElementSpaceSize());
DeviceMem bnScale_dev(sizeof(ScaleDataType) * bnScale.mDesc.GetElementSpaceSize());
DeviceMem savedMean_dev(sizeof(MeanVarDataType) * savedMean.mDesc.GetElementSpaceSize());
DeviceMem savedInvVar_dev(sizeof(MeanVarDataType) * savedInvVar.mDesc.GetElementSpaceSize());
// output data of the batchnorm backward algorithm
DeviceMem dx_dev(sizeof(DxDataType) * dx.mDesc.GetElementSpaceSize());
DeviceMem dscale_dev(sizeof(DscaleDbiasDataType) * dscale.mDesc.GetElementSpaceSize());
DeviceMem dbias_dev(sizeof(DscaleDbiasDataType) * dbias.mDesc.GetElementSpaceSize());
x_dev.ToDevice(x.mData.data());
dy_dev.ToDevice(dy.mData.data());
bnScale_dev.ToDevice(bnScale.mData.data());
if(haveSavedMeanInvVar)
{
savedMean_dev.ToDevice(savedMean.mData.data());
savedInvVar_dev.ToDevice(savedInvVar.mData.data());
};
std::array<index_t, Rank> arrInOutLengths;
std::array<index_t, Rank> arrInOutStrides;
std::array<index_t, Rank - NumBatchNormReduceDim> arrScaleBiasMeanVarLengths;
std::array<index_t, Rank - NumBatchNormReduceDim> arrScaleBiasMeanVarStrides;
std::array<int, NumBatchNormReduceDim> arrReduceDims;
std::copy(inOutLengths.begin(), inOutLengths.end(), arrInOutLengths.begin());
std::copy(inOutStrides.begin(), inOutStrides.end(), arrInOutStrides.begin());
std::copy(scaleBiasMeanVarLengths.begin(),
scaleBiasMeanVarLengths.end(),
arrScaleBiasMeanVarLengths.begin());
std::copy(scaleBiasMeanVarStrides.begin(),
scaleBiasMeanVarStrides.end(),
arrScaleBiasMeanVarStrides.begin());
std::copy(reduceDims.begin(), reduceDims.end(), arrReduceDims.begin());
using PassThroughOp = ck::tensor_operation::element_wise::PassThrough;
// add device batchnorm-backward instances
using DeviceOp = ck::tensor_operation::device::DeviceBatchNormBwd<XDataType,
DxDataType,
DxDataType,
AccDataType,
ScaleDataType,
DscaleDbiasDataType,
MeanVarDataType,
PassThroughOp,
Rank,
NumBatchNormReduceDim>;
// get device op instances
const auto instance_ptrs =
ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
DeviceOp>::GetInstances();
std::cout << "found " << instance_ptrs.size() << " instances" << std::endl;
std::string best_instance_name;
float best_avg_time = std::numeric_limits<float>::max();
float best_gb_per_sec = 0;
if(do_verification)
{
using ReferenceBatchNormBwdInstance =
ck::tensor_operation::host::ReferenceBatchNormBwd<XDataType,
DxDataType,
DyDataType,
AccDataType,
ScaleDataType,
DscaleDbiasDataType,
MeanVarDataType,
PassThroughOp,
Rank,
NumBatchNormReduceDim>;
auto batchNormBwd_ref = ReferenceBatchNormBwdInstance{};
auto argument_ptr_ref = batchNormBwd_ref.MakeArgumentPointer(
arrInOutLengths,
arrInOutStrides,
arrInOutStrides,
arrInOutStrides,
arrReduceDims,
arrScaleBiasMeanVarLengths,
arrScaleBiasMeanVarStrides,
arrScaleBiasMeanVarStrides,
arrScaleBiasMeanVarStrides,
x.mData.data(),
dy.mData.data(),
bnScale.mData.data(),
haveSavedMeanInvVar ? savedMean.mData.data() : nullptr,
haveSavedMeanInvVar ? savedInvVar.mData.data() : nullptr,
epsilon,
PassThroughOp{},
dx_ref.mData.data(),
dscale_ref.mData.data(),
dbias_ref.mData.data());
if(!batchNormBwd_ref.IsSupportedArgument(argument_ptr_ref.get()))
{
std::cout << "The runtime parameters not supported by the reference instance, exiting!"
<< std::endl;
return (false);
};
auto invoker_ptr_ref = batchNormBwd_ref.MakeInvokerPointer();
(void)invoker_ptr_ref->Run(argument_ptr_ref.get());
}
int num_kernel = 0;
bool pass = true;
for(auto& inst_ptr : instance_ptrs)
{
auto argument_ptr = inst_ptr->MakeArgumentPointer(
arrInOutLengths,
arrInOutStrides,
arrInOutStrides,
arrInOutStrides,
arrReduceDims,
arrScaleBiasMeanVarLengths,
arrScaleBiasMeanVarStrides,
arrScaleBiasMeanVarStrides,
arrScaleBiasMeanVarStrides,
x_dev.GetDeviceBuffer(),
dy_dev.GetDeviceBuffer(),
bnScale_dev.GetDeviceBuffer(),
haveSavedMeanInvVar ? savedMean_dev.GetDeviceBuffer() : nullptr,
haveSavedMeanInvVar ? savedInvVar_dev.GetDeviceBuffer() : nullptr,
epsilon,
PassThroughOp{},
dx_dev.GetDeviceBuffer(),
dscale_dev.GetDeviceBuffer(),
dbias_dev.GetDeviceBuffer());
if(inst_ptr->IsSupportedArgument(argument_ptr.get()))
{
num_kernel++;
}
else
{
if(time_kernel)
{
std::cout << inst_ptr->GetTypeString()
<< " skipped due to unsupported argument: " << std::endl;
}
continue;
};
size_t workspace_sz = inst_ptr->GetWorkSpaceSize(argument_ptr.get());
DeviceMem workspace_dev(workspace_sz);
inst_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace_dev.GetDeviceBuffer());
auto invoker_ptr = inst_ptr->MakeInvokerPointer();
float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
size_t num_bytes = 0;
// inputing of x, dy, scale, outputing of dx, dscale, dbias
num_bytes += total_length * (sizeof(XDataType) + sizeof(DyDataType) + sizeof(DxDataType)) +
invariant_length * sizeof(DscaleDbiasDataType) * 2;
// inputting of savedMean, savedInvVariance
if(haveSavedMeanInvVar)
num_bytes += invariant_length * sizeof(MeanVarDataType) * 2;
float gb_per_sec = num_bytes / 1.E6 / avg_time;
if(time_kernel)
std::cout << "Perf: " << avg_time << " ms, " << gb_per_sec << " GB/s, "
<< inst_ptr->GetTypeString() << std::endl;
if(avg_time < best_avg_time)
{
best_instance_name = inst_ptr->GetTypeString();
best_avg_time = avg_time;
best_gb_per_sec = gb_per_sec;
}
if(do_verification)
{
using ck::utils::check_err;
bool single_pass = true;
dx_dev.FromDevice(dx.mData.data());
dscale_dev.FromDevice(dscale.data());
dbias_dev.FromDevice(dbias.data());
// clang-format off
single_pass = single_pass && ck::utils::check_err(dx.mData, dx_ref.mData, "dx result:", 5e-4, 5e-4);
single_pass = single_pass && ck::utils::check_err(dscale.mData, dscale_ref.mData, "dScale result:", 3e-3, 3e-3);
single_pass = single_pass && ck::utils::check_err(dbias.mData, dbias_ref.mData, "dBias result:", 3e-3, 3e-3);
// clang-format on
pass = pass && single_pass;
};
if(do_dumpout)
{
using ck::host_common::dumpBufferToFile;
// clang-format off
dumpBufferToFile("dump_x.bin", x.mData.data(), x.mDesc.GetElementSize());
dumpBufferToFile("dump_dy.bin", dy.mData.data(), dy.mDesc.GetElementSize());
dumpBufferToFile("dump_dx.bin", dx.mData.data(), dx.mDesc.GetElementSize());
dumpBufferToFile("dump_dx_ref.bin", dx_ref.mData.data(), dx_ref.mDesc.GetElementSize());
dumpBufferToFile("dump_dscale.bin", dscale.mData.data(), dscale.mDesc.GetElementSize());
dumpBufferToFile("dump_dscale_ref.bin", dscale_ref.mData.data(), dscale_ref.mDesc.GetElementSize());
// clang-format off
};
}
if(time_kernel)
{
std::cout << "best perf = " << best_avg_time << " ms, " << best_gb_per_sec << " GB/s, "
<< best_instance_name << std::endl;
}
if(num_kernel == 0)
{
std::cout << "Error: No kernel is applicable" << std::endl;
return false;
}
return pass;
}
} // namespace profiler
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <vector>
#include <getopt.h>
#include "ck/library/utility/host_common_util.hpp"
#include "profiler/include/profile_batchnorm_backward_impl.hpp"
using ck::index_t;
using namespace std;
static const struct option long_options[] = {{"inOutLengths", required_argument, nullptr, 'D'},
{"reduceDims", required_argument, nullptr, 'R'},
{"dumpout", required_argument, nullptr, 'o'},
{"verify", required_argument, nullptr, 'v'},
{"help", no_argument, nullptr, '?'},
{nullptr, 0, nullptr, 0}};
class BatchnormBwdArgParser
{
private:
int option_index = 0;
public:
std::vector<size_t> inLengths;
std::vector<int> reduceDims;
bool do_verification = false;
bool do_dumpout = false;
bool haveSavedMeanInvVar;
int data_type = 0;
int init_method = 2;
bool time_kernel = false;
BatchnormBwdArgParser() = default;
~BatchnormBwdArgParser() = default;
void show_usage(const char* cmd)
{
// clang-format off
std::cout << "Usage of " << cmd << std::endl;
std::cout << "--inOutLengths or -D, comma separated list of input tensor dimension lengths, must have 4 integers for nhwc" << std::endl;
std::cout << "--reduceDims or -R, comma separated list of dimensions to reduce on" << std::endl;
std::cout << "--verify or -v, 1/0 to indicate whether to verify the result by comparing with the host-based batch-normalization" << std::endl;
std::cout << "Arg1: data type (0: fp16, 1: fp32, 5: bp16, 6: fp64)" << std::endl;
std::cout << "Arg2 -- 1/0 to indicate whether to use saved mean and invVariance" << std::endl;
std::cout << "Arg3 -- init method used for dy and bnScale (0=no init, 1=single integer value, 2=scope integer value, 3=decimal value)" << std::endl;
std::cout << "Arg4 -- time kernel (0=no, 1=yes)" << std::endl;
// clang-format on
};
int operator()(int argc, char* argv[])
{
using ck::host_common::getTypeValuesFromString;
int ch;
optind++; // to skip the module name
while(1)
{
ch = getopt_long(argc, argv, "D:R:v:o:", long_options, &option_index);
if(ch == -1)
break;
switch(ch)
{
case 'D':
if(!optarg)
throw std::runtime_error("Invalid option format!");
inLengths = getTypeValuesFromString<size_t>(optarg);
break;
case 'R':
if(!optarg)
throw std::runtime_error("Invalid option format!");
reduceDims = getTypeValuesFromString<int>(optarg);
break;
case 'v':
if(!optarg)
throw std::runtime_error("Invalid option format!");
do_verification = static_cast<bool>(std::atoi(optarg));
break;
case 'o':
if(!optarg)
throw std::runtime_error("Invalid option format!");
do_dumpout = static_cast<bool>(std::atoi(optarg));
break;
case '?':
if(std::string(long_options[option_index].name) == "help")
{
show_usage(argv[0]);
return -1;
};
break;
default:
show_usage(argv[0]);
std::cerr << "Invalid cmd-line options!" << std::endl;
return -1;
};
};
if(optind + 4 > argc)
throw std::runtime_error("Invalid cmd-line arguments, more argumetns are needed!");
data_type = std::atoi(argv[optind++]);
haveSavedMeanInvVar = std::atoi(argv[optind++]);
init_method = std::atoi(argv[optind++]);
time_kernel = static_cast<bool>(std::atoi(argv[optind++]));
if(data_type != 0 && data_type != 1 && data_type != 3 && data_type != 5 && data_type != 6)
return -1;
return 0;
};
}; // end of class AppArgs
static const double epsilon = std::numeric_limits<float>::epsilon();
int profile_batchnorm_backward(int argc, char* argv[])
{
using ck::profiler::profile_batchnorm_backward_impl;
BatchnormBwdArgParser arg_parser;
if(arg_parser(argc, argv) != 0)
return -1;
using F16 = ck::half_t;
using F32 = float;
using BF16 = ck::bhalf_t;
using F64 = double;
if(arg_parser.data_type == 0)
{
if(arg_parser.inLengths.size() == 4 && arg_parser.reduceDims.size() == 3)
{
profile_batchnorm_backward_impl<F16, F32, F32, F32, F16, F32, F32, 4, 3>(
arg_parser.do_verification,
arg_parser.init_method,
arg_parser.do_dumpout,
arg_parser.time_kernel,
arg_parser.inLengths,
arg_parser.reduceDims,
arg_parser.haveSavedMeanInvVar,
epsilon);
};
}
else if(arg_parser.data_type == 1)
{
if(arg_parser.inLengths.size() == 4 && arg_parser.reduceDims.size() == 3)
{
profile_batchnorm_backward_impl<F32, F32, F32, F32, F32, F32, F32, 4, 3>(
arg_parser.do_verification,
arg_parser.init_method,
arg_parser.do_dumpout,
arg_parser.time_kernel,
arg_parser.inLengths,
arg_parser.reduceDims,
arg_parser.haveSavedMeanInvVar,
epsilon);
};
}
else if(arg_parser.data_type == 5)
{
if(arg_parser.inLengths.size() == 4 && arg_parser.reduceDims.size() == 3)
{
profile_batchnorm_backward_impl<BF16, F32, F32, F32, BF16, F32, F32, 4, 3>(
arg_parser.do_verification,
arg_parser.init_method,
arg_parser.do_dumpout,
arg_parser.time_kernel,
arg_parser.inLengths,
arg_parser.reduceDims,
arg_parser.haveSavedMeanInvVar,
epsilon);
};
}
else if(arg_parser.data_type == 6)
{
if(arg_parser.inLengths.size() == 4 && arg_parser.reduceDims.size() == 3)
{
profile_batchnorm_backward_impl<F64, F64, F64, F64, F64, F64, F64, 4, 3>(
arg_parser.do_verification,
arg_parser.init_method,
arg_parser.do_dumpout,
arg_parser.time_kernel,
arg_parser.inLengths,
arg_parser.reduceDims,
arg_parser.haveSavedMeanInvVar,
epsilon);
};
}
return 0;
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment