Commit a715222c authored by yuguo's avatar yuguo
Browse files

0.9.1-rocm

parent f262efc9
...@@ -19,15 +19,15 @@ limitations under the License. ...@@ -19,15 +19,15 @@ limitations under the License.
namespace oneflow { namespace oneflow {
namespace { namespace {
std::shared_ptr<ErrorProto>* MutRegistryError() { std::shared_ptr<StackedError>* MutRegistryError() {
static std::shared_ptr<ErrorProto> registry_error; static std::shared_ptr<StackedError> registry_error;
return &registry_error; return &registry_error;
} }
} // namespace } // namespace
Maybe<void> CheckAndClearRegistryFlag() { Maybe<void> CheckAndClearRegistryFlag() {
if (!*MutRegistryError()) { return Maybe<void>::Ok(); } if (!*MutRegistryError()) { return Maybe<void>::Ok(); }
std::shared_ptr<ErrorProto> registry_error_old = *MutRegistryError(); std::shared_ptr<StackedError> registry_error_old = *MutRegistryError();
*MutRegistryError() = nullptr; *MutRegistryError() = nullptr;
return registry_error_old; return registry_error_old;
} }
...@@ -35,7 +35,7 @@ Maybe<void> CheckAndClearRegistryFlag() { ...@@ -35,7 +35,7 @@ Maybe<void> CheckAndClearRegistryFlag() {
void CatchRegistryError(const std::function<Maybe<void>()>& handler) { void CatchRegistryError(const std::function<Maybe<void>()>& handler) {
const auto& maybe_error = TRY(handler()); const auto& maybe_error = TRY(handler());
if (!maybe_error.IsOk()) { if (!maybe_error.IsOk()) {
if (!*MutRegistryError()) { *MutRegistryError() = maybe_error.error(); } if (!*MutRegistryError()) { *MutRegistryError() = maybe_error.stacked_error(); }
} }
} }
......
...@@ -29,7 +29,7 @@ namespace oneflow { ...@@ -29,7 +29,7 @@ namespace oneflow {
} \ } \
return *this; \ return *this; \
} \ } \
Scalar Scalar::operator op(const Scalar& other) { \ Scalar Scalar::operator op(const Scalar& other) const { \
if (IsFloatingPoint() || other.IsFloatingPoint()) { \ if (IsFloatingPoint() || other.IsFloatingPoint()) { \
double val = As<double>() op other.As<double>(); \ double val = As<double>() op other.As<double>(); \
return Scalar(val); \ return Scalar(val); \
......
...@@ -29,28 +29,28 @@ class Scalar { ...@@ -29,28 +29,28 @@ class Scalar {
Scalar() : Scalar(int32_t(0)) {} Scalar() : Scalar(int32_t(0)) {}
template<typename T, typename std::enable_if<std::is_same<T, bool>::value, int>::type = 0> template<typename T, typename std::enable_if<std::is_same<T, bool>::value, int>::type = 0>
Scalar(const T& value) : value_{.b = value}, active_tag_(HAS_B) {} OF_DEVICE_FUNC Scalar(const T& value) : value_{.b = value}, active_tag_(HAS_B) {}
template<typename T, typename std::enable_if< template<typename T, typename std::enable_if<
std::is_integral<T>::value && std::is_signed<T>::value, int>::type = 0> std::is_integral<T>::value && std::is_signed<T>::value, int>::type = 0>
Scalar(const T& value) : value_{.s = value}, active_tag_(HAS_S) {} OF_DEVICE_FUNC Scalar(const T& value) : value_{.s = value}, active_tag_(HAS_S) {}
template<typename T, template<typename T,
typename std::enable_if<std::is_integral<T>::value && std::is_unsigned<T>::value typename std::enable_if<std::is_integral<T>::value && std::is_unsigned<T>::value
&& !std::is_same<T, bool>::value, && !std::is_same<T, bool>::value,
int>::type = 0> int>::type = 0>
Scalar(const T& value) : value_{.u = value}, active_tag_(HAS_U) {} OF_DEVICE_FUNC Scalar(const T& value) : value_{.u = value}, active_tag_(HAS_U) {}
template<typename T, typename std::enable_if<std::is_floating_point<T>::value, int>::type = 0> template<typename T, typename std::enable_if<std::is_floating_point<T>::value, int>::type = 0>
Scalar(const T& value) : value_{.d = value}, active_tag_(HAS_D) {} OF_DEVICE_FUNC Scalar(const T& value) : value_{.d = value}, active_tag_(HAS_D) {}
template<typename T, typename std::enable_if<!std::is_same<T, Scalar>::value, int>::type = 0> template<typename T, typename std::enable_if<!std::is_same<T, Scalar>::value, int>::type = 0>
Scalar& operator=(const T& value) { OF_DEVICE_FUNC Scalar& operator=(const T& value) {
*this = Scalar(value); *this = Scalar(value);
return *this; return *this;
} }
Scalar& operator=(const Scalar& other) { OF_DEVICE_FUNC Scalar& operator=(const Scalar& other) {
value_ = other.value_; value_ = other.value_;
active_tag_ = other.active_tag_; active_tag_ = other.active_tag_;
return *this; return *this;
...@@ -78,10 +78,10 @@ class Scalar { ...@@ -78,10 +78,10 @@ class Scalar {
bool IsSigned() const { return active_tag_ == HAS_S || active_tag_ == HAS_D; } bool IsSigned() const { return active_tag_ == HAS_S || active_tag_ == HAS_D; }
bool IsUnsigned() const { return active_tag_ == HAS_U; } bool IsUnsigned() const { return active_tag_ == HAS_U; }
Scalar operator+(const Scalar& other); Scalar operator+(const Scalar& other) const;
Scalar operator-(const Scalar& other); Scalar operator-(const Scalar& other) const;
Scalar operator*(const Scalar& other); Scalar operator*(const Scalar& other) const;
Scalar operator/(const Scalar& other); Scalar operator/(const Scalar& other) const;
Scalar& operator+=(const Scalar& other); Scalar& operator+=(const Scalar& other);
Scalar& operator-=(const Scalar& other); Scalar& operator-=(const Scalar& other);
......
...@@ -220,4 +220,14 @@ Maybe<Shape> Shape::Slice(int64_t start_dim, int64_t end_dim) const { ...@@ -220,4 +220,14 @@ Maybe<Shape> Shape::Slice(int64_t start_dim, int64_t end_dim) const {
return shape; return shape;
} }
bool Shape::operator==(const Shape& rhs) const {
if (is_initialized_ != rhs.is_initialized_) { return false; }
if (is_initialized_ == false) { return true; }
if (this->NumAxes() != rhs.NumAxes()) { return false; }
FOR_RANGE(int, i, 0, this->NumAxes()) {
if (this->At(i) != rhs.At(i)) { return false; }
}
return true;
}
} // namespace oneflow } // namespace oneflow
...@@ -147,6 +147,8 @@ class Shape final : public DimVector, public MutShapeMixIn<Shape> { ...@@ -147,6 +147,8 @@ class Shape final : public DimVector, public MutShapeMixIn<Shape> {
Maybe<Shape> Slice(int64_t start_dim, int64_t end_dim) const; Maybe<Shape> Slice(int64_t start_dim, int64_t end_dim) const;
bool operator==(const Shape& rhs) const;
private: private:
// Set default value here because some constructors are inherited from DimVector // Set default value here because some constructors are inherited from DimVector
// TODO(daquexian): remove this field and make it initializied by construction // TODO(daquexian): remove this field and make it initializied by construction
...@@ -170,6 +172,7 @@ namespace std { ...@@ -170,6 +172,7 @@ namespace std {
template<> template<>
struct hash<oneflow::Shape> { struct hash<oneflow::Shape> {
size_t operator()(const oneflow::Shape& shape) const { size_t operator()(const oneflow::Shape& shape) const {
if (!shape.is_initialized()) { return 0; }
size_t ret = shape.NumAxes(); size_t ret = shape.NumAxes();
FOR_RANGE(int, i, 0, shape.NumAxes()) { oneflow::AddHash(&ret, shape.At(i)); } FOR_RANGE(int, i, 0, shape.NumAxes()) { oneflow::AddHash(&ret, shape.At(i)); }
return ret; return ret;
......
...@@ -36,6 +36,7 @@ std::ostream& operator<<(std::ostream& out, ShapeView shape) { ...@@ -36,6 +36,7 @@ std::ostream& operator<<(std::ostream& out, ShapeView shape) {
} }
void MutShapeView::set_shape(ShapeView shape) { void MutShapeView::set_shape(ShapeView shape) {
if (shape.ptr() == mut_ptr() && shape.NumAxes() == NumAxes()) { return; }
CHECK_EQ(NumAxes(), shape.NumAxes()); CHECK_EQ(NumAxes(), shape.NumAxes());
std::copy(shape.ptr(), shape.ptr() + shape.NumAxes(), mut_ptr()); std::copy(shape.ptr(), shape.ptr() + shape.NumAxes(), mut_ptr());
} }
......
...@@ -25,6 +25,7 @@ class small_vector : public llvm::SmallVector<T, N> { ...@@ -25,6 +25,7 @@ class small_vector : public llvm::SmallVector<T, N> {
using Base = llvm::SmallVector<T, N>; using Base = llvm::SmallVector<T, N>;
public: public:
constexpr static size_t kInitialSize = N;
// https://stackoverflow.com/questions/27954940/a-using-statement-compiles-with-g-fails-compilation-with-clang // https://stackoverflow.com/questions/27954940/a-using-statement-compiles-with-g-fails-compilation-with-clang
using Base::Base; using Base::Base;
...@@ -36,6 +37,10 @@ class small_vector : public llvm::SmallVector<T, N> { ...@@ -36,6 +37,10 @@ class small_vector : public llvm::SmallVector<T, N> {
CHECK_LT(idx, Base::size()); CHECK_LT(idx, Base::size());
return (*this)[idx]; return (*this)[idx];
} }
typename Base::reference operator[](typename Base::size_type idx) { return this->data()[idx]; }
typename Base::const_reference operator[](typename Base::size_type idx) const {
return this->data()[idx];
}
typename Base::const_iterator cbegin() const { typename Base::const_iterator cbegin() const {
return (typename Base::const_iterator)this->BeginX; return (typename Base::const_iterator)this->BeginX;
} }
......
...@@ -22,7 +22,7 @@ namespace oneflow { ...@@ -22,7 +22,7 @@ namespace oneflow {
Maybe<void> SpinCounter::WaitUntilCntEqualZero() const { Maybe<void> SpinCounter::WaitUntilCntEqualZero() const {
return Singleton<ForeignLockHelper>::Get()->WithScopedRelease([&]() -> Maybe<void> { return Singleton<ForeignLockHelper>::Get()->WithScopedRelease([&]() -> Maybe<void> {
while (cnt_val_ > 0) {}; while (cnt_val_ > 0) {}
return Maybe<void>::Ok(); return Maybe<void>::Ok();
}); });
} }
......
...@@ -31,6 +31,7 @@ class SpinCounter final { ...@@ -31,6 +31,7 @@ class SpinCounter final {
explicit SpinCounter(int64_t cnt_val) : cnt_val_(cnt_val) {} explicit SpinCounter(int64_t cnt_val) : cnt_val_(cnt_val) {}
int64_t Decrease() { return --cnt_val_; } int64_t Decrease() { return --cnt_val_; }
void Reset(int64_t cnt_val) { cnt_val_ = cnt_val; }
Maybe<void> WaitUntilCntEqualZero() const; Maybe<void> WaitUntilCntEqualZero() const;
private: private:
......
...@@ -34,7 +34,7 @@ class SteadyVector { ...@@ -34,7 +34,7 @@ class SteadyVector {
using size_type = size_t; using size_type = size_t;
// thread safe. // thread safe.
size_t size() const { return size_; } size_t size() const { return size_.load(std::memory_order_acquire); }
// thread safe. // thread safe.
const T& at(size_t index) const { const T& at(size_t index) const {
...@@ -51,12 +51,10 @@ class SteadyVector { ...@@ -51,12 +51,10 @@ class SteadyVector {
return granularity2data_[gran].get()[index - start]; return granularity2data_[gran].get()[index - start];
} }
void push_back(const T& elem) { *MutableOrAdd(size_) = elem; } // `index` should be <= size()
void SetOrAdd(size_t index, T value) {
// `index` shoule be <= size()
T* MutableOrAdd(size_t index) {
std::unique_lock<std::mutex> lock(mutex_); std::unique_lock<std::mutex> lock(mutex_);
size_t size = size_; size_t size = size_.load(std::memory_order_relaxed);
CHECK_LE(index, size) << "index out of range"; CHECK_LE(index, size) << "index out of range";
if (index == size) { if (index == size) {
int granularity = GetGranularity(size); int granularity = GetGranularity(size);
...@@ -64,11 +62,15 @@ class SteadyVector { ...@@ -64,11 +62,15 @@ class SteadyVector {
CHECK_LT(granularity, N); CHECK_LT(granularity, N);
granularity2data_[granularity].reset(new T[1 << granularity]); granularity2data_[granularity].reset(new T[1 << granularity]);
} }
++size_; *Mutable(index) = std::move(value);
size_.fetch_add(1, std::memory_order_release);
} else {
*Mutable(index) = std::move(value);
} }
return Mutable(index);
} }
void push_back(const T& elem) { SetOrAdd(size_, elem); }
private: private:
T* Mutable(size_t index) { T* Mutable(size_t index) {
int gran = 0; int gran = 0;
......
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_COMMON_STREAM_ROLE_H_
#define ONEFLOW_CORE_COMMON_STREAM_ROLE_H_
#include <functional>
#include <array>
#include "oneflow/core/common/preprocessor.h"
#include "glog/logging.h"
namespace oneflow {
enum class StreamRole {
kInvalid = 0,
kCompute,
kHost2Device,
kDevice2Host,
kSyncedLaunchedCommNet,
kAsyncedLaunchedCommNet,
kBarrier,
kCriticalSection,
kLazyJobLauncher,
kPinnedCompute
};
template<typename DerivedT>
struct StreamRoleVisitor {
template<typename... Args>
static auto Visit(StreamRole stream_role, Args&&... args) {
switch (stream_role) {
case StreamRole::kInvalid: LOG(FATAL) << "invalid stream role";
case StreamRole::kCompute: return DerivedT::VisitCompute(std::forward<Args>(args)...);
case StreamRole::kHost2Device: return DerivedT::VisitHost2Device(std::forward<Args>(args)...);
case StreamRole::kDevice2Host: return DerivedT::VisitDevice2Host(std::forward<Args>(args)...);
case StreamRole::kSyncedLaunchedCommNet:
return DerivedT::VisitSyncedLaunchedCommNet(std::forward<Args>(args)...);
case StreamRole::kAsyncedLaunchedCommNet:
return DerivedT::VisitAsyncedLaunchedCommNet(std::forward<Args>(args)...);
case StreamRole::kBarrier: return DerivedT::VisitBarrier(std::forward<Args>(args)...);
case StreamRole::kCriticalSection:
return DerivedT::VisitCriticalSection(std::forward<Args>(args)...);
case StreamRole::kLazyJobLauncher:
return DerivedT::VisitLazyJobLauncher(std::forward<Args>(args)...);
case StreamRole::kPinnedCompute:
return DerivedT::VisitPinnedCompute(std::forward<Args>(args)...);
}
LOG(FATAL) << "invalid stream role";
}
};
} // namespace oneflow
namespace std {
template<>
struct hash<oneflow::StreamRole> final {
size_t operator()(const oneflow::StreamRole& stream_role) const {
return static_cast<int>(stream_role);
}
};
} // namespace std
#endif // ONEFLOW_CORE_COMMON_STREAM_ROLE_H_
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_COMMON_STREAM_TYPE_H_
#define ONEFLOW_CORE_COMMON_STREAM_TYPE_H_
#include <functional>
#include <array>
#include "oneflow/core/common/preprocessor.h"
#include "glog/logging.h"
namespace oneflow {
enum class StreamType {
kInvalid = 0,
kCompute,
kHost2Device,
kDevice2Host,
kCcl,
kBarrier,
kCriticalSection,
kLazyJobLauncher,
kPinnedCompute
};
template<typename DerivedT>
struct StreamTypeVisitor {
template<typename... Args>
static auto Visit(StreamType stream_type, Args&&... args) {
switch (stream_type) {
case StreamType::kInvalid: LOG(FATAL) << "invalid stream type";
case StreamType::kCompute: return DerivedT::VisitCompute(std::forward<Args>(args)...);
case StreamType::kHost2Device: return DerivedT::VisitHost2Device(std::forward<Args>(args)...);
case StreamType::kDevice2Host: return DerivedT::VisitDevice2Host(std::forward<Args>(args)...);
case StreamType::kCcl: return DerivedT::VisitCcl(std::forward<Args>(args)...);
case StreamType::kBarrier: return DerivedT::VisitBarrier(std::forward<Args>(args)...);
case StreamType::kCriticalSection:
return DerivedT::VisitCriticalSection(std::forward<Args>(args)...);
case StreamType::kLazyJobLauncher:
return DerivedT::VisitLazyJobLauncher(std::forward<Args>(args)...);
case StreamType::kPinnedCompute:
return DerivedT::VisitPinnedCompute(std::forward<Args>(args)...);
}
LOG(FATAL) << "invalid stream type";
}
};
} // namespace oneflow
namespace std {
template<>
struct hash<oneflow::StreamType> final {
size_t operator()(const oneflow::StreamType& stream_type) const {
return static_cast<int>(stream_type);
}
};
} // namespace std
#endif // ONEFLOW_CORE_COMMON_STREAM_TYPE_H_
...@@ -15,6 +15,7 @@ limitations under the License. ...@@ -15,6 +15,7 @@ limitations under the License.
*/ */
#include "oneflow/core/common/stride.h" #include "oneflow/core/common/stride.h"
#include "oneflow/core/common/constant.h"
#include "oneflow/core/common/protobuf.h" #include "oneflow/core/common/protobuf.h"
#include "oneflow/core/common/cplusplus_17.h" #include "oneflow/core/common/cplusplus_17.h"
...@@ -29,7 +30,7 @@ Stride::Stride(const Shape& shape) { ...@@ -29,7 +30,7 @@ Stride::Stride(const Shape& shape) {
std::multiplies<>{}); std::multiplies<>{});
} else if (ndim > 0 && shape.elem_cnt() == 0) { } else if (ndim > 0 && shape.elem_cnt() == 0) {
// 0-size shape // 0-size shape
std::vector<int64_t> tmp_shape(ndim); small_vector<int64_t, kMaxNumDims> tmp_shape(ndim);
for (int64_t i = 0; i < ndim; ++i) { tmp_shape[i] = shape.At(i) > 0 ? shape.At(i) : 1; } for (int64_t i = 0; i < ndim; ++i) { tmp_shape[i] = shape.At(i) > 0 ? shape.At(i) : 1; }
std::exclusive_scan(tmp_shape.rbegin(), tmp_shape.rend(), rbegin(), (int64_t)1, std::exclusive_scan(tmp_shape.rbegin(), tmp_shape.rend(), rbegin(), (int64_t)1,
std::multiplies<>{}); std::multiplies<>{});
......
...@@ -22,7 +22,6 @@ limitations under the License. ...@@ -22,7 +22,6 @@ limitations under the License.
#include <unordered_set> #include <unordered_set>
#include <glog/logging.h> #include <glog/logging.h>
#include "oneflow/core/common/type_traits.h" #include "oneflow/core/common/type_traits.h"
#include "oneflow/core/common/maybe.h"
#include "oneflow/core/common/hash_eq_trait_ptr.h" #include "oneflow/core/common/hash_eq_trait_ptr.h"
namespace oneflow { namespace oneflow {
...@@ -128,12 +127,6 @@ struct SymbolUtil final { ...@@ -128,12 +127,6 @@ struct SymbolUtil final {
static const std::shared_ptr<const T>& GetOrCreatePtr(const T& obj) { static const std::shared_ptr<const T>& GetOrCreatePtr(const T& obj) {
return LocalThreadGetOr<CreateGlobalSymbol>(obj); return LocalThreadGetOr<CreateGlobalSymbol>(obj);
} }
static Maybe<Symbol<T>> GetSymbolByExistedRawPtr(const T* ptr) {
CHECK_GT_OR_RETURN(ThreadLocalSymbolPtrSet()->count(ptr), 0) << "ptr: " << ptr;
Symbol<T> symbol;
symbol.ptr_ = ptr;
return symbol;
}
}; };
template<typename T> template<typename T>
......
...@@ -13,17 +13,18 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -13,17 +13,18 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
*/ */
#include "oneflow/core/framework/tensor_desc.h" #include "oneflow/core/common/tensor_desc.h"
#include "oneflow/core/register/blob_desc.pb.h"
namespace oneflow { namespace oneflow {
namespace user_op { namespace user_op {
TensorDesc& TensorDesc::operator=(const TensorDesc& rhs) { TensorDesc& TensorDesc::operator=(const TensorDesc& rhs) {
*this->mut_shape() = rhs.shape(); this->set_shape(rhs.shape());
*this->mut_stride() = rhs.stride(); this->set_stride(rhs.stride());
*this->mut_data_type() = rhs.data_type(); this->set_data_type(rhs.data_type());
*this->mut_is_dynamic() = rhs.is_dynamic(); this->set_is_dynamic(rhs.is_dynamic());
return *this; return *this;
} }
......
...@@ -13,16 +13,18 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -13,16 +13,18 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
*/ */
#ifndef ONEFLOW_CORE_FRAMEWORK_TENSOR_DESC_H_ #ifndef ONEFLOW_CORE_COMMON_TENSOR_DESC_H_
#define ONEFLOW_CORE_FRAMEWORK_TENSOR_DESC_H_ #define ONEFLOW_CORE_COMMON_TENSOR_DESC_H_
#include "oneflow/core/common/util.h" #include "oneflow/core/common/util.h"
#include "oneflow/core/register/blob_desc.pb.h"
#include "oneflow/core/common/shape.h" #include "oneflow/core/common/shape.h"
#include "oneflow/core/common/stride.h" #include "oneflow/core/common/stride.h"
#include "oneflow/core/common/data_type.pb.h"
namespace oneflow { namespace oneflow {
class BlobDescProto;
namespace user_op { namespace user_op {
class TensorDesc { class TensorDesc {
...@@ -32,15 +34,14 @@ class TensorDesc { ...@@ -32,15 +34,14 @@ class TensorDesc {
bool operator==(const TensorDesc&) const; bool operator==(const TensorDesc&) const;
virtual const Shape& shape() const = 0; virtual const Shape& shape() const = 0;
virtual Shape* mut_shape() = 0; virtual void set_shape(const Shape& shape) = 0;
virtual const Stride& stride() const = 0; virtual const Stride& stride() const = 0;
virtual Stride* mut_stride() = 0; virtual void set_stride(const Stride& stride) = 0;
virtual DataType data_type() const = 0; virtual DataType data_type() const = 0;
virtual DataType* mut_data_type() = 0; virtual void set_data_type(DataType data_type) = 0;
virtual bool is_dynamic() const = 0; virtual bool is_dynamic() const = 0;
virtual bool* mut_is_dynamic() = 0; virtual void set_is_dynamic(bool is_dynamic) = 0;
virtual void set_is_dynamic(bool val) = 0;
protected: protected:
TensorDesc() = default; TensorDesc() = default;
...@@ -56,15 +57,14 @@ class NaiveTensorDesc final : public TensorDesc { ...@@ -56,15 +57,14 @@ class NaiveTensorDesc final : public TensorDesc {
NaiveTensorDesc& operator=(const BlobDescProto&); NaiveTensorDesc& operator=(const BlobDescProto&);
const Shape& shape() const override { return shape_; } const Shape& shape() const override { return shape_; }
Shape* mut_shape() override { return &shape_; } void set_shape(const Shape& shape) override { shape_ = shape; }
const Stride& stride() const override { return stride_; } const Stride& stride() const override { return stride_; }
Stride* mut_stride() override { return &stride_; } void set_stride(const Stride& stride) override { stride_ = stride; }
DataType data_type() const override { return data_type_; } DataType data_type() const override { return data_type_; }
DataType* mut_data_type() override { return &data_type_; } void set_data_type(DataType data_type) override { data_type_ = data_type; }
bool is_dynamic() const override { return is_dynamic_; } bool is_dynamic() const override { return is_dynamic_; }
bool* mut_is_dynamic() override { return &is_dynamic_; } void set_is_dynamic(bool is_dynamic) override { is_dynamic_ = is_dynamic; }
void set_is_dynamic(bool val) override { is_dynamic_ = val; }
private: private:
Shape shape_; Shape shape_;
...@@ -77,4 +77,4 @@ class NaiveTensorDesc final : public TensorDesc { ...@@ -77,4 +77,4 @@ class NaiveTensorDesc final : public TensorDesc {
} // namespace oneflow } // namespace oneflow
#endif // ONEFLOW_CORE_FRAMEWORK_TENSOR_DESC_H_ #endif // ONEFLOW_CORE_COMMON_TENSOR_DESC_H_
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/common/tensor_meta.h"
#include "oneflow/core/common/stride.h"
#include "oneflow/core/framework/device.h"
#include "oneflow/core/common/shape_view.h"
namespace oneflow {
namespace one {
MutTensorMeta::MutTensorMeta()
: TensorMeta(kInvalidDataType),
shape_(std::make_shared<const Shape>()),
stride_(std::make_shared<const Stride>()) {}
MutTensorMeta::MutTensorMeta(const std::shared_ptr<const Shape>& shape, DataType dtype)
: TensorMeta(dtype),
shape_(std::make_shared<const Shape>(*shape)),
stride_(std::make_shared<const Stride>(*shape)) {}
MutTensorMeta::MutTensorMeta(const std::shared_ptr<const Shape>& shape,
const std::shared_ptr<const Stride>& stride, DataType dtype)
: TensorMeta(dtype),
shape_(std::make_shared<const Shape>(*shape)),
stride_(std::make_shared<const Stride>(*stride)) {}
MutTensorMeta::MutTensorMeta(const Shape& shape, DataType dtype)
: TensorMeta(dtype),
shape_(std::make_shared<const Shape>(shape)),
stride_(std::make_shared<const Stride>(shape)) {}
MutTensorMeta::MutTensorMeta(const Shape& shape, const Stride& stride, DataType dtype)
: TensorMeta(dtype),
shape_(std::make_shared<const Shape>(shape)),
stride_(std::make_shared<const Stride>(stride)) {}
bool MutTensorMeta::operator==(const MutTensorMeta& other) const {
// It's correct to ignore is_dynamic_ field.
return *this->shape_ptr() == *other.shape_ptr() && this->dtype() == other.dtype()
&& this->stride() == other.stride();
}
size_t MutTensorMeta::CalcHashValue() const {
// It's correct to ignore is_dynamic_ field.
return Hash(*shape_ptr(), dtype(), stride());
}
ConstTensorMeta::ConstTensorMeta()
: TensorMeta(kInvalidDataType), shape_(SymbolOf(Shape())), stride_(SymbolOf(Stride())) {}
ConstTensorMeta::ConstTensorMeta(Symbol<Shape> shape, DataType dtype)
: TensorMeta(dtype), shape_(shape), stride_(SymbolOf(Stride(*shape))) {}
ConstTensorMeta::ConstTensorMeta(Symbol<Shape> shape, Symbol<Stride> stride, DataType dtype)
: TensorMeta(dtype), shape_(shape), stride_(stride) {}
bool ConstTensorMeta::operator==(const ConstTensorMeta& other) const {
// It's correct to ignore is_dynamic_ field.
return *this->shape_ptr() == *other.shape_ptr() && this->dtype() == other.dtype()
&& this->stride() == other.stride();
}
size_t ConstTensorMeta::CalcHashValue() const {
// It's correct to ignore is_dynamic_ field.
return Hash(*shape_ptr(), dtype(), stride());
}
LocalTensorMeta::LocalTensorMeta()
: ConstTensorMeta(SymbolOf(Shape()), SymbolOf(Stride()), DataType::kInvalidDataType),
device_(Symbol<Device>()) {}
LocalTensorMeta::LocalTensorMeta(Symbol<Shape> shape, DataType dtype, Symbol<Device> device)
: ConstTensorMeta(shape, SymbolOf(Stride(*shape)), dtype), device_(device) {}
LocalTensorMeta::LocalTensorMeta(Symbol<Shape> shape, Symbol<Stride> stride, DataType dtype,
Symbol<Device> device)
: ConstTensorMeta(shape, stride, dtype), device_(device) {}
bool LocalTensorMeta::operator==(const LocalTensorMeta& other) const {
// It's correct to ignore is_dynamic_ field.
return *this->shape_ptr() == *other.shape_ptr() && this->dtype() == other.dtype()
&& this->device() == other.device() && this->stride() == other.stride();
}
size_t LocalTensorMeta::CalcHashValue() const {
// It's correct to ignore is_dynamic_ field.
return Hash(*shape_ptr(), dtype(), device(), stride());
}
MutLocalTensorMeta::MutLocalTensorMeta()
: MutTensorMeta(std::make_shared<const Shape>(), std::make_shared<const Stride>(),
kInvalidDataType),
device_(Symbol<Device>()) {}
MutLocalTensorMeta::MutLocalTensorMeta(const std::shared_ptr<const Shape>& shape, DataType dtype,
Symbol<Device> device)
: MutTensorMeta(shape, std::make_shared<const Stride>(*shape), dtype), device_(device) {}
MutLocalTensorMeta::MutLocalTensorMeta(const std::shared_ptr<const Shape>& shape,
const std::shared_ptr<const Stride>& stride, DataType dtype,
Symbol<Device> device)
: MutTensorMeta(shape, stride, dtype), device_(device) {}
MutLocalTensorMeta::MutLocalTensorMeta(const Shape& shape, DataType dtype, Symbol<Device> device)
: MutTensorMeta(shape, Stride(shape), dtype), device_(device) {}
MutLocalTensorMeta::MutLocalTensorMeta(const Shape& shape, const Stride& stride, DataType dtype,
Symbol<Device> device)
: MutTensorMeta(shape, stride, dtype), device_(device) {}
bool MutLocalTensorMeta::operator==(const MutLocalTensorMeta& other) const {
// It's correct to ignore is_dynamic_ field.
return *this->shape_ptr() == *other.shape_ptr() && this->dtype() == other.dtype()
&& *this->device() == *other.device() && this->stride() == other.stride();
}
size_t MutLocalTensorMeta::CalcHashValue() const {
// It's correct to ignore is_dynamic_ field.
return Hash(*shape_ptr(), dtype(), *device(), stride());
}
bool GlobalTensorMeta::operator==(const GlobalTensorMeta& other) const {
// It's correct to ignore is_dynamic_ field.
return *this->shape_ptr() == *other.shape_ptr() && this->dtype() == other.dtype()
&& this->nd_sbp() == other.nd_sbp() && this->parallel_desc() == other.parallel_desc();
}
size_t GlobalTensorMeta::CalcHashValue() const {
return Hash(*shape_ptr(), dtype(), nd_sbp(), parallel_desc());
}
bool IsContiguous(const Shape& shape, const Stride& stride) {
if (!shape.is_initialized()) { return true; }
return IsContiguous(ShapeView(shape), stride);
}
bool IsContiguous(const ShapeView& shape_view, const Stride& stride) {
if (shape_view.NumAxes() < 1 || shape_view.elem_cnt() <= 1) { return true; }
int64_t dim = shape_view.NumAxes();
int64_t expected_stride = 1;
bool contig_if_nonempty = true;
for (int64_t i = dim - 1; i >= 0; --i) {
// Contiguous by default when any dim is equal to zero
// https://stackoverflow.com/questions/31681324/identify-contiguous-segments-of-a-non-contiguous-numpy-array
if (shape_view.At(i) == 0) { return true; }
if (contig_if_nonempty && shape_view.At(i) != 1) {
if (stride.at(i) != expected_stride) { contig_if_nonempty = false; }
expected_stride *= shape_view.At(i);
}
}
return contig_if_nonempty;
}
} // namespace one
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_COMMON_TENSOR_META_H_
#define ONEFLOW_COMMON_TENSOR_META_H_
#include <memory>
#include "oneflow/core/common/tensor_desc.h"
#include "oneflow/core/common/symbol.h"
namespace oneflow {
class NdSbp;
class Shape;
class Stride;
class Device;
class ParallelDesc;
namespace one {
bool IsContiguous(const Shape& shape, const Stride& stride);
bool IsContiguous(const ShapeView& shape_view, const Stride& stride);
class TensorMeta : public user_op::TensorDesc {
public:
TensorMeta(DataType dtype) : data_type_(dtype), is_dynamic_(false) {}
TensorMeta(const TensorMeta& other) = default;
TensorMeta(TensorMeta&&) = default;
virtual ~TensorMeta() = default;
virtual const std::shared_ptr<const Shape>& shape_ptr() const = 0;
virtual const std::shared_ptr<const Stride>& stride_ptr() const = 0;
virtual bool is_contiguous() const = 0;
DataType dtype() const { return data_type_; }
DataType data_type() const override { return data_type_; }
bool is_dynamic() const override { return is_dynamic_; }
virtual void set_shape(const Shape& shape) override { PRINT_BUG_PROMPT_AND_ABORT(); }
virtual void set_stride(const Stride& stride) override { PRINT_BUG_PROMPT_AND_ABORT(); }
virtual void set_data_type(DataType data_type) override { PRINT_BUG_PROMPT_AND_ABORT(); }
virtual void set_is_dynamic(bool is_dynamic) override { PRINT_BUG_PROMPT_AND_ABORT(); }
protected:
DataType data_type_;
bool is_dynamic_;
};
class MutTensorMeta : public TensorMeta {
public:
// uninitialized MutTensorMeta.
MutTensorMeta();
MutTensorMeta(const MutTensorMeta& other)
: TensorMeta(other),
shape_(std::make_shared<const Shape>(*other.shape_)),
stride_(std::make_shared<const Stride>(*other.stride_)) {}
MutTensorMeta(const std::shared_ptr<const Shape>& shape, DataType dtype);
MutTensorMeta(const std::shared_ptr<const Shape>& shape,
const std::shared_ptr<const Stride>& stride, DataType dtype);
MutTensorMeta(const Shape& shape, DataType dtype);
MutTensorMeta(const Shape& shape, const Stride& stride, DataType dtype);
virtual ~MutTensorMeta() = default;
const std::shared_ptr<const Shape>& shape_ptr() const override { return shape_; }
const std::shared_ptr<const Stride>& stride_ptr() const override { return stride_; }
const Shape& shape() const override { return *shape_; }
const Stride& stride() const override { return *stride_; }
bool is_contiguous() const override { return IsContiguous(*shape_, *stride_); }
void set_shape(const Shape& shape) override { *const_cast<Shape*>(shape_.get()) = shape; }
void set_stride(const Stride& stride) override { *const_cast<Stride*>(stride_.get()) = stride; }
void set_data_type(DataType data_type) override { data_type_ = data_type; }
void set_is_dynamic(bool is_dynamic) override { is_dynamic_ = is_dynamic; }
bool operator==(const MutTensorMeta& other) const;
size_t CalcHashValue() const;
MutTensorMeta& operator=(const MutTensorMeta& other) {
this->data_type_ = other.data_type_;
this->is_dynamic_ = other.is_dynamic_;
this->shape_ = std::make_shared<const Shape>(*other.shape_);
this->stride_ = std::make_shared<const Stride>(*other.stride_);
return *this;
}
protected:
std::shared_ptr<const Shape> shape_;
std::shared_ptr<const Stride> stride_;
};
class ConstTensorMeta : public TensorMeta {
public:
// uninitialized ConstTensorMeta.
ConstTensorMeta();
ConstTensorMeta(const ConstTensorMeta&) = default;
ConstTensorMeta(Symbol<Shape> shape, DataType dtype);
ConstTensorMeta(Symbol<Shape> shape, Symbol<Stride> stride, DataType dtype);
ConstTensorMeta(const Shape& shape, DataType dtype) : ConstTensorMeta(SymbolOf(shape), dtype) {}
ConstTensorMeta(const Shape& shape, const Stride& stride, DataType dtype)
: ConstTensorMeta(SymbolOf(shape), SymbolOf(stride), dtype) {}
virtual ~ConstTensorMeta() = default;
const std::shared_ptr<const Shape>& shape_ptr() const override {
return shape_.shared_from_symbol();
}
const std::shared_ptr<const Stride>& stride_ptr() const override {
return stride_.shared_from_symbol();
}
const Shape& shape() const override { return *shape_; }
const Stride& stride() const override { return *stride_; }
bool is_contiguous() const override { return IsContiguous(*shape_, *stride_); }
bool operator==(const ConstTensorMeta& other) const;
size_t CalcHashValue() const;
ConstTensorMeta& operator=(const ConstTensorMeta& other) {
this->data_type_ = other.data_type_;
this->is_dynamic_ = other.is_dynamic_;
this->shape_ = other.shape_;
this->stride_ = other.stride_;
return *this;
}
protected:
Symbol<Shape> shape_;
Symbol<Stride> stride_;
};
class LocalTensorMeta : public ConstTensorMeta {
public:
// uninitialized LocalTensorMeta.
LocalTensorMeta();
LocalTensorMeta(const LocalTensorMeta&) = default;
LocalTensorMeta(Symbol<Shape> shape, DataType dtype, Symbol<Device> device);
LocalTensorMeta(Symbol<Shape> shape, Symbol<Stride> stride, DataType dtype,
Symbol<Device> device);
LocalTensorMeta(const Shape& shape, DataType dtype, Symbol<Device> device)
: LocalTensorMeta(SymbolOf(shape), dtype, device) {}
LocalTensorMeta(const Shape& shape, const Stride& stride, DataType dtype, Symbol<Device> device)
: LocalTensorMeta(SymbolOf(shape), SymbolOf(stride), dtype, device) {}
virtual ~LocalTensorMeta() = default;
const Symbol<Device>& device() const { return device_; }
bool operator==(const LocalTensorMeta& other) const;
size_t CalcHashValue() const;
LocalTensorMeta& operator=(const LocalTensorMeta& other) = default;
private:
Symbol<Device> device_;
};
class MutLocalTensorMeta : public MutTensorMeta {
public:
// uninitialized MutLocalTensorMeta.
MutLocalTensorMeta();
MutLocalTensorMeta(const MutLocalTensorMeta&) = default;
MutLocalTensorMeta(const std::shared_ptr<const Shape>& shape, DataType dtype,
Symbol<Device> device);
MutLocalTensorMeta(const std::shared_ptr<const Shape>& shape,
const std::shared_ptr<const Stride>& stride, DataType dtype,
Symbol<Device> device);
MutLocalTensorMeta(const Shape& shape, DataType dtype, Symbol<Device> device);
MutLocalTensorMeta(const Shape& shape, const Stride& stride, DataType dtype,
Symbol<Device> device);
virtual ~MutLocalTensorMeta() = default;
const Symbol<Device>& device() const { return device_; }
Symbol<Device>* mut_device() { return &device_; }
bool operator==(const MutLocalTensorMeta& other) const;
size_t CalcHashValue() const;
MutLocalTensorMeta& operator=(const MutLocalTensorMeta& other) = default;
private:
Symbol<Device> device_;
};
class GlobalTensorMeta : public ConstTensorMeta {
public:
GlobalTensorMeta(Symbol<Shape> shape, DataType dtype, Symbol<NdSbp> nd_sbp,
Symbol<ParallelDesc> parallel_desc)
: ConstTensorMeta(shape, dtype), nd_sbp_(nd_sbp), parallel_desc_(parallel_desc) {}
GlobalTensorMeta(const Shape& shape, DataType dtype, Symbol<NdSbp> nd_sbp,
Symbol<ParallelDesc> parallel_desc)
: GlobalTensorMeta(SymbolOf(shape), dtype, nd_sbp, parallel_desc) {}
GlobalTensorMeta(const GlobalTensorMeta&) = default;
GlobalTensorMeta(GlobalTensorMeta&&) = default;
virtual ~GlobalTensorMeta() = default;
bool operator==(const GlobalTensorMeta& other) const;
Symbol<NdSbp> nd_sbp() const { return nd_sbp_; }
Symbol<ParallelDesc> parallel_desc() const { return parallel_desc_; }
size_t CalcHashValue() const;
private:
Symbol<NdSbp> nd_sbp_;
Symbol<ParallelDesc> parallel_desc_;
};
} // namespace one
} // namespace oneflow
namespace std {
template<>
struct hash<oneflow::one::LocalTensorMeta> final {
size_t operator()(const oneflow::one::LocalTensorMeta& local_tensor_meta) const {
return local_tensor_meta.CalcHashValue();
}
};
template<>
struct hash<oneflow::one::GlobalTensorMeta> final {
size_t operator()(const oneflow::one::GlobalTensorMeta& global_tensor_meta) const {
return global_tensor_meta.CalcHashValue();
}
};
} // namespace std
#endif // ONEFLOW_COMMON_TENSOR_META_H_
...@@ -25,6 +25,10 @@ namespace oneflow { ...@@ -25,6 +25,10 @@ namespace oneflow {
template<typename T> template<typename T>
class ThreadLocalGuard { class ThreadLocalGuard {
public: public:
ThreadLocalGuard() {
old_value_ = *MutThreadLocalValue();
*MutThreadLocalValue() = Optional<T>();
}
explicit ThreadLocalGuard(const T& value) { explicit ThreadLocalGuard(const T& value) {
old_value_ = *MutThreadLocalValue(); old_value_ = *MutThreadLocalValue();
*MutThreadLocalValue() = Optional<T>(value); *MutThreadLocalValue() = Optional<T>(value);
......
...@@ -23,21 +23,29 @@ namespace oneflow { ...@@ -23,21 +23,29 @@ namespace oneflow {
namespace details { namespace details {
struct Throw final { struct Throw final {
void operator=(Error&& error) { ThrowError(error.error_proto()); } void operator=(Error&& error) { ThrowError(error.stacked_error()); }
}; };
} // namespace details } // namespace details
} // namespace oneflow } // namespace oneflow
#define THROW(err_type) \ #define THROW(err_type) \
oneflow::details::Throw() = \ ::oneflow::details::Throw() = \
oneflow::Error::err_type().AddStackFrame(__FILE__, __LINE__, __FUNCTION__) ::oneflow::Error::err_type().AddStackFrame([](const char* function) { \
thread_local static auto frame = \
#define CHECK_OR_THROW(expr) \ ::oneflow::SymbolOf(::oneflow::ErrorStackFrame(__FILE__, __LINE__, function)); \
if (!(expr)) \ return frame; \
oneflow::details::Throw() = \ }(__FUNCTION__))
oneflow::Error::CheckFailedError().AddStackFrame(__FILE__, __LINE__, __FUNCTION__) \
#define CHECK_OR_THROW(expr) \
if (!(expr)) \
::oneflow::details::Throw() = \
::oneflow::Error::CheckFailedError().AddStackFrame([](const char* function) { \
thread_local static auto frame = \
::oneflow::SymbolOf(::oneflow::ErrorStackFrame(__FILE__, __LINE__, function)); \
return frame; \
}(__FUNCTION__)) \
<< "Check failed: " << OF_PP_STRINGIZE(expr) << ": " << "Check failed: " << OF_PP_STRINGIZE(expr) << ": "
#define CHECK_EQ_OR_THROW(lhs, rhs) \ #define CHECK_EQ_OR_THROW(lhs, rhs) \
...@@ -66,12 +74,20 @@ struct Throw final { ...@@ -66,12 +74,20 @@ struct Throw final {
#define CHECK_ISNULL_OR_THROW(ptr) CHECK_OR_THROW(ptr == nullptr) #define CHECK_ISNULL_OR_THROW(ptr) CHECK_OR_THROW(ptr == nullptr)
#define TODO_THEN_THROW() \ #define TODO_THEN_THROW() \
oneflow::details::Throw() = \ ::oneflow::details::Throw() = \
oneflow::Error::TodoError().AddStackFrame(__FILE__, __LINE__, __FUNCTION__) ::oneflow::Error::TodoError().AddStackFrame([](const char* function) { \
thread_local static auto frame = \
#define UNIMPLEMENTED_THEN_THROW() \ ::oneflow::SymbolOf(::oneflow::ErrorStackFrame(__FILE__, __LINE__, function)); \
oneflow::details::Throw() = \ return frame; \
oneflow::Error::UnimplementedError().AddStackFrame(__FILE__, __LINE__, __FUNCTION__) }(__FUNCTION__))
#define UNIMPLEMENTED_THEN_THROW() \
::oneflow::details::Throw() = \
::oneflow::Error::UnimplementedError().AddStackFrame([](const char* function) { \
thread_local static auto frame = \
::oneflow::SymbolOf(::oneflow::ErrorStackFrame(__FILE__, __LINE__, function)); \
return frame; \
}(__FUNCTION__))
#endif // ONEFLOW_CORE_COMMON_THROW_H_ #endif // ONEFLOW_CORE_COMMON_THROW_H_
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment