Commit b7f500f0 authored by rocking5566's avatar rocking5566 Committed by rocking
Browse files

Merge branch 'develop' into gemm_layernorm_welford

parents 694057a7 4e6a5575
add_executable(client_batchnorm_fwd_nhwc batchnorm_fwd_nhwc.cpp)
target_link_libraries(client_batchnorm_fwd_nhwc PRIVATE composable_kernel::device_operations)
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <functional>
#include <numeric>
#include <iomanip>
#include <iostream>
#include <vector>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/device_reduce.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/gpu/batchnorm_forward.hpp"
using XDataType = float;
using YDataType = float;
using AccDataType = float;
using ScaleDataType = AccDataType;
using BiasDataType = AccDataType;
using MeanVarDataType = AccDataType;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
constexpr int Rank = 4;
constexpr int NumBatchNormReduceDim = 3;
const double epsilon = std::numeric_limits<float>::epsilon();
const double averageFactor = 0.1;
struct SimpleDeviceMem
{
SimpleDeviceMem() = delete;
SimpleDeviceMem(std::size_t mem_size) : p_mem_{}
{
(void)hipMalloc(static_cast<void**>(&p_mem_), mem_size);
}
void* GetDeviceBuffer() { return p_mem_; }
~SimpleDeviceMem() { (void)hipFree(p_mem_); }
void* p_mem_;
};
int main(int argc, char* argv[])
{
std::array<ck::index_t, Rank> xyLengths{16, 8, 128, 256};
std::array<ck::index_t, Rank> xyStrides{8 * 128 * 256, 128 * 256, 256, 1};
std::array<ck::index_t, Rank - NumBatchNormReduceDim> scaleBiasMeanVarLengths{256};
std::array<ck::index_t, Rank - NumBatchNormReduceDim> scaleBiasMeanVarStrides{1};
std::array<int, NumBatchNormReduceDim> reduceDims{0, 1, 2};
ck::index_t numXYElement =
std::accumulate(xyLengths.begin(), xyLengths.end(), 1, std::multiplies<ck::index_t>());
ck::index_t numScaleBiasMeanVarElement = std::accumulate(scaleBiasMeanVarLengths.begin(),
scaleBiasMeanVarLengths.end(),
1,
std::multiplies<ck::index_t>());
SimpleDeviceMem x(sizeof(XDataType) * numXYElement);
SimpleDeviceMem y(sizeof(YDataType) * numXYElement);
SimpleDeviceMem scale(sizeof(ScaleDataType) * numScaleBiasMeanVarElement);
SimpleDeviceMem bias(sizeof(BiasDataType) * numScaleBiasMeanVarElement);
SimpleDeviceMem mean(sizeof(MeanVarDataType) * numScaleBiasMeanVarElement);
SimpleDeviceMem invVariance(sizeof(MeanVarDataType) * numScaleBiasMeanVarElement);
using DeviceOp = ck::tensor_operation::device::DeviceBatchNormFwd<XDataType,
YDataType,
AccDataType,
ScaleDataType,
BiasDataType,
MeanVarDataType,
PassThrough,
Rank,
NumBatchNormReduceDim>;
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
DeviceOp>::GetInstances();
std::cout << "found " << op_ptrs.size() << " instances" << std::endl;
std::string best_op_name;
bool found = false;
int best_op_id = -1;
float best_ave_time = std::numeric_limits<float>::max();
float best_gb_per_sec = 0;
// profile device operation instances
std::cout << "Run all instances and do timing" << std::endl;
for(int i = 0; i < op_ptrs.size(); ++i)
{
auto& op_ptr = op_ptrs[i];
auto argument_ptr = op_ptr->MakeArgumentPointer(xyLengths,
xyStrides,
xyStrides,
reduceDims,
scaleBiasMeanVarLengths,
scaleBiasMeanVarStrides,
scaleBiasMeanVarStrides,
scaleBiasMeanVarStrides,
x.GetDeviceBuffer(),
scale.GetDeviceBuffer(),
bias.GetDeviceBuffer(),
epsilon,
PassThrough{},
y.GetDeviceBuffer(),
mean.GetDeviceBuffer(),
invVariance.GetDeviceBuffer(),
averageFactor,
nullptr,
nullptr);
auto invoker_ptr = op_ptr->MakeInvokerPointer();
std::string op_name = op_ptr->GetTypeString();
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{
size_t workspace_sz = op_ptr->GetWorkSpaceSize(argument_ptr.get());
SimpleDeviceMem workspace(workspace_sz);
op_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace.GetDeviceBuffer());
float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true});
std::size_t num_bytes =
numXYElement * (sizeof(XDataType) + sizeof(YDataType)) +
numScaleBiasMeanVarElement * (sizeof(ScaleDataType) + sizeof(BiasDataType) +
sizeof(MeanVarDataType) + sizeof(MeanVarDataType));
float gb_per_sec = num_bytes / 1.E6 / ave_time;
std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << gb_per_sec << " GB/s, "
<< op_name << std::endl;
if(ave_time < best_ave_time)
{
found = true;
best_op_id = i;
best_op_name = op_name;
best_ave_time = ave_time;
best_gb_per_sec = gb_per_sec;
}
}
else
{
std::cout << op_name << " does not support this problem" << std::endl;
}
}
if(found)
{
std::cout << "Best Perf: " << best_ave_time << " ms, " << best_gb_per_sec << " GB/s, "
<< best_op_name << std::endl;
// run the best intance
auto& op_ptr = op_ptrs[best_op_id];
std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString()
<< std::endl;
auto argument_ptr = op_ptr->MakeArgumentPointer(xyLengths,
xyStrides,
xyStrides,
reduceDims,
scaleBiasMeanVarLengths,
scaleBiasMeanVarStrides,
scaleBiasMeanVarStrides,
scaleBiasMeanVarStrides,
x.GetDeviceBuffer(),
scale.GetDeviceBuffer(),
bias.GetDeviceBuffer(),
epsilon,
PassThrough{},
y.GetDeviceBuffer(),
mean.GetDeviceBuffer(),
invVariance.GetDeviceBuffer(),
averageFactor,
nullptr,
nullptr);
auto invoker_ptr = op_ptr->MakeInvokerPointer();
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false});
}
std::cout << "Done" << std::endl;
}
return 0;
}
......@@ -15,7 +15,7 @@
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/host_common_util.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batchnorm_forward_nhwc_c.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batchnorm_forward.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batchnorm_forward_impl.hpp"
#include "ck/library/utility/host_common_util.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
......@@ -142,6 +142,8 @@ bool bnorm_fwd_nhwc_test(bool do_verification,
constexpr int Rank = 4;
constexpr int NumReduceDim = 3;
// when using lengths[] to create a tensor, lengths[0] is the length of highest dimension
// eg. N of NHWC, so lengths[3] is the dimension C length of NHWC
const std::vector<size_t> scaleBiasMeanVarLengths = {inOutLengths[3]};
// input data of the batchnorm forward algorithm
......@@ -300,7 +302,7 @@ bool bnorm_fwd_nhwc_test(bool do_verification,
i_inOutLengths,
i_inOutStrides,
i_inOutStrides,
{0, 1, 2},
{0, 1, 2}, // indicates physical indices of reduce dimensions in lengths[] and strides[]
i_scaleBiasMeanVarLengths,
i_scaleBiasMeanVarStrides,
i_scaleBiasMeanVarStrides,
......@@ -366,13 +368,15 @@ bool bnorm_fwd_nhwc_test(bool do_verification,
{
using ReferenceBatchNormFwdInstance =
ck::tensor_operation::host::ReferenceBatchNormFwd_Input_N_H_W_C_Output_C<InOutDataType,
ck::tensor_operation::host::ReferenceBatchNormFwd<InOutDataType,
InOutDataType,
AccDataType,
AccDataType,
AccDataType,
AccDataType,
PassThroughOp>;
PassThroughOp,
Rank,
NumReduceDim>;
auto batchNormFwd_ref = ReferenceBatchNormFwdInstance{};
......@@ -380,7 +384,7 @@ bool bnorm_fwd_nhwc_test(bool do_verification,
i_inOutLengths,
i_inOutStrides,
i_inOutStrides,
{0, 1, 2},
{0, 1, 2}, // indicates physical indices of reduce dimensions in lengths[] and strides[]
i_scaleBiasMeanVarLengths,
i_scaleBiasMeanVarStrides,
i_scaleBiasMeanVarStrides,
......
......@@ -15,7 +15,8 @@
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/host_common_util.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batchnorm_infer_nhwc_c.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batchnorm_infer.hpp"
#include "batchnorm_infer_impl.hpp"
......@@ -124,6 +125,8 @@ bool bnorm_infer_nhwc_test(bool do_verification,
constexpr int Rank = 4;
constexpr int NumReduceDim = 3;
// when using lengths[] to create a tensor, lengths[0] is the length of highest dimension
// eg. N of NHWC, so lengths[3] is the dimension C length of NHWC
const std::vector<size_t> scaleBiasMeanVarLengths = {inOutLengths[3]};
// input data of the batchnorm forward algorithm
......@@ -260,20 +263,25 @@ bool bnorm_infer_nhwc_test(bool do_verification,
if(do_verification)
{
using PassThroughOp = ck::tensor_operation::element_wise::PassThrough;
using ReferenceBatchNormInferInstance =
ck::tensor_operation::host::ReferenceBatchNormInfer_Input_N_H_W_C_Output_C<
InOutDataType,
ck::tensor_operation::host::ReferenceBatchNormInfer<InOutDataType,
InOutDataType,
AccDataType,
AccDataType,
AccDataType,
AccDataType>;
AccDataType,
PassThroughOp,
Rank,
NumReduceDim>;
auto batchNormInfer_ref = ReferenceBatchNormInferInstance{};
auto argument_ptr_ref =
batchNormInfer_ref.MakeArgumentPointer(i_inOutLengths,
i_inOutStrides,
i_inOutStrides,
{0, 1, 2},
i_scaleBiasMeanVarLengths,
i_scaleBiasMeanVarStrides,
i_scaleBiasMeanVarStrides,
......@@ -282,6 +290,7 @@ bool bnorm_infer_nhwc_test(bool do_verification,
bnScale.mData.data(),
bnBias.mData.data(),
epsilon,
PassThroughOp{},
estimatedMean.mData.data(),
estimatedVariance.mData.data(),
y_ref.mData.data());
......
......@@ -13,7 +13,15 @@ namespace ck {
namespace tensor_operation {
namespace device {
template <index_t Rank, index_t NumBatchNormReduceDim, typename YElementwiseOp>
template <typename XDataType,
typename YDataType,
typename AccDataType,
typename ScaleDataType,
typename BiasDataType,
typename MeanVarDataType,
typename YElementwiseOp,
index_t Rank,
index_t NumBatchNormReduceDim>
struct DeviceBatchNormFwd : public BaseOperator
{
virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(
......@@ -40,9 +48,24 @@ struct DeviceBatchNormFwd : public BaseOperator
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
template <index_t Rank, index_t NumBatchNormReduceDim, typename YElementwiseOp>
using DeviceBatchNormFwdPtr =
std::unique_ptr<DeviceBatchNormFwd<Rank, NumBatchNormReduceDim, YElementwiseOp>>;
template <typename XDataType,
typename YDataType,
typename AccDataType,
typename ScaleDataType,
typename BiasDataType,
typename MeanVarDataType,
typename YElementwiseOp,
index_t Rank,
index_t NumBatchNormReduceDim>
using DeviceBatchNormFwdPtr = std::unique_ptr<DeviceBatchNormFwd<XDataType,
YDataType,
AccDataType,
ScaleDataType,
BiasDataType,
MeanVarDataType,
YElementwiseOp,
Rank,
NumBatchNormReduceDim>>;
} // namespace device
} // namespace tensor_operation
......
......@@ -13,13 +13,22 @@ namespace ck {
namespace tensor_operation {
namespace device {
template <index_t Rank, index_t NumBatchNormReduceDim>
template <typename XDataType,
typename YDataType,
typename AccDataType,
typename ScaleDataType,
typename BiasDataType,
typename MeanVarDataType,
typename YElementwiseOp,
index_t Rank,
index_t NumBatchNormReduceDim>
struct DeviceBatchNormInfer : public BaseOperator
{
virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(
const std::array<index_t, Rank> xyLengths,
const std::array<index_t, Rank> xStrides,
const std::array<index_t, Rank> yStrides,
const std::array<int, NumBatchNormReduceDim> reduceDims,
const std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleBiasMeanVarLengths,
const std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleStrides,
const std::array<index_t, Rank - NumBatchNormReduceDim> bnBiasStrides,
......@@ -28,6 +37,7 @@ struct DeviceBatchNormInfer : public BaseOperator
const void* bnScale,
const void* bnBias,
double epsilon,
const YElementwiseOp y_elementwise_op,
const void* estimatedMean,
const void* estimatedInvVariance,
void* p_y) = 0;
......@@ -35,8 +45,24 @@ struct DeviceBatchNormInfer : public BaseOperator
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
template <index_t Rank, index_t NumBatchNormReduceDim>
using DeviceBatchNormInferPtr = std::unique_ptr<DeviceBatchNormInfer<Rank, NumBatchNormReduceDim>>;
template <typename XDataType,
typename YDataType,
typename AccDataType,
typename ScaleDataType,
typename BiasDataType,
typename MeanVarDataType,
typename YElementwiseOp,
index_t Rank,
index_t NumBatchNormReduceDim>
using DeviceBatchNormInferPtr = std::unique_ptr<DeviceBatchNormInfer<XDataType,
YDataType,
AccDataType,
ScaleDataType,
BiasDataType,
MeanVarDataType,
YElementwiseOp,
Rank,
NumBatchNormReduceDim>>;
} // namespace device
} // namespace tensor_operation
......
......@@ -13,15 +13,14 @@
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_gemm_multiple_d_welford_first_half_xdl_cshuffle.hpp"
// #include
// "ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_welford_second_half_layernorm2d.hpp"
#include "ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_welford_second_half_layernorm2d.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "device_base.hpp"
namespace ck {
template <typename GridwiseGemm,
template <typename GridwiseGemmWelford,
typename ABDataType,
typename DsPointer,
typename EDataType,
......@@ -63,9 +62,10 @@ __global__ void
const Block2ETileMap block_2_etile_map)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
__shared__ char p_shared[GridwiseGemmWelford::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
GridwiseGemmWelford::template Run<HasMainKBlockLoop>(
p_a_grid,
p_b_grid,
p_ds_grid,
p_e_grid,
......@@ -102,23 +102,23 @@ __global__ void
#endif
}
// template <typename GridwiseWelfordLayernorm,
// typename EDataType,
// typename HDataType,
// typename MeanDataType,
// typename VarDataType>
// __global__ void
// #if CK_USE_LAUNCH_BOUNDS
// __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
// #endif
// kernel_welford_layernorm2d_second_half(const EDataType* __restrict__ p_x_grid,
// const MeanDataType* __restrict__ p_mean_grid,
// const VarDataType* __restrict__ p_var_grid,
// HDataType* __restrict__ p_y_grid,
// index_t blkgroup_size)
// {
// GridwiseWelfordLayernorm::Run(p_x_grid, p_mean_grid, p_var_grid, p_y_grid, blkgroup_size);
// }
template <typename GridwiseWelfordLayernorm,
typename EDataType,
typename HDataType,
typename MeanDataType,
typename VarDataType>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_welford_layernorm2d_second_half(const EDataType* __restrict__ p_x_grid,
const MeanDataType* __restrict__ p_mean_grid,
const VarDataType* __restrict__ p_var_grid,
HDataType* __restrict__ p_y_grid,
index_t blkgroup_size)
{
GridwiseWelfordLayernorm::Run(p_x_grid, p_mean_grid, p_var_grid, p_y_grid, blkgroup_size);
}
} // namespace ck
......@@ -335,8 +335,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
using MeanVarGridDesc_M = decltype(MakeDescriptor_M(1));
using HGridDesc_M_N = decltype(MakeGridDescriptor_M_N<HLayout>(1, 1, 1));
// GridwiseGemm
using GridwiseGemm = GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle<
using GridwiseGemmWelford = GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle<
ADataType, // TODO: distinguish A/B datatype
AccDataType,
CShuffleDataType,
......@@ -388,29 +387,29 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
1,
LoopSched>;
using Block2ETileMap = typename GridwiseGemm::DefaultBlock2ETileMap;
// using GridwiseWelfordLayernorm =
// GridwiseWelfordSecondHalfLayernorm2d<EDataType,
// HDataType,
// MeanDataType,
// VarDataType,
// AccDataType,
// HGridDesc_M_N,
// MeanGridDesc_M_N,
// GammaBetaGridDesc_N,
// MeanVarGridDesc_M,
// BlockSize,
// LayernormMThreadClusterSize,
// LayernormNThreadClusterSize,
// LayernormMThreadSliceSize,
// LayernormNThreadSliceSize,
// LayernormESrcHDstVectorDim,
// LayernormESrcVectorSize,
// LayernormHDstVectorSize,
// LayernormGammaSrcVectorSize,
// LayernormBetaSrcVectorSize,
// LayernormMeanVarSrcDstVectorSize>;
using Block2ETileMap = typename GridwiseGemmWelford::DefaultBlock2ETileMap;
using GridwiseWelfordLayernorm =
GridwiseWelfordSecondHalfLayernorm2d<EDataType,
HDataType,
MeanDataType,
VarDataType,
AccDataType,
HGridDesc_M_N,
MeanVarGridDesc_M_N,
GammaBetaGridDesc_N,
MeanVarGridDesc_M,
BlockSize,
LayernormThreadClusterSize_M_N::At(I0),
LayernormThreadClusterSize_M_N::At(I1),
LayernormThreadSliceSize_M_N::At(I0),
LayernormThreadSliceSize_M_N::At(I1),
LayernormESrcHDstVectorDim,
LayernormESrcVectorSize,
LayernormHDstVectorSize,
LayernormGammaSrcVectorSize,
LayernormBetaSrcVectorSize,
LayernormMeanVarSrcDstVectorSize>;
// Argument
struct Argument : public BaseArgument
......@@ -451,7 +450,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
gamma_grid_desc_n_{DeviceOp::MakeDescriptor_N(NRaw)},
beta_grid_desc_n_{DeviceOp::MakeDescriptor_N(NRaw)},
h_grid_desc_m_n_{DeviceOp::MakeGridDescriptor_M_N<HLayout>(MRaw, NRaw, StrideH)},
block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)},
block_2_etile_map_{GridwiseGemmWelford::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
cde_element_op_{cde_element_op},
......@@ -484,7 +483,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
});
// populate desc for Ds/E/F/G
if(GridwiseGemm::CheckValidity(a_grid_desc_m_k_,
if(GridwiseGemmWelford::CheckValidity(a_grid_desc_m_k_,
b_grid_desc_n_k_,
ds_grid_desc_m_n_,
e_grid_desc_m_n_,
......@@ -493,19 +492,19 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
block_2_etile_map_))
{
ds_grid_desc_mblock_mperblock_nblock_nperblock_ =
GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
GridwiseGemmWelford::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
ds_grid_desc_m_n_);
e_grid_desc_mblock_mperblock_nblock_nperblock_ =
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
GridwiseGemmWelford::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
e_grid_desc_m_n_);
mean_grid_desc_mblock_mperblock_nblock_ =
GridwiseGemm::MakeMeanVarGridDescriptor_MBlock_MPerBlock_NBlock(
GridwiseGemmWelford::MakeMeanVarGridDescriptor_MBlock_MPerBlock_NBlock(
mean_grid_desc_m_n_);
var_grid_desc_mblock_mperblock_nblock_ =
GridwiseGemm::MakeMeanVarGridDescriptor_MBlock_MPerBlock_NBlock(
GridwiseGemmWelford::MakeMeanVarGridDescriptor_MBlock_MPerBlock_NBlock(
var_grid_desc_m_n_);
}
......@@ -526,7 +525,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
// pointers
const ADataType* p_a_grid_;
const BDataType* p_b_grid_;
typename GridwiseGemm::DsGridPointer p_ds_grid_;
typename GridwiseGemmWelford::DsGridPointer p_ds_grid_;
EDataType* p_e_grid_;
MeanDataType* p_mean_grid_; // mean
VarDataType* p_var_grid_; // variance * count
......@@ -546,15 +545,15 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
HGridDesc_M_N h_grid_desc_m_n_;
// tensor descriptors for block/thread-wise copy
typename GridwiseGemm::DefaultAGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
typename GridwiseGemm::DefaultBGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
typename GridwiseGemm::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
typename GridwiseGemmWelford::DefaultAGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
typename GridwiseGemmWelford::DefaultBGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
typename GridwiseGemmWelford::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock_;
typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
typename GridwiseGemmWelford::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock_;
typename GridwiseGemm::MeanGridDescriptor_MBlock_MPerBlock_NBlock
typename GridwiseGemmWelford::MeanGridDescriptor_MBlock_MPerBlock_NBlock
mean_grid_desc_mblock_mperblock_nblock_;
typename GridwiseGemm::VarGridDescriptor_MBlock_MPerBlock_NBlock
typename GridwiseGemmWelford::VarGridDescriptor_MBlock_MPerBlock_NBlock
var_grid_desc_mblock_mperblock_nblock_;
// block-to-e-tile map
......@@ -579,7 +578,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
{
float avg_time = 0;
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_,
if(!GridwiseGemmWelford::CheckValidity(arg.a_grid_desc_m_k_,
arg.b_grid_desc_n_k_,
arg.ds_grid_desc_m_n_,
arg.e_grid_desc_m_n_,
......@@ -587,7 +586,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
arg.var_grid_desc_m_n_,
arg.block_2_etile_map_))
{
throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
throw std::runtime_error("wrong! GridwiseGemmWelford has invalid setting");
}
const index_t grid_size =
......@@ -601,30 +600,32 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
const auto kernel_gemm_welford =
kernel_gemm_multiple_d_welford_first_half_xdl_cshuffle<
GridwiseGemm,
GridwiseGemmWelford,
ADataType, // TODO: distiguish A/B datatype
typename GridwiseGemm::DsGridPointer,
typename GridwiseGemmWelford::DsGridPointer,
EDataType,
MeanDataType,
VarDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation,
typename GridwiseGemm::DefaultAGridDesc_AK0_M_AK1,
typename GridwiseGemm::DefaultBGridDesc_BK0_N_BK1,
typename GridwiseGemm::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::MeanGridDescriptor_MBlock_MPerBlock_NBlock,
typename GridwiseGemm::VarGridDescriptor_MBlock_MPerBlock_NBlock,
typename GridwiseGemm::DefaultBlock2ETileMap,
typename GridwiseGemmWelford::DefaultAGridDesc_AK0_M_AK1,
typename GridwiseGemmWelford::DefaultBGridDesc_BK0_N_BK1,
typename GridwiseGemmWelford::
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemmWelford::
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemmWelford::MeanGridDescriptor_MBlock_MPerBlock_NBlock,
typename GridwiseGemmWelford::VarGridDescriptor_MBlock_MPerBlock_NBlock,
typename GridwiseGemmWelford::DefaultBlock2ETileMap,
has_main_loop>;
// const auto kernel_welford_layernorm =
// kernel_welford_layernorm2d_second_half<GridwiseWelfordLayernorm,
// EDataType,
// HDataType,
// MeanDataType,
// VarDataType>;
const auto kernel_welford_layernorm =
kernel_welford_layernorm2d_second_half<GridwiseWelfordLayernorm,
EDataType,
HDataType,
MeanDataType,
VarDataType>;
avg_time +=
launch_and_time_kernel(stream_config,
......@@ -649,21 +650,21 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
arg.var_grid_desc_mblock_mperblock_nblock_,
arg.block_2_etile_map_);
// avg_time += launch_and_time_kernel(stream_config,
// kernel_welford_layernorm,
// dim3(grid_size),
// dim3(BlockSize),
// 0,
// arg.p_e_grid_,
// arg.p_mean_grid_,
// arg.p_var_grid_,
// arg.p_h_grid_,
// arg.blkGroupSize_);
avg_time += launch_and_time_kernel(stream_config,
kernel_welford_layernorm,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_e_grid_,
arg.p_mean_grid_,
arg.p_var_grid_,
arg.p_h_grid_,
arg.blkGroupSize_);
return avg_time;
};
if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
if(GridwiseGemmWelford::CalculateHasMainKBlockLoop(K))
{
return launch_kernel(integral_constant<bool, true>{});
}
......
......@@ -42,8 +42,15 @@ template <typename XDataType,
index_t ScaleSrcVectorSize,
index_t BiasSrcVectorSize,
index_t MeanVarSrcDstVectorSize>
struct DeviceBatchNormFwdImpl
: public DeviceBatchNormFwd<Rank, NumBatchNormReduceDim, YElementwiseOp>
struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType,
YDataType,
AccDataType,
ScaleDataType,
BiasDataType,
MeanVarDataType,
YElementwiseOp,
Rank,
NumBatchNormReduceDim>
{
static_assert(Rank <= 6, "Bigger Rank size is not supported!");
static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize,
......
......@@ -19,18 +19,115 @@
namespace ck {
template <typename XDataType, typename YDataType, typename MeanDataType, typename VarDataType>
template <typename EDataType,
typename HDataType,
typename MeanDataType,
typename VarDataType,
typename ComputeDataType,
typename XYGridDesc_M_N,
typename MeanVarGridDesc_M_N,
typename GammaBetaGridDesc_N,
typename MeanVarGridDesc_M,
index_t BlockSize,
index_t MThreadClusterSize,
index_t NThreadClusterSize,
index_t MThreadSliceSize,
index_t NThreadSliceSize,
index_t XSrcYDstVectorDim,
index_t XSrcVectorSize,
index_t YDstVectorSize,
index_t GammaSrcVectorSize,
index_t BetaSrcVectorSize,
index_t MeanVarSrcDstVectorSize>
struct GridwiseWelfordSecondHalfLayernorm2d
{
__device__ static void Run(const XDataType* __restrict__ p_x_grid,
static constexpr bool reorder_thread_cluster = (XSrcYDstVectorDim == 0);
using ThreadClusterLengths_M_N = Sequence<MThreadClusterSize, NThreadClusterSize>;
using ThreadBufferDimAccessOrder =
typename conditional<reorder_thread_cluster, Sequence<1, 0>, Sequence<0, 1>>::type;
using ThreadClusterArrangeOrder =
typename conditional<reorder_thread_cluster, Sequence<1, 0>, Sequence<0, 1>>::type;
static constexpr auto thread_cluster_desc =
make_cluster_descriptor(ThreadClusterLengths_M_N{}, ThreadClusterArrangeOrder{});
using ThreadReduceSrcDesc_M_1 = decltype(
make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{}, Number<1>{})));
using ThreadReduceDstDesc_M =
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{})));
using ThreadwiseWelford =
ThreadwiseWelfordMerge<ComputeDataType, ThreadReduceSrcDesc_M_1, ThreadReduceDstDesc_M>;
using BlockwiseWelford = BlockwiseWelford<ComputeDataType,
BlockSize,
ThreadClusterLengths_M_N,
ThreadClusterArrangeOrder>;
using PassThroughOp = tensor_operation::element_wise::PassThrough;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
static constexpr index_t N_BlockTileSize = NThreadClusterSize * NThreadSliceSize;
__device__ static void Run(const EDataType* __restrict__ p_e_grid,
const MeanDataType* __restrict__ p_mean_grid,
const VarDataType* __restrict__ p_var_grid,
YDataType* __restrict__ p_y_grid)
HDataType* __restrict__ p_h_grid,
/*const MeanVarGridDesc_M_N& mean_grid_desc_m_k,
const MeanVarGridDesc_M_N& var_grid_desc_m_k,
const GammaBetaGridDesc_N& gamma_grid_desc_m,
const GammaBetaGridDesc_N& beta_grid_desc_m,
const MeanVarGridDesc_M& mean_var_grid_desc_m,*/
index_t blkgroup_size)
{
ignore = p_x_grid;
ignore = p_e_grid;
ignore = p_mean_grid;
ignore = p_var_grid;
ignore = p_y_grid;
ignore = p_h_grid;
const index_t thread_local_id = get_thread_local_1d_id();
const index_t block_global_id = get_block_1d_id();
const index_t blkgroup_id = block_global_id / blkgroup_size;
const index_t block_local_id = block_global_id % blkgroup_size;
const auto thread_cluster_idx =
thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));
const auto thread_m_cluster_id = thread_cluster_idx[I0];
const auto thread_n_cluster_id = thread_cluster_idx[I1];
using ThreadBufferLengths_M_N = Sequence<MThreadSliceSize, NThreadSliceSize>;
using ThreadBufferLengths_M = Sequence<MThreadSliceSize>;
using ThreadBufferLengths_M_1 = Sequence<MThreadSliceSize, 1>;
constexpr auto thread_buffer_desc_m_ = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<NThreadSliceSize>{}));
constexpr auto thread_buffer_desc_m =
make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{}));
constexpr auto thread_buffer_desc_m_1 = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<1>{}));
/*
auto threadwise_mean_load_m_n =
ThreadwiseTensorSliceTransfer_v2<MeanDataType,
ComputeDataType,
MeanVarGridDesc_M_N,
decltype(thread_buffer_desc_m_1),
ThreadBufferLengths_M_1,
Sequence<0, 1>,
1,
1,
1,
true>(
mean_grid_desc_m_n,
make_multi_index(blkgroup_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
thread_n_cluster_id * 1));*/
} // run
};
......
......@@ -4,13 +4,13 @@
#pragma once
#include <iostream>
#include <vector>
#include <array>
#include <algorithm>
#include <thread>
#include "ck/utility/math_v2.hpp"
#include "ck/utility/ignore.hpp"
#include "ck/library/utility/host_common_util.hpp"
#include "ck/tensor_operation/gpu/device/device_batchnorm_forward.hpp"
namespace ck {
......@@ -23,20 +23,33 @@ template <typename XDataType,
typename ScaleDataType,
typename BiasDataType,
typename MeanVarDataType,
typename YElementwiseOp>
struct ReferenceBatchNormFwd_Input_N_H_W_C_Output_C
: public device::DeviceBatchNormFwd<4, 3, YElementwiseOp>
typename YElementwiseOp,
index_t Rank,
index_t NumBatchNormReduceDim>
struct ReferenceBatchNormFwd : public device::DeviceBatchNormFwd<XDataType,
YDataType,
AccDataType,
ScaleDataType,
BiasDataType,
MeanVarDataType,
YElementwiseOp,
Rank,
NumBatchNormReduceDim>
{
static_assert(Rank <= 6, "Bigger Rank size is not supported!");
static constexpr index_t NumInvariantDim = Rank - NumBatchNormReduceDim;
struct Argument : public device::BaseArgument
{
Argument(const std::array<index_t, 4> xyLengths,
const std::array<index_t, 4> xStrides,
const std::array<index_t, 4> yStrides,
const std::array<int, 3> reduceDims,
const std::array<index_t, 1> bnScaleBiasMeanVarLengths,
const std::array<index_t, 1> bnScaleStrides,
const std::array<index_t, 1> bnBiasStrides,
const std::array<index_t, 1> bnMeanVarStrides,
Argument(const std::array<index_t, Rank> xyLengths,
const std::array<index_t, Rank> xStrides,
const std::array<index_t, Rank> yStrides,
const std::array<int, NumBatchNormReduceDim> reduceDims,
const std::array<index_t, NumInvariantDim> bnScaleBiasMeanVarLengths,
const std::array<index_t, NumInvariantDim> bnScaleStrides,
const std::array<index_t, NumInvariantDim> bnBiasStrides,
const std::array<index_t, NumInvariantDim> bnMeanVarStrides,
const XDataType* p_x,
const ScaleDataType* bnScale,
const BiasDataType* bnBias,
......@@ -48,7 +61,12 @@ struct ReferenceBatchNormFwd_Input_N_H_W_C_Output_C
double averageFactor,
MeanVarDataType* resultRunningMean,
MeanVarDataType* resultRunningVariance)
: p_x_(p_x),
: reduceDims_(reduceDims),
bnScaleBiasMeanVarLengths_(bnScaleBiasMeanVarLengths),
bnScaleStrides_(bnScaleStrides),
bnBiasStrides_(bnBiasStrides),
bnMeanVarStrides_(bnMeanVarStrides),
p_x_(p_x),
bnScale_(bnScale),
bnBias_(bnBias),
y_elementwise_op_(y_elementwise_op),
......@@ -58,21 +76,51 @@ struct ReferenceBatchNormFwd_Input_N_H_W_C_Output_C
resultRunningMean_(resultRunningMean),
resultRunningVariance_(resultRunningVariance)
{
ignore = xStrides;
ignore = yStrides;
ignore = bnScaleStrides;
ignore = bnBiasStrides;
ignore = bnMeanVarStrides;
ignore = reduceDims;
if(xyLengths.size() != 4 || bnScaleBiasMeanVarLengths.size() != 1 ||
bnScaleBiasMeanVarLengths[0] != xyLengths[3])
throw std::runtime_error("Invalid tensor dimensions!");
n = xyLengths[0];
h = xyLengths[1];
w = xyLengths[2];
c = xyLengths[3];
using ck::host_common::get_index_set;
if(std::any_of(
reduceDims.begin(), reduceDims.end(), [](int d) { return d < 0 || d >= Rank; }))
throw std::runtime_error("Invalid reduce dimensions!");
// get invariant_dims[] and invariant_lengths[]
for(int dim = 0, i = 0; dim < Rank; dim++)
if(std::none_of(
reduceDims.begin(), reduceDims.end(), [&](int d) { return d == dim; }))
{
invariantDims_[i] = dim;
invariant_lengths_[i] = xyLengths[dim];
i++;
};
// get reduce_lengths_[]
for(int j = 0, i = 0; j < NumBatchNormReduceDim; j++)
{
int dim = reduceDims[j];
reduce_lengths_[i++] = xyLengths[dim];
};
for(int i = 0; i < NumInvariantDim; i++)
if(invariant_lengths_[i] != bnScaleBiasMeanVarLengths_[i])
throw std::runtime_error("Invalid lengths parameters!");
for(int j = 0, i = 0; j < NumInvariantDim; j++)
{
int dim = invariantDims_[j];
x_invariant_strides_[i] = xStrides[dim];
y_invariant_strides_[i] = yStrides[dim];
i++;
};
for(int j = 0, i = 0; j < NumBatchNormReduceDim; j++)
{
int dim = reduceDims_[j];
x_reduce_strides_[i] = xStrides[dim];
y_reduce_strides_[i] = yStrides[dim];
i++;
};
invariant_index_set_ = get_index_set<NumInvariantDim>(invariant_lengths_);
reduce_index_set_ = get_index_set<NumBatchNormReduceDim>(reduce_lengths_);
epsilon_ = type_convert<AccDataType>(epsilon);
averageFactor_ = type_convert<AccDataType>(averageFactor);
......@@ -81,6 +129,21 @@ struct ReferenceBatchNormFwd_Input_N_H_W_C_Output_C
resultRunning = (resultRunningMean != nullptr && resultRunningVariance != nullptr);
}
std::array<int, NumBatchNormReduceDim> reduceDims_;
std::array<int, NumInvariantDim> invariantDims_;
std::array<index_t, NumInvariantDim> invariant_lengths_;
std::array<index_t, NumBatchNormReduceDim> reduce_lengths_;
const std::array<index_t, NumInvariantDim> bnScaleBiasMeanVarLengths_;
const std::array<index_t, NumInvariantDim> bnScaleStrides_;
const std::array<index_t, NumInvariantDim> bnBiasStrides_;
const std::array<index_t, NumInvariantDim> bnMeanVarStrides_;
std::array<index_t, NumInvariantDim> x_invariant_strides_;
std::array<index_t, NumInvariantDim> y_invariant_strides_;
std::array<index_t, NumBatchNormReduceDim> x_reduce_strides_;
std::array<index_t, NumBatchNormReduceDim> y_reduce_strides_;
const XDataType* p_x_;
const ScaleDataType* bnScale_;
const BiasDataType* bnBias_;
......@@ -94,7 +157,8 @@ struct ReferenceBatchNormFwd_Input_N_H_W_C_Output_C
bool resultSave, resultRunning;
index_t n, h, w, c;
std::vector<std::array<index_t, NumInvariantDim>> invariant_index_set_;
std::vector<std::array<index_t, NumBatchNormReduceDim>> reduce_index_set_;
AccDataType averageFactor_;
AccDataType epsilon_;
......@@ -104,28 +168,28 @@ struct ReferenceBatchNormFwd_Input_N_H_W_C_Output_C
{
float Run(const Argument& arg)
{
auto thread_reduce_func = [&](auto iC) {
index_t offset_C = iC;
using ck::host_common::get_offset_from_index;
auto thread_reduce_func = [&](auto invariant_index) {
size_t x_invariant_offset = get_offset_from_index<NumInvariantDim>(
arg.x_invariant_strides_, invariant_index);
size_t y_invariant_offset = get_offset_from_index<NumInvariantDim>(
arg.y_invariant_strides_, invariant_index);
AccDataType mean = type_convert<AccDataType>(0.0f);
AccDataType variance = type_convert<AccDataType>(0.0f);
int32_t curr_count = 0;
// compute mean, variance using welford method
for(index_t iN = 0; iN < arg.n; iN++)
{
index_t offset_N = iN * arg.h * arg.w * arg.c;
for(index_t iH = 0; iH < arg.h; iH++)
{
index_t offset_H = iH * arg.w * arg.c;
for(index_t iW = 0; iW < arg.w; iW++)
for(const auto& reduce_index : arg.reduce_index_set_)
{
index_t offset_W = iW * arg.c;
size_t x_reduce_offset = get_offset_from_index<NumBatchNormReduceDim>(
arg.x_reduce_strides_, reduce_index);
auto offset = offset_N + offset_H + offset_W + offset_C;
auto x_offset = x_invariant_offset + x_reduce_offset;
curr_count++;
AccDataType x = type_convert<AccDataType>(arg.p_x_[offset]);
AccDataType x = type_convert<AccDataType>(arg.p_x_[x_offset]);
AccDataType delta = x - mean;
......@@ -135,74 +199,88 @@ struct ReferenceBatchNormFwd_Input_N_H_W_C_Output_C
variance += delta * delta2;
};
}
};
// actual variance
variance = variance / curr_count;
// inv-variance defined as 1/sqrt(epsilon+variance)
AccDataType invVariance =
type_convert<AccDataType>(1.0f) / ck::math::sqrt(arg.epsilon_ + variance);
// save the mean/invVariance if required
// save the mean/inv-variance if required
if(arg.resultSave)
{
arg.resultSaveMean_[iC] = type_convert<MeanVarDataType>(mean);
arg.resultSaveInvVariance_[iC] = type_convert<MeanVarDataType>(invVariance);
size_t offset = get_offset_from_index<NumInvariantDim>(arg.bnMeanVarStrides_,
invariant_index);
arg.resultSaveMean_[offset] = type_convert<MeanVarDataType>(mean);
arg.resultSaveInvVariance_[offset] = type_convert<MeanVarDataType>(invVariance);
};
// update the moving average if required
if(arg.resultRunning)
{
size_t offset = get_offset_from_index<NumInvariantDim>(arg.bnMeanVarStrides_,
invariant_index);
AccDataType oneMinusAverageFactor =
type_convert<AccDataType>(1.0) - arg.averageFactor_;
arg.resultRunningMean_[iC] = type_convert<MeanVarDataType>(
type_convert<AccDataType>(arg.resultRunningMean_[iC]) *
arg.resultRunningMean_[offset] = type_convert<MeanVarDataType>(
type_convert<AccDataType>(arg.resultRunningMean_[offset]) *
oneMinusAverageFactor +
mean * arg.averageFactor_);
arg.resultRunningVariance_[iC] = type_convert<MeanVarDataType>(
arg.resultRunningVariance_[iC] * oneMinusAverageFactor +
arg.resultRunningVariance_[offset] = type_convert<MeanVarDataType>(
arg.resultRunningVariance_[offset] * oneMinusAverageFactor +
variance * arg.averageFactor_);
};
size_t scale_offset =
get_offset_from_index<NumInvariantDim>(arg.bnScaleStrides_, invariant_index);
size_t bias_offset =
get_offset_from_index<NumInvariantDim>(arg.bnBiasStrides_, invariant_index);
AccDataType scale = type_convert<AccDataType>(arg.bnScale_[scale_offset]);
AccDataType bias = type_convert<AccDataType>(arg.bnBias_[bias_offset]);
// Normalization
for(index_t iN = 0; iN < arg.n; iN++)
{
index_t offset_N = iN * arg.h * arg.w * arg.c;
for(index_t iH = 0; iH < arg.h; iH++)
for(const auto& reduce_index : arg.reduce_index_set_)
{
index_t offset_H = iH * arg.w * arg.c;
for(index_t iW = 0; iW < arg.w; iW++)
{
index_t offset_W = iW * arg.c;
size_t x_reduce_offset = get_offset_from_index<NumBatchNormReduceDim>(
arg.x_reduce_strides_, reduce_index);
size_t y_reduce_offset = get_offset_from_index<NumBatchNormReduceDim>(
arg.y_reduce_strides_, reduce_index);
auto offset = offset_N + offset_H + offset_W + offset_C;
auto x_offset = x_invariant_offset + x_reduce_offset;
auto y_offset = y_invariant_offset + y_reduce_offset;
AccDataType x = type_convert<AccDataType>(arg.p_x_[offset]);
AccDataType x = type_convert<AccDataType>(arg.p_x_[x_offset]);
AccDataType norm_x =
arg.bnScale_[iC] * (x - mean) * invVariance + arg.bnBias_[iC];
AccDataType norm_x = (x - mean) * invVariance;
arg.p_y_[offset] = type_convert<YDataType>(norm_x);
};
}
AccDataType y = scale * norm_x + bias;
arg.y_elementwise_op_(y, y);
arg.p_y_[y_offset] = type_convert<YDataType>(y);
};
};
std::size_t num_thread = std::thread::hardware_concurrency();
std::size_t work_per_thread = (arg.c + num_thread - 1) / num_thread;
std::size_t work_per_thread =
(arg.invariant_index_set_.size() + num_thread - 1) / num_thread;
std::vector<joinable_thread> threads(num_thread);
for(std::size_t it = 0; it < num_thread; ++it)
{
std::size_t ic_begin = it * work_per_thread;
std::size_t ic_end = std::min(static_cast<int>((it + 1) * work_per_thread), arg.c);
std::size_t i_begin = it * work_per_thread;
std::size_t i_end = std::min(static_cast<size_t>((it + 1) * work_per_thread),
arg.invariant_index_set_.size());
auto f = [=] {
for(std::size_t ic = ic_begin; ic < ic_end; ++ic)
for(std::size_t i = i_begin; i < i_end; ++i)
{
thread_reduce_func(ic);
thread_reduce_func(arg.invariant_index_set_[i]);
}
};
......@@ -278,7 +356,7 @@ struct ReferenceBatchNormFwd_Input_N_H_W_C_Output_C
auto str = std::stringstream();
// clang-format off
str << "Reference_BatchNorm_Forward_NHWC_C<" << std::endl;
str << "Reference_BatchNorm_Forward" << std::endl;
// clang-format on
return str.str();
......
......@@ -8,6 +8,7 @@
#include <array>
#include <algorithm>
#include "ck/library/utility/host_common_util.hpp"
#include "ck/tensor_operation/gpu/device/device_batchnorm_infer.hpp"
namespace ck {
......@@ -19,114 +20,205 @@ template <typename XDataType,
typename AccDataType,
typename ScaleDataType,
typename BiasDataType,
typename MeanVarDataType>
struct ReferenceBatchNormInfer_Input_N_H_W_C_Output_C : public device::DeviceBatchNormInfer<4, 3>
typename MeanVarDataType,
typename YElementwiseOp,
index_t Rank,
index_t NumBatchNormReduceDim>
struct ReferenceBatchNormInfer : public device::DeviceBatchNormInfer<XDataType,
YDataType,
AccDataType,
ScaleDataType,
BiasDataType,
MeanVarDataType,
YElementwiseOp,
Rank,
NumBatchNormReduceDim>
{
static_assert(Rank <= 6, "Bigger Rank size is not supported!");
static constexpr index_t NumInvariantDim = Rank - NumBatchNormReduceDim;
struct Argument : public device::BaseArgument
{
Argument(const std::array<index_t, 4> xyLengths,
const std::array<index_t, 4> xStrides,
const std::array<index_t, 4> yStrides,
const std::array<index_t, 1> bnScaleBiasMeanVarLengths,
const std::array<index_t, 1> bnScaleStrides,
const std::array<index_t, 1> bnBiasStrides,
const std::array<index_t, 1> bnMeanVarStrides,
Argument(const std::array<index_t, Rank> xyLengths,
const std::array<index_t, Rank> xStrides,
const std::array<index_t, Rank> yStrides,
const std::array<int, NumBatchNormReduceDim> reduceDims,
const std::array<index_t, NumInvariantDim> bnScaleBiasMeanVarLengths,
const std::array<index_t, NumInvariantDim> bnScaleStrides,
const std::array<index_t, NumInvariantDim> bnBiasStrides,
const std::array<index_t, NumInvariantDim> bnMeanVarStrides,
const XDataType* p_x,
const ScaleDataType* bnScale,
const BiasDataType* bnBias,
double epsilon,
const YElementwiseOp y_elementwise_op,
const MeanVarDataType* estimatedMean,
const MeanVarDataType* estimatedVariance,
YDataType* p_y)
: p_x_(p_x),
: reduceDims_(reduceDims),
bnScaleBiasMeanVarLengths_(bnScaleBiasMeanVarLengths),
bnScaleStrides_(bnScaleStrides),
bnBiasStrides_(bnBiasStrides),
bnMeanVarStrides_(bnMeanVarStrides),
p_x_(p_x),
bnScale_(bnScale),
bnBias_(bnBias),
epsilon_(epsilon),
y_elementwise_op_(y_elementwise_op),
estimatedMean_(estimatedMean),
estimatedVariance_(estimatedVariance),
p_y_(p_y)
{
ignore = xStrides;
ignore = yStrides;
ignore = bnScaleStrides;
ignore = bnBiasStrides;
ignore = bnMeanVarStrides;
using ck::host_common::get_index_set;
if(std::any_of(
reduceDims.begin(), reduceDims.end(), [](int d) { return d < 0 || d >= Rank; }))
throw std::runtime_error("Invalid reduce dimensions!");
// get invariant_dims[] and invariant_lengths[]
for(int dim = 0, i = 0; dim < Rank; dim++)
if(std::none_of(
reduceDims.begin(), reduceDims.end(), [&](int d) { return d == dim; }))
{
invariantDims_[i] = dim;
invariant_lengths_[i] = xyLengths[dim];
i++;
};
// get reduce_lengths_[]
for(int j = 0, i = 0; j < NumBatchNormReduceDim; j++)
{
int dim = reduceDims[j];
reduce_lengths_[i++] = xyLengths[dim];
};
// check invariant_lengths_ and bnScaleBiasMeanVarLengths
for(int i = 0; i < NumInvariantDim; i++)
if(invariant_lengths_[i] != bnScaleBiasMeanVarLengths_[i])
throw std::runtime_error("Invalid lengths parameters!");
for(int j = 0, i = 0; j < NumInvariantDim; j++)
{
int dim = invariantDims_[j];
x_invariant_strides_[i] = xStrides[dim];
y_invariant_strides_[i] = yStrides[dim];
i++;
};
for(int j = 0, i = 0; j < NumBatchNormReduceDim; j++)
{
int dim = reduceDims_[j];
x_reduce_strides_[i] = xStrides[dim];
y_reduce_strides_[i] = yStrides[dim];
i++;
};
if(xyLengths.size() != 4 || bnScaleBiasMeanVarLengths.size() != 1 ||
bnScaleBiasMeanVarLengths[0] != xyLengths[3])
throw std::runtime_error("Invalid tensor dimensions!");
invariant_index_set_ = get_index_set<NumInvariantDim>(invariant_lengths_);
reduce_index_set_ = get_index_set<NumBatchNormReduceDim>(reduce_lengths_);
n_ = xyLengths[0];
h_ = xyLengths[1];
w_ = xyLengths[2];
c_ = xyLengths[3];
epsilon_ = type_convert<AccDataType>(epsilon);
}
std::array<int, NumBatchNormReduceDim> reduceDims_;
std::array<int, NumInvariantDim> invariantDims_;
std::array<index_t, NumInvariantDim> invariant_lengths_;
std::array<index_t, NumBatchNormReduceDim> reduce_lengths_;
const std::array<index_t, NumInvariantDim> bnScaleBiasMeanVarLengths_;
const std::array<index_t, NumInvariantDim> bnScaleStrides_;
const std::array<index_t, NumInvariantDim> bnBiasStrides_;
const std::array<index_t, NumInvariantDim> bnMeanVarStrides_;
std::array<index_t, NumInvariantDim> x_invariant_strides_;
std::array<index_t, NumInvariantDim> y_invariant_strides_;
std::array<index_t, NumBatchNormReduceDim> x_reduce_strides_;
std::array<index_t, NumBatchNormReduceDim> y_reduce_strides_;
const XDataType* p_x_;
const ScaleDataType* bnScale_;
const BiasDataType* bnBias_;
double epsilon_;
const YElementwiseOp y_elementwise_op_;
const MeanVarDataType* estimatedMean_;
const MeanVarDataType* estimatedVariance_;
YDataType* p_y_;
index_t n_, h_, w_, c_;
std::vector<std::array<index_t, NumInvariantDim>> invariant_index_set_;
std::vector<std::array<index_t, NumBatchNormReduceDim>> reduce_index_set_;
AccDataType epsilon_;
};
struct Invoker : public device::BaseInvoker
{
float Run(const Argument& arg)
{
auto thread_reduce_func = [&](auto iC) {
index_t offset_C = iC;
AccDataType mean = arg.estimatedMean_[offset_C];
AccDataType variance = arg.estimatedVariance_[offset_C];
using ck::host_common::get_offset_from_index;
auto thread_reduce_func = [&](auto invariant_index) {
size_t x_invariant_offset = get_offset_from_index<NumInvariantDim>(
arg.x_invariant_strides_, invariant_index);
size_t y_invariant_offset = get_offset_from_index<NumInvariantDim>(
arg.y_invariant_strides_, invariant_index);
size_t mean_variance_offset =
get_offset_from_index<NumInvariantDim>(arg.bnMeanVarStrides_, invariant_index);
AccDataType mean = arg.estimatedMean_[mean_variance_offset];
AccDataType variance = arg.estimatedVariance_[mean_variance_offset];
// inv-variance defined as 1/sqrt(epsilon+variance)
AccDataType invVariance =
type_convert<AccDataType>(1.0f) /
std::sqrt(type_convert<AccDataType>(arg.epsilon_) + variance);
type_convert<AccDataType>(1.0f) / std::sqrt(arg.epsilon_ + variance);
// Normalization
for(index_t iN = 0; iN < arg.n_; iN++)
{
index_t offset_N = iN * arg.h_ * arg.w_ * arg.c_;
for(index_t iH = 0; iH < arg.h_; iH++)
{
index_t offset_H = iH * arg.w_ * arg.c_;
for(index_t iW = 0; iW < arg.w_; iW++)
size_t scale_offset =
get_offset_from_index<NumInvariantDim>(arg.bnScaleStrides_, invariant_index);
size_t bias_offset =
get_offset_from_index<NumInvariantDim>(arg.bnBiasStrides_, invariant_index);
AccDataType scale = type_convert<AccDataType>(arg.bnScale_[scale_offset]);
AccDataType bias = type_convert<AccDataType>(arg.bnBias_[bias_offset]);
// normalization
for(const auto& reduce_index : arg.reduce_index_set_)
{
index_t offset_W = iW * arg.c_;
size_t x_reduce_offset = get_offset_from_index<NumBatchNormReduceDim>(
arg.x_reduce_strides_, reduce_index);
size_t y_reduce_offset = get_offset_from_index<NumBatchNormReduceDim>(
arg.y_reduce_strides_, reduce_index);
auto offset = offset_N + offset_H + offset_W + offset_C;
auto x_offset = x_invariant_offset + x_reduce_offset;
auto y_offset = y_invariant_offset + y_reduce_offset;
AccDataType x = type_convert<AccDataType>(arg.p_x_[offset]);
AccDataType x = type_convert<AccDataType>(arg.p_x_[x_offset]);
AccDataType norm_x =
arg.bnScale_[iC] * (x - mean) * invVariance + arg.bnBias_[iC];
AccDataType norm_x = (x - mean) * invVariance;
arg.p_y_[offset] = type_convert<YDataType>(norm_x);
};
}
AccDataType y = scale * norm_x + bias;
arg.y_elementwise_op_(y, y);
arg.p_y_[y_offset] = type_convert<YDataType>(y);
};
};
std::size_t num_thread = std::thread::hardware_concurrency();
std::size_t work_per_thread = (arg.c_ + num_thread - 1) / num_thread;
std::size_t work_per_thread =
(arg.invariant_index_set_.size() + num_thread - 1) / num_thread;
std::vector<joinable_thread> threads(num_thread);
for(std::size_t it = 0; it < num_thread; ++it)
{
std::size_t ic_begin = it * work_per_thread;
std::size_t ic_end = std::min(static_cast<int>((it + 1) * work_per_thread), arg.c_);
std::size_t i_begin = it * work_per_thread;
std::size_t i_end = std::min(static_cast<size_t>((it + 1) * work_per_thread),
arg.invariant_index_set_.size());
auto f = [=] {
for(std::size_t ic = ic_begin; ic < ic_end; ++ic)
for(std::size_t i = i_begin; i < i_end; ++i)
{
thread_reduce_func(ic);
thread_reduce_func(arg.invariant_index_set_[i]);
}
};
......@@ -151,17 +243,19 @@ struct ReferenceBatchNormInfer_Input_N_H_W_C_Output_C : public device::DeviceBat
};
std::unique_ptr<device::BaseArgument>
MakeArgumentPointer(const std::array<index_t, 4> xyLengths,
const std::array<index_t, 4> xStrides,
const std::array<index_t, 4> yStrides,
const std::array<index_t, 1> bnScaleBiasMeanVarLengths,
const std::array<index_t, 1> bnScaleStrides,
const std::array<index_t, 1> bnBiasStrides,
const std::array<index_t, 1> bnMeanVarStrides,
MakeArgumentPointer(const std::array<index_t, Rank> xyLengths,
const std::array<index_t, Rank> xStrides,
const std::array<index_t, Rank> yStrides,
const std::array<int, NumBatchNormReduceDim> reduceDims,
const std::array<index_t, NumInvariantDim> bnScaleBiasMeanVarLengths,
const std::array<index_t, NumInvariantDim> bnScaleStrides,
const std::array<index_t, NumInvariantDim> bnBiasStrides,
const std::array<index_t, NumInvariantDim> bnMeanVarStrides,
const void* p_x,
const void* bnScale,
const void* bnBias,
double epsilon,
const YElementwiseOp y_elementwise_op,
const void* estimatedMean,
const void* estimatedVariance,
void* p_y) override
......@@ -169,6 +263,7 @@ struct ReferenceBatchNormInfer_Input_N_H_W_C_Output_C : public device::DeviceBat
return std::make_unique<Argument>(xyLengths,
xStrides,
yStrides,
reduceDims,
bnScaleBiasMeanVarLengths,
bnScaleStrides,
bnBiasStrides,
......@@ -177,6 +272,7 @@ struct ReferenceBatchNormInfer_Input_N_H_W_C_Output_C : public device::DeviceBat
static_cast<const ScaleDataType*>(bnScale),
static_cast<const BiasDataType*>(bnBias),
epsilon,
y_elementwise_op,
static_cast<const MeanVarDataType*>(estimatedMean),
static_cast<const MeanVarDataType*>(estimatedVariance),
static_cast<YDataType*>(p_y));
......@@ -192,7 +288,7 @@ struct ReferenceBatchNormInfer_Input_N_H_W_C_Output_C : public device::DeviceBat
auto str = std::stringstream();
// clang-format off
str << "Reference_BatchNorm_Forward_NHWC_C<" << std::endl;
str << "Reference_BatchNorm_Infer<" << std::endl;
// clang-format on
return str.str();
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/device_batchnorm_forward.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// FP16
void add_device_batchnorm_forward_rank_4_3_f16_instances(
std::vector<
std::unique_ptr<DeviceBatchNormFwd<F16, F16, F32, F16, F16, F32, PassThrough, 4, 3>>>&);
// FP32
void add_device_batchnorm_forward_rank_4_3_f32_instances(
std::vector<
std::unique_ptr<DeviceBatchNormFwd<F32, F32, F32, F32, F32, F32, PassThrough, 4, 3>>>&);
// BF16
void add_device_batchnorm_forward_rank_4_3_bf16_instances(
std::vector<
std::unique_ptr<DeviceBatchNormFwd<BF16, BF16, F32, BF16, BF16, F32, PassThrough, 4, 3>>>&);
// Int8
void add_device_batchnorm_forward_rank_4_3_i8_instances(
std::vector<std::unique_ptr<DeviceBatchNormFwd<I8, I8, F32, I8, I8, F32, PassThrough, 4, 3>>>&);
// FP64
void add_device_batchnorm_forward_rank_4_3_f64_instances(
std::vector<
std::unique_ptr<DeviceBatchNormFwd<F64, F64, F64, F64, F64, F64, PassThrough, 4, 3>>>&);
template <typename XDataType,
typename YDataType,
typename AccDataType,
typename ScaleDataType,
typename BiasDataType,
typename MeanVarDataType,
typename YElementwiseOp,
index_t Rank,
index_t NumReduceDim>
struct DeviceOperationInstanceFactory<
ck::tensor_operation::device::DeviceBatchNormFwd<XDataType,
YDataType,
AccDataType,
ScaleDataType,
BiasDataType,
MeanVarDataType,
YElementwiseOp,
Rank,
NumReduceDim>>
{
using DeviceOp = DeviceBatchNormFwd<XDataType,
YDataType,
AccDataType,
ScaleDataType,
BiasDataType,
MeanVarDataType,
YElementwiseOp,
Rank,
NumReduceDim>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
if constexpr(is_same_v<XDataType, F16> && is_same_v<YDataType, F16> &&
is_same_v<AccDataType, F32> && is_same_v<ScaleDataType, F16> &&
is_same_v<BiasDataType, F16> && is_same_v<MeanVarDataType, F32>)
{
if constexpr(Rank == 4 && NumReduceDim == 3 && is_same_v<YElementwiseOp, PassThrough>)
{
add_device_batchnorm_forward_rank_4_3_f16_instances(op_ptrs);
}
}
else if constexpr(is_same_v<XDataType, F32> && is_same_v<YDataType, F32> &&
is_same_v<AccDataType, F32> && is_same_v<ScaleDataType, F32> &&
is_same_v<BiasDataType, F32> && is_same_v<MeanVarDataType, F32>)
{
if constexpr(Rank == 4 && NumReduceDim == 3 && is_same_v<YElementwiseOp, PassThrough>)
{
add_device_batchnorm_forward_rank_4_3_f32_instances(op_ptrs);
}
}
else if constexpr(is_same_v<XDataType, BF16> && is_same_v<YDataType, BF16> &&
is_same_v<AccDataType, F32> && is_same_v<ScaleDataType, BF16> &&
is_same_v<BiasDataType, BF16> && is_same_v<MeanVarDataType, F32>)
{
if constexpr(Rank == 4 && NumReduceDim == 3 && is_same_v<YElementwiseOp, PassThrough>)
{
add_device_batchnorm_forward_rank_4_3_bf16_instances(op_ptrs);
}
}
else if constexpr(is_same_v<XDataType, I8> && is_same_v<YDataType, I8> &&
is_same_v<AccDataType, F32> && is_same_v<ScaleDataType, I8> &&
is_same_v<BiasDataType, I8> && is_same_v<MeanVarDataType, F32>)
{
if constexpr(Rank == 4 && NumReduceDim == 3 && is_same_v<YElementwiseOp, PassThrough>)
{
add_device_batchnorm_forward_rank_4_3_i8_instances(op_ptrs);
}
}
else if constexpr(is_same_v<XDataType, F64> && is_same_v<YDataType, F64> &&
is_same_v<AccDataType, F64> && is_same_v<ScaleDataType, F64> &&
is_same_v<BiasDataType, F64> && is_same_v<MeanVarDataType, F64>)
{
if constexpr(Rank == 4 && NumReduceDim == 3 && is_same_v<YElementwiseOp, PassThrough>)
{
add_device_batchnorm_forward_rank_4_3_f64_instances(op_ptrs);
}
}
return op_ptrs;
}
};
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -4,9 +4,11 @@
#pragma once
#include <vector>
#include <array>
#include <iostream>
#include <fstream>
#include <string>
#include <algorithm>
#include "ck/ck.hpp"
......@@ -72,5 +74,63 @@ static inline std::vector<T> getTypeValuesFromString(const char* cstr_values)
return (values);
}
template <int NDim>
static inline std::vector<std::array<index_t, NDim>>
get_index_set(const std::array<index_t, NDim>& dim_lengths)
{
static_assert(NDim >= 1, "NDim >= 1 is required to use this function!");
if constexpr(NDim == 1)
{
std::vector<std::array<index_t, NDim>> index_set;
for(int i = 0; i < dim_lengths[0]; i++)
{
std::array<index_t, 1> index{i};
index_set.push_back(index);
};
return index_set;
}
else
{
std::vector<std::array<index_t, NDim>> index_set;
std::array<index_t, NDim - 1> partial_dim_lengths;
std::copy(dim_lengths.begin() + 1, dim_lengths.end(), partial_dim_lengths.begin());
std::vector<std::array<index_t, NDim - 1>> partial_index_set;
partial_index_set = get_index_set<NDim - 1>(partial_dim_lengths);
for(index_t i = 0; i < dim_lengths[0]; i++)
for(const auto& partial_index : partial_index_set)
{
std::array<index_t, NDim> index;
index[0] = i;
std::copy(partial_index.begin(), partial_index.end(), index.begin() + 1);
index_set.push_back(index);
};
return index_set;
};
};
template <int NDim>
static inline size_t get_offset_from_index(const std::array<index_t, NDim>& strides,
const std::array<index_t, NDim>& index)
{
size_t offset = 0;
for(int i = 0; i < NDim; i++)
offset += index[i] * strides[i];
return (offset);
};
} // namespace host_common
} // namespace ck
add_instance_library(device_batchnorm_instance
device_batchnorm_forward_f16_instance.cpp
device_batchnorm_forward_f32_instance.cpp
device_batchnorm_forward_bf16_instance.cpp
device_batchnorm_forward_i8_instance.cpp
device_batchnorm_forward_f64_instance.cpp
)
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batchnorm_forward_impl.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using BF16 = ck::bhalf_t;
using F32 = float;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
// clang-format off
template <index_t Rank, index_t NumReduceDim, typename YElementwiseOp>
using device_batchnorm_forward_bf16_blockwise_instances =
std::tuple <
// XDataType, YDataType, AccDataType, ScaleDataType, BiasDataType, MeanVarDataType, YElementwiseOp, Rank, NumReduceDim, UseMultiBlockInK, BLockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XSrcYDstVectorDim, XSrcVectorSize, YDstVectorSize, ScaleSrcVectorSize, BiasSrcVectorSize, MeanVarSrcDstVectorSize
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 128, 2, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 128, 2, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 128, 2, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 128, 2, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 128, 2, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 128, 2, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 64, 4, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 64, 4, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 64, 4, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 64, 4, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 64, 4, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 64, 4, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 32, 8, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 32, 8, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 32, 8, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 32, 8, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 32, 8, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 32, 8, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 16, 16, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 16, 16, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 16, 16, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 16, 16, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 16, 16, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 16, 16, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 8, 32, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 8, 32, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 8, 32, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 8, 32, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 8, 32, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 8, 32, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 4, 64, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 4, 64, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 4, 64, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 4, 64, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 4, 64, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 4, 64, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 2, 128, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 2, 128, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 2, 128, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 2, 128, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 2, 128, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 2, 128, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 1, 256, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 1, 256, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 1, 256, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 1, 256, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 1, 256, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 1, 256, 2, 2, 1, 1, 1, 1, 1, 1>
>;
// clang-format on
// clang-format off
template <index_t Rank, index_t NumReduceDim, typename YElementwiseOp>
using device_batchnorm_forward_bf16_multiblock_instances =
std::tuple <
// XDataType, YDataType, AccDataType, ScaleDataType, BiasDataType, MeanVarDataType, YElementwiseOp, Rank, NumReduceDim, UseMultiBlockInK, BLockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XSrcYDstVectorDim, XSrcVectorSize, YDstVectorSize, ScaleSrcVectorSize, BiasSrcVectorSize, MeanVarSrcDstVectorSize
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 128, 2, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 128, 2, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 128, 2, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 128, 2, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 128, 2, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 128, 2, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 64, 4, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 64, 4, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 64, 4, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 64, 4, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 64, 4, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 64, 4, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 32, 8, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 32, 8, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 32, 8, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 32, 8, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 32, 8, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 32, 8, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 16, 16, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 16, 16, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 16, 16, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 16, 16, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 16, 16, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 16, 16, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 8, 32, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 8, 32, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 8, 32, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 8, 32, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 8, 32, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 8, 32, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 4, 64, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 4, 64, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 4, 64, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 4, 64, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 4, 64, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 4, 64, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 2, 128, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 2, 128, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 2, 128, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 2, 128, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 2, 128, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 2, 128, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 1, 256, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 1, 256, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 1, 256, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 1, 256, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 1, 256, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 1, 256, 2, 2, 1, 1, 1, 1, 1, 1>
>;
// clang-format on
void add_device_batchnorm_forward_rank_4_3_bf16_instances(
std::vector<
std::unique_ptr<DeviceBatchNormFwd<BF16, BF16, F32, BF16, BF16, F32, PassThrough, 4, 3>>>&
instances)
{
add_device_operation_instances(
instances, device_batchnorm_forward_bf16_blockwise_instances<4, 3, PassThrough>{});
add_device_operation_instances(
instances, device_batchnorm_forward_bf16_multiblock_instances<4, 3, PassThrough>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batchnorm_forward_impl.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using F16 = ck::half_t;
using F32 = float;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
// clang-format off
template <index_t Rank, index_t NumReduceDim, typename YElementwiseOp>
using device_batchnorm_forward_f16_blockwise_instances =
std::tuple <
// XDataType, YDataType, AccDataType, ScaleDataType, BiasDataType, MeanVarDataType, YElementwiseOp, Rank, NumReduceDim, UseMultiBlockInK, BLockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XSrcYDstVectorDim, XSrcVectorSize, YDstVectorSize, ScaleSrcVectorSize, BiasSrcVectorSize, MeanVarSrcDstVectorSize
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 128, 2, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 128, 2, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 128, 2, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 128, 2, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 128, 2, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 128, 2, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 64, 4, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 64, 4, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 64, 4, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 64, 4, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 64, 4, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 64, 4, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 32, 8, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 32, 8, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 32, 8, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 32, 8, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 32, 8, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 32, 8, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 16, 16, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 16, 16, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 16, 16, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 16, 16, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 16, 16, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 16, 16, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 8, 32, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 8, 32, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 8, 32, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 8, 32, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 8, 32, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 8, 32, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 4, 64, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 4, 64, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 4, 64, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 4, 64, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 4, 64, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 4, 64, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 2, 128, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 2, 128, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 2, 128, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 2, 128, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 2, 128, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 2, 128, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 1, 256, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 1, 256, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 1, 256, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 1, 256, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 1, 256, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 1, 256, 2, 2, 1, 1, 1, 1, 1, 1>
>;
// clang-format on
// clang-format off
template <index_t Rank, index_t NumReduceDim, typename YElementwiseOp>
using device_batchnorm_forward_f16_multiblock_instances =
std::tuple <
// XDataType, YDataType, AccDataType, ScaleDataType, BiasDataType, MeanVarDataType, YElementwiseOp, Rank, NumReduceDim, UseMultiBlockInK, BLockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XSrcYDstVectorDim, XSrcVectorSize, YDstVectorSize, ScaleSrcVectorSize, BiasSrcVectorSize, MeanVarSrcDstVectorSize
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 128, 2, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 128, 2, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 128, 2, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 128, 2, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 128, 2, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 128, 2, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 64, 4, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 64, 4, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 64, 4, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 64, 4, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 64, 4, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 64, 4, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 32, 8, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 32, 8, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 32, 8, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 32, 8, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 32, 8, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 32, 8, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 16, 16, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 16, 16, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 16, 16, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 16, 16, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 16, 16, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 16, 16, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 8, 32, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 8, 32, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 8, 32, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 8, 32, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 8, 32, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 8, 32, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 4, 64, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 4, 64, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 4, 64, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 4, 64, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 4, 64, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 4, 64, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 2, 128, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 2, 128, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 2, 128, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 2, 128, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 2, 128, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 2, 128, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 1, 256, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 1, 256, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 1, 256, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 1, 256, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 1, 256, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 1, 256, 2, 2, 1, 1, 1, 1, 1, 1>
>;
// clang-format on
void add_device_batchnorm_forward_rank_4_3_f16_instances(
std::vector<
std::unique_ptr<DeviceBatchNormFwd<F16, F16, F32, F16, F16, F32, PassThrough, 4, 3>>>&
instances)
{
add_device_operation_instances(
instances, device_batchnorm_forward_f16_blockwise_instances<4, 3, PassThrough>{});
add_device_operation_instances(
instances, device_batchnorm_forward_f16_multiblock_instances<4, 3, PassThrough>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batchnorm_forward_impl.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using F32 = float;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
// clang-format off
template <index_t Rank, index_t NumReduceDim, typename YElementwiseOp>
using device_batchnorm_forward_f32_blockwise_instances = std::tuple<
// XDataType, YDataType, AccDataType, ScaleDataType, BiasDataType, MeanVarDataType, YElementwiseOp, Rank, NumReduceDim, UseMultiBlockInK, BLockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XSrcYDstVectorDim, XSrcVectorSize, YDstVectorSize, ScaleSrcVectorSize, BiasSrcVectorSize, MeanVarSrcDstVectorSize
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 128, 2, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 128, 2, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 128, 2, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 128, 2, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 128, 2, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 128, 2, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 64, 4, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 64, 4, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 64, 4, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 64, 4, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 64, 4, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 64, 4, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 32, 8, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 32, 8, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 32, 8, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 32, 8, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 32, 8, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 32, 8, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 16, 16, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 16, 16, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 16, 16, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 16, 16, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 16, 16, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 16, 16, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 8, 32, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 8, 32, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 8, 32, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 8, 32, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 8, 32, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 8, 32, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 4, 64, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 4, 64, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 4, 64, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 4, 64, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 4, 64, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 4, 64, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 2, 128, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 2, 128, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 2, 128, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 2, 128, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 2, 128, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 2, 128, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 1, 256, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 1, 256, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 1, 256, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 1, 256, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 1, 256, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 1, 256, 2, 2, 1, 1, 1, 1, 1, 1>
>;
// clang-format on
// clang-format off
template <index_t Rank, index_t NumReduceDim, typename YElementwiseOp>
using device_batchnorm_forward_f32_multiblock_instances =
std::tuple <
// XDataType, YDataType, AccDataType, ScaleDataType, BiasDataType, MeanVarDataType, YElementwiseOp, Rank, NumReduceDim, UseMultiBlockInK, BLockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XSrcYDstVectorDim, XSrcVectorSize, YDstVectorSize, ScaleSrcVectorSize, BiasSrcVectorSize, MeanVarSrcDstVectorSize
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 128, 2, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 128, 2, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 128, 2, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 128, 2, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 128, 2, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 128, 2, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 64, 4, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 64, 4, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 64, 4, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 64, 4, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 64, 4, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 64, 4, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 32, 8, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 32, 8, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 32, 8, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 32, 8, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 32, 8, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 32, 8, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 16, 16, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 16, 16, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 16, 16, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 16, 16, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 16, 16, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 16, 16, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 8, 32, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 8, 32, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 8, 32, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 8, 32, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 8, 32, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 8, 32, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 4, 64, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 4, 64, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 4, 64, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 4, 64, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 4, 64, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 4, 64, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 2, 128, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 2, 128, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 2, 128, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 2, 128, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 2, 128, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 2, 128, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 1, 256, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 1, 256, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 1, 256, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 1, 256, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 1, 256, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 1, 256, 2, 2, 1, 1, 1, 1, 1, 1>
>;
// clang-format on
void add_device_batchnorm_forward_rank_4_3_f32_instances(
std::vector<
std::unique_ptr<DeviceBatchNormFwd<F32, F32, F32, F32, F32, F32, PassThrough, 4, 3>>>&
instances)
{
add_device_operation_instances(
instances, device_batchnorm_forward_f32_blockwise_instances<4, 3, PassThrough>{});
add_device_operation_instances(
instances, device_batchnorm_forward_f32_multiblock_instances<4, 3, PassThrough>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batchnorm_forward_impl.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using F64 = double;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
// clang-format off
template <index_t Rank, index_t NumReduceDim, typename YElementwiseOp>
using device_batchnorm_forward_f64_blockwise_instances = std::tuple<
// XDataType, YDataType, AccDataType, ScaleDataType, BiasDataType, MeanVarDataType, YElementwiseOp, Rank, NumReduceDim, UseMultiBlockInK, BLockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XSrcYDstVectorDim, XSrcVectorSize, YDstVectorSize, ScaleSrcVectorSize, BiasSrcVectorSize, MeanVarSrcDstVectorSize
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, false, 256, 128, 2, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, false, 256, 128, 2, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, false, 256, 128, 2, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, false, 256, 128, 2, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, false, 256, 128, 2, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, false, 256, 128, 2, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, false, 256, 64, 4, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, false, 256, 64, 4, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, false, 256, 64, 4, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, false, 256, 64, 4, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, false, 256, 64, 4, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, false, 256, 64, 4, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, false, 256, 32, 8, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, false, 256, 32, 8, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, false, 256, 32, 8, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, false, 256, 32, 8, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, false, 256, 32, 8, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, false, 256, 32, 8, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, false, 256, 16, 16, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, false, 256, 16, 16, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, false, 256, 16, 16, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, false, 256, 16, 16, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, false, 256, 16, 16, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, false, 256, 16, 16, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, false, 256, 8, 32, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, false, 256, 8, 32, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, false, 256, 8, 32, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, false, 256, 8, 32, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, false, 256, 8, 32, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, false, 256, 8, 32, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, false, 256, 4, 64, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, false, 256, 4, 64, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, false, 256, 4, 64, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, false, 256, 4, 64, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, false, 256, 4, 64, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, false, 256, 4, 64, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, false, 256, 2, 128, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, false, 256, 2, 128, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, false, 256, 2, 128, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, false, 256, 2, 128, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, false, 256, 2, 128, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, false, 256, 2, 128, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, false, 256, 1, 256, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, false, 256, 1, 256, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, false, 256, 1, 256, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, false, 256, 1, 256, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, false, 256, 1, 256, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, false, 256, 1, 256, 2, 2, 1, 1, 1, 1, 1, 1>
>;
// clang-format on
// clang-format off
template <index_t Rank, index_t NumReduceDim, typename YElementwiseOp>
using device_batchnorm_forward_f64_multiblock_instances =
std::tuple <
// XDataType, YDataType, AccDataType, ScaleDataType, BiasDataType, MeanVarDataType, YElementwiseOp, Rank, NumReduceDim, UseMultiBlockInK, BLockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XSrcYDstVectorDim, XSrcVectorSize, YDstVectorSize, ScaleSrcVectorSize, BiasSrcVectorSize, MeanVarSrcDstVectorSize
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, true, 256, 128, 2, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, true, 256, 128, 2, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, true, 256, 128, 2, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, true, 256, 128, 2, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, true, 256, 128, 2, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, true, 256, 128, 2, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, true, 256, 64, 4, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, true, 256, 64, 4, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, true, 256, 64, 4, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, true, 256, 64, 4, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, true, 256, 64, 4, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, true, 256, 64, 4, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, true, 256, 32, 8, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, true, 256, 32, 8, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, true, 256, 32, 8, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, true, 256, 32, 8, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, true, 256, 32, 8, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, true, 256, 32, 8, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, true, 256, 16, 16, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, true, 256, 16, 16, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, true, 256, 16, 16, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, true, 256, 16, 16, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, true, 256, 16, 16, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, true, 256, 16, 16, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, true, 256, 8, 32, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, true, 256, 8, 32, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, true, 256, 8, 32, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, true, 256, 8, 32, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, true, 256, 8, 32, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, true, 256, 8, 32, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, true, 256, 4, 64, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, true, 256, 4, 64, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, true, 256, 4, 64, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, true, 256, 4, 64, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, true, 256, 4, 64, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, true, 256, 4, 64, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, true, 256, 2, 128, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, true, 256, 2, 128, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, true, 256, 2, 128, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, true, 256, 2, 128, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, true, 256, 2, 128, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, true, 256, 2, 128, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, true, 256, 1, 256, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, true, 256, 1, 256, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, true, 256, 1, 256, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, true, 256, 1, 256, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, true, 256, 1, 256, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, true, 256, 1, 256, 2, 2, 1, 1, 1, 1, 1, 1>
>;
// clang-format on
void add_device_batchnorm_forward_rank_4_3_f64_instances(
std::vector<
std::unique_ptr<DeviceBatchNormFwd<F64, F64, F64, F64, F64, F64, PassThrough, 4, 3>>>&
instances)
{
add_device_operation_instances(
instances, device_batchnorm_forward_f64_blockwise_instances<4, 3, PassThrough>{});
add_device_operation_instances(
instances, device_batchnorm_forward_f64_multiblock_instances<4, 3, PassThrough>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batchnorm_forward_impl.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using I8 = int8_t;
using F32 = float;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
// clang-format off
template <index_t Rank, index_t NumReduceDim, typename YElementwiseOp>
using device_batchnorm_forward_i8_blockwise_instances = std::tuple<
// XDataType, YDataType, AccDataType, ScaleDataType, BiasDataType, MeanVarDataType, YElementwiseOp, Rank, NumReduceDim, UseMultiBlockInK, BLockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XSrcYDstVectorDim, XSrcVectorSize, YDstVectorSize, ScaleSrcVectorSize, BiasSrcVectorSize, MeanVarSrcDstVectorSize
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 128, 2, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 128, 2, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 128, 2, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 128, 2, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 128, 2, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 128, 2, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 64, 4, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 64, 4, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 64, 4, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 64, 4, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 64, 4, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 64, 4, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 32, 8, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 32, 8, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 32, 8, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 32, 8, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 32, 8, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 32, 8, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 16, 16, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 16, 16, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 16, 16, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 16, 16, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 16, 16, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 16, 16, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 8, 32, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 8, 32, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 8, 32, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 8, 32, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 8, 32, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 8, 32, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 4, 64, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 4, 64, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 4, 64, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 4, 64, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 4, 64, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 4, 64, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 2, 128, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 2, 128, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 2, 128, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 2, 128, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 2, 128, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 2, 128, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 1, 256, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 1, 256, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 1, 256, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 1, 256, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 1, 256, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 1, 256, 2, 2, 1, 1, 1, 1, 1, 1>
>;
// clang-format on
// clang-format off
template <index_t Rank, index_t NumReduceDim, typename YElementwiseOp>
using device_batchnorm_forward_i8_multiblock_instances =
std::tuple <
// XDataType, YDataType, AccDataType, ScaleDataType, BiasDataType, MeanVarDataType, YElementwiseOp, Rank, NumReduceDim, UseMultiBlockInK, BLockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XSrcYDstVectorDim, XSrcVectorSize, YDstVectorSize, ScaleSrcVectorSize, BiasSrcVectorSize, MeanVarSrcDstVectorSize
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 128, 2, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 128, 2, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 128, 2, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 128, 2, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 128, 2, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 128, 2, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 64, 4, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 64, 4, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 64, 4, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 64, 4, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 64, 4, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 64, 4, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 32, 8, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 32, 8, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 32, 8, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 32, 8, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 32, 8, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 32, 8, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 16, 16, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 16, 16, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 16, 16, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 16, 16, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 16, 16, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 16, 16, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 8, 32, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 8, 32, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 8, 32, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 8, 32, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 8, 32, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 8, 32, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 4, 64, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 4, 64, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 4, 64, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 4, 64, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 4, 64, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 4, 64, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 2, 128, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 2, 128, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 2, 128, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 2, 128, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 2, 128, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 2, 128, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 1, 256, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 1, 256, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 1, 256, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 1, 256, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 1, 256, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 1, 256, 2, 2, 1, 1, 1, 1, 1, 1>
>;
// clang-format on
void add_device_batchnorm_forward_rank_4_3_i8_instances(
std::vector<std::unique_ptr<DeviceBatchNormFwd<I8, I8, F32, I8, I8, F32, PassThrough, 4, 3>>>&
instances)
{
add_device_operation_instances(
instances, device_batchnorm_forward_i8_blockwise_instances<4, 3, PassThrough>{});
add_device_operation_instances(
instances, device_batchnorm_forward_i8_multiblock_instances<4, 3, PassThrough>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -26,6 +26,7 @@ set(PROFILER_SOURCE
src/profile_groupnorm.cpp
src/profile_layernorm.cpp
src/profile_softmax.cpp
src/profile_batchnorm_fwd.cpp
)
add_executable(ckProfiler ${PROFILER_SOURCE})
......@@ -57,5 +58,6 @@ target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_bias_relu_add_instanc
target_link_libraries(ckProfiler PRIVATE device_normalization_instance)
target_link_libraries(ckProfiler PRIVATE device_softmax_instance)
target_link_libraries(ckProfiler PRIVATE device_reduce_instance)
target_link_libraries(ckProfiler PRIVATE device_batchnorm_instance)
rocm_install(TARGETS ckProfiler COMPONENT profiler)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment