Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
pybind11
Commits
dd9bd777
Commit
dd9bd777
authored
Oct 25, 2016
by
Wenzel Jakob
Committed by
GitHub
Oct 25, 2016
Browse files
Merge pull request #453 from aldanor/feature/numpy-scalars
NumPy scalars to ctypes conversion support
parents
6ba98650
8f3e045d
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
96 additions
and
18 deletions
+96
-18
include/pybind11/cast.h
include/pybind11/cast.h
+8
-2
include/pybind11/common.h
include/pybind11/common.h
+1
-0
include/pybind11/numpy.h
include/pybind11/numpy.h
+36
-0
include/pybind11/pybind11.h
include/pybind11/pybind11.h
+10
-16
tests/test_numpy_dtypes.cpp
tests/test_numpy_dtypes.cpp
+10
-0
tests/test_numpy_dtypes.py
tests/test_numpy_dtypes.py
+31
-0
No files found.
include/pybind11/cast.h
View file @
dd9bd777
...
...
@@ -26,6 +26,7 @@ struct type_info {
void
(
*
init_holder
)(
PyObject
*
,
const
void
*
);
std
::
vector
<
PyObject
*
(
*
)(
PyObject
*
,
PyTypeObject
*
)
>
implicit_conversions
;
std
::
vector
<
std
::
pair
<
const
std
::
type_info
*
,
void
*
(
*
)(
void
*
)
>>
implicit_casts
;
std
::
vector
<
bool
(
*
)(
PyObject
*
,
void
*&
)
>
*
direct_conversions
;
buffer_info
*
(
*
get_buffer
)(
PyObject
*
,
void
*
)
=
nullptr
;
void
*
get_buffer_data
=
nullptr
;
/** A simple type never occurs as a (direct or indirect) parent
...
...
@@ -90,7 +91,8 @@ PYBIND11_NOINLINE inline detail::type_info* get_type_info(PyTypeObject *type) {
}
while
(
true
);
}
PYBIND11_NOINLINE
inline
detail
::
type_info
*
get_type_info
(
const
std
::
type_info
&
tp
,
bool
throw_if_missing
)
{
PYBIND11_NOINLINE
inline
detail
::
type_info
*
get_type_info
(
const
std
::
type_info
&
tp
,
bool
throw_if_missing
=
false
)
{
auto
&
types
=
get_internals
().
registered_types_cpp
;
auto
it
=
types
.
find
(
std
::
type_index
(
tp
));
...
...
@@ -157,7 +159,7 @@ inline void keep_alive_impl(handle nurse, handle patient);
class
type_caster_generic
{
public:
PYBIND11_NOINLINE
type_caster_generic
(
const
std
::
type_info
&
type_info
)
:
typeinfo
(
get_type_info
(
type_info
,
false
))
{
}
:
typeinfo
(
get_type_info
(
type_info
))
{
}
PYBIND11_NOINLINE
bool
load
(
handle
src
,
bool
convert
)
{
if
(
!
src
)
...
...
@@ -215,6 +217,10 @@ public:
if
(
load
(
temp
,
false
))
return
true
;
}
for
(
auto
&
converter
:
*
typeinfo
->
direct_conversions
)
{
if
(
converter
(
src
.
ptr
(),
value
))
return
true
;
}
}
return
false
;
}
...
...
include/pybind11/common.h
View file @
dd9bd777
...
...
@@ -321,6 +321,7 @@ struct internals {
std
::
unordered_map
<
const
void
*
,
void
*>
registered_types_py
;
// PyTypeObject* -> type_info
std
::
unordered_multimap
<
const
void
*
,
void
*>
registered_instances
;
// void * -> PyObject*
std
::
unordered_set
<
std
::
pair
<
const
PyObject
*
,
const
char
*>
,
overload_hash
>
inactive_overload_cache
;
std
::
unordered_map
<
std
::
type_index
,
std
::
vector
<
bool
(
*
)(
PyObject
*
,
void
*&
)
>>
direct_conversions
;
std
::
forward_list
<
void
(
*
)
(
std
::
exception_ptr
)
>
registered_exception_translators
;
#if defined(WITH_THREAD)
decltype
(
PyThread_create_key
())
tstate
=
0
;
// Usually an int but a long on Cygwin64 with Python 3.x
...
...
include/pybind11/numpy.h
View file @
dd9bd777
...
...
@@ -63,6 +63,14 @@ struct PyArray_Proxy {
int
flags
;
};
struct
PyVoidScalarObject_Proxy
{
PyObject_VAR_HEAD
char
*
obval
;
PyArrayDescr_Proxy
*
descr
;
int
flags
;
PyObject
*
base
;
};
struct
npy_api
{
enum
constants
{
NPY_C_CONTIGUOUS_
=
0x0001
,
...
...
@@ -103,7 +111,9 @@ struct npy_api {
PyObject
*
(
*
PyArray_DescrNewFromType_
)(
int
);
PyObject
*
(
*
PyArray_NewCopy_
)(
PyObject
*
,
int
);
PyTypeObject
*
PyArray_Type_
;
PyTypeObject
*
PyVoidArrType_Type_
;
PyTypeObject
*
PyArrayDescr_Type_
;
PyObject
*
(
*
PyArray_DescrFromScalar_
)(
PyObject
*
);
PyObject
*
(
*
PyArray_FromAny_
)
(
PyObject
*
,
PyObject
*
,
int
,
int
,
int
,
PyObject
*
);
int
(
*
PyArray_DescrConverter_
)
(
PyObject
*
,
PyObject
**
);
bool
(
*
PyArray_EquivTypes_
)
(
PyObject
*
,
PyObject
*
);
...
...
@@ -114,7 +124,9 @@ private:
enum
functions
{
API_PyArray_Type
=
2
,
API_PyArrayDescr_Type
=
3
,
API_PyVoidArrType_Type
=
39
,
API_PyArray_DescrFromType
=
45
,
API_PyArray_DescrFromScalar
=
57
,
API_PyArray_FromAny
=
69
,
API_PyArray_NewCopy
=
85
,
API_PyArray_NewFromDescr
=
94
,
...
...
@@ -136,8 +148,10 @@ private:
npy_api
api
;
#define DECL_NPY_API(Func) api.Func##_ = (decltype(api.Func##_)) api_ptr[API_##Func];
DECL_NPY_API
(
PyArray_Type
);
DECL_NPY_API
(
PyVoidArrType_Type
);
DECL_NPY_API
(
PyArrayDescr_Type
);
DECL_NPY_API
(
PyArray_DescrFromType
);
DECL_NPY_API
(
PyArray_DescrFromScalar
);
DECL_NPY_API
(
PyArray_FromAny
);
DECL_NPY_API
(
PyArray_NewCopy
);
DECL_NPY_API
(
PyArray_NewFromDescr
);
...
...
@@ -658,6 +672,9 @@ struct npy_format_descriptor<T, enable_if_t<is_pod_struct<T>::value>> {
}
static
void
register_dtype
(
std
::
initializer_list
<
field_descriptor
>
fields
)
{
if
(
dtype_ptr
)
pybind11_fail
(
"NumPy: dtype is already registered"
);
list
names
,
formats
,
offsets
;
for
(
auto
field
:
fields
)
{
if
(
!
field
.
descr
)
...
...
@@ -700,11 +717,30 @@ struct npy_format_descriptor<T, enable_if_t<is_pod_struct<T>::value>> {
auto
arr
=
array
(
buffer_info
(
nullptr
,
sizeof
(
T
),
format
(),
1
));
if
(
!
api
.
PyArray_EquivTypes_
(
dtype_ptr
,
arr
.
dtype
().
ptr
()))
pybind11_fail
(
"NumPy: invalid buffer descriptor!"
);
register_direct_converter
();
}
private:
static
std
::
string
format_str
;
static
PyObject
*
dtype_ptr
;
static
bool
direct_converter
(
PyObject
*
obj
,
void
*&
value
)
{
auto
&
api
=
npy_api
::
get
();
if
(
!
PyObject_TypeCheck
(
obj
,
api
.
PyVoidArrType_Type_
))
return
false
;
if
(
auto
descr
=
object
(
api
.
PyArray_DescrFromScalar_
(
obj
),
false
))
{
if
(
api
.
PyArray_EquivTypes_
(
dtype_ptr
,
descr
.
ptr
()))
{
value
=
((
PyVoidScalarObject_Proxy
*
)
obj
)
->
obval
;
return
true
;
}
}
return
false
;
}
static
void
register_direct_converter
()
{
get_internals
().
direct_conversions
[
std
::
type_index
(
typeid
(
T
))].
push_back
(
direct_converter
);
}
};
template
<
typename
T
>
...
...
include/pybind11/pybind11.h
View file @
dd9bd777
...
...
@@ -180,8 +180,6 @@ protected:
a
.
descr
=
strdup
(
a
.
value
.
attr
(
"__repr__"
)().
cast
<
std
::
string
>
().
c_str
());
}
auto
const
&
registered_types
=
detail
::
get_internals
().
registered_types_cpp
;
/* Generate a proper function signature */
std
::
string
signature
;
size_t
type_depth
=
0
,
char_index
=
0
,
type_index
=
0
,
arg_index
=
0
;
...
...
@@ -216,9 +214,8 @@ protected:
const
std
::
type_info
*
t
=
types
[
type_index
++
];
if
(
!
t
)
pybind11_fail
(
"Internal error while parsing type signature (1)"
);
auto
it
=
registered_types
.
find
(
std
::
type_index
(
*
t
));
if
(
it
!=
registered_types
.
end
())
{
signature
+=
((
const
detail
::
type_info
*
)
it
->
second
)
->
type
->
tp_name
;
if
(
auto
tinfo
=
detail
::
get_type_info
(
*
t
))
{
signature
+=
tinfo
->
type
->
tp_name
;
}
else
{
std
::
string
tname
(
t
->
name
());
detail
::
clean_type_id
(
tname
);
...
...
@@ -610,8 +607,7 @@ protected:
auto
&
internals
=
get_internals
();
auto
tindex
=
std
::
type_index
(
*
(
rec
->
type
));
if
(
internals
.
registered_types_cpp
.
find
(
tindex
)
!=
internals
.
registered_types_cpp
.
end
())
if
(
get_type_info
(
*
(
rec
->
type
)))
pybind11_fail
(
"generic_type: type
\"
"
+
std
::
string
(
rec
->
name
)
+
"
\"
is already registered!"
);
...
...
@@ -672,6 +668,7 @@ protected:
tinfo
->
type
=
(
PyTypeObject
*
)
type
;
tinfo
->
type_size
=
rec
->
type_size
;
tinfo
->
init_holder
=
rec
->
init_holder
;
tinfo
->
direct_conversions
=
&
internals
.
direct_conversions
[
tindex
];
internals
.
registered_types_cpp
[
tindex
]
=
tinfo
;
internals
.
registered_types_py
[
type
]
=
tinfo
;
...
...
@@ -1333,11 +1330,11 @@ template <typename InputType, typename OutputType> void implicitly_convertible()
PyErr_Clear
();
return
result
;
};
auto
&
registered_types
=
detail
::
get_internals
().
registered_types_cpp
;
auto
it
=
registered_types
.
find
(
std
::
type_index
(
typeid
(
OutputType
)));
if
(
it
==
registered_types
.
end
())
if
(
auto
tinfo
=
detail
::
get_type_info
(
typeid
(
OutputType
)))
tinfo
->
implicit_conversions
.
push_back
(
implicit_caster
);
else
pybind11_fail
(
"implicitly_convertible: Unable to find type "
+
type_id
<
OutputType
>
());
((
detail
::
type_info
*
)
it
->
second
)
->
implicit_conversions
.
push_back
(
implicit_caster
);
}
template
<
typename
ExceptionTranslator
>
...
...
@@ -1589,11 +1586,8 @@ inline function get_type_overload(const void *this_ptr, const detail::type_info
}
template
<
class
T
>
function
get_overload
(
const
T
*
this_ptr
,
const
char
*
name
)
{
auto
&
cpp_types
=
detail
::
get_internals
().
registered_types_cpp
;
auto
it
=
cpp_types
.
find
(
typeid
(
T
));
if
(
it
==
cpp_types
.
end
())
return
function
();
return
get_type_overload
(
this_ptr
,
(
const
detail
::
type_info
*
)
it
->
second
,
name
);
auto
tinfo
=
detail
::
get_type_info
(
typeid
(
T
));
return
tinfo
?
get_type_overload
(
this_ptr
,
tinfo
,
name
)
:
function
();
}
#define PYBIND11_OVERLOAD_INT(ret_type, cname, name, ...) { \
...
...
tests/test_numpy_dtypes.cpp
View file @
dd9bd777
...
...
@@ -298,6 +298,9 @@ test_initializer numpy_dtypes([](py::module &m) {
return
;
}
// typeinfo may be registered before the dtype descriptor for scalar casts to work...
py
::
class_
<
SimpleStruct
>
(
m
,
"SimpleStruct"
);
PYBIND11_NUMPY_DTYPE
(
SimpleStruct
,
x
,
y
,
z
);
PYBIND11_NUMPY_DTYPE
(
PackedStruct
,
x
,
y
,
z
);
PYBIND11_NUMPY_DTYPE
(
NestedStruct
,
a
,
b
);
...
...
@@ -306,6 +309,9 @@ test_initializer numpy_dtypes([](py::module &m) {
PYBIND11_NUMPY_DTYPE
(
StringStruct
,
a
,
b
);
PYBIND11_NUMPY_DTYPE
(
EnumStruct
,
e1
,
e2
);
// ... or after
py
::
class_
<
PackedStruct
>
(
m
,
"PackedStruct"
);
m
.
def
(
"create_rec_simple"
,
&
create_recarray
<
SimpleStruct
>
);
m
.
def
(
"create_rec_packed"
,
&
create_recarray
<
PackedStruct
>
);
m
.
def
(
"create_rec_nested"
,
&
create_nested
);
...
...
@@ -324,6 +330,10 @@ test_initializer numpy_dtypes([](py::module &m) {
m
.
def
(
"test_array_ctors"
,
&
test_array_ctors
);
m
.
def
(
"test_dtype_ctors"
,
&
test_dtype_ctors
);
m
.
def
(
"test_dtype_methods"
,
&
test_dtype_methods
);
m
.
def
(
"f_simple"
,
[](
SimpleStruct
s
)
{
return
s
.
y
*
10
;
});
m
.
def
(
"f_packed"
,
[](
PackedStruct
s
)
{
return
s
.
y
*
10
;
});
m
.
def
(
"f_nested"
,
[](
NestedStruct
s
)
{
return
s
.
a
.
y
*
10
;
});
m
.
def
(
"register_dtype"
,
[]()
{
PYBIND11_NUMPY_DTYPE
(
SimpleStruct
,
x
,
y
,
z
);
});
});
#undef PYBIND11_PACKED
tests/test_numpy_dtypes.py
View file @
dd9bd777
...
...
@@ -174,3 +174,34 @@ def test_signature(doc):
from
pybind11_tests
import
create_rec_nested
assert
doc
(
create_rec_nested
)
==
"create_rec_nested(arg0: int) -> numpy.ndarray[NestedStruct]"
@
pytest
.
requires_numpy
def
test_scalar_conversion
():
from
pybind11_tests
import
(
create_rec_simple
,
f_simple
,
create_rec_packed
,
f_packed
,
create_rec_nested
,
f_nested
,
create_enum_array
)
n
=
3
arrays
=
[
create_rec_simple
(
n
),
create_rec_packed
(
n
),
create_rec_nested
(
n
),
create_enum_array
(
n
)]
funcs
=
[
f_simple
,
f_packed
,
f_nested
]
for
i
,
func
in
enumerate
(
funcs
):
for
j
,
arr
in
enumerate
(
arrays
):
if
i
==
j
and
i
<
2
:
assert
[
func
(
arr
[
k
])
for
k
in
range
(
n
)]
==
[
k
*
10
for
k
in
range
(
n
)]
else
:
with
pytest
.
raises
(
TypeError
)
as
excinfo
:
func
(
arr
[
0
])
assert
'incompatible function arguments'
in
str
(
excinfo
.
value
)
@
pytest
.
requires_numpy
def
test_register_dtype
():
from
pybind11_tests
import
register_dtype
with
pytest
.
raises
(
RuntimeError
)
as
excinfo
:
register_dtype
()
assert
'dtype is already registered'
in
str
(
excinfo
.
value
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment