"vscode:/vscode.git/clone" did not exist on "eb3ec5082ace9154438058f9fa4077c50960e0ba"
Commit a715222c authored by yuguo's avatar yuguo
Browse files

0.9.1-rocm

parent f262efc9
...@@ -90,8 +90,8 @@ ONEFLOW_API_PYBIND11_MODULE("sbp", m) { ...@@ -90,8 +90,8 @@ ONEFLOW_API_PYBIND11_MODULE("sbp", m) {
m.attr("max_split_axis") = kMaxSplitAxis; m.attr("max_split_axis") = kMaxSplitAxis;
py::class_<Symbol<SbpParallel>, std::shared_ptr<Symbol<SbpParallel>>>(m, "sbp", py::class_<Symbol<SbpParallel>, std::shared_ptr<Symbol<SbpParallel>>>(m, "sbp",
py::dynamic_attr()) py::dynamic_attr())
.def("__str__", &api::SbpToString) .def("__str__", &api::ApiSbpToString)
.def("__repr__", &api::SbpToString) .def("__repr__", &api::ApiSbpToString)
.def(py::self == py::self) .def(py::self == py::self)
.def(py::hash(py::self)) .def(py::hash(py::self))
.def("_ToAttrStr", .def("_ToAttrStr",
......
...@@ -15,13 +15,14 @@ limitations under the License. ...@@ -15,13 +15,14 @@ limitations under the License.
*/ */
#include "oneflow/api/python/utils/tensor_utils.h" #include "oneflow/api/python/utils/tensor_utils.h"
#include "oneflow/api/python/ofblob/ofblob.e.h"
#include "oneflow/core/autograd/autograd_engine.h" #include "oneflow/core/autograd/autograd_engine.h"
#include "oneflow/core/common/container_util.h" #include "oneflow/core/common/container_util.h"
#include "oneflow/core/common/switch_func.h" #include "oneflow/core/common/switch_func.h"
#include "oneflow/core/common/tensor_buffer.h" #include "oneflow/core/common/tensor_buffer.h"
#include "oneflow/core/framework/nd_sbp.h" #include "oneflow/core/framework/nd_sbp.h"
#include "oneflow/core/functional/functional.h" #include "oneflow/core/functional/functional.h"
#include "oneflow/core/job/global_mode.h"
#include "oneflow/core/kernel/kernel_util.h"
#include "oneflow/extension/python/numpy.h" #include "oneflow/extension/python/numpy.h"
#include "oneflow/core/common/decorator.h" #include "oneflow/core/common/decorator.h"
#include "oneflow/core/framework/consistency_check.h" #include "oneflow/core/framework/consistency_check.h"
...@@ -32,11 +33,11 @@ namespace py = pybind11; ...@@ -32,11 +33,11 @@ namespace py = pybind11;
namespace oneflow { namespace oneflow {
namespace one { namespace one {
Maybe<void> EagerMirroredTensorZeros(const std::shared_ptr<Tensor>& t) { Maybe<void> EagerLocalTensorZeros(const std::shared_ptr<Tensor>& t) {
JUST(functional::CheckInplaceValid(t)); JUST(functional::CheckInplaceValid(t));
std::shared_ptr<MirroredTensor> local_tensor; std::shared_ptr<LocalTensor> local_tensor;
if (t->is_local()) { if (t->is_local()) {
local_tensor = JUST(t->AsMirroredTensor()); local_tensor = JUST(t->AsLocalTensor());
} else { } else {
local_tensor = JUST(t->cur_rank_phy_tensor()); local_tensor = JUST(t->cur_rank_phy_tensor());
} }
...@@ -44,9 +45,9 @@ Maybe<void> EagerMirroredTensorZeros(const std::shared_ptr<Tensor>& t) { ...@@ -44,9 +45,9 @@ Maybe<void> EagerMirroredTensorZeros(const std::shared_ptr<Tensor>& t) {
JUST(PhysicalRun([&](InstructionsBuilder* builder) -> Maybe<void> { JUST(PhysicalRun([&](InstructionsBuilder* builder) -> Maybe<void> {
JUST(builder->AccessBlobByCallback( JUST(builder->AccessBlobByCallback(
local_tensor, local_tensor,
[](uint64_t of_blob_ptr) { [](ep::Stream* stream, const std::shared_ptr<vm::EagerBlobObject>& eager_blob_object) {
auto* of_blob = reinterpret_cast<OfBlob*>(of_blob_ptr); AutoMemset(stream, eager_blob_object->mut_dptr(), 0,
of_blob->AsyncAutoMemset(0); eager_blob_object->ByteSizeOfBlobBody(), eager_blob_object->mem_case());
}, },
"mut")); "mut"));
return Maybe<void>::Ok(); return Maybe<void>::Ok();
...@@ -54,38 +55,25 @@ Maybe<void> EagerMirroredTensorZeros(const std::shared_ptr<Tensor>& t) { ...@@ -54,38 +55,25 @@ Maybe<void> EagerMirroredTensorZeros(const std::shared_ptr<Tensor>& t) {
return Maybe<void>::Ok(); return Maybe<void>::Ok();
} }
template<typename T> namespace {
Maybe<void> CopyMirroredTensorFromUntypedArray(const std::shared_ptr<Tensor>& tensor, void CopyFromNumpyArray(ep::Stream* stream,
PyObject* array) { const std::shared_ptr<vm::EagerBlobObject>& eager_blob_object,
return CopyBetweenMirroredTensorAndNumpy<T>(tensor, array, BlobNumpyCopyUtil<T>::From, "mut", const NumPyArrayPtr& array_ptr) {
/*block_host_until_done=*/false); SyncAutoMemcpy(stream, eager_blob_object->mut_dptr(), array_ptr.data(),
} eager_blob_object->ByteSizeOfBlobBody(), eager_blob_object->mem_case(),
memory::MakeHostMemCase());
Maybe<std::string> GetCopyMirroredTensorToNumpyFuncName(DataType dtype) {
using namespace oneflow;
static const HashMap<int64_t, std::shared_ptr<std::string>> data_type2func_name{
#define DATA_TYPE_FUNC_NAME_PAIR(type_cpp, type_proto) \
{type_proto, std::make_shared<std::string>("_copy_to_numpy_" #type_cpp)},
OF_PP_FOR_EACH_TUPLE(DATA_TYPE_FUNC_NAME_PAIR, POD_DATA_TYPE_SEQ)
#undef DATA_TYPE_FUNC_NAME_PAIR
};
return JUST(MapAt(data_type2func_name, static_cast<int64_t>(dtype)));
} }
} // namespace
Maybe<std::string> GetCopyMirroredTensorFromNumpyFuncName(DataType dtype) { Maybe<void> CopyLocalTensorFromUntypedArray(const std::shared_ptr<Tensor>& tensor,
using namespace oneflow; PyObject* array) {
static const HashMap<int64_t, std::shared_ptr<std::string>> data_type2func_name{ return CopyBetweenLocalTensorAndNumpy(tensor, array, CopyFromNumpyArray, "mut",
#define DATA_TYPE_FUNC_NAME_PAIR(type_cpp, type_proto) \ /*block_host_until_done=*/false);
{type_proto, std::make_shared<std::string>("_copy_from_numpy_" #type_cpp)},
OF_PP_FOR_EACH_TUPLE(DATA_TYPE_FUNC_NAME_PAIR, POD_DATA_TYPE_SEQ)
#undef DATA_TYPE_FUNC_NAME_PAIR
};
return JUST(MapAt(data_type2func_name, static_cast<int64_t>(dtype)));
} }
Maybe<std::tuple<std::vector<Shape>, std::vector<Symbol<DType>>>> Maybe<std::tuple<std::vector<Shape>, std::vector<Symbol<DType>>>>
MaybeGetTensorBufferShapesAndDTypes(const std::shared_ptr<Tensor>& t) { MaybeGetTensorBufferShapesAndDTypes(const std::shared_ptr<Tensor>& t) {
const auto& tensor = JUST(t->AsMirroredTensor()); const auto& tensor = JUST(t->AsLocalTensor());
if (tensor->dtype() != DType::TensorBuffer()) { if (tensor->dtype() != DType::TensorBuffer()) {
return Error::RuntimeError() << "tensor buffer supported only"; return Error::RuntimeError() << "tensor buffer supported only";
} }
...@@ -93,10 +81,11 @@ MaybeGetTensorBufferShapesAndDTypes(const std::shared_ptr<Tensor>& t) { ...@@ -93,10 +81,11 @@ MaybeGetTensorBufferShapesAndDTypes(const std::shared_ptr<Tensor>& t) {
std::vector<Shape> shapes; std::vector<Shape> shapes;
std::vector<Symbol<DType>> dtypes; std::vector<Symbol<DType>> dtypes;
auto btb = std::make_shared<BlockingThenBusy>(1); auto btb = std::make_shared<BlockingThenBusy>();
JUST(PhysicalRun([&](InstructionsBuilder* builder) -> Maybe<void> { JUST(PhysicalRun([&](InstructionsBuilder* builder) -> Maybe<void> {
return builder->SyncAccessBlobByCallback( return builder->SyncAccessBlobByCallback(
tensor, btb, [](uint64_t) {}, "const"); tensor, btb, [](ep::Stream* stream, const std::shared_ptr<vm::EagerBlobObject>&) {},
"const");
})); }));
JUST(btb->WaitUntilCntEqualZero(VirtualMachine::GetPredicatorNoMoreInstructionsFinished())); JUST(btb->WaitUntilCntEqualZero(VirtualMachine::GetPredicatorNoMoreInstructionsFinished()));
...@@ -136,41 +125,51 @@ Maybe<py::tuple> TensorGetPyTupleOfSbp(const Tensor& tensor) { ...@@ -136,41 +125,51 @@ Maybe<py::tuple> TensorGetPyTupleOfSbp(const Tensor& tensor) {
return tuple; return tuple;
} }
#define MAKE_SWITCH_ENTRY(func_name, dtype) func_name<dtype>
DEFINE_STATIC_SWITCH_FUNC(Maybe<void>, CopyMirroredTensorFromUntypedArray, MAKE_SWITCH_ENTRY,
MAKE_DATA_TYPE_CTRV_SEQ(POD_AND_HALF_DATA_TYPE_SEQ));
Maybe<Tensor> MakeLocalTensorFromData(PyObject* data, const Optional<Symbol<DType>>& dtype, Maybe<Tensor> MakeLocalTensorFromData(PyObject* data, const Optional<Symbol<DType>>& dtype,
const Optional<Symbol<Device>>& device, const Optional<Symbol<Device>>& device,
const bool requires_grad, const bool pin_memory) { const bool requires_grad, const bool pin_memory) {
PyObject* array = NULL; bool is_bfloat16_dtype = dtype ? JUST(dtype)->data_type() == DataType::kBFloat16 : false;
bool is_cuda_device = device ? JUST(device)->enum_type() == DeviceType::kCUDA : false;
if (is_bfloat16_dtype && is_cuda_device) {
#if (CUDA_VERSION < 11000)
return Error::RuntimeError()
<< "Cannot create a bfloat16 tensor on gpu under cuda version: 11000";
#endif // CUDA_VERSION >= 11000
#ifdef WITH_ROCM
return Error::RuntimeError()
<< "Cannot create a bfloat16 tensor on gpu under ROCm for now";
#endif // WITH_ROCM
}
PyArray_Descr* np_dtype = PyArray_Descr* np_dtype =
dtype.has_value() dtype.has_value() && !is_bfloat16_dtype
? PyArray_DescrFromType(JUST(numpy::OFDataTypeToNumpyType(JUST(dtype)->data_type()))) ? PyArray_DescrFromType(JUST(numpy::OFDataTypeToNumpyType(JUST(dtype)->data_type())))
: nullptr; : nullptr;
// PyArray_FromAny steals a reference to np_dtype object, so no need to decref it.
// NPY_ARRAY_DEFAULT is NPY_ARRAY_C_CONTIGUOUS | NPY_ARRAY_BEHAVED, so the // NPY_ARRAY_DEFAULT is NPY_ARRAY_C_CONTIGUOUS | NPY_ARRAY_BEHAVED, so the
// array with NPY_ARRAY_DEFAULT flag is C-style contiguous. // array with NPY_ARRAY_DEFAULT flag is C-style contiguous.
// NPY_ARRAY_FORCECAST is needed otherwise there will a segfault. // NPY_ARRAY_FORCECAST is needed otherwise there will a segfault.
array = PyArray_FromAny(data, np_dtype, 0, 0, //
NPY_ARRAY_DEFAULT | NPY_ARRAY_ENSURECOPY | NPY_ARRAY_FORCECAST, nullptr); // Even though PyArray_FromAny can cast the input array to the desired dtype
if (!array) { // if `dtype` argument is set, it fails to handle the following case:
// >> x = [flow.tensor([1, 2])] * 3 <-- x is a list of flow.Tensor
// >> y = flow.tensor(x, dtype=flow.float32) <-- returns nullptr
// However, the following case without `dtype` argument works well:
// >> x = [flow.tensor([1, 2])] * 3
// >> y = flow.tensor(x)
// So we cast the input array to the desired dtype manually.
PyArrayObject* _array = reinterpret_cast<PyArrayObject*>(
PyArray_FromAny(data, nullptr, 0, 0,
NPY_ARRAY_DEFAULT | NPY_ARRAY_ENSURECOPY | NPY_ARRAY_FORCECAST, nullptr));
if (!_array) {
return Error::RuntimeError() << "Can not convert input data to a new numpy array."; return Error::RuntimeError() << "Can not convert input data to a new numpy array.";
} }
// flow.tensor([1., 2.]).dtype should be flow.float32 rather than flow.float64 // PyArray_FromArray steals a reference to np_dtype object, so no need to decref it.
if (!PyArray_Check(data)) { PyObject* array = PyArray_FromArray(
int np_array_type = PyArray_TYPE(reinterpret_cast<PyArrayObject*>(array)); _array, np_dtype, NPY_ARRAY_DEFAULT | NPY_ARRAY_ENSURECOPY | NPY_ARRAY_FORCECAST);
// Cast to float if data is double sequence, rather than numpy array. Py_DECREF(_array);
if (np_array_type == NPY_DOUBLE && np_dtype == nullptr) {
PyObject* fp32_array = PyArray_Cast(reinterpret_cast<PyArrayObject*>(array), NPY_FLOAT);
Py_DECREF(array);
array = fp32_array;
}
}
auto* np_arr = reinterpret_cast<PyArrayObject*>(array); auto* np_arr = reinterpret_cast<PyArrayObject*>(array);
const npy_intp* dims_ptr = PyArray_SHAPE(np_arr); const npy_intp* dims_ptr = PyArray_SHAPE(np_arr);
const Shape shape(DimVector(dims_ptr, dims_ptr + PyArray_NDIM(np_arr))); const Shape shape(DimVector(dims_ptr, dims_ptr + PyArray_NDIM(np_arr)));
DataType data_type = JUST(numpy::GetOFDataTypeFromNpArray(np_arr)); DataType np_data_type = JUST(numpy::GetOFDataTypeFromNpArray(np_arr));
Symbol<Device> device_; Symbol<Device> device_;
if (device) { if (device) {
...@@ -179,10 +178,17 @@ Maybe<Tensor> MakeLocalTensorFromData(PyObject* data, const Optional<Symbol<DTyp ...@@ -179,10 +178,17 @@ Maybe<Tensor> MakeLocalTensorFromData(PyObject* data, const Optional<Symbol<DTyp
device_ = JUST(Device::New("cpu")); device_ = JUST(Device::New("cpu"));
} }
std::shared_ptr<Tensor> tensor = JUST( std::shared_ptr<Tensor> tensor = JUST(
functional::Empty(shape, JUST(DType::Get(data_type)), device_, /*pin_memory=*/pin_memory)); functional::Empty(shape, JUST(DType::Get(np_data_type)), device_, /*pin_memory=*/pin_memory));
JUST(SwitchCopyMirroredTensorFromUntypedArray(SwitchCase(data_type), tensor, array)); JUST(CopyLocalTensorFromUntypedArray(tensor, array));
Py_DECREF(array); Py_DECREF(array);
if (dtype && JUST(dtype)->data_type() != np_data_type) {
tensor = JUST(functional::To(tensor, JUST(dtype), false));
} else if (!dtype && !PyArray_Check(data) && tensor->dtype()->is_floating_point()
&& GetDefaultDType() != tensor->dtype()) {
// If it not assign dtype and created from PySequence, cast tensor to default floating dtype
tensor = JUST(functional::To(tensor, JUST(DType::Get(DataType::kFloat)), false));
}
JUST(tensor->set_requires_grad(requires_grad)); JUST(tensor->set_requires_grad(requires_grad));
return tensor; return tensor;
} }
...@@ -201,10 +207,10 @@ auto* CachedGetAllBroadcastNdSbp = DECORATE(&GetAllBroadcastNdSbp, ThreadLocal); ...@@ -201,10 +207,10 @@ auto* CachedGetAllBroadcastNdSbp = DECORATE(&GetAllBroadcastNdSbp, ThreadLocal);
} // namespace } // namespace
Maybe<Tensor> MakeConsistentTensorFromData(PyObject* data, const Optional<Symbol<DType>>& dtype, Maybe<Tensor> MakeGlobalTensorFromData(PyObject* data, const Optional<Symbol<DType>>& dtype,
Symbol<ParallelDesc> placement, Symbol<ParallelDesc> placement,
const std::vector<Symbol<SbpParallel>>& sbp_tuple, const std::vector<Symbol<SbpParallel>>& sbp_tuple,
const bool requires_grad) { const bool requires_grad) {
PyObject* array = NULL; PyObject* array = NULL;
if (PyArray_Check(data)) { if (PyArray_Check(data)) {
// Only NPY_CORDER is supported, and returns a new C-style contiguous array. // Only NPY_CORDER is supported, and returns a new C-style contiguous array.
...@@ -229,9 +235,13 @@ Maybe<Tensor> MakeConsistentTensorFromData(PyObject* data, const Optional<Symbol ...@@ -229,9 +235,13 @@ Maybe<Tensor> MakeConsistentTensorFromData(PyObject* data, const Optional<Symbol
} }
Symbol<Device> device = JUST(Device::New(placement->device_tag())); Symbol<Device> device = JUST(Device::New(placement->device_tag()));
std::shared_ptr<Tensor> local_tensor = std::shared_ptr<Tensor> local_tensor;
JUST(functional::Empty(shape, JUST(DType::Get(data_type)), device, /*pin_memory=*/false)); {
JUST(SwitchCopyMirroredTensorFromUntypedArray(SwitchCase(data_type), local_tensor, array)); GlobalMode::Guard guard(/* disable global mode */ false);
local_tensor =
JUST(functional::Empty(shape, JUST(DType::Get(data_type)), device, /*pin_memory=*/false));
}
JUST(CopyLocalTensorFromUntypedArray(local_tensor, array));
Py_DECREF(array); Py_DECREF(array);
// Cast to float if data is double sequence, rather than numpy array. // Cast to float if data is double sequence, rather than numpy array.
...@@ -246,14 +256,16 @@ Maybe<Tensor> MakeConsistentTensorFromData(PyObject* data, const Optional<Symbol ...@@ -246,14 +256,16 @@ Maybe<Tensor> MakeConsistentTensorFromData(PyObject* data, const Optional<Symbol
size_t sbp_dims = sbp_tuple.size(); size_t sbp_dims = sbp_tuple.size();
Symbol<NdSbp> broadcast_nd_sbp = JUST(CachedGetAllBroadcastNdSbp(sbp_dims)); Symbol<NdSbp> broadcast_nd_sbp = JUST(CachedGetAllBroadcastNdSbp(sbp_dims));
std::shared_ptr<Tensor> broadcast_tensor = JUST(functional::LocalToConsistent( std::shared_ptr<Tensor> broadcast_tensor = JUST(
local_tensor, placement, *JUST(GetSbpList(broadcast_nd_sbp)), shape, local_tensor->dtype())); functional::LocalToGlobal(local_tensor, placement, *JUST(GetSbpList(broadcast_nd_sbp)), shape,
local_tensor->dtype(), /* sync_data */ true, /*copy=*/false));
std::vector<Symbol<SbpParallel>> grad_sbp_tuple; std::vector<Symbol<SbpParallel>> grad_sbp_tuple;
auto consistent_tensor = JUST(functional::ToConsistent(broadcast_tensor, placement, sbp_tuple, auto global_tensor =
grad_sbp_tuple, /* check_meta */ false)); JUST(functional::ToGlobal(broadcast_tensor, placement, sbp_tuple, grad_sbp_tuple,
JUST(consistent_tensor->set_requires_grad(requires_grad)); /* check_meta */ false, /*copy=*/false));
return consistent_tensor; JUST(global_tensor->set_requires_grad(requires_grad));
return global_tensor;
} }
Maybe<Tensor> MakeTensorFromOtherTensor(const std::shared_ptr<Tensor>& other, Maybe<Tensor> MakeTensorFromOtherTensor(const std::shared_ptr<Tensor>& other,
...@@ -265,9 +277,9 @@ Maybe<Tensor> MakeTensorFromOtherTensor(const std::shared_ptr<Tensor>& other, ...@@ -265,9 +277,9 @@ Maybe<Tensor> MakeTensorFromOtherTensor(const std::shared_ptr<Tensor>& other,
const Symbol<NdSbp>& nd_sbp = JUST(other->nd_sbp()); const Symbol<NdSbp>& nd_sbp = JUST(other->nd_sbp());
const std::vector<Symbol<SbpParallel>>& sbp_tuple = *JUST(GetSbpList(nd_sbp)); const std::vector<Symbol<SbpParallel>>& sbp_tuple = *JUST(GetSbpList(nd_sbp));
std::vector<Symbol<SbpParallel>> grad_sbp_tuple; std::vector<Symbol<SbpParallel>> grad_sbp_tuple;
// TODO:(zhaoluyang) consistent case support pin_memory // TODO:(zhaoluyang) global case support pin_memory
return functional::ToConsistent(other, JUST(other->parallel_desc()), sbp_tuple, grad_sbp_tuple, return functional::ToGlobal(other, JUST(other->parallel_desc()), sbp_tuple, grad_sbp_tuple,
/* check_meta */ false); /* check_meta */ false, /*copy=*/false);
} }
} }
...@@ -283,7 +295,7 @@ Maybe<Tensor> MakeTensorFromOtherTensor(const std::shared_ptr<Tensor>& other, ...@@ -283,7 +295,7 @@ Maybe<Tensor> MakeTensorFromOtherTensor(const std::shared_ptr<Tensor>& other,
tensor = JUST(functional::Copy(other, device_->type(), device_->device_id(), tensor = JUST(functional::Copy(other, device_->type(), device_->device_id(),
pin_memory && !dtype.has_value())); pin_memory && !dtype.has_value()));
} else { } else {
tensor = JUST(functional::ConsistentToLocal(other)); tensor = JUST(functional::GlobalToLocal(other, /*copy=*/false));
if (!device) { device_ = JUST(Device::New("cpu")); } if (!device) { device_ = JUST(Device::New("cpu")); }
tensor = JUST(functional::Copy(tensor, device_->type(), device_->device_id(), tensor = JUST(functional::Copy(tensor, device_->type(), device_->device_id(),
pin_memory && !dtype.has_value())); pin_memory && !dtype.has_value()));
...@@ -302,9 +314,9 @@ Maybe<Tensor> MakeTensorFromOtherTensor(const std::shared_ptr<Tensor>& other, ...@@ -302,9 +314,9 @@ Maybe<Tensor> MakeTensorFromOtherTensor(const std::shared_ptr<Tensor>& other,
const std::vector<Symbol<SbpParallel>>& sbp_tuple, const std::vector<Symbol<SbpParallel>>& sbp_tuple,
const bool requires_grad) { const bool requires_grad) {
std::vector<Symbol<SbpParallel>> grad_sbp_tuple; std::vector<Symbol<SbpParallel>> grad_sbp_tuple;
bool check_meta = other->is_consistent() ? false : true; bool check_meta = other->is_global() ? false : true;
std::shared_ptr<Tensor> tensor = std::shared_ptr<Tensor> tensor = JUST(functional::ToGlobal(
JUST(functional::ToConsistent(other, placement, sbp_tuple, grad_sbp_tuple, check_meta)); other, placement, sbp_tuple, grad_sbp_tuple, check_meta, /*copy=*/false));
if (dtype) { if (dtype) {
const Symbol<DType>& dtype_ = JUST(dtype); const Symbol<DType>& dtype_ = JUST(dtype);
if (tensor->dtype() != dtype_) { if (tensor->dtype() != dtype_) {
......
...@@ -29,10 +29,13 @@ limitations under the License. ...@@ -29,10 +29,13 @@ limitations under the License.
#include "oneflow/core/framework/instructions_builder.h" #include "oneflow/core/framework/instructions_builder.h"
#include "oneflow/core/framework/tensor.h" #include "oneflow/core/framework/tensor.h"
#include "oneflow/core/common/stride.h" #include "oneflow/core/common/stride.h"
#include "oneflow/core/register/ofblob.h"
#include "oneflow/core/common/blocking_then_busy.h" #include "oneflow/core/common/blocking_then_busy.h"
#include "oneflow/core/vm/virtual_machine.h" #include "oneflow/core/vm/virtual_machine.h"
#include "oneflow/core/common/foreign_lock_helper.h" #include "oneflow/core/common/foreign_lock_helper.h"
#include "oneflow/core/kernel/kernel_util.h"
#include "oneflow/api/python/functional/common.h"
#include "oneflow/core/framework/tensor_util.h"
#include "oneflow/core/profiler/profiler.h"
namespace py = pybind11; namespace py = pybind11;
...@@ -55,13 +58,13 @@ struct format_descriptor<oneflow::float16> { ...@@ -55,13 +58,13 @@ struct format_descriptor<oneflow::float16> {
namespace oneflow { namespace oneflow {
namespace one { namespace one {
Maybe<void> EagerMirroredTensorZeros(const std::shared_ptr<Tensor>& t); Maybe<void> EagerLocalTensorZeros(const std::shared_ptr<Tensor>& t);
template<typename T> template<typename T>
inline static Maybe<PyObject*> EagerMirroredTensorToNumpy(PyObject* py_tensor) { inline static Maybe<PyObject*> EagerLocalTensorToNumpy(PyObject* py_tensor) {
const auto& t = PyTensor_Unpack(py_tensor); const auto& t = PyTensor_Unpack(py_tensor);
std::shared_ptr<MirroredTensor> tensor = JUST(t->AsMirroredTensor()); std::shared_ptr<LocalTensor> tensor = JUST(t->AsLocalTensor());
CHECK_OR_RETURN(JUST(tensor->device()) == JUST(Device::New("cpu"))); CHECK_OR_RETURN(JUST(tensor->device()) == JUST(Device::New("cpu")));
CHECK_OR_RETURN(tensor->is_eager()) << "eager tensors supported only."; CHECK_OR_RETURN(tensor->is_eager()) << "eager tensors supported only.";
// set base object attr // set base object attr
...@@ -74,12 +77,13 @@ inline static Maybe<PyObject*> EagerMirroredTensorToNumpy(PyObject* py_tensor) { ...@@ -74,12 +77,13 @@ inline static Maybe<PyObject*> EagerMirroredTensorToNumpy(PyObject* py_tensor) {
numpy::OFStrideToNumpyStride(*JUST(tensor->stride()), tensor->dtype()->data_type()); numpy::OFStrideToNumpyStride(*JUST(tensor->stride()), tensor->dtype()->data_type());
T* data_ptr = nullptr; T* data_ptr = nullptr;
const auto& Callback = [&](uint64_t ofblob_ptr) { const auto& Callback = [&](ep::Stream*,
data_ptr = reinterpret_cast<OfBlob*>(ofblob_ptr)->mut_blob()->mut_dptr<T>(); const std::shared_ptr<vm::EagerBlobObject>& eager_blob_object) {
data_ptr = eager_blob_object->mut_dptr<T>();
}; };
auto btb = std::make_shared<BlockingThenBusy>(1); auto btb = std::make_shared<BlockingThenBusy>();
JUST(PhysicalRun([&](InstructionsBuilder* builder) -> Maybe<void> { JUST(PhysicalRun([&](InstructionsBuilder* builder) -> Maybe<void> {
return builder->SyncAccessBlobByCallback(tensor, btb, Callback, "mut"); return builder->SyncAccessBlobByCallback(tensor, btb, Callback, "const");
})); }));
JUST(btb->WaitUntilCntEqualZero(VirtualMachine::GetPredicatorNoMoreInstructionsFinished())); JUST(btb->WaitUntilCntEqualZero(VirtualMachine::GetPredicatorNoMoreInstructionsFinished()));
return py::array(py::buffer_info(data_ptr, sizeof(T), py::format_descriptor<T>::format(), ndim, return py::array(py::buffer_info(data_ptr, sizeof(T), py::format_descriptor<T>::format(), ndim,
...@@ -90,19 +94,43 @@ inline static Maybe<PyObject*> EagerMirroredTensorToNumpy(PyObject* py_tensor) { ...@@ -90,19 +94,43 @@ inline static Maybe<PyObject*> EagerMirroredTensorToNumpy(PyObject* py_tensor) {
} }
template<typename T> template<typename T>
inline Maybe<void> CopyBetweenMirroredTensorAndNumpy( struct TensorTypeToPyType final {
typedef T type;
};
template<>
struct TensorTypeToPyType<float16> final {
typedef float type;
};
template<>
struct TensorTypeToPyType<bfloat16> final {
typedef float type;
};
template<typename T>
inline static Maybe<PyObject*> EagerLocalTensorItem(const std::shared_ptr<Tensor>& tensor) {
// OF_PROFILER_RANGE_GUARD("EagerLocalTensorItem");
T value = JUST(GetItemInScalarTensor<T>(tensor));
return functional::CastToPyObject(static_cast<typename TensorTypeToPyType<T>::type>(value));
}
inline Maybe<void> CopyBetweenLocalTensorAndNumpy(
const std::shared_ptr<Tensor>& t, PyObject* array, const std::shared_ptr<Tensor>& t, PyObject* array,
Maybe<void> (*Copy)(uint64_t, const NumPyArrayPtr&), const std::string& modifier, void (*Copy)(ep::Stream*, const std::shared_ptr<vm::EagerBlobObject>&, const NumPyArrayPtr&),
bool block_host_until_done) { const std::string& modifier, bool block_host_until_done) {
auto tensor = JUST(t->AsMirroredTensor()); auto tensor = JUST(t->AsLocalTensor());
CHECK_OR_RETURN(tensor->is_contiguous()) << "contiguous tensors supported only.";
CHECK_OR_RETURN(tensor->is_eager()) << "eager tensors supported only."; CHECK_OR_RETURN(tensor->is_eager()) << "eager tensors supported only.";
if (block_host_until_done) { if (block_host_until_done) {
NumPyArrayPtr array_ptr(array); NumPyArrayPtr array_ptr(array);
const auto& Callback = [array_ptr, Copy](uint64_t ofblob_ptr) { const auto& Callback = [array_ptr, Copy](
CHECK_JUST(Copy(ofblob_ptr, array_ptr)); ep::Stream* stream,
const std::shared_ptr<vm::EagerBlobObject>& eager_blob_object) {
Copy(stream, eager_blob_object, array_ptr);
}; };
auto btb = std::make_shared<BlockingThenBusy>(1); auto btb = std::make_shared<BlockingThenBusy>();
JUST(PhysicalRun([&](InstructionsBuilder* builder) -> Maybe<void> { JUST(PhysicalRun([&](InstructionsBuilder* builder) -> Maybe<void> {
return builder->SyncAccessBlobByCallback(tensor, btb, Callback, modifier); return builder->SyncAccessBlobByCallback(tensor, btb, Callback, modifier);
})); }));
...@@ -119,17 +147,16 @@ inline Maybe<void> CopyBetweenMirroredTensorAndNumpy( ...@@ -119,17 +147,16 @@ inline Maybe<void> CopyBetweenMirroredTensorAndNumpy(
JUST(PhysicalRun([&](InstructionsBuilder* builder) -> Maybe<void> { JUST(PhysicalRun([&](InstructionsBuilder* builder) -> Maybe<void> {
return builder->AccessBlobByCallback( return builder->AccessBlobByCallback(
tensor, tensor,
[array_ptr, Copy](uint64_t ofblob_ptr) { CHECK_JUST(Copy(ofblob_ptr, array_ptr)); }, [array_ptr, Copy](ep::Stream* stream,
const std::shared_ptr<vm::EagerBlobObject>& eager_blob_object) {
Copy(stream, eager_blob_object, array_ptr);
},
modifier); modifier);
})); }));
} }
return Maybe<void>::Ok(); return Maybe<void>::Ok();
} }
Maybe<std::string> GetCopyMirroredTensorToNumpyFuncName(DataType dtype);
Maybe<std::string> GetCopyMirroredTensorFromNumpyFuncName(DataType dtype);
Maybe<std::tuple<std::vector<Shape>, std::vector<Symbol<DType>>>> Maybe<std::tuple<std::vector<Shape>, std::vector<Symbol<DType>>>>
MaybeGetTensorBufferShapesAndDTypes(const std::shared_ptr<Tensor>& t); MaybeGetTensorBufferShapesAndDTypes(const std::shared_ptr<Tensor>& t);
...@@ -144,10 +171,10 @@ Maybe<Tensor> MakeLocalTensorFromData(PyObject* data, const Optional<Symbol<DTyp ...@@ -144,10 +171,10 @@ Maybe<Tensor> MakeLocalTensorFromData(PyObject* data, const Optional<Symbol<DTyp
const Optional<Symbol<Device>>& device, const Optional<Symbol<Device>>& device,
const bool requires_grad, const bool pin_memory); const bool requires_grad, const bool pin_memory);
Maybe<Tensor> MakeConsistentTensorFromData(PyObject* data, const Optional<Symbol<DType>>& dtype, Maybe<Tensor> MakeGlobalTensorFromData(PyObject* data, const Optional<Symbol<DType>>& dtype,
Symbol<ParallelDesc> placement, Symbol<ParallelDesc> placement,
const std::vector<Symbol<SbpParallel>>& sbp_tuple, const std::vector<Symbol<SbpParallel>>& sbp_tuple,
const bool requires_grad); const bool requires_grad);
Maybe<Tensor> MakeTensorFromOtherTensor(const std::shared_ptr<Tensor>& other, Maybe<Tensor> MakeTensorFromOtherTensor(const std::shared_ptr<Tensor>& other,
const bool pin_memory); const bool pin_memory);
......
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/auto_parallel/algorithm_util.h"
namespace oneflow {
namespace auto_parallel {
// Inverse function of order
// The reason why we need the inverse_order, a.k.a id2order, instead of id2value is to eliminate
// equality. For example, we have v[0] < v[1] = v[2] < v[3] We do not know v[1] is before or after
// v[2] with comp(v[1], v[2]). But if we transfer it to order order[0] < order[1] < order[2] <
// order[3] We know the strict order.
void InverseOrder(const std::vector<int32_t>& order, std::vector<int32_t>& inverse_order) {
inverse_order.resize(order.size());
for (int32_t i = 0; i < order.size(); i++) { inverse_order[order[i]] = i; }
}
} // namespace auto_parallel
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_AUTO_PARALLEL_ALGORITHM_UTIL_H_
#define ONEFLOW_CORE_AUTO_PARALLEL_ALGORITHM_UTIL_H_
#include <vector>
#include <cstdlib>
#include <algorithm>
#include <unordered_map>
namespace oneflow {
namespace auto_parallel {
// this function is to remove the i-th element from a vector in Constant time.
// the vector should not care about ordering.
// Be more careful about this function. Make sure that the traveling order of
// the vector goes from back to front.
template<class T>
void RemoveFrom(std::vector<T>& v, int32_t i) {
v[i] = v.back();
v.pop_back();
}
template<class T>
void CheckAndRemoveFrom(std::vector<T>& v, T& t) {
for (int32_t i = v.size() - 1; i >= 0; i--) {
if (v[i] == t) {
RemoveFrom<T>(v, i);
break;
}
}
}
// Inverse function, which transfer a vector to an unordered_map.
template<class T>
void InverseFunction(const std::vector<T>& v, std::unordered_map<T, int32_t>& inverse_map) {
inverse_map.clear();
for (int32_t i = 0; i < v.size(); i++) { inverse_map[v[i]] = i; }
}
// When you want to sort something but you can not move any elements, use order.
// Decide the order of sorting in a list v, we have
// v[order[i]] < v[order[j]] for all i<j.
// We could define the comparison, then we have
// comp(v[order[i]], v[order[j]]) == true for all i<j.
template<class T, class Compare>
void DecideOrder(const T& v, std::vector<int32_t>& order, const Compare& comp) {
// Initialize order
order.resize(v.size());
for (int32_t i = 0; i < v.size(); i++) { order[i] = i; }
// sort
std::sort(order.begin(), order.end(), [&](int32_t i, int32_t j) { return comp(v[i], v[j]); });
}
// Inverse function of order
// The reason why we need the inverse_order, a.k.a id2order, instead of id2value is to eliminate
// equality. For example, we have v[0] < v[1] = v[2] < v[3] We do not know v[1] is before or after
// v[2] with comp(v[1], v[2]). But if we transfer it to order order[0] < order[1] < order[2] <
// order[3] We know the strict order.
void InverseOrder(const std::vector<int32_t>& order, std::vector<int32_t>& inverse_order);
} // namespace auto_parallel
static const double kFloatDeviationMinus = 0.9999999;
static const double kFloatDeviationPlus = 1.0000001;
} // namespace oneflow
#endif // ONEFLOW_CORE_AUTO_PARALLEL_ALGORITHM_UTIL_H_
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/auto_parallel/binary_set.h"
namespace oneflow {
namespace auto_parallel {
namespace {
// A static function for initialization of log_2 mapping
std::unordered_map<BinarySetEntryType, int32_t> InitLog2() {
std::unordered_map<BinarySetEntryType, int32_t> log_2;
for (int32_t i = 0; i < 8 * sizeof(BinarySetEntryType); i++) {
log_2[static_cast<BinarySetEntryType>(1 << i)] = i;
}
return log_2;
}
// Initialization of log_2 mapping
// Take log2 of a integer value: 2^n -> n.
const std::unordered_map<BinarySetEntryType, int32_t> log_2 = InitLog2();
} // namespace
// Constructor
BinarySet::BinarySet(int32_t size_of_set) : size_of_set_(size_of_set) {
int32_t k = (size_of_set - 1) / bit_entry_type_ + 1;
binary_set_values_.resize(k, 0);
}
// Initialization if needed
void BinarySet::Initialize(int32_t size_of_set) {
size_of_set_ = size_of_set;
int32_t k = (size_of_set - 1) / bit_entry_type_ + 1;
binary_set_values_.resize(k, 0);
}
// Clear all the elements in the set
void BinarySet::Clear() { binary_set_values_.assign(binary_set_values_.size(), 0); }
// Check if i-th element in this subset
bool BinarySet::CheckExistence(int32_t i) const {
int32_t k = i / bit_entry_type_;
int32_t j = i % bit_entry_type_;
return bool((binary_set_values_[k] >> j) & 1);
}
// Add i-th element into this subset
void BinarySet::AddEntry(int32_t i) {
int32_t k = i / bit_entry_type_;
int32_t j = i % bit_entry_type_;
binary_set_values_[k] |= (1 << j);
}
// Take i-th element out from this subset
void BinarySet::DeleteEntry(int32_t i) {
int32_t k = i / bit_entry_type_;
int32_t j = i % bit_entry_type_;
binary_set_values_[k] &= ~(1 << j);
}
// Get the union with another subset and store it into u
void BinarySet::UnionTo(const BinarySet& bs, BinarySet& u) {
for (int32_t k = 0; k < binary_set_values_.size(); k++) {
u.binary_set_values_[k] = binary_set_values_[k] | bs.binary_set_values_[k];
}
}
// If this binary set intersects another one
bool BinarySet::IfIntersect(const BinarySet& bs) const {
int32_t min_bs_size = std::min(binary_set_values_.size(), bs.binary_set_values_.size());
for (int32_t k = 0; k < min_bs_size; k++) {
if (binary_set_values_[k] & bs.binary_set_values_[k]) { return true; }
}
return false;
}
// Get the intersection with another subset and store it into i
void BinarySet::IntersectionTo(const BinarySet& bs, BinarySet& i) const {
int32_t min_bs_size = std::min(binary_set_values_.size(), bs.binary_set_values_.size());
if (min_bs_size > i.binary_set_values_.size()) { i.binary_set_values_.resize(min_bs_size, 0); }
for (int32_t k = 0; k < binary_set_values_.size(); k++) {
i.binary_set_values_[k] = binary_set_values_[k] & bs.binary_set_values_[k];
}
}
// Count number of elements in this subset
int32_t BinarySet::Total() const {
int32_t t = 0;
for (int32_t k = 0; k < binary_set_values_.size(); k++) {
BinarySetEntryType bsv = binary_set_values_[k];
bsv = (bsv & 0x5555555555555555) + ((bsv >> 1) & 0x5555555555555555);
bsv = (bsv & 0x3333333333333333) + ((bsv >> 2) & 0x3333333333333333);
bsv = (bsv & 0x0F0F0F0F0F0F0F0F) + ((bsv >> 4) & 0x0F0F0F0F0F0F0F0F);
bsv = (bsv & 0x00FF00FF00FF00FF) + ((bsv >> 8) & 0x00FF00FF00FF00FF);
bsv = (bsv & 0x0000FFFF0000FFFF) + ((bsv >> 16) & 0x0000FFFF0000FFFF);
// bsv = (bsv & 0x00000000FFFFFFFF) + ((bsv >> 32) & 0x00000000FFFFFFFF);
t += int32_t(bsv);
}
return t;
}
// Output all the elements in the subset
void BinarySet::Output(std::vector<int32_t>& out) const {
out.clear();
for (int32_t i = 0; i < size_of_set_; i++) {
if (CheckExistence(i)) { out.emplace_back(i); }
}
}
// Output all the elements in the subset
void BinarySet::QuickOutput(std::vector<int32_t>& out) const {
out.clear();
for (int32_t i = 0; i < binary_set_values_.size(); i++) {
BinarySetEntryType x = binary_set_values_[i];
BinarySetEntryType y = 0;
while (x) {
y = x;
x &= x - 1;
out.emplace_back(i * BinarySet::bit_entry_type_ + log_2.find(y - x)->second);
}
}
}
// Add elements of input into this subset
void BinarySet::AddEntries(std::vector<int32_t>& in) {
for (int32_t i : in) { AddEntry(i); }
}
// If two binary sets are equal to each other
bool BinarySet::operator==(const BinarySet& rhs) const {
if (size_of_set_ != rhs.size_of_set_) { return false; }
for (int32_t i = 0; i < binary_set_values_.size(); i++) {
if (binary_set_values_[i] != rhs.binary_set_values_[i]) { return false; }
}
return true;
}
} // namespace auto_parallel
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_AUTO_PARALLEL_BINARY_SET_H_
#define ONEFLOW_CORE_AUTO_PARALLEL_BINARY_SET_H_
#include <cstdlib>
#include <unordered_map>
#include <vector>
#include "oneflow/core/common/hash.h"
namespace oneflow {
namespace auto_parallel {
// log_2_ index only support 32-bit int. Don't know why.
// Don't have any other bugs for unsigned int.
using BinarySetEntryType = unsigned int;
class BinarySet {
public:
BinarySet() {}
explicit BinarySet(int32_t size_of_set);
// Initialization
void Initialize(int32_t size_of_set);
// Clear all the elements in the set
void Clear();
// Check if i-th element in this subset
bool CheckExistence(int32_t i) const;
// Add i-th element into this subset
void AddEntry(int32_t i);
// Take i-th element out from this subset
void DeleteEntry(int32_t i);
// Get the union with another subset and store it into u
void UnionTo(const BinarySet& bs, BinarySet& u);
// If this binary set intersects another one
bool IfIntersect(const BinarySet& bs) const;
// Get the intersection with another subset and store it into i
void IntersectionTo(const BinarySet& bs, BinarySet& i) const;
// Count number of elements in this subset
int32_t Total() const;
// Output all the elements in the subset
void Output(std::vector<int32_t>& out) const;
// Output all the elements in the subset
void QuickOutput(std::vector<int32_t>& out) const;
// Add elements of input into this subset
void AddEntries(std::vector<int32_t>& in);
// If two binary sets are equal to each other
bool operator==(const BinarySet& rhs) const;
inline int32_t GetSizeOfSet() const { return size_of_set_; };
private:
friend struct BinarySetHasher;
// binary_set_values_ contains a vector of 64-bit or 32-bit int.
// Each bit means whether an entry is in the set
std::vector<BinarySetEntryType> binary_set_values_;
int32_t size_of_set_ = -1;
// total bits of the entry type in vector binary_set_values_.
static constexpr int32_t bit_entry_type_ = 8 * sizeof(BinarySetEntryType);
};
struct BinarySetHasher {
std::size_t operator()(const BinarySet& bs) const {
using std::hash;
using std::size_t;
size_t h = 0;
for (int i = 0; i < bs.binary_set_values_.size(); i++) {
h = HashCombine(h, hash<BinarySetEntryType>()(bs.binary_set_values_[i]));
}
return h;
};
};
} // namespace auto_parallel
} // namespace oneflow
#endif // ONEFLOW_CORE_AUTO_PARALLEL_BINARY_SET_H_
...@@ -16,8 +16,10 @@ limitations under the License. ...@@ -16,8 +16,10 @@ limitations under the License.
#include <memory> #include <memory>
#include <string> #include <string>
#include "oneflow/core/auto_parallel/algorithm_util.h"
#include "oneflow/core/auto_parallel/boxing_collector.h" #include "oneflow/core/auto_parallel/boxing_collector.h"
#include "oneflow/core/common/data_type.h" #include "oneflow/core/common/data_type.h"
#include "oneflow/core/common/device_type.pb.h"
#include "oneflow/core/common/maybe.h" #include "oneflow/core/common/maybe.h"
#include "oneflow/core/framework/nd_sbp.h" #include "oneflow/core/framework/nd_sbp.h"
#include "oneflow/core/job/global_for.h" #include "oneflow/core/job/global_for.h"
...@@ -34,9 +36,9 @@ limitations under the License. ...@@ -34,9 +36,9 @@ limitations under the License.
namespace oneflow { namespace oneflow {
namespace { namespace {
void DfsSetNdSbp(const std::vector<::oneflow::SbpParallel>& id2sbp_parallel, int32_t depth, void DfsSetNdSbp(const std::vector<SbpParallel>& id2sbp_parallel, int32_t depth, int32_t max_depth,
int32_t max_depth, NdSbp& nd_sbp, std::vector<NdSbp>& nd_sbp_lists, NdSbp& nd_sbp, std::vector<NdSbp>& nd_sbp_lists,
std::unordered_map<::oneflow::NdSbp, int32_t>& nd_sbp_universe) { std::unordered_map<NdSbp, int32_t>& nd_sbp_universe) {
if (depth == max_depth) { if (depth == max_depth) {
nd_sbp_universe[nd_sbp] = nd_sbp_lists.size(); nd_sbp_universe[nd_sbp] = nd_sbp_lists.size();
nd_sbp_lists.push_back(nd_sbp); nd_sbp_lists.push_back(nd_sbp);
...@@ -49,7 +51,7 @@ void DfsSetNdSbp(const std::vector<::oneflow::SbpParallel>& id2sbp_parallel, int ...@@ -49,7 +51,7 @@ void DfsSetNdSbp(const std::vector<::oneflow::SbpParallel>& id2sbp_parallel, int
} }
// Let a nd sbp be consistent with the given hierarchy number // Let a nd sbp be consistent with the given hierarchy number
Maybe<NdSbp> SetNdSbpDim(NdSbp nd_sbp, int32_t hierarchy_num) { Maybe<NdSbp> SetNdSbpDim(const NdSbp& nd_sbp, int32_t hierarchy_num) {
// Do not need to change // Do not need to change
if (nd_sbp.sbp_parallel_size() == hierarchy_num) { return nd_sbp; } if (nd_sbp.sbp_parallel_size() == hierarchy_num) { return nd_sbp; }
// (S0, S0) -> S0 // (S0, S0) -> S0
...@@ -71,9 +73,63 @@ Maybe<NdSbp> SetNdSbpDim(NdSbp nd_sbp, int32_t hierarchy_num) { ...@@ -71,9 +73,63 @@ Maybe<NdSbp> SetNdSbpDim(NdSbp nd_sbp, int32_t hierarchy_num) {
return new_sbp; return new_sbp;
} }
int32_t TotalNumSplit(const NdSbp& nd_sbp, const ParallelDesc& parallel_desc) {
int32_t total_num_split = 1;
for (int32_t i = 0; i < nd_sbp.sbp_parallel_size(); i++) {
if (nd_sbp.sbp_parallel(i).has_split_parallel()) {
total_num_split *= parallel_desc.hierarchy()->At(i);
}
}
return total_num_split;
}
// Dealing with 1D sbp to 1D sbp
// Specifically, S -> P.
Maybe<void> AskSbpCombinationFor1DSbp(const NdSbp& sbp_producer, const NdSbp& sbp_consumer,
const ParallelDesc& producer_parallel_desc,
const ParallelDesc& consumer_parallel_desc,
std::vector<NdSbp>& middle_sbps, int32_t* diag_node_pos) {
if (sbp_consumer.sbp_parallel(0).has_partial_sum_parallel()) {
// Support [4]: P <--> [2, 2]: (P, P)
// Support {0, 1, 2, 3}: P <--> {2, 0, 6, 7}: (P, P)
if (producer_parallel_desc.parallel_num() == consumer_parallel_desc.parallel_num()
&& sbp_producer.sbp_parallel(0).has_partial_sum_parallel()) {
return Maybe<void>::Ok();
}
if (!sbp_producer.sbp_parallel(0).has_broadcast_parallel()) {
// S -> B -> P (Large cost!)
// TODO: Please implement S -> P directly.
// We do not support [3]: P <--> [2, 2]: (P, P) as well.
int32_t hierarchy_size = 0;
if (producer_parallel_desc.hierarchy()->elem_cnt()
< consumer_parallel_desc.hierarchy()->elem_cnt()) {
// The diagonal node uses the parallel description from producer
// (S, S) -> (B, B) -> P/(P, P) or S -> B -> P/(P, P)
*diag_node_pos = 1;
hierarchy_size = producer_parallel_desc.hierarchy()->NumAxes();
} else {
// The diagonal node uses the parallel description from consumer
// S/(S, S) -> B -> P or S/(S, S) -> (B, B) -> (P, P)
*diag_node_pos = 0;
hierarchy_size = consumer_parallel_desc.hierarchy()->NumAxes();
}
NdSbp broadcast_nd;
for (int32_t i = 0; i < hierarchy_size; i++) {
broadcast_nd.add_sbp_parallel();
broadcast_nd.mutable_sbp_parallel(i)->mutable_broadcast_parallel();
}
middle_sbps.emplace_back(broadcast_nd);
}
}
return Maybe<void>::Ok();
}
} // namespace } // namespace
// A constructor with init, designed for uncustomized boxing collector // A constructor with init, designed for pre-stored boxing collector
BoxingCollector::BoxingCollector(int32_t max_axis) { CHECK_JUST(Init(max_axis)); } BoxingCollector::BoxingCollector(int32_t max_axis) { CHECK_JUST(Init(max_axis)); }
// Construct a boxing collector with given maximum number of axis // Construct a boxing collector with given maximum number of axis
...@@ -92,6 +148,8 @@ Maybe<void> BoxingCollector::Init(int32_t max_axis) { ...@@ -92,6 +148,8 @@ Maybe<void> BoxingCollector::Init(int32_t max_axis) {
JUST(GenerateCombination4SamePlacement(3)); JUST(GenerateCombination4SamePlacement(3));
JUST(GenerateCombination4DiffHierarchy(this, this)); JUST(GenerateCombination4DiffHierarchy(this, this));
JUST(GenerateCombination4DiffPlacement(this, this)); JUST(GenerateCombination4DiffPlacement(this, this));
init_type_ = int32_t(enable_general_basic_communication
|| Singleton<ResourceDesc, ForSession>::Get()->nccl_use_compute_stream());
return Maybe<void>::Ok(); return Maybe<void>::Ok();
} }
...@@ -106,6 +164,8 @@ Maybe<void> BoxingCollector::Init(const BlobDesc& logical_blob_desc, ...@@ -106,6 +164,8 @@ Maybe<void> BoxingCollector::Init(const BlobDesc& logical_blob_desc,
// Get copy cost in lazy mode // Get copy cost in lazy mode
LazyMode::Guard enable_lazy_mode(true); LazyMode::Guard enable_lazy_mode(true);
JUST(GenerateCombination4SamePlacement(5, logical_blob_desc, parallel_desc)); JUST(GenerateCombination4SamePlacement(5, logical_blob_desc, parallel_desc));
init_type_ = int32_t(enable_general_basic_communication
|| Singleton<ResourceDesc, ForSession>::Get()->nccl_use_compute_stream());
return Maybe<void>::Ok(); return Maybe<void>::Ok();
} }
...@@ -173,6 +233,7 @@ void BoxingCollector::GenerateMap1d2nd() { ...@@ -173,6 +233,7 @@ void BoxingCollector::GenerateMap1d2nd() {
// Generate the id Map from 1d sbp to nd sbp // Generate the id Map from 1d sbp to nd sbp
NdSbp nd_sbp; NdSbp nd_sbp;
for (int32_t dim_sbp = 0; dim_sbp < hierarchy_num_; dim_sbp++) { nd_sbp.add_sbp_parallel(); } for (int32_t dim_sbp = 0; dim_sbp < hierarchy_num_; dim_sbp++) { nd_sbp.add_sbp_parallel(); }
id_1d_2_nd_.clear();
id_1d_2_nd_.resize(m, -1); id_1d_2_nd_.resize(m, -1);
for (int32_t id_1d = 0; id_1d < m; id_1d++) { for (int32_t id_1d = 0; id_1d < m; id_1d++) {
for (int32_t dim_sbp = 0; dim_sbp < hierarchy_num_; dim_sbp++) { for (int32_t dim_sbp = 0; dim_sbp < hierarchy_num_; dim_sbp++) {
...@@ -190,10 +251,13 @@ Maybe<void> BoxingCollector::GenerateCombination4SamePlacement(int32_t max_middl ...@@ -190,10 +251,13 @@ Maybe<void> BoxingCollector::GenerateCombination4SamePlacement(int32_t max_middl
// NOTE: The performance of this function are all the same with different hierarchy // NOTE: The performance of this function are all the same with different hierarchy
int32_t world_size = GlobalProcessCtx::WorldSize(); int32_t world_size = GlobalProcessCtx::WorldSize();
Shape hierarchy44({4 * world_size, 4 * world_size}); Shape hierarchy44({4 * world_size, 4 * world_size});
int32_t virtual_range_size = hierarchy44.elem_cnt();
std::shared_ptr<Shape> virtual_hierarchy = std::make_shared<Shape>(hierarchy44); std::shared_ptr<Shape> virtual_hierarchy = std::make_shared<Shape>(hierarchy44);
auto parallel_desc = JUST(ParallelDesc::New( auto parallel_desc = JUST(ParallelDesc::New(
"cpu", {"0:0-" + std::to_string(hierarchy44.elem_cnt() - 1)}, virtual_hierarchy)); "cpu", {"0:0-" + std::to_string(hierarchy44.elem_cnt() - 1)}, virtual_hierarchy));
BlobDesc blob_desc({16, 16, 16, 16}, DataType::kInt8, /*is_dynamic=*/false); BlobDesc blob_desc({virtual_range_size, virtual_range_size, virtual_range_size,
virtual_range_size, virtual_range_size, virtual_range_size},
DataType::kInt8, /*is_dynamic=*/false);
JUST(GenerateCombination4SamePlacement(max_middle_node_num, blob_desc, *parallel_desc)); JUST(GenerateCombination4SamePlacement(max_middle_node_num, blob_desc, *parallel_desc));
return Maybe<void>::Ok(); return Maybe<void>::Ok();
} }
...@@ -204,7 +268,9 @@ Maybe<void> BoxingCollector::GenerateCombination4SamePlacement(int32_t max_middl ...@@ -204,7 +268,9 @@ Maybe<void> BoxingCollector::GenerateCombination4SamePlacement(int32_t max_middl
const ParallelDesc& parallel_desc) { const ParallelDesc& parallel_desc) {
// Store the origin transfer cost information // Store the origin transfer cost information
int32_t n = nd_sbp_lists_.size(); int32_t n = nd_sbp_lists_.size();
minimum_copy_cost_.clear();
minimum_copy_cost_.resize(n); minimum_copy_cost_.resize(n);
middle_nodes_.clear();
middle_nodes_.resize(n); middle_nodes_.resize(n);
for (int32_t i = 0; i < n; i++) { for (int32_t i = 0; i < n; i++) {
minimum_copy_cost_[i].resize(n); minimum_copy_cost_[i].resize(n);
...@@ -250,7 +316,7 @@ Maybe<void> BoxingCollector::GenerateCombination4SamePlacement(int32_t max_middl ...@@ -250,7 +316,7 @@ Maybe<void> BoxingCollector::GenerateCombination4SamePlacement(int32_t max_middl
minimum_copy_cost_[i][j] = curr_copy_cost; minimum_copy_cost_[i][j] = curr_copy_cost;
} }
} }
// If the minimum copy cost remians infinity, adding one middle node does not make it. // If the minimum copy cost remains infinity, adding one middle node does not make it.
if (minimum_copy_cost_[i][j] > GetValidMaxCopyCost()) { continue; } if (minimum_copy_cost_[i][j] > GetValidMaxCopyCost()) { continue; }
// Find those middle nodes // Find those middle nodes
for (int32_t k = 0; k < n; k++) { for (int32_t k = 0; k < n; k++) {
...@@ -291,6 +357,7 @@ Maybe<void> BoxingCollector::GenerateCombination4DiffHierarchy( ...@@ -291,6 +357,7 @@ Maybe<void> BoxingCollector::GenerateCombination4DiffHierarchy(
// Search the path that contains one of the diagonal sbp // Search the path that contains one of the diagonal sbp
int32_t n = nd_sbp_lists_.size(); int32_t n = nd_sbp_lists_.size();
diag_node_diff_hierarchy_.clear();
diag_node_diff_hierarchy_.resize(n); diag_node_diff_hierarchy_.resize(n);
for (int32_t i = 0; i < n; i++) { for (int32_t i = 0; i < n; i++) {
diag_node_diff_hierarchy_[i].resize(n); diag_node_diff_hierarchy_[i].resize(n);
...@@ -309,7 +376,10 @@ Maybe<void> BoxingCollector::GenerateCombination4DiffPlacement( ...@@ -309,7 +376,10 @@ Maybe<void> BoxingCollector::GenerateCombination4DiffPlacement(
BoxingCollector* boxing_collector_producer, BoxingCollector* boxing_collector_consumer) { BoxingCollector* boxing_collector_producer, BoxingCollector* boxing_collector_consumer) {
// Virtual parallel and blob description // Virtual parallel and blob description
int32_t world_size = GlobalProcessCtx::WorldSize(); int32_t world_size = GlobalProcessCtx::WorldSize();
BlobDesc blob_desc({16, 16, 16, 16}, DataType::kInt8, /*is_dynamic=*/false); int32_t virtual_range_size = 4 * world_size * (4 * world_size + 1);
BlobDesc blob_desc({virtual_range_size, virtual_range_size, virtual_range_size,
virtual_range_size, virtual_range_size, virtual_range_size},
DataType::kInt8, /*is_dynamic=*/false);
// Virtual placements before transfer // Virtual placements before transfer
Shape in_hierarchy44({4 * world_size + 1, 4 * world_size}); Shape in_hierarchy44({4 * world_size + 1, 4 * world_size});
std::shared_ptr<Shape> in_hierarchy = std::make_shared<Shape>(in_hierarchy44); std::shared_ptr<Shape> in_hierarchy = std::make_shared<Shape>(in_hierarchy44);
...@@ -334,6 +404,7 @@ Maybe<void> BoxingCollector::ComputeCostFor1DSbpDiffPlacement( ...@@ -334,6 +404,7 @@ Maybe<void> BoxingCollector::ComputeCostFor1DSbpDiffPlacement(
// Number of 1d sbp // Number of 1d sbp
int32_t m = id2sbp_parallel_.size(); int32_t m = id2sbp_parallel_.size();
// Compute the cost while transferring a 1D sbp between different placements // Compute the cost while transferring a 1D sbp between different placements
cost_4_diff_placement.clear();
cost_4_diff_placement.resize(m); cost_4_diff_placement.resize(m);
for (int32_t id_1d_producer = 0; id_1d_producer < m; id_1d_producer++) { for (int32_t id_1d_producer = 0; id_1d_producer < m; id_1d_producer++) {
cost_4_diff_placement[id_1d_producer].resize(m, GetMaxVal<float>()); cost_4_diff_placement[id_1d_producer].resize(m, GetMaxVal<float>());
...@@ -364,6 +435,7 @@ Maybe<void> BoxingCollector::GenerateCombination4DiffPlacement( ...@@ -364,6 +435,7 @@ Maybe<void> BoxingCollector::GenerateCombination4DiffPlacement(
// Search the path that contains two of the diagonal sbp // Search the path that contains two of the diagonal sbp
int32_t n = nd_sbp_lists_.size(); int32_t n = nd_sbp_lists_.size();
diag_node_diff_placement_.clear();
diag_node_diff_placement_.resize(n); diag_node_diff_placement_.resize(n);
for (int32_t i = 0; i < n; i++) { for (int32_t i = 0; i < n; i++) {
diag_node_diff_placement_[i].resize(n); diag_node_diff_placement_[i].resize(n);
...@@ -496,64 +568,74 @@ Maybe<void> BoxingCollector::AskSbpCombination(const NdSbp& sbp_producer, const ...@@ -496,64 +568,74 @@ Maybe<void> BoxingCollector::AskSbpCombination(const NdSbp& sbp_producer, const
if (ParseBooleanFromEnv("ONEFLOW_BOXING_DISABLE_MIDDLE_NODE_AND_CHECK", false)) { if (ParseBooleanFromEnv("ONEFLOW_BOXING_DISABLE_MIDDLE_NODE_AND_CHECK", false)) {
return Maybe<void>::Ok(); return Maybe<void>::Ok();
} }
// If compute_cost==false + 2D sbp + same placment + nccl logical + not (p->b), if (producer_parallel_desc == consumer_parallel_desc && sbp_producer == sbp_consumer) {
// Use nccl logical send recv instead of middle node.
// Note that in op sbp inference, cost of middle nodes is still used for the moment.
#if defined(WITH_CUDA) || defined(WITH_ROCM)
if (compute_cost == false && producer_parallel_desc.hierarchy()->NumAxes() == 2
&& producer_parallel_desc == consumer_parallel_desc
&& !(NdSbpHasPartialParallel(sbp_consumer)) &&
// TODO(): When same dim 0 finished dealing with (*, P) -> (*, S) in nccl logical pass, open
// this condition. When dealing with (P, P) -> (B, S0), middle node will change it to (P, P)
// -> (P, S0) -> (B, S0), neither same dim 0 or send recv in nccl logical pass can deal with
// (P, P) -> (P, S0) at the moment.
// !(NdSbpHasPartialParallel(sbp_producer) && NdSbpHasBroadcastParallel(sbp_consumer)) &&
Singleton<ResourceDesc, ForSession>::Get()->nccl_use_compute_stream()) {
VLOG(3) << "Middle node insertion is skipped when src sbp is " << NdSbpToString(sbp_producer)
<< " dst sbp is " << NdSbpToString(sbp_consumer)
<< ", because nccl logical send/recv can handle this.";
return Maybe<void>::Ok(); return Maybe<void>::Ok();
} }
#endif // WITH_CUDA
// Dealing with 1D sbp to 1D sbp // Dealing with 1D sbp to 1D sbp
// Specifically, S -> P.
if (Is1dSbp(sbp_producer) && Is1dSbp(sbp_consumer)) { if (Is1dSbp(sbp_producer) && Is1dSbp(sbp_consumer)) {
if (sbp_consumer.sbp_parallel(0).has_partial_sum_parallel()) { JUST(AskSbpCombinationFor1DSbp(sbp_producer, sbp_consumer, producer_parallel_desc,
// Support [4]: P <--> [2, 2]: (P, P) consumer_parallel_desc, middle_sbps, diag_node_pos));
// Support {0, 1, 2, 3}: P <--> {2, 0, 6, 7}: (P, P) // No middle nodes for the other 1d-sbp combinations
if (producer_parallel_desc.parallel_num() == consumer_parallel_desc.parallel_num() return Maybe<void>::Ok();
&& sbp_producer.sbp_parallel(0).has_partial_sum_parallel()) { }
return Maybe<void>::Ok();
}
if (!sbp_producer.sbp_parallel(0).has_broadcast_parallel()) { #ifdef WITH_CUDA
// S -> B -> P (Large cost!) // Use a general basic communication if no P in the consumer
// TODO: Please implement S -> P directly. if (((Singleton<ResourceDesc, ForSession>::Get()->nccl_use_compute_stream()
// We do not support [3]: P <--> [2, 2]: (P, P) as well. && producer_parallel_desc == consumer_parallel_desc)
|| enable_general_basic_communication)
int32_t hierarchy_size = 0; && (!NdSbpHasPartialParallel(sbp_consumer))
if (producer_parallel_desc.hierarchy()->elem_cnt() && producer_parallel_desc.device_type() == DeviceType::kCUDA
< consumer_parallel_desc.hierarchy()->elem_cnt()) { && consumer_parallel_desc.device_type() == DeviceType::kCUDA) {
// The diagonal node uses the parallel description from producer if (NdSbpHasPartialParallel(sbp_producer) && NdSbpHasBroadcastParallel(sbp_consumer)) {
// (S, S) -> (B, B) -> P/(P, P) or S -> B -> P/(P, P) // (?, P, ?)->(Si, Sj)->(?, B, ?), two-step transfer
*diag_node_pos = 1; // Directly applying general basic communication would have O(n^2) time complexity for P->B
hierarchy_size = producer_parallel_desc.hierarchy()->NumAxes(); // Using two-step transfer would reduce it to a linear cost
} else { JUST(AskSbpCombination4GeneralBasicCommunication(
// The diagonal node uses the parallel description from consumer sbp_producer, sbp_consumer, logical_blob_desc, producer_parallel_desc,
// S/(S, S) -> B -> P or S/(S, S) -> (B, B) -> (P, P) consumer_parallel_desc, middle_sbps, diag_node_pos));
*diag_node_pos = 0; }
hierarchy_size = consumer_parallel_desc.hierarchy()->NumAxes(); // Otherwise, one-step transfer
} return Maybe<void>::Ok();
}
#endif // WITH_CUDA
NdSbp broadcast_nd; #ifdef WITH_ROCM
for (int32_t i = 0; i < hierarchy_size; i++) { // Use a general basic communication if no P in the consumer
broadcast_nd.add_sbp_parallel(); if (((Singleton<ResourceDesc, ForSession>::Get()->nccl_use_compute_stream()
broadcast_nd.mutable_sbp_parallel(i)->mutable_broadcast_parallel(); && producer_parallel_desc == consumer_parallel_desc)
} || enable_general_basic_communication)
middle_sbps.emplace_back(broadcast_nd); && (!NdSbpHasPartialParallel(sbp_consumer))
} && producer_parallel_desc.device_type() == DeviceType::kCUDA
return Maybe<void>::Ok(); && consumer_parallel_desc.device_type() == DeviceType::kCUDA) {
if (NdSbpHasPartialParallel(sbp_producer) && NdSbpHasBroadcastParallel(sbp_consumer)) {
// (?, P, ?)->(Si, Sj)->(?, B, ?), two-step transfer
// Directly applying general basic communication would have O(n^2) time complexity for P->B
// Using two-step transfer would reduce it to a linear cost
JUST(AskSbpCombination4GeneralBasicCommunication(
sbp_producer, sbp_consumer, logical_blob_desc, producer_parallel_desc,
consumer_parallel_desc, middle_sbps, diag_node_pos));
}
// Otherwise, one-step transfer
return Maybe<void>::Ok();
}
#endif // WITH_ROCM
if (JUST(ComputeLazyCopyCostBetweenNdSbp(sbp_producer, sbp_consumer, logical_blob_desc,
producer_parallel_desc, consumer_parallel_desc,
/*requires_same_sbp=*/false))
< GetValidMaxCopyCost()) {
return Maybe<void>::Ok();
} else {
int32_t require_init_type =
int32_t(enable_general_basic_communication
|| Singleton<ResourceDesc, ForSession>::Get()->nccl_use_compute_stream());
if (init_type_ != require_init_type) {
// We assemble the boxing table from S(0) to S(5).
// Those splitting in higher axes are considered in the customized boxing.
constexpr int32_t kRegularMaxSplitAxes = 6;
JUST(Init(kRegularMaxSplitAxes));
} }
} }
...@@ -568,6 +650,7 @@ Maybe<void> BoxingCollector::AskSbpCombination(const NdSbp& sbp_producer, const ...@@ -568,6 +650,7 @@ Maybe<void> BoxingCollector::AskSbpCombination(const NdSbp& sbp_producer, const
// Transfer for the same machines, devices and hierarchy. // Transfer for the same machines, devices and hierarchy.
if (sbp_producer == sbp_consumer) { return Maybe<void>::Ok(); } if (sbp_producer == sbp_consumer) { return Maybe<void>::Ok(); }
const auto& parallel_hierarchy = producer_parallel_desc.hierarchy(); const auto& parallel_hierarchy = producer_parallel_desc.hierarchy();
*diag_node_pos = 0; *diag_node_pos = 0;
// Dealing with nD sbp, n>2 // Dealing with nD sbp, n>2
if (parallel_hierarchy->NumAxes() > 2) { if (parallel_hierarchy->NumAxes() > 2) {
...@@ -675,7 +758,7 @@ Maybe<void> BoxingCollector::AskSbpCombination4DiffPlacement( ...@@ -675,7 +758,7 @@ Maybe<void> BoxingCollector::AskSbpCombination4DiffPlacement(
if (same_placement) { if (same_placement) {
// Different hierarchies // Different hierarchies
CHECK_OR_RETURN(diag_node_diff_hierarchy_.size() > 0) CHECK_OR_RETURN(diag_node_diff_hierarchy_.size() > 0)
<< "Have not initialzie the combination table for different hierarchies yet! " << "Have not initialized the combination table for different hierarchies yet! "
"Please run JUST(GenerateCombination4DiffHierarchy(this, this)); " "Please run JUST(GenerateCombination4DiffHierarchy(this, this)); "
"before Asking sbp combination for different parallel description."; "before Asking sbp combination for different parallel description.";
if (JUST(Ask1Combination4DiffPlacement( if (JUST(Ask1Combination4DiffPlacement(
...@@ -687,7 +770,7 @@ Maybe<void> BoxingCollector::AskSbpCombination4DiffPlacement( ...@@ -687,7 +770,7 @@ Maybe<void> BoxingCollector::AskSbpCombination4DiffPlacement(
} else { } else {
// Different placements // Different placements
CHECK_OR_RETURN(diag_node_diff_placement_.size() > 0) CHECK_OR_RETURN(diag_node_diff_placement_.size() > 0)
<< "Have not initialzie the combination table for different hierarchies yet! " << "Have not initialized the combination table for different hierarchies yet! "
"Please run JUST(GenerateCombination4DiffPlacement(this, this)); " "Please run JUST(GenerateCombination4DiffPlacement(this, this)); "
"before Asking sbp combination for different parallel description."; "before Asking sbp combination for different parallel description.";
if (JUST(Ask1Combination4DiffPlacement( if (JUST(Ask1Combination4DiffPlacement(
...@@ -787,9 +870,9 @@ Maybe<void> BoxingCollector::Generate1Combination4DiffHierarchy( ...@@ -787,9 +870,9 @@ Maybe<void> BoxingCollector::Generate1Combination4DiffHierarchy(
min_path_length = path_length; min_path_length = path_length;
// Find a candidate with small cost // Find a candidate with small cost
if (curr_cost < min_cost * 1.0000001) { if (curr_cost < min_cost * kFloatDeviationPlus) {
// Find a smaller cost, clear the previous path. // Find a smaller cost, clear the previous path.
if (curr_cost < min_cost * 0.9999999) { if (curr_cost < min_cost * kFloatDeviationMinus) {
min_cost = curr_cost; min_cost = curr_cost;
diag_nodes.clear(); diag_nodes.clear();
} }
...@@ -1007,4 +1090,105 @@ Maybe<void> BoxingCollector::FilterNdSbpList4LogicalShape(const BlobDesc& logica ...@@ -1007,4 +1090,105 @@ Maybe<void> BoxingCollector::FilterNdSbpList4LogicalShape(const BlobDesc& logica
return Maybe<void>::Ok(); return Maybe<void>::Ok();
} }
// Ask for sbp combination for general basic communication
Maybe<void> BoxingCollector::AskSbpCombination4GeneralBasicCommunication(
const NdSbp& sbp_producer, const NdSbp& sbp_consumer, const BlobDesc& logical_blob_desc,
const ParallelDesc& producer_parallel_desc, const ParallelDesc& consumer_parallel_desc,
std::vector<NdSbp>& middle_sbps, int32_t* diag_node_pos) {
// (P, X) -> (B, X) || (X , P) -> (X, B), X is any SBP
// One step transfer, at most 50% reduction in the transfer cost, do not use middle nodes
if (producer_parallel_desc == consumer_parallel_desc
&& producer_parallel_desc.hierarchy()->NumAxes() == 2
&& (sbp_producer.sbp_parallel(0) == sbp_consumer.sbp_parallel(0)
|| sbp_producer.sbp_parallel(1) == sbp_consumer.sbp_parallel(1))) {
return Maybe<void>::Ok();
}
// Not enough gain in transfer cost, do not use middle nodes
int32_t partial_ratio4producer = PartialRatio4Producer(sbp_producer, producer_parallel_desc);
int32_t broadcast_ratio4consumer = BroadcastRatio4Consumer(sbp_consumer, consumer_parallel_desc);
if (2 * (partial_ratio4producer + broadcast_ratio4consumer)
>= partial_ratio4producer * broadcast_ratio4consumer) {
return Maybe<void>::Ok();
}
bool close2producer = true;
if (producer_parallel_desc.parallel_num() == consumer_parallel_desc.parallel_num()) {
// Get close to the one with more splits
close2producer = TotalNumSplit(sbp_producer, producer_parallel_desc)
> TotalNumSplit(sbp_consumer, consumer_parallel_desc);
} else {
// Get close to the one with more machines
close2producer = producer_parallel_desc.parallel_num() > consumer_parallel_desc.parallel_num();
}
// Get the contiguous sbp
if (close2producer) {
JUST(AskCloseAllSplitSbp(sbp_producer, producer_parallel_desc, logical_blob_desc, middle_sbps));
*diag_node_pos = 1;
} else {
JUST(AskCloseAllSplitSbp(sbp_consumer, consumer_parallel_desc, logical_blob_desc, middle_sbps));
*diag_node_pos = 0;
}
return Maybe<void>::Ok();
}
// Ask for a all-split sbp which is close to the original one
Maybe<void> BoxingCollector::AskCloseAllSplitSbp(const NdSbp& nd_sbp,
const ParallelDesc& parallel_desc,
const BlobDesc& logical_blob_desc,
std::vector<NdSbp>& middle_sbps) {
Shape remain_shape = logical_blob_desc.shape();
Shape rest_split_shape = logical_blob_desc.shape();
int32_t dim_shape = remain_shape.NumAxes();
// Initialize the remains and splitting
// logical_blob_desc.shape() == remain_shape .* rest_split_shape;
for (int32_t i = 0; i < dim_shape; i++) { rest_split_shape.Set(i, 1); }
for (int32_t sbp_id = 0; sbp_id < nd_sbp.sbp_parallel_size(); sbp_id++) {
const auto& sbp = nd_sbp.sbp_parallel(sbp_id);
if (sbp.has_split_parallel()) {
int32_t axis = sbp.split_parallel().axis();
int32_t split_num = parallel_desc.hierarchy()->At(sbp_id);
remain_shape.Set(axis, remain_shape.At(axis) / split_num);
rest_split_shape.Set(axis, rest_split_shape.At(axis) * split_num);
}
}
// Get the contiguous sbp
NdSbp new_sbp = nd_sbp;
for (int32_t sbp_id = 0; sbp_id < nd_sbp.sbp_parallel_size(); sbp_id++) {
const auto& sbp = nd_sbp.sbp_parallel(sbp_id);
int32_t split_num = parallel_desc.hierarchy()->At(sbp_id);
if (sbp.has_split_parallel()) {
int32_t axis = sbp.split_parallel().axis();
// split shape is the total splitting number starting from sbp_id to the end
rest_split_shape.Set(axis, rest_split_shape.At(axis) / split_num);
} else {
// change P or B to S(axis)
int32_t axis = -1;
// 4096 is large enough, we might not have that much devices
int32_t min_split_num = 4096;
// We need to pick a suitable axis
for (int32_t i = 0; i < remain_shape.NumAxes(); i++) {
if (remain_shape.At(i) % split_num == 0) {
if (rest_split_shape.At(i) < min_split_num) {
// Pick the axis with smallest splitting number among the rest of the sbp
min_split_num = rest_split_shape.At(i);
axis = i;
}
}
}
// P, B -> S(axis)
if (axis >= 0) {
new_sbp.mutable_sbp_parallel(sbp_id)->mutable_split_parallel()->set_axis(axis);
remain_shape.Set(axis, remain_shape.At(axis) / split_num);
} else {
// Can not find a suitable contiguous sbp
return Maybe<void>::Ok();
}
}
}
// Add the new sbp into the middle node lists
middle_sbps.emplace_back(new_sbp);
return Maybe<void>::Ok();
}
} // namespace oneflow } // namespace oneflow
...@@ -129,10 +129,19 @@ class BoxingCollector final { ...@@ -129,10 +129,19 @@ class BoxingCollector final {
BoxingCollector* boxing_collector_producer, BoxingCollector* boxing_collector_producer,
BoxingCollector* boxing_collector_consumer, BoxingCollector* boxing_collector_consumer,
const std::vector<std::vector<int32_t>>& diag_nodes); const std::vector<std::vector<int32_t>>& diag_nodes);
// Ask for sbp combination for general basic communication
Maybe<void> AskSbpCombination4GeneralBasicCommunication(
const NdSbp& sbp_producer, const NdSbp& sbp_consumer, const BlobDesc& logical_blob_desc,
const ParallelDesc& producer_parallel_desc, const ParallelDesc& consumer_parallel_desc,
std::vector<NdSbp>& middle_sbps, int32_t* diag_node_pos);
// Ask for a all-split sbp which is closed to the original one
Maybe<void> AskCloseAllSplitSbp(const NdSbp& nd_sbp, const ParallelDesc& parallel_desc,
const BlobDesc& logical_blob_desc,
std::vector<NdSbp>& middle_sbps);
// Stores all the possible SbpParallel. // Stores all the possible SbpParallel.
HashMap<::oneflow::SbpParallel, int32_t> sbp_parallel_universe_; HashMap<SbpParallel, int32_t> sbp_parallel_universe_;
// Relationship between id and Sbp Parallel // Relationship between id and Sbp Parallel
std::vector<::oneflow::SbpParallel> id2sbp_parallel_; std::vector<SbpParallel> id2sbp_parallel_;
// minimum cost // minimum cost
// minimum_copy_cost[producer][consumer] // minimum_copy_cost[producer][consumer]
std::vector<std::vector<double>> minimum_copy_cost_; std::vector<std::vector<double>> minimum_copy_cost_;
...@@ -142,18 +151,23 @@ class BoxingCollector final { ...@@ -142,18 +151,23 @@ class BoxingCollector final {
// nodes that needs to be inserted // nodes that needs to be inserted
std::vector<std::vector<std::vector<std::vector<int32_t>>>> middle_nodes_; std::vector<std::vector<std::vector<std::vector<int32_t>>>> middle_nodes_;
// Stores all the possible NdSbp. // Stores all the possible NdSbp.
std::unordered_map<::oneflow::NdSbp, int32_t> nd_sbp_universe_; std::unordered_map<NdSbp, int32_t> nd_sbp_universe_;
// Relationship between id and Nd Sbp // Relationship between id and Nd Sbp
std::vector<NdSbp> nd_sbp_lists_; std::vector<NdSbp> nd_sbp_lists_;
// The diagonal middle node for differe placements // The diagonal middle node for different placements
std::vector<std::vector<std::vector<std::vector<int32_t>>>> diag_node_diff_placement_; std::vector<std::vector<std::vector<std::vector<int32_t>>>> diag_node_diff_placement_;
// The diagonal middle node for differe hierarchies in the same placement // The diagonal middle node for different hierarchies in the same placement
std::vector<std::vector<std::vector<std::vector<int32_t>>>> diag_node_diff_hierarchy_; std::vector<std::vector<std::vector<std::vector<int32_t>>>> diag_node_diff_hierarchy_;
// Id Map from 1d sbp to 2d sbp // Id Map from 1d sbp to 2d sbp
// For example: B -> (B, B), S0 -> (S0, S0) // For example: B -> (B, B), S0 -> (S0, S0)
std::vector<int32_t> id_1d_2_nd_; std::vector<int32_t> id_1d_2_nd_;
// The sbp size in the combination table // The sbp size in the combination table
int32_t hierarchy_num_; int32_t hierarchy_num_;
// How the boxing collector is initialized
int32_t init_type_ = -1;
// Enable general basic communication or not
const bool enable_general_basic_communication =
ParseBooleanFromEnv("ONEFLOW_BOXING_ENABLE_GENERAL_BASIC_COMMUNICATION", false);
}; // class BoxingCollector }; // class BoxingCollector
} // namespace oneflow } // namespace oneflow
......
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include <string>
#include "oneflow/core/auto_parallel/sbp_collector.h"
#include "oneflow/core/auto_parallel/binary_set.h"
#include "oneflow/core/auto_parallel/sbp_util.h"
#include "oneflow/core/auto_parallel/sbp_constructor.h"
namespace oneflow {
namespace auto_parallel {
namespace {
// Whether the given binary set intersects all the sbp sets of the consumers
bool IfIntersectAll(
const HashMap<std::pair<std::string, std::string>, BinarySet>& consumer_bn2sbp_set,
const BinarySet& bs) {
for (const auto& sbp_set_group : consumer_bn2sbp_set) {
if (!bs.IfIntersect(sbp_set_group.second)) { return false; }
}
return true;
}
// Find unique sbp sets
void FindUniqueSbpSets(
const HashMap<std::pair<std::string, std::string>, BinarySet>& consumer_bn2sbp_set,
const std::unordered_set<int32_t>& all_sbp_set, std::vector<int32_t>& accumulator,
BinarySet& unique_sbps) {
std::vector<int32_t> sbp_ids;
// count the number of sbp
for (const auto& sbp_set_group : consumer_bn2sbp_set) {
sbp_set_group.second.QuickOutput(sbp_ids);
for (int32_t sbp_id : sbp_ids) { accumulator[sbp_id]++; }
}
// find unique sbp and clear the accumulator
for (const auto& sbp_id : all_sbp_set) {
if (accumulator[sbp_id] == 1) { unique_sbps.AddEntry(sbp_id); }
accumulator[sbp_id] = 0;
}
}
// Find unique sbp groups
void FindUniqueSbpGroups(
const HashMap<std::pair<std::string, std::string>, BinarySet>& consumer_bn2sbp_set,
const std::unordered_set<int32_t>& all_sbp_set, std::vector<int32_t>& accumulator,
BinarySet& bs_buffer, std::vector<BinarySet>& unique_sbp_groups) {
// find the unique sbp sets
BinarySet unique_sbps(accumulator.size());
FindUniqueSbpSets(consumer_bn2sbp_set, all_sbp_set, accumulator, unique_sbps);
// A: {B, S0, S1, S2, S3}, C: {B, S0}, D: {B, S0}
// {S1, S2, S3} show up only once, a parallel candidate should not contain two of them
for (const auto& sbp_set_group : consumer_bn2sbp_set) {
unique_sbps.IntersectionTo(sbp_set_group.second, bs_buffer);
// Find those unique sbp groups with more than two sbp
// For example {B, S1, S2} is an impossible proxy candidate,
// since {S1, S2} is only contained by A but not contained by C and D.
// A could be either S1 or S2. The tensor do not need to be transferred to both S1 and S2.
if (bs_buffer.Total() >= 2) { unique_sbp_groups.push_back(bs_buffer); }
}
bs_buffer.Clear();
}
// If not contains two sbp from a same unique group
bool No2SbpFromSameUniqueGroup(const BinarySet& bs,
const std::vector<BinarySet>& unique_sbp_groups) {
BinarySet intersection(bs.GetSizeOfSet());
for (const auto& unique_sbp_group : unique_sbp_groups) {
bs.IntersectionTo(unique_sbp_group, intersection);
// For example {B, S1, S2} is an impossible proxy candidate,
// since {S1, S2} is only contained by A but not contained by C and D.
// A could be either S1 or S2. The tensor do not need to be transferred to both S1 and S2.
if (intersection.Total() >= 2) { return false; }
}
return true;
}
} // namespace
// Default constructor for SbpCollector
// Don't allow any special case for broadcast!
SbpCollector::SbpCollector() {
// initialize Sbp Parallel Universe with broadcast.
// NdSbp sbp_broadcast;
// sbp_broadcast.mutable_broadcast_parallel();
// nd_sbp_universe_[sbp_broadcast] = 0;
// id2nd_sbp_.push_back(sbp_broadcast);
}
// Collect all the possible Sbp Parallel from a NdSbpSignature
void SbpCollector::CollectUniverse(const NdSbpSignature& nd_sbp_sig) {
for (auto& bn_sbp_pair : nd_sbp_sig.bn_in_op2nd_sbp()) {
if (nd_sbp_universe_.find(bn_sbp_pair.second) == nd_sbp_universe_.end()) {
int32_t curr_size = nd_sbp_universe_.size();
nd_sbp_universe_[bn_sbp_pair.second] = curr_size;
id2nd_sbp_.push_back(bn_sbp_pair.second);
}
}
}
// Collect all the possible Sbp Parallel from a SbpNode
void SbpCollector::CollectUniverse(const SbpNode* sbp_node) {
for (auto& nd_sbp_sig : sbp_node->sbp_sig_list_) { CollectUniverse(nd_sbp_sig); }
}
// Collect all the possible Sbp Parallel from a SbpGraph
void SbpCollector::CollectUniverse(const SbpGraph& sbp_graph) {
for (auto* sbp_node : sbp_graph.node_list_) { CollectUniverse(sbp_node); }
accumulator_.resize(nd_sbp_universe_.size(), 0);
bs_buffer_.Initialize(nd_sbp_universe_.size());
}
// TODO: Auto Placement!
// It only collect the same sbp with the same parallel description
// In this moment their hierarchy is the same!
// Initialize copy cost from producer to proxy of producer
void SbpCollector::InitializeCopyCostFromNode2Proxy(const SbpNode* sbp_proxy,
const LogicalBlobId& lbi) const {
// the only edge from producer to proxy of producer
SbpEdge* sbp_edge = sbp_proxy->edges_in_[0];
SbpNode* sbp_node_producer = sbp_edge->start_node_;
sbp_edge->cost_.resize(sbp_node_producer->sbp_sig_list_.size());
int32_t consumer_sbp_size = sbp_proxy->parallel_candidates_.size();
// look through sbp signature in producer
for (int32_t sbp_id_producer = 0; sbp_id_producer < sbp_node_producer->sbp_sig_list_.size();
sbp_id_producer++) {
sbp_edge->cost_[sbp_id_producer].resize(consumer_sbp_size, 0);
}
// Assemble copy cost from producer to proxy of producer
OpNode* producer = sbp_node_producer->op_node_;
// get parallel description. Number of devices.
const ParallelDesc& producer_parallel_desc = producer->parallel_desc();
// Need to be careful, the logical blob description should be independent to current
// NdSbp. Use producer or op_node?
const BlobDesc& logical_blob_desc = producer->LogicalBlobDesc4Lbi(lbi);
const std::string& obn = *CHECK_JUST(producer->op().obn4lbi(lbi));
// A buffer to store the sbp parallel id
std::vector<int32_t> sbp_parallel_ids;
// look through sbp signature in producer
for (int32_t sbp_id_producer = 0; sbp_id_producer < sbp_node_producer->sbp_sig_list_.size();
sbp_id_producer++) {
// get sbp parallel for a logical blob in producer
const auto& producer_sbp_bn_in_op2sbp_parallel =
sbp_node_producer->sbp_sig_list_[sbp_id_producer].bn_in_op2nd_sbp();
const NdSbp& sbp_producer = producer_sbp_bn_in_op2sbp_parallel.at(obn);
// look through sbp parallel set in consumer
for (int32_t sbp_id_consumer = 0; sbp_id_consumer < consumer_sbp_size; sbp_id_consumer++) {
const BinarySet& sbp_parallel_set = sbp_proxy->parallel_candidates_[sbp_id_consumer];
sbp_parallel_set.QuickOutput(sbp_parallel_ids);
// look through all sbp parallels in a sbp parallel set
for (int32_t sbp_parallel_id : sbp_parallel_ids) {
// get sbp parallel for a logical blob in consumer
const NdSbp& sbp_consumer = id2nd_sbp_[sbp_parallel_id];
// compute copy cost for a specific logical blob
// Use the parallel description of producer as those for consumer for now.
sbp_edge->cost_[sbp_id_producer][sbp_id_consumer] +=
CHECK_JUST(ComputeCopyCostWithMiddleNodes(sbp_producer, sbp_consumer, logical_blob_desc,
producer_parallel_desc,
producer_parallel_desc, /*is_same=*/false));
}
}
}
}
// Initialize copy cost from proxy of producer to consumers
void SbpCollector::InitializeCopyCostFromProxy2Consumer(
SbpNode* sbp_proxy,
const HashMap<std::pair<std::string, std::string>, BinarySet>& consumer_bn2sbp_set,
const HashMap<std::string, SbpNode*>& op_name2sbp_node) const {
// Connect sbp proxy and consumers
for (const auto& consumer_bn_group : consumer_bn2sbp_set) {
// consumer in cost model
SbpNode* sbp_node_consumer = op_name2sbp_node.find(consumer_bn_group.first.first)->second;
// input blob name of logical blob in consumer
const std::string& ibn = consumer_bn_group.first.second;
// check is_mutable in consumer
OpNode* consumer = sbp_node_consumer->op_node_;
CHECK(!RequireSameSbp(consumer, ibn)) << "Create a proxy for an unsuitable consumer!\n";
// Connect sbp proxy and consumer
sbp_proxy->PointTo(sbp_node_consumer);
// the sbp edge connecting proxy and consumer
SbpEdge* sbp_edge = sbp_node_consumer->FindEdgeWithNode(sbp_proxy);
sbp_edge->cost_.resize(sbp_proxy->parallel_candidates_.size());
int32_t consumer_sbp_size = sbp_node_consumer->sbp_sig_list_.size();
// look through sbp parallel set in proxy
for (int32_t sbp_id_producer = 0; sbp_id_producer < sbp_proxy->parallel_candidates_.size();
sbp_id_producer++) {
// initialization for copy cost
sbp_edge->cost_[sbp_id_producer].resize(consumer_sbp_size, 0);
// get sbp parallel set for a logical blob in proxy
BinarySet& parallel_candidate = sbp_proxy->parallel_candidates_[sbp_id_producer];
// look through sbp signatures in consumers
for (int32_t sbp_id_consumer = 0; sbp_id_consumer < consumer_sbp_size; sbp_id_consumer++) {
// get sbp parallel for a logical blob in consumer
const auto& consumer_sbp_bn_in_op2sbp_parallel =
sbp_node_consumer->sbp_sig_list_[sbp_id_consumer].bn_in_op2nd_sbp();
const NdSbp& sbp_consumer = consumer_sbp_bn_in_op2sbp_parallel.at(ibn);
if ((!parallel_candidate.CheckExistence(nd_sbp_universe_.find(sbp_consumer)->second))) {
sbp_edge->cost_[sbp_id_producer][sbp_id_consumer] = GetMaxVal<float>();
}
}
}
}
}
// Export list of possible combination of Sbp Parallels
void SbpCollector::ProxySbpCandidate(const OpGraph& op_graph,
const HashMap<std::string, SbpNode*>& op_name2sbp_node,
SbpGraph& sbp_graph) {
// If needed, we can output the mapping from operator name to its proxy.
// HashMap<std::string, HashMap<LogicalBlobId, SbpNode*>>&
// op_name2lbi2sbp_proxy;
// mapping from a logical blob id to index
HashMap<LogicalBlobId, int32_t> lbi2index;
// mapping from the index to producer, consumer and corresponding input blob name, possible sbp
// sets
std::vector<const OpNode*> index2producer;
std::vector<std::unordered_set<int32_t>> index2sbp_set;
// mapping from consumers and input blob names to an unordered_set of SBP Parallel.
std::vector<HashMap<std::pair<std::string, std::string>, BinarySet>> index2consumer_bn2sbp_set;
for (auto* consumer_sbp_node : sbp_graph.node_list_) {
auto* node = consumer_sbp_node->op_node_;
OperatorConf::OpTypeCase op_type_case = node->op().op_conf().op_type_case();
// If not support boxing, just skip it.
if (IsClassRegistered<int32_t, DisableInputBoxingGroup>(op_type_case)) { return; }
for (const std::string& ibn : node->op().input_bns()) {
// Skip those blobs who enforce same SBP.
if (RequireSameSbp(node, ibn)) {
// Enforcing same SBP. Can not collect sbp from this blob.
continue;
}
const LogicalBlobId& lbi = node->op().BnInOp2Lbi(ibn);
const OpNode& producer = node->ProducerOpNode4Lbi(lbi);
// not building proxy for fixed operators
if (op_name2sbp_node.find(producer.op().op_name()) == op_name2sbp_node.end()) { return; }
// decide the index of a logical blob description
const auto& iterator_lbi = lbi2index.find(lbi);
int32_t index = 0;
if (iterator_lbi == lbi2index.end()) {
index = lbi2index.size();
lbi2index[lbi] = index;
// map from lbi to the producer
index2producer.push_back(&producer);
// Initialize consumer_bns and the sbp sets
index2consumer_bn2sbp_set.resize(index + 1);
index2sbp_set.resize(index + 1);
} else {
index = iterator_lbi->second;
}
// a set to store the id of all possible SBP Parallel for a downstream op
// should filter out repeated SBP Parallel by pre-storing them into an unordered_set
BinarySet& nd_sbp_ids = index2consumer_bn2sbp_set[index][{node->op().op_name(), ibn}];
nd_sbp_ids.Initialize(nd_sbp_universe_.size());
// The union sbp set of all the consumers
std::unordered_set<int32_t>& union_nd_sbp_ids = index2sbp_set[index];
for (auto& sbp_sig : consumer_sbp_node->sbp_sig_list_) {
const auto& map = sbp_sig.bn_in_op2nd_sbp();
const auto& iter = map.find(ibn);
CHECK(iter != map.end()) << "blob_name " << ibn << " not found in sbp signature";
const NdSbp& consumer_sbp = iter->second;
// filter out repeated SBP
int32_t sbp_universe_id = nd_sbp_universe_.find(consumer_sbp)->second;
nd_sbp_ids.AddEntry(sbp_universe_id);
union_nd_sbp_ids.insert(sbp_universe_id);
}
}
};
// A set of binary set with broadcast only
// std::unordered_set<BinarySet, BinarySetHasher> parallel_candidates_initializer;
// BinarySet one_broadcast(nd_sbp_universe_.size());
// one_broadcast.AddEntry(0);
// parallel_candidates_initializer.insert(std::move(one_broadcast));
// Decide if we should insert a proxy for each logical blob
for (auto& lbi_index : lbi2index) {
int32_t index = lbi_index.second;
// Only insert proxy for those blobs with multiple downstream consumers.
if (index2consumer_bn2sbp_set[index].size() < 2) { continue; }
// Maximum number of possible sbp in the proxy
int32_t max_num_sbp_proxy =
std::min(max_num_sbp_proxy_, index2consumer_bn2sbp_set[index].size());
// producer in cost model
const std::string& producer_name = index2producer[index]->op().op_name();
SbpNode* sbp_node_producer = op_name2sbp_node.find(producer_name)->second;
const LogicalBlobId& lbi = lbi_index.first;
// store all the binary sets of SBP Parallel into an unordered_set.
// std::vector<BinarySet> parallel_candidates;
// generate sbp proxy
SbpNode* sbp_proxy = sbp_graph.GenerateNode();
// A: {B, S0, S1, S2, S3}, C: {B, S0}, D: {B, S0}
// {S1, S2, S3} show up only once, a parallel candidate should not contain two of them
std::vector<BinarySet> unique_sbp_groups;
FindUniqueSbpGroups(index2consumer_bn2sbp_set[index], index2sbp_set[index], accumulator_,
bs_buffer_, unique_sbp_groups);
// Depth first search to collect Sbp Parallel information for the whole sbp set
DfsSbpSet(0, max_num_sbp_proxy, index2sbp_set[index], index2sbp_set[index].begin(),
index2consumer_bn2sbp_set[index], unique_sbp_groups, sbp_proxy->parallel_candidates_);
// Initialize computation cost
sbp_proxy->cost_.resize(sbp_proxy->parallel_candidates_.size(), 0);
// Transfer a logical blob from producer to a sbp proxy of this blob
sbp_node_producer->PointTo(sbp_proxy);
// Compute copy cost between producer and proxy
InitializeCopyCostFromNode2Proxy(sbp_proxy, lbi);
// Build connection and compute copy cost between proxy and consumers
InitializeCopyCostFromProxy2Consumer(sbp_proxy, index2consumer_bn2sbp_set[index],
op_name2sbp_node);
// Unloading
for (const auto& consumer_bn_group : index2consumer_bn2sbp_set[index]) {
// consumer in cost model
SbpNode* sbp_node_consumer = op_name2sbp_node.find(consumer_bn_group.first.first)->second;
// the sbp edge connecting producer and consumer
SbpEdge* edge_found = sbp_node_consumer->FindEdgeWithNode(sbp_node_producer);
// unload logical blob from sbp edges
edge_found->UnloadLbi(lbi);
// Do not clip this edge. Save it for wait time.
// clip this edge if it no longer carries any blob
// We don't clip edges before since we have transfer cost
// Now we clip edges, which makes the topology simpler
if (edge_found->EmptyLbi() && edge_found->wait_time_ <= 0.0
&& edge_found->wait_time_ > -0.5) {
sbp_graph.ClipEdge(edge_found);
}
}
}
}
// Depth first search to collect Sbp Parallel information for different logical blob ids
void SbpCollector::DfsSbpSet(
int32_t depth, int32_t max_depth, const std::unordered_set<int32_t>& sbp_sets,
const std::unordered_set<int32_t>::iterator& start_it,
const HashMap<std::pair<std::string, std::string>, BinarySet>& consumer_bn2sbp_set,
const std::vector<BinarySet>& unique_sbp_groups, std::vector<BinarySet>& parallel_candidates) {
if (depth > 0) {
if (IfIntersectAll(consumer_bn2sbp_set, bs_buffer_)
&& No2SbpFromSameUniqueGroup(bs_buffer_, unique_sbp_groups)) {
// store the binary set into an unordered_set
parallel_candidates.push_back(bs_buffer_);
}
}
if (depth >= max_depth) { return; }
// go through the rest of the sbp parallel
std::unordered_set<int32_t>::iterator curr_it = start_it;
while (curr_it != sbp_sets.end()) {
// Take the value out
int32_t nd_sbp_num = *curr_it;
// Then move to the next pointer
++curr_it;
if (accumulator_[nd_sbp_num] == 0) {
bs_buffer_.AddEntry(nd_sbp_num);
++accumulator_[nd_sbp_num];
DfsSbpSet(depth + 1, max_depth, sbp_sets, curr_it, consumer_bn2sbp_set, unique_sbp_groups,
parallel_candidates);
bs_buffer_.DeleteEntry(nd_sbp_num);
--accumulator_[nd_sbp_num];
}
}
}
} // namespace auto_parallel
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef SBP_COLLECTOR_
#define SBP_COLLECTOR_
#include <unordered_map>
#include <vector>
#include <unordered_set>
#include <utility>
#include <type_traits>
#include "oneflow/core/auto_parallel/sbp_graph.h"
#include "oneflow/core/graph/op_graph.h"
#include "oneflow/core/job/sbp_parallel.pb.h"
#include "oneflow/core/job/local_sig_infer_hint.h"
#include "oneflow/core/job/job_builder.h"
// #include "sbp_constructor.h"
#define DEBUG_COLLECTOR_
namespace oneflow {
namespace auto_parallel {
class SbpCollector {
public:
SbpCollector();
~SbpCollector() {}
// Collect all the possible Sbp Parallel from a SbpGraph
void CollectUniverse(const SbpGraph& sbp_graph);
// Export list of possible combination of Sbp Parallels
void ProxySbpCandidate(const OpGraph& op_graph,
const HashMap<std::string, SbpNode*>& op_name2sbp_node,
SbpGraph& sbp_graph);
private:
// Stores all the possible NdSbp.
std::unordered_map<NdSbp, int32_t> nd_sbp_universe_;
// Relationship between id and Sbp Parallel
std::vector<NdSbp> id2nd_sbp_;
// Calculate number of downstream sbp
std::vector<int32_t> accumulator_;
// A binary set buffer to indicate sets of downstream sbp
BinarySet bs_buffer_;
// Collect all the possible Sbp Parallel from a NdSbpSignature
void CollectUniverse(const NdSbpSignature& nd_sbp_sig);
// Collect all the possible Sbp Parallel from a SbpNode
void CollectUniverse(const SbpNode* sbp_node);
// Initialize copy cost from producer to proxy of producer
void InitializeCopyCostFromNode2Proxy(const SbpNode* sbp_proxy, const LogicalBlobId& lbi) const;
// Initialize copy cost from proxy of producer to consumers
void InitializeCopyCostFromProxy2Consumer(
SbpNode* sbp_proxy,
const HashMap<std::pair<std::string, std::string>, BinarySet>& consumer_bn2sbp_set,
const HashMap<std::string, SbpNode*>& op_name2sbp_node) const;
// Maximum number of possible sbp in the proxy
const unsigned long max_num_sbp_proxy_ = 3;
// Depth first search to collect Sbp Parallel information for the whole sbp set
void DfsSbpSet(int32_t depth, int32_t max_depth, const std::unordered_set<int32_t>& sbp_sets,
const std::unordered_set<int32_t>::iterator& sbp_set_it,
const HashMap<std::pair<std::string, std::string>, BinarySet>& consumer_bn2sbp_set,
const std::vector<BinarySet>& unique_sbp_groups,
std::vector<BinarySet>& parallel_candidates);
}; // class SbpCollector
} // namespace auto_parallel
} // namespace oneflow
#endif // SBP_COLLECTOR_
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/auto_parallel/sbp_constructor.h"
#include "oneflow/core/auto_parallel/sbp_node.h"
#include "oneflow/core/auto_parallel/sbp_util.h"
#include "oneflow/core/framework/sbp_infer_util.h"
#include "oneflow/core/graph/op_graph.h"
#include "oneflow/core/job/sbp_parallel.h"
#include "oneflow/core/framework/nd_sbp.h"
#include "oneflow/core/job/job.pb.h"
#include "oneflow/core/auto_parallel/sbp_collector.h"
namespace oneflow {
namespace auto_parallel {
Maybe<void> SbpConstructor::Init(const OpGraph& op_graph, Job* job /*Maybe not use*/) {
JUST(InitSbpGraph(op_graph, *job));
return Maybe<void>::Ok();
}
Maybe<void> SbpConstructor::InitSbpGraph(const OpGraph& op_graph, const Job& job) {
// TODO: process local node
JUST(GenerateNodeAndEdge(op_graph, job));
JUST(FillSbpSignatureForOpNode(op_graph, job));
JUST(InitComputationCost(op_graph));
if (enable_trunk_algo_) { JUST(ApplyTrunkAlgo()); }
if (use_sbp_collector_) {
// Load logical blobs on all sbp edges.
LoadLbi2SbpEdge(op_graph);
// Use sbp collector to create sbp proxy for nodes with multiple downstream operators.
SbpCollector sbp_collector;
sbp_collector.CollectUniverse(sbp_graph_);
sbp_collector.ProxySbpCandidate(op_graph, op_name2sbp_node_, sbp_graph_);
}
JUST(InitCopyCost(op_graph));
// TODO: Set all the sbp signature id to be 0 for initialization.
// Could revert it back to
// sbp_graph_.RandomSbpSignature(use_sbp_collector_);
// after settling down the synchronization of sbp strategy.
sbp_graph_.SetDefaultSbpSig();
double ori_cost = sbp_graph_.ComputeCost();
LOG(INFO) << "Initial cost: " << ori_cost;
// If we do not prune those parallel cast ops, steal the initial strategy from user setting and
// semi-auto parallelism
if (!job.job_conf().enable_auto_parallel_ignore_user_sbp_config()) {
JUST(StealSbpSignatureFromOpNode(op_graph, job));
ori_cost = sbp_graph_.ComputeCost();
LOG(INFO) << "OpGraph cost: " << ori_cost;
}
return Maybe<void>::Ok();
}
Maybe<void> SbpConstructor::FindBestSbpSignature() {
double ori_cost = sbp_graph_.ComputeCost();
LOG(INFO) << "Initial cost: " << ori_cost;
int elimination_num = sbp_graph_.NodeAndEdgeEliminations();
LOG(INFO) << "Elimination number: " << elimination_num;
if (ori_cost > GetValidMaxCopyCost()) {
JUST(sbp_graph_.Find1Strategy4Greedy());
ori_cost = sbp_graph_.ComputeCost();
LOG(INFO) << "Greedy cost: " << ori_cost;
}
sbp_graph_.GreedyStrategy(4);
sbp_graph_.FinalizeSbp();
double final_cost = sbp_graph_.ComputeCost();
LOG(INFO) << "Final cost: " << final_cost;
if (ori_cost + 1.0 < final_cost) { LOG(WARNING) << "ori_cost less than final_cost!!!"; }
// TODO: Restart searching with another original random strategy
CHECK_LT_OR_RETURN(final_cost, GetValidMaxCopyCost())
<< "Failed! Auto parallel can't find a strategy with reasonable cost!";
return Maybe<void>::Ok();
}
Maybe<void> SbpConstructor::DumpNdSbpSignatureForJob(const OpGraph& op_graph, Job* job) {
for (auto& op_conf : *job->mutable_net()->mutable_op()) {
const OpNode* node = op_graph.OpNode4OpName(op_conf.name());
SbpNode* sbp_node = op_name2sbp_node_[node->op().op_name()];
const NdSbpSignature& nd_sbp_sig = sbp_node->FinalSbpSignature();
// Update NdSbpSignature
(*job->mutable_job_parallel_view_conf()
->mutable_op_name2nd_sbp_signature_conf())[node->op().op_name()]
.CopyFrom(nd_sbp_sig);
// If we have 1D SbpSignature Conf
if (node->parallel_desc().hierarchy()->NumAxes() == 1) {
// Update SbpSignature
SbpSignature sbp_signature;
NdSbpSignatureToSbpSignature(nd_sbp_sig, &sbp_signature);
(*job->mutable_job_parallel_view_conf()
->mutable_op_name2sbp_signature_conf())[node->op().op_name()]
.CopyFrom(sbp_signature);
}
JUST(node->op().GetDumpNdSbpSignatureForOpConfFn()(nd_sbp_sig, &op_conf));
}
return Maybe<void>::Ok();
}
Maybe<void> SbpConstructor::GenerateNodeAndEdge(const OpGraph& op_graph, const Job& job) {
JobParallelViewConf job_parallel_view_conf(job.job_parallel_view_conf());
// Collect op_node
std::vector<OpNode*> op_node_list;
op_graph.ForEachNode([&](OpNode* op_node) {
// TODO: support local op
bool is_local_conf = false;
{
const auto& op_name2is_local = job_parallel_view_conf.op_name2is_local_parallel_view();
const auto& iter = op_name2is_local.find(op_node->op().op_name());
if (iter != op_name2is_local.end()) { is_local_conf = iter->second; }
}
CHECK(is_local_conf == false) << "Haven't deal with local operators.";
op_node_list.push_back(op_node);
});
// Decide the order to visit the op
std::vector<int32_t> order;
auto CompareOpName = [&](OpNode* a, OpNode* b) {
return a->op().op_name().compare(b->op().op_name()) > 0;
};
auto_parallel::DecideOrder(op_node_list, order, CompareOpName);
std::vector<int32_t> output_order;
// Create sbp nodes
for (int32_t i = 0; i < op_node_list.size(); i++) {
OpNode* op_node = op_node_list[order[i]];
// Generate sbp node in cost model and link it with corresponding op node
SbpNode* sbp_node = sbp_graph_.GenerateNode();
// Mapping from sbp_node to op_node
sbp_node->op_node_ = op_node; // TODO: SetOpNode()
op_name2sbp_node_[op_node->op().op_name()] = sbp_node;
}
// Create sbp edges
for (int32_t i = 0; i < op_node_list.size(); i++) {
OpNode* op_node = op_node_list[order[i]];
// Get corresponding sbp node
SbpNode* sbp_node = op_name2sbp_node_[op_node->op().op_name()];
std::vector<OpNode*> output_node_list;
for (const auto* op_edge : op_node->out_edges()) {
output_node_list.push_back(op_edge->dst_node());
}
auto_parallel::DecideOrder(output_node_list, output_order, CompareOpName);
for (int32_t j : output_order) {
const auto& end_node_name = output_node_list[j]->op().op_name();
// Generate sbp edge in cost model
sbp_node->PointTo(op_name2sbp_node_[end_node_name]);
}
}
return Maybe<void>::Ok();
}
Maybe<void> SbpConstructor::FillSbpSignatureForOpNode(const OpGraph& op_graph, const Job& job) {
// TODO: use user sbp signature in JobParallelViewConf
// const JobParallelViewConf& job_parallel_view_conf(job.job_parallel_view_conf());
JUST(op_graph.TopoForEachNodeWithErrorCaptured([&](OpNode* op_node) -> Maybe<void> {
HashMap<std::string, const BlobDesc*> ibn2blob_desc;
auto FindShape4Blobs = [&](const PbRpf<std::string>& bns) -> Maybe<void> {
for (const std::string& ibn : bns) {
const LogicalBlobId& lbi = op_node->op().BnInOp2Lbi(ibn);
const BlobDesc* logical_blob_desc = &op_node->LogicalBlobDesc4Lbi(lbi);
ibn2blob_desc.emplace(ibn, logical_blob_desc);
}
return Maybe<void>::Ok();
};
JUST(FindShape4Blobs(op_node->op().input_bns()));
JUST(FindShape4Blobs(op_node->op().output_bns()));
// Get logical blob description
auto LogicalBlobDesc4Ibn = [&](const std::string& ibn) -> Maybe<const BlobDesc&> {
const auto& it = ibn2blob_desc.find(ibn);
if (it == ibn2blob_desc.end()) {
return Error::InvalidValueError()
<< "Cannot find corresponding blob description for input_blob_name : " + ibn + " in "
+ op_node->op().op_name();
}
return *(it->second);
};
// Get all valid sbp_signatures
SbpNode* sbp_node = op_name2sbp_node_[op_node->op().op_name()];
JUST(op_node->op().GetValidNdSbpSignatureList(LogicalBlobDesc4Ibn, op_node->parallel_desc(),
&sbp_node->sbp_sig_list_, /*check_output=*/true));
sbp_node->InitializeSbp();
return Maybe<void>::Ok();
}));
return Maybe<void>::Ok();
}
Maybe<void> SbpConstructor::StealSbpSignatureFromOpNode(const OpGraph& op_graph, const Job& job) {
// Steal some strategy from original op graph
for (auto* sbp_node : sbp_graph_.node_list_) {
// sbp_collectors do not have op_node
if (sbp_node->op_node_) {
for (int32_t sbp_id = 0; sbp_id < sbp_node->sbp_sig_list_.size(); sbp_id++) {
if (*JUST(sbp_node->op_node_->op().nd_sbp_signature()) == sbp_node->sbp_sig_list_[sbp_id]) {
sbp_node->final_sbp_sig_id_ = sbp_id;
break;
}
}
}
}
return Maybe<void>::Ok();
}
Maybe<void> SbpConstructor::InitComputationCost(const OpGraph& op_graph) {
// Compute computation cost for sbp nodes
JUST(op_graph.TopoForEachNodeWithErrorCaptured([&](OpNode* op_node) -> Maybe<void> {
// get corresponding sbp node producer
SbpNode* sbp_node = op_name2sbp_node_[op_node->op().op_name()];
// get parallel description. Number of devices.
const ParallelDesc& parallel_desc = op_node->parallel_desc();
CHECK_EQ_OR_RETURN(sbp_node->cost_.size(), sbp_node->sbp_sig_list_.size());
auto LogicalBlobDesc4Bn = [&](const std::string& bn) -> const BlobDesc& {
const LogicalBlobId& lbi = op_node->op().BnInOp2Lbi(bn);
return op_node->LogicalBlobDesc4Lbi(lbi);
};
for (int32_t sbp_id = 0; sbp_id < sbp_node->sbp_sig_list_.size(); sbp_id++) {
double comp_cost = JUST(op_node->op().GetComputeComplexity(
&sbp_node->sbp_sig_list_[sbp_id], LogicalBlobDesc4Bn, parallel_desc));
if (comp_cost > GetValidMaxCopyCost()) {
sbp_node->cost_[sbp_id] = comp_cost;
} else {
sbp_node->cost_[sbp_id] =
cost_ratio_ * comp_cost
* JUST(op_node->op().GetInputOutputFastestTimeShape())->elem_cnt();
}
}
return Maybe<void>::Ok();
}));
return Maybe<void>::Ok();
}
Maybe<void> SbpConstructor::InitCopyCost(const OpGraph& op_graph) {
// Compute copy cost for sbp edges
op_graph.ForEachNode([&](OpNode* op_node) {
// get corresponding sbp node consumer
SbpNode* sbp_node_consumer = op_name2sbp_node_[op_node->op().op_name()];
// Initialize copy cost between two nodes
for (auto* sbp_edge : sbp_node_consumer->edges_in_) {
// producer sbp node
const auto* sbp_node_producer = sbp_edge->start_node_;
// skip it if proxy
if (!sbp_node_producer->op_node_) { continue; }
sbp_edge->cost_.resize(sbp_node_producer->sbp_sig_list_.size());
int32_t consumer_sbp_size = sbp_node_consumer->sbp_sig_list_.size();
// look through sbp signature in producer
for (int32_t i = 0; i < sbp_node_producer->sbp_sig_list_.size(); ++i) {
sbp_edge->cost_[i].resize(consumer_sbp_size, 0);
}
}
// Find all those cases with wait time
// Do not skip edges carrying no lbi
sbp_node_consumer->InitializeCopyCost(use_sbp_collector_);
});
return Maybe<void>::Ok();
}
Maybe<void> SbpConstructor::ApplyTrunkAlgo() {
auto OpNode2MutableOpCtrlDeps = JUST(GetMutableOpCtrlDeps(*op_graph_));
// Compute layer number for each node
int32_t max_min_layer = sbp_graph_.ComputeLayer(op_name2sbp_node_, *OpNode2MutableOpCtrlDeps);
// Accumulate cost on the trunk after initializing computation cost
sbp_graph_.FindTrunk(max_min_layer, op_name2sbp_node_);
return Maybe<void>::Ok();
}
// Load logical blob ids onto sbp edges
void SbpConstructor::LoadLbi2SbpEdge(const OpGraph& op_graph) {
// Load logical blobs onto sbp edges
for (auto* sbp_node_consumer : sbp_graph_.node_list_) {
auto* op_node = sbp_node_consumer->op_node_;
// Loading logical blobs between two nodes
// look through input blobs
for (const std::string& ibn : op_node->op().input_bns()) {
// Each input blob has one source op node.
OpNode* producer = op_node->MutSrcNode4Ibn(ibn);
// producer sbp node
const auto* sbp_node_producer = op_name2sbp_node_[producer->op().op_name()];
// TODO: recode this
auto* edge_found = sbp_node_consumer->FindEdgeWithNode(sbp_node_producer);
CHECK(edge_found != NULL) << "SbpEdge not found while loading!" << std::endl;
// Add copy cost for each blob
const LogicalBlobId& lbi = op_node->op().BnInOp2Lbi(ibn);
edge_found->LoadLbi(lbi);
}
};
}
Maybe<void> SbpConstructor::CheckSbpAgreement(const Job& job) {
Job new_job;
new_job.CopyFrom(job);
OpGraph op_graph(new_job);
// Compare sbp in job
JUST(op_graph.TopoForEachNodeWithErrorCaptured([&](OpNode* op_node) -> Maybe<void> {
const std::string& op_name = op_node->op().op_name();
const NdSbpSignature& auto_parallel_sbp =
NdSbpSignature(job.job_parallel_view_conf().op_name2nd_sbp_signature_conf().at(op_name));
const NdSbpSignature& new_sbp = op_node->nd_sbp_signature();
CHECK_EQ_OR_RETURN(auto_parallel_sbp.bn_in_op2nd_sbp_size(), new_sbp.bn_in_op2nd_sbp_size());
for (const auto& iter : auto_parallel_sbp.bn_in_op2nd_sbp()) {
const NdSbp& new_sbp_parallel = new_sbp.bn_in_op2nd_sbp().at(iter.first);
const NdSbp& auto_parallel_sbp = iter.second;
// According error message, we can find op_type in op_conf.proto with type_id and locate
// the error op type.
const std::string& error_mgs =
"Op: `" + op_name + "`(type_id: " + std::to_string(op_node->op().op_conf().op_type_case())
+ ") changed sbp from " + NdSbpToString(auto_parallel_sbp) + "(AutoParallel) to "
+ NdSbpToString(new_sbp_parallel) + "(OpGraph) with blob_name: `" + iter.first + "`.";
CHECK_OR_RETURN(new_sbp_parallel == auto_parallel_sbp) << error_mgs;
}
return Maybe<void>::Ok();
}));
return Maybe<void>::Ok();
}
Maybe<HashMap<const OpNode*, HashSet<std::string>>> SbpConstructor::GetMutableOpCtrlDeps(
const OpGraph& op_graph) {
auto IsMutableConsumedLbi = [](const Operator& op, const LogicalBlobId& lbi) -> bool {
for (const std::string& bn : op.input_bns()) {
if (op.BnInOp2Lbi(bn) == lbi && op.InputBlobModifier4Ibn(bn).is_mutable()) { return true; }
}
return false;
};
const auto& IsReachable = op_graph.MakePredicatorIsOpNameDataOrCtrlReachable();
HashMap<const OpNode*, HashSet<std::string>> op_node2ctrl_in_op_names;
JUST(op_graph.MaybeForEachNode([&](OpNode* op_node) -> Maybe<void> {
if (op_node->op().op_conf().has_variable_conf() == false) { return Maybe<void>::Ok(); }
if (op_node->out_edges().size() <= 1) { return Maybe<void>::Ok(); }
const Operator& variable_op = op_node->op();
const LogicalBlobId& variable_lbi = variable_op.BnInOp2Lbi(variable_op.SoleObn());
const OpNode* mutable_consumer = nullptr;
std::vector<const OperatorConf*> naive_consumers;
naive_consumers.reserve(op_node->out_edges().size());
for (OpEdge* edge : op_node->out_edges()) {
const auto& op_conf = edge->dst_node()->op().op_conf();
if (IsMutableConsumedLbi(edge->dst_node()->op(), variable_lbi)) {
CHECK_OR_RETURN(mutable_consumer == nullptr);
mutable_consumer = edge->dst_node();
} else {
naive_consumers.emplace_back(&op_conf);
}
}
if (mutable_consumer == nullptr) { return Maybe<void>::Ok(); }
for (const auto* fw_bw_op : naive_consumers) {
op_node2ctrl_in_op_names[mutable_consumer].insert(fw_bw_op->name());
}
return Maybe<void>::Ok();
}));
// Filter ctrl edges if all ctrl_in_op_names are reachable
HashMap<const OpNode*, HashSet<std::string>> filter_op_ctrl_deps;
for (const auto& pair : op_node2ctrl_in_op_names) {
const OpNode* op_node = pair.first;
for (const auto& fw_bw_op_name : pair.second) {
if (!IsReachable(fw_bw_op_name, op_node->op().op_name())) {
filter_op_ctrl_deps[op_node].insert(fw_bw_op_name);
}
}
}
return filter_op_ctrl_deps;
}
// Print the graph with SBP in order
void SbpConstructor::PrintSBPGraphDebugInfo() {
// sbp constructor information
std::cout << "cost_ratio_:" << cost_ratio_ << std::endl;
std::cout << "wait_time_:" << sbp_graph_.wait_time_ << std::endl;
std::cout << "use_sbp_collector_" << use_sbp_collector_ << std::endl;
// test debug
std::cout << "Get Into Print Op Graph" << std::endl;
// Collect op_node
std::vector<OpNode*> node_list;
for (const auto& op_name_sbp_node : op_name2sbp_node_) {
auto* op_node_ = op_name_sbp_node.second->op_node_;
if (op_node_) { node_list.push_back(op_node_); }
}
// test debug
std::cout << "Deciding order" << std::endl;
// Decide the order to visit the op
std::vector<int32_t> order;
auto_parallel::DecideOrder(node_list, order, [&](OpNode* a, OpNode* b) {
return a->op().op_name().compare(b->op().op_name()) > 0;
});
std::vector<int32_t> str_order;
// test debug
std::cout << "Finish deciding order" << std::endl;
for (int32_t i = 0; i < node_list.size(); i++) {
OpNode* op_node = node_list[order[i]];
std::cout << op_node->op().op_name() << " (^_^):" << std::endl;
// get corresponding sbp node
const auto& it = op_name2sbp_node_.find(op_node->op().op_name());
// Print debug information for sbp graph
CHECK(it != op_name2sbp_node_.end());
const SbpNode* sbp_node = it->second;
std::cout << "Computation Cost: " << sbp_node->cost_[sbp_node->final_sbp_sig_id_];
std::cout << ", Min Layer: " << sbp_node->min_layer_ << ", Max Layer: " << sbp_node->max_layer_
<< ", Tributary Layer: " << sbp_node->tributary_layer_
<< ", in trunk: " << sbp_node->on_trunk_
<< ", Remain Cost: " << sbp_node->acc_trunk_cost_ << std::endl;
// Sort before printing
const auto& op_input_bns = op_node->op().input_bns();
auto CompareString = [](const std::string& a, const std::string& b) {
return a.compare(b) > 0;
};
auto_parallel::DecideOrder(op_input_bns, str_order, CompareString);
const NdSbpSignature& sbp_signature = sbp_node->FinalSbpSignature();
// Print out SBP information for input operator
for (int32_t j : str_order) {
const auto& ibn = op_input_bns[j];
const auto& producer_node = op_node->SrcNode4Ibn(ibn);
std::cout << "Pre Op:" << producer_node.op().op_name() << ": " << ibn;
const auto& this_sbp_parallel = sbp_signature.bn_in_op2nd_sbp().at(ibn);
std::cout << ", " << NdSbpToString(this_sbp_parallel);
if (RequireSameSbp(op_node, ibn)) { std::cout << ", require same SBP"; }
std::cout << ", "
<< op_node->LogicalBlobDesc4Lbi(op_node->op().BnInOp2Lbi(ibn)).shape().elem_cnt();
std::cout << std::endl;
}
// Sort before printing
const auto& op_output_bns = op_node->op().output_bns();
auto_parallel::DecideOrder(op_output_bns, str_order, CompareString);
// Print out SBP information for output blobs
for (int32_t j : str_order) {
const auto& obn = op_output_bns[j];
std::cout << "Out Op:" << obn;
const auto& this_sbp_parallel = sbp_signature.bn_in_op2nd_sbp().at(obn);
std::cout << ", " << NdSbpToString(this_sbp_parallel);
std::cout << ", "
<< op_node->LogicalBlobDesc4Lbi(op_node->op().BnInOp2Lbi(obn)).shape().elem_cnt();
std::cout << std::endl;
}
std::cout << std::endl;
}
}
} // namespace auto_parallel
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_AUTO_PARALLEL_SBP_CONSTRUCTOR_H_
#define ONEFLOW_CORE_AUTO_PARALLEL_SBP_CONSTRUCTOR_H_
#include "oneflow/core/common/util.h"
#include "oneflow/core/auto_parallel/sbp_graph.h"
#include "oneflow/core/job/global_for.h"
namespace oneflow {
class OpGraph;
class Job;
namespace auto_parallel {
// A constructor which will assemble the sbp_graph with the information from oneflow.
// SbpGraph contains the algorithms for elimination and search which is mainly for the strategy
// itself. Constructor mainly deal with the assemblage of each node, edge and the cost computation,
// activation of functions.
class SbpConstructor final {
public:
OF_DISALLOW_COPY_AND_MOVE(SbpConstructor);
SbpConstructor() = delete;
SbpConstructor(const OpGraph& op_graph, Job* job)
: cost_ratio_(job->job_conf().auto_parallel_computation_cost_ratio()),
enable_trunk_algo_(job->job_conf().enable_auto_parallel_trunk_algo()),
use_sbp_collector_(!Singleton<ResourceDesc, ForSession>::Get()
->resource()
.disable_group_boxing_by_dst_parallel()
&& job->job_conf().enable_auto_parallel_sbp_collector()),
op_graph_(&op_graph) {
sbp_graph_.SetWaitTime(job->job_conf().auto_parallel_wait_time());
CHECK_JUST(Init(op_graph, job));
}
~SbpConstructor() = default;
Maybe<void> Init(const OpGraph& op_graph, Job* job);
Maybe<void> FindBestSbpSignature();
Maybe<void> DumpNdSbpSignatureForJob(const OpGraph& op_graph, Job* job);
// Re-build OpGraph and check all sbp is same between op_graph and job
Maybe<void> CheckSbpAgreement(const Job& job);
// Print the graph with SBP in order
void PrintSBPGraphDebugInfo();
private:
Maybe<void> InitSbpGraph(const OpGraph& op_graph, const Job& job);
Maybe<void> GenerateNodeAndEdge(const OpGraph& op_graph, const Job& job);
Maybe<void> FillSbpSignatureForOpNode(const OpGraph& op_graph, const Job& job);
Maybe<void> StealSbpSignatureFromOpNode(const OpGraph& op_graph, const Job& job);
Maybe<void> InitComputationCost(const OpGraph& op_graph);
Maybe<void> InitCopyCost(const OpGraph& op_graph);
Maybe<void> ApplyTrunkAlgo();
Maybe<HashMap<const OpNode*, HashSet<std::string>>> GetMutableOpCtrlDeps(const OpGraph& op_graph);
// Load logical blob ids onto sbp edges
void LoadLbi2SbpEdge(const OpGraph& op_graph);
double cost_ratio_;
bool enable_trunk_algo_;
bool use_sbp_collector_;
SbpGraph sbp_graph_;
const OpGraph* op_graph_;
HashMap<std::string, SbpNode*> op_name2sbp_node_;
};
} // namespace auto_parallel
} // namespace oneflow
#endif // ONEFLOW_CORE_AUTO_PARALLEL_SBP_CONSTRUCTOR_H_
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include <assert.h>
#include <algorithm>
#include <unordered_set>
#include "oneflow/core/job/parallel_desc.h"
#include "oneflow/core/job/lazy_mode.h"
#include "oneflow/core/framework/sbp_infer_util.h"
#include "oneflow/core/auto_parallel/sbp_edge.h"
#include "oneflow/core/auto_parallel/sbp_node.h"
#include "oneflow/core/auto_parallel/sbp_graph.h"
#include "oneflow/core/auto_parallel/sbp_util.h"
#include "oneflow/core/graph/op_graph.h"
namespace oneflow {
namespace auto_parallel {
// function in cpp. Should be put in one file due to use of template
// Otherwise we will need to declare specific template at the end of cpp file.
SbpEdge::SbpEdge(SbpNode* start_node, SbpNode* mid_node, SbpNode* end_node, SbpEdge* first_edge,
SbpEdge* second_edge)
: start_node_(start_node), mid_node_(mid_node), end_node_(end_node) {
edge_list_.emplace_back(first_edge);
edge_list_.emplace_back(second_edge);
};
// Deconstructor
SbpEdge::~SbpEdge() {
if (mid_node_ != nullptr) { delete mid_node_; }
for (auto& this_edge : edge_list_) { delete this_edge; }
}
void SbpEdge::SummarizeCost() {
if (mid_node_) {
cost_.resize(start_node_->cost_.size());
mid_node_sbp_sig_.resize(start_node_->cost_.size());
int32_t end_node_sbp_size = end_node_->cost_.size();
int32_t mid_node_sbp_size = mid_node_->cost_.size();
for (int32_t sbp_start = 0; sbp_start < cost_.size(); sbp_start++) {
cost_[sbp_start].resize(end_node_sbp_size);
mid_node_sbp_sig_[sbp_start].resize(end_node_sbp_size);
for (int32_t sbp_end = 0; sbp_end < end_node_sbp_size; sbp_end++) {
for (int32_t sbp_mid = 0; sbp_mid < mid_node_sbp_size; sbp_mid++) {
// Add middle node cost
double temp_cost = mid_node_->cost_[sbp_mid];
// Add first edge cost
if (edge_list_[0]->start_node_ == start_node_) {
temp_cost += edge_list_[0]->cost_[sbp_start][sbp_mid];
} else {
temp_cost += edge_list_[0]->cost_[sbp_mid][sbp_start];
}
// Add second edge cost
if (edge_list_[1]->end_node_ == end_node_) {
temp_cost += edge_list_[1]->cost_[sbp_mid][sbp_end];
} else {
temp_cost += edge_list_[1]->cost_[sbp_end][sbp_mid];
}
// Compare and look for the minimum cost
if (sbp_mid == 0) {
cost_[sbp_start][sbp_end] = temp_cost;
mid_node_sbp_sig_[sbp_start][sbp_end] = sbp_mid;
} else if (temp_cost < cost_[sbp_start][sbp_end]) {
cost_[sbp_start][sbp_end] = temp_cost;
mid_node_sbp_sig_[sbp_start][sbp_end] = sbp_mid;
}
}
}
}
} else {
cost_.resize(start_node_->cost_.size());
int32_t end_node_sbp_size = end_node_->cost_.size();
for (int32_t sbp_start = 0; sbp_start < cost_.size(); sbp_start++) {
cost_[sbp_start].resize(end_node_sbp_size);
for (int32_t sbp_end = 0; sbp_end < end_node_sbp_size; sbp_end++) {
cost_[sbp_start][sbp_end] = 0;
for (int32_t edge_num = 0; edge_num < edge_list_.size(); edge_num++) {
if (edge_list_[edge_num]->start_node_ == start_node_) {
cost_[sbp_start][sbp_end] += edge_list_[edge_num]->cost_[sbp_start][sbp_end];
} else {
cost_[sbp_start][sbp_end] += edge_list_[edge_num]->cost_[sbp_end][sbp_start];
}
}
}
}
}
}
void SbpEdge::DuplicateCost(
bool merged_node_is_start_node, bool duplicating_first_node,
const std::vector<std::pair<int32_t, int32_t>>& merged_sig_id2children_sig_id) {
const int32_t num_sig = merged_sig_id2children_sig_id.size();
std::vector<std::vector<double>> temp_cost;
std::vector<std::vector<int32_t>> temp_mid_node_sbp_sig;
if (merged_node_is_start_node) {
temp_cost.resize(num_sig);
if (mid_node_) { temp_mid_node_sbp_sig.resize(num_sig); }
for (int32_t i = 0; i < num_sig; i++) {
const int32_t sig_idx = duplicating_first_node ? merged_sig_id2children_sig_id[i].first
: merged_sig_id2children_sig_id[i].second;
temp_cost[i] = cost_[sig_idx];
if (mid_node_) { temp_mid_node_sbp_sig[i] = mid_node_sbp_sig_[sig_idx]; }
}
} else {
const int32_t num_start_sig = cost_.size();
temp_cost.resize(num_start_sig);
if (mid_node_) { temp_mid_node_sbp_sig.resize(num_start_sig); }
for (int32_t i = 0; i < num_start_sig; i++) {
temp_cost[i].resize(num_sig);
if (mid_node_) { temp_mid_node_sbp_sig[i].resize(num_sig); }
for (int32_t j = 0; j < num_sig; j++) {
const int32_t sig_idx = duplicating_first_node ? merged_sig_id2children_sig_id[j].first
: merged_sig_id2children_sig_id[j].second;
temp_cost[i][j] = cost_[i][sig_idx];
if (mid_node_) { temp_mid_node_sbp_sig[i][j] = mid_node_sbp_sig_[i][sig_idx]; }
}
}
}
cost_ = temp_cost;
if (mid_node_) { mid_node_sbp_sig_ = temp_mid_node_sbp_sig; }
}
void SbpEdge::FinalizeSbp() {
// Finalize Sbp for mid_node_
if (mid_node_) {
mid_node_->final_sbp_sig_id_ =
mid_node_sbp_sig_[start_node_->final_sbp_sig_id_][end_node_->final_sbp_sig_id_];
mid_node_->FinalizeSbp();
}
for (const auto& this_edge : edge_list_) { this_edge->FinalizeSbp(); }
}
double SbpEdge::GreedyStrategy() {
// Sbp combination of the minimum cost
int32_t min_sbp_start = start_node_->final_sbp_sig_id_,
min_sbp_end = end_node_->final_sbp_sig_id_;
// An unordered_map to evaluate cost between two edge nodes and other nodes.
std::unordered_map<int32_t, int32_t> node_list_id2nbh_id = {{start_node_->node_list_id_, 0},
{end_node_->node_list_id_, 1}};
// pre-compute and store the current cost between end_node_ and outside.
std::vector<double> end_node_out_cost(end_node_->cost_.size());
for (int32_t sbp_end = 0; sbp_end < cost_[0].size(); sbp_end++) {
end_node_->final_sbp_sig_id_ = sbp_end;
end_node_out_cost[sbp_end] = end_node_->EvalOutNbhCost(node_list_id2nbh_id);
}
// pre-compute and store the current cost between start_node_ and outside.
std::vector<double> start_node_out_cost(start_node_->cost_.size());
for (int32_t sbp_start = 0; sbp_start < cost_.size(); sbp_start++) {
start_node_->final_sbp_sig_id_ = sbp_start;
start_node_out_cost[sbp_start] = start_node_->EvalOutNbhCost(node_list_id2nbh_id);
}
// Current Cost, Minimum Cost, Cost with original sbp
double curr_cost = 0.0;
double min_cost = start_node_out_cost[min_sbp_start] + end_node_out_cost[min_sbp_end]
+ cost_[min_sbp_start][min_sbp_end];
double original_cost = min_cost;
for (int32_t sbp_start = 0; sbp_start < cost_.size(); sbp_start++) {
for (int32_t sbp_end = 0; sbp_end < cost_[0].size(); sbp_end++) {
// compute Current Cost for Neighborhood of edge
end_node_->final_sbp_sig_id_ = sbp_end;
curr_cost =
start_node_out_cost[sbp_start] + end_node_out_cost[sbp_end] + cost_[sbp_start][sbp_end];
// Find the minimum current cost
if (curr_cost < min_cost) {
min_cost = curr_cost;
min_sbp_start = sbp_start;
min_sbp_end = sbp_end;
}
}
}
start_node_->final_sbp_sig_id_ = min_sbp_start;
end_node_->final_sbp_sig_id_ = min_sbp_end;
return min_cost - original_cost;
}
// Get the minimum element in Cost
double SbpEdge::GetMinCost() {
// used the stored value if pre-computed.
if (min_cost_ >= 0) { return min_cost_; }
// Check the size of Cost
CHECK(cost_.size() > 0) << "Cost not initialized!" << std::endl;
// Compute the min_cost
min_cost_ = *std::min_element(cost_[0].begin(), cost_[0].end());
for (int32_t i = 1; i < cost_.size(); i++) {
double min_cost_row = *std::min_element(cost_[i].begin(), cost_[i].end());
if (min_cost_row < min_cost_) { min_cost_ = min_cost_row; }
}
return min_cost_;
}
// Get the maximum element in Cost
double SbpEdge::GetMaxCost() const {
// used the stored value if pre-computed.
// if (max_cost >= 0) return max_cost;
// Check the size of Cost
CHECK(cost_.size() > 0) << "Cost not initialized!" << std::endl;
// Compute the max_cost
double max_cost = -1.0;
for (int32_t i = 0; i < cost_.size(); i++) {
for (int32_t j = 0; j < cost_[i].size(); j++) {
if (cost_[i][j] < GetValidMaxCopyCost() && cost_[i][j] > max_cost) { max_cost = cost_[i][j]; }
}
}
return max_cost;
}
// Assemble copy cost
void SbpEdge::InitializeCopyCost(const std::string& ibn, bool use_sbp_collector) {
// In this part, we assemble the cost from nodes to nodes.
if (start_node_->op_node_ && end_node_->op_node_) {
OpNode* consumer = end_node_->op_node_;
// Add copy cost for each blob
const LogicalBlobId& lbi = consumer->op().BnInOp2Lbi(ibn);
// Check whether lbi is transferred by this edge
if (use_sbp_collector && !SearchLbi(lbi)) { return; }
OpNode* producer = start_node_->op_node_;
const std::string& producer_lbn = *CHECK_JUST(producer->op().obn4lbi(lbi));
const ParallelDesc& producer_parallel_desc =
*CHECK_JUST(producer->op().GetParallelDesc4BnInOp(producer_lbn));
const ParallelDesc& consumer_parallel_desc =
*CHECK_JUST(consumer->op().GetParallelDesc4BnInOp(ibn));
// Need to be careful, the logical blob description should be independent to current
// SbpParallel. Use producer or op_node?
const BlobDesc& logical_blob_desc = producer->LogicalBlobDesc4Lbi(lbi);
const std::string& obn = *CHECK_JUST(producer->op().obn4lbi(lbi));
// If we are deciding whether we need the wait time, then make require_same_sbp true.
// B->S cause cudaEventSynchronize in current implementation.
bool require_same_sbp = RequireSameSbp(consumer, ibn);
int32_t consumer_sbp_size = end_node_->sbp_sig_list_.size();
LazyMode::Guard enable_lazy_mode(true);
// look through sbp signature in producer
for (int32_t sbp_id_producer = 0; sbp_id_producer < start_node_->sbp_sig_list_.size();
sbp_id_producer++) {
// get sbp parallel for a logical blob in producer
const auto& producer_sbp_bn_in_op2sbp_parallel =
start_node_->sbp_sig_list_[sbp_id_producer].bn_in_op2nd_sbp();
const NdSbp& sbp_producer = producer_sbp_bn_in_op2sbp_parallel.at(obn);
// look through sbp signature in consumer
for (int32_t sbp_id_consumer = 0; sbp_id_consumer < consumer_sbp_size; sbp_id_consumer++) {
// get sbp parallel for a logical blob in consumer
const auto& consumer_sbp_bn_in_op2sbp_parallel =
end_node_->sbp_sig_list_[sbp_id_consumer].bn_in_op2nd_sbp();
const NdSbp& sbp_consumer = consumer_sbp_bn_in_op2sbp_parallel.at(ibn);
// compute copy cost for a specific logical blob
double curr_edge_cost = CHECK_JUST(ComputeCopyCostWithMiddleNodes(
sbp_producer, sbp_consumer, logical_blob_desc, producer_parallel_desc,
consumer_parallel_desc, require_same_sbp));
if (curr_edge_cost < GetValidMaxCopyCost()) {
cost_[sbp_id_producer][sbp_id_consumer] +=
CHECK_JUST(producer->op().GetOpTimeShape())->elem_cnt() * curr_edge_cost;
} else {
cost_[sbp_id_producer][sbp_id_consumer] = curr_edge_cost;
}
}
}
}
}
// Set the cut ratio
double SbpEdge::GetCutRatio() const {
int32_t num = 0;
for (int32_t i = 0; i < cost_.size(); i++) {
for (int32_t j = 0; j < cost_[i].size(); j++) {
if (cost_[i][j] < GetValidMaxCopyCost()) { num++; }
}
}
return double(num) / double(cost_.size() * cost_[0].size());
}
// find the cut ratio
// (#c>GetValidMaxCopyCost() in Cost)/(#c in Cost)
double SbpEdge::FindCutRatio(int32_t threshold) const {
double cut_ratio = GetCutRatio();
// lift the cut ratio to 1 to filter out some improper couples to avoid unlimited merging
double n = cost_.size();
double m = cost_[0].size();
double num = cut_ratio * n * m;
cut_ratio += 0.16 * (n + m) / double(threshold);
if (num <= n * 2 || num <= m * 2 || (num <= threshold && cut_ratio < 0.51)) {
return cut_ratio;
} else {
return 1.0;
}
}
// load a logical blob
void SbpEdge::LoadLbi(const LogicalBlobId& lbi) { carry_lbis_.insert(lbi); }
// check the existence of a logical blob
bool SbpEdge::SearchLbi(const LogicalBlobId& lbi) const {
return carry_lbis_.find(lbi) != carry_lbis_.end();
}
// unload a logical blob
void SbpEdge::UnloadLbi(const LogicalBlobId& lbi) {
if (carry_lbis_.erase(lbi) == 0) { std::cout << "Unload an empty lbi!" << std::endl; }
}
// Not carrying any blob
bool SbpEdge::EmptyLbi() const { return carry_lbis_.empty(); }
} // namespace auto_parallel
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_AUTO_PARALLEL_SBP_EDGE_H_
#define ONEFLOW_CORE_AUTO_PARALLEL_SBP_EDGE_H_
#include <assert.h>
#include <algorithm>
#include <unordered_set>
#include "oneflow/core/common/util.h"
#include "oneflow/core/job/parallel_desc.h"
#include "oneflow/core/job/lazy_mode.h"
#include "oneflow/core/framework/sbp_infer_util.h"
#include "oneflow/core/auto_parallel/sbp_node.h"
#include "oneflow/core/auto_parallel/sbp_util.h"
#include "oneflow/core/graph/op_graph.h"
namespace oneflow {
namespace auto_parallel {
// An edge structure to deal with the SBP strategy.
// Please see SbpGraph for the whole algorithm and introduction.
class SbpEdge final {
/* There are 3 types of edges:
* 1. start_node_ -> end_node_
* Nothing special
* 2. Multiple start_node_ -> end_node_
* edge_list_ will store all the edges which goes from start_node_ to end_node_
* 3. start_node_ -> mid_node_ -> end_node_
* It will pass by a middle node.
*/
public:
// Constructor for type 1 & 2
SbpEdge(SbpNode* start_node, SbpNode* end_node) : start_node_(start_node), end_node_(end_node) {
mid_node_ = nullptr;
}
// Constructor for type 3
SbpEdge(SbpNode* start_node, SbpNode* mid_node, SbpNode* end_node, SbpEdge* first_edge,
SbpEdge* second_edge);
// Deconstructor
~SbpEdge();
OF_DISALLOW_COPY_AND_MOVE(SbpEdge);
bool operator==(const SbpEdge& other) { return this == &other; }
// Update copy cost for type 2 and 3
void SummarizeCost();
// Duplicate Cost. Designed for merging two nodes.
void DuplicateCost(bool merged_node_is_start_node, bool duplicating_first_node,
const std::vector<std::pair<int32_t, int32_t>>& merged_sig_id2children_sig_id);
// Determine Final SbpSignature for attachment of this edge
void FinalizeSbp();
// Use Greedy Strategy to pick the sbp signature with minimum cost for this
// edge. You should have an initial strategy before running this. And the
// graph should be fully eliminated.
double GreedyStrategy();
// load a logical blob
void LoadLbi(const LogicalBlobId& lbi);
// check the existence of a logical blob
bool SearchLbi(const LogicalBlobId& lbi) const;
// unload a logical blob
void UnloadLbi(const LogicalBlobId& lbi);
// Not carrying any blob
bool EmptyLbi() const;
// Get the minimum element in Cost
double GetMinCost();
// Get the maximum element in Cost
double GetMaxCost() const;
// Assemble copy cost
void InitializeCopyCost(const std::string& ibn, bool use_sbp_collector);
// find the cut ratio
// (#c>GetValidMaxCopyCost() in Cost)/(#c in Cost)
// But we would lift the cut ratio to 1 to filter out some improper couples
double FindCutRatio(int32_t threshold) const;
// Get the cut ratio
double GetCutRatio() const;
private:
friend class SbpNode;
friend class SbpGraph;
friend class SbpCollector;
friend class SbpConstructor;
// The edge point from start_node_ to end_node_
// It will have a middle node if and only if type 3
SbpNode *start_node_, *mid_node_, *end_node_;
// Cost[sbp_i][sbp_j] is the total cost from start_node_ with sbp_i to end_node_
// with sbp_j
std::vector<std::vector<double>> cost_;
// SbpSignature for mid_node_ with corresponding Cost if type 3, empty otherwise
std::vector<std::vector<int32_t>> mid_node_sbp_sig_;
// Contained edge list:
// empty if type 1,
// Parallel edges if type 2,
// succeed edges if type 3
// the edge list might have reverse direction:
// example 1: type 3 edge_list_ contain two edges:
// mid_node_ -> start_node_, mid_node_ -> end_node_;
// example 2: type 2 edge_list_ contain three edges:
// start_node_ -> end_node_, end_node_ -> start_node_, start_node_ -> end_node_;
std::vector<SbpEdge*> edge_list_;
// Time waiting for other gpus. pthread_cond_wait
double wait_time_ = -1.0;
// a set of ids of logical blobs carried/transferred on this sbp edge
std::unordered_set<LogicalBlobId> carry_lbis_;
// Minimum and maximum cost would not be changed by eliminations, which will generate new edges.
// Also would not be changed by node merging, which will only perform cost copy for the expanding
// dimensions.
// Minimum cost in the 2D array Cost.
// Would be initialized after GetMinCost();
// Only used in the final graph.
double min_cost_ = -1.0;
};
} // namespace auto_parallel
} // namespace oneflow
#endif // ONEFLOW_CORE_AUTO_PARALLEL_SBP_EDGE_H_
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include <algorithm>
#include <unordered_map>
#include "oneflow/core/auto_parallel/binary_set.h"
#include "oneflow/core/auto_parallel/sbp_graph.h"
#include "oneflow/core/auto_parallel/sbp_edge.h"
#include "oneflow/core/auto_parallel/sbp_node.h"
#include "oneflow/core/auto_parallel/algorithm_util.h"
namespace oneflow {
namespace auto_parallel {
// function in cpp. Should be put in one file due to use of template
// Otherwise we will need to declare specific template at the end of cpp file.
namespace {
static const int32_t kMinNodeInGraphForMerging = 4;
} // anonymous namespace
// Generate a node
SbpNode* SbpGraph::GenerateNode() {
SbpNode* this_node = new SbpNode();
node_list_.emplace_back(this_node);
this_node->node_list_id_ = node_list_.size() - 1;
return this_node;
}
void SbpGraph::RemoveFromNodeList(SbpNode* this_node) {
if (this_node->node_list_id_ < 0) { return; }
node_list_.back()->node_list_id_ = this_node->node_list_id_;
RemoveFrom<SbpNode*>(node_list_, this_node->node_list_id_);
this_node->node_list_id_ = -1;
}
SbpGraph::~SbpGraph() {
for (auto this_node : node_list_) { delete this_node; }
node_list_.clear();
}
void SbpGraph::RandomSbpSignature(bool use_sbp_collector) const {
for (const auto& this_node : node_list_) {
if (this_node->sbp_sig_list_.size() > 0) {
this_node->final_sbp_sig_id_ = rand() % this_node->sbp_sig_list_.size();
} else {
// It must be a proxy when this_node->sbp_sig_list_.size() == 0
this_node->final_sbp_sig_id_ = rand() % this_node->parallel_candidates_.size();
}
}
};
void SbpGraph::SetDefaultSbpSig() const {
for (const auto& this_node : node_list_) { this_node->final_sbp_sig_id_ = 0; }
};
double SbpGraph::ComputeCost() const {
// Over All Cost under current strategy
double graph_cost_ = 0;
for (const auto& this_node : node_list_) {
int32_t this_id = this_node->final_sbp_sig_id_;
graph_cost_ += this_node->cost_[this_id];
for (const auto& edge_out : this_node->edges_out_) {
graph_cost_ += edge_out->cost_[this_id][edge_out->end_node_->final_sbp_sig_id_];
}
}
return graph_cost_;
}
int32_t SbpGraph::NodeElimination(SbpNode* this_node) {
if (this_node->edges_in_.size() + this_node->edges_out_.size() == 2) {
std::vector<SbpNode*> two_nodes;
for (const auto& one_edge : this_node->edges_in_) two_nodes.emplace_back(one_edge->start_node_);
for (const auto& one_edge : this_node->edges_out_) two_nodes.emplace_back(one_edge->end_node_);
// If a node is pointing to itself, could happen when shrink from a circle
if (two_nodes[0] == two_nodes[1]) {
int32_t elimination_number = 0;
if (this_node->edges_out_.empty()) {
elimination_number += EdgeElimination(two_nodes[0]);
} else {
elimination_number += EdgeElimination(this_node);
}
elimination_number += ChildElimination(this_node);
return elimination_number;
}
std::vector<SbpEdge*> two_edges(this_node->edges_in_);
two_edges.insert(two_edges.end(), this_node->edges_out_.begin(), this_node->edges_out_.end());
int32_t edges_in_size = this_node->edges_in_.size();
SbpEdge* e = new SbpEdge(two_nodes[0], this_node, two_nodes[1], two_edges[0], two_edges[1]);
e->SummarizeCost();
// check and remove the edge_in with new edge in graph
for (int32_t i = 0; i < edges_in_size; i++) {
CheckAndRemoveFrom<SbpEdge*>(two_nodes[i]->edges_out_, two_edges[i]);
}
// check and remove the edge_out with new edge in graph
for (int32_t i = edges_in_size; i < 2; i++) {
CheckAndRemoveFrom<SbpEdge*>(two_nodes[i]->edges_in_, two_edges[i]);
}
// Let e take control of edge_list_ completely by disconnecting MidNode
e->mid_node_->edges_out_.clear();
e->mid_node_->edges_in_.clear();
// Insert new compound edge into graph
two_nodes[0]->edges_out_.emplace_back(e);
two_nodes[1]->edges_in_.emplace_back(e);
// eliminate the node from graph by swapping with the last element and
// popping
RemoveFromNodeList(this_node);
// successfully eliminate this node
return 1;
}
// can not eliminate this node
return 0;
}
int32_t SbpGraph::NodeAndEdgeEliminations() {
// Total elimination number
int32_t total_elimination_num = 0;
int32_t elimination_num = 1;
// repeat these kinds of elimination until stuck
while (elimination_num > 0) {
elimination_num = 0;
for (int32_t i = node_list_.size() - 1; i >= 0; i--) {
elimination_num += NodeElimination(node_list_[i]);
}
for (int32_t i = node_list_.size() - 1; i >= 0; i--) {
elimination_num += EdgeElimination(node_list_[i]);
}
for (int32_t i = node_list_.size() - 1; i >= 0; i--) {
elimination_num += ChildElimination(node_list_[i]);
}
if (elimination_num == 0 && node_list_.size() > 2) {
elimination_num += PickAndMerge();
for (int32_t i = node_list_.size() - 1; i >= 0; i--) {
elimination_num += EdgeElimination(node_list_[i]);
}
}
total_elimination_num += elimination_num;
}
return total_elimination_num;
}
int32_t SbpGraph::EdgeElimination(SbpNode* this_node) const {
// Remove all edges with (start_node -> end_node) from edges_in_ of end_node
auto RemoveFromEdgesIn = [](SbpNode* start_node, SbpNode* end_node) -> void {
for (int32_t i = end_node->edges_in_.size() - 1; i >= 0; i--) {
if (start_node == end_node->edges_in_[i]->start_node_) {
RemoveFrom<SbpEdge*>(end_node->edges_in_, i);
}
}
};
auto LookForParallelEdge = [](SbpEdge*& e, SbpNode* start_node, SbpNode* end_node,
bool if_reverse, int32_t stop_sign) -> int32_t {
// elimination edges with specific start node and end node in
// start_node->edges_out_ from index stop sign to the end.
// start_node->edges_out_[stop_sign] not included and need special treatment
// after this process.
int32_t elimination_num = 0;
for (int32_t j = start_node->edges_out_.size() - 1; j > stop_sign; j--) {
if (end_node == start_node->edges_out_[j]->end_node_) {
if (!e) {
if (if_reverse) {
e = new SbpEdge(end_node, start_node);
} else {
e = new SbpEdge(start_node, end_node);
}
}
// edge elimination
e->edge_list_.emplace_back(start_node->edges_out_[j]);
elimination_num++;
RemoveFrom<SbpEdge*>(start_node->edges_out_, j);
}
}
return elimination_num;
};
int32_t elimination_num = 0;
for (int32_t i = 0; i < this_node->edges_out_.size(); i++) {
SbpEdge* e = nullptr;
// Find and delete Parallel Edges from edges_out_
elimination_num += LookForParallelEdge(e, this_node, this_node->edges_out_[i]->end_node_,
/*if_reverse=*/false, i);
elimination_num += LookForParallelEdge(e, this_node->edges_out_[i]->end_node_, this_node,
/*if_reverse=*/true, /*stop_sign=*/-1);
if (e) {
// Delete Parallel Edges from edges_in_
RemoveFromEdgesIn(this_node, e->end_node_);
RemoveFromEdgesIn(e->end_node_, this_node);
// Add the compound edge
e->edge_list_.emplace_back(this_node->edges_out_[i]);
this_node->edges_out_[i] = e;
e->SummarizeCost();
e->end_node_->edges_in_.emplace_back(e);
}
}
return elimination_num;
}
int32_t SbpGraph::ChildElimination(SbpNode* this_node) {
if (this_node->EliminateItselfAsChild()) {
// eliminate this node from global node list
RemoveFromNodeList(this_node);
// successfully eliminate this node
return 1;
} else {
// can not eliminate this node
return 0;
}
}
// Merge two nodes
int32_t SbpGraph::NodeMerging(SbpNode* first, SbpNode* second) {
SbpNode* new_node = new SbpNode(first, second);
// Adjust node_list_
RemoveFromNodeList(first);
RemoveFromNodeList(second);
new_node->node_list_id_ = node_list_.size();
node_list_.emplace_back(new_node);
return 1;
}
void SbpGraph::FinalizeSbp() const {
for (const auto& this_node : node_list_) { this_node->FinalizeSbp(); }
}
double SbpGraph::GreedyStrategy(bool for_node) const {
// Overall, this function should be replaced by GreedyStrategy(nbh_num);
// Total Cost Reduce & Cost Reduce for one loop
double total_cost_reduction = 0, cost_reduction = 0;
for (int32_t step = node_list_.size(); step >= 0; step--) {
cost_reduction = 0;
for (SbpNode* this_node : node_list_) {
// Use GreedyStrategy on Nodes if there is one node left for this
// connected component. Otherwise, Use GreedyStrategy on Edges.
if (for_node || this_node->edges_in_.size() + this_node->edges_out_.size() == 0) {
cost_reduction += this_node->GreedyStrategy();
} else {
// GreedyStrategy on Edges.
for (SbpEdge* this_edge : this_node->edges_out_) {
double second_rdc = this_edge->GreedyStrategy();
cost_reduction += second_rdc;
}
}
}
if (cost_reduction == 0) { break; }
total_cost_reduction += cost_reduction;
}
return total_cost_reduction;
}
double SbpGraph::GreedyStrategy(int32_t nbh_num) const {
// nbh_num is the maximum number of neighborhood to adjust sbp strategy in each step
// Total Cost Reduce & Cost Reduce for one loop
double total_cost_reduction = 0, cost_reduction = 0;
// A global buffer to store part of the one ring neighborhood.
std::vector<int32_t> nbh_id2node_list_id;
// Not accept a number lower than 1
if (nbh_num < 1) { nbh_num = 1; }
nbh_id2node_list_id.resize(nbh_num);
std::vector<int32_t> original_sbp_sig_id(nbh_num);
// store all the node_list_id whose corresponding nodes will be visited
// We can use unordered_map to do this but vector is faster
std::vector<int32_t> pre_visit_node_list(node_list_.size() + 1);
for (int32_t nbh_id = 0; nbh_id < node_list_.size(); nbh_id++) {
pre_visit_node_list[nbh_id] = nbh_id;
}
int32_t head = 0, tail = node_list_.size();
// whether a node_list_id is in pre_visit_node_list
std::vector<bool> pre_visit_tags(node_list_.size(), true);
int32_t step = 0;
// 1 ring neighborhood buffer
std::vector<int32_t> nbh_1ring(nbh_num);
// 2 ring neighborhood buffer
std::vector<int32_t> nbh_2ring;
std::vector<bool> node_tags(node_list_.size(), false);
std::vector<int32_t> nbh_1ring_buffer;
while (head != tail && step < node_list_.size()) {
auto* this_node = node_list_[pre_visit_node_list[head]];
if (nbh_num <= 1) {
// Greedy strategy on nodes, here we use nbh_1ring to store the nbh_id2node_list_id
// information for reutilization
nbh_1ring[0] = this_node->node_list_id_;
// store the original sbp signature of the 1-ring neighborhood for comparison
original_sbp_sig_id[0] = this_node->final_sbp_sig_id_;
cost_reduction = NbhGreedyStrategy(nbh_1ring);
} else {
// Use GreedyStrategy on the one ring neighborhood of this node.
this_node->OneRingNeighborhood(nbh_1ring);
// store the original sbp signature of the 1-ring neighborhood for comparison
original_sbp_sig_id.resize(nbh_1ring.size());
for (int32_t nbh_id = 0; nbh_id < nbh_1ring.size(); nbh_id++) {
original_sbp_sig_id[nbh_id] = node_list_[nbh_1ring[nbh_id]]->final_sbp_sig_id_;
}
if (nbh_1ring.size() <= nbh_num) {
cost_reduction = NbhGreedyStrategy(nbh_1ring);
} else {
// Use GreedyStrategy on part of the one ring neighborhood.
// Loop through the neighborhood. Each loop should contain the centroid.
// Initialize part of the one ring neighborhood
int32_t nbh_1ring_id = nbh_1ring.size() - nbh_num;
for (int32_t nbh_id = 1; nbh_id < nbh_num; ++nbh_id) {
nbh_id2node_list_id[nbh_id] = nbh_1ring[++nbh_1ring_id];
}
// loop through the one ring neighborhood
cost_reduction = 0;
int32_t nbh_id = 0;
for (nbh_1ring_id = 0; nbh_1ring_id < nbh_1ring.size(); ++nbh_1ring_id) {
nbh_id2node_list_id[nbh_id] = nbh_1ring[nbh_1ring_id];
cost_reduction += NbhGreedyStrategy(nbh_id2node_list_id);
// nbh_id for the next step
if (++nbh_id >= nbh_num) { nbh_id = 1; }
}
}
}
// change of strategies
if (cost_reduction != 0) {
// Add neighborhood into pre-visited node list for each node with changing strategy
for (int32_t nbh_id = 0; nbh_id < nbh_1ring.size(); nbh_id++) {
// If changes occur
if (original_sbp_sig_id[nbh_id] != node_list_[nbh_1ring[nbh_id]]->final_sbp_sig_id_) {
// schedule to visit the neighborhood of that changing node
node_list_[nbh_1ring[nbh_id]]->NRingNeighborhood(2, nbh_2ring, nbh_1ring_buffer,
node_list_, node_tags);
for (int32_t nbh_node_list_id : nbh_2ring) {
// Put them into the pre-visited node list
if (!pre_visit_tags[nbh_node_list_id]) {
pre_visit_node_list[tail] = nbh_node_list_id;
pre_visit_tags[nbh_node_list_id] = true;
tail++;
if (tail == pre_visit_node_list.size()) { tail = 0; }
}
}
}
}
}
// Finish visiting
pre_visit_tags[pre_visit_node_list[head]] = false;
head++;
if (head == pre_visit_node_list.size()) {
head = 0;
step++;
}
total_cost_reduction += cost_reduction;
}
return total_cost_reduction;
}
void SbpGraph::DfsAddNbhCost(std::vector<int32_t>& nbh_id2node_list_id,
std::unordered_map<int32_t, int32_t>& node_list_id2nbh_id,
std::vector<int32_t>& order2nbh_id, std::vector<int32_t>& nbh_id2order,
std::vector<double>& order2acc_min_in_nbh_cost,
std::vector<std::vector<double>>& out_nbh_costs,
std::vector<std::vector<int32_t>>& nbh_id2order2sbp_id,
std::vector<int32_t>& min_sbp_sig_id, double& min_cost, int32_t order,
double curr_cost) const {
// We have finished visiting the neighborhood
if (order >= nbh_id2node_list_id.size()) {
// relative difference > 1e-12
if (curr_cost < min_cost * kFloatDeviationMinus) {
min_cost = curr_cost;
for (int32_t nbh_id = 0; nbh_id < nbh_id2node_list_id.size(); nbh_id++) {
min_sbp_sig_id[nbh_id] = node_list_[nbh_id2node_list_id[nbh_id]]->final_sbp_sig_id_;
}
}
return;
}
// Pruning, remove all those branch with large cost
if (curr_cost + order2acc_min_in_nbh_cost[order] >= min_cost) { return; }
// Deep first search in the next order
int32_t nbh_id = order2nbh_id[order];
SbpNode* sbp_node = node_list_[nbh_id2node_list_id[nbh_id]];
for (int32_t sbp_id : nbh_id2order2sbp_id[nbh_id]) {
sbp_node->final_sbp_sig_id_ = sbp_id;
DfsAddNbhCost(nbh_id2node_list_id, node_list_id2nbh_id, order2nbh_id, nbh_id2order,
order2acc_min_in_nbh_cost, out_nbh_costs, nbh_id2order2sbp_id, min_sbp_sig_id,
min_cost, order + 1,
curr_cost + out_nbh_costs[nbh_id][sbp_id]
+ sbp_node->EvalInNbhCost(node_list_id2nbh_id, nbh_id2order));
}
}
bool SbpGraph::DfsFindReasonableCost(std::vector<int32_t>& nbh_id2node_list_id,
std::unordered_map<int32_t, int32_t>& node_list_id2nbh_id,
std::vector<int32_t>& nbh_id2order, int32_t nbh_id) const {
// We found such a strategy
if (nbh_id == nbh_id2order.size()) { return true; }
SbpNode* sbp_node = node_list_[nbh_id2node_list_id[nbh_id]];
// Start from B.
for (int32_t sbp_id = sbp_node->cost_.size() - 1; sbp_id >= 0; sbp_id--) {
sbp_node->final_sbp_sig_id_ = sbp_id;
// If the cost for this node is reasonable, then go to the next one
if (sbp_node->cost_[sbp_id] + sbp_node->EvalInNbhCost(node_list_id2nbh_id, nbh_id2order)
< GetValidMaxCopyCost()) {
if (DfsFindReasonableCost(nbh_id2node_list_id, node_list_id2nbh_id, nbh_id2order,
nbh_id + 1)) {
// If we found one strategy, then exist the Dfs.
return true;
}
}
}
// Can not find a reasonable strategy with the setting for previous nodes.
// Go back and change the previous node.
return false;
}
// Find one strategy with finite cost for adjustment
Maybe<void> SbpGraph::Find1Strategy4Greedy() const {
std::vector<int32_t> nbh_id2node_list_id;
std::vector<bool> not_visited(node_list_.size(), true);
std::vector<int32_t> nbh_1ring;
int32_t head = 0;
int32_t tail = 0;
std::vector<double> node_cut_ratios(node_list_.size());
// Initialize cut ratio for all the nodes
for (int32_t node_list_id = 0; node_list_id < node_list_.size(); node_list_id++) {
node_cut_ratios[node_list_id] = node_list_[node_list_id]->GetCutRatio();
}
// If have not visited all the nodes
while (tail < node_list_.size()) {
// Find the node with the minimum cut ratio
int32_t node_with_min_cut_ratio = -1;
double min_cut_ratio = 2.0;
for (int32_t node_list_id = 0; node_list_id < node_list_.size(); node_list_id++) {
if (not_visited[node_list_id]) {
double curr_cut_ratio = node_cut_ratios[node_list_id];
if (curr_cut_ratio < min_cut_ratio) {
min_cut_ratio = curr_cut_ratio;
node_with_min_cut_ratio = node_list_id;
}
}
}
// put this node into the open set
nbh_id2node_list_id.push_back(node_with_min_cut_ratio);
not_visited[node_with_min_cut_ratio] = false;
tail++;
// BFS
while (head < tail) {
// look for the neighborhood of the head
int32_t node_list_id = nbh_id2node_list_id[head];
node_list_[node_list_id]->OneRingNeighborhood(nbh_1ring);
// sort
std::sort(nbh_1ring.begin(), nbh_1ring.end(),
[&](int32_t i, int32_t j) { return node_cut_ratios[i] < node_cut_ratios[j]; });
for (int32_t curr_id : nbh_1ring) {
if (not_visited[curr_id]) {
nbh_id2node_list_id.push_back(curr_id);
tail++;
not_visited[curr_id] = false;
}
}
head++;
}
}
// mapping from the node_list_id to the id in the nbh_id2node_list_id
std::unordered_map<int32_t, int32_t> node_list_id2nbh_id;
InverseFunction<int32_t>(nbh_id2node_list_id, node_list_id2nbh_id);
// Initial an ordinary order
std::vector<int32_t> nbh_id2order(nbh_id2node_list_id.size());
for (int32_t nbh_id = 0; nbh_id < nbh_id2node_list_id.size(); nbh_id++) {
nbh_id2order[nbh_id] = nbh_id;
}
// Combining deep first search and pruning based on cut ratio
CHECK(DfsFindReasonableCost(nbh_id2node_list_id, node_list_id2nbh_id, nbh_id2order, /*nbh_id=*/0))
<< "Can't find a reasonable strategy!";
return Maybe<void>::Ok();
}
// Use brute force to search for a strategy with minimum cost for a neighborhood
double SbpGraph::NbhGreedyStrategy(std::vector<int32_t>& nbh_id2node_list_id) const {
// number of nodes in the neighborhood
int32_t num_nbh = nbh_id2node_list_id.size();
// mapping from the node_list_id to the id in the nbh_id2node_list_id
std::unordered_map<int32_t, int32_t> node_list_id2nbh_id;
InverseFunction<int32_t>(nbh_id2node_list_id, node_list_id2nbh_id);
// a sbp signature id set minimizing the overall cost, store the original one as default
std::vector<int32_t> min_sbp_sig_id(num_nbh);
for (int32_t nbh_id = 0; nbh_id < num_nbh; nbh_id++) {
min_sbp_sig_id[nbh_id] = node_list_[nbh_id2node_list_id[nbh_id]]->final_sbp_sig_id_;
}
// pre-compute and store the cost between neighborhood and outside nodes under different sbp for
// each node within the neighborhood
std::vector<std::vector<double>> out_nbh_costs(num_nbh);
for (int32_t nbh_id = 0; nbh_id < num_nbh; nbh_id++) {
SbpNode* sbp_node = node_list_[nbh_id2node_list_id[nbh_id]];
out_nbh_costs[nbh_id].resize(sbp_node->cost_.size());
for (int32_t sbp_id = sbp_node->cost_.size() - 1; sbp_id >= 0; sbp_id--) {
sbp_node->final_sbp_sig_id_ = sbp_id;
out_nbh_costs[nbh_id][sbp_id] = sbp_node->EvalOutNbhCost(node_list_id2nbh_id);
}
}
// pre-compute and store the order of the out_nbh_costs
std::vector<std::vector<int32_t>> nbh_id2order2sbp_id(num_nbh);
auto CompareDoubleLess = [](double a, double b) { return a < b; };
for (int32_t nbh_id = 0; nbh_id < num_nbh; nbh_id++) {
DecideOrder(out_nbh_costs[nbh_id], nbh_id2order2sbp_id[nbh_id], CompareDoubleLess);
}
// Decide the order to go through the neighborhood.
// Should visit those nodes with a larger difference in the out cost first.
std::vector<double> out_nbh_cost_diff(num_nbh);
for (int32_t nbh_id = 0; nbh_id < num_nbh; nbh_id++) {
out_nbh_cost_diff[nbh_id] =
*std::max_element(out_nbh_costs[nbh_id].begin(), out_nbh_costs[nbh_id].end())
- *std::min_element(out_nbh_costs[nbh_id].begin(), out_nbh_costs[nbh_id].end());
}
std::vector<int32_t> order2nbh_id;
DecideOrder(out_nbh_cost_diff, order2nbh_id, [](double a, double b) { return a > b; });
// Find the inverse map of order
std::vector<int32_t> nbh_id2order;
InverseOrder(order2nbh_id, nbh_id2order);
// Current Cost, Minimum Cost, Cost with original sbp
double original_cost = 0;
// Recover original sbp
for (int32_t nbh_id = 0; nbh_id < num_nbh; nbh_id++) {
node_list_[nbh_id2node_list_id[nbh_id]]->final_sbp_sig_id_ = min_sbp_sig_id[nbh_id];
}
// Compute cost with original sbp
for (int32_t nbh_id = 0; nbh_id < num_nbh; nbh_id++) {
SbpNode* sbp_node = node_list_[nbh_id2node_list_id[nbh_id]];
original_cost += out_nbh_costs[nbh_id][min_sbp_sig_id[nbh_id]];
original_cost += sbp_node->EvalInNbhCost(node_list_id2nbh_id, nbh_id2order);
}
double min_cost = original_cost;
// Accumulate minimum cost from the current node to the end of the neighborhood node list.
// The accumulated cost include the current node.
std::vector<double> order2acc_min_in_nbh_cost(num_nbh);
order2acc_min_in_nbh_cost[num_nbh - 1] =
*std::min_element(out_nbh_costs[order2nbh_id[num_nbh - 1]].begin(),
out_nbh_costs[order2nbh_id[num_nbh - 1]].end());
for (int32_t order = num_nbh - 2; order >= 0; order--) {
int32_t nbh_id = order2nbh_id[order];
order2acc_min_in_nbh_cost[order] =
order2acc_min_in_nbh_cost[order + 1]
+ *std::min_element(out_nbh_costs[nbh_id].begin(), out_nbh_costs[nbh_id].end())
+ node_list_[nbh_id2node_list_id[nbh_id]]->EvalMinInNbhCost(node_list_id2nbh_id,
nbh_id2order);
}
// Use brute force (DFS) to adjust for the best strategy in the neighborhood.
DfsAddNbhCost(nbh_id2node_list_id, node_list_id2nbh_id, order2nbh_id, nbh_id2order,
order2acc_min_in_nbh_cost, out_nbh_costs, nbh_id2order2sbp_id, min_sbp_sig_id,
min_cost, /*order=*/0, /*curr_cost=*/0);
// Use the sbp strategy with minimum cost
for (int32_t nbh_id = 0; nbh_id < num_nbh; nbh_id++) {
node_list_[nbh_id2node_list_id[nbh_id]]->final_sbp_sig_id_ = min_sbp_sig_id[nbh_id];
}
if (min_cost < original_cost) {
// Directly return (min_cost - original_cost) might have floating point error up to 3e-16
// For example, original_cost: 2.22507e+06, min_cost: 2.22507e+06,
// diff: -4.65661e-10, relative diff:2.09279e-16
// Therefore, we use a threshold to filter out such fake true detection to
// avoid unlimited search.
if (original_cost * kFloatDeviationMinus > min_cost) { return min_cost - original_cost; }
}
return 0.0;
}
// Select and Merge two nodes
int32_t SbpGraph::PickAndMerge() {
if (node_list_.size() < kMinNodeInGraphForMerging) { return 0; }
// Pick the one with the smallest cut ratio
double min_cut_ratio = 1.0;
double curr_cut_ratio = 0.0;
SbpEdge* merging_edge = nullptr;
for (int32_t i = 0; i < node_list_.size(); i++) {
for (SbpEdge* edge_in : node_list_[i]->edges_in_) {
curr_cut_ratio = edge_in->FindCutRatio(threshold_);
if (curr_cut_ratio < min_cut_ratio) {
min_cut_ratio = curr_cut_ratio;
merging_edge = edge_in;
}
}
}
if (merging_edge != nullptr) {
// Merge two nodes on the edge with the minimum cut ratio
return NodeMerging(merging_edge->start_node_, merging_edge->end_node_);
} else {
// Pick the couple with the largest similar neighborhood
std::vector<BinarySet> node_binary_sets(node_list_.size());
for (int32_t i = 0; i < node_list_.size(); i++) {
// Transfer edge to binary set
node_binary_sets[i].Initialize(node_list_.size());
node_binary_sets[i].AddEntry(i);
for (const SbpEdge* edge_in : node_list_[i]->edges_in_) {
node_binary_sets[i].AddEntry(edge_in->start_node_->node_list_id_);
}
for (const SbpEdge* edge_out : node_list_[i]->edges_out_) {
node_binary_sets[i].AddEntry(edge_out->start_node_->node_list_id_);
}
}
// Find two nodes with largest common subset
// buffer of binary set
BinarySet buffer_binary_set(node_list_.size());
// Number of common edges
int32_t max_comm_edge_num = 0, curr_comm_edge_num = 0;
int32_t min_node_pair[2];
// Number of Sbp Signature in merged node
int32_t min_sbp_num = 0, curr_sbp_num = 0;
for (int32_t i = 0; i < node_list_.size(); i++) {
for (int32_t j = i + 1; j < node_list_.size(); j++) {
curr_sbp_num = node_list_[i]->cost_.size() * node_list_[j]->cost_.size();
if (curr_sbp_num <= threshold_) {
node_binary_sets[i].IntersectionTo(node_binary_sets[j], buffer_binary_set);
curr_comm_edge_num = buffer_binary_set.Total();
if (curr_comm_edge_num > max_comm_edge_num
|| (curr_comm_edge_num == max_comm_edge_num && curr_sbp_num < min_sbp_num)) {
min_node_pair[0] = i;
min_node_pair[1] = j;
max_comm_edge_num = curr_comm_edge_num;
min_sbp_num = curr_sbp_num;
}
}
}
}
if (max_comm_edge_num > 0) {
return NodeMerging(node_list_[min_node_pair[0]], node_list_[min_node_pair[1]]);
} else {
return 0;
}
}
}
// Clip an edge, remove it from graph
void SbpGraph::ClipEdge(SbpEdge* this_edge) const {
CheckAndRemoveFrom<SbpEdge*>(this_edge->end_node_->edges_in_, this_edge);
CheckAndRemoveFrom<SbpEdge*>(this_edge->start_node_->edges_out_, this_edge);
delete this_edge;
}
// Compute the minimum and maximum layer of each node in the graph
int32_t SbpGraph::ComputeLayer(
HashMap<std::string, SbpNode*>& op_name2sbp_node,
const HashMap<const OpNode*, HashSet<std::string>>& op_node2mutable_op_ctrl_deps) const {
// Compute minimum layer
for (SbpNode* this_node : node_list_) {
this_node->GetMinLayer(op_name2sbp_node, op_node2mutable_op_ctrl_deps);
}
// Find the largest minimum layer
int32_t max_min_layer = -1;
for (SbpNode* this_node : node_list_) {
if (max_min_layer < this_node->min_layer_) { max_min_layer = this_node->min_layer_; }
}
// Compute maximum layer
for (SbpNode* this_node : node_list_) {
this_node->SpreadMaxLayer(op_name2sbp_node, op_node2mutable_op_ctrl_deps);
}
for (SbpNode* this_node : node_list_) { this_node->LiftMaxLayer(max_min_layer); }
return max_min_layer;
}
// Find the trunk of the sbp graph, then reduce the wait time for tributaries
void SbpGraph::FindTrunk(int32_t max_min_layer,
HashMap<std::string, SbpNode*>& op_name2sbp_node) const {
// Summarize cost for each layer, on the trunk or tributaries
std::vector<double> trunk_cost(max_min_layer + 1, 0);
for (SbpNode* this_node : node_list_) {
trunk_cost[this_node->min_layer_] += this_node->GetMinCost();
}
// Decide trunks
double acc_cost = 0;
// All the nodes with MinLayer>=trunk_end_id would be considered as trunks
int32_t trunk_end_id = max_min_layer;
for (int32_t layer_id = max_min_layer; layer_id >= 0; layer_id--) {
acc_cost += trunk_cost[layer_id];
if (acc_cost > 0.5 * wait_time_) {
trunk_end_id = layer_id;
break;
}
}
// Find out all the nodes on the trunk.
for (SbpNode* this_node : node_list_) {
if (this_node->min_layer_ >= trunk_end_id) { this_node->SpreadTrunk(op_name2sbp_node); }
}
// Compute maximum layer for tributaries
// Clear counter and initialize tributary layer for each sbp node
for (SbpNode* this_node : node_list_) {
this_node->counter_ = 0;
this_node->DropTributaryLayer(max_min_layer);
}
// Count the number of consumers and downstream nodes
for (SbpNode* this_node : node_list_) { this_node->RaiseConsumerNum(op_name2sbp_node); }
// Compute maximum layer for tributaries
for (SbpNode* this_node : node_list_) { this_node->SpreadTributaryLayer(op_name2sbp_node); }
// Summarize cost for each layer on the trunk, store it to avoid subtraction of large values.
trunk_cost.assign(max_min_layer + 1, 0);
// tributary cost start from each min layer
std::vector<double> tributary_cost(max_min_layer + 1, 0);
// tributary cost would be outdated after Max Layer (before Max Layer + 1)
std::vector<double> outdated_tributary_cost(max_min_layer + 1, 0);
// number of operators in the trunk
std::vector<std::vector<SbpNode*>> trunk_ops(max_min_layer + 1);
for (SbpNode* this_node : node_list_) {
if (this_node->on_trunk_) {
trunk_cost[this_node->min_layer_] += this_node->GetMinCost();
trunk_ops[this_node->min_layer_].emplace_back(this_node);
} else {
double curr_min_cost = this_node->GetMinCost();
tributary_cost[this_node->min_layer_] += curr_min_cost;
outdated_tributary_cost[this_node->tributary_layer_] += curr_min_cost;
}
}
// Accumulate the cost from the consumer to the end, not including itself
std::vector<double> acc_trunk_cost(max_min_layer + 1, 0);
for (int32_t layer_id = max_min_layer; layer_id > 0; layer_id--) {
acc_trunk_cost[layer_id - 1] = acc_trunk_cost[layer_id] + trunk_cost[layer_id];
}
// Clear counter for each sbp node
for (SbpNode* this_node : node_list_) { this_node->counter_ = 0; }
// Count the number of consumers and downstream nodes
for (SbpNode* this_node : node_list_) { this_node->RaiseConsumerNum(op_name2sbp_node); }
// Reduce the wait time for tributaries
for (SbpNode* this_node : node_list_) {
this_node->SpreadAvailWaitTime(trunk_cost, acc_trunk_cost, op_name2sbp_node, wait_time_);
}
// Reduce the wait time for trunk from the end to the begin
double acc_tributary_cost = outdated_tributary_cost[max_min_layer];
double used_tributary_cost = 0.0;
double curr_wait_time = 0.0;
for (int32_t layer_id = max_min_layer - 1; layer_id >= 0; layer_id--) {
// Can not move it backward since we need to do this at the 0th layer.
// At some moment, the cost haven't been used would disappear.
if (tributary_cost[layer_id + 1] > used_tributary_cost) {
acc_tributary_cost -= tributary_cost[layer_id + 1] - used_tributary_cost;
used_tributary_cost = 0.0;
if (acc_tributary_cost < 0.0) {
// should not happen besides floating point error
std::cout << "Caution! Current accumulated tributary cost is: " << acc_tributary_cost
<< std::endl;
acc_tributary_cost = 0.0;
}
} else {
used_tributary_cost -= tributary_cost[layer_id + 1];
}
// accumulate tributary cost at this layer
acc_tributary_cost += outdated_tributary_cost[layer_id];
// If we have more cost in tributaries, we reduce the wait time
// This code maintains ( acc_tributary_cost + used_tributary_cost )
if (acc_tributary_cost > 0.0) {
if (acc_tributary_cost > wait_time_) {
curr_wait_time = 0.0;
acc_tributary_cost -= wait_time_;
used_tributary_cost += wait_time_;
} else {
curr_wait_time = wait_time_ - acc_tributary_cost;
used_tributary_cost += acc_tributary_cost;
acc_tributary_cost = 0.0;
}
// Reduce the wait time in the trunk
for (SbpNode* this_node : trunk_ops[layer_id]) {
this_node->SetTrunkWaitTime(curr_wait_time);
}
}
}
}
// Set wait time
void SbpGraph::SetWaitTime(double wait_time) { wait_time_ = wait_time; }
} // namespace auto_parallel
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_AUTO_PARALLEL_SBP_GRAPH_H_
#define ONEFLOW_CORE_AUTO_PARALLEL_SBP_GRAPH_H_
#include <algorithm>
#include <unordered_map>
#include "oneflow/core/auto_parallel/binary_set.h"
#include "oneflow/core/auto_parallel/sbp_node.h"
#include "oneflow/core/auto_parallel/sbp_edge.h"
#include "oneflow/core/auto_parallel/algorithm_util.h"
#include "oneflow/core/common/util.h"
namespace oneflow {
namespace auto_parallel {
// A graph structure to deal with the SBP strategy.
// It contains a lot of eliminations to shrink the topography structure of the original graph.
// Furthermore, it contains some adjustment tricks for search a good strategy in the shrunk graph.
class SbpGraph final {
public:
// Constructor
SbpGraph() = default;
// Deconstructor
~SbpGraph();
OF_DISALLOW_COPY_AND_MOVE(SbpGraph);
bool operator==(const SbpGraph& other) { return this == &other; }
// Randomly assign a SbpSignature strategy
void RandomSbpSignature(bool use_sbp_collector) const;
// assign 0 to a SbpSignature strategy to avoid randomness
void SetDefaultSbpSig() const;
// Compute Cost for current strategy
double ComputeCost() const;
// Generate a node
SbpNode* GenerateNode();
// Merge all parallel edges & Check and eliminate all nodes with only one
// degree-in and one degree-out
int32_t NodeAndEdgeEliminations();
// Finalize Sbp Cost for the whole graph
void FinalizeSbp() const;
// Use Greedy Strategy to decide Sbp for Nodes in node_list_. Should be used
// after we have a initial strategy.
// Set for_node to be true will only use GreedyStrategy on Nodes.
double GreedyStrategy(bool for_node) const;
// Use greedy strategy on the one ring neighborhood with the maximum number of points nbh_num.
double GreedyStrategy(int32_t nbh_num = 4) const;
// Find one strategy with finite cost for adjustment
Maybe<void> Find1Strategy4Greedy() const;
// Use brute force to search for a strategy with minimum cost for a neighborhood
double NbhGreedyStrategy(std::vector<int32_t>& nbh_id2node_list_id) const;
// Set threshold_ for SbpNode Merging
void SetThreshold(int32_t threshold) { threshold_ = threshold; }
// Clip an edge, remove it from graph
// Clipping an edge will also delete the nodes and edges contained in this edge. Though not
// suffering from any compiling and runtime bugs, clipping an edge on a shrunk graph is not
// recommended. We should carefully think about it before any clipping.
void ClipEdge(SbpEdge* this_edge) const;
// Compute the minimum and maximum layer of each node in the graph
int32_t ComputeLayer(
HashMap<std::string, SbpNode*>& op_name2sbp_node,
const HashMap<const OpNode*, HashSet<std::string>>& op_node2mutable_op_ctrl_deps) const;
// Find the trunk of the sbp graph, then reduce the wait time for tributaries
void FindTrunk(int32_t max_min_layer, HashMap<std::string, SbpNode*>& op_name2sbp_node) const;
// Set wait time
void SetWaitTime(double wait_time);
private:
friend class SbpCollector;
friend class SbpConstructor;
// All the nodes
std::vector<SbpNode*> node_list_;
// Limitation: Merged node should not have a number of Sbp Signature greater
// than threshold.
int32_t threshold_ = 100;
// Overlayable wait time for copy cost, which occurs before communication between devices.
double wait_time_ = 16500.0;
// Remove a node from the node list
void RemoveFromNodeList(SbpNode* this_node);
// Check and eliminate one node with only one degree-in and one degree-out
int32_t NodeElimination(SbpNode* this_node);
// Merge all parallel edges with given start_node_ and end_node_
int32_t EdgeElimination(SbpNode* this_node) const;
// Check and eliminate one child node
int32_t ChildElimination(SbpNode* this_node);
// Merge two nodes
int32_t NodeMerging(SbpNode* first, SbpNode* second);
// Select two nodes and merge them
int32_t PickAndMerge();
void DfsAddNbhCost(std::vector<int32_t>& nbh_id2node_list_id,
std::unordered_map<int32_t, int32_t>& node_list_id2nbh_id,
std::vector<int32_t>& order2nbh_id, std::vector<int32_t>& nbh_id2order,
std::vector<double>& order2acc_min_in_nbh_cost,
std::vector<std::vector<double>>& out_nbh_costs,
std::vector<std::vector<int32_t>>& nbh_id2order2sbp_id,
std::vector<int32_t>& min_sbp_sig_id, double& min_cost, int32_t order,
double curr_cost) const;
bool DfsFindReasonableCost(std::vector<int32_t>& nbh_id2node_list_id,
std::unordered_map<int32_t, int32_t>& node_list_id2nbh_id,
std::vector<int32_t>& nbh_id2order, int32_t nbh_id) const;
};
} // namespace auto_parallel
} // namespace oneflow
#endif // ONEFLOW_CORE_AUTO_PARALLEL_SBP_GRAPH_H_
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include <cstdlib>
#include <functional>
#include <iostream>
#include <vector>
#include "oneflow/core/auto_parallel/binary_set.h"
#include "oneflow/core/common/data_type.h"
#include "oneflow/core/framework/sbp_infer_util.h"
#include "oneflow/core/graph/op_graph.h"
#include "oneflow/core/auto_parallel/algorithm_util.h"
#include "oneflow/core/job/sbp_parallel.pb.h"
#include "oneflow/core/auto_parallel/sbp_node.h"
#include "oneflow/core/auto_parallel/sbp_edge.h"
#include "oneflow/core/auto_parallel/sbp_graph.h"
namespace oneflow {
namespace auto_parallel {
// function in cpp. Should be put in one file due to use of template
// Otherwise we will need to declare specific template at the end of cpp file.
SbpNode::SbpNode(SbpNode* first, SbpNode* second) {
half_node_.resize(2);
half_node_[0] = first;
half_node_[1] = second;
// Get the edge between first and second
// NOTE: It must zero or one edge between them
SbpEdge* common_edge = nullptr;
for (int32_t k = 0; k < first->edges_in_.size(); k++) {
if (first->edges_in_[k]->start_node_ == second) {
// CHECK_ISNULL(edge);
common_edge = first->edges_in_[k];
}
}
for (int32_t k = 0; k < first->edges_out_.size(); k++) {
if (first->edges_out_[k]->end_node_ == second) { common_edge = first->edges_out_[k]; }
}
// Find all available merged-SbpSignature(edge's cost less than threshold).
if (common_edge) {
double min_cost = GetMaxVal<float>();
for (const auto& row : common_edge->cost_) {
for (const double& c : row) min_cost = std::min(min_cost, c);
}
// If there is no one case can choose, we will blow up
for (int32_t i = 0; i < first->cost_.size(); i++) {
for (int32_t j = 0; j < second->cost_.size(); j++) {
const double edge_cost =
common_edge->start_node_ == first ? common_edge->cost_[i][j] : common_edge->cost_[j][i];
if (edge_cost < GetValidMaxCopyCost()) {
merged_sig_id2children_sig_id_.emplace_back(std::make_pair(i, j));
cost_.emplace_back(edge_cost + first->cost_[i] + second->cost_[j]);
}
}
}
CHECK(merged_sig_id2children_sig_id_.size() > 0)
<< "0 size for merge child edge, min cost: " << min_cost;
} else {
for (int32_t i = 0; i < first->cost_.size(); i++) {
for (int32_t j = 0; j < second->cost_.size(); j++) {
merged_sig_id2children_sig_id_.emplace_back(std::make_pair(i, j));
cost_.emplace_back(first->cost_[i] + second->cost_[j]);
}
}
}
// Initialize default sbp choice
// If the original sbp pair does not go through, then use 0 as default.
final_sbp_sig_id_ = 0;
// Track the original strategy
for (int32_t sig_id = 0; sig_id < merged_sig_id2children_sig_id_.size(); sig_id++) {
if (merged_sig_id2children_sig_id_[sig_id].first == first->final_sbp_sig_id_
&& merged_sig_id2children_sig_id_[sig_id].second == second->final_sbp_sig_id_) {
final_sbp_sig_id_ = sig_id;
}
}
// Merge edges_in_
edges_in_.reserve(first->edges_in_.size() + second->edges_in_.size());
edges_in_.insert(edges_in_.end(), first->edges_in_.begin(), first->edges_in_.end());
edges_in_.insert(edges_in_.end(), second->edges_in_.begin(), second->edges_in_.end());
// Merge edges_out_
edges_out_.reserve(first->edges_out_.size() + second->edges_out_.size());
edges_out_.insert(edges_out_.end(), first->edges_out_.begin(), first->edges_out_.end());
edges_out_.insert(edges_out_.end(), second->edges_out_.begin(), second->edges_out_.end());
// Merge SbpEdge Cost
for (SbpEdge*& this_edge : first->edges_in_) {
this_edge->DuplicateCost(false, true, merged_sig_id2children_sig_id_);
this_edge->end_node_ = this;
}
for (SbpEdge*& this_edge : first->edges_out_) {
this_edge->DuplicateCost(true, true, merged_sig_id2children_sig_id_);
this_edge->start_node_ = this;
}
for (SbpEdge*& this_edge : second->edges_in_) {
this_edge->DuplicateCost(false, false, merged_sig_id2children_sig_id_);
this_edge->end_node_ = this;
}
for (SbpEdge*& this_edge : second->edges_out_) {
this_edge->DuplicateCost(true, false, merged_sig_id2children_sig_id_);
this_edge->start_node_ = this;
}
// Remove edges from original nodes
first->edges_in_.clear();
first->edges_out_.clear();
second->edges_in_.clear();
second->edges_out_.clear();
// Move edges between two nodes to each half node
for (int32_t k = edges_out_.size() - 1; k >= 0; k--) {
if (edges_out_[k]->end_node_ == this) {
// Remove this edge from edges_out_ and edges_in_ and put it inside the node
CheckAndRemoveFrom<SbpEdge*>(edges_in_, edges_out_[k]);
first->edges_out_.emplace_back(edges_out_[k]);
second->edges_in_.emplace_back(edges_out_[k]);
RemoveFrom<SbpEdge*>(edges_out_, k);
}
}
}
SbpNode::~SbpNode() {
for (auto& edge_out : edges_out_) { delete edge_out; }
for (auto& child_node : children_) {
if (child_node->edges_in_.size()) { delete child_node->edges_in_[0]; }
delete child_node;
}
for (auto& half_node : half_node_) { delete half_node; }
}
void SbpNode::InitializeSbp() {
global_sbp_sig_size_ = sbp_sig_list_.size();
cost_.resize(sbp_sig_list_.size());
};
// Let one node point to another
void SbpNode::StartPointToEnd(SbpNode* start_node, SbpNode* end_node) {
// generate the edge between them
SbpEdge* e = new SbpEdge(start_node, end_node);
start_node->edges_out_.emplace_back(e);
end_node->edges_in_.emplace_back(e);
};
void SbpNode::PointFrom(SbpNode* start_node) { StartPointToEnd(start_node, this); };
void SbpNode::PointTo(SbpNode* end_node) { StartPointToEnd(this, end_node); };
void SbpNode::SummarizeCost() {
if (children_.size() == child_node_sbp_sig_.size()) { return; }
int32_t previous_children_size = child_node_sbp_sig_.size();
child_node_sbp_sig_.resize(children_.size());
// Only deal with new children_
for (int32_t child = previous_children_size; child < children_.size(); child++) {
child_node_sbp_sig_[child].resize(cost_.size());
for (int32_t sbp_this = 0; sbp_this < cost_.size(); sbp_this++) {
double min_cost = 0, curr_cost = 0;
for (int32_t sbp_child = 0; sbp_child < children_[child]->cost_.size(); sbp_child++) {
if (children_[child]->edges_in_.size()) {
// edge in graph: father -> child
curr_cost = children_[child]->edges_in_[0]->cost_[sbp_this][sbp_child]
+ children_[child]->cost_[sbp_child];
} else {
// edge in graph: child -> father
curr_cost = children_[child]->edges_out_[0]->cost_[sbp_child][sbp_this]
+ children_[child]->cost_[sbp_child];
}
// update min_cost with fixed SbpSignature for this node and child node
if (sbp_child == 0 || curr_cost < min_cost) {
min_cost = curr_cost;
child_node_sbp_sig_[child][sbp_this] = sbp_child;
}
}
// Add the cost for child node to this node
cost_[sbp_this] += min_cost;
}
}
}
bool SbpNode::EliminateItselfAsChild() {
if (edges_in_.size() + edges_out_.size() == 1) {
if (edges_in_.size()) {
// edge in graph: father -> this_node
SbpNode* father = edges_in_[0]->start_node_;
father->children_.emplace_back(this);
CheckAndRemoveFrom<SbpEdge*>(father->edges_out_, edges_in_[0]);
father->SummarizeCost();
} else {
// edge in graph: this_node -> father
SbpNode* father = edges_out_[0]->end_node_;
father->children_.emplace_back(this);
CheckAndRemoveFrom<SbpEdge*>(father->edges_in_, edges_out_[0]);
father->SummarizeCost();
}
// successfully eliminate this node
return true;
}
// can not eliminate this node
return false;
}
void SbpNode::FinalizeSbp() {
if (!half_node_.empty()) {
// Finalize Sbp of merged nodes
half_node_[0]->final_sbp_sig_id_ = merged_sig_id2children_sig_id_[final_sbp_sig_id_].first;
half_node_[1]->final_sbp_sig_id_ = merged_sig_id2children_sig_id_[final_sbp_sig_id_].second;
}
// Finalize Sbp of children_
for (int32_t i = 0; i < children_.size(); i++) {
children_[i]->final_sbp_sig_id_ = child_node_sbp_sig_[i][this->final_sbp_sig_id_];
}
// Finalize Sbp of half_node_ Attachment
if (!half_node_.empty()) {
half_node_[0]->FinalizeSbp();
half_node_[1]->FinalizeSbp();
}
// Finalize Sbp of edges in edges_out_
for (const auto& edge_out : edges_out_) { edge_out->FinalizeSbp(); }
// Finalize Sbp again in case of the node on the other side is not finalized
// yet. This may happen when Two side of an edge merged into two larger nodes
// and this edge is just a sub edge.
for (const auto& edge_in : edges_in_) { edge_in->FinalizeSbp(); }
// Finalize Sbp of children_ Attachment
for (int32_t i = 0; i < children_.size(); i++) {
children_[i]->FinalizeSbp();
for (const auto& edge_in : children_[i]->edges_in_) { edge_in->FinalizeSbp(); }
}
}
double SbpNode::GreedyStrategy() {
// Current Cost, Minimum Cost, Cost with original sbp
double curr_cost = 0;
double original_cost = EvalNbhCost();
double min_cost = original_cost;
int32_t min_sbp = final_sbp_sig_id_;
for (int32_t sbp = 0; sbp < cost_.size(); sbp++) {
final_sbp_sig_id_ = sbp;
curr_cost = EvalNbhCost();
if (curr_cost < min_cost) {
min_cost = curr_cost;
min_sbp = sbp;
}
}
final_sbp_sig_id_ = min_sbp;
return min_cost - original_cost;
}
double SbpNode::EvalNbhCost() const {
// Current Cost, Minimum Cost, Cost with original sbp
double curr_cost = cost_[final_sbp_sig_id_];
for (SbpEdge* this_edge : edges_in_) {
curr_cost += this_edge->cost_[this_edge->start_node_->final_sbp_sig_id_][final_sbp_sig_id_];
}
for (SbpEdge* this_edge : edges_out_) {
curr_cost += this_edge->cost_[final_sbp_sig_id_][this_edge->end_node_->final_sbp_sig_id_];
}
return curr_cost;
}
double SbpNode::EvalOutNbhCost(
const std::unordered_map<int32_t, int32_t>& node_list_id2nbh_id) const {
// check if this node is in the node list
CHECK(node_list_id_ >= 0) << "Compute out cost for a node out of the node list" << std::endl;
// Cost with original sbp
double curr_cost = cost_[final_sbp_sig_id_];
for (SbpEdge* this_edge : edges_in_) {
// if the start node is not in the neighborhood
if (node_list_id2nbh_id.find(this_edge->start_node_->node_list_id_)
== node_list_id2nbh_id.end()) {
curr_cost += this_edge->cost_[this_edge->start_node_->final_sbp_sig_id_][final_sbp_sig_id_];
}
}
for (SbpEdge* this_edge : edges_out_) {
// if the end node is not in the neighborhood
if (node_list_id2nbh_id.find(this_edge->end_node_->node_list_id_)
== node_list_id2nbh_id.end()) {
curr_cost += this_edge->cost_[final_sbp_sig_id_][this_edge->end_node_->final_sbp_sig_id_];
}
}
return curr_cost;
}
// Compute the cost between this node and adjacent nodes with a lower order
double SbpNode::EvalInNbhCost(const std::unordered_map<int32_t, int32_t>& node_list_id2nbh_id,
const std::vector<int32_t>& nbh_id2order) const {
// check if this node is in the node list
CHECK(node_list_id_ >= 0) << "Compute in cost for a node out of the node list";
// check if the node is in the neighborhood
const auto& this_it = node_list_id2nbh_id.find(node_list_id_);
CHECK(this_it != node_list_id2nbh_id.end())
<< "Compute in cost for a node out of the neighborhood";
// Compute the minimum cost between this node and adjacent nodes with a lower order
int32_t order = nbh_id2order[this_it->second];
double curr_cost = 0;
for (SbpEdge* this_edge : edges_in_) {
const auto& it = node_list_id2nbh_id.find(this_edge->start_node_->node_list_id_);
// if the start node is in the neighborhood
if (it != node_list_id2nbh_id.end() && nbh_id2order[it->second] < order) {
curr_cost += this_edge->cost_[this_edge->start_node_->final_sbp_sig_id_][final_sbp_sig_id_];
// End this function and return infinity.
if (curr_cost > GetValidMaxCopyCost()) { return GetMaxVal<float>(); }
}
}
for (SbpEdge* this_edge : edges_out_) {
const auto& it = node_list_id2nbh_id.find(this_edge->end_node_->node_list_id_);
// if the end node is in the neighborhood
if (it != node_list_id2nbh_id.end() && nbh_id2order[it->second] < order) {
curr_cost += this_edge->cost_[final_sbp_sig_id_][this_edge->end_node_->final_sbp_sig_id_];
if (curr_cost > GetValidMaxCopyCost()) { return GetMaxVal<float>(); }
}
}
return curr_cost;
}
double SbpNode::EvalMinInNbhCost(const std::unordered_map<int32_t, int32_t>& node_list_id2nbh_id,
const std::vector<int32_t>& nbh_id2order) const {
// check if this node is in the node list
CHECK(node_list_id_ >= 0) << "Compute out cost for a node out of the node list" << std::endl;
// check if the node is in the neighborhood
const auto& this_it = node_list_id2nbh_id.find(node_list_id_);
CHECK(this_it != node_list_id2nbh_id.end())
<< "Compute out cost for a node out of the neighborhood" << std::endl;
// Compute the minimum cost between this node and adjacent nodes with a higher order
int32_t order = nbh_id2order[this_it->second];
double curr_cost = 0;
for (SbpEdge* this_edge : edges_in_) {
const auto& it = node_list_id2nbh_id.find(this_edge->start_node_->node_list_id_);
// if the start node is in the neighborhood
if (it != node_list_id2nbh_id.end() && nbh_id2order[it->second] > order) {
curr_cost += this_edge->GetMinCost();
}
}
for (SbpEdge* this_edge : edges_out_) {
const auto& it = node_list_id2nbh_id.find(this_edge->end_node_->node_list_id_);
// if the end node is in the neighborhood
if (it != node_list_id2nbh_id.end() && nbh_id2order[it->second] > order) {
curr_cost += this_edge->GetMinCost();
}
}
return curr_cost;
}
void SbpNode::OneRingNeighborhood(std::vector<int32_t>& nbh_1ring) const {
nbh_1ring.resize(edges_in_.size() + edges_out_.size() + 1);
int32_t nbh_id = 0;
nbh_1ring[nbh_id] = node_list_id_;
for (SbpEdge* this_edge : edges_in_) {
nbh_id++;
nbh_1ring[nbh_id] = this_edge->start_node_->node_list_id_;
}
for (SbpEdge* this_edge : edges_out_) {
nbh_id++;
nbh_1ring[nbh_id] = this_edge->end_node_->node_list_id_;
}
}
// Get the n ring neighborhood of this node
// Pre-allocate buffer, which will be faster.
void SbpNode::NRingNeighborhood(int32_t n, std::vector<int32_t>& nbh_n_ring,
std::vector<int32_t>& nbh_1ring,
const std::vector<SbpNode*>& node_list,
std::vector<bool>& node_tags) const {
// Initialize 0 ring
if (n <= 0) { n = 0; }
nbh_n_ring.resize(1);
nbh_n_ring[0] = node_list_id_;
node_tags[node_list_id_] = true;
int32_t l = 0;
// do ring expansion for n times
for (int32_t i = 0; i < n; i++) {
for (int32_t r = nbh_n_ring.size(); l < r; l++) {
node_list[nbh_n_ring[l]]->OneRingNeighborhood(nbh_1ring);
for (auto nbh_id : nbh_1ring) {
if (!node_tags[nbh_id]) {
nbh_n_ring.push_back(nbh_id);
node_tags[nbh_id] = true;
}
}
}
}
// Recover false for buffer
for (auto nbh_id : nbh_n_ring) { node_tags[nbh_id] = false; }
}
// Get or compute the minimum layer of this node
int32_t SbpNode::GetMinLayer(
const HashMap<std::string, SbpNode*>& op_name2sbp_node,
const HashMap<const OpNode*, HashSet<std::string>>& op_node2mutable_op_ctrl_deps) {
if (min_layer_ >= 0) { return min_layer_; }
if (!op_node_) { return min_layer_; }
for (SbpEdge* this_edge : edges_in_) {
int32_t producer_min_layer =
this_edge->start_node_->GetMinLayer(op_name2sbp_node, op_node2mutable_op_ctrl_deps);
if (producer_min_layer > min_layer_) { min_layer_ = producer_min_layer; }
}
for (const auto& ctrl_in_op_name : op_node_->op().op_conf().ctrl_in_op_name()) {
const auto& it = op_name2sbp_node.find(ctrl_in_op_name);
if (it != op_name2sbp_node.end()) {
int32_t producer_min_layer =
it->second->GetMinLayer(op_name2sbp_node, op_node2mutable_op_ctrl_deps);
if (producer_min_layer > min_layer_) { min_layer_ = producer_min_layer; }
}
}
if (op_node2mutable_op_ctrl_deps.find(op_node_) != op_node2mutable_op_ctrl_deps.end()) {
for (const auto& ctrl_in_op_name : op_node2mutable_op_ctrl_deps.at(op_node_)) {
const auto& it = op_name2sbp_node.find(ctrl_in_op_name);
if (it != op_name2sbp_node.end()) {
int32_t producer_min_layer =
it->second->GetMinLayer(op_name2sbp_node, op_node2mutable_op_ctrl_deps);
if (producer_min_layer > min_layer_) { min_layer_ = producer_min_layer; }
}
}
}
return ++min_layer_;
}
// Spread the minimum layer to compute the maximum layer of producers
void SbpNode::SpreadMaxLayer(
const HashMap<std::string, SbpNode*>& op_name2sbp_node,
const HashMap<const OpNode*, HashSet<std::string>>& op_node2mutable_op_ctrl_deps) {
if (min_layer_ <= 0) { return; }
int32_t producer_max_lay = min_layer_ - 1;
for (SbpEdge* this_edge : edges_in_) { this_edge->start_node_->DropMaxLayer(producer_max_lay); }
for (const auto& ctrl_in_op_name : op_node_->op().op_conf().ctrl_in_op_name()) {
const auto& it = op_name2sbp_node.find(ctrl_in_op_name);
if (it != op_name2sbp_node.end()) { it->second->DropMaxLayer(producer_max_lay); }
}
if (op_node2mutable_op_ctrl_deps.find(op_node_) != op_node2mutable_op_ctrl_deps.end()) {
for (const auto& ctrl_in_op_name : op_node2mutable_op_ctrl_deps.at(op_node_)) {
const auto& it = op_name2sbp_node.find(ctrl_in_op_name);
if (it != op_name2sbp_node.end()) { it->second->DropMaxLayer(producer_max_lay); }
}
}
}
// Drop down the maximum layer with the minimum layer from consumer
void SbpNode::DropMaxLayer(int32_t upper_bound) {
if (upper_bound < max_layer_ || max_layer_ < 0) { max_layer_ = upper_bound; }
}
// Set max_layer_ = min_layer_ if this node does not have any consumer
// This is the end of the whole graph
// We could also set it to be the maximum of the min_layer_ in the graph. (It should be the same.)
void SbpNode::LiftMaxLayer() {
if (max_layer_ < min_layer_) { max_layer_ = min_layer_; }
}
// Set max_layer_ = upper_bound if this node does not have any consumer
void SbpNode::LiftMaxLayer(int32_t upper_bound) {
if (max_layer_ < min_layer_) { max_layer_ = upper_bound; }
}
// Get the minimum element in Cost
double SbpNode::GetMinCost() const {
// Check the size of Cost
CHECK(cost_.size() > 0) << "Cost not initialized!" << std::endl;
// Compute the min_comp_cost
return *std::min_element(cost_.begin(), cost_.end());
}
// Set the cut ratio
double SbpNode::GetCutRatio() const {
double curr_cut_ratio = 1.0;
for (auto* this_edge : edges_in_) { curr_cut_ratio *= this_edge->GetCutRatio(); }
for (auto* this_edge : edges_out_) { curr_cut_ratio *= this_edge->GetCutRatio(); }
return curr_cut_ratio;
}
// Judge if this node is on the trunk
// If so, judge it for its producer/upstream nodes
void SbpNode::SpreadTrunk(const HashMap<std::string, SbpNode*>& op_name2sbp_node) {
// Skip it if this node is already judged.
if (on_trunk_) { return; }
// Skip sbp proxy. This is before we have proxy.
if (min_layer_ < 0) { return; }
on_trunk_ = true;
// If I am in the trunk, then all the children with (min_layer_ >= my layer id - 1) would be
// considered as in the trunk
for (SbpEdge* this_edge : edges_in_) {
if (this_edge->start_node_->min_layer_ >= min_layer_ - 1) {
this_edge->start_node_->SpreadTrunk(op_name2sbp_node);
}
}
for (const auto& ctrl_in_op_name : op_node_->op().op_conf().ctrl_in_op_name()) {
const auto& it = op_name2sbp_node.find(ctrl_in_op_name);
if (it != op_name2sbp_node.end() && it->second->min_layer_ >= min_layer_ - 1) {
it->second->SpreadTrunk(op_name2sbp_node);
}
}
}
// Count consumers and any downstream nodes defined by control edges
void SbpNode::RaiseConsumerNum(const HashMap<std::string, SbpNode*>& op_name2sbp_node) {
// Should clear it before running.
// skip the proxy nodes and the sources
if (min_layer_ <= 0) { return; }
for (SbpEdge* this_edge : edges_in_) { this_edge->start_node_->counter_++; }
for (const auto& ctrl_in_op_name : op_node_->op().op_conf().ctrl_in_op_name()) {
const auto& it = op_name2sbp_node.find(ctrl_in_op_name);
if (it != op_name2sbp_node.end()) { it->second->counter_++; }
}
}
// Compute the minimal available wait time for producers or upstream nodes
void SbpNode::SpreadAvailWaitTime(const std::vector<double>& trunk_cost,
const std::vector<double>& acc_trunk_cost,
const HashMap<std::string, SbpNode*>& op_name2sbp_node,
double wait_time) {
// skip the proxy nodes and the sources
if (min_layer_ <= 0) { return; }
// Have not finished spreading for consumers or downstream nodes or already visited.
if (counter_) { return; }
if (on_trunk_) {
// Nodes on the trunk does not have any accumulate cost
acc_trunk_cost_ = 0;
} else {
if (acc_trunk_cost_ < 0) {
// Do not have any consumer or downstream node
acc_trunk_cost_ = acc_trunk_cost[min_layer_ - 1];
} else {
// Add the trunk cost at this layer
acc_trunk_cost_ += trunk_cost[min_layer_];
}
}
// Reduce the wait time for edges_in_, put the rest of the trunk cost in the producers
for (SbpEdge* this_edge : edges_in_) {
CHECK(this_edge->wait_time_ < 0)
<< "Double assign values into wait_time_ of this edge!" << std::endl;
SbpNode* producer = this_edge->start_node_;
// Accumulate the cost from the start node to this node
double curr_trunk_cost =
acc_trunk_cost_ + acc_trunk_cost[producer->min_layer_] - acc_trunk_cost[min_layer_ - 1];
if (curr_trunk_cost >= wait_time) {
// Remain cost in the trunk is able to cover all the wait time
this_edge->wait_time_ = 0.0;
curr_trunk_cost -= wait_time;
} else {
// Remain cost in the trunk can only cover partial wait time
this_edge->wait_time_ = wait_time - curr_trunk_cost;
curr_trunk_cost = 0.0;
}
// Reducing non-matching edges
// For example:
// (1) P->S0->S0->S0->B
// (2) p->B->B->B->B
// We would use (2) when the tensor is relatively tiny.
// Do not inherit trunk cost for nodes on the trunk
if (!producer->on_trunk_) {
// Inherit the minimal of the trunk cost from consumers
producer->DropAvailWaitTime(curr_trunk_cost);
}
producer->counter_--;
producer->SpreadAvailWaitTime(trunk_cost, acc_trunk_cost, op_name2sbp_node, wait_time);
}
// Put the rest the trunk cost in the upstream nodes.
for (const auto& ctrl_in_op_name : op_node_->op().op_conf().ctrl_in_op_name()) {
const auto& it = op_name2sbp_node.find(ctrl_in_op_name);
if (it != op_name2sbp_node.end()) {
SbpNode* producer = it->second;
// Do not inherit trunk cost for nodes on the trunk
if (!producer->on_trunk_) {
// Accumulate the cost from the start node to this node
double curr_trunk_cost =
acc_trunk_cost_ + acc_trunk_cost[producer->min_layer_] - acc_trunk_cost[min_layer_ - 1];
// Inherit the minimal of the trunk cost from consumers
producer->DropAvailWaitTime(curr_trunk_cost);
}
producer->counter_--;
producer->SpreadAvailWaitTime(trunk_cost, acc_trunk_cost, op_name2sbp_node, wait_time);
}
}
// Set counter_ to be -1, do not visit it again.
counter_--;
}
// Drop down the available wait time with the minimum cost from downstream
void SbpNode::DropAvailWaitTime(double curr_trunk_cost) {
if (acc_trunk_cost_ < 0.0 || acc_trunk_cost_ > curr_trunk_cost) {
acc_trunk_cost_ = curr_trunk_cost;
}
}
// Assemble copy cost for all the incoming edges
void SbpNode::InitializeCopyCost(bool use_sbp_collector) {
for (SbpEdge* this_edge : edges_in_) {
const auto* sbp_node_producer = this_edge->start_node_;
OpNode* producer = sbp_node_producer->op_node_;
// skip it if proxy
if (use_sbp_collector && !producer) { continue; }
// look through input blobs
for (const std::string& ibn : op_node_->op().input_bns()) {
if (producer->op().op_name() == op_node_->SrcNode4Ibn(ibn).op().op_name()) {
this_edge->InitializeCopyCost(ibn, use_sbp_collector);
}
}
// Add Wait time
for (auto& cost_row : this_edge->cost_) {
for (auto& cost_value : cost_row) {
// If transferring between devices, we need to add wait time.
if (cost_value > 0.0) { cost_value += this_edge->wait_time_; }
}
}
}
}
// Reduce and set the wait time for op in the trunk
void SbpNode::SetTrunkWaitTime(double trunk_wait_time) {
// only reduce the wait time for operators in the trunk
if (on_trunk_) {
// Reduce the wait time for edges_out_
for (SbpEdge* edge_out : edges_out_) {
if (edge_out->wait_time_ < 0.0 || edge_out->wait_time_ > trunk_wait_time) {
edge_out->wait_time_ = trunk_wait_time;
}
}
// Might reduce it for edges_in_
}
}
// Drop down the maximum layer with the minimum layer from consumer
void SbpNode::DropTributaryLayer(int32_t upper_bound) {
if (upper_bound < tributary_layer_ || tributary_layer_ < 0) { tributary_layer_ = upper_bound; }
}
// Compute maximum layer for tributaries
void SbpNode::SpreadTributaryLayer(const HashMap<std::string, SbpNode*>& op_name2sbp_node) {
if (counter_ || min_layer_ <= 0) { return; }
int32_t producer_max_lay = 0;
if (on_trunk_) {
producer_max_lay = min_layer_ - 1;
} else {
// On a tributary, the operator could be run later.
producer_max_lay = tributary_layer_;
// producer_max_lay = tributary_layer_ - 1;
}
for (SbpEdge* this_edge : edges_in_) {
this_edge->start_node_->DropTributaryLayer(producer_max_lay);
if (--this_edge->start_node_->counter_ == 0) {
this_edge->start_node_->SpreadTributaryLayer(op_name2sbp_node);
}
}
for (const auto& ctrl_in_op_name : op_node_->op().op_conf().ctrl_in_op_name()) {
const auto& it = op_name2sbp_node.find(ctrl_in_op_name);
if (it != op_name2sbp_node.end()) {
it->second->DropTributaryLayer(producer_max_lay);
if (--it->second->counter_ == 0) { it->second->SpreadTributaryLayer(op_name2sbp_node); }
}
}
counter_--;
}
SbpEdge* SbpNode::FindEdgeWithNode(const SbpNode* other_node) const {
for (auto* sbp_edge : edges_in_) {
if (sbp_edge->start_node_ == other_node) { return sbp_edge; }
}
for (auto* sbp_edge : edges_out_) {
if (sbp_edge->end_node_ == other_node) { return sbp_edge; }
}
return nullptr;
};
// Decide to use this SbpSignature
const NdSbpSignature& SbpNode::FinalSbpSignature() const {
CHECK(!sbp_sig_list_.empty()) << "Asking for sbp signature for an empty node";
return sbp_sig_list_[final_sbp_sig_id_];
};
} // namespace auto_parallel
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_AUTO_PARALLEL_SBP_NODE_H_
#define ONEFLOW_CORE_AUTO_PARALLEL_SBP_NODE_H_
#include <cstdlib>
#include <functional>
#include <iostream>
#include <vector>
#include "oneflow/core/auto_parallel/binary_set.h"
#include "oneflow/core/common/data_type.h"
#include "oneflow/core/framework/sbp_infer_util.h"
#include "oneflow/core/graph/op_graph.h"
#include "oneflow/core/auto_parallel/algorithm_util.h"
#include "oneflow/core/job/sbp_parallel.pb.h"
namespace oneflow {
namespace auto_parallel {
class SbpEdge;
// A node structure to deal with the SBP strategy.
// Please see SbpGraph for the whole algorithm and introduction.
class SbpNode final {
public:
// default constructor
SbpNode() : final_sbp_sig_id_(0) {}
// This constructor is to merge two node into one
SbpNode(SbpNode* first, SbpNode* second);
~SbpNode();
OF_DISALLOW_COPY_AND_MOVE(SbpNode);
bool operator==(const SbpNode& other) { return this == &other; }
// another node point to this node
void PointFrom(SbpNode* start_node);
// this node point to another node
void PointTo(SbpNode* end_node);
SbpEdge* FindEdgeWithNode(const SbpNode* other_node) const;
// Check and eliminate one child node.
// Only used by SbpGraph since it need to remove it from the NodeList after this.
bool EliminateItselfAsChild();
// Initialize SbpSignature from Signature Objects
void InitializeSbp();
// Decide to use this SbpSignature
const NdSbpSignature& FinalSbpSignature() const;
// Recompute Computation Cost after adding child nodes in it
void SummarizeCost();
// Determine Final SbpSignature for attachment of this node
void FinalizeSbp();
// Use Greedy Strategy to pick the sbp signature with minimum cost for this
// node You should have an initial strategy before running this
double GreedyStrategy();
// Evaluate summery of cost between neighborhood and outside nodes
double EvalOutNbhCost(const std::unordered_map<int32_t, int32_t>& node_list_id2nbh_id) const;
// Evaluate summery of cost within neighborhood
// We only accumulate the edge cost with a lower order.
double EvalInNbhCost(const std::unordered_map<int32_t, int32_t>& node_list_id2nbh_id,
const std::vector<int32_t>& nbh_id2order) const;
// Evaluate summery of cost within neighborhood
// We only accumulate the minimum edge cost with a higher order.
double EvalMinInNbhCost(const std::unordered_map<int32_t, int32_t>& node_list_id2nbh_id,
const std::vector<int32_t>& nbh_id2order) const;
// Get the one ring neighborhood of this node, which is itself and all the adjacent nodes.
void OneRingNeighborhood(std::vector<int32_t>& nbh_1ring) const;
// Get the n ring neighborhood of this node
// Pre-allocate buffer, which will be faster.
void NRingNeighborhood(int32_t n, std::vector<int32_t>& nbh_n_ring,
std::vector<int32_t>& nbh_1ring, const std::vector<SbpNode*>& node_list,
std::vector<bool>& node_tags) const;
// Get or compute the minimum layer of this node
int32_t GetMinLayer(
const HashMap<std::string, SbpNode*>& op_name2sbp_node,
const HashMap<const OpNode*, HashSet<std::string>>& op_node2mutable_op_ctrl_deps);
// Spread the minimum layer to compute the maximum layer of producers
void SpreadMaxLayer(
const HashMap<std::string, SbpNode*>& op_name2sbp_node,
const HashMap<const OpNode*, HashSet<std::string>>& op_node2mutable_op_ctrl_deps);
// Set max_layer_ = min_layer_ if this node does not have any consumer
void LiftMaxLayer();
// Set max_layer_ = upper_bound if this node does not have any consumer
void LiftMaxLayer(int32_t upper_bound);
// Compute maximum layer for tributaries
void SpreadTributaryLayer(const HashMap<std::string, SbpNode*>& op_name2sbp_node);
// Drop down the tributary layer
void DropTributaryLayer(int32_t upper_bound);
// Get the minimum element in Cost
double GetMinCost() const;
// get the cut ratio
double GetCutRatio() const;
// Judge if this node is on the trunk
// If so, judge it for its producer/upstream nodes
void SpreadTrunk(const HashMap<std::string, SbpNode*>& op_name2sbp_node);
// Count consumers and any downstream nodes defined by control edges
// for producers or upstream nodes
void RaiseConsumerNum(const HashMap<std::string, SbpNode*>& op_name2sbp_node);
// Compute the minimal available wait time for producers or upstream nodes
void SpreadAvailWaitTime(const std::vector<double>& trunk_cost,
const std::vector<double>& acc_trunk_cost,
const HashMap<std::string, SbpNode*>& op_name2sbp_node,
double wait_time);
// Reduce and set the wait time for op in the trunk
void SetTrunkWaitTime(double trunk_wait_time);
// Assemble copy cost for all the incoming edges
void InitializeCopyCost(bool use_sbp_collector);
private:
friend class SbpEdge;
friend class SbpGraph;
friend class SbpCollector;
friend class SbpConstructor;
// compound edge in
std::vector<SbpEdge*> edges_in_;
// compound edge out
std::vector<SbpEdge*> edges_out_;
// Location in node_list of SbpGraph
int32_t node_list_id_ = -1;
// Global SbpSignature List Size
int32_t global_sbp_sig_size_ = -1;
// Decide to use SbpSignature with this id
int32_t final_sbp_sig_id_;
// Available SbpSignature object for this node
std::vector<NdSbpSignature> sbp_sig_list_;
// Cost[sbp] is Computation Cost when using sbp_sig_list_[sbp]
std::vector<double> cost_;
// Child node list
std::vector<SbpNode*> children_;
// SbpSignature for each child node when using specific SbpSignature for this
// node Its dimension is Number of Child Nodes * Number of Available
// SbpSignatures for this node
std::vector<std::vector<int32_t>> child_node_sbp_sig_;
// Merge two nodes into this compound node
std::vector<SbpNode*> half_node_;
// We should delete those merged-signatures which has very large cost for speed up
// New sbp_sig_list_ index map to each half_node_'s sig_index
std::vector<std::pair<int32_t, int32_t>> merged_sig_id2children_sig_id_;
std::vector<BinarySet> parallel_candidates_;
OpNode* op_node_ = nullptr;
// We divide the sbp graph into multiple layers.
// min_layer_ is the minimum layer number to run this op as soon as possible.
// max_layer_ is the maximum layer number without slowing down the whole process of the graph.
// producer.max_layer_ < this_node.min_layer_ <= this_node.max_layer_ < consumer.min_layer_
int32_t min_layer_ = -1, max_layer_ = -1;
// Maximum layer in tributaries
int32_t tributary_layer_ = -1;
// Whether we are on the trunk
bool on_trunk_ = false;
// A counter_ buffer for topological traversal or something else
int32_t counter_ = 0;
// Accumulate trunk cost from consumer to the end
double acc_trunk_cost_ = -1.0;
// Let one node point to another
void StartPointToEnd(SbpNode* start_node, SbpNode* end_node);
// Evaluate summery of cost in 1-ring neighborhood.
double EvalNbhCost() const;
// Drop down the maximum layer with the minimum layer from consumer
void DropMaxLayer(int32_t upper_bound);
// Drop down the available wait time with the minimum cost from downstream
void DropAvailWaitTime(double curr_trunk_cost);
}; // class SbpNode
} // namespace auto_parallel
} // namespace oneflow
#endif // ONEFLOW_CORE_AUTO_PARALLEL_SBP_NODE_H_
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include <memory>
#include "oneflow/core/auto_parallel/sbp_util.h"
#include "oneflow/core/common/data_type.h"
#include "oneflow/core/job/sbp_parallel.h"
#include "oneflow/core/graph/boxing/hierarchical_sub_task_graph_builder_impl.h"
namespace oneflow {
namespace auto_parallel {
// Judge whether we need the same SBP for both producer and consumer
bool RequireSameSbp(const OpNode* consumer, const std::string& ibn) {
// is mutable
const auto& input_blob_modifier_ = consumer->op().InputBlobModifier4Ibn(ibn);
if (input_blob_modifier_.has_is_mutable() && input_blob_modifier_.is_mutable()) { return true; }
// kOFRecord or kTensorBuffer don't accept boxing
const LogicalBlobId& lbi = consumer->op().BnInOp2Lbi(ibn);
const OpNode& producer = consumer->ProducerOpNode4Lbi(lbi);
const BlobDesc& logical_blob_desc = producer.LogicalBlobDesc4Lbi(lbi);
return (logical_blob_desc.data_type() == DataType::kOFRecord
|| logical_blob_desc.data_type() == DataType::kTensorBuffer);
}
} // namespace auto_parallel
} // namespace oneflow
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment