Commit b7f500f0 authored by rocking5566's avatar rocking5566 Committed by rocking
Browse files

Merge branch 'develop' into gemm_layernorm_welford

parents 694057a7 4e6a5575
add_executable(client_batchnorm_fwd_nhwc batchnorm_fwd_nhwc.cpp)
target_link_libraries(client_batchnorm_fwd_nhwc PRIVATE composable_kernel::device_operations)
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <functional>
#include <numeric>
#include <iomanip>
#include <iostream>
#include <vector>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/device_reduce.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/gpu/batchnorm_forward.hpp"
using XDataType = float;
using YDataType = float;
using AccDataType = float;
using ScaleDataType = AccDataType;
using BiasDataType = AccDataType;
using MeanVarDataType = AccDataType;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
constexpr int Rank = 4;
constexpr int NumBatchNormReduceDim = 3;
const double epsilon = std::numeric_limits<float>::epsilon();
const double averageFactor = 0.1;
struct SimpleDeviceMem
{
SimpleDeviceMem() = delete;
SimpleDeviceMem(std::size_t mem_size) : p_mem_{}
{
(void)hipMalloc(static_cast<void**>(&p_mem_), mem_size);
}
void* GetDeviceBuffer() { return p_mem_; }
~SimpleDeviceMem() { (void)hipFree(p_mem_); }
void* p_mem_;
};
int main(int argc, char* argv[])
{
std::array<ck::index_t, Rank> xyLengths{16, 8, 128, 256};
std::array<ck::index_t, Rank> xyStrides{8 * 128 * 256, 128 * 256, 256, 1};
std::array<ck::index_t, Rank - NumBatchNormReduceDim> scaleBiasMeanVarLengths{256};
std::array<ck::index_t, Rank - NumBatchNormReduceDim> scaleBiasMeanVarStrides{1};
std::array<int, NumBatchNormReduceDim> reduceDims{0, 1, 2};
ck::index_t numXYElement =
std::accumulate(xyLengths.begin(), xyLengths.end(), 1, std::multiplies<ck::index_t>());
ck::index_t numScaleBiasMeanVarElement = std::accumulate(scaleBiasMeanVarLengths.begin(),
scaleBiasMeanVarLengths.end(),
1,
std::multiplies<ck::index_t>());
SimpleDeviceMem x(sizeof(XDataType) * numXYElement);
SimpleDeviceMem y(sizeof(YDataType) * numXYElement);
SimpleDeviceMem scale(sizeof(ScaleDataType) * numScaleBiasMeanVarElement);
SimpleDeviceMem bias(sizeof(BiasDataType) * numScaleBiasMeanVarElement);
SimpleDeviceMem mean(sizeof(MeanVarDataType) * numScaleBiasMeanVarElement);
SimpleDeviceMem invVariance(sizeof(MeanVarDataType) * numScaleBiasMeanVarElement);
using DeviceOp = ck::tensor_operation::device::DeviceBatchNormFwd<XDataType,
YDataType,
AccDataType,
ScaleDataType,
BiasDataType,
MeanVarDataType,
PassThrough,
Rank,
NumBatchNormReduceDim>;
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
DeviceOp>::GetInstances();
std::cout << "found " << op_ptrs.size() << " instances" << std::endl;
std::string best_op_name;
bool found = false;
int best_op_id = -1;
float best_ave_time = std::numeric_limits<float>::max();
float best_gb_per_sec = 0;
// profile device operation instances
std::cout << "Run all instances and do timing" << std::endl;
for(int i = 0; i < op_ptrs.size(); ++i)
{
auto& op_ptr = op_ptrs[i];
auto argument_ptr = op_ptr->MakeArgumentPointer(xyLengths,
xyStrides,
xyStrides,
reduceDims,
scaleBiasMeanVarLengths,
scaleBiasMeanVarStrides,
scaleBiasMeanVarStrides,
scaleBiasMeanVarStrides,
x.GetDeviceBuffer(),
scale.GetDeviceBuffer(),
bias.GetDeviceBuffer(),
epsilon,
PassThrough{},
y.GetDeviceBuffer(),
mean.GetDeviceBuffer(),
invVariance.GetDeviceBuffer(),
averageFactor,
nullptr,
nullptr);
auto invoker_ptr = op_ptr->MakeInvokerPointer();
std::string op_name = op_ptr->GetTypeString();
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{
size_t workspace_sz = op_ptr->GetWorkSpaceSize(argument_ptr.get());
SimpleDeviceMem workspace(workspace_sz);
op_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace.GetDeviceBuffer());
float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true});
std::size_t num_bytes =
numXYElement * (sizeof(XDataType) + sizeof(YDataType)) +
numScaleBiasMeanVarElement * (sizeof(ScaleDataType) + sizeof(BiasDataType) +
sizeof(MeanVarDataType) + sizeof(MeanVarDataType));
float gb_per_sec = num_bytes / 1.E6 / ave_time;
std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << gb_per_sec << " GB/s, "
<< op_name << std::endl;
if(ave_time < best_ave_time)
{
found = true;
best_op_id = i;
best_op_name = op_name;
best_ave_time = ave_time;
best_gb_per_sec = gb_per_sec;
}
}
else
{
std::cout << op_name << " does not support this problem" << std::endl;
}
}
if(found)
{
std::cout << "Best Perf: " << best_ave_time << " ms, " << best_gb_per_sec << " GB/s, "
<< best_op_name << std::endl;
// run the best intance
auto& op_ptr = op_ptrs[best_op_id];
std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString()
<< std::endl;
auto argument_ptr = op_ptr->MakeArgumentPointer(xyLengths,
xyStrides,
xyStrides,
reduceDims,
scaleBiasMeanVarLengths,
scaleBiasMeanVarStrides,
scaleBiasMeanVarStrides,
scaleBiasMeanVarStrides,
x.GetDeviceBuffer(),
scale.GetDeviceBuffer(),
bias.GetDeviceBuffer(),
epsilon,
PassThrough{},
y.GetDeviceBuffer(),
mean.GetDeviceBuffer(),
invVariance.GetDeviceBuffer(),
averageFactor,
nullptr,
nullptr);
auto invoker_ptr = op_ptr->MakeInvokerPointer();
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false});
}
std::cout << "Done" << std::endl;
}
return 0;
}
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
#include "ck/library/utility/host_tensor.hpp" #include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/host_common_util.hpp" #include "ck/library/utility/host_common_util.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batchnorm_forward_nhwc_c.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_batchnorm_forward.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batchnorm_forward_impl.hpp" #include "ck/tensor_operation/gpu/device/impl/device_batchnorm_forward_impl.hpp"
#include "ck/library/utility/host_common_util.hpp" #include "ck/library/utility/host_common_util.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
...@@ -142,6 +142,8 @@ bool bnorm_fwd_nhwc_test(bool do_verification, ...@@ -142,6 +142,8 @@ bool bnorm_fwd_nhwc_test(bool do_verification,
constexpr int Rank = 4; constexpr int Rank = 4;
constexpr int NumReduceDim = 3; constexpr int NumReduceDim = 3;
// when using lengths[] to create a tensor, lengths[0] is the length of highest dimension
// eg. N of NHWC, so lengths[3] is the dimension C length of NHWC
const std::vector<size_t> scaleBiasMeanVarLengths = {inOutLengths[3]}; const std::vector<size_t> scaleBiasMeanVarLengths = {inOutLengths[3]};
// input data of the batchnorm forward algorithm // input data of the batchnorm forward algorithm
...@@ -300,7 +302,7 @@ bool bnorm_fwd_nhwc_test(bool do_verification, ...@@ -300,7 +302,7 @@ bool bnorm_fwd_nhwc_test(bool do_verification,
i_inOutLengths, i_inOutLengths,
i_inOutStrides, i_inOutStrides,
i_inOutStrides, i_inOutStrides,
{0, 1, 2}, {0, 1, 2}, // indicates physical indices of reduce dimensions in lengths[] and strides[]
i_scaleBiasMeanVarLengths, i_scaleBiasMeanVarLengths,
i_scaleBiasMeanVarStrides, i_scaleBiasMeanVarStrides,
i_scaleBiasMeanVarStrides, i_scaleBiasMeanVarStrides,
...@@ -366,13 +368,15 @@ bool bnorm_fwd_nhwc_test(bool do_verification, ...@@ -366,13 +368,15 @@ bool bnorm_fwd_nhwc_test(bool do_verification,
{ {
using ReferenceBatchNormFwdInstance = using ReferenceBatchNormFwdInstance =
ck::tensor_operation::host::ReferenceBatchNormFwd_Input_N_H_W_C_Output_C<InOutDataType, ck::tensor_operation::host::ReferenceBatchNormFwd<InOutDataType,
InOutDataType, InOutDataType,
AccDataType, AccDataType,
AccDataType, AccDataType,
AccDataType, AccDataType,
AccDataType, AccDataType,
PassThroughOp>; PassThroughOp,
Rank,
NumReduceDim>;
auto batchNormFwd_ref = ReferenceBatchNormFwdInstance{}; auto batchNormFwd_ref = ReferenceBatchNormFwdInstance{};
...@@ -380,7 +384,7 @@ bool bnorm_fwd_nhwc_test(bool do_verification, ...@@ -380,7 +384,7 @@ bool bnorm_fwd_nhwc_test(bool do_verification,
i_inOutLengths, i_inOutLengths,
i_inOutStrides, i_inOutStrides,
i_inOutStrides, i_inOutStrides,
{0, 1, 2}, {0, 1, 2}, // indicates physical indices of reduce dimensions in lengths[] and strides[]
i_scaleBiasMeanVarLengths, i_scaleBiasMeanVarLengths,
i_scaleBiasMeanVarStrides, i_scaleBiasMeanVarStrides,
i_scaleBiasMeanVarStrides, i_scaleBiasMeanVarStrides,
......
...@@ -15,7 +15,8 @@ ...@@ -15,7 +15,8 @@
#include "ck/library/utility/host_tensor.hpp" #include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/host_common_util.hpp" #include "ck/library/utility/host_common_util.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batchnorm_infer_nhwc_c.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batchnorm_infer.hpp"
#include "batchnorm_infer_impl.hpp" #include "batchnorm_infer_impl.hpp"
...@@ -124,6 +125,8 @@ bool bnorm_infer_nhwc_test(bool do_verification, ...@@ -124,6 +125,8 @@ bool bnorm_infer_nhwc_test(bool do_verification,
constexpr int Rank = 4; constexpr int Rank = 4;
constexpr int NumReduceDim = 3; constexpr int NumReduceDim = 3;
// when using lengths[] to create a tensor, lengths[0] is the length of highest dimension
// eg. N of NHWC, so lengths[3] is the dimension C length of NHWC
const std::vector<size_t> scaleBiasMeanVarLengths = {inOutLengths[3]}; const std::vector<size_t> scaleBiasMeanVarLengths = {inOutLengths[3]};
// input data of the batchnorm forward algorithm // input data of the batchnorm forward algorithm
...@@ -260,20 +263,25 @@ bool bnorm_infer_nhwc_test(bool do_verification, ...@@ -260,20 +263,25 @@ bool bnorm_infer_nhwc_test(bool do_verification,
if(do_verification) if(do_verification)
{ {
using PassThroughOp = ck::tensor_operation::element_wise::PassThrough;
using ReferenceBatchNormInferInstance = using ReferenceBatchNormInferInstance =
ck::tensor_operation::host::ReferenceBatchNormInfer_Input_N_H_W_C_Output_C< ck::tensor_operation::host::ReferenceBatchNormInfer<InOutDataType,
InOutDataType, InOutDataType,
InOutDataType, AccDataType,
AccDataType, AccDataType,
AccDataType, AccDataType,
AccDataType, AccDataType,
AccDataType>; PassThroughOp,
Rank,
NumReduceDim>;
auto batchNormInfer_ref = ReferenceBatchNormInferInstance{}; auto batchNormInfer_ref = ReferenceBatchNormInferInstance{};
auto argument_ptr_ref = auto argument_ptr_ref =
batchNormInfer_ref.MakeArgumentPointer(i_inOutLengths, batchNormInfer_ref.MakeArgumentPointer(i_inOutLengths,
i_inOutStrides, i_inOutStrides,
i_inOutStrides, i_inOutStrides,
{0, 1, 2},
i_scaleBiasMeanVarLengths, i_scaleBiasMeanVarLengths,
i_scaleBiasMeanVarStrides, i_scaleBiasMeanVarStrides,
i_scaleBiasMeanVarStrides, i_scaleBiasMeanVarStrides,
...@@ -282,6 +290,7 @@ bool bnorm_infer_nhwc_test(bool do_verification, ...@@ -282,6 +290,7 @@ bool bnorm_infer_nhwc_test(bool do_verification,
bnScale.mData.data(), bnScale.mData.data(),
bnBias.mData.data(), bnBias.mData.data(),
epsilon, epsilon,
PassThroughOp{},
estimatedMean.mData.data(), estimatedMean.mData.data(),
estimatedVariance.mData.data(), estimatedVariance.mData.data(),
y_ref.mData.data()); y_ref.mData.data());
......
...@@ -13,7 +13,15 @@ namespace ck { ...@@ -13,7 +13,15 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
template <index_t Rank, index_t NumBatchNormReduceDim, typename YElementwiseOp> template <typename XDataType,
typename YDataType,
typename AccDataType,
typename ScaleDataType,
typename BiasDataType,
typename MeanVarDataType,
typename YElementwiseOp,
index_t Rank,
index_t NumBatchNormReduceDim>
struct DeviceBatchNormFwd : public BaseOperator struct DeviceBatchNormFwd : public BaseOperator
{ {
virtual std::unique_ptr<BaseArgument> MakeArgumentPointer( virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(
...@@ -40,9 +48,24 @@ struct DeviceBatchNormFwd : public BaseOperator ...@@ -40,9 +48,24 @@ struct DeviceBatchNormFwd : public BaseOperator
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0; virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
}; };
template <index_t Rank, index_t NumBatchNormReduceDim, typename YElementwiseOp> template <typename XDataType,
using DeviceBatchNormFwdPtr = typename YDataType,
std::unique_ptr<DeviceBatchNormFwd<Rank, NumBatchNormReduceDim, YElementwiseOp>>; typename AccDataType,
typename ScaleDataType,
typename BiasDataType,
typename MeanVarDataType,
typename YElementwiseOp,
index_t Rank,
index_t NumBatchNormReduceDim>
using DeviceBatchNormFwdPtr = std::unique_ptr<DeviceBatchNormFwd<XDataType,
YDataType,
AccDataType,
ScaleDataType,
BiasDataType,
MeanVarDataType,
YElementwiseOp,
Rank,
NumBatchNormReduceDim>>;
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
......
...@@ -13,13 +13,22 @@ namespace ck { ...@@ -13,13 +13,22 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
template <index_t Rank, index_t NumBatchNormReduceDim> template <typename XDataType,
typename YDataType,
typename AccDataType,
typename ScaleDataType,
typename BiasDataType,
typename MeanVarDataType,
typename YElementwiseOp,
index_t Rank,
index_t NumBatchNormReduceDim>
struct DeviceBatchNormInfer : public BaseOperator struct DeviceBatchNormInfer : public BaseOperator
{ {
virtual std::unique_ptr<BaseArgument> MakeArgumentPointer( virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(
const std::array<index_t, Rank> xyLengths, const std::array<index_t, Rank> xyLengths,
const std::array<index_t, Rank> xStrides, const std::array<index_t, Rank> xStrides,
const std::array<index_t, Rank> yStrides, const std::array<index_t, Rank> yStrides,
const std::array<int, NumBatchNormReduceDim> reduceDims,
const std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleBiasMeanVarLengths, const std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleBiasMeanVarLengths,
const std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleStrides, const std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleStrides,
const std::array<index_t, Rank - NumBatchNormReduceDim> bnBiasStrides, const std::array<index_t, Rank - NumBatchNormReduceDim> bnBiasStrides,
...@@ -28,6 +37,7 @@ struct DeviceBatchNormInfer : public BaseOperator ...@@ -28,6 +37,7 @@ struct DeviceBatchNormInfer : public BaseOperator
const void* bnScale, const void* bnScale,
const void* bnBias, const void* bnBias,
double epsilon, double epsilon,
const YElementwiseOp y_elementwise_op,
const void* estimatedMean, const void* estimatedMean,
const void* estimatedInvVariance, const void* estimatedInvVariance,
void* p_y) = 0; void* p_y) = 0;
...@@ -35,8 +45,24 @@ struct DeviceBatchNormInfer : public BaseOperator ...@@ -35,8 +45,24 @@ struct DeviceBatchNormInfer : public BaseOperator
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0; virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
}; };
template <index_t Rank, index_t NumBatchNormReduceDim> template <typename XDataType,
using DeviceBatchNormInferPtr = std::unique_ptr<DeviceBatchNormInfer<Rank, NumBatchNormReduceDim>>; typename YDataType,
typename AccDataType,
typename ScaleDataType,
typename BiasDataType,
typename MeanVarDataType,
typename YElementwiseOp,
index_t Rank,
index_t NumBatchNormReduceDim>
using DeviceBatchNormInferPtr = std::unique_ptr<DeviceBatchNormInfer<XDataType,
YDataType,
AccDataType,
ScaleDataType,
BiasDataType,
MeanVarDataType,
YElementwiseOp,
Rank,
NumBatchNormReduceDim>>;
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
......
...@@ -13,15 +13,14 @@ ...@@ -13,15 +13,14 @@
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" #include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_gemm_multiple_d_welford_first_half_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_gemm_multiple_d_welford_first_half_xdl_cshuffle.hpp"
// #include #include "ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_welford_second_half_layernorm2d.hpp"
// "ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_welford_second_half_layernorm2d.hpp"
#include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/kernel_launch.hpp"
#include "device_base.hpp" #include "device_base.hpp"
namespace ck { namespace ck {
template <typename GridwiseGemm, template <typename GridwiseGemmWelford,
typename ABDataType, typename ABDataType,
typename DsPointer, typename DsPointer,
typename EDataType, typename EDataType,
...@@ -63,25 +62,26 @@ __global__ void ...@@ -63,25 +62,26 @@ __global__ void
const Block2ETileMap block_2_etile_map) const Block2ETileMap block_2_etile_map)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemmWelford::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid, GridwiseGemmWelford::template Run<HasMainKBlockLoop>(
p_b_grid, p_a_grid,
p_ds_grid, p_b_grid,
p_e_grid, p_ds_grid,
p_mean_grid, p_e_grid,
p_var_grid, p_mean_grid,
p_shared, p_var_grid,
a_element_op, p_shared,
b_element_op, a_element_op,
cde_element_op, b_element_op,
a_grid_desc_ak0_m_ak1, cde_element_op,
b_grid_desc_bk0_n_bk1, a_grid_desc_ak0_m_ak1,
ds_grid_desc_mblock_mperblock_nblock_nperblock, b_grid_desc_bk0_n_bk1,
e_grid_desc_mblock_mperblock_nblock_nperblock, ds_grid_desc_mblock_mperblock_nblock_nperblock,
mean_grid_desc_mblock_mperblock_nblock, e_grid_desc_mblock_mperblock_nblock_nperblock,
var_grid_desc_mblock_mperblock_nblock, mean_grid_desc_mblock_mperblock_nblock,
block_2_etile_map); var_grid_desc_mblock_mperblock_nblock,
block_2_etile_map);
#else #else
ignore = p_a_grid; ignore = p_a_grid;
ignore = p_b_grid; ignore = p_b_grid;
...@@ -102,23 +102,23 @@ __global__ void ...@@ -102,23 +102,23 @@ __global__ void
#endif #endif
} }
// template <typename GridwiseWelfordLayernorm, template <typename GridwiseWelfordLayernorm,
// typename EDataType, typename EDataType,
// typename HDataType, typename HDataType,
// typename MeanDataType, typename MeanDataType,
// typename VarDataType> typename VarDataType>
// __global__ void __global__ void
// #if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
// __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
// #endif #endif
// kernel_welford_layernorm2d_second_half(const EDataType* __restrict__ p_x_grid, kernel_welford_layernorm2d_second_half(const EDataType* __restrict__ p_x_grid,
// const MeanDataType* __restrict__ p_mean_grid, const MeanDataType* __restrict__ p_mean_grid,
// const VarDataType* __restrict__ p_var_grid, const VarDataType* __restrict__ p_var_grid,
// HDataType* __restrict__ p_y_grid, HDataType* __restrict__ p_y_grid,
// index_t blkgroup_size) index_t blkgroup_size)
// { {
// GridwiseWelfordLayernorm::Run(p_x_grid, p_mean_grid, p_var_grid, p_y_grid, blkgroup_size); GridwiseWelfordLayernorm::Run(p_x_grid, p_mean_grid, p_var_grid, p_y_grid, blkgroup_size);
// } }
} // namespace ck } // namespace ck
...@@ -335,8 +335,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -335,8 +335,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
using MeanVarGridDesc_M = decltype(MakeDescriptor_M(1)); using MeanVarGridDesc_M = decltype(MakeDescriptor_M(1));
using HGridDesc_M_N = decltype(MakeGridDescriptor_M_N<HLayout>(1, 1, 1)); using HGridDesc_M_N = decltype(MakeGridDescriptor_M_N<HLayout>(1, 1, 1));
// GridwiseGemm using GridwiseGemmWelford = GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle<
using GridwiseGemm = GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle<
ADataType, // TODO: distinguish A/B datatype ADataType, // TODO: distinguish A/B datatype
AccDataType, AccDataType,
CShuffleDataType, CShuffleDataType,
...@@ -388,29 +387,29 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -388,29 +387,29 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
1, 1,
LoopSched>; LoopSched>;
using Block2ETileMap = typename GridwiseGemm::DefaultBlock2ETileMap; using Block2ETileMap = typename GridwiseGemmWelford::DefaultBlock2ETileMap;
// using GridwiseWelfordLayernorm = using GridwiseWelfordLayernorm =
// GridwiseWelfordSecondHalfLayernorm2d<EDataType, GridwiseWelfordSecondHalfLayernorm2d<EDataType,
// HDataType, HDataType,
// MeanDataType, MeanDataType,
// VarDataType, VarDataType,
// AccDataType, AccDataType,
// HGridDesc_M_N, HGridDesc_M_N,
// MeanGridDesc_M_N, MeanVarGridDesc_M_N,
// GammaBetaGridDesc_N, GammaBetaGridDesc_N,
// MeanVarGridDesc_M, MeanVarGridDesc_M,
// BlockSize, BlockSize,
// LayernormMThreadClusterSize, LayernormThreadClusterSize_M_N::At(I0),
// LayernormNThreadClusterSize, LayernormThreadClusterSize_M_N::At(I1),
// LayernormMThreadSliceSize, LayernormThreadSliceSize_M_N::At(I0),
// LayernormNThreadSliceSize, LayernormThreadSliceSize_M_N::At(I1),
// LayernormESrcHDstVectorDim, LayernormESrcHDstVectorDim,
// LayernormESrcVectorSize, LayernormESrcVectorSize,
// LayernormHDstVectorSize, LayernormHDstVectorSize,
// LayernormGammaSrcVectorSize, LayernormGammaSrcVectorSize,
// LayernormBetaSrcVectorSize, LayernormBetaSrcVectorSize,
// LayernormMeanVarSrcDstVectorSize>; LayernormMeanVarSrcDstVectorSize>;
// Argument // Argument
struct Argument : public BaseArgument struct Argument : public BaseArgument
...@@ -451,7 +450,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -451,7 +450,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
gamma_grid_desc_n_{DeviceOp::MakeDescriptor_N(NRaw)}, gamma_grid_desc_n_{DeviceOp::MakeDescriptor_N(NRaw)},
beta_grid_desc_n_{DeviceOp::MakeDescriptor_N(NRaw)}, beta_grid_desc_n_{DeviceOp::MakeDescriptor_N(NRaw)},
h_grid_desc_m_n_{DeviceOp::MakeGridDescriptor_M_N<HLayout>(MRaw, NRaw, StrideH)}, h_grid_desc_m_n_{DeviceOp::MakeGridDescriptor_M_N<HLayout>(MRaw, NRaw, StrideH)},
block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)}, block_2_etile_map_{GridwiseGemmWelford::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)},
a_element_op_{a_element_op}, a_element_op_{a_element_op},
b_element_op_{b_element_op}, b_element_op_{b_element_op},
cde_element_op_{cde_element_op}, cde_element_op_{cde_element_op},
...@@ -484,28 +483,28 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -484,28 +483,28 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
}); });
// populate desc for Ds/E/F/G // populate desc for Ds/E/F/G
if(GridwiseGemm::CheckValidity(a_grid_desc_m_k_, if(GridwiseGemmWelford::CheckValidity(a_grid_desc_m_k_,
b_grid_desc_n_k_, b_grid_desc_n_k_,
ds_grid_desc_m_n_, ds_grid_desc_m_n_,
e_grid_desc_m_n_, e_grid_desc_m_n_,
mean_grid_desc_m_n_, mean_grid_desc_m_n_,
var_grid_desc_m_n_, var_grid_desc_m_n_,
block_2_etile_map_)) block_2_etile_map_))
{ {
ds_grid_desc_mblock_mperblock_nblock_nperblock_ = ds_grid_desc_mblock_mperblock_nblock_nperblock_ =
GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( GridwiseGemmWelford::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
ds_grid_desc_m_n_); ds_grid_desc_m_n_);
e_grid_desc_mblock_mperblock_nblock_nperblock_ = e_grid_desc_mblock_mperblock_nblock_nperblock_ =
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( GridwiseGemmWelford::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
e_grid_desc_m_n_); e_grid_desc_m_n_);
mean_grid_desc_mblock_mperblock_nblock_ = mean_grid_desc_mblock_mperblock_nblock_ =
GridwiseGemm::MakeMeanVarGridDescriptor_MBlock_MPerBlock_NBlock( GridwiseGemmWelford::MakeMeanVarGridDescriptor_MBlock_MPerBlock_NBlock(
mean_grid_desc_m_n_); mean_grid_desc_m_n_);
var_grid_desc_mblock_mperblock_nblock_ = var_grid_desc_mblock_mperblock_nblock_ =
GridwiseGemm::MakeMeanVarGridDescriptor_MBlock_MPerBlock_NBlock( GridwiseGemmWelford::MakeMeanVarGridDescriptor_MBlock_MPerBlock_NBlock(
var_grid_desc_m_n_); var_grid_desc_m_n_);
} }
...@@ -526,7 +525,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -526,7 +525,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
// pointers // pointers
const ADataType* p_a_grid_; const ADataType* p_a_grid_;
const BDataType* p_b_grid_; const BDataType* p_b_grid_;
typename GridwiseGemm::DsGridPointer p_ds_grid_; typename GridwiseGemmWelford::DsGridPointer p_ds_grid_;
EDataType* p_e_grid_; EDataType* p_e_grid_;
MeanDataType* p_mean_grid_; // mean MeanDataType* p_mean_grid_; // mean
VarDataType* p_var_grid_; // variance * count VarDataType* p_var_grid_; // variance * count
...@@ -546,15 +545,15 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -546,15 +545,15 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
HGridDesc_M_N h_grid_desc_m_n_; HGridDesc_M_N h_grid_desc_m_n_;
// tensor descriptors for block/thread-wise copy // tensor descriptors for block/thread-wise copy
typename GridwiseGemm::DefaultAGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; typename GridwiseGemmWelford::DefaultAGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
typename GridwiseGemm::DefaultBGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; typename GridwiseGemmWelford::DefaultBGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
typename GridwiseGemm::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock typename GridwiseGemmWelford::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock_; ds_grid_desc_mblock_mperblock_nblock_nperblock_;
typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock typename GridwiseGemmWelford::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock_; e_grid_desc_mblock_mperblock_nblock_nperblock_;
typename GridwiseGemm::MeanGridDescriptor_MBlock_MPerBlock_NBlock typename GridwiseGemmWelford::MeanGridDescriptor_MBlock_MPerBlock_NBlock
mean_grid_desc_mblock_mperblock_nblock_; mean_grid_desc_mblock_mperblock_nblock_;
typename GridwiseGemm::VarGridDescriptor_MBlock_MPerBlock_NBlock typename GridwiseGemmWelford::VarGridDescriptor_MBlock_MPerBlock_NBlock
var_grid_desc_mblock_mperblock_nblock_; var_grid_desc_mblock_mperblock_nblock_;
// block-to-e-tile map // block-to-e-tile map
...@@ -579,15 +578,15 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -579,15 +578,15 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
{ {
float avg_time = 0; float avg_time = 0;
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_, if(!GridwiseGemmWelford::CheckValidity(arg.a_grid_desc_m_k_,
arg.b_grid_desc_n_k_, arg.b_grid_desc_n_k_,
arg.ds_grid_desc_m_n_, arg.ds_grid_desc_m_n_,
arg.e_grid_desc_m_n_, arg.e_grid_desc_m_n_,
arg.mean_grid_desc_m_n_, arg.mean_grid_desc_m_n_,
arg.var_grid_desc_m_n_, arg.var_grid_desc_m_n_,
arg.block_2_etile_map_)) arg.block_2_etile_map_))
{ {
throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); throw std::runtime_error("wrong! GridwiseGemmWelford has invalid setting");
} }
const index_t grid_size = const index_t grid_size =
...@@ -601,30 +600,32 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -601,30 +600,32 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
const auto kernel_gemm_welford = const auto kernel_gemm_welford =
kernel_gemm_multiple_d_welford_first_half_xdl_cshuffle< kernel_gemm_multiple_d_welford_first_half_xdl_cshuffle<
GridwiseGemm, GridwiseGemmWelford,
ADataType, // TODO: distiguish A/B datatype ADataType, // TODO: distiguish A/B datatype
typename GridwiseGemm::DsGridPointer, typename GridwiseGemmWelford::DsGridPointer,
EDataType, EDataType,
MeanDataType, MeanDataType,
VarDataType, VarDataType,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CDEElementwiseOperation, CDEElementwiseOperation,
typename GridwiseGemm::DefaultAGridDesc_AK0_M_AK1, typename GridwiseGemmWelford::DefaultAGridDesc_AK0_M_AK1,
typename GridwiseGemm::DefaultBGridDesc_BK0_N_BK1, typename GridwiseGemmWelford::DefaultBGridDesc_BK0_N_BK1,
typename GridwiseGemm::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename GridwiseGemmWelford::
typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::MeanGridDescriptor_MBlock_MPerBlock_NBlock, typename GridwiseGemmWelford::
typename GridwiseGemm::VarGridDescriptor_MBlock_MPerBlock_NBlock, EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::DefaultBlock2ETileMap, typename GridwiseGemmWelford::MeanGridDescriptor_MBlock_MPerBlock_NBlock,
typename GridwiseGemmWelford::VarGridDescriptor_MBlock_MPerBlock_NBlock,
typename GridwiseGemmWelford::DefaultBlock2ETileMap,
has_main_loop>; has_main_loop>;
// const auto kernel_welford_layernorm = const auto kernel_welford_layernorm =
// kernel_welford_layernorm2d_second_half<GridwiseWelfordLayernorm, kernel_welford_layernorm2d_second_half<GridwiseWelfordLayernorm,
// EDataType, EDataType,
// HDataType, HDataType,
// MeanDataType, MeanDataType,
// VarDataType>; VarDataType>;
avg_time += avg_time +=
launch_and_time_kernel(stream_config, launch_and_time_kernel(stream_config,
...@@ -649,21 +650,21 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -649,21 +650,21 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
arg.var_grid_desc_mblock_mperblock_nblock_, arg.var_grid_desc_mblock_mperblock_nblock_,
arg.block_2_etile_map_); arg.block_2_etile_map_);
// avg_time += launch_and_time_kernel(stream_config, avg_time += launch_and_time_kernel(stream_config,
// kernel_welford_layernorm, kernel_welford_layernorm,
// dim3(grid_size), dim3(grid_size),
// dim3(BlockSize), dim3(BlockSize),
// 0, 0,
// arg.p_e_grid_, arg.p_e_grid_,
// arg.p_mean_grid_, arg.p_mean_grid_,
// arg.p_var_grid_, arg.p_var_grid_,
// arg.p_h_grid_, arg.p_h_grid_,
// arg.blkGroupSize_); arg.blkGroupSize_);
return avg_time; return avg_time;
}; };
if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) if(GridwiseGemmWelford::CalculateHasMainKBlockLoop(K))
{ {
return launch_kernel(integral_constant<bool, true>{}); return launch_kernel(integral_constant<bool, true>{});
} }
......
...@@ -42,8 +42,15 @@ template <typename XDataType, ...@@ -42,8 +42,15 @@ template <typename XDataType,
index_t ScaleSrcVectorSize, index_t ScaleSrcVectorSize,
index_t BiasSrcVectorSize, index_t BiasSrcVectorSize,
index_t MeanVarSrcDstVectorSize> index_t MeanVarSrcDstVectorSize>
struct DeviceBatchNormFwdImpl struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType,
: public DeviceBatchNormFwd<Rank, NumBatchNormReduceDim, YElementwiseOp> YDataType,
AccDataType,
ScaleDataType,
BiasDataType,
MeanVarDataType,
YElementwiseOp,
Rank,
NumBatchNormReduceDim>
{ {
static_assert(Rank <= 6, "Bigger Rank size is not supported!"); static_assert(Rank <= 6, "Bigger Rank size is not supported!");
static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize, static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize,
......
...@@ -19,18 +19,115 @@ ...@@ -19,18 +19,115 @@
namespace ck { namespace ck {
template <typename XDataType, typename YDataType, typename MeanDataType, typename VarDataType> template <typename EDataType,
typename HDataType,
typename MeanDataType,
typename VarDataType,
typename ComputeDataType,
typename XYGridDesc_M_N,
typename MeanVarGridDesc_M_N,
typename GammaBetaGridDesc_N,
typename MeanVarGridDesc_M,
index_t BlockSize,
index_t MThreadClusterSize,
index_t NThreadClusterSize,
index_t MThreadSliceSize,
index_t NThreadSliceSize,
index_t XSrcYDstVectorDim,
index_t XSrcVectorSize,
index_t YDstVectorSize,
index_t GammaSrcVectorSize,
index_t BetaSrcVectorSize,
index_t MeanVarSrcDstVectorSize>
struct GridwiseWelfordSecondHalfLayernorm2d struct GridwiseWelfordSecondHalfLayernorm2d
{ {
__device__ static void Run(const XDataType* __restrict__ p_x_grid, static constexpr bool reorder_thread_cluster = (XSrcYDstVectorDim == 0);
using ThreadClusterLengths_M_N = Sequence<MThreadClusterSize, NThreadClusterSize>;
using ThreadBufferDimAccessOrder =
typename conditional<reorder_thread_cluster, Sequence<1, 0>, Sequence<0, 1>>::type;
using ThreadClusterArrangeOrder =
typename conditional<reorder_thread_cluster, Sequence<1, 0>, Sequence<0, 1>>::type;
static constexpr auto thread_cluster_desc =
make_cluster_descriptor(ThreadClusterLengths_M_N{}, ThreadClusterArrangeOrder{});
using ThreadReduceSrcDesc_M_1 = decltype(
make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{}, Number<1>{})));
using ThreadReduceDstDesc_M =
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{})));
using ThreadwiseWelford =
ThreadwiseWelfordMerge<ComputeDataType, ThreadReduceSrcDesc_M_1, ThreadReduceDstDesc_M>;
using BlockwiseWelford = BlockwiseWelford<ComputeDataType,
BlockSize,
ThreadClusterLengths_M_N,
ThreadClusterArrangeOrder>;
using PassThroughOp = tensor_operation::element_wise::PassThrough;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
static constexpr index_t N_BlockTileSize = NThreadClusterSize * NThreadSliceSize;
__device__ static void Run(const EDataType* __restrict__ p_e_grid,
const MeanDataType* __restrict__ p_mean_grid, const MeanDataType* __restrict__ p_mean_grid,
const VarDataType* __restrict__ p_var_grid, const VarDataType* __restrict__ p_var_grid,
YDataType* __restrict__ p_y_grid) HDataType* __restrict__ p_h_grid,
/*const MeanVarGridDesc_M_N& mean_grid_desc_m_k,
const MeanVarGridDesc_M_N& var_grid_desc_m_k,
const GammaBetaGridDesc_N& gamma_grid_desc_m,
const GammaBetaGridDesc_N& beta_grid_desc_m,
const MeanVarGridDesc_M& mean_var_grid_desc_m,*/
index_t blkgroup_size)
{ {
ignore = p_x_grid; ignore = p_e_grid;
ignore = p_mean_grid; ignore = p_mean_grid;
ignore = p_var_grid; ignore = p_var_grid;
ignore = p_y_grid; ignore = p_h_grid;
const index_t thread_local_id = get_thread_local_1d_id();
const index_t block_global_id = get_block_1d_id();
const index_t blkgroup_id = block_global_id / blkgroup_size;
const index_t block_local_id = block_global_id % blkgroup_size;
const auto thread_cluster_idx =
thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));
const auto thread_m_cluster_id = thread_cluster_idx[I0];
const auto thread_n_cluster_id = thread_cluster_idx[I1];
using ThreadBufferLengths_M_N = Sequence<MThreadSliceSize, NThreadSliceSize>;
using ThreadBufferLengths_M = Sequence<MThreadSliceSize>;
using ThreadBufferLengths_M_1 = Sequence<MThreadSliceSize, 1>;
constexpr auto thread_buffer_desc_m_ = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<NThreadSliceSize>{}));
constexpr auto thread_buffer_desc_m =
make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{}));
constexpr auto thread_buffer_desc_m_1 = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<1>{}));
/*
auto threadwise_mean_load_m_n =
ThreadwiseTensorSliceTransfer_v2<MeanDataType,
ComputeDataType,
MeanVarGridDesc_M_N,
decltype(thread_buffer_desc_m_1),
ThreadBufferLengths_M_1,
Sequence<0, 1>,
1,
1,
1,
true>(
mean_grid_desc_m_n,
make_multi_index(blkgroup_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
thread_n_cluster_id * 1));*/
} // run } // run
}; };
......
...@@ -4,13 +4,13 @@ ...@@ -4,13 +4,13 @@
#pragma once #pragma once
#include <iostream> #include <iostream>
#include <vector>
#include <array> #include <array>
#include <algorithm> #include <algorithm>
#include <thread> #include <thread>
#include "ck/utility/math_v2.hpp" #include "ck/utility/math_v2.hpp"
#include "ck/utility/ignore.hpp" #include "ck/utility/ignore.hpp"
#include "ck/library/utility/host_common_util.hpp"
#include "ck/tensor_operation/gpu/device/device_batchnorm_forward.hpp" #include "ck/tensor_operation/gpu/device/device_batchnorm_forward.hpp"
namespace ck { namespace ck {
...@@ -23,20 +23,33 @@ template <typename XDataType, ...@@ -23,20 +23,33 @@ template <typename XDataType,
typename ScaleDataType, typename ScaleDataType,
typename BiasDataType, typename BiasDataType,
typename MeanVarDataType, typename MeanVarDataType,
typename YElementwiseOp> typename YElementwiseOp,
struct ReferenceBatchNormFwd_Input_N_H_W_C_Output_C index_t Rank,
: public device::DeviceBatchNormFwd<4, 3, YElementwiseOp> index_t NumBatchNormReduceDim>
struct ReferenceBatchNormFwd : public device::DeviceBatchNormFwd<XDataType,
YDataType,
AccDataType,
ScaleDataType,
BiasDataType,
MeanVarDataType,
YElementwiseOp,
Rank,
NumBatchNormReduceDim>
{ {
static_assert(Rank <= 6, "Bigger Rank size is not supported!");
static constexpr index_t NumInvariantDim = Rank - NumBatchNormReduceDim;
struct Argument : public device::BaseArgument struct Argument : public device::BaseArgument
{ {
Argument(const std::array<index_t, 4> xyLengths, Argument(const std::array<index_t, Rank> xyLengths,
const std::array<index_t, 4> xStrides, const std::array<index_t, Rank> xStrides,
const std::array<index_t, 4> yStrides, const std::array<index_t, Rank> yStrides,
const std::array<int, 3> reduceDims, const std::array<int, NumBatchNormReduceDim> reduceDims,
const std::array<index_t, 1> bnScaleBiasMeanVarLengths, const std::array<index_t, NumInvariantDim> bnScaleBiasMeanVarLengths,
const std::array<index_t, 1> bnScaleStrides, const std::array<index_t, NumInvariantDim> bnScaleStrides,
const std::array<index_t, 1> bnBiasStrides, const std::array<index_t, NumInvariantDim> bnBiasStrides,
const std::array<index_t, 1> bnMeanVarStrides, const std::array<index_t, NumInvariantDim> bnMeanVarStrides,
const XDataType* p_x, const XDataType* p_x,
const ScaleDataType* bnScale, const ScaleDataType* bnScale,
const BiasDataType* bnBias, const BiasDataType* bnBias,
...@@ -48,7 +61,12 @@ struct ReferenceBatchNormFwd_Input_N_H_W_C_Output_C ...@@ -48,7 +61,12 @@ struct ReferenceBatchNormFwd_Input_N_H_W_C_Output_C
double averageFactor, double averageFactor,
MeanVarDataType* resultRunningMean, MeanVarDataType* resultRunningMean,
MeanVarDataType* resultRunningVariance) MeanVarDataType* resultRunningVariance)
: p_x_(p_x), : reduceDims_(reduceDims),
bnScaleBiasMeanVarLengths_(bnScaleBiasMeanVarLengths),
bnScaleStrides_(bnScaleStrides),
bnBiasStrides_(bnBiasStrides),
bnMeanVarStrides_(bnMeanVarStrides),
p_x_(p_x),
bnScale_(bnScale), bnScale_(bnScale),
bnBias_(bnBias), bnBias_(bnBias),
y_elementwise_op_(y_elementwise_op), y_elementwise_op_(y_elementwise_op),
...@@ -58,21 +76,51 @@ struct ReferenceBatchNormFwd_Input_N_H_W_C_Output_C ...@@ -58,21 +76,51 @@ struct ReferenceBatchNormFwd_Input_N_H_W_C_Output_C
resultRunningMean_(resultRunningMean), resultRunningMean_(resultRunningMean),
resultRunningVariance_(resultRunningVariance) resultRunningVariance_(resultRunningVariance)
{ {
ignore = xStrides; using ck::host_common::get_index_set;
ignore = yStrides;
ignore = bnScaleStrides; if(std::any_of(
ignore = bnBiasStrides; reduceDims.begin(), reduceDims.end(), [](int d) { return d < 0 || d >= Rank; }))
ignore = bnMeanVarStrides; throw std::runtime_error("Invalid reduce dimensions!");
ignore = reduceDims;
// get invariant_dims[] and invariant_lengths[]
if(xyLengths.size() != 4 || bnScaleBiasMeanVarLengths.size() != 1 || for(int dim = 0, i = 0; dim < Rank; dim++)
bnScaleBiasMeanVarLengths[0] != xyLengths[3]) if(std::none_of(
throw std::runtime_error("Invalid tensor dimensions!"); reduceDims.begin(), reduceDims.end(), [&](int d) { return d == dim; }))
{
n = xyLengths[0]; invariantDims_[i] = dim;
h = xyLengths[1]; invariant_lengths_[i] = xyLengths[dim];
w = xyLengths[2]; i++;
c = xyLengths[3]; };
// get reduce_lengths_[]
for(int j = 0, i = 0; j < NumBatchNormReduceDim; j++)
{
int dim = reduceDims[j];
reduce_lengths_[i++] = xyLengths[dim];
};
for(int i = 0; i < NumInvariantDim; i++)
if(invariant_lengths_[i] != bnScaleBiasMeanVarLengths_[i])
throw std::runtime_error("Invalid lengths parameters!");
for(int j = 0, i = 0; j < NumInvariantDim; j++)
{
int dim = invariantDims_[j];
x_invariant_strides_[i] = xStrides[dim];
y_invariant_strides_[i] = yStrides[dim];
i++;
};
for(int j = 0, i = 0; j < NumBatchNormReduceDim; j++)
{
int dim = reduceDims_[j];
x_reduce_strides_[i] = xStrides[dim];
y_reduce_strides_[i] = yStrides[dim];
i++;
};
invariant_index_set_ = get_index_set<NumInvariantDim>(invariant_lengths_);
reduce_index_set_ = get_index_set<NumBatchNormReduceDim>(reduce_lengths_);
epsilon_ = type_convert<AccDataType>(epsilon); epsilon_ = type_convert<AccDataType>(epsilon);
averageFactor_ = type_convert<AccDataType>(averageFactor); averageFactor_ = type_convert<AccDataType>(averageFactor);
...@@ -81,6 +129,21 @@ struct ReferenceBatchNormFwd_Input_N_H_W_C_Output_C ...@@ -81,6 +129,21 @@ struct ReferenceBatchNormFwd_Input_N_H_W_C_Output_C
resultRunning = (resultRunningMean != nullptr && resultRunningVariance != nullptr); resultRunning = (resultRunningMean != nullptr && resultRunningVariance != nullptr);
} }
std::array<int, NumBatchNormReduceDim> reduceDims_;
std::array<int, NumInvariantDim> invariantDims_;
std::array<index_t, NumInvariantDim> invariant_lengths_;
std::array<index_t, NumBatchNormReduceDim> reduce_lengths_;
const std::array<index_t, NumInvariantDim> bnScaleBiasMeanVarLengths_;
const std::array<index_t, NumInvariantDim> bnScaleStrides_;
const std::array<index_t, NumInvariantDim> bnBiasStrides_;
const std::array<index_t, NumInvariantDim> bnMeanVarStrides_;
std::array<index_t, NumInvariantDim> x_invariant_strides_;
std::array<index_t, NumInvariantDim> y_invariant_strides_;
std::array<index_t, NumBatchNormReduceDim> x_reduce_strides_;
std::array<index_t, NumBatchNormReduceDim> y_reduce_strides_;
const XDataType* p_x_; const XDataType* p_x_;
const ScaleDataType* bnScale_; const ScaleDataType* bnScale_;
const BiasDataType* bnBias_; const BiasDataType* bnBias_;
...@@ -94,7 +157,8 @@ struct ReferenceBatchNormFwd_Input_N_H_W_C_Output_C ...@@ -94,7 +157,8 @@ struct ReferenceBatchNormFwd_Input_N_H_W_C_Output_C
bool resultSave, resultRunning; bool resultSave, resultRunning;
index_t n, h, w, c; std::vector<std::array<index_t, NumInvariantDim>> invariant_index_set_;
std::vector<std::array<index_t, NumBatchNormReduceDim>> reduce_index_set_;
AccDataType averageFactor_; AccDataType averageFactor_;
AccDataType epsilon_; AccDataType epsilon_;
...@@ -104,105 +168,119 @@ struct ReferenceBatchNormFwd_Input_N_H_W_C_Output_C ...@@ -104,105 +168,119 @@ struct ReferenceBatchNormFwd_Input_N_H_W_C_Output_C
{ {
float Run(const Argument& arg) float Run(const Argument& arg)
{ {
auto thread_reduce_func = [&](auto iC) { using ck::host_common::get_offset_from_index;
index_t offset_C = iC;
auto thread_reduce_func = [&](auto invariant_index) {
size_t x_invariant_offset = get_offset_from_index<NumInvariantDim>(
arg.x_invariant_strides_, invariant_index);
size_t y_invariant_offset = get_offset_from_index<NumInvariantDim>(
arg.y_invariant_strides_, invariant_index);
AccDataType mean = type_convert<AccDataType>(0.0f); AccDataType mean = type_convert<AccDataType>(0.0f);
AccDataType variance = type_convert<AccDataType>(0.0f); AccDataType variance = type_convert<AccDataType>(0.0f);
int32_t curr_count = 0; int32_t curr_count = 0;
// compute mean, variance using welford method // compute mean, variance using welford method
for(index_t iN = 0; iN < arg.n; iN++) for(const auto& reduce_index : arg.reduce_index_set_)
{ {
index_t offset_N = iN * arg.h * arg.w * arg.c; size_t x_reduce_offset = get_offset_from_index<NumBatchNormReduceDim>(
for(index_t iH = 0; iH < arg.h; iH++) arg.x_reduce_strides_, reduce_index);
{
index_t offset_H = iH * arg.w * arg.c;
for(index_t iW = 0; iW < arg.w; iW++)
{
index_t offset_W = iW * arg.c;
auto offset = offset_N + offset_H + offset_W + offset_C; auto x_offset = x_invariant_offset + x_reduce_offset;
curr_count++; curr_count++;
AccDataType x = type_convert<AccDataType>(arg.p_x_[offset]); AccDataType x = type_convert<AccDataType>(arg.p_x_[x_offset]);
AccDataType delta = x - mean; AccDataType delta = x - mean;
mean += delta / curr_count; mean += delta / curr_count;
AccDataType delta2 = x - mean; AccDataType delta2 = x - mean;
variance += delta * delta2; variance += delta * delta2;
};
}
}; };
// actual variance // actual variance
variance = variance / curr_count; variance = variance / curr_count;
// inv-variance defined as 1/sqrt(epsilon+variance)
AccDataType invVariance = AccDataType invVariance =
type_convert<AccDataType>(1.0f) / ck::math::sqrt(arg.epsilon_ + variance); type_convert<AccDataType>(1.0f) / ck::math::sqrt(arg.epsilon_ + variance);
// save the mean/invVariance if required // save the mean/inv-variance if required
if(arg.resultSave) if(arg.resultSave)
{ {
arg.resultSaveMean_[iC] = type_convert<MeanVarDataType>(mean); size_t offset = get_offset_from_index<NumInvariantDim>(arg.bnMeanVarStrides_,
arg.resultSaveInvVariance_[iC] = type_convert<MeanVarDataType>(invVariance); invariant_index);
arg.resultSaveMean_[offset] = type_convert<MeanVarDataType>(mean);
arg.resultSaveInvVariance_[offset] = type_convert<MeanVarDataType>(invVariance);
}; };
// update the moving average if required // update the moving average if required
if(arg.resultRunning) if(arg.resultRunning)
{ {
size_t offset = get_offset_from_index<NumInvariantDim>(arg.bnMeanVarStrides_,
invariant_index);
AccDataType oneMinusAverageFactor = AccDataType oneMinusAverageFactor =
type_convert<AccDataType>(1.0) - arg.averageFactor_; type_convert<AccDataType>(1.0) - arg.averageFactor_;
arg.resultRunningMean_[iC] = type_convert<MeanVarDataType>( arg.resultRunningMean_[offset] = type_convert<MeanVarDataType>(
type_convert<AccDataType>(arg.resultRunningMean_[iC]) * type_convert<AccDataType>(arg.resultRunningMean_[offset]) *
oneMinusAverageFactor + oneMinusAverageFactor +
mean * arg.averageFactor_); mean * arg.averageFactor_);
arg.resultRunningVariance_[iC] = type_convert<MeanVarDataType>( arg.resultRunningVariance_[offset] = type_convert<MeanVarDataType>(
arg.resultRunningVariance_[iC] * oneMinusAverageFactor + arg.resultRunningVariance_[offset] * oneMinusAverageFactor +
variance * arg.averageFactor_); variance * arg.averageFactor_);
}; };
size_t scale_offset =
get_offset_from_index<NumInvariantDim>(arg.bnScaleStrides_, invariant_index);
size_t bias_offset =
get_offset_from_index<NumInvariantDim>(arg.bnBiasStrides_, invariant_index);
AccDataType scale = type_convert<AccDataType>(arg.bnScale_[scale_offset]);
AccDataType bias = type_convert<AccDataType>(arg.bnBias_[bias_offset]);
// Normalization // Normalization
for(index_t iN = 0; iN < arg.n; iN++) for(const auto& reduce_index : arg.reduce_index_set_)
{ {
index_t offset_N = iN * arg.h * arg.w * arg.c; size_t x_reduce_offset = get_offset_from_index<NumBatchNormReduceDim>(
for(index_t iH = 0; iH < arg.h; iH++) arg.x_reduce_strides_, reduce_index);
{ size_t y_reduce_offset = get_offset_from_index<NumBatchNormReduceDim>(
index_t offset_H = iH * arg.w * arg.c; arg.y_reduce_strides_, reduce_index);
for(index_t iW = 0; iW < arg.w; iW++)
{
index_t offset_W = iW * arg.c;
auto offset = offset_N + offset_H + offset_W + offset_C; auto x_offset = x_invariant_offset + x_reduce_offset;
auto y_offset = y_invariant_offset + y_reduce_offset;
AccDataType x = type_convert<AccDataType>(arg.p_x_[offset]); AccDataType x = type_convert<AccDataType>(arg.p_x_[x_offset]);
AccDataType norm_x = AccDataType norm_x = (x - mean) * invVariance;
arg.bnScale_[iC] * (x - mean) * invVariance + arg.bnBias_[iC];
arg.p_y_[offset] = type_convert<YDataType>(norm_x); AccDataType y = scale * norm_x + bias;
};
} arg.y_elementwise_op_(y, y);
arg.p_y_[y_offset] = type_convert<YDataType>(y);
}; };
}; };
std::size_t num_thread = std::thread::hardware_concurrency(); std::size_t num_thread = std::thread::hardware_concurrency();
std::size_t work_per_thread = (arg.c + num_thread - 1) / num_thread; std::size_t work_per_thread =
(arg.invariant_index_set_.size() + num_thread - 1) / num_thread;
std::vector<joinable_thread> threads(num_thread); std::vector<joinable_thread> threads(num_thread);
for(std::size_t it = 0; it < num_thread; ++it) for(std::size_t it = 0; it < num_thread; ++it)
{ {
std::size_t ic_begin = it * work_per_thread; std::size_t i_begin = it * work_per_thread;
std::size_t ic_end = std::min(static_cast<int>((it + 1) * work_per_thread), arg.c); std::size_t i_end = std::min(static_cast<size_t>((it + 1) * work_per_thread),
arg.invariant_index_set_.size());
auto f = [=] { auto f = [=] {
for(std::size_t ic = ic_begin; ic < ic_end; ++ic) for(std::size_t i = i_begin; i < i_end; ++i)
{ {
thread_reduce_func(ic); thread_reduce_func(arg.invariant_index_set_[i]);
} }
}; };
...@@ -278,7 +356,7 @@ struct ReferenceBatchNormFwd_Input_N_H_W_C_Output_C ...@@ -278,7 +356,7 @@ struct ReferenceBatchNormFwd_Input_N_H_W_C_Output_C
auto str = std::stringstream(); auto str = std::stringstream();
// clang-format off // clang-format off
str << "Reference_BatchNorm_Forward_NHWC_C<" << std::endl; str << "Reference_BatchNorm_Forward" << std::endl;
// clang-format on // clang-format on
return str.str(); return str.str();
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
#include <array> #include <array>
#include <algorithm> #include <algorithm>
#include "ck/library/utility/host_common_util.hpp"
#include "ck/tensor_operation/gpu/device/device_batchnorm_infer.hpp" #include "ck/tensor_operation/gpu/device/device_batchnorm_infer.hpp"
namespace ck { namespace ck {
...@@ -19,114 +20,205 @@ template <typename XDataType, ...@@ -19,114 +20,205 @@ template <typename XDataType,
typename AccDataType, typename AccDataType,
typename ScaleDataType, typename ScaleDataType,
typename BiasDataType, typename BiasDataType,
typename MeanVarDataType> typename MeanVarDataType,
struct ReferenceBatchNormInfer_Input_N_H_W_C_Output_C : public device::DeviceBatchNormInfer<4, 3> typename YElementwiseOp,
index_t Rank,
index_t NumBatchNormReduceDim>
struct ReferenceBatchNormInfer : public device::DeviceBatchNormInfer<XDataType,
YDataType,
AccDataType,
ScaleDataType,
BiasDataType,
MeanVarDataType,
YElementwiseOp,
Rank,
NumBatchNormReduceDim>
{ {
static_assert(Rank <= 6, "Bigger Rank size is not supported!");
static constexpr index_t NumInvariantDim = Rank - NumBatchNormReduceDim;
struct Argument : public device::BaseArgument struct Argument : public device::BaseArgument
{ {
Argument(const std::array<index_t, 4> xyLengths, Argument(const std::array<index_t, Rank> xyLengths,
const std::array<index_t, 4> xStrides, const std::array<index_t, Rank> xStrides,
const std::array<index_t, 4> yStrides, const std::array<index_t, Rank> yStrides,
const std::array<index_t, 1> bnScaleBiasMeanVarLengths, const std::array<int, NumBatchNormReduceDim> reduceDims,
const std::array<index_t, 1> bnScaleStrides, const std::array<index_t, NumInvariantDim> bnScaleBiasMeanVarLengths,
const std::array<index_t, 1> bnBiasStrides, const std::array<index_t, NumInvariantDim> bnScaleStrides,
const std::array<index_t, 1> bnMeanVarStrides, const std::array<index_t, NumInvariantDim> bnBiasStrides,
const std::array<index_t, NumInvariantDim> bnMeanVarStrides,
const XDataType* p_x, const XDataType* p_x,
const ScaleDataType* bnScale, const ScaleDataType* bnScale,
const BiasDataType* bnBias, const BiasDataType* bnBias,
double epsilon, double epsilon,
const YElementwiseOp y_elementwise_op,
const MeanVarDataType* estimatedMean, const MeanVarDataType* estimatedMean,
const MeanVarDataType* estimatedVariance, const MeanVarDataType* estimatedVariance,
YDataType* p_y) YDataType* p_y)
: p_x_(p_x), : reduceDims_(reduceDims),
bnScaleBiasMeanVarLengths_(bnScaleBiasMeanVarLengths),
bnScaleStrides_(bnScaleStrides),
bnBiasStrides_(bnBiasStrides),
bnMeanVarStrides_(bnMeanVarStrides),
p_x_(p_x),
bnScale_(bnScale), bnScale_(bnScale),
bnBias_(bnBias), bnBias_(bnBias),
epsilon_(epsilon), y_elementwise_op_(y_elementwise_op),
estimatedMean_(estimatedMean), estimatedMean_(estimatedMean),
estimatedVariance_(estimatedVariance), estimatedVariance_(estimatedVariance),
p_y_(p_y) p_y_(p_y)
{ {
ignore = xStrides; using ck::host_common::get_index_set;
ignore = yStrides;
ignore = bnScaleStrides; if(std::any_of(
ignore = bnBiasStrides; reduceDims.begin(), reduceDims.end(), [](int d) { return d < 0 || d >= Rank; }))
ignore = bnMeanVarStrides; throw std::runtime_error("Invalid reduce dimensions!");
if(xyLengths.size() != 4 || bnScaleBiasMeanVarLengths.size() != 1 || // get invariant_dims[] and invariant_lengths[]
bnScaleBiasMeanVarLengths[0] != xyLengths[3]) for(int dim = 0, i = 0; dim < Rank; dim++)
throw std::runtime_error("Invalid tensor dimensions!"); if(std::none_of(
reduceDims.begin(), reduceDims.end(), [&](int d) { return d == dim; }))
n_ = xyLengths[0]; {
h_ = xyLengths[1]; invariantDims_[i] = dim;
w_ = xyLengths[2]; invariant_lengths_[i] = xyLengths[dim];
c_ = xyLengths[3]; i++;
};
// get reduce_lengths_[]
for(int j = 0, i = 0; j < NumBatchNormReduceDim; j++)
{
int dim = reduceDims[j];
reduce_lengths_[i++] = xyLengths[dim];
};
// check invariant_lengths_ and bnScaleBiasMeanVarLengths
for(int i = 0; i < NumInvariantDim; i++)
if(invariant_lengths_[i] != bnScaleBiasMeanVarLengths_[i])
throw std::runtime_error("Invalid lengths parameters!");
for(int j = 0, i = 0; j < NumInvariantDim; j++)
{
int dim = invariantDims_[j];
x_invariant_strides_[i] = xStrides[dim];
y_invariant_strides_[i] = yStrides[dim];
i++;
};
for(int j = 0, i = 0; j < NumBatchNormReduceDim; j++)
{
int dim = reduceDims_[j];
x_reduce_strides_[i] = xStrides[dim];
y_reduce_strides_[i] = yStrides[dim];
i++;
};
invariant_index_set_ = get_index_set<NumInvariantDim>(invariant_lengths_);
reduce_index_set_ = get_index_set<NumBatchNormReduceDim>(reduce_lengths_);
epsilon_ = type_convert<AccDataType>(epsilon);
} }
std::array<int, NumBatchNormReduceDim> reduceDims_;
std::array<int, NumInvariantDim> invariantDims_;
std::array<index_t, NumInvariantDim> invariant_lengths_;
std::array<index_t, NumBatchNormReduceDim> reduce_lengths_;
const std::array<index_t, NumInvariantDim> bnScaleBiasMeanVarLengths_;
const std::array<index_t, NumInvariantDim> bnScaleStrides_;
const std::array<index_t, NumInvariantDim> bnBiasStrides_;
const std::array<index_t, NumInvariantDim> bnMeanVarStrides_;
std::array<index_t, NumInvariantDim> x_invariant_strides_;
std::array<index_t, NumInvariantDim> y_invariant_strides_;
std::array<index_t, NumBatchNormReduceDim> x_reduce_strides_;
std::array<index_t, NumBatchNormReduceDim> y_reduce_strides_;
const XDataType* p_x_; const XDataType* p_x_;
const ScaleDataType* bnScale_; const ScaleDataType* bnScale_;
const BiasDataType* bnBias_; const BiasDataType* bnBias_;
const YElementwiseOp y_elementwise_op_;
double epsilon_;
const MeanVarDataType* estimatedMean_; const MeanVarDataType* estimatedMean_;
const MeanVarDataType* estimatedVariance_; const MeanVarDataType* estimatedVariance_;
YDataType* p_y_; YDataType* p_y_;
index_t n_, h_, w_, c_; std::vector<std::array<index_t, NumInvariantDim>> invariant_index_set_;
std::vector<std::array<index_t, NumBatchNormReduceDim>> reduce_index_set_;
AccDataType epsilon_;
}; };
struct Invoker : public device::BaseInvoker struct Invoker : public device::BaseInvoker
{ {
float Run(const Argument& arg) float Run(const Argument& arg)
{ {
auto thread_reduce_func = [&](auto iC) { using ck::host_common::get_offset_from_index;
index_t offset_C = iC;
AccDataType mean = arg.estimatedMean_[offset_C]; auto thread_reduce_func = [&](auto invariant_index) {
AccDataType variance = arg.estimatedVariance_[offset_C]; size_t x_invariant_offset = get_offset_from_index<NumInvariantDim>(
arg.x_invariant_strides_, invariant_index);
size_t y_invariant_offset = get_offset_from_index<NumInvariantDim>(
arg.y_invariant_strides_, invariant_index);
size_t mean_variance_offset =
get_offset_from_index<NumInvariantDim>(arg.bnMeanVarStrides_, invariant_index);
AccDataType mean = arg.estimatedMean_[mean_variance_offset];
AccDataType variance = arg.estimatedVariance_[mean_variance_offset];
// inv-variance defined as 1/sqrt(epsilon+variance)
AccDataType invVariance = AccDataType invVariance =
type_convert<AccDataType>(1.0f) / type_convert<AccDataType>(1.0f) / std::sqrt(arg.epsilon_ + variance);
std::sqrt(type_convert<AccDataType>(arg.epsilon_) + variance);
size_t scale_offset =
get_offset_from_index<NumInvariantDim>(arg.bnScaleStrides_, invariant_index);
size_t bias_offset =
get_offset_from_index<NumInvariantDim>(arg.bnBiasStrides_, invariant_index);
AccDataType scale = type_convert<AccDataType>(arg.bnScale_[scale_offset]);
AccDataType bias = type_convert<AccDataType>(arg.bnBias_[bias_offset]);
// Normalization // normalization
for(index_t iN = 0; iN < arg.n_; iN++) for(const auto& reduce_index : arg.reduce_index_set_)
{ {
index_t offset_N = iN * arg.h_ * arg.w_ * arg.c_; size_t x_reduce_offset = get_offset_from_index<NumBatchNormReduceDim>(
for(index_t iH = 0; iH < arg.h_; iH++) arg.x_reduce_strides_, reduce_index);
{ size_t y_reduce_offset = get_offset_from_index<NumBatchNormReduceDim>(
index_t offset_H = iH * arg.w_ * arg.c_; arg.y_reduce_strides_, reduce_index);
for(index_t iW = 0; iW < arg.w_; iW++)
{
index_t offset_W = iW * arg.c_;
auto offset = offset_N + offset_H + offset_W + offset_C; auto x_offset = x_invariant_offset + x_reduce_offset;
auto y_offset = y_invariant_offset + y_reduce_offset;
AccDataType x = type_convert<AccDataType>(arg.p_x_[offset]); AccDataType x = type_convert<AccDataType>(arg.p_x_[x_offset]);
AccDataType norm_x = AccDataType norm_x = (x - mean) * invVariance;
arg.bnScale_[iC] * (x - mean) * invVariance + arg.bnBias_[iC];
arg.p_y_[offset] = type_convert<YDataType>(norm_x); AccDataType y = scale * norm_x + bias;
};
} arg.y_elementwise_op_(y, y);
arg.p_y_[y_offset] = type_convert<YDataType>(y);
}; };
}; };
std::size_t num_thread = std::thread::hardware_concurrency(); std::size_t num_thread = std::thread::hardware_concurrency();
std::size_t work_per_thread = (arg.c_ + num_thread - 1) / num_thread; std::size_t work_per_thread =
(arg.invariant_index_set_.size() + num_thread - 1) / num_thread;
std::vector<joinable_thread> threads(num_thread); std::vector<joinable_thread> threads(num_thread);
for(std::size_t it = 0; it < num_thread; ++it) for(std::size_t it = 0; it < num_thread; ++it)
{ {
std::size_t ic_begin = it * work_per_thread; std::size_t i_begin = it * work_per_thread;
std::size_t ic_end = std::min(static_cast<int>((it + 1) * work_per_thread), arg.c_); std::size_t i_end = std::min(static_cast<size_t>((it + 1) * work_per_thread),
arg.invariant_index_set_.size());
auto f = [=] { auto f = [=] {
for(std::size_t ic = ic_begin; ic < ic_end; ++ic) for(std::size_t i = i_begin; i < i_end; ++i)
{ {
thread_reduce_func(ic); thread_reduce_func(arg.invariant_index_set_[i]);
} }
}; };
...@@ -151,17 +243,19 @@ struct ReferenceBatchNormInfer_Input_N_H_W_C_Output_C : public device::DeviceBat ...@@ -151,17 +243,19 @@ struct ReferenceBatchNormInfer_Input_N_H_W_C_Output_C : public device::DeviceBat
}; };
std::unique_ptr<device::BaseArgument> std::unique_ptr<device::BaseArgument>
MakeArgumentPointer(const std::array<index_t, 4> xyLengths, MakeArgumentPointer(const std::array<index_t, Rank> xyLengths,
const std::array<index_t, 4> xStrides, const std::array<index_t, Rank> xStrides,
const std::array<index_t, 4> yStrides, const std::array<index_t, Rank> yStrides,
const std::array<index_t, 1> bnScaleBiasMeanVarLengths, const std::array<int, NumBatchNormReduceDim> reduceDims,
const std::array<index_t, 1> bnScaleStrides, const std::array<index_t, NumInvariantDim> bnScaleBiasMeanVarLengths,
const std::array<index_t, 1> bnBiasStrides, const std::array<index_t, NumInvariantDim> bnScaleStrides,
const std::array<index_t, 1> bnMeanVarStrides, const std::array<index_t, NumInvariantDim> bnBiasStrides,
const std::array<index_t, NumInvariantDim> bnMeanVarStrides,
const void* p_x, const void* p_x,
const void* bnScale, const void* bnScale,
const void* bnBias, const void* bnBias,
double epsilon, double epsilon,
const YElementwiseOp y_elementwise_op,
const void* estimatedMean, const void* estimatedMean,
const void* estimatedVariance, const void* estimatedVariance,
void* p_y) override void* p_y) override
...@@ -169,6 +263,7 @@ struct ReferenceBatchNormInfer_Input_N_H_W_C_Output_C : public device::DeviceBat ...@@ -169,6 +263,7 @@ struct ReferenceBatchNormInfer_Input_N_H_W_C_Output_C : public device::DeviceBat
return std::make_unique<Argument>(xyLengths, return std::make_unique<Argument>(xyLengths,
xStrides, xStrides,
yStrides, yStrides,
reduceDims,
bnScaleBiasMeanVarLengths, bnScaleBiasMeanVarLengths,
bnScaleStrides, bnScaleStrides,
bnBiasStrides, bnBiasStrides,
...@@ -177,6 +272,7 @@ struct ReferenceBatchNormInfer_Input_N_H_W_C_Output_C : public device::DeviceBat ...@@ -177,6 +272,7 @@ struct ReferenceBatchNormInfer_Input_N_H_W_C_Output_C : public device::DeviceBat
static_cast<const ScaleDataType*>(bnScale), static_cast<const ScaleDataType*>(bnScale),
static_cast<const BiasDataType*>(bnBias), static_cast<const BiasDataType*>(bnBias),
epsilon, epsilon,
y_elementwise_op,
static_cast<const MeanVarDataType*>(estimatedMean), static_cast<const MeanVarDataType*>(estimatedMean),
static_cast<const MeanVarDataType*>(estimatedVariance), static_cast<const MeanVarDataType*>(estimatedVariance),
static_cast<YDataType*>(p_y)); static_cast<YDataType*>(p_y));
...@@ -192,7 +288,7 @@ struct ReferenceBatchNormInfer_Input_N_H_W_C_Output_C : public device::DeviceBat ...@@ -192,7 +288,7 @@ struct ReferenceBatchNormInfer_Input_N_H_W_C_Output_C : public device::DeviceBat
auto str = std::stringstream(); auto str = std::stringstream();
// clang-format off // clang-format off
str << "Reference_BatchNorm_Forward_NHWC_C<" << std::endl; str << "Reference_BatchNorm_Infer<" << std::endl;
// clang-format on // clang-format on
return str.str(); return str.str();
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/device_batchnorm_forward.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// FP16
void add_device_batchnorm_forward_rank_4_3_f16_instances(
std::vector<
std::unique_ptr<DeviceBatchNormFwd<F16, F16, F32, F16, F16, F32, PassThrough, 4, 3>>>&);
// FP32
void add_device_batchnorm_forward_rank_4_3_f32_instances(
std::vector<
std::unique_ptr<DeviceBatchNormFwd<F32, F32, F32, F32, F32, F32, PassThrough, 4, 3>>>&);
// BF16
void add_device_batchnorm_forward_rank_4_3_bf16_instances(
std::vector<
std::unique_ptr<DeviceBatchNormFwd<BF16, BF16, F32, BF16, BF16, F32, PassThrough, 4, 3>>>&);
// Int8
void add_device_batchnorm_forward_rank_4_3_i8_instances(
std::vector<std::unique_ptr<DeviceBatchNormFwd<I8, I8, F32, I8, I8, F32, PassThrough, 4, 3>>>&);
// FP64
void add_device_batchnorm_forward_rank_4_3_f64_instances(
std::vector<
std::unique_ptr<DeviceBatchNormFwd<F64, F64, F64, F64, F64, F64, PassThrough, 4, 3>>>&);
template <typename XDataType,
typename YDataType,
typename AccDataType,
typename ScaleDataType,
typename BiasDataType,
typename MeanVarDataType,
typename YElementwiseOp,
index_t Rank,
index_t NumReduceDim>
struct DeviceOperationInstanceFactory<
ck::tensor_operation::device::DeviceBatchNormFwd<XDataType,
YDataType,
AccDataType,
ScaleDataType,
BiasDataType,
MeanVarDataType,
YElementwiseOp,
Rank,
NumReduceDim>>
{
using DeviceOp = DeviceBatchNormFwd<XDataType,
YDataType,
AccDataType,
ScaleDataType,
BiasDataType,
MeanVarDataType,
YElementwiseOp,
Rank,
NumReduceDim>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
if constexpr(is_same_v<XDataType, F16> && is_same_v<YDataType, F16> &&
is_same_v<AccDataType, F32> && is_same_v<ScaleDataType, F16> &&
is_same_v<BiasDataType, F16> && is_same_v<MeanVarDataType, F32>)
{
if constexpr(Rank == 4 && NumReduceDim == 3 && is_same_v<YElementwiseOp, PassThrough>)
{
add_device_batchnorm_forward_rank_4_3_f16_instances(op_ptrs);
}
}
else if constexpr(is_same_v<XDataType, F32> && is_same_v<YDataType, F32> &&
is_same_v<AccDataType, F32> && is_same_v<ScaleDataType, F32> &&
is_same_v<BiasDataType, F32> && is_same_v<MeanVarDataType, F32>)
{
if constexpr(Rank == 4 && NumReduceDim == 3 && is_same_v<YElementwiseOp, PassThrough>)
{
add_device_batchnorm_forward_rank_4_3_f32_instances(op_ptrs);
}
}
else if constexpr(is_same_v<XDataType, BF16> && is_same_v<YDataType, BF16> &&
is_same_v<AccDataType, F32> && is_same_v<ScaleDataType, BF16> &&
is_same_v<BiasDataType, BF16> && is_same_v<MeanVarDataType, F32>)
{
if constexpr(Rank == 4 && NumReduceDim == 3 && is_same_v<YElementwiseOp, PassThrough>)
{
add_device_batchnorm_forward_rank_4_3_bf16_instances(op_ptrs);
}
}
else if constexpr(is_same_v<XDataType, I8> && is_same_v<YDataType, I8> &&
is_same_v<AccDataType, F32> && is_same_v<ScaleDataType, I8> &&
is_same_v<BiasDataType, I8> && is_same_v<MeanVarDataType, F32>)
{
if constexpr(Rank == 4 && NumReduceDim == 3 && is_same_v<YElementwiseOp, PassThrough>)
{
add_device_batchnorm_forward_rank_4_3_i8_instances(op_ptrs);
}
}
else if constexpr(is_same_v<XDataType, F64> && is_same_v<YDataType, F64> &&
is_same_v<AccDataType, F64> && is_same_v<ScaleDataType, F64> &&
is_same_v<BiasDataType, F64> && is_same_v<MeanVarDataType, F64>)
{
if constexpr(Rank == 4 && NumReduceDim == 3 && is_same_v<YElementwiseOp, PassThrough>)
{
add_device_batchnorm_forward_rank_4_3_f64_instances(op_ptrs);
}
}
return op_ptrs;
}
};
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -4,9 +4,11 @@ ...@@ -4,9 +4,11 @@
#pragma once #pragma once
#include <vector> #include <vector>
#include <array>
#include <iostream> #include <iostream>
#include <fstream> #include <fstream>
#include <string> #include <string>
#include <algorithm>
#include "ck/ck.hpp" #include "ck/ck.hpp"
...@@ -72,5 +74,63 @@ static inline std::vector<T> getTypeValuesFromString(const char* cstr_values) ...@@ -72,5 +74,63 @@ static inline std::vector<T> getTypeValuesFromString(const char* cstr_values)
return (values); return (values);
} }
template <int NDim>
static inline std::vector<std::array<index_t, NDim>>
get_index_set(const std::array<index_t, NDim>& dim_lengths)
{
static_assert(NDim >= 1, "NDim >= 1 is required to use this function!");
if constexpr(NDim == 1)
{
std::vector<std::array<index_t, NDim>> index_set;
for(int i = 0; i < dim_lengths[0]; i++)
{
std::array<index_t, 1> index{i};
index_set.push_back(index);
};
return index_set;
}
else
{
std::vector<std::array<index_t, NDim>> index_set;
std::array<index_t, NDim - 1> partial_dim_lengths;
std::copy(dim_lengths.begin() + 1, dim_lengths.end(), partial_dim_lengths.begin());
std::vector<std::array<index_t, NDim - 1>> partial_index_set;
partial_index_set = get_index_set<NDim - 1>(partial_dim_lengths);
for(index_t i = 0; i < dim_lengths[0]; i++)
for(const auto& partial_index : partial_index_set)
{
std::array<index_t, NDim> index;
index[0] = i;
std::copy(partial_index.begin(), partial_index.end(), index.begin() + 1);
index_set.push_back(index);
};
return index_set;
};
};
template <int NDim>
static inline size_t get_offset_from_index(const std::array<index_t, NDim>& strides,
const std::array<index_t, NDim>& index)
{
size_t offset = 0;
for(int i = 0; i < NDim; i++)
offset += index[i] * strides[i];
return (offset);
};
} // namespace host_common } // namespace host_common
} // namespace ck } // namespace ck
add_instance_library(device_batchnorm_instance
device_batchnorm_forward_f16_instance.cpp
device_batchnorm_forward_f32_instance.cpp
device_batchnorm_forward_bf16_instance.cpp
device_batchnorm_forward_i8_instance.cpp
device_batchnorm_forward_f64_instance.cpp
)
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batchnorm_forward_impl.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using BF16 = ck::bhalf_t;
using F32 = float;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
// clang-format off
template <index_t Rank, index_t NumReduceDim, typename YElementwiseOp>
using device_batchnorm_forward_bf16_blockwise_instances =
std::tuple <
// XDataType, YDataType, AccDataType, ScaleDataType, BiasDataType, MeanVarDataType, YElementwiseOp, Rank, NumReduceDim, UseMultiBlockInK, BLockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XSrcYDstVectorDim, XSrcVectorSize, YDstVectorSize, ScaleSrcVectorSize, BiasSrcVectorSize, MeanVarSrcDstVectorSize
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 128, 2, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 128, 2, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 128, 2, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 128, 2, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 128, 2, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 128, 2, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 64, 4, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 64, 4, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 64, 4, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 64, 4, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 64, 4, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 64, 4, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 32, 8, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 32, 8, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 32, 8, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 32, 8, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 32, 8, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 32, 8, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 16, 16, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 16, 16, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 16, 16, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 16, 16, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 16, 16, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 16, 16, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 8, 32, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 8, 32, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 8, 32, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 8, 32, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 8, 32, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 8, 32, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 4, 64, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 4, 64, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 4, 64, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 4, 64, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 4, 64, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 4, 64, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 2, 128, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 2, 128, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 2, 128, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 2, 128, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 2, 128, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 2, 128, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 1, 256, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 1, 256, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 1, 256, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 1, 256, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 1, 256, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 1, 256, 2, 2, 1, 1, 1, 1, 1, 1>
>;
// clang-format on
// clang-format off
template <index_t Rank, index_t NumReduceDim, typename YElementwiseOp>
using device_batchnorm_forward_bf16_multiblock_instances =
std::tuple <
// XDataType, YDataType, AccDataType, ScaleDataType, BiasDataType, MeanVarDataType, YElementwiseOp, Rank, NumReduceDim, UseMultiBlockInK, BLockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XSrcYDstVectorDim, XSrcVectorSize, YDstVectorSize, ScaleSrcVectorSize, BiasSrcVectorSize, MeanVarSrcDstVectorSize
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 128, 2, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 128, 2, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 128, 2, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 128, 2, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 128, 2, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 128, 2, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 64, 4, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 64, 4, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 64, 4, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 64, 4, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 64, 4, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 64, 4, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 32, 8, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 32, 8, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 32, 8, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 32, 8, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 32, 8, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 32, 8, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 16, 16, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 16, 16, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 16, 16, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 16, 16, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 16, 16, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 16, 16, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 8, 32, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 8, 32, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 8, 32, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 8, 32, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 8, 32, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 8, 32, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 4, 64, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 4, 64, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 4, 64, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 4, 64, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 4, 64, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 4, 64, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 2, 128, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 2, 128, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 2, 128, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 2, 128, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 2, 128, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 2, 128, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 1, 256, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 1, 256, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 1, 256, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 1, 256, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 1, 256, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<BF16, BF16, F32, BF16, BF16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 1, 256, 2, 2, 1, 1, 1, 1, 1, 1>
>;
// clang-format on
void add_device_batchnorm_forward_rank_4_3_bf16_instances(
std::vector<
std::unique_ptr<DeviceBatchNormFwd<BF16, BF16, F32, BF16, BF16, F32, PassThrough, 4, 3>>>&
instances)
{
add_device_operation_instances(
instances, device_batchnorm_forward_bf16_blockwise_instances<4, 3, PassThrough>{});
add_device_operation_instances(
instances, device_batchnorm_forward_bf16_multiblock_instances<4, 3, PassThrough>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batchnorm_forward_impl.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using F16 = ck::half_t;
using F32 = float;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
// clang-format off
template <index_t Rank, index_t NumReduceDim, typename YElementwiseOp>
using device_batchnorm_forward_f16_blockwise_instances =
std::tuple <
// XDataType, YDataType, AccDataType, ScaleDataType, BiasDataType, MeanVarDataType, YElementwiseOp, Rank, NumReduceDim, UseMultiBlockInK, BLockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XSrcYDstVectorDim, XSrcVectorSize, YDstVectorSize, ScaleSrcVectorSize, BiasSrcVectorSize, MeanVarSrcDstVectorSize
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 128, 2, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 128, 2, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 128, 2, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 128, 2, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 128, 2, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 128, 2, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 64, 4, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 64, 4, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 64, 4, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 64, 4, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 64, 4, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 64, 4, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 32, 8, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 32, 8, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 32, 8, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 32, 8, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 32, 8, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 32, 8, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 16, 16, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 16, 16, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 16, 16, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 16, 16, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 16, 16, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 16, 16, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 8, 32, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 8, 32, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 8, 32, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 8, 32, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 8, 32, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 8, 32, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 4, 64, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 4, 64, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 4, 64, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 4, 64, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 4, 64, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 4, 64, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 2, 128, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 2, 128, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 2, 128, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 2, 128, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 2, 128, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 2, 128, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 1, 256, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 1, 256, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 1, 256, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 1, 256, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 1, 256, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 1, 256, 2, 2, 1, 1, 1, 1, 1, 1>
>;
// clang-format on
// clang-format off
template <index_t Rank, index_t NumReduceDim, typename YElementwiseOp>
using device_batchnorm_forward_f16_multiblock_instances =
std::tuple <
// XDataType, YDataType, AccDataType, ScaleDataType, BiasDataType, MeanVarDataType, YElementwiseOp, Rank, NumReduceDim, UseMultiBlockInK, BLockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XSrcYDstVectorDim, XSrcVectorSize, YDstVectorSize, ScaleSrcVectorSize, BiasSrcVectorSize, MeanVarSrcDstVectorSize
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 128, 2, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 128, 2, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 128, 2, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 128, 2, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 128, 2, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 128, 2, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 64, 4, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 64, 4, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 64, 4, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 64, 4, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 64, 4, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 64, 4, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 32, 8, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 32, 8, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 32, 8, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 32, 8, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 32, 8, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 32, 8, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 16, 16, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 16, 16, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 16, 16, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 16, 16, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 16, 16, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 16, 16, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 8, 32, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 8, 32, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 8, 32, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 8, 32, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 8, 32, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 8, 32, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 4, 64, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 4, 64, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 4, 64, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 4, 64, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 4, 64, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 4, 64, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 2, 128, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 2, 128, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 2, 128, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 2, 128, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 2, 128, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 2, 128, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 1, 256, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 1, 256, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 1, 256, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 1, 256, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 1, 256, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F16, F16, F32, F16, F16, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 1, 256, 2, 2, 1, 1, 1, 1, 1, 1>
>;
// clang-format on
void add_device_batchnorm_forward_rank_4_3_f16_instances(
std::vector<
std::unique_ptr<DeviceBatchNormFwd<F16, F16, F32, F16, F16, F32, PassThrough, 4, 3>>>&
instances)
{
add_device_operation_instances(
instances, device_batchnorm_forward_f16_blockwise_instances<4, 3, PassThrough>{});
add_device_operation_instances(
instances, device_batchnorm_forward_f16_multiblock_instances<4, 3, PassThrough>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batchnorm_forward_impl.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using F32 = float;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
// clang-format off
template <index_t Rank, index_t NumReduceDim, typename YElementwiseOp>
using device_batchnorm_forward_f32_blockwise_instances = std::tuple<
// XDataType, YDataType, AccDataType, ScaleDataType, BiasDataType, MeanVarDataType, YElementwiseOp, Rank, NumReduceDim, UseMultiBlockInK, BLockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XSrcYDstVectorDim, XSrcVectorSize, YDstVectorSize, ScaleSrcVectorSize, BiasSrcVectorSize, MeanVarSrcDstVectorSize
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 128, 2, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 128, 2, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 128, 2, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 128, 2, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 128, 2, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 128, 2, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 64, 4, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 64, 4, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 64, 4, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 64, 4, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 64, 4, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 64, 4, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 32, 8, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 32, 8, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 32, 8, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 32, 8, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 32, 8, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 32, 8, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 16, 16, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 16, 16, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 16, 16, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 16, 16, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 16, 16, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 16, 16, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 8, 32, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 8, 32, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 8, 32, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 8, 32, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 8, 32, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 8, 32, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 4, 64, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 4, 64, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 4, 64, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 4, 64, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 4, 64, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 4, 64, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 2, 128, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 2, 128, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 2, 128, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 2, 128, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 2, 128, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 2, 128, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 1, 256, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 1, 256, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 1, 256, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 1, 256, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 1, 256, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 1, 256, 2, 2, 1, 1, 1, 1, 1, 1>
>;
// clang-format on
// clang-format off
template <index_t Rank, index_t NumReduceDim, typename YElementwiseOp>
using device_batchnorm_forward_f32_multiblock_instances =
std::tuple <
// XDataType, YDataType, AccDataType, ScaleDataType, BiasDataType, MeanVarDataType, YElementwiseOp, Rank, NumReduceDim, UseMultiBlockInK, BLockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XSrcYDstVectorDim, XSrcVectorSize, YDstVectorSize, ScaleSrcVectorSize, BiasSrcVectorSize, MeanVarSrcDstVectorSize
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 128, 2, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 128, 2, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 128, 2, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 128, 2, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 128, 2, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 128, 2, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 64, 4, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 64, 4, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 64, 4, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 64, 4, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 64, 4, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 64, 4, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 32, 8, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 32, 8, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 32, 8, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 32, 8, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 32, 8, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 32, 8, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 16, 16, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 16, 16, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 16, 16, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 16, 16, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 16, 16, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 16, 16, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 8, 32, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 8, 32, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 8, 32, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 8, 32, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 8, 32, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 8, 32, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 4, 64, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 4, 64, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 4, 64, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 4, 64, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 4, 64, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 4, 64, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 2, 128, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 2, 128, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 2, 128, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 2, 128, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 2, 128, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 2, 128, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 1, 256, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 1, 256, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 1, 256, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 1, 256, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 1, 256, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F32, F32, F32, F32, F32, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 1, 256, 2, 2, 1, 1, 1, 1, 1, 1>
>;
// clang-format on
void add_device_batchnorm_forward_rank_4_3_f32_instances(
std::vector<
std::unique_ptr<DeviceBatchNormFwd<F32, F32, F32, F32, F32, F32, PassThrough, 4, 3>>>&
instances)
{
add_device_operation_instances(
instances, device_batchnorm_forward_f32_blockwise_instances<4, 3, PassThrough>{});
add_device_operation_instances(
instances, device_batchnorm_forward_f32_multiblock_instances<4, 3, PassThrough>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batchnorm_forward_impl.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using F64 = double;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
// clang-format off
template <index_t Rank, index_t NumReduceDim, typename YElementwiseOp>
using device_batchnorm_forward_f64_blockwise_instances = std::tuple<
// XDataType, YDataType, AccDataType, ScaleDataType, BiasDataType, MeanVarDataType, YElementwiseOp, Rank, NumReduceDim, UseMultiBlockInK, BLockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XSrcYDstVectorDim, XSrcVectorSize, YDstVectorSize, ScaleSrcVectorSize, BiasSrcVectorSize, MeanVarSrcDstVectorSize
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, false, 256, 128, 2, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, false, 256, 128, 2, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, false, 256, 128, 2, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, false, 256, 128, 2, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, false, 256, 128, 2, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, false, 256, 128, 2, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, false, 256, 64, 4, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, false, 256, 64, 4, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, false, 256, 64, 4, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, false, 256, 64, 4, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, false, 256, 64, 4, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, false, 256, 64, 4, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, false, 256, 32, 8, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, false, 256, 32, 8, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, false, 256, 32, 8, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, false, 256, 32, 8, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, false, 256, 32, 8, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, false, 256, 32, 8, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, false, 256, 16, 16, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, false, 256, 16, 16, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, false, 256, 16, 16, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, false, 256, 16, 16, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, false, 256, 16, 16, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, false, 256, 16, 16, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, false, 256, 8, 32, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, false, 256, 8, 32, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, false, 256, 8, 32, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, false, 256, 8, 32, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, false, 256, 8, 32, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, false, 256, 8, 32, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, false, 256, 4, 64, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, false, 256, 4, 64, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, false, 256, 4, 64, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, false, 256, 4, 64, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, false, 256, 4, 64, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, false, 256, 4, 64, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, false, 256, 2, 128, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, false, 256, 2, 128, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, false, 256, 2, 128, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, false, 256, 2, 128, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, false, 256, 2, 128, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, false, 256, 2, 128, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, false, 256, 1, 256, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, false, 256, 1, 256, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, false, 256, 1, 256, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, false, 256, 1, 256, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, false, 256, 1, 256, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, false, 256, 1, 256, 2, 2, 1, 1, 1, 1, 1, 1>
>;
// clang-format on
// clang-format off
template <index_t Rank, index_t NumReduceDim, typename YElementwiseOp>
using device_batchnorm_forward_f64_multiblock_instances =
std::tuple <
// XDataType, YDataType, AccDataType, ScaleDataType, BiasDataType, MeanVarDataType, YElementwiseOp, Rank, NumReduceDim, UseMultiBlockInK, BLockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XSrcYDstVectorDim, XSrcVectorSize, YDstVectorSize, ScaleSrcVectorSize, BiasSrcVectorSize, MeanVarSrcDstVectorSize
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, true, 256, 128, 2, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, true, 256, 128, 2, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, true, 256, 128, 2, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, true, 256, 128, 2, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, true, 256, 128, 2, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, true, 256, 128, 2, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, true, 256, 64, 4, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, true, 256, 64, 4, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, true, 256, 64, 4, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, true, 256, 64, 4, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, true, 256, 64, 4, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, true, 256, 64, 4, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, true, 256, 32, 8, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, true, 256, 32, 8, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, true, 256, 32, 8, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, true, 256, 32, 8, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, true, 256, 32, 8, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, true, 256, 32, 8, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, true, 256, 16, 16, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, true, 256, 16, 16, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, true, 256, 16, 16, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, true, 256, 16, 16, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, true, 256, 16, 16, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, true, 256, 16, 16, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, true, 256, 8, 32, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, true, 256, 8, 32, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, true, 256, 8, 32, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, true, 256, 8, 32, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, true, 256, 8, 32, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, true, 256, 8, 32, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, true, 256, 4, 64, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, true, 256, 4, 64, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, true, 256, 4, 64, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, true, 256, 4, 64, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, true, 256, 4, 64, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, true, 256, 4, 64, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, true, 256, 2, 128, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, true, 256, 2, 128, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, true, 256, 2, 128, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, true, 256, 2, 128, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, true, 256, 2, 128, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, true, 256, 2, 128, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, true, 256, 1, 256, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, true, 256, 1, 256, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, true, 256, 1, 256, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, true, 256, 1, 256, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, true, 256, 1, 256, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<F64, F64, F64, F64, F64, F64, YElementwiseOp, Rank, NumReduceDim, true, 256, 1, 256, 2, 2, 1, 1, 1, 1, 1, 1>
>;
// clang-format on
void add_device_batchnorm_forward_rank_4_3_f64_instances(
std::vector<
std::unique_ptr<DeviceBatchNormFwd<F64, F64, F64, F64, F64, F64, PassThrough, 4, 3>>>&
instances)
{
add_device_operation_instances(
instances, device_batchnorm_forward_f64_blockwise_instances<4, 3, PassThrough>{});
add_device_operation_instances(
instances, device_batchnorm_forward_f64_multiblock_instances<4, 3, PassThrough>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batchnorm_forward_impl.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using I8 = int8_t;
using F32 = float;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
// clang-format off
template <index_t Rank, index_t NumReduceDim, typename YElementwiseOp>
using device_batchnorm_forward_i8_blockwise_instances = std::tuple<
// XDataType, YDataType, AccDataType, ScaleDataType, BiasDataType, MeanVarDataType, YElementwiseOp, Rank, NumReduceDim, UseMultiBlockInK, BLockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XSrcYDstVectorDim, XSrcVectorSize, YDstVectorSize, ScaleSrcVectorSize, BiasSrcVectorSize, MeanVarSrcDstVectorSize
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 128, 2, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 128, 2, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 128, 2, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 128, 2, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 128, 2, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 128, 2, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 64, 4, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 64, 4, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 64, 4, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 64, 4, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 64, 4, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 64, 4, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 32, 8, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 32, 8, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 32, 8, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 32, 8, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 32, 8, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 32, 8, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 16, 16, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 16, 16, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 16, 16, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 16, 16, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 16, 16, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 16, 16, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 8, 32, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 8, 32, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 8, 32, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 8, 32, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 8, 32, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 8, 32, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 4, 64, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 4, 64, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 4, 64, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 4, 64, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 4, 64, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 4, 64, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 2, 128, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 2, 128, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 2, 128, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 2, 128, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 2, 128, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 2, 128, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 1, 256, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 1, 256, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 1, 256, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 1, 256, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 1, 256, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 1, 256, 2, 2, 1, 1, 1, 1, 1, 1>
>;
// clang-format on
// clang-format off
template <index_t Rank, index_t NumReduceDim, typename YElementwiseOp>
using device_batchnorm_forward_i8_multiblock_instances =
std::tuple <
// XDataType, YDataType, AccDataType, ScaleDataType, BiasDataType, MeanVarDataType, YElementwiseOp, Rank, NumReduceDim, UseMultiBlockInK, BLockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XSrcYDstVectorDim, XSrcVectorSize, YDstVectorSize, ScaleSrcVectorSize, BiasSrcVectorSize, MeanVarSrcDstVectorSize
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 128, 2, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 128, 2, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 128, 2, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 128, 2, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 128, 2, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 128, 2, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 64, 4, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 64, 4, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 64, 4, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 64, 4, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 64, 4, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 64, 4, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 32, 8, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 32, 8, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 32, 8, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 32, 8, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 32, 8, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 32, 8, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 16, 16, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 16, 16, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 16, 16, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 16, 16, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 16, 16, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 16, 16, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 8, 32, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 8, 32, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 8, 32, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 8, 32, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 8, 32, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 8, 32, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 4, 64, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 4, 64, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 4, 64, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 4, 64, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 4, 64, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 4, 64, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 2, 128, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 2, 128, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 2, 128, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 2, 128, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 2, 128, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 2, 128, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 1, 256, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 1, 256, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 1, 256, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 1, 256, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 1, 256, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 1, 256, 2, 2, 1, 1, 1, 1, 1, 1>
>;
// clang-format on
void add_device_batchnorm_forward_rank_4_3_i8_instances(
std::vector<std::unique_ptr<DeviceBatchNormFwd<I8, I8, F32, I8, I8, F32, PassThrough, 4, 3>>>&
instances)
{
add_device_operation_instances(
instances, device_batchnorm_forward_i8_blockwise_instances<4, 3, PassThrough>{});
add_device_operation_instances(
instances, device_batchnorm_forward_i8_multiblock_instances<4, 3, PassThrough>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -26,6 +26,7 @@ set(PROFILER_SOURCE ...@@ -26,6 +26,7 @@ set(PROFILER_SOURCE
src/profile_groupnorm.cpp src/profile_groupnorm.cpp
src/profile_layernorm.cpp src/profile_layernorm.cpp
src/profile_softmax.cpp src/profile_softmax.cpp
src/profile_batchnorm_fwd.cpp
) )
add_executable(ckProfiler ${PROFILER_SOURCE}) add_executable(ckProfiler ${PROFILER_SOURCE})
...@@ -57,5 +58,6 @@ target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_bias_relu_add_instanc ...@@ -57,5 +58,6 @@ target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_bias_relu_add_instanc
target_link_libraries(ckProfiler PRIVATE device_normalization_instance) target_link_libraries(ckProfiler PRIVATE device_normalization_instance)
target_link_libraries(ckProfiler PRIVATE device_softmax_instance) target_link_libraries(ckProfiler PRIVATE device_softmax_instance)
target_link_libraries(ckProfiler PRIVATE device_reduce_instance) target_link_libraries(ckProfiler PRIVATE device_reduce_instance)
target_link_libraries(ckProfiler PRIVATE device_batchnorm_instance)
rocm_install(TARGETS ckProfiler COMPONENT profiler) rocm_install(TARGETS ckProfiler COMPONENT profiler)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment