Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
pybind11
Commits
2488b320
Commit
2488b320
authored
Jun 19, 2016
by
Ivan Smirnov
Browse files
Add PYBIND11_DTYPE macro for registering dtypes
parent
fab02efb
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
107 additions
and
0 deletions
+107
-0
include/pybind11/numpy.h
include/pybind11/numpy.h
+107
-0
No files found.
include/pybind11/numpy.h
View file @
2488b320
...
...
@@ -14,6 +14,8 @@
#include <numeric>
#include <algorithm>
#include <cstdlib>
#include <cstring>
#include <initializer_list>
#if defined(_MSC_VER)
#pragma warning(push)
...
...
@@ -32,6 +34,7 @@ public:
API_PyArray_FromAny
=
69
,
API_PyArray_NewCopy
=
85
,
API_PyArray_NewFromDescr
=
94
,
API_PyArray_DescrConverter
=
174
,
API_PyArray_GetArrayParamsFromObject
=
278
,
NPY_C_CONTIGUOUS_
=
0x0001
,
...
...
@@ -63,6 +66,7 @@ public:
DECL_NPY_API
(
PyArray_FromAny
);
DECL_NPY_API
(
PyArray_NewCopy
);
DECL_NPY_API
(
PyArray_NewFromDescr
);
DECL_NPY_API
(
PyArray_DescrConverter
);
DECL_NPY_API
(
PyArray_GetArrayParamsFromObject
);
#undef DECL_NPY_API
return
api
;
...
...
@@ -77,6 +81,7 @@ public:
PyObject
*
(
*
PyArray_NewCopy_
)(
PyObject
*
,
int
);
PyTypeObject
*
PyArray_Type_
;
PyObject
*
(
*
PyArray_FromAny_
)
(
PyObject
*
,
PyObject
*
,
int
,
int
,
int
,
PyObject
*
);
int
(
*
PyArray_DescrConverter_
)
(
PyObject
*
,
PyObject
**
);
int
(
*
PyArray_GetArrayParamsFromObject_
)(
PyObject
*
,
PyObject
*
,
char
,
PyObject
**
,
int
*
,
Py_ssize_t
*
,
PyObject
**
,
PyObject
*
);
};
...
...
@@ -149,6 +154,19 @@ public:
}
};
template
<
typename
T
>
struct
format_descriptor
<
T
,
typename
std
::
enable_if
<
std
::
is_pod
<
T
>::
value
&&
!
std
::
is_integral
<
T
>::
value
&&
!
std
::
is_same
<
T
,
float
>::
value
&&
!
std
::
is_same
<
T
,
bool
>::
value
&&
!
std
::
is_same
<
T
,
std
::
complex
<
float
>>::
value
&&
!
std
::
is_same
<
T
,
std
::
complex
<
double
>>::
value
>::
type
>
{
static
const
char
*
value
()
{
return
detail
::
npy_format_descriptor
<
T
>::
format_str
();
}
};
NAMESPACE_BEGIN
(
detail
)
template
<
typename
T
>
struct
npy_format_descriptor
<
T
,
typename
std
::
enable_if
<
std
::
is_integral
<
T
>::
value
>::
type
>
{
...
...
@@ -184,6 +202,95 @@ DECL_FMT(std::complex<float>, NPY_CFLOAT_, "complex64");
DECL_FMT
(
std
::
complex
<
double
>
,
NPY_CDOUBLE_
,
"complex128"
);
#undef DECL_FMT
struct
field_descriptor
{
const
char
*
name
;
int
offset
;
PyObject
*
descr
;
};
template
<
typename
T
>
struct
npy_format_descriptor
<
T
,
typename
std
::
enable_if
<
std
::
is_pod
<
T
>::
value
&&
// offsetof only works correctly for POD types
!
std
::
is_integral
<
T
>::
value
&&
!
std
::
is_same
<
T
,
float
>::
value
&&
!
std
::
is_same
<
T
,
bool
>::
value
&&
!
std
::
is_same
<
T
,
std
::
complex
<
float
>>::
value
&&
!
std
::
is_same
<
T
,
std
::
complex
<
double
>>::
value
>::
type
>
{
static
PYBIND11_DESCR
name
()
{
return
_
(
"user-defined"
);
}
static
PyObject
*
descr
()
{
if
(
!
descr_
())
pybind11_fail
(
"NumPy: unsupported buffer format!"
);
return
descr_
();
}
static
const
char
*
format_str
()
{
return
format_str_
();
}
static
void
register_dtype
(
std
::
initializer_list
<
field_descriptor
>
fields
)
{
array
::
API
&
api
=
array
::
lookup_api
();
auto
args
=
py
::
dict
();
py
::
list
names
{
},
offsets
{
},
formats
{
};
std
::
vector
<
py
::
object
>
dtypes
;
for
(
auto
field
:
fields
)
{
names
.
append
(
py
::
str
(
field
.
name
));
offsets
.
append
(
py
::
int_
(
field
.
offset
));
if
(
!
field
.
descr
)
pybind11_fail
(
"NumPy: unsupported field dtype"
);
dtypes
.
emplace_back
(
field
.
descr
,
false
);
formats
.
append
(
dtypes
.
back
());
}
args
[
"names"
]
=
names
;
args
[
"offsets"
]
=
offsets
;
args
[
"formats"
]
=
formats
;
if
(
!
api
.
PyArray_DescrConverter_
(
args
.
ptr
(),
&
descr_
())
||
!
descr_
())
pybind11_fail
(
"NumPy: failed to create structured dtype"
);
auto
np
=
module
::
import
(
"numpy"
);
auto
empty
=
(
object
)
np
.
attr
(
"empty"
);
if
(
auto
arr
=
(
object
)
empty
(
py
::
int_
(
0
),
object
(
descr
(),
true
)))
if
(
auto
view
=
PyMemoryView_FromObject
(
arr
.
ptr
()))
if
(
auto
info
=
PyMemoryView_GET_BUFFER
(
view
))
{
std
::
strncpy
(
format_str_
(),
info
->
format
,
4096
);
return
;
}
pybind11_fail
(
"NumPy: failed to extract buffer format"
);
}
private:
static
inline
PyObject
*&
descr_
()
{
static
PyObject
*
ptr
=
nullptr
;
return
ptr
;
}
static
inline
char
*
format_str_
()
{
static
char
s
[
4096
];
return
s
;
}
};
#define FIELD_DESCRIPTOR(Type, Field) \
::pybind11::detail::field_descriptor { \
#Field, offsetof(Type, Field), \
::pybind11::detail::npy_format_descriptor<decltype(static_cast<Type*>(0)->Field)>::descr() }
// The main idea of this macro is borrowed from https://github.com/swansontec/map-macro
// (C) William Swanson, Paul Fultz
#define EVAL0(...) __VA_ARGS__
#define EVAL1(...) EVAL0 (EVAL0 (EVAL0 (__VA_ARGS__)))
#define EVAL2(...) EVAL1 (EVAL1 (EVAL1 (__VA_ARGS__)))
#define EVAL3(...) EVAL2 (EVAL2 (EVAL2 (__VA_ARGS__)))
#define EVAL4(...) EVAL3 (EVAL3 (EVAL3 (__VA_ARGS__)))
#define EVAL(...) EVAL4 (EVAL4 (EVAL4 (__VA_ARGS__)))
#define MAP_END(...)
#define MAP_OUT
#define MAP_COMMA ,
#define MAP_GET_END() 0, MAP_END
#define MAP_NEXT0(test, next, ...) next MAP_OUT
#define MAP_NEXT1(test, next) MAP_NEXT0 (test, next, 0)
#define MAP_NEXT(test, next) MAP_NEXT1 (MAP_GET_END test, next)
#define MAP_LIST_NEXT1(test, next) MAP_NEXT0 (test, MAP_COMMA next, 0)
#define MAP_LIST_NEXT(test, next) MAP_LIST_NEXT1 (MAP_GET_END test, next)
#define MAP_LIST0(f, t, x, peek, ...) f(t, x) MAP_LIST_NEXT (peek, MAP_LIST1) (f, t, peek, __VA_ARGS__)
#define MAP_LIST1(f, t, x, peek, ...) f(t, x) MAP_LIST_NEXT (peek, MAP_LIST0) (f, t, peek, __VA_ARGS__)
#define MAP_LIST(f, t, ...) EVAL (MAP_LIST1 (f, t, __VA_ARGS__, (), 0))
#define PYBIND11_DTYPE(Type, ...) \
::pybind11::detail::npy_format_descriptor<Type>::register_dtype({MAP_LIST(FIELD_DESCRIPTOR, Type, __VA_ARGS__)})
template
<
class
T
>
using
array_iterator
=
typename
std
::
add_pointer
<
T
>::
type
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment