Commit 21d47d0e authored by yuguo's avatar yuguo
Browse files

Oneflow 0.8 for DCU

parents
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <pybind11/functional.h>
#include "oneflow/api/python/of_api_registry.h"
#include "oneflow/core/framework/rank_group_rpc_util.h"
#include "oneflow/core/job/rank_group.h"
#include "oneflow/core/job/rank_group_scope.h"
#include "oneflow/core/common/symbol.h"
namespace py = pybind11;
namespace oneflow {
namespace {
Maybe<void> CheckCurrentRankGroupConsistency() {
const auto& rank_group = JUST(RankGroupScope::CurrentRankGroup());
const auto& ctx = JUST(CheckTransportToken(rank_group));
JUST(ctx->WaitDone());
return Maybe<void>::Ok();
}
} // namespace
ONEFLOW_API_PYBIND11_MODULE("", m) {
m.def("check_current_rank_group_consistency", &CheckCurrentRankGroupConsistency);
}
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include <pybind11/pybind11.h>
#include <pybind11/pytypes.h>
#include <string>
#include "oneflow/api/python/of_api_registry.h"
#include "oneflow/core/job/session.h"
#include "oneflow/core/job/env_global_objects_scope.h"
#include "oneflow/core/framework/multi_client_session_context.h"
#include "oneflow/api/python/session/session.h"
namespace py = pybind11;
namespace oneflow {
ONEFLOW_API_PYBIND11_MODULE("", m) {
m.def("IsSessionInited", &IsSessionInited);
m.def("InitLazyGlobalSession", &InitLazyGlobalSession);
m.def("InitEagerGlobalSession", &InitEagerGlobalSession);
m.def("DestroyLazyGlobalSession", &DestroyLazyGlobalSession);
m.def("StartLazyGlobalSession", &StartLazyGlobalSession);
m.def("StopLazyGlobalSession", &StopLazyGlobalSession);
using namespace oneflow;
py::class_<MultiClientSessionContext, std::shared_ptr<MultiClientSessionContext>>(
m, "SessionContext")
.def(py::init<const std::shared_ptr<EnvGlobalObjectsScope>&>())
.def("try_init",
[](MultiClientSessionContext& session, const std::string& config_proto_str) {
return session.TryInit(config_proto_str).GetOrThrow();
})
.def("update_resource",
[](MultiClientSessionContext& session, const std::string& reso_proto_str) {
return session.UpdateResource(reso_proto_str).GetOrThrow();
});
m.def("NewSessionId", &NewSessionId);
py::class_<LogicalConfigProtoContext>(m, "LogicalConfigProtoContext")
.def(py::init<const std::string&>());
}
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_API_PYTHON_SESSION_SESSION_H_
#define ONEFLOW_API_PYTHON_SESSION_SESSION_H_
#include <string>
#include <google/protobuf/text_format.h>
#include "oneflow/core/common/protobuf.h"
#include "oneflow/core/control/ctrl_client.h"
#include "oneflow/core/control/global_process_ctx.h"
#include "oneflow/core/job/global_for.h"
#include "oneflow/core/job/env_global_objects_scope.h"
#include "oneflow/core/job/session_global_objects_scope.h"
#include "oneflow/core/job/cluster_instruction.h"
#include "oneflow/core/job/oneflow.h"
#include "oneflow/core/job/job_build_and_infer_ctx_mgr.h"
#include "oneflow/core/job/resource_desc.h"
#include "oneflow/core/framework/config_def.h"
#include "oneflow/core/framework/multi_client_session_context.h"
#include "oneflow/core/framework/nn_graph.h"
#include "oneflow/core/persistence/tee_persistent_log_stream.h"
namespace oneflow {
inline Maybe<bool> IsSessionInited() {
return Singleton<SessionGlobalObjectsScope>::Get() != nullptr;
}
inline void FixCpuDeviceNum(ConfigProto* config_proto) {
if (config_proto->resource().cpu_device_num() > 0) { return; }
config_proto->mutable_resource()->set_cpu_device_num(std::thread::hardware_concurrency());
}
inline Maybe<void> InitEagerGlobalSession(const std::string& config_proto_str) {
CHECK_NOTNULL_OR_RETURN(Singleton<EnvDesc>::Get()) << "env not found";
ConfigProto config_proto;
CHECK_OR_RETURN(TxtString2PbMessage(config_proto_str, &config_proto))
<< "failed to parse config_proto: " << config_proto_str;
FixCpuDeviceNum(&config_proto);
Singleton<CtrlClient>::Get()->PushKV("config_proto", config_proto);
CHECK_ISNULL_OR_RETURN(Singleton<SessionGlobalObjectsScope>::Get());
Singleton<SessionGlobalObjectsScope>::SetAllocated(new SessionGlobalObjectsScope());
JUST(Singleton<SessionGlobalObjectsScope>::Get()->EagerInit(config_proto));
VLOG(3) << "NewGlobal " << typeid(SessionGlobalObjectsScope).name();
return Maybe<void>::Ok();
}
inline Maybe<void> InitLazyGlobalSession(const std::string& config_proto_str) {
CHECK_NOTNULL_OR_RETURN(Singleton<EnvDesc>::Get()) << "env not found";
CHECK_OR_RETURN(GlobalProcessCtx::IsThisProcessMaster());
ClusterInstruction::MasterSendSessionStart();
ConfigProto config_proto;
CHECK_OR_RETURN(TxtString2PbMessage(config_proto_str, &config_proto))
<< "failed to parse config_proto: " << config_proto_str;
FixCpuDeviceNum(&config_proto);
Singleton<CtrlClient>::Get()->PushKV("config_proto", config_proto);
CHECK_ISNULL_OR_RETURN(Singleton<SessionGlobalObjectsScope>::Get());
Singleton<SessionGlobalObjectsScope>::SetAllocated(new SessionGlobalObjectsScope());
JUST(Singleton<SessionGlobalObjectsScope>::Get()->Init(config_proto));
VLOG(3) << "NewGlobal " << typeid(SessionGlobalObjectsScope).name();
return Maybe<void>::Ok();
}
inline Maybe<void> DestroyLazyGlobalSession() {
if (Singleton<SessionGlobalObjectsScope>::Get() == nullptr) { return Maybe<void>::Ok(); }
CHECK_OR_RETURN(GlobalProcessCtx::IsThisProcessMaster());
Singleton<SessionGlobalObjectsScope>::Delete();
return Maybe<void>::Ok();
}
inline Maybe<void> StartLazyGlobalSession() {
CHECK_NOTNULL_OR_RETURN(Singleton<SessionGlobalObjectsScope>::Get()) << "session not found";
CHECK_OR_RETURN(GlobalProcessCtx::IsThisProcessMaster());
const JobSet& job_set = Singleton<LazyJobBuildAndInferCtxMgr>::Get()->job_set();
if (Singleton<ResourceDesc, ForSession>::Get()->enable_debug_mode()) {
TeePersistentLogStream::Create("job_set.prototxt")->Write(job_set);
}
if (job_set.job().empty()) { return Error::JobSetEmptyError() << "no function defined"; }
CHECK_ISNULL_OR_RETURN(Singleton<Oneflow>::Get());
Singleton<CtrlClient>::Get()->PushKV("session_job_set", job_set);
Singleton<const InterJobReuseMemStrategy>::New(job_set.inter_job_reuse_mem_strategy());
Singleton<Oneflow>::New();
JUST(Singleton<Oneflow>::Get()->Init(job_set));
return Maybe<void>::Ok();
}
inline Maybe<void> StopLazyGlobalSession() {
if (Singleton<Oneflow>::Get() == nullptr) { return Maybe<void>::Ok(); }
CHECK_OR_RETURN(GlobalProcessCtx::IsThisProcessMaster());
CHECK_NOTNULL_OR_RETURN(Singleton<Oneflow>::Get());
Singleton<Oneflow>::Delete();
Singleton<const InterJobReuseMemStrategy>::Delete();
return Maybe<void>::Ok();
}
} // namespace oneflow
#endif // ONEFLOW_API_PYTHON_SESSION_SESSION_H_
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include <pybind11/pybind11.h>
#include <pybind11/operators.h>
#include "oneflow/core/common/throw.h"
#include "oneflow/core/common/protobuf.h"
#include "oneflow/api/python/of_api_registry.h"
#include "oneflow/core/job/job_desc.h"
#include "oneflow/core/job/job_conf.pb.h"
namespace py = pybind11;
namespace oneflow {
Maybe<JobDesc> CreateJobConfSymbol(int64_t symbol_id, const std::string& serialized_symbol_conf) {
JobConfigProto symbol_pb;
if (!TxtString2PbMessage(serialized_symbol_conf, &symbol_pb)) {
THROW(RuntimeError) << "job conf parse failed.\n" << serialized_symbol_conf;
}
return JobDesc::New(symbol_id, symbol_pb);
}
ONEFLOW_API_PYBIND11_MODULE("", m) {
py::class_<JobDesc, std::shared_ptr<JobDesc>>(m, "JobConfSymbol")
.def(py::init([](int64_t symbol_id, const std::string& serialized_symbol_conf) {
return CreateJobConfSymbol(symbol_id, serialized_symbol_conf).GetPtrOrThrow();
}))
.def_property_readonly("symbol_id",
[](const JobDesc& x) {
if (!x.symbol_id().has_value()) {
THROW(RuntimeError) << "symbol_id not initialized";
}
return CHECK_JUST(x.symbol_id());
})
.def_property_readonly("data", [](const JobDesc& job_conf_sym) -> std::string {
return PbMessage2TxtString(job_conf_sym.job_conf());
});
}
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include <pybind11/pybind11.h>
#include <pybind11/operators.h>
#include "oneflow/core/common/throw.h"
#include "oneflow/api/python/of_api_registry.h"
#include "oneflow/core/operator/op_conf_symbol.h"
#include "oneflow/core/common/maybe.h"
namespace py = pybind11;
namespace oneflow {
ONEFLOW_API_PYBIND11_MODULE("", m) {
py::class_<OperatorConfSymbol, std::shared_ptr<OperatorConfSymbol>>(m, "OpConfSymbol")
.def_property_readonly("symbol_id",
[](const OperatorConfSymbol& x) {
if (!x.symbol_id().has_value()) {
THROW(RuntimeError) << "symbol_id not initialized";
}
return CHECK_JUST(x.symbol_id());
})
.def_property_readonly("data", &OperatorConfSymbol::data);
}
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include <pybind11/numpy.h>
#include <pybind11/stl.h>
#include <pybind11/operators.h>
#include "oneflow/core/common/maybe.h"
#include "oneflow/extension/python/numpy.h"
#include "oneflow/api/python/framework/size.h"
#include "oneflow/api/python/of_api_registry.h"
#include "oneflow/core/control/global_process_ctx.h"
#include "oneflow/core/common/symbol.h"
#include "oneflow/core/framework/instructions_builder.h"
#include "oneflow/core/framework/parallel_conf_util.h"
#include "oneflow/core/framework/to_string.h"
#include "oneflow/core/job/parallel_desc.h"
#include "oneflow/core/job/global_for.h"
#include "oneflow/core/job/resource_desc.h"
#include "oneflow/core/ep/include/device_manager_registry.h"
namespace py = pybind11;
namespace oneflow {
namespace {
int64_t GetDeviceCount(const std::string& device_name) {
return Singleton<ep::DeviceManagerRegistry>::Get()->GetDeviceCount(device_name);
}
struct PlacementSymbolExportUtil {
static Maybe<void> CheckDeviceTag(const std::string& type) {
if (!TRY(DeviceType4DeviceTag(type)).IsOk()) {
return Error::RuntimeError() << "Expected one of " << PrintAvailableDevices()
<< " device type at start of device string: " << type;
}
return Maybe<void>::Ok();
}
static Maybe<ParallelDesc> CreateParallelDesc(
const std::string& type, const std::vector<std::string>& formated_machine_device_ids,
const std::shared_ptr<Shape>& hierarchy_shape) {
JUST(CheckDeviceTag(type));
auto parallel_conf = JUST(MakeParallelConf(type, formated_machine_device_ids, hierarchy_shape));
std::shared_ptr<ParallelDesc> parallel_desc;
JUST(PhysicalRun([&parallel_desc, &parallel_conf](InstructionsBuilder* builder) -> Maybe<void> {
parallel_desc = JUST(builder->GetParallelDescSymbol(*parallel_conf));
return Maybe<void>::Ok();
}));
return parallel_desc;
}
static Maybe<ParallelDesc> CreateParallelDesc(const std::string& proto_str) {
ParallelConf parallel_conf;
CHECK_OR_RETURN(TxtString2PbMessage(proto_str, &parallel_conf))
<< " Get ParallelConf Pb from string failed.";
std::shared_ptr<ParallelDesc> parallel_desc;
JUST(PhysicalRun([&parallel_desc, &parallel_conf](InstructionsBuilder* builder) -> Maybe<void> {
parallel_desc = JUST(builder->GetParallelDescSymbol(parallel_conf));
return Maybe<void>::Ok();
}));
return parallel_desc;
}
static Maybe<std::vector<std::string>> ParseAndFormatRanks(const py::dict& device_ids) {
std::vector<std::pair<int64_t, int64_t>> machine_device_id_vec;
for (const auto& pair : device_ids) {
CHECK_OR_RETURN(py::isinstance<py::int_>(pair.first))
<< "The key (node id) of placement device_ids must be int64.";
int64_t machine_id = pair.first.cast<int64_t>();
if (py::isinstance<py::int_>(pair.second)) {
machine_device_id_vec.emplace_back(machine_id, pair.second.cast<int64_t>());
} else {
CHECK_OR_RETURN(py::isinstance<py::iterable>(pair.second))
<< "Value of device_ids dict must be int, list or range";
for (const auto& device_id : pair.second) {
CHECK_OR_RETURN(py::isinstance<py::int_>(device_id))
<< "Value of device_ids dict must be int, list or range of int.";
machine_device_id_vec.emplace_back(machine_id, device_id.cast<int64_t>());
}
}
}
auto formated_machine_device_ids = std::make_shared<std::vector<std::string>>();
for (const auto& pair : machine_device_id_vec) {
const std::string& device_name =
std::to_string(pair.first) + ":" + std::to_string(pair.second);
formated_machine_device_ids->emplace_back(device_name);
}
return formated_machine_device_ids;
}
static Maybe<Shape> GetRanksShape(PyArrayObject* ranks) {
auto* shape = PyArray_SHAPE(ranks);
return std::make_shared<Shape>(DimVector(shape, shape + PyArray_NDIM(ranks)));
}
// Parse and format ranks to string "machine_id:local_rank"
static Maybe<std::vector<std::string>> ParseAndFormatRanks(PyArrayObject* ranks) {
size_t size = PyArray_SIZE(ranks);
CHECK_EQ_OR_RETURN(PyArray_TYPE(ranks), NPY_INT64)
<< Error::RuntimeError() << "placement ranks shoule be an array of long int";
int64_t* rank_data = static_cast<int64_t*>(PyArray_DATA(ranks));
std::vector<std::pair<int64_t, int64_t>> machine_device_id_vec;
for (int i = 0; i < size; ++i) {
int64_t rank = rank_data[i];
int64_t machine_id = GlobalProcessCtx::NodeId(rank);
int64_t device_id = GlobalProcessCtx::LocalRank(rank);
machine_device_id_vec.emplace_back(machine_id, device_id);
}
auto formated_machine_device_ids = std::make_shared<std::vector<std::string>>();
for (const auto& pair : machine_device_id_vec) {
auto device_name = std::to_string(pair.first) + ":" + std::to_string(pair.second);
formated_machine_device_ids->emplace_back(device_name);
}
return formated_machine_device_ids;
}
static Maybe<Symbol<ParallelDesc>> CreateParallelDescSymbol(
const std::string& type, const py::dict& device_ids,
const std::shared_ptr<Shape>& hierarchy) {
const auto& formated_machine_device_ids = JUST(ParseAndFormatRanks(device_ids));
return SymbolOf(*JUST(CreateParallelDesc(type, *formated_machine_device_ids, hierarchy)));
}
// create Symbol<ParallelDesc> object through given device_type and ranks parameters
static Maybe<Symbol<ParallelDesc>> CreateParallelDescSymbol(const std::string& type,
const py::object& ranks) {
auto* obj = reinterpret_cast<PyArrayObject*>(PyArray_FromAny(
ranks.ptr(), nullptr, 0, 0, NPY_ARRAY_DEFAULT | NPY_ARRAY_ENSURECOPY, nullptr));
if (!obj) { return Error::RuntimeError() << "placement ranks shoule be an array of long int"; }
const auto& shape = JUST(GetRanksShape(obj));
const auto& formated_machine_device_ids = JUST(ParseAndFormatRanks(obj));
return SymbolOf(*JUST(CreateParallelDesc(type, *formated_machine_device_ids, shape)));
}
static Maybe<Symbol<ParallelDesc>> CreateParallelDescSymbol(const std::string& proto_str) {
return SymbolOf(*JUST(CreateParallelDesc(proto_str)));
}
static Maybe<Symbol<ParallelDesc>> AllDevicePlacement(const std::string& type) {
static thread_local HashMap<std::string, Symbol<ParallelDesc>> device_tag2placement;
CHECK_NOTNULL((Singleton<ResourceDesc, ForEnv>::Get()));
JUST(CheckDeviceTag(type));
auto it = device_tag2placement.find(type);
if (it == device_tag2placement.end()) {
int64_t node_size = GlobalProcessCtx::NodeSize();
int64_t device_num = GlobalProcessCtx::NumOfProcessPerNode();
if (type != "cpu") {
const int64_t device_count = GetDeviceCount(type);
CHECK_NE_OR_RETURN(device_count, 0)
<< Error::RuntimeError() << "Can\'t construct placement with \"" << type
<< "\" type because there is no device!";
device_num = std::min(device_num, device_count);
}
std::vector<std::string> machine_device_ids;
for (int64_t node_id = 0; node_id < node_size; ++node_id) {
std::string device_name = std::to_string(node_id) + ":0-" + std::to_string(device_num - 1);
machine_device_ids.emplace_back(device_name);
}
Symbol<ParallelDesc> placement =
SymbolOf(*JUST(CreateParallelDesc(type, machine_device_ids, std::shared_ptr<Shape>())));
it = device_tag2placement.emplace(type, placement).first;
}
return it->second;
}
static Maybe<py::array> GetPlacementRanks(const Symbol<ParallelDesc>& placement) {
py::list ranks;
for (int64_t machine_id : placement->sorted_machine_ids()) {
int64_t node_id = GlobalProcessCtx::NodeId(machine_id);
for (int64_t device_id : placement->sorted_dev_phy_ids(machine_id)) {
ranks.append(py::cast(node_id * GlobalProcessCtx::NumOfProcessPerNode() + device_id));
}
}
auto array_ranks = py::cast<py::array>(ranks);
array_ranks.resize(placement->hierarchy()->dim_vec());
return array_ranks;
}
};
} // namespace
ONEFLOW_API_PYBIND11_MODULE("", m) {
py::class_<Symbol<ParallelDesc>, std::shared_ptr<Symbol<ParallelDesc>>>(m, "placement",
py::dynamic_attr())
.def(py::init([](const std::string& device_type, const py::dict& device_ids,
const std::shared_ptr<Shape>& hierarchy) {
PyErr_WarnEx(
PyExc_UserWarning,
"The way to construct placement is deprecated, and it will be removed in next "
"versions. Please use oneflow.placement(type=str, ranks=int array) instead",
1);
return PlacementSymbolExportUtil::CreateParallelDescSymbol(device_type, device_ids,
hierarchy)
.GetOrThrow();
}),
py::arg("device_type"), py::arg("device_ids"), py::arg("hierarchy"))
.def(py::init([](const std::string& device_type, const py::dict& device_ids,
const py::tuple& hierarchy) {
PyErr_WarnEx(
PyExc_UserWarning,
"The way to construct placement is deprecated, and it will be removed in next "
"versions. Please use oneflow.placement(type=str, ranks=int array) instead",
1);
DimVector shape_dims{};
for (const auto& dim : hierarchy) { shape_dims.emplace_back(dim.cast<int64_t>()); }
return PlacementSymbolExportUtil::CreateParallelDescSymbol(
device_type, device_ids, std::make_shared<Shape>(shape_dims))
.GetOrThrow();
}),
py::arg("device_type"), py::arg("device_ids"), py::arg("hierarchy") = py::tuple())
.def(py::init([](const std::string& type, const py::object& ranks) {
return PlacementSymbolExportUtil::CreateParallelDescSymbol(type, ranks).GetOrThrow();
}),
py::arg("type"), py::arg("ranks"))
.def(py::init([](const std::string& proto_str) {
return PlacementSymbolExportUtil::CreateParallelDescSymbol(proto_str).GetOrThrow();
}),
py::arg("proto_str"))
.def_property_readonly(
"device_type",
[](Symbol<ParallelDesc> p) {
PyErr_WarnEx(
PyExc_UserWarning,
"The property .device_type of placement is deprecated, please use .type instead",
1);
return p->device_tag();
})
.def_property_readonly("type", [](Symbol<ParallelDesc> p) { return p->device_tag(); })
.def_property_readonly("hierarchy",
[](Symbol<ParallelDesc> p) {
PyErr_WarnEx(PyExc_UserWarning,
"The property .hierarchy of placement is deprecated, "
"please use .ranks.shape instead",
1);
return p->hierarchy();
})
.def_property_readonly("ranks", &PlacementSymbolExportUtil::GetPlacementRanks)
.def("__str__", PlacementToString)
.def("__repr__", PlacementToString)
.def(py::self == py::self)
.def(py::hash(py::self));
m.def("AllDevicePlacement", &PlacementSymbolExportUtil::AllDevicePlacement);
}
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include <pybind11/pybind11.h>
#include <pybind11/operators.h>
#include "oneflow/api/python/of_api_registry.h"
#include "oneflow/api/common/sbp.h"
#include "oneflow/core/common/util.h"
#include "oneflow/core/common/constant.h"
#include "oneflow/core/common/maybe.h"
#include "oneflow/core/common/symbol.h"
#include "oneflow/core/framework/nd_sbp.h"
#include "oneflow/core/job/sbp_parallel.h"
#include "oneflow/core/framework/nd_sbp.h"
namespace py = pybind11;
namespace oneflow {
namespace {
Maybe<std::vector<Symbol<SbpParallel>>> MakeSplitSbpParallelList(int max_split_axis) {
std::shared_ptr<std::vector<Symbol<SbpParallel>>> ret =
std::make_shared<std::vector<Symbol<SbpParallel>>>(max_split_axis);
for (int i = 0; i < max_split_axis; ++i) { ret->at(i) = JUST(MakeSplitSbpParallel(i)); }
return ret;
}
Maybe<Symbol<SbpParallel>> GetSplitSbpParallel(int axis) {
CHECK_GE_OR_RETURN(axis, 0) << Error::RuntimeError()
<< "Split axis must not be negative, but got " << axis << "!";
CHECK_LT_OR_RETURN(axis, kMaxSplitAxis)
<< Error::RuntimeError() << "Expected split axis to be less than the supported maximum axis ("
<< kMaxSplitAxis << "), but got " << axis << "!";
static std::vector<Symbol<SbpParallel>> split_sbp_sym_list =
*JUST(MakeSplitSbpParallelList(kMaxSplitAxis));
return split_sbp_sym_list.at(axis);
}
Maybe<Symbol<SbpParallel>> GetBroadcastSbpParallel() {
static Symbol<SbpParallel> broadcast_sbp = JUST(MakeBroadcastSbpParallel());
return broadcast_sbp;
}
Maybe<Symbol<SbpParallel>> GetPartialSumSbpParallel() {
static Symbol<SbpParallel> partial_sum_sbp = JUST(MakePartialSumSbpParallel());
return partial_sum_sbp;
}
Maybe<std::pair<std::string, int>> SbpGetState(const Symbol<SbpParallel>& sbp) {
if (sbp->has_broadcast_parallel()) {
return std::make_shared<std::pair<std::string, int>>("B", -1);
} else if (sbp->has_partial_sum_parallel()) {
return std::make_shared<std::pair<std::string, int>>("P", -1);
} else if (sbp->has_split_parallel()) {
return std::make_shared<std::pair<std::string, int>>("S", sbp->split_parallel().axis());
} else {
return Error::RuntimeError() << "Invalid sbp signature: " << sbp->DebugString();
}
}
Maybe<Symbol<SbpParallel>> GetSbpFromState(const std::pair<std::string, int>& state) {
if (state.first == "B") {
return GetBroadcastSbpParallel();
} else if (state.first == "P") {
return GetPartialSumSbpParallel();
} else if (state.first == "S") {
return GetSplitSbpParallel(state.second);
} else {
return Error::RuntimeError() << "Invalid sbp signature state: (" << state.first << ", "
<< state.second << ");";
}
}
} // namespace
ONEFLOW_API_PYBIND11_MODULE("sbp", m) {
m.attr("max_split_axis") = kMaxSplitAxis;
py::class_<Symbol<SbpParallel>, std::shared_ptr<Symbol<SbpParallel>>>(m, "sbp",
py::dynamic_attr())
.def("__str__", &api::SbpToString)
.def("__repr__", &api::SbpToString)
.def(py::self == py::self)
.def(py::hash(py::self))
.def("_ToAttrStr",
[](const Symbol<SbpParallel>& sbp_sym) { return SbpParallelToString(*sbp_sym); })
.def(py::pickle(
[](const Symbol<SbpParallel>& sbp) { // __getstate__
return SbpGetState(sbp).GetOrThrow();
},
[](const std::pair<std::string, int>& state) { // __setstate__
return GetSbpFromState(state).GetOrThrow();
}));
m.def("split", GetSplitSbpParallel, py::arg("axis"));
m.def("broadcast", &GetBroadcastSbpParallel);
m.def("partial_sum", &GetPartialSumSbpParallel);
}
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include <pybind11/pybind11.h>
#include <pybind11/operators.h>
#include "oneflow/core/common/throw.h"
#include "oneflow/api/python/of_api_registry.h"
#include "oneflow/core/common/protobuf.h"
#include "oneflow/core/job/scope.h"
namespace py = pybind11;
namespace oneflow {
Maybe<Scope> CreateScopeSymbol(int64_t symbol_id, const std::string& symbol_conf_str) {
ScopeProto symbol_pb;
if (!TxtString2PbMessage(symbol_conf_str, &symbol_pb)) {
THROW(RuntimeError) << "symbol conf parse failed.\n" << symbol_conf_str;
}
return Scope::New(symbol_id, symbol_pb);
}
ONEFLOW_API_PYBIND11_MODULE("", m) {
py::class_<Scope, std::shared_ptr<Scope>>(m, "ScopeSymbol")
.def(py::init([](int64_t symbol_id, const std::string& symbol_conf_str) {
return CreateScopeSymbol(symbol_id, symbol_conf_str).GetPtrOrThrow();
}))
.def_property_readonly("symbol_id",
[](const Scope& x) {
if (!x.symbol_id().has_value()) {
THROW(RuntimeError) << "symbol_id not initialized";
}
return CHECK_JUST(x.symbol_id());
})
.def_property_readonly("_proto_str",
[](const Scope& x) { return PbMessage2TxtString(x.scope_proto()); })
.def("auto_increment_id", &Scope::auto_increment_id)
.def_property_readonly("session_id", &Scope::session_id)
.def_property_readonly("job_desc_symbol", &Scope::job_desc_symbol)
.def_property_readonly(
"device_parallel_desc_symbol",
[](const Scope& x) { return x.device_parallel_desc_symbol().shared_from_symbol(); })
.def_property_readonly("parent_scope_symbol", &Scope::parent_scope_symbol)
.def("MakeChildScopeProto", &Scope::MakeChildScopeProto);
}
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef _WIN32
#include <atomic>
#include <map>
#include <set>
#include <csignal>
#include <sstream>
#include <sys/wait.h>
#include <pybind11/pybind11.h>
#include "oneflow/api/python/of_api_registry.h"
#include <stdexcept>
namespace oneflow {
namespace py = pybind11;
// reference: pytorch/torch/csrc/DataLoader.cpp
// https://github.com/pytorch/pytorch/blob/d69c22dd61a2f006dcfe1e3ea8468a3ecaf931aa/torch/csrc/DataLoader.cpp
// Critical signal handlers should be registered on worker processes before
// doing work.
// The handler will raise default handler so that the kill information will be
// retrieved from main process.
// Python handle is _set_worker_signal_handlers().
#define SIGNAL_HANDLER(SIGNAL, HANDLER_NAME, ERROR_MSG) \
static void HANDLER_NAME(int sig, siginfo_t* info, void* ctx) { \
auto _w = write(STDERR_FILENO, ERROR_MSG, sizeof(ERROR_MSG) / sizeof(char)); \
(void)_w; \
struct sigaction sa {}; \
sa.sa_handler = SIG_DFL; \
sa.sa_flags = 0; \
if (sigemptyset(&sa.sa_mask) != 0 || sigaction(SIGNAL, &sa, nullptr) != 0) { \
_exit(EXIT_FAILURE); \
} else { \
raise(SIGNAL); \
} \
}
// signal(2) is really not portable. So use sigaction.
// http://man7.org/linux/man-pages/man2/signal.2.html
static inline void setSignalHandler(int signal, void (*handler)(int, siginfo_t*, void*),
struct sigaction* old_sa_ptr) {
struct sigaction sa {};
sa.sa_sigaction = handler;
sa.sa_flags = SA_RESTART | SA_SIGINFO | SA_NOCLDSTOP | SA_NODEFER;
if (sigemptyset(&sa.sa_mask) != 0 || sigaction(signal, &sa, old_sa_ptr) != 0) {
std::ostringstream oss;
oss << "An error occurred while setting handler for " << strsignal(signal) << ".";
throw std::runtime_error(oss.str());
}
}
SIGNAL_HANDLER(SIGBUS, handler_SIGBUS,
"ERROR: Unexpected bus error encountered in worker. "
"This might be caused by insufficient shared memory (shm).\n");
SIGNAL_HANDLER(SIGSEGV, handler_SIGSEGV,
"ERROR: Unexpected segmentation fault encountered in worker.\n");
SIGNAL_HANDLER(SIGFPE, handler_SIGFPE,
"ERROR: Unexpected floating-point exception encountered in worker.\n");
// When an error happened in DataLoader methods and Python starts to exit, the
// error trace will keep the loader alive, and Python may kill the children
// processes first before deleting the loader object. Then the cleaning up
// methods in DataLoader.__del__ are not yet called, and SIGCHILD will print an
// error saying a worker is killed by SIGTERM. So we suppress SIGTERM from main
// loader process here to avoid this by _exit(EXIT_SUCCESS). Note that if we
// exit with nonzero code, the loader SIGCHLD handler may report RuntimeError
// again, and then it defeats the whole purpose.
static void handler_SIGTERM(int sig, siginfo_t* info, void* ctx) {
if (info->si_pid == getppid()) { _exit(EXIT_SUCCESS); }
struct sigaction sa {};
sa.sa_handler = SIG_DFL;
sa.sa_flags = 0;
if (sigemptyset(&sa.sa_mask) != 0 || sigaction(SIGTERM, &sa, nullptr) != 0) {
_exit(EXIT_FAILURE);
} else {
raise(SIGTERM);
}
}
static void set_worker_signal_handlers() {
setSignalHandler(SIGBUS, &handler_SIGBUS, nullptr);
setSignalHandler(SIGSEGV, &handler_SIGSEGV, nullptr);
setSignalHandler(SIGTERM, &handler_SIGTERM, nullptr);
setSignalHandler(SIGFPE, &handler_SIGFPE, nullptr);
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
static std::map<int64_t, std::set<pid_t>> worker_pids = {};
static void error_if_any_worker_fails() {
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
int error;
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
std::set<pid_t>* pid_set;
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
pid_t worker_pid;
siginfo_t infop;
// Only check the pids we care about
for (auto& w : worker_pids) {
pid_set = &(w.second);
for (auto pid_it = pid_set->begin(); pid_it != pid_set->end(); ++pid_it) {
worker_pid = *pid_it;
// Use waitid rather than waitpid so that we can set NOWAIT, and that Python
// and other handlers can get whatever info they want about the child.
infop.si_pid = 0;
error = waitid(P_PID, worker_pid, &infop, WEXITED | WNOHANG | WNOWAIT);
// ignore errors and case with no waitable child
if (error < 0 || infop.si_pid == 0) continue;
if (infop.si_code == CLD_EXITED && infop.si_status != EXIT_SUCCESS) { // exit with error
std::ostringstream oss;
oss << "DataLoader worker (pid " << worker_pid << ") exited "
<< "unexpectedly with exit code " << infop.si_status << ". "
<< "Details are lost due to multiprocessing. Rerunning with "
<< "num_workers=0 may give better error trace.";
// This is necessary. Otherwise, the runtime error will kill the other
// workers, and trigger this again.
pid_set->clear();
throw std::runtime_error(oss.str());
} else if (infop.si_code == CLD_KILLED || infop.si_code == CLD_DUMPED) { // killed by signal
std::ostringstream oss;
oss << "DataLoader worker (pid " << worker_pid << ") is killed "
<< "by signal: " << strsignal(infop.si_status) << ". ";
if (infop.si_status == SIGBUS) {
oss << "It is possible that dataloader's workers are out of shared memory. "
<< "Please try to raise your shared memory limit.";
}
// This is necessary. Otherwise, the runtime error will kill the other
// workers, and trigger this again.
pid_set->clear();
throw std::runtime_error(oss.str());
}
}
}
}
inline int64_t utils_unpackLong(PyObject* obj) {
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
int overflow;
long long value = PyLong_AsLongLongAndOverflow(obj, &overflow);
if (value == -1 && PyErr_Occurred()) { throw py::value_error(); }
if (overflow != 0) { throw std::runtime_error("Overflow when unpacking long"); }
return (int64_t)value;
}
// We don't want to exit on any SIGCHLD from any child. child_pids is a tuple
// of pids we are interested in.
static void set_worker_pids(py::args py_args) {
PyObject* args = py_args.ptr();
if (PyTuple_GET_SIZE(args) != 2) {
throw py::type_error("_set_worker_pids expects exactly 2 arguments.");
}
int64_t key = utils_unpackLong(PyTuple_GET_ITEM(args, 0));
if (worker_pids.find(key) != worker_pids.end()) {
throw py::value_error(
"_set_worker_pids should be called only once for each _BaseDataLoaderIter.");
}
PyObject* child_pids = PyTuple_GET_ITEM(args, 1);
if (!PyTuple_Check(child_pids)) {
py::print("_set_worker_pids expects a tuple for child_pids, but got: ",
Py_TYPE(child_pids)->tp_name);
throw py::type_error("_set_worker_pids expects a tuple for child_pids");
}
std::set<pid_t> pids_set = {};
auto size = PyTuple_GET_SIZE(child_pids);
for (int idx = 0; idx < size; idx++) {
PyObject* obj = PyTuple_GET_ITEM(child_pids, idx);
pids_set.insert(static_cast<pid_t>(utils_unpackLong(obj)));
}
worker_pids[key] = pids_set;
}
static void remove_worker_pids(py::args py_args) {
PyObject* args = py_args.ptr();
int64_t key = utils_unpackLong(PyTuple_GET_ITEM(args, 0));
auto it = worker_pids.find(key);
if (it == worker_pids.end()) {
py::print("Cannot find worker information for _BaseDataLoaderIter with id :", key);
throw py::value_error("Cannot find worker information for _BaseDataLoaderIter");
}
worker_pids.erase(it);
}
#undef SIGNAL_HANDLER
#else
// dummy implementations for windows
static PyObject* set_worker_signal_handlers(PyObject* module, PyObject* _ignored) {
Py_RETURN_NONE;
}
static PyObject* set_worker_pids(PyObject* module, PyObject* _ignored) { Py_RETURN_NONE; }
static PyObject* remove_worker_pids(PyObject* module, PyObject* _ignored) { Py_RETURN_NONE; }
static PyObject* error_if_any_worker_fails(PyObject* module, PyObject* _ignored) { Py_RETURN_NONE; }
#endif
ONEFLOW_API_PYBIND11_MODULE("", m) {
m.def("_set_worker_signal_handlers", &set_worker_signal_handlers);
m.def("_set_worker_pids", &set_worker_pids);
m.def("_remove_worker_pids", &remove_worker_pids);
m.def("_error_if_any_worker_fails", &error_if_any_worker_fails);
}
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/api/python/utils/tensor_utils.h"
#include "oneflow/api/python/ofblob/ofblob.e.h"
#include "oneflow/core/autograd/autograd_engine.h"
#include "oneflow/core/common/container_util.h"
#include "oneflow/core/common/switch_func.h"
#include "oneflow/core/common/tensor_buffer.h"
#include "oneflow/core/framework/nd_sbp.h"
#include "oneflow/core/functional/functional.h"
#include "oneflow/extension/python/numpy.h"
#include "oneflow/core/common/decorator.h"
#include "oneflow/core/framework/consistency_check.h"
#include "oneflow/core/functional/impl/common.h"
namespace py = pybind11;
namespace oneflow {
namespace one {
Maybe<void> EagerMirroredTensorZeros(const std::shared_ptr<Tensor>& t) {
JUST(functional::CheckInplaceValid(t));
std::shared_ptr<MirroredTensor> local_tensor;
if (t->is_local()) {
local_tensor = JUST(t->AsMirroredTensor());
} else {
local_tensor = JUST(t->cur_rank_phy_tensor());
}
CHECK_OR_RETURN(local_tensor->is_eager()) << "eager tensors supported only";
JUST(PhysicalRun([&](InstructionsBuilder* builder) -> Maybe<void> {
JUST(builder->AccessBlobByCallback(
local_tensor,
[](uint64_t of_blob_ptr) {
auto* of_blob = reinterpret_cast<OfBlob*>(of_blob_ptr);
of_blob->AsyncAutoMemset(0);
},
"mut"));
return Maybe<void>::Ok();
}));
return Maybe<void>::Ok();
}
template<typename T>
Maybe<void> CopyMirroredTensorFromUntypedArray(const std::shared_ptr<Tensor>& tensor,
PyObject* array) {
return CopyBetweenMirroredTensorAndNumpy<T>(tensor, array, BlobNumpyCopyUtil<T>::From, "mut",
/*block_host_until_done=*/false);
}
Maybe<std::string> GetCopyMirroredTensorToNumpyFuncName(DataType dtype) {
using namespace oneflow;
static const HashMap<int64_t, std::shared_ptr<std::string>> data_type2func_name{
#define DATA_TYPE_FUNC_NAME_PAIR(type_cpp, type_proto) \
{type_proto, std::make_shared<std::string>("_copy_to_numpy_" #type_cpp)},
OF_PP_FOR_EACH_TUPLE(DATA_TYPE_FUNC_NAME_PAIR, POD_DATA_TYPE_SEQ)
#undef DATA_TYPE_FUNC_NAME_PAIR
};
return JUST(MapAt(data_type2func_name, static_cast<int64_t>(dtype)));
}
Maybe<std::string> GetCopyMirroredTensorFromNumpyFuncName(DataType dtype) {
using namespace oneflow;
static const HashMap<int64_t, std::shared_ptr<std::string>> data_type2func_name{
#define DATA_TYPE_FUNC_NAME_PAIR(type_cpp, type_proto) \
{type_proto, std::make_shared<std::string>("_copy_from_numpy_" #type_cpp)},
OF_PP_FOR_EACH_TUPLE(DATA_TYPE_FUNC_NAME_PAIR, POD_DATA_TYPE_SEQ)
#undef DATA_TYPE_FUNC_NAME_PAIR
};
return JUST(MapAt(data_type2func_name, static_cast<int64_t>(dtype)));
}
Maybe<std::tuple<std::vector<Shape>, std::vector<Symbol<DType>>>>
MaybeGetTensorBufferShapesAndDTypes(const std::shared_ptr<Tensor>& t) {
const auto& tensor = JUST(t->AsMirroredTensor());
if (tensor->dtype() != DType::TensorBuffer()) {
return Error::RuntimeError() << "tensor buffer supported only";
}
CHECK_OR_RETURN(tensor->is_eager()) << "eager tensors supported only";
std::vector<Shape> shapes;
std::vector<Symbol<DType>> dtypes;
auto btb = std::make_shared<BlockingThenBusy>(1);
JUST(PhysicalRun([&](InstructionsBuilder* builder) -> Maybe<void> {
return builder->SyncAccessBlobByCallback(
tensor, btb, [](uint64_t) {}, "const");
}));
JUST(btb->WaitUntilCntEqualZero(VirtualMachine::GetPredicatorNoMoreInstructionsFinished()));
const auto& eager_blob_object = JUST(tensor->eager_blob_object());
const Shape& blob_shape = eager_blob_object->shape();
const auto* tensor_buffer_ptr = eager_blob_object->dptr<TensorBuffer>();
for (int64_t i = 0; i < blob_shape.elem_cnt(); ++i) {
const TensorBuffer* tensor_buffer = tensor_buffer_ptr + i;
shapes.emplace_back(tensor_buffer->shape());
dtypes.emplace_back(DType::Get(tensor_buffer->data_type()).GetOrThrow());
}
return std::make_tuple(shapes, dtypes);
}
Maybe<void> RegisterTensorHook(const std::shared_ptr<Tensor>& self,
const AutogradMeta::Hook& hook) {
CHECK_OR_RETURN(self->requires_grad())
<< "cannot register a hook on a tensor that doesn't require gradient";
if (!self->grad_fn_node()) { JUST(AddAccumulateFunctionNode(self)); }
self->mut_autograd_meta()->add_hook(hook);
return Maybe<void>::Ok();
}
Maybe<void> RegisterTensorPostGradAccumulationHook(const std::shared_ptr<Tensor>& self,
const AutogradMeta::Hook& hook) {
if (!self->grad_fn_node()) { JUST(AddAccumulateFunctionNode(self)); }
self->mut_autograd_meta()->add_post_grad_accumulation_hook(hook);
return Maybe<void>::Ok();
}
Maybe<py::tuple> TensorGetPyTupleOfSbp(const Tensor& tensor) {
const auto& nd_sbp = JUST(tensor.nd_sbp());
const auto& tuple = std::make_shared<py::tuple>(nd_sbp->sbp_parallel_size());
for (int i = 0; i < nd_sbp->sbp_parallel_size(); ++i) {
(*tuple)[i] = SymbolOf(nd_sbp->sbp_parallel(i));
}
return tuple;
}
#define MAKE_SWITCH_ENTRY(func_name, dtype) func_name<dtype>
DEFINE_STATIC_SWITCH_FUNC(Maybe<void>, CopyMirroredTensorFromUntypedArray, MAKE_SWITCH_ENTRY,
MAKE_DATA_TYPE_CTRV_SEQ(POD_AND_HALF_DATA_TYPE_SEQ));
Maybe<Tensor> MakeLocalTensorFromData(PyObject* data, const Optional<Symbol<DType>>& dtype,
const Optional<Symbol<Device>>& device,
const bool requires_grad, const bool pin_memory) {
PyObject* array = NULL;
PyArray_Descr* np_dtype =
dtype.has_value()
? PyArray_DescrFromType(JUST(numpy::OFDataTypeToNumpyType(JUST(dtype)->data_type())))
: nullptr;
// PyArray_FromAny steals a reference to np_dtype object, so no need to decref it.
// NPY_ARRAY_DEFAULT is NPY_ARRAY_C_CONTIGUOUS | NPY_ARRAY_BEHAVED, so the
// array with NPY_ARRAY_DEFAULT flag is C-style contiguous.
// NPY_ARRAY_FORCECAST is needed otherwise there will a segfault.
array = PyArray_FromAny(data, np_dtype, 0, 0,
NPY_ARRAY_DEFAULT | NPY_ARRAY_ENSURECOPY | NPY_ARRAY_FORCECAST, nullptr);
if (!array) {
return Error::RuntimeError() << "Can not convert input data to a new numpy array.";
}
// flow.tensor([1., 2.]).dtype should be flow.float32 rather than flow.float64
if (!PyArray_Check(data)) {
int np_array_type = PyArray_TYPE(reinterpret_cast<PyArrayObject*>(array));
// Cast to float if data is double sequence, rather than numpy array.
if (np_array_type == NPY_DOUBLE && np_dtype == nullptr) {
PyObject* fp32_array = PyArray_Cast(reinterpret_cast<PyArrayObject*>(array), NPY_FLOAT);
Py_DECREF(array);
array = fp32_array;
}
}
auto* np_arr = reinterpret_cast<PyArrayObject*>(array);
const npy_intp* dims_ptr = PyArray_SHAPE(np_arr);
const Shape shape(DimVector(dims_ptr, dims_ptr + PyArray_NDIM(np_arr)));
DataType data_type = JUST(numpy::GetOFDataTypeFromNpArray(np_arr));
Symbol<Device> device_;
if (device) {
device_ = JUST(device);
} else {
device_ = JUST(Device::New("cpu"));
}
std::shared_ptr<Tensor> tensor = JUST(
functional::Empty(shape, JUST(DType::Get(data_type)), device_, /*pin_memory=*/pin_memory));
JUST(SwitchCopyMirroredTensorFromUntypedArray(SwitchCase(data_type), tensor, array));
Py_DECREF(array);
JUST(tensor->set_requires_grad(requires_grad));
return tensor;
}
namespace {
Maybe<Symbol<NdSbp>> GetAllBroadcastNdSbp(size_t ndim) {
NdSbp broadcast_nd_sbp;
for (size_t i = 0; i < ndim; ++i) {
broadcast_nd_sbp.mutable_sbp_parallel()->Add()->mutable_broadcast_parallel();
}
return SymbolOf(broadcast_nd_sbp);
}
auto* CachedGetAllBroadcastNdSbp = DECORATE(&GetAllBroadcastNdSbp, ThreadLocal);
} // namespace
Maybe<Tensor> MakeConsistentTensorFromData(PyObject* data, const Optional<Symbol<DType>>& dtype,
Symbol<ParallelDesc> placement,
const std::vector<Symbol<SbpParallel>>& sbp_tuple,
const bool requires_grad) {
PyObject* array = NULL;
if (PyArray_Check(data)) {
// Only NPY_CORDER is supported, and returns a new C-style contiguous array.
array = PyArray_NewCopy((PyArrayObject*)data, NPY_CORDER);
} else {
// NPY_ARRAY_DEFAULT is NPY_ARRAY_C_CONTIGUOUS | NPY_ARRAY_BEHAVED, so the
// array with NPY_ARRAY_DEFAULT flag is C-style contiguous.
array = PyArray_FromAny(data, nullptr, 0, 0, NPY_ARRAY_DEFAULT | NPY_ARRAY_ENSURECOPY, nullptr);
if (!array) { return Error::RuntimeError() << "Can not convert input data to a numpy array."; }
}
auto* np_arr = reinterpret_cast<PyArrayObject*>(array);
const npy_intp* dims_ptr = PyArray_SHAPE(np_arr);
const Shape shape(DimVector(dims_ptr, dims_ptr + PyArray_NDIM(np_arr)));
DataType data_type = JUST(numpy::GetOFDataTypeFromNpArray(np_arr));
if (placement->parallel_num() > 1) {
const void* buf_ptr = PyArray_DATA(np_arr);
size_t array_size = PyArray_SIZE(np_arr);
CHECK_EQ_OR_RETURN(array_size, shape.elem_cnt());
size_t byte_size = array_size * GetSizeOfDataType(data_type);
JUST(DataConsistencyCheck(buf_ptr, byte_size, placement));
}
Symbol<Device> device = JUST(Device::New(placement->device_tag()));
std::shared_ptr<Tensor> local_tensor =
JUST(functional::Empty(shape, JUST(DType::Get(data_type)), device, /*pin_memory=*/false));
JUST(SwitchCopyMirroredTensorFromUntypedArray(SwitchCase(data_type), local_tensor, array));
Py_DECREF(array);
// Cast to float if data is double sequence, rather than numpy array.
Symbol<DType> dtype_;
if (dtype) {
dtype_ = JUST(dtype);
} else if (!dtype && data_type == DataType::kDouble && !PyArray_Check(data)) {
dtype_ = DType::Float();
}
if (dtype_) { local_tensor = JUST(functional::Cast(local_tensor, dtype_, /*pin_memory=*/false)); }
size_t sbp_dims = sbp_tuple.size();
Symbol<NdSbp> broadcast_nd_sbp = JUST(CachedGetAllBroadcastNdSbp(sbp_dims));
std::shared_ptr<Tensor> broadcast_tensor = JUST(functional::LocalToConsistent(
local_tensor, placement, *JUST(GetSbpList(broadcast_nd_sbp)), shape, local_tensor->dtype()));
std::vector<Symbol<SbpParallel>> grad_sbp_tuple;
auto consistent_tensor = JUST(functional::ToConsistent(broadcast_tensor, placement, sbp_tuple,
grad_sbp_tuple, /* check_meta */ false));
JUST(consistent_tensor->set_requires_grad(requires_grad));
return consistent_tensor;
}
Maybe<Tensor> MakeTensorFromOtherTensor(const std::shared_ptr<Tensor>& other,
const bool pin_memory) {
if (other->is_local()) {
const Symbol<Device>& device = JUST(other->device());
return functional::Copy(other, device->type(), device->device_id(), pin_memory);
} else {
const Symbol<NdSbp>& nd_sbp = JUST(other->nd_sbp());
const std::vector<Symbol<SbpParallel>>& sbp_tuple = *JUST(GetSbpList(nd_sbp));
std::vector<Symbol<SbpParallel>> grad_sbp_tuple;
// TODO:(zhaoluyang) consistent case support pin_memory
return functional::ToConsistent(other, JUST(other->parallel_desc()), sbp_tuple, grad_sbp_tuple,
/* check_meta */ false);
}
}
Maybe<Tensor> MakeTensorFromOtherTensor(const std::shared_ptr<Tensor>& other,
const Optional<Symbol<DType>>& dtype,
const Optional<Symbol<Device>>& device,
const bool requires_grad, const bool pin_memory) {
std::shared_ptr<Tensor> tensor;
Symbol<Device> device_;
if (device) { device_ = JUST(device); }
if (other->is_local()) {
if (!device) { device_ = JUST(other->device()); }
tensor = JUST(functional::Copy(other, device_->type(), device_->device_id(),
pin_memory && !dtype.has_value()));
} else {
tensor = JUST(functional::ConsistentToLocal(other));
if (!device) { device_ = JUST(Device::New("cpu")); }
tensor = JUST(functional::Copy(tensor, device_->type(), device_->device_id(),
pin_memory && !dtype.has_value()));
}
if (dtype) {
const Symbol<DType>& dtype_ = JUST(dtype);
if (tensor->dtype() != dtype_) { tensor = JUST(functional::Cast(tensor, dtype_, pin_memory)); }
}
JUST(tensor->set_requires_grad(requires_grad));
return tensor;
}
Maybe<Tensor> MakeTensorFromOtherTensor(const std::shared_ptr<Tensor>& other,
const Optional<Symbol<DType>>& dtype,
const Symbol<ParallelDesc>& placement,
const std::vector<Symbol<SbpParallel>>& sbp_tuple,
const bool requires_grad) {
std::vector<Symbol<SbpParallel>> grad_sbp_tuple;
bool check_meta = other->is_consistent() ? false : true;
std::shared_ptr<Tensor> tensor =
JUST(functional::ToConsistent(other, placement, sbp_tuple, grad_sbp_tuple, check_meta));
if (dtype) {
const Symbol<DType>& dtype_ = JUST(dtype);
if (tensor->dtype() != dtype_) {
tensor = JUST(functional::Cast(tensor, dtype_, /*pin_memory=*/false));
}
}
JUST(tensor->set_requires_grad(requires_grad));
return tensor;
}
} // namespace one
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_API_PYTHON_UTILS_TENSOR_UTILS_H_
#define ONEFLOW_API_PYTHON_UTILS_TENSOR_UTILS_H_
#include <Python.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <pybind11/functional.h>
#include <pybind11/numpy.h>
#include "oneflow/api/python/framework/tensor.h"
#include "oneflow/extension/python/numpy.h"
#include "oneflow/core/framework/device.h"
#include "oneflow/core/framework/dtype.h"
#include "oneflow/core/framework/instructions_builder.h"
#include "oneflow/core/framework/tensor.h"
#include "oneflow/core/common/stride.h"
#include "oneflow/core/register/ofblob.h"
#include "oneflow/core/common/blocking_then_busy.h"
#include "oneflow/core/vm/virtual_machine.h"
#include "oneflow/core/common/foreign_lock_helper.h"
namespace py = pybind11;
namespace pybind11 {
// reference: https://github.com/pybind/pybind11/issues/1776
template<>
struct format_descriptor<oneflow::float16> {
static pybind11::dtype dtype() {
handle ptr = detail::npy_api::get().PyArray_DescrFromType_(NPY_FLOAT16);
return reinterpret_borrow<pybind11::dtype>(ptr);
}
static std::string format() {
// following: https://docs.python.org/3/library/struct.html#format-characters
return "e";
}
static constexpr auto name() { return detail::_("float16"); }
};
} // namespace pybind11
namespace oneflow {
namespace one {
Maybe<void> EagerMirroredTensorZeros(const std::shared_ptr<Tensor>& t);
template<typename T>
inline static Maybe<PyObject*> EagerMirroredTensorToNumpy(PyObject* py_tensor) {
const auto& t = PyTensor_Unpack(py_tensor);
std::shared_ptr<MirroredTensor> tensor = JUST(t->AsMirroredTensor());
CHECK_OR_RETURN(JUST(tensor->device()) == JUST(Device::New("cpu")));
CHECK_OR_RETURN(tensor->is_eager()) << "eager tensors supported only.";
// set base object attr
py::handle handle = py::handle(py_tensor);
const size_t ndim = tensor->ndim();
const auto shape = numpy::OFShapeToNumpyShape(tensor->shape()->dim_vec());
// NumPy strides use bytes. OneFlow strides use element counts.
const auto stride =
numpy::OFStrideToNumpyStride(*JUST(tensor->stride()), tensor->dtype()->data_type());
T* data_ptr = nullptr;
const auto& Callback = [&](uint64_t ofblob_ptr) {
data_ptr = reinterpret_cast<OfBlob*>(ofblob_ptr)->mut_blob()->mut_dptr<T>();
};
auto btb = std::make_shared<BlockingThenBusy>(1);
JUST(PhysicalRun([&](InstructionsBuilder* builder) -> Maybe<void> {
return builder->SyncAccessBlobByCallback(tensor, btb, Callback, "mut");
}));
JUST(btb->WaitUntilCntEqualZero(VirtualMachine::GetPredicatorNoMoreInstructionsFinished()));
return py::array(py::buffer_info(data_ptr, sizeof(T), py::format_descriptor<T>::format(), ndim,
shape, stride),
handle)
.release()
.ptr();
}
template<typename T>
inline Maybe<void> CopyBetweenMirroredTensorAndNumpy(
const std::shared_ptr<Tensor>& t, PyObject* array,
Maybe<void> (*Copy)(uint64_t, const NumPyArrayPtr&), const std::string& modifier,
bool block_host_until_done) {
auto tensor = JUST(t->AsMirroredTensor());
CHECK_OR_RETURN(tensor->is_eager()) << "eager tensors supported only.";
if (block_host_until_done) {
NumPyArrayPtr array_ptr(array);
const auto& Callback = [array_ptr, Copy](uint64_t ofblob_ptr) {
CHECK_JUST(Copy(ofblob_ptr, array_ptr));
};
auto btb = std::make_shared<BlockingThenBusy>(1);
JUST(PhysicalRun([&](InstructionsBuilder* builder) -> Maybe<void> {
return builder->SyncAccessBlobByCallback(tensor, btb, Callback, modifier);
}));
JUST(btb->WaitUntilCntEqualZero(VirtualMachine::GetPredicatorNoMoreInstructionsFinished()));
} else {
Py_INCREF(array);
NumPyArrayPtr array_ptr(array, [array]() {
CHECK_JUST(Singleton<ForeignLockHelper>::Get()->WithScopedAcquire([&]() -> Maybe<void> {
Py_DECREF(array);
return Maybe<void>::Ok();
}));
});
JUST(PhysicalRun([&](InstructionsBuilder* builder) -> Maybe<void> {
return builder->AccessBlobByCallback(
tensor,
[array_ptr, Copy](uint64_t ofblob_ptr) { CHECK_JUST(Copy(ofblob_ptr, array_ptr)); },
modifier);
}));
}
return Maybe<void>::Ok();
}
Maybe<std::string> GetCopyMirroredTensorToNumpyFuncName(DataType dtype);
Maybe<std::string> GetCopyMirroredTensorFromNumpyFuncName(DataType dtype);
Maybe<std::tuple<std::vector<Shape>, std::vector<Symbol<DType>>>>
MaybeGetTensorBufferShapesAndDTypes(const std::shared_ptr<Tensor>& t);
Maybe<void> RegisterTensorHook(const std::shared_ptr<Tensor>& self, const AutogradMeta::Hook& hook);
Maybe<void> RegisterTensorPostGradAccumulationHook(const std::shared_ptr<Tensor>& self,
const AutogradMeta::Hook& hook);
Maybe<py::tuple> TensorGetPyTupleOfSbp(const Tensor& tensor);
Maybe<Tensor> MakeLocalTensorFromData(PyObject* data, const Optional<Symbol<DType>>& dtype,
const Optional<Symbol<Device>>& device,
const bool requires_grad, const bool pin_memory);
Maybe<Tensor> MakeConsistentTensorFromData(PyObject* data, const Optional<Symbol<DType>>& dtype,
Symbol<ParallelDesc> placement,
const std::vector<Symbol<SbpParallel>>& sbp_tuple,
const bool requires_grad);
Maybe<Tensor> MakeTensorFromOtherTensor(const std::shared_ptr<Tensor>& other,
const bool pin_memory);
Maybe<Tensor> MakeTensorFromOtherTensor(const std::shared_ptr<Tensor>& other,
const Optional<Symbol<DType>>& dtype,
const Optional<Symbol<Device>>& device,
const bool requires_grad, const bool pin_memory);
Maybe<Tensor> MakeTensorFromOtherTensor(const std::shared_ptr<Tensor>& other,
const Optional<Symbol<DType>>& dtype,
const Symbol<ParallelDesc>& placement,
const std::vector<Symbol<SbpParallel>>& sbp_tuple,
const bool requires_grad);
} // namespace one
} // namespace oneflow
#endif // ONEFLOW_API_PYTHON_UTILS_TENSOR_UTILS_H_
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include <memory>
#include <string>
#include "oneflow/core/auto_parallel/boxing_collector.h"
#include "oneflow/core/common/data_type.h"
#include "oneflow/core/common/maybe.h"
#include "oneflow/core/framework/nd_sbp.h"
#include "oneflow/core/job/global_for.h"
#include "oneflow/core/job/nd_sbp_util.h"
#include "oneflow/core/job/resource_desc.h"
#include "oneflow/core/job/sbp_parallel.h"
#include "oneflow/core/job/sbp_parallel.pb.h"
#include "oneflow/core/register/blob_desc.h"
#include "oneflow/core/rpc/include/global_process_ctx.h"
#include "oneflow/core/framework/sbp_infer_util.h"
#include "oneflow/core/job/parallel_desc.h"
#include "oneflow/core/job/lazy_mode.h"
namespace oneflow {
namespace {
void DfsSetNdSbp(const std::vector<::oneflow::SbpParallel>& id2sbp_parallel, int32_t depth,
int32_t max_depth, NdSbp& nd_sbp, std::vector<NdSbp>& nd_sbp_lists,
std::unordered_map<::oneflow::NdSbp, int32_t>& nd_sbp_universe) {
if (depth == max_depth) {
nd_sbp_universe[nd_sbp] = nd_sbp_lists.size();
nd_sbp_lists.push_back(nd_sbp);
} else {
for (const auto& sbp_parallel : id2sbp_parallel) {
*nd_sbp.mutable_sbp_parallel(depth) = sbp_parallel;
DfsSetNdSbp(id2sbp_parallel, depth + 1, max_depth, nd_sbp, nd_sbp_lists, nd_sbp_universe);
}
}
}
// Let a nd sbp be consistent with the given hierarchy number
Maybe<NdSbp> SetNdSbpDim(NdSbp nd_sbp, int32_t hierarchy_num) {
// Do not need to change
if (nd_sbp.sbp_parallel_size() == hierarchy_num) { return nd_sbp; }
// (S0, S0) -> S0
if (hierarchy_num == 1) {
CHECK_OR_RETURN(Is1dSbp(nd_sbp))
<< NdSbpToString(nd_sbp) << " can not be converted to a 1d sbp!";
NdSbp new_sbp;
new_sbp.add_sbp_parallel();
*new_sbp.mutable_sbp_parallel(0) = nd_sbp.sbp_parallel(0);
return new_sbp;
}
// S0 -> (S0, S0)
CHECK_EQ_OR_RETURN(nd_sbp.sbp_parallel_size(), 1) << "Illegal nd sbp transform.";
NdSbp new_sbp;
for (int32_t i = 0; i < hierarchy_num; i++) {
new_sbp.add_sbp_parallel();
*new_sbp.mutable_sbp_parallel(i) = nd_sbp.sbp_parallel(0);
}
return new_sbp;
}
} // namespace
// A constructor with init, designed for uncustomized boxing collector
BoxingCollector::BoxingCollector(int32_t max_axis) { CHECK_JUST(Init(max_axis)); }
// Construct a boxing collector with given maximum number of axis
Maybe<void> BoxingCollector::Init(int32_t max_axis) {
// Not allowed two-step boxing and disable checking for debugging
if (ParseBooleanFromEnv("ONEFLOW_BOXING_DISABLE_MIDDLE_NODE_AND_CHECK", false)) {
return Maybe<void>::Ok();
}
// Set up at least two split for op graph.
// For a negative example: Resnet50 only have B, P, S(0)
CollectUniverse(max_axis);
GenerateNdSbpList(2);
GenerateMap1d2nd();
// Get copy cost in lazy mode
LazyMode::Guard enable_lazy_mode(true);
JUST(GenerateCombination4SamePlacement(3));
JUST(GenerateCombination4DiffHierarchy(this, this));
JUST(GenerateCombination4DiffPlacement(this, this));
return Maybe<void>::Ok();
}
// Customized initialization with given blob and parallel description
Maybe<void> BoxingCollector::Init(const BlobDesc& logical_blob_desc,
const ParallelDesc& parallel_desc) {
CollectUniverse(logical_blob_desc.shape().NumAxes());
GenerateNdSbpList(parallel_desc.hierarchy()->NumAxes());
// Filter out unsuitable middle nodes before computing minimum cost.
JUST(FilterNdSbpList4LogicalShape(logical_blob_desc, *parallel_desc.hierarchy()));
GenerateMap1d2nd();
// Get copy cost in lazy mode
LazyMode::Guard enable_lazy_mode(true);
JUST(GenerateCombination4SamePlacement(5, logical_blob_desc, parallel_desc));
return Maybe<void>::Ok();
}
// Collect Sbp Parallel
void BoxingCollector::CollectUniverse(const SbpParallel& sbp) {
if (sbp_parallel_universe_.find(sbp) == sbp_parallel_universe_.end()) {
int32_t curr_size = sbp_parallel_universe_.size();
sbp_parallel_universe_[sbp] = curr_size;
id2sbp_parallel_.push_back(sbp);
}
}
// Find corresponding id for Nd sbp
int32_t BoxingCollector::FindId4NdSbp(const NdSbp& nd_sbp) {
// Directly search on the nd_sbp_list
if (nd_sbp.sbp_parallel_size() == hierarchy_num_) {
const auto& it_nd_sbp = nd_sbp_universe_.find(nd_sbp);
if (it_nd_sbp != nd_sbp_universe_.end()) {
return it_nd_sbp->second;
} else {
return -1;
}
}
// Find the diagonal node if it could be converted to a 1D sbp
if (Is1dSbp(nd_sbp)) {
const auto& it_nd_sbp = sbp_parallel_universe_.find(nd_sbp.sbp_parallel(0));
if (it_nd_sbp != sbp_parallel_universe_.end()) { return id_1d_2_nd_[it_nd_sbp->second]; }
}
// Can not be converted to a 1D sbp or not found in the 1D sbp list
return -1;
}
// Set default Sbp list
void BoxingCollector::CollectUniverse(int32_t max_axis) {
SbpParallel sbp;
sbp.mutable_broadcast_parallel();
CollectUniverse(sbp);
for (int32_t axis = 0; axis < max_axis; axis++) {
sbp.mutable_split_parallel()->set_axis(axis);
CollectUniverse(sbp);
}
sbp.mutable_partial_sum_parallel();
CollectUniverse(sbp);
}
// Generate nd sbp list
void BoxingCollector::GenerateNdSbpList(int32_t hierarchy_num) {
// 1D sbp does not support S->P. But it seems that we do not need to deal with it for now.
// And we do not have 3D sbp or higher dimension.
hierarchy_num_ = hierarchy_num;
// Generate possible nd_sbp lists
NdSbp nd_sbp;
for (int32_t dim_sbp = 0; dim_sbp < hierarchy_num; dim_sbp++) { nd_sbp.add_sbp_parallel(); }
DfsSetNdSbp(id2sbp_parallel_, 0, hierarchy_num, nd_sbp, nd_sbp_lists_, nd_sbp_universe_);
}
// Generate the map from 1d sbp to 2d sbp
void BoxingCollector::GenerateMap1d2nd() {
// Number of 1d sbp
int32_t m = id2sbp_parallel_.size();
// Generate the id Map from 1d sbp to nd sbp
NdSbp nd_sbp;
for (int32_t dim_sbp = 0; dim_sbp < hierarchy_num_; dim_sbp++) { nd_sbp.add_sbp_parallel(); }
id_1d_2_nd_.resize(m, -1);
for (int32_t id_1d = 0; id_1d < m; id_1d++) {
for (int32_t dim_sbp = 0; dim_sbp < hierarchy_num_; dim_sbp++) {
*nd_sbp.mutable_sbp_parallel(dim_sbp) = id2sbp_parallel_[id_1d];
}
// NOTE: The 2d sbp might be filtered out already.
const auto& it_ = nd_sbp_universe_.find(nd_sbp);
if (it_ != nd_sbp_universe_.end()) { id_1d_2_nd_[id_1d] = it_->second; }
}
}
// Generate the transfer rule for different combinations with the same hierarchy
Maybe<void> BoxingCollector::GenerateCombination4SamePlacement(int32_t max_middle_node_num) {
// other parameters
// NOTE: The performance of this function are all the same with different hierarchy
int32_t world_size = GlobalProcessCtx::WorldSize();
Shape hierarchy44({4 * world_size, 4 * world_size});
std::shared_ptr<Shape> virtual_hierarchy = std::make_shared<Shape>(hierarchy44);
auto parallel_desc = JUST(ParallelDesc::New(
"cpu", {"0:0-" + std::to_string(hierarchy44.elem_cnt() - 1)}, virtual_hierarchy));
BlobDesc blob_desc({16, 16, 16, 16}, DataType::kInt8, /*is_dynamic=*/false);
JUST(GenerateCombination4SamePlacement(max_middle_node_num, blob_desc, *parallel_desc));
return Maybe<void>::Ok();
}
// Generate the transfer rule for different combinations with the same hierarchy
Maybe<void> BoxingCollector::GenerateCombination4SamePlacement(int32_t max_middle_node_num,
const BlobDesc& blob_desc,
const ParallelDesc& parallel_desc) {
// Store the origin transfer cost information
int32_t n = nd_sbp_lists_.size();
minimum_copy_cost_.resize(n);
middle_nodes_.resize(n);
for (int32_t i = 0; i < n; i++) {
minimum_copy_cost_[i].resize(n);
middle_nodes_[i].resize(n);
for (int32_t j = 0; j < n; j++) {
minimum_copy_cost_[i][j] = JUST(ComputeLazyCopyCostBetweenNdSbp(
nd_sbp_lists_[i], nd_sbp_lists_[j], blob_desc, parallel_desc, parallel_desc,
/*requires_same_sbp=*/false));
}
}
auto NotMiddleNode = [&](int32_t i, int32_t j, int32_t k, int32_t middle_node_num_ik) -> bool {
// Not allow i -> i -> j or i -> j -> j.
if (k == j || k == i) { return true; }
// We add middle nodes one by one
// Thus, we allow multiple nodes from i to k but we only accept 1 step from k to j.
// i -> ? -> k -> j
if (middle_nodes_[k][j].size() > 0) { return true; }
// To avoid multiple counting and bugs, the number of middle nodes between i and k
// must be exactly middle_node_num_ik, which is (middle_node_num - 1)
if (middle_node_num_ik) {
if (middle_nodes_[i][k].size() == 0 || middle_nodes_[i][k][0].size() != middle_node_num_ik) {
return true;
}
} else {
if (middle_nodes_[i][k].size() > 0) { return true; }
}
return false;
};
for (int32_t middle_node_num = 1; middle_node_num <= max_middle_node_num; middle_node_num++) {
int32_t middle_node_num_ik = middle_node_num - 1;
for (int32_t i = 0; i < n; i++) {
for (int32_t j = 0; j < n; j++) {
if (minimum_copy_cost_[i][j] < GetValidMaxCopyCost()) { continue; }
// Compute the smallest transfer cost
// k is the middle node, i -> k -> j
for (int32_t k = 0; k < n; k++) {
if (NotMiddleNode(i, j, k, middle_node_num_ik)) { continue; }
double curr_copy_cost = minimum_copy_cost_[i][k] + minimum_copy_cost_[k][j];
if (curr_copy_cost < minimum_copy_cost_[i][j]) {
minimum_copy_cost_[i][j] = curr_copy_cost;
}
}
// If the minimum copy cost remians infinity, adding one middle node does not make it.
if (minimum_copy_cost_[i][j] > GetValidMaxCopyCost()) { continue; }
// Find those middle nodes
for (int32_t k = 0; k < n; k++) {
if (NotMiddleNode(i, j, k, middle_node_num_ik)) { continue; }
// Now we start to judge if the edge have a minimum cost
// It needs to be "<=" since we have 0 cost.
// Using "<" would give no middle nodes from (B, B) to any other nd sbp.
if (minimum_copy_cost_[i][k] + minimum_copy_cost_[k][j]
<= minimum_copy_cost_[i][j] * 1.0000001) {
// i -> ? -> k
if (middle_nodes_[i][k].size() > 0) {
// We have multiple choices going from i to k
for (const auto& middle_node_ik : middle_nodes_[i][k]) {
middle_nodes_[i][j].push_back(middle_node_ik);
middle_nodes_[i][j][middle_nodes_[i][j].size() - 1].push_back(k);
}
} else {
// We only need one middle node k to reach j from i
middle_nodes_[i][j].push_back({k});
}
}
}
CHECK_OR_RETURN(middle_nodes_[i][j].size() > 0)
<< "No middle nodes given from " << NdSbpToString(nd_sbp_lists_[i]) << " to "
<< NdSbpToString(nd_sbp_lists_[j]) << " in boxing collector";
}
}
}
return Maybe<void>::Ok();
}
// Generate the transfer rule for different combinations with different hierarchies on the same
// placement
Maybe<void> BoxingCollector::GenerateCombination4DiffHierarchy(
BoxingCollector* boxing_collector_producer, BoxingCollector* boxing_collector_consumer) {
// Store the boxing collector pointer
// Search the path that contains one of the diagonal sbp
int32_t n = nd_sbp_lists_.size();
diag_node_diff_hierarchy_.resize(n);
for (int32_t i = 0; i < n; i++) {
diag_node_diff_hierarchy_[i].resize(n);
for (int32_t j = 0; j < n; j++) {
JUST(Generate1Combination4DiffHierarchy(i, j, boxing_collector_producer,
boxing_collector_consumer,
diag_node_diff_hierarchy_[i][j]));
}
}
return Maybe<void>::Ok();
}
// Generate the transfer rule for different combinations with different placements
Maybe<void> BoxingCollector::GenerateCombination4DiffPlacement(
BoxingCollector* boxing_collector_producer, BoxingCollector* boxing_collector_consumer) {
// Virtual parallel and blob description
int32_t world_size = GlobalProcessCtx::WorldSize();
BlobDesc blob_desc({16, 16, 16, 16}, DataType::kInt8, /*is_dynamic=*/false);
// Virtual placements before transfer
Shape in_hierarchy44({4 * world_size + 1, 4 * world_size});
std::shared_ptr<Shape> in_hierarchy = std::make_shared<Shape>(in_hierarchy44);
auto in_parallel_desc = JUST(ParallelDesc::New(
"cpu", {"0:0-" + std::to_string(in_hierarchy44.elem_cnt() - 1)}, in_hierarchy));
// Virtual placements after transfer
Shape out_hierarchy44({4 * world_size, 4 * world_size});
std::shared_ptr<Shape> out_hierarchy = std::make_shared<Shape>(out_hierarchy44);
auto out_parallel_desc = JUST(ParallelDesc::New(
"cpu", {"0:0-" + std::to_string(out_hierarchy44.elem_cnt() - 1)}, out_hierarchy));
JUST(GenerateCombination4DiffPlacement(boxing_collector_producer, boxing_collector_consumer,
blob_desc, *in_parallel_desc, *out_parallel_desc));
return Maybe<void>::Ok();
}
// The cost for transferring a 1D sbp between different placements
Maybe<void> BoxingCollector::ComputeCostFor1DSbpDiffPlacement(
const BlobDesc& blob_desc, const ParallelDesc& in_parallel_desc,
const ParallelDesc& out_parallel_desc,
std::vector<std::vector<double>>& cost_4_diff_placement) {
// Number of 1d sbp
int32_t m = id2sbp_parallel_.size();
// Compute the cost while transferring a 1D sbp between different placements
cost_4_diff_placement.resize(m);
for (int32_t id_1d_producer = 0; id_1d_producer < m; id_1d_producer++) {
cost_4_diff_placement[id_1d_producer].resize(m, GetMaxVal<float>());
int32_t diag_producer = id_1d_2_nd_[id_1d_producer];
if (diag_producer < 0) { continue; }
for (int32_t id_1d_consumer = 0; id_1d_consumer < m; id_1d_consumer++) {
int32_t diag_consumer = id_1d_2_nd_[id_1d_consumer];
if (diag_consumer < 0) { continue; }
cost_4_diff_placement[id_1d_producer][id_1d_consumer] = JUST(ComputeLazyCopyCostBetweenNdSbp(
nd_sbp_lists_[diag_producer], nd_sbp_lists_[diag_consumer], blob_desc, in_parallel_desc,
out_parallel_desc, false));
}
}
return Maybe<void>::Ok();
}
// Generate the transfer rule for different combinations with different placements
Maybe<void> BoxingCollector::GenerateCombination4DiffPlacement(
BoxingCollector* boxing_collector_producer, BoxingCollector* boxing_collector_consumer,
const BlobDesc& blob_desc, const ParallelDesc& in_parallel_desc,
const ParallelDesc& out_parallel_desc) {
// The cost for transferring a 1D sbp between different placements
std::vector<std::vector<double>> cost_4_diff_placement;
// Compute the cost while transferring a 1D sbp between different placements
JUST(ComputeCostFor1DSbpDiffPlacement(blob_desc, in_parallel_desc, out_parallel_desc,
cost_4_diff_placement));
// Search the path that contains two of the diagonal sbp
int32_t n = nd_sbp_lists_.size();
diag_node_diff_placement_.resize(n);
for (int32_t i = 0; i < n; i++) {
diag_node_diff_placement_[i].resize(n);
for (int32_t j = 0; j < n; j++) {
JUST(Generate1Combination4DiffPlacement(i, j, boxing_collector_producer,
boxing_collector_consumer, cost_4_diff_placement,
diag_node_diff_placement_[i][j]));
}
}
return Maybe<void>::Ok();
}
// Print the cost and middle nodes
void BoxingCollector::PrintBoxingTables() {
if (GlobalProcessCtx::Rank() == 0) {
std::cout << "===================minimum copy cost==================" << std::endl;
// other parameters
// To be noted that the performance of this function are all the same with different hierarchy
Shape hierarchy44({4, 4});
std::shared_ptr<Shape> in_hierarchy = std::make_shared<Shape>(hierarchy44);
double logical_blob_size = 1024.0;
int32_t n = nd_sbp_lists_.size();
// Print the origin copy cost table
std::cout << "Cost\t";
for (int32_t j = 0; j < n; j++) { std::cout << NdSbpToString(nd_sbp_lists_[j]) << "\t"; }
std::cout << std::endl;
for (int32_t i = 0; i < n; i++) {
std::cout << NdSbpToString(nd_sbp_lists_[i]) << "\t";
for (int32_t j = 0; j < n; j++) {
if (minimum_copy_cost_[i][j] > GetValidMaxCopyCost()) {
std::cout << "X\t";
} else {
std::cout << minimum_copy_cost_[i][j] << "\t";
}
}
std::cout << std::endl;
}
std::cout << std::endl;
std::cout << "Original Copy Cost" << std::endl;
std::cout << "logical blob size: " << logical_blob_size << std::endl;
std::cout << "hierarchy: " << *in_hierarchy << std::endl;
std::cout << "============================middle nodes===========================" << std::endl;
// Print the middle nodes
std::cout << "Middle Sbp\t";
for (int32_t j = 0; j < n; j++) { std::cout << NdSbpToString(nd_sbp_lists_[j]) << "\t"; }
std::cout << std::endl;
for (int32_t i = 0; i < n; i++) {
std::cout << NdSbpToString(nd_sbp_lists_[i]) << "\t";
for (int32_t j = 0; j < n; j++) {
if (minimum_copy_cost_[i][j] > GetValidMaxCopyCost()) {
std::cout << "X";
} else if (middle_nodes_[i][j].size() > 0) {
for (int32_t k = 0; k < middle_nodes_[i][j].size(); k++) {
std::cout << NdSbpToString(nd_sbp_lists_[middle_nodes_[i][j][k][0]]);
for (int32_t l = 1; l < middle_nodes_[i][j][k].size(); l++) {
std::cout << "->" << NdSbpToString(nd_sbp_lists_[middle_nodes_[i][j][k][l]]);
}
std::cout << "; ";
}
}
std::cout << "\t";
}
std::cout << std::endl;
}
std::cout << std::endl;
std::cout << "Minimum Copy Cost after second search" << std::endl;
std::cout << "logical blob size: " << logical_blob_size << std::endl;
std::cout << "hierarchy: " << *in_hierarchy << std::endl;
std::cout << "====================middle nodes for different placement===================="
<< std::endl;
std::cout << "Middle nodes for different placement\t";
for (int32_t j = 0; j < n; j++) { std::cout << NdSbpToString(nd_sbp_lists_[j]) << "\t"; }
std::cout << std::endl;
for (int32_t i = 0; i < n; i++) {
std::cout << NdSbpToString(nd_sbp_lists_[i]) << "\t";
for (int32_t j = 0; j < n; j++) {
if (diag_node_diff_placement_[i][j].size() > 0) {
for (int32_t k = 0; k < diag_node_diff_placement_[i][j].size(); k++) {
std::cout << "[" << NdSbpToString(nd_sbp_lists_[diag_node_diff_placement_[i][j][k][0]])
<< ", " << NdSbpToString(nd_sbp_lists_[diag_node_diff_placement_[i][j][k][1]])
<< "]; ";
}
}
std::cout << "\t";
}
std::cout << std::endl;
}
std::cout << "====================middle nodes for different hierarchy===================="
<< std::endl;
std::cout << "Middle nodes for different hierarchy\t";
for (int32_t j = 0; j < n; j++) { std::cout << NdSbpToString(nd_sbp_lists_[j]) << "\t"; }
std::cout << std::endl;
for (int32_t i = 0; i < n; i++) {
std::cout << NdSbpToString(nd_sbp_lists_[i]) << "\t";
for (int32_t j = 0; j < n; j++) {
if (diag_node_diff_hierarchy_[i][j].size() > 0) {
for (int32_t k = 0; k < diag_node_diff_hierarchy_[i][j].size(); k++) {
std::cout << NdSbpToString(nd_sbp_lists_[diag_node_diff_hierarchy_[i][j][k][0]])
<< "; ";
}
}
std::cout << "\t";
}
std::cout << std::endl;
}
std::cout << "================================================" << std::endl;
}
}
// Ask if the boxing algorithm accepts the current sbp combination
Maybe<void> BoxingCollector::AskSbpCombination(const NdSbp& sbp_producer, const NdSbp& sbp_consumer,
const BlobDesc& logical_blob_desc,
const ParallelDesc& producer_parallel_desc,
const ParallelDesc& consumer_parallel_desc,
bool is_customized, std::vector<NdSbp>& middle_sbps,
int32_t* diag_node_pos, bool compute_cost) {
middle_sbps.clear();
// Not allowed two-step boxing and disable checking for debugging
if (ParseBooleanFromEnv("ONEFLOW_BOXING_DISABLE_MIDDLE_NODE_AND_CHECK", false)) {
return Maybe<void>::Ok();
}
// If compute_cost==false + 2D sbp + same placment + nccl logical + not (p->b),
// Use nccl logical send recv instead of middle node.
// Note that in op sbp inference, cost of middle nodes is still used for the moment.
#if defined(WITH_CUDA) || defined(WITH_ROCM)
if (compute_cost == false && producer_parallel_desc.hierarchy()->NumAxes() == 2
&& producer_parallel_desc == consumer_parallel_desc
&& !(NdSbpHasPartialParallel(sbp_consumer)) &&
// TODO(): When same dim 0 finished dealing with (*, P) -> (*, S) in nccl logical pass, open
// this condition. When dealing with (P, P) -> (B, S0), middle node will change it to (P, P)
// -> (P, S0) -> (B, S0), neither same dim 0 or send recv in nccl logical pass can deal with
// (P, P) -> (P, S0) at the moment.
// !(NdSbpHasPartialParallel(sbp_producer) && NdSbpHasBroadcastParallel(sbp_consumer)) &&
Singleton<ResourceDesc, ForSession>::Get()->nccl_use_compute_stream()) {
VLOG(3) << "Middle node insertion is skipped when src sbp is " << NdSbpToString(sbp_producer)
<< " dst sbp is " << NdSbpToString(sbp_consumer)
<< ", because nccl logical send/recv can handle this.";
return Maybe<void>::Ok();
}
#endif // WITH_CUDA
// Dealing with 1D sbp to 1D sbp
// Specifically, S -> P.
if (Is1dSbp(sbp_producer) && Is1dSbp(sbp_consumer)) {
if (sbp_consumer.sbp_parallel(0).has_partial_sum_parallel()) {
// Support [4]: P <--> [2, 2]: (P, P)
// Support {0, 1, 2, 3}: P <--> {2, 0, 6, 7}: (P, P)
if (producer_parallel_desc.parallel_num() == consumer_parallel_desc.parallel_num()
&& sbp_producer.sbp_parallel(0).has_partial_sum_parallel()) {
return Maybe<void>::Ok();
}
if (!sbp_producer.sbp_parallel(0).has_broadcast_parallel()) {
// S -> B -> P (Large cost!)
// TODO: Please implement S -> P directly.
// We do not support [3]: P <--> [2, 2]: (P, P) as well.
int32_t hierarchy_size = 0;
if (producer_parallel_desc.hierarchy()->elem_cnt()
< consumer_parallel_desc.hierarchy()->elem_cnt()) {
// The diagonal node uses the parallel description from producer
// (S, S) -> (B, B) -> P/(P, P) or S -> B -> P/(P, P)
*diag_node_pos = 1;
hierarchy_size = producer_parallel_desc.hierarchy()->NumAxes();
} else {
// The diagonal node uses the parallel description from consumer
// S/(S, S) -> B -> P or S/(S, S) -> (B, B) -> (P, P)
*diag_node_pos = 0;
hierarchy_size = consumer_parallel_desc.hierarchy()->NumAxes();
}
NdSbp broadcast_nd;
for (int32_t i = 0; i < hierarchy_size; i++) {
broadcast_nd.add_sbp_parallel();
broadcast_nd.mutable_sbp_parallel(i)->mutable_broadcast_parallel();
}
middle_sbps.emplace_back(broadcast_nd);
}
return Maybe<void>::Ok();
}
}
// Middle nodes algorithm supports transfer for different machines or devices or hierarchies
if (producer_parallel_desc != consumer_parallel_desc) {
JUST(AskSbpCombination4DiffPlacement(sbp_producer, sbp_consumer, logical_blob_desc,
producer_parallel_desc, consumer_parallel_desc,
is_customized, middle_sbps, diag_node_pos, compute_cost));
return Maybe<void>::Ok();
}
// Transfer for the same machines, devices and hierarchy.
if (sbp_producer == sbp_consumer) { return Maybe<void>::Ok(); }
const auto& parallel_hierarchy = producer_parallel_desc.hierarchy();
*diag_node_pos = 0;
// Dealing with nD sbp, n>2
if (parallel_hierarchy->NumAxes() > 2) {
CHECK_OR_RETURN(compute_cost)
<< "Boxing does not support a hierarchy with dimension greater than 2";
return Maybe<void>::Ok();
}
// Ask for sbp combination with the same 2-D hierarchy and placement
JUST(AskSbpCombination4Same2DPlacement(sbp_producer, sbp_consumer, logical_blob_desc,
producer_parallel_desc, consumer_parallel_desc,
is_customized, middle_sbps, diag_node_pos, compute_cost));
return Maybe<void>::Ok();
}
// Ask for sbp combination with the same 2-D hierarchy and placement
Maybe<void> BoxingCollector::AskSbpCombination4Same2DPlacement(
const NdSbp& sbp_producer, const NdSbp& sbp_consumer, const BlobDesc& logical_blob_desc,
const ParallelDesc& producer_parallel_desc, const ParallelDesc& consumer_parallel_desc,
bool is_customized, std::vector<NdSbp>& middle_sbps, int32_t* diag_node_pos,
bool compute_cost) {
CHECK_OR_RETURN(producer_parallel_desc == consumer_parallel_desc)
<< "Producer and consumer have different placements, Please use AskSbpCombination directly";
middle_sbps.clear();
// Find the 2D sbp id
int32_t i = FindId4NdSbp(sbp_producer);
int32_t j = FindId4NdSbp(sbp_consumer);
// Dealing with 2D sbp
if (i >= 0 && j >= 0) {
// Such combination can not be support with limited middle nodes
if (minimum_copy_cost_[i][j] > GetValidMaxCopyCost()) {
CHECK_OR_RETURN(compute_cost) << "Boxing does not support " << NdSbpToString(sbp_producer)
<< " -> " << NdSbpToString(sbp_consumer) << " for 2D sbp";
return Maybe<void>::Ok();
}
// Current design can deal with such combination. Do not need to insert middle nodes
if (middle_nodes_[i][j].size() == 0) { return Maybe<void>::Ok(); }
// Find a list of middle nodes with minimum storage
int32_t min_k = -1;
double min_cost = GetValidMaxCopyCost();
for (int32_t k = 0; k < middle_nodes_[i][j].size(); k++) {
double curr_cost = 0.0;
for (int32_t middle_sbp_id : middle_nodes_[i][j][k]) {
Shape logical_shape = logical_blob_desc.shape();
// Storage4NdSbp would modify logical_shape2 as well
curr_cost += Storage4NdSbp(nd_sbp_lists_[middle_sbp_id], logical_shape,
*producer_parallel_desc.hierarchy());
if (curr_cost > GetValidMaxCopyCost()) { break; }
}
// store k if renew minimum cost
if (curr_cost < min_cost) {
min_k = k;
min_cost = curr_cost;
}
}
// If we found a list of middle nodes with current boxing collector
int32_t producer_hierarchy_num = producer_parallel_desc.hierarchy()->NumAxes();
if (min_k >= 0) {
for (int32_t middle_sbp_id : middle_nodes_[i][j][min_k]) {
middle_sbps.emplace_back(
*JUST(SetNdSbpDim(nd_sbp_lists_[middle_sbp_id], producer_hierarchy_num)));
}
return Maybe<void>::Ok();
}
}
// // If we can not found a list of middle nodes even after customized boxing collector
if (is_customized) {
CHECK_OR_RETURN(compute_cost) << "Boxing does not support " << NdSbpToString(sbp_producer)
<< " -> " << NdSbpToString(sbp_consumer)
<< " for Shape: " << logical_blob_desc.shape();
return Maybe<void>::Ok();
}
// Customized boxing collector and try the algorithm again
BoxingCollector customized_boxing_collector;
JUST(customized_boxing_collector.Init(logical_blob_desc, producer_parallel_desc));
JUST(customized_boxing_collector.AskSbpCombination4Same2DPlacement(
sbp_producer, sbp_consumer, logical_blob_desc, producer_parallel_desc, consumer_parallel_desc,
/*is_customized=*/true, middle_sbps, diag_node_pos, compute_cost));
return Maybe<void>::Ok();
}
// Ask for sbp combination with different hierarchies and placements
Maybe<void> BoxingCollector::AskSbpCombination4DiffPlacement(
const NdSbp& sbp_producer, const NdSbp& sbp_consumer, const BlobDesc& logical_blob_desc,
const ParallelDesc& producer_parallel_desc, const ParallelDesc& consumer_parallel_desc,
bool is_customized, std::vector<NdSbp>& middle_sbps, int32_t* diag_node_pos,
bool compute_cost) {
middle_sbps.clear();
// Find the 2D sbp id
int32_t i = FindId4NdSbp(sbp_producer);
int32_t j = FindId4NdSbp(sbp_consumer);
// Different placements: [2, 3] vs 5, or [3, 2] vs [2, 2], or cpu vs cuda
// Different hierarchies: [2, 3] vs 5, or [4, 3] vs [6, 2]
bool same_placement = producer_parallel_desc.EqualsIgnoringHierarchy(consumer_parallel_desc);
// Dealing with 2D sbp
if (i >= 0 && j >= 0) {
// Pure copy between machines and devices
if (i == j && (*producer_parallel_desc.hierarchy() == *consumer_parallel_desc.hierarchy())) {
return Maybe<void>::Ok();
}
if (same_placement) {
// Different hierarchies
CHECK_OR_RETURN(diag_node_diff_hierarchy_.size() > 0)
<< "Have not initialzie the combination table for different hierarchies yet! "
"Please run JUST(GenerateCombination4DiffHierarchy(this, this)); "
"before Asking sbp combination for different parallel description.";
if (JUST(Ask1Combination4DiffPlacement(
sbp_producer, sbp_consumer, logical_blob_desc, producer_parallel_desc,
consumer_parallel_desc, is_customized, middle_sbps, diag_node_pos, compute_cost, this,
this, diag_node_diff_hierarchy_[i][j]))) {
return Maybe<void>::Ok();
}
} else {
// Different placements
CHECK_OR_RETURN(diag_node_diff_placement_.size() > 0)
<< "Have not initialzie the combination table for different hierarchies yet! "
"Please run JUST(GenerateCombination4DiffPlacement(this, this)); "
"before Asking sbp combination for different parallel description.";
if (JUST(Ask1Combination4DiffPlacement(
sbp_producer, sbp_consumer, logical_blob_desc, producer_parallel_desc,
consumer_parallel_desc, is_customized, middle_sbps, diag_node_pos, compute_cost, this,
this, diag_node_diff_placement_[i][j]))) {
return Maybe<void>::Ok();
}
}
}
// Customized boxing collector and try the algorithm again
if (is_customized) {
CHECK_OR_RETURN(compute_cost) << "Boxing does not support " << NdSbpToString(sbp_producer)
<< "[hierarchy: " << *producer_parallel_desc.hierarchy()
<< "] -> " << NdSbpToString(sbp_consumer)
<< "[hierarchy: " << *consumer_parallel_desc.hierarchy()
<< "] for blob shape: " << logical_blob_desc.shape();
return Maybe<void>::Ok();
}
// Customize boxing collector for producer
BoxingCollector customized_boxing_collector_producer;
JUST(customized_boxing_collector_producer.Init(logical_blob_desc, producer_parallel_desc));
// Customize boxing collector for consumer
BoxingCollector customized_boxing_collector_consumer;
JUST(customized_boxing_collector_consumer.Init(logical_blob_desc, consumer_parallel_desc));
std::vector<std::vector<int32_t>> diag_nodes;
// Generate the combination table for different hierarchies or placements
if (same_placement) {
JUST(customized_boxing_collector_producer.Generate1Combination4DiffHierarchy(
customized_boxing_collector_producer.FindId4NdSbp(sbp_producer),
customized_boxing_collector_consumer.FindId4NdSbp(sbp_consumer),
&customized_boxing_collector_producer, &customized_boxing_collector_consumer, diag_nodes));
} else {
// Compute the cost while transferring a 1D sbp between different placements
std::vector<std::vector<double>> cost_4_diff_placement;
JUST(ComputeCostFor1DSbpDiffPlacement(logical_blob_desc, producer_parallel_desc,
consumer_parallel_desc, cost_4_diff_placement));
JUST(customized_boxing_collector_producer.Generate1Combination4DiffPlacement(
customized_boxing_collector_producer.FindId4NdSbp(sbp_producer),
customized_boxing_collector_consumer.FindId4NdSbp(sbp_consumer),
&customized_boxing_collector_producer, &customized_boxing_collector_consumer,
cost_4_diff_placement, diag_nodes));
}
JUST(customized_boxing_collector_producer.Ask1Combination4DiffPlacement(
sbp_producer, sbp_consumer, logical_blob_desc, producer_parallel_desc, consumer_parallel_desc,
/*is_customized=*/true, middle_sbps, diag_node_pos, compute_cost,
&customized_boxing_collector_producer, &customized_boxing_collector_consumer, diag_nodes));
return Maybe<void>::Ok();
}
// Generate the transfer rule for one combination with different hierarchies on the same
// placement. id_producer -> id_consumer.
Maybe<void> BoxingCollector::Generate1Combination4DiffHierarchy(
int32_t id_producer, int32_t id_consumer, BoxingCollector* boxing_collector_producer,
BoxingCollector* boxing_collector_consumer, std::vector<std::vector<int32_t>>& diag_nodes) {
// Number of 1d sbp
int32_t m = id2sbp_parallel_.size();
// Search the path that contains one of the diagonal sbp
// minimum number of node
int32_t min_path_length = 100;
// minimum cost
double min_cost = GetValidMaxCopyCost();
for (int32_t id_1d = 0; id_1d < m; id_1d++) {
// We do not support [2, 3]: (S0, S1) -> [6]: S0 for a tensor with shape (14, 21)
// Thus, the diagonal node should suit both the hierarchies.
int32_t diag_producer = boxing_collector_producer->id_1d_2_nd_[id_1d];
if (diag_producer < 0) { continue; }
int32_t diag_consumer = boxing_collector_consumer->id_1d_2_nd_[id_1d];
if (diag_consumer < 0) { continue; }
// Find the path with minimum number of nodes
int32_t path_length = 0;
// Transfer from id_producer to id_2d
if (boxing_collector_producer->middle_nodes_[id_producer][diag_producer].size() > 0) {
path_length +=
boxing_collector_producer->middle_nodes_[id_producer][diag_producer][0].size() + 1;
} else if (id_producer != diag_producer) {
path_length++;
}
// Transfer from id_2d to id_consumer
if (boxing_collector_consumer->middle_nodes_[diag_consumer][id_consumer].size() > 0) {
path_length +=
boxing_collector_consumer->middle_nodes_[diag_consumer][id_consumer][0].size() + 1;
} else if (diag_consumer != id_consumer) {
path_length++;
}
// Pick the path with minimum copy cost
if (path_length <= min_path_length) {
double curr_cost =
boxing_collector_producer->minimum_copy_cost_[id_producer][diag_producer]
+ boxing_collector_consumer->minimum_copy_cost_[diag_consumer][id_consumer];
min_path_length = path_length;
// Find a candidate with small cost
if (curr_cost < min_cost * 1.0000001) {
// Find a smaller cost, clear the previous path.
if (curr_cost < min_cost * 0.9999999) {
min_cost = curr_cost;
diag_nodes.clear();
}
// Add the current diagonal node
// Asymmetry happens here. We can only store one side of the diagonal node.
// We do not store diag_consumer
diag_nodes.push_back({diag_producer, diag_consumer});
}
}
}
return Maybe<void>::Ok();
}
// Ask for one combination with different hierarchies and placements
Maybe<bool> BoxingCollector::Ask1Combination4DiffPlacement(
const NdSbp& sbp_producer, const NdSbp& sbp_consumer, const BlobDesc& logical_blob_desc,
const ParallelDesc& producer_parallel_desc, const ParallelDesc& consumer_parallel_desc,
bool is_customized, std::vector<NdSbp>& middle_sbps, int32_t* diag_node_pos, bool compute_cost,
BoxingCollector* boxing_collector_producer, BoxingCollector* boxing_collector_consumer,
const std::vector<std::vector<int32_t>>& diag_nodes) {
// Pick the path with minimum storage for the diagonal node
int32_t id_producer = boxing_collector_producer->FindId4NdSbp(sbp_producer);
if (id_producer < 0) {
CHECK_OR_RETURN(compute_cost) << "Source data with shape " << logical_blob_desc.shape()
<< " has an invalid sbp " << NdSbpToString(sbp_producer);
return false;
}
int32_t id_consumer = boxing_collector_consumer->FindId4NdSbp(sbp_consumer);
if (id_consumer < 0) {
CHECK_OR_RETURN(compute_cost) << "Target data with shape " << logical_blob_desc.shape()
<< " has an invalid sbp " << NdSbpToString(sbp_consumer);
return false;
}
middle_sbps.clear();
// NOTE: For simplicity, We do not dig into those storage cost for the other middle nodes at
// this moment.
double min_cost = GetValidMaxCopyCost();
int32_t producer_hierarchy_num_axes = producer_parallel_desc.hierarchy()->NumAxes();
int32_t consumer_hierarchy_num_axes = consumer_parallel_desc.hierarchy()->NumAxes();
int32_t min_diag_producer = -1, min_diag_consumer = -1;
for (const auto& diag_pair : diag_nodes) {
Shape logical_shape = logical_blob_desc.shape();
// We do not check whether such shape is valid under two side of the sbp list in the
// middle nodes algorithm. Thus, we need to check them here.
double curr_cost =
Storage4NdSbp(*JUST(SetNdSbpDim(boxing_collector_producer->nd_sbp_lists_[diag_pair[0]],
producer_hierarchy_num_axes)),
logical_shape, *producer_parallel_desc.hierarchy());
// Check the shape for both producer and consumer.
logical_shape = logical_blob_desc.shape();
curr_cost +=
Storage4NdSbp(*JUST(SetNdSbpDim(boxing_collector_consumer->nd_sbp_lists_[diag_pair[1]],
consumer_hierarchy_num_axes)),
logical_shape, *consumer_parallel_desc.hierarchy());
if (curr_cost < min_cost) {
min_cost = curr_cost;
min_diag_producer = diag_pair[0];
min_diag_consumer = diag_pair[1];
}
}
// Different placements: [2, 3] vs 5, or [3, 2] vs [2, 2], or cpu vs cuda
// Different hierarchies: [2, 3] vs 5, or [4, 3] vs [6, 2]
bool diff_placement = !producer_parallel_desc.EqualsIgnoringHierarchy(consumer_parallel_desc);
// If we found a diagonal middle node with current boxing collector
if (min_diag_producer >= 0) {
std::vector<NdSbp> middle_sbps_buffer;
// Find the middle nodes between the producer and the diagonal node
if (id_producer != min_diag_producer) {
JUST(boxing_collector_producer->AskSbpCombination(
sbp_producer, boxing_collector_producer->nd_sbp_lists_[min_diag_producer],
logical_blob_desc, producer_parallel_desc, producer_parallel_desc,
/*is_customized=*/false, middle_sbps_buffer, diag_node_pos, compute_cost));
// Add the path into middle_sbps
for (auto& middle_sbp : middle_sbps_buffer) {
middle_sbps.emplace_back(*JUST(SetNdSbpDim(middle_sbp, producer_hierarchy_num_axes)));
}
// If different placement,
// or the same placement but with 2D hierarchies
// For example: Oneflow supports [6]: (S0) -> [3, 2]: (S0, S1)
// but does not support [2, 3]: (S0, S0) -> [3, 2]: (S0, S1)
if (diff_placement || producer_hierarchy_num_axes > 1) {
middle_sbps.emplace_back(
*JUST(SetNdSbpDim(boxing_collector_producer->nd_sbp_lists_[min_diag_producer],
producer_hierarchy_num_axes)));
}
}
// If we do not have middle nodes on the consumer side
*diag_node_pos = middle_sbps.size();
// Find the middle nodes between the diagonal node and the consumer
if (id_consumer != min_diag_consumer) {
JUST(boxing_collector_consumer->AskSbpCombination(
boxing_collector_consumer->nd_sbp_lists_[min_diag_consumer], sbp_consumer,
logical_blob_desc, consumer_parallel_desc, consumer_parallel_desc,
/*is_customized=*/false, middle_sbps_buffer, diag_node_pos, compute_cost));
// Set the diagonal node position and stop using it as buffer
*diag_node_pos = middle_sbps.size();
// If different placement
if (diff_placement || consumer_hierarchy_num_axes > 1) {
middle_sbps.emplace_back(
*JUST(SetNdSbpDim(boxing_collector_consumer->nd_sbp_lists_[min_diag_consumer],
consumer_hierarchy_num_axes)));
}
// Add the path into middle_sbps
for (auto& middle_sbp : middle_sbps_buffer) {
middle_sbps.emplace_back(*JUST(SetNdSbpDim(middle_sbp, consumer_hierarchy_num_axes)));
}
}
return true;
}
return false;
}
// Generate the transfer rule for one combination with different placements
// id_producer -> id_consumer.
Maybe<void> BoxingCollector::Generate1Combination4DiffPlacement(
int32_t id_producer, int32_t id_consumer, BoxingCollector* boxing_collector_producer,
BoxingCollector* boxing_collector_consumer,
const std::vector<std::vector<double>>& cost_4_diff_placement,
std::vector<std::vector<int32_t>>& diag_nodes) {
// Number of 1d sbp
int32_t m = id2sbp_parallel_.size();
// minimum number of node
int32_t min_path_length = 100;
// minimum cost
double min_cost = GetValidMaxCopyCost();
// Search the path that contains two of the diagonal sbp
// From the producer to the first diagonal node
for (int32_t id_1d_producer = 0; id_1d_producer < m; id_1d_producer++) {
// We do not support [2, 3]: (S0, S1) -> [6]: S0 for a tensor with shape (14, 21)
// Thus, the diagonal node should suit both the hierarchies.
int32_t diag_producer = boxing_collector_producer->id_1d_2_nd_[id_1d_producer];
if (diag_producer < 0
|| boxing_collector_producer->minimum_copy_cost_[id_producer][diag_producer]
> GetValidMaxCopyCost()) {
continue;
}
// Find the path with minimum number of nodes
int32_t path_length = 0;
// Transfer from id_producer to diag_producer
if (boxing_collector_producer->middle_nodes_[id_producer][diag_producer].size() > 0) {
path_length +=
boxing_collector_producer->middle_nodes_[id_producer][diag_producer][0].size() + 1;
} else if (id_producer != diag_producer) {
path_length++;
}
// pruning
if (path_length > min_path_length) { continue; }
// From the second diagonal node to the consumer
for (int32_t id_1d_consumer = 0; id_1d_consumer < m; id_1d_consumer++) {
int32_t diag_consumer = boxing_collector_consumer->id_1d_2_nd_[id_1d_consumer];
// The diagonal sbp is not supported or no paths exist from the diagonal sbp to the
// consumer or between the two diagonal sbps.
if (diag_consumer < 0
|| boxing_collector_consumer->minimum_copy_cost_[diag_consumer][id_consumer]
> GetValidMaxCopyCost()
|| cost_4_diff_placement[id_1d_producer][id_1d_consumer] > GetValidMaxCopyCost()) {
continue;
}
// Transfer from diag_consumer to id_consumer
int32_t curr_path_length = path_length;
if (boxing_collector_consumer->middle_nodes_[diag_consumer][id_consumer].size() > 0) {
curr_path_length +=
boxing_collector_consumer->middle_nodes_[diag_consumer][id_consumer][0].size() + 1;
} else if (diag_consumer != id_consumer) {
curr_path_length++;
}
// Pick the path with minimum copy cost
if (curr_path_length <= min_path_length) {
double curr_cost =
boxing_collector_producer->minimum_copy_cost_[id_producer][diag_producer]
+ cost_4_diff_placement[id_1d_producer][id_1d_consumer]
+ boxing_collector_consumer->minimum_copy_cost_[diag_consumer][id_consumer];
min_path_length = curr_path_length;
// Find a candidate with small cost
if (curr_cost < min_cost * 1.0000001) {
// Find a smaller cost, clear the previous path.
if (curr_cost < min_cost * 0.9999999) {
min_cost = curr_cost;
diag_nodes.clear();
}
// Add the current diagonal node
// Asymmetry happens here. We can only store one side of the diagonal node.
// We do not store diag_consumer
diag_nodes.push_back({diag_producer, diag_consumer});
}
}
}
}
return Maybe<void>::Ok();
}
// Filter nd sbp from nd_sbp_lists_ with given logical shape
Maybe<void> BoxingCollector::FilterNdSbpList4LogicalShape(const BlobDesc& logical_blob_desc,
const Shape& parallel_hierarchy) {
for (int32_t middle_sbp_id = nd_sbp_lists_.size() - 1; middle_sbp_id >= 0; middle_sbp_id--) {
Shape logical_shape = logical_blob_desc.shape();
if (JUST(FilterNdSbpByLogicalShape(nd_sbp_lists_[middle_sbp_id], logical_shape,
parallel_hierarchy))) {
// Change the value before erasing
// This might be true: nd_sbp_lists_.size() - 1 == middle_sbp_id
nd_sbp_universe_[nd_sbp_lists_[nd_sbp_lists_.size() - 1]] = middle_sbp_id;
nd_sbp_universe_.erase(nd_sbp_lists_[middle_sbp_id]);
nd_sbp_lists_[middle_sbp_id] = nd_sbp_lists_[nd_sbp_lists_.size() - 1];
nd_sbp_lists_.pop_back();
}
}
return Maybe<void>::Ok();
}
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_AUTO_PARALLEL_BOXING_COLLECTOR_H_
#define ONEFLOW_CORE_AUTO_PARALLEL_BOXING_COLLECTOR_H_
#include "oneflow/core/common/hash_container.h"
#include "oneflow/core/job/parallel_desc.h"
#include "oneflow/core/job/sbp_parallel.h"
#include "oneflow/core/framework/sbp_infer_util.h"
namespace oneflow {
class BoxingCollector final {
public:
BoxingCollector() = default;
~BoxingCollector() = default;
// A constructor with init, designed for uncustomized boxing collector
BoxingCollector(int32_t max_axis);
// Set default Sbp list
void CollectUniverse(int32_t max_axis);
// Construct a boxing collector with given maximum number of axis
Maybe<void> Init(int32_t max_axis);
// Init with given blob description
Maybe<void> Init(const BlobDesc& logical_blob_desc, const ParallelDesc& parallel_desc);
// Generate nd sbp list
void GenerateNdSbpList(int32_t hierarchy_num);
// Generate the map from 1d sbp to 2d sbp
void GenerateMap1d2nd();
// Generate the transfer rule for different combinations with the same hierarchy
Maybe<void> GenerateCombination4SamePlacement(int32_t max_middle_node_num);
Maybe<void> GenerateCombination4SamePlacement(int32_t max_middle_node_num,
const BlobDesc& blob_desc,
const ParallelDesc& parallel_desc);
// Generate the transfer rule for different combinations with different hierarchies
// on the same placement
Maybe<void> GenerateCombination4DiffHierarchy(BoxingCollector* boxing_collector_producer,
BoxingCollector* boxing_collector_consumer);
// Generate the transfer rule for different combinations with different placements
Maybe<void> GenerateCombination4DiffPlacement(BoxingCollector* boxing_collector_producer,
BoxingCollector* boxing_collector_consumer);
Maybe<void> GenerateCombination4DiffPlacement(BoxingCollector* boxing_collector_producer,
BoxingCollector* boxing_collector_consumer,
const BlobDesc& blob_desc,
const ParallelDesc& in_parallel_desc,
const ParallelDesc& out_parallel_desc);
// Print the cost and middle nodes
void PrintBoxingTables();
// Ask if the boxing algorithm accepts the current sbp combination
// If is_customized is true and we can not find a middle node list with
// resonable cost, error occurs.
// If compute_cost is true, then no error occur even if no suitable middle nodes paths found.
// For different placements, we would return a diagonal node.
// Before this diagonal node (< *diag_node_pos), we use the parallel description of the producer.
// After this diagonal node (>= *diag_node_pos), we use the parallel description of the consumer.
Maybe<void> AskSbpCombination(const NdSbp& sbp_producer, const NdSbp& sbp_consumer,
const BlobDesc& logical_blob_desc,
const ParallelDesc& producer_parallel_desc,
const ParallelDesc& consumer_parallel_desc, bool is_customized,
std::vector<NdSbp>& middle_sbps, int32_t* diag_node_pos,
bool compute_cost);
// Filter nd sbp from nd_sbp_lists_ with given logical shape
Maybe<void> FilterNdSbpList4LogicalShape(const BlobDesc& logical_blob_desc,
const Shape& parallel_hierarchy);
private:
// Collect Sbp Parallel
void CollectUniverse(const SbpParallel& sbp);
// Find corresponding id for Nd sbp
int32_t FindId4NdSbp(const NdSbp& nd_sbp);
// Ask for sbp combination with the same 2-D hierarchy and placement
Maybe<void> AskSbpCombination4Same2DPlacement(const NdSbp& sbp_producer,
const NdSbp& sbp_consumer,
const BlobDesc& logical_blob_desc,
const ParallelDesc& producer_parallel_desc,
const ParallelDesc& consumer_parallel_desc,
bool is_customized, std::vector<NdSbp>& middle_sbps,
int32_t* diag_node_pos, bool compute_cost);
// Ask for sbp combination with different hierarchies on the same placement
Maybe<void> AskSbpCombination4DiffPlacement(const NdSbp& sbp_producer, const NdSbp& sbp_consumer,
const BlobDesc& logical_blob_desc,
const ParallelDesc& producer_parallel_desc,
const ParallelDesc& consumer_parallel_desc,
bool is_customized, std::vector<NdSbp>& middle_sbps,
int32_t* diag_node_pos, bool compute_cost);
// Generate the transfer rule for one combination with different hierarchies on the same
// placement. id_producer -> id_consumer.
Maybe<void> Generate1Combination4DiffHierarchy(int32_t id_producer, int32_t id_consumer,
BoxingCollector* boxing_collector_producer,
BoxingCollector* boxing_collector_consumer,
std::vector<std::vector<int32_t>>& diag_nodes);
// The cost for transferring a 1D sbp between different placements
Maybe<void> ComputeCostFor1DSbpDiffPlacement(
const BlobDesc& blob_desc, const ParallelDesc& in_parallel_desc,
const ParallelDesc& out_parallel_desc,
std::vector<std::vector<double>>& cost_4_diff_placement);
// Generate the transfer rule for one combination with different placements
// id_producer -> id_consumer.
Maybe<void> Generate1Combination4DiffPlacement(
int32_t id_producer, int32_t id_consumer, BoxingCollector* boxing_collector_producer,
BoxingCollector* boxing_collector_consumer,
const std::vector<std::vector<double>>& cost_4_diff_placement,
std::vector<std::vector<int32_t>>& diag_nodes);
// Ask for one combination with different hierarchies and placements
Maybe<bool> Ask1Combination4DiffPlacement(const NdSbp& sbp_producer, const NdSbp& sbp_consumer,
const BlobDesc& logical_blob_desc,
const ParallelDesc& producer_parallel_desc,
const ParallelDesc& consumer_parallel_desc,
bool is_customized, std::vector<NdSbp>& middle_sbps,
int32_t* diag_node_pos, bool compute_cost,
BoxingCollector* boxing_collector_producer,
BoxingCollector* boxing_collector_consumer,
const std::vector<std::vector<int32_t>>& diag_nodes);
// Stores all the possible SbpParallel.
HashMap<::oneflow::SbpParallel, int32_t> sbp_parallel_universe_;
// Relationship between id and Sbp Parallel
std::vector<::oneflow::SbpParallel> id2sbp_parallel_;
// minimum cost
// minimum_copy_cost[producer][consumer]
std::vector<std::vector<double>> minimum_copy_cost_;
// middle nodes
// middle_nodes_[producer][consumer][different choices] is a vector of middle nodes
// middle_nodes_[producer][consumer][different choices].size() is the minimum number of middle
// nodes that needs to be inserted
std::vector<std::vector<std::vector<std::vector<int32_t>>>> middle_nodes_;
// Stores all the possible NdSbp.
std::unordered_map<::oneflow::NdSbp, int32_t> nd_sbp_universe_;
// Relationship between id and Nd Sbp
std::vector<NdSbp> nd_sbp_lists_;
// The diagonal middle node for differe placements
std::vector<std::vector<std::vector<std::vector<int32_t>>>> diag_node_diff_placement_;
// The diagonal middle node for differe hierarchies in the same placement
std::vector<std::vector<std::vector<std::vector<int32_t>>>> diag_node_diff_hierarchy_;
// Id Map from 1d sbp to 2d sbp
// For example: B -> (B, B), S0 -> (S0, S0)
std::vector<int32_t> id_1d_2_nd_;
// The sbp size in the combination table
int32_t hierarchy_num_;
}; // class BoxingCollector
} // namespace oneflow
#endif // ONEFLOW_CORE_AUTO_PARALLEL_BOXING_COLLECTOR_H_
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_AUTOGRAD_AUTOGRAD_CAPTURED_TENSOR_H_
#define ONEFLOW_CORE_AUTOGRAD_AUTOGRAD_CAPTURED_TENSOR_H_
#include "oneflow/core/framework/tensor.h"
#include "oneflow/core/functional/functional.h"
namespace oneflow {
namespace one {
class AutogradCapturedTensor final : public ProxyTensor<AutogradCapturedTensor> {
public:
static Maybe<AutogradCapturedTensor> MakeTensor(const std::shared_ptr<Tensor>& tensor) {
if (tensor->requires_grad()) {
CHECK_NOTNULL_OR_RETURN(tensor->grad_fn_node().get())
<< Error::RuntimeError()
<< "a grad function node is expected for the captured tensor "
"which requires_grad is True.";
}
std::shared_ptr<AutogradCapturedTensor> captured_tensor(
new AutogradCapturedTensor(JUST(tensor->detach())));
captured_tensor->set_autograd_meta(tensor->mut_autograd_meta());
captured_tensor->grad_fn_node_ = tensor->mut_grad_fn_node();
return captured_tensor;
}
std::shared_ptr<const FunctionNode> grad_fn_node() const override { return grad_fn_node_.lock(); }
void set_grad_fn_node(const std::shared_ptr<FunctionNode>& grad_fn_node) override {
PRINT_BUG_PROMPT_AND_ABORT();
}
std::shared_ptr<FunctionNode> mut_grad_fn_node() override { return grad_fn_node_.lock(); }
std::shared_ptr<Tensor> contiguous() const override {
const auto& tensor = std::const_pointer_cast<Tensor>(shared_from_this());
if (tensor_->is_contiguous()) { return tensor; }
return CHECK_JUST(functional::ToContiguous(tensor));
}
private:
explicit AutogradCapturedTensor(const std::shared_ptr<Tensor>& tensor)
: ProxyTensor<AutogradCapturedTensor>(tensor) {}
private:
std::weak_ptr<FunctionNode> grad_fn_node_;
};
} // namespace one
} // namespace oneflow
#endif // ONEFLOW_CORE_AUTOGRAD_AUTOGRAD_CAPTURED_TENSOR_H_
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include <memory>
#include <stack>
#include <queue>
#include "oneflow/core/autograd/autograd_engine.h"
#include "oneflow/core/autograd/autograd_meta.h"
#include "oneflow/core/framework/tensor.h"
#include "oneflow/core/framework/tensor_arg.h"
#include "oneflow/core/framework/tensor_tuple.h"
#include "oneflow/core/framework/tensor_rpc_util.h"
#include "oneflow/core/autograd/autograd_mode.h"
#include "oneflow/core/functional/functional.h"
#include "oneflow/core/framework/nd_sbp.h"
#include "oneflow/core/framework/global_param_grad_sync_mode.h"
#include "oneflow/core/common/container_util.h"
namespace oneflow {
namespace one {
namespace {
void GatherFunctionNodes(FunctionNode* node, std::stack<std::shared_ptr<FunctionNode>>& stack) {
for (auto& prev_node : node->next_functions()) {
if (prev_node) {
if (prev_node.use_count() == 1) { stack.push(prev_node); }
}
}
}
/* NOTE:
* Stack overflows when releasing a very deep computation graph without
* a custom deleter.
*
* For example, here is a very deep computation graph:
* Tensor -> FunctionNode -> Tensor -> FunctionNode -> ... -> Tensor -> FunctionNode
* When releasing the first Tensor, it will trigger the recursive deletion and stack overflow.
*
* So we must set a custom deleter and release them iteratively.
*/
void FunctionNodeDeleter(FunctionNode* node) {
std::stack<std::shared_ptr<FunctionNode>> stack;
node->ReleaseData();
GatherFunctionNodes(node, stack);
delete node;
while (!stack.empty()) {
auto now_node = std::move(stack.top());
stack.pop();
now_node->ReleaseData();
GatherFunctionNodes(now_node.get(), stack);
}
}
bool IsReadyToRun(const std::vector<std::shared_ptr<AutogradMeta>>& out_meta_datas) {
return std::any_of(out_meta_datas.begin(), out_meta_datas.end(),
[](const std::shared_ptr<AutogradMeta>& meta_data) {
return !meta_data->current_grad()->Empty();
});
}
Maybe<void> CopyOrAccGrad(AutogradMeta* autograd_meta, bool autograd_mode) {
autograd::AutoGradMode mode(autograd_mode);
auto current_grad = JUST(autograd_meta->current_grad()->GetAccTensor({}));
if (!current_grad) { return Maybe<void>::Ok(); }
if (autograd_meta->acc_grad()) {
// Should not inplace accumulate grad. For example,
// >>> z = x + y
// >>> p = x / z
// >>> p.sum().backward()
//
// As we know that dx = dz + dp / z and dy = dz, so it will lead to wrong value
// for dy if dx is shared with dz.
const auto& output = JUST(functional::Add(autograd_meta->acc_grad(), current_grad, /*alpha=*/1,
/*inplace=*/autograd_meta->is_grad_acc_inplace()));
JUST(autograd_meta->set_acc_grad(output));
} else {
JUST(autograd_meta->set_acc_grad(current_grad));
}
for (const auto& hook : autograd_meta->post_grad_accumulation_hooks()) {
auto new_grad = hook(autograd_meta->acc_grad());
if (new_grad) { JUST(autograd_meta->set_acc_grad(new_grad)); }
}
return Maybe<void>::Ok();
}
Maybe<void> RawTorchConsistentTensor(const std::shared_ptr<one::Tensor>& tensor) {
// Do nothing.
return Maybe<void>::Ok();
}
static constexpr auto* TorchConsistentTensor =
DECORATE(&RawTorchConsistentTensor, CheckConsistentTensorMeta);
Maybe<void> CheckConsistentTensorsMeta(const TensorTuple& tensor_tuple) {
for (const auto& tensor : tensor_tuple) {
if (tensor->is_consistent()) { JUST(TorchConsistentTensor(tensor)); }
}
return Maybe<void>::Ok();
}
} // namespace
Maybe<void> AutogradEngine::RunBackwardAndSaveGrads4LeafTensorIf(const TensorTuple& outputs,
const TensorTuple& out_grads,
bool retain_graph,
bool create_graph) {
JUST(CheckConsistentTensorsMeta(outputs));
JUST(CheckConsistentTensorsMeta(out_grads));
DisableCheckConsistentTensorMetaScope disable_meta_check;
return RunBackwardAndSaveGrads4LeafTensor(outputs, out_grads, retain_graph, create_graph);
}
Maybe<TensorTuple> AutogradEngine::RunBackwardAndReturnInputsTensorGradIf(
const TensorTuple& outputs, const TensorTuple& inputs, const TensorTuple& out_grads,
bool retain_graph, bool create_graph) {
JUST(CheckConsistentTensorsMeta(outputs));
JUST(CheckConsistentTensorsMeta(inputs));
JUST(CheckConsistentTensorsMeta(out_grads));
DisableCheckConsistentTensorMetaScope disable_meta_check;
return RunBackwardAndReturnInputsTensorGrad(outputs, inputs, out_grads, retain_graph,
create_graph);
}
Maybe<void> FunctionNode::AccGrad4RetainGradTensor() {
for (const std::shared_ptr<AutogradMeta>& out : output_meta_data_) {
if (out->retain_grad()) { JUST(CopyOrAccGrad(out.get(), /*autograd_mode=*/false)); }
}
return Maybe<void>::Ok();
}
Maybe<void> FunctionNode::AccGrad4LeafTensor(bool create_graph) {
for (auto i = 0; i < output_meta_data_.size(); i++) {
auto& out = output_meta_data_[i];
if (out->is_leaf() && out->requires_grad()) {
JUST(CopyOrAccGrad(out.get(), /*autograd_mode=*/false));
// control acc_grad to do boxing conditionally
const auto& acc_grad = out->acc_grad();
if (GlobalGradSyncMode::is_enabled() && acc_grad->is_consistent()) {
auto& tensor_info = output_tensor_infos_[i];
const auto& placement = JUST(tensor_info.placement());
const auto& nd_sbp = JUST(tensor_info.sbp());
JUST(out->set_acc_grad(
JUST(functional::ToConsistent(acc_grad, placement, *JUST(GetSbpList(nd_sbp)),
GetNoneSbpList(), /* check_meta */ false))));
}
}
}
return Maybe<void>::Ok();
}
void FunctionNode::ReleaseOutTensorArgs() {
for (const std::shared_ptr<AutogradMeta>& meta_data : output_meta_data_) {
meta_data->current_grad()->Release();
}
}
Maybe<bool> FunctionNode::Apply(bool create_graph) {
CHECK_NOTNULL_OR_RETURN(backward_fn_)
<< "This FunctionNode with name `" << name() << "` has been released.\n"
<< "Maybe you try to backward through the node a second time. Specify retain_graph=True when "
"calling .backward() or autograd.grad() the first time.";
if (!IsReadyToRun(output_meta_data_)) { return false; }
TensorTuple input_grads(input_meta_data_.size());
TensorTuple output_grads(output_meta_data_.size());
for (int i = 0; i < output_meta_data_.size(); ++i) {
if (output_meta_data_.at(i)->current_grad()->Empty()) {
output_grads.at(i) = JUST(output_tensor_infos_.at(i).zeros());
} else {
const auto& hooks = JUST(oneflow::VectorAt(output_meta_data_, i))->hooks();
JUST(oneflow::VectorAt(output_grads, i)) =
JUST(JUST(oneflow::VectorAt(output_meta_data_, i))->current_grad()->GetAccTensor(hooks));
}
}
JUST(backward_fn_->body(output_grads, &input_grads, create_graph));
for (int i = 0; i < input_meta_data_.size(); ++i) {
if (JUST(VectorAt(input_grads, i))) {
CHECK_NOTNULL_OR_RETURN(input_meta_data_.at(i))
<< name_
<< " calculate grad for tensor which requires_grad is False. Please submit an issue in "
"`https://github.com/Oneflow-Inc/oneflow/issues` and we will fix it as soon as "
"possible";
JUST(input_meta_data_.at(i)->current_grad()->PushPartialTensor(input_grads.at(i)));
}
}
return true;
}
void GraphFunctionNode::ReleaseData() {
if (backward_fn_ && backward_fn_->status()) { backward_fn_.reset(); }
}
/*static*/ std::shared_ptr<GraphFunctionNode> GraphFunctionNode::New(
const std::string& name, const std::shared_ptr<BackwardFunction>& backward_fn,
const TensorTuple& inputs, const TensorTuple& outputs) {
auto node = std::shared_ptr<GraphFunctionNode>(
new GraphFunctionNode(name, backward_fn, inputs, outputs), FunctionNodeDeleter);
return node;
}
GraphFunctionNode::GraphFunctionNode(const std::string& name,
const std::shared_ptr<BackwardFunction>& backward_fn,
const TensorTuple& inputs, const TensorTuple& outputs)
: FunctionNode(name, backward_fn) {
input_meta_data_.resize(inputs.size());
next_functions_.reserve(inputs.size());
for (int i = 0; i < inputs.size(); ++i) {
if (inputs.at(i)->requires_grad()) {
input_meta_data_.at(i) = inputs.at(i)->mut_autograd_meta();
next_functions_.emplace_back(inputs.at(i)->mut_grad_fn_node());
}
}
output_meta_data_.resize(outputs.size());
output_tensor_infos_.reserve(outputs.size());
for (int i = 0; i < outputs.size(); ++i) {
const auto& autograd_meta =
NewAutogradMeta(outputs.at(i)->requires_grad(), outputs.at(i)->is_leaf());
outputs.at(i)->set_autograd_meta(autograd_meta);
output_meta_data_.at(i) = outputs.at(i)->mut_autograd_meta();
output_tensor_infos_.emplace_back(TensorInfo(*outputs.at(i)));
}
backward_fn_ = backward_fn;
}
GraphTask::GraphTask(const TensorTuple& outputs, bool retain_graph, bool create_graph)
: retain_graph_(retain_graph), create_graph_(create_graph) {
roots_.reserve(outputs.size());
for (const auto& out_tensor : outputs) {
FunctionNode* node = out_tensor->mut_grad_fn_node().get();
roots_.emplace_back(node);
dependencies_.insert(std::make_pair(node, 0));
}
}
// Computes the number of dependencies for each FunctionNode
Maybe<void> GraphTask::ComputeDependencies() {
HashSet<FunctionNode*> seen;
std::stack<FunctionNode*> stack;
for (FunctionNode* node : roots_) { stack.push(node); }
while (!stack.empty()) {
FunctionNode* node = stack.top();
stack.pop();
if (/*bool has_seen=*/!seen.insert(node).second) { continue; }
for (const auto& next_grad_fn : node->next_functions()) {
FunctionNode* next_node = next_grad_fn.get();
dependencies_[next_node] += 1;
if (seen.find(next_node) == seen.end()) { stack.push(next_node); }
}
}
return Maybe<void>::Ok();
}
// Computes the number of dependencies for each FunctionNode and prunes useless FunctionNode
// according to input tensors
Maybe<void> GraphTask::ComputeDependenciesAndPruneNode(const TensorTuple& inputs) {
struct NodeFrame {
explicit NodeFrame(FunctionNode* node) : node_(node), next_function_idx_(0) {}
FunctionNode* node_;
size_t next_function_idx_;
FunctionNode* GetNextFunction() {
if (next_function_idx_ < node_->next_functions().size()) {
next_function_idx_ += 1;
return node_->next_functions().at(next_function_idx_ - 1).get();
} else {
return nullptr;
}
}
};
for (const auto& input : inputs) {
CHECK_NOTNULL_OR_RETURN(input->mut_grad_fn_node().get());
need_execute_.insert(input->mut_grad_fn_node().get());
}
HashSet<FunctionNode*> seen;
std::stack<NodeFrame> stack;
// Note: dfs to determine each FunctionNode should execute or not.
for (const auto& root : roots_) { stack.push(NodeFrame(root)); }
while (!stack.empty()) {
NodeFrame& frame = stack.top();
if (/*bool has_seen=*/seen.find(frame.node_) != seen.end()) {
stack.pop();
continue;
}
if (FunctionNode* node = frame.GetNextFunction()) {
dependencies_[node] += 1;
if (seen.find(node) == seen.end()) {
stack.push(NodeFrame(node));
continue; // recurse
}
} else {
bool need_execute =
std::any_of(frame.node_->next_functions().begin(), frame.node_->next_functions().end(),
[&](const std::shared_ptr<FunctionNode>& fn) {
return need_execute_.find(fn.get()) != need_execute_.end();
});
if (need_execute) { need_execute_.insert(frame.node_); }
seen.insert(frame.node_);
stack.pop();
}
}
return Maybe<void>::Ok();
}
Maybe<void> GraphTask::Apply(bool save_grad_for_leaf) {
std::queue<FunctionNode*> queue;
for (FunctionNode* node : roots_) {
if (dependencies_[node] == 0) { queue.push(node); }
}
while (!queue.empty()) {
FunctionNode* node = queue.front();
queue.pop();
if (!need_execute_.empty() && need_execute_.find(node) == need_execute_.end()) {
node->ReleaseOutTensorArgs();
continue;
}
if (/*bool not_ready_to_apply=*/!(JUST(node->Apply(create_graph_)))) { continue; }
if (save_grad_for_leaf) { JUST(node->AccGrad4LeafTensor(create_graph_)); }
JUST(node->AccGrad4RetainGradTensor());
node->ReleaseOutTensorArgs();
if (!retain_graph_) { node->ReleaseData(); }
for (const auto& next_grad_fn : node->next_functions()) {
FunctionNode* next_node = next_grad_fn.get();
dependencies_[next_node] -= 1;
if (dependencies_[next_node] == 0) { queue.push(next_node); }
}
}
return Maybe<void>::Ok();
}
Maybe<void> GraphAutogradEngine::RunBackwardAndSaveGrads4LeafTensor(const TensorTuple& outputs,
const TensorTuple& out_grads,
bool retain_graph,
bool create_graph) {
for (int i = 0; i < outputs.size(); ++i) {
JUST(JUST(outputs.at(i)->current_grad())->PushPartialTensor(out_grads.at(i)));
}
GraphTask graph_task(outputs, retain_graph, create_graph);
JUST(graph_task.ComputeDependencies());
JUST(graph_task.Apply(/*save_grad_for_leaf=*/true));
return Maybe<void>::Ok();
}
Maybe<TensorTuple> GraphAutogradEngine::RunBackwardAndReturnInputsTensorGrad(
const TensorTuple& outputs, const TensorTuple& inputs, const TensorTuple& out_grads,
bool retain_graph, bool create_graph) {
std::shared_ptr<TensorTuple> input_current_grad = std::make_shared<TensorTuple>(inputs.size());
GraphTask graph_task(outputs, retain_graph, create_graph);
std::vector<bool> ori_retain_grad(inputs.size());
for (int i = 0; i < inputs.size(); ++i) {
ori_retain_grad.at(i) = inputs.at(i)->retain_grad();
JUST(inputs.at(i)->set_retain_grad(true));
}
for (int i = 0; i < outputs.size(); ++i) {
JUST(JUST(outputs.at(i)->current_grad())->PushPartialTensor(out_grads.at(i)));
}
JUST(graph_task.ComputeDependenciesAndPruneNode(inputs));
JUST(graph_task.Apply(/*save_grad_for_leaf=*/false));
// Gets input grads and resume retain_grad
for (int i = 0; i < inputs.size(); ++i) {
input_current_grad->at(i) = JUST(inputs.at(i)->acc_grad());
if (!ori_retain_grad.at(i)) {
JUST(inputs.at(i)->set_acc_grad(nullptr));
JUST(inputs.at(i)->set_retain_grad(false));
}
}
return input_current_grad;
}
Maybe<FunctionNode> GraphAutogradEngine::AddNode(
const std::string& name, const std::shared_ptr<BackwardFunction>& backward_fn,
const TensorTuple& inputs, TensorTuple* outputs) {
// Firstly push function_node of tensor in stack which is leaf and requires_grad
for (const std::shared_ptr<Tensor>& in_tensor : inputs) {
if (in_tensor->is_leaf() && in_tensor->requires_grad()) {
if (!in_tensor->grad_fn_node()) { JUST(AddAccumulateFunctionNode(in_tensor)); }
}
}
std::shared_ptr<FunctionNode> func_node =
GraphFunctionNode::New(name, backward_fn, inputs, *outputs);
for (const std::shared_ptr<Tensor>& out_tensor : *outputs) {
out_tensor->set_grad_fn_node(func_node);
}
return func_node;
}
AutogradEngine* GetThreadLocalAutogradEngine() {
thread_local static GraphAutogradEngine autograd_engine;
return &autograd_engine;
}
Maybe<void> AddAccumulateFunctionNode(const std::shared_ptr<Tensor>& tensor) {
auto backward_fn = std::make_shared<BackwardFunction>();
backward_fn->body = [=](const TensorTuple& out_grads, TensorTuple* in_grads,
bool create_graph) -> Maybe<void> { return Maybe<void>::Ok(); };
backward_fn->status = []() { return false; };
tensor->set_grad_fn_node(GraphFunctionNode::New(
"accumulate_grad", backward_fn, /*inputs=*/TensorTuple{}, /*outputs*/ TensorTuple{tensor}));
return Maybe<void>::Ok();
}
} // namespace one
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_AUTOGRAD_AUTOGRAD_ENGINE_H_
#define ONEFLOW_CORE_AUTOGRAD_AUTOGRAD_ENGINE_H_
#include <list>
#include <vector>
#include <memory>
#include <functional>
#include "oneflow/core/common/util.h"
#include "oneflow/core/autograd/autograd_meta.h"
namespace oneflow {
namespace one {
class Tensor;
class TensorTuple;
using CaptureStatus = bool;
struct BackwardFunction {
std::function<Maybe<void>(const TensorTuple&, TensorTuple*, bool)> body;
std::function<CaptureStatus()> status;
};
// Calculates one backward op
class FunctionNode {
public:
virtual ~FunctionNode() = default;
Maybe<bool> Apply(bool create_graph);
Maybe<void> AccGrad4LeafTensor(bool create_graph);
Maybe<void> AccGrad4RetainGradTensor();
void ReleaseOutTensorArgs();
// Releases the eventual c++ std::function for backward if retain_graph=False to avoid calling
// `Apply` in second time
virtual void ReleaseData() = 0;
const std::vector<std::shared_ptr<FunctionNode>>& next_functions() const {
return next_functions_;
}
const std::string& name() const { return name_; }
protected:
explicit FunctionNode(const std::string& name,
const std::shared_ptr<BackwardFunction>& backward_fn)
: name_(name), backward_fn_(backward_fn) {}
const std::string name_;
std::vector<std::shared_ptr<FunctionNode>> next_functions_;
std::vector<std::shared_ptr<AutogradMeta>> input_meta_data_;
std::vector<std::shared_ptr<AutogradMeta>> output_meta_data_;
std::vector<TensorInfo> output_tensor_infos_;
// Actual backward function builds in `AutogradInterpreter` to calculate one backward op
std::shared_ptr<BackwardFunction> backward_fn_;
};
class AutogradEngine {
public:
virtual ~AutogradEngine() = default;
Maybe<void> RunBackwardAndSaveGrads4LeafTensorIf(const TensorTuple& outputs,
const TensorTuple& out_grads, bool retain_graph,
bool create_graph);
Maybe<TensorTuple> RunBackwardAndReturnInputsTensorGradIf(const TensorTuple& outputs,
const TensorTuple& inputs,
const TensorTuple& out_grads,
bool retain_graph, bool create_graph);
virtual void ClearEngine() = 0;
// Builds FunctionNode, binding to all `outputs_` tensors and saving in AutogradEngine
virtual Maybe<FunctionNode> AddNode(const std::string& name,
const std::shared_ptr<BackwardFunction>& backward_fn,
const TensorTuple& inputs, TensorTuple* outputs) = 0;
protected:
AutogradEngine() = default;
private:
virtual Maybe<void> RunBackwardAndSaveGrads4LeafTensor(const TensorTuple& outputs,
const TensorTuple& out_grads,
bool retain_graph, bool create_graph) = 0;
virtual Maybe<TensorTuple> RunBackwardAndReturnInputsTensorGrad(const TensorTuple& outputs,
const TensorTuple& inputs,
const TensorTuple& out_grads,
bool retain_graph,
bool create_graph) = 0;
};
// Graph Autograd Node and Engine
class GraphFunctionNode final : public FunctionNode {
public:
OF_DISALLOW_COPY_AND_MOVE(GraphFunctionNode);
static std::shared_ptr<GraphFunctionNode> New(
const std::string& name, const std::shared_ptr<BackwardFunction>& backward_fn,
const TensorTuple& inputs, const TensorTuple& outputs);
GraphFunctionNode() = delete;
~GraphFunctionNode() override = default;
void ReleaseData() override;
private:
GraphFunctionNode(const std::string& name, const std::shared_ptr<BackwardFunction>& backward_fn,
const TensorTuple& inputs, const TensorTuple& outputs);
};
class GraphTask final {
public:
OF_DISALLOW_COPY_AND_MOVE(GraphTask);
GraphTask() = delete;
GraphTask(const TensorTuple& outputs, bool retain_graph, bool create_graph);
Maybe<void> ComputeDependencies();
Maybe<void> ComputeDependenciesAndPruneNode(const TensorTuple& inputs);
Maybe<void> Apply(bool save_grad_for_leaf);
private:
bool retain_graph_;
bool create_graph_;
std::vector<FunctionNode*> roots_;
HashMap<FunctionNode*, int> dependencies_;
HashSet<FunctionNode*> need_execute_;
};
class GraphAutogradEngine final : public AutogradEngine {
public:
OF_DISALLOW_COPY_AND_MOVE(GraphAutogradEngine);
GraphAutogradEngine() = default;
~GraphAutogradEngine() override = default;
void ClearEngine() override{};
Maybe<FunctionNode> AddNode(const std::string& name,
const std::shared_ptr<BackwardFunction>& backward_fn,
const TensorTuple& inputs, TensorTuple* outputs) override;
private:
Maybe<void> RunBackwardAndSaveGrads4LeafTensor(const TensorTuple& outputs,
const TensorTuple& out_grads, bool retain_graph,
bool create_graph) override;
Maybe<TensorTuple> RunBackwardAndReturnInputsTensorGrad(const TensorTuple& outputs,
const TensorTuple& inputs,
const TensorTuple& out_grads,
bool retain_graph,
bool create_graph) override;
};
AutogradEngine* GetThreadLocalAutogradEngine();
Maybe<void> AddAccumulateFunctionNode(const std::shared_ptr<Tensor>& tensor);
} // namespace one
} // namespace oneflow
#endif // ONEFLOW_CORE_AUTOGRAD_AUTOGRAD_ENGINE_H_
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/autograd/autograd_function.h"
#include "oneflow/core/framework/tensor_tuple.h"
#include "oneflow/core/framework/op_expr.h"
#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h"
#include "oneflow/core/framework/op_expr_grad_function.h"
namespace oneflow {
namespace one {
/*static*/ Maybe<TensorTuple> AutogradFunctionBase::Apply(const std::string& name,
const FType& forward_fn,
const FType& backward_fn,
const TensorTuple& inputs) {
std::shared_ptr<TensorTuple> outputs = std::make_shared<TensorTuple>();
const auto& op = JUST(FunctionOpExpr::New(name, forward_fn, backward_fn));
JUST(OpInterpUtil::Dispatch(*op, inputs, outputs.get(), {}));
const HashSet<Tensor*>& non_differentiable_tensors = op->state()->NonDifferentiableTensors();
for (const auto& tensor : *outputs) {
if (non_differentiable_tensors.find(tensor.get()) != non_differentiable_tensors.end()) {
JUST(tensor->set_requires_grad(false));
}
}
return outputs;
}
} // namespace one
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_AUTOGRAD_AUTOGRAD_FUNCTION_H_
#define ONEFLOW_CORE_AUTOGRAD_AUTOGRAD_FUNCTION_H_
#include "oneflow/core/common/util.h"
namespace oneflow {
namespace one {
class TensorTuple;
class FunctionAutoGradCaptureState;
class FunctionOpExpr;
class AutogradFunctionBase {
public:
using FType = std::function<std::shared_ptr<TensorTuple>(
const std::shared_ptr<FunctionAutoGradCaptureState>&, const TensorTuple&)>;
AutogradFunctionBase() = default;
virtual ~AutogradFunctionBase() = default;
static Maybe<TensorTuple> Apply(const std::string& name, const FType& forward_fn,
const FType& backward_fn, const TensorTuple& inputs);
};
} // namespace one
} // namespace oneflow
#endif // ONEFLOW_CORE_AUTOGRAD_AUTOGRAD_FUNCTION_H_
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/tensor.h"
#include "oneflow/core/framework/dtype.h"
#include "oneflow/core/framework/tensor_arg.h"
#include "oneflow/core/autograd/autograd_meta.h"
#include "oneflow/core/functional/functional.h"
namespace oneflow {
namespace one {
TensorInfo::TensorInfo(const Tensor& tensor) : shape_(tensor.shape()), dtype_(tensor.dtype()) {
if (TRY(tensor.device()).IsOk()) { device_ = CHECK_JUST(tensor.device()); }
if (TRY(tensor.parallel_desc()).IsOk()) { parallel_desc_ = CHECK_JUST(tensor.parallel_desc()); }
if (TRY(tensor.nd_sbp()).IsOk()) { nd_sbp_ = CHECK_JUST(tensor.nd_sbp()); }
}
Maybe<const std::vector<Symbol<SbpParallel>>&> GetSbpTuple(Symbol<NdSbp> nd_sbp) {
static thread_local HashMap<Symbol<NdSbp>, std::vector<Symbol<SbpParallel>>> map;
auto iter = map.find(nd_sbp);
if (iter == map.end()) {
std::vector<Symbol<SbpParallel>> sbp_tuple;
sbp_tuple.reserve(nd_sbp->sbp_parallel().size());
for (const auto& sbp_parallel : nd_sbp->sbp_parallel()) {
sbp_tuple.push_back(SymbolOf(sbp_parallel));
}
iter = map.emplace(nd_sbp, sbp_tuple).first;
}
return iter->second;
}
Maybe<Tensor> TensorInfo::zeros() const {
if (device_.has_value()) {
const auto& device = JUST(device_);
return functional::Constant(*shape_.get(), 0, dtype_, device);
} else {
const auto& parallel_desc = JUST(parallel_desc_);
const auto& nd_sbp = JUST(nd_sbp_);
const auto& sbp_tuple = JUST(GetSbpTuple(nd_sbp));
return functional::ConsistentConstant(*shape_.get(), 0, dtype_, parallel_desc, sbp_tuple);
}
}
AutogradMeta::AutogradMeta(bool requires_grad, bool is_leaf)
: is_leaf_(is_leaf),
requires_grad_(requires_grad),
retain_grad_(false),
is_grad_acc_inplace_(false),
current_grad_(new TensorArg) {}
Maybe<void> AutogradMeta::set_acc_grad(const std::shared_ptr<Tensor>& grad) {
if (const auto& static_zeros_tensor = std::dynamic_pointer_cast<StaticZerosTensor>(grad)) {
acc_grad_ = JUST(static_zeros_tensor->AsMirroredTensor());
} else {
acc_grad_ = grad;
}
return Maybe<void>::Ok();
}
} // namespace one
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_AUTOGRAD_AUTOGRAD_META_H_
#define ONEFLOW_CORE_AUTOGRAD_AUTOGRAD_META_H_
#include <memory>
#include "oneflow/core/common/data_type.pb.h"
#include "oneflow/core/framework/dtype.h"
#include "oneflow/core/common/util.h"
#include "oneflow/core/common/symbol.h"
#include "oneflow/core/common/optional.h"
namespace oneflow {
class Shape;
class Device;
class ParallelDesc;
class NdSbp;
namespace one {
class Tensor;
class TensorArg;
class MirroredTensor;
class AutogradMeta final {
public:
AutogradMeta() = delete;
AutogradMeta(bool requires_grad, bool is_leaf);
// Getters
const std::shared_ptr<Tensor>& acc_grad() const { return acc_grad_; }
const std::shared_ptr<TensorArg>& current_grad() const { return current_grad_; }
bool is_grad_acc_inplace() const { return is_grad_acc_inplace_; }
bool requires_grad() const { return requires_grad_; }
bool is_leaf() const { return is_leaf_; }
bool retain_grad() const { return retain_grad_; }
using Hook = std::function<std::shared_ptr<Tensor>(const std::shared_ptr<const Tensor>&)>;
const std::vector<Hook>& hooks() const { return hooks_; }
const std::vector<Hook>& post_grad_accumulation_hooks() const {
return post_grad_accumulation_hooks_;
}
// Setters
Maybe<void> set_acc_grad(const std::shared_ptr<Tensor>& grad);
std::shared_ptr<Tensor> mut_acc_grad() { return acc_grad_; }
void set_is_grad_acc_inplace(bool is_inplace) { is_grad_acc_inplace_ = is_inplace; }
void set_requires_grad(bool requires_grad) { requires_grad_ = requires_grad; }
void set_retain_grad(bool retain_grad) { retain_grad_ = retain_grad; }
void set_is_leaf(bool is_leaf) { is_leaf_ = is_leaf; }
void add_hook(const Hook& hook) { hooks_.emplace_back(hook); }
void add_post_grad_accumulation_hook(const Hook& hook) {
post_grad_accumulation_hooks_.emplace_back(hook);
}
private:
bool is_leaf_;
// Only meaningful on leaf Tensors (must be false otherwise)
bool requires_grad_;
// Only meaningful on non_leaf Tensors (must be false otherwise)
bool retain_grad_;
// Control whether grad accumulation is inplace. Don't change it
// unless you know what you are doing
bool is_grad_acc_inplace_;
std::shared_ptr<Tensor> acc_grad_;
std::shared_ptr<TensorArg> current_grad_;
std::vector<Hook> hooks_;
std::vector<Hook> post_grad_accumulation_hooks_;
};
inline std::shared_ptr<AutogradMeta> NewAutogradMeta(bool requires_grad, bool is_leaf) {
return std::shared_ptr<AutogradMeta>(new AutogradMeta(requires_grad, is_leaf));
}
class TensorInfo final {
public:
TensorInfo() = delete;
explicit TensorInfo(const Tensor& tensor);
Maybe<Tensor> zeros() const;
Optional<Symbol<ParallelDesc>> placement() const { return parallel_desc_; }
Optional<Symbol<NdSbp>> sbp() const { return nd_sbp_; }
private:
std::shared_ptr<const Shape> shape_;
Symbol<DType> dtype_;
Optional<Symbol<Device>> device_; // for local tensor
Optional<Symbol<ParallelDesc>> parallel_desc_; // for consistent tensor
Optional<Symbol<NdSbp>> nd_sbp_; // for consistent tensor
};
} // namespace one
} // namespace oneflow
#endif // ONEFLOW_CORE_AUTOGRAD_AUTOGRAD_META_H_
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment