Commit 8f7de847 authored by yuguo960516yuguo's avatar yuguo960516yuguo
Browse files

dtk

parent f262efc9
Pipeline #248 failed with stages
in 0 seconds
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/ep/common/primitive/binary_functor.h"
namespace oneflow {
namespace ep {
namespace primitive {
namespace broadcast_elementwise_binary {
template<typename Src, typename Dst>
struct BinaryFunctor<DeviceType::kCUDA, BinaryOp::kPow, Src, Dst> {
OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}
OF_DEVICE_FUNC Dst operator()(Src src0, Src src1) const { return pow(src0, src1); }
};
template<>
struct BinaryFunctor<DeviceType::kCUDA, BinaryOp::kPow, bool, bool> {
OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}
OF_DEVICE_FUNC bool operator()(bool src0, bool src1) const {
return static_cast<bool>(pow(static_cast<double>(src0), static_cast<double>(src1)));
}
};
template<>
struct BinaryFunctor<DeviceType::kCUDA, BinaryOp::kPow, half, half> {
OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}
OF_DEVICE_FUNC half operator()(half src0, half src1) const {
return static_cast<half>(pow(static_cast<float>(src0), static_cast<float>(src1)));
}
};
template<typename Src, typename Dst>
struct BinaryFunctor<DeviceType::kCUDA, BinaryOp::kGeluBackwardWithDyX, Src, Dst> {
OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {
#if defined(__CUDA_ARCH__)
coef = sqrt(static_cast<Src>(2.0) / acos(static_cast<Src>(-1.0)));
#elif defined(__HIP_DEVICE_COMPILE__)
coef = sqrt(static_cast<Src>(2.0) / acos(static_cast<Src>(-1.0)));
#else
coef = std::sqrt(static_cast<Src>(2.0) / std::acos(static_cast<Src>(-1.0)));
#endif
}
OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const {
return static_cast<Src>(0.5)
* (static_cast<Src>(1.0) + erf(static_cast<Src>(M_SQRT1_2) * x)
+ x * coef * exp(static_cast<Src>(-0.5) * x * x))
* dy;
}
Src coef;
};
template<typename Src, typename Dst>
struct BinaryFunctor<DeviceType::kCUDA, BinaryOp::kTanhBackwardWithDyX, Src, Dst> {
OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}
OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const {
Src tanh_val = tanh(x);
return static_cast<Dst>(dy * (static_cast<Src>(1.0) - tanh_val * tanh_val));
}
};
// /*********nv_bfloat16_kernel*******/
// #if CUDA_VERSION >= 11000
// template<>
// struct BinaryFunctor<DeviceType::kCUDA, BinaryOp::kPow, nv_bfloat16, nv_bfloat16> {
// OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}
// OF_DEVICE_FUNC nv_bfloat16 operator()(nv_bfloat16 src0, nv_bfloat16 src1) const {
// return static_cast<nv_bfloat16>(pow(static_cast<float>(src0), static_cast<float>(src1)));
// }
// };
// #define SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(op) \
// template<> \
// struct BinaryFunctor<DeviceType::kCUDA, op, nv_bfloat16, nv_bfloat16> { \
// OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) : float_functor(attr0, attr1) {} \
// \
// BinaryFunctor<DeviceType::kCUDA, op, float, float> float_functor; \
// OF_DEVICE_FUNC nv_bfloat16 operator()(nv_bfloat16 src0, nv_bfloat16 src1) const { \
// return __float2bfloat16(float_functor(__bfloat162float(src0), __bfloat162float(src1))); \
// } \
// };
// SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kEluBackwardWithDyX);
// SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kCeluBackwardWithDyX);
// SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kGeluBackwardWithDyX);
// SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kHardswishBackwardWithDyX);
// SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kHardsigmoidBackwardWithDyX);
// SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kHardshrinkBackwardWithDyY);
// SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kHardtanhBackwardWithDyY);
// SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kLeakyReluBackwardWithDyX);
// SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kMishBackwardWithDyX);
// SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kSeluBackwardWithDyX);
// SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kSiluBackwardWithDyX);
// SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kSoftsignBackwardWithDyX);
// SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kSoftplusBackwardWithDyX);
// SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kSoftshrinkBackwardWithDyY);
// SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kTanhBackwardWithDyX);
// SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kThresholdBackwardWithDyX);
// #endif // CUDA_VERSION >= 11000
#define SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(op) \
template<> \
struct BinaryFunctor<DeviceType::kCUDA, op, half, half> { \
OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) : float_functor(attr0, attr1) {} \
\
BinaryFunctor<DeviceType::kCUDA, op, float, float> float_functor; \
OF_DEVICE_FUNC half operator()(half src0, half src1) const { \
return __float2half(float_functor(__half2float(src0), __half2float(src1))); \
} \
};
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kEluBackwardWithDyX);
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kCeluBackwardWithDyX);
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kGeluBackwardWithDyX);
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kHardswishBackwardWithDyX);
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kHardshrinkBackwardWithDyY);
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kMishBackwardWithDyX);
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kSiluBackwardWithDyX);
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kSeluBackwardWithDyX);
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kSoftplusBackwardWithDyX);
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kSoftsignBackwardWithDyX);
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kSoftshrinkBackwardWithDyY);
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kThresholdBackwardWithDyX);
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kTanhBackwardWithDyX);
} // namespace broadcast_elementwise_binary
} // namespace primitive
} // namespace ep
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/ep/common/primitive/binary_functor.h"
namespace oneflow {
namespace ep {
namespace primitive {
namespace broadcast_elementwise_binary {
template<typename Src, typename Dst>
struct BinaryFunctor<DeviceType::kCUDA, BinaryOp::kPow, Src, Dst> {
OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}
OF_DEVICE_FUNC Dst operator()(Src src0, Src src1) const { return pow(src0, src1); }
};
template<>
struct BinaryFunctor<DeviceType::kCUDA, BinaryOp::kPow, bool, bool> {
OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}
OF_DEVICE_FUNC bool operator()(bool src0, bool src1) const {
return static_cast<bool>(pow(static_cast<double>(src0), static_cast<double>(src1)));
}
};
template<>
struct BinaryFunctor<DeviceType::kCUDA, BinaryOp::kPow, half, half> {
OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}
OF_DEVICE_FUNC half operator()(half src0, half src1) const {
return static_cast<half>(pow(static_cast<float>(src0), static_cast<float>(src1)));
}
};
template<typename Src, typename Dst>
struct BinaryFunctor<DeviceType::kCUDA, BinaryOp::kGeluBackwardWithDyX, Src, Dst> {
OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {
#if defined(__CUDA_ARCH__)
coef = sqrt(static_cast<Src>(2.0) / acos(static_cast<Src>(-1.0)));
#elif defined(__HIP_DEVICE_COMPILE__)
coef = sqrt(static_cast<Src>(2.0) / acos(static_cast<Src>(-1.0)));
#else
coef = std::sqrt(static_cast<Src>(2.0) / std::acos(static_cast<Src>(-1.0)));
#endif
}
OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const {
return static_cast<Src>(0.5)
* (static_cast<Src>(1.0) + erf(static_cast<Src>(M_SQRT1_2) * x)
+ x * coef * exp(static_cast<Src>(-0.5) * x * x))
* dy;
}
Src coef;
};
template<typename Src, typename Dst>
struct BinaryFunctor<DeviceType::kCUDA, BinaryOp::kTanhBackwardWithDyX, Src, Dst> {
OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}
OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const {
Src tanh_val = tanh(x);
return static_cast<Dst>(dy * (static_cast<Src>(1.0) - tanh_val * tanh_val));
}
};
// /*********nv_bfloat16_kernel*******/
// #if CUDA_VERSION >= 11000
// template<>
// struct BinaryFunctor<DeviceType::kCUDA, BinaryOp::kPow, nv_bfloat16, nv_bfloat16> {
// OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}
// OF_DEVICE_FUNC nv_bfloat16 operator()(nv_bfloat16 src0, nv_bfloat16 src1) const {
// return static_cast<nv_bfloat16>(pow(static_cast<float>(src0), static_cast<float>(src1)));
// }
// };
// #define SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(op) \
// template<> \
// struct BinaryFunctor<DeviceType::kCUDA, op, nv_bfloat16, nv_bfloat16> { \
// OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) : float_functor(attr0, attr1) {} \
// \
// BinaryFunctor<DeviceType::kCUDA, op, float, float> float_functor; \
// OF_DEVICE_FUNC nv_bfloat16 operator()(nv_bfloat16 src0, nv_bfloat16 src1) const { \
// return __float2bfloat16(float_functor(__bfloat162float(src0), __bfloat162float(src1))); \
// } \
// };
// SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kEluBackwardWithDyX);
// SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kCeluBackwardWithDyX);
// SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kGeluBackwardWithDyX);
// SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kHardswishBackwardWithDyX);
// SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kHardsigmoidBackwardWithDyX);
// SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kHardshrinkBackwardWithDyY);
// SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kHardtanhBackwardWithDyY);
// SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kLeakyReluBackwardWithDyX);
// SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kMishBackwardWithDyX);
// SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kSeluBackwardWithDyX);
// SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kSiluBackwardWithDyX);
// SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kSoftsignBackwardWithDyX);
// SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kSoftplusBackwardWithDyX);
// SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kSoftshrinkBackwardWithDyY);
// SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kTanhBackwardWithDyX);
// SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kThresholdBackwardWithDyX);
// #endif // CUDA_VERSION >= 11000
#define SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(op) \
template<> \
struct BinaryFunctor<DeviceType::kCUDA, op, half, half> { \
OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) : float_functor(attr0, attr1) {} \
\
BinaryFunctor<DeviceType::kCUDA, op, float, float> float_functor; \
OF_DEVICE_FUNC half operator()(half src0, half src1) const { \
return __float2half(float_functor(__half2float(src0), __half2float(src1))); \
} \
};
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kEluBackwardWithDyX);
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kCeluBackwardWithDyX);
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kGeluBackwardWithDyX);
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kHardswishBackwardWithDyX);
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kHardshrinkBackwardWithDyY);
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kMishBackwardWithDyX);
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kSiluBackwardWithDyX);
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kSeluBackwardWithDyX);
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kSoftplusBackwardWithDyX);
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kSoftsignBackwardWithDyX);
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kSoftshrinkBackwardWithDyY);
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kThresholdBackwardWithDyX);
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kTanhBackwardWithDyX);
} // namespace broadcast_elementwise_binary
} // namespace primitive
} // namespace ep
} // namespace oneflow
\ No newline at end of file
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/ep/include/primitive/broadcast_elementwise_binary.h"
#include "oneflow/core/ep/common/primitive/broadcast_elementwise_binary.h"
#include "oneflow/core/ep/rocm/primitive/type_seq.h"
#include "oneflow/core/ep/rocm/cuda_stream.h"
#include "oneflow/core/hip/elementwise.hip.h"
#include "oneflow/core/ep/rocm/primitive/binary_functor.hip.h"
namespace oneflow {
namespace ep {
namespace primitive {
namespace broadcast_elementwise_binary {
template<BinaryOp binary_op, typename Src, typename Dst>
std::unique_ptr<BroadcastElementwiseBinary> NewBroadcastElementwiseBinary(Scalar attr0,
Scalar attr1);
namespace {
class BroadcastElementwiseBinaryFactoryImpl : public BroadcastElementwiseBinaryFactory {
public:
OF_DISALLOW_COPY_AND_MOVE(BroadcastElementwiseBinaryFactoryImpl);
BroadcastElementwiseBinaryFactoryImpl() = default;
~BroadcastElementwiseBinaryFactoryImpl() override = default;
std::unique_ptr<BroadcastElementwiseBinary> New(BinaryOp op, DataType src_type, DataType dst_type,
size_t max_num_dims) override {
return New(op, src_type, dst_type, max_num_dims, Scalar(), Scalar());
}
std::unique_ptr<BroadcastElementwiseBinary> New(BinaryOp op, DataType src_type, DataType dst_type,
size_t max_num_dims, Scalar attr0) override {
return New(op, src_type, dst_type, max_num_dims, attr0, Scalar());
}
std::unique_ptr<BroadcastElementwiseBinary> New(BinaryOp binary_op, DataType src_type,
DataType dst_type, size_t max_num_dims,
Scalar attr0, Scalar attr1) override {
if (max_num_dims > kMaxNumDims) { return nullptr; }
#define MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY(binary_op, data_type_pair) \
{std::make_tuple(binary_op, OF_PP_PAIR_SECOND(data_type_pair), \
OF_PP_PAIR_SECOND(data_type_pair)), \
NewBroadcastElementwiseBinary<binary_op, OF_PP_PAIR_FIRST(data_type_pair), \
OF_PP_PAIR_FIRST(data_type_pair)>},
#define MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_COMPARASION_AND_LOGICAL_ENTRY( \
binary_op, src_data_type_pair, dst_data_type_pair) \
{std::make_tuple(binary_op, OF_PP_PAIR_SECOND(src_data_type_pair), \
OF_PP_PAIR_SECOND(dst_data_type_pair)), \
NewBroadcastElementwiseBinary<binary_op, OF_PP_PAIR_FIRST(src_data_type_pair), \
OF_PP_PAIR_FIRST(dst_data_type_pair)>},
#define MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_ACTIVATION_GRAD_ENTRY(binary_op, data_type_pair) \
{std::make_tuple(binary_op, OF_PP_PAIR_SECOND(data_type_pair), \
OF_PP_PAIR_SECOND(data_type_pair)), \
NewBroadcastElementwiseBinary<binary_op, OF_PP_PAIR_FIRST(data_type_pair), \
OF_PP_PAIR_FIRST(data_type_pair)>},
static const std::map<
std::tuple<BinaryOp, DataType, DataType>,
std::function<std::unique_ptr<BroadcastElementwiseBinary>(Scalar, Scalar)>>
new_broadcast_elementwise_binary_handle{
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY,
BINARY_MATH_OP_SEQ, CUDA_PRIMITIVE_ALL_TYPE_SEQ)
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(
MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_COMPARASION_AND_LOGICAL_ENTRY,
BINARY_COMPARISION_OP_SEQ BINARY_LOGICAL_OP_SEQ, CUDA_PRIMITIVE_ALL_TYPE_SEQ,
CUDA_PRIMITIVE_BOOL_TYPE_SEQ)
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(
MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_ACTIVATION_GRAD_ENTRY,
BINARY_ACTIVATION_BACKWARD_OP_SEQ, CUDA_PRIMITIVE_FLOATING_TYPE_SEQ)};
#undef MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_COMPARASION_AND_LOGICAL_ENTRY
#undef MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY
const auto it = new_broadcast_elementwise_binary_handle.find(
std::make_tuple(binary_op, src_type, dst_type));
if (it != new_broadcast_elementwise_binary_handle.end()) {
return it->second(attr0, attr1);
} else {
return nullptr;
}
}
};
REGISTER_PRIMITIVE_FACTORY(DeviceType::kCUDA, BroadcastElementwiseBinaryFactory,
BroadcastElementwiseBinaryFactoryImpl);
} // namespace
} // namespace broadcast_elementwise_binary
} // namespace primitive
} // namespace ep
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/ep/include/primitive/broadcast_elementwise_binary.h"
#include "oneflow/core/ep/common/primitive/broadcast_elementwise_binary.h"
#include "oneflow/core/ep/rocm/primitive/type_seq.h"
#include "oneflow/core/ep/rocm/cuda_stream.h"
#include "oneflow/core/hip/elementwise.hip.h"
#include "oneflow/core/ep/rocm/primitive/binary_functor.hip.h"
namespace oneflow {
namespace ep {
namespace primitive {
namespace broadcast_elementwise_binary {
template<BinaryOp binary_op, typename Src, typename Dst>
std::unique_ptr<BroadcastElementwiseBinary> NewBroadcastElementwiseBinary(Scalar attr0,
Scalar attr1);
namespace {
class BroadcastElementwiseBinaryFactoryImpl : public BroadcastElementwiseBinaryFactory {
public:
OF_DISALLOW_COPY_AND_MOVE(BroadcastElementwiseBinaryFactoryImpl);
BroadcastElementwiseBinaryFactoryImpl() = default;
~BroadcastElementwiseBinaryFactoryImpl() override = default;
std::unique_ptr<BroadcastElementwiseBinary> New(BinaryOp op, DataType src_type, DataType dst_type,
size_t max_num_dims) override {
return New(op, src_type, dst_type, max_num_dims, Scalar(), Scalar());
}
std::unique_ptr<BroadcastElementwiseBinary> New(BinaryOp op, DataType src_type, DataType dst_type,
size_t max_num_dims, Scalar attr0) override {
return New(op, src_type, dst_type, max_num_dims, attr0, Scalar());
}
std::unique_ptr<BroadcastElementwiseBinary> New(BinaryOp binary_op, DataType src_type,
DataType dst_type, size_t max_num_dims,
Scalar attr0, Scalar attr1) override {
if (max_num_dims > kMaxNumDims) { return nullptr; }
#define MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY(binary_op, data_type_pair) \
{std::make_tuple(binary_op, OF_PP_PAIR_SECOND(data_type_pair), \
OF_PP_PAIR_SECOND(data_type_pair)), \
NewBroadcastElementwiseBinary<binary_op, OF_PP_PAIR_FIRST(data_type_pair), \
OF_PP_PAIR_FIRST(data_type_pair)>},
#define MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_COMPARASION_AND_LOGICAL_ENTRY( \
binary_op, src_data_type_pair, dst_data_type_pair) \
{std::make_tuple(binary_op, OF_PP_PAIR_SECOND(src_data_type_pair), \
OF_PP_PAIR_SECOND(dst_data_type_pair)), \
NewBroadcastElementwiseBinary<binary_op, OF_PP_PAIR_FIRST(src_data_type_pair), \
OF_PP_PAIR_FIRST(dst_data_type_pair)>},
#define MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_ACTIVATION_GRAD_ENTRY(binary_op, data_type_pair) \
{std::make_tuple(binary_op, OF_PP_PAIR_SECOND(data_type_pair), \
OF_PP_PAIR_SECOND(data_type_pair)), \
NewBroadcastElementwiseBinary<binary_op, OF_PP_PAIR_FIRST(data_type_pair), \
OF_PP_PAIR_FIRST(data_type_pair)>},
static const std::map<
std::tuple<BinaryOp, DataType, DataType>,
std::function<std::unique_ptr<BroadcastElementwiseBinary>(Scalar, Scalar)>>
new_broadcast_elementwise_binary_handle{
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY,
BINARY_MATH_OP_SEQ, CUDA_PRIMITIVE_ALL_TYPE_SEQ)
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(
MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_COMPARASION_AND_LOGICAL_ENTRY,
BINARY_COMPARISION_OP_SEQ BINARY_LOGICAL_OP_SEQ, CUDA_PRIMITIVE_ALL_TYPE_SEQ,
CUDA_PRIMITIVE_BOOL_TYPE_SEQ)
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(
MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_ACTIVATION_GRAD_ENTRY,
BINARY_ACTIVATION_BACKWARD_OP_SEQ, CUDA_PRIMITIVE_FLOATING_TYPE_SEQ)};
#undef MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_COMPARASION_AND_LOGICAL_ENTRY
#undef MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY
const auto it = new_broadcast_elementwise_binary_handle.find(
std::make_tuple(binary_op, src_type, dst_type));
if (it != new_broadcast_elementwise_binary_handle.end()) {
return it->second(attr0, attr1);
} else {
return nullptr;
}
}
};
REGISTER_PRIMITIVE_FACTORY(DeviceType::kCUDA, BroadcastElementwiseBinaryFactory,
BroadcastElementwiseBinaryFactoryImpl);
} // namespace
} // namespace broadcast_elementwise_binary
} // namespace primitive
} // namespace ep
} // namespace oneflow
\ No newline at end of file
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "hip/hip_runtime.h"
#include "oneflow/core/ep/include/primitive/broadcast_elementwise_binary.h"
#include "oneflow/core/ep/common/primitive/broadcast_elementwise_binary.h"
#include "oneflow/core/ep/rocm/primitive/type_seq.h"
#include "oneflow/core/ep/rocm/cuda_stream.h"
#include "oneflow/core/hip/elementwise.hip.h"
#include "oneflow/core/ep/rocm/primitive/binary_functor.hip.h"
namespace oneflow {
namespace ep {
namespace primitive {
namespace broadcast_elementwise_binary {
namespace {
template<typename T, int N>
struct GetPackType {
using type = typename std::aligned_storage<N * sizeof(T), N * sizeof(T)>::type;
};
template<typename T, int N>
using PackType = typename GetPackType<T, N>::type;
template<typename T, int N>
union Pack {
static_assert(sizeof(PackType<T, N>) == sizeof(T) * N, "");
OF_DEVICE_FUNC Pack() {
// do nothing
}
PackType<T, N> storage;
T elem[N];
};
template<size_t max_dims, typename IndexType>
struct BroadcastElementwiseBinaryParams {
NdIndexOffsetHelper<IndexType, max_dims> src0_index_helper;
NdIndexOffsetHelper<IndexType, max_dims> src1_index_helper;
NdIndexOffsetHelper<IndexType, max_dims> dst_index_helper;
size_t num_dims;
IndexType src0_index_mask[max_dims];
IndexType src1_index_mask[max_dims];
IndexType count{};
const void* src0{};
const void* src1{};
void* dst{};
Scalar attr0;
Scalar attr1;
};
template<BinaryOp binary_op, typename Src, typename Dst, size_t max_dims, size_t src0_pack_size,
size_t src1_pack_size, typename IndexType>
__global__ void BroadcastElementwiseBinaryGpu(
BroadcastElementwiseBinaryParams<max_dims, IndexType> params) {
constexpr size_t dst_pack_size =
src0_pack_size > src1_pack_size ? src0_pack_size : src1_pack_size;
static_assert(src0_pack_size == dst_pack_size || src0_pack_size == 1, "");
static_assert(src1_pack_size == dst_pack_size || src1_pack_size == 1, "");
const PackType<Src, src0_pack_size>* src0 =
reinterpret_cast<const PackType<Src, src0_pack_size>*>(params.src0);
const PackType<Src, src1_pack_size>* src1 =
reinterpret_cast<const PackType<Src, src1_pack_size>*>(params.src1);
PackType<Dst, dst_pack_size>* dst = reinterpret_cast<PackType<Dst, dst_pack_size>*>(params.dst);
IndexType src0_index[max_dims];
IndexType src1_index[max_dims];
IndexType dst_index[max_dims];
size_t num_dims = params.num_dims;
CUDA_1D_KERNEL_LOOP_T(IndexType, offset, params.count) {
params.dst_index_helper.OffsetToNdIndex(offset, dst_index, num_dims);
#pragma unroll
for (int i = 0; i < max_dims; ++i) {
if (i < num_dims) {
src0_index[i] = params.src0_index_mask[i] * dst_index[i];
src1_index[i] = params.src1_index_mask[i] * dst_index[i];
} else {
src0_index[i] = 0;
src1_index[i] = 0;
}
}
const IndexType src0_offset = params.src0_index_helper.NdIndexToOffset(src0_index, num_dims);
const IndexType src1_offset = params.src1_index_helper.NdIndexToOffset(src1_index, num_dims);
Pack<Src, src0_pack_size> src0_pack;
src0_pack.storage = src0[src0_offset];
Pack<Src, src1_pack_size> src1_pack;
src1_pack.storage = src1[src1_offset];
Pack<Dst, dst_pack_size> dst_pack;
BinaryFunctor<DeviceType::kCUDA, binary_op, Src, Dst> functor(params.attr0, params.attr1);
#pragma unroll
for (int j = 0; j < dst_pack_size; ++j) {
const Src src0_val =
(src0_pack_size == dst_pack_size) ? src0_pack.elem[j] : src0_pack.elem[0];
const Src src1_val =
(src1_pack_size == dst_pack_size) ? src1_pack.elem[j] : src1_pack.elem[0];
dst_pack.elem[j] = functor(src0_val, src1_val);
}
dst[offset] = dst_pack.storage;
}
}
template<BinaryOp op, typename T, typename R, size_t max_dims, size_t src0_pack_size,
size_t src1_pack_size, typename IndexType>
void LaunchKernel(Stream* stream, int num_dims, const int64_t* src0_dims, const void* src0,
const int64_t* src1_dims, const void* src1, const int64_t* dst_dims, void* dst,
size_t count, Scalar attr0, Scalar attr1) {
BroadcastElementwiseBinaryParams<max_dims, IndexType> params;
for (size_t i = 0; i < num_dims; ++i) {
params.src0_index_mask[i] = (src0_dims[i] == 1) ? 0 : 1;
params.src1_index_mask[i] = (src1_dims[i] == 1) ? 0 : 1;
}
params.src0_index_helper = NdIndexOffsetHelper<IndexType, max_dims>(src0_dims, num_dims);
params.src1_index_helper = NdIndexOffsetHelper<IndexType, max_dims>(src1_dims, num_dims);
params.dst_index_helper = NdIndexOffsetHelper<IndexType, max_dims>(dst_dims, num_dims);
params.num_dims = num_dims;
params.src0 = src0;
params.src1 = src1;
params.dst = dst;
params.count = static_cast<IndexType>(count);
params.attr0 = attr0;
params.attr1 = attr1;
auto* cuda_stream = stream->As<CudaStream>();
BroadcastElementwiseBinaryGpu<op, T, R, max_dims, src0_pack_size, src1_pack_size, IndexType>
<<<BlocksNum4ThreadsNum(params.count), kCudaThreadsNumPerBlock, 0,
cuda_stream->cuda_stream()>>>(params);
}
template<BinaryOp op, typename T, typename R, size_t max_dims, size_t src0_pack_size,
size_t src1_pack_size>
void DispatchIndexType(Stream* stream, size_t num_dims, const int64_t* src0_dims, const void* src0,
const int64_t* src1_dims, const void* src1, const int64_t* dst_dims,
void* dst, Scalar attr0, Scalar attr1) {
size_t count = GetElementCount(num_dims, dst_dims);
if (count < GetMaxVal<int32_t>()) {
LaunchKernel<op, T, R, max_dims, src0_pack_size, src1_pack_size, int32_t>(
stream, num_dims, src0_dims, src0, src1_dims, src1, dst_dims, dst, count, attr0, attr1);
} else {
LaunchKernel<op, T, R, max_dims, src0_pack_size, src1_pack_size, int64_t>(
stream, num_dims, src0_dims, src0, src1_dims, src1, dst_dims, dst, count, attr0, attr1);
}
}
template<BinaryOp op, typename T, typename R, size_t max_dims>
void DispatchPackSize(Stream* stream, size_t src0_pack_size, size_t src1_pack_size, size_t num_dims,
const int64_t* src0_dims, const void* src0, const int64_t* src1_dims,
const void* src1, const int64_t* dst_dims, void* dst, Scalar attr0,
Scalar attr1) {
void (*func)(Stream* /*stream*/, size_t /*num_dims*/, const int64_t* /*src0_dims*/,
const void* /*src0*/, const int64_t* /*src1_dims*/, const void* /*src1*/,
const int64_t* /*dst_dims*/, void* /*dst*/, Scalar /*attr0*/, Scalar /*attr1*/) =
nullptr;
if (src0_pack_size == 1 && src1_pack_size == 1) {
func = DispatchIndexType<op, T, R, max_dims, 1, 1>;
} else if (src0_pack_size == 4 && src1_pack_size == 4) {
func = DispatchIndexType<op, T, R, max_dims, 4, 4>;
} else if (src0_pack_size == 1 && src1_pack_size == 4) {
func = DispatchIndexType<op, T, R, max_dims, 1, 4>;
} else if (src0_pack_size == 4 && src1_pack_size == 1) {
func = DispatchIndexType<op, T, R, max_dims, 4, 1>;
} else {
UNIMPLEMENTED();
}
func(stream, num_dims, src0_dims, src0, src1_dims, src1, dst_dims, dst, attr0, attr1);
}
template<BinaryOp op, typename T, typename R>
void DispatchNumDims(Stream* stream, size_t src0_pack_size, size_t src1_pack_size, size_t num_dims,
const int64_t* src0_dims, const void* src0, const int64_t* src1_dims,
const void* src1, const int64_t* dst_dims, void* dst, Scalar attr0,
Scalar attr1) {
void (*func)(Stream* /*stream*/, size_t /*src0_pack_size*/, size_t /*src1_pack_size*/,
size_t /*num_dims*/, const int64_t* /*src0_dims*/, const void* /*src0*/,
const int64_t* /*src1_dims*/, const void* /*src1*/, const int64_t* /*dst_dims*/,
void* /*dst*/, Scalar /*attr0*/, Scalar /*attr1*/) = nullptr;
CHECK_NE(num_dims, 1);
if (num_dims == 2) {
func = DispatchPackSize<op, T, R, 2>;
} else if (num_dims == 3) {
func = DispatchPackSize<op, T, R, 3>;
} else if (num_dims == 4) {
func = DispatchPackSize<op, T, R, 4>;
} else if (num_dims <= 8) {
func = DispatchPackSize<op, T, R, 8>;
} else {
UNIMPLEMENTED();
}
func(stream, src0_pack_size, src1_pack_size, num_dims, src0_dims, src0, src1_dims, src1, dst_dims,
dst, attr0, attr1);
}
template<size_t max_pack_size, typename T, typename R>
size_t GetPackSize(size_t num_src_dims, const int64_t* src0_dims, const void* src0,
const int64_t* src1_dims, const void* src1, void* dst) {
static_assert(max_pack_size > 0 && (max_pack_size & (max_pack_size - 1)) == 0, "");
CHECK(src0_dims[num_src_dims - 1] != 1 || src1_dims[num_src_dims - 1] != 1);
auto dst_ptr = reinterpret_cast<std::uintptr_t>(dst);
for (size_t pack_size = max_pack_size; pack_size > 2; pack_size /= 2) {
bool is_src0_supported = (src0_dims[num_src_dims - 1] == 1)
|| IsPackSizeSupported<T>(pack_size, num_src_dims, src0_dims, src0);
bool is_src1_supported = (src1_dims[num_src_dims - 1] == 1)
|| IsPackSizeSupported<T>(pack_size, num_src_dims, src1_dims, src1);
if (is_src0_supported && is_src1_supported && (dst_ptr % (pack_size * sizeof(R))) == 0) {
return pack_size;
}
}
return 1;
}
constexpr size_t kMaxPackSize = 4;
template<BinaryOp op, typename T, typename R>
void LaunchWithSimplified(Stream* stream, size_t simplified_num_dims, int64_t* simplified_src0_dims,
const void* src0, int64_t* simplified_src1_dims, const void* src1,
int64_t* simplified_dst_dims, void* dst, Scalar attr0, Scalar attr1) {
CHECK_LE(simplified_num_dims, kMaxNumDims);
size_t pack_size = GetPackSize<kMaxPackSize, T, R>(simplified_num_dims, simplified_src0_dims,
src0, simplified_src1_dims, src1, dst);
size_t src0_pack_size = 1;
size_t src1_pack_size = 1;
if (simplified_src0_dims[simplified_num_dims - 1] != 1) {
simplified_src0_dims[simplified_num_dims - 1] /= pack_size;
src0_pack_size = pack_size;
}
if (simplified_src1_dims[simplified_num_dims - 1] != 1) {
simplified_src1_dims[simplified_num_dims - 1] /= pack_size;
src1_pack_size = pack_size;
}
simplified_dst_dims[simplified_num_dims - 1] /= pack_size;
DispatchNumDims<op, T, R>(stream, src0_pack_size, src1_pack_size, simplified_num_dims,
simplified_src0_dims, src0, simplified_src1_dims, src1,
simplified_dst_dims, dst, attr0, attr1);
}
template<BinaryOp binary_op, typename Src, typename Dst>
struct BinaryLhsScalarFunctor {
__host__ __device__ BinaryLhsScalarFunctor(Src scalar, Scalar attr0, Scalar attr1)
: scalar(scalar), functor(attr0, attr1) {}
__device__ Dst operator()(Src src) const { return functor(scalar, src); }
const Src scalar;
BinaryFunctor<DeviceType::kCUDA, binary_op, Src, Dst> functor;
};
template<BinaryOp binary_op, typename Src, typename Dst>
struct BinaryRhsScalarFunctor {
__host__ __device__ BinaryRhsScalarFunctor(Src scalar, Scalar attr0, Scalar attr1)
: scalar(scalar), functor(attr0, attr1) {}
__device__ Dst operator()(Src src) const { return functor(src, scalar); }
const Src scalar;
BinaryFunctor<DeviceType::kCUDA, binary_op, Src, Dst> functor;
};
template<BinaryOp binary_op, typename Src, typename Dst>
struct BinaryLhsScalarPtrFunctorFactory {
__host__ __device__ BinaryLhsScalarPtrFunctorFactory(const Src* scalar_ptr, Scalar attr0,
Scalar attr1)
: scalar_ptr(scalar_ptr), attr0(attr0), attr1(attr1) {}
__device__ BinaryLhsScalarFunctor<binary_op, Src, Dst> operator()() const {
return BinaryLhsScalarFunctor<binary_op, Src, Dst>(*scalar_ptr, attr0, attr1);
}
const Src* scalar_ptr;
Scalar attr0, attr1;
};
template<BinaryOp binary_op, typename Src, typename Dst>
struct BinaryRhsScalarPtrFunctorFactory {
__host__ __device__ explicit BinaryRhsScalarPtrFunctorFactory(const Src* scalar_ptr, Scalar attr0,
Scalar attr1)
: scalar_ptr(scalar_ptr), attr0(attr0), attr1(attr1) {}
__device__ BinaryRhsScalarFunctor<binary_op, Src, Dst> operator()() const {
return BinaryRhsScalarFunctor<binary_op, Src, Dst>(*scalar_ptr, attr0, attr1);
}
const Src* scalar_ptr;
Scalar attr0, attr1;
};
template<BinaryOp binary_op, typename Src, typename Dst>
void DispatchLaunch(Stream* stream, size_t num_src0_dims, const int64_t* src0_dims, const Src* src0,
size_t num_src1_dims, const int64_t* src1_dims, const Src* src1, Dst* dst,
Scalar attr0, Scalar attr1) {
auto* cuda_stream = stream->As<CudaStream>();
size_t simplified_num_dims = 0;
int64_t simplified_src0_dims[kMaxNumDims];
int64_t simplified_src1_dims[kMaxNumDims];
int64_t simplified_dst_dims[kMaxNumDims];
SimplifyBroadcastDims<kMaxNumDims>(num_src0_dims, src0_dims, num_src1_dims, src1_dims,
&simplified_num_dims, simplified_src0_dims,
simplified_src1_dims, simplified_dst_dims);
CheckInplace(simplified_num_dims, simplified_src0_dims, src0, simplified_src1_dims, src1,
simplified_dst_dims, dst);
if (IsDimsEquals(simplified_num_dims, simplified_src0_dims, simplified_num_dims,
simplified_src1_dims)) {
const int64_t elem_cnt = GetElementCount(simplified_num_dims, simplified_src0_dims);
OF_CUDA_CHECK((cuda::elementwise::Binary(
BinaryFunctor<DeviceType::kCUDA, binary_op, Src, Dst>(attr0, attr1), elem_cnt, dst, src0,
src1, cuda_stream->cuda_stream())));
} else {
if (simplified_num_dims == 1 && simplified_src0_dims[0] == 1) {
OF_CUDA_CHECK((cuda::elementwise::UnaryWithFactory(
BinaryLhsScalarPtrFunctorFactory<binary_op, Src, Dst>(src0, attr0, attr1),
simplified_src1_dims[0], dst, src1, cuda_stream->cuda_stream())));
} else if (simplified_num_dims == 1 && simplified_src1_dims[0] == 1) {
OF_CUDA_CHECK((cuda::elementwise::UnaryWithFactory(
BinaryRhsScalarPtrFunctorFactory<binary_op, Src, Dst>(src1, attr0, attr1),
simplified_src0_dims[0], dst, src0, cuda_stream->cuda_stream())));
} else {
LaunchWithSimplified<binary_op, Src, Dst>(stream, simplified_num_dims, simplified_src0_dims,
src0, simplified_src1_dims, src1,
simplified_dst_dims, dst, attr0, attr1);
}
}
}
template<typename T>
T GetValue(Scalar value) {
return value.Value<T>();
}
template<>
half GetValue<half>(Scalar value) {
return static_cast<half>(GetValue<float>(value));
}
// #if CUDA_VERSION >= 11000
// template<>
// nv_bfloat16 GetValue<nv_bfloat16>(Scalar value) {
// return static_cast<nv_bfloat16>(GetValue<float>(value));
// }
// #endif // CUDA_VERSION >= 11000
template<BinaryOp binary_op, typename Src, typename Dst>
class BroadcastElementwiseBinaryImpl : public BroadcastElementwiseBinary {
public:
OF_DISALLOW_COPY_AND_MOVE(BroadcastElementwiseBinaryImpl);
BroadcastElementwiseBinaryImpl(Scalar attr0, Scalar attr1) : attr0(attr0), attr1(attr1) {}
~BroadcastElementwiseBinaryImpl() override = default;
void Launch(Stream* stream, Scalar src0, size_t num_src1_dims, const int64_t* src1_dims,
const void* src1, void* dst) override {
auto* cuda_stream = stream->As<CudaStream>();
const size_t elem_cnt = GetElementCount(num_src1_dims, src1_dims);
OF_CUDA_CHECK((cuda::elementwise::Unary(
BinaryLhsScalarFunctor<binary_op, Src, Dst>(GetValue<Src>(src0), attr0, attr1), elem_cnt,
reinterpret_cast<Dst*>(dst), reinterpret_cast<const Src*>(src1),
cuda_stream->cuda_stream())));
}
void Launch(Stream* stream, size_t num_src0_dims, const int64_t* src0_dims, const void* src0,
Scalar src1, void* dst) override {
auto* cuda_stream = stream->As<CudaStream>();
const size_t elem_cnt = GetElementCount(num_src0_dims, src0_dims);
OF_CUDA_CHECK((cuda::elementwise::Unary(
BinaryRhsScalarFunctor<binary_op, Src, Dst>(GetValue<Src>(src1), attr0, attr1), elem_cnt,
reinterpret_cast<Dst*>(dst), reinterpret_cast<const Src*>(src0),
cuda_stream->cuda_stream())));
}
void Launch(Stream* stream, size_t num_src0_dims, const int64_t* src0_dims, const void* src0,
size_t num_src1_dims, const int64_t* src1_dims, const void* src1,
void* dst) override {
DispatchLaunch<binary_op, Src, Dst>(
stream, num_src0_dims, src0_dims, reinterpret_cast<const Src*>(src0), num_src1_dims,
src1_dims, reinterpret_cast<const Src*>(src1), reinterpret_cast<Dst*>(dst), attr0, attr1);
}
private:
Scalar attr0, attr1;
};
} // namespace
template<BinaryOp binary_op, typename Src, typename Dst>
std::unique_ptr<BroadcastElementwiseBinary> NewBroadcastElementwiseBinary(Scalar attr0,
Scalar attr1) {
return std::unique_ptr<BroadcastElementwiseBinary>(
new BroadcastElementwiseBinaryImpl<binary_op, Src, Dst>(attr0, attr1));
}
} // namespace broadcast_elementwise_binary
} // namespace primitive
} // namespace ep
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "hip/hip_runtime.h"
#include "oneflow/core/ep/include/primitive/broadcast_elementwise_binary.h"
#include "oneflow/core/ep/common/primitive/broadcast_elementwise_binary.h"
#include "oneflow/core/ep/rocm/primitive/type_seq.h"
#include "oneflow/core/ep/rocm/cuda_stream.h"
#include "oneflow/core/hip/elementwise.hip.h"
#include "oneflow/core/ep/rocm/primitive/binary_functor.hip.h"
namespace oneflow {
namespace ep {
namespace primitive {
namespace broadcast_elementwise_binary {
namespace {
template<typename T, int N>
struct GetPackType {
using type = typename std::aligned_storage<N * sizeof(T), N * sizeof(T)>::type;
};
template<typename T, int N>
using PackType = typename GetPackType<T, N>::type;
template<typename T, int N>
union Pack {
static_assert(sizeof(PackType<T, N>) == sizeof(T) * N, "");
OF_DEVICE_FUNC Pack() {
// do nothing
}
PackType<T, N> storage;
T elem[N];
};
template<size_t max_dims, typename IndexType>
struct BroadcastElementwiseBinaryParams {
NdIndexOffsetHelper<IndexType, max_dims> src0_index_helper;
NdIndexOffsetHelper<IndexType, max_dims> src1_index_helper;
NdIndexOffsetHelper<IndexType, max_dims> dst_index_helper;
size_t num_dims;
IndexType src0_index_mask[max_dims];
IndexType src1_index_mask[max_dims];
IndexType count{};
const void* src0{};
const void* src1{};
void* dst{};
Scalar attr0;
Scalar attr1;
};
template<BinaryOp binary_op, typename Src, typename Dst, size_t max_dims, size_t src0_pack_size,
size_t src1_pack_size, typename IndexType>
__global__ void BroadcastElementwiseBinaryGpu(
BroadcastElementwiseBinaryParams<max_dims, IndexType> params) {
constexpr size_t dst_pack_size =
src0_pack_size > src1_pack_size ? src0_pack_size : src1_pack_size;
static_assert(src0_pack_size == dst_pack_size || src0_pack_size == 1, "");
static_assert(src1_pack_size == dst_pack_size || src1_pack_size == 1, "");
const PackType<Src, src0_pack_size>* src0 =
reinterpret_cast<const PackType<Src, src0_pack_size>*>(params.src0);
const PackType<Src, src1_pack_size>* src1 =
reinterpret_cast<const PackType<Src, src1_pack_size>*>(params.src1);
PackType<Dst, dst_pack_size>* dst = reinterpret_cast<PackType<Dst, dst_pack_size>*>(params.dst);
IndexType src0_index[max_dims];
IndexType src1_index[max_dims];
IndexType dst_index[max_dims];
size_t num_dims = params.num_dims;
CUDA_1D_KERNEL_LOOP_T(IndexType, offset, params.count) {
params.dst_index_helper.OffsetToNdIndex(offset, dst_index, num_dims);
#pragma unroll
for (int i = 0; i < max_dims; ++i) {
if (i < num_dims) {
src0_index[i] = params.src0_index_mask[i] * dst_index[i];
src1_index[i] = params.src1_index_mask[i] * dst_index[i];
} else {
src0_index[i] = 0;
src1_index[i] = 0;
}
}
const IndexType src0_offset = params.src0_index_helper.NdIndexToOffset(src0_index, num_dims);
const IndexType src1_offset = params.src1_index_helper.NdIndexToOffset(src1_index, num_dims);
Pack<Src, src0_pack_size> src0_pack;
src0_pack.storage = src0[src0_offset];
Pack<Src, src1_pack_size> src1_pack;
src1_pack.storage = src1[src1_offset];
Pack<Dst, dst_pack_size> dst_pack;
BinaryFunctor<DeviceType::kCUDA, binary_op, Src, Dst> functor(params.attr0, params.attr1);
#pragma unroll
for (int j = 0; j < dst_pack_size; ++j) {
const Src src0_val =
(src0_pack_size == dst_pack_size) ? src0_pack.elem[j] : src0_pack.elem[0];
const Src src1_val =
(src1_pack_size == dst_pack_size) ? src1_pack.elem[j] : src1_pack.elem[0];
dst_pack.elem[j] = functor(src0_val, src1_val);
}
dst[offset] = dst_pack.storage;
}
}
template<BinaryOp op, typename T, typename R, size_t max_dims, size_t src0_pack_size,
size_t src1_pack_size, typename IndexType>
void LaunchKernel(Stream* stream, int num_dims, const int64_t* src0_dims, const void* src0,
const int64_t* src1_dims, const void* src1, const int64_t* dst_dims, void* dst,
size_t count, Scalar attr0, Scalar attr1) {
BroadcastElementwiseBinaryParams<max_dims, IndexType> params;
for (size_t i = 0; i < num_dims; ++i) {
params.src0_index_mask[i] = (src0_dims[i] == 1) ? 0 : 1;
params.src1_index_mask[i] = (src1_dims[i] == 1) ? 0 : 1;
}
params.src0_index_helper = NdIndexOffsetHelper<IndexType, max_dims>(src0_dims, num_dims);
params.src1_index_helper = NdIndexOffsetHelper<IndexType, max_dims>(src1_dims, num_dims);
params.dst_index_helper = NdIndexOffsetHelper<IndexType, max_dims>(dst_dims, num_dims);
params.num_dims = num_dims;
params.src0 = src0;
params.src1 = src1;
params.dst = dst;
params.count = static_cast<IndexType>(count);
params.attr0 = attr0;
params.attr1 = attr1;
auto* cuda_stream = stream->As<CudaStream>();
BroadcastElementwiseBinaryGpu<op, T, R, max_dims, src0_pack_size, src1_pack_size, IndexType>
<<<BlocksNum4ThreadsNum(params.count), kCudaThreadsNumPerBlock, 0,
cuda_stream->cuda_stream()>>>(params);
}
template<BinaryOp op, typename T, typename R, size_t max_dims, size_t src0_pack_size,
size_t src1_pack_size>
void DispatchIndexType(Stream* stream, size_t num_dims, const int64_t* src0_dims, const void* src0,
const int64_t* src1_dims, const void* src1, const int64_t* dst_dims,
void* dst, Scalar attr0, Scalar attr1) {
size_t count = GetElementCount(num_dims, dst_dims);
if (count < GetMaxVal<int32_t>()) {
LaunchKernel<op, T, R, max_dims, src0_pack_size, src1_pack_size, int32_t>(
stream, num_dims, src0_dims, src0, src1_dims, src1, dst_dims, dst, count, attr0, attr1);
} else {
LaunchKernel<op, T, R, max_dims, src0_pack_size, src1_pack_size, int64_t>(
stream, num_dims, src0_dims, src0, src1_dims, src1, dst_dims, dst, count, attr0, attr1);
}
}
template<BinaryOp op, typename T, typename R, size_t max_dims>
void DispatchPackSize(Stream* stream, size_t src0_pack_size, size_t src1_pack_size, size_t num_dims,
const int64_t* src0_dims, const void* src0, const int64_t* src1_dims,
const void* src1, const int64_t* dst_dims, void* dst, Scalar attr0,
Scalar attr1) {
void (*func)(Stream* /*stream*/, size_t /*num_dims*/, const int64_t* /*src0_dims*/,
const void* /*src0*/, const int64_t* /*src1_dims*/, const void* /*src1*/,
const int64_t* /*dst_dims*/, void* /*dst*/, Scalar /*attr0*/, Scalar /*attr1*/) =
nullptr;
if (src0_pack_size == 1 && src1_pack_size == 1) {
func = DispatchIndexType<op, T, R, max_dims, 1, 1>;
} else if (src0_pack_size == 4 && src1_pack_size == 4) {
func = DispatchIndexType<op, T, R, max_dims, 4, 4>;
} else if (src0_pack_size == 1 && src1_pack_size == 4) {
func = DispatchIndexType<op, T, R, max_dims, 1, 4>;
} else if (src0_pack_size == 4 && src1_pack_size == 1) {
func = DispatchIndexType<op, T, R, max_dims, 4, 1>;
} else {
UNIMPLEMENTED();
}
func(stream, num_dims, src0_dims, src0, src1_dims, src1, dst_dims, dst, attr0, attr1);
}
template<BinaryOp op, typename T, typename R>
void DispatchNumDims(Stream* stream, size_t src0_pack_size, size_t src1_pack_size, size_t num_dims,
const int64_t* src0_dims, const void* src0, const int64_t* src1_dims,
const void* src1, const int64_t* dst_dims, void* dst, Scalar attr0,
Scalar attr1) {
void (*func)(Stream* /*stream*/, size_t /*src0_pack_size*/, size_t /*src1_pack_size*/,
size_t /*num_dims*/, const int64_t* /*src0_dims*/, const void* /*src0*/,
const int64_t* /*src1_dims*/, const void* /*src1*/, const int64_t* /*dst_dims*/,
void* /*dst*/, Scalar /*attr0*/, Scalar /*attr1*/) = nullptr;
CHECK_NE(num_dims, 1);
if (num_dims == 2) {
func = DispatchPackSize<op, T, R, 2>;
} else if (num_dims == 3) {
func = DispatchPackSize<op, T, R, 3>;
} else if (num_dims == 4) {
func = DispatchPackSize<op, T, R, 4>;
} else if (num_dims <= 8) {
func = DispatchPackSize<op, T, R, 8>;
} else {
UNIMPLEMENTED();
}
func(stream, src0_pack_size, src1_pack_size, num_dims, src0_dims, src0, src1_dims, src1, dst_dims,
dst, attr0, attr1);
}
template<size_t max_pack_size, typename T, typename R>
size_t GetPackSize(size_t num_src_dims, const int64_t* src0_dims, const void* src0,
const int64_t* src1_dims, const void* src1, void* dst) {
static_assert(max_pack_size > 0 && (max_pack_size & (max_pack_size - 1)) == 0, "");
CHECK(src0_dims[num_src_dims - 1] != 1 || src1_dims[num_src_dims - 1] != 1);
auto dst_ptr = reinterpret_cast<std::uintptr_t>(dst);
for (size_t pack_size = max_pack_size; pack_size > 2; pack_size /= 2) {
bool is_src0_supported = (src0_dims[num_src_dims - 1] == 1)
|| IsPackSizeSupported<T>(pack_size, num_src_dims, src0_dims, src0);
bool is_src1_supported = (src1_dims[num_src_dims - 1] == 1)
|| IsPackSizeSupported<T>(pack_size, num_src_dims, src1_dims, src1);
if (is_src0_supported && is_src1_supported && (dst_ptr % (pack_size * sizeof(R))) == 0) {
return pack_size;
}
}
return 1;
}
constexpr size_t kMaxPackSize = 4;
template<BinaryOp op, typename T, typename R>
void LaunchWithSimplified(Stream* stream, size_t simplified_num_dims, int64_t* simplified_src0_dims,
const void* src0, int64_t* simplified_src1_dims, const void* src1,
int64_t* simplified_dst_dims, void* dst, Scalar attr0, Scalar attr1) {
CHECK_LE(simplified_num_dims, kMaxNumDims);
size_t pack_size = GetPackSize<kMaxPackSize, T, R>(simplified_num_dims, simplified_src0_dims,
src0, simplified_src1_dims, src1, dst);
size_t src0_pack_size = 1;
size_t src1_pack_size = 1;
if (simplified_src0_dims[simplified_num_dims - 1] != 1) {
simplified_src0_dims[simplified_num_dims - 1] /= pack_size;
src0_pack_size = pack_size;
}
if (simplified_src1_dims[simplified_num_dims - 1] != 1) {
simplified_src1_dims[simplified_num_dims - 1] /= pack_size;
src1_pack_size = pack_size;
}
simplified_dst_dims[simplified_num_dims - 1] /= pack_size;
DispatchNumDims<op, T, R>(stream, src0_pack_size, src1_pack_size, simplified_num_dims,
simplified_src0_dims, src0, simplified_src1_dims, src1,
simplified_dst_dims, dst, attr0, attr1);
}
template<BinaryOp binary_op, typename Src, typename Dst>
struct BinaryLhsScalarFunctor {
__host__ __device__ BinaryLhsScalarFunctor(Src scalar, Scalar attr0, Scalar attr1)
: scalar(scalar), functor(attr0, attr1) {}
__device__ Dst operator()(Src src) const { return functor(scalar, src); }
const Src scalar;
BinaryFunctor<DeviceType::kCUDA, binary_op, Src, Dst> functor;
};
template<BinaryOp binary_op, typename Src, typename Dst>
struct BinaryRhsScalarFunctor {
__host__ __device__ BinaryRhsScalarFunctor(Src scalar, Scalar attr0, Scalar attr1)
: scalar(scalar), functor(attr0, attr1) {}
__device__ Dst operator()(Src src) const { return functor(src, scalar); }
const Src scalar;
BinaryFunctor<DeviceType::kCUDA, binary_op, Src, Dst> functor;
};
template<BinaryOp binary_op, typename Src, typename Dst>
struct BinaryLhsScalarPtrFunctorFactory {
__host__ __device__ BinaryLhsScalarPtrFunctorFactory(const Src* scalar_ptr, Scalar attr0,
Scalar attr1)
: scalar_ptr(scalar_ptr), attr0(attr0), attr1(attr1) {}
__device__ BinaryLhsScalarFunctor<binary_op, Src, Dst> operator()() const {
return BinaryLhsScalarFunctor<binary_op, Src, Dst>(*scalar_ptr, attr0, attr1);
}
const Src* scalar_ptr;
Scalar attr0, attr1;
};
template<BinaryOp binary_op, typename Src, typename Dst>
struct BinaryRhsScalarPtrFunctorFactory {
__host__ __device__ explicit BinaryRhsScalarPtrFunctorFactory(const Src* scalar_ptr, Scalar attr0,
Scalar attr1)
: scalar_ptr(scalar_ptr), attr0(attr0), attr1(attr1) {}
__device__ BinaryRhsScalarFunctor<binary_op, Src, Dst> operator()() const {
return BinaryRhsScalarFunctor<binary_op, Src, Dst>(*scalar_ptr, attr0, attr1);
}
const Src* scalar_ptr;
Scalar attr0, attr1;
};
template<BinaryOp binary_op, typename Src, typename Dst>
void DispatchLaunch(Stream* stream, size_t num_src0_dims, const int64_t* src0_dims, const Src* src0,
size_t num_src1_dims, const int64_t* src1_dims, const Src* src1, Dst* dst,
Scalar attr0, Scalar attr1) {
auto* cuda_stream = stream->As<CudaStream>();
size_t simplified_num_dims = 0;
int64_t simplified_src0_dims[kMaxNumDims];
int64_t simplified_src1_dims[kMaxNumDims];
int64_t simplified_dst_dims[kMaxNumDims];
SimplifyBroadcastDims<kMaxNumDims>(num_src0_dims, src0_dims, num_src1_dims, src1_dims,
&simplified_num_dims, simplified_src0_dims,
simplified_src1_dims, simplified_dst_dims);
CheckInplace(simplified_num_dims, simplified_src0_dims, src0, simplified_src1_dims, src1,
simplified_dst_dims, dst);
if (IsDimsEquals(simplified_num_dims, simplified_src0_dims, simplified_num_dims,
simplified_src1_dims)) {
const int64_t elem_cnt = GetElementCount(simplified_num_dims, simplified_src0_dims);
OF_CUDA_CHECK((cuda::elementwise::Binary(
BinaryFunctor<DeviceType::kCUDA, binary_op, Src, Dst>(attr0, attr1), elem_cnt, dst, src0,
src1, cuda_stream->cuda_stream())));
} else {
if (simplified_num_dims == 1 && simplified_src0_dims[0] == 1) {
OF_CUDA_CHECK((cuda::elementwise::UnaryWithFactory(
BinaryLhsScalarPtrFunctorFactory<binary_op, Src, Dst>(src0, attr0, attr1),
simplified_src1_dims[0], dst, src1, cuda_stream->cuda_stream())));
} else if (simplified_num_dims == 1 && simplified_src1_dims[0] == 1) {
OF_CUDA_CHECK((cuda::elementwise::UnaryWithFactory(
BinaryRhsScalarPtrFunctorFactory<binary_op, Src, Dst>(src1, attr0, attr1),
simplified_src0_dims[0], dst, src0, cuda_stream->cuda_stream())));
} else {
LaunchWithSimplified<binary_op, Src, Dst>(stream, simplified_num_dims, simplified_src0_dims,
src0, simplified_src1_dims, src1,
simplified_dst_dims, dst, attr0, attr1);
}
}
}
template<typename T>
T GetValue(Scalar value) {
return value.Value<T>();
}
template<>
half GetValue<half>(Scalar value) {
return static_cast<half>(GetValue<float>(value));
}
// #if CUDA_VERSION >= 11000
// template<>
// nv_bfloat16 GetValue<nv_bfloat16>(Scalar value) {
// return static_cast<nv_bfloat16>(GetValue<float>(value));
// }
// #endif // CUDA_VERSION >= 11000
template<BinaryOp binary_op, typename Src, typename Dst>
class BroadcastElementwiseBinaryImpl : public BroadcastElementwiseBinary {
public:
OF_DISALLOW_COPY_AND_MOVE(BroadcastElementwiseBinaryImpl);
BroadcastElementwiseBinaryImpl(Scalar attr0, Scalar attr1) : attr0(attr0), attr1(attr1) {}
~BroadcastElementwiseBinaryImpl() override = default;
void Launch(Stream* stream, Scalar src0, size_t num_src1_dims, const int64_t* src1_dims,
const void* src1, void* dst) override {
auto* cuda_stream = stream->As<CudaStream>();
const size_t elem_cnt = GetElementCount(num_src1_dims, src1_dims);
OF_CUDA_CHECK((cuda::elementwise::Unary(
BinaryLhsScalarFunctor<binary_op, Src, Dst>(GetValue<Src>(src0), attr0, attr1), elem_cnt,
reinterpret_cast<Dst*>(dst), reinterpret_cast<const Src*>(src1),
cuda_stream->cuda_stream())));
}
void Launch(Stream* stream, size_t num_src0_dims, const int64_t* src0_dims, const void* src0,
Scalar src1, void* dst) override {
auto* cuda_stream = stream->As<CudaStream>();
const size_t elem_cnt = GetElementCount(num_src0_dims, src0_dims);
OF_CUDA_CHECK((cuda::elementwise::Unary(
BinaryRhsScalarFunctor<binary_op, Src, Dst>(GetValue<Src>(src1), attr0, attr1), elem_cnt,
reinterpret_cast<Dst*>(dst), reinterpret_cast<const Src*>(src0),
cuda_stream->cuda_stream())));
}
void Launch(Stream* stream, size_t num_src0_dims, const int64_t* src0_dims, const void* src0,
size_t num_src1_dims, const int64_t* src1_dims, const void* src1,
void* dst) override {
DispatchLaunch<binary_op, Src, Dst>(
stream, num_src0_dims, src0_dims, reinterpret_cast<const Src*>(src0), num_src1_dims,
src1_dims, reinterpret_cast<const Src*>(src1), reinterpret_cast<Dst*>(dst), attr0, attr1);
}
private:
Scalar attr0, attr1;
};
} // namespace
template<BinaryOp binary_op, typename Src, typename Dst>
std::unique_ptr<BroadcastElementwiseBinary> NewBroadcastElementwiseBinary(Scalar attr0,
Scalar attr1) {
return std::unique_ptr<BroadcastElementwiseBinary>(
new BroadcastElementwiseBinaryImpl<binary_op, Src, Dst>(attr0, attr1));
}
} // namespace broadcast_elementwise_binary
} // namespace primitive
} // namespace ep
} // namespace oneflow
\ No newline at end of file
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/ep/rocm/primitive/broadcast_elementwise_binary.hip.h"
namespace oneflow {
namespace ep {
namespace primitive {
namespace broadcast_elementwise_binary {
#define INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_ACTIVATION_GRAD_ENTRY(binary_op, \
data_type_pair) \
template std::unique_ptr<BroadcastElementwiseBinary> NewBroadcastElementwiseBinary< \
binary_op, OF_PP_PAIR_FIRST(data_type_pair), OF_PP_PAIR_FIRST(data_type_pair)>( \
Scalar attr0, Scalar attr1);
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_ACTIVATION_GRAD_ENTRY,
BINARY_ACTIVATION_BACKWARD_OP_SEQ,
CUDA_PRIMITIVE_FLOATING_TYPE_SEQ);
} // namespace broadcast_elementwise_binary
} // namespace primitive
} // namespace ep
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/ep/rocm/primitive/broadcast_elementwise_binary.hip.h"
namespace oneflow {
namespace ep {
namespace primitive {
namespace broadcast_elementwise_binary {
#define INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_ACTIVATION_GRAD_ENTRY(binary_op, \
data_type_pair) \
template std::unique_ptr<BroadcastElementwiseBinary> NewBroadcastElementwiseBinary< \
binary_op, OF_PP_PAIR_FIRST(data_type_pair), OF_PP_PAIR_FIRST(data_type_pair)>( \
Scalar attr0, Scalar attr1);
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_ACTIVATION_GRAD_ENTRY,
BINARY_ACTIVATION_BACKWARD_OP_SEQ,
CUDA_PRIMITIVE_FLOATING_TYPE_SEQ);
} // namespace broadcast_elementwise_binary
} // namespace primitive
} // namespace ep
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/ep/rocm/primitive/broadcast_elementwise_binary.hip.h"
namespace oneflow {
namespace ep {
namespace primitive {
namespace broadcast_elementwise_binary {
#define INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_COMPARASION_ENTRY( \
binary_op, src_data_type_pair, dst_data_type_pair) \
template std::unique_ptr<BroadcastElementwiseBinary> NewBroadcastElementwiseBinary< \
binary_op, OF_PP_PAIR_FIRST(src_data_type_pair), OF_PP_PAIR_FIRST(dst_data_type_pair)>( \
Scalar attr0, Scalar attr1);
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_COMPARASION_ENTRY,
BINARY_COMPARISION_OP_SEQ, CUDA_PRIMITIVE_ALL_TYPE_SEQ,
CUDA_PRIMITIVE_BOOL_TYPE_SEQ);
} // namespace broadcast_elementwise_binary
} // namespace primitive
} // namespace ep
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/ep/rocm/primitive/broadcast_elementwise_binary.hip.h"
namespace oneflow {
namespace ep {
namespace primitive {
namespace broadcast_elementwise_binary {
#define INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_COMPARASION_ENTRY( \
binary_op, src_data_type_pair, dst_data_type_pair) \
template std::unique_ptr<BroadcastElementwiseBinary> NewBroadcastElementwiseBinary< \
binary_op, OF_PP_PAIR_FIRST(src_data_type_pair), OF_PP_PAIR_FIRST(dst_data_type_pair)>( \
Scalar attr0, Scalar attr1);
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_COMPARASION_ENTRY,
BINARY_COMPARISION_OP_SEQ, CUDA_PRIMITIVE_ALL_TYPE_SEQ,
CUDA_PRIMITIVE_BOOL_TYPE_SEQ);
} // namespace broadcast_elementwise_binary
} // namespace primitive
} // namespace ep
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/ep/rocm/primitive/broadcast_elementwise_binary.hip.h"
namespace oneflow {
namespace ep {
namespace primitive {
namespace broadcast_elementwise_binary {
#define INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_LOGICAL_ENTRY(binary_op, src_data_type_pair, \
dst_data_type_pair) \
template std::unique_ptr<BroadcastElementwiseBinary> NewBroadcastElementwiseBinary< \
binary_op, OF_PP_PAIR_FIRST(src_data_type_pair), OF_PP_PAIR_FIRST(dst_data_type_pair)>( \
Scalar attr0, Scalar attr1);
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_LOGICAL_ENTRY,
BINARY_COMPARISION_OP_SEQ BINARY_LOGICAL_OP_SEQ,
CUDA_PRIMITIVE_ALL_TYPE_SEQ, CUDA_PRIMITIVE_BOOL_TYPE_SEQ);
} // namespace broadcast_elementwise_binary
} // namespace primitive
} // namespace ep
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/ep/rocm/primitive/broadcast_elementwise_binary.hip.h"
namespace oneflow {
namespace ep {
namespace primitive {
namespace broadcast_elementwise_binary {
#define INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_LOGICAL_ENTRY(binary_op, src_data_type_pair, \
dst_data_type_pair) \
template std::unique_ptr<BroadcastElementwiseBinary> NewBroadcastElementwiseBinary< \
binary_op, OF_PP_PAIR_FIRST(src_data_type_pair), OF_PP_PAIR_FIRST(dst_data_type_pair)>( \
Scalar attr0, Scalar attr1);
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_LOGICAL_ENTRY,
BINARY_COMPARISION_OP_SEQ BINARY_LOGICAL_OP_SEQ,
CUDA_PRIMITIVE_ALL_TYPE_SEQ, CUDA_PRIMITIVE_BOOL_TYPE_SEQ);
} // namespace broadcast_elementwise_binary
} // namespace primitive
} // namespace ep
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/ep/rocm/primitive/broadcast_elementwise_binary.hip.h"
namespace oneflow {
namespace ep {
namespace primitive {
namespace broadcast_elementwise_binary {
#define INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY(binary_op, data_type_pair) \
template std::unique_ptr<BroadcastElementwiseBinary> NewBroadcastElementwiseBinary< \
binary_op, OF_PP_PAIR_FIRST(data_type_pair), OF_PP_PAIR_FIRST(data_type_pair)>( \
Scalar attr0, Scalar attr1);
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY,
BINARY_MATH_OP_SEQ, CUDA_PRIMITIVE_ALL_TYPE_SEQ);
} // namespace broadcast_elementwise_binary
} // namespace primitive
} // namespace ep
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/ep/rocm/primitive/broadcast_elementwise_binary.hip.h"
namespace oneflow {
namespace ep {
namespace primitive {
namespace broadcast_elementwise_binary {
#define INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY(binary_op, data_type_pair) \
template std::unique_ptr<BroadcastElementwiseBinary> NewBroadcastElementwiseBinary< \
binary_op, OF_PP_PAIR_FIRST(data_type_pair), OF_PP_PAIR_FIRST(data_type_pair)>( \
Scalar attr0, Scalar attr1);
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY,
BINARY_MATH_OP_SEQ, CUDA_PRIMITIVE_ALL_TYPE_SEQ);
} // namespace broadcast_elementwise_binary
} // namespace primitive
} // namespace ep
} // namespace oneflow
\ No newline at end of file
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifdef WITH_ROCM
#include "oneflow/core/ep/include/primitive/primitive.h"
#include "oneflow/core/ep/include/primitive/broadcast_matmul.h"
#include "oneflow/core/ep/common/primitive/broadcast_matmul.h"
#include "oneflow/core/common/optional.h"
#include "oneflow/core/device/cuda_util.h"
#include "oneflow/core/ep/rocm/cuda_stream.h"
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
namespace oneflow {
namespace ep {
namespace primitive {
namespace broadcast_matmul {
namespace internal {
namespace {
constexpr size_t kMaxNumDims = 8;
Optional<hipblasDatatype_t> OptCudaDataType(DataType data_type) {
switch (data_type) {
case kFloat: return HIPBLAS_R_32F;
case kDouble: return HIPBLAS_R_64F;
case kFloat16: return HIPBLAS_R_16F;
// #if CUDA_VERSION >= 11000
// case kBFloat16: return CUDA_R_16BF;
// #endif // CUDA_VERSION >= 11000
default: return NullOpt;
}
}
hipblasDatatype_t GetCudaDataType(DataType data_type) {
auto cuda_data_type = OptCudaDataType(data_type);
CHECK(cuda_data_type.has_value());
return cuda_data_type.value_or(HIPBLAS_R_32F);
}
union CublasScalarParameter {
double d;
float s;
};
CublasScalarParameter GetCublasScalarParameter(Scalar scalar, hipblasDatatype_t compute_type) {
CublasScalarParameter sp{};
if (compute_type == HIPBLAS_R_64F) {
sp.d = scalar.Value<double>();
} else if (compute_type == HIPBLAS_R_32F) {
sp.s = scalar.Value<float>();
} else if (compute_type == HIPBLAS_R_16F) {
sp.s = scalar.Value<float>();
} else {
UNIMPLEMENTED();
}
return sp;
}
hipblasDatatype_t GetComputeType(DataType data_type) {
switch (data_type) {
case kFloat: return HIPBLAS_R_32F;
case kDouble: return HIPBLAS_R_64F;
case kFloat16: return HIPBLAS_R_16F;
// #if CUDA_VERSION >= 11000
// case kBFloat16: return HIPBLAS_R_32F;
// #endif // CUDA_VERSION >= 11000
default: UNIMPLEMENTED(); return HIPBLAS_R_32F;
}
}
void LaunchBroadcastMatmul(Stream* stream, DataType data_type, BlasTransposeType transpose_a,
BlasTransposeType transpose_b, int64_t num_batch_dims,
const int64_t* broadcast_batch_dims, const int64_t* a_batch_dims,
const int64_t* b_batch_dims, const int64_t* c_batch_dims, int64_t m,
int64_t n, int64_t k, Scalar alpha, const void* a, const void* b,
Scalar beta, void* c) {
auto* cuda_stream = stream->As<CudaStream>();
const auto cuda_data_type = GetCudaDataType(data_type);
const auto compute_type = GetComputeType(data_type);
const auto sp_alpha = GetCublasScalarParameter(alpha, compute_type);
__half h_alpha = 0;
if (compute_type == HIPBLAS_R_16F) {
h_alpha = __float2half(sp_alpha.s);
}
const auto GetCublasOperation = [](BlasTransposeType transpose_type) {
if (transpose_type == BlasTransposeType::N) {
return HIPBLAS_OP_N;
} else if (transpose_type == BlasTransposeType::T) {
return HIPBLAS_OP_T;
} else {
UNIMPLEMENTED();
return HIPBLAS_OP_N;
}
};
const hipblasOperation_t cublas_trans_a = GetCublasOperation(transpose_b);
const hipblasOperation_t cublas_trans_b = GetCublasOperation(transpose_a);
const int cublas_m = n;
const int cublas_n = m;
const int cublas_k = k;
int cublas_lda = 0;
if (transpose_b == BlasTransposeType::N) {
cublas_lda = n;
} else if (transpose_b == BlasTransposeType::T) {
cublas_lda = k;
} else {
UNIMPLEMENTED();
}
int cublas_ldb = 0;
if (transpose_a == BlasTransposeType::N) {
cublas_ldb = k;
} else if (transpose_a == BlasTransposeType::T) {
cublas_ldb = m;
} else {
UNIMPLEMENTED();
}
const int cublas_ldc = n;
// CublasMathModeGuard guard(cuda_stream->cublas_handle());
// if (data_type == DataType::kFloat16) {
// #if CUDA_VERSION < 11000
// guard.SetMathMode(CUBLAS_TENSOR_OP_MATH);
// #else
// guard.SetMathMode(CUBLAS_DEFAULT_MATH);
// #endif // CUDA_VERSION < 11000
// }
// #if CUDA_VERSION >= 11000
// hipblasGemmAlgo_t algo = HIPBLAS_GEMM_DEFAULT;
hipblasGemmAlgo_t algo = HIPBLAS_GEMM_DEFAULT;
// #else
// hipblasGemmAlgo_t algo =
// (data_type == DataType::kFloat16) ? CUBLAS_GEMM_DFALT_TENSOR_OP : HIPBLAS_GEMM_DEFAULT;
// #endif
if (num_batch_dims == 1 && c_batch_dims[0] != 1) {
const void* cublas_a = b;
const void* cublas_b = a;
void* cublas_c = c;
const int64_t a_batch_count = a_batch_dims[0];
const int64_t b_batch_count = b_batch_dims[0];
CHECK(a_batch_count == 1 || b_batch_count == 1 || a_batch_count == b_batch_count);
CHECK_GT(a_batch_count, 0);
CHECK_GT(b_batch_count, 0);
const int batch_count = std::max(a_batch_count, b_batch_count);
const long long int cublas_stride_a = b_batch_count == 1 ? 0 : cublas_m * cublas_k;
const long long int cublas_stride_b = a_batch_count == 1 ? 0 : cublas_k * cublas_n;
const long long int cublas_stride_c = cublas_m * cublas_n;
const auto sp_beta = GetCublasScalarParameter(beta, compute_type);
__half h_beta = 0;
if (compute_type == HIPBLAS_R_16F) {
h_beta = __float2half(sp_beta.s);
OF_CUBLAS_CHECK(hipblasGemmStridedBatchedEx(
cuda_stream->cublas_handle(), cublas_trans_a, cublas_trans_b, cublas_m, cublas_n, cublas_k,
&h_alpha, cublas_a, cuda_data_type, cublas_lda, cublas_stride_a, cublas_b, cuda_data_type,
cublas_ldb, cublas_stride_b, &h_beta, cublas_c, cuda_data_type, cublas_ldc,
cublas_stride_c, batch_count, compute_type, algo));
} else {
OF_CUBLAS_CHECK(hipblasGemmStridedBatchedEx(
cuda_stream->cublas_handle(), cublas_trans_a, cublas_trans_b, cublas_m, cublas_n, cublas_k,
&sp_alpha, cublas_a, cuda_data_type, cublas_lda, cublas_stride_a, cublas_b, cuda_data_type,
cublas_ldb, cublas_stride_b, &sp_beta, cublas_c, cuda_data_type, cublas_ldc,
cublas_stride_c, batch_count, compute_type, algo));
}
} else {
auto func = [&](const void* batch_a, const void* batch_b, void* batch_c, Scalar batch_beta) {
const auto sp_beta = GetCublasScalarParameter(batch_beta, compute_type);
__half h_beta = 0;
const void* cublas_a = batch_b;
const void* cublas_b = batch_a;
void* cublas_c = batch_c;
if (compute_type == HIPBLAS_R_16F) {
h_beta = __float2half(sp_beta.s);
OF_CUBLAS_CHECK(hipblasGemmEx(
cuda_stream->cublas_handle(), cublas_trans_a, cublas_trans_b, cublas_m, cublas_n,
cublas_k, &h_alpha, cublas_a, cuda_data_type, cublas_lda, cublas_b, cuda_data_type,
cublas_ldb, &h_beta, cublas_c, cuda_data_type, cublas_ldc, compute_type, algo));
} else {
OF_CUBLAS_CHECK(hipblasGemmEx(
cuda_stream->cublas_handle(), cublas_trans_a, cublas_trans_b, cublas_m, cublas_n,
cublas_k, &sp_alpha, cublas_a, cuda_data_type, cublas_lda, cublas_b, cuda_data_type,
cublas_ldb, &sp_beta, cublas_c, cuda_data_type, cublas_ldc, compute_type, algo));
}
};
ForEachMatmul<kMaxNumDims>(data_type, m, n, k, beta, num_batch_dims, broadcast_batch_dims,
a_batch_dims, b_batch_dims, c_batch_dims, a, b, c, func);
}
}
class BroadcastMatmulFactoryImpl : public BroadcastMatmulFactory {
public:
OF_DISALLOW_COPY_AND_MOVE(BroadcastMatmulFactoryImpl);
BroadcastMatmulFactoryImpl() = default;
~BroadcastMatmulFactoryImpl() override = default;
std::unique_ptr<BroadcastMatmul> New(DataType data_type, BlasTransposeType transpose_a,
BlasTransposeType transpose_b,
size_t max_num_dims) override {
auto cuda_data_type = OptCudaDataType(data_type);
if (max_num_dims <= kMaxNumDims && cuda_data_type.has_value()) {
return std::make_unique<BroadcastMatmulImpl<kMaxNumDims>>(data_type, transpose_a,
transpose_b);
} else {
return nullptr;
}
}
};
REGISTER_PRIMITIVE_FACTORY(DeviceType::kCUDA, BroadcastMatmulFactory, BroadcastMatmulFactoryImpl);
} // namespace
} // namespace internal
} // namespace broadcast_matmul
} // namespace primitive
} // namespace ep
} // namespace oneflow
#endif // WITH_ROCM
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifdef WITH_ROCM
#include "oneflow/core/ep/include/primitive/primitive.h"
#include "oneflow/core/ep/include/primitive/broadcast_matmul.h"
#include "oneflow/core/ep/common/primitive/broadcast_matmul.h"
#include "oneflow/core/common/optional.h"
#include "oneflow/core/device/cuda_util.h"
#include "oneflow/core/ep/rocm/cuda_stream.h"
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
namespace oneflow {
namespace ep {
namespace primitive {
namespace broadcast_matmul {
namespace internal {
namespace {
constexpr size_t kMaxNumDims = 8;
Optional<hipblasDatatype_t> OptCudaDataType(DataType data_type) {
switch (data_type) {
case kFloat: return HIPBLAS_R_32F;
case kDouble: return HIPBLAS_R_64F;
case kFloat16: return HIPBLAS_R_16F;
// #if CUDA_VERSION >= 11000
// case kBFloat16: return CUDA_R_16BF;
// #endif // CUDA_VERSION >= 11000
default: return NullOpt;
}
}
hipblasDatatype_t GetCudaDataType(DataType data_type) {
auto cuda_data_type = OptCudaDataType(data_type);
CHECK(cuda_data_type.has_value());
return cuda_data_type.value_or(HIPBLAS_R_32F);
}
union CublasScalarParameter {
double d;
float s;
};
CublasScalarParameter GetCublasScalarParameter(Scalar scalar, hipblasDatatype_t compute_type) {
CublasScalarParameter sp{};
if (compute_type == HIPBLAS_R_64F) {
sp.d = scalar.Value<double>();
} else if (compute_type == HIPBLAS_R_32F) {
sp.s = scalar.Value<float>();
} else if (compute_type == HIPBLAS_R_16F) {
sp.s = scalar.Value<float>();
} else {
UNIMPLEMENTED();
}
return sp;
}
hipblasDatatype_t GetComputeType(DataType data_type) {
switch (data_type) {
case kFloat: return HIPBLAS_R_32F;
case kDouble: return HIPBLAS_R_64F;
case kFloat16: return HIPBLAS_R_16F;
// #if CUDA_VERSION >= 11000
// case kBFloat16: return HIPBLAS_R_32F;
// #endif // CUDA_VERSION >= 11000
default: UNIMPLEMENTED(); return HIPBLAS_R_32F;
}
}
void LaunchBroadcastMatmul(Stream* stream, DataType data_type, BlasTransposeType transpose_a,
BlasTransposeType transpose_b, int64_t num_batch_dims,
const int64_t* broadcast_batch_dims, const int64_t* a_batch_dims,
const int64_t* b_batch_dims, const int64_t* c_batch_dims, int64_t m,
int64_t n, int64_t k, Scalar alpha, const void* a, const void* b,
Scalar beta, void* c) {
auto* cuda_stream = stream->As<CudaStream>();
const auto cuda_data_type = GetCudaDataType(data_type);
const auto compute_type = GetComputeType(data_type);
const auto sp_alpha = GetCublasScalarParameter(alpha, compute_type);
__half h_alpha = 0;
if (compute_type == HIPBLAS_R_16F) {
h_alpha = __float2half(sp_alpha.s);
}
const auto GetCublasOperation = [](BlasTransposeType transpose_type) {
if (transpose_type == BlasTransposeType::N) {
return HIPBLAS_OP_N;
} else if (transpose_type == BlasTransposeType::T) {
return HIPBLAS_OP_T;
} else {
UNIMPLEMENTED();
return HIPBLAS_OP_N;
}
};
const hipblasOperation_t cublas_trans_a = GetCublasOperation(transpose_b);
const hipblasOperation_t cublas_trans_b = GetCublasOperation(transpose_a);
const int cublas_m = n;
const int cublas_n = m;
const int cublas_k = k;
int cublas_lda = 0;
if (transpose_b == BlasTransposeType::N) {
cublas_lda = n;
} else if (transpose_b == BlasTransposeType::T) {
cublas_lda = k;
} else {
UNIMPLEMENTED();
}
int cublas_ldb = 0;
if (transpose_a == BlasTransposeType::N) {
cublas_ldb = k;
} else if (transpose_a == BlasTransposeType::T) {
cublas_ldb = m;
} else {
UNIMPLEMENTED();
}
const int cublas_ldc = n;
// CublasMathModeGuard guard(cuda_stream->cublas_handle());
// if (data_type == DataType::kFloat16) {
// #if CUDA_VERSION < 11000
// guard.SetMathMode(CUBLAS_TENSOR_OP_MATH);
// #else
// guard.SetMathMode(CUBLAS_DEFAULT_MATH);
// #endif // CUDA_VERSION < 11000
// }
// #if CUDA_VERSION >= 11000
// hipblasGemmAlgo_t algo = HIPBLAS_GEMM_DEFAULT;
hipblasGemmAlgo_t algo = HIPBLAS_GEMM_DEFAULT;
// #else
// hipblasGemmAlgo_t algo =
// (data_type == DataType::kFloat16) ? CUBLAS_GEMM_DFALT_TENSOR_OP : HIPBLAS_GEMM_DEFAULT;
// #endif
if (num_batch_dims == 1 && c_batch_dims[0] != 1) {
const void* cublas_a = b;
const void* cublas_b = a;
void* cublas_c = c;
const int64_t a_batch_count = a_batch_dims[0];
const int64_t b_batch_count = b_batch_dims[0];
CHECK(a_batch_count == 1 || b_batch_count == 1 || a_batch_count == b_batch_count);
CHECK_GT(a_batch_count, 0);
CHECK_GT(b_batch_count, 0);
const int batch_count = std::max(a_batch_count, b_batch_count);
const long long int cublas_stride_a = b_batch_count == 1 ? 0 : cublas_m * cublas_k;
const long long int cublas_stride_b = a_batch_count == 1 ? 0 : cublas_k * cublas_n;
const long long int cublas_stride_c = cublas_m * cublas_n;
const auto sp_beta = GetCublasScalarParameter(beta, compute_type);
__half h_beta = 0;
if (compute_type == HIPBLAS_R_16F) {
h_beta = __float2half(sp_beta.s);
OF_CUBLAS_CHECK(hipblasGemmStridedBatchedEx(
cuda_stream->cublas_handle(), cublas_trans_a, cublas_trans_b, cublas_m, cublas_n, cublas_k,
&h_alpha, cublas_a, cuda_data_type, cublas_lda, cublas_stride_a, cublas_b, cuda_data_type,
cublas_ldb, cublas_stride_b, &h_beta, cublas_c, cuda_data_type, cublas_ldc,
cublas_stride_c, batch_count, compute_type, algo));
} else {
OF_CUBLAS_CHECK(hipblasGemmStridedBatchedEx(
cuda_stream->cublas_handle(), cublas_trans_a, cublas_trans_b, cublas_m, cublas_n, cublas_k,
&sp_alpha, cublas_a, cuda_data_type, cublas_lda, cublas_stride_a, cublas_b, cuda_data_type,
cublas_ldb, cublas_stride_b, &sp_beta, cublas_c, cuda_data_type, cublas_ldc,
cublas_stride_c, batch_count, compute_type, algo));
}
} else {
auto func = [&](const void* batch_a, const void* batch_b, void* batch_c, Scalar batch_beta) {
const auto sp_beta = GetCublasScalarParameter(batch_beta, compute_type);
__half h_beta = 0;
const void* cublas_a = batch_b;
const void* cublas_b = batch_a;
void* cublas_c = batch_c;
if (compute_type == HIPBLAS_R_16F) {
h_beta = __float2half(sp_beta.s);
OF_CUBLAS_CHECK(hipblasGemmEx(
cuda_stream->cublas_handle(), cublas_trans_a, cublas_trans_b, cublas_m, cublas_n,
cublas_k, &h_alpha, cublas_a, cuda_data_type, cublas_lda, cublas_b, cuda_data_type,
cublas_ldb, &h_beta, cublas_c, cuda_data_type, cublas_ldc, compute_type, algo));
} else {
OF_CUBLAS_CHECK(hipblasGemmEx(
cuda_stream->cublas_handle(), cublas_trans_a, cublas_trans_b, cublas_m, cublas_n,
cublas_k, &sp_alpha, cublas_a, cuda_data_type, cublas_lda, cublas_b, cuda_data_type,
cublas_ldb, &sp_beta, cublas_c, cuda_data_type, cublas_ldc, compute_type, algo));
}
};
ForEachMatmul<kMaxNumDims>(data_type, m, n, k, beta, num_batch_dims, broadcast_batch_dims,
a_batch_dims, b_batch_dims, c_batch_dims, a, b, c, func);
}
}
class BroadcastMatmulFactoryImpl : public BroadcastMatmulFactory {
public:
OF_DISALLOW_COPY_AND_MOVE(BroadcastMatmulFactoryImpl);
BroadcastMatmulFactoryImpl() = default;
~BroadcastMatmulFactoryImpl() override = default;
std::unique_ptr<BroadcastMatmul> New(DataType data_type, BlasTransposeType transpose_a,
BlasTransposeType transpose_b,
size_t max_num_dims) override {
auto cuda_data_type = OptCudaDataType(data_type);
if (max_num_dims <= kMaxNumDims && cuda_data_type.has_value()) {
return std::make_unique<BroadcastMatmulImpl<kMaxNumDims>>(data_type, transpose_a,
transpose_b);
} else {
return nullptr;
}
}
};
REGISTER_PRIMITIVE_FACTORY(DeviceType::kCUDA, BroadcastMatmulFactory, BroadcastMatmulFactoryImpl);
} // namespace
} // namespace internal
} // namespace broadcast_matmul
} // namespace primitive
} // namespace ep
} // namespace oneflow
#endif // WITH_ROCM
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/common/preprocessor.h"
#include "oneflow/core/ep/include/primitive/cast.h"
#include "oneflow/core/ep/rocm/primitive/type_seq.h"
#include "oneflow/core/hip/elementwise.hip.h"
#include "oneflow/core/ep/rocm/cuda_stream.h"
namespace oneflow {
namespace ep {
namespace primitive {
namespace {
template<typename To, typename From, typename = void>
struct CastFunctor {
__device__ To operator()(From from) const { return static_cast<To>(from); }
};
template<typename To>
struct CastFunctor<To, half, typename std::enable_if<!std::is_same<To, half>::value>::type> {
__device__ To operator()(half from) const { return static_cast<To>(static_cast<float>(from)); }
__device__ void Apply2(To* to, const half* from) const {
const float2 f2 = __half22float2(*reinterpret_cast<const half2*>(from));
to[0] = static_cast<To>(f2.x);
to[1] = static_cast<To>(f2.y);
}
};
template<typename From>
struct CastFunctor<half, From, typename std::enable_if<!std::is_same<From, half>::value>::type> {
__device__ half operator()(From from) const {
return static_cast<half>(static_cast<float>(from));
}
__device__ void Apply2(half* to, const From* from) const {
float2 f2;
f2.x = static_cast<float>(from[0]);
f2.y = static_cast<float>(from[1]);
*reinterpret_cast<half2*>(to) = __float22half2_rn(f2);
}
};
// #if CUDA_VERSION >= 11000
// template<typename To>
// struct CastFunctor<To, nv_bfloat16,
// typename std::enable_if<!(std::is_same<To, nv_bfloat16>::value
// || std::is_same<To, half>::value)>::type> {
// __device__ To operator()(nv_bfloat16 from) const {
// return static_cast<To>(static_cast<float>(from));
// }
// };
// template<typename From>
// struct CastFunctor<nv_bfloat16, From,
// typename std::enable_if<!(std::is_same<From, nv_bfloat16>::value
// || std::is_same<From, half>::value)>::type> {
// __device__ nv_bfloat16 operator()(From from) const {
// return static_cast<nv_bfloat16>(static_cast<float>(from));
// }
// };
// #endif // CUDA_VERSION >= 11000
template<typename From, typename To>
class CastImpl : public Cast {
public:
OF_DISALLOW_COPY_AND_MOVE(CastImpl);
explicit CastImpl() = default;
~CastImpl() override = default;
void Launch(Stream* stream, const void* from, void* to, size_t count) override {
auto* cuda_stream = stream->As<CudaStream>();
OF_CUDA_CHECK((cuda::elementwise::Unary<CastFunctor<To, From>, To, From>(
CastFunctor<To, From>(), count, reinterpret_cast<To*>(to),
reinterpret_cast<const From*>(from), cuda_stream->cuda_stream())));
}
};
template<typename From, typename To>
std::unique_ptr<Cast> NewCast() {
return std::unique_ptr<Cast>(new CastImpl<From, To>());
}
#define CUDA_PRIMITIVE_CAST_TYPE_SEQ \
CUDA_PRIMITIVE_BOOL_TYPE_SEQ \
CUDA_PRIMITIVE_CHAR_TYPE_SEQ \
CUDA_PRIMITIVE_INT8_TYPE_SEQ \
CUDA_PRIMITIVE_UINT8_TYPE_SEQ \
CUDA_PRIMITIVE_INT32_TYPE_SEQ \
CUDA_PRIMITIVE_UINT32_TYPE_SEQ \
CUDA_PRIMITIVE_INT64_TYPE_SEQ \
CUDA_PRIMITIVE_UINT64_TYPE_SEQ \
CUDA_PRIMITIVE_FLOAT_TYPE_SEQ \
CUDA_PRIMITIVE_DOUBLE_TYPE_SEQ \
CUDA_PRIMITIVE_FLOAT16_TYPE_SEQ \
CUDA_PRIMITIVE_BFLOAT16_TYPE_SEQ
class CastFactoryImpl : public CastFactory {
public:
OF_DISALLOW_COPY_AND_MOVE(CastFactoryImpl);
CastFactoryImpl() = default;
~CastFactoryImpl() override = default;
std::unique_ptr<Cast> New(DataType from, DataType to) override {
#define MAKE_NEW_CAST_ENTRY(from_pair, to_pair) \
{std::make_pair(OF_PP_PAIR_SECOND(from_pair), OF_PP_PAIR_SECOND(to_pair)), \
NewCast<OF_PP_PAIR_FIRST(from_pair), OF_PP_PAIR_FIRST(to_pair)>},
static const std::map<std::pair<DataType, DataType>, std::function<std::unique_ptr<Cast>()>>
new_cast_handle{OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(
MAKE_NEW_CAST_ENTRY, CUDA_PRIMITIVE_CAST_TYPE_SEQ, CUDA_PRIMITIVE_CAST_TYPE_SEQ)};
#undef MAKE_NEW_CAST_ENTRY
const auto it = new_cast_handle.find(std::make_pair(from, to));
if (it != new_cast_handle.end()) {
return it->second();
} else {
return nullptr;
}
}
};
REGISTER_PRIMITIVE_FACTORY(DeviceType::kCUDA, CastFactory, CastFactoryImpl);
} // namespace
} // namespace primitive
} // namespace ep
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/common/preprocessor.h"
#include "oneflow/core/ep/include/primitive/cast.h"
#include "oneflow/core/ep/rocm/primitive/type_seq.h"
#include "oneflow/core/hip/elementwise.hip.h"
#include "oneflow/core/ep/rocm/cuda_stream.h"
namespace oneflow {
namespace ep {
namespace primitive {
namespace {
template<typename To, typename From, typename = void>
struct CastFunctor {
__device__ To operator()(From from) const { return static_cast<To>(from); }
};
template<typename To>
struct CastFunctor<To, half, typename std::enable_if<!std::is_same<To, half>::value>::type> {
__device__ To operator()(half from) const { return static_cast<To>(static_cast<float>(from)); }
__device__ void Apply2(To* to, const half* from) const {
const float2 f2 = __half22float2(*reinterpret_cast<const half2*>(from));
to[0] = static_cast<To>(f2.x);
to[1] = static_cast<To>(f2.y);
}
};
template<typename From>
struct CastFunctor<half, From, typename std::enable_if<!std::is_same<From, half>::value>::type> {
__device__ half operator()(From from) const {
return static_cast<half>(static_cast<float>(from));
}
__device__ void Apply2(half* to, const From* from) const {
float2 f2;
f2.x = static_cast<float>(from[0]);
f2.y = static_cast<float>(from[1]);
*reinterpret_cast<half2*>(to) = __float22half2_rn(f2);
}
};
// #if CUDA_VERSION >= 11000
// template<typename To>
// struct CastFunctor<To, nv_bfloat16,
// typename std::enable_if<!(std::is_same<To, nv_bfloat16>::value
// || std::is_same<To, half>::value)>::type> {
// __device__ To operator()(nv_bfloat16 from) const {
// return static_cast<To>(static_cast<float>(from));
// }
// };
// template<typename From>
// struct CastFunctor<nv_bfloat16, From,
// typename std::enable_if<!(std::is_same<From, nv_bfloat16>::value
// || std::is_same<From, half>::value)>::type> {
// __device__ nv_bfloat16 operator()(From from) const {
// return static_cast<nv_bfloat16>(static_cast<float>(from));
// }
// };
// #endif // CUDA_VERSION >= 11000
template<typename From, typename To>
class CastImpl : public Cast {
public:
OF_DISALLOW_COPY_AND_MOVE(CastImpl);
explicit CastImpl() = default;
~CastImpl() override = default;
void Launch(Stream* stream, const void* from, void* to, size_t count) override {
auto* cuda_stream = stream->As<CudaStream>();
OF_CUDA_CHECK((cuda::elementwise::Unary<CastFunctor<To, From>, To, From>(
CastFunctor<To, From>(), count, reinterpret_cast<To*>(to),
reinterpret_cast<const From*>(from), cuda_stream->cuda_stream())));
}
};
template<typename From, typename To>
std::unique_ptr<Cast> NewCast() {
return std::unique_ptr<Cast>(new CastImpl<From, To>());
}
#define CUDA_PRIMITIVE_CAST_TYPE_SEQ \
CUDA_PRIMITIVE_BOOL_TYPE_SEQ \
CUDA_PRIMITIVE_CHAR_TYPE_SEQ \
CUDA_PRIMITIVE_INT8_TYPE_SEQ \
CUDA_PRIMITIVE_UINT8_TYPE_SEQ \
CUDA_PRIMITIVE_INT32_TYPE_SEQ \
CUDA_PRIMITIVE_UINT32_TYPE_SEQ \
CUDA_PRIMITIVE_INT64_TYPE_SEQ \
CUDA_PRIMITIVE_UINT64_TYPE_SEQ \
CUDA_PRIMITIVE_FLOAT_TYPE_SEQ \
CUDA_PRIMITIVE_DOUBLE_TYPE_SEQ \
CUDA_PRIMITIVE_FLOAT16_TYPE_SEQ \
CUDA_PRIMITIVE_BFLOAT16_TYPE_SEQ
class CastFactoryImpl : public CastFactory {
public:
OF_DISALLOW_COPY_AND_MOVE(CastFactoryImpl);
CastFactoryImpl() = default;
~CastFactoryImpl() override = default;
std::unique_ptr<Cast> New(DataType from, DataType to) override {
#define MAKE_NEW_CAST_ENTRY(from_pair, to_pair) \
{std::make_pair(OF_PP_PAIR_SECOND(from_pair), OF_PP_PAIR_SECOND(to_pair)), \
NewCast<OF_PP_PAIR_FIRST(from_pair), OF_PP_PAIR_FIRST(to_pair)>},
static const std::map<std::pair<DataType, DataType>, std::function<std::unique_ptr<Cast>()>>
new_cast_handle{OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(
MAKE_NEW_CAST_ENTRY, CUDA_PRIMITIVE_CAST_TYPE_SEQ, CUDA_PRIMITIVE_CAST_TYPE_SEQ)};
#undef MAKE_NEW_CAST_ENTRY
const auto it = new_cast_handle.find(std::make_pair(from, to));
if (it != new_cast_handle.end()) {
return it->second();
} else {
return nullptr;
}
}
};
REGISTER_PRIMITIVE_FACTORY(DeviceType::kCUDA, CastFactory, CastFactoryImpl);
} // namespace
} // namespace primitive
} // namespace ep
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/ep/include/primitive/constant_pad.h"
#include "oneflow/core/ep/common/primitive/constant_pad.h"
#include "oneflow/core/ep/rocm/primitive/type_seq.h"
#include "oneflow/core/ep/rocm/cuda_stream.h"
#include <hip/hip_runtime.h>
namespace oneflow {
namespace ep {
namespace primitive {
namespace {
template<size_t num_dims, typename IndexType, typename StorageType>
__global__ void ConstantPadKernel(ConstantPadParams<num_dims, IndexType> params,
StorageType packed_pad_val) {
const StorageType* src = reinterpret_cast<const StorageType*>(params.src);
StorageType* dst = reinterpret_cast<StorageType*>(params.dst);
IndexType src_index[num_dims];
IndexType dst_index[num_dims];
CUDA_1D_KERNEL_LOOP_T(IndexType, linear_index, params.elem_cnt) {
params.dst_index_helper.OffsetToNdIndex(linear_index, dst_index);
bool if_pad = false;
#pragma unroll
for (int i = 0; i < num_dims; i++) {
if (dst_index[i] >= params.valid_start[i] && dst_index[i] < params.valid_end[i]) {
src_index[i] = dst_index[i] - params.valid_start[i];
} else {
if_pad = true;
break;
}
}
StorageType dst_val = packed_pad_val;
if (!if_pad) {
const IndexType src_offset = params.src_index_helper.NdIndexToOffset(src_index);
dst_val = src[src_offset];
}
dst[linear_index] = dst_val;
}
}
template<>
half GetValue<half>(Scalar value) {
return static_cast<half>(GetValue<float>(value));
}
// #if CUDA_VERSION >= 11000
// template<>
// nv_bfloat16 GetValue<nv_bfloat16>(Scalar value) {
// return static_cast<nv_bfloat16>(GetValue<float>(value));
// }
// #endif // CUDA_VERSION >= 11000
template<size_t num_dims, typename IndexType, typename StorageType>
void LaunchKernel(Stream* stream, ConstantPadParams<num_dims, IndexType> params,
StorageType packed_pad_val, size_t elem_cnt) {
stream->As<CudaStream>()->LaunchKernelDefaultWaves(
(ConstantPadKernel<num_dims, IndexType, StorageType>), elem_cnt, params, packed_pad_val);
}
template<size_t num_dims, typename IndexType, typename StorageType>
void LaunchKernel(Stream* stream, void* dst, const int64_t* dst_dims, const void* src,
const int64_t* src_dims, const int64_t* padding_before,
const int64_t* padding_after, StorageType packed_pad_val, size_t elem_cnt) {
ConstantPadParams<num_dims, IndexType> params;
params.dst_index_helper = OffsetToIndexCalculator<IndexType, num_dims>(dst_dims);
params.src_index_helper = NdIndexOffsetHelper<IndexType, num_dims>(src_dims);
params.dst = dst;
params.src = src;
for (int i = 0; i < num_dims; i++) {
params.valid_start[i] = padding_before[i];
params.valid_end[i] = dst_dims[i] - padding_after[i];
}
params.elem_cnt = elem_cnt;
LaunchKernel<num_dims, IndexType, StorageType>(stream, params, packed_pad_val, elem_cnt);
}
template<size_t num_dims, typename StorageType>
void DispatchIndexType(Stream* stream, void* dst, const int64_t* dst_dims, const void* src,
const int64_t* src_dims, const int64_t* padding_before,
const int64_t* padding_after, StorageType packed_pad_val, size_t elem_cnt) {
if (elem_cnt < GetMaxVal<int32_t>()) {
LaunchKernel<num_dims, int32_t, StorageType>(stream, dst, dst_dims, src, src_dims,
padding_before, padding_after, packed_pad_val,
elem_cnt);
} else {
LaunchKernel<num_dims, int64_t, StorageType>(stream, dst, dst_dims, src, src_dims,
padding_before, padding_after, packed_pad_val,
elem_cnt);
}
}
template<size_t num_dims, typename T>
void DispatchPackSize(Stream* stream, void* dst, int64_t* dst_dims, const void* src,
int64_t* src_dims, int64_t* padding_before, int64_t* padding_after,
T pad_val) {
constexpr int32_t max_packsize = GetMaxPackSize<T>();
size_t launch_pack_size = GetLaunchPackSize<max_packsize>(num_dims, dst, dst_dims, src, src_dims,
padding_before, padding_after);
dst_dims[num_dims - 1] /= launch_pack_size;
src_dims[num_dims - 1] /= launch_pack_size;
padding_before[num_dims - 1] /= launch_pack_size;
padding_after[num_dims - 1] /= launch_pack_size;
size_t elem_cnt = 1;
for (int i = 0; i < num_dims; i++) { elem_cnt *= dst_dims[i]; }
if (launch_pack_size == 1) {
Pack<T, 1> packed_pad_val(pad_val);
DispatchIndexType<num_dims, PackType<T, 1>>(stream, dst, dst_dims, src, src_dims,
padding_before, padding_after,
packed_pad_val.storage, elem_cnt);
} else if (launch_pack_size == 2) {
Pack<T, 2> packed_pad_val(pad_val);
DispatchIndexType<num_dims, PackType<T, 2>>(stream, dst, dst_dims, src, src_dims,
padding_before, padding_after,
packed_pad_val.storage, elem_cnt);
} else if (launch_pack_size == 4) {
Pack<T, 4> packed_pad_val(pad_val);
DispatchIndexType<num_dims, PackType<T, 4>>(stream, dst, dst_dims, src, src_dims,
padding_before, padding_after,
packed_pad_val.storage, elem_cnt);
} else if (launch_pack_size == 8) {
Pack<T, 8> packed_pad_val(pad_val);
DispatchIndexType<num_dims, PackType<T, 8>>(stream, dst, dst_dims, src, src_dims,
padding_before, padding_after,
packed_pad_val.storage, elem_cnt);
} else if (launch_pack_size == 16) {
Pack<T, 16> packed_pad_val(pad_val);
DispatchIndexType<num_dims, PackType<T, 16>>(stream, dst, dst_dims, src, src_dims,
padding_before, padding_after,
packed_pad_val.storage, elem_cnt);
} else {
UNIMPLEMENTED();
}
}
template<typename T>
void LaunchWithSimplified(Stream* stream, size_t num_dims, void* dst, int64_t* dst_dims,
const void* src, int64_t* src_dims, int64_t* padding_before,
int64_t* padding_after, T pad_val) {
void (*func)(Stream* /*stream*/, void* /*dst*/, int64_t* /*dst_dims*/, const void* /*src*/,
int64_t* /*src_dims*/, int64_t* /*padding_before*/, int64_t* /*padding_after*/, T) =
nullptr;
if (num_dims == 1) {
func = DispatchPackSize<1, T>;
} else if (num_dims == 2) {
func = DispatchPackSize<2, T>;
} else if (num_dims == 3) {
func = DispatchPackSize<3, T>;
} else if (num_dims == 4) {
func = DispatchPackSize<4, T>;
} else if (num_dims == 5) {
func = DispatchPackSize<5, T>;
} else if (num_dims == 6) {
func = DispatchPackSize<6, T>;
} else if (num_dims == 7) {
func = DispatchPackSize<7, T>;
} else if (num_dims == 8) {
func = DispatchPackSize<8, T>;
} else {
UNIMPLEMENTED();
}
func(stream, dst, dst_dims, src, src_dims, padding_before, padding_after, pad_val);
}
template<typename T>
void SimplifyThenLaunch(Stream* stream, size_t num_dims, const int64_t* src_dims, const void* src,
const int64_t* padding_before, const int64_t* padding_after, T pad_val,
void* dst) {
CHECK_LE(num_dims, kMaxNumDims);
int64_t simplified_dst_dims[kMaxNumDims];
int64_t simplified_src_dims[kMaxNumDims];
int64_t simplified_padding_before[kMaxNumDims];
int64_t simplified_padding_after[kMaxNumDims];
size_t simplified_num_dims = 1;
SimplifyPadDims(num_dims, src_dims, padding_before, padding_after, &simplified_num_dims,
simplified_dst_dims, simplified_src_dims, simplified_padding_before,
simplified_padding_after);
LaunchWithSimplified<T>(stream, simplified_num_dims, dst, simplified_dst_dims, src,
simplified_src_dims, simplified_padding_before, simplified_padding_after,
pad_val);
}
template<typename T>
class ConstantPadImpl : public ConstantPad {
public:
OF_DISALLOW_COPY_AND_MOVE(ConstantPadImpl);
ConstantPadImpl() = default;
~ConstantPadImpl() override = default;
void Launch(Stream* stream, size_t num_dims, const int64_t* src_dims, const void* src,
const int64_t* padding_before, const int64_t* padding_after, Scalar pad_val,
void* dst) override {
SimplifyThenLaunch<T>(stream, num_dims, src_dims, src, padding_before, padding_after,
GetValue<T>(pad_val), dst);
}
};
template<typename T>
std::unique_ptr<ConstantPad> NewConstantPad() {
return std::unique_ptr<ConstantPad>(new ConstantPadImpl<T>());
}
class ConstantPadFactoryImpl : public ConstantPadFactory {
public:
OF_DISALLOW_COPY_AND_MOVE(ConstantPadFactoryImpl);
ConstantPadFactoryImpl() = default;
~ConstantPadFactoryImpl() override = default;
std::unique_ptr<ConstantPad> New(DataType data_type) override {
#define MAKE_NEW_CONSTANT_PAD_ENTRY(type_cpp, type_proto) {type_proto, NewConstantPad<type_cpp>},
static const std::map<DataType, std::function<std::unique_ptr<ConstantPad>()>>
new_constant_pad_handle{
OF_PP_FOR_EACH_TUPLE(MAKE_NEW_CONSTANT_PAD_ENTRY, CUDA_PRIMITIVE_ALL_TYPE_SEQ)};
#undef MAKE_NEW_CONSTANT_PAD_ENTRY
const auto it = new_constant_pad_handle.find(data_type);
if (it != new_constant_pad_handle.end()) {
return it->second();
} else {
return nullptr;
}
}
};
REGISTER_PRIMITIVE_FACTORY(DeviceType::kCUDA, ConstantPadFactory, ConstantPadFactoryImpl);
} // namespace
} // namespace primitive
} // namespace ep
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/ep/include/primitive/constant_pad.h"
#include "oneflow/core/ep/common/primitive/constant_pad.h"
#include "oneflow/core/ep/rocm/primitive/type_seq.h"
#include "oneflow/core/ep/rocm/cuda_stream.h"
#include <hip/hip_runtime.h>
namespace oneflow {
namespace ep {
namespace primitive {
namespace {
template<size_t num_dims, typename IndexType, typename StorageType>
__global__ void ConstantPadKernel(ConstantPadParams<num_dims, IndexType> params,
StorageType packed_pad_val) {
const StorageType* src = reinterpret_cast<const StorageType*>(params.src);
StorageType* dst = reinterpret_cast<StorageType*>(params.dst);
IndexType src_index[num_dims];
IndexType dst_index[num_dims];
CUDA_1D_KERNEL_LOOP_T(IndexType, linear_index, params.elem_cnt) {
params.dst_index_helper.OffsetToNdIndex(linear_index, dst_index);
bool if_pad = false;
#pragma unroll
for (int i = 0; i < num_dims; i++) {
if (dst_index[i] >= params.valid_start[i] && dst_index[i] < params.valid_end[i]) {
src_index[i] = dst_index[i] - params.valid_start[i];
} else {
if_pad = true;
break;
}
}
StorageType dst_val = packed_pad_val;
if (!if_pad) {
const IndexType src_offset = params.src_index_helper.NdIndexToOffset(src_index);
dst_val = src[src_offset];
}
dst[linear_index] = dst_val;
}
}
template<>
half GetValue<half>(Scalar value) {
return static_cast<half>(GetValue<float>(value));
}
// #if CUDA_VERSION >= 11000
// template<>
// nv_bfloat16 GetValue<nv_bfloat16>(Scalar value) {
// return static_cast<nv_bfloat16>(GetValue<float>(value));
// }
// #endif // CUDA_VERSION >= 11000
template<size_t num_dims, typename IndexType, typename StorageType>
void LaunchKernel(Stream* stream, ConstantPadParams<num_dims, IndexType> params,
StorageType packed_pad_val, size_t elem_cnt) {
stream->As<CudaStream>()->LaunchKernelDefaultWaves(
(ConstantPadKernel<num_dims, IndexType, StorageType>), elem_cnt, params, packed_pad_val);
}
template<size_t num_dims, typename IndexType, typename StorageType>
void LaunchKernel(Stream* stream, void* dst, const int64_t* dst_dims, const void* src,
const int64_t* src_dims, const int64_t* padding_before,
const int64_t* padding_after, StorageType packed_pad_val, size_t elem_cnt) {
ConstantPadParams<num_dims, IndexType> params;
params.dst_index_helper = OffsetToIndexCalculator<IndexType, num_dims>(dst_dims);
params.src_index_helper = NdIndexOffsetHelper<IndexType, num_dims>(src_dims);
params.dst = dst;
params.src = src;
for (int i = 0; i < num_dims; i++) {
params.valid_start[i] = padding_before[i];
params.valid_end[i] = dst_dims[i] - padding_after[i];
}
params.elem_cnt = elem_cnt;
LaunchKernel<num_dims, IndexType, StorageType>(stream, params, packed_pad_val, elem_cnt);
}
template<size_t num_dims, typename StorageType>
void DispatchIndexType(Stream* stream, void* dst, const int64_t* dst_dims, const void* src,
const int64_t* src_dims, const int64_t* padding_before,
const int64_t* padding_after, StorageType packed_pad_val, size_t elem_cnt) {
if (elem_cnt < GetMaxVal<int32_t>()) {
LaunchKernel<num_dims, int32_t, StorageType>(stream, dst, dst_dims, src, src_dims,
padding_before, padding_after, packed_pad_val,
elem_cnt);
} else {
LaunchKernel<num_dims, int64_t, StorageType>(stream, dst, dst_dims, src, src_dims,
padding_before, padding_after, packed_pad_val,
elem_cnt);
}
}
template<size_t num_dims, typename T>
void DispatchPackSize(Stream* stream, void* dst, int64_t* dst_dims, const void* src,
int64_t* src_dims, int64_t* padding_before, int64_t* padding_after,
T pad_val) {
constexpr int32_t max_packsize = GetMaxPackSize<T>();
size_t launch_pack_size = GetLaunchPackSize<max_packsize>(num_dims, dst, dst_dims, src, src_dims,
padding_before, padding_after);
dst_dims[num_dims - 1] /= launch_pack_size;
src_dims[num_dims - 1] /= launch_pack_size;
padding_before[num_dims - 1] /= launch_pack_size;
padding_after[num_dims - 1] /= launch_pack_size;
size_t elem_cnt = 1;
for (int i = 0; i < num_dims; i++) { elem_cnt *= dst_dims[i]; }
if (launch_pack_size == 1) {
Pack<T, 1> packed_pad_val(pad_val);
DispatchIndexType<num_dims, PackType<T, 1>>(stream, dst, dst_dims, src, src_dims,
padding_before, padding_after,
packed_pad_val.storage, elem_cnt);
} else if (launch_pack_size == 2) {
Pack<T, 2> packed_pad_val(pad_val);
DispatchIndexType<num_dims, PackType<T, 2>>(stream, dst, dst_dims, src, src_dims,
padding_before, padding_after,
packed_pad_val.storage, elem_cnt);
} else if (launch_pack_size == 4) {
Pack<T, 4> packed_pad_val(pad_val);
DispatchIndexType<num_dims, PackType<T, 4>>(stream, dst, dst_dims, src, src_dims,
padding_before, padding_after,
packed_pad_val.storage, elem_cnt);
} else if (launch_pack_size == 8) {
Pack<T, 8> packed_pad_val(pad_val);
DispatchIndexType<num_dims, PackType<T, 8>>(stream, dst, dst_dims, src, src_dims,
padding_before, padding_after,
packed_pad_val.storage, elem_cnt);
} else if (launch_pack_size == 16) {
Pack<T, 16> packed_pad_val(pad_val);
DispatchIndexType<num_dims, PackType<T, 16>>(stream, dst, dst_dims, src, src_dims,
padding_before, padding_after,
packed_pad_val.storage, elem_cnt);
} else {
UNIMPLEMENTED();
}
}
template<typename T>
void LaunchWithSimplified(Stream* stream, size_t num_dims, void* dst, int64_t* dst_dims,
const void* src, int64_t* src_dims, int64_t* padding_before,
int64_t* padding_after, T pad_val) {
void (*func)(Stream* /*stream*/, void* /*dst*/, int64_t* /*dst_dims*/, const void* /*src*/,
int64_t* /*src_dims*/, int64_t* /*padding_before*/, int64_t* /*padding_after*/, T) =
nullptr;
if (num_dims == 1) {
func = DispatchPackSize<1, T>;
} else if (num_dims == 2) {
func = DispatchPackSize<2, T>;
} else if (num_dims == 3) {
func = DispatchPackSize<3, T>;
} else if (num_dims == 4) {
func = DispatchPackSize<4, T>;
} else if (num_dims == 5) {
func = DispatchPackSize<5, T>;
} else if (num_dims == 6) {
func = DispatchPackSize<6, T>;
} else if (num_dims == 7) {
func = DispatchPackSize<7, T>;
} else if (num_dims == 8) {
func = DispatchPackSize<8, T>;
} else {
UNIMPLEMENTED();
}
func(stream, dst, dst_dims, src, src_dims, padding_before, padding_after, pad_val);
}
template<typename T>
void SimplifyThenLaunch(Stream* stream, size_t num_dims, const int64_t* src_dims, const void* src,
const int64_t* padding_before, const int64_t* padding_after, T pad_val,
void* dst) {
CHECK_LE(num_dims, kMaxNumDims);
int64_t simplified_dst_dims[kMaxNumDims];
int64_t simplified_src_dims[kMaxNumDims];
int64_t simplified_padding_before[kMaxNumDims];
int64_t simplified_padding_after[kMaxNumDims];
size_t simplified_num_dims = 1;
SimplifyPadDims(num_dims, src_dims, padding_before, padding_after, &simplified_num_dims,
simplified_dst_dims, simplified_src_dims, simplified_padding_before,
simplified_padding_after);
LaunchWithSimplified<T>(stream, simplified_num_dims, dst, simplified_dst_dims, src,
simplified_src_dims, simplified_padding_before, simplified_padding_after,
pad_val);
}
template<typename T>
class ConstantPadImpl : public ConstantPad {
public:
OF_DISALLOW_COPY_AND_MOVE(ConstantPadImpl);
ConstantPadImpl() = default;
~ConstantPadImpl() override = default;
void Launch(Stream* stream, size_t num_dims, const int64_t* src_dims, const void* src,
const int64_t* padding_before, const int64_t* padding_after, Scalar pad_val,
void* dst) override {
SimplifyThenLaunch<T>(stream, num_dims, src_dims, src, padding_before, padding_after,
GetValue<T>(pad_val), dst);
}
};
template<typename T>
std::unique_ptr<ConstantPad> NewConstantPad() {
return std::unique_ptr<ConstantPad>(new ConstantPadImpl<T>());
}
class ConstantPadFactoryImpl : public ConstantPadFactory {
public:
OF_DISALLOW_COPY_AND_MOVE(ConstantPadFactoryImpl);
ConstantPadFactoryImpl() = default;
~ConstantPadFactoryImpl() override = default;
std::unique_ptr<ConstantPad> New(DataType data_type) override {
#define MAKE_NEW_CONSTANT_PAD_ENTRY(type_cpp, type_proto) {type_proto, NewConstantPad<type_cpp>},
static const std::map<DataType, std::function<std::unique_ptr<ConstantPad>()>>
new_constant_pad_handle{
OF_PP_FOR_EACH_TUPLE(MAKE_NEW_CONSTANT_PAD_ENTRY, CUDA_PRIMITIVE_ALL_TYPE_SEQ)};
#undef MAKE_NEW_CONSTANT_PAD_ENTRY
const auto it = new_constant_pad_handle.find(data_type);
if (it != new_constant_pad_handle.end()) {
return it->second();
} else {
return nullptr;
}
}
};
REGISTER_PRIMITIVE_FACTORY(DeviceType::kCUDA, ConstantPadFactory, ConstantPadFactoryImpl);
} // namespace
} // namespace primitive
} // namespace ep
} // namespace oneflow
\ No newline at end of file
#include "hip/hip_runtime.h"
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/ep/include/primitive/copy_nd.h"
#include "oneflow/core/ep/common/primitive/copy_nd.h"
#include "oneflow/core/ep/rocm/cuda_stream.h"
#include <hip/hip_runtime.h>
namespace oneflow {
namespace ep {
namespace primitive {
namespace {
template<size_t num_dims, size_t movement_size, typename IndexType>
__global__ void CopyNdKernel(CopyNdKernelParams<num_dims, IndexType> params) {
using T = typename std::aligned_storage<movement_size, movement_size>::type;
const T* src = reinterpret_cast<const T*>(params.src);
T* dst = reinterpret_cast<T*>(params.dst);
IndexType copy_index[num_dims];
IndexType src_index[num_dims];
IndexType dst_index[num_dims];
CUDA_1D_KERNEL_LOOP_T(IndexType, i, params.count) {
params.copy_index_helper.OffsetToNdIndex(i, copy_index);
#pragma unroll
for (size_t j = 0; j < num_dims; ++j) {
src_index[j] = params.src_pos[j] + copy_index[j];
dst_index[j] = params.dst_pos[j] + copy_index[j];
}
const IndexType src_offset = params.src_index_helper.NdIndexToOffset(src_index);
const IndexType dst_offset = params.dst_index_helper.NdIndexToOffset(dst_index);
dst[dst_offset] = src[src_offset];
}
}
template<size_t num_dims, size_t movement_size, typename IndexType>
void LaunchKernel(Stream* stream, CopyNdKernelParams<num_dims, IndexType> params) {
hipStream_t cuda_stream = stream->As<CudaStream>()->cuda_stream();
CopyNdKernel<num_dims, movement_size, IndexType>
<<<BlocksNum4ThreadsNum(params.count), kCudaThreadsNumPerBlock, 0, cuda_stream>>>(params);
}
class CopyNdImpl : public CopyNd {
public:
OF_DISALLOW_COPY_AND_MOVE(CopyNdImpl);
CopyNdImpl() = default;
~CopyNdImpl() override = default;
void Launch(Stream* stream, DataType data_type, size_t num_dims, void* dst,
const int64_t* dst_dims, const int64_t* dst_pos, const void* src,
const int64_t* src_dims, const int64_t* src_pos,
const int64_t* extent) const override {
SimplifyThenLaunch(stream, data_type, num_dims, dst, dst_dims, dst_pos, src, src_dims, src_pos,
extent);
}
};
class CopyNdFactoryImpl : public CopyNdFactory {
public:
OF_DISALLOW_COPY_AND_MOVE(CopyNdFactoryImpl);
CopyNdFactoryImpl() = default;
~CopyNdFactoryImpl() override = default;
std::unique_ptr<CopyNd> New(size_t max_num_dims) override {
if (max_num_dims <= kMaxNumDims) {
return std::unique_ptr<CopyNd>(new CopyNdImpl());
} else {
return nullptr;
}
}
};
REGISTER_PRIMITIVE_FACTORY(DeviceType::kCUDA, CopyNdFactory, CopyNdFactoryImpl);
} // namespace
} // namespace primitive
} // namespace ep
} // namespace oneflow
#include "hip/hip_runtime.h"
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/ep/include/primitive/copy_nd.h"
#include "oneflow/core/ep/common/primitive/copy_nd.h"
#include "oneflow/core/ep/rocm/cuda_stream.h"
#include <hip/hip_runtime.h>
namespace oneflow {
namespace ep {
namespace primitive {
namespace {
template<size_t num_dims, size_t movement_size, typename IndexType>
__global__ void CopyNdKernel(CopyNdKernelParams<num_dims, IndexType> params) {
using T = typename std::aligned_storage<movement_size, movement_size>::type;
const T* src = reinterpret_cast<const T*>(params.src);
T* dst = reinterpret_cast<T*>(params.dst);
IndexType copy_index[num_dims];
IndexType src_index[num_dims];
IndexType dst_index[num_dims];
CUDA_1D_KERNEL_LOOP_T(IndexType, i, params.count) {
params.copy_index_helper.OffsetToNdIndex(i, copy_index);
#pragma unroll
for (size_t j = 0; j < num_dims; ++j) {
src_index[j] = params.src_pos[j] + copy_index[j];
dst_index[j] = params.dst_pos[j] + copy_index[j];
}
const IndexType src_offset = params.src_index_helper.NdIndexToOffset(src_index);
const IndexType dst_offset = params.dst_index_helper.NdIndexToOffset(dst_index);
dst[dst_offset] = src[src_offset];
}
}
template<size_t num_dims, size_t movement_size, typename IndexType>
void LaunchKernel(Stream* stream, CopyNdKernelParams<num_dims, IndexType> params) {
hipStream_t cuda_stream = stream->As<CudaStream>()->cuda_stream();
CopyNdKernel<num_dims, movement_size, IndexType>
<<<BlocksNum4ThreadsNum(params.count), kCudaThreadsNumPerBlock, 0, cuda_stream>>>(params);
}
class CopyNdImpl : public CopyNd {
public:
OF_DISALLOW_COPY_AND_MOVE(CopyNdImpl);
CopyNdImpl() = default;
~CopyNdImpl() override = default;
void Launch(Stream* stream, DataType data_type, size_t num_dims, void* dst,
const int64_t* dst_dims, const int64_t* dst_pos, const void* src,
const int64_t* src_dims, const int64_t* src_pos,
const int64_t* extent) const override {
SimplifyThenLaunch(stream, data_type, num_dims, dst, dst_dims, dst_pos, src, src_dims, src_pos,
extent);
}
};
class CopyNdFactoryImpl : public CopyNdFactory {
public:
OF_DISALLOW_COPY_AND_MOVE(CopyNdFactoryImpl);
CopyNdFactoryImpl() = default;
~CopyNdFactoryImpl() override = default;
std::unique_ptr<CopyNd> New(size_t max_num_dims) override {
if (max_num_dims <= kMaxNumDims) {
return std::unique_ptr<CopyNd>(new CopyNdImpl());
} else {
return nullptr;
}
}
};
REGISTER_PRIMITIVE_FACTORY(DeviceType::kCUDA, CopyNdFactory, CopyNdFactoryImpl);
} // namespace
} // namespace primitive
} // namespace ep
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/ep/common/primitive/elementwise_unary.h"
#include "oneflow/core/ep/rocm/primitive/unary_functor.hip.h"
namespace oneflow {
namespace ep {
namespace primitive {
namespace {
template<UnaryOp unary_op, typename Src, typename Dst>
class ElementwiseUnaryImpl : public ElementwiseUnary {
public:
OF_DISALLOW_COPY_AND_MOVE(ElementwiseUnaryImpl);
ElementwiseUnaryImpl(Scalar attr0, Scalar attr1) : attr0(attr0), attr1(attr1) {}
~ElementwiseUnaryImpl() override = default;
void Launch(Stream* stream, const void* src, void* dst, size_t count) override {
auto* cuda_stream = stream->As<CudaStream>();
auto functor = UnaryFunctor<DeviceType::kCUDA, unary_op, Dst, Src>(attr0, attr1);
OF_CUDA_CHECK((cuda::elementwise::Unary<decltype(functor), Dst, Src>(
functor, count, reinterpret_cast<Dst*>(dst), reinterpret_cast<const Src*>(src),
cuda_stream->cuda_stream())));
}
protected:
Scalar attr0, attr1;
};
template<UnaryOp unary_op, typename Src, typename Dst>
std::unique_ptr<ElementwiseUnary> NewElementwiseUnary(Scalar attr0, Scalar attr1) {
return std::unique_ptr<ElementwiseUnary>(
new ElementwiseUnaryImpl<unary_op, Src, Dst>(attr0, attr1));
}
class ElementwiseUnaryFactoryImpl : public ElementwiseUnaryFactory {
public:
OF_DISALLOW_COPY_AND_MOVE(ElementwiseUnaryFactoryImpl);
ElementwiseUnaryFactoryImpl() = default;
~ElementwiseUnaryFactoryImpl() override = default;
std::unique_ptr<ElementwiseUnary> New(UnaryOp unary_op, DataType src_type,
DataType dst_dtype) override {
return New(unary_op, src_type, dst_dtype, Scalar(), Scalar());
}
std::unique_ptr<ElementwiseUnary> New(UnaryOp unary_op, DataType src_type, DataType dst_dtype,
Scalar attr0) override {
return New(unary_op, src_type, dst_dtype, attr0, Scalar());
}
std::unique_ptr<ElementwiseUnary> New(UnaryOp unary_op, DataType src_type, DataType dst_dtype,
Scalar attr0, Scalar attr1) override {
#define MAKE_NEW_SAME_DTYPE_ELEMENTWISE_UNARY_ENTRY(unary_op, dtype_pair) \
{std::make_tuple(unary_op, OF_PP_PAIR_SECOND(dtype_pair), OF_PP_PAIR_SECOND(dtype_pair)), \
NewElementwiseUnary<unary_op, OF_PP_PAIR_FIRST(dtype_pair), OF_PP_PAIR_FIRST(dtype_pair)>},
#define MAKE_NEW_DIFFERENT_DTYPE_ELEMENTWISE_UNARY_ENTRY(unary_op, src_type_pair, dst_dtype_pair) \
{std::make_tuple(unary_op, OF_PP_PAIR_SECOND(src_type_pair), OF_PP_PAIR_SECOND(dst_dtype_pair)), \
NewElementwiseUnary<unary_op, OF_PP_PAIR_FIRST(src_type_pair), \
OF_PP_PAIR_FIRST(dst_dtype_pair)>},
static const std::map<std::tuple<UnaryOp, DataType, DataType>,
std::function<std::unique_ptr<ElementwiseUnary>(Scalar, Scalar)>>
new_elementwise_unary_handle{
// For All Type OP
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(MAKE_NEW_SAME_DTYPE_ELEMENTWISE_UNARY_ENTRY,
UNARY_MATH_OP_SEQ, CUDA_PRIMITIVE_ALL_TYPE_SEQ)
// For Float Type OP
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(MAKE_NEW_SAME_DTYPE_ELEMENTWISE_UNARY_ENTRY,
UNARY_FLOATING_MATH_OP_SEQ,
CUDA_PRIMITIVE_FLOATING_TYPE_SEQ)
// For Utils OP
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(MAKE_NEW_DIFFERENT_DTYPE_ELEMENTWISE_UNARY_ENTRY,
UNARY_UTILS_OP_SEQ, UTIL_OPS_DATA_TYPE_SEQ,
CUDA_PRIMITIVE_BOOL_TYPE_SEQ)
// For Logical OP
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(MAKE_NEW_DIFFERENT_DTYPE_ELEMENTWISE_UNARY_ENTRY,
UNARY_LOGICAL_OP_SEQ, CUDA_PRIMITIVE_ALL_TYPE_SEQ,
CUDA_PRIMITIVE_BOOL_TYPE_SEQ)};
#undef MAKE_NEW_DIFFERENT_DTYPE_ELEMENTWISE_UNARY_ENTRY
#undef MAKE_NEW_SAME_DTYPE_ELEMENTWISE_UNARY_ENTRY
const auto it =
new_elementwise_unary_handle.find(std::make_tuple(unary_op, src_type, dst_dtype));
if (it != new_elementwise_unary_handle.end()) {
return it->second(attr0, attr1);
} else {
return nullptr;
}
}
};
REGISTER_PRIMITIVE_FACTORY(DeviceType::kCUDA, ElementwiseUnaryFactory, ElementwiseUnaryFactoryImpl);
} // namespace
} // namespace primitive
} // namespace ep
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/ep/common/primitive/elementwise_unary.h"
#include "oneflow/core/ep/rocm/primitive/unary_functor.hip.h"
namespace oneflow {
namespace ep {
namespace primitive {
namespace {
template<UnaryOp unary_op, typename Src, typename Dst>
class ElementwiseUnaryImpl : public ElementwiseUnary {
public:
OF_DISALLOW_COPY_AND_MOVE(ElementwiseUnaryImpl);
ElementwiseUnaryImpl(Scalar attr0, Scalar attr1) : attr0(attr0), attr1(attr1) {}
~ElementwiseUnaryImpl() override = default;
void Launch(Stream* stream, const void* src, void* dst, size_t count) override {
auto* cuda_stream = stream->As<CudaStream>();
auto functor = UnaryFunctor<DeviceType::kCUDA, unary_op, Dst, Src>(attr0, attr1);
OF_CUDA_CHECK((cuda::elementwise::Unary<decltype(functor), Dst, Src>(
functor, count, reinterpret_cast<Dst*>(dst), reinterpret_cast<const Src*>(src),
cuda_stream->cuda_stream())));
}
protected:
Scalar attr0, attr1;
};
template<UnaryOp unary_op, typename Src, typename Dst>
std::unique_ptr<ElementwiseUnary> NewElementwiseUnary(Scalar attr0, Scalar attr1) {
return std::unique_ptr<ElementwiseUnary>(
new ElementwiseUnaryImpl<unary_op, Src, Dst>(attr0, attr1));
}
class ElementwiseUnaryFactoryImpl : public ElementwiseUnaryFactory {
public:
OF_DISALLOW_COPY_AND_MOVE(ElementwiseUnaryFactoryImpl);
ElementwiseUnaryFactoryImpl() = default;
~ElementwiseUnaryFactoryImpl() override = default;
std::unique_ptr<ElementwiseUnary> New(UnaryOp unary_op, DataType src_type,
DataType dst_dtype) override {
return New(unary_op, src_type, dst_dtype, Scalar(), Scalar());
}
std::unique_ptr<ElementwiseUnary> New(UnaryOp unary_op, DataType src_type, DataType dst_dtype,
Scalar attr0) override {
return New(unary_op, src_type, dst_dtype, attr0, Scalar());
}
std::unique_ptr<ElementwiseUnary> New(UnaryOp unary_op, DataType src_type, DataType dst_dtype,
Scalar attr0, Scalar attr1) override {
#define MAKE_NEW_SAME_DTYPE_ELEMENTWISE_UNARY_ENTRY(unary_op, dtype_pair) \
{std::make_tuple(unary_op, OF_PP_PAIR_SECOND(dtype_pair), OF_PP_PAIR_SECOND(dtype_pair)), \
NewElementwiseUnary<unary_op, OF_PP_PAIR_FIRST(dtype_pair), OF_PP_PAIR_FIRST(dtype_pair)>},
#define MAKE_NEW_DIFFERENT_DTYPE_ELEMENTWISE_UNARY_ENTRY(unary_op, src_type_pair, dst_dtype_pair) \
{std::make_tuple(unary_op, OF_PP_PAIR_SECOND(src_type_pair), OF_PP_PAIR_SECOND(dst_dtype_pair)), \
NewElementwiseUnary<unary_op, OF_PP_PAIR_FIRST(src_type_pair), \
OF_PP_PAIR_FIRST(dst_dtype_pair)>},
static const std::map<std::tuple<UnaryOp, DataType, DataType>,
std::function<std::unique_ptr<ElementwiseUnary>(Scalar, Scalar)>>
new_elementwise_unary_handle{
// For All Type OP
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(MAKE_NEW_SAME_DTYPE_ELEMENTWISE_UNARY_ENTRY,
UNARY_MATH_OP_SEQ, CUDA_PRIMITIVE_ALL_TYPE_SEQ)
// For Float Type OP
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(MAKE_NEW_SAME_DTYPE_ELEMENTWISE_UNARY_ENTRY,
UNARY_FLOATING_MATH_OP_SEQ,
CUDA_PRIMITIVE_FLOATING_TYPE_SEQ)
// For Utils OP
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(MAKE_NEW_DIFFERENT_DTYPE_ELEMENTWISE_UNARY_ENTRY,
UNARY_UTILS_OP_SEQ, UTIL_OPS_DATA_TYPE_SEQ,
CUDA_PRIMITIVE_BOOL_TYPE_SEQ)
// For Logical OP
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(MAKE_NEW_DIFFERENT_DTYPE_ELEMENTWISE_UNARY_ENTRY,
UNARY_LOGICAL_OP_SEQ, CUDA_PRIMITIVE_ALL_TYPE_SEQ,
CUDA_PRIMITIVE_BOOL_TYPE_SEQ)};
#undef MAKE_NEW_DIFFERENT_DTYPE_ELEMENTWISE_UNARY_ENTRY
#undef MAKE_NEW_SAME_DTYPE_ELEMENTWISE_UNARY_ENTRY
const auto it =
new_elementwise_unary_handle.find(std::make_tuple(unary_op, src_type, dst_dtype));
if (it != new_elementwise_unary_handle.end()) {
return it->second(attr0, attr1);
} else {
return nullptr;
}
}
};
REGISTER_PRIMITIVE_FACTORY(DeviceType::kCUDA, ElementwiseUnaryFactory, ElementwiseUnaryFactoryImpl);
} // namespace
} // namespace primitive
} // namespace ep
} // namespace oneflow
\ No newline at end of file
#include "hip/hip_runtime.h"
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/ep/include/primitive/fill.h"
#include "oneflow/core/ep/rocm/primitive/type_seq.h"
#include "oneflow/core/ep/rocm/cuda_stream.h"
namespace oneflow {
namespace ep {
namespace primitive {
namespace {
template<size_t size>
using Storage = typename std::aligned_storage<size, size>::type;
template<typename T, size_t pack>
union Pack {
static constexpr size_t size = sizeof(T) * pack;
explicit __device__ __host__ Pack(T value) {
static_assert(sizeof(Pack) == size, "");
static_assert(alignof(Pack) == size, "");
#pragma unroll
for (size_t i = 0; i < pack; ++i) { elem[i] = value; }
}
T elem[pack];
Storage<size> storage;
};
template<typename T, size_t pack>
__global__ void FillGpu(T* dst, T value, size_t count) {
const size_t pack_count = count / pack;
Pack<T, pack> pack_value(value);
auto* pack_dst = reinterpret_cast<decltype(pack_value.storage)*>(dst);
CUDA_1D_KERNEL_LOOP_T(size_t, i, pack_count) { pack_dst[i] = pack_value.storage; }
T* tail_dst = dst + pack_count * pack;
const size_t tail_count = count - pack_count * pack;
CUDA_1D_KERNEL_LOOP_T(size_t, i, tail_count) { tail_dst[i] = value; }
}
template<typename T>
T GetValue(Scalar value) {
return value.Value<T>();
}
template<>
half GetValue<half>(Scalar value) {
return static_cast<half>(GetValue<float>(value));
}
// #if CUDA_VERSION >= 11000
// template<>
// nv_bfloat16 GetValue<nv_bfloat16>(Scalar value) {
// return static_cast<nv_bfloat16>(GetValue<float>(value));
// }
// #endif // CUDA_VERSION >= 11000
template<typename T, size_t pack>
typename std::enable_if<(pack != 0), void>::type LaunchPackFill(hipStream_t stream, T* dst,
T value, size_t count) {
FillGpu<T, pack>
<<<BlocksNum4ThreadsNum(count), kCudaThreadsNumPerBlock, 0, stream>>>(dst, value, count);
}
template<typename T, size_t pack>
typename std::enable_if<(pack == 0), void>::type LaunchPackFill(hipStream_t stream, T* dst,
T value, size_t count) {
LOG(FATAL) << "wrong alignment";
}
template<typename T>
void LaunchFill(hipStream_t stream, T* dst, T value, size_t count) {
auto uintptr = reinterpret_cast<std::uintptr_t>(dst);
if (uintptr % 16 == 0) {
LaunchPackFill<T, 16 / sizeof(T)>(stream, dst, value, count);
} else if (uintptr % 8 == 0) {
LaunchPackFill<T, 8 / sizeof(T)>(stream, dst, value, count);
} else if (uintptr % 4 == 0) {
LaunchPackFill<T, 4 / sizeof(T)>(stream, dst, value, count);
} else if (uintptr % 2 == 0) {
LaunchPackFill<T, 2 / sizeof(T)>(stream, dst, value, count);
} else {
LaunchPackFill<T, 1 / sizeof(T)>(stream, dst, value, count);
}
}
template<typename T>
class FillImpl : public Fill {
public:
OF_DISALLOW_COPY_AND_MOVE(FillImpl);
FillImpl() = default;
~FillImpl() override = default;
void Launch(Stream* stream, void* dst, Scalar value, size_t count) override {
hipStream_t cuda_stream = stream->As<CudaStream>()->cuda_stream();
LaunchFill<T>(cuda_stream, reinterpret_cast<T*>(dst), GetValue<T>(value), count);
}
};
template<typename T>
std::unique_ptr<Fill> NewFill() {
return std::unique_ptr<Fill>(new FillImpl<T>());
}
class FillFactoryImpl : public FillFactory {
public:
OF_DISALLOW_COPY_AND_MOVE(FillFactoryImpl);
FillFactoryImpl() = default;
~FillFactoryImpl() override = default;
std::unique_ptr<Fill> New(DataType data_type) override {
#define MAKE_NEW_FILL_ENTRY(type_cpp, type_proto) {type_proto, NewFill<type_cpp>},
static const std::map<DataType, std::function<std::unique_ptr<Fill>()>> new_fill_handle{
OF_PP_FOR_EACH_TUPLE(MAKE_NEW_FILL_ENTRY, CUDA_PRIMITIVE_ALL_TYPE_SEQ)};
#undef MAKE_NEW_FILL_ENTRY
const auto it = new_fill_handle.find(data_type);
if (it != new_fill_handle.end()) {
return it->second();
} else {
return nullptr;
}
}
};
REGISTER_PRIMITIVE_FACTORY(DeviceType::kCUDA, FillFactory, FillFactoryImpl);
} // namespace
} // namespace primitive
} // namespace ep
} // namespace oneflow
#include "hip/hip_runtime.h"
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/ep/include/primitive/fill.h"
#include "oneflow/core/ep/rocm/primitive/type_seq.h"
#include "oneflow/core/ep/rocm/cuda_stream.h"
namespace oneflow {
namespace ep {
namespace primitive {
namespace {
template<size_t size>
using Storage = typename std::aligned_storage<size, size>::type;
template<typename T, size_t pack>
union Pack {
static constexpr size_t size = sizeof(T) * pack;
explicit __device__ __host__ Pack(T value) {
static_assert(sizeof(Pack) == size, "");
static_assert(alignof(Pack) == size, "");
#pragma unroll
for (size_t i = 0; i < pack; ++i) { elem[i] = value; }
}
T elem[pack];
Storage<size> storage;
};
template<typename T, size_t pack>
__global__ void FillGpu(T* dst, T value, size_t count) {
const size_t pack_count = count / pack;
Pack<T, pack> pack_value(value);
auto* pack_dst = reinterpret_cast<decltype(pack_value.storage)*>(dst);
CUDA_1D_KERNEL_LOOP_T(size_t, i, pack_count) { pack_dst[i] = pack_value.storage; }
T* tail_dst = dst + pack_count * pack;
const size_t tail_count = count - pack_count * pack;
CUDA_1D_KERNEL_LOOP_T(size_t, i, tail_count) { tail_dst[i] = value; }
}
template<typename T>
T GetValue(Scalar value) {
return value.Value<T>();
}
template<>
half GetValue<half>(Scalar value) {
return static_cast<half>(GetValue<float>(value));
}
// #if CUDA_VERSION >= 11000
// template<>
// nv_bfloat16 GetValue<nv_bfloat16>(Scalar value) {
// return static_cast<nv_bfloat16>(GetValue<float>(value));
// }
// #endif // CUDA_VERSION >= 11000
template<typename T, size_t pack>
typename std::enable_if<(pack != 0), void>::type LaunchPackFill(hipStream_t stream, T* dst,
T value, size_t count) {
FillGpu<T, pack>
<<<BlocksNum4ThreadsNum(count), kCudaThreadsNumPerBlock, 0, stream>>>(dst, value, count);
}
template<typename T, size_t pack>
typename std::enable_if<(pack == 0), void>::type LaunchPackFill(hipStream_t stream, T* dst,
T value, size_t count) {
LOG(FATAL) << "wrong alignment";
}
template<typename T>
void LaunchFill(hipStream_t stream, T* dst, T value, size_t count) {
auto uintptr = reinterpret_cast<std::uintptr_t>(dst);
if (uintptr % 16 == 0) {
LaunchPackFill<T, 16 / sizeof(T)>(stream, dst, value, count);
} else if (uintptr % 8 == 0) {
LaunchPackFill<T, 8 / sizeof(T)>(stream, dst, value, count);
} else if (uintptr % 4 == 0) {
LaunchPackFill<T, 4 / sizeof(T)>(stream, dst, value, count);
} else if (uintptr % 2 == 0) {
LaunchPackFill<T, 2 / sizeof(T)>(stream, dst, value, count);
} else {
LaunchPackFill<T, 1 / sizeof(T)>(stream, dst, value, count);
}
}
template<typename T>
class FillImpl : public Fill {
public:
OF_DISALLOW_COPY_AND_MOVE(FillImpl);
FillImpl() = default;
~FillImpl() override = default;
void Launch(Stream* stream, void* dst, Scalar value, size_t count) override {
hipStream_t cuda_stream = stream->As<CudaStream>()->cuda_stream();
LaunchFill<T>(cuda_stream, reinterpret_cast<T*>(dst), GetValue<T>(value), count);
}
};
template<typename T>
std::unique_ptr<Fill> NewFill() {
return std::unique_ptr<Fill>(new FillImpl<T>());
}
class FillFactoryImpl : public FillFactory {
public:
OF_DISALLOW_COPY_AND_MOVE(FillFactoryImpl);
FillFactoryImpl() = default;
~FillFactoryImpl() override = default;
std::unique_ptr<Fill> New(DataType data_type) override {
#define MAKE_NEW_FILL_ENTRY(type_cpp, type_proto) {type_proto, NewFill<type_cpp>},
static const std::map<DataType, std::function<std::unique_ptr<Fill>()>> new_fill_handle{
OF_PP_FOR_EACH_TUPLE(MAKE_NEW_FILL_ENTRY, CUDA_PRIMITIVE_ALL_TYPE_SEQ)};
#undef MAKE_NEW_FILL_ENTRY
const auto it = new_fill_handle.find(data_type);
if (it != new_fill_handle.end()) {
return it->second();
} else {
return nullptr;
}
}
};
REGISTER_PRIMITIVE_FACTORY(DeviceType::kCUDA, FillFactory, FillFactoryImpl);
} // namespace
} // namespace primitive
} // namespace ep
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifdef WITH_ROCM
#include "oneflow/core/ep/include/primitive/memcpy.h"
#include "oneflow/core/ep/rocm/cuda_stream.h"
#include <hip/hip_runtime.h>
namespace oneflow {
namespace ep {
namespace primitive {
namespace {
class MemcpyImpl : public Memcpy {
public:
OF_DISALLOW_COPY_AND_MOVE(MemcpyImpl);
MemcpyImpl() = default;
~MemcpyImpl() override = default;
void Launch(Stream* stream, void* dst, const void* src, size_t count) override {
if (dst == src) { return; }
auto* cuda_stream = stream->As<CudaStream>();
OF_CUDA_CHECK(hipMemcpyAsync(dst, src, count, hipMemcpyDefault, cuda_stream->cuda_stream()));
}
};
class MemcpyFactoryImpl : public MemcpyFactory {
public:
OF_DISALLOW_COPY_AND_MOVE(MemcpyFactoryImpl);
MemcpyFactoryImpl() = default;
~MemcpyFactoryImpl() override = default;
std::unique_ptr<Memcpy> New(MemcpyKind kind) override {
return std::unique_ptr<Memcpy>(new MemcpyImpl());
}
};
REGISTER_PRIMITIVE_FACTORY(DeviceType::kCUDA, MemcpyFactory, MemcpyFactoryImpl);
} // namespace
} // namespace primitive
} // namespace ep
} // namespace oneflow
#endif
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifdef WITH_ROCM
#include "oneflow/core/ep/include/primitive/memcpy.h"
#include "oneflow/core/ep/rocm/cuda_stream.h"
#include <hip/hip_runtime.h>
namespace oneflow {
namespace ep {
namespace primitive {
namespace {
class MemcpyImpl : public Memcpy {
public:
OF_DISALLOW_COPY_AND_MOVE(MemcpyImpl);
MemcpyImpl() = default;
~MemcpyImpl() override = default;
void Launch(Stream* stream, void* dst, const void* src, size_t count) override {
if (dst == src) { return; }
auto* cuda_stream = stream->As<CudaStream>();
OF_CUDA_CHECK(hipMemcpyAsync(dst, src, count, hipMemcpyDefault, cuda_stream->cuda_stream()));
}
};
class MemcpyFactoryImpl : public MemcpyFactory {
public:
OF_DISALLOW_COPY_AND_MOVE(MemcpyFactoryImpl);
MemcpyFactoryImpl() = default;
~MemcpyFactoryImpl() override = default;
std::unique_ptr<Memcpy> New(MemcpyKind kind) override {
return std::unique_ptr<Memcpy>(new MemcpyImpl());
}
};
REGISTER_PRIMITIVE_FACTORY(DeviceType::kCUDA, MemcpyFactory, MemcpyFactoryImpl);
} // namespace
} // namespace primitive
} // namespace ep
} // namespace oneflow
#endif
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifdef WITH_ROCM
#include "oneflow/core/ep/include/primitive/memset.h"
#include "oneflow/core/ep/rocm/cuda_stream.h"
#include <hip/hip_runtime.h>
namespace oneflow {
namespace ep {
namespace primitive {
namespace {
class MemsetImpl : public Memset {
public:
OF_DISALLOW_COPY_AND_MOVE(MemsetImpl);
MemsetImpl() = default;
~MemsetImpl() override = default;
void Launch(Stream* stream, void* ptr, int value, size_t count) override {
auto* cuda_stream = stream->As<CudaStream>();
OF_CUDA_CHECK(hipMemsetAsync(ptr, value, count, cuda_stream->cuda_stream()));
}
};
class MemsetFactoryImpl : public MemsetFactory {
public:
OF_DISALLOW_COPY_AND_MOVE(MemsetFactoryImpl);
MemsetFactoryImpl() = default;
~MemsetFactoryImpl() override = default;
std::unique_ptr<Memset> New() override { return std::unique_ptr<Memset>(new MemsetImpl()); }
};
REGISTER_PRIMITIVE_FACTORY(DeviceType::kCUDA, MemsetFactory, MemsetFactoryImpl);
} // namespace
} // namespace primitive
} // namespace ep
} // namespace oneflow
#endif
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifdef WITH_ROCM
#include "oneflow/core/ep/include/primitive/memset.h"
#include "oneflow/core/ep/rocm/cuda_stream.h"
#include <hip/hip_runtime.h>
namespace oneflow {
namespace ep {
namespace primitive {
namespace {
class MemsetImpl : public Memset {
public:
OF_DISALLOW_COPY_AND_MOVE(MemsetImpl);
MemsetImpl() = default;
~MemsetImpl() override = default;
void Launch(Stream* stream, void* ptr, int value, size_t count) override {
auto* cuda_stream = stream->As<CudaStream>();
OF_CUDA_CHECK(hipMemsetAsync(ptr, value, count, cuda_stream->cuda_stream()));
}
};
class MemsetFactoryImpl : public MemsetFactory {
public:
OF_DISALLOW_COPY_AND_MOVE(MemsetFactoryImpl);
MemsetFactoryImpl() = default;
~MemsetFactoryImpl() override = default;
std::unique_ptr<Memset> New() override { return std::unique_ptr<Memset>(new MemsetImpl()); }
};
REGISTER_PRIMITIVE_FACTORY(DeviceType::kCUDA, MemsetFactory, MemsetFactoryImpl);
} // namespace
} // namespace primitive
} // namespace ep
} // namespace oneflow
#endif
#include "hip/hip_runtime.h"
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/ep/include/primitive/permute.h"
#include "oneflow/core/ep/common/primitive/permute_impl.h"
#include "oneflow/core/ep/rocm/cuda_stream.h"
#include <hip/hip_runtime.h>
namespace oneflow {
namespace ep {
namespace primitive {
namespace permute {
namespace internal {
namespace {
constexpr int32_t kMov4TileSize = 32;
constexpr int32_t kMov2TileSize = 64;
constexpr int32_t kBlockRows = 8;
template<size_t num_dims, size_t movement_size, typename IndexType>
__global__ void PermuteKernel(PermuteKernelParams<num_dims, IndexType> params) {
using T = typename std::aligned_storage<movement_size, movement_size>::type;
const T* src = reinterpret_cast<const T*>(params.src);
T* dst = reinterpret_cast<T*>(params.dst);
IndexType src_index[num_dims];
IndexType dst_index[num_dims];
CUDA_1D_KERNEL_LOOP_T(IndexType, i, params.count) {
params.dst_index_helper.OffsetToNdIndex(i, dst_index);
#pragma unroll
for (size_t dim = 0; dim < num_dims; ++dim) {
src_index[params.permutation[dim]] = dst_index[dim];
}
IndexType src_offset = params.src_index_helper.NdIndexToOffset(src_index);
dst[i] = src[src_offset];
}
}
// (B, X, Y) -> (B, Y, X)
// refer from https://developer.nvidia.com/blog/efficient-matrix-transpose-cuda-cc/
template<size_t num_dims, size_t movement_size, size_t tile_size, typename IndexType>
__global__ void BatchTransposeKernel(const void* src_ptr, void* dst_ptr, IndexType rows,
IndexType cols, IndexType num_tile_rows,
IndexType num_tile_cols, int32_t block_nums) {
const IndexType src_rows = rows;
const IndexType src_cols = cols;
const IndexType dst_rows = cols;
const IndexType dst_cols = rows;
using T = typename std::aligned_storage<movement_size, movement_size>::type;
__shared__ T tile[tile_size][tile_size + 1]; // To avoid bank conflict.
const T* src = reinterpret_cast<const T*>(src_ptr);
T* dst = reinterpret_cast<T*>(dst_ptr);
IndexType batch_num_tile = num_tile_rows * num_tile_cols;
for (int i = blockIdx.x, step = gridDim.x; i < block_nums; i += step) {
const IndexType batch_index = i / batch_num_tile; // the index of batch.
const IndexType tile_index =
i - batch_index * batch_num_tile; // equal to i % (num_tile_rows*num_tile_cols). the
// flatten index of tile in a batch.
const IndexType tile_row_index =
tile_index / num_tile_cols; // the row index of tile in a batch.
const IndexType tile_col_index =
tile_index
- tile_row_index
* num_tile_cols; // equal to k % num_tile_cols. the col index of tile in a batch.
const IndexType offset = batch_index * src_rows * src_cols;
{
IndexType col_in_tile = threadIdx.x;
IndexType col_in_matrix = tile_col_index * tile_size + threadIdx.x;
#pragma unroll
for (IndexType row_in_tile = threadIdx.y; row_in_tile < tile_size;
row_in_tile += kBlockRows) {
IndexType row_in_matrix = row_in_tile + tile_row_index * tile_size;
if (col_in_matrix < src_cols && row_in_matrix < src_rows) {
tile[row_in_tile][col_in_tile] = src[offset + row_in_matrix * src_cols + col_in_matrix];
}
}
}
__syncthreads();
{
IndexType col_in_tile = threadIdx.x;
IndexType col_in_matrix = tile_row_index * tile_size + threadIdx.x;
#pragma unroll
for (IndexType row_in_tile = threadIdx.y; row_in_tile < tile_size;
row_in_tile += kBlockRows) {
IndexType row_in_matrix = row_in_tile + tile_col_index * tile_size;
if (col_in_matrix < dst_cols && row_in_matrix < dst_rows) {
dst[offset + row_in_matrix * dst_cols + col_in_matrix] = tile[col_in_tile][row_in_tile];
}
}
}
__syncthreads();
}
}
/*
Here is a Movementsie=2 version of Batch Transpose.
When the H W can be divided by 2. we can read data use movementsize=4, and write back as
movementsize=4.
*/
template<size_t num_dims, size_t tile_size, typename IndexType>
__global__ void BatchTransposeMovement2Kernel(const void* src_ptr, void* dst_ptr, IndexType rows,
IndexType cols, IndexType num_tile_rows,
IndexType num_tile_cols, int32_t block_nums) {
const IndexType src_rows = rows;
const IndexType src_cols = cols;
const IndexType dst_rows = cols;
const IndexType dst_cols = rows;
static_assert(tile_size % 2 == 0, "");
using T_MOV2 = typename std::aligned_storage<2, 2>::type;
using T_MOV4 = typename std::aligned_storage<4, 4>::type;
const T_MOV4* src = reinterpret_cast<const T_MOV4*>(src_ptr);
T_MOV4* dst = reinterpret_cast<T_MOV4*>(dst_ptr);
// Use union structure to process Load and Store.
__shared__ union {
T_MOV2 tile_m2[tile_size][tile_size + 2]; // half [64][66]
T_MOV4 tile_m4[tile_size][tile_size / 2 + 1]; // half2 [64][33]
} tile_mem;
IndexType batch_num_tile = num_tile_rows * num_tile_cols;
for (int i = blockIdx.x, step = gridDim.x; i < block_nums; i += step) {
const IndexType batch_index = i / batch_num_tile; // the index of batch.
const IndexType tile_index =
i - batch_index * batch_num_tile; // equal to i % (num_tile_rows*num_tile_cols). the
// flatten index of tile in a batch.
const IndexType tile_row_index =
tile_index / num_tile_cols; // the row index of tile in a batch.
const IndexType tile_col_index =
tile_index
- tile_row_index
* num_tile_cols; // equal to k % num_tile_cols. the col index of tile in a batch.
const IndexType offset = batch_index * src_rows * src_cols;
{
IndexType col_in_tile = threadIdx.x;
IndexType col_in_matrix = tile_col_index * tile_size + threadIdx.x * 2;
#pragma unroll
for (IndexType row_in_tile = threadIdx.y; row_in_tile < tile_size;
row_in_tile += kBlockRows) {
IndexType row_in_matrix = row_in_tile + tile_row_index * tile_size;
if (col_in_matrix < src_cols && row_in_matrix < src_rows) {
tile_mem.tile_m4[row_in_tile][col_in_tile] =
src[(offset + row_in_matrix * src_cols + col_in_matrix) / 2];
}
}
}
__syncthreads();
{
IndexType col_in_tile = threadIdx.x;
IndexType col_in_matrix = tile_row_index * tile_size + threadIdx.x * 2;
#pragma unroll
for (IndexType row_in_tile = threadIdx.y; row_in_tile < tile_size;
row_in_tile += kBlockRows) {
IndexType row_in_matrix = row_in_tile + tile_col_index * tile_size;
union {
T_MOV4 m4;
T_MOV2 m2[2];
} tmp_storage;
if (col_in_matrix < dst_cols && row_in_matrix < dst_rows) {
tmp_storage.m2[0] = tile_mem.tile_m2[col_in_tile * 2][row_in_tile];
tmp_storage.m2[1] = tile_mem.tile_m2[col_in_tile * 2 + 1][row_in_tile];
dst[(offset + row_in_matrix * dst_cols + col_in_matrix) / 2] = tmp_storage.m4;
}
}
}
__syncthreads();
}
}
template<size_t num_dims, size_t movement_size, size_t tile_size, typename IndexType>
void LaunchBatchTransposeKernel(hipStream_t& cuda_stream,
const PermuteKernelParams<num_dims, IndexType>& params,
const IndexType& num_batches, const IndexType& rows,
const IndexType& cols) {
IndexType num_tile_rows = (rows + tile_size - 1) / tile_size;
IndexType num_tile_cols = (cols + tile_size - 1) / tile_size;
const int32_t block_nums = num_batches * num_tile_rows * num_tile_cols;
int32_t launched_block_nums = std::min(block_nums, kCudaMaxBlocksNum);
if (tile_size == kMov2TileSize) {
const int32_t half2_thread = tile_size / 2; // cause each thread process two half elements.
BatchTransposeMovement2Kernel<num_dims, kMov2TileSize, IndexType>
<<<launched_block_nums, dim3(half2_thread, kBlockRows), 0, cuda_stream>>>(
params.src, params.dst, rows, cols, num_tile_rows, num_tile_cols,
block_nums); // Set threads num as 32x8 cause each threads
// process 4 elements to 64x66 half share memory.
} else {
BatchTransposeKernel<num_dims, movement_size, tile_size, IndexType>
<<<launched_block_nums, dim3(tile_size, kBlockRows), 0, cuda_stream>>>(
params.src, params.dst, rows, cols, num_tile_rows, num_tile_cols, block_nums);
}
}
template<size_t tile_size, typename IndexType>
bool CheckIfGreaterEqualThanTileSize(const IndexType& rows, const IndexType& cols) {
if (rows < tile_size || cols < tile_size) { return false; }
return true;
}
template<size_t num_dims, size_t tile_size, typename IndexType>
bool CheckLaunchBatchTranspose(const int* permutation, const IndexType& num_batches,
const IndexType& rows, const IndexType& cols) {
if (CheckIfGreaterEqualThanTileSize<tile_size, IndexType>(rows, cols)) {
if (num_batches == 1 && permutation[1] == 0 && permutation[0] == 1) {
// 2d tensor case: (0, 1) -> (1, 0)
return true;
} else if (num_dims == 3 && permutation[2] == 1 && permutation[1] == 2) {
// 3d tensor case: (0, 1, 2) -> (0, 2, 1)
return true;
} else {
return false;
}
}
return false;
}
template<typename IndexType, size_t movement_size>
bool CheckUseMov2(const IndexType& rows, const IndexType& cols, const void* src, void* dst) {
auto src_ptr = reinterpret_cast<std::uintptr_t>(src);
auto dst_ptr = reinterpret_cast<std::uintptr_t>(dst);
return (movement_size == 2) && (rows % 2 == 0) && (cols % 2 == 0) && (src_ptr % 4 == 0)
&& (dst_ptr % 4 == 0);
}
template<size_t num_dims, typename IndexType>
void InferBatchTransposeShape(const int64_t* src_dims, IndexType* num_batches, IndexType* rows,
IndexType* cols) {
if (num_dims == 2) {
*num_batches = 1;
*rows = src_dims[0];
*cols = src_dims[1];
} else {
*num_batches = src_dims[0];
*rows = src_dims[1];
*cols = src_dims[2];
}
}
template<size_t num_dims, size_t movement_size, typename IndexType>
void LaunchKernel(Stream* stream, const int64_t* src_dims, const void* src, const int* permutation,
void* dst, size_t count) {
PermuteKernelParams<num_dims, IndexType> params =
MakePermuteParams<num_dims, IndexType>(src_dims, src, permutation, dst, count);
hipStream_t cuda_stream = stream->As<CudaStream>()->cuda_stream();
if (num_dims == 2 || num_dims == 3) {
IndexType num_batches;
IndexType rows;
IndexType cols;
InferBatchTransposeShape<num_dims, IndexType>(src_dims, &num_batches, &rows, &cols);
if (CheckLaunchBatchTranspose<num_dims, kMov4TileSize>(params.permutation, num_batches, rows,
cols)) {
if (CheckUseMov2<IndexType, movement_size>(rows, cols, src, dst)) {
LaunchBatchTransposeKernel<num_dims, 2, kMov2TileSize, IndexType>(cuda_stream, params,
num_batches, rows, cols);
} else {
LaunchBatchTransposeKernel<num_dims, movement_size, kMov4TileSize, IndexType>(
cuda_stream, params, num_batches, rows, cols);
}
} else {
PermuteKernel<num_dims, movement_size, IndexType>
<<<BlocksNum4ThreadsNum(params.count), kCudaThreadsNumPerBlock, 0, cuda_stream>>>(params);
}
} else {
PermuteKernel<num_dims, movement_size, IndexType>
<<<BlocksNum4ThreadsNum(params.count), kCudaThreadsNumPerBlock, 0, cuda_stream>>>(params);
}
}
class PermuteImpl : public Permute {
public:
OF_DISALLOW_COPY_AND_MOVE(PermuteImpl);
PermuteImpl() = default;
~PermuteImpl() override = default;
using Permute::Launch;
void Launch(Stream* stream, DataType data_type, size_t num_dims, const int64_t* src_dims,
const void* src, const int* permutation, void* dst) override {
SimplifyThenLaunch(stream, data_type, num_dims, src_dims, src, permutation, dst);
}
};
class PermuteFactoryImpl : public PermuteFactory {
public:
OF_DISALLOW_COPY_AND_MOVE(PermuteFactoryImpl);
PermuteFactoryImpl() = default;
~PermuteFactoryImpl() override = default;
std::unique_ptr<Permute> New(size_t max_num_dims) override {
if (max_num_dims <= kMaxNumDims) {
return std::unique_ptr<Permute>(new PermuteImpl());
} else {
return nullptr;
}
}
};
REGISTER_PRIMITIVE_FACTORY(DeviceType::kCUDA, PermuteFactory, PermuteFactoryImpl);
} // namespace
} // namespace internal
} // namespace permute
} // namespace primitive
} // namespace ep
} // namespace oneflow
#include "hip/hip_runtime.h"
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/ep/include/primitive/permute.h"
#include "oneflow/core/ep/common/primitive/permute_impl.h"
#include "oneflow/core/ep/rocm/cuda_stream.h"
#include <hip/hip_runtime.h>
namespace oneflow {
namespace ep {
namespace primitive {
namespace permute {
namespace internal {
namespace {
constexpr int32_t kMov4TileSize = 32;
constexpr int32_t kMov2TileSize = 64;
constexpr int32_t kBlockRows = 8;
template<size_t num_dims, size_t movement_size, typename IndexType>
__global__ void PermuteKernel(PermuteKernelParams<num_dims, IndexType> params) {
using T = typename std::aligned_storage<movement_size, movement_size>::type;
const T* src = reinterpret_cast<const T*>(params.src);
T* dst = reinterpret_cast<T*>(params.dst);
IndexType src_index[num_dims];
IndexType dst_index[num_dims];
CUDA_1D_KERNEL_LOOP_T(IndexType, i, params.count) {
params.dst_index_helper.OffsetToNdIndex(i, dst_index);
#pragma unroll
for (size_t dim = 0; dim < num_dims; ++dim) {
src_index[params.permutation[dim]] = dst_index[dim];
}
IndexType src_offset = params.src_index_helper.NdIndexToOffset(src_index);
dst[i] = src[src_offset];
}
}
// (B, X, Y) -> (B, Y, X)
// refer from https://developer.nvidia.com/blog/efficient-matrix-transpose-cuda-cc/
template<size_t num_dims, size_t movement_size, size_t tile_size, typename IndexType>
__global__ void BatchTransposeKernel(const void* src_ptr, void* dst_ptr, IndexType rows,
IndexType cols, IndexType num_tile_rows,
IndexType num_tile_cols, int32_t block_nums) {
const IndexType src_rows = rows;
const IndexType src_cols = cols;
const IndexType dst_rows = cols;
const IndexType dst_cols = rows;
using T = typename std::aligned_storage<movement_size, movement_size>::type;
__shared__ T tile[tile_size][tile_size + 1]; // To avoid bank conflict.
const T* src = reinterpret_cast<const T*>(src_ptr);
T* dst = reinterpret_cast<T*>(dst_ptr);
IndexType batch_num_tile = num_tile_rows * num_tile_cols;
for (int i = blockIdx.x, step = gridDim.x; i < block_nums; i += step) {
const IndexType batch_index = i / batch_num_tile; // the index of batch.
const IndexType tile_index =
i - batch_index * batch_num_tile; // equal to i % (num_tile_rows*num_tile_cols). the
// flatten index of tile in a batch.
const IndexType tile_row_index =
tile_index / num_tile_cols; // the row index of tile in a batch.
const IndexType tile_col_index =
tile_index
- tile_row_index
* num_tile_cols; // equal to k % num_tile_cols. the col index of tile in a batch.
const IndexType offset = batch_index * src_rows * src_cols;
{
IndexType col_in_tile = threadIdx.x;
IndexType col_in_matrix = tile_col_index * tile_size + threadIdx.x;
#pragma unroll
for (IndexType row_in_tile = threadIdx.y; row_in_tile < tile_size;
row_in_tile += kBlockRows) {
IndexType row_in_matrix = row_in_tile + tile_row_index * tile_size;
if (col_in_matrix < src_cols && row_in_matrix < src_rows) {
tile[row_in_tile][col_in_tile] = src[offset + row_in_matrix * src_cols + col_in_matrix];
}
}
}
__syncthreads();
{
IndexType col_in_tile = threadIdx.x;
IndexType col_in_matrix = tile_row_index * tile_size + threadIdx.x;
#pragma unroll
for (IndexType row_in_tile = threadIdx.y; row_in_tile < tile_size;
row_in_tile += kBlockRows) {
IndexType row_in_matrix = row_in_tile + tile_col_index * tile_size;
if (col_in_matrix < dst_cols && row_in_matrix < dst_rows) {
dst[offset + row_in_matrix * dst_cols + col_in_matrix] = tile[col_in_tile][row_in_tile];
}
}
}
__syncthreads();
}
}
/*
Here is a Movementsie=2 version of Batch Transpose.
When the H W can be divided by 2. we can read data use movementsize=4, and write back as
movementsize=4.
*/
template<size_t num_dims, size_t tile_size, typename IndexType>
__global__ void BatchTransposeMovement2Kernel(const void* src_ptr, void* dst_ptr, IndexType rows,
IndexType cols, IndexType num_tile_rows,
IndexType num_tile_cols, int32_t block_nums) {
const IndexType src_rows = rows;
const IndexType src_cols = cols;
const IndexType dst_rows = cols;
const IndexType dst_cols = rows;
static_assert(tile_size % 2 == 0, "");
using T_MOV2 = typename std::aligned_storage<2, 2>::type;
using T_MOV4 = typename std::aligned_storage<4, 4>::type;
const T_MOV4* src = reinterpret_cast<const T_MOV4*>(src_ptr);
T_MOV4* dst = reinterpret_cast<T_MOV4*>(dst_ptr);
// Use union structure to process Load and Store.
__shared__ union {
T_MOV2 tile_m2[tile_size][tile_size + 2]; // half [64][66]
T_MOV4 tile_m4[tile_size][tile_size / 2 + 1]; // half2 [64][33]
} tile_mem;
IndexType batch_num_tile = num_tile_rows * num_tile_cols;
for (int i = blockIdx.x, step = gridDim.x; i < block_nums; i += step) {
const IndexType batch_index = i / batch_num_tile; // the index of batch.
const IndexType tile_index =
i - batch_index * batch_num_tile; // equal to i % (num_tile_rows*num_tile_cols). the
// flatten index of tile in a batch.
const IndexType tile_row_index =
tile_index / num_tile_cols; // the row index of tile in a batch.
const IndexType tile_col_index =
tile_index
- tile_row_index
* num_tile_cols; // equal to k % num_tile_cols. the col index of tile in a batch.
const IndexType offset = batch_index * src_rows * src_cols;
{
IndexType col_in_tile = threadIdx.x;
IndexType col_in_matrix = tile_col_index * tile_size + threadIdx.x * 2;
#pragma unroll
for (IndexType row_in_tile = threadIdx.y; row_in_tile < tile_size;
row_in_tile += kBlockRows) {
IndexType row_in_matrix = row_in_tile + tile_row_index * tile_size;
if (col_in_matrix < src_cols && row_in_matrix < src_rows) {
tile_mem.tile_m4[row_in_tile][col_in_tile] =
src[(offset + row_in_matrix * src_cols + col_in_matrix) / 2];
}
}
}
__syncthreads();
{
IndexType col_in_tile = threadIdx.x;
IndexType col_in_matrix = tile_row_index * tile_size + threadIdx.x * 2;
#pragma unroll
for (IndexType row_in_tile = threadIdx.y; row_in_tile < tile_size;
row_in_tile += kBlockRows) {
IndexType row_in_matrix = row_in_tile + tile_col_index * tile_size;
union {
T_MOV4 m4;
T_MOV2 m2[2];
} tmp_storage;
if (col_in_matrix < dst_cols && row_in_matrix < dst_rows) {
tmp_storage.m2[0] = tile_mem.tile_m2[col_in_tile * 2][row_in_tile];
tmp_storage.m2[1] = tile_mem.tile_m2[col_in_tile * 2 + 1][row_in_tile];
dst[(offset + row_in_matrix * dst_cols + col_in_matrix) / 2] = tmp_storage.m4;
}
}
}
__syncthreads();
}
}
template<size_t num_dims, size_t movement_size, size_t tile_size, typename IndexType>
void LaunchBatchTransposeKernel(hipStream_t& cuda_stream,
const PermuteKernelParams<num_dims, IndexType>& params,
const IndexType& num_batches, const IndexType& rows,
const IndexType& cols) {
IndexType num_tile_rows = (rows + tile_size - 1) / tile_size;
IndexType num_tile_cols = (cols + tile_size - 1) / tile_size;
const int32_t block_nums = num_batches * num_tile_rows * num_tile_cols;
int32_t launched_block_nums = std::min(block_nums, kCudaMaxBlocksNum);
if (tile_size == kMov2TileSize) {
const int32_t half2_thread = tile_size / 2; // cause each thread process two half elements.
BatchTransposeMovement2Kernel<num_dims, kMov2TileSize, IndexType>
<<<launched_block_nums, dim3(half2_thread, kBlockRows), 0, cuda_stream>>>(
params.src, params.dst, rows, cols, num_tile_rows, num_tile_cols,
block_nums); // Set threads num as 32x8 cause each threads
// process 4 elements to 64x66 half share memory.
} else {
BatchTransposeKernel<num_dims, movement_size, tile_size, IndexType>
<<<launched_block_nums, dim3(tile_size, kBlockRows), 0, cuda_stream>>>(
params.src, params.dst, rows, cols, num_tile_rows, num_tile_cols, block_nums);
}
}
template<size_t tile_size, typename IndexType>
bool CheckIfGreaterEqualThanTileSize(const IndexType& rows, const IndexType& cols) {
if (rows < tile_size || cols < tile_size) { return false; }
return true;
}
template<size_t num_dims, size_t tile_size, typename IndexType>
bool CheckLaunchBatchTranspose(const int* permutation, const IndexType& num_batches,
const IndexType& rows, const IndexType& cols) {
if (CheckIfGreaterEqualThanTileSize<tile_size, IndexType>(rows, cols)) {
if (num_batches == 1 && permutation[1] == 0 && permutation[0] == 1) {
// 2d tensor case: (0, 1) -> (1, 0)
return true;
} else if (num_dims == 3 && permutation[2] == 1 && permutation[1] == 2) {
// 3d tensor case: (0, 1, 2) -> (0, 2, 1)
return true;
} else {
return false;
}
}
return false;
}
template<typename IndexType, size_t movement_size>
bool CheckUseMov2(const IndexType& rows, const IndexType& cols, const void* src, void* dst) {
auto src_ptr = reinterpret_cast<std::uintptr_t>(src);
auto dst_ptr = reinterpret_cast<std::uintptr_t>(dst);
return (movement_size == 2) && (rows % 2 == 0) && (cols % 2 == 0) && (src_ptr % 4 == 0)
&& (dst_ptr % 4 == 0);
}
template<size_t num_dims, typename IndexType>
void InferBatchTransposeShape(const int64_t* src_dims, IndexType* num_batches, IndexType* rows,
IndexType* cols) {
if (num_dims == 2) {
*num_batches = 1;
*rows = src_dims[0];
*cols = src_dims[1];
} else {
*num_batches = src_dims[0];
*rows = src_dims[1];
*cols = src_dims[2];
}
}
template<size_t num_dims, size_t movement_size, typename IndexType>
void LaunchKernel(Stream* stream, const int64_t* src_dims, const void* src, const int* permutation,
void* dst, size_t count) {
PermuteKernelParams<num_dims, IndexType> params =
MakePermuteParams<num_dims, IndexType>(src_dims, src, permutation, dst, count);
hipStream_t cuda_stream = stream->As<CudaStream>()->cuda_stream();
if (num_dims == 2 || num_dims == 3) {
IndexType num_batches;
IndexType rows;
IndexType cols;
InferBatchTransposeShape<num_dims, IndexType>(src_dims, &num_batches, &rows, &cols);
if (CheckLaunchBatchTranspose<num_dims, kMov4TileSize>(params.permutation, num_batches, rows,
cols)) {
if (CheckUseMov2<IndexType, movement_size>(rows, cols, src, dst)) {
LaunchBatchTransposeKernel<num_dims, 2, kMov2TileSize, IndexType>(cuda_stream, params,
num_batches, rows, cols);
} else {
LaunchBatchTransposeKernel<num_dims, movement_size, kMov4TileSize, IndexType>(
cuda_stream, params, num_batches, rows, cols);
}
} else {
PermuteKernel<num_dims, movement_size, IndexType>
<<<BlocksNum4ThreadsNum(params.count), kCudaThreadsNumPerBlock, 0, cuda_stream>>>(params);
}
} else {
PermuteKernel<num_dims, movement_size, IndexType>
<<<BlocksNum4ThreadsNum(params.count), kCudaThreadsNumPerBlock, 0, cuda_stream>>>(params);
}
}
class PermuteImpl : public Permute {
public:
OF_DISALLOW_COPY_AND_MOVE(PermuteImpl);
PermuteImpl() = default;
~PermuteImpl() override = default;
using Permute::Launch;
void Launch(Stream* stream, DataType data_type, size_t num_dims, const int64_t* src_dims,
const void* src, const int* permutation, void* dst) override {
SimplifyThenLaunch(stream, data_type, num_dims, src_dims, src, permutation, dst);
}
};
class PermuteFactoryImpl : public PermuteFactory {
public:
OF_DISALLOW_COPY_AND_MOVE(PermuteFactoryImpl);
PermuteFactoryImpl() = default;
~PermuteFactoryImpl() override = default;
std::unique_ptr<Permute> New(size_t max_num_dims) override {
if (max_num_dims <= kMaxNumDims) {
return std::unique_ptr<Permute>(new PermuteImpl());
} else {
return nullptr;
}
}
};
REGISTER_PRIMITIVE_FACTORY(DeviceType::kCUDA, PermuteFactory, PermuteFactoryImpl);
} // namespace
} // namespace internal
} // namespace permute
} // namespace primitive
} // namespace ep
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/common/preprocessor.h"
#include "oneflow/core/ep/include/primitive/softmax.h"
#include "oneflow/core/ep/include/primitive/log_softmax.h"
#include "oneflow/core/ep/rocm/primitive/type_seq.h"
#include "oneflow/core/hip/softmax.hip.h"
#include "oneflow/core/ep/rocm/cuda_stream.h"
namespace oneflow {
namespace ep {
namespace primitive {
namespace {
enum class Algorithm {
kSoftmax,
kLogSoftmax,
};
template<Algorithm algorithm, typename T>
void SoftmaxGpu(hipStream_t cuda_stream, size_t rows, size_t cols, const T* x, T* y) {
using ComputeType = typename cuda::softmax::DefaultComputeType<T>::type;
oneflow::cuda::softmax::DirectLoad<T, ComputeType> load(x, cols);
oneflow::cuda::softmax::DirectStore<ComputeType, T> store(y, cols);
if (algorithm == Algorithm::kSoftmax) {
OF_CUDA_CHECK((cuda::softmax::DispatchSoftmax<decltype(load), decltype(store), ComputeType>(
cuda_stream, load, store, rows, cols)));
} else if (algorithm == Algorithm::kLogSoftmax) {
OF_CUDA_CHECK((cuda::softmax::DispatchLogSoftmax<decltype(load), decltype(store), ComputeType>(
cuda_stream, load, store, rows, cols)));
} else {
UNIMPLEMENTED();
}
}
template<typename SoftmaxBase, Algorithm algorithm, typename T>
class SoftmaxImpl : public SoftmaxBase {
public:
OF_DISALLOW_COPY_AND_MOVE(SoftmaxImpl);
SoftmaxImpl() = default;
~SoftmaxImpl() override = default;
void Launch(Stream* stream, size_t rows, size_t cols, const void* x, void* y) override {
hipStream_t cuda_stream = stream->As<CudaStream>()->cuda_stream();
SoftmaxGpu<algorithm, T>(cuda_stream, rows, cols, reinterpret_cast<const T*>(x),
reinterpret_cast<T*>(y));
}
};
template<typename SoftmaxBase, Algorithm algorithm, typename T>
std::unique_ptr<SoftmaxBase> NewSoftmax() {
return std::unique_ptr<SoftmaxBase>(new SoftmaxImpl<SoftmaxBase, algorithm, T>());
}
template<typename FactoryBase, typename SoftmaxBase, Algorithm algorithm>
class GenericSoftmaxFactoryImpl : public FactoryBase {
public:
OF_DISALLOW_COPY_AND_MOVE(GenericSoftmaxFactoryImpl);
GenericSoftmaxFactoryImpl() = default;
~GenericSoftmaxFactoryImpl() override = default;
std::unique_ptr<SoftmaxBase> New(DataType data_type) override {
#define MAKE_NEW_SOFTMAX_ENTRY(type_cpp, type_proto) \
{type_proto, NewSoftmax<SoftmaxBase, algorithm, type_cpp>},
static const std::map<DataType, std::function<std::unique_ptr<SoftmaxBase>()>>
new_softmax_handle{
OF_PP_FOR_EACH_TUPLE(MAKE_NEW_SOFTMAX_ENTRY, CUDA_PRIMITIVE_FLOATING_TYPE_SEQ)};
#undef MAKE_NEW_SOFTMAX_ENTRY
const auto it = new_softmax_handle.find(data_type);
if (it != new_softmax_handle.end()) {
return it->second();
} else {
return nullptr;
}
}
};
using SoftmaxFactoryImpl = GenericSoftmaxFactoryImpl<SoftmaxFactory, Softmax, Algorithm::kSoftmax>;
using LogSoftmaxFactoryImpl =
GenericSoftmaxFactoryImpl<LogSoftmaxFactory, LogSoftmax, Algorithm::kLogSoftmax>;
REGISTER_PRIMITIVE_FACTORY(DeviceType::kCUDA, SoftmaxFactory, SoftmaxFactoryImpl);
REGISTER_PRIMITIVE_FACTORY(DeviceType::kCUDA, LogSoftmaxFactory, LogSoftmaxFactoryImpl);
} // namespace
} // namespace primitive
} // namespace ep
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/common/preprocessor.h"
#include "oneflow/core/ep/include/primitive/softmax.h"
#include "oneflow/core/ep/include/primitive/log_softmax.h"
#include "oneflow/core/ep/rocm/primitive/type_seq.h"
#include "oneflow/core/hip/softmax.hip.h"
#include "oneflow/core/ep/rocm/cuda_stream.h"
namespace oneflow {
namespace ep {
namespace primitive {
namespace {
enum class Algorithm {
kSoftmax,
kLogSoftmax,
};
template<Algorithm algorithm, typename T>
void SoftmaxGpu(hipStream_t cuda_stream, size_t rows, size_t cols, const T* x, T* y) {
using ComputeType = typename cuda::softmax::DefaultComputeType<T>::type;
oneflow::cuda::softmax::DirectLoad<T, ComputeType> load(x, cols);
oneflow::cuda::softmax::DirectStore<ComputeType, T> store(y, cols);
if (algorithm == Algorithm::kSoftmax) {
OF_CUDA_CHECK((cuda::softmax::DispatchSoftmax<decltype(load), decltype(store), ComputeType>(
cuda_stream, load, store, rows, cols)));
} else if (algorithm == Algorithm::kLogSoftmax) {
OF_CUDA_CHECK((cuda::softmax::DispatchLogSoftmax<decltype(load), decltype(store), ComputeType>(
cuda_stream, load, store, rows, cols)));
} else {
UNIMPLEMENTED();
}
}
template<typename SoftmaxBase, Algorithm algorithm, typename T>
class SoftmaxImpl : public SoftmaxBase {
public:
OF_DISALLOW_COPY_AND_MOVE(SoftmaxImpl);
SoftmaxImpl() = default;
~SoftmaxImpl() override = default;
void Launch(Stream* stream, size_t rows, size_t cols, const void* x, void* y) override {
hipStream_t cuda_stream = stream->As<CudaStream>()->cuda_stream();
SoftmaxGpu<algorithm, T>(cuda_stream, rows, cols, reinterpret_cast<const T*>(x),
reinterpret_cast<T*>(y));
}
};
template<typename SoftmaxBase, Algorithm algorithm, typename T>
std::unique_ptr<SoftmaxBase> NewSoftmax() {
return std::unique_ptr<SoftmaxBase>(new SoftmaxImpl<SoftmaxBase, algorithm, T>());
}
template<typename FactoryBase, typename SoftmaxBase, Algorithm algorithm>
class GenericSoftmaxFactoryImpl : public FactoryBase {
public:
OF_DISALLOW_COPY_AND_MOVE(GenericSoftmaxFactoryImpl);
GenericSoftmaxFactoryImpl() = default;
~GenericSoftmaxFactoryImpl() override = default;
std::unique_ptr<SoftmaxBase> New(DataType data_type) override {
#define MAKE_NEW_SOFTMAX_ENTRY(type_cpp, type_proto) \
{type_proto, NewSoftmax<SoftmaxBase, algorithm, type_cpp>},
static const std::map<DataType, std::function<std::unique_ptr<SoftmaxBase>()>>
new_softmax_handle{
OF_PP_FOR_EACH_TUPLE(MAKE_NEW_SOFTMAX_ENTRY, CUDA_PRIMITIVE_FLOATING_TYPE_SEQ)};
#undef MAKE_NEW_SOFTMAX_ENTRY
const auto it = new_softmax_handle.find(data_type);
if (it != new_softmax_handle.end()) {
return it->second();
} else {
return nullptr;
}
}
};
using SoftmaxFactoryImpl = GenericSoftmaxFactoryImpl<SoftmaxFactory, Softmax, Algorithm::kSoftmax>;
using LogSoftmaxFactoryImpl =
GenericSoftmaxFactoryImpl<LogSoftmaxFactory, LogSoftmax, Algorithm::kLogSoftmax>;
REGISTER_PRIMITIVE_FACTORY(DeviceType::kCUDA, SoftmaxFactory, SoftmaxFactoryImpl);
REGISTER_PRIMITIVE_FACTORY(DeviceType::kCUDA, LogSoftmaxFactory, LogSoftmaxFactoryImpl);
} // namespace
} // namespace primitive
} // namespace ep
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/common/preprocessor.h"
#include "oneflow/core/ep/include/primitive/softmax_backward.h"
#include "oneflow/core/ep/include/primitive/log_softmax_backward.h"
#include "oneflow/core/ep/rocm/primitive/type_seq.h"
#include "oneflow/core/hip/softmax.hip.h"
#include "oneflow/core/ep/rocm/cuda_stream.h"
namespace oneflow {
namespace ep {
namespace primitive {
namespace {
enum class Algorithm {
kSoftmax,
kLogSoftmax,
};
template<Algorithm algorithm, typename T>
void SoftmaxBackwardGpu(hipStream_t cuda_stream, size_t rows, size_t cols, const T* y, const T* dy,
T* dx) {
using ComputeType = typename cuda::softmax::DefaultComputeType<T>::type;
cuda::softmax::DirectLoad<T, ComputeType> load_y(y, cols);
cuda::softmax::DirectLoad<T, ComputeType> load_dy(dy, cols);
cuda::softmax::DirectStore<ComputeType, T> store(dx, cols);
if (algorithm == Algorithm::kSoftmax) {
OF_CUDA_CHECK((cuda::softmax::DispatchSoftmaxGrad<decltype(load_y), decltype(load_dy),
decltype(store), ComputeType>(
cuda_stream, load_y, load_dy, store, rows, cols)));
} else if (algorithm == Algorithm::kLogSoftmax) {
OF_CUDA_CHECK((cuda::softmax::DispatchLogSoftmaxGrad<decltype(load_y), decltype(load_dy),
decltype(store), ComputeType>(
cuda_stream, load_y, load_dy, store, rows, cols)));
} else {
UNIMPLEMENTED();
}
}
template<typename SoftmaxBackwardBase, Algorithm algorithm, typename T>
class SoftmaxBackwardImpl : public SoftmaxBackwardBase {
public:
OF_DISALLOW_COPY_AND_MOVE(SoftmaxBackwardImpl);
SoftmaxBackwardImpl() = default;
~SoftmaxBackwardImpl() override = default;
void Launch(Stream* stream, size_t rows, size_t cols, const void* y, const void* dy,
void* dx) override {
hipStream_t cuda_stream = stream->As<CudaStream>()->cuda_stream();
SoftmaxBackwardGpu<algorithm, T>(cuda_stream, rows, cols, reinterpret_cast<const T*>(y),
reinterpret_cast<const T*>(dy), reinterpret_cast<T*>(dx));
}
};
template<typename SoftmaxBackwardBase, Algorithm algorithm, typename T>
std::unique_ptr<SoftmaxBackwardBase> NewSoftmaxBackward() {
return std::unique_ptr<SoftmaxBackwardBase>(
new SoftmaxBackwardImpl<SoftmaxBackwardBase, algorithm, T>());
}
template<typename BackwardFactoryBase, typename SoftmaxBackwardBase, Algorithm algorithm>
class GenericSoftmaxBackwardFactoryImpl : public BackwardFactoryBase {
public:
OF_DISALLOW_COPY_AND_MOVE(GenericSoftmaxBackwardFactoryImpl);
GenericSoftmaxBackwardFactoryImpl() = default;
~GenericSoftmaxBackwardFactoryImpl() override = default;
std::unique_ptr<SoftmaxBackwardBase> New(DataType data_type) override {
#define MAKE_NEW_SOFTMAX_ENTRY(type_cpp, type_proto) \
{type_proto, NewSoftmaxBackward<SoftmaxBackwardBase, algorithm, type_cpp>},
static const std::map<DataType, std::function<std::unique_ptr<SoftmaxBackwardBase>()>>
new_softmax_backward_handle{
OF_PP_FOR_EACH_TUPLE(MAKE_NEW_SOFTMAX_ENTRY, CUDA_PRIMITIVE_FLOATING_TYPE_SEQ)};
#undef MAKE_NEW_SOFTMAX_ENTRY
const auto it = new_softmax_backward_handle.find(data_type);
if (it != new_softmax_backward_handle.end()) {
return it->second();
} else {
return nullptr;
}
}
};
using SoftmaxBackwardFactoryImpl =
GenericSoftmaxBackwardFactoryImpl<SoftmaxBackwardFactory, SoftmaxBackward, Algorithm::kSoftmax>;
using LogSoftmaxBackwardFactoryImpl =
GenericSoftmaxBackwardFactoryImpl<LogSoftmaxBackwardFactory, LogSoftmaxBackward,
Algorithm::kLogSoftmax>;
REGISTER_PRIMITIVE_FACTORY(DeviceType::kCUDA, SoftmaxBackwardFactory, SoftmaxBackwardFactoryImpl);
REGISTER_PRIMITIVE_FACTORY(DeviceType::kCUDA, LogSoftmaxBackwardFactory,
LogSoftmaxBackwardFactoryImpl);
} // namespace
} // namespace primitive
} // namespace ep
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/common/preprocessor.h"
#include "oneflow/core/ep/include/primitive/softmax_backward.h"
#include "oneflow/core/ep/include/primitive/log_softmax_backward.h"
#include "oneflow/core/ep/rocm/primitive/type_seq.h"
#include "oneflow/core/hip/softmax.hip.h"
#include "oneflow/core/ep/rocm/cuda_stream.h"
namespace oneflow {
namespace ep {
namespace primitive {
namespace {
enum class Algorithm {
kSoftmax,
kLogSoftmax,
};
template<Algorithm algorithm, typename T>
void SoftmaxBackwardGpu(hipStream_t cuda_stream, size_t rows, size_t cols, const T* y, const T* dy,
T* dx) {
using ComputeType = typename cuda::softmax::DefaultComputeType<T>::type;
cuda::softmax::DirectLoad<T, ComputeType> load_y(y, cols);
cuda::softmax::DirectLoad<T, ComputeType> load_dy(dy, cols);
cuda::softmax::DirectStore<ComputeType, T> store(dx, cols);
if (algorithm == Algorithm::kSoftmax) {
OF_CUDA_CHECK((cuda::softmax::DispatchSoftmaxGrad<decltype(load_y), decltype(load_dy),
decltype(store), ComputeType>(
cuda_stream, load_y, load_dy, store, rows, cols)));
} else if (algorithm == Algorithm::kLogSoftmax) {
OF_CUDA_CHECK((cuda::softmax::DispatchLogSoftmaxGrad<decltype(load_y), decltype(load_dy),
decltype(store), ComputeType>(
cuda_stream, load_y, load_dy, store, rows, cols)));
} else {
UNIMPLEMENTED();
}
}
template<typename SoftmaxBackwardBase, Algorithm algorithm, typename T>
class SoftmaxBackwardImpl : public SoftmaxBackwardBase {
public:
OF_DISALLOW_COPY_AND_MOVE(SoftmaxBackwardImpl);
SoftmaxBackwardImpl() = default;
~SoftmaxBackwardImpl() override = default;
void Launch(Stream* stream, size_t rows, size_t cols, const void* y, const void* dy,
void* dx) override {
hipStream_t cuda_stream = stream->As<CudaStream>()->cuda_stream();
SoftmaxBackwardGpu<algorithm, T>(cuda_stream, rows, cols, reinterpret_cast<const T*>(y),
reinterpret_cast<const T*>(dy), reinterpret_cast<T*>(dx));
}
};
template<typename SoftmaxBackwardBase, Algorithm algorithm, typename T>
std::unique_ptr<SoftmaxBackwardBase> NewSoftmaxBackward() {
return std::unique_ptr<SoftmaxBackwardBase>(
new SoftmaxBackwardImpl<SoftmaxBackwardBase, algorithm, T>());
}
template<typename BackwardFactoryBase, typename SoftmaxBackwardBase, Algorithm algorithm>
class GenericSoftmaxBackwardFactoryImpl : public BackwardFactoryBase {
public:
OF_DISALLOW_COPY_AND_MOVE(GenericSoftmaxBackwardFactoryImpl);
GenericSoftmaxBackwardFactoryImpl() = default;
~GenericSoftmaxBackwardFactoryImpl() override = default;
std::unique_ptr<SoftmaxBackwardBase> New(DataType data_type) override {
#define MAKE_NEW_SOFTMAX_ENTRY(type_cpp, type_proto) \
{type_proto, NewSoftmaxBackward<SoftmaxBackwardBase, algorithm, type_cpp>},
static const std::map<DataType, std::function<std::unique_ptr<SoftmaxBackwardBase>()>>
new_softmax_backward_handle{
OF_PP_FOR_EACH_TUPLE(MAKE_NEW_SOFTMAX_ENTRY, CUDA_PRIMITIVE_FLOATING_TYPE_SEQ)};
#undef MAKE_NEW_SOFTMAX_ENTRY
const auto it = new_softmax_backward_handle.find(data_type);
if (it != new_softmax_backward_handle.end()) {
return it->second();
} else {
return nullptr;
}
}
};
using SoftmaxBackwardFactoryImpl =
GenericSoftmaxBackwardFactoryImpl<SoftmaxBackwardFactory, SoftmaxBackward, Algorithm::kSoftmax>;
using LogSoftmaxBackwardFactoryImpl =
GenericSoftmaxBackwardFactoryImpl<LogSoftmaxBackwardFactory, LogSoftmaxBackward,
Algorithm::kLogSoftmax>;
REGISTER_PRIMITIVE_FACTORY(DeviceType::kCUDA, SoftmaxBackwardFactory, SoftmaxBackwardFactoryImpl);
REGISTER_PRIMITIVE_FACTORY(DeviceType::kCUDA, LogSoftmaxBackwardFactory,
LogSoftmaxBackwardFactoryImpl);
} // namespace
} // namespace primitive
} // namespace ep
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_EP_CUDA_PRIMITIVE_TYPE_SEQ_H_
#define ONEFLOW_CORE_EP_CUDA_PRIMITIVE_TYPE_SEQ_H_
#include "oneflow/core/common/preprocessor.h"
#include "oneflow/core/common/data_type.h"
#ifdef WITH_ROCM
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
// #if CUDA_VERSION >= 11000
// #include <cuda_bf16.h>
// #endif // CUDA_VERSION >= 11000
#define CUDA_PRIMITIVE_BOOL_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(bool, DataType::kBool)
#define CUDA_PRIMITIVE_CHAR_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(char, DataType::kChar)
#define CUDA_PRIMITIVE_INT8_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(int8_t, DataType::kInt8)
#define CUDA_PRIMITIVE_UINT8_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(uint8_t, DataType::kUInt8)
#define CUDA_PRIMITIVE_INT32_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(int32_t, DataType::kInt32)
#define CUDA_PRIMITIVE_UINT32_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(uint32_t, DataType::kUInt32)
#define CUDA_PRIMITIVE_INT64_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(int64_t, DataType::kInt64)
#define CUDA_PRIMITIVE_UINT64_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(uint64_t, DataType::kUInt64)
#define CUDA_PRIMITIVE_FLOAT_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(float, DataType::kFloat)
#define CUDA_PRIMITIVE_DOUBLE_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(double, DataType::kDouble)
#define CUDA_PRIMITIVE_FLOAT16_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(half, DataType::kFloat16)
// #if CUDA_VERSION >= 11000
// #define CUDA_PRIMITIVE_BFLOAT16_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(nv_bfloat16, DataType::kBFloat16)
// #else
#define CUDA_PRIMITIVE_BFLOAT16_TYPE_SEQ
// #endif // CUDA_VERSION >= 11000
#define CUDA_PRIMITIVE_ALL_TYPE_SEQ \
CUDA_PRIMITIVE_BOOL_TYPE_SEQ \
CUDA_PRIMITIVE_CHAR_TYPE_SEQ \
CUDA_PRIMITIVE_INT8_TYPE_SEQ \
CUDA_PRIMITIVE_UINT8_TYPE_SEQ \
CUDA_PRIMITIVE_INT32_TYPE_SEQ \
CUDA_PRIMITIVE_INT64_TYPE_SEQ \
CUDA_PRIMITIVE_FLOAT_TYPE_SEQ \
CUDA_PRIMITIVE_DOUBLE_TYPE_SEQ \
CUDA_PRIMITIVE_FLOAT16_TYPE_SEQ \
CUDA_PRIMITIVE_BFLOAT16_TYPE_SEQ
#define CUDA_PRIMITIVE_FLOATING_TYPE_SEQ \
CUDA_PRIMITIVE_FLOAT_TYPE_SEQ \
CUDA_PRIMITIVE_DOUBLE_TYPE_SEQ \
CUDA_PRIMITIVE_FLOAT16_TYPE_SEQ \
CUDA_PRIMITIVE_BFLOAT16_TYPE_SEQ
#define UTIL_OPS_DATA_TYPE_SEQ \
CUDA_PRIMITIVE_INT8_TYPE_SEQ \
CUDA_PRIMITIVE_UINT8_TYPE_SEQ \
CUDA_PRIMITIVE_INT32_TYPE_SEQ \
CUDA_PRIMITIVE_INT64_TYPE_SEQ \
CUDA_PRIMITIVE_FLOAT_TYPE_SEQ \
CUDA_PRIMITIVE_DOUBLE_TYPE_SEQ \
CUDA_PRIMITIVE_FLOAT16_TYPE_SEQ \
CUDA_PRIMITIVE_BFLOAT16_TYPE_SEQ
#endif // WITH_ROCM
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_EP_CUDA_PRIMITIVE_TYPE_SEQ_H_
#define ONEFLOW_CORE_EP_CUDA_PRIMITIVE_TYPE_SEQ_H_
#include "oneflow/core/common/preprocessor.h"
#include "oneflow/core/common/data_type.h"
#ifdef WITH_ROCM
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
// #if CUDA_VERSION >= 11000
// #include <cuda_bf16.h>
// #endif // CUDA_VERSION >= 11000
#define CUDA_PRIMITIVE_BOOL_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(bool, DataType::kBool)
#define CUDA_PRIMITIVE_CHAR_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(char, DataType::kChar)
#define CUDA_PRIMITIVE_INT8_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(int8_t, DataType::kInt8)
#define CUDA_PRIMITIVE_UINT8_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(uint8_t, DataType::kUInt8)
#define CUDA_PRIMITIVE_INT32_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(int32_t, DataType::kInt32)
#define CUDA_PRIMITIVE_UINT32_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(uint32_t, DataType::kUInt32)
#define CUDA_PRIMITIVE_INT64_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(int64_t, DataType::kInt64)
#define CUDA_PRIMITIVE_UINT64_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(uint64_t, DataType::kUInt64)
#define CUDA_PRIMITIVE_FLOAT_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(float, DataType::kFloat)
#define CUDA_PRIMITIVE_DOUBLE_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(double, DataType::kDouble)
#define CUDA_PRIMITIVE_FLOAT16_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(half, DataType::kFloat16)
// #if CUDA_VERSION >= 11000
// #define CUDA_PRIMITIVE_BFLOAT16_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(nv_bfloat16, DataType::kBFloat16)
// #else
#define CUDA_PRIMITIVE_BFLOAT16_TYPE_SEQ
// #endif // CUDA_VERSION >= 11000
#define CUDA_PRIMITIVE_ALL_TYPE_SEQ \
CUDA_PRIMITIVE_BOOL_TYPE_SEQ \
CUDA_PRIMITIVE_CHAR_TYPE_SEQ \
CUDA_PRIMITIVE_INT8_TYPE_SEQ \
CUDA_PRIMITIVE_UINT8_TYPE_SEQ \
CUDA_PRIMITIVE_INT32_TYPE_SEQ \
CUDA_PRIMITIVE_INT64_TYPE_SEQ \
CUDA_PRIMITIVE_FLOAT_TYPE_SEQ \
CUDA_PRIMITIVE_DOUBLE_TYPE_SEQ \
CUDA_PRIMITIVE_FLOAT16_TYPE_SEQ \
CUDA_PRIMITIVE_BFLOAT16_TYPE_SEQ
#define CUDA_PRIMITIVE_FLOATING_TYPE_SEQ \
CUDA_PRIMITIVE_FLOAT_TYPE_SEQ \
CUDA_PRIMITIVE_DOUBLE_TYPE_SEQ \
CUDA_PRIMITIVE_FLOAT16_TYPE_SEQ \
CUDA_PRIMITIVE_BFLOAT16_TYPE_SEQ
#define UTIL_OPS_DATA_TYPE_SEQ \
CUDA_PRIMITIVE_INT8_TYPE_SEQ \
CUDA_PRIMITIVE_UINT8_TYPE_SEQ \
CUDA_PRIMITIVE_INT32_TYPE_SEQ \
CUDA_PRIMITIVE_INT64_TYPE_SEQ \
CUDA_PRIMITIVE_FLOAT_TYPE_SEQ \
CUDA_PRIMITIVE_DOUBLE_TYPE_SEQ \
CUDA_PRIMITIVE_FLOAT16_TYPE_SEQ \
CUDA_PRIMITIVE_BFLOAT16_TYPE_SEQ
#endif // WITH_ROCM
#endif // ONEFLOW_CORE_EP_CUDA_PRIMITIVE_TYPE_SEQ_H_
\ No newline at end of file
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/ep/common/primitive/unary_functor.h"
#include "oneflow/core/ep/rocm/primitive/type_seq.h"
#include "oneflow/core/hip/elementwise.hip.h"
#include "oneflow/core/ep/rocm/cuda_stream.h"
namespace oneflow {
namespace ep {
namespace primitive {
template<typename Dst, typename Src>
struct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kGelu, Dst, Src> {
UnaryFunctor(Scalar attr0, Scalar attr1) {}
OF_DEVICE_FUNC Dst operator()(Src src) const {
return static_cast<Src>(0.5) * src
* (static_cast<Src>(1.0) + erf(static_cast<Src>(M_SQRT1_2) * src));
}
};
template<>
struct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kTanh, float, float> {
UnaryFunctor(Scalar attr0, Scalar attr1) {}
OF_DEVICE_FUNC float operator()(float src) const { return tanhf(src); }
};
template<>
struct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kTanh, double, double> {
UnaryFunctor(Scalar attr0, Scalar attr1) {}
OF_DEVICE_FUNC double operator()(double src) const { return tanh(src); }
};
template<>
struct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kTanh, half, half> {
UnaryFunctor(Scalar attr0, Scalar attr1) {}
OF_DEVICE_FUNC half operator()(half src) const { return __float2half(tanhf(__half2float(src))); }
};
template<>
struct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kIsInf, bool, half> {
UnaryFunctor(Scalar attr0, Scalar attr1) {}
OF_DEVICE_FUNC bool operator()(half src) const { return isinf(__half2float(src)); }
};
template<>
struct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kIsInf, bool, float> {
UnaryFunctor(Scalar attr0, Scalar attr1) {}
OF_DEVICE_FUNC bool operator()(float src) const { return isinf(src); }
};
template<>
struct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kIsInf, bool, double> {
UnaryFunctor(Scalar attr0, Scalar attr1) {}
OF_DEVICE_FUNC bool operator()(double src) const { return isinf(src); }
};
template<>
struct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kIsNan, bool, half> {
UnaryFunctor(Scalar attr0, Scalar attr1) {}
OF_DEVICE_FUNC bool operator()(half src) const { return isnan(__half2float(src)); }
};
template<>
struct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kIsNan, bool, float> {
UnaryFunctor(Scalar attr0, Scalar attr1) {}
OF_DEVICE_FUNC bool operator()(float src) const { return isnan(src); }
};
template<>
struct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kIsNan, bool, double> {
UnaryFunctor(Scalar attr0, Scalar attr1) {}
OF_DEVICE_FUNC bool operator()(double src) const { return isnan(src); }
};
#define SPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR(op) \
template<> \
struct UnaryFunctor<DeviceType::kCUDA, op, half, half> { \
UnaryFunctor(Scalar attr0, Scalar attr1) : float_functor(attr0, attr1) {} \
\
UnaryFunctor<DeviceType::kCUDA, op, float, float> float_functor; \
OF_DEVICE_FUNC half operator()(half src) const { \
return __float2half(float_functor(__half2float(src))); \
} \
};
SPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR(UnaryOp::kElu);
SPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR(UnaryOp::kCelu);
SPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR(UnaryOp::kGelu);
SPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR(UnaryOp::kMish);
SPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR(UnaryOp::kSelu);
SPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR(UnaryOp::kSilu);
SPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR(UnaryOp::kSoftSign);
SPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR(UnaryOp::kSoftPlus);
// /*********nv_bfloat16_kernel*******/
// #if CUDA_VERSION >= 11000
// #define SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(op) \
// template<> \
// struct UnaryFunctor<DeviceType::kCUDA, op, nv_bfloat16, nv_bfloat16> { \
// UnaryFunctor(Scalar attr0, Scalar attr1) : float_functor(attr0, attr1) {} \
// \
// UnaryFunctor<DeviceType::kCUDA, op, float, float> float_functor; \
// OF_DEVICE_FUNC nv_bfloat16 operator()(nv_bfloat16 src) const { \
// return __float2bfloat16(float_functor(__bfloat162float(src))); \
// } \
// };
// SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kElu);
// SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kCelu);
// SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kGelu);
// SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kHardSwish);
// SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kHardSigmoid);
// SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kHardShrink);
// SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kHardTanh);
// SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kLeakyRelu);
// SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kMish);
// SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kSelu);
// SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kSilu);
// SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kSoftShrink);
// SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kSoftSign);
// SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kSoftPlus);
// SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kTanh);
// SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kThreshold);
// template<>
// struct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kIsInf, bool, nv_bfloat16> {
// UnaryFunctor(Scalar attr0, Scalar attr1) {}
// OF_DEVICE_FUNC bool operator()(nv_bfloat16 src) const { return isinf(__bfloat162float(src)); }
// };
// template<>
// struct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kIsNan, bool, nv_bfloat16> {
// UnaryFunctor(Scalar attr0, Scalar attr1) {}
// OF_DEVICE_FUNC bool operator()(nv_bfloat16 src) const { return isnan(__bfloat162float(src)); }
// };
// #endif
} // namespace primitive
} // namespace ep
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/ep/common/primitive/unary_functor.h"
#include "oneflow/core/ep/rocm/primitive/type_seq.h"
#include "oneflow/core/hip/elementwise.hip.h"
#include "oneflow/core/ep/rocm/cuda_stream.h"
namespace oneflow {
namespace ep {
namespace primitive {
template<typename Dst, typename Src>
struct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kGelu, Dst, Src> {
UnaryFunctor(Scalar attr0, Scalar attr1) {}
OF_DEVICE_FUNC Dst operator()(Src src) const {
return static_cast<Src>(0.5) * src
* (static_cast<Src>(1.0) + erf(static_cast<Src>(M_SQRT1_2) * src));
}
};
template<>
struct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kTanh, float, float> {
UnaryFunctor(Scalar attr0, Scalar attr1) {}
OF_DEVICE_FUNC float operator()(float src) const { return tanhf(src); }
};
template<>
struct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kTanh, double, double> {
UnaryFunctor(Scalar attr0, Scalar attr1) {}
OF_DEVICE_FUNC double operator()(double src) const { return tanh(src); }
};
template<>
struct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kTanh, half, half> {
UnaryFunctor(Scalar attr0, Scalar attr1) {}
OF_DEVICE_FUNC half operator()(half src) const { return __float2half(tanhf(__half2float(src))); }
};
template<>
struct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kIsInf, bool, half> {
UnaryFunctor(Scalar attr0, Scalar attr1) {}
OF_DEVICE_FUNC bool operator()(half src) const { return isinf(__half2float(src)); }
};
template<>
struct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kIsInf, bool, float> {
UnaryFunctor(Scalar attr0, Scalar attr1) {}
OF_DEVICE_FUNC bool operator()(float src) const { return isinf(src); }
};
template<>
struct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kIsInf, bool, double> {
UnaryFunctor(Scalar attr0, Scalar attr1) {}
OF_DEVICE_FUNC bool operator()(double src) const { return isinf(src); }
};
template<>
struct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kIsNan, bool, half> {
UnaryFunctor(Scalar attr0, Scalar attr1) {}
OF_DEVICE_FUNC bool operator()(half src) const { return isnan(__half2float(src)); }
};
template<>
struct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kIsNan, bool, float> {
UnaryFunctor(Scalar attr0, Scalar attr1) {}
OF_DEVICE_FUNC bool operator()(float src) const { return isnan(src); }
};
template<>
struct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kIsNan, bool, double> {
UnaryFunctor(Scalar attr0, Scalar attr1) {}
OF_DEVICE_FUNC bool operator()(double src) const { return isnan(src); }
};
#define SPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR(op) \
template<> \
struct UnaryFunctor<DeviceType::kCUDA, op, half, half> { \
UnaryFunctor(Scalar attr0, Scalar attr1) : float_functor(attr0, attr1) {} \
\
UnaryFunctor<DeviceType::kCUDA, op, float, float> float_functor; \
OF_DEVICE_FUNC half operator()(half src) const { \
return __float2half(float_functor(__half2float(src))); \
} \
};
SPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR(UnaryOp::kElu);
SPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR(UnaryOp::kCelu);
SPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR(UnaryOp::kGelu);
SPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR(UnaryOp::kMish);
SPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR(UnaryOp::kSelu);
SPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR(UnaryOp::kSilu);
SPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR(UnaryOp::kSoftSign);
SPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR(UnaryOp::kSoftPlus);
// /*********nv_bfloat16_kernel*******/
// #if CUDA_VERSION >= 11000
// #define SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(op) \
// template<> \
// struct UnaryFunctor<DeviceType::kCUDA, op, nv_bfloat16, nv_bfloat16> { \
// UnaryFunctor(Scalar attr0, Scalar attr1) : float_functor(attr0, attr1) {} \
// \
// UnaryFunctor<DeviceType::kCUDA, op, float, float> float_functor; \
// OF_DEVICE_FUNC nv_bfloat16 operator()(nv_bfloat16 src) const { \
// return __float2bfloat16(float_functor(__bfloat162float(src))); \
// } \
// };
// SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kElu);
// SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kCelu);
// SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kGelu);
// SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kHardSwish);
// SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kHardSigmoid);
// SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kHardShrink);
// SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kHardTanh);
// SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kLeakyRelu);
// SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kMish);
// SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kSelu);
// SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kSilu);
// SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kSoftShrink);
// SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kSoftSign);
// SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kSoftPlus);
// SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kTanh);
// SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kThreshold);
// template<>
// struct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kIsInf, bool, nv_bfloat16> {
// UnaryFunctor(Scalar attr0, Scalar attr1) {}
// OF_DEVICE_FUNC bool operator()(nv_bfloat16 src) const { return isinf(__bfloat162float(src)); }
// };
// template<>
// struct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kIsNan, bool, nv_bfloat16> {
// UnaryFunctor(Scalar attr0, Scalar attr1) {}
// OF_DEVICE_FUNC bool operator()(nv_bfloat16 src) const { return isnan(__bfloat162float(src)); }
// };
// #endif
} // namespace primitive
} // namespace ep
} // namespace oneflow
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment