Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Oneflow
Commits
8f7de847
Commit
8f7de847
authored
Apr 25, 2023
by
yuguo960516yuguo
Browse files
dtk
parent
f262efc9
Pipeline
#248
failed with stages
in 0 seconds
Changes
121
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2730 additions
and
2730 deletions
+2730
-2730
oneflow/core/ep/rocm/primitive/binary_functor.hip.h
oneflow/core/ep/rocm/primitive/binary_functor.hip.h
+150
-150
oneflow/core/ep/rocm/primitive/broadcast_elementwise_binary.hip.cpp
...re/ep/rocm/primitive/broadcast_elementwise_binary.hip.cpp
+109
-109
oneflow/core/ep/rocm/primitive/broadcast_elementwise_binary.hip.h
...core/ep/rocm/primitive/broadcast_elementwise_binary.hip.h
+396
-396
oneflow/core/ep/rocm/primitive/broadcast_elementwise_binary_activation_grad.hip.cpp
...tive/broadcast_elementwise_binary_activation_grad.hip.cpp
+39
-39
oneflow/core/ep/rocm/primitive/broadcast_elementwise_binary_comparision.hip.cpp
...rimitive/broadcast_elementwise_binary_comparision.hip.cpp
+38
-38
oneflow/core/ep/rocm/primitive/broadcast_elementwise_binary_logical.hip.cpp
...cm/primitive/broadcast_elementwise_binary_logical.hip.cpp
+38
-38
oneflow/core/ep/rocm/primitive/broadcast_elementwise_binary_math.hip.cpp
.../rocm/primitive/broadcast_elementwise_binary_math.hip.cpp
+35
-35
oneflow/core/ep/rocm/primitive/broadcast_matmul.cpp
oneflow/core/ep/rocm/primitive/broadcast_matmul.cpp
+237
-237
oneflow/core/ep/rocm/primitive/cast.hip.cpp
oneflow/core/ep/rocm/primitive/cast.hip.cpp
+148
-148
oneflow/core/ep/rocm/primitive/constant_pad.hip.cpp
oneflow/core/ep/rocm/primitive/constant_pad.hip.cpp
+254
-254
oneflow/core/ep/rocm/primitive/copy_nd.hip.cpp
oneflow/core/ep/rocm/primitive/copy_nd.hip.cpp
+95
-95
oneflow/core/ep/rocm/primitive/elementwise_unary.hip.cpp
oneflow/core/ep/rocm/primitive/elementwise_unary.hip.cpp
+116
-116
oneflow/core/ep/rocm/primitive/fill.hip.cpp
oneflow/core/ep/rocm/primitive/fill.hip.cpp
+151
-151
oneflow/core/ep/rocm/primitive/memcpy.cpp
oneflow/core/ep/rocm/primitive/memcpy.cpp
+62
-62
oneflow/core/ep/rocm/primitive/memset.cpp
oneflow/core/ep/rocm/primitive/memset.cpp
+59
-59
oneflow/core/ep/rocm/primitive/permute.hip.cpp
oneflow/core/ep/rocm/primitive/permute.hip.cpp
+333
-333
oneflow/core/ep/rocm/primitive/softmax.hip.cpp
oneflow/core/ep/rocm/primitive/softmax.hip.cpp
+107
-107
oneflow/core/ep/rocm/primitive/softmax_backward.hip.cpp
oneflow/core/ep/rocm/primitive/softmax_backward.hip.cpp
+116
-116
oneflow/core/ep/rocm/primitive/type_seq.h
oneflow/core/ep/rocm/primitive/type_seq.h
+77
-77
oneflow/core/ep/rocm/primitive/unary_functor.hip.h
oneflow/core/ep/rocm/primitive/unary_functor.hip.h
+170
-170
No files found.
Too many changes to show.
To preserve performance only
121 of 121+
files are displayed.
Plain diff
Email patch
oneflow/core/ep/rocm/primitive/binary_functor.hip.h
View file @
8f7de847
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/ep/common/primitive/binary_functor.h"
namespace
oneflow
{
namespace
ep
{
namespace
primitive
{
namespace
broadcast_elementwise_binary
{
template
<
typename
Src
,
typename
Dst
>
struct
BinaryFunctor
<
DeviceType
::
kCUDA
,
BinaryOp
::
kPow
,
Src
,
Dst
>
{
OF_DEVICE_FUNC
BinaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
Dst
operator
()(
Src
src0
,
Src
src1
)
const
{
return
pow
(
src0
,
src1
);
}
};
template
<
>
struct
BinaryFunctor
<
DeviceType
::
kCUDA
,
BinaryOp
::
kPow
,
bool
,
bool
>
{
OF_DEVICE_FUNC
BinaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
bool
operator
()(
bool
src0
,
bool
src1
)
const
{
return
static_cast
<
bool
>
(
pow
(
static_cast
<
double
>
(
src0
),
static_cast
<
double
>
(
src1
)));
}
};
template
<
>
struct
BinaryFunctor
<
DeviceType
::
kCUDA
,
BinaryOp
::
kPow
,
half
,
half
>
{
OF_DEVICE_FUNC
BinaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
half
operator
()(
half
src0
,
half
src1
)
const
{
return
static_cast
<
half
>
(
pow
(
static_cast
<
float
>
(
src0
),
static_cast
<
float
>
(
src1
)));
}
};
template
<
typename
Src
,
typename
Dst
>
struct
BinaryFunctor
<
DeviceType
::
kCUDA
,
BinaryOp
::
kGeluBackwardWithDyX
,
Src
,
Dst
>
{
OF_DEVICE_FUNC
BinaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{
#if defined(__CUDA_ARCH__)
coef
=
sqrt
(
static_cast
<
Src
>
(
2.0
)
/
acos
(
static_cast
<
Src
>
(
-
1.0
)));
#elif defined(__HIP_DEVICE_COMPILE__)
coef
=
sqrt
(
static_cast
<
Src
>
(
2.0
)
/
acos
(
static_cast
<
Src
>
(
-
1.0
)));
#else
coef
=
std
::
sqrt
(
static_cast
<
Src
>
(
2.0
)
/
std
::
acos
(
static_cast
<
Src
>
(
-
1.0
)));
#endif
}
OF_DEVICE_FUNC
Dst
operator
()(
Src
dy
,
Src
x
)
const
{
return
static_cast
<
Src
>
(
0.5
)
*
(
static_cast
<
Src
>
(
1.0
)
+
erf
(
static_cast
<
Src
>
(
M_SQRT1_2
)
*
x
)
+
x
*
coef
*
exp
(
static_cast
<
Src
>
(
-
0.5
)
*
x
*
x
))
*
dy
;
}
Src
coef
;
};
template
<
typename
Src
,
typename
Dst
>
struct
BinaryFunctor
<
DeviceType
::
kCUDA
,
BinaryOp
::
kTanhBackwardWithDyX
,
Src
,
Dst
>
{
OF_DEVICE_FUNC
BinaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
Dst
operator
()(
Src
dy
,
Src
x
)
const
{
Src
tanh_val
=
tanh
(
x
);
return
static_cast
<
Dst
>
(
dy
*
(
static_cast
<
Src
>
(
1.0
)
-
tanh_val
*
tanh_val
));
}
};
// /*********nv_bfloat16_kernel*******/
// #if CUDA_VERSION >= 11000
// template<>
// struct BinaryFunctor<DeviceType::kCUDA, BinaryOp::kPow, nv_bfloat16, nv_bfloat16> {
// OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}
// OF_DEVICE_FUNC nv_bfloat16 operator()(nv_bfloat16 src0, nv_bfloat16 src1) const {
// return static_cast<nv_bfloat16>(pow(static_cast<float>(src0), static_cast<float>(src1)));
// }
// };
// #define SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(op) \
// template<> \
// struct BinaryFunctor<DeviceType::kCUDA, op, nv_bfloat16, nv_bfloat16> { \
// OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) : float_functor(attr0, attr1) {} \
// \
// BinaryFunctor<DeviceType::kCUDA, op, float, float> float_functor; \
// OF_DEVICE_FUNC nv_bfloat16 operator()(nv_bfloat16 src0, nv_bfloat16 src1) const { \
// return __float2bfloat16(float_functor(__bfloat162float(src0), __bfloat162float(src1))); \
// } \
// };
// SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kEluBackwardWithDyX);
// SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kCeluBackwardWithDyX);
// SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kGeluBackwardWithDyX);
// SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kHardswishBackwardWithDyX);
// SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kHardsigmoidBackwardWithDyX);
// SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kHardshrinkBackwardWithDyY);
// SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kHardtanhBackwardWithDyY);
// SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kLeakyReluBackwardWithDyX);
// SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kMishBackwardWithDyX);
// SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kSeluBackwardWithDyX);
// SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kSiluBackwardWithDyX);
// SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kSoftsignBackwardWithDyX);
// SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kSoftplusBackwardWithDyX);
// SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kSoftshrinkBackwardWithDyY);
// SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kTanhBackwardWithDyX);
// SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kThresholdBackwardWithDyX);
// #endif // CUDA_VERSION >= 11000
#define SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(op) \
template
<
>
\
struct
BinaryFunctor
<
DeviceType
::
kCUDA
,
op
,
half
,
half
>
{
\
OF_DEVICE_FUNC
BinaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
:
float_functor
(
attr0
,
attr1
)
{}
\
\
BinaryFunctor
<
DeviceType
::
kCUDA
,
op
,
float
,
float
>
float_functor
;
\
OF_DEVICE_FUNC
half
operator
()(
half
src0
,
half
src1
)
const
{
\
return
__float2half
(
float_functor
(
__half2float
(
src0
),
__half2float
(
src1
)));
\
}
\
};
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR
(
BinaryOp
::
kEluBackwardWithDyX
);
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR
(
BinaryOp
::
kCeluBackwardWithDyX
);
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR
(
BinaryOp
::
kGeluBackwardWithDyX
);
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR
(
BinaryOp
::
kHardswishBackwardWithDyX
);
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR
(
BinaryOp
::
kHardshrinkBackwardWithDyY
);
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR
(
BinaryOp
::
kMishBackwardWithDyX
);
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR
(
BinaryOp
::
kSiluBackwardWithDyX
);
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR
(
BinaryOp
::
kSeluBackwardWithDyX
);
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR
(
BinaryOp
::
kSoftplusBackwardWithDyX
);
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR
(
BinaryOp
::
kSoftsignBackwardWithDyX
);
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR
(
BinaryOp
::
kSoftshrinkBackwardWithDyY
);
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR
(
BinaryOp
::
kThresholdBackwardWithDyX
);
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR
(
BinaryOp
::
kTanhBackwardWithDyX
);
}
// namespace broadcast_elementwise_binary
}
// namespace primitive
}
// namespace ep
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/ep/common/primitive/binary_functor.h"
namespace
oneflow
{
namespace
ep
{
namespace
primitive
{
namespace
broadcast_elementwise_binary
{
template
<
typename
Src
,
typename
Dst
>
struct
BinaryFunctor
<
DeviceType
::
kCUDA
,
BinaryOp
::
kPow
,
Src
,
Dst
>
{
OF_DEVICE_FUNC
BinaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
Dst
operator
()(
Src
src0
,
Src
src1
)
const
{
return
pow
(
src0
,
src1
);
}
};
template
<
>
struct
BinaryFunctor
<
DeviceType
::
kCUDA
,
BinaryOp
::
kPow
,
bool
,
bool
>
{
OF_DEVICE_FUNC
BinaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
bool
operator
()(
bool
src0
,
bool
src1
)
const
{
return
static_cast
<
bool
>
(
pow
(
static_cast
<
double
>
(
src0
),
static_cast
<
double
>
(
src1
)));
}
};
template
<
>
struct
BinaryFunctor
<
DeviceType
::
kCUDA
,
BinaryOp
::
kPow
,
half
,
half
>
{
OF_DEVICE_FUNC
BinaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
half
operator
()(
half
src0
,
half
src1
)
const
{
return
static_cast
<
half
>
(
pow
(
static_cast
<
float
>
(
src0
),
static_cast
<
float
>
(
src1
)));
}
};
template
<
typename
Src
,
typename
Dst
>
struct
BinaryFunctor
<
DeviceType
::
kCUDA
,
BinaryOp
::
kGeluBackwardWithDyX
,
Src
,
Dst
>
{
OF_DEVICE_FUNC
BinaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{
#if defined(__CUDA_ARCH__)
coef
=
sqrt
(
static_cast
<
Src
>
(
2.0
)
/
acos
(
static_cast
<
Src
>
(
-
1.0
)));
#elif defined(__HIP_DEVICE_COMPILE__)
coef
=
sqrt
(
static_cast
<
Src
>
(
2.0
)
/
acos
(
static_cast
<
Src
>
(
-
1.0
)));
#else
coef
=
std
::
sqrt
(
static_cast
<
Src
>
(
2.0
)
/
std
::
acos
(
static_cast
<
Src
>
(
-
1.0
)));
#endif
}
OF_DEVICE_FUNC
Dst
operator
()(
Src
dy
,
Src
x
)
const
{
return
static_cast
<
Src
>
(
0.5
)
*
(
static_cast
<
Src
>
(
1.0
)
+
erf
(
static_cast
<
Src
>
(
M_SQRT1_2
)
*
x
)
+
x
*
coef
*
exp
(
static_cast
<
Src
>
(
-
0.5
)
*
x
*
x
))
*
dy
;
}
Src
coef
;
};
template
<
typename
Src
,
typename
Dst
>
struct
BinaryFunctor
<
DeviceType
::
kCUDA
,
BinaryOp
::
kTanhBackwardWithDyX
,
Src
,
Dst
>
{
OF_DEVICE_FUNC
BinaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
Dst
operator
()(
Src
dy
,
Src
x
)
const
{
Src
tanh_val
=
tanh
(
x
);
return
static_cast
<
Dst
>
(
dy
*
(
static_cast
<
Src
>
(
1.0
)
-
tanh_val
*
tanh_val
));
}
};
// /*********nv_bfloat16_kernel*******/
// #if CUDA_VERSION >= 11000
// template<>
// struct BinaryFunctor<DeviceType::kCUDA, BinaryOp::kPow, nv_bfloat16, nv_bfloat16> {
// OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}
// OF_DEVICE_FUNC nv_bfloat16 operator()(nv_bfloat16 src0, nv_bfloat16 src1) const {
// return static_cast<nv_bfloat16>(pow(static_cast<float>(src0), static_cast<float>(src1)));
// }
// };
// #define SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(op) \
// template<> \
// struct BinaryFunctor<DeviceType::kCUDA, op, nv_bfloat16, nv_bfloat16> { \
// OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) : float_functor(attr0, attr1) {} \
// \
// BinaryFunctor<DeviceType::kCUDA, op, float, float> float_functor; \
// OF_DEVICE_FUNC nv_bfloat16 operator()(nv_bfloat16 src0, nv_bfloat16 src1) const { \
// return __float2bfloat16(float_functor(__bfloat162float(src0), __bfloat162float(src1))); \
// } \
// };
// SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kEluBackwardWithDyX);
// SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kCeluBackwardWithDyX);
// SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kGeluBackwardWithDyX);
// SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kHardswishBackwardWithDyX);
// SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kHardsigmoidBackwardWithDyX);
// SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kHardshrinkBackwardWithDyY);
// SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kHardtanhBackwardWithDyY);
// SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kLeakyReluBackwardWithDyX);
// SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kMishBackwardWithDyX);
// SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kSeluBackwardWithDyX);
// SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kSiluBackwardWithDyX);
// SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kSoftsignBackwardWithDyX);
// SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kSoftplusBackwardWithDyX);
// SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kSoftshrinkBackwardWithDyY);
// SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kTanhBackwardWithDyX);
// SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kThresholdBackwardWithDyX);
// #endif // CUDA_VERSION >= 11000
#define SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(op) \
template<> \
struct BinaryFunctor<DeviceType::kCUDA, op, half, half> { \
OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) : float_functor(attr0, attr1) {} \
\
BinaryFunctor<DeviceType::kCUDA, op, float, float> float_functor; \
OF_DEVICE_FUNC half operator()(half src0, half src1) const { \
return __float2half(float_functor(__half2float(src0), __half2float(src1))); \
} \
};
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR
(
BinaryOp
::
kEluBackwardWithDyX
);
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR
(
BinaryOp
::
kCeluBackwardWithDyX
);
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR
(
BinaryOp
::
kGeluBackwardWithDyX
);
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR
(
BinaryOp
::
kHardswishBackwardWithDyX
);
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR
(
BinaryOp
::
kHardshrinkBackwardWithDyY
);
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR
(
BinaryOp
::
kMishBackwardWithDyX
);
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR
(
BinaryOp
::
kSiluBackwardWithDyX
);
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR
(
BinaryOp
::
kSeluBackwardWithDyX
);
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR
(
BinaryOp
::
kSoftplusBackwardWithDyX
);
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR
(
BinaryOp
::
kSoftsignBackwardWithDyX
);
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR
(
BinaryOp
::
kSoftshrinkBackwardWithDyY
);
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR
(
BinaryOp
::
kThresholdBackwardWithDyX
);
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR
(
BinaryOp
::
kTanhBackwardWithDyX
);
}
// namespace broadcast_elementwise_binary
}
// namespace primitive
}
// namespace ep
}
// namespace oneflow
\ No newline at end of file
oneflow/core/ep/rocm/primitive/broadcast_elementwise_binary.hip.cpp
View file @
8f7de847
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/ep/include/primitive/broadcast_elementwise_binary.h"
#include "oneflow/core/ep/common/primitive/broadcast_elementwise_binary.h"
#include "oneflow/core/ep/rocm/primitive/type_seq.h"
#include "oneflow/core/ep/rocm/cuda_stream.h"
#include "oneflow/core/hip/elementwise.hip.h"
#include "oneflow/core/ep/rocm/primitive/binary_functor.hip.h"
namespace
oneflow
{
namespace
ep
{
namespace
primitive
{
namespace
broadcast_elementwise_binary
{
template
<
BinaryOp
binary_op
,
typename
Src
,
typename
Dst
>
std
::
unique_ptr
<
BroadcastElementwiseBinary
>
NewBroadcastElementwiseBinary
(
Scalar
attr0
,
Scalar
attr1
);
namespace
{
class
BroadcastElementwiseBinaryFactoryImpl
:
public
BroadcastElementwiseBinaryFactory
{
public:
OF_DISALLOW_COPY_AND_MOVE
(
BroadcastElementwiseBinaryFactoryImpl
);
BroadcastElementwiseBinaryFactoryImpl
()
=
default
;
~
BroadcastElementwiseBinaryFactoryImpl
()
override
=
default
;
std
::
unique_ptr
<
BroadcastElementwiseBinary
>
New
(
BinaryOp
op
,
DataType
src_type
,
DataType
dst_type
,
size_t
max_num_dims
)
override
{
return
New
(
op
,
src_type
,
dst_type
,
max_num_dims
,
Scalar
(),
Scalar
());
}
std
::
unique_ptr
<
BroadcastElementwiseBinary
>
New
(
BinaryOp
op
,
DataType
src_type
,
DataType
dst_type
,
size_t
max_num_dims
,
Scalar
attr0
)
override
{
return
New
(
op
,
src_type
,
dst_type
,
max_num_dims
,
attr0
,
Scalar
());
}
std
::
unique_ptr
<
BroadcastElementwiseBinary
>
New
(
BinaryOp
binary_op
,
DataType
src_type
,
DataType
dst_type
,
size_t
max_num_dims
,
Scalar
attr0
,
Scalar
attr1
)
override
{
if
(
max_num_dims
>
kMaxNumDims
)
{
return
nullptr
;
}
#define MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY(binary_op, data_type_pair) \
{
std
::
make_tuple
(
binary_op
,
OF_PP_PAIR_SECOND
(
data_type_pair
),
\
OF_PP_PAIR_SECOND
(
data_type_pair
)),
\
NewBroadcastElementwiseBinary
<
binary_op
,
OF_PP_PAIR_FIRST
(
data_type_pair
),
\
OF_PP_PAIR_FIRST
(
data_type_pair
)
>
},
#define MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_COMPARASION_AND_LOGICAL_ENTRY( \
binary_op
,
src_data_type_pair
,
dst_data_type_pair
)
\
{
std
::
make_tuple
(
binary_op
,
OF_PP_PAIR_SECOND
(
src_data_type_pair
),
\
OF_PP_PAIR_SECOND
(
dst_data_type_pair
)),
\
NewBroadcastElementwiseBinary
<
binary_op
,
OF_PP_PAIR_FIRST
(
src_data_type_pair
),
\
OF_PP_PAIR_FIRST
(
dst_data_type_pair
)
>
},
#define MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_ACTIVATION_GRAD_ENTRY(binary_op, data_type_pair) \
{
std
::
make_tuple
(
binary_op
,
OF_PP_PAIR_SECOND
(
data_type_pair
),
\
OF_PP_PAIR_SECOND
(
data_type_pair
)),
\
NewBroadcastElementwiseBinary
<
binary_op
,
OF_PP_PAIR_FIRST
(
data_type_pair
),
\
OF_PP_PAIR_FIRST
(
data_type_pair
)
>
},
static
const
std
::
map
<
std
::
tuple
<
BinaryOp
,
DataType
,
DataType
>
,
std
::
function
<
std
::
unique_ptr
<
BroadcastElementwiseBinary
>
(
Scalar
,
Scalar
)
>>
new_broadcast_elementwise_binary_handle
{
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE
(
MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY
,
BINARY_MATH_OP_SEQ
,
CUDA_PRIMITIVE_ALL_TYPE_SEQ
)
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE
(
MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_COMPARASION_AND_LOGICAL_ENTRY
,
BINARY_COMPARISION_OP_SEQ
BINARY_LOGICAL_OP_SEQ
,
CUDA_PRIMITIVE_ALL_TYPE_SEQ
,
CUDA_PRIMITIVE_BOOL_TYPE_SEQ
)
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE
(
MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_ACTIVATION_GRAD_ENTRY
,
BINARY_ACTIVATION_BACKWARD_OP_SEQ
,
CUDA_PRIMITIVE_FLOATING_TYPE_SEQ
)};
#undef MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_COMPARASION_AND_LOGICAL_ENTRY
#undef MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY
const
auto
it
=
new_broadcast_elementwise_binary_handle
.
find
(
std
::
make_tuple
(
binary_op
,
src_type
,
dst_type
));
if
(
it
!=
new_broadcast_elementwise_binary_handle
.
end
())
{
return
it
->
second
(
attr0
,
attr1
);
}
else
{
return
nullptr
;
}
}
};
REGISTER_PRIMITIVE_FACTORY
(
DeviceType
::
kCUDA
,
BroadcastElementwiseBinaryFactory
,
BroadcastElementwiseBinaryFactoryImpl
);
}
// namespace
}
// namespace broadcast_elementwise_binary
}
// namespace primitive
}
// namespace ep
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/ep/include/primitive/broadcast_elementwise_binary.h"
#include "oneflow/core/ep/common/primitive/broadcast_elementwise_binary.h"
#include "oneflow/core/ep/rocm/primitive/type_seq.h"
#include "oneflow/core/ep/rocm/cuda_stream.h"
#include "oneflow/core/hip/elementwise.hip.h"
#include "oneflow/core/ep/rocm/primitive/binary_functor.hip.h"
namespace
oneflow
{
namespace
ep
{
namespace
primitive
{
namespace
broadcast_elementwise_binary
{
template
<
BinaryOp
binary_op
,
typename
Src
,
typename
Dst
>
std
::
unique_ptr
<
BroadcastElementwiseBinary
>
NewBroadcastElementwiseBinary
(
Scalar
attr0
,
Scalar
attr1
);
namespace
{
class
BroadcastElementwiseBinaryFactoryImpl
:
public
BroadcastElementwiseBinaryFactory
{
public:
OF_DISALLOW_COPY_AND_MOVE
(
BroadcastElementwiseBinaryFactoryImpl
);
BroadcastElementwiseBinaryFactoryImpl
()
=
default
;
~
BroadcastElementwiseBinaryFactoryImpl
()
override
=
default
;
std
::
unique_ptr
<
BroadcastElementwiseBinary
>
New
(
BinaryOp
op
,
DataType
src_type
,
DataType
dst_type
,
size_t
max_num_dims
)
override
{
return
New
(
op
,
src_type
,
dst_type
,
max_num_dims
,
Scalar
(),
Scalar
());
}
std
::
unique_ptr
<
BroadcastElementwiseBinary
>
New
(
BinaryOp
op
,
DataType
src_type
,
DataType
dst_type
,
size_t
max_num_dims
,
Scalar
attr0
)
override
{
return
New
(
op
,
src_type
,
dst_type
,
max_num_dims
,
attr0
,
Scalar
());
}
std
::
unique_ptr
<
BroadcastElementwiseBinary
>
New
(
BinaryOp
binary_op
,
DataType
src_type
,
DataType
dst_type
,
size_t
max_num_dims
,
Scalar
attr0
,
Scalar
attr1
)
override
{
if
(
max_num_dims
>
kMaxNumDims
)
{
return
nullptr
;
}
#define MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY(binary_op, data_type_pair) \
{std::make_tuple(binary_op, OF_PP_PAIR_SECOND(data_type_pair), \
OF_PP_PAIR_SECOND(data_type_pair)), \
NewBroadcastElementwiseBinary<binary_op, OF_PP_PAIR_FIRST(data_type_pair), \
OF_PP_PAIR_FIRST(data_type_pair)>},
#define MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_COMPARASION_AND_LOGICAL_ENTRY( \
binary_op, src_data_type_pair, dst_data_type_pair) \
{std::make_tuple(binary_op, OF_PP_PAIR_SECOND(src_data_type_pair), \
OF_PP_PAIR_SECOND(dst_data_type_pair)), \
NewBroadcastElementwiseBinary<binary_op, OF_PP_PAIR_FIRST(src_data_type_pair), \
OF_PP_PAIR_FIRST(dst_data_type_pair)>},
#define MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_ACTIVATION_GRAD_ENTRY(binary_op, data_type_pair) \
{std::make_tuple(binary_op, OF_PP_PAIR_SECOND(data_type_pair), \
OF_PP_PAIR_SECOND(data_type_pair)), \
NewBroadcastElementwiseBinary<binary_op, OF_PP_PAIR_FIRST(data_type_pair), \
OF_PP_PAIR_FIRST(data_type_pair)>},
static
const
std
::
map
<
std
::
tuple
<
BinaryOp
,
DataType
,
DataType
>
,
std
::
function
<
std
::
unique_ptr
<
BroadcastElementwiseBinary
>
(
Scalar
,
Scalar
)
>>
new_broadcast_elementwise_binary_handle
{
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE
(
MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY
,
BINARY_MATH_OP_SEQ
,
CUDA_PRIMITIVE_ALL_TYPE_SEQ
)
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE
(
MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_COMPARASION_AND_LOGICAL_ENTRY
,
BINARY_COMPARISION_OP_SEQ
BINARY_LOGICAL_OP_SEQ
,
CUDA_PRIMITIVE_ALL_TYPE_SEQ
,
CUDA_PRIMITIVE_BOOL_TYPE_SEQ
)
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE
(
MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_ACTIVATION_GRAD_ENTRY
,
BINARY_ACTIVATION_BACKWARD_OP_SEQ
,
CUDA_PRIMITIVE_FLOATING_TYPE_SEQ
)};
#undef MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_COMPARASION_AND_LOGICAL_ENTRY
#undef MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY
const
auto
it
=
new_broadcast_elementwise_binary_handle
.
find
(
std
::
make_tuple
(
binary_op
,
src_type
,
dst_type
));
if
(
it
!=
new_broadcast_elementwise_binary_handle
.
end
())
{
return
it
->
second
(
attr0
,
attr1
);
}
else
{
return
nullptr
;
}
}
};
REGISTER_PRIMITIVE_FACTORY
(
DeviceType
::
kCUDA
,
BroadcastElementwiseBinaryFactory
,
BroadcastElementwiseBinaryFactoryImpl
);
}
// namespace
}
// namespace broadcast_elementwise_binary
}
// namespace primitive
}
// namespace ep
}
// namespace oneflow
\ No newline at end of file
oneflow/core/ep/rocm/primitive/broadcast_elementwise_binary.hip.h
View file @
8f7de847
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "hip/hip_runtime.h"
#include "oneflow/core/ep/include/primitive/broadcast_elementwise_binary.h"
#include "oneflow/core/ep/common/primitive/broadcast_elementwise_binary.h"
#include "oneflow/core/ep/rocm/primitive/type_seq.h"
#include "oneflow/core/ep/rocm/cuda_stream.h"
#include "oneflow/core/hip/elementwise.hip.h"
#include "oneflow/core/ep/rocm/primitive/binary_functor.hip.h"
namespace
oneflow
{
namespace
ep
{
namespace
primitive
{
namespace
broadcast_elementwise_binary
{
namespace
{
template
<
typename
T
,
int
N
>
struct
GetPackType
{
using
type
=
typename
std
::
aligned_storage
<
N
*
sizeof
(
T
),
N
*
sizeof
(
T
)
>::
type
;
};
template
<
typename
T
,
int
N
>
using
PackType
=
typename
GetPackType
<
T
,
N
>::
type
;
template
<
typename
T
,
int
N
>
union
Pack
{
static_assert
(
sizeof
(
PackType
<
T
,
N
>
)
==
sizeof
(
T
)
*
N
,
""
);
OF_DEVICE_FUNC
Pack
()
{
// do nothing
}
PackType
<
T
,
N
>
storage
;
T
elem
[
N
];
};
template
<
size_t
max_dims
,
typename
IndexType
>
struct
BroadcastElementwiseBinaryParams
{
NdIndexOffsetHelper
<
IndexType
,
max_dims
>
src0_index_helper
;
NdIndexOffsetHelper
<
IndexType
,
max_dims
>
src1_index_helper
;
NdIndexOffsetHelper
<
IndexType
,
max_dims
>
dst_index_helper
;
size_t
num_dims
;
IndexType
src0_index_mask
[
max_dims
];
IndexType
src1_index_mask
[
max_dims
];
IndexType
count
{};
const
void
*
src0
{};
const
void
*
src1
{};
void
*
dst
{};
Scalar
attr0
;
Scalar
attr1
;
};
template
<
BinaryOp
binary_op
,
typename
Src
,
typename
Dst
,
size_t
max_dims
,
size_t
src0_pack_size
,
size_t
src1_pack_size
,
typename
IndexType
>
__global__
void
BroadcastElementwiseBinaryGpu
(
BroadcastElementwiseBinaryParams
<
max_dims
,
IndexType
>
params
)
{
constexpr
size_t
dst_pack_size
=
src0_pack_size
>
src1_pack_size
?
src0_pack_size
:
src1_pack_size
;
static_assert
(
src0_pack_size
==
dst_pack_size
||
src0_pack_size
==
1
,
""
);
static_assert
(
src1_pack_size
==
dst_pack_size
||
src1_pack_size
==
1
,
""
);
const
PackType
<
Src
,
src0_pack_size
>*
src0
=
reinterpret_cast
<
const
PackType
<
Src
,
src0_pack_size
>*>
(
params
.
src0
);
const
PackType
<
Src
,
src1_pack_size
>*
src1
=
reinterpret_cast
<
const
PackType
<
Src
,
src1_pack_size
>*>
(
params
.
src1
);
PackType
<
Dst
,
dst_pack_size
>*
dst
=
reinterpret_cast
<
PackType
<
Dst
,
dst_pack_size
>*>
(
params
.
dst
);
IndexType
src0_index
[
max_dims
];
IndexType
src1_index
[
max_dims
];
IndexType
dst_index
[
max_dims
];
size_t
num_dims
=
params
.
num_dims
;
CUDA_1D_KERNEL_LOOP_T
(
IndexType
,
offset
,
params
.
count
)
{
params
.
dst_index_helper
.
OffsetToNdIndex
(
offset
,
dst_index
,
num_dims
);
#pragma unroll
for
(
int
i
=
0
;
i
<
max_dims
;
++
i
)
{
if
(
i
<
num_dims
)
{
src0_index
[
i
]
=
params
.
src0_index_mask
[
i
]
*
dst_index
[
i
];
src1_index
[
i
]
=
params
.
src1_index_mask
[
i
]
*
dst_index
[
i
];
}
else
{
src0_index
[
i
]
=
0
;
src1_index
[
i
]
=
0
;
}
}
const
IndexType
src0_offset
=
params
.
src0_index_helper
.
NdIndexToOffset
(
src0_index
,
num_dims
);
const
IndexType
src1_offset
=
params
.
src1_index_helper
.
NdIndexToOffset
(
src1_index
,
num_dims
);
Pack
<
Src
,
src0_pack_size
>
src0_pack
;
src0_pack
.
storage
=
src0
[
src0_offset
];
Pack
<
Src
,
src1_pack_size
>
src1_pack
;
src1_pack
.
storage
=
src1
[
src1_offset
];
Pack
<
Dst
,
dst_pack_size
>
dst_pack
;
BinaryFunctor
<
DeviceType
::
kCUDA
,
binary_op
,
Src
,
Dst
>
functor
(
params
.
attr0
,
params
.
attr1
);
#pragma unroll
for
(
int
j
=
0
;
j
<
dst_pack_size
;
++
j
)
{
const
Src
src0_val
=
(
src0_pack_size
==
dst_pack_size
)
?
src0_pack
.
elem
[
j
]
:
src0_pack
.
elem
[
0
];
const
Src
src1_val
=
(
src1_pack_size
==
dst_pack_size
)
?
src1_pack
.
elem
[
j
]
:
src1_pack
.
elem
[
0
];
dst_pack
.
elem
[
j
]
=
functor
(
src0_val
,
src1_val
);
}
dst
[
offset
]
=
dst_pack
.
storage
;
}
}
template
<
BinaryOp
op
,
typename
T
,
typename
R
,
size_t
max_dims
,
size_t
src0_pack_size
,
size_t
src1_pack_size
,
typename
IndexType
>
void
LaunchKernel
(
Stream
*
stream
,
int
num_dims
,
const
int64_t
*
src0_dims
,
const
void
*
src0
,
const
int64_t
*
src1_dims
,
const
void
*
src1
,
const
int64_t
*
dst_dims
,
void
*
dst
,
size_t
count
,
Scalar
attr0
,
Scalar
attr1
)
{
BroadcastElementwiseBinaryParams
<
max_dims
,
IndexType
>
params
;
for
(
size_t
i
=
0
;
i
<
num_dims
;
++
i
)
{
params
.
src0_index_mask
[
i
]
=
(
src0_dims
[
i
]
==
1
)
?
0
:
1
;
params
.
src1_index_mask
[
i
]
=
(
src1_dims
[
i
]
==
1
)
?
0
:
1
;
}
params
.
src0_index_helper
=
NdIndexOffsetHelper
<
IndexType
,
max_dims
>
(
src0_dims
,
num_dims
);
params
.
src1_index_helper
=
NdIndexOffsetHelper
<
IndexType
,
max_dims
>
(
src1_dims
,
num_dims
);
params
.
dst_index_helper
=
NdIndexOffsetHelper
<
IndexType
,
max_dims
>
(
dst_dims
,
num_dims
);
params
.
num_dims
=
num_dims
;
params
.
src0
=
src0
;
params
.
src1
=
src1
;
params
.
dst
=
dst
;
params
.
count
=
static_cast
<
IndexType
>
(
count
);
params
.
attr0
=
attr0
;
params
.
attr1
=
attr1
;
auto
*
cuda_stream
=
stream
->
As
<
CudaStream
>
();
BroadcastElementwiseBinaryGpu
<
op
,
T
,
R
,
max_dims
,
src0_pack_size
,
src1_pack_size
,
IndexType
>
<<<
BlocksNum4ThreadsNum
(
params
.
count
),
kCudaThreadsNumPerBlock
,
0
,
cuda_stream
->
cuda_stream
()
>>>
(
params
);
}
template
<
BinaryOp
op
,
typename
T
,
typename
R
,
size_t
max_dims
,
size_t
src0_pack_size
,
size_t
src1_pack_size
>
void
DispatchIndexType
(
Stream
*
stream
,
size_t
num_dims
,
const
int64_t
*
src0_dims
,
const
void
*
src0
,
const
int64_t
*
src1_dims
,
const
void
*
src1
,
const
int64_t
*
dst_dims
,
void
*
dst
,
Scalar
attr0
,
Scalar
attr1
)
{
size_t
count
=
GetElementCount
(
num_dims
,
dst_dims
);
if
(
count
<
GetMaxVal
<
int32_t
>
())
{
LaunchKernel
<
op
,
T
,
R
,
max_dims
,
src0_pack_size
,
src1_pack_size
,
int32_t
>
(
stream
,
num_dims
,
src0_dims
,
src0
,
src1_dims
,
src1
,
dst_dims
,
dst
,
count
,
attr0
,
attr1
);
}
else
{
LaunchKernel
<
op
,
T
,
R
,
max_dims
,
src0_pack_size
,
src1_pack_size
,
int64_t
>
(
stream
,
num_dims
,
src0_dims
,
src0
,
src1_dims
,
src1
,
dst_dims
,
dst
,
count
,
attr0
,
attr1
);
}
}
template
<
BinaryOp
op
,
typename
T
,
typename
R
,
size_t
max_dims
>
void
DispatchPackSize
(
Stream
*
stream
,
size_t
src0_pack_size
,
size_t
src1_pack_size
,
size_t
num_dims
,
const
int64_t
*
src0_dims
,
const
void
*
src0
,
const
int64_t
*
src1_dims
,
const
void
*
src1
,
const
int64_t
*
dst_dims
,
void
*
dst
,
Scalar
attr0
,
Scalar
attr1
)
{
void
(
*
func
)(
Stream
*
/*stream*/
,
size_t
/*num_dims*/
,
const
int64_t
*
/*src0_dims*/
,
const
void
*
/*src0*/
,
const
int64_t
*
/*src1_dims*/
,
const
void
*
/*src1*/
,
const
int64_t
*
/*dst_dims*/
,
void
*
/*dst*/
,
Scalar
/*attr0*/
,
Scalar
/*attr1*/
)
=
nullptr
;
if
(
src0_pack_size
==
1
&&
src1_pack_size
==
1
)
{
func
=
DispatchIndexType
<
op
,
T
,
R
,
max_dims
,
1
,
1
>
;
}
else
if
(
src0_pack_size
==
4
&&
src1_pack_size
==
4
)
{
func
=
DispatchIndexType
<
op
,
T
,
R
,
max_dims
,
4
,
4
>
;
}
else
if
(
src0_pack_size
==
1
&&
src1_pack_size
==
4
)
{
func
=
DispatchIndexType
<
op
,
T
,
R
,
max_dims
,
1
,
4
>
;
}
else
if
(
src0_pack_size
==
4
&&
src1_pack_size
==
1
)
{
func
=
DispatchIndexType
<
op
,
T
,
R
,
max_dims
,
4
,
1
>
;
}
else
{
UNIMPLEMENTED
();
}
func
(
stream
,
num_dims
,
src0_dims
,
src0
,
src1_dims
,
src1
,
dst_dims
,
dst
,
attr0
,
attr1
);
}
template
<
BinaryOp
op
,
typename
T
,
typename
R
>
void
DispatchNumDims
(
Stream
*
stream
,
size_t
src0_pack_size
,
size_t
src1_pack_size
,
size_t
num_dims
,
const
int64_t
*
src0_dims
,
const
void
*
src0
,
const
int64_t
*
src1_dims
,
const
void
*
src1
,
const
int64_t
*
dst_dims
,
void
*
dst
,
Scalar
attr0
,
Scalar
attr1
)
{
void
(
*
func
)(
Stream
*
/*stream*/
,
size_t
/*src0_pack_size*/
,
size_t
/*src1_pack_size*/
,
size_t
/*num_dims*/
,
const
int64_t
*
/*src0_dims*/
,
const
void
*
/*src0*/
,
const
int64_t
*
/*src1_dims*/
,
const
void
*
/*src1*/
,
const
int64_t
*
/*dst_dims*/
,
void
*
/*dst*/
,
Scalar
/*attr0*/
,
Scalar
/*attr1*/
)
=
nullptr
;
CHECK_NE
(
num_dims
,
1
);
if
(
num_dims
==
2
)
{
func
=
DispatchPackSize
<
op
,
T
,
R
,
2
>
;
}
else
if
(
num_dims
==
3
)
{
func
=
DispatchPackSize
<
op
,
T
,
R
,
3
>
;
}
else
if
(
num_dims
==
4
)
{
func
=
DispatchPackSize
<
op
,
T
,
R
,
4
>
;
}
else
if
(
num_dims
<=
8
)
{
func
=
DispatchPackSize
<
op
,
T
,
R
,
8
>
;
}
else
{
UNIMPLEMENTED
();
}
func
(
stream
,
src0_pack_size
,
src1_pack_size
,
num_dims
,
src0_dims
,
src0
,
src1_dims
,
src1
,
dst_dims
,
dst
,
attr0
,
attr1
);
}
template
<
size_t
max_pack_size
,
typename
T
,
typename
R
>
size_t
GetPackSize
(
size_t
num_src_dims
,
const
int64_t
*
src0_dims
,
const
void
*
src0
,
const
int64_t
*
src1_dims
,
const
void
*
src1
,
void
*
dst
)
{
static_assert
(
max_pack_size
>
0
&&
(
max_pack_size
&
(
max_pack_size
-
1
))
==
0
,
""
);
CHECK
(
src0_dims
[
num_src_dims
-
1
]
!=
1
||
src1_dims
[
num_src_dims
-
1
]
!=
1
);
auto
dst_ptr
=
reinterpret_cast
<
std
::
uintptr_t
>
(
dst
);
for
(
size_t
pack_size
=
max_pack_size
;
pack_size
>
2
;
pack_size
/=
2
)
{
bool
is_src0_supported
=
(
src0_dims
[
num_src_dims
-
1
]
==
1
)
||
IsPackSizeSupported
<
T
>
(
pack_size
,
num_src_dims
,
src0_dims
,
src0
);
bool
is_src1_supported
=
(
src1_dims
[
num_src_dims
-
1
]
==
1
)
||
IsPackSizeSupported
<
T
>
(
pack_size
,
num_src_dims
,
src1_dims
,
src1
);
if
(
is_src0_supported
&&
is_src1_supported
&&
(
dst_ptr
%
(
pack_size
*
sizeof
(
R
)))
==
0
)
{
return
pack_size
;
}
}
return
1
;
}
constexpr
size_t
kMaxPackSize
=
4
;
template
<
BinaryOp
op
,
typename
T
,
typename
R
>
void
LaunchWithSimplified
(
Stream
*
stream
,
size_t
simplified_num_dims
,
int64_t
*
simplified_src0_dims
,
const
void
*
src0
,
int64_t
*
simplified_src1_dims
,
const
void
*
src1
,
int64_t
*
simplified_dst_dims
,
void
*
dst
,
Scalar
attr0
,
Scalar
attr1
)
{
CHECK_LE
(
simplified_num_dims
,
kMaxNumDims
);
size_t
pack_size
=
GetPackSize
<
kMaxPackSize
,
T
,
R
>
(
simplified_num_dims
,
simplified_src0_dims
,
src0
,
simplified_src1_dims
,
src1
,
dst
);
size_t
src0_pack_size
=
1
;
size_t
src1_pack_size
=
1
;
if
(
simplified_src0_dims
[
simplified_num_dims
-
1
]
!=
1
)
{
simplified_src0_dims
[
simplified_num_dims
-
1
]
/=
pack_size
;
src0_pack_size
=
pack_size
;
}
if
(
simplified_src1_dims
[
simplified_num_dims
-
1
]
!=
1
)
{
simplified_src1_dims
[
simplified_num_dims
-
1
]
/=
pack_size
;
src1_pack_size
=
pack_size
;
}
simplified_dst_dims
[
simplified_num_dims
-
1
]
/=
pack_size
;
DispatchNumDims
<
op
,
T
,
R
>
(
stream
,
src0_pack_size
,
src1_pack_size
,
simplified_num_dims
,
simplified_src0_dims
,
src0
,
simplified_src1_dims
,
src1
,
simplified_dst_dims
,
dst
,
attr0
,
attr1
);
}
template
<
BinaryOp
binary_op
,
typename
Src
,
typename
Dst
>
struct
BinaryLhsScalarFunctor
{
__host__
__device__
BinaryLhsScalarFunctor
(
Src
scalar
,
Scalar
attr0
,
Scalar
attr1
)
:
scalar
(
scalar
),
functor
(
attr0
,
attr1
)
{}
__device__
Dst
operator
()(
Src
src
)
const
{
return
functor
(
scalar
,
src
);
}
const
Src
scalar
;
BinaryFunctor
<
DeviceType
::
kCUDA
,
binary_op
,
Src
,
Dst
>
functor
;
};
template
<
BinaryOp
binary_op
,
typename
Src
,
typename
Dst
>
struct
BinaryRhsScalarFunctor
{
__host__
__device__
BinaryRhsScalarFunctor
(
Src
scalar
,
Scalar
attr0
,
Scalar
attr1
)
:
scalar
(
scalar
),
functor
(
attr0
,
attr1
)
{}
__device__
Dst
operator
()(
Src
src
)
const
{
return
functor
(
src
,
scalar
);
}
const
Src
scalar
;
BinaryFunctor
<
DeviceType
::
kCUDA
,
binary_op
,
Src
,
Dst
>
functor
;
};
template
<
BinaryOp
binary_op
,
typename
Src
,
typename
Dst
>
struct
BinaryLhsScalarPtrFunctorFactory
{
__host__
__device__
BinaryLhsScalarPtrFunctorFactory
(
const
Src
*
scalar_ptr
,
Scalar
attr0
,
Scalar
attr1
)
:
scalar_ptr
(
scalar_ptr
),
attr0
(
attr0
),
attr1
(
attr1
)
{}
__device__
BinaryLhsScalarFunctor
<
binary_op
,
Src
,
Dst
>
operator
()()
const
{
return
BinaryLhsScalarFunctor
<
binary_op
,
Src
,
Dst
>
(
*
scalar_ptr
,
attr0
,
attr1
);
}
const
Src
*
scalar_ptr
;
Scalar
attr0
,
attr1
;
};
template
<
BinaryOp
binary_op
,
typename
Src
,
typename
Dst
>
struct
BinaryRhsScalarPtrFunctorFactory
{
__host__
__device__
explicit
BinaryRhsScalarPtrFunctorFactory
(
const
Src
*
scalar_ptr
,
Scalar
attr0
,
Scalar
attr1
)
:
scalar_ptr
(
scalar_ptr
),
attr0
(
attr0
),
attr1
(
attr1
)
{}
__device__
BinaryRhsScalarFunctor
<
binary_op
,
Src
,
Dst
>
operator
()()
const
{
return
BinaryRhsScalarFunctor
<
binary_op
,
Src
,
Dst
>
(
*
scalar_ptr
,
attr0
,
attr1
);
}
const
Src
*
scalar_ptr
;
Scalar
attr0
,
attr1
;
};
template
<
BinaryOp
binary_op
,
typename
Src
,
typename
Dst
>
void
DispatchLaunch
(
Stream
*
stream
,
size_t
num_src0_dims
,
const
int64_t
*
src0_dims
,
const
Src
*
src0
,
size_t
num_src1_dims
,
const
int64_t
*
src1_dims
,
const
Src
*
src1
,
Dst
*
dst
,
Scalar
attr0
,
Scalar
attr1
)
{
auto
*
cuda_stream
=
stream
->
As
<
CudaStream
>
();
size_t
simplified_num_dims
=
0
;
int64_t
simplified_src0_dims
[
kMaxNumDims
];
int64_t
simplified_src1_dims
[
kMaxNumDims
];
int64_t
simplified_dst_dims
[
kMaxNumDims
];
SimplifyBroadcastDims
<
kMaxNumDims
>
(
num_src0_dims
,
src0_dims
,
num_src1_dims
,
src1_dims
,
&
simplified_num_dims
,
simplified_src0_dims
,
simplified_src1_dims
,
simplified_dst_dims
);
CheckInplace
(
simplified_num_dims
,
simplified_src0_dims
,
src0
,
simplified_src1_dims
,
src1
,
simplified_dst_dims
,
dst
);
if
(
IsDimsEquals
(
simplified_num_dims
,
simplified_src0_dims
,
simplified_num_dims
,
simplified_src1_dims
))
{
const
int64_t
elem_cnt
=
GetElementCount
(
simplified_num_dims
,
simplified_src0_dims
);
OF_CUDA_CHECK
((
cuda
::
elementwise
::
Binary
(
BinaryFunctor
<
DeviceType
::
kCUDA
,
binary_op
,
Src
,
Dst
>
(
attr0
,
attr1
),
elem_cnt
,
dst
,
src0
,
src1
,
cuda_stream
->
cuda_stream
())));
}
else
{
if
(
simplified_num_dims
==
1
&&
simplified_src0_dims
[
0
]
==
1
)
{
OF_CUDA_CHECK
((
cuda
::
elementwise
::
UnaryWithFactory
(
BinaryLhsScalarPtrFunctorFactory
<
binary_op
,
Src
,
Dst
>
(
src0
,
attr0
,
attr1
),
simplified_src1_dims
[
0
],
dst
,
src1
,
cuda_stream
->
cuda_stream
())));
}
else
if
(
simplified_num_dims
==
1
&&
simplified_src1_dims
[
0
]
==
1
)
{
OF_CUDA_CHECK
((
cuda
::
elementwise
::
UnaryWithFactory
(
BinaryRhsScalarPtrFunctorFactory
<
binary_op
,
Src
,
Dst
>
(
src1
,
attr0
,
attr1
),
simplified_src0_dims
[
0
],
dst
,
src0
,
cuda_stream
->
cuda_stream
())));
}
else
{
LaunchWithSimplified
<
binary_op
,
Src
,
Dst
>
(
stream
,
simplified_num_dims
,
simplified_src0_dims
,
src0
,
simplified_src1_dims
,
src1
,
simplified_dst_dims
,
dst
,
attr0
,
attr1
);
}
}
}
template
<
typename
T
>
T
GetValue
(
Scalar
value
)
{
return
value
.
Value
<
T
>
();
}
template
<
>
half
GetValue
<
half
>
(
Scalar
value
)
{
return
static_cast
<
half
>
(
GetValue
<
float
>
(
value
));
}
// #if CUDA_VERSION >= 11000
// template<>
// nv_bfloat16 GetValue<nv_bfloat16>(Scalar value) {
// return static_cast<nv_bfloat16>(GetValue<float>(value));
// }
// #endif // CUDA_VERSION >= 11000
template
<
BinaryOp
binary_op
,
typename
Src
,
typename
Dst
>
class
BroadcastElementwiseBinaryImpl
:
public
BroadcastElementwiseBinary
{
public:
OF_DISALLOW_COPY_AND_MOVE
(
BroadcastElementwiseBinaryImpl
);
BroadcastElementwiseBinaryImpl
(
Scalar
attr0
,
Scalar
attr1
)
:
attr0
(
attr0
),
attr1
(
attr1
)
{}
~
BroadcastElementwiseBinaryImpl
()
override
=
default
;
void
Launch
(
Stream
*
stream
,
Scalar
src0
,
size_t
num_src1_dims
,
const
int64_t
*
src1_dims
,
const
void
*
src1
,
void
*
dst
)
override
{
auto
*
cuda_stream
=
stream
->
As
<
CudaStream
>
();
const
size_t
elem_cnt
=
GetElementCount
(
num_src1_dims
,
src1_dims
);
OF_CUDA_CHECK
((
cuda
::
elementwise
::
Unary
(
BinaryLhsScalarFunctor
<
binary_op
,
Src
,
Dst
>
(
GetValue
<
Src
>
(
src0
),
attr0
,
attr1
),
elem_cnt
,
reinterpret_cast
<
Dst
*>
(
dst
),
reinterpret_cast
<
const
Src
*>
(
src1
),
cuda_stream
->
cuda_stream
())));
}
void
Launch
(
Stream
*
stream
,
size_t
num_src0_dims
,
const
int64_t
*
src0_dims
,
const
void
*
src0
,
Scalar
src1
,
void
*
dst
)
override
{
auto
*
cuda_stream
=
stream
->
As
<
CudaStream
>
();
const
size_t
elem_cnt
=
GetElementCount
(
num_src0_dims
,
src0_dims
);
OF_CUDA_CHECK
((
cuda
::
elementwise
::
Unary
(
BinaryRhsScalarFunctor
<
binary_op
,
Src
,
Dst
>
(
GetValue
<
Src
>
(
src1
),
attr0
,
attr1
),
elem_cnt
,
reinterpret_cast
<
Dst
*>
(
dst
),
reinterpret_cast
<
const
Src
*>
(
src0
),
cuda_stream
->
cuda_stream
())));
}
void
Launch
(
Stream
*
stream
,
size_t
num_src0_dims
,
const
int64_t
*
src0_dims
,
const
void
*
src0
,
size_t
num_src1_dims
,
const
int64_t
*
src1_dims
,
const
void
*
src1
,
void
*
dst
)
override
{
DispatchLaunch
<
binary_op
,
Src
,
Dst
>
(
stream
,
num_src0_dims
,
src0_dims
,
reinterpret_cast
<
const
Src
*>
(
src0
),
num_src1_dims
,
src1_dims
,
reinterpret_cast
<
const
Src
*>
(
src1
),
reinterpret_cast
<
Dst
*>
(
dst
),
attr0
,
attr1
);
}
private:
Scalar
attr0
,
attr1
;
};
}
// namespace
template
<
BinaryOp
binary_op
,
typename
Src
,
typename
Dst
>
std
::
unique_ptr
<
BroadcastElementwiseBinary
>
NewBroadcastElementwiseBinary
(
Scalar
attr0
,
Scalar
attr1
)
{
return
std
::
unique_ptr
<
BroadcastElementwiseBinary
>
(
new
BroadcastElementwiseBinaryImpl
<
binary_op
,
Src
,
Dst
>
(
attr0
,
attr1
));
}
}
// namespace broadcast_elementwise_binary
}
// namespace primitive
}
// namespace ep
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "hip/hip_runtime.h"
#include "oneflow/core/ep/include/primitive/broadcast_elementwise_binary.h"
#include "oneflow/core/ep/common/primitive/broadcast_elementwise_binary.h"
#include "oneflow/core/ep/rocm/primitive/type_seq.h"
#include "oneflow/core/ep/rocm/cuda_stream.h"
#include "oneflow/core/hip/elementwise.hip.h"
#include "oneflow/core/ep/rocm/primitive/binary_functor.hip.h"
namespace
oneflow
{
namespace
ep
{
namespace
primitive
{
namespace
broadcast_elementwise_binary
{
namespace
{
template
<
typename
T
,
int
N
>
struct
GetPackType
{
using
type
=
typename
std
::
aligned_storage
<
N
*
sizeof
(
T
),
N
*
sizeof
(
T
)
>::
type
;
};
template
<
typename
T
,
int
N
>
using
PackType
=
typename
GetPackType
<
T
,
N
>::
type
;
template
<
typename
T
,
int
N
>
union
Pack
{
static_assert
(
sizeof
(
PackType
<
T
,
N
>
)
==
sizeof
(
T
)
*
N
,
""
);
OF_DEVICE_FUNC
Pack
()
{
// do nothing
}
PackType
<
T
,
N
>
storage
;
T
elem
[
N
];
};
template
<
size_t
max_dims
,
typename
IndexType
>
struct
BroadcastElementwiseBinaryParams
{
NdIndexOffsetHelper
<
IndexType
,
max_dims
>
src0_index_helper
;
NdIndexOffsetHelper
<
IndexType
,
max_dims
>
src1_index_helper
;
NdIndexOffsetHelper
<
IndexType
,
max_dims
>
dst_index_helper
;
size_t
num_dims
;
IndexType
src0_index_mask
[
max_dims
];
IndexType
src1_index_mask
[
max_dims
];
IndexType
count
{};
const
void
*
src0
{};
const
void
*
src1
{};
void
*
dst
{};
Scalar
attr0
;
Scalar
attr1
;
};
template
<
BinaryOp
binary_op
,
typename
Src
,
typename
Dst
,
size_t
max_dims
,
size_t
src0_pack_size
,
size_t
src1_pack_size
,
typename
IndexType
>
__global__
void
BroadcastElementwiseBinaryGpu
(
BroadcastElementwiseBinaryParams
<
max_dims
,
IndexType
>
params
)
{
constexpr
size_t
dst_pack_size
=
src0_pack_size
>
src1_pack_size
?
src0_pack_size
:
src1_pack_size
;
static_assert
(
src0_pack_size
==
dst_pack_size
||
src0_pack_size
==
1
,
""
);
static_assert
(
src1_pack_size
==
dst_pack_size
||
src1_pack_size
==
1
,
""
);
const
PackType
<
Src
,
src0_pack_size
>*
src0
=
reinterpret_cast
<
const
PackType
<
Src
,
src0_pack_size
>*>
(
params
.
src0
);
const
PackType
<
Src
,
src1_pack_size
>*
src1
=
reinterpret_cast
<
const
PackType
<
Src
,
src1_pack_size
>*>
(
params
.
src1
);
PackType
<
Dst
,
dst_pack_size
>*
dst
=
reinterpret_cast
<
PackType
<
Dst
,
dst_pack_size
>*>
(
params
.
dst
);
IndexType
src0_index
[
max_dims
];
IndexType
src1_index
[
max_dims
];
IndexType
dst_index
[
max_dims
];
size_t
num_dims
=
params
.
num_dims
;
CUDA_1D_KERNEL_LOOP_T
(
IndexType
,
offset
,
params
.
count
)
{
params
.
dst_index_helper
.
OffsetToNdIndex
(
offset
,
dst_index
,
num_dims
);
#pragma unroll
for
(
int
i
=
0
;
i
<
max_dims
;
++
i
)
{
if
(
i
<
num_dims
)
{
src0_index
[
i
]
=
params
.
src0_index_mask
[
i
]
*
dst_index
[
i
];
src1_index
[
i
]
=
params
.
src1_index_mask
[
i
]
*
dst_index
[
i
];
}
else
{
src0_index
[
i
]
=
0
;
src1_index
[
i
]
=
0
;
}
}
const
IndexType
src0_offset
=
params
.
src0_index_helper
.
NdIndexToOffset
(
src0_index
,
num_dims
);
const
IndexType
src1_offset
=
params
.
src1_index_helper
.
NdIndexToOffset
(
src1_index
,
num_dims
);
Pack
<
Src
,
src0_pack_size
>
src0_pack
;
src0_pack
.
storage
=
src0
[
src0_offset
];
Pack
<
Src
,
src1_pack_size
>
src1_pack
;
src1_pack
.
storage
=
src1
[
src1_offset
];
Pack
<
Dst
,
dst_pack_size
>
dst_pack
;
BinaryFunctor
<
DeviceType
::
kCUDA
,
binary_op
,
Src
,
Dst
>
functor
(
params
.
attr0
,
params
.
attr1
);
#pragma unroll
for
(
int
j
=
0
;
j
<
dst_pack_size
;
++
j
)
{
const
Src
src0_val
=
(
src0_pack_size
==
dst_pack_size
)
?
src0_pack
.
elem
[
j
]
:
src0_pack
.
elem
[
0
];
const
Src
src1_val
=
(
src1_pack_size
==
dst_pack_size
)
?
src1_pack
.
elem
[
j
]
:
src1_pack
.
elem
[
0
];
dst_pack
.
elem
[
j
]
=
functor
(
src0_val
,
src1_val
);
}
dst
[
offset
]
=
dst_pack
.
storage
;
}
}
template
<
BinaryOp
op
,
typename
T
,
typename
R
,
size_t
max_dims
,
size_t
src0_pack_size
,
size_t
src1_pack_size
,
typename
IndexType
>
void
LaunchKernel
(
Stream
*
stream
,
int
num_dims
,
const
int64_t
*
src0_dims
,
const
void
*
src0
,
const
int64_t
*
src1_dims
,
const
void
*
src1
,
const
int64_t
*
dst_dims
,
void
*
dst
,
size_t
count
,
Scalar
attr0
,
Scalar
attr1
)
{
BroadcastElementwiseBinaryParams
<
max_dims
,
IndexType
>
params
;
for
(
size_t
i
=
0
;
i
<
num_dims
;
++
i
)
{
params
.
src0_index_mask
[
i
]
=
(
src0_dims
[
i
]
==
1
)
?
0
:
1
;
params
.
src1_index_mask
[
i
]
=
(
src1_dims
[
i
]
==
1
)
?
0
:
1
;
}
params
.
src0_index_helper
=
NdIndexOffsetHelper
<
IndexType
,
max_dims
>
(
src0_dims
,
num_dims
);
params
.
src1_index_helper
=
NdIndexOffsetHelper
<
IndexType
,
max_dims
>
(
src1_dims
,
num_dims
);
params
.
dst_index_helper
=
NdIndexOffsetHelper
<
IndexType
,
max_dims
>
(
dst_dims
,
num_dims
);
params
.
num_dims
=
num_dims
;
params
.
src0
=
src0
;
params
.
src1
=
src1
;
params
.
dst
=
dst
;
params
.
count
=
static_cast
<
IndexType
>
(
count
);
params
.
attr0
=
attr0
;
params
.
attr1
=
attr1
;
auto
*
cuda_stream
=
stream
->
As
<
CudaStream
>
();
BroadcastElementwiseBinaryGpu
<
op
,
T
,
R
,
max_dims
,
src0_pack_size
,
src1_pack_size
,
IndexType
>
<<<
BlocksNum4ThreadsNum
(
params
.
count
),
kCudaThreadsNumPerBlock
,
0
,
cuda_stream
->
cuda_stream
()
>>>
(
params
);
}
template
<
BinaryOp
op
,
typename
T
,
typename
R
,
size_t
max_dims
,
size_t
src0_pack_size
,
size_t
src1_pack_size
>
void
DispatchIndexType
(
Stream
*
stream
,
size_t
num_dims
,
const
int64_t
*
src0_dims
,
const
void
*
src0
,
const
int64_t
*
src1_dims
,
const
void
*
src1
,
const
int64_t
*
dst_dims
,
void
*
dst
,
Scalar
attr0
,
Scalar
attr1
)
{
size_t
count
=
GetElementCount
(
num_dims
,
dst_dims
);
if
(
count
<
GetMaxVal
<
int32_t
>
())
{
LaunchKernel
<
op
,
T
,
R
,
max_dims
,
src0_pack_size
,
src1_pack_size
,
int32_t
>
(
stream
,
num_dims
,
src0_dims
,
src0
,
src1_dims
,
src1
,
dst_dims
,
dst
,
count
,
attr0
,
attr1
);
}
else
{
LaunchKernel
<
op
,
T
,
R
,
max_dims
,
src0_pack_size
,
src1_pack_size
,
int64_t
>
(
stream
,
num_dims
,
src0_dims
,
src0
,
src1_dims
,
src1
,
dst_dims
,
dst
,
count
,
attr0
,
attr1
);
}
}
template
<
BinaryOp
op
,
typename
T
,
typename
R
,
size_t
max_dims
>
void
DispatchPackSize
(
Stream
*
stream
,
size_t
src0_pack_size
,
size_t
src1_pack_size
,
size_t
num_dims
,
const
int64_t
*
src0_dims
,
const
void
*
src0
,
const
int64_t
*
src1_dims
,
const
void
*
src1
,
const
int64_t
*
dst_dims
,
void
*
dst
,
Scalar
attr0
,
Scalar
attr1
)
{
void
(
*
func
)(
Stream
*
/*stream*/
,
size_t
/*num_dims*/
,
const
int64_t
*
/*src0_dims*/
,
const
void
*
/*src0*/
,
const
int64_t
*
/*src1_dims*/
,
const
void
*
/*src1*/
,
const
int64_t
*
/*dst_dims*/
,
void
*
/*dst*/
,
Scalar
/*attr0*/
,
Scalar
/*attr1*/
)
=
nullptr
;
if
(
src0_pack_size
==
1
&&
src1_pack_size
==
1
)
{
func
=
DispatchIndexType
<
op
,
T
,
R
,
max_dims
,
1
,
1
>
;
}
else
if
(
src0_pack_size
==
4
&&
src1_pack_size
==
4
)
{
func
=
DispatchIndexType
<
op
,
T
,
R
,
max_dims
,
4
,
4
>
;
}
else
if
(
src0_pack_size
==
1
&&
src1_pack_size
==
4
)
{
func
=
DispatchIndexType
<
op
,
T
,
R
,
max_dims
,
1
,
4
>
;
}
else
if
(
src0_pack_size
==
4
&&
src1_pack_size
==
1
)
{
func
=
DispatchIndexType
<
op
,
T
,
R
,
max_dims
,
4
,
1
>
;
}
else
{
UNIMPLEMENTED
();
}
func
(
stream
,
num_dims
,
src0_dims
,
src0
,
src1_dims
,
src1
,
dst_dims
,
dst
,
attr0
,
attr1
);
}
template
<
BinaryOp
op
,
typename
T
,
typename
R
>
void
DispatchNumDims
(
Stream
*
stream
,
size_t
src0_pack_size
,
size_t
src1_pack_size
,
size_t
num_dims
,
const
int64_t
*
src0_dims
,
const
void
*
src0
,
const
int64_t
*
src1_dims
,
const
void
*
src1
,
const
int64_t
*
dst_dims
,
void
*
dst
,
Scalar
attr0
,
Scalar
attr1
)
{
void
(
*
func
)(
Stream
*
/*stream*/
,
size_t
/*src0_pack_size*/
,
size_t
/*src1_pack_size*/
,
size_t
/*num_dims*/
,
const
int64_t
*
/*src0_dims*/
,
const
void
*
/*src0*/
,
const
int64_t
*
/*src1_dims*/
,
const
void
*
/*src1*/
,
const
int64_t
*
/*dst_dims*/
,
void
*
/*dst*/
,
Scalar
/*attr0*/
,
Scalar
/*attr1*/
)
=
nullptr
;
CHECK_NE
(
num_dims
,
1
);
if
(
num_dims
==
2
)
{
func
=
DispatchPackSize
<
op
,
T
,
R
,
2
>
;
}
else
if
(
num_dims
==
3
)
{
func
=
DispatchPackSize
<
op
,
T
,
R
,
3
>
;
}
else
if
(
num_dims
==
4
)
{
func
=
DispatchPackSize
<
op
,
T
,
R
,
4
>
;
}
else
if
(
num_dims
<=
8
)
{
func
=
DispatchPackSize
<
op
,
T
,
R
,
8
>
;
}
else
{
UNIMPLEMENTED
();
}
func
(
stream
,
src0_pack_size
,
src1_pack_size
,
num_dims
,
src0_dims
,
src0
,
src1_dims
,
src1
,
dst_dims
,
dst
,
attr0
,
attr1
);
}
template
<
size_t
max_pack_size
,
typename
T
,
typename
R
>
size_t
GetPackSize
(
size_t
num_src_dims
,
const
int64_t
*
src0_dims
,
const
void
*
src0
,
const
int64_t
*
src1_dims
,
const
void
*
src1
,
void
*
dst
)
{
static_assert
(
max_pack_size
>
0
&&
(
max_pack_size
&
(
max_pack_size
-
1
))
==
0
,
""
);
CHECK
(
src0_dims
[
num_src_dims
-
1
]
!=
1
||
src1_dims
[
num_src_dims
-
1
]
!=
1
);
auto
dst_ptr
=
reinterpret_cast
<
std
::
uintptr_t
>
(
dst
);
for
(
size_t
pack_size
=
max_pack_size
;
pack_size
>
2
;
pack_size
/=
2
)
{
bool
is_src0_supported
=
(
src0_dims
[
num_src_dims
-
1
]
==
1
)
||
IsPackSizeSupported
<
T
>
(
pack_size
,
num_src_dims
,
src0_dims
,
src0
);
bool
is_src1_supported
=
(
src1_dims
[
num_src_dims
-
1
]
==
1
)
||
IsPackSizeSupported
<
T
>
(
pack_size
,
num_src_dims
,
src1_dims
,
src1
);
if
(
is_src0_supported
&&
is_src1_supported
&&
(
dst_ptr
%
(
pack_size
*
sizeof
(
R
)))
==
0
)
{
return
pack_size
;
}
}
return
1
;
}
constexpr
size_t
kMaxPackSize
=
4
;
template
<
BinaryOp
op
,
typename
T
,
typename
R
>
void
LaunchWithSimplified
(
Stream
*
stream
,
size_t
simplified_num_dims
,
int64_t
*
simplified_src0_dims
,
const
void
*
src0
,
int64_t
*
simplified_src1_dims
,
const
void
*
src1
,
int64_t
*
simplified_dst_dims
,
void
*
dst
,
Scalar
attr0
,
Scalar
attr1
)
{
CHECK_LE
(
simplified_num_dims
,
kMaxNumDims
);
size_t
pack_size
=
GetPackSize
<
kMaxPackSize
,
T
,
R
>
(
simplified_num_dims
,
simplified_src0_dims
,
src0
,
simplified_src1_dims
,
src1
,
dst
);
size_t
src0_pack_size
=
1
;
size_t
src1_pack_size
=
1
;
if
(
simplified_src0_dims
[
simplified_num_dims
-
1
]
!=
1
)
{
simplified_src0_dims
[
simplified_num_dims
-
1
]
/=
pack_size
;
src0_pack_size
=
pack_size
;
}
if
(
simplified_src1_dims
[
simplified_num_dims
-
1
]
!=
1
)
{
simplified_src1_dims
[
simplified_num_dims
-
1
]
/=
pack_size
;
src1_pack_size
=
pack_size
;
}
simplified_dst_dims
[
simplified_num_dims
-
1
]
/=
pack_size
;
DispatchNumDims
<
op
,
T
,
R
>
(
stream
,
src0_pack_size
,
src1_pack_size
,
simplified_num_dims
,
simplified_src0_dims
,
src0
,
simplified_src1_dims
,
src1
,
simplified_dst_dims
,
dst
,
attr0
,
attr1
);
}
template
<
BinaryOp
binary_op
,
typename
Src
,
typename
Dst
>
struct
BinaryLhsScalarFunctor
{
__host__
__device__
BinaryLhsScalarFunctor
(
Src
scalar
,
Scalar
attr0
,
Scalar
attr1
)
:
scalar
(
scalar
),
functor
(
attr0
,
attr1
)
{}
__device__
Dst
operator
()(
Src
src
)
const
{
return
functor
(
scalar
,
src
);
}
const
Src
scalar
;
BinaryFunctor
<
DeviceType
::
kCUDA
,
binary_op
,
Src
,
Dst
>
functor
;
};
template
<
BinaryOp
binary_op
,
typename
Src
,
typename
Dst
>
struct
BinaryRhsScalarFunctor
{
__host__
__device__
BinaryRhsScalarFunctor
(
Src
scalar
,
Scalar
attr0
,
Scalar
attr1
)
:
scalar
(
scalar
),
functor
(
attr0
,
attr1
)
{}
__device__
Dst
operator
()(
Src
src
)
const
{
return
functor
(
src
,
scalar
);
}
const
Src
scalar
;
BinaryFunctor
<
DeviceType
::
kCUDA
,
binary_op
,
Src
,
Dst
>
functor
;
};
template
<
BinaryOp
binary_op
,
typename
Src
,
typename
Dst
>
struct
BinaryLhsScalarPtrFunctorFactory
{
__host__
__device__
BinaryLhsScalarPtrFunctorFactory
(
const
Src
*
scalar_ptr
,
Scalar
attr0
,
Scalar
attr1
)
:
scalar_ptr
(
scalar_ptr
),
attr0
(
attr0
),
attr1
(
attr1
)
{}
__device__
BinaryLhsScalarFunctor
<
binary_op
,
Src
,
Dst
>
operator
()()
const
{
return
BinaryLhsScalarFunctor
<
binary_op
,
Src
,
Dst
>
(
*
scalar_ptr
,
attr0
,
attr1
);
}
const
Src
*
scalar_ptr
;
Scalar
attr0
,
attr1
;
};
template
<
BinaryOp
binary_op
,
typename
Src
,
typename
Dst
>
struct
BinaryRhsScalarPtrFunctorFactory
{
__host__
__device__
explicit
BinaryRhsScalarPtrFunctorFactory
(
const
Src
*
scalar_ptr
,
Scalar
attr0
,
Scalar
attr1
)
:
scalar_ptr
(
scalar_ptr
),
attr0
(
attr0
),
attr1
(
attr1
)
{}
__device__
BinaryRhsScalarFunctor
<
binary_op
,
Src
,
Dst
>
operator
()()
const
{
return
BinaryRhsScalarFunctor
<
binary_op
,
Src
,
Dst
>
(
*
scalar_ptr
,
attr0
,
attr1
);
}
const
Src
*
scalar_ptr
;
Scalar
attr0
,
attr1
;
};
template
<
BinaryOp
binary_op
,
typename
Src
,
typename
Dst
>
void
DispatchLaunch
(
Stream
*
stream
,
size_t
num_src0_dims
,
const
int64_t
*
src0_dims
,
const
Src
*
src0
,
size_t
num_src1_dims
,
const
int64_t
*
src1_dims
,
const
Src
*
src1
,
Dst
*
dst
,
Scalar
attr0
,
Scalar
attr1
)
{
auto
*
cuda_stream
=
stream
->
As
<
CudaStream
>
();
size_t
simplified_num_dims
=
0
;
int64_t
simplified_src0_dims
[
kMaxNumDims
];
int64_t
simplified_src1_dims
[
kMaxNumDims
];
int64_t
simplified_dst_dims
[
kMaxNumDims
];
SimplifyBroadcastDims
<
kMaxNumDims
>
(
num_src0_dims
,
src0_dims
,
num_src1_dims
,
src1_dims
,
&
simplified_num_dims
,
simplified_src0_dims
,
simplified_src1_dims
,
simplified_dst_dims
);
CheckInplace
(
simplified_num_dims
,
simplified_src0_dims
,
src0
,
simplified_src1_dims
,
src1
,
simplified_dst_dims
,
dst
);
if
(
IsDimsEquals
(
simplified_num_dims
,
simplified_src0_dims
,
simplified_num_dims
,
simplified_src1_dims
))
{
const
int64_t
elem_cnt
=
GetElementCount
(
simplified_num_dims
,
simplified_src0_dims
);
OF_CUDA_CHECK
((
cuda
::
elementwise
::
Binary
(
BinaryFunctor
<
DeviceType
::
kCUDA
,
binary_op
,
Src
,
Dst
>
(
attr0
,
attr1
),
elem_cnt
,
dst
,
src0
,
src1
,
cuda_stream
->
cuda_stream
())));
}
else
{
if
(
simplified_num_dims
==
1
&&
simplified_src0_dims
[
0
]
==
1
)
{
OF_CUDA_CHECK
((
cuda
::
elementwise
::
UnaryWithFactory
(
BinaryLhsScalarPtrFunctorFactory
<
binary_op
,
Src
,
Dst
>
(
src0
,
attr0
,
attr1
),
simplified_src1_dims
[
0
],
dst
,
src1
,
cuda_stream
->
cuda_stream
())));
}
else
if
(
simplified_num_dims
==
1
&&
simplified_src1_dims
[
0
]
==
1
)
{
OF_CUDA_CHECK
((
cuda
::
elementwise
::
UnaryWithFactory
(
BinaryRhsScalarPtrFunctorFactory
<
binary_op
,
Src
,
Dst
>
(
src1
,
attr0
,
attr1
),
simplified_src0_dims
[
0
],
dst
,
src0
,
cuda_stream
->
cuda_stream
())));
}
else
{
LaunchWithSimplified
<
binary_op
,
Src
,
Dst
>
(
stream
,
simplified_num_dims
,
simplified_src0_dims
,
src0
,
simplified_src1_dims
,
src1
,
simplified_dst_dims
,
dst
,
attr0
,
attr1
);
}
}
}
template
<
typename
T
>
T
GetValue
(
Scalar
value
)
{
return
value
.
Value
<
T
>
();
}
template
<
>
half
GetValue
<
half
>
(
Scalar
value
)
{
return
static_cast
<
half
>
(
GetValue
<
float
>
(
value
));
}
// #if CUDA_VERSION >= 11000
// template<>
// nv_bfloat16 GetValue<nv_bfloat16>(Scalar value) {
// return static_cast<nv_bfloat16>(GetValue<float>(value));
// }
// #endif // CUDA_VERSION >= 11000
template
<
BinaryOp
binary_op
,
typename
Src
,
typename
Dst
>
class
BroadcastElementwiseBinaryImpl
:
public
BroadcastElementwiseBinary
{
public:
OF_DISALLOW_COPY_AND_MOVE
(
BroadcastElementwiseBinaryImpl
);
BroadcastElementwiseBinaryImpl
(
Scalar
attr0
,
Scalar
attr1
)
:
attr0
(
attr0
),
attr1
(
attr1
)
{}
~
BroadcastElementwiseBinaryImpl
()
override
=
default
;
void
Launch
(
Stream
*
stream
,
Scalar
src0
,
size_t
num_src1_dims
,
const
int64_t
*
src1_dims
,
const
void
*
src1
,
void
*
dst
)
override
{
auto
*
cuda_stream
=
stream
->
As
<
CudaStream
>
();
const
size_t
elem_cnt
=
GetElementCount
(
num_src1_dims
,
src1_dims
);
OF_CUDA_CHECK
((
cuda
::
elementwise
::
Unary
(
BinaryLhsScalarFunctor
<
binary_op
,
Src
,
Dst
>
(
GetValue
<
Src
>
(
src0
),
attr0
,
attr1
),
elem_cnt
,
reinterpret_cast
<
Dst
*>
(
dst
),
reinterpret_cast
<
const
Src
*>
(
src1
),
cuda_stream
->
cuda_stream
())));
}
void
Launch
(
Stream
*
stream
,
size_t
num_src0_dims
,
const
int64_t
*
src0_dims
,
const
void
*
src0
,
Scalar
src1
,
void
*
dst
)
override
{
auto
*
cuda_stream
=
stream
->
As
<
CudaStream
>
();
const
size_t
elem_cnt
=
GetElementCount
(
num_src0_dims
,
src0_dims
);
OF_CUDA_CHECK
((
cuda
::
elementwise
::
Unary
(
BinaryRhsScalarFunctor
<
binary_op
,
Src
,
Dst
>
(
GetValue
<
Src
>
(
src1
),
attr0
,
attr1
),
elem_cnt
,
reinterpret_cast
<
Dst
*>
(
dst
),
reinterpret_cast
<
const
Src
*>
(
src0
),
cuda_stream
->
cuda_stream
())));
}
void
Launch
(
Stream
*
stream
,
size_t
num_src0_dims
,
const
int64_t
*
src0_dims
,
const
void
*
src0
,
size_t
num_src1_dims
,
const
int64_t
*
src1_dims
,
const
void
*
src1
,
void
*
dst
)
override
{
DispatchLaunch
<
binary_op
,
Src
,
Dst
>
(
stream
,
num_src0_dims
,
src0_dims
,
reinterpret_cast
<
const
Src
*>
(
src0
),
num_src1_dims
,
src1_dims
,
reinterpret_cast
<
const
Src
*>
(
src1
),
reinterpret_cast
<
Dst
*>
(
dst
),
attr0
,
attr1
);
}
private:
Scalar
attr0
,
attr1
;
};
}
// namespace
template
<
BinaryOp
binary_op
,
typename
Src
,
typename
Dst
>
std
::
unique_ptr
<
BroadcastElementwiseBinary
>
NewBroadcastElementwiseBinary
(
Scalar
attr0
,
Scalar
attr1
)
{
return
std
::
unique_ptr
<
BroadcastElementwiseBinary
>
(
new
BroadcastElementwiseBinaryImpl
<
binary_op
,
Src
,
Dst
>
(
attr0
,
attr1
));
}
}
// namespace broadcast_elementwise_binary
}
// namespace primitive
}
// namespace ep
}
// namespace oneflow
\ No newline at end of file
oneflow/core/ep/rocm/primitive/broadcast_elementwise_binary_activation_grad.hip.cpp
View file @
8f7de847
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/ep/rocm/primitive/broadcast_elementwise_binary.hip.h"
namespace
oneflow
{
namespace
ep
{
namespace
primitive
{
namespace
broadcast_elementwise_binary
{
#define INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_ACTIVATION_GRAD_ENTRY(binary_op, \
data_type_pair
)
\
template
std
::
unique_ptr
<
BroadcastElementwiseBinary
>
NewBroadcastElementwiseBinary
<
\
binary_op
,
OF_PP_PAIR_FIRST
(
data_type_pair
),
OF_PP_PAIR_FIRST
(
data_type_pair
)
>
(
\
Scalar
attr0
,
Scalar
attr1
);
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE
(
INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_ACTIVATION_GRAD_ENTRY
,
BINARY_ACTIVATION_BACKWARD_OP_SEQ
,
CUDA_PRIMITIVE_FLOATING_TYPE_SEQ
);
}
// namespace broadcast_elementwise_binary
}
// namespace primitive
}
// namespace ep
}
// namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/ep/rocm/primitive/broadcast_elementwise_binary.hip.h"
namespace
oneflow
{
namespace
ep
{
namespace
primitive
{
namespace
broadcast_elementwise_binary
{
#define INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_ACTIVATION_GRAD_ENTRY(binary_op, \
data_type_pair) \
template std::unique_ptr<BroadcastElementwiseBinary> NewBroadcastElementwiseBinary< \
binary_op, OF_PP_PAIR_FIRST(data_type_pair), OF_PP_PAIR_FIRST(data_type_pair)>( \
Scalar attr0, Scalar attr1);
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE
(
INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_ACTIVATION_GRAD_ENTRY
,
BINARY_ACTIVATION_BACKWARD_OP_SEQ
,
CUDA_PRIMITIVE_FLOATING_TYPE_SEQ
);
}
// namespace broadcast_elementwise_binary
}
// namespace primitive
}
// namespace ep
}
// namespace oneflow
oneflow/core/ep/rocm/primitive/broadcast_elementwise_binary_comparision.hip.cpp
View file @
8f7de847
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/ep/rocm/primitive/broadcast_elementwise_binary.hip.h"
namespace
oneflow
{
namespace
ep
{
namespace
primitive
{
namespace
broadcast_elementwise_binary
{
#define INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_COMPARASION_ENTRY( \
binary_op
,
src_data_type_pair
,
dst_data_type_pair
)
\
template
std
::
unique_ptr
<
BroadcastElementwiseBinary
>
NewBroadcastElementwiseBinary
<
\
binary_op
,
OF_PP_PAIR_FIRST
(
src_data_type_pair
),
OF_PP_PAIR_FIRST
(
dst_data_type_pair
)
>
(
\
Scalar
attr0
,
Scalar
attr1
);
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE
(
INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_COMPARASION_ENTRY
,
BINARY_COMPARISION_OP_SEQ
,
CUDA_PRIMITIVE_ALL_TYPE_SEQ
,
CUDA_PRIMITIVE_BOOL_TYPE_SEQ
);
}
// namespace broadcast_elementwise_binary
}
// namespace primitive
}
// namespace ep
}
// namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/ep/rocm/primitive/broadcast_elementwise_binary.hip.h"
namespace
oneflow
{
namespace
ep
{
namespace
primitive
{
namespace
broadcast_elementwise_binary
{
#define INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_COMPARASION_ENTRY( \
binary_op, src_data_type_pair, dst_data_type_pair) \
template std::unique_ptr<BroadcastElementwiseBinary> NewBroadcastElementwiseBinary< \
binary_op, OF_PP_PAIR_FIRST(src_data_type_pair), OF_PP_PAIR_FIRST(dst_data_type_pair)>( \
Scalar attr0, Scalar attr1);
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE
(
INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_COMPARASION_ENTRY
,
BINARY_COMPARISION_OP_SEQ
,
CUDA_PRIMITIVE_ALL_TYPE_SEQ
,
CUDA_PRIMITIVE_BOOL_TYPE_SEQ
);
}
// namespace broadcast_elementwise_binary
}
// namespace primitive
}
// namespace ep
}
// namespace oneflow
oneflow/core/ep/rocm/primitive/broadcast_elementwise_binary_logical.hip.cpp
View file @
8f7de847
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/ep/rocm/primitive/broadcast_elementwise_binary.hip.h"
namespace
oneflow
{
namespace
ep
{
namespace
primitive
{
namespace
broadcast_elementwise_binary
{
#define INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_LOGICAL_ENTRY(binary_op, src_data_type_pair, \
dst_data_type_pair
)
\
template
std
::
unique_ptr
<
BroadcastElementwiseBinary
>
NewBroadcastElementwiseBinary
<
\
binary_op
,
OF_PP_PAIR_FIRST
(
src_data_type_pair
),
OF_PP_PAIR_FIRST
(
dst_data_type_pair
)
>
(
\
Scalar
attr0
,
Scalar
attr1
);
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE
(
INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_LOGICAL_ENTRY
,
BINARY_COMPARISION_OP_SEQ
BINARY_LOGICAL_OP_SEQ
,
CUDA_PRIMITIVE_ALL_TYPE_SEQ
,
CUDA_PRIMITIVE_BOOL_TYPE_SEQ
);
}
// namespace broadcast_elementwise_binary
}
// namespace primitive
}
// namespace ep
}
// namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/ep/rocm/primitive/broadcast_elementwise_binary.hip.h"
namespace
oneflow
{
namespace
ep
{
namespace
primitive
{
namespace
broadcast_elementwise_binary
{
#define INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_LOGICAL_ENTRY(binary_op, src_data_type_pair, \
dst_data_type_pair) \
template std::unique_ptr<BroadcastElementwiseBinary> NewBroadcastElementwiseBinary< \
binary_op, OF_PP_PAIR_FIRST(src_data_type_pair), OF_PP_PAIR_FIRST(dst_data_type_pair)>( \
Scalar attr0, Scalar attr1);
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE
(
INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_LOGICAL_ENTRY
,
BINARY_COMPARISION_OP_SEQ
BINARY_LOGICAL_OP_SEQ
,
CUDA_PRIMITIVE_ALL_TYPE_SEQ
,
CUDA_PRIMITIVE_BOOL_TYPE_SEQ
);
}
// namespace broadcast_elementwise_binary
}
// namespace primitive
}
// namespace ep
}
// namespace oneflow
oneflow/core/ep/rocm/primitive/broadcast_elementwise_binary_math.hip.cpp
View file @
8f7de847
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/ep/rocm/primitive/broadcast_elementwise_binary.hip.h"
namespace
oneflow
{
namespace
ep
{
namespace
primitive
{
namespace
broadcast_elementwise_binary
{
#define INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY(binary_op, data_type_pair) \
template
std
::
unique_ptr
<
BroadcastElementwiseBinary
>
NewBroadcastElementwiseBinary
<
\
binary_op
,
OF_PP_PAIR_FIRST
(
data_type_pair
),
OF_PP_PAIR_FIRST
(
data_type_pair
)
>
(
\
Scalar
attr0
,
Scalar
attr1
);
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE
(
INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY
,
BINARY_MATH_OP_SEQ
,
CUDA_PRIMITIVE_ALL_TYPE_SEQ
);
}
// namespace broadcast_elementwise_binary
}
// namespace primitive
}
// namespace ep
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/ep/rocm/primitive/broadcast_elementwise_binary.hip.h"
namespace
oneflow
{
namespace
ep
{
namespace
primitive
{
namespace
broadcast_elementwise_binary
{
#define INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY(binary_op, data_type_pair) \
template std::unique_ptr<BroadcastElementwiseBinary> NewBroadcastElementwiseBinary< \
binary_op, OF_PP_PAIR_FIRST(data_type_pair), OF_PP_PAIR_FIRST(data_type_pair)>( \
Scalar attr0, Scalar attr1);
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE
(
INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY
,
BINARY_MATH_OP_SEQ
,
CUDA_PRIMITIVE_ALL_TYPE_SEQ
);
}
// namespace broadcast_elementwise_binary
}
// namespace primitive
}
// namespace ep
}
// namespace oneflow
\ No newline at end of file
oneflow/core/ep/rocm/primitive/broadcast_matmul.cpp
View file @
8f7de847
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifdef WITH_ROCM
#include "oneflow/core/ep/include/primitive/primitive.h"
#include "oneflow/core/ep/include/primitive/broadcast_matmul.h"
#include "oneflow/core/ep/common/primitive/broadcast_matmul.h"
#include "oneflow/core/common/optional.h"
#include "oneflow/core/device/cuda_util.h"
#include "oneflow/core/ep/rocm/cuda_stream.h"
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
namespace
oneflow
{
namespace
ep
{
namespace
primitive
{
namespace
broadcast_matmul
{
namespace
internal
{
namespace
{
constexpr
size_t
kMaxNumDims
=
8
;
Optional
<
hipblasDatatype_t
>
OptCudaDataType
(
DataType
data_type
)
{
switch
(
data_type
)
{
case
kFloat
:
return
HIPBLAS_R_32F
;
case
kDouble
:
return
HIPBLAS_R_64F
;
case
kFloat16
:
return
HIPBLAS_R_16F
;
// #if CUDA_VERSION >= 11000
// case kBFloat16: return CUDA_R_16BF;
// #endif // CUDA_VERSION >= 11000
default:
return
NullOpt
;
}
}
hipblasDatatype_t
GetCudaDataType
(
DataType
data_type
)
{
auto
cuda_data_type
=
OptCudaDataType
(
data_type
);
CHECK
(
cuda_data_type
.
has_value
());
return
cuda_data_type
.
value_or
(
HIPBLAS_R_32F
);
}
union
CublasScalarParameter
{
double
d
;
float
s
;
};
CublasScalarParameter
GetCublasScalarParameter
(
Scalar
scalar
,
hipblasDatatype_t
compute_type
)
{
CublasScalarParameter
sp
{};
if
(
compute_type
==
HIPBLAS_R_64F
)
{
sp
.
d
=
scalar
.
Value
<
double
>
();
}
else
if
(
compute_type
==
HIPBLAS_R_32F
)
{
sp
.
s
=
scalar
.
Value
<
float
>
();
}
else
if
(
compute_type
==
HIPBLAS_R_16F
)
{
sp
.
s
=
scalar
.
Value
<
float
>
();
}
else
{
UNIMPLEMENTED
();
}
return
sp
;
}
hipblasDatatype_t
GetComputeType
(
DataType
data_type
)
{
switch
(
data_type
)
{
case
kFloat
:
return
HIPBLAS_R_32F
;
case
kDouble
:
return
HIPBLAS_R_64F
;
case
kFloat16
:
return
HIPBLAS_R_16F
;
// #if CUDA_VERSION >= 11000
// case kBFloat16: return HIPBLAS_R_32F;
// #endif // CUDA_VERSION >= 11000
default:
UNIMPLEMENTED
();
return
HIPBLAS_R_32F
;
}
}
void
LaunchBroadcastMatmul
(
Stream
*
stream
,
DataType
data_type
,
BlasTransposeType
transpose_a
,
BlasTransposeType
transpose_b
,
int64_t
num_batch_dims
,
const
int64_t
*
broadcast_batch_dims
,
const
int64_t
*
a_batch_dims
,
const
int64_t
*
b_batch_dims
,
const
int64_t
*
c_batch_dims
,
int64_t
m
,
int64_t
n
,
int64_t
k
,
Scalar
alpha
,
const
void
*
a
,
const
void
*
b
,
Scalar
beta
,
void
*
c
)
{
auto
*
cuda_stream
=
stream
->
As
<
CudaStream
>
();
const
auto
cuda_data_type
=
GetCudaDataType
(
data_type
);
const
auto
compute_type
=
GetComputeType
(
data_type
);
const
auto
sp_alpha
=
GetCublasScalarParameter
(
alpha
,
compute_type
);
__half
h_alpha
=
0
;
if
(
compute_type
==
HIPBLAS_R_16F
)
{
h_alpha
=
__float2half
(
sp_alpha
.
s
);
}
const
auto
GetCublasOperation
=
[](
BlasTransposeType
transpose_type
)
{
if
(
transpose_type
==
BlasTransposeType
::
N
)
{
return
HIPBLAS_OP_N
;
}
else
if
(
transpose_type
==
BlasTransposeType
::
T
)
{
return
HIPBLAS_OP_T
;
}
else
{
UNIMPLEMENTED
();
return
HIPBLAS_OP_N
;
}
};
const
hipblasOperation_t
cublas_trans_a
=
GetCublasOperation
(
transpose_b
);
const
hipblasOperation_t
cublas_trans_b
=
GetCublasOperation
(
transpose_a
);
const
int
cublas_m
=
n
;
const
int
cublas_n
=
m
;
const
int
cublas_k
=
k
;
int
cublas_lda
=
0
;
if
(
transpose_b
==
BlasTransposeType
::
N
)
{
cublas_lda
=
n
;
}
else
if
(
transpose_b
==
BlasTransposeType
::
T
)
{
cublas_lda
=
k
;
}
else
{
UNIMPLEMENTED
();
}
int
cublas_ldb
=
0
;
if
(
transpose_a
==
BlasTransposeType
::
N
)
{
cublas_ldb
=
k
;
}
else
if
(
transpose_a
==
BlasTransposeType
::
T
)
{
cublas_ldb
=
m
;
}
else
{
UNIMPLEMENTED
();
}
const
int
cublas_ldc
=
n
;
// CublasMathModeGuard guard(cuda_stream->cublas_handle());
// if (data_type == DataType::kFloat16) {
// #if CUDA_VERSION < 11000
// guard.SetMathMode(CUBLAS_TENSOR_OP_MATH);
// #else
// guard.SetMathMode(CUBLAS_DEFAULT_MATH);
// #endif // CUDA_VERSION < 11000
// }
// #if CUDA_VERSION >= 11000
// hipblasGemmAlgo_t algo = HIPBLAS_GEMM_DEFAULT;
hipblasGemmAlgo_t
algo
=
HIPBLAS_GEMM_DEFAULT
;
// #else
// hipblasGemmAlgo_t algo =
// (data_type == DataType::kFloat16) ? CUBLAS_GEMM_DFALT_TENSOR_OP : HIPBLAS_GEMM_DEFAULT;
// #endif
if
(
num_batch_dims
==
1
&&
c_batch_dims
[
0
]
!=
1
)
{
const
void
*
cublas_a
=
b
;
const
void
*
cublas_b
=
a
;
void
*
cublas_c
=
c
;
const
int64_t
a_batch_count
=
a_batch_dims
[
0
];
const
int64_t
b_batch_count
=
b_batch_dims
[
0
];
CHECK
(
a_batch_count
==
1
||
b_batch_count
==
1
||
a_batch_count
==
b_batch_count
);
CHECK_GT
(
a_batch_count
,
0
);
CHECK_GT
(
b_batch_count
,
0
);
const
int
batch_count
=
std
::
max
(
a_batch_count
,
b_batch_count
);
const
long
long
int
cublas_stride_a
=
b_batch_count
==
1
?
0
:
cublas_m
*
cublas_k
;
const
long
long
int
cublas_stride_b
=
a_batch_count
==
1
?
0
:
cublas_k
*
cublas_n
;
const
long
long
int
cublas_stride_c
=
cublas_m
*
cublas_n
;
const
auto
sp_beta
=
GetCublasScalarParameter
(
beta
,
compute_type
);
__half
h_beta
=
0
;
if
(
compute_type
==
HIPBLAS_R_16F
)
{
h_beta
=
__float2half
(
sp_beta
.
s
);
OF_CUBLAS_CHECK
(
hipblasGemmStridedBatchedEx
(
cuda_stream
->
cublas_handle
(),
cublas_trans_a
,
cublas_trans_b
,
cublas_m
,
cublas_n
,
cublas_k
,
&
h_alpha
,
cublas_a
,
cuda_data_type
,
cublas_lda
,
cublas_stride_a
,
cublas_b
,
cuda_data_type
,
cublas_ldb
,
cublas_stride_b
,
&
h_beta
,
cublas_c
,
cuda_data_type
,
cublas_ldc
,
cublas_stride_c
,
batch_count
,
compute_type
,
algo
));
}
else
{
OF_CUBLAS_CHECK
(
hipblasGemmStridedBatchedEx
(
cuda_stream
->
cublas_handle
(),
cublas_trans_a
,
cublas_trans_b
,
cublas_m
,
cublas_n
,
cublas_k
,
&
sp_alpha
,
cublas_a
,
cuda_data_type
,
cublas_lda
,
cublas_stride_a
,
cublas_b
,
cuda_data_type
,
cublas_ldb
,
cublas_stride_b
,
&
sp_beta
,
cublas_c
,
cuda_data_type
,
cublas_ldc
,
cublas_stride_c
,
batch_count
,
compute_type
,
algo
));
}
}
else
{
auto
func
=
[
&
](
const
void
*
batch_a
,
const
void
*
batch_b
,
void
*
batch_c
,
Scalar
batch_beta
)
{
const
auto
sp_beta
=
GetCublasScalarParameter
(
batch_beta
,
compute_type
);
__half
h_beta
=
0
;
const
void
*
cublas_a
=
batch_b
;
const
void
*
cublas_b
=
batch_a
;
void
*
cublas_c
=
batch_c
;
if
(
compute_type
==
HIPBLAS_R_16F
)
{
h_beta
=
__float2half
(
sp_beta
.
s
);
OF_CUBLAS_CHECK
(
hipblasGemmEx
(
cuda_stream
->
cublas_handle
(),
cublas_trans_a
,
cublas_trans_b
,
cublas_m
,
cublas_n
,
cublas_k
,
&
h_alpha
,
cublas_a
,
cuda_data_type
,
cublas_lda
,
cublas_b
,
cuda_data_type
,
cublas_ldb
,
&
h_beta
,
cublas_c
,
cuda_data_type
,
cublas_ldc
,
compute_type
,
algo
));
}
else
{
OF_CUBLAS_CHECK
(
hipblasGemmEx
(
cuda_stream
->
cublas_handle
(),
cublas_trans_a
,
cublas_trans_b
,
cublas_m
,
cublas_n
,
cublas_k
,
&
sp_alpha
,
cublas_a
,
cuda_data_type
,
cublas_lda
,
cublas_b
,
cuda_data_type
,
cublas_ldb
,
&
sp_beta
,
cublas_c
,
cuda_data_type
,
cublas_ldc
,
compute_type
,
algo
));
}
};
ForEachMatmul
<
kMaxNumDims
>
(
data_type
,
m
,
n
,
k
,
beta
,
num_batch_dims
,
broadcast_batch_dims
,
a_batch_dims
,
b_batch_dims
,
c_batch_dims
,
a
,
b
,
c
,
func
);
}
}
class
BroadcastMatmulFactoryImpl
:
public
BroadcastMatmulFactory
{
public:
OF_DISALLOW_COPY_AND_MOVE
(
BroadcastMatmulFactoryImpl
);
BroadcastMatmulFactoryImpl
()
=
default
;
~
BroadcastMatmulFactoryImpl
()
override
=
default
;
std
::
unique_ptr
<
BroadcastMatmul
>
New
(
DataType
data_type
,
BlasTransposeType
transpose_a
,
BlasTransposeType
transpose_b
,
size_t
max_num_dims
)
override
{
auto
cuda_data_type
=
OptCudaDataType
(
data_type
);
if
(
max_num_dims
<=
kMaxNumDims
&&
cuda_data_type
.
has_value
())
{
return
std
::
make_unique
<
BroadcastMatmulImpl
<
kMaxNumDims
>>
(
data_type
,
transpose_a
,
transpose_b
);
}
else
{
return
nullptr
;
}
}
};
REGISTER_PRIMITIVE_FACTORY
(
DeviceType
::
kCUDA
,
BroadcastMatmulFactory
,
BroadcastMatmulFactoryImpl
);
}
// namespace
}
// namespace internal
}
// namespace broadcast_matmul
}
// namespace primitive
}
// namespace ep
}
// namespace oneflow
#endif // WITH_ROCM
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifdef WITH_ROCM
#include "oneflow/core/ep/include/primitive/primitive.h"
#include "oneflow/core/ep/include/primitive/broadcast_matmul.h"
#include "oneflow/core/ep/common/primitive/broadcast_matmul.h"
#include "oneflow/core/common/optional.h"
#include "oneflow/core/device/cuda_util.h"
#include "oneflow/core/ep/rocm/cuda_stream.h"
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
namespace
oneflow
{
namespace
ep
{
namespace
primitive
{
namespace
broadcast_matmul
{
namespace
internal
{
namespace
{
constexpr
size_t
kMaxNumDims
=
8
;
Optional
<
hipblasDatatype_t
>
OptCudaDataType
(
DataType
data_type
)
{
switch
(
data_type
)
{
case
kFloat
:
return
HIPBLAS_R_32F
;
case
kDouble
:
return
HIPBLAS_R_64F
;
case
kFloat16
:
return
HIPBLAS_R_16F
;
// #if CUDA_VERSION >= 11000
// case kBFloat16: return CUDA_R_16BF;
// #endif // CUDA_VERSION >= 11000
default:
return
NullOpt
;
}
}
hipblasDatatype_t
GetCudaDataType
(
DataType
data_type
)
{
auto
cuda_data_type
=
OptCudaDataType
(
data_type
);
CHECK
(
cuda_data_type
.
has_value
());
return
cuda_data_type
.
value_or
(
HIPBLAS_R_32F
);
}
union
CublasScalarParameter
{
double
d
;
float
s
;
};
CublasScalarParameter
GetCublasScalarParameter
(
Scalar
scalar
,
hipblasDatatype_t
compute_type
)
{
CublasScalarParameter
sp
{};
if
(
compute_type
==
HIPBLAS_R_64F
)
{
sp
.
d
=
scalar
.
Value
<
double
>
();
}
else
if
(
compute_type
==
HIPBLAS_R_32F
)
{
sp
.
s
=
scalar
.
Value
<
float
>
();
}
else
if
(
compute_type
==
HIPBLAS_R_16F
)
{
sp
.
s
=
scalar
.
Value
<
float
>
();
}
else
{
UNIMPLEMENTED
();
}
return
sp
;
}
hipblasDatatype_t
GetComputeType
(
DataType
data_type
)
{
switch
(
data_type
)
{
case
kFloat
:
return
HIPBLAS_R_32F
;
case
kDouble
:
return
HIPBLAS_R_64F
;
case
kFloat16
:
return
HIPBLAS_R_16F
;
// #if CUDA_VERSION >= 11000
// case kBFloat16: return HIPBLAS_R_32F;
// #endif // CUDA_VERSION >= 11000
default:
UNIMPLEMENTED
();
return
HIPBLAS_R_32F
;
}
}
void
LaunchBroadcastMatmul
(
Stream
*
stream
,
DataType
data_type
,
BlasTransposeType
transpose_a
,
BlasTransposeType
transpose_b
,
int64_t
num_batch_dims
,
const
int64_t
*
broadcast_batch_dims
,
const
int64_t
*
a_batch_dims
,
const
int64_t
*
b_batch_dims
,
const
int64_t
*
c_batch_dims
,
int64_t
m
,
int64_t
n
,
int64_t
k
,
Scalar
alpha
,
const
void
*
a
,
const
void
*
b
,
Scalar
beta
,
void
*
c
)
{
auto
*
cuda_stream
=
stream
->
As
<
CudaStream
>
();
const
auto
cuda_data_type
=
GetCudaDataType
(
data_type
);
const
auto
compute_type
=
GetComputeType
(
data_type
);
const
auto
sp_alpha
=
GetCublasScalarParameter
(
alpha
,
compute_type
);
__half
h_alpha
=
0
;
if
(
compute_type
==
HIPBLAS_R_16F
)
{
h_alpha
=
__float2half
(
sp_alpha
.
s
);
}
const
auto
GetCublasOperation
=
[](
BlasTransposeType
transpose_type
)
{
if
(
transpose_type
==
BlasTransposeType
::
N
)
{
return
HIPBLAS_OP_N
;
}
else
if
(
transpose_type
==
BlasTransposeType
::
T
)
{
return
HIPBLAS_OP_T
;
}
else
{
UNIMPLEMENTED
();
return
HIPBLAS_OP_N
;
}
};
const
hipblasOperation_t
cublas_trans_a
=
GetCublasOperation
(
transpose_b
);
const
hipblasOperation_t
cublas_trans_b
=
GetCublasOperation
(
transpose_a
);
const
int
cublas_m
=
n
;
const
int
cublas_n
=
m
;
const
int
cublas_k
=
k
;
int
cublas_lda
=
0
;
if
(
transpose_b
==
BlasTransposeType
::
N
)
{
cublas_lda
=
n
;
}
else
if
(
transpose_b
==
BlasTransposeType
::
T
)
{
cublas_lda
=
k
;
}
else
{
UNIMPLEMENTED
();
}
int
cublas_ldb
=
0
;
if
(
transpose_a
==
BlasTransposeType
::
N
)
{
cublas_ldb
=
k
;
}
else
if
(
transpose_a
==
BlasTransposeType
::
T
)
{
cublas_ldb
=
m
;
}
else
{
UNIMPLEMENTED
();
}
const
int
cublas_ldc
=
n
;
// CublasMathModeGuard guard(cuda_stream->cublas_handle());
// if (data_type == DataType::kFloat16) {
// #if CUDA_VERSION < 11000
// guard.SetMathMode(CUBLAS_TENSOR_OP_MATH);
// #else
// guard.SetMathMode(CUBLAS_DEFAULT_MATH);
// #endif // CUDA_VERSION < 11000
// }
// #if CUDA_VERSION >= 11000
// hipblasGemmAlgo_t algo = HIPBLAS_GEMM_DEFAULT;
hipblasGemmAlgo_t
algo
=
HIPBLAS_GEMM_DEFAULT
;
// #else
// hipblasGemmAlgo_t algo =
// (data_type == DataType::kFloat16) ? CUBLAS_GEMM_DFALT_TENSOR_OP : HIPBLAS_GEMM_DEFAULT;
// #endif
if
(
num_batch_dims
==
1
&&
c_batch_dims
[
0
]
!=
1
)
{
const
void
*
cublas_a
=
b
;
const
void
*
cublas_b
=
a
;
void
*
cublas_c
=
c
;
const
int64_t
a_batch_count
=
a_batch_dims
[
0
];
const
int64_t
b_batch_count
=
b_batch_dims
[
0
];
CHECK
(
a_batch_count
==
1
||
b_batch_count
==
1
||
a_batch_count
==
b_batch_count
);
CHECK_GT
(
a_batch_count
,
0
);
CHECK_GT
(
b_batch_count
,
0
);
const
int
batch_count
=
std
::
max
(
a_batch_count
,
b_batch_count
);
const
long
long
int
cublas_stride_a
=
b_batch_count
==
1
?
0
:
cublas_m
*
cublas_k
;
const
long
long
int
cublas_stride_b
=
a_batch_count
==
1
?
0
:
cublas_k
*
cublas_n
;
const
long
long
int
cublas_stride_c
=
cublas_m
*
cublas_n
;
const
auto
sp_beta
=
GetCublasScalarParameter
(
beta
,
compute_type
);
__half
h_beta
=
0
;
if
(
compute_type
==
HIPBLAS_R_16F
)
{
h_beta
=
__float2half
(
sp_beta
.
s
);
OF_CUBLAS_CHECK
(
hipblasGemmStridedBatchedEx
(
cuda_stream
->
cublas_handle
(),
cublas_trans_a
,
cublas_trans_b
,
cublas_m
,
cublas_n
,
cublas_k
,
&
h_alpha
,
cublas_a
,
cuda_data_type
,
cublas_lda
,
cublas_stride_a
,
cublas_b
,
cuda_data_type
,
cublas_ldb
,
cublas_stride_b
,
&
h_beta
,
cublas_c
,
cuda_data_type
,
cublas_ldc
,
cublas_stride_c
,
batch_count
,
compute_type
,
algo
));
}
else
{
OF_CUBLAS_CHECK
(
hipblasGemmStridedBatchedEx
(
cuda_stream
->
cublas_handle
(),
cublas_trans_a
,
cublas_trans_b
,
cublas_m
,
cublas_n
,
cublas_k
,
&
sp_alpha
,
cublas_a
,
cuda_data_type
,
cublas_lda
,
cublas_stride_a
,
cublas_b
,
cuda_data_type
,
cublas_ldb
,
cublas_stride_b
,
&
sp_beta
,
cublas_c
,
cuda_data_type
,
cublas_ldc
,
cublas_stride_c
,
batch_count
,
compute_type
,
algo
));
}
}
else
{
auto
func
=
[
&
](
const
void
*
batch_a
,
const
void
*
batch_b
,
void
*
batch_c
,
Scalar
batch_beta
)
{
const
auto
sp_beta
=
GetCublasScalarParameter
(
batch_beta
,
compute_type
);
__half
h_beta
=
0
;
const
void
*
cublas_a
=
batch_b
;
const
void
*
cublas_b
=
batch_a
;
void
*
cublas_c
=
batch_c
;
if
(
compute_type
==
HIPBLAS_R_16F
)
{
h_beta
=
__float2half
(
sp_beta
.
s
);
OF_CUBLAS_CHECK
(
hipblasGemmEx
(
cuda_stream
->
cublas_handle
(),
cublas_trans_a
,
cublas_trans_b
,
cublas_m
,
cublas_n
,
cublas_k
,
&
h_alpha
,
cublas_a
,
cuda_data_type
,
cublas_lda
,
cublas_b
,
cuda_data_type
,
cublas_ldb
,
&
h_beta
,
cublas_c
,
cuda_data_type
,
cublas_ldc
,
compute_type
,
algo
));
}
else
{
OF_CUBLAS_CHECK
(
hipblasGemmEx
(
cuda_stream
->
cublas_handle
(),
cublas_trans_a
,
cublas_trans_b
,
cublas_m
,
cublas_n
,
cublas_k
,
&
sp_alpha
,
cublas_a
,
cuda_data_type
,
cublas_lda
,
cublas_b
,
cuda_data_type
,
cublas_ldb
,
&
sp_beta
,
cublas_c
,
cuda_data_type
,
cublas_ldc
,
compute_type
,
algo
));
}
};
ForEachMatmul
<
kMaxNumDims
>
(
data_type
,
m
,
n
,
k
,
beta
,
num_batch_dims
,
broadcast_batch_dims
,
a_batch_dims
,
b_batch_dims
,
c_batch_dims
,
a
,
b
,
c
,
func
);
}
}
class
BroadcastMatmulFactoryImpl
:
public
BroadcastMatmulFactory
{
public:
OF_DISALLOW_COPY_AND_MOVE
(
BroadcastMatmulFactoryImpl
);
BroadcastMatmulFactoryImpl
()
=
default
;
~
BroadcastMatmulFactoryImpl
()
override
=
default
;
std
::
unique_ptr
<
BroadcastMatmul
>
New
(
DataType
data_type
,
BlasTransposeType
transpose_a
,
BlasTransposeType
transpose_b
,
size_t
max_num_dims
)
override
{
auto
cuda_data_type
=
OptCudaDataType
(
data_type
);
if
(
max_num_dims
<=
kMaxNumDims
&&
cuda_data_type
.
has_value
())
{
return
std
::
make_unique
<
BroadcastMatmulImpl
<
kMaxNumDims
>>
(
data_type
,
transpose_a
,
transpose_b
);
}
else
{
return
nullptr
;
}
}
};
REGISTER_PRIMITIVE_FACTORY
(
DeviceType
::
kCUDA
,
BroadcastMatmulFactory
,
BroadcastMatmulFactoryImpl
);
}
// namespace
}
// namespace internal
}
// namespace broadcast_matmul
}
// namespace primitive
}
// namespace ep
}
// namespace oneflow
#endif // WITH_ROCM
oneflow/core/ep/rocm/primitive/cast.hip.cpp
View file @
8f7de847
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/common/preprocessor.h"
#include "oneflow/core/ep/include/primitive/cast.h"
#include "oneflow/core/ep/rocm/primitive/type_seq.h"
#include "oneflow/core/hip/elementwise.hip.h"
#include "oneflow/core/ep/rocm/cuda_stream.h"
namespace
oneflow
{
namespace
ep
{
namespace
primitive
{
namespace
{
template
<
typename
To
,
typename
From
,
typename
=
void
>
struct
CastFunctor
{
__device__
To
operator
()(
From
from
)
const
{
return
static_cast
<
To
>
(
from
);
}
};
template
<
typename
To
>
struct
CastFunctor
<
To
,
half
,
typename
std
::
enable_if
<!
std
::
is_same
<
To
,
half
>::
value
>::
type
>
{
__device__
To
operator
()(
half
from
)
const
{
return
static_cast
<
To
>
(
static_cast
<
float
>
(
from
));
}
__device__
void
Apply2
(
To
*
to
,
const
half
*
from
)
const
{
const
float2
f2
=
__half22float2
(
*
reinterpret_cast
<
const
half2
*>
(
from
));
to
[
0
]
=
static_cast
<
To
>
(
f2
.
x
);
to
[
1
]
=
static_cast
<
To
>
(
f2
.
y
);
}
};
template
<
typename
From
>
struct
CastFunctor
<
half
,
From
,
typename
std
::
enable_if
<!
std
::
is_same
<
From
,
half
>::
value
>::
type
>
{
__device__
half
operator
()(
From
from
)
const
{
return
static_cast
<
half
>
(
static_cast
<
float
>
(
from
));
}
__device__
void
Apply2
(
half
*
to
,
const
From
*
from
)
const
{
float2
f2
;
f2
.
x
=
static_cast
<
float
>
(
from
[
0
]);
f2
.
y
=
static_cast
<
float
>
(
from
[
1
]);
*
reinterpret_cast
<
half2
*>
(
to
)
=
__float22half2_rn
(
f2
);
}
};
// #if CUDA_VERSION >= 11000
// template<typename To>
// struct CastFunctor<To, nv_bfloat16,
// typename std::enable_if<!(std::is_same<To, nv_bfloat16>::value
// || std::is_same<To, half>::value)>::type> {
// __device__ To operator()(nv_bfloat16 from) const {
// return static_cast<To>(static_cast<float>(from));
// }
// };
// template<typename From>
// struct CastFunctor<nv_bfloat16, From,
// typename std::enable_if<!(std::is_same<From, nv_bfloat16>::value
// || std::is_same<From, half>::value)>::type> {
// __device__ nv_bfloat16 operator()(From from) const {
// return static_cast<nv_bfloat16>(static_cast<float>(from));
// }
// };
// #endif // CUDA_VERSION >= 11000
template
<
typename
From
,
typename
To
>
class
CastImpl
:
public
Cast
{
public:
OF_DISALLOW_COPY_AND_MOVE
(
CastImpl
);
explicit
CastImpl
()
=
default
;
~
CastImpl
()
override
=
default
;
void
Launch
(
Stream
*
stream
,
const
void
*
from
,
void
*
to
,
size_t
count
)
override
{
auto
*
cuda_stream
=
stream
->
As
<
CudaStream
>
();
OF_CUDA_CHECK
((
cuda
::
elementwise
::
Unary
<
CastFunctor
<
To
,
From
>
,
To
,
From
>
(
CastFunctor
<
To
,
From
>
(),
count
,
reinterpret_cast
<
To
*>
(
to
),
reinterpret_cast
<
const
From
*>
(
from
),
cuda_stream
->
cuda_stream
())));
}
};
template
<
typename
From
,
typename
To
>
std
::
unique_ptr
<
Cast
>
NewCast
()
{
return
std
::
unique_ptr
<
Cast
>
(
new
CastImpl
<
From
,
To
>
());
}
#define CUDA_PRIMITIVE_CAST_TYPE_SEQ \
CUDA_PRIMITIVE_BOOL_TYPE_SEQ
\
CUDA_PRIMITIVE_CHAR_TYPE_SEQ
\
CUDA_PRIMITIVE_INT8_TYPE_SEQ
\
CUDA_PRIMITIVE_UINT8_TYPE_SEQ
\
CUDA_PRIMITIVE_INT32_TYPE_SEQ
\
CUDA_PRIMITIVE_UINT32_TYPE_SEQ
\
CUDA_PRIMITIVE_INT64_TYPE_SEQ
\
CUDA_PRIMITIVE_UINT64_TYPE_SEQ
\
CUDA_PRIMITIVE_FLOAT_TYPE_SEQ
\
CUDA_PRIMITIVE_DOUBLE_TYPE_SEQ
\
CUDA_PRIMITIVE_FLOAT16_TYPE_SEQ
\
CUDA_PRIMITIVE_BFLOAT16_TYPE_SEQ
class
CastFactoryImpl
:
public
CastFactory
{
public:
OF_DISALLOW_COPY_AND_MOVE
(
CastFactoryImpl
);
CastFactoryImpl
()
=
default
;
~
CastFactoryImpl
()
override
=
default
;
std
::
unique_ptr
<
Cast
>
New
(
DataType
from
,
DataType
to
)
override
{
#define MAKE_NEW_CAST_ENTRY(from_pair, to_pair) \
{
std
::
make_pair
(
OF_PP_PAIR_SECOND
(
from_pair
),
OF_PP_PAIR_SECOND
(
to_pair
)),
\
NewCast
<
OF_PP_PAIR_FIRST
(
from_pair
),
OF_PP_PAIR_FIRST
(
to_pair
)
>
},
static
const
std
::
map
<
std
::
pair
<
DataType
,
DataType
>
,
std
::
function
<
std
::
unique_ptr
<
Cast
>
()
>>
new_cast_handle
{
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE
(
MAKE_NEW_CAST_ENTRY
,
CUDA_PRIMITIVE_CAST_TYPE_SEQ
,
CUDA_PRIMITIVE_CAST_TYPE_SEQ
)};
#undef MAKE_NEW_CAST_ENTRY
const
auto
it
=
new_cast_handle
.
find
(
std
::
make_pair
(
from
,
to
));
if
(
it
!=
new_cast_handle
.
end
())
{
return
it
->
second
();
}
else
{
return
nullptr
;
}
}
};
REGISTER_PRIMITIVE_FACTORY
(
DeviceType
::
kCUDA
,
CastFactory
,
CastFactoryImpl
);
}
// namespace
}
// namespace primitive
}
// namespace ep
}
// namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/common/preprocessor.h"
#include "oneflow/core/ep/include/primitive/cast.h"
#include "oneflow/core/ep/rocm/primitive/type_seq.h"
#include "oneflow/core/hip/elementwise.hip.h"
#include "oneflow/core/ep/rocm/cuda_stream.h"
namespace
oneflow
{
namespace
ep
{
namespace
primitive
{
namespace
{
template
<
typename
To
,
typename
From
,
typename
=
void
>
struct
CastFunctor
{
__device__
To
operator
()(
From
from
)
const
{
return
static_cast
<
To
>
(
from
);
}
};
template
<
typename
To
>
struct
CastFunctor
<
To
,
half
,
typename
std
::
enable_if
<!
std
::
is_same
<
To
,
half
>::
value
>::
type
>
{
__device__
To
operator
()(
half
from
)
const
{
return
static_cast
<
To
>
(
static_cast
<
float
>
(
from
));
}
__device__
void
Apply2
(
To
*
to
,
const
half
*
from
)
const
{
const
float2
f2
=
__half22float2
(
*
reinterpret_cast
<
const
half2
*>
(
from
));
to
[
0
]
=
static_cast
<
To
>
(
f2
.
x
);
to
[
1
]
=
static_cast
<
To
>
(
f2
.
y
);
}
};
template
<
typename
From
>
struct
CastFunctor
<
half
,
From
,
typename
std
::
enable_if
<!
std
::
is_same
<
From
,
half
>::
value
>::
type
>
{
__device__
half
operator
()(
From
from
)
const
{
return
static_cast
<
half
>
(
static_cast
<
float
>
(
from
));
}
__device__
void
Apply2
(
half
*
to
,
const
From
*
from
)
const
{
float2
f2
;
f2
.
x
=
static_cast
<
float
>
(
from
[
0
]);
f2
.
y
=
static_cast
<
float
>
(
from
[
1
]);
*
reinterpret_cast
<
half2
*>
(
to
)
=
__float22half2_rn
(
f2
);
}
};
// #if CUDA_VERSION >= 11000
// template<typename To>
// struct CastFunctor<To, nv_bfloat16,
// typename std::enable_if<!(std::is_same<To, nv_bfloat16>::value
// || std::is_same<To, half>::value)>::type> {
// __device__ To operator()(nv_bfloat16 from) const {
// return static_cast<To>(static_cast<float>(from));
// }
// };
// template<typename From>
// struct CastFunctor<nv_bfloat16, From,
// typename std::enable_if<!(std::is_same<From, nv_bfloat16>::value
// || std::is_same<From, half>::value)>::type> {
// __device__ nv_bfloat16 operator()(From from) const {
// return static_cast<nv_bfloat16>(static_cast<float>(from));
// }
// };
// #endif // CUDA_VERSION >= 11000
template
<
typename
From
,
typename
To
>
class
CastImpl
:
public
Cast
{
public:
OF_DISALLOW_COPY_AND_MOVE
(
CastImpl
);
explicit
CastImpl
()
=
default
;
~
CastImpl
()
override
=
default
;
void
Launch
(
Stream
*
stream
,
const
void
*
from
,
void
*
to
,
size_t
count
)
override
{
auto
*
cuda_stream
=
stream
->
As
<
CudaStream
>
();
OF_CUDA_CHECK
((
cuda
::
elementwise
::
Unary
<
CastFunctor
<
To
,
From
>
,
To
,
From
>
(
CastFunctor
<
To
,
From
>
(),
count
,
reinterpret_cast
<
To
*>
(
to
),
reinterpret_cast
<
const
From
*>
(
from
),
cuda_stream
->
cuda_stream
())));
}
};
template
<
typename
From
,
typename
To
>
std
::
unique_ptr
<
Cast
>
NewCast
()
{
return
std
::
unique_ptr
<
Cast
>
(
new
CastImpl
<
From
,
To
>
());
}
#define CUDA_PRIMITIVE_CAST_TYPE_SEQ \
CUDA_PRIMITIVE_BOOL_TYPE_SEQ \
CUDA_PRIMITIVE_CHAR_TYPE_SEQ \
CUDA_PRIMITIVE_INT8_TYPE_SEQ \
CUDA_PRIMITIVE_UINT8_TYPE_SEQ \
CUDA_PRIMITIVE_INT32_TYPE_SEQ \
CUDA_PRIMITIVE_UINT32_TYPE_SEQ \
CUDA_PRIMITIVE_INT64_TYPE_SEQ \
CUDA_PRIMITIVE_UINT64_TYPE_SEQ \
CUDA_PRIMITIVE_FLOAT_TYPE_SEQ \
CUDA_PRIMITIVE_DOUBLE_TYPE_SEQ \
CUDA_PRIMITIVE_FLOAT16_TYPE_SEQ \
CUDA_PRIMITIVE_BFLOAT16_TYPE_SEQ
class
CastFactoryImpl
:
public
CastFactory
{
public:
OF_DISALLOW_COPY_AND_MOVE
(
CastFactoryImpl
);
CastFactoryImpl
()
=
default
;
~
CastFactoryImpl
()
override
=
default
;
std
::
unique_ptr
<
Cast
>
New
(
DataType
from
,
DataType
to
)
override
{
#define MAKE_NEW_CAST_ENTRY(from_pair, to_pair) \
{std::make_pair(OF_PP_PAIR_SECOND(from_pair), OF_PP_PAIR_SECOND(to_pair)), \
NewCast<OF_PP_PAIR_FIRST(from_pair), OF_PP_PAIR_FIRST(to_pair)>},
static
const
std
::
map
<
std
::
pair
<
DataType
,
DataType
>
,
std
::
function
<
std
::
unique_ptr
<
Cast
>
()
>>
new_cast_handle
{
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE
(
MAKE_NEW_CAST_ENTRY
,
CUDA_PRIMITIVE_CAST_TYPE_SEQ
,
CUDA_PRIMITIVE_CAST_TYPE_SEQ
)};
#undef MAKE_NEW_CAST_ENTRY
const
auto
it
=
new_cast_handle
.
find
(
std
::
make_pair
(
from
,
to
));
if
(
it
!=
new_cast_handle
.
end
())
{
return
it
->
second
();
}
else
{
return
nullptr
;
}
}
};
REGISTER_PRIMITIVE_FACTORY
(
DeviceType
::
kCUDA
,
CastFactory
,
CastFactoryImpl
);
}
// namespace
}
// namespace primitive
}
// namespace ep
}
// namespace oneflow
oneflow/core/ep/rocm/primitive/constant_pad.hip.cpp
View file @
8f7de847
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/ep/include/primitive/constant_pad.h"
#include "oneflow/core/ep/common/primitive/constant_pad.h"
#include "oneflow/core/ep/rocm/primitive/type_seq.h"
#include "oneflow/core/ep/rocm/cuda_stream.h"
#include <hip/hip_runtime.h>
namespace
oneflow
{
namespace
ep
{
namespace
primitive
{
namespace
{
template
<
size_t
num_dims
,
typename
IndexType
,
typename
StorageType
>
__global__
void
ConstantPadKernel
(
ConstantPadParams
<
num_dims
,
IndexType
>
params
,
StorageType
packed_pad_val
)
{
const
StorageType
*
src
=
reinterpret_cast
<
const
StorageType
*>
(
params
.
src
);
StorageType
*
dst
=
reinterpret_cast
<
StorageType
*>
(
params
.
dst
);
IndexType
src_index
[
num_dims
];
IndexType
dst_index
[
num_dims
];
CUDA_1D_KERNEL_LOOP_T
(
IndexType
,
linear_index
,
params
.
elem_cnt
)
{
params
.
dst_index_helper
.
OffsetToNdIndex
(
linear_index
,
dst_index
);
bool
if_pad
=
false
;
#pragma unroll
for
(
int
i
=
0
;
i
<
num_dims
;
i
++
)
{
if
(
dst_index
[
i
]
>=
params
.
valid_start
[
i
]
&&
dst_index
[
i
]
<
params
.
valid_end
[
i
])
{
src_index
[
i
]
=
dst_index
[
i
]
-
params
.
valid_start
[
i
];
}
else
{
if_pad
=
true
;
break
;
}
}
StorageType
dst_val
=
packed_pad_val
;
if
(
!
if_pad
)
{
const
IndexType
src_offset
=
params
.
src_index_helper
.
NdIndexToOffset
(
src_index
);
dst_val
=
src
[
src_offset
];
}
dst
[
linear_index
]
=
dst_val
;
}
}
template
<
>
half
GetValue
<
half
>
(
Scalar
value
)
{
return
static_cast
<
half
>
(
GetValue
<
float
>
(
value
));
}
// #if CUDA_VERSION >= 11000
// template<>
// nv_bfloat16 GetValue<nv_bfloat16>(Scalar value) {
// return static_cast<nv_bfloat16>(GetValue<float>(value));
// }
// #endif // CUDA_VERSION >= 11000
template
<
size_t
num_dims
,
typename
IndexType
,
typename
StorageType
>
void
LaunchKernel
(
Stream
*
stream
,
ConstantPadParams
<
num_dims
,
IndexType
>
params
,
StorageType
packed_pad_val
,
size_t
elem_cnt
)
{
stream
->
As
<
CudaStream
>
()
->
LaunchKernelDefaultWaves
(
(
ConstantPadKernel
<
num_dims
,
IndexType
,
StorageType
>
),
elem_cnt
,
params
,
packed_pad_val
);
}
template
<
size_t
num_dims
,
typename
IndexType
,
typename
StorageType
>
void
LaunchKernel
(
Stream
*
stream
,
void
*
dst
,
const
int64_t
*
dst_dims
,
const
void
*
src
,
const
int64_t
*
src_dims
,
const
int64_t
*
padding_before
,
const
int64_t
*
padding_after
,
StorageType
packed_pad_val
,
size_t
elem_cnt
)
{
ConstantPadParams
<
num_dims
,
IndexType
>
params
;
params
.
dst_index_helper
=
OffsetToIndexCalculator
<
IndexType
,
num_dims
>
(
dst_dims
);
params
.
src_index_helper
=
NdIndexOffsetHelper
<
IndexType
,
num_dims
>
(
src_dims
);
params
.
dst
=
dst
;
params
.
src
=
src
;
for
(
int
i
=
0
;
i
<
num_dims
;
i
++
)
{
params
.
valid_start
[
i
]
=
padding_before
[
i
];
params
.
valid_end
[
i
]
=
dst_dims
[
i
]
-
padding_after
[
i
];
}
params
.
elem_cnt
=
elem_cnt
;
LaunchKernel
<
num_dims
,
IndexType
,
StorageType
>
(
stream
,
params
,
packed_pad_val
,
elem_cnt
);
}
template
<
size_t
num_dims
,
typename
StorageType
>
void
DispatchIndexType
(
Stream
*
stream
,
void
*
dst
,
const
int64_t
*
dst_dims
,
const
void
*
src
,
const
int64_t
*
src_dims
,
const
int64_t
*
padding_before
,
const
int64_t
*
padding_after
,
StorageType
packed_pad_val
,
size_t
elem_cnt
)
{
if
(
elem_cnt
<
GetMaxVal
<
int32_t
>
())
{
LaunchKernel
<
num_dims
,
int32_t
,
StorageType
>
(
stream
,
dst
,
dst_dims
,
src
,
src_dims
,
padding_before
,
padding_after
,
packed_pad_val
,
elem_cnt
);
}
else
{
LaunchKernel
<
num_dims
,
int64_t
,
StorageType
>
(
stream
,
dst
,
dst_dims
,
src
,
src_dims
,
padding_before
,
padding_after
,
packed_pad_val
,
elem_cnt
);
}
}
template
<
size_t
num_dims
,
typename
T
>
void
DispatchPackSize
(
Stream
*
stream
,
void
*
dst
,
int64_t
*
dst_dims
,
const
void
*
src
,
int64_t
*
src_dims
,
int64_t
*
padding_before
,
int64_t
*
padding_after
,
T
pad_val
)
{
constexpr
int32_t
max_packsize
=
GetMaxPackSize
<
T
>
();
size_t
launch_pack_size
=
GetLaunchPackSize
<
max_packsize
>
(
num_dims
,
dst
,
dst_dims
,
src
,
src_dims
,
padding_before
,
padding_after
);
dst_dims
[
num_dims
-
1
]
/=
launch_pack_size
;
src_dims
[
num_dims
-
1
]
/=
launch_pack_size
;
padding_before
[
num_dims
-
1
]
/=
launch_pack_size
;
padding_after
[
num_dims
-
1
]
/=
launch_pack_size
;
size_t
elem_cnt
=
1
;
for
(
int
i
=
0
;
i
<
num_dims
;
i
++
)
{
elem_cnt
*=
dst_dims
[
i
];
}
if
(
launch_pack_size
==
1
)
{
Pack
<
T
,
1
>
packed_pad_val
(
pad_val
);
DispatchIndexType
<
num_dims
,
PackType
<
T
,
1
>>
(
stream
,
dst
,
dst_dims
,
src
,
src_dims
,
padding_before
,
padding_after
,
packed_pad_val
.
storage
,
elem_cnt
);
}
else
if
(
launch_pack_size
==
2
)
{
Pack
<
T
,
2
>
packed_pad_val
(
pad_val
);
DispatchIndexType
<
num_dims
,
PackType
<
T
,
2
>>
(
stream
,
dst
,
dst_dims
,
src
,
src_dims
,
padding_before
,
padding_after
,
packed_pad_val
.
storage
,
elem_cnt
);
}
else
if
(
launch_pack_size
==
4
)
{
Pack
<
T
,
4
>
packed_pad_val
(
pad_val
);
DispatchIndexType
<
num_dims
,
PackType
<
T
,
4
>>
(
stream
,
dst
,
dst_dims
,
src
,
src_dims
,
padding_before
,
padding_after
,
packed_pad_val
.
storage
,
elem_cnt
);
}
else
if
(
launch_pack_size
==
8
)
{
Pack
<
T
,
8
>
packed_pad_val
(
pad_val
);
DispatchIndexType
<
num_dims
,
PackType
<
T
,
8
>>
(
stream
,
dst
,
dst_dims
,
src
,
src_dims
,
padding_before
,
padding_after
,
packed_pad_val
.
storage
,
elem_cnt
);
}
else
if
(
launch_pack_size
==
16
)
{
Pack
<
T
,
16
>
packed_pad_val
(
pad_val
);
DispatchIndexType
<
num_dims
,
PackType
<
T
,
16
>>
(
stream
,
dst
,
dst_dims
,
src
,
src_dims
,
padding_before
,
padding_after
,
packed_pad_val
.
storage
,
elem_cnt
);
}
else
{
UNIMPLEMENTED
();
}
}
template
<
typename
T
>
void
LaunchWithSimplified
(
Stream
*
stream
,
size_t
num_dims
,
void
*
dst
,
int64_t
*
dst_dims
,
const
void
*
src
,
int64_t
*
src_dims
,
int64_t
*
padding_before
,
int64_t
*
padding_after
,
T
pad_val
)
{
void
(
*
func
)(
Stream
*
/*stream*/
,
void
*
/*dst*/
,
int64_t
*
/*dst_dims*/
,
const
void
*
/*src*/
,
int64_t
*
/*src_dims*/
,
int64_t
*
/*padding_before*/
,
int64_t
*
/*padding_after*/
,
T
)
=
nullptr
;
if
(
num_dims
==
1
)
{
func
=
DispatchPackSize
<
1
,
T
>
;
}
else
if
(
num_dims
==
2
)
{
func
=
DispatchPackSize
<
2
,
T
>
;
}
else
if
(
num_dims
==
3
)
{
func
=
DispatchPackSize
<
3
,
T
>
;
}
else
if
(
num_dims
==
4
)
{
func
=
DispatchPackSize
<
4
,
T
>
;
}
else
if
(
num_dims
==
5
)
{
func
=
DispatchPackSize
<
5
,
T
>
;
}
else
if
(
num_dims
==
6
)
{
func
=
DispatchPackSize
<
6
,
T
>
;
}
else
if
(
num_dims
==
7
)
{
func
=
DispatchPackSize
<
7
,
T
>
;
}
else
if
(
num_dims
==
8
)
{
func
=
DispatchPackSize
<
8
,
T
>
;
}
else
{
UNIMPLEMENTED
();
}
func
(
stream
,
dst
,
dst_dims
,
src
,
src_dims
,
padding_before
,
padding_after
,
pad_val
);
}
template
<
typename
T
>
void
SimplifyThenLaunch
(
Stream
*
stream
,
size_t
num_dims
,
const
int64_t
*
src_dims
,
const
void
*
src
,
const
int64_t
*
padding_before
,
const
int64_t
*
padding_after
,
T
pad_val
,
void
*
dst
)
{
CHECK_LE
(
num_dims
,
kMaxNumDims
);
int64_t
simplified_dst_dims
[
kMaxNumDims
];
int64_t
simplified_src_dims
[
kMaxNumDims
];
int64_t
simplified_padding_before
[
kMaxNumDims
];
int64_t
simplified_padding_after
[
kMaxNumDims
];
size_t
simplified_num_dims
=
1
;
SimplifyPadDims
(
num_dims
,
src_dims
,
padding_before
,
padding_after
,
&
simplified_num_dims
,
simplified_dst_dims
,
simplified_src_dims
,
simplified_padding_before
,
simplified_padding_after
);
LaunchWithSimplified
<
T
>
(
stream
,
simplified_num_dims
,
dst
,
simplified_dst_dims
,
src
,
simplified_src_dims
,
simplified_padding_before
,
simplified_padding_after
,
pad_val
);
}
template
<
typename
T
>
class
ConstantPadImpl
:
public
ConstantPad
{
public:
OF_DISALLOW_COPY_AND_MOVE
(
ConstantPadImpl
);
ConstantPadImpl
()
=
default
;
~
ConstantPadImpl
()
override
=
default
;
void
Launch
(
Stream
*
stream
,
size_t
num_dims
,
const
int64_t
*
src_dims
,
const
void
*
src
,
const
int64_t
*
padding_before
,
const
int64_t
*
padding_after
,
Scalar
pad_val
,
void
*
dst
)
override
{
SimplifyThenLaunch
<
T
>
(
stream
,
num_dims
,
src_dims
,
src
,
padding_before
,
padding_after
,
GetValue
<
T
>
(
pad_val
),
dst
);
}
};
template
<
typename
T
>
std
::
unique_ptr
<
ConstantPad
>
NewConstantPad
()
{
return
std
::
unique_ptr
<
ConstantPad
>
(
new
ConstantPadImpl
<
T
>
());
}
class
ConstantPadFactoryImpl
:
public
ConstantPadFactory
{
public:
OF_DISALLOW_COPY_AND_MOVE
(
ConstantPadFactoryImpl
);
ConstantPadFactoryImpl
()
=
default
;
~
ConstantPadFactoryImpl
()
override
=
default
;
std
::
unique_ptr
<
ConstantPad
>
New
(
DataType
data_type
)
override
{
#define MAKE_NEW_CONSTANT_PAD_ENTRY(type_cpp, type_proto) {type_proto, NewConstantPad<type_cpp>},
static
const
std
::
map
<
DataType
,
std
::
function
<
std
::
unique_ptr
<
ConstantPad
>
()
>>
new_constant_pad_handle
{
OF_PP_FOR_EACH_TUPLE
(
MAKE_NEW_CONSTANT_PAD_ENTRY
,
CUDA_PRIMITIVE_ALL_TYPE_SEQ
)};
#undef MAKE_NEW_CONSTANT_PAD_ENTRY
const
auto
it
=
new_constant_pad_handle
.
find
(
data_type
);
if
(
it
!=
new_constant_pad_handle
.
end
())
{
return
it
->
second
();
}
else
{
return
nullptr
;
}
}
};
REGISTER_PRIMITIVE_FACTORY
(
DeviceType
::
kCUDA
,
ConstantPadFactory
,
ConstantPadFactoryImpl
);
}
// namespace
}
// namespace primitive
}
// namespace ep
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/ep/include/primitive/constant_pad.h"
#include "oneflow/core/ep/common/primitive/constant_pad.h"
#include "oneflow/core/ep/rocm/primitive/type_seq.h"
#include "oneflow/core/ep/rocm/cuda_stream.h"
#include <hip/hip_runtime.h>
namespace
oneflow
{
namespace
ep
{
namespace
primitive
{
namespace
{
template
<
size_t
num_dims
,
typename
IndexType
,
typename
StorageType
>
__global__
void
ConstantPadKernel
(
ConstantPadParams
<
num_dims
,
IndexType
>
params
,
StorageType
packed_pad_val
)
{
const
StorageType
*
src
=
reinterpret_cast
<
const
StorageType
*>
(
params
.
src
);
StorageType
*
dst
=
reinterpret_cast
<
StorageType
*>
(
params
.
dst
);
IndexType
src_index
[
num_dims
];
IndexType
dst_index
[
num_dims
];
CUDA_1D_KERNEL_LOOP_T
(
IndexType
,
linear_index
,
params
.
elem_cnt
)
{
params
.
dst_index_helper
.
OffsetToNdIndex
(
linear_index
,
dst_index
);
bool
if_pad
=
false
;
#pragma unroll
for
(
int
i
=
0
;
i
<
num_dims
;
i
++
)
{
if
(
dst_index
[
i
]
>=
params
.
valid_start
[
i
]
&&
dst_index
[
i
]
<
params
.
valid_end
[
i
])
{
src_index
[
i
]
=
dst_index
[
i
]
-
params
.
valid_start
[
i
];
}
else
{
if_pad
=
true
;
break
;
}
}
StorageType
dst_val
=
packed_pad_val
;
if
(
!
if_pad
)
{
const
IndexType
src_offset
=
params
.
src_index_helper
.
NdIndexToOffset
(
src_index
);
dst_val
=
src
[
src_offset
];
}
dst
[
linear_index
]
=
dst_val
;
}
}
template
<
>
half
GetValue
<
half
>
(
Scalar
value
)
{
return
static_cast
<
half
>
(
GetValue
<
float
>
(
value
));
}
// #if CUDA_VERSION >= 11000
// template<>
// nv_bfloat16 GetValue<nv_bfloat16>(Scalar value) {
// return static_cast<nv_bfloat16>(GetValue<float>(value));
// }
// #endif // CUDA_VERSION >= 11000
template
<
size_t
num_dims
,
typename
IndexType
,
typename
StorageType
>
void
LaunchKernel
(
Stream
*
stream
,
ConstantPadParams
<
num_dims
,
IndexType
>
params
,
StorageType
packed_pad_val
,
size_t
elem_cnt
)
{
stream
->
As
<
CudaStream
>
()
->
LaunchKernelDefaultWaves
(
(
ConstantPadKernel
<
num_dims
,
IndexType
,
StorageType
>
),
elem_cnt
,
params
,
packed_pad_val
);
}
template
<
size_t
num_dims
,
typename
IndexType
,
typename
StorageType
>
void
LaunchKernel
(
Stream
*
stream
,
void
*
dst
,
const
int64_t
*
dst_dims
,
const
void
*
src
,
const
int64_t
*
src_dims
,
const
int64_t
*
padding_before
,
const
int64_t
*
padding_after
,
StorageType
packed_pad_val
,
size_t
elem_cnt
)
{
ConstantPadParams
<
num_dims
,
IndexType
>
params
;
params
.
dst_index_helper
=
OffsetToIndexCalculator
<
IndexType
,
num_dims
>
(
dst_dims
);
params
.
src_index_helper
=
NdIndexOffsetHelper
<
IndexType
,
num_dims
>
(
src_dims
);
params
.
dst
=
dst
;
params
.
src
=
src
;
for
(
int
i
=
0
;
i
<
num_dims
;
i
++
)
{
params
.
valid_start
[
i
]
=
padding_before
[
i
];
params
.
valid_end
[
i
]
=
dst_dims
[
i
]
-
padding_after
[
i
];
}
params
.
elem_cnt
=
elem_cnt
;
LaunchKernel
<
num_dims
,
IndexType
,
StorageType
>
(
stream
,
params
,
packed_pad_val
,
elem_cnt
);
}
template
<
size_t
num_dims
,
typename
StorageType
>
void
DispatchIndexType
(
Stream
*
stream
,
void
*
dst
,
const
int64_t
*
dst_dims
,
const
void
*
src
,
const
int64_t
*
src_dims
,
const
int64_t
*
padding_before
,
const
int64_t
*
padding_after
,
StorageType
packed_pad_val
,
size_t
elem_cnt
)
{
if
(
elem_cnt
<
GetMaxVal
<
int32_t
>
())
{
LaunchKernel
<
num_dims
,
int32_t
,
StorageType
>
(
stream
,
dst
,
dst_dims
,
src
,
src_dims
,
padding_before
,
padding_after
,
packed_pad_val
,
elem_cnt
);
}
else
{
LaunchKernel
<
num_dims
,
int64_t
,
StorageType
>
(
stream
,
dst
,
dst_dims
,
src
,
src_dims
,
padding_before
,
padding_after
,
packed_pad_val
,
elem_cnt
);
}
}
template
<
size_t
num_dims
,
typename
T
>
void
DispatchPackSize
(
Stream
*
stream
,
void
*
dst
,
int64_t
*
dst_dims
,
const
void
*
src
,
int64_t
*
src_dims
,
int64_t
*
padding_before
,
int64_t
*
padding_after
,
T
pad_val
)
{
constexpr
int32_t
max_packsize
=
GetMaxPackSize
<
T
>
();
size_t
launch_pack_size
=
GetLaunchPackSize
<
max_packsize
>
(
num_dims
,
dst
,
dst_dims
,
src
,
src_dims
,
padding_before
,
padding_after
);
dst_dims
[
num_dims
-
1
]
/=
launch_pack_size
;
src_dims
[
num_dims
-
1
]
/=
launch_pack_size
;
padding_before
[
num_dims
-
1
]
/=
launch_pack_size
;
padding_after
[
num_dims
-
1
]
/=
launch_pack_size
;
size_t
elem_cnt
=
1
;
for
(
int
i
=
0
;
i
<
num_dims
;
i
++
)
{
elem_cnt
*=
dst_dims
[
i
];
}
if
(
launch_pack_size
==
1
)
{
Pack
<
T
,
1
>
packed_pad_val
(
pad_val
);
DispatchIndexType
<
num_dims
,
PackType
<
T
,
1
>>
(
stream
,
dst
,
dst_dims
,
src
,
src_dims
,
padding_before
,
padding_after
,
packed_pad_val
.
storage
,
elem_cnt
);
}
else
if
(
launch_pack_size
==
2
)
{
Pack
<
T
,
2
>
packed_pad_val
(
pad_val
);
DispatchIndexType
<
num_dims
,
PackType
<
T
,
2
>>
(
stream
,
dst
,
dst_dims
,
src
,
src_dims
,
padding_before
,
padding_after
,
packed_pad_val
.
storage
,
elem_cnt
);
}
else
if
(
launch_pack_size
==
4
)
{
Pack
<
T
,
4
>
packed_pad_val
(
pad_val
);
DispatchIndexType
<
num_dims
,
PackType
<
T
,
4
>>
(
stream
,
dst
,
dst_dims
,
src
,
src_dims
,
padding_before
,
padding_after
,
packed_pad_val
.
storage
,
elem_cnt
);
}
else
if
(
launch_pack_size
==
8
)
{
Pack
<
T
,
8
>
packed_pad_val
(
pad_val
);
DispatchIndexType
<
num_dims
,
PackType
<
T
,
8
>>
(
stream
,
dst
,
dst_dims
,
src
,
src_dims
,
padding_before
,
padding_after
,
packed_pad_val
.
storage
,
elem_cnt
);
}
else
if
(
launch_pack_size
==
16
)
{
Pack
<
T
,
16
>
packed_pad_val
(
pad_val
);
DispatchIndexType
<
num_dims
,
PackType
<
T
,
16
>>
(
stream
,
dst
,
dst_dims
,
src
,
src_dims
,
padding_before
,
padding_after
,
packed_pad_val
.
storage
,
elem_cnt
);
}
else
{
UNIMPLEMENTED
();
}
}
template
<
typename
T
>
void
LaunchWithSimplified
(
Stream
*
stream
,
size_t
num_dims
,
void
*
dst
,
int64_t
*
dst_dims
,
const
void
*
src
,
int64_t
*
src_dims
,
int64_t
*
padding_before
,
int64_t
*
padding_after
,
T
pad_val
)
{
void
(
*
func
)(
Stream
*
/*stream*/
,
void
*
/*dst*/
,
int64_t
*
/*dst_dims*/
,
const
void
*
/*src*/
,
int64_t
*
/*src_dims*/
,
int64_t
*
/*padding_before*/
,
int64_t
*
/*padding_after*/
,
T
)
=
nullptr
;
if
(
num_dims
==
1
)
{
func
=
DispatchPackSize
<
1
,
T
>
;
}
else
if
(
num_dims
==
2
)
{
func
=
DispatchPackSize
<
2
,
T
>
;
}
else
if
(
num_dims
==
3
)
{
func
=
DispatchPackSize
<
3
,
T
>
;
}
else
if
(
num_dims
==
4
)
{
func
=
DispatchPackSize
<
4
,
T
>
;
}
else
if
(
num_dims
==
5
)
{
func
=
DispatchPackSize
<
5
,
T
>
;
}
else
if
(
num_dims
==
6
)
{
func
=
DispatchPackSize
<
6
,
T
>
;
}
else
if
(
num_dims
==
7
)
{
func
=
DispatchPackSize
<
7
,
T
>
;
}
else
if
(
num_dims
==
8
)
{
func
=
DispatchPackSize
<
8
,
T
>
;
}
else
{
UNIMPLEMENTED
();
}
func
(
stream
,
dst
,
dst_dims
,
src
,
src_dims
,
padding_before
,
padding_after
,
pad_val
);
}
template
<
typename
T
>
void
SimplifyThenLaunch
(
Stream
*
stream
,
size_t
num_dims
,
const
int64_t
*
src_dims
,
const
void
*
src
,
const
int64_t
*
padding_before
,
const
int64_t
*
padding_after
,
T
pad_val
,
void
*
dst
)
{
CHECK_LE
(
num_dims
,
kMaxNumDims
);
int64_t
simplified_dst_dims
[
kMaxNumDims
];
int64_t
simplified_src_dims
[
kMaxNumDims
];
int64_t
simplified_padding_before
[
kMaxNumDims
];
int64_t
simplified_padding_after
[
kMaxNumDims
];
size_t
simplified_num_dims
=
1
;
SimplifyPadDims
(
num_dims
,
src_dims
,
padding_before
,
padding_after
,
&
simplified_num_dims
,
simplified_dst_dims
,
simplified_src_dims
,
simplified_padding_before
,
simplified_padding_after
);
LaunchWithSimplified
<
T
>
(
stream
,
simplified_num_dims
,
dst
,
simplified_dst_dims
,
src
,
simplified_src_dims
,
simplified_padding_before
,
simplified_padding_after
,
pad_val
);
}
template
<
typename
T
>
class
ConstantPadImpl
:
public
ConstantPad
{
public:
OF_DISALLOW_COPY_AND_MOVE
(
ConstantPadImpl
);
ConstantPadImpl
()
=
default
;
~
ConstantPadImpl
()
override
=
default
;
void
Launch
(
Stream
*
stream
,
size_t
num_dims
,
const
int64_t
*
src_dims
,
const
void
*
src
,
const
int64_t
*
padding_before
,
const
int64_t
*
padding_after
,
Scalar
pad_val
,
void
*
dst
)
override
{
SimplifyThenLaunch
<
T
>
(
stream
,
num_dims
,
src_dims
,
src
,
padding_before
,
padding_after
,
GetValue
<
T
>
(
pad_val
),
dst
);
}
};
template
<
typename
T
>
std
::
unique_ptr
<
ConstantPad
>
NewConstantPad
()
{
return
std
::
unique_ptr
<
ConstantPad
>
(
new
ConstantPadImpl
<
T
>
());
}
class
ConstantPadFactoryImpl
:
public
ConstantPadFactory
{
public:
OF_DISALLOW_COPY_AND_MOVE
(
ConstantPadFactoryImpl
);
ConstantPadFactoryImpl
()
=
default
;
~
ConstantPadFactoryImpl
()
override
=
default
;
std
::
unique_ptr
<
ConstantPad
>
New
(
DataType
data_type
)
override
{
#define MAKE_NEW_CONSTANT_PAD_ENTRY(type_cpp, type_proto) {type_proto, NewConstantPad<type_cpp>},
static
const
std
::
map
<
DataType
,
std
::
function
<
std
::
unique_ptr
<
ConstantPad
>
()
>>
new_constant_pad_handle
{
OF_PP_FOR_EACH_TUPLE
(
MAKE_NEW_CONSTANT_PAD_ENTRY
,
CUDA_PRIMITIVE_ALL_TYPE_SEQ
)};
#undef MAKE_NEW_CONSTANT_PAD_ENTRY
const
auto
it
=
new_constant_pad_handle
.
find
(
data_type
);
if
(
it
!=
new_constant_pad_handle
.
end
())
{
return
it
->
second
();
}
else
{
return
nullptr
;
}
}
};
REGISTER_PRIMITIVE_FACTORY
(
DeviceType
::
kCUDA
,
ConstantPadFactory
,
ConstantPadFactoryImpl
);
}
// namespace
}
// namespace primitive
}
// namespace ep
}
// namespace oneflow
\ No newline at end of file
oneflow/core/ep/rocm/primitive/copy_nd.hip.cpp
View file @
8f7de847
#include "hip/hip_runtime.h"
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/ep/include/primitive/copy_nd.h"
#include "oneflow/core/ep/common/primitive/copy_nd.h"
#include "oneflow/core/ep/rocm/cuda_stream.h"
#include <hip/hip_runtime.h>
namespace
oneflow
{
namespace
ep
{
namespace
primitive
{
namespace
{
template
<
size_t
num_dims
,
size_t
movement_size
,
typename
IndexType
>
__global__
void
CopyNdKernel
(
CopyNdKernelParams
<
num_dims
,
IndexType
>
params
)
{
using
T
=
typename
std
::
aligned_storage
<
movement_size
,
movement_size
>::
type
;
const
T
*
src
=
reinterpret_cast
<
const
T
*>
(
params
.
src
);
T
*
dst
=
reinterpret_cast
<
T
*>
(
params
.
dst
);
IndexType
copy_index
[
num_dims
];
IndexType
src_index
[
num_dims
];
IndexType
dst_index
[
num_dims
];
CUDA_1D_KERNEL_LOOP_T
(
IndexType
,
i
,
params
.
count
)
{
params
.
copy_index_helper
.
OffsetToNdIndex
(
i
,
copy_index
);
#pragma unroll
for
(
size_t
j
=
0
;
j
<
num_dims
;
++
j
)
{
src_index
[
j
]
=
params
.
src_pos
[
j
]
+
copy_index
[
j
];
dst_index
[
j
]
=
params
.
dst_pos
[
j
]
+
copy_index
[
j
];
}
const
IndexType
src_offset
=
params
.
src_index_helper
.
NdIndexToOffset
(
src_index
);
const
IndexType
dst_offset
=
params
.
dst_index_helper
.
NdIndexToOffset
(
dst_index
);
dst
[
dst_offset
]
=
src
[
src_offset
];
}
}
template
<
size_t
num_dims
,
size_t
movement_size
,
typename
IndexType
>
void
LaunchKernel
(
Stream
*
stream
,
CopyNdKernelParams
<
num_dims
,
IndexType
>
params
)
{
hipStream_t
cuda_stream
=
stream
->
As
<
CudaStream
>
()
->
cuda_stream
();
CopyNdKernel
<
num_dims
,
movement_size
,
IndexType
>
<<<
BlocksNum4ThreadsNum
(
params
.
count
),
kCudaThreadsNumPerBlock
,
0
,
cuda_stream
>>>
(
params
);
}
class
CopyNdImpl
:
public
CopyNd
{
public:
OF_DISALLOW_COPY_AND_MOVE
(
CopyNdImpl
);
CopyNdImpl
()
=
default
;
~
CopyNdImpl
()
override
=
default
;
void
Launch
(
Stream
*
stream
,
DataType
data_type
,
size_t
num_dims
,
void
*
dst
,
const
int64_t
*
dst_dims
,
const
int64_t
*
dst_pos
,
const
void
*
src
,
const
int64_t
*
src_dims
,
const
int64_t
*
src_pos
,
const
int64_t
*
extent
)
const
override
{
SimplifyThenLaunch
(
stream
,
data_type
,
num_dims
,
dst
,
dst_dims
,
dst_pos
,
src
,
src_dims
,
src_pos
,
extent
);
}
};
class
CopyNdFactoryImpl
:
public
CopyNdFactory
{
public:
OF_DISALLOW_COPY_AND_MOVE
(
CopyNdFactoryImpl
);
CopyNdFactoryImpl
()
=
default
;
~
CopyNdFactoryImpl
()
override
=
default
;
std
::
unique_ptr
<
CopyNd
>
New
(
size_t
max_num_dims
)
override
{
if
(
max_num_dims
<=
kMaxNumDims
)
{
return
std
::
unique_ptr
<
CopyNd
>
(
new
CopyNdImpl
());
}
else
{
return
nullptr
;
}
}
};
REGISTER_PRIMITIVE_FACTORY
(
DeviceType
::
kCUDA
,
CopyNdFactory
,
CopyNdFactoryImpl
);
}
// namespace
}
// namespace primitive
}
// namespace ep
}
// namespace oneflow
#include "hip/hip_runtime.h"
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/ep/include/primitive/copy_nd.h"
#include "oneflow/core/ep/common/primitive/copy_nd.h"
#include "oneflow/core/ep/rocm/cuda_stream.h"
#include <hip/hip_runtime.h>
namespace
oneflow
{
namespace
ep
{
namespace
primitive
{
namespace
{
template
<
size_t
num_dims
,
size_t
movement_size
,
typename
IndexType
>
__global__
void
CopyNdKernel
(
CopyNdKernelParams
<
num_dims
,
IndexType
>
params
)
{
using
T
=
typename
std
::
aligned_storage
<
movement_size
,
movement_size
>::
type
;
const
T
*
src
=
reinterpret_cast
<
const
T
*>
(
params
.
src
);
T
*
dst
=
reinterpret_cast
<
T
*>
(
params
.
dst
);
IndexType
copy_index
[
num_dims
];
IndexType
src_index
[
num_dims
];
IndexType
dst_index
[
num_dims
];
CUDA_1D_KERNEL_LOOP_T
(
IndexType
,
i
,
params
.
count
)
{
params
.
copy_index_helper
.
OffsetToNdIndex
(
i
,
copy_index
);
#pragma unroll
for
(
size_t
j
=
0
;
j
<
num_dims
;
++
j
)
{
src_index
[
j
]
=
params
.
src_pos
[
j
]
+
copy_index
[
j
];
dst_index
[
j
]
=
params
.
dst_pos
[
j
]
+
copy_index
[
j
];
}
const
IndexType
src_offset
=
params
.
src_index_helper
.
NdIndexToOffset
(
src_index
);
const
IndexType
dst_offset
=
params
.
dst_index_helper
.
NdIndexToOffset
(
dst_index
);
dst
[
dst_offset
]
=
src
[
src_offset
];
}
}
template
<
size_t
num_dims
,
size_t
movement_size
,
typename
IndexType
>
void
LaunchKernel
(
Stream
*
stream
,
CopyNdKernelParams
<
num_dims
,
IndexType
>
params
)
{
hipStream_t
cuda_stream
=
stream
->
As
<
CudaStream
>
()
->
cuda_stream
();
CopyNdKernel
<
num_dims
,
movement_size
,
IndexType
>
<<<
BlocksNum4ThreadsNum
(
params
.
count
),
kCudaThreadsNumPerBlock
,
0
,
cuda_stream
>>>
(
params
);
}
class
CopyNdImpl
:
public
CopyNd
{
public:
OF_DISALLOW_COPY_AND_MOVE
(
CopyNdImpl
);
CopyNdImpl
()
=
default
;
~
CopyNdImpl
()
override
=
default
;
void
Launch
(
Stream
*
stream
,
DataType
data_type
,
size_t
num_dims
,
void
*
dst
,
const
int64_t
*
dst_dims
,
const
int64_t
*
dst_pos
,
const
void
*
src
,
const
int64_t
*
src_dims
,
const
int64_t
*
src_pos
,
const
int64_t
*
extent
)
const
override
{
SimplifyThenLaunch
(
stream
,
data_type
,
num_dims
,
dst
,
dst_dims
,
dst_pos
,
src
,
src_dims
,
src_pos
,
extent
);
}
};
class
CopyNdFactoryImpl
:
public
CopyNdFactory
{
public:
OF_DISALLOW_COPY_AND_MOVE
(
CopyNdFactoryImpl
);
CopyNdFactoryImpl
()
=
default
;
~
CopyNdFactoryImpl
()
override
=
default
;
std
::
unique_ptr
<
CopyNd
>
New
(
size_t
max_num_dims
)
override
{
if
(
max_num_dims
<=
kMaxNumDims
)
{
return
std
::
unique_ptr
<
CopyNd
>
(
new
CopyNdImpl
());
}
else
{
return
nullptr
;
}
}
};
REGISTER_PRIMITIVE_FACTORY
(
DeviceType
::
kCUDA
,
CopyNdFactory
,
CopyNdFactoryImpl
);
}
// namespace
}
// namespace primitive
}
// namespace ep
}
// namespace oneflow
oneflow/core/ep/rocm/primitive/elementwise_unary.hip.cpp
View file @
8f7de847
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/ep/common/primitive/elementwise_unary.h"
#include "oneflow/core/ep/rocm/primitive/unary_functor.hip.h"
namespace
oneflow
{
namespace
ep
{
namespace
primitive
{
namespace
{
template
<
UnaryOp
unary_op
,
typename
Src
,
typename
Dst
>
class
ElementwiseUnaryImpl
:
public
ElementwiseUnary
{
public:
OF_DISALLOW_COPY_AND_MOVE
(
ElementwiseUnaryImpl
);
ElementwiseUnaryImpl
(
Scalar
attr0
,
Scalar
attr1
)
:
attr0
(
attr0
),
attr1
(
attr1
)
{}
~
ElementwiseUnaryImpl
()
override
=
default
;
void
Launch
(
Stream
*
stream
,
const
void
*
src
,
void
*
dst
,
size_t
count
)
override
{
auto
*
cuda_stream
=
stream
->
As
<
CudaStream
>
();
auto
functor
=
UnaryFunctor
<
DeviceType
::
kCUDA
,
unary_op
,
Dst
,
Src
>
(
attr0
,
attr1
);
OF_CUDA_CHECK
((
cuda
::
elementwise
::
Unary
<
decltype
(
functor
),
Dst
,
Src
>
(
functor
,
count
,
reinterpret_cast
<
Dst
*>
(
dst
),
reinterpret_cast
<
const
Src
*>
(
src
),
cuda_stream
->
cuda_stream
())));
}
protected:
Scalar
attr0
,
attr1
;
};
template
<
UnaryOp
unary_op
,
typename
Src
,
typename
Dst
>
std
::
unique_ptr
<
ElementwiseUnary
>
NewElementwiseUnary
(
Scalar
attr0
,
Scalar
attr1
)
{
return
std
::
unique_ptr
<
ElementwiseUnary
>
(
new
ElementwiseUnaryImpl
<
unary_op
,
Src
,
Dst
>
(
attr0
,
attr1
));
}
class
ElementwiseUnaryFactoryImpl
:
public
ElementwiseUnaryFactory
{
public:
OF_DISALLOW_COPY_AND_MOVE
(
ElementwiseUnaryFactoryImpl
);
ElementwiseUnaryFactoryImpl
()
=
default
;
~
ElementwiseUnaryFactoryImpl
()
override
=
default
;
std
::
unique_ptr
<
ElementwiseUnary
>
New
(
UnaryOp
unary_op
,
DataType
src_type
,
DataType
dst_dtype
)
override
{
return
New
(
unary_op
,
src_type
,
dst_dtype
,
Scalar
(),
Scalar
());
}
std
::
unique_ptr
<
ElementwiseUnary
>
New
(
UnaryOp
unary_op
,
DataType
src_type
,
DataType
dst_dtype
,
Scalar
attr0
)
override
{
return
New
(
unary_op
,
src_type
,
dst_dtype
,
attr0
,
Scalar
());
}
std
::
unique_ptr
<
ElementwiseUnary
>
New
(
UnaryOp
unary_op
,
DataType
src_type
,
DataType
dst_dtype
,
Scalar
attr0
,
Scalar
attr1
)
override
{
#define MAKE_NEW_SAME_DTYPE_ELEMENTWISE_UNARY_ENTRY(unary_op, dtype_pair) \
{
std
::
make_tuple
(
unary_op
,
OF_PP_PAIR_SECOND
(
dtype_pair
),
OF_PP_PAIR_SECOND
(
dtype_pair
)),
\
NewElementwiseUnary
<
unary_op
,
OF_PP_PAIR_FIRST
(
dtype_pair
),
OF_PP_PAIR_FIRST
(
dtype_pair
)
>
},
#define MAKE_NEW_DIFFERENT_DTYPE_ELEMENTWISE_UNARY_ENTRY(unary_op, src_type_pair, dst_dtype_pair) \
{
std
::
make_tuple
(
unary_op
,
OF_PP_PAIR_SECOND
(
src_type_pair
),
OF_PP_PAIR_SECOND
(
dst_dtype_pair
)),
\
NewElementwiseUnary
<
unary_op
,
OF_PP_PAIR_FIRST
(
src_type_pair
),
\
OF_PP_PAIR_FIRST
(
dst_dtype_pair
)
>
},
static
const
std
::
map
<
std
::
tuple
<
UnaryOp
,
DataType
,
DataType
>
,
std
::
function
<
std
::
unique_ptr
<
ElementwiseUnary
>
(
Scalar
,
Scalar
)
>>
new_elementwise_unary_handle
{
// For All Type OP
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE
(
MAKE_NEW_SAME_DTYPE_ELEMENTWISE_UNARY_ENTRY
,
UNARY_MATH_OP_SEQ
,
CUDA_PRIMITIVE_ALL_TYPE_SEQ
)
// For Float Type OP
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE
(
MAKE_NEW_SAME_DTYPE_ELEMENTWISE_UNARY_ENTRY
,
UNARY_FLOATING_MATH_OP_SEQ
,
CUDA_PRIMITIVE_FLOATING_TYPE_SEQ
)
// For Utils OP
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE
(
MAKE_NEW_DIFFERENT_DTYPE_ELEMENTWISE_UNARY_ENTRY
,
UNARY_UTILS_OP_SEQ
,
UTIL_OPS_DATA_TYPE_SEQ
,
CUDA_PRIMITIVE_BOOL_TYPE_SEQ
)
// For Logical OP
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE
(
MAKE_NEW_DIFFERENT_DTYPE_ELEMENTWISE_UNARY_ENTRY
,
UNARY_LOGICAL_OP_SEQ
,
CUDA_PRIMITIVE_ALL_TYPE_SEQ
,
CUDA_PRIMITIVE_BOOL_TYPE_SEQ
)};
#undef MAKE_NEW_DIFFERENT_DTYPE_ELEMENTWISE_UNARY_ENTRY
#undef MAKE_NEW_SAME_DTYPE_ELEMENTWISE_UNARY_ENTRY
const
auto
it
=
new_elementwise_unary_handle
.
find
(
std
::
make_tuple
(
unary_op
,
src_type
,
dst_dtype
));
if
(
it
!=
new_elementwise_unary_handle
.
end
())
{
return
it
->
second
(
attr0
,
attr1
);
}
else
{
return
nullptr
;
}
}
};
REGISTER_PRIMITIVE_FACTORY
(
DeviceType
::
kCUDA
,
ElementwiseUnaryFactory
,
ElementwiseUnaryFactoryImpl
);
}
// namespace
}
// namespace primitive
}
// namespace ep
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/ep/common/primitive/elementwise_unary.h"
#include "oneflow/core/ep/rocm/primitive/unary_functor.hip.h"
namespace
oneflow
{
namespace
ep
{
namespace
primitive
{
namespace
{
template
<
UnaryOp
unary_op
,
typename
Src
,
typename
Dst
>
class
ElementwiseUnaryImpl
:
public
ElementwiseUnary
{
public:
OF_DISALLOW_COPY_AND_MOVE
(
ElementwiseUnaryImpl
);
ElementwiseUnaryImpl
(
Scalar
attr0
,
Scalar
attr1
)
:
attr0
(
attr0
),
attr1
(
attr1
)
{}
~
ElementwiseUnaryImpl
()
override
=
default
;
void
Launch
(
Stream
*
stream
,
const
void
*
src
,
void
*
dst
,
size_t
count
)
override
{
auto
*
cuda_stream
=
stream
->
As
<
CudaStream
>
();
auto
functor
=
UnaryFunctor
<
DeviceType
::
kCUDA
,
unary_op
,
Dst
,
Src
>
(
attr0
,
attr1
);
OF_CUDA_CHECK
((
cuda
::
elementwise
::
Unary
<
decltype
(
functor
),
Dst
,
Src
>
(
functor
,
count
,
reinterpret_cast
<
Dst
*>
(
dst
),
reinterpret_cast
<
const
Src
*>
(
src
),
cuda_stream
->
cuda_stream
())));
}
protected:
Scalar
attr0
,
attr1
;
};
template
<
UnaryOp
unary_op
,
typename
Src
,
typename
Dst
>
std
::
unique_ptr
<
ElementwiseUnary
>
NewElementwiseUnary
(
Scalar
attr0
,
Scalar
attr1
)
{
return
std
::
unique_ptr
<
ElementwiseUnary
>
(
new
ElementwiseUnaryImpl
<
unary_op
,
Src
,
Dst
>
(
attr0
,
attr1
));
}
class
ElementwiseUnaryFactoryImpl
:
public
ElementwiseUnaryFactory
{
public:
OF_DISALLOW_COPY_AND_MOVE
(
ElementwiseUnaryFactoryImpl
);
ElementwiseUnaryFactoryImpl
()
=
default
;
~
ElementwiseUnaryFactoryImpl
()
override
=
default
;
std
::
unique_ptr
<
ElementwiseUnary
>
New
(
UnaryOp
unary_op
,
DataType
src_type
,
DataType
dst_dtype
)
override
{
return
New
(
unary_op
,
src_type
,
dst_dtype
,
Scalar
(),
Scalar
());
}
std
::
unique_ptr
<
ElementwiseUnary
>
New
(
UnaryOp
unary_op
,
DataType
src_type
,
DataType
dst_dtype
,
Scalar
attr0
)
override
{
return
New
(
unary_op
,
src_type
,
dst_dtype
,
attr0
,
Scalar
());
}
std
::
unique_ptr
<
ElementwiseUnary
>
New
(
UnaryOp
unary_op
,
DataType
src_type
,
DataType
dst_dtype
,
Scalar
attr0
,
Scalar
attr1
)
override
{
#define MAKE_NEW_SAME_DTYPE_ELEMENTWISE_UNARY_ENTRY(unary_op, dtype_pair) \
{std::make_tuple(unary_op, OF_PP_PAIR_SECOND(dtype_pair), OF_PP_PAIR_SECOND(dtype_pair)), \
NewElementwiseUnary<unary_op, OF_PP_PAIR_FIRST(dtype_pair), OF_PP_PAIR_FIRST(dtype_pair)>},
#define MAKE_NEW_DIFFERENT_DTYPE_ELEMENTWISE_UNARY_ENTRY(unary_op, src_type_pair, dst_dtype_pair) \
{std::make_tuple(unary_op, OF_PP_PAIR_SECOND(src_type_pair), OF_PP_PAIR_SECOND(dst_dtype_pair)), \
NewElementwiseUnary<unary_op, OF_PP_PAIR_FIRST(src_type_pair), \
OF_PP_PAIR_FIRST(dst_dtype_pair)>},
static
const
std
::
map
<
std
::
tuple
<
UnaryOp
,
DataType
,
DataType
>
,
std
::
function
<
std
::
unique_ptr
<
ElementwiseUnary
>
(
Scalar
,
Scalar
)
>>
new_elementwise_unary_handle
{
// For All Type OP
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE
(
MAKE_NEW_SAME_DTYPE_ELEMENTWISE_UNARY_ENTRY
,
UNARY_MATH_OP_SEQ
,
CUDA_PRIMITIVE_ALL_TYPE_SEQ
)
// For Float Type OP
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE
(
MAKE_NEW_SAME_DTYPE_ELEMENTWISE_UNARY_ENTRY
,
UNARY_FLOATING_MATH_OP_SEQ
,
CUDA_PRIMITIVE_FLOATING_TYPE_SEQ
)
// For Utils OP
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE
(
MAKE_NEW_DIFFERENT_DTYPE_ELEMENTWISE_UNARY_ENTRY
,
UNARY_UTILS_OP_SEQ
,
UTIL_OPS_DATA_TYPE_SEQ
,
CUDA_PRIMITIVE_BOOL_TYPE_SEQ
)
// For Logical OP
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE
(
MAKE_NEW_DIFFERENT_DTYPE_ELEMENTWISE_UNARY_ENTRY
,
UNARY_LOGICAL_OP_SEQ
,
CUDA_PRIMITIVE_ALL_TYPE_SEQ
,
CUDA_PRIMITIVE_BOOL_TYPE_SEQ
)};
#undef MAKE_NEW_DIFFERENT_DTYPE_ELEMENTWISE_UNARY_ENTRY
#undef MAKE_NEW_SAME_DTYPE_ELEMENTWISE_UNARY_ENTRY
const
auto
it
=
new_elementwise_unary_handle
.
find
(
std
::
make_tuple
(
unary_op
,
src_type
,
dst_dtype
));
if
(
it
!=
new_elementwise_unary_handle
.
end
())
{
return
it
->
second
(
attr0
,
attr1
);
}
else
{
return
nullptr
;
}
}
};
REGISTER_PRIMITIVE_FACTORY
(
DeviceType
::
kCUDA
,
ElementwiseUnaryFactory
,
ElementwiseUnaryFactoryImpl
);
}
// namespace
}
// namespace primitive
}
// namespace ep
}
// namespace oneflow
\ No newline at end of file
oneflow/core/ep/rocm/primitive/fill.hip.cpp
View file @
8f7de847
#include "hip/hip_runtime.h"
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/ep/include/primitive/fill.h"
#include "oneflow/core/ep/rocm/primitive/type_seq.h"
#include "oneflow/core/ep/rocm/cuda_stream.h"
namespace
oneflow
{
namespace
ep
{
namespace
primitive
{
namespace
{
template
<
size_t
size
>
using
Storage
=
typename
std
::
aligned_storage
<
size
,
size
>::
type
;
template
<
typename
T
,
size_t
pack
>
union
Pack
{
static
constexpr
size_t
size
=
sizeof
(
T
)
*
pack
;
explicit
__device__
__host__
Pack
(
T
value
)
{
static_assert
(
sizeof
(
Pack
)
==
size
,
""
);
static_assert
(
alignof
(
Pack
)
==
size
,
""
);
#pragma unroll
for
(
size_t
i
=
0
;
i
<
pack
;
++
i
)
{
elem
[
i
]
=
value
;
}
}
T
elem
[
pack
];
Storage
<
size
>
storage
;
};
template
<
typename
T
,
size_t
pack
>
__global__
void
FillGpu
(
T
*
dst
,
T
value
,
size_t
count
)
{
const
size_t
pack_count
=
count
/
pack
;
Pack
<
T
,
pack
>
pack_value
(
value
);
auto
*
pack_dst
=
reinterpret_cast
<
decltype
(
pack_value
.
storage
)
*>
(
dst
);
CUDA_1D_KERNEL_LOOP_T
(
size_t
,
i
,
pack_count
)
{
pack_dst
[
i
]
=
pack_value
.
storage
;
}
T
*
tail_dst
=
dst
+
pack_count
*
pack
;
const
size_t
tail_count
=
count
-
pack_count
*
pack
;
CUDA_1D_KERNEL_LOOP_T
(
size_t
,
i
,
tail_count
)
{
tail_dst
[
i
]
=
value
;
}
}
template
<
typename
T
>
T
GetValue
(
Scalar
value
)
{
return
value
.
Value
<
T
>
();
}
template
<
>
half
GetValue
<
half
>
(
Scalar
value
)
{
return
static_cast
<
half
>
(
GetValue
<
float
>
(
value
));
}
// #if CUDA_VERSION >= 11000
// template<>
// nv_bfloat16 GetValue<nv_bfloat16>(Scalar value) {
// return static_cast<nv_bfloat16>(GetValue<float>(value));
// }
// #endif // CUDA_VERSION >= 11000
template
<
typename
T
,
size_t
pack
>
typename
std
::
enable_if
<
(
pack
!=
0
),
void
>::
type
LaunchPackFill
(
hipStream_t
stream
,
T
*
dst
,
T
value
,
size_t
count
)
{
FillGpu
<
T
,
pack
>
<<<
BlocksNum4ThreadsNum
(
count
),
kCudaThreadsNumPerBlock
,
0
,
stream
>>>
(
dst
,
value
,
count
);
}
template
<
typename
T
,
size_t
pack
>
typename
std
::
enable_if
<
(
pack
==
0
),
void
>::
type
LaunchPackFill
(
hipStream_t
stream
,
T
*
dst
,
T
value
,
size_t
count
)
{
LOG
(
FATAL
)
<<
"wrong alignment"
;
}
template
<
typename
T
>
void
LaunchFill
(
hipStream_t
stream
,
T
*
dst
,
T
value
,
size_t
count
)
{
auto
uintptr
=
reinterpret_cast
<
std
::
uintptr_t
>
(
dst
);
if
(
uintptr
%
16
==
0
)
{
LaunchPackFill
<
T
,
16
/
sizeof
(
T
)
>
(
stream
,
dst
,
value
,
count
);
}
else
if
(
uintptr
%
8
==
0
)
{
LaunchPackFill
<
T
,
8
/
sizeof
(
T
)
>
(
stream
,
dst
,
value
,
count
);
}
else
if
(
uintptr
%
4
==
0
)
{
LaunchPackFill
<
T
,
4
/
sizeof
(
T
)
>
(
stream
,
dst
,
value
,
count
);
}
else
if
(
uintptr
%
2
==
0
)
{
LaunchPackFill
<
T
,
2
/
sizeof
(
T
)
>
(
stream
,
dst
,
value
,
count
);
}
else
{
LaunchPackFill
<
T
,
1
/
sizeof
(
T
)
>
(
stream
,
dst
,
value
,
count
);
}
}
template
<
typename
T
>
class
FillImpl
:
public
Fill
{
public:
OF_DISALLOW_COPY_AND_MOVE
(
FillImpl
);
FillImpl
()
=
default
;
~
FillImpl
()
override
=
default
;
void
Launch
(
Stream
*
stream
,
void
*
dst
,
Scalar
value
,
size_t
count
)
override
{
hipStream_t
cuda_stream
=
stream
->
As
<
CudaStream
>
()
->
cuda_stream
();
LaunchFill
<
T
>
(
cuda_stream
,
reinterpret_cast
<
T
*>
(
dst
),
GetValue
<
T
>
(
value
),
count
);
}
};
template
<
typename
T
>
std
::
unique_ptr
<
Fill
>
NewFill
()
{
return
std
::
unique_ptr
<
Fill
>
(
new
FillImpl
<
T
>
());
}
class
FillFactoryImpl
:
public
FillFactory
{
public:
OF_DISALLOW_COPY_AND_MOVE
(
FillFactoryImpl
);
FillFactoryImpl
()
=
default
;
~
FillFactoryImpl
()
override
=
default
;
std
::
unique_ptr
<
Fill
>
New
(
DataType
data_type
)
override
{
#define MAKE_NEW_FILL_ENTRY(type_cpp, type_proto) {type_proto, NewFill<type_cpp>},
static
const
std
::
map
<
DataType
,
std
::
function
<
std
::
unique_ptr
<
Fill
>
()
>>
new_fill_handle
{
OF_PP_FOR_EACH_TUPLE
(
MAKE_NEW_FILL_ENTRY
,
CUDA_PRIMITIVE_ALL_TYPE_SEQ
)};
#undef MAKE_NEW_FILL_ENTRY
const
auto
it
=
new_fill_handle
.
find
(
data_type
);
if
(
it
!=
new_fill_handle
.
end
())
{
return
it
->
second
();
}
else
{
return
nullptr
;
}
}
};
REGISTER_PRIMITIVE_FACTORY
(
DeviceType
::
kCUDA
,
FillFactory
,
FillFactoryImpl
);
}
// namespace
}
// namespace primitive
}
// namespace ep
}
// namespace oneflow
#include "hip/hip_runtime.h"
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/ep/include/primitive/fill.h"
#include "oneflow/core/ep/rocm/primitive/type_seq.h"
#include "oneflow/core/ep/rocm/cuda_stream.h"
namespace
oneflow
{
namespace
ep
{
namespace
primitive
{
namespace
{
template
<
size_t
size
>
using
Storage
=
typename
std
::
aligned_storage
<
size
,
size
>::
type
;
template
<
typename
T
,
size_t
pack
>
union
Pack
{
static
constexpr
size_t
size
=
sizeof
(
T
)
*
pack
;
explicit
__device__
__host__
Pack
(
T
value
)
{
static_assert
(
sizeof
(
Pack
)
==
size
,
""
);
static_assert
(
alignof
(
Pack
)
==
size
,
""
);
#pragma unroll
for
(
size_t
i
=
0
;
i
<
pack
;
++
i
)
{
elem
[
i
]
=
value
;
}
}
T
elem
[
pack
];
Storage
<
size
>
storage
;
};
template
<
typename
T
,
size_t
pack
>
__global__
void
FillGpu
(
T
*
dst
,
T
value
,
size_t
count
)
{
const
size_t
pack_count
=
count
/
pack
;
Pack
<
T
,
pack
>
pack_value
(
value
);
auto
*
pack_dst
=
reinterpret_cast
<
decltype
(
pack_value
.
storage
)
*>
(
dst
);
CUDA_1D_KERNEL_LOOP_T
(
size_t
,
i
,
pack_count
)
{
pack_dst
[
i
]
=
pack_value
.
storage
;
}
T
*
tail_dst
=
dst
+
pack_count
*
pack
;
const
size_t
tail_count
=
count
-
pack_count
*
pack
;
CUDA_1D_KERNEL_LOOP_T
(
size_t
,
i
,
tail_count
)
{
tail_dst
[
i
]
=
value
;
}
}
template
<
typename
T
>
T
GetValue
(
Scalar
value
)
{
return
value
.
Value
<
T
>
();
}
template
<
>
half
GetValue
<
half
>
(
Scalar
value
)
{
return
static_cast
<
half
>
(
GetValue
<
float
>
(
value
));
}
// #if CUDA_VERSION >= 11000
// template<>
// nv_bfloat16 GetValue<nv_bfloat16>(Scalar value) {
// return static_cast<nv_bfloat16>(GetValue<float>(value));
// }
// #endif // CUDA_VERSION >= 11000
template
<
typename
T
,
size_t
pack
>
typename
std
::
enable_if
<
(
pack
!=
0
),
void
>::
type
LaunchPackFill
(
hipStream_t
stream
,
T
*
dst
,
T
value
,
size_t
count
)
{
FillGpu
<
T
,
pack
>
<<<
BlocksNum4ThreadsNum
(
count
),
kCudaThreadsNumPerBlock
,
0
,
stream
>>>
(
dst
,
value
,
count
);
}
template
<
typename
T
,
size_t
pack
>
typename
std
::
enable_if
<
(
pack
==
0
),
void
>::
type
LaunchPackFill
(
hipStream_t
stream
,
T
*
dst
,
T
value
,
size_t
count
)
{
LOG
(
FATAL
)
<<
"wrong alignment"
;
}
template
<
typename
T
>
void
LaunchFill
(
hipStream_t
stream
,
T
*
dst
,
T
value
,
size_t
count
)
{
auto
uintptr
=
reinterpret_cast
<
std
::
uintptr_t
>
(
dst
);
if
(
uintptr
%
16
==
0
)
{
LaunchPackFill
<
T
,
16
/
sizeof
(
T
)
>
(
stream
,
dst
,
value
,
count
);
}
else
if
(
uintptr
%
8
==
0
)
{
LaunchPackFill
<
T
,
8
/
sizeof
(
T
)
>
(
stream
,
dst
,
value
,
count
);
}
else
if
(
uintptr
%
4
==
0
)
{
LaunchPackFill
<
T
,
4
/
sizeof
(
T
)
>
(
stream
,
dst
,
value
,
count
);
}
else
if
(
uintptr
%
2
==
0
)
{
LaunchPackFill
<
T
,
2
/
sizeof
(
T
)
>
(
stream
,
dst
,
value
,
count
);
}
else
{
LaunchPackFill
<
T
,
1
/
sizeof
(
T
)
>
(
stream
,
dst
,
value
,
count
);
}
}
template
<
typename
T
>
class
FillImpl
:
public
Fill
{
public:
OF_DISALLOW_COPY_AND_MOVE
(
FillImpl
);
FillImpl
()
=
default
;
~
FillImpl
()
override
=
default
;
void
Launch
(
Stream
*
stream
,
void
*
dst
,
Scalar
value
,
size_t
count
)
override
{
hipStream_t
cuda_stream
=
stream
->
As
<
CudaStream
>
()
->
cuda_stream
();
LaunchFill
<
T
>
(
cuda_stream
,
reinterpret_cast
<
T
*>
(
dst
),
GetValue
<
T
>
(
value
),
count
);
}
};
template
<
typename
T
>
std
::
unique_ptr
<
Fill
>
NewFill
()
{
return
std
::
unique_ptr
<
Fill
>
(
new
FillImpl
<
T
>
());
}
class
FillFactoryImpl
:
public
FillFactory
{
public:
OF_DISALLOW_COPY_AND_MOVE
(
FillFactoryImpl
);
FillFactoryImpl
()
=
default
;
~
FillFactoryImpl
()
override
=
default
;
std
::
unique_ptr
<
Fill
>
New
(
DataType
data_type
)
override
{
#define MAKE_NEW_FILL_ENTRY(type_cpp, type_proto) {type_proto, NewFill<type_cpp>},
static
const
std
::
map
<
DataType
,
std
::
function
<
std
::
unique_ptr
<
Fill
>
()
>>
new_fill_handle
{
OF_PP_FOR_EACH_TUPLE
(
MAKE_NEW_FILL_ENTRY
,
CUDA_PRIMITIVE_ALL_TYPE_SEQ
)};
#undef MAKE_NEW_FILL_ENTRY
const
auto
it
=
new_fill_handle
.
find
(
data_type
);
if
(
it
!=
new_fill_handle
.
end
())
{
return
it
->
second
();
}
else
{
return
nullptr
;
}
}
};
REGISTER_PRIMITIVE_FACTORY
(
DeviceType
::
kCUDA
,
FillFactory
,
FillFactoryImpl
);
}
// namespace
}
// namespace primitive
}
// namespace ep
}
// namespace oneflow
oneflow/core/ep/rocm/primitive/memcpy.cpp
View file @
8f7de847
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifdef WITH_ROCM
#include "oneflow/core/ep/include/primitive/memcpy.h"
#include "oneflow/core/ep/rocm/cuda_stream.h"
#include <hip/hip_runtime.h>
namespace
oneflow
{
namespace
ep
{
namespace
primitive
{
namespace
{
class
MemcpyImpl
:
public
Memcpy
{
public:
OF_DISALLOW_COPY_AND_MOVE
(
MemcpyImpl
);
MemcpyImpl
()
=
default
;
~
MemcpyImpl
()
override
=
default
;
void
Launch
(
Stream
*
stream
,
void
*
dst
,
const
void
*
src
,
size_t
count
)
override
{
if
(
dst
==
src
)
{
return
;
}
auto
*
cuda_stream
=
stream
->
As
<
CudaStream
>
();
OF_CUDA_CHECK
(
hipMemcpyAsync
(
dst
,
src
,
count
,
hipMemcpyDefault
,
cuda_stream
->
cuda_stream
()));
}
};
class
MemcpyFactoryImpl
:
public
MemcpyFactory
{
public:
OF_DISALLOW_COPY_AND_MOVE
(
MemcpyFactoryImpl
);
MemcpyFactoryImpl
()
=
default
;
~
MemcpyFactoryImpl
()
override
=
default
;
std
::
unique_ptr
<
Memcpy
>
New
(
MemcpyKind
kind
)
override
{
return
std
::
unique_ptr
<
Memcpy
>
(
new
MemcpyImpl
());
}
};
REGISTER_PRIMITIVE_FACTORY
(
DeviceType
::
kCUDA
,
MemcpyFactory
,
MemcpyFactoryImpl
);
}
// namespace
}
// namespace primitive
}
// namespace ep
}
// namespace oneflow
#endif
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifdef WITH_ROCM
#include "oneflow/core/ep/include/primitive/memcpy.h"
#include "oneflow/core/ep/rocm/cuda_stream.h"
#include <hip/hip_runtime.h>
namespace
oneflow
{
namespace
ep
{
namespace
primitive
{
namespace
{
class
MemcpyImpl
:
public
Memcpy
{
public:
OF_DISALLOW_COPY_AND_MOVE
(
MemcpyImpl
);
MemcpyImpl
()
=
default
;
~
MemcpyImpl
()
override
=
default
;
void
Launch
(
Stream
*
stream
,
void
*
dst
,
const
void
*
src
,
size_t
count
)
override
{
if
(
dst
==
src
)
{
return
;
}
auto
*
cuda_stream
=
stream
->
As
<
CudaStream
>
();
OF_CUDA_CHECK
(
hipMemcpyAsync
(
dst
,
src
,
count
,
hipMemcpyDefault
,
cuda_stream
->
cuda_stream
()));
}
};
class
MemcpyFactoryImpl
:
public
MemcpyFactory
{
public:
OF_DISALLOW_COPY_AND_MOVE
(
MemcpyFactoryImpl
);
MemcpyFactoryImpl
()
=
default
;
~
MemcpyFactoryImpl
()
override
=
default
;
std
::
unique_ptr
<
Memcpy
>
New
(
MemcpyKind
kind
)
override
{
return
std
::
unique_ptr
<
Memcpy
>
(
new
MemcpyImpl
());
}
};
REGISTER_PRIMITIVE_FACTORY
(
DeviceType
::
kCUDA
,
MemcpyFactory
,
MemcpyFactoryImpl
);
}
// namespace
}
// namespace primitive
}
// namespace ep
}
// namespace oneflow
#endif
oneflow/core/ep/rocm/primitive/memset.cpp
View file @
8f7de847
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifdef WITH_ROCM
#include "oneflow/core/ep/include/primitive/memset.h"
#include "oneflow/core/ep/rocm/cuda_stream.h"
#include <hip/hip_runtime.h>
namespace
oneflow
{
namespace
ep
{
namespace
primitive
{
namespace
{
class
MemsetImpl
:
public
Memset
{
public:
OF_DISALLOW_COPY_AND_MOVE
(
MemsetImpl
);
MemsetImpl
()
=
default
;
~
MemsetImpl
()
override
=
default
;
void
Launch
(
Stream
*
stream
,
void
*
ptr
,
int
value
,
size_t
count
)
override
{
auto
*
cuda_stream
=
stream
->
As
<
CudaStream
>
();
OF_CUDA_CHECK
(
hipMemsetAsync
(
ptr
,
value
,
count
,
cuda_stream
->
cuda_stream
()));
}
};
class
MemsetFactoryImpl
:
public
MemsetFactory
{
public:
OF_DISALLOW_COPY_AND_MOVE
(
MemsetFactoryImpl
);
MemsetFactoryImpl
()
=
default
;
~
MemsetFactoryImpl
()
override
=
default
;
std
::
unique_ptr
<
Memset
>
New
()
override
{
return
std
::
unique_ptr
<
Memset
>
(
new
MemsetImpl
());
}
};
REGISTER_PRIMITIVE_FACTORY
(
DeviceType
::
kCUDA
,
MemsetFactory
,
MemsetFactoryImpl
);
}
// namespace
}
// namespace primitive
}
// namespace ep
}
// namespace oneflow
#endif
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifdef WITH_ROCM
#include "oneflow/core/ep/include/primitive/memset.h"
#include "oneflow/core/ep/rocm/cuda_stream.h"
#include <hip/hip_runtime.h>
namespace
oneflow
{
namespace
ep
{
namespace
primitive
{
namespace
{
class
MemsetImpl
:
public
Memset
{
public:
OF_DISALLOW_COPY_AND_MOVE
(
MemsetImpl
);
MemsetImpl
()
=
default
;
~
MemsetImpl
()
override
=
default
;
void
Launch
(
Stream
*
stream
,
void
*
ptr
,
int
value
,
size_t
count
)
override
{
auto
*
cuda_stream
=
stream
->
As
<
CudaStream
>
();
OF_CUDA_CHECK
(
hipMemsetAsync
(
ptr
,
value
,
count
,
cuda_stream
->
cuda_stream
()));
}
};
class
MemsetFactoryImpl
:
public
MemsetFactory
{
public:
OF_DISALLOW_COPY_AND_MOVE
(
MemsetFactoryImpl
);
MemsetFactoryImpl
()
=
default
;
~
MemsetFactoryImpl
()
override
=
default
;
std
::
unique_ptr
<
Memset
>
New
()
override
{
return
std
::
unique_ptr
<
Memset
>
(
new
MemsetImpl
());
}
};
REGISTER_PRIMITIVE_FACTORY
(
DeviceType
::
kCUDA
,
MemsetFactory
,
MemsetFactoryImpl
);
}
// namespace
}
// namespace primitive
}
// namespace ep
}
// namespace oneflow
#endif
oneflow/core/ep/rocm/primitive/permute.hip.cpp
View file @
8f7de847
#include "hip/hip_runtime.h"
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/ep/include/primitive/permute.h"
#include "oneflow/core/ep/common/primitive/permute_impl.h"
#include "oneflow/core/ep/rocm/cuda_stream.h"
#include <hip/hip_runtime.h>
namespace
oneflow
{
namespace
ep
{
namespace
primitive
{
namespace
permute
{
namespace
internal
{
namespace
{
constexpr
int32_t
kMov4TileSize
=
32
;
constexpr
int32_t
kMov2TileSize
=
64
;
constexpr
int32_t
kBlockRows
=
8
;
template
<
size_t
num_dims
,
size_t
movement_size
,
typename
IndexType
>
__global__
void
PermuteKernel
(
PermuteKernelParams
<
num_dims
,
IndexType
>
params
)
{
using
T
=
typename
std
::
aligned_storage
<
movement_size
,
movement_size
>::
type
;
const
T
*
src
=
reinterpret_cast
<
const
T
*>
(
params
.
src
);
T
*
dst
=
reinterpret_cast
<
T
*>
(
params
.
dst
);
IndexType
src_index
[
num_dims
];
IndexType
dst_index
[
num_dims
];
CUDA_1D_KERNEL_LOOP_T
(
IndexType
,
i
,
params
.
count
)
{
params
.
dst_index_helper
.
OffsetToNdIndex
(
i
,
dst_index
);
#pragma unroll
for
(
size_t
dim
=
0
;
dim
<
num_dims
;
++
dim
)
{
src_index
[
params
.
permutation
[
dim
]]
=
dst_index
[
dim
];
}
IndexType
src_offset
=
params
.
src_index_helper
.
NdIndexToOffset
(
src_index
);
dst
[
i
]
=
src
[
src_offset
];
}
}
// (B, X, Y) -> (B, Y, X)
// refer from https://developer.nvidia.com/blog/efficient-matrix-transpose-cuda-cc/
template
<
size_t
num_dims
,
size_t
movement_size
,
size_t
tile_size
,
typename
IndexType
>
__global__
void
BatchTransposeKernel
(
const
void
*
src_ptr
,
void
*
dst_ptr
,
IndexType
rows
,
IndexType
cols
,
IndexType
num_tile_rows
,
IndexType
num_tile_cols
,
int32_t
block_nums
)
{
const
IndexType
src_rows
=
rows
;
const
IndexType
src_cols
=
cols
;
const
IndexType
dst_rows
=
cols
;
const
IndexType
dst_cols
=
rows
;
using
T
=
typename
std
::
aligned_storage
<
movement_size
,
movement_size
>::
type
;
__shared__
T
tile
[
tile_size
][
tile_size
+
1
];
// To avoid bank conflict.
const
T
*
src
=
reinterpret_cast
<
const
T
*>
(
src_ptr
);
T
*
dst
=
reinterpret_cast
<
T
*>
(
dst_ptr
);
IndexType
batch_num_tile
=
num_tile_rows
*
num_tile_cols
;
for
(
int
i
=
blockIdx
.
x
,
step
=
gridDim
.
x
;
i
<
block_nums
;
i
+=
step
)
{
const
IndexType
batch_index
=
i
/
batch_num_tile
;
// the index of batch.
const
IndexType
tile_index
=
i
-
batch_index
*
batch_num_tile
;
// equal to i % (num_tile_rows*num_tile_cols). the
// flatten index of tile in a batch.
const
IndexType
tile_row_index
=
tile_index
/
num_tile_cols
;
// the row index of tile in a batch.
const
IndexType
tile_col_index
=
tile_index
-
tile_row_index
*
num_tile_cols
;
// equal to k % num_tile_cols. the col index of tile in a batch.
const
IndexType
offset
=
batch_index
*
src_rows
*
src_cols
;
{
IndexType
col_in_tile
=
threadIdx
.
x
;
IndexType
col_in_matrix
=
tile_col_index
*
tile_size
+
threadIdx
.
x
;
#pragma unroll
for
(
IndexType
row_in_tile
=
threadIdx
.
y
;
row_in_tile
<
tile_size
;
row_in_tile
+=
kBlockRows
)
{
IndexType
row_in_matrix
=
row_in_tile
+
tile_row_index
*
tile_size
;
if
(
col_in_matrix
<
src_cols
&&
row_in_matrix
<
src_rows
)
{
tile
[
row_in_tile
][
col_in_tile
]
=
src
[
offset
+
row_in_matrix
*
src_cols
+
col_in_matrix
];
}
}
}
__syncthreads
();
{
IndexType
col_in_tile
=
threadIdx
.
x
;
IndexType
col_in_matrix
=
tile_row_index
*
tile_size
+
threadIdx
.
x
;
#pragma unroll
for
(
IndexType
row_in_tile
=
threadIdx
.
y
;
row_in_tile
<
tile_size
;
row_in_tile
+=
kBlockRows
)
{
IndexType
row_in_matrix
=
row_in_tile
+
tile_col_index
*
tile_size
;
if
(
col_in_matrix
<
dst_cols
&&
row_in_matrix
<
dst_rows
)
{
dst
[
offset
+
row_in_matrix
*
dst_cols
+
col_in_matrix
]
=
tile
[
col_in_tile
][
row_in_tile
];
}
}
}
__syncthreads
();
}
}
/*
Here is a Movementsie=2 version of Batch Transpose.
When the H W can be divided by 2. we can read data use movementsize=4, and write back as
movementsize=4.
*/
template
<
size_t
num_dims
,
size_t
tile_size
,
typename
IndexType
>
__global__
void
BatchTransposeMovement2Kernel
(
const
void
*
src_ptr
,
void
*
dst_ptr
,
IndexType
rows
,
IndexType
cols
,
IndexType
num_tile_rows
,
IndexType
num_tile_cols
,
int32_t
block_nums
)
{
const
IndexType
src_rows
=
rows
;
const
IndexType
src_cols
=
cols
;
const
IndexType
dst_rows
=
cols
;
const
IndexType
dst_cols
=
rows
;
static_assert
(
tile_size
%
2
==
0
,
""
);
using
T_MOV2
=
typename
std
::
aligned_storage
<
2
,
2
>::
type
;
using
T_MOV4
=
typename
std
::
aligned_storage
<
4
,
4
>::
type
;
const
T_MOV4
*
src
=
reinterpret_cast
<
const
T_MOV4
*>
(
src_ptr
);
T_MOV4
*
dst
=
reinterpret_cast
<
T_MOV4
*>
(
dst_ptr
);
// Use union structure to process Load and Store.
__shared__
union
{
T_MOV2
tile_m2
[
tile_size
][
tile_size
+
2
];
// half [64][66]
T_MOV4
tile_m4
[
tile_size
][
tile_size
/
2
+
1
];
// half2 [64][33]
}
tile_mem
;
IndexType
batch_num_tile
=
num_tile_rows
*
num_tile_cols
;
for
(
int
i
=
blockIdx
.
x
,
step
=
gridDim
.
x
;
i
<
block_nums
;
i
+=
step
)
{
const
IndexType
batch_index
=
i
/
batch_num_tile
;
// the index of batch.
const
IndexType
tile_index
=
i
-
batch_index
*
batch_num_tile
;
// equal to i % (num_tile_rows*num_tile_cols). the
// flatten index of tile in a batch.
const
IndexType
tile_row_index
=
tile_index
/
num_tile_cols
;
// the row index of tile in a batch.
const
IndexType
tile_col_index
=
tile_index
-
tile_row_index
*
num_tile_cols
;
// equal to k % num_tile_cols. the col index of tile in a batch.
const
IndexType
offset
=
batch_index
*
src_rows
*
src_cols
;
{
IndexType
col_in_tile
=
threadIdx
.
x
;
IndexType
col_in_matrix
=
tile_col_index
*
tile_size
+
threadIdx
.
x
*
2
;
#pragma unroll
for
(
IndexType
row_in_tile
=
threadIdx
.
y
;
row_in_tile
<
tile_size
;
row_in_tile
+=
kBlockRows
)
{
IndexType
row_in_matrix
=
row_in_tile
+
tile_row_index
*
tile_size
;
if
(
col_in_matrix
<
src_cols
&&
row_in_matrix
<
src_rows
)
{
tile_mem
.
tile_m4
[
row_in_tile
][
col_in_tile
]
=
src
[(
offset
+
row_in_matrix
*
src_cols
+
col_in_matrix
)
/
2
];
}
}
}
__syncthreads
();
{
IndexType
col_in_tile
=
threadIdx
.
x
;
IndexType
col_in_matrix
=
tile_row_index
*
tile_size
+
threadIdx
.
x
*
2
;
#pragma unroll
for
(
IndexType
row_in_tile
=
threadIdx
.
y
;
row_in_tile
<
tile_size
;
row_in_tile
+=
kBlockRows
)
{
IndexType
row_in_matrix
=
row_in_tile
+
tile_col_index
*
tile_size
;
union
{
T_MOV4
m4
;
T_MOV2
m2
[
2
];
}
tmp_storage
;
if
(
col_in_matrix
<
dst_cols
&&
row_in_matrix
<
dst_rows
)
{
tmp_storage
.
m2
[
0
]
=
tile_mem
.
tile_m2
[
col_in_tile
*
2
][
row_in_tile
];
tmp_storage
.
m2
[
1
]
=
tile_mem
.
tile_m2
[
col_in_tile
*
2
+
1
][
row_in_tile
];
dst
[(
offset
+
row_in_matrix
*
dst_cols
+
col_in_matrix
)
/
2
]
=
tmp_storage
.
m4
;
}
}
}
__syncthreads
();
}
}
template
<
size_t
num_dims
,
size_t
movement_size
,
size_t
tile_size
,
typename
IndexType
>
void
LaunchBatchTransposeKernel
(
hipStream_t
&
cuda_stream
,
const
PermuteKernelParams
<
num_dims
,
IndexType
>&
params
,
const
IndexType
&
num_batches
,
const
IndexType
&
rows
,
const
IndexType
&
cols
)
{
IndexType
num_tile_rows
=
(
rows
+
tile_size
-
1
)
/
tile_size
;
IndexType
num_tile_cols
=
(
cols
+
tile_size
-
1
)
/
tile_size
;
const
int32_t
block_nums
=
num_batches
*
num_tile_rows
*
num_tile_cols
;
int32_t
launched_block_nums
=
std
::
min
(
block_nums
,
kCudaMaxBlocksNum
);
if
(
tile_size
==
kMov2TileSize
)
{
const
int32_t
half2_thread
=
tile_size
/
2
;
// cause each thread process two half elements.
BatchTransposeMovement2Kernel
<
num_dims
,
kMov2TileSize
,
IndexType
>
<<<
launched_block_nums
,
dim3
(
half2_thread
,
kBlockRows
),
0
,
cuda_stream
>>>
(
params
.
src
,
params
.
dst
,
rows
,
cols
,
num_tile_rows
,
num_tile_cols
,
block_nums
);
// Set threads num as 32x8 cause each threads
// process 4 elements to 64x66 half share memory.
}
else
{
BatchTransposeKernel
<
num_dims
,
movement_size
,
tile_size
,
IndexType
>
<<<
launched_block_nums
,
dim3
(
tile_size
,
kBlockRows
),
0
,
cuda_stream
>>>
(
params
.
src
,
params
.
dst
,
rows
,
cols
,
num_tile_rows
,
num_tile_cols
,
block_nums
);
}
}
template
<
size_t
tile_size
,
typename
IndexType
>
bool
CheckIfGreaterEqualThanTileSize
(
const
IndexType
&
rows
,
const
IndexType
&
cols
)
{
if
(
rows
<
tile_size
||
cols
<
tile_size
)
{
return
false
;
}
return
true
;
}
template
<
size_t
num_dims
,
size_t
tile_size
,
typename
IndexType
>
bool
CheckLaunchBatchTranspose
(
const
int
*
permutation
,
const
IndexType
&
num_batches
,
const
IndexType
&
rows
,
const
IndexType
&
cols
)
{
if
(
CheckIfGreaterEqualThanTileSize
<
tile_size
,
IndexType
>
(
rows
,
cols
))
{
if
(
num_batches
==
1
&&
permutation
[
1
]
==
0
&&
permutation
[
0
]
==
1
)
{
// 2d tensor case: (0, 1) -> (1, 0)
return
true
;
}
else
if
(
num_dims
==
3
&&
permutation
[
2
]
==
1
&&
permutation
[
1
]
==
2
)
{
// 3d tensor case: (0, 1, 2) -> (0, 2, 1)
return
true
;
}
else
{
return
false
;
}
}
return
false
;
}
template
<
typename
IndexType
,
size_t
movement_size
>
bool
CheckUseMov2
(
const
IndexType
&
rows
,
const
IndexType
&
cols
,
const
void
*
src
,
void
*
dst
)
{
auto
src_ptr
=
reinterpret_cast
<
std
::
uintptr_t
>
(
src
);
auto
dst_ptr
=
reinterpret_cast
<
std
::
uintptr_t
>
(
dst
);
return
(
movement_size
==
2
)
&&
(
rows
%
2
==
0
)
&&
(
cols
%
2
==
0
)
&&
(
src_ptr
%
4
==
0
)
&&
(
dst_ptr
%
4
==
0
);
}
template
<
size_t
num_dims
,
typename
IndexType
>
void
InferBatchTransposeShape
(
const
int64_t
*
src_dims
,
IndexType
*
num_batches
,
IndexType
*
rows
,
IndexType
*
cols
)
{
if
(
num_dims
==
2
)
{
*
num_batches
=
1
;
*
rows
=
src_dims
[
0
];
*
cols
=
src_dims
[
1
];
}
else
{
*
num_batches
=
src_dims
[
0
];
*
rows
=
src_dims
[
1
];
*
cols
=
src_dims
[
2
];
}
}
template
<
size_t
num_dims
,
size_t
movement_size
,
typename
IndexType
>
void
LaunchKernel
(
Stream
*
stream
,
const
int64_t
*
src_dims
,
const
void
*
src
,
const
int
*
permutation
,
void
*
dst
,
size_t
count
)
{
PermuteKernelParams
<
num_dims
,
IndexType
>
params
=
MakePermuteParams
<
num_dims
,
IndexType
>
(
src_dims
,
src
,
permutation
,
dst
,
count
);
hipStream_t
cuda_stream
=
stream
->
As
<
CudaStream
>
()
->
cuda_stream
();
if
(
num_dims
==
2
||
num_dims
==
3
)
{
IndexType
num_batches
;
IndexType
rows
;
IndexType
cols
;
InferBatchTransposeShape
<
num_dims
,
IndexType
>
(
src_dims
,
&
num_batches
,
&
rows
,
&
cols
);
if
(
CheckLaunchBatchTranspose
<
num_dims
,
kMov4TileSize
>
(
params
.
permutation
,
num_batches
,
rows
,
cols
))
{
if
(
CheckUseMov2
<
IndexType
,
movement_size
>
(
rows
,
cols
,
src
,
dst
))
{
LaunchBatchTransposeKernel
<
num_dims
,
2
,
kMov2TileSize
,
IndexType
>
(
cuda_stream
,
params
,
num_batches
,
rows
,
cols
);
}
else
{
LaunchBatchTransposeKernel
<
num_dims
,
movement_size
,
kMov4TileSize
,
IndexType
>
(
cuda_stream
,
params
,
num_batches
,
rows
,
cols
);
}
}
else
{
PermuteKernel
<
num_dims
,
movement_size
,
IndexType
>
<<<
BlocksNum4ThreadsNum
(
params
.
count
),
kCudaThreadsNumPerBlock
,
0
,
cuda_stream
>>>
(
params
);
}
}
else
{
PermuteKernel
<
num_dims
,
movement_size
,
IndexType
>
<<<
BlocksNum4ThreadsNum
(
params
.
count
),
kCudaThreadsNumPerBlock
,
0
,
cuda_stream
>>>
(
params
);
}
}
class
PermuteImpl
:
public
Permute
{
public:
OF_DISALLOW_COPY_AND_MOVE
(
PermuteImpl
);
PermuteImpl
()
=
default
;
~
PermuteImpl
()
override
=
default
;
using
Permute
::
Launch
;
void
Launch
(
Stream
*
stream
,
DataType
data_type
,
size_t
num_dims
,
const
int64_t
*
src_dims
,
const
void
*
src
,
const
int
*
permutation
,
void
*
dst
)
override
{
SimplifyThenLaunch
(
stream
,
data_type
,
num_dims
,
src_dims
,
src
,
permutation
,
dst
);
}
};
class
PermuteFactoryImpl
:
public
PermuteFactory
{
public:
OF_DISALLOW_COPY_AND_MOVE
(
PermuteFactoryImpl
);
PermuteFactoryImpl
()
=
default
;
~
PermuteFactoryImpl
()
override
=
default
;
std
::
unique_ptr
<
Permute
>
New
(
size_t
max_num_dims
)
override
{
if
(
max_num_dims
<=
kMaxNumDims
)
{
return
std
::
unique_ptr
<
Permute
>
(
new
PermuteImpl
());
}
else
{
return
nullptr
;
}
}
};
REGISTER_PRIMITIVE_FACTORY
(
DeviceType
::
kCUDA
,
PermuteFactory
,
PermuteFactoryImpl
);
}
// namespace
}
// namespace internal
}
// namespace permute
}
// namespace primitive
}
// namespace ep
}
// namespace oneflow
#include "hip/hip_runtime.h"
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/ep/include/primitive/permute.h"
#include "oneflow/core/ep/common/primitive/permute_impl.h"
#include "oneflow/core/ep/rocm/cuda_stream.h"
#include <hip/hip_runtime.h>
namespace
oneflow
{
namespace
ep
{
namespace
primitive
{
namespace
permute
{
namespace
internal
{
namespace
{
constexpr
int32_t
kMov4TileSize
=
32
;
constexpr
int32_t
kMov2TileSize
=
64
;
constexpr
int32_t
kBlockRows
=
8
;
template
<
size_t
num_dims
,
size_t
movement_size
,
typename
IndexType
>
__global__
void
PermuteKernel
(
PermuteKernelParams
<
num_dims
,
IndexType
>
params
)
{
using
T
=
typename
std
::
aligned_storage
<
movement_size
,
movement_size
>::
type
;
const
T
*
src
=
reinterpret_cast
<
const
T
*>
(
params
.
src
);
T
*
dst
=
reinterpret_cast
<
T
*>
(
params
.
dst
);
IndexType
src_index
[
num_dims
];
IndexType
dst_index
[
num_dims
];
CUDA_1D_KERNEL_LOOP_T
(
IndexType
,
i
,
params
.
count
)
{
params
.
dst_index_helper
.
OffsetToNdIndex
(
i
,
dst_index
);
#pragma unroll
for
(
size_t
dim
=
0
;
dim
<
num_dims
;
++
dim
)
{
src_index
[
params
.
permutation
[
dim
]]
=
dst_index
[
dim
];
}
IndexType
src_offset
=
params
.
src_index_helper
.
NdIndexToOffset
(
src_index
);
dst
[
i
]
=
src
[
src_offset
];
}
}
// (B, X, Y) -> (B, Y, X)
// refer from https://developer.nvidia.com/blog/efficient-matrix-transpose-cuda-cc/
template
<
size_t
num_dims
,
size_t
movement_size
,
size_t
tile_size
,
typename
IndexType
>
__global__
void
BatchTransposeKernel
(
const
void
*
src_ptr
,
void
*
dst_ptr
,
IndexType
rows
,
IndexType
cols
,
IndexType
num_tile_rows
,
IndexType
num_tile_cols
,
int32_t
block_nums
)
{
const
IndexType
src_rows
=
rows
;
const
IndexType
src_cols
=
cols
;
const
IndexType
dst_rows
=
cols
;
const
IndexType
dst_cols
=
rows
;
using
T
=
typename
std
::
aligned_storage
<
movement_size
,
movement_size
>::
type
;
__shared__
T
tile
[
tile_size
][
tile_size
+
1
];
// To avoid bank conflict.
const
T
*
src
=
reinterpret_cast
<
const
T
*>
(
src_ptr
);
T
*
dst
=
reinterpret_cast
<
T
*>
(
dst_ptr
);
IndexType
batch_num_tile
=
num_tile_rows
*
num_tile_cols
;
for
(
int
i
=
blockIdx
.
x
,
step
=
gridDim
.
x
;
i
<
block_nums
;
i
+=
step
)
{
const
IndexType
batch_index
=
i
/
batch_num_tile
;
// the index of batch.
const
IndexType
tile_index
=
i
-
batch_index
*
batch_num_tile
;
// equal to i % (num_tile_rows*num_tile_cols). the
// flatten index of tile in a batch.
const
IndexType
tile_row_index
=
tile_index
/
num_tile_cols
;
// the row index of tile in a batch.
const
IndexType
tile_col_index
=
tile_index
-
tile_row_index
*
num_tile_cols
;
// equal to k % num_tile_cols. the col index of tile in a batch.
const
IndexType
offset
=
batch_index
*
src_rows
*
src_cols
;
{
IndexType
col_in_tile
=
threadIdx
.
x
;
IndexType
col_in_matrix
=
tile_col_index
*
tile_size
+
threadIdx
.
x
;
#pragma unroll
for
(
IndexType
row_in_tile
=
threadIdx
.
y
;
row_in_tile
<
tile_size
;
row_in_tile
+=
kBlockRows
)
{
IndexType
row_in_matrix
=
row_in_tile
+
tile_row_index
*
tile_size
;
if
(
col_in_matrix
<
src_cols
&&
row_in_matrix
<
src_rows
)
{
tile
[
row_in_tile
][
col_in_tile
]
=
src
[
offset
+
row_in_matrix
*
src_cols
+
col_in_matrix
];
}
}
}
__syncthreads
();
{
IndexType
col_in_tile
=
threadIdx
.
x
;
IndexType
col_in_matrix
=
tile_row_index
*
tile_size
+
threadIdx
.
x
;
#pragma unroll
for
(
IndexType
row_in_tile
=
threadIdx
.
y
;
row_in_tile
<
tile_size
;
row_in_tile
+=
kBlockRows
)
{
IndexType
row_in_matrix
=
row_in_tile
+
tile_col_index
*
tile_size
;
if
(
col_in_matrix
<
dst_cols
&&
row_in_matrix
<
dst_rows
)
{
dst
[
offset
+
row_in_matrix
*
dst_cols
+
col_in_matrix
]
=
tile
[
col_in_tile
][
row_in_tile
];
}
}
}
__syncthreads
();
}
}
/*
Here is a Movementsie=2 version of Batch Transpose.
When the H W can be divided by 2. we can read data use movementsize=4, and write back as
movementsize=4.
*/
template
<
size_t
num_dims
,
size_t
tile_size
,
typename
IndexType
>
__global__
void
BatchTransposeMovement2Kernel
(
const
void
*
src_ptr
,
void
*
dst_ptr
,
IndexType
rows
,
IndexType
cols
,
IndexType
num_tile_rows
,
IndexType
num_tile_cols
,
int32_t
block_nums
)
{
const
IndexType
src_rows
=
rows
;
const
IndexType
src_cols
=
cols
;
const
IndexType
dst_rows
=
cols
;
const
IndexType
dst_cols
=
rows
;
static_assert
(
tile_size
%
2
==
0
,
""
);
using
T_MOV2
=
typename
std
::
aligned_storage
<
2
,
2
>::
type
;
using
T_MOV4
=
typename
std
::
aligned_storage
<
4
,
4
>::
type
;
const
T_MOV4
*
src
=
reinterpret_cast
<
const
T_MOV4
*>
(
src_ptr
);
T_MOV4
*
dst
=
reinterpret_cast
<
T_MOV4
*>
(
dst_ptr
);
// Use union structure to process Load and Store.
__shared__
union
{
T_MOV2
tile_m2
[
tile_size
][
tile_size
+
2
];
// half [64][66]
T_MOV4
tile_m4
[
tile_size
][
tile_size
/
2
+
1
];
// half2 [64][33]
}
tile_mem
;
IndexType
batch_num_tile
=
num_tile_rows
*
num_tile_cols
;
for
(
int
i
=
blockIdx
.
x
,
step
=
gridDim
.
x
;
i
<
block_nums
;
i
+=
step
)
{
const
IndexType
batch_index
=
i
/
batch_num_tile
;
// the index of batch.
const
IndexType
tile_index
=
i
-
batch_index
*
batch_num_tile
;
// equal to i % (num_tile_rows*num_tile_cols). the
// flatten index of tile in a batch.
const
IndexType
tile_row_index
=
tile_index
/
num_tile_cols
;
// the row index of tile in a batch.
const
IndexType
tile_col_index
=
tile_index
-
tile_row_index
*
num_tile_cols
;
// equal to k % num_tile_cols. the col index of tile in a batch.
const
IndexType
offset
=
batch_index
*
src_rows
*
src_cols
;
{
IndexType
col_in_tile
=
threadIdx
.
x
;
IndexType
col_in_matrix
=
tile_col_index
*
tile_size
+
threadIdx
.
x
*
2
;
#pragma unroll
for
(
IndexType
row_in_tile
=
threadIdx
.
y
;
row_in_tile
<
tile_size
;
row_in_tile
+=
kBlockRows
)
{
IndexType
row_in_matrix
=
row_in_tile
+
tile_row_index
*
tile_size
;
if
(
col_in_matrix
<
src_cols
&&
row_in_matrix
<
src_rows
)
{
tile_mem
.
tile_m4
[
row_in_tile
][
col_in_tile
]
=
src
[(
offset
+
row_in_matrix
*
src_cols
+
col_in_matrix
)
/
2
];
}
}
}
__syncthreads
();
{
IndexType
col_in_tile
=
threadIdx
.
x
;
IndexType
col_in_matrix
=
tile_row_index
*
tile_size
+
threadIdx
.
x
*
2
;
#pragma unroll
for
(
IndexType
row_in_tile
=
threadIdx
.
y
;
row_in_tile
<
tile_size
;
row_in_tile
+=
kBlockRows
)
{
IndexType
row_in_matrix
=
row_in_tile
+
tile_col_index
*
tile_size
;
union
{
T_MOV4
m4
;
T_MOV2
m2
[
2
];
}
tmp_storage
;
if
(
col_in_matrix
<
dst_cols
&&
row_in_matrix
<
dst_rows
)
{
tmp_storage
.
m2
[
0
]
=
tile_mem
.
tile_m2
[
col_in_tile
*
2
][
row_in_tile
];
tmp_storage
.
m2
[
1
]
=
tile_mem
.
tile_m2
[
col_in_tile
*
2
+
1
][
row_in_tile
];
dst
[(
offset
+
row_in_matrix
*
dst_cols
+
col_in_matrix
)
/
2
]
=
tmp_storage
.
m4
;
}
}
}
__syncthreads
();
}
}
template
<
size_t
num_dims
,
size_t
movement_size
,
size_t
tile_size
,
typename
IndexType
>
void
LaunchBatchTransposeKernel
(
hipStream_t
&
cuda_stream
,
const
PermuteKernelParams
<
num_dims
,
IndexType
>&
params
,
const
IndexType
&
num_batches
,
const
IndexType
&
rows
,
const
IndexType
&
cols
)
{
IndexType
num_tile_rows
=
(
rows
+
tile_size
-
1
)
/
tile_size
;
IndexType
num_tile_cols
=
(
cols
+
tile_size
-
1
)
/
tile_size
;
const
int32_t
block_nums
=
num_batches
*
num_tile_rows
*
num_tile_cols
;
int32_t
launched_block_nums
=
std
::
min
(
block_nums
,
kCudaMaxBlocksNum
);
if
(
tile_size
==
kMov2TileSize
)
{
const
int32_t
half2_thread
=
tile_size
/
2
;
// cause each thread process two half elements.
BatchTransposeMovement2Kernel
<
num_dims
,
kMov2TileSize
,
IndexType
>
<<<
launched_block_nums
,
dim3
(
half2_thread
,
kBlockRows
),
0
,
cuda_stream
>>>
(
params
.
src
,
params
.
dst
,
rows
,
cols
,
num_tile_rows
,
num_tile_cols
,
block_nums
);
// Set threads num as 32x8 cause each threads
// process 4 elements to 64x66 half share memory.
}
else
{
BatchTransposeKernel
<
num_dims
,
movement_size
,
tile_size
,
IndexType
>
<<<
launched_block_nums
,
dim3
(
tile_size
,
kBlockRows
),
0
,
cuda_stream
>>>
(
params
.
src
,
params
.
dst
,
rows
,
cols
,
num_tile_rows
,
num_tile_cols
,
block_nums
);
}
}
template
<
size_t
tile_size
,
typename
IndexType
>
bool
CheckIfGreaterEqualThanTileSize
(
const
IndexType
&
rows
,
const
IndexType
&
cols
)
{
if
(
rows
<
tile_size
||
cols
<
tile_size
)
{
return
false
;
}
return
true
;
}
template
<
size_t
num_dims
,
size_t
tile_size
,
typename
IndexType
>
bool
CheckLaunchBatchTranspose
(
const
int
*
permutation
,
const
IndexType
&
num_batches
,
const
IndexType
&
rows
,
const
IndexType
&
cols
)
{
if
(
CheckIfGreaterEqualThanTileSize
<
tile_size
,
IndexType
>
(
rows
,
cols
))
{
if
(
num_batches
==
1
&&
permutation
[
1
]
==
0
&&
permutation
[
0
]
==
1
)
{
// 2d tensor case: (0, 1) -> (1, 0)
return
true
;
}
else
if
(
num_dims
==
3
&&
permutation
[
2
]
==
1
&&
permutation
[
1
]
==
2
)
{
// 3d tensor case: (0, 1, 2) -> (0, 2, 1)
return
true
;
}
else
{
return
false
;
}
}
return
false
;
}
template
<
typename
IndexType
,
size_t
movement_size
>
bool
CheckUseMov2
(
const
IndexType
&
rows
,
const
IndexType
&
cols
,
const
void
*
src
,
void
*
dst
)
{
auto
src_ptr
=
reinterpret_cast
<
std
::
uintptr_t
>
(
src
);
auto
dst_ptr
=
reinterpret_cast
<
std
::
uintptr_t
>
(
dst
);
return
(
movement_size
==
2
)
&&
(
rows
%
2
==
0
)
&&
(
cols
%
2
==
0
)
&&
(
src_ptr
%
4
==
0
)
&&
(
dst_ptr
%
4
==
0
);
}
template
<
size_t
num_dims
,
typename
IndexType
>
void
InferBatchTransposeShape
(
const
int64_t
*
src_dims
,
IndexType
*
num_batches
,
IndexType
*
rows
,
IndexType
*
cols
)
{
if
(
num_dims
==
2
)
{
*
num_batches
=
1
;
*
rows
=
src_dims
[
0
];
*
cols
=
src_dims
[
1
];
}
else
{
*
num_batches
=
src_dims
[
0
];
*
rows
=
src_dims
[
1
];
*
cols
=
src_dims
[
2
];
}
}
template
<
size_t
num_dims
,
size_t
movement_size
,
typename
IndexType
>
void
LaunchKernel
(
Stream
*
stream
,
const
int64_t
*
src_dims
,
const
void
*
src
,
const
int
*
permutation
,
void
*
dst
,
size_t
count
)
{
PermuteKernelParams
<
num_dims
,
IndexType
>
params
=
MakePermuteParams
<
num_dims
,
IndexType
>
(
src_dims
,
src
,
permutation
,
dst
,
count
);
hipStream_t
cuda_stream
=
stream
->
As
<
CudaStream
>
()
->
cuda_stream
();
if
(
num_dims
==
2
||
num_dims
==
3
)
{
IndexType
num_batches
;
IndexType
rows
;
IndexType
cols
;
InferBatchTransposeShape
<
num_dims
,
IndexType
>
(
src_dims
,
&
num_batches
,
&
rows
,
&
cols
);
if
(
CheckLaunchBatchTranspose
<
num_dims
,
kMov4TileSize
>
(
params
.
permutation
,
num_batches
,
rows
,
cols
))
{
if
(
CheckUseMov2
<
IndexType
,
movement_size
>
(
rows
,
cols
,
src
,
dst
))
{
LaunchBatchTransposeKernel
<
num_dims
,
2
,
kMov2TileSize
,
IndexType
>
(
cuda_stream
,
params
,
num_batches
,
rows
,
cols
);
}
else
{
LaunchBatchTransposeKernel
<
num_dims
,
movement_size
,
kMov4TileSize
,
IndexType
>
(
cuda_stream
,
params
,
num_batches
,
rows
,
cols
);
}
}
else
{
PermuteKernel
<
num_dims
,
movement_size
,
IndexType
>
<<<
BlocksNum4ThreadsNum
(
params
.
count
),
kCudaThreadsNumPerBlock
,
0
,
cuda_stream
>>>
(
params
);
}
}
else
{
PermuteKernel
<
num_dims
,
movement_size
,
IndexType
>
<<<
BlocksNum4ThreadsNum
(
params
.
count
),
kCudaThreadsNumPerBlock
,
0
,
cuda_stream
>>>
(
params
);
}
}
class
PermuteImpl
:
public
Permute
{
public:
OF_DISALLOW_COPY_AND_MOVE
(
PermuteImpl
);
PermuteImpl
()
=
default
;
~
PermuteImpl
()
override
=
default
;
using
Permute
::
Launch
;
void
Launch
(
Stream
*
stream
,
DataType
data_type
,
size_t
num_dims
,
const
int64_t
*
src_dims
,
const
void
*
src
,
const
int
*
permutation
,
void
*
dst
)
override
{
SimplifyThenLaunch
(
stream
,
data_type
,
num_dims
,
src_dims
,
src
,
permutation
,
dst
);
}
};
class
PermuteFactoryImpl
:
public
PermuteFactory
{
public:
OF_DISALLOW_COPY_AND_MOVE
(
PermuteFactoryImpl
);
PermuteFactoryImpl
()
=
default
;
~
PermuteFactoryImpl
()
override
=
default
;
std
::
unique_ptr
<
Permute
>
New
(
size_t
max_num_dims
)
override
{
if
(
max_num_dims
<=
kMaxNumDims
)
{
return
std
::
unique_ptr
<
Permute
>
(
new
PermuteImpl
());
}
else
{
return
nullptr
;
}
}
};
REGISTER_PRIMITIVE_FACTORY
(
DeviceType
::
kCUDA
,
PermuteFactory
,
PermuteFactoryImpl
);
}
// namespace
}
// namespace internal
}
// namespace permute
}
// namespace primitive
}
// namespace ep
}
// namespace oneflow
oneflow/core/ep/rocm/primitive/softmax.hip.cpp
View file @
8f7de847
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/common/preprocessor.h"
#include "oneflow/core/ep/include/primitive/softmax.h"
#include "oneflow/core/ep/include/primitive/log_softmax.h"
#include "oneflow/core/ep/rocm/primitive/type_seq.h"
#include "oneflow/core/hip/softmax.hip.h"
#include "oneflow/core/ep/rocm/cuda_stream.h"
namespace
oneflow
{
namespace
ep
{
namespace
primitive
{
namespace
{
enum
class
Algorithm
{
kSoftmax
,
kLogSoftmax
,
};
template
<
Algorithm
algorithm
,
typename
T
>
void
SoftmaxGpu
(
hipStream_t
cuda_stream
,
size_t
rows
,
size_t
cols
,
const
T
*
x
,
T
*
y
)
{
using
ComputeType
=
typename
cuda
::
softmax
::
DefaultComputeType
<
T
>::
type
;
oneflow
::
cuda
::
softmax
::
DirectLoad
<
T
,
ComputeType
>
load
(
x
,
cols
);
oneflow
::
cuda
::
softmax
::
DirectStore
<
ComputeType
,
T
>
store
(
y
,
cols
);
if
(
algorithm
==
Algorithm
::
kSoftmax
)
{
OF_CUDA_CHECK
((
cuda
::
softmax
::
DispatchSoftmax
<
decltype
(
load
),
decltype
(
store
),
ComputeType
>
(
cuda_stream
,
load
,
store
,
rows
,
cols
)));
}
else
if
(
algorithm
==
Algorithm
::
kLogSoftmax
)
{
OF_CUDA_CHECK
((
cuda
::
softmax
::
DispatchLogSoftmax
<
decltype
(
load
),
decltype
(
store
),
ComputeType
>
(
cuda_stream
,
load
,
store
,
rows
,
cols
)));
}
else
{
UNIMPLEMENTED
();
}
}
template
<
typename
SoftmaxBase
,
Algorithm
algorithm
,
typename
T
>
class
SoftmaxImpl
:
public
SoftmaxBase
{
public:
OF_DISALLOW_COPY_AND_MOVE
(
SoftmaxImpl
);
SoftmaxImpl
()
=
default
;
~
SoftmaxImpl
()
override
=
default
;
void
Launch
(
Stream
*
stream
,
size_t
rows
,
size_t
cols
,
const
void
*
x
,
void
*
y
)
override
{
hipStream_t
cuda_stream
=
stream
->
As
<
CudaStream
>
()
->
cuda_stream
();
SoftmaxGpu
<
algorithm
,
T
>
(
cuda_stream
,
rows
,
cols
,
reinterpret_cast
<
const
T
*>
(
x
),
reinterpret_cast
<
T
*>
(
y
));
}
};
template
<
typename
SoftmaxBase
,
Algorithm
algorithm
,
typename
T
>
std
::
unique_ptr
<
SoftmaxBase
>
NewSoftmax
()
{
return
std
::
unique_ptr
<
SoftmaxBase
>
(
new
SoftmaxImpl
<
SoftmaxBase
,
algorithm
,
T
>
());
}
template
<
typename
FactoryBase
,
typename
SoftmaxBase
,
Algorithm
algorithm
>
class
GenericSoftmaxFactoryImpl
:
public
FactoryBase
{
public:
OF_DISALLOW_COPY_AND_MOVE
(
GenericSoftmaxFactoryImpl
);
GenericSoftmaxFactoryImpl
()
=
default
;
~
GenericSoftmaxFactoryImpl
()
override
=
default
;
std
::
unique_ptr
<
SoftmaxBase
>
New
(
DataType
data_type
)
override
{
#define MAKE_NEW_SOFTMAX_ENTRY(type_cpp, type_proto) \
{
type_proto
,
NewSoftmax
<
SoftmaxBase
,
algorithm
,
type_cpp
>
},
static
const
std
::
map
<
DataType
,
std
::
function
<
std
::
unique_ptr
<
SoftmaxBase
>
()
>>
new_softmax_handle
{
OF_PP_FOR_EACH_TUPLE
(
MAKE_NEW_SOFTMAX_ENTRY
,
CUDA_PRIMITIVE_FLOATING_TYPE_SEQ
)};
#undef MAKE_NEW_SOFTMAX_ENTRY
const
auto
it
=
new_softmax_handle
.
find
(
data_type
);
if
(
it
!=
new_softmax_handle
.
end
())
{
return
it
->
second
();
}
else
{
return
nullptr
;
}
}
};
using
SoftmaxFactoryImpl
=
GenericSoftmaxFactoryImpl
<
SoftmaxFactory
,
Softmax
,
Algorithm
::
kSoftmax
>
;
using
LogSoftmaxFactoryImpl
=
GenericSoftmaxFactoryImpl
<
LogSoftmaxFactory
,
LogSoftmax
,
Algorithm
::
kLogSoftmax
>
;
REGISTER_PRIMITIVE_FACTORY
(
DeviceType
::
kCUDA
,
SoftmaxFactory
,
SoftmaxFactoryImpl
);
REGISTER_PRIMITIVE_FACTORY
(
DeviceType
::
kCUDA
,
LogSoftmaxFactory
,
LogSoftmaxFactoryImpl
);
}
// namespace
}
// namespace primitive
}
// namespace ep
}
// namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/common/preprocessor.h"
#include "oneflow/core/ep/include/primitive/softmax.h"
#include "oneflow/core/ep/include/primitive/log_softmax.h"
#include "oneflow/core/ep/rocm/primitive/type_seq.h"
#include "oneflow/core/hip/softmax.hip.h"
#include "oneflow/core/ep/rocm/cuda_stream.h"
namespace
oneflow
{
namespace
ep
{
namespace
primitive
{
namespace
{
enum
class
Algorithm
{
kSoftmax
,
kLogSoftmax
,
};
template
<
Algorithm
algorithm
,
typename
T
>
void
SoftmaxGpu
(
hipStream_t
cuda_stream
,
size_t
rows
,
size_t
cols
,
const
T
*
x
,
T
*
y
)
{
using
ComputeType
=
typename
cuda
::
softmax
::
DefaultComputeType
<
T
>::
type
;
oneflow
::
cuda
::
softmax
::
DirectLoad
<
T
,
ComputeType
>
load
(
x
,
cols
);
oneflow
::
cuda
::
softmax
::
DirectStore
<
ComputeType
,
T
>
store
(
y
,
cols
);
if
(
algorithm
==
Algorithm
::
kSoftmax
)
{
OF_CUDA_CHECK
((
cuda
::
softmax
::
DispatchSoftmax
<
decltype
(
load
),
decltype
(
store
),
ComputeType
>
(
cuda_stream
,
load
,
store
,
rows
,
cols
)));
}
else
if
(
algorithm
==
Algorithm
::
kLogSoftmax
)
{
OF_CUDA_CHECK
((
cuda
::
softmax
::
DispatchLogSoftmax
<
decltype
(
load
),
decltype
(
store
),
ComputeType
>
(
cuda_stream
,
load
,
store
,
rows
,
cols
)));
}
else
{
UNIMPLEMENTED
();
}
}
template
<
typename
SoftmaxBase
,
Algorithm
algorithm
,
typename
T
>
class
SoftmaxImpl
:
public
SoftmaxBase
{
public:
OF_DISALLOW_COPY_AND_MOVE
(
SoftmaxImpl
);
SoftmaxImpl
()
=
default
;
~
SoftmaxImpl
()
override
=
default
;
void
Launch
(
Stream
*
stream
,
size_t
rows
,
size_t
cols
,
const
void
*
x
,
void
*
y
)
override
{
hipStream_t
cuda_stream
=
stream
->
As
<
CudaStream
>
()
->
cuda_stream
();
SoftmaxGpu
<
algorithm
,
T
>
(
cuda_stream
,
rows
,
cols
,
reinterpret_cast
<
const
T
*>
(
x
),
reinterpret_cast
<
T
*>
(
y
));
}
};
template
<
typename
SoftmaxBase
,
Algorithm
algorithm
,
typename
T
>
std
::
unique_ptr
<
SoftmaxBase
>
NewSoftmax
()
{
return
std
::
unique_ptr
<
SoftmaxBase
>
(
new
SoftmaxImpl
<
SoftmaxBase
,
algorithm
,
T
>
());
}
template
<
typename
FactoryBase
,
typename
SoftmaxBase
,
Algorithm
algorithm
>
class
GenericSoftmaxFactoryImpl
:
public
FactoryBase
{
public:
OF_DISALLOW_COPY_AND_MOVE
(
GenericSoftmaxFactoryImpl
);
GenericSoftmaxFactoryImpl
()
=
default
;
~
GenericSoftmaxFactoryImpl
()
override
=
default
;
std
::
unique_ptr
<
SoftmaxBase
>
New
(
DataType
data_type
)
override
{
#define MAKE_NEW_SOFTMAX_ENTRY(type_cpp, type_proto) \
{type_proto, NewSoftmax<SoftmaxBase, algorithm, type_cpp>},
static
const
std
::
map
<
DataType
,
std
::
function
<
std
::
unique_ptr
<
SoftmaxBase
>
()
>>
new_softmax_handle
{
OF_PP_FOR_EACH_TUPLE
(
MAKE_NEW_SOFTMAX_ENTRY
,
CUDA_PRIMITIVE_FLOATING_TYPE_SEQ
)};
#undef MAKE_NEW_SOFTMAX_ENTRY
const
auto
it
=
new_softmax_handle
.
find
(
data_type
);
if
(
it
!=
new_softmax_handle
.
end
())
{
return
it
->
second
();
}
else
{
return
nullptr
;
}
}
};
using
SoftmaxFactoryImpl
=
GenericSoftmaxFactoryImpl
<
SoftmaxFactory
,
Softmax
,
Algorithm
::
kSoftmax
>
;
using
LogSoftmaxFactoryImpl
=
GenericSoftmaxFactoryImpl
<
LogSoftmaxFactory
,
LogSoftmax
,
Algorithm
::
kLogSoftmax
>
;
REGISTER_PRIMITIVE_FACTORY
(
DeviceType
::
kCUDA
,
SoftmaxFactory
,
SoftmaxFactoryImpl
);
REGISTER_PRIMITIVE_FACTORY
(
DeviceType
::
kCUDA
,
LogSoftmaxFactory
,
LogSoftmaxFactoryImpl
);
}
// namespace
}
// namespace primitive
}
// namespace ep
}
// namespace oneflow
oneflow/core/ep/rocm/primitive/softmax_backward.hip.cpp
View file @
8f7de847
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/common/preprocessor.h"
#include "oneflow/core/ep/include/primitive/softmax_backward.h"
#include "oneflow/core/ep/include/primitive/log_softmax_backward.h"
#include "oneflow/core/ep/rocm/primitive/type_seq.h"
#include "oneflow/core/hip/softmax.hip.h"
#include "oneflow/core/ep/rocm/cuda_stream.h"
namespace
oneflow
{
namespace
ep
{
namespace
primitive
{
namespace
{
enum
class
Algorithm
{
kSoftmax
,
kLogSoftmax
,
};
template
<
Algorithm
algorithm
,
typename
T
>
void
SoftmaxBackwardGpu
(
hipStream_t
cuda_stream
,
size_t
rows
,
size_t
cols
,
const
T
*
y
,
const
T
*
dy
,
T
*
dx
)
{
using
ComputeType
=
typename
cuda
::
softmax
::
DefaultComputeType
<
T
>::
type
;
cuda
::
softmax
::
DirectLoad
<
T
,
ComputeType
>
load_y
(
y
,
cols
);
cuda
::
softmax
::
DirectLoad
<
T
,
ComputeType
>
load_dy
(
dy
,
cols
);
cuda
::
softmax
::
DirectStore
<
ComputeType
,
T
>
store
(
dx
,
cols
);
if
(
algorithm
==
Algorithm
::
kSoftmax
)
{
OF_CUDA_CHECK
((
cuda
::
softmax
::
DispatchSoftmaxGrad
<
decltype
(
load_y
),
decltype
(
load_dy
),
decltype
(
store
),
ComputeType
>
(
cuda_stream
,
load_y
,
load_dy
,
store
,
rows
,
cols
)));
}
else
if
(
algorithm
==
Algorithm
::
kLogSoftmax
)
{
OF_CUDA_CHECK
((
cuda
::
softmax
::
DispatchLogSoftmaxGrad
<
decltype
(
load_y
),
decltype
(
load_dy
),
decltype
(
store
),
ComputeType
>
(
cuda_stream
,
load_y
,
load_dy
,
store
,
rows
,
cols
)));
}
else
{
UNIMPLEMENTED
();
}
}
template
<
typename
SoftmaxBackwardBase
,
Algorithm
algorithm
,
typename
T
>
class
SoftmaxBackwardImpl
:
public
SoftmaxBackwardBase
{
public:
OF_DISALLOW_COPY_AND_MOVE
(
SoftmaxBackwardImpl
);
SoftmaxBackwardImpl
()
=
default
;
~
SoftmaxBackwardImpl
()
override
=
default
;
void
Launch
(
Stream
*
stream
,
size_t
rows
,
size_t
cols
,
const
void
*
y
,
const
void
*
dy
,
void
*
dx
)
override
{
hipStream_t
cuda_stream
=
stream
->
As
<
CudaStream
>
()
->
cuda_stream
();
SoftmaxBackwardGpu
<
algorithm
,
T
>
(
cuda_stream
,
rows
,
cols
,
reinterpret_cast
<
const
T
*>
(
y
),
reinterpret_cast
<
const
T
*>
(
dy
),
reinterpret_cast
<
T
*>
(
dx
));
}
};
template
<
typename
SoftmaxBackwardBase
,
Algorithm
algorithm
,
typename
T
>
std
::
unique_ptr
<
SoftmaxBackwardBase
>
NewSoftmaxBackward
()
{
return
std
::
unique_ptr
<
SoftmaxBackwardBase
>
(
new
SoftmaxBackwardImpl
<
SoftmaxBackwardBase
,
algorithm
,
T
>
());
}
template
<
typename
BackwardFactoryBase
,
typename
SoftmaxBackwardBase
,
Algorithm
algorithm
>
class
GenericSoftmaxBackwardFactoryImpl
:
public
BackwardFactoryBase
{
public:
OF_DISALLOW_COPY_AND_MOVE
(
GenericSoftmaxBackwardFactoryImpl
);
GenericSoftmaxBackwardFactoryImpl
()
=
default
;
~
GenericSoftmaxBackwardFactoryImpl
()
override
=
default
;
std
::
unique_ptr
<
SoftmaxBackwardBase
>
New
(
DataType
data_type
)
override
{
#define MAKE_NEW_SOFTMAX_ENTRY(type_cpp, type_proto) \
{
type_proto
,
NewSoftmaxBackward
<
SoftmaxBackwardBase
,
algorithm
,
type_cpp
>
},
static
const
std
::
map
<
DataType
,
std
::
function
<
std
::
unique_ptr
<
SoftmaxBackwardBase
>
()
>>
new_softmax_backward_handle
{
OF_PP_FOR_EACH_TUPLE
(
MAKE_NEW_SOFTMAX_ENTRY
,
CUDA_PRIMITIVE_FLOATING_TYPE_SEQ
)};
#undef MAKE_NEW_SOFTMAX_ENTRY
const
auto
it
=
new_softmax_backward_handle
.
find
(
data_type
);
if
(
it
!=
new_softmax_backward_handle
.
end
())
{
return
it
->
second
();
}
else
{
return
nullptr
;
}
}
};
using
SoftmaxBackwardFactoryImpl
=
GenericSoftmaxBackwardFactoryImpl
<
SoftmaxBackwardFactory
,
SoftmaxBackward
,
Algorithm
::
kSoftmax
>
;
using
LogSoftmaxBackwardFactoryImpl
=
GenericSoftmaxBackwardFactoryImpl
<
LogSoftmaxBackwardFactory
,
LogSoftmaxBackward
,
Algorithm
::
kLogSoftmax
>
;
REGISTER_PRIMITIVE_FACTORY
(
DeviceType
::
kCUDA
,
SoftmaxBackwardFactory
,
SoftmaxBackwardFactoryImpl
);
REGISTER_PRIMITIVE_FACTORY
(
DeviceType
::
kCUDA
,
LogSoftmaxBackwardFactory
,
LogSoftmaxBackwardFactoryImpl
);
}
// namespace
}
// namespace primitive
}
// namespace ep
}
// namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/common/preprocessor.h"
#include "oneflow/core/ep/include/primitive/softmax_backward.h"
#include "oneflow/core/ep/include/primitive/log_softmax_backward.h"
#include "oneflow/core/ep/rocm/primitive/type_seq.h"
#include "oneflow/core/hip/softmax.hip.h"
#include "oneflow/core/ep/rocm/cuda_stream.h"
namespace
oneflow
{
namespace
ep
{
namespace
primitive
{
namespace
{
enum
class
Algorithm
{
kSoftmax
,
kLogSoftmax
,
};
template
<
Algorithm
algorithm
,
typename
T
>
void
SoftmaxBackwardGpu
(
hipStream_t
cuda_stream
,
size_t
rows
,
size_t
cols
,
const
T
*
y
,
const
T
*
dy
,
T
*
dx
)
{
using
ComputeType
=
typename
cuda
::
softmax
::
DefaultComputeType
<
T
>::
type
;
cuda
::
softmax
::
DirectLoad
<
T
,
ComputeType
>
load_y
(
y
,
cols
);
cuda
::
softmax
::
DirectLoad
<
T
,
ComputeType
>
load_dy
(
dy
,
cols
);
cuda
::
softmax
::
DirectStore
<
ComputeType
,
T
>
store
(
dx
,
cols
);
if
(
algorithm
==
Algorithm
::
kSoftmax
)
{
OF_CUDA_CHECK
((
cuda
::
softmax
::
DispatchSoftmaxGrad
<
decltype
(
load_y
),
decltype
(
load_dy
),
decltype
(
store
),
ComputeType
>
(
cuda_stream
,
load_y
,
load_dy
,
store
,
rows
,
cols
)));
}
else
if
(
algorithm
==
Algorithm
::
kLogSoftmax
)
{
OF_CUDA_CHECK
((
cuda
::
softmax
::
DispatchLogSoftmaxGrad
<
decltype
(
load_y
),
decltype
(
load_dy
),
decltype
(
store
),
ComputeType
>
(
cuda_stream
,
load_y
,
load_dy
,
store
,
rows
,
cols
)));
}
else
{
UNIMPLEMENTED
();
}
}
template
<
typename
SoftmaxBackwardBase
,
Algorithm
algorithm
,
typename
T
>
class
SoftmaxBackwardImpl
:
public
SoftmaxBackwardBase
{
public:
OF_DISALLOW_COPY_AND_MOVE
(
SoftmaxBackwardImpl
);
SoftmaxBackwardImpl
()
=
default
;
~
SoftmaxBackwardImpl
()
override
=
default
;
void
Launch
(
Stream
*
stream
,
size_t
rows
,
size_t
cols
,
const
void
*
y
,
const
void
*
dy
,
void
*
dx
)
override
{
hipStream_t
cuda_stream
=
stream
->
As
<
CudaStream
>
()
->
cuda_stream
();
SoftmaxBackwardGpu
<
algorithm
,
T
>
(
cuda_stream
,
rows
,
cols
,
reinterpret_cast
<
const
T
*>
(
y
),
reinterpret_cast
<
const
T
*>
(
dy
),
reinterpret_cast
<
T
*>
(
dx
));
}
};
template
<
typename
SoftmaxBackwardBase
,
Algorithm
algorithm
,
typename
T
>
std
::
unique_ptr
<
SoftmaxBackwardBase
>
NewSoftmaxBackward
()
{
return
std
::
unique_ptr
<
SoftmaxBackwardBase
>
(
new
SoftmaxBackwardImpl
<
SoftmaxBackwardBase
,
algorithm
,
T
>
());
}
template
<
typename
BackwardFactoryBase
,
typename
SoftmaxBackwardBase
,
Algorithm
algorithm
>
class
GenericSoftmaxBackwardFactoryImpl
:
public
BackwardFactoryBase
{
public:
OF_DISALLOW_COPY_AND_MOVE
(
GenericSoftmaxBackwardFactoryImpl
);
GenericSoftmaxBackwardFactoryImpl
()
=
default
;
~
GenericSoftmaxBackwardFactoryImpl
()
override
=
default
;
std
::
unique_ptr
<
SoftmaxBackwardBase
>
New
(
DataType
data_type
)
override
{
#define MAKE_NEW_SOFTMAX_ENTRY(type_cpp, type_proto) \
{type_proto, NewSoftmaxBackward<SoftmaxBackwardBase, algorithm, type_cpp>},
static
const
std
::
map
<
DataType
,
std
::
function
<
std
::
unique_ptr
<
SoftmaxBackwardBase
>
()
>>
new_softmax_backward_handle
{
OF_PP_FOR_EACH_TUPLE
(
MAKE_NEW_SOFTMAX_ENTRY
,
CUDA_PRIMITIVE_FLOATING_TYPE_SEQ
)};
#undef MAKE_NEW_SOFTMAX_ENTRY
const
auto
it
=
new_softmax_backward_handle
.
find
(
data_type
);
if
(
it
!=
new_softmax_backward_handle
.
end
())
{
return
it
->
second
();
}
else
{
return
nullptr
;
}
}
};
using
SoftmaxBackwardFactoryImpl
=
GenericSoftmaxBackwardFactoryImpl
<
SoftmaxBackwardFactory
,
SoftmaxBackward
,
Algorithm
::
kSoftmax
>
;
using
LogSoftmaxBackwardFactoryImpl
=
GenericSoftmaxBackwardFactoryImpl
<
LogSoftmaxBackwardFactory
,
LogSoftmaxBackward
,
Algorithm
::
kLogSoftmax
>
;
REGISTER_PRIMITIVE_FACTORY
(
DeviceType
::
kCUDA
,
SoftmaxBackwardFactory
,
SoftmaxBackwardFactoryImpl
);
REGISTER_PRIMITIVE_FACTORY
(
DeviceType
::
kCUDA
,
LogSoftmaxBackwardFactory
,
LogSoftmaxBackwardFactoryImpl
);
}
// namespace
}
// namespace primitive
}
// namespace ep
}
// namespace oneflow
oneflow/core/ep/rocm/primitive/type_seq.h
View file @
8f7de847
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_EP_CUDA_PRIMITIVE_TYPE_SEQ_H_
#define ONEFLOW_CORE_EP_CUDA_PRIMITIVE_TYPE_SEQ_H_
#include "oneflow/core/common/preprocessor.h"
#include "oneflow/core/common/data_type.h"
#ifdef WITH_ROCM
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
// #if CUDA_VERSION >= 11000
// #include <cuda_bf16.h>
// #endif // CUDA_VERSION >= 11000
#define CUDA_PRIMITIVE_BOOL_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(bool, DataType::kBool)
#define CUDA_PRIMITIVE_CHAR_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(char, DataType::kChar)
#define CUDA_PRIMITIVE_INT8_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(int8_t, DataType::kInt8)
#define CUDA_PRIMITIVE_UINT8_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(uint8_t, DataType::kUInt8)
#define CUDA_PRIMITIVE_INT32_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(int32_t, DataType::kInt32)
#define CUDA_PRIMITIVE_UINT32_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(uint32_t, DataType::kUInt32)
#define CUDA_PRIMITIVE_INT64_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(int64_t, DataType::kInt64)
#define CUDA_PRIMITIVE_UINT64_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(uint64_t, DataType::kUInt64)
#define CUDA_PRIMITIVE_FLOAT_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(float, DataType::kFloat)
#define CUDA_PRIMITIVE_DOUBLE_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(double, DataType::kDouble)
#define CUDA_PRIMITIVE_FLOAT16_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(half, DataType::kFloat16)
// #if CUDA_VERSION >= 11000
// #define CUDA_PRIMITIVE_BFLOAT16_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(nv_bfloat16, DataType::kBFloat16)
// #else
#define CUDA_PRIMITIVE_BFLOAT16_TYPE_SEQ
// #endif // CUDA_VERSION >= 11000
#define CUDA_PRIMITIVE_ALL_TYPE_SEQ \
CUDA_PRIMITIVE_BOOL_TYPE_SEQ
\
CUDA_PRIMITIVE_CHAR_TYPE_SEQ
\
CUDA_PRIMITIVE_INT8_TYPE_SEQ
\
CUDA_PRIMITIVE_UINT8_TYPE_SEQ
\
CUDA_PRIMITIVE_INT32_TYPE_SEQ
\
CUDA_PRIMITIVE_INT64_TYPE_SEQ
\
CUDA_PRIMITIVE_FLOAT_TYPE_SEQ
\
CUDA_PRIMITIVE_DOUBLE_TYPE_SEQ
\
CUDA_PRIMITIVE_FLOAT16_TYPE_SEQ
\
CUDA_PRIMITIVE_BFLOAT16_TYPE_SEQ
#define CUDA_PRIMITIVE_FLOATING_TYPE_SEQ \
CUDA_PRIMITIVE_FLOAT_TYPE_SEQ
\
CUDA_PRIMITIVE_DOUBLE_TYPE_SEQ
\
CUDA_PRIMITIVE_FLOAT16_TYPE_SEQ
\
CUDA_PRIMITIVE_BFLOAT16_TYPE_SEQ
#define UTIL_OPS_DATA_TYPE_SEQ \
CUDA_PRIMITIVE_INT8_TYPE_SEQ
\
CUDA_PRIMITIVE_UINT8_TYPE_SEQ
\
CUDA_PRIMITIVE_INT32_TYPE_SEQ
\
CUDA_PRIMITIVE_INT64_TYPE_SEQ
\
CUDA_PRIMITIVE_FLOAT_TYPE_SEQ
\
CUDA_PRIMITIVE_DOUBLE_TYPE_SEQ
\
CUDA_PRIMITIVE_FLOAT16_TYPE_SEQ
\
CUDA_PRIMITIVE_BFLOAT16_TYPE_SEQ
#endif // WITH_ROCM
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_EP_CUDA_PRIMITIVE_TYPE_SEQ_H_
#define ONEFLOW_CORE_EP_CUDA_PRIMITIVE_TYPE_SEQ_H_
#include "oneflow/core/common/preprocessor.h"
#include "oneflow/core/common/data_type.h"
#ifdef WITH_ROCM
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
// #if CUDA_VERSION >= 11000
// #include <cuda_bf16.h>
// #endif // CUDA_VERSION >= 11000
#define CUDA_PRIMITIVE_BOOL_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(bool, DataType::kBool)
#define CUDA_PRIMITIVE_CHAR_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(char, DataType::kChar)
#define CUDA_PRIMITIVE_INT8_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(int8_t, DataType::kInt8)
#define CUDA_PRIMITIVE_UINT8_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(uint8_t, DataType::kUInt8)
#define CUDA_PRIMITIVE_INT32_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(int32_t, DataType::kInt32)
#define CUDA_PRIMITIVE_UINT32_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(uint32_t, DataType::kUInt32)
#define CUDA_PRIMITIVE_INT64_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(int64_t, DataType::kInt64)
#define CUDA_PRIMITIVE_UINT64_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(uint64_t, DataType::kUInt64)
#define CUDA_PRIMITIVE_FLOAT_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(float, DataType::kFloat)
#define CUDA_PRIMITIVE_DOUBLE_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(double, DataType::kDouble)
#define CUDA_PRIMITIVE_FLOAT16_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(half, DataType::kFloat16)
// #if CUDA_VERSION >= 11000
// #define CUDA_PRIMITIVE_BFLOAT16_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(nv_bfloat16, DataType::kBFloat16)
// #else
#define CUDA_PRIMITIVE_BFLOAT16_TYPE_SEQ
// #endif // CUDA_VERSION >= 11000
#define CUDA_PRIMITIVE_ALL_TYPE_SEQ \
CUDA_PRIMITIVE_BOOL_TYPE_SEQ \
CUDA_PRIMITIVE_CHAR_TYPE_SEQ \
CUDA_PRIMITIVE_INT8_TYPE_SEQ \
CUDA_PRIMITIVE_UINT8_TYPE_SEQ \
CUDA_PRIMITIVE_INT32_TYPE_SEQ \
CUDA_PRIMITIVE_INT64_TYPE_SEQ \
CUDA_PRIMITIVE_FLOAT_TYPE_SEQ \
CUDA_PRIMITIVE_DOUBLE_TYPE_SEQ \
CUDA_PRIMITIVE_FLOAT16_TYPE_SEQ \
CUDA_PRIMITIVE_BFLOAT16_TYPE_SEQ
#define CUDA_PRIMITIVE_FLOATING_TYPE_SEQ \
CUDA_PRIMITIVE_FLOAT_TYPE_SEQ \
CUDA_PRIMITIVE_DOUBLE_TYPE_SEQ \
CUDA_PRIMITIVE_FLOAT16_TYPE_SEQ \
CUDA_PRIMITIVE_BFLOAT16_TYPE_SEQ
#define UTIL_OPS_DATA_TYPE_SEQ \
CUDA_PRIMITIVE_INT8_TYPE_SEQ \
CUDA_PRIMITIVE_UINT8_TYPE_SEQ \
CUDA_PRIMITIVE_INT32_TYPE_SEQ \
CUDA_PRIMITIVE_INT64_TYPE_SEQ \
CUDA_PRIMITIVE_FLOAT_TYPE_SEQ \
CUDA_PRIMITIVE_DOUBLE_TYPE_SEQ \
CUDA_PRIMITIVE_FLOAT16_TYPE_SEQ \
CUDA_PRIMITIVE_BFLOAT16_TYPE_SEQ
#endif // WITH_ROCM
#endif // ONEFLOW_CORE_EP_CUDA_PRIMITIVE_TYPE_SEQ_H_
\ No newline at end of file
oneflow/core/ep/rocm/primitive/unary_functor.hip.h
View file @
8f7de847
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/ep/common/primitive/unary_functor.h"
#include "oneflow/core/ep/rocm/primitive/type_seq.h"
#include "oneflow/core/hip/elementwise.hip.h"
#include "oneflow/core/ep/rocm/cuda_stream.h"
namespace
oneflow
{
namespace
ep
{
namespace
primitive
{
template
<
typename
Dst
,
typename
Src
>
struct
UnaryFunctor
<
DeviceType
::
kCUDA
,
UnaryOp
::
kGelu
,
Dst
,
Src
>
{
UnaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
Dst
operator
()(
Src
src
)
const
{
return
static_cast
<
Src
>
(
0.5
)
*
src
*
(
static_cast
<
Src
>
(
1.0
)
+
erf
(
static_cast
<
Src
>
(
M_SQRT1_2
)
*
src
));
}
};
template
<
>
struct
UnaryFunctor
<
DeviceType
::
kCUDA
,
UnaryOp
::
kTanh
,
float
,
float
>
{
UnaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
float
operator
()(
float
src
)
const
{
return
tanhf
(
src
);
}
};
template
<
>
struct
UnaryFunctor
<
DeviceType
::
kCUDA
,
UnaryOp
::
kTanh
,
double
,
double
>
{
UnaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
double
operator
()(
double
src
)
const
{
return
tanh
(
src
);
}
};
template
<
>
struct
UnaryFunctor
<
DeviceType
::
kCUDA
,
UnaryOp
::
kTanh
,
half
,
half
>
{
UnaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
half
operator
()(
half
src
)
const
{
return
__float2half
(
tanhf
(
__half2float
(
src
)));
}
};
template
<
>
struct
UnaryFunctor
<
DeviceType
::
kCUDA
,
UnaryOp
::
kIsInf
,
bool
,
half
>
{
UnaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
bool
operator
()(
half
src
)
const
{
return
isinf
(
__half2float
(
src
));
}
};
template
<
>
struct
UnaryFunctor
<
DeviceType
::
kCUDA
,
UnaryOp
::
kIsInf
,
bool
,
float
>
{
UnaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
bool
operator
()(
float
src
)
const
{
return
isinf
(
src
);
}
};
template
<
>
struct
UnaryFunctor
<
DeviceType
::
kCUDA
,
UnaryOp
::
kIsInf
,
bool
,
double
>
{
UnaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
bool
operator
()(
double
src
)
const
{
return
isinf
(
src
);
}
};
template
<
>
struct
UnaryFunctor
<
DeviceType
::
kCUDA
,
UnaryOp
::
kIsNan
,
bool
,
half
>
{
UnaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
bool
operator
()(
half
src
)
const
{
return
isnan
(
__half2float
(
src
));
}
};
template
<
>
struct
UnaryFunctor
<
DeviceType
::
kCUDA
,
UnaryOp
::
kIsNan
,
bool
,
float
>
{
UnaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
bool
operator
()(
float
src
)
const
{
return
isnan
(
src
);
}
};
template
<
>
struct
UnaryFunctor
<
DeviceType
::
kCUDA
,
UnaryOp
::
kIsNan
,
bool
,
double
>
{
UnaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
bool
operator
()(
double
src
)
const
{
return
isnan
(
src
);
}
};
#define SPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR(op) \
template
<
>
\
struct
UnaryFunctor
<
DeviceType
::
kCUDA
,
op
,
half
,
half
>
{
\
UnaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
:
float_functor
(
attr0
,
attr1
)
{}
\
\
UnaryFunctor
<
DeviceType
::
kCUDA
,
op
,
float
,
float
>
float_functor
;
\
OF_DEVICE_FUNC
half
operator
()(
half
src
)
const
{
\
return
__float2half
(
float_functor
(
__half2float
(
src
)));
\
}
\
};
SPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR
(
UnaryOp
::
kElu
);
SPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR
(
UnaryOp
::
kCelu
);
SPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR
(
UnaryOp
::
kGelu
);
SPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR
(
UnaryOp
::
kMish
);
SPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR
(
UnaryOp
::
kSelu
);
SPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR
(
UnaryOp
::
kSilu
);
SPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR
(
UnaryOp
::
kSoftSign
);
SPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR
(
UnaryOp
::
kSoftPlus
);
// /*********nv_bfloat16_kernel*******/
// #if CUDA_VERSION >= 11000
// #define SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(op) \
// template<> \
// struct UnaryFunctor<DeviceType::kCUDA, op, nv_bfloat16, nv_bfloat16> { \
// UnaryFunctor(Scalar attr0, Scalar attr1) : float_functor(attr0, attr1) {} \
// \
// UnaryFunctor<DeviceType::kCUDA, op, float, float> float_functor; \
// OF_DEVICE_FUNC nv_bfloat16 operator()(nv_bfloat16 src) const { \
// return __float2bfloat16(float_functor(__bfloat162float(src))); \
// } \
// };
// SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kElu);
// SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kCelu);
// SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kGelu);
// SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kHardSwish);
// SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kHardSigmoid);
// SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kHardShrink);
// SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kHardTanh);
// SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kLeakyRelu);
// SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kMish);
// SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kSelu);
// SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kSilu);
// SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kSoftShrink);
// SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kSoftSign);
// SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kSoftPlus);
// SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kTanh);
// SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kThreshold);
// template<>
// struct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kIsInf, bool, nv_bfloat16> {
// UnaryFunctor(Scalar attr0, Scalar attr1) {}
// OF_DEVICE_FUNC bool operator()(nv_bfloat16 src) const { return isinf(__bfloat162float(src)); }
// };
// template<>
// struct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kIsNan, bool, nv_bfloat16> {
// UnaryFunctor(Scalar attr0, Scalar attr1) {}
// OF_DEVICE_FUNC bool operator()(nv_bfloat16 src) const { return isnan(__bfloat162float(src)); }
// };
// #endif
}
// namespace primitive
}
// namespace ep
}
// namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/ep/common/primitive/unary_functor.h"
#include "oneflow/core/ep/rocm/primitive/type_seq.h"
#include "oneflow/core/hip/elementwise.hip.h"
#include "oneflow/core/ep/rocm/cuda_stream.h"
namespace
oneflow
{
namespace
ep
{
namespace
primitive
{
template
<
typename
Dst
,
typename
Src
>
struct
UnaryFunctor
<
DeviceType
::
kCUDA
,
UnaryOp
::
kGelu
,
Dst
,
Src
>
{
UnaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
Dst
operator
()(
Src
src
)
const
{
return
static_cast
<
Src
>
(
0.5
)
*
src
*
(
static_cast
<
Src
>
(
1.0
)
+
erf
(
static_cast
<
Src
>
(
M_SQRT1_2
)
*
src
));
}
};
template
<
>
struct
UnaryFunctor
<
DeviceType
::
kCUDA
,
UnaryOp
::
kTanh
,
float
,
float
>
{
UnaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
float
operator
()(
float
src
)
const
{
return
tanhf
(
src
);
}
};
template
<
>
struct
UnaryFunctor
<
DeviceType
::
kCUDA
,
UnaryOp
::
kTanh
,
double
,
double
>
{
UnaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
double
operator
()(
double
src
)
const
{
return
tanh
(
src
);
}
};
template
<
>
struct
UnaryFunctor
<
DeviceType
::
kCUDA
,
UnaryOp
::
kTanh
,
half
,
half
>
{
UnaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
half
operator
()(
half
src
)
const
{
return
__float2half
(
tanhf
(
__half2float
(
src
)));
}
};
template
<
>
struct
UnaryFunctor
<
DeviceType
::
kCUDA
,
UnaryOp
::
kIsInf
,
bool
,
half
>
{
UnaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
bool
operator
()(
half
src
)
const
{
return
isinf
(
__half2float
(
src
));
}
};
template
<
>
struct
UnaryFunctor
<
DeviceType
::
kCUDA
,
UnaryOp
::
kIsInf
,
bool
,
float
>
{
UnaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
bool
operator
()(
float
src
)
const
{
return
isinf
(
src
);
}
};
template
<
>
struct
UnaryFunctor
<
DeviceType
::
kCUDA
,
UnaryOp
::
kIsInf
,
bool
,
double
>
{
UnaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
bool
operator
()(
double
src
)
const
{
return
isinf
(
src
);
}
};
template
<
>
struct
UnaryFunctor
<
DeviceType
::
kCUDA
,
UnaryOp
::
kIsNan
,
bool
,
half
>
{
UnaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
bool
operator
()(
half
src
)
const
{
return
isnan
(
__half2float
(
src
));
}
};
template
<
>
struct
UnaryFunctor
<
DeviceType
::
kCUDA
,
UnaryOp
::
kIsNan
,
bool
,
float
>
{
UnaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
bool
operator
()(
float
src
)
const
{
return
isnan
(
src
);
}
};
template
<
>
struct
UnaryFunctor
<
DeviceType
::
kCUDA
,
UnaryOp
::
kIsNan
,
bool
,
double
>
{
UnaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
bool
operator
()(
double
src
)
const
{
return
isnan
(
src
);
}
};
#define SPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR(op) \
template<> \
struct UnaryFunctor<DeviceType::kCUDA, op, half, half> { \
UnaryFunctor(Scalar attr0, Scalar attr1) : float_functor(attr0, attr1) {} \
\
UnaryFunctor<DeviceType::kCUDA, op, float, float> float_functor; \
OF_DEVICE_FUNC half operator()(half src) const { \
return __float2half(float_functor(__half2float(src))); \
} \
};
SPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR
(
UnaryOp
::
kElu
);
SPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR
(
UnaryOp
::
kCelu
);
SPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR
(
UnaryOp
::
kGelu
);
SPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR
(
UnaryOp
::
kMish
);
SPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR
(
UnaryOp
::
kSelu
);
SPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR
(
UnaryOp
::
kSilu
);
SPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR
(
UnaryOp
::
kSoftSign
);
SPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR
(
UnaryOp
::
kSoftPlus
);
// /*********nv_bfloat16_kernel*******/
// #if CUDA_VERSION >= 11000
// #define SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(op) \
// template<> \
// struct UnaryFunctor<DeviceType::kCUDA, op, nv_bfloat16, nv_bfloat16> { \
// UnaryFunctor(Scalar attr0, Scalar attr1) : float_functor(attr0, attr1) {} \
// \
// UnaryFunctor<DeviceType::kCUDA, op, float, float> float_functor; \
// OF_DEVICE_FUNC nv_bfloat16 operator()(nv_bfloat16 src) const { \
// return __float2bfloat16(float_functor(__bfloat162float(src))); \
// } \
// };
// SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kElu);
// SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kCelu);
// SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kGelu);
// SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kHardSwish);
// SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kHardSigmoid);
// SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kHardShrink);
// SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kHardTanh);
// SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kLeakyRelu);
// SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kMish);
// SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kSelu);
// SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kSilu);
// SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kSoftShrink);
// SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kSoftSign);
// SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kSoftPlus);
// SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kTanh);
// SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kThreshold);
// template<>
// struct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kIsInf, bool, nv_bfloat16> {
// UnaryFunctor(Scalar attr0, Scalar attr1) {}
// OF_DEVICE_FUNC bool operator()(nv_bfloat16 src) const { return isinf(__bfloat162float(src)); }
// };
// template<>
// struct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kIsNan, bool, nv_bfloat16> {
// UnaryFunctor(Scalar attr0, Scalar attr1) {}
// OF_DEVICE_FUNC bool operator()(nv_bfloat16 src) const { return isnan(__bfloat162float(src)); }
// };
// #endif
}
// namespace primitive
}
// namespace ep
}
// namespace oneflow
Prev
1
2
3
4
5
6
7
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment