Commit a715222c authored by yuguo's avatar yuguo
Browse files

0.9.1-rocm

parent f262efc9
...@@ -30,15 +30,6 @@ namespace broadcast_elementwise_binary { ...@@ -30,15 +30,6 @@ namespace broadcast_elementwise_binary {
constexpr size_t kMaxNumDims = 8; constexpr size_t kMaxNumDims = 8;
inline void CheckInplace(size_t num_dims, const int64_t* src0_dims, const void* src0,
const int64_t* src1_dims, const void* src1, const int64_t* dst_dims,
const void* dst) {
for (int64_t i = 0; i < num_dims; ++i) {
if (src0 == dst) { CHECK_EQ(src0_dims[i], dst_dims[i]); }
if (src1 == dst) { CHECK_EQ(src1_dims[i], dst_dims[i]); }
}
}
inline bool IsDimsEquals(size_t num_src0_dims, const int64_t* src0_dims, size_t num_src1_dims, inline bool IsDimsEquals(size_t num_src0_dims, const int64_t* src0_dims, size_t num_src1_dims,
const int64_t* src1_dims) { const int64_t* src1_dims) {
if (num_src0_dims != num_src1_dims) { return false; } if (num_src0_dims != num_src1_dims) { return false; }
...@@ -48,22 +39,36 @@ inline bool IsDimsEquals(size_t num_src0_dims, const int64_t* src0_dims, size_t ...@@ -48,22 +39,36 @@ inline bool IsDimsEquals(size_t num_src0_dims, const int64_t* src0_dims, size_t
return true; return true;
} }
#define BINARY_MATH_OP_SEQ \ #define BINARY_MATH_OP_SEQ_0 \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kAdd) \ OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kAdd) \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kSub) \ OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kSub) \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kMul) \ OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kMul) \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kDiv) \ OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kDiv) \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kMax) \ OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kMax) \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kMin) \ OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kMin) \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kPow) OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kPow) \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kFmod) \
#define BINARY_COMPARISION_OP_SEQ \ OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kFloorDiv) \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kEqual) \ OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kTruncDiv) \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kNotEqual) \ OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kFloorMod) \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kLessThan) \ OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kScalarExpPowerGrad)
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kLessEqual) \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kGreaterThan) \ #define BINARY_MATH_OP_SEQ_1 \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kGreaterEqual) OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kScalarBasePowerGrad)
#define BINARY_MATH_OP_SEQ \
BINARY_MATH_OP_SEQ_0 \
BINARY_MATH_OP_SEQ_1
#define BINARY_COMPARISION_OP_SEQ \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kEqual) \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kNotEqual) \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kLessThan) \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kLessEqual) \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kGreaterThan) \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kGreaterEqual) \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kIsCloseEqualNan) \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kIsClose)
#define BINARY_LOGICAL_OP_SEQ \ #define BINARY_LOGICAL_OP_SEQ \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kLogicalAnd) \ OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kLogicalAnd) \
...@@ -87,7 +92,39 @@ inline bool IsDimsEquals(size_t num_src0_dims, const int64_t* src0_dims, size_t ...@@ -87,7 +92,39 @@ inline bool IsDimsEquals(size_t num_src0_dims, const int64_t* src0_dims, size_t
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kSoftplusBackwardWithDyX) \ OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kSoftplusBackwardWithDyX) \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kSoftshrinkBackwardWithDyY) \ OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kSoftshrinkBackwardWithDyY) \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kTanhBackwardWithDyX) \ OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kTanhBackwardWithDyX) \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kThresholdBackwardWithDyX) OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kThresholdBackwardWithDyX) \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kFastGeluBackwardWithDyX) \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kQuickGeluBackwardWithDyX)
#define BINARY_MATH_BACKWARD_OP_SEQ \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kAbsBackwardWithDyX) \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kAcosBackwardWithDyX) \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kAcoshBackwardWithDyX) \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kAsinBackwardWithDyX) \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kAsinhBackwardWithDyX) \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kAtanBackwardWithDyX) \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kAtanhBackwardWithDyX) \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kCosBackwardWithDyX) \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kCoshBackwardWithDyX) \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kErfBackwardWithDyX) \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kErfcBackwardWithDyX) \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kExpBackwardWithDyX) \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kExpm1BackwardWithDyX) \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kLgammaBackwardWithDyX) \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kLogBackwardWithDyX) \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kLog2BackwardWithDyX) \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kLog10BackwardWithDyX) \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kLog1pBackwardWithDyX) \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kLogSigmoidBackwardWithDyX) \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kReciprocalBackwardWithDyX) \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kReciprocalNoNanBackwardWithDyX) \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kRsqrtBackwardWithDyX) \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kSinBackwardWithDyX) \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kSigmoidBackwardWithDyY) \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kSinhBackwardWithDyX) \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kSqrtBackwardWithDyX) \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kSquareBackwardWithDyX) \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kTanBackwardWithDyX)
} // namespace broadcast_elementwise_binary } // namespace broadcast_elementwise_binary
} // namespace primitive } // namespace primitive
......
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_PRIMITIVE_COMMON_BROADCAST_ELEMENTWISE_UNARY
#define ONEFLOW_CORE_PRIMITIVE_COMMON_BROADCAST_ELEMENTWISE_UNARY
#include "oneflow/core/ep/include/primitive/broadcast_elementwise_unary.h"
#include "oneflow/core/ep/include/primitive/fast_integer_math.h"
#include "oneflow/core/ep/common/primitive/util.h"
namespace oneflow {
namespace ep {
namespace primitive {
namespace broadcast_elementwise_unary {
constexpr size_t kMaxNumDims = 8;
template<typename T, int N>
class IndexToOffsetWithStrideCalculator {
public:
IndexToOffsetWithStrideCalculator() {}
OF_DEVICE_FUNC explicit IndexToOffsetWithStrideCalculator(const T* strides) {
InitStrides(strides, N);
}
template<typename U>
OF_DEVICE_FUNC explicit IndexToOffsetWithStrideCalculator(const U* strides) {
T strides_arr[N];
for (int i = 0; i < N; ++i) { strides_arr[i] = strides[i]; }
InitStrides(strides_arr, N);
}
OF_DEVICE_FUNC explicit IndexToOffsetWithStrideCalculator(const T* strides, int n) {
InitStrides(strides, n);
}
template<typename U>
OF_DEVICE_FUNC explicit IndexToOffsetWithStrideCalculator(const U* strides, int n) {
T strides_arr[N];
for (int i = 0; i < N; ++i) {
if (i < n) { strides_arr[i] = strides[i]; }
}
InitStrides(strides_arr, n);
}
~IndexToOffsetWithStrideCalculator() = default;
OF_DEVICE_FUNC T NdIndexToOffset(const T* index) const {
T offset = 0;
#ifdef __CUDA_ARCH__
#pragma unroll
#endif
for (int i = 0; i < N - 1; ++i) { offset += index[i] * stride_[i]; }
offset += index[N - 1];
return offset;
}
OF_DEVICE_FUNC T NdIndexToOffset(const T* index, int n) const {
assert(n <= N);
T offset = 0;
#ifdef __CUDA_ARCH__
#pragma unroll
#endif
for (int i = 0; i < N; ++i) {
if (i < n) { offset += index[i] * stride_[i]; }
}
return offset;
}
OF_DEVICE_FUNC constexpr int Size() const { return N; }
private:
OF_DEVICE_FUNC void InitStrides(const T* strides, const int n) {
for (int i = n; i < N; ++i) { stride_[i] = 1; }
for (int i = n - 1; i >= 0; --i) { stride_[i] = strides[i]; }
}
T stride_[N];
};
template<typename T, int N>
class OffsetToIndexWithStrideCalculator {
public:
OffsetToIndexWithStrideCalculator() {}
OF_DEVICE_FUNC explicit OffsetToIndexWithStrideCalculator(const T* dims) {
InitFastIntegerMath(dims, N);
}
template<typename U>
OF_DEVICE_FUNC explicit OffsetToIndexWithStrideCalculator(const U* dims) {
T dims_arr[N];
for (int i = 0; i < N; ++i) { dims_arr[i] = dims[i]; }
InitFastIntegerMath(dims_arr, N);
}
OF_DEVICE_FUNC explicit OffsetToIndexWithStrideCalculator(const T* dims, int n) {
InitFastIntegerMath(dims, n);
}
template<typename U>
OF_DEVICE_FUNC explicit OffsetToIndexWithStrideCalculator(const U* dims, int n) {
T dims_arr[N];
for (int i = 0; i < N; ++i) {
if (i < n) { dims_arr[i] = dims[i]; }
}
InitFastIntegerMath(dims_arr, n);
}
~OffsetToIndexWithStrideCalculator() = default;
OF_DEVICE_FUNC void OffsetToNdIndex(T offset, T* index) const {
T remaining = offset;
#ifdef __CUDA_ARCH__
#pragma unroll
#endif
for (int i = 0; i < N - 1; ++i) {
const T idx = math_helper_[i].divides(remaining);
index[i] = idx;
remaining = remaining - math_helper_[i].mul(idx);
}
index[N - 1] = remaining;
}
OF_DEVICE_FUNC void OffsetToNdIndex(T offset, T* index, int n) const {
assert(n <= N);
T remaining = offset;
#ifdef __CUDA_ARCH__
#pragma unroll
#endif
for (int i = 0; i < N; ++i) {
if (i == n - 1) { break; }
if (i < n - 1) {
const T idx = math_helper_[i].divides(remaining);
index[i] = idx;
remaining = remaining - math_helper_[i].mul(idx);
}
}
index[n - 1] = remaining;
}
OF_DEVICE_FUNC constexpr int Size() const { return N; }
private:
OF_DEVICE_FUNC void InitFastIntegerMath(const T* dims, const int n) {
T stride_arr[N];
for (int i = n - 1; i < N; ++i) {
stride_arr[i] = 1;
math_helper_[i] = FastIntegerMath<T>(1);
}
for (int i = n - 2; i >= 0; --i) {
stride_arr[i] = dims[i + 1] * stride_arr[i + 1];
math_helper_[i] = FastIntegerMath<T>(stride_arr[i]);
}
}
FastIntegerMath<T> math_helper_[N];
};
#define UNARY_BROADCAST_OP_SEQ OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kIdentity)
} // namespace broadcast_elementwise_unary
} // namespace primitive
} // namespace ep
} // namespace oneflow
#endif // ONEFLOW_CORE_PRIMITIVE_COMMON_BROADCAST_ELEMENTWISE_UNARY
...@@ -206,6 +206,7 @@ void SimplifyCopyNd(size_t num_dims, const int64_t* dst_dims, const int64_t* dst ...@@ -206,6 +206,7 @@ void SimplifyCopyNd(size_t num_dims, const int64_t* dst_dims, const int64_t* dst
void SimplifyThenLaunch(Stream* stream, DataType data_type, size_t num_dims, void* dst, void SimplifyThenLaunch(Stream* stream, DataType data_type, size_t num_dims, void* dst,
const int64_t* dst_dims, const int64_t* dst_pos, const void* src, const int64_t* dst_dims, const int64_t* dst_pos, const void* src,
const int64_t* src_dims, const int64_t* src_pos, const int64_t* extent) { const int64_t* src_dims, const int64_t* src_pos, const int64_t* extent) {
CHECK_GT(num_dims, 0) << "num_dims must greater than 0";
CHECK_LE(num_dims, kMaxNumDims); CHECK_LE(num_dims, kMaxNumDims);
size_t simplified_num_dims = 0; size_t simplified_num_dims = 0;
int64_t simplified_dst_dims[kMaxNumDims]; int64_t simplified_dst_dims[kMaxNumDims];
......
...@@ -25,29 +25,72 @@ namespace primitive { ...@@ -25,29 +25,72 @@ namespace primitive {
#define UNARY_MATH_OP_SEQ OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kRelu) #define UNARY_MATH_OP_SEQ OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kRelu)
#define UNARY_FLOATING_MATH_OP_SEQ \ #define UNARY_FLOATING_MATH_OP_SEQ \
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kElu) \ OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kElu) \
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kCelu) \ OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kCelu) \
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kGelu) \ OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kGelu) \
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kHardSwish) \ OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kHardSwish) \
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kHardSigmoid) \ OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kHardSigmoid) \
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kHardShrink) \ OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kHardShrink) \
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kHardTanh) \ OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kHardTanh) \
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kLeakyRelu) \ OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kLeakyRelu) \
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kMish) \ OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kMish) \
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kSelu) \ OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kSelu) \
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kSilu) \ OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kSilu) \
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kSoftShrink) \ OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kSoftShrink) \
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kSoftSign) \ OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kSoftSign) \
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kSoftPlus) \ OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kSoftPlus) \
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kTanh) \ OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kTanh) \
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kThreshold) OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kThreshold) \
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kAbs) \
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kAcos) \
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kAcosh) \
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kAsin) \
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kAsinh) \
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kAtan) \
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kAtanh) \
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kCeil) \
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kCos) \
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kCosh) \
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kErf) \
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kErfc) \
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kExp) \
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kExpm1) \
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kFloor) \
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kLgamma) \
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kLog) \
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kLog2) \
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kLog10) \
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kLog1p) \
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kLogSigmoid) \
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kNegative) \
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kReciprocal) \
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kReciprocalNoNan) \
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kRint) \
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kRound) \
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kRsqrt) \
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kSigmoid) \
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kSign) \
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kSin) \
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kSinh) \
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kSqrt) \
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kSign) \
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kSquare) \
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kTan) \
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kTrunc) \
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kNotEqualZero) \
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kNanAssign) \
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kFastGelu) \
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kQuickGelu)
#define UNARY_INT_MATH_OP_SEQ OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kAbs)
#define UNARY_LOGICAL_OP_SEQ OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kLogicalNot) #define UNARY_LOGICAL_OP_SEQ OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kLogicalNot)
#define UNARY_UTILS_OP_SEQ \ #define UNARY_UTILS_OP_SEQ \
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kIsInf) \ OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kIsInf) \
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kIsNan) OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kIsNan) \
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kIsFinite)
} // namespace primitive } // namespace primitive
} // namespace ep } // namespace ep
......
...@@ -60,12 +60,10 @@ REGISTER_PRIMITIVE_FACTORY(DeviceType::kCPU, MatmulFactory, MatmulFactoryImpl<De ...@@ -60,12 +60,10 @@ REGISTER_PRIMITIVE_FACTORY(DeviceType::kCPU, MatmulFactory, MatmulFactoryImpl<De
#ifdef WITH_CUDA #ifdef WITH_CUDA
REGISTER_PRIMITIVE_FACTORY(DeviceType::kCUDA, MatmulFactory, MatmulFactoryImpl<DeviceType::kCUDA>); REGISTER_PRIMITIVE_FACTORY(DeviceType::kCUDA, MatmulFactory, MatmulFactoryImpl<DeviceType::kCUDA>);
#endif // WITH_CUDA #endif // WITH_CUDA
#ifdef WITH_ROCM #ifdef WITH_ROCM
REGISTER_PRIMITIVE_FACTORY(DeviceType::kCUDA, MatmulFactory, MatmulFactoryImpl<DeviceType::kCUDA>); REGISTER_PRIMITIVE_FACTORY(DeviceType::kCUDA, MatmulFactory, MatmulFactoryImpl<DeviceType::kCUDA>);
#endif // WITH_ROCM #endif // WITH_ROCM
} // namespace } // namespace
} // namespace primitive } // namespace primitive
......
...@@ -28,9 +28,16 @@ namespace primitive { ...@@ -28,9 +28,16 @@ namespace primitive {
template<DeviceType device, UnaryOp unary_op, typename Dst, typename Src> template<DeviceType device, UnaryOp unary_op, typename Dst, typename Src>
struct UnaryFunctor; struct UnaryFunctor;
template<DeviceType device, typename Dst, typename Src>
struct UnaryFunctor<device, UnaryOp::kIdentity, Dst, Src> {
OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}
OF_DEVICE_FUNC Dst operator()(Src src) const { return static_cast<Dst>(src); }
};
template<DeviceType device, typename Dst, typename Src> template<DeviceType device, typename Dst, typename Src>
struct UnaryFunctor<device, UnaryOp::kElu, Dst, Src> { struct UnaryFunctor<device, UnaryOp::kElu, Dst, Src> {
UnaryFunctor(Scalar attr0, Scalar attr1) : alpha(attr0.Value<double>()) {} OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) : alpha(attr0.Value<double>()) {}
OF_DEVICE_FUNC Dst operator()(Src src) const { OF_DEVICE_FUNC Dst operator()(Src src) const {
return static_cast<Dst>( return static_cast<Dst>(
...@@ -41,7 +48,7 @@ struct UnaryFunctor<device, UnaryOp::kElu, Dst, Src> { ...@@ -41,7 +48,7 @@ struct UnaryFunctor<device, UnaryOp::kElu, Dst, Src> {
template<DeviceType device, typename Dst, typename Src> template<DeviceType device, typename Dst, typename Src>
struct UnaryFunctor<device, UnaryOp::kCelu, Dst, Src> { struct UnaryFunctor<device, UnaryOp::kCelu, Dst, Src> {
UnaryFunctor(Scalar attr0, Scalar attr1) OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1)
: alpha(attr0.Value<double>()), inv_alpha(1.0f / attr0.Value<double>()) {} : alpha(attr0.Value<double>()), inv_alpha(1.0f / attr0.Value<double>()) {}
OF_DEVICE_FUNC Dst operator()(Src src) const { OF_DEVICE_FUNC Dst operator()(Src src) const {
...@@ -54,7 +61,7 @@ struct UnaryFunctor<device, UnaryOp::kCelu, Dst, Src> { ...@@ -54,7 +61,7 @@ struct UnaryFunctor<device, UnaryOp::kCelu, Dst, Src> {
template<DeviceType device, typename Dst, typename Src> template<DeviceType device, typename Dst, typename Src>
struct UnaryFunctor<device, UnaryOp::kHardSwish, Dst, Src> { struct UnaryFunctor<device, UnaryOp::kHardSwish, Dst, Src> {
UnaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}
OF_DEVICE_FUNC Dst operator()(Src src) const { OF_DEVICE_FUNC Dst operator()(Src src) const {
if (src <= static_cast<Src>(-3)) { if (src <= static_cast<Src>(-3)) {
...@@ -69,7 +76,7 @@ struct UnaryFunctor<device, UnaryOp::kHardSwish, Dst, Src> { ...@@ -69,7 +76,7 @@ struct UnaryFunctor<device, UnaryOp::kHardSwish, Dst, Src> {
template<DeviceType device, typename Dst, typename Src> template<DeviceType device, typename Dst, typename Src>
struct UnaryFunctor<device, UnaryOp::kHardSigmoid, Dst, Src> { struct UnaryFunctor<device, UnaryOp::kHardSigmoid, Dst, Src> {
UnaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}
OF_DEVICE_FUNC Dst operator()(Src src) const { OF_DEVICE_FUNC Dst operator()(Src src) const {
if (src <= static_cast<Src>(-3)) { if (src <= static_cast<Src>(-3)) {
...@@ -84,7 +91,7 @@ struct UnaryFunctor<device, UnaryOp::kHardSigmoid, Dst, Src> { ...@@ -84,7 +91,7 @@ struct UnaryFunctor<device, UnaryOp::kHardSigmoid, Dst, Src> {
template<DeviceType device, typename Dst, typename Src> template<DeviceType device, typename Dst, typename Src>
struct UnaryFunctor<device, UnaryOp::kHardShrink, Dst, Src> { struct UnaryFunctor<device, UnaryOp::kHardShrink, Dst, Src> {
UnaryFunctor(Scalar attr0, Scalar attr1) : lambd(attr0.Value<double>()) {} OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) : lambd(attr0.Value<double>()) {}
OF_DEVICE_FUNC Dst operator()(Src src) const { OF_DEVICE_FUNC Dst operator()(Src src) const {
return (src <= lambd && src >= -lambd) ? static_cast<Dst>(0) : static_cast<Dst>(src); return (src <= lambd && src >= -lambd) ? static_cast<Dst>(0) : static_cast<Dst>(src);
...@@ -95,7 +102,7 @@ struct UnaryFunctor<device, UnaryOp::kHardShrink, Dst, Src> { ...@@ -95,7 +102,7 @@ struct UnaryFunctor<device, UnaryOp::kHardShrink, Dst, Src> {
template<DeviceType device, typename Dst, typename Src> template<DeviceType device, typename Dst, typename Src>
struct UnaryFunctor<device, UnaryOp::kHardTanh, Dst, Src> { struct UnaryFunctor<device, UnaryOp::kHardTanh, Dst, Src> {
UnaryFunctor(Scalar attr0, Scalar attr1) OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1)
: min_val(attr0.Value<double>()), max_val(attr1.Value<double>()) {} : min_val(attr0.Value<double>()), max_val(attr1.Value<double>()) {}
OF_DEVICE_FUNC Dst operator()(Src src) const { OF_DEVICE_FUNC Dst operator()(Src src) const {
...@@ -114,7 +121,7 @@ struct UnaryFunctor<device, UnaryOp::kHardTanh, Dst, Src> { ...@@ -114,7 +121,7 @@ struct UnaryFunctor<device, UnaryOp::kHardTanh, Dst, Src> {
template<DeviceType device, typename Dst, typename Src> template<DeviceType device, typename Dst, typename Src>
struct UnaryFunctor<device, UnaryOp::kLeakyRelu, Dst, Src> { struct UnaryFunctor<device, UnaryOp::kLeakyRelu, Dst, Src> {
UnaryFunctor(Scalar attr0, Scalar attr1) : alpha(attr0.Value<float>()) {} OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) : alpha(attr0.Value<float>()) {}
OF_DEVICE_FUNC Dst operator()(Src src) const { OF_DEVICE_FUNC Dst operator()(Src src) const {
return static_cast<Dst>((src > static_cast<Src>(0.0)) ? src : alpha * src); return static_cast<Dst>((src > static_cast<Src>(0.0)) ? src : alpha * src);
...@@ -124,7 +131,7 @@ struct UnaryFunctor<device, UnaryOp::kLeakyRelu, Dst, Src> { ...@@ -124,7 +131,7 @@ struct UnaryFunctor<device, UnaryOp::kLeakyRelu, Dst, Src> {
template<DeviceType device, typename Dst, typename Src> template<DeviceType device, typename Dst, typename Src>
struct UnaryFunctor<device, UnaryOp::kMish, Dst, Src> { struct UnaryFunctor<device, UnaryOp::kMish, Dst, Src> {
UnaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}
OF_DEVICE_FUNC Dst operator()(Src src) const { OF_DEVICE_FUNC Dst operator()(Src src) const {
Src soft_plus_val = log(static_cast<Src>(1) + exp(src)); Src soft_plus_val = log(static_cast<Src>(1) + exp(src));
...@@ -137,7 +144,7 @@ struct UnaryFunctor<device, UnaryOp::kMish, Dst, Src> { ...@@ -137,7 +144,7 @@ struct UnaryFunctor<device, UnaryOp::kMish, Dst, Src> {
template<DeviceType device, typename Dst, typename Src> template<DeviceType device, typename Dst, typename Src>
struct UnaryFunctor<device, UnaryOp::kRelu, Dst, Src> { struct UnaryFunctor<device, UnaryOp::kRelu, Dst, Src> {
UnaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}
OF_DEVICE_FUNC Dst operator()(Src src) const { OF_DEVICE_FUNC Dst operator()(Src src) const {
const Src zero_val = static_cast<Src>(0.0); const Src zero_val = static_cast<Src>(0.0);
...@@ -151,7 +158,7 @@ struct UnaryFunctor<device, UnaryOp::kRelu, Dst, Src> { ...@@ -151,7 +158,7 @@ struct UnaryFunctor<device, UnaryOp::kRelu, Dst, Src> {
template<DeviceType device, typename Dst, typename Src> template<DeviceType device, typename Dst, typename Src>
struct UnaryFunctor<device, UnaryOp::kSilu, Dst, Src> { struct UnaryFunctor<device, UnaryOp::kSilu, Dst, Src> {
UnaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}
OF_DEVICE_FUNC Dst operator()(Src src) const { OF_DEVICE_FUNC Dst operator()(Src src) const {
return static_cast<Dst>(src / (static_cast<Src>(1) + exp(-src))); return static_cast<Dst>(src / (static_cast<Src>(1) + exp(-src)));
...@@ -160,7 +167,7 @@ struct UnaryFunctor<device, UnaryOp::kSilu, Dst, Src> { ...@@ -160,7 +167,7 @@ struct UnaryFunctor<device, UnaryOp::kSilu, Dst, Src> {
template<DeviceType device, typename Dst, typename Src> template<DeviceType device, typename Dst, typename Src>
struct UnaryFunctor<device, UnaryOp::kSelu, Dst, Src> { struct UnaryFunctor<device, UnaryOp::kSelu, Dst, Src> {
UnaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}
OF_DEVICE_FUNC Dst operator()(Src src) const { OF_DEVICE_FUNC Dst operator()(Src src) const {
return static_cast<Dst>((src > static_cast<Src>(0.0)) return static_cast<Dst>((src > static_cast<Src>(0.0))
...@@ -173,7 +180,7 @@ struct UnaryFunctor<device, UnaryOp::kSelu, Dst, Src> { ...@@ -173,7 +180,7 @@ struct UnaryFunctor<device, UnaryOp::kSelu, Dst, Src> {
template<DeviceType device, typename Dst, typename Src> template<DeviceType device, typename Dst, typename Src>
struct UnaryFunctor<device, UnaryOp::kSoftSign, Dst, Src> { struct UnaryFunctor<device, UnaryOp::kSoftSign, Dst, Src> {
UnaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}
OF_DEVICE_FUNC Dst operator()(Src src) const { OF_DEVICE_FUNC Dst operator()(Src src) const {
return static_cast<Dst>(src / (static_cast<Src>(1) + abs(src))); return static_cast<Dst>(src / (static_cast<Src>(1) + abs(src)));
...@@ -182,7 +189,7 @@ struct UnaryFunctor<device, UnaryOp::kSoftSign, Dst, Src> { ...@@ -182,7 +189,7 @@ struct UnaryFunctor<device, UnaryOp::kSoftSign, Dst, Src> {
template<DeviceType device, typename Dst, typename Src> template<DeviceType device, typename Dst, typename Src>
struct UnaryFunctor<device, UnaryOp::kSoftPlus, Dst, Src> { struct UnaryFunctor<device, UnaryOp::kSoftPlus, Dst, Src> {
UnaryFunctor(Scalar attr0, Scalar attr1) OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1)
: beta(attr0.Value<double>()), threshold(attr1.Value<double>()) {} : beta(attr0.Value<double>()), threshold(attr1.Value<double>()) {}
OF_DEVICE_FUNC Dst operator()(Src src) const { OF_DEVICE_FUNC Dst operator()(Src src) const {
...@@ -196,7 +203,7 @@ struct UnaryFunctor<device, UnaryOp::kSoftPlus, Dst, Src> { ...@@ -196,7 +203,7 @@ struct UnaryFunctor<device, UnaryOp::kSoftPlus, Dst, Src> {
template<DeviceType device, typename Dst, typename Src> template<DeviceType device, typename Dst, typename Src>
struct UnaryFunctor<device, UnaryOp::kSoftShrink, Dst, Src> { struct UnaryFunctor<device, UnaryOp::kSoftShrink, Dst, Src> {
UnaryFunctor(Scalar attr0, Scalar attr1) : alpha(attr0.Value<double>()) {} OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) : alpha(attr0.Value<double>()) {}
OF_DEVICE_FUNC Dst operator()(Src src) const { OF_DEVICE_FUNC Dst operator()(Src src) const {
if (src <= alpha && src >= -alpha) { if (src <= alpha && src >= -alpha) {
...@@ -212,7 +219,7 @@ struct UnaryFunctor<device, UnaryOp::kSoftShrink, Dst, Src> { ...@@ -212,7 +219,7 @@ struct UnaryFunctor<device, UnaryOp::kSoftShrink, Dst, Src> {
template<DeviceType device, typename Dst, typename Src> template<DeviceType device, typename Dst, typename Src>
struct UnaryFunctor<device, UnaryOp::kThreshold, Dst, Src> { struct UnaryFunctor<device, UnaryOp::kThreshold, Dst, Src> {
UnaryFunctor(Scalar attr0, Scalar attr1) OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1)
: threshold(attr0.Value<double>()), value(attr1.Value<double>()) {} : threshold(attr0.Value<double>()), value(attr1.Value<double>()) {}
OF_DEVICE_FUNC Dst operator()(Src src) const { OF_DEVICE_FUNC Dst operator()(Src src) const {
...@@ -224,7 +231,7 @@ struct UnaryFunctor<device, UnaryOp::kThreshold, Dst, Src> { ...@@ -224,7 +231,7 @@ struct UnaryFunctor<device, UnaryOp::kThreshold, Dst, Src> {
template<DeviceType device, typename Dst, typename Src> template<DeviceType device, typename Dst, typename Src>
struct UnaryFunctor<device, UnaryOp::kLogicalNot, Dst, Src> { struct UnaryFunctor<device, UnaryOp::kLogicalNot, Dst, Src> {
UnaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}
OF_DEVICE_FUNC Dst operator()(Src src) const { return static_cast<Dst>(!src); } OF_DEVICE_FUNC Dst operator()(Src src) const { return static_cast<Dst>(!src); }
}; };
...@@ -243,6 +250,300 @@ struct UnaryFunctor<device, UnaryOp::kIsNan, bool, Src> { ...@@ -243,6 +250,300 @@ struct UnaryFunctor<device, UnaryOp::kIsNan, bool, Src> {
OF_DEVICE_FUNC bool operator()(Src src) const { return false; } OF_DEVICE_FUNC bool operator()(Src src) const { return false; }
}; };
template<DeviceType device, typename Src>
struct UnaryFunctor<device, UnaryOp::kIsFinite, bool, Src> {
UnaryFunctor(Scalar attr0, Scalar attr1) {}
OF_DEVICE_FUNC bool operator()(Src src) const { return true; }
};
template<DeviceType device, typename Dst, typename Src>
struct UnaryFunctor<device, UnaryOp::kTrunc, Dst, Src> {
OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1);
OF_DEVICE_FUNC Dst operator()(Src src) const;
};
template<DeviceType device, typename Dst, typename Src>
struct UnaryFunctor<device, UnaryOp::kAbs, Dst, Src> {
OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}
OF_DEVICE_FUNC Dst operator()(Src src) const { return static_cast<Dst>(abs(src)); }
};
template<DeviceType device>
struct UnaryFunctor<device, UnaryOp::kAbs, uint8_t, uint8_t> {
OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}
OF_DEVICE_FUNC uint8_t operator()(uint8_t src) const { return src; }
};
template<DeviceType device, typename Dst, typename Src>
struct UnaryFunctor<device, UnaryOp::kExp, Dst, Src> {
OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}
OF_DEVICE_FUNC Dst operator()(Src src) const { return static_cast<Dst>(exp(src)); }
};
template<DeviceType device, typename Dst, typename Src>
struct UnaryFunctor<device, UnaryOp::kAcos, Dst, Src> {
OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}
OF_DEVICE_FUNC Dst operator()(Src src) const { return static_cast<Dst>(acos(src)); }
};
template<DeviceType device, typename Dst, typename Src>
struct UnaryFunctor<device, UnaryOp::kAcosh, Dst, Src> {
OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}
OF_DEVICE_FUNC Dst operator()(Src src) const { return static_cast<Dst>(acosh(src)); }
};
template<DeviceType device, typename Dst, typename Src>
struct UnaryFunctor<device, UnaryOp::kAsin, Dst, Src> {
OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}
OF_DEVICE_FUNC Dst operator()(Src src) const { return static_cast<Dst>(asin(src)); }
};
template<DeviceType device, typename Dst, typename Src>
struct UnaryFunctor<device, UnaryOp::kAsinh, Dst, Src> {
OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}
OF_DEVICE_FUNC Dst operator()(Src src) const { return static_cast<Dst>(asinh(src)); }
};
template<DeviceType device, typename Dst, typename Src>
struct UnaryFunctor<device, UnaryOp::kAtan, Dst, Src> {
OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}
OF_DEVICE_FUNC Dst operator()(Src src) const { return static_cast<Dst>(atan(src)); }
};
template<DeviceType device, typename Dst, typename Src>
struct UnaryFunctor<device, UnaryOp::kAtanh, Dst, Src> {
OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}
OF_DEVICE_FUNC Dst operator()(Src src) const { return static_cast<Dst>(atanh(src)); }
};
template<DeviceType device, typename Dst, typename Src>
struct UnaryFunctor<device, UnaryOp::kCeil, Dst, Src> {
OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}
OF_DEVICE_FUNC Dst operator()(Src src) const { return static_cast<Dst>(ceil(src)); }
};
template<DeviceType device, typename Dst, typename Src>
struct UnaryFunctor<device, UnaryOp::kCos, Dst, Src> {
OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}
OF_DEVICE_FUNC Dst operator()(Src src) const { return static_cast<Dst>(cos(src)); }
};
template<DeviceType device, typename Dst, typename Src>
struct UnaryFunctor<device, UnaryOp::kCosh, Dst, Src> {
OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}
OF_DEVICE_FUNC Dst operator()(Src src) const { return static_cast<Dst>(cosh(src)); }
};
template<DeviceType device, typename Dst, typename Src>
struct UnaryFunctor<device, UnaryOp::kErf, Dst, Src> {
OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}
OF_DEVICE_FUNC Dst operator()(Src src) const { return static_cast<Dst>(erf(src)); }
};
template<DeviceType device, typename Dst, typename Src>
struct UnaryFunctor<device, UnaryOp::kErfc, Dst, Src> {
OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}
OF_DEVICE_FUNC Dst operator()(Src src) const { return static_cast<Dst>(erfc(src)); }
};
template<DeviceType device, typename Dst, typename Src>
struct UnaryFunctor<device, UnaryOp::kExpm1, Dst, Src> {
OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}
OF_DEVICE_FUNC Dst operator()(Src src) const { return static_cast<Dst>(expm1(src)); }
};
template<DeviceType device, typename Dst, typename Src>
struct UnaryFunctor<device, UnaryOp::kFloor, Dst, Src> {
OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}
OF_DEVICE_FUNC Dst operator()(Src src) const { return static_cast<Dst>(floor(src)); }
};
template<DeviceType device, typename Dst, typename Src>
struct UnaryFunctor<device, UnaryOp::kLgamma, Dst, Src> {
OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}
OF_DEVICE_FUNC Dst operator()(Src src) const { return static_cast<Dst>(lgamma(src)); }
};
template<DeviceType device, typename Dst, typename Src>
struct UnaryFunctor<device, UnaryOp::kLog, Dst, Src> {
OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}
OF_DEVICE_FUNC Dst operator()(Src src) const { return static_cast<Dst>(log(src)); }
};
template<DeviceType device, typename Dst, typename Src>
struct UnaryFunctor<device, UnaryOp::kLog2, Dst, Src> {
OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}
OF_DEVICE_FUNC Dst operator()(Src src) const { return static_cast<Dst>(log2(src)); }
};
template<DeviceType device, typename Dst, typename Src>
struct UnaryFunctor<device, UnaryOp::kLog10, Dst, Src> {
OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}
OF_DEVICE_FUNC Dst operator()(Src src) const { return static_cast<Dst>(log10(src)); }
};
template<DeviceType device, typename Dst, typename Src>
struct UnaryFunctor<device, UnaryOp::kLog1p, Dst, Src> {
OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}
OF_DEVICE_FUNC Dst operator()(Src src) const { return static_cast<Dst>(log1p(src)); }
};
template<DeviceType device, typename Dst, typename Src>
struct UnaryFunctor<device, UnaryOp::kLogSigmoid, Dst, Src> {
OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}
OF_DEVICE_FUNC Dst operator()(Src src) const {
return static_cast<Dst>(-log(static_cast<Src>(1.0) + exp(-src)));
}
};
template<DeviceType device, typename Dst, typename Src>
struct UnaryFunctor<device, UnaryOp::kNegative, Dst, Src> {
OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}
OF_DEVICE_FUNC Dst operator()(Src src) const { return static_cast<Dst>(-src); }
};
template<DeviceType device, typename Dst, typename Src>
struct UnaryFunctor<device, UnaryOp::kReciprocal, Dst, Src> {
OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}
OF_DEVICE_FUNC Dst operator()(Src src) const {
return static_cast<Dst>(static_cast<Src>(1.0) / src);
}
};
template<DeviceType device, typename Dst, typename Src>
struct UnaryFunctor<device, UnaryOp::kReciprocalNoNan, Dst, Src> {
OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}
OF_DEVICE_FUNC Dst operator()(Src src) const {
if (abs(src) <= static_cast<Src>(0.0)) { return static_cast<Dst>(0.0); }
return static_cast<Dst>(static_cast<Src>(1.0) / src);
}
};
template<DeviceType device, typename Dst, typename Src>
struct UnaryFunctor<device, UnaryOp::kRint, Dst, Src> {
OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}
OF_DEVICE_FUNC Dst operator()(Src src) const { return static_cast<Dst>(rint(src)); }
};
template<DeviceType device, typename Dst, typename Src>
struct UnaryFunctor<device, UnaryOp::kRound, Dst, Src> {
OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}
OF_DEVICE_FUNC Dst operator()(Src src) const { return static_cast<Dst>(nearbyint(src)); }
};
template<DeviceType device, typename Dst, typename Src>
struct UnaryFunctor<device, UnaryOp::kRsqrt, Dst, Src> {
OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}
OF_DEVICE_FUNC Dst operator()(Src src) const { return static_cast<Dst>(rsqrt(src)); }
};
template<DeviceType device, typename Dst, typename Src>
struct UnaryFunctor<device, UnaryOp::kSigmoid, Dst, Src> {
OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}
OF_DEVICE_FUNC Dst operator()(Src src) const {
return static_cast<Dst>(static_cast<Src>(1.0) / (static_cast<Src>(1.0) + exp(-src)));
}
};
template<DeviceType device, typename Dst, typename Src>
struct UnaryFunctor<device, UnaryOp::kSign, Dst, Src> {
OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}
OF_DEVICE_FUNC Dst operator()(Src src) const {
const Src zero = static_cast<Src>(0.0);
if (src > zero) {
return static_cast<Dst>(1.0);
} else if (src < zero) {
return static_cast<Dst>(-1.0);
} else {
return static_cast<Dst>(0.0);
}
}
};
template<DeviceType device, typename Dst, typename Src>
struct UnaryFunctor<device, UnaryOp::kSin, Dst, Src> {
OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}
OF_DEVICE_FUNC Dst operator()(Src src) const { return static_cast<Dst>(sin(src)); }
};
template<DeviceType device, typename Dst, typename Src>
struct UnaryFunctor<device, UnaryOp::kSinh, Dst, Src> {
OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}
OF_DEVICE_FUNC Dst operator()(Src src) const { return static_cast<Dst>(sinh(src)); }
};
template<DeviceType device, typename Dst, typename Src>
struct UnaryFunctor<device, UnaryOp::kSqrt, Dst, Src> {
OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}
OF_DEVICE_FUNC Dst operator()(Src src) const { return static_cast<Dst>(sqrt(src)); }
};
template<DeviceType device, typename Dst, typename Src>
struct UnaryFunctor<device, UnaryOp::kSquare, Dst, Src> {
OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}
OF_DEVICE_FUNC Dst operator()(Src src) const { return static_cast<Dst>(src * src); }
};
template<DeviceType device, typename Dst, typename Src>
struct UnaryFunctor<device, UnaryOp::kTan, Dst, Src> {
OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}
OF_DEVICE_FUNC Dst operator()(Src src) const { return static_cast<Dst>(tan(src)); }
};
template<DeviceType device, typename Dst, typename Src>
struct UnaryFunctor<device, UnaryOp::kNotEqualZero, Dst, Src> {
OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}
OF_DEVICE_FUNC Dst operator()(Src src) const {
return static_cast<Dst>(src != static_cast<Src>(0.0));
}
};
template<DeviceType device, typename Dst, typename Src>
struct UnaryFunctor<device, UnaryOp::kNanAssign, Dst, Src> {
OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}
OF_DEVICE_FUNC Dst operator()(Src src) const {
return std::isnan(src) ? static_cast<Dst>(0.0) : src;
}
};
} // namespace primitive } // namespace primitive
} // namespace ep } // namespace ep
} // namespace oneflow } // namespace oneflow
......
...@@ -37,6 +37,71 @@ bool IsPackSizeSupported(const size_t pack_size, size_t num_dims, const int64_t* ...@@ -37,6 +37,71 @@ bool IsPackSizeSupported(const size_t pack_size, size_t num_dims, const int64_t*
&& (reinterpret_cast<std::uintptr_t>(ptr) % (pack_size * sizeof(T)) == 0); && (reinterpret_cast<std::uintptr_t>(ptr) % (pack_size * sizeof(T)) == 0);
} }
inline void CheckInplace(size_t num_dims, const int64_t* src_dims_or_strides, const void* src,
const int64_t* dst_dims_or_strides, const void* dst) {
if (src == dst) {
for (int64_t i = 0; i < num_dims; ++i) {
CHECK_EQ(src_dims_or_strides[i], dst_dims_or_strides[i]);
}
}
}
template<size_t max_num_dims>
inline void SimplifyBroadcastDims(size_t num_src_dims, const int64_t* src_dims,
const int64_t* src_strides, size_t num_dst_dims,
const int64_t* dst_dims, const int64_t* dst_strides,
size_t* simplified_num_dims, int64_t* simplified_src_dims,
int64_t* simplified_src_strides, int64_t* simplified_dst_dims,
int64_t* simplified_dst_strides) {
*simplified_num_dims = 0;
std::pair<int64_t, size_t> sorted_dst_strides[max_num_dims];
int64_t new_dst_dims[max_num_dims];
int64_t new_src_dims[max_num_dims];
int64_t new_dst_strides[max_num_dims];
int64_t new_src_strides[max_num_dims];
for (size_t i = 0; i < num_dst_dims; i++) { sorted_dst_strides[i] = {dst_strides[i], i}; }
std::sort(sorted_dst_strides, sorted_dst_strides + num_dst_dims,
[](auto pair1, auto pair2) { return pair1.first > pair2.first; });
const int64_t num_src_padding_dims = num_dst_dims - num_src_dims;
// dimension completion
int64_t expanded_src_dims[max_num_dims];
int64_t expanded_src_strides[max_num_dims];
for (int64_t i = num_dst_dims - 1; i >= 0; i--) {
expanded_src_dims[i] = i < num_src_padding_dims ? 1 : src_dims[i - num_src_padding_dims];
expanded_src_strides[i] = i < num_src_padding_dims ? 0 : src_strides[i - num_src_padding_dims];
}
// dimension permutation
for (int64_t i = num_dst_dims - 1; i >= 0; i--) {
size_t idx = sorted_dst_strides[i].second;
new_dst_dims[i] = dst_dims[idx];
new_dst_strides[i] = dst_strides[idx];
new_src_dims[i] = expanded_src_dims[idx];
new_src_strides[i] = expanded_src_strides[idx];
}
// dimension merge
bool prev_broadcast_src = false;
for (int64_t i = 0; i < num_dst_dims; ++i) {
const bool broadcast_src = (new_src_dims[i] == 1);
if (new_dst_dims[i] == 1) {
continue;
} else if (*simplified_num_dims != 0 && prev_broadcast_src == broadcast_src
&& (new_src_strides[i - 1] == new_src_strides[i] * new_src_dims[i])
&& (new_dst_strides[i - 1] == new_dst_strides[i] * new_dst_dims[i])) {
simplified_src_dims[*simplified_num_dims - 1] *= new_src_dims[i];
simplified_dst_dims[*simplified_num_dims - 1] *= new_dst_dims[i];
simplified_src_strides[*simplified_num_dims - 1] = new_src_strides[i];
simplified_dst_strides[*simplified_num_dims - 1] = new_dst_strides[i];
} else {
simplified_src_dims[*simplified_num_dims] = new_src_dims[i];
simplified_dst_dims[*simplified_num_dims] = new_dst_dims[i];
simplified_src_strides[*simplified_num_dims] = new_src_strides[i];
simplified_dst_strides[*simplified_num_dims] = new_dst_strides[i];
*simplified_num_dims += 1;
prev_broadcast_src = broadcast_src;
}
}
}
inline void SimplifyBroadcastDims(size_t num_a_dims, const int64_t* a_dims, size_t num_b_dims, inline void SimplifyBroadcastDims(size_t num_a_dims, const int64_t* a_dims, size_t num_b_dims,
const int64_t* b_dims, size_t num_c_dims, const int64_t* c_dims, const int64_t* b_dims, size_t num_c_dims, const int64_t* c_dims,
size_t* simplified_num_dims, int64_t* simplified_broadcast_dims, size_t* simplified_num_dims, int64_t* simplified_broadcast_dims,
......
...@@ -42,15 +42,13 @@ Maybe<void> CpuDevice::Alloc(const AllocationOptions& options, void** ptr, size_ ...@@ -42,15 +42,13 @@ Maybe<void> CpuDevice::Alloc(const AllocationOptions& options, void** ptr, size_
this->device_manager()->registry()->GetDevice(options.GetPinnedDeviceType(), // NOLINT this->device_manager()->registry()->GetDevice(options.GetPinnedDeviceType(), // NOLINT
options.GetPinnedDeviceIndex()); // NOLINT options.GetPinnedDeviceIndex()); // NOLINT
CHECK_OR_RETURN(device); CHECK_OR_RETURN(device);
return device->AllocPinned(options, ptr, size); JUST(device->AllocPinned(options, ptr, size));
} else { } else {
*ptr = aligned_alloc(kMaxAlignmentRequirement, size); *ptr = aligned_alloc(kMaxAlignmentRequirement, RoundUp(size, kMaxAlignmentRequirement));
if (*ptr == nullptr) { if (*ptr == nullptr) { return Error::RuntimeError() << "allocate failed"; }
return Error::RuntimeError() << "allocate failed";
} else {
return Maybe<void>::Ok();
}
} }
memset(*ptr, 0, size);
return Maybe<void>::Ok();
} }
void CpuDevice::Free(const AllocationOptions& options, void* ptr) { void CpuDevice::Free(const AllocationOptions& options, void* ptr) {
......
...@@ -29,23 +29,224 @@ struct BinaryFunctor<DeviceType::kCPU, BinaryOp::kPow, Src, Dst> { ...@@ -29,23 +29,224 @@ struct BinaryFunctor<DeviceType::kCPU, BinaryOp::kPow, Src, Dst> {
}; };
template<> template<>
struct BinaryFunctor<DeviceType::kCPU, BinaryOp::kPow, bool, bool> { struct BinaryFunctor<DeviceType::kCPU, BinaryOp::kPow, float16, float16> {
OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}
OF_DEVICE_FUNC bool operator()(bool src0, bool src1) const { OF_DEVICE_FUNC float16 operator()(float16 src0, float16 src1) const {
return static_cast<bool>(std::pow(static_cast<double>(src0), static_cast<double>(src1))); return static_cast<float16>(std::pow(static_cast<float>(src0), static_cast<float>(src1)));
} }
}; };
template<> template<>
struct BinaryFunctor<DeviceType::kCPU, BinaryOp::kPow, float16, float16> { struct BinaryFunctor<DeviceType::kCPU, BinaryOp::kFmod, float, float> {
OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}
OF_DEVICE_FUNC float operator()(float src0, float src1) const { return std::fmod(src0, src1); }
};
template<>
struct BinaryFunctor<DeviceType::kCPU, BinaryOp::kFmod, double, double> {
OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}
OF_DEVICE_FUNC double operator()(double src0, double src1) const { return std::fmod(src0, src1); }
};
template<>
struct BinaryFunctor<DeviceType::kCPU, BinaryOp::kFmod, float16, float16> {
OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}
OF_DEVICE_FUNC float16 operator()(float16 src0, float16 src1) const { OF_DEVICE_FUNC float16 operator()(float16 src0, float16 src1) const {
return static_cast<float16>(std::pow(static_cast<float>(src0), static_cast<float>(src1))); return static_cast<float16>(std::fmod(static_cast<float>(src0), static_cast<float>(src1)));
}
};
template<>
struct BinaryFunctor<DeviceType::kCPU, BinaryOp::kFmod, bfloat16, bfloat16> {
OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}
OF_DEVICE_FUNC bfloat16 operator()(bfloat16 src0, bfloat16 src1) const {
return std::fmod(src0, src1);
}
};
template<>
struct BinaryFunctor<DeviceType::kCPU, BinaryOp::kFloorDiv, float, float> {
OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}
OF_DEVICE_FUNC float operator()(float src0, float src1) const { return std::floor(src0 / src1); }
};
template<>
struct BinaryFunctor<DeviceType::kCPU, BinaryOp::kFloorDiv, double, double> {
OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}
OF_DEVICE_FUNC double operator()(double src0, double src1) const {
return std::floor(src0 / src1);
}
};
template<>
struct BinaryFunctor<DeviceType::kCPU, BinaryOp::kFloorDiv, float16, float16> {
OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}
OF_DEVICE_FUNC float16 operator()(float16 src0, float16 src1) const {
return static_cast<float16>(std::floor(static_cast<float>(src0) / static_cast<float>(src1)));
}
};
template<>
struct BinaryFunctor<DeviceType::kCPU, BinaryOp::kFloorDiv, bfloat16, bfloat16> {
OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}
OF_DEVICE_FUNC bfloat16 operator()(bfloat16 src0, bfloat16 src1) const {
return std::floor(src0 / src1);
}
};
template<>
struct BinaryFunctor<DeviceType::kCPU, BinaryOp::kTruncDiv, float, float> {
OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}
OF_DEVICE_FUNC float operator()(float src0, float src1) const { return std::trunc(src0 / src1); }
};
template<>
struct BinaryFunctor<DeviceType::kCPU, BinaryOp::kTruncDiv, double, double> {
OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}
OF_DEVICE_FUNC double operator()(double src0, double src1) const {
return std::trunc(src0 / src1);
}
};
template<>
struct BinaryFunctor<DeviceType::kCPU, BinaryOp::kTruncDiv, float16, float16> {
OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}
OF_DEVICE_FUNC float16 operator()(float16 src0, float16 src1) const {
return static_cast<float16>(std::trunc(static_cast<float>(src0) / static_cast<float>(src1)));
}
};
template<>
struct BinaryFunctor<DeviceType::kCPU, BinaryOp::kTruncDiv, bfloat16, bfloat16> {
OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}
OF_DEVICE_FUNC bfloat16 operator()(bfloat16 src0, bfloat16 src1) const {
return std::trunc(src0 / src1);
}
};
template<>
struct BinaryFunctor<DeviceType::kCPU, BinaryOp::kFloorMod, float, float> {
OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}
OF_DEVICE_FUNC float operator()(float src0, float src1) const {
float trunc_mod = std::fmod(src0, src1);
return (trunc_mod != static_cast<float>(0))
&& ((src1 < static_cast<float>(0)) != (trunc_mod < static_cast<float>(0)))
? trunc_mod + src1
: trunc_mod;
}
};
template<>
struct BinaryFunctor<DeviceType::kCPU, BinaryOp::kFloorMod, double, double> {
OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}
OF_DEVICE_FUNC double operator()(double src0, double src1) const {
double trunc_mod = std::fmod(src0, src1);
return (trunc_mod != static_cast<double>(0))
&& ((src1 < static_cast<double>(0)) != (trunc_mod < static_cast<double>(0)))
? trunc_mod + src1
: trunc_mod;
}
};
template<>
struct BinaryFunctor<DeviceType::kCPU, BinaryOp::kFloorMod, float16, float16> {
OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) : float_functor(attr0, attr1) {}
BinaryFunctor<DeviceType::kCPU, BinaryOp::kFloorMod, float, float> float_functor;
OF_DEVICE_FUNC float16 operator()(float16 src0, float16 src1) const {
return static_cast<float16>(float_functor(static_cast<float>(src0), static_cast<float>(src1)));
}
};
template<>
struct BinaryFunctor<DeviceType::kCPU, BinaryOp::kFloorMod, bfloat16, bfloat16> {
OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) : float_functor(attr0, attr1) {}
BinaryFunctor<DeviceType::kCPU, BinaryOp::kFloorMod, float, float> float_functor;
OF_DEVICE_FUNC bfloat16 operator()(bfloat16 src0, bfloat16 src1) const {
return static_cast<bfloat16>(float_functor(static_cast<float>(src0), static_cast<float>(src1)));
}
};
template<>
struct BinaryFunctor<DeviceType::kCPU, BinaryOp::kScalarBasePowerGrad, float16, float16> {
OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) : scalar_operand(attr0.Value<float>()) {}
OF_DEVICE_FUNC float16 operator()(float16 src0, float16 src1) const {
return static_cast<float16>(
scalar_operand
* (std::pow(static_cast<float>(src0), scalar_operand - static_cast<float>(1)))
* static_cast<float>(src1));
}
float scalar_operand;
};
template<typename Dst>
struct BinaryFunctor<DeviceType::kCPU, BinaryOp::kScalarExpPowerGrad, int, Dst> {
OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) : float_functor(attr0, attr1) {}
BinaryFunctor<DeviceType::kCPU, BinaryOp::kScalarExpPowerGrad, float, float> float_functor;
OF_DEVICE_FUNC Dst operator()(int src0, int src1) const {
return static_cast<Dst>(float_functor(static_cast<float>(src0), static_cast<float>(src1)));
} }
}; };
template<typename Dst>
struct BinaryFunctor<DeviceType::kCPU, BinaryOp::kScalarExpPowerGrad, int8_t, Dst> {
OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) : float_functor(attr0, attr1) {}
BinaryFunctor<DeviceType::kCPU, BinaryOp::kScalarExpPowerGrad, float, float> float_functor;
OF_DEVICE_FUNC Dst operator()(int8_t src0, int8_t src1) const {
return static_cast<Dst>(float_functor(static_cast<float>(src0), static_cast<float>(src1)));
}
};
template<typename Dst>
struct BinaryFunctor<DeviceType::kCPU, BinaryOp::kScalarExpPowerGrad, uint8_t, Dst> {
OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) : float_functor(attr0, attr1) {}
BinaryFunctor<DeviceType::kCPU, BinaryOp::kScalarExpPowerGrad, float, float> float_functor;
OF_DEVICE_FUNC Dst operator()(uint8_t src0, uint8_t src1) const {
return static_cast<Dst>(float_functor(static_cast<float>(src0), static_cast<float>(src1)));
}
};
template<typename Dst>
struct BinaryFunctor<DeviceType::kCPU, BinaryOp::kScalarExpPowerGrad, int64_t, Dst> {
OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) : float_functor(attr0, attr1) {}
BinaryFunctor<DeviceType::kCPU, BinaryOp::kScalarExpPowerGrad, float, float> float_functor;
OF_DEVICE_FUNC Dst operator()(int src0, int src1) const {
return static_cast<Dst>(float_functor(static_cast<float>(src0), static_cast<float>(src1)));
}
};
template<typename Dst>
struct BinaryFunctor<DeviceType::kCPU, BinaryOp::kScalarExpPowerGrad, float16, Dst> {
OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) : scalar_operand(attr0.Value<float>()) {}
OF_DEVICE_FUNC Dst operator()(float16 src0, float16 src1) const {
return static_cast<Dst>(std::pow(scalar_operand, static_cast<float>(src0))
* std::log(scalar_operand) * static_cast<float>(src1));
}
float scalar_operand;
};
template<typename Src, typename Dst> template<typename Src, typename Dst>
struct BinaryFunctor<DeviceType::kCPU, BinaryOp::kGeluBackwardWithDyX, Src, Dst> { struct BinaryFunctor<DeviceType::kCPU, BinaryOp::kGeluBackwardWithDyX, Src, Dst> {
OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}
...@@ -59,6 +260,39 @@ struct BinaryFunctor<DeviceType::kCPU, BinaryOp::kGeluBackwardWithDyX, Src, Dst> ...@@ -59,6 +260,39 @@ struct BinaryFunctor<DeviceType::kCPU, BinaryOp::kGeluBackwardWithDyX, Src, Dst>
Src coef = std::sqrt(2.0 / std::acos(-1.0)); Src coef = std::sqrt(2.0 / std::acos(-1.0));
}; };
template<typename Src, typename Dst>
struct BinaryFunctor<DeviceType::kCPU, BinaryOp::kFastGeluBackwardWithDyX, Src, Dst> {
OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}
OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const {
// ref to: https://mlfromscratch.com/activation-functions-explained/#gelu
const Src one = static_cast<Src>(1);
const Src half = static_cast<Src>(0.5);
const Src pow3 = x * x * x;
const Src tanh_out = std::tanh(alpha * (x + beta * pow3));
const Src dtanh = alpha * (half * x + beta * static_cast<Src>(1.5) * pow3);
return dy * (half + half * tanh_out + dtanh * (one - tanh_out * tanh_out));
}
private:
static constexpr Src alpha = static_cast<Src>(0.7978845608028654);
static constexpr Src beta = static_cast<Src>(0.044714998453855515);
};
template<typename Src, typename Dst>
struct BinaryFunctor<DeviceType::kCPU, BinaryOp::kQuickGeluBackwardWithDyX, Src, Dst> {
OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}
OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const {
const Src one = static_cast<Src>(1.0);
const Src sigmoid = one / (one + exp(-x * alpha));
return dy * (sigmoid + alpha * x * (sigmoid * (one - sigmoid)));
}
private:
static constexpr Src alpha = static_cast<Src>(1.702);
};
template<typename Src, typename Dst> template<typename Src, typename Dst>
struct BinaryFunctor<DeviceType::kCPU, BinaryOp::kTanhBackwardWithDyX, Src, Dst> { struct BinaryFunctor<DeviceType::kCPU, BinaryOp::kTanhBackwardWithDyX, Src, Dst> {
OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}
...@@ -69,6 +303,82 @@ struct BinaryFunctor<DeviceType::kCPU, BinaryOp::kTanhBackwardWithDyX, Src, Dst> ...@@ -69,6 +303,82 @@ struct BinaryFunctor<DeviceType::kCPU, BinaryOp::kTanhBackwardWithDyX, Src, Dst>
} }
}; };
template<typename Src, typename Dst>
struct BinaryFunctor<DeviceType::kCPU, BinaryOp::kAcosBackwardWithDyX, Src, Dst> {
OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}
OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const {
return dy * -(static_cast<Src>(1.0) / sqrt(static_cast<Src>(1.0) - x * x));
}
};
template<typename Src, typename Dst>
struct BinaryFunctor<DeviceType::kCPU, BinaryOp::kAcoshBackwardWithDyX, Src, Dst> {
OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}
OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const {
return dy / sqrt(x * x - static_cast<Src>(1.0));
}
};
template<typename Src, typename Dst>
struct BinaryFunctor<DeviceType::kCPU, BinaryOp::kAsinBackwardWithDyX, Src, Dst> {
OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}
OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const {
return dy * (static_cast<Src>(1.0) / sqrt(static_cast<Src>(1.0) - x * x));
}
};
template<typename Src, typename Dst>
struct BinaryFunctor<DeviceType::kCPU, BinaryOp::kAsinhBackwardWithDyX, Src, Dst> {
OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}
OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const {
return dy * (static_cast<Src>(1.0) / sqrt(static_cast<Src>(1.0) + x * x));
}
};
template<typename Src, typename Dst>
struct BinaryFunctor<DeviceType::kCPU, BinaryOp::kErfBackwardWithDyX, Src, Dst> {
OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}
OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const {
return dy * static_cast<Src>(2.0) * (static_cast<Src>(1.0) / sqrt(static_cast<Src>(M_PI)))
* exp(-x * x);
}
};
template<typename Src, typename Dst>
struct BinaryFunctor<DeviceType::kCPU, BinaryOp::kErfcBackwardWithDyX, Src, Dst> {
OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}
OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const {
return dy * static_cast<Src>(-2.0) * (static_cast<Src>(1.0) / sqrt(static_cast<Src>(M_PI)))
* exp(-x * x);
}
};
#define SPECIALIZATION_CPU_BINARY_FUNCTOR(op, type) \
template<> \
struct BinaryFunctor<DeviceType::kCPU, op, type, type> { \
OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) : int_functor(attr0, attr1) {} \
\
BinaryFunctor<DeviceType::kCPU, op, int, int> int_functor; \
OF_DEVICE_FUNC type operator()(type src0, type src1) const { \
return static_cast<type>(int_functor(static_cast<int>(src0), static_cast<int>(src1))); \
} \
};
SPECIALIZATION_CPU_BINARY_FUNCTOR(BinaryOp::kPow, bool);
SPECIALIZATION_CPU_BINARY_FUNCTOR(BinaryOp::kFmod, bool);
SPECIALIZATION_CPU_BINARY_FUNCTOR(BinaryOp::kFloorDiv, bool);
SPECIALIZATION_CPU_BINARY_FUNCTOR(BinaryOp::kTruncDiv, bool);
SPECIALIZATION_CPU_BINARY_FUNCTOR(BinaryOp::kFloorMod, bool);
SPECIALIZATION_CPU_BINARY_FUNCTOR(BinaryOp::kScalarBasePowerGrad, bool);
SPECIALIZATION_CPU_BINARY_FUNCTOR(BinaryOp::kScalarExpPowerGrad, bool);
SPECIALIZATION_CPU_BINARY_FUNCTOR(BinaryOp::kPow, char);
SPECIALIZATION_CPU_BINARY_FUNCTOR(BinaryOp::kFmod, char);
SPECIALIZATION_CPU_BINARY_FUNCTOR(BinaryOp::kFloorDiv, char);
SPECIALIZATION_CPU_BINARY_FUNCTOR(BinaryOp::kTruncDiv, char);
SPECIALIZATION_CPU_BINARY_FUNCTOR(BinaryOp::kFloorMod, char);
SPECIALIZATION_CPU_BINARY_FUNCTOR(BinaryOp::kScalarBasePowerGrad, char);
SPECIALIZATION_CPU_BINARY_FUNCTOR(BinaryOp::kScalarExpPowerGrad, char);
} // namespace broadcast_elementwise_binary } // namespace broadcast_elementwise_binary
} // namespace primitive } // namespace primitive
} // namespace ep } // namespace ep
......
...@@ -45,6 +45,11 @@ float16 GetValue<float16>(Scalar value) { ...@@ -45,6 +45,11 @@ float16 GetValue<float16>(Scalar value) {
return static_cast<float16>(GetValue<float>(value)); return static_cast<float16>(GetValue<float>(value));
} }
template<>
bfloat16 GetValue<bfloat16>(Scalar value) {
return static_cast<bfloat16>(GetValue<float>(value));
}
template<BinaryOp binary_op, typename Src, typename Dst> template<BinaryOp binary_op, typename Src, typename Dst>
struct BinaryLhsScalarFunctor { struct BinaryLhsScalarFunctor {
BinaryLhsScalarFunctor(Src scalar, Scalar attr0, Scalar attr1) BinaryLhsScalarFunctor(Src scalar, Scalar attr0, Scalar attr1)
...@@ -247,8 +252,8 @@ void DispatchLaunch(Stream* stream, size_t num_src0_dims, const int64_t* src0_di ...@@ -247,8 +252,8 @@ void DispatchLaunch(Stream* stream, size_t num_src0_dims, const int64_t* src0_di
SimplifyBroadcastDims<kMaxNumDims>(num_src0_dims, src0_dims, num_src1_dims, src1_dims, SimplifyBroadcastDims<kMaxNumDims>(num_src0_dims, src0_dims, num_src1_dims, src1_dims,
&simplified_num_dims, simplified_src0_dims, &simplified_num_dims, simplified_src0_dims,
simplified_src1_dims, simplified_dst_dims); simplified_src1_dims, simplified_dst_dims);
CheckInplace(simplified_num_dims, simplified_src0_dims, src0, simplified_src1_dims, src1, CheckInplace(simplified_num_dims, simplified_src0_dims, src0, simplified_dst_dims, dst);
simplified_dst_dims, dst); CheckInplace(simplified_num_dims, simplified_src1_dims, src1, simplified_dst_dims, dst);
if (IsDimsEquals(simplified_num_dims, simplified_src0_dims, simplified_num_dims, if (IsDimsEquals(simplified_num_dims, simplified_src0_dims, simplified_num_dims,
simplified_src1_dims)) { simplified_src1_dims)) {
LaunchElementwise<binary_op, Src, Dst>(cpu_stream, simplified_num_dims, simplified_src0_dims, LaunchElementwise<binary_op, Src, Dst>(cpu_stream, simplified_num_dims, simplified_src0_dims,
...@@ -260,16 +265,20 @@ void DispatchLaunch(Stream* stream, size_t num_src0_dims, const int64_t* src0_di ...@@ -260,16 +265,20 @@ void DispatchLaunch(Stream* stream, size_t num_src0_dims, const int64_t* src0_di
} else if (simplified_num_dims == 1 && simplified_src1_dims[0] == 1) { } else if (simplified_num_dims == 1 && simplified_src1_dims[0] == 1) {
LaunchBinaryRhsScalar<binary_op, Src, Dst>(cpu_stream, *src1, simplified_src0_dims[0], src0, LaunchBinaryRhsScalar<binary_op, Src, Dst>(cpu_stream, *src1, simplified_src0_dims[0], src0,
dst, attr0, attr1); dst, attr0, attr1);
} else if (simplified_num_dims == 2 && simplified_src0_dims[0] == 1) { } else if (simplified_num_dims == 2 && simplified_src0_dims[0] == 1
&& simplified_src0_dims[1] == simplified_src1_dims[1]) {
LaunchRowWithMatrix<binary_op, Src, Dst>(cpu_stream, simplified_src0_dims, src0, LaunchRowWithMatrix<binary_op, Src, Dst>(cpu_stream, simplified_src0_dims, src0,
simplified_src1_dims, src1, dst, attr0, attr1); simplified_src1_dims, src1, dst, attr0, attr1);
} else if (simplified_num_dims == 2 && simplified_src1_dims[0] == 1) { } else if (simplified_num_dims == 2 && simplified_src1_dims[0] == 1
&& simplified_src0_dims[1] == simplified_src1_dims[1]) {
LaunchMatrixWithRow<binary_op, Src, Dst>(cpu_stream, simplified_src0_dims, src0, LaunchMatrixWithRow<binary_op, Src, Dst>(cpu_stream, simplified_src0_dims, src0,
simplified_src1_dims, src1, dst, attr0, attr1); simplified_src1_dims, src1, dst, attr0, attr1);
} else if (simplified_num_dims == 2 && simplified_src0_dims[1] == 1) { } else if (simplified_num_dims == 2 && simplified_src0_dims[1] == 1
&& simplified_src0_dims[0] == simplified_src1_dims[0]) {
LaunchColWithMatrix<binary_op, Src, Dst>(cpu_stream, simplified_src0_dims, src0, LaunchColWithMatrix<binary_op, Src, Dst>(cpu_stream, simplified_src0_dims, src0,
simplified_src1_dims, src1, dst, attr0, attr1); simplified_src1_dims, src1, dst, attr0, attr1);
} else if (simplified_num_dims == 2 && simplified_src1_dims[1] == 1) { } else if (simplified_num_dims == 2 && simplified_src1_dims[1] == 1
&& simplified_src0_dims[0] == simplified_src1_dims[0]) {
LaunchMatrixWithCol<binary_op, Src, Dst>(cpu_stream, simplified_src0_dims, src0, LaunchMatrixWithCol<binary_op, Src, Dst>(cpu_stream, simplified_src0_dims, src0,
simplified_src1_dims, src1, dst, attr0, attr1); simplified_src1_dims, src1, dst, attr0, attr1);
} else { } else {
...@@ -405,8 +414,8 @@ class OneDnnBroadcastElementwiseBinaryImpl : public BroadcastElementwiseBinary { ...@@ -405,8 +414,8 @@ class OneDnnBroadcastElementwiseBinaryImpl : public BroadcastElementwiseBinary {
src1_dims, dst_dims); src1_dims, dst_dims);
} }
CheckInplace(num_dims, src_0_dims.data(), onednn_src0, src_1_dims.data(), onednn_src1, CheckInplace(num_dims, src_0_dims.data(), onednn_src0, dst_dims.data(), dst);
dst_dims.data(), dst); CheckInplace(num_dims, src_1_dims.data(), onednn_src1, dst_dims.data(), dst);
auto src_0_md = dnnl::memory::desc( auto src_0_md = dnnl::memory::desc(
src_0_dims, src_onednn, src_0_dims, src_onednn,
...@@ -564,7 +573,11 @@ class BroadcastElementwiseBinaryFactoryImpl : public BroadcastElementwiseBinaryF ...@@ -564,7 +573,11 @@ class BroadcastElementwiseBinaryFactoryImpl : public BroadcastElementwiseBinaryF
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE( OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(
MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_ACTIVATION_GRAD_ENTRY, MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_ACTIVATION_GRAD_ENTRY,
BINARY_ACTIVATION_BACKWARD_OP_SEQ, CPU_PRIMITIVE_FLOATING_TYPE_SEQ)}; BINARY_ACTIVATION_BACKWARD_OP_SEQ, CPU_PRIMITIVE_FLOATING_TYPE_SEQ)
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(
MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY,
BINARY_MATH_BACKWARD_OP_SEQ, CPU_PRIMITIVE_FLOATING_TYPE_SEQ)};
#undef MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_COMPARASION_AND_LOGICAL_ENTRY #undef MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_COMPARASION_AND_LOGICAL_ENTRY
#undef MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY #undef MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY
......
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/ep/include/primitive/broadcast_elementwise_unary.h"
#include "oneflow/core/common/data_type.h"
#include "oneflow/core/ep/common/primitive/broadcast_elementwise_unary.h"
#include "oneflow/core/ep/cpu/primitive/unary_functor.h"
#include "oneflow/core/ep/cpu/primitive/type_seq.h"
#include "oneflow/core/ep/cpu/cpu_stream.h"
#include "oneflow/core/ep/cpu/cpu_device.h"
namespace oneflow {
namespace ep {
namespace primitive {
namespace broadcast_elementwise_unary {
namespace {
bool IsContiguous(size_t num_dims, const int64_t* dims, const int64_t* strides) {
for (int i = num_dims - 1; i >= 0; i--) {
if ((i == num_dims - 1 && strides[i] != 1)
|| (i != num_dims - 1 && strides[i] != dims[i + 1] * strides[i + 1])) {
return false;
}
}
return true;
}
template<UnaryOp unary_op, typename Src, typename Dst>
void LaunchScalarFill(CpuStream* stream, Dst* dst, const Src* src, size_t count, size_t stride,
Scalar attr0, Scalar attr1) {
auto functor = UnaryFunctor<DeviceType::kCPU, unary_op, Src, Dst>(attr0, attr1);
Dst scalar_value = functor(*src);
stream->ParallelFor(0, count, [dst, stride, scalar_value](int64_t begin, int64_t end) {
for (int64_t i = begin; i < end; i++) { dst[i * stride] = scalar_value; }
});
}
template<UnaryOp unary_op, typename Src, typename Dst>
void LaunchTensorFill(CpuStream* stream, Dst* dst, const Src* src, size_t count, size_t dst_stride,
size_t src_stride, Scalar attr0, Scalar attr1) {
auto functor = UnaryFunctor<DeviceType::kCPU, unary_op, Src, Dst>(attr0, attr1);
stream->ParallelFor(0, count,
[functor, src, dst, src_stride, dst_stride](int64_t begin, int64_t end) {
for (int64_t i = begin; i < end; i++) {
dst[i * dst_stride] = functor(src[i * src_stride]);
}
});
}
template<UnaryOp unary_op, typename Src, typename Dst>
void LaunchGeneral(CpuStream* stream, Dst* dst, const Src* src, size_t num_dims,
const int64_t* dst_dims, const int64_t* src_dims, const int64_t* dst_stride,
const int64_t* src_stride, Scalar attr0, Scalar attr1) {
bool contiguous_output = IsContiguous(num_dims, dst_dims, dst_stride);
const int64_t elem_cnt = GetElementCount(num_dims, dst_dims);
auto functor = UnaryFunctor<DeviceType::kCPU, unary_op, Src, Dst>(attr0, attr1);
stream->ParallelFor(
0, elem_cnt,
[functor, src, dst, num_dims, src_dims, dst_dims, src_stride, dst_stride, contiguous_output](
int64_t begin, int64_t end) {
auto src_index_to_offset_helper =
IndexToOffsetWithStrideCalculator<int64_t, kMaxNumDims>(src_stride, num_dims);
auto dst_offset_to_index_helper =
OffsetToIndexWithStrideCalculator<int64_t, kMaxNumDims>(dst_dims, num_dims);
auto dst_index_to_offset_helper =
IndexToOffsetWithStrideCalculator<int64_t, kMaxNumDims>(dst_stride, num_dims);
int64_t src_index[kMaxNumDims];
int64_t dst_index[kMaxNumDims];
for (int64_t offset = begin; offset < end; offset++) {
dst_offset_to_index_helper.OffsetToNdIndex(offset, dst_index, num_dims);
for (int i = 0; i < kMaxNumDims; i++) {
if (i < num_dims) {
src_index[i] = (src_dims[i] != 1) ? dst_index[i] : 0;
} else {
src_index[i] = 0;
}
}
const int64_t src_offset =
src_index_to_offset_helper.NdIndexToOffset(src_index, num_dims);
if (!contiguous_output) {
const int64_t dst_offset =
dst_index_to_offset_helper.NdIndexToOffset(dst_index, num_dims);
dst[dst_offset] = functor(src[src_offset]);
} else {
dst[offset] = functor(src[src_offset]);
}
}
});
}
template<UnaryOp unary_op, typename Src, typename Dst>
class BroadcastElementwiseUnaryImpl : public BroadcastElementwiseUnary {
public:
OF_DISALLOW_COPY_AND_MOVE(BroadcastElementwiseUnaryImpl);
BroadcastElementwiseUnaryImpl(Scalar attr0, Scalar attr1) : attr0(attr0), attr1(attr1) {}
~BroadcastElementwiseUnaryImpl() override = default;
void Launch(Stream* stream, size_t num_src_dims, const int64_t* src_dims, const void* src,
size_t num_dst_dims, const int64_t* dst_dims, void* dst) override {
CHECK_GT(num_src_dims, 0) << "num_src_dims must greater than 0";
CHECK_GT(num_dst_dims, 0) << "num_dst_dims must greater than 0";
int64_t src_strides[kMaxNumDims];
int64_t dst_strides[kMaxNumDims];
// init stride
for (int i = num_src_dims - 1; i < kMaxNumDims; ++i) { src_strides[i] = 1; }
for (int i = num_src_dims - 2; i >= 0; --i) {
src_strides[i] = src_dims[i + 1] * src_strides[i + 1];
}
for (int i = num_dst_dims - 1; i < kMaxNumDims; ++i) { dst_strides[i] = 1; }
for (int i = num_dst_dims - 2; i >= 0; --i) {
dst_strides[i] = dst_dims[i + 1] * dst_strides[i + 1];
}
Launch(stream, num_src_dims, src_dims, src_strides, src, num_dst_dims, dst_dims, dst_strides,
dst);
}
void Launch(Stream* stream, size_t num_src_dims, const int64_t* src_dims,
const int64_t* src_strides, const void* src_ptr, size_t num_dst_dims,
const int64_t* dst_dims, const int64_t* dst_strides, void* dst_ptr) override {
CHECK_GT(num_src_dims, 0) << "num_src_dims must greater than 0";
CHECK_GT(num_dst_dims, 0) << "num_dst_dims must greater than 0";
auto* cpu_stream = stream->As<CpuStream>();
Dst* dst = reinterpret_cast<Dst*>(dst_ptr);
const Src* src = reinterpret_cast<const Src*>(src_ptr);
size_t simplified_num_dims = 0;
int64_t simplified_src_dims[kMaxNumDims];
int64_t simplified_dst_dims[kMaxNumDims];
int64_t simplified_src_strides[kMaxNumDims];
int64_t simplified_dst_strides[kMaxNumDims];
SimplifyBroadcastDims<kMaxNumDims>(num_src_dims, src_dims, src_strides, num_dst_dims, dst_dims,
dst_strides, &simplified_num_dims, simplified_src_dims,
simplified_src_strides, simplified_dst_dims,
simplified_dst_strides);
CheckInplace(simplified_num_dims, simplified_src_dims, src, simplified_dst_dims, dst);
CheckInplace(simplified_num_dims, simplified_src_strides, src, simplified_dst_strides, dst);
if (simplified_num_dims == 1 && simplified_src_dims[0] == 1) {
const int64_t elem_cnt = simplified_dst_dims[0];
const int64_t dst_stride = simplified_dst_strides[0];
LaunchScalarFill<unary_op, Src, Dst>(cpu_stream, dst, src, elem_cnt, dst_stride, attr0,
attr1);
} else if (simplified_num_dims == 1) {
const int64_t elem_cnt = simplified_src_dims[0];
const int64_t src_stride = simplified_src_strides[0];
const int64_t dst_stride = simplified_dst_strides[0];
LaunchTensorFill<unary_op, Src, Dst>(cpu_stream, dst, src, elem_cnt, dst_stride, src_stride,
attr0, attr1);
} else {
LaunchGeneral<unary_op, Src, Dst>(
cpu_stream, dst, src, simplified_num_dims, simplified_dst_dims, simplified_src_dims,
simplified_dst_strides, simplified_src_strides, attr0, attr1);
}
}
protected:
Scalar attr0, attr1;
};
template<UnaryOp unary_op, typename Src, typename Dst>
std::unique_ptr<BroadcastElementwiseUnary> NewBroadcastElementwiseUnary(Scalar attr0,
Scalar attr1) {
return std::unique_ptr<BroadcastElementwiseUnary>(
new BroadcastElementwiseUnaryImpl<unary_op, Src, Dst>(attr0, attr1));
}
class BroadcastElementwiseUnaryFactoryImpl : public BroadcastElementwiseUnaryFactory {
public:
OF_DISALLOW_COPY_AND_MOVE(BroadcastElementwiseUnaryFactoryImpl);
BroadcastElementwiseUnaryFactoryImpl() = default;
~BroadcastElementwiseUnaryFactoryImpl() override = default;
std::unique_ptr<BroadcastElementwiseUnary> New(UnaryOp op, DataType src_type, DataType dst_type,
size_t max_num_dims) override {
return New(op, src_type, dst_type, max_num_dims, Scalar(), Scalar());
}
std::unique_ptr<BroadcastElementwiseUnary> New(UnaryOp op, DataType src_type, DataType dst_type,
size_t max_num_dims, Scalar attr0) override {
return New(op, src_type, dst_type, max_num_dims, attr0, Scalar());
}
std::unique_ptr<BroadcastElementwiseUnary> New(UnaryOp unary_op, DataType src_type,
DataType dst_type, size_t max_num_dims,
Scalar attr0, Scalar attr1) override {
if (max_num_dims > kMaxNumDims) { return nullptr; }
#define MAKE_NEW_SAME_DTYPE_BROADCAST_ELEMENTWISE_UNARY_ENTRY(unary_op, dtype_pair) \
{std::make_tuple(unary_op, OF_PP_PAIR_SECOND(dtype_pair), OF_PP_PAIR_SECOND(dtype_pair)), \
NewBroadcastElementwiseUnary<unary_op, OF_PP_PAIR_FIRST(dtype_pair), \
OF_PP_PAIR_FIRST(dtype_pair)>},
static const std::map<std::tuple<UnaryOp, DataType, DataType>,
std::function<std::unique_ptr<BroadcastElementwiseUnary>(Scalar, Scalar)>>
new_broadcast_elementwise_unary_handle{
// For All Type OP
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(MAKE_NEW_SAME_DTYPE_BROADCAST_ELEMENTWISE_UNARY_ENTRY,
UNARY_BROADCAST_OP_SEQ, CPU_PRIMITIVE_ALL_TYPE_SEQ)};
#undef MAKE_NEW_SAME_DTYPE_BROADCAST_ELEMENTWISE_UNARY_ENTRY
const auto iter =
new_broadcast_elementwise_unary_handle.find(std::make_tuple(unary_op, src_type, dst_type));
if (iter != new_broadcast_elementwise_unary_handle.end()) {
return iter->second(attr0, attr1);
} else {
return nullptr;
}
}
};
REGISTER_PRIMITIVE_FACTORY(DeviceType::kCPU, BroadcastElementwiseUnaryFactory,
BroadcastElementwiseUnaryFactoryImpl);
} // namespace
} // namespace broadcast_elementwise_unary
} // namespace primitive
} // namespace ep
} // namespace oneflow
...@@ -23,10 +23,28 @@ namespace primitive { ...@@ -23,10 +23,28 @@ namespace primitive {
namespace { namespace {
template<typename From, typename To> template<typename From, typename To, typename = void>
void CastCpu(const From* from, To* to, size_t count) { struct CpuCastFunctor {
for (size_t i = 0; i < count; ++i) { to[i] = static_cast<To>(from[i]); } static void Call(const From* from, To* to, size_t count) {
} for (size_t i = 0; i < count; ++i) { to[i] = static_cast<To>(from[i]); }
}
};
template<typename To>
struct CpuCastFunctor<bfloat16, To,
typename std::enable_if<!(std::is_same<To, bfloat16>::value)>::type> {
static void Call(const bfloat16* from, To* to, size_t count) {
for (size_t i = 0; i < count; ++i) { to[i] = static_cast<To>(static_cast<float>(from[i])); }
}
};
template<typename From>
struct CpuCastFunctor<From, bfloat16,
typename std::enable_if<!(std::is_same<From, bfloat16>::value)>::type> {
static void Call(const From* from, bfloat16* to, size_t count) {
for (size_t i = 0; i < count; ++i) { to[i] = bfloat16(static_cast<float>(from[i])); }
}
};
template<typename From, typename To> template<typename From, typename To>
class CastImpl : public Cast { class CastImpl : public Cast {
...@@ -36,7 +54,8 @@ class CastImpl : public Cast { ...@@ -36,7 +54,8 @@ class CastImpl : public Cast {
~CastImpl() override = default; ~CastImpl() override = default;
void Launch(Stream* stream, const void* from, void* to, size_t count) override { void Launch(Stream* stream, const void* from, void* to, size_t count) override {
CastCpu(reinterpret_cast<const From*>(from), reinterpret_cast<To*>(to), count); CpuCastFunctor<From, To>::Call(reinterpret_cast<const From*>(from), reinterpret_cast<To*>(to),
count);
} }
}; };
...@@ -56,7 +75,8 @@ std::unique_ptr<Cast> NewCast() { ...@@ -56,7 +75,8 @@ std::unique_ptr<Cast> NewCast() {
CPU_PRIMITIVE_UINT64_TYPE_SEQ \ CPU_PRIMITIVE_UINT64_TYPE_SEQ \
CPU_PRIMITIVE_FLOAT_TYPE_SEQ \ CPU_PRIMITIVE_FLOAT_TYPE_SEQ \
CPU_PRIMITIVE_DOUBLE_TYPE_SEQ \ CPU_PRIMITIVE_DOUBLE_TYPE_SEQ \
CPU_PRIMITIVE_FLOAT16_TYPE_SEQ CPU_PRIMITIVE_FLOAT16_TYPE_SEQ \
CPU_PRIMITIVE_BFLOAT16_TYPE_SEQ
class CastFactoryImpl : public CastFactory { class CastFactoryImpl : public CastFactory {
public: public:
......
...@@ -56,6 +56,11 @@ float16 GetValue<float16>(Scalar value) { ...@@ -56,6 +56,11 @@ float16 GetValue<float16>(Scalar value) {
return static_cast<float16>(GetValue<float>(value)); return static_cast<float16>(GetValue<float>(value));
} }
template<>
bfloat16 GetValue<bfloat16>(Scalar value) {
return static_cast<bfloat16>(GetValue<float>(value));
}
template<size_t num_dims, typename IndexType, typename StorageType> template<size_t num_dims, typename IndexType, typename StorageType>
void LaunchKernel(ConstantPadParams<num_dims, IndexType> params, StorageType packed_pad_val) { void LaunchKernel(ConstantPadParams<num_dims, IndexType> params, StorageType packed_pad_val) {
ConstantPadKernel<num_dims, IndexType, StorageType>(params, packed_pad_val); ConstantPadKernel<num_dims, IndexType, StorageType>(params, packed_pad_val);
...@@ -163,6 +168,7 @@ template<typename T> ...@@ -163,6 +168,7 @@ template<typename T>
void SimplifyThenLaunch(size_t num_dims, const int64_t* src_dims, const void* src, void SimplifyThenLaunch(size_t num_dims, const int64_t* src_dims, const void* src,
const int64_t* padding_before, const int64_t* padding_after, T pad_val, const int64_t* padding_before, const int64_t* padding_after, T pad_val,
void* dst) { void* dst) {
CHECK_GT(num_dims, 0) << "num_dims must greater than 0";
CHECK_LE(num_dims, kMaxNumDims); CHECK_LE(num_dims, kMaxNumDims);
int64_t simplified_dst_dims[kMaxNumDims]; int64_t simplified_dst_dims[kMaxNumDims];
int64_t simplified_src_dims[kMaxNumDims]; int64_t simplified_src_dims[kMaxNumDims];
......
...@@ -88,9 +88,13 @@ class ElementwiseUnaryFactoryImpl : public ElementwiseUnaryFactory { ...@@ -88,9 +88,13 @@ class ElementwiseUnaryFactoryImpl : public ElementwiseUnaryFactory {
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(MAKE_NEW_SAME_DTYPE_ELEMENTWISE_UNARY_ENTRY, OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(MAKE_NEW_SAME_DTYPE_ELEMENTWISE_UNARY_ENTRY,
UNARY_MATH_OP_SEQ, CPU_PRIMITIVE_NATIVE_TYPE_SEQ) UNARY_MATH_OP_SEQ, CPU_PRIMITIVE_NATIVE_TYPE_SEQ)
// For Float Type OP // For Float Type OP
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(
MAKE_NEW_SAME_DTYPE_ELEMENTWISE_UNARY_ENTRY, UNARY_FLOATING_MATH_OP_SEQ,
CPU_PRIMITIVE_FLOATING_TYPE_SEQ CPU_PRIMITIVE_BFLOAT16_TYPE_SEQ)
// For Int Type OP
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(MAKE_NEW_SAME_DTYPE_ELEMENTWISE_UNARY_ENTRY, OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(MAKE_NEW_SAME_DTYPE_ELEMENTWISE_UNARY_ENTRY,
UNARY_FLOATING_MATH_OP_SEQ, UNARY_INT_MATH_OP_SEQ, CPU_PRIMITIVE_INT_TYPE_SEQ)
CPU_PRIMITIVE_FLOATING_TYPE_SEQ)
// For Utils OP // For Utils OP
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(MAKE_NEW_DIFFERENT_DTYPE_ELEMENTWISE_UNARY_ENTRY, OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(MAKE_NEW_DIFFERENT_DTYPE_ELEMENTWISE_UNARY_ENTRY,
......
...@@ -34,6 +34,11 @@ float16 GetValue<float16>(Scalar value) { ...@@ -34,6 +34,11 @@ float16 GetValue<float16>(Scalar value) {
return static_cast<float16>(GetValue<float>(value)); return static_cast<float16>(GetValue<float>(value));
} }
template<>
bfloat16 GetValue<bfloat16>(Scalar value) {
return static_cast<bfloat16>(GetValue<float>(value));
}
template<typename T> template<typename T>
class FillImpl : public Fill { class FillImpl : public Fill {
public: public:
......
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/ep/include/primitive/tensor_fill.h"
#include "oneflow/core/ep/cpu/primitive/type_seq.h"
namespace oneflow {
namespace ep {
namespace primitive {
namespace {
template<typename T>
class TensorFillImpl : public TensorFill {
public:
OF_DISALLOW_COPY_AND_MOVE(TensorFillImpl);
TensorFillImpl() = default;
~TensorFillImpl() override = default;
void Launch(Stream* stream, const void* src, void* dst, size_t count) override {
const T* value = reinterpret_cast<const T*>(src);
std::fill_n(reinterpret_cast<T*>(dst), count, value[0]);
}
};
template<typename T>
std::unique_ptr<TensorFill> NewTensorFill() {
return std::unique_ptr<TensorFill>(new TensorFillImpl<T>());
}
class TensorFillFactoryImpl : public TensorFillFactory {
public:
OF_DISALLOW_COPY_AND_MOVE(TensorFillFactoryImpl);
TensorFillFactoryImpl() = default;
~TensorFillFactoryImpl() override = default;
std::unique_ptr<TensorFill> New(DataType data_type) override {
#define MAKE_NEW_FILL_ENTRY(type_cpp, type_proto) {type_proto, NewTensorFill<type_cpp>},
static const std::map<DataType, std::function<std::unique_ptr<TensorFill>()>> new_fill_handle{
OF_PP_FOR_EACH_TUPLE(MAKE_NEW_FILL_ENTRY, CPU_PRIMITIVE_ALL_TYPE_SEQ)};
#undef MAKE_NEW_ADD_ENTRY
const auto it = new_fill_handle.find(data_type);
if (it != new_fill_handle.end()) {
return it->second();
} else {
return nullptr;
}
}
};
REGISTER_PRIMITIVE_FACTORY(DeviceType::kCPU, TensorFillFactory, TensorFillFactoryImpl);
} // namespace
} // namespace primitive
} // namespace ep
} // namespace oneflow
...@@ -35,6 +35,7 @@ limitations under the License. ...@@ -35,6 +35,7 @@ limitations under the License.
#define CPU_PRIMITIVE_FLOAT_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(float, DataType::kFloat) #define CPU_PRIMITIVE_FLOAT_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(float, DataType::kFloat)
#define CPU_PRIMITIVE_DOUBLE_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(double, DataType::kDouble) #define CPU_PRIMITIVE_DOUBLE_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(double, DataType::kDouble)
#define CPU_PRIMITIVE_FLOAT16_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(float16, DataType::kFloat16) #define CPU_PRIMITIVE_FLOAT16_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(float16, DataType::kFloat16)
#define CPU_PRIMITIVE_BFLOAT16_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(bfloat16, DataType::kBFloat16)
#define CPU_PRIMITIVE_ONEDNN_BOOl_TYPE_SEQ \ #define CPU_PRIMITIVE_ONEDNN_BOOl_TYPE_SEQ \
OF_PP_MAKE_TUPLE_SEQ(dnnl::memory::data_type::u8, DataType::kBool) OF_PP_MAKE_TUPLE_SEQ(dnnl::memory::data_type::u8, DataType::kBool)
...@@ -63,12 +64,19 @@ limitations under the License. ...@@ -63,12 +64,19 @@ limitations under the License.
#define CPU_PRIMITIVE_ALL_TYPE_SEQ \ #define CPU_PRIMITIVE_ALL_TYPE_SEQ \
CPU_PRIMITIVE_NATIVE_TYPE_SEQ \ CPU_PRIMITIVE_NATIVE_TYPE_SEQ \
CPU_PRIMITIVE_FLOAT16_TYPE_SEQ CPU_PRIMITIVE_FLOAT16_TYPE_SEQ \
CPU_PRIMITIVE_BFLOAT16_TYPE_SEQ
#define CPU_PRIMITIVE_FLOATING_TYPE_SEQ \ #define CPU_PRIMITIVE_FLOATING_TYPE_SEQ \
CPU_PRIMITIVE_FLOAT_TYPE_SEQ \ CPU_PRIMITIVE_FLOAT_TYPE_SEQ \
CPU_PRIMITIVE_DOUBLE_TYPE_SEQ CPU_PRIMITIVE_DOUBLE_TYPE_SEQ
#define CPU_PRIMITIVE_INT_TYPE_SEQ \
CPU_PRIMITIVE_INT8_TYPE_SEQ \
CPU_PRIMITIVE_UINT8_TYPE_SEQ \
CPU_PRIMITIVE_INT32_TYPE_SEQ \
CPU_PRIMITIVE_INT64_TYPE_SEQ
#define UTIL_OPS_DATA_TYPE_SEQ \ #define UTIL_OPS_DATA_TYPE_SEQ \
CPU_PRIMITIVE_INT8_TYPE_SEQ \ CPU_PRIMITIVE_INT8_TYPE_SEQ \
CPU_PRIMITIVE_UINT8_TYPE_SEQ \ CPU_PRIMITIVE_UINT8_TYPE_SEQ \
......
...@@ -15,7 +15,6 @@ limitations under the License. ...@@ -15,7 +15,6 @@ limitations under the License.
*/ */
#include "oneflow/core/ep/common/primitive/unary_functor.h" #include "oneflow/core/ep/common/primitive/unary_functor.h"
#include "oneflow/core/ep/cpu/primitive/type_seq.h" #include "oneflow/core/ep/cpu/primitive/type_seq.h"
#include <cmath>
namespace oneflow { namespace oneflow {
namespace ep { namespace ep {
...@@ -23,7 +22,7 @@ namespace primitive { ...@@ -23,7 +22,7 @@ namespace primitive {
template<typename Dst, typename Src> template<typename Dst, typename Src>
struct UnaryFunctor<DeviceType::kCPU, UnaryOp::kGelu, Dst, Src> { struct UnaryFunctor<DeviceType::kCPU, UnaryOp::kGelu, Dst, Src> {
UnaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}
OF_DEVICE_FUNC Dst operator()(Src src) const { OF_DEVICE_FUNC Dst operator()(Src src) const {
return static_cast<Src>(0.5) * src * (static_cast<Src>(1.0) + std::erf(inv_sqrt2 * src)); return static_cast<Src>(0.5) * src * (static_cast<Src>(1.0) + std::erf(inv_sqrt2 * src));
...@@ -31,9 +30,42 @@ struct UnaryFunctor<DeviceType::kCPU, UnaryOp::kGelu, Dst, Src> { ...@@ -31,9 +30,42 @@ struct UnaryFunctor<DeviceType::kCPU, UnaryOp::kGelu, Dst, Src> {
Src inv_sqrt2 = std::sqrt(0.5); Src inv_sqrt2 = std::sqrt(0.5);
}; };
template<typename Dst, typename Src>
struct UnaryFunctor<DeviceType::kCPU, UnaryOp::kFastGelu, Dst, Src> {
OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}
OF_DEVICE_FUNC Dst operator()(Src src) const {
// ref to: https://mlfromscratch.com/activation-functions-explained/#gelu
const Src half = static_cast<Src>(0.5);
const Src one = static_cast<Src>(1);
const Src tanh_in = alpha * (src + beta * src * src * src);
return half * src * (one + std::tanh(tanh_in));
}
private:
// constant ref to:
// https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/test/testdata/transform/fusion/fast_gelu.py
static constexpr Src alpha = static_cast<Src>(0.7978845608028654);
static constexpr Src beta = static_cast<Src>(0.044714998453855515);
};
template<typename Dst, typename Src>
struct UnaryFunctor<DeviceType::kCPU, UnaryOp::kQuickGelu, Dst, Src> {
OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}
OF_DEVICE_FUNC Dst operator()(Src src) const {
const Src sigmoid =
static_cast<Dst>(static_cast<Src>(1.0) / (static_cast<Src>(1.0) + exp(-src * alpha)));
return src * sigmoid;
}
private:
static constexpr Src alpha = static_cast<Src>(1.702);
};
template<typename Dst, typename Src> template<typename Dst, typename Src>
struct UnaryFunctor<DeviceType::kCPU, UnaryOp::kTanh, Dst, Src> { struct UnaryFunctor<DeviceType::kCPU, UnaryOp::kTanh, Dst, Src> {
UnaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}
OF_DEVICE_FUNC Dst operator()(Src src) const { return std::tanh(src); } OF_DEVICE_FUNC Dst operator()(Src src) const { return std::tanh(src); }
}; };
...@@ -66,6 +98,109 @@ struct UnaryFunctor<DeviceType::kCPU, UnaryOp::kIsNan, bool, double> { ...@@ -66,6 +98,109 @@ struct UnaryFunctor<DeviceType::kCPU, UnaryOp::kIsNan, bool, double> {
OF_DEVICE_FUNC bool operator()(double src) const { return std::isnan(src); } OF_DEVICE_FUNC bool operator()(double src) const { return std::isnan(src); }
}; };
template<typename Src>
struct UnaryFunctor<DeviceType::kCPU, UnaryOp::kIsFinite, bool, Src> {
UnaryFunctor(Scalar attr0, Scalar attr1) {}
OF_DEVICE_FUNC bool operator()(Src src) const { return std::isfinite(src); }
};
template<typename Dst, typename Src>
struct UnaryFunctor<DeviceType::kCPU, UnaryOp::kTrunc, Dst, Src> {
OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}
OF_DEVICE_FUNC Dst operator()(Src src) const { return static_cast<Dst>(std::trunc(src)); }
};
template<typename Dst, typename Src>
struct UnaryFunctor<DeviceType::kCPU, UnaryOp::kRsqrt, Dst, Src> {
OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}
OF_DEVICE_FUNC Dst operator()(Src src) const {
return static_cast<Dst>(static_cast<Src>(1.0) / static_cast<Src>(std::sqrt(src)));
}
};
template<>
struct UnaryFunctor<DeviceType::kCPU, UnaryOp::kAbs, bfloat16, bfloat16> {
OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}
OF_DEVICE_FUNC bfloat16 operator()(bfloat16 src) const { return std::abs(src); }
};
#define SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(op) \
template<> \
struct UnaryFunctor<DeviceType::kCPU, op, bfloat16, bfloat16> { \
OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) : float_functor(attr0, attr1) {} \
\
UnaryFunctor<DeviceType::kCPU, op, float, float> float_functor; \
OF_DEVICE_FUNC bfloat16 operator()(bfloat16 src) const { \
return bfloat16(float_functor(static_cast<float>(src))); \
} \
};
SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kElu);
SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kCelu);
SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kGelu);
SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kHardSwish);
SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kHardSigmoid);
SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kHardShrink);
SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kHardTanh);
SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kLeakyRelu);
SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kMish);
SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kSelu);
SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kSilu);
SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kSoftShrink);
SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kSoftSign);
SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kSoftPlus);
SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kTanh);
SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kThreshold);
SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kAcos);
SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kAcosh);
SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kAsin);
SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kAsinh);
SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kAtan);
SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kAtanh);
SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kCeil);
SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kCos);
SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kCosh);
SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kErf);
SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kErfc);
SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kExp);
SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kExpm1);
SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kFloor);
SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kLgamma);
SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kLog);
SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kLog2);
SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kLog1p);
SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kLogSigmoid);
SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kRint);
SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kRound);
SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kRsqrt);
SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kSigmoid);
SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kSin);
SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kSinh);
SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kSqrt);
SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kSquare);
SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kTan);
SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kReciprocalNoNan);
SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kNotEqualZero);
SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kFastGelu);
SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kQuickGelu);
template<>
struct UnaryFunctor<DeviceType::kCPU, UnaryOp::kIsInf, bool, bfloat16> {
UnaryFunctor(Scalar attr0, Scalar attr1) {}
OF_DEVICE_FUNC bool operator()(bfloat16 src) const { return std::isinf(src); }
};
template<>
struct UnaryFunctor<DeviceType::kCPU, UnaryOp::kIsNan, bool, bfloat16> {
UnaryFunctor(Scalar attr0, Scalar attr1) {}
OF_DEVICE_FUNC bool operator()(bfloat16 src) const { return std::isnan(src); }
};
} // namespace primitive } // namespace primitive
} // namespace ep } // namespace ep
} // namespace oneflow } // namespace oneflow
...@@ -55,6 +55,13 @@ CudaDevice::CudaDevice(int device_index, DeviceManager* device_manager) ...@@ -55,6 +55,13 @@ CudaDevice::CudaDevice(int device_index, DeviceManager* device_manager)
const_ones_buffer_bf16_(nullptr) { const_ones_buffer_bf16_(nullptr) {
CudaCurrentDeviceGuard guard(device_index_); CudaCurrentDeviceGuard guard(device_index_);
OF_CUDA_CHECK(cudaGetDeviceProperties(&properties_, device_index_)); OF_CUDA_CHECK(cudaGetDeviceProperties(&properties_, device_index_));
{
const char* env_name = "ONEFLOW_EP_CUDA_DEVICE_FLAGS";
if (std::getenv(env_name) != nullptr) {
const unsigned int flags = ParseIntegerFromEnv(env_name, 0);
OF_CUDA_CHECK(cudaSetDeviceFlags(flags));
}
}
event_flags_ = cudaEventDisableTiming; event_flags_ = cudaEventDisableTiming;
if (ParseBooleanFromEnv("ONEFLOW_STREAM_CUDA_EVENT_FLAG_BLOCKING_SYNC", false)) { if (ParseBooleanFromEnv("ONEFLOW_STREAM_CUDA_EVENT_FLAG_BLOCKING_SYNC", false)) {
event_flags_ |= cudaEventBlockingSync; event_flags_ |= cudaEventBlockingSync;
...@@ -119,6 +126,10 @@ Maybe<void> CudaDevice::Alloc(const AllocationOptions& options, void** ptr, size ...@@ -119,6 +126,10 @@ Maybe<void> CudaDevice::Alloc(const AllocationOptions& options, void** ptr, size
CHECK(!options.HasPinnedDevice()); CHECK(!options.HasPinnedDevice());
cudaError_t err = cudaMalloc(ptr, size); cudaError_t err = cudaMalloc(ptr, size);
if (err != cudaSuccess) { if (err != cudaSuccess) {
if (err == cudaErrorMemoryAllocation) {
// NOTE:return out of memory error, so vm will try to shrink memory and rerun
return Error::OutOfMemoryError() << cudaGetErrorString(err);
}
return Error::RuntimeError() << cudaGetErrorString(err); return Error::RuntimeError() << cudaGetErrorString(err);
} else { } else {
return Maybe<void>::Ok(); return Maybe<void>::Ok();
...@@ -177,3 +188,176 @@ const void* CudaDevice::GetConstOnes(DataType data_type, size_t n) const { ...@@ -177,3 +188,176 @@ const void* CudaDevice::GetConstOnes(DataType data_type, size_t n) const {
} // namespace oneflow } // namespace oneflow
#endif // WITH_CUDA #endif // WITH_CUDA
#ifdef WITH_ROCM
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
// #if CUDA_VERSION >= 11000
// #include <cuda_bf16.h>
// #endif
namespace oneflow {
namespace ep {
namespace {
constexpr size_t kDefaultConstBufElementCount = 1024 * 1024;
template<typename T>
void CreateConstBuffer(void** buf, T value, size_t n) {
OF_CUDA_CHECK(hipMalloc(buf, n * sizeof(T)));
std::vector<T> host(n, value);
OF_CUDA_CHECK(hipMemcpy(*buf, host.data(), n * sizeof(T), hipMemcpyDefault));
}
} // namespace
CudaDevice::CudaDevice(int device_index, DeviceManager* device_manager)
: device_index_(device_index),
event_flags_{},
properties_{},
device_manager_(device_manager),
const_buf_elem_cnt_(0),
const_zeros_buffer_(nullptr),
const_ones_buffer_fp32_(nullptr),
const_ones_buffer_fp16_(nullptr),
const_ones_buffer_bf16_(nullptr) {
CudaCurrentDeviceGuard guard(device_index_);
OF_CUDA_CHECK(hipGetDeviceProperties(&properties_, device_index_));
{
const char* env_name = "ONEFLOW_EP_CUDA_DEVICE_FLAGS";
if (std::getenv(env_name) != nullptr) {
const unsigned int flags = ParseIntegerFromEnv(env_name, 0);
OF_CUDA_CHECK(hipSetDeviceFlags(flags));
}
}
event_flags_ = hipEventDisableTiming;
if (ParseBooleanFromEnv("ONEFLOW_STREAM_CUDA_EVENT_FLAG_BLOCKING_SYNC", false)) {
event_flags_ |= hipEventBlockingSync;
}
const_buf_elem_cnt_ = ParseIntegerFromEnv("ONEFLOW_EP_CUDA_CONST_BUFFER_ELEMENT_COUNT",
kDefaultConstBufElementCount);
if (const_buf_elem_cnt_ > 0) {
CreateConstBuffer<float>(&const_zeros_buffer_, static_cast<float>(0), const_buf_elem_cnt_);
CreateConstBuffer<float>(&const_ones_buffer_fp32_, static_cast<float>(1.0),
const_buf_elem_cnt_);
CreateConstBuffer<half>(&const_ones_buffer_fp16_, static_cast<half>(1.0), const_buf_elem_cnt_);
// #if CUDA_VERSION >= 11000
// CreateConstBuffer<nv_bfloat16>(&const_ones_buffer_bf16_, static_cast<nv_bfloat16>(1.0),
// const_buf_elem_cnt_);
// #endif
}
}
CudaDevice::~CudaDevice() {
CudaCurrentDeviceGuard guard(device_index_);
for (auto* event : events_) { delete event; }
OF_CUDA_CHECK(hipFree(const_zeros_buffer_));
OF_CUDA_CHECK(hipFree(const_ones_buffer_fp32_));
OF_CUDA_CHECK(hipFree(const_ones_buffer_fp16_));
OF_CUDA_CHECK(hipFree(const_ones_buffer_bf16_));
}
void CudaDevice::SetAsActiveDevice() { OF_CUDA_CHECK(hipSetDevice(device_index_)); }
Stream* CudaDevice::CreateStream() {
CudaCurrentDeviceGuard guard(device_index_);
return new CudaStream(this);
}
void CudaDevice::DestroyStream(Stream* stream) {
CudaCurrentDeviceGuard guard(device_index_);
delete stream;
}
void CudaDevice::CreateEvents(Event** events, size_t count) {
size_t copied = 0;
{
std::lock_guard<std::mutex> lock(events_mutex_);
copied = std::min(count, events_.size());
size_t offset = events_.size() - copied;
std::copy(events_.begin() + offset, events_.end(), events);
events_.resize(offset);
}
if (copied != count) {
CudaCurrentDeviceGuard guard(device_index_);
for (size_t i = copied; i < count; ++i) { events[i] = new CudaEvent(event_flags_); }
}
}
void CudaDevice::DestroyEvents(Event** events, size_t count) {
std::lock_guard<std::mutex> lock(events_mutex_);
events_.insert(events_.end(), events, events + count);
}
Maybe<void> CudaDevice::Alloc(const AllocationOptions& options, void** ptr, size_t size) {
CudaCurrentDeviceGuard guard(device_index_);
CHECK(!options.HasPinnedDevice());
hipError_t err = hipMalloc(ptr, size);
if (err != hipSuccess) {
if (err == hipErrorMemoryAllocation) {
// NOTE:return out of memory error, so vm will try to shrink memory and rerun
return Error::OutOfMemoryError() << hipGetErrorString(err);
}
return Error::RuntimeError() << hipGetErrorString(err);
} else {
return Maybe<void>::Ok();
}
}
void CudaDevice::Free(const AllocationOptions& attr, void* ptr) {
CudaCurrentDeviceGuard guard(device_index_);
OF_CUDA_CHECK(hipFree(ptr));
}
Maybe<void> CudaDevice::AllocPinned(const AllocationOptions& options, void** ptr, size_t size) {
CudaCurrentDeviceGuard guard(device_index_);
hipError_t err = NumaAwareCudaMallocHost(device_index_, ptr, size);
if (err != hipSuccess) {
return Error::RuntimeError() << hipGetErrorString(err);
} else {
return Maybe<void>::Ok();
}
}
void CudaDevice::FreePinned(const AllocationOptions& options, void* ptr) {
CudaCurrentDeviceGuard guard(device_index_);
OF_CUDA_CHECK(hipHostFree(ptr));
}
const hipDeviceProp_t& CudaDevice::properties() const { return properties_; }
const void* CudaDevice::GetConstZeros(DataType data_type, size_t n) const {
if (GetSizeOfDataType(data_type) * n
<= GetSizeOfDataType(DataType::kFloat) * const_buf_elem_cnt_) {
return const_zeros_buffer_;
} else {
return nullptr;
}
}
const void* CudaDevice::GetConstOnes(DataType data_type, size_t n) const {
if (n <= const_buf_elem_cnt_) {
if (data_type == DataType::kFloat) {
return const_ones_buffer_fp32_;
} else if (data_type == DataType::kFloat16) {
return const_ones_buffer_fp16_;
} else if (data_type == DataType::kBFloat16) {
return const_ones_buffer_bf16_;
} else {
return nullptr;
}
} else {
return nullptr;
}
}
} // namespace ep
} // namespace oneflow
#endif // WITH_ROCM
...@@ -75,4 +75,60 @@ class CudaDevice : public Device { ...@@ -75,4 +75,60 @@ class CudaDevice : public Device {
#endif // WITH_CUDA #endif // WITH_CUDA
#ifdef WITH_ROCM
#include <hip/hip_runtime.h>
namespace oneflow {
namespace ep {
class CudaDevice : public Device {
public:
OF_DISALLOW_COPY_AND_MOVE(CudaDevice);
explicit CudaDevice(int device_index, DeviceManager* device_manager);
~CudaDevice() override;
void SetAsActiveDevice() override;
DeviceType device_type() const override { return DeviceType::kCUDA; }
size_t device_index() const override { return device_index_; }
DeviceManager* device_manager() const override { return device_manager_; }
Stream* CreateStream() override;
void DestroyStream(Stream* stream) override;
void CreateEvents(Event** events, size_t count) override;
void DestroyEvents(Event** events, size_t count) override;
Maybe<void> Alloc(const AllocationOptions& options, void** ptr, size_t size) override;
void Free(const AllocationOptions& options, void* ptr) override;
Maybe<void> AllocPinned(const AllocationOptions& options, void** ptr, size_t size) override;
void FreePinned(const AllocationOptions& options, void* ptr) override;
const hipDeviceProp_t& properties() const;
const void* GetConstZeros(DataType data_type, size_t n) const;
const void* GetConstOnes(DataType data_type, size_t n) const;
private:
int device_index_;
std::mutex events_mutex_;
std::vector<Event*> events_;
unsigned int event_flags_;
hipDeviceProp_t properties_;
DeviceManager* device_manager_;
int64_t const_buf_elem_cnt_;
void* const_zeros_buffer_;
void* const_ones_buffer_fp32_;
void* const_ones_buffer_fp16_;
void* const_ones_buffer_bf16_;
};
} // namespace ep
} // namespace oneflow
#endif // WITH_ROCM
#endif // ONEFLOW_CORE_EP_CUDA_CUDA_DEVICE_H_ #endif // ONEFLOW_CORE_EP_CUDA_CUDA_DEVICE_H_
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment