Commit 21d47d0e authored by yuguo's avatar yuguo
Browse files

Oneflow 0.8 for DCU

parents
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/common/foreign_lock_helper.h"
#include <pybind11/pybind11.h>
#include "oneflow/api/python/of_api_registry.h"
#include "oneflow/core/common/singleton.h"
namespace py = pybind11;
namespace oneflow {
class GILForeignLockHelper final : public ForeignLockHelper {
Maybe<void> WithScopedRelease(const std::function<Maybe<void>()>& Callback) const override {
if (PyGILState_Check()) {
py::gil_scoped_release release;
JUST(Callback());
} else {
JUST(Callback());
}
return Maybe<void>::Ok();
}
Maybe<void> WithScopedAcquire(const std::function<Maybe<void>()>& Callback) const override {
if (!PyGILState_Check()) {
py::gil_scoped_acquire acquire;
JUST(Callback());
} else {
JUST(Callback());
}
return Maybe<void>::Ok();
}
};
ONEFLOW_API_PYBIND11_MODULE("", m) {
m.def("RegisterGILForeignLockHelper", []() {
Singleton<ForeignLockHelper>::Delete();
Singleton<ForeignLockHelper>::SetAllocated(new GILForeignLockHelper());
});
}
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include <vector>
#include <unordered_map>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "oneflow/core/job/env_global_objects_scope.h"
#include "oneflow/api/python/of_api_registry.h"
#include "oneflow/core/job/cluster_instruction.h"
namespace py = pybind11;
PYBIND11_MAKE_OPAQUE(std::vector<int64_t>);
PYBIND11_MAKE_OPAQUE(std::unordered_map<int64_t, std::shared_ptr<std::vector<int64_t>>>);
namespace oneflow {
namespace {
using IntList = std::vector<int64_t>;
using Int2IntListMap = std::unordered_map<int64_t, std::shared_ptr<IntList>>;
bool Int2IntListMapContaining(const Int2IntListMap& bigger, const Int2IntListMap& smaller) {
for (const auto& pair : smaller) {
if (bigger.find(pair.first) == bigger.end()) { return false; }
const auto& bigger_device_ids = bigger.find(pair.first)->second;
std::vector<int64_t>::iterator ret;
for (int64_t device_id : *pair.second) {
ret = std::find(bigger_device_ids->begin(), bigger_device_ids->end(), device_id);
if (ret == bigger_device_ids->end()) { return false; }
}
}
return true;
}
} // namespace
PYBIND11_MODULE(_oneflow_internal, m) {
using IntList = std::vector<int64_t>;
using Int2IntListMap = std::unordered_map<int64_t, std::shared_ptr<IntList>>;
py::module_ oneflow_api_util = m.def_submodule("util");
py::class_<IntList, std::shared_ptr<IntList>>(oneflow_api_util, "IntList")
.def(py::init<>())
.def("__len__", [](const std::shared_ptr<IntList>& v) { return v->size(); })
.def(
"items",
[](std::shared_ptr<IntList>& v) { return py::make_iterator(v->begin(), v->end()); },
py::keep_alive<0, 1>())
.def("__getitem__", (IntList::reference & (IntList::*)(IntList::size_type pos)) & IntList::at)
.def(
"__iter__",
[](std::shared_ptr<IntList>& v) { return py::make_iterator(v->begin(), v->end()); },
py::keep_alive<0, 1>())
.def("__eq__", [](std::shared_ptr<IntList>& lhs, std::shared_ptr<IntList>& rhs) {
return *lhs == *rhs;
});
py::class_<Int2IntListMap, std::shared_ptr<Int2IntListMap>>(oneflow_api_util, "Int2IntListMap")
.def(py::init<>())
.def("__len__", [](const std::shared_ptr<Int2IntListMap>& v) { return v->size(); })
.def(
"items",
[](std::shared_ptr<Int2IntListMap>& v) {
return py::make_iterator(v->begin(), v->end());
},
py::keep_alive<0, 1>())
.def("__getitem__",
(Int2IntListMap::mapped_type & (Int2IntListMap::*)(const Int2IntListMap::key_type& pos))
& Int2IntListMap::operator[])
.def(
"__iter__",
[](std::shared_ptr<Int2IntListMap>& v) {
return py::make_iterator(v->begin(), v->end());
},
py::keep_alive<0, 1>())
.def("__eq__",
[](std::shared_ptr<Int2IntListMap>& lhs, std::shared_ptr<Int2IntListMap>& rhs) {
return Int2IntListMapContaining(*lhs, *rhs) && Int2IntListMapContaining(*rhs, *lhs);
});
::oneflow::OneflowModuleRegistry().ImportAll(m);
}
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifdef WITH_MLIR
#include "oneflow/ir/include/OneFlow/Extension.h"
#include "oneflow/ir/oneflow-extension/include/OneFlow/OneFlowRoundTrip.h"
#include <glog/logging.h>
#include "oneflow/api/python/of_api_registry.h"
namespace oneflow {
ONEFLOW_API_PYBIND11_MODULE("ir", m) {
m.def("load_jit_shared_lib",
[](const std::string& lib_path) { MutSharedLibPaths()->insert(lib_path); });
}
} // namespace oneflow
#endif // WITH_MLIR
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include <pybind11/pybind11.h>
#include <string>
#include "oneflow/api/python/of_api_registry.h"
#include "oneflow/api/python/job_build/job_build_and_infer.h"
namespace py = pybind11;
namespace oneflow {
ONEFLOW_API_PYBIND11_MODULE("", m) {
m.def("JobBuildAndInferCtx_Open", &JobBuildAndInferCtx_Open);
m.def("JobBuildAndInferCtx_GetCurrentJobName", &JobBuildAndInferCtx_GetCurrentJobName);
m.def("JobBuildAndInferCtx_GetCurrentJobId", &JobBuildAndInferCtx_GetCurrentJobId);
m.def("JobBuildAndInferCtx_Close", &JobBuildAndInferCtx_Close);
m.def("CurJobBuildAndInferCtx_CheckJob", &CurJobBuildAndInferCtx_CheckJob);
m.def("CurJobBuildAndInferCtx_SetJobConf", &CurJobBuildAndInferCtx_SetJobConf);
m.def("CurJobBuildAndInferCtx_SetTrainConf", &CurJobBuildAndInferCtx_SetTrainConf);
m.def("CurJobBuildAndInferCtx_Complete", &CurJobBuildAndInferCtx_Complete,
py::call_guard<py::gil_scoped_release>());
m.def("CurJobBuildAndInferCtx_Rebuild", &CurJobBuildAndInferCtx_Rebuild,
py::call_guard<py::gil_scoped_release>());
m.def("CurJobBuildAndInferCtx_HasJobConf", &CurJobBuildAndInferCtx_HasJobConf);
m.def("CurJobBuildAndInferCtx_AddAndInferMirroredOp",
&CurJobBuildAndInferCtx_AddAndInferMirroredOp, py::call_guard<py::gil_scoped_release>());
m.def("CurJobBuildAndInferCtx_AddAndInferConsistentOp",
&CurJobBuildAndInferCtx_AddAndInferConsistentOp);
m.def("CurJobBuildAndInferCtx_AddLbiAndDiffWatcherUuidPair",
&CurJobBuildAndInferCtx_AddLbiAndDiffWatcherUuidPair);
m.def("JobBuildAndInferCtx_GetSerializedIdListAsStaticShape",
&JobBuildAndInferCtx_GetSerializedIdListAsStaticShape);
m.def("JobBuildAndInferCtx_GetDataType", &JobBuildAndInferCtx_GetDataType);
m.def("JobBuildAndInferCtx_IsDynamic", &JobBuildAndInferCtx_IsDynamic);
m.def("JobBuildAndInferCtx_IsDisableBoxing", &JobBuildAndInferCtx_IsDisableBoxing);
m.def("JobBuildAndInferCtx_GetSplitAxisFromProducerView",
&JobBuildAndInferCtx_GetSplitAxisFromProducerView);
m.def("JobBuildAndInferCtx_GetSerializedParallelConfFromProducerView",
&JobBuildAndInferCtx_GetSerializedParallelConfFromProducerView);
m.def("CurJobBuildAndInferCtx_AddLossLogicalBlobName",
&CurJobBuildAndInferCtx_AddLossLogicalBlobName);
m.def("JobBuildAndInferCtx_IsMirroredBlob", &JobBuildAndInferCtx_IsMirroredBlob);
m.def("JobBuildAndInferCtx_MirroredBlobGetNumSubLbi",
&JobBuildAndInferCtx_MirroredBlobGetNumSubLbi);
m.def("JobBuildAndInferCtx_MirroredBlobGetSerializedSubLbi",
&JobBuildAndInferCtx_MirroredBlobGetSubLbi);
m.def("JobBuildAndInferCtx_CheckLbnValidAndExist", &JobBuildAndInferCtx_CheckLbnValidAndExist);
m.def("JobBuildAndInferCtx_GetOpBlobLbn", &JobBuildAndInferCtx_GetOpBlobLbn);
}
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_API_PYTHON_JOB_BUILD_JOB_BUILD_AND_INFER_H_
#define ONEFLOW_API_PYTHON_JOB_BUILD_JOB_BUILD_AND_INFER_H_
#include "oneflow/core/job/global_for.h"
#include "oneflow/core/common/protobuf.h"
#include "oneflow/core/framework/tensor.h"
#include "oneflow/core/framework/tensor_name_scope.h"
#include "oneflow/core/job/job_build_and_infer_ctx.h"
#include "oneflow/core/job/job_build_and_infer_ctx_mgr.h"
#include "oneflow/core/job/job.pb.h"
#include "oneflow/core/job/lazy_mode.h"
#include "oneflow/core/record/record.pb.h"
namespace oneflow {
inline Maybe<void> JobBuildAndInferCtx_Open(const std::string& job_name) {
auto* mgr = JUST(GlobalJobBuildAndInferCtxMgr());
return mgr->OpenJobBuildAndInferCtx(job_name);
}
inline Maybe<std::string> JobBuildAndInferCtx_GetCurrentJobName() {
auto* mgr = JUST(GlobalJobBuildAndInferCtxMgr());
return mgr->GetCurrentJobName();
}
inline Maybe<int64_t> JobBuildAndInferCtx_GetCurrentJobId() {
return JUST(GetCurInferCtx())->job_id();
}
inline Maybe<void> JobBuildAndInferCtx_Close() {
auto* mgr = JUST(GlobalJobBuildAndInferCtxMgr());
JUST(mgr->CloseCurrentJobBuildAndInferCtx());
return Maybe<void>::Ok();
}
inline Maybe<void> CurJobBuildAndInferCtx_CheckJob() { return JUST(GetCurInferCtx())->CheckJob(); }
inline Maybe<void> CurJobBuildAndInferCtx_SetJobConf(const std::string& job_conf_str) {
JobConfigProto job_conf;
CHECK_OR_RETURN(TxtString2PbMessage(job_conf_str, &job_conf)) << "job conf parse failed";
return JUST(GetCurInferCtx())->SetJobConf(job_conf);
}
inline Maybe<void> CurJobBuildAndInferCtx_SetTrainConf(const std::string& train_conf_str) {
TrainConf train_conf;
CHECK_OR_RETURN(TxtString2PbMessage(train_conf_str, &train_conf)) << "train conf parse failed";
return JUST(GetCurInferCtx())->SetTrainConf(train_conf);
}
inline Maybe<void> CurJobBuildAndInferCtx_Complete() { return JUST(GetCurInferCtx())->Complete(); }
inline Maybe<void> CurJobBuildAndInferCtx_Rebuild() { return JUST(GetCurInferCtx())->Rebuild(); }
inline Maybe<bool> CurJobBuildAndInferCtx_HasJobConf() {
return JUST(GetCurInferCtx())->HasJobConf();
}
inline Maybe<std::string> CurJobBuildAndInferCtx_AddAndInferMirroredOp(
const std::string& op_conf_str) {
OperatorConf op_conf;
CHECK_OR_RETURN(TxtString2PbMessage(op_conf_str, &op_conf)) << "operator conf parse failed";
auto* ctx = JUST(GetCurInferCtx());
const auto& op_attribute = JUST(ctx->AddAndInferMirroredOp(op_conf));
return PbMessage2TxtString(*op_attribute);
}
inline Maybe<std::string> CurJobBuildAndInferCtx_AddAndInferConsistentOp(
const std::string& op_conf_str) {
OperatorConf op_conf;
CHECK_OR_RETURN(TxtString2PbMessage(op_conf_str, &op_conf)) << "operator conf parse failed";
auto* ctx = JUST(GetCurInferCtx());
const auto& op_attribute = JUST(ctx->AddAndInferConsistentOp(op_conf));
return PbMessage2TxtString(*op_attribute);
}
inline Maybe<void> CurJobBuildAndInferCtx_AddLbiAndDiffWatcherUuidPair(
const std::string& lbi_uuid_pair_str) {
auto* ctx = JUST(GetCurInferCtx());
LbiAndDiffWatcherUuidPair lbi_uuid_pair;
CHECK_OR_RETURN(TxtString2PbMessage(lbi_uuid_pair_str, &lbi_uuid_pair))
<< "LbiAndDiffWatcherUuidPair parse failed";
return ctx->AddLbiAndDiffWatcherUuidPair(lbi_uuid_pair);
}
inline Maybe<std::string> JobBuildAndInferCtx_GetSerializedIdListAsStaticShape(
const std::string& job_name, const std::string& lbn) {
auto* ctx = JUST(GetJobBuildAndInferCtx(job_name));
const auto& shape = JUST(ctx->GetStaticShape(lbn));
Int64List id_list;
*id_list.mutable_value() = {shape->dim_vec().begin(), shape->dim_vec().end()};
return PbMessage2TxtString(id_list);
}
inline Maybe<long long> JobBuildAndInferCtx_GetDataType(const std::string& job_name,
const std::string& lbn) {
auto* ctx = JUST(GetJobBuildAndInferCtx(job_name));
return JUST(ctx->GetDataType(lbn));
}
inline Maybe<bool> JobBuildAndInferCtx_IsDynamic(const std::string& job_name,
const std::string& lbn) {
auto* ctx = JUST(GetJobBuildAndInferCtx(job_name));
return ctx->IsDynamic(lbn);
}
inline Maybe<bool> JobBuildAndInferCtx_IsDisableBoxing(const std::string& job_name,
const std::string& lbn) {
auto* ctx = JUST(GetJobBuildAndInferCtx(job_name));
return ctx->IsDisableBoxing(lbn);
}
inline Maybe<std::string> JobBuildAndInferCtx_GetSplitAxisFromProducerView(
const std::string& job_name, const std::string& lbn) {
auto* ctx = JUST(GetJobBuildAndInferCtx(job_name));
return PbMessage2TxtString(*JUST(ctx->GetSplitAxisFromProducerView(lbn)));
}
inline Maybe<std::string> JobBuildAndInferCtx_GetSerializedParallelConfFromProducerView(
const std::string& job_name, const std::string& lbn) {
auto* ctx = JUST(GetJobBuildAndInferCtx(job_name));
return PbMessage2TxtString(JUST(ctx->GetParallelDescFromProducerView(lbn))->parallel_conf());
}
inline Maybe<void> CurJobBuildAndInferCtx_AddLossLogicalBlobName(const std::string& lbn) {
return JUST(GetCurInferCtx())->AddLossLogicalBlobName(lbn);
}
inline Maybe<bool> JobBuildAndInferCtx_IsMirroredBlob(const std::string& job_name,
const std::string& lbn) {
auto* ctx = JUST(GetJobBuildAndInferCtx(job_name));
return ctx->IsMirroredBlob(lbn);
}
inline Maybe<int> JobBuildAndInferCtx_MirroredBlobGetNumSubLbi(const std::string& job_name,
const std::string& lbn) {
auto* ctx = JUST(GetJobBuildAndInferCtx(job_name));
return ctx->MirroredBlobGetNumSubLbi(lbn);
}
inline Maybe<std::string> JobBuildAndInferCtx_MirroredBlobGetSubLbi(const std::string& job_name,
const std::string& lbn,
int index) {
auto* ctx = JUST(GetJobBuildAndInferCtx(job_name));
return PbMessage2TxtString(*JUST(ctx->MirroredBlobGetSubLbi(lbn, index)));
}
inline Maybe<void> JobBuildAndInferCtx_CheckLbnValidAndExist(const std::string& job_name,
const std::string& lbn) {
auto* ctx = JUST(GetJobBuildAndInferCtx(job_name));
JUST(ctx->CheckLbnValidAndExist(lbn));
return Maybe<void>::Ok();
}
inline Maybe<std::string> JobBuildAndInferCtx_GetOpBlobLbn(const std::string& job_name,
const std::string& op_name,
const std::string bn_in_op) {
const auto* job_ctx = JUST(GetJobBuildAndInferCtx(job_name));
return job_ctx->GetOpBlobLbn(op_name, bn_in_op);
}
inline Maybe<void> AddTensorAsGraphLoss(const std::shared_ptr<one::Tensor>& t) {
CHECK_OR_RETURN(t->is_lazy());
CHECK_OR_RETURN(LazyMode::is_enabled());
const std::string& loss_lbn = one::TensorNameScope::Global()->Lookup(t);
CHECK_OR_RETURN("" != loss_lbn);
return JUST(GetCurInferCtx())->AddLossLogicalBlobName(loss_lbn);
}
} // namespace oneflow
#endif // ONEFLOW_API_PYTHON_JOB_BUILD_JOB_BUILD_AND_INFER_H_
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include <memory>
#include <pybind11/pybind11.h>
#include "oneflow/api/python/of_api_registry.h"
#include "oneflow/core/job/lazy_mode.h"
namespace py = pybind11;
namespace oneflow {
ONEFLOW_API_PYBIND11_MODULE("lazy_mode", m) {
py::class_<LazyMode::Guard, std::shared_ptr<LazyMode::Guard>>(m, "guard")
.def(py::init(
[](const bool is_enabled) { return std::make_shared<LazyMode::Guard>(is_enabled); }))
.def("__enter__", [](const LazyMode::Guard& guard_obj) {})
.def("__exit__", [](const LazyMode::Guard& guard_obj, const py::object& type,
const py::object& value, const py::object& traceback) {});
m.def("is_enabled", []() { return LazyMode::is_enabled(); });
}
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include <pybind11/pybind11.h>
#include "oneflow/api/python/of_api_registry.h"
#include "oneflow/api/python/multiprocessing/object_ptr.h"
#include "oneflow/core/ep/cpu/cpu_device_manager.h"
#include "oneflow/core/ep/include/device_manager_registry.h"
#include "oneflow/core/ep/cpu/cpu_device.h"
#include <csignal>
#include <stdexcept>
#if defined(__linux__)
#include <sys/prctl.h>
#include <system_error>
#endif
#define SYSASSERT(rv, ...) \
if ((rv) < 0) { throw std::system_error(errno, std::system_category(), ##__VA_ARGS__); }
namespace oneflow {
namespace multiprocessing {
namespace py = pybind11;
void multiprocessing_init() {
auto multiprocessing_module = OFObjectPtr(PyImport_ImportModule("oneflow.multiprocessing"));
if (!multiprocessing_module) {
throw std::runtime_error("multiprocessing init error >> multiprocessing_module init fail!");
}
auto module = py::handle(multiprocessing_module).cast<py::module>();
module.def("_prctl_pr_set_pdeathsig", [](int signal) {
#if defined(__linux__)
auto rv = prctl(PR_SET_PDEATHSIG, signal);
SYSASSERT(rv, "prctl");
#endif
});
// Py_RETURN_TRUE;
}
void set_num_threads(int num) {
int64_t cpu_logic_core = std::thread::hardware_concurrency();
if (num <= 0) {
py::print("Warning : ", num, " less than 1 will be set to 1.");
num = 1;
} else if (num >= cpu_logic_core) {
py::print("Warning : ", num,
" is greater than the number of logical cores and will be set to the maximum number "
"of logical cores ",
cpu_logic_core);
num = cpu_logic_core;
}
auto cpu_device = std::static_pointer_cast<ep::CpuDevice>(
Singleton<ep::DeviceManagerRegistry>::Get()->GetDevice(DeviceType::kCPU, 0));
cpu_device->SetNumThreads(num);
}
ONEFLOW_API_PYBIND11_MODULE("", m) {
py::options options;
options.disable_function_signatures();
m.def("_multiprocessing_init", &multiprocessing_init);
m.def("_set_num_threads", &set_num_threads);
options.disable_function_signatures();
}
} // namespace multiprocessing
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/api/python/multiprocessing/object_ptr.h"
template<>
void OFPointer<PyObject>::free() {
if (ptr) Py_DECREF(ptr);
}
template class OFPointer<PyObject>;
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#pragma once
#include <pybind11/pybind11.h>
#include "oneflow/api/python/of_api_registry.h"
// reference: pytorch/torch/csrc/utils/object_ptr.h
// https://github.com/pytorch/pytorch/blob/d69c22dd61a2f006dcfe1e3ea8468a3ecaf931aa/torch/csrc/utils/object_ptr.h
template<class T>
class OFPointer {
public:
OFPointer() : ptr(nullptr){};
explicit OFPointer(T* ptr) noexcept : ptr(ptr){};
OFPointer(OFPointer&& p) noexcept {
free();
ptr = p.ptr;
p.ptr = nullptr;
};
~OFPointer() { free(); };
T* get() { return ptr; }
const T* get() const { return ptr; }
T* release() {
T* tmp = ptr;
ptr = nullptr;
return tmp;
}
operator T*() { return ptr; }
OFPointer& operator=(T* new_ptr) noexcept {
free();
ptr = new_ptr;
return *this;
}
OFPointer& operator=(OFPointer&& p) noexcept {
free();
ptr = p.ptr;
p.ptr = nullptr;
return *this;
}
T* operator->() { return ptr; }
explicit operator bool() const { return ptr != nullptr; }
private:
void free();
T* ptr = nullptr;
};
/**
* An RAII-style, owning pointer to a PyObject. You must protect
* destruction of this object with the GIL.
*
* WARNING: Think twice before putting this as a field in a C++
* struct. This class does NOT take out the GIL on destruction,
* so if you will need to ensure that the destructor of your struct
* is either (a) always invoked when the GIL is taken or (b) takes
* out the GIL itself. Easiest way to avoid this problem is to
* not use THPPointer in this situation.
*/
using OFObjectPtr = OFPointer<PyObject>;
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "oneflow/api/python/of_api_registry.h"
#include "oneflow/core/ipc/shared_memory.h"
namespace oneflow {
namespace py = pybind11;
ONEFLOW_API_PYBIND11_MODULE("multiprocessing", m) {
py::class_<ipc::SharedMemory, std::shared_ptr<ipc::SharedMemory>>(m, "SharedMemory")
.def(py::init([](const std::string& name, bool create, size_t size) {
if (create) { return ipc::SharedMemory::Open(size, create).GetPtrOrThrow(); }
return ipc::SharedMemory::Open(name, create).GetPtrOrThrow();
}),
py::arg("name") = "", py::arg("create") = false, py::arg("size") = 0)
.def("close", &ipc::SharedMemory::Close)
.def("unlink", &ipc::SharedMemory::Unlink)
.def_property_readonly("buf",
[](ipc::SharedMemory* shm) {
return py::memoryview::from_memory(shm->mut_buf(), shm->size());
})
.def_property_readonly("name", &ipc::SharedMemory::name)
.def_property_readonly("size", &ipc::SharedMemory::size);
m.def("unlink_all_shared_memory",
[]() { return ipc::SharedMemoryManager::get().UnlinkAllShms(); });
}
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include <pybind11/pybind11.h>
#include "oneflow/api/python/of_api_registry.h"
#include "oneflow/core/common/maybe.h"
#include "oneflow/extension/python/numpy.h"
namespace py = pybind11;
ONEFLOW_API_PYBIND11_MODULE("", m) {
m.def("InitNumpyCAPI", []() { return oneflow::numpy::InitNumpyCAPI(); });
}
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/api/python/of_api_registry.h"
namespace oneflow {
namespace {
// If different APIs are registered under the same path, the BuildModuleFuntion of which will be
// saved in the corresponding vector.
using SubModuleMap = std::map<std::string, std::vector<std::function<void(pybind11::module&)>>>;
SubModuleMap* GetSubModuleMap() {
static SubModuleMap sub_module_map;
return &sub_module_map;
}
} // namespace
void OneflowModuleRegistry::Register(std::string module_path,
std::function<void(pybind11::module&)> BuildModule) {
(*GetSubModuleMap())[module_path].emplace_back(BuildModule);
}
void OneflowModuleRegistry::ImportAll(pybind11::module& m) {
for (const auto& pair : (*GetSubModuleMap())) {
for (const auto& BuildModule : pair.second) { BuildSubModule(pair.first, m, BuildModule); }
}
}
void OneflowModuleRegistry::BuildSubModule(
const std::string& module_path, pybind11::module& m,
const std::function<void(pybind11::module&)>& BuildModule) {
if (module_path.empty()) {
BuildModule(m);
return;
}
size_t dot_pos = module_path.find(".");
if (dot_pos == std::string::npos) {
pybind11::module sub_module = m.def_submodule(module_path.data());
BuildModule(sub_module);
} else {
const std::string& sub_module_name = module_path.substr(0, dot_pos);
pybind11::module sub_module = m.def_submodule(sub_module_name.data());
BuildSubModule(module_path.substr(dot_pos + 1), sub_module, BuildModule);
}
}
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_API_PYTHON_UTIL_OF_API_REGISTRY_H_
#define ONEFLOW_API_PYTHON_UTIL_OF_API_REGISTRY_H_
#include <pybind11/pybind11.h>
#include <map>
#include <vector>
#include <functional>
#include "oneflow/api/python/caster/maybe.h"
#include "oneflow/api/python/caster/tensor.h"
#include "oneflow/api/python/caster/optional.h"
#include "oneflow/core/common/preprocessor.h"
namespace oneflow {
class OneflowModuleRegistry {
public:
OneflowModuleRegistry() = default;
~OneflowModuleRegistry() = default;
void Register(std::string module_path, std::function<void(pybind11::module&)> BuildModule);
void ImportAll(pybind11::module& m);
private:
void BuildSubModule(const std::string& module_path, pybind11::module& m,
const std::function<void(pybind11::module&)>& BuildModule);
};
} // namespace oneflow
#define ONEFLOW_API_PYBIND11_MODULE(module_path, m) \
static void OF_PP_CAT(OneflowApiPythonModule, __LINE__)(pybind11::module&); \
namespace { \
struct OfApiRegistryInit { \
OfApiRegistryInit() { \
::oneflow::OneflowModuleRegistry().Register(module_path, \
&OF_PP_CAT(OneflowApiPythonModule, __LINE__)); \
} \
}; \
OfApiRegistryInit of_api_registry_init; \
} \
static void OF_PP_CAT(OneflowApiPythonModule, __LINE__)(pybind11::module & m)
#endif // ONEFLOW_API_PYTHON_UTIL_OF_API_REGISTRY_H_
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include <pybind11/pybind11.h>
#include "oneflow/api/python/of_api_registry.h"
#include "oneflow/core/common/preprocessor.h"
#include "oneflow/core/common/data_type_seq.h"
#include "oneflow/api/python/ofblob/ofblob.h"
#include "oneflow/api/python/ofblob/ofblob.e.h"
namespace py = pybind11;
ONEFLOW_API_PYBIND11_MODULE("", m) {
m.def("Ofblob_GetDataType", &Ofblob_GetDataType);
m.def("OfBlob_NumAxes", &OfBlob_NumAxes);
m.def("OfBlob_IsDynamic", &OfBlob_IsDynamic);
m.def("OfBlob_CopyShapeTo", &OfBlob_CopyShapeTo);
m.def("OfBlob_CopyStaticShapeTo", &OfBlob_CopyStaticShapeTo);
m.def("OfBlob_CopyShapeFrom", &OfBlob_CopyShapeFrom);
m.def("Dtype_GetOfBlobCopyToBufferFuncName", &Dtype_GetOfBlobCopyToBufferFuncName);
m.def("Dtype_GetOfBlobCopyFromBufferFuncName", &Dtype_GetOfBlobCopyFromBufferFuncName);
#define EXPORT_COPY_DATA_API(T, type_proto) \
m.def("OfBlob_CopyToBuffer_" OF_PP_STRINGIZE(T), \
[](uint64_t of_blob_ptr, py::array_t<T> array) { \
oneflow::NumPyArrayPtr array_ptr(array.ptr()); \
OfBlob_CopyToBuffer_##T(of_blob_ptr, array_ptr); \
}); \
m.def("OfBlob_CopyFromBuffer_" OF_PP_STRINGIZE(T), \
[](uint64_t of_blob_ptr, py::array_t<T> array) { \
oneflow::NumPyArrayPtr array_ptr(array.ptr()); \
OfBlob_CopyFromBuffer_##T(of_blob_ptr, array_ptr); \
});
OF_PP_FOR_EACH_TUPLE(EXPORT_COPY_DATA_API, POD_DATA_TYPE_SEQ);
#undef EXPORT_COPY_DATA_API
}
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_API_PYTHON_OFBLOB_OFBLOB_E_H_
#define ONEFLOW_API_PYTHON_OFBLOB_OFBLOB_E_H_
#include "oneflow/core/common/foreign_lock_helper.h"
#include "oneflow/core/common/type_traits.h"
#include <pybind11/pybind11.h>
#include <pybind11/numpy.h>
#include "oneflow/core/register/ofblob.h"
#include "oneflow/core/common/preprocessor.h"
#include "oneflow/core/common/data_type_seq.h"
#include "oneflow/core/common/maybe.h"
#include "oneflow/api/common/ofblob.h"
#include "oneflow/extension/python/numpy.h"
namespace py = pybind11;
namespace oneflow {
template<typename T>
struct BlobNumpyCopyUtil {
static Maybe<void> From(uint64_t of_blob_ptr, const NumPyArrayPtr& array) {
return BlobBufferCopyUtil<T>::From(of_blob_ptr, (T*)array.data(), array.size());
}
static Maybe<void> To(uint64_t of_blob_ptr, const NumPyArrayPtr& array) {
return BlobBufferCopyUtil<T>::To(of_blob_ptr, (T*)array.data(), array.size());
}
};
} // namespace oneflow
#define DEFINE_COPIER(T, type_proto) \
inline void OfBlob_CopyToBuffer_##T(uint64_t of_blob_ptr, const oneflow::NumPyArrayPtr& array) { \
oneflow::BlobNumpyCopyUtil<T>::To(of_blob_ptr, array).GetOrThrow(); \
} \
inline void OfBlob_CopyFromBuffer_##T(uint64_t of_blob_ptr, \
const oneflow::NumPyArrayPtr& array) { \
oneflow::BlobNumpyCopyUtil<T>::From(of_blob_ptr, array).GetOrThrow(); \
}
OF_PP_FOR_EACH_TUPLE(DEFINE_COPIER, POD_DATA_TYPE_SEQ);
#undef DEFINE_COPIER
inline std::string Dtype_GetOfBlobCopyToBufferFuncName(int64_t dtype) {
using namespace oneflow;
static const HashMap<int64_t, std::string> data_type2func_name{
#define DATA_TYPE_FUNC_NAME_PAIR(type_cpp, type_proto) \
{type_proto, "OfBlob_CopyToBuffer_" #type_cpp},
OF_PP_FOR_EACH_TUPLE(DATA_TYPE_FUNC_NAME_PAIR, POD_DATA_TYPE_SEQ)
#undef DATA_TYPE_FUNC_NAME_PAIR
};
return data_type2func_name.at(dtype);
}
inline std::string Dtype_GetOfBlobCopyFromBufferFuncName(int64_t dtype) {
using namespace oneflow;
static const HashMap<int64_t, std::string> data_type2func_name{
#define DATA_TYPE_FUNC_NAME_PAIR(type_cpp, type_proto) \
{type_proto, "OfBlob_CopyFromBuffer_" #type_cpp},
OF_PP_FOR_EACH_TUPLE(DATA_TYPE_FUNC_NAME_PAIR, POD_DATA_TYPE_SEQ)
#undef DATA_TYPE_FUNC_NAME_PAIR
};
return data_type2func_name.at(dtype);
}
#endif // ONEFLOW_API_PYTHON_OFBLOB_OFBLOB_E_H_
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_API_PYTHON_OFBLOB_OFBLOB_H_
#define ONEFLOW_API_PYTHON_OFBLOB_OFBLOB_H_
#include "oneflow/core/common/type_traits.h"
#include <pybind11/pybind11.h>
#include <pybind11/numpy.h>
#include "oneflow/core/register/ofblob.h"
namespace py = pybind11;
inline int Ofblob_GetDataType(uint64_t of_blob_ptr) {
using namespace oneflow;
auto* of_blob = reinterpret_cast<OfBlob*>(of_blob_ptr);
return of_blob->data_type();
}
inline size_t OfBlob_NumAxes(uint64_t of_blob_ptr) {
using namespace oneflow;
auto* of_blob = reinterpret_cast<OfBlob*>(of_blob_ptr);
return of_blob->NumAxes();
}
inline bool OfBlob_IsDynamic(uint64_t of_blob_ptr) {
using namespace oneflow;
auto* of_blob = reinterpret_cast<OfBlob*>(of_blob_ptr);
return of_blob->is_dynamic();
}
inline void OfBlob_CopyShapeFrom(uint64_t of_blob_ptr, py::array_t<int64_t> array) {
py::buffer_info buf = array.request();
int64_t* buf_ptr = (int64_t*)buf.ptr;
size_t size = buf.size;
using namespace oneflow;
auto* of_blob = reinterpret_cast<OfBlob*>(of_blob_ptr);
return of_blob->CopyShapeFrom(buf_ptr, size);
}
inline void OfBlob_CopyShapeTo(uint64_t of_blob_ptr, py::array_t<int64_t> array) {
py::buffer_info buf = array.request();
int64_t* buf_ptr = (int64_t*)buf.ptr;
size_t size = buf.size;
using namespace oneflow;
auto* of_blob = reinterpret_cast<OfBlob*>(of_blob_ptr);
return of_blob->CopyShapeTo(buf_ptr, size);
}
inline void OfBlob_CopyStaticShapeTo(uint64_t of_blob_ptr, py::array_t<int64_t> array) {
py::buffer_info buf = array.request();
int64_t* buf_ptr = (int64_t*)buf.ptr;
size_t size = buf.size;
using namespace oneflow;
auto* of_blob = reinterpret_cast<OfBlob*>(of_blob_ptr);
return of_blob->CopyStaticShapeTo(buf_ptr, size);
}
#endif // ONEFLOW_API_PYTHON_OFBLOB_OFBLOB_H_
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include <pybind11/pybind11.h>
#include "oneflow/api/python/of_api_registry.h"
#include "oneflow/core/profiler/profiler.h"
namespace py = pybind11;
namespace oneflow {
ONEFLOW_API_PYBIND11_MODULE("profiler", m) {
m.def("RangePush", [](const std::string& str) { OF_PROFILER_RANGE_PUSH(str); });
m.def("RangePop", []() { OF_PROFILER_RANGE_POP(); });
m.def("ProfilerStart", []() { profiler::ProfilerStart(); });
m.def("ProfilerStop", []() { profiler::ProfilerStop(); });
m.def("EnableProfiler", &profiler::EnableProfiler);
m.def("DisableProfilerAndReturnResult", &profiler::DisableProfilerAndReturnResult);
m.def("StartRecord", &profiler::StartRecord);
m.def("EndRecord", &profiler::EndRecord);
}
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include <pybind11/pybind11.h>
#include "oneflow/api/python/of_api_registry.h"
#include "oneflow/core/common/registry_error.h"
namespace py = pybind11;
namespace oneflow {
ONEFLOW_API_PYBIND11_MODULE("", m) {
m.def("CheckAndClearRegistryFlag", &CheckAndClearRegistryFlag);
}
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include <pybind11/pybind11.h>
#include <pybind11/pytypes.h>
#include "oneflow/api/python/of_api_registry.h"
#include "oneflow/core/framework/dtype.h"
#include "oneflow/core/ccl/ccl.h"
#include "oneflow/core/rpc/include/global_process_ctx.h"
#include "oneflow/core/job/rank_group.h"
namespace py = pybind11;
namespace oneflow {
namespace {
Maybe<py::bytes> CpuBroadcast(py::bytes* in, int64_t root) {
const auto& rank_group = JUST(RankGroup::DefaultRankGroup());
const auto& parallel_desc = JUST(RankGroup::GetDefaultParallelDesc(DeviceType::kCPU, rank_group));
Py_ssize_t length;
char* buffer;
if (GlobalProcessCtx::Rank() == root) {
CHECK_NOTNULL_OR_RETURN(in);
PyBytes_AsStringAndSize(in->ptr(), &buffer, &length);
}
JUST(ccl::Broadcast<DeviceType::kCPU>(&length, &length, sizeof(length), DataType::kChar, root,
parallel_desc, nullptr));
if (GlobalProcessCtx::Rank() == root) {
JUST(ccl::Broadcast<DeviceType::kCPU>(buffer, buffer, length, DataType::kChar, root, // NOLINT
parallel_desc, nullptr));
return *in;
} else {
// https://github.com/pybind/pybind11/issues/1236#issuecomment-527730864
PyBytesObject* bytesObject =
static_cast<PyBytesObject*>(PyObject_Malloc(offsetof(PyBytesObject, ob_sval) + length + 1));
PyObject_INIT_VAR(bytesObject, &PyBytes_Type, length);
bytesObject->ob_shash = -1;
bytesObject->ob_sval[length] = '\0';
buffer = bytesObject->ob_sval;
JUST(ccl::Broadcast<DeviceType::kCPU>(nullptr, buffer, length, DataType::kChar, root,
parallel_desc, nullptr));
return py::reinterpret_steal<py::bytes>(reinterpret_cast<PyObject*>(bytesObject));
}
}
} // namespace
ONEFLOW_API_PYBIND11_MODULE("", m) {
m.def("cpu_broadcast",
[](py::bytes in, int64_t root) -> Maybe<py::bytes> { return CpuBroadcast(&in, root); });
m.def("cpu_broadcast", [](const py::none& in, int64_t root) -> Maybe<py::bytes> {
return CpuBroadcast(nullptr, root);
});
}
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <pybind11/functional.h>
#include "oneflow/api/python/of_api_registry.h"
#include "oneflow/core/thread/thread_consistent_id.h"
#include "oneflow/core/framework/rank_group_rpc_util.h"
#include "oneflow/core/job/rank_group.h"
#include "oneflow/core/job/rank_group_scope.h"
#include "oneflow/core/common/symbol.h"
namespace py = pybind11;
namespace oneflow {
namespace {
Maybe<void> InitConsistentTransportTokenScope(const std::string& thread_tag,
int64_t thread_consistent_id,
Symbol<RankGroup> rank_group) {
JUST(InitThisThreadUniqueConsistentId(thread_consistent_id, thread_tag));
static thread_local const auto& init_rank_group_scope =
JUST(RankGroupScope::MakeInitialRankGroupScope(rank_group));
// no unused warning for `init_rank_group_scope`.
(void)(init_rank_group_scope);
return Maybe<void>::Ok();
}
Maybe<void> InitConsistentTransportTokenScope(const std::string& thread_tag,
int64_t thread_consistent_id) {
const auto& rank_group = JUST(RankGroup::DefaultRankGroup());
JUST(InitConsistentTransportTokenScope(thread_tag, thread_consistent_id, rank_group));
return Maybe<void>::Ok();
}
Maybe<void> ApiInitDefaultConsistentTransportTokenScope() {
return InitConsistentTransportTokenScope("main", kThreadConsistentIdMain);
}
} // namespace
ONEFLOW_API_PYBIND11_MODULE("", m) {
m.def("InitDefaultConsistentTransportTokenScope", &ApiInitDefaultConsistentTransportTokenScope);
}
} // namespace oneflow
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment