Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Oneflow
Commits
a715222c
Commit
a715222c
authored
Feb 28, 2023
by
yuguo
Browse files
0.9.1-rocm
parent
f262efc9
Changes
469
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
650 additions
and
130 deletions
+650
-130
oneflow/core/common/registry_error.cpp
oneflow/core/common/registry_error.cpp
+4
-4
oneflow/core/common/scalar.cpp
oneflow/core/common/scalar.cpp
+1
-1
oneflow/core/common/scalar.h
oneflow/core/common/scalar.h
+10
-10
oneflow/core/common/shape.cpp
oneflow/core/common/shape.cpp
+10
-0
oneflow/core/common/shape.h
oneflow/core/common/shape.h
+3
-0
oneflow/core/common/shape_view.cpp
oneflow/core/common/shape_view.cpp
+1
-0
oneflow/core/common/small_vector.h
oneflow/core/common/small_vector.h
+5
-0
oneflow/core/common/spin_counter.cpp
oneflow/core/common/spin_counter.cpp
+1
-1
oneflow/core/common/spin_counter.h
oneflow/core/common/spin_counter.h
+1
-0
oneflow/core/common/steady_vector.h
oneflow/core/common/steady_vector.h
+10
-8
oneflow/core/common/stream_role.h
oneflow/core/common/stream_role.h
+0
-77
oneflow/core/common/stream_type.h
oneflow/core/common/stream_type.h
+73
-0
oneflow/core/common/stride.cpp
oneflow/core/common/stride.cpp
+2
-1
oneflow/core/common/symbol.h
oneflow/core/common/symbol.h
+0
-7
oneflow/core/common/tensor_desc.cpp
oneflow/core/common/tensor_desc.cpp
+6
-5
oneflow/core/common/tensor_desc.h
oneflow/core/common/tensor_desc.h
+80
-0
oneflow/core/common/tensor_meta.cpp
oneflow/core/common/tensor_meta.cpp
+168
-0
oneflow/core/common/tensor_meta.h
oneflow/core/common/tensor_meta.h
+239
-0
oneflow/core/common/thread_local_guard.h
oneflow/core/common/thread_local_guard.h
+4
-0
oneflow/core/common/throw.h
oneflow/core/common/throw.h
+32
-16
No files found.
Too many changes to show.
To preserve performance only
469 of 469+
files are displayed.
Plain diff
Email patch
oneflow/core/common/registry_error.cpp
View file @
a715222c
...
@@ -19,15 +19,15 @@ limitations under the License.
...
@@ -19,15 +19,15 @@ limitations under the License.
namespace
oneflow
{
namespace
oneflow
{
namespace
{
namespace
{
std
::
shared_ptr
<
Error
Proto
>*
MutRegistryError
()
{
std
::
shared_ptr
<
Stacked
Error
>*
MutRegistryError
()
{
static
std
::
shared_ptr
<
Error
Proto
>
registry_error
;
static
std
::
shared_ptr
<
Stacked
Error
>
registry_error
;
return
&
registry_error
;
return
&
registry_error
;
}
}
}
// namespace
}
// namespace
Maybe
<
void
>
CheckAndClearRegistryFlag
()
{
Maybe
<
void
>
CheckAndClearRegistryFlag
()
{
if
(
!*
MutRegistryError
())
{
return
Maybe
<
void
>::
Ok
();
}
if
(
!*
MutRegistryError
())
{
return
Maybe
<
void
>::
Ok
();
}
std
::
shared_ptr
<
Error
Proto
>
registry_error_old
=
*
MutRegistryError
();
std
::
shared_ptr
<
Stacked
Error
>
registry_error_old
=
*
MutRegistryError
();
*
MutRegistryError
()
=
nullptr
;
*
MutRegistryError
()
=
nullptr
;
return
registry_error_old
;
return
registry_error_old
;
}
}
...
@@ -35,7 +35,7 @@ Maybe<void> CheckAndClearRegistryFlag() {
...
@@ -35,7 +35,7 @@ Maybe<void> CheckAndClearRegistryFlag() {
void
CatchRegistryError
(
const
std
::
function
<
Maybe
<
void
>
()
>&
handler
)
{
void
CatchRegistryError
(
const
std
::
function
<
Maybe
<
void
>
()
>&
handler
)
{
const
auto
&
maybe_error
=
TRY
(
handler
());
const
auto
&
maybe_error
=
TRY
(
handler
());
if
(
!
maybe_error
.
IsOk
())
{
if
(
!
maybe_error
.
IsOk
())
{
if
(
!*
MutRegistryError
())
{
*
MutRegistryError
()
=
maybe_error
.
error
();
}
if
(
!*
MutRegistryError
())
{
*
MutRegistryError
()
=
maybe_error
.
stacked_
error
();
}
}
}
}
}
...
...
oneflow/core/common/scalar.cpp
View file @
a715222c
...
@@ -29,7 +29,7 @@ namespace oneflow {
...
@@ -29,7 +29,7 @@ namespace oneflow {
} \
} \
return *this; \
return *this; \
} \
} \
Scalar Scalar::operator op(const Scalar& other)
{
\
Scalar Scalar::operator op(const Scalar& other)
const {
\
if (IsFloatingPoint() || other.IsFloatingPoint()) { \
if (IsFloatingPoint() || other.IsFloatingPoint()) { \
double val = As<double>() op other.As<double>(); \
double val = As<double>() op other.As<double>(); \
return Scalar(val); \
return Scalar(val); \
...
...
oneflow/core/common/scalar.h
View file @
a715222c
...
@@ -29,28 +29,28 @@ class Scalar {
...
@@ -29,28 +29,28 @@ class Scalar {
Scalar
()
:
Scalar
(
int32_t
(
0
))
{}
Scalar
()
:
Scalar
(
int32_t
(
0
))
{}
template
<
typename
T
,
typename
std
::
enable_if
<
std
::
is_same
<
T
,
bool
>
::
value
,
int
>::
type
=
0
>
template
<
typename
T
,
typename
std
::
enable_if
<
std
::
is_same
<
T
,
bool
>
::
value
,
int
>::
type
=
0
>
Scalar
(
const
T
&
value
)
:
value_
{.
b
=
value
},
active_tag_
(
HAS_B
)
{}
OF_DEVICE_FUNC
Scalar
(
const
T
&
value
)
:
value_
{.
b
=
value
},
active_tag_
(
HAS_B
)
{}
template
<
typename
T
,
typename
std
::
enable_if
<
template
<
typename
T
,
typename
std
::
enable_if
<
std
::
is_integral
<
T
>
::
value
&&
std
::
is_signed
<
T
>::
value
,
int
>::
type
=
0
>
std
::
is_integral
<
T
>
::
value
&&
std
::
is_signed
<
T
>::
value
,
int
>::
type
=
0
>
Scalar
(
const
T
&
value
)
:
value_
{.
s
=
value
},
active_tag_
(
HAS_S
)
{}
OF_DEVICE_FUNC
Scalar
(
const
T
&
value
)
:
value_
{.
s
=
value
},
active_tag_
(
HAS_S
)
{}
template
<
typename
T
,
template
<
typename
T
,
typename
std
::
enable_if
<
std
::
is_integral
<
T
>
::
value
&&
std
::
is_unsigned
<
T
>::
value
typename
std
::
enable_if
<
std
::
is_integral
<
T
>
::
value
&&
std
::
is_unsigned
<
T
>::
value
&&
!
std
::
is_same
<
T
,
bool
>::
value
,
&&
!
std
::
is_same
<
T
,
bool
>::
value
,
int
>::
type
=
0
>
int
>::
type
=
0
>
Scalar
(
const
T
&
value
)
:
value_
{.
u
=
value
},
active_tag_
(
HAS_U
)
{}
OF_DEVICE_FUNC
Scalar
(
const
T
&
value
)
:
value_
{.
u
=
value
},
active_tag_
(
HAS_U
)
{}
template
<
typename
T
,
typename
std
::
enable_if
<
std
::
is_floating_point
<
T
>
::
value
,
int
>::
type
=
0
>
template
<
typename
T
,
typename
std
::
enable_if
<
std
::
is_floating_point
<
T
>
::
value
,
int
>::
type
=
0
>
Scalar
(
const
T
&
value
)
:
value_
{.
d
=
value
},
active_tag_
(
HAS_D
)
{}
OF_DEVICE_FUNC
Scalar
(
const
T
&
value
)
:
value_
{.
d
=
value
},
active_tag_
(
HAS_D
)
{}
template
<
typename
T
,
typename
std
::
enable_if
<!
std
::
is_same
<
T
,
Scalar
>
::
value
,
int
>::
type
=
0
>
template
<
typename
T
,
typename
std
::
enable_if
<!
std
::
is_same
<
T
,
Scalar
>
::
value
,
int
>::
type
=
0
>
Scalar
&
operator
=
(
const
T
&
value
)
{
OF_DEVICE_FUNC
Scalar
&
operator
=
(
const
T
&
value
)
{
*
this
=
Scalar
(
value
);
*
this
=
Scalar
(
value
);
return
*
this
;
return
*
this
;
}
}
Scalar
&
operator
=
(
const
Scalar
&
other
)
{
OF_DEVICE_FUNC
Scalar
&
operator
=
(
const
Scalar
&
other
)
{
value_
=
other
.
value_
;
value_
=
other
.
value_
;
active_tag_
=
other
.
active_tag_
;
active_tag_
=
other
.
active_tag_
;
return
*
this
;
return
*
this
;
...
@@ -78,10 +78,10 @@ class Scalar {
...
@@ -78,10 +78,10 @@ class Scalar {
bool
IsSigned
()
const
{
return
active_tag_
==
HAS_S
||
active_tag_
==
HAS_D
;
}
bool
IsSigned
()
const
{
return
active_tag_
==
HAS_S
||
active_tag_
==
HAS_D
;
}
bool
IsUnsigned
()
const
{
return
active_tag_
==
HAS_U
;
}
bool
IsUnsigned
()
const
{
return
active_tag_
==
HAS_U
;
}
Scalar
operator
+
(
const
Scalar
&
other
);
Scalar
operator
+
(
const
Scalar
&
other
)
const
;
Scalar
operator
-
(
const
Scalar
&
other
);
Scalar
operator
-
(
const
Scalar
&
other
)
const
;
Scalar
operator
*
(
const
Scalar
&
other
);
Scalar
operator
*
(
const
Scalar
&
other
)
const
;
Scalar
operator
/
(
const
Scalar
&
other
);
Scalar
operator
/
(
const
Scalar
&
other
)
const
;
Scalar
&
operator
+=
(
const
Scalar
&
other
);
Scalar
&
operator
+=
(
const
Scalar
&
other
);
Scalar
&
operator
-=
(
const
Scalar
&
other
);
Scalar
&
operator
-=
(
const
Scalar
&
other
);
...
...
oneflow/core/common/shape.cpp
View file @
a715222c
...
@@ -220,4 +220,14 @@ Maybe<Shape> Shape::Slice(int64_t start_dim, int64_t end_dim) const {
...
@@ -220,4 +220,14 @@ Maybe<Shape> Shape::Slice(int64_t start_dim, int64_t end_dim) const {
return
shape
;
return
shape
;
}
}
bool
Shape
::
operator
==
(
const
Shape
&
rhs
)
const
{
if
(
is_initialized_
!=
rhs
.
is_initialized_
)
{
return
false
;
}
if
(
is_initialized_
==
false
)
{
return
true
;
}
if
(
this
->
NumAxes
()
!=
rhs
.
NumAxes
())
{
return
false
;
}
FOR_RANGE
(
int
,
i
,
0
,
this
->
NumAxes
())
{
if
(
this
->
At
(
i
)
!=
rhs
.
At
(
i
))
{
return
false
;
}
}
return
true
;
}
}
// namespace oneflow
}
// namespace oneflow
oneflow/core/common/shape.h
View file @
a715222c
...
@@ -147,6 +147,8 @@ class Shape final : public DimVector, public MutShapeMixIn<Shape> {
...
@@ -147,6 +147,8 @@ class Shape final : public DimVector, public MutShapeMixIn<Shape> {
Maybe
<
Shape
>
Slice
(
int64_t
start_dim
,
int64_t
end_dim
)
const
;
Maybe
<
Shape
>
Slice
(
int64_t
start_dim
,
int64_t
end_dim
)
const
;
bool
operator
==
(
const
Shape
&
rhs
)
const
;
private:
private:
// Set default value here because some constructors are inherited from DimVector
// Set default value here because some constructors are inherited from DimVector
// TODO(daquexian): remove this field and make it initializied by construction
// TODO(daquexian): remove this field and make it initializied by construction
...
@@ -170,6 +172,7 @@ namespace std {
...
@@ -170,6 +172,7 @@ namespace std {
template
<
>
template
<
>
struct
hash
<
oneflow
::
Shape
>
{
struct
hash
<
oneflow
::
Shape
>
{
size_t
operator
()(
const
oneflow
::
Shape
&
shape
)
const
{
size_t
operator
()(
const
oneflow
::
Shape
&
shape
)
const
{
if
(
!
shape
.
is_initialized
())
{
return
0
;
}
size_t
ret
=
shape
.
NumAxes
();
size_t
ret
=
shape
.
NumAxes
();
FOR_RANGE
(
int
,
i
,
0
,
shape
.
NumAxes
())
{
oneflow
::
AddHash
(
&
ret
,
shape
.
At
(
i
));
}
FOR_RANGE
(
int
,
i
,
0
,
shape
.
NumAxes
())
{
oneflow
::
AddHash
(
&
ret
,
shape
.
At
(
i
));
}
return
ret
;
return
ret
;
...
...
oneflow/core/common/shape_view.cpp
View file @
a715222c
...
@@ -36,6 +36,7 @@ std::ostream& operator<<(std::ostream& out, ShapeView shape) {
...
@@ -36,6 +36,7 @@ std::ostream& operator<<(std::ostream& out, ShapeView shape) {
}
}
void
MutShapeView
::
set_shape
(
ShapeView
shape
)
{
void
MutShapeView
::
set_shape
(
ShapeView
shape
)
{
if
(
shape
.
ptr
()
==
mut_ptr
()
&&
shape
.
NumAxes
()
==
NumAxes
())
{
return
;
}
CHECK_EQ
(
NumAxes
(),
shape
.
NumAxes
());
CHECK_EQ
(
NumAxes
(),
shape
.
NumAxes
());
std
::
copy
(
shape
.
ptr
(),
shape
.
ptr
()
+
shape
.
NumAxes
(),
mut_ptr
());
std
::
copy
(
shape
.
ptr
(),
shape
.
ptr
()
+
shape
.
NumAxes
(),
mut_ptr
());
}
}
...
...
oneflow/core/common/small_vector.h
View file @
a715222c
...
@@ -25,6 +25,7 @@ class small_vector : public llvm::SmallVector<T, N> {
...
@@ -25,6 +25,7 @@ class small_vector : public llvm::SmallVector<T, N> {
using
Base
=
llvm
::
SmallVector
<
T
,
N
>
;
using
Base
=
llvm
::
SmallVector
<
T
,
N
>
;
public:
public:
constexpr
static
size_t
kInitialSize
=
N
;
// https://stackoverflow.com/questions/27954940/a-using-statement-compiles-with-g-fails-compilation-with-clang
// https://stackoverflow.com/questions/27954940/a-using-statement-compiles-with-g-fails-compilation-with-clang
using
Base
::
Base
;
using
Base
::
Base
;
...
@@ -36,6 +37,10 @@ class small_vector : public llvm::SmallVector<T, N> {
...
@@ -36,6 +37,10 @@ class small_vector : public llvm::SmallVector<T, N> {
CHECK_LT
(
idx
,
Base
::
size
());
CHECK_LT
(
idx
,
Base
::
size
());
return
(
*
this
)[
idx
];
return
(
*
this
)[
idx
];
}
}
typename
Base
::
reference
operator
[](
typename
Base
::
size_type
idx
)
{
return
this
->
data
()[
idx
];
}
typename
Base
::
const_reference
operator
[](
typename
Base
::
size_type
idx
)
const
{
return
this
->
data
()[
idx
];
}
typename
Base
::
const_iterator
cbegin
()
const
{
typename
Base
::
const_iterator
cbegin
()
const
{
return
(
typename
Base
::
const_iterator
)
this
->
BeginX
;
return
(
typename
Base
::
const_iterator
)
this
->
BeginX
;
}
}
...
...
oneflow/core/common/spin_counter.cpp
View file @
a715222c
...
@@ -22,7 +22,7 @@ namespace oneflow {
...
@@ -22,7 +22,7 @@ namespace oneflow {
Maybe
<
void
>
SpinCounter
::
WaitUntilCntEqualZero
()
const
{
Maybe
<
void
>
SpinCounter
::
WaitUntilCntEqualZero
()
const
{
return
Singleton
<
ForeignLockHelper
>::
Get
()
->
WithScopedRelease
([
&
]()
->
Maybe
<
void
>
{
return
Singleton
<
ForeignLockHelper
>::
Get
()
->
WithScopedRelease
([
&
]()
->
Maybe
<
void
>
{
while
(
cnt_val_
>
0
)
{}
;
while
(
cnt_val_
>
0
)
{}
return
Maybe
<
void
>::
Ok
();
return
Maybe
<
void
>::
Ok
();
});
});
}
}
...
...
oneflow/core/common/spin_counter.h
View file @
a715222c
...
@@ -31,6 +31,7 @@ class SpinCounter final {
...
@@ -31,6 +31,7 @@ class SpinCounter final {
explicit
SpinCounter
(
int64_t
cnt_val
)
:
cnt_val_
(
cnt_val
)
{}
explicit
SpinCounter
(
int64_t
cnt_val
)
:
cnt_val_
(
cnt_val
)
{}
int64_t
Decrease
()
{
return
--
cnt_val_
;
}
int64_t
Decrease
()
{
return
--
cnt_val_
;
}
void
Reset
(
int64_t
cnt_val
)
{
cnt_val_
=
cnt_val
;
}
Maybe
<
void
>
WaitUntilCntEqualZero
()
const
;
Maybe
<
void
>
WaitUntilCntEqualZero
()
const
;
private:
private:
...
...
oneflow/core/common/steady_vector.h
View file @
a715222c
...
@@ -34,7 +34,7 @@ class SteadyVector {
...
@@ -34,7 +34,7 @@ class SteadyVector {
using
size_type
=
size_t
;
using
size_type
=
size_t
;
// thread safe.
// thread safe.
size_t
size
()
const
{
return
size_
;
}
size_t
size
()
const
{
return
size_
.
load
(
std
::
memory_order_acquire
)
;
}
// thread safe.
// thread safe.
const
T
&
at
(
size_t
index
)
const
{
const
T
&
at
(
size_t
index
)
const
{
...
@@ -51,12 +51,10 @@ class SteadyVector {
...
@@ -51,12 +51,10 @@ class SteadyVector {
return
granularity2data_
[
gran
].
get
()[
index
-
start
];
return
granularity2data_
[
gran
].
get
()[
index
-
start
];
}
}
void
push_back
(
const
T
&
elem
)
{
*
MutableOrAdd
(
size_
)
=
elem
;
}
// `index` should be <= size()
void
SetOrAdd
(
size_t
index
,
T
value
)
{
// `index` shoule be <= size()
T
*
MutableOrAdd
(
size_t
index
)
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
size_t
size
=
size_
;
size_t
size
=
size_
.
load
(
std
::
memory_order_relaxed
)
;
CHECK_LE
(
index
,
size
)
<<
"index out of range"
;
CHECK_LE
(
index
,
size
)
<<
"index out of range"
;
if
(
index
==
size
)
{
if
(
index
==
size
)
{
int
granularity
=
GetGranularity
(
size
);
int
granularity
=
GetGranularity
(
size
);
...
@@ -64,11 +62,15 @@ class SteadyVector {
...
@@ -64,11 +62,15 @@ class SteadyVector {
CHECK_LT
(
granularity
,
N
);
CHECK_LT
(
granularity
,
N
);
granularity2data_
[
granularity
].
reset
(
new
T
[
1
<<
granularity
]);
granularity2data_
[
granularity
].
reset
(
new
T
[
1
<<
granularity
]);
}
}
++
size_
;
*
Mutable
(
index
)
=
std
::
move
(
value
);
size_
.
fetch_add
(
1
,
std
::
memory_order_release
);
}
else
{
*
Mutable
(
index
)
=
std
::
move
(
value
);
}
}
return
Mutable
(
index
);
}
}
void
push_back
(
const
T
&
elem
)
{
SetOrAdd
(
size_
,
elem
);
}
private:
private:
T
*
Mutable
(
size_t
index
)
{
T
*
Mutable
(
size_t
index
)
{
int
gran
=
0
;
int
gran
=
0
;
...
...
oneflow/core/common/stream_role.h
deleted
100644 → 0
View file @
f262efc9
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_COMMON_STREAM_ROLE_H_
#define ONEFLOW_CORE_COMMON_STREAM_ROLE_H_
#include <functional>
#include <array>
#include "oneflow/core/common/preprocessor.h"
#include "glog/logging.h"
namespace
oneflow
{
enum
class
StreamRole
{
kInvalid
=
0
,
kCompute
,
kHost2Device
,
kDevice2Host
,
kSyncedLaunchedCommNet
,
kAsyncedLaunchedCommNet
,
kBarrier
,
kCriticalSection
,
kLazyJobLauncher
,
kPinnedCompute
};
template
<
typename
DerivedT
>
struct
StreamRoleVisitor
{
template
<
typename
...
Args
>
static
auto
Visit
(
StreamRole
stream_role
,
Args
&&
...
args
)
{
switch
(
stream_role
)
{
case
StreamRole
::
kInvalid
:
LOG
(
FATAL
)
<<
"invalid stream role"
;
case
StreamRole
::
kCompute
:
return
DerivedT
::
VisitCompute
(
std
::
forward
<
Args
>
(
args
)...);
case
StreamRole
::
kHost2Device
:
return
DerivedT
::
VisitHost2Device
(
std
::
forward
<
Args
>
(
args
)...);
case
StreamRole
::
kDevice2Host
:
return
DerivedT
::
VisitDevice2Host
(
std
::
forward
<
Args
>
(
args
)...);
case
StreamRole
::
kSyncedLaunchedCommNet
:
return
DerivedT
::
VisitSyncedLaunchedCommNet
(
std
::
forward
<
Args
>
(
args
)...);
case
StreamRole
::
kAsyncedLaunchedCommNet
:
return
DerivedT
::
VisitAsyncedLaunchedCommNet
(
std
::
forward
<
Args
>
(
args
)...);
case
StreamRole
::
kBarrier
:
return
DerivedT
::
VisitBarrier
(
std
::
forward
<
Args
>
(
args
)...);
case
StreamRole
::
kCriticalSection
:
return
DerivedT
::
VisitCriticalSection
(
std
::
forward
<
Args
>
(
args
)...);
case
StreamRole
::
kLazyJobLauncher
:
return
DerivedT
::
VisitLazyJobLauncher
(
std
::
forward
<
Args
>
(
args
)...);
case
StreamRole
::
kPinnedCompute
:
return
DerivedT
::
VisitPinnedCompute
(
std
::
forward
<
Args
>
(
args
)...);
}
LOG
(
FATAL
)
<<
"invalid stream role"
;
}
};
}
// namespace oneflow
namespace
std
{
template
<
>
struct
hash
<
oneflow
::
StreamRole
>
final
{
size_t
operator
()(
const
oneflow
::
StreamRole
&
stream_role
)
const
{
return
static_cast
<
int
>
(
stream_role
);
}
};
}
// namespace std
#endif // ONEFLOW_CORE_COMMON_STREAM_ROLE_H_
oneflow/core/common/stream_type.h
0 → 100644
View file @
a715222c
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_COMMON_STREAM_TYPE_H_
#define ONEFLOW_CORE_COMMON_STREAM_TYPE_H_
#include <functional>
#include <array>
#include "oneflow/core/common/preprocessor.h"
#include "glog/logging.h"
namespace
oneflow
{
enum
class
StreamType
{
kInvalid
=
0
,
kCompute
,
kHost2Device
,
kDevice2Host
,
kCcl
,
kBarrier
,
kCriticalSection
,
kLazyJobLauncher
,
kPinnedCompute
};
template
<
typename
DerivedT
>
struct
StreamTypeVisitor
{
template
<
typename
...
Args
>
static
auto
Visit
(
StreamType
stream_type
,
Args
&&
...
args
)
{
switch
(
stream_type
)
{
case
StreamType
::
kInvalid
:
LOG
(
FATAL
)
<<
"invalid stream type"
;
case
StreamType
::
kCompute
:
return
DerivedT
::
VisitCompute
(
std
::
forward
<
Args
>
(
args
)...);
case
StreamType
::
kHost2Device
:
return
DerivedT
::
VisitHost2Device
(
std
::
forward
<
Args
>
(
args
)...);
case
StreamType
::
kDevice2Host
:
return
DerivedT
::
VisitDevice2Host
(
std
::
forward
<
Args
>
(
args
)...);
case
StreamType
::
kCcl
:
return
DerivedT
::
VisitCcl
(
std
::
forward
<
Args
>
(
args
)...);
case
StreamType
::
kBarrier
:
return
DerivedT
::
VisitBarrier
(
std
::
forward
<
Args
>
(
args
)...);
case
StreamType
::
kCriticalSection
:
return
DerivedT
::
VisitCriticalSection
(
std
::
forward
<
Args
>
(
args
)...);
case
StreamType
::
kLazyJobLauncher
:
return
DerivedT
::
VisitLazyJobLauncher
(
std
::
forward
<
Args
>
(
args
)...);
case
StreamType
::
kPinnedCompute
:
return
DerivedT
::
VisitPinnedCompute
(
std
::
forward
<
Args
>
(
args
)...);
}
LOG
(
FATAL
)
<<
"invalid stream type"
;
}
};
}
// namespace oneflow
namespace
std
{
template
<
>
struct
hash
<
oneflow
::
StreamType
>
final
{
size_t
operator
()(
const
oneflow
::
StreamType
&
stream_type
)
const
{
return
static_cast
<
int
>
(
stream_type
);
}
};
}
// namespace std
#endif // ONEFLOW_CORE_COMMON_STREAM_TYPE_H_
oneflow/core/common/stride.cpp
View file @
a715222c
...
@@ -15,6 +15,7 @@ limitations under the License.
...
@@ -15,6 +15,7 @@ limitations under the License.
*/
*/
#include "oneflow/core/common/stride.h"
#include "oneflow/core/common/stride.h"
#include "oneflow/core/common/constant.h"
#include "oneflow/core/common/protobuf.h"
#include "oneflow/core/common/protobuf.h"
#include "oneflow/core/common/cplusplus_17.h"
#include "oneflow/core/common/cplusplus_17.h"
...
@@ -29,7 +30,7 @@ Stride::Stride(const Shape& shape) {
...
@@ -29,7 +30,7 @@ Stride::Stride(const Shape& shape) {
std
::
multiplies
<>
{});
std
::
multiplies
<>
{});
}
else
if
(
ndim
>
0
&&
shape
.
elem_cnt
()
==
0
)
{
}
else
if
(
ndim
>
0
&&
shape
.
elem_cnt
()
==
0
)
{
// 0-size shape
// 0-size shape
s
td
::
vector
<
int64_t
>
tmp_shape
(
ndim
);
s
mall_
vector
<
int64_t
,
kMaxNumDims
>
tmp_shape
(
ndim
);
for
(
int64_t
i
=
0
;
i
<
ndim
;
++
i
)
{
tmp_shape
[
i
]
=
shape
.
At
(
i
)
>
0
?
shape
.
At
(
i
)
:
1
;
}
for
(
int64_t
i
=
0
;
i
<
ndim
;
++
i
)
{
tmp_shape
[
i
]
=
shape
.
At
(
i
)
>
0
?
shape
.
At
(
i
)
:
1
;
}
std
::
exclusive_scan
(
tmp_shape
.
rbegin
(),
tmp_shape
.
rend
(),
rbegin
(),
(
int64_t
)
1
,
std
::
exclusive_scan
(
tmp_shape
.
rbegin
(),
tmp_shape
.
rend
(),
rbegin
(),
(
int64_t
)
1
,
std
::
multiplies
<>
{});
std
::
multiplies
<>
{});
...
...
oneflow/core/common/symbol.h
View file @
a715222c
...
@@ -22,7 +22,6 @@ limitations under the License.
...
@@ -22,7 +22,6 @@ limitations under the License.
#include <unordered_set>
#include <unordered_set>
#include <glog/logging.h>
#include <glog/logging.h>
#include "oneflow/core/common/type_traits.h"
#include "oneflow/core/common/type_traits.h"
#include "oneflow/core/common/maybe.h"
#include "oneflow/core/common/hash_eq_trait_ptr.h"
#include "oneflow/core/common/hash_eq_trait_ptr.h"
namespace
oneflow
{
namespace
oneflow
{
...
@@ -128,12 +127,6 @@ struct SymbolUtil final {
...
@@ -128,12 +127,6 @@ struct SymbolUtil final {
static
const
std
::
shared_ptr
<
const
T
>&
GetOrCreatePtr
(
const
T
&
obj
)
{
static
const
std
::
shared_ptr
<
const
T
>&
GetOrCreatePtr
(
const
T
&
obj
)
{
return
LocalThreadGetOr
<
CreateGlobalSymbol
>
(
obj
);
return
LocalThreadGetOr
<
CreateGlobalSymbol
>
(
obj
);
}
}
static
Maybe
<
Symbol
<
T
>>
GetSymbolByExistedRawPtr
(
const
T
*
ptr
)
{
CHECK_GT_OR_RETURN
(
ThreadLocalSymbolPtrSet
()
->
count
(
ptr
),
0
)
<<
"ptr: "
<<
ptr
;
Symbol
<
T
>
symbol
;
symbol
.
ptr_
=
ptr
;
return
symbol
;
}
};
};
template
<
typename
T
>
template
<
typename
T
>
...
...
oneflow/core/
framework
/tensor_desc.cpp
→
oneflow/core/
common
/tensor_desc.cpp
View file @
a715222c
...
@@ -13,17 +13,18 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...
@@ -13,17 +13,18 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
limitations under the License.
limitations under the License.
*/
*/
#include "oneflow/core/framework/tensor_desc.h"
#include "oneflow/core/common/tensor_desc.h"
#include "oneflow/core/register/blob_desc.pb.h"
namespace
oneflow
{
namespace
oneflow
{
namespace
user_op
{
namespace
user_op
{
TensorDesc
&
TensorDesc
::
operator
=
(
const
TensorDesc
&
rhs
)
{
TensorDesc
&
TensorDesc
::
operator
=
(
const
TensorDesc
&
rhs
)
{
*
this
->
mu
t_shape
(
)
=
rhs
.
shape
();
this
->
se
t_shape
(
rhs
.
shape
()
)
;
*
this
->
mu
t_stride
(
)
=
rhs
.
stride
();
this
->
se
t_stride
(
rhs
.
stride
()
)
;
*
this
->
mu
t_data_type
(
)
=
rhs
.
data_type
();
this
->
se
t_data_type
(
rhs
.
data_type
()
)
;
*
this
->
mu
t_is_dynamic
(
)
=
rhs
.
is_dynamic
();
this
->
se
t_is_dynamic
(
rhs
.
is_dynamic
()
)
;
return
*
this
;
return
*
this
;
}
}
...
...
oneflow/core/
framework
/tensor_desc.h
→
oneflow/core/
common
/tensor_desc.h
View file @
a715222c
...
@@ -13,16 +13,18 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...
@@ -13,16 +13,18 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
limitations under the License.
limitations under the License.
*/
*/
#ifndef ONEFLOW_CORE_
FRAMEWORK
_TENSOR_DESC_H_
#ifndef ONEFLOW_CORE_
COMMON
_TENSOR_DESC_H_
#define ONEFLOW_CORE_
FRAMEWORK
_TENSOR_DESC_H_
#define ONEFLOW_CORE_
COMMON
_TENSOR_DESC_H_
#include "oneflow/core/common/util.h"
#include "oneflow/core/common/util.h"
#include "oneflow/core/register/blob_desc.pb.h"
#include "oneflow/core/common/shape.h"
#include "oneflow/core/common/shape.h"
#include "oneflow/core/common/stride.h"
#include "oneflow/core/common/stride.h"
#include "oneflow/core/common/data_type.pb.h"
namespace
oneflow
{
namespace
oneflow
{
class
BlobDescProto
;
namespace
user_op
{
namespace
user_op
{
class
TensorDesc
{
class
TensorDesc
{
...
@@ -32,15 +34,14 @@ class TensorDesc {
...
@@ -32,15 +34,14 @@ class TensorDesc {
bool
operator
==
(
const
TensorDesc
&
)
const
;
bool
operator
==
(
const
TensorDesc
&
)
const
;
virtual
const
Shape
&
shape
()
const
=
0
;
virtual
const
Shape
&
shape
()
const
=
0
;
virtual
Shape
*
mut_
shape
(
)
=
0
;
virtual
void
set_shape
(
const
Shape
&
shape
)
=
0
;
virtual
const
Stride
&
stride
()
const
=
0
;
virtual
const
Stride
&
stride
()
const
=
0
;
virtual
Stride
*
mut_
stride
(
)
=
0
;
virtual
void
set_stride
(
const
Stride
&
stride
)
=
0
;
virtual
DataType
data_type
()
const
=
0
;
virtual
DataType
data_type
()
const
=
0
;
virtual
DataType
*
mut_
data_type
(
)
=
0
;
virtual
void
set_data_type
(
DataType
data_type
)
=
0
;
virtual
bool
is_dynamic
()
const
=
0
;
virtual
bool
is_dynamic
()
const
=
0
;
virtual
bool
*
mut_is_dynamic
()
=
0
;
virtual
void
set_is_dynamic
(
bool
is_dynamic
)
=
0
;
virtual
void
set_is_dynamic
(
bool
val
)
=
0
;
protected:
protected:
TensorDesc
()
=
default
;
TensorDesc
()
=
default
;
...
@@ -56,15 +57,14 @@ class NaiveTensorDesc final : public TensorDesc {
...
@@ -56,15 +57,14 @@ class NaiveTensorDesc final : public TensorDesc {
NaiveTensorDesc
&
operator
=
(
const
BlobDescProto
&
);
NaiveTensorDesc
&
operator
=
(
const
BlobDescProto
&
);
const
Shape
&
shape
()
const
override
{
return
shape_
;
}
const
Shape
&
shape
()
const
override
{
return
shape_
;
}
Shape
*
mut_
shape
(
)
override
{
return
&
shape
_
;
}
void
set_shape
(
const
Shape
&
shape
)
override
{
shape_
=
shape
;
}
const
Stride
&
stride
()
const
override
{
return
stride_
;
}
const
Stride
&
stride
()
const
override
{
return
stride_
;
}
Stride
*
mut_
stride
(
)
override
{
return
&
stride
_
;
}
void
set_stride
(
const
Stride
&
stride
)
override
{
stride_
=
stride
;
}
DataType
data_type
()
const
override
{
return
data_type_
;
}
DataType
data_type
()
const
override
{
return
data_type_
;
}
DataType
*
mut_
data_type
(
)
override
{
return
&
data_type
_
;
}
void
set_data_type
(
DataType
data_type
)
override
{
data_type_
=
data_type
;
}
bool
is_dynamic
()
const
override
{
return
is_dynamic_
;
}
bool
is_dynamic
()
const
override
{
return
is_dynamic_
;
}
bool
*
mut_is_dynamic
()
override
{
return
&
is_dynamic_
;
}
void
set_is_dynamic
(
bool
is_dynamic
)
override
{
is_dynamic_
=
is_dynamic
;
}
void
set_is_dynamic
(
bool
val
)
override
{
is_dynamic_
=
val
;
}
private:
private:
Shape
shape_
;
Shape
shape_
;
...
@@ -77,4 +77,4 @@ class NaiveTensorDesc final : public TensorDesc {
...
@@ -77,4 +77,4 @@ class NaiveTensorDesc final : public TensorDesc {
}
// namespace oneflow
}
// namespace oneflow
#endif // ONEFLOW_CORE_
FRAMEWORK
_TENSOR_DESC_H_
#endif // ONEFLOW_CORE_
COMMON
_TENSOR_DESC_H_
oneflow/core/common/tensor_meta.cpp
0 → 100644
View file @
a715222c
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/common/tensor_meta.h"
#include "oneflow/core/common/stride.h"
#include "oneflow/core/framework/device.h"
#include "oneflow/core/common/shape_view.h"
namespace
oneflow
{
namespace
one
{
MutTensorMeta
::
MutTensorMeta
()
:
TensorMeta
(
kInvalidDataType
),
shape_
(
std
::
make_shared
<
const
Shape
>
()),
stride_
(
std
::
make_shared
<
const
Stride
>
())
{}
MutTensorMeta
::
MutTensorMeta
(
const
std
::
shared_ptr
<
const
Shape
>&
shape
,
DataType
dtype
)
:
TensorMeta
(
dtype
),
shape_
(
std
::
make_shared
<
const
Shape
>
(
*
shape
)),
stride_
(
std
::
make_shared
<
const
Stride
>
(
*
shape
))
{}
MutTensorMeta
::
MutTensorMeta
(
const
std
::
shared_ptr
<
const
Shape
>&
shape
,
const
std
::
shared_ptr
<
const
Stride
>&
stride
,
DataType
dtype
)
:
TensorMeta
(
dtype
),
shape_
(
std
::
make_shared
<
const
Shape
>
(
*
shape
)),
stride_
(
std
::
make_shared
<
const
Stride
>
(
*
stride
))
{}
MutTensorMeta
::
MutTensorMeta
(
const
Shape
&
shape
,
DataType
dtype
)
:
TensorMeta
(
dtype
),
shape_
(
std
::
make_shared
<
const
Shape
>
(
shape
)),
stride_
(
std
::
make_shared
<
const
Stride
>
(
shape
))
{}
MutTensorMeta
::
MutTensorMeta
(
const
Shape
&
shape
,
const
Stride
&
stride
,
DataType
dtype
)
:
TensorMeta
(
dtype
),
shape_
(
std
::
make_shared
<
const
Shape
>
(
shape
)),
stride_
(
std
::
make_shared
<
const
Stride
>
(
stride
))
{}
bool
MutTensorMeta
::
operator
==
(
const
MutTensorMeta
&
other
)
const
{
// It's correct to ignore is_dynamic_ field.
return
*
this
->
shape_ptr
()
==
*
other
.
shape_ptr
()
&&
this
->
dtype
()
==
other
.
dtype
()
&&
this
->
stride
()
==
other
.
stride
();
}
size_t
MutTensorMeta
::
CalcHashValue
()
const
{
// It's correct to ignore is_dynamic_ field.
return
Hash
(
*
shape_ptr
(),
dtype
(),
stride
());
}
ConstTensorMeta
::
ConstTensorMeta
()
:
TensorMeta
(
kInvalidDataType
),
shape_
(
SymbolOf
(
Shape
())),
stride_
(
SymbolOf
(
Stride
()))
{}
ConstTensorMeta
::
ConstTensorMeta
(
Symbol
<
Shape
>
shape
,
DataType
dtype
)
:
TensorMeta
(
dtype
),
shape_
(
shape
),
stride_
(
SymbolOf
(
Stride
(
*
shape
)))
{}
ConstTensorMeta
::
ConstTensorMeta
(
Symbol
<
Shape
>
shape
,
Symbol
<
Stride
>
stride
,
DataType
dtype
)
:
TensorMeta
(
dtype
),
shape_
(
shape
),
stride_
(
stride
)
{}
bool
ConstTensorMeta
::
operator
==
(
const
ConstTensorMeta
&
other
)
const
{
// It's correct to ignore is_dynamic_ field.
return
*
this
->
shape_ptr
()
==
*
other
.
shape_ptr
()
&&
this
->
dtype
()
==
other
.
dtype
()
&&
this
->
stride
()
==
other
.
stride
();
}
size_t
ConstTensorMeta
::
CalcHashValue
()
const
{
// It's correct to ignore is_dynamic_ field.
return
Hash
(
*
shape_ptr
(),
dtype
(),
stride
());
}
LocalTensorMeta
::
LocalTensorMeta
()
:
ConstTensorMeta
(
SymbolOf
(
Shape
()),
SymbolOf
(
Stride
()),
DataType
::
kInvalidDataType
),
device_
(
Symbol
<
Device
>
())
{}
LocalTensorMeta
::
LocalTensorMeta
(
Symbol
<
Shape
>
shape
,
DataType
dtype
,
Symbol
<
Device
>
device
)
:
ConstTensorMeta
(
shape
,
SymbolOf
(
Stride
(
*
shape
)),
dtype
),
device_
(
device
)
{}
LocalTensorMeta
::
LocalTensorMeta
(
Symbol
<
Shape
>
shape
,
Symbol
<
Stride
>
stride
,
DataType
dtype
,
Symbol
<
Device
>
device
)
:
ConstTensorMeta
(
shape
,
stride
,
dtype
),
device_
(
device
)
{}
bool
LocalTensorMeta
::
operator
==
(
const
LocalTensorMeta
&
other
)
const
{
// It's correct to ignore is_dynamic_ field.
return
*
this
->
shape_ptr
()
==
*
other
.
shape_ptr
()
&&
this
->
dtype
()
==
other
.
dtype
()
&&
this
->
device
()
==
other
.
device
()
&&
this
->
stride
()
==
other
.
stride
();
}
size_t
LocalTensorMeta
::
CalcHashValue
()
const
{
// It's correct to ignore is_dynamic_ field.
return
Hash
(
*
shape_ptr
(),
dtype
(),
device
(),
stride
());
}
MutLocalTensorMeta
::
MutLocalTensorMeta
()
:
MutTensorMeta
(
std
::
make_shared
<
const
Shape
>
(),
std
::
make_shared
<
const
Stride
>
(),
kInvalidDataType
),
device_
(
Symbol
<
Device
>
())
{}
MutLocalTensorMeta
::
MutLocalTensorMeta
(
const
std
::
shared_ptr
<
const
Shape
>&
shape
,
DataType
dtype
,
Symbol
<
Device
>
device
)
:
MutTensorMeta
(
shape
,
std
::
make_shared
<
const
Stride
>
(
*
shape
),
dtype
),
device_
(
device
)
{}
MutLocalTensorMeta
::
MutLocalTensorMeta
(
const
std
::
shared_ptr
<
const
Shape
>&
shape
,
const
std
::
shared_ptr
<
const
Stride
>&
stride
,
DataType
dtype
,
Symbol
<
Device
>
device
)
:
MutTensorMeta
(
shape
,
stride
,
dtype
),
device_
(
device
)
{}
MutLocalTensorMeta
::
MutLocalTensorMeta
(
const
Shape
&
shape
,
DataType
dtype
,
Symbol
<
Device
>
device
)
:
MutTensorMeta
(
shape
,
Stride
(
shape
),
dtype
),
device_
(
device
)
{}
MutLocalTensorMeta
::
MutLocalTensorMeta
(
const
Shape
&
shape
,
const
Stride
&
stride
,
DataType
dtype
,
Symbol
<
Device
>
device
)
:
MutTensorMeta
(
shape
,
stride
,
dtype
),
device_
(
device
)
{}
bool
MutLocalTensorMeta
::
operator
==
(
const
MutLocalTensorMeta
&
other
)
const
{
// It's correct to ignore is_dynamic_ field.
return
*
this
->
shape_ptr
()
==
*
other
.
shape_ptr
()
&&
this
->
dtype
()
==
other
.
dtype
()
&&
*
this
->
device
()
==
*
other
.
device
()
&&
this
->
stride
()
==
other
.
stride
();
}
size_t
MutLocalTensorMeta
::
CalcHashValue
()
const
{
// It's correct to ignore is_dynamic_ field.
return
Hash
(
*
shape_ptr
(),
dtype
(),
*
device
(),
stride
());
}
bool
GlobalTensorMeta
::
operator
==
(
const
GlobalTensorMeta
&
other
)
const
{
// It's correct to ignore is_dynamic_ field.
return
*
this
->
shape_ptr
()
==
*
other
.
shape_ptr
()
&&
this
->
dtype
()
==
other
.
dtype
()
&&
this
->
nd_sbp
()
==
other
.
nd_sbp
()
&&
this
->
parallel_desc
()
==
other
.
parallel_desc
();
}
size_t
GlobalTensorMeta
::
CalcHashValue
()
const
{
return
Hash
(
*
shape_ptr
(),
dtype
(),
nd_sbp
(),
parallel_desc
());
}
bool
IsContiguous
(
const
Shape
&
shape
,
const
Stride
&
stride
)
{
if
(
!
shape
.
is_initialized
())
{
return
true
;
}
return
IsContiguous
(
ShapeView
(
shape
),
stride
);
}
bool
IsContiguous
(
const
ShapeView
&
shape_view
,
const
Stride
&
stride
)
{
if
(
shape_view
.
NumAxes
()
<
1
||
shape_view
.
elem_cnt
()
<=
1
)
{
return
true
;
}
int64_t
dim
=
shape_view
.
NumAxes
();
int64_t
expected_stride
=
1
;
bool
contig_if_nonempty
=
true
;
for
(
int64_t
i
=
dim
-
1
;
i
>=
0
;
--
i
)
{
// Contiguous by default when any dim is equal to zero
// https://stackoverflow.com/questions/31681324/identify-contiguous-segments-of-a-non-contiguous-numpy-array
if
(
shape_view
.
At
(
i
)
==
0
)
{
return
true
;
}
if
(
contig_if_nonempty
&&
shape_view
.
At
(
i
)
!=
1
)
{
if
(
stride
.
at
(
i
)
!=
expected_stride
)
{
contig_if_nonempty
=
false
;
}
expected_stride
*=
shape_view
.
At
(
i
);
}
}
return
contig_if_nonempty
;
}
}
// namespace one
}
// namespace oneflow
oneflow/core/common/tensor_meta.h
0 → 100644
View file @
a715222c
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_COMMON_TENSOR_META_H_
#define ONEFLOW_COMMON_TENSOR_META_H_
#include <memory>
#include "oneflow/core/common/tensor_desc.h"
#include "oneflow/core/common/symbol.h"
namespace
oneflow
{
class
NdSbp
;
class
Shape
;
class
Stride
;
class
Device
;
class
ParallelDesc
;
namespace
one
{
bool
IsContiguous
(
const
Shape
&
shape
,
const
Stride
&
stride
);
bool
IsContiguous
(
const
ShapeView
&
shape_view
,
const
Stride
&
stride
);
class
TensorMeta
:
public
user_op
::
TensorDesc
{
public:
TensorMeta
(
DataType
dtype
)
:
data_type_
(
dtype
),
is_dynamic_
(
false
)
{}
TensorMeta
(
const
TensorMeta
&
other
)
=
default
;
TensorMeta
(
TensorMeta
&&
)
=
default
;
virtual
~
TensorMeta
()
=
default
;
virtual
const
std
::
shared_ptr
<
const
Shape
>&
shape_ptr
()
const
=
0
;
virtual
const
std
::
shared_ptr
<
const
Stride
>&
stride_ptr
()
const
=
0
;
virtual
bool
is_contiguous
()
const
=
0
;
DataType
dtype
()
const
{
return
data_type_
;
}
DataType
data_type
()
const
override
{
return
data_type_
;
}
bool
is_dynamic
()
const
override
{
return
is_dynamic_
;
}
virtual
void
set_shape
(
const
Shape
&
shape
)
override
{
PRINT_BUG_PROMPT_AND_ABORT
();
}
virtual
void
set_stride
(
const
Stride
&
stride
)
override
{
PRINT_BUG_PROMPT_AND_ABORT
();
}
virtual
void
set_data_type
(
DataType
data_type
)
override
{
PRINT_BUG_PROMPT_AND_ABORT
();
}
virtual
void
set_is_dynamic
(
bool
is_dynamic
)
override
{
PRINT_BUG_PROMPT_AND_ABORT
();
}
protected:
DataType
data_type_
;
bool
is_dynamic_
;
};
class
MutTensorMeta
:
public
TensorMeta
{
public:
// uninitialized MutTensorMeta.
MutTensorMeta
();
MutTensorMeta
(
const
MutTensorMeta
&
other
)
:
TensorMeta
(
other
),
shape_
(
std
::
make_shared
<
const
Shape
>
(
*
other
.
shape_
)),
stride_
(
std
::
make_shared
<
const
Stride
>
(
*
other
.
stride_
))
{}
MutTensorMeta
(
const
std
::
shared_ptr
<
const
Shape
>&
shape
,
DataType
dtype
);
MutTensorMeta
(
const
std
::
shared_ptr
<
const
Shape
>&
shape
,
const
std
::
shared_ptr
<
const
Stride
>&
stride
,
DataType
dtype
);
MutTensorMeta
(
const
Shape
&
shape
,
DataType
dtype
);
MutTensorMeta
(
const
Shape
&
shape
,
const
Stride
&
stride
,
DataType
dtype
);
virtual
~
MutTensorMeta
()
=
default
;
const
std
::
shared_ptr
<
const
Shape
>&
shape_ptr
()
const
override
{
return
shape_
;
}
const
std
::
shared_ptr
<
const
Stride
>&
stride_ptr
()
const
override
{
return
stride_
;
}
const
Shape
&
shape
()
const
override
{
return
*
shape_
;
}
const
Stride
&
stride
()
const
override
{
return
*
stride_
;
}
bool
is_contiguous
()
const
override
{
return
IsContiguous
(
*
shape_
,
*
stride_
);
}
void
set_shape
(
const
Shape
&
shape
)
override
{
*
const_cast
<
Shape
*>
(
shape_
.
get
())
=
shape
;
}
void
set_stride
(
const
Stride
&
stride
)
override
{
*
const_cast
<
Stride
*>
(
stride_
.
get
())
=
stride
;
}
void
set_data_type
(
DataType
data_type
)
override
{
data_type_
=
data_type
;
}
void
set_is_dynamic
(
bool
is_dynamic
)
override
{
is_dynamic_
=
is_dynamic
;
}
bool
operator
==
(
const
MutTensorMeta
&
other
)
const
;
size_t
CalcHashValue
()
const
;
MutTensorMeta
&
operator
=
(
const
MutTensorMeta
&
other
)
{
this
->
data_type_
=
other
.
data_type_
;
this
->
is_dynamic_
=
other
.
is_dynamic_
;
this
->
shape_
=
std
::
make_shared
<
const
Shape
>
(
*
other
.
shape_
);
this
->
stride_
=
std
::
make_shared
<
const
Stride
>
(
*
other
.
stride_
);
return
*
this
;
}
protected:
std
::
shared_ptr
<
const
Shape
>
shape_
;
std
::
shared_ptr
<
const
Stride
>
stride_
;
};
class
ConstTensorMeta
:
public
TensorMeta
{
public:
// uninitialized ConstTensorMeta.
ConstTensorMeta
();
ConstTensorMeta
(
const
ConstTensorMeta
&
)
=
default
;
ConstTensorMeta
(
Symbol
<
Shape
>
shape
,
DataType
dtype
);
ConstTensorMeta
(
Symbol
<
Shape
>
shape
,
Symbol
<
Stride
>
stride
,
DataType
dtype
);
ConstTensorMeta
(
const
Shape
&
shape
,
DataType
dtype
)
:
ConstTensorMeta
(
SymbolOf
(
shape
),
dtype
)
{}
ConstTensorMeta
(
const
Shape
&
shape
,
const
Stride
&
stride
,
DataType
dtype
)
:
ConstTensorMeta
(
SymbolOf
(
shape
),
SymbolOf
(
stride
),
dtype
)
{}
virtual
~
ConstTensorMeta
()
=
default
;
const
std
::
shared_ptr
<
const
Shape
>&
shape_ptr
()
const
override
{
return
shape_
.
shared_from_symbol
();
}
const
std
::
shared_ptr
<
const
Stride
>&
stride_ptr
()
const
override
{
return
stride_
.
shared_from_symbol
();
}
const
Shape
&
shape
()
const
override
{
return
*
shape_
;
}
const
Stride
&
stride
()
const
override
{
return
*
stride_
;
}
bool
is_contiguous
()
const
override
{
return
IsContiguous
(
*
shape_
,
*
stride_
);
}
bool
operator
==
(
const
ConstTensorMeta
&
other
)
const
;
size_t
CalcHashValue
()
const
;
ConstTensorMeta
&
operator
=
(
const
ConstTensorMeta
&
other
)
{
this
->
data_type_
=
other
.
data_type_
;
this
->
is_dynamic_
=
other
.
is_dynamic_
;
this
->
shape_
=
other
.
shape_
;
this
->
stride_
=
other
.
stride_
;
return
*
this
;
}
protected:
Symbol
<
Shape
>
shape_
;
Symbol
<
Stride
>
stride_
;
};
class
LocalTensorMeta
:
public
ConstTensorMeta
{
public:
// uninitialized LocalTensorMeta.
LocalTensorMeta
();
LocalTensorMeta
(
const
LocalTensorMeta
&
)
=
default
;
LocalTensorMeta
(
Symbol
<
Shape
>
shape
,
DataType
dtype
,
Symbol
<
Device
>
device
);
LocalTensorMeta
(
Symbol
<
Shape
>
shape
,
Symbol
<
Stride
>
stride
,
DataType
dtype
,
Symbol
<
Device
>
device
);
LocalTensorMeta
(
const
Shape
&
shape
,
DataType
dtype
,
Symbol
<
Device
>
device
)
:
LocalTensorMeta
(
SymbolOf
(
shape
),
dtype
,
device
)
{}
LocalTensorMeta
(
const
Shape
&
shape
,
const
Stride
&
stride
,
DataType
dtype
,
Symbol
<
Device
>
device
)
:
LocalTensorMeta
(
SymbolOf
(
shape
),
SymbolOf
(
stride
),
dtype
,
device
)
{}
virtual
~
LocalTensorMeta
()
=
default
;
const
Symbol
<
Device
>&
device
()
const
{
return
device_
;
}
bool
operator
==
(
const
LocalTensorMeta
&
other
)
const
;
size_t
CalcHashValue
()
const
;
LocalTensorMeta
&
operator
=
(
const
LocalTensorMeta
&
other
)
=
default
;
private:
Symbol
<
Device
>
device_
;
};
class
MutLocalTensorMeta
:
public
MutTensorMeta
{
public:
// uninitialized MutLocalTensorMeta.
MutLocalTensorMeta
();
MutLocalTensorMeta
(
const
MutLocalTensorMeta
&
)
=
default
;
MutLocalTensorMeta
(
const
std
::
shared_ptr
<
const
Shape
>&
shape
,
DataType
dtype
,
Symbol
<
Device
>
device
);
MutLocalTensorMeta
(
const
std
::
shared_ptr
<
const
Shape
>&
shape
,
const
std
::
shared_ptr
<
const
Stride
>&
stride
,
DataType
dtype
,
Symbol
<
Device
>
device
);
MutLocalTensorMeta
(
const
Shape
&
shape
,
DataType
dtype
,
Symbol
<
Device
>
device
);
MutLocalTensorMeta
(
const
Shape
&
shape
,
const
Stride
&
stride
,
DataType
dtype
,
Symbol
<
Device
>
device
);
virtual
~
MutLocalTensorMeta
()
=
default
;
const
Symbol
<
Device
>&
device
()
const
{
return
device_
;
}
Symbol
<
Device
>*
mut_device
()
{
return
&
device_
;
}
bool
operator
==
(
const
MutLocalTensorMeta
&
other
)
const
;
size_t
CalcHashValue
()
const
;
MutLocalTensorMeta
&
operator
=
(
const
MutLocalTensorMeta
&
other
)
=
default
;
private:
Symbol
<
Device
>
device_
;
};
class
GlobalTensorMeta
:
public
ConstTensorMeta
{
public:
GlobalTensorMeta
(
Symbol
<
Shape
>
shape
,
DataType
dtype
,
Symbol
<
NdSbp
>
nd_sbp
,
Symbol
<
ParallelDesc
>
parallel_desc
)
:
ConstTensorMeta
(
shape
,
dtype
),
nd_sbp_
(
nd_sbp
),
parallel_desc_
(
parallel_desc
)
{}
GlobalTensorMeta
(
const
Shape
&
shape
,
DataType
dtype
,
Symbol
<
NdSbp
>
nd_sbp
,
Symbol
<
ParallelDesc
>
parallel_desc
)
:
GlobalTensorMeta
(
SymbolOf
(
shape
),
dtype
,
nd_sbp
,
parallel_desc
)
{}
GlobalTensorMeta
(
const
GlobalTensorMeta
&
)
=
default
;
GlobalTensorMeta
(
GlobalTensorMeta
&&
)
=
default
;
virtual
~
GlobalTensorMeta
()
=
default
;
bool
operator
==
(
const
GlobalTensorMeta
&
other
)
const
;
Symbol
<
NdSbp
>
nd_sbp
()
const
{
return
nd_sbp_
;
}
Symbol
<
ParallelDesc
>
parallel_desc
()
const
{
return
parallel_desc_
;
}
size_t
CalcHashValue
()
const
;
private:
Symbol
<
NdSbp
>
nd_sbp_
;
Symbol
<
ParallelDesc
>
parallel_desc_
;
};
}
// namespace one
}
// namespace oneflow
namespace
std
{
template
<
>
struct
hash
<
oneflow
::
one
::
LocalTensorMeta
>
final
{
size_t
operator
()(
const
oneflow
::
one
::
LocalTensorMeta
&
local_tensor_meta
)
const
{
return
local_tensor_meta
.
CalcHashValue
();
}
};
template
<
>
struct
hash
<
oneflow
::
one
::
GlobalTensorMeta
>
final
{
size_t
operator
()(
const
oneflow
::
one
::
GlobalTensorMeta
&
global_tensor_meta
)
const
{
return
global_tensor_meta
.
CalcHashValue
();
}
};
}
// namespace std
#endif // ONEFLOW_COMMON_TENSOR_META_H_
oneflow/core/common/thread_local_guard.h
View file @
a715222c
...
@@ -25,6 +25,10 @@ namespace oneflow {
...
@@ -25,6 +25,10 @@ namespace oneflow {
template
<
typename
T
>
template
<
typename
T
>
class
ThreadLocalGuard
{
class
ThreadLocalGuard
{
public:
public:
ThreadLocalGuard
()
{
old_value_
=
*
MutThreadLocalValue
();
*
MutThreadLocalValue
()
=
Optional
<
T
>
();
}
explicit
ThreadLocalGuard
(
const
T
&
value
)
{
explicit
ThreadLocalGuard
(
const
T
&
value
)
{
old_value_
=
*
MutThreadLocalValue
();
old_value_
=
*
MutThreadLocalValue
();
*
MutThreadLocalValue
()
=
Optional
<
T
>
(
value
);
*
MutThreadLocalValue
()
=
Optional
<
T
>
(
value
);
...
...
oneflow/core/common/throw.h
View file @
a715222c
...
@@ -23,21 +23,29 @@ namespace oneflow {
...
@@ -23,21 +23,29 @@ namespace oneflow {
namespace
details
{
namespace
details
{
struct
Throw
final
{
struct
Throw
final
{
void
operator
=
(
Error
&&
error
)
{
ThrowError
(
error
.
error_proto
());
}
void
operator
=
(
Error
&&
error
)
{
ThrowError
(
error
.
stacked_error
());
}
};
};
}
// namespace details
}
// namespace details
}
// namespace oneflow
}
// namespace oneflow
#define THROW(err_type) \
#define THROW(err_type) \
oneflow::details::Throw() = \
::oneflow::details::Throw() = \
oneflow::Error::err_type().AddStackFrame(__FILE__, __LINE__, __FUNCTION__)
::oneflow::Error::err_type().AddStackFrame([](const char* function) { \
thread_local static auto frame = \
#define CHECK_OR_THROW(expr) \
::oneflow::SymbolOf(::oneflow::ErrorStackFrame(__FILE__, __LINE__, function)); \
if (!(expr)) \
return frame; \
oneflow::details::Throw() = \
}(__FUNCTION__))
oneflow::Error::CheckFailedError().AddStackFrame(__FILE__, __LINE__, __FUNCTION__) \
#define CHECK_OR_THROW(expr) \
if (!(expr)) \
::oneflow::details::Throw() = \
::oneflow::Error::CheckFailedError().AddStackFrame([](const char* function) { \
thread_local static auto frame = \
::oneflow::SymbolOf(::oneflow::ErrorStackFrame(__FILE__, __LINE__, function)); \
return frame; \
}(__FUNCTION__)) \
<< "Check failed: " << OF_PP_STRINGIZE(expr) << ": "
<< "Check failed: " << OF_PP_STRINGIZE(expr) << ": "
#define CHECK_EQ_OR_THROW(lhs, rhs) \
#define CHECK_EQ_OR_THROW(lhs, rhs) \
...
@@ -66,12 +74,20 @@ struct Throw final {
...
@@ -66,12 +74,20 @@ struct Throw final {
#define CHECK_ISNULL_OR_THROW(ptr) CHECK_OR_THROW(ptr == nullptr)
#define CHECK_ISNULL_OR_THROW(ptr) CHECK_OR_THROW(ptr == nullptr)
#define TODO_THEN_THROW() \
#define TODO_THEN_THROW() \
oneflow::details::Throw() = \
::oneflow::details::Throw() = \
oneflow::Error::TodoError().AddStackFrame(__FILE__, __LINE__, __FUNCTION__)
::oneflow::Error::TodoError().AddStackFrame([](const char* function) { \
thread_local static auto frame = \
#define UNIMPLEMENTED_THEN_THROW() \
::oneflow::SymbolOf(::oneflow::ErrorStackFrame(__FILE__, __LINE__, function)); \
oneflow::details::Throw() = \
return frame; \
oneflow::Error::UnimplementedError().AddStackFrame(__FILE__, __LINE__, __FUNCTION__)
}(__FUNCTION__))
#define UNIMPLEMENTED_THEN_THROW() \
::oneflow::details::Throw() = \
::oneflow::Error::UnimplementedError().AddStackFrame([](const char* function) { \
thread_local static auto frame = \
::oneflow::SymbolOf(::oneflow::ErrorStackFrame(__FILE__, __LINE__, function)); \
return frame; \
}(__FUNCTION__))
#endif // ONEFLOW_CORE_COMMON_THROW_H_
#endif // ONEFLOW_CORE_COMMON_THROW_H_
Prev
1
…
13
14
15
16
17
18
19
20
21
…
24
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment