Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Oneflow
Commits
21d47d0e
Commit
21d47d0e
authored
Oct 24, 2022
by
yuguo
Browse files
Oneflow 0.8 for DCU
parents
Changes
556
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1786 additions
and
0 deletions
+1786
-0
oneflow/core/common/data_type_seq.h
oneflow/core/common/data_type_seq.h
+72
-0
oneflow/core/common/decorator.h
oneflow/core/common/decorator.h
+207
-0
oneflow/core/common/decorator_test.cpp
oneflow/core/common/decorator_test.cpp
+60
-0
oneflow/core/common/device_type.h
oneflow/core/common/device_type.h
+61
-0
oneflow/core/common/device_type.proto
oneflow/core/common/device_type.proto
+10
-0
oneflow/core/common/dtype_signature.h
oneflow/core/common/dtype_signature.h
+43
-0
oneflow/core/common/dtype_signature.proto
oneflow/core/common/dtype_signature.proto
+8
-0
oneflow/core/common/eigen_util.h
oneflow/core/common/eigen_util.h
+38
-0
oneflow/core/common/either_ptr.h
oneflow/core/common/either_ptr.h
+105
-0
oneflow/core/common/env_var/debug_mode.h
oneflow/core/common/env_var/debug_mode.h
+30
-0
oneflow/core/common/env_var/env_var.h
oneflow/core/common/env_var/env_var.h
+75
-0
oneflow/core/common/env_var/vm.h
oneflow/core/common/env_var/vm.h
+26
-0
oneflow/core/common/error.cpp
oneflow/core/common/error.cpp
+354
-0
oneflow/core/common/error.h
oneflow/core/common/error.h
+144
-0
oneflow/core/common/error.proto
oneflow/core/common/error.proto
+189
-0
oneflow/core/common/error_util.cpp
oneflow/core/common/error_util.cpp
+158
-0
oneflow/core/common/error_util.h
oneflow/core/common/error_util.h
+29
-0
oneflow/core/common/exception.h
oneflow/core/common/exception.h
+57
-0
oneflow/core/common/flat_shape.cpp
oneflow/core/common/flat_shape.cpp
+70
-0
oneflow/core/common/flat_shape.h
oneflow/core/common/flat_shape.h
+50
-0
No files found.
Too many changes to show.
To preserve performance only
556 of 556+
files are displayed.
Plain diff
Email patch
oneflow/core/common/data_type_seq.h
0 → 100644
View file @
21d47d0e
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_COMMON_DATA_TYPE_SEQ_H_
#define ONEFLOW_CORE_COMMON_DATA_TYPE_SEQ_H_
#include "oneflow/core/common/preprocessor.h"
// SEQ
#define BOOL_DATA_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(bool, DataType::kBool)
#define FLOATING_DATA_TYPE_SEQ \
OF_PP_MAKE_TUPLE_SEQ(float, DataType::kFloat) \
OF_PP_MAKE_TUPLE_SEQ(double, DataType::kDouble)
#define SIGNED_INT_DATA_TYPE_SEQ \
OF_PP_MAKE_TUPLE_SEQ(int8_t, DataType::kInt8) \
OF_PP_MAKE_TUPLE_SEQ(int32_t, DataType::kInt32) \
OF_PP_MAKE_TUPLE_SEQ(int64_t, DataType::kInt64)
#define UNSIGNED_INT_DATA_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(uint8_t, DataType::kUInt8)
#define INT_DATA_TYPE_SEQ SIGNED_INT_DATA_TYPE_SEQ
#define CHAR_DATA_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(char, DataType::kChar)
#define ARITHMETIC_DATA_TYPE_SEQ \
FLOATING_DATA_TYPE_SEQ \
INT_DATA_TYPE_SEQ
#define POD_DATA_TYPE_SEQ \
ARITHMETIC_DATA_TYPE_SEQ CHAR_DATA_TYPE_SEQ UNSIGNED_INT_DATA_TYPE_SEQ BOOL_DATA_TYPE_SEQ
#define POD_AND_HALF_DATA_TYPE_SEQ POD_DATA_TYPE_SEQ FLOAT16_DATA_TYPE_SEQ
#define PB_DATA_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(OFRecord, DataType::kOFRecord)
#define ALL_DATA_TYPE_SEQ POD_DATA_TYPE_SEQ PB_DATA_TYPE_SEQ
#define INDEX_DATA_TYPE_SEQ \
OF_PP_MAKE_TUPLE_SEQ(int32_t, DataType::kInt32) \
OF_PP_MAKE_TUPLE_SEQ(int64_t, DataType::kInt64)
#define FLOAT16_DATA_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(float16, DataType::kFloat16)
#if defined(WITH_CUDA)
#define HALF_DATA_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(half, DataType::kFloat16)
#endif
#if defined(WITH_ROCM)
#define HALF_DATA_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(half, DataType::kFloat16)
#endif
#define IMAGE_DATA_TYPE_SEQ \
OF_PP_MAKE_TUPLE_SEQ(uint8_t, DataType::kUInt8) \
OF_PP_MAKE_TUPLE_SEQ(float, DataType::kFloat)
#define NO_BOXING_DATA_TYPE_SEQ \
OF_PP_MAKE_TUPLE_SEQ(OFRecord, DataType::kOFRecord) \
OF_PP_MAKE_TUPLE_SEQ(TensorBuffer, DataType::kTensorBuffer)
#endif // ONEFLOW_CORE_COMMON_DATA_TYPE_SEQ_H_
oneflow/core/common/decorator.h
0 → 100644
View file @
21d47d0e
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_COMMON_DECORATOR_H_
#define ONEFLOW_CORE_COMMON_DECORATOR_H_
#include <type_traits>
#include <unordered_map>
#include "tuple_hash.h"
#include "static_check.h"
#include "oneflow/core/common/env_var/env_var.h"
#include "oneflow/core/common/cpp_attribute.h"
namespace
oneflow
{
template
<
template
<
typename
...
>
class
Decorator
>
struct
WithDecorator
final
{
template
<
typename
T
,
typename
=
void
>
struct
Decorate
;
template
<
typename
T
,
typename
...
Args
>
struct
Decorate
<
T
(
*
)(
Args
...)
>
final
{
template
<
T
(
*
func
)(
Args
...)>
static
T
Call
(
Args
...
args
)
{
return
Decorator
<
T
,
Args
...
>::
template
Call
<
func
>(
args
...);
}
};
};
#define DECORATE(fn_ptr, decorator) \
(&WithDecorator<decorator>::Decorate<decltype(fn_ptr)>::Call<fn_ptr>)
template
<
typename
...
Args
>
struct
ThreadLocalCopiable
;
template
<
typename
RetT
>
struct
ThreadLocalCopiable
<
RetT
>
{
template
<
RetT
(
*
func
)()>
static
RetT
Call
()
{
static
thread_local
RetT
value
=
func
();
return
value
;
}
};
template
<
typename
RetT
,
typename
Arg0
>
struct
ThreadLocalCopiable
<
RetT
,
Arg0
>
{
template
<
RetT
(
*
func
)(
Arg0
)>
static
RetT
Call
(
Arg0
arg0
)
{
using
KeyT
=
typename
std
::
decay
<
Arg0
>::
type
;
using
MappedT
=
typename
std
::
decay
<
RetT
>::
type
;
static
thread_local
std
::
unordered_map
<
KeyT
,
MappedT
>
map
;
auto
iter
=
map
.
find
(
arg0
);
if
(
iter
==
map
.
end
())
{
iter
=
map
.
emplace
(
arg0
,
func
(
arg0
)).
first
;
}
return
iter
->
second
;
}
private:
static_assert
(
!
IsOutArg
<
Arg0
>::
value
,
""
);
static_assert
(
!
StaticAny
<
IsOutArg
,
Arg0
>::
value
,
""
);
};
template
<
typename
RetT
,
typename
Arg0
,
typename
Arg1
>
struct
ThreadLocalCopiable
<
RetT
,
Arg0
,
Arg1
>
{
template
<
RetT
(
*
func
)(
Arg0
,
Arg1
)>
static
RetT
Call
(
Arg0
arg0
,
Arg1
arg1
)
{
using
KeyT0
=
typename
std
::
decay
<
Arg0
>::
type
;
using
KeyT1
=
typename
std
::
decay
<
Arg1
>::
type
;
using
MappedT
=
typename
std
::
decay
<
RetT
>::
type
;
static
thread_local
std
::
unordered_map
<
KeyT0
,
std
::
unordered_map
<
KeyT1
,
MappedT
>>
map
;
auto
*
last_map
=
&
map
[
arg0
];
auto
iter
=
last_map
->
find
(
arg1
);
if
(
iter
==
last_map
->
end
())
{
iter
=
last_map
->
emplace
(
arg1
,
func
(
arg0
,
arg1
)).
first
;
}
return
iter
->
second
;
}
private:
static_assert
(
!
StaticAny
<
IsOutArg
,
Arg0
,
Arg1
>::
value
,
""
);
};
template
<
typename
RetT
,
typename
Arg0
,
typename
Arg1
,
typename
Arg2
>
struct
ThreadLocalCopiable
<
RetT
,
Arg0
,
Arg1
,
Arg2
>
{
template
<
RetT
(
*
func
)(
Arg0
,
Arg1
,
Arg2
)>
static
RetT
Call
(
Arg0
arg0
,
Arg1
arg1
,
Arg2
arg2
)
{
using
KeyT0
=
typename
std
::
decay
<
Arg0
>::
type
;
using
KeyT1
=
typename
std
::
decay
<
Arg1
>::
type
;
using
KeyT2
=
typename
std
::
decay
<
Arg2
>::
type
;
using
MappedT
=
typename
std
::
decay
<
RetT
>::
type
;
static
thread_local
std
::
unordered_map
<
KeyT0
,
std
::
unordered_map
<
KeyT1
,
std
::
unordered_map
<
KeyT2
,
MappedT
>>>
map
;
auto
*
last_map
=
&
map
[
arg0
][
arg1
];
auto
iter
=
last_map
->
find
(
arg2
);
if
(
iter
==
last_map
->
end
())
{
iter
=
last_map
->
emplace
(
arg2
,
func
(
arg0
,
arg1
,
arg2
)).
first
;
}
return
iter
->
second
;
}
private:
static_assert
(
!
StaticAny
<
IsOutArg
,
Arg0
,
Arg1
,
Arg2
>::
value
,
""
);
};
template
<
typename
RetT
,
typename
Arg0
,
typename
Arg1
,
typename
Arg2
,
typename
Arg3
,
typename
...
Args
>
struct
ThreadLocalCopiable
<
RetT
,
Arg0
,
Arg1
,
Arg2
,
Arg3
,
Args
...
>
{
template
<
RetT
(
*
func
)(
Arg0
,
Arg1
,
Arg2
,
Arg3
,
Args
...)>
static
RetT
Call
(
Arg0
arg0
,
Arg1
arg1
,
Arg2
arg2
,
Arg3
arg3
,
Args
...
args
)
{
using
KeyT0
=
typename
std
::
decay
<
Arg0
>::
type
;
using
KeyT1
=
typename
std
::
decay
<
Arg1
>::
type
;
using
KeyT2
=
typename
std
::
decay
<
Arg2
>::
type
;
using
KeyT3
=
typename
std
::
decay
<
Arg3
>::
type
;
using
KeyT
=
std
::
tuple
<
KeyT0
,
KeyT1
,
KeyT2
,
KeyT3
,
typename
std
::
decay
<
Args
>::
type
...
>
;
using
MappedT
=
typename
std
::
decay
<
RetT
>::
type
;
static
thread_local
std
::
unordered_map
<
KeyT
,
MappedT
>
map
;
const
auto
&
key
=
KeyT
(
arg0
,
arg1
,
arg2
,
arg3
,
args
...);
auto
iter
=
map
.
find
(
key
);
if
(
iter
==
map
.
end
())
{
iter
=
map
.
emplace
(
key
,
func
(
arg0
,
arg1
,
arg2
,
arg3
,
args
...)).
first
;
}
return
iter
->
second
;
}
private:
static_assert
(
!
StaticAny
<
IsOutArg
,
Arg0
,
Arg1
,
Arg2
,
Arg3
,
Args
...
>::
value
,
""
);
};
// for scalar type key.
template
<
typename
RetT
,
typename
...
Args
>
struct
ThreadLocal
:
public
ThreadLocalCopiable
<
RetT
,
Args
...
>
{
private:
static_assert
(
StaticAll
<
IsDecayedScalarType
,
Args
...
>::
value
,
""
);
};
template
<
typename
...
Args
>
struct
ThreadLocalCachedCopiable
;
template
<
typename
RetT
>
struct
ThreadLocalCachedCopiable
<
RetT
>
{
template
<
RetT
(
*
func
)()>
static
RetT
Call
()
{
static
thread_local
RetT
value
=
func
();
return
value
;
}
};
template
<
typename
RetT
,
typename
Arg0
>
struct
ThreadLocalCachedCopiable
<
RetT
,
Arg0
>
{
template
<
RetT
(
*
func
)(
Arg0
)>
static
RetT
Call
(
Arg0
arg0
)
{
using
KeyT
=
typename
std
::
decay
<
Arg0
>::
type
;
using
MappedT
=
typename
std
::
decay
<
RetT
>::
type
;
static
thread_local
std
::
unordered_map
<
KeyT
,
MappedT
>
map
;
auto
iter
=
map
.
find
(
arg0
);
if
(
iter
==
map
.
end
())
{
if
(
unlikely
(
map
.
size
()
>=
ThreadLocalEnvInteger
<
ONEFLOW_THRAED_LOCAL_CACHED_SIZE
>
()))
{
map
.
clear
();
}
iter
=
map
.
emplace
(
arg0
,
func
(
arg0
)).
first
;
}
return
iter
->
second
;
}
private:
static_assert
(
!
IsOutArg
<
Arg0
>::
value
,
""
);
static_assert
(
!
StaticAny
<
IsOutArg
,
Arg0
>::
value
,
""
);
};
template
<
typename
RetT
,
typename
Arg0
,
typename
...
Args
>
struct
ThreadLocalCachedCopiable
<
RetT
,
Arg0
,
Args
...
>
{
template
<
RetT
(
*
func
)(
Arg0
,
Args
...)>
static
RetT
Call
(
Arg0
arg0
,
Args
...
args
)
{
using
KeyT0
=
typename
std
::
decay
<
Arg0
>::
type
;
using
KeyT
=
std
::
tuple
<
KeyT0
,
typename
std
::
decay
<
Args
>::
type
...
>
;
using
MappedT
=
typename
std
::
decay
<
RetT
>::
type
;
static
thread_local
std
::
unordered_map
<
KeyT
,
MappedT
>
map
;
const
auto
&
key
=
KeyT
(
arg0
,
args
...);
auto
iter
=
map
.
find
(
key
);
if
(
iter
==
map
.
end
())
{
if
(
unlikely
(
map
.
size
()
>=
ThreadLocalEnvInteger
<
ONEFLOW_THRAED_LOCAL_CACHED_SIZE
>
()))
{
map
.
clear
();
}
iter
=
map
.
emplace
(
key
,
func
(
arg0
,
args
...)).
first
;
}
return
iter
->
second
;
}
private:
static_assert
(
!
StaticAny
<
IsOutArg
,
Arg0
,
Args
...
>::
value
,
""
);
};
// for scalar type key.
template
<
typename
RetT
,
typename
...
Args
>
struct
ThreadLocalCached
:
public
ThreadLocalCachedCopiable
<
RetT
,
Args
...
>
{
private:
static_assert
(
StaticAll
<
IsDecayedScalarType
,
Args
...
>::
value
,
""
);
};
}
// namespace oneflow
#endif // ONEFLOW_CORE_COMMON_DECORATOR_H_
oneflow/core/common/decorator_test.cpp
0 → 100644
View file @
21d47d0e
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "gtest/gtest.h"
#include "oneflow/core/common/decorator.h"
#include "oneflow/core/common/util.h"
namespace
oneflow
{
namespace
test
{
Maybe
<
int
>
Inc
(
int
x
)
{
return
x
+
1
;
}
Maybe
<
int
>
IncByConstRef
(
const
int
&
x
)
{
return
x
+
1
;
}
TEST
(
ThreadLocal
,
scalar
)
{
auto
*
CachedInc
=
DECORATE
(
&
Inc
,
ThreadLocal
);
int
x
=
CHECK_JUST
(
CachedInc
(
0
));
ASSERT_EQ
(
x
,
1
);
}
TEST
(
ThreadLocal
,
const_ref
)
{
auto
*
CachedIncByConstRef
=
DECORATE
(
&
IncByConstRef
,
ThreadLocal
);
int
x
=
CHECK_JUST
(
CachedIncByConstRef
(
0
));
ASSERT_EQ
(
x
,
1
);
}
namespace
{
struct
Foo
{
static
Maybe
<
Foo
>
New
(
int
x
)
{
return
std
::
shared_ptr
<
Foo
>
(
new
Foo
{
x
});
}
int
x
;
};
}
// namespace
TEST
(
ThreadLocal
,
_class
)
{
auto
*
CachedFooNew
=
DECORATE
(
&
Foo
::
New
,
ThreadLocal
);
const
auto
&
foo
=
CHECK_JUST
(
CachedFooNew
(
10
));
const
auto
&
bar
=
CHECK_JUST
(
CachedFooNew
(
10
));
ASSERT_EQ
(
foo
->
x
,
10
);
ASSERT_TRUE
(
foo
==
bar
);
}
}
// namespace test
}
// namespace oneflow
oneflow/core/common/device_type.h
0 → 100644
View file @
21d47d0e
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_COMMON_DEVICE_TYPE_H_
#define ONEFLOW_CORE_COMMON_DEVICE_TYPE_H_
#include "oneflow/core/common/device_type.pb.h"
namespace
std
{
template
<
>
struct
hash
<
oneflow
::
DeviceType
>
final
{
size_t
operator
()(
oneflow
::
DeviceType
device_type
)
const
{
return
static_cast
<
size_t
>
(
device_type
);
}
};
}
// namespace std
namespace
oneflow
{
inline
std
::
string
PrintAvailableDevices
()
{
std
::
string
str
(
"cpu"
);
#if defined(WITH_CUDA) || defined(WITH_ROCM)
str
+=
", cuda"
;
#endif
return
str
;
}
inline
std
::
string
PrintGeneratorAvailableDevices
()
{
std
::
string
str
(
"cpu"
);
#if defined(WITH_CUDA) || defined(WITH_ROCM)
str
+=
", cuda"
;
#endif
str
+=
", auto"
;
// "auto" is a fake device type for random generator.
return
str
;
}
#if defined(WITH_CUDA) || defined(WITH_ROCM)
#define DEVICE_TYPE_SEQ \
OF_PP_MAKE_TUPLE_SEQ(DeviceType::kCPU) \
OF_PP_MAKE_TUPLE_SEQ(DeviceType::kCUDA)
#else
#define DEVICE_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(DeviceType::kCPU)
#endif
}
// namespace oneflow
#endif // ONEFLOW_CORE_COMMON_DEVICE_TYPE_H_
oneflow/core/common/device_type.proto
0 → 100644
View file @
21d47d0e
syntax
=
"proto2"
;
package
oneflow
;
enum
DeviceType
{
kInvalidDevice
=
0
;
kCPU
=
1
;
kCUDA
=
2
;
kMockDevice
=
3
;
// pseudo device for test.
kROCm
=
4
;
}
oneflow/core/common/dtype_signature.h
0 → 100644
View file @
21d47d0e
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_REGISTER_DTYPE_SIGNATURE_H_
#define ONEFLOW_CORE_REGISTER_DTYPE_SIGNATURE_H_
#include "oneflow/core/common/dtype_signature.pb.h"
#include "oneflow/core/common/protobuf.h"
namespace
oneflow
{
inline
bool
operator
==
(
const
DTypeSignature
&
lhs
,
const
DTypeSignature
&
rhs
)
{
return
PbMd
().
Equals
(
lhs
,
rhs
);
}
}
// namespace oneflow
namespace
std
{
template
<
>
struct
hash
<
oneflow
::
DTypeSignature
>
final
{
size_t
operator
()(
const
oneflow
::
DTypeSignature
&
dtype_signature
)
{
std
::
string
serialized
;
dtype_signature
.
SerializeToString
(
&
serialized
);
return
std
::
hash
<
std
::
string
>
()(
serialized
);
}
};
}
// namespace std
#endif // ONEFLOW_CORE_REGISTER_DTYPE_SIGNATURE_H_
oneflow/core/common/dtype_signature.proto
0 → 100644
View file @
21d47d0e
syntax
=
"proto2"
;
package
oneflow
;
import
"oneflow/core/common/data_type.proto"
;
message
DTypeSignature
{
map
<
string
,
DataType
>
name2dtype
=
1
;
}
oneflow/core/common/eigen_util.h
0 → 100644
View file @
21d47d0e
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_COMMON_EIGEN_UTIL_H_
#define ONEFLOW_CORE_COMMON_EIGEN_UTIL_H_
#include "Eigen/Core"
#include "Eigen/Dense"
namespace
oneflow
{
template
<
typename
T
>
using
EigenMatrixMap
=
Eigen
::
Map
<
Eigen
::
Matrix
<
T
,
Eigen
::
Dynamic
,
Eigen
::
Dynamic
>>
;
template
<
typename
T
>
using
EigenArrayMap
=
Eigen
::
Map
<
Eigen
::
Array
<
T
,
Eigen
::
Dynamic
,
Eigen
::
Dynamic
>>
;
template
<
typename
T
>
using
ConstEigenMatrixMap
=
Eigen
::
Map
<
const
Eigen
::
Matrix
<
T
,
Eigen
::
Dynamic
,
Eigen
::
Dynamic
>>
;
template
<
typename
T
>
using
ConstEigenArrayMap
=
Eigen
::
Map
<
const
Eigen
::
Array
<
T
,
Eigen
::
Dynamic
,
Eigen
::
Dynamic
>>
;
}
// namespace oneflow
#endif // ONEFLOW_CORE_COMMON_EIGEN_UTIL_H_
oneflow/core/common/either_ptr.h
0 → 100644
View file @
21d47d0e
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_COMMON_EITHER_PTR_H_
#define ONEFLOW_CORE_COMMON_EITHER_PTR_H_
#include <glog/logging.h>
#include <memory>
namespace
oneflow
{
template
<
typename
X
,
typename
Y
>
class
EitherPtr
final
{
public:
static_assert
(
!
std
::
is_same
<
X
,
Y
>::
value
,
"X should not be Y"
);
using
XPtr
=
std
::
shared_ptr
<
X
>
;
using
YPtr
=
std
::
shared_ptr
<
Y
>
;
// WARNING: we should assume that the structure of shared_ptr<X> and shared_ptr<Y> is same,
// and obviously at most time the assumption holds
static_assert
(
sizeof
(
XPtr
)
==
sizeof
(
YPtr
),
"unsupported shared_ptr implementation"
);
EitherPtr
()
:
type_
(
UnionType
<
X
>::
value
),
x_ptr_
(
nullptr
)
{}
EitherPtr
(
const
XPtr
&
ptr
)
:
type_
(
UnionType
<
X
>::
value
),
x_ptr_
(
ptr
)
{}
EitherPtr
(
const
YPtr
&
ptr
)
:
type_
(
UnionType
<
Y
>::
value
)
{
new
(
&
x_ptr_
)
YPtr
(
ptr
);
}
EitherPtr
(
XPtr
&&
ptr
)
:
type_
(
UnionType
<
X
>::
value
),
x_ptr_
(
std
::
move
(
ptr
))
{}
EitherPtr
(
YPtr
&&
ptr
)
:
type_
(
UnionType
<
Y
>::
value
)
{
new
(
&
x_ptr_
)
YPtr
(
std
::
move
(
ptr
));
}
EitherPtr
(
const
EitherPtr
&
either_ptr
)
:
type_
(
either_ptr
.
type_
),
x_ptr_
(
either_ptr
.
x_ptr_
)
{}
EitherPtr
(
EitherPtr
&&
either_ptr
)
:
type_
(
either_ptr
.
type_
),
x_ptr_
(
std
::
move
(
either_ptr
.
x_ptr_
))
{}
// the destructor of X or Y will be called properly because it will be stored in the deleter of
// shared_ptr while constructed
~
EitherPtr
()
=
default
;
EitherPtr
&
operator
=
(
const
EitherPtr
&
either_ptr
)
{
x_ptr_
=
either_ptr
.
x_ptr_
;
type_
=
either_ptr
.
type_
;
return
*
this
;
}
EitherPtr
&
operator
=
(
EitherPtr
&&
either_ptr
)
{
x_ptr_
=
std
::
move
(
either_ptr
.
x_ptr_
);
type_
=
either_ptr
.
type_
;
return
*
this
;
}
template
<
typename
T
>
bool
Has
()
const
{
return
type_
==
UnionType
<
T
>::
value
;
}
template
<
typename
T
>
const
std
::
shared_ptr
<
T
>&
Get
()
const
{
return
Get
(
tag
<
T
>
{});
}
private:
template
<
typename
T
,
typename
Enable
=
void
>
struct
UnionType
;
template
<
typename
T
>
struct
UnionType
<
T
,
typename
std
::
enable_if
<
std
::
is_same
<
X
,
T
>::
value
>::
type
>
{
static
constexpr
int8_t
value
=
0
;
};
template
<
typename
T
>
struct
UnionType
<
T
,
typename
std
::
enable_if
<
std
::
is_same
<
Y
,
T
>::
value
>::
type
>
{
static
constexpr
int8_t
value
=
1
;
};
template
<
typename
>
struct
tag
{};
const
XPtr
&
Get
(
tag
<
X
>
)
const
{
CHECK
(
Has
<
X
>
());
return
x_ptr_
;
}
const
YPtr
&
Get
(
tag
<
Y
>
)
const
{
CHECK
(
Has
<
Y
>
());
const
auto
*
__attribute__
((
__may_alias__
))
ptr
=
reinterpret_cast
<
const
YPtr
*>
(
&
x_ptr_
);
return
*
ptr
;
}
int8_t
type_
;
std
::
shared_ptr
<
X
>
x_ptr_
;
};
}
// namespace oneflow
#endif // ONEFLOW_CORE_COMMON_EITHER_PTR_H_
oneflow/core/common/env_var/debug_mode.h
0 → 100644
View file @
21d47d0e
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_COMMON_ENV_VAR_DEBUG_MODE_H_
#define ONEFLOW_CORE_COMMON_ENV_VAR_DEBUG_MODE_H_
#include "oneflow/core/common/env_var/env_var.h"
namespace
oneflow
{
DEFINE_ENV_BOOL
(
ONEFLOW_DEBUG_MODE
,
false
);
DEFINE_ENV_BOOL
(
ONEFLOW_DEBUG
,
false
);
inline
bool
IsInDebugMode
()
{
return
EnvBool
<
ONEFLOW_DEBUG_MODE
>
()
||
EnvBool
<
ONEFLOW_DEBUG
>
();
}
}
// namespace oneflow
#endif // ONEFLOW_CORE_COMMON_ENV_VAR_DEBUG_MODE_H_
oneflow/core/common/env_var/env_var.h
0 → 100644
View file @
21d47d0e
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_COMMON_ENV_VAR_ENV_VAR_H_
#define ONEFLOW_CORE_COMMON_ENV_VAR_ENV_VAR_H_
#include "oneflow/core/common/util.h"
namespace
oneflow
{
template
<
typename
env_var
>
bool
EnvBool
();
#define DEFINE_ENV_BOOL(env_var, default_value) \
struct env_var {}; \
template<> \
inline bool EnvBool<env_var>() { \
return ParseBooleanFromEnv(OF_PP_STRINGIZE(env_var), default_value); \
}
template
<
typename
env_var
>
int64_t
EnvInteger
();
#define DEFINE_ENV_INTEGER(env_var, default_value) \
struct env_var {}; \
template<> \
inline int64_t EnvInteger<env_var>() { \
return ParseIntegerFromEnv(OF_PP_STRINGIZE(env_var), default_value); \
}
DEFINE_ENV_INTEGER
(
ONEFLOW_TIMEOUT_SECONDS
,
7200
);
DEFINE_ENV_INTEGER
(
ONEFLOW_CHECK_TIMEOUT_SLEEP_SECONDS
,
EnvInteger
<
ONEFLOW_TIMEOUT_SECONDS
>
());
DEFINE_ENV_INTEGER
(
ONEFLOW_VM_BLOCKING_DEBUG_INSTRUCTIONS_DISPLAY_LIMIT
,
100
);
DEFINE_ENV_INTEGER
(
ONEFLOW_DELETE_OUTDATED_SHM_NAMES_INTERVAL
,
1000
);
template
<
typename
env_var
>
bool
ThreadLocalEnvBool
();
#define DEFINE_THREAD_LOCAL_ENV_BOOL(env_var, default_value) \
struct env_var {}; \
template<> \
inline bool ThreadLocalEnvBool<env_var>() { \
thread_local bool value = ParseBooleanFromEnv(OF_PP_STRINGIZE(env_var), default_value); \
return value; \
}
template
<
typename
env_var
>
int64_t
ThreadLocalEnvInteger
();
#define DEFINE_THREAD_LOCAL_ENV_INTEGER(env_var, default_value) \
struct env_var {}; \
template<> \
inline int64_t ThreadLocalEnvInteger<env_var>() { \
thread_local int64_t value = ParseIntegerFromEnv(OF_PP_STRINGIZE(env_var), default_value); \
return value; \
}
DEFINE_THREAD_LOCAL_ENV_INTEGER
(
ONEFLOW_THRAED_LOCAL_CACHED_SIZE
,
128
*
1024
);
}
// namespace oneflow
#endif // ONEFLOW_CORE_COMMON_ENV_VAR_ENV_VAR_H_
oneflow/core/common/env_var/vm.h
0 → 100644
View file @
21d47d0e
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_COMMON_ENV_VAR_VM_H_
#define ONEFLOW_CORE_COMMON_ENV_VAR_VM_H_
#include "oneflow/core/common/env_var/env_var.h"
namespace
oneflow
{
DEFINE_THREAD_LOCAL_ENV_BOOL
(
ONEFLOW_VM_WORKLOAD_ON_SCHEDULER_THREAD
,
false
);
}
#endif // ONEFLOW_CORE_COMMON_ENV_VAR_VM_H_
oneflow/core/common/error.cpp
0 → 100644
View file @
21d47d0e
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include <stdexcept>
#include "oneflow/core/common/error.h"
#include "oneflow/core/common/exception.h"
#include "oneflow/core/common/protobuf.h"
#include "oneflow/core/common/util.h"
#include "oneflow/core/common/error_util.h"
#include "oneflow/core/common/env_var/debug_mode.h"
namespace
oneflow
{
namespace
{
void
LogError
(
const
Error
&
error
)
{
// gdb break point
LOG
(
ERROR
)
<<
error
->
msg
();
}
std
::
shared_ptr
<
ErrorProto
>*
MutThreadLocalError
()
{
thread_local
std
::
shared_ptr
<
ErrorProto
>
error
;
return
&
error
;
}
}
// namespace
Error
&&
Error
::
AddStackFrame
(
const
std
::
string
&
file
,
const
int64_t
&
line
,
const
std
::
string
&
function
)
{
auto
*
stack_frame
=
error_proto_
->
add_stack_frame
();
stack_frame
->
set_file
(
file
);
stack_frame
->
set_line
(
line
);
stack_frame
->
set_function
(
function
);
return
std
::
move
(
*
this
);
}
void
Error
::
Merge
(
const
Error
&
other
)
{
std
::
string
error_summary
{
error_proto_
->
error_summary
()};
std
::
string
msg
{
error_proto_
->
msg
()};
error_proto_
->
MergeFrom
(
*
other
.
error_proto_
);
// MergeFrom will overwrite singular field, so restore it.
if
(
!
error_summary
.
empty
())
{
error_proto_
->
set_error_summary
(
error_summary
+
" "
+
error_proto_
->
error_summary
());
}
if
(
!
msg
.
empty
())
{
error_proto_
->
set_msg
(
msg
+
" "
+
error_proto_
->
msg
());
}
}
Error
::
operator
std
::
string
()
const
{
return
error_proto_
->
DebugString
();
}
Error
Error
::
Ok
()
{
return
std
::
make_shared
<
ErrorProto
>
();
}
Error
Error
::
ProtoParseFailedError
()
{
auto
error
=
std
::
make_shared
<
ErrorProto
>
();
error
->
mutable_proto_parse_failed_error
();
return
error
;
}
Error
Error
::
JobSetEmptyError
()
{
auto
error
=
std
::
make_shared
<
ErrorProto
>
();
error
->
mutable_job_set_empty_error
();
return
error
;
}
Error
Error
::
DeviceTagNotFoundError
()
{
auto
error
=
std
::
make_shared
<
ErrorProto
>
();
error
->
mutable_device_tag_not_found_error
();
return
error
;
}
Error
Error
::
InvalidValueError
(
const
std
::
string
&
error_summary
)
{
auto
error
=
std
::
make_shared
<
ErrorProto
>
();
error
->
set_error_summary
(
error_summary
);
error
->
mutable_invalid_value_error
();
return
error
;
}
Error
Error
::
IndexError
()
{
auto
error
=
std
::
make_shared
<
ErrorProto
>
();
error
->
mutable_index_error
();
return
error
;
}
Error
Error
::
TypeError
()
{
auto
error
=
std
::
make_shared
<
ErrorProto
>
();
error
->
mutable_type_error
();
return
error
;
}
Error
Error
::
TimeoutError
()
{
auto
error
=
std
::
make_shared
<
ErrorProto
>
();
error
->
mutable_timeout_error
();
return
error
;
}
Error
Error
::
JobNameExistError
()
{
auto
error
=
std
::
make_shared
<
ErrorProto
>
();
error
->
mutable_job_name_exist_error
();
return
error
;
}
Error
Error
::
JobNameEmptyError
()
{
auto
error
=
std
::
make_shared
<
ErrorProto
>
();
error
->
mutable_job_name_empty_error
();
return
error
;
}
Error
Error
::
JobNameNotEqualError
()
{
auto
error
=
std
::
make_shared
<
ErrorProto
>
();
error
->
mutable_job_name_not_equal_error
();
return
error
;
}
Error
Error
::
NoJobBuildAndInferCtxError
()
{
auto
error
=
std
::
make_shared
<
ErrorProto
>
();
error
->
mutable_no_job_build_and_infer_ctx_error
();
return
error
;
}
Error
Error
::
JobConfFrozenError
()
{
auto
error
=
std
::
make_shared
<
ErrorProto
>
();
error
->
mutable_job_conf_frozen_error
();
return
error
;
}
Error
Error
::
JobConfNotSetError
()
{
auto
error
=
std
::
make_shared
<
ErrorProto
>
();
error
->
mutable_job_conf_not_set_error
();
return
error
;
}
Error
Error
::
JobConfRepeatedSetError
()
{
auto
error
=
std
::
make_shared
<
ErrorProto
>
();
error
->
mutable_job_conf_repeated_set_error
();
return
error
;
}
Error
Error
::
JobTypeNotSetError
()
{
auto
error
=
std
::
make_shared
<
ErrorProto
>
();
error
->
mutable_job_type_not_set_error
();
return
error
;
}
Error
Error
::
LogicalBlobNameNotExistError
()
{
auto
error
=
std
::
make_shared
<
ErrorProto
>
();
error
->
mutable_logical_blob_name_not_exist_error
();
return
error
;
}
Error
Error
::
LogicalBlobNameExistError
()
{
auto
error
=
std
::
make_shared
<
ErrorProto
>
();
error
->
mutable_logical_blob_name_exist_error
();
return
error
;
}
Error
Error
::
LogicalBlobNameInvalidError
()
{
auto
error
=
std
::
make_shared
<
ErrorProto
>
();
error
->
mutable_logical_blob_name_invalid_error
();
return
error
;
}
Error
Error
::
OpNameExistError
()
{
auto
error
=
std
::
make_shared
<
ErrorProto
>
();
error
->
mutable_op_name_exist_error
();
return
error
;
}
Error
Error
::
OpConfDeviceTagNoSetError
()
{
auto
error
=
std
::
make_shared
<
ErrorProto
>
();
error
->
mutable_op_conf_device_tag_no_set_error
();
return
error
;
}
Error
Error
::
PlacementError
()
{
auto
error
=
std
::
make_shared
<
ErrorProto
>
();
error
->
mutable_placement_error
();
return
error
;
}
Error
Error
::
BlobSplitAxisInferError
()
{
auto
error
=
std
::
make_shared
<
ErrorProto
>
();
error
->
mutable_blob_split_axis_infer_error
();
return
error
;
}
Error
Error
::
UnknownJobBuildAndInferError
()
{
auto
error
=
std
::
make_shared
<
ErrorProto
>
();
error
->
mutable_unknown_job_build_and_infer_error
();
return
error
;
}
Error
Error
::
CheckFailedError
()
{
auto
error
=
std
::
make_shared
<
ErrorProto
>
();
error
->
mutable_check_failed_error
();
return
error
;
}
Error
Error
::
ValueNotFoundError
()
{
auto
error
=
std
::
make_shared
<
ErrorProto
>
();
error
->
mutable_value_not_found_error
();
return
error
;
}
Error
Error
::
TodoError
()
{
auto
error
=
std
::
make_shared
<
ErrorProto
>
();
error
->
mutable_todo_error
();
return
error
;
}
Error
Error
::
UnimplementedError
()
{
auto
error
=
std
::
make_shared
<
ErrorProto
>
();
error
->
mutable_unimplemented_error
();
return
error
;
}
Error
Error
::
RuntimeError
()
{
auto
error
=
std
::
make_shared
<
ErrorProto
>
();
error
->
mutable_runtime_error
();
return
error
;
}
Error
Error
::
OutOfMemoryError
()
{
auto
error
=
std
::
make_shared
<
ErrorProto
>
();
error
->
mutable_out_of_memory_error
();
return
error
;
}
Error
Error
::
BoxingNotSupportedError
()
{
auto
error
=
std
::
make_shared
<
ErrorProto
>
();
error
->
mutable_boxing_not_supported_error
();
return
error
;
}
Error
Error
::
OpKernelNotFoundError
(
const
std
::
string
&
error_summary
,
const
std
::
vector
<
std
::
string
>&
error_msgs
)
{
auto
error
=
std
::
make_shared
<
ErrorProto
>
();
error
->
set_error_summary
(
error_summary
);
auto
*
op_kernel_not_found_error
=
error
->
mutable_op_kernel_not_found_error
();
for
(
const
auto
&
msg
:
error_msgs
)
{
op_kernel_not_found_error
->
add_op_kernels_not_found_debug_str
(
msg
);
}
return
error
;
}
Error
Error
::
MultipleOpKernelsMatchedError
(
const
std
::
string
&
error_summary
,
const
std
::
vector
<
std
::
string
>&
error_msgs
)
{
auto
error
=
std
::
make_shared
<
ErrorProto
>
();
error
->
set_error_summary
(
error_summary
);
auto
*
multiple_op_kernels_matched_error
=
error
->
mutable_multiple_op_kernels_matched_error
();
for
(
const
auto
&
msg
:
error_msgs
)
{
multiple_op_kernels_matched_error
->
add_matched_op_kernels_debug_str
(
msg
);
}
return
error
;
}
Error
Error
::
MemoryZoneOutOfMemoryError
(
int64_t
machine_id
,
int64_t
mem_zone_id
,
uint64_t
calc
,
uint64_t
available
,
const
std
::
string
&
device_tag
)
{
auto
error
=
std
::
make_shared
<
ErrorProto
>
();
auto
*
memory_zone_out_of_memory_error
=
error
->
mutable_memory_zone_out_of_memory_error
();
memory_zone_out_of_memory_error
->
add_machine_id
(
std
::
to_string
(
machine_id
));
memory_zone_out_of_memory_error
->
add_mem_zone_id
(
std
::
to_string
(
mem_zone_id
));
memory_zone_out_of_memory_error
->
add_device_tag
(
device_tag
);
memory_zone_out_of_memory_error
->
add_available
(
std
::
to_string
(
available
)
+
" bytes"
);
memory_zone_out_of_memory_error
->
add_required
(
std
::
to_string
(
calc
)
+
" bytes"
);
return
error
;
}
Error
Error
::
LossBlobNotFoundError
(
const
std
::
string
&
error_summary
)
{
auto
error
=
std
::
make_shared
<
ErrorProto
>
();
error
->
mutable_loss_blob_not_found_error
();
error
->
set_error_summary
(
error_summary
);
return
error
;
}
Error
Error
::
RwMutexedObjectNotFoundError
()
{
auto
error
=
std
::
make_shared
<
ErrorProto
>
();
error
->
mutable_rw_mutexed_object_not_found_error
();
return
error
;
}
Error
Error
::
GradientFunctionNotFoundError
()
{
auto
error
=
std
::
make_shared
<
ErrorProto
>
();
error
->
mutable_gradient_function_not_found_error
();
return
error
;
}
Error
Error
::
SymbolIdUninitializedError
()
{
auto
error
=
std
::
make_shared
<
ErrorProto
>
();
error
->
mutable_symbol_id_uninitialized_error
();
return
error
;
}
Error
Error
::
CompileOptionWrongError
()
{
auto
error
=
std
::
make_shared
<
ErrorProto
>
();
error
->
mutable_compile_option_wrong_error
();
return
error
;
}
Error
Error
::
InputDeviceNotMatchError
()
{
auto
error
=
std
::
make_shared
<
ErrorProto
>
();
auto
*
input_device_not_match_error
=
error
->
mutable_input_device_not_match_error
();
input_device_not_match_error
->
add_info
(
std
::
string
(
"Input tensors are at different devices, please try to use tensor.to or "
"module.to to correct it."
));
return
error
;
}
std
::
string
GetStackedErrorString
(
const
std
::
shared_ptr
<
ErrorProto
>&
error
)
{
const
auto
&
maybe_error
=
TRY
(
FormatErrorStr
(
error
));
const
auto
&
error_str
=
maybe_error
.
GetDataAndErrorProto
(
error
->
DebugString
());
CHECK_NE
(
error
->
error_type_case
(),
ErrorProto
::
ERROR_TYPE_NOT_SET
);
return
error_str
.
first
;
}
std
::
string
GetErrorString
(
const
std
::
shared_ptr
<
ErrorProto
>&
error
)
{
if
(
IsInDebugMode
())
{
return
GetStackedErrorString
(
error
);
}
else
{
if
(
error
->
msg
().
empty
()
&&
error
->
stack_frame
().
size
()
>
0
)
{
return
error
->
stack_frame
(
0
).
error_msg
();
}
else
{
return
error
->
msg
();
}
}
}
void
ThrowError
(
const
std
::
shared_ptr
<
ErrorProto
>&
error
)
{
*
MutThreadLocalError
()
=
error
;
if
(
error
->
has_runtime_error
())
{
throw
RuntimeException
(
GetErrorString
(
error
));
}
if
(
error
->
has_type_error
())
{
throw
TypeException
(
GetErrorString
(
error
));
}
if
(
error
->
has_index_error
())
{
throw
IndexException
(
GetErrorString
(
error
));
}
if
(
error
->
has_unimplemented_error
())
{
throw
NotImplementedException
(
GetErrorString
(
error
));
}
throw
Exception
(
GetStackedErrorString
(
error
));
}
const
std
::
shared_ptr
<
ErrorProto
>&
ThreadLocalError
()
{
return
*
MutThreadLocalError
();
}
const
char
*
kOfBugIssueUploadPrompt
=
"This is a oneflow bug, please submit issues in "
"'https://github.com/Oneflow-Inc/oneflow/issues' include the log information of the error, the "
"minimum reproduction code, and the system information."
;
}
// namespace oneflow
oneflow/core/common/error.h
0 → 100644
View file @
21d47d0e
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_COMMON_ERROR_H_
#define ONEFLOW_CORE_COMMON_ERROR_H_
#include <sstream>
#include <vector>
#include "oneflow/core/common/error.pb.h"
namespace
oneflow
{
class
Error
final
{
public:
Error
(
const
std
::
shared_ptr
<
ErrorProto
>&
error_proto
)
:
error_proto_
(
error_proto
)
{}
Error
(
const
Error
&
)
=
default
;
~
Error
()
=
default
;
std
::
shared_ptr
<
ErrorProto
>
error_proto
()
const
{
return
error_proto_
;
}
const
ErrorProto
*
operator
->
()
const
{
return
error_proto_
.
get
();
}
ErrorProto
*
operator
->
()
{
return
error_proto_
.
get
();
}
operator
std
::
string
()
const
;
void
Assign
(
const
Error
&
other
)
{
error_proto_
=
other
.
error_proto_
;
}
void
Merge
(
const
Error
&
other
);
// r-value reference is used to supporting expressions like `Error().AddStackFrame("foo.cpp",
// ,"line", "Bar") << "invalid value"` because operator<<() need r-value reference
Error
&&
AddStackFrame
(
const
std
::
string
&
file
,
const
int64_t
&
line
,
const
std
::
string
&
function
);
static
Error
Ok
();
static
Error
ProtoParseFailedError
();
static
Error
JobSetEmptyError
();
static
Error
DeviceTagNotFoundError
();
static
Error
InvalidValueError
(
const
std
::
string
&
error_summary
);
static
Error
IndexError
();
static
Error
TypeError
();
static
Error
TimeoutError
();
static
Error
JobNameExistError
();
static
Error
JobNameEmptyError
();
static
Error
JobNameNotEqualError
();
static
Error
NoJobBuildAndInferCtxError
();
static
Error
JobConfFrozenError
();
static
Error
JobConfNotSetError
();
static
Error
JobConfRepeatedSetError
();
static
Error
JobTypeNotSetError
();
static
Error
LogicalBlobNameNotExistError
();
static
Error
LogicalBlobNameExistError
();
static
Error
LogicalBlobNameInvalidError
();
static
Error
OpNameExistError
();
static
Error
OpConfDeviceTagNoSetError
();
static
Error
PlacementError
();
static
Error
BlobSplitAxisInferError
();
static
Error
UnknownJobBuildAndInferError
();
static
Error
CheckFailedError
();
static
Error
ValueNotFoundError
();
static
Error
TodoError
();
static
Error
UnimplementedError
();
static
Error
RuntimeError
();
static
Error
OutOfMemoryError
();
static
Error
BoxingNotSupportedError
();
static
Error
MemoryZoneOutOfMemoryError
(
int64_t
machine_id
,
int64_t
mem_zone_id
,
uint64_t
calc
,
uint64_t
available
,
const
std
::
string
&
device_type
);
static
Error
OpKernelNotFoundError
(
const
std
::
string
&
error_summary
,
const
std
::
vector
<
std
::
string
>&
error_msgs
);
static
Error
MultipleOpKernelsMatchedError
(
const
std
::
string
&
error_summary
,
const
std
::
vector
<
std
::
string
>&
error_msgs
);
static
Error
LossBlobNotFoundError
(
const
std
::
string
&
error_summary
);
static
Error
RwMutexedObjectNotFoundError
();
// gradient
static
Error
GradientFunctionNotFoundError
();
// symbol
static
Error
SymbolIdUninitializedError
();
static
Error
CompileOptionWrongError
();
static
Error
InputDeviceNotMatchError
();
private:
std
::
shared_ptr
<
ErrorProto
>
error_proto_
;
};
void
ThrowError
(
const
std
::
shared_ptr
<
ErrorProto
>&
error
);
const
std
::
shared_ptr
<
ErrorProto
>&
ThreadLocalError
();
template
<
typename
T
>
Error
&
operator
<<
(
Error
&
error
,
const
T
&
x
)
{
std
::
ostringstream
ss
;
ss
<<
x
;
if
(
error
->
stack_frame
().
empty
())
{
error
->
set_msg
(
error
->
msg
()
+
ss
.
str
());
}
else
{
auto
*
stack_frame_top
=
error
->
mutable_stack_frame
(
error
->
stack_frame_size
()
-
1
);
stack_frame_top
->
set_error_msg
(
stack_frame_top
->
error_msg
()
+
ss
.
str
());
}
return
error
;
}
// r-value reference is used to supporting expressions like `Error() << "invalid value"`
template
<
typename
T
>
Error
&&
operator
<<
(
Error
&&
error
,
const
T
&
x
)
{
error
<<
x
;
return
std
::
move
(
error
);
}
template
<
>
inline
Error
&&
operator
<<
(
Error
&&
error
,
const
std
::
stringstream
&
x
)
{
error
<<
x
.
str
();
return
std
::
move
(
error
);
}
template
<
>
inline
Error
&&
operator
<<
(
Error
&&
error
,
const
std
::
ostream
&
x
)
{
error
<<
x
.
rdbuf
();
return
std
::
move
(
error
);
}
template
<
>
inline
Error
&&
operator
<<
(
Error
&&
error
,
const
Error
&
other
)
{
error
.
Merge
(
other
);
return
std
::
move
(
error
);
}
extern
const
char
*
kOfBugIssueUploadPrompt
;
}
// namespace oneflow
#define PRINT_BUG_PROMPT_AND_ABORT() LOG(FATAL) << kOfBugIssueUploadPrompt
#endif // ONEFLOW_CORE_COMMON_ERROR_H_
oneflow/core/common/error.proto
0 → 100644
View file @
21d47d0e
syntax
=
"proto2"
;
package
oneflow
;
message
FieldValue
{
required
string
field
=
1
;
required
string
value
=
2
;
}
enum
OpcodeType
{
kInvalidCompareType
=
0
;
kEq
=
1
;
kNe
=
2
;
kGt
=
3
;
kGe
=
4
;
kLt
=
5
;
kLe
=
6
;
}
message
OneFieldAssertError
{
required
OpcodeType
compare_type
=
1
;
required
FieldValue
left
=
2
;
required
string
right_value
=
3
;
}
message
TwoFieldAssertError
{
required
OpcodeType
compare_type
=
1
;
required
FieldValue
left
=
2
;
required
FieldValue
right
=
3
;
}
message
ConfigAssertFailedError
{
oneof
oprand_type
{
OneFieldAssertError
one_field_assert_error
=
1
;
TwoFieldAssertError
two_field_assert_error
=
2
;
}
}
message
ConfigResourceUnavailableError
{
required
FieldValue
field_value
=
1
;
}
message
JobSetEmptyError
{
}
message
DeviceTagNotFoundError
{
}
message
JobNameExistError
{
}
message
JobNameEmptyError
{
}
message
JobNameNotEqualError
{
}
message
NoJobBuildAndInferCtxError
{
}
message
JobConfFrozenError
{
}
message
JobConfNotSetError
{
}
message
JobConfRepeatedSetError
{
}
message
JobTypeNotSetError
{
}
message
LogicalBlobNameNotExistError
{
}
message
LogicalBlobNameExistError
{
}
message
LogicalBlobNameInvalidError
{
}
message
OpNameExistError
{
}
message
OpConfDeviceTagNoSetError
{
}
message
PlacementError
{
}
message
BlobSplitAxisInferError
{
}
message
UnknownJobBuildAndInferError
{
}
message
ProtoParseFailedError
{
}
message
CheckFailedError
{
}
message
TodoError
{
}
message
UnimplementedError
{
}
message
RuntimeError
{
}
message
OutOfMemoryError
{
}
message
BoxingNotSupportedError
{
}
message
GradientFunctionNotFoundError
{
}
message
OpKernelNotFoundError
{
repeated
string
op_kernels_not_found_debug_str
=
1
;
}
message
MultipleOpKernelsMatchedError
{
repeated
string
matched_op_kernels_debug_str
=
1
;
}
message
MemoryZoneOutOfMemoryError
{
repeated
string
machine_id
=
1
;
repeated
string
mem_zone_id
=
2
;
repeated
string
device_tag
=
3
;
repeated
string
required
=
4
;
repeated
string
available
=
5
;
}
message
LossBlobNotFoundError
{
}
message
RwMutexedObjectNotFoundError
{
}
message
UnknownError
{
}
message
CompileOptionWrongError
{
}
message
InputDeviceNotMatchError
{
repeated
string
info
=
1
;
}
message
ErrorStackFrame
{
required
string
file
=
1
;
required
int64
line
=
2
;
required
string
function
=
3
;
required
string
error_msg
=
4
;
}
message
SymbolIdUninitializedError
{}
message
InvalidValueError
{}
message
IndexError
{}
message
TypeError
{}
message
TimeoutError
{}
message
ValueNotFoundError
{}
message
ErrorProto
{
optional
string
error_summary
=
1
[
default
=
""
];
optional
string
msg
=
2
[
default
=
""
];
repeated
ErrorStackFrame
stack_frame
=
3
;
oneof
error_type
{
ConfigAssertFailedError
config_assert_failed_error
=
12
;
ConfigResourceUnavailableError
config_resource_unavailable_error
=
13
;
ProtoParseFailedError
proto_parse_failed_error
=
15
;
CheckFailedError
check_failed_error
=
16
;
TodoError
todo_error
=
17
;
UnimplementedError
unimplemented_error
=
18
;
BoxingNotSupportedError
boxing_not_supported_error
=
19
;
GradientFunctionNotFoundError
gradient_function_not_found_error
=
20
;
OpKernelNotFoundError
op_kernel_not_found_error
=
21
;
MultipleOpKernelsMatchedError
multiple_op_kernels_matched_error
=
22
;
MemoryZoneOutOfMemoryError
memory_zone_out_of_memory_error
=
23
;
LossBlobNotFoundError
loss_blob_not_found_error
=
24
;
JobSetEmptyError
job_set_empty_error
=
25
;
DeviceTagNotFoundError
device_tag_not_found_error
=
26
;
InvalidValueError
invalid_value_error
=
27
;
IndexError
index_error
=
28
;
TypeError
type_error
=
29
;
RuntimeError
runtime_error
=
30
;
OutOfMemoryError
out_of_memory_error
=
32
;
TimeoutError
timeout_error
=
40
;
ValueNotFoundError
value_not_found_error
=
31
;
JobNameExistError
job_name_exist_error
=
100
;
JobNameEmptyError
job_name_empty_error
=
101
;
JobNameNotEqualError
job_name_not_equal_error
=
102
;
NoJobBuildAndInferCtxError
no_job_build_and_infer_ctx_error
=
200
;
JobConfFrozenError
job_conf_frozen_error
=
300
;
JobConfNotSetError
job_conf_not_set_error
=
301
;
JobConfRepeatedSetError
job_conf_repeated_set_error
=
302
;
JobTypeNotSetError
job_type_not_set_error
=
303
;
LogicalBlobNameNotExistError
logical_blob_name_not_exist_error
=
400
;
LogicalBlobNameExistError
logical_blob_name_exist_error
=
401
;
LogicalBlobNameInvalidError
logical_blob_name_invalid_error
=
402
;
OpNameExistError
op_name_exist_error
=
450
;
OpConfDeviceTagNoSetError
op_conf_device_tag_no_set_error
=
460
;
PlacementError
placement_error
=
470
;
BlobSplitAxisInferError
blob_split_axis_infer_error
=
480
;
UnknownJobBuildAndInferError
unknown_job_build_and_infer_error
=
500
;
RwMutexedObjectNotFoundError
rw_mutexed_object_not_found_error
=
600
;
SymbolIdUninitializedError
symbol_id_uninitialized_error
=
700
;
UnknownError
unknown_error
=
900
;
CompileOptionWrongError
compile_option_wrong_error
=
950
;
InputDeviceNotMatchError
input_device_not_match_error
=
1000
;
}
}
oneflow/core/common/error_util.cpp
0 → 100644
View file @
21d47d0e
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include <sstream>
#include "oneflow/core/common/error_util.h"
#include "oneflow/core/common/util.h"
#include "oneflow/core/job/graph_scope_vars.h"
namespace
oneflow
{
namespace
{
std
::
string
StripSpace
(
std
::
string
str
)
{
if
(
str
.
size
()
==
0
)
{
return
""
;
}
size_t
pos
=
str
.
find_first_not_of
(
" "
);
if
(
pos
!=
std
::
string
::
npos
)
{
str
.
erase
(
0
,
pos
);
}
pos
=
str
.
find_last_not_of
(
" "
);
if
(
pos
!=
std
::
string
::
npos
)
{
str
.
erase
(
pos
+
1
);
}
return
str
;
}
bool
IsLetterNumberOrUnderline
(
char
c
)
{
return
(
c
>=
'0'
&&
c
<=
'9'
)
||
(
c
>=
'a'
&&
c
<=
'z'
)
||
(
c
>=
'A'
&&
c
<=
'Z'
)
||
(
c
==
'_'
);
}
Maybe
<
std
::
string
>
ShortenMsg
(
std
::
string
str
)
{
// 150 characters is the threshold
const
int
num_character_threshold
=
150
;
const
int
num_displayed_character
=
50
;
if
(
str
.
size
()
==
0
)
{
return
str
;
}
// strip space when JUST( xx );
str
=
StripSpace
(
str
);
if
(
str
.
size
()
<
num_character_threshold
)
{
return
str
;
}
// left part whose number of characters is just over 50
int
left_index
=
num_displayed_character
;
bool
pre_condition
=
IsLetterNumberOrUnderline
(
str
.
at
(
left_index
));
for
(;
left_index
<
str
.
size
();
left_index
++
)
{
bool
cur_condition
=
IsLetterNumberOrUnderline
(
str
.
at
(
left_index
));
if
((
pre_condition
&&
!
cur_condition
)
||
(
!
pre_condition
&&
cur_condition
))
{
break
;
}
}
// right part whose number of characters is just over 50
int
right_index
=
str
.
size
()
-
num_displayed_character
;
pre_condition
=
IsLetterNumberOrUnderline
(
str
.
at
(
right_index
));
for
(;
right_index
>=
0
;
right_index
--
)
{
bool
cur_condition
=
IsLetterNumberOrUnderline
(
str
.
at
(
right_index
));
if
((
pre_condition
&&
!
cur_condition
)
||
(
!
pre_condition
&&
cur_condition
))
{
right_index
++
;
break
;
}
}
// a long word of more than 150
if
(
right_index
-
left_index
<
50
)
{
return
str
;
}
std
::
stringstream
ss
;
CHECK_OR_RETURN
(
left_index
>=
0
);
CHECK_OR_RETURN
(
left_index
<
str
.
size
());
ss
<<
str
.
substr
(
0
,
left_index
);
ss
<<
" ... "
;
CHECK_OR_RETURN
(
right_index
>=
0
);
CHECK_OR_RETURN
(
right_index
<
str
.
size
());
ss
<<
str
.
substr
(
right_index
);
return
ss
.
str
();
}
// file info in stack frame
std
::
string
FormatFileOfStackFrame
(
const
std
::
string
&
file
)
{
std
::
stringstream
ss
;
ss
<<
"
\n
File
\"
"
<<
file
<<
"
\"
, "
;
return
ss
.
str
();
}
// line info in stack frame
std
::
string
FormatLineOfStackFrame
(
const
int64_t
&
line
)
{
std
::
stringstream
ss
;
ss
<<
"line "
<<
line
<<
","
;
return
ss
.
str
();
}
// function info in stack frame
std
::
string
FormatFunctionOfStackFrame
(
const
std
::
string
&
function
)
{
std
::
stringstream
ss
;
ss
<<
" in "
<<
function
;
return
ss
.
str
();
}
// msg in stack frame
Maybe
<
std
::
string
>
FormatMsgOfStackFrame
(
std
::
string
error_msg
,
bool
is_last_stack_frame
)
{
const
bool
debug_mode
=
GetGraphDebugMode
();
// only shorten the message if it is not the last stack frame AND not in debug mode
if
(
!
is_last_stack_frame
&&
!
debug_mode
)
{
error_msg
=
*
JUST
(
ShortenMsg
(
error_msg
));
}
// error_msg of last stack frame come from "<<"
if
(
is_last_stack_frame
)
{
error_msg
=
StripSpace
(
error_msg
);
}
std
::
stringstream
ss
;
ss
<<
"
\n
"
<<
error_msg
;
return
ss
.
str
();
}
// the error_summary and msg in error proto
std
::
string
FormatErrorSummaryAndMsgOfErrorProto
(
const
std
::
shared_ptr
<
ErrorProto
>&
error
)
{
std
::
stringstream
ss
;
if
(
error
->
has_error_summary
())
{
ss
<<
error
->
error_summary
();
}
if
(
error
->
has_msg
())
{
ss
<<
(
ss
.
str
().
size
()
!=
0
?
"
\n
"
+
error
->
msg
()
:
error
->
msg
());
}
return
ss
.
str
();
}
// the msg in error type instance.
Maybe
<
std
::
string
>
FormatMsgOfErrorType
(
const
std
::
shared_ptr
<
ErrorProto
>&
error
)
{
CHECK_NE_OR_RETURN
(
error
->
error_type_case
(),
ErrorProto
::
ERROR_TYPE_NOT_SET
)
<<
Error
::
RuntimeError
()
<<
"Parse error failed, unknown error type"
;
std
::
stringstream
ss
;
const
google
::
protobuf
::
Descriptor
*
error_des
=
error
->
GetDescriptor
();
const
google
::
protobuf
::
OneofDescriptor
*
oneof_field_des
=
error_des
->
FindOneofByName
(
"error_type"
);
const
google
::
protobuf
::
Reflection
*
error_ref
=
error
->
GetReflection
();
const
google
::
protobuf
::
FieldDescriptor
*
field_des
=
error_ref
->
GetOneofFieldDescriptor
(
*
error
,
oneof_field_des
);
CHECK_OR_RETURN
(
field_des
!=
nullptr
);
ss
<<
"Error Type: "
<<
field_des
->
full_name
();
return
ss
.
str
();
}
}
// namespace
Maybe
<
std
::
string
>
FormatErrorStr
(
const
std
::
shared_ptr
<
ErrorProto
>&
error
)
{
std
::
stringstream
ss
;
// Get msg from stack frame of error proto
for
(
auto
stack_frame
=
error
->
mutable_stack_frame
()
->
rbegin
();
stack_frame
<
error
->
mutable_stack_frame
()
->
rend
();
stack_frame
++
)
{
ss
<<
FormatFileOfStackFrame
(
stack_frame
->
file
())
<<
FormatLineOfStackFrame
(
stack_frame
->
line
())
<<
FormatFunctionOfStackFrame
(
stack_frame
->
function
())
<<
*
JUST
(
FormatMsgOfStackFrame
(
stack_frame
->
error_msg
(),
stack_frame
==
error
->
mutable_stack_frame
()
->
rend
()
-
1
));
}
// Get msg from error summary and msg of error proto
std
::
string
error_summary_and_msg_of_error_proto
=
FormatErrorSummaryAndMsgOfErrorProto
(
error
);
if
(
error_summary_and_msg_of_error_proto
.
size
()
!=
0
)
{
ss
<<
"
\n
"
<<
error_summary_and_msg_of_error_proto
;
}
// Get msg from error type of error proto
std
::
string
msg_of_error_type
=
*
JUST
(
FormatMsgOfErrorType
(
error
));
if
(
msg_of_error_type
.
size
()
!=
0
)
{
ss
<<
"
\n
"
<<
msg_of_error_type
;
}
return
ss
.
str
();
}
}
// namespace oneflow
oneflow/core/common/error_util.h
0 → 100644
View file @
21d47d0e
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_COMMON_ERROR_UTIL_H
#define ONEFLOW_CORE_COMMON_ERROR_UTIL_H
#include <string>
#include "oneflow/core/common/error.pb.h"
#include "oneflow/core/common/maybe.h"
namespace
oneflow
{
Maybe
<
std
::
string
>
FormatErrorStr
(
const
std
::
shared_ptr
<
ErrorProto
>&
error
);
}
// namespace oneflow
#endif // ONEFLOW_CORE_COMMON_ERROR_UTIL_H
oneflow/core/common/exception.h
0 → 100644
View file @
21d47d0e
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_COMMON_EXCEPTION_H_
#define ONEFLOW_CORE_COMMON_EXCEPTION_H_
#include <exception>
#include <string>
namespace
oneflow
{
class
Exception
:
public
std
::
exception
{
public:
explicit
Exception
(
const
std
::
string
&
what
)
:
what_
(
what
)
{}
virtual
~
Exception
()
=
default
;
const
char
*
what
()
const
noexcept
override
{
return
what_
.
c_str
();
}
private:
std
::
string
what_
;
};
class
RuntimeException
:
public
Exception
{
public:
using
Exception
::
Exception
;
};
class
TypeException
:
public
Exception
{
public:
using
Exception
::
Exception
;
};
class
IndexException
:
public
Exception
{
public:
using
Exception
::
Exception
;
};
class
NotImplementedException
:
public
Exception
{
public:
using
Exception
::
Exception
;
};
}
// namespace oneflow
#endif // ONEFLOW_CORE_COMMON_EXCEPTION_H_
oneflow/core/common/flat_shape.cpp
0 → 100644
View file @
21d47d0e
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/common/flat_shape.h"
#include "oneflow/core/common/shape.h"
namespace
oneflow
{
/*static*/
Maybe
<
FlatShape
>
FlatShape
::
New
(
const
Shape
&
shape
)
{
const
auto
&
flat_shape
=
std
::
make_shared
<
FlatShape
>
();
JUST
(
flat_shape
->
Init
(
shape
));
return
flat_shape
;
}
Maybe
<
void
>
FlatShape
::
Init
(
const
Shape
&
shape
)
{
CHECK_LE_OR_RETURN
(
shape
.
NumAxes
(),
SHAPE_MAX_AXIS_SIZE
);
this
->
clear_dim
();
for
(
int
i
=
0
;
i
<
shape
.
NumAxes
();
++
i
)
{
*
this
->
mutable_dim
()
->
Add
()
=
shape
.
At
(
i
);
}
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
FlatShape
::
Check
(
const
Shape
&
shape
)
const
{
CHECK_EQ_OR_RETURN
(
this
->
dim_size
(),
shape
.
NumAxes
())
<<
Error
::
RuntimeError
()
<<
"Expected same shape on each rank, but found at least two shapes, "
<<
JUST
(
ToShape
())
->
ToString
()
<<
" and "
<<
shape
.
ToString
()
<<
"!"
;
for
(
int
i
=
0
;
i
<
this
->
dim_size
();
++
i
)
{
CHECK_EQ_OR_RETURN
(
this
->
dim
(
i
),
shape
.
At
(
i
));
}
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
FlatShape
::
Check
(
const
FlatShape
&
flat_shape
)
const
{
CHECK_EQ_OR_RETURN
(
this
->
dim_size
(),
flat_shape
.
NumAxes
())
<<
Error
::
RuntimeError
()
<<
"Expected input of each rank must have the same size, but got at least two size, "
<<
JUST
(
ToShape
())
->
ToString
()
<<
" and "
<<
JUST
(
flat_shape
.
ToShape
())
->
ToString
();
for
(
int
i
=
0
;
i
<
this
->
dim_size
();
++
i
)
{
CHECK_EQ_OR_RETURN
(
this
->
dim
(
i
),
flat_shape
.
At
(
i
))
<<
Error
::
RuntimeError
()
<<
"Expected input of each rank must have the same size, but got at least two size, "
<<
JUST
(
ToShape
())
->
ToString
()
<<
" and "
<<
JUST
(
flat_shape
.
ToShape
())
->
ToString
();
}
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
Shape
>
FlatShape
::
ToShape
()
const
{
const
auto
&
shape
=
std
::
make_shared
<
Shape
>
();
JUST
(
ToShape
(
shape
.
get
()));
return
shape
;
}
Maybe
<
void
>
FlatShape
::
ToShape
(
Shape
*
shape
)
const
{
DimVector
dim_vec
;
for
(
int
i
=
0
;
i
<
this
->
dim_size
();
++
i
)
{
dim_vec
.
emplace_back
(
this
->
dim
(
i
));
}
*
shape
=
Shape
(
dim_vec
);
return
Maybe
<
void
>::
Ok
();
}
}
// namespace oneflow
oneflow/core/common/flat_shape.h
0 → 100644
View file @
21d47d0e
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_COMMON_FLAT_SHAPE_H_
#define ONEFLOW_CORE_COMMON_FLAT_SHAPE_H_
#include <memory>
#include "oneflow/core/intrusive/flat_msg.h"
#include "oneflow/core/common/maybe.h"
#include "oneflow/core/common/shape_vec.h"
namespace
oneflow
{
class
Shape
;
// clang-format off
FLAT_MSG_BEGIN
(
FlatShape
);
public:
// Methods
static
Maybe
<
FlatShape
>
New
(
const
Shape
&
shape
);
Maybe
<
void
>
Init
(
const
Shape
&
shape
);
Maybe
<
void
>
Check
(
const
Shape
&
shape
)
const
;
Maybe
<
void
>
Check
(
const
FlatShape
&
flat_shape
)
const
;
Maybe
<
Shape
>
ToShape
()
const
;
Maybe
<
void
>
ToShape
(
Shape
*
shape
)
const
;
int64_t
At
(
int
i
)
const
{
return
dim
(
i
);
}
int64_t
NumAxes
()
const
{
return
dim_size
();
}
// Fields
FLAT_MSG_DEFINE_REPEATED
(
int64_t
,
dim
,
SHAPE_MAX_AXIS_SIZE
);
FLAT_MSG_END
(
FlatShape
);
// clang-format on
}
// namespace oneflow
#endif // ONEFLOW_CORE_COMMON_FLAT_SHAPE_H_
Prev
1
…
23
24
25
26
27
28
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment