Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Oneflow
Commits
a715222c
Commit
a715222c
authored
Feb 28, 2023
by
yuguo
Browse files
0.9.1-rocm
parent
f262efc9
Changes
469
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1766 additions
and
95 deletions
+1766
-95
oneflow/core/ep/common/primitive/broadcast_elementwise_binary.h
...w/core/ep/common/primitive/broadcast_elementwise_binary.h
+63
-26
oneflow/core/ep/common/primitive/broadcast_elementwise_unary.h
...ow/core/ep/common/primitive/broadcast_elementwise_unary.h
+182
-0
oneflow/core/ep/common/primitive/copy_nd.h
oneflow/core/ep/common/primitive/copy_nd.h
+1
-0
oneflow/core/ep/common/primitive/elementwise_unary.h
oneflow/core/ep/common/primitive/elementwise_unary.h
+61
-18
oneflow/core/ep/common/primitive/matmul.cpp
oneflow/core/ep/common/primitive/matmul.cpp
+0
-2
oneflow/core/ep/common/primitive/unary_functor.h
oneflow/core/ep/common/primitive/unary_functor.h
+317
-16
oneflow/core/ep/common/primitive/util.h
oneflow/core/ep/common/primitive/util.h
+65
-0
oneflow/core/ep/cpu/cpu_device.cpp
oneflow/core/ep/cpu/cpu_device.cpp
+5
-7
oneflow/core/ep/cpu/primitive/binary_functor.h
oneflow/core/ep/cpu/primitive/binary_functor.h
+315
-5
oneflow/core/ep/cpu/primitive/broadcast_elementwise_binary.cpp
...ow/core/ep/cpu/primitive/broadcast_elementwise_binary.cpp
+22
-9
oneflow/core/ep/cpu/primitive/broadcast_elementwise_unary.cpp
...low/core/ep/cpu/primitive/broadcast_elementwise_unary.cpp
+233
-0
oneflow/core/ep/cpu/primitive/cast.cpp
oneflow/core/ep/cpu/primitive/cast.cpp
+26
-6
oneflow/core/ep/cpu/primitive/constant_pad.cpp
oneflow/core/ep/cpu/primitive/constant_pad.cpp
+6
-0
oneflow/core/ep/cpu/primitive/elementwise_unary.cpp
oneflow/core/ep/cpu/primitive/elementwise_unary.cpp
+6
-2
oneflow/core/ep/cpu/primitive/fill.cpp
oneflow/core/ep/cpu/primitive/fill.cpp
+5
-0
oneflow/core/ep/cpu/primitive/tensor_fill.cpp
oneflow/core/ep/cpu/primitive/tensor_fill.cpp
+72
-0
oneflow/core/ep/cpu/primitive/type_seq.h
oneflow/core/ep/cpu/primitive/type_seq.h
+9
-1
oneflow/core/ep/cpu/primitive/unary_functor.h
oneflow/core/ep/cpu/primitive/unary_functor.h
+138
-3
oneflow/core/ep/cuda/cuda_device.cpp
oneflow/core/ep/cuda/cuda_device.cpp
+184
-0
oneflow/core/ep/cuda/cuda_device.h
oneflow/core/ep/cuda/cuda_device.h
+56
-0
No files found.
Too many changes to show.
To preserve performance only
469 of 469+
files are displayed.
Plain diff
Email patch
oneflow/core/ep/common/primitive/broadcast_elementwise_binary.h
View file @
a715222c
...
...
@@ -30,15 +30,6 @@ namespace broadcast_elementwise_binary {
constexpr
size_t
kMaxNumDims
=
8
;
inline
void
CheckInplace
(
size_t
num_dims
,
const
int64_t
*
src0_dims
,
const
void
*
src0
,
const
int64_t
*
src1_dims
,
const
void
*
src1
,
const
int64_t
*
dst_dims
,
const
void
*
dst
)
{
for
(
int64_t
i
=
0
;
i
<
num_dims
;
++
i
)
{
if
(
src0
==
dst
)
{
CHECK_EQ
(
src0_dims
[
i
],
dst_dims
[
i
]);
}
if
(
src1
==
dst
)
{
CHECK_EQ
(
src1_dims
[
i
],
dst_dims
[
i
]);
}
}
}
inline
bool
IsDimsEquals
(
size_t
num_src0_dims
,
const
int64_t
*
src0_dims
,
size_t
num_src1_dims
,
const
int64_t
*
src1_dims
)
{
if
(
num_src0_dims
!=
num_src1_dims
)
{
return
false
;
}
...
...
@@ -48,22 +39,36 @@ inline bool IsDimsEquals(size_t num_src0_dims, const int64_t* src0_dims, size_t
return
true
;
}
#define BINARY_MATH_OP_SEQ \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kAdd) \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kSub) \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kMul) \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kDiv) \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kMax) \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kMin) \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kPow)
#define BINARY_COMPARISION_OP_SEQ \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kEqual) \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kNotEqual) \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kLessThan) \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kLessEqual) \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kGreaterThan) \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kGreaterEqual)
#define BINARY_MATH_OP_SEQ_0 \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kAdd) \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kSub) \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kMul) \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kDiv) \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kMax) \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kMin) \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kPow) \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kFmod) \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kFloorDiv) \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kTruncDiv) \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kFloorMod) \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kScalarExpPowerGrad)
#define BINARY_MATH_OP_SEQ_1 \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kScalarBasePowerGrad)
#define BINARY_MATH_OP_SEQ \
BINARY_MATH_OP_SEQ_0 \
BINARY_MATH_OP_SEQ_1
#define BINARY_COMPARISION_OP_SEQ \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kEqual) \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kNotEqual) \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kLessThan) \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kLessEqual) \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kGreaterThan) \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kGreaterEqual) \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kIsCloseEqualNan) \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kIsClose)
#define BINARY_LOGICAL_OP_SEQ \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kLogicalAnd) \
...
...
@@ -87,7 +92,39 @@ inline bool IsDimsEquals(size_t num_src0_dims, const int64_t* src0_dims, size_t
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kSoftplusBackwardWithDyX) \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kSoftshrinkBackwardWithDyY) \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kTanhBackwardWithDyX) \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kThresholdBackwardWithDyX)
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kThresholdBackwardWithDyX) \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kFastGeluBackwardWithDyX) \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kQuickGeluBackwardWithDyX)
#define BINARY_MATH_BACKWARD_OP_SEQ \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kAbsBackwardWithDyX) \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kAcosBackwardWithDyX) \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kAcoshBackwardWithDyX) \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kAsinBackwardWithDyX) \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kAsinhBackwardWithDyX) \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kAtanBackwardWithDyX) \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kAtanhBackwardWithDyX) \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kCosBackwardWithDyX) \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kCoshBackwardWithDyX) \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kErfBackwardWithDyX) \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kErfcBackwardWithDyX) \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kExpBackwardWithDyX) \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kExpm1BackwardWithDyX) \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kLgammaBackwardWithDyX) \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kLogBackwardWithDyX) \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kLog2BackwardWithDyX) \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kLog10BackwardWithDyX) \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kLog1pBackwardWithDyX) \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kLogSigmoidBackwardWithDyX) \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kReciprocalBackwardWithDyX) \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kReciprocalNoNanBackwardWithDyX) \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kRsqrtBackwardWithDyX) \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kSinBackwardWithDyX) \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kSigmoidBackwardWithDyY) \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kSinhBackwardWithDyX) \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kSqrtBackwardWithDyX) \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kSquareBackwardWithDyX) \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kTanBackwardWithDyX)
}
// namespace broadcast_elementwise_binary
}
// namespace primitive
...
...
oneflow/core/ep/common/primitive/broadcast_elementwise_unary.h
0 → 100644
View file @
a715222c
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_PRIMITIVE_COMMON_BROADCAST_ELEMENTWISE_UNARY
#define ONEFLOW_CORE_PRIMITIVE_COMMON_BROADCAST_ELEMENTWISE_UNARY
#include "oneflow/core/ep/include/primitive/broadcast_elementwise_unary.h"
#include "oneflow/core/ep/include/primitive/fast_integer_math.h"
#include "oneflow/core/ep/common/primitive/util.h"
namespace
oneflow
{
namespace
ep
{
namespace
primitive
{
namespace
broadcast_elementwise_unary
{
constexpr
size_t
kMaxNumDims
=
8
;
template
<
typename
T
,
int
N
>
class
IndexToOffsetWithStrideCalculator
{
public:
IndexToOffsetWithStrideCalculator
()
{}
OF_DEVICE_FUNC
explicit
IndexToOffsetWithStrideCalculator
(
const
T
*
strides
)
{
InitStrides
(
strides
,
N
);
}
template
<
typename
U
>
OF_DEVICE_FUNC
explicit
IndexToOffsetWithStrideCalculator
(
const
U
*
strides
)
{
T
strides_arr
[
N
];
for
(
int
i
=
0
;
i
<
N
;
++
i
)
{
strides_arr
[
i
]
=
strides
[
i
];
}
InitStrides
(
strides_arr
,
N
);
}
OF_DEVICE_FUNC
explicit
IndexToOffsetWithStrideCalculator
(
const
T
*
strides
,
int
n
)
{
InitStrides
(
strides
,
n
);
}
template
<
typename
U
>
OF_DEVICE_FUNC
explicit
IndexToOffsetWithStrideCalculator
(
const
U
*
strides
,
int
n
)
{
T
strides_arr
[
N
];
for
(
int
i
=
0
;
i
<
N
;
++
i
)
{
if
(
i
<
n
)
{
strides_arr
[
i
]
=
strides
[
i
];
}
}
InitStrides
(
strides_arr
,
n
);
}
~
IndexToOffsetWithStrideCalculator
()
=
default
;
OF_DEVICE_FUNC
T
NdIndexToOffset
(
const
T
*
index
)
const
{
T
offset
=
0
;
#ifdef __CUDA_ARCH__
#pragma unroll
#endif
for
(
int
i
=
0
;
i
<
N
-
1
;
++
i
)
{
offset
+=
index
[
i
]
*
stride_
[
i
];
}
offset
+=
index
[
N
-
1
];
return
offset
;
}
OF_DEVICE_FUNC
T
NdIndexToOffset
(
const
T
*
index
,
int
n
)
const
{
assert
(
n
<=
N
);
T
offset
=
0
;
#ifdef __CUDA_ARCH__
#pragma unroll
#endif
for
(
int
i
=
0
;
i
<
N
;
++
i
)
{
if
(
i
<
n
)
{
offset
+=
index
[
i
]
*
stride_
[
i
];
}
}
return
offset
;
}
OF_DEVICE_FUNC
constexpr
int
Size
()
const
{
return
N
;
}
private:
OF_DEVICE_FUNC
void
InitStrides
(
const
T
*
strides
,
const
int
n
)
{
for
(
int
i
=
n
;
i
<
N
;
++
i
)
{
stride_
[
i
]
=
1
;
}
for
(
int
i
=
n
-
1
;
i
>=
0
;
--
i
)
{
stride_
[
i
]
=
strides
[
i
];
}
}
T
stride_
[
N
];
};
template
<
typename
T
,
int
N
>
class
OffsetToIndexWithStrideCalculator
{
public:
OffsetToIndexWithStrideCalculator
()
{}
OF_DEVICE_FUNC
explicit
OffsetToIndexWithStrideCalculator
(
const
T
*
dims
)
{
InitFastIntegerMath
(
dims
,
N
);
}
template
<
typename
U
>
OF_DEVICE_FUNC
explicit
OffsetToIndexWithStrideCalculator
(
const
U
*
dims
)
{
T
dims_arr
[
N
];
for
(
int
i
=
0
;
i
<
N
;
++
i
)
{
dims_arr
[
i
]
=
dims
[
i
];
}
InitFastIntegerMath
(
dims_arr
,
N
);
}
OF_DEVICE_FUNC
explicit
OffsetToIndexWithStrideCalculator
(
const
T
*
dims
,
int
n
)
{
InitFastIntegerMath
(
dims
,
n
);
}
template
<
typename
U
>
OF_DEVICE_FUNC
explicit
OffsetToIndexWithStrideCalculator
(
const
U
*
dims
,
int
n
)
{
T
dims_arr
[
N
];
for
(
int
i
=
0
;
i
<
N
;
++
i
)
{
if
(
i
<
n
)
{
dims_arr
[
i
]
=
dims
[
i
];
}
}
InitFastIntegerMath
(
dims_arr
,
n
);
}
~
OffsetToIndexWithStrideCalculator
()
=
default
;
OF_DEVICE_FUNC
void
OffsetToNdIndex
(
T
offset
,
T
*
index
)
const
{
T
remaining
=
offset
;
#ifdef __CUDA_ARCH__
#pragma unroll
#endif
for
(
int
i
=
0
;
i
<
N
-
1
;
++
i
)
{
const
T
idx
=
math_helper_
[
i
].
divides
(
remaining
);
index
[
i
]
=
idx
;
remaining
=
remaining
-
math_helper_
[
i
].
mul
(
idx
);
}
index
[
N
-
1
]
=
remaining
;
}
OF_DEVICE_FUNC
void
OffsetToNdIndex
(
T
offset
,
T
*
index
,
int
n
)
const
{
assert
(
n
<=
N
);
T
remaining
=
offset
;
#ifdef __CUDA_ARCH__
#pragma unroll
#endif
for
(
int
i
=
0
;
i
<
N
;
++
i
)
{
if
(
i
==
n
-
1
)
{
break
;
}
if
(
i
<
n
-
1
)
{
const
T
idx
=
math_helper_
[
i
].
divides
(
remaining
);
index
[
i
]
=
idx
;
remaining
=
remaining
-
math_helper_
[
i
].
mul
(
idx
);
}
}
index
[
n
-
1
]
=
remaining
;
}
OF_DEVICE_FUNC
constexpr
int
Size
()
const
{
return
N
;
}
private:
OF_DEVICE_FUNC
void
InitFastIntegerMath
(
const
T
*
dims
,
const
int
n
)
{
T
stride_arr
[
N
];
for
(
int
i
=
n
-
1
;
i
<
N
;
++
i
)
{
stride_arr
[
i
]
=
1
;
math_helper_
[
i
]
=
FastIntegerMath
<
T
>
(
1
);
}
for
(
int
i
=
n
-
2
;
i
>=
0
;
--
i
)
{
stride_arr
[
i
]
=
dims
[
i
+
1
]
*
stride_arr
[
i
+
1
];
math_helper_
[
i
]
=
FastIntegerMath
<
T
>
(
stride_arr
[
i
]);
}
}
FastIntegerMath
<
T
>
math_helper_
[
N
];
};
#define UNARY_BROADCAST_OP_SEQ OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kIdentity)
}
// namespace broadcast_elementwise_unary
}
// namespace primitive
}
// namespace ep
}
// namespace oneflow
#endif // ONEFLOW_CORE_PRIMITIVE_COMMON_BROADCAST_ELEMENTWISE_UNARY
oneflow/core/ep/common/primitive/copy_nd.h
View file @
a715222c
...
...
@@ -206,6 +206,7 @@ void SimplifyCopyNd(size_t num_dims, const int64_t* dst_dims, const int64_t* dst
void
SimplifyThenLaunch
(
Stream
*
stream
,
DataType
data_type
,
size_t
num_dims
,
void
*
dst
,
const
int64_t
*
dst_dims
,
const
int64_t
*
dst_pos
,
const
void
*
src
,
const
int64_t
*
src_dims
,
const
int64_t
*
src_pos
,
const
int64_t
*
extent
)
{
CHECK_GT
(
num_dims
,
0
)
<<
"num_dims must greater than 0"
;
CHECK_LE
(
num_dims
,
kMaxNumDims
);
size_t
simplified_num_dims
=
0
;
int64_t
simplified_dst_dims
[
kMaxNumDims
];
...
...
oneflow/core/ep/common/primitive/elementwise_unary.h
View file @
a715222c
...
...
@@ -25,29 +25,72 @@ namespace primitive {
#define UNARY_MATH_OP_SEQ OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kRelu)
#define UNARY_FLOATING_MATH_OP_SEQ \
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kElu) \
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kCelu) \
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kGelu) \
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kHardSwish) \
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kHardSigmoid) \
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kHardShrink) \
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kHardTanh) \
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kLeakyRelu) \
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kMish) \
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kSelu) \
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kSilu) \
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kSoftShrink) \
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kSoftSign) \
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kSoftPlus) \
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kTanh) \
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kThreshold)
#define UNARY_FLOATING_MATH_OP_SEQ \
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kElu) \
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kCelu) \
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kGelu) \
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kHardSwish) \
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kHardSigmoid) \
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kHardShrink) \
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kHardTanh) \
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kLeakyRelu) \
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kMish) \
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kSelu) \
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kSilu) \
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kSoftShrink) \
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kSoftSign) \
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kSoftPlus) \
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kTanh) \
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kThreshold) \
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kAbs) \
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kAcos) \
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kAcosh) \
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kAsin) \
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kAsinh) \
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kAtan) \
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kAtanh) \
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kCeil) \
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kCos) \
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kCosh) \
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kErf) \
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kErfc) \
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kExp) \
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kExpm1) \
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kFloor) \
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kLgamma) \
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kLog) \
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kLog2) \
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kLog10) \
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kLog1p) \
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kLogSigmoid) \
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kNegative) \
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kReciprocal) \
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kReciprocalNoNan) \
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kRint) \
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kRound) \
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kRsqrt) \
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kSigmoid) \
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kSign) \
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kSin) \
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kSinh) \
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kSqrt) \
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kSign) \
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kSquare) \
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kTan) \
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kTrunc) \
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kNotEqualZero) \
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kNanAssign) \
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kFastGelu) \
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kQuickGelu)
#define UNARY_INT_MATH_OP_SEQ OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kAbs)
#define UNARY_LOGICAL_OP_SEQ OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kLogicalNot)
#define UNARY_UTILS_OP_SEQ \
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kIsInf) \
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kIsNan)
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kIsNan) \
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kIsFinite)
}
// namespace primitive
}
// namespace ep
...
...
oneflow/core/ep/common/primitive/matmul.cpp
View file @
a715222c
...
...
@@ -60,12 +60,10 @@ REGISTER_PRIMITIVE_FACTORY(DeviceType::kCPU, MatmulFactory, MatmulFactoryImpl<De
#ifdef WITH_CUDA
REGISTER_PRIMITIVE_FACTORY
(
DeviceType
::
kCUDA
,
MatmulFactory
,
MatmulFactoryImpl
<
DeviceType
::
kCUDA
>
);
#endif // WITH_CUDA
#ifdef WITH_ROCM
REGISTER_PRIMITIVE_FACTORY
(
DeviceType
::
kCUDA
,
MatmulFactory
,
MatmulFactoryImpl
<
DeviceType
::
kCUDA
>
);
#endif // WITH_ROCM
}
// namespace
}
// namespace primitive
...
...
oneflow/core/ep/common/primitive/unary_functor.h
View file @
a715222c
...
...
@@ -28,9 +28,16 @@ namespace primitive {
template
<
DeviceType
device
,
UnaryOp
unary_op
,
typename
Dst
,
typename
Src
>
struct
UnaryFunctor
;
template
<
DeviceType
device
,
typename
Dst
,
typename
Src
>
struct
UnaryFunctor
<
device
,
UnaryOp
::
kIdentity
,
Dst
,
Src
>
{
OF_DEVICE_FUNC
UnaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
Dst
operator
()(
Src
src
)
const
{
return
static_cast
<
Dst
>
(
src
);
}
};
template
<
DeviceType
device
,
typename
Dst
,
typename
Src
>
struct
UnaryFunctor
<
device
,
UnaryOp
::
kElu
,
Dst
,
Src
>
{
UnaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
:
alpha
(
attr0
.
Value
<
double
>
())
{}
OF_DEVICE_FUNC
UnaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
:
alpha
(
attr0
.
Value
<
double
>
())
{}
OF_DEVICE_FUNC
Dst
operator
()(
Src
src
)
const
{
return
static_cast
<
Dst
>
(
...
...
@@ -41,7 +48,7 @@ struct UnaryFunctor<device, UnaryOp::kElu, Dst, Src> {
template
<
DeviceType
device
,
typename
Dst
,
typename
Src
>
struct
UnaryFunctor
<
device
,
UnaryOp
::
kCelu
,
Dst
,
Src
>
{
UnaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
OF_DEVICE_FUNC
UnaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
:
alpha
(
attr0
.
Value
<
double
>
()),
inv_alpha
(
1.0
f
/
attr0
.
Value
<
double
>
())
{}
OF_DEVICE_FUNC
Dst
operator
()(
Src
src
)
const
{
...
...
@@ -54,7 +61,7 @@ struct UnaryFunctor<device, UnaryOp::kCelu, Dst, Src> {
template
<
DeviceType
device
,
typename
Dst
,
typename
Src
>
struct
UnaryFunctor
<
device
,
UnaryOp
::
kHardSwish
,
Dst
,
Src
>
{
UnaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
UnaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
Dst
operator
()(
Src
src
)
const
{
if
(
src
<=
static_cast
<
Src
>
(
-
3
))
{
...
...
@@ -69,7 +76,7 @@ struct UnaryFunctor<device, UnaryOp::kHardSwish, Dst, Src> {
template
<
DeviceType
device
,
typename
Dst
,
typename
Src
>
struct
UnaryFunctor
<
device
,
UnaryOp
::
kHardSigmoid
,
Dst
,
Src
>
{
UnaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
UnaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
Dst
operator
()(
Src
src
)
const
{
if
(
src
<=
static_cast
<
Src
>
(
-
3
))
{
...
...
@@ -84,7 +91,7 @@ struct UnaryFunctor<device, UnaryOp::kHardSigmoid, Dst, Src> {
template
<
DeviceType
device
,
typename
Dst
,
typename
Src
>
struct
UnaryFunctor
<
device
,
UnaryOp
::
kHardShrink
,
Dst
,
Src
>
{
UnaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
:
lambd
(
attr0
.
Value
<
double
>
())
{}
OF_DEVICE_FUNC
UnaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
:
lambd
(
attr0
.
Value
<
double
>
())
{}
OF_DEVICE_FUNC
Dst
operator
()(
Src
src
)
const
{
return
(
src
<=
lambd
&&
src
>=
-
lambd
)
?
static_cast
<
Dst
>
(
0
)
:
static_cast
<
Dst
>
(
src
);
...
...
@@ -95,7 +102,7 @@ struct UnaryFunctor<device, UnaryOp::kHardShrink, Dst, Src> {
template
<
DeviceType
device
,
typename
Dst
,
typename
Src
>
struct
UnaryFunctor
<
device
,
UnaryOp
::
kHardTanh
,
Dst
,
Src
>
{
UnaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
OF_DEVICE_FUNC
UnaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
:
min_val
(
attr0
.
Value
<
double
>
()),
max_val
(
attr1
.
Value
<
double
>
())
{}
OF_DEVICE_FUNC
Dst
operator
()(
Src
src
)
const
{
...
...
@@ -114,7 +121,7 @@ struct UnaryFunctor<device, UnaryOp::kHardTanh, Dst, Src> {
template
<
DeviceType
device
,
typename
Dst
,
typename
Src
>
struct
UnaryFunctor
<
device
,
UnaryOp
::
kLeakyRelu
,
Dst
,
Src
>
{
UnaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
:
alpha
(
attr0
.
Value
<
float
>
())
{}
OF_DEVICE_FUNC
UnaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
:
alpha
(
attr0
.
Value
<
float
>
())
{}
OF_DEVICE_FUNC
Dst
operator
()(
Src
src
)
const
{
return
static_cast
<
Dst
>
((
src
>
static_cast
<
Src
>
(
0.0
))
?
src
:
alpha
*
src
);
...
...
@@ -124,7 +131,7 @@ struct UnaryFunctor<device, UnaryOp::kLeakyRelu, Dst, Src> {
template
<
DeviceType
device
,
typename
Dst
,
typename
Src
>
struct
UnaryFunctor
<
device
,
UnaryOp
::
kMish
,
Dst
,
Src
>
{
UnaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
UnaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
Dst
operator
()(
Src
src
)
const
{
Src
soft_plus_val
=
log
(
static_cast
<
Src
>
(
1
)
+
exp
(
src
));
...
...
@@ -137,7 +144,7 @@ struct UnaryFunctor<device, UnaryOp::kMish, Dst, Src> {
template
<
DeviceType
device
,
typename
Dst
,
typename
Src
>
struct
UnaryFunctor
<
device
,
UnaryOp
::
kRelu
,
Dst
,
Src
>
{
UnaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
UnaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
Dst
operator
()(
Src
src
)
const
{
const
Src
zero_val
=
static_cast
<
Src
>
(
0.0
);
...
...
@@ -151,7 +158,7 @@ struct UnaryFunctor<device, UnaryOp::kRelu, Dst, Src> {
template
<
DeviceType
device
,
typename
Dst
,
typename
Src
>
struct
UnaryFunctor
<
device
,
UnaryOp
::
kSilu
,
Dst
,
Src
>
{
UnaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
UnaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
Dst
operator
()(
Src
src
)
const
{
return
static_cast
<
Dst
>
(
src
/
(
static_cast
<
Src
>
(
1
)
+
exp
(
-
src
)));
...
...
@@ -160,7 +167,7 @@ struct UnaryFunctor<device, UnaryOp::kSilu, Dst, Src> {
template
<
DeviceType
device
,
typename
Dst
,
typename
Src
>
struct
UnaryFunctor
<
device
,
UnaryOp
::
kSelu
,
Dst
,
Src
>
{
UnaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
UnaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
Dst
operator
()(
Src
src
)
const
{
return
static_cast
<
Dst
>
((
src
>
static_cast
<
Src
>
(
0.0
))
...
...
@@ -173,7 +180,7 @@ struct UnaryFunctor<device, UnaryOp::kSelu, Dst, Src> {
template
<
DeviceType
device
,
typename
Dst
,
typename
Src
>
struct
UnaryFunctor
<
device
,
UnaryOp
::
kSoftSign
,
Dst
,
Src
>
{
UnaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
UnaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
Dst
operator
()(
Src
src
)
const
{
return
static_cast
<
Dst
>
(
src
/
(
static_cast
<
Src
>
(
1
)
+
abs
(
src
)));
...
...
@@ -182,7 +189,7 @@ struct UnaryFunctor<device, UnaryOp::kSoftSign, Dst, Src> {
template
<
DeviceType
device
,
typename
Dst
,
typename
Src
>
struct
UnaryFunctor
<
device
,
UnaryOp
::
kSoftPlus
,
Dst
,
Src
>
{
UnaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
OF_DEVICE_FUNC
UnaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
:
beta
(
attr0
.
Value
<
double
>
()),
threshold
(
attr1
.
Value
<
double
>
())
{}
OF_DEVICE_FUNC
Dst
operator
()(
Src
src
)
const
{
...
...
@@ -196,7 +203,7 @@ struct UnaryFunctor<device, UnaryOp::kSoftPlus, Dst, Src> {
template
<
DeviceType
device
,
typename
Dst
,
typename
Src
>
struct
UnaryFunctor
<
device
,
UnaryOp
::
kSoftShrink
,
Dst
,
Src
>
{
UnaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
:
alpha
(
attr0
.
Value
<
double
>
())
{}
OF_DEVICE_FUNC
UnaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
:
alpha
(
attr0
.
Value
<
double
>
())
{}
OF_DEVICE_FUNC
Dst
operator
()(
Src
src
)
const
{
if
(
src
<=
alpha
&&
src
>=
-
alpha
)
{
...
...
@@ -212,7 +219,7 @@ struct UnaryFunctor<device, UnaryOp::kSoftShrink, Dst, Src> {
template
<
DeviceType
device
,
typename
Dst
,
typename
Src
>
struct
UnaryFunctor
<
device
,
UnaryOp
::
kThreshold
,
Dst
,
Src
>
{
UnaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
OF_DEVICE_FUNC
UnaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
:
threshold
(
attr0
.
Value
<
double
>
()),
value
(
attr1
.
Value
<
double
>
())
{}
OF_DEVICE_FUNC
Dst
operator
()(
Src
src
)
const
{
...
...
@@ -224,7 +231,7 @@ struct UnaryFunctor<device, UnaryOp::kThreshold, Dst, Src> {
template
<
DeviceType
device
,
typename
Dst
,
typename
Src
>
struct
UnaryFunctor
<
device
,
UnaryOp
::
kLogicalNot
,
Dst
,
Src
>
{
UnaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
UnaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
Dst
operator
()(
Src
src
)
const
{
return
static_cast
<
Dst
>
(
!
src
);
}
};
...
...
@@ -243,6 +250,300 @@ struct UnaryFunctor<device, UnaryOp::kIsNan, bool, Src> {
OF_DEVICE_FUNC
bool
operator
()(
Src
src
)
const
{
return
false
;
}
};
template
<
DeviceType
device
,
typename
Src
>
struct
UnaryFunctor
<
device
,
UnaryOp
::
kIsFinite
,
bool
,
Src
>
{
UnaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
bool
operator
()(
Src
src
)
const
{
return
true
;
}
};
template
<
DeviceType
device
,
typename
Dst
,
typename
Src
>
struct
UnaryFunctor
<
device
,
UnaryOp
::
kTrunc
,
Dst
,
Src
>
{
OF_DEVICE_FUNC
UnaryFunctor
(
Scalar
attr0
,
Scalar
attr1
);
OF_DEVICE_FUNC
Dst
operator
()(
Src
src
)
const
;
};
template
<
DeviceType
device
,
typename
Dst
,
typename
Src
>
struct
UnaryFunctor
<
device
,
UnaryOp
::
kAbs
,
Dst
,
Src
>
{
OF_DEVICE_FUNC
UnaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
Dst
operator
()(
Src
src
)
const
{
return
static_cast
<
Dst
>
(
abs
(
src
));
}
};
template
<
DeviceType
device
>
struct
UnaryFunctor
<
device
,
UnaryOp
::
kAbs
,
uint8_t
,
uint8_t
>
{
OF_DEVICE_FUNC
UnaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
uint8_t
operator
()(
uint8_t
src
)
const
{
return
src
;
}
};
template
<
DeviceType
device
,
typename
Dst
,
typename
Src
>
struct
UnaryFunctor
<
device
,
UnaryOp
::
kExp
,
Dst
,
Src
>
{
OF_DEVICE_FUNC
UnaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
Dst
operator
()(
Src
src
)
const
{
return
static_cast
<
Dst
>
(
exp
(
src
));
}
};
template
<
DeviceType
device
,
typename
Dst
,
typename
Src
>
struct
UnaryFunctor
<
device
,
UnaryOp
::
kAcos
,
Dst
,
Src
>
{
OF_DEVICE_FUNC
UnaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
Dst
operator
()(
Src
src
)
const
{
return
static_cast
<
Dst
>
(
acos
(
src
));
}
};
template
<
DeviceType
device
,
typename
Dst
,
typename
Src
>
struct
UnaryFunctor
<
device
,
UnaryOp
::
kAcosh
,
Dst
,
Src
>
{
OF_DEVICE_FUNC
UnaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
Dst
operator
()(
Src
src
)
const
{
return
static_cast
<
Dst
>
(
acosh
(
src
));
}
};
template
<
DeviceType
device
,
typename
Dst
,
typename
Src
>
struct
UnaryFunctor
<
device
,
UnaryOp
::
kAsin
,
Dst
,
Src
>
{
OF_DEVICE_FUNC
UnaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
Dst
operator
()(
Src
src
)
const
{
return
static_cast
<
Dst
>
(
asin
(
src
));
}
};
template
<
DeviceType
device
,
typename
Dst
,
typename
Src
>
struct
UnaryFunctor
<
device
,
UnaryOp
::
kAsinh
,
Dst
,
Src
>
{
OF_DEVICE_FUNC
UnaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
Dst
operator
()(
Src
src
)
const
{
return
static_cast
<
Dst
>
(
asinh
(
src
));
}
};
template
<
DeviceType
device
,
typename
Dst
,
typename
Src
>
struct
UnaryFunctor
<
device
,
UnaryOp
::
kAtan
,
Dst
,
Src
>
{
OF_DEVICE_FUNC
UnaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
Dst
operator
()(
Src
src
)
const
{
return
static_cast
<
Dst
>
(
atan
(
src
));
}
};
template
<
DeviceType
device
,
typename
Dst
,
typename
Src
>
struct
UnaryFunctor
<
device
,
UnaryOp
::
kAtanh
,
Dst
,
Src
>
{
OF_DEVICE_FUNC
UnaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
Dst
operator
()(
Src
src
)
const
{
return
static_cast
<
Dst
>
(
atanh
(
src
));
}
};
template
<
DeviceType
device
,
typename
Dst
,
typename
Src
>
struct
UnaryFunctor
<
device
,
UnaryOp
::
kCeil
,
Dst
,
Src
>
{
OF_DEVICE_FUNC
UnaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
Dst
operator
()(
Src
src
)
const
{
return
static_cast
<
Dst
>
(
ceil
(
src
));
}
};
template
<
DeviceType
device
,
typename
Dst
,
typename
Src
>
struct
UnaryFunctor
<
device
,
UnaryOp
::
kCos
,
Dst
,
Src
>
{
OF_DEVICE_FUNC
UnaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
Dst
operator
()(
Src
src
)
const
{
return
static_cast
<
Dst
>
(
cos
(
src
));
}
};
template
<
DeviceType
device
,
typename
Dst
,
typename
Src
>
struct
UnaryFunctor
<
device
,
UnaryOp
::
kCosh
,
Dst
,
Src
>
{
OF_DEVICE_FUNC
UnaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
Dst
operator
()(
Src
src
)
const
{
return
static_cast
<
Dst
>
(
cosh
(
src
));
}
};
template
<
DeviceType
device
,
typename
Dst
,
typename
Src
>
struct
UnaryFunctor
<
device
,
UnaryOp
::
kErf
,
Dst
,
Src
>
{
OF_DEVICE_FUNC
UnaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
Dst
operator
()(
Src
src
)
const
{
return
static_cast
<
Dst
>
(
erf
(
src
));
}
};
template
<
DeviceType
device
,
typename
Dst
,
typename
Src
>
struct
UnaryFunctor
<
device
,
UnaryOp
::
kErfc
,
Dst
,
Src
>
{
OF_DEVICE_FUNC
UnaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
Dst
operator
()(
Src
src
)
const
{
return
static_cast
<
Dst
>
(
erfc
(
src
));
}
};
template
<
DeviceType
device
,
typename
Dst
,
typename
Src
>
struct
UnaryFunctor
<
device
,
UnaryOp
::
kExpm1
,
Dst
,
Src
>
{
OF_DEVICE_FUNC
UnaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
Dst
operator
()(
Src
src
)
const
{
return
static_cast
<
Dst
>
(
expm1
(
src
));
}
};
template
<
DeviceType
device
,
typename
Dst
,
typename
Src
>
struct
UnaryFunctor
<
device
,
UnaryOp
::
kFloor
,
Dst
,
Src
>
{
OF_DEVICE_FUNC
UnaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
Dst
operator
()(
Src
src
)
const
{
return
static_cast
<
Dst
>
(
floor
(
src
));
}
};
template
<
DeviceType
device
,
typename
Dst
,
typename
Src
>
struct
UnaryFunctor
<
device
,
UnaryOp
::
kLgamma
,
Dst
,
Src
>
{
OF_DEVICE_FUNC
UnaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
Dst
operator
()(
Src
src
)
const
{
return
static_cast
<
Dst
>
(
lgamma
(
src
));
}
};
template
<
DeviceType
device
,
typename
Dst
,
typename
Src
>
struct
UnaryFunctor
<
device
,
UnaryOp
::
kLog
,
Dst
,
Src
>
{
OF_DEVICE_FUNC
UnaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
Dst
operator
()(
Src
src
)
const
{
return
static_cast
<
Dst
>
(
log
(
src
));
}
};
template
<
DeviceType
device
,
typename
Dst
,
typename
Src
>
struct
UnaryFunctor
<
device
,
UnaryOp
::
kLog2
,
Dst
,
Src
>
{
OF_DEVICE_FUNC
UnaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
Dst
operator
()(
Src
src
)
const
{
return
static_cast
<
Dst
>
(
log2
(
src
));
}
};
template
<
DeviceType
device
,
typename
Dst
,
typename
Src
>
struct
UnaryFunctor
<
device
,
UnaryOp
::
kLog10
,
Dst
,
Src
>
{
OF_DEVICE_FUNC
UnaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
Dst
operator
()(
Src
src
)
const
{
return
static_cast
<
Dst
>
(
log10
(
src
));
}
};
template
<
DeviceType
device
,
typename
Dst
,
typename
Src
>
struct
UnaryFunctor
<
device
,
UnaryOp
::
kLog1p
,
Dst
,
Src
>
{
OF_DEVICE_FUNC
UnaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
Dst
operator
()(
Src
src
)
const
{
return
static_cast
<
Dst
>
(
log1p
(
src
));
}
};
template
<
DeviceType
device
,
typename
Dst
,
typename
Src
>
struct
UnaryFunctor
<
device
,
UnaryOp
::
kLogSigmoid
,
Dst
,
Src
>
{
OF_DEVICE_FUNC
UnaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
Dst
operator
()(
Src
src
)
const
{
return
static_cast
<
Dst
>
(
-
log
(
static_cast
<
Src
>
(
1.0
)
+
exp
(
-
src
)));
}
};
template
<
DeviceType
device
,
typename
Dst
,
typename
Src
>
struct
UnaryFunctor
<
device
,
UnaryOp
::
kNegative
,
Dst
,
Src
>
{
OF_DEVICE_FUNC
UnaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
Dst
operator
()(
Src
src
)
const
{
return
static_cast
<
Dst
>
(
-
src
);
}
};
template
<
DeviceType
device
,
typename
Dst
,
typename
Src
>
struct
UnaryFunctor
<
device
,
UnaryOp
::
kReciprocal
,
Dst
,
Src
>
{
OF_DEVICE_FUNC
UnaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
Dst
operator
()(
Src
src
)
const
{
return
static_cast
<
Dst
>
(
static_cast
<
Src
>
(
1.0
)
/
src
);
}
};
template
<
DeviceType
device
,
typename
Dst
,
typename
Src
>
struct
UnaryFunctor
<
device
,
UnaryOp
::
kReciprocalNoNan
,
Dst
,
Src
>
{
OF_DEVICE_FUNC
UnaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
Dst
operator
()(
Src
src
)
const
{
if
(
abs
(
src
)
<=
static_cast
<
Src
>
(
0.0
))
{
return
static_cast
<
Dst
>
(
0.0
);
}
return
static_cast
<
Dst
>
(
static_cast
<
Src
>
(
1.0
)
/
src
);
}
};
template
<
DeviceType
device
,
typename
Dst
,
typename
Src
>
struct
UnaryFunctor
<
device
,
UnaryOp
::
kRint
,
Dst
,
Src
>
{
OF_DEVICE_FUNC
UnaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
Dst
operator
()(
Src
src
)
const
{
return
static_cast
<
Dst
>
(
rint
(
src
));
}
};
template
<
DeviceType
device
,
typename
Dst
,
typename
Src
>
struct
UnaryFunctor
<
device
,
UnaryOp
::
kRound
,
Dst
,
Src
>
{
OF_DEVICE_FUNC
UnaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
Dst
operator
()(
Src
src
)
const
{
return
static_cast
<
Dst
>
(
nearbyint
(
src
));
}
};
template
<
DeviceType
device
,
typename
Dst
,
typename
Src
>
struct
UnaryFunctor
<
device
,
UnaryOp
::
kRsqrt
,
Dst
,
Src
>
{
OF_DEVICE_FUNC
UnaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
Dst
operator
()(
Src
src
)
const
{
return
static_cast
<
Dst
>
(
rsqrt
(
src
));
}
};
template
<
DeviceType
device
,
typename
Dst
,
typename
Src
>
struct
UnaryFunctor
<
device
,
UnaryOp
::
kSigmoid
,
Dst
,
Src
>
{
OF_DEVICE_FUNC
UnaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
Dst
operator
()(
Src
src
)
const
{
return
static_cast
<
Dst
>
(
static_cast
<
Src
>
(
1.0
)
/
(
static_cast
<
Src
>
(
1.0
)
+
exp
(
-
src
)));
}
};
template
<
DeviceType
device
,
typename
Dst
,
typename
Src
>
struct
UnaryFunctor
<
device
,
UnaryOp
::
kSign
,
Dst
,
Src
>
{
OF_DEVICE_FUNC
UnaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
Dst
operator
()(
Src
src
)
const
{
const
Src
zero
=
static_cast
<
Src
>
(
0.0
);
if
(
src
>
zero
)
{
return
static_cast
<
Dst
>
(
1.0
);
}
else
if
(
src
<
zero
)
{
return
static_cast
<
Dst
>
(
-
1.0
);
}
else
{
return
static_cast
<
Dst
>
(
0.0
);
}
}
};
template
<
DeviceType
device
,
typename
Dst
,
typename
Src
>
struct
UnaryFunctor
<
device
,
UnaryOp
::
kSin
,
Dst
,
Src
>
{
OF_DEVICE_FUNC
UnaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
Dst
operator
()(
Src
src
)
const
{
return
static_cast
<
Dst
>
(
sin
(
src
));
}
};
template
<
DeviceType
device
,
typename
Dst
,
typename
Src
>
struct
UnaryFunctor
<
device
,
UnaryOp
::
kSinh
,
Dst
,
Src
>
{
OF_DEVICE_FUNC
UnaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
Dst
operator
()(
Src
src
)
const
{
return
static_cast
<
Dst
>
(
sinh
(
src
));
}
};
template
<
DeviceType
device
,
typename
Dst
,
typename
Src
>
struct
UnaryFunctor
<
device
,
UnaryOp
::
kSqrt
,
Dst
,
Src
>
{
OF_DEVICE_FUNC
UnaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
Dst
operator
()(
Src
src
)
const
{
return
static_cast
<
Dst
>
(
sqrt
(
src
));
}
};
template
<
DeviceType
device
,
typename
Dst
,
typename
Src
>
struct
UnaryFunctor
<
device
,
UnaryOp
::
kSquare
,
Dst
,
Src
>
{
OF_DEVICE_FUNC
UnaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
Dst
operator
()(
Src
src
)
const
{
return
static_cast
<
Dst
>
(
src
*
src
);
}
};
template
<
DeviceType
device
,
typename
Dst
,
typename
Src
>
struct
UnaryFunctor
<
device
,
UnaryOp
::
kTan
,
Dst
,
Src
>
{
OF_DEVICE_FUNC
UnaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
Dst
operator
()(
Src
src
)
const
{
return
static_cast
<
Dst
>
(
tan
(
src
));
}
};
template
<
DeviceType
device
,
typename
Dst
,
typename
Src
>
struct
UnaryFunctor
<
device
,
UnaryOp
::
kNotEqualZero
,
Dst
,
Src
>
{
OF_DEVICE_FUNC
UnaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
Dst
operator
()(
Src
src
)
const
{
return
static_cast
<
Dst
>
(
src
!=
static_cast
<
Src
>
(
0.0
));
}
};
template
<
DeviceType
device
,
typename
Dst
,
typename
Src
>
struct
UnaryFunctor
<
device
,
UnaryOp
::
kNanAssign
,
Dst
,
Src
>
{
OF_DEVICE_FUNC
UnaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
Dst
operator
()(
Src
src
)
const
{
return
std
::
isnan
(
src
)
?
static_cast
<
Dst
>
(
0.0
)
:
src
;
}
};
}
// namespace primitive
}
// namespace ep
}
// namespace oneflow
...
...
oneflow/core/ep/common/primitive/util.h
View file @
a715222c
...
...
@@ -37,6 +37,71 @@ bool IsPackSizeSupported(const size_t pack_size, size_t num_dims, const int64_t*
&&
(
reinterpret_cast
<
std
::
uintptr_t
>
(
ptr
)
%
(
pack_size
*
sizeof
(
T
))
==
0
);
}
inline
void
CheckInplace
(
size_t
num_dims
,
const
int64_t
*
src_dims_or_strides
,
const
void
*
src
,
const
int64_t
*
dst_dims_or_strides
,
const
void
*
dst
)
{
if
(
src
==
dst
)
{
for
(
int64_t
i
=
0
;
i
<
num_dims
;
++
i
)
{
CHECK_EQ
(
src_dims_or_strides
[
i
],
dst_dims_or_strides
[
i
]);
}
}
}
template
<
size_t
max_num_dims
>
inline
void
SimplifyBroadcastDims
(
size_t
num_src_dims
,
const
int64_t
*
src_dims
,
const
int64_t
*
src_strides
,
size_t
num_dst_dims
,
const
int64_t
*
dst_dims
,
const
int64_t
*
dst_strides
,
size_t
*
simplified_num_dims
,
int64_t
*
simplified_src_dims
,
int64_t
*
simplified_src_strides
,
int64_t
*
simplified_dst_dims
,
int64_t
*
simplified_dst_strides
)
{
*
simplified_num_dims
=
0
;
std
::
pair
<
int64_t
,
size_t
>
sorted_dst_strides
[
max_num_dims
];
int64_t
new_dst_dims
[
max_num_dims
];
int64_t
new_src_dims
[
max_num_dims
];
int64_t
new_dst_strides
[
max_num_dims
];
int64_t
new_src_strides
[
max_num_dims
];
for
(
size_t
i
=
0
;
i
<
num_dst_dims
;
i
++
)
{
sorted_dst_strides
[
i
]
=
{
dst_strides
[
i
],
i
};
}
std
::
sort
(
sorted_dst_strides
,
sorted_dst_strides
+
num_dst_dims
,
[](
auto
pair1
,
auto
pair2
)
{
return
pair1
.
first
>
pair2
.
first
;
});
const
int64_t
num_src_padding_dims
=
num_dst_dims
-
num_src_dims
;
// dimension completion
int64_t
expanded_src_dims
[
max_num_dims
];
int64_t
expanded_src_strides
[
max_num_dims
];
for
(
int64_t
i
=
num_dst_dims
-
1
;
i
>=
0
;
i
--
)
{
expanded_src_dims
[
i
]
=
i
<
num_src_padding_dims
?
1
:
src_dims
[
i
-
num_src_padding_dims
];
expanded_src_strides
[
i
]
=
i
<
num_src_padding_dims
?
0
:
src_strides
[
i
-
num_src_padding_dims
];
}
// dimension permutation
for
(
int64_t
i
=
num_dst_dims
-
1
;
i
>=
0
;
i
--
)
{
size_t
idx
=
sorted_dst_strides
[
i
].
second
;
new_dst_dims
[
i
]
=
dst_dims
[
idx
];
new_dst_strides
[
i
]
=
dst_strides
[
idx
];
new_src_dims
[
i
]
=
expanded_src_dims
[
idx
];
new_src_strides
[
i
]
=
expanded_src_strides
[
idx
];
}
// dimension merge
bool
prev_broadcast_src
=
false
;
for
(
int64_t
i
=
0
;
i
<
num_dst_dims
;
++
i
)
{
const
bool
broadcast_src
=
(
new_src_dims
[
i
]
==
1
);
if
(
new_dst_dims
[
i
]
==
1
)
{
continue
;
}
else
if
(
*
simplified_num_dims
!=
0
&&
prev_broadcast_src
==
broadcast_src
&&
(
new_src_strides
[
i
-
1
]
==
new_src_strides
[
i
]
*
new_src_dims
[
i
])
&&
(
new_dst_strides
[
i
-
1
]
==
new_dst_strides
[
i
]
*
new_dst_dims
[
i
]))
{
simplified_src_dims
[
*
simplified_num_dims
-
1
]
*=
new_src_dims
[
i
];
simplified_dst_dims
[
*
simplified_num_dims
-
1
]
*=
new_dst_dims
[
i
];
simplified_src_strides
[
*
simplified_num_dims
-
1
]
=
new_src_strides
[
i
];
simplified_dst_strides
[
*
simplified_num_dims
-
1
]
=
new_dst_strides
[
i
];
}
else
{
simplified_src_dims
[
*
simplified_num_dims
]
=
new_src_dims
[
i
];
simplified_dst_dims
[
*
simplified_num_dims
]
=
new_dst_dims
[
i
];
simplified_src_strides
[
*
simplified_num_dims
]
=
new_src_strides
[
i
];
simplified_dst_strides
[
*
simplified_num_dims
]
=
new_dst_strides
[
i
];
*
simplified_num_dims
+=
1
;
prev_broadcast_src
=
broadcast_src
;
}
}
}
inline
void
SimplifyBroadcastDims
(
size_t
num_a_dims
,
const
int64_t
*
a_dims
,
size_t
num_b_dims
,
const
int64_t
*
b_dims
,
size_t
num_c_dims
,
const
int64_t
*
c_dims
,
size_t
*
simplified_num_dims
,
int64_t
*
simplified_broadcast_dims
,
...
...
oneflow/core/ep/cpu/cpu_device.cpp
View file @
a715222c
...
...
@@ -42,15 +42,13 @@ Maybe<void> CpuDevice::Alloc(const AllocationOptions& options, void** ptr, size_
this
->
device_manager
()
->
registry
()
->
GetDevice
(
options
.
GetPinnedDeviceType
(),
// NOLINT
options
.
GetPinnedDeviceIndex
());
// NOLINT
CHECK_OR_RETURN
(
device
);
return
device
->
AllocPinned
(
options
,
ptr
,
size
);
JUST
(
device
->
AllocPinned
(
options
,
ptr
,
size
)
)
;
}
else
{
*
ptr
=
aligned_alloc
(
kMaxAlignmentRequirement
,
size
);
if
(
*
ptr
==
nullptr
)
{
return
Error
::
RuntimeError
()
<<
"allocate failed"
;
}
else
{
return
Maybe
<
void
>::
Ok
();
}
*
ptr
=
aligned_alloc
(
kMaxAlignmentRequirement
,
RoundUp
(
size
,
kMaxAlignmentRequirement
));
if
(
*
ptr
==
nullptr
)
{
return
Error
::
RuntimeError
()
<<
"allocate failed"
;
}
}
memset
(
*
ptr
,
0
,
size
);
return
Maybe
<
void
>::
Ok
();
}
void
CpuDevice
::
Free
(
const
AllocationOptions
&
options
,
void
*
ptr
)
{
...
...
oneflow/core/ep/cpu/primitive/binary_functor.h
View file @
a715222c
...
...
@@ -29,23 +29,224 @@ struct BinaryFunctor<DeviceType::kCPU, BinaryOp::kPow, Src, Dst> {
};
template
<
>
struct
BinaryFunctor
<
DeviceType
::
kCPU
,
BinaryOp
::
kPow
,
bool
,
bool
>
{
struct
BinaryFunctor
<
DeviceType
::
kCPU
,
BinaryOp
::
kPow
,
float16
,
float16
>
{
OF_DEVICE_FUNC
BinaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
bool
operator
()(
bool
src0
,
bool
src1
)
const
{
return
static_cast
<
bool
>
(
std
::
pow
(
static_cast
<
double
>
(
src0
),
static_cast
<
double
>
(
src1
)));
OF_DEVICE_FUNC
float16
operator
()(
float16
src0
,
float16
src1
)
const
{
return
static_cast
<
float16
>
(
std
::
pow
(
static_cast
<
float
>
(
src0
),
static_cast
<
float
>
(
src1
)));
}
};
template
<
>
struct
BinaryFunctor
<
DeviceType
::
kCPU
,
BinaryOp
::
kPow
,
float16
,
float16
>
{
struct
BinaryFunctor
<
DeviceType
::
kCPU
,
BinaryOp
::
kFmod
,
float
,
float
>
{
OF_DEVICE_FUNC
BinaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
float
operator
()(
float
src0
,
float
src1
)
const
{
return
std
::
fmod
(
src0
,
src1
);
}
};
template
<
>
struct
BinaryFunctor
<
DeviceType
::
kCPU
,
BinaryOp
::
kFmod
,
double
,
double
>
{
OF_DEVICE_FUNC
BinaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
double
operator
()(
double
src0
,
double
src1
)
const
{
return
std
::
fmod
(
src0
,
src1
);
}
};
template
<
>
struct
BinaryFunctor
<
DeviceType
::
kCPU
,
BinaryOp
::
kFmod
,
float16
,
float16
>
{
OF_DEVICE_FUNC
BinaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
float16
operator
()(
float16
src0
,
float16
src1
)
const
{
return
static_cast
<
float16
>
(
std
::
pow
(
static_cast
<
float
>
(
src0
),
static_cast
<
float
>
(
src1
)));
return
static_cast
<
float16
>
(
std
::
fmod
(
static_cast
<
float
>
(
src0
),
static_cast
<
float
>
(
src1
)));
}
};
template
<
>
struct
BinaryFunctor
<
DeviceType
::
kCPU
,
BinaryOp
::
kFmod
,
bfloat16
,
bfloat16
>
{
OF_DEVICE_FUNC
BinaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
bfloat16
operator
()(
bfloat16
src0
,
bfloat16
src1
)
const
{
return
std
::
fmod
(
src0
,
src1
);
}
};
template
<
>
struct
BinaryFunctor
<
DeviceType
::
kCPU
,
BinaryOp
::
kFloorDiv
,
float
,
float
>
{
OF_DEVICE_FUNC
BinaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
float
operator
()(
float
src0
,
float
src1
)
const
{
return
std
::
floor
(
src0
/
src1
);
}
};
template
<
>
struct
BinaryFunctor
<
DeviceType
::
kCPU
,
BinaryOp
::
kFloorDiv
,
double
,
double
>
{
OF_DEVICE_FUNC
BinaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
double
operator
()(
double
src0
,
double
src1
)
const
{
return
std
::
floor
(
src0
/
src1
);
}
};
template
<
>
struct
BinaryFunctor
<
DeviceType
::
kCPU
,
BinaryOp
::
kFloorDiv
,
float16
,
float16
>
{
OF_DEVICE_FUNC
BinaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
float16
operator
()(
float16
src0
,
float16
src1
)
const
{
return
static_cast
<
float16
>
(
std
::
floor
(
static_cast
<
float
>
(
src0
)
/
static_cast
<
float
>
(
src1
)));
}
};
template
<
>
struct
BinaryFunctor
<
DeviceType
::
kCPU
,
BinaryOp
::
kFloorDiv
,
bfloat16
,
bfloat16
>
{
OF_DEVICE_FUNC
BinaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
bfloat16
operator
()(
bfloat16
src0
,
bfloat16
src1
)
const
{
return
std
::
floor
(
src0
/
src1
);
}
};
template
<
>
struct
BinaryFunctor
<
DeviceType
::
kCPU
,
BinaryOp
::
kTruncDiv
,
float
,
float
>
{
OF_DEVICE_FUNC
BinaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
float
operator
()(
float
src0
,
float
src1
)
const
{
return
std
::
trunc
(
src0
/
src1
);
}
};
template
<
>
struct
BinaryFunctor
<
DeviceType
::
kCPU
,
BinaryOp
::
kTruncDiv
,
double
,
double
>
{
OF_DEVICE_FUNC
BinaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
double
operator
()(
double
src0
,
double
src1
)
const
{
return
std
::
trunc
(
src0
/
src1
);
}
};
template
<
>
struct
BinaryFunctor
<
DeviceType
::
kCPU
,
BinaryOp
::
kTruncDiv
,
float16
,
float16
>
{
OF_DEVICE_FUNC
BinaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
float16
operator
()(
float16
src0
,
float16
src1
)
const
{
return
static_cast
<
float16
>
(
std
::
trunc
(
static_cast
<
float
>
(
src0
)
/
static_cast
<
float
>
(
src1
)));
}
};
template
<
>
struct
BinaryFunctor
<
DeviceType
::
kCPU
,
BinaryOp
::
kTruncDiv
,
bfloat16
,
bfloat16
>
{
OF_DEVICE_FUNC
BinaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
bfloat16
operator
()(
bfloat16
src0
,
bfloat16
src1
)
const
{
return
std
::
trunc
(
src0
/
src1
);
}
};
template
<
>
struct
BinaryFunctor
<
DeviceType
::
kCPU
,
BinaryOp
::
kFloorMod
,
float
,
float
>
{
OF_DEVICE_FUNC
BinaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
float
operator
()(
float
src0
,
float
src1
)
const
{
float
trunc_mod
=
std
::
fmod
(
src0
,
src1
);
return
(
trunc_mod
!=
static_cast
<
float
>
(
0
))
&&
((
src1
<
static_cast
<
float
>
(
0
))
!=
(
trunc_mod
<
static_cast
<
float
>
(
0
)))
?
trunc_mod
+
src1
:
trunc_mod
;
}
};
template
<
>
struct
BinaryFunctor
<
DeviceType
::
kCPU
,
BinaryOp
::
kFloorMod
,
double
,
double
>
{
OF_DEVICE_FUNC
BinaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
double
operator
()(
double
src0
,
double
src1
)
const
{
double
trunc_mod
=
std
::
fmod
(
src0
,
src1
);
return
(
trunc_mod
!=
static_cast
<
double
>
(
0
))
&&
((
src1
<
static_cast
<
double
>
(
0
))
!=
(
trunc_mod
<
static_cast
<
double
>
(
0
)))
?
trunc_mod
+
src1
:
trunc_mod
;
}
};
template
<
>
struct
BinaryFunctor
<
DeviceType
::
kCPU
,
BinaryOp
::
kFloorMod
,
float16
,
float16
>
{
OF_DEVICE_FUNC
BinaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
:
float_functor
(
attr0
,
attr1
)
{}
BinaryFunctor
<
DeviceType
::
kCPU
,
BinaryOp
::
kFloorMod
,
float
,
float
>
float_functor
;
OF_DEVICE_FUNC
float16
operator
()(
float16
src0
,
float16
src1
)
const
{
return
static_cast
<
float16
>
(
float_functor
(
static_cast
<
float
>
(
src0
),
static_cast
<
float
>
(
src1
)));
}
};
template
<
>
struct
BinaryFunctor
<
DeviceType
::
kCPU
,
BinaryOp
::
kFloorMod
,
bfloat16
,
bfloat16
>
{
OF_DEVICE_FUNC
BinaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
:
float_functor
(
attr0
,
attr1
)
{}
BinaryFunctor
<
DeviceType
::
kCPU
,
BinaryOp
::
kFloorMod
,
float
,
float
>
float_functor
;
OF_DEVICE_FUNC
bfloat16
operator
()(
bfloat16
src0
,
bfloat16
src1
)
const
{
return
static_cast
<
bfloat16
>
(
float_functor
(
static_cast
<
float
>
(
src0
),
static_cast
<
float
>
(
src1
)));
}
};
template
<
>
struct
BinaryFunctor
<
DeviceType
::
kCPU
,
BinaryOp
::
kScalarBasePowerGrad
,
float16
,
float16
>
{
OF_DEVICE_FUNC
BinaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
:
scalar_operand
(
attr0
.
Value
<
float
>
())
{}
OF_DEVICE_FUNC
float16
operator
()(
float16
src0
,
float16
src1
)
const
{
return
static_cast
<
float16
>
(
scalar_operand
*
(
std
::
pow
(
static_cast
<
float
>
(
src0
),
scalar_operand
-
static_cast
<
float
>
(
1
)))
*
static_cast
<
float
>
(
src1
));
}
float
scalar_operand
;
};
template
<
typename
Dst
>
struct
BinaryFunctor
<
DeviceType
::
kCPU
,
BinaryOp
::
kScalarExpPowerGrad
,
int
,
Dst
>
{
OF_DEVICE_FUNC
BinaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
:
float_functor
(
attr0
,
attr1
)
{}
BinaryFunctor
<
DeviceType
::
kCPU
,
BinaryOp
::
kScalarExpPowerGrad
,
float
,
float
>
float_functor
;
OF_DEVICE_FUNC
Dst
operator
()(
int
src0
,
int
src1
)
const
{
return
static_cast
<
Dst
>
(
float_functor
(
static_cast
<
float
>
(
src0
),
static_cast
<
float
>
(
src1
)));
}
};
template
<
typename
Dst
>
struct
BinaryFunctor
<
DeviceType
::
kCPU
,
BinaryOp
::
kScalarExpPowerGrad
,
int8_t
,
Dst
>
{
OF_DEVICE_FUNC
BinaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
:
float_functor
(
attr0
,
attr1
)
{}
BinaryFunctor
<
DeviceType
::
kCPU
,
BinaryOp
::
kScalarExpPowerGrad
,
float
,
float
>
float_functor
;
OF_DEVICE_FUNC
Dst
operator
()(
int8_t
src0
,
int8_t
src1
)
const
{
return
static_cast
<
Dst
>
(
float_functor
(
static_cast
<
float
>
(
src0
),
static_cast
<
float
>
(
src1
)));
}
};
template
<
typename
Dst
>
struct
BinaryFunctor
<
DeviceType
::
kCPU
,
BinaryOp
::
kScalarExpPowerGrad
,
uint8_t
,
Dst
>
{
OF_DEVICE_FUNC
BinaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
:
float_functor
(
attr0
,
attr1
)
{}
BinaryFunctor
<
DeviceType
::
kCPU
,
BinaryOp
::
kScalarExpPowerGrad
,
float
,
float
>
float_functor
;
OF_DEVICE_FUNC
Dst
operator
()(
uint8_t
src0
,
uint8_t
src1
)
const
{
return
static_cast
<
Dst
>
(
float_functor
(
static_cast
<
float
>
(
src0
),
static_cast
<
float
>
(
src1
)));
}
};
template
<
typename
Dst
>
struct
BinaryFunctor
<
DeviceType
::
kCPU
,
BinaryOp
::
kScalarExpPowerGrad
,
int64_t
,
Dst
>
{
OF_DEVICE_FUNC
BinaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
:
float_functor
(
attr0
,
attr1
)
{}
BinaryFunctor
<
DeviceType
::
kCPU
,
BinaryOp
::
kScalarExpPowerGrad
,
float
,
float
>
float_functor
;
OF_DEVICE_FUNC
Dst
operator
()(
int
src0
,
int
src1
)
const
{
return
static_cast
<
Dst
>
(
float_functor
(
static_cast
<
float
>
(
src0
),
static_cast
<
float
>
(
src1
)));
}
};
template
<
typename
Dst
>
struct
BinaryFunctor
<
DeviceType
::
kCPU
,
BinaryOp
::
kScalarExpPowerGrad
,
float16
,
Dst
>
{
OF_DEVICE_FUNC
BinaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
:
scalar_operand
(
attr0
.
Value
<
float
>
())
{}
OF_DEVICE_FUNC
Dst
operator
()(
float16
src0
,
float16
src1
)
const
{
return
static_cast
<
Dst
>
(
std
::
pow
(
scalar_operand
,
static_cast
<
float
>
(
src0
))
*
std
::
log
(
scalar_operand
)
*
static_cast
<
float
>
(
src1
));
}
float
scalar_operand
;
};
template
<
typename
Src
,
typename
Dst
>
struct
BinaryFunctor
<
DeviceType
::
kCPU
,
BinaryOp
::
kGeluBackwardWithDyX
,
Src
,
Dst
>
{
OF_DEVICE_FUNC
BinaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
...
...
@@ -59,6 +260,39 @@ struct BinaryFunctor<DeviceType::kCPU, BinaryOp::kGeluBackwardWithDyX, Src, Dst>
Src
coef
=
std
::
sqrt
(
2.0
/
std
::
acos
(
-
1.0
));
};
template
<
typename
Src
,
typename
Dst
>
struct
BinaryFunctor
<
DeviceType
::
kCPU
,
BinaryOp
::
kFastGeluBackwardWithDyX
,
Src
,
Dst
>
{
OF_DEVICE_FUNC
BinaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
Dst
operator
()(
Src
dy
,
Src
x
)
const
{
// ref to: https://mlfromscratch.com/activation-functions-explained/#gelu
const
Src
one
=
static_cast
<
Src
>
(
1
);
const
Src
half
=
static_cast
<
Src
>
(
0.5
);
const
Src
pow3
=
x
*
x
*
x
;
const
Src
tanh_out
=
std
::
tanh
(
alpha
*
(
x
+
beta
*
pow3
));
const
Src
dtanh
=
alpha
*
(
half
*
x
+
beta
*
static_cast
<
Src
>
(
1.5
)
*
pow3
);
return
dy
*
(
half
+
half
*
tanh_out
+
dtanh
*
(
one
-
tanh_out
*
tanh_out
));
}
private:
static
constexpr
Src
alpha
=
static_cast
<
Src
>
(
0.7978845608028654
);
static
constexpr
Src
beta
=
static_cast
<
Src
>
(
0.044714998453855515
);
};
template
<
typename
Src
,
typename
Dst
>
struct
BinaryFunctor
<
DeviceType
::
kCPU
,
BinaryOp
::
kQuickGeluBackwardWithDyX
,
Src
,
Dst
>
{
OF_DEVICE_FUNC
BinaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
Dst
operator
()(
Src
dy
,
Src
x
)
const
{
const
Src
one
=
static_cast
<
Src
>
(
1.0
);
const
Src
sigmoid
=
one
/
(
one
+
exp
(
-
x
*
alpha
));
return
dy
*
(
sigmoid
+
alpha
*
x
*
(
sigmoid
*
(
one
-
sigmoid
)));
}
private:
static
constexpr
Src
alpha
=
static_cast
<
Src
>
(
1.702
);
};
template
<
typename
Src
,
typename
Dst
>
struct
BinaryFunctor
<
DeviceType
::
kCPU
,
BinaryOp
::
kTanhBackwardWithDyX
,
Src
,
Dst
>
{
OF_DEVICE_FUNC
BinaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
...
...
@@ -69,6 +303,82 @@ struct BinaryFunctor<DeviceType::kCPU, BinaryOp::kTanhBackwardWithDyX, Src, Dst>
}
};
template
<
typename
Src
,
typename
Dst
>
struct
BinaryFunctor
<
DeviceType
::
kCPU
,
BinaryOp
::
kAcosBackwardWithDyX
,
Src
,
Dst
>
{
OF_DEVICE_FUNC
BinaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
Dst
operator
()(
Src
dy
,
Src
x
)
const
{
return
dy
*
-
(
static_cast
<
Src
>
(
1.0
)
/
sqrt
(
static_cast
<
Src
>
(
1.0
)
-
x
*
x
));
}
};
template
<
typename
Src
,
typename
Dst
>
struct
BinaryFunctor
<
DeviceType
::
kCPU
,
BinaryOp
::
kAcoshBackwardWithDyX
,
Src
,
Dst
>
{
OF_DEVICE_FUNC
BinaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
Dst
operator
()(
Src
dy
,
Src
x
)
const
{
return
dy
/
sqrt
(
x
*
x
-
static_cast
<
Src
>
(
1.0
));
}
};
template
<
typename
Src
,
typename
Dst
>
struct
BinaryFunctor
<
DeviceType
::
kCPU
,
BinaryOp
::
kAsinBackwardWithDyX
,
Src
,
Dst
>
{
OF_DEVICE_FUNC
BinaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
Dst
operator
()(
Src
dy
,
Src
x
)
const
{
return
dy
*
(
static_cast
<
Src
>
(
1.0
)
/
sqrt
(
static_cast
<
Src
>
(
1.0
)
-
x
*
x
));
}
};
template
<
typename
Src
,
typename
Dst
>
struct
BinaryFunctor
<
DeviceType
::
kCPU
,
BinaryOp
::
kAsinhBackwardWithDyX
,
Src
,
Dst
>
{
OF_DEVICE_FUNC
BinaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
Dst
operator
()(
Src
dy
,
Src
x
)
const
{
return
dy
*
(
static_cast
<
Src
>
(
1.0
)
/
sqrt
(
static_cast
<
Src
>
(
1.0
)
+
x
*
x
));
}
};
template
<
typename
Src
,
typename
Dst
>
struct
BinaryFunctor
<
DeviceType
::
kCPU
,
BinaryOp
::
kErfBackwardWithDyX
,
Src
,
Dst
>
{
OF_DEVICE_FUNC
BinaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
Dst
operator
()(
Src
dy
,
Src
x
)
const
{
return
dy
*
static_cast
<
Src
>
(
2.0
)
*
(
static_cast
<
Src
>
(
1.0
)
/
sqrt
(
static_cast
<
Src
>
(
M_PI
)))
*
exp
(
-
x
*
x
);
}
};
template
<
typename
Src
,
typename
Dst
>
struct
BinaryFunctor
<
DeviceType
::
kCPU
,
BinaryOp
::
kErfcBackwardWithDyX
,
Src
,
Dst
>
{
OF_DEVICE_FUNC
BinaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
Dst
operator
()(
Src
dy
,
Src
x
)
const
{
return
dy
*
static_cast
<
Src
>
(
-
2.0
)
*
(
static_cast
<
Src
>
(
1.0
)
/
sqrt
(
static_cast
<
Src
>
(
M_PI
)))
*
exp
(
-
x
*
x
);
}
};
#define SPECIALIZATION_CPU_BINARY_FUNCTOR(op, type) \
template<> \
struct BinaryFunctor<DeviceType::kCPU, op, type, type> { \
OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) : int_functor(attr0, attr1) {} \
\
BinaryFunctor<DeviceType::kCPU, op, int, int> int_functor; \
OF_DEVICE_FUNC type operator()(type src0, type src1) const { \
return static_cast<type>(int_functor(static_cast<int>(src0), static_cast<int>(src1))); \
} \
};
SPECIALIZATION_CPU_BINARY_FUNCTOR
(
BinaryOp
::
kPow
,
bool
);
SPECIALIZATION_CPU_BINARY_FUNCTOR
(
BinaryOp
::
kFmod
,
bool
);
SPECIALIZATION_CPU_BINARY_FUNCTOR
(
BinaryOp
::
kFloorDiv
,
bool
);
SPECIALIZATION_CPU_BINARY_FUNCTOR
(
BinaryOp
::
kTruncDiv
,
bool
);
SPECIALIZATION_CPU_BINARY_FUNCTOR
(
BinaryOp
::
kFloorMod
,
bool
);
SPECIALIZATION_CPU_BINARY_FUNCTOR
(
BinaryOp
::
kScalarBasePowerGrad
,
bool
);
SPECIALIZATION_CPU_BINARY_FUNCTOR
(
BinaryOp
::
kScalarExpPowerGrad
,
bool
);
SPECIALIZATION_CPU_BINARY_FUNCTOR
(
BinaryOp
::
kPow
,
char
);
SPECIALIZATION_CPU_BINARY_FUNCTOR
(
BinaryOp
::
kFmod
,
char
);
SPECIALIZATION_CPU_BINARY_FUNCTOR
(
BinaryOp
::
kFloorDiv
,
char
);
SPECIALIZATION_CPU_BINARY_FUNCTOR
(
BinaryOp
::
kTruncDiv
,
char
);
SPECIALIZATION_CPU_BINARY_FUNCTOR
(
BinaryOp
::
kFloorMod
,
char
);
SPECIALIZATION_CPU_BINARY_FUNCTOR
(
BinaryOp
::
kScalarBasePowerGrad
,
char
);
SPECIALIZATION_CPU_BINARY_FUNCTOR
(
BinaryOp
::
kScalarExpPowerGrad
,
char
);
}
// namespace broadcast_elementwise_binary
}
// namespace primitive
}
// namespace ep
...
...
oneflow/core/ep/cpu/primitive/broadcast_elementwise_binary.cpp
View file @
a715222c
...
...
@@ -45,6 +45,11 @@ float16 GetValue<float16>(Scalar value) {
return
static_cast
<
float16
>
(
GetValue
<
float
>
(
value
));
}
template
<
>
bfloat16
GetValue
<
bfloat16
>
(
Scalar
value
)
{
return
static_cast
<
bfloat16
>
(
GetValue
<
float
>
(
value
));
}
template
<
BinaryOp
binary_op
,
typename
Src
,
typename
Dst
>
struct
BinaryLhsScalarFunctor
{
BinaryLhsScalarFunctor
(
Src
scalar
,
Scalar
attr0
,
Scalar
attr1
)
...
...
@@ -247,8 +252,8 @@ void DispatchLaunch(Stream* stream, size_t num_src0_dims, const int64_t* src0_di
SimplifyBroadcastDims
<
kMaxNumDims
>
(
num_src0_dims
,
src0_dims
,
num_src1_dims
,
src1_dims
,
&
simplified_num_dims
,
simplified_src0_dims
,
simplified_src1_dims
,
simplified_dst_dims
);
CheckInplace
(
simplified_num_dims
,
simplified_src0_dims
,
src0
,
simplified_
src1
_dims
,
src1
,
simplified_dst_dims
,
dst
);
CheckInplace
(
simplified_num_dims
,
simplified_src0_dims
,
src0
,
simplified_
dst
_dims
,
dst
);
CheckInplace
(
simplified_num_dims
,
simplified_src1_dims
,
src1
,
simplified_dst_dims
,
dst
);
if
(
IsDimsEquals
(
simplified_num_dims
,
simplified_src0_dims
,
simplified_num_dims
,
simplified_src1_dims
))
{
LaunchElementwise
<
binary_op
,
Src
,
Dst
>
(
cpu_stream
,
simplified_num_dims
,
simplified_src0_dims
,
...
...
@@ -260,16 +265,20 @@ void DispatchLaunch(Stream* stream, size_t num_src0_dims, const int64_t* src0_di
}
else
if
(
simplified_num_dims
==
1
&&
simplified_src1_dims
[
0
]
==
1
)
{
LaunchBinaryRhsScalar
<
binary_op
,
Src
,
Dst
>
(
cpu_stream
,
*
src1
,
simplified_src0_dims
[
0
],
src0
,
dst
,
attr0
,
attr1
);
}
else
if
(
simplified_num_dims
==
2
&&
simplified_src0_dims
[
0
]
==
1
)
{
}
else
if
(
simplified_num_dims
==
2
&&
simplified_src0_dims
[
0
]
==
1
&&
simplified_src0_dims
[
1
]
==
simplified_src1_dims
[
1
])
{
LaunchRowWithMatrix
<
binary_op
,
Src
,
Dst
>
(
cpu_stream
,
simplified_src0_dims
,
src0
,
simplified_src1_dims
,
src1
,
dst
,
attr0
,
attr1
);
}
else
if
(
simplified_num_dims
==
2
&&
simplified_src1_dims
[
0
]
==
1
)
{
}
else
if
(
simplified_num_dims
==
2
&&
simplified_src1_dims
[
0
]
==
1
&&
simplified_src0_dims
[
1
]
==
simplified_src1_dims
[
1
])
{
LaunchMatrixWithRow
<
binary_op
,
Src
,
Dst
>
(
cpu_stream
,
simplified_src0_dims
,
src0
,
simplified_src1_dims
,
src1
,
dst
,
attr0
,
attr1
);
}
else
if
(
simplified_num_dims
==
2
&&
simplified_src0_dims
[
1
]
==
1
)
{
}
else
if
(
simplified_num_dims
==
2
&&
simplified_src0_dims
[
1
]
==
1
&&
simplified_src0_dims
[
0
]
==
simplified_src1_dims
[
0
])
{
LaunchColWithMatrix
<
binary_op
,
Src
,
Dst
>
(
cpu_stream
,
simplified_src0_dims
,
src0
,
simplified_src1_dims
,
src1
,
dst
,
attr0
,
attr1
);
}
else
if
(
simplified_num_dims
==
2
&&
simplified_src1_dims
[
1
]
==
1
)
{
}
else
if
(
simplified_num_dims
==
2
&&
simplified_src1_dims
[
1
]
==
1
&&
simplified_src0_dims
[
0
]
==
simplified_src1_dims
[
0
])
{
LaunchMatrixWithCol
<
binary_op
,
Src
,
Dst
>
(
cpu_stream
,
simplified_src0_dims
,
src0
,
simplified_src1_dims
,
src1
,
dst
,
attr0
,
attr1
);
}
else
{
...
...
@@ -405,8 +414,8 @@ class OneDnnBroadcastElementwiseBinaryImpl : public BroadcastElementwiseBinary {
src1_dims
,
dst_dims
);
}
CheckInplace
(
num_dims
,
src_0_dims
.
data
(),
onednn_src0
,
src_1
_dims
.
data
(),
onednn_src1
,
dst_dims
.
data
(),
dst
);
CheckInplace
(
num_dims
,
src_0_dims
.
data
(),
onednn_src0
,
dst
_dims
.
data
(),
dst
);
CheckInplace
(
num_dims
,
src_1_dims
.
data
(),
onednn_src1
,
dst_dims
.
data
(),
dst
);
auto
src_0_md
=
dnnl
::
memory
::
desc
(
src_0_dims
,
src_onednn
,
...
...
@@ -564,7 +573,11 @@ class BroadcastElementwiseBinaryFactoryImpl : public BroadcastElementwiseBinaryF
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE
(
MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_ACTIVATION_GRAD_ENTRY
,
BINARY_ACTIVATION_BACKWARD_OP_SEQ
,
CPU_PRIMITIVE_FLOATING_TYPE_SEQ
)};
BINARY_ACTIVATION_BACKWARD_OP_SEQ
,
CPU_PRIMITIVE_FLOATING_TYPE_SEQ
)
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE
(
MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY
,
BINARY_MATH_BACKWARD_OP_SEQ
,
CPU_PRIMITIVE_FLOATING_TYPE_SEQ
)};
#undef MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_COMPARASION_AND_LOGICAL_ENTRY
#undef MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY
...
...
oneflow/core/ep/cpu/primitive/broadcast_elementwise_unary.cpp
0 → 100644
View file @
a715222c
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/ep/include/primitive/broadcast_elementwise_unary.h"
#include "oneflow/core/common/data_type.h"
#include "oneflow/core/ep/common/primitive/broadcast_elementwise_unary.h"
#include "oneflow/core/ep/cpu/primitive/unary_functor.h"
#include "oneflow/core/ep/cpu/primitive/type_seq.h"
#include "oneflow/core/ep/cpu/cpu_stream.h"
#include "oneflow/core/ep/cpu/cpu_device.h"
namespace
oneflow
{
namespace
ep
{
namespace
primitive
{
namespace
broadcast_elementwise_unary
{
namespace
{
bool
IsContiguous
(
size_t
num_dims
,
const
int64_t
*
dims
,
const
int64_t
*
strides
)
{
for
(
int
i
=
num_dims
-
1
;
i
>=
0
;
i
--
)
{
if
((
i
==
num_dims
-
1
&&
strides
[
i
]
!=
1
)
||
(
i
!=
num_dims
-
1
&&
strides
[
i
]
!=
dims
[
i
+
1
]
*
strides
[
i
+
1
]))
{
return
false
;
}
}
return
true
;
}
template
<
UnaryOp
unary_op
,
typename
Src
,
typename
Dst
>
void
LaunchScalarFill
(
CpuStream
*
stream
,
Dst
*
dst
,
const
Src
*
src
,
size_t
count
,
size_t
stride
,
Scalar
attr0
,
Scalar
attr1
)
{
auto
functor
=
UnaryFunctor
<
DeviceType
::
kCPU
,
unary_op
,
Src
,
Dst
>
(
attr0
,
attr1
);
Dst
scalar_value
=
functor
(
*
src
);
stream
->
ParallelFor
(
0
,
count
,
[
dst
,
stride
,
scalar_value
](
int64_t
begin
,
int64_t
end
)
{
for
(
int64_t
i
=
begin
;
i
<
end
;
i
++
)
{
dst
[
i
*
stride
]
=
scalar_value
;
}
});
}
template
<
UnaryOp
unary_op
,
typename
Src
,
typename
Dst
>
void
LaunchTensorFill
(
CpuStream
*
stream
,
Dst
*
dst
,
const
Src
*
src
,
size_t
count
,
size_t
dst_stride
,
size_t
src_stride
,
Scalar
attr0
,
Scalar
attr1
)
{
auto
functor
=
UnaryFunctor
<
DeviceType
::
kCPU
,
unary_op
,
Src
,
Dst
>
(
attr0
,
attr1
);
stream
->
ParallelFor
(
0
,
count
,
[
functor
,
src
,
dst
,
src_stride
,
dst_stride
](
int64_t
begin
,
int64_t
end
)
{
for
(
int64_t
i
=
begin
;
i
<
end
;
i
++
)
{
dst
[
i
*
dst_stride
]
=
functor
(
src
[
i
*
src_stride
]);
}
});
}
template
<
UnaryOp
unary_op
,
typename
Src
,
typename
Dst
>
void
LaunchGeneral
(
CpuStream
*
stream
,
Dst
*
dst
,
const
Src
*
src
,
size_t
num_dims
,
const
int64_t
*
dst_dims
,
const
int64_t
*
src_dims
,
const
int64_t
*
dst_stride
,
const
int64_t
*
src_stride
,
Scalar
attr0
,
Scalar
attr1
)
{
bool
contiguous_output
=
IsContiguous
(
num_dims
,
dst_dims
,
dst_stride
);
const
int64_t
elem_cnt
=
GetElementCount
(
num_dims
,
dst_dims
);
auto
functor
=
UnaryFunctor
<
DeviceType
::
kCPU
,
unary_op
,
Src
,
Dst
>
(
attr0
,
attr1
);
stream
->
ParallelFor
(
0
,
elem_cnt
,
[
functor
,
src
,
dst
,
num_dims
,
src_dims
,
dst_dims
,
src_stride
,
dst_stride
,
contiguous_output
](
int64_t
begin
,
int64_t
end
)
{
auto
src_index_to_offset_helper
=
IndexToOffsetWithStrideCalculator
<
int64_t
,
kMaxNumDims
>
(
src_stride
,
num_dims
);
auto
dst_offset_to_index_helper
=
OffsetToIndexWithStrideCalculator
<
int64_t
,
kMaxNumDims
>
(
dst_dims
,
num_dims
);
auto
dst_index_to_offset_helper
=
IndexToOffsetWithStrideCalculator
<
int64_t
,
kMaxNumDims
>
(
dst_stride
,
num_dims
);
int64_t
src_index
[
kMaxNumDims
];
int64_t
dst_index
[
kMaxNumDims
];
for
(
int64_t
offset
=
begin
;
offset
<
end
;
offset
++
)
{
dst_offset_to_index_helper
.
OffsetToNdIndex
(
offset
,
dst_index
,
num_dims
);
for
(
int
i
=
0
;
i
<
kMaxNumDims
;
i
++
)
{
if
(
i
<
num_dims
)
{
src_index
[
i
]
=
(
src_dims
[
i
]
!=
1
)
?
dst_index
[
i
]
:
0
;
}
else
{
src_index
[
i
]
=
0
;
}
}
const
int64_t
src_offset
=
src_index_to_offset_helper
.
NdIndexToOffset
(
src_index
,
num_dims
);
if
(
!
contiguous_output
)
{
const
int64_t
dst_offset
=
dst_index_to_offset_helper
.
NdIndexToOffset
(
dst_index
,
num_dims
);
dst
[
dst_offset
]
=
functor
(
src
[
src_offset
]);
}
else
{
dst
[
offset
]
=
functor
(
src
[
src_offset
]);
}
}
});
}
template
<
UnaryOp
unary_op
,
typename
Src
,
typename
Dst
>
class
BroadcastElementwiseUnaryImpl
:
public
BroadcastElementwiseUnary
{
public:
OF_DISALLOW_COPY_AND_MOVE
(
BroadcastElementwiseUnaryImpl
);
BroadcastElementwiseUnaryImpl
(
Scalar
attr0
,
Scalar
attr1
)
:
attr0
(
attr0
),
attr1
(
attr1
)
{}
~
BroadcastElementwiseUnaryImpl
()
override
=
default
;
void
Launch
(
Stream
*
stream
,
size_t
num_src_dims
,
const
int64_t
*
src_dims
,
const
void
*
src
,
size_t
num_dst_dims
,
const
int64_t
*
dst_dims
,
void
*
dst
)
override
{
CHECK_GT
(
num_src_dims
,
0
)
<<
"num_src_dims must greater than 0"
;
CHECK_GT
(
num_dst_dims
,
0
)
<<
"num_dst_dims must greater than 0"
;
int64_t
src_strides
[
kMaxNumDims
];
int64_t
dst_strides
[
kMaxNumDims
];
// init stride
for
(
int
i
=
num_src_dims
-
1
;
i
<
kMaxNumDims
;
++
i
)
{
src_strides
[
i
]
=
1
;
}
for
(
int
i
=
num_src_dims
-
2
;
i
>=
0
;
--
i
)
{
src_strides
[
i
]
=
src_dims
[
i
+
1
]
*
src_strides
[
i
+
1
];
}
for
(
int
i
=
num_dst_dims
-
1
;
i
<
kMaxNumDims
;
++
i
)
{
dst_strides
[
i
]
=
1
;
}
for
(
int
i
=
num_dst_dims
-
2
;
i
>=
0
;
--
i
)
{
dst_strides
[
i
]
=
dst_dims
[
i
+
1
]
*
dst_strides
[
i
+
1
];
}
Launch
(
stream
,
num_src_dims
,
src_dims
,
src_strides
,
src
,
num_dst_dims
,
dst_dims
,
dst_strides
,
dst
);
}
void
Launch
(
Stream
*
stream
,
size_t
num_src_dims
,
const
int64_t
*
src_dims
,
const
int64_t
*
src_strides
,
const
void
*
src_ptr
,
size_t
num_dst_dims
,
const
int64_t
*
dst_dims
,
const
int64_t
*
dst_strides
,
void
*
dst_ptr
)
override
{
CHECK_GT
(
num_src_dims
,
0
)
<<
"num_src_dims must greater than 0"
;
CHECK_GT
(
num_dst_dims
,
0
)
<<
"num_dst_dims must greater than 0"
;
auto
*
cpu_stream
=
stream
->
As
<
CpuStream
>
();
Dst
*
dst
=
reinterpret_cast
<
Dst
*>
(
dst_ptr
);
const
Src
*
src
=
reinterpret_cast
<
const
Src
*>
(
src_ptr
);
size_t
simplified_num_dims
=
0
;
int64_t
simplified_src_dims
[
kMaxNumDims
];
int64_t
simplified_dst_dims
[
kMaxNumDims
];
int64_t
simplified_src_strides
[
kMaxNumDims
];
int64_t
simplified_dst_strides
[
kMaxNumDims
];
SimplifyBroadcastDims
<
kMaxNumDims
>
(
num_src_dims
,
src_dims
,
src_strides
,
num_dst_dims
,
dst_dims
,
dst_strides
,
&
simplified_num_dims
,
simplified_src_dims
,
simplified_src_strides
,
simplified_dst_dims
,
simplified_dst_strides
);
CheckInplace
(
simplified_num_dims
,
simplified_src_dims
,
src
,
simplified_dst_dims
,
dst
);
CheckInplace
(
simplified_num_dims
,
simplified_src_strides
,
src
,
simplified_dst_strides
,
dst
);
if
(
simplified_num_dims
==
1
&&
simplified_src_dims
[
0
]
==
1
)
{
const
int64_t
elem_cnt
=
simplified_dst_dims
[
0
];
const
int64_t
dst_stride
=
simplified_dst_strides
[
0
];
LaunchScalarFill
<
unary_op
,
Src
,
Dst
>
(
cpu_stream
,
dst
,
src
,
elem_cnt
,
dst_stride
,
attr0
,
attr1
);
}
else
if
(
simplified_num_dims
==
1
)
{
const
int64_t
elem_cnt
=
simplified_src_dims
[
0
];
const
int64_t
src_stride
=
simplified_src_strides
[
0
];
const
int64_t
dst_stride
=
simplified_dst_strides
[
0
];
LaunchTensorFill
<
unary_op
,
Src
,
Dst
>
(
cpu_stream
,
dst
,
src
,
elem_cnt
,
dst_stride
,
src_stride
,
attr0
,
attr1
);
}
else
{
LaunchGeneral
<
unary_op
,
Src
,
Dst
>
(
cpu_stream
,
dst
,
src
,
simplified_num_dims
,
simplified_dst_dims
,
simplified_src_dims
,
simplified_dst_strides
,
simplified_src_strides
,
attr0
,
attr1
);
}
}
protected:
Scalar
attr0
,
attr1
;
};
template
<
UnaryOp
unary_op
,
typename
Src
,
typename
Dst
>
std
::
unique_ptr
<
BroadcastElementwiseUnary
>
NewBroadcastElementwiseUnary
(
Scalar
attr0
,
Scalar
attr1
)
{
return
std
::
unique_ptr
<
BroadcastElementwiseUnary
>
(
new
BroadcastElementwiseUnaryImpl
<
unary_op
,
Src
,
Dst
>
(
attr0
,
attr1
));
}
class
BroadcastElementwiseUnaryFactoryImpl
:
public
BroadcastElementwiseUnaryFactory
{
public:
OF_DISALLOW_COPY_AND_MOVE
(
BroadcastElementwiseUnaryFactoryImpl
);
BroadcastElementwiseUnaryFactoryImpl
()
=
default
;
~
BroadcastElementwiseUnaryFactoryImpl
()
override
=
default
;
std
::
unique_ptr
<
BroadcastElementwiseUnary
>
New
(
UnaryOp
op
,
DataType
src_type
,
DataType
dst_type
,
size_t
max_num_dims
)
override
{
return
New
(
op
,
src_type
,
dst_type
,
max_num_dims
,
Scalar
(),
Scalar
());
}
std
::
unique_ptr
<
BroadcastElementwiseUnary
>
New
(
UnaryOp
op
,
DataType
src_type
,
DataType
dst_type
,
size_t
max_num_dims
,
Scalar
attr0
)
override
{
return
New
(
op
,
src_type
,
dst_type
,
max_num_dims
,
attr0
,
Scalar
());
}
std
::
unique_ptr
<
BroadcastElementwiseUnary
>
New
(
UnaryOp
unary_op
,
DataType
src_type
,
DataType
dst_type
,
size_t
max_num_dims
,
Scalar
attr0
,
Scalar
attr1
)
override
{
if
(
max_num_dims
>
kMaxNumDims
)
{
return
nullptr
;
}
#define MAKE_NEW_SAME_DTYPE_BROADCAST_ELEMENTWISE_UNARY_ENTRY(unary_op, dtype_pair) \
{std::make_tuple(unary_op, OF_PP_PAIR_SECOND(dtype_pair), OF_PP_PAIR_SECOND(dtype_pair)), \
NewBroadcastElementwiseUnary<unary_op, OF_PP_PAIR_FIRST(dtype_pair), \
OF_PP_PAIR_FIRST(dtype_pair)>},
static
const
std
::
map
<
std
::
tuple
<
UnaryOp
,
DataType
,
DataType
>
,
std
::
function
<
std
::
unique_ptr
<
BroadcastElementwiseUnary
>
(
Scalar
,
Scalar
)
>>
new_broadcast_elementwise_unary_handle
{
// For All Type OP
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE
(
MAKE_NEW_SAME_DTYPE_BROADCAST_ELEMENTWISE_UNARY_ENTRY
,
UNARY_BROADCAST_OP_SEQ
,
CPU_PRIMITIVE_ALL_TYPE_SEQ
)};
#undef MAKE_NEW_SAME_DTYPE_BROADCAST_ELEMENTWISE_UNARY_ENTRY
const
auto
iter
=
new_broadcast_elementwise_unary_handle
.
find
(
std
::
make_tuple
(
unary_op
,
src_type
,
dst_type
));
if
(
iter
!=
new_broadcast_elementwise_unary_handle
.
end
())
{
return
iter
->
second
(
attr0
,
attr1
);
}
else
{
return
nullptr
;
}
}
};
REGISTER_PRIMITIVE_FACTORY
(
DeviceType
::
kCPU
,
BroadcastElementwiseUnaryFactory
,
BroadcastElementwiseUnaryFactoryImpl
);
}
// namespace
}
// namespace broadcast_elementwise_unary
}
// namespace primitive
}
// namespace ep
}
// namespace oneflow
oneflow/core/ep/cpu/primitive/cast.cpp
View file @
a715222c
...
...
@@ -23,10 +23,28 @@ namespace primitive {
namespace
{
template
<
typename
From
,
typename
To
>
void
CastCpu
(
const
From
*
from
,
To
*
to
,
size_t
count
)
{
for
(
size_t
i
=
0
;
i
<
count
;
++
i
)
{
to
[
i
]
=
static_cast
<
To
>
(
from
[
i
]);
}
}
template
<
typename
From
,
typename
To
,
typename
=
void
>
struct
CpuCastFunctor
{
static
void
Call
(
const
From
*
from
,
To
*
to
,
size_t
count
)
{
for
(
size_t
i
=
0
;
i
<
count
;
++
i
)
{
to
[
i
]
=
static_cast
<
To
>
(
from
[
i
]);
}
}
};
template
<
typename
To
>
struct
CpuCastFunctor
<
bfloat16
,
To
,
typename
std
::
enable_if
<!
(
std
::
is_same
<
To
,
bfloat16
>::
value
)
>::
type
>
{
static
void
Call
(
const
bfloat16
*
from
,
To
*
to
,
size_t
count
)
{
for
(
size_t
i
=
0
;
i
<
count
;
++
i
)
{
to
[
i
]
=
static_cast
<
To
>
(
static_cast
<
float
>
(
from
[
i
]));
}
}
};
template
<
typename
From
>
struct
CpuCastFunctor
<
From
,
bfloat16
,
typename
std
::
enable_if
<!
(
std
::
is_same
<
From
,
bfloat16
>::
value
)
>::
type
>
{
static
void
Call
(
const
From
*
from
,
bfloat16
*
to
,
size_t
count
)
{
for
(
size_t
i
=
0
;
i
<
count
;
++
i
)
{
to
[
i
]
=
bfloat16
(
static_cast
<
float
>
(
from
[
i
]));
}
}
};
template
<
typename
From
,
typename
To
>
class
CastImpl
:
public
Cast
{
...
...
@@ -36,7 +54,8 @@ class CastImpl : public Cast {
~
CastImpl
()
override
=
default
;
void
Launch
(
Stream
*
stream
,
const
void
*
from
,
void
*
to
,
size_t
count
)
override
{
CastCpu
(
reinterpret_cast
<
const
From
*>
(
from
),
reinterpret_cast
<
To
*>
(
to
),
count
);
CpuCastFunctor
<
From
,
To
>::
Call
(
reinterpret_cast
<
const
From
*>
(
from
),
reinterpret_cast
<
To
*>
(
to
),
count
);
}
};
...
...
@@ -56,7 +75,8 @@ std::unique_ptr<Cast> NewCast() {
CPU_PRIMITIVE_UINT64_TYPE_SEQ \
CPU_PRIMITIVE_FLOAT_TYPE_SEQ \
CPU_PRIMITIVE_DOUBLE_TYPE_SEQ \
CPU_PRIMITIVE_FLOAT16_TYPE_SEQ
CPU_PRIMITIVE_FLOAT16_TYPE_SEQ \
CPU_PRIMITIVE_BFLOAT16_TYPE_SEQ
class
CastFactoryImpl
:
public
CastFactory
{
public:
...
...
oneflow/core/ep/cpu/primitive/constant_pad.cpp
View file @
a715222c
...
...
@@ -56,6 +56,11 @@ float16 GetValue<float16>(Scalar value) {
return
static_cast
<
float16
>
(
GetValue
<
float
>
(
value
));
}
template
<
>
bfloat16
GetValue
<
bfloat16
>
(
Scalar
value
)
{
return
static_cast
<
bfloat16
>
(
GetValue
<
float
>
(
value
));
}
template
<
size_t
num_dims
,
typename
IndexType
,
typename
StorageType
>
void
LaunchKernel
(
ConstantPadParams
<
num_dims
,
IndexType
>
params
,
StorageType
packed_pad_val
)
{
ConstantPadKernel
<
num_dims
,
IndexType
,
StorageType
>
(
params
,
packed_pad_val
);
...
...
@@ -163,6 +168,7 @@ template<typename T>
void
SimplifyThenLaunch
(
size_t
num_dims
,
const
int64_t
*
src_dims
,
const
void
*
src
,
const
int64_t
*
padding_before
,
const
int64_t
*
padding_after
,
T
pad_val
,
void
*
dst
)
{
CHECK_GT
(
num_dims
,
0
)
<<
"num_dims must greater than 0"
;
CHECK_LE
(
num_dims
,
kMaxNumDims
);
int64_t
simplified_dst_dims
[
kMaxNumDims
];
int64_t
simplified_src_dims
[
kMaxNumDims
];
...
...
oneflow/core/ep/cpu/primitive/elementwise_unary.cpp
View file @
a715222c
...
...
@@ -88,9 +88,13 @@ class ElementwiseUnaryFactoryImpl : public ElementwiseUnaryFactory {
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE
(
MAKE_NEW_SAME_DTYPE_ELEMENTWISE_UNARY_ENTRY
,
UNARY_MATH_OP_SEQ
,
CPU_PRIMITIVE_NATIVE_TYPE_SEQ
)
// For Float Type OP
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE
(
MAKE_NEW_SAME_DTYPE_ELEMENTWISE_UNARY_ENTRY
,
UNARY_FLOATING_MATH_OP_SEQ
,
CPU_PRIMITIVE_FLOATING_TYPE_SEQ
CPU_PRIMITIVE_BFLOAT16_TYPE_SEQ
)
// For Int Type OP
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE
(
MAKE_NEW_SAME_DTYPE_ELEMENTWISE_UNARY_ENTRY
,
UNARY_FLOATING_MATH_OP_SEQ
,
CPU_PRIMITIVE_FLOATING_TYPE_SEQ
)
UNARY_INT_MATH_OP_SEQ
,
CPU_PRIMITIVE_INT_TYPE_SEQ
)
// For Utils OP
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE
(
MAKE_NEW_DIFFERENT_DTYPE_ELEMENTWISE_UNARY_ENTRY
,
...
...
oneflow/core/ep/cpu/primitive/fill.cpp
View file @
a715222c
...
...
@@ -34,6 +34,11 @@ float16 GetValue<float16>(Scalar value) {
return
static_cast
<
float16
>
(
GetValue
<
float
>
(
value
));
}
template
<
>
bfloat16
GetValue
<
bfloat16
>
(
Scalar
value
)
{
return
static_cast
<
bfloat16
>
(
GetValue
<
float
>
(
value
));
}
template
<
typename
T
>
class
FillImpl
:
public
Fill
{
public:
...
...
oneflow/core/ep/cpu/primitive/tensor_fill.cpp
0 → 100644
View file @
a715222c
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/ep/include/primitive/tensor_fill.h"
#include "oneflow/core/ep/cpu/primitive/type_seq.h"
namespace
oneflow
{
namespace
ep
{
namespace
primitive
{
namespace
{
template
<
typename
T
>
class
TensorFillImpl
:
public
TensorFill
{
public:
OF_DISALLOW_COPY_AND_MOVE
(
TensorFillImpl
);
TensorFillImpl
()
=
default
;
~
TensorFillImpl
()
override
=
default
;
void
Launch
(
Stream
*
stream
,
const
void
*
src
,
void
*
dst
,
size_t
count
)
override
{
const
T
*
value
=
reinterpret_cast
<
const
T
*>
(
src
);
std
::
fill_n
(
reinterpret_cast
<
T
*>
(
dst
),
count
,
value
[
0
]);
}
};
template
<
typename
T
>
std
::
unique_ptr
<
TensorFill
>
NewTensorFill
()
{
return
std
::
unique_ptr
<
TensorFill
>
(
new
TensorFillImpl
<
T
>
());
}
class
TensorFillFactoryImpl
:
public
TensorFillFactory
{
public:
OF_DISALLOW_COPY_AND_MOVE
(
TensorFillFactoryImpl
);
TensorFillFactoryImpl
()
=
default
;
~
TensorFillFactoryImpl
()
override
=
default
;
std
::
unique_ptr
<
TensorFill
>
New
(
DataType
data_type
)
override
{
#define MAKE_NEW_FILL_ENTRY(type_cpp, type_proto) {type_proto, NewTensorFill<type_cpp>},
static
const
std
::
map
<
DataType
,
std
::
function
<
std
::
unique_ptr
<
TensorFill
>
()
>>
new_fill_handle
{
OF_PP_FOR_EACH_TUPLE
(
MAKE_NEW_FILL_ENTRY
,
CPU_PRIMITIVE_ALL_TYPE_SEQ
)};
#undef MAKE_NEW_ADD_ENTRY
const
auto
it
=
new_fill_handle
.
find
(
data_type
);
if
(
it
!=
new_fill_handle
.
end
())
{
return
it
->
second
();
}
else
{
return
nullptr
;
}
}
};
REGISTER_PRIMITIVE_FACTORY
(
DeviceType
::
kCPU
,
TensorFillFactory
,
TensorFillFactoryImpl
);
}
// namespace
}
// namespace primitive
}
// namespace ep
}
// namespace oneflow
oneflow/core/ep/cpu/primitive/type_seq.h
View file @
a715222c
...
...
@@ -35,6 +35,7 @@ limitations under the License.
#define CPU_PRIMITIVE_FLOAT_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(float, DataType::kFloat)
#define CPU_PRIMITIVE_DOUBLE_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(double, DataType::kDouble)
#define CPU_PRIMITIVE_FLOAT16_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(float16, DataType::kFloat16)
#define CPU_PRIMITIVE_BFLOAT16_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(bfloat16, DataType::kBFloat16)
#define CPU_PRIMITIVE_ONEDNN_BOOl_TYPE_SEQ \
OF_PP_MAKE_TUPLE_SEQ(dnnl::memory::data_type::u8, DataType::kBool)
...
...
@@ -63,12 +64,19 @@ limitations under the License.
#define CPU_PRIMITIVE_ALL_TYPE_SEQ \
CPU_PRIMITIVE_NATIVE_TYPE_SEQ \
CPU_PRIMITIVE_FLOAT16_TYPE_SEQ
CPU_PRIMITIVE_FLOAT16_TYPE_SEQ \
CPU_PRIMITIVE_BFLOAT16_TYPE_SEQ
#define CPU_PRIMITIVE_FLOATING_TYPE_SEQ \
CPU_PRIMITIVE_FLOAT_TYPE_SEQ \
CPU_PRIMITIVE_DOUBLE_TYPE_SEQ
#define CPU_PRIMITIVE_INT_TYPE_SEQ \
CPU_PRIMITIVE_INT8_TYPE_SEQ \
CPU_PRIMITIVE_UINT8_TYPE_SEQ \
CPU_PRIMITIVE_INT32_TYPE_SEQ \
CPU_PRIMITIVE_INT64_TYPE_SEQ
#define UTIL_OPS_DATA_TYPE_SEQ \
CPU_PRIMITIVE_INT8_TYPE_SEQ \
CPU_PRIMITIVE_UINT8_TYPE_SEQ \
...
...
oneflow/core/ep/cpu/primitive/unary_functor.h
View file @
a715222c
...
...
@@ -15,7 +15,6 @@ limitations under the License.
*/
#include "oneflow/core/ep/common/primitive/unary_functor.h"
#include "oneflow/core/ep/cpu/primitive/type_seq.h"
#include <cmath>
namespace
oneflow
{
namespace
ep
{
...
...
@@ -23,7 +22,7 @@ namespace primitive {
template
<
typename
Dst
,
typename
Src
>
struct
UnaryFunctor
<
DeviceType
::
kCPU
,
UnaryOp
::
kGelu
,
Dst
,
Src
>
{
UnaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
UnaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
Dst
operator
()(
Src
src
)
const
{
return
static_cast
<
Src
>
(
0.5
)
*
src
*
(
static_cast
<
Src
>
(
1.0
)
+
std
::
erf
(
inv_sqrt2
*
src
));
...
...
@@ -31,9 +30,42 @@ struct UnaryFunctor<DeviceType::kCPU, UnaryOp::kGelu, Dst, Src> {
Src
inv_sqrt2
=
std
::
sqrt
(
0.5
);
};
template
<
typename
Dst
,
typename
Src
>
struct
UnaryFunctor
<
DeviceType
::
kCPU
,
UnaryOp
::
kFastGelu
,
Dst
,
Src
>
{
OF_DEVICE_FUNC
UnaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
Dst
operator
()(
Src
src
)
const
{
// ref to: https://mlfromscratch.com/activation-functions-explained/#gelu
const
Src
half
=
static_cast
<
Src
>
(
0.5
);
const
Src
one
=
static_cast
<
Src
>
(
1
);
const
Src
tanh_in
=
alpha
*
(
src
+
beta
*
src
*
src
*
src
);
return
half
*
src
*
(
one
+
std
::
tanh
(
tanh_in
));
}
private:
// constant ref to:
// https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/test/testdata/transform/fusion/fast_gelu.py
static
constexpr
Src
alpha
=
static_cast
<
Src
>
(
0.7978845608028654
);
static
constexpr
Src
beta
=
static_cast
<
Src
>
(
0.044714998453855515
);
};
template
<
typename
Dst
,
typename
Src
>
struct
UnaryFunctor
<
DeviceType
::
kCPU
,
UnaryOp
::
kQuickGelu
,
Dst
,
Src
>
{
OF_DEVICE_FUNC
UnaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
Dst
operator
()(
Src
src
)
const
{
const
Src
sigmoid
=
static_cast
<
Dst
>
(
static_cast
<
Src
>
(
1.0
)
/
(
static_cast
<
Src
>
(
1.0
)
+
exp
(
-
src
*
alpha
)));
return
src
*
sigmoid
;
}
private:
static
constexpr
Src
alpha
=
static_cast
<
Src
>
(
1.702
);
};
template
<
typename
Dst
,
typename
Src
>
struct
UnaryFunctor
<
DeviceType
::
kCPU
,
UnaryOp
::
kTanh
,
Dst
,
Src
>
{
UnaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
UnaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
Dst
operator
()(
Src
src
)
const
{
return
std
::
tanh
(
src
);
}
};
...
...
@@ -66,6 +98,109 @@ struct UnaryFunctor<DeviceType::kCPU, UnaryOp::kIsNan, bool, double> {
OF_DEVICE_FUNC
bool
operator
()(
double
src
)
const
{
return
std
::
isnan
(
src
);
}
};
template
<
typename
Src
>
struct
UnaryFunctor
<
DeviceType
::
kCPU
,
UnaryOp
::
kIsFinite
,
bool
,
Src
>
{
UnaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
bool
operator
()(
Src
src
)
const
{
return
std
::
isfinite
(
src
);
}
};
template
<
typename
Dst
,
typename
Src
>
struct
UnaryFunctor
<
DeviceType
::
kCPU
,
UnaryOp
::
kTrunc
,
Dst
,
Src
>
{
OF_DEVICE_FUNC
UnaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
Dst
operator
()(
Src
src
)
const
{
return
static_cast
<
Dst
>
(
std
::
trunc
(
src
));
}
};
template
<
typename
Dst
,
typename
Src
>
struct
UnaryFunctor
<
DeviceType
::
kCPU
,
UnaryOp
::
kRsqrt
,
Dst
,
Src
>
{
OF_DEVICE_FUNC
UnaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
Dst
operator
()(
Src
src
)
const
{
return
static_cast
<
Dst
>
(
static_cast
<
Src
>
(
1.0
)
/
static_cast
<
Src
>
(
std
::
sqrt
(
src
)));
}
};
template
<
>
struct
UnaryFunctor
<
DeviceType
::
kCPU
,
UnaryOp
::
kAbs
,
bfloat16
,
bfloat16
>
{
OF_DEVICE_FUNC
UnaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
bfloat16
operator
()(
bfloat16
src
)
const
{
return
std
::
abs
(
src
);
}
};
#define SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(op) \
template<> \
struct UnaryFunctor<DeviceType::kCPU, op, bfloat16, bfloat16> { \
OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) : float_functor(attr0, attr1) {} \
\
UnaryFunctor<DeviceType::kCPU, op, float, float> float_functor; \
OF_DEVICE_FUNC bfloat16 operator()(bfloat16 src) const { \
return bfloat16(float_functor(static_cast<float>(src))); \
} \
};
SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR
(
UnaryOp
::
kElu
);
SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR
(
UnaryOp
::
kCelu
);
SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR
(
UnaryOp
::
kGelu
);
SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR
(
UnaryOp
::
kHardSwish
);
SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR
(
UnaryOp
::
kHardSigmoid
);
SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR
(
UnaryOp
::
kHardShrink
);
SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR
(
UnaryOp
::
kHardTanh
);
SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR
(
UnaryOp
::
kLeakyRelu
);
SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR
(
UnaryOp
::
kMish
);
SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR
(
UnaryOp
::
kSelu
);
SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR
(
UnaryOp
::
kSilu
);
SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR
(
UnaryOp
::
kSoftShrink
);
SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR
(
UnaryOp
::
kSoftSign
);
SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR
(
UnaryOp
::
kSoftPlus
);
SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR
(
UnaryOp
::
kTanh
);
SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR
(
UnaryOp
::
kThreshold
);
SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR
(
UnaryOp
::
kAcos
);
SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR
(
UnaryOp
::
kAcosh
);
SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR
(
UnaryOp
::
kAsin
);
SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR
(
UnaryOp
::
kAsinh
);
SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR
(
UnaryOp
::
kAtan
);
SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR
(
UnaryOp
::
kAtanh
);
SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR
(
UnaryOp
::
kCeil
);
SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR
(
UnaryOp
::
kCos
);
SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR
(
UnaryOp
::
kCosh
);
SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR
(
UnaryOp
::
kErf
);
SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR
(
UnaryOp
::
kErfc
);
SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR
(
UnaryOp
::
kExp
);
SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR
(
UnaryOp
::
kExpm1
);
SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR
(
UnaryOp
::
kFloor
);
SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR
(
UnaryOp
::
kLgamma
);
SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR
(
UnaryOp
::
kLog
);
SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR
(
UnaryOp
::
kLog2
);
SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR
(
UnaryOp
::
kLog1p
);
SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR
(
UnaryOp
::
kLogSigmoid
);
SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR
(
UnaryOp
::
kRint
);
SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR
(
UnaryOp
::
kRound
);
SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR
(
UnaryOp
::
kRsqrt
);
SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR
(
UnaryOp
::
kSigmoid
);
SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR
(
UnaryOp
::
kSin
);
SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR
(
UnaryOp
::
kSinh
);
SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR
(
UnaryOp
::
kSqrt
);
SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR
(
UnaryOp
::
kSquare
);
SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR
(
UnaryOp
::
kTan
);
SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR
(
UnaryOp
::
kReciprocalNoNan
);
SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR
(
UnaryOp
::
kNotEqualZero
);
SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR
(
UnaryOp
::
kFastGelu
);
SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR
(
UnaryOp
::
kQuickGelu
);
template
<
>
struct
UnaryFunctor
<
DeviceType
::
kCPU
,
UnaryOp
::
kIsInf
,
bool
,
bfloat16
>
{
UnaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
bool
operator
()(
bfloat16
src
)
const
{
return
std
::
isinf
(
src
);
}
};
template
<
>
struct
UnaryFunctor
<
DeviceType
::
kCPU
,
UnaryOp
::
kIsNan
,
bool
,
bfloat16
>
{
UnaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
bool
operator
()(
bfloat16
src
)
const
{
return
std
::
isnan
(
src
);
}
};
}
// namespace primitive
}
// namespace ep
}
// namespace oneflow
oneflow/core/ep/cuda/cuda_device.cpp
View file @
a715222c
...
...
@@ -55,6 +55,13 @@ CudaDevice::CudaDevice(int device_index, DeviceManager* device_manager)
const_ones_buffer_bf16_
(
nullptr
)
{
CudaCurrentDeviceGuard
guard
(
device_index_
);
OF_CUDA_CHECK
(
cudaGetDeviceProperties
(
&
properties_
,
device_index_
));
{
const
char
*
env_name
=
"ONEFLOW_EP_CUDA_DEVICE_FLAGS"
;
if
(
std
::
getenv
(
env_name
)
!=
nullptr
)
{
const
unsigned
int
flags
=
ParseIntegerFromEnv
(
env_name
,
0
);
OF_CUDA_CHECK
(
cudaSetDeviceFlags
(
flags
));
}
}
event_flags_
=
cudaEventDisableTiming
;
if
(
ParseBooleanFromEnv
(
"ONEFLOW_STREAM_CUDA_EVENT_FLAG_BLOCKING_SYNC"
,
false
))
{
event_flags_
|=
cudaEventBlockingSync
;
...
...
@@ -119,6 +126,10 @@ Maybe<void> CudaDevice::Alloc(const AllocationOptions& options, void** ptr, size
CHECK
(
!
options
.
HasPinnedDevice
());
cudaError_t
err
=
cudaMalloc
(
ptr
,
size
);
if
(
err
!=
cudaSuccess
)
{
if
(
err
==
cudaErrorMemoryAllocation
)
{
// NOTE:return out of memory error, so vm will try to shrink memory and rerun
return
Error
::
OutOfMemoryError
()
<<
cudaGetErrorString
(
err
);
}
return
Error
::
RuntimeError
()
<<
cudaGetErrorString
(
err
);
}
else
{
return
Maybe
<
void
>::
Ok
();
...
...
@@ -177,3 +188,176 @@ const void* CudaDevice::GetConstOnes(DataType data_type, size_t n) const {
}
// namespace oneflow
#endif // WITH_CUDA
#ifdef WITH_ROCM
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
// #if CUDA_VERSION >= 11000
// #include <cuda_bf16.h>
// #endif
namespace
oneflow
{
namespace
ep
{
namespace
{
constexpr
size_t
kDefaultConstBufElementCount
=
1024
*
1024
;
template
<
typename
T
>
void
CreateConstBuffer
(
void
**
buf
,
T
value
,
size_t
n
)
{
OF_CUDA_CHECK
(
hipMalloc
(
buf
,
n
*
sizeof
(
T
)));
std
::
vector
<
T
>
host
(
n
,
value
);
OF_CUDA_CHECK
(
hipMemcpy
(
*
buf
,
host
.
data
(),
n
*
sizeof
(
T
),
hipMemcpyDefault
));
}
}
// namespace
CudaDevice
::
CudaDevice
(
int
device_index
,
DeviceManager
*
device_manager
)
:
device_index_
(
device_index
),
event_flags_
{},
properties_
{},
device_manager_
(
device_manager
),
const_buf_elem_cnt_
(
0
),
const_zeros_buffer_
(
nullptr
),
const_ones_buffer_fp32_
(
nullptr
),
const_ones_buffer_fp16_
(
nullptr
),
const_ones_buffer_bf16_
(
nullptr
)
{
CudaCurrentDeviceGuard
guard
(
device_index_
);
OF_CUDA_CHECK
(
hipGetDeviceProperties
(
&
properties_
,
device_index_
));
{
const
char
*
env_name
=
"ONEFLOW_EP_CUDA_DEVICE_FLAGS"
;
if
(
std
::
getenv
(
env_name
)
!=
nullptr
)
{
const
unsigned
int
flags
=
ParseIntegerFromEnv
(
env_name
,
0
);
OF_CUDA_CHECK
(
hipSetDeviceFlags
(
flags
));
}
}
event_flags_
=
hipEventDisableTiming
;
if
(
ParseBooleanFromEnv
(
"ONEFLOW_STREAM_CUDA_EVENT_FLAG_BLOCKING_SYNC"
,
false
))
{
event_flags_
|=
hipEventBlockingSync
;
}
const_buf_elem_cnt_
=
ParseIntegerFromEnv
(
"ONEFLOW_EP_CUDA_CONST_BUFFER_ELEMENT_COUNT"
,
kDefaultConstBufElementCount
);
if
(
const_buf_elem_cnt_
>
0
)
{
CreateConstBuffer
<
float
>
(
&
const_zeros_buffer_
,
static_cast
<
float
>
(
0
),
const_buf_elem_cnt_
);
CreateConstBuffer
<
float
>
(
&
const_ones_buffer_fp32_
,
static_cast
<
float
>
(
1.0
),
const_buf_elem_cnt_
);
CreateConstBuffer
<
half
>
(
&
const_ones_buffer_fp16_
,
static_cast
<
half
>
(
1.0
),
const_buf_elem_cnt_
);
// #if CUDA_VERSION >= 11000
// CreateConstBuffer<nv_bfloat16>(&const_ones_buffer_bf16_, static_cast<nv_bfloat16>(1.0),
// const_buf_elem_cnt_);
// #endif
}
}
CudaDevice
::~
CudaDevice
()
{
CudaCurrentDeviceGuard
guard
(
device_index_
);
for
(
auto
*
event
:
events_
)
{
delete
event
;
}
OF_CUDA_CHECK
(
hipFree
(
const_zeros_buffer_
));
OF_CUDA_CHECK
(
hipFree
(
const_ones_buffer_fp32_
));
OF_CUDA_CHECK
(
hipFree
(
const_ones_buffer_fp16_
));
OF_CUDA_CHECK
(
hipFree
(
const_ones_buffer_bf16_
));
}
void
CudaDevice
::
SetAsActiveDevice
()
{
OF_CUDA_CHECK
(
hipSetDevice
(
device_index_
));
}
Stream
*
CudaDevice
::
CreateStream
()
{
CudaCurrentDeviceGuard
guard
(
device_index_
);
return
new
CudaStream
(
this
);
}
void
CudaDevice
::
DestroyStream
(
Stream
*
stream
)
{
CudaCurrentDeviceGuard
guard
(
device_index_
);
delete
stream
;
}
void
CudaDevice
::
CreateEvents
(
Event
**
events
,
size_t
count
)
{
size_t
copied
=
0
;
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
events_mutex_
);
copied
=
std
::
min
(
count
,
events_
.
size
());
size_t
offset
=
events_
.
size
()
-
copied
;
std
::
copy
(
events_
.
begin
()
+
offset
,
events_
.
end
(),
events
);
events_
.
resize
(
offset
);
}
if
(
copied
!=
count
)
{
CudaCurrentDeviceGuard
guard
(
device_index_
);
for
(
size_t
i
=
copied
;
i
<
count
;
++
i
)
{
events
[
i
]
=
new
CudaEvent
(
event_flags_
);
}
}
}
void
CudaDevice
::
DestroyEvents
(
Event
**
events
,
size_t
count
)
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
events_mutex_
);
events_
.
insert
(
events_
.
end
(),
events
,
events
+
count
);
}
Maybe
<
void
>
CudaDevice
::
Alloc
(
const
AllocationOptions
&
options
,
void
**
ptr
,
size_t
size
)
{
CudaCurrentDeviceGuard
guard
(
device_index_
);
CHECK
(
!
options
.
HasPinnedDevice
());
hipError_t
err
=
hipMalloc
(
ptr
,
size
);
if
(
err
!=
hipSuccess
)
{
if
(
err
==
hipErrorMemoryAllocation
)
{
// NOTE:return out of memory error, so vm will try to shrink memory and rerun
return
Error
::
OutOfMemoryError
()
<<
hipGetErrorString
(
err
);
}
return
Error
::
RuntimeError
()
<<
hipGetErrorString
(
err
);
}
else
{
return
Maybe
<
void
>::
Ok
();
}
}
void
CudaDevice
::
Free
(
const
AllocationOptions
&
attr
,
void
*
ptr
)
{
CudaCurrentDeviceGuard
guard
(
device_index_
);
OF_CUDA_CHECK
(
hipFree
(
ptr
));
}
Maybe
<
void
>
CudaDevice
::
AllocPinned
(
const
AllocationOptions
&
options
,
void
**
ptr
,
size_t
size
)
{
CudaCurrentDeviceGuard
guard
(
device_index_
);
hipError_t
err
=
NumaAwareCudaMallocHost
(
device_index_
,
ptr
,
size
);
if
(
err
!=
hipSuccess
)
{
return
Error
::
RuntimeError
()
<<
hipGetErrorString
(
err
);
}
else
{
return
Maybe
<
void
>::
Ok
();
}
}
void
CudaDevice
::
FreePinned
(
const
AllocationOptions
&
options
,
void
*
ptr
)
{
CudaCurrentDeviceGuard
guard
(
device_index_
);
OF_CUDA_CHECK
(
hipHostFree
(
ptr
));
}
const
hipDeviceProp_t
&
CudaDevice
::
properties
()
const
{
return
properties_
;
}
const
void
*
CudaDevice
::
GetConstZeros
(
DataType
data_type
,
size_t
n
)
const
{
if
(
GetSizeOfDataType
(
data_type
)
*
n
<=
GetSizeOfDataType
(
DataType
::
kFloat
)
*
const_buf_elem_cnt_
)
{
return
const_zeros_buffer_
;
}
else
{
return
nullptr
;
}
}
const
void
*
CudaDevice
::
GetConstOnes
(
DataType
data_type
,
size_t
n
)
const
{
if
(
n
<=
const_buf_elem_cnt_
)
{
if
(
data_type
==
DataType
::
kFloat
)
{
return
const_ones_buffer_fp32_
;
}
else
if
(
data_type
==
DataType
::
kFloat16
)
{
return
const_ones_buffer_fp16_
;
}
else
if
(
data_type
==
DataType
::
kBFloat16
)
{
return
const_ones_buffer_bf16_
;
}
else
{
return
nullptr
;
}
}
else
{
return
nullptr
;
}
}
}
// namespace ep
}
// namespace oneflow
#endif // WITH_ROCM
oneflow/core/ep/cuda/cuda_device.h
View file @
a715222c
...
...
@@ -75,4 +75,60 @@ class CudaDevice : public Device {
#endif // WITH_CUDA
#ifdef WITH_ROCM
#include <hip/hip_runtime.h>
namespace
oneflow
{
namespace
ep
{
class
CudaDevice
:
public
Device
{
public:
OF_DISALLOW_COPY_AND_MOVE
(
CudaDevice
);
explicit
CudaDevice
(
int
device_index
,
DeviceManager
*
device_manager
);
~
CudaDevice
()
override
;
void
SetAsActiveDevice
()
override
;
DeviceType
device_type
()
const
override
{
return
DeviceType
::
kCUDA
;
}
size_t
device_index
()
const
override
{
return
device_index_
;
}
DeviceManager
*
device_manager
()
const
override
{
return
device_manager_
;
}
Stream
*
CreateStream
()
override
;
void
DestroyStream
(
Stream
*
stream
)
override
;
void
CreateEvents
(
Event
**
events
,
size_t
count
)
override
;
void
DestroyEvents
(
Event
**
events
,
size_t
count
)
override
;
Maybe
<
void
>
Alloc
(
const
AllocationOptions
&
options
,
void
**
ptr
,
size_t
size
)
override
;
void
Free
(
const
AllocationOptions
&
options
,
void
*
ptr
)
override
;
Maybe
<
void
>
AllocPinned
(
const
AllocationOptions
&
options
,
void
**
ptr
,
size_t
size
)
override
;
void
FreePinned
(
const
AllocationOptions
&
options
,
void
*
ptr
)
override
;
const
hipDeviceProp_t
&
properties
()
const
;
const
void
*
GetConstZeros
(
DataType
data_type
,
size_t
n
)
const
;
const
void
*
GetConstOnes
(
DataType
data_type
,
size_t
n
)
const
;
private:
int
device_index_
;
std
::
mutex
events_mutex_
;
std
::
vector
<
Event
*>
events_
;
unsigned
int
event_flags_
;
hipDeviceProp_t
properties_
;
DeviceManager
*
device_manager_
;
int64_t
const_buf_elem_cnt_
;
void
*
const_zeros_buffer_
;
void
*
const_ones_buffer_fp32_
;
void
*
const_ones_buffer_fp16_
;
void
*
const_ones_buffer_bf16_
;
};
}
// namespace ep
}
// namespace oneflow
#endif // WITH_ROCM
#endif // ONEFLOW_CORE_EP_CUDA_CUDA_DEVICE_H_
Prev
1
…
17
18
19
20
21
22
23
24
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment