Unverified Commit f5de8b57 authored by Chao Liu's avatar Chao Liu Committed by GitHub
Browse files

Merge branch 'develop' into modified_grouped_gemm_addressing_method

parents e83c7061 fa9a0a5c
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <vector>
#include "device_base.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <ck::index_t NumInputTensor,
ck::index_t NumOutputTensor,
index_t NDim,
typename ElementwiseFunctor>
struct DeviceElementwise : public BaseOperator
{
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(std::array<const void*, NumInputTensor> p_inputs,
std::array<void*, NumOutputTensor> p_outputs,
std::vector<index_t> lengths,
std::vector<std::vector<index_t>> input_strides,
std::vector<std::vector<index_t>> output_strides,
ElementwiseFunctor functor) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
template <ck::index_t NumInputTensor,
ck::index_t NumOutputTensor,
index_t NDim,
typename ElementwiseFunctor>
using DeviceElementwisePtr =
std::unique_ptr<DeviceElementwise<NumInputTensor, NumOutputTensor, NDim, ElementwiseFunctor>>;
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -2,52 +2,55 @@ ...@@ -2,52 +2,55 @@
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#include <iostream>
#include <array>
#include "device_base.hpp" #include "device_base.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
struct DEGridDesc_M0_M1_M2_N0_N1
{
ck::index_t M0_, M1_, M2_, N0_, N1_;
ck::index_t stride_M0_, stride_M1_, stride_M2_, stride_N0_, stride_N1_;
};
// input : A[M, K], B[K, N],
// input : D[M, N], ...
// output : E[M, N]
// C = a_op(A) * b_op(B)
// E = cde_op(C, D)
template <typename AElementwiseOperation, template <typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation, typename CDEElementwiseOperation>
typename DxsInElementwiseOperation, struct DeviceGemmBiasCPermute : public BaseOperator
typename DxsReduceAccElementwiseOperation>
struct DeviceBatchedGemmReduce : public BaseOperator
{ {
virtual std::unique_ptr<BaseArgument> virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a, MakeArgumentPointer(const void* p_a,
const void* p_b, const void* p_b,
void* p_c, const void* p_d,
void* p_dxs, void* p_e,
ck::index_t M, ck::index_t M,
ck::index_t N, ck::index_t N,
ck::index_t K, ck::index_t K,
ck::index_t StrideA, ck::index_t StrideA,
ck::index_t StrideB, ck::index_t StrideB,
ck::index_t StrideC, DEGridDesc_M0_M1_M2_N0_N1 d_gride_desc,
DEGridDesc_M0_M1_M2_N0_N1 e_gride_desc,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op, CDEElementwiseOperation cde_element_op) = 0;
DxsInElementwiseOperation dxs_in_element_op,
DxsReduceAccElementwiseOperation dxs_out_element_op,
ck::index_t Batch) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0; virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
}; };
template <typename AElementwiseOperation, template <typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation, typename CElementwiseOperation>
typename DxsInElementwiseOperation, using DeviceGemmBiasCPermutePtr = std::unique_ptr<
typename DxsReduceAccElementwiseOperation> DeviceGemmBiasCPermute<AElementwiseOperation, BElementwiseOperation, CElementwiseOperation>>;
using DeviceBatchedGemmReducePtr =
std::unique_ptr<DeviceBatchedGemmReduce<AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
DxsInElementwiseOperation,
DxsReduceAccElementwiseOperation>>;
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
......
...@@ -9,91 +9,34 @@ namespace ck { ...@@ -9,91 +9,34 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
template <typename AElementwiseOperation, template <ck::index_t NumDTensor, ck::index_t NumReduce>
typename BElementwiseOperation,
typename CElementwiseOperation,
typename DxsInElementwiseOperation,
typename DxsReduceAccElementwiseOperation>
struct DeviceGemmReduce : public BaseOperator struct DeviceGemmReduce : public BaseOperator
{ {
virtual std::unique_ptr<BaseArgument> virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a, MakeArgumentPointer(const void* p_a,
const void* p_b, const void* p_b,
const void* p_bias,
std::array<const void*, NumDTensor> p_ds,
void* p_c, void* p_c,
void* p_dxs, std::array<void*, NumReduce> p_reduces,
ck::index_t M, ck::index_t M,
ck::index_t N, ck::index_t N,
ck::index_t K, ck::index_t K,
ck::index_t StrideA, ck::index_t StrideA,
ck::index_t StrideB, ck::index_t StrideB,
ck::index_t StrideC, ck::index_t StrideC,
AElementwiseOperation a_element_op, std::array<ck::index_t, NumDTensor> StrideDs,
BElementwiseOperation b_element_op, std::array<void*, 3> gemm_element_ops,
CElementwiseOperation c_element_op, std::array<void*, NumDTensor> d_element_ops,
DxsInElementwiseOperation dxs_in_element_op, std::array<void*, NumReduce> reduce_in_element_ops,
DxsReduceAccElementwiseOperation dxs_out_element_op, std::array<void*, NumReduce> reduce_out_element_ops,
ck::index_t BatchCount = 1) = 0; ck::index_t BatchCount = 1) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0; virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
}; };
template <typename AElementwiseOperation, template <ck::index_t NumDTensor, ck::index_t NumReduce>
typename BElementwiseOperation, using DeviceGemmReducePtr = std::unique_ptr<DeviceGemmReduce<NumDTensor, NumReduce>>;
typename CElementwiseOperation,
typename DxsInElementwiseOperation,
typename DxsReduceAccElementwiseOperation>
using DeviceGemmReducePtr = std::unique_ptr<DeviceGemmReduce<AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
DxsInElementwiseOperation,
DxsReduceAccElementwiseOperation>>;
template <typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
typename C1ElementwiseOperation,
typename DxsInElementwiseOperation,
typename DxsReduceAccElementwiseOperation>
struct DeviceGemmBiasAddReduce : public BaseOperator
{
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b,
void* p_c,
const void* p_c0,
const void* p_c1,
void* p_dxs,
ck::index_t M,
ck::index_t N,
ck::index_t K,
ck::index_t StrideA,
ck::index_t StrideB,
ck::index_t StrideC,
ck::index_t StrideC1,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
C1ElementwiseOperation c1_element_op,
DxsInElementwiseOperation dxs_in_element_op,
DxsReduceAccElementwiseOperation dxs_out_element_op,
ck::index_t BatchCount = 1) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
template <typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
typename C1ElementwiseOperation,
typename DxsInElementwiseOperation,
typename DxsReduceAccElementwiseOperation>
using DeviceGemmBiasAddReducePtr =
std::unique_ptr<DeviceGemmBiasAddReduce<AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
C1ElementwiseOperation,
DxsInElementwiseOperation,
DxsReduceAccElementwiseOperation>>;
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <vector>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
struct DeviceNormalization : public BaseOperator
{
// inLengths: input tensor extent(s) from high to low dimension
// inStrides: input tensor stride(s) from high to low dimension
// reduceDims: the dimension(s) the normalization operation is applied
// alpha: typeless pointer in host memory storing the alpha scaling value of type AccDataType
// beta: typeless pointer in host memory storing the beta scaling value of type AccDataType
// in_dev: typeless const pointer in device memory storing the input tensor
// out_dev: typeless pointer in device memory storing the output tensor
virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(const std::vector<index_t> inLengths,
const std::vector<index_t> inStrides,
const std::vector<int> reduceDims,
const void* alpha,
const void* beta,
const void* in_dev,
void* out_dev) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
virtual index_t GetRank() const = 0;
virtual index_t GetNumReduceDim() const = 0;
};
using DeviceNormalizationPtr = std::unique_ptr<DeviceNormalization>;
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
#include "ck/utility/reduction_operator.hpp" #include "ck/utility/reduction_operator.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp" #include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/tensor_operation/gpu/device/device_reduce.hpp" #include "ck/tensor_operation/gpu/device/device_reduce.hpp"
#include "ck/tensor_operation/gpu/device/device_normalization.hpp"
#include "ck/tensor_operation/gpu/device/device_reduce_multiblock.hpp" #include "ck/tensor_operation/gpu/device/device_reduce_multiblock.hpp"
#include "ck/tensor_operation/gpu/device/device_reduce_common.hpp" #include "ck/tensor_operation/gpu/device/device_reduce_common.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_softmax.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_softmax.hpp"
...@@ -33,8 +34,15 @@ template <typename InDataType, ...@@ -33,8 +34,15 @@ template <typename InDataType,
index_t InSrcVectorDim, index_t InSrcVectorDim,
index_t InSrcVectorSize, index_t InSrcVectorSize,
index_t OutDstVectorSize> index_t OutDstVectorSize>
struct DeviceSoftmax : public BaseOperator struct DeviceSoftmax : public DeviceNormalization
{ {
static constexpr index_t kRank = Rank;
static constexpr index_t kNumReduceDim = NumReduceDim;
virtual index_t GetRank() const override { return kRank; }
virtual index_t GetNumReduceDim() const override { return kNumReduceDim; }
using PassThrough = tensor_operation::element_wise::PassThrough; using PassThrough = tensor_operation::element_wise::PassThrough;
// Used for freeloading of some handy functions from DeviceReduceMultiBlock // Used for freeloading of some handy functions from DeviceReduceMultiBlock
...@@ -61,18 +69,33 @@ struct DeviceSoftmax : public BaseOperator ...@@ -61,18 +69,33 @@ struct DeviceSoftmax : public BaseOperator
using GridDesc_M_K = decltype(Reduction::MakeSrc2dDescriptor({1}, {1}, 1, 1)); using GridDesc_M_K = decltype(Reduction::MakeSrc2dDescriptor({1}, {1}, 1, 1));
using GridwiseReduce = GridwiseSoftmax_mk_to_mk<InDataType, using GridwiseSoftmaxGeneric = GridwiseSoftmax_mk_to_mk<InDataType,
OutDataType, OutDataType,
AccDataType, AccDataType,
GridDesc_M_K, GridDesc_M_K,
BlockSize, BlockSize,
MThreadClusterSize, MThreadClusterSize,
KThreadClusterSize, KThreadClusterSize,
MThreadSliceSize, MThreadSliceSize,
KThreadSliceSize, KThreadSliceSize,
InSrcVectorDim, InSrcVectorDim,
InSrcVectorSize, InSrcVectorSize,
OutDstVectorSize>; OutDstVectorSize,
false>;
using GridwiseSoftmaxSweepOnce = GridwiseSoftmax_mk_to_mk<InDataType,
OutDataType,
AccDataType,
GridDesc_M_K,
BlockSize,
MThreadClusterSize,
KThreadClusterSize,
MThreadSliceSize,
KThreadSliceSize,
InSrcVectorDim,
InSrcVectorSize,
OutDstVectorSize,
true>;
struct Argument : public Reduction::Argument struct Argument : public Reduction::Argument
{ {
...@@ -121,8 +144,19 @@ struct DeviceSoftmax : public BaseOperator ...@@ -121,8 +144,19 @@ struct DeviceSoftmax : public BaseOperator
const auto out_grid_desc_m_k = Reduction::MakeSrc2dDescriptor( const auto out_grid_desc_m_k = Reduction::MakeSrc2dDescriptor(
arg.inLengths_, arg.inStrides_, arg.blkGroupSize, arg.numBlockTileIteration); arg.inLengths_, arg.inStrides_, arg.blkGroupSize, arg.numBlockTileIteration);
const auto kernel_main = bool sweep_once =
kernel_softmax<GridwiseReduce, InDataType, OutDataType, AccDataType, GridDesc_M_K>; in_grid_desc_m_k.GetLength(Number<1>{}) <= KThreadClusterSize * KThreadSliceSize;
const auto kernel_main = sweep_once ? kernel_softmax<GridwiseSoftmaxSweepOnce,
InDataType,
OutDataType,
AccDataType,
GridDesc_M_K>
: kernel_softmax<GridwiseSoftmaxGeneric,
InDataType,
OutDataType,
AccDataType,
GridDesc_M_K>;
float avg_time = 0; float avg_time = 0;
...@@ -167,24 +201,34 @@ struct DeviceSoftmax : public BaseOperator ...@@ -167,24 +201,34 @@ struct DeviceSoftmax : public BaseOperator
return true; return true;
}; };
// inLengths: input tensor extent(s) from high to low dimension
// inStrides: input tensor stride(s) from high to low dimension
// reduceDims: the dimension(s) the softmax normalization operate on
// alpha: typeless pointer in host memory storing the alpha scaling value as type AccDataType
// beta: typeless pointer in host memory storing the beta scaling value as type AccDataType
// in_dev: typeless const pointer in device memory storing the input tensor
// out_dev: typeless pointer in device memory storing the output tensor
std::unique_ptr<BaseArgument> MakeArgumentPointer(const std::vector<index_t> inLengths, std::unique_ptr<BaseArgument> MakeArgumentPointer(const std::vector<index_t> inLengths,
const std::vector<index_t> inStrides, const std::vector<index_t> inStrides,
const std::vector<int> reduceDims, const std::vector<int> reduceDims,
AccDataType alpha, const void* alpha,
AccDataType beta, const void* beta,
const void* in_dev, const void* in_dev,
void* out_dev) void* out_dev) override
{ {
return std::make_unique<Argument>(inLengths, return std::make_unique<Argument>(inLengths,
inStrides, inStrides,
reduceDims, reduceDims,
alpha, *static_cast<const AccDataType*>(alpha),
beta, *static_cast<const AccDataType*>(beta),
static_cast<const InDataType*>(in_dev), static_cast<const InDataType*>(in_dev),
static_cast<OutDataType*>(out_dev)); static_cast<OutDataType*>(out_dev));
}; };
std::unique_ptr<BaseInvoker> MakeInvokerPointer() { return std::make_unique<Invoker>(); }; std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>();
};
std::string GetTypeString() const override std::string GetTypeString() const override
{ {
......
...@@ -11,8 +11,8 @@ namespace element_wise { ...@@ -11,8 +11,8 @@ namespace element_wise {
struct Add struct Add
{ {
template <typename T> template <typename Y, typename X0, typename X1>
__host__ __device__ constexpr void operator()(T& y, const T& x0, const T& x1) const; __host__ __device__ constexpr void operator()(Y& y, const X0& x0, const X1& x1) const;
template <> template <>
__host__ __device__ constexpr void __host__ __device__ constexpr void
...@@ -28,6 +28,13 @@ struct Add ...@@ -28,6 +28,13 @@ struct Add
y = x0 + x1; y = x0 + x1;
}; };
template <>
__host__ __device__ constexpr void
operator()<half_t>(half_t& y, const float& x0, const half_t& x1) const
{
y = type_convert<half_t>(x0) + x1;
};
// Question: should half_t be supported ? // Question: should half_t be supported ?
template <> template <>
__host__ __device__ constexpr void __host__ __device__ constexpr void
......
...@@ -236,9 +236,14 @@ template <typename SrcData, ...@@ -236,9 +236,14 @@ template <typename SrcData,
index_t SrcScalarPerVector, index_t SrcScalarPerVector,
index_t SrcScalarStrideInVector, index_t SrcScalarStrideInVector,
bool SrcResetCoordinateAfterRun, bool SrcResetCoordinateAfterRun,
bool InvalidElementAsNaN = false,
typename enable_if<DstDesc::IsKnownAtCompileTime(), bool>::type = false> typename enable_if<DstDesc::IsKnownAtCompileTime(), bool>::type = false>
struct ThreadwiseTensorSliceTransfer_v2 struct ThreadwiseTensorSliceTransfer_v2
{ {
static_assert((InvalidElementAsNaN && !std::is_integral<DstData>::value) ||
(!InvalidElementAsNaN),
"Filling invalid element as NaN is only for floating point types");
static constexpr index_t nDim = SliceLengths::Size(); static constexpr index_t nDim = SliceLengths::Size();
using Index = MultiIndex<nDim>; using Index = MultiIndex<nDim>;
...@@ -318,8 +323,18 @@ struct ThreadwiseTensorSliceTransfer_v2 ...@@ -318,8 +323,18 @@ struct ThreadwiseTensorSliceTransfer_v2
dst_desc.CalculateOffset(to_multi_index(dst_slice_origin_idx) + src_data_idx + dst_desc.CalculateOffset(to_multi_index(dst_slice_origin_idx) + src_data_idx +
i * src_scalar_step_in_vector); i * src_scalar_step_in_vector);
dst_buf(Number<dst_offset>{}) = if constexpr(InvalidElementAsNaN)
type_convert<DstData>(src_vector.template AsType<SrcData>()[i]); {
dst_buf(Number<dst_offset>{}) =
is_src_valid
? type_convert<DstData>(src_vector.template AsType<SrcData>()[i])
: NumericLimits<DstData>::QuietNaN();
}
else
{
dst_buf(Number<dst_offset>{}) =
type_convert<DstData>(src_vector.template AsType<SrcData>()[i]);
}
}); });
if constexpr(idx_1d.value != num_access - 1) if constexpr(idx_1d.value != num_access - 1)
......
...@@ -148,6 +148,8 @@ __host__ __device__ constexpr auto min(X x, Ys... ys) ...@@ -148,6 +148,8 @@ __host__ __device__ constexpr auto min(X x, Ys... ys)
template <typename T> template <typename T>
__device__ T exp(T x); __device__ T exp(T x);
// TODO: add f16 support using v_exp_f16
template <> template <>
__device__ float exp<float>(float x) __device__ float exp<float>(float x)
{ {
......
...@@ -17,7 +17,7 @@ struct AccumulateWithNanIgnore ...@@ -17,7 +17,7 @@ struct AccumulateWithNanIgnore
{ {
__device__ static inline void Calculate(AccDataType& accuVal, AccDataType currVal) __device__ static inline void Calculate(AccDataType& accuVal, AccDataType currVal)
{ {
if(!isnan(currVal)) if(!ck::math::isnan(currVal))
{ {
ReduceOperation{}(accuVal, currVal); ReduceOperation{}(accuVal, currVal);
} }
......
...@@ -222,6 +222,12 @@ struct Tensor ...@@ -222,6 +222,12 @@ struct Tensor
Tensor(const Tensor& other) : mDesc(other.mDesc), mData(other.mData) {} Tensor(const Tensor& other) : mDesc(other.mDesc), mData(other.mData) {}
Tensor& operator=(const Tensor& other)
{
mDesc = other.mDesc;
mData = other.mData;
}
template <typename F> template <typename F>
void ForEach_impl(F&& f, std::vector<size_t>& idx, size_t rank) void ForEach_impl(F&& f, std::vector<size_t>& idx, size_t rank)
{ {
......
...@@ -26,12 +26,11 @@ struct ReferenceSoftmax : public device::BaseOperator ...@@ -26,12 +26,11 @@ struct ReferenceSoftmax : public device::BaseOperator
Tensor<OutDataType>& out, Tensor<OutDataType>& out,
AccDataType alpha, AccDataType alpha,
AccDataType beta, AccDataType beta,
const index_t rank,
const std::vector<index_t> sm_reduce_dims) const std::vector<index_t> sm_reduce_dims)
: in_(in), out_(out), alpha_(alpha), beta_(beta), sm_reduce_dims_(sm_reduce_dims) : in_(in), out_(out), alpha_(alpha), beta_(beta), sm_reduce_dims_(sm_reduce_dims)
{ {
// std::cout << "debug: scalar dims: "; // std::cout << "debug: scalar dims: ";
for(int i = 0; i < rank; i++) for(size_t i = 0; i < in.mDesc.GetNumOfDimension(); i++)
{ {
if(std::find(sm_reduce_dims.begin(), sm_reduce_dims.end(), i) == if(std::find(sm_reduce_dims.begin(), sm_reduce_dims.end(), i) ==
sm_reduce_dims.end()) sm_reduce_dims.end())
...@@ -47,7 +46,6 @@ struct ReferenceSoftmax : public device::BaseOperator ...@@ -47,7 +46,6 @@ struct ReferenceSoftmax : public device::BaseOperator
Tensor<OutDataType>& out_; Tensor<OutDataType>& out_;
AccDataType alpha_; AccDataType alpha_;
AccDataType beta_; AccDataType beta_;
index_t rank_;
std::vector<index_t> sm_reduce_dims_; std::vector<index_t> sm_reduce_dims_;
std::vector<index_t> sm_scalar_dims_; // dim after internal max/sum reduction std::vector<index_t> sm_scalar_dims_; // dim after internal max/sum reduction
}; };
...@@ -136,10 +134,9 @@ struct ReferenceSoftmax : public device::BaseOperator ...@@ -136,10 +134,9 @@ struct ReferenceSoftmax : public device::BaseOperator
Tensor<OutDataType>& out, Tensor<OutDataType>& out,
AccDataType alpha, AccDataType alpha,
AccDataType beta, AccDataType beta,
const index_t rank,
const std::vector<index_t> sm_reduce_dims) const std::vector<index_t> sm_reduce_dims)
{ {
return Argument{in, out, alpha, beta, rank, sm_reduce_dims}; return Argument{in, out, alpha, beta, sm_reduce_dims};
} }
static auto MakeInvoker() { return Invoker{}; } static auto MakeInvoker() { return Invoker{}; }
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
#pragma once #pragma once
#include <vector> #include <vector>
#include "ck/utility/functional2.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_elementwise.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
using Normalize = ck::tensor_operation::element_wise::Normalize;
using DeviceNormalizeFromMeanMeanSquarePtr =
ck::tensor_operation::device::DeviceElementwisePtr<5, 1, 2, Normalize>;
void add_device_normalize_from_mean_squaremean_f16_f32_f32_f16_f16_instances(
std::vector<DeviceNormalizeFromMeanMeanSquarePtr>& instances);
template <typename InputType,
typename MeanType,
typename MeanSquareType,
typename GammaDataType,
typename BetaDataType,
typename OutputType>
auto get_device_normalize_from_mean_meansquare_instances()
{
std::vector<DeviceNormalizeFromMeanMeanSquarePtr> op_ptrs;
if constexpr(is_same<InputType, half_t>::value && is_same<MeanType, float>::value &&
is_same<MeanSquareType, float>::value && is_same<GammaDataType, half_t>::value &&
is_same<BetaDataType, half_t>::value && is_same<OutputType, half_t>::value)
{
ck::tensor_operation::device::
add_device_normalize_from_mean_squaremean_f16_f32_f32_f16_f16_instances(op_ptrs);
}
return op_ptrs;
}
} // namespace device
} // namespace tensor_operation
} // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment