Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Oneflow
Commits
a715222c
Commit
a715222c
authored
Feb 28, 2023
by
yuguo
Browse files
0.9.1-rocm
parent
f262efc9
Changes
473
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
4162 additions
and
172 deletions
+4162
-172
oneflow/api/python/symbol/sbp_symbol.cpp
oneflow/api/python/symbol/sbp_symbol.cpp
+2
-2
oneflow/api/python/utils/tensor_utils.cpp
oneflow/api/python/utils/tensor_utils.cpp
+90
-78
oneflow/api/python/utils/tensor_utils.h
oneflow/api/python/utils/tensor_utils.h
+51
-24
oneflow/core/auto_parallel/algorithm_util.cpp
oneflow/core/auto_parallel/algorithm_util.cpp
+33
-0
oneflow/core/auto_parallel/algorithm_util.h
oneflow/core/auto_parallel/algorithm_util.h
+82
-0
oneflow/core/auto_parallel/binary_set.cpp
oneflow/core/auto_parallel/binary_set.cpp
+147
-0
oneflow/core/auto_parallel/binary_set.h
oneflow/core/auto_parallel/binary_set.h
+93
-0
oneflow/core/auto_parallel/boxing_collector.cpp
oneflow/core/auto_parallel/boxing_collector.cpp
+247
-63
oneflow/core/auto_parallel/boxing_collector.h
oneflow/core/auto_parallel/boxing_collector.h
+19
-5
oneflow/core/auto_parallel/sbp_collector.cpp
oneflow/core/auto_parallel/sbp_collector.cpp
+403
-0
oneflow/core/auto_parallel/sbp_collector.h
oneflow/core/auto_parallel/sbp_collector.h
+90
-0
oneflow/core/auto_parallel/sbp_constructor.cpp
oneflow/core/auto_parallel/sbp_constructor.cpp
+457
-0
oneflow/core/auto_parallel/sbp_constructor.h
oneflow/core/auto_parallel/sbp_constructor.h
+83
-0
oneflow/core/auto_parallel/sbp_edge.cpp
oneflow/core/auto_parallel/sbp_edge.cpp
+332
-0
oneflow/core/auto_parallel/sbp_edge.h
oneflow/core/auto_parallel/sbp_edge.h
+140
-0
oneflow/core/auto_parallel/sbp_graph.cpp
oneflow/core/auto_parallel/sbp_graph.cpp
+809
-0
oneflow/core/auto_parallel/sbp_graph.h
oneflow/core/auto_parallel/sbp_graph.h
+139
-0
oneflow/core/auto_parallel/sbp_node.cpp
oneflow/core/auto_parallel/sbp_node.cpp
+708
-0
oneflow/core/auto_parallel/sbp_node.h
oneflow/core/auto_parallel/sbp_node.h
+197
-0
oneflow/core/auto_parallel/sbp_util.cpp
oneflow/core/auto_parallel/sbp_util.cpp
+40
-0
No files found.
Too many changes to show.
To preserve performance only
473 of 473+
files are displayed.
Plain diff
Email patch
oneflow/api/python/symbol/sbp_symbol.cpp
View file @
a715222c
...
...
@@ -90,8 +90,8 @@ ONEFLOW_API_PYBIND11_MODULE("sbp", m) {
m
.
attr
(
"max_split_axis"
)
=
kMaxSplitAxis
;
py
::
class_
<
Symbol
<
SbpParallel
>
,
std
::
shared_ptr
<
Symbol
<
SbpParallel
>>>
(
m
,
"sbp"
,
py
::
dynamic_attr
())
.
def
(
"__str__"
,
&
api
::
SbpToString
)
.
def
(
"__repr__"
,
&
api
::
SbpToString
)
.
def
(
"__str__"
,
&
api
::
Api
SbpToString
)
.
def
(
"__repr__"
,
&
api
::
Api
SbpToString
)
.
def
(
py
::
self
==
py
::
self
)
.
def
(
py
::
hash
(
py
::
self
))
.
def
(
"_ToAttrStr"
,
...
...
oneflow/api/python/utils/tensor_utils.cpp
View file @
a715222c
...
...
@@ -15,13 +15,14 @@ limitations under the License.
*/
#include "oneflow/api/python/utils/tensor_utils.h"
#include "oneflow/api/python/ofblob/ofblob.e.h"
#include "oneflow/core/autograd/autograd_engine.h"
#include "oneflow/core/common/container_util.h"
#include "oneflow/core/common/switch_func.h"
#include "oneflow/core/common/tensor_buffer.h"
#include "oneflow/core/framework/nd_sbp.h"
#include "oneflow/core/functional/functional.h"
#include "oneflow/core/job/global_mode.h"
#include "oneflow/core/kernel/kernel_util.h"
#include "oneflow/extension/python/numpy.h"
#include "oneflow/core/common/decorator.h"
#include "oneflow/core/framework/consistency_check.h"
...
...
@@ -32,11 +33,11 @@ namespace py = pybind11;
namespace
oneflow
{
namespace
one
{
Maybe
<
void
>
Eager
Mirrored
TensorZeros
(
const
std
::
shared_ptr
<
Tensor
>&
t
)
{
Maybe
<
void
>
Eager
Local
TensorZeros
(
const
std
::
shared_ptr
<
Tensor
>&
t
)
{
JUST
(
functional
::
CheckInplaceValid
(
t
));
std
::
shared_ptr
<
Mirrored
Tensor
>
local_tensor
;
std
::
shared_ptr
<
Local
Tensor
>
local_tensor
;
if
(
t
->
is_local
())
{
local_tensor
=
JUST
(
t
->
As
Mirrored
Tensor
());
local_tensor
=
JUST
(
t
->
As
Local
Tensor
());
}
else
{
local_tensor
=
JUST
(
t
->
cur_rank_phy_tensor
());
}
...
...
@@ -44,9 +45,9 @@ Maybe<void> EagerMirroredTensorZeros(const std::shared_ptr<Tensor>& t) {
JUST
(
PhysicalRun
([
&
](
InstructionsBuilder
*
builder
)
->
Maybe
<
void
>
{
JUST
(
builder
->
AccessBlobByCallback
(
local_tensor
,
[](
uint64_t
of_blob_ptr
)
{
a
uto
*
of_blob
=
reinterpret_cast
<
OfBlob
*>
(
of_blob_ptr
);
of_blob
->
AsyncAutoMemset
(
0
);
[](
ep
::
Stream
*
stream
,
const
std
::
shared_ptr
<
vm
::
EagerBlobObject
>&
eager_blob_object
)
{
A
uto
Memset
(
stream
,
eager_blob_object
->
mut_dptr
(),
0
,
eager_blob_object
->
ByteSizeOfBlobBody
(),
eager_blob_object
->
mem_case
()
);
},
"mut"
));
return
Maybe
<
void
>::
Ok
();
...
...
@@ -54,38 +55,25 @@ Maybe<void> EagerMirroredTensorZeros(const std::shared_ptr<Tensor>& t) {
return
Maybe
<
void
>::
Ok
();
}
template
<
typename
T
>
Maybe
<
void
>
CopyMirroredTensorFromUntypedArray
(
const
std
::
shared_ptr
<
Tensor
>&
tensor
,
PyObject
*
array
)
{
return
CopyBetweenMirroredTensorAndNumpy
<
T
>
(
tensor
,
array
,
BlobNumpyCopyUtil
<
T
>::
From
,
"mut"
,
/*block_host_until_done=*/
false
);
}
Maybe
<
std
::
string
>
GetCopyMirroredTensorToNumpyFuncName
(
DataType
dtype
)
{
using
namespace
oneflow
;
static
const
HashMap
<
int64_t
,
std
::
shared_ptr
<
std
::
string
>>
data_type2func_name
{
#define DATA_TYPE_FUNC_NAME_PAIR(type_cpp, type_proto) \
{type_proto, std::make_shared<std::string>("_copy_to_numpy_" #type_cpp)},
OF_PP_FOR_EACH_TUPLE
(
DATA_TYPE_FUNC_NAME_PAIR
,
POD_DATA_TYPE_SEQ
)
#undef DATA_TYPE_FUNC_NAME_PAIR
};
return
JUST
(
MapAt
(
data_type2func_name
,
static_cast
<
int64_t
>
(
dtype
)));
namespace
{
void
CopyFromNumpyArray
(
ep
::
Stream
*
stream
,
const
std
::
shared_ptr
<
vm
::
EagerBlobObject
>&
eager_blob_object
,
const
NumPyArrayPtr
&
array_ptr
)
{
SyncAutoMemcpy
(
stream
,
eager_blob_object
->
mut_dptr
(),
array_ptr
.
data
(),
eager_blob_object
->
ByteSizeOfBlobBody
(),
eager_blob_object
->
mem_case
(),
memory
::
MakeHostMemCase
());
}
}
// namespace
Maybe
<
std
::
string
>
GetCopyMirroredTensorFromNumpyFuncName
(
DataType
dtype
)
{
using
namespace
oneflow
;
static
const
HashMap
<
int64_t
,
std
::
shared_ptr
<
std
::
string
>>
data_type2func_name
{
#define DATA_TYPE_FUNC_NAME_PAIR(type_cpp, type_proto) \
{type_proto, std::make_shared<std::string>("_copy_from_numpy_" #type_cpp)},
OF_PP_FOR_EACH_TUPLE
(
DATA_TYPE_FUNC_NAME_PAIR
,
POD_DATA_TYPE_SEQ
)
#undef DATA_TYPE_FUNC_NAME_PAIR
};
return
JUST
(
MapAt
(
data_type2func_name
,
static_cast
<
int64_t
>
(
dtype
)));
Maybe
<
void
>
CopyLocalTensorFromUntypedArray
(
const
std
::
shared_ptr
<
Tensor
>&
tensor
,
PyObject
*
array
)
{
return
CopyBetweenLocalTensorAndNumpy
(
tensor
,
array
,
CopyFromNumpyArray
,
"mut"
,
/*block_host_until_done=*/
false
);
}
Maybe
<
std
::
tuple
<
std
::
vector
<
Shape
>
,
std
::
vector
<
Symbol
<
DType
>>>>
MaybeGetTensorBufferShapesAndDTypes
(
const
std
::
shared_ptr
<
Tensor
>&
t
)
{
const
auto
&
tensor
=
JUST
(
t
->
As
Mirrored
Tensor
());
const
auto
&
tensor
=
JUST
(
t
->
As
Local
Tensor
());
if
(
tensor
->
dtype
()
!=
DType
::
TensorBuffer
())
{
return
Error
::
RuntimeError
()
<<
"tensor buffer supported only"
;
}
...
...
@@ -93,10 +81,11 @@ MaybeGetTensorBufferShapesAndDTypes(const std::shared_ptr<Tensor>& t) {
std
::
vector
<
Shape
>
shapes
;
std
::
vector
<
Symbol
<
DType
>>
dtypes
;
auto
btb
=
std
::
make_shared
<
BlockingThenBusy
>
(
1
);
auto
btb
=
std
::
make_shared
<
BlockingThenBusy
>
();
JUST
(
PhysicalRun
([
&
](
InstructionsBuilder
*
builder
)
->
Maybe
<
void
>
{
return
builder
->
SyncAccessBlobByCallback
(
tensor
,
btb
,
[](
uint64_t
)
{},
"const"
);
tensor
,
btb
,
[](
ep
::
Stream
*
stream
,
const
std
::
shared_ptr
<
vm
::
EagerBlobObject
>&
)
{},
"const"
);
}));
JUST
(
btb
->
WaitUntilCntEqualZero
(
VirtualMachine
::
GetPredicatorNoMoreInstructionsFinished
()));
...
...
@@ -136,41 +125,51 @@ Maybe<py::tuple> TensorGetPyTupleOfSbp(const Tensor& tensor) {
return
tuple
;
}
#define MAKE_SWITCH_ENTRY(func_name, dtype) func_name<dtype>
DEFINE_STATIC_SWITCH_FUNC
(
Maybe
<
void
>
,
CopyMirroredTensorFromUntypedArray
,
MAKE_SWITCH_ENTRY
,
MAKE_DATA_TYPE_CTRV_SEQ
(
POD_AND_HALF_DATA_TYPE_SEQ
));
Maybe
<
Tensor
>
MakeLocalTensorFromData
(
PyObject
*
data
,
const
Optional
<
Symbol
<
DType
>>&
dtype
,
const
Optional
<
Symbol
<
Device
>>&
device
,
const
bool
requires_grad
,
const
bool
pin_memory
)
{
PyObject
*
array
=
NULL
;
bool
is_bfloat16_dtype
=
dtype
?
JUST
(
dtype
)
->
data_type
()
==
DataType
::
kBFloat16
:
false
;
bool
is_cuda_device
=
device
?
JUST
(
device
)
->
enum_type
()
==
DeviceType
::
kCUDA
:
false
;
if
(
is_bfloat16_dtype
&&
is_cuda_device
)
{
#if (CUDA_VERSION < 11000)
return
Error
::
RuntimeError
()
<<
"Cannot create a bfloat16 tensor on gpu under cuda version: 11000"
;
#endif // CUDA_VERSION >= 11000
#ifdef WITH_ROCM
return
Error
::
RuntimeError
()
<<
"Cannot create a bfloat16 tensor on gpu under ROCm for now"
;
#endif // WITH_ROCM
}
PyArray_Descr
*
np_dtype
=
dtype
.
has_value
()
dtype
.
has_value
()
&&
!
is_bfloat16_dtype
?
PyArray_DescrFromType
(
JUST
(
numpy
::
OFDataTypeToNumpyType
(
JUST
(
dtype
)
->
data_type
())))
:
nullptr
;
// PyArray_FromAny steals a reference to np_dtype object, so no need to decref it.
// NPY_ARRAY_DEFAULT is NPY_ARRAY_C_CONTIGUOUS | NPY_ARRAY_BEHAVED, so the
// array with NPY_ARRAY_DEFAULT flag is C-style contiguous.
// NPY_ARRAY_FORCECAST is needed otherwise there will a segfault.
array
=
PyArray_FromAny
(
data
,
np_dtype
,
0
,
0
,
NPY_ARRAY_DEFAULT
|
NPY_ARRAY_ENSURECOPY
|
NPY_ARRAY_FORCECAST
,
nullptr
);
if
(
!
array
)
{
//
// Even though PyArray_FromAny can cast the input array to the desired dtype
// if `dtype` argument is set, it fails to handle the following case:
// >> x = [flow.tensor([1, 2])] * 3 <-- x is a list of flow.Tensor
// >> y = flow.tensor(x, dtype=flow.float32) <-- returns nullptr
// However, the following case without `dtype` argument works well:
// >> x = [flow.tensor([1, 2])] * 3
// >> y = flow.tensor(x)
// So we cast the input array to the desired dtype manually.
PyArrayObject
*
_array
=
reinterpret_cast
<
PyArrayObject
*>
(
PyArray_FromAny
(
data
,
nullptr
,
0
,
0
,
NPY_ARRAY_DEFAULT
|
NPY_ARRAY_ENSURECOPY
|
NPY_ARRAY_FORCECAST
,
nullptr
));
if
(
!
_array
)
{
return
Error
::
RuntimeError
()
<<
"Can not convert input data to a new numpy array."
;
}
// flow.tensor([1., 2.]).dtype should be flow.float32 rather than flow.float64
if
(
!
PyArray_Check
(
data
))
{
int
np_array_type
=
PyArray_TYPE
(
reinterpret_cast
<
PyArrayObject
*>
(
array
));
// Cast to float if data is double sequence, rather than numpy array.
if
(
np_array_type
==
NPY_DOUBLE
&&
np_dtype
==
nullptr
)
{
PyObject
*
fp32_array
=
PyArray_Cast
(
reinterpret_cast
<
PyArrayObject
*>
(
array
),
NPY_FLOAT
);
Py_DECREF
(
array
);
array
=
fp32_array
;
}
}
// PyArray_FromArray steals a reference to np_dtype object, so no need to decref it.
PyObject
*
array
=
PyArray_FromArray
(
_array
,
np_dtype
,
NPY_ARRAY_DEFAULT
|
NPY_ARRAY_ENSURECOPY
|
NPY_ARRAY_FORCECAST
);
Py_DECREF
(
_array
);
auto
*
np_arr
=
reinterpret_cast
<
PyArrayObject
*>
(
array
);
const
npy_intp
*
dims_ptr
=
PyArray_SHAPE
(
np_arr
);
const
Shape
shape
(
DimVector
(
dims_ptr
,
dims_ptr
+
PyArray_NDIM
(
np_arr
)));
DataType
data_type
=
JUST
(
numpy
::
GetOFDataTypeFromNpArray
(
np_arr
));
DataType
np_
data_type
=
JUST
(
numpy
::
GetOFDataTypeFromNpArray
(
np_arr
));
Symbol
<
Device
>
device_
;
if
(
device
)
{
...
...
@@ -179,10 +178,17 @@ Maybe<Tensor> MakeLocalTensorFromData(PyObject* data, const Optional<Symbol<DTyp
device_
=
JUST
(
Device
::
New
(
"cpu"
));
}
std
::
shared_ptr
<
Tensor
>
tensor
=
JUST
(
functional
::
Empty
(
shape
,
JUST
(
DType
::
Get
(
data_type
)),
device_
,
/*pin_memory=*/
pin_memory
));
JUST
(
SwitchCopyMirrored
TensorFromUntypedArray
(
SwitchCase
(
data_type
),
tensor
,
array
));
functional
::
Empty
(
shape
,
JUST
(
DType
::
Get
(
np_
data_type
)),
device_
,
/*pin_memory=*/
pin_memory
));
JUST
(
CopyLocal
TensorFromUntypedArray
(
tensor
,
array
));
Py_DECREF
(
array
);
if
(
dtype
&&
JUST
(
dtype
)
->
data_type
()
!=
np_data_type
)
{
tensor
=
JUST
(
functional
::
To
(
tensor
,
JUST
(
dtype
),
false
));
}
else
if
(
!
dtype
&&
!
PyArray_Check
(
data
)
&&
tensor
->
dtype
()
->
is_floating_point
()
&&
GetDefaultDType
()
!=
tensor
->
dtype
())
{
// If it not assign dtype and created from PySequence, cast tensor to default floating dtype
tensor
=
JUST
(
functional
::
To
(
tensor
,
JUST
(
DType
::
Get
(
DataType
::
kFloat
)),
false
));
}
JUST
(
tensor
->
set_requires_grad
(
requires_grad
));
return
tensor
;
}
...
...
@@ -201,7 +207,7 @@ auto* CachedGetAllBroadcastNdSbp = DECORATE(&GetAllBroadcastNdSbp, ThreadLocal);
}
// namespace
Maybe
<
Tensor
>
Make
Consistent
TensorFromData
(
PyObject
*
data
,
const
Optional
<
Symbol
<
DType
>>&
dtype
,
Maybe
<
Tensor
>
Make
Global
TensorFromData
(
PyObject
*
data
,
const
Optional
<
Symbol
<
DType
>>&
dtype
,
Symbol
<
ParallelDesc
>
placement
,
const
std
::
vector
<
Symbol
<
SbpParallel
>>&
sbp_tuple
,
const
bool
requires_grad
)
{
...
...
@@ -229,9 +235,13 @@ Maybe<Tensor> MakeConsistentTensorFromData(PyObject* data, const Optional<Symbol
}
Symbol
<
Device
>
device
=
JUST
(
Device
::
New
(
placement
->
device_tag
()));
std
::
shared_ptr
<
Tensor
>
local_tensor
=
std
::
shared_ptr
<
Tensor
>
local_tensor
;
{
GlobalMode
::
Guard
guard
(
/* disable global mode */
false
);
local_tensor
=
JUST
(
functional
::
Empty
(
shape
,
JUST
(
DType
::
Get
(
data_type
)),
device
,
/*pin_memory=*/
false
));
JUST
(
SwitchCopyMirroredTensorFromUntypedArray
(
SwitchCase
(
data_type
),
local_tensor
,
array
));
}
JUST
(
CopyLocalTensorFromUntypedArray
(
local_tensor
,
array
));
Py_DECREF
(
array
);
// Cast to float if data is double sequence, rather than numpy array.
...
...
@@ -246,14 +256,16 @@ Maybe<Tensor> MakeConsistentTensorFromData(PyObject* data, const Optional<Symbol
size_t
sbp_dims
=
sbp_tuple
.
size
();
Symbol
<
NdSbp
>
broadcast_nd_sbp
=
JUST
(
CachedGetAllBroadcastNdSbp
(
sbp_dims
));
std
::
shared_ptr
<
Tensor
>
broadcast_tensor
=
JUST
(
functional
::
LocalToConsistent
(
local_tensor
,
placement
,
*
JUST
(
GetSbpList
(
broadcast_nd_sbp
)),
shape
,
local_tensor
->
dtype
()));
std
::
shared_ptr
<
Tensor
>
broadcast_tensor
=
JUST
(
functional
::
LocalToGlobal
(
local_tensor
,
placement
,
*
JUST
(
GetSbpList
(
broadcast_nd_sbp
)),
shape
,
local_tensor
->
dtype
(),
/* sync_data */
true
,
/*copy=*/
false
));
std
::
vector
<
Symbol
<
SbpParallel
>>
grad_sbp_tuple
;
auto
consistent_tensor
=
JUST
(
functional
::
ToConsistent
(
broadcast_tensor
,
placement
,
sbp_tuple
,
grad_sbp_tuple
,
/* check_meta */
false
));
JUST
(
consistent_tensor
->
set_requires_grad
(
requires_grad
));
return
consistent_tensor
;
auto
global_tensor
=
JUST
(
functional
::
ToGlobal
(
broadcast_tensor
,
placement
,
sbp_tuple
,
grad_sbp_tuple
,
/* check_meta */
false
,
/*copy=*/
false
));
JUST
(
global_tensor
->
set_requires_grad
(
requires_grad
));
return
global_tensor
;
}
Maybe
<
Tensor
>
MakeTensorFromOtherTensor
(
const
std
::
shared_ptr
<
Tensor
>&
other
,
...
...
@@ -265,9 +277,9 @@ Maybe<Tensor> MakeTensorFromOtherTensor(const std::shared_ptr<Tensor>& other,
const
Symbol
<
NdSbp
>&
nd_sbp
=
JUST
(
other
->
nd_sbp
());
const
std
::
vector
<
Symbol
<
SbpParallel
>>&
sbp_tuple
=
*
JUST
(
GetSbpList
(
nd_sbp
));
std
::
vector
<
Symbol
<
SbpParallel
>>
grad_sbp_tuple
;
// TODO:(zhaoluyang)
consistent
case support pin_memory
return
functional
::
To
Consistent
(
other
,
JUST
(
other
->
parallel_desc
()),
sbp_tuple
,
grad_sbp_tuple
,
/* check_meta */
false
);
// TODO:(zhaoluyang)
global
case support pin_memory
return
functional
::
To
Global
(
other
,
JUST
(
other
->
parallel_desc
()),
sbp_tuple
,
grad_sbp_tuple
,
/* check_meta */
false
,
/*copy=*/
false
);
}
}
...
...
@@ -283,7 +295,7 @@ Maybe<Tensor> MakeTensorFromOtherTensor(const std::shared_ptr<Tensor>& other,
tensor
=
JUST
(
functional
::
Copy
(
other
,
device_
->
type
(),
device_
->
device_id
(),
pin_memory
&&
!
dtype
.
has_value
()));
}
else
{
tensor
=
JUST
(
functional
::
Consistent
ToLocal
(
other
));
tensor
=
JUST
(
functional
::
Global
ToLocal
(
other
,
/*copy=*/
false
));
if
(
!
device
)
{
device_
=
JUST
(
Device
::
New
(
"cpu"
));
}
tensor
=
JUST
(
functional
::
Copy
(
tensor
,
device_
->
type
(),
device_
->
device_id
(),
pin_memory
&&
!
dtype
.
has_value
()));
...
...
@@ -302,9 +314,9 @@ Maybe<Tensor> MakeTensorFromOtherTensor(const std::shared_ptr<Tensor>& other,
const
std
::
vector
<
Symbol
<
SbpParallel
>>&
sbp_tuple
,
const
bool
requires_grad
)
{
std
::
vector
<
Symbol
<
SbpParallel
>>
grad_sbp_tuple
;
bool
check_meta
=
other
->
is_
consistent
()
?
false
:
true
;
std
::
shared_ptr
<
Tensor
>
tensor
=
JUST
(
functional
::
ToConsistent
(
other
,
placement
,
sbp_tuple
,
grad_sbp_tuple
,
check_meta
));
bool
check_meta
=
other
->
is_
global
()
?
false
:
true
;
std
::
shared_ptr
<
Tensor
>
tensor
=
JUST
(
functional
::
ToGlobal
(
other
,
placement
,
sbp_tuple
,
grad_sbp_tuple
,
check_meta
,
/*copy=*/
false
));
if
(
dtype
)
{
const
Symbol
<
DType
>&
dtype_
=
JUST
(
dtype
);
if
(
tensor
->
dtype
()
!=
dtype_
)
{
...
...
oneflow/api/python/utils/tensor_utils.h
View file @
a715222c
...
...
@@ -29,10 +29,13 @@ limitations under the License.
#include "oneflow/core/framework/instructions_builder.h"
#include "oneflow/core/framework/tensor.h"
#include "oneflow/core/common/stride.h"
#include "oneflow/core/register/ofblob.h"
#include "oneflow/core/common/blocking_then_busy.h"
#include "oneflow/core/vm/virtual_machine.h"
#include "oneflow/core/common/foreign_lock_helper.h"
#include "oneflow/core/kernel/kernel_util.h"
#include "oneflow/api/python/functional/common.h"
#include "oneflow/core/framework/tensor_util.h"
#include "oneflow/core/profiler/profiler.h"
namespace
py
=
pybind11
;
...
...
@@ -55,13 +58,13 @@ struct format_descriptor<oneflow::float16> {
namespace
oneflow
{
namespace
one
{
Maybe
<
void
>
Eager
Mirrored
TensorZeros
(
const
std
::
shared_ptr
<
Tensor
>&
t
);
Maybe
<
void
>
Eager
Local
TensorZeros
(
const
std
::
shared_ptr
<
Tensor
>&
t
);
template
<
typename
T
>
inline
static
Maybe
<
PyObject
*>
Eager
Mirrored
TensorToNumpy
(
PyObject
*
py_tensor
)
{
inline
static
Maybe
<
PyObject
*>
Eager
Local
TensorToNumpy
(
PyObject
*
py_tensor
)
{
const
auto
&
t
=
PyTensor_Unpack
(
py_tensor
);
std
::
shared_ptr
<
Mirrored
Tensor
>
tensor
=
JUST
(
t
->
As
Mirrored
Tensor
());
std
::
shared_ptr
<
Local
Tensor
>
tensor
=
JUST
(
t
->
As
Local
Tensor
());
CHECK_OR_RETURN
(
JUST
(
tensor
->
device
())
==
JUST
(
Device
::
New
(
"cpu"
)));
CHECK_OR_RETURN
(
tensor
->
is_eager
())
<<
"eager tensors supported only."
;
// set base object attr
...
...
@@ -74,12 +77,13 @@ inline static Maybe<PyObject*> EagerMirroredTensorToNumpy(PyObject* py_tensor) {
numpy
::
OFStrideToNumpyStride
(
*
JUST
(
tensor
->
stride
()),
tensor
->
dtype
()
->
data_type
());
T
*
data_ptr
=
nullptr
;
const
auto
&
Callback
=
[
&
](
uint64_t
ofblob_ptr
)
{
data_ptr
=
reinterpret_cast
<
OfBlob
*>
(
ofblob_ptr
)
->
mut_blob
()
->
mut_dptr
<
T
>
();
const
auto
&
Callback
=
[
&
](
ep
::
Stream
*
,
const
std
::
shared_ptr
<
vm
::
EagerBlobObject
>&
eager_blob_object
)
{
data_ptr
=
eager_blob_object
->
mut_dptr
<
T
>
();
};
auto
btb
=
std
::
make_shared
<
BlockingThenBusy
>
(
1
);
auto
btb
=
std
::
make_shared
<
BlockingThenBusy
>
();
JUST
(
PhysicalRun
([
&
](
InstructionsBuilder
*
builder
)
->
Maybe
<
void
>
{
return
builder
->
SyncAccessBlobByCallback
(
tensor
,
btb
,
Callback
,
"
mu
t"
);
return
builder
->
SyncAccessBlobByCallback
(
tensor
,
btb
,
Callback
,
"
cons
t"
);
}));
JUST
(
btb
->
WaitUntilCntEqualZero
(
VirtualMachine
::
GetPredicatorNoMoreInstructionsFinished
()));
return
py
::
array
(
py
::
buffer_info
(
data_ptr
,
sizeof
(
T
),
py
::
format_descriptor
<
T
>::
format
(),
ndim
,
...
...
@@ -90,19 +94,43 @@ inline static Maybe<PyObject*> EagerMirroredTensorToNumpy(PyObject* py_tensor) {
}
template
<
typename
T
>
inline
Maybe
<
void
>
CopyBetweenMirroredTensorAndNumpy
(
struct
TensorTypeToPyType
final
{
typedef
T
type
;
};
template
<
>
struct
TensorTypeToPyType
<
float16
>
final
{
typedef
float
type
;
};
template
<
>
struct
TensorTypeToPyType
<
bfloat16
>
final
{
typedef
float
type
;
};
template
<
typename
T
>
inline
static
Maybe
<
PyObject
*>
EagerLocalTensorItem
(
const
std
::
shared_ptr
<
Tensor
>&
tensor
)
{
// OF_PROFILER_RANGE_GUARD("EagerLocalTensorItem");
T
value
=
JUST
(
GetItemInScalarTensor
<
T
>
(
tensor
));
return
functional
::
CastToPyObject
(
static_cast
<
typename
TensorTypeToPyType
<
T
>::
type
>
(
value
));
}
inline
Maybe
<
void
>
CopyBetweenLocalTensorAndNumpy
(
const
std
::
shared_ptr
<
Tensor
>&
t
,
PyObject
*
array
,
Maybe
<
void
>
(
*
Copy
)(
uint64_t
,
const
NumPyArrayPtr
&
),
const
std
::
string
&
modifier
,
bool
block_host_until_done
)
{
auto
tensor
=
JUST
(
t
->
AsMirroredTensor
());
void
(
*
Copy
)(
ep
::
Stream
*
,
const
std
::
shared_ptr
<
vm
::
EagerBlobObject
>&
,
const
NumPyArrayPtr
&
),
const
std
::
string
&
modifier
,
bool
block_host_until_done
)
{
auto
tensor
=
JUST
(
t
->
AsLocalTensor
());
CHECK_OR_RETURN
(
tensor
->
is_contiguous
())
<<
"contiguous tensors supported only."
;
CHECK_OR_RETURN
(
tensor
->
is_eager
())
<<
"eager tensors supported only."
;
if
(
block_host_until_done
)
{
NumPyArrayPtr
array_ptr
(
array
);
const
auto
&
Callback
=
[
array_ptr
,
Copy
](
uint64_t
ofblob_ptr
)
{
CHECK_JUST
(
Copy
(
ofblob_ptr
,
array_ptr
));
const
auto
&
Callback
=
[
array_ptr
,
Copy
](
ep
::
Stream
*
stream
,
const
std
::
shared_ptr
<
vm
::
EagerBlobObject
>&
eager_blob_object
)
{
Copy
(
stream
,
eager_blob_object
,
array_ptr
);
};
auto
btb
=
std
::
make_shared
<
BlockingThenBusy
>
(
1
);
auto
btb
=
std
::
make_shared
<
BlockingThenBusy
>
();
JUST
(
PhysicalRun
([
&
](
InstructionsBuilder
*
builder
)
->
Maybe
<
void
>
{
return
builder
->
SyncAccessBlobByCallback
(
tensor
,
btb
,
Callback
,
modifier
);
}));
...
...
@@ -119,17 +147,16 @@ inline Maybe<void> CopyBetweenMirroredTensorAndNumpy(
JUST
(
PhysicalRun
([
&
](
InstructionsBuilder
*
builder
)
->
Maybe
<
void
>
{
return
builder
->
AccessBlobByCallback
(
tensor
,
[
array_ptr
,
Copy
](
uint64_t
ofblob_ptr
)
{
CHECK_JUST
(
Copy
(
ofblob_ptr
,
array_ptr
));
},
[
array_ptr
,
Copy
](
ep
::
Stream
*
stream
,
const
std
::
shared_ptr
<
vm
::
EagerBlobObject
>&
eager_blob_object
)
{
Copy
(
stream
,
eager_blob_object
,
array_ptr
);
},
modifier
);
}));
}
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
std
::
string
>
GetCopyMirroredTensorToNumpyFuncName
(
DataType
dtype
);
Maybe
<
std
::
string
>
GetCopyMirroredTensorFromNumpyFuncName
(
DataType
dtype
);
Maybe
<
std
::
tuple
<
std
::
vector
<
Shape
>
,
std
::
vector
<
Symbol
<
DType
>>>>
MaybeGetTensorBufferShapesAndDTypes
(
const
std
::
shared_ptr
<
Tensor
>&
t
);
...
...
@@ -144,7 +171,7 @@ Maybe<Tensor> MakeLocalTensorFromData(PyObject* data, const Optional<Symbol<DTyp
const
Optional
<
Symbol
<
Device
>>&
device
,
const
bool
requires_grad
,
const
bool
pin_memory
);
Maybe
<
Tensor
>
Make
Consistent
TensorFromData
(
PyObject
*
data
,
const
Optional
<
Symbol
<
DType
>>&
dtype
,
Maybe
<
Tensor
>
Make
Global
TensorFromData
(
PyObject
*
data
,
const
Optional
<
Symbol
<
DType
>>&
dtype
,
Symbol
<
ParallelDesc
>
placement
,
const
std
::
vector
<
Symbol
<
SbpParallel
>>&
sbp_tuple
,
const
bool
requires_grad
);
...
...
oneflow/core/auto_parallel/algorithm_util.cpp
0 → 100644
View file @
a715222c
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/auto_parallel/algorithm_util.h"
namespace
oneflow
{
namespace
auto_parallel
{
// Inverse function of order
// The reason why we need the inverse_order, a.k.a id2order, instead of id2value is to eliminate
// equality. For example, we have v[0] < v[1] = v[2] < v[3] We do not know v[1] is before or after
// v[2] with comp(v[1], v[2]). But if we transfer it to order order[0] < order[1] < order[2] <
// order[3] We know the strict order.
void
InverseOrder
(
const
std
::
vector
<
int32_t
>&
order
,
std
::
vector
<
int32_t
>&
inverse_order
)
{
inverse_order
.
resize
(
order
.
size
());
for
(
int32_t
i
=
0
;
i
<
order
.
size
();
i
++
)
{
inverse_order
[
order
[
i
]]
=
i
;
}
}
}
// namespace auto_parallel
}
// namespace oneflow
oneflow/core/auto_parallel/algorithm_util.h
0 → 100644
View file @
a715222c
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_AUTO_PARALLEL_ALGORITHM_UTIL_H_
#define ONEFLOW_CORE_AUTO_PARALLEL_ALGORITHM_UTIL_H_
#include <vector>
#include <cstdlib>
#include <algorithm>
#include <unordered_map>
namespace
oneflow
{
namespace
auto_parallel
{
// this function is to remove the i-th element from a vector in Constant time.
// the vector should not care about ordering.
// Be more careful about this function. Make sure that the traveling order of
// the vector goes from back to front.
template
<
class
T
>
void
RemoveFrom
(
std
::
vector
<
T
>&
v
,
int32_t
i
)
{
v
[
i
]
=
v
.
back
();
v
.
pop_back
();
}
template
<
class
T
>
void
CheckAndRemoveFrom
(
std
::
vector
<
T
>&
v
,
T
&
t
)
{
for
(
int32_t
i
=
v
.
size
()
-
1
;
i
>=
0
;
i
--
)
{
if
(
v
[
i
]
==
t
)
{
RemoveFrom
<
T
>
(
v
,
i
);
break
;
}
}
}
// Inverse function, which transfer a vector to an unordered_map.
template
<
class
T
>
void
InverseFunction
(
const
std
::
vector
<
T
>&
v
,
std
::
unordered_map
<
T
,
int32_t
>&
inverse_map
)
{
inverse_map
.
clear
();
for
(
int32_t
i
=
0
;
i
<
v
.
size
();
i
++
)
{
inverse_map
[
v
[
i
]]
=
i
;
}
}
// When you want to sort something but you can not move any elements, use order.
// Decide the order of sorting in a list v, we have
// v[order[i]] < v[order[j]] for all i<j.
// We could define the comparison, then we have
// comp(v[order[i]], v[order[j]]) == true for all i<j.
template
<
class
T
,
class
Compare
>
void
DecideOrder
(
const
T
&
v
,
std
::
vector
<
int32_t
>&
order
,
const
Compare
&
comp
)
{
// Initialize order
order
.
resize
(
v
.
size
());
for
(
int32_t
i
=
0
;
i
<
v
.
size
();
i
++
)
{
order
[
i
]
=
i
;
}
// sort
std
::
sort
(
order
.
begin
(),
order
.
end
(),
[
&
](
int32_t
i
,
int32_t
j
)
{
return
comp
(
v
[
i
],
v
[
j
]);
});
}
// Inverse function of order
// The reason why we need the inverse_order, a.k.a id2order, instead of id2value is to eliminate
// equality. For example, we have v[0] < v[1] = v[2] < v[3] We do not know v[1] is before or after
// v[2] with comp(v[1], v[2]). But if we transfer it to order order[0] < order[1] < order[2] <
// order[3] We know the strict order.
void
InverseOrder
(
const
std
::
vector
<
int32_t
>&
order
,
std
::
vector
<
int32_t
>&
inverse_order
);
}
// namespace auto_parallel
static
const
double
kFloatDeviationMinus
=
0.9999999
;
static
const
double
kFloatDeviationPlus
=
1.0000001
;
}
// namespace oneflow
#endif // ONEFLOW_CORE_AUTO_PARALLEL_ALGORITHM_UTIL_H_
oneflow/core/auto_parallel/binary_set.cpp
0 → 100644
View file @
a715222c
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/auto_parallel/binary_set.h"
namespace
oneflow
{
namespace
auto_parallel
{
namespace
{
// A static function for initialization of log_2 mapping
std
::
unordered_map
<
BinarySetEntryType
,
int32_t
>
InitLog2
()
{
std
::
unordered_map
<
BinarySetEntryType
,
int32_t
>
log_2
;
for
(
int32_t
i
=
0
;
i
<
8
*
sizeof
(
BinarySetEntryType
);
i
++
)
{
log_2
[
static_cast
<
BinarySetEntryType
>
(
1
<<
i
)]
=
i
;
}
return
log_2
;
}
// Initialization of log_2 mapping
// Take log2 of a integer value: 2^n -> n.
const
std
::
unordered_map
<
BinarySetEntryType
,
int32_t
>
log_2
=
InitLog2
();
}
// namespace
// Constructor
BinarySet
::
BinarySet
(
int32_t
size_of_set
)
:
size_of_set_
(
size_of_set
)
{
int32_t
k
=
(
size_of_set
-
1
)
/
bit_entry_type_
+
1
;
binary_set_values_
.
resize
(
k
,
0
);
}
// Initialization if needed
void
BinarySet
::
Initialize
(
int32_t
size_of_set
)
{
size_of_set_
=
size_of_set
;
int32_t
k
=
(
size_of_set
-
1
)
/
bit_entry_type_
+
1
;
binary_set_values_
.
resize
(
k
,
0
);
}
// Clear all the elements in the set
void
BinarySet
::
Clear
()
{
binary_set_values_
.
assign
(
binary_set_values_
.
size
(),
0
);
}
// Check if i-th element in this subset
bool
BinarySet
::
CheckExistence
(
int32_t
i
)
const
{
int32_t
k
=
i
/
bit_entry_type_
;
int32_t
j
=
i
%
bit_entry_type_
;
return
bool
((
binary_set_values_
[
k
]
>>
j
)
&
1
);
}
// Add i-th element into this subset
void
BinarySet
::
AddEntry
(
int32_t
i
)
{
int32_t
k
=
i
/
bit_entry_type_
;
int32_t
j
=
i
%
bit_entry_type_
;
binary_set_values_
[
k
]
|=
(
1
<<
j
);
}
// Take i-th element out from this subset
void
BinarySet
::
DeleteEntry
(
int32_t
i
)
{
int32_t
k
=
i
/
bit_entry_type_
;
int32_t
j
=
i
%
bit_entry_type_
;
binary_set_values_
[
k
]
&=
~
(
1
<<
j
);
}
// Get the union with another subset and store it into u
void
BinarySet
::
UnionTo
(
const
BinarySet
&
bs
,
BinarySet
&
u
)
{
for
(
int32_t
k
=
0
;
k
<
binary_set_values_
.
size
();
k
++
)
{
u
.
binary_set_values_
[
k
]
=
binary_set_values_
[
k
]
|
bs
.
binary_set_values_
[
k
];
}
}
// If this binary set intersects another one
bool
BinarySet
::
IfIntersect
(
const
BinarySet
&
bs
)
const
{
int32_t
min_bs_size
=
std
::
min
(
binary_set_values_
.
size
(),
bs
.
binary_set_values_
.
size
());
for
(
int32_t
k
=
0
;
k
<
min_bs_size
;
k
++
)
{
if
(
binary_set_values_
[
k
]
&
bs
.
binary_set_values_
[
k
])
{
return
true
;
}
}
return
false
;
}
// Get the intersection with another subset and store it into i
void
BinarySet
::
IntersectionTo
(
const
BinarySet
&
bs
,
BinarySet
&
i
)
const
{
int32_t
min_bs_size
=
std
::
min
(
binary_set_values_
.
size
(),
bs
.
binary_set_values_
.
size
());
if
(
min_bs_size
>
i
.
binary_set_values_
.
size
())
{
i
.
binary_set_values_
.
resize
(
min_bs_size
,
0
);
}
for
(
int32_t
k
=
0
;
k
<
binary_set_values_
.
size
();
k
++
)
{
i
.
binary_set_values_
[
k
]
=
binary_set_values_
[
k
]
&
bs
.
binary_set_values_
[
k
];
}
}
// Count number of elements in this subset
int32_t
BinarySet
::
Total
()
const
{
int32_t
t
=
0
;
for
(
int32_t
k
=
0
;
k
<
binary_set_values_
.
size
();
k
++
)
{
BinarySetEntryType
bsv
=
binary_set_values_
[
k
];
bsv
=
(
bsv
&
0x5555555555555555
)
+
((
bsv
>>
1
)
&
0x5555555555555555
);
bsv
=
(
bsv
&
0x3333333333333333
)
+
((
bsv
>>
2
)
&
0x3333333333333333
);
bsv
=
(
bsv
&
0x0F0F0F0F0F0F0F0F
)
+
((
bsv
>>
4
)
&
0x0F0F0F0F0F0F0F0F
);
bsv
=
(
bsv
&
0x00FF00FF00FF00FF
)
+
((
bsv
>>
8
)
&
0x00FF00FF00FF00FF
);
bsv
=
(
bsv
&
0x0000FFFF0000FFFF
)
+
((
bsv
>>
16
)
&
0x0000FFFF0000FFFF
);
// bsv = (bsv & 0x00000000FFFFFFFF) + ((bsv >> 32) & 0x00000000FFFFFFFF);
t
+=
int32_t
(
bsv
);
}
return
t
;
}
// Output all the elements in the subset
void
BinarySet
::
Output
(
std
::
vector
<
int32_t
>&
out
)
const
{
out
.
clear
();
for
(
int32_t
i
=
0
;
i
<
size_of_set_
;
i
++
)
{
if
(
CheckExistence
(
i
))
{
out
.
emplace_back
(
i
);
}
}
}
// Output all the elements in the subset
void
BinarySet
::
QuickOutput
(
std
::
vector
<
int32_t
>&
out
)
const
{
out
.
clear
();
for
(
int32_t
i
=
0
;
i
<
binary_set_values_
.
size
();
i
++
)
{
BinarySetEntryType
x
=
binary_set_values_
[
i
];
BinarySetEntryType
y
=
0
;
while
(
x
)
{
y
=
x
;
x
&=
x
-
1
;
out
.
emplace_back
(
i
*
BinarySet
::
bit_entry_type_
+
log_2
.
find
(
y
-
x
)
->
second
);
}
}
}
// Add elements of input into this subset
void
BinarySet
::
AddEntries
(
std
::
vector
<
int32_t
>&
in
)
{
for
(
int32_t
i
:
in
)
{
AddEntry
(
i
);
}
}
// If two binary sets are equal to each other
bool
BinarySet
::
operator
==
(
const
BinarySet
&
rhs
)
const
{
if
(
size_of_set_
!=
rhs
.
size_of_set_
)
{
return
false
;
}
for
(
int32_t
i
=
0
;
i
<
binary_set_values_
.
size
();
i
++
)
{
if
(
binary_set_values_
[
i
]
!=
rhs
.
binary_set_values_
[
i
])
{
return
false
;
}
}
return
true
;
}
}
// namespace auto_parallel
}
// namespace oneflow
oneflow/core/auto_parallel/binary_set.h
0 → 100644
View file @
a715222c
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_AUTO_PARALLEL_BINARY_SET_H_
#define ONEFLOW_CORE_AUTO_PARALLEL_BINARY_SET_H_
#include <cstdlib>
#include <unordered_map>
#include <vector>
#include "oneflow/core/common/hash.h"
namespace
oneflow
{
namespace
auto_parallel
{
// log_2_ index only support 32-bit int. Don't know why.
// Don't have any other bugs for unsigned int.
using
BinarySetEntryType
=
unsigned
int
;
class
BinarySet
{
public:
BinarySet
()
{}
explicit
BinarySet
(
int32_t
size_of_set
);
// Initialization
void
Initialize
(
int32_t
size_of_set
);
// Clear all the elements in the set
void
Clear
();
// Check if i-th element in this subset
bool
CheckExistence
(
int32_t
i
)
const
;
// Add i-th element into this subset
void
AddEntry
(
int32_t
i
);
// Take i-th element out from this subset
void
DeleteEntry
(
int32_t
i
);
// Get the union with another subset and store it into u
void
UnionTo
(
const
BinarySet
&
bs
,
BinarySet
&
u
);
// If this binary set intersects another one
bool
IfIntersect
(
const
BinarySet
&
bs
)
const
;
// Get the intersection with another subset and store it into i
void
IntersectionTo
(
const
BinarySet
&
bs
,
BinarySet
&
i
)
const
;
// Count number of elements in this subset
int32_t
Total
()
const
;
// Output all the elements in the subset
void
Output
(
std
::
vector
<
int32_t
>&
out
)
const
;
// Output all the elements in the subset
void
QuickOutput
(
std
::
vector
<
int32_t
>&
out
)
const
;
// Add elements of input into this subset
void
AddEntries
(
std
::
vector
<
int32_t
>&
in
);
// If two binary sets are equal to each other
bool
operator
==
(
const
BinarySet
&
rhs
)
const
;
inline
int32_t
GetSizeOfSet
()
const
{
return
size_of_set_
;
};
private:
friend
struct
BinarySetHasher
;
// binary_set_values_ contains a vector of 64-bit or 32-bit int.
// Each bit means whether an entry is in the set
std
::
vector
<
BinarySetEntryType
>
binary_set_values_
;
int32_t
size_of_set_
=
-
1
;
// total bits of the entry type in vector binary_set_values_.
static
constexpr
int32_t
bit_entry_type_
=
8
*
sizeof
(
BinarySetEntryType
);
};
struct
BinarySetHasher
{
std
::
size_t
operator
()(
const
BinarySet
&
bs
)
const
{
using
std
::
hash
;
using
std
::
size_t
;
size_t
h
=
0
;
for
(
int
i
=
0
;
i
<
bs
.
binary_set_values_
.
size
();
i
++
)
{
h
=
HashCombine
(
h
,
hash
<
BinarySetEntryType
>
()(
bs
.
binary_set_values_
[
i
]));
}
return
h
;
};
};
}
// namespace auto_parallel
}
// namespace oneflow
#endif // ONEFLOW_CORE_AUTO_PARALLEL_BINARY_SET_H_
oneflow/core/auto_parallel/boxing_collector.cpp
View file @
a715222c
...
...
@@ -16,8 +16,10 @@ limitations under the License.
#include <memory>
#include <string>
#include "oneflow/core/auto_parallel/algorithm_util.h"
#include "oneflow/core/auto_parallel/boxing_collector.h"
#include "oneflow/core/common/data_type.h"
#include "oneflow/core/common/device_type.pb.h"
#include "oneflow/core/common/maybe.h"
#include "oneflow/core/framework/nd_sbp.h"
#include "oneflow/core/job/global_for.h"
...
...
@@ -34,9 +36,9 @@ limitations under the License.
namespace
oneflow
{
namespace
{
void
DfsSetNdSbp
(
const
std
::
vector
<
::
oneflow
::
SbpParallel
>&
id2sbp_parallel
,
int32_t
depth
,
int32_t
max_depth
,
NdSbp
&
nd_sbp
,
std
::
vector
<
NdSbp
>&
nd_sbp_lists
,
std
::
unordered_map
<
::
oneflow
::
NdSbp
,
int32_t
>&
nd_sbp_universe
)
{
void
DfsSetNdSbp
(
const
std
::
vector
<
SbpParallel
>&
id2sbp_parallel
,
int32_t
depth
,
int32_t
max_depth
,
NdSbp
&
nd_sbp
,
std
::
vector
<
NdSbp
>&
nd_sbp_lists
,
std
::
unordered_map
<
NdSbp
,
int32_t
>&
nd_sbp_universe
)
{
if
(
depth
==
max_depth
)
{
nd_sbp_universe
[
nd_sbp
]
=
nd_sbp_lists
.
size
();
nd_sbp_lists
.
push_back
(
nd_sbp
);
...
...
@@ -49,7 +51,7 @@ void DfsSetNdSbp(const std::vector<::oneflow::SbpParallel>& id2sbp_parallel, int
}
// Let a nd sbp be consistent with the given hierarchy number
Maybe
<
NdSbp
>
SetNdSbpDim
(
NdSbp
nd_sbp
,
int32_t
hierarchy_num
)
{
Maybe
<
NdSbp
>
SetNdSbpDim
(
const
NdSbp
&
nd_sbp
,
int32_t
hierarchy_num
)
{
// Do not need to change
if
(
nd_sbp
.
sbp_parallel_size
()
==
hierarchy_num
)
{
return
nd_sbp
;
}
// (S0, S0) -> S0
...
...
@@ -71,9 +73,63 @@ Maybe<NdSbp> SetNdSbpDim(NdSbp nd_sbp, int32_t hierarchy_num) {
return
new_sbp
;
}
int32_t
TotalNumSplit
(
const
NdSbp
&
nd_sbp
,
const
ParallelDesc
&
parallel_desc
)
{
int32_t
total_num_split
=
1
;
for
(
int32_t
i
=
0
;
i
<
nd_sbp
.
sbp_parallel_size
();
i
++
)
{
if
(
nd_sbp
.
sbp_parallel
(
i
).
has_split_parallel
())
{
total_num_split
*=
parallel_desc
.
hierarchy
()
->
At
(
i
);
}
}
return
total_num_split
;
}
// Dealing with 1D sbp to 1D sbp
// Specifically, S -> P.
Maybe
<
void
>
AskSbpCombinationFor1DSbp
(
const
NdSbp
&
sbp_producer
,
const
NdSbp
&
sbp_consumer
,
const
ParallelDesc
&
producer_parallel_desc
,
const
ParallelDesc
&
consumer_parallel_desc
,
std
::
vector
<
NdSbp
>&
middle_sbps
,
int32_t
*
diag_node_pos
)
{
if
(
sbp_consumer
.
sbp_parallel
(
0
).
has_partial_sum_parallel
())
{
// Support [4]: P <--> [2, 2]: (P, P)
// Support {0, 1, 2, 3}: P <--> {2, 0, 6, 7}: (P, P)
if
(
producer_parallel_desc
.
parallel_num
()
==
consumer_parallel_desc
.
parallel_num
()
&&
sbp_producer
.
sbp_parallel
(
0
).
has_partial_sum_parallel
())
{
return
Maybe
<
void
>::
Ok
();
}
if
(
!
sbp_producer
.
sbp_parallel
(
0
).
has_broadcast_parallel
())
{
// S -> B -> P (Large cost!)
// TODO: Please implement S -> P directly.
// We do not support [3]: P <--> [2, 2]: (P, P) as well.
int32_t
hierarchy_size
=
0
;
if
(
producer_parallel_desc
.
hierarchy
()
->
elem_cnt
()
<
consumer_parallel_desc
.
hierarchy
()
->
elem_cnt
())
{
// The diagonal node uses the parallel description from producer
// (S, S) -> (B, B) -> P/(P, P) or S -> B -> P/(P, P)
*
diag_node_pos
=
1
;
hierarchy_size
=
producer_parallel_desc
.
hierarchy
()
->
NumAxes
();
}
else
{
// The diagonal node uses the parallel description from consumer
// S/(S, S) -> B -> P or S/(S, S) -> (B, B) -> (P, P)
*
diag_node_pos
=
0
;
hierarchy_size
=
consumer_parallel_desc
.
hierarchy
()
->
NumAxes
();
}
NdSbp
broadcast_nd
;
for
(
int32_t
i
=
0
;
i
<
hierarchy_size
;
i
++
)
{
broadcast_nd
.
add_sbp_parallel
();
broadcast_nd
.
mutable_sbp_parallel
(
i
)
->
mutable_broadcast_parallel
();
}
middle_sbps
.
emplace_back
(
broadcast_nd
);
}
}
return
Maybe
<
void
>::
Ok
();
}
}
// namespace
// A constructor with init, designed for
uncustomiz
ed boxing collector
// A constructor with init, designed for
pre-stor
ed boxing collector
BoxingCollector
::
BoxingCollector
(
int32_t
max_axis
)
{
CHECK_JUST
(
Init
(
max_axis
));
}
// Construct a boxing collector with given maximum number of axis
...
...
@@ -92,6 +148,8 @@ Maybe<void> BoxingCollector::Init(int32_t max_axis) {
JUST
(
GenerateCombination4SamePlacement
(
3
));
JUST
(
GenerateCombination4DiffHierarchy
(
this
,
this
));
JUST
(
GenerateCombination4DiffPlacement
(
this
,
this
));
init_type_
=
int32_t
(
enable_general_basic_communication
||
Singleton
<
ResourceDesc
,
ForSession
>::
Get
()
->
nccl_use_compute_stream
());
return
Maybe
<
void
>::
Ok
();
}
...
...
@@ -106,6 +164,8 @@ Maybe<void> BoxingCollector::Init(const BlobDesc& logical_blob_desc,
// Get copy cost in lazy mode
LazyMode
::
Guard
enable_lazy_mode
(
true
);
JUST
(
GenerateCombination4SamePlacement
(
5
,
logical_blob_desc
,
parallel_desc
));
init_type_
=
int32_t
(
enable_general_basic_communication
||
Singleton
<
ResourceDesc
,
ForSession
>::
Get
()
->
nccl_use_compute_stream
());
return
Maybe
<
void
>::
Ok
();
}
...
...
@@ -173,6 +233,7 @@ void BoxingCollector::GenerateMap1d2nd() {
// Generate the id Map from 1d sbp to nd sbp
NdSbp
nd_sbp
;
for
(
int32_t
dim_sbp
=
0
;
dim_sbp
<
hierarchy_num_
;
dim_sbp
++
)
{
nd_sbp
.
add_sbp_parallel
();
}
id_1d_2_nd_
.
clear
();
id_1d_2_nd_
.
resize
(
m
,
-
1
);
for
(
int32_t
id_1d
=
0
;
id_1d
<
m
;
id_1d
++
)
{
for
(
int32_t
dim_sbp
=
0
;
dim_sbp
<
hierarchy_num_
;
dim_sbp
++
)
{
...
...
@@ -190,10 +251,13 @@ Maybe<void> BoxingCollector::GenerateCombination4SamePlacement(int32_t max_middl
// NOTE: The performance of this function are all the same with different hierarchy
int32_t
world_size
=
GlobalProcessCtx
::
WorldSize
();
Shape
hierarchy44
({
4
*
world_size
,
4
*
world_size
});
int32_t
virtual_range_size
=
hierarchy44
.
elem_cnt
();
std
::
shared_ptr
<
Shape
>
virtual_hierarchy
=
std
::
make_shared
<
Shape
>
(
hierarchy44
);
auto
parallel_desc
=
JUST
(
ParallelDesc
::
New
(
"cpu"
,
{
"0:0-"
+
std
::
to_string
(
hierarchy44
.
elem_cnt
()
-
1
)},
virtual_hierarchy
));
BlobDesc
blob_desc
({
16
,
16
,
16
,
16
},
DataType
::
kInt8
,
/*is_dynamic=*/
false
);
BlobDesc
blob_desc
({
virtual_range_size
,
virtual_range_size
,
virtual_range_size
,
virtual_range_size
,
virtual_range_size
,
virtual_range_size
},
DataType
::
kInt8
,
/*is_dynamic=*/
false
);
JUST
(
GenerateCombination4SamePlacement
(
max_middle_node_num
,
blob_desc
,
*
parallel_desc
));
return
Maybe
<
void
>::
Ok
();
}
...
...
@@ -204,7 +268,9 @@ Maybe<void> BoxingCollector::GenerateCombination4SamePlacement(int32_t max_middl
const
ParallelDesc
&
parallel_desc
)
{
// Store the origin transfer cost information
int32_t
n
=
nd_sbp_lists_
.
size
();
minimum_copy_cost_
.
clear
();
minimum_copy_cost_
.
resize
(
n
);
middle_nodes_
.
clear
();
middle_nodes_
.
resize
(
n
);
for
(
int32_t
i
=
0
;
i
<
n
;
i
++
)
{
minimum_copy_cost_
[
i
].
resize
(
n
);
...
...
@@ -250,7 +316,7 @@ Maybe<void> BoxingCollector::GenerateCombination4SamePlacement(int32_t max_middl
minimum_copy_cost_
[
i
][
j
]
=
curr_copy_cost
;
}
}
// If the minimum copy cost rem
i
ans infinity, adding one middle node does not make it.
// If the minimum copy cost rema
i
ns infinity, adding one middle node does not make it.
if
(
minimum_copy_cost_
[
i
][
j
]
>
GetValidMaxCopyCost
())
{
continue
;
}
// Find those middle nodes
for
(
int32_t
k
=
0
;
k
<
n
;
k
++
)
{
...
...
@@ -291,6 +357,7 @@ Maybe<void> BoxingCollector::GenerateCombination4DiffHierarchy(
// Search the path that contains one of the diagonal sbp
int32_t
n
=
nd_sbp_lists_
.
size
();
diag_node_diff_hierarchy_
.
clear
();
diag_node_diff_hierarchy_
.
resize
(
n
);
for
(
int32_t
i
=
0
;
i
<
n
;
i
++
)
{
diag_node_diff_hierarchy_
[
i
].
resize
(
n
);
...
...
@@ -309,7 +376,10 @@ Maybe<void> BoxingCollector::GenerateCombination4DiffPlacement(
BoxingCollector
*
boxing_collector_producer
,
BoxingCollector
*
boxing_collector_consumer
)
{
// Virtual parallel and blob description
int32_t
world_size
=
GlobalProcessCtx
::
WorldSize
();
BlobDesc
blob_desc
({
16
,
16
,
16
,
16
},
DataType
::
kInt8
,
/*is_dynamic=*/
false
);
int32_t
virtual_range_size
=
4
*
world_size
*
(
4
*
world_size
+
1
);
BlobDesc
blob_desc
({
virtual_range_size
,
virtual_range_size
,
virtual_range_size
,
virtual_range_size
,
virtual_range_size
,
virtual_range_size
},
DataType
::
kInt8
,
/*is_dynamic=*/
false
);
// Virtual placements before transfer
Shape
in_hierarchy44
({
4
*
world_size
+
1
,
4
*
world_size
});
std
::
shared_ptr
<
Shape
>
in_hierarchy
=
std
::
make_shared
<
Shape
>
(
in_hierarchy44
);
...
...
@@ -334,6 +404,7 @@ Maybe<void> BoxingCollector::ComputeCostFor1DSbpDiffPlacement(
// Number of 1d sbp
int32_t
m
=
id2sbp_parallel_
.
size
();
// Compute the cost while transferring a 1D sbp between different placements
cost_4_diff_placement
.
clear
();
cost_4_diff_placement
.
resize
(
m
);
for
(
int32_t
id_1d_producer
=
0
;
id_1d_producer
<
m
;
id_1d_producer
++
)
{
cost_4_diff_placement
[
id_1d_producer
].
resize
(
m
,
GetMaxVal
<
float
>
());
...
...
@@ -364,6 +435,7 @@ Maybe<void> BoxingCollector::GenerateCombination4DiffPlacement(
// Search the path that contains two of the diagonal sbp
int32_t
n
=
nd_sbp_lists_
.
size
();
diag_node_diff_placement_
.
clear
();
diag_node_diff_placement_
.
resize
(
n
);
for
(
int32_t
i
=
0
;
i
<
n
;
i
++
)
{
diag_node_diff_placement_
[
i
].
resize
(
n
);
...
...
@@ -496,64 +568,74 @@ Maybe<void> BoxingCollector::AskSbpCombination(const NdSbp& sbp_producer, const
if
(
ParseBooleanFromEnv
(
"ONEFLOW_BOXING_DISABLE_MIDDLE_NODE_AND_CHECK"
,
false
))
{
return
Maybe
<
void
>::
Ok
();
}
// If compute_cost==false + 2D sbp + same placment + nccl logical + not (p->b),
// Use nccl logical send recv instead of middle node.
// Note that in op sbp inference, cost of middle nodes is still used for the moment.
#if defined(WITH_CUDA) || defined(WITH_ROCM)
if
(
compute_cost
==
false
&&
producer_parallel_desc
.
hierarchy
()
->
NumAxes
()
==
2
&&
producer_parallel_desc
==
consumer_parallel_desc
&&
!
(
NdSbpHasPartialParallel
(
sbp_consumer
))
&&
// TODO(): When same dim 0 finished dealing with (*, P) -> (*, S) in nccl logical pass, open
// this condition. When dealing with (P, P) -> (B, S0), middle node will change it to (P, P)
// -> (P, S0) -> (B, S0), neither same dim 0 or send recv in nccl logical pass can deal with
// (P, P) -> (P, S0) at the moment.
// !(NdSbpHasPartialParallel(sbp_producer) && NdSbpHasBroadcastParallel(sbp_consumer)) &&
Singleton
<
ResourceDesc
,
ForSession
>::
Get
()
->
nccl_use_compute_stream
())
{
VLOG
(
3
)
<<
"Middle node insertion is skipped when src sbp is "
<<
NdSbpToString
(
sbp_producer
)
<<
" dst sbp is "
<<
NdSbpToString
(
sbp_consumer
)
<<
", because nccl logical send/recv can handle this."
;
if
(
producer_parallel_desc
==
consumer_parallel_desc
&&
sbp_producer
==
sbp_consumer
)
{
return
Maybe
<
void
>::
Ok
();
}
#endif // WITH_CUDA
// Dealing with 1D sbp to 1D sbp
// Specifically, S -> P.
if
(
Is1dSbp
(
sbp_producer
)
&&
Is1dSbp
(
sbp_consumer
))
{
if
(
sbp_consumer
.
sbp_parallel
(
0
).
has_partial_sum_parallel
())
{
// Support [4]: P <--> [2, 2]: (P, P)
// Support {0, 1, 2, 3}: P <--> {2, 0, 6, 7}: (P, P)
if
(
producer_parallel_desc
.
parallel_num
()
==
consumer_parallel_desc
.
parallel_num
()
&&
sbp_producer
.
sbp_parallel
(
0
).
has_partial_sum_parallel
())
{
JUST
(
AskSbpCombinationFor1DSbp
(
sbp_producer
,
sbp_consumer
,
producer_parallel_desc
,
consumer_parallel_desc
,
middle_sbps
,
diag_node_pos
));
// No middle nodes for the other 1d-sbp combinations
return
Maybe
<
void
>::
Ok
();
}
if
(
!
sbp_producer
.
sbp_parallel
(
0
).
has_broadcast_parallel
())
{
// S -> B -> P (Large cost!)
// TODO: Please implement S -> P directly.
// We do not support [3]: P <--> [2, 2]: (P, P) as well.
int32_t
hierarchy_size
=
0
;
if
(
producer_parallel_desc
.
hierarchy
()
->
elem_cnt
()
<
consumer_parallel_desc
.
hierarchy
()
->
elem_cnt
())
{
// The diagonal node uses the parallel description from producer
// (S, S) -> (B, B) -> P/(P, P) or S -> B -> P/(P, P)
*
diag_node_pos
=
1
;
hierarchy_size
=
producer_parallel_desc
.
hierarchy
()
->
NumAxes
();
}
else
{
// The diagonal node uses the parallel description from consumer
// S/(S, S) -> B -> P or S/(S, S) -> (B, B) -> (P, P)
*
diag_node_pos
=
0
;
hierarchy_size
=
consumer_parallel_desc
.
hierarchy
()
->
NumAxes
();
#ifdef WITH_CUDA
// Use a general basic communication if no P in the consumer
if
(((
Singleton
<
ResourceDesc
,
ForSession
>::
Get
()
->
nccl_use_compute_stream
()
&&
producer_parallel_desc
==
consumer_parallel_desc
)
||
enable_general_basic_communication
)
&&
(
!
NdSbpHasPartialParallel
(
sbp_consumer
))
&&
producer_parallel_desc
.
device_type
()
==
DeviceType
::
kCUDA
&&
consumer_parallel_desc
.
device_type
()
==
DeviceType
::
kCUDA
)
{
if
(
NdSbpHasPartialParallel
(
sbp_producer
)
&&
NdSbpHasBroadcastParallel
(
sbp_consumer
))
{
// (?, P, ?)->(Si, Sj)->(?, B, ?), two-step transfer
// Directly applying general basic communication would have O(n^2) time complexity for P->B
// Using two-step transfer would reduce it to a linear cost
JUST
(
AskSbpCombination4GeneralBasicCommunication
(
sbp_producer
,
sbp_consumer
,
logical_blob_desc
,
producer_parallel_desc
,
consumer_parallel_desc
,
middle_sbps
,
diag_node_pos
));
}
// Otherwise, one-step transfer
return
Maybe
<
void
>::
Ok
();
}
#endif // WITH_CUDA
NdSbp
broadcast_nd
;
for
(
int32_t
i
=
0
;
i
<
hierarchy_size
;
i
++
)
{
broadcast_nd
.
add_sbp_parallel
();
broadcast_nd
.
mutable_sbp_parallel
(
i
)
->
mutable_broadcast_parallel
();
#ifdef WITH_ROCM
// Use a general basic communication if no P in the consumer
if
(((
Singleton
<
ResourceDesc
,
ForSession
>::
Get
()
->
nccl_use_compute_stream
()
&&
producer_parallel_desc
==
consumer_parallel_desc
)
||
enable_general_basic_communication
)
&&
(
!
NdSbpHasPartialParallel
(
sbp_consumer
))
&&
producer_parallel_desc
.
device_type
()
==
DeviceType
::
kCUDA
&&
consumer_parallel_desc
.
device_type
()
==
DeviceType
::
kCUDA
)
{
if
(
NdSbpHasPartialParallel
(
sbp_producer
)
&&
NdSbpHasBroadcastParallel
(
sbp_consumer
))
{
// (?, P, ?)->(Si, Sj)->(?, B, ?), two-step transfer
// Directly applying general basic communication would have O(n^2) time complexity for P->B
// Using two-step transfer would reduce it to a linear cost
JUST
(
AskSbpCombination4GeneralBasicCommunication
(
sbp_producer
,
sbp_consumer
,
logical_blob_desc
,
producer_parallel_desc
,
consumer_parallel_desc
,
middle_sbps
,
diag_node_pos
));
}
middle_sbps
.
emplace_back
(
broadcast_nd
);
// Otherwise, one-step transfer
return
Maybe
<
void
>::
Ok
();
}
#endif // WITH_ROCM
if
(
JUST
(
ComputeLazyCopyCostBetweenNdSbp
(
sbp_producer
,
sbp_consumer
,
logical_blob_desc
,
producer_parallel_desc
,
consumer_parallel_desc
,
/*requires_same_sbp=*/
false
))
<
GetValidMaxCopyCost
())
{
return
Maybe
<
void
>::
Ok
();
}
else
{
int32_t
require_init_type
=
int32_t
(
enable_general_basic_communication
||
Singleton
<
ResourceDesc
,
ForSession
>::
Get
()
->
nccl_use_compute_stream
());
if
(
init_type_
!=
require_init_type
)
{
// We assemble the boxing table from S(0) to S(5).
// Those splitting in higher axes are considered in the customized boxing.
constexpr
int32_t
kRegularMaxSplitAxes
=
6
;
JUST
(
Init
(
kRegularMaxSplitAxes
));
}
}
...
...
@@ -568,6 +650,7 @@ Maybe<void> BoxingCollector::AskSbpCombination(const NdSbp& sbp_producer, const
// Transfer for the same machines, devices and hierarchy.
if
(
sbp_producer
==
sbp_consumer
)
{
return
Maybe
<
void
>::
Ok
();
}
const
auto
&
parallel_hierarchy
=
producer_parallel_desc
.
hierarchy
();
*
diag_node_pos
=
0
;
// Dealing with nD sbp, n>2
if
(
parallel_hierarchy
->
NumAxes
()
>
2
)
{
...
...
@@ -675,7 +758,7 @@ Maybe<void> BoxingCollector::AskSbpCombination4DiffPlacement(
if
(
same_placement
)
{
// Different hierarchies
CHECK_OR_RETURN
(
diag_node_diff_hierarchy_
.
size
()
>
0
)
<<
"Have not initial
zie
the combination table for different hierarchies yet! "
<<
"Have not initial
ized
the combination table for different hierarchies yet! "
"Please run JUST(GenerateCombination4DiffHierarchy(this, this)); "
"before Asking sbp combination for different parallel description."
;
if
(
JUST
(
Ask1Combination4DiffPlacement
(
...
...
@@ -687,7 +770,7 @@ Maybe<void> BoxingCollector::AskSbpCombination4DiffPlacement(
}
else
{
// Different placements
CHECK_OR_RETURN
(
diag_node_diff_placement_
.
size
()
>
0
)
<<
"Have not initial
zie
the combination table for different hierarchies yet! "
<<
"Have not initial
ized
the combination table for different hierarchies yet! "
"Please run JUST(GenerateCombination4DiffPlacement(this, this)); "
"before Asking sbp combination for different parallel description."
;
if
(
JUST
(
Ask1Combination4DiffPlacement
(
...
...
@@ -787,9 +870,9 @@ Maybe<void> BoxingCollector::Generate1Combination4DiffHierarchy(
min_path_length
=
path_length
;
// Find a candidate with small cost
if
(
curr_cost
<
min_cost
*
1.0000001
)
{
if
(
curr_cost
<
min_cost
*
kFloatDeviationPlus
)
{
// Find a smaller cost, clear the previous path.
if
(
curr_cost
<
min_cost
*
0.9999999
)
{
if
(
curr_cost
<
min_cost
*
kFloatDeviationMinus
)
{
min_cost
=
curr_cost
;
diag_nodes
.
clear
();
}
...
...
@@ -1007,4 +1090,105 @@ Maybe<void> BoxingCollector::FilterNdSbpList4LogicalShape(const BlobDesc& logica
return
Maybe
<
void
>::
Ok
();
}
// Ask for sbp combination for general basic communication
Maybe
<
void
>
BoxingCollector
::
AskSbpCombination4GeneralBasicCommunication
(
const
NdSbp
&
sbp_producer
,
const
NdSbp
&
sbp_consumer
,
const
BlobDesc
&
logical_blob_desc
,
const
ParallelDesc
&
producer_parallel_desc
,
const
ParallelDesc
&
consumer_parallel_desc
,
std
::
vector
<
NdSbp
>&
middle_sbps
,
int32_t
*
diag_node_pos
)
{
// (P, X) -> (B, X) || (X , P) -> (X, B), X is any SBP
// One step transfer, at most 50% reduction in the transfer cost, do not use middle nodes
if
(
producer_parallel_desc
==
consumer_parallel_desc
&&
producer_parallel_desc
.
hierarchy
()
->
NumAxes
()
==
2
&&
(
sbp_producer
.
sbp_parallel
(
0
)
==
sbp_consumer
.
sbp_parallel
(
0
)
||
sbp_producer
.
sbp_parallel
(
1
)
==
sbp_consumer
.
sbp_parallel
(
1
)))
{
return
Maybe
<
void
>::
Ok
();
}
// Not enough gain in transfer cost, do not use middle nodes
int32_t
partial_ratio4producer
=
PartialRatio4Producer
(
sbp_producer
,
producer_parallel_desc
);
int32_t
broadcast_ratio4consumer
=
BroadcastRatio4Consumer
(
sbp_consumer
,
consumer_parallel_desc
);
if
(
2
*
(
partial_ratio4producer
+
broadcast_ratio4consumer
)
>=
partial_ratio4producer
*
broadcast_ratio4consumer
)
{
return
Maybe
<
void
>::
Ok
();
}
bool
close2producer
=
true
;
if
(
producer_parallel_desc
.
parallel_num
()
==
consumer_parallel_desc
.
parallel_num
())
{
// Get close to the one with more splits
close2producer
=
TotalNumSplit
(
sbp_producer
,
producer_parallel_desc
)
>
TotalNumSplit
(
sbp_consumer
,
consumer_parallel_desc
);
}
else
{
// Get close to the one with more machines
close2producer
=
producer_parallel_desc
.
parallel_num
()
>
consumer_parallel_desc
.
parallel_num
();
}
// Get the contiguous sbp
if
(
close2producer
)
{
JUST
(
AskCloseAllSplitSbp
(
sbp_producer
,
producer_parallel_desc
,
logical_blob_desc
,
middle_sbps
));
*
diag_node_pos
=
1
;
}
else
{
JUST
(
AskCloseAllSplitSbp
(
sbp_consumer
,
consumer_parallel_desc
,
logical_blob_desc
,
middle_sbps
));
*
diag_node_pos
=
0
;
}
return
Maybe
<
void
>::
Ok
();
}
// Ask for a all-split sbp which is close to the original one
Maybe
<
void
>
BoxingCollector
::
AskCloseAllSplitSbp
(
const
NdSbp
&
nd_sbp
,
const
ParallelDesc
&
parallel_desc
,
const
BlobDesc
&
logical_blob_desc
,
std
::
vector
<
NdSbp
>&
middle_sbps
)
{
Shape
remain_shape
=
logical_blob_desc
.
shape
();
Shape
rest_split_shape
=
logical_blob_desc
.
shape
();
int32_t
dim_shape
=
remain_shape
.
NumAxes
();
// Initialize the remains and splitting
// logical_blob_desc.shape() == remain_shape .* rest_split_shape;
for
(
int32_t
i
=
0
;
i
<
dim_shape
;
i
++
)
{
rest_split_shape
.
Set
(
i
,
1
);
}
for
(
int32_t
sbp_id
=
0
;
sbp_id
<
nd_sbp
.
sbp_parallel_size
();
sbp_id
++
)
{
const
auto
&
sbp
=
nd_sbp
.
sbp_parallel
(
sbp_id
);
if
(
sbp
.
has_split_parallel
())
{
int32_t
axis
=
sbp
.
split_parallel
().
axis
();
int32_t
split_num
=
parallel_desc
.
hierarchy
()
->
At
(
sbp_id
);
remain_shape
.
Set
(
axis
,
remain_shape
.
At
(
axis
)
/
split_num
);
rest_split_shape
.
Set
(
axis
,
rest_split_shape
.
At
(
axis
)
*
split_num
);
}
}
// Get the contiguous sbp
NdSbp
new_sbp
=
nd_sbp
;
for
(
int32_t
sbp_id
=
0
;
sbp_id
<
nd_sbp
.
sbp_parallel_size
();
sbp_id
++
)
{
const
auto
&
sbp
=
nd_sbp
.
sbp_parallel
(
sbp_id
);
int32_t
split_num
=
parallel_desc
.
hierarchy
()
->
At
(
sbp_id
);
if
(
sbp
.
has_split_parallel
())
{
int32_t
axis
=
sbp
.
split_parallel
().
axis
();
// split shape is the total splitting number starting from sbp_id to the end
rest_split_shape
.
Set
(
axis
,
rest_split_shape
.
At
(
axis
)
/
split_num
);
}
else
{
// change P or B to S(axis)
int32_t
axis
=
-
1
;
// 4096 is large enough, we might not have that much devices
int32_t
min_split_num
=
4096
;
// We need to pick a suitable axis
for
(
int32_t
i
=
0
;
i
<
remain_shape
.
NumAxes
();
i
++
)
{
if
(
remain_shape
.
At
(
i
)
%
split_num
==
0
)
{
if
(
rest_split_shape
.
At
(
i
)
<
min_split_num
)
{
// Pick the axis with smallest splitting number among the rest of the sbp
min_split_num
=
rest_split_shape
.
At
(
i
);
axis
=
i
;
}
}
}
// P, B -> S(axis)
if
(
axis
>=
0
)
{
new_sbp
.
mutable_sbp_parallel
(
sbp_id
)
->
mutable_split_parallel
()
->
set_axis
(
axis
);
remain_shape
.
Set
(
axis
,
remain_shape
.
At
(
axis
)
/
split_num
);
}
else
{
// Can not find a suitable contiguous sbp
return
Maybe
<
void
>::
Ok
();
}
}
}
// Add the new sbp into the middle node lists
middle_sbps
.
emplace_back
(
new_sbp
);
return
Maybe
<
void
>::
Ok
();
}
}
// namespace oneflow
oneflow/core/auto_parallel/boxing_collector.h
View file @
a715222c
...
...
@@ -129,10 +129,19 @@ class BoxingCollector final {
BoxingCollector
*
boxing_collector_producer
,
BoxingCollector
*
boxing_collector_consumer
,
const
std
::
vector
<
std
::
vector
<
int32_t
>>&
diag_nodes
);
// Ask for sbp combination for general basic communication
Maybe
<
void
>
AskSbpCombination4GeneralBasicCommunication
(
const
NdSbp
&
sbp_producer
,
const
NdSbp
&
sbp_consumer
,
const
BlobDesc
&
logical_blob_desc
,
const
ParallelDesc
&
producer_parallel_desc
,
const
ParallelDesc
&
consumer_parallel_desc
,
std
::
vector
<
NdSbp
>&
middle_sbps
,
int32_t
*
diag_node_pos
);
// Ask for a all-split sbp which is closed to the original one
Maybe
<
void
>
AskCloseAllSplitSbp
(
const
NdSbp
&
nd_sbp
,
const
ParallelDesc
&
parallel_desc
,
const
BlobDesc
&
logical_blob_desc
,
std
::
vector
<
NdSbp
>&
middle_sbps
);
// Stores all the possible SbpParallel.
HashMap
<
::
oneflow
::
SbpParallel
,
int32_t
>
sbp_parallel_universe_
;
HashMap
<
SbpParallel
,
int32_t
>
sbp_parallel_universe_
;
// Relationship between id and Sbp Parallel
std
::
vector
<
::
oneflow
::
SbpParallel
>
id2sbp_parallel_
;
std
::
vector
<
SbpParallel
>
id2sbp_parallel_
;
// minimum cost
// minimum_copy_cost[producer][consumer]
std
::
vector
<
std
::
vector
<
double
>>
minimum_copy_cost_
;
...
...
@@ -142,18 +151,23 @@ class BoxingCollector final {
// nodes that needs to be inserted
std
::
vector
<
std
::
vector
<
std
::
vector
<
std
::
vector
<
int32_t
>>>>
middle_nodes_
;
// Stores all the possible NdSbp.
std
::
unordered_map
<
::
oneflow
::
NdSbp
,
int32_t
>
nd_sbp_universe_
;
std
::
unordered_map
<
NdSbp
,
int32_t
>
nd_sbp_universe_
;
// Relationship between id and Nd Sbp
std
::
vector
<
NdSbp
>
nd_sbp_lists_
;
// The diagonal middle node for differe placements
// The diagonal middle node for differe
nt
placements
std
::
vector
<
std
::
vector
<
std
::
vector
<
std
::
vector
<
int32_t
>>>>
diag_node_diff_placement_
;
// The diagonal middle node for differe hierarchies in the same placement
// The diagonal middle node for differe
nt
hierarchies in the same placement
std
::
vector
<
std
::
vector
<
std
::
vector
<
std
::
vector
<
int32_t
>>>>
diag_node_diff_hierarchy_
;
// Id Map from 1d sbp to 2d sbp
// For example: B -> (B, B), S0 -> (S0, S0)
std
::
vector
<
int32_t
>
id_1d_2_nd_
;
// The sbp size in the combination table
int32_t
hierarchy_num_
;
// How the boxing collector is initialized
int32_t
init_type_
=
-
1
;
// Enable general basic communication or not
const
bool
enable_general_basic_communication
=
ParseBooleanFromEnv
(
"ONEFLOW_BOXING_ENABLE_GENERAL_BASIC_COMMUNICATION"
,
false
);
};
// class BoxingCollector
}
// namespace oneflow
...
...
oneflow/core/auto_parallel/sbp_collector.cpp
0 → 100644
View file @
a715222c
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include <string>
#include "oneflow/core/auto_parallel/sbp_collector.h"
#include "oneflow/core/auto_parallel/binary_set.h"
#include "oneflow/core/auto_parallel/sbp_util.h"
#include "oneflow/core/auto_parallel/sbp_constructor.h"
namespace
oneflow
{
namespace
auto_parallel
{
namespace
{
// Whether the given binary set intersects all the sbp sets of the consumers
bool
IfIntersectAll
(
const
HashMap
<
std
::
pair
<
std
::
string
,
std
::
string
>
,
BinarySet
>&
consumer_bn2sbp_set
,
const
BinarySet
&
bs
)
{
for
(
const
auto
&
sbp_set_group
:
consumer_bn2sbp_set
)
{
if
(
!
bs
.
IfIntersect
(
sbp_set_group
.
second
))
{
return
false
;
}
}
return
true
;
}
// Find unique sbp sets
void
FindUniqueSbpSets
(
const
HashMap
<
std
::
pair
<
std
::
string
,
std
::
string
>
,
BinarySet
>&
consumer_bn2sbp_set
,
const
std
::
unordered_set
<
int32_t
>&
all_sbp_set
,
std
::
vector
<
int32_t
>&
accumulator
,
BinarySet
&
unique_sbps
)
{
std
::
vector
<
int32_t
>
sbp_ids
;
// count the number of sbp
for
(
const
auto
&
sbp_set_group
:
consumer_bn2sbp_set
)
{
sbp_set_group
.
second
.
QuickOutput
(
sbp_ids
);
for
(
int32_t
sbp_id
:
sbp_ids
)
{
accumulator
[
sbp_id
]
++
;
}
}
// find unique sbp and clear the accumulator
for
(
const
auto
&
sbp_id
:
all_sbp_set
)
{
if
(
accumulator
[
sbp_id
]
==
1
)
{
unique_sbps
.
AddEntry
(
sbp_id
);
}
accumulator
[
sbp_id
]
=
0
;
}
}
// Find unique sbp groups
void
FindUniqueSbpGroups
(
const
HashMap
<
std
::
pair
<
std
::
string
,
std
::
string
>
,
BinarySet
>&
consumer_bn2sbp_set
,
const
std
::
unordered_set
<
int32_t
>&
all_sbp_set
,
std
::
vector
<
int32_t
>&
accumulator
,
BinarySet
&
bs_buffer
,
std
::
vector
<
BinarySet
>&
unique_sbp_groups
)
{
// find the unique sbp sets
BinarySet
unique_sbps
(
accumulator
.
size
());
FindUniqueSbpSets
(
consumer_bn2sbp_set
,
all_sbp_set
,
accumulator
,
unique_sbps
);
// A: {B, S0, S1, S2, S3}, C: {B, S0}, D: {B, S0}
// {S1, S2, S3} show up only once, a parallel candidate should not contain two of them
for
(
const
auto
&
sbp_set_group
:
consumer_bn2sbp_set
)
{
unique_sbps
.
IntersectionTo
(
sbp_set_group
.
second
,
bs_buffer
);
// Find those unique sbp groups with more than two sbp
// For example {B, S1, S2} is an impossible proxy candidate,
// since {S1, S2} is only contained by A but not contained by C and D.
// A could be either S1 or S2. The tensor do not need to be transferred to both S1 and S2.
if
(
bs_buffer
.
Total
()
>=
2
)
{
unique_sbp_groups
.
push_back
(
bs_buffer
);
}
}
bs_buffer
.
Clear
();
}
// If not contains two sbp from a same unique group
bool
No2SbpFromSameUniqueGroup
(
const
BinarySet
&
bs
,
const
std
::
vector
<
BinarySet
>&
unique_sbp_groups
)
{
BinarySet
intersection
(
bs
.
GetSizeOfSet
());
for
(
const
auto
&
unique_sbp_group
:
unique_sbp_groups
)
{
bs
.
IntersectionTo
(
unique_sbp_group
,
intersection
);
// For example {B, S1, S2} is an impossible proxy candidate,
// since {S1, S2} is only contained by A but not contained by C and D.
// A could be either S1 or S2. The tensor do not need to be transferred to both S1 and S2.
if
(
intersection
.
Total
()
>=
2
)
{
return
false
;
}
}
return
true
;
}
}
// namespace
// Default constructor for SbpCollector
// Don't allow any special case for broadcast!
SbpCollector
::
SbpCollector
()
{
// initialize Sbp Parallel Universe with broadcast.
// NdSbp sbp_broadcast;
// sbp_broadcast.mutable_broadcast_parallel();
// nd_sbp_universe_[sbp_broadcast] = 0;
// id2nd_sbp_.push_back(sbp_broadcast);
}
// Collect all the possible Sbp Parallel from a NdSbpSignature
void
SbpCollector
::
CollectUniverse
(
const
NdSbpSignature
&
nd_sbp_sig
)
{
for
(
auto
&
bn_sbp_pair
:
nd_sbp_sig
.
bn_in_op2nd_sbp
())
{
if
(
nd_sbp_universe_
.
find
(
bn_sbp_pair
.
second
)
==
nd_sbp_universe_
.
end
())
{
int32_t
curr_size
=
nd_sbp_universe_
.
size
();
nd_sbp_universe_
[
bn_sbp_pair
.
second
]
=
curr_size
;
id2nd_sbp_
.
push_back
(
bn_sbp_pair
.
second
);
}
}
}
// Collect all the possible Sbp Parallel from a SbpNode
void
SbpCollector
::
CollectUniverse
(
const
SbpNode
*
sbp_node
)
{
for
(
auto
&
nd_sbp_sig
:
sbp_node
->
sbp_sig_list_
)
{
CollectUniverse
(
nd_sbp_sig
);
}
}
// Collect all the possible Sbp Parallel from a SbpGraph
void
SbpCollector
::
CollectUniverse
(
const
SbpGraph
&
sbp_graph
)
{
for
(
auto
*
sbp_node
:
sbp_graph
.
node_list_
)
{
CollectUniverse
(
sbp_node
);
}
accumulator_
.
resize
(
nd_sbp_universe_
.
size
(),
0
);
bs_buffer_
.
Initialize
(
nd_sbp_universe_
.
size
());
}
// TODO: Auto Placement!
// It only collect the same sbp with the same parallel description
// In this moment their hierarchy is the same!
// Initialize copy cost from producer to proxy of producer
void
SbpCollector
::
InitializeCopyCostFromNode2Proxy
(
const
SbpNode
*
sbp_proxy
,
const
LogicalBlobId
&
lbi
)
const
{
// the only edge from producer to proxy of producer
SbpEdge
*
sbp_edge
=
sbp_proxy
->
edges_in_
[
0
];
SbpNode
*
sbp_node_producer
=
sbp_edge
->
start_node_
;
sbp_edge
->
cost_
.
resize
(
sbp_node_producer
->
sbp_sig_list_
.
size
());
int32_t
consumer_sbp_size
=
sbp_proxy
->
parallel_candidates_
.
size
();
// look through sbp signature in producer
for
(
int32_t
sbp_id_producer
=
0
;
sbp_id_producer
<
sbp_node_producer
->
sbp_sig_list_
.
size
();
sbp_id_producer
++
)
{
sbp_edge
->
cost_
[
sbp_id_producer
].
resize
(
consumer_sbp_size
,
0
);
}
// Assemble copy cost from producer to proxy of producer
OpNode
*
producer
=
sbp_node_producer
->
op_node_
;
// get parallel description. Number of devices.
const
ParallelDesc
&
producer_parallel_desc
=
producer
->
parallel_desc
();
// Need to be careful, the logical blob description should be independent to current
// NdSbp. Use producer or op_node?
const
BlobDesc
&
logical_blob_desc
=
producer
->
LogicalBlobDesc4Lbi
(
lbi
);
const
std
::
string
&
obn
=
*
CHECK_JUST
(
producer
->
op
().
obn4lbi
(
lbi
));
// A buffer to store the sbp parallel id
std
::
vector
<
int32_t
>
sbp_parallel_ids
;
// look through sbp signature in producer
for
(
int32_t
sbp_id_producer
=
0
;
sbp_id_producer
<
sbp_node_producer
->
sbp_sig_list_
.
size
();
sbp_id_producer
++
)
{
// get sbp parallel for a logical blob in producer
const
auto
&
producer_sbp_bn_in_op2sbp_parallel
=
sbp_node_producer
->
sbp_sig_list_
[
sbp_id_producer
].
bn_in_op2nd_sbp
();
const
NdSbp
&
sbp_producer
=
producer_sbp_bn_in_op2sbp_parallel
.
at
(
obn
);
// look through sbp parallel set in consumer
for
(
int32_t
sbp_id_consumer
=
0
;
sbp_id_consumer
<
consumer_sbp_size
;
sbp_id_consumer
++
)
{
const
BinarySet
&
sbp_parallel_set
=
sbp_proxy
->
parallel_candidates_
[
sbp_id_consumer
];
sbp_parallel_set
.
QuickOutput
(
sbp_parallel_ids
);
// look through all sbp parallels in a sbp parallel set
for
(
int32_t
sbp_parallel_id
:
sbp_parallel_ids
)
{
// get sbp parallel for a logical blob in consumer
const
NdSbp
&
sbp_consumer
=
id2nd_sbp_
[
sbp_parallel_id
];
// compute copy cost for a specific logical blob
// Use the parallel description of producer as those for consumer for now.
sbp_edge
->
cost_
[
sbp_id_producer
][
sbp_id_consumer
]
+=
CHECK_JUST
(
ComputeCopyCostWithMiddleNodes
(
sbp_producer
,
sbp_consumer
,
logical_blob_desc
,
producer_parallel_desc
,
producer_parallel_desc
,
/*is_same=*/
false
));
}
}
}
}
// Initialize copy cost from proxy of producer to consumers
void
SbpCollector
::
InitializeCopyCostFromProxy2Consumer
(
SbpNode
*
sbp_proxy
,
const
HashMap
<
std
::
pair
<
std
::
string
,
std
::
string
>
,
BinarySet
>&
consumer_bn2sbp_set
,
const
HashMap
<
std
::
string
,
SbpNode
*>&
op_name2sbp_node
)
const
{
// Connect sbp proxy and consumers
for
(
const
auto
&
consumer_bn_group
:
consumer_bn2sbp_set
)
{
// consumer in cost model
SbpNode
*
sbp_node_consumer
=
op_name2sbp_node
.
find
(
consumer_bn_group
.
first
.
first
)
->
second
;
// input blob name of logical blob in consumer
const
std
::
string
&
ibn
=
consumer_bn_group
.
first
.
second
;
// check is_mutable in consumer
OpNode
*
consumer
=
sbp_node_consumer
->
op_node_
;
CHECK
(
!
RequireSameSbp
(
consumer
,
ibn
))
<<
"Create a proxy for an unsuitable consumer!
\n
"
;
// Connect sbp proxy and consumer
sbp_proxy
->
PointTo
(
sbp_node_consumer
);
// the sbp edge connecting proxy and consumer
SbpEdge
*
sbp_edge
=
sbp_node_consumer
->
FindEdgeWithNode
(
sbp_proxy
);
sbp_edge
->
cost_
.
resize
(
sbp_proxy
->
parallel_candidates_
.
size
());
int32_t
consumer_sbp_size
=
sbp_node_consumer
->
sbp_sig_list_
.
size
();
// look through sbp parallel set in proxy
for
(
int32_t
sbp_id_producer
=
0
;
sbp_id_producer
<
sbp_proxy
->
parallel_candidates_
.
size
();
sbp_id_producer
++
)
{
// initialization for copy cost
sbp_edge
->
cost_
[
sbp_id_producer
].
resize
(
consumer_sbp_size
,
0
);
// get sbp parallel set for a logical blob in proxy
BinarySet
&
parallel_candidate
=
sbp_proxy
->
parallel_candidates_
[
sbp_id_producer
];
// look through sbp signatures in consumers
for
(
int32_t
sbp_id_consumer
=
0
;
sbp_id_consumer
<
consumer_sbp_size
;
sbp_id_consumer
++
)
{
// get sbp parallel for a logical blob in consumer
const
auto
&
consumer_sbp_bn_in_op2sbp_parallel
=
sbp_node_consumer
->
sbp_sig_list_
[
sbp_id_consumer
].
bn_in_op2nd_sbp
();
const
NdSbp
&
sbp_consumer
=
consumer_sbp_bn_in_op2sbp_parallel
.
at
(
ibn
);
if
((
!
parallel_candidate
.
CheckExistence
(
nd_sbp_universe_
.
find
(
sbp_consumer
)
->
second
)))
{
sbp_edge
->
cost_
[
sbp_id_producer
][
sbp_id_consumer
]
=
GetMaxVal
<
float
>
();
}
}
}
}
}
// Export list of possible combination of Sbp Parallels
void
SbpCollector
::
ProxySbpCandidate
(
const
OpGraph
&
op_graph
,
const
HashMap
<
std
::
string
,
SbpNode
*>&
op_name2sbp_node
,
SbpGraph
&
sbp_graph
)
{
// If needed, we can output the mapping from operator name to its proxy.
// HashMap<std::string, HashMap<LogicalBlobId, SbpNode*>>&
// op_name2lbi2sbp_proxy;
// mapping from a logical blob id to index
HashMap
<
LogicalBlobId
,
int32_t
>
lbi2index
;
// mapping from the index to producer, consumer and corresponding input blob name, possible sbp
// sets
std
::
vector
<
const
OpNode
*>
index2producer
;
std
::
vector
<
std
::
unordered_set
<
int32_t
>>
index2sbp_set
;
// mapping from consumers and input blob names to an unordered_set of SBP Parallel.
std
::
vector
<
HashMap
<
std
::
pair
<
std
::
string
,
std
::
string
>
,
BinarySet
>>
index2consumer_bn2sbp_set
;
for
(
auto
*
consumer_sbp_node
:
sbp_graph
.
node_list_
)
{
auto
*
node
=
consumer_sbp_node
->
op_node_
;
OperatorConf
::
OpTypeCase
op_type_case
=
node
->
op
().
op_conf
().
op_type_case
();
// If not support boxing, just skip it.
if
(
IsClassRegistered
<
int32_t
,
DisableInputBoxingGroup
>
(
op_type_case
))
{
return
;
}
for
(
const
std
::
string
&
ibn
:
node
->
op
().
input_bns
())
{
// Skip those blobs who enforce same SBP.
if
(
RequireSameSbp
(
node
,
ibn
))
{
// Enforcing same SBP. Can not collect sbp from this blob.
continue
;
}
const
LogicalBlobId
&
lbi
=
node
->
op
().
BnInOp2Lbi
(
ibn
);
const
OpNode
&
producer
=
node
->
ProducerOpNode4Lbi
(
lbi
);
// not building proxy for fixed operators
if
(
op_name2sbp_node
.
find
(
producer
.
op
().
op_name
())
==
op_name2sbp_node
.
end
())
{
return
;
}
// decide the index of a logical blob description
const
auto
&
iterator_lbi
=
lbi2index
.
find
(
lbi
);
int32_t
index
=
0
;
if
(
iterator_lbi
==
lbi2index
.
end
())
{
index
=
lbi2index
.
size
();
lbi2index
[
lbi
]
=
index
;
// map from lbi to the producer
index2producer
.
push_back
(
&
producer
);
// Initialize consumer_bns and the sbp sets
index2consumer_bn2sbp_set
.
resize
(
index
+
1
);
index2sbp_set
.
resize
(
index
+
1
);
}
else
{
index
=
iterator_lbi
->
second
;
}
// a set to store the id of all possible SBP Parallel for a downstream op
// should filter out repeated SBP Parallel by pre-storing them into an unordered_set
BinarySet
&
nd_sbp_ids
=
index2consumer_bn2sbp_set
[
index
][{
node
->
op
().
op_name
(),
ibn
}];
nd_sbp_ids
.
Initialize
(
nd_sbp_universe_
.
size
());
// The union sbp set of all the consumers
std
::
unordered_set
<
int32_t
>&
union_nd_sbp_ids
=
index2sbp_set
[
index
];
for
(
auto
&
sbp_sig
:
consumer_sbp_node
->
sbp_sig_list_
)
{
const
auto
&
map
=
sbp_sig
.
bn_in_op2nd_sbp
();
const
auto
&
iter
=
map
.
find
(
ibn
);
CHECK
(
iter
!=
map
.
end
())
<<
"blob_name "
<<
ibn
<<
" not found in sbp signature"
;
const
NdSbp
&
consumer_sbp
=
iter
->
second
;
// filter out repeated SBP
int32_t
sbp_universe_id
=
nd_sbp_universe_
.
find
(
consumer_sbp
)
->
second
;
nd_sbp_ids
.
AddEntry
(
sbp_universe_id
);
union_nd_sbp_ids
.
insert
(
sbp_universe_id
);
}
}
};
// A set of binary set with broadcast only
// std::unordered_set<BinarySet, BinarySetHasher> parallel_candidates_initializer;
// BinarySet one_broadcast(nd_sbp_universe_.size());
// one_broadcast.AddEntry(0);
// parallel_candidates_initializer.insert(std::move(one_broadcast));
// Decide if we should insert a proxy for each logical blob
for
(
auto
&
lbi_index
:
lbi2index
)
{
int32_t
index
=
lbi_index
.
second
;
// Only insert proxy for those blobs with multiple downstream consumers.
if
(
index2consumer_bn2sbp_set
[
index
].
size
()
<
2
)
{
continue
;
}
// Maximum number of possible sbp in the proxy
int32_t
max_num_sbp_proxy
=
std
::
min
(
max_num_sbp_proxy_
,
index2consumer_bn2sbp_set
[
index
].
size
());
// producer in cost model
const
std
::
string
&
producer_name
=
index2producer
[
index
]
->
op
().
op_name
();
SbpNode
*
sbp_node_producer
=
op_name2sbp_node
.
find
(
producer_name
)
->
second
;
const
LogicalBlobId
&
lbi
=
lbi_index
.
first
;
// store all the binary sets of SBP Parallel into an unordered_set.
// std::vector<BinarySet> parallel_candidates;
// generate sbp proxy
SbpNode
*
sbp_proxy
=
sbp_graph
.
GenerateNode
();
// A: {B, S0, S1, S2, S3}, C: {B, S0}, D: {B, S0}
// {S1, S2, S3} show up only once, a parallel candidate should not contain two of them
std
::
vector
<
BinarySet
>
unique_sbp_groups
;
FindUniqueSbpGroups
(
index2consumer_bn2sbp_set
[
index
],
index2sbp_set
[
index
],
accumulator_
,
bs_buffer_
,
unique_sbp_groups
);
// Depth first search to collect Sbp Parallel information for the whole sbp set
DfsSbpSet
(
0
,
max_num_sbp_proxy
,
index2sbp_set
[
index
],
index2sbp_set
[
index
].
begin
(),
index2consumer_bn2sbp_set
[
index
],
unique_sbp_groups
,
sbp_proxy
->
parallel_candidates_
);
// Initialize computation cost
sbp_proxy
->
cost_
.
resize
(
sbp_proxy
->
parallel_candidates_
.
size
(),
0
);
// Transfer a logical blob from producer to a sbp proxy of this blob
sbp_node_producer
->
PointTo
(
sbp_proxy
);
// Compute copy cost between producer and proxy
InitializeCopyCostFromNode2Proxy
(
sbp_proxy
,
lbi
);
// Build connection and compute copy cost between proxy and consumers
InitializeCopyCostFromProxy2Consumer
(
sbp_proxy
,
index2consumer_bn2sbp_set
[
index
],
op_name2sbp_node
);
// Unloading
for
(
const
auto
&
consumer_bn_group
:
index2consumer_bn2sbp_set
[
index
])
{
// consumer in cost model
SbpNode
*
sbp_node_consumer
=
op_name2sbp_node
.
find
(
consumer_bn_group
.
first
.
first
)
->
second
;
// the sbp edge connecting producer and consumer
SbpEdge
*
edge_found
=
sbp_node_consumer
->
FindEdgeWithNode
(
sbp_node_producer
);
// unload logical blob from sbp edges
edge_found
->
UnloadLbi
(
lbi
);
// Do not clip this edge. Save it for wait time.
// clip this edge if it no longer carries any blob
// We don't clip edges before since we have transfer cost
// Now we clip edges, which makes the topology simpler
if
(
edge_found
->
EmptyLbi
()
&&
edge_found
->
wait_time_
<=
0.0
&&
edge_found
->
wait_time_
>
-
0.5
)
{
sbp_graph
.
ClipEdge
(
edge_found
);
}
}
}
}
// Depth first search to collect Sbp Parallel information for different logical blob ids
void
SbpCollector
::
DfsSbpSet
(
int32_t
depth
,
int32_t
max_depth
,
const
std
::
unordered_set
<
int32_t
>&
sbp_sets
,
const
std
::
unordered_set
<
int32_t
>::
iterator
&
start_it
,
const
HashMap
<
std
::
pair
<
std
::
string
,
std
::
string
>
,
BinarySet
>&
consumer_bn2sbp_set
,
const
std
::
vector
<
BinarySet
>&
unique_sbp_groups
,
std
::
vector
<
BinarySet
>&
parallel_candidates
)
{
if
(
depth
>
0
)
{
if
(
IfIntersectAll
(
consumer_bn2sbp_set
,
bs_buffer_
)
&&
No2SbpFromSameUniqueGroup
(
bs_buffer_
,
unique_sbp_groups
))
{
// store the binary set into an unordered_set
parallel_candidates
.
push_back
(
bs_buffer_
);
}
}
if
(
depth
>=
max_depth
)
{
return
;
}
// go through the rest of the sbp parallel
std
::
unordered_set
<
int32_t
>::
iterator
curr_it
=
start_it
;
while
(
curr_it
!=
sbp_sets
.
end
())
{
// Take the value out
int32_t
nd_sbp_num
=
*
curr_it
;
// Then move to the next pointer
++
curr_it
;
if
(
accumulator_
[
nd_sbp_num
]
==
0
)
{
bs_buffer_
.
AddEntry
(
nd_sbp_num
);
++
accumulator_
[
nd_sbp_num
];
DfsSbpSet
(
depth
+
1
,
max_depth
,
sbp_sets
,
curr_it
,
consumer_bn2sbp_set
,
unique_sbp_groups
,
parallel_candidates
);
bs_buffer_
.
DeleteEntry
(
nd_sbp_num
);
--
accumulator_
[
nd_sbp_num
];
}
}
}
}
// namespace auto_parallel
}
// namespace oneflow
oneflow/core/auto_parallel/sbp_collector.h
0 → 100644
View file @
a715222c
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef SBP_COLLECTOR_
#define SBP_COLLECTOR_
#include <unordered_map>
#include <vector>
#include <unordered_set>
#include <utility>
#include <type_traits>
#include "oneflow/core/auto_parallel/sbp_graph.h"
#include "oneflow/core/graph/op_graph.h"
#include "oneflow/core/job/sbp_parallel.pb.h"
#include "oneflow/core/job/local_sig_infer_hint.h"
#include "oneflow/core/job/job_builder.h"
// #include "sbp_constructor.h"
#define DEBUG_COLLECTOR_
namespace
oneflow
{
namespace
auto_parallel
{
class
SbpCollector
{
public:
SbpCollector
();
~
SbpCollector
()
{}
// Collect all the possible Sbp Parallel from a SbpGraph
void
CollectUniverse
(
const
SbpGraph
&
sbp_graph
);
// Export list of possible combination of Sbp Parallels
void
ProxySbpCandidate
(
const
OpGraph
&
op_graph
,
const
HashMap
<
std
::
string
,
SbpNode
*>&
op_name2sbp_node
,
SbpGraph
&
sbp_graph
);
private:
// Stores all the possible NdSbp.
std
::
unordered_map
<
NdSbp
,
int32_t
>
nd_sbp_universe_
;
// Relationship between id and Sbp Parallel
std
::
vector
<
NdSbp
>
id2nd_sbp_
;
// Calculate number of downstream sbp
std
::
vector
<
int32_t
>
accumulator_
;
// A binary set buffer to indicate sets of downstream sbp
BinarySet
bs_buffer_
;
// Collect all the possible Sbp Parallel from a NdSbpSignature
void
CollectUniverse
(
const
NdSbpSignature
&
nd_sbp_sig
);
// Collect all the possible Sbp Parallel from a SbpNode
void
CollectUniverse
(
const
SbpNode
*
sbp_node
);
// Initialize copy cost from producer to proxy of producer
void
InitializeCopyCostFromNode2Proxy
(
const
SbpNode
*
sbp_proxy
,
const
LogicalBlobId
&
lbi
)
const
;
// Initialize copy cost from proxy of producer to consumers
void
InitializeCopyCostFromProxy2Consumer
(
SbpNode
*
sbp_proxy
,
const
HashMap
<
std
::
pair
<
std
::
string
,
std
::
string
>
,
BinarySet
>&
consumer_bn2sbp_set
,
const
HashMap
<
std
::
string
,
SbpNode
*>&
op_name2sbp_node
)
const
;
// Maximum number of possible sbp in the proxy
const
unsigned
long
max_num_sbp_proxy_
=
3
;
// Depth first search to collect Sbp Parallel information for the whole sbp set
void
DfsSbpSet
(
int32_t
depth
,
int32_t
max_depth
,
const
std
::
unordered_set
<
int32_t
>&
sbp_sets
,
const
std
::
unordered_set
<
int32_t
>::
iterator
&
sbp_set_it
,
const
HashMap
<
std
::
pair
<
std
::
string
,
std
::
string
>
,
BinarySet
>&
consumer_bn2sbp_set
,
const
std
::
vector
<
BinarySet
>&
unique_sbp_groups
,
std
::
vector
<
BinarySet
>&
parallel_candidates
);
};
// class SbpCollector
}
// namespace auto_parallel
}
// namespace oneflow
#endif // SBP_COLLECTOR_
oneflow/core/auto_parallel/sbp_constructor.cpp
0 → 100644
View file @
a715222c
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/auto_parallel/sbp_constructor.h"
#include "oneflow/core/auto_parallel/sbp_node.h"
#include "oneflow/core/auto_parallel/sbp_util.h"
#include "oneflow/core/framework/sbp_infer_util.h"
#include "oneflow/core/graph/op_graph.h"
#include "oneflow/core/job/sbp_parallel.h"
#include "oneflow/core/framework/nd_sbp.h"
#include "oneflow/core/job/job.pb.h"
#include "oneflow/core/auto_parallel/sbp_collector.h"
namespace
oneflow
{
namespace
auto_parallel
{
Maybe
<
void
>
SbpConstructor
::
Init
(
const
OpGraph
&
op_graph
,
Job
*
job
/*Maybe not use*/
)
{
JUST
(
InitSbpGraph
(
op_graph
,
*
job
));
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
SbpConstructor
::
InitSbpGraph
(
const
OpGraph
&
op_graph
,
const
Job
&
job
)
{
// TODO: process local node
JUST
(
GenerateNodeAndEdge
(
op_graph
,
job
));
JUST
(
FillSbpSignatureForOpNode
(
op_graph
,
job
));
JUST
(
InitComputationCost
(
op_graph
));
if
(
enable_trunk_algo_
)
{
JUST
(
ApplyTrunkAlgo
());
}
if
(
use_sbp_collector_
)
{
// Load logical blobs on all sbp edges.
LoadLbi2SbpEdge
(
op_graph
);
// Use sbp collector to create sbp proxy for nodes with multiple downstream operators.
SbpCollector
sbp_collector
;
sbp_collector
.
CollectUniverse
(
sbp_graph_
);
sbp_collector
.
ProxySbpCandidate
(
op_graph
,
op_name2sbp_node_
,
sbp_graph_
);
}
JUST
(
InitCopyCost
(
op_graph
));
// TODO: Set all the sbp signature id to be 0 for initialization.
// Could revert it back to
// sbp_graph_.RandomSbpSignature(use_sbp_collector_);
// after settling down the synchronization of sbp strategy.
sbp_graph_
.
SetDefaultSbpSig
();
double
ori_cost
=
sbp_graph_
.
ComputeCost
();
LOG
(
INFO
)
<<
"Initial cost: "
<<
ori_cost
;
// If we do not prune those parallel cast ops, steal the initial strategy from user setting and
// semi-auto parallelism
if
(
!
job
.
job_conf
().
enable_auto_parallel_ignore_user_sbp_config
())
{
JUST
(
StealSbpSignatureFromOpNode
(
op_graph
,
job
));
ori_cost
=
sbp_graph_
.
ComputeCost
();
LOG
(
INFO
)
<<
"OpGraph cost: "
<<
ori_cost
;
}
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
SbpConstructor
::
FindBestSbpSignature
()
{
double
ori_cost
=
sbp_graph_
.
ComputeCost
();
LOG
(
INFO
)
<<
"Initial cost: "
<<
ori_cost
;
int
elimination_num
=
sbp_graph_
.
NodeAndEdgeEliminations
();
LOG
(
INFO
)
<<
"Elimination number: "
<<
elimination_num
;
if
(
ori_cost
>
GetValidMaxCopyCost
())
{
JUST
(
sbp_graph_
.
Find1Strategy4Greedy
());
ori_cost
=
sbp_graph_
.
ComputeCost
();
LOG
(
INFO
)
<<
"Greedy cost: "
<<
ori_cost
;
}
sbp_graph_
.
GreedyStrategy
(
4
);
sbp_graph_
.
FinalizeSbp
();
double
final_cost
=
sbp_graph_
.
ComputeCost
();
LOG
(
INFO
)
<<
"Final cost: "
<<
final_cost
;
if
(
ori_cost
+
1.0
<
final_cost
)
{
LOG
(
WARNING
)
<<
"ori_cost less than final_cost!!!"
;
}
// TODO: Restart searching with another original random strategy
CHECK_LT_OR_RETURN
(
final_cost
,
GetValidMaxCopyCost
())
<<
"Failed! Auto parallel can't find a strategy with reasonable cost!"
;
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
SbpConstructor
::
DumpNdSbpSignatureForJob
(
const
OpGraph
&
op_graph
,
Job
*
job
)
{
for
(
auto
&
op_conf
:
*
job
->
mutable_net
()
->
mutable_op
())
{
const
OpNode
*
node
=
op_graph
.
OpNode4OpName
(
op_conf
.
name
());
SbpNode
*
sbp_node
=
op_name2sbp_node_
[
node
->
op
().
op_name
()];
const
NdSbpSignature
&
nd_sbp_sig
=
sbp_node
->
FinalSbpSignature
();
// Update NdSbpSignature
(
*
job
->
mutable_job_parallel_view_conf
()
->
mutable_op_name2nd_sbp_signature_conf
())[
node
->
op
().
op_name
()]
.
CopyFrom
(
nd_sbp_sig
);
// If we have 1D SbpSignature Conf
if
(
node
->
parallel_desc
().
hierarchy
()
->
NumAxes
()
==
1
)
{
// Update SbpSignature
SbpSignature
sbp_signature
;
NdSbpSignatureToSbpSignature
(
nd_sbp_sig
,
&
sbp_signature
);
(
*
job
->
mutable_job_parallel_view_conf
()
->
mutable_op_name2sbp_signature_conf
())[
node
->
op
().
op_name
()]
.
CopyFrom
(
sbp_signature
);
}
JUST
(
node
->
op
().
GetDumpNdSbpSignatureForOpConfFn
()(
nd_sbp_sig
,
&
op_conf
));
}
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
SbpConstructor
::
GenerateNodeAndEdge
(
const
OpGraph
&
op_graph
,
const
Job
&
job
)
{
JobParallelViewConf
job_parallel_view_conf
(
job
.
job_parallel_view_conf
());
// Collect op_node
std
::
vector
<
OpNode
*>
op_node_list
;
op_graph
.
ForEachNode
([
&
](
OpNode
*
op_node
)
{
// TODO: support local op
bool
is_local_conf
=
false
;
{
const
auto
&
op_name2is_local
=
job_parallel_view_conf
.
op_name2is_local_parallel_view
();
const
auto
&
iter
=
op_name2is_local
.
find
(
op_node
->
op
().
op_name
());
if
(
iter
!=
op_name2is_local
.
end
())
{
is_local_conf
=
iter
->
second
;
}
}
CHECK
(
is_local_conf
==
false
)
<<
"Haven't deal with local operators."
;
op_node_list
.
push_back
(
op_node
);
});
// Decide the order to visit the op
std
::
vector
<
int32_t
>
order
;
auto
CompareOpName
=
[
&
](
OpNode
*
a
,
OpNode
*
b
)
{
return
a
->
op
().
op_name
().
compare
(
b
->
op
().
op_name
())
>
0
;
};
auto_parallel
::
DecideOrder
(
op_node_list
,
order
,
CompareOpName
);
std
::
vector
<
int32_t
>
output_order
;
// Create sbp nodes
for
(
int32_t
i
=
0
;
i
<
op_node_list
.
size
();
i
++
)
{
OpNode
*
op_node
=
op_node_list
[
order
[
i
]];
// Generate sbp node in cost model and link it with corresponding op node
SbpNode
*
sbp_node
=
sbp_graph_
.
GenerateNode
();
// Mapping from sbp_node to op_node
sbp_node
->
op_node_
=
op_node
;
// TODO: SetOpNode()
op_name2sbp_node_
[
op_node
->
op
().
op_name
()]
=
sbp_node
;
}
// Create sbp edges
for
(
int32_t
i
=
0
;
i
<
op_node_list
.
size
();
i
++
)
{
OpNode
*
op_node
=
op_node_list
[
order
[
i
]];
// Get corresponding sbp node
SbpNode
*
sbp_node
=
op_name2sbp_node_
[
op_node
->
op
().
op_name
()];
std
::
vector
<
OpNode
*>
output_node_list
;
for
(
const
auto
*
op_edge
:
op_node
->
out_edges
())
{
output_node_list
.
push_back
(
op_edge
->
dst_node
());
}
auto_parallel
::
DecideOrder
(
output_node_list
,
output_order
,
CompareOpName
);
for
(
int32_t
j
:
output_order
)
{
const
auto
&
end_node_name
=
output_node_list
[
j
]
->
op
().
op_name
();
// Generate sbp edge in cost model
sbp_node
->
PointTo
(
op_name2sbp_node_
[
end_node_name
]);
}
}
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
SbpConstructor
::
FillSbpSignatureForOpNode
(
const
OpGraph
&
op_graph
,
const
Job
&
job
)
{
// TODO: use user sbp signature in JobParallelViewConf
// const JobParallelViewConf& job_parallel_view_conf(job.job_parallel_view_conf());
JUST
(
op_graph
.
TopoForEachNodeWithErrorCaptured
([
&
](
OpNode
*
op_node
)
->
Maybe
<
void
>
{
HashMap
<
std
::
string
,
const
BlobDesc
*>
ibn2blob_desc
;
auto
FindShape4Blobs
=
[
&
](
const
PbRpf
<
std
::
string
>&
bns
)
->
Maybe
<
void
>
{
for
(
const
std
::
string
&
ibn
:
bns
)
{
const
LogicalBlobId
&
lbi
=
op_node
->
op
().
BnInOp2Lbi
(
ibn
);
const
BlobDesc
*
logical_blob_desc
=
&
op_node
->
LogicalBlobDesc4Lbi
(
lbi
);
ibn2blob_desc
.
emplace
(
ibn
,
logical_blob_desc
);
}
return
Maybe
<
void
>::
Ok
();
};
JUST
(
FindShape4Blobs
(
op_node
->
op
().
input_bns
()));
JUST
(
FindShape4Blobs
(
op_node
->
op
().
output_bns
()));
// Get logical blob description
auto
LogicalBlobDesc4Ibn
=
[
&
](
const
std
::
string
&
ibn
)
->
Maybe
<
const
BlobDesc
&>
{
const
auto
&
it
=
ibn2blob_desc
.
find
(
ibn
);
if
(
it
==
ibn2blob_desc
.
end
())
{
return
Error
::
InvalidValueError
()
<<
"Cannot find corresponding blob description for input_blob_name : "
+
ibn
+
" in "
+
op_node
->
op
().
op_name
();
}
return
*
(
it
->
second
);
};
// Get all valid sbp_signatures
SbpNode
*
sbp_node
=
op_name2sbp_node_
[
op_node
->
op
().
op_name
()];
JUST
(
op_node
->
op
().
GetValidNdSbpSignatureList
(
LogicalBlobDesc4Ibn
,
op_node
->
parallel_desc
(),
&
sbp_node
->
sbp_sig_list_
,
/*check_output=*/
true
));
sbp_node
->
InitializeSbp
();
return
Maybe
<
void
>::
Ok
();
}));
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
SbpConstructor
::
StealSbpSignatureFromOpNode
(
const
OpGraph
&
op_graph
,
const
Job
&
job
)
{
// Steal some strategy from original op graph
for
(
auto
*
sbp_node
:
sbp_graph_
.
node_list_
)
{
// sbp_collectors do not have op_node
if
(
sbp_node
->
op_node_
)
{
for
(
int32_t
sbp_id
=
0
;
sbp_id
<
sbp_node
->
sbp_sig_list_
.
size
();
sbp_id
++
)
{
if
(
*
JUST
(
sbp_node
->
op_node_
->
op
().
nd_sbp_signature
())
==
sbp_node
->
sbp_sig_list_
[
sbp_id
])
{
sbp_node
->
final_sbp_sig_id_
=
sbp_id
;
break
;
}
}
}
}
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
SbpConstructor
::
InitComputationCost
(
const
OpGraph
&
op_graph
)
{
// Compute computation cost for sbp nodes
JUST
(
op_graph
.
TopoForEachNodeWithErrorCaptured
([
&
](
OpNode
*
op_node
)
->
Maybe
<
void
>
{
// get corresponding sbp node producer
SbpNode
*
sbp_node
=
op_name2sbp_node_
[
op_node
->
op
().
op_name
()];
// get parallel description. Number of devices.
const
ParallelDesc
&
parallel_desc
=
op_node
->
parallel_desc
();
CHECK_EQ_OR_RETURN
(
sbp_node
->
cost_
.
size
(),
sbp_node
->
sbp_sig_list_
.
size
());
auto
LogicalBlobDesc4Bn
=
[
&
](
const
std
::
string
&
bn
)
->
const
BlobDesc
&
{
const
LogicalBlobId
&
lbi
=
op_node
->
op
().
BnInOp2Lbi
(
bn
);
return
op_node
->
LogicalBlobDesc4Lbi
(
lbi
);
};
for
(
int32_t
sbp_id
=
0
;
sbp_id
<
sbp_node
->
sbp_sig_list_
.
size
();
sbp_id
++
)
{
double
comp_cost
=
JUST
(
op_node
->
op
().
GetComputeComplexity
(
&
sbp_node
->
sbp_sig_list_
[
sbp_id
],
LogicalBlobDesc4Bn
,
parallel_desc
));
if
(
comp_cost
>
GetValidMaxCopyCost
())
{
sbp_node
->
cost_
[
sbp_id
]
=
comp_cost
;
}
else
{
sbp_node
->
cost_
[
sbp_id
]
=
cost_ratio_
*
comp_cost
*
JUST
(
op_node
->
op
().
GetInputOutputFastestTimeShape
())
->
elem_cnt
();
}
}
return
Maybe
<
void
>::
Ok
();
}));
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
SbpConstructor
::
InitCopyCost
(
const
OpGraph
&
op_graph
)
{
// Compute copy cost for sbp edges
op_graph
.
ForEachNode
([
&
](
OpNode
*
op_node
)
{
// get corresponding sbp node consumer
SbpNode
*
sbp_node_consumer
=
op_name2sbp_node_
[
op_node
->
op
().
op_name
()];
// Initialize copy cost between two nodes
for
(
auto
*
sbp_edge
:
sbp_node_consumer
->
edges_in_
)
{
// producer sbp node
const
auto
*
sbp_node_producer
=
sbp_edge
->
start_node_
;
// skip it if proxy
if
(
!
sbp_node_producer
->
op_node_
)
{
continue
;
}
sbp_edge
->
cost_
.
resize
(
sbp_node_producer
->
sbp_sig_list_
.
size
());
int32_t
consumer_sbp_size
=
sbp_node_consumer
->
sbp_sig_list_
.
size
();
// look through sbp signature in producer
for
(
int32_t
i
=
0
;
i
<
sbp_node_producer
->
sbp_sig_list_
.
size
();
++
i
)
{
sbp_edge
->
cost_
[
i
].
resize
(
consumer_sbp_size
,
0
);
}
}
// Find all those cases with wait time
// Do not skip edges carrying no lbi
sbp_node_consumer
->
InitializeCopyCost
(
use_sbp_collector_
);
});
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
SbpConstructor
::
ApplyTrunkAlgo
()
{
auto
OpNode2MutableOpCtrlDeps
=
JUST
(
GetMutableOpCtrlDeps
(
*
op_graph_
));
// Compute layer number for each node
int32_t
max_min_layer
=
sbp_graph_
.
ComputeLayer
(
op_name2sbp_node_
,
*
OpNode2MutableOpCtrlDeps
);
// Accumulate cost on the trunk after initializing computation cost
sbp_graph_
.
FindTrunk
(
max_min_layer
,
op_name2sbp_node_
);
return
Maybe
<
void
>::
Ok
();
}
// Load logical blob ids onto sbp edges
void
SbpConstructor
::
LoadLbi2SbpEdge
(
const
OpGraph
&
op_graph
)
{
// Load logical blobs onto sbp edges
for
(
auto
*
sbp_node_consumer
:
sbp_graph_
.
node_list_
)
{
auto
*
op_node
=
sbp_node_consumer
->
op_node_
;
// Loading logical blobs between two nodes
// look through input blobs
for
(
const
std
::
string
&
ibn
:
op_node
->
op
().
input_bns
())
{
// Each input blob has one source op node.
OpNode
*
producer
=
op_node
->
MutSrcNode4Ibn
(
ibn
);
// producer sbp node
const
auto
*
sbp_node_producer
=
op_name2sbp_node_
[
producer
->
op
().
op_name
()];
// TODO: recode this
auto
*
edge_found
=
sbp_node_consumer
->
FindEdgeWithNode
(
sbp_node_producer
);
CHECK
(
edge_found
!=
NULL
)
<<
"SbpEdge not found while loading!"
<<
std
::
endl
;
// Add copy cost for each blob
const
LogicalBlobId
&
lbi
=
op_node
->
op
().
BnInOp2Lbi
(
ibn
);
edge_found
->
LoadLbi
(
lbi
);
}
};
}
Maybe
<
void
>
SbpConstructor
::
CheckSbpAgreement
(
const
Job
&
job
)
{
Job
new_job
;
new_job
.
CopyFrom
(
job
);
OpGraph
op_graph
(
new_job
);
// Compare sbp in job
JUST
(
op_graph
.
TopoForEachNodeWithErrorCaptured
([
&
](
OpNode
*
op_node
)
->
Maybe
<
void
>
{
const
std
::
string
&
op_name
=
op_node
->
op
().
op_name
();
const
NdSbpSignature
&
auto_parallel_sbp
=
NdSbpSignature
(
job
.
job_parallel_view_conf
().
op_name2nd_sbp_signature_conf
().
at
(
op_name
));
const
NdSbpSignature
&
new_sbp
=
op_node
->
nd_sbp_signature
();
CHECK_EQ_OR_RETURN
(
auto_parallel_sbp
.
bn_in_op2nd_sbp_size
(),
new_sbp
.
bn_in_op2nd_sbp_size
());
for
(
const
auto
&
iter
:
auto_parallel_sbp
.
bn_in_op2nd_sbp
())
{
const
NdSbp
&
new_sbp_parallel
=
new_sbp
.
bn_in_op2nd_sbp
().
at
(
iter
.
first
);
const
NdSbp
&
auto_parallel_sbp
=
iter
.
second
;
// According error message, we can find op_type in op_conf.proto with type_id and locate
// the error op type.
const
std
::
string
&
error_mgs
=
"Op: `"
+
op_name
+
"`(type_id: "
+
std
::
to_string
(
op_node
->
op
().
op_conf
().
op_type_case
())
+
") changed sbp from "
+
NdSbpToString
(
auto_parallel_sbp
)
+
"(AutoParallel) to "
+
NdSbpToString
(
new_sbp_parallel
)
+
"(OpGraph) with blob_name: `"
+
iter
.
first
+
"`."
;
CHECK_OR_RETURN
(
new_sbp_parallel
==
auto_parallel_sbp
)
<<
error_mgs
;
}
return
Maybe
<
void
>::
Ok
();
}));
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
HashMap
<
const
OpNode
*
,
HashSet
<
std
::
string
>>>
SbpConstructor
::
GetMutableOpCtrlDeps
(
const
OpGraph
&
op_graph
)
{
auto
IsMutableConsumedLbi
=
[](
const
Operator
&
op
,
const
LogicalBlobId
&
lbi
)
->
bool
{
for
(
const
std
::
string
&
bn
:
op
.
input_bns
())
{
if
(
op
.
BnInOp2Lbi
(
bn
)
==
lbi
&&
op
.
InputBlobModifier4Ibn
(
bn
).
is_mutable
())
{
return
true
;
}
}
return
false
;
};
const
auto
&
IsReachable
=
op_graph
.
MakePredicatorIsOpNameDataOrCtrlReachable
();
HashMap
<
const
OpNode
*
,
HashSet
<
std
::
string
>>
op_node2ctrl_in_op_names
;
JUST
(
op_graph
.
MaybeForEachNode
([
&
](
OpNode
*
op_node
)
->
Maybe
<
void
>
{
if
(
op_node
->
op
().
op_conf
().
has_variable_conf
()
==
false
)
{
return
Maybe
<
void
>::
Ok
();
}
if
(
op_node
->
out_edges
().
size
()
<=
1
)
{
return
Maybe
<
void
>::
Ok
();
}
const
Operator
&
variable_op
=
op_node
->
op
();
const
LogicalBlobId
&
variable_lbi
=
variable_op
.
BnInOp2Lbi
(
variable_op
.
SoleObn
());
const
OpNode
*
mutable_consumer
=
nullptr
;
std
::
vector
<
const
OperatorConf
*>
naive_consumers
;
naive_consumers
.
reserve
(
op_node
->
out_edges
().
size
());
for
(
OpEdge
*
edge
:
op_node
->
out_edges
())
{
const
auto
&
op_conf
=
edge
->
dst_node
()
->
op
().
op_conf
();
if
(
IsMutableConsumedLbi
(
edge
->
dst_node
()
->
op
(),
variable_lbi
))
{
CHECK_OR_RETURN
(
mutable_consumer
==
nullptr
);
mutable_consumer
=
edge
->
dst_node
();
}
else
{
naive_consumers
.
emplace_back
(
&
op_conf
);
}
}
if
(
mutable_consumer
==
nullptr
)
{
return
Maybe
<
void
>::
Ok
();
}
for
(
const
auto
*
fw_bw_op
:
naive_consumers
)
{
op_node2ctrl_in_op_names
[
mutable_consumer
].
insert
(
fw_bw_op
->
name
());
}
return
Maybe
<
void
>::
Ok
();
}));
// Filter ctrl edges if all ctrl_in_op_names are reachable
HashMap
<
const
OpNode
*
,
HashSet
<
std
::
string
>>
filter_op_ctrl_deps
;
for
(
const
auto
&
pair
:
op_node2ctrl_in_op_names
)
{
const
OpNode
*
op_node
=
pair
.
first
;
for
(
const
auto
&
fw_bw_op_name
:
pair
.
second
)
{
if
(
!
IsReachable
(
fw_bw_op_name
,
op_node
->
op
().
op_name
()))
{
filter_op_ctrl_deps
[
op_node
].
insert
(
fw_bw_op_name
);
}
}
}
return
filter_op_ctrl_deps
;
}
// Print the graph with SBP in order
void
SbpConstructor
::
PrintSBPGraphDebugInfo
()
{
// sbp constructor information
std
::
cout
<<
"cost_ratio_:"
<<
cost_ratio_
<<
std
::
endl
;
std
::
cout
<<
"wait_time_:"
<<
sbp_graph_
.
wait_time_
<<
std
::
endl
;
std
::
cout
<<
"use_sbp_collector_"
<<
use_sbp_collector_
<<
std
::
endl
;
// test debug
std
::
cout
<<
"Get Into Print Op Graph"
<<
std
::
endl
;
// Collect op_node
std
::
vector
<
OpNode
*>
node_list
;
for
(
const
auto
&
op_name_sbp_node
:
op_name2sbp_node_
)
{
auto
*
op_node_
=
op_name_sbp_node
.
second
->
op_node_
;
if
(
op_node_
)
{
node_list
.
push_back
(
op_node_
);
}
}
// test debug
std
::
cout
<<
"Deciding order"
<<
std
::
endl
;
// Decide the order to visit the op
std
::
vector
<
int32_t
>
order
;
auto_parallel
::
DecideOrder
(
node_list
,
order
,
[
&
](
OpNode
*
a
,
OpNode
*
b
)
{
return
a
->
op
().
op_name
().
compare
(
b
->
op
().
op_name
())
>
0
;
});
std
::
vector
<
int32_t
>
str_order
;
// test debug
std
::
cout
<<
"Finish deciding order"
<<
std
::
endl
;
for
(
int32_t
i
=
0
;
i
<
node_list
.
size
();
i
++
)
{
OpNode
*
op_node
=
node_list
[
order
[
i
]];
std
::
cout
<<
op_node
->
op
().
op_name
()
<<
" (^_^):"
<<
std
::
endl
;
// get corresponding sbp node
const
auto
&
it
=
op_name2sbp_node_
.
find
(
op_node
->
op
().
op_name
());
// Print debug information for sbp graph
CHECK
(
it
!=
op_name2sbp_node_
.
end
());
const
SbpNode
*
sbp_node
=
it
->
second
;
std
::
cout
<<
"Computation Cost: "
<<
sbp_node
->
cost_
[
sbp_node
->
final_sbp_sig_id_
];
std
::
cout
<<
", Min Layer: "
<<
sbp_node
->
min_layer_
<<
", Max Layer: "
<<
sbp_node
->
max_layer_
<<
", Tributary Layer: "
<<
sbp_node
->
tributary_layer_
<<
", in trunk: "
<<
sbp_node
->
on_trunk_
<<
", Remain Cost: "
<<
sbp_node
->
acc_trunk_cost_
<<
std
::
endl
;
// Sort before printing
const
auto
&
op_input_bns
=
op_node
->
op
().
input_bns
();
auto
CompareString
=
[](
const
std
::
string
&
a
,
const
std
::
string
&
b
)
{
return
a
.
compare
(
b
)
>
0
;
};
auto_parallel
::
DecideOrder
(
op_input_bns
,
str_order
,
CompareString
);
const
NdSbpSignature
&
sbp_signature
=
sbp_node
->
FinalSbpSignature
();
// Print out SBP information for input operator
for
(
int32_t
j
:
str_order
)
{
const
auto
&
ibn
=
op_input_bns
[
j
];
const
auto
&
producer_node
=
op_node
->
SrcNode4Ibn
(
ibn
);
std
::
cout
<<
"Pre Op:"
<<
producer_node
.
op
().
op_name
()
<<
": "
<<
ibn
;
const
auto
&
this_sbp_parallel
=
sbp_signature
.
bn_in_op2nd_sbp
().
at
(
ibn
);
std
::
cout
<<
", "
<<
NdSbpToString
(
this_sbp_parallel
);
if
(
RequireSameSbp
(
op_node
,
ibn
))
{
std
::
cout
<<
", require same SBP"
;
}
std
::
cout
<<
", "
<<
op_node
->
LogicalBlobDesc4Lbi
(
op_node
->
op
().
BnInOp2Lbi
(
ibn
)).
shape
().
elem_cnt
();
std
::
cout
<<
std
::
endl
;
}
// Sort before printing
const
auto
&
op_output_bns
=
op_node
->
op
().
output_bns
();
auto_parallel
::
DecideOrder
(
op_output_bns
,
str_order
,
CompareString
);
// Print out SBP information for output blobs
for
(
int32_t
j
:
str_order
)
{
const
auto
&
obn
=
op_output_bns
[
j
];
std
::
cout
<<
"Out Op:"
<<
obn
;
const
auto
&
this_sbp_parallel
=
sbp_signature
.
bn_in_op2nd_sbp
().
at
(
obn
);
std
::
cout
<<
", "
<<
NdSbpToString
(
this_sbp_parallel
);
std
::
cout
<<
", "
<<
op_node
->
LogicalBlobDesc4Lbi
(
op_node
->
op
().
BnInOp2Lbi
(
obn
)).
shape
().
elem_cnt
();
std
::
cout
<<
std
::
endl
;
}
std
::
cout
<<
std
::
endl
;
}
}
}
// namespace auto_parallel
}
// namespace oneflow
oneflow/core/auto_parallel/sbp_constructor.h
0 → 100644
View file @
a715222c
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_AUTO_PARALLEL_SBP_CONSTRUCTOR_H_
#define ONEFLOW_CORE_AUTO_PARALLEL_SBP_CONSTRUCTOR_H_
#include "oneflow/core/common/util.h"
#include "oneflow/core/auto_parallel/sbp_graph.h"
#include "oneflow/core/job/global_for.h"
namespace
oneflow
{
class
OpGraph
;
class
Job
;
namespace
auto_parallel
{
// A constructor which will assemble the sbp_graph with the information from oneflow.
// SbpGraph contains the algorithms for elimination and search which is mainly for the strategy
// itself. Constructor mainly deal with the assemblage of each node, edge and the cost computation,
// activation of functions.
class
SbpConstructor
final
{
public:
OF_DISALLOW_COPY_AND_MOVE
(
SbpConstructor
);
SbpConstructor
()
=
delete
;
SbpConstructor
(
const
OpGraph
&
op_graph
,
Job
*
job
)
:
cost_ratio_
(
job
->
job_conf
().
auto_parallel_computation_cost_ratio
()),
enable_trunk_algo_
(
job
->
job_conf
().
enable_auto_parallel_trunk_algo
()),
use_sbp_collector_
(
!
Singleton
<
ResourceDesc
,
ForSession
>::
Get
()
->
resource
()
.
disable_group_boxing_by_dst_parallel
()
&&
job
->
job_conf
().
enable_auto_parallel_sbp_collector
()),
op_graph_
(
&
op_graph
)
{
sbp_graph_
.
SetWaitTime
(
job
->
job_conf
().
auto_parallel_wait_time
());
CHECK_JUST
(
Init
(
op_graph
,
job
));
}
~
SbpConstructor
()
=
default
;
Maybe
<
void
>
Init
(
const
OpGraph
&
op_graph
,
Job
*
job
);
Maybe
<
void
>
FindBestSbpSignature
();
Maybe
<
void
>
DumpNdSbpSignatureForJob
(
const
OpGraph
&
op_graph
,
Job
*
job
);
// Re-build OpGraph and check all sbp is same between op_graph and job
Maybe
<
void
>
CheckSbpAgreement
(
const
Job
&
job
);
// Print the graph with SBP in order
void
PrintSBPGraphDebugInfo
();
private:
Maybe
<
void
>
InitSbpGraph
(
const
OpGraph
&
op_graph
,
const
Job
&
job
);
Maybe
<
void
>
GenerateNodeAndEdge
(
const
OpGraph
&
op_graph
,
const
Job
&
job
);
Maybe
<
void
>
FillSbpSignatureForOpNode
(
const
OpGraph
&
op_graph
,
const
Job
&
job
);
Maybe
<
void
>
StealSbpSignatureFromOpNode
(
const
OpGraph
&
op_graph
,
const
Job
&
job
);
Maybe
<
void
>
InitComputationCost
(
const
OpGraph
&
op_graph
);
Maybe
<
void
>
InitCopyCost
(
const
OpGraph
&
op_graph
);
Maybe
<
void
>
ApplyTrunkAlgo
();
Maybe
<
HashMap
<
const
OpNode
*
,
HashSet
<
std
::
string
>>>
GetMutableOpCtrlDeps
(
const
OpGraph
&
op_graph
);
// Load logical blob ids onto sbp edges
void
LoadLbi2SbpEdge
(
const
OpGraph
&
op_graph
);
double
cost_ratio_
;
bool
enable_trunk_algo_
;
bool
use_sbp_collector_
;
SbpGraph
sbp_graph_
;
const
OpGraph
*
op_graph_
;
HashMap
<
std
::
string
,
SbpNode
*>
op_name2sbp_node_
;
};
}
// namespace auto_parallel
}
// namespace oneflow
#endif // ONEFLOW_CORE_AUTO_PARALLEL_SBP_CONSTRUCTOR_H_
oneflow/core/auto_parallel/sbp_edge.cpp
0 → 100644
View file @
a715222c
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include <assert.h>
#include <algorithm>
#include <unordered_set>
#include "oneflow/core/job/parallel_desc.h"
#include "oneflow/core/job/lazy_mode.h"
#include "oneflow/core/framework/sbp_infer_util.h"
#include "oneflow/core/auto_parallel/sbp_edge.h"
#include "oneflow/core/auto_parallel/sbp_node.h"
#include "oneflow/core/auto_parallel/sbp_graph.h"
#include "oneflow/core/auto_parallel/sbp_util.h"
#include "oneflow/core/graph/op_graph.h"
namespace
oneflow
{
namespace
auto_parallel
{
// function in cpp. Should be put in one file due to use of template
// Otherwise we will need to declare specific template at the end of cpp file.
SbpEdge
::
SbpEdge
(
SbpNode
*
start_node
,
SbpNode
*
mid_node
,
SbpNode
*
end_node
,
SbpEdge
*
first_edge
,
SbpEdge
*
second_edge
)
:
start_node_
(
start_node
),
mid_node_
(
mid_node
),
end_node_
(
end_node
)
{
edge_list_
.
emplace_back
(
first_edge
);
edge_list_
.
emplace_back
(
second_edge
);
};
// Deconstructor
SbpEdge
::~
SbpEdge
()
{
if
(
mid_node_
!=
nullptr
)
{
delete
mid_node_
;
}
for
(
auto
&
this_edge
:
edge_list_
)
{
delete
this_edge
;
}
}
void
SbpEdge
::
SummarizeCost
()
{
if
(
mid_node_
)
{
cost_
.
resize
(
start_node_
->
cost_
.
size
());
mid_node_sbp_sig_
.
resize
(
start_node_
->
cost_
.
size
());
int32_t
end_node_sbp_size
=
end_node_
->
cost_
.
size
();
int32_t
mid_node_sbp_size
=
mid_node_
->
cost_
.
size
();
for
(
int32_t
sbp_start
=
0
;
sbp_start
<
cost_
.
size
();
sbp_start
++
)
{
cost_
[
sbp_start
].
resize
(
end_node_sbp_size
);
mid_node_sbp_sig_
[
sbp_start
].
resize
(
end_node_sbp_size
);
for
(
int32_t
sbp_end
=
0
;
sbp_end
<
end_node_sbp_size
;
sbp_end
++
)
{
for
(
int32_t
sbp_mid
=
0
;
sbp_mid
<
mid_node_sbp_size
;
sbp_mid
++
)
{
// Add middle node cost
double
temp_cost
=
mid_node_
->
cost_
[
sbp_mid
];
// Add first edge cost
if
(
edge_list_
[
0
]
->
start_node_
==
start_node_
)
{
temp_cost
+=
edge_list_
[
0
]
->
cost_
[
sbp_start
][
sbp_mid
];
}
else
{
temp_cost
+=
edge_list_
[
0
]
->
cost_
[
sbp_mid
][
sbp_start
];
}
// Add second edge cost
if
(
edge_list_
[
1
]
->
end_node_
==
end_node_
)
{
temp_cost
+=
edge_list_
[
1
]
->
cost_
[
sbp_mid
][
sbp_end
];
}
else
{
temp_cost
+=
edge_list_
[
1
]
->
cost_
[
sbp_end
][
sbp_mid
];
}
// Compare and look for the minimum cost
if
(
sbp_mid
==
0
)
{
cost_
[
sbp_start
][
sbp_end
]
=
temp_cost
;
mid_node_sbp_sig_
[
sbp_start
][
sbp_end
]
=
sbp_mid
;
}
else
if
(
temp_cost
<
cost_
[
sbp_start
][
sbp_end
])
{
cost_
[
sbp_start
][
sbp_end
]
=
temp_cost
;
mid_node_sbp_sig_
[
sbp_start
][
sbp_end
]
=
sbp_mid
;
}
}
}
}
}
else
{
cost_
.
resize
(
start_node_
->
cost_
.
size
());
int32_t
end_node_sbp_size
=
end_node_
->
cost_
.
size
();
for
(
int32_t
sbp_start
=
0
;
sbp_start
<
cost_
.
size
();
sbp_start
++
)
{
cost_
[
sbp_start
].
resize
(
end_node_sbp_size
);
for
(
int32_t
sbp_end
=
0
;
sbp_end
<
end_node_sbp_size
;
sbp_end
++
)
{
cost_
[
sbp_start
][
sbp_end
]
=
0
;
for
(
int32_t
edge_num
=
0
;
edge_num
<
edge_list_
.
size
();
edge_num
++
)
{
if
(
edge_list_
[
edge_num
]
->
start_node_
==
start_node_
)
{
cost_
[
sbp_start
][
sbp_end
]
+=
edge_list_
[
edge_num
]
->
cost_
[
sbp_start
][
sbp_end
];
}
else
{
cost_
[
sbp_start
][
sbp_end
]
+=
edge_list_
[
edge_num
]
->
cost_
[
sbp_end
][
sbp_start
];
}
}
}
}
}
}
void
SbpEdge
::
DuplicateCost
(
bool
merged_node_is_start_node
,
bool
duplicating_first_node
,
const
std
::
vector
<
std
::
pair
<
int32_t
,
int32_t
>>&
merged_sig_id2children_sig_id
)
{
const
int32_t
num_sig
=
merged_sig_id2children_sig_id
.
size
();
std
::
vector
<
std
::
vector
<
double
>>
temp_cost
;
std
::
vector
<
std
::
vector
<
int32_t
>>
temp_mid_node_sbp_sig
;
if
(
merged_node_is_start_node
)
{
temp_cost
.
resize
(
num_sig
);
if
(
mid_node_
)
{
temp_mid_node_sbp_sig
.
resize
(
num_sig
);
}
for
(
int32_t
i
=
0
;
i
<
num_sig
;
i
++
)
{
const
int32_t
sig_idx
=
duplicating_first_node
?
merged_sig_id2children_sig_id
[
i
].
first
:
merged_sig_id2children_sig_id
[
i
].
second
;
temp_cost
[
i
]
=
cost_
[
sig_idx
];
if
(
mid_node_
)
{
temp_mid_node_sbp_sig
[
i
]
=
mid_node_sbp_sig_
[
sig_idx
];
}
}
}
else
{
const
int32_t
num_start_sig
=
cost_
.
size
();
temp_cost
.
resize
(
num_start_sig
);
if
(
mid_node_
)
{
temp_mid_node_sbp_sig
.
resize
(
num_start_sig
);
}
for
(
int32_t
i
=
0
;
i
<
num_start_sig
;
i
++
)
{
temp_cost
[
i
].
resize
(
num_sig
);
if
(
mid_node_
)
{
temp_mid_node_sbp_sig
[
i
].
resize
(
num_sig
);
}
for
(
int32_t
j
=
0
;
j
<
num_sig
;
j
++
)
{
const
int32_t
sig_idx
=
duplicating_first_node
?
merged_sig_id2children_sig_id
[
j
].
first
:
merged_sig_id2children_sig_id
[
j
].
second
;
temp_cost
[
i
][
j
]
=
cost_
[
i
][
sig_idx
];
if
(
mid_node_
)
{
temp_mid_node_sbp_sig
[
i
][
j
]
=
mid_node_sbp_sig_
[
i
][
sig_idx
];
}
}
}
}
cost_
=
temp_cost
;
if
(
mid_node_
)
{
mid_node_sbp_sig_
=
temp_mid_node_sbp_sig
;
}
}
void
SbpEdge
::
FinalizeSbp
()
{
// Finalize Sbp for mid_node_
if
(
mid_node_
)
{
mid_node_
->
final_sbp_sig_id_
=
mid_node_sbp_sig_
[
start_node_
->
final_sbp_sig_id_
][
end_node_
->
final_sbp_sig_id_
];
mid_node_
->
FinalizeSbp
();
}
for
(
const
auto
&
this_edge
:
edge_list_
)
{
this_edge
->
FinalizeSbp
();
}
}
double
SbpEdge
::
GreedyStrategy
()
{
// Sbp combination of the minimum cost
int32_t
min_sbp_start
=
start_node_
->
final_sbp_sig_id_
,
min_sbp_end
=
end_node_
->
final_sbp_sig_id_
;
// An unordered_map to evaluate cost between two edge nodes and other nodes.
std
::
unordered_map
<
int32_t
,
int32_t
>
node_list_id2nbh_id
=
{{
start_node_
->
node_list_id_
,
0
},
{
end_node_
->
node_list_id_
,
1
}};
// pre-compute and store the current cost between end_node_ and outside.
std
::
vector
<
double
>
end_node_out_cost
(
end_node_
->
cost_
.
size
());
for
(
int32_t
sbp_end
=
0
;
sbp_end
<
cost_
[
0
].
size
();
sbp_end
++
)
{
end_node_
->
final_sbp_sig_id_
=
sbp_end
;
end_node_out_cost
[
sbp_end
]
=
end_node_
->
EvalOutNbhCost
(
node_list_id2nbh_id
);
}
// pre-compute and store the current cost between start_node_ and outside.
std
::
vector
<
double
>
start_node_out_cost
(
start_node_
->
cost_
.
size
());
for
(
int32_t
sbp_start
=
0
;
sbp_start
<
cost_
.
size
();
sbp_start
++
)
{
start_node_
->
final_sbp_sig_id_
=
sbp_start
;
start_node_out_cost
[
sbp_start
]
=
start_node_
->
EvalOutNbhCost
(
node_list_id2nbh_id
);
}
// Current Cost, Minimum Cost, Cost with original sbp
double
curr_cost
=
0.0
;
double
min_cost
=
start_node_out_cost
[
min_sbp_start
]
+
end_node_out_cost
[
min_sbp_end
]
+
cost_
[
min_sbp_start
][
min_sbp_end
];
double
original_cost
=
min_cost
;
for
(
int32_t
sbp_start
=
0
;
sbp_start
<
cost_
.
size
();
sbp_start
++
)
{
for
(
int32_t
sbp_end
=
0
;
sbp_end
<
cost_
[
0
].
size
();
sbp_end
++
)
{
// compute Current Cost for Neighborhood of edge
end_node_
->
final_sbp_sig_id_
=
sbp_end
;
curr_cost
=
start_node_out_cost
[
sbp_start
]
+
end_node_out_cost
[
sbp_end
]
+
cost_
[
sbp_start
][
sbp_end
];
// Find the minimum current cost
if
(
curr_cost
<
min_cost
)
{
min_cost
=
curr_cost
;
min_sbp_start
=
sbp_start
;
min_sbp_end
=
sbp_end
;
}
}
}
start_node_
->
final_sbp_sig_id_
=
min_sbp_start
;
end_node_
->
final_sbp_sig_id_
=
min_sbp_end
;
return
min_cost
-
original_cost
;
}
// Get the minimum element in Cost
double
SbpEdge
::
GetMinCost
()
{
// used the stored value if pre-computed.
if
(
min_cost_
>=
0
)
{
return
min_cost_
;
}
// Check the size of Cost
CHECK
(
cost_
.
size
()
>
0
)
<<
"Cost not initialized!"
<<
std
::
endl
;
// Compute the min_cost
min_cost_
=
*
std
::
min_element
(
cost_
[
0
].
begin
(),
cost_
[
0
].
end
());
for
(
int32_t
i
=
1
;
i
<
cost_
.
size
();
i
++
)
{
double
min_cost_row
=
*
std
::
min_element
(
cost_
[
i
].
begin
(),
cost_
[
i
].
end
());
if
(
min_cost_row
<
min_cost_
)
{
min_cost_
=
min_cost_row
;
}
}
return
min_cost_
;
}
// Get the maximum element in Cost
double
SbpEdge
::
GetMaxCost
()
const
{
// used the stored value if pre-computed.
// if (max_cost >= 0) return max_cost;
// Check the size of Cost
CHECK
(
cost_
.
size
()
>
0
)
<<
"Cost not initialized!"
<<
std
::
endl
;
// Compute the max_cost
double
max_cost
=
-
1.0
;
for
(
int32_t
i
=
0
;
i
<
cost_
.
size
();
i
++
)
{
for
(
int32_t
j
=
0
;
j
<
cost_
[
i
].
size
();
j
++
)
{
if
(
cost_
[
i
][
j
]
<
GetValidMaxCopyCost
()
&&
cost_
[
i
][
j
]
>
max_cost
)
{
max_cost
=
cost_
[
i
][
j
];
}
}
}
return
max_cost
;
}
// Assemble copy cost
void
SbpEdge
::
InitializeCopyCost
(
const
std
::
string
&
ibn
,
bool
use_sbp_collector
)
{
// In this part, we assemble the cost from nodes to nodes.
if
(
start_node_
->
op_node_
&&
end_node_
->
op_node_
)
{
OpNode
*
consumer
=
end_node_
->
op_node_
;
// Add copy cost for each blob
const
LogicalBlobId
&
lbi
=
consumer
->
op
().
BnInOp2Lbi
(
ibn
);
// Check whether lbi is transferred by this edge
if
(
use_sbp_collector
&&
!
SearchLbi
(
lbi
))
{
return
;
}
OpNode
*
producer
=
start_node_
->
op_node_
;
const
std
::
string
&
producer_lbn
=
*
CHECK_JUST
(
producer
->
op
().
obn4lbi
(
lbi
));
const
ParallelDesc
&
producer_parallel_desc
=
*
CHECK_JUST
(
producer
->
op
().
GetParallelDesc4BnInOp
(
producer_lbn
));
const
ParallelDesc
&
consumer_parallel_desc
=
*
CHECK_JUST
(
consumer
->
op
().
GetParallelDesc4BnInOp
(
ibn
));
// Need to be careful, the logical blob description should be independent to current
// SbpParallel. Use producer or op_node?
const
BlobDesc
&
logical_blob_desc
=
producer
->
LogicalBlobDesc4Lbi
(
lbi
);
const
std
::
string
&
obn
=
*
CHECK_JUST
(
producer
->
op
().
obn4lbi
(
lbi
));
// If we are deciding whether we need the wait time, then make require_same_sbp true.
// B->S cause cudaEventSynchronize in current implementation.
bool
require_same_sbp
=
RequireSameSbp
(
consumer
,
ibn
);
int32_t
consumer_sbp_size
=
end_node_
->
sbp_sig_list_
.
size
();
LazyMode
::
Guard
enable_lazy_mode
(
true
);
// look through sbp signature in producer
for
(
int32_t
sbp_id_producer
=
0
;
sbp_id_producer
<
start_node_
->
sbp_sig_list_
.
size
();
sbp_id_producer
++
)
{
// get sbp parallel for a logical blob in producer
const
auto
&
producer_sbp_bn_in_op2sbp_parallel
=
start_node_
->
sbp_sig_list_
[
sbp_id_producer
].
bn_in_op2nd_sbp
();
const
NdSbp
&
sbp_producer
=
producer_sbp_bn_in_op2sbp_parallel
.
at
(
obn
);
// look through sbp signature in consumer
for
(
int32_t
sbp_id_consumer
=
0
;
sbp_id_consumer
<
consumer_sbp_size
;
sbp_id_consumer
++
)
{
// get sbp parallel for a logical blob in consumer
const
auto
&
consumer_sbp_bn_in_op2sbp_parallel
=
end_node_
->
sbp_sig_list_
[
sbp_id_consumer
].
bn_in_op2nd_sbp
();
const
NdSbp
&
sbp_consumer
=
consumer_sbp_bn_in_op2sbp_parallel
.
at
(
ibn
);
// compute copy cost for a specific logical blob
double
curr_edge_cost
=
CHECK_JUST
(
ComputeCopyCostWithMiddleNodes
(
sbp_producer
,
sbp_consumer
,
logical_blob_desc
,
producer_parallel_desc
,
consumer_parallel_desc
,
require_same_sbp
));
if
(
curr_edge_cost
<
GetValidMaxCopyCost
())
{
cost_
[
sbp_id_producer
][
sbp_id_consumer
]
+=
CHECK_JUST
(
producer
->
op
().
GetOpTimeShape
())
->
elem_cnt
()
*
curr_edge_cost
;
}
else
{
cost_
[
sbp_id_producer
][
sbp_id_consumer
]
=
curr_edge_cost
;
}
}
}
}
}
// Set the cut ratio
double
SbpEdge
::
GetCutRatio
()
const
{
int32_t
num
=
0
;
for
(
int32_t
i
=
0
;
i
<
cost_
.
size
();
i
++
)
{
for
(
int32_t
j
=
0
;
j
<
cost_
[
i
].
size
();
j
++
)
{
if
(
cost_
[
i
][
j
]
<
GetValidMaxCopyCost
())
{
num
++
;
}
}
}
return
double
(
num
)
/
double
(
cost_
.
size
()
*
cost_
[
0
].
size
());
}
// find the cut ratio
// (#c>GetValidMaxCopyCost() in Cost)/(#c in Cost)
double
SbpEdge
::
FindCutRatio
(
int32_t
threshold
)
const
{
double
cut_ratio
=
GetCutRatio
();
// lift the cut ratio to 1 to filter out some improper couples to avoid unlimited merging
double
n
=
cost_
.
size
();
double
m
=
cost_
[
0
].
size
();
double
num
=
cut_ratio
*
n
*
m
;
cut_ratio
+=
0.16
*
(
n
+
m
)
/
double
(
threshold
);
if
(
num
<=
n
*
2
||
num
<=
m
*
2
||
(
num
<=
threshold
&&
cut_ratio
<
0.51
))
{
return
cut_ratio
;
}
else
{
return
1.0
;
}
}
// load a logical blob
void
SbpEdge
::
LoadLbi
(
const
LogicalBlobId
&
lbi
)
{
carry_lbis_
.
insert
(
lbi
);
}
// check the existence of a logical blob
bool
SbpEdge
::
SearchLbi
(
const
LogicalBlobId
&
lbi
)
const
{
return
carry_lbis_
.
find
(
lbi
)
!=
carry_lbis_
.
end
();
}
// unload a logical blob
void
SbpEdge
::
UnloadLbi
(
const
LogicalBlobId
&
lbi
)
{
if
(
carry_lbis_
.
erase
(
lbi
)
==
0
)
{
std
::
cout
<<
"Unload an empty lbi!"
<<
std
::
endl
;
}
}
// Not carrying any blob
bool
SbpEdge
::
EmptyLbi
()
const
{
return
carry_lbis_
.
empty
();
}
}
// namespace auto_parallel
}
// namespace oneflow
oneflow/core/auto_parallel/sbp_edge.h
0 → 100644
View file @
a715222c
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_AUTO_PARALLEL_SBP_EDGE_H_
#define ONEFLOW_CORE_AUTO_PARALLEL_SBP_EDGE_H_
#include <assert.h>
#include <algorithm>
#include <unordered_set>
#include "oneflow/core/common/util.h"
#include "oneflow/core/job/parallel_desc.h"
#include "oneflow/core/job/lazy_mode.h"
#include "oneflow/core/framework/sbp_infer_util.h"
#include "oneflow/core/auto_parallel/sbp_node.h"
#include "oneflow/core/auto_parallel/sbp_util.h"
#include "oneflow/core/graph/op_graph.h"
namespace
oneflow
{
namespace
auto_parallel
{
// An edge structure to deal with the SBP strategy.
// Please see SbpGraph for the whole algorithm and introduction.
class
SbpEdge
final
{
/* There are 3 types of edges:
* 1. start_node_ -> end_node_
* Nothing special
* 2. Multiple start_node_ -> end_node_
* edge_list_ will store all the edges which goes from start_node_ to end_node_
* 3. start_node_ -> mid_node_ -> end_node_
* It will pass by a middle node.
*/
public:
// Constructor for type 1 & 2
SbpEdge
(
SbpNode
*
start_node
,
SbpNode
*
end_node
)
:
start_node_
(
start_node
),
end_node_
(
end_node
)
{
mid_node_
=
nullptr
;
}
// Constructor for type 3
SbpEdge
(
SbpNode
*
start_node
,
SbpNode
*
mid_node
,
SbpNode
*
end_node
,
SbpEdge
*
first_edge
,
SbpEdge
*
second_edge
);
// Deconstructor
~
SbpEdge
();
OF_DISALLOW_COPY_AND_MOVE
(
SbpEdge
);
bool
operator
==
(
const
SbpEdge
&
other
)
{
return
this
==
&
other
;
}
// Update copy cost for type 2 and 3
void
SummarizeCost
();
// Duplicate Cost. Designed for merging two nodes.
void
DuplicateCost
(
bool
merged_node_is_start_node
,
bool
duplicating_first_node
,
const
std
::
vector
<
std
::
pair
<
int32_t
,
int32_t
>>&
merged_sig_id2children_sig_id
);
// Determine Final SbpSignature for attachment of this edge
void
FinalizeSbp
();
// Use Greedy Strategy to pick the sbp signature with minimum cost for this
// edge. You should have an initial strategy before running this. And the
// graph should be fully eliminated.
double
GreedyStrategy
();
// load a logical blob
void
LoadLbi
(
const
LogicalBlobId
&
lbi
);
// check the existence of a logical blob
bool
SearchLbi
(
const
LogicalBlobId
&
lbi
)
const
;
// unload a logical blob
void
UnloadLbi
(
const
LogicalBlobId
&
lbi
);
// Not carrying any blob
bool
EmptyLbi
()
const
;
// Get the minimum element in Cost
double
GetMinCost
();
// Get the maximum element in Cost
double
GetMaxCost
()
const
;
// Assemble copy cost
void
InitializeCopyCost
(
const
std
::
string
&
ibn
,
bool
use_sbp_collector
);
// find the cut ratio
// (#c>GetValidMaxCopyCost() in Cost)/(#c in Cost)
// But we would lift the cut ratio to 1 to filter out some improper couples
double
FindCutRatio
(
int32_t
threshold
)
const
;
// Get the cut ratio
double
GetCutRatio
()
const
;
private:
friend
class
SbpNode
;
friend
class
SbpGraph
;
friend
class
SbpCollector
;
friend
class
SbpConstructor
;
// The edge point from start_node_ to end_node_
// It will have a middle node if and only if type 3
SbpNode
*
start_node_
,
*
mid_node_
,
*
end_node_
;
// Cost[sbp_i][sbp_j] is the total cost from start_node_ with sbp_i to end_node_
// with sbp_j
std
::
vector
<
std
::
vector
<
double
>>
cost_
;
// SbpSignature for mid_node_ with corresponding Cost if type 3, empty otherwise
std
::
vector
<
std
::
vector
<
int32_t
>>
mid_node_sbp_sig_
;
// Contained edge list:
// empty if type 1,
// Parallel edges if type 2,
// succeed edges if type 3
// the edge list might have reverse direction:
// example 1: type 3 edge_list_ contain two edges:
// mid_node_ -> start_node_, mid_node_ -> end_node_;
// example 2: type 2 edge_list_ contain three edges:
// start_node_ -> end_node_, end_node_ -> start_node_, start_node_ -> end_node_;
std
::
vector
<
SbpEdge
*>
edge_list_
;
// Time waiting for other gpus. pthread_cond_wait
double
wait_time_
=
-
1.0
;
// a set of ids of logical blobs carried/transferred on this sbp edge
std
::
unordered_set
<
LogicalBlobId
>
carry_lbis_
;
// Minimum and maximum cost would not be changed by eliminations, which will generate new edges.
// Also would not be changed by node merging, which will only perform cost copy for the expanding
// dimensions.
// Minimum cost in the 2D array Cost.
// Would be initialized after GetMinCost();
// Only used in the final graph.
double
min_cost_
=
-
1.0
;
};
}
// namespace auto_parallel
}
// namespace oneflow
#endif // ONEFLOW_CORE_AUTO_PARALLEL_SBP_EDGE_H_
oneflow/core/auto_parallel/sbp_graph.cpp
0 → 100644
View file @
a715222c
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include <algorithm>
#include <unordered_map>
#include "oneflow/core/auto_parallel/binary_set.h"
#include "oneflow/core/auto_parallel/sbp_graph.h"
#include "oneflow/core/auto_parallel/sbp_edge.h"
#include "oneflow/core/auto_parallel/sbp_node.h"
#include "oneflow/core/auto_parallel/algorithm_util.h"
namespace
oneflow
{
namespace
auto_parallel
{
// function in cpp. Should be put in one file due to use of template
// Otherwise we will need to declare specific template at the end of cpp file.
namespace
{
static
const
int32_t
kMinNodeInGraphForMerging
=
4
;
}
// anonymous namespace
// Generate a node
SbpNode
*
SbpGraph
::
GenerateNode
()
{
SbpNode
*
this_node
=
new
SbpNode
();
node_list_
.
emplace_back
(
this_node
);
this_node
->
node_list_id_
=
node_list_
.
size
()
-
1
;
return
this_node
;
}
void
SbpGraph
::
RemoveFromNodeList
(
SbpNode
*
this_node
)
{
if
(
this_node
->
node_list_id_
<
0
)
{
return
;
}
node_list_
.
back
()
->
node_list_id_
=
this_node
->
node_list_id_
;
RemoveFrom
<
SbpNode
*>
(
node_list_
,
this_node
->
node_list_id_
);
this_node
->
node_list_id_
=
-
1
;
}
SbpGraph
::~
SbpGraph
()
{
for
(
auto
this_node
:
node_list_
)
{
delete
this_node
;
}
node_list_
.
clear
();
}
void
SbpGraph
::
RandomSbpSignature
(
bool
use_sbp_collector
)
const
{
for
(
const
auto
&
this_node
:
node_list_
)
{
if
(
this_node
->
sbp_sig_list_
.
size
()
>
0
)
{
this_node
->
final_sbp_sig_id_
=
rand
()
%
this_node
->
sbp_sig_list_
.
size
();
}
else
{
// It must be a proxy when this_node->sbp_sig_list_.size() == 0
this_node
->
final_sbp_sig_id_
=
rand
()
%
this_node
->
parallel_candidates_
.
size
();
}
}
};
void
SbpGraph
::
SetDefaultSbpSig
()
const
{
for
(
const
auto
&
this_node
:
node_list_
)
{
this_node
->
final_sbp_sig_id_
=
0
;
}
};
double
SbpGraph
::
ComputeCost
()
const
{
// Over All Cost under current strategy
double
graph_cost_
=
0
;
for
(
const
auto
&
this_node
:
node_list_
)
{
int32_t
this_id
=
this_node
->
final_sbp_sig_id_
;
graph_cost_
+=
this_node
->
cost_
[
this_id
];
for
(
const
auto
&
edge_out
:
this_node
->
edges_out_
)
{
graph_cost_
+=
edge_out
->
cost_
[
this_id
][
edge_out
->
end_node_
->
final_sbp_sig_id_
];
}
}
return
graph_cost_
;
}
int32_t
SbpGraph
::
NodeElimination
(
SbpNode
*
this_node
)
{
if
(
this_node
->
edges_in_
.
size
()
+
this_node
->
edges_out_
.
size
()
==
2
)
{
std
::
vector
<
SbpNode
*>
two_nodes
;
for
(
const
auto
&
one_edge
:
this_node
->
edges_in_
)
two_nodes
.
emplace_back
(
one_edge
->
start_node_
);
for
(
const
auto
&
one_edge
:
this_node
->
edges_out_
)
two_nodes
.
emplace_back
(
one_edge
->
end_node_
);
// If a node is pointing to itself, could happen when shrink from a circle
if
(
two_nodes
[
0
]
==
two_nodes
[
1
])
{
int32_t
elimination_number
=
0
;
if
(
this_node
->
edges_out_
.
empty
())
{
elimination_number
+=
EdgeElimination
(
two_nodes
[
0
]);
}
else
{
elimination_number
+=
EdgeElimination
(
this_node
);
}
elimination_number
+=
ChildElimination
(
this_node
);
return
elimination_number
;
}
std
::
vector
<
SbpEdge
*>
two_edges
(
this_node
->
edges_in_
);
two_edges
.
insert
(
two_edges
.
end
(),
this_node
->
edges_out_
.
begin
(),
this_node
->
edges_out_
.
end
());
int32_t
edges_in_size
=
this_node
->
edges_in_
.
size
();
SbpEdge
*
e
=
new
SbpEdge
(
two_nodes
[
0
],
this_node
,
two_nodes
[
1
],
two_edges
[
0
],
two_edges
[
1
]);
e
->
SummarizeCost
();
// check and remove the edge_in with new edge in graph
for
(
int32_t
i
=
0
;
i
<
edges_in_size
;
i
++
)
{
CheckAndRemoveFrom
<
SbpEdge
*>
(
two_nodes
[
i
]
->
edges_out_
,
two_edges
[
i
]);
}
// check and remove the edge_out with new edge in graph
for
(
int32_t
i
=
edges_in_size
;
i
<
2
;
i
++
)
{
CheckAndRemoveFrom
<
SbpEdge
*>
(
two_nodes
[
i
]
->
edges_in_
,
two_edges
[
i
]);
}
// Let e take control of edge_list_ completely by disconnecting MidNode
e
->
mid_node_
->
edges_out_
.
clear
();
e
->
mid_node_
->
edges_in_
.
clear
();
// Insert new compound edge into graph
two_nodes
[
0
]
->
edges_out_
.
emplace_back
(
e
);
two_nodes
[
1
]
->
edges_in_
.
emplace_back
(
e
);
// eliminate the node from graph by swapping with the last element and
// popping
RemoveFromNodeList
(
this_node
);
// successfully eliminate this node
return
1
;
}
// can not eliminate this node
return
0
;
}
int32_t
SbpGraph
::
NodeAndEdgeEliminations
()
{
// Total elimination number
int32_t
total_elimination_num
=
0
;
int32_t
elimination_num
=
1
;
// repeat these kinds of elimination until stuck
while
(
elimination_num
>
0
)
{
elimination_num
=
0
;
for
(
int32_t
i
=
node_list_
.
size
()
-
1
;
i
>=
0
;
i
--
)
{
elimination_num
+=
NodeElimination
(
node_list_
[
i
]);
}
for
(
int32_t
i
=
node_list_
.
size
()
-
1
;
i
>=
0
;
i
--
)
{
elimination_num
+=
EdgeElimination
(
node_list_
[
i
]);
}
for
(
int32_t
i
=
node_list_
.
size
()
-
1
;
i
>=
0
;
i
--
)
{
elimination_num
+=
ChildElimination
(
node_list_
[
i
]);
}
if
(
elimination_num
==
0
&&
node_list_
.
size
()
>
2
)
{
elimination_num
+=
PickAndMerge
();
for
(
int32_t
i
=
node_list_
.
size
()
-
1
;
i
>=
0
;
i
--
)
{
elimination_num
+=
EdgeElimination
(
node_list_
[
i
]);
}
}
total_elimination_num
+=
elimination_num
;
}
return
total_elimination_num
;
}
int32_t
SbpGraph
::
EdgeElimination
(
SbpNode
*
this_node
)
const
{
// Remove all edges with (start_node -> end_node) from edges_in_ of end_node
auto
RemoveFromEdgesIn
=
[](
SbpNode
*
start_node
,
SbpNode
*
end_node
)
->
void
{
for
(
int32_t
i
=
end_node
->
edges_in_
.
size
()
-
1
;
i
>=
0
;
i
--
)
{
if
(
start_node
==
end_node
->
edges_in_
[
i
]
->
start_node_
)
{
RemoveFrom
<
SbpEdge
*>
(
end_node
->
edges_in_
,
i
);
}
}
};
auto
LookForParallelEdge
=
[](
SbpEdge
*&
e
,
SbpNode
*
start_node
,
SbpNode
*
end_node
,
bool
if_reverse
,
int32_t
stop_sign
)
->
int32_t
{
// elimination edges with specific start node and end node in
// start_node->edges_out_ from index stop sign to the end.
// start_node->edges_out_[stop_sign] not included and need special treatment
// after this process.
int32_t
elimination_num
=
0
;
for
(
int32_t
j
=
start_node
->
edges_out_
.
size
()
-
1
;
j
>
stop_sign
;
j
--
)
{
if
(
end_node
==
start_node
->
edges_out_
[
j
]
->
end_node_
)
{
if
(
!
e
)
{
if
(
if_reverse
)
{
e
=
new
SbpEdge
(
end_node
,
start_node
);
}
else
{
e
=
new
SbpEdge
(
start_node
,
end_node
);
}
}
// edge elimination
e
->
edge_list_
.
emplace_back
(
start_node
->
edges_out_
[
j
]);
elimination_num
++
;
RemoveFrom
<
SbpEdge
*>
(
start_node
->
edges_out_
,
j
);
}
}
return
elimination_num
;
};
int32_t
elimination_num
=
0
;
for
(
int32_t
i
=
0
;
i
<
this_node
->
edges_out_
.
size
();
i
++
)
{
SbpEdge
*
e
=
nullptr
;
// Find and delete Parallel Edges from edges_out_
elimination_num
+=
LookForParallelEdge
(
e
,
this_node
,
this_node
->
edges_out_
[
i
]
->
end_node_
,
/*if_reverse=*/
false
,
i
);
elimination_num
+=
LookForParallelEdge
(
e
,
this_node
->
edges_out_
[
i
]
->
end_node_
,
this_node
,
/*if_reverse=*/
true
,
/*stop_sign=*/
-
1
);
if
(
e
)
{
// Delete Parallel Edges from edges_in_
RemoveFromEdgesIn
(
this_node
,
e
->
end_node_
);
RemoveFromEdgesIn
(
e
->
end_node_
,
this_node
);
// Add the compound edge
e
->
edge_list_
.
emplace_back
(
this_node
->
edges_out_
[
i
]);
this_node
->
edges_out_
[
i
]
=
e
;
e
->
SummarizeCost
();
e
->
end_node_
->
edges_in_
.
emplace_back
(
e
);
}
}
return
elimination_num
;
}
int32_t
SbpGraph
::
ChildElimination
(
SbpNode
*
this_node
)
{
if
(
this_node
->
EliminateItselfAsChild
())
{
// eliminate this node from global node list
RemoveFromNodeList
(
this_node
);
// successfully eliminate this node
return
1
;
}
else
{
// can not eliminate this node
return
0
;
}
}
// Merge two nodes
int32_t
SbpGraph
::
NodeMerging
(
SbpNode
*
first
,
SbpNode
*
second
)
{
SbpNode
*
new_node
=
new
SbpNode
(
first
,
second
);
// Adjust node_list_
RemoveFromNodeList
(
first
);
RemoveFromNodeList
(
second
);
new_node
->
node_list_id_
=
node_list_
.
size
();
node_list_
.
emplace_back
(
new_node
);
return
1
;
}
void
SbpGraph
::
FinalizeSbp
()
const
{
for
(
const
auto
&
this_node
:
node_list_
)
{
this_node
->
FinalizeSbp
();
}
}
double
SbpGraph
::
GreedyStrategy
(
bool
for_node
)
const
{
// Overall, this function should be replaced by GreedyStrategy(nbh_num);
// Total Cost Reduce & Cost Reduce for one loop
double
total_cost_reduction
=
0
,
cost_reduction
=
0
;
for
(
int32_t
step
=
node_list_
.
size
();
step
>=
0
;
step
--
)
{
cost_reduction
=
0
;
for
(
SbpNode
*
this_node
:
node_list_
)
{
// Use GreedyStrategy on Nodes if there is one node left for this
// connected component. Otherwise, Use GreedyStrategy on Edges.
if
(
for_node
||
this_node
->
edges_in_
.
size
()
+
this_node
->
edges_out_
.
size
()
==
0
)
{
cost_reduction
+=
this_node
->
GreedyStrategy
();
}
else
{
// GreedyStrategy on Edges.
for
(
SbpEdge
*
this_edge
:
this_node
->
edges_out_
)
{
double
second_rdc
=
this_edge
->
GreedyStrategy
();
cost_reduction
+=
second_rdc
;
}
}
}
if
(
cost_reduction
==
0
)
{
break
;
}
total_cost_reduction
+=
cost_reduction
;
}
return
total_cost_reduction
;
}
double
SbpGraph
::
GreedyStrategy
(
int32_t
nbh_num
)
const
{
// nbh_num is the maximum number of neighborhood to adjust sbp strategy in each step
// Total Cost Reduce & Cost Reduce for one loop
double
total_cost_reduction
=
0
,
cost_reduction
=
0
;
// A global buffer to store part of the one ring neighborhood.
std
::
vector
<
int32_t
>
nbh_id2node_list_id
;
// Not accept a number lower than 1
if
(
nbh_num
<
1
)
{
nbh_num
=
1
;
}
nbh_id2node_list_id
.
resize
(
nbh_num
);
std
::
vector
<
int32_t
>
original_sbp_sig_id
(
nbh_num
);
// store all the node_list_id whose corresponding nodes will be visited
// We can use unordered_map to do this but vector is faster
std
::
vector
<
int32_t
>
pre_visit_node_list
(
node_list_
.
size
()
+
1
);
for
(
int32_t
nbh_id
=
0
;
nbh_id
<
node_list_
.
size
();
nbh_id
++
)
{
pre_visit_node_list
[
nbh_id
]
=
nbh_id
;
}
int32_t
head
=
0
,
tail
=
node_list_
.
size
();
// whether a node_list_id is in pre_visit_node_list
std
::
vector
<
bool
>
pre_visit_tags
(
node_list_
.
size
(),
true
);
int32_t
step
=
0
;
// 1 ring neighborhood buffer
std
::
vector
<
int32_t
>
nbh_1ring
(
nbh_num
);
// 2 ring neighborhood buffer
std
::
vector
<
int32_t
>
nbh_2ring
;
std
::
vector
<
bool
>
node_tags
(
node_list_
.
size
(),
false
);
std
::
vector
<
int32_t
>
nbh_1ring_buffer
;
while
(
head
!=
tail
&&
step
<
node_list_
.
size
())
{
auto
*
this_node
=
node_list_
[
pre_visit_node_list
[
head
]];
if
(
nbh_num
<=
1
)
{
// Greedy strategy on nodes, here we use nbh_1ring to store the nbh_id2node_list_id
// information for reutilization
nbh_1ring
[
0
]
=
this_node
->
node_list_id_
;
// store the original sbp signature of the 1-ring neighborhood for comparison
original_sbp_sig_id
[
0
]
=
this_node
->
final_sbp_sig_id_
;
cost_reduction
=
NbhGreedyStrategy
(
nbh_1ring
);
}
else
{
// Use GreedyStrategy on the one ring neighborhood of this node.
this_node
->
OneRingNeighborhood
(
nbh_1ring
);
// store the original sbp signature of the 1-ring neighborhood for comparison
original_sbp_sig_id
.
resize
(
nbh_1ring
.
size
());
for
(
int32_t
nbh_id
=
0
;
nbh_id
<
nbh_1ring
.
size
();
nbh_id
++
)
{
original_sbp_sig_id
[
nbh_id
]
=
node_list_
[
nbh_1ring
[
nbh_id
]]
->
final_sbp_sig_id_
;
}
if
(
nbh_1ring
.
size
()
<=
nbh_num
)
{
cost_reduction
=
NbhGreedyStrategy
(
nbh_1ring
);
}
else
{
// Use GreedyStrategy on part of the one ring neighborhood.
// Loop through the neighborhood. Each loop should contain the centroid.
// Initialize part of the one ring neighborhood
int32_t
nbh_1ring_id
=
nbh_1ring
.
size
()
-
nbh_num
;
for
(
int32_t
nbh_id
=
1
;
nbh_id
<
nbh_num
;
++
nbh_id
)
{
nbh_id2node_list_id
[
nbh_id
]
=
nbh_1ring
[
++
nbh_1ring_id
];
}
// loop through the one ring neighborhood
cost_reduction
=
0
;
int32_t
nbh_id
=
0
;
for
(
nbh_1ring_id
=
0
;
nbh_1ring_id
<
nbh_1ring
.
size
();
++
nbh_1ring_id
)
{
nbh_id2node_list_id
[
nbh_id
]
=
nbh_1ring
[
nbh_1ring_id
];
cost_reduction
+=
NbhGreedyStrategy
(
nbh_id2node_list_id
);
// nbh_id for the next step
if
(
++
nbh_id
>=
nbh_num
)
{
nbh_id
=
1
;
}
}
}
}
// change of strategies
if
(
cost_reduction
!=
0
)
{
// Add neighborhood into pre-visited node list for each node with changing strategy
for
(
int32_t
nbh_id
=
0
;
nbh_id
<
nbh_1ring
.
size
();
nbh_id
++
)
{
// If changes occur
if
(
original_sbp_sig_id
[
nbh_id
]
!=
node_list_
[
nbh_1ring
[
nbh_id
]]
->
final_sbp_sig_id_
)
{
// schedule to visit the neighborhood of that changing node
node_list_
[
nbh_1ring
[
nbh_id
]]
->
NRingNeighborhood
(
2
,
nbh_2ring
,
nbh_1ring_buffer
,
node_list_
,
node_tags
);
for
(
int32_t
nbh_node_list_id
:
nbh_2ring
)
{
// Put them into the pre-visited node list
if
(
!
pre_visit_tags
[
nbh_node_list_id
])
{
pre_visit_node_list
[
tail
]
=
nbh_node_list_id
;
pre_visit_tags
[
nbh_node_list_id
]
=
true
;
tail
++
;
if
(
tail
==
pre_visit_node_list
.
size
())
{
tail
=
0
;
}
}
}
}
}
}
// Finish visiting
pre_visit_tags
[
pre_visit_node_list
[
head
]]
=
false
;
head
++
;
if
(
head
==
pre_visit_node_list
.
size
())
{
head
=
0
;
step
++
;
}
total_cost_reduction
+=
cost_reduction
;
}
return
total_cost_reduction
;
}
void
SbpGraph
::
DfsAddNbhCost
(
std
::
vector
<
int32_t
>&
nbh_id2node_list_id
,
std
::
unordered_map
<
int32_t
,
int32_t
>&
node_list_id2nbh_id
,
std
::
vector
<
int32_t
>&
order2nbh_id
,
std
::
vector
<
int32_t
>&
nbh_id2order
,
std
::
vector
<
double
>&
order2acc_min_in_nbh_cost
,
std
::
vector
<
std
::
vector
<
double
>>&
out_nbh_costs
,
std
::
vector
<
std
::
vector
<
int32_t
>>&
nbh_id2order2sbp_id
,
std
::
vector
<
int32_t
>&
min_sbp_sig_id
,
double
&
min_cost
,
int32_t
order
,
double
curr_cost
)
const
{
// We have finished visiting the neighborhood
if
(
order
>=
nbh_id2node_list_id
.
size
())
{
// relative difference > 1e-12
if
(
curr_cost
<
min_cost
*
kFloatDeviationMinus
)
{
min_cost
=
curr_cost
;
for
(
int32_t
nbh_id
=
0
;
nbh_id
<
nbh_id2node_list_id
.
size
();
nbh_id
++
)
{
min_sbp_sig_id
[
nbh_id
]
=
node_list_
[
nbh_id2node_list_id
[
nbh_id
]]
->
final_sbp_sig_id_
;
}
}
return
;
}
// Pruning, remove all those branch with large cost
if
(
curr_cost
+
order2acc_min_in_nbh_cost
[
order
]
>=
min_cost
)
{
return
;
}
// Deep first search in the next order
int32_t
nbh_id
=
order2nbh_id
[
order
];
SbpNode
*
sbp_node
=
node_list_
[
nbh_id2node_list_id
[
nbh_id
]];
for
(
int32_t
sbp_id
:
nbh_id2order2sbp_id
[
nbh_id
])
{
sbp_node
->
final_sbp_sig_id_
=
sbp_id
;
DfsAddNbhCost
(
nbh_id2node_list_id
,
node_list_id2nbh_id
,
order2nbh_id
,
nbh_id2order
,
order2acc_min_in_nbh_cost
,
out_nbh_costs
,
nbh_id2order2sbp_id
,
min_sbp_sig_id
,
min_cost
,
order
+
1
,
curr_cost
+
out_nbh_costs
[
nbh_id
][
sbp_id
]
+
sbp_node
->
EvalInNbhCost
(
node_list_id2nbh_id
,
nbh_id2order
));
}
}
bool
SbpGraph
::
DfsFindReasonableCost
(
std
::
vector
<
int32_t
>&
nbh_id2node_list_id
,
std
::
unordered_map
<
int32_t
,
int32_t
>&
node_list_id2nbh_id
,
std
::
vector
<
int32_t
>&
nbh_id2order
,
int32_t
nbh_id
)
const
{
// We found such a strategy
if
(
nbh_id
==
nbh_id2order
.
size
())
{
return
true
;
}
SbpNode
*
sbp_node
=
node_list_
[
nbh_id2node_list_id
[
nbh_id
]];
// Start from B.
for
(
int32_t
sbp_id
=
sbp_node
->
cost_
.
size
()
-
1
;
sbp_id
>=
0
;
sbp_id
--
)
{
sbp_node
->
final_sbp_sig_id_
=
sbp_id
;
// If the cost for this node is reasonable, then go to the next one
if
(
sbp_node
->
cost_
[
sbp_id
]
+
sbp_node
->
EvalInNbhCost
(
node_list_id2nbh_id
,
nbh_id2order
)
<
GetValidMaxCopyCost
())
{
if
(
DfsFindReasonableCost
(
nbh_id2node_list_id
,
node_list_id2nbh_id
,
nbh_id2order
,
nbh_id
+
1
))
{
// If we found one strategy, then exist the Dfs.
return
true
;
}
}
}
// Can not find a reasonable strategy with the setting for previous nodes.
// Go back and change the previous node.
return
false
;
}
// Find one strategy with finite cost for adjustment
Maybe
<
void
>
SbpGraph
::
Find1Strategy4Greedy
()
const
{
std
::
vector
<
int32_t
>
nbh_id2node_list_id
;
std
::
vector
<
bool
>
not_visited
(
node_list_
.
size
(),
true
);
std
::
vector
<
int32_t
>
nbh_1ring
;
int32_t
head
=
0
;
int32_t
tail
=
0
;
std
::
vector
<
double
>
node_cut_ratios
(
node_list_
.
size
());
// Initialize cut ratio for all the nodes
for
(
int32_t
node_list_id
=
0
;
node_list_id
<
node_list_
.
size
();
node_list_id
++
)
{
node_cut_ratios
[
node_list_id
]
=
node_list_
[
node_list_id
]
->
GetCutRatio
();
}
// If have not visited all the nodes
while
(
tail
<
node_list_
.
size
())
{
// Find the node with the minimum cut ratio
int32_t
node_with_min_cut_ratio
=
-
1
;
double
min_cut_ratio
=
2.0
;
for
(
int32_t
node_list_id
=
0
;
node_list_id
<
node_list_
.
size
();
node_list_id
++
)
{
if
(
not_visited
[
node_list_id
])
{
double
curr_cut_ratio
=
node_cut_ratios
[
node_list_id
];
if
(
curr_cut_ratio
<
min_cut_ratio
)
{
min_cut_ratio
=
curr_cut_ratio
;
node_with_min_cut_ratio
=
node_list_id
;
}
}
}
// put this node into the open set
nbh_id2node_list_id
.
push_back
(
node_with_min_cut_ratio
);
not_visited
[
node_with_min_cut_ratio
]
=
false
;
tail
++
;
// BFS
while
(
head
<
tail
)
{
// look for the neighborhood of the head
int32_t
node_list_id
=
nbh_id2node_list_id
[
head
];
node_list_
[
node_list_id
]
->
OneRingNeighborhood
(
nbh_1ring
);
// sort
std
::
sort
(
nbh_1ring
.
begin
(),
nbh_1ring
.
end
(),
[
&
](
int32_t
i
,
int32_t
j
)
{
return
node_cut_ratios
[
i
]
<
node_cut_ratios
[
j
];
});
for
(
int32_t
curr_id
:
nbh_1ring
)
{
if
(
not_visited
[
curr_id
])
{
nbh_id2node_list_id
.
push_back
(
curr_id
);
tail
++
;
not_visited
[
curr_id
]
=
false
;
}
}
head
++
;
}
}
// mapping from the node_list_id to the id in the nbh_id2node_list_id
std
::
unordered_map
<
int32_t
,
int32_t
>
node_list_id2nbh_id
;
InverseFunction
<
int32_t
>
(
nbh_id2node_list_id
,
node_list_id2nbh_id
);
// Initial an ordinary order
std
::
vector
<
int32_t
>
nbh_id2order
(
nbh_id2node_list_id
.
size
());
for
(
int32_t
nbh_id
=
0
;
nbh_id
<
nbh_id2node_list_id
.
size
();
nbh_id
++
)
{
nbh_id2order
[
nbh_id
]
=
nbh_id
;
}
// Combining deep first search and pruning based on cut ratio
CHECK
(
DfsFindReasonableCost
(
nbh_id2node_list_id
,
node_list_id2nbh_id
,
nbh_id2order
,
/*nbh_id=*/
0
))
<<
"Can't find a reasonable strategy!"
;
return
Maybe
<
void
>::
Ok
();
}
// Use brute force to search for a strategy with minimum cost for a neighborhood
double
SbpGraph
::
NbhGreedyStrategy
(
std
::
vector
<
int32_t
>&
nbh_id2node_list_id
)
const
{
// number of nodes in the neighborhood
int32_t
num_nbh
=
nbh_id2node_list_id
.
size
();
// mapping from the node_list_id to the id in the nbh_id2node_list_id
std
::
unordered_map
<
int32_t
,
int32_t
>
node_list_id2nbh_id
;
InverseFunction
<
int32_t
>
(
nbh_id2node_list_id
,
node_list_id2nbh_id
);
// a sbp signature id set minimizing the overall cost, store the original one as default
std
::
vector
<
int32_t
>
min_sbp_sig_id
(
num_nbh
);
for
(
int32_t
nbh_id
=
0
;
nbh_id
<
num_nbh
;
nbh_id
++
)
{
min_sbp_sig_id
[
nbh_id
]
=
node_list_
[
nbh_id2node_list_id
[
nbh_id
]]
->
final_sbp_sig_id_
;
}
// pre-compute and store the cost between neighborhood and outside nodes under different sbp for
// each node within the neighborhood
std
::
vector
<
std
::
vector
<
double
>>
out_nbh_costs
(
num_nbh
);
for
(
int32_t
nbh_id
=
0
;
nbh_id
<
num_nbh
;
nbh_id
++
)
{
SbpNode
*
sbp_node
=
node_list_
[
nbh_id2node_list_id
[
nbh_id
]];
out_nbh_costs
[
nbh_id
].
resize
(
sbp_node
->
cost_
.
size
());
for
(
int32_t
sbp_id
=
sbp_node
->
cost_
.
size
()
-
1
;
sbp_id
>=
0
;
sbp_id
--
)
{
sbp_node
->
final_sbp_sig_id_
=
sbp_id
;
out_nbh_costs
[
nbh_id
][
sbp_id
]
=
sbp_node
->
EvalOutNbhCost
(
node_list_id2nbh_id
);
}
}
// pre-compute and store the order of the out_nbh_costs
std
::
vector
<
std
::
vector
<
int32_t
>>
nbh_id2order2sbp_id
(
num_nbh
);
auto
CompareDoubleLess
=
[](
double
a
,
double
b
)
{
return
a
<
b
;
};
for
(
int32_t
nbh_id
=
0
;
nbh_id
<
num_nbh
;
nbh_id
++
)
{
DecideOrder
(
out_nbh_costs
[
nbh_id
],
nbh_id2order2sbp_id
[
nbh_id
],
CompareDoubleLess
);
}
// Decide the order to go through the neighborhood.
// Should visit those nodes with a larger difference in the out cost first.
std
::
vector
<
double
>
out_nbh_cost_diff
(
num_nbh
);
for
(
int32_t
nbh_id
=
0
;
nbh_id
<
num_nbh
;
nbh_id
++
)
{
out_nbh_cost_diff
[
nbh_id
]
=
*
std
::
max_element
(
out_nbh_costs
[
nbh_id
].
begin
(),
out_nbh_costs
[
nbh_id
].
end
())
-
*
std
::
min_element
(
out_nbh_costs
[
nbh_id
].
begin
(),
out_nbh_costs
[
nbh_id
].
end
());
}
std
::
vector
<
int32_t
>
order2nbh_id
;
DecideOrder
(
out_nbh_cost_diff
,
order2nbh_id
,
[](
double
a
,
double
b
)
{
return
a
>
b
;
});
// Find the inverse map of order
std
::
vector
<
int32_t
>
nbh_id2order
;
InverseOrder
(
order2nbh_id
,
nbh_id2order
);
// Current Cost, Minimum Cost, Cost with original sbp
double
original_cost
=
0
;
// Recover original sbp
for
(
int32_t
nbh_id
=
0
;
nbh_id
<
num_nbh
;
nbh_id
++
)
{
node_list_
[
nbh_id2node_list_id
[
nbh_id
]]
->
final_sbp_sig_id_
=
min_sbp_sig_id
[
nbh_id
];
}
// Compute cost with original sbp
for
(
int32_t
nbh_id
=
0
;
nbh_id
<
num_nbh
;
nbh_id
++
)
{
SbpNode
*
sbp_node
=
node_list_
[
nbh_id2node_list_id
[
nbh_id
]];
original_cost
+=
out_nbh_costs
[
nbh_id
][
min_sbp_sig_id
[
nbh_id
]];
original_cost
+=
sbp_node
->
EvalInNbhCost
(
node_list_id2nbh_id
,
nbh_id2order
);
}
double
min_cost
=
original_cost
;
// Accumulate minimum cost from the current node to the end of the neighborhood node list.
// The accumulated cost include the current node.
std
::
vector
<
double
>
order2acc_min_in_nbh_cost
(
num_nbh
);
order2acc_min_in_nbh_cost
[
num_nbh
-
1
]
=
*
std
::
min_element
(
out_nbh_costs
[
order2nbh_id
[
num_nbh
-
1
]].
begin
(),
out_nbh_costs
[
order2nbh_id
[
num_nbh
-
1
]].
end
());
for
(
int32_t
order
=
num_nbh
-
2
;
order
>=
0
;
order
--
)
{
int32_t
nbh_id
=
order2nbh_id
[
order
];
order2acc_min_in_nbh_cost
[
order
]
=
order2acc_min_in_nbh_cost
[
order
+
1
]
+
*
std
::
min_element
(
out_nbh_costs
[
nbh_id
].
begin
(),
out_nbh_costs
[
nbh_id
].
end
())
+
node_list_
[
nbh_id2node_list_id
[
nbh_id
]]
->
EvalMinInNbhCost
(
node_list_id2nbh_id
,
nbh_id2order
);
}
// Use brute force (DFS) to adjust for the best strategy in the neighborhood.
DfsAddNbhCost
(
nbh_id2node_list_id
,
node_list_id2nbh_id
,
order2nbh_id
,
nbh_id2order
,
order2acc_min_in_nbh_cost
,
out_nbh_costs
,
nbh_id2order2sbp_id
,
min_sbp_sig_id
,
min_cost
,
/*order=*/
0
,
/*curr_cost=*/
0
);
// Use the sbp strategy with minimum cost
for
(
int32_t
nbh_id
=
0
;
nbh_id
<
num_nbh
;
nbh_id
++
)
{
node_list_
[
nbh_id2node_list_id
[
nbh_id
]]
->
final_sbp_sig_id_
=
min_sbp_sig_id
[
nbh_id
];
}
if
(
min_cost
<
original_cost
)
{
// Directly return (min_cost - original_cost) might have floating point error up to 3e-16
// For example, original_cost: 2.22507e+06, min_cost: 2.22507e+06,
// diff: -4.65661e-10, relative diff:2.09279e-16
// Therefore, we use a threshold to filter out such fake true detection to
// avoid unlimited search.
if
(
original_cost
*
kFloatDeviationMinus
>
min_cost
)
{
return
min_cost
-
original_cost
;
}
}
return
0.0
;
}
// Select and Merge two nodes
int32_t
SbpGraph
::
PickAndMerge
()
{
if
(
node_list_
.
size
()
<
kMinNodeInGraphForMerging
)
{
return
0
;
}
// Pick the one with the smallest cut ratio
double
min_cut_ratio
=
1.0
;
double
curr_cut_ratio
=
0.0
;
SbpEdge
*
merging_edge
=
nullptr
;
for
(
int32_t
i
=
0
;
i
<
node_list_
.
size
();
i
++
)
{
for
(
SbpEdge
*
edge_in
:
node_list_
[
i
]
->
edges_in_
)
{
curr_cut_ratio
=
edge_in
->
FindCutRatio
(
threshold_
);
if
(
curr_cut_ratio
<
min_cut_ratio
)
{
min_cut_ratio
=
curr_cut_ratio
;
merging_edge
=
edge_in
;
}
}
}
if
(
merging_edge
!=
nullptr
)
{
// Merge two nodes on the edge with the minimum cut ratio
return
NodeMerging
(
merging_edge
->
start_node_
,
merging_edge
->
end_node_
);
}
else
{
// Pick the couple with the largest similar neighborhood
std
::
vector
<
BinarySet
>
node_binary_sets
(
node_list_
.
size
());
for
(
int32_t
i
=
0
;
i
<
node_list_
.
size
();
i
++
)
{
// Transfer edge to binary set
node_binary_sets
[
i
].
Initialize
(
node_list_
.
size
());
node_binary_sets
[
i
].
AddEntry
(
i
);
for
(
const
SbpEdge
*
edge_in
:
node_list_
[
i
]
->
edges_in_
)
{
node_binary_sets
[
i
].
AddEntry
(
edge_in
->
start_node_
->
node_list_id_
);
}
for
(
const
SbpEdge
*
edge_out
:
node_list_
[
i
]
->
edges_out_
)
{
node_binary_sets
[
i
].
AddEntry
(
edge_out
->
start_node_
->
node_list_id_
);
}
}
// Find two nodes with largest common subset
// buffer of binary set
BinarySet
buffer_binary_set
(
node_list_
.
size
());
// Number of common edges
int32_t
max_comm_edge_num
=
0
,
curr_comm_edge_num
=
0
;
int32_t
min_node_pair
[
2
];
// Number of Sbp Signature in merged node
int32_t
min_sbp_num
=
0
,
curr_sbp_num
=
0
;
for
(
int32_t
i
=
0
;
i
<
node_list_
.
size
();
i
++
)
{
for
(
int32_t
j
=
i
+
1
;
j
<
node_list_
.
size
();
j
++
)
{
curr_sbp_num
=
node_list_
[
i
]
->
cost_
.
size
()
*
node_list_
[
j
]
->
cost_
.
size
();
if
(
curr_sbp_num
<=
threshold_
)
{
node_binary_sets
[
i
].
IntersectionTo
(
node_binary_sets
[
j
],
buffer_binary_set
);
curr_comm_edge_num
=
buffer_binary_set
.
Total
();
if
(
curr_comm_edge_num
>
max_comm_edge_num
||
(
curr_comm_edge_num
==
max_comm_edge_num
&&
curr_sbp_num
<
min_sbp_num
))
{
min_node_pair
[
0
]
=
i
;
min_node_pair
[
1
]
=
j
;
max_comm_edge_num
=
curr_comm_edge_num
;
min_sbp_num
=
curr_sbp_num
;
}
}
}
}
if
(
max_comm_edge_num
>
0
)
{
return
NodeMerging
(
node_list_
[
min_node_pair
[
0
]],
node_list_
[
min_node_pair
[
1
]]);
}
else
{
return
0
;
}
}
}
// Clip an edge, remove it from graph
void
SbpGraph
::
ClipEdge
(
SbpEdge
*
this_edge
)
const
{
CheckAndRemoveFrom
<
SbpEdge
*>
(
this_edge
->
end_node_
->
edges_in_
,
this_edge
);
CheckAndRemoveFrom
<
SbpEdge
*>
(
this_edge
->
start_node_
->
edges_out_
,
this_edge
);
delete
this_edge
;
}
// Compute the minimum and maximum layer of each node in the graph
int32_t
SbpGraph
::
ComputeLayer
(
HashMap
<
std
::
string
,
SbpNode
*>&
op_name2sbp_node
,
const
HashMap
<
const
OpNode
*
,
HashSet
<
std
::
string
>>&
op_node2mutable_op_ctrl_deps
)
const
{
// Compute minimum layer
for
(
SbpNode
*
this_node
:
node_list_
)
{
this_node
->
GetMinLayer
(
op_name2sbp_node
,
op_node2mutable_op_ctrl_deps
);
}
// Find the largest minimum layer
int32_t
max_min_layer
=
-
1
;
for
(
SbpNode
*
this_node
:
node_list_
)
{
if
(
max_min_layer
<
this_node
->
min_layer_
)
{
max_min_layer
=
this_node
->
min_layer_
;
}
}
// Compute maximum layer
for
(
SbpNode
*
this_node
:
node_list_
)
{
this_node
->
SpreadMaxLayer
(
op_name2sbp_node
,
op_node2mutable_op_ctrl_deps
);
}
for
(
SbpNode
*
this_node
:
node_list_
)
{
this_node
->
LiftMaxLayer
(
max_min_layer
);
}
return
max_min_layer
;
}
// Find the trunk of the sbp graph, then reduce the wait time for tributaries
void
SbpGraph
::
FindTrunk
(
int32_t
max_min_layer
,
HashMap
<
std
::
string
,
SbpNode
*>&
op_name2sbp_node
)
const
{
// Summarize cost for each layer, on the trunk or tributaries
std
::
vector
<
double
>
trunk_cost
(
max_min_layer
+
1
,
0
);
for
(
SbpNode
*
this_node
:
node_list_
)
{
trunk_cost
[
this_node
->
min_layer_
]
+=
this_node
->
GetMinCost
();
}
// Decide trunks
double
acc_cost
=
0
;
// All the nodes with MinLayer>=trunk_end_id would be considered as trunks
int32_t
trunk_end_id
=
max_min_layer
;
for
(
int32_t
layer_id
=
max_min_layer
;
layer_id
>=
0
;
layer_id
--
)
{
acc_cost
+=
trunk_cost
[
layer_id
];
if
(
acc_cost
>
0.5
*
wait_time_
)
{
trunk_end_id
=
layer_id
;
break
;
}
}
// Find out all the nodes on the trunk.
for
(
SbpNode
*
this_node
:
node_list_
)
{
if
(
this_node
->
min_layer_
>=
trunk_end_id
)
{
this_node
->
SpreadTrunk
(
op_name2sbp_node
);
}
}
// Compute maximum layer for tributaries
// Clear counter and initialize tributary layer for each sbp node
for
(
SbpNode
*
this_node
:
node_list_
)
{
this_node
->
counter_
=
0
;
this_node
->
DropTributaryLayer
(
max_min_layer
);
}
// Count the number of consumers and downstream nodes
for
(
SbpNode
*
this_node
:
node_list_
)
{
this_node
->
RaiseConsumerNum
(
op_name2sbp_node
);
}
// Compute maximum layer for tributaries
for
(
SbpNode
*
this_node
:
node_list_
)
{
this_node
->
SpreadTributaryLayer
(
op_name2sbp_node
);
}
// Summarize cost for each layer on the trunk, store it to avoid subtraction of large values.
trunk_cost
.
assign
(
max_min_layer
+
1
,
0
);
// tributary cost start from each min layer
std
::
vector
<
double
>
tributary_cost
(
max_min_layer
+
1
,
0
);
// tributary cost would be outdated after Max Layer (before Max Layer + 1)
std
::
vector
<
double
>
outdated_tributary_cost
(
max_min_layer
+
1
,
0
);
// number of operators in the trunk
std
::
vector
<
std
::
vector
<
SbpNode
*>>
trunk_ops
(
max_min_layer
+
1
);
for
(
SbpNode
*
this_node
:
node_list_
)
{
if
(
this_node
->
on_trunk_
)
{
trunk_cost
[
this_node
->
min_layer_
]
+=
this_node
->
GetMinCost
();
trunk_ops
[
this_node
->
min_layer_
].
emplace_back
(
this_node
);
}
else
{
double
curr_min_cost
=
this_node
->
GetMinCost
();
tributary_cost
[
this_node
->
min_layer_
]
+=
curr_min_cost
;
outdated_tributary_cost
[
this_node
->
tributary_layer_
]
+=
curr_min_cost
;
}
}
// Accumulate the cost from the consumer to the end, not including itself
std
::
vector
<
double
>
acc_trunk_cost
(
max_min_layer
+
1
,
0
);
for
(
int32_t
layer_id
=
max_min_layer
;
layer_id
>
0
;
layer_id
--
)
{
acc_trunk_cost
[
layer_id
-
1
]
=
acc_trunk_cost
[
layer_id
]
+
trunk_cost
[
layer_id
];
}
// Clear counter for each sbp node
for
(
SbpNode
*
this_node
:
node_list_
)
{
this_node
->
counter_
=
0
;
}
// Count the number of consumers and downstream nodes
for
(
SbpNode
*
this_node
:
node_list_
)
{
this_node
->
RaiseConsumerNum
(
op_name2sbp_node
);
}
// Reduce the wait time for tributaries
for
(
SbpNode
*
this_node
:
node_list_
)
{
this_node
->
SpreadAvailWaitTime
(
trunk_cost
,
acc_trunk_cost
,
op_name2sbp_node
,
wait_time_
);
}
// Reduce the wait time for trunk from the end to the begin
double
acc_tributary_cost
=
outdated_tributary_cost
[
max_min_layer
];
double
used_tributary_cost
=
0.0
;
double
curr_wait_time
=
0.0
;
for
(
int32_t
layer_id
=
max_min_layer
-
1
;
layer_id
>=
0
;
layer_id
--
)
{
// Can not move it backward since we need to do this at the 0th layer.
// At some moment, the cost haven't been used would disappear.
if
(
tributary_cost
[
layer_id
+
1
]
>
used_tributary_cost
)
{
acc_tributary_cost
-=
tributary_cost
[
layer_id
+
1
]
-
used_tributary_cost
;
used_tributary_cost
=
0.0
;
if
(
acc_tributary_cost
<
0.0
)
{
// should not happen besides floating point error
std
::
cout
<<
"Caution! Current accumulated tributary cost is: "
<<
acc_tributary_cost
<<
std
::
endl
;
acc_tributary_cost
=
0.0
;
}
}
else
{
used_tributary_cost
-=
tributary_cost
[
layer_id
+
1
];
}
// accumulate tributary cost at this layer
acc_tributary_cost
+=
outdated_tributary_cost
[
layer_id
];
// If we have more cost in tributaries, we reduce the wait time
// This code maintains ( acc_tributary_cost + used_tributary_cost )
if
(
acc_tributary_cost
>
0.0
)
{
if
(
acc_tributary_cost
>
wait_time_
)
{
curr_wait_time
=
0.0
;
acc_tributary_cost
-=
wait_time_
;
used_tributary_cost
+=
wait_time_
;
}
else
{
curr_wait_time
=
wait_time_
-
acc_tributary_cost
;
used_tributary_cost
+=
acc_tributary_cost
;
acc_tributary_cost
=
0.0
;
}
// Reduce the wait time in the trunk
for
(
SbpNode
*
this_node
:
trunk_ops
[
layer_id
])
{
this_node
->
SetTrunkWaitTime
(
curr_wait_time
);
}
}
}
}
// Set wait time
void
SbpGraph
::
SetWaitTime
(
double
wait_time
)
{
wait_time_
=
wait_time
;
}
}
// namespace auto_parallel
}
// namespace oneflow
oneflow/core/auto_parallel/sbp_graph.h
0 → 100644
View file @
a715222c
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_AUTO_PARALLEL_SBP_GRAPH_H_
#define ONEFLOW_CORE_AUTO_PARALLEL_SBP_GRAPH_H_
#include <algorithm>
#include <unordered_map>
#include "oneflow/core/auto_parallel/binary_set.h"
#include "oneflow/core/auto_parallel/sbp_node.h"
#include "oneflow/core/auto_parallel/sbp_edge.h"
#include "oneflow/core/auto_parallel/algorithm_util.h"
#include "oneflow/core/common/util.h"
namespace
oneflow
{
namespace
auto_parallel
{
// A graph structure to deal with the SBP strategy.
// It contains a lot of eliminations to shrink the topography structure of the original graph.
// Furthermore, it contains some adjustment tricks for search a good strategy in the shrunk graph.
class
SbpGraph
final
{
public:
// Constructor
SbpGraph
()
=
default
;
// Deconstructor
~
SbpGraph
();
OF_DISALLOW_COPY_AND_MOVE
(
SbpGraph
);
bool
operator
==
(
const
SbpGraph
&
other
)
{
return
this
==
&
other
;
}
// Randomly assign a SbpSignature strategy
void
RandomSbpSignature
(
bool
use_sbp_collector
)
const
;
// assign 0 to a SbpSignature strategy to avoid randomness
void
SetDefaultSbpSig
()
const
;
// Compute Cost for current strategy
double
ComputeCost
()
const
;
// Generate a node
SbpNode
*
GenerateNode
();
// Merge all parallel edges & Check and eliminate all nodes with only one
// degree-in and one degree-out
int32_t
NodeAndEdgeEliminations
();
// Finalize Sbp Cost for the whole graph
void
FinalizeSbp
()
const
;
// Use Greedy Strategy to decide Sbp for Nodes in node_list_. Should be used
// after we have a initial strategy.
// Set for_node to be true will only use GreedyStrategy on Nodes.
double
GreedyStrategy
(
bool
for_node
)
const
;
// Use greedy strategy on the one ring neighborhood with the maximum number of points nbh_num.
double
GreedyStrategy
(
int32_t
nbh_num
=
4
)
const
;
// Find one strategy with finite cost for adjustment
Maybe
<
void
>
Find1Strategy4Greedy
()
const
;
// Use brute force to search for a strategy with minimum cost for a neighborhood
double
NbhGreedyStrategy
(
std
::
vector
<
int32_t
>&
nbh_id2node_list_id
)
const
;
// Set threshold_ for SbpNode Merging
void
SetThreshold
(
int32_t
threshold
)
{
threshold_
=
threshold
;
}
// Clip an edge, remove it from graph
// Clipping an edge will also delete the nodes and edges contained in this edge. Though not
// suffering from any compiling and runtime bugs, clipping an edge on a shrunk graph is not
// recommended. We should carefully think about it before any clipping.
void
ClipEdge
(
SbpEdge
*
this_edge
)
const
;
// Compute the minimum and maximum layer of each node in the graph
int32_t
ComputeLayer
(
HashMap
<
std
::
string
,
SbpNode
*>&
op_name2sbp_node
,
const
HashMap
<
const
OpNode
*
,
HashSet
<
std
::
string
>>&
op_node2mutable_op_ctrl_deps
)
const
;
// Find the trunk of the sbp graph, then reduce the wait time for tributaries
void
FindTrunk
(
int32_t
max_min_layer
,
HashMap
<
std
::
string
,
SbpNode
*>&
op_name2sbp_node
)
const
;
// Set wait time
void
SetWaitTime
(
double
wait_time
);
private:
friend
class
SbpCollector
;
friend
class
SbpConstructor
;
// All the nodes
std
::
vector
<
SbpNode
*>
node_list_
;
// Limitation: Merged node should not have a number of Sbp Signature greater
// than threshold.
int32_t
threshold_
=
100
;
// Overlayable wait time for copy cost, which occurs before communication between devices.
double
wait_time_
=
16500.0
;
// Remove a node from the node list
void
RemoveFromNodeList
(
SbpNode
*
this_node
);
// Check and eliminate one node with only one degree-in and one degree-out
int32_t
NodeElimination
(
SbpNode
*
this_node
);
// Merge all parallel edges with given start_node_ and end_node_
int32_t
EdgeElimination
(
SbpNode
*
this_node
)
const
;
// Check and eliminate one child node
int32_t
ChildElimination
(
SbpNode
*
this_node
);
// Merge two nodes
int32_t
NodeMerging
(
SbpNode
*
first
,
SbpNode
*
second
);
// Select two nodes and merge them
int32_t
PickAndMerge
();
void
DfsAddNbhCost
(
std
::
vector
<
int32_t
>&
nbh_id2node_list_id
,
std
::
unordered_map
<
int32_t
,
int32_t
>&
node_list_id2nbh_id
,
std
::
vector
<
int32_t
>&
order2nbh_id
,
std
::
vector
<
int32_t
>&
nbh_id2order
,
std
::
vector
<
double
>&
order2acc_min_in_nbh_cost
,
std
::
vector
<
std
::
vector
<
double
>>&
out_nbh_costs
,
std
::
vector
<
std
::
vector
<
int32_t
>>&
nbh_id2order2sbp_id
,
std
::
vector
<
int32_t
>&
min_sbp_sig_id
,
double
&
min_cost
,
int32_t
order
,
double
curr_cost
)
const
;
bool
DfsFindReasonableCost
(
std
::
vector
<
int32_t
>&
nbh_id2node_list_id
,
std
::
unordered_map
<
int32_t
,
int32_t
>&
node_list_id2nbh_id
,
std
::
vector
<
int32_t
>&
nbh_id2order
,
int32_t
nbh_id
)
const
;
};
}
// namespace auto_parallel
}
// namespace oneflow
#endif // ONEFLOW_CORE_AUTO_PARALLEL_SBP_GRAPH_H_
oneflow/core/auto_parallel/sbp_node.cpp
0 → 100644
View file @
a715222c
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include <cstdlib>
#include <functional>
#include <iostream>
#include <vector>
#include "oneflow/core/auto_parallel/binary_set.h"
#include "oneflow/core/common/data_type.h"
#include "oneflow/core/framework/sbp_infer_util.h"
#include "oneflow/core/graph/op_graph.h"
#include "oneflow/core/auto_parallel/algorithm_util.h"
#include "oneflow/core/job/sbp_parallel.pb.h"
#include "oneflow/core/auto_parallel/sbp_node.h"
#include "oneflow/core/auto_parallel/sbp_edge.h"
#include "oneflow/core/auto_parallel/sbp_graph.h"
namespace
oneflow
{
namespace
auto_parallel
{
// function in cpp. Should be put in one file due to use of template
// Otherwise we will need to declare specific template at the end of cpp file.
SbpNode
::
SbpNode
(
SbpNode
*
first
,
SbpNode
*
second
)
{
half_node_
.
resize
(
2
);
half_node_
[
0
]
=
first
;
half_node_
[
1
]
=
second
;
// Get the edge between first and second
// NOTE: It must zero or one edge between them
SbpEdge
*
common_edge
=
nullptr
;
for
(
int32_t
k
=
0
;
k
<
first
->
edges_in_
.
size
();
k
++
)
{
if
(
first
->
edges_in_
[
k
]
->
start_node_
==
second
)
{
// CHECK_ISNULL(edge);
common_edge
=
first
->
edges_in_
[
k
];
}
}
for
(
int32_t
k
=
0
;
k
<
first
->
edges_out_
.
size
();
k
++
)
{
if
(
first
->
edges_out_
[
k
]
->
end_node_
==
second
)
{
common_edge
=
first
->
edges_out_
[
k
];
}
}
// Find all available merged-SbpSignature(edge's cost less than threshold).
if
(
common_edge
)
{
double
min_cost
=
GetMaxVal
<
float
>
();
for
(
const
auto
&
row
:
common_edge
->
cost_
)
{
for
(
const
double
&
c
:
row
)
min_cost
=
std
::
min
(
min_cost
,
c
);
}
// If there is no one case can choose, we will blow up
for
(
int32_t
i
=
0
;
i
<
first
->
cost_
.
size
();
i
++
)
{
for
(
int32_t
j
=
0
;
j
<
second
->
cost_
.
size
();
j
++
)
{
const
double
edge_cost
=
common_edge
->
start_node_
==
first
?
common_edge
->
cost_
[
i
][
j
]
:
common_edge
->
cost_
[
j
][
i
];
if
(
edge_cost
<
GetValidMaxCopyCost
())
{
merged_sig_id2children_sig_id_
.
emplace_back
(
std
::
make_pair
(
i
,
j
));
cost_
.
emplace_back
(
edge_cost
+
first
->
cost_
[
i
]
+
second
->
cost_
[
j
]);
}
}
}
CHECK
(
merged_sig_id2children_sig_id_
.
size
()
>
0
)
<<
"0 size for merge child edge, min cost: "
<<
min_cost
;
}
else
{
for
(
int32_t
i
=
0
;
i
<
first
->
cost_
.
size
();
i
++
)
{
for
(
int32_t
j
=
0
;
j
<
second
->
cost_
.
size
();
j
++
)
{
merged_sig_id2children_sig_id_
.
emplace_back
(
std
::
make_pair
(
i
,
j
));
cost_
.
emplace_back
(
first
->
cost_
[
i
]
+
second
->
cost_
[
j
]);
}
}
}
// Initialize default sbp choice
// If the original sbp pair does not go through, then use 0 as default.
final_sbp_sig_id_
=
0
;
// Track the original strategy
for
(
int32_t
sig_id
=
0
;
sig_id
<
merged_sig_id2children_sig_id_
.
size
();
sig_id
++
)
{
if
(
merged_sig_id2children_sig_id_
[
sig_id
].
first
==
first
->
final_sbp_sig_id_
&&
merged_sig_id2children_sig_id_
[
sig_id
].
second
==
second
->
final_sbp_sig_id_
)
{
final_sbp_sig_id_
=
sig_id
;
}
}
// Merge edges_in_
edges_in_
.
reserve
(
first
->
edges_in_
.
size
()
+
second
->
edges_in_
.
size
());
edges_in_
.
insert
(
edges_in_
.
end
(),
first
->
edges_in_
.
begin
(),
first
->
edges_in_
.
end
());
edges_in_
.
insert
(
edges_in_
.
end
(),
second
->
edges_in_
.
begin
(),
second
->
edges_in_
.
end
());
// Merge edges_out_
edges_out_
.
reserve
(
first
->
edges_out_
.
size
()
+
second
->
edges_out_
.
size
());
edges_out_
.
insert
(
edges_out_
.
end
(),
first
->
edges_out_
.
begin
(),
first
->
edges_out_
.
end
());
edges_out_
.
insert
(
edges_out_
.
end
(),
second
->
edges_out_
.
begin
(),
second
->
edges_out_
.
end
());
// Merge SbpEdge Cost
for
(
SbpEdge
*&
this_edge
:
first
->
edges_in_
)
{
this_edge
->
DuplicateCost
(
false
,
true
,
merged_sig_id2children_sig_id_
);
this_edge
->
end_node_
=
this
;
}
for
(
SbpEdge
*&
this_edge
:
first
->
edges_out_
)
{
this_edge
->
DuplicateCost
(
true
,
true
,
merged_sig_id2children_sig_id_
);
this_edge
->
start_node_
=
this
;
}
for
(
SbpEdge
*&
this_edge
:
second
->
edges_in_
)
{
this_edge
->
DuplicateCost
(
false
,
false
,
merged_sig_id2children_sig_id_
);
this_edge
->
end_node_
=
this
;
}
for
(
SbpEdge
*&
this_edge
:
second
->
edges_out_
)
{
this_edge
->
DuplicateCost
(
true
,
false
,
merged_sig_id2children_sig_id_
);
this_edge
->
start_node_
=
this
;
}
// Remove edges from original nodes
first
->
edges_in_
.
clear
();
first
->
edges_out_
.
clear
();
second
->
edges_in_
.
clear
();
second
->
edges_out_
.
clear
();
// Move edges between two nodes to each half node
for
(
int32_t
k
=
edges_out_
.
size
()
-
1
;
k
>=
0
;
k
--
)
{
if
(
edges_out_
[
k
]
->
end_node_
==
this
)
{
// Remove this edge from edges_out_ and edges_in_ and put it inside the node
CheckAndRemoveFrom
<
SbpEdge
*>
(
edges_in_
,
edges_out_
[
k
]);
first
->
edges_out_
.
emplace_back
(
edges_out_
[
k
]);
second
->
edges_in_
.
emplace_back
(
edges_out_
[
k
]);
RemoveFrom
<
SbpEdge
*>
(
edges_out_
,
k
);
}
}
}
SbpNode
::~
SbpNode
()
{
for
(
auto
&
edge_out
:
edges_out_
)
{
delete
edge_out
;
}
for
(
auto
&
child_node
:
children_
)
{
if
(
child_node
->
edges_in_
.
size
())
{
delete
child_node
->
edges_in_
[
0
];
}
delete
child_node
;
}
for
(
auto
&
half_node
:
half_node_
)
{
delete
half_node
;
}
}
void
SbpNode
::
InitializeSbp
()
{
global_sbp_sig_size_
=
sbp_sig_list_
.
size
();
cost_
.
resize
(
sbp_sig_list_
.
size
());
};
// Let one node point to another
void
SbpNode
::
StartPointToEnd
(
SbpNode
*
start_node
,
SbpNode
*
end_node
)
{
// generate the edge between them
SbpEdge
*
e
=
new
SbpEdge
(
start_node
,
end_node
);
start_node
->
edges_out_
.
emplace_back
(
e
);
end_node
->
edges_in_
.
emplace_back
(
e
);
};
void
SbpNode
::
PointFrom
(
SbpNode
*
start_node
)
{
StartPointToEnd
(
start_node
,
this
);
};
void
SbpNode
::
PointTo
(
SbpNode
*
end_node
)
{
StartPointToEnd
(
this
,
end_node
);
};
void
SbpNode
::
SummarizeCost
()
{
if
(
children_
.
size
()
==
child_node_sbp_sig_
.
size
())
{
return
;
}
int32_t
previous_children_size
=
child_node_sbp_sig_
.
size
();
child_node_sbp_sig_
.
resize
(
children_
.
size
());
// Only deal with new children_
for
(
int32_t
child
=
previous_children_size
;
child
<
children_
.
size
();
child
++
)
{
child_node_sbp_sig_
[
child
].
resize
(
cost_
.
size
());
for
(
int32_t
sbp_this
=
0
;
sbp_this
<
cost_
.
size
();
sbp_this
++
)
{
double
min_cost
=
0
,
curr_cost
=
0
;
for
(
int32_t
sbp_child
=
0
;
sbp_child
<
children_
[
child
]
->
cost_
.
size
();
sbp_child
++
)
{
if
(
children_
[
child
]
->
edges_in_
.
size
())
{
// edge in graph: father -> child
curr_cost
=
children_
[
child
]
->
edges_in_
[
0
]
->
cost_
[
sbp_this
][
sbp_child
]
+
children_
[
child
]
->
cost_
[
sbp_child
];
}
else
{
// edge in graph: child -> father
curr_cost
=
children_
[
child
]
->
edges_out_
[
0
]
->
cost_
[
sbp_child
][
sbp_this
]
+
children_
[
child
]
->
cost_
[
sbp_child
];
}
// update min_cost with fixed SbpSignature for this node and child node
if
(
sbp_child
==
0
||
curr_cost
<
min_cost
)
{
min_cost
=
curr_cost
;
child_node_sbp_sig_
[
child
][
sbp_this
]
=
sbp_child
;
}
}
// Add the cost for child node to this node
cost_
[
sbp_this
]
+=
min_cost
;
}
}
}
bool
SbpNode
::
EliminateItselfAsChild
()
{
if
(
edges_in_
.
size
()
+
edges_out_
.
size
()
==
1
)
{
if
(
edges_in_
.
size
())
{
// edge in graph: father -> this_node
SbpNode
*
father
=
edges_in_
[
0
]
->
start_node_
;
father
->
children_
.
emplace_back
(
this
);
CheckAndRemoveFrom
<
SbpEdge
*>
(
father
->
edges_out_
,
edges_in_
[
0
]);
father
->
SummarizeCost
();
}
else
{
// edge in graph: this_node -> father
SbpNode
*
father
=
edges_out_
[
0
]
->
end_node_
;
father
->
children_
.
emplace_back
(
this
);
CheckAndRemoveFrom
<
SbpEdge
*>
(
father
->
edges_in_
,
edges_out_
[
0
]);
father
->
SummarizeCost
();
}
// successfully eliminate this node
return
true
;
}
// can not eliminate this node
return
false
;
}
void
SbpNode
::
FinalizeSbp
()
{
if
(
!
half_node_
.
empty
())
{
// Finalize Sbp of merged nodes
half_node_
[
0
]
->
final_sbp_sig_id_
=
merged_sig_id2children_sig_id_
[
final_sbp_sig_id_
].
first
;
half_node_
[
1
]
->
final_sbp_sig_id_
=
merged_sig_id2children_sig_id_
[
final_sbp_sig_id_
].
second
;
}
// Finalize Sbp of children_
for
(
int32_t
i
=
0
;
i
<
children_
.
size
();
i
++
)
{
children_
[
i
]
->
final_sbp_sig_id_
=
child_node_sbp_sig_
[
i
][
this
->
final_sbp_sig_id_
];
}
// Finalize Sbp of half_node_ Attachment
if
(
!
half_node_
.
empty
())
{
half_node_
[
0
]
->
FinalizeSbp
();
half_node_
[
1
]
->
FinalizeSbp
();
}
// Finalize Sbp of edges in edges_out_
for
(
const
auto
&
edge_out
:
edges_out_
)
{
edge_out
->
FinalizeSbp
();
}
// Finalize Sbp again in case of the node on the other side is not finalized
// yet. This may happen when Two side of an edge merged into two larger nodes
// and this edge is just a sub edge.
for
(
const
auto
&
edge_in
:
edges_in_
)
{
edge_in
->
FinalizeSbp
();
}
// Finalize Sbp of children_ Attachment
for
(
int32_t
i
=
0
;
i
<
children_
.
size
();
i
++
)
{
children_
[
i
]
->
FinalizeSbp
();
for
(
const
auto
&
edge_in
:
children_
[
i
]
->
edges_in_
)
{
edge_in
->
FinalizeSbp
();
}
}
}
double
SbpNode
::
GreedyStrategy
()
{
// Current Cost, Minimum Cost, Cost with original sbp
double
curr_cost
=
0
;
double
original_cost
=
EvalNbhCost
();
double
min_cost
=
original_cost
;
int32_t
min_sbp
=
final_sbp_sig_id_
;
for
(
int32_t
sbp
=
0
;
sbp
<
cost_
.
size
();
sbp
++
)
{
final_sbp_sig_id_
=
sbp
;
curr_cost
=
EvalNbhCost
();
if
(
curr_cost
<
min_cost
)
{
min_cost
=
curr_cost
;
min_sbp
=
sbp
;
}
}
final_sbp_sig_id_
=
min_sbp
;
return
min_cost
-
original_cost
;
}
double
SbpNode
::
EvalNbhCost
()
const
{
// Current Cost, Minimum Cost, Cost with original sbp
double
curr_cost
=
cost_
[
final_sbp_sig_id_
];
for
(
SbpEdge
*
this_edge
:
edges_in_
)
{
curr_cost
+=
this_edge
->
cost_
[
this_edge
->
start_node_
->
final_sbp_sig_id_
][
final_sbp_sig_id_
];
}
for
(
SbpEdge
*
this_edge
:
edges_out_
)
{
curr_cost
+=
this_edge
->
cost_
[
final_sbp_sig_id_
][
this_edge
->
end_node_
->
final_sbp_sig_id_
];
}
return
curr_cost
;
}
double
SbpNode
::
EvalOutNbhCost
(
const
std
::
unordered_map
<
int32_t
,
int32_t
>&
node_list_id2nbh_id
)
const
{
// check if this node is in the node list
CHECK
(
node_list_id_
>=
0
)
<<
"Compute out cost for a node out of the node list"
<<
std
::
endl
;
// Cost with original sbp
double
curr_cost
=
cost_
[
final_sbp_sig_id_
];
for
(
SbpEdge
*
this_edge
:
edges_in_
)
{
// if the start node is not in the neighborhood
if
(
node_list_id2nbh_id
.
find
(
this_edge
->
start_node_
->
node_list_id_
)
==
node_list_id2nbh_id
.
end
())
{
curr_cost
+=
this_edge
->
cost_
[
this_edge
->
start_node_
->
final_sbp_sig_id_
][
final_sbp_sig_id_
];
}
}
for
(
SbpEdge
*
this_edge
:
edges_out_
)
{
// if the end node is not in the neighborhood
if
(
node_list_id2nbh_id
.
find
(
this_edge
->
end_node_
->
node_list_id_
)
==
node_list_id2nbh_id
.
end
())
{
curr_cost
+=
this_edge
->
cost_
[
final_sbp_sig_id_
][
this_edge
->
end_node_
->
final_sbp_sig_id_
];
}
}
return
curr_cost
;
}
// Compute the cost between this node and adjacent nodes with a lower order
double
SbpNode
::
EvalInNbhCost
(
const
std
::
unordered_map
<
int32_t
,
int32_t
>&
node_list_id2nbh_id
,
const
std
::
vector
<
int32_t
>&
nbh_id2order
)
const
{
// check if this node is in the node list
CHECK
(
node_list_id_
>=
0
)
<<
"Compute in cost for a node out of the node list"
;
// check if the node is in the neighborhood
const
auto
&
this_it
=
node_list_id2nbh_id
.
find
(
node_list_id_
);
CHECK
(
this_it
!=
node_list_id2nbh_id
.
end
())
<<
"Compute in cost for a node out of the neighborhood"
;
// Compute the minimum cost between this node and adjacent nodes with a lower order
int32_t
order
=
nbh_id2order
[
this_it
->
second
];
double
curr_cost
=
0
;
for
(
SbpEdge
*
this_edge
:
edges_in_
)
{
const
auto
&
it
=
node_list_id2nbh_id
.
find
(
this_edge
->
start_node_
->
node_list_id_
);
// if the start node is in the neighborhood
if
(
it
!=
node_list_id2nbh_id
.
end
()
&&
nbh_id2order
[
it
->
second
]
<
order
)
{
curr_cost
+=
this_edge
->
cost_
[
this_edge
->
start_node_
->
final_sbp_sig_id_
][
final_sbp_sig_id_
];
// End this function and return infinity.
if
(
curr_cost
>
GetValidMaxCopyCost
())
{
return
GetMaxVal
<
float
>
();
}
}
}
for
(
SbpEdge
*
this_edge
:
edges_out_
)
{
const
auto
&
it
=
node_list_id2nbh_id
.
find
(
this_edge
->
end_node_
->
node_list_id_
);
// if the end node is in the neighborhood
if
(
it
!=
node_list_id2nbh_id
.
end
()
&&
nbh_id2order
[
it
->
second
]
<
order
)
{
curr_cost
+=
this_edge
->
cost_
[
final_sbp_sig_id_
][
this_edge
->
end_node_
->
final_sbp_sig_id_
];
if
(
curr_cost
>
GetValidMaxCopyCost
())
{
return
GetMaxVal
<
float
>
();
}
}
}
return
curr_cost
;
}
double
SbpNode
::
EvalMinInNbhCost
(
const
std
::
unordered_map
<
int32_t
,
int32_t
>&
node_list_id2nbh_id
,
const
std
::
vector
<
int32_t
>&
nbh_id2order
)
const
{
// check if this node is in the node list
CHECK
(
node_list_id_
>=
0
)
<<
"Compute out cost for a node out of the node list"
<<
std
::
endl
;
// check if the node is in the neighborhood
const
auto
&
this_it
=
node_list_id2nbh_id
.
find
(
node_list_id_
);
CHECK
(
this_it
!=
node_list_id2nbh_id
.
end
())
<<
"Compute out cost for a node out of the neighborhood"
<<
std
::
endl
;
// Compute the minimum cost between this node and adjacent nodes with a higher order
int32_t
order
=
nbh_id2order
[
this_it
->
second
];
double
curr_cost
=
0
;
for
(
SbpEdge
*
this_edge
:
edges_in_
)
{
const
auto
&
it
=
node_list_id2nbh_id
.
find
(
this_edge
->
start_node_
->
node_list_id_
);
// if the start node is in the neighborhood
if
(
it
!=
node_list_id2nbh_id
.
end
()
&&
nbh_id2order
[
it
->
second
]
>
order
)
{
curr_cost
+=
this_edge
->
GetMinCost
();
}
}
for
(
SbpEdge
*
this_edge
:
edges_out_
)
{
const
auto
&
it
=
node_list_id2nbh_id
.
find
(
this_edge
->
end_node_
->
node_list_id_
);
// if the end node is in the neighborhood
if
(
it
!=
node_list_id2nbh_id
.
end
()
&&
nbh_id2order
[
it
->
second
]
>
order
)
{
curr_cost
+=
this_edge
->
GetMinCost
();
}
}
return
curr_cost
;
}
void
SbpNode
::
OneRingNeighborhood
(
std
::
vector
<
int32_t
>&
nbh_1ring
)
const
{
nbh_1ring
.
resize
(
edges_in_
.
size
()
+
edges_out_
.
size
()
+
1
);
int32_t
nbh_id
=
0
;
nbh_1ring
[
nbh_id
]
=
node_list_id_
;
for
(
SbpEdge
*
this_edge
:
edges_in_
)
{
nbh_id
++
;
nbh_1ring
[
nbh_id
]
=
this_edge
->
start_node_
->
node_list_id_
;
}
for
(
SbpEdge
*
this_edge
:
edges_out_
)
{
nbh_id
++
;
nbh_1ring
[
nbh_id
]
=
this_edge
->
end_node_
->
node_list_id_
;
}
}
// Get the n ring neighborhood of this node
// Pre-allocate buffer, which will be faster.
void
SbpNode
::
NRingNeighborhood
(
int32_t
n
,
std
::
vector
<
int32_t
>&
nbh_n_ring
,
std
::
vector
<
int32_t
>&
nbh_1ring
,
const
std
::
vector
<
SbpNode
*>&
node_list
,
std
::
vector
<
bool
>&
node_tags
)
const
{
// Initialize 0 ring
if
(
n
<=
0
)
{
n
=
0
;
}
nbh_n_ring
.
resize
(
1
);
nbh_n_ring
[
0
]
=
node_list_id_
;
node_tags
[
node_list_id_
]
=
true
;
int32_t
l
=
0
;
// do ring expansion for n times
for
(
int32_t
i
=
0
;
i
<
n
;
i
++
)
{
for
(
int32_t
r
=
nbh_n_ring
.
size
();
l
<
r
;
l
++
)
{
node_list
[
nbh_n_ring
[
l
]]
->
OneRingNeighborhood
(
nbh_1ring
);
for
(
auto
nbh_id
:
nbh_1ring
)
{
if
(
!
node_tags
[
nbh_id
])
{
nbh_n_ring
.
push_back
(
nbh_id
);
node_tags
[
nbh_id
]
=
true
;
}
}
}
}
// Recover false for buffer
for
(
auto
nbh_id
:
nbh_n_ring
)
{
node_tags
[
nbh_id
]
=
false
;
}
}
// Get or compute the minimum layer of this node
int32_t
SbpNode
::
GetMinLayer
(
const
HashMap
<
std
::
string
,
SbpNode
*>&
op_name2sbp_node
,
const
HashMap
<
const
OpNode
*
,
HashSet
<
std
::
string
>>&
op_node2mutable_op_ctrl_deps
)
{
if
(
min_layer_
>=
0
)
{
return
min_layer_
;
}
if
(
!
op_node_
)
{
return
min_layer_
;
}
for
(
SbpEdge
*
this_edge
:
edges_in_
)
{
int32_t
producer_min_layer
=
this_edge
->
start_node_
->
GetMinLayer
(
op_name2sbp_node
,
op_node2mutable_op_ctrl_deps
);
if
(
producer_min_layer
>
min_layer_
)
{
min_layer_
=
producer_min_layer
;
}
}
for
(
const
auto
&
ctrl_in_op_name
:
op_node_
->
op
().
op_conf
().
ctrl_in_op_name
())
{
const
auto
&
it
=
op_name2sbp_node
.
find
(
ctrl_in_op_name
);
if
(
it
!=
op_name2sbp_node
.
end
())
{
int32_t
producer_min_layer
=
it
->
second
->
GetMinLayer
(
op_name2sbp_node
,
op_node2mutable_op_ctrl_deps
);
if
(
producer_min_layer
>
min_layer_
)
{
min_layer_
=
producer_min_layer
;
}
}
}
if
(
op_node2mutable_op_ctrl_deps
.
find
(
op_node_
)
!=
op_node2mutable_op_ctrl_deps
.
end
())
{
for
(
const
auto
&
ctrl_in_op_name
:
op_node2mutable_op_ctrl_deps
.
at
(
op_node_
))
{
const
auto
&
it
=
op_name2sbp_node
.
find
(
ctrl_in_op_name
);
if
(
it
!=
op_name2sbp_node
.
end
())
{
int32_t
producer_min_layer
=
it
->
second
->
GetMinLayer
(
op_name2sbp_node
,
op_node2mutable_op_ctrl_deps
);
if
(
producer_min_layer
>
min_layer_
)
{
min_layer_
=
producer_min_layer
;
}
}
}
}
return
++
min_layer_
;
}
// Spread the minimum layer to compute the maximum layer of producers
void
SbpNode
::
SpreadMaxLayer
(
const
HashMap
<
std
::
string
,
SbpNode
*>&
op_name2sbp_node
,
const
HashMap
<
const
OpNode
*
,
HashSet
<
std
::
string
>>&
op_node2mutable_op_ctrl_deps
)
{
if
(
min_layer_
<=
0
)
{
return
;
}
int32_t
producer_max_lay
=
min_layer_
-
1
;
for
(
SbpEdge
*
this_edge
:
edges_in_
)
{
this_edge
->
start_node_
->
DropMaxLayer
(
producer_max_lay
);
}
for
(
const
auto
&
ctrl_in_op_name
:
op_node_
->
op
().
op_conf
().
ctrl_in_op_name
())
{
const
auto
&
it
=
op_name2sbp_node
.
find
(
ctrl_in_op_name
);
if
(
it
!=
op_name2sbp_node
.
end
())
{
it
->
second
->
DropMaxLayer
(
producer_max_lay
);
}
}
if
(
op_node2mutable_op_ctrl_deps
.
find
(
op_node_
)
!=
op_node2mutable_op_ctrl_deps
.
end
())
{
for
(
const
auto
&
ctrl_in_op_name
:
op_node2mutable_op_ctrl_deps
.
at
(
op_node_
))
{
const
auto
&
it
=
op_name2sbp_node
.
find
(
ctrl_in_op_name
);
if
(
it
!=
op_name2sbp_node
.
end
())
{
it
->
second
->
DropMaxLayer
(
producer_max_lay
);
}
}
}
}
// Drop down the maximum layer with the minimum layer from consumer
void
SbpNode
::
DropMaxLayer
(
int32_t
upper_bound
)
{
if
(
upper_bound
<
max_layer_
||
max_layer_
<
0
)
{
max_layer_
=
upper_bound
;
}
}
// Set max_layer_ = min_layer_ if this node does not have any consumer
// This is the end of the whole graph
// We could also set it to be the maximum of the min_layer_ in the graph. (It should be the same.)
void
SbpNode
::
LiftMaxLayer
()
{
if
(
max_layer_
<
min_layer_
)
{
max_layer_
=
min_layer_
;
}
}
// Set max_layer_ = upper_bound if this node does not have any consumer
void
SbpNode
::
LiftMaxLayer
(
int32_t
upper_bound
)
{
if
(
max_layer_
<
min_layer_
)
{
max_layer_
=
upper_bound
;
}
}
// Get the minimum element in Cost
double
SbpNode
::
GetMinCost
()
const
{
// Check the size of Cost
CHECK
(
cost_
.
size
()
>
0
)
<<
"Cost not initialized!"
<<
std
::
endl
;
// Compute the min_comp_cost
return
*
std
::
min_element
(
cost_
.
begin
(),
cost_
.
end
());
}
// Set the cut ratio
double
SbpNode
::
GetCutRatio
()
const
{
double
curr_cut_ratio
=
1.0
;
for
(
auto
*
this_edge
:
edges_in_
)
{
curr_cut_ratio
*=
this_edge
->
GetCutRatio
();
}
for
(
auto
*
this_edge
:
edges_out_
)
{
curr_cut_ratio
*=
this_edge
->
GetCutRatio
();
}
return
curr_cut_ratio
;
}
// Judge if this node is on the trunk
// If so, judge it for its producer/upstream nodes
void
SbpNode
::
SpreadTrunk
(
const
HashMap
<
std
::
string
,
SbpNode
*>&
op_name2sbp_node
)
{
// Skip it if this node is already judged.
if
(
on_trunk_
)
{
return
;
}
// Skip sbp proxy. This is before we have proxy.
if
(
min_layer_
<
0
)
{
return
;
}
on_trunk_
=
true
;
// If I am in the trunk, then all the children with (min_layer_ >= my layer id - 1) would be
// considered as in the trunk
for
(
SbpEdge
*
this_edge
:
edges_in_
)
{
if
(
this_edge
->
start_node_
->
min_layer_
>=
min_layer_
-
1
)
{
this_edge
->
start_node_
->
SpreadTrunk
(
op_name2sbp_node
);
}
}
for
(
const
auto
&
ctrl_in_op_name
:
op_node_
->
op
().
op_conf
().
ctrl_in_op_name
())
{
const
auto
&
it
=
op_name2sbp_node
.
find
(
ctrl_in_op_name
);
if
(
it
!=
op_name2sbp_node
.
end
()
&&
it
->
second
->
min_layer_
>=
min_layer_
-
1
)
{
it
->
second
->
SpreadTrunk
(
op_name2sbp_node
);
}
}
}
// Count consumers and any downstream nodes defined by control edges
void
SbpNode
::
RaiseConsumerNum
(
const
HashMap
<
std
::
string
,
SbpNode
*>&
op_name2sbp_node
)
{
// Should clear it before running.
// skip the proxy nodes and the sources
if
(
min_layer_
<=
0
)
{
return
;
}
for
(
SbpEdge
*
this_edge
:
edges_in_
)
{
this_edge
->
start_node_
->
counter_
++
;
}
for
(
const
auto
&
ctrl_in_op_name
:
op_node_
->
op
().
op_conf
().
ctrl_in_op_name
())
{
const
auto
&
it
=
op_name2sbp_node
.
find
(
ctrl_in_op_name
);
if
(
it
!=
op_name2sbp_node
.
end
())
{
it
->
second
->
counter_
++
;
}
}
}
// Compute the minimal available wait time for producers or upstream nodes
void
SbpNode
::
SpreadAvailWaitTime
(
const
std
::
vector
<
double
>&
trunk_cost
,
const
std
::
vector
<
double
>&
acc_trunk_cost
,
const
HashMap
<
std
::
string
,
SbpNode
*>&
op_name2sbp_node
,
double
wait_time
)
{
// skip the proxy nodes and the sources
if
(
min_layer_
<=
0
)
{
return
;
}
// Have not finished spreading for consumers or downstream nodes or already visited.
if
(
counter_
)
{
return
;
}
if
(
on_trunk_
)
{
// Nodes on the trunk does not have any accumulate cost
acc_trunk_cost_
=
0
;
}
else
{
if
(
acc_trunk_cost_
<
0
)
{
// Do not have any consumer or downstream node
acc_trunk_cost_
=
acc_trunk_cost
[
min_layer_
-
1
];
}
else
{
// Add the trunk cost at this layer
acc_trunk_cost_
+=
trunk_cost
[
min_layer_
];
}
}
// Reduce the wait time for edges_in_, put the rest of the trunk cost in the producers
for
(
SbpEdge
*
this_edge
:
edges_in_
)
{
CHECK
(
this_edge
->
wait_time_
<
0
)
<<
"Double assign values into wait_time_ of this edge!"
<<
std
::
endl
;
SbpNode
*
producer
=
this_edge
->
start_node_
;
// Accumulate the cost from the start node to this node
double
curr_trunk_cost
=
acc_trunk_cost_
+
acc_trunk_cost
[
producer
->
min_layer_
]
-
acc_trunk_cost
[
min_layer_
-
1
];
if
(
curr_trunk_cost
>=
wait_time
)
{
// Remain cost in the trunk is able to cover all the wait time
this_edge
->
wait_time_
=
0.0
;
curr_trunk_cost
-=
wait_time
;
}
else
{
// Remain cost in the trunk can only cover partial wait time
this_edge
->
wait_time_
=
wait_time
-
curr_trunk_cost
;
curr_trunk_cost
=
0.0
;
}
// Reducing non-matching edges
// For example:
// (1) P->S0->S0->S0->B
// (2) p->B->B->B->B
// We would use (2) when the tensor is relatively tiny.
// Do not inherit trunk cost for nodes on the trunk
if
(
!
producer
->
on_trunk_
)
{
// Inherit the minimal of the trunk cost from consumers
producer
->
DropAvailWaitTime
(
curr_trunk_cost
);
}
producer
->
counter_
--
;
producer
->
SpreadAvailWaitTime
(
trunk_cost
,
acc_trunk_cost
,
op_name2sbp_node
,
wait_time
);
}
// Put the rest the trunk cost in the upstream nodes.
for
(
const
auto
&
ctrl_in_op_name
:
op_node_
->
op
().
op_conf
().
ctrl_in_op_name
())
{
const
auto
&
it
=
op_name2sbp_node
.
find
(
ctrl_in_op_name
);
if
(
it
!=
op_name2sbp_node
.
end
())
{
SbpNode
*
producer
=
it
->
second
;
// Do not inherit trunk cost for nodes on the trunk
if
(
!
producer
->
on_trunk_
)
{
// Accumulate the cost from the start node to this node
double
curr_trunk_cost
=
acc_trunk_cost_
+
acc_trunk_cost
[
producer
->
min_layer_
]
-
acc_trunk_cost
[
min_layer_
-
1
];
// Inherit the minimal of the trunk cost from consumers
producer
->
DropAvailWaitTime
(
curr_trunk_cost
);
}
producer
->
counter_
--
;
producer
->
SpreadAvailWaitTime
(
trunk_cost
,
acc_trunk_cost
,
op_name2sbp_node
,
wait_time
);
}
}
// Set counter_ to be -1, do not visit it again.
counter_
--
;
}
// Drop down the available wait time with the minimum cost from downstream
void
SbpNode
::
DropAvailWaitTime
(
double
curr_trunk_cost
)
{
if
(
acc_trunk_cost_
<
0.0
||
acc_trunk_cost_
>
curr_trunk_cost
)
{
acc_trunk_cost_
=
curr_trunk_cost
;
}
}
// Assemble copy cost for all the incoming edges
void
SbpNode
::
InitializeCopyCost
(
bool
use_sbp_collector
)
{
for
(
SbpEdge
*
this_edge
:
edges_in_
)
{
const
auto
*
sbp_node_producer
=
this_edge
->
start_node_
;
OpNode
*
producer
=
sbp_node_producer
->
op_node_
;
// skip it if proxy
if
(
use_sbp_collector
&&
!
producer
)
{
continue
;
}
// look through input blobs
for
(
const
std
::
string
&
ibn
:
op_node_
->
op
().
input_bns
())
{
if
(
producer
->
op
().
op_name
()
==
op_node_
->
SrcNode4Ibn
(
ibn
).
op
().
op_name
())
{
this_edge
->
InitializeCopyCost
(
ibn
,
use_sbp_collector
);
}
}
// Add Wait time
for
(
auto
&
cost_row
:
this_edge
->
cost_
)
{
for
(
auto
&
cost_value
:
cost_row
)
{
// If transferring between devices, we need to add wait time.
if
(
cost_value
>
0.0
)
{
cost_value
+=
this_edge
->
wait_time_
;
}
}
}
}
}
// Reduce and set the wait time for op in the trunk
void
SbpNode
::
SetTrunkWaitTime
(
double
trunk_wait_time
)
{
// only reduce the wait time for operators in the trunk
if
(
on_trunk_
)
{
// Reduce the wait time for edges_out_
for
(
SbpEdge
*
edge_out
:
edges_out_
)
{
if
(
edge_out
->
wait_time_
<
0.0
||
edge_out
->
wait_time_
>
trunk_wait_time
)
{
edge_out
->
wait_time_
=
trunk_wait_time
;
}
}
// Might reduce it for edges_in_
}
}
// Drop down the maximum layer with the minimum layer from consumer
void
SbpNode
::
DropTributaryLayer
(
int32_t
upper_bound
)
{
if
(
upper_bound
<
tributary_layer_
||
tributary_layer_
<
0
)
{
tributary_layer_
=
upper_bound
;
}
}
// Compute maximum layer for tributaries
void
SbpNode
::
SpreadTributaryLayer
(
const
HashMap
<
std
::
string
,
SbpNode
*>&
op_name2sbp_node
)
{
if
(
counter_
||
min_layer_
<=
0
)
{
return
;
}
int32_t
producer_max_lay
=
0
;
if
(
on_trunk_
)
{
producer_max_lay
=
min_layer_
-
1
;
}
else
{
// On a tributary, the operator could be run later.
producer_max_lay
=
tributary_layer_
;
// producer_max_lay = tributary_layer_ - 1;
}
for
(
SbpEdge
*
this_edge
:
edges_in_
)
{
this_edge
->
start_node_
->
DropTributaryLayer
(
producer_max_lay
);
if
(
--
this_edge
->
start_node_
->
counter_
==
0
)
{
this_edge
->
start_node_
->
SpreadTributaryLayer
(
op_name2sbp_node
);
}
}
for
(
const
auto
&
ctrl_in_op_name
:
op_node_
->
op
().
op_conf
().
ctrl_in_op_name
())
{
const
auto
&
it
=
op_name2sbp_node
.
find
(
ctrl_in_op_name
);
if
(
it
!=
op_name2sbp_node
.
end
())
{
it
->
second
->
DropTributaryLayer
(
producer_max_lay
);
if
(
--
it
->
second
->
counter_
==
0
)
{
it
->
second
->
SpreadTributaryLayer
(
op_name2sbp_node
);
}
}
}
counter_
--
;
}
SbpEdge
*
SbpNode
::
FindEdgeWithNode
(
const
SbpNode
*
other_node
)
const
{
for
(
auto
*
sbp_edge
:
edges_in_
)
{
if
(
sbp_edge
->
start_node_
==
other_node
)
{
return
sbp_edge
;
}
}
for
(
auto
*
sbp_edge
:
edges_out_
)
{
if
(
sbp_edge
->
end_node_
==
other_node
)
{
return
sbp_edge
;
}
}
return
nullptr
;
};
// Decide to use this SbpSignature
const
NdSbpSignature
&
SbpNode
::
FinalSbpSignature
()
const
{
CHECK
(
!
sbp_sig_list_
.
empty
())
<<
"Asking for sbp signature for an empty node"
;
return
sbp_sig_list_
[
final_sbp_sig_id_
];
};
}
// namespace auto_parallel
}
// namespace oneflow
oneflow/core/auto_parallel/sbp_node.h
0 → 100644
View file @
a715222c
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_AUTO_PARALLEL_SBP_NODE_H_
#define ONEFLOW_CORE_AUTO_PARALLEL_SBP_NODE_H_
#include <cstdlib>
#include <functional>
#include <iostream>
#include <vector>
#include "oneflow/core/auto_parallel/binary_set.h"
#include "oneflow/core/common/data_type.h"
#include "oneflow/core/framework/sbp_infer_util.h"
#include "oneflow/core/graph/op_graph.h"
#include "oneflow/core/auto_parallel/algorithm_util.h"
#include "oneflow/core/job/sbp_parallel.pb.h"
namespace
oneflow
{
namespace
auto_parallel
{
class
SbpEdge
;
// A node structure to deal with the SBP strategy.
// Please see SbpGraph for the whole algorithm and introduction.
class
SbpNode
final
{
public:
// default constructor
SbpNode
()
:
final_sbp_sig_id_
(
0
)
{}
// This constructor is to merge two node into one
SbpNode
(
SbpNode
*
first
,
SbpNode
*
second
);
~
SbpNode
();
OF_DISALLOW_COPY_AND_MOVE
(
SbpNode
);
bool
operator
==
(
const
SbpNode
&
other
)
{
return
this
==
&
other
;
}
// another node point to this node
void
PointFrom
(
SbpNode
*
start_node
);
// this node point to another node
void
PointTo
(
SbpNode
*
end_node
);
SbpEdge
*
FindEdgeWithNode
(
const
SbpNode
*
other_node
)
const
;
// Check and eliminate one child node.
// Only used by SbpGraph since it need to remove it from the NodeList after this.
bool
EliminateItselfAsChild
();
// Initialize SbpSignature from Signature Objects
void
InitializeSbp
();
// Decide to use this SbpSignature
const
NdSbpSignature
&
FinalSbpSignature
()
const
;
// Recompute Computation Cost after adding child nodes in it
void
SummarizeCost
();
// Determine Final SbpSignature for attachment of this node
void
FinalizeSbp
();
// Use Greedy Strategy to pick the sbp signature with minimum cost for this
// node You should have an initial strategy before running this
double
GreedyStrategy
();
// Evaluate summery of cost between neighborhood and outside nodes
double
EvalOutNbhCost
(
const
std
::
unordered_map
<
int32_t
,
int32_t
>&
node_list_id2nbh_id
)
const
;
// Evaluate summery of cost within neighborhood
// We only accumulate the edge cost with a lower order.
double
EvalInNbhCost
(
const
std
::
unordered_map
<
int32_t
,
int32_t
>&
node_list_id2nbh_id
,
const
std
::
vector
<
int32_t
>&
nbh_id2order
)
const
;
// Evaluate summery of cost within neighborhood
// We only accumulate the minimum edge cost with a higher order.
double
EvalMinInNbhCost
(
const
std
::
unordered_map
<
int32_t
,
int32_t
>&
node_list_id2nbh_id
,
const
std
::
vector
<
int32_t
>&
nbh_id2order
)
const
;
// Get the one ring neighborhood of this node, which is itself and all the adjacent nodes.
void
OneRingNeighborhood
(
std
::
vector
<
int32_t
>&
nbh_1ring
)
const
;
// Get the n ring neighborhood of this node
// Pre-allocate buffer, which will be faster.
void
NRingNeighborhood
(
int32_t
n
,
std
::
vector
<
int32_t
>&
nbh_n_ring
,
std
::
vector
<
int32_t
>&
nbh_1ring
,
const
std
::
vector
<
SbpNode
*>&
node_list
,
std
::
vector
<
bool
>&
node_tags
)
const
;
// Get or compute the minimum layer of this node
int32_t
GetMinLayer
(
const
HashMap
<
std
::
string
,
SbpNode
*>&
op_name2sbp_node
,
const
HashMap
<
const
OpNode
*
,
HashSet
<
std
::
string
>>&
op_node2mutable_op_ctrl_deps
);
// Spread the minimum layer to compute the maximum layer of producers
void
SpreadMaxLayer
(
const
HashMap
<
std
::
string
,
SbpNode
*>&
op_name2sbp_node
,
const
HashMap
<
const
OpNode
*
,
HashSet
<
std
::
string
>>&
op_node2mutable_op_ctrl_deps
);
// Set max_layer_ = min_layer_ if this node does not have any consumer
void
LiftMaxLayer
();
// Set max_layer_ = upper_bound if this node does not have any consumer
void
LiftMaxLayer
(
int32_t
upper_bound
);
// Compute maximum layer for tributaries
void
SpreadTributaryLayer
(
const
HashMap
<
std
::
string
,
SbpNode
*>&
op_name2sbp_node
);
// Drop down the tributary layer
void
DropTributaryLayer
(
int32_t
upper_bound
);
// Get the minimum element in Cost
double
GetMinCost
()
const
;
// get the cut ratio
double
GetCutRatio
()
const
;
// Judge if this node is on the trunk
// If so, judge it for its producer/upstream nodes
void
SpreadTrunk
(
const
HashMap
<
std
::
string
,
SbpNode
*>&
op_name2sbp_node
);
// Count consumers and any downstream nodes defined by control edges
// for producers or upstream nodes
void
RaiseConsumerNum
(
const
HashMap
<
std
::
string
,
SbpNode
*>&
op_name2sbp_node
);
// Compute the minimal available wait time for producers or upstream nodes
void
SpreadAvailWaitTime
(
const
std
::
vector
<
double
>&
trunk_cost
,
const
std
::
vector
<
double
>&
acc_trunk_cost
,
const
HashMap
<
std
::
string
,
SbpNode
*>&
op_name2sbp_node
,
double
wait_time
);
// Reduce and set the wait time for op in the trunk
void
SetTrunkWaitTime
(
double
trunk_wait_time
);
// Assemble copy cost for all the incoming edges
void
InitializeCopyCost
(
bool
use_sbp_collector
);
private:
friend
class
SbpEdge
;
friend
class
SbpGraph
;
friend
class
SbpCollector
;
friend
class
SbpConstructor
;
// compound edge in
std
::
vector
<
SbpEdge
*>
edges_in_
;
// compound edge out
std
::
vector
<
SbpEdge
*>
edges_out_
;
// Location in node_list of SbpGraph
int32_t
node_list_id_
=
-
1
;
// Global SbpSignature List Size
int32_t
global_sbp_sig_size_
=
-
1
;
// Decide to use SbpSignature with this id
int32_t
final_sbp_sig_id_
;
// Available SbpSignature object for this node
std
::
vector
<
NdSbpSignature
>
sbp_sig_list_
;
// Cost[sbp] is Computation Cost when using sbp_sig_list_[sbp]
std
::
vector
<
double
>
cost_
;
// Child node list
std
::
vector
<
SbpNode
*>
children_
;
// SbpSignature for each child node when using specific SbpSignature for this
// node Its dimension is Number of Child Nodes * Number of Available
// SbpSignatures for this node
std
::
vector
<
std
::
vector
<
int32_t
>>
child_node_sbp_sig_
;
// Merge two nodes into this compound node
std
::
vector
<
SbpNode
*>
half_node_
;
// We should delete those merged-signatures which has very large cost for speed up
// New sbp_sig_list_ index map to each half_node_'s sig_index
std
::
vector
<
std
::
pair
<
int32_t
,
int32_t
>>
merged_sig_id2children_sig_id_
;
std
::
vector
<
BinarySet
>
parallel_candidates_
;
OpNode
*
op_node_
=
nullptr
;
// We divide the sbp graph into multiple layers.
// min_layer_ is the minimum layer number to run this op as soon as possible.
// max_layer_ is the maximum layer number without slowing down the whole process of the graph.
// producer.max_layer_ < this_node.min_layer_ <= this_node.max_layer_ < consumer.min_layer_
int32_t
min_layer_
=
-
1
,
max_layer_
=
-
1
;
// Maximum layer in tributaries
int32_t
tributary_layer_
=
-
1
;
// Whether we are on the trunk
bool
on_trunk_
=
false
;
// A counter_ buffer for topological traversal or something else
int32_t
counter_
=
0
;
// Accumulate trunk cost from consumer to the end
double
acc_trunk_cost_
=
-
1.0
;
// Let one node point to another
void
StartPointToEnd
(
SbpNode
*
start_node
,
SbpNode
*
end_node
);
// Evaluate summery of cost in 1-ring neighborhood.
double
EvalNbhCost
()
const
;
// Drop down the maximum layer with the minimum layer from consumer
void
DropMaxLayer
(
int32_t
upper_bound
);
// Drop down the available wait time with the minimum cost from downstream
void
DropAvailWaitTime
(
double
curr_trunk_cost
);
};
// class SbpNode
}
// namespace auto_parallel
}
// namespace oneflow
#endif // ONEFLOW_CORE_AUTO_PARALLEL_SBP_NODE_H_
oneflow/core/auto_parallel/sbp_util.cpp
0 → 100644
View file @
a715222c
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include <memory>
#include "oneflow/core/auto_parallel/sbp_util.h"
#include "oneflow/core/common/data_type.h"
#include "oneflow/core/job/sbp_parallel.h"
#include "oneflow/core/graph/boxing/hierarchical_sub_task_graph_builder_impl.h"
namespace
oneflow
{
namespace
auto_parallel
{
// Judge whether we need the same SBP for both producer and consumer
bool
RequireSameSbp
(
const
OpNode
*
consumer
,
const
std
::
string
&
ibn
)
{
// is mutable
const
auto
&
input_blob_modifier_
=
consumer
->
op
().
InputBlobModifier4Ibn
(
ibn
);
if
(
input_blob_modifier_
.
has_is_mutable
()
&&
input_blob_modifier_
.
is_mutable
())
{
return
true
;
}
// kOFRecord or kTensorBuffer don't accept boxing
const
LogicalBlobId
&
lbi
=
consumer
->
op
().
BnInOp2Lbi
(
ibn
);
const
OpNode
&
producer
=
consumer
->
ProducerOpNode4Lbi
(
lbi
);
const
BlobDesc
&
logical_blob_desc
=
producer
.
LogicalBlobDesc4Lbi
(
lbi
);
return
(
logical_blob_desc
.
data_type
()
==
DataType
::
kOFRecord
||
logical_blob_desc
.
data_type
()
==
DataType
::
kTensorBuffer
);
}
}
// namespace auto_parallel
}
// namespace oneflow
Prev
1
…
4
5
6
7
8
9
10
11
12
…
24
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment