Commit f1b53d78 authored by aska-0096's avatar aska-0096
Browse files

Merge branch 'develop' of...

Merge branch 'develop' of https://github.com/ROCmSoftwarePlatform/composable_kernel into navi3x_mD_batchedGEMM_GroupConvFwd
parents 0c9cdbce 7494c1c6
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <array>
#include "device_base.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
// GEMM:
// input : A[M, K]
// input : B[N, K]
// input : D0[M, N], D1[M, N], ...
// output : E[M, N]
// output : H[M, N]
// C = a_op(A) * b_op(B)
// E = cde_op(C, D0, D1, ...)
// H = layernorm(E)
// Assume:
// D0, D1, ... and E have the same layout
// Calculate mean & variance along N dimension in layernorm(E)
template <typename ALayout,
typename BLayout,
typename DsLayout,
typename HLayout,
typename ADataType,
typename BDataType,
typename DsDataType,
typename GammaDataType,
typename BetaDataType,
typename HDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
typename HElementwiseOperation>
struct DeviceGemmMultipleDLayernorm : public BaseOperator
{
static constexpr index_t NumDTensor = DsDataType::Size();
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b,
std::array<const void*, NumDTensor> p_ds,
const void* p_gamma,
const void* p_beta,
void* p_h,
index_t MRaw,
index_t NRaw,
index_t KRaw,
index_t StrideA,
index_t StrideB,
std::array<index_t, NumDTensor> StrideDs,
index_t StrideH,
double epsilon,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op,
HElementwiseOperation h_element_op) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
}; // namespace device
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -32,8 +32,8 @@ struct DeviceMultipleReduce : public BaseOperator ...@@ -32,8 +32,8 @@ struct DeviceMultipleReduce : public BaseOperator
const std::array<index_t, NumOutputDim> outLengths, const std::array<index_t, NumOutputDim> outLengths,
const std::array<std::array<index_t, NumOutputDim>, NumReduction> outStrides, const std::array<std::array<index_t, NumOutputDim>, NumReduction> outStrides,
const std::array<int, NumReduceDim> reduceDims, const std::array<int, NumReduceDim> reduceDims,
const std::array<const void*, NumReduction> alphas, const std::array<double, NumReduction> alphas,
const std::array<const void*, NumReduction> betas, const std::array<double, NumReduction> betas,
const void* in_dev, const void* in_dev,
const std::array<void*, NumReduction> out_dev_buffers, const std::array<void*, NumReduction> out_dev_buffers,
const InElementwiseOperationTuple in_elementwise_op_tuple, const InElementwiseOperationTuple in_elementwise_op_tuple,
......
...@@ -28,7 +28,7 @@ struct DeviceNormalization : public BaseOperator ...@@ -28,7 +28,7 @@ struct DeviceNormalization : public BaseOperator
const std::vector<index_t> betaStrides, const std::vector<index_t> betaStrides,
const std::vector<index_t> yStrides, const std::vector<index_t> yStrides,
const std::vector<index_t> reduceDims, const std::vector<index_t> reduceDims,
AccDataType epsilon, double epsilon,
const void* p_x, const void* p_x,
const void* p_gamma, const void* p_gamma,
const void* p_beta, const void* p_beta,
......
...@@ -13,10 +13,16 @@ namespace ck { ...@@ -13,10 +13,16 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
template <index_t Rank, template <typename InDataType,
typename AccDataType,
typename OutDataType,
index_t Rank,
index_t NumReduceDim, index_t NumReduceDim,
typename ReduceOperation,
typename InElementwiseOperation, typename InElementwiseOperation,
typename AccElementwiseOperation> typename AccElementwiseOperation,
bool PropagateNan,
bool OutputIndex>
struct DeviceReduce : public BaseOperator struct DeviceReduce : public BaseOperator
{ {
static constexpr index_t NumOutDim = (Rank - NumReduceDim == 0) ? 1 : Rank - NumReduceDim; static constexpr index_t NumOutDim = (Rank - NumReduceDim == 0) ? 1 : Rank - NumReduceDim;
...@@ -27,8 +33,8 @@ struct DeviceReduce : public BaseOperator ...@@ -27,8 +33,8 @@ struct DeviceReduce : public BaseOperator
const std::array<index_t, NumOutDim> outLengths, const std::array<index_t, NumOutDim> outLengths,
const std::array<index_t, NumOutDim> outStrides, const std::array<index_t, NumOutDim> outStrides,
const std::array<int, NumReduceDim> reduceDims, const std::array<int, NumReduceDim> reduceDims,
float alpha, double alpha,
float beta, double beta,
const void* in_dev, const void* in_dev,
const void* in_index_dev, const void* in_index_dev,
void* out_dev, void* out_dev,
...@@ -39,12 +45,26 @@ struct DeviceReduce : public BaseOperator ...@@ -39,12 +45,26 @@ struct DeviceReduce : public BaseOperator
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0; virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
}; };
template <index_t Rank, template <typename InDataType,
typename AccDataType,
typename OutDataType,
index_t Rank,
index_t NumReduceDim, index_t NumReduceDim,
typename ReduceOperation,
typename InElementwiseOperation, typename InElementwiseOperation,
typename AccElementwiseOperation> typename AccElementwiseOperation,
using DeviceReducePtr = std::unique_ptr< bool PropagateNan,
DeviceReduce<Rank, NumReduceDim, InElementwiseOperation, AccElementwiseOperation>>; bool OutputIndex>
using DeviceReducePtr = std::unique_ptr<DeviceReduce<InDataType,
AccDataType,
OutDataType,
Rank,
NumReduceDim,
ReduceOperation,
InElementwiseOperation,
AccElementwiseOperation,
PropagateNan,
OutputIndex>>;
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
......
...@@ -27,10 +27,8 @@ struct DeviceSoftmax : public BaseOperator ...@@ -27,10 +27,8 @@ struct DeviceSoftmax : public BaseOperator
// @param[in] inLengths Input tensor extent(s) from high to low dimension // @param[in] inLengths Input tensor extent(s) from high to low dimension
// @param[in] inStrides Input tensor stride(s) from high to low dimension // @param[in] inStrides Input tensor stride(s) from high to low dimension
// @param[in] reduceDims The dimension(s) the normalization operation is applied // @param[in] reduceDims The dimension(s) the normalization operation is applied
// @param[in] alpha Typeless pointer in host memory storing the alpha scaling // @param[in] alpha double type value
// value as type AccDataType // @param[in] beta double type value
// @param[in] beta Typeless pointer in host memory storing the beta scaling
// value as type AccDataType
// @param[in] in_dev Typeless const pointer in device memory storing the input // @param[in] in_dev Typeless const pointer in device memory storing the input
// tensor // tensor
// @param out_dev Typeless pointer in device memory storing the output tensor // @param out_dev Typeless pointer in device memory storing the output tensor
...@@ -43,8 +41,8 @@ struct DeviceSoftmax : public BaseOperator ...@@ -43,8 +41,8 @@ struct DeviceSoftmax : public BaseOperator
MakeArgumentPointer(const std::vector<index_t> inLengths, MakeArgumentPointer(const std::vector<index_t> inLengths,
const std::vector<index_t> inStrides, const std::vector<index_t> inStrides,
const std::vector<int> reduceDims, const std::vector<int> reduceDims,
const void* alpha, double alpha,
const void* beta, double beta,
const void* in_dev, const void* in_dev,
void* out_dev, void* out_dev,
InElementwiseOp in_elementwise_op, InElementwiseOp in_elementwise_op,
......
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
#include "ck/utility/math.hpp" #include "ck/utility/math.hpp"
#include "ck/utility/sequence.hpp" #include "ck/utility/sequence.hpp"
#include "ck/tensor_operation/gpu/device/device_elementwise_base.hpp" #include "ck/tensor_operation/gpu/device/device_elementwise.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp"
...@@ -26,10 +26,10 @@ template <typename InDataTypeTuple, ...@@ -26,10 +26,10 @@ template <typename InDataTypeTuple,
index_t NPerThread, index_t NPerThread,
typename InScalarPerVectorSeq, typename InScalarPerVectorSeq,
typename OutScalarPerVectorSeq> typename OutScalarPerVectorSeq>
struct DeviceElementwise : public DeviceElementwiseBase<InDataTypeTuple, struct DeviceElementwise2dImpl : public DeviceElementwise<InDataTypeTuple,
OutDataTypeTuple, OutDataTypeTuple,
ElementwiseOperation, ElementwiseOperation,
NumDim_m + NumDim_n> NumDim_m + NumDim_n>
{ {
static constexpr index_t NumDim = NumDim_m + NumDim_n; static constexpr index_t NumDim = NumDim_m + NumDim_n;
......
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
#include "ck/utility/math.hpp" #include "ck/utility/math.hpp"
#include "ck/utility/sequence.hpp" #include "ck/utility/sequence.hpp"
#include "ck/tensor_operation/gpu/device/device_elementwise_base.hpp" #include "ck/tensor_operation/gpu/device/device_elementwise.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_1d.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_elementwise_1d.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp"
...@@ -25,8 +25,8 @@ template <typename InDataTypeTuple, ...@@ -25,8 +25,8 @@ template <typename InDataTypeTuple,
index_t MPerThread, index_t MPerThread,
typename InScalarPerVectorSeq, typename InScalarPerVectorSeq,
typename OutScalarPerVectorSeq> typename OutScalarPerVectorSeq>
struct DeviceElementwise struct DeviceElementwiseImpl
: public DeviceElementwiseBase<InDataTypeTuple, OutDataTypeTuple, ElementwiseOperation, NumDim> : public DeviceElementwise<InDataTypeTuple, OutDataTypeTuple, ElementwiseOperation, NumDim>
{ {
static constexpr int NumInput = InDataTypeTuple::Size(); static constexpr int NumInput = InDataTypeTuple::Size();
static constexpr int NumOutput = OutDataTypeTuple::Size(); static constexpr int NumOutput = OutDataTypeTuple::Size();
......
...@@ -270,18 +270,18 @@ struct DeviceElementwiseNormalizationImpl ...@@ -270,18 +270,18 @@ struct DeviceElementwiseNormalizationImpl
const std::vector<index_t> reduceDims, const std::vector<index_t> reduceDims,
XElementwiseOperation x_elementwise_op, XElementwiseOperation x_elementwise_op,
YElementwiseOperation y_elementwise_op, YElementwiseOperation y_elementwise_op,
AccDataType epsilon, double epsilon,
const std::array<const void*, NumInput> in_dev_buffers, const std::array<const void*, NumInput> in_dev_buffers,
const GammaDataType* p_gamma, const GammaDataType* p_gamma,
const BetaDataType* p_beta, const BetaDataType* p_beta,
YDataType* p_y) YDataType* p_y)
: epsilon_(epsilon), : p_gamma_(p_gamma),
p_gamma_(p_gamma),
p_beta_(p_beta), p_beta_(p_beta),
p_y_(p_y), p_y_(p_y),
x_elementwise_op_(x_elementwise_op), x_elementwise_op_(x_elementwise_op),
y_elementwise_op_(y_elementwise_op) y_elementwise_op_(y_elementwise_op)
{ {
epsilon_ = static_cast<AccDataType>(epsilon);
Lengths_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(lengths, reduceDims); Lengths_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(lengths, reduceDims);
for(int i = 0; i < NumInput; i++) for(int i = 0; i < NumInput; i++)
...@@ -543,7 +543,7 @@ struct DeviceElementwiseNormalizationImpl ...@@ -543,7 +543,7 @@ struct DeviceElementwiseNormalizationImpl
const std::vector<index_t> betaStrides, const std::vector<index_t> betaStrides,
const std::vector<index_t> yStrides, const std::vector<index_t> yStrides,
const std::vector<index_t> reduceDims, const std::vector<index_t> reduceDims,
AccDataType epsilon, double epsilon,
const std::array<const void*, NumInput> in_dev_buffers, const std::array<const void*, NumInput> in_dev_buffers,
const void* p_gamma, const void* p_gamma,
const void* p_beta, const void* p_beta,
......
...@@ -431,9 +431,6 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout, ...@@ -431,9 +431,6 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
const index_t grid_size = const index_t grid_size =
arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_); arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_);
const auto K =
arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2);
auto launch_kernel = [&](auto has_main_k_block_loop) { auto launch_kernel = [&](auto has_main_k_block_loop) {
constexpr bool has_main_loop = has_main_k_block_loop.value; constexpr bool has_main_loop = has_main_k_block_loop.value;
...@@ -471,6 +468,8 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout, ...@@ -471,6 +468,8 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
arg.block_2_etile_map_); arg.block_2_etile_map_);
}; };
const auto K = arg.a_grid_desc_m_k_.GetLength(I1);
if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
{ {
return launch_kernel(integral_constant<bool, true>{}); return launch_kernel(integral_constant<bool, true>{});
......
...@@ -486,7 +486,6 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout, ...@@ -486,7 +486,6 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
const index_t grid_size = const index_t grid_size =
arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_); arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_);
const auto K = const auto K =
arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2); arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2);
......
...@@ -73,8 +73,8 @@ struct DeviceMultipleReduceMultiBlock : public DeviceMultipleReduce<Rank, ...@@ -73,8 +73,8 @@ struct DeviceMultipleReduceMultiBlock : public DeviceMultipleReduce<Rank,
static_for<0, NumReduction, 1>{}([&](auto I) { static_for<0, NumReduction, 1>{}([&](auto I) {
using OutDataType = remove_cvref_t<decltype(OutDataTypeTuple{}[I])>; using OutDataType = remove_cvref_t<decltype(OutDataTypeTuple{}[I])>;
flag = flag =
flag && ck::reduce::InMemoryDataOperatonSupportedOnDataType<OutMemoryDataOperation, flag && ck::reduce::InMemoryDataOperationSupportedOnDataType<OutMemoryDataOperation,
OutDataType>::value; OutDataType>::value;
}); });
return flag; return flag;
...@@ -270,8 +270,8 @@ struct DeviceMultipleReduceMultiBlock : public DeviceMultipleReduce<Rank, ...@@ -270,8 +270,8 @@ struct DeviceMultipleReduceMultiBlock : public DeviceMultipleReduce<Rank,
const std::array<index_t, NumOutputDim>& outLengths, const std::array<index_t, NumOutputDim>& outLengths,
const std::array<std::array<index_t, NumOutputDim>, NumReduction>& outStridesArray, const std::array<std::array<index_t, NumOutputDim>, NumReduction>& outStridesArray,
const std::array<int, NumReduceDim>& reduceDims, const std::array<int, NumReduceDim>& reduceDims,
const std::array<const void*, NumReduction>& alphas, const std::array<double, NumReduction>& alphas,
const std::array<const void*, NumReduction>& betas, const std::array<double, NumReduction>& betas,
const void* in_dev, const void* in_dev,
const std::array<void*, NumReduction>& out_dev_buffers, const std::array<void*, NumReduction>& out_dev_buffers,
const InElementwiseOperationTuple in_elementwise_op_tuple, const InElementwiseOperationTuple in_elementwise_op_tuple,
...@@ -286,8 +286,8 @@ struct DeviceMultipleReduceMultiBlock : public DeviceMultipleReduce<Rank, ...@@ -286,8 +286,8 @@ struct DeviceMultipleReduceMultiBlock : public DeviceMultipleReduce<Rank,
for(size_t i = 0; i < NumReduction; i++) for(size_t i = 0; i < NumReduction; i++)
{ {
alpha_values_(i) = *static_cast<const AccDataType*>(alphas[i]); alpha_values_(i) = static_cast<AccDataType>(alphas[i]);
beta_values_(i) = *static_cast<const AccDataType*>(betas[i]); beta_values_(i) = static_cast<AccDataType>(betas[i]);
}; };
in_dev_ = static_cast<const InDataType*>(in_dev); in_dev_ = static_cast<const InDataType*>(in_dev);
...@@ -547,8 +547,8 @@ struct DeviceMultipleReduceMultiBlock : public DeviceMultipleReduce<Rank, ...@@ -547,8 +547,8 @@ struct DeviceMultipleReduceMultiBlock : public DeviceMultipleReduce<Rank,
const std::array<index_t, NumOutputDim> outLengths, const std::array<index_t, NumOutputDim> outLengths,
const std::array<std::array<index_t, NumOutputDim>, NumReduction> outStridesArray, const std::array<std::array<index_t, NumOutputDim>, NumReduction> outStridesArray,
const std::array<int, NumReduceDim> reduceDims, const std::array<int, NumReduceDim> reduceDims,
const std::array<const void*, NumReduction> alphas, const std::array<double, NumReduction> alphas,
const std::array<const void*, NumReduction> betas, const std::array<double, NumReduction> betas,
const void* in_dev, const void* in_dev,
const std::array<void*, NumReduction> out_dev_buffers, const std::array<void*, NumReduction> out_dev_buffers,
const InElementwiseOperationTuple in_elementwise_op_tuple, const InElementwiseOperationTuple in_elementwise_op_tuple,
......
...@@ -195,8 +195,8 @@ struct DeviceMultipleReduceThreadWise : public DeviceMultipleReduce<Rank, ...@@ -195,8 +195,8 @@ struct DeviceMultipleReduceThreadWise : public DeviceMultipleReduce<Rank,
const std::array<index_t, NumOutputDim>& outLengths, const std::array<index_t, NumOutputDim>& outLengths,
const std::array<std::array<index_t, NumOutputDim>, NumReduction>& outStridesArray, const std::array<std::array<index_t, NumOutputDim>, NumReduction>& outStridesArray,
const std::array<int, NumReduceDim>& reduceDims, const std::array<int, NumReduceDim>& reduceDims,
const std::array<const void*, NumReduction>& alphas, const std::array<double, NumReduction>& alphas,
const std::array<const void*, NumReduction>& betas, const std::array<double, NumReduction>& betas,
const void* in_dev, const void* in_dev,
const std::array<void*, NumReduction>& out_dev_buffers, const std::array<void*, NumReduction>& out_dev_buffers,
const InElementwiseOperationTuple in_elementwise_op_tuple, const InElementwiseOperationTuple in_elementwise_op_tuple,
...@@ -211,8 +211,8 @@ struct DeviceMultipleReduceThreadWise : public DeviceMultipleReduce<Rank, ...@@ -211,8 +211,8 @@ struct DeviceMultipleReduceThreadWise : public DeviceMultipleReduce<Rank,
for(size_t i = 0; i < NumReduction; i++) for(size_t i = 0; i < NumReduction; i++)
{ {
alpha_values_(i) = *static_cast<const AccDataType*>(alphas[i]); alpha_values_(i) = static_cast<AccDataType>(alphas[i]);
beta_values_(i) = *static_cast<const AccDataType*>(betas[i]); beta_values_(i) = static_cast<AccDataType>(betas[i]);
}; };
in_dev_ = static_cast<const InDataType*>(in_dev); in_dev_ = static_cast<const InDataType*>(in_dev);
...@@ -374,8 +374,8 @@ struct DeviceMultipleReduceThreadWise : public DeviceMultipleReduce<Rank, ...@@ -374,8 +374,8 @@ struct DeviceMultipleReduceThreadWise : public DeviceMultipleReduce<Rank,
const std::array<index_t, NumOutputDim> outLengths, const std::array<index_t, NumOutputDim> outLengths,
const std::array<std::array<index_t, NumOutputDim>, NumReduction> outStridesArray, const std::array<std::array<index_t, NumOutputDim>, NumReduction> outStridesArray,
const std::array<int, NumReduceDim> reduceDims, const std::array<int, NumReduceDim> reduceDims,
const std::array<const void*, NumReduction> alphas, const std::array<double, NumReduction> alphas,
const std::array<const void*, NumReduction> betas, const std::array<double, NumReduction> betas,
const void* in_dev, const void* in_dev,
const std::array<void*, NumReduction> out_dev_buffers, const std::array<void*, NumReduction> out_dev_buffers,
const InElementwiseOperationTuple in_elementwise_op_tuple, const InElementwiseOperationTuple in_elementwise_op_tuple,
......
...@@ -221,18 +221,19 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType, ...@@ -221,18 +221,19 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
const std::vector<index_t> yStrides, const std::vector<index_t> yStrides,
const std::vector<index_t> reduceDims, const std::vector<index_t> reduceDims,
AccElementwiseOperation acc_elementwise_op, AccElementwiseOperation acc_elementwise_op,
AccDataType epsilon, double epsilon,
const XDataType* p_x, const XDataType* p_x,
const GammaDataType* p_gamma, const GammaDataType* p_gamma,
const BetaDataType* p_beta, const BetaDataType* p_beta,
YDataType* p_y) YDataType* p_y)
: epsilon_(epsilon), : p_x_(p_x),
p_x_(p_x),
p_gamma_(p_gamma), p_gamma_(p_gamma),
p_beta_(p_beta), p_beta_(p_beta),
p_y_(p_y), p_y_(p_y),
acc_elementwise_op_(acc_elementwise_op) acc_elementwise_op_(acc_elementwise_op)
{ {
epsilon_ = static_cast<AccDataType>(epsilon);
Lengths_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(lengths, reduceDims); Lengths_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(lengths, reduceDims);
xStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(xStrides, reduceDims); xStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(xStrides, reduceDims);
yStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(yStrides, reduceDims); yStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(yStrides, reduceDims);
...@@ -421,7 +422,7 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType, ...@@ -421,7 +422,7 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
const std::vector<index_t> betaStrides, const std::vector<index_t> betaStrides,
const std::vector<index_t> yStrides, const std::vector<index_t> yStrides,
const std::vector<index_t> reduceDims, const std::vector<index_t> reduceDims,
AccDataType epsilon, double epsilon,
const void* p_x, const void* p_x,
const void* p_gamma, const void* p_gamma,
const void* p_beta, const void* p_beta,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment