Commit 21d47d0e authored by yuguo's avatar yuguo
Browse files

Oneflow 0.8 for DCU

parents
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/common/foreign_lock_helper.h"
#include "oneflow/core/common/singleton.h"
namespace oneflow {
class NoForeignLockHelper final : public ForeignLockHelper {
Maybe<void> WithScopedRelease(const std::function<Maybe<void>()>& Callback) const override {
return Callback();
}
Maybe<void> WithScopedAcquire(const std::function<Maybe<void>()>& Callback) const override {
return Callback();
}
};
static int __register_no_foreign_lock_helper __attribute__((unused)) = []() {
Singleton<ForeignLockHelper>::SetAllocated(new NoForeignLockHelper());
return 0;
}();
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_COMMON_FOREIGN_LOCK_HELPER_H
#define ONEFLOW_CORE_COMMON_FOREIGN_LOCK_HELPER_H
#include <functional>
#include "oneflow/core/common/maybe.h"
namespace oneflow {
class ForeignLockHelper {
public:
virtual ~ForeignLockHelper() = default;
virtual Maybe<void> WithScopedRelease(const std::function<Maybe<void>()>&) const = 0;
virtual Maybe<void> WithScopedAcquire(const std::function<Maybe<void>()>&) const = 0;
};
} // namespace oneflow
#endif // ONEFLOW_CORE_COMMON_FOREIGN_LOCK_HELPER_H
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_COMMON_FUNCTION_TRAITS_H_
#define ONEFLOW_CORE_COMMON_FUNCTION_TRAITS_H_
#include <tuple>
namespace oneflow {
template<typename... Args>
using void_t = void;
template<typename T, typename = void>
struct function_traits;
template<typename Ret, typename... Args>
struct function_traits<Ret(Args...)> {
using func_type = Ret(Args...);
using return_type = Ret;
using args_type = std::tuple<Args...>;
template<size_t i>
using arg_type = typename std::tuple_element<i, args_type>::type;
static constexpr size_t nargs = sizeof...(Args);
};
template<typename Ret, typename... Args>
struct function_traits<Ret (*)(Args...)> {
using func_type = Ret(Args...);
using return_type = Ret;
using args_type = std::tuple<Args...>;
template<size_t i>
using arg_type = typename std::tuple_element<i, args_type>::type;
static constexpr size_t nargs = sizeof...(Args);
};
template<typename Ret, typename C, typename... Args>
struct function_traits<Ret (C::*)(Args...)> {
using func_type = Ret(Args...);
using return_type = Ret;
using args_type = std::tuple<Args...>;
template<size_t i>
using arg_type = typename std::tuple_element<i, args_type>::type;
static constexpr size_t nargs = sizeof...(Args);
};
template<typename Ret, typename C, typename... Args>
struct function_traits<Ret (C::*)(Args...) const> {
using func_type = Ret(Args...);
using return_type = Ret;
using args_type = std::tuple<Args...>;
template<size_t i>
using arg_type = typename std::tuple_element<i, args_type>::type;
static constexpr size_t nargs = sizeof...(Args);
};
template<typename F>
struct function_traits<F, void_t<decltype(&F::operator())>>
: public function_traits<decltype(&F::operator())> {};
} // namespace oneflow
#endif // ONEFLOW_CORE_COMMON_FUNCTION_TRAITS_H_
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_COMMON_HASH_CONTAINER_
#define ONEFLOW_CORE_COMMON_HASH_CONTAINER_
#include <unordered_set>
#include <unordered_map>
namespace oneflow {
template<typename Key, typename T, typename Hash = std::hash<Key>>
using HashMap = std::unordered_map<Key, T, Hash>;
template<typename Key, typename Hash = std::hash<Key>>
using HashSet = std::unordered_set<Key, Hash>;
} // namespace oneflow
#endif // ONEFLOW_CORE_COMMON_HASH_CONTAINER_
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_COMMON_HASH_EQ_TRAIT_PTR_H_
#define ONEFLOW_CORE_COMMON_HASH_EQ_TRAIT_PTR_H_
namespace oneflow {
template<typename T>
class HashEqTraitPtr final {
public:
HashEqTraitPtr(const HashEqTraitPtr<T>&) = default;
HashEqTraitPtr(T* ptr, size_t hash_value) : ptr_(ptr), hash_value_(hash_value) {}
~HashEqTraitPtr() = default;
T* ptr() const { return ptr_; }
size_t hash_value() const { return hash_value_; }
bool operator==(const HashEqTraitPtr<T>& rhs) const { return *ptr_ == *rhs.ptr_; }
private:
T* ptr_;
size_t hash_value_;
};
} // namespace oneflow
namespace std {
template<typename T>
struct hash<oneflow::HashEqTraitPtr<T>> final {
size_t operator()(const oneflow::HashEqTraitPtr<T>& ptr) const { return ptr.hash_value(); }
};
} // namespace std
#endif // ONEFLOW_CORE_COMMON_HASH_EQ_TRAIT_PTR_H_
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_COMMON_HIGH_ORDER_BOOL_H_
#define ONEFLOW_CORE_COMMON_HIGH_ORDER_BOOL_H_
#include <string>
#include <memory>
#include <sstream>
#include <functional>
#include <utility>
#include "oneflow/core/common/function_traits.h"
#include "oneflow/core/common/type_traits.h"
#include "oneflow/core/common/util.h"
namespace oneflow {
namespace hob {
template<typename Context, typename ValueT>
struct BaseExpr {
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wnon-virtual-dtor"
// NOTE: Performance will be degraded if the destructor is virtual.
// So please do NOT implement custom destructor in any child classes of BaseExpr,
// and every fields of child classes should be of POD type.
~BaseExpr() = default;
#pragma GCC diagnostic pop
ALWAYS_INLINE virtual scalar_or_const_ref_t<ValueT> get(const Context&) const = 0;
virtual std::string DebugStr(const Context&, bool display_result = true) const = 0; // NOLINT
operator bool() = delete;
};
template<typename Context, typename ValueT, typename E>
struct Expr : public BaseExpr<Context, ValueT> {
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wnon-virtual-dtor"
~Expr() = default;
#pragma GCC diagnostic pop
};
template<typename Context, typename ValueT>
struct Literal final : public Expr<Context, ValueT, Literal<Context, ValueT>> {
Literal(const ValueT& val) : Literal(ToString(val), val) {} // NOLINT
Literal(const std::string& debug_str, const ValueT& val) : val_(val), debug_str_(debug_str) {}
ALWAYS_INLINE scalar_or_const_ref_t<ValueT> get(const Context&) const override { return val_; }
std::string DebugStr(const Context&, bool display_result) const override { return debug_str_; }
private:
ValueT val_;
std::string debug_str_;
};
template<typename Context>
using LiteralBool = Literal<Context, bool>;
template<typename Fn,
typename Context =
std::decay_t<typename oneflow::function_traits<Fn>::template arg_type<0>>,
typename ValueT = std::decay_t<typename oneflow::function_traits<Fn>::return_type>>
struct Custom final : public Expr<Context, ValueT, Custom<Fn>> {
explicit Custom(Fn fn) : Custom("", fn) {}
Custom(std::string debug_str, Fn fn) : fn_(std::move(fn)), debug_str_(std::move(debug_str)) {}
ALWAYS_INLINE scalar_or_const_ref_t<ValueT> get(const Context& context) const override {
return fn_(context);
}
std::string DebugStr(const Context&, bool display_result) const override { return debug_str_; }
private:
Fn fn_;
std::string debug_str_;
};
template<typename Fn>
ALWAYS_INLINE inline Custom<Fn> make_custom(Fn fn) {
return Custom<Fn>(std::forward<Fn>(fn));
}
template<typename Fn>
ALWAYS_INLINE inline Custom<Fn> make_custom(const std::string& debug_str, Fn fn) {
return Custom<Fn>(debug_str, std::forward<Fn>(fn));
}
template<typename Context, typename E>
using BoolExpr = Expr<Context, bool, E>;
template<typename Context, typename E>
struct NotBoolFunctor final : public BoolExpr<Context, NotBoolFunctor<Context, E>> {
explicit NotBoolFunctor(const E& expr) : expr_(expr) {}
ALWAYS_INLINE bool get(const Context& context) const override { return !expr_.get(context); }
std::string DebugStr(const Context& ctx, bool display_result) const override {
std::ostringstream string_stream;
string_stream << "("
<< "not " << expr_.DebugStr(ctx, display_result) << ")";
return string_stream.str();
}
private:
const E expr_;
};
template<typename Context, typename E>
NotBoolFunctor<Context, E> operator!(BoolExpr<Context, E> const& lhs) {
return NotBoolFunctor<Context, E>(*static_cast<const E*>(&lhs));
}
#define DEFINE_BINARY_FUNCTOR(name, op) \
template<typename Context, typename E1, typename E2> \
struct name##BoolFunctor final : public BoolExpr<Context, name##BoolFunctor<Context, E1, E2>> { \
name##BoolFunctor(const E1& lhs, const E2& rhs) : lhs_(lhs), rhs_(rhs) {} \
\
ALWAYS_INLINE bool get(const Context& context) const override; \
\
std::string DebugStr(const Context& ctx, bool display_result) const override; \
\
private: \
const E1 lhs_; \
const E2 rhs_; \
}; \
\
template<typename Context, typename ValueT, typename E1, typename E2> \
name##BoolFunctor<Context, E1, E2> operator op(Expr<Context, ValueT, E1> const& lhs, \
Expr<Context, ValueT, E2> const& rhs) { \
return name##BoolFunctor<Context, E1, E2>(*static_cast<const E1*>(&lhs), \
*static_cast<const E2*>(&rhs)); \
} \
\
template<typename Context, typename ValueT, typename E1> \
name##BoolFunctor<Context, E1, Literal<Context, ValueT>> operator op( \
Expr<Context, ValueT, E1> const& lhs, ValueT const& rhs) { \
return name##BoolFunctor<Context, E1, Literal<Context, ValueT>>( \
*static_cast<const E1*>(&lhs), Literal<Context, ValueT>(rhs)); \
}
DEFINE_BINARY_FUNCTOR(Equal, ==)
DEFINE_BINARY_FUNCTOR(And, &&)
DEFINE_BINARY_FUNCTOR(Or, ||)
DEFINE_BINARY_FUNCTOR(Greater, >)
DEFINE_BINARY_FUNCTOR(Less, <)
DEFINE_BINARY_FUNCTOR(EqualOrGreater, >=)
DEFINE_BINARY_FUNCTOR(EqualOrLess, <=)
#undef DEFINE_BINARY_FUNCTOR
#define DEFINE_NON_SHORT_CIRCUIT_FUNCTOR_METHODS(name, op) \
template<typename Context, typename E1, typename E2> \
ALWAYS_INLINE inline bool name##BoolFunctor<Context, E1, E2>::get(const Context& context) \
const { \
return lhs_.get(context) op rhs_.get(context); \
} \
template<typename Context, typename E1, typename E2> \
std::string name##BoolFunctor<Context, E1, E2>::DebugStr(const Context& ctx, \
bool display_result) const { \
std::string l_str = lhs_.DebugStr(ctx, display_result); \
std::string r_str = rhs_.DebugStr(ctx, display_result); \
std::ostringstream string_stream; \
string_stream << "(" << l_str << " " << OF_PP_STRINGIZE(op) << " " << r_str << ")"; \
return string_stream.str(); \
}
DEFINE_NON_SHORT_CIRCUIT_FUNCTOR_METHODS(Equal, ==)
DEFINE_NON_SHORT_CIRCUIT_FUNCTOR_METHODS(Greater, >)
DEFINE_NON_SHORT_CIRCUIT_FUNCTOR_METHODS(Less, <)
DEFINE_NON_SHORT_CIRCUIT_FUNCTOR_METHODS(EqualOrGreater, >=)
DEFINE_NON_SHORT_CIRCUIT_FUNCTOR_METHODS(EqualOrLess, <=)
#undef DEFINE_NON_SHORT_CIRCUIT_FUNCTOR_METHODS
template<typename Context, typename E1, typename E2>
ALWAYS_INLINE inline bool AndBoolFunctor<Context, E1, E2>::get(const Context& context) const {
bool lhs_result = lhs_.get(context);
if (!lhs_result) { return false; }
return rhs_.get(context);
}
template<typename Context, typename E1, typename E2>
std::string AndBoolFunctor<Context, E1, E2>::DebugStr(const Context& ctx,
bool display_result) const {
std::string l_str = lhs_.DebugStr(ctx, display_result);
display_result = display_result && lhs_.get(ctx);
std::string r_str = rhs_.DebugStr(ctx, display_result);
std::ostringstream string_stream;
string_stream << "(" << l_str << " and " << r_str << ")";
return string_stream.str();
}
template<typename Context, typename E1, typename E2>
ALWAYS_INLINE inline bool OrBoolFunctor<Context, E1, E2>::get(const Context& context) const {
bool lhs_result = lhs_.get(context);
if (lhs_result) { return true; }
return rhs_.get(context);
}
template<typename Context, typename E1, typename E2>
std::string OrBoolFunctor<Context, E1, E2>::DebugStr(const Context& ctx,
bool display_result) const {
std::string l_str = lhs_.DebugStr(ctx, display_result);
display_result = display_result && (!lhs_.get(ctx));
std::string r_str = rhs_.DebugStr(ctx, display_result);
std::ostringstream string_stream;
string_stream << "(" << l_str << " or " << r_str << ")";
return string_stream.str();
}
template<typename Context, typename E1>
EqualBoolFunctor<Context, E1, Literal<Context, std::string>> operator==(
Expr<Context, std::string, E1> const& lhs, const char* rhs) {
return EqualBoolFunctor<Context, E1, Literal<Context, std::string>>(
*static_cast<const E1*>(&lhs), Literal<Context, std::string>(rhs));
}
} // namespace hob
} // namespace oneflow
#endif // ONEFLOW_CORE_COMMON_HIGH_ORDER_BOOL_H_
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_COMMON_JUST_H_
#define ONEFLOW_CORE_COMMON_JUST_H_
#include <glog/logging.h>
#include <type_traits>
#include "oneflow/core/common/error.h"
#include "oneflow/core/common/preprocessor.h"
namespace oneflow {
template<typename T, typename Enabled = void>
class Maybe;
template<typename T>
class Optional;
Maybe<std::string> FormatErrorStr(const std::shared_ptr<ErrorProto>&);
namespace {
std::string GetFormatedSerializedError(const std::shared_ptr<ErrorProto>&);
}
namespace private_details {
inline std::shared_ptr<ErrorProto>&& JustErrorAddStackFrame(std::shared_ptr<ErrorProto>&& err,
const std::string& file, int64_t line,
const std::string& func,
const std::string& message) {
auto* stack_frame = err->add_stack_frame();
stack_frame->set_file(file);
stack_frame->set_line(line);
stack_frame->set_function(func);
stack_frame->set_error_msg(message);
return std::move(err);
}
template<typename... T>
Error&& JustErrorAddMessage(Error&& err, T&&... msg) {
__attribute__((unused)) int dummy[] = {((void)(std::move(err) << std::forward<T>(msg)), 0)...};
return std::move(err);
}
template<typename T>
bool JustIsOk(const Maybe<T>& val) {
return val.IsOk();
}
template<typename T>
bool JustIsOk(const Optional<T>& val) {
return val.has_value();
}
template<typename T>
std::shared_ptr<ErrorProto> JustGetError(const Maybe<T>& val) {
return val.error();
}
template<typename T>
std::shared_ptr<ErrorProto> JustGetError(const Optional<T>&) {
return Error::ValueNotFoundError().error_proto();
}
template<typename T>
typename std::remove_const<typename std::remove_reference<T>::type>::type&& RemoveRValConst(
T&& v) noexcept {
static_assert(std::is_rvalue_reference<T&&>::value, "rvalue is expected here");
return const_cast<typename std::remove_const<typename std::remove_reference<T>::type>::type&&>(v);
}
} // namespace private_details
} // namespace oneflow
#define __JustStackCheckWrapper__(...) __VA_ARGS__
#define TRY(...) __JustStackCheckWrapper__(__VA_ARGS__)
#if defined(__GNUC__) || defined(__CUDACC__) || defined(__clang__)
#define JUST(...) \
::oneflow::private_details::RemoveRValConst(({ \
auto&& _just_value_to_check_ = __JustStackCheckWrapper__(__VA_ARGS__); \
if (!::oneflow::private_details::JustIsOk(_just_value_to_check_)) { \
return ::oneflow::private_details::JustErrorAddStackFrame( \
::oneflow::private_details::JustGetError(_just_value_to_check_), __FILE__, __LINE__, \
__FUNCTION__, OF_PP_STRINGIZE(__VA_ARGS__)); \
} \
std::forward<decltype(_just_value_to_check_)>(_just_value_to_check_); \
})).Data_YouAreNotAllowedToCallThisFuncOutsideThisFile()
#define CHECK_JUST(...) \
([&](const char* _just_closure_func_name_) { \
auto&& _just_value_to_check_ = __JustStackCheckWrapper__(__VA_ARGS__); \
if (!::oneflow::private_details::JustIsOk(_just_value_to_check_)) { \
LOG(FATAL) << ::oneflow::GetFormatedSerializedError( \
::oneflow::private_details::JustErrorAddStackFrame( \
::oneflow::private_details::JustGetError(_just_value_to_check_), __FILE__, __LINE__, \
_just_closure_func_name_, OF_PP_STRINGIZE(__VA_ARGS__))); \
} \
return std::forward<decltype(_just_value_to_check_)>(_just_value_to_check_); \
})(__FUNCTION__) \
.Data_YouAreNotAllowedToCallThisFuncOutsideThisFile()
#define JUST_MSG(value, ...) \
::oneflow::private_details::RemoveRValConst(({ \
auto&& _just_value_to_check_ = (value); \
if (!::oneflow::private_details::JustIsOk(_just_value_to_check_)) { \
return ::oneflow::private_details::JustErrorAddMessage( \
::oneflow::Error(::oneflow::private_details::JustGetError(_just_value_to_check_)) \
.AddStackFrame(__FILE__, __LINE__, __FUNCTION__), \
OF_PP_STRINGIZE(value), ": ", __VA_ARGS__); \
} \
std::forward<decltype(_just_value_to_check_)>(_just_value_to_check_); \
})).Data_YouAreNotAllowedToCallThisFuncOutsideThisFile()
#define CHECK_JUST_MSG(value, ...) \
([&](const char* _just_closure_func_name_) { \
auto&& _just_value_to_check_ = (value); \
if (!::oneflow::private_details::JustIsOk(_just_value_to_check_)) { \
LOG(FATAL) << ::oneflow::GetFormatedSerializedError( \
::oneflow::private_details::JustErrorAddMessage( \
::oneflow::Error(::oneflow::private_details::JustGetError(_just_value_to_check_)) \
.AddStackFrame(__FILE__, __LINE__, _just_closure_func_name_), \
OF_PP_STRINGIZE(value), ": ", __VA_ARGS__) \
.error_proto()); \
} \
return std::forward<decltype(_just_value_to_check_)>(_just_value_to_check_); \
})(__FUNCTION__) \
.Data_YouAreNotAllowedToCallThisFuncOutsideThisFile()
#define JUST_OPT(...) \
::oneflow::private_details::RemoveRValConst(({ \
auto&& _just_value_to_check_ = __JustStackCheckWrapper__(__VA_ARGS__); \
if (!_just_value_to_check_.has_value()) { return NullOpt; } \
std::forward<decltype(_just_value_to_check_)>(_just_value_to_check_); \
})).Data_YouAreNotAllowedToCallThisFuncOutsideThisFile()
#else
#error statement expression is no supported, please implement try-catch version of JUST
#endif
#endif // ONEFLOW_CORE_COMMON_JUST_H_
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_COMMON_LAYOUT_STANDARDIZE_H_
#define ONEFLOW_CORE_COMMON_LAYOUT_STANDARDIZE_H_
namespace oneflow {
template<typename T>
class LayoutStandardize final {
public:
void __Init__(const T& val) { new (&data_[0]) T(val); }
void __Delete__() { Mutable()->~T(); }
const T& Get() const { return *reinterpret_cast<const T*>(&data_[0]); }
T* Mutable() { return reinterpret_cast<T*>(&data_[0]); }
private:
union {
char data_[sizeof(T)];
int64_t align_;
};
};
} // namespace oneflow
#endif // ONEFLOW_CORE_COMMON_LAYOUT_STANDARDIZE_H_
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include <utility>
#include "glog/logging.h"
#include "oneflow/core/common/math_util.h"
namespace oneflow {
int64_t Gcd(int64_t m, int64_t n) {
if (m < n) { std::swap(m, n); }
if (n == 0) { return m; }
CHECK_GT(m, 0);
CHECK_GT(n, 0);
return Gcd(n, m % n);
}
int64_t Lcm(int64_t m, int64_t n) { return m * n / Gcd(m, n); }
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_COMMON_MATH_UTIL_H_
#define ONEFLOW_CORE_COMMON_MATH_UTIL_H_
#include <stdint.h>
namespace oneflow {
int64_t Gcd(int64_t m, int64_t n);
int64_t Lcm(int64_t m, int64_t n);
} // namespace oneflow
#endif // ONEFLOW_CORE_COMMON_MATH_UTIL_H_
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_COMMON_MAYBE_H_
#define ONEFLOW_CORE_COMMON_MAYBE_H_
#include <glog/logging.h>
#include <google/protobuf/text_format.h>
#include "oneflow/core/common/type_traits.h"
#include "oneflow/core/common/either_ptr.h"
#include "oneflow/core/common/shared_or_scalar.h"
#include "oneflow/core/common/error.h"
#include "oneflow/core/common/preprocessor.h"
#include "oneflow/core/common/just.h"
namespace oneflow {
template<typename T>
struct is_maybe {
static const bool value = false;
};
template<typename T>
struct is_maybe<Maybe<T>> {
static const bool value = true;
};
template<typename T>
class Maybe<T, typename std::enable_if<!(std::is_same<T, void>::value || IsScalarType<T>::value)
&& !std::is_reference<T>::value>::type>
final {
public:
Maybe(const T& data) : data_or_error_(std::make_shared<T>(data)) {}
Maybe(T&& data) : data_or_error_(std::make_shared<T>(std::move(data))) {}
Maybe(const Error& error) : data_or_error_(error.error_proto()) {}
Maybe(const std::shared_ptr<T>& data) : data_or_error_(data) {}
Maybe(std::shared_ptr<T>&& data) : data_or_error_(std::move(data)) {}
Maybe(const std::shared_ptr<ErrorProto>& error) : data_or_error_(error) {}
Maybe(const Maybe&) = default;
Maybe(Maybe&& other) : data_or_error_(std::move(other.data_or_error_)) {}
~Maybe() = default;
bool IsOk() const { return data_or_error_.template Has<T>(); }
std::shared_ptr<T> Data_YouAreNotAllowedToCallThisFuncOutsideThisFile() const {
return data_or_error_.template Get<T>();
}
std::shared_ptr<ErrorProto> error() const { return data_or_error_.template Get<ErrorProto>(); }
std::string GetSerializedError() const {
CHECK(!IsOk());
return GetFormatedSerializedError(this->error());
}
template<typename Type = T>
Type GetDataAndSerializedErrorProto(std::string* error_str, const Type& default_for_error) const {
static_assert(std::is_same<T, Type>::value, "error type for argument 1");
if (IsOk()) {
*error_str = ErrorProto().DebugString();
return *Data_YouAreNotAllowedToCallThisFuncOutsideThisFile();
} else {
*error_str = this->error()->DebugString();
return default_for_error;
}
}
template<typename Type = T>
std::pair<Type, std::shared_ptr<ErrorProto>> GetDataAndErrorProto(
const Type& default_for_error) const {
if (IsOk()) {
return std::make_pair(*Data_YouAreNotAllowedToCallThisFuncOutsideThisFile(),
std::shared_ptr<ErrorProto>());
} else {
return std::make_pair(default_for_error, error());
}
}
std::pair<std::shared_ptr<T>, std::shared_ptr<ErrorProto>> GetDataPtrAndErrorProto() const {
if (IsOk()) {
return std::make_pair(Data_YouAreNotAllowedToCallThisFuncOutsideThisFile(),
std::shared_ptr<ErrorProto>());
} else {
return std::make_pair(std::shared_ptr<T>(), error());
}
}
template<typename Type = T>
Type GetOrThrow() const {
if (!IsOk()) { ThrowError(error()); }
return *Data_YouAreNotAllowedToCallThisFuncOutsideThisFile();
}
std::shared_ptr<T> GetPtrOrThrow() const {
if (!IsOk()) { ThrowError(error()); }
return Data_YouAreNotAllowedToCallThisFuncOutsideThisFile();
}
private:
EitherPtr<T, ErrorProto> data_or_error_;
};
template<typename T>
class Maybe<T, typename std::enable_if<std::is_same<T, void>::value>::type> final {
public:
Maybe(const Error& error) : error_or_scalar_(error.error_proto()) { CheckError(); }
Maybe(const std::shared_ptr<ErrorProto>& error) : error_or_scalar_(error) { CheckError(); }
Maybe(const Maybe&) = default;
Maybe(Maybe&&) = default;
~Maybe() = default;
static Maybe Ok() { return Maybe(); }
bool IsOk() const { return error_or_scalar_.IsScalar(); }
void Data_YouAreNotAllowedToCallThisFuncOutsideThisFile() const {}
std::shared_ptr<ErrorProto> error() const { return error_or_scalar_.shared_ptr(); }
std::string GetSerializedError() const {
CHECK(!IsOk());
return GetFormatedSerializedError(this->error());
}
void GetDataAndSerializedErrorProto(std::string* error_str) const {
if (IsOk()) {
*error_str = ErrorProto().DebugString();
} else {
*error_str = this->error()->DebugString();
}
}
std::shared_ptr<ErrorProto> GetDataAndErrorProto() const {
if (IsOk()) {
return std::shared_ptr<ErrorProto>();
} else {
return error();
}
}
void GetOrThrow() const {
if (!IsOk()) { ThrowError(error()); }
return Data_YouAreNotAllowedToCallThisFuncOutsideThisFile();
}
private:
Maybe() : error_or_scalar_(nullptr) {}
void CheckError() const {
CHECK_NE(this->error()->error_type_case(), ErrorProto::ERROR_TYPE_NOT_SET);
}
SharedOrScalar<ErrorProto, void*> error_or_scalar_;
};
inline const std::shared_ptr<ErrorProto>& UninitializedValueError() {
static thread_local const auto& error =
Error::InvalidValueError("uninitialized value").error_proto();
return error;
}
template<typename T>
class Maybe<T, typename std::enable_if<IsScalarType<T>::value>::type> final {
public:
Maybe(T data) : error_or_scalar_(data) {}
Maybe(const Error& error) : error_or_scalar_(error.error_proto()) { CheckError(); }
Maybe(const std::shared_ptr<ErrorProto>& error) : error_or_scalar_(error) { CheckError(); }
Maybe() : error_or_scalar_(UninitializedValueError()) {}
Maybe(const Maybe&) = default;
Maybe(Maybe&&) = default;
~Maybe() = default;
void operator=(const Maybe& rhs) { error_or_scalar_ = rhs.error_or_scalar_; }
bool IsOk() const { return error_or_scalar_.IsScalar(); }
T Data_YouAreNotAllowedToCallThisFuncOutsideThisFile() const {
return error_or_scalar_.scalar_value();
}
std::shared_ptr<ErrorProto> error() const { return error_or_scalar_.shared_ptr(); }
std::string GetSerializedError() const {
CHECK(!IsOk());
return GetFormatedSerializedError(this->error());
}
T GetDataAndSerializedErrorProto(std::string* error_str, const T& default_for_error) const {
if (IsOk()) {
*error_str = ErrorProto().DebugString();
return Data_YouAreNotAllowedToCallThisFuncOutsideThisFile();
} else {
*error_str = this->error()->DebugString();
return default_for_error;
}
}
std::pair<T, std::shared_ptr<ErrorProto>> GetDataAndErrorProto(const T& default_for_error) const {
if (IsOk()) {
return std::make_pair(Data_YouAreNotAllowedToCallThisFuncOutsideThisFile(),
std::shared_ptr<ErrorProto>());
} else {
return std::make_pair(default_for_error, error());
}
}
T GetOrThrow() const {
if (!IsOk()) { ThrowError(error()); }
return Data_YouAreNotAllowedToCallThisFuncOutsideThisFile();
}
private:
void CheckError() const {
CHECK_NE(this->error()->error_type_case(), ErrorProto::ERROR_TYPE_NOT_SET);
}
SharedOrScalar<ErrorProto, T> error_or_scalar_;
};
template<typename T>
class Maybe<T, typename std::enable_if<!(std::is_same<T, void>::value || IsScalarType<T>::value)
&& std::is_reference<T>::value>::type>
final {
using ValueT = typename std::remove_reference<T>::type;
using PtrT = ValueT*;
public:
Maybe(T data) : maybe_ptr_(&data) {}
Maybe(const Error& error) : maybe_ptr_(error) {}
Maybe(const std::shared_ptr<ErrorProto>& error) : maybe_ptr_(error) {}
Maybe(const Maybe&) = default;
Maybe(Maybe&&) = default;
~Maybe() = default;
bool IsOk() const { return maybe_ptr_.IsOk(); }
T Data_YouAreNotAllowedToCallThisFuncOutsideThisFile() const {
return *maybe_ptr_.Data_YouAreNotAllowedToCallThisFuncOutsideThisFile();
}
std::shared_ptr<ErrorProto> error() const { return maybe_ptr_.error(); }
std::string GetSerializedError() const {
CHECK(!IsOk());
return maybe_ptr_.GetSerializedError();
}
T GetDataAndSerializedErrorProto(std::string* error_str) const {
return *maybe_ptr_.GetDataAndSerializedErrorProto(error_str, static_cast<PtrT>(nullptr));
}
T GetOrThrow() const {
if (!IsOk()) { ThrowError(error()); }
return Data_YouAreNotAllowedToCallThisFuncOutsideThisFile();
}
private:
Maybe<PtrT> maybe_ptr_;
};
namespace {
std::string GetFormatedSerializedError(const std::shared_ptr<ErrorProto>& error_proto) {
// return error msg got from formatted function or debugstring.
const auto& maybe_error = TRY(FormatErrorStr(error_proto));
const auto& error_str = maybe_error.GetDataAndErrorProto(error_proto->DebugString());
return error_str.first;
}
} // namespace
} // namespace oneflow
#define CHECK_OK(...) \
for (auto&& maybe = __JustStackCheckWrapper__(__VA_ARGS__); \
GOOGLE_PREDICT_BRANCH_NOT_TAKEN(!maybe.IsOk());) \
LOG(FATAL) << OF_PP_STRINGIZE(__VA_ARGS__) << " is not OK:\n" << maybe.GetSerializedError()
#define OF_RETURN_IF_ERROR(...) \
for (auto&& maybe_##__LINE__ = __JustStackCheckWrapper__(__VA_ARGS__); \
!maybe_##__LINE__.IsOk();) \
return Error(maybe_##__LINE__.error()).AddStackFrame(__FILE__, __LINE__, __FUNCTION__)
#define OF_TODO() return Error::TodoError().AddStackFrame(__FILE__, __LINE__, __FUNCTION__)
#define OF_UNIMPLEMENTED() \
return Error::UnimplementedError().AddStackFrame(__FILE__, __LINE__, __FUNCTION__)
#define OF_RUNTIME_ERROR() \
return Error::RuntimeError().AddStackFrame(__FILE__, __LINE__, __FUNCTION__) << "RuntimeError " \
": "
#define RETURN_ERROR_WITH_BUG_PROMPT() OF_RUNTIME_ERROR() << kOfBugIssueUploadPrompt
#define OF_LOG_ONCE(x) \
{ \
static bool warned = false; \
if (!warned) { \
warned = true; \
x; \
} \
}
#define OF_COMPLIE_OPTION_ERROR() \
return Error::CompileOptionWrongError().AddStackFrame(__FILE__, __LINE__, __FUNCTION__) \
<< "Compile option wrong: "
#define CHECK_OR_RETURN(expr) \
if (!(expr)) \
return Error::CheckFailedError().AddStackFrame(__FILE__, __LINE__, __FUNCTION__) \
<< "Check failed: " << OF_PP_STRINGIZE(expr) << " "
#define CHECK_EQ_OR_RETURN(lhs, rhs) \
CHECK_OR_RETURN((lhs) == (rhs)) << "(" << (lhs) << " vs " << (rhs) << ") "
#define CHECK_GE_OR_RETURN(lhs, rhs) \
CHECK_OR_RETURN((lhs) >= (rhs)) << "(" << (lhs) << " vs " << (rhs) << ") "
#define CHECK_GT_OR_RETURN(lhs, rhs) \
CHECK_OR_RETURN((lhs) > (rhs)) << "(" << (lhs) << " vs " << (rhs) << ") "
#define CHECK_LE_OR_RETURN(lhs, rhs) \
CHECK_OR_RETURN((lhs) <= (rhs)) << "(" << (lhs) << " vs " << (rhs) << ") "
#define CHECK_LT_OR_RETURN(lhs, rhs) \
CHECK_OR_RETURN((lhs) < (rhs)) << "(" << (lhs) << " vs " << (rhs) << ") "
#define CHECK_NE_OR_RETURN(lhs, rhs) \
CHECK_OR_RETURN((lhs) != (rhs)) << "(" << (lhs) << " vs " << (rhs) << ") "
#define CHECK_STREQ_OR_RETURN(lhs, rhs) CHECK_EQ_OR_RETURN(std::string(lhs), std::string(rhs))
#define CHECK_STRNE_OR_RETURN(lhs, rhs) CHECK_NE_OR_RETURN(std::string(lhs), std::string(rhs))
#define CHECK_NOTNULL_OR_RETURN(ptr) CHECK_OR_RETURN(ptr != nullptr)
#define CHECK_ISNULL_OR_RETURN(ptr) CHECK_OR_RETURN(ptr == nullptr)
#define TODO_THEN_RETURN() OF_TODO()
#define UNIMPLEMENTED_THEN_RETURN() OF_UNIMPLEMENTED()
#endif // ONEFLOW_CORE_COMMON_MAYBE_H_
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/common/maybe.h"
#include "gtest/gtest.h"
#include <gtest/gtest-death-test.h>
#include <memory>
#include "oneflow/core/common/util.h"
namespace oneflow {
namespace test {
TEST(Maybe, JUST_MSG) {
auto f = [](int x) -> Maybe<int> {
if (x > 10) { return Error::InvalidValueError("") << "input value " << x; }
return 233;
};
auto g = [](int x) { return x * x - 5 * x + 3; };
auto h = [&](int x) -> Maybe<int> {
auto y = g(x);
return JUST_MSG(f(y), "input value g(", x, ")");
};
auto i = [&](float x) -> Maybe<int> {
int y = x;
return JUST_MSG(h(y), std::stringstream() << "input value int(" << x << ")");
};
auto data = CHECK_JUST(i(1));
ASSERT_EQ(data, 233);
auto err = i(10.123).error();
ASSERT_EQ(err->msg(), "input value 53");
ASSERT_EQ(err->stack_frame(0).error_msg(), "f(y): input value g(10)");
ASSERT_EQ(err->stack_frame(1).error_msg(), "h(y): input value int(10.123)");
// NOLINTNEXTLINE(cppcoreguidelines-avoid-goto)
ASSERT_EXIT(CHECK_JUST(i(10.234)), testing::KilledBySignal(SIGABRT), R"(input value 53)");
}
TEST(Maybe, CHECK_OK) {
auto f = [](int x) -> Maybe<int> {
if (x > 10) { return Error::InvalidValueError("") << "input value " << x; }
return 233;
};
auto g = [&](int x) -> Maybe<int> {
auto y = JUST(f(x));
return f(y);
};
// NOLINTNEXTLINE(cppcoreguidelines-avoid-goto)
ASSERT_EXIT(CHECK_OK(g(11)), testing::KilledBySignal(SIGABRT), R"(g\(11\) is not OK)");
}
TEST(Maybe, Noncopyable) { Maybe<std::unique_ptr<int>> a{std::make_unique<int>(1)}; }
} // namespace test
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_COMMON_META_UTIL_HPP_
#define ONEFLOW_CORE_COMMON_META_UTIL_HPP_
#include <utility>
#include <tuple>
namespace oneflow {
template<typename... Args, typename Func, std::size_t... Idx>
void for_each(const std::tuple<Args...>& t, Func&& f, std::index_sequence<Idx...>) {
(void)std::initializer_list<int>{(f(std::get<Idx>(t)), void(), 0)...};
}
template<typename... Args, typename Func, std::size_t... Idx>
void for_each_i(const std::tuple<Args...>& t, Func&& f, std::index_sequence<Idx...>) {
(void)std::initializer_list<int>{
(f(std::get<Idx>(t), std::integral_constant<size_t, Idx>{}), void(), 0)...};
}
template<typename T>
using remove_const_reference_t = std::remove_const_t<std::remove_reference_t<T>>;
template<std::size_t... Is>
auto make_tuple_from_sequence(std::index_sequence<Is...>) {
return std::make_tuple(Is...);
}
template<std::size_t N>
constexpr auto make_tuple_from_sequence() {
return make_tuple_from_sequence(std::make_index_sequence<N>{});
}
namespace detail {
template<class Tuple, class F, std::size_t... Is>
void tuple_switch(const std::size_t i, Tuple&& t, F&& f, std::index_sequence<Is...>) {
(void)std::initializer_list<int>{
(i == Is && ((void)std::forward<F>(f)(std::integral_constant<size_t, Is>{}), 0))...};
}
} // namespace detail
template<class Tuple, class F>
inline void tuple_switch(const std::size_t i, Tuple&& t, F&& f) {
constexpr auto N = std::tuple_size<std::remove_reference_t<Tuple>>::value;
detail::tuple_switch(i, std::forward<Tuple>(t), std::forward<F>(f),
std::make_index_sequence<N>{});
}
} // namespace oneflow
#endif // ONEFLOW_CORE_COMMON_META_UTIL_HPP_
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/common/nd_index.h"
#include "oneflow/core/common/protobuf.h"
namespace oneflow {
NdIndex::NdIndex(const std::initializer_list<int64_t>& dim_vec) : dim_vec_(dim_vec) {}
NdIndex::NdIndex(const DimVector& dim_vec) : dim_vec_(dim_vec) {}
NdIndex& NdIndex::operator=(const NdIndex& shape) {
dim_vec_ = shape.dim_vec_;
return *this;
}
bool NdIndex::operator==(const NdIndex& rhs) const { return dim_vec_ == rhs.dim_vec_; }
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_COMMON_ND_INDEX_H_
#define ONEFLOW_CORE_COMMON_ND_INDEX_H_
#include "oneflow/core/common/util.h"
#include "oneflow/core/common/shape.h"
namespace oneflow {
class NdIndex final {
public:
NdIndex() = default;
explicit NdIndex(const DimVector& dim_vec);
NdIndex(const std::initializer_list<int64_t>& dim_vec);
~NdIndex() = default;
NdIndex& operator=(const NdIndex& other);
bool operator==(const NdIndex& rhs) const;
bool operator!=(const NdIndex& rhs) const { return !(*this == rhs); }
const DimVector& dim_vec() const { return dim_vec_; }
int64_t At(int64_t index) const { return dim_vec_.at(index); }
int64_t NumAxes() const { return dim_vec_.size(); }
private:
DimVector dim_vec_;
};
} // namespace oneflow
#endif // ONEFLOW_CORE_COMMON_ND_INDEX_H_
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_COMMON_ND_INDEX_OFFSET_HELPER_H_
#define ONEFLOW_CORE_COMMON_ND_INDEX_OFFSET_HELPER_H_
#include "oneflow/core/common/data_type.h"
#include <cassert>
namespace oneflow {
template<typename T, int N>
class NdIndexOffsetHelper {
public:
NdIndexOffsetHelper() {}
template<class... Ts>
OF_DEVICE_FUNC explicit NdIndexOffsetHelper(T d0, Ts... dims) {
constexpr int n = 1 + sizeof...(dims);
static_assert(n <= N, "");
T dims_arr[n] = {d0, static_cast<T>(dims)...};
InitStrides(dims_arr, n);
}
OF_DEVICE_FUNC explicit NdIndexOffsetHelper(const T* dims) { InitStrides(dims, N); }
template<typename U>
OF_DEVICE_FUNC explicit NdIndexOffsetHelper(const U* dims) {
T dims_arr[N];
for (int i = 0; i < N; ++i) { dims_arr[i] = dims[i]; }
InitStrides(dims_arr, N);
}
OF_DEVICE_FUNC explicit NdIndexOffsetHelper(const T* dims, int n) { InitStrides(dims, n); }
template<typename U>
OF_DEVICE_FUNC explicit NdIndexOffsetHelper(const U* dims, int n) {
T dims_arr[N];
for (int i = 0; i < N; ++i) {
if (i < n) { dims_arr[i] = dims[i]; }
}
InitStrides(dims_arr, n);
}
~NdIndexOffsetHelper() = default;
OF_DEVICE_FUNC T NdIndexToOffset(const T* index) const {
T offset = 0;
#ifdef __CUDA_ARCH__
#pragma unroll
#endif
for (int i = 0; i < N - 1; ++i) { offset += index[i] * stride_[i]; }
offset += index[N - 1];
return offset;
}
OF_DEVICE_FUNC T NdIndexToOffset(const T* index, int n) const {
assert(n <= N);
T offset = 0;
#ifdef __CUDA_ARCH__
#pragma unroll
#endif
for (int i = 0; i < N; ++i) {
if (i < n) { offset += index[i] * stride_[i]; }
}
return offset;
}
template<class... Ts>
OF_DEVICE_FUNC T NdIndexToOffset(T d0, Ts... others) const {
constexpr int n = 1 + sizeof...(others);
static_assert(n <= N, "");
T index[n] = {d0, others...};
T offset = 0;
#ifdef __CUDA_ARCH__
#pragma unroll
#endif
for (int i = 0; i < n - 1; ++i) { offset += index[i] * stride_[i]; }
if (n == N) {
offset += index[n - 1];
} else {
offset += index[n - 1] * stride_[n - 1];
}
return offset;
}
OF_DEVICE_FUNC void OffsetToNdIndex(T offset, T* index) const {
T remaining = offset;
#ifdef __CUDA_ARCH__
#pragma unroll
#endif
for (int i = 0; i < N - 1; ++i) {
const T idx = remaining / stride_[i];
index[i] = idx;
remaining = remaining - idx * stride_[i];
}
index[N - 1] = remaining;
}
OF_DEVICE_FUNC void OffsetToNdIndex(T offset, T* index, int n) const {
assert(n <= N);
T remaining = offset;
#ifdef __CUDA_ARCH__
#pragma unroll
#endif
for (int i = 0; i < N; ++i) {
if (i < n) {
const T idx = remaining / stride_[i];
index[i] = idx;
remaining = remaining - idx * stride_[i];
}
}
}
template<class... Ts>
OF_DEVICE_FUNC void OffsetToNdIndex(T offset, T& d0, Ts&... others) const {
constexpr int n = 1 + sizeof...(others);
static_assert(n <= N, "");
T* index[n] = {&d0, &others...};
T remaining = offset;
#ifdef __CUDA_ARCH__
#pragma unroll
#endif
for (int i = 0; i < n - 1; ++i) {
const T idx = remaining / stride_[i];
*index[i] = idx;
remaining = remaining - idx * stride_[i];
}
if (n == N) {
*index[n - 1] = remaining;
} else {
*index[n - 1] = remaining / stride_[n - 1];
}
}
OF_DEVICE_FUNC constexpr int Size() const { return N; }
private:
OF_DEVICE_FUNC void InitStrides(const T* dims, const int n) {
for (int i = n - 1; i < N; ++i) { stride_[i] = 1; }
for (int i = n - 2; i >= 0; --i) { stride_[i] = dims[i + 1] * stride_[i + 1]; }
}
T stride_[N];
};
} // namespace oneflow
#endif // ONEFLOW_CORE_COMMON_ND_INDEX_OFFSET_HELPER_H_
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment