Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
pybind11
Commits
01f74095
Commit
01f74095
authored
Jul 23, 2016
by
Ivan Smirnov
Browse files
Initial implementation of py::dtype
parent
05cb58ad
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
110 additions
and
83 deletions
+110
-83
example/example-numpy-dtypes.cpp
example/example-numpy-dtypes.cpp
+6
-9
include/pybind11/numpy.h
include/pybind11/numpy.h
+104
-74
No files found.
example/example-numpy-dtypes.cpp
View file @
01f74095
...
...
@@ -158,15 +158,12 @@ void print_format_descriptors() {
}
void
print_dtypes
()
{
auto
to_str
=
[](
py
::
object
obj
)
{
return
(
std
::
string
)
(
py
::
str
)
((
py
::
object
)
obj
.
attr
(
"__str__"
))();
};
std
::
cout
<<
to_str
(
py
::
dtype_of
<
SimpleStruct
>
())
<<
std
::
endl
;
std
::
cout
<<
to_str
(
py
::
dtype_of
<
PackedStruct
>
())
<<
std
::
endl
;
std
::
cout
<<
to_str
(
py
::
dtype_of
<
NestedStruct
>
())
<<
std
::
endl
;
std
::
cout
<<
to_str
(
py
::
dtype_of
<
PartialStruct
>
())
<<
std
::
endl
;
std
::
cout
<<
to_str
(
py
::
dtype_of
<
PartialNestedStruct
>
())
<<
std
::
endl
;
std
::
cout
<<
to_str
(
py
::
dtype_of
<
StringStruct
>
())
<<
std
::
endl
;
std
::
cout
<<
(
std
::
string
)
py
::
dtype
::
of
<
SimpleStruct
>
().
str
()
<<
std
::
endl
;
std
::
cout
<<
(
std
::
string
)
py
::
dtype
::
of
<
PackedStruct
>
().
str
()
<<
std
::
endl
;
std
::
cout
<<
(
std
::
string
)
py
::
dtype
::
of
<
NestedStruct
>
().
str
()
<<
std
::
endl
;
std
::
cout
<<
(
std
::
string
)
py
::
dtype
::
of
<
PartialStruct
>
().
str
()
<<
std
::
endl
;
std
::
cout
<<
(
std
::
string
)
py
::
dtype
::
of
<
PartialNestedStruct
>
().
str
()
<<
std
::
endl
;
std
::
cout
<<
(
std
::
string
)
py
::
dtype
::
of
<
StringStruct
>
().
str
()
<<
std
::
endl
;
}
void
init_ex_numpy_dtypes
(
py
::
module
&
m
)
{
...
...
include/pybind11/numpy.h
View file @
01f74095
...
...
@@ -52,7 +52,12 @@ struct npy_api {
return
api
;
}
bool
PyArray_Check_
(
PyObject
*
obj
)
const
{
return
(
bool
)
PyObject_TypeCheck
(
obj
,
PyArray_Type_
);
}
bool
PyArray_Check_
(
PyObject
*
obj
)
const
{
return
(
bool
)
PyObject_TypeCheck
(
obj
,
PyArray_Type_
);
}
bool
PyArrayDescr_Check_
(
PyObject
*
obj
)
const
{
return
(
bool
)
PyObject_TypeCheck
(
obj
,
PyArrayDescr_Type_
);
}
PyObject
*
(
*
PyArray_DescrFromType_
)(
int
);
PyObject
*
(
*
PyArray_NewFromDescr_
)
...
...
@@ -61,6 +66,7 @@ struct npy_api {
PyObject
*
(
*
PyArray_DescrNewFromType_
)(
int
);
PyObject
*
(
*
PyArray_NewCopy_
)(
PyObject
*
,
int
);
PyTypeObject
*
PyArray_Type_
;
PyTypeObject
*
PyArrayDescr_Type_
;
PyObject
*
(
*
PyArray_FromAny_
)
(
PyObject
*
,
PyObject
*
,
int
,
int
,
int
,
PyObject
*
);
int
(
*
PyArray_DescrConverter_
)
(
PyObject
*
,
PyObject
**
);
bool
(
*
PyArray_EquivTypes_
)
(
PyObject
*
,
PyObject
*
);
...
...
@@ -69,6 +75,7 @@ struct npy_api {
private:
enum
functions
{
API_PyArray_Type
=
2
,
API_PyArrayDescr_Type
=
3
,
API_PyArray_DescrFromType
=
45
,
API_PyArray_FromAny
=
69
,
API_PyArray_NewCopy
=
85
,
...
...
@@ -90,6 +97,7 @@ private:
npy_api
api
;
#define DECL_NPY_API(Func) api.Func##_ = (decltype(api.Func##_)) api_ptr[API_##Func];
DECL_NPY_API
(
PyArray_Type
);
DECL_NPY_API
(
PyArrayDescr_Type
);
DECL_NPY_API
(
PyArray_DescrFromType
);
DECL_NPY_API
(
PyArray_FromAny
);
DECL_NPY_API
(
PyArray_NewCopy
);
...
...
@@ -104,56 +112,55 @@ private:
};
}
class
array
:
public
buffer
{
class
dtype
:
public
object
{
public:
PYBIND11_OBJECT_DEFAULT
(
array
,
buffer
,
detail
::
npy_api
::
get
().
PyArray_Check_
)
PYBIND11_OBJECT_DEFAULT
(
dtype
,
object
,
detail
::
npy_api
::
get
().
PyArray
Descr
_Check_
)
;
enum
{
c_style
=
detail
::
npy_api
::
NPY_C_CONTIGUOUS_
,
f_style
=
detail
::
npy_api
::
NPY_F_CONTIGUOUS_
,
forcecast
=
detail
::
npy_api
::
NPY_ARRAY_FORCECAST_
};
dtype
(
const
buffer_info
&
info
)
{
dtype
descr
(
_dtype_from_pep3118
()(
pybind11
::
str
(
info
.
format
)));
m_ptr
=
descr
.
strip_padding
().
release
().
ptr
();
}
template
<
typename
Type
>
array
(
size_t
size
,
const
Type
*
ptr
)
{
auto
&
api
=
detail
::
npy_api
::
get
();
PyObject
*
descr
=
detail
::
npy_format_descriptor
<
Type
>::
dtype
().
release
().
ptr
();
Py_intptr_t
shape
=
(
Py_intptr_t
)
size
;
object
tmp
=
object
(
api
.
PyArray_NewFromDescr_
(
api
.
PyArray_Type_
,
descr
,
1
,
&
shape
,
nullptr
,
(
void
*
)
ptr
,
0
,
nullptr
),
false
);
if
(
!
tmp
)
pybind11_fail
(
"NumPy: unable to create array!"
);
if
(
ptr
)
tmp
=
object
(
api
.
PyArray_NewCopy_
(
tmp
.
ptr
(),
-
1
/* any order */
),
false
);
m_ptr
=
tmp
.
release
().
ptr
();
dtype
(
std
::
string
format
)
{
m_ptr
=
from_args
(
pybind11
::
str
(
format
)).
release
().
ptr
();
}
array
(
const
buffer_info
&
info
)
{
auto
&
api
=
detail
::
npy_api
::
get
();
static
dtype
from_args
(
object
args
)
{
// This is essentially the same as calling np.dtype() constructor in Python
PyObject
*
ptr
=
nullptr
;
if
(
!
detail
::
npy_api
::
get
().
PyArray_DescrConverter_
(
args
.
release
().
ptr
(),
&
ptr
)
||
!
ptr
)
pybind11_fail
(
"NumPy: failed to create structured dtype"
);
return
object
(
ptr
,
false
);
}
// _dtype_from_pep3118 returns dtypes with padding fields in, so we need to strip them
auto
numpy_internal
=
module
::
import
(
"numpy.core._internal"
);
auto
dtype_from_fmt
=
(
object
)
numpy_internal
.
attr
(
"_dtype_from_pep3118"
);
auto
dtype
=
strip_padding_fields
(
dtype_from_fmt
(
pybind11
::
str
(
info
.
format
)));
template
<
typename
T
>
static
dtype
of
()
{
return
detail
::
npy_format_descriptor
<
T
>::
dtype
();
}
object
tmp
(
api
.
PyArray_NewFromDescr_
(
api
.
PyArray_Type_
,
dtype
.
release
().
ptr
(),
(
int
)
info
.
ndim
,
(
Py_intptr_t
*
)
&
info
.
shape
[
0
],
(
Py_intptr_t
*
)
&
info
.
strides
[
0
],
info
.
ptr
,
0
,
nullptr
),
false
);
if
(
!
tmp
)
pybind11_fail
(
"NumPy: unable to create array!"
);
if
(
info
.
ptr
)
tmp
=
object
(
api
.
PyArray_NewCopy_
(
tmp
.
ptr
(),
-
1
/* any order */
),
false
);
m_ptr
=
tmp
.
release
().
ptr
();
size_t
itemsize
()
const
{
return
(
size_t
)
attr
(
"itemsize"
).
cast
<
int_
>
();
}
protected:
template
<
typename
T
,
typename
SFINAE
>
friend
struct
detail
::
npy_format_descriptor
;
bool
has_fields
()
const
{
return
attr
(
"fields"
).
cast
<
object
>
().
ptr
()
!=
Py_None
;
}
std
::
string
kind
()
const
{
return
(
std
::
string
)
attr
(
"kind"
).
cast
<
pybind11
::
str
>
();
}
private:
static
object
&
_dtype_from_pep3118
()
{
static
object
obj
=
module
::
import
(
"numpy.core._internal"
).
attr
(
"_dtype_from_pep3118"
);
return
obj
;
}
static
object
strip_padding
_fields
(
object
dtype
)
{
dtype
strip_padding
(
)
{
// Recursively strip all void fields with empty names that are generated for
// padding fields (as of NumPy v1.11).
auto
fields
=
dtype
.
attr
(
"fields"
).
cast
<
object
>
();
auto
fields
=
attr
(
"fields"
).
cast
<
object
>
();
if
(
fields
.
ptr
()
==
Py_None
)
return
dtype
;
return
*
this
;
struct
field_descr
{
pybind11
::
str
name
;
object
format
;
int_
offset
;
};
std
::
vector
<
field_descr
>
field_descriptors
;
...
...
@@ -162,11 +169,11 @@ protected:
for
(
auto
field
:
items
())
{
auto
spec
=
object
(
field
,
true
).
cast
<
tuple
>
();
auto
name
=
spec
[
0
].
cast
<
pybind11
::
str
>
();
auto
format
=
spec
[
1
].
cast
<
tuple
>
()[
0
].
cast
<
object
>
();
auto
format
=
spec
[
1
].
cast
<
tuple
>
()[
0
].
cast
<
dtype
>
();
auto
offset
=
spec
[
1
].
cast
<
tuple
>
()[
1
].
cast
<
int_
>
();
if
(
!
len
(
name
)
&&
(
std
::
string
)
dtype
.
attr
(
"kind"
).
cast
<
pybind11
::
str
>
()
==
"V"
)
if
(
!
len
(
name
)
&&
format
.
kind
()
==
"V"
)
continue
;
field_descriptors
.
push_back
({
name
,
strip_padding
_fields
(
format
),
offset
});
field_descriptors
.
push_back
({
name
,
format
.
strip_padding
(
),
offset
});
}
std
::
sort
(
field_descriptors
.
begin
(),
field_descriptors
.
end
(),
...
...
@@ -176,19 +183,57 @@ protected:
list
names
,
formats
,
offsets
;
for
(
auto
&
descr
:
field_descriptors
)
{
names
.
append
(
descr
.
name
);
formats
.
append
(
descr
.
format
);
offsets
.
append
(
descr
.
offset
);
names
.
append
(
descr
.
name
);
formats
.
append
(
descr
.
format
);
offsets
.
append
(
descr
.
offset
);
}
auto
args
=
dict
();
args
[
"names"
]
=
names
;
args
[
"formats"
]
=
formats
;
args
[
"offsets"
]
=
offsets
;
args
[
"itemsize"
]
=
dtype
.
attr
(
"itemsize"
).
cast
<
int_
>
();
args
[
"itemsize"
]
=
(
int_
)
itemsize
();
return
dtype
::
from_args
(
args
);
}
};
PyObject
*
descr
=
nullptr
;
if
(
!
detail
::
npy_api
::
get
().
PyArray_DescrConverter_
(
args
.
release
().
ptr
(),
&
descr
)
||
!
descr
)
pybind11_fail
(
"NumPy: failed to create structured dtype"
);
return
object
(
descr
,
false
);
class
array
:
public
buffer
{
public:
PYBIND11_OBJECT_DEFAULT
(
array
,
buffer
,
detail
::
npy_api
::
get
().
PyArray_Check_
)
enum
{
c_style
=
detail
::
npy_api
::
NPY_C_CONTIGUOUS_
,
f_style
=
detail
::
npy_api
::
NPY_F_CONTIGUOUS_
,
forcecast
=
detail
::
npy_api
::
NPY_ARRAY_FORCECAST_
};
template
<
typename
Type
>
array
(
size_t
size
,
const
Type
*
ptr
)
{
auto
&
api
=
detail
::
npy_api
::
get
();
auto
descr
=
pybind11
::
dtype
::
of
<
Type
>
().
release
().
ptr
();
Py_intptr_t
shape
=
(
Py_intptr_t
)
size
;
object
tmp
=
object
(
api
.
PyArray_NewFromDescr_
(
api
.
PyArray_Type_
,
descr
,
1
,
&
shape
,
nullptr
,
(
void
*
)
ptr
,
0
,
nullptr
),
false
);
if
(
!
tmp
)
pybind11_fail
(
"NumPy: unable to create array!"
);
if
(
ptr
)
tmp
=
object
(
api
.
PyArray_NewCopy_
(
tmp
.
ptr
(),
-
1
/* any order */
),
false
);
m_ptr
=
tmp
.
release
().
ptr
();
}
array
(
const
buffer_info
&
info
)
{
auto
&
api
=
detail
::
npy_api
::
get
();
auto
descr
=
pybind11
::
dtype
(
info
).
release
().
ptr
();
object
tmp
(
api
.
PyArray_NewFromDescr_
(
api
.
PyArray_Type_
,
descr
,
(
int
)
info
.
ndim
,
(
Py_intptr_t
*
)
&
info
.
shape
[
0
],
(
Py_intptr_t
*
)
&
info
.
strides
[
0
],
info
.
ptr
,
0
,
nullptr
),
false
);
if
(
!
tmp
)
pybind11_fail
(
"NumPy: unable to create array!"
);
if
(
info
.
ptr
)
tmp
=
object
(
api
.
PyArray_NewCopy_
(
tmp
.
ptr
(),
-
1
/* any order */
),
false
);
m_ptr
=
tmp
.
release
().
ptr
();
}
pybind11
::
dtype
dtype
()
{
return
attr
(
"dtype"
).
cast
<
pybind11
::
dtype
>
();
}
protected:
template
<
typename
T
,
typename
SFINAE
>
friend
struct
detail
::
npy_format_descriptor
;
};
template
<
typename
T
,
int
ExtraFlags
=
array
::
forcecast
>
class
array_t
:
public
array
{
...
...
@@ -201,8 +246,7 @@ public:
if
(
ptr
==
nullptr
)
return
nullptr
;
auto
&
api
=
detail
::
npy_api
::
get
();
PyObject
*
descr
=
detail
::
npy_format_descriptor
<
T
>::
dtype
().
release
().
ptr
();
PyObject
*
result
=
api
.
PyArray_FromAny_
(
ptr
,
descr
,
0
,
0
,
PyObject
*
result
=
api
.
PyArray_FromAny_
(
ptr
,
pybind11
::
dtype
::
of
<
T
>
().
release
().
ptr
(),
0
,
0
,
detail
::
npy_api
::
NPY_ENSURE_ARRAY_
|
ExtraFlags
,
nullptr
);
if
(
!
result
)
PyErr_Clear
();
...
...
@@ -223,11 +267,6 @@ template <size_t N> struct format_descriptor<std::array<char, N>> {
static
const
char
*
format
()
{
PYBIND11_DESCR
s
=
detail
::
_
<
N
>
()
+
detail
::
_
(
"s"
);
return
s
.
text
();
}
};
template
<
typename
T
>
object
dtype_of
()
{
return
detail
::
npy_format_descriptor
<
T
>::
dtype
();
}
NAMESPACE_BEGIN
(
detail
)
template
<
typename
T
>
struct
is_std_array
:
std
::
false_type
{
};
template
<
typename
T
,
size_t
N
>
struct
is_std_array
<
std
::
array
<
T
,
N
>>
:
std
::
true_type
{
};
...
...
@@ -252,7 +291,7 @@ private:
npy_api
::
NPY_INT_
,
npy_api
::
NPY_UINT_
,
npy_api
::
NPY_LONGLONG_
,
npy_api
::
NPY_ULONGLONG_
};
public:
enum
{
value
=
values
[
detail
::
log2
(
sizeof
(
T
))
*
2
+
(
std
::
is_unsigned
<
T
>::
value
?
1
:
0
)]
};
static
object
dtype
()
{
static
pybind11
::
dtype
dtype
()
{
if
(
auto
ptr
=
npy_api
::
get
().
PyArray_DescrFromType_
(
value
))
return
object
(
ptr
,
true
);
pybind11_fail
(
"Unsupported buffer format!"
);
...
...
@@ -267,7 +306,7 @@ template <typename T> constexpr const int npy_format_descriptor<
#define DECL_FMT(Type, NumPyName, Name) template<> struct npy_format_descriptor<Type> { \
enum { value = npy_api::NumPyName }; \
static
object
dtype() { \
static
pybind11::dtype
dtype() { \
if (auto ptr = npy_api::get().PyArray_DescrFromType_(value)) \
return object(ptr, true); \
pybind11_fail("Unsupported buffer format!"); \
...
...
@@ -282,14 +321,9 @@ DECL_FMT(std::complex<double>, NPY_CDOUBLE_, "complex128");
#define DECL_CHAR_FMT \
static PYBIND11_DESCR name() { return _("S") + _<N>(); } \
static object dtype() { \
auto& api = npy_api::get(); \
PyObject *descr = nullptr; \
static pybind11::dtype dtype() { \
PYBIND11_DESCR fmt = _("S") + _<N>(); \
pybind11::str py_fmt(fmt.text()); \
if (!api.PyArray_DescrConverter_(py_fmt.release().ptr(), &descr) || !descr) \
pybind11_fail("NumPy: failed to create string dtype"); \
return object(descr, false); \
return pybind11::dtype::from_args(pybind11::str(fmt.text())); \
} \
static const char *format() { PYBIND11_DESCR s = _<N>() + _("s"); return s.text(); }
template
<
size_t
N
>
struct
npy_format_descriptor
<
char
[
N
]
>
{
DECL_CHAR_FMT
};
...
...
@@ -301,14 +335,14 @@ struct field_descriptor {
size_t
offset
;
size_t
size
;
const
char
*
format
;
object
descr
;
dtype
descr
;
};
template
<
typename
T
>
struct
npy_format_descriptor
<
T
,
typename
std
::
enable_if
<
is_pod_struct
<
T
>::
value
>::
type
>
{
static
PYBIND11_DESCR
name
()
{
return
_
(
"struct"
);
}
static
object
dtype
()
{
static
pybind11
::
dtype
dtype
()
{
if
(
!
dtype_
())
pybind11_fail
(
"NumPy: unsupported buffer format!"
);
return
object
(
dtype_
(),
true
);
...
...
@@ -321,7 +355,6 @@ struct npy_format_descriptor<T, typename std::enable_if<is_pod_struct<T>::value>
}
static
void
register_dtype
(
std
::
initializer_list
<
field_descriptor
>
fields
)
{
auto
&
api
=
npy_api
::
get
();
auto
args
=
dict
();
list
names
{
},
offsets
{
},
formats
{
};
for
(
auto
field
:
fields
)
{
...
...
@@ -333,10 +366,7 @@ struct npy_format_descriptor<T, typename std::enable_if<is_pod_struct<T>::value>
}
args
[
"names"
]
=
names
;
args
[
"offsets"
]
=
offsets
;
args
[
"formats"
]
=
formats
;
args
[
"itemsize"
]
=
int_
(
sizeof
(
T
));
// This is essentially the same as calling np.dtype() constructor in Python and passing
// it a dict of the form {'names': ..., 'formats': ..., 'offsets': ...}.
if
(
!
api
.
PyArray_DescrConverter_
(
args
.
release
().
ptr
(),
&
dtype_
())
||
!
dtype_
())
pybind11_fail
(
"NumPy: failed to create structured dtype"
);
dtype_
()
=
pybind11
::
dtype
::
from_args
(
args
).
release
().
ptr
();
// There is an existing bug in NumPy (as of v1.11): trailing bytes are
// not encoded explicitly into the format string. This will supposedly
...
...
@@ -366,9 +396,9 @@ struct npy_format_descriptor<T, typename std::enable_if<is_pod_struct<T>::value>
format_
()
=
oss
.
str
();
// Sanity check: verify that NumPy properly parses our buffer format string
auto
&
api
=
npy_api
::
get
();
auto
arr
=
array
(
buffer_info
(
nullptr
,
sizeof
(
T
),
format
(),
1
,
{
0
},
{
sizeof
(
T
)
}));
auto
fixed_dtype
=
array
::
strip_padding_fields
(
object
(
dtype_
(),
true
));
if
(
!
api
.
PyArray_EquivTypes_
(
dtype_
(),
fixed_dtype
.
ptr
()))
if
(
!
api
.
PyArray_EquivTypes_
(
dtype_
(),
arr
.
dtype
().
ptr
()))
pybind11_fail
(
"NumPy: invalid buffer descriptor!"
);
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment