Commit 8f7de847 authored by yuguo960516yuguo's avatar yuguo960516yuguo
Browse files

dtk

parent f262efc9
Pipeline #248 failed with stages
in 0 seconds
/* /*
Copyright 2020 The OneFlow Authors. All rights reserved. Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
*/ */
#include "oneflow/core/ep/common/primitive/binary_functor.h" #include "oneflow/core/ep/common/primitive/binary_functor.h"
namespace oneflow { namespace oneflow {
namespace ep { namespace ep {
namespace primitive { namespace primitive {
namespace broadcast_elementwise_binary { namespace broadcast_elementwise_binary {
template<typename Src, typename Dst> template<typename Src, typename Dst>
struct BinaryFunctor<DeviceType::kCUDA, BinaryOp::kPow, Src, Dst> { struct BinaryFunctor<DeviceType::kCUDA, BinaryOp::kPow, Src, Dst> {
OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}
OF_DEVICE_FUNC Dst operator()(Src src0, Src src1) const { return pow(src0, src1); } OF_DEVICE_FUNC Dst operator()(Src src0, Src src1) const { return pow(src0, src1); }
}; };
template<> template<>
struct BinaryFunctor<DeviceType::kCUDA, BinaryOp::kPow, bool, bool> { struct BinaryFunctor<DeviceType::kCUDA, BinaryOp::kPow, bool, bool> {
OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}
OF_DEVICE_FUNC bool operator()(bool src0, bool src1) const { OF_DEVICE_FUNC bool operator()(bool src0, bool src1) const {
return static_cast<bool>(pow(static_cast<double>(src0), static_cast<double>(src1))); return static_cast<bool>(pow(static_cast<double>(src0), static_cast<double>(src1)));
} }
}; };
template<> template<>
struct BinaryFunctor<DeviceType::kCUDA, BinaryOp::kPow, half, half> { struct BinaryFunctor<DeviceType::kCUDA, BinaryOp::kPow, half, half> {
OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}
OF_DEVICE_FUNC half operator()(half src0, half src1) const { OF_DEVICE_FUNC half operator()(half src0, half src1) const {
return static_cast<half>(pow(static_cast<float>(src0), static_cast<float>(src1))); return static_cast<half>(pow(static_cast<float>(src0), static_cast<float>(src1)));
} }
}; };
template<typename Src, typename Dst> template<typename Src, typename Dst>
struct BinaryFunctor<DeviceType::kCUDA, BinaryOp::kGeluBackwardWithDyX, Src, Dst> { struct BinaryFunctor<DeviceType::kCUDA, BinaryOp::kGeluBackwardWithDyX, Src, Dst> {
OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) { OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {
#if defined(__CUDA_ARCH__) #if defined(__CUDA_ARCH__)
coef = sqrt(static_cast<Src>(2.0) / acos(static_cast<Src>(-1.0))); coef = sqrt(static_cast<Src>(2.0) / acos(static_cast<Src>(-1.0)));
#elif defined(__HIP_DEVICE_COMPILE__) #elif defined(__HIP_DEVICE_COMPILE__)
coef = sqrt(static_cast<Src>(2.0) / acos(static_cast<Src>(-1.0))); coef = sqrt(static_cast<Src>(2.0) / acos(static_cast<Src>(-1.0)));
#else #else
coef = std::sqrt(static_cast<Src>(2.0) / std::acos(static_cast<Src>(-1.0))); coef = std::sqrt(static_cast<Src>(2.0) / std::acos(static_cast<Src>(-1.0)));
#endif #endif
} }
OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const { OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const {
return static_cast<Src>(0.5) return static_cast<Src>(0.5)
* (static_cast<Src>(1.0) + erf(static_cast<Src>(M_SQRT1_2) * x) * (static_cast<Src>(1.0) + erf(static_cast<Src>(M_SQRT1_2) * x)
+ x * coef * exp(static_cast<Src>(-0.5) * x * x)) + x * coef * exp(static_cast<Src>(-0.5) * x * x))
* dy; * dy;
} }
Src coef; Src coef;
}; };
template<typename Src, typename Dst> template<typename Src, typename Dst>
struct BinaryFunctor<DeviceType::kCUDA, BinaryOp::kTanhBackwardWithDyX, Src, Dst> { struct BinaryFunctor<DeviceType::kCUDA, BinaryOp::kTanhBackwardWithDyX, Src, Dst> {
OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {} OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}
OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const { OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const {
Src tanh_val = tanh(x); Src tanh_val = tanh(x);
return static_cast<Dst>(dy * (static_cast<Src>(1.0) - tanh_val * tanh_val)); return static_cast<Dst>(dy * (static_cast<Src>(1.0) - tanh_val * tanh_val));
} }
}; };
// /*********nv_bfloat16_kernel*******/ // /*********nv_bfloat16_kernel*******/
// #if CUDA_VERSION >= 11000 // #if CUDA_VERSION >= 11000
// template<> // template<>
// struct BinaryFunctor<DeviceType::kCUDA, BinaryOp::kPow, nv_bfloat16, nv_bfloat16> { // struct BinaryFunctor<DeviceType::kCUDA, BinaryOp::kPow, nv_bfloat16, nv_bfloat16> {
// OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {} // OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}
// OF_DEVICE_FUNC nv_bfloat16 operator()(nv_bfloat16 src0, nv_bfloat16 src1) const { // OF_DEVICE_FUNC nv_bfloat16 operator()(nv_bfloat16 src0, nv_bfloat16 src1) const {
// return static_cast<nv_bfloat16>(pow(static_cast<float>(src0), static_cast<float>(src1))); // return static_cast<nv_bfloat16>(pow(static_cast<float>(src0), static_cast<float>(src1)));
// } // }
// }; // };
// #define SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(op) \ // #define SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(op) \
// template<> \ // template<> \
// struct BinaryFunctor<DeviceType::kCUDA, op, nv_bfloat16, nv_bfloat16> { \ // struct BinaryFunctor<DeviceType::kCUDA, op, nv_bfloat16, nv_bfloat16> { \
// OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) : float_functor(attr0, attr1) {} \ // OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) : float_functor(attr0, attr1) {} \
// \ // \
// BinaryFunctor<DeviceType::kCUDA, op, float, float> float_functor; \ // BinaryFunctor<DeviceType::kCUDA, op, float, float> float_functor; \
// OF_DEVICE_FUNC nv_bfloat16 operator()(nv_bfloat16 src0, nv_bfloat16 src1) const { \ // OF_DEVICE_FUNC nv_bfloat16 operator()(nv_bfloat16 src0, nv_bfloat16 src1) const { \
// return __float2bfloat16(float_functor(__bfloat162float(src0), __bfloat162float(src1))); \ // return __float2bfloat16(float_functor(__bfloat162float(src0), __bfloat162float(src1))); \
// } \ // } \
// }; // };
// SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kEluBackwardWithDyX); // SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kEluBackwardWithDyX);
// SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kCeluBackwardWithDyX); // SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kCeluBackwardWithDyX);
// SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kGeluBackwardWithDyX); // SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kGeluBackwardWithDyX);
// SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kHardswishBackwardWithDyX); // SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kHardswishBackwardWithDyX);
// SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kHardsigmoidBackwardWithDyX); // SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kHardsigmoidBackwardWithDyX);
// SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kHardshrinkBackwardWithDyY); // SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kHardshrinkBackwardWithDyY);
// SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kHardtanhBackwardWithDyY); // SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kHardtanhBackwardWithDyY);
// SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kLeakyReluBackwardWithDyX); // SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kLeakyReluBackwardWithDyX);
// SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kMishBackwardWithDyX); // SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kMishBackwardWithDyX);
// SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kSeluBackwardWithDyX); // SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kSeluBackwardWithDyX);
// SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kSiluBackwardWithDyX); // SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kSiluBackwardWithDyX);
// SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kSoftsignBackwardWithDyX); // SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kSoftsignBackwardWithDyX);
// SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kSoftplusBackwardWithDyX); // SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kSoftplusBackwardWithDyX);
// SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kSoftshrinkBackwardWithDyY); // SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kSoftshrinkBackwardWithDyY);
// SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kTanhBackwardWithDyX); // SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kTanhBackwardWithDyX);
// SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kThresholdBackwardWithDyX); // SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kThresholdBackwardWithDyX);
// #endif // CUDA_VERSION >= 11000 // #endif // CUDA_VERSION >= 11000
#define SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(op) \ #define SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(op) \
template<> \ template<> \
struct BinaryFunctor<DeviceType::kCUDA, op, half, half> { \ struct BinaryFunctor<DeviceType::kCUDA, op, half, half> { \
OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) : float_functor(attr0, attr1) {} \ OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) : float_functor(attr0, attr1) {} \
\ \
BinaryFunctor<DeviceType::kCUDA, op, float, float> float_functor; \ BinaryFunctor<DeviceType::kCUDA, op, float, float> float_functor; \
OF_DEVICE_FUNC half operator()(half src0, half src1) const { \ OF_DEVICE_FUNC half operator()(half src0, half src1) const { \
return __float2half(float_functor(__half2float(src0), __half2float(src1))); \ return __float2half(float_functor(__half2float(src0), __half2float(src1))); \
} \ } \
}; };
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kEluBackwardWithDyX); SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kEluBackwardWithDyX);
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kCeluBackwardWithDyX); SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kCeluBackwardWithDyX);
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kGeluBackwardWithDyX); SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kGeluBackwardWithDyX);
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kHardswishBackwardWithDyX); SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kHardswishBackwardWithDyX);
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kHardshrinkBackwardWithDyY); SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kHardshrinkBackwardWithDyY);
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kMishBackwardWithDyX); SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kMishBackwardWithDyX);
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kSiluBackwardWithDyX); SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kSiluBackwardWithDyX);
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kSeluBackwardWithDyX); SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kSeluBackwardWithDyX);
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kSoftplusBackwardWithDyX); SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kSoftplusBackwardWithDyX);
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kSoftsignBackwardWithDyX); SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kSoftsignBackwardWithDyX);
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kSoftshrinkBackwardWithDyY); SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kSoftshrinkBackwardWithDyY);
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kThresholdBackwardWithDyX); SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kThresholdBackwardWithDyX);
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kTanhBackwardWithDyX); SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kTanhBackwardWithDyX);
} // namespace broadcast_elementwise_binary } // namespace broadcast_elementwise_binary
} // namespace primitive } // namespace primitive
} // namespace ep } // namespace ep
} // namespace oneflow } // namespace oneflow
\ No newline at end of file
/* /*
Copyright 2020 The OneFlow Authors. All rights reserved. Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
*/ */
#include "oneflow/core/ep/include/primitive/broadcast_elementwise_binary.h" #include "oneflow/core/ep/include/primitive/broadcast_elementwise_binary.h"
#include "oneflow/core/ep/common/primitive/broadcast_elementwise_binary.h" #include "oneflow/core/ep/common/primitive/broadcast_elementwise_binary.h"
#include "oneflow/core/ep/rocm/primitive/type_seq.h" #include "oneflow/core/ep/rocm/primitive/type_seq.h"
#include "oneflow/core/ep/rocm/cuda_stream.h" #include "oneflow/core/ep/rocm/cuda_stream.h"
#include "oneflow/core/hip/elementwise.hip.h" #include "oneflow/core/hip/elementwise.hip.h"
#include "oneflow/core/ep/rocm/primitive/binary_functor.hip.h" #include "oneflow/core/ep/rocm/primitive/binary_functor.hip.h"
namespace oneflow { namespace oneflow {
namespace ep { namespace ep {
namespace primitive { namespace primitive {
namespace broadcast_elementwise_binary { namespace broadcast_elementwise_binary {
template<BinaryOp binary_op, typename Src, typename Dst> template<BinaryOp binary_op, typename Src, typename Dst>
std::unique_ptr<BroadcastElementwiseBinary> NewBroadcastElementwiseBinary(Scalar attr0, std::unique_ptr<BroadcastElementwiseBinary> NewBroadcastElementwiseBinary(Scalar attr0,
Scalar attr1); Scalar attr1);
namespace { namespace {
class BroadcastElementwiseBinaryFactoryImpl : public BroadcastElementwiseBinaryFactory { class BroadcastElementwiseBinaryFactoryImpl : public BroadcastElementwiseBinaryFactory {
public: public:
OF_DISALLOW_COPY_AND_MOVE(BroadcastElementwiseBinaryFactoryImpl); OF_DISALLOW_COPY_AND_MOVE(BroadcastElementwiseBinaryFactoryImpl);
BroadcastElementwiseBinaryFactoryImpl() = default; BroadcastElementwiseBinaryFactoryImpl() = default;
~BroadcastElementwiseBinaryFactoryImpl() override = default; ~BroadcastElementwiseBinaryFactoryImpl() override = default;
std::unique_ptr<BroadcastElementwiseBinary> New(BinaryOp op, DataType src_type, DataType dst_type, std::unique_ptr<BroadcastElementwiseBinary> New(BinaryOp op, DataType src_type, DataType dst_type,
size_t max_num_dims) override { size_t max_num_dims) override {
return New(op, src_type, dst_type, max_num_dims, Scalar(), Scalar()); return New(op, src_type, dst_type, max_num_dims, Scalar(), Scalar());
} }
std::unique_ptr<BroadcastElementwiseBinary> New(BinaryOp op, DataType src_type, DataType dst_type, std::unique_ptr<BroadcastElementwiseBinary> New(BinaryOp op, DataType src_type, DataType dst_type,
size_t max_num_dims, Scalar attr0) override { size_t max_num_dims, Scalar attr0) override {
return New(op, src_type, dst_type, max_num_dims, attr0, Scalar()); return New(op, src_type, dst_type, max_num_dims, attr0, Scalar());
} }
std::unique_ptr<BroadcastElementwiseBinary> New(BinaryOp binary_op, DataType src_type, std::unique_ptr<BroadcastElementwiseBinary> New(BinaryOp binary_op, DataType src_type,
DataType dst_type, size_t max_num_dims, DataType dst_type, size_t max_num_dims,
Scalar attr0, Scalar attr1) override { Scalar attr0, Scalar attr1) override {
if (max_num_dims > kMaxNumDims) { return nullptr; } if (max_num_dims > kMaxNumDims) { return nullptr; }
#define MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY(binary_op, data_type_pair) \ #define MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY(binary_op, data_type_pair) \
{std::make_tuple(binary_op, OF_PP_PAIR_SECOND(data_type_pair), \ {std::make_tuple(binary_op, OF_PP_PAIR_SECOND(data_type_pair), \
OF_PP_PAIR_SECOND(data_type_pair)), \ OF_PP_PAIR_SECOND(data_type_pair)), \
NewBroadcastElementwiseBinary<binary_op, OF_PP_PAIR_FIRST(data_type_pair), \ NewBroadcastElementwiseBinary<binary_op, OF_PP_PAIR_FIRST(data_type_pair), \
OF_PP_PAIR_FIRST(data_type_pair)>}, OF_PP_PAIR_FIRST(data_type_pair)>},
#define MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_COMPARASION_AND_LOGICAL_ENTRY( \ #define MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_COMPARASION_AND_LOGICAL_ENTRY( \
binary_op, src_data_type_pair, dst_data_type_pair) \ binary_op, src_data_type_pair, dst_data_type_pair) \
{std::make_tuple(binary_op, OF_PP_PAIR_SECOND(src_data_type_pair), \ {std::make_tuple(binary_op, OF_PP_PAIR_SECOND(src_data_type_pair), \
OF_PP_PAIR_SECOND(dst_data_type_pair)), \ OF_PP_PAIR_SECOND(dst_data_type_pair)), \
NewBroadcastElementwiseBinary<binary_op, OF_PP_PAIR_FIRST(src_data_type_pair), \ NewBroadcastElementwiseBinary<binary_op, OF_PP_PAIR_FIRST(src_data_type_pair), \
OF_PP_PAIR_FIRST(dst_data_type_pair)>}, OF_PP_PAIR_FIRST(dst_data_type_pair)>},
#define MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_ACTIVATION_GRAD_ENTRY(binary_op, data_type_pair) \ #define MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_ACTIVATION_GRAD_ENTRY(binary_op, data_type_pair) \
{std::make_tuple(binary_op, OF_PP_PAIR_SECOND(data_type_pair), \ {std::make_tuple(binary_op, OF_PP_PAIR_SECOND(data_type_pair), \
OF_PP_PAIR_SECOND(data_type_pair)), \ OF_PP_PAIR_SECOND(data_type_pair)), \
NewBroadcastElementwiseBinary<binary_op, OF_PP_PAIR_FIRST(data_type_pair), \ NewBroadcastElementwiseBinary<binary_op, OF_PP_PAIR_FIRST(data_type_pair), \
OF_PP_PAIR_FIRST(data_type_pair)>}, OF_PP_PAIR_FIRST(data_type_pair)>},
static const std::map< static const std::map<
std::tuple<BinaryOp, DataType, DataType>, std::tuple<BinaryOp, DataType, DataType>,
std::function<std::unique_ptr<BroadcastElementwiseBinary>(Scalar, Scalar)>> std::function<std::unique_ptr<BroadcastElementwiseBinary>(Scalar, Scalar)>>
new_broadcast_elementwise_binary_handle{ new_broadcast_elementwise_binary_handle{
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY, OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY,
BINARY_MATH_OP_SEQ, CUDA_PRIMITIVE_ALL_TYPE_SEQ) BINARY_MATH_OP_SEQ, CUDA_PRIMITIVE_ALL_TYPE_SEQ)
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE( OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(
MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_COMPARASION_AND_LOGICAL_ENTRY, MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_COMPARASION_AND_LOGICAL_ENTRY,
BINARY_COMPARISION_OP_SEQ BINARY_LOGICAL_OP_SEQ, CUDA_PRIMITIVE_ALL_TYPE_SEQ, BINARY_COMPARISION_OP_SEQ BINARY_LOGICAL_OP_SEQ, CUDA_PRIMITIVE_ALL_TYPE_SEQ,
CUDA_PRIMITIVE_BOOL_TYPE_SEQ) CUDA_PRIMITIVE_BOOL_TYPE_SEQ)
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE( OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(
MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_ACTIVATION_GRAD_ENTRY, MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_ACTIVATION_GRAD_ENTRY,
BINARY_ACTIVATION_BACKWARD_OP_SEQ, CUDA_PRIMITIVE_FLOATING_TYPE_SEQ)}; BINARY_ACTIVATION_BACKWARD_OP_SEQ, CUDA_PRIMITIVE_FLOATING_TYPE_SEQ)};
#undef MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_COMPARASION_AND_LOGICAL_ENTRY #undef MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_COMPARASION_AND_LOGICAL_ENTRY
#undef MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY #undef MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY
const auto it = new_broadcast_elementwise_binary_handle.find( const auto it = new_broadcast_elementwise_binary_handle.find(
std::make_tuple(binary_op, src_type, dst_type)); std::make_tuple(binary_op, src_type, dst_type));
if (it != new_broadcast_elementwise_binary_handle.end()) { if (it != new_broadcast_elementwise_binary_handle.end()) {
return it->second(attr0, attr1); return it->second(attr0, attr1);
} else { } else {
return nullptr; return nullptr;
} }
} }
}; };
REGISTER_PRIMITIVE_FACTORY(DeviceType::kCUDA, BroadcastElementwiseBinaryFactory, REGISTER_PRIMITIVE_FACTORY(DeviceType::kCUDA, BroadcastElementwiseBinaryFactory,
BroadcastElementwiseBinaryFactoryImpl); BroadcastElementwiseBinaryFactoryImpl);
} // namespace } // namespace
} // namespace broadcast_elementwise_binary } // namespace broadcast_elementwise_binary
} // namespace primitive } // namespace primitive
} // namespace ep } // namespace ep
} // namespace oneflow } // namespace oneflow
\ No newline at end of file
/* /*
Copyright 2020 The OneFlow Authors. All rights reserved. Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
*/ */
#include "hip/hip_runtime.h" #include "hip/hip_runtime.h"
#include "oneflow/core/ep/include/primitive/broadcast_elementwise_binary.h" #include "oneflow/core/ep/include/primitive/broadcast_elementwise_binary.h"
#include "oneflow/core/ep/common/primitive/broadcast_elementwise_binary.h" #include "oneflow/core/ep/common/primitive/broadcast_elementwise_binary.h"
#include "oneflow/core/ep/rocm/primitive/type_seq.h" #include "oneflow/core/ep/rocm/primitive/type_seq.h"
#include "oneflow/core/ep/rocm/cuda_stream.h" #include "oneflow/core/ep/rocm/cuda_stream.h"
#include "oneflow/core/hip/elementwise.hip.h" #include "oneflow/core/hip/elementwise.hip.h"
#include "oneflow/core/ep/rocm/primitive/binary_functor.hip.h" #include "oneflow/core/ep/rocm/primitive/binary_functor.hip.h"
namespace oneflow { namespace oneflow {
namespace ep { namespace ep {
namespace primitive { namespace primitive {
namespace broadcast_elementwise_binary { namespace broadcast_elementwise_binary {
namespace { namespace {
template<typename T, int N> template<typename T, int N>
struct GetPackType { struct GetPackType {
using type = typename std::aligned_storage<N * sizeof(T), N * sizeof(T)>::type; using type = typename std::aligned_storage<N * sizeof(T), N * sizeof(T)>::type;
}; };
template<typename T, int N> template<typename T, int N>
using PackType = typename GetPackType<T, N>::type; using PackType = typename GetPackType<T, N>::type;
template<typename T, int N> template<typename T, int N>
union Pack { union Pack {
static_assert(sizeof(PackType<T, N>) == sizeof(T) * N, ""); static_assert(sizeof(PackType<T, N>) == sizeof(T) * N, "");
OF_DEVICE_FUNC Pack() { OF_DEVICE_FUNC Pack() {
// do nothing // do nothing
} }
PackType<T, N> storage; PackType<T, N> storage;
T elem[N]; T elem[N];
}; };
template<size_t max_dims, typename IndexType> template<size_t max_dims, typename IndexType>
struct BroadcastElementwiseBinaryParams { struct BroadcastElementwiseBinaryParams {
NdIndexOffsetHelper<IndexType, max_dims> src0_index_helper; NdIndexOffsetHelper<IndexType, max_dims> src0_index_helper;
NdIndexOffsetHelper<IndexType, max_dims> src1_index_helper; NdIndexOffsetHelper<IndexType, max_dims> src1_index_helper;
NdIndexOffsetHelper<IndexType, max_dims> dst_index_helper; NdIndexOffsetHelper<IndexType, max_dims> dst_index_helper;
size_t num_dims; size_t num_dims;
IndexType src0_index_mask[max_dims]; IndexType src0_index_mask[max_dims];
IndexType src1_index_mask[max_dims]; IndexType src1_index_mask[max_dims];
IndexType count{}; IndexType count{};
const void* src0{}; const void* src0{};
const void* src1{}; const void* src1{};
void* dst{}; void* dst{};
Scalar attr0; Scalar attr0;
Scalar attr1; Scalar attr1;
}; };
template<BinaryOp binary_op, typename Src, typename Dst, size_t max_dims, size_t src0_pack_size, template<BinaryOp binary_op, typename Src, typename Dst, size_t max_dims, size_t src0_pack_size,
size_t src1_pack_size, typename IndexType> size_t src1_pack_size, typename IndexType>
__global__ void BroadcastElementwiseBinaryGpu( __global__ void BroadcastElementwiseBinaryGpu(
BroadcastElementwiseBinaryParams<max_dims, IndexType> params) { BroadcastElementwiseBinaryParams<max_dims, IndexType> params) {
constexpr size_t dst_pack_size = constexpr size_t dst_pack_size =
src0_pack_size > src1_pack_size ? src0_pack_size : src1_pack_size; src0_pack_size > src1_pack_size ? src0_pack_size : src1_pack_size;
static_assert(src0_pack_size == dst_pack_size || src0_pack_size == 1, ""); static_assert(src0_pack_size == dst_pack_size || src0_pack_size == 1, "");
static_assert(src1_pack_size == dst_pack_size || src1_pack_size == 1, ""); static_assert(src1_pack_size == dst_pack_size || src1_pack_size == 1, "");
const PackType<Src, src0_pack_size>* src0 = const PackType<Src, src0_pack_size>* src0 =
reinterpret_cast<const PackType<Src, src0_pack_size>*>(params.src0); reinterpret_cast<const PackType<Src, src0_pack_size>*>(params.src0);
const PackType<Src, src1_pack_size>* src1 = const PackType<Src, src1_pack_size>* src1 =
reinterpret_cast<const PackType<Src, src1_pack_size>*>(params.src1); reinterpret_cast<const PackType<Src, src1_pack_size>*>(params.src1);
PackType<Dst, dst_pack_size>* dst = reinterpret_cast<PackType<Dst, dst_pack_size>*>(params.dst); PackType<Dst, dst_pack_size>* dst = reinterpret_cast<PackType<Dst, dst_pack_size>*>(params.dst);
IndexType src0_index[max_dims]; IndexType src0_index[max_dims];
IndexType src1_index[max_dims]; IndexType src1_index[max_dims];
IndexType dst_index[max_dims]; IndexType dst_index[max_dims];
size_t num_dims = params.num_dims; size_t num_dims = params.num_dims;
CUDA_1D_KERNEL_LOOP_T(IndexType, offset, params.count) { CUDA_1D_KERNEL_LOOP_T(IndexType, offset, params.count) {
params.dst_index_helper.OffsetToNdIndex(offset, dst_index, num_dims); params.dst_index_helper.OffsetToNdIndex(offset, dst_index, num_dims);
#pragma unroll #pragma unroll
for (int i = 0; i < max_dims; ++i) { for (int i = 0; i < max_dims; ++i) {
if (i < num_dims) { if (i < num_dims) {
src0_index[i] = params.src0_index_mask[i] * dst_index[i]; src0_index[i] = params.src0_index_mask[i] * dst_index[i];
src1_index[i] = params.src1_index_mask[i] * dst_index[i]; src1_index[i] = params.src1_index_mask[i] * dst_index[i];
} else { } else {
src0_index[i] = 0; src0_index[i] = 0;
src1_index[i] = 0; src1_index[i] = 0;
} }
} }
const IndexType src0_offset = params.src0_index_helper.NdIndexToOffset(src0_index, num_dims); const IndexType src0_offset = params.src0_index_helper.NdIndexToOffset(src0_index, num_dims);
const IndexType src1_offset = params.src1_index_helper.NdIndexToOffset(src1_index, num_dims); const IndexType src1_offset = params.src1_index_helper.NdIndexToOffset(src1_index, num_dims);
Pack<Src, src0_pack_size> src0_pack; Pack<Src, src0_pack_size> src0_pack;
src0_pack.storage = src0[src0_offset]; src0_pack.storage = src0[src0_offset];
Pack<Src, src1_pack_size> src1_pack; Pack<Src, src1_pack_size> src1_pack;
src1_pack.storage = src1[src1_offset]; src1_pack.storage = src1[src1_offset];
Pack<Dst, dst_pack_size> dst_pack; Pack<Dst, dst_pack_size> dst_pack;
BinaryFunctor<DeviceType::kCUDA, binary_op, Src, Dst> functor(params.attr0, params.attr1); BinaryFunctor<DeviceType::kCUDA, binary_op, Src, Dst> functor(params.attr0, params.attr1);
#pragma unroll #pragma unroll
for (int j = 0; j < dst_pack_size; ++j) { for (int j = 0; j < dst_pack_size; ++j) {
const Src src0_val = const Src src0_val =
(src0_pack_size == dst_pack_size) ? src0_pack.elem[j] : src0_pack.elem[0]; (src0_pack_size == dst_pack_size) ? src0_pack.elem[j] : src0_pack.elem[0];
const Src src1_val = const Src src1_val =
(src1_pack_size == dst_pack_size) ? src1_pack.elem[j] : src1_pack.elem[0]; (src1_pack_size == dst_pack_size) ? src1_pack.elem[j] : src1_pack.elem[0];
dst_pack.elem[j] = functor(src0_val, src1_val); dst_pack.elem[j] = functor(src0_val, src1_val);
} }
dst[offset] = dst_pack.storage; dst[offset] = dst_pack.storage;
} }
} }
template<BinaryOp op, typename T, typename R, size_t max_dims, size_t src0_pack_size, template<BinaryOp op, typename T, typename R, size_t max_dims, size_t src0_pack_size,
size_t src1_pack_size, typename IndexType> size_t src1_pack_size, typename IndexType>
void LaunchKernel(Stream* stream, int num_dims, const int64_t* src0_dims, const void* src0, void LaunchKernel(Stream* stream, int num_dims, const int64_t* src0_dims, const void* src0,
const int64_t* src1_dims, const void* src1, const int64_t* dst_dims, void* dst, const int64_t* src1_dims, const void* src1, const int64_t* dst_dims, void* dst,
size_t count, Scalar attr0, Scalar attr1) { size_t count, Scalar attr0, Scalar attr1) {
BroadcastElementwiseBinaryParams<max_dims, IndexType> params; BroadcastElementwiseBinaryParams<max_dims, IndexType> params;
for (size_t i = 0; i < num_dims; ++i) { for (size_t i = 0; i < num_dims; ++i) {
params.src0_index_mask[i] = (src0_dims[i] == 1) ? 0 : 1; params.src0_index_mask[i] = (src0_dims[i] == 1) ? 0 : 1;
params.src1_index_mask[i] = (src1_dims[i] == 1) ? 0 : 1; params.src1_index_mask[i] = (src1_dims[i] == 1) ? 0 : 1;
} }
params.src0_index_helper = NdIndexOffsetHelper<IndexType, max_dims>(src0_dims, num_dims); params.src0_index_helper = NdIndexOffsetHelper<IndexType, max_dims>(src0_dims, num_dims);
params.src1_index_helper = NdIndexOffsetHelper<IndexType, max_dims>(src1_dims, num_dims); params.src1_index_helper = NdIndexOffsetHelper<IndexType, max_dims>(src1_dims, num_dims);
params.dst_index_helper = NdIndexOffsetHelper<IndexType, max_dims>(dst_dims, num_dims); params.dst_index_helper = NdIndexOffsetHelper<IndexType, max_dims>(dst_dims, num_dims);
params.num_dims = num_dims; params.num_dims = num_dims;
params.src0 = src0; params.src0 = src0;
params.src1 = src1; params.src1 = src1;
params.dst = dst; params.dst = dst;
params.count = static_cast<IndexType>(count); params.count = static_cast<IndexType>(count);
params.attr0 = attr0; params.attr0 = attr0;
params.attr1 = attr1; params.attr1 = attr1;
auto* cuda_stream = stream->As<CudaStream>(); auto* cuda_stream = stream->As<CudaStream>();
BroadcastElementwiseBinaryGpu<op, T, R, max_dims, src0_pack_size, src1_pack_size, IndexType> BroadcastElementwiseBinaryGpu<op, T, R, max_dims, src0_pack_size, src1_pack_size, IndexType>
<<<BlocksNum4ThreadsNum(params.count), kCudaThreadsNumPerBlock, 0, <<<BlocksNum4ThreadsNum(params.count), kCudaThreadsNumPerBlock, 0,
cuda_stream->cuda_stream()>>>(params); cuda_stream->cuda_stream()>>>(params);
} }
template<BinaryOp op, typename T, typename R, size_t max_dims, size_t src0_pack_size, template<BinaryOp op, typename T, typename R, size_t max_dims, size_t src0_pack_size,
size_t src1_pack_size> size_t src1_pack_size>
void DispatchIndexType(Stream* stream, size_t num_dims, const int64_t* src0_dims, const void* src0, void DispatchIndexType(Stream* stream, size_t num_dims, const int64_t* src0_dims, const void* src0,
const int64_t* src1_dims, const void* src1, const int64_t* dst_dims, const int64_t* src1_dims, const void* src1, const int64_t* dst_dims,
void* dst, Scalar attr0, Scalar attr1) { void* dst, Scalar attr0, Scalar attr1) {
size_t count = GetElementCount(num_dims, dst_dims); size_t count = GetElementCount(num_dims, dst_dims);
if (count < GetMaxVal<int32_t>()) { if (count < GetMaxVal<int32_t>()) {
LaunchKernel<op, T, R, max_dims, src0_pack_size, src1_pack_size, int32_t>( LaunchKernel<op, T, R, max_dims, src0_pack_size, src1_pack_size, int32_t>(
stream, num_dims, src0_dims, src0, src1_dims, src1, dst_dims, dst, count, attr0, attr1); stream, num_dims, src0_dims, src0, src1_dims, src1, dst_dims, dst, count, attr0, attr1);
} else { } else {
LaunchKernel<op, T, R, max_dims, src0_pack_size, src1_pack_size, int64_t>( LaunchKernel<op, T, R, max_dims, src0_pack_size, src1_pack_size, int64_t>(
stream, num_dims, src0_dims, src0, src1_dims, src1, dst_dims, dst, count, attr0, attr1); stream, num_dims, src0_dims, src0, src1_dims, src1, dst_dims, dst, count, attr0, attr1);
} }
} }
template<BinaryOp op, typename T, typename R, size_t max_dims> template<BinaryOp op, typename T, typename R, size_t max_dims>
void DispatchPackSize(Stream* stream, size_t src0_pack_size, size_t src1_pack_size, size_t num_dims, void DispatchPackSize(Stream* stream, size_t src0_pack_size, size_t src1_pack_size, size_t num_dims,
const int64_t* src0_dims, const void* src0, const int64_t* src1_dims, const int64_t* src0_dims, const void* src0, const int64_t* src1_dims,
const void* src1, const int64_t* dst_dims, void* dst, Scalar attr0, const void* src1, const int64_t* dst_dims, void* dst, Scalar attr0,
Scalar attr1) { Scalar attr1) {
void (*func)(Stream* /*stream*/, size_t /*num_dims*/, const int64_t* /*src0_dims*/, void (*func)(Stream* /*stream*/, size_t /*num_dims*/, const int64_t* /*src0_dims*/,
const void* /*src0*/, const int64_t* /*src1_dims*/, const void* /*src1*/, const void* /*src0*/, const int64_t* /*src1_dims*/, const void* /*src1*/,
const int64_t* /*dst_dims*/, void* /*dst*/, Scalar /*attr0*/, Scalar /*attr1*/) = const int64_t* /*dst_dims*/, void* /*dst*/, Scalar /*attr0*/, Scalar /*attr1*/) =
nullptr; nullptr;
if (src0_pack_size == 1 && src1_pack_size == 1) { if (src0_pack_size == 1 && src1_pack_size == 1) {
func = DispatchIndexType<op, T, R, max_dims, 1, 1>; func = DispatchIndexType<op, T, R, max_dims, 1, 1>;
} else if (src0_pack_size == 4 && src1_pack_size == 4) { } else if (src0_pack_size == 4 && src1_pack_size == 4) {
func = DispatchIndexType<op, T, R, max_dims, 4, 4>; func = DispatchIndexType<op, T, R, max_dims, 4, 4>;
} else if (src0_pack_size == 1 && src1_pack_size == 4) { } else if (src0_pack_size == 1 && src1_pack_size == 4) {
func = DispatchIndexType<op, T, R, max_dims, 1, 4>; func = DispatchIndexType<op, T, R, max_dims, 1, 4>;
} else if (src0_pack_size == 4 && src1_pack_size == 1) { } else if (src0_pack_size == 4 && src1_pack_size == 1) {
func = DispatchIndexType<op, T, R, max_dims, 4, 1>; func = DispatchIndexType<op, T, R, max_dims, 4, 1>;
} else { } else {
UNIMPLEMENTED(); UNIMPLEMENTED();
} }
func(stream, num_dims, src0_dims, src0, src1_dims, src1, dst_dims, dst, attr0, attr1); func(stream, num_dims, src0_dims, src0, src1_dims, src1, dst_dims, dst, attr0, attr1);
} }
template<BinaryOp op, typename T, typename R> template<BinaryOp op, typename T, typename R>
void DispatchNumDims(Stream* stream, size_t src0_pack_size, size_t src1_pack_size, size_t num_dims, void DispatchNumDims(Stream* stream, size_t src0_pack_size, size_t src1_pack_size, size_t num_dims,
const int64_t* src0_dims, const void* src0, const int64_t* src1_dims, const int64_t* src0_dims, const void* src0, const int64_t* src1_dims,
const void* src1, const int64_t* dst_dims, void* dst, Scalar attr0, const void* src1, const int64_t* dst_dims, void* dst, Scalar attr0,
Scalar attr1) { Scalar attr1) {
void (*func)(Stream* /*stream*/, size_t /*src0_pack_size*/, size_t /*src1_pack_size*/, void (*func)(Stream* /*stream*/, size_t /*src0_pack_size*/, size_t /*src1_pack_size*/,
size_t /*num_dims*/, const int64_t* /*src0_dims*/, const void* /*src0*/, size_t /*num_dims*/, const int64_t* /*src0_dims*/, const void* /*src0*/,
const int64_t* /*src1_dims*/, const void* /*src1*/, const int64_t* /*dst_dims*/, const int64_t* /*src1_dims*/, const void* /*src1*/, const int64_t* /*dst_dims*/,
void* /*dst*/, Scalar /*attr0*/, Scalar /*attr1*/) = nullptr; void* /*dst*/, Scalar /*attr0*/, Scalar /*attr1*/) = nullptr;
CHECK_NE(num_dims, 1); CHECK_NE(num_dims, 1);
if (num_dims == 2) { if (num_dims == 2) {
func = DispatchPackSize<op, T, R, 2>; func = DispatchPackSize<op, T, R, 2>;
} else if (num_dims == 3) { } else if (num_dims == 3) {
func = DispatchPackSize<op, T, R, 3>; func = DispatchPackSize<op, T, R, 3>;
} else if (num_dims == 4) { } else if (num_dims == 4) {
func = DispatchPackSize<op, T, R, 4>; func = DispatchPackSize<op, T, R, 4>;
} else if (num_dims <= 8) { } else if (num_dims <= 8) {
func = DispatchPackSize<op, T, R, 8>; func = DispatchPackSize<op, T, R, 8>;
} else { } else {
UNIMPLEMENTED(); UNIMPLEMENTED();
} }
func(stream, src0_pack_size, src1_pack_size, num_dims, src0_dims, src0, src1_dims, src1, dst_dims, func(stream, src0_pack_size, src1_pack_size, num_dims, src0_dims, src0, src1_dims, src1, dst_dims,
dst, attr0, attr1); dst, attr0, attr1);
} }
template<size_t max_pack_size, typename T, typename R> template<size_t max_pack_size, typename T, typename R>
size_t GetPackSize(size_t num_src_dims, const int64_t* src0_dims, const void* src0, size_t GetPackSize(size_t num_src_dims, const int64_t* src0_dims, const void* src0,
const int64_t* src1_dims, const void* src1, void* dst) { const int64_t* src1_dims, const void* src1, void* dst) {
static_assert(max_pack_size > 0 && (max_pack_size & (max_pack_size - 1)) == 0, ""); static_assert(max_pack_size > 0 && (max_pack_size & (max_pack_size - 1)) == 0, "");
CHECK(src0_dims[num_src_dims - 1] != 1 || src1_dims[num_src_dims - 1] != 1); CHECK(src0_dims[num_src_dims - 1] != 1 || src1_dims[num_src_dims - 1] != 1);
auto dst_ptr = reinterpret_cast<std::uintptr_t>(dst); auto dst_ptr = reinterpret_cast<std::uintptr_t>(dst);
for (size_t pack_size = max_pack_size; pack_size > 2; pack_size /= 2) { for (size_t pack_size = max_pack_size; pack_size > 2; pack_size /= 2) {
bool is_src0_supported = (src0_dims[num_src_dims - 1] == 1) bool is_src0_supported = (src0_dims[num_src_dims - 1] == 1)
|| IsPackSizeSupported<T>(pack_size, num_src_dims, src0_dims, src0); || IsPackSizeSupported<T>(pack_size, num_src_dims, src0_dims, src0);
bool is_src1_supported = (src1_dims[num_src_dims - 1] == 1) bool is_src1_supported = (src1_dims[num_src_dims - 1] == 1)
|| IsPackSizeSupported<T>(pack_size, num_src_dims, src1_dims, src1); || IsPackSizeSupported<T>(pack_size, num_src_dims, src1_dims, src1);
if (is_src0_supported && is_src1_supported && (dst_ptr % (pack_size * sizeof(R))) == 0) { if (is_src0_supported && is_src1_supported && (dst_ptr % (pack_size * sizeof(R))) == 0) {
return pack_size; return pack_size;
} }
} }
return 1; return 1;
} }
constexpr size_t kMaxPackSize = 4; constexpr size_t kMaxPackSize = 4;
template<BinaryOp op, typename T, typename R> template<BinaryOp op, typename T, typename R>
void LaunchWithSimplified(Stream* stream, size_t simplified_num_dims, int64_t* simplified_src0_dims, void LaunchWithSimplified(Stream* stream, size_t simplified_num_dims, int64_t* simplified_src0_dims,
const void* src0, int64_t* simplified_src1_dims, const void* src1, const void* src0, int64_t* simplified_src1_dims, const void* src1,
int64_t* simplified_dst_dims, void* dst, Scalar attr0, Scalar attr1) { int64_t* simplified_dst_dims, void* dst, Scalar attr0, Scalar attr1) {
CHECK_LE(simplified_num_dims, kMaxNumDims); CHECK_LE(simplified_num_dims, kMaxNumDims);
size_t pack_size = GetPackSize<kMaxPackSize, T, R>(simplified_num_dims, simplified_src0_dims, size_t pack_size = GetPackSize<kMaxPackSize, T, R>(simplified_num_dims, simplified_src0_dims,
src0, simplified_src1_dims, src1, dst); src0, simplified_src1_dims, src1, dst);
size_t src0_pack_size = 1; size_t src0_pack_size = 1;
size_t src1_pack_size = 1; size_t src1_pack_size = 1;
if (simplified_src0_dims[simplified_num_dims - 1] != 1) { if (simplified_src0_dims[simplified_num_dims - 1] != 1) {
simplified_src0_dims[simplified_num_dims - 1] /= pack_size; simplified_src0_dims[simplified_num_dims - 1] /= pack_size;
src0_pack_size = pack_size; src0_pack_size = pack_size;
} }
if (simplified_src1_dims[simplified_num_dims - 1] != 1) { if (simplified_src1_dims[simplified_num_dims - 1] != 1) {
simplified_src1_dims[simplified_num_dims - 1] /= pack_size; simplified_src1_dims[simplified_num_dims - 1] /= pack_size;
src1_pack_size = pack_size; src1_pack_size = pack_size;
} }
simplified_dst_dims[simplified_num_dims - 1] /= pack_size; simplified_dst_dims[simplified_num_dims - 1] /= pack_size;
DispatchNumDims<op, T, R>(stream, src0_pack_size, src1_pack_size, simplified_num_dims, DispatchNumDims<op, T, R>(stream, src0_pack_size, src1_pack_size, simplified_num_dims,
simplified_src0_dims, src0, simplified_src1_dims, src1, simplified_src0_dims, src0, simplified_src1_dims, src1,
simplified_dst_dims, dst, attr0, attr1); simplified_dst_dims, dst, attr0, attr1);
} }
template<BinaryOp binary_op, typename Src, typename Dst> template<BinaryOp binary_op, typename Src, typename Dst>
struct BinaryLhsScalarFunctor { struct BinaryLhsScalarFunctor {
__host__ __device__ BinaryLhsScalarFunctor(Src scalar, Scalar attr0, Scalar attr1) __host__ __device__ BinaryLhsScalarFunctor(Src scalar, Scalar attr0, Scalar attr1)
: scalar(scalar), functor(attr0, attr1) {} : scalar(scalar), functor(attr0, attr1) {}
__device__ Dst operator()(Src src) const { return functor(scalar, src); } __device__ Dst operator()(Src src) const { return functor(scalar, src); }
const Src scalar; const Src scalar;
BinaryFunctor<DeviceType::kCUDA, binary_op, Src, Dst> functor; BinaryFunctor<DeviceType::kCUDA, binary_op, Src, Dst> functor;
}; };
template<BinaryOp binary_op, typename Src, typename Dst> template<BinaryOp binary_op, typename Src, typename Dst>
struct BinaryRhsScalarFunctor { struct BinaryRhsScalarFunctor {
__host__ __device__ BinaryRhsScalarFunctor(Src scalar, Scalar attr0, Scalar attr1) __host__ __device__ BinaryRhsScalarFunctor(Src scalar, Scalar attr0, Scalar attr1)
: scalar(scalar), functor(attr0, attr1) {} : scalar(scalar), functor(attr0, attr1) {}
__device__ Dst operator()(Src src) const { return functor(src, scalar); } __device__ Dst operator()(Src src) const { return functor(src, scalar); }
const Src scalar; const Src scalar;
BinaryFunctor<DeviceType::kCUDA, binary_op, Src, Dst> functor; BinaryFunctor<DeviceType::kCUDA, binary_op, Src, Dst> functor;
}; };
template<BinaryOp binary_op, typename Src, typename Dst> template<BinaryOp binary_op, typename Src, typename Dst>
struct BinaryLhsScalarPtrFunctorFactory { struct BinaryLhsScalarPtrFunctorFactory {
__host__ __device__ BinaryLhsScalarPtrFunctorFactory(const Src* scalar_ptr, Scalar attr0, __host__ __device__ BinaryLhsScalarPtrFunctorFactory(const Src* scalar_ptr, Scalar attr0,
Scalar attr1) Scalar attr1)
: scalar_ptr(scalar_ptr), attr0(attr0), attr1(attr1) {} : scalar_ptr(scalar_ptr), attr0(attr0), attr1(attr1) {}
__device__ BinaryLhsScalarFunctor<binary_op, Src, Dst> operator()() const { __device__ BinaryLhsScalarFunctor<binary_op, Src, Dst> operator()() const {
return BinaryLhsScalarFunctor<binary_op, Src, Dst>(*scalar_ptr, attr0, attr1); return BinaryLhsScalarFunctor<binary_op, Src, Dst>(*scalar_ptr, attr0, attr1);
} }
const Src* scalar_ptr; const Src* scalar_ptr;
Scalar attr0, attr1; Scalar attr0, attr1;
}; };
template<BinaryOp binary_op, typename Src, typename Dst> template<BinaryOp binary_op, typename Src, typename Dst>
struct BinaryRhsScalarPtrFunctorFactory { struct BinaryRhsScalarPtrFunctorFactory {
__host__ __device__ explicit BinaryRhsScalarPtrFunctorFactory(const Src* scalar_ptr, Scalar attr0, __host__ __device__ explicit BinaryRhsScalarPtrFunctorFactory(const Src* scalar_ptr, Scalar attr0,
Scalar attr1) Scalar attr1)
: scalar_ptr(scalar_ptr), attr0(attr0), attr1(attr1) {} : scalar_ptr(scalar_ptr), attr0(attr0), attr1(attr1) {}
__device__ BinaryRhsScalarFunctor<binary_op, Src, Dst> operator()() const { __device__ BinaryRhsScalarFunctor<binary_op, Src, Dst> operator()() const {
return BinaryRhsScalarFunctor<binary_op, Src, Dst>(*scalar_ptr, attr0, attr1); return BinaryRhsScalarFunctor<binary_op, Src, Dst>(*scalar_ptr, attr0, attr1);
} }
const Src* scalar_ptr; const Src* scalar_ptr;
Scalar attr0, attr1; Scalar attr0, attr1;
}; };
template<BinaryOp binary_op, typename Src, typename Dst> template<BinaryOp binary_op, typename Src, typename Dst>
void DispatchLaunch(Stream* stream, size_t num_src0_dims, const int64_t* src0_dims, const Src* src0, void DispatchLaunch(Stream* stream, size_t num_src0_dims, const int64_t* src0_dims, const Src* src0,
size_t num_src1_dims, const int64_t* src1_dims, const Src* src1, Dst* dst, size_t num_src1_dims, const int64_t* src1_dims, const Src* src1, Dst* dst,
Scalar attr0, Scalar attr1) { Scalar attr0, Scalar attr1) {
auto* cuda_stream = stream->As<CudaStream>(); auto* cuda_stream = stream->As<CudaStream>();
size_t simplified_num_dims = 0; size_t simplified_num_dims = 0;
int64_t simplified_src0_dims[kMaxNumDims]; int64_t simplified_src0_dims[kMaxNumDims];
int64_t simplified_src1_dims[kMaxNumDims]; int64_t simplified_src1_dims[kMaxNumDims];
int64_t simplified_dst_dims[kMaxNumDims]; int64_t simplified_dst_dims[kMaxNumDims];
SimplifyBroadcastDims<kMaxNumDims>(num_src0_dims, src0_dims, num_src1_dims, src1_dims, SimplifyBroadcastDims<kMaxNumDims>(num_src0_dims, src0_dims, num_src1_dims, src1_dims,
&simplified_num_dims, simplified_src0_dims, &simplified_num_dims, simplified_src0_dims,
simplified_src1_dims, simplified_dst_dims); simplified_src1_dims, simplified_dst_dims);
CheckInplace(simplified_num_dims, simplified_src0_dims, src0, simplified_src1_dims, src1, CheckInplace(simplified_num_dims, simplified_src0_dims, src0, simplified_src1_dims, src1,
simplified_dst_dims, dst); simplified_dst_dims, dst);
if (IsDimsEquals(simplified_num_dims, simplified_src0_dims, simplified_num_dims, if (IsDimsEquals(simplified_num_dims, simplified_src0_dims, simplified_num_dims,
simplified_src1_dims)) { simplified_src1_dims)) {
const int64_t elem_cnt = GetElementCount(simplified_num_dims, simplified_src0_dims); const int64_t elem_cnt = GetElementCount(simplified_num_dims, simplified_src0_dims);
OF_CUDA_CHECK((cuda::elementwise::Binary( OF_CUDA_CHECK((cuda::elementwise::Binary(
BinaryFunctor<DeviceType::kCUDA, binary_op, Src, Dst>(attr0, attr1), elem_cnt, dst, src0, BinaryFunctor<DeviceType::kCUDA, binary_op, Src, Dst>(attr0, attr1), elem_cnt, dst, src0,
src1, cuda_stream->cuda_stream()))); src1, cuda_stream->cuda_stream())));
} else { } else {
if (simplified_num_dims == 1 && simplified_src0_dims[0] == 1) { if (simplified_num_dims == 1 && simplified_src0_dims[0] == 1) {
OF_CUDA_CHECK((cuda::elementwise::UnaryWithFactory( OF_CUDA_CHECK((cuda::elementwise::UnaryWithFactory(
BinaryLhsScalarPtrFunctorFactory<binary_op, Src, Dst>(src0, attr0, attr1), BinaryLhsScalarPtrFunctorFactory<binary_op, Src, Dst>(src0, attr0, attr1),
simplified_src1_dims[0], dst, src1, cuda_stream->cuda_stream()))); simplified_src1_dims[0], dst, src1, cuda_stream->cuda_stream())));
} else if (simplified_num_dims == 1 && simplified_src1_dims[0] == 1) { } else if (simplified_num_dims == 1 && simplified_src1_dims[0] == 1) {
OF_CUDA_CHECK((cuda::elementwise::UnaryWithFactory( OF_CUDA_CHECK((cuda::elementwise::UnaryWithFactory(
BinaryRhsScalarPtrFunctorFactory<binary_op, Src, Dst>(src1, attr0, attr1), BinaryRhsScalarPtrFunctorFactory<binary_op, Src, Dst>(src1, attr0, attr1),
simplified_src0_dims[0], dst, src0, cuda_stream->cuda_stream()))); simplified_src0_dims[0], dst, src0, cuda_stream->cuda_stream())));
} else { } else {
LaunchWithSimplified<binary_op, Src, Dst>(stream, simplified_num_dims, simplified_src0_dims, LaunchWithSimplified<binary_op, Src, Dst>(stream, simplified_num_dims, simplified_src0_dims,
src0, simplified_src1_dims, src1, src0, simplified_src1_dims, src1,
simplified_dst_dims, dst, attr0, attr1); simplified_dst_dims, dst, attr0, attr1);
} }
} }
} }
template<typename T> template<typename T>
T GetValue(Scalar value) { T GetValue(Scalar value) {
return value.Value<T>(); return value.Value<T>();
} }
template<> template<>
half GetValue<half>(Scalar value) { half GetValue<half>(Scalar value) {
return static_cast<half>(GetValue<float>(value)); return static_cast<half>(GetValue<float>(value));
} }
// #if CUDA_VERSION >= 11000 // #if CUDA_VERSION >= 11000
// template<> // template<>
// nv_bfloat16 GetValue<nv_bfloat16>(Scalar value) { // nv_bfloat16 GetValue<nv_bfloat16>(Scalar value) {
// return static_cast<nv_bfloat16>(GetValue<float>(value)); // return static_cast<nv_bfloat16>(GetValue<float>(value));
// } // }
// #endif // CUDA_VERSION >= 11000 // #endif // CUDA_VERSION >= 11000
template<BinaryOp binary_op, typename Src, typename Dst> template<BinaryOp binary_op, typename Src, typename Dst>
class BroadcastElementwiseBinaryImpl : public BroadcastElementwiseBinary { class BroadcastElementwiseBinaryImpl : public BroadcastElementwiseBinary {
public: public:
OF_DISALLOW_COPY_AND_MOVE(BroadcastElementwiseBinaryImpl); OF_DISALLOW_COPY_AND_MOVE(BroadcastElementwiseBinaryImpl);
BroadcastElementwiseBinaryImpl(Scalar attr0, Scalar attr1) : attr0(attr0), attr1(attr1) {} BroadcastElementwiseBinaryImpl(Scalar attr0, Scalar attr1) : attr0(attr0), attr1(attr1) {}
~BroadcastElementwiseBinaryImpl() override = default; ~BroadcastElementwiseBinaryImpl() override = default;
void Launch(Stream* stream, Scalar src0, size_t num_src1_dims, const int64_t* src1_dims, void Launch(Stream* stream, Scalar src0, size_t num_src1_dims, const int64_t* src1_dims,
const void* src1, void* dst) override { const void* src1, void* dst) override {
auto* cuda_stream = stream->As<CudaStream>(); auto* cuda_stream = stream->As<CudaStream>();
const size_t elem_cnt = GetElementCount(num_src1_dims, src1_dims); const size_t elem_cnt = GetElementCount(num_src1_dims, src1_dims);
OF_CUDA_CHECK((cuda::elementwise::Unary( OF_CUDA_CHECK((cuda::elementwise::Unary(
BinaryLhsScalarFunctor<binary_op, Src, Dst>(GetValue<Src>(src0), attr0, attr1), elem_cnt, BinaryLhsScalarFunctor<binary_op, Src, Dst>(GetValue<Src>(src0), attr0, attr1), elem_cnt,
reinterpret_cast<Dst*>(dst), reinterpret_cast<const Src*>(src1), reinterpret_cast<Dst*>(dst), reinterpret_cast<const Src*>(src1),
cuda_stream->cuda_stream()))); cuda_stream->cuda_stream())));
} }
void Launch(Stream* stream, size_t num_src0_dims, const int64_t* src0_dims, const void* src0, void Launch(Stream* stream, size_t num_src0_dims, const int64_t* src0_dims, const void* src0,
Scalar src1, void* dst) override { Scalar src1, void* dst) override {
auto* cuda_stream = stream->As<CudaStream>(); auto* cuda_stream = stream->As<CudaStream>();
const size_t elem_cnt = GetElementCount(num_src0_dims, src0_dims); const size_t elem_cnt = GetElementCount(num_src0_dims, src0_dims);
OF_CUDA_CHECK((cuda::elementwise::Unary( OF_CUDA_CHECK((cuda::elementwise::Unary(
BinaryRhsScalarFunctor<binary_op, Src, Dst>(GetValue<Src>(src1), attr0, attr1), elem_cnt, BinaryRhsScalarFunctor<binary_op, Src, Dst>(GetValue<Src>(src1), attr0, attr1), elem_cnt,
reinterpret_cast<Dst*>(dst), reinterpret_cast<const Src*>(src0), reinterpret_cast<Dst*>(dst), reinterpret_cast<const Src*>(src0),
cuda_stream->cuda_stream()))); cuda_stream->cuda_stream())));
} }
void Launch(Stream* stream, size_t num_src0_dims, const int64_t* src0_dims, const void* src0, void Launch(Stream* stream, size_t num_src0_dims, const int64_t* src0_dims, const void* src0,
size_t num_src1_dims, const int64_t* src1_dims, const void* src1, size_t num_src1_dims, const int64_t* src1_dims, const void* src1,
void* dst) override { void* dst) override {
DispatchLaunch<binary_op, Src, Dst>( DispatchLaunch<binary_op, Src, Dst>(
stream, num_src0_dims, src0_dims, reinterpret_cast<const Src*>(src0), num_src1_dims, stream, num_src0_dims, src0_dims, reinterpret_cast<const Src*>(src0), num_src1_dims,
src1_dims, reinterpret_cast<const Src*>(src1), reinterpret_cast<Dst*>(dst), attr0, attr1); src1_dims, reinterpret_cast<const Src*>(src1), reinterpret_cast<Dst*>(dst), attr0, attr1);
} }
private: private:
Scalar attr0, attr1; Scalar attr0, attr1;
}; };
} // namespace } // namespace
template<BinaryOp binary_op, typename Src, typename Dst> template<BinaryOp binary_op, typename Src, typename Dst>
std::unique_ptr<BroadcastElementwiseBinary> NewBroadcastElementwiseBinary(Scalar attr0, std::unique_ptr<BroadcastElementwiseBinary> NewBroadcastElementwiseBinary(Scalar attr0,
Scalar attr1) { Scalar attr1) {
return std::unique_ptr<BroadcastElementwiseBinary>( return std::unique_ptr<BroadcastElementwiseBinary>(
new BroadcastElementwiseBinaryImpl<binary_op, Src, Dst>(attr0, attr1)); new BroadcastElementwiseBinaryImpl<binary_op, Src, Dst>(attr0, attr1));
} }
} // namespace broadcast_elementwise_binary } // namespace broadcast_elementwise_binary
} // namespace primitive } // namespace primitive
} // namespace ep } // namespace ep
} // namespace oneflow } // namespace oneflow
\ No newline at end of file
/* /*
Copyright 2020 The OneFlow Authors. All rights reserved. Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
*/ */
#include "oneflow/core/ep/rocm/primitive/broadcast_elementwise_binary.hip.h" #include "oneflow/core/ep/rocm/primitive/broadcast_elementwise_binary.hip.h"
namespace oneflow { namespace oneflow {
namespace ep { namespace ep {
namespace primitive { namespace primitive {
namespace broadcast_elementwise_binary { namespace broadcast_elementwise_binary {
#define INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_ACTIVATION_GRAD_ENTRY(binary_op, \ #define INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_ACTIVATION_GRAD_ENTRY(binary_op, \
data_type_pair) \ data_type_pair) \
template std::unique_ptr<BroadcastElementwiseBinary> NewBroadcastElementwiseBinary< \ template std::unique_ptr<BroadcastElementwiseBinary> NewBroadcastElementwiseBinary< \
binary_op, OF_PP_PAIR_FIRST(data_type_pair), OF_PP_PAIR_FIRST(data_type_pair)>( \ binary_op, OF_PP_PAIR_FIRST(data_type_pair), OF_PP_PAIR_FIRST(data_type_pair)>( \
Scalar attr0, Scalar attr1); Scalar attr0, Scalar attr1);
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_ACTIVATION_GRAD_ENTRY, OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_ACTIVATION_GRAD_ENTRY,
BINARY_ACTIVATION_BACKWARD_OP_SEQ, BINARY_ACTIVATION_BACKWARD_OP_SEQ,
CUDA_PRIMITIVE_FLOATING_TYPE_SEQ); CUDA_PRIMITIVE_FLOATING_TYPE_SEQ);
} // namespace broadcast_elementwise_binary } // namespace broadcast_elementwise_binary
} // namespace primitive } // namespace primitive
} // namespace ep } // namespace ep
} // namespace oneflow } // namespace oneflow
/* /*
Copyright 2020 The OneFlow Authors. All rights reserved. Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
*/ */
#include "oneflow/core/ep/rocm/primitive/broadcast_elementwise_binary.hip.h" #include "oneflow/core/ep/rocm/primitive/broadcast_elementwise_binary.hip.h"
namespace oneflow { namespace oneflow {
namespace ep { namespace ep {
namespace primitive { namespace primitive {
namespace broadcast_elementwise_binary { namespace broadcast_elementwise_binary {
#define INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_COMPARASION_ENTRY( \ #define INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_COMPARASION_ENTRY( \
binary_op, src_data_type_pair, dst_data_type_pair) \ binary_op, src_data_type_pair, dst_data_type_pair) \
template std::unique_ptr<BroadcastElementwiseBinary> NewBroadcastElementwiseBinary< \ template std::unique_ptr<BroadcastElementwiseBinary> NewBroadcastElementwiseBinary< \
binary_op, OF_PP_PAIR_FIRST(src_data_type_pair), OF_PP_PAIR_FIRST(dst_data_type_pair)>( \ binary_op, OF_PP_PAIR_FIRST(src_data_type_pair), OF_PP_PAIR_FIRST(dst_data_type_pair)>( \
Scalar attr0, Scalar attr1); Scalar attr0, Scalar attr1);
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_COMPARASION_ENTRY, OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_COMPARASION_ENTRY,
BINARY_COMPARISION_OP_SEQ, CUDA_PRIMITIVE_ALL_TYPE_SEQ, BINARY_COMPARISION_OP_SEQ, CUDA_PRIMITIVE_ALL_TYPE_SEQ,
CUDA_PRIMITIVE_BOOL_TYPE_SEQ); CUDA_PRIMITIVE_BOOL_TYPE_SEQ);
} // namespace broadcast_elementwise_binary } // namespace broadcast_elementwise_binary
} // namespace primitive } // namespace primitive
} // namespace ep } // namespace ep
} // namespace oneflow } // namespace oneflow
/* /*
Copyright 2020 The OneFlow Authors. All rights reserved. Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
*/ */
#include "oneflow/core/ep/rocm/primitive/broadcast_elementwise_binary.hip.h" #include "oneflow/core/ep/rocm/primitive/broadcast_elementwise_binary.hip.h"
namespace oneflow { namespace oneflow {
namespace ep { namespace ep {
namespace primitive { namespace primitive {
namespace broadcast_elementwise_binary { namespace broadcast_elementwise_binary {
#define INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_LOGICAL_ENTRY(binary_op, src_data_type_pair, \ #define INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_LOGICAL_ENTRY(binary_op, src_data_type_pair, \
dst_data_type_pair) \ dst_data_type_pair) \
template std::unique_ptr<BroadcastElementwiseBinary> NewBroadcastElementwiseBinary< \ template std::unique_ptr<BroadcastElementwiseBinary> NewBroadcastElementwiseBinary< \
binary_op, OF_PP_PAIR_FIRST(src_data_type_pair), OF_PP_PAIR_FIRST(dst_data_type_pair)>( \ binary_op, OF_PP_PAIR_FIRST(src_data_type_pair), OF_PP_PAIR_FIRST(dst_data_type_pair)>( \
Scalar attr0, Scalar attr1); Scalar attr0, Scalar attr1);
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_LOGICAL_ENTRY, OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_LOGICAL_ENTRY,
BINARY_COMPARISION_OP_SEQ BINARY_LOGICAL_OP_SEQ, BINARY_COMPARISION_OP_SEQ BINARY_LOGICAL_OP_SEQ,
CUDA_PRIMITIVE_ALL_TYPE_SEQ, CUDA_PRIMITIVE_BOOL_TYPE_SEQ); CUDA_PRIMITIVE_ALL_TYPE_SEQ, CUDA_PRIMITIVE_BOOL_TYPE_SEQ);
} // namespace broadcast_elementwise_binary } // namespace broadcast_elementwise_binary
} // namespace primitive } // namespace primitive
} // namespace ep } // namespace ep
} // namespace oneflow } // namespace oneflow
/* /*
Copyright 2020 The OneFlow Authors. All rights reserved. Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
*/ */
#include "oneflow/core/ep/rocm/primitive/broadcast_elementwise_binary.hip.h" #include "oneflow/core/ep/rocm/primitive/broadcast_elementwise_binary.hip.h"
namespace oneflow { namespace oneflow {
namespace ep { namespace ep {
namespace primitive { namespace primitive {
namespace broadcast_elementwise_binary { namespace broadcast_elementwise_binary {
#define INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY(binary_op, data_type_pair) \ #define INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY(binary_op, data_type_pair) \
template std::unique_ptr<BroadcastElementwiseBinary> NewBroadcastElementwiseBinary< \ template std::unique_ptr<BroadcastElementwiseBinary> NewBroadcastElementwiseBinary< \
binary_op, OF_PP_PAIR_FIRST(data_type_pair), OF_PP_PAIR_FIRST(data_type_pair)>( \ binary_op, OF_PP_PAIR_FIRST(data_type_pair), OF_PP_PAIR_FIRST(data_type_pair)>( \
Scalar attr0, Scalar attr1); Scalar attr0, Scalar attr1);
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY, OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY,
BINARY_MATH_OP_SEQ, CUDA_PRIMITIVE_ALL_TYPE_SEQ); BINARY_MATH_OP_SEQ, CUDA_PRIMITIVE_ALL_TYPE_SEQ);
} // namespace broadcast_elementwise_binary } // namespace broadcast_elementwise_binary
} // namespace primitive } // namespace primitive
} // namespace ep } // namespace ep
} // namespace oneflow } // namespace oneflow
\ No newline at end of file
/* /*
Copyright 2020 The OneFlow Authors. All rights reserved. Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
*/ */
#ifdef WITH_ROCM #ifdef WITH_ROCM
#include "oneflow/core/ep/include/primitive/primitive.h" #include "oneflow/core/ep/include/primitive/primitive.h"
#include "oneflow/core/ep/include/primitive/broadcast_matmul.h" #include "oneflow/core/ep/include/primitive/broadcast_matmul.h"
#include "oneflow/core/ep/common/primitive/broadcast_matmul.h" #include "oneflow/core/ep/common/primitive/broadcast_matmul.h"
#include "oneflow/core/common/optional.h" #include "oneflow/core/common/optional.h"
#include "oneflow/core/device/cuda_util.h" #include "oneflow/core/device/cuda_util.h"
#include "oneflow/core/ep/rocm/cuda_stream.h" #include "oneflow/core/ep/rocm/cuda_stream.h"
#include <hip/hip_runtime.h> #include <hip/hip_runtime.h>
#include <hip/hip_fp16.h> #include <hip/hip_fp16.h>
namespace oneflow { namespace oneflow {
namespace ep { namespace ep {
namespace primitive { namespace primitive {
namespace broadcast_matmul { namespace broadcast_matmul {
namespace internal { namespace internal {
namespace { namespace {
constexpr size_t kMaxNumDims = 8; constexpr size_t kMaxNumDims = 8;
Optional<hipblasDatatype_t> OptCudaDataType(DataType data_type) { Optional<hipblasDatatype_t> OptCudaDataType(DataType data_type) {
switch (data_type) { switch (data_type) {
case kFloat: return HIPBLAS_R_32F; case kFloat: return HIPBLAS_R_32F;
case kDouble: return HIPBLAS_R_64F; case kDouble: return HIPBLAS_R_64F;
case kFloat16: return HIPBLAS_R_16F; case kFloat16: return HIPBLAS_R_16F;
// #if CUDA_VERSION >= 11000 // #if CUDA_VERSION >= 11000
// case kBFloat16: return CUDA_R_16BF; // case kBFloat16: return CUDA_R_16BF;
// #endif // CUDA_VERSION >= 11000 // #endif // CUDA_VERSION >= 11000
default: return NullOpt; default: return NullOpt;
} }
} }
hipblasDatatype_t GetCudaDataType(DataType data_type) { hipblasDatatype_t GetCudaDataType(DataType data_type) {
auto cuda_data_type = OptCudaDataType(data_type); auto cuda_data_type = OptCudaDataType(data_type);
CHECK(cuda_data_type.has_value()); CHECK(cuda_data_type.has_value());
return cuda_data_type.value_or(HIPBLAS_R_32F); return cuda_data_type.value_or(HIPBLAS_R_32F);
} }
union CublasScalarParameter { union CublasScalarParameter {
double d; double d;
float s; float s;
}; };
CublasScalarParameter GetCublasScalarParameter(Scalar scalar, hipblasDatatype_t compute_type) { CublasScalarParameter GetCublasScalarParameter(Scalar scalar, hipblasDatatype_t compute_type) {
CublasScalarParameter sp{}; CublasScalarParameter sp{};
if (compute_type == HIPBLAS_R_64F) { if (compute_type == HIPBLAS_R_64F) {
sp.d = scalar.Value<double>(); sp.d = scalar.Value<double>();
} else if (compute_type == HIPBLAS_R_32F) { } else if (compute_type == HIPBLAS_R_32F) {
sp.s = scalar.Value<float>(); sp.s = scalar.Value<float>();
} else if (compute_type == HIPBLAS_R_16F) { } else if (compute_type == HIPBLAS_R_16F) {
sp.s = scalar.Value<float>(); sp.s = scalar.Value<float>();
} else { } else {
UNIMPLEMENTED(); UNIMPLEMENTED();
} }
return sp; return sp;
} }
hipblasDatatype_t GetComputeType(DataType data_type) { hipblasDatatype_t GetComputeType(DataType data_type) {
switch (data_type) { switch (data_type) {
case kFloat: return HIPBLAS_R_32F; case kFloat: return HIPBLAS_R_32F;
case kDouble: return HIPBLAS_R_64F; case kDouble: return HIPBLAS_R_64F;
case kFloat16: return HIPBLAS_R_16F; case kFloat16: return HIPBLAS_R_16F;
// #if CUDA_VERSION >= 11000 // #if CUDA_VERSION >= 11000
// case kBFloat16: return HIPBLAS_R_32F; // case kBFloat16: return HIPBLAS_R_32F;
// #endif // CUDA_VERSION >= 11000 // #endif // CUDA_VERSION >= 11000
default: UNIMPLEMENTED(); return HIPBLAS_R_32F; default: UNIMPLEMENTED(); return HIPBLAS_R_32F;
} }
} }
void LaunchBroadcastMatmul(Stream* stream, DataType data_type, BlasTransposeType transpose_a, void LaunchBroadcastMatmul(Stream* stream, DataType data_type, BlasTransposeType transpose_a,
BlasTransposeType transpose_b, int64_t num_batch_dims, BlasTransposeType transpose_b, int64_t num_batch_dims,
const int64_t* broadcast_batch_dims, const int64_t* a_batch_dims, const int64_t* broadcast_batch_dims, const int64_t* a_batch_dims,
const int64_t* b_batch_dims, const int64_t* c_batch_dims, int64_t m, const int64_t* b_batch_dims, const int64_t* c_batch_dims, int64_t m,
int64_t n, int64_t k, Scalar alpha, const void* a, const void* b, int64_t n, int64_t k, Scalar alpha, const void* a, const void* b,
Scalar beta, void* c) { Scalar beta, void* c) {
auto* cuda_stream = stream->As<CudaStream>(); auto* cuda_stream = stream->As<CudaStream>();
const auto cuda_data_type = GetCudaDataType(data_type); const auto cuda_data_type = GetCudaDataType(data_type);
const auto compute_type = GetComputeType(data_type); const auto compute_type = GetComputeType(data_type);
const auto sp_alpha = GetCublasScalarParameter(alpha, compute_type); const auto sp_alpha = GetCublasScalarParameter(alpha, compute_type);
__half h_alpha = 0; __half h_alpha = 0;
if (compute_type == HIPBLAS_R_16F) { if (compute_type == HIPBLAS_R_16F) {
h_alpha = __float2half(sp_alpha.s); h_alpha = __float2half(sp_alpha.s);
} }
const auto GetCublasOperation = [](BlasTransposeType transpose_type) { const auto GetCublasOperation = [](BlasTransposeType transpose_type) {
if (transpose_type == BlasTransposeType::N) { if (transpose_type == BlasTransposeType::N) {
return HIPBLAS_OP_N; return HIPBLAS_OP_N;
} else if (transpose_type == BlasTransposeType::T) { } else if (transpose_type == BlasTransposeType::T) {
return HIPBLAS_OP_T; return HIPBLAS_OP_T;
} else { } else {
UNIMPLEMENTED(); UNIMPLEMENTED();
return HIPBLAS_OP_N; return HIPBLAS_OP_N;
} }
}; };
const hipblasOperation_t cublas_trans_a = GetCublasOperation(transpose_b); const hipblasOperation_t cublas_trans_a = GetCublasOperation(transpose_b);
const hipblasOperation_t cublas_trans_b = GetCublasOperation(transpose_a); const hipblasOperation_t cublas_trans_b = GetCublasOperation(transpose_a);
const int cublas_m = n; const int cublas_m = n;
const int cublas_n = m; const int cublas_n = m;
const int cublas_k = k; const int cublas_k = k;
int cublas_lda = 0; int cublas_lda = 0;
if (transpose_b == BlasTransposeType::N) { if (transpose_b == BlasTransposeType::N) {
cublas_lda = n; cublas_lda = n;
} else if (transpose_b == BlasTransposeType::T) { } else if (transpose_b == BlasTransposeType::T) {
cublas_lda = k; cublas_lda = k;
} else { } else {
UNIMPLEMENTED(); UNIMPLEMENTED();
} }
int cublas_ldb = 0; int cublas_ldb = 0;
if (transpose_a == BlasTransposeType::N) { if (transpose_a == BlasTransposeType::N) {
cublas_ldb = k; cublas_ldb = k;
} else if (transpose_a == BlasTransposeType::T) { } else if (transpose_a == BlasTransposeType::T) {
cublas_ldb = m; cublas_ldb = m;
} else { } else {
UNIMPLEMENTED(); UNIMPLEMENTED();
} }
const int cublas_ldc = n; const int cublas_ldc = n;
// CublasMathModeGuard guard(cuda_stream->cublas_handle()); // CublasMathModeGuard guard(cuda_stream->cublas_handle());
// if (data_type == DataType::kFloat16) { // if (data_type == DataType::kFloat16) {
// #if CUDA_VERSION < 11000 // #if CUDA_VERSION < 11000
// guard.SetMathMode(CUBLAS_TENSOR_OP_MATH); // guard.SetMathMode(CUBLAS_TENSOR_OP_MATH);
// #else // #else
// guard.SetMathMode(CUBLAS_DEFAULT_MATH); // guard.SetMathMode(CUBLAS_DEFAULT_MATH);
// #endif // CUDA_VERSION < 11000 // #endif // CUDA_VERSION < 11000
// } // }
// #if CUDA_VERSION >= 11000 // #if CUDA_VERSION >= 11000
// hipblasGemmAlgo_t algo = HIPBLAS_GEMM_DEFAULT; // hipblasGemmAlgo_t algo = HIPBLAS_GEMM_DEFAULT;
hipblasGemmAlgo_t algo = HIPBLAS_GEMM_DEFAULT; hipblasGemmAlgo_t algo = HIPBLAS_GEMM_DEFAULT;
// #else // #else
// hipblasGemmAlgo_t algo = // hipblasGemmAlgo_t algo =
// (data_type == DataType::kFloat16) ? CUBLAS_GEMM_DFALT_TENSOR_OP : HIPBLAS_GEMM_DEFAULT; // (data_type == DataType::kFloat16) ? CUBLAS_GEMM_DFALT_TENSOR_OP : HIPBLAS_GEMM_DEFAULT;
// #endif // #endif
if (num_batch_dims == 1 && c_batch_dims[0] != 1) { if (num_batch_dims == 1 && c_batch_dims[0] != 1) {
const void* cublas_a = b; const void* cublas_a = b;
const void* cublas_b = a; const void* cublas_b = a;
void* cublas_c = c; void* cublas_c = c;
const int64_t a_batch_count = a_batch_dims[0]; const int64_t a_batch_count = a_batch_dims[0];
const int64_t b_batch_count = b_batch_dims[0]; const int64_t b_batch_count = b_batch_dims[0];
CHECK(a_batch_count == 1 || b_batch_count == 1 || a_batch_count == b_batch_count); CHECK(a_batch_count == 1 || b_batch_count == 1 || a_batch_count == b_batch_count);
CHECK_GT(a_batch_count, 0); CHECK_GT(a_batch_count, 0);
CHECK_GT(b_batch_count, 0); CHECK_GT(b_batch_count, 0);
const int batch_count = std::max(a_batch_count, b_batch_count); const int batch_count = std::max(a_batch_count, b_batch_count);
const long long int cublas_stride_a = b_batch_count == 1 ? 0 : cublas_m * cublas_k; const long long int cublas_stride_a = b_batch_count == 1 ? 0 : cublas_m * cublas_k;
const long long int cublas_stride_b = a_batch_count == 1 ? 0 : cublas_k * cublas_n; const long long int cublas_stride_b = a_batch_count == 1 ? 0 : cublas_k * cublas_n;
const long long int cublas_stride_c = cublas_m * cublas_n; const long long int cublas_stride_c = cublas_m * cublas_n;
const auto sp_beta = GetCublasScalarParameter(beta, compute_type); const auto sp_beta = GetCublasScalarParameter(beta, compute_type);
__half h_beta = 0; __half h_beta = 0;
if (compute_type == HIPBLAS_R_16F) { if (compute_type == HIPBLAS_R_16F) {
h_beta = __float2half(sp_beta.s); h_beta = __float2half(sp_beta.s);
OF_CUBLAS_CHECK(hipblasGemmStridedBatchedEx( OF_CUBLAS_CHECK(hipblasGemmStridedBatchedEx(
cuda_stream->cublas_handle(), cublas_trans_a, cublas_trans_b, cublas_m, cublas_n, cublas_k, cuda_stream->cublas_handle(), cublas_trans_a, cublas_trans_b, cublas_m, cublas_n, cublas_k,
&h_alpha, cublas_a, cuda_data_type, cublas_lda, cublas_stride_a, cublas_b, cuda_data_type, &h_alpha, cublas_a, cuda_data_type, cublas_lda, cublas_stride_a, cublas_b, cuda_data_type,
cublas_ldb, cublas_stride_b, &h_beta, cublas_c, cuda_data_type, cublas_ldc, cublas_ldb, cublas_stride_b, &h_beta, cublas_c, cuda_data_type, cublas_ldc,
cublas_stride_c, batch_count, compute_type, algo)); cublas_stride_c, batch_count, compute_type, algo));
} else { } else {
OF_CUBLAS_CHECK(hipblasGemmStridedBatchedEx( OF_CUBLAS_CHECK(hipblasGemmStridedBatchedEx(
cuda_stream->cublas_handle(), cublas_trans_a, cublas_trans_b, cublas_m, cublas_n, cublas_k, cuda_stream->cublas_handle(), cublas_trans_a, cublas_trans_b, cublas_m, cublas_n, cublas_k,
&sp_alpha, cublas_a, cuda_data_type, cublas_lda, cublas_stride_a, cublas_b, cuda_data_type, &sp_alpha, cublas_a, cuda_data_type, cublas_lda, cublas_stride_a, cublas_b, cuda_data_type,
cublas_ldb, cublas_stride_b, &sp_beta, cublas_c, cuda_data_type, cublas_ldc, cublas_ldb, cublas_stride_b, &sp_beta, cublas_c, cuda_data_type, cublas_ldc,
cublas_stride_c, batch_count, compute_type, algo)); cublas_stride_c, batch_count, compute_type, algo));
} }
} else { } else {
auto func = [&](const void* batch_a, const void* batch_b, void* batch_c, Scalar batch_beta) { auto func = [&](const void* batch_a, const void* batch_b, void* batch_c, Scalar batch_beta) {
const auto sp_beta = GetCublasScalarParameter(batch_beta, compute_type); const auto sp_beta = GetCublasScalarParameter(batch_beta, compute_type);
__half h_beta = 0; __half h_beta = 0;
const void* cublas_a = batch_b; const void* cublas_a = batch_b;
const void* cublas_b = batch_a; const void* cublas_b = batch_a;
void* cublas_c = batch_c; void* cublas_c = batch_c;
if (compute_type == HIPBLAS_R_16F) { if (compute_type == HIPBLAS_R_16F) {
h_beta = __float2half(sp_beta.s); h_beta = __float2half(sp_beta.s);
OF_CUBLAS_CHECK(hipblasGemmEx( OF_CUBLAS_CHECK(hipblasGemmEx(
cuda_stream->cublas_handle(), cublas_trans_a, cublas_trans_b, cublas_m, cublas_n, cuda_stream->cublas_handle(), cublas_trans_a, cublas_trans_b, cublas_m, cublas_n,
cublas_k, &h_alpha, cublas_a, cuda_data_type, cublas_lda, cublas_b, cuda_data_type, cublas_k, &h_alpha, cublas_a, cuda_data_type, cublas_lda, cublas_b, cuda_data_type,
cublas_ldb, &h_beta, cublas_c, cuda_data_type, cublas_ldc, compute_type, algo)); cublas_ldb, &h_beta, cublas_c, cuda_data_type, cublas_ldc, compute_type, algo));
} else { } else {
OF_CUBLAS_CHECK(hipblasGemmEx( OF_CUBLAS_CHECK(hipblasGemmEx(
cuda_stream->cublas_handle(), cublas_trans_a, cublas_trans_b, cublas_m, cublas_n, cuda_stream->cublas_handle(), cublas_trans_a, cublas_trans_b, cublas_m, cublas_n,
cublas_k, &sp_alpha, cublas_a, cuda_data_type, cublas_lda, cublas_b, cuda_data_type, cublas_k, &sp_alpha, cublas_a, cuda_data_type, cublas_lda, cublas_b, cuda_data_type,
cublas_ldb, &sp_beta, cublas_c, cuda_data_type, cublas_ldc, compute_type, algo)); cublas_ldb, &sp_beta, cublas_c, cuda_data_type, cublas_ldc, compute_type, algo));
} }
}; };
ForEachMatmul<kMaxNumDims>(data_type, m, n, k, beta, num_batch_dims, broadcast_batch_dims, ForEachMatmul<kMaxNumDims>(data_type, m, n, k, beta, num_batch_dims, broadcast_batch_dims,
a_batch_dims, b_batch_dims, c_batch_dims, a, b, c, func); a_batch_dims, b_batch_dims, c_batch_dims, a, b, c, func);
} }
} }
class BroadcastMatmulFactoryImpl : public BroadcastMatmulFactory { class BroadcastMatmulFactoryImpl : public BroadcastMatmulFactory {
public: public:
OF_DISALLOW_COPY_AND_MOVE(BroadcastMatmulFactoryImpl); OF_DISALLOW_COPY_AND_MOVE(BroadcastMatmulFactoryImpl);
BroadcastMatmulFactoryImpl() = default; BroadcastMatmulFactoryImpl() = default;
~BroadcastMatmulFactoryImpl() override = default; ~BroadcastMatmulFactoryImpl() override = default;
std::unique_ptr<BroadcastMatmul> New(DataType data_type, BlasTransposeType transpose_a, std::unique_ptr<BroadcastMatmul> New(DataType data_type, BlasTransposeType transpose_a,
BlasTransposeType transpose_b, BlasTransposeType transpose_b,
size_t max_num_dims) override { size_t max_num_dims) override {
auto cuda_data_type = OptCudaDataType(data_type); auto cuda_data_type = OptCudaDataType(data_type);
if (max_num_dims <= kMaxNumDims && cuda_data_type.has_value()) { if (max_num_dims <= kMaxNumDims && cuda_data_type.has_value()) {
return std::make_unique<BroadcastMatmulImpl<kMaxNumDims>>(data_type, transpose_a, return std::make_unique<BroadcastMatmulImpl<kMaxNumDims>>(data_type, transpose_a,
transpose_b); transpose_b);
} else { } else {
return nullptr; return nullptr;
} }
} }
}; };
REGISTER_PRIMITIVE_FACTORY(DeviceType::kCUDA, BroadcastMatmulFactory, BroadcastMatmulFactoryImpl); REGISTER_PRIMITIVE_FACTORY(DeviceType::kCUDA, BroadcastMatmulFactory, BroadcastMatmulFactoryImpl);
} // namespace } // namespace
} // namespace internal } // namespace internal
} // namespace broadcast_matmul } // namespace broadcast_matmul
} // namespace primitive } // namespace primitive
} // namespace ep } // namespace ep
} // namespace oneflow } // namespace oneflow
#endif // WITH_ROCM #endif // WITH_ROCM
/* /*
Copyright 2020 The OneFlow Authors. All rights reserved. Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
*/ */
#include "oneflow/core/common/preprocessor.h" #include "oneflow/core/common/preprocessor.h"
#include "oneflow/core/ep/include/primitive/cast.h" #include "oneflow/core/ep/include/primitive/cast.h"
#include "oneflow/core/ep/rocm/primitive/type_seq.h" #include "oneflow/core/ep/rocm/primitive/type_seq.h"
#include "oneflow/core/hip/elementwise.hip.h" #include "oneflow/core/hip/elementwise.hip.h"
#include "oneflow/core/ep/rocm/cuda_stream.h" #include "oneflow/core/ep/rocm/cuda_stream.h"
namespace oneflow { namespace oneflow {
namespace ep { namespace ep {
namespace primitive { namespace primitive {
namespace { namespace {
template<typename To, typename From, typename = void> template<typename To, typename From, typename = void>
struct CastFunctor { struct CastFunctor {
__device__ To operator()(From from) const { return static_cast<To>(from); } __device__ To operator()(From from) const { return static_cast<To>(from); }
}; };
template<typename To> template<typename To>
struct CastFunctor<To, half, typename std::enable_if<!std::is_same<To, half>::value>::type> { struct CastFunctor<To, half, typename std::enable_if<!std::is_same<To, half>::value>::type> {
__device__ To operator()(half from) const { return static_cast<To>(static_cast<float>(from)); } __device__ To operator()(half from) const { return static_cast<To>(static_cast<float>(from)); }
__device__ void Apply2(To* to, const half* from) const { __device__ void Apply2(To* to, const half* from) const {
const float2 f2 = __half22float2(*reinterpret_cast<const half2*>(from)); const float2 f2 = __half22float2(*reinterpret_cast<const half2*>(from));
to[0] = static_cast<To>(f2.x); to[0] = static_cast<To>(f2.x);
to[1] = static_cast<To>(f2.y); to[1] = static_cast<To>(f2.y);
} }
}; };
template<typename From> template<typename From>
struct CastFunctor<half, From, typename std::enable_if<!std::is_same<From, half>::value>::type> { struct CastFunctor<half, From, typename std::enable_if<!std::is_same<From, half>::value>::type> {
__device__ half operator()(From from) const { __device__ half operator()(From from) const {
return static_cast<half>(static_cast<float>(from)); return static_cast<half>(static_cast<float>(from));
} }
__device__ void Apply2(half* to, const From* from) const { __device__ void Apply2(half* to, const From* from) const {
float2 f2; float2 f2;
f2.x = static_cast<float>(from[0]); f2.x = static_cast<float>(from[0]);
f2.y = static_cast<float>(from[1]); f2.y = static_cast<float>(from[1]);
*reinterpret_cast<half2*>(to) = __float22half2_rn(f2); *reinterpret_cast<half2*>(to) = __float22half2_rn(f2);
} }
}; };
// #if CUDA_VERSION >= 11000 // #if CUDA_VERSION >= 11000
// template<typename To> // template<typename To>
// struct CastFunctor<To, nv_bfloat16, // struct CastFunctor<To, nv_bfloat16,
// typename std::enable_if<!(std::is_same<To, nv_bfloat16>::value // typename std::enable_if<!(std::is_same<To, nv_bfloat16>::value
// || std::is_same<To, half>::value)>::type> { // || std::is_same<To, half>::value)>::type> {
// __device__ To operator()(nv_bfloat16 from) const { // __device__ To operator()(nv_bfloat16 from) const {
// return static_cast<To>(static_cast<float>(from)); // return static_cast<To>(static_cast<float>(from));
// } // }
// }; // };
// template<typename From> // template<typename From>
// struct CastFunctor<nv_bfloat16, From, // struct CastFunctor<nv_bfloat16, From,
// typename std::enable_if<!(std::is_same<From, nv_bfloat16>::value // typename std::enable_if<!(std::is_same<From, nv_bfloat16>::value
// || std::is_same<From, half>::value)>::type> { // || std::is_same<From, half>::value)>::type> {
// __device__ nv_bfloat16 operator()(From from) const { // __device__ nv_bfloat16 operator()(From from) const {
// return static_cast<nv_bfloat16>(static_cast<float>(from)); // return static_cast<nv_bfloat16>(static_cast<float>(from));
// } // }
// }; // };
// #endif // CUDA_VERSION >= 11000 // #endif // CUDA_VERSION >= 11000
template<typename From, typename To> template<typename From, typename To>
class CastImpl : public Cast { class CastImpl : public Cast {
public: public:
OF_DISALLOW_COPY_AND_MOVE(CastImpl); OF_DISALLOW_COPY_AND_MOVE(CastImpl);
explicit CastImpl() = default; explicit CastImpl() = default;
~CastImpl() override = default; ~CastImpl() override = default;
void Launch(Stream* stream, const void* from, void* to, size_t count) override { void Launch(Stream* stream, const void* from, void* to, size_t count) override {
auto* cuda_stream = stream->As<CudaStream>(); auto* cuda_stream = stream->As<CudaStream>();
OF_CUDA_CHECK((cuda::elementwise::Unary<CastFunctor<To, From>, To, From>( OF_CUDA_CHECK((cuda::elementwise::Unary<CastFunctor<To, From>, To, From>(
CastFunctor<To, From>(), count, reinterpret_cast<To*>(to), CastFunctor<To, From>(), count, reinterpret_cast<To*>(to),
reinterpret_cast<const From*>(from), cuda_stream->cuda_stream()))); reinterpret_cast<const From*>(from), cuda_stream->cuda_stream())));
} }
}; };
template<typename From, typename To> template<typename From, typename To>
std::unique_ptr<Cast> NewCast() { std::unique_ptr<Cast> NewCast() {
return std::unique_ptr<Cast>(new CastImpl<From, To>()); return std::unique_ptr<Cast>(new CastImpl<From, To>());
} }
#define CUDA_PRIMITIVE_CAST_TYPE_SEQ \ #define CUDA_PRIMITIVE_CAST_TYPE_SEQ \
CUDA_PRIMITIVE_BOOL_TYPE_SEQ \ CUDA_PRIMITIVE_BOOL_TYPE_SEQ \
CUDA_PRIMITIVE_CHAR_TYPE_SEQ \ CUDA_PRIMITIVE_CHAR_TYPE_SEQ \
CUDA_PRIMITIVE_INT8_TYPE_SEQ \ CUDA_PRIMITIVE_INT8_TYPE_SEQ \
CUDA_PRIMITIVE_UINT8_TYPE_SEQ \ CUDA_PRIMITIVE_UINT8_TYPE_SEQ \
CUDA_PRIMITIVE_INT32_TYPE_SEQ \ CUDA_PRIMITIVE_INT32_TYPE_SEQ \
CUDA_PRIMITIVE_UINT32_TYPE_SEQ \ CUDA_PRIMITIVE_UINT32_TYPE_SEQ \
CUDA_PRIMITIVE_INT64_TYPE_SEQ \ CUDA_PRIMITIVE_INT64_TYPE_SEQ \
CUDA_PRIMITIVE_UINT64_TYPE_SEQ \ CUDA_PRIMITIVE_UINT64_TYPE_SEQ \
CUDA_PRIMITIVE_FLOAT_TYPE_SEQ \ CUDA_PRIMITIVE_FLOAT_TYPE_SEQ \
CUDA_PRIMITIVE_DOUBLE_TYPE_SEQ \ CUDA_PRIMITIVE_DOUBLE_TYPE_SEQ \
CUDA_PRIMITIVE_FLOAT16_TYPE_SEQ \ CUDA_PRIMITIVE_FLOAT16_TYPE_SEQ \
CUDA_PRIMITIVE_BFLOAT16_TYPE_SEQ CUDA_PRIMITIVE_BFLOAT16_TYPE_SEQ
class CastFactoryImpl : public CastFactory { class CastFactoryImpl : public CastFactory {
public: public:
OF_DISALLOW_COPY_AND_MOVE(CastFactoryImpl); OF_DISALLOW_COPY_AND_MOVE(CastFactoryImpl);
CastFactoryImpl() = default; CastFactoryImpl() = default;
~CastFactoryImpl() override = default; ~CastFactoryImpl() override = default;
std::unique_ptr<Cast> New(DataType from, DataType to) override { std::unique_ptr<Cast> New(DataType from, DataType to) override {
#define MAKE_NEW_CAST_ENTRY(from_pair, to_pair) \ #define MAKE_NEW_CAST_ENTRY(from_pair, to_pair) \
{std::make_pair(OF_PP_PAIR_SECOND(from_pair), OF_PP_PAIR_SECOND(to_pair)), \ {std::make_pair(OF_PP_PAIR_SECOND(from_pair), OF_PP_PAIR_SECOND(to_pair)), \
NewCast<OF_PP_PAIR_FIRST(from_pair), OF_PP_PAIR_FIRST(to_pair)>}, NewCast<OF_PP_PAIR_FIRST(from_pair), OF_PP_PAIR_FIRST(to_pair)>},
static const std::map<std::pair<DataType, DataType>, std::function<std::unique_ptr<Cast>()>> static const std::map<std::pair<DataType, DataType>, std::function<std::unique_ptr<Cast>()>>
new_cast_handle{OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE( new_cast_handle{OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(
MAKE_NEW_CAST_ENTRY, CUDA_PRIMITIVE_CAST_TYPE_SEQ, CUDA_PRIMITIVE_CAST_TYPE_SEQ)}; MAKE_NEW_CAST_ENTRY, CUDA_PRIMITIVE_CAST_TYPE_SEQ, CUDA_PRIMITIVE_CAST_TYPE_SEQ)};
#undef MAKE_NEW_CAST_ENTRY #undef MAKE_NEW_CAST_ENTRY
const auto it = new_cast_handle.find(std::make_pair(from, to)); const auto it = new_cast_handle.find(std::make_pair(from, to));
if (it != new_cast_handle.end()) { if (it != new_cast_handle.end()) {
return it->second(); return it->second();
} else { } else {
return nullptr; return nullptr;
} }
} }
}; };
REGISTER_PRIMITIVE_FACTORY(DeviceType::kCUDA, CastFactory, CastFactoryImpl); REGISTER_PRIMITIVE_FACTORY(DeviceType::kCUDA, CastFactory, CastFactoryImpl);
} // namespace } // namespace
} // namespace primitive } // namespace primitive
} // namespace ep } // namespace ep
} // namespace oneflow } // namespace oneflow
/* /*
Copyright 2020 The OneFlow Authors. All rights reserved. Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
*/ */
#include "oneflow/core/ep/include/primitive/constant_pad.h" #include "oneflow/core/ep/include/primitive/constant_pad.h"
#include "oneflow/core/ep/common/primitive/constant_pad.h" #include "oneflow/core/ep/common/primitive/constant_pad.h"
#include "oneflow/core/ep/rocm/primitive/type_seq.h" #include "oneflow/core/ep/rocm/primitive/type_seq.h"
#include "oneflow/core/ep/rocm/cuda_stream.h" #include "oneflow/core/ep/rocm/cuda_stream.h"
#include <hip/hip_runtime.h> #include <hip/hip_runtime.h>
namespace oneflow { namespace oneflow {
namespace ep { namespace ep {
namespace primitive { namespace primitive {
namespace { namespace {
template<size_t num_dims, typename IndexType, typename StorageType> template<size_t num_dims, typename IndexType, typename StorageType>
__global__ void ConstantPadKernel(ConstantPadParams<num_dims, IndexType> params, __global__ void ConstantPadKernel(ConstantPadParams<num_dims, IndexType> params,
StorageType packed_pad_val) { StorageType packed_pad_val) {
const StorageType* src = reinterpret_cast<const StorageType*>(params.src); const StorageType* src = reinterpret_cast<const StorageType*>(params.src);
StorageType* dst = reinterpret_cast<StorageType*>(params.dst); StorageType* dst = reinterpret_cast<StorageType*>(params.dst);
IndexType src_index[num_dims]; IndexType src_index[num_dims];
IndexType dst_index[num_dims]; IndexType dst_index[num_dims];
CUDA_1D_KERNEL_LOOP_T(IndexType, linear_index, params.elem_cnt) { CUDA_1D_KERNEL_LOOP_T(IndexType, linear_index, params.elem_cnt) {
params.dst_index_helper.OffsetToNdIndex(linear_index, dst_index); params.dst_index_helper.OffsetToNdIndex(linear_index, dst_index);
bool if_pad = false; bool if_pad = false;
#pragma unroll #pragma unroll
for (int i = 0; i < num_dims; i++) { for (int i = 0; i < num_dims; i++) {
if (dst_index[i] >= params.valid_start[i] && dst_index[i] < params.valid_end[i]) { if (dst_index[i] >= params.valid_start[i] && dst_index[i] < params.valid_end[i]) {
src_index[i] = dst_index[i] - params.valid_start[i]; src_index[i] = dst_index[i] - params.valid_start[i];
} else { } else {
if_pad = true; if_pad = true;
break; break;
} }
} }
StorageType dst_val = packed_pad_val; StorageType dst_val = packed_pad_val;
if (!if_pad) { if (!if_pad) {
const IndexType src_offset = params.src_index_helper.NdIndexToOffset(src_index); const IndexType src_offset = params.src_index_helper.NdIndexToOffset(src_index);
dst_val = src[src_offset]; dst_val = src[src_offset];
} }
dst[linear_index] = dst_val; dst[linear_index] = dst_val;
} }
} }
template<> template<>
half GetValue<half>(Scalar value) { half GetValue<half>(Scalar value) {
return static_cast<half>(GetValue<float>(value)); return static_cast<half>(GetValue<float>(value));
} }
// #if CUDA_VERSION >= 11000 // #if CUDA_VERSION >= 11000
// template<> // template<>
// nv_bfloat16 GetValue<nv_bfloat16>(Scalar value) { // nv_bfloat16 GetValue<nv_bfloat16>(Scalar value) {
// return static_cast<nv_bfloat16>(GetValue<float>(value)); // return static_cast<nv_bfloat16>(GetValue<float>(value));
// } // }
// #endif // CUDA_VERSION >= 11000 // #endif // CUDA_VERSION >= 11000
template<size_t num_dims, typename IndexType, typename StorageType> template<size_t num_dims, typename IndexType, typename StorageType>
void LaunchKernel(Stream* stream, ConstantPadParams<num_dims, IndexType> params, void LaunchKernel(Stream* stream, ConstantPadParams<num_dims, IndexType> params,
StorageType packed_pad_val, size_t elem_cnt) { StorageType packed_pad_val, size_t elem_cnt) {
stream->As<CudaStream>()->LaunchKernelDefaultWaves( stream->As<CudaStream>()->LaunchKernelDefaultWaves(
(ConstantPadKernel<num_dims, IndexType, StorageType>), elem_cnt, params, packed_pad_val); (ConstantPadKernel<num_dims, IndexType, StorageType>), elem_cnt, params, packed_pad_val);
} }
template<size_t num_dims, typename IndexType, typename StorageType> template<size_t num_dims, typename IndexType, typename StorageType>
void LaunchKernel(Stream* stream, void* dst, const int64_t* dst_dims, const void* src, void LaunchKernel(Stream* stream, void* dst, const int64_t* dst_dims, const void* src,
const int64_t* src_dims, const int64_t* padding_before, const int64_t* src_dims, const int64_t* padding_before,
const int64_t* padding_after, StorageType packed_pad_val, size_t elem_cnt) { const int64_t* padding_after, StorageType packed_pad_val, size_t elem_cnt) {
ConstantPadParams<num_dims, IndexType> params; ConstantPadParams<num_dims, IndexType> params;
params.dst_index_helper = OffsetToIndexCalculator<IndexType, num_dims>(dst_dims); params.dst_index_helper = OffsetToIndexCalculator<IndexType, num_dims>(dst_dims);
params.src_index_helper = NdIndexOffsetHelper<IndexType, num_dims>(src_dims); params.src_index_helper = NdIndexOffsetHelper<IndexType, num_dims>(src_dims);
params.dst = dst; params.dst = dst;
params.src = src; params.src = src;
for (int i = 0; i < num_dims; i++) { for (int i = 0; i < num_dims; i++) {
params.valid_start[i] = padding_before[i]; params.valid_start[i] = padding_before[i];
params.valid_end[i] = dst_dims[i] - padding_after[i]; params.valid_end[i] = dst_dims[i] - padding_after[i];
} }
params.elem_cnt = elem_cnt; params.elem_cnt = elem_cnt;
LaunchKernel<num_dims, IndexType, StorageType>(stream, params, packed_pad_val, elem_cnt); LaunchKernel<num_dims, IndexType, StorageType>(stream, params, packed_pad_val, elem_cnt);
} }
template<size_t num_dims, typename StorageType> template<size_t num_dims, typename StorageType>
void DispatchIndexType(Stream* stream, void* dst, const int64_t* dst_dims, const void* src, void DispatchIndexType(Stream* stream, void* dst, const int64_t* dst_dims, const void* src,
const int64_t* src_dims, const int64_t* padding_before, const int64_t* src_dims, const int64_t* padding_before,
const int64_t* padding_after, StorageType packed_pad_val, size_t elem_cnt) { const int64_t* padding_after, StorageType packed_pad_val, size_t elem_cnt) {
if (elem_cnt < GetMaxVal<int32_t>()) { if (elem_cnt < GetMaxVal<int32_t>()) {
LaunchKernel<num_dims, int32_t, StorageType>(stream, dst, dst_dims, src, src_dims, LaunchKernel<num_dims, int32_t, StorageType>(stream, dst, dst_dims, src, src_dims,
padding_before, padding_after, packed_pad_val, padding_before, padding_after, packed_pad_val,
elem_cnt); elem_cnt);
} else { } else {
LaunchKernel<num_dims, int64_t, StorageType>(stream, dst, dst_dims, src, src_dims, LaunchKernel<num_dims, int64_t, StorageType>(stream, dst, dst_dims, src, src_dims,
padding_before, padding_after, packed_pad_val, padding_before, padding_after, packed_pad_val,
elem_cnt); elem_cnt);
} }
} }
template<size_t num_dims, typename T> template<size_t num_dims, typename T>
void DispatchPackSize(Stream* stream, void* dst, int64_t* dst_dims, const void* src, void DispatchPackSize(Stream* stream, void* dst, int64_t* dst_dims, const void* src,
int64_t* src_dims, int64_t* padding_before, int64_t* padding_after, int64_t* src_dims, int64_t* padding_before, int64_t* padding_after,
T pad_val) { T pad_val) {
constexpr int32_t max_packsize = GetMaxPackSize<T>(); constexpr int32_t max_packsize = GetMaxPackSize<T>();
size_t launch_pack_size = GetLaunchPackSize<max_packsize>(num_dims, dst, dst_dims, src, src_dims, size_t launch_pack_size = GetLaunchPackSize<max_packsize>(num_dims, dst, dst_dims, src, src_dims,
padding_before, padding_after); padding_before, padding_after);
dst_dims[num_dims - 1] /= launch_pack_size; dst_dims[num_dims - 1] /= launch_pack_size;
src_dims[num_dims - 1] /= launch_pack_size; src_dims[num_dims - 1] /= launch_pack_size;
padding_before[num_dims - 1] /= launch_pack_size; padding_before[num_dims - 1] /= launch_pack_size;
padding_after[num_dims - 1] /= launch_pack_size; padding_after[num_dims - 1] /= launch_pack_size;
size_t elem_cnt = 1; size_t elem_cnt = 1;
for (int i = 0; i < num_dims; i++) { elem_cnt *= dst_dims[i]; } for (int i = 0; i < num_dims; i++) { elem_cnt *= dst_dims[i]; }
if (launch_pack_size == 1) { if (launch_pack_size == 1) {
Pack<T, 1> packed_pad_val(pad_val); Pack<T, 1> packed_pad_val(pad_val);
DispatchIndexType<num_dims, PackType<T, 1>>(stream, dst, dst_dims, src, src_dims, DispatchIndexType<num_dims, PackType<T, 1>>(stream, dst, dst_dims, src, src_dims,
padding_before, padding_after, padding_before, padding_after,
packed_pad_val.storage, elem_cnt); packed_pad_val.storage, elem_cnt);
} else if (launch_pack_size == 2) { } else if (launch_pack_size == 2) {
Pack<T, 2> packed_pad_val(pad_val); Pack<T, 2> packed_pad_val(pad_val);
DispatchIndexType<num_dims, PackType<T, 2>>(stream, dst, dst_dims, src, src_dims, DispatchIndexType<num_dims, PackType<T, 2>>(stream, dst, dst_dims, src, src_dims,
padding_before, padding_after, padding_before, padding_after,
packed_pad_val.storage, elem_cnt); packed_pad_val.storage, elem_cnt);
} else if (launch_pack_size == 4) { } else if (launch_pack_size == 4) {
Pack<T, 4> packed_pad_val(pad_val); Pack<T, 4> packed_pad_val(pad_val);
DispatchIndexType<num_dims, PackType<T, 4>>(stream, dst, dst_dims, src, src_dims, DispatchIndexType<num_dims, PackType<T, 4>>(stream, dst, dst_dims, src, src_dims,
padding_before, padding_after, padding_before, padding_after,
packed_pad_val.storage, elem_cnt); packed_pad_val.storage, elem_cnt);
} else if (launch_pack_size == 8) { } else if (launch_pack_size == 8) {
Pack<T, 8> packed_pad_val(pad_val); Pack<T, 8> packed_pad_val(pad_val);
DispatchIndexType<num_dims, PackType<T, 8>>(stream, dst, dst_dims, src, src_dims, DispatchIndexType<num_dims, PackType<T, 8>>(stream, dst, dst_dims, src, src_dims,
padding_before, padding_after, padding_before, padding_after,
packed_pad_val.storage, elem_cnt); packed_pad_val.storage, elem_cnt);
} else if (launch_pack_size == 16) { } else if (launch_pack_size == 16) {
Pack<T, 16> packed_pad_val(pad_val); Pack<T, 16> packed_pad_val(pad_val);
DispatchIndexType<num_dims, PackType<T, 16>>(stream, dst, dst_dims, src, src_dims, DispatchIndexType<num_dims, PackType<T, 16>>(stream, dst, dst_dims, src, src_dims,
padding_before, padding_after, padding_before, padding_after,
packed_pad_val.storage, elem_cnt); packed_pad_val.storage, elem_cnt);
} else { } else {
UNIMPLEMENTED(); UNIMPLEMENTED();
} }
} }
template<typename T> template<typename T>
void LaunchWithSimplified(Stream* stream, size_t num_dims, void* dst, int64_t* dst_dims, void LaunchWithSimplified(Stream* stream, size_t num_dims, void* dst, int64_t* dst_dims,
const void* src, int64_t* src_dims, int64_t* padding_before, const void* src, int64_t* src_dims, int64_t* padding_before,
int64_t* padding_after, T pad_val) { int64_t* padding_after, T pad_val) {
void (*func)(Stream* /*stream*/, void* /*dst*/, int64_t* /*dst_dims*/, const void* /*src*/, void (*func)(Stream* /*stream*/, void* /*dst*/, int64_t* /*dst_dims*/, const void* /*src*/,
int64_t* /*src_dims*/, int64_t* /*padding_before*/, int64_t* /*padding_after*/, T) = int64_t* /*src_dims*/, int64_t* /*padding_before*/, int64_t* /*padding_after*/, T) =
nullptr; nullptr;
if (num_dims == 1) { if (num_dims == 1) {
func = DispatchPackSize<1, T>; func = DispatchPackSize<1, T>;
} else if (num_dims == 2) { } else if (num_dims == 2) {
func = DispatchPackSize<2, T>; func = DispatchPackSize<2, T>;
} else if (num_dims == 3) { } else if (num_dims == 3) {
func = DispatchPackSize<3, T>; func = DispatchPackSize<3, T>;
} else if (num_dims == 4) { } else if (num_dims == 4) {
func = DispatchPackSize<4, T>; func = DispatchPackSize<4, T>;
} else if (num_dims == 5) { } else if (num_dims == 5) {
func = DispatchPackSize<5, T>; func = DispatchPackSize<5, T>;
} else if (num_dims == 6) { } else if (num_dims == 6) {
func = DispatchPackSize<6, T>; func = DispatchPackSize<6, T>;
} else if (num_dims == 7) { } else if (num_dims == 7) {
func = DispatchPackSize<7, T>; func = DispatchPackSize<7, T>;
} else if (num_dims == 8) { } else if (num_dims == 8) {
func = DispatchPackSize<8, T>; func = DispatchPackSize<8, T>;
} else { } else {
UNIMPLEMENTED(); UNIMPLEMENTED();
} }
func(stream, dst, dst_dims, src, src_dims, padding_before, padding_after, pad_val); func(stream, dst, dst_dims, src, src_dims, padding_before, padding_after, pad_val);
} }
template<typename T> template<typename T>
void SimplifyThenLaunch(Stream* stream, size_t num_dims, const int64_t* src_dims, const void* src, void SimplifyThenLaunch(Stream* stream, size_t num_dims, const int64_t* src_dims, const void* src,
const int64_t* padding_before, const int64_t* padding_after, T pad_val, const int64_t* padding_before, const int64_t* padding_after, T pad_val,
void* dst) { void* dst) {
CHECK_LE(num_dims, kMaxNumDims); CHECK_LE(num_dims, kMaxNumDims);
int64_t simplified_dst_dims[kMaxNumDims]; int64_t simplified_dst_dims[kMaxNumDims];
int64_t simplified_src_dims[kMaxNumDims]; int64_t simplified_src_dims[kMaxNumDims];
int64_t simplified_padding_before[kMaxNumDims]; int64_t simplified_padding_before[kMaxNumDims];
int64_t simplified_padding_after[kMaxNumDims]; int64_t simplified_padding_after[kMaxNumDims];
size_t simplified_num_dims = 1; size_t simplified_num_dims = 1;
SimplifyPadDims(num_dims, src_dims, padding_before, padding_after, &simplified_num_dims, SimplifyPadDims(num_dims, src_dims, padding_before, padding_after, &simplified_num_dims,
simplified_dst_dims, simplified_src_dims, simplified_padding_before, simplified_dst_dims, simplified_src_dims, simplified_padding_before,
simplified_padding_after); simplified_padding_after);
LaunchWithSimplified<T>(stream, simplified_num_dims, dst, simplified_dst_dims, src, LaunchWithSimplified<T>(stream, simplified_num_dims, dst, simplified_dst_dims, src,
simplified_src_dims, simplified_padding_before, simplified_padding_after, simplified_src_dims, simplified_padding_before, simplified_padding_after,
pad_val); pad_val);
} }
template<typename T> template<typename T>
class ConstantPadImpl : public ConstantPad { class ConstantPadImpl : public ConstantPad {
public: public:
OF_DISALLOW_COPY_AND_MOVE(ConstantPadImpl); OF_DISALLOW_COPY_AND_MOVE(ConstantPadImpl);
ConstantPadImpl() = default; ConstantPadImpl() = default;
~ConstantPadImpl() override = default; ~ConstantPadImpl() override = default;
void Launch(Stream* stream, size_t num_dims, const int64_t* src_dims, const void* src, void Launch(Stream* stream, size_t num_dims, const int64_t* src_dims, const void* src,
const int64_t* padding_before, const int64_t* padding_after, Scalar pad_val, const int64_t* padding_before, const int64_t* padding_after, Scalar pad_val,
void* dst) override { void* dst) override {
SimplifyThenLaunch<T>(stream, num_dims, src_dims, src, padding_before, padding_after, SimplifyThenLaunch<T>(stream, num_dims, src_dims, src, padding_before, padding_after,
GetValue<T>(pad_val), dst); GetValue<T>(pad_val), dst);
} }
}; };
template<typename T> template<typename T>
std::unique_ptr<ConstantPad> NewConstantPad() { std::unique_ptr<ConstantPad> NewConstantPad() {
return std::unique_ptr<ConstantPad>(new ConstantPadImpl<T>()); return std::unique_ptr<ConstantPad>(new ConstantPadImpl<T>());
} }
class ConstantPadFactoryImpl : public ConstantPadFactory { class ConstantPadFactoryImpl : public ConstantPadFactory {
public: public:
OF_DISALLOW_COPY_AND_MOVE(ConstantPadFactoryImpl); OF_DISALLOW_COPY_AND_MOVE(ConstantPadFactoryImpl);
ConstantPadFactoryImpl() = default; ConstantPadFactoryImpl() = default;
~ConstantPadFactoryImpl() override = default; ~ConstantPadFactoryImpl() override = default;
std::unique_ptr<ConstantPad> New(DataType data_type) override { std::unique_ptr<ConstantPad> New(DataType data_type) override {
#define MAKE_NEW_CONSTANT_PAD_ENTRY(type_cpp, type_proto) {type_proto, NewConstantPad<type_cpp>}, #define MAKE_NEW_CONSTANT_PAD_ENTRY(type_cpp, type_proto) {type_proto, NewConstantPad<type_cpp>},
static const std::map<DataType, std::function<std::unique_ptr<ConstantPad>()>> static const std::map<DataType, std::function<std::unique_ptr<ConstantPad>()>>
new_constant_pad_handle{ new_constant_pad_handle{
OF_PP_FOR_EACH_TUPLE(MAKE_NEW_CONSTANT_PAD_ENTRY, CUDA_PRIMITIVE_ALL_TYPE_SEQ)}; OF_PP_FOR_EACH_TUPLE(MAKE_NEW_CONSTANT_PAD_ENTRY, CUDA_PRIMITIVE_ALL_TYPE_SEQ)};
#undef MAKE_NEW_CONSTANT_PAD_ENTRY #undef MAKE_NEW_CONSTANT_PAD_ENTRY
const auto it = new_constant_pad_handle.find(data_type); const auto it = new_constant_pad_handle.find(data_type);
if (it != new_constant_pad_handle.end()) { if (it != new_constant_pad_handle.end()) {
return it->second(); return it->second();
} else { } else {
return nullptr; return nullptr;
} }
} }
}; };
REGISTER_PRIMITIVE_FACTORY(DeviceType::kCUDA, ConstantPadFactory, ConstantPadFactoryImpl); REGISTER_PRIMITIVE_FACTORY(DeviceType::kCUDA, ConstantPadFactory, ConstantPadFactoryImpl);
} // namespace } // namespace
} // namespace primitive } // namespace primitive
} // namespace ep } // namespace ep
} // namespace oneflow } // namespace oneflow
\ No newline at end of file
#include "hip/hip_runtime.h" #include "hip/hip_runtime.h"
/* /*
Copyright 2020 The OneFlow Authors. All rights reserved. Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
*/ */
#include "oneflow/core/ep/include/primitive/copy_nd.h" #include "oneflow/core/ep/include/primitive/copy_nd.h"
#include "oneflow/core/ep/common/primitive/copy_nd.h" #include "oneflow/core/ep/common/primitive/copy_nd.h"
#include "oneflow/core/ep/rocm/cuda_stream.h" #include "oneflow/core/ep/rocm/cuda_stream.h"
#include <hip/hip_runtime.h> #include <hip/hip_runtime.h>
namespace oneflow { namespace oneflow {
namespace ep { namespace ep {
namespace primitive { namespace primitive {
namespace { namespace {
template<size_t num_dims, size_t movement_size, typename IndexType> template<size_t num_dims, size_t movement_size, typename IndexType>
__global__ void CopyNdKernel(CopyNdKernelParams<num_dims, IndexType> params) { __global__ void CopyNdKernel(CopyNdKernelParams<num_dims, IndexType> params) {
using T = typename std::aligned_storage<movement_size, movement_size>::type; using T = typename std::aligned_storage<movement_size, movement_size>::type;
const T* src = reinterpret_cast<const T*>(params.src); const T* src = reinterpret_cast<const T*>(params.src);
T* dst = reinterpret_cast<T*>(params.dst); T* dst = reinterpret_cast<T*>(params.dst);
IndexType copy_index[num_dims]; IndexType copy_index[num_dims];
IndexType src_index[num_dims]; IndexType src_index[num_dims];
IndexType dst_index[num_dims]; IndexType dst_index[num_dims];
CUDA_1D_KERNEL_LOOP_T(IndexType, i, params.count) { CUDA_1D_KERNEL_LOOP_T(IndexType, i, params.count) {
params.copy_index_helper.OffsetToNdIndex(i, copy_index); params.copy_index_helper.OffsetToNdIndex(i, copy_index);
#pragma unroll #pragma unroll
for (size_t j = 0; j < num_dims; ++j) { for (size_t j = 0; j < num_dims; ++j) {
src_index[j] = params.src_pos[j] + copy_index[j]; src_index[j] = params.src_pos[j] + copy_index[j];
dst_index[j] = params.dst_pos[j] + copy_index[j]; dst_index[j] = params.dst_pos[j] + copy_index[j];
} }
const IndexType src_offset = params.src_index_helper.NdIndexToOffset(src_index); const IndexType src_offset = params.src_index_helper.NdIndexToOffset(src_index);
const IndexType dst_offset = params.dst_index_helper.NdIndexToOffset(dst_index); const IndexType dst_offset = params.dst_index_helper.NdIndexToOffset(dst_index);
dst[dst_offset] = src[src_offset]; dst[dst_offset] = src[src_offset];
} }
} }
template<size_t num_dims, size_t movement_size, typename IndexType> template<size_t num_dims, size_t movement_size, typename IndexType>
void LaunchKernel(Stream* stream, CopyNdKernelParams<num_dims, IndexType> params) { void LaunchKernel(Stream* stream, CopyNdKernelParams<num_dims, IndexType> params) {
hipStream_t cuda_stream = stream->As<CudaStream>()->cuda_stream(); hipStream_t cuda_stream = stream->As<CudaStream>()->cuda_stream();
CopyNdKernel<num_dims, movement_size, IndexType> CopyNdKernel<num_dims, movement_size, IndexType>
<<<BlocksNum4ThreadsNum(params.count), kCudaThreadsNumPerBlock, 0, cuda_stream>>>(params); <<<BlocksNum4ThreadsNum(params.count), kCudaThreadsNumPerBlock, 0, cuda_stream>>>(params);
} }
class CopyNdImpl : public CopyNd { class CopyNdImpl : public CopyNd {
public: public:
OF_DISALLOW_COPY_AND_MOVE(CopyNdImpl); OF_DISALLOW_COPY_AND_MOVE(CopyNdImpl);
CopyNdImpl() = default; CopyNdImpl() = default;
~CopyNdImpl() override = default; ~CopyNdImpl() override = default;
void Launch(Stream* stream, DataType data_type, size_t num_dims, void* dst, void Launch(Stream* stream, DataType data_type, size_t num_dims, void* dst,
const int64_t* dst_dims, const int64_t* dst_pos, const void* src, const int64_t* dst_dims, const int64_t* dst_pos, const void* src,
const int64_t* src_dims, const int64_t* src_pos, const int64_t* src_dims, const int64_t* src_pos,
const int64_t* extent) const override { const int64_t* extent) const override {
SimplifyThenLaunch(stream, data_type, num_dims, dst, dst_dims, dst_pos, src, src_dims, src_pos, SimplifyThenLaunch(stream, data_type, num_dims, dst, dst_dims, dst_pos, src, src_dims, src_pos,
extent); extent);
} }
}; };
class CopyNdFactoryImpl : public CopyNdFactory { class CopyNdFactoryImpl : public CopyNdFactory {
public: public:
OF_DISALLOW_COPY_AND_MOVE(CopyNdFactoryImpl); OF_DISALLOW_COPY_AND_MOVE(CopyNdFactoryImpl);
CopyNdFactoryImpl() = default; CopyNdFactoryImpl() = default;
~CopyNdFactoryImpl() override = default; ~CopyNdFactoryImpl() override = default;
std::unique_ptr<CopyNd> New(size_t max_num_dims) override { std::unique_ptr<CopyNd> New(size_t max_num_dims) override {
if (max_num_dims <= kMaxNumDims) { if (max_num_dims <= kMaxNumDims) {
return std::unique_ptr<CopyNd>(new CopyNdImpl()); return std::unique_ptr<CopyNd>(new CopyNdImpl());
} else { } else {
return nullptr; return nullptr;
} }
} }
}; };
REGISTER_PRIMITIVE_FACTORY(DeviceType::kCUDA, CopyNdFactory, CopyNdFactoryImpl); REGISTER_PRIMITIVE_FACTORY(DeviceType::kCUDA, CopyNdFactory, CopyNdFactoryImpl);
} // namespace } // namespace
} // namespace primitive } // namespace primitive
} // namespace ep } // namespace ep
} // namespace oneflow } // namespace oneflow
/* /*
Copyright 2020 The OneFlow Authors. All rights reserved. Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
*/ */
#include "oneflow/core/ep/common/primitive/elementwise_unary.h" #include "oneflow/core/ep/common/primitive/elementwise_unary.h"
#include "oneflow/core/ep/rocm/primitive/unary_functor.hip.h" #include "oneflow/core/ep/rocm/primitive/unary_functor.hip.h"
namespace oneflow { namespace oneflow {
namespace ep { namespace ep {
namespace primitive { namespace primitive {
namespace { namespace {
template<UnaryOp unary_op, typename Src, typename Dst> template<UnaryOp unary_op, typename Src, typename Dst>
class ElementwiseUnaryImpl : public ElementwiseUnary { class ElementwiseUnaryImpl : public ElementwiseUnary {
public: public:
OF_DISALLOW_COPY_AND_MOVE(ElementwiseUnaryImpl); OF_DISALLOW_COPY_AND_MOVE(ElementwiseUnaryImpl);
ElementwiseUnaryImpl(Scalar attr0, Scalar attr1) : attr0(attr0), attr1(attr1) {} ElementwiseUnaryImpl(Scalar attr0, Scalar attr1) : attr0(attr0), attr1(attr1) {}
~ElementwiseUnaryImpl() override = default; ~ElementwiseUnaryImpl() override = default;
void Launch(Stream* stream, const void* src, void* dst, size_t count) override { void Launch(Stream* stream, const void* src, void* dst, size_t count) override {
auto* cuda_stream = stream->As<CudaStream>(); auto* cuda_stream = stream->As<CudaStream>();
auto functor = UnaryFunctor<DeviceType::kCUDA, unary_op, Dst, Src>(attr0, attr1); auto functor = UnaryFunctor<DeviceType::kCUDA, unary_op, Dst, Src>(attr0, attr1);
OF_CUDA_CHECK((cuda::elementwise::Unary<decltype(functor), Dst, Src>( OF_CUDA_CHECK((cuda::elementwise::Unary<decltype(functor), Dst, Src>(
functor, count, reinterpret_cast<Dst*>(dst), reinterpret_cast<const Src*>(src), functor, count, reinterpret_cast<Dst*>(dst), reinterpret_cast<const Src*>(src),
cuda_stream->cuda_stream()))); cuda_stream->cuda_stream())));
} }
protected: protected:
Scalar attr0, attr1; Scalar attr0, attr1;
}; };
template<UnaryOp unary_op, typename Src, typename Dst> template<UnaryOp unary_op, typename Src, typename Dst>
std::unique_ptr<ElementwiseUnary> NewElementwiseUnary(Scalar attr0, Scalar attr1) { std::unique_ptr<ElementwiseUnary> NewElementwiseUnary(Scalar attr0, Scalar attr1) {
return std::unique_ptr<ElementwiseUnary>( return std::unique_ptr<ElementwiseUnary>(
new ElementwiseUnaryImpl<unary_op, Src, Dst>(attr0, attr1)); new ElementwiseUnaryImpl<unary_op, Src, Dst>(attr0, attr1));
} }
class ElementwiseUnaryFactoryImpl : public ElementwiseUnaryFactory { class ElementwiseUnaryFactoryImpl : public ElementwiseUnaryFactory {
public: public:
OF_DISALLOW_COPY_AND_MOVE(ElementwiseUnaryFactoryImpl); OF_DISALLOW_COPY_AND_MOVE(ElementwiseUnaryFactoryImpl);
ElementwiseUnaryFactoryImpl() = default; ElementwiseUnaryFactoryImpl() = default;
~ElementwiseUnaryFactoryImpl() override = default; ~ElementwiseUnaryFactoryImpl() override = default;
std::unique_ptr<ElementwiseUnary> New(UnaryOp unary_op, DataType src_type, std::unique_ptr<ElementwiseUnary> New(UnaryOp unary_op, DataType src_type,
DataType dst_dtype) override { DataType dst_dtype) override {
return New(unary_op, src_type, dst_dtype, Scalar(), Scalar()); return New(unary_op, src_type, dst_dtype, Scalar(), Scalar());
} }
std::unique_ptr<ElementwiseUnary> New(UnaryOp unary_op, DataType src_type, DataType dst_dtype, std::unique_ptr<ElementwiseUnary> New(UnaryOp unary_op, DataType src_type, DataType dst_dtype,
Scalar attr0) override { Scalar attr0) override {
return New(unary_op, src_type, dst_dtype, attr0, Scalar()); return New(unary_op, src_type, dst_dtype, attr0, Scalar());
} }
std::unique_ptr<ElementwiseUnary> New(UnaryOp unary_op, DataType src_type, DataType dst_dtype, std::unique_ptr<ElementwiseUnary> New(UnaryOp unary_op, DataType src_type, DataType dst_dtype,
Scalar attr0, Scalar attr1) override { Scalar attr0, Scalar attr1) override {
#define MAKE_NEW_SAME_DTYPE_ELEMENTWISE_UNARY_ENTRY(unary_op, dtype_pair) \ #define MAKE_NEW_SAME_DTYPE_ELEMENTWISE_UNARY_ENTRY(unary_op, dtype_pair) \
{std::make_tuple(unary_op, OF_PP_PAIR_SECOND(dtype_pair), OF_PP_PAIR_SECOND(dtype_pair)), \ {std::make_tuple(unary_op, OF_PP_PAIR_SECOND(dtype_pair), OF_PP_PAIR_SECOND(dtype_pair)), \
NewElementwiseUnary<unary_op, OF_PP_PAIR_FIRST(dtype_pair), OF_PP_PAIR_FIRST(dtype_pair)>}, NewElementwiseUnary<unary_op, OF_PP_PAIR_FIRST(dtype_pair), OF_PP_PAIR_FIRST(dtype_pair)>},
#define MAKE_NEW_DIFFERENT_DTYPE_ELEMENTWISE_UNARY_ENTRY(unary_op, src_type_pair, dst_dtype_pair) \ #define MAKE_NEW_DIFFERENT_DTYPE_ELEMENTWISE_UNARY_ENTRY(unary_op, src_type_pair, dst_dtype_pair) \
{std::make_tuple(unary_op, OF_PP_PAIR_SECOND(src_type_pair), OF_PP_PAIR_SECOND(dst_dtype_pair)), \ {std::make_tuple(unary_op, OF_PP_PAIR_SECOND(src_type_pair), OF_PP_PAIR_SECOND(dst_dtype_pair)), \
NewElementwiseUnary<unary_op, OF_PP_PAIR_FIRST(src_type_pair), \ NewElementwiseUnary<unary_op, OF_PP_PAIR_FIRST(src_type_pair), \
OF_PP_PAIR_FIRST(dst_dtype_pair)>}, OF_PP_PAIR_FIRST(dst_dtype_pair)>},
static const std::map<std::tuple<UnaryOp, DataType, DataType>, static const std::map<std::tuple<UnaryOp, DataType, DataType>,
std::function<std::unique_ptr<ElementwiseUnary>(Scalar, Scalar)>> std::function<std::unique_ptr<ElementwiseUnary>(Scalar, Scalar)>>
new_elementwise_unary_handle{ new_elementwise_unary_handle{
// For All Type OP // For All Type OP
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(MAKE_NEW_SAME_DTYPE_ELEMENTWISE_UNARY_ENTRY, OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(MAKE_NEW_SAME_DTYPE_ELEMENTWISE_UNARY_ENTRY,
UNARY_MATH_OP_SEQ, CUDA_PRIMITIVE_ALL_TYPE_SEQ) UNARY_MATH_OP_SEQ, CUDA_PRIMITIVE_ALL_TYPE_SEQ)
// For Float Type OP // For Float Type OP
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(MAKE_NEW_SAME_DTYPE_ELEMENTWISE_UNARY_ENTRY, OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(MAKE_NEW_SAME_DTYPE_ELEMENTWISE_UNARY_ENTRY,
UNARY_FLOATING_MATH_OP_SEQ, UNARY_FLOATING_MATH_OP_SEQ,
CUDA_PRIMITIVE_FLOATING_TYPE_SEQ) CUDA_PRIMITIVE_FLOATING_TYPE_SEQ)
// For Utils OP // For Utils OP
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(MAKE_NEW_DIFFERENT_DTYPE_ELEMENTWISE_UNARY_ENTRY, OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(MAKE_NEW_DIFFERENT_DTYPE_ELEMENTWISE_UNARY_ENTRY,
UNARY_UTILS_OP_SEQ, UTIL_OPS_DATA_TYPE_SEQ, UNARY_UTILS_OP_SEQ, UTIL_OPS_DATA_TYPE_SEQ,
CUDA_PRIMITIVE_BOOL_TYPE_SEQ) CUDA_PRIMITIVE_BOOL_TYPE_SEQ)
// For Logical OP // For Logical OP
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(MAKE_NEW_DIFFERENT_DTYPE_ELEMENTWISE_UNARY_ENTRY, OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(MAKE_NEW_DIFFERENT_DTYPE_ELEMENTWISE_UNARY_ENTRY,
UNARY_LOGICAL_OP_SEQ, CUDA_PRIMITIVE_ALL_TYPE_SEQ, UNARY_LOGICAL_OP_SEQ, CUDA_PRIMITIVE_ALL_TYPE_SEQ,
CUDA_PRIMITIVE_BOOL_TYPE_SEQ)}; CUDA_PRIMITIVE_BOOL_TYPE_SEQ)};
#undef MAKE_NEW_DIFFERENT_DTYPE_ELEMENTWISE_UNARY_ENTRY #undef MAKE_NEW_DIFFERENT_DTYPE_ELEMENTWISE_UNARY_ENTRY
#undef MAKE_NEW_SAME_DTYPE_ELEMENTWISE_UNARY_ENTRY #undef MAKE_NEW_SAME_DTYPE_ELEMENTWISE_UNARY_ENTRY
const auto it = const auto it =
new_elementwise_unary_handle.find(std::make_tuple(unary_op, src_type, dst_dtype)); new_elementwise_unary_handle.find(std::make_tuple(unary_op, src_type, dst_dtype));
if (it != new_elementwise_unary_handle.end()) { if (it != new_elementwise_unary_handle.end()) {
return it->second(attr0, attr1); return it->second(attr0, attr1);
} else { } else {
return nullptr; return nullptr;
} }
} }
}; };
REGISTER_PRIMITIVE_FACTORY(DeviceType::kCUDA, ElementwiseUnaryFactory, ElementwiseUnaryFactoryImpl); REGISTER_PRIMITIVE_FACTORY(DeviceType::kCUDA, ElementwiseUnaryFactory, ElementwiseUnaryFactoryImpl);
} // namespace } // namespace
} // namespace primitive } // namespace primitive
} // namespace ep } // namespace ep
} // namespace oneflow } // namespace oneflow
\ No newline at end of file
#include "hip/hip_runtime.h" #include "hip/hip_runtime.h"
/* /*
Copyright 2020 The OneFlow Authors. All rights reserved. Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
*/ */
#include "oneflow/core/ep/include/primitive/fill.h" #include "oneflow/core/ep/include/primitive/fill.h"
#include "oneflow/core/ep/rocm/primitive/type_seq.h" #include "oneflow/core/ep/rocm/primitive/type_seq.h"
#include "oneflow/core/ep/rocm/cuda_stream.h" #include "oneflow/core/ep/rocm/cuda_stream.h"
namespace oneflow { namespace oneflow {
namespace ep { namespace ep {
namespace primitive { namespace primitive {
namespace { namespace {
template<size_t size> template<size_t size>
using Storage = typename std::aligned_storage<size, size>::type; using Storage = typename std::aligned_storage<size, size>::type;
template<typename T, size_t pack> template<typename T, size_t pack>
union Pack { union Pack {
static constexpr size_t size = sizeof(T) * pack; static constexpr size_t size = sizeof(T) * pack;
explicit __device__ __host__ Pack(T value) { explicit __device__ __host__ Pack(T value) {
static_assert(sizeof(Pack) == size, ""); static_assert(sizeof(Pack) == size, "");
static_assert(alignof(Pack) == size, ""); static_assert(alignof(Pack) == size, "");
#pragma unroll #pragma unroll
for (size_t i = 0; i < pack; ++i) { elem[i] = value; } for (size_t i = 0; i < pack; ++i) { elem[i] = value; }
} }
T elem[pack]; T elem[pack];
Storage<size> storage; Storage<size> storage;
}; };
template<typename T, size_t pack> template<typename T, size_t pack>
__global__ void FillGpu(T* dst, T value, size_t count) { __global__ void FillGpu(T* dst, T value, size_t count) {
const size_t pack_count = count / pack; const size_t pack_count = count / pack;
Pack<T, pack> pack_value(value); Pack<T, pack> pack_value(value);
auto* pack_dst = reinterpret_cast<decltype(pack_value.storage)*>(dst); auto* pack_dst = reinterpret_cast<decltype(pack_value.storage)*>(dst);
CUDA_1D_KERNEL_LOOP_T(size_t, i, pack_count) { pack_dst[i] = pack_value.storage; } CUDA_1D_KERNEL_LOOP_T(size_t, i, pack_count) { pack_dst[i] = pack_value.storage; }
T* tail_dst = dst + pack_count * pack; T* tail_dst = dst + pack_count * pack;
const size_t tail_count = count - pack_count * pack; const size_t tail_count = count - pack_count * pack;
CUDA_1D_KERNEL_LOOP_T(size_t, i, tail_count) { tail_dst[i] = value; } CUDA_1D_KERNEL_LOOP_T(size_t, i, tail_count) { tail_dst[i] = value; }
} }
template<typename T> template<typename T>
T GetValue(Scalar value) { T GetValue(Scalar value) {
return value.Value<T>(); return value.Value<T>();
} }
template<> template<>
half GetValue<half>(Scalar value) { half GetValue<half>(Scalar value) {
return static_cast<half>(GetValue<float>(value)); return static_cast<half>(GetValue<float>(value));
} }
// #if CUDA_VERSION >= 11000 // #if CUDA_VERSION >= 11000
// template<> // template<>
// nv_bfloat16 GetValue<nv_bfloat16>(Scalar value) { // nv_bfloat16 GetValue<nv_bfloat16>(Scalar value) {
// return static_cast<nv_bfloat16>(GetValue<float>(value)); // return static_cast<nv_bfloat16>(GetValue<float>(value));
// } // }
// #endif // CUDA_VERSION >= 11000 // #endif // CUDA_VERSION >= 11000
template<typename T, size_t pack> template<typename T, size_t pack>
typename std::enable_if<(pack != 0), void>::type LaunchPackFill(hipStream_t stream, T* dst, typename std::enable_if<(pack != 0), void>::type LaunchPackFill(hipStream_t stream, T* dst,
T value, size_t count) { T value, size_t count) {
FillGpu<T, pack> FillGpu<T, pack>
<<<BlocksNum4ThreadsNum(count), kCudaThreadsNumPerBlock, 0, stream>>>(dst, value, count); <<<BlocksNum4ThreadsNum(count), kCudaThreadsNumPerBlock, 0, stream>>>(dst, value, count);
} }
template<typename T, size_t pack> template<typename T, size_t pack>
typename std::enable_if<(pack == 0), void>::type LaunchPackFill(hipStream_t stream, T* dst, typename std::enable_if<(pack == 0), void>::type LaunchPackFill(hipStream_t stream, T* dst,
T value, size_t count) { T value, size_t count) {
LOG(FATAL) << "wrong alignment"; LOG(FATAL) << "wrong alignment";
} }
template<typename T> template<typename T>
void LaunchFill(hipStream_t stream, T* dst, T value, size_t count) { void LaunchFill(hipStream_t stream, T* dst, T value, size_t count) {
auto uintptr = reinterpret_cast<std::uintptr_t>(dst); auto uintptr = reinterpret_cast<std::uintptr_t>(dst);
if (uintptr % 16 == 0) { if (uintptr % 16 == 0) {
LaunchPackFill<T, 16 / sizeof(T)>(stream, dst, value, count); LaunchPackFill<T, 16 / sizeof(T)>(stream, dst, value, count);
} else if (uintptr % 8 == 0) { } else if (uintptr % 8 == 0) {
LaunchPackFill<T, 8 / sizeof(T)>(stream, dst, value, count); LaunchPackFill<T, 8 / sizeof(T)>(stream, dst, value, count);
} else if (uintptr % 4 == 0) { } else if (uintptr % 4 == 0) {
LaunchPackFill<T, 4 / sizeof(T)>(stream, dst, value, count); LaunchPackFill<T, 4 / sizeof(T)>(stream, dst, value, count);
} else if (uintptr % 2 == 0) { } else if (uintptr % 2 == 0) {
LaunchPackFill<T, 2 / sizeof(T)>(stream, dst, value, count); LaunchPackFill<T, 2 / sizeof(T)>(stream, dst, value, count);
} else { } else {
LaunchPackFill<T, 1 / sizeof(T)>(stream, dst, value, count); LaunchPackFill<T, 1 / sizeof(T)>(stream, dst, value, count);
} }
} }
template<typename T> template<typename T>
class FillImpl : public Fill { class FillImpl : public Fill {
public: public:
OF_DISALLOW_COPY_AND_MOVE(FillImpl); OF_DISALLOW_COPY_AND_MOVE(FillImpl);
FillImpl() = default; FillImpl() = default;
~FillImpl() override = default; ~FillImpl() override = default;
void Launch(Stream* stream, void* dst, Scalar value, size_t count) override { void Launch(Stream* stream, void* dst, Scalar value, size_t count) override {
hipStream_t cuda_stream = stream->As<CudaStream>()->cuda_stream(); hipStream_t cuda_stream = stream->As<CudaStream>()->cuda_stream();
LaunchFill<T>(cuda_stream, reinterpret_cast<T*>(dst), GetValue<T>(value), count); LaunchFill<T>(cuda_stream, reinterpret_cast<T*>(dst), GetValue<T>(value), count);
} }
}; };
template<typename T> template<typename T>
std::unique_ptr<Fill> NewFill() { std::unique_ptr<Fill> NewFill() {
return std::unique_ptr<Fill>(new FillImpl<T>()); return std::unique_ptr<Fill>(new FillImpl<T>());
} }
class FillFactoryImpl : public FillFactory { class FillFactoryImpl : public FillFactory {
public: public:
OF_DISALLOW_COPY_AND_MOVE(FillFactoryImpl); OF_DISALLOW_COPY_AND_MOVE(FillFactoryImpl);
FillFactoryImpl() = default; FillFactoryImpl() = default;
~FillFactoryImpl() override = default; ~FillFactoryImpl() override = default;
std::unique_ptr<Fill> New(DataType data_type) override { std::unique_ptr<Fill> New(DataType data_type) override {
#define MAKE_NEW_FILL_ENTRY(type_cpp, type_proto) {type_proto, NewFill<type_cpp>}, #define MAKE_NEW_FILL_ENTRY(type_cpp, type_proto) {type_proto, NewFill<type_cpp>},
static const std::map<DataType, std::function<std::unique_ptr<Fill>()>> new_fill_handle{ static const std::map<DataType, std::function<std::unique_ptr<Fill>()>> new_fill_handle{
OF_PP_FOR_EACH_TUPLE(MAKE_NEW_FILL_ENTRY, CUDA_PRIMITIVE_ALL_TYPE_SEQ)}; OF_PP_FOR_EACH_TUPLE(MAKE_NEW_FILL_ENTRY, CUDA_PRIMITIVE_ALL_TYPE_SEQ)};
#undef MAKE_NEW_FILL_ENTRY #undef MAKE_NEW_FILL_ENTRY
const auto it = new_fill_handle.find(data_type); const auto it = new_fill_handle.find(data_type);
if (it != new_fill_handle.end()) { if (it != new_fill_handle.end()) {
return it->second(); return it->second();
} else { } else {
return nullptr; return nullptr;
} }
} }
}; };
REGISTER_PRIMITIVE_FACTORY(DeviceType::kCUDA, FillFactory, FillFactoryImpl); REGISTER_PRIMITIVE_FACTORY(DeviceType::kCUDA, FillFactory, FillFactoryImpl);
} // namespace } // namespace
} // namespace primitive } // namespace primitive
} // namespace ep } // namespace ep
} // namespace oneflow } // namespace oneflow
/* /*
Copyright 2020 The OneFlow Authors. All rights reserved. Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
*/ */
#ifdef WITH_ROCM #ifdef WITH_ROCM
#include "oneflow/core/ep/include/primitive/memcpy.h" #include "oneflow/core/ep/include/primitive/memcpy.h"
#include "oneflow/core/ep/rocm/cuda_stream.h" #include "oneflow/core/ep/rocm/cuda_stream.h"
#include <hip/hip_runtime.h> #include <hip/hip_runtime.h>
namespace oneflow { namespace oneflow {
namespace ep { namespace ep {
namespace primitive { namespace primitive {
namespace { namespace {
class MemcpyImpl : public Memcpy { class MemcpyImpl : public Memcpy {
public: public:
OF_DISALLOW_COPY_AND_MOVE(MemcpyImpl); OF_DISALLOW_COPY_AND_MOVE(MemcpyImpl);
MemcpyImpl() = default; MemcpyImpl() = default;
~MemcpyImpl() override = default; ~MemcpyImpl() override = default;
void Launch(Stream* stream, void* dst, const void* src, size_t count) override { void Launch(Stream* stream, void* dst, const void* src, size_t count) override {
if (dst == src) { return; } if (dst == src) { return; }
auto* cuda_stream = stream->As<CudaStream>(); auto* cuda_stream = stream->As<CudaStream>();
OF_CUDA_CHECK(hipMemcpyAsync(dst, src, count, hipMemcpyDefault, cuda_stream->cuda_stream())); OF_CUDA_CHECK(hipMemcpyAsync(dst, src, count, hipMemcpyDefault, cuda_stream->cuda_stream()));
} }
}; };
class MemcpyFactoryImpl : public MemcpyFactory { class MemcpyFactoryImpl : public MemcpyFactory {
public: public:
OF_DISALLOW_COPY_AND_MOVE(MemcpyFactoryImpl); OF_DISALLOW_COPY_AND_MOVE(MemcpyFactoryImpl);
MemcpyFactoryImpl() = default; MemcpyFactoryImpl() = default;
~MemcpyFactoryImpl() override = default; ~MemcpyFactoryImpl() override = default;
std::unique_ptr<Memcpy> New(MemcpyKind kind) override { std::unique_ptr<Memcpy> New(MemcpyKind kind) override {
return std::unique_ptr<Memcpy>(new MemcpyImpl()); return std::unique_ptr<Memcpy>(new MemcpyImpl());
} }
}; };
REGISTER_PRIMITIVE_FACTORY(DeviceType::kCUDA, MemcpyFactory, MemcpyFactoryImpl); REGISTER_PRIMITIVE_FACTORY(DeviceType::kCUDA, MemcpyFactory, MemcpyFactoryImpl);
} // namespace } // namespace
} // namespace primitive } // namespace primitive
} // namespace ep } // namespace ep
} // namespace oneflow } // namespace oneflow
#endif #endif
/* /*
Copyright 2020 The OneFlow Authors. All rights reserved. Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
*/ */
#ifdef WITH_ROCM #ifdef WITH_ROCM
#include "oneflow/core/ep/include/primitive/memset.h" #include "oneflow/core/ep/include/primitive/memset.h"
#include "oneflow/core/ep/rocm/cuda_stream.h" #include "oneflow/core/ep/rocm/cuda_stream.h"
#include <hip/hip_runtime.h> #include <hip/hip_runtime.h>
namespace oneflow { namespace oneflow {
namespace ep { namespace ep {
namespace primitive { namespace primitive {
namespace { namespace {
class MemsetImpl : public Memset { class MemsetImpl : public Memset {
public: public:
OF_DISALLOW_COPY_AND_MOVE(MemsetImpl); OF_DISALLOW_COPY_AND_MOVE(MemsetImpl);
MemsetImpl() = default; MemsetImpl() = default;
~MemsetImpl() override = default; ~MemsetImpl() override = default;
void Launch(Stream* stream, void* ptr, int value, size_t count) override { void Launch(Stream* stream, void* ptr, int value, size_t count) override {
auto* cuda_stream = stream->As<CudaStream>(); auto* cuda_stream = stream->As<CudaStream>();
OF_CUDA_CHECK(hipMemsetAsync(ptr, value, count, cuda_stream->cuda_stream())); OF_CUDA_CHECK(hipMemsetAsync(ptr, value, count, cuda_stream->cuda_stream()));
} }
}; };
class MemsetFactoryImpl : public MemsetFactory { class MemsetFactoryImpl : public MemsetFactory {
public: public:
OF_DISALLOW_COPY_AND_MOVE(MemsetFactoryImpl); OF_DISALLOW_COPY_AND_MOVE(MemsetFactoryImpl);
MemsetFactoryImpl() = default; MemsetFactoryImpl() = default;
~MemsetFactoryImpl() override = default; ~MemsetFactoryImpl() override = default;
std::unique_ptr<Memset> New() override { return std::unique_ptr<Memset>(new MemsetImpl()); } std::unique_ptr<Memset> New() override { return std::unique_ptr<Memset>(new MemsetImpl()); }
}; };
REGISTER_PRIMITIVE_FACTORY(DeviceType::kCUDA, MemsetFactory, MemsetFactoryImpl); REGISTER_PRIMITIVE_FACTORY(DeviceType::kCUDA, MemsetFactory, MemsetFactoryImpl);
} // namespace } // namespace
} // namespace primitive } // namespace primitive
} // namespace ep } // namespace ep
} // namespace oneflow } // namespace oneflow
#endif #endif
#include "hip/hip_runtime.h" #include "hip/hip_runtime.h"
/* /*
Copyright 2020 The OneFlow Authors. All rights reserved. Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
*/ */
#include "oneflow/core/ep/include/primitive/permute.h" #include "oneflow/core/ep/include/primitive/permute.h"
#include "oneflow/core/ep/common/primitive/permute_impl.h" #include "oneflow/core/ep/common/primitive/permute_impl.h"
#include "oneflow/core/ep/rocm/cuda_stream.h" #include "oneflow/core/ep/rocm/cuda_stream.h"
#include <hip/hip_runtime.h> #include <hip/hip_runtime.h>
namespace oneflow { namespace oneflow {
namespace ep { namespace ep {
namespace primitive { namespace primitive {
namespace permute { namespace permute {
namespace internal { namespace internal {
namespace { namespace {
constexpr int32_t kMov4TileSize = 32; constexpr int32_t kMov4TileSize = 32;
constexpr int32_t kMov2TileSize = 64; constexpr int32_t kMov2TileSize = 64;
constexpr int32_t kBlockRows = 8; constexpr int32_t kBlockRows = 8;
template<size_t num_dims, size_t movement_size, typename IndexType> template<size_t num_dims, size_t movement_size, typename IndexType>
__global__ void PermuteKernel(PermuteKernelParams<num_dims, IndexType> params) { __global__ void PermuteKernel(PermuteKernelParams<num_dims, IndexType> params) {
using T = typename std::aligned_storage<movement_size, movement_size>::type; using T = typename std::aligned_storage<movement_size, movement_size>::type;
const T* src = reinterpret_cast<const T*>(params.src); const T* src = reinterpret_cast<const T*>(params.src);
T* dst = reinterpret_cast<T*>(params.dst); T* dst = reinterpret_cast<T*>(params.dst);
IndexType src_index[num_dims]; IndexType src_index[num_dims];
IndexType dst_index[num_dims]; IndexType dst_index[num_dims];
CUDA_1D_KERNEL_LOOP_T(IndexType, i, params.count) { CUDA_1D_KERNEL_LOOP_T(IndexType, i, params.count) {
params.dst_index_helper.OffsetToNdIndex(i, dst_index); params.dst_index_helper.OffsetToNdIndex(i, dst_index);
#pragma unroll #pragma unroll
for (size_t dim = 0; dim < num_dims; ++dim) { for (size_t dim = 0; dim < num_dims; ++dim) {
src_index[params.permutation[dim]] = dst_index[dim]; src_index[params.permutation[dim]] = dst_index[dim];
} }
IndexType src_offset = params.src_index_helper.NdIndexToOffset(src_index); IndexType src_offset = params.src_index_helper.NdIndexToOffset(src_index);
dst[i] = src[src_offset]; dst[i] = src[src_offset];
} }
} }
// (B, X, Y) -> (B, Y, X) // (B, X, Y) -> (B, Y, X)
// refer from https://developer.nvidia.com/blog/efficient-matrix-transpose-cuda-cc/ // refer from https://developer.nvidia.com/blog/efficient-matrix-transpose-cuda-cc/
template<size_t num_dims, size_t movement_size, size_t tile_size, typename IndexType> template<size_t num_dims, size_t movement_size, size_t tile_size, typename IndexType>
__global__ void BatchTransposeKernel(const void* src_ptr, void* dst_ptr, IndexType rows, __global__ void BatchTransposeKernel(const void* src_ptr, void* dst_ptr, IndexType rows,
IndexType cols, IndexType num_tile_rows, IndexType cols, IndexType num_tile_rows,
IndexType num_tile_cols, int32_t block_nums) { IndexType num_tile_cols, int32_t block_nums) {
const IndexType src_rows = rows; const IndexType src_rows = rows;
const IndexType src_cols = cols; const IndexType src_cols = cols;
const IndexType dst_rows = cols; const IndexType dst_rows = cols;
const IndexType dst_cols = rows; const IndexType dst_cols = rows;
using T = typename std::aligned_storage<movement_size, movement_size>::type; using T = typename std::aligned_storage<movement_size, movement_size>::type;
__shared__ T tile[tile_size][tile_size + 1]; // To avoid bank conflict. __shared__ T tile[tile_size][tile_size + 1]; // To avoid bank conflict.
const T* src = reinterpret_cast<const T*>(src_ptr); const T* src = reinterpret_cast<const T*>(src_ptr);
T* dst = reinterpret_cast<T*>(dst_ptr); T* dst = reinterpret_cast<T*>(dst_ptr);
IndexType batch_num_tile = num_tile_rows * num_tile_cols; IndexType batch_num_tile = num_tile_rows * num_tile_cols;
for (int i = blockIdx.x, step = gridDim.x; i < block_nums; i += step) { for (int i = blockIdx.x, step = gridDim.x; i < block_nums; i += step) {
const IndexType batch_index = i / batch_num_tile; // the index of batch. const IndexType batch_index = i / batch_num_tile; // the index of batch.
const IndexType tile_index = const IndexType tile_index =
i - batch_index * batch_num_tile; // equal to i % (num_tile_rows*num_tile_cols). the i - batch_index * batch_num_tile; // equal to i % (num_tile_rows*num_tile_cols). the
// flatten index of tile in a batch. // flatten index of tile in a batch.
const IndexType tile_row_index = const IndexType tile_row_index =
tile_index / num_tile_cols; // the row index of tile in a batch. tile_index / num_tile_cols; // the row index of tile in a batch.
const IndexType tile_col_index = const IndexType tile_col_index =
tile_index tile_index
- tile_row_index - tile_row_index
* num_tile_cols; // equal to k % num_tile_cols. the col index of tile in a batch. * num_tile_cols; // equal to k % num_tile_cols. the col index of tile in a batch.
const IndexType offset = batch_index * src_rows * src_cols; const IndexType offset = batch_index * src_rows * src_cols;
{ {
IndexType col_in_tile = threadIdx.x; IndexType col_in_tile = threadIdx.x;
IndexType col_in_matrix = tile_col_index * tile_size + threadIdx.x; IndexType col_in_matrix = tile_col_index * tile_size + threadIdx.x;
#pragma unroll #pragma unroll
for (IndexType row_in_tile = threadIdx.y; row_in_tile < tile_size; for (IndexType row_in_tile = threadIdx.y; row_in_tile < tile_size;
row_in_tile += kBlockRows) { row_in_tile += kBlockRows) {
IndexType row_in_matrix = row_in_tile + tile_row_index * tile_size; IndexType row_in_matrix = row_in_tile + tile_row_index * tile_size;
if (col_in_matrix < src_cols && row_in_matrix < src_rows) { if (col_in_matrix < src_cols && row_in_matrix < src_rows) {
tile[row_in_tile][col_in_tile] = src[offset + row_in_matrix * src_cols + col_in_matrix]; tile[row_in_tile][col_in_tile] = src[offset + row_in_matrix * src_cols + col_in_matrix];
} }
} }
} }
__syncthreads(); __syncthreads();
{ {
IndexType col_in_tile = threadIdx.x; IndexType col_in_tile = threadIdx.x;
IndexType col_in_matrix = tile_row_index * tile_size + threadIdx.x; IndexType col_in_matrix = tile_row_index * tile_size + threadIdx.x;
#pragma unroll #pragma unroll
for (IndexType row_in_tile = threadIdx.y; row_in_tile < tile_size; for (IndexType row_in_tile = threadIdx.y; row_in_tile < tile_size;
row_in_tile += kBlockRows) { row_in_tile += kBlockRows) {
IndexType row_in_matrix = row_in_tile + tile_col_index * tile_size; IndexType row_in_matrix = row_in_tile + tile_col_index * tile_size;
if (col_in_matrix < dst_cols && row_in_matrix < dst_rows) { if (col_in_matrix < dst_cols && row_in_matrix < dst_rows) {
dst[offset + row_in_matrix * dst_cols + col_in_matrix] = tile[col_in_tile][row_in_tile]; dst[offset + row_in_matrix * dst_cols + col_in_matrix] = tile[col_in_tile][row_in_tile];
} }
} }
} }
__syncthreads(); __syncthreads();
} }
} }
/* /*
Here is a Movementsie=2 version of Batch Transpose. Here is a Movementsie=2 version of Batch Transpose.
When the H W can be divided by 2. we can read data use movementsize=4, and write back as When the H W can be divided by 2. we can read data use movementsize=4, and write back as
movementsize=4. movementsize=4.
*/ */
template<size_t num_dims, size_t tile_size, typename IndexType> template<size_t num_dims, size_t tile_size, typename IndexType>
__global__ void BatchTransposeMovement2Kernel(const void* src_ptr, void* dst_ptr, IndexType rows, __global__ void BatchTransposeMovement2Kernel(const void* src_ptr, void* dst_ptr, IndexType rows,
IndexType cols, IndexType num_tile_rows, IndexType cols, IndexType num_tile_rows,
IndexType num_tile_cols, int32_t block_nums) { IndexType num_tile_cols, int32_t block_nums) {
const IndexType src_rows = rows; const IndexType src_rows = rows;
const IndexType src_cols = cols; const IndexType src_cols = cols;
const IndexType dst_rows = cols; const IndexType dst_rows = cols;
const IndexType dst_cols = rows; const IndexType dst_cols = rows;
static_assert(tile_size % 2 == 0, ""); static_assert(tile_size % 2 == 0, "");
using T_MOV2 = typename std::aligned_storage<2, 2>::type; using T_MOV2 = typename std::aligned_storage<2, 2>::type;
using T_MOV4 = typename std::aligned_storage<4, 4>::type; using T_MOV4 = typename std::aligned_storage<4, 4>::type;
const T_MOV4* src = reinterpret_cast<const T_MOV4*>(src_ptr); const T_MOV4* src = reinterpret_cast<const T_MOV4*>(src_ptr);
T_MOV4* dst = reinterpret_cast<T_MOV4*>(dst_ptr); T_MOV4* dst = reinterpret_cast<T_MOV4*>(dst_ptr);
// Use union structure to process Load and Store. // Use union structure to process Load and Store.
__shared__ union { __shared__ union {
T_MOV2 tile_m2[tile_size][tile_size + 2]; // half [64][66] T_MOV2 tile_m2[tile_size][tile_size + 2]; // half [64][66]
T_MOV4 tile_m4[tile_size][tile_size / 2 + 1]; // half2 [64][33] T_MOV4 tile_m4[tile_size][tile_size / 2 + 1]; // half2 [64][33]
} tile_mem; } tile_mem;
IndexType batch_num_tile = num_tile_rows * num_tile_cols; IndexType batch_num_tile = num_tile_rows * num_tile_cols;
for (int i = blockIdx.x, step = gridDim.x; i < block_nums; i += step) { for (int i = blockIdx.x, step = gridDim.x; i < block_nums; i += step) {
const IndexType batch_index = i / batch_num_tile; // the index of batch. const IndexType batch_index = i / batch_num_tile; // the index of batch.
const IndexType tile_index = const IndexType tile_index =
i - batch_index * batch_num_tile; // equal to i % (num_tile_rows*num_tile_cols). the i - batch_index * batch_num_tile; // equal to i % (num_tile_rows*num_tile_cols). the
// flatten index of tile in a batch. // flatten index of tile in a batch.
const IndexType tile_row_index = const IndexType tile_row_index =
tile_index / num_tile_cols; // the row index of tile in a batch. tile_index / num_tile_cols; // the row index of tile in a batch.
const IndexType tile_col_index = const IndexType tile_col_index =
tile_index tile_index
- tile_row_index - tile_row_index
* num_tile_cols; // equal to k % num_tile_cols. the col index of tile in a batch. * num_tile_cols; // equal to k % num_tile_cols. the col index of tile in a batch.
const IndexType offset = batch_index * src_rows * src_cols; const IndexType offset = batch_index * src_rows * src_cols;
{ {
IndexType col_in_tile = threadIdx.x; IndexType col_in_tile = threadIdx.x;
IndexType col_in_matrix = tile_col_index * tile_size + threadIdx.x * 2; IndexType col_in_matrix = tile_col_index * tile_size + threadIdx.x * 2;
#pragma unroll #pragma unroll
for (IndexType row_in_tile = threadIdx.y; row_in_tile < tile_size; for (IndexType row_in_tile = threadIdx.y; row_in_tile < tile_size;
row_in_tile += kBlockRows) { row_in_tile += kBlockRows) {
IndexType row_in_matrix = row_in_tile + tile_row_index * tile_size; IndexType row_in_matrix = row_in_tile + tile_row_index * tile_size;
if (col_in_matrix < src_cols && row_in_matrix < src_rows) { if (col_in_matrix < src_cols && row_in_matrix < src_rows) {
tile_mem.tile_m4[row_in_tile][col_in_tile] = tile_mem.tile_m4[row_in_tile][col_in_tile] =
src[(offset + row_in_matrix * src_cols + col_in_matrix) / 2]; src[(offset + row_in_matrix * src_cols + col_in_matrix) / 2];
} }
} }
} }
__syncthreads(); __syncthreads();
{ {
IndexType col_in_tile = threadIdx.x; IndexType col_in_tile = threadIdx.x;
IndexType col_in_matrix = tile_row_index * tile_size + threadIdx.x * 2; IndexType col_in_matrix = tile_row_index * tile_size + threadIdx.x * 2;
#pragma unroll #pragma unroll
for (IndexType row_in_tile = threadIdx.y; row_in_tile < tile_size; for (IndexType row_in_tile = threadIdx.y; row_in_tile < tile_size;
row_in_tile += kBlockRows) { row_in_tile += kBlockRows) {
IndexType row_in_matrix = row_in_tile + tile_col_index * tile_size; IndexType row_in_matrix = row_in_tile + tile_col_index * tile_size;
union { union {
T_MOV4 m4; T_MOV4 m4;
T_MOV2 m2[2]; T_MOV2 m2[2];
} tmp_storage; } tmp_storage;
if (col_in_matrix < dst_cols && row_in_matrix < dst_rows) { if (col_in_matrix < dst_cols && row_in_matrix < dst_rows) {
tmp_storage.m2[0] = tile_mem.tile_m2[col_in_tile * 2][row_in_tile]; tmp_storage.m2[0] = tile_mem.tile_m2[col_in_tile * 2][row_in_tile];
tmp_storage.m2[1] = tile_mem.tile_m2[col_in_tile * 2 + 1][row_in_tile]; tmp_storage.m2[1] = tile_mem.tile_m2[col_in_tile * 2 + 1][row_in_tile];
dst[(offset + row_in_matrix * dst_cols + col_in_matrix) / 2] = tmp_storage.m4; dst[(offset + row_in_matrix * dst_cols + col_in_matrix) / 2] = tmp_storage.m4;
} }
} }
} }
__syncthreads(); __syncthreads();
} }
} }
template<size_t num_dims, size_t movement_size, size_t tile_size, typename IndexType> template<size_t num_dims, size_t movement_size, size_t tile_size, typename IndexType>
void LaunchBatchTransposeKernel(hipStream_t& cuda_stream, void LaunchBatchTransposeKernel(hipStream_t& cuda_stream,
const PermuteKernelParams<num_dims, IndexType>& params, const PermuteKernelParams<num_dims, IndexType>& params,
const IndexType& num_batches, const IndexType& rows, const IndexType& num_batches, const IndexType& rows,
const IndexType& cols) { const IndexType& cols) {
IndexType num_tile_rows = (rows + tile_size - 1) / tile_size; IndexType num_tile_rows = (rows + tile_size - 1) / tile_size;
IndexType num_tile_cols = (cols + tile_size - 1) / tile_size; IndexType num_tile_cols = (cols + tile_size - 1) / tile_size;
const int32_t block_nums = num_batches * num_tile_rows * num_tile_cols; const int32_t block_nums = num_batches * num_tile_rows * num_tile_cols;
int32_t launched_block_nums = std::min(block_nums, kCudaMaxBlocksNum); int32_t launched_block_nums = std::min(block_nums, kCudaMaxBlocksNum);
if (tile_size == kMov2TileSize) { if (tile_size == kMov2TileSize) {
const int32_t half2_thread = tile_size / 2; // cause each thread process two half elements. const int32_t half2_thread = tile_size / 2; // cause each thread process two half elements.
BatchTransposeMovement2Kernel<num_dims, kMov2TileSize, IndexType> BatchTransposeMovement2Kernel<num_dims, kMov2TileSize, IndexType>
<<<launched_block_nums, dim3(half2_thread, kBlockRows), 0, cuda_stream>>>( <<<launched_block_nums, dim3(half2_thread, kBlockRows), 0, cuda_stream>>>(
params.src, params.dst, rows, cols, num_tile_rows, num_tile_cols, params.src, params.dst, rows, cols, num_tile_rows, num_tile_cols,
block_nums); // Set threads num as 32x8 cause each threads block_nums); // Set threads num as 32x8 cause each threads
// process 4 elements to 64x66 half share memory. // process 4 elements to 64x66 half share memory.
} else { } else {
BatchTransposeKernel<num_dims, movement_size, tile_size, IndexType> BatchTransposeKernel<num_dims, movement_size, tile_size, IndexType>
<<<launched_block_nums, dim3(tile_size, kBlockRows), 0, cuda_stream>>>( <<<launched_block_nums, dim3(tile_size, kBlockRows), 0, cuda_stream>>>(
params.src, params.dst, rows, cols, num_tile_rows, num_tile_cols, block_nums); params.src, params.dst, rows, cols, num_tile_rows, num_tile_cols, block_nums);
} }
} }
template<size_t tile_size, typename IndexType> template<size_t tile_size, typename IndexType>
bool CheckIfGreaterEqualThanTileSize(const IndexType& rows, const IndexType& cols) { bool CheckIfGreaterEqualThanTileSize(const IndexType& rows, const IndexType& cols) {
if (rows < tile_size || cols < tile_size) { return false; } if (rows < tile_size || cols < tile_size) { return false; }
return true; return true;
} }
template<size_t num_dims, size_t tile_size, typename IndexType> template<size_t num_dims, size_t tile_size, typename IndexType>
bool CheckLaunchBatchTranspose(const int* permutation, const IndexType& num_batches, bool CheckLaunchBatchTranspose(const int* permutation, const IndexType& num_batches,
const IndexType& rows, const IndexType& cols) { const IndexType& rows, const IndexType& cols) {
if (CheckIfGreaterEqualThanTileSize<tile_size, IndexType>(rows, cols)) { if (CheckIfGreaterEqualThanTileSize<tile_size, IndexType>(rows, cols)) {
if (num_batches == 1 && permutation[1] == 0 && permutation[0] == 1) { if (num_batches == 1 && permutation[1] == 0 && permutation[0] == 1) {
// 2d tensor case: (0, 1) -> (1, 0) // 2d tensor case: (0, 1) -> (1, 0)
return true; return true;
} else if (num_dims == 3 && permutation[2] == 1 && permutation[1] == 2) { } else if (num_dims == 3 && permutation[2] == 1 && permutation[1] == 2) {
// 3d tensor case: (0, 1, 2) -> (0, 2, 1) // 3d tensor case: (0, 1, 2) -> (0, 2, 1)
return true; return true;
} else { } else {
return false; return false;
} }
} }
return false; return false;
} }
template<typename IndexType, size_t movement_size> template<typename IndexType, size_t movement_size>
bool CheckUseMov2(const IndexType& rows, const IndexType& cols, const void* src, void* dst) { bool CheckUseMov2(const IndexType& rows, const IndexType& cols, const void* src, void* dst) {
auto src_ptr = reinterpret_cast<std::uintptr_t>(src); auto src_ptr = reinterpret_cast<std::uintptr_t>(src);
auto dst_ptr = reinterpret_cast<std::uintptr_t>(dst); auto dst_ptr = reinterpret_cast<std::uintptr_t>(dst);
return (movement_size == 2) && (rows % 2 == 0) && (cols % 2 == 0) && (src_ptr % 4 == 0) return (movement_size == 2) && (rows % 2 == 0) && (cols % 2 == 0) && (src_ptr % 4 == 0)
&& (dst_ptr % 4 == 0); && (dst_ptr % 4 == 0);
} }
template<size_t num_dims, typename IndexType> template<size_t num_dims, typename IndexType>
void InferBatchTransposeShape(const int64_t* src_dims, IndexType* num_batches, IndexType* rows, void InferBatchTransposeShape(const int64_t* src_dims, IndexType* num_batches, IndexType* rows,
IndexType* cols) { IndexType* cols) {
if (num_dims == 2) { if (num_dims == 2) {
*num_batches = 1; *num_batches = 1;
*rows = src_dims[0]; *rows = src_dims[0];
*cols = src_dims[1]; *cols = src_dims[1];
} else { } else {
*num_batches = src_dims[0]; *num_batches = src_dims[0];
*rows = src_dims[1]; *rows = src_dims[1];
*cols = src_dims[2]; *cols = src_dims[2];
} }
} }
template<size_t num_dims, size_t movement_size, typename IndexType> template<size_t num_dims, size_t movement_size, typename IndexType>
void LaunchKernel(Stream* stream, const int64_t* src_dims, const void* src, const int* permutation, void LaunchKernel(Stream* stream, const int64_t* src_dims, const void* src, const int* permutation,
void* dst, size_t count) { void* dst, size_t count) {
PermuteKernelParams<num_dims, IndexType> params = PermuteKernelParams<num_dims, IndexType> params =
MakePermuteParams<num_dims, IndexType>(src_dims, src, permutation, dst, count); MakePermuteParams<num_dims, IndexType>(src_dims, src, permutation, dst, count);
hipStream_t cuda_stream = stream->As<CudaStream>()->cuda_stream(); hipStream_t cuda_stream = stream->As<CudaStream>()->cuda_stream();
if (num_dims == 2 || num_dims == 3) { if (num_dims == 2 || num_dims == 3) {
IndexType num_batches; IndexType num_batches;
IndexType rows; IndexType rows;
IndexType cols; IndexType cols;
InferBatchTransposeShape<num_dims, IndexType>(src_dims, &num_batches, &rows, &cols); InferBatchTransposeShape<num_dims, IndexType>(src_dims, &num_batches, &rows, &cols);
if (CheckLaunchBatchTranspose<num_dims, kMov4TileSize>(params.permutation, num_batches, rows, if (CheckLaunchBatchTranspose<num_dims, kMov4TileSize>(params.permutation, num_batches, rows,
cols)) { cols)) {
if (CheckUseMov2<IndexType, movement_size>(rows, cols, src, dst)) { if (CheckUseMov2<IndexType, movement_size>(rows, cols, src, dst)) {
LaunchBatchTransposeKernel<num_dims, 2, kMov2TileSize, IndexType>(cuda_stream, params, LaunchBatchTransposeKernel<num_dims, 2, kMov2TileSize, IndexType>(cuda_stream, params,
num_batches, rows, cols); num_batches, rows, cols);
} else { } else {
LaunchBatchTransposeKernel<num_dims, movement_size, kMov4TileSize, IndexType>( LaunchBatchTransposeKernel<num_dims, movement_size, kMov4TileSize, IndexType>(
cuda_stream, params, num_batches, rows, cols); cuda_stream, params, num_batches, rows, cols);
} }
} else { } else {
PermuteKernel<num_dims, movement_size, IndexType> PermuteKernel<num_dims, movement_size, IndexType>
<<<BlocksNum4ThreadsNum(params.count), kCudaThreadsNumPerBlock, 0, cuda_stream>>>(params); <<<BlocksNum4ThreadsNum(params.count), kCudaThreadsNumPerBlock, 0, cuda_stream>>>(params);
} }
} else { } else {
PermuteKernel<num_dims, movement_size, IndexType> PermuteKernel<num_dims, movement_size, IndexType>
<<<BlocksNum4ThreadsNum(params.count), kCudaThreadsNumPerBlock, 0, cuda_stream>>>(params); <<<BlocksNum4ThreadsNum(params.count), kCudaThreadsNumPerBlock, 0, cuda_stream>>>(params);
} }
} }
class PermuteImpl : public Permute { class PermuteImpl : public Permute {
public: public:
OF_DISALLOW_COPY_AND_MOVE(PermuteImpl); OF_DISALLOW_COPY_AND_MOVE(PermuteImpl);
PermuteImpl() = default; PermuteImpl() = default;
~PermuteImpl() override = default; ~PermuteImpl() override = default;
using Permute::Launch; using Permute::Launch;
void Launch(Stream* stream, DataType data_type, size_t num_dims, const int64_t* src_dims, void Launch(Stream* stream, DataType data_type, size_t num_dims, const int64_t* src_dims,
const void* src, const int* permutation, void* dst) override { const void* src, const int* permutation, void* dst) override {
SimplifyThenLaunch(stream, data_type, num_dims, src_dims, src, permutation, dst); SimplifyThenLaunch(stream, data_type, num_dims, src_dims, src, permutation, dst);
} }
}; };
class PermuteFactoryImpl : public PermuteFactory { class PermuteFactoryImpl : public PermuteFactory {
public: public:
OF_DISALLOW_COPY_AND_MOVE(PermuteFactoryImpl); OF_DISALLOW_COPY_AND_MOVE(PermuteFactoryImpl);
PermuteFactoryImpl() = default; PermuteFactoryImpl() = default;
~PermuteFactoryImpl() override = default; ~PermuteFactoryImpl() override = default;
std::unique_ptr<Permute> New(size_t max_num_dims) override { std::unique_ptr<Permute> New(size_t max_num_dims) override {
if (max_num_dims <= kMaxNumDims) { if (max_num_dims <= kMaxNumDims) {
return std::unique_ptr<Permute>(new PermuteImpl()); return std::unique_ptr<Permute>(new PermuteImpl());
} else { } else {
return nullptr; return nullptr;
} }
} }
}; };
REGISTER_PRIMITIVE_FACTORY(DeviceType::kCUDA, PermuteFactory, PermuteFactoryImpl); REGISTER_PRIMITIVE_FACTORY(DeviceType::kCUDA, PermuteFactory, PermuteFactoryImpl);
} // namespace } // namespace
} // namespace internal } // namespace internal
} // namespace permute } // namespace permute
} // namespace primitive } // namespace primitive
} // namespace ep } // namespace ep
} // namespace oneflow } // namespace oneflow
/* /*
Copyright 2020 The OneFlow Authors. All rights reserved. Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
*/ */
#include "oneflow/core/common/preprocessor.h" #include "oneflow/core/common/preprocessor.h"
#include "oneflow/core/ep/include/primitive/softmax.h" #include "oneflow/core/ep/include/primitive/softmax.h"
#include "oneflow/core/ep/include/primitive/log_softmax.h" #include "oneflow/core/ep/include/primitive/log_softmax.h"
#include "oneflow/core/ep/rocm/primitive/type_seq.h" #include "oneflow/core/ep/rocm/primitive/type_seq.h"
#include "oneflow/core/hip/softmax.hip.h" #include "oneflow/core/hip/softmax.hip.h"
#include "oneflow/core/ep/rocm/cuda_stream.h" #include "oneflow/core/ep/rocm/cuda_stream.h"
namespace oneflow { namespace oneflow {
namespace ep { namespace ep {
namespace primitive { namespace primitive {
namespace { namespace {
enum class Algorithm { enum class Algorithm {
kSoftmax, kSoftmax,
kLogSoftmax, kLogSoftmax,
}; };
template<Algorithm algorithm, typename T> template<Algorithm algorithm, typename T>
void SoftmaxGpu(hipStream_t cuda_stream, size_t rows, size_t cols, const T* x, T* y) { void SoftmaxGpu(hipStream_t cuda_stream, size_t rows, size_t cols, const T* x, T* y) {
using ComputeType = typename cuda::softmax::DefaultComputeType<T>::type; using ComputeType = typename cuda::softmax::DefaultComputeType<T>::type;
oneflow::cuda::softmax::DirectLoad<T, ComputeType> load(x, cols); oneflow::cuda::softmax::DirectLoad<T, ComputeType> load(x, cols);
oneflow::cuda::softmax::DirectStore<ComputeType, T> store(y, cols); oneflow::cuda::softmax::DirectStore<ComputeType, T> store(y, cols);
if (algorithm == Algorithm::kSoftmax) { if (algorithm == Algorithm::kSoftmax) {
OF_CUDA_CHECK((cuda::softmax::DispatchSoftmax<decltype(load), decltype(store), ComputeType>( OF_CUDA_CHECK((cuda::softmax::DispatchSoftmax<decltype(load), decltype(store), ComputeType>(
cuda_stream, load, store, rows, cols))); cuda_stream, load, store, rows, cols)));
} else if (algorithm == Algorithm::kLogSoftmax) { } else if (algorithm == Algorithm::kLogSoftmax) {
OF_CUDA_CHECK((cuda::softmax::DispatchLogSoftmax<decltype(load), decltype(store), ComputeType>( OF_CUDA_CHECK((cuda::softmax::DispatchLogSoftmax<decltype(load), decltype(store), ComputeType>(
cuda_stream, load, store, rows, cols))); cuda_stream, load, store, rows, cols)));
} else { } else {
UNIMPLEMENTED(); UNIMPLEMENTED();
} }
} }
template<typename SoftmaxBase, Algorithm algorithm, typename T> template<typename SoftmaxBase, Algorithm algorithm, typename T>
class SoftmaxImpl : public SoftmaxBase { class SoftmaxImpl : public SoftmaxBase {
public: public:
OF_DISALLOW_COPY_AND_MOVE(SoftmaxImpl); OF_DISALLOW_COPY_AND_MOVE(SoftmaxImpl);
SoftmaxImpl() = default; SoftmaxImpl() = default;
~SoftmaxImpl() override = default; ~SoftmaxImpl() override = default;
void Launch(Stream* stream, size_t rows, size_t cols, const void* x, void* y) override { void Launch(Stream* stream, size_t rows, size_t cols, const void* x, void* y) override {
hipStream_t cuda_stream = stream->As<CudaStream>()->cuda_stream(); hipStream_t cuda_stream = stream->As<CudaStream>()->cuda_stream();
SoftmaxGpu<algorithm, T>(cuda_stream, rows, cols, reinterpret_cast<const T*>(x), SoftmaxGpu<algorithm, T>(cuda_stream, rows, cols, reinterpret_cast<const T*>(x),
reinterpret_cast<T*>(y)); reinterpret_cast<T*>(y));
} }
}; };
template<typename SoftmaxBase, Algorithm algorithm, typename T> template<typename SoftmaxBase, Algorithm algorithm, typename T>
std::unique_ptr<SoftmaxBase> NewSoftmax() { std::unique_ptr<SoftmaxBase> NewSoftmax() {
return std::unique_ptr<SoftmaxBase>(new SoftmaxImpl<SoftmaxBase, algorithm, T>()); return std::unique_ptr<SoftmaxBase>(new SoftmaxImpl<SoftmaxBase, algorithm, T>());
} }
template<typename FactoryBase, typename SoftmaxBase, Algorithm algorithm> template<typename FactoryBase, typename SoftmaxBase, Algorithm algorithm>
class GenericSoftmaxFactoryImpl : public FactoryBase { class GenericSoftmaxFactoryImpl : public FactoryBase {
public: public:
OF_DISALLOW_COPY_AND_MOVE(GenericSoftmaxFactoryImpl); OF_DISALLOW_COPY_AND_MOVE(GenericSoftmaxFactoryImpl);
GenericSoftmaxFactoryImpl() = default; GenericSoftmaxFactoryImpl() = default;
~GenericSoftmaxFactoryImpl() override = default; ~GenericSoftmaxFactoryImpl() override = default;
std::unique_ptr<SoftmaxBase> New(DataType data_type) override { std::unique_ptr<SoftmaxBase> New(DataType data_type) override {
#define MAKE_NEW_SOFTMAX_ENTRY(type_cpp, type_proto) \ #define MAKE_NEW_SOFTMAX_ENTRY(type_cpp, type_proto) \
{type_proto, NewSoftmax<SoftmaxBase, algorithm, type_cpp>}, {type_proto, NewSoftmax<SoftmaxBase, algorithm, type_cpp>},
static const std::map<DataType, std::function<std::unique_ptr<SoftmaxBase>()>> static const std::map<DataType, std::function<std::unique_ptr<SoftmaxBase>()>>
new_softmax_handle{ new_softmax_handle{
OF_PP_FOR_EACH_TUPLE(MAKE_NEW_SOFTMAX_ENTRY, CUDA_PRIMITIVE_FLOATING_TYPE_SEQ)}; OF_PP_FOR_EACH_TUPLE(MAKE_NEW_SOFTMAX_ENTRY, CUDA_PRIMITIVE_FLOATING_TYPE_SEQ)};
#undef MAKE_NEW_SOFTMAX_ENTRY #undef MAKE_NEW_SOFTMAX_ENTRY
const auto it = new_softmax_handle.find(data_type); const auto it = new_softmax_handle.find(data_type);
if (it != new_softmax_handle.end()) { if (it != new_softmax_handle.end()) {
return it->second(); return it->second();
} else { } else {
return nullptr; return nullptr;
} }
} }
}; };
using SoftmaxFactoryImpl = GenericSoftmaxFactoryImpl<SoftmaxFactory, Softmax, Algorithm::kSoftmax>; using SoftmaxFactoryImpl = GenericSoftmaxFactoryImpl<SoftmaxFactory, Softmax, Algorithm::kSoftmax>;
using LogSoftmaxFactoryImpl = using LogSoftmaxFactoryImpl =
GenericSoftmaxFactoryImpl<LogSoftmaxFactory, LogSoftmax, Algorithm::kLogSoftmax>; GenericSoftmaxFactoryImpl<LogSoftmaxFactory, LogSoftmax, Algorithm::kLogSoftmax>;
REGISTER_PRIMITIVE_FACTORY(DeviceType::kCUDA, SoftmaxFactory, SoftmaxFactoryImpl); REGISTER_PRIMITIVE_FACTORY(DeviceType::kCUDA, SoftmaxFactory, SoftmaxFactoryImpl);
REGISTER_PRIMITIVE_FACTORY(DeviceType::kCUDA, LogSoftmaxFactory, LogSoftmaxFactoryImpl); REGISTER_PRIMITIVE_FACTORY(DeviceType::kCUDA, LogSoftmaxFactory, LogSoftmaxFactoryImpl);
} // namespace } // namespace
} // namespace primitive } // namespace primitive
} // namespace ep } // namespace ep
} // namespace oneflow } // namespace oneflow
/* /*
Copyright 2020 The OneFlow Authors. All rights reserved. Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
*/ */
#include "oneflow/core/common/preprocessor.h" #include "oneflow/core/common/preprocessor.h"
#include "oneflow/core/ep/include/primitive/softmax_backward.h" #include "oneflow/core/ep/include/primitive/softmax_backward.h"
#include "oneflow/core/ep/include/primitive/log_softmax_backward.h" #include "oneflow/core/ep/include/primitive/log_softmax_backward.h"
#include "oneflow/core/ep/rocm/primitive/type_seq.h" #include "oneflow/core/ep/rocm/primitive/type_seq.h"
#include "oneflow/core/hip/softmax.hip.h" #include "oneflow/core/hip/softmax.hip.h"
#include "oneflow/core/ep/rocm/cuda_stream.h" #include "oneflow/core/ep/rocm/cuda_stream.h"
namespace oneflow { namespace oneflow {
namespace ep { namespace ep {
namespace primitive { namespace primitive {
namespace { namespace {
enum class Algorithm { enum class Algorithm {
kSoftmax, kSoftmax,
kLogSoftmax, kLogSoftmax,
}; };
template<Algorithm algorithm, typename T> template<Algorithm algorithm, typename T>
void SoftmaxBackwardGpu(hipStream_t cuda_stream, size_t rows, size_t cols, const T* y, const T* dy, void SoftmaxBackwardGpu(hipStream_t cuda_stream, size_t rows, size_t cols, const T* y, const T* dy,
T* dx) { T* dx) {
using ComputeType = typename cuda::softmax::DefaultComputeType<T>::type; using ComputeType = typename cuda::softmax::DefaultComputeType<T>::type;
cuda::softmax::DirectLoad<T, ComputeType> load_y(y, cols); cuda::softmax::DirectLoad<T, ComputeType> load_y(y, cols);
cuda::softmax::DirectLoad<T, ComputeType> load_dy(dy, cols); cuda::softmax::DirectLoad<T, ComputeType> load_dy(dy, cols);
cuda::softmax::DirectStore<ComputeType, T> store(dx, cols); cuda::softmax::DirectStore<ComputeType, T> store(dx, cols);
if (algorithm == Algorithm::kSoftmax) { if (algorithm == Algorithm::kSoftmax) {
OF_CUDA_CHECK((cuda::softmax::DispatchSoftmaxGrad<decltype(load_y), decltype(load_dy), OF_CUDA_CHECK((cuda::softmax::DispatchSoftmaxGrad<decltype(load_y), decltype(load_dy),
decltype(store), ComputeType>( decltype(store), ComputeType>(
cuda_stream, load_y, load_dy, store, rows, cols))); cuda_stream, load_y, load_dy, store, rows, cols)));
} else if (algorithm == Algorithm::kLogSoftmax) { } else if (algorithm == Algorithm::kLogSoftmax) {
OF_CUDA_CHECK((cuda::softmax::DispatchLogSoftmaxGrad<decltype(load_y), decltype(load_dy), OF_CUDA_CHECK((cuda::softmax::DispatchLogSoftmaxGrad<decltype(load_y), decltype(load_dy),
decltype(store), ComputeType>( decltype(store), ComputeType>(
cuda_stream, load_y, load_dy, store, rows, cols))); cuda_stream, load_y, load_dy, store, rows, cols)));
} else { } else {
UNIMPLEMENTED(); UNIMPLEMENTED();
} }
} }
template<typename SoftmaxBackwardBase, Algorithm algorithm, typename T> template<typename SoftmaxBackwardBase, Algorithm algorithm, typename T>
class SoftmaxBackwardImpl : public SoftmaxBackwardBase { class SoftmaxBackwardImpl : public SoftmaxBackwardBase {
public: public:
OF_DISALLOW_COPY_AND_MOVE(SoftmaxBackwardImpl); OF_DISALLOW_COPY_AND_MOVE(SoftmaxBackwardImpl);
SoftmaxBackwardImpl() = default; SoftmaxBackwardImpl() = default;
~SoftmaxBackwardImpl() override = default; ~SoftmaxBackwardImpl() override = default;
void Launch(Stream* stream, size_t rows, size_t cols, const void* y, const void* dy, void Launch(Stream* stream, size_t rows, size_t cols, const void* y, const void* dy,
void* dx) override { void* dx) override {
hipStream_t cuda_stream = stream->As<CudaStream>()->cuda_stream(); hipStream_t cuda_stream = stream->As<CudaStream>()->cuda_stream();
SoftmaxBackwardGpu<algorithm, T>(cuda_stream, rows, cols, reinterpret_cast<const T*>(y), SoftmaxBackwardGpu<algorithm, T>(cuda_stream, rows, cols, reinterpret_cast<const T*>(y),
reinterpret_cast<const T*>(dy), reinterpret_cast<T*>(dx)); reinterpret_cast<const T*>(dy), reinterpret_cast<T*>(dx));
} }
}; };
template<typename SoftmaxBackwardBase, Algorithm algorithm, typename T> template<typename SoftmaxBackwardBase, Algorithm algorithm, typename T>
std::unique_ptr<SoftmaxBackwardBase> NewSoftmaxBackward() { std::unique_ptr<SoftmaxBackwardBase> NewSoftmaxBackward() {
return std::unique_ptr<SoftmaxBackwardBase>( return std::unique_ptr<SoftmaxBackwardBase>(
new SoftmaxBackwardImpl<SoftmaxBackwardBase, algorithm, T>()); new SoftmaxBackwardImpl<SoftmaxBackwardBase, algorithm, T>());
} }
template<typename BackwardFactoryBase, typename SoftmaxBackwardBase, Algorithm algorithm> template<typename BackwardFactoryBase, typename SoftmaxBackwardBase, Algorithm algorithm>
class GenericSoftmaxBackwardFactoryImpl : public BackwardFactoryBase { class GenericSoftmaxBackwardFactoryImpl : public BackwardFactoryBase {
public: public:
OF_DISALLOW_COPY_AND_MOVE(GenericSoftmaxBackwardFactoryImpl); OF_DISALLOW_COPY_AND_MOVE(GenericSoftmaxBackwardFactoryImpl);
GenericSoftmaxBackwardFactoryImpl() = default; GenericSoftmaxBackwardFactoryImpl() = default;
~GenericSoftmaxBackwardFactoryImpl() override = default; ~GenericSoftmaxBackwardFactoryImpl() override = default;
std::unique_ptr<SoftmaxBackwardBase> New(DataType data_type) override { std::unique_ptr<SoftmaxBackwardBase> New(DataType data_type) override {
#define MAKE_NEW_SOFTMAX_ENTRY(type_cpp, type_proto) \ #define MAKE_NEW_SOFTMAX_ENTRY(type_cpp, type_proto) \
{type_proto, NewSoftmaxBackward<SoftmaxBackwardBase, algorithm, type_cpp>}, {type_proto, NewSoftmaxBackward<SoftmaxBackwardBase, algorithm, type_cpp>},
static const std::map<DataType, std::function<std::unique_ptr<SoftmaxBackwardBase>()>> static const std::map<DataType, std::function<std::unique_ptr<SoftmaxBackwardBase>()>>
new_softmax_backward_handle{ new_softmax_backward_handle{
OF_PP_FOR_EACH_TUPLE(MAKE_NEW_SOFTMAX_ENTRY, CUDA_PRIMITIVE_FLOATING_TYPE_SEQ)}; OF_PP_FOR_EACH_TUPLE(MAKE_NEW_SOFTMAX_ENTRY, CUDA_PRIMITIVE_FLOATING_TYPE_SEQ)};
#undef MAKE_NEW_SOFTMAX_ENTRY #undef MAKE_NEW_SOFTMAX_ENTRY
const auto it = new_softmax_backward_handle.find(data_type); const auto it = new_softmax_backward_handle.find(data_type);
if (it != new_softmax_backward_handle.end()) { if (it != new_softmax_backward_handle.end()) {
return it->second(); return it->second();
} else { } else {
return nullptr; return nullptr;
} }
} }
}; };
using SoftmaxBackwardFactoryImpl = using SoftmaxBackwardFactoryImpl =
GenericSoftmaxBackwardFactoryImpl<SoftmaxBackwardFactory, SoftmaxBackward, Algorithm::kSoftmax>; GenericSoftmaxBackwardFactoryImpl<SoftmaxBackwardFactory, SoftmaxBackward, Algorithm::kSoftmax>;
using LogSoftmaxBackwardFactoryImpl = using LogSoftmaxBackwardFactoryImpl =
GenericSoftmaxBackwardFactoryImpl<LogSoftmaxBackwardFactory, LogSoftmaxBackward, GenericSoftmaxBackwardFactoryImpl<LogSoftmaxBackwardFactory, LogSoftmaxBackward,
Algorithm::kLogSoftmax>; Algorithm::kLogSoftmax>;
REGISTER_PRIMITIVE_FACTORY(DeviceType::kCUDA, SoftmaxBackwardFactory, SoftmaxBackwardFactoryImpl); REGISTER_PRIMITIVE_FACTORY(DeviceType::kCUDA, SoftmaxBackwardFactory, SoftmaxBackwardFactoryImpl);
REGISTER_PRIMITIVE_FACTORY(DeviceType::kCUDA, LogSoftmaxBackwardFactory, REGISTER_PRIMITIVE_FACTORY(DeviceType::kCUDA, LogSoftmaxBackwardFactory,
LogSoftmaxBackwardFactoryImpl); LogSoftmaxBackwardFactoryImpl);
} // namespace } // namespace
} // namespace primitive } // namespace primitive
} // namespace ep } // namespace ep
} // namespace oneflow } // namespace oneflow
/* /*
Copyright 2020 The OneFlow Authors. All rights reserved. Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
*/ */
#ifndef ONEFLOW_CORE_EP_CUDA_PRIMITIVE_TYPE_SEQ_H_ #ifndef ONEFLOW_CORE_EP_CUDA_PRIMITIVE_TYPE_SEQ_H_
#define ONEFLOW_CORE_EP_CUDA_PRIMITIVE_TYPE_SEQ_H_ #define ONEFLOW_CORE_EP_CUDA_PRIMITIVE_TYPE_SEQ_H_
#include "oneflow/core/common/preprocessor.h" #include "oneflow/core/common/preprocessor.h"
#include "oneflow/core/common/data_type.h" #include "oneflow/core/common/data_type.h"
#ifdef WITH_ROCM #ifdef WITH_ROCM
#include <hip/hip_runtime.h> #include <hip/hip_runtime.h>
#include <hip/hip_fp16.h> #include <hip/hip_fp16.h>
// #if CUDA_VERSION >= 11000 // #if CUDA_VERSION >= 11000
// #include <cuda_bf16.h> // #include <cuda_bf16.h>
// #endif // CUDA_VERSION >= 11000 // #endif // CUDA_VERSION >= 11000
#define CUDA_PRIMITIVE_BOOL_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(bool, DataType::kBool) #define CUDA_PRIMITIVE_BOOL_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(bool, DataType::kBool)
#define CUDA_PRIMITIVE_CHAR_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(char, DataType::kChar) #define CUDA_PRIMITIVE_CHAR_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(char, DataType::kChar)
#define CUDA_PRIMITIVE_INT8_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(int8_t, DataType::kInt8) #define CUDA_PRIMITIVE_INT8_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(int8_t, DataType::kInt8)
#define CUDA_PRIMITIVE_UINT8_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(uint8_t, DataType::kUInt8) #define CUDA_PRIMITIVE_UINT8_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(uint8_t, DataType::kUInt8)
#define CUDA_PRIMITIVE_INT32_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(int32_t, DataType::kInt32) #define CUDA_PRIMITIVE_INT32_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(int32_t, DataType::kInt32)
#define CUDA_PRIMITIVE_UINT32_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(uint32_t, DataType::kUInt32) #define CUDA_PRIMITIVE_UINT32_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(uint32_t, DataType::kUInt32)
#define CUDA_PRIMITIVE_INT64_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(int64_t, DataType::kInt64) #define CUDA_PRIMITIVE_INT64_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(int64_t, DataType::kInt64)
#define CUDA_PRIMITIVE_UINT64_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(uint64_t, DataType::kUInt64) #define CUDA_PRIMITIVE_UINT64_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(uint64_t, DataType::kUInt64)
#define CUDA_PRIMITIVE_FLOAT_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(float, DataType::kFloat) #define CUDA_PRIMITIVE_FLOAT_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(float, DataType::kFloat)
#define CUDA_PRIMITIVE_DOUBLE_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(double, DataType::kDouble) #define CUDA_PRIMITIVE_DOUBLE_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(double, DataType::kDouble)
#define CUDA_PRIMITIVE_FLOAT16_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(half, DataType::kFloat16) #define CUDA_PRIMITIVE_FLOAT16_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(half, DataType::kFloat16)
// #if CUDA_VERSION >= 11000 // #if CUDA_VERSION >= 11000
// #define CUDA_PRIMITIVE_BFLOAT16_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(nv_bfloat16, DataType::kBFloat16) // #define CUDA_PRIMITIVE_BFLOAT16_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(nv_bfloat16, DataType::kBFloat16)
// #else // #else
#define CUDA_PRIMITIVE_BFLOAT16_TYPE_SEQ #define CUDA_PRIMITIVE_BFLOAT16_TYPE_SEQ
// #endif // CUDA_VERSION >= 11000 // #endif // CUDA_VERSION >= 11000
#define CUDA_PRIMITIVE_ALL_TYPE_SEQ \ #define CUDA_PRIMITIVE_ALL_TYPE_SEQ \
CUDA_PRIMITIVE_BOOL_TYPE_SEQ \ CUDA_PRIMITIVE_BOOL_TYPE_SEQ \
CUDA_PRIMITIVE_CHAR_TYPE_SEQ \ CUDA_PRIMITIVE_CHAR_TYPE_SEQ \
CUDA_PRIMITIVE_INT8_TYPE_SEQ \ CUDA_PRIMITIVE_INT8_TYPE_SEQ \
CUDA_PRIMITIVE_UINT8_TYPE_SEQ \ CUDA_PRIMITIVE_UINT8_TYPE_SEQ \
CUDA_PRIMITIVE_INT32_TYPE_SEQ \ CUDA_PRIMITIVE_INT32_TYPE_SEQ \
CUDA_PRIMITIVE_INT64_TYPE_SEQ \ CUDA_PRIMITIVE_INT64_TYPE_SEQ \
CUDA_PRIMITIVE_FLOAT_TYPE_SEQ \ CUDA_PRIMITIVE_FLOAT_TYPE_SEQ \
CUDA_PRIMITIVE_DOUBLE_TYPE_SEQ \ CUDA_PRIMITIVE_DOUBLE_TYPE_SEQ \
CUDA_PRIMITIVE_FLOAT16_TYPE_SEQ \ CUDA_PRIMITIVE_FLOAT16_TYPE_SEQ \
CUDA_PRIMITIVE_BFLOAT16_TYPE_SEQ CUDA_PRIMITIVE_BFLOAT16_TYPE_SEQ
#define CUDA_PRIMITIVE_FLOATING_TYPE_SEQ \ #define CUDA_PRIMITIVE_FLOATING_TYPE_SEQ \
CUDA_PRIMITIVE_FLOAT_TYPE_SEQ \ CUDA_PRIMITIVE_FLOAT_TYPE_SEQ \
CUDA_PRIMITIVE_DOUBLE_TYPE_SEQ \ CUDA_PRIMITIVE_DOUBLE_TYPE_SEQ \
CUDA_PRIMITIVE_FLOAT16_TYPE_SEQ \ CUDA_PRIMITIVE_FLOAT16_TYPE_SEQ \
CUDA_PRIMITIVE_BFLOAT16_TYPE_SEQ CUDA_PRIMITIVE_BFLOAT16_TYPE_SEQ
#define UTIL_OPS_DATA_TYPE_SEQ \ #define UTIL_OPS_DATA_TYPE_SEQ \
CUDA_PRIMITIVE_INT8_TYPE_SEQ \ CUDA_PRIMITIVE_INT8_TYPE_SEQ \
CUDA_PRIMITIVE_UINT8_TYPE_SEQ \ CUDA_PRIMITIVE_UINT8_TYPE_SEQ \
CUDA_PRIMITIVE_INT32_TYPE_SEQ \ CUDA_PRIMITIVE_INT32_TYPE_SEQ \
CUDA_PRIMITIVE_INT64_TYPE_SEQ \ CUDA_PRIMITIVE_INT64_TYPE_SEQ \
CUDA_PRIMITIVE_FLOAT_TYPE_SEQ \ CUDA_PRIMITIVE_FLOAT_TYPE_SEQ \
CUDA_PRIMITIVE_DOUBLE_TYPE_SEQ \ CUDA_PRIMITIVE_DOUBLE_TYPE_SEQ \
CUDA_PRIMITIVE_FLOAT16_TYPE_SEQ \ CUDA_PRIMITIVE_FLOAT16_TYPE_SEQ \
CUDA_PRIMITIVE_BFLOAT16_TYPE_SEQ CUDA_PRIMITIVE_BFLOAT16_TYPE_SEQ
#endif // WITH_ROCM #endif // WITH_ROCM
#endif // ONEFLOW_CORE_EP_CUDA_PRIMITIVE_TYPE_SEQ_H_ #endif // ONEFLOW_CORE_EP_CUDA_PRIMITIVE_TYPE_SEQ_H_
\ No newline at end of file
/* /*
Copyright 2020 The OneFlow Authors. All rights reserved. Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
*/ */
#include "oneflow/core/ep/common/primitive/unary_functor.h" #include "oneflow/core/ep/common/primitive/unary_functor.h"
#include "oneflow/core/ep/rocm/primitive/type_seq.h" #include "oneflow/core/ep/rocm/primitive/type_seq.h"
#include "oneflow/core/hip/elementwise.hip.h" #include "oneflow/core/hip/elementwise.hip.h"
#include "oneflow/core/ep/rocm/cuda_stream.h" #include "oneflow/core/ep/rocm/cuda_stream.h"
namespace oneflow { namespace oneflow {
namespace ep { namespace ep {
namespace primitive { namespace primitive {
template<typename Dst, typename Src> template<typename Dst, typename Src>
struct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kGelu, Dst, Src> { struct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kGelu, Dst, Src> {
UnaryFunctor(Scalar attr0, Scalar attr1) {} UnaryFunctor(Scalar attr0, Scalar attr1) {}
OF_DEVICE_FUNC Dst operator()(Src src) const { OF_DEVICE_FUNC Dst operator()(Src src) const {
return static_cast<Src>(0.5) * src return static_cast<Src>(0.5) * src
* (static_cast<Src>(1.0) + erf(static_cast<Src>(M_SQRT1_2) * src)); * (static_cast<Src>(1.0) + erf(static_cast<Src>(M_SQRT1_2) * src));
} }
}; };
template<> template<>
struct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kTanh, float, float> { struct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kTanh, float, float> {
UnaryFunctor(Scalar attr0, Scalar attr1) {} UnaryFunctor(Scalar attr0, Scalar attr1) {}
OF_DEVICE_FUNC float operator()(float src) const { return tanhf(src); } OF_DEVICE_FUNC float operator()(float src) const { return tanhf(src); }
}; };
template<> template<>
struct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kTanh, double, double> { struct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kTanh, double, double> {
UnaryFunctor(Scalar attr0, Scalar attr1) {} UnaryFunctor(Scalar attr0, Scalar attr1) {}
OF_DEVICE_FUNC double operator()(double src) const { return tanh(src); } OF_DEVICE_FUNC double operator()(double src) const { return tanh(src); }
}; };
template<> template<>
struct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kTanh, half, half> { struct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kTanh, half, half> {
UnaryFunctor(Scalar attr0, Scalar attr1) {} UnaryFunctor(Scalar attr0, Scalar attr1) {}
OF_DEVICE_FUNC half operator()(half src) const { return __float2half(tanhf(__half2float(src))); } OF_DEVICE_FUNC half operator()(half src) const { return __float2half(tanhf(__half2float(src))); }
}; };
template<> template<>
struct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kIsInf, bool, half> { struct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kIsInf, bool, half> {
UnaryFunctor(Scalar attr0, Scalar attr1) {} UnaryFunctor(Scalar attr0, Scalar attr1) {}
OF_DEVICE_FUNC bool operator()(half src) const { return isinf(__half2float(src)); } OF_DEVICE_FUNC bool operator()(half src) const { return isinf(__half2float(src)); }
}; };
template<> template<>
struct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kIsInf, bool, float> { struct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kIsInf, bool, float> {
UnaryFunctor(Scalar attr0, Scalar attr1) {} UnaryFunctor(Scalar attr0, Scalar attr1) {}
OF_DEVICE_FUNC bool operator()(float src) const { return isinf(src); } OF_DEVICE_FUNC bool operator()(float src) const { return isinf(src); }
}; };
template<> template<>
struct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kIsInf, bool, double> { struct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kIsInf, bool, double> {
UnaryFunctor(Scalar attr0, Scalar attr1) {} UnaryFunctor(Scalar attr0, Scalar attr1) {}
OF_DEVICE_FUNC bool operator()(double src) const { return isinf(src); } OF_DEVICE_FUNC bool operator()(double src) const { return isinf(src); }
}; };
template<> template<>
struct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kIsNan, bool, half> { struct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kIsNan, bool, half> {
UnaryFunctor(Scalar attr0, Scalar attr1) {} UnaryFunctor(Scalar attr0, Scalar attr1) {}
OF_DEVICE_FUNC bool operator()(half src) const { return isnan(__half2float(src)); } OF_DEVICE_FUNC bool operator()(half src) const { return isnan(__half2float(src)); }
}; };
template<> template<>
struct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kIsNan, bool, float> { struct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kIsNan, bool, float> {
UnaryFunctor(Scalar attr0, Scalar attr1) {} UnaryFunctor(Scalar attr0, Scalar attr1) {}
OF_DEVICE_FUNC bool operator()(float src) const { return isnan(src); } OF_DEVICE_FUNC bool operator()(float src) const { return isnan(src); }
}; };
template<> template<>
struct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kIsNan, bool, double> { struct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kIsNan, bool, double> {
UnaryFunctor(Scalar attr0, Scalar attr1) {} UnaryFunctor(Scalar attr0, Scalar attr1) {}
OF_DEVICE_FUNC bool operator()(double src) const { return isnan(src); } OF_DEVICE_FUNC bool operator()(double src) const { return isnan(src); }
}; };
#define SPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR(op) \ #define SPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR(op) \
template<> \ template<> \
struct UnaryFunctor<DeviceType::kCUDA, op, half, half> { \ struct UnaryFunctor<DeviceType::kCUDA, op, half, half> { \
UnaryFunctor(Scalar attr0, Scalar attr1) : float_functor(attr0, attr1) {} \ UnaryFunctor(Scalar attr0, Scalar attr1) : float_functor(attr0, attr1) {} \
\ \
UnaryFunctor<DeviceType::kCUDA, op, float, float> float_functor; \ UnaryFunctor<DeviceType::kCUDA, op, float, float> float_functor; \
OF_DEVICE_FUNC half operator()(half src) const { \ OF_DEVICE_FUNC half operator()(half src) const { \
return __float2half(float_functor(__half2float(src))); \ return __float2half(float_functor(__half2float(src))); \
} \ } \
}; };
SPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR(UnaryOp::kElu); SPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR(UnaryOp::kElu);
SPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR(UnaryOp::kCelu); SPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR(UnaryOp::kCelu);
SPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR(UnaryOp::kGelu); SPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR(UnaryOp::kGelu);
SPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR(UnaryOp::kMish); SPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR(UnaryOp::kMish);
SPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR(UnaryOp::kSelu); SPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR(UnaryOp::kSelu);
SPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR(UnaryOp::kSilu); SPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR(UnaryOp::kSilu);
SPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR(UnaryOp::kSoftSign); SPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR(UnaryOp::kSoftSign);
SPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR(UnaryOp::kSoftPlus); SPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR(UnaryOp::kSoftPlus);
// /*********nv_bfloat16_kernel*******/ // /*********nv_bfloat16_kernel*******/
// #if CUDA_VERSION >= 11000 // #if CUDA_VERSION >= 11000
// #define SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(op) \ // #define SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(op) \
// template<> \ // template<> \
// struct UnaryFunctor<DeviceType::kCUDA, op, nv_bfloat16, nv_bfloat16> { \ // struct UnaryFunctor<DeviceType::kCUDA, op, nv_bfloat16, nv_bfloat16> { \
// UnaryFunctor(Scalar attr0, Scalar attr1) : float_functor(attr0, attr1) {} \ // UnaryFunctor(Scalar attr0, Scalar attr1) : float_functor(attr0, attr1) {} \
// \ // \
// UnaryFunctor<DeviceType::kCUDA, op, float, float> float_functor; \ // UnaryFunctor<DeviceType::kCUDA, op, float, float> float_functor; \
// OF_DEVICE_FUNC nv_bfloat16 operator()(nv_bfloat16 src) const { \ // OF_DEVICE_FUNC nv_bfloat16 operator()(nv_bfloat16 src) const { \
// return __float2bfloat16(float_functor(__bfloat162float(src))); \ // return __float2bfloat16(float_functor(__bfloat162float(src))); \
// } \ // } \
// }; // };
// SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kElu); // SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kElu);
// SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kCelu); // SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kCelu);
// SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kGelu); // SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kGelu);
// SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kHardSwish); // SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kHardSwish);
// SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kHardSigmoid); // SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kHardSigmoid);
// SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kHardShrink); // SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kHardShrink);
// SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kHardTanh); // SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kHardTanh);
// SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kLeakyRelu); // SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kLeakyRelu);
// SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kMish); // SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kMish);
// SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kSelu); // SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kSelu);
// SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kSilu); // SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kSilu);
// SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kSoftShrink); // SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kSoftShrink);
// SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kSoftSign); // SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kSoftSign);
// SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kSoftPlus); // SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kSoftPlus);
// SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kTanh); // SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kTanh);
// SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kThreshold); // SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kThreshold);
// template<> // template<>
// struct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kIsInf, bool, nv_bfloat16> { // struct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kIsInf, bool, nv_bfloat16> {
// UnaryFunctor(Scalar attr0, Scalar attr1) {} // UnaryFunctor(Scalar attr0, Scalar attr1) {}
// OF_DEVICE_FUNC bool operator()(nv_bfloat16 src) const { return isinf(__bfloat162float(src)); } // OF_DEVICE_FUNC bool operator()(nv_bfloat16 src) const { return isinf(__bfloat162float(src)); }
// }; // };
// template<> // template<>
// struct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kIsNan, bool, nv_bfloat16> { // struct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kIsNan, bool, nv_bfloat16> {
// UnaryFunctor(Scalar attr0, Scalar attr1) {} // UnaryFunctor(Scalar attr0, Scalar attr1) {}
// OF_DEVICE_FUNC bool operator()(nv_bfloat16 src) const { return isnan(__bfloat162float(src)); } // OF_DEVICE_FUNC bool operator()(nv_bfloat16 src) const { return isnan(__bfloat162float(src)); }
// }; // };
// #endif // #endif
} // namespace primitive } // namespace primitive
} // namespace ep } // namespace ep
} // namespace oneflow } // namespace oneflow
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment