Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
pybind11
Commits
05cb58ad
Commit
05cb58ad
authored
Jul 20, 2016
by
Ivan Smirnov
Browse files
Cleanup: move numpy API bindings out of py::array
parent
afb07e7e
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
87 additions
and
85 deletions
+87
-85
include/pybind11/numpy.h
include/pybind11/numpy.h
+87
-85
No files found.
include/pybind11/numpy.h
View file @
05cb58ad
...
...
@@ -28,87 +28,94 @@ NAMESPACE_BEGIN(pybind11)
namespace
detail
{
template
<
typename
type
,
typename
SFINAE
=
void
>
struct
npy_format_descriptor
{
};
template
<
typename
type
>
struct
is_pod_struct
;
}
class
array
:
public
buffer
{
public:
struct
API
{
enum
Entries
{
API_PyArray_Type
=
2
,
API_PyArray_DescrFromType
=
45
,
API_PyArray_FromAny
=
69
,
API_PyArray_NewCopy
=
85
,
API_PyArray_NewFromDescr
=
94
,
API_PyArray_DescrNewFromType
=
9
,
API_PyArray_DescrConverter
=
174
,
API_PyArray_EquivTypes
=
182
,
API_PyArray_GetArrayParamsFromObject
=
278
,
NPY_C_CONTIGUOUS_
=
0x0001
,
NPY_F_CONTIGUOUS_
=
0x0002
,
NPY_ARRAY_FORCECAST_
=
0x0010
,
NPY_ENSURE_ARRAY_
=
0x0040
,
NPY_BOOL_
=
0
,
NPY_BYTE_
,
NPY_UBYTE_
,
NPY_SHORT_
,
NPY_USHORT_
,
NPY_INT_
,
NPY_UINT_
,
NPY_LONG_
,
NPY_ULONG_
,
NPY_LONGLONG_
,
NPY_ULONGLONG_
,
NPY_FLOAT_
,
NPY_DOUBLE_
,
NPY_LONGDOUBLE_
,
NPY_CFLOAT_
,
NPY_CDOUBLE_
,
NPY_CLONGDOUBLE_
,
NPY_OBJECT_
=
17
,
NPY_STRING_
,
NPY_UNICODE_
,
NPY_VOID_
};
static
API
lookup
()
{
module
m
=
module
::
import
(
"numpy.core.multiarray"
);
object
c
=
(
object
)
m
.
attr
(
"_ARRAY_API"
);
struct
npy_api
{
enum
constants
{
NPY_C_CONTIGUOUS_
=
0x0001
,
NPY_F_CONTIGUOUS_
=
0x0002
,
NPY_ARRAY_FORCECAST_
=
0x0010
,
NPY_ENSURE_ARRAY_
=
0x0040
,
NPY_BOOL_
=
0
,
NPY_BYTE_
,
NPY_UBYTE_
,
NPY_SHORT_
,
NPY_USHORT_
,
NPY_INT_
,
NPY_UINT_
,
NPY_LONG_
,
NPY_ULONG_
,
NPY_LONGLONG_
,
NPY_ULONGLONG_
,
NPY_FLOAT_
,
NPY_DOUBLE_
,
NPY_LONGDOUBLE_
,
NPY_CFLOAT_
,
NPY_CDOUBLE_
,
NPY_CLONGDOUBLE_
,
NPY_OBJECT_
=
17
,
NPY_STRING_
,
NPY_UNICODE_
,
NPY_VOID_
};
static
npy_api
&
get
()
{
static
npy_api
api
=
lookup
();
return
api
;
}
bool
PyArray_Check_
(
PyObject
*
obj
)
const
{
return
(
bool
)
PyObject_TypeCheck
(
obj
,
PyArray_Type_
);
}
PyObject
*
(
*
PyArray_DescrFromType_
)(
int
);
PyObject
*
(
*
PyArray_NewFromDescr_
)
(
PyTypeObject
*
,
PyObject
*
,
int
,
Py_intptr_t
*
,
Py_intptr_t
*
,
void
*
,
int
,
PyObject
*
);
PyObject
*
(
*
PyArray_DescrNewFromType_
)(
int
);
PyObject
*
(
*
PyArray_NewCopy_
)(
PyObject
*
,
int
);
PyTypeObject
*
PyArray_Type_
;
PyObject
*
(
*
PyArray_FromAny_
)
(
PyObject
*
,
PyObject
*
,
int
,
int
,
int
,
PyObject
*
);
int
(
*
PyArray_DescrConverter_
)
(
PyObject
*
,
PyObject
**
);
bool
(
*
PyArray_EquivTypes_
)
(
PyObject
*
,
PyObject
*
);
int
(
*
PyArray_GetArrayParamsFromObject_
)(
PyObject
*
,
PyObject
*
,
char
,
PyObject
**
,
int
*
,
Py_ssize_t
*
,
PyObject
**
,
PyObject
*
);
private:
enum
functions
{
API_PyArray_Type
=
2
,
API_PyArray_DescrFromType
=
45
,
API_PyArray_FromAny
=
69
,
API_PyArray_NewCopy
=
85
,
API_PyArray_NewFromDescr
=
94
,
API_PyArray_DescrNewFromType
=
9
,
API_PyArray_DescrConverter
=
174
,
API_PyArray_EquivTypes
=
182
,
API_PyArray_GetArrayParamsFromObject
=
278
,
};
static
npy_api
lookup
()
{
module
m
=
module
::
import
(
"numpy.core.multiarray"
);
object
c
=
(
object
)
m
.
attr
(
"_ARRAY_API"
);
#if PY_MAJOR_VERSION >= 3
void
**
api_ptr
=
(
void
**
)
(
c
?
PyCapsule_GetPointer
(
c
.
ptr
(),
NULL
)
:
nullptr
);
void
**
api_ptr
=
(
void
**
)
(
c
?
PyCapsule_GetPointer
(
c
.
ptr
(),
NULL
)
:
nullptr
);
#else
void
**
api_ptr
=
(
void
**
)
(
c
?
PyCObject_AsVoidPtr
(
c
.
ptr
())
:
nullptr
);
void
**
api_ptr
=
(
void
**
)
(
c
?
PyCObject_AsVoidPtr
(
c
.
ptr
())
:
nullptr
);
#endif
API
api
;
npy_api
api
;
#define DECL_NPY_API(Func) api.Func##_ = (decltype(api.Func##_)) api_ptr[API_##Func];
DECL_NPY_API
(
PyArray_Type
);
DECL_NPY_API
(
PyArray_DescrFromType
);
DECL_NPY_API
(
PyArray_FromAny
);
DECL_NPY_API
(
PyArray_NewCopy
);
DECL_NPY_API
(
PyArray_NewFromDescr
);
DECL_NPY_API
(
PyArray_DescrNewFromType
);
DECL_NPY_API
(
PyArray_DescrConverter
);
DECL_NPY_API
(
PyArray_EquivTypes
);
DECL_NPY_API
(
PyArray_GetArrayParamsFromObject
);
DECL_NPY_API
(
PyArray_Type
);
DECL_NPY_API
(
PyArray_DescrFromType
);
DECL_NPY_API
(
PyArray_FromAny
);
DECL_NPY_API
(
PyArray_NewCopy
);
DECL_NPY_API
(
PyArray_NewFromDescr
);
DECL_NPY_API
(
PyArray_DescrNewFromType
);
DECL_NPY_API
(
PyArray_DescrConverter
);
DECL_NPY_API
(
PyArray_EquivTypes
);
DECL_NPY_API
(
PyArray_GetArrayParamsFromObject
);
#undef DECL_NPY_API
return
api
;
}
bool
PyArray_Check_
(
PyObject
*
obj
)
const
{
return
(
bool
)
PyObject_TypeCheck
(
obj
,
PyArray_Type_
);
}
PyObject
*
(
*
PyArray_DescrFromType_
)(
int
);
PyObject
*
(
*
PyArray_NewFromDescr_
)
(
PyTypeObject
*
,
PyObject
*
,
int
,
Py_intptr_t
*
,
Py_intptr_t
*
,
void
*
,
int
,
PyObject
*
);
PyObject
*
(
*
PyArray_DescrNewFromType_
)(
int
);
PyObject
*
(
*
PyArray_NewCopy_
)(
PyObject
*
,
int
);
PyTypeObject
*
PyArray_Type_
;
PyObject
*
(
*
PyArray_FromAny_
)
(
PyObject
*
,
PyObject
*
,
int
,
int
,
int
,
PyObject
*
);
int
(
*
PyArray_DescrConverter_
)
(
PyObject
*
,
PyObject
**
);
bool
(
*
PyArray_EquivTypes_
)
(
PyObject
*
,
PyObject
*
);
int
(
*
PyArray_GetArrayParamsFromObject_
)(
PyObject
*
,
PyObject
*
,
char
,
PyObject
**
,
int
*
,
Py_ssize_t
*
,
PyObject
**
,
PyObject
*
);
};
return
api
;
}
};
}
PYBIND11_OBJECT_DEFAULT
(
array
,
buffer
,
lookup_api
().
PyArray_Check_
)
class
array
:
public
buffer
{
public:
PYBIND11_OBJECT_DEFAULT
(
array
,
buffer
,
detail
::
npy_api
::
get
().
PyArray_Check_
)
enum
{
c_style
=
API
::
NPY_C_CONTIGUOUS_
,
f_style
=
API
::
NPY_F_CONTIGUOUS_
,
forcecast
=
API
::
NPY_ARRAY_FORCECAST_
c_style
=
detail
::
npy_api
::
NPY_C_CONTIGUOUS_
,
f_style
=
detail
::
npy_api
::
NPY_F_CONTIGUOUS_
,
forcecast
=
detail
::
npy_api
::
NPY_ARRAY_FORCECAST_
};
template
<
typename
Type
>
array
(
size_t
size
,
const
Type
*
ptr
)
{
API
&
api
=
lookup_api
();
auto
&
api
=
detail
::
npy_api
::
get
();
PyObject
*
descr
=
detail
::
npy_format_descriptor
<
Type
>::
dtype
().
release
().
ptr
();
Py_intptr_t
shape
=
(
Py_intptr_t
)
size
;
object
tmp
=
object
(
api
.
PyArray_NewFromDescr_
(
...
...
@@ -121,7 +128,7 @@ public:
}
array
(
const
buffer_info
&
info
)
{
auto
&
api
=
lookup_api
();
auto
&
api
=
detail
::
npy_api
::
get
();
// _dtype_from_pep3118 returns dtypes with padding fields in, so we need to strip them
auto
numpy_internal
=
module
::
import
(
"numpy.core._internal"
);
...
...
@@ -139,11 +146,6 @@ public:
}
protected:
static
API
&
lookup_api
()
{
static
API
api
=
API
::
lookup
();
return
api
;
}
template
<
typename
T
,
typename
SFINAE
>
friend
struct
detail
::
npy_format_descriptor
;
static
object
strip_padding_fields
(
object
dtype
)
{
...
...
@@ -183,7 +185,7 @@ protected:
args
[
"itemsize"
]
=
dtype
.
attr
(
"itemsize"
).
cast
<
int_
>
();
PyObject
*
descr
=
nullptr
;
if
(
!
lookup_api
().
PyArray_DescrConverter_
(
args
.
release
().
ptr
(),
&
descr
)
||
!
descr
)
if
(
!
detail
::
npy_api
::
get
().
PyArray_DescrConverter_
(
args
.
release
().
ptr
(),
&
descr
)
||
!
descr
)
pybind11_fail
(
"NumPy: failed to create structured dtype"
);
return
object
(
descr
,
false
);
}
...
...
@@ -198,10 +200,10 @@ public:
static
PyObject
*
ensure
(
PyObject
*
ptr
)
{
if
(
ptr
==
nullptr
)
return
nullptr
;
API
&
api
=
lookup_api
();
auto
&
api
=
detail
::
npy_api
::
get
();
PyObject
*
descr
=
detail
::
npy_format_descriptor
<
T
>::
dtype
().
release
().
ptr
();
PyObject
*
result
=
api
.
PyArray_FromAny_
(
ptr
,
descr
,
0
,
0
,
API
::
NPY_ENSURE_ARRAY_
|
ExtraFlags
,
nullptr
);
detail
::
npy_api
::
NPY_ENSURE_ARRAY_
|
ExtraFlags
,
nullptr
);
if
(
!
result
)
PyErr_Clear
();
Py_DECREF
(
ptr
);
...
...
@@ -246,12 +248,12 @@ struct is_pod_struct {
template
<
typename
T
>
struct
npy_format_descriptor
<
T
,
typename
std
::
enable_if
<
std
::
is_integral
<
T
>::
value
>::
type
>
{
private:
constexpr
static
const
int
values
[
8
]
=
{
array
::
API
::
NPY_BYTE_
,
array
::
API
::
NPY_UBYTE_
,
array
::
API
::
NPY_SHORT_
,
array
::
API
::
NPY_USHORT_
,
array
::
API
::
NPY_INT_
,
array
::
API
::
NPY_UINT_
,
array
::
API
::
NPY_LONGLONG_
,
array
::
API
::
NPY_ULONGLONG_
};
npy_api
::
NPY_BYTE_
,
npy_api
::
NPY_UBYTE_
,
npy_api
::
NPY_SHORT_
,
npy_api
::
NPY_USHORT_
,
npy_api
::
NPY_INT_
,
npy_api
::
NPY_UINT_
,
npy_api
::
NPY_LONGLONG_
,
npy_api
::
NPY_ULONGLONG_
};
public:
enum
{
value
=
values
[
detail
::
log2
(
sizeof
(
T
))
*
2
+
(
std
::
is_unsigned
<
T
>::
value
?
1
:
0
)]
};
static
object
dtype
()
{
if
(
auto
ptr
=
array
::
lookup_api
().
PyArray_DescrFromType_
(
value
))
if
(
auto
ptr
=
npy_api
::
get
().
PyArray_DescrFromType_
(
value
))
return
object
(
ptr
,
true
);
pybind11_fail
(
"Unsupported buffer format!"
);
}
...
...
@@ -264,9 +266,9 @@ template <typename T> constexpr const int npy_format_descriptor<
T
,
typename
std
::
enable_if
<
std
::
is_integral
<
T
>::
value
>::
type
>::
values
[
8
];
#define DECL_FMT(Type, NumPyName, Name) template<> struct npy_format_descriptor<Type> { \
enum { value =
array::API
::NumPyName }; \
enum { value =
npy_api
::NumPyName }; \
static object dtype() { \
if (auto ptr =
array::lookup_api
().PyArray_DescrFromType_(value)) \
if (auto ptr =
npy_api::get
().PyArray_DescrFromType_(value)) \
return object(ptr, true); \
pybind11_fail("Unsupported buffer format!"); \
} \
...
...
@@ -281,7 +283,7 @@ DECL_FMT(std::complex<double>, NPY_CDOUBLE_, "complex128");
#define DECL_CHAR_FMT \
static PYBIND11_DESCR name() { return _("S") + _<N>(); } \
static object dtype() { \
auto& api =
array::lookup_api
(); \
auto& api =
npy_api::get
();
\
PyObject *descr = nullptr; \
PYBIND11_DESCR fmt = _("S") + _<N>(); \
pybind11::str py_fmt(fmt.text()); \
...
...
@@ -319,7 +321,7 @@ struct npy_format_descriptor<T, typename std::enable_if<is_pod_struct<T>::value>
}
static
void
register_dtype
(
std
::
initializer_list
<
field_descriptor
>
fields
)
{
auto
&
api
=
array
::
lookup_api
();
auto
&
api
=
npy_api
::
get
();
auto
args
=
dict
();
list
names
{
},
offsets
{
},
formats
{
};
for
(
auto
field
:
fields
)
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment