/* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include "oneflow/api/python/caster/common.h" #include "oneflow/core/common/maybe.h" namespace pybind11 { namespace detail { using oneflow::Maybe; namespace impl { template using IsHoldedInsideSharedPtrByMaybe = std::is_same>().Data_YouAreNotAllowedToCallThisFuncOutsideThisFile()), std::shared_ptr>; template::value && IsHoldedInsideSharedPtrByMaybe::value, int> = 0> std::shared_ptr GetOrThrowHelper(Maybe x) { return x.GetPtrOrThrow(); } template::value || !IsHoldedInsideSharedPtrByMaybe::value, int> = 0> T GetOrThrowHelper(Maybe x) { return x.GetOrThrow(); } } // namespace impl // Information about pybind11 custom type caster can be found // at oneflow/api/python/caster/optional.h, and also at // https://pybind11.readthedocs.io/en/stable/advanced/cast/custom.html template struct maybe_caster { using Value = decltype(impl::GetOrThrowHelper(std::declval())); using value_conv = make_caster; bool load(handle src, bool convert) { if (!src) { return false; } if (src.is_none()) { // Maybe (except Maybe) does not accept `None` from Python. Users can use Optional in // those cases. return false; } value_conv inner_caster; if (!inner_caster.load(src, convert)) { return false; } value = std::make_shared(cast_op(std::move(inner_caster))); return true; } template static handle cast(T&& src, return_value_policy policy, handle parent) { if (!std::is_lvalue_reference::value) { policy = return_value_policy_override::policy(policy); } return value_conv::cast(impl::GetOrThrowHelper(std::forward(src)), policy, parent); } PYBIND11_TYPE_CASTER_WITH_SHARED_PTR(Maybe, _("Maybe[void]")); }; template<> struct maybe_caster> { template static handle cast(T&& src, return_value_policy policy, handle parent) { if (!src.IsOk()) { oneflow::ThrowError(src.error()); } return none().inc_ref(); } bool load(handle src, bool convert) { if (src && src.is_none()) { return true; // None is accepted because NoneType (i.e. void) is the value type of // Maybe } return false; } PYBIND11_TYPE_CASTER_WITH_SHARED_PTR(Maybe, _("Maybe[void]")); }; template struct type_caster> : public maybe_caster> {}; } // namespace detail } // namespace pybind11