Unverified Commit 7e1acffc authored by Adam Osewski's avatar Adam Osewski Committed by GitHub
Browse files

Merge branch 'develop' into wavelet_model

parents 2c2a6277 0e9c88ce
add_example_executable(example_batchnorm_forward batchnorm_forward_nhwc.cpp) add_example_executable(example_batchnorm_forward_training batchnorm_forward_training_nhwc.cpp)
add_example_executable(example_batchnorm_infer batchnorm_infer_nhwc.cpp) add_example_executable(example_batchnorm_forward_inferring batchnorm_forward_inferring_nhwc.cpp)
add_example_executable(example_batchnorm_backward batchnorm_backward_nhwc.cpp)
...@@ -53,4 +53,29 @@ Start running 10 times... ...@@ -53,4 +53,29 @@ Start running 10 times...
Perf: 1.28235 ms, 523.329 GB/s Perf: 1.28235 ms, 523.329 GB/s
``` ```
## Run ```batchnorm backward nhwc```
```bash
# -D <xxx> : input 4-d tensor lengths
# -v <x> : verification (0=no, 1=yes)
Arg1: data type (0: fp16, 1: fp32, 3: int8, 5: bp16, 6: fp64)
Arg2 -- 1/0 to indicate whether to use saved mean and invVariance
Arg3 -- init method used for dy and bnScale (0=no init, 1=single integer value, 2=scope integer value, 3=decimal value)
Arg4 -- time kernel (0=no, 1=yes)
Arg5: use multi-block welford (0=n0, 1=yes)
./bin/example_batchnorm_backward -D 128,16,3,1024 -v 1 0 0 3 1 1
```
Result
```
./bin/example_batchnorm_backward -D 128,16,3,1024 -v 1 0 0 3 1 1
launch_and_time_kernel: grid_dim {6144, 1, 1}, block_dim {256, 1, 1}
Warm up 1 time
Start running 10 times...
launch_and_time_kernel: grid_dim {6144, 1, 1}, block_dim {256, 1, 1}
Warm up 1 time
Start running 10 times...
launch_and_time_kernel: grid_dim {6144, 1, 1}, block_dim {256, 1, 1}
Warm up 1 time
Start running 10 times...
Perf: 0.411026 ms, 91.8702 GB/s
```
This diff is collapsed.
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <array>
#include <memory>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <index_t Rank, index_t NumBatchNormReduceDim, typename DyElementwiseOp>
struct DeviceBatchNormBwd : public BaseOperator
{
static constexpr index_t NumInvariantDim = Rank - NumBatchNormReduceDim;
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const std::array<index_t, Rank> xyLengths,
const std::array<index_t, Rank> xStrides,
const std::array<index_t, Rank> dyStrides,
const std::array<index_t, Rank> dxStrides,
const std::array<int, NumBatchNormReduceDim> reduceDims,
const std::array<ck::index_t, NumInvariantDim> bnScaleBiasMeanVarLengths,
const std::array<ck::index_t, NumInvariantDim> bnScaleStrides,
const std::array<ck::index_t, NumInvariantDim> bnBiasStrides,
const std::array<ck::index_t, NumInvariantDim> bnMeanVarStrides,
const void* p_x,
const void* p_dy,
const void* p_scale,
const void* p_savedMean,
const void* p_savedInvVar,
double epsilon,
const DyElementwiseOp dy_elementwise_op,
void* p_dx,
void* p_dscale,
void* p_dbias) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
template <index_t Rank, index_t NumBatchNormReduceDim, typename DyElementwiseOp>
using DeviceBatchNormBwdPtr =
std::unique_ptr<DeviceBatchNormBwd<Rank, NumBatchNormReduceDim, DyElementwiseOp>>;
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -700,7 +700,7 @@ struct DeviceBatchedGemmMultiD_Xdl : public DeviceBatchedGemmMultiD<ALayout, ...@@ -700,7 +700,7 @@ struct DeviceBatchedGemmMultiD_Xdl : public DeviceBatchedGemmMultiD<ALayout,
<< BlockSize << ", " << BlockSize << ", "
<< MPerBlock << ", " << MPerBlock << ", "
<< NPerBlock << ", " << NPerBlock << ", "
<< KPerBlock << KPerBlock << ", "
<< AK1 << ", " << AK1 << ", "
<< BK1 << ", " << BK1 << ", "
<< getGemmSpecializationString(GemmSpec) << getGemmSpecializationString(GemmSpec)
......
...@@ -93,6 +93,9 @@ struct GridwiseMultiblockWelfordFirstHalf ...@@ -93,6 +93,9 @@ struct GridwiseMultiblockWelfordFirstHalf
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
// clang-format off
// First half of the Multiblock Welford method to calculate mean and variance, used by both batchnorm-forward and batchnorm-backward.
// clang-format on
__device__ static void Run(const XGridDesc_M_K& x_grid_desc_m_k, __device__ static void Run(const XGridDesc_M_K& x_grid_desc_m_k,
const MeanVarCountGridDesc_M_G& mean_var_count_grid_desc_m_g, const MeanVarCountGridDesc_M_G& mean_var_count_grid_desc_m_g,
const GetReduceCountPerThreadFunctor& get_reduce_count_per_thread, const GetReduceCountPerThreadFunctor& get_reduce_count_per_thread,
......
...@@ -529,6 +529,7 @@ struct GridwiseWelfordSecondHalfBatchNormForwardFinal ...@@ -529,6 +529,7 @@ struct GridwiseWelfordSecondHalfBatchNormForwardFinal
auto result_inv_var_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto result_inv_var_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
resultSaveInvVariance, mean_var_grid_desc_m.GetElementSpaceSize()); resultSaveInvVariance, mean_var_grid_desc_m.GetElementSpaceSize());
// calculate inv-variance as 1/sqrt(epsilon+variance)
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
welford_var_thread_buf(I) = welford_var_thread_buf(I) =
type_convert<AccDataType>(1.0f) / sqrt(epsilon + welford_var_thread_buf[I]); type_convert<AccDataType>(1.0f) / sqrt(epsilon + welford_var_thread_buf[I]);
......
...@@ -441,6 +441,7 @@ struct GridwiseBatchNormForwardWithBlockwiseWelford ...@@ -441,6 +441,7 @@ struct GridwiseBatchNormForwardWithBlockwiseWelford
auto result_inv_var_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto result_inv_var_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
resultSaveInvVariance, mean_var_grid_desc_m.GetElementSpaceSize()); resultSaveInvVariance, mean_var_grid_desc_m.GetElementSpaceSize());
// calculate inv-variance as 1/sqrt(epsilon+variance), stored in place of variance
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
var_thread_buf(I) = var_thread_buf(I) =
type_convert<AccDataType>(1.0f) / sqrt(epsilon + var_thread_buf[I]); type_convert<AccDataType>(1.0f) / sqrt(epsilon + var_thread_buf[I]);
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/data_type.hpp"
#include "ck/utility/math.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_welford.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_welford.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace ck {
template <typename GridwiseMultiblockWelfordFirstHalf_,
typename XDataType,
typename MeanVarDataType,
typename XGridDesc_M_K,
typename MeanVarCountGridDesc_M_G,
typename GetReduceCountPerThreadFunctor>
__global__ void kernel_multiblock_welford_first_half(
const XGridDesc_M_K x_grid_desc_m_k,
const MeanVarCountGridDesc_M_G mean_var_count_grid_desc_m_g,
const GetReduceCountPerThreadFunctor get_reduce_count_per_thread,
index_t num_k_block_tile_iteration,
const XDataType* const __restrict__ p_x,
MeanVarDataType* const p_welford_mean,
MeanVarDataType* const p_welford_variance,
int32_t* const p_welford_count)
{
GridwiseMultiblockWelfordFirstHalf_::Run(x_grid_desc_m_k,
mean_var_count_grid_desc_m_g,
get_reduce_count_per_thread,
num_k_block_tile_iteration,
p_x,
p_welford_mean,
p_welford_variance,
p_welford_count);
};
template <typename XDataType,
typename AccDataType,
typename MeanVarDataType,
typename XGridDesc_M_K,
typename MeanVarCountGridDesc_M_G,
typename GetReduceCountPerThreadFunctor,
index_t BlockSize,
index_t MThreadClusterSize,
index_t KThreadClusterSize,
index_t MThreadSliceSize,
index_t KThreadSliceSize,
index_t XSrcCountSrcVectorDim,
index_t XSrcCountSrcVectorSize>
struct GridwiseMultiblockWelfordFirstHalf
{
static_assert((XSrcCountSrcVectorDim == 0 && MThreadSliceSize % XSrcCountSrcVectorSize == 0) ||
(XSrcCountSrcVectorDim == 1 &&
KThreadSliceSize % XSrcCountSrcVectorSize == 0),
"Invalid thread slice sizes and/or vector sizes configuration, please check!");
static constexpr bool reorder_thread_cluster = (XSrcCountSrcVectorDim == 0);
using ThreadClusterLengths_M_K = Sequence<MThreadClusterSize, KThreadClusterSize>;
using ThreadBufferDimAccessOrder =
typename conditional<reorder_thread_cluster, Sequence<1, 0>, Sequence<0, 1>>::type;
using ThreadClusterArrangeOrder =
typename conditional<reorder_thread_cluster, Sequence<1, 0>, Sequence<0, 1>>::type;
static constexpr auto thread_cluster_desc =
make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{});
using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{})));
using ThreadReduceDstDesc_M =
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{})));
using ThreadwiseWelford =
ThreadwiseWelford<AccDataType, ThreadReduceSrcDesc_M_K, ThreadReduceDstDesc_M>;
using BlockwiseWelford = BlockwiseWelford<AccDataType,
BlockSize,
ThreadClusterLengths_M_K,
ThreadClusterArrangeOrder,
false>;
using PassThroughOp = tensor_operation::element_wise::PassThrough;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
__device__ static void Run(const XGridDesc_M_K& x_grid_desc_m_k,
const MeanVarCountGridDesc_M_G& mean_var_count_grid_desc_m_g,
const GetReduceCountPerThreadFunctor& get_reduce_count_per_thread,
index_t num_k_block_tile_iteration,
const XDataType* const __restrict__ p_x,
MeanVarDataType* const p_welford_mean,
MeanVarDataType* const p_welford_variance,
int32_t* const p_welford_count)
{
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
x_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true>
welford_mean_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true>
welford_var_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, int32_t, MThreadSliceSize, true>
welford_count_thread_buf;
const index_t blkgroup_size = mean_var_count_grid_desc_m_g.GetLength(I1);
const index_t thread_local_id = get_thread_local_1d_id();
const index_t block_global_id = get_block_1d_id();
const index_t blkgroup_id = block_global_id / blkgroup_size;
const index_t block_local_id = block_global_id % blkgroup_size;
const auto thread_cluster_idx =
thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));
const auto thread_m_cluster_id = thread_cluster_idx[I0];
const auto thread_k_cluster_id = thread_cluster_idx[I1];
using ThreadBufferLengths_M_K = Sequence<MThreadSliceSize, KThreadSliceSize>;
using ThreadBufferLengths_M_1 = Sequence<MThreadSliceSize, 1>;
constexpr auto thread_buffer_desc_m_k = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{}));
constexpr auto thread_buffer_desc_m_1 = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<1>{}));
const index_t reduceSizePerBlock = K_BlockTileSize * num_k_block_tile_iteration;
auto threadwise_x_load = ThreadwiseTensorSliceTransfer_v2<XDataType,
AccDataType,
XGridDesc_M_K,
decltype(thread_buffer_desc_m_k),
ThreadBufferLengths_M_K,
ThreadBufferDimAccessOrder,
XSrcCountSrcVectorDim,
XSrcCountSrcVectorSize,
1,
true>(
x_grid_desc_m_k,
make_multi_index(blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize,
block_local_id * reduceSizePerBlock +
thread_k_cluster_id * KThreadSliceSize));
auto threadwise_welford_mean_var_store =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
MeanVarDataType,
decltype(thread_buffer_desc_m_1),
MeanVarCountGridDesc_M_G,
PassThroughOp,
ThreadBufferLengths_M_1,
Sequence<0, 1>,
1,
1,
InMemoryDataOperationEnum::Set,
1,
true>(
mean_var_count_grid_desc_m_g,
make_multi_index(blkgroup_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
block_local_id),
PassThroughOp{});
auto threadwise_welford_count_store =
ThreadwiseTensorSliceTransfer_v1r3<int32_t,
int32_t,
decltype(thread_buffer_desc_m_1),
MeanVarCountGridDesc_M_G,
PassThroughOp,
ThreadBufferLengths_M_1,
Sequence<0, 1>,
1,
1,
InMemoryDataOperationEnum::Set,
1,
true>(
mean_var_count_grid_desc_m_g,
make_multi_index(blkgroup_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
block_local_id),
PassThroughOp{});
constexpr auto thread_copy_fwd_step_m_k = make_multi_index(0, K_BlockTileSize);
const auto x_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_x, x_grid_desc_m_k.GetElementSpaceSize());
auto welford_mean_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_welford_mean, mean_var_count_grid_desc_m_g.GetElementSpaceSize());
auto welford_var_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_welford_variance, mean_var_count_grid_desc_m_g.GetElementSpaceSize());
auto welford_count_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_welford_count, mean_var_count_grid_desc_m_g.GetElementSpaceSize());
auto threadwise_welford = ThreadwiseWelford();
threadwise_welford.max_count_ =
get_reduce_count_per_thread(block_local_id, thread_k_cluster_id);
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
welford_mean_thread_buf(I) = type_convert<AccDataType>(0.0f);
welford_var_thread_buf(I) = type_convert<AccDataType>(0.0f);
});
for(index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles)
{
threadwise_x_load.Run(x_grid_desc_m_k,
x_global_val_buf,
thread_buffer_desc_m_k,
make_tuple(I0, I0),
x_thread_buf);
threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_fwd_step_m_k);
threadwise_welford.Run(x_thread_buf, welford_mean_thread_buf, welford_var_thread_buf);
}
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
if constexpr(I > 0)
block_sync_lds();
welford_count_thread_buf(I) = threadwise_welford.cur_count_;
BlockwiseWelford::Run(
welford_mean_thread_buf(I), welford_var_thread_buf(I), welford_count_thread_buf(I));
});
if(thread_k_cluster_id == 0)
{
threadwise_welford_mean_var_store.Run(thread_buffer_desc_m_1,
make_tuple(I0, I0),
welford_mean_thread_buf,
mean_var_count_grid_desc_m_g,
welford_mean_global_val_buf);
threadwise_welford_mean_var_store.Run(thread_buffer_desc_m_1,
make_tuple(I0, I0),
welford_var_thread_buf,
mean_var_count_grid_desc_m_g,
welford_var_global_val_buf);
threadwise_welford_count_store.Run(thread_buffer_desc_m_1,
make_tuple(I0, I0),
welford_count_thread_buf,
mean_var_count_grid_desc_m_g,
welford_count_global_val_buf);
};
}
};
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include <algorithm>
#include "ck/tensor_operation/gpu/device/device_batchnorm_backward.hpp"
namespace ck {
namespace tensor_operation {
namespace host {
template <typename XDataType,
typename DyDataType,
typename DxDataType,
typename AccDataType,
typename ScaleDataType,
typename BiasDataType,
typename MeanVarDataType,
typename DyElementwiseOp>
struct ReferenceBatchNormBwd_Input_N_H_W_C_Output_C
: public device::DeviceBatchNormBwd<4, 3, DyElementwiseOp>
{
struct Argument : public device::BaseArgument
{
Argument(const std::array<index_t, 4> xyLengths,
const std::array<index_t, 4> xStrides,
const std::array<index_t, 4> dyStrides,
const std::array<index_t, 4> dxStrides,
const std::array<int, 3> reduceDims,
const std::array<ck::index_t, 1> bnScaleBiasMeanVarLengths,
const std::array<ck::index_t, 1> bnScaleStrides,
const std::array<ck::index_t, 1> bnBiasStrides,
const std::array<ck::index_t, 1> bnMeanVarStrides,
const XDataType* p_x,
const DyDataType* p_dy,
const ScaleDataType* p_scale,
const MeanVarDataType* p_savedMean,
const MeanVarDataType* p_savedInvVar,
double epsilon,
const DyElementwiseOp dy_elementwise_op,
DxDataType* p_dx,
ScaleDataType* p_dscale,
BiasDataType* p_dbias)
: p_x_(p_x),
p_dy_(p_dy),
p_scale_(p_scale),
p_savedMean_(p_savedMean),
p_savedInvVar_(p_savedInvVar),
epsilon_(epsilon),
dy_elementwise_op_(dy_elementwise_op),
p_dx_(p_dx),
p_dscale_(p_dscale),
p_dbias_(p_dbias)
{
ignore = xStrides;
ignore = dyStrides;
ignore = dxStrides;
ignore = bnScaleStrides;
ignore = bnBiasStrides;
ignore = bnMeanVarStrides;
if(xyLengths.size() != 4 || bnScaleBiasMeanVarLengths.size() != 1 ||
bnScaleBiasMeanVarLengths[0] != xyLengths[3])
throw std::runtime_error("Invalid tensor dimensions!");
if(reduceDims[0] != 0 || reduceDims[1] != 1 || reduceDims[2] != 2)
throw std::runtime_error("Invalid reduce dimensions!");
n_ = xyLengths[0];
h_ = xyLengths[1];
w_ = xyLengths[2];
c_ = xyLengths[3];
haveSavedMeanInvVar_ = (p_savedMean != nullptr && p_savedInvVar != nullptr);
}
const XDataType* p_x_;
const DyDataType* p_dy_;
const ScaleDataType* p_scale_;
const MeanVarDataType* p_savedMean_;
const MeanVarDataType* p_savedInvVar_;
double epsilon_;
const DyElementwiseOp dy_elementwise_op_;
DxDataType* p_dx_;
ScaleDataType* p_dscale_;
BiasDataType* p_dbias_;
bool haveSavedMeanInvVar_;
index_t n_, h_, w_, c_;
};
struct Invoker : public device::BaseInvoker
{
float Run(const Argument& arg)
{
auto thread_reduce_func = [&](auto iC) {
AccDataType reduceSize = type_convert<AccDataType>(arg.n_) *
type_convert<AccDataType>(arg.h_) *
type_convert<AccDataType>(arg.w_);
index_t offset_C = iC;
AccDataType mean;
AccDataType invVar;
if(arg.haveSavedMeanInvVar_)
{
mean = arg.p_savedMean_[offset_C];
invVar = arg.p_savedInvVar_[offset_C];
}
else
{
AccDataType meansquare;
meansquare = type_convert<AccDataType>(0.0f);
mean = type_convert<AccDataType>(0.0f);
// compute mean, meanquare, variance, inv-variance
for(index_t iN = 0; iN < arg.n_; iN++)
{
index_t offset_N = iN * arg.h_ * arg.w_ * arg.c_;
for(index_t iH = 0; iH < arg.h_; iH++)
{
index_t offset_H = iH * arg.w_ * arg.c_;
for(index_t iW = 0; iW < arg.w_; iW++)
{
index_t offset_W = iW * arg.c_;
auto offset = offset_N + offset_H + offset_W + offset_C;
AccDataType x = type_convert<AccDataType>(arg.p_x_[offset]);
mean += x;
meansquare += x * x;
};
}
};
mean = mean / reduceSize;
meansquare = meansquare / reduceSize;
AccDataType variance = meansquare - mean * mean;
invVar = type_convert<AccDataType>(1.0f) /
std::sqrt(type_convert<AccDataType>(arg.epsilon_) + variance);
};
AccDataType dbias = type_convert<AccDataType>(0.0f); // Sum on NHW of dy
AccDataType dscale = type_convert<AccDataType>(0.0f); // Sum on NHW of dy * norm_x
// 1) calculate dy * (x - mean) * inv-variance
// 2) calculate sum(dy) on NHW dimensions
// 3) calculate sum(dy * norm_x) on NHW dimensions
for(index_t iN = 0; iN < arg.n_; iN++)
{
index_t offset_N = iN * arg.h_ * arg.w_ * arg.c_;
for(index_t iH = 0; iH < arg.h_; iH++)
{
index_t offset_H = iH * arg.w_ * arg.c_;
for(index_t iW = 0; iW < arg.w_; iW++)
{
index_t offset_W = iW * arg.c_;
auto offset = offset_N + offset_H + offset_W + offset_C;
AccDataType x = type_convert<AccDataType>(arg.p_x_[offset]);
AccDataType norm_x = (x - mean) * invVar;
AccDataType dy = type_convert<AccDataType>(arg.p_dy_[offset]);
arg.dy_elementwise_op_(dy, dy);
dbias += dy;
dscale += norm_x * dy;
};
}
};
arg.p_dscale_[offset_C] = type_convert<ScaleDataType>(dscale);
arg.p_dbias_[offset_C] = type_convert<BiasDataType>(dbias);
AccDataType scale = type_convert<AccDataType>(arg.p_scale_[offset_C]);
AccDataType multiplier =
type_convert<AccDataType>(1.0f) / reduceSize * invVar * scale;
// 1) calculate tmp = dscale * (x - mean) * inv-variance
// 2) calculate dx = 1/nhw * inv-variance * scale * (nhw * dy - dbias - tmp)
for(index_t iN = 0; iN < arg.n_; iN++)
{
index_t offset_N = iN * arg.h_ * arg.w_ * arg.c_;
for(index_t iH = 0; iH < arg.h_; iH++)
{
index_t offset_H = iH * arg.w_ * arg.c_;
for(index_t iW = 0; iW < arg.w_; iW++)
{
index_t offset_W = iW * arg.c_;
auto offset = offset_N + offset_H + offset_W + offset_C;
AccDataType x = type_convert<AccDataType>(arg.p_x_[offset]);
AccDataType norm_x = (x - mean) * invVar;
AccDataType dy = type_convert<AccDataType>(arg.p_dy_[offset]);
arg.dy_elementwise_op_(dy, dy);
AccDataType tmpVal = norm_x * dscale;
AccDataType dx = multiplier * (reduceSize * dy - dbias - tmpVal);
arg.p_dx_[offset] = type_convert<XDataType>(dx);
};
}
};
};
std::size_t num_thread = std::thread::hardware_concurrency();
std::size_t work_per_thread = (arg.c_ + num_thread - 1) / num_thread;
std::vector<joinable_thread> threads(num_thread);
for(std::size_t it = 0; it < num_thread; ++it)
{
std::size_t ic_begin = it * work_per_thread;
std::size_t ic_end = std::min(static_cast<int>((it + 1) * work_per_thread), arg.c_);
auto f = [=] {
for(std::size_t ic = ic_begin; ic < ic_end; ++ic)
{
thread_reduce_func(ic);
}
};
threads[it] = joinable_thread(f);
}
return (0.0f);
};
float Run(const device::BaseArgument* p_arg,
const StreamConfig& /*stream_config*/ = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg));
};
};
bool IsSupportedArgument(const device::BaseArgument* p_arg) override
{
(void)p_arg;
return (true);
};
std::unique_ptr<device::BaseArgument>
MakeArgumentPointer(const std::array<index_t, 4> xyLengths,
const std::array<index_t, 4> xStrides,
const std::array<index_t, 4> dyStrides,
const std::array<index_t, 4> dxStrides,
const std::array<int, 3> reduceDims,
const std::array<ck::index_t, 1> bnScaleBiasMeanVarLengths,
const std::array<ck::index_t, 1> bnScaleStrides,
const std::array<ck::index_t, 1> bnBiasStrides,
const std::array<ck::index_t, 1> bnMeanVarStrides,
const void* p_x,
const void* p_dy,
const void* p_scale,
const void* p_savedMean,
const void* p_savedInvVar,
double epsilon,
const DyElementwiseOp dy_elementwise_op,
void* p_dx,
void* p_dscale,
void* p_dbias) override
{
return std::make_unique<Argument>(xyLengths,
xStrides,
dyStrides,
dxStrides,
reduceDims,
bnScaleBiasMeanVarLengths,
bnScaleStrides,
bnBiasStrides,
bnMeanVarStrides,
static_cast<const XDataType*>(p_x),
static_cast<const DyDataType*>(p_dy),
static_cast<const ScaleDataType*>(p_scale),
static_cast<const MeanVarDataType*>(p_savedMean),
static_cast<const MeanVarDataType*>(p_savedInvVar),
epsilon,
dy_elementwise_op,
static_cast<DxDataType*>(p_dx),
static_cast<ScaleDataType*>(p_dscale),
static_cast<BiasDataType*>(p_dbias));
};
std::unique_ptr<device::BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>();
};
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "Reference_BatchNorm_Backward_NHWC_C<" << std::endl;
// clang-format on
return str.str();
}
};
} // namespace host
} // namespace tensor_operation
} // namespace ck
...@@ -31,10 +31,6 @@ void add_device_batchnorm_forward_rank_4_3_bf16_instances( ...@@ -31,10 +31,6 @@ void add_device_batchnorm_forward_rank_4_3_bf16_instances(
std::vector< std::vector<
std::unique_ptr<DeviceBatchNormFwd<BF16, BF16, F32, BF16, BF16, F32, PassThrough, 4, 3>>>&); std::unique_ptr<DeviceBatchNormFwd<BF16, BF16, F32, BF16, BF16, F32, PassThrough, 4, 3>>>&);
// Int8
void add_device_batchnorm_forward_rank_4_3_i8_instances(
std::vector<std::unique_ptr<DeviceBatchNormFwd<I8, I8, F32, I8, I8, F32, PassThrough, 4, 3>>>&);
// FP64 // FP64
void add_device_batchnorm_forward_rank_4_3_f64_instances( void add_device_batchnorm_forward_rank_4_3_f64_instances(
std::vector< std::vector<
...@@ -101,15 +97,6 @@ struct DeviceOperationInstanceFactory< ...@@ -101,15 +97,6 @@ struct DeviceOperationInstanceFactory<
add_device_batchnorm_forward_rank_4_3_bf16_instances(op_ptrs); add_device_batchnorm_forward_rank_4_3_bf16_instances(op_ptrs);
} }
} }
else if constexpr(is_same_v<XDataType, I8> && is_same_v<YDataType, I8> &&
is_same_v<AccDataType, F32> && is_same_v<ScaleDataType, I8> &&
is_same_v<BiasDataType, I8> && is_same_v<MeanVarDataType, F32>)
{
if constexpr(Rank == 4 && NumReduceDim == 3 && is_same_v<YElementwiseOp, PassThrough>)
{
add_device_batchnorm_forward_rank_4_3_i8_instances(op_ptrs);
}
}
else if constexpr(is_same_v<XDataType, F64> && is_same_v<YDataType, F64> && else if constexpr(is_same_v<XDataType, F64> && is_same_v<YDataType, F64> &&
is_same_v<AccDataType, F64> && is_same_v<ScaleDataType, F64> && is_same_v<AccDataType, F64> && is_same_v<ScaleDataType, F64> &&
is_same_v<BiasDataType, F64> && is_same_v<MeanVarDataType, F64>) is_same_v<BiasDataType, F64> && is_same_v<MeanVarDataType, F64>)
......
...@@ -2,6 +2,5 @@ add_instance_library(device_batchnorm_instance ...@@ -2,6 +2,5 @@ add_instance_library(device_batchnorm_instance
device_batchnorm_forward_f16_instance.cpp device_batchnorm_forward_f16_instance.cpp
device_batchnorm_forward_f32_instance.cpp device_batchnorm_forward_f32_instance.cpp
device_batchnorm_forward_bf16_instance.cpp device_batchnorm_forward_bf16_instance.cpp
device_batchnorm_forward_i8_instance.cpp
device_batchnorm_forward_f64_instance.cpp device_batchnorm_forward_f64_instance.cpp
) )
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batchnorm_forward_impl.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using I8 = int8_t;
using F32 = float;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
// clang-format off
template <index_t Rank, index_t NumReduceDim, typename YElementwiseOp>
using device_batchnorm_forward_i8_blockwise_instances = std::tuple<
// XDataType, YDataType, AccDataType, ScaleDataType, BiasDataType, MeanVarDataType, YElementwiseOp, Rank, NumReduceDim, UseMultiBlockInK, BLockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XSrcYDstVectorDim, XSrcVectorSize, YDstVectorSize, ScaleSrcVectorSize, BiasSrcVectorSize, MeanVarSrcDstVectorSize
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 128, 2, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 128, 2, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 128, 2, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 128, 2, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 128, 2, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 128, 2, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 64, 4, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 64, 4, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 64, 4, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 64, 4, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 64, 4, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 64, 4, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 32, 8, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 32, 8, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 32, 8, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 32, 8, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 32, 8, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 32, 8, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 16, 16, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 16, 16, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 16, 16, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 16, 16, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 16, 16, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 16, 16, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 8, 32, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 8, 32, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 8, 32, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 8, 32, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 8, 32, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 8, 32, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 4, 64, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 4, 64, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 4, 64, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 4, 64, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 4, 64, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 4, 64, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 2, 128, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 2, 128, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 2, 128, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 2, 128, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 2, 128, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 2, 128, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 1, 256, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 1, 256, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 1, 256, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 1, 256, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 1, 256, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 1, 256, 2, 2, 1, 1, 1, 1, 1, 1>
>;
// clang-format on
// clang-format off
template <index_t Rank, index_t NumReduceDim, typename YElementwiseOp>
using device_batchnorm_forward_i8_multiblock_instances =
std::tuple <
// XDataType, YDataType, AccDataType, ScaleDataType, BiasDataType, MeanVarDataType, YElementwiseOp, Rank, NumReduceDim, UseMultiBlockInK, BLockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XSrcYDstVectorDim, XSrcVectorSize, YDstVectorSize, ScaleSrcVectorSize, BiasSrcVectorSize, MeanVarSrcDstVectorSize
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 128, 2, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 128, 2, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 128, 2, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 128, 2, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 128, 2, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 128, 2, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 64, 4, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 64, 4, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 64, 4, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 64, 4, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 64, 4, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 64, 4, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 32, 8, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 32, 8, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 32, 8, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 32, 8, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 32, 8, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 32, 8, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 16, 16, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 16, 16, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 16, 16, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 16, 16, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 16, 16, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 16, 16, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 8, 32, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 8, 32, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 8, 32, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 8, 32, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 8, 32, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 8, 32, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 4, 64, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 4, 64, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 4, 64, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 4, 64, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 4, 64, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 4, 64, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 2, 128, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 2, 128, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 2, 128, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 2, 128, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 2, 128, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 2, 128, 2, 2, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 1, 256, 2, 2, 0, 2, 2, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 1, 256, 2, 2, 0, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 1, 256, 2, 2, 0, 1, 1, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 1, 256, 2, 2, 0, 2, 2, 1, 1, 1>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 1, 256, 2, 2, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 1, 256, 2, 2, 1, 1, 1, 1, 1, 1>
>;
// clang-format on
void add_device_batchnorm_forward_rank_4_3_i8_instances(
std::vector<std::unique_ptr<DeviceBatchNormFwd<I8, I8, F32, I8, I8, F32, PassThrough, 4, 3>>>&
instances)
{
add_device_operation_instances(
instances, device_batchnorm_forward_i8_blockwise_instances<4, 3, PassThrough>{});
add_device_operation_instances(
instances, device_batchnorm_forward_i8_multiblock_instances<4, 3, PassThrough>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -85,39 +85,22 @@ bool profile_batchnorm_forward_impl(int do_verification, ...@@ -85,39 +85,22 @@ bool profile_batchnorm_forward_impl(int do_verification,
if(updateMovingAverage) if(updateMovingAverage)
{ {
if constexpr(ck::is_same_v<XDataType, int8_t>) const float x_mean = 0.0f;
{ const float x_stddev = 1.0f;
x.GenerateTensorValue(GeneratorTensor_2<XDataType>{-5, 5}, num_thread); const float noise_stddev = 0.04f;
const float x_mean = 0.0f; // input data in normal distribution
const float x_stddev = 2.5f; x.GenerateTensorValue(GeneratorTensor_4<XDataType>{x_mean, x_stddev}, num_thread);
const float noise_stddev = 0.04f;
// initialize the runningMean to be values with tiny variation to the mean of the x
resultRunningMean_ref.GenerateTensorValue( // values
GeneratorTensor_4<MeanVarDataType>{x_mean, noise_stddev}, num_thread); resultRunningMean_ref.GenerateTensorValue(
GeneratorTensor_4<MeanVarDataType>{x_mean, noise_stddev}, num_thread);
resultRunningVariance_ref.GenerateTensorValue(
GeneratorTensor_4<MeanVarDataType>{x_stddev * x_stddev, noise_stddev}, num_thread); // initialize the runningVariance to be values with tiny variation to the variance of
} // the x values
else resultRunningVariance_ref.GenerateTensorValue(
{ GeneratorTensor_4<MeanVarDataType>{x_stddev * x_stddev, noise_stddev}, num_thread);
const float x_mean = 0.0f;
const float x_stddev = 1.0f;
const float noise_stddev = 0.04f;
// input data in normal distribution
x.GenerateTensorValue(GeneratorTensor_4<XDataType>{x_mean, x_stddev}, num_thread);
// initialize the runningMean to be values with tiny variation to the mean of the x
// values
resultRunningMean_ref.GenerateTensorValue(
GeneratorTensor_4<MeanVarDataType>{x_mean, noise_stddev}, num_thread);
// initialize the runningVariance to be values with tiny variation to the variance of
// the x values
resultRunningVariance_ref.GenerateTensorValue(
GeneratorTensor_4<MeanVarDataType>{x_stddev * x_stddev, noise_stddev}, num_thread);
};
} }
else else
{ {
...@@ -129,35 +112,24 @@ bool profile_batchnorm_forward_impl(int do_verification, ...@@ -129,35 +112,24 @@ bool profile_batchnorm_forward_impl(int do_verification,
if(do_verification) if(do_verification)
{ {
if constexpr(ck::is_same_v<ScaleDataType, int8_t> && ck::is_same_v<BiasDataType, int8_t>) switch(init_method)
{ {
case 0:
bnScale.GenerateTensorValue(GeneratorTensor_0<ScaleDataType>{}, num_thread);
bnBias.GenerateTensorValue(GeneratorTensor_0<BiasDataType>{}, num_thread);
break;
case 1:
bnScale.GenerateTensorValue(GeneratorTensor_1<ScaleDataType>{1}, num_thread);
bnBias.GenerateTensorValue(GeneratorTensor_1<BiasDataType>{0}, num_thread);
break;
case 2:
bnScale.GenerateTensorValue(GeneratorTensor_2<ScaleDataType>{-5, 5}, num_thread); bnScale.GenerateTensorValue(GeneratorTensor_2<ScaleDataType>{-5, 5}, num_thread);
bnBias.GenerateTensorValue(GeneratorTensor_2<BiasDataType>{-5, 5}, num_thread); bnBias.GenerateTensorValue(GeneratorTensor_2<BiasDataType>{-5, 5}, num_thread);
break;
default:
bnScale.GenerateTensorValue(GeneratorTensor_3<ScaleDataType>{-1.0f, 1.0f}, num_thread);
bnBias.GenerateTensorValue(GeneratorTensor_3<BiasDataType>{-1.0f, 1.0f}, num_thread);
} }
else
{
switch(init_method)
{
case 0:
bnScale.GenerateTensorValue(GeneratorTensor_0<ScaleDataType>{}, num_thread);
bnBias.GenerateTensorValue(GeneratorTensor_0<BiasDataType>{}, num_thread);
break;
case 1:
bnScale.GenerateTensorValue(GeneratorTensor_1<ScaleDataType>{1}, num_thread);
bnBias.GenerateTensorValue(GeneratorTensor_1<BiasDataType>{0}, num_thread);
break;
case 2:
bnScale.GenerateTensorValue(GeneratorTensor_2<ScaleDataType>{-5, 5}, num_thread);
bnBias.GenerateTensorValue(GeneratorTensor_2<BiasDataType>{-5, 5}, num_thread);
break;
default:
bnScale.GenerateTensorValue(GeneratorTensor_3<ScaleDataType>{-1.0f, 1.0f},
num_thread);
bnBias.GenerateTensorValue(GeneratorTensor_3<BiasDataType>{-1.0f, 1.0f},
num_thread);
}
};
}; };
// these buffers are usually provided by the user application // these buffers are usually provided by the user application
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment