Commit abfbe166 authored by yuguo960516yuguo's avatar yuguo960516yuguo
Browse files

dtk23.04

parent f262efc9
Pipeline #250 failed with stages
in 0 seconds
...@@ -37,10 +37,10 @@ ONEFLOW_API_PYBIND11_MODULE("", m) { ...@@ -37,10 +37,10 @@ ONEFLOW_API_PYBIND11_MODULE("", m) {
})) }))
.def("manual_seed", .def("manual_seed",
[](const std::shared_ptr<one::Generator>& generator, [](const std::shared_ptr<one::Generator>& generator,
const py::object& seed) -> Maybe<void> { const py::object& seed) -> std::shared_ptr<one::Generator> {
int64_t seed_val = JUST(one::functional::PyUnpackLong(seed.ptr())); int64_t seed_val = (one::functional::PyUnpackLong(seed.ptr())).GetOrThrow();
generator->set_current_seed(seed_val); generator->set_current_seed(seed_val);
return Maybe<void>::Ok(); return generator;
}) })
.def("initial_seed", &one::Generator::current_seed) .def("initial_seed", &one::Generator::current_seed)
.def("seed", &one::Generator::seed) .def("seed", &one::Generator::seed)
......
...@@ -25,10 +25,10 @@ ONEFLOW_API_PYBIND11_MODULE("", m) { ...@@ -25,10 +25,10 @@ ONEFLOW_API_PYBIND11_MODULE("", m) {
m.def("GetCurrentScope", &GetCurrentScope); m.def("GetCurrentScope", &GetCurrentScope);
m.def("MakeInitialScope", m.def("MakeInitialScope",
[](const std::string& job_conf_str, Symbol<ParallelDesc> placement, [](const std::string& job_conf_str, Symbol<ParallelDesc> placement,
bool is_mirrored) -> Maybe<Scope> { bool is_local) -> Maybe<Scope> {
JobConfigProto job_conf; JobConfigProto job_conf;
CHECK_OR_RETURN(TxtString2PbMessage(job_conf_str, &job_conf)) << "job conf parse failed"; CHECK_OR_RETURN(TxtString2PbMessage(job_conf_str, &job_conf)) << "job conf parse failed";
return MakeInitialScope(job_conf, placement, is_mirrored); return MakeInitialScope(job_conf, placement, is_local);
}); });
m.def("InitGlobalScopeStack", &InitThreadLocalScopeStack); m.def("InitGlobalScopeStack", &InitThreadLocalScopeStack);
......
...@@ -13,7 +13,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -13,7 +13,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
*/ */
#include <pybind11/pybind11.h>
#include "oneflow/api/python/of_api_registry.h" #include "oneflow/api/python/of_api_registry.h"
#include "oneflow/core/framework/session_util.h" #include "oneflow/core/framework/session_util.h"
...@@ -22,20 +21,9 @@ namespace py = pybind11; ...@@ -22,20 +21,9 @@ namespace py = pybind11;
namespace oneflow { namespace oneflow {
ONEFLOW_API_PYBIND11_MODULE("", m) { ONEFLOW_API_PYBIND11_MODULE("", m) {
py::class_<Session, std::shared_ptr<Session>>(m, "Session") m.def("GetDefaultSessionId", []() -> int64_t { return GetDefaultSessionId().GetOrThrow(); });
.def_property_readonly("id", &Session::id) m.def("RegsterSessionId", &RegsterSessionId);
.def("push_mirrored_strategy_enabled", &Session::PushMirroredStrategyEnabled) m.def("ClearSessionId", &ClearSessionId);
.def("pop_mirrored_strategy_enabled", &Session::PopMirroredStrategyEnabled)
.def("is_mirrored_strategy_enabled", &Session::IsMirroredStrategyEnabled)
.def("is_consistent_strategy_enabled", &Session::IsConsistentStrategyEnabled)
.def("is_mirrored_strategy_enabled_stack_size",
[](const Session* sess) { return sess->is_mirrored_strategy_enabled_stack()->size(); });
m.def("GetDefaultSessionId", &GetDefaultSessionId);
m.def("RegsiterSession", &RegsiterSession);
m.def("GetDefaultSession", &GetDefaultSession);
m.def("ClearSessionById", &ClearSessionById);
} }
} // namespace oneflow } // namespace oneflow
...@@ -37,85 +37,4 @@ Shape TensorSize_AsShape(PyObject* self); ...@@ -37,85 +37,4 @@ Shape TensorSize_AsShape(PyObject* self);
} // namespace oneflow } // namespace oneflow
PYBIND11_NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
class shape : public object {
public:
PYBIND11_OBJECT_CVT(shape, object, oneflow::TensorSize_Check, raw_shape)
explicit shape(size_t size = 0) : object(oneflow::TensorSize_New((ssize_t)size), stolen_t{}) {
if (!m_ptr) pybind11_fail("Could not allocate tensor size object!");
}
size_t size() const { return (size_t)PyTuple_Size(m_ptr); }
bool empty() const { return size() == 0; }
detail::tuple_accessor operator[](size_t index) const { return {*this, index}; }
detail::item_accessor operator[](handle h) const { return object::operator[](h); }
detail::tuple_iterator begin() const { return {*this, 0}; }
detail::tuple_iterator end() const { return {*this, PyTuple_GET_SIZE(m_ptr)}; }
private:
static PyObject* raw_shape(PyObject* op) {
if (oneflow::TensorSize_Check(op)) return handle(op).inc_ref().ptr();
return PyObject_CallFunctionObjArgs((PyObject*)&oneflow::TensorSize_Type, op, NULL);
}
};
PYBIND11_NAMESPACE_BEGIN(detail)
template<typename T>
struct shape_type_caster {
public:
bool load(handle src, bool convert) {
value_ = nullptr;
if (src && src.is_none()) { return true; }
if (!oneflow::TensorSize_Check(src.ptr())) { return false; }
value_ = std::make_shared<T>(oneflow::TensorSize_AsShape(src.ptr()));
return true;
}
template<typename U>
static handle cast(U&& src, return_value_policy /*policy*/, handle /*parent*/) {
return cast_impl(std::forward<U>(src));
}
template<typename U>
static handle cast(U* src, return_value_policy policy, handle parent) {
if (!src) { return none().release(); }
return cast(*src, policy, parent);
}
operator T*() { return value_.get(); }
operator T&() { return *value_; }
operator T&&() && { return std::move(*value_); }
operator std::shared_ptr<T>*() { return &value_; }
operator std::shared_ptr<T>&() { return value_; }
operator std::shared_ptr<T>&&() && { return std::move(value_); }
static constexpr auto name = _("shape");
template<typename U>
using cast_op_type = pybind11::detail::cast_op_type<std::shared_ptr<T>>;
private:
static handle cast_impl(const oneflow::Shape& src) {
return reinterpret_steal<shape>(oneflow::TensorSize_NewFromShape(src)).release();
}
static handle cast_impl(const std::shared_ptr<const oneflow::Shape>& src) {
return reinterpret_steal<shape>(oneflow::TensorSize_NewFromShape(*src)).release();
}
protected:
std::shared_ptr<T> value_;
};
template<>
struct type_caster<oneflow::Shape> : public shape_type_caster<oneflow::Shape> {};
template<>
struct type_caster<std::shared_ptr<oneflow::Shape>> : public shape_type_caster<oneflow::Shape> {};
template<>
struct type_caster<std::shared_ptr<const oneflow::Shape>>
: public shape_type_caster<const oneflow::Shape> {};
PYBIND11_NAMESPACE_END(detail)
PYBIND11_NAMESPACE_END(PYBIND11_NAMESPACE)
#endif // ONEFLOW_API_PYTHON_FRAMEWORK_SIZE_H_ #endif // ONEFLOW_API_PYTHON_FRAMEWORK_SIZE_H_
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include <pybind11/pybind11.h>
#include "oneflow/api/python/of_api_registry.h"
#include "oneflow/api/python/framework/thread.h"
#include "oneflow/core/framework/stream.h"
#include "oneflow/core/framework/stream_set.h"
#include "oneflow/core/framework/stream_guard.h"
namespace py = pybind11;
ONEFLOW_API_PYBIND11_MODULE("", m) {
using namespace oneflow;
py::class_<StreamSet, std::shared_ptr<StreamSet>>(m, "StreamSet")
.def(py::init([](const AsyncThread& async_thread) {
return StreamSet::New(async_thread.thread_uid()).GetPtrOrThrow();
}));
py::class_<StreamGuard, std::shared_ptr<StreamGuard>>(m, "StreamGuard")
.def(py::init([](const std::shared_ptr<StreamSet>& stream_set) {
auto stream_converter = std::make_shared<StreamConverter>(stream_set);
return std::make_shared<StreamGuard>(stream_converter);
}));
}
...@@ -25,7 +25,6 @@ limitations under the License. ...@@ -25,7 +25,6 @@ limitations under the License.
#include "oneflow/api/python/functional/functional_api.yaml.pybind.h" #include "oneflow/api/python/functional/functional_api.yaml.pybind.h"
#include "oneflow/api/python/functional/tensor_api.yaml.pybind.h" #include "oneflow/api/python/functional/tensor_api.yaml.pybind.h"
#include "oneflow/api/python/of_api_registry.h" #include "oneflow/api/python/of_api_registry.h"
#include "oneflow/api/python/ofblob/ofblob.e.h"
#include "oneflow/api/python/utils/tensor_utils.h" #include "oneflow/api/python/utils/tensor_utils.h"
#include "oneflow/core/autograd/autograd_engine.h" #include "oneflow/core/autograd/autograd_engine.h"
#include "oneflow/core/framework/tensor.h" #include "oneflow/core/framework/tensor.h"
...@@ -36,6 +35,7 @@ limitations under the License. ...@@ -36,6 +35,7 @@ limitations under the License.
#include "oneflow/core/framework/placement_utils.h" #include "oneflow/core/framework/placement_utils.h"
#include "oneflow/core/functional/functional.h" #include "oneflow/core/functional/functional.h"
#include "oneflow/core/functional/tensor_index.h" #include "oneflow/core/functional/tensor_index.h"
#include "oneflow/core/kernel/kernel_util.h"
namespace py = pybind11; namespace py = pybind11;
...@@ -55,29 +55,84 @@ namespace one { ...@@ -55,29 +55,84 @@ namespace one {
PyTypeObject* PyTensorObject_Type = NULL; PyTypeObject* PyTensorObject_Type = NULL;
PyTypeObject* PyParameterObject_Type = NULL; PyTypeObject* PyParameterObject_Type = NULL;
namespace {
template<typename T>
struct AllocType {};
#define DEFINE_ALLOC_TYPE(type) \
template<> \
struct AllocType<type> { \
static PyTypeObject** value; \
}; \
PyTypeObject** AllocType<type>::value = &Py##type##Object_Type
DEFINE_ALLOC_TYPE(Tensor);
DEFINE_ALLOC_TYPE(Parameter);
#undef DEFINE_ALLOC_TYPE
template<typename T>
PyObject* PyTensor_wrap(const std::shared_ptr<T>& data, PyTensorObject* bind_pyobj) {
if (!data) { Py_RETURN_NONE; }
PyObject* py_tensor = (PyObject*)data->pyobject();
if (bind_pyobj == nullptr && py_tensor) {
// Has been wrapped by python before
if (data->owns_pyobj()) {
// PyTensor are not alive in python side, so we flip back the ownership to PyTensor
data->set_owns_pyobj(false);
((PyTensorObject*)py_tensor)->data = data;
// NOTE: Needn't incref here, because the reference count of py_tensor is already increased
return py_tensor;
} else {
// PyTensor is alive, so we directly incref it and return it
Py_XINCREF(py_tensor);
return py_tensor;
}
} else {
// Has not been wrapped by python before, so we create a new PyTensor and give it the ownership
if (bind_pyobj == nullptr) {
bind_pyobj = (PyTensorObject*)PyTensorObject_Type->tp_alloc(*AllocType<T>::value, 0);
}
bind_pyobj->data = data;
if (py_tensor) {
// If it has bind pyobj, reset the shared_ptr in origin PyTensorObject
((PyTensorObject*)py_tensor)->data.reset();
}
bind_pyobj->data->set_pyobject_ptr(std::unique_ptr<void, void (*)(void*)>(
bind_pyobj, [](void* ptr) { Py_DECREF((PyObject*)ptr); }));
bind_pyobj->data->set_owns_pyobj(false);
return (PyObject*)bind_pyobj;
}
}
bool PyTensor_tryResurrect(PyObject* py_tensor) {
auto* self = (PyTensorObject*)py_tensor;
if (self->data) {
// PyTensor holds the ownership, now we flip it back to C++ and resurrect python object
// temporarily
auto tensor = self->data;
self->data.reset();
tensor->set_owns_pyobj(true);
Py_XINCREF(py_tensor);
return true;
}
// Otherwise, PyTensor was already not alive in python side
return false;
}
} // namespace
static int PyTensorObject_init(PyObject* self, PyObject* args, PyObject* kwargs) { static int PyTensorObject_init(PyObject* self, PyObject* args, PyObject* kwargs) {
HANDLE_ERRORS HANDLE_ERRORS
auto* temp = functional::_legacy_tensor_ctor(NULL, args, kwargs); auto* temp = functional::_legacy_tensor_ctor(NULL, args, kwargs);
if (PyErr_Occurred()) { throw py::error_already_set(); } if (PyErr_Occurred()) { throw py::error_already_set(); }
auto* _self = (PyTensorObject*)self; PyTensor_wrap<Tensor>(PyTensor_Unpack(temp), (PyTensorObject*)self);
_self->data = PyTensor_Unpack(temp);
_self->data->set_pyobject(self);
// reset temp data to prevent clearing the pyobject
// when the temp is deallocated
((PyTensorObject*)temp)->data.reset();
Py_XDECREF(temp);
return 0; return 0;
END_HANDLE_ERRORS_RET(-1) END_HANDLE_ERRORS_RET(-1)
} }
static void PyTensorObject_dealloc(PyObject* self) { static void PyTensorObject_dealloc(PyObject* self) {
auto* _self = (PyTensorObject*)self; if (PyTensor_tryResurrect(self)) { return; }
// clear pyobject
if (_self->data) {
_self->data->set_pyobject(NULL);
_self->data.reset();
}
// clear __dict__ // clear __dict__
PyObject** dict_ptr = _PyObject_GetDictPtr(self); PyObject** dict_ptr = _PyObject_GetDictPtr(self);
if (dict_ptr) { Py_CLEAR(*dict_ptr); } if (dict_ptr) { Py_CLEAR(*dict_ptr); }
...@@ -96,9 +151,9 @@ static int PyParameterObject_init(PyObject* self, PyObject* args, PyObject* kwar ...@@ -96,9 +151,9 @@ static int PyParameterObject_init(PyObject* self, PyObject* args, PyObject* kwar
return -1; return -1;
} }
if (self) { if (self) {
auto* _self = (PyTensorObject*)self; PyTensor_wrap<Parameter>(
_self->data = ASSERT_PTR(Parameter::MakeTensor(PyTensor_Unpack(data), requires_grad)); ASSERT_PTR(Parameter::MakeTensor(PyTensor_Unpack(data), requires_grad)),
_self->data->set_pyobject(self); (PyTensorObject*)self);
} }
return 0; return 0;
END_HANDLE_ERRORS_RET(-1) END_HANDLE_ERRORS_RET(-1)
...@@ -186,6 +241,16 @@ static PyObject* PyTensorObject_is_pinned(PyObject* self, PyObject* unused) { ...@@ -186,6 +241,16 @@ static PyObject* PyTensorObject_is_pinned(PyObject* self, PyObject* unused) {
END_HANDLE_ERRORS END_HANDLE_ERRORS
} }
static PyObject* PyTensorObject_is_floating_point(PyObject* self, PyObject* unused) {
HANDLE_ERRORS
if (PyTensor_Unpack(self)->dtype()->is_floating_point()) {
Py_RETURN_TRUE;
} else {
Py_RETURN_FALSE;
}
END_HANDLE_ERRORS
}
static PyObject* PyTensorObject_requires_grad_(PyObject* self, PyObject* args, PyObject* kwargs) { static PyObject* PyTensorObject_requires_grad_(PyObject* self, PyObject* args, PyObject* kwargs) {
HANDLE_ERRORS HANDLE_ERRORS
int requires_grad = 1; int requires_grad = 1;
...@@ -203,11 +268,7 @@ static PyObject* PyTensorObject_requires_grad_(PyObject* self, PyObject* args, P ...@@ -203,11 +268,7 @@ static PyObject* PyTensorObject_requires_grad_(PyObject* self, PyObject* args, P
static PyObject* PyTensorObject_retain_grad(PyObject* self, PyObject* unused) { static PyObject* PyTensorObject_retain_grad(PyObject* self, PyObject* unused) {
HANDLE_ERRORS HANDLE_ERRORS
const auto& t = PyTensor_Unpack(self); const auto& t = PyTensor_Unpack(self);
if (!t->requires_grad()) { CHECK_JUST(t->set_retain_grad(true));
return PyErr_Format(PyExc_RuntimeError,
"can't retain_grad on Tensor that has requires_grad=False");
}
ASSERT(t->set_retain_grad(true));
Py_RETURN_NONE; Py_RETURN_NONE;
END_HANDLE_ERRORS END_HANDLE_ERRORS
} }
...@@ -226,7 +287,48 @@ static PyObject* PyTensorObject_clone(PyObject* self, PyObject* unused) { ...@@ -226,7 +287,48 @@ static PyObject* PyTensorObject_clone(PyObject* self, PyObject* unused) {
static PyObject* PyTensorObject_zero_(PyObject* self, PyObject* unused) { static PyObject* PyTensorObject_zero_(PyObject* self, PyObject* unused) {
HANDLE_ERRORS HANDLE_ERRORS
ASSERT(EagerMirroredTensorZeros(PyTensor_Unpack(self))); ASSERT(EagerLocalTensorZeros(PyTensor_Unpack(self)));
Py_XINCREF(self);
return self;
END_HANDLE_ERRORS
}
std::vector<Symbol<SbpParallel>> RawSbpBToP(Symbol<NdSbp> nd_sbp) {
std::vector<Symbol<SbpParallel>> new_nd_sbp;
for (const auto& old_sbp : nd_sbp->sbp_parallel()) {
SbpParallel new_sbp = old_sbp;
if (new_sbp.has_broadcast_parallel()) { new_sbp.mutable_partial_sum_parallel(); }
new_nd_sbp.push_back(SymbolOf(new_sbp));
}
return new_nd_sbp;
}
static constexpr auto* SbpBToP = DECORATE(&RawSbpBToP, ThreadLocalCached);
static PyObject* PyTensorObject_zero_grad(PyObject* self, PyObject* args, PyObject* kwargs) {
HANDLE_ERRORS
int set_to_none = 0;
static const char* keywords[2] = {"set_to_none", NULL};
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|p:_zero_grad_", const_cast<char**>(keywords),
&set_to_none)) {
return NULL;
}
const auto& t = PyTensor_Unpack(self);
const auto acc_grad = ASSERT_PTR(t->acc_grad());
if (acc_grad) {
if (set_to_none) {
ASSERT(t->set_acc_grad(NULL));
} else {
ASSERT(EagerLocalTensorZeros(acc_grad));
if (acc_grad->is_global() && acc_grad->is_eager()) {
const auto local_tensor = ASSERT_PTR(functional::GlobalToLocal(acc_grad, false));
const auto p = ASSERT_PTR(functional::LocalToGlobal(
local_tensor, ASSERT(acc_grad->parallel_desc()), SbpBToP(ASSERT(acc_grad->nd_sbp())),
*acc_grad->shape(), acc_grad->dtype(), false, false));
ASSERT(acc_grad->set_data(p));
}
}
}
Py_XINCREF(self); Py_XINCREF(self);
return self; return self;
END_HANDLE_ERRORS END_HANDLE_ERRORS
...@@ -269,17 +371,35 @@ static PyObject* PyTensorObject_to_numpy(PyObject* self, PyObject* unused) { ...@@ -269,17 +371,35 @@ static PyObject* PyTensorObject_to_numpy(PyObject* self, PyObject* unused) {
DataType data_type = t->dtype()->data_type(); DataType data_type = t->dtype()->data_type();
switch (data_type) { switch (data_type) {
#define SWITCH_EAGER_TENSOR_TO_NUMPY(cpp_type, of_type) \ #define SWITCH_EAGER_TENSOR_TO_NUMPY(cpp_type, of_type) \
case of_type: return ASSERT(EagerMirroredTensorToNumpy<cpp_type>(self)); case of_type: return ASSERT(EagerLocalTensorToNumpy<cpp_type>(self));
OF_PP_FOR_EACH_TUPLE(SWITCH_EAGER_TENSOR_TO_NUMPY, POD_DATA_TYPE_SEQ) OF_PP_FOR_EACH_TUPLE(SWITCH_EAGER_TENSOR_TO_NUMPY, POD_DATA_TYPE_SEQ)
case DataType::kFloat16: return ASSERT(EagerMirroredTensorToNumpy<float16>(self)); case DataType::kFloat16: return ASSERT(EagerLocalTensorToNumpy<float16>(self));
default: { default: {
return PyErr_Format(PyExc_RuntimeError, "Invalid datatype"); return PyErr_Format(PyExc_RuntimeError,
("Invalid datatype " + DataType_Name(data_type)).data());
} }
} }
#undef SWITCH_EAGER_TENSOR_TO_NUMPY #undef SWITCH_EAGER_TENSOR_TO_NUMPY
END_HANDLE_ERRORS END_HANDLE_ERRORS
} }
static PyObject* PyTensorObject_item(PyObject* self, PyObject* unused) {
HANDLE_ERRORS
const auto& t = PyTensor_Unpack(self);
DataType data_type = t->dtype()->data_type();
switch (data_type) {
#define CASE_SCALAR_TENSOR_TO_SCALAR(cpp_type, of_type) \
case of_type: return ASSERT(EagerLocalTensorItem<cpp_type>(t));
OF_PP_FOR_EACH_TUPLE(CASE_SCALAR_TENSOR_TO_SCALAR, POD_AND_HALF_DATA_TYPE_SEQ);
default: {
return PyErr_Format(PyExc_RuntimeError,
("Invalid datatype " + DataType_Name(data_type)).data());
}
}
#undef CASE_SCALAR_TENSOR_TO_SCALAR
END_HANDLE_ERRORS
}
static PyObject* PyTensorObject_type(PyObject* self, PyObject* args, PyObject* kwargs) { static PyObject* PyTensorObject_type(PyObject* self, PyObject* args, PyObject* kwargs) {
HANDLE_ERRORS HANDLE_ERRORS
const auto& tensor = PyTensor_Unpack(self); const auto& tensor = PyTensor_Unpack(self);
...@@ -299,6 +419,10 @@ static PyObject* PyTensorObject_type(PyObject* self, PyObject* args, PyObject* k ...@@ -299,6 +419,10 @@ static PyObject* PyTensorObject_type(PyObject* self, PyObject* args, PyObject* k
PyTensorType_FromDTypeAndDeviceType(tensor->dtype(), ASSERT(tensor->device())->enum_type()); PyTensorType_FromDTypeAndDeviceType(tensor->dtype(), ASSERT(tensor->device())->enum_type());
return PyUnicode_FromString(((PyTensorType*)tensor_type)->name); return PyUnicode_FromString(((PyTensorType*)tensor_type)->name);
} }
if (PyTensorMetaClass_CheckExact(tensor_type)) {
Optional<std::string> device = "cpu";
return PyTensor_New(ASSERT_PTR(functional::To(tensor, device, DType::Float(), /*copy=*/false)));
}
if (PyUnicode_Check(tensor_type)) { if (PyUnicode_Check(tensor_type)) {
tensor_type = PyTensorType_FromString(PyUnicode_AsUTF8(tensor_type)); tensor_type = PyTensorType_FromString(PyUnicode_AsUTF8(tensor_type));
} }
...@@ -319,41 +443,38 @@ static PyObject* PyTensorObject_type(PyObject* self, PyObject* args, PyObject* k ...@@ -319,41 +443,38 @@ static PyObject* PyTensorObject_type(PyObject* self, PyObject* args, PyObject* k
END_HANDLE_ERRORS END_HANDLE_ERRORS
} }
#define DEFINE_TENSOR_METHOD(T, type_proto) \ namespace {
static PyObject* PyTensorObject__copy_to_numpy_##T(PyObject* self, PyObject* array) { \ void CopyFromNumpyArray(ep::Stream* stream,
HANDLE_ERRORS \ const std::shared_ptr<vm::EagerBlobObject>& eager_blob_object,
ASSERT(CopyBetweenMirroredTensorAndNumpy<T>(PyTensor_Unpack(self), array, \ const NumPyArrayPtr& array_ptr) {
BlobNumpyCopyUtil<T>::To, "const", \ SyncAutoMemcpy(stream, eager_blob_object->mut_dptr(), array_ptr.data(),
/*block_host_until_done=*/true)); \ eager_blob_object->ByteSizeOfBlobBody(), eager_blob_object->mem_case(),
Py_RETURN_NONE; \ memory::MakeHostMemCase());
END_HANDLE_ERRORS \ }
} \
static PyObject* PyTensorObject__copy_from_numpy_##T(PyObject* self, PyObject* array) { \
HANDLE_ERRORS \
auto* copied = PyArray_NewCopy((PyArrayObject*)array, NPY_CORDER); \
ASSERT(CopyBetweenMirroredTensorAndNumpy<T>(PyTensor_Unpack(self), copied, \
BlobNumpyCopyUtil<T>::From, "mut", \
/*block_host_until_done=*/false)); \
Py_DECREF(copied); \
Py_RETURN_NONE; \
END_HANDLE_ERRORS \
}
OF_PP_FOR_EACH_TUPLE(DEFINE_TENSOR_METHOD, POD_DATA_TYPE_SEQ)
#undef DEFINE_TENSOR_METHOD
static PyObject* PyTensorObject__get_copy_mirrored_tensor_to_numpy_func_name(PyObject* self, void CopyToNumpyArray(ep::Stream* stream,
PyObject* unused) { const std::shared_ptr<vm::EagerBlobObject>& eager_blob_object,
const NumPyArrayPtr& array_ptr) {
SyncAutoMemcpy(stream, array_ptr.data(), eager_blob_object->dptr(),
eager_blob_object->ByteSizeOfBlobBody(), memory::MakeHostMemCase(),
eager_blob_object->mem_case());
}
} // namespace
//
static PyObject* PyTensorObject__copy_to_numpy(PyObject* self, PyObject* array) {
HANDLE_ERRORS HANDLE_ERRORS
return functional::CastToPyObject( ASSERT(CopyBetweenLocalTensorAndNumpy(PyTensor_Unpack(self), array, CopyToNumpyArray, "const",
GetCopyMirroredTensorToNumpyFuncName(PyTensor_Unpack(self)->dtype()->data_type())); /*block_host_until_done=*/true));
Py_RETURN_NONE;
END_HANDLE_ERRORS END_HANDLE_ERRORS
} }
static PyObject* PyTensorObject__copy_from_numpy(PyObject* self, PyObject* array) {
static PyObject* PyTensorObject__get_copy_mirrored_tensor_from_numpy_func_name(PyObject* self,
PyObject* unused) {
HANDLE_ERRORS HANDLE_ERRORS
return functional::CastToPyObject( auto* copied = PyArray_NewCopy((PyArrayObject*)array, NPY_CORDER);
GetCopyMirroredTensorFromNumpyFuncName(PyTensor_Unpack(self)->dtype()->data_type())); ASSERT(CopyBetweenLocalTensorAndNumpy(PyTensor_Unpack(self), copied, CopyFromNumpyArray, "mut",
/*block_host_until_done=*/false));
Py_DECREF(copied);
Py_RETURN_NONE;
END_HANDLE_ERRORS END_HANDLE_ERRORS
} }
...@@ -388,28 +509,24 @@ static PyMethodDef PyTensorObject_methods[] = { ...@@ -388,28 +509,24 @@ static PyMethodDef PyTensorObject_methods[] = {
{"contiguous_", PyTensorObject_contiguous_, METH_NOARGS, NULL}, {"contiguous_", PyTensorObject_contiguous_, METH_NOARGS, NULL},
{"pin_memory", PyTensorObject_pin_memory, METH_NOARGS, NULL}, {"pin_memory", PyTensorObject_pin_memory, METH_NOARGS, NULL},
{"is_pinned", PyTensorObject_is_pinned, METH_NOARGS, NULL}, {"is_pinned", PyTensorObject_is_pinned, METH_NOARGS, NULL},
{"is_floating_point", PyTensorObject_is_floating_point, METH_NOARGS, NULL},
{"requires_grad_", (PyCFunction)PyTensorObject_requires_grad_, METH_VARARGS | METH_KEYWORDS, {"requires_grad_", (PyCFunction)PyTensorObject_requires_grad_, METH_VARARGS | METH_KEYWORDS,
NULL}, NULL},
{"retain_grad", PyTensorObject_retain_grad, METH_NOARGS, NULL}, {"retain_grad", PyTensorObject_retain_grad, METH_NOARGS, NULL},
{"detach", PyTensorObject_detach, METH_NOARGS, NULL}, {"detach", PyTensorObject_detach, METH_NOARGS, NULL},
{"clone", PyTensorObject_clone, METH_NOARGS, NULL}, {"clone", PyTensorObject_clone, METH_NOARGS, NULL},
{"zero_", PyTensorObject_zero_, METH_NOARGS, NULL}, {"zero_", PyTensorObject_zero_, METH_NOARGS, NULL},
{"_zero_grad_", (PyCFunction)PyTensorObject_zero_grad, METH_VARARGS | METH_KEYWORDS, NULL},
{"register_hook", PyTensorObject_register_hook, METH_O, NULL}, {"register_hook", PyTensorObject_register_hook, METH_O, NULL},
{"_register_post_grad_accumulation_hook", PyTensorObject__register_post_grad_accumulation_hook, {"_register_post_grad_accumulation_hook", PyTensorObject__register_post_grad_accumulation_hook,
METH_O, NULL}, METH_O, NULL},
{"global_id", PyTensorObject_global_id, METH_NOARGS, NULL}, {"global_id", PyTensorObject_global_id, METH_NOARGS, NULL},
{"check_meta_consistency", PyTensorObject_check_meta_consistency, METH_NOARGS, NULL}, {"check_meta_consistency", PyTensorObject_check_meta_consistency, METH_NOARGS, NULL},
{"to_numpy", PyTensorObject_to_numpy, METH_NOARGS, NULL}, {"to_numpy", PyTensorObject_to_numpy, METH_NOARGS, NULL},
{"item", PyTensorObject_item, METH_NOARGS, NULL},
{"type", (PyCFunction)PyTensorObject_type, METH_VARARGS | METH_KEYWORDS, NULL}, {"type", (PyCFunction)PyTensorObject_type, METH_VARARGS | METH_KEYWORDS, NULL},
#define DEFINE_TENSOR_METHOD(T, type_proto) \ {"_copy_to_numpy", PyTensorObject__copy_to_numpy, METH_O, NULL},
{"_copy_to_numpy_" #T, PyTensorObject__copy_to_numpy_##T, METH_O, NULL}, \ {"_copy_from_numpy", PyTensorObject__copy_from_numpy, METH_O, NULL},
{"_copy_from_numpy_" #T, PyTensorObject__copy_from_numpy_##T, METH_O, NULL},
OF_PP_FOR_EACH_TUPLE(DEFINE_TENSOR_METHOD, POD_DATA_TYPE_SEQ)
#undef DEFINE_TENSOR_METHOD
{"_get_copy_mirrored_tensor_to_numpy_func_name",
PyTensorObject__get_copy_mirrored_tensor_to_numpy_func_name, METH_NOARGS, NULL},
{"_get_copy_mirrored_tensor_from_numpy_func_name",
PyTensorObject__get_copy_mirrored_tensor_from_numpy_func_name, METH_NOARGS, NULL},
{"_register_storage_delete_hook", PyTensorObject__register_storage_delete_hook, METH_O, NULL}, {"_register_storage_delete_hook", PyTensorObject__register_storage_delete_hook, METH_O, NULL},
{NULL}}; {NULL}};
...@@ -451,16 +568,6 @@ static int PyTensorObject_set_grad(PyObject* self, PyObject* grad, void* unused) ...@@ -451,16 +568,6 @@ static int PyTensorObject_set_grad(PyObject* self, PyObject* grad, void* unused)
END_HANDLE_ERRORS_RET(-1) END_HANDLE_ERRORS_RET(-1)
} }
static PyObject* PyTensorObject__is_grad_acc_inplace(PyObject* self, void* unused) {
return functional::CastToPyObject(PyTensor_Unpack(self)->autograd_meta()->is_grad_acc_inplace());
}
static int PyTensorObject_set__is_grad_acc_inplace(PyObject* self, PyObject* is_inplace,
void* unused) {
PyTensor_Unpack(self)->mut_autograd_meta()->set_is_grad_acc_inplace(is_inplace);
return 0;
}
static PyObject* PyTensorObject_data(PyObject* self, void* unused) { static PyObject* PyTensorObject_data(PyObject* self, void* unused) {
HANDLE_ERRORS HANDLE_ERRORS
return PyTensor_New(ASSERT_PTR(PyTensor_Unpack(self)->data())); return PyTensor_New(ASSERT_PTR(PyTensor_Unpack(self)->data()));
...@@ -509,7 +616,7 @@ static PyObject* PyTensorObject_is_eager(PyObject* self, void* unused) { ...@@ -509,7 +616,7 @@ static PyObject* PyTensorObject_is_eager(PyObject* self, void* unused) {
} }
static PyObject* PyTensorObject_is_global(PyObject* self, void* unused) { static PyObject* PyTensorObject_is_global(PyObject* self, void* unused) {
return functional::CastToPyObject(PyTensor_Unpack(self)->is_consistent()); return functional::CastToPyObject(PyTensor_Unpack(self)->is_global());
} }
static PyObject* PyTensorObject_is_local(PyObject* self, void* unused) { static PyObject* PyTensorObject_is_local(PyObject* self, void* unused) {
...@@ -548,8 +655,6 @@ static PyGetSetDef PyTensorObject_properties[] = { ...@@ -548,8 +655,6 @@ static PyGetSetDef PyTensorObject_properties[] = {
{PYGETSET_NAME("is_cuda"), (getter)PyTensorObject_is_cuda, NULL, NULL, NULL}, {PYGETSET_NAME("is_cuda"), (getter)PyTensorObject_is_cuda, NULL, NULL, NULL},
{PYGETSET_NAME("grad"), (getter)PyTensorObject_grad, (setter)PyTensorObject_set_grad, NULL, {PYGETSET_NAME("grad"), (getter)PyTensorObject_grad, (setter)PyTensorObject_set_grad, NULL,
NULL}, NULL},
{PYGETSET_NAME("_is_grad_acc_inplace"), (getter)PyTensorObject__is_grad_acc_inplace,
(setter)PyTensorObject_set__is_grad_acc_inplace, NULL, NULL},
{PYGETSET_NAME("data"), (getter)PyTensorObject_data, (setter)PyTensorObject_set_data, NULL, {PYGETSET_NAME("data"), (getter)PyTensorObject_data, (setter)PyTensorObject_set_data, NULL,
NULL}, NULL},
{PYGETSET_NAME("grad_fn"), (getter)PyTensorObject_grad_fn, NULL, NULL, NULL}, {PYGETSET_NAME("grad_fn"), (getter)PyTensorObject_grad_fn, NULL, NULL, NULL},
...@@ -657,35 +762,17 @@ static PyTypeObject* MakeParameterType() { ...@@ -657,35 +762,17 @@ static PyTypeObject* MakeParameterType() {
} }
PyObject* PyTensor_New(const std::shared_ptr<Tensor>& data) { PyObject* PyTensor_New(const std::shared_ptr<Tensor>& data) {
if (!data) { Py_RETURN_NONE; } return PyTensor_wrap<Tensor>(data, /*bind_pyobj=*/nullptr);
if (data->pyobject()) { return PY_XINCREF((PyObject*)(data->pyobject())); }
auto* self = (PyTensorObject*)PyTensorObject_Type->tp_alloc(PyTensorObject_Type, 0);
if (self) {
self->data = data;
self->data->set_pyobject(self);
}
return (PyObject*)self;
} }
PyObject* PyParameter_New(const std::shared_ptr<Parameter>& data) { PyObject* PyParameter_New(const std::shared_ptr<Parameter>& data) {
if (!data) { Py_RETURN_NONE; } return PyTensor_wrap<Parameter>(data, /*bind_pyobj=*/nullptr);
if (data->pyobject()) { return PY_XINCREF((PyObject*)(data->pyobject())); }
auto* self = (PyTensorObject*)PyTensorObject_Type->tp_alloc(PyParameterObject_Type, 0);
if (self) {
self->data = data;
self->data->set_pyobject(self);
}
return (PyObject*)self;
} }
PyObject* PyParameter_New(const std::shared_ptr<Tensor>& data, bool requires_grad) { PyObject* PyParameter_New(const std::shared_ptr<Tensor>& data, bool requires_grad) {
if (!data) { Py_RETURN_NONE; } if (!data) { Py_RETURN_NONE; }
auto* self = (PyTensorObject*)PyTensorObject_Type->tp_alloc(PyParameterObject_Type, 0); return PyTensor_wrap<Parameter>(ASSERT_PTR(Parameter::MakeTensor(data, requires_grad)),
if (self) { /*bind_pyobj=*/nullptr);
self->data = ASSERT_PTR(Parameter::MakeTensor(data, requires_grad));
self->data->set_pyobject(self);
}
return (PyObject*)self;
} }
} // namespace one } // namespace one
......
...@@ -31,6 +31,10 @@ typedef struct { ...@@ -31,6 +31,10 @@ typedef struct {
extern PyTypeObject* PyTensorObject_Type; extern PyTypeObject* PyTensorObject_Type;
extern PyTypeObject* PyParameterObject_Type; extern PyTypeObject* PyParameterObject_Type;
inline bool PyTensorMetaClass_CheckExact(PyObject* obj) {
return obj == (PyObject*)PyTensorObject_Type;
}
inline bool PyTensor_Check(PyObject* op) { return PyObject_TypeCheck(op, PyTensorObject_Type); } inline bool PyTensor_Check(PyObject* op) { return PyObject_TypeCheck(op, PyTensorObject_Type); }
inline bool PyTensor_CheckExact(PyObject* op) { inline bool PyTensor_CheckExact(PyObject* op) {
......
...@@ -23,6 +23,10 @@ limitations under the License. ...@@ -23,6 +23,10 @@ limitations under the License.
#include "oneflow/core/functional/functional.h" #include "oneflow/core/functional/functional.h"
#include "oneflow/core/common/shape.h" #include "oneflow/core/common/shape.h"
#include "oneflow/core/common/wrap_dim_utils.h" #include "oneflow/core/common/wrap_dim_utils.h"
#include "oneflow/core/functional/functional_api.yaml.h"
#include "oneflow/api/python/functional/tensor_api.yaml.h"
#include "oneflow/extension/python/numpy.h"
#include "oneflow/api/python/utils/tensor_utils.h"
namespace oneflow { namespace oneflow {
namespace one { namespace one {
...@@ -31,14 +35,35 @@ namespace one { ...@@ -31,14 +35,35 @@ namespace one {
#define ASSERT_PTR(x) (x).GetPtrOrThrow() #define ASSERT_PTR(x) (x).GetPtrOrThrow()
using functional::PyObjectPtr; using functional::PyObjectPtr;
namespace {
static PyObject* concat_self(PyObject* self, PyObject* args) { PyObject* concat_self(PyObject* self, PyObject* args) {
PyObjectPtr self_tuple(PyTuple_Pack(1, self)); PyObjectPtr self_tuple(PyTuple_Pack(1, self));
PyObject* tuple = PySequence_Concat(self_tuple.get(), args); PyObject* tuple = PySequence_Concat(self_tuple.get(), args);
CHECK_OR_THROW(tuple != NULL); CHECK_OR_THROW(tuple != NULL);
return tuple; return tuple;
} }
PyObject* ndarray_judgment_and_compatibility(PyObject* self, PyObject* other) {
if (PyArray_Check(other)) {
const auto& tensor = PyTensor_Unpack(self);
CHECK_OR_THROW(!tensor->is_cuda())
<< Error::RuntimeError() << "Can't convert cuda device type tensor to numpy";
if (tensor->is_global()) {
Symbol<ParallelDesc> placement = ASSERT(tensor->parallel_desc());
auto ndsbp = ASSERT(tensor->nd_sbp());
std::vector<Symbol<SbpParallel>> sbp(ndsbp->sbp_parallel_size(),
ASSERT(MakeBroadcastSbpParallel()));
other = functional::CastToPyObject(MakeGlobalTensorFromData(other, tensor->dtype(), placement,
sbp, /*requires_grad=*/false));
} else {
other = functional::CastToPyObject(functional::LocalTensorSharedNumpyData(other));
}
}
return other;
}
} // namespace
#define NB_UNARY_FUNC(func_name, bind_func) \ #define NB_UNARY_FUNC(func_name, bind_func) \
static PyObject* func_name(PyObject* self) { \ static PyObject* func_name(PyObject* self) { \
HANDLE_ERRORS \ HANDLE_ERRORS \
...@@ -52,12 +77,13 @@ static PyObject* concat_self(PyObject* self, PyObject* args) { ...@@ -52,12 +77,13 @@ static PyObject* concat_self(PyObject* self, PyObject* args) {
#define NB_BINARY_FUNC(func_name, bind_func) \ #define NB_BINARY_FUNC(func_name, bind_func) \
static PyObject* func_name(PyObject* a, PyObject* b) { \ static PyObject* func_name(PyObject* a, PyObject* b) { \
HANDLE_ERRORS \ HANDLE_ERRORS \
b = ndarray_judgment_and_compatibility(a, b); \
PyObjectPtr tuple(PyTuple_Pack(2, a, b)); \ PyObjectPtr tuple(PyTuple_Pack(2, a, b)); \
auto* result = bind_func(NULL, tuple.get(), NULL); \ auto* result = bind_func(NULL, tuple.get(), NULL); \
if (PyErr_Occurred()) { throw py::error_already_set(); } \ if (PyErr_Occurred()) { throw py::error_already_set(); } \
return result; \ return result; \
END_HANDLE_ERRORS \ END_HANDLE_ERRORS \
} } // namespace one
NB_UNARY_FUNC(PyTensorObject_nb_absolute, functional::abs); NB_UNARY_FUNC(PyTensorObject_nb_absolute, functional::abs);
NB_UNARY_FUNC(PyTensorObject_nb_negative, functional::negative); NB_UNARY_FUNC(PyTensorObject_nb_negative, functional::negative);
...@@ -76,8 +102,9 @@ NB_BINARY_FUNC(PyTensorObject_nb_floor_div, functional::floor_divide); ...@@ -76,8 +102,9 @@ NB_BINARY_FUNC(PyTensorObject_nb_floor_div, functional::floor_divide);
NB_BINARY_FUNC(PyTensorObject_nb_true_div, functional::div); NB_BINARY_FUNC(PyTensorObject_nb_true_div, functional::div);
NB_BINARY_FUNC(PyTensorObject_nb_matrix_multiply, functional::matmul); NB_BINARY_FUNC(PyTensorObject_nb_matrix_multiply, functional::matmul);
static PyObject* PyTensorObject_nb_pow(PyObject* a, PyObject* b, PyObject* unsed) { static PyObject* PyTensorObject_nb_pow(PyObject* a, PyObject* b, PyObject* unused) {
HANDLE_ERRORS HANDLE_ERRORS
b = ndarray_judgment_and_compatibility(a, b);
PyObjectPtr tuple(PyTuple_Pack(2, a, b)); PyObjectPtr tuple(PyTuple_Pack(2, a, b));
PyObject* result = functional::pow(NULL, tuple.get(), NULL); PyObject* result = functional::pow(NULL, tuple.get(), NULL);
if (PyErr_Occurred()) { throw py::error_already_set(); } if (PyErr_Occurred()) { throw py::error_already_set(); }
...@@ -99,6 +126,7 @@ static PyObject* PyTensorObject_nb_invert(PyObject* self) { ...@@ -99,6 +126,7 @@ static PyObject* PyTensorObject_nb_invert(PyObject* self) {
#define NB_INPLACE_BINARY_FUNC(func_name, bind_func) \ #define NB_INPLACE_BINARY_FUNC(func_name, bind_func) \
static PyObject* func_name(PyObject* a, PyObject* b) { \ static PyObject* func_name(PyObject* a, PyObject* b) { \
HANDLE_ERRORS \ HANDLE_ERRORS \
b = ndarray_judgment_and_compatibility(a, b); \
PyObjectPtr tuple(PyTuple_Pack(2, a, b)); \ PyObjectPtr tuple(PyTuple_Pack(2, a, b)); \
PyObjectPtr dict(PyDict_New()); \ PyObjectPtr dict(PyDict_New()); \
CHECK_OR_THROW(PyDict_SetItemString(dict.get(), "inplace", Py_True) > -1); \ CHECK_OR_THROW(PyDict_SetItemString(dict.get(), "inplace", Py_True) > -1); \
...@@ -115,7 +143,7 @@ NB_INPLACE_BINARY_FUNC(PyTensorObject_nb_inplace_sub, functional::sub); ...@@ -115,7 +143,7 @@ NB_INPLACE_BINARY_FUNC(PyTensorObject_nb_inplace_sub, functional::sub);
NB_BINARY_FUNC(PyTensorObject_nb_inplace_mul, functional::mul_); NB_BINARY_FUNC(PyTensorObject_nb_inplace_mul, functional::mul_);
NB_BINARY_FUNC(PyTensorObject_nb_inplace_true_div, functional::div_); NB_BINARY_FUNC(PyTensorObject_nb_inplace_true_div, functional::div_);
PyObject* PyTensorObject_nb_inplace_pow(PyObject* a, PyObject* b, PyObject* unsed) { PyObject* PyTensorObject_nb_inplace_pow(PyObject* a, PyObject* b, PyObject* unused) {
HANDLE_ERRORS HANDLE_ERRORS
PyObjectPtr tuple(PyTuple_Pack(2, a, b)); PyObjectPtr tuple(PyTuple_Pack(2, a, b));
PyObjectPtr dict(PyDict_New()); PyObjectPtr dict(PyDict_New());
...@@ -193,6 +221,7 @@ UNARY_METHOD(PyTensorObject_selu, functional::Selu); ...@@ -193,6 +221,7 @@ UNARY_METHOD(PyTensorObject_selu, functional::Selu);
UNARY_METHOD(PyTensorObject_softsign, functional::SoftSign); UNARY_METHOD(PyTensorObject_softsign, functional::SoftSign);
UNARY_METHOD(PyTensorObject_log1p, functional::Log1p); UNARY_METHOD(PyTensorObject_log1p, functional::Log1p);
UNARY_METHOD(PyTensorObject_log2, functional::Log2); UNARY_METHOD(PyTensorObject_log2, functional::Log2);
UNARY_METHOD(PyTensorObject_log10, functional::Log10);
UNARY_METHOD(PyTensorObject_reciprocal, functional::Reciprocal); UNARY_METHOD(PyTensorObject_reciprocal, functional::Reciprocal);
UNARY_METHOD(PyTensorObject_ceil, functional::Ceil); UNARY_METHOD(PyTensorObject_ceil, functional::Ceil);
UNARY_METHOD(PyTensorObject_erf, functional::Erf); UNARY_METHOD(PyTensorObject_erf, functional::Erf);
...@@ -246,6 +275,7 @@ DIRECT_PASS_FUNC(PyTensorObject_fmod, functional::fmod) ...@@ -246,6 +275,7 @@ DIRECT_PASS_FUNC(PyTensorObject_fmod, functional::fmod)
DIRECT_PASS_FUNC(PyTensorObject_logical_and, functional::logical_and) DIRECT_PASS_FUNC(PyTensorObject_logical_and, functional::logical_and)
DIRECT_PASS_FUNC(PyTensorObject_logical_or, functional::logical_or) DIRECT_PASS_FUNC(PyTensorObject_logical_or, functional::logical_or)
DIRECT_PASS_FUNC(PyTensorObject_logical_xor, functional::logical_xor) DIRECT_PASS_FUNC(PyTensorObject_logical_xor, functional::logical_xor)
DIRECT_PASS_FUNC(PyTensorObject_equal, functional::equal)
DIRECT_PASS_FUNC(PyTensorObject_ne, functional::not_equal) DIRECT_PASS_FUNC(PyTensorObject_ne, functional::not_equal)
DIRECT_PASS_FUNC(PyTensorObject_lt, functional::less) DIRECT_PASS_FUNC(PyTensorObject_lt, functional::less)
DIRECT_PASS_FUNC(PyTensorObject_le, functional::less_equal) DIRECT_PASS_FUNC(PyTensorObject_le, functional::less_equal)
...@@ -256,17 +286,26 @@ DIRECT_PASS_FUNC(PyTensorObject_amin, functional::amin) ...@@ -256,17 +286,26 @@ DIRECT_PASS_FUNC(PyTensorObject_amin, functional::amin)
DIRECT_PASS_FUNC(PyTensorObject_amax, functional::amax) DIRECT_PASS_FUNC(PyTensorObject_amax, functional::amax)
DIRECT_PASS_FUNC(PyTensorObject_addcmul, functional::addcmul) DIRECT_PASS_FUNC(PyTensorObject_addcmul, functional::addcmul)
DIRECT_PASS_FUNC(PyTensorObject_addcmul_, functional::addcmul_) DIRECT_PASS_FUNC(PyTensorObject_addcmul_, functional::addcmul_)
DIRECT_PASS_FUNC(PyTensorObject_addcdiv, functional::addcdiv)
DIRECT_PASS_FUNC(PyTensorObject_addcdiv_, functional::addcdiv_)
DIRECT_PASS_FUNC(PyTensorObject_flip, functional::flip)
DIRECT_PASS_FUNC(PyTensorObject_clip, functional::clip) DIRECT_PASS_FUNC(PyTensorObject_clip, functional::clip)
DIRECT_PASS_FUNC(PyTensorObject_clip_, functional::clip_) DIRECT_PASS_FUNC(PyTensorObject_clip_, functional::clip_)
DIRECT_PASS_FUNC(PyTensorObject_clamp, functional::clamp) DIRECT_PASS_FUNC(PyTensorObject_clamp, functional::clamp)
DIRECT_PASS_FUNC(PyTensorObject_clamp_min, functional::clamp_min)
DIRECT_PASS_FUNC(PyTensorObject_clamp_max, functional::clamp_max)
DIRECT_PASS_FUNC(PyTensorObject_clamp_, functional::clamp_) DIRECT_PASS_FUNC(PyTensorObject_clamp_, functional::clamp_)
DIRECT_PASS_FUNC(PyTensorObject_clamp_min_, functional::clamp_min_)
DIRECT_PASS_FUNC(PyTensorObject_clamp_max_, functional::clamp_max_)
DIRECT_PASS_FUNC(PyTensorObject_flatten, functional::flatten) DIRECT_PASS_FUNC(PyTensorObject_flatten, functional::flatten)
DIRECT_PASS_FUNC(PyTensorObject_in_top_k, functional::in_top_k) DIRECT_PASS_FUNC(PyTensorObject_in_top_k, functional::in_top_k)
DIRECT_PASS_FUNC(PyTensorObject_index_select, functional::index_select) DIRECT_PASS_FUNC(PyTensorObject_index_select, functional::index_select)
DIRECT_PASS_FUNC(PyTensorObject_logsumexp, functional::logsumexp)
DIRECT_PASS_FUNC(PyTensorObject_maximum, functional::maximum) DIRECT_PASS_FUNC(PyTensorObject_maximum, functional::maximum)
DIRECT_PASS_FUNC(PyTensorObject_minimum, functional::minimum) DIRECT_PASS_FUNC(PyTensorObject_minimum, functional::minimum)
DIRECT_PASS_FUNC(PyTensorObject_tril, functional::tril) DIRECT_PASS_FUNC(PyTensorObject_tril, functional::tril)
DIRECT_PASS_FUNC(PyTensorObject_triu, functional::triu) DIRECT_PASS_FUNC(PyTensorObject_triu, functional::triu)
DIRECT_PASS_FUNC(PyTensorObject_triu_, functional::triu_)
DIRECT_PASS_FUNC(PyTensorObject_softmax, functional::softmax) DIRECT_PASS_FUNC(PyTensorObject_softmax, functional::softmax)
DIRECT_PASS_FUNC(PyTensorObject_log_softmax, functional::log_softmax) DIRECT_PASS_FUNC(PyTensorObject_log_softmax, functional::log_softmax)
DIRECT_PASS_FUNC(PyTensorObject_roll, functional::roll) DIRECT_PASS_FUNC(PyTensorObject_roll, functional::roll)
...@@ -281,9 +320,19 @@ DIRECT_PASS_FUNC(PyTensorObject_min, functional::min) ...@@ -281,9 +320,19 @@ DIRECT_PASS_FUNC(PyTensorObject_min, functional::min)
DIRECT_PASS_FUNC(PyTensorObject_median, functional::median) DIRECT_PASS_FUNC(PyTensorObject_median, functional::median)
DIRECT_PASS_FUNC(PyTensorObject_pow, functional::pow) DIRECT_PASS_FUNC(PyTensorObject_pow, functional::pow)
DIRECT_PASS_FUNC(PyTensorObject_chunk, functional::chunk) DIRECT_PASS_FUNC(PyTensorObject_chunk, functional::chunk)
DIRECT_PASS_FUNC(PyTensorObject_split, functional::split)
DIRECT_PASS_FUNC(PyTensorObject_narrow, functional::narrow) DIRECT_PASS_FUNC(PyTensorObject_narrow, functional::narrow)
DIRECT_PASS_FUNC(PyTensorObject_masked_fill, functional::masked_fill) DIRECT_PASS_FUNC(PyTensorObject_masked_fill, functional::masked_fill)
DIRECT_PASS_FUNC(PyTensorObject_masked_fill_, functional::masked_fill_)
DIRECT_PASS_FUNC(PyTensorObject_dot, functional::dot) DIRECT_PASS_FUNC(PyTensorObject_dot, functional::dot)
DIRECT_PASS_FUNC(PyTensorObject_nansum, functional::reduce_nansum)
DIRECT_PASS_FUNC(PyTensorObject_bernoulli, functional::bernoulli)
DIRECT_PASS_FUNC(PyTensorObject_bernoulli_, functional::bernoulli_)
DIRECT_PASS_FUNC(PyTensorObject_bincount, functional::bincount)
DIRECT_PASS_FUNC(PyTensorObject_isclose, functional::isclose)
DIRECT_PASS_FUNC(PyTensorObject_broadcast_to, functional::broadcast_to)
DIRECT_PASS_FUNC(PyTensorObject_unique, functional::unique)
DIRECT_PASS_FUNC(PyTensorObject_topk, functional::topk)
// functions that parsing at Python C api layer // functions that parsing at Python C api layer
static PyObject* PyTensorObject_byte(PyObject* self, PyObject* unused) { static PyObject* PyTensorObject_byte(PyObject* self, PyObject* unused) {
...@@ -569,11 +618,13 @@ REDUCE_FUNC(PyTensorObject_mean, functional::reduce_mean, functional::ReduceMean ...@@ -569,11 +618,13 @@ REDUCE_FUNC(PyTensorObject_mean, functional::reduce_mean, functional::ReduceMean
END_HANDLE_ERRORS \ END_HANDLE_ERRORS \
} }
DATATYPE_FUNC(PyTensorObject_bool, DType::Bool());
DATATYPE_FUNC(PyTensorObject_int, DType::Int32()); DATATYPE_FUNC(PyTensorObject_int, DType::Int32());
DATATYPE_FUNC(PyTensorObject_long, DType::Int64()); DATATYPE_FUNC(PyTensorObject_long, DType::Int64());
DATATYPE_FUNC(PyTensorObject_half, DType::Float16()); DATATYPE_FUNC(PyTensorObject_half, DType::Float16());
DATATYPE_FUNC(PyTensorObject_float, DType::Float()); DATATYPE_FUNC(PyTensorObject_float, DType::Float());
DATATYPE_FUNC(PyTensorObject_double, DType::Double()); DATATYPE_FUNC(PyTensorObject_double, DType::Double());
DATATYPE_FUNC(PyTensorObject_bfloat16, DType::BFloat16());
static PyObject* PyTensorObject_view(PyObject* self, PyObject* args, PyObject* kwargs) { static PyObject* PyTensorObject_view(PyObject* self, PyObject* args, PyObject* kwargs) {
HANDLE_ERRORS HANDLE_ERRORS
...@@ -639,16 +690,20 @@ static PyObject* PyTensorObject_local_to_global(PyObject* self, PyObject* args, ...@@ -639,16 +690,20 @@ static PyObject* PyTensorObject_local_to_global(PyObject* self, PyObject* args,
CHECK_OR_THROW(tensor->is_local()) << Error::RuntimeError() << "input must be a local tensor"; CHECK_OR_THROW(tensor->is_local()) << Error::RuntimeError() << "input must be a local tensor";
PyObject* placement_obj = Py_None; PyObject* placement_obj = Py_None;
PyObject* sbp_obj = Py_None; PyObject* sbp_obj = Py_None;
bool check_meta = true; PyObject* check_meta_obj = Py_True;
static const char* keywords[4] = {"placement", "sbp", "check_meta", NULL}; PyObject* copy_obj = Py_False;
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|OO$O!:local_to_global", static const char* keywords[5] = {"placement", "sbp", "check_meta", "copy", NULL};
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|OO$O!O!:local_to_global",
const_cast<char**>(keywords), &placement_obj, &sbp_obj, const_cast<char**>(keywords), &placement_obj, &sbp_obj,
&PyBool_Type, &check_meta)) { &PyBool_Type, &check_meta_obj, &PyBool_Type, &copy_obj)) {
return NULL; return NULL;
}; }
const bool check_meta = (check_meta_obj == Py_True);
const bool copy = (copy_obj == Py_True);
CHECK_OR_THROW(placement_obj != Py_None && sbp_obj != Py_None) << Error::InvalidValueError( CHECK_OR_THROW(placement_obj != Py_None && sbp_obj != Py_None)
"Converting a local tensor to global tensor must have placement and sbp parameters."); << Error::InvalidValueError()
<< "Converting a local tensor to global tensor must have placement and sbp parameters.";
CHECK_OR_THROW(functional::PyParallelDescCheck(placement_obj)) CHECK_OR_THROW(functional::PyParallelDescCheck(placement_obj))
<< Error::TypeError() << "Invalid parameter placement with type " << Error::TypeError() << "Invalid parameter placement with type "
<< functional::PyStringAsString(PyObject_Str((PyObject*)Py_TYPE(placement_obj))); << functional::PyStringAsString(PyObject_Str((PyObject*)Py_TYPE(placement_obj)));
...@@ -662,29 +717,31 @@ static PyObject* PyTensorObject_local_to_global(PyObject* self, PyObject* args, ...@@ -662,29 +717,31 @@ static PyObject* PyTensorObject_local_to_global(PyObject* self, PyObject* args,
<< functional::PyStringAsString(PyObject_Str((PyObject*)Py_TYPE(sbp_obj))); << functional::PyStringAsString(PyObject_Str((PyObject*)Py_TYPE(sbp_obj)));
sbp = functional::PyUnpackSbpParallelSequence(sbp_obj); sbp = functional::PyUnpackSbpParallelSequence(sbp_obj);
} }
return PyTensor_New(ASSERT_PTR(functional::ToConsistent( return PyTensor_New(ASSERT_PTR(functional::ToGlobal(
tensor, functional::PyUnpackParallelDesc(placement_obj), sbp, {}, check_meta))); tensor, functional::PyUnpackParallelDesc(placement_obj), sbp, {}, check_meta, copy)));
END_HANDLE_ERRORS END_HANDLE_ERRORS
} }
static PyObject* PyTensorObject_global_to_global(PyObject* self, PyObject* args, PyObject* kwargs) { static PyObject* PyTensorObject_global_to_global(PyObject* self, PyObject* args, PyObject* kwargs) {
HANDLE_ERRORS HANDLE_ERRORS
auto tensor = PyTensor_Unpack(self); auto tensor = PyTensor_Unpack(self);
CHECK_OR_THROW(tensor->is_consistent()) CHECK_OR_THROW(tensor->is_global()) << Error::RuntimeError() << "input must be a global tensor";
<< Error::RuntimeError() << "input must be a global tensor";
PyObject* placement_obj = Py_None; PyObject* placement_obj = Py_None;
PyObject* sbp_obj = Py_None; PyObject* sbp_obj = Py_None;
PyObject* grad_sbp_obj = Py_None; PyObject* grad_sbp_obj = Py_None;
Symbol<ParallelDesc> placement; Symbol<ParallelDesc> placement;
std::vector<Symbol<SbpParallel>> sbp; std::vector<Symbol<SbpParallel>> sbp;
std::vector<Symbol<SbpParallel>> grad_sbp; std::vector<Symbol<SbpParallel>> grad_sbp;
bool check_meta = false; PyObject* check_meta_obj = Py_False;
static const char* keywords[5] = {"placement", "sbp", "grad_sbp", "check_meta", NULL}; PyObject* copy_obj = Py_False;
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|OO$OO!:global_to_global", static const char* keywords[6] = {"placement", "sbp", "grad_sbp", "check_meta", "copy", NULL};
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|OO$OO!O!:global_to_global",
const_cast<char**>(keywords), &placement_obj, &sbp_obj, const_cast<char**>(keywords), &placement_obj, &sbp_obj,
&grad_sbp_obj, &PyBool_Type, &check_meta)) { &grad_sbp_obj, &PyBool_Type, &check_meta_obj, &copy_obj)) {
return NULL; return NULL;
}; }
const bool check_meta = (check_meta_obj == Py_True);
const bool copy = (copy_obj == Py_True);
// sbp // sbp
CHECK_OR_THROW(sbp_obj == Py_None || functional::PySbpParallelCheck(sbp_obj) CHECK_OR_THROW(sbp_obj == Py_None || functional::PySbpParallelCheck(sbp_obj)
...@@ -721,7 +778,7 @@ static PyObject* PyTensorObject_global_to_global(PyObject* self, PyObject* args, ...@@ -721,7 +778,7 @@ static PyObject* PyTensorObject_global_to_global(PyObject* self, PyObject* args,
grad_sbp = functional::PyUnpackSbpParallelSequence(grad_sbp_obj); grad_sbp = functional::PyUnpackSbpParallelSequence(grad_sbp_obj);
} }
return PyTensor_New( return PyTensor_New(
ASSERT_PTR(functional::ToConsistent(tensor, placement, sbp, grad_sbp, check_meta))); ASSERT_PTR(functional::ToGlobal(tensor, placement, sbp, grad_sbp, check_meta, copy)));
END_HANDLE_ERRORS END_HANDLE_ERRORS
} }
...@@ -729,7 +786,7 @@ static PyObject* PyTensorObject_to_global(PyObject* self, PyObject* args, PyObje ...@@ -729,7 +786,7 @@ static PyObject* PyTensorObject_to_global(PyObject* self, PyObject* args, PyObje
HANDLE_ERRORS HANDLE_ERRORS
const auto& tensor = PyTensor_Unpack(self); const auto& tensor = PyTensor_Unpack(self);
PyObject* result = NULL; PyObject* result = NULL;
if (tensor->is_consistent()) if (tensor->is_global())
result = PyTensorObject_global_to_global(self, args, kwargs); result = PyTensorObject_global_to_global(self, args, kwargs);
else { else {
result = PyTensorObject_local_to_global(self, args, kwargs); result = PyTensorObject_local_to_global(self, args, kwargs);
...@@ -740,57 +797,116 @@ static PyObject* PyTensorObject_to_global(PyObject* self, PyObject* args, PyObje ...@@ -740,57 +797,116 @@ static PyObject* PyTensorObject_to_global(PyObject* self, PyObject* args, PyObje
END_HANDLE_ERRORS END_HANDLE_ERRORS
} }
static PyObject* PyTensorObject_to_local(PyObject* self, PyObject* unused) { static PyObject* PyTensorObject_to_local(PyObject* self, PyObject* unused, PyObject* kwargs) {
HANDLE_ERRORS HANDLE_ERRORS
auto tensor = PyTensor_Unpack(self); auto tensor = PyTensor_Unpack(self);
CHECK_OR_THROW(tensor->is_consistent()) CHECK_OR_THROW(tensor->is_global())
<< Error::RuntimeError() << "Expected global tensor for to_local but got local tensor!"; << Error::RuntimeError() << "Expected global tensor for to_local but got local tensor!";
return PyTensor_New(ASSERT_PTR(functional::ConsistentToLocal(tensor))); bool copy = false;
static const char* keywords[2] = {"copy", NULL};
if (!PyArg_ParseTupleAndKeywords(unused, kwargs, "|$O!:to_local", const_cast<char**>(keywords),
&PyBool_Type, &copy)) {
return NULL;
};
return PyTensor_New(ASSERT_PTR(functional::GlobalToLocal(tensor, /*copy=*/copy)));
END_HANDLE_ERRORS END_HANDLE_ERRORS
} }
int PyTensorObject_setitem(PyObject* self, PyObject* item, PyObject* value) { static PyObject* PyTensorObject_type_as(PyObject* self, PyObject* args, PyObject* kwargs) {
HANDLE_ERRORS HANDLE_ERRORS
auto tensor = PyTensor_Unpack(self); auto self_tensor = PyTensor_Unpack(self);
PyObject* other = NULL;
static const char* keywords[2] = {"other", NULL};
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O|:type_as", const_cast<char**>(keywords),
&other)) {
return NULL;
}
// target is local
auto other_tensor = PyTensor_Unpack(other);
if (other_tensor->is_local()) {
Optional<Symbol<Device>> device = ASSERT(other_tensor->device());
if (self_tensor->is_global()) {
self_tensor = ASSERT_PTR(functional::GlobalToLocal(self_tensor, /*copy=*/false));
}
return PyTensor_New(
ASSERT_PTR(functional::To(self_tensor, device, other_tensor->dtype(), /*copy=*/false)));
}
// target is global
std::shared_ptr<Tensor> value_tensor; std::shared_ptr<Tensor> value_tensor;
value_tensor = ASSERT_PTR(functional::To(self_tensor, other_tensor->dtype(), /*copy=*/false));
Symbol<ParallelDesc> placement = ASSERT(other_tensor->parallel_desc());
std::vector<Symbol<SbpParallel>> sbp;
auto ndsbp = ASSERT(other_tensor->nd_sbp());
for (int32_t i = 0; i < ndsbp->sbp_parallel_size(); i++) {
sbp.emplace_back(ndsbp->sbp_parallel(i));
}
return PyTensor_New(
ASSERT_PTR(functional::ToGlobal(value_tensor, placement, sbp, {}, true, /*copy=*/false)));
END_HANDLE_ERRORS
}
int PyTensorObject_setitem(PyObject* self, PyObject* item, PyObject* value) {
HANDLE_ERRORS
CHECK_OR_THROW(functional::PyTensorIndexCheck(item)) CHECK_OR_THROW(functional::PyTensorIndexCheck(item))
<< Error::TypeError() << "tensor_setitem(): argument 'index' must be index, not " << Error::TypeError() << "tensor_setitem(): argument 'index' must be index, not "
<< functional::PyStringAsString(PyObject_Str((PyObject*)Py_TYPE(item))); << functional::PyStringAsString(PyObject_Str((PyObject*)Py_TYPE(item)));
CHECK_OR_THROW(functional::PyScalarCheck(value) || PyTensor_Check(value)) CHECK_OR_THROW(functional::PyScalarCheck(value) || PyTensor_Check(value))
<< Error::TypeError() << "tensor_setitem(): argument 'value' must be tensor or scalar, not " << Error::TypeError() << "tensor_setitem(): argument 'value' must be tensor or scalar, not "
<< functional::PyStringAsString(PyObject_Str((PyObject*)Py_TYPE(value))); << functional::PyStringAsString(PyObject_Str((PyObject*)Py_TYPE(value)));
const auto& index_item = functional::PyUnpackTensorIndex(item);
if (tensor->is_consistent()) { auto tensor = PyTensor_Unpack(self);
Symbol<ParallelDesc> placement = ASSERT(tensor->parallel_desc()); // NOTE: use masked_fill_(local,global) to avoid D2H in TensorSetItem if index is bool tensor
auto ndsbp = ASSERT(tensor->nd_sbp()); if (functional::PyScalarCheck(value) && index_item.size() == 1 && index_item[0].IsTensor()) {
std::vector<Symbol<SbpParallel>> sbp(ndsbp->sbp_parallel_size(), const auto& index_tensor = index_item[0].tensor();
ASSERT(MakeBroadcastSbpParallel())); if (index_tensor->shape() == tensor->shape()
if (functional::PyScalarCheck(value)) { && (index_tensor->dtype() == DType::Bool() || index_tensor->dtype() == DType::UInt8())) {
Scalar value_scalar = functional::PyUnpackScalar(value); ASSERT_PTR(
value_tensor = ASSERT_PTR( functional::MaskedFillInplace(tensor, index_tensor, functional::PyUnpackScalar(value)));
functional::ConsistentConstant({1}, value_scalar, tensor->dtype(), placement, sbp)); return 0;
} else {
value_tensor = PyTensor_Unpack(value);
CHECK_OR_THROW(value_tensor->is_consistent())
<< Error::RuntimeError()
<< "tensor_setitem(): value must be a global tensor when self is global";
value_tensor = ASSERT_PTR(functional::ToConsistent(value_tensor, placement, sbp, {}, true));
} }
} else { }
if (functional::PyScalarCheck(value)) {
Scalar value_scalar = functional::PyUnpackScalar(value); std::shared_ptr<Tensor> value_tensor;
value_tensor = ASSERT_PTR( {
functional::Constant({1}, value_scalar, tensor->dtype(), ASSERT(tensor->device()))); if (tensor->is_global()) {
Symbol<ParallelDesc> placement = ASSERT(tensor->parallel_desc());
auto ndsbp = ASSERT(tensor->nd_sbp());
std::vector<Symbol<SbpParallel>> sbp(ndsbp->sbp_parallel_size(),
ASSERT(MakeBroadcastSbpParallel()));
if (functional::PyScalarCheck(value)) {
Scalar value_scalar = functional::PyUnpackScalar(value);
value_tensor = ASSERT_PTR(
functional::GlobalConstant(Shape({}), value_scalar, tensor->dtype(), placement, sbp));
} else {
value_tensor = PyTensor_Unpack(value);
CHECK_OR_THROW(value_tensor->is_global())
<< Error::RuntimeError()
<< "tensor_setitem(): value must be a global tensor when self is global";
value_tensor = ASSERT_PTR(
functional::ToGlobal(value_tensor, placement, sbp, {}, true, /*copy=*/false));
}
} else { } else {
value_tensor = PyTensor_Unpack(value); if (functional::PyScalarCheck(value)) {
CHECK_OR_THROW(value_tensor->is_local()) // NOTE: initialize value_tensor in eager mode
<< Error::RuntimeError() LazyMode::Guard lazy_mode_disabled_guard(/*is_enabled=*/false);
<< "tensor_setitem(): value must be a local tensor when self is local"; Scalar value_scalar = functional::PyUnpackScalar(value);
Optional<Symbol<Device>> device = ASSERT(tensor->device()); value_tensor = ASSERT_PTR(functional::Constant(Shape({}), value_scalar, tensor->dtype(),
value_tensor = ASSERT_PTR(functional::To(value_tensor, device, value_tensor->dtype(), false)); ASSERT(tensor->device())));
} else {
value_tensor = PyTensor_Unpack(value);
CHECK_OR_THROW(value_tensor->is_local())
<< Error::RuntimeError()
<< "tensor_setitem(): value must be a local tensor when self is local";
Optional<Symbol<Device>> device = ASSERT(tensor->device());
value_tensor =
ASSERT_PTR(functional::To(value_tensor, device, value_tensor->dtype(), false));
}
} }
} }
ASSERT(functional::TensorSetItem(tensor, functional::PyUnpackTensorIndex(item), value_tensor)); ASSERT(functional::TensorSetItem(tensor, index_item, value_tensor));
return 0; return 0;
END_HANDLE_ERRORS_RET(-1) END_HANDLE_ERRORS_RET(-1)
} }
...@@ -812,18 +928,23 @@ PyMethodDef PyTensorObject_extra_methods[] = { ...@@ -812,18 +928,23 @@ PyMethodDef PyTensorObject_extra_methods[] = {
{"diagonal", (PyCFunction)PyTensorObject_diagonal, METH_VARARGS | METH_KEYWORDS, NULL}, {"diagonal", (PyCFunction)PyTensorObject_diagonal, METH_VARARGS | METH_KEYWORDS, NULL},
{"addcmul", (PyCFunction)PyTensorObject_addcmul, METH_VARARGS | METH_KEYWORDS, NULL}, {"addcmul", (PyCFunction)PyTensorObject_addcmul, METH_VARARGS | METH_KEYWORDS, NULL},
{"addcmul_", (PyCFunction)PyTensorObject_addcmul_, METH_VARARGS | METH_KEYWORDS, NULL}, {"addcmul_", (PyCFunction)PyTensorObject_addcmul_, METH_VARARGS | METH_KEYWORDS, NULL},
{"addcdiv", (PyCFunction)PyTensorObject_addcdiv, METH_VARARGS | METH_KEYWORDS, NULL},
{"addcdiv_", (PyCFunction)PyTensorObject_addcdiv_, METH_VARARGS | METH_KEYWORDS, NULL},
{"matmul", (PyCFunction)PyTensorObject_matmul, METH_VARARGS | METH_KEYWORDS, NULL}, {"matmul", (PyCFunction)PyTensorObject_matmul, METH_VARARGS | METH_KEYWORDS, NULL},
{"bool", PyTensorObject_bool, METH_NOARGS, NULL},
{"int", PyTensorObject_int, METH_NOARGS, NULL}, {"int", PyTensorObject_int, METH_NOARGS, NULL},
{"long", PyTensorObject_long, METH_NOARGS, NULL}, {"long", PyTensorObject_long, METH_NOARGS, NULL},
{"half", PyTensorObject_half, METH_NOARGS, NULL}, {"half", PyTensorObject_half, METH_NOARGS, NULL},
{"float", PyTensorObject_float, METH_NOARGS, NULL}, {"float", PyTensorObject_float, METH_NOARGS, NULL},
{"double", PyTensorObject_double, METH_NOARGS, NULL}, {"double", PyTensorObject_double, METH_NOARGS, NULL},
{"bfloat16", PyTensorObject_bfloat16, METH_NOARGS, NULL},
{"local_to_global", (PyCFunction)PyTensorObject_local_to_global, METH_VARARGS | METH_KEYWORDS, {"local_to_global", (PyCFunction)PyTensorObject_local_to_global, METH_VARARGS | METH_KEYWORDS,
NULL}, NULL},
{"global_to_global", (PyCFunction)PyTensorObject_global_to_global, METH_VARARGS | METH_KEYWORDS, {"global_to_global", (PyCFunction)PyTensorObject_global_to_global, METH_VARARGS | METH_KEYWORDS,
NULL}, NULL},
{"to_local", PyTensorObject_to_local, METH_NOARGS, NULL}, {"to_local", (PyCFunction)PyTensorObject_to_local, METH_VARARGS | METH_KEYWORDS, NULL},
{"to_global", (PyCFunction)PyTensorObject_to_global, METH_VARARGS | METH_KEYWORDS, NULL}, {"to_global", (PyCFunction)PyTensorObject_to_global, METH_VARARGS | METH_KEYWORDS, NULL},
{"type_as", (PyCFunction)PyTensorObject_type_as, METH_VARARGS | METH_KEYWORDS, NULL},
{"cpu", PyTensorObject_cpu, METH_NOARGS, NULL}, {"cpu", PyTensorObject_cpu, METH_NOARGS, NULL},
{"cuda", (PyCFunction)PyTensorObject_cuda, METH_VARARGS | METH_KEYWORDS, NULL}, {"cuda", (PyCFunction)PyTensorObject_cuda, METH_VARARGS | METH_KEYWORDS, NULL},
{"var", (PyCFunction)PyTensorObject_var, METH_VARARGS | METH_KEYWORDS, NULL}, {"var", (PyCFunction)PyTensorObject_var, METH_VARARGS | METH_KEYWORDS, NULL},
...@@ -839,6 +960,7 @@ PyMethodDef PyTensorObject_extra_methods[] = { ...@@ -839,6 +960,7 @@ PyMethodDef PyTensorObject_extra_methods[] = {
// macro DIRECT_PASS_FUNC // macro DIRECT_PASS_FUNC
{"floor_divide", (PyCFunction)PyTensorObject_floor_divide, METH_VARARGS | METH_KEYWORDS, NULL}, {"floor_divide", (PyCFunction)PyTensorObject_floor_divide, METH_VARARGS | METH_KEYWORDS, NULL},
{"atan2", (PyCFunction)PyTensorObject_atan2, METH_VARARGS | METH_KEYWORDS, NULL}, {"atan2", (PyCFunction)PyTensorObject_atan2, METH_VARARGS | METH_KEYWORDS, NULL},
{"equal", (PyCFunction)PyTensorObject_equal, METH_VARARGS | METH_KEYWORDS, NULL},
{"gt", (PyCFunction)PyTensorObject_gt, METH_VARARGS | METH_KEYWORDS, NULL}, {"gt", (PyCFunction)PyTensorObject_gt, METH_VARARGS | METH_KEYWORDS, NULL},
{"ge", (PyCFunction)PyTensorObject_ge, METH_VARARGS | METH_KEYWORDS, NULL}, {"ge", (PyCFunction)PyTensorObject_ge, METH_VARARGS | METH_KEYWORDS, NULL},
{"div", (PyCFunction)PyTensorObject_div, METH_VARARGS | METH_KEYWORDS, NULL}, {"div", (PyCFunction)PyTensorObject_div, METH_VARARGS | METH_KEYWORDS, NULL},
...@@ -853,10 +975,15 @@ PyMethodDef PyTensorObject_extra_methods[] = { ...@@ -853,10 +975,15 @@ PyMethodDef PyTensorObject_extra_methods[] = {
{"ne", (PyCFunction)PyTensorObject_ne, METH_VARARGS | METH_KEYWORDS, NULL}, {"ne", (PyCFunction)PyTensorObject_ne, METH_VARARGS | METH_KEYWORDS, NULL},
{"lt", (PyCFunction)PyTensorObject_lt, METH_VARARGS | METH_KEYWORDS, NULL}, {"lt", (PyCFunction)PyTensorObject_lt, METH_VARARGS | METH_KEYWORDS, NULL},
{"le", (PyCFunction)PyTensorObject_le, METH_VARARGS | METH_KEYWORDS, NULL}, {"le", (PyCFunction)PyTensorObject_le, METH_VARARGS | METH_KEYWORDS, NULL},
{"flip", (PyCFunction)PyTensorObject_flip, METH_VARARGS | METH_KEYWORDS, NULL},
{"clip", (PyCFunction)PyTensorObject_clip, METH_VARARGS | METH_KEYWORDS, NULL}, {"clip", (PyCFunction)PyTensorObject_clip, METH_VARARGS | METH_KEYWORDS, NULL},
{"clip_", (PyCFunction)PyTensorObject_clip_, METH_VARARGS | METH_KEYWORDS, NULL}, {"clip_", (PyCFunction)PyTensorObject_clip_, METH_VARARGS | METH_KEYWORDS, NULL},
{"clamp", (PyCFunction)PyTensorObject_clamp, METH_VARARGS | METH_KEYWORDS, NULL}, {"clamp", (PyCFunction)PyTensorObject_clamp, METH_VARARGS | METH_KEYWORDS, NULL},
{"clamp_min", (PyCFunction)PyTensorObject_clamp_min, METH_VARARGS | METH_KEYWORDS, NULL},
{"clamp_max", (PyCFunction)PyTensorObject_clamp_max, METH_VARARGS | METH_KEYWORDS, NULL},
{"clamp_", (PyCFunction)PyTensorObject_clamp_, METH_VARARGS | METH_KEYWORDS, NULL}, {"clamp_", (PyCFunction)PyTensorObject_clamp_, METH_VARARGS | METH_KEYWORDS, NULL},
{"clamp_min_", (PyCFunction)PyTensorObject_clamp_min_, METH_VARARGS | METH_KEYWORDS, NULL},
{"clamp_max_", (PyCFunction)PyTensorObject_clamp_max_, METH_VARARGS | METH_KEYWORDS, NULL},
{"flatten", (PyCFunction)PyTensorObject_flatten, METH_VARARGS | METH_KEYWORDS, NULL}, {"flatten", (PyCFunction)PyTensorObject_flatten, METH_VARARGS | METH_KEYWORDS, NULL},
{"in_top_k", (PyCFunction)PyTensorObject_in_top_k, METH_VARARGS | METH_KEYWORDS, NULL}, {"in_top_k", (PyCFunction)PyTensorObject_in_top_k, METH_VARARGS | METH_KEYWORDS, NULL},
{"index_select", (PyCFunction)PyTensorObject_index_select, METH_VARARGS | METH_KEYWORDS, NULL}, {"index_select", (PyCFunction)PyTensorObject_index_select, METH_VARARGS | METH_KEYWORDS, NULL},
...@@ -864,6 +991,7 @@ PyMethodDef PyTensorObject_extra_methods[] = { ...@@ -864,6 +991,7 @@ PyMethodDef PyTensorObject_extra_methods[] = {
{"minimum", (PyCFunction)PyTensorObject_minimum, METH_VARARGS | METH_KEYWORDS, NULL}, {"minimum", (PyCFunction)PyTensorObject_minimum, METH_VARARGS | METH_KEYWORDS, NULL},
{"tril", (PyCFunction)PyTensorObject_tril, METH_VARARGS | METH_KEYWORDS, NULL}, {"tril", (PyCFunction)PyTensorObject_tril, METH_VARARGS | METH_KEYWORDS, NULL},
{"triu", (PyCFunction)PyTensorObject_triu, METH_VARARGS | METH_KEYWORDS, NULL}, {"triu", (PyCFunction)PyTensorObject_triu, METH_VARARGS | METH_KEYWORDS, NULL},
{"triu_", (PyCFunction)PyTensorObject_triu_, METH_VARARGS | METH_KEYWORDS, NULL},
{"softmax", (PyCFunction)PyTensorObject_softmax, METH_VARARGS | METH_KEYWORDS, NULL}, {"softmax", (PyCFunction)PyTensorObject_softmax, METH_VARARGS | METH_KEYWORDS, NULL},
{"log_softmax", (PyCFunction)PyTensorObject_log_softmax, METH_VARARGS | METH_KEYWORDS, NULL}, {"log_softmax", (PyCFunction)PyTensorObject_log_softmax, METH_VARARGS | METH_KEYWORDS, NULL},
{"roll", (PyCFunction)PyTensorObject_roll, METH_VARARGS | METH_KEYWORDS, NULL}, {"roll", (PyCFunction)PyTensorObject_roll, METH_VARARGS | METH_KEYWORDS, NULL},
...@@ -879,9 +1007,19 @@ PyMethodDef PyTensorObject_extra_methods[] = { ...@@ -879,9 +1007,19 @@ PyMethodDef PyTensorObject_extra_methods[] = {
{"median", (PyCFunction)PyTensorObject_median, METH_VARARGS | METH_KEYWORDS, NULL}, {"median", (PyCFunction)PyTensorObject_median, METH_VARARGS | METH_KEYWORDS, NULL},
{"pow", (PyCFunction)PyTensorObject_pow, METH_VARARGS | METH_KEYWORDS, NULL}, {"pow", (PyCFunction)PyTensorObject_pow, METH_VARARGS | METH_KEYWORDS, NULL},
{"chunk", (PyCFunction)PyTensorObject_chunk, METH_VARARGS | METH_KEYWORDS, NULL}, {"chunk", (PyCFunction)PyTensorObject_chunk, METH_VARARGS | METH_KEYWORDS, NULL},
{"split", (PyCFunction)PyTensorObject_split, METH_VARARGS | METH_KEYWORDS, NULL},
{"narrow", (PyCFunction)PyTensorObject_narrow, METH_VARARGS | METH_KEYWORDS, NULL}, {"narrow", (PyCFunction)PyTensorObject_narrow, METH_VARARGS | METH_KEYWORDS, NULL},
{"masked_fill", (PyCFunction)PyTensorObject_masked_fill, METH_VARARGS | METH_KEYWORDS, NULL}, {"masked_fill", (PyCFunction)PyTensorObject_masked_fill, METH_VARARGS | METH_KEYWORDS, NULL},
{"masked_fill_", (PyCFunction)PyTensorObject_masked_fill_, METH_VARARGS | METH_KEYWORDS, NULL},
{"dot", (PyCFunction)PyTensorObject_dot, METH_VARARGS | METH_KEYWORDS, NULL}, {"dot", (PyCFunction)PyTensorObject_dot, METH_VARARGS | METH_KEYWORDS, NULL},
{"nansum", (PyCFunction)PyTensorObject_nansum, METH_VARARGS | METH_KEYWORDS, NULL},
{"bernoulli", (PyCFunction)PyTensorObject_bernoulli, METH_VARARGS | METH_KEYWORDS, NULL},
{"bernoulli_", (PyCFunction)PyTensorObject_bernoulli_, METH_VARARGS | METH_KEYWORDS, NULL},
{"bincount", (PyCFunction)PyTensorObject_bincount, METH_VARARGS | METH_KEYWORDS, NULL},
{"isclose", (PyCFunction)PyTensorObject_isclose, METH_VARARGS | METH_KEYWORDS, NULL},
{"broadcast_to", (PyCFunction)PyTensorObject_broadcast_to, METH_VARARGS | METH_KEYWORDS, NULL},
{"unique", (PyCFunction)PyTensorObject_unique, METH_VARARGS | METH_KEYWORDS, NULL},
{"topk", (PyCFunction)PyTensorObject_topk, METH_VARARGS | METH_KEYWORDS, NULL},
// macro UNARY_METHOD // macro UNARY_METHOD
{"abs", PyTensorObject_abs, METH_NOARGS, NULL}, {"abs", PyTensorObject_abs, METH_NOARGS, NULL},
...@@ -908,6 +1046,7 @@ PyMethodDef PyTensorObject_extra_methods[] = { ...@@ -908,6 +1046,7 @@ PyMethodDef PyTensorObject_extra_methods[] = {
{"softsign", PyTensorObject_softsign, METH_NOARGS, NULL}, {"softsign", PyTensorObject_softsign, METH_NOARGS, NULL},
{"log1p", PyTensorObject_log1p, METH_NOARGS, NULL}, {"log1p", PyTensorObject_log1p, METH_NOARGS, NULL},
{"log2", PyTensorObject_log2, METH_NOARGS, NULL}, {"log2", PyTensorObject_log2, METH_NOARGS, NULL},
{"log10", PyTensorObject_log10, METH_NOARGS, NULL},
{"reciprocal", PyTensorObject_reciprocal, METH_NOARGS, NULL}, {"reciprocal", PyTensorObject_reciprocal, METH_NOARGS, NULL},
{"asin", PyTensorObject_asin, METH_NOARGS, NULL}, {"asin", PyTensorObject_asin, METH_NOARGS, NULL},
{"arcsin", PyTensorObject_asin, METH_NOARGS, NULL}, {"arcsin", PyTensorObject_asin, METH_NOARGS, NULL},
...@@ -942,6 +1081,7 @@ PyMethodDef PyTensorObject_extra_methods[] = { ...@@ -942,6 +1081,7 @@ PyMethodDef PyTensorObject_extra_methods[] = {
{"view_as", (PyCFunction)PyTensorObject_view_as, METH_VARARGS | METH_KEYWORDS, NULL}, {"view_as", (PyCFunction)PyTensorObject_view_as, METH_VARARGS | METH_KEYWORDS, NULL},
{"permute", (PyCFunction)PyTensorObject_permute, METH_VARARGS | METH_KEYWORDS, NULL}, {"permute", (PyCFunction)PyTensorObject_permute, METH_VARARGS | METH_KEYWORDS, NULL},
{"transpose", (PyCFunction)PyTensorObject_transpose, METH_VARARGS | METH_KEYWORDS, NULL}, {"transpose", (PyCFunction)PyTensorObject_transpose, METH_VARARGS | METH_KEYWORDS, NULL},
{"logsumexp", (PyCFunction)PyTensorObject_logsumexp, METH_VARARGS | METH_KEYWORDS, NULL},
{NULL}, {NULL},
}; };
...@@ -954,7 +1094,7 @@ PyObject* PyTensorObject_richcompare(PyObject* self, PyObject* other, int op) { ...@@ -954,7 +1094,7 @@ PyObject* PyTensorObject_richcompare(PyObject* self, PyObject* other, int op) {
case Py_LE: return functional::less_equal(NULL, tuple.get(), NULL); case Py_LE: return functional::less_equal(NULL, tuple.get(), NULL);
case Py_EQ: { case Py_EQ: {
if (self == Py_None || other == Py_None) return Py_False; if (self == Py_None || other == Py_None) return Py_False;
return functional::equal(NULL, tuple.get(), NULL); return functional::broadcast_equal(NULL, tuple.get(), NULL);
} }
case Py_NE: return functional::not_equal(NULL, tuple.get(), NULL); case Py_NE: return functional::not_equal(NULL, tuple.get(), NULL);
case Py_GT: return functional::greater(NULL, tuple.get(), NULL); case Py_GT: return functional::greater(NULL, tuple.get(), NULL);
......
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include <pybind11/pybind11.h>
#include "oneflow/api/python/of_api_registry.h"
#include "oneflow/api/python/framework/thread.h"
#include "oneflow/core/common/env_var/vm.h"
namespace py = pybind11;
namespace oneflow {
namespace {
class UsingThreadUidSet final {
public:
UsingThreadUidSet()
: using_thread_uids_({Stream::kDefaultStreamThreadUid}),
thread_limits_(using_thread_uids_.size()
+ ThreadLocalEnvInteger<ONEFLOW_VM_WORKER_THREAD_LIMIT>()) {}
~UsingThreadUidSet() = default;
Maybe<int64_t> Get() {
std::unique_lock<std::mutex> lock(mutex_);
CHECK_LT_OR_RETURN(using_thread_uids_.size(), thread_limits_)
<< "can not create more worker threads. please check your code or increase environment "
"variable ONEFLOW_VM_WORKER_THREAD_LIMIT(default value:"
<< ThreadLocalEnvInteger<ONEFLOW_VM_WORKER_THREAD_LIMIT>() << ")";
for (int i = 0; i < using_thread_uids_.size() + 1; ++i) {
if (using_thread_uids_.count(i) == 0) {
using_thread_uids_.insert(i);
return i;
}
}
UNIMPLEMENTED_THEN_RETURN();
}
Maybe<void> Put(int64_t thread_uid) {
std::unique_lock<std::mutex> lock(mutex_);
CHECK_NE_OR_RETURN(thread_uid, Stream::kDefaultStreamThreadUid)
<< "default thread_uid should not be erased. value: " << thread_uid;
CHECK_OR_RETURN(using_thread_uids_.erase(thread_uid) > 0)
<< "no thread_uid found. (current: " << thread_uid << ").";
return Maybe<void>::Ok();
}
private:
std::set<int64_t> using_thread_uids_;
size_t thread_limits_;
std::mutex mutex_;
};
UsingThreadUidSet* MutUsingThreadUidSet() {
static UsingThreadUidSet thread_uid_set;
return &thread_uid_set;
}
} // namespace
/*static*/ Maybe<AsyncThread> AsyncThread::New() {
return std::shared_ptr<AsyncThread>(new AsyncThread(JUST(MutUsingThreadUidSet()->Get())));
}
AsyncThread::~AsyncThread() { MutUsingThreadUidSet()->Put(thread_uid_).GetOrThrow(); }
} // namespace oneflow
ONEFLOW_API_PYBIND11_MODULE("", m) {
using namespace oneflow;
py::class_<AsyncThread, std::shared_ptr<AsyncThread>>(m, "AsyncThread").def(py::init([]() {
return AsyncThread::New().GetPtrOrThrow();
}));
}
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_API_PYTHON_FRAMEWORK_THREAD_H_
#define ONEFLOW_API_PYTHON_FRAMEWORK_THREAD_H_
#include "oneflow/core/framework/stream.h"
#include "oneflow/core/common/util.h"
namespace oneflow {
class AsyncThread final {
public:
OF_DISALLOW_COPY_AND_MOVE(AsyncThread);
~AsyncThread();
static Maybe<AsyncThread> New();
int64_t thread_uid() const { return thread_uid_; }
private:
AsyncThread(int64_t thread_uid) : thread_uid_(thread_uid) {}
int64_t thread_uid_;
};
} // namespace oneflow
#endif // ONEFLOW_API_PYTHON_FRAMEWORK_THREAD_H_
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include <limits>
#include "oneflow/api/python/exception/exception.h"
#include "oneflow/api/python/functional/common.h"
#include "oneflow/api/python/of_api_registry.h"
#include "oneflow/api/python/framework/typeinfo.h"
namespace oneflow {
namespace one {
#define ASSERT(x) (x).GetOrThrow()
#if PY_VERSION_HEX < 0x03070000
#define PYGETSET_NAME(name) const_cast<char*>(name)
#else
#define PYGETSET_NAME(name) (name)
#endif
using functional::PyObjectPtr;
#define INT_TYPE INT_DATA_TYPE_SEQ UNSIGNED_INT_DATA_TYPE_SEQ
// TODO(WangYi): support bf16
#define FLOAT_TYPE FLOATING_DATA_TYPE_SEQ FLOAT16_DATA_TYPE_SEQ
PyObject* PyGetMaxVal(DataType datatype) {
#define GET_INT_MAX_VAL(cpp_type, of_datatype) \
case of_datatype: \
return PyLong_FromLong(GetMaxVal<DataTypeToType<of_datatype>>());
#define GET_FLOAT_MAX_VAL(cpp_type, of_datatype) \
case of_datatype: return PyFloat_FromDouble(GetMaxVal<DataTypeToType<of_datatype>>());
switch (datatype) {
OF_PP_FOR_EACH_TUPLE(GET_INT_MAX_VAL, INT_TYPE);
OF_PP_FOR_EACH_TUPLE(GET_FLOAT_MAX_VAL, FLOAT_TYPE);
default: return NULL;
#undef GET_INT_MAX_VAL
#undef GET_FLOAT_MAX_VAL
}
}
PyObject* PyGetMinVal(DataType datatype) {
#define GET_INT_MIN_VAL(cpp_type, of_datatype) \
case of_datatype: \
return PyLong_FromLong(GetMinVal<DataTypeToType<of_datatype>>());
#define GET_FLOAT_MIN_VAL(cpp_type, of_datatype) \
case of_datatype: return PyFloat_FromDouble(GetMinVal<DataTypeToType<of_datatype>>());
switch (datatype) {
OF_PP_FOR_EACH_TUPLE(GET_INT_MIN_VAL, INT_TYPE);
OF_PP_FOR_EACH_TUPLE(GET_FLOAT_MIN_VAL, FLOAT_TYPE);
default: return NULL;
#undef GET_INT_MIN_VAL
#undef GET_FLOAT_MIN_VAL
}
}
#define GET_FLOAT_RESOLUTION(cpp_type, of_datatype) \
case of_datatype: \
return PyFloat_FromDouble( \
std::pow(10, -std::numeric_limits<DataTypeToType<of_datatype>>::digits10));
#define GET_FLOAT_EPS(cpp_type, of_datatype) \
case of_datatype: \
return PyFloat_FromDouble(std::numeric_limits<DataTypeToType<of_datatype>>::epsilon());
#define GET_FLOAT_TINY(cpp_type, of_datatype) \
case of_datatype: \
return PyFloat_FromDouble(std::numeric_limits<DataTypeToType<of_datatype>>::min());
PyTypeObject PyIInfoType = {
PyVarObject_HEAD_INIT(NULL, 0) "oneflow.iinfo", // tp_name
sizeof(PyDTypeInfo), // tp_basicsize
};
PyTypeObject PyFInfoType = {
PyVarObject_HEAD_INIT(NULL, 0) "oneflow.finfo", // tp_name
sizeof(PyDTypeInfo), // tp_basicsize
};
static PyObject* PyIInfo_new(PyTypeObject* self, PyObject* args, PyObject* kwargs) {
HANDLE_ERRORS
PyObject* dtype_obj = NULL;
static const char* keywords[2] = {"type", NULL};
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O:iinfo", const_cast<char**>(keywords),
&dtype_obj)) {
return NULL;
}
CHECK_OR_THROW(functional::PyDTypeCheck(dtype_obj))
<< Error::TypeError() << "iinfo(): argument 'type' must be oneflow.dtype, but found "
<< functional::PyStringAsString(PyObject_Str((PyObject*)Py_TYPE(dtype_obj)));
auto* self = (PyDTypeInfo*)PyIInfoType.tp_alloc(&PyIInfoType, 0);
if (!self) { throw py::error_already_set(); }
self->dtype = functional::PyUnpackDType(dtype_obj);
CHECK_OR_THROW(!self->dtype->is_floating_point() && !self->dtype->is_complex())
<< Error::TypeError()
<< "oneflow.iinfo() requires an integer input type. Use oneflow.finfo to handle '"
<< self->dtype->name() << "' ";
return (PyObject*)self;
END_HANDLE_ERRORS
}
static PyObject* PyFInfo_new(PyTypeObject* self, PyObject* args, PyObject* kwargs) {
HANDLE_ERRORS
PyObject* dtype_obj = functional::CastToPyObject(DType::Float());
static const char* keywords[2] = {"type", NULL};
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|O:finfo", const_cast<char**>(keywords),
&dtype_obj)) {
return NULL;
}
CHECK_OR_THROW(functional::PyDTypeCheck(dtype_obj))
<< Error::TypeError() << "finfo(): argument 'type' must be oneflow.dtype, but found "
<< functional::PyStringAsString(PyObject_Str((PyObject*)Py_TYPE(dtype_obj)));
auto* self = (PyDTypeInfo*)PyFInfoType.tp_alloc(&PyFInfoType, 0);
if (!self) { throw py::error_already_set(); }
self->dtype = functional::PyUnpackDType(dtype_obj);
CHECK_OR_THROW(self->dtype->is_floating_point() && !self->dtype->is_complex())
<< Error::TypeError()
<< "oneflow.finfo() requires a float input type. Use oneflow.iinfo to handle '"
<< self->dtype->name() << "' ";
// TODO (wangyi): support bfloat16
CHECK_OR_THROW(self->dtype->data_type() != kBFloat16)
<< Error::TypeError() << "bfloat16 is not supported yet by oneflow.finfo";
return (PyObject*)self;
END_HANDLE_ERRORS
}
static PyObject* PyDInfo_bits(PyObject* self, void*) {
HANDLE_ERRORS
size_t bits = ASSERT(((PyDTypeInfo*)self)->dtype->bytes()) * 8;
return PyLong_FromSize_t(bits);
END_HANDLE_ERRORS
}
static PyObject* PyDInfo_min(PyObject* self, void*) {
HANDLE_ERRORS
DataType datatype = PyDTypeInfo_UnpackDataType(self);
PyObject* result = PyGetMinVal(datatype);
if (!result) {
THROW(RuntimeError) << PyDTypeInfo_UnpackDType(self)->name() << " not supported by "
<< self->ob_type->tp_name;
}
return result;
END_HANDLE_ERRORS
}
static PyObject* PyDInfo_max(PyObject* self, void*) {
HANDLE_ERRORS
DataType datatype = PyDTypeInfo_UnpackDataType(self);
PyObject* result = PyGetMaxVal(datatype);
if (!result) {
THROW(RuntimeError) << PyDTypeInfo_UnpackDType(self)->name() << " not supported by "
<< self->ob_type->tp_name;
}
return result;
END_HANDLE_ERRORS
}
static PyObject* PyFInfo_resolution(PyObject* self, void*) {
HANDLE_ERRORS
DataType datatype = PyDTypeInfo_UnpackDataType(self);
switch (datatype) {
OF_PP_FOR_EACH_TUPLE(GET_FLOAT_RESOLUTION, FLOAT_TYPE);
default:
THROW(RuntimeError) << PyDTypeInfo_UnpackDType(self)->name()
<< " not supported by oneflow.finfo";
return NULL;
}
END_HANDLE_ERRORS
}
static PyObject* PyFInfo_eps(PyObject* self, void*) {
HANDLE_ERRORS
DataType datatype = PyDTypeInfo_UnpackDataType(self);
switch (datatype) {
OF_PP_FOR_EACH_TUPLE(GET_FLOAT_EPS, FLOAT_TYPE);
default:
THROW(RuntimeError) << PyDTypeInfo_UnpackDType(self)->name()
<< " not supported by oneflow.finfo";
return NULL;
}
END_HANDLE_ERRORS
}
static PyObject* PyFInfo_tiny(PyObject* self, void*) {
HANDLE_ERRORS
DataType datatype = PyDTypeInfo_UnpackDataType(self);
switch (datatype) {
OF_PP_FOR_EACH_TUPLE(GET_FLOAT_TINY, FLOAT_TYPE);
default:
THROW(RuntimeError) << PyDTypeInfo_UnpackDType(self)->name()
<< " not supported by oneflow.finfo";
return NULL;
}
END_HANDLE_ERRORS
}
static PyObject* PyDInfo_dtype(PyObject* self, void*) {
HANDLE_ERRORS
std::string name = ((PyDTypeInfo*)self)->dtype->name();
name = name.erase(0, name.find('.') + 1);
return PyUnicode_FromString(name.data());
END_HANDLE_ERRORS
}
static PyObject* PyIInfo_str(PyObject* self) {
HANDLE_ERRORS
std::ostringstream oss;
oss << "iinfo(min=" << PyLong_AS_LONG(PyDInfo_min((PyObject*)self, NULL)) << ", ";
oss << "max=" << PyLong_AS_LONG(PyDInfo_max((PyObject*)self, NULL)) << ", ";
oss << "dtype=" << PyDTypeInfo_UnpackDType(self)->name() << ", ";
oss << "bits=" << PyLong_AS_LONG(PyDInfo_bits((PyObject*)self, NULL)) << ")";
return PyUnicode_FromString(oss.str().data());
END_HANDLE_ERRORS
}
static PyObject* PyFInfo_str(PyObject* self) {
HANDLE_ERRORS
std::ostringstream oss;
oss << "finfo(resolution=" << PyFloat_AS_DOUBLE(PyFInfo_resolution((PyObject*)self, NULL))
<< ", ";
oss << "min=" << PyFloat_AS_DOUBLE(PyDInfo_min((PyObject*)self, NULL)) << ", ";
oss << "max=" << PyFloat_AS_DOUBLE(PyDInfo_max((PyObject*)self, NULL)) << ", ";
oss << "eps=" << PyFloat_AS_DOUBLE(PyFInfo_eps((PyObject*)self, NULL)) << ", ";
oss << "tiny=" << PyFloat_AS_DOUBLE(PyFInfo_tiny((PyObject*)self, NULL)) << ", ";
oss << "dtype=" << PyDTypeInfo_UnpackDType(self)->name() << ", ";
oss << "bits=" << PyLong_AS_LONG(PyDInfo_bits((PyObject*)self, NULL)) << ")";
return PyUnicode_FromString(oss.str().data());
END_HANDLE_ERRORS
}
static struct PyGetSetDef PyIInfo_properties[] = {
{PYGETSET_NAME("bits"), (getter)PyDInfo_bits, nullptr, nullptr, nullptr},
{PYGETSET_NAME("max"), (getter)PyDInfo_max, nullptr, nullptr, nullptr},
{PYGETSET_NAME("min"), (getter)PyDInfo_min, nullptr, nullptr, nullptr},
{PYGETSET_NAME("dtype"), (getter)PyDInfo_dtype, nullptr, nullptr, nullptr},
{nullptr},
};
static struct PyGetSetDef PyFInfo_properties[] = {
{PYGETSET_NAME("bits"), (getter)PyDInfo_bits, nullptr, nullptr, nullptr},
{PYGETSET_NAME("max"), (getter)PyDInfo_max, nullptr, nullptr, nullptr},
{PYGETSET_NAME("min"), (getter)PyDInfo_min, nullptr, nullptr, nullptr},
{PYGETSET_NAME("resolution"), (getter)PyFInfo_resolution, nullptr, nullptr, nullptr},
{PYGETSET_NAME("eps"), (getter)PyFInfo_eps, nullptr, nullptr, nullptr},
{PYGETSET_NAME("tiny"), (getter)PyFInfo_tiny, nullptr, nullptr, nullptr},
{PYGETSET_NAME("dtype"), (getter)PyDInfo_dtype, nullptr, nullptr, nullptr},
{nullptr},
};
static void init_info_type() {
PyIInfoType.tp_flags = Py_TPFLAGS_DEFAULT;
PyIInfoType.tp_str = (reprfunc)PyIInfo_str;
PyIInfoType.tp_repr = (reprfunc)PyIInfo_str;
PyIInfoType.tp_new = (newfunc)PyIInfo_new;
PyIInfoType.tp_getset = PyIInfo_properties;
if (PyType_Ready(&PyIInfoType) < 0) { return; }
PyFInfoType.tp_flags = Py_TPFLAGS_DEFAULT;
PyFInfoType.tp_str = (reprfunc)PyFInfo_str;
PyFInfoType.tp_repr = (reprfunc)PyFInfo_str;
PyFInfoType.tp_new = (newfunc)PyFInfo_new;
PyFInfoType.tp_getset = PyFInfo_properties;
if (PyType_Ready(&PyFInfoType) < 0) { return; }
}
ONEFLOW_API_PYBIND11_MODULE("_C", m) {
init_info_type();
if (PyModule_AddObject(m.ptr(), "iinfo", (PyObject*)&PyIInfoType) < 0) return;
if (PyModule_AddObject(m.ptr(), "finfo", (PyObject*)&PyFInfoType) < 0) return;
}
} // namespace one
} // namespace oneflow
#undef ASSERT
#undef GET_FLOAT_RESOLUTION
#undef GET_FLOAT_EPS
#undef GET_FLOAT_TINY
#undef INT_TYPE
#undef FLOAT_TYPE
#undef PYGETSET_NAME
\ No newline at end of file
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_API_PYTHON_FRAMEWORK_TYPEINFO_H_
#define ONEFLOW_API_PYTHON_FRAMEWORK_TYPEINFO_H_
#include <Python.h>
#include "oneflow/core/common/throw.h"
#include "oneflow/core/framework/dtype.h"
namespace oneflow {
namespace one {
typedef struct {
PyObject_HEAD;
Symbol<DType> dtype;
} PyDTypeInfo;
extern PyTypeObject PyIInfoType;
extern PyTypeObject PyFInfoType;
inline bool PyIInfo_Check(PyObject* obj) { return PyObject_TypeCheck(obj, &PyIInfoType); }
inline bool PyFInfo_Check(PyObject* obj) { return PyObject_TypeCheck(obj, &PyFInfoType); }
inline bool PyDTypeInfo_Check(PyObject* obj) { return PyIInfo_Check(obj) || PyFInfo_Check(obj); }
inline Symbol<DType> PyDTypeInfo_UnpackDType(PyObject* obj) {
assert(PyDTypeInfo_Check(obj));
return ((PyDTypeInfo*)obj)->dtype;
}
inline DataType PyDTypeInfo_UnpackDataType(PyObject* obj) {
assert(PyDTypeInfo_Check(obj));
return ((PyDTypeInfo*)obj)->dtype->data_type();
}
} // namespace one
} // namespace oneflow
#endif // ONEFLOW_API_PYTHON_FRAMEWORK_TYPEINFO_H_
...@@ -26,7 +26,7 @@ namespace oneflow { ...@@ -26,7 +26,7 @@ namespace oneflow {
ONEFLOW_API_PYBIND11_MODULE("", m) { ONEFLOW_API_PYBIND11_MODULE("", m) {
m.def("FillVariableTensorMgr", &FillVariableTensorMgr); m.def("FillVariableTensorMgr", &FillVariableTensorMgr);
m.def("DumpVariableTensorMgr", &DumpVariableTensorMgr); m.def("DumpVariableTensorMgr", &DumpVariableTensorMgr);
m.def("ClearVariableTensorMgr", &ClearVariableTensorMgr); m.def("ResetVariableTensorMgr", &ResetVariableTensorMgr);
} }
} // namespace oneflow } // namespace oneflow
...@@ -13,7 +13,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -13,7 +13,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
*/ */
#include "oneflow/api/python/functional/common.h" #include "oneflow/api/python/functional/common.h"
#include <object.h> #include <object.h>
#include <string> #include <string>
...@@ -28,12 +27,62 @@ limitations under the License. ...@@ -28,12 +27,62 @@ limitations under the License.
#include "oneflow/core/framework/tensor.h" #include "oneflow/core/framework/tensor.h"
#include "oneflow/core/framework/tensor_tuple.h" #include "oneflow/core/framework/tensor_tuple.h"
#include "oneflow/core/framework/random_generator.h" #include "oneflow/core/framework/random_generator.h"
#include "oneflow/core/framework/instructions_builder.h"
#include "oneflow/core/functional/tensor_index.h" #include "oneflow/core/functional/tensor_index.h"
#include "oneflow/core/functional/functional.h"
#include "oneflow/core/vm/virtual_machine.h"
#include "oneflow/core/kernel/kernel_util.h"
#include "oneflow/core/framework/tensor_util.h"
namespace oneflow { namespace oneflow {
namespace one { namespace one {
namespace functional { namespace functional {
namespace detail {
namespace {
template<typename T>
Maybe<T> GetItemInPyScalarTensor(PyObject* obj) {
return GetItemInScalarTensor<T>(PyTensor_Unpack(obj));
}
} // namespace
template<typename T, typename std::enable_if<!std::is_base_of<py::object, T>::value, int>::type = 0>
bool isinstance_fast(PyObject* obj) {
static auto type = py::detail::get_type_handle(typeid(T), false);
if (!type) { return false; }
const auto result = PyObject_IsInstance(obj, type.ptr());
if (result == -1) { throw py::error_already_set(); }
return result != 0;
}
template<typename T, typename std::enable_if<!std::is_base_of<py::object, T>::value
&& !py::detail::is_shared_ptr<T>::value,
int>::type = 0>
const T& cast_fast(PyObject* obj) {
auto vh = reinterpret_cast<py::detail::instance*>(obj)->get_value_and_holder();
auto*& vptr = vh.value_ptr();
if (!vptr) {
throw py::cast_error("Unable to cast from object to T& since lazy allocation is not allowed "
"for fast cast, please use pybind11::cast instead");
}
return *reinterpret_cast<T*>(&vptr);
}
template<typename T, typename std::enable_if<!std::is_base_of<py::object, T>::value
&& py::detail::is_shared_ptr<T>::value,
int>::type = 0>
const T& cast_fast(PyObject* obj) {
auto vh = reinterpret_cast<py::detail::instance*>(obj)->get_value_and_holder();
if (!vh.holder_constructed()) {
throw py::cast_error("Unable to cast from non-held to held instance (T& to Holder<T>)");
}
return vh.template holder<T>();
}
} // namespace detail
bool PySequenceCheck(PyObject* obj, const std::function<bool(PyObject*)>& item_check) { bool PySequenceCheck(PyObject* obj, const std::function<bool(PyObject*)>& item_check) {
bool is_tuple = PyTuple_Check(obj); bool is_tuple = PyTuple_Check(obj);
if (!is_tuple && !PyList_Check(obj)) { return false; } if (!is_tuple && !PyList_Check(obj)) { return false; }
...@@ -44,12 +93,15 @@ bool PySequenceCheck(PyObject* obj, const std::function<bool(PyObject*)>& item_c ...@@ -44,12 +93,15 @@ bool PySequenceCheck(PyObject* obj, const std::function<bool(PyObject*)>& item_c
} }
bool PyLongSequenceCheck(PyObject* obj) { bool PyLongSequenceCheck(PyObject* obj) {
return PySequenceCheck(obj, [](PyObject* item) { return PyLong_Check(item); }); return PySequenceCheck(
obj, [](PyObject* item) { return PyLong_Check(item) || PyIntegerScalarTensorCheck(item); });
} }
bool PyFloatSquenceCheck(PyObject* obj) { bool PyFloatSequenceCheck(PyObject* obj) {
return PySequenceCheck(obj, return PySequenceCheck(obj, [](PyObject* item) {
[](PyObject* item) { return PyFloat_Check(item) || PyLong_Check(item); }); return PyFloat_Check(item) || PyLong_Check(item) || PyFloatScalarTensorCheck(item)
|| PyIntegerScalarTensorCheck(item);
});
} }
bool PyStringCheck(PyObject* obj) { return PyBytes_Check(obj) || PyUnicode_Check(obj); } bool PyStringCheck(PyObject* obj) { return PyBytes_Check(obj) || PyUnicode_Check(obj); }
...@@ -82,14 +134,10 @@ std::vector<std::shared_ptr<Tensor>> PyUnpackTensorSequence(PyObject* obj) { ...@@ -82,14 +134,10 @@ std::vector<std::shared_ptr<Tensor>> PyUnpackTensorSequence(PyObject* obj) {
} }
// TensorTuple // TensorTuple
bool PyTensorTupleCheck(PyObject* obj) { bool PyTensorTupleCheck(PyObject* obj) { return detail::isinstance_fast<TensorTuple>(obj); }
auto handle = py::reinterpret_borrow<py::object>(obj);
return py::isinstance<TensorTuple>(handle);
}
std::shared_ptr<TensorTuple> PyUnpackTensorTuple(PyObject* obj) { std::shared_ptr<TensorTuple> PyUnpackTensorTuple(PyObject* obj) {
auto handle = py::reinterpret_borrow<py::object>(obj); return detail::cast_fast<std::shared_ptr<TensorTuple>>(obj);
return py::cast<std::shared_ptr<TensorTuple>>(handle);
} }
// Scalar // Scalar
...@@ -107,16 +155,60 @@ Scalar PyUnpackScalar(PyObject* obj) { ...@@ -107,16 +155,60 @@ Scalar PyUnpackScalar(PyObject* obj) {
return 0; return 0;
} }
// DType // Scalar Tensor
bool PyDTypeCheck(PyObject* obj) { bool PyScalarTensorCheck(PyObject* obj) {
auto handle = py::reinterpret_borrow<py::object>(obj); if (!LazyMode::is_enabled() && PyTensor_Check(obj)) {
return py::isinstance<Symbol<DType>>(handle); const auto& tensor = PyTensor_Unpack(obj);
return tensor->shape()->size() == 0 && IsPODDataType(tensor->dtype()->data_type());
}
return false;
} }
Symbol<DType> PyUnpackDType(PyObject* obj) {
auto handle = py::reinterpret_borrow<py::object>(obj); Scalar PyUnpackScalarTensor(PyObject* obj) {
return *py::cast<Symbol<DType>*>(handle); if (PyBoolScalarTensorCheck(obj)) {
return PyUnpackBoolScalarTensor(obj);
} else if (PyIntegerScalarTensorCheck(obj)) {
return PyUnpackIntegerScalarTensor_AsLongLong(obj);
} else if (PyFloatScalarTensorCheck(obj)) {
return PyUnpackFloatScalarTensor_AsDouble(obj);
}
THROW(RuntimeError) << "The object is not scalar tensor, but is " << Py_TYPE(obj)->tp_name
<< "with data type: "
<< DataType_Name(PyTensor_Unpack(obj)->dtype()->data_type());
return 0;
} }
#define SWITCH_SCALAR_TENSOR_TO_SCALAR(cpp_type, of_type) \
case of_type: \
return detail::GetItemInPyScalarTensor<cpp_type>(obj).GetOrThrow();
#define SCALAR_TENSOR_UNPACK_FUNC_IMPL(func_name, return_type, type_seq) \
return_type func_name(PyObject* obj) { \
const auto& tensor = PyTensor_Unpack(obj); \
DataType data_type = tensor->dtype()->data_type(); \
switch (data_type) { \
OF_PP_FOR_EACH_TUPLE(SWITCH_SCALAR_TENSOR_TO_SCALAR, type_seq) \
default: { \
throw py::cast_error("Cannot get ##cpp##type from scalar tensor with data type: " \
+ DataType_Name(data_type)); \
} \
} \
}
SCALAR_TENSOR_UNPACK_FUNC_IMPL(PyUnpackBoolScalarTensor, bool,
BOOL_DATA_TYPE_SEQ CHAR_DATA_TYPE_SEQ);
SCALAR_TENSOR_UNPACK_FUNC_IMPL(PyUnpackIntegerScalarTensor_AsLongLong, long long,
INT_DATA_TYPE_SEQ UNSIGNED_INT_DATA_TYPE_SEQ BOOL_DATA_TYPE_SEQ
CHAR_DATA_TYPE_SEQ);
SCALAR_TENSOR_UNPACK_FUNC_IMPL(PyUnpackFloatScalarTensor_AsDouble, double,
FLOATING_DATA_TYPE_SEQ INT_DATA_TYPE_SEQ UNSIGNED_INT_DATA_TYPE_SEQ);
#undef SWITCH_SCALAR_TENSOR_TO_SCALAR
#undef SCALAR_TENSOR_UNPACK_FUNC_IMPL
// DType
bool PyDTypeCheck(PyObject* obj) { return detail::isinstance_fast<Symbol<DType>>(obj); }
Symbol<DType> PyUnpackDType(PyObject* obj) { return *detail::cast_fast<Symbol<DType>*>(obj); }
// DType list // DType list
bool PyDTypeSequenceCheck(PyObject* obj) { bool PyDTypeSequenceCheck(PyObject* obj) {
return PySequenceCheck(obj, [](PyObject* item) { return PyDTypeCheck(item); }); return PySequenceCheck(obj, [](PyObject* item) { return PyDTypeCheck(item); });
...@@ -125,55 +217,54 @@ std::vector<Symbol<DType>> PyUnpackDTypeSequence(PyObject* obj) { ...@@ -125,55 +217,54 @@ std::vector<Symbol<DType>> PyUnpackDTypeSequence(PyObject* obj) {
return PyUnpackSequence<Symbol<DType>>(obj, [](PyObject* item) { return PyUnpackDType(item); }); return PyUnpackSequence<Symbol<DType>>(obj, [](PyObject* item) { return PyUnpackDType(item); });
} }
// Shape
bool PyShapeCheck(PyObject* obj) { return PyLongSequenceCheck(obj); }
Shape PyUnpackShape(PyObject* obj) {
bool is_tuple = PyTuple_Check(obj);
CHECK_OR_THROW(is_tuple || PyList_Check(obj))
<< "The object is not list or tuple, but is " << Py_TYPE(obj)->tp_name;
size_t size = is_tuple ? PyTuple_GET_SIZE(obj) : PyList_GET_SIZE(obj);
DimVector values(size);
for (int i = 0; i < size; ++i) {
PyObject* item = is_tuple ? PyTuple_GET_ITEM(obj, i) : PyList_GET_ITEM(obj, i);
values[i] = PyLong_AsLongLong(item);
}
return Shape(values);
}
// Shape list // Shape list
bool PyShapeSequenceCheck(PyObject* obj) { bool PyShapeSequenceCheck(PyObject* obj) {
return PySequenceCheck(obj, [](PyObject* item) { return PyLongSequenceCheck(item); }); return PySequenceCheck(obj, [](PyObject* item) { return PyLongSequenceCheck(item); });
} }
std::vector<Shape> PyUnpackShapeSequence(PyObject* obj) { std::vector<Shape> PyUnpackShapeSequence(PyObject* obj) {
return PyUnpackSequence<Shape>(obj, [](PyObject* item) -> Shape { return PyUnpackSequence<Shape>(obj, [](PyObject* item) -> Shape { return PyUnpackShape(item); });
const auto& shape = PyUnpackLongSequence<int64_t>(item);
return Shape(DimVector(shape.begin(), shape.end()));
});
} }
// Generator // Generator
bool PyGeneratorCheck(PyObject* obj) { bool PyGeneratorCheck(PyObject* obj) { return detail::isinstance_fast<Generator>(obj); }
auto handle = py::reinterpret_borrow<py::object>(obj);
return py::isinstance<Generator>(handle);
}
std::shared_ptr<Generator> PyUnpackGenerator(PyObject* obj) { std::shared_ptr<Generator> PyUnpackGenerator(PyObject* obj) {
auto handle = py::reinterpret_borrow<py::object>(obj); return detail::cast_fast<std::shared_ptr<one::Generator>>(obj);
return py::cast<std::shared_ptr<one::Generator>>(handle);
} }
// Device // Device
bool PyDeviceCheck(PyObject* obj) { bool PyDeviceCheck(PyObject* obj) { return detail::isinstance_fast<Symbol<Device>>(obj); }
auto handle = py::reinterpret_borrow<py::object>(obj);
return py::isinstance<Symbol<Device>>(handle);
}
Symbol<Device> PyUnpackDevice(PyObject* obj) { Symbol<Device> PyUnpackDevice(PyObject* obj) {
auto handle = py::reinterpret_borrow<py::object>(obj); return *detail::cast_fast<std::shared_ptr<Symbol<Device>>>(obj);
return *py::cast<std::shared_ptr<Symbol<Device>>>(handle);
} }
// Placement // Placement
bool PyParallelDescCheck(PyObject* obj) { bool PyParallelDescCheck(PyObject* obj) {
auto handle = py::reinterpret_borrow<py::object>(obj); return detail::isinstance_fast<Symbol<ParallelDesc>>(obj);
return py::isinstance<Symbol<ParallelDesc>>(handle);
} }
Symbol<ParallelDesc> PyUnpackParallelDesc(PyObject* obj) { Symbol<ParallelDesc> PyUnpackParallelDesc(PyObject* obj) {
auto handle = py::reinterpret_borrow<py::object>(obj); return *detail::cast_fast<std::shared_ptr<Symbol<ParallelDesc>>>(obj);
return *py::cast<std::shared_ptr<Symbol<ParallelDesc>>>(handle);
} }
// SBP // SBP
bool PySbpParallelCheck(PyObject* obj) { bool PySbpParallelCheck(PyObject* obj) { return detail::isinstance_fast<Symbol<SbpParallel>>(obj); }
auto handle = py::reinterpret_borrow<py::object>(obj);
return py::isinstance<Symbol<SbpParallel>>(handle);
}
Symbol<SbpParallel> PyUnpackSbpParallel(PyObject* obj) { Symbol<SbpParallel> PyUnpackSbpParallel(PyObject* obj) {
auto handle = py::reinterpret_borrow<py::object>(obj); return *detail::cast_fast<std::shared_ptr<Symbol<SbpParallel>>>(obj);
return *py::cast<std::shared_ptr<Symbol<SbpParallel>>>(handle);
} }
// SBP list // SBP list
...@@ -280,14 +371,10 @@ TensorIndex PyUnpackTensorIndex(PyObject* obj) { ...@@ -280,14 +371,10 @@ TensorIndex PyUnpackTensorIndex(PyObject* obj) {
} }
// OpExpr // OpExpr
bool PyOpExprCheck(PyObject* obj) { bool PyOpExprCheck(PyObject* obj) { return detail::isinstance_fast<OpExpr>(obj); }
auto handle = py::reinterpret_borrow<py::object>(obj);
return py::isinstance<OpExpr>(handle);
}
std::shared_ptr<OpExpr> PyUnpackOpExpr(PyObject* obj) { std::shared_ptr<OpExpr> PyUnpackOpExpr(PyObject* obj) {
auto handle = py::reinterpret_borrow<py::object>(obj); return detail::cast_fast<std::shared_ptr<OpExpr>>(obj);
return py::cast<std::shared_ptr<OpExpr>>(handle);
} }
// int64_t // int64_t
......
...@@ -21,6 +21,8 @@ limitations under the License. ...@@ -21,6 +21,8 @@ limitations under the License.
#include <pybind11/pybind11.h> #include <pybind11/pybind11.h>
#include "oneflow/api/python/framework/tensor.h" #include "oneflow/api/python/framework/tensor.h"
#include "oneflow/api/python/caster/maybe.h"
#include "oneflow/api/python/caster/optional.h"
#include "oneflow/core/common/throw.h" #include "oneflow/core/common/throw.h"
#include "oneflow/core/common/maybe.h" #include "oneflow/core/common/maybe.h"
#include "oneflow/core/common/preprocessor.h" #include "oneflow/core/common/preprocessor.h"
...@@ -52,11 +54,11 @@ struct PyObjectPtrDeleter { ...@@ -52,11 +54,11 @@ struct PyObjectPtrDeleter {
using PyObjectPtr = std::unique_ptr<PyObject, PyObjectPtrDeleter>; using PyObjectPtr = std::unique_ptr<PyObject, PyObjectPtrDeleter>;
#define INTEGER_TYPE_SEQ \ #define INTEGER_AND_BOOL_TYPE_SEQ \
OF_PP_MAKE_TUPLE_SEQ(int32_t) \ OF_PP_MAKE_TUPLE_SEQ(int32_t) \
OF_PP_MAKE_TUPLE_SEQ(uint32_t) \ OF_PP_MAKE_TUPLE_SEQ(uint32_t) \
OF_PP_MAKE_TUPLE_SEQ(int64_t) \ OF_PP_MAKE_TUPLE_SEQ(int64_t) \
OF_PP_MAKE_TUPLE_SEQ(uint64_t) \ OF_PP_MAKE_TUPLE_SEQ(uint64_t) \
OF_PP_MAKE_TUPLE_SEQ(bool) OF_PP_MAKE_TUPLE_SEQ(bool)
#define FLOATING_TYPE_SEQ \ #define FLOATING_TYPE_SEQ \
...@@ -80,20 +82,47 @@ inline std::vector<T> PyUnpackSequence(PyObject* obj, UnpackItemFunc unpack_item ...@@ -80,20 +82,47 @@ inline std::vector<T> PyUnpackSequence(PyObject* obj, UnpackItemFunc unpack_item
return values; return values;
} }
// Scalar Tensor
bool PyScalarTensorCheck(PyObject* obj);
Scalar PyUnpackScalarTensor(PyObject* obj);
#define DefinePyTypeScalarTensorCheck(type, type_check_func) \
inline bool Py##type##ScalarTensorCheck(PyObject* obj) { \
return PyScalarTensorCheck(obj) \
&& type_check_func(PyTensor_Unpack(obj)->dtype()->data_type()); \
}
DefinePyTypeScalarTensorCheck(Bool, IsBoolDataType); // PyBoolScalarTensorCheck
DefinePyTypeScalarTensorCheck(Integer, IsIntegralDataType); // PyIntegerScalarTensorCheck
DefinePyTypeScalarTensorCheck(Float, IsFloatingDataType); // PyFloatScalarTensorCheck
#undef DefinePyTypeScalarTensorCheck
bool PyUnpackBoolScalarTensor(PyObject* obj);
long long PyUnpackIntegerScalarTensor_AsLongLong(PyObject* obj);
double PyUnpackFloatScalarTensor_AsDouble(PyObject* obj);
// Integer/Float list // Integer/Float list
bool PyLongSequenceCheck(PyObject* obj); bool PyLongSequenceCheck(PyObject* obj);
bool PyFloatSquenceCheck(PyObject* obj); bool PyFloatSequenceCheck(PyObject* obj);
template<typename T> template<typename T>
inline std::vector<T> PyUnpackLongSequence(PyObject* obj) { inline std::vector<T> PyUnpackLongSequence(PyObject* obj) {
return PyUnpackSequence<T>( return PyUnpackSequence<T>(obj, [](PyObject* item) -> T {
obj, [](PyObject* item) -> T { return static_cast<T>(PyLong_AsLongLong(item)); }); if (PyIntegerScalarTensorCheck(item)) {
return static_cast<T>(PyUnpackIntegerScalarTensor_AsLongLong(item));
}
return static_cast<T>(PyLong_AsLongLong(item));
});
} }
template<typename T> template<typename T>
inline std::vector<T> PyUnpackFloatSequence(PyObject* obj) { inline std::vector<T> PyUnpackFloatSequence(PyObject* obj) {
return PyUnpackSequence<T>( return PyUnpackSequence<T>(obj, [](PyObject* item) -> T {
obj, [](PyObject* item) -> T { return static_cast<T>(PyFloat_AsDouble(item)); }); if (PyFloatScalarTensorCheck(item)) {
return static_cast<T>(PyUnpackFloatScalarTensor_AsDouble(item));
}
return static_cast<T>(PyFloat_AsDouble(item));
});
} }
// String // String
...@@ -124,6 +153,10 @@ Symbol<DType> PyUnpackDType(PyObject* obj); ...@@ -124,6 +153,10 @@ Symbol<DType> PyUnpackDType(PyObject* obj);
bool PyDTypeSequenceCheck(PyObject* obj); bool PyDTypeSequenceCheck(PyObject* obj);
std::vector<Symbol<DType>> PyUnpackDTypeSequence(PyObject* obj); std::vector<Symbol<DType>> PyUnpackDTypeSequence(PyObject* obj);
// Shape
bool PyShapeCheck(PyObject* obj);
Shape PyUnpackShape(PyObject* obj);
// Shape list // Shape list
bool PyShapeSequenceCheck(PyObject* obj); bool PyShapeSequenceCheck(PyObject* obj);
std::vector<Shape> PyUnpackShapeSequence(PyObject* obj); std::vector<Shape> PyUnpackShapeSequence(PyObject* obj);
......
...@@ -16,7 +16,9 @@ limitations under the License. ...@@ -16,7 +16,9 @@ limitations under the License.
#include "oneflow/core/common/scalar.h" #include "oneflow/core/common/scalar.h"
#include "oneflow/core/framework/attr_map.h" #include "oneflow/core/framework/attr_map.h"
#include "oneflow/core/framework/mutable_attr_map.h"
#include "oneflow/core/framework/nd_sbp.h" #include "oneflow/core/framework/nd_sbp.h"
#include "oneflow/core/framework/op_interpreter/lazy_op_interpreter.h"
#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h" #include "oneflow/core/framework/op_interpreter/op_interpreter_util.h"
#include "oneflow/core/framework/tensor.h" #include "oneflow/core/framework/tensor.h"
#include "oneflow/core/framework/tensor_tuple.h" #include "oneflow/core/framework/tensor_tuple.h"
...@@ -33,19 +35,26 @@ ONEFLOW_FUNCTION_LIBRARY(m) { ...@@ -33,19 +35,26 @@ ONEFLOW_FUNCTION_LIBRARY(m) {
m.add_functor( m.add_functor(
"DispatchFeedInput", "DispatchFeedInput",
[](const std::shared_ptr<OpExpr>& op, const std::shared_ptr<Tensor>& input) -> Maybe<Tensor> { [](const std::shared_ptr<OpExpr>& op, const std::shared_ptr<Tensor>& input) -> Maybe<Tensor> {
return OpInterpUtil::Dispatch<Tensor>(*op, {input}); const auto& origin_input = JUST(OpInterpUtil::Dispatch<Tensor>(*op, {input}));
// Unpack input when do grad acc
return GradAccTryInsertUnpackAfterInput(origin_input);
}); });
m.add_functor( m.add_functor(
"DispatchFetchOutput", "DispatchFetchOutput",
[](const std::shared_ptr<OpExpr>& op, const std::shared_ptr<Tensor>& input) -> Maybe<Tensor> { [](const std::shared_ptr<OpExpr>& op, const std::shared_ptr<Tensor>& input) -> Maybe<Tensor> {
return OpInterpUtil::Dispatch<Tensor>(*op, {input}); // Pack output when do grad acc
const auto& pack_input = JUST(GradAccTryInsertPackBeforeOutput(input));
return OpInterpUtil::Dispatch<Tensor>(*op, {pack_input});
}); });
m.add_functor("DispatchFeedVariable", m.add_functor("DispatchFeedVariable",
[](const std::shared_ptr<OpExpr>& op, const std::shared_ptr<Tensor>& input, [](const std::shared_ptr<OpExpr>& op, const std::shared_ptr<Tensor>& input,
const Scalar& l2) -> Maybe<Tensor> { const Scalar& l2) -> Maybe<Tensor> {
MutableAttrMap attrs; auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("l2");
JUST(attrs.SetAttr<double>("l2", l2.As<double>())); attrs.SetAllAttrs(l2.As<double>());
return OpInterpUtil::Dispatch<Tensor>(*op, {input}, attrs); const auto& origin_var =
JUST(OpInterpUtil::Dispatch<Tensor>(*op, {input}, attrs));
// Repeat variable when do grad acc
return GradAccTryInsertRepeatAfterVar(origin_var);
}); });
m.add_functor( m.add_functor(
"DispatchOfrecordReader", "DispatchOfrecordReader",
...@@ -53,16 +62,12 @@ ONEFLOW_FUNCTION_LIBRARY(m) { ...@@ -53,16 +62,12 @@ ONEFLOW_FUNCTION_LIBRARY(m) {
const std::string& part_name_prefix, int32_t part_name_suffix_length, int32_t batch_size, const std::string& part_name_prefix, int32_t part_name_suffix_length, int32_t batch_size,
int32_t shuffle_buffer_size, bool random_shuffle, bool shuffle_after_epoch, int64_t seed, int32_t shuffle_buffer_size, bool random_shuffle, bool shuffle_after_epoch, int64_t seed,
const Optional<Symbol<Device>>& device) -> Maybe<Tensor> { const Optional<Symbol<Device>>& device) -> Maybe<Tensor> {
MutableAttrMap attrs; auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(
JUST(attrs.SetAttr("data_dir", data_dir)); "data_dir", "data_part_num", "part_name_prefix", "part_name_suffix_length",
JUST(attrs.SetAttr("data_part_num", data_part_num)); "batch_size", "shuffle_buffer_size", "random_shuffle", "shuffle_after_epoch", "seed");
JUST(attrs.SetAttr("part_name_prefix", part_name_prefix)); attrs.SetAllAttrs(data_dir, data_part_num, part_name_prefix, part_name_suffix_length,
JUST(attrs.SetAttr("part_name_suffix_length", part_name_suffix_length)); batch_size, shuffle_buffer_size, random_shuffle, shuffle_after_epoch,
JUST(attrs.SetAttr("batch_size", batch_size)); seed);
JUST(attrs.SetAttr("shuffle_buffer_size", shuffle_buffer_size));
JUST(attrs.SetAttr("random_shuffle", random_shuffle));
JUST(attrs.SetAttr("shuffle_after_epoch", shuffle_after_epoch));
JUST(attrs.SetAttr("seed", seed));
return OpInterpUtil::Dispatch<Tensor>(*op, {}, OpExprInterpContext(attrs, JUST(device))); return OpInterpUtil::Dispatch<Tensor>(*op, {}, OpExprInterpContext(attrs, JUST(device)));
}); });
m.add_functor( m.add_functor(
...@@ -72,17 +77,13 @@ ONEFLOW_FUNCTION_LIBRARY(m) { ...@@ -72,17 +77,13 @@ ONEFLOW_FUNCTION_LIBRARY(m) {
int32_t shuffle_buffer_size, bool random_shuffle, bool shuffle_after_epoch, int64_t seed, int32_t shuffle_buffer_size, bool random_shuffle, bool shuffle_after_epoch, int64_t seed,
const Symbol<ParallelDesc>& placement, const Symbol<ParallelDesc>& placement,
const std::vector<Symbol<SbpParallel>>& sbp_tuple) -> Maybe<Tensor> { const std::vector<Symbol<SbpParallel>>& sbp_tuple) -> Maybe<Tensor> {
MutableAttrMap attrs; auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(
JUST(attrs.SetAttr("data_dir", data_dir)); "data_dir", "data_part_num", "part_name_prefix", "part_name_suffix_length",
JUST(attrs.SetAttr("data_part_num", data_part_num)); "batch_size", "shuffle_buffer_size", "random_shuffle", "shuffle_after_epoch", "seed",
JUST(attrs.SetAttr("part_name_prefix", part_name_prefix)); "nd_sbp");
JUST(attrs.SetAttr("part_name_suffix_length", part_name_suffix_length)); attrs.SetAllAttrs(data_dir, data_part_num, part_name_prefix, part_name_suffix_length,
JUST(attrs.SetAttr("batch_size", batch_size)); batch_size, shuffle_buffer_size, random_shuffle, shuffle_after_epoch,
JUST(attrs.SetAttr("shuffle_buffer_size", shuffle_buffer_size)); seed, *JUST(GetNdSbpStrList(sbp_tuple)));
JUST(attrs.SetAttr("random_shuffle", random_shuffle));
JUST(attrs.SetAttr("shuffle_after_epoch", shuffle_after_epoch));
JUST(attrs.SetAttr("seed", seed));
JUST(attrs.SetAttr("nd_sbp", *JUST(GetNdSbpStrList(sbp_tuple))));
auto nd_sbp = JUST(GetNdSbp(sbp_tuple)); auto nd_sbp = JUST(GetNdSbp(sbp_tuple));
return OpInterpUtil::Dispatch<Tensor>(*op, {}, return OpInterpUtil::Dispatch<Tensor>(*op, {},
OpExprInterpContext(attrs, placement, nd_sbp)); OpExprInterpContext(attrs, placement, nd_sbp));
...@@ -91,35 +92,29 @@ ONEFLOW_FUNCTION_LIBRARY(m) { ...@@ -91,35 +92,29 @@ ONEFLOW_FUNCTION_LIBRARY(m) {
[](const std::shared_ptr<OpExpr>& op, const std::shared_ptr<Tensor>& input, [](const std::shared_ptr<OpExpr>& op, const std::shared_ptr<Tensor>& input,
const std::string& name, const Shape& shape, const Symbol<DType>& data_type, const std::string& name, const Shape& shape, const Symbol<DType>& data_type,
bool dim1_varying_length, bool truncate) -> Maybe<Tensor> { bool dim1_varying_length, bool truncate) -> Maybe<Tensor> {
MutableAttrMap attrs; auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("name", "shape", "data_type",
JUST(attrs.SetAttr("name", name)); "dim1_varying_length", "truncate");
JUST(attrs.SetAttr("shape", shape)); attrs.SetAllAttrs(name, shape, data_type->data_type(), dim1_varying_length,
JUST(attrs.SetAttr("data_type", data_type->data_type())); truncate);
JUST(attrs.SetAttr("dim1_varying_length", dim1_varying_length));
JUST(attrs.SetAttr("truncate", truncate));
return OpInterpUtil::Dispatch<Tensor>(*op, {input}, attrs); return OpInterpUtil::Dispatch<Tensor>(*op, {input}, attrs);
}); });
m.add_functor( m.add_functor(
"DispatchCoinFlip", "DispatchCoinFlip",
[](const std::shared_ptr<OpExpr>& op, int64_t batch_size, Scalar probability, int64_t seed, [](const std::shared_ptr<OpExpr>& op, int64_t batch_size, Scalar probability, int64_t seed,
bool has_seed, const Optional<Symbol<Device>>& device) -> Maybe<Tensor> { bool has_seed, const Optional<Symbol<Device>>& device) -> Maybe<Tensor> {
MutableAttrMap attrs; auto& attrs =
JUST(attrs.SetAttr("probability", probability.As<float>())); THREAD_CACHED_MUTABLE_ATTR_MAP("probability", "batch_size", "seed", "has_seed");
JUST(attrs.SetAttr("batch_size", batch_size)); attrs.SetAllAttrs(probability.As<float>(), batch_size, seed, has_seed);
JUST(attrs.SetAttr("seed", seed));
JUST(attrs.SetAttr("has_seed", has_seed));
return OpInterpUtil::Dispatch<Tensor>(*op, {}, OpExprInterpContext(attrs, JUST(device))); return OpInterpUtil::Dispatch<Tensor>(*op, {}, OpExprInterpContext(attrs, JUST(device)));
}); });
m.add_functor("DispatchCoinFlip", m.add_functor("DispatchCoinFlip",
[](const std::shared_ptr<OpExpr>& op, int64_t batch_size, Scalar probability, [](const std::shared_ptr<OpExpr>& op, int64_t batch_size, Scalar probability,
int64_t seed, bool has_seed, const Symbol<ParallelDesc>& placement, int64_t seed, bool has_seed, const Symbol<ParallelDesc>& placement,
const std::vector<Symbol<SbpParallel>>& sbp_tuple) -> Maybe<Tensor> { const std::vector<Symbol<SbpParallel>>& sbp_tuple) -> Maybe<Tensor> {
MutableAttrMap attrs; auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("probability", "batch_size", "seed",
JUST(attrs.SetAttr("probability", probability.As<float>())); "has_seed", "nd_sbp");
JUST(attrs.SetAttr("batch_size", batch_size)); attrs.SetAllAttrs(probability.As<float>(), batch_size, seed, has_seed,
JUST(attrs.SetAttr("seed", seed)); *JUST(GetNdSbpStrList(sbp_tuple)));
JUST(attrs.SetAttr("has_seed", has_seed));
JUST(attrs.SetAttr("nd_sbp", *JUST(GetNdSbpStrList(sbp_tuple))));
auto nd_sbp = JUST(GetNdSbp(sbp_tuple)); auto nd_sbp = JUST(GetNdSbp(sbp_tuple));
return OpInterpUtil::Dispatch<Tensor>( return OpInterpUtil::Dispatch<Tensor>(
*op, {}, OpExprInterpContext(attrs, placement, nd_sbp)); *op, {}, OpExprInterpContext(attrs, placement, nd_sbp));
...@@ -128,8 +123,8 @@ ONEFLOW_FUNCTION_LIBRARY(m) { ...@@ -128,8 +123,8 @@ ONEFLOW_FUNCTION_LIBRARY(m) {
"DispatchDistributedPariticalFCSample", "DispatchDistributedPariticalFCSample",
[](const std::shared_ptr<OpExpr>& op, const std::shared_ptr<Tensor>& weight, [](const std::shared_ptr<OpExpr>& op, const std::shared_ptr<Tensor>& weight,
const std::shared_ptr<Tensor>& label, const int64_t& num_sample) -> Maybe<TensorTuple> { const std::shared_ptr<Tensor>& label, const int64_t& num_sample) -> Maybe<TensorTuple> {
MutableAttrMap attrs; auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("num_sample");
JUST(attrs.SetAttr<int64_t>("num_sample", num_sample)); attrs.SetAllAttrs(num_sample);
return OpInterpUtil::Dispatch<TensorTuple>(*op, {weight, label}, attrs); return OpInterpUtil::Dispatch<TensorTuple>(*op, {weight, label}, attrs);
}); });
m.add_functor( m.add_functor(
...@@ -138,16 +133,11 @@ ONEFLOW_FUNCTION_LIBRARY(m) { ...@@ -138,16 +133,11 @@ ONEFLOW_FUNCTION_LIBRARY(m) {
int64_t crop_w, float crop_pos_x, float crop_pos_y, const std::vector<float>& mean, int64_t crop_w, float crop_pos_x, float crop_pos_y, const std::vector<float>& mean,
const std::vector<float>& std, const Symbol<DType>& output_dtype, const std::vector<float>& std, const Symbol<DType>& output_dtype,
const std::string& output_layout, const std::string& color_space) -> Maybe<Tensor> { const std::string& output_layout, const std::string& color_space) -> Maybe<Tensor> {
MutableAttrMap attrs; auto& attrs =
JUST(attrs.SetAttr("color_space", color_space)); THREAD_CACHED_MUTABLE_ATTR_MAP("color_space", "output_layout", "mean", "std", "crop_h",
JUST(attrs.SetAttr("output_layout", output_layout)); "crop_w", "crop_pos_x", "crop_pos_y", "output_dtype");
JUST(attrs.SetAttr("mean", mean)); attrs.SetAllAttrs(color_space, output_layout, mean, std, crop_h, crop_w, crop_pos_x,
JUST(attrs.SetAttr("std", std)); crop_pos_y, output_dtype->data_type());
JUST(attrs.SetAttr("crop_h", crop_h));
JUST(attrs.SetAttr("crop_w", crop_w));
JUST(attrs.SetAttr("crop_pos_x", crop_pos_x));
JUST(attrs.SetAttr("crop_pos_y", crop_pos_y));
JUST(attrs.SetAttr("output_dtype", output_dtype->data_type()));
return OpInterpUtil::Dispatch<Tensor>(*op, input, attrs); return OpInterpUtil::Dispatch<Tensor>(*op, input, attrs);
}); });
m.add_functor( m.add_functor(
...@@ -156,16 +146,11 @@ ONEFLOW_FUNCTION_LIBRARY(m) { ...@@ -156,16 +146,11 @@ ONEFLOW_FUNCTION_LIBRARY(m) {
int64_t crop_w, float crop_pos_x, float crop_pos_y, const std::vector<float>& mean, int64_t crop_w, float crop_pos_x, float crop_pos_y, const std::vector<float>& mean,
const std::vector<float>& std, const Symbol<DType>& output_dtype, const std::vector<float>& std, const Symbol<DType>& output_dtype,
const std::string& output_layout, const std::string& color_space) -> Maybe<Tensor> { const std::string& output_layout, const std::string& color_space) -> Maybe<Tensor> {
MutableAttrMap attrs; auto& attrs =
JUST(attrs.SetAttr("color_space", color_space)); THREAD_CACHED_MUTABLE_ATTR_MAP("color_space", "output_layout", "mean", "std", "crop_h",
JUST(attrs.SetAttr("output_layout", output_layout)); "crop_w", "crop_pos_x", "crop_pos_y", "output_dtype");
JUST(attrs.SetAttr("mean", mean)); attrs.SetAllAttrs(color_space, output_layout, mean, std, crop_h, crop_w, crop_pos_x,
JUST(attrs.SetAttr("std", std)); crop_pos_y, output_dtype->data_type());
JUST(attrs.SetAttr("crop_h", crop_h));
JUST(attrs.SetAttr("crop_w", crop_w));
JUST(attrs.SetAttr("crop_pos_x", crop_pos_x));
JUST(attrs.SetAttr("crop_pos_y", crop_pos_y));
JUST(attrs.SetAttr("output_dtype", output_dtype->data_type()));
return OpInterpUtil::Dispatch<Tensor>(*op, {input}, attrs); return OpInterpUtil::Dispatch<Tensor>(*op, {input}, attrs);
}); });
m.add_functor( m.add_functor(
...@@ -174,22 +159,18 @@ ONEFLOW_FUNCTION_LIBRARY(m) { ...@@ -174,22 +159,18 @@ ONEFLOW_FUNCTION_LIBRARY(m) {
const std::string& name, const std::string& color_space, const std::string& name, const std::string& color_space,
const std::vector<float>& random_area, const std::vector<float>& random_aspect_ratio, const std::vector<float>& random_area, const std::vector<float>& random_aspect_ratio,
int32_t num_attempts, int64_t seed, bool has_seed) -> Maybe<Tensor> { int32_t num_attempts, int64_t seed, bool has_seed) -> Maybe<Tensor> {
MutableAttrMap attrs; auto& attrs =
JUST(attrs.SetAttr("name", name)); THREAD_CACHED_MUTABLE_ATTR_MAP("name", "color_space", "num_attempts", "seed",
JUST(attrs.SetAttr("color_space", color_space)); "has_seed", "random_area", "random_aspect_ratio");
JUST(attrs.SetAttr("num_attempts", num_attempts)); attrs.SetAllAttrs(name, color_space, num_attempts, seed, has_seed, random_area,
JUST(attrs.SetAttr("seed", seed)); random_aspect_ratio);
JUST(attrs.SetAttr("has_seed", has_seed));
JUST(attrs.SetAttr("random_area", random_area));
JUST(attrs.SetAttr("random_aspect_ratio", random_aspect_ratio));
return OpInterpUtil::Dispatch<Tensor>(*op, {input}, attrs); return OpInterpUtil::Dispatch<Tensor>(*op, {input}, attrs);
}); });
m.add_functor("DispatchOfrecordImageDecoder", m.add_functor("DispatchOfrecordImageDecoder",
[](const std::shared_ptr<OpExpr>& op, const std::shared_ptr<Tensor>& input, [](const std::shared_ptr<OpExpr>& op, const std::shared_ptr<Tensor>& input,
const std::string& name, const std::string& color_space) -> Maybe<Tensor> { const std::string& name, const std::string& color_space) -> Maybe<Tensor> {
MutableAttrMap attrs; auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("name", "color_space");
JUST(attrs.SetAttr("name", name)); attrs.SetAllAttrs(name, color_space);
JUST(attrs.SetAttr("color_space", color_space));
return OpInterpUtil::Dispatch<Tensor>(*op, {input}, attrs); return OpInterpUtil::Dispatch<Tensor>(*op, {input}, attrs);
}); });
m.add_functor("DispatchImageDecoderRandomCropResize", m.add_functor("DispatchImageDecoderRandomCropResize",
...@@ -198,18 +179,13 @@ ONEFLOW_FUNCTION_LIBRARY(m) { ...@@ -198,18 +179,13 @@ ONEFLOW_FUNCTION_LIBRARY(m) {
int64_t max_num_pixels, float random_area_min, float random_area_max, int64_t max_num_pixels, float random_area_min, float random_area_max,
float random_aspect_ratio_min, float random_aspect_ratio_max, float random_aspect_ratio_min, float random_aspect_ratio_max,
int64_t warmup_size, int64_t num_attempts) -> Maybe<Tensor> { int64_t warmup_size, int64_t num_attempts) -> Maybe<Tensor> {
MutableAttrMap attrs; auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(
JUST(attrs.SetAttr("target_width", target_width)); "target_width", "target_height", "seed", "num_workers", "max_num_pixels",
JUST(attrs.SetAttr("target_height", target_height)); "random_area_min", "random_area_max", "random_aspect_ratio_min",
JUST(attrs.SetAttr("seed", seed)); "random_aspect_ratio_max", "warmup_size", "num_attempts");
JUST(attrs.SetAttr("num_workers", num_workers)); attrs.SetAllAttrs(target_width, target_height, seed, num_workers, max_num_pixels,
JUST(attrs.SetAttr("max_num_pixels", max_num_pixels)); random_area_min, random_area_max, random_aspect_ratio_min,
JUST(attrs.SetAttr("random_area_min", random_area_min)); random_aspect_ratio_max, warmup_size, num_attempts);
JUST(attrs.SetAttr("random_area_max", random_area_max));
JUST(attrs.SetAttr("random_aspect_ratio_min", random_aspect_ratio_min));
JUST(attrs.SetAttr("random_aspect_ratio_max", random_aspect_ratio_max));
JUST(attrs.SetAttr("warmup_size", warmup_size));
JUST(attrs.SetAttr("num_attempts", num_attempts));
return OpInterpUtil::Dispatch<Tensor>(*op, {input}, attrs); return OpInterpUtil::Dispatch<Tensor>(*op, {input}, attrs);
}); });
m.add_functor( m.add_functor(
...@@ -217,26 +193,22 @@ ONEFLOW_FUNCTION_LIBRARY(m) { ...@@ -217,26 +193,22 @@ ONEFLOW_FUNCTION_LIBRARY(m) {
[](const std::shared_ptr<OpExpr>& op, const std::shared_ptr<Tensor>& input, [](const std::shared_ptr<OpExpr>& op, const std::shared_ptr<Tensor>& input,
const std::vector<Shape>& out_shapes, const std::vector<Symbol<DType>>& out_dtypes, const std::vector<Shape>& out_shapes, const std::vector<Symbol<DType>>& out_dtypes,
bool dynamic_out) -> Maybe<TensorTuple> { bool dynamic_out) -> Maybe<TensorTuple> {
MutableAttrMap attrs;
JUST(attrs.SetAttr("out_shapes", out_shapes));
JUST(attrs.SetAttr("dynamic_out", dynamic_out));
auto out_data_types = std::vector<DataType>(); auto out_data_types = std::vector<DataType>();
for (auto it = out_dtypes.begin(); it != out_dtypes.end(); it++) { for (auto it = out_dtypes.begin(); it != out_dtypes.end(); it++) {
out_data_types.emplace_back((*it)->data_type()); out_data_types.emplace_back((*it)->data_type());
} }
JUST(attrs.SetAttr("out_dtypes", out_data_types)); auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("out_shapes", "dynamic_out", "out_dtypes");
attrs.SetAllAttrs(out_shapes, dynamic_out, out_data_types);
return OpInterpUtil::Dispatch<TensorTuple>(*op, {input}, attrs); return OpInterpUtil::Dispatch<TensorTuple>(*op, {input}, attrs);
}); });
m.add_functor("DispatchImageResizeKeepAspectRatio", m.add_functor("DispatchImageResizeKeepAspectRatio",
[](const std::shared_ptr<OpExpr>& op, const std::shared_ptr<Tensor>& input, [](const std::shared_ptr<OpExpr>& op, const std::shared_ptr<Tensor>& input,
int32_t target_size, int32_t min_size, int32_t max_size, bool resize_longer, int32_t target_size, int32_t min_size, int32_t max_size, bool resize_longer,
const std::string& interpolation_type) -> Maybe<TensorTuple> { const std::string& interpolation_type) -> Maybe<TensorTuple> {
MutableAttrMap attrs; auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(
JUST(attrs.SetAttr("target_size", target_size)); "target_size", "min_size", "max_size", "resize_longer", "interpolation_type");
JUST(attrs.SetAttr("min_size", min_size)); attrs.SetAllAttrs(target_size, min_size, max_size, resize_longer,
JUST(attrs.SetAttr("max_size", max_size)); interpolation_type);
JUST(attrs.SetAttr("resize_longer", resize_longer));
JUST(attrs.SetAttr("interpolation_type", interpolation_type));
return OpInterpUtil::Dispatch<TensorTuple>(*op, {input}, attrs); return OpInterpUtil::Dispatch<TensorTuple>(*op, {input}, attrs);
}); });
m.add_functor("DispatchImageResizeToFixed", m.add_functor("DispatchImageResizeToFixed",
...@@ -244,89 +216,76 @@ ONEFLOW_FUNCTION_LIBRARY(m) { ...@@ -244,89 +216,76 @@ ONEFLOW_FUNCTION_LIBRARY(m) {
int64_t target_width, int64_t target_height, int64_t channels, int64_t target_width, int64_t target_height, int64_t channels,
const Symbol<DType>& data_type, const Symbol<DType>& data_type,
const std::string& interpolation_type) -> Maybe<TensorTuple> { const std::string& interpolation_type) -> Maybe<TensorTuple> {
MutableAttrMap attrs; auto& attrs =
JUST(attrs.SetAttr("target_width", target_width)); THREAD_CACHED_MUTABLE_ATTR_MAP("target_width", "target_height", "channels",
JUST(attrs.SetAttr("target_height", target_height)); "data_type", "interpolation_type");
JUST(attrs.SetAttr("channels", channels)); attrs.SetAllAttrs(target_width, target_height, channels, data_type->data_type(),
JUST(attrs.SetAttr("data_type", data_type->data_type())); interpolation_type);
JUST(attrs.SetAttr("interpolation_type", interpolation_type));
return OpInterpUtil::Dispatch<TensorTuple>(*op, {input}, attrs); return OpInterpUtil::Dispatch<TensorTuple>(*op, {input}, attrs);
}); });
m.add_functor( m.add_functor(
"DispatchImageDecode", "DispatchImageDecode",
[](const std::shared_ptr<OpExpr>& op, const std::shared_ptr<Tensor>& input, [](const std::shared_ptr<OpExpr>& op, const std::shared_ptr<Tensor>& input,
const std::string& color_space, const Symbol<DType>& data_type) -> Maybe<Tensor> { const std::string& color_space, const Symbol<DType>& data_type) -> Maybe<Tensor> {
MutableAttrMap attrs; auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("color_space", "data_type");
JUST(attrs.SetAttr("color_space", color_space)); attrs.SetAllAttrs(color_space, data_type->data_type());
JUST(attrs.SetAttr("data_type", data_type->data_type()));
return OpInterpUtil::Dispatch<Tensor>(*op, {input}, attrs); return OpInterpUtil::Dispatch<Tensor>(*op, {input}, attrs);
}); });
m.add_functor("DispatchImageNormalize", m.add_functor("DispatchImageNormalize",
[](const std::shared_ptr<OpExpr>& op, const std::shared_ptr<Tensor>& input, [](const std::shared_ptr<OpExpr>& op, const std::shared_ptr<Tensor>& input,
const std::vector<float>& mean, const std::vector<float>& std) -> Maybe<Tensor> { const std::vector<float>& mean, const std::vector<float>& std) -> Maybe<Tensor> {
MutableAttrMap attrs; auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("std", "mean");
JUST(attrs.SetAttr("std", std)); attrs.SetAllAttrs(std, mean);
JUST(attrs.SetAttr("mean", mean));
return OpInterpUtil::Dispatch<Tensor>(*op, {input}, attrs); return OpInterpUtil::Dispatch<Tensor>(*op, {input}, attrs);
}); });
m.add_functor( m.add_functor("DispatchCOCOReader",
"DispatchCOCOReader", [](const std::shared_ptr<OpExpr>& op, const std::string& image_dir,
[](const std::shared_ptr<OpExpr>& op, const std::string& image_dir, const std::string& annotation_file, int64_t batch_size, bool shuffle_after_epoch,
const std::string& annotation_file, int64_t batch_size, bool shuffle_after_epoch, int64_t random_seed, bool group_by_ratio, bool remove_images_without_annotations,
int64_t random_seed, bool group_by_ratio, bool remove_images_without_annotations, bool stride_partition, int64_t session_id,
bool stride_partition, int64_t session_id, const Optional<Symbol<Device>>& device) -> Maybe<TensorTuple> {
const Optional<Symbol<Device>>& device) -> Maybe<TensorTuple> { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(
MutableAttrMap attrs; "session_id", "annotation_file", "image_dir", "batch_size",
JUST(attrs.SetAttr("session_id", session_id)); "shuffle_after_epoch", "random_seed", "group_by_ratio",
JUST(attrs.SetAttr("annotation_file", annotation_file)); "remove_images_without_annotations", "stride_partition");
JUST(attrs.SetAttr("image_dir", image_dir)); attrs.SetAllAttrs(session_id, annotation_file, image_dir, batch_size,
JUST(attrs.SetAttr("batch_size", batch_size)); shuffle_after_epoch, random_seed, group_by_ratio,
JUST(attrs.SetAttr("shuffle_after_epoch", shuffle_after_epoch)); remove_images_without_annotations, stride_partition);
JUST(attrs.SetAttr("random_seed", random_seed)); return OpInterpUtil::Dispatch<TensorTuple>(
JUST(attrs.SetAttr("group_by_ratio", group_by_ratio)); *op, {}, OpExprInterpContext(attrs, JUST(device)));
JUST(attrs.SetAttr("remove_images_without_annotations", remove_images_without_annotations)); });
JUST(attrs.SetAttr("stride_partition", stride_partition)); m.add_functor("DispatchCOCOReader",
return OpInterpUtil::Dispatch<TensorTuple>(*op, {}, [](const std::shared_ptr<OpExpr>& op, const std::string& image_dir,
OpExprInterpContext(attrs, JUST(device))); const std::string& annotation_file, int64_t batch_size, bool shuffle_after_epoch,
}); int64_t random_seed, bool group_by_ratio, bool remove_images_without_annotations,
m.add_functor( bool stride_partition, int64_t session_id, const Symbol<ParallelDesc>& placement,
"DispatchCOCOReader", const std::vector<Symbol<SbpParallel>>& sbp_tuple) -> Maybe<TensorTuple> {
[](const std::shared_ptr<OpExpr>& op, const std::string& image_dir, auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(
const std::string& annotation_file, int64_t batch_size, bool shuffle_after_epoch, "session_id", "annotation_file", "image_dir", "batch_size",
int64_t random_seed, bool group_by_ratio, bool remove_images_without_annotations, "shuffle_after_epoch", "random_seed", "group_by_ratio",
bool stride_partition, int64_t session_id, const Symbol<ParallelDesc>& placement, "remove_images_without_annotations", "stride_partition", "nd_sbp");
const std::vector<Symbol<SbpParallel>>& sbp_tuple) -> Maybe<TensorTuple> { attrs.SetAllAttrs(session_id, annotation_file, image_dir, batch_size,
MutableAttrMap attrs; shuffle_after_epoch, random_seed, group_by_ratio,
JUST(attrs.SetAttr("session_id", session_id)); remove_images_without_annotations, stride_partition,
JUST(attrs.SetAttr("annotation_file", annotation_file)); *JUST(GetNdSbpStrList(sbp_tuple)));
JUST(attrs.SetAttr("image_dir", image_dir)); auto nd_sbp = JUST(GetNdSbp(sbp_tuple));
JUST(attrs.SetAttr("batch_size", batch_size)); return OpInterpUtil::Dispatch<TensorTuple>(
JUST(attrs.SetAttr("shuffle_after_epoch", shuffle_after_epoch)); *op, {}, OpExprInterpContext(attrs, placement, nd_sbp));
JUST(attrs.SetAttr("random_seed", random_seed)); });
JUST(attrs.SetAttr("group_by_ratio", group_by_ratio));
JUST(attrs.SetAttr("remove_images_without_annotations", remove_images_without_annotations));
JUST(attrs.SetAttr("stride_partition", stride_partition));
JUST(attrs.SetAttr("nd_sbp", *JUST(GetNdSbpStrList(sbp_tuple))));
auto nd_sbp = JUST(GetNdSbp(sbp_tuple));
return OpInterpUtil::Dispatch<TensorTuple>(*op, {},
OpExprInterpContext(attrs, placement, nd_sbp));
});
m.add_functor( m.add_functor(
"DispatchImageBatchAlign", "DispatchImageBatchAlign",
[](const std::shared_ptr<OpExpr>& op, const std::shared_ptr<Tensor>& input, int32_t alignment, [](const std::shared_ptr<OpExpr>& op, const std::shared_ptr<Tensor>& input, int32_t alignment,
const Shape& shape, const Symbol<DType>& data_type, bool dynamic_out) -> Maybe<Tensor> { const Shape& shape, const Symbol<DType>& data_type, bool dynamic_out) -> Maybe<Tensor> {
MutableAttrMap attrs; auto& attrs =
JUST(attrs.SetAttr("shape", shape)); THREAD_CACHED_MUTABLE_ATTR_MAP("shape", "data_type", "alignment", "dynamic_out");
JUST(attrs.SetAttr("data_type", data_type->data_type())); attrs.SetAllAttrs(shape, data_type->data_type(), alignment, dynamic_out);
JUST(attrs.SetAttr("alignment", alignment));
JUST(attrs.SetAttr("dynamic_out", dynamic_out));
return OpInterpUtil::Dispatch<Tensor>(*op, {input}, attrs); return OpInterpUtil::Dispatch<Tensor>(*op, {input}, attrs);
}); });
m.add_functor("DispatchOfrecordBytesDecoder", m.add_functor("DispatchOfrecordBytesDecoder",
[](const std::shared_ptr<OpExpr>& op, const std::shared_ptr<Tensor>& input, [](const std::shared_ptr<OpExpr>& op, const std::shared_ptr<Tensor>& input,
const std::string& name) -> Maybe<Tensor> { const std::string& name) -> Maybe<Tensor> {
MutableAttrMap attrs; auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("name");
JUST(attrs.SetAttr("name", name)); attrs.SetAllAttrs(name);
return OpInterpUtil::Dispatch<Tensor>(*op, {input}, attrs); return OpInterpUtil::Dispatch<Tensor>(*op, {input}, attrs);
}); });
m.add_functor( m.add_functor(
...@@ -335,15 +294,11 @@ ONEFLOW_FUNCTION_LIBRARY(m) { ...@@ -335,15 +294,11 @@ ONEFLOW_FUNCTION_LIBRARY(m) {
const int64_t batch_size, const bool random_shuffle, const std::string& shuffle_mode, const int64_t batch_size, const bool random_shuffle, const std::string& shuffle_mode,
const int32_t shuffle_buffer_size, const bool shuffle_after_epoch, int64_t random_seed, const int32_t shuffle_buffer_size, const bool shuffle_after_epoch, int64_t random_seed,
const bool verify_example, const Optional<Symbol<Device>>& device) -> Maybe<Tensor> { const bool verify_example, const Optional<Symbol<Device>>& device) -> Maybe<Tensor> {
MutableAttrMap attrs; auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(
JUST(attrs.SetAttr<std::vector<std::string>>("files", files)); "files", "batch_size", "random_shuffle", "shuffle_mode", "shuffle_buffer_size",
JUST(attrs.SetAttr<int64_t>("batch_size", batch_size)); "shuffle_after_epoch", "seed", "verify_example");
JUST(attrs.SetAttr<bool>("random_shuffle", random_shuffle)); attrs.SetAllAttrs(files, batch_size, random_shuffle, shuffle_mode, shuffle_buffer_size,
JUST(attrs.SetAttr<std::string>("shuffle_mode", shuffle_mode)); shuffle_after_epoch, random_seed, verify_example);
JUST(attrs.SetAttr<int32_t>("shuffle_buffer_size", shuffle_buffer_size));
JUST(attrs.SetAttr<bool>("shuffle_after_epoch", shuffle_after_epoch));
JUST(attrs.SetAttr<int64_t>("seed", random_seed));
JUST(attrs.SetAttr<bool>("verify_example", verify_example));
return OpInterpUtil::Dispatch<Tensor>(*op, {}, OpExprInterpContext(attrs, JUST(device))); return OpInterpUtil::Dispatch<Tensor>(*op, {}, OpExprInterpContext(attrs, JUST(device)));
}); });
m.add_functor( m.add_functor(
...@@ -353,16 +308,12 @@ ONEFLOW_FUNCTION_LIBRARY(m) { ...@@ -353,16 +308,12 @@ ONEFLOW_FUNCTION_LIBRARY(m) {
const int32_t shuffle_buffer_size, const bool shuffle_after_epoch, int64_t random_seed, const int32_t shuffle_buffer_size, const bool shuffle_after_epoch, int64_t random_seed,
const bool verify_example, const Symbol<ParallelDesc>& placement, const bool verify_example, const Symbol<ParallelDesc>& placement,
const std::vector<Symbol<SbpParallel>>& sbp_tuple) -> Maybe<Tensor> { const std::vector<Symbol<SbpParallel>>& sbp_tuple) -> Maybe<Tensor> {
MutableAttrMap attrs; auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(
JUST(attrs.SetAttr<std::vector<std::string>>("files", files)); "files", "batch_size", "random_shuffle", "shuffle_mode", "shuffle_buffer_size",
JUST(attrs.SetAttr<int64_t>("batch_size", batch_size)); "shuffle_after_epoch", "seed", "verify_example", "nd_sbp");
JUST(attrs.SetAttr<bool>("random_shuffle", random_shuffle)); attrs.SetAllAttrs(files, batch_size, random_shuffle, shuffle_mode, shuffle_buffer_size,
JUST(attrs.SetAttr<std::string>("shuffle_mode", shuffle_mode)); shuffle_after_epoch, random_seed, verify_example,
JUST(attrs.SetAttr<int32_t>("shuffle_buffer_size", shuffle_buffer_size)); *JUST(GetNdSbpStrList(sbp_tuple)));
JUST(attrs.SetAttr<bool>("shuffle_after_epoch", shuffle_after_epoch));
JUST(attrs.SetAttr<int64_t>("seed", random_seed));
JUST(attrs.SetAttr<bool>("verify_example", verify_example));
JUST(attrs.SetAttr("nd_sbp", *JUST(GetNdSbpStrList(sbp_tuple))));
auto nd_sbp = JUST(GetNdSbp(sbp_tuple)); auto nd_sbp = JUST(GetNdSbp(sbp_tuple));
return OpInterpUtil::Dispatch<Tensor>(*op, {}, return OpInterpUtil::Dispatch<Tensor>(*op, {},
OpExprInterpContext(attrs, placement, nd_sbp)); OpExprInterpContext(attrs, placement, nd_sbp));
...@@ -373,17 +324,11 @@ ONEFLOW_FUNCTION_LIBRARY(m) { ...@@ -373,17 +324,11 @@ ONEFLOW_FUNCTION_LIBRARY(m) {
int64_t label_length, int64_t num_samples, int64_t batch_size, const Symbol<DType>& dtype, int64_t label_length, int64_t num_samples, int64_t batch_size, const Symbol<DType>& dtype,
const std::vector<int64_t>& split_sizes, int64_t split_index, bool shuffle, const std::vector<int64_t>& split_sizes, int64_t split_index, bool shuffle,
int64_t random_seed, const Optional<Symbol<Device>>& device) -> Maybe<Tensor> { int64_t random_seed, const Optional<Symbol<Device>>& device) -> Maybe<Tensor> {
MutableAttrMap attrs; auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(
JUST(attrs.SetAttr("data_file_prefix", data_file_prefix)); "data_file_prefix", "seq_length", "label_length", "num_samples", "batch_size", "dtype",
JUST(attrs.SetAttr("seq_length", seq_length)); "split_sizes", "split_index", "shuffle", "random_seed");
JUST(attrs.SetAttr("label_length", label_length)); attrs.SetAllAttrs(data_file_prefix, seq_length, label_length, num_samples, batch_size,
JUST(attrs.SetAttr("num_samples", num_samples)); dtype->data_type(), split_sizes, split_index, shuffle, random_seed);
JUST(attrs.SetAttr("batch_size", batch_size));
JUST(attrs.SetAttr("dtype", dtype->data_type()));
JUST(attrs.SetAttr("split_sizes", split_sizes));
JUST(attrs.SetAttr("split_index", split_index));
JUST(attrs.SetAttr("shuffle", shuffle));
JUST(attrs.SetAttr("random_seed", random_seed));
return OpInterpUtil::Dispatch<Tensor>(*op, {}, OpExprInterpContext(attrs, JUST(device))); return OpInterpUtil::Dispatch<Tensor>(*op, {}, OpExprInterpContext(attrs, JUST(device)));
}); });
m.add_functor( m.add_functor(
...@@ -393,17 +338,11 @@ ONEFLOW_FUNCTION_LIBRARY(m) { ...@@ -393,17 +338,11 @@ ONEFLOW_FUNCTION_LIBRARY(m) {
const std::vector<int64_t>& split_sizes, int64_t split_index, bool shuffle, const std::vector<int64_t>& split_sizes, int64_t split_index, bool shuffle,
int64_t random_seed, const Symbol<ParallelDesc>& placement, int64_t random_seed, const Symbol<ParallelDesc>& placement,
const std::vector<Symbol<SbpParallel>>& sbp_tuple) -> Maybe<Tensor> { const std::vector<Symbol<SbpParallel>>& sbp_tuple) -> Maybe<Tensor> {
MutableAttrMap attrs; auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(
JUST(attrs.SetAttr("data_file_prefix", data_file_prefix)); "data_file_prefix", "seq_length", "label_length", "num_samples", "batch_size", "dtype",
JUST(attrs.SetAttr("seq_length", seq_length)); "split_sizes", "split_index", "shuffle", "random_seed");
JUST(attrs.SetAttr("label_length", label_length)); attrs.SetAllAttrs(data_file_prefix, seq_length, label_length, num_samples, batch_size,
JUST(attrs.SetAttr("num_samples", num_samples)); dtype->data_type(), split_sizes, split_index, shuffle, random_seed);
JUST(attrs.SetAttr("batch_size", batch_size));
JUST(attrs.SetAttr("dtype", dtype->data_type()));
JUST(attrs.SetAttr("split_sizes", split_sizes));
JUST(attrs.SetAttr("split_index", split_index));
JUST(attrs.SetAttr("shuffle", shuffle));
JUST(attrs.SetAttr("random_seed", random_seed));
auto nd_sbp = JUST(GetNdSbp(sbp_tuple)); auto nd_sbp = JUST(GetNdSbp(sbp_tuple));
return OpInterpUtil::Dispatch<Tensor>(*op, {}, return OpInterpUtil::Dispatch<Tensor>(*op, {},
OpExprInterpContext(attrs, placement, nd_sbp)); OpExprInterpContext(attrs, placement, nd_sbp));
...@@ -412,66 +351,50 @@ ONEFLOW_FUNCTION_LIBRARY(m) { ...@@ -412,66 +351,50 @@ ONEFLOW_FUNCTION_LIBRARY(m) {
[](const std::shared_ptr<OpExpr>& op, const TensorTuple& inputs, [](const std::shared_ptr<OpExpr>& op, const TensorTuple& inputs,
float learning_rate, double scale, float l1, float l2, bool centered, float learning_rate, double scale, float l1, float l2, bool centered,
float epsilon, float decay_rate, float weight_decay) -> Maybe<void> { float epsilon, float decay_rate, float weight_decay) -> Maybe<void> {
MutableAttrMap attrs; auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("learning_rate_val", "scale", "l1",
JUST(attrs.SetAttr("learning_rate_val", learning_rate)); "l2", "centered", "epsilon",
JUST(attrs.SetAttr("scale", scale)); "decay_rate", "weight_decay");
JUST(attrs.SetAttr("l1", l1)); attrs.SetAllAttrs(learning_rate, scale, l1, l2, centered, epsilon, decay_rate,
JUST(attrs.SetAttr("l2", l2)); weight_decay);
JUST(attrs.SetAttr("centered", centered));
JUST(attrs.SetAttr("epsilon", epsilon));
JUST(attrs.SetAttr("decay_rate", decay_rate));
JUST(attrs.SetAttr("weight_decay", weight_decay));
JUST(OpInterpUtil::Dispatch<TensorTuple>(*op, inputs, attrs));
return Maybe<void>::Ok();
});
m.add_functor("DispatchAdamUpdate",
[](const std::shared_ptr<OpExpr>& op, const TensorTuple& inputs,
float learning_rate, float bias_correction1, float bias_correction2,
double scale, float l1, float l2, float beta1, float beta2, float epsilon,
float weight_decay, bool amsgrad, bool do_bias_correction) -> Maybe<void> {
MutableAttrMap attrs;
JUST(attrs.SetAttr("learning_rate_val", learning_rate));
JUST(attrs.SetAttr("bias_correction1_val", bias_correction1));
JUST(attrs.SetAttr("bias_correction2_val", bias_correction2));
JUST(attrs.SetAttr("scale", scale));
JUST(attrs.SetAttr("l1", l1));
JUST(attrs.SetAttr("l2", l2));
JUST(attrs.SetAttr("beta1", beta1));
JUST(attrs.SetAttr("beta2", beta2));
JUST(attrs.SetAttr("epsilon", epsilon));
JUST(attrs.SetAttr("weight_decay", weight_decay));
JUST(attrs.SetAttr("amsgrad", amsgrad));
JUST(attrs.SetAttr("do_bias_correction", do_bias_correction));
JUST(OpInterpUtil::Dispatch<TensorTuple>(*op, inputs, attrs)); JUST(OpInterpUtil::Dispatch<TensorTuple>(*op, inputs, attrs));
return Maybe<void>::Ok(); return Maybe<void>::Ok();
}); });
m.add_functor(
"DispatchAdamUpdate",
[](const std::shared_ptr<OpExpr>& op, const TensorTuple& inputs, float learning_rate,
float bias_correction1, float bias_correction2, double scale, float l1, float l2,
float beta1, float beta2, float epsilon, float weight_decay, bool amsgrad,
bool do_bias_correction) -> Maybe<void> {
auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(
"learning_rate_val", "bias_correction1_val", "bias_correction2_val", "scale", "l1",
"l2", "beta1", "beta2", "epsilon", "weight_decay", "amsgrad", "do_bias_correction");
attrs.SetAllAttrs(learning_rate, bias_correction1, bias_correction2, scale, l1, l2, beta1,
beta2, epsilon, weight_decay, amsgrad, do_bias_correction);
JUST(OpInterpUtil::Dispatch<TensorTuple>(*op, inputs, attrs));
return Maybe<void>::Ok();
});
m.add_functor("DispatchAdagradUpdate", m.add_functor("DispatchAdagradUpdate",
[](const std::shared_ptr<OpExpr>& op, const TensorTuple& inputs, [](const std::shared_ptr<OpExpr>& op, const TensorTuple& inputs,
float learning_rate, double scale, float l1, float l2, float lr_decay, float learning_rate, double scale, float l1, float l2, float lr_decay,
float weight_decay, float epsilon, int32_t train_step) -> Maybe<void> { float weight_decay, float epsilon, int32_t train_step) -> Maybe<void> {
MutableAttrMap attrs; auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("learning_rate_val", "scale", "l1",
JUST(attrs.SetAttr("learning_rate_val", learning_rate)); "l2", "lr_decay", "weight_decay",
JUST(attrs.SetAttr("scale", scale)); "epsilon", "train_step_val");
JUST(attrs.SetAttr("l1", l1)); attrs.SetAllAttrs(learning_rate, scale, l1, l2, lr_decay, weight_decay, epsilon,
JUST(attrs.SetAttr("l2", l2)); train_step);
JUST(attrs.SetAttr("lr_decay", lr_decay));
JUST(attrs.SetAttr("weight_decay", weight_decay));
JUST(attrs.SetAttr("epsilon", epsilon));
JUST(attrs.SetAttr("train_step_val", train_step));
JUST(OpInterpUtil::Dispatch<TensorTuple>(*op, inputs, attrs)); JUST(OpInterpUtil::Dispatch<TensorTuple>(*op, inputs, attrs));
return Maybe<void>::Ok(); return Maybe<void>::Ok();
}); });
m.add_functor( m.add_functor(
"DispatchMomentumUpdate", "DispatchMomentumUpdate",
[](const std::shared_ptr<OpExpr>& op, const TensorTuple& inputs, float learning_rate, [](const std::shared_ptr<OpExpr>& op, const TensorTuple& inputs, float learning_rate,
double scale, float l1, float l2, float beta, float weight_decay) -> Maybe<void> { double scale, float l1, float l2, float beta, float dampening, bool nesterov,
MutableAttrMap attrs; bool maximize, float weight_decay) -> Maybe<void> {
JUST(attrs.SetAttr("learning_rate_val", learning_rate)); auto& attrs =
JUST(attrs.SetAttr("scale", scale)); THREAD_CACHED_MUTABLE_ATTR_MAP("learning_rate_val", "scale", "l1", "l2", "beta",
JUST(attrs.SetAttr("l1", l1)); "dampening", "nesterov", "maximize", "weight_decay");
JUST(attrs.SetAttr("l2", l2)); attrs.SetAllAttrs(learning_rate, scale, l1, l2, beta, dampening, nesterov, maximize,
JUST(attrs.SetAttr("beta", beta)); weight_decay);
JUST(attrs.SetAttr("weight_decay", weight_decay));
JUST(OpInterpUtil::Dispatch<TensorTuple>(*op, inputs, attrs)); JUST(OpInterpUtil::Dispatch<TensorTuple>(*op, inputs, attrs));
return Maybe<void>::Ok(); return Maybe<void>::Ok();
}); });
...@@ -479,12 +402,9 @@ ONEFLOW_FUNCTION_LIBRARY(m) { ...@@ -479,12 +402,9 @@ ONEFLOW_FUNCTION_LIBRARY(m) {
"DispatchSgdUpdate", "DispatchSgdUpdate",
[](const std::shared_ptr<OpExpr>& op, const TensorTuple& inputs, float learning_rate, [](const std::shared_ptr<OpExpr>& op, const TensorTuple& inputs, float learning_rate,
double scale, float l1, float l2, float weight_decay) -> Maybe<void> { double scale, float l1, float l2, float weight_decay) -> Maybe<void> {
MutableAttrMap attrs; auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("learning_rate_val", "scale", "l1", "l2",
JUST(attrs.SetAttr("learning_rate_val", learning_rate)); "weight_decay");
JUST(attrs.SetAttr("scale", scale)); attrs.SetAllAttrs(learning_rate, scale, l1, l2, weight_decay);
JUST(attrs.SetAttr("l1", l1));
JUST(attrs.SetAttr("l2", l2));
JUST(attrs.SetAttr("weight_decay", weight_decay));
JUST(OpInterpUtil::Dispatch<TensorTuple>(*op, inputs, attrs)); JUST(OpInterpUtil::Dispatch<TensorTuple>(*op, inputs, attrs));
return Maybe<void>::Ok(); return Maybe<void>::Ok();
}); });
...@@ -493,18 +413,12 @@ ONEFLOW_FUNCTION_LIBRARY(m) { ...@@ -493,18 +413,12 @@ ONEFLOW_FUNCTION_LIBRARY(m) {
float learning_rate, float bias_correction1, float bias_correction2, float learning_rate, float bias_correction1, float bias_correction2,
double scale, float l1, float l2, float beta1, float beta2, float epsilon, double scale, float l1, float l2, float beta1, float beta2, float epsilon,
float weight_decay, bool do_bias_correction) -> Maybe<void> { float weight_decay, bool do_bias_correction) -> Maybe<void> {
MutableAttrMap attrs; auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(
JUST(attrs.SetAttr("learning_rate_val", learning_rate)); "learning_rate_val", "bias_correction1_val", "bias_correction2_val", "scale",
JUST(attrs.SetAttr("bias_correction1_val", bias_correction1)); "l1", "l2", "beta1", "beta2", "epsilon", "weight_decay",
JUST(attrs.SetAttr("bias_correction2_val", bias_correction2)); "do_bias_correction");
JUST(attrs.SetAttr("scale", scale)); attrs.SetAllAttrs(learning_rate, bias_correction1, bias_correction2, scale, l1,
JUST(attrs.SetAttr("l1", l1)); l2, beta1, beta2, epsilon, weight_decay, do_bias_correction);
JUST(attrs.SetAttr("l2", l2));
JUST(attrs.SetAttr("beta1", beta1));
JUST(attrs.SetAttr("beta2", beta2));
JUST(attrs.SetAttr("epsilon", epsilon));
JUST(attrs.SetAttr("weight_decay", weight_decay));
JUST(attrs.SetAttr("do_bias_correction", do_bias_correction));
JUST(OpInterpUtil::Dispatch<TensorTuple>(*op, inputs, attrs)); JUST(OpInterpUtil::Dispatch<TensorTuple>(*op, inputs, attrs));
return Maybe<void>::Ok(); return Maybe<void>::Ok();
}); });
...@@ -512,27 +426,61 @@ ONEFLOW_FUNCTION_LIBRARY(m) { ...@@ -512,27 +426,61 @@ ONEFLOW_FUNCTION_LIBRARY(m) {
[](const std::shared_ptr<OpExpr>& op, const TensorTuple& inputs, [](const std::shared_ptr<OpExpr>& op, const TensorTuple& inputs,
float learning_rate, double scale, float l1, float l2, float lr_power, float learning_rate, double scale, float l1, float l2, float lr_power,
float lambda1, float lambda2, float beta, float weight_decay) -> Maybe<void> { float lambda1, float lambda2, float beta, float weight_decay) -> Maybe<void> {
MutableAttrMap attrs; auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("learning_rate_val", "scale", "l1",
JUST(attrs.SetAttr("learning_rate_val", learning_rate)); "l2", "lr_power", "lambda1",
JUST(attrs.SetAttr("scale", scale)); "lambda2", "beta", "weight_decay");
JUST(attrs.SetAttr("l1", l1)); attrs.SetAllAttrs(learning_rate, scale, l1, l2, lr_power, lambda1, lambda2, beta,
JUST(attrs.SetAttr("l2", l2)); weight_decay);
JUST(attrs.SetAttr("lr_power", lr_power));
JUST(attrs.SetAttr("lambda1", lambda1));
JUST(attrs.SetAttr("lambda2", lambda2));
JUST(attrs.SetAttr("beta", beta));
JUST(attrs.SetAttr("weight_decay", weight_decay));
JUST(OpInterpUtil::Dispatch<TensorTuple>(*op, inputs, attrs)); JUST(OpInterpUtil::Dispatch<TensorTuple>(*op, inputs, attrs));
return Maybe<void>::Ok(); return Maybe<void>::Ok();
}); });
m.add_functor("DispatchEagerNcclAllReduce", m.add_functor(
"DispatchAdadeltaUpdate",
[](const std::shared_ptr<OpExpr>& op, const TensorTuple& inputs, float learning_rate,
double scale, float l1, float l2, float rho, float epsilon, bool maximize,
float weight_decay) -> Maybe<void> {
auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("learning_rate_val", "scale", "l1", "l2",
"rho", "epsilon", "maximize", "weight_decay");
attrs.SetAllAttrs(learning_rate, scale, l1, l2, rho, epsilon, maximize, weight_decay);
JUST(OpInterpUtil::Dispatch<TensorTuple>(*op, inputs, attrs));
return Maybe<void>::Ok();
});
m.add_functor("DispatchEagerCclAllReduce",
[](const std::shared_ptr<OpExpr>& op, const std::shared_ptr<Tensor>& input, [](const std::shared_ptr<OpExpr>& op, const std::shared_ptr<Tensor>& input,
const std::string& parallel_conf, bool async_launch) -> Maybe<Tensor> { const std::string& parallel_conf) -> Maybe<Tensor> {
MutableAttrMap attrs; auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("parallel_conf");
JUST(attrs.SetAttr("parallel_conf", parallel_conf)); attrs.SetAllAttrs(parallel_conf);
JUST(attrs.SetAttr("async_launch", async_launch));
return OpInterpUtil::Dispatch<Tensor>(*op, {input}, attrs); return OpInterpUtil::Dispatch<Tensor>(*op, {input}, attrs);
}); });
m.add_functor(
"DispatchRawReader",
[](const std::shared_ptr<OpExpr>& op, const std::vector<std::string>& files,
const Shape& shape, const Symbol<DType>& data_type, const int64_t batch_size,
const bool random_shuffle, const int64_t shuffle_block_size, int64_t random_seed,
const Optional<Symbol<Device>>& device) -> Maybe<Tensor> {
auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("files", "shape", "data_type", "batch_size",
"random_shuffle", "shuffle_block_size", "seed",
"nd_sbp");
attrs.SetAllAttrs(files, shape, data_type->data_type(), batch_size, random_shuffle,
shuffle_block_size, random_seed, std::vector<std::string>());
return OpInterpUtil::Dispatch<Tensor>(*op, {}, OpExprInterpContext(attrs, JUST(device)));
});
m.add_functor("DispatchRawReader",
[](const std::shared_ptr<OpExpr>& op, const std::vector<std::string>& files,
const Shape& shape, const Symbol<DType>& data_type, const int64_t batch_size,
const bool random_shuffle, const int64_t shuffle_block_size, int64_t random_seed,
const Symbol<ParallelDesc>& placement,
const std::vector<Symbol<SbpParallel>>& sbp_tuple) -> Maybe<Tensor> {
auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(
"files", "shape", "data_type", "batch_size", "random_shuffle",
"shuffle_block_size", "seed", "nd_sbp");
attrs.SetAllAttrs(files, shape, data_type->data_type(), batch_size,
random_shuffle, shuffle_block_size, random_seed,
*JUST(GetNdSbpStrList(sbp_tuple)));
auto nd_sbp = JUST(GetNdSbp(sbp_tuple));
return OpInterpUtil::Dispatch<Tensor>(
*op, {}, OpExprInterpContext(attrs, placement, nd_sbp));
});
} }
} // namespace impl } // namespace impl
......
...@@ -137,7 +137,7 @@ ...@@ -137,7 +137,7 @@
bind_python: True bind_python: True
- name: "dispatch_momentum_update" - name: "dispatch_momentum_update"
signature: "Void (OpExpr op, TensorTuple inputs, Float learning_rate=0, Double scale=1.0, Float l1=0, Float l2=0, Float beta=0.9, Float weight_decay=0) => DispatchMomentumUpdate" signature: "Void (OpExpr op, TensorTuple inputs, Float learning_rate=0, Double scale=1.0, Float l1=0, Float l2=0, Float beta=0.9, Float dampening=0.0, Bool nesterov=False, Bool maximize=False, Float weight_decay=0) => DispatchMomentumUpdate"
bind_python: True bind_python: True
- name: "dispatch_sgd_update" - name: "dispatch_sgd_update"
...@@ -152,6 +152,18 @@ ...@@ -152,6 +152,18 @@
signature: "Void (OpExpr op, TensorTuple inputs, Float learning_rate=0, Double scale=1.0, Float l1=0, Float l2=0, Float lr_power, Float lambda1, Float lambda2, Float beta, Float weight_decay=0) => DispatchFtrlUpdate" signature: "Void (OpExpr op, TensorTuple inputs, Float learning_rate=0, Double scale=1.0, Float l1=0, Float l2=0, Float lr_power, Float lambda1, Float lambda2, Float beta, Float weight_decay=0) => DispatchFtrlUpdate"
bind_python: True bind_python: True
- name: "dispatch_eager_nccl_all_reduce" - name: "dispatch_adadelta_update"
signature: "Tensor (OpExpr op, Tensor input, String parallel_conf, Bool async_launch=False) => DispatchEagerNcclAllReduce" signature: "Void (OpExpr op, TensorTuple inputs, Float learning_rate=0, Double scale=1.0, Float l1=0, Float l2=0, Float rho, Float epsilon, Bool maximize, Float weight_decay=0) => DispatchAdadeltaUpdate"
bind_python: True
- name: "dispatch_eager_ccl_all_reduce"
signature: "Tensor (OpExpr op, Tensor input, String parallel_conf) => DispatchEagerCclAllReduce"
bind_python: True
- name: "dispatch_raw_reader"
signature: [
"Tensor (OpExpr op, StringList files, Shape shape, DataType data_type, Int64 batch_size, Bool random_shuffle, Int64 shuffle_block_size, Int64 random_seed=-1, Device device=None) => DispatchRawReader",
"Tensor (OpExpr op, StringList files, Shape shape, DataType data_type, Int64 batch_size, Bool random_shuffle, Int64 shuffle_block_size, Int64 random_seed=-1, Placement placement, SbpList sbp) => DispatchRawReader",
]
bind_python: True bind_python: True
...@@ -20,7 +20,6 @@ limitations under the License. ...@@ -20,7 +20,6 @@ limitations under the License.
#include "oneflow/api/python/functional/common.h" #include "oneflow/api/python/functional/common.h"
#include "oneflow/extension/python/numpy.h" #include "oneflow/extension/python/numpy.h"
#include "oneflow/core/eager/eager_blob_object.h" #include "oneflow/core/eager/eager_blob_object.h"
#include "oneflow/core/register/ofblob.h"
#include "oneflow/core/framework/device.h" #include "oneflow/core/framework/device.h"
#include "oneflow/core/framework/instructions_builder.h" #include "oneflow/core/framework/instructions_builder.h"
#include "oneflow/core/functional/functional.h" #include "oneflow/core/functional/functional.h"
...@@ -68,7 +67,7 @@ DataType InferScalarType(PyObject* object) { ...@@ -68,7 +67,7 @@ DataType InferScalarType(PyObject* object) {
return numpy::NumpyTypeToOFDataType(PyArray_DescrFromScalar(object)->type_num).GetOrThrow(); return numpy::NumpyTypeToOFDataType(PyArray_DescrFromScalar(object)->type_num).GetOrThrow();
} else if (PySequence_Check(object)) { } else if (PySequence_Check(object)) {
int64_t length = PySequence_Length(object); int64_t length = PySequence_Length(object);
CHECK_GT_OR_THROW(length, 0) << "Index should not be empty."; if (length == 0) { return DataType::kInt64; }
DataType scalar_type = DataType::kInvalidDataType; DataType scalar_type = DataType::kInvalidDataType;
for (int64_t i = 0; i < length; ++i) { for (int64_t i = 0; i < length; ++i) {
PyObjectPtr item(PySequence_GetItem(object, i)); PyObjectPtr item(PySequence_GetItem(object, i));
...@@ -126,16 +125,18 @@ void RecursiveParseAndAssign(PyObject* object, char* data, const int& ndims, con ...@@ -126,16 +125,18 @@ void RecursiveParseAndAssign(PyObject* object, char* data, const int& ndims, con
} }
} }
void ParseArrayToBlob(PyObject* object, Blob* blob) { void ParseArrayToTensor(PyObject* object,
const DataType dtype = blob->data_type(); const std::shared_ptr<vm::EagerBlobObject>& eager_blob_object) {
const int ndims = blob->shape().NumAxes(); const DataType dtype = eager_blob_object->data_type();
const int ndims = eager_blob_object->shape().NumAxes();
DimVector strides(ndims); DimVector strides(ndims);
int64_t size = 1; int64_t size = 1;
for (int i = ndims - 1; i >= 0; --i) { for (int i = ndims - 1; i >= 0; --i) {
strides[i] = size; strides[i] = size;
size *= blob->shape().At(i); size *= eager_blob_object->shape().At(i);
} }
RecursiveParseAndAssign(object, blob->mut_dptr<char>(), ndims, 0, blob->shape(), strides, dtype); RecursiveParseAndAssign(object, eager_blob_object->mut_dptr<char>(), ndims, 0,
eager_blob_object->shape(), strides, dtype);
} }
Shape InferArraySizes(PyObject* object) { Shape InferArraySizes(PyObject* object) {
...@@ -144,7 +145,6 @@ Shape InferArraySizes(PyObject* object) { ...@@ -144,7 +145,6 @@ Shape InferArraySizes(PyObject* object) {
PyObjectPtr handle; PyObjectPtr handle;
while (PySequence_Check(seq)) { while (PySequence_Check(seq)) {
int64_t length = PySequence_Length(seq); int64_t length = PySequence_Length(seq);
CHECK_GT_OR_THROW(length, 0) << "Index should not be empty.";
sizes.emplace_back(length); sizes.emplace_back(length);
CHECK_LE_OR_THROW(sizes.size(), /*MAX_DIMS=*/128) CHECK_LE_OR_THROW(sizes.size(), /*MAX_DIMS=*/128)
<< "Too many dimensions " << Py_TYPE(seq)->tp_name; << "Too many dimensions " << Py_TYPE(seq)->tp_name;
...@@ -156,6 +156,8 @@ Shape InferArraySizes(PyObject* object) { ...@@ -156,6 +156,8 @@ Shape InferArraySizes(PyObject* object) {
} }
Maybe<Tensor> ConvertToIndexingTensor(PyObject* object) { Maybe<Tensor> ConvertToIndexingTensor(PyObject* object) {
// NOTE: convert data to indexing will ensure in eager mode
LazyMode::Guard lazy_mode_disabled_guard(/*is_enabled*/ false);
const DataType dtype = InferScalarType(object); const DataType dtype = InferScalarType(object);
const auto& device = JUST(Device::New("cpu")); const auto& device = JUST(Device::New("cpu"));
...@@ -178,11 +180,11 @@ Maybe<Tensor> ConvertToIndexingTensor(PyObject* object) { ...@@ -178,11 +180,11 @@ Maybe<Tensor> ConvertToIndexingTensor(PyObject* object) {
JUST(PhysicalRun([&](InstructionsBuilder* builder) -> Maybe<void> { JUST(PhysicalRun([&](InstructionsBuilder* builder) -> Maybe<void> {
return builder->AccessBlobByCallback( return builder->AccessBlobByCallback(
JUST(tensor->AsMirroredTensor()), JUST(tensor->AsLocalTensor()),
[handle](uint64_t ofblob_ptr) { [handle](ep::Stream* stream,
auto* of_blob = reinterpret_cast<OfBlob*>(ofblob_ptr); const std::shared_ptr<vm::EagerBlobObject>& eager_blob_object) {
CHECK_JUST(Singleton<ForeignLockHelper>::Get()->WithScopedAcquire([&]() -> Maybe<void> { CHECK_JUST(Singleton<ForeignLockHelper>::Get()->WithScopedAcquire([&]() -> Maybe<void> {
ParseArrayToBlob(handle.get(), of_blob->mut_blob()); ParseArrayToTensor(handle.get(), eager_blob_object);
return Maybe<void>::Ok(); return Maybe<void>::Ok();
})); }));
}, },
...@@ -211,7 +213,7 @@ IndexItem UnpackIndexItem(PyObject* object) { ...@@ -211,7 +213,7 @@ IndexItem UnpackIndexItem(PyObject* object) {
} else if (PySequence_Check(object)) { } else if (PySequence_Check(object)) {
return IndexItem(ConvertToIndexingTensor(object).GetPtrOrThrow()); return IndexItem(ConvertToIndexingTensor(object).GetPtrOrThrow());
} }
THROW(TypeError) << "Invalid index " << Py_TYPE(object)->tp_name; THROW(IndexError) << "Invalid index " << Py_TYPE(object)->tp_name;
return IndexItem(); return IndexItem();
} }
......
...@@ -37,6 +37,9 @@ namespace functional { ...@@ -37,6 +37,9 @@ namespace functional {
#define INSTANCE_OBJECT_AS_INTEGER(T) \ #define INSTANCE_OBJECT_AS_INTEGER(T) \
template<> \ template<> \
T PythonArg::ObjectAs<T>() const { \ T PythonArg::ObjectAs<T>() const { \
if (PyIntegerScalarTensorCheck(object_)) { \
return static_cast<T>(PyUnpackIntegerScalarTensor_AsLongLong(object_)); \
} \
return static_cast<T>(PyLong_AsLongLong(object_)); \ return static_cast<T>(PyLong_AsLongLong(object_)); \
} \ } \
template<> \ template<> \
...@@ -51,12 +54,15 @@ namespace functional { ...@@ -51,12 +54,15 @@ namespace functional {
return std::make_shared<std::vector<T>>(ObjectAs<std::vector<T>>()); \ return std::make_shared<std::vector<T>>(ObjectAs<std::vector<T>>()); \
} }
OF_PP_FOR_EACH_TUPLE(INSTANCE_OBJECT_AS_INTEGER, INTEGER_TYPE_SEQ) OF_PP_FOR_EACH_TUPLE(INSTANCE_OBJECT_AS_INTEGER, INTEGER_AND_BOOL_TYPE_SEQ)
#undef INSTANCE_OBJECT_AS_INTEGER #undef INSTANCE_OBJECT_AS_INTEGER
#define INSTANCE_OBJECT_AS_FLOAT(T) \ #define INSTANCE_OBJECT_AS_FLOAT(T) \
template<> \ template<> \
T PythonArg::ObjectAs<T>() const { \ T PythonArg::ObjectAs<T>() const { \
if (PyFloatScalarTensorCheck(object_)) { \
return static_cast<T>(PyUnpackFloatScalarTensor_AsDouble(object_)); \
} \
return static_cast<T>(PyFloat_AsDouble(object_)); \ return static_cast<T>(PyFloat_AsDouble(object_)); \
} \ } \
template<> \ template<> \
...@@ -88,6 +94,7 @@ INSTANCE_OBJECT_AS_SHARED_PTR(std::string) ...@@ -88,6 +94,7 @@ INSTANCE_OBJECT_AS_SHARED_PTR(std::string)
template<> template<>
Scalar PythonArg::ObjectAs<Scalar>() const { Scalar PythonArg::ObjectAs<Scalar>() const {
if (PyScalarTensorCheck(object_)) { return PyUnpackScalarTensor(object_); }
return PyUnpackScalar(object_); return PyUnpackScalar(object_);
} }
INSTANCE_OBJECT_AS_SHARED_PTR(Scalar) INSTANCE_OBJECT_AS_SHARED_PTR(Scalar)
...@@ -120,8 +127,7 @@ INSTANCE_OBJECT_AS_SHARED_PTR(std::vector<Symbol<DType>>) ...@@ -120,8 +127,7 @@ INSTANCE_OBJECT_AS_SHARED_PTR(std::vector<Symbol<DType>>)
template<> template<>
Shape PythonArg::ObjectAs<Shape>() const { Shape PythonArg::ObjectAs<Shape>() const {
const auto& shape = PyUnpackLongSequence<int64_t>(object_); return PyUnpackShape(object_);
return Shape(DimVector(shape.begin(), shape.end()));
} }
INSTANCE_OBJECT_AS_SHARED_PTR(Shape) INSTANCE_OBJECT_AS_SHARED_PTR(Shape)
...@@ -197,7 +203,9 @@ bool PythonArg::TypeCheck(ValueType type) const { ...@@ -197,7 +203,9 @@ bool PythonArg::TypeCheck(ValueType type) const {
case kUINT32: case kUINT32:
case kINT64: case kINT64:
case kUINT64: case kUINT64:
case kBOOL: return PyLong_Check(object_) || numpy::PyArrayCheckLongScalar(object_); case kBOOL:
return PyLong_Check(object_) || numpy::PyArrayCheckLongScalar(object_)
|| PyIntegerScalarTensorCheck(object_) || PyBoolScalarTensorCheck(object_);
case kINT32_LIST: case kINT32_LIST:
case kUINT32_LIST: case kUINT32_LIST:
case kINT64_LIST: case kINT64_LIST:
...@@ -206,16 +214,17 @@ bool PythonArg::TypeCheck(ValueType type) const { ...@@ -206,16 +214,17 @@ bool PythonArg::TypeCheck(ValueType type) const {
case kFLOAT: case kFLOAT:
case kDOUBLE: case kDOUBLE:
return PyFloat_Check(object_) || PyLong_Check(object_) return PyFloat_Check(object_) || PyLong_Check(object_)
|| numpy::PyArrayCheckFloatScalar(object_) || numpy::PyArrayCheckLongScalar(object_); || numpy::PyArrayCheckFloatScalar(object_) || numpy::PyArrayCheckLongScalar(object_)
|| PyFloatScalarTensorCheck(object_) || PyIntegerScalarTensorCheck(object_);
case kFLOAT_LIST: case kFLOAT_LIST:
case kDOUBLE_LIST: case kDOUBLE_LIST:
return PyFloatSquenceCheck(object_) return PyFloatSequenceCheck(object_)
|| (size_ > 0 && (PyFloat_Check(object_) || PyLong_Check(object_))); || (size_ > 0 && (PyFloat_Check(object_) || PyLong_Check(object_)));
case kSTRING: return PyStringCheck(object_); case kSTRING: return PyStringCheck(object_);
case kSTRING_LIST: return PyStringSequenceCheck(object_); case kSTRING_LIST: return PyStringSequenceCheck(object_);
case kSCALAR: case kSCALAR:
return PyScalarCheck(object_) || numpy::PyArrayCheckLongScalar(object_) return PyScalarCheck(object_) || numpy::PyArrayCheckLongScalar(object_)
|| numpy::PyArrayCheckFloatScalar(object_); || numpy::PyArrayCheckFloatScalar(object_) || PyScalarTensorCheck(object_);
case kTENSOR: case kTENSOR:
case kTENSOR_REF: return PyTensor_Check(object_); case kTENSOR_REF: return PyTensor_Check(object_);
case kTENSOR_TUPLE: return PyTensorTupleCheck(object_) || PyTensorSequenceCheck(object_); case kTENSOR_TUPLE: return PyTensorTupleCheck(object_) || PyTensorSequenceCheck(object_);
...@@ -224,7 +233,7 @@ bool PythonArg::TypeCheck(ValueType type) const { ...@@ -224,7 +233,7 @@ bool PythonArg::TypeCheck(ValueType type) const {
case kGENERATOR: case kGENERATOR:
case kGENERATOR_REF: return PyGeneratorCheck(object_); case kGENERATOR_REF: return PyGeneratorCheck(object_);
case kTENSOR_INDEX: return PyTensorIndexCheck(object_); case kTENSOR_INDEX: return PyTensorIndexCheck(object_);
case kDEVICE: return PyDeviceCheck(object_) || PyStringCheck(object_); case kDEVICE: return PyStringCheck(object_) || PyDeviceCheck(object_);
case kPARALLEL_DESC: return PyParallelDescCheck(object_); case kPARALLEL_DESC: return PyParallelDescCheck(object_);
case kSBP_PARALLEL: return PySbpParallelCheck(object_); case kSBP_PARALLEL: return PySbpParallelCheck(object_);
case kSBP_PARALLEL_LIST: case kSBP_PARALLEL_LIST:
...@@ -240,8 +249,6 @@ bool PythonArg::TypeCheck(ValueType type) const { ...@@ -240,8 +249,6 @@ bool PythonArg::TypeCheck(ValueType type) const {
return false; return false;
} }
bool PythonArgCheck(const PythonArg& arg, ValueType type) { return arg.TypeCheck(type); }
} // namespace functional } // namespace functional
} // namespace one } // namespace one
} // namespace oneflow } // namespace oneflow
...@@ -61,22 +61,13 @@ struct optional_traits<Optional<T>> { ...@@ -61,22 +61,13 @@ struct optional_traits<Optional<T>> {
class PythonArg { class PythonArg {
public: public:
PythonArg() = default; PythonArg() = default;
PythonArg(const py::object& object, int size = 0) : PythonArg(object.ptr(), size) {}
PythonArg(PyObject* object, int size = 0) PythonArg(PyObject* object, int size = 0)
: object_(object), default_val_(), size_(size), tag_(HAS_OBJECT) {} : object_(object), default_val_(), size_(size), tag_(HAS_OBJECT) {}
PythonArg(const std::shared_ptr<const detail::DefaultVal>& value, int size = 0) PythonArg(const detail::DefaultVal* value, int size = 0)
: object_(nullptr), default_val_(value), size_(size), tag_(HAS_DEFAULT) {} : object_(nullptr), default_val_(value), size_(size), tag_(HAS_DEFAULT) {}
template<typename T, typename std::enable_if<!py::detail::is_pyobject<T>::value, int>::type = 0>
PythonArg(const T& value, int size = 0)
: object_(nullptr),
default_val_(std::make_shared<detail::TypedDefaultVal<T>>(value)),
size_(size),
tag_(HAS_DEFAULT) {}
virtual ~PythonArg() = default;
template<typename T, typename std::enable_if<!internal::IsOptional<T>::value, int>::type = 0> template<typename T, typename std::enable_if<!internal::IsOptional<T>::value, int>::type = 0>
T As() const { T As() const {
if (tag_ == HAS_DEFAULT) { if (tag_ == HAS_DEFAULT) {
...@@ -109,13 +100,11 @@ class PythonArg { ...@@ -109,13 +100,11 @@ class PythonArg {
T ObjectAs() const; T ObjectAs() const;
PyObject* object_; PyObject* object_;
std::shared_ptr<const detail::DefaultVal> default_val_; const detail::DefaultVal* default_val_;
size_t size_; size_t size_;
enum { HAS_OBJECT, HAS_DEFAULT, HAS_NONE } tag_; enum { HAS_OBJECT, HAS_DEFAULT, HAS_NONE } tag_;
}; };
bool PythonArgCheck(const PythonArg& arg, ValueType type);
} // namespace functional } // namespace functional
} // namespace one } // namespace one
} // namespace oneflow } // namespace oneflow
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment