Commit 21d47d0e authored by yuguo's avatar yuguo
Browse files

Oneflow 0.8 for DCU

parents
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_COMMON_DATA_TYPE_SEQ_H_
#define ONEFLOW_CORE_COMMON_DATA_TYPE_SEQ_H_
#include "oneflow/core/common/preprocessor.h"
// SEQ
#define BOOL_DATA_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(bool, DataType::kBool)
#define FLOATING_DATA_TYPE_SEQ \
OF_PP_MAKE_TUPLE_SEQ(float, DataType::kFloat) \
OF_PP_MAKE_TUPLE_SEQ(double, DataType::kDouble)
#define SIGNED_INT_DATA_TYPE_SEQ \
OF_PP_MAKE_TUPLE_SEQ(int8_t, DataType::kInt8) \
OF_PP_MAKE_TUPLE_SEQ(int32_t, DataType::kInt32) \
OF_PP_MAKE_TUPLE_SEQ(int64_t, DataType::kInt64)
#define UNSIGNED_INT_DATA_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(uint8_t, DataType::kUInt8)
#define INT_DATA_TYPE_SEQ SIGNED_INT_DATA_TYPE_SEQ
#define CHAR_DATA_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(char, DataType::kChar)
#define ARITHMETIC_DATA_TYPE_SEQ \
FLOATING_DATA_TYPE_SEQ \
INT_DATA_TYPE_SEQ
#define POD_DATA_TYPE_SEQ \
ARITHMETIC_DATA_TYPE_SEQ CHAR_DATA_TYPE_SEQ UNSIGNED_INT_DATA_TYPE_SEQ BOOL_DATA_TYPE_SEQ
#define POD_AND_HALF_DATA_TYPE_SEQ POD_DATA_TYPE_SEQ FLOAT16_DATA_TYPE_SEQ
#define PB_DATA_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(OFRecord, DataType::kOFRecord)
#define ALL_DATA_TYPE_SEQ POD_DATA_TYPE_SEQ PB_DATA_TYPE_SEQ
#define INDEX_DATA_TYPE_SEQ \
OF_PP_MAKE_TUPLE_SEQ(int32_t, DataType::kInt32) \
OF_PP_MAKE_TUPLE_SEQ(int64_t, DataType::kInt64)
#define FLOAT16_DATA_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(float16, DataType::kFloat16)
#if defined(WITH_CUDA)
#define HALF_DATA_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(half, DataType::kFloat16)
#endif
#if defined(WITH_ROCM)
#define HALF_DATA_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(half, DataType::kFloat16)
#endif
#define IMAGE_DATA_TYPE_SEQ \
OF_PP_MAKE_TUPLE_SEQ(uint8_t, DataType::kUInt8) \
OF_PP_MAKE_TUPLE_SEQ(float, DataType::kFloat)
#define NO_BOXING_DATA_TYPE_SEQ \
OF_PP_MAKE_TUPLE_SEQ(OFRecord, DataType::kOFRecord) \
OF_PP_MAKE_TUPLE_SEQ(TensorBuffer, DataType::kTensorBuffer)
#endif // ONEFLOW_CORE_COMMON_DATA_TYPE_SEQ_H_
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_COMMON_DECORATOR_H_
#define ONEFLOW_CORE_COMMON_DECORATOR_H_
#include <type_traits>
#include <unordered_map>
#include "tuple_hash.h"
#include "static_check.h"
#include "oneflow/core/common/env_var/env_var.h"
#include "oneflow/core/common/cpp_attribute.h"
namespace oneflow {
template<template<typename...> class Decorator>
struct WithDecorator final {
template<typename T, typename = void>
struct Decorate;
template<typename T, typename... Args>
struct Decorate<T (*)(Args...)> final {
template<T (*func)(Args...)>
static T Call(Args... args) {
return Decorator<T, Args...>::template Call<func>(args...);
}
};
};
#define DECORATE(fn_ptr, decorator) \
(&WithDecorator<decorator>::Decorate<decltype(fn_ptr)>::Call<fn_ptr>)
template<typename... Args>
struct ThreadLocalCopiable;
template<typename RetT>
struct ThreadLocalCopiable<RetT> {
template<RetT (*func)()>
static RetT Call() {
static thread_local RetT value = func();
return value;
}
};
template<typename RetT, typename Arg0>
struct ThreadLocalCopiable<RetT, Arg0> {
template<RetT (*func)(Arg0)>
static RetT Call(Arg0 arg0) {
using KeyT = typename std::decay<Arg0>::type;
using MappedT = typename std::decay<RetT>::type;
static thread_local std::unordered_map<KeyT, MappedT> map;
auto iter = map.find(arg0);
if (iter == map.end()) { iter = map.emplace(arg0, func(arg0)).first; }
return iter->second;
}
private:
static_assert(!IsOutArg<Arg0>::value, "");
static_assert(!StaticAny<IsOutArg, Arg0>::value, "");
};
template<typename RetT, typename Arg0, typename Arg1>
struct ThreadLocalCopiable<RetT, Arg0, Arg1> {
template<RetT (*func)(Arg0, Arg1)>
static RetT Call(Arg0 arg0, Arg1 arg1) {
using KeyT0 = typename std::decay<Arg0>::type;
using KeyT1 = typename std::decay<Arg1>::type;
using MappedT = typename std::decay<RetT>::type;
static thread_local std::unordered_map<KeyT0, std::unordered_map<KeyT1, MappedT>> map;
auto* last_map = &map[arg0];
auto iter = last_map->find(arg1);
if (iter == last_map->end()) { iter = last_map->emplace(arg1, func(arg0, arg1)).first; }
return iter->second;
}
private:
static_assert(!StaticAny<IsOutArg, Arg0, Arg1>::value, "");
};
template<typename RetT, typename Arg0, typename Arg1, typename Arg2>
struct ThreadLocalCopiable<RetT, Arg0, Arg1, Arg2> {
template<RetT (*func)(Arg0, Arg1, Arg2)>
static RetT Call(Arg0 arg0, Arg1 arg1, Arg2 arg2) {
using KeyT0 = typename std::decay<Arg0>::type;
using KeyT1 = typename std::decay<Arg1>::type;
using KeyT2 = typename std::decay<Arg2>::type;
using MappedT = typename std::decay<RetT>::type;
static thread_local std::unordered_map<
KeyT0, std::unordered_map<KeyT1, std::unordered_map<KeyT2, MappedT>>>
map;
auto* last_map = &map[arg0][arg1];
auto iter = last_map->find(arg2);
if (iter == last_map->end()) { iter = last_map->emplace(arg2, func(arg0, arg1, arg2)).first; }
return iter->second;
}
private:
static_assert(!StaticAny<IsOutArg, Arg0, Arg1, Arg2>::value, "");
};
template<typename RetT, typename Arg0, typename Arg1, typename Arg2, typename Arg3,
typename... Args>
struct ThreadLocalCopiable<RetT, Arg0, Arg1, Arg2, Arg3, Args...> {
template<RetT (*func)(Arg0, Arg1, Arg2, Arg3, Args...)>
static RetT Call(Arg0 arg0, Arg1 arg1, Arg2 arg2, Arg3 arg3, Args... args) {
using KeyT0 = typename std::decay<Arg0>::type;
using KeyT1 = typename std::decay<Arg1>::type;
using KeyT2 = typename std::decay<Arg2>::type;
using KeyT3 = typename std::decay<Arg3>::type;
using KeyT = std::tuple<KeyT0, KeyT1, KeyT2, KeyT3, typename std::decay<Args>::type...>;
using MappedT = typename std::decay<RetT>::type;
static thread_local std::unordered_map<KeyT, MappedT> map;
const auto& key = KeyT(arg0, arg1, arg2, arg3, args...);
auto iter = map.find(key);
if (iter == map.end()) { iter = map.emplace(key, func(arg0, arg1, arg2, arg3, args...)).first; }
return iter->second;
}
private:
static_assert(!StaticAny<IsOutArg, Arg0, Arg1, Arg2, Arg3, Args...>::value, "");
};
// for scalar type key.
template<typename RetT, typename... Args>
struct ThreadLocal : public ThreadLocalCopiable<RetT, Args...> {
private:
static_assert(StaticAll<IsDecayedScalarType, Args...>::value, "");
};
template<typename... Args>
struct ThreadLocalCachedCopiable;
template<typename RetT>
struct ThreadLocalCachedCopiable<RetT> {
template<RetT (*func)()>
static RetT Call() {
static thread_local RetT value = func();
return value;
}
};
template<typename RetT, typename Arg0>
struct ThreadLocalCachedCopiable<RetT, Arg0> {
template<RetT (*func)(Arg0)>
static RetT Call(Arg0 arg0) {
using KeyT = typename std::decay<Arg0>::type;
using MappedT = typename std::decay<RetT>::type;
static thread_local std::unordered_map<KeyT, MappedT> map;
auto iter = map.find(arg0);
if (iter == map.end()) {
if (unlikely(map.size() >= ThreadLocalEnvInteger<ONEFLOW_THRAED_LOCAL_CACHED_SIZE>())) {
map.clear();
}
iter = map.emplace(arg0, func(arg0)).first;
}
return iter->second;
}
private:
static_assert(!IsOutArg<Arg0>::value, "");
static_assert(!StaticAny<IsOutArg, Arg0>::value, "");
};
template<typename RetT, typename Arg0, typename... Args>
struct ThreadLocalCachedCopiable<RetT, Arg0, Args...> {
template<RetT (*func)(Arg0, Args...)>
static RetT Call(Arg0 arg0, Args... args) {
using KeyT0 = typename std::decay<Arg0>::type;
using KeyT = std::tuple<KeyT0, typename std::decay<Args>::type...>;
using MappedT = typename std::decay<RetT>::type;
static thread_local std::unordered_map<KeyT, MappedT> map;
const auto& key = KeyT(arg0, args...);
auto iter = map.find(key);
if (iter == map.end()) {
if (unlikely(map.size() >= ThreadLocalEnvInteger<ONEFLOW_THRAED_LOCAL_CACHED_SIZE>())) {
map.clear();
}
iter = map.emplace(key, func(arg0, args...)).first;
}
return iter->second;
}
private:
static_assert(!StaticAny<IsOutArg, Arg0, Args...>::value, "");
};
// for scalar type key.
template<typename RetT, typename... Args>
struct ThreadLocalCached : public ThreadLocalCachedCopiable<RetT, Args...> {
private:
static_assert(StaticAll<IsDecayedScalarType, Args...>::value, "");
};
} // namespace oneflow
#endif // ONEFLOW_CORE_COMMON_DECORATOR_H_
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "gtest/gtest.h"
#include "oneflow/core/common/decorator.h"
#include "oneflow/core/common/util.h"
namespace oneflow {
namespace test {
Maybe<int> Inc(int x) { return x + 1; }
Maybe<int> IncByConstRef(const int& x) { return x + 1; }
TEST(ThreadLocal, scalar) {
auto* CachedInc = DECORATE(&Inc, ThreadLocal);
int x = CHECK_JUST(CachedInc(0));
ASSERT_EQ(x, 1);
}
TEST(ThreadLocal, const_ref) {
auto* CachedIncByConstRef = DECORATE(&IncByConstRef, ThreadLocal);
int x = CHECK_JUST(CachedIncByConstRef(0));
ASSERT_EQ(x, 1);
}
namespace {
struct Foo {
static Maybe<Foo> New(int x) { return std::shared_ptr<Foo>(new Foo{x}); }
int x;
};
} // namespace
TEST(ThreadLocal, _class) {
auto* CachedFooNew = DECORATE(&Foo::New, ThreadLocal);
const auto& foo = CHECK_JUST(CachedFooNew(10));
const auto& bar = CHECK_JUST(CachedFooNew(10));
ASSERT_EQ(foo->x, 10);
ASSERT_TRUE(foo == bar);
}
} // namespace test
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_COMMON_DEVICE_TYPE_H_
#define ONEFLOW_CORE_COMMON_DEVICE_TYPE_H_
#include "oneflow/core/common/device_type.pb.h"
namespace std {
template<>
struct hash<oneflow::DeviceType> final {
size_t operator()(oneflow::DeviceType device_type) const {
return static_cast<size_t>(device_type);
}
};
} // namespace std
namespace oneflow {
inline std::string PrintAvailableDevices() {
std::string str("cpu");
#if defined(WITH_CUDA) || defined(WITH_ROCM)
str += ", cuda";
#endif
return str;
}
inline std::string PrintGeneratorAvailableDevices() {
std::string str("cpu");
#if defined(WITH_CUDA) || defined(WITH_ROCM)
str += ", cuda";
#endif
str += ", auto"; // "auto" is a fake device type for random generator.
return str;
}
#if defined(WITH_CUDA) || defined(WITH_ROCM)
#define DEVICE_TYPE_SEQ \
OF_PP_MAKE_TUPLE_SEQ(DeviceType::kCPU) \
OF_PP_MAKE_TUPLE_SEQ(DeviceType::kCUDA)
#else
#define DEVICE_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(DeviceType::kCPU)
#endif
} // namespace oneflow
#endif // ONEFLOW_CORE_COMMON_DEVICE_TYPE_H_
syntax = "proto2";
package oneflow;
enum DeviceType {
kInvalidDevice = 0;
kCPU = 1;
kCUDA = 2;
kMockDevice = 3; // pseudo device for test.
kROCm = 4;
}
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_REGISTER_DTYPE_SIGNATURE_H_
#define ONEFLOW_CORE_REGISTER_DTYPE_SIGNATURE_H_
#include "oneflow/core/common/dtype_signature.pb.h"
#include "oneflow/core/common/protobuf.h"
namespace oneflow {
inline bool operator==(const DTypeSignature& lhs, const DTypeSignature& rhs) {
return PbMd().Equals(lhs, rhs);
}
} // namespace oneflow
namespace std {
template<>
struct hash<oneflow::DTypeSignature> final {
size_t operator()(const oneflow::DTypeSignature& dtype_signature) {
std::string serialized;
dtype_signature.SerializeToString(&serialized);
return std::hash<std::string>()(serialized);
}
};
} // namespace std
#endif // ONEFLOW_CORE_REGISTER_DTYPE_SIGNATURE_H_
syntax = "proto2";
package oneflow;
import "oneflow/core/common/data_type.proto";
message DTypeSignature {
map<string, DataType> name2dtype = 1;
}
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_COMMON_EIGEN_UTIL_H_
#define ONEFLOW_CORE_COMMON_EIGEN_UTIL_H_
#include "Eigen/Core"
#include "Eigen/Dense"
namespace oneflow {
template<typename T>
using EigenMatrixMap = Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>>;
template<typename T>
using EigenArrayMap = Eigen::Map<Eigen::Array<T, Eigen::Dynamic, Eigen::Dynamic>>;
template<typename T>
using ConstEigenMatrixMap = Eigen::Map<const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>>;
template<typename T>
using ConstEigenArrayMap = Eigen::Map<const Eigen::Array<T, Eigen::Dynamic, Eigen::Dynamic>>;
} // namespace oneflow
#endif // ONEFLOW_CORE_COMMON_EIGEN_UTIL_H_
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_COMMON_EITHER_PTR_H_
#define ONEFLOW_CORE_COMMON_EITHER_PTR_H_
#include <glog/logging.h>
#include <memory>
namespace oneflow {
template<typename X, typename Y>
class EitherPtr final {
public:
static_assert(!std::is_same<X, Y>::value, "X should not be Y");
using XPtr = std::shared_ptr<X>;
using YPtr = std::shared_ptr<Y>;
// WARNING: we should assume that the structure of shared_ptr<X> and shared_ptr<Y> is same,
// and obviously at most time the assumption holds
static_assert(sizeof(XPtr) == sizeof(YPtr), "unsupported shared_ptr implementation");
EitherPtr() : type_(UnionType<X>::value), x_ptr_(nullptr) {}
EitherPtr(const XPtr& ptr) : type_(UnionType<X>::value), x_ptr_(ptr) {}
EitherPtr(const YPtr& ptr) : type_(UnionType<Y>::value) { new (&x_ptr_) YPtr(ptr); }
EitherPtr(XPtr&& ptr) : type_(UnionType<X>::value), x_ptr_(std::move(ptr)) {}
EitherPtr(YPtr&& ptr) : type_(UnionType<Y>::value) { new (&x_ptr_) YPtr(std::move(ptr)); }
EitherPtr(const EitherPtr& either_ptr) : type_(either_ptr.type_), x_ptr_(either_ptr.x_ptr_) {}
EitherPtr(EitherPtr&& either_ptr)
: type_(either_ptr.type_), x_ptr_(std::move(either_ptr.x_ptr_)) {}
// the destructor of X or Y will be called properly because it will be stored in the deleter of
// shared_ptr while constructed
~EitherPtr() = default;
EitherPtr& operator=(const EitherPtr& either_ptr) {
x_ptr_ = either_ptr.x_ptr_;
type_ = either_ptr.type_;
return *this;
}
EitherPtr& operator=(EitherPtr&& either_ptr) {
x_ptr_ = std::move(either_ptr.x_ptr_);
type_ = either_ptr.type_;
return *this;
}
template<typename T>
bool Has() const {
return type_ == UnionType<T>::value;
}
template<typename T>
const std::shared_ptr<T>& Get() const {
return Get(tag<T>{});
}
private:
template<typename T, typename Enable = void>
struct UnionType;
template<typename T>
struct UnionType<T, typename std::enable_if<std::is_same<X, T>::value>::type> {
static constexpr int8_t value = 0;
};
template<typename T>
struct UnionType<T, typename std::enable_if<std::is_same<Y, T>::value>::type> {
static constexpr int8_t value = 1;
};
template<typename>
struct tag {};
const XPtr& Get(tag<X>) const {
CHECK(Has<X>());
return x_ptr_;
}
const YPtr& Get(tag<Y>) const {
CHECK(Has<Y>());
const auto* __attribute__((__may_alias__)) ptr = reinterpret_cast<const YPtr*>(&x_ptr_);
return *ptr;
}
int8_t type_;
std::shared_ptr<X> x_ptr_;
};
} // namespace oneflow
#endif // ONEFLOW_CORE_COMMON_EITHER_PTR_H_
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_COMMON_ENV_VAR_DEBUG_MODE_H_
#define ONEFLOW_CORE_COMMON_ENV_VAR_DEBUG_MODE_H_
#include "oneflow/core/common/env_var/env_var.h"
namespace oneflow {
DEFINE_ENV_BOOL(ONEFLOW_DEBUG_MODE, false);
DEFINE_ENV_BOOL(ONEFLOW_DEBUG, false);
inline bool IsInDebugMode() { return EnvBool<ONEFLOW_DEBUG_MODE>() || EnvBool<ONEFLOW_DEBUG>(); }
} // namespace oneflow
#endif // ONEFLOW_CORE_COMMON_ENV_VAR_DEBUG_MODE_H_
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_COMMON_ENV_VAR_ENV_VAR_H_
#define ONEFLOW_CORE_COMMON_ENV_VAR_ENV_VAR_H_
#include "oneflow/core/common/util.h"
namespace oneflow {
template<typename env_var>
bool EnvBool();
#define DEFINE_ENV_BOOL(env_var, default_value) \
struct env_var {}; \
template<> \
inline bool EnvBool<env_var>() { \
return ParseBooleanFromEnv(OF_PP_STRINGIZE(env_var), default_value); \
}
template<typename env_var>
int64_t EnvInteger();
#define DEFINE_ENV_INTEGER(env_var, default_value) \
struct env_var {}; \
template<> \
inline int64_t EnvInteger<env_var>() { \
return ParseIntegerFromEnv(OF_PP_STRINGIZE(env_var), default_value); \
}
DEFINE_ENV_INTEGER(ONEFLOW_TIMEOUT_SECONDS, 7200);
DEFINE_ENV_INTEGER(ONEFLOW_CHECK_TIMEOUT_SLEEP_SECONDS, EnvInteger<ONEFLOW_TIMEOUT_SECONDS>());
DEFINE_ENV_INTEGER(ONEFLOW_VM_BLOCKING_DEBUG_INSTRUCTIONS_DISPLAY_LIMIT, 100);
DEFINE_ENV_INTEGER(ONEFLOW_DELETE_OUTDATED_SHM_NAMES_INTERVAL, 1000);
template<typename env_var>
bool ThreadLocalEnvBool();
#define DEFINE_THREAD_LOCAL_ENV_BOOL(env_var, default_value) \
struct env_var {}; \
template<> \
inline bool ThreadLocalEnvBool<env_var>() { \
thread_local bool value = ParseBooleanFromEnv(OF_PP_STRINGIZE(env_var), default_value); \
return value; \
}
template<typename env_var>
int64_t ThreadLocalEnvInteger();
#define DEFINE_THREAD_LOCAL_ENV_INTEGER(env_var, default_value) \
struct env_var {}; \
template<> \
inline int64_t ThreadLocalEnvInteger<env_var>() { \
thread_local int64_t value = ParseIntegerFromEnv(OF_PP_STRINGIZE(env_var), default_value); \
return value; \
}
DEFINE_THREAD_LOCAL_ENV_INTEGER(ONEFLOW_THRAED_LOCAL_CACHED_SIZE, 128 * 1024);
} // namespace oneflow
#endif // ONEFLOW_CORE_COMMON_ENV_VAR_ENV_VAR_H_
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_COMMON_ENV_VAR_VM_H_
#define ONEFLOW_CORE_COMMON_ENV_VAR_VM_H_
#include "oneflow/core/common/env_var/env_var.h"
namespace oneflow {
DEFINE_THREAD_LOCAL_ENV_BOOL(ONEFLOW_VM_WORKLOAD_ON_SCHEDULER_THREAD, false);
}
#endif // ONEFLOW_CORE_COMMON_ENV_VAR_VM_H_
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include <stdexcept>
#include "oneflow/core/common/error.h"
#include "oneflow/core/common/exception.h"
#include "oneflow/core/common/protobuf.h"
#include "oneflow/core/common/util.h"
#include "oneflow/core/common/error_util.h"
#include "oneflow/core/common/env_var/debug_mode.h"
namespace oneflow {
namespace {
void LogError(const Error& error) {
// gdb break point
LOG(ERROR) << error->msg();
}
std::shared_ptr<ErrorProto>* MutThreadLocalError() {
thread_local std::shared_ptr<ErrorProto> error;
return &error;
}
} // namespace
Error&& Error::AddStackFrame(const std::string& file, const int64_t& line,
const std::string& function) {
auto* stack_frame = error_proto_->add_stack_frame();
stack_frame->set_file(file);
stack_frame->set_line(line);
stack_frame->set_function(function);
return std::move(*this);
}
void Error::Merge(const Error& other) {
std::string error_summary{error_proto_->error_summary()};
std::string msg{error_proto_->msg()};
error_proto_->MergeFrom(*other.error_proto_);
// MergeFrom will overwrite singular field, so restore it.
if (!error_summary.empty()) {
error_proto_->set_error_summary(error_summary + " " + error_proto_->error_summary());
}
if (!msg.empty()) { error_proto_->set_msg(msg + " " + error_proto_->msg()); }
}
Error::operator std::string() const { return error_proto_->DebugString(); }
Error Error::Ok() { return std::make_shared<ErrorProto>(); }
Error Error::ProtoParseFailedError() {
auto error = std::make_shared<ErrorProto>();
error->mutable_proto_parse_failed_error();
return error;
}
Error Error::JobSetEmptyError() {
auto error = std::make_shared<ErrorProto>();
error->mutable_job_set_empty_error();
return error;
}
Error Error::DeviceTagNotFoundError() {
auto error = std::make_shared<ErrorProto>();
error->mutable_device_tag_not_found_error();
return error;
}
Error Error::InvalidValueError(const std::string& error_summary) {
auto error = std::make_shared<ErrorProto>();
error->set_error_summary(error_summary);
error->mutable_invalid_value_error();
return error;
}
Error Error::IndexError() {
auto error = std::make_shared<ErrorProto>();
error->mutable_index_error();
return error;
}
Error Error::TypeError() {
auto error = std::make_shared<ErrorProto>();
error->mutable_type_error();
return error;
}
Error Error::TimeoutError() {
auto error = std::make_shared<ErrorProto>();
error->mutable_timeout_error();
return error;
}
Error Error::JobNameExistError() {
auto error = std::make_shared<ErrorProto>();
error->mutable_job_name_exist_error();
return error;
}
Error Error::JobNameEmptyError() {
auto error = std::make_shared<ErrorProto>();
error->mutable_job_name_empty_error();
return error;
}
Error Error::JobNameNotEqualError() {
auto error = std::make_shared<ErrorProto>();
error->mutable_job_name_not_equal_error();
return error;
}
Error Error::NoJobBuildAndInferCtxError() {
auto error = std::make_shared<ErrorProto>();
error->mutable_no_job_build_and_infer_ctx_error();
return error;
}
Error Error::JobConfFrozenError() {
auto error = std::make_shared<ErrorProto>();
error->mutable_job_conf_frozen_error();
return error;
}
Error Error::JobConfNotSetError() {
auto error = std::make_shared<ErrorProto>();
error->mutable_job_conf_not_set_error();
return error;
}
Error Error::JobConfRepeatedSetError() {
auto error = std::make_shared<ErrorProto>();
error->mutable_job_conf_repeated_set_error();
return error;
}
Error Error::JobTypeNotSetError() {
auto error = std::make_shared<ErrorProto>();
error->mutable_job_type_not_set_error();
return error;
}
Error Error::LogicalBlobNameNotExistError() {
auto error = std::make_shared<ErrorProto>();
error->mutable_logical_blob_name_not_exist_error();
return error;
}
Error Error::LogicalBlobNameExistError() {
auto error = std::make_shared<ErrorProto>();
error->mutable_logical_blob_name_exist_error();
return error;
}
Error Error::LogicalBlobNameInvalidError() {
auto error = std::make_shared<ErrorProto>();
error->mutable_logical_blob_name_invalid_error();
return error;
}
Error Error::OpNameExistError() {
auto error = std::make_shared<ErrorProto>();
error->mutable_op_name_exist_error();
return error;
}
Error Error::OpConfDeviceTagNoSetError() {
auto error = std::make_shared<ErrorProto>();
error->mutable_op_conf_device_tag_no_set_error();
return error;
}
Error Error::PlacementError() {
auto error = std::make_shared<ErrorProto>();
error->mutable_placement_error();
return error;
}
Error Error::BlobSplitAxisInferError() {
auto error = std::make_shared<ErrorProto>();
error->mutable_blob_split_axis_infer_error();
return error;
}
Error Error::UnknownJobBuildAndInferError() {
auto error = std::make_shared<ErrorProto>();
error->mutable_unknown_job_build_and_infer_error();
return error;
}
Error Error::CheckFailedError() {
auto error = std::make_shared<ErrorProto>();
error->mutable_check_failed_error();
return error;
}
Error Error::ValueNotFoundError() {
auto error = std::make_shared<ErrorProto>();
error->mutable_value_not_found_error();
return error;
}
Error Error::TodoError() {
auto error = std::make_shared<ErrorProto>();
error->mutable_todo_error();
return error;
}
Error Error::UnimplementedError() {
auto error = std::make_shared<ErrorProto>();
error->mutable_unimplemented_error();
return error;
}
Error Error::RuntimeError() {
auto error = std::make_shared<ErrorProto>();
error->mutable_runtime_error();
return error;
}
Error Error::OutOfMemoryError() {
auto error = std::make_shared<ErrorProto>();
error->mutable_out_of_memory_error();
return error;
}
Error Error::BoxingNotSupportedError() {
auto error = std::make_shared<ErrorProto>();
error->mutable_boxing_not_supported_error();
return error;
}
Error Error::OpKernelNotFoundError(const std::string& error_summary,
const std::vector<std::string>& error_msgs) {
auto error = std::make_shared<ErrorProto>();
error->set_error_summary(error_summary);
auto* op_kernel_not_found_error = error->mutable_op_kernel_not_found_error();
for (const auto& msg : error_msgs) {
op_kernel_not_found_error->add_op_kernels_not_found_debug_str(msg);
}
return error;
}
Error Error::MultipleOpKernelsMatchedError(const std::string& error_summary,
const std::vector<std::string>& error_msgs) {
auto error = std::make_shared<ErrorProto>();
error->set_error_summary(error_summary);
auto* multiple_op_kernels_matched_error = error->mutable_multiple_op_kernels_matched_error();
for (const auto& msg : error_msgs) {
multiple_op_kernels_matched_error->add_matched_op_kernels_debug_str(msg);
}
return error;
}
Error Error::MemoryZoneOutOfMemoryError(int64_t machine_id, int64_t mem_zone_id, uint64_t calc,
uint64_t available, const std::string& device_tag) {
auto error = std::make_shared<ErrorProto>();
auto* memory_zone_out_of_memory_error = error->mutable_memory_zone_out_of_memory_error();
memory_zone_out_of_memory_error->add_machine_id(std::to_string(machine_id));
memory_zone_out_of_memory_error->add_mem_zone_id(std::to_string(mem_zone_id));
memory_zone_out_of_memory_error->add_device_tag(device_tag);
memory_zone_out_of_memory_error->add_available(std::to_string(available) + " bytes");
memory_zone_out_of_memory_error->add_required(std::to_string(calc) + " bytes");
return error;
}
Error Error::LossBlobNotFoundError(const std::string& error_summary) {
auto error = std::make_shared<ErrorProto>();
error->mutable_loss_blob_not_found_error();
error->set_error_summary(error_summary);
return error;
}
Error Error::RwMutexedObjectNotFoundError() {
auto error = std::make_shared<ErrorProto>();
error->mutable_rw_mutexed_object_not_found_error();
return error;
}
Error Error::GradientFunctionNotFoundError() {
auto error = std::make_shared<ErrorProto>();
error->mutable_gradient_function_not_found_error();
return error;
}
Error Error::SymbolIdUninitializedError() {
auto error = std::make_shared<ErrorProto>();
error->mutable_symbol_id_uninitialized_error();
return error;
}
Error Error::CompileOptionWrongError() {
auto error = std::make_shared<ErrorProto>();
error->mutable_compile_option_wrong_error();
return error;
}
Error Error::InputDeviceNotMatchError() {
auto error = std::make_shared<ErrorProto>();
auto* input_device_not_match_error = error->mutable_input_device_not_match_error();
input_device_not_match_error->add_info(
std::string("Input tensors are at different devices, please try to use tensor.to or "
"module.to to correct it."));
return error;
}
std::string GetStackedErrorString(const std::shared_ptr<ErrorProto>& error) {
const auto& maybe_error = TRY(FormatErrorStr(error));
const auto& error_str = maybe_error.GetDataAndErrorProto(error->DebugString());
CHECK_NE(error->error_type_case(), ErrorProto::ERROR_TYPE_NOT_SET);
return error_str.first;
}
std::string GetErrorString(const std::shared_ptr<ErrorProto>& error) {
if (IsInDebugMode()) {
return GetStackedErrorString(error);
} else {
if (error->msg().empty() && error->stack_frame().size() > 0) {
return error->stack_frame(0).error_msg();
} else {
return error->msg();
}
}
}
void ThrowError(const std::shared_ptr<ErrorProto>& error) {
*MutThreadLocalError() = error;
if (error->has_runtime_error()) { throw RuntimeException(GetErrorString(error)); }
if (error->has_type_error()) { throw TypeException(GetErrorString(error)); }
if (error->has_index_error()) { throw IndexException(GetErrorString(error)); }
if (error->has_unimplemented_error()) { throw NotImplementedException(GetErrorString(error)); }
throw Exception(GetStackedErrorString(error));
}
const std::shared_ptr<ErrorProto>& ThreadLocalError() { return *MutThreadLocalError(); }
const char* kOfBugIssueUploadPrompt =
"This is a oneflow bug, please submit issues in "
"'https://github.com/Oneflow-Inc/oneflow/issues' include the log information of the error, the "
"minimum reproduction code, and the system information.";
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_COMMON_ERROR_H_
#define ONEFLOW_CORE_COMMON_ERROR_H_
#include <sstream>
#include <vector>
#include "oneflow/core/common/error.pb.h"
namespace oneflow {
class Error final {
public:
Error(const std::shared_ptr<ErrorProto>& error_proto) : error_proto_(error_proto) {}
Error(const Error&) = default;
~Error() = default;
std::shared_ptr<ErrorProto> error_proto() const { return error_proto_; }
const ErrorProto* operator->() const { return error_proto_.get(); }
ErrorProto* operator->() { return error_proto_.get(); }
operator std::string() const;
void Assign(const Error& other) { error_proto_ = other.error_proto_; }
void Merge(const Error& other);
// r-value reference is used to supporting expressions like `Error().AddStackFrame("foo.cpp",
// ,"line", "Bar") << "invalid value"` because operator<<() need r-value reference
Error&& AddStackFrame(const std::string& file, const int64_t& line, const std::string& function);
static Error Ok();
static Error ProtoParseFailedError();
static Error JobSetEmptyError();
static Error DeviceTagNotFoundError();
static Error InvalidValueError(const std::string& error_summary);
static Error IndexError();
static Error TypeError();
static Error TimeoutError();
static Error JobNameExistError();
static Error JobNameEmptyError();
static Error JobNameNotEqualError();
static Error NoJobBuildAndInferCtxError();
static Error JobConfFrozenError();
static Error JobConfNotSetError();
static Error JobConfRepeatedSetError();
static Error JobTypeNotSetError();
static Error LogicalBlobNameNotExistError();
static Error LogicalBlobNameExistError();
static Error LogicalBlobNameInvalidError();
static Error OpNameExistError();
static Error OpConfDeviceTagNoSetError();
static Error PlacementError();
static Error BlobSplitAxisInferError();
static Error UnknownJobBuildAndInferError();
static Error CheckFailedError();
static Error ValueNotFoundError();
static Error TodoError();
static Error UnimplementedError();
static Error RuntimeError();
static Error OutOfMemoryError();
static Error BoxingNotSupportedError();
static Error MemoryZoneOutOfMemoryError(int64_t machine_id, int64_t mem_zone_id, uint64_t calc,
uint64_t available, const std::string& device_type);
static Error OpKernelNotFoundError(const std::string& error_summary,
const std::vector<std::string>& error_msgs);
static Error MultipleOpKernelsMatchedError(const std::string& error_summary,
const std::vector<std::string>& error_msgs);
static Error LossBlobNotFoundError(const std::string& error_summary);
static Error RwMutexedObjectNotFoundError();
// gradient
static Error GradientFunctionNotFoundError();
// symbol
static Error SymbolIdUninitializedError();
static Error CompileOptionWrongError();
static Error InputDeviceNotMatchError();
private:
std::shared_ptr<ErrorProto> error_proto_;
};
void ThrowError(const std::shared_ptr<ErrorProto>& error);
const std::shared_ptr<ErrorProto>& ThreadLocalError();
template<typename T>
Error& operator<<(Error& error, const T& x) {
std::ostringstream ss;
ss << x;
if (error->stack_frame().empty()) {
error->set_msg(error->msg() + ss.str());
} else {
auto* stack_frame_top = error->mutable_stack_frame(error->stack_frame_size() - 1);
stack_frame_top->set_error_msg(stack_frame_top->error_msg() + ss.str());
}
return error;
}
// r-value reference is used to supporting expressions like `Error() << "invalid value"`
template<typename T>
Error&& operator<<(Error&& error, const T& x) {
error << x;
return std::move(error);
}
template<>
inline Error&& operator<<(Error&& error, const std::stringstream& x) {
error << x.str();
return std::move(error);
}
template<>
inline Error&& operator<<(Error&& error, const std::ostream& x) {
error << x.rdbuf();
return std::move(error);
}
template<>
inline Error&& operator<<(Error&& error, const Error& other) {
error.Merge(other);
return std::move(error);
}
extern const char* kOfBugIssueUploadPrompt;
} // namespace oneflow
#define PRINT_BUG_PROMPT_AND_ABORT() LOG(FATAL) << kOfBugIssueUploadPrompt
#endif // ONEFLOW_CORE_COMMON_ERROR_H_
syntax = "proto2";
package oneflow;
message FieldValue {
required string field = 1;
required string value = 2;
}
enum OpcodeType {
kInvalidCompareType = 0;
kEq = 1;
kNe = 2;
kGt = 3;
kGe = 4;
kLt = 5;
kLe = 6;
}
message OneFieldAssertError {
required OpcodeType compare_type = 1;
required FieldValue left = 2;
required string right_value = 3;
}
message TwoFieldAssertError {
required OpcodeType compare_type = 1;
required FieldValue left = 2;
required FieldValue right = 3;
}
message ConfigAssertFailedError {
oneof oprand_type {
OneFieldAssertError one_field_assert_error = 1;
TwoFieldAssertError two_field_assert_error = 2;
}
}
message ConfigResourceUnavailableError {
required FieldValue field_value = 1;
}
message JobSetEmptyError { }
message DeviceTagNotFoundError { }
message JobNameExistError { }
message JobNameEmptyError { }
message JobNameNotEqualError { }
message NoJobBuildAndInferCtxError { }
message JobConfFrozenError { }
message JobConfNotSetError { }
message JobConfRepeatedSetError { }
message JobTypeNotSetError { }
message LogicalBlobNameNotExistError { }
message LogicalBlobNameExistError { }
message LogicalBlobNameInvalidError { }
message OpNameExistError { }
message OpConfDeviceTagNoSetError { }
message PlacementError { }
message BlobSplitAxisInferError { }
message UnknownJobBuildAndInferError { }
message ProtoParseFailedError { }
message CheckFailedError { }
message TodoError { }
message UnimplementedError { }
message RuntimeError { }
message OutOfMemoryError { }
message BoxingNotSupportedError { }
message GradientFunctionNotFoundError { }
message OpKernelNotFoundError {
repeated string op_kernels_not_found_debug_str = 1;
}
message MultipleOpKernelsMatchedError {
repeated string matched_op_kernels_debug_str = 1;
}
message MemoryZoneOutOfMemoryError {
repeated string machine_id = 1;
repeated string mem_zone_id = 2;
repeated string device_tag = 3;
repeated string required = 4;
repeated string available = 5;
}
message LossBlobNotFoundError { }
message RwMutexedObjectNotFoundError { }
message UnknownError { }
message CompileOptionWrongError { }
message InputDeviceNotMatchError {
repeated string info = 1;
}
message ErrorStackFrame {
required string file = 1;
required int64 line = 2;
required string function = 3;
required string error_msg = 4;
}
message SymbolIdUninitializedError {}
message InvalidValueError {}
message IndexError {}
message TypeError {}
message TimeoutError {}
message ValueNotFoundError {}
message ErrorProto {
optional string error_summary = 1 [default = ""];
optional string msg = 2 [default = ""];
repeated ErrorStackFrame stack_frame = 3;
oneof error_type {
ConfigAssertFailedError config_assert_failed_error = 12;
ConfigResourceUnavailableError config_resource_unavailable_error = 13;
ProtoParseFailedError proto_parse_failed_error = 15;
CheckFailedError check_failed_error = 16;
TodoError todo_error = 17;
UnimplementedError unimplemented_error = 18;
BoxingNotSupportedError boxing_not_supported_error = 19;
GradientFunctionNotFoundError gradient_function_not_found_error = 20;
OpKernelNotFoundError op_kernel_not_found_error = 21;
MultipleOpKernelsMatchedError multiple_op_kernels_matched_error = 22;
MemoryZoneOutOfMemoryError memory_zone_out_of_memory_error = 23;
LossBlobNotFoundError loss_blob_not_found_error = 24;
JobSetEmptyError job_set_empty_error = 25;
DeviceTagNotFoundError device_tag_not_found_error = 26;
InvalidValueError invalid_value_error = 27;
IndexError index_error = 28;
TypeError type_error = 29;
RuntimeError runtime_error = 30;
OutOfMemoryError out_of_memory_error = 32;
TimeoutError timeout_error = 40;
ValueNotFoundError value_not_found_error = 31;
JobNameExistError job_name_exist_error = 100;
JobNameEmptyError job_name_empty_error = 101;
JobNameNotEqualError job_name_not_equal_error = 102;
NoJobBuildAndInferCtxError no_job_build_and_infer_ctx_error = 200;
JobConfFrozenError job_conf_frozen_error = 300;
JobConfNotSetError job_conf_not_set_error = 301;
JobConfRepeatedSetError job_conf_repeated_set_error = 302;
JobTypeNotSetError job_type_not_set_error = 303;
LogicalBlobNameNotExistError logical_blob_name_not_exist_error = 400;
LogicalBlobNameExistError logical_blob_name_exist_error = 401;
LogicalBlobNameInvalidError logical_blob_name_invalid_error = 402;
OpNameExistError op_name_exist_error = 450;
OpConfDeviceTagNoSetError op_conf_device_tag_no_set_error = 460;
PlacementError placement_error= 470;
BlobSplitAxisInferError blob_split_axis_infer_error = 480;
UnknownJobBuildAndInferError unknown_job_build_and_infer_error = 500;
RwMutexedObjectNotFoundError rw_mutexed_object_not_found_error = 600;
SymbolIdUninitializedError symbol_id_uninitialized_error = 700;
UnknownError unknown_error = 900;
CompileOptionWrongError compile_option_wrong_error = 950;
InputDeviceNotMatchError input_device_not_match_error = 1000;
}
}
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include <sstream>
#include "oneflow/core/common/error_util.h"
#include "oneflow/core/common/util.h"
#include "oneflow/core/job/graph_scope_vars.h"
namespace oneflow {
namespace {
std::string StripSpace(std::string str) {
if (str.size() == 0) { return ""; }
size_t pos = str.find_first_not_of(" ");
if (pos != std::string::npos) { str.erase(0, pos); }
pos = str.find_last_not_of(" ");
if (pos != std::string::npos) { str.erase(pos + 1); }
return str;
}
bool IsLetterNumberOrUnderline(char c) {
return (c >= '0' && c <= '9') || (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || (c == '_');
}
Maybe<std::string> ShortenMsg(std::string str) {
// 150 characters is the threshold
const int num_character_threshold = 150;
const int num_displayed_character = 50;
if (str.size() == 0) { return str; }
// strip space when JUST( xx );
str = StripSpace(str);
if (str.size() < num_character_threshold) { return str; }
// left part whose number of characters is just over 50
int left_index = num_displayed_character;
bool pre_condition = IsLetterNumberOrUnderline(str.at(left_index));
for (; left_index < str.size(); left_index++) {
bool cur_condition = IsLetterNumberOrUnderline(str.at(left_index));
if ((pre_condition && !cur_condition) || (!pre_condition && cur_condition)) { break; }
}
// right part whose number of characters is just over 50
int right_index = str.size() - num_displayed_character;
pre_condition = IsLetterNumberOrUnderline(str.at(right_index));
for (; right_index >= 0; right_index--) {
bool cur_condition = IsLetterNumberOrUnderline(str.at(right_index));
if ((pre_condition && !cur_condition) || (!pre_condition && cur_condition)) {
right_index++;
break;
}
}
// a long word of more than 150
if (right_index - left_index < 50) { return str; }
std::stringstream ss;
CHECK_OR_RETURN(left_index >= 0);
CHECK_OR_RETURN(left_index < str.size());
ss << str.substr(0, left_index);
ss << " ... ";
CHECK_OR_RETURN(right_index >= 0);
CHECK_OR_RETURN(right_index < str.size());
ss << str.substr(right_index);
return ss.str();
}
// file info in stack frame
std::string FormatFileOfStackFrame(const std::string& file) {
std::stringstream ss;
ss << "\n File \"" << file << "\", ";
return ss.str();
}
// line info in stack frame
std::string FormatLineOfStackFrame(const int64_t& line) {
std::stringstream ss;
ss << "line " << line << ",";
return ss.str();
}
// function info in stack frame
std::string FormatFunctionOfStackFrame(const std::string& function) {
std::stringstream ss;
ss << " in " << function;
return ss.str();
}
// msg in stack frame
Maybe<std::string> FormatMsgOfStackFrame(std::string error_msg, bool is_last_stack_frame) {
const bool debug_mode = GetGraphDebugMode();
// only shorten the message if it is not the last stack frame AND not in debug mode
if (!is_last_stack_frame && !debug_mode) { error_msg = *JUST(ShortenMsg(error_msg)); }
// error_msg of last stack frame come from "<<"
if (is_last_stack_frame) { error_msg = StripSpace(error_msg); }
std::stringstream ss;
ss << "\n " << error_msg;
return ss.str();
}
// the error_summary and msg in error proto
std::string FormatErrorSummaryAndMsgOfErrorProto(const std::shared_ptr<ErrorProto>& error) {
std::stringstream ss;
if (error->has_error_summary()) { ss << error->error_summary(); }
if (error->has_msg()) { ss << (ss.str().size() != 0 ? "\n" + error->msg() : error->msg()); }
return ss.str();
}
// the msg in error type instance.
Maybe<std::string> FormatMsgOfErrorType(const std::shared_ptr<ErrorProto>& error) {
CHECK_NE_OR_RETURN(error->error_type_case(), ErrorProto::ERROR_TYPE_NOT_SET)
<< Error::RuntimeError() << "Parse error failed, unknown error type";
std::stringstream ss;
const google::protobuf::Descriptor* error_des = error->GetDescriptor();
const google::protobuf::OneofDescriptor* oneof_field_des =
error_des->FindOneofByName("error_type");
const google::protobuf::Reflection* error_ref = error->GetReflection();
const google::protobuf::FieldDescriptor* field_des =
error_ref->GetOneofFieldDescriptor(*error, oneof_field_des);
CHECK_OR_RETURN(field_des != nullptr);
ss << "Error Type: " << field_des->full_name();
return ss.str();
}
} // namespace
Maybe<std::string> FormatErrorStr(const std::shared_ptr<ErrorProto>& error) {
std::stringstream ss;
// Get msg from stack frame of error proto
for (auto stack_frame = error->mutable_stack_frame()->rbegin();
stack_frame < error->mutable_stack_frame()->rend(); stack_frame++) {
ss << FormatFileOfStackFrame(stack_frame->file()) << FormatLineOfStackFrame(stack_frame->line())
<< FormatFunctionOfStackFrame(stack_frame->function())
<< *JUST(FormatMsgOfStackFrame(stack_frame->error_msg(),
stack_frame == error->mutable_stack_frame()->rend() - 1));
}
// Get msg from error summary and msg of error proto
std::string error_summary_and_msg_of_error_proto = FormatErrorSummaryAndMsgOfErrorProto(error);
if (error_summary_and_msg_of_error_proto.size() != 0) {
ss << "\n" << error_summary_and_msg_of_error_proto;
}
// Get msg from error type of error proto
std::string msg_of_error_type = *JUST(FormatMsgOfErrorType(error));
if (msg_of_error_type.size() != 0) { ss << "\n" << msg_of_error_type; }
return ss.str();
}
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_COMMON_ERROR_UTIL_H
#define ONEFLOW_CORE_COMMON_ERROR_UTIL_H
#include <string>
#include "oneflow/core/common/error.pb.h"
#include "oneflow/core/common/maybe.h"
namespace oneflow {
Maybe<std::string> FormatErrorStr(const std::shared_ptr<ErrorProto>& error);
} // namespace oneflow
#endif // ONEFLOW_CORE_COMMON_ERROR_UTIL_H
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_COMMON_EXCEPTION_H_
#define ONEFLOW_CORE_COMMON_EXCEPTION_H_
#include <exception>
#include <string>
namespace oneflow {
class Exception : public std::exception {
public:
explicit Exception(const std::string& what) : what_(what) {}
virtual ~Exception() = default;
const char* what() const noexcept override { return what_.c_str(); }
private:
std::string what_;
};
class RuntimeException : public Exception {
public:
using Exception::Exception;
};
class TypeException : public Exception {
public:
using Exception::Exception;
};
class IndexException : public Exception {
public:
using Exception::Exception;
};
class NotImplementedException : public Exception {
public:
using Exception::Exception;
};
} // namespace oneflow
#endif // ONEFLOW_CORE_COMMON_EXCEPTION_H_
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/common/flat_shape.h"
#include "oneflow/core/common/shape.h"
namespace oneflow {
/*static*/ Maybe<FlatShape> FlatShape::New(const Shape& shape) {
const auto& flat_shape = std::make_shared<FlatShape>();
JUST(flat_shape->Init(shape));
return flat_shape;
}
Maybe<void> FlatShape::Init(const Shape& shape) {
CHECK_LE_OR_RETURN(shape.NumAxes(), SHAPE_MAX_AXIS_SIZE);
this->clear_dim();
for (int i = 0; i < shape.NumAxes(); ++i) { *this->mutable_dim()->Add() = shape.At(i); }
return Maybe<void>::Ok();
}
Maybe<void> FlatShape::Check(const Shape& shape) const {
CHECK_EQ_OR_RETURN(this->dim_size(), shape.NumAxes())
<< Error::RuntimeError()
<< "Expected same shape on each rank, but found at least two shapes, "
<< JUST(ToShape())->ToString() << " and " << shape.ToString() << "!";
for (int i = 0; i < this->dim_size(); ++i) { CHECK_EQ_OR_RETURN(this->dim(i), shape.At(i)); }
return Maybe<void>::Ok();
}
Maybe<void> FlatShape::Check(const FlatShape& flat_shape) const {
CHECK_EQ_OR_RETURN(this->dim_size(), flat_shape.NumAxes())
<< Error::RuntimeError()
<< "Expected input of each rank must have the same size, but got at least two size, "
<< JUST(ToShape())->ToString() << " and " << JUST(flat_shape.ToShape())->ToString();
for (int i = 0; i < this->dim_size(); ++i) {
CHECK_EQ_OR_RETURN(this->dim(i), flat_shape.At(i))
<< Error::RuntimeError()
<< "Expected input of each rank must have the same size, but got at least two size, "
<< JUST(ToShape())->ToString() << " and " << JUST(flat_shape.ToShape())->ToString();
}
return Maybe<void>::Ok();
}
Maybe<Shape> FlatShape::ToShape() const {
const auto& shape = std::make_shared<Shape>();
JUST(ToShape(shape.get()));
return shape;
}
Maybe<void> FlatShape::ToShape(Shape* shape) const {
DimVector dim_vec;
for (int i = 0; i < this->dim_size(); ++i) { dim_vec.emplace_back(this->dim(i)); }
*shape = Shape(dim_vec);
return Maybe<void>::Ok();
}
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_COMMON_FLAT_SHAPE_H_
#define ONEFLOW_CORE_COMMON_FLAT_SHAPE_H_
#include <memory>
#include "oneflow/core/intrusive/flat_msg.h"
#include "oneflow/core/common/maybe.h"
#include "oneflow/core/common/shape_vec.h"
namespace oneflow {
class Shape;
// clang-format off
FLAT_MSG_BEGIN(FlatShape);
public:
// Methods
static Maybe<FlatShape> New(const Shape& shape);
Maybe<void> Init(const Shape& shape);
Maybe<void> Check(const Shape& shape) const;
Maybe<void> Check(const FlatShape& flat_shape) const;
Maybe<Shape> ToShape() const;
Maybe<void> ToShape(Shape* shape) const;
int64_t At(int i) const { return dim(i); }
int64_t NumAxes() const { return dim_size(); }
// Fields
FLAT_MSG_DEFINE_REPEATED(int64_t, dim, SHAPE_MAX_AXIS_SIZE);
FLAT_MSG_END(FlatShape);
// clang-format on
} // namespace oneflow
#endif // ONEFLOW_CORE_COMMON_FLAT_SHAPE_H_
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment