Commit 21d47d0e authored by yuguo's avatar yuguo
Browse files

Oneflow 0.8 for DCU

parents
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include <memory>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <pybind11/functional.h>
#include "oneflow/api/python/of_api_registry.h"
#include "oneflow/api/python/functional/common.h"
#include "oneflow/core/autograd/autograd_function.h"
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/framework/tensor_tuple.h"
namespace py = pybind11;
namespace oneflow {
namespace {
// Transform input to TensorTuple
Maybe<one::TensorTuple> UnpackTensorTuple(const py::object& input) {
one::TensorTuple tp;
if (one::PyTensor_Check(input.ptr())) {
tp.emplace_back(input.cast<std::shared_ptr<one::Tensor>>());
} else if (py::isinstance<py::tuple>(input)) {
auto tuple = input.cast<py::tuple>();
for (int i = 0; i < tuple.size(); ++i) {
PyObject* obj = tuple[i].ptr();
if (!one::PyTensor_Check(obj)) {
return Error::RuntimeError()
<< "expected Tensor as element " << i << ", but got "
<< one::functional::PyStringAsString(PyObject_Str((PyObject*)Py_TYPE(obj)));
}
tp.emplace_back(one::PyTensor_Unpack(obj));
}
} else {
return Error::RuntimeError() << "Only support tensor or list of tensors";
}
return tp;
}
// Return single Tensor when TensorTuple's size is one, otherwise py::tuple
py::object PackTensorTuple(const one::TensorTuple& tp) {
if (tp.size() == 1) {
return py::cast(tp.at(0));
} else {
py::tuple out = py::tuple(tp.size());
for (int i = 0; i < tp.size(); ++i) { out[i] = tp.at(i); }
return py::cast<py::object>(out);
}
}
// wrap PyFunction, unpack the inputs from TensorTuple and pack outputs to TensorTuple
one::AutogradFunctionBase::FType PackPyFunctionToFType(const py::function& func) {
return [func](const std::shared_ptr<one::FunctionAutoGradCaptureState>& ctx,
const one::TensorTuple& inputs) {
const py::tuple& a = py::cast(inputs);
py::object res = func(ctx, *a);
return UnpackTensorTuple(res).GetPtrOrThrow();
};
}
} // namespace
namespace one {
ONEFLOW_API_PYBIND11_MODULE("autograd", m) {
py::class_<AutogradFunctionBase, std::shared_ptr<AutogradFunctionBase>>(m, "AutogradFunctionBase")
.def(py::init([]() { return std::make_shared<AutogradFunctionBase>(); }))
.def_static("apply",
[](const std::string& name, const py::function& forward_fn,
const py::function& backward_fn, const py::args& input) -> Maybe<py::object> {
const auto& input_tensor_tuple = JUST(UnpackTensorTuple(input));
const std::shared_ptr<TensorTuple>& res = JUST(AutogradFunctionBase::Apply(
name, PackPyFunctionToFType(forward_fn), PackPyFunctionToFType(backward_fn),
*input_tensor_tuple));
return PackTensorTuple(*res);
});
py::class_<FunctionAutoGradCaptureState, std::shared_ptr<FunctionAutoGradCaptureState>>(
m, "FunctionAutoGradCaptureState")
.def(py::init([]() { return std::make_shared<FunctionAutoGradCaptureState>(); }))
.def("save_for_backward",
[](FunctionAutoGradCaptureState& ctx, const py::args& input) {
const auto& tensors = UnpackTensorTuple(input).GetOrThrow();
for (const auto& tensor : tensors) { ctx.SaveTensorForBackward(tensor); }
})
.def_property_readonly(
"saved_tensors",
[](const FunctionAutoGradCaptureState& ctx) { return py::cast(ctx.SavedTensors()); })
.def("mark_non_differentiable", [](FunctionAutoGradCaptureState& ctx, const py::args& input) {
const auto& tensors = UnpackTensorTuple(input).GetOrThrow();
for (const auto& tensor : tensors) { ctx.MarkNonDifferentiable(tensor); }
});
}
} // namespace one
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include <memory>
#include <pybind11/pybind11.h>
#include "oneflow/api/python/of_api_registry.h"
#include "oneflow/core/autograd/autograd_mode.h"
namespace py = pybind11;
namespace oneflow {
namespace autograd {
ONEFLOW_API_PYBIND11_MODULE("autograd", m) {
py::class_<AutoGradMode, std::shared_ptr<AutoGradMode>>(m, "AutoGradMode")
.def(py::init([](bool mode) { return std::make_shared<AutoGradMode>(mode); }))
.def("__enter__", [](const AutoGradMode& no_grad_obj) {})
.def("__exit__", [](const AutoGradMode& no_grad_obj, const py::object& type,
const py::object& value, const py::object& traceback) {});
m.def("is_grad_enabled", &GradMode::is_enabled);
}
} // namespace autograd
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include <vector>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "oneflow/api/python/of_api_registry.h"
#include "oneflow/core/autograd/autograd_engine.h"
namespace py = pybind11;
namespace oneflow {
namespace {
struct FunctionNodeUtil final {
static std::string ToString(const one::FunctionNode& func_node) {
std::stringstream ss;
ss << "<";
ss << func_node.name();
ss << " at " << &func_node;
ss << ">";
return ss.str();
}
};
} // namespace
ONEFLOW_API_PYBIND11_MODULE("", m) {
py::class_<one::FunctionNode, std::shared_ptr<one::FunctionNode>>(m, "FunctionNode")
.def("__str__", &FunctionNodeUtil::ToString)
.def("__repr__", &FunctionNodeUtil::ToString)
.def("_register_hook_dict", []() { TODO(); })
.def_property_readonly(
"next_functions",
[](const one::FunctionNode& func_node) { return func_node.next_functions(); })
.def_property_readonly("metadata", []() { TODO(); })
.def_property_readonly("requires_grad", []() { TODO(); })
.def("register_hook", []() { TODO(); })
.def("name", [](const one::FunctionNode& func_node) { return func_node.name(); });
}
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include <type_traits>
#include <pybind11/pybind11.h>
namespace pybind11 {
namespace detail {
// The condition follows the pybind11 source code
template<typename T>
using IsSupportedByPybind11WhenInsideSharedPtr =
std::is_base_of<type_caster_base<T>, type_caster<T>>;
#define PYBIND11_TYPE_CASTER_WITH_SHARED_PTR(type, py_name) \
protected: \
std::shared_ptr<type> value; \
\
public: \
static constexpr auto name = py_name; \
template<typename T_, enable_if_t<std::is_same<type, remove_cv_t<T_>>::value, int> = 0> \
static handle cast(T_* src, return_value_policy policy, handle parent) { \
if (!src) return none().release(); \
if (policy == return_value_policy::take_ownership) { \
auto h = cast(std::move(*src), policy, parent); \
delete src; \
return h; \
} \
return cast(*src, policy, parent); \
} \
operator type*() { return value.get(); } \
operator type&() { return *value; } \
operator type&&()&& { return std::move(*value); } \
template<typename T_> \
using cast_op_type = pybind11::detail::movable_cast_op_type<T_>
} // namespace detail
} // namespace pybind11
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include <pybind11/pybind11.h>
#include "oneflow/api/python/caster/common.h"
#include "oneflow/core/common/maybe.h"
namespace pybind11 {
namespace detail {
using oneflow::Maybe;
namespace impl {
template<typename T>
using IsHoldedInsideSharedPtrByMaybe =
std::is_same<decltype(
std::declval<Maybe<T>>().Data_YouAreNotAllowedToCallThisFuncOutsideThisFile()),
std::shared_ptr<T>>;
template<typename T, typename std::enable_if_t<IsSupportedByPybind11WhenInsideSharedPtr<T>::value
&& IsHoldedInsideSharedPtrByMaybe<T>::value,
int> = 0>
std::shared_ptr<T> GetOrThrowHelper(Maybe<T> x) {
return x.GetPtrOrThrow();
}
template<typename T, typename std::enable_if_t<!IsSupportedByPybind11WhenInsideSharedPtr<T>::value
|| !IsHoldedInsideSharedPtrByMaybe<T>::value,
int> = 0>
T GetOrThrowHelper(Maybe<T> x) {
return x.GetOrThrow();
}
} // namespace impl
// Information about pybind11 custom type caster can be found
// at oneflow/api/python/caster/optional.h, and also at
// https://pybind11.readthedocs.io/en/stable/advanced/cast/custom.html
template<typename Type>
struct maybe_caster {
using Value = decltype(impl::GetOrThrowHelper(std::declval<Type>()));
using value_conv = make_caster<Value>;
bool load(handle src, bool convert) {
if (!src) { return false; }
if (src.is_none()) {
// Maybe<T> (except Maybe<void>) does not accept `None` from Python. Users can use Optional in
// those cases.
return false;
}
value_conv inner_caster;
if (!inner_caster.load(src, convert)) { return false; }
value = std::make_shared<Type>(cast_op<Value&&>(std::move(inner_caster)));
return true;
}
template<typename T>
static handle cast(T&& src, return_value_policy policy, handle parent) {
if (!std::is_lvalue_reference<T>::value) {
policy = return_value_policy_override<Value>::policy(policy);
}
return value_conv::cast(impl::GetOrThrowHelper(std::forward<T>(src)), policy, parent);
}
PYBIND11_TYPE_CASTER_WITH_SHARED_PTR(Maybe<void>, _("Maybe[void]"));
};
template<>
struct maybe_caster<Maybe<void>> {
template<typename T>
static handle cast(T&& src, return_value_policy policy, handle parent) {
if (!src.IsOk()) { oneflow::ThrowError(src.error()); }
return none().inc_ref();
}
bool load(handle src, bool convert) {
if (src && src.is_none()) {
return true; // None is accepted because NoneType (i.e. void) is the value type of
// Maybe<void>
}
return false;
}
PYBIND11_TYPE_CASTER_WITH_SHARED_PTR(Maybe<void>, _("Maybe[void]"));
};
template<typename T>
struct type_caster<Maybe<T>> : public maybe_caster<Maybe<T>> {};
} // namespace detail
} // namespace pybind11
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include <pybind11/pybind11.h>
#include "oneflow/api/python/caster/common.h"
#include "oneflow/core/common/optional.h"
namespace pybind11 {
namespace detail {
using oneflow::Optional;
namespace impl {
template<typename T>
T& DeferenceIfSharedPtr(std::shared_ptr<T> ptr) {
return *ptr;
}
template<typename T>
T&& DeferenceIfSharedPtr(T&& obj) {
return std::forward<T>(obj);
}
template<typename T>
using IsHoldedInsideSharedPtrByOptional =
std::is_same<typename Optional<T>::storage_type, std::shared_ptr<T>>;
template<typename T, typename std::enable_if_t<IsSupportedByPybind11WhenInsideSharedPtr<T>::value
&& IsHoldedInsideSharedPtrByOptional<T>::value,
int> = 0>
std::shared_ptr<T> GetDataHelper(Optional<T> x) {
return CHECK_JUST(x);
}
template<typename T, typename std::enable_if_t<!IsSupportedByPybind11WhenInsideSharedPtr<T>::value
|| !IsHoldedInsideSharedPtrByOptional<T>::value,
int> = 0>
T GetDataHelper(Optional<T> x) {
return DeferenceIfSharedPtr<T>(CHECK_JUST(x));
}
} // namespace impl
// Code is copied from pybind11 include/pybind11/stl.h
// Comments wrapped by /* */ are copied from
// https://pybind11.readthedocs.io/en/stable/advanced/cast/custom.html
template<typename Type>
struct oneflow_optional_caster {
using Value = decltype(impl::GetDataHelper(std::declval<Type>()));
using value_conv = make_caster<Value>;
/**
* Conversion part 1 (Python->C++): convert a PyObject into a Optional<T>
* instance or return false upon failure. The second argument
* indicates whether implicit conversions should be applied.
*/
bool load(handle src, bool convert) {
if (!src) { return false; }
if (src.is_none()) {
return true; // default-constructed value is already empty
}
value_conv inner_caster;
if (!inner_caster.load(src, convert)) { return false; }
value = cast_op<Value&&>(std::move(inner_caster));
return true;
}
/**
* Conversion part 2 (C++ -> Python): convert an Optional<T> instance into
* a Python object. The second and third arguments are used to
* indicate the return value policy and parent object (for
* ``return_value_policy::reference_internal``) and are generally
* ignored by implicit casters.
*/
template<typename T>
static handle cast(T&& src, return_value_policy policy, handle parent) {
if (!src) { return none().inc_ref(); }
if (!std::is_lvalue_reference<T>::value) {
policy = return_value_policy_override<Value>::policy(policy);
}
return value_conv::cast(impl::GetDataHelper(std::forward<T>(src)), policy, parent);
}
/**
* This macro establishes the name 'Optional[T]' in
* function signatures and declares a local variable
* 'value' of type inty
*/
PYBIND11_TYPE_CASTER(Type, _("Optional[") + value_conv::name + _("]"));
};
template<typename T>
struct type_caster<Optional<T>> : public oneflow_optional_caster<Optional<T>> {};
} // namespace detail
} // namespace pybind11
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include <pybind11/pybind11.h>
#include "oneflow/api/python/caster/common.h"
#include "oneflow/api/python/framework/tensor.h"
namespace pybind11 {
namespace detail {
template<typename T>
struct tensor_type_caster {
public:
bool load(handle src, bool convert) {
using namespace oneflow::one;
value_ = nullptr;
if (!src) { return false; }
if (src.is_none()) { return true; }
if (!PyTensor_Check(src.ptr())) { return false; }
value_ = PyTensor_Unpack(src.ptr());
return true;
}
template<typename U>
static handle cast(U&& src, return_value_policy policy, handle parent) {
using namespace oneflow::one;
return reinterpret_steal<object>(PyTensor_New(std::const_pointer_cast<Tensor>(src))).release();
}
operator std::shared_ptr<T>*() { return &value_; }
operator std::shared_ptr<T>&() { return value_; }
operator std::shared_ptr<T>&&() && { return std::move(value_); }
static constexpr auto name = _("tensor");
template<typename U>
using cast_op_type = pybind11::detail::cast_op_type<std::shared_ptr<T>>;
protected:
std::shared_ptr<T> value_;
};
template<typename T>
struct parameter_type_caster {
public:
bool load(handle src, bool convert) {
using namespace oneflow::one;
value_ = nullptr;
if (!src) { return false; }
if (src.is_none()) { return true; }
if (!PyTensor_Check(src.ptr())) { return false; }
value_ = PyTensor_Unpack(src.ptr());
return true;
}
template<typename U>
static handle cast(U&& src, return_value_policy policy, handle parent) {
using namespace oneflow::one;
return reinterpret_steal<object>(PyParameter_New(std::const_pointer_cast<Parameter>(src)))
.release();
}
operator std::shared_ptr<T>*() { return &value_; }
operator std::shared_ptr<T>&() { return value_; }
operator std::shared_ptr<T>&&() && { return std::move(value_); }
static constexpr auto name = _("parameter");
template<typename U>
using cast_op_type = pybind11::detail::cast_op_type<std::shared_ptr<T>>;
protected:
std::shared_ptr<T> value_;
};
template<>
struct type_caster<std::shared_ptr<oneflow::one::Tensor>>
: public tensor_type_caster<oneflow::one::Tensor> {};
template<>
struct type_caster<std::shared_ptr<const oneflow::one::Tensor>>
: public tensor_type_caster<const oneflow::one::Tensor> {};
template<>
struct type_caster<std::shared_ptr<oneflow::one::Parameter>>
: public parameter_type_caster<oneflow::one::Parameter> {};
template<>
struct type_caster<std::shared_ptr<const oneflow::one::Parameter>>
: public parameter_type_caster<const oneflow::one::Parameter> {};
} // namespace detail
} // namespace pybind11
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include <pybind11/pybind11.h>
#include "oneflow/api/python/of_api_registry.h"
namespace py = pybind11;
namespace oneflow {
class A {
public:
void inc_x() { x++; }
int get_x() { return x; }
private:
int x = 0;
};
std::shared_ptr<A> get_singleton_a() {
static std::shared_ptr<A> a = std::make_shared<A>();
return a;
}
ONEFLOW_API_PYBIND11_MODULE("test_api", m) {
py::class_<A, std::shared_ptr<A>>(m, "A").def("inc_x", &A::inc_x).def("get_x", &A::get_x);
m.def("get_singleton_a", []() -> Maybe<A> { return get_singleton_a(); });
m.def("increase_x_of_a_if_not_none", [](const Optional<A>& a) -> Optional<A> {
a.map([](const std::shared_ptr<A>& a) -> std::shared_ptr<A> {
a->inc_x();
return a;
});
return a;
});
m.def("increase_if_not_none",
[](const Optional<int>& x) -> Optional<int> { return x.map([](int i) { return i + 1; }); });
m.def("divide", [](float x, float y) -> Maybe<float> {
CHECK_NE_OR_RETURN(y, 0);
return x / y;
});
m.def("throw_if_zero", [](int x) -> Maybe<void> {
CHECK_NE_OR_RETURN(x, 0);
return Maybe<void>::Ok();
});
}
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include <pybind11/pybind11.h>
#include "oneflow/api/python/of_api_registry.h"
#include "oneflow/core/common/maybe.h"
#include "oneflow/core/common/protobuf.h"
#include "oneflow/core/framework/dtype.h"
namespace py = pybind11;
namespace oneflow {
ONEFLOW_API_PYBIND11_MODULE("deprecated", m) {
m.def("GetProtoDtype4OfDtype",
[](const Symbol<DType>& x) { return static_cast<int>(x->data_type()); });
m.def("GetDTypeByDataType",
[](int data_type) { return DType::Get(static_cast<DataType>(data_type)); });
}
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include <pybind11/pybind11.h>
#include "oneflow/api/python/of_api_registry.h"
#include "oneflow/core/vm/vm_util.h"
#include "oneflow/core/eager/dev_vm_dep_object_consume_mode.h"
ONEFLOW_API_PYBIND11_MODULE("eager", m) {
using namespace oneflow;
namespace py = pybind11;
m.def(
"Sync", []() { return vm::ClusterSync(); }, py::call_guard<py::gil_scoped_release>());
py::class_<one::DevVmDepObjectConsumeModeGuard,
std::shared_ptr<one::DevVmDepObjectConsumeModeGuard>>(
m, "DevVmDepObjectConsumeModeGuard");
m.def("SourceOpOnlyResourceDependenceModeGuard", []() {
return std::make_shared<one::DevVmDepObjectConsumeModeGuard>(
one::DevVmDepObjectConsumeMode::NONE);
});
}
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include <pybind11/pybind11.h>
#include "oneflow/api/python/env/env.h"
#include "oneflow/api/python/of_api_registry.h"
#include "oneflow/core/job/env_global_objects_scope.h"
#include "oneflow/core/common/singleton.h"
#include "oneflow/core/vm/vm_util.h"
#include "oneflow/core/vm/virtual_machine.h"
#include "oneflow/core/framework/shut_down_util.h"
#include "oneflow/core/device/cuda_util.h"
namespace py = pybind11;
namespace oneflow {
Maybe<void> SwitchToShuttingDownPhase(EnvGlobalObjectsScope* env, bool is_normal_exit) {
if (is_normal_exit) {
JUST(vm::ClusterSync());
auto* vm = JUST(SingletonMaybe<VirtualMachine>());
JUST(vm->CloseVMThreads());
}
JUST(env->init_is_normal_exit(is_normal_exit));
SetShuttingDown(true);
return Maybe<void>::Ok();
}
ONEFLOW_API_PYBIND11_MODULE("", m) {
m.def("CurrentResource", &CurrentResource);
m.def("EnvResource", &EnvResource);
m.def("EnableEagerEnvironment", &EnableEagerEnvironment);
py::class_<oneflow::EnvGlobalObjectsScope, std::shared_ptr<oneflow::EnvGlobalObjectsScope>>(
m, "EnvContext")
.def(py::init<const std::string&>())
.def("SwitchToShuttingDownPhase", &SwitchToShuttingDownPhase,
py::call_guard<py::gil_scoped_release>());
m.def("CurrentMachineId", &CurrentMachineId);
m.def("GetRank", &GetRank);
m.def("GetWorldSize", &GetWorldSize);
m.def("GetNodeSize", &GetNodeSize);
m.def("GetLocalRank", &GetLocalRank);
m.def("InitRDMA", &InitRDMA);
m.def("RDMAIsInitialized", &RDMAIsInitialized);
m.def("CudaGetDeviceCount", &CudaGetDeviceCount);
m.def("EmptyCache", &EmptyCache);
#ifdef WITH_CUDA
m.def("GetCudaDeviceIndex", &GetCudaDeviceIndex);
m.def("SetCudaDeviceIndex", &SetCudaDeviceIndex);
m.def("CudaSynchronize", &CudaSynchronize);
m.def("GetCUDAMemoryUsed", &GetCUDAMemoryUsed);
#endif // WITH_CUDA
#ifdef WITH_ROCM
m.def("GetCudaDeviceIndex", &GetCudaDeviceIndex);
m.def("SetCudaDeviceIndex", &SetCudaDeviceIndex);
m.def("CudaSynchronize", &CudaSynchronize);
m.def("GetCUDAMemoryUsed", &GetCUDAMemoryUsed);
#endif // WITH_ROCM
m.def("SetFLAGS_alsologtostderr", &SetFLAGS_alsologtostderr);
m.def("GetFLAGS_alsologtostderr", &GetFLAGS_alsologtostderr);
m.def("SetFLAGS_v", &SetFLAGS_v);
m.def("GetFLAGS_v", &GetFLAGS_v);
m.def("SetGraphLRVerbose", &SetGraphLRVerbose);
m.def("GetGraphLRVerbose", &GetGraphLRVerbose);
m.def("SetGraphDebugMaxPyStackDepth", &SetGraphDebugMaxPyStackDepth);
m.def("GetGraphDebugMaxPyStackDepth", &GetGraphDebugMaxPyStackDepth);
m.def("SetGraphDebugMode", &SetGraphDebugMode);
m.def("GetGraphDebugMode", &GetGraphDebugMode);
}
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_API_PYTHON_ENV_ENV_H_
#define ONEFLOW_API_PYTHON_ENV_ENV_H_
#include <string>
#include <google/protobuf/text_format.h>
#include "oneflow/core/common/protobuf.h"
#include "oneflow/core/common/singleton.h"
#include "oneflow/core/job/cluster.h"
#include "oneflow/core/job/cluster_instruction.h"
#include "oneflow/core/job/env_global_objects_scope.h"
#include "oneflow/core/job/global_for.h"
#include "oneflow/core/job/resource_desc.h"
#include "oneflow/core/job/graph_scope_vars.h"
#include "oneflow/core/control/global_process_ctx.h"
#include "oneflow/core/rpc/include/base.h"
#include "oneflow/core/ep/include/device_manager_registry.h"
#include "oneflow/core/vm/vm_util.h"
#include "oneflow/core/vm/virtual_machine.h"
namespace oneflow {
inline Maybe<std::string> CurrentResource() {
CHECK_NOTNULL_OR_RETURN((Singleton<ResourceDesc, ForSession>::Get()));
return PbMessage2TxtString(Singleton<ResourceDesc, ForSession>::Get()->resource());
}
inline Maybe<std::string> EnvResource() {
CHECK_NOTNULL_OR_RETURN((Singleton<ResourceDesc, ForEnv>::Get()));
return PbMessage2TxtString(Singleton<ResourceDesc, ForEnv>::Get()->resource());
}
inline Maybe<void> EnableEagerEnvironment(bool enable_eager_execution) {
CHECK_NOTNULL_OR_RETURN((Singleton<bool, EagerExecution>::Get()));
*Singleton<bool, EagerExecution>::Get() = enable_eager_execution;
return Maybe<void>::Ok();
}
inline Maybe<long long> CurrentMachineId() { return GlobalProcessCtx::Rank(); }
inline Maybe<int64_t> GetRank() { return GlobalProcessCtx::Rank(); }
inline Maybe<size_t> GetWorldSize() { return GlobalProcessCtx::WorldSize(); }
inline Maybe<size_t> GetNodeSize() { return GlobalProcessCtx::NodeSize(); }
inline Maybe<size_t> GetLocalRank() { return GlobalProcessCtx::LocalRank(); }
inline Maybe<size_t> CudaGetDeviceCount() {
return Singleton<ep::DeviceManagerRegistry>::Get()->GetDeviceCount(DeviceType::kCUDA);
}
inline Maybe<void> SetFLAGS_alsologtostderr(bool flag) {
FLAGS_alsologtostderr = flag;
return Maybe<void>::Ok();
}
inline Maybe<bool> GetFLAGS_alsologtostderr() {
return FLAGS_alsologtostderr;
} // namespace oneflow
inline Maybe<void> SetFLAGS_v(int32_t v_level) {
FLAGS_v = v_level;
return Maybe<void>::Ok();
}
inline Maybe<int32_t> GetFLAGS_v() { return FLAGS_v; }
inline Maybe<void> EmptyCache() {
JUST(vm::CurrentRankSync());
auto* vm = JUST(SingletonMaybe<VirtualMachine>());
JUST(vm->ShrinkAllMem());
return Maybe<void>::Ok();
}
inline Maybe<void> SetGraphLRVerbose(bool verbose) {
SetGraphVerboseStepLr(verbose);
return Maybe<void>::Ok();
}
inline bool GetGraphLRVerbose() { return IsOpenGraphVerboseStepLr(); }
} // namespace oneflow
#endif // ONEFLOW_API_PYTHON_ENV_ENV_H_
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include <pybind11/pybind11.h>
#include "oneflow/core/common/exception.h"
#include "oneflow/core/common/error.h"
#include "oneflow/api/python/of_api_registry.h"
namespace py = pybind11;
namespace oneflow {
ONEFLOW_API_PYBIND11_MODULE("exception", m) {
m.def("GetThreadLocalLastError", &ThreadLocalError);
py::register_exception<oneflow::Exception>(m, "Exception");
py::register_exception<oneflow::RuntimeException>(m, "RuntimeError", PyExc_RuntimeError);
py::register_exception<oneflow::TypeException>(m, "TypeError", PyExc_TypeError);
py::register_exception<oneflow::IndexException>(m, "IndexError", PyExc_IndexError);
py::register_exception<oneflow::NotImplementedException>(m, "NotImplementedError",
PyExc_NotImplementedError);
}
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_API_PYTHON_COMMON_EXCEPTION_H_
#define ONEFLOW_API_PYTHON_COMMON_EXCEPTION_H_
#include <Python.h>
#include <pybind11/pybind11.h>
#include "oneflow/core/common/exception.h"
namespace py = pybind11;
#define HANDLE_ERRORS try {
#define END_HANDLE_ERRORS_RETSTMT(retstmt) \
} \
catch (py::error_already_set & e) { \
e.restore(); \
retstmt; \
} \
catch (const oneflow::RuntimeException& e) { \
PyErr_SetString(PyExc_RuntimeError, e.what()); \
retstmt; \
} \
catch (const oneflow::IndexException& e) { \
PyErr_SetString(PyExc_IndexError, e.what()); \
retstmt; \
} \
catch (const oneflow::TypeException& e) { \
PyErr_SetString(PyExc_TypeError, e.what()); \
retstmt; \
} \
catch (const oneflow::NotImplementedException& e) { \
PyErr_SetString(PyExc_NotImplementedError, e.what()); \
retstmt; \
} \
catch (const std::exception& e) { \
PyErr_SetString(PyExc_RuntimeError, e.what()); \
retstmt; \
}
#define END_HANDLE_ERRORS END_HANDLE_ERRORS_RETSTMT(return NULL)
#define END_HANDLE_ERRORS_RET(retval) END_HANDLE_ERRORS_RETSTMT(return retval)
#define END_HANDLE_ERRORS_NORET END_HANDLE_ERRORS_RETSTMT(void)
#endif // ONEFLOW_API_PYTHON_COMMON_EXCEPTION_H_
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/api/python/of_api_registry.h"
#ifdef WITH_CUDA
#include <cuda.h>
#endif
namespace oneflow {
ONEFLOW_API_PYBIND11_MODULE("flags", m) {
m.def("with_cuda", []() {
#ifdef WITH_CUDA
return true;
#else
return false;
#endif // WITH_CUDA
});
m.def("cuda_version", []() {
#ifdef WITH_CUDA
return CUDA_VERSION;
#else
return 0;
#endif // WITH_CUDA
});
m.def("use_cxx11_abi", []() {
#if _GLIBCXX_USE_CXX11_ABI == 1
return true;
#else
return false;
#endif // _GLIBCXX_USE_CXX11_ABI
});
m.def("with_mlir", []() {
#ifdef WITH_MLIR
return true;
#else
return false;
#endif // WITH_MLIR
});
m.def("with_mlir_cuda_codegen", []() {
#ifdef WITH_MLIR_CUDA_CODEGEN
return true;
#else
return false;
#endif // WITH_MLIR_CUDA_CODEGEN
});
m.def("with_rdma", []() {
#ifdef WITH_RDMA
return true;
#else
return false;
#endif // WITH_RDMA
});
m.def("has_rpc_backend_grpc", []() {
#ifdef RPC_BACKEND_GRPC
return true;
#else
return false;
#endif // RPC_BACKEND_GRPC
});
m.def("has_rpc_backend_local", []() {
#ifdef RPC_BACKEND_LOCAL
return true;
#else
return false;
#endif // RPC_BACKEND_LOCAL
});
#define STRINGIFY(x) STRINGIFY_(x)
#define STRINGIFY_(x) #x
m.def("cmake_build_type", []() {
#ifdef ONEFLOW_CMAKE_BUILD_TYPE
return std::string(STRINGIFY(ONEFLOW_CMAKE_BUILD_TYPE));
#else
return std::string("Undefined");
#endif // ONEFLOW_CMAKE_BUILD_TYPE
});
#undef STRINGIFY
#undef STRINGIFY_
}
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include <pybind11/pybind11.h>
#include <pybind11/operators.h>
#include "oneflow/core/control/global_process_ctx.h"
#include "oneflow/api/python/of_api_registry.h"
#include "oneflow/core/framework/device.h"
#include "oneflow/core/common/str_util.h"
#include "oneflow/core/control/global_process_ctx.h"
namespace py = pybind11;
namespace oneflow {
ONEFLOW_API_PYBIND11_MODULE("", m) {
py::class_<Symbol<Device>, std::shared_ptr<Symbol<Device>>>(m, "device")
.def(py::init([](const std::string& type_or_type_with_device_id) {
return Device::ParseAndNew(type_or_type_with_device_id).GetOrThrow();
}))
.def(py::init([](const std::string& type, int64_t device_id) {
return Device::New(type, device_id).GetOrThrow();
}))
.def_property_readonly("type", [](const Symbol<Device>& d) { return d->type(); })
.def_property_readonly("index", [](const Symbol<Device>& d) { return d->device_id(); })
.def("__str__", [](const Symbol<Device>& d) { return d->ToString(); })
.def("__repr__", [](const Symbol<Device>& d) { return d->ToRepr(); })
.def(py::self == py::self)
.def(py::hash(py::self));
}
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include <pybind11/pybind11.h>
#include "oneflow/api/python/of_api_registry.h"
#include "oneflow/core/common/throw.h"
namespace py = pybind11;
namespace oneflow {
py::object AddFunctionDoc(py::object f, const std::string& doc_string) {
static std::vector<std::string> all_doc_strings;
all_doc_strings.emplace_back(doc_string);
const char* doc_str = all_doc_strings.back().c_str();
PyObject* obj = f.ptr();
if (PyCFunction_Check(obj)) {
auto* f = (PyCFunctionObject*)obj;
if (f->m_ml->ml_doc) {
THROW(RuntimeError) << "function " << f->m_ml->ml_name << " already has a docstring "
<< "shows: " << f->m_ml->ml_doc;
}
f->m_ml->ml_doc = doc_str;
} else if (PyFunction_Check(obj)) {
auto* f = (PyFunctionObject*)obj;
if (f->func_doc != Py_None) {
THROW(RuntimeError) << "function "
<< PyBytes_AsString(
PyUnicode_AsEncodedString(f->func_name, "utf-8", "~E~"))
<< " already has a docstring";
}
f->func_doc = PyUnicode_FromString(doc_str);
} else if (strcmp(Py_TYPE(obj)->tp_name, "method_descriptor") == 0) {
PyMethodDescrObject* f = (PyMethodDescrObject*)obj;
if (f->d_method->ml_doc) {
THROW(RuntimeError) << "function " << f->d_method->ml_name << "already has a docstring";
}
f->d_method->ml_doc = doc_str;
} else if (strcmp(Py_TYPE(obj)->tp_name, "getset_descriptor") == 0) {
PyMethodDescrObject* f = (PyMethodDescrObject*)obj;
if (f->d_method->ml_doc) {
THROW(RuntimeError) << "function " << f->d_method->ml_name << "already has a docstring";
}
f->d_method->ml_doc = doc_str;
} else if (py::isinstance<py::detail::generic_type>(f)) {
if (py::hasattr(f, "__doc__")) {
auto doc = py::getattr(f, "__doc__");
if (!doc.is(py::none())) {
THROW(RuntimeError) << Py_TYPE(obj)->tp_name << " already has a docstring";
}
}
py::setattr(f, "__doc__", py::reinterpret_steal<py::object>(PyUnicode_FromString(doc_str)));
} else if (Py_TYPE(obj)->tp_name == PyProperty_Type.tp_name) {
py::setattr(f, "__doc__", py::reinterpret_steal<py::object>(PyUnicode_FromString(doc_str)));
} else if (PyInstanceMethod_Check(obj)) {
auto* f = (PyCFunctionObject*)(PyInstanceMethod_Function(obj));
f->m_ml->ml_doc = doc_str;
} else {
THROW(RuntimeError) << "function is " << Py_TYPE(obj)->tp_name << ", not a valid function";
}
f.inc_ref();
return f;
}
py::object ReplaceDoc(py::object f, const std::string& doc_string) {
static std::vector<std::string> all_doc_strings;
all_doc_strings.emplace_back(doc_string);
const char* doc_str = all_doc_strings.back().c_str();
PyObject* obj = f.ptr();
if (PyCFunction_Check(obj)) {
auto* f = (PyCFunctionObject*)obj;
if (!f->m_ml->ml_doc) {
THROW(RuntimeError) << "function " << f->m_ml->ml_name << " has not a docstring yet.";
}
f->m_ml->ml_doc = doc_str;
} else if (PyFunction_Check(obj)) {
auto* f = (PyFunctionObject*)obj;
if (f->func_doc == Py_None) {
THROW(RuntimeError) << "function "
<< PyBytes_AsString(
PyUnicode_AsEncodedString(f->func_name, "utf-8", "~E~"))
<< " has not a docstring yet.";
}
Py_DECREF(f->func_doc);
f->func_doc = PyUnicode_FromString(doc_str);
} else {
THROW(RuntimeError) << "function is " << Py_TYPE(obj)->tp_name << ", not a valid function.";
}
f.inc_ref();
return f;
}
} // namespace oneflow
ONEFLOW_API_PYBIND11_MODULE("", m) {
m.def("add_doc", &oneflow::AddFunctionDoc);
m.def("reset_doc", &oneflow::ReplaceDoc);
}
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include <pybind11/pybind11.h>
#include <pybind11/operators.h>
#include "oneflow/api/python/of_api_registry.h"
#include "oneflow/core/framework/dtype.h"
namespace py = pybind11;
namespace oneflow {
ONEFLOW_API_PYBIND11_MODULE("", m) {
py::class_<Symbol<DType>, std::shared_ptr<Symbol<DType>>>(m, "dtype")
.def_property_readonly("is_signed", [](const Symbol<DType>& d) { return d->is_signed(); })
.def_property_readonly("is_complex", [](const Symbol<DType>& d) { return d->is_complex(); })
.def_property_readonly("is_floating_point",
[](const Symbol<DType>& d) { return d->is_floating_point(); })
.def("__str__", [](const Symbol<DType>& d) { return d->name(); })
.def("__repr__", [](const Symbol<DType>& d) { return d->name(); })
.def(py::self == py::self)
.def(py::hash(py::self))
.def(py::pickle(
[](const Symbol<DType>& dtype) { // __getstate__
return static_cast<int>(dtype->data_type());
},
[](int t) { // __setstate__
return CHECK_JUST(DType::Get(DataType(t)));
}))
.def_property_readonly("bytes", [](const Symbol<DType>& dtype) { return dtype->bytes(); })
.def("get", [](const int data_type_enum) {
return CHECK_JUST(DType::Get(static_cast<DataType>(data_type_enum)));
});
m.attr("bool") = &CHECK_JUST(DType::Get(DataType::kBool));
m.attr("char") = &CHECK_JUST(DType::Get(DataType::kChar));
m.attr("float16") = &CHECK_JUST(DType::Get(DataType::kFloat16));
m.attr("float") = &CHECK_JUST(DType::Get(DataType::kFloat));
m.attr("float32") = &CHECK_JUST(DType::Get(DataType::kFloat));
m.attr("double") = &CHECK_JUST(DType::Get(DataType::kDouble));
m.attr("float64") = &CHECK_JUST(DType::Get(DataType::kDouble));
m.attr("int8") = &CHECK_JUST(DType::Get(DataType::kInt8));
m.attr("int32") = &CHECK_JUST(DType::Get(DataType::kInt32));
m.attr("int64") = &CHECK_JUST(DType::Get(DataType::kInt64));
m.attr("uint8") = &CHECK_JUST(DType::Get(DataType::kUInt8));
m.attr("record") = &CHECK_JUST(DType::Get(DataType::kOFRecord));
m.attr("tensor_buffer") = &CHECK_JUST(DType::Get(DataType::kTensorBuffer));
m.attr("bfloat16") = &CHECK_JUST(DType::Get(DataType::kBFloat16));
m.attr("uint16") = &CHECK_JUST(DType::Get(DataType::kUInt16));
m.attr("uint32") = &CHECK_JUST(DType::Get(DataType::kUInt32));
m.attr("uint64") = &CHECK_JUST(DType::Get(DataType::kUInt64));
m.attr("uint128") = &CHECK_JUST(DType::Get(DataType::kUInt128));
m.attr("int16") = &CHECK_JUST(DType::Get(DataType::kInt16));
m.attr("int128") = &CHECK_JUST(DType::Get(DataType::kInt128));
m.attr("complex32") = &CHECK_JUST(DType::Get(DataType::kComplex32));
m.attr("complex64") = &CHECK_JUST(DType::Get(DataType::kComplex64));
m.attr("complex128") = &CHECK_JUST(DType::Get(DataType::kComplex128));
}
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include <pybind11/pybind11.h>
#include <string>
#include "oneflow/api/python/of_api_registry.h"
#include "oneflow/core/common/util.h"
#include "oneflow/core/job/foreign_callback.h"
namespace py = pybind11;
namespace oneflow {
class PyForeignCallback : public ForeignCallback {
public:
// Inherit the constructors
using ForeignCallback::ForeignCallback;
// Trampoline (need one for each virtual function)
void OfBlobCall(int64_t unique_id, int64_t ofblob_ptr) const override {
PYBIND11_OVERRIDE(void, /* Return type */
ForeignCallback, /* Parent class */
OfBlobCall, /* Name of function in C++ (must match Python name) */
unique_id, ofblob_ptr /* Argument(s) */
);
}
void RemoveForeignCallback(int64_t unique_id) const override {
PYBIND11_OVERRIDE(void, ForeignCallback, RemoveForeignCallback, unique_id);
}
};
} // namespace oneflow
ONEFLOW_API_PYBIND11_MODULE("", m) {
using namespace oneflow;
py::class_<ForeignCallback, PyForeignCallback, std::shared_ptr<ForeignCallback>>(
m, "ForeignCallback")
.def(py::init<>())
.def("OfBlobCall", &ForeignCallback::OfBlobCall)
.def("RemoveForeignCallback", &ForeignCallback::RemoveForeignCallback);
}
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include <pybind11/pybind11.h>
#include <string>
#include "oneflow/api/python/of_api_registry.h"
#include "oneflow/core/common/util.h"
#include "oneflow/core/job/foreign_watcher.h"
namespace py = pybind11;
namespace oneflow {
class PyForeignWatcher : public ForeignWatcher {
public:
using ForeignWatcher::ForeignWatcher;
void Call(const std::string& handler_uuid, int64_t ofblob_ptr) const override {
PYBIND11_OVERRIDE(void, ForeignWatcher, Call, handler_uuid, ofblob_ptr);
}
};
} // namespace oneflow
ONEFLOW_API_PYBIND11_MODULE("", m) {
using namespace oneflow;
py::class_<ForeignWatcher, PyForeignWatcher, std::shared_ptr<ForeignWatcher>>(m, "ForeignWatcher")
.def(py::init<>())
.def("Call", &ForeignWatcher::Call);
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment