Commit 05ee41c3 authored by Rosty Geyyer's avatar Rosty Geyyer
Browse files

Merge branch 'develop' into lwpck-471

parents 37116c98 ad541ad6
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
#include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" #include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp" #include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" #include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r3.hpp" #include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r3.hpp"
...@@ -132,7 +132,8 @@ template < ...@@ -132,7 +132,8 @@ template <
index_t CShuffleNXdlPerWavePerShuffle, index_t CShuffleNXdlPerWavePerShuffle,
typename CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl, typename CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl,
index_t CBlockTransferScalarPerVector_NWaveNPerXdl, index_t CBlockTransferScalarPerVector_NWaveNPerXdl,
index_t NumGemmKPrefetchStage = 1> index_t NumGemmKPrefetchStage = 1,
PipelineVersion PipelineVer = PipelineVersion::v1>
struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3 struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
...@@ -149,7 +150,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3 ...@@ -149,7 +150,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
using ThisThreadBlock = ThisThreadBlock<BlockSize>; using ThisThreadBlock = ThisThreadBlock<BlockSize>;
using GridwiseGemmPipe = GridwiseGemmPipeline_v1<NumGemmKPrefetchStage>; using GridwiseGemmPipe = remove_cvref_t<decltype(
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage>())>;
__host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1() __host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1()
{ {
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
namespace ck { namespace ck {
// Y = LayerNorm(X, Beta, Gamma) // Y = Normalization(X, Beta, Gamma)
template <typename XDataType, template <typename XDataType,
typename GammaDataType, typename GammaDataType,
typename BetaDataType, typename BetaDataType,
...@@ -36,7 +36,7 @@ template <typename XDataType, ...@@ -36,7 +36,7 @@ template <typename XDataType,
index_t YDstVectorDim, index_t YDstVectorDim,
index_t YDstVectorSize, index_t YDstVectorSize,
bool SweepOnce> bool SweepOnce>
struct GridwiseLayernormNaiveVariance_mk_to_mk struct GridwiseNormalizationNaiveVariance_mk_to_mk
{ {
static_assert((XSrcVectorDim == 0 && MThreadSliceSize % XSrcVectorSize == 0) || static_assert((XSrcVectorDim == 0 && MThreadSliceSize % XSrcVectorSize == 0) ||
(XSrcVectorDim == 1 && KThreadSliceSize % XSrcVectorSize == 0), (XSrcVectorDim == 1 && KThreadSliceSize % XSrcVectorSize == 0),
......
...@@ -11,7 +11,7 @@ ...@@ -11,7 +11,7 @@
namespace ck { namespace ck {
// Y = LayerNorm(X, Beta, Gamma) // Y = Normalization(X, Beta, Gamma)
template <typename XDataType, template <typename XDataType,
typename GammaDataType, typename GammaDataType,
typename BetaDataType, typename BetaDataType,
...@@ -33,7 +33,7 @@ template <typename XDataType, ...@@ -33,7 +33,7 @@ template <typename XDataType,
index_t YDstVectorDim, index_t YDstVectorDim,
index_t YDstVectorSize, index_t YDstVectorSize,
bool SweepOnce> bool SweepOnce>
struct GridwiseLayernormWelfordVariance_mk_to_mk struct GridwiseNormalizationWelfordVariance_mk_to_mk
{ {
static_assert((XSrcVectorDim == 0 && MThreadSliceSize % XSrcVectorSize == 0) || static_assert((XSrcVectorDim == 0 && MThreadSliceSize % XSrcVectorSize == 0) ||
(XSrcVectorDim == 1 && KThreadSliceSize % XSrcVectorSize == 0), (XSrcVectorDim == 1 && KThreadSliceSize % XSrcVectorSize == 0),
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
#pragma once #pragma once
#include "ck/library/utility/numeric.hpp"
#include "ck/utility/common_header.hpp" #include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp"
...@@ -47,10 +48,9 @@ struct TransformConvFwdToGemm ...@@ -47,10 +48,9 @@ struct TransformConvFwdToGemm
if constexpr(ConvForwardSpecialization == if constexpr(ConvForwardSpecialization ==
device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
{ {
const index_t NWo = N * std::accumulate(c_g_n_k_wos_lengths.begin() + 3, const index_t NWo =
c_g_n_k_wos_lengths.begin() + 3 + NDimSpatial, N * ck::accumulate_n<index_t>(
index_t{1}, c_g_n_k_wos_lengths.begin() + 3, NDimSpatial, 1, std::multiplies<>());
std::multiplies<index_t>());
const auto in_gemmm_gemmk_desc = const auto in_gemmm_gemmk_desc =
make_naive_tensor_descriptor_packed(make_tuple(NWo, C)); make_naive_tensor_descriptor_packed(make_tuple(NWo, C));
...@@ -146,10 +146,9 @@ struct TransformConvFwdToGemm ...@@ -146,10 +146,9 @@ struct TransformConvFwdToGemm
if constexpr(ConvForwardSpecialization == if constexpr(ConvForwardSpecialization ==
device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
{ {
const index_t NHoWo = N * std::accumulate(c_g_n_k_wos_lengths.begin() + 3, const index_t NHoWo =
c_g_n_k_wos_lengths.begin() + 3 + NDimSpatial, N * ck::accumulate_n<index_t>(
index_t{1}, c_g_n_k_wos_lengths.begin() + 3, NDimSpatial, 1, std::multiplies<>());
std::multiplies<index_t>());
const auto in_gemmm_gemmk_desc = const auto in_gemmm_gemmk_desc =
make_naive_tensor_descriptor_packed(make_tuple(NHoWo, C)); make_naive_tensor_descriptor_packed(make_tuple(NHoWo, C));
...@@ -262,10 +261,8 @@ struct TransformConvFwdToGemm ...@@ -262,10 +261,8 @@ struct TransformConvFwdToGemm
device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
{ {
const index_t NDoHoWo = const index_t NDoHoWo =
N * std::accumulate(c_g_n_k_wos_lengths.begin() + 3, N * ck::accumulate_n<index_t>(
c_g_n_k_wos_lengths.begin() + 3 + NDimSpatial, c_g_n_k_wos_lengths.begin() + 3, NDimSpatial, 1, std::multiplies<>());
index_t{1},
std::multiplies<index_t>());
const auto in_gemmm_gemmk_desc = const auto in_gemmm_gemmk_desc =
make_naive_tensor_descriptor_packed(make_tuple(NDoHoWo, C)); make_naive_tensor_descriptor_packed(make_tuple(NDoHoWo, C));
...@@ -390,10 +387,9 @@ struct TransformConvFwdToGemm ...@@ -390,10 +387,9 @@ struct TransformConvFwdToGemm
if constexpr(ConvForwardSpecialization == if constexpr(ConvForwardSpecialization ==
device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
{ {
const index_t NHoWo = N * std::accumulate(c_g_n_k_wos_lengths.begin() + 3, const index_t NHoWo =
c_g_n_k_wos_lengths.begin() + 3 + NDimSpatial, N * ck::accumulate_n<index_t>(
index_t{1}, c_g_n_k_wos_lengths.begin() + 3, NDimSpatial, 1, std::multiplies<>());
std::multiplies<index_t>());
// This is different // This is different
const index_t WiStride = a_g_n_c_wis_strides[2 + NDimSpatial]; const index_t WiStride = a_g_n_c_wis_strides[2 + NDimSpatial];
...@@ -506,10 +502,9 @@ struct TransformConvFwdToGemm ...@@ -506,10 +502,9 @@ struct TransformConvFwdToGemm
if constexpr(ConvForwardSpecialization == if constexpr(ConvForwardSpecialization ==
device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
{ {
const index_t NHoWo = N * std::accumulate(c_g_n_k_wos_lengths.begin() + 3, const index_t NHoWo =
c_g_n_k_wos_lengths.begin() + 3 + NDimSpatial, N * ck::accumulate_n<index_t>(
index_t{1}, c_g_n_k_wos_lengths.begin() + 3, NDimSpatial, 1, std::multiplies<>());
std::multiplies<index_t>());
// This is different // This is different
const index_t WiStride = a_g_n_c_wis_strides[2 + NDimSpatial]; const index_t WiStride = a_g_n_c_wis_strides[2 + NDimSpatial];
...@@ -639,10 +634,8 @@ struct TransformConvFwdToGemm ...@@ -639,10 +634,8 @@ struct TransformConvFwdToGemm
device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
{ {
const index_t NDoHoWo = const index_t NDoHoWo =
N * std::accumulate(c_g_n_k_wos_lengths.begin() + 3, N * ck::accumulate_n<index_t>(
c_g_n_k_wos_lengths.begin() + 3 + NDimSpatial, c_g_n_k_wos_lengths.begin() + 3, NDimSpatial, 1, std::multiplies<>());
index_t{1},
std::multiplies<index_t>());
// This is different // This is different
const index_t WiStride = a_g_n_c_wis_strides[2 + NDimSpatial]; const index_t WiStride = a_g_n_c_wis_strides[2 + NDimSpatial];
...@@ -768,10 +761,8 @@ struct TransformConvFwdToGemm ...@@ -768,10 +761,8 @@ struct TransformConvFwdToGemm
const index_t K = b_g_k_c_xs_lengths[1]; const index_t K = b_g_k_c_xs_lengths[1];
const index_t C = b_g_k_c_xs_lengths[2]; const index_t C = b_g_k_c_xs_lengths[2];
const index_t YX = std::accumulate(b_g_k_c_xs_lengths.begin() + 3, const index_t YX = ck::accumulate_n<index_t>(
b_g_k_c_xs_lengths.begin() + 3 + NDimSpatial, b_g_k_c_xs_lengths.begin() + 3, NDimSpatial, 1, std::multiplies<>());
index_t{1},
std::multiplies<index_t>());
const auto wei_gemmn_gemmk_desc = const auto wei_gemmn_gemmk_desc =
make_naive_tensor_descriptor_packed(make_tuple(K, YX * C)); make_naive_tensor_descriptor_packed(make_tuple(K, YX * C));
...@@ -794,10 +785,8 @@ struct TransformConvFwdToGemm ...@@ -794,10 +785,8 @@ struct TransformConvFwdToGemm
const index_t K = b_g_k_c_xs_lengths[1]; const index_t K = b_g_k_c_xs_lengths[1];
const index_t C = b_g_k_c_xs_lengths[2]; const index_t C = b_g_k_c_xs_lengths[2];
const index_t YX = std::accumulate(b_g_k_c_xs_lengths.begin() + 3, const index_t YX = ck::accumulate_n<index_t>(
b_g_k_c_xs_lengths.begin() + 3 + NDimSpatial, b_g_k_c_xs_lengths.begin() + 3, NDimSpatial, 1, std::multiplies<>());
index_t{1},
std::multiplies<index_t>());
const index_t KStride = b_g_k_c_xs_strides[1]; const index_t KStride = b_g_k_c_xs_strides[1];
const index_t XStride = b_g_k_c_xs_strides[2 + NDimSpatial]; const index_t XStride = b_g_k_c_xs_strides[2 + NDimSpatial];
...@@ -827,10 +816,9 @@ struct TransformConvFwdToGemm ...@@ -827,10 +816,9 @@ struct TransformConvFwdToGemm
const index_t N = c_g_n_k_wos_lengths[1]; const index_t N = c_g_n_k_wos_lengths[1];
const index_t K = c_g_n_k_wos_lengths[2]; const index_t K = c_g_n_k_wos_lengths[2];
const index_t NHoWo = N * std::accumulate(c_g_n_k_wos_lengths.begin() + 3, const index_t NHoWo =
c_g_n_k_wos_lengths.begin() + 3 + NDimSpatial, N * ck::accumulate_n<index_t>(
index_t{1}, c_g_n_k_wos_lengths.begin() + 3, NDimSpatial, 1, std::multiplies<>());
std::multiplies<index_t>());
const auto out_gemmm_gemmn_desc = make_naive_tensor_descriptor_packed(make_tuple(NHoWo, K)); const auto out_gemmm_gemmn_desc = make_naive_tensor_descriptor_packed(make_tuple(NHoWo, K));
...@@ -855,10 +843,9 @@ struct TransformConvFwdToGemm ...@@ -855,10 +843,9 @@ struct TransformConvFwdToGemm
const auto KStride = I1; const auto KStride = I1;
const index_t WoStride = c_g_n_k_wos_strides[NDimSpatial + 2]; const index_t WoStride = c_g_n_k_wos_strides[NDimSpatial + 2];
const index_t NHoWo = N * std::accumulate(c_g_n_k_wos_lengths.begin() + 3, const index_t NHoWo =
c_g_n_k_wos_lengths.begin() + 3 + NDimSpatial, N * ck::accumulate_n<index_t>(
index_t{1}, c_g_n_k_wos_lengths.begin() + 3, NDimSpatial, 1, std::multiplies<>());
std::multiplies<index_t>());
const auto out_gemmm_gemmn_desc = const auto out_gemmm_gemmn_desc =
make_naive_tensor_descriptor(make_tuple(NHoWo, K), make_tuple(WoStride, KStride)); make_naive_tensor_descriptor(make_tuple(NHoWo, K), make_tuple(WoStride, KStride));
...@@ -878,10 +865,9 @@ struct TransformConvFwdToGemm ...@@ -878,10 +865,9 @@ struct TransformConvFwdToGemm
const index_t N = c_g_n_k_wos_lengths[1]; const index_t N = c_g_n_k_wos_lengths[1];
const index_t K = c_g_n_k_wos_lengths[2]; const index_t K = c_g_n_k_wos_lengths[2];
const index_t NHoWo = N * std::accumulate(c_g_n_k_wos_lengths.begin() + 3, const index_t NHoWo =
c_g_n_k_wos_lengths.begin() + 3 + NDimSpatial, N * ck::accumulate_n<index_t>(
index_t{1}, c_g_n_k_wos_lengths.begin() + 3, NDimSpatial, 1, std::multiplies<>());
std::multiplies<index_t>());
const auto out_gemmm_gemmn_desc = const auto out_gemmm_gemmn_desc =
make_naive_tensor_descriptor(make_tuple(NHoWo, K), make_tuple(I0, I1)); make_naive_tensor_descriptor(make_tuple(NHoWo, K), make_tuple(I0, I1));
......
...@@ -254,7 +254,7 @@ struct intrin_mfma_f32_16x16x8bf16<16, 16> ...@@ -254,7 +254,7 @@ struct intrin_mfma_f32_16x16x8bf16<16, 16>
template <class FloatC> template <class FloatC>
__device__ static void Run(const bhalf2_t& reg_a, const bhalf2_t& reg_b, FloatC& reg_c) __device__ static void Run(const bhalf2_t& reg_a, const bhalf2_t& reg_b, FloatC& reg_c)
{ {
reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x4bf16( reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x8bf16(
reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 0, 0, 0); reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 0, 0, 0);
} }
}; };
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <array>
#include <algorithm>
#include <thread>
#include "ck/utility/math_v2.hpp"
#include "ck/utility/ignore.hpp"
#include "ck/library/utility/host_common_util.hpp"
#include "ck/tensor_operation/gpu/device/device_batchnorm_backward.hpp"
namespace ck {
namespace tensor_operation {
namespace host {
template <typename XDataType,
typename DxDataType,
typename DyDataType,
typename AccDataType,
typename ScaleDataType,
typename DscaleDbiasDataType,
typename MeanVarDataType,
typename DyElementwiseOp,
index_t Rank,
index_t NumBatchNormReduceDim>
struct ReferenceBatchNormBwd : public device::DeviceBatchNormBwd<XDataType,
DxDataType,
DyDataType,
AccDataType,
ScaleDataType,
DscaleDbiasDataType,
MeanVarDataType,
DyElementwiseOp,
Rank,
NumBatchNormReduceDim>
{
static_assert(Rank <= 6, "Bigger Rank size is not supported!");
static constexpr index_t NumInvariantDim = Rank - NumBatchNormReduceDim;
struct Argument : public device::BaseArgument
{
Argument(const std::array<index_t, Rank> xyLengths,
const std::array<index_t, Rank> xStrides,
const std::array<index_t, Rank> dxStrides,
const std::array<index_t, Rank> dyStrides,
const std::array<int, NumBatchNormReduceDim> reduceDims,
const std::array<index_t, NumInvariantDim> bnScaleBiasMeanVarLengths,
const std::array<index_t, NumInvariantDim> bnScaleStrides,
const std::array<index_t, NumInvariantDim> bnDscaleDbiasStrides,
const std::array<index_t, NumInvariantDim> bnMeanVarStrides,
const XDataType* p_x,
const DyDataType* p_dy,
const ScaleDataType* p_scale,
const MeanVarDataType* p_savedMean,
const MeanVarDataType* p_savedInvVar,
double epsilon,
const DyElementwiseOp dy_elementwise_op,
DxDataType* p_dx,
DscaleDbiasDataType* p_dscale,
DscaleDbiasDataType* p_dbias)
: reduceDims_(reduceDims),
bnScaleBiasMeanVarLengths_(bnScaleBiasMeanVarLengths),
bnScaleStrides_(bnScaleStrides),
bnDscaleDbiasStrides_(bnDscaleDbiasStrides),
bnMeanVarStrides_(bnMeanVarStrides),
p_x_(p_x),
p_dy_(p_dy),
p_scale_(p_scale),
p_savedMean_(p_savedMean),
p_savedInvVar_(p_savedInvVar),
dy_elementwise_op_(dy_elementwise_op),
p_dx_(p_dx),
p_dscale_(p_dscale),
p_dbias_(p_dbias)
{
using ck::host_common::get_index_set;
if(std::any_of(
reduceDims.begin(), reduceDims.end(), [](int d) { return d < 0 || d >= Rank; }))
throw std::runtime_error("Invalid reduce dimensions!");
// get invariant_dims[] and invariant_lengths[]
for(int dim = 0, i = 0; dim < Rank; dim++)
if(std::none_of(
reduceDims.begin(), reduceDims.end(), [&](int d) { return d == dim; }))
{
invariantDims_[i] = dim;
invariant_lengths_[i] = xyLengths[dim];
i++;
};
// get reduce_lengths_[]
for(int j = 0, i = 0; j < NumBatchNormReduceDim; j++)
{
int dim = reduceDims[j];
reduce_lengths_[i++] = xyLengths[dim];
};
for(int i = 0; i < NumInvariantDim; i++)
if(invariant_lengths_[i] != bnScaleBiasMeanVarLengths_[i])
throw std::runtime_error("Invalid lengths parameters!");
for(int j = 0, i = 0; j < NumInvariantDim; j++)
{
int dim = invariantDims_[j];
x_invariant_strides_[i] = xStrides[dim];
dy_invariant_strides_[i] = dyStrides[dim];
dx_invariant_strides_[i] = dxStrides[dim];
i++;
};
for(int j = 0, i = 0; j < NumBatchNormReduceDim; j++)
{
int dim = reduceDims_[j];
x_reduce_strides_[i] = xStrides[dim];
dy_reduce_strides_[i] = dyStrides[dim];
dx_reduce_strides_[i] = dxStrides[dim];
i++;
};
reduceSize_ = std::accumulate(
reduce_lengths_.begin(), reduce_lengths_.end(), 1, std::multiplies<size_t>{});
invariant_index_set_ = get_index_set<NumInvariantDim>(invariant_lengths_);
reduce_index_set_ = get_index_set<NumBatchNormReduceDim>(reduce_lengths_);
epsilon_ = type_convert<AccDataType>(epsilon);
haveSavedMeanInvVar_ = (p_savedMean != nullptr && p_savedInvVar != nullptr);
}
std::array<int, NumBatchNormReduceDim> reduceDims_;
std::array<int, NumInvariantDim> invariantDims_;
std::array<index_t, NumInvariantDim> invariant_lengths_;
std::array<index_t, NumBatchNormReduceDim> reduce_lengths_;
const std::array<index_t, NumInvariantDim> bnScaleBiasMeanVarLengths_;
const std::array<index_t, NumInvariantDim> bnScaleStrides_;
const std::array<index_t, NumInvariantDim> bnDscaleDbiasStrides_;
const std::array<index_t, NumInvariantDim> bnMeanVarStrides_;
std::array<index_t, NumInvariantDim> x_invariant_strides_;
std::array<index_t, NumInvariantDim> dy_invariant_strides_;
std::array<index_t, NumInvariantDim> dx_invariant_strides_;
std::array<index_t, NumBatchNormReduceDim> x_reduce_strides_;
std::array<index_t, NumBatchNormReduceDim> dy_reduce_strides_;
std::array<index_t, NumBatchNormReduceDim> dx_reduce_strides_;
const XDataType* p_x_;
const DyDataType* p_dy_;
const ScaleDataType* p_scale_;
const MeanVarDataType* p_savedMean_;
const MeanVarDataType* p_savedInvVar_;
const DyElementwiseOp dy_elementwise_op_;
DxDataType* p_dx_;
DscaleDbiasDataType* p_dscale_;
DscaleDbiasDataType* p_dbias_;
bool haveSavedMeanInvVar_;
std::vector<std::array<index_t, NumInvariantDim>> invariant_index_set_;
std::vector<std::array<index_t, NumBatchNormReduceDim>> reduce_index_set_;
AccDataType epsilon_;
size_t reduceSize_;
};
struct Invoker : public device::BaseInvoker
{
float Run(const Argument& arg)
{
using ck::host_common::get_offset_from_index;
auto thread_reduce_func = [&](auto invariant_index) {
size_t x_invariant_offset = get_offset_from_index<NumInvariantDim>(
arg.x_invariant_strides_, invariant_index);
size_t dy_invariant_offset = get_offset_from_index<NumInvariantDim>(
arg.dy_invariant_strides_, invariant_index);
size_t dx_invariant_offset = get_offset_from_index<NumInvariantDim>(
arg.dx_invariant_strides_, invariant_index);
AccDataType mean = type_convert<AccDataType>(0.0f);
AccDataType variance = type_convert<AccDataType>(0.0f);
AccDataType invVar;
int32_t curr_count = 0;
if(arg.haveSavedMeanInvVar_)
{
size_t mean_invVar_invariant_offset = get_offset_from_index<NumInvariantDim>(
arg.bnMeanVarStrides_, invariant_index);
mean =
type_convert<AccDataType>(arg.p_savedMean_[mean_invVar_invariant_offset]);
invVar =
type_convert<AccDataType>(arg.p_savedInvVar_[mean_invVar_invariant_offset]);
}
else
{
// compute mean, variance using welford method
for(const auto& reduce_index : arg.reduce_index_set_)
{
size_t x_reduce_offset = get_offset_from_index<NumBatchNormReduceDim>(
arg.x_reduce_strides_, reduce_index);
auto x_offset = x_invariant_offset + x_reduce_offset;
curr_count++;
AccDataType x = type_convert<AccDataType>(arg.p_x_[x_offset]);
AccDataType delta = x - mean;
mean += delta / curr_count;
AccDataType delta2 = x - mean;
variance += delta * delta2;
};
// actual variance
variance = variance / curr_count;
// inv-variance defined as 1/sqrt(epsilon+variance)
invVar =
type_convert<AccDataType>(1.0f) / ck::math::sqrt(arg.epsilon_ + variance);
};
AccDataType dbias =
type_convert<AccDataType>(0.0f); // Sum on reduced dimensions of dy
AccDataType dscale =
type_convert<AccDataType>(0.0f); // Sum on reduced dimensions of dy * norm_x
// 1) calculate dy * (x - mean) * inv-variance
// 2) calculate sum(dy) on reduced dimensions
// 3) calculate sum(dy * norm_x) on reduced dimensions
for(const auto& reduce_index : arg.reduce_index_set_)
{
size_t x_reduce_offset = get_offset_from_index<NumBatchNormReduceDim>(
arg.x_reduce_strides_, reduce_index);
size_t dy_reduce_offset = get_offset_from_index<NumBatchNormReduceDim>(
arg.dy_reduce_strides_, reduce_index);
auto x_offset = x_invariant_offset + x_reduce_offset;
auto dy_offset = dy_invariant_offset + dy_reduce_offset;
AccDataType x = type_convert<AccDataType>(arg.p_x_[x_offset]);
AccDataType norm_x = (x - mean) * invVar;
AccDataType dy = type_convert<AccDataType>(arg.p_dy_[dy_offset]);
arg.dy_elementwise_op_(dy, dy);
dbias += dy;
dscale += norm_x * dy;
};
size_t dscale_offset = get_offset_from_index<NumInvariantDim>(
arg.bnDscaleDbiasStrides_, invariant_index);
size_t dbias_offset = get_offset_from_index<NumInvariantDim>(
arg.bnDscaleDbiasStrides_, invariant_index);
arg.p_dscale_[dscale_offset] = type_convert<DscaleDbiasDataType>(dscale);
arg.p_dbias_[dbias_offset] = type_convert<DscaleDbiasDataType>(dbias);
size_t scale_offset =
get_offset_from_index<NumInvariantDim>(arg.bnScaleStrides_, invariant_index);
AccDataType scale = type_convert<AccDataType>(arg.p_scale_[scale_offset]);
AccDataType multiplier = type_convert<AccDataType>(1.0f) /
type_convert<AccDataType>(arg.reduceSize_) * invVar *
scale;
// 1) calculate tmp = dscale * (x - mean) * inv-variance
// 2) calculate dx = 1/reduceSize * inv-variance * scale * (reduceSize * dy - dbias
// - tmp)
for(const auto& reduce_index : arg.reduce_index_set_)
{
size_t x_reduce_offset = get_offset_from_index<NumBatchNormReduceDim>(
arg.x_reduce_strides_, reduce_index);
size_t dy_reduce_offset = get_offset_from_index<NumBatchNormReduceDim>(
arg.dy_reduce_strides_, reduce_index);
size_t dx_reduce_offset = get_offset_from_index<NumBatchNormReduceDim>(
arg.dx_reduce_strides_, reduce_index);
auto x_offset = x_invariant_offset + x_reduce_offset;
auto dy_offset = dy_invariant_offset + dy_reduce_offset;
auto dx_offset = dx_invariant_offset + dx_reduce_offset;
AccDataType x = type_convert<AccDataType>(arg.p_x_[x_offset]);
AccDataType norm_x = (x - mean) * invVar;
AccDataType dy = type_convert<AccDataType>(arg.p_dy_[dy_offset]);
arg.dy_elementwise_op_(dy, dy);
AccDataType tmpVal = norm_x * dscale;
AccDataType dx = multiplier * (type_convert<AccDataType>(arg.reduceSize_) * dy -
dbias - tmpVal);
arg.p_dx_[dx_offset] = type_convert<DxDataType>(dx);
};
};
std::size_t num_thread = std::thread::hardware_concurrency();
std::size_t work_per_thread =
(arg.invariant_index_set_.size() + num_thread - 1) / num_thread;
std::vector<joinable_thread> threads(num_thread);
for(std::size_t it = 0; it < num_thread; ++it)
{
std::size_t i_begin = it * work_per_thread;
std::size_t i_end = std::min(static_cast<size_t>((it + 1) * work_per_thread),
arg.invariant_index_set_.size());
auto f = [=] {
for(std::size_t i = i_begin; i < i_end; ++i)
{
thread_reduce_func(arg.invariant_index_set_[i]);
}
};
threads[it] = joinable_thread(f);
}
return (0.0f);
};
float Run(const device::BaseArgument* p_arg,
const StreamConfig& /*stream_config*/ = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg));
};
};
bool IsSupportedArgument(const device::BaseArgument* p_arg) override
{
(void)p_arg;
return (true);
};
std::unique_ptr<device::BaseArgument>
MakeArgumentPointer(const std::array<index_t, Rank> xyLengths,
const std::array<index_t, Rank> xStrides,
const std::array<index_t, Rank> dxStrides,
const std::array<index_t, Rank> dyStrides,
const std::array<int, NumBatchNormReduceDim> reduceDims,
const std::array<index_t, NumInvariantDim> bnScaleBiasMeanVarLengths,
const std::array<index_t, NumInvariantDim> bnScaleStrides,
const std::array<index_t, NumInvariantDim> bnDscaleDbiasStrides,
const std::array<index_t, NumInvariantDim> bnMeanVarStrides,
const void* p_x,
const void* p_dy,
const void* p_scale,
const void* p_savedMean,
const void* p_savedInvVar,
double epsilon,
const DyElementwiseOp dy_elementwise_op,
void* p_dx,
void* p_dscale,
void* p_dbias) override
{
return std::make_unique<Argument>(xyLengths,
xStrides,
dxStrides,
dyStrides,
reduceDims,
bnScaleBiasMeanVarLengths,
bnScaleStrides,
bnDscaleDbiasStrides,
bnMeanVarStrides,
static_cast<const XDataType*>(p_x),
static_cast<const DyDataType*>(p_dy),
static_cast<const ScaleDataType*>(p_scale),
static_cast<const MeanVarDataType*>(p_savedMean),
static_cast<const MeanVarDataType*>(p_savedInvVar),
epsilon,
dy_elementwise_op,
static_cast<DxDataType*>(p_dx),
static_cast<DscaleDbiasDataType*>(p_dscale),
static_cast<DscaleDbiasDataType*>(p_dbias));
};
std::unique_ptr<device::BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>();
};
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "Reference_BatchNorm_Backward" << std::endl;
// clang-format on
return str.str();
}
};
} // namespace host
} // namespace tensor_operation
} // namespace ck
...@@ -4,13 +4,13 @@ ...@@ -4,13 +4,13 @@
#pragma once #pragma once
#include <iostream> #include <iostream>
#include <vector>
#include <array> #include <array>
#include <algorithm> #include <algorithm>
#include <thread> #include <thread>
#include "ck/utility/math_v2.hpp" #include "ck/utility/math_v2.hpp"
#include "ck/utility/ignore.hpp" #include "ck/utility/ignore.hpp"
#include "ck/library/utility/host_common_util.hpp"
#include "ck/tensor_operation/gpu/device/device_batchnorm_forward.hpp" #include "ck/tensor_operation/gpu/device/device_batchnorm_forward.hpp"
namespace ck { namespace ck {
...@@ -23,20 +23,33 @@ template <typename XDataType, ...@@ -23,20 +23,33 @@ template <typename XDataType,
typename ScaleDataType, typename ScaleDataType,
typename BiasDataType, typename BiasDataType,
typename MeanVarDataType, typename MeanVarDataType,
typename YElementwiseOp> typename YElementwiseOp,
struct ReferenceBatchNormFwd_Input_N_H_W_C_Output_C index_t Rank,
: public device::DeviceBatchNormFwd<4, 3, YElementwiseOp> index_t NumBatchNormReduceDim>
struct ReferenceBatchNormFwd : public device::DeviceBatchNormFwd<XDataType,
YDataType,
AccDataType,
ScaleDataType,
BiasDataType,
MeanVarDataType,
YElementwiseOp,
Rank,
NumBatchNormReduceDim>
{ {
static_assert(Rank <= 6, "Bigger Rank size is not supported!");
static constexpr index_t NumInvariantDim = Rank - NumBatchNormReduceDim;
struct Argument : public device::BaseArgument struct Argument : public device::BaseArgument
{ {
Argument(const std::array<index_t, 4> xyLengths, Argument(const std::array<index_t, Rank> xyLengths,
const std::array<index_t, 4> xStrides, const std::array<index_t, Rank> xStrides,
const std::array<index_t, 4> yStrides, const std::array<index_t, Rank> yStrides,
const std::array<int, 3> reduceDims, const std::array<int, NumBatchNormReduceDim> reduceDims,
const std::array<index_t, 1> bnScaleBiasMeanVarLengths, const std::array<index_t, NumInvariantDim> bnScaleBiasMeanVarLengths,
const std::array<index_t, 1> bnScaleStrides, const std::array<index_t, NumInvariantDim> bnScaleStrides,
const std::array<index_t, 1> bnBiasStrides, const std::array<index_t, NumInvariantDim> bnBiasStrides,
const std::array<index_t, 1> bnMeanVarStrides, const std::array<index_t, NumInvariantDim> bnMeanVarStrides,
const XDataType* p_x, const XDataType* p_x,
const ScaleDataType* bnScale, const ScaleDataType* bnScale,
const BiasDataType* bnBias, const BiasDataType* bnBias,
...@@ -48,7 +61,12 @@ struct ReferenceBatchNormFwd_Input_N_H_W_C_Output_C ...@@ -48,7 +61,12 @@ struct ReferenceBatchNormFwd_Input_N_H_W_C_Output_C
double averageFactor, double averageFactor,
MeanVarDataType* resultRunningMean, MeanVarDataType* resultRunningMean,
MeanVarDataType* resultRunningVariance) MeanVarDataType* resultRunningVariance)
: p_x_(p_x), : reduceDims_(reduceDims),
bnScaleBiasMeanVarLengths_(bnScaleBiasMeanVarLengths),
bnScaleStrides_(bnScaleStrides),
bnBiasStrides_(bnBiasStrides),
bnMeanVarStrides_(bnMeanVarStrides),
p_x_(p_x),
bnScale_(bnScale), bnScale_(bnScale),
bnBias_(bnBias), bnBias_(bnBias),
y_elementwise_op_(y_elementwise_op), y_elementwise_op_(y_elementwise_op),
...@@ -58,21 +76,51 @@ struct ReferenceBatchNormFwd_Input_N_H_W_C_Output_C ...@@ -58,21 +76,51 @@ struct ReferenceBatchNormFwd_Input_N_H_W_C_Output_C
resultRunningMean_(resultRunningMean), resultRunningMean_(resultRunningMean),
resultRunningVariance_(resultRunningVariance) resultRunningVariance_(resultRunningVariance)
{ {
ignore = xStrides; using ck::host_common::get_index_set;
ignore = yStrides;
ignore = bnScaleStrides; if(std::any_of(
ignore = bnBiasStrides; reduceDims.begin(), reduceDims.end(), [](int d) { return d < 0 || d >= Rank; }))
ignore = bnMeanVarStrides; throw std::runtime_error("Invalid reduce dimensions!");
ignore = reduceDims;
// get invariant_dims[] and invariant_lengths[]
if(xyLengths.size() != 4 || bnScaleBiasMeanVarLengths.size() != 1 || for(int dim = 0, i = 0; dim < Rank; dim++)
bnScaleBiasMeanVarLengths[0] != xyLengths[3]) if(std::none_of(
throw std::runtime_error("Invalid tensor dimensions!"); reduceDims.begin(), reduceDims.end(), [&](int d) { return d == dim; }))
{
n = xyLengths[0]; invariantDims_[i] = dim;
h = xyLengths[1]; invariant_lengths_[i] = xyLengths[dim];
w = xyLengths[2]; i++;
c = xyLengths[3]; };
// get reduce_lengths_[]
for(int j = 0, i = 0; j < NumBatchNormReduceDim; j++)
{
int dim = reduceDims[j];
reduce_lengths_[i++] = xyLengths[dim];
};
for(int i = 0; i < NumInvariantDim; i++)
if(invariant_lengths_[i] != bnScaleBiasMeanVarLengths_[i])
throw std::runtime_error("Invalid lengths parameters!");
for(int j = 0, i = 0; j < NumInvariantDim; j++)
{
int dim = invariantDims_[j];
x_invariant_strides_[i] = xStrides[dim];
y_invariant_strides_[i] = yStrides[dim];
i++;
};
for(int j = 0, i = 0; j < NumBatchNormReduceDim; j++)
{
int dim = reduceDims_[j];
x_reduce_strides_[i] = xStrides[dim];
y_reduce_strides_[i] = yStrides[dim];
i++;
};
invariant_index_set_ = get_index_set<NumInvariantDim>(invariant_lengths_);
reduce_index_set_ = get_index_set<NumBatchNormReduceDim>(reduce_lengths_);
epsilon_ = type_convert<AccDataType>(epsilon); epsilon_ = type_convert<AccDataType>(epsilon);
averageFactor_ = type_convert<AccDataType>(averageFactor); averageFactor_ = type_convert<AccDataType>(averageFactor);
...@@ -81,6 +129,21 @@ struct ReferenceBatchNormFwd_Input_N_H_W_C_Output_C ...@@ -81,6 +129,21 @@ struct ReferenceBatchNormFwd_Input_N_H_W_C_Output_C
resultRunning = (resultRunningMean != nullptr && resultRunningVariance != nullptr); resultRunning = (resultRunningMean != nullptr && resultRunningVariance != nullptr);
} }
std::array<int, NumBatchNormReduceDim> reduceDims_;
std::array<int, NumInvariantDim> invariantDims_;
std::array<index_t, NumInvariantDim> invariant_lengths_;
std::array<index_t, NumBatchNormReduceDim> reduce_lengths_;
const std::array<index_t, NumInvariantDim> bnScaleBiasMeanVarLengths_;
const std::array<index_t, NumInvariantDim> bnScaleStrides_;
const std::array<index_t, NumInvariantDim> bnBiasStrides_;
const std::array<index_t, NumInvariantDim> bnMeanVarStrides_;
std::array<index_t, NumInvariantDim> x_invariant_strides_;
std::array<index_t, NumInvariantDim> y_invariant_strides_;
std::array<index_t, NumBatchNormReduceDim> x_reduce_strides_;
std::array<index_t, NumBatchNormReduceDim> y_reduce_strides_;
const XDataType* p_x_; const XDataType* p_x_;
const ScaleDataType* bnScale_; const ScaleDataType* bnScale_;
const BiasDataType* bnBias_; const BiasDataType* bnBias_;
...@@ -94,7 +157,8 @@ struct ReferenceBatchNormFwd_Input_N_H_W_C_Output_C ...@@ -94,7 +157,8 @@ struct ReferenceBatchNormFwd_Input_N_H_W_C_Output_C
bool resultSave, resultRunning; bool resultSave, resultRunning;
index_t n, h, w, c; std::vector<std::array<index_t, NumInvariantDim>> invariant_index_set_;
std::vector<std::array<index_t, NumBatchNormReduceDim>> reduce_index_set_;
AccDataType averageFactor_; AccDataType averageFactor_;
AccDataType epsilon_; AccDataType epsilon_;
...@@ -104,105 +168,119 @@ struct ReferenceBatchNormFwd_Input_N_H_W_C_Output_C ...@@ -104,105 +168,119 @@ struct ReferenceBatchNormFwd_Input_N_H_W_C_Output_C
{ {
float Run(const Argument& arg) float Run(const Argument& arg)
{ {
auto thread_reduce_func = [&](auto iC) { using ck::host_common::get_offset_from_index;
index_t offset_C = iC;
auto thread_reduce_func = [&](auto invariant_index) {
size_t x_invariant_offset = get_offset_from_index<NumInvariantDim>(
arg.x_invariant_strides_, invariant_index);
size_t y_invariant_offset = get_offset_from_index<NumInvariantDim>(
arg.y_invariant_strides_, invariant_index);
AccDataType mean = type_convert<AccDataType>(0.0f); AccDataType mean = type_convert<AccDataType>(0.0f);
AccDataType variance = type_convert<AccDataType>(0.0f); AccDataType variance = type_convert<AccDataType>(0.0f);
int32_t curr_count = 0; int32_t curr_count = 0;
// compute mean, variance using welford method // compute mean, variance using welford method
for(index_t iN = 0; iN < arg.n; iN++) for(const auto& reduce_index : arg.reduce_index_set_)
{ {
index_t offset_N = iN * arg.h * arg.w * arg.c; size_t x_reduce_offset = get_offset_from_index<NumBatchNormReduceDim>(
for(index_t iH = 0; iH < arg.h; iH++) arg.x_reduce_strides_, reduce_index);
{
index_t offset_H = iH * arg.w * arg.c;
for(index_t iW = 0; iW < arg.w; iW++)
{
index_t offset_W = iW * arg.c;
auto offset = offset_N + offset_H + offset_W + offset_C; auto x_offset = x_invariant_offset + x_reduce_offset;
curr_count++; curr_count++;
AccDataType x = type_convert<AccDataType>(arg.p_x_[offset]); AccDataType x = type_convert<AccDataType>(arg.p_x_[x_offset]);
AccDataType delta = x - mean; AccDataType delta = x - mean;
mean += delta / curr_count; mean += delta / curr_count;
AccDataType delta2 = x - mean; AccDataType delta2 = x - mean;
variance += delta * delta2; variance += delta * delta2;
};
}
}; };
// actual variance // actual variance
variance = variance / curr_count; variance = variance / curr_count;
// inv-variance defined as 1/sqrt(epsilon+variance)
AccDataType invVariance = AccDataType invVariance =
type_convert<AccDataType>(1.0f) / ck::math::sqrt(arg.epsilon_ + variance); type_convert<AccDataType>(1.0f) / ck::math::sqrt(arg.epsilon_ + variance);
// save the mean/invVariance if required // save the mean/inv-variance if required
if(arg.resultSave) if(arg.resultSave)
{ {
arg.resultSaveMean_[iC] = type_convert<MeanVarDataType>(mean); size_t offset = get_offset_from_index<NumInvariantDim>(arg.bnMeanVarStrides_,
arg.resultSaveInvVariance_[iC] = type_convert<MeanVarDataType>(invVariance); invariant_index);
arg.resultSaveMean_[offset] = type_convert<MeanVarDataType>(mean);
arg.resultSaveInvVariance_[offset] = type_convert<MeanVarDataType>(invVariance);
}; };
// update the moving average if required // update the moving average if required
if(arg.resultRunning) if(arg.resultRunning)
{ {
size_t offset = get_offset_from_index<NumInvariantDim>(arg.bnMeanVarStrides_,
invariant_index);
AccDataType oneMinusAverageFactor = AccDataType oneMinusAverageFactor =
type_convert<AccDataType>(1.0) - arg.averageFactor_; type_convert<AccDataType>(1.0) - arg.averageFactor_;
arg.resultRunningMean_[iC] = type_convert<MeanVarDataType>( arg.resultRunningMean_[offset] = type_convert<MeanVarDataType>(
type_convert<AccDataType>(arg.resultRunningMean_[iC]) * type_convert<AccDataType>(arg.resultRunningMean_[offset]) *
oneMinusAverageFactor + oneMinusAverageFactor +
mean * arg.averageFactor_); mean * arg.averageFactor_);
arg.resultRunningVariance_[iC] = type_convert<MeanVarDataType>( arg.resultRunningVariance_[offset] = type_convert<MeanVarDataType>(
arg.resultRunningVariance_[iC] * oneMinusAverageFactor + arg.resultRunningVariance_[offset] * oneMinusAverageFactor +
variance * arg.averageFactor_); variance * arg.averageFactor_);
}; };
size_t scale_offset =
get_offset_from_index<NumInvariantDim>(arg.bnScaleStrides_, invariant_index);
size_t bias_offset =
get_offset_from_index<NumInvariantDim>(arg.bnBiasStrides_, invariant_index);
AccDataType scale = type_convert<AccDataType>(arg.bnScale_[scale_offset]);
AccDataType bias = type_convert<AccDataType>(arg.bnBias_[bias_offset]);
// Normalization // Normalization
for(index_t iN = 0; iN < arg.n; iN++) for(const auto& reduce_index : arg.reduce_index_set_)
{ {
index_t offset_N = iN * arg.h * arg.w * arg.c; size_t x_reduce_offset = get_offset_from_index<NumBatchNormReduceDim>(
for(index_t iH = 0; iH < arg.h; iH++) arg.x_reduce_strides_, reduce_index);
{ size_t y_reduce_offset = get_offset_from_index<NumBatchNormReduceDim>(
index_t offset_H = iH * arg.w * arg.c; arg.y_reduce_strides_, reduce_index);
for(index_t iW = 0; iW < arg.w; iW++)
{
index_t offset_W = iW * arg.c;
auto offset = offset_N + offset_H + offset_W + offset_C; auto x_offset = x_invariant_offset + x_reduce_offset;
auto y_offset = y_invariant_offset + y_reduce_offset;
AccDataType x = type_convert<AccDataType>(arg.p_x_[offset]); AccDataType x = type_convert<AccDataType>(arg.p_x_[x_offset]);
AccDataType norm_x = AccDataType norm_x = (x - mean) * invVariance;
arg.bnScale_[iC] * (x - mean) * invVariance + arg.bnBias_[iC];
arg.p_y_[offset] = type_convert<YDataType>(norm_x); AccDataType y = scale * norm_x + bias;
};
} arg.y_elementwise_op_(y, y);
arg.p_y_[y_offset] = type_convert<YDataType>(y);
}; };
}; };
std::size_t num_thread = std::thread::hardware_concurrency(); std::size_t num_thread = std::thread::hardware_concurrency();
std::size_t work_per_thread = (arg.c + num_thread - 1) / num_thread; std::size_t work_per_thread =
(arg.invariant_index_set_.size() + num_thread - 1) / num_thread;
std::vector<joinable_thread> threads(num_thread); std::vector<joinable_thread> threads(num_thread);
for(std::size_t it = 0; it < num_thread; ++it) for(std::size_t it = 0; it < num_thread; ++it)
{ {
std::size_t ic_begin = it * work_per_thread; std::size_t i_begin = it * work_per_thread;
std::size_t ic_end = std::min(static_cast<int>((it + 1) * work_per_thread), arg.c); std::size_t i_end = std::min(static_cast<size_t>((it + 1) * work_per_thread),
arg.invariant_index_set_.size());
auto f = [=] { auto f = [=] {
for(std::size_t ic = ic_begin; ic < ic_end; ++ic) for(std::size_t i = i_begin; i < i_end; ++i)
{ {
thread_reduce_func(ic); thread_reduce_func(arg.invariant_index_set_[i]);
} }
}; };
...@@ -278,7 +356,7 @@ struct ReferenceBatchNormFwd_Input_N_H_W_C_Output_C ...@@ -278,7 +356,7 @@ struct ReferenceBatchNormFwd_Input_N_H_W_C_Output_C
auto str = std::stringstream(); auto str = std::stringstream();
// clang-format off // clang-format off
str << "Reference_BatchNorm_Forward_NHWC_C<" << std::endl; str << "Reference_BatchNorm_Forward" << std::endl;
// clang-format on // clang-format on
return str.str(); return str.str();
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
#include <array> #include <array>
#include <algorithm> #include <algorithm>
#include "ck/library/utility/host_common_util.hpp"
#include "ck/tensor_operation/gpu/device/device_batchnorm_infer.hpp" #include "ck/tensor_operation/gpu/device/device_batchnorm_infer.hpp"
namespace ck { namespace ck {
...@@ -19,114 +20,205 @@ template <typename XDataType, ...@@ -19,114 +20,205 @@ template <typename XDataType,
typename AccDataType, typename AccDataType,
typename ScaleDataType, typename ScaleDataType,
typename BiasDataType, typename BiasDataType,
typename MeanVarDataType> typename MeanVarDataType,
struct ReferenceBatchNormInfer_Input_N_H_W_C_Output_C : public device::DeviceBatchNormInfer<4, 3> typename YElementwiseOp,
index_t Rank,
index_t NumBatchNormReduceDim>
struct ReferenceBatchNormInfer : public device::DeviceBatchNormInfer<XDataType,
YDataType,
AccDataType,
ScaleDataType,
BiasDataType,
MeanVarDataType,
YElementwiseOp,
Rank,
NumBatchNormReduceDim>
{ {
static_assert(Rank <= 6, "Bigger Rank size is not supported!");
static constexpr index_t NumInvariantDim = Rank - NumBatchNormReduceDim;
struct Argument : public device::BaseArgument struct Argument : public device::BaseArgument
{ {
Argument(const std::array<index_t, 4> xyLengths, Argument(const std::array<index_t, Rank> xyLengths,
const std::array<index_t, 4> xStrides, const std::array<index_t, Rank> xStrides,
const std::array<index_t, 4> yStrides, const std::array<index_t, Rank> yStrides,
const std::array<index_t, 1> bnScaleBiasMeanVarLengths, const std::array<int, NumBatchNormReduceDim> reduceDims,
const std::array<index_t, 1> bnScaleStrides, const std::array<index_t, NumInvariantDim> bnScaleBiasMeanVarLengths,
const std::array<index_t, 1> bnBiasStrides, const std::array<index_t, NumInvariantDim> bnScaleStrides,
const std::array<index_t, 1> bnMeanVarStrides, const std::array<index_t, NumInvariantDim> bnBiasStrides,
const std::array<index_t, NumInvariantDim> bnMeanVarStrides,
const XDataType* p_x, const XDataType* p_x,
const ScaleDataType* bnScale, const ScaleDataType* bnScale,
const BiasDataType* bnBias, const BiasDataType* bnBias,
double epsilon, double epsilon,
const YElementwiseOp y_elementwise_op,
const MeanVarDataType* estimatedMean, const MeanVarDataType* estimatedMean,
const MeanVarDataType* estimatedVariance, const MeanVarDataType* estimatedVariance,
YDataType* p_y) YDataType* p_y)
: p_x_(p_x), : reduceDims_(reduceDims),
bnScaleBiasMeanVarLengths_(bnScaleBiasMeanVarLengths),
bnScaleStrides_(bnScaleStrides),
bnBiasStrides_(bnBiasStrides),
bnMeanVarStrides_(bnMeanVarStrides),
p_x_(p_x),
bnScale_(bnScale), bnScale_(bnScale),
bnBias_(bnBias), bnBias_(bnBias),
epsilon_(epsilon), y_elementwise_op_(y_elementwise_op),
estimatedMean_(estimatedMean), estimatedMean_(estimatedMean),
estimatedVariance_(estimatedVariance), estimatedVariance_(estimatedVariance),
p_y_(p_y) p_y_(p_y)
{ {
ignore = xStrides; using ck::host_common::get_index_set;
ignore = yStrides;
ignore = bnScaleStrides; if(std::any_of(
ignore = bnBiasStrides; reduceDims.begin(), reduceDims.end(), [](int d) { return d < 0 || d >= Rank; }))
ignore = bnMeanVarStrides; throw std::runtime_error("Invalid reduce dimensions!");
if(xyLengths.size() != 4 || bnScaleBiasMeanVarLengths.size() != 1 || // get invariant_dims[] and invariant_lengths[]
bnScaleBiasMeanVarLengths[0] != xyLengths[3]) for(int dim = 0, i = 0; dim < Rank; dim++)
throw std::runtime_error("Invalid tensor dimensions!"); if(std::none_of(
reduceDims.begin(), reduceDims.end(), [&](int d) { return d == dim; }))
n_ = xyLengths[0]; {
h_ = xyLengths[1]; invariantDims_[i] = dim;
w_ = xyLengths[2]; invariant_lengths_[i] = xyLengths[dim];
c_ = xyLengths[3]; i++;
};
// get reduce_lengths_[]
for(int j = 0, i = 0; j < NumBatchNormReduceDim; j++)
{
int dim = reduceDims[j];
reduce_lengths_[i++] = xyLengths[dim];
};
// check invariant_lengths_ and bnScaleBiasMeanVarLengths
for(int i = 0; i < NumInvariantDim; i++)
if(invariant_lengths_[i] != bnScaleBiasMeanVarLengths_[i])
throw std::runtime_error("Invalid lengths parameters!");
for(int j = 0, i = 0; j < NumInvariantDim; j++)
{
int dim = invariantDims_[j];
x_invariant_strides_[i] = xStrides[dim];
y_invariant_strides_[i] = yStrides[dim];
i++;
};
for(int j = 0, i = 0; j < NumBatchNormReduceDim; j++)
{
int dim = reduceDims_[j];
x_reduce_strides_[i] = xStrides[dim];
y_reduce_strides_[i] = yStrides[dim];
i++;
};
invariant_index_set_ = get_index_set<NumInvariantDim>(invariant_lengths_);
reduce_index_set_ = get_index_set<NumBatchNormReduceDim>(reduce_lengths_);
epsilon_ = type_convert<AccDataType>(epsilon);
} }
std::array<int, NumBatchNormReduceDim> reduceDims_;
std::array<int, NumInvariantDim> invariantDims_;
std::array<index_t, NumInvariantDim> invariant_lengths_;
std::array<index_t, NumBatchNormReduceDim> reduce_lengths_;
const std::array<index_t, NumInvariantDim> bnScaleBiasMeanVarLengths_;
const std::array<index_t, NumInvariantDim> bnScaleStrides_;
const std::array<index_t, NumInvariantDim> bnBiasStrides_;
const std::array<index_t, NumInvariantDim> bnMeanVarStrides_;
std::array<index_t, NumInvariantDim> x_invariant_strides_;
std::array<index_t, NumInvariantDim> y_invariant_strides_;
std::array<index_t, NumBatchNormReduceDim> x_reduce_strides_;
std::array<index_t, NumBatchNormReduceDim> y_reduce_strides_;
const XDataType* p_x_; const XDataType* p_x_;
const ScaleDataType* bnScale_; const ScaleDataType* bnScale_;
const BiasDataType* bnBias_; const BiasDataType* bnBias_;
const YElementwiseOp y_elementwise_op_;
double epsilon_;
const MeanVarDataType* estimatedMean_; const MeanVarDataType* estimatedMean_;
const MeanVarDataType* estimatedVariance_; const MeanVarDataType* estimatedVariance_;
YDataType* p_y_; YDataType* p_y_;
index_t n_, h_, w_, c_; std::vector<std::array<index_t, NumInvariantDim>> invariant_index_set_;
std::vector<std::array<index_t, NumBatchNormReduceDim>> reduce_index_set_;
AccDataType epsilon_;
}; };
struct Invoker : public device::BaseInvoker struct Invoker : public device::BaseInvoker
{ {
float Run(const Argument& arg) float Run(const Argument& arg)
{ {
auto thread_reduce_func = [&](auto iC) { using ck::host_common::get_offset_from_index;
index_t offset_C = iC;
AccDataType mean = arg.estimatedMean_[offset_C]; auto thread_reduce_func = [&](auto invariant_index) {
AccDataType variance = arg.estimatedVariance_[offset_C]; size_t x_invariant_offset = get_offset_from_index<NumInvariantDim>(
arg.x_invariant_strides_, invariant_index);
size_t y_invariant_offset = get_offset_from_index<NumInvariantDim>(
arg.y_invariant_strides_, invariant_index);
size_t mean_variance_offset =
get_offset_from_index<NumInvariantDim>(arg.bnMeanVarStrides_, invariant_index);
AccDataType mean = arg.estimatedMean_[mean_variance_offset];
AccDataType variance = arg.estimatedVariance_[mean_variance_offset];
// inv-variance defined as 1/sqrt(epsilon+variance)
AccDataType invVariance = AccDataType invVariance =
type_convert<AccDataType>(1.0f) / type_convert<AccDataType>(1.0f) / std::sqrt(arg.epsilon_ + variance);
std::sqrt(type_convert<AccDataType>(arg.epsilon_) + variance);
size_t scale_offset =
get_offset_from_index<NumInvariantDim>(arg.bnScaleStrides_, invariant_index);
size_t bias_offset =
get_offset_from_index<NumInvariantDim>(arg.bnBiasStrides_, invariant_index);
AccDataType scale = type_convert<AccDataType>(arg.bnScale_[scale_offset]);
AccDataType bias = type_convert<AccDataType>(arg.bnBias_[bias_offset]);
// Normalization // normalization
for(index_t iN = 0; iN < arg.n_; iN++) for(const auto& reduce_index : arg.reduce_index_set_)
{ {
index_t offset_N = iN * arg.h_ * arg.w_ * arg.c_; size_t x_reduce_offset = get_offset_from_index<NumBatchNormReduceDim>(
for(index_t iH = 0; iH < arg.h_; iH++) arg.x_reduce_strides_, reduce_index);
{ size_t y_reduce_offset = get_offset_from_index<NumBatchNormReduceDim>(
index_t offset_H = iH * arg.w_ * arg.c_; arg.y_reduce_strides_, reduce_index);
for(index_t iW = 0; iW < arg.w_; iW++)
{
index_t offset_W = iW * arg.c_;
auto offset = offset_N + offset_H + offset_W + offset_C; auto x_offset = x_invariant_offset + x_reduce_offset;
auto y_offset = y_invariant_offset + y_reduce_offset;
AccDataType x = type_convert<AccDataType>(arg.p_x_[offset]); AccDataType x = type_convert<AccDataType>(arg.p_x_[x_offset]);
AccDataType norm_x = AccDataType norm_x = (x - mean) * invVariance;
arg.bnScale_[iC] * (x - mean) * invVariance + arg.bnBias_[iC];
arg.p_y_[offset] = type_convert<YDataType>(norm_x); AccDataType y = scale * norm_x + bias;
};
} arg.y_elementwise_op_(y, y);
arg.p_y_[y_offset] = type_convert<YDataType>(y);
}; };
}; };
std::size_t num_thread = std::thread::hardware_concurrency(); std::size_t num_thread = std::thread::hardware_concurrency();
std::size_t work_per_thread = (arg.c_ + num_thread - 1) / num_thread; std::size_t work_per_thread =
(arg.invariant_index_set_.size() + num_thread - 1) / num_thread;
std::vector<joinable_thread> threads(num_thread); std::vector<joinable_thread> threads(num_thread);
for(std::size_t it = 0; it < num_thread; ++it) for(std::size_t it = 0; it < num_thread; ++it)
{ {
std::size_t ic_begin = it * work_per_thread; std::size_t i_begin = it * work_per_thread;
std::size_t ic_end = std::min(static_cast<int>((it + 1) * work_per_thread), arg.c_); std::size_t i_end = std::min(static_cast<size_t>((it + 1) * work_per_thread),
arg.invariant_index_set_.size());
auto f = [=] { auto f = [=] {
for(std::size_t ic = ic_begin; ic < ic_end; ++ic) for(std::size_t i = i_begin; i < i_end; ++i)
{ {
thread_reduce_func(ic); thread_reduce_func(arg.invariant_index_set_[i]);
} }
}; };
...@@ -151,17 +243,19 @@ struct ReferenceBatchNormInfer_Input_N_H_W_C_Output_C : public device::DeviceBat ...@@ -151,17 +243,19 @@ struct ReferenceBatchNormInfer_Input_N_H_W_C_Output_C : public device::DeviceBat
}; };
std::unique_ptr<device::BaseArgument> std::unique_ptr<device::BaseArgument>
MakeArgumentPointer(const std::array<index_t, 4> xyLengths, MakeArgumentPointer(const std::array<index_t, Rank> xyLengths,
const std::array<index_t, 4> xStrides, const std::array<index_t, Rank> xStrides,
const std::array<index_t, 4> yStrides, const std::array<index_t, Rank> yStrides,
const std::array<index_t, 1> bnScaleBiasMeanVarLengths, const std::array<int, NumBatchNormReduceDim> reduceDims,
const std::array<index_t, 1> bnScaleStrides, const std::array<index_t, NumInvariantDim> bnScaleBiasMeanVarLengths,
const std::array<index_t, 1> bnBiasStrides, const std::array<index_t, NumInvariantDim> bnScaleStrides,
const std::array<index_t, 1> bnMeanVarStrides, const std::array<index_t, NumInvariantDim> bnBiasStrides,
const std::array<index_t, NumInvariantDim> bnMeanVarStrides,
const void* p_x, const void* p_x,
const void* bnScale, const void* bnScale,
const void* bnBias, const void* bnBias,
double epsilon, double epsilon,
const YElementwiseOp y_elementwise_op,
const void* estimatedMean, const void* estimatedMean,
const void* estimatedVariance, const void* estimatedVariance,
void* p_y) override void* p_y) override
...@@ -169,6 +263,7 @@ struct ReferenceBatchNormInfer_Input_N_H_W_C_Output_C : public device::DeviceBat ...@@ -169,6 +263,7 @@ struct ReferenceBatchNormInfer_Input_N_H_W_C_Output_C : public device::DeviceBat
return std::make_unique<Argument>(xyLengths, return std::make_unique<Argument>(xyLengths,
xStrides, xStrides,
yStrides, yStrides,
reduceDims,
bnScaleBiasMeanVarLengths, bnScaleBiasMeanVarLengths,
bnScaleStrides, bnScaleStrides,
bnBiasStrides, bnBiasStrides,
...@@ -177,6 +272,7 @@ struct ReferenceBatchNormInfer_Input_N_H_W_C_Output_C : public device::DeviceBat ...@@ -177,6 +272,7 @@ struct ReferenceBatchNormInfer_Input_N_H_W_C_Output_C : public device::DeviceBat
static_cast<const ScaleDataType*>(bnScale), static_cast<const ScaleDataType*>(bnScale),
static_cast<const BiasDataType*>(bnBias), static_cast<const BiasDataType*>(bnBias),
epsilon, epsilon,
y_elementwise_op,
static_cast<const MeanVarDataType*>(estimatedMean), static_cast<const MeanVarDataType*>(estimatedMean),
static_cast<const MeanVarDataType*>(estimatedVariance), static_cast<const MeanVarDataType*>(estimatedVariance),
static_cast<YDataType*>(p_y)); static_cast<YDataType*>(p_y));
...@@ -192,7 +288,7 @@ struct ReferenceBatchNormInfer_Input_N_H_W_C_Output_C : public device::DeviceBat ...@@ -192,7 +288,7 @@ struct ReferenceBatchNormInfer_Input_N_H_W_C_Output_C : public device::DeviceBat
auto str = std::stringstream(); auto str = std::stringstream();
// clang-format off // clang-format off
str << "Reference_BatchNorm_Forward_NHWC_C<" << std::endl; str << "Reference_BatchNorm_Infer<" << std::endl;
// clang-format on // clang-format on
return str.str(); return str.str();
......
...@@ -131,17 +131,22 @@ struct ReferenceConvBwdWeight : public device::BaseOperator ...@@ -131,17 +131,22 @@ struct ReferenceConvBwdWeight : public device::BaseOperator
else if constexpr(NDimSpatial == 2) else if constexpr(NDimSpatial == 2)
{ {
auto f_kcyx = [&](auto g, auto k, auto c, auto y, auto x) { auto f_kcyx = [&](auto g, auto k, auto c, auto y, auto x) {
std::size_t N = arg.output_.GetLengths()[1];
std::size_t Ho = arg.output_.GetLengths()[3];
std::size_t Wo = arg.output_.GetLengths()[4];
float v_acc = 0; float v_acc = 0;
for(std::size_t n = 0; n < arg.output_.GetLengths()[1]; ++n) for(std::size_t n = 0; n < N; ++n)
{ {
for(std::size_t ho = 0; ho < arg.output_.GetLengths()[3]; ++ho) for(std::size_t ho = 0; ho < Ho; ++ho)
{ {
auto hi = static_cast<ck::long_index_t>(ho * arg.conv_strides_[0]) + auto hi = static_cast<ck::long_index_t>(ho * arg.conv_strides_[0]) +
static_cast<ck::long_index_t>(y * arg.conv_dilations_[0]) - static_cast<ck::long_index_t>(y * arg.conv_dilations_[0]) -
static_cast<ck::long_index_t>(arg.in_left_pads_[0]); static_cast<ck::long_index_t>(arg.in_left_pads_[0]);
for(std::size_t wo = 0; wo < arg.output_.GetLengths()[4]; ++wo) for(std::size_t wo = 0; wo < Wo; ++wo)
{ {
auto wi = auto wi =
static_cast<ck::long_index_t>(wo * arg.conv_strides_[1]) + static_cast<ck::long_index_t>(wo * arg.conv_strides_[1]) +
......
...@@ -44,8 +44,8 @@ struct ReferenceGemmLayernorm : public device::BaseOperator ...@@ -44,8 +44,8 @@ struct ReferenceGemmLayernorm : public device::BaseOperator
size_t M = acc.mDesc.GetLengths()[0]; size_t M = acc.mDesc.GetLengths()[0];
size_t N = acc.mDesc.GetLengths()[1]; size_t N = acc.mDesc.GetLengths()[1];
Tensor<ComputeDataType> avg_acc_sq(HostTensorDescriptor(std::vector<size_t>({M}))); Tensor<ComputeDataType> avg_acc_sq({M});
Tensor<ComputeDataType> avg_acc(HostTensorDescriptor(std::vector<size_t>({M}))); Tensor<ComputeDataType> avg_acc({M});
Tensor<ComputeDataType> acc_layernorm(acc); Tensor<ComputeDataType> acc_layernorm(acc);
// reduce N dim // reduce N dim
......
...@@ -92,9 +92,10 @@ struct ReferenceLayernorm : public device::BaseOperator ...@@ -92,9 +92,10 @@ struct ReferenceLayernorm : public device::BaseOperator
{ {
for(int n = 0; n < N; ++n) for(int n = 0; n < N; ++n)
{ {
auto x_val = ck::type_convert<AccDataType>(arg.x_m_n_(m, n)); auto x_val = ck::type_convert<AccDataType>(arg.x_m_n_(m, n));
auto y_val = (x_val - mean(m)) / sqrt(var(m) + arg.epsilon_); auto y_val = (x_val - mean(m)) / sqrt(var(m) + arg.epsilon_);
y_val = (y_val * arg.gamma_n_(n)) + arg.beta_n_(n); y_val = (y_val * arg.gamma_n_(n)) + arg.beta_n_(n);
arg.acc_elementwise_op_(y_val, y_val);
arg.y_m_n_(m, n) = ck::type_convert<YDataType>(y_val); arg.y_m_n_(m, n) = ck::type_convert<YDataType>(y_val);
} }
} }
......
...@@ -60,6 +60,12 @@ struct ReferenceSoftmax : public device::BaseOperator ...@@ -60,6 +60,12 @@ struct ReferenceSoftmax : public device::BaseOperator
{ {
scalar_lengths.push_back(arg.in_.mDesc.GetLengths()[dim]); scalar_lengths.push_back(arg.in_.mDesc.GetLengths()[dim]);
} }
// max and sum reduction with final reduced values of dim=0 is a scalar so give it
// appropriate lengths of {1}
if(arg.sm_scalar_dims_.size() == 0)
{
scalar_lengths.push_back(1);
}
Tensor<AccDataType> reduce_max(scalar_lengths); Tensor<AccDataType> reduce_max(scalar_lengths);
reduce_max.GenerateTensorValue( reduce_max.GenerateTensorValue(
...@@ -67,6 +73,9 @@ struct ReferenceSoftmax : public device::BaseOperator ...@@ -67,6 +73,9 @@ struct ReferenceSoftmax : public device::BaseOperator
Tensor<AccDataType> reduce_sum(scalar_lengths); Tensor<AccDataType> reduce_sum(scalar_lengths);
reduce_sum.GenerateTensorValue(GeneratorTensor_1<AccDataType>{0}); reduce_sum.GenerateTensorValue(GeneratorTensor_1<AccDataType>{0});
// when final reduced values is of dim=0, the index will be transformed into empty
// std::vector which is actually a valid input for Tensor::operator(std::vector) and
// internally accesses 0'th element
auto to_sm_scalar_idx = [&](auto idx) { auto to_sm_scalar_idx = [&](auto idx) {
std::vector<size_t> sm_scalar_idx; std::vector<size_t> sm_scalar_idx;
for(index_t dim : arg.sm_scalar_dims_) for(index_t dim : arg.sm_scalar_dims_)
...@@ -77,8 +86,8 @@ struct ReferenceSoftmax : public device::BaseOperator ...@@ -77,8 +86,8 @@ struct ReferenceSoftmax : public device::BaseOperator
}; };
arg.in_.ForEach([&](auto& self, auto idx) { arg.in_.ForEach([&](auto& self, auto idx) {
reduce_max(to_sm_scalar_idx(idx)) = std::max(reduce_max(to_sm_scalar_idx(idx)), reduce_max(to_sm_scalar_idx(idx)) = std::max(
static_cast<AccDataType>(self(idx))); reduce_max(to_sm_scalar_idx(idx)), ck::type_convert<AccDataType>(self(idx)));
}); });
// LogRangeAsType<float>(std::cout << "reduce_max: ", reduce_max.mData, ",") << // LogRangeAsType<float>(std::cout << "reduce_max: ", reduce_max.mData, ",") <<
...@@ -87,7 +96,7 @@ struct ReferenceSoftmax : public device::BaseOperator ...@@ -87,7 +96,7 @@ struct ReferenceSoftmax : public device::BaseOperator
Tensor<AccDataType> in_stable(arg.in_.mDesc); Tensor<AccDataType> in_stable(arg.in_.mDesc);
in_stable.ForEach([&](auto& self, auto idx) { in_stable.ForEach([&](auto& self, auto idx) {
// numerator = exp(x - max(x)) // numerator = exp(x - max(x))
self(idx) = std::exp(static_cast<AccDataType>(arg.in_(idx)) - self(idx) = std::exp(ck::type_convert<AccDataType>(arg.in_(idx)) -
reduce_max(to_sm_scalar_idx(idx))); reduce_max(to_sm_scalar_idx(idx)));
}); });
...@@ -102,8 +111,10 @@ struct ReferenceSoftmax : public device::BaseOperator ...@@ -102,8 +111,10 @@ struct ReferenceSoftmax : public device::BaseOperator
// std::endl; // std::endl;
arg.out_.ForEach([&](auto& self, auto idx) { arg.out_.ForEach([&](auto& self, auto idx) {
self(idx) = arg.alpha_ * in_stable(idx) / reduce_sum(to_sm_scalar_idx(idx)) + AccDataType temp_result =
arg.beta_ * self(idx); arg.alpha_ * in_stable(idx) / reduce_sum(to_sm_scalar_idx(idx)) +
arg.beta_ * self(idx);
self(idx) = ck::type_convert<OutDataType>(temp_result);
}); });
// LogRangeAsType<float>(std::cout << "out: ", arg.out_.mData, ",") << std::endl; // LogRangeAsType<float>(std::cout << "out: ", arg.out_.mData, ",") << std::endl;
......
...@@ -3,10 +3,10 @@ ...@@ -3,10 +3,10 @@
#pragma once #pragma once
#include "ck/utility/data_type.hpp"
#include "ck/utility/tuple.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/utility/tuple.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
...@@ -26,7 +26,9 @@ using Empty_Tuple = ck::Tuple<>; ...@@ -26,7 +26,9 @@ using Empty_Tuple = ck::Tuple<>;
using F16_Tuple = ck::Tuple<F16>; using F16_Tuple = ck::Tuple<F16>;
using F16_F16_Tuple = ck::Tuple<F16, F16>; using F16_F16_Tuple = ck::Tuple<F16, F16>;
using F32_Tuple = ck::Tuple<F32>; using F32_Tuple = ck::Tuple<F32>;
using I32_Tuple = ck::Tuple<I32>;
using I32_F32_Tuple = ck::Tuple<I32, F32>;
// GEMM layout // GEMM layout
using Row = ck::tensor_layout::gemm::RowMajor; using Row = ck::tensor_layout::gemm::RowMajor;
...@@ -75,13 +77,35 @@ using NWGK = ck::tensor_layout::convolution::NWGK; ...@@ -75,13 +77,35 @@ using NWGK = ck::tensor_layout::convolution::NWGK;
using NHWGK = ck::tensor_layout::convolution::NHWGK; using NHWGK = ck::tensor_layout::convolution::NHWGK;
using NDHWGK = ck::tensor_layout::convolution::NDHWGK; using NDHWGK = ck::tensor_layout::convolution::NDHWGK;
//
using GK = ck::tensor_layout::convolution::G_K;
using GK_Tuple = ck::Tuple<GK>;
using GK_GK_Tuple = ck::Tuple<GK, GK>;
// pointwise functor // pointwise functor
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using Relu = ck::tensor_operation::element_wise::Relu;
using Scale = ck::tensor_operation::element_wise::Scale; using Scale = ck::tensor_operation::element_wise::Scale;
using Bilinear = ck::tensor_operation::element_wise::Bilinear; using Bilinear = ck::tensor_operation::element_wise::Bilinear;
using AddAddFastGelu = ck::tensor_operation::element_wise::AddAddFastGelu; using AddAddFastGelu = ck::tensor_operation::element_wise::AddAddFastGelu;
using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu;
using FastGelu = ck::tensor_operation::element_wise::FastGelu;
template <typename Activation>
using Activation_Mul_Clamp = ck::tensor_operation::element_wise::Activation_Mul_Clamp<Activation>;
template <typename Activation>
using Add_Activation_Mul_Clamp =
ck::tensor_operation::element_wise::Add_Activation_Mul_Clamp<Activation>;
template <typename Activation>
using Activation_Mul2_Clamp = ck::tensor_operation::element_wise::Activation_Mul2_Clamp<Activation>;
template <typename Activation>
using Add_Activation_Mul2_Clamp =
ck::tensor_operation::element_wise::Add_Activation_Mul2_Clamp<Activation>;
template <typename DeviceOp> template <typename DeviceOp, typename Tag = void>
struct DeviceOperationInstanceFactory; struct DeviceOperationInstanceFactory;
} // namespace instance } // namespace instance
......
...@@ -59,6 +59,48 @@ void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_g ...@@ -59,6 +59,48 @@ void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_g
MaskingSpecialization::MaskDisabled>>>& MaskingSpecialization::MaskDisabled>>>&
instances); instances);
void add_device_batched_gemm_masking_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances(
std::vector<std::unique_ptr<
DeviceBatchedGemmSoftmaxGemmPermute<2,
1,
1,
1,
1,
BF16,
BF16,
BF16,
BF16,
ck::Tuple<>,
ck::Tuple<>,
PassThrough,
PassThrough,
Scale,
PassThrough,
PassThrough,
MaskingSpecialization::MaskOutUpperTriangle>>>&
instances);
void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances(
std::vector<
std::unique_ptr<DeviceBatchedGemmSoftmaxGemmPermute<2,
1,
1,
1,
1,
BF16,
BF16,
BF16,
BF16,
ck::Tuple<>,
ck::Tuple<>,
PassThrough,
PassThrough,
Scale,
PassThrough,
PassThrough,
MaskingSpecialization::MaskDisabled>>>&
instances);
template <typename ADataType, template <typename ADataType,
typename B0DataType, typename B0DataType,
typename B1DataType, typename B1DataType,
...@@ -119,6 +161,20 @@ struct DeviceOperationInstanceFactory< ...@@ -119,6 +161,20 @@ struct DeviceOperationInstanceFactory<
op_ptrs); op_ptrs);
} }
} }
else if constexpr(is_same_v<ADataType, BF16> && is_same_v<B0DataType, BF16> &&
is_same_v<B1DataType, BF16> && is_same_v<CDataType, BF16>)
{
if constexpr(MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle)
{
add_device_batched_gemm_masking_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances(
op_ptrs);
}
else if(MaskingSpec == MaskingSpecialization::MaskDisabled)
{
add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances(
op_ptrs);
}
}
return op_ptrs; return op_ptrs;
} }
}; };
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/device_batchnorm_backward.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// FP16
void add_device_batchnorm_backward_rank_4_3_f16_instances(
std::vector<std::unique_ptr<
DeviceBatchNormBwd<F16, F32, F32, F32, F16, F32, F32, PassThrough, 4, 3>>>&);
// FP32
void add_device_batchnorm_backward_rank_4_3_f32_instances(
std::vector<std::unique_ptr<
DeviceBatchNormBwd<F32, F32, F32, F32, F32, F32, F32, PassThrough, 4, 3>>>&);
// BF16
void add_device_batchnorm_backward_rank_4_3_bf16_instances(
std::vector<std::unique_ptr<
DeviceBatchNormBwd<BF16, F32, F32, F32, BF16, F32, F32, PassThrough, 4, 3>>>&);
// FP64
void add_device_batchnorm_backward_rank_4_3_f64_instances(
std::vector<std::unique_ptr<
DeviceBatchNormBwd<F64, F64, F64, F64, F64, F64, F64, PassThrough, 4, 3>>>&);
template <typename XDataType,
typename DxDataType,
typename DyDataType,
typename AccDataType,
typename ScaleDataType,
typename DscaleDbiasDataType,
typename MeanVarDataType,
typename DyElementwiseOp,
index_t Rank,
index_t NumReduceDim>
struct DeviceOperationInstanceFactory<
ck::tensor_operation::device::DeviceBatchNormBwd<XDataType,
DxDataType,
DyDataType,
AccDataType,
ScaleDataType,
DscaleDbiasDataType,
MeanVarDataType,
DyElementwiseOp,
Rank,
NumReduceDim>>
{
using DeviceOp = DeviceBatchNormBwd<XDataType,
DxDataType,
DyDataType,
AccDataType,
ScaleDataType,
DscaleDbiasDataType,
MeanVarDataType,
DyElementwiseOp,
Rank,
NumReduceDim>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
if constexpr(is_same_v<XDataType, F16> && is_same_v<DxDataType, F32> &&
is_same_v<DyDataType, F32> && is_same_v<AccDataType, F32> &&
is_same_v<ScaleDataType, F16> && is_same_v<DscaleDbiasDataType, F32> &&
is_same_v<MeanVarDataType, F32>)
{
if constexpr(Rank == 4 && NumReduceDim == 3 && is_same_v<DyElementwiseOp, PassThrough>)
{
add_device_batchnorm_backward_rank_4_3_f16_instances(op_ptrs);
}
}
else if constexpr(is_same_v<XDataType, F32> && is_same_v<DxDataType, F32> &&
is_same_v<DyDataType, F32> && is_same_v<AccDataType, F32> &&
is_same_v<ScaleDataType, F32> && is_same_v<DscaleDbiasDataType, F32> &&
is_same_v<MeanVarDataType, F32>)
{
if constexpr(Rank == 4 && NumReduceDim == 3 && is_same_v<DyElementwiseOp, PassThrough>)
{
add_device_batchnorm_backward_rank_4_3_f32_instances(op_ptrs);
}
}
else if constexpr(is_same_v<XDataType, BF16> && is_same_v<DxDataType, F32> &&
is_same_v<DyDataType, F32> && is_same_v<AccDataType, F32> &&
is_same_v<ScaleDataType, BF16> && is_same_v<DscaleDbiasDataType, F32> &&
is_same_v<MeanVarDataType, F32>)
{
if constexpr(Rank == 4 && NumReduceDim == 3 && is_same_v<DyElementwiseOp, PassThrough>)
{
add_device_batchnorm_backward_rank_4_3_bf16_instances(op_ptrs);
}
}
else if constexpr(is_same_v<XDataType, F64> && is_same_v<DxDataType, F64> &&
is_same_v<DyDataType, F64> && is_same_v<AccDataType, F64> &&
is_same_v<ScaleDataType, F64> && is_same_v<DscaleDbiasDataType, F64> &&
is_same_v<MeanVarDataType, F64>)
{
if constexpr(Rank == 4 && NumReduceDim == 3 && is_same_v<DyElementwiseOp, PassThrough>)
{
add_device_batchnorm_backward_rank_4_3_f64_instances(op_ptrs);
}
}
return op_ptrs;
}
};
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/device_batchnorm_forward.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// FP16
void add_device_batchnorm_forward_rank_4_3_f16_instances(
std::vector<
std::unique_ptr<DeviceBatchNormFwd<F16, F16, F32, F16, F16, F32, PassThrough, 4, 3>>>&);
// FP32
void add_device_batchnorm_forward_rank_4_3_f32_instances(
std::vector<
std::unique_ptr<DeviceBatchNormFwd<F32, F32, F32, F32, F32, F32, PassThrough, 4, 3>>>&);
// BF16
void add_device_batchnorm_forward_rank_4_3_bf16_instances(
std::vector<
std::unique_ptr<DeviceBatchNormFwd<BF16, BF16, F32, BF16, BF16, F32, PassThrough, 4, 3>>>&);
// FP64
void add_device_batchnorm_forward_rank_4_3_f64_instances(
std::vector<
std::unique_ptr<DeviceBatchNormFwd<F64, F64, F64, F64, F64, F64, PassThrough, 4, 3>>>&);
template <typename XDataType,
typename YDataType,
typename AccDataType,
typename ScaleDataType,
typename BiasDataType,
typename MeanVarDataType,
typename YElementwiseOp,
index_t Rank,
index_t NumReduceDim>
struct DeviceOperationInstanceFactory<
ck::tensor_operation::device::DeviceBatchNormFwd<XDataType,
YDataType,
AccDataType,
ScaleDataType,
BiasDataType,
MeanVarDataType,
YElementwiseOp,
Rank,
NumReduceDim>>
{
using DeviceOp = DeviceBatchNormFwd<XDataType,
YDataType,
AccDataType,
ScaleDataType,
BiasDataType,
MeanVarDataType,
YElementwiseOp,
Rank,
NumReduceDim>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
if constexpr(is_same_v<XDataType, F16> && is_same_v<YDataType, F16> &&
is_same_v<AccDataType, F32> && is_same_v<ScaleDataType, F16> &&
is_same_v<BiasDataType, F16> && is_same_v<MeanVarDataType, F32>)
{
if constexpr(Rank == 4 && NumReduceDim == 3 && is_same_v<YElementwiseOp, PassThrough>)
{
add_device_batchnorm_forward_rank_4_3_f16_instances(op_ptrs);
}
}
else if constexpr(is_same_v<XDataType, F32> && is_same_v<YDataType, F32> &&
is_same_v<AccDataType, F32> && is_same_v<ScaleDataType, F32> &&
is_same_v<BiasDataType, F32> && is_same_v<MeanVarDataType, F32>)
{
if constexpr(Rank == 4 && NumReduceDim == 3 && is_same_v<YElementwiseOp, PassThrough>)
{
add_device_batchnorm_forward_rank_4_3_f32_instances(op_ptrs);
}
}
else if constexpr(is_same_v<XDataType, BF16> && is_same_v<YDataType, BF16> &&
is_same_v<AccDataType, F32> && is_same_v<ScaleDataType, BF16> &&
is_same_v<BiasDataType, BF16> && is_same_v<MeanVarDataType, F32>)
{
if constexpr(Rank == 4 && NumReduceDim == 3 && is_same_v<YElementwiseOp, PassThrough>)
{
add_device_batchnorm_forward_rank_4_3_bf16_instances(op_ptrs);
}
}
else if constexpr(is_same_v<XDataType, F64> && is_same_v<YDataType, F64> &&
is_same_v<AccDataType, F64> && is_same_v<ScaleDataType, F64> &&
is_same_v<BiasDataType, F64> && is_same_v<MeanVarDataType, F64>)
{
if constexpr(Rank == 4 && NumReduceDim == 3 && is_same_v<YElementwiseOp, PassThrough>)
{
add_device_batchnorm_forward_rank_4_3_f64_instances(op_ptrs);
}
}
return op_ptrs;
}
};
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -101,6 +101,42 @@ void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances( ...@@ -101,6 +101,42 @@ void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances(
PassThrough, PassThrough,
PassThrough>>>& instances); PassThrough>>>& instances);
// conv2d dl
void add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f16_instances(
std::vector<std::unique_ptr<DeviceConvBwdData<2,
NHWC,
KYXC,
NHWK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f32_instances(
std::vector<std::unique_ptr<DeviceConvBwdData<2,
NHWC,
KYXC,
NHWK,
F32,
F32,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_int8_instances(
std::vector<std::unique_ptr<DeviceConvBwdData<2,
NHWC,
KYXC,
NHWK,
int8_t,
int8_t,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances);
// conv3d backward data // conv3d backward data
void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_bf16_instances( void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_bf16_instances(
std::vector<std::unique_ptr<DeviceConvBwdData<3, std::vector<std::unique_ptr<DeviceConvBwdData<3,
...@@ -216,11 +252,13 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceConvBw ...@@ -216,11 +252,13 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceConvBw
is_same_v<OutDataType, float>) is_same_v<OutDataType, float>)
{ {
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instances(op_ptrs); add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instances(op_ptrs);
add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f32_instances(op_ptrs);
} }
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> && else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t>) is_same_v<OutDataType, half_t>)
{ {
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instances(op_ptrs); add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instances(op_ptrs);
add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f16_instances(op_ptrs);
} }
else if constexpr(is_same_v<InDataType, ck::bhalf_t> && else if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
is_same_v<WeiDataType, ck::bhalf_t> && is_same_v<WeiDataType, ck::bhalf_t> &&
...@@ -232,6 +270,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceConvBw ...@@ -232,6 +270,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceConvBw
is_same_v<OutDataType, int8_t>) is_same_v<OutDataType, int8_t>)
{ {
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances(op_ptrs); add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances(op_ptrs);
add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_int8_instances(op_ptrs);
} }
} }
else if constexpr(NumDimSpatial == 3 && is_same_v<InLayout, NDHWC> && else if constexpr(NumDimSpatial == 3 && is_same_v<InLayout, NDHWC> &&
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_conv_bwd_weight.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// conv1d backward weight
void add_device_conv1d_bwd_weight_xdl_nwc_kxc_nwk_bf16_f32_bf16_instances(
std::vector<std::unique_ptr<DeviceConvBwdWeight<1,
NWC,
KXC,
NWK,
BF16,
F32,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_conv1d_bwd_weight_xdl_nwc_kxc_nwk_f16_instances(
std::vector<std::unique_ptr<DeviceConvBwdWeight<1,
NWC,
KXC,
NWK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_conv1d_bwd_weight_xdl_nwc_kxc_nwk_f32_instances(
std::vector<std::unique_ptr<DeviceConvBwdWeight<1,
NWC,
KXC,
NWK,
F32,
F32,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances);
// conv2d backward weight
void add_device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_bf16_f32_bf16_instances(
std::vector<std::unique_ptr<DeviceConvBwdWeight<2,
NHWC,
KYXC,
NHWK,
BF16,
F32,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f16_instances(
std::vector<std::unique_ptr<DeviceConvBwdWeight<2,
NHWC,
KYXC,
NHWK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f32_instances(
std::vector<std::unique_ptr<DeviceConvBwdWeight<2,
NHWC,
KYXC,
NHWK,
F32,
F32,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances);
// conv3d backward weight
void add_device_conv3d_bwd_weight_xdl_ndhwc_kzyxc_ndhwk_bf16_f32_bf16_instances(
std::vector<std::unique_ptr<DeviceConvBwdWeight<3,
NDHWC,
KZYXC,
NDHWK,
BF16,
F32,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_conv3d_bwd_weight_xdl_ndhwc_kzyxc_ndhwk_f16_instances(
std::vector<std::unique_ptr<DeviceConvBwdWeight<3,
NDHWC,
KZYXC,
NDHWK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_conv3d_bwd_weight_xdl_ndhwc_kzyxc_ndhwk_f32_instances(
std::vector<std::unique_ptr<DeviceConvBwdWeight<3,
NDHWC,
KZYXC,
NDHWK,
F32,
F32,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances);
template <ck::index_t NumDimSpatial,
typename InLayout,
typename WeiLayout,
typename OutLayout,
typename InDataType,
typename WeiDataType,
typename OutDataType>
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceConvBwdWeight<
NumDimSpatial,
InLayout,
WeiLayout,
OutLayout,
InDataType,
WeiDataType,
OutDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough>>
{
using DeviceOp = DeviceConvBwdWeight<NumDimSpatial,
InLayout,
WeiLayout,
OutLayout,
InDataType,
WeiDataType,
OutDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
if constexpr(NumDimSpatial == 1 && is_same_v<InLayout, NWC> && is_same_v<WeiLayout, KXC> &&
is_same_v<OutLayout, NWK>)
{
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, float>)
{
add_device_conv1d_bwd_weight_xdl_nwc_kxc_nwk_f32_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t>)
{
add_device_conv1d_bwd_weight_xdl_nwc_kxc_nwk_f16_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, ck::bhalf_t> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, ck::bhalf_t>)
{
add_device_conv1d_bwd_weight_xdl_nwc_kxc_nwk_bf16_f32_bf16_instances(op_ptrs);
}
}
else if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, NHWC> &&
is_same_v<WeiLayout, KYXC> && is_same_v<OutLayout, NHWK>)
{
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, float>)
{
add_device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f32_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t>)
{
add_device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f16_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, ck::bhalf_t> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, ck::bhalf_t>)
{
add_device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_bf16_f32_bf16_instances(op_ptrs);
}
}
else if constexpr(NumDimSpatial == 3 && is_same_v<InLayout, NDHWC> &&
is_same_v<WeiLayout, KZYXC> && is_same_v<OutLayout, NDHWK>)
{
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, float>)
{
add_device_conv3d_bwd_weight_xdl_ndhwc_kzyxc_ndhwk_f32_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t>)
{
add_device_conv3d_bwd_weight_xdl_ndhwc_kzyxc_ndhwk_f16_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, ck::bhalf_t> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, ck::bhalf_t>)
{
add_device_conv3d_bwd_weight_xdl_ndhwc_kzyxc_ndhwk_bf16_f32_bf16_instances(op_ptrs);
}
}
return op_ptrs;
}
};
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_elementwise_normalization.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// FP16
void add_device_elementwise_normalization_rank_2_1_f16_instances(
std::vector<std::unique_ptr<DeviceElementwiseNormalization<ck::Tuple<F16, F16>,
F16,
F16,
F32,
F16,
element_wise::Add,
PassThrough,
2,
1>>>&);
template <typename InDataTypeTuple,
typename GammaDataType,
typename BetaDataType,
typename YDataType,
index_t Rank,
index_t NumReduceDim>
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceElementwiseNormalization<
InDataTypeTuple,
GammaDataType,
BetaDataType,
F32,
YDataType,
ck::tensor_operation::element_wise::Add,
ck::tensor_operation::element_wise::PassThrough,
Rank,
NumReduceDim>>
{
using DeviceOp = DeviceElementwiseNormalization<InDataTypeTuple,
GammaDataType,
BetaDataType,
F32,
YDataType,
ck::tensor_operation::element_wise::Add,
ck::tensor_operation::element_wise::PassThrough,
Rank,
NumReduceDim>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
if constexpr(is_same_v<GammaDataType, F16> && is_same_v<BetaDataType, F16> &&
is_same_v<YDataType, F16>)
{
if constexpr(Rank == 2 && NumReduceDim == 1)
{
add_device_elementwise_normalization_rank_2_1_f16_instances(op_ptrs);
}
}
return op_ptrs;
}
};
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include <vector>
#include <memory>
#include "ck/ck.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD<Row,
Row,
Row_Tuple,
Row,
F16,
F16,
F16_Tuple,
F16,
PassThrough,
PassThrough,
AddFastGelu>>>&);
void add_device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD<Row,
Col,
Row_Tuple,
Row,
F16,
F16,
F16_Tuple,
F16,
PassThrough,
PassThrough,
AddFastGelu>>>&);
void add_device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD<Col,
Row,
Row_Tuple,
Row,
F16,
F16,
F16_Tuple,
F16,
PassThrough,
PassThrough,
AddFastGelu>>>&);
void add_device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD<Col,
Col,
Row_Tuple,
Row,
F16,
F16,
F16_Tuple,
F16,
PassThrough,
PassThrough,
AddFastGelu>>>&);
// GEMM + Add + FastGelu
template <typename ALayout,
typename BLayout,
typename D0Layout,
typename ELayout,
typename ADataType,
typename BDataType,
typename D0DataType,
typename EDataType>
struct DeviceOperationInstanceFactory<
ck::tensor_operation::device::DeviceGemmMultipleD<ALayout,
BLayout,
ck::Tuple<D0Layout>,
ELayout,
ADataType,
BDataType,
ck::Tuple<D0DataType>,
EDataType,
PassThrough,
PassThrough,
AddFastGelu>>
{
using DeviceOp = DeviceGemmMultipleD<ALayout,
BLayout,
ck::Tuple<D0Layout>,
ELayout,
ADataType,
BDataType,
ck::Tuple<D0DataType>,
EDataType,
PassThrough,
PassThrough,
AddFastGelu>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> &&
is_same_v<D0DataType, half_t> && is_same_v<EDataType, half_t>)
{
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
is_same_v<D0Layout, Row> && is_same_v<ELayout, Row>)
{
add_device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances(
op_ptrs);
}
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
is_same_v<D0Layout, Row> && is_same_v<ELayout, Row>)
{
add_device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instances(
op_ptrs);
}
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Row> &&
is_same_v<D0Layout, Row> && is_same_v<ELayout, Row>)
{
add_device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instances(
op_ptrs);
}
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Col> &&
is_same_v<D0Layout, Row> && is_same_v<ELayout, Row>)
{
add_device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_instances(
op_ptrs);
}
}
return op_ptrs;
}
};
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment