Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
gaoqiong
pybind11
Commits
43398a85
Commit
43398a85
authored
Jul 28, 2015
by
Wenzel Jakob
Browse files
complex number support
parent
d4258baf
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
88 additions
and
28 deletions
+88
-28
example/example10.cpp
example/example10.cpp
+6
-1
example/example10.py
example/example10.py
+3
-0
include/pybind/cast.h
include/pybind/cast.h
+17
-0
include/pybind/common.h
include/pybind/common.h
+6
-4
include/pybind/numpy.h
include/pybind/numpy.h
+39
-14
include/pybind/pybind.h
include/pybind/pybind.h
+17
-9
No files found.
example/example10.cpp
View file @
43398a85
...
...
@@ -15,14 +15,19 @@ double my_func(int x, float y, double z) {
return
x
*
y
*
z
;
}
std
::
complex
<
double
>
my_func3
(
std
::
complex
<
double
>
c
)
{
return
c
*
std
::
complex
<
double
>
(
2.
f
);
}
void
init_ex10
(
py
::
module
&
m
)
{
// Vectorize all arguments (though non-vector arguments are also allowed)
m
.
def
(
"vectorized_func"
,
py
::
vectorize
(
my_func
));
// Vectorize a lambda function with a capture object (e.g. to exclude some arguments from the vectorization)
m
.
def
(
"vectorized_func2"
,
[](
py
::
array_dtype
<
int
>
x
,
py
::
array_dtype
<
float
>
y
,
float
z
)
{
return
py
::
vectorize
([
z
](
int
x
,
float
y
)
{
return
my_func
(
x
,
y
,
z
);
})(
x
,
y
);
}
);
// Vectorize all arguments (complex numbers)
m
.
def
(
"vectorized_func3"
,
py
::
vectorize
(
my_func3
));
}
example/example10.py
View file @
43398a85
...
...
@@ -7,6 +7,9 @@ import numpy as np
from
example
import
vectorized_func
from
example
import
vectorized_func2
from
example
import
vectorized_func3
print
(
vectorized_func3
(
np
.
array
(
3
+
7j
)))
for
f
in
[
vectorized_func
,
vectorized_func2
]:
print
(
f
(
1
,
2
,
3
))
...
...
include/pybind/cast.h
View file @
43398a85
...
...
@@ -192,6 +192,23 @@ public:
PYBIND_TYPE_CASTER
(
bool
,
"bool"
);
};
template
<
typename
T
>
class
type_caster
<
std
::
complex
<
T
>>
{
public:
bool
load
(
PyObject
*
src
,
bool
)
{
Py_complex
result
=
PyComplex_AsCComplex
(
src
);
if
(
result
.
real
==
-
1.0
&&
PyErr_Occurred
())
{
PyErr_Clear
();
return
false
;
}
value
=
std
::
complex
<
T
>
((
T
)
result
.
real
,
(
T
)
result
.
imag
);
return
true
;
}
static
PyObject
*
cast
(
const
std
::
complex
<
T
>
&
src
,
return_value_policy
/* policy */
,
PyObject
*
/* parent */
)
{
return
PyComplex_FromDoubles
((
double
)
src
.
real
(),
(
double
)
src
.
imag
());
}
PYBIND_TYPE_CASTER
(
std
::
complex
<
T
>
,
"complex"
);
};
template
<
>
class
type_caster
<
std
::
string
>
{
public:
bool
load
(
PyObject
*
src
,
bool
)
{
...
...
include/pybind/common.h
View file @
43398a85
...
...
@@ -33,7 +33,7 @@
#include <unordered_map>
#include <iostream>
#include <memory>
#include <
functional
>
#include <
complex
>
/// Include Python header, disable linking to pythonX_d.lib on Windows in debug mode
#if defined(_MSC_VER)
...
...
@@ -82,7 +82,8 @@ template <typename type> struct format_descriptor { };
#define DECL_FMT(t, n) template<> struct format_descriptor<t> { static std::string value() { return n; }; };
DECL_FMT
(
int8_t
,
"b"
);
DECL_FMT
(
uint8_t
,
"B"
);
DECL_FMT
(
int16_t
,
"h"
);
DECL_FMT
(
uint16_t
,
"H"
);
DECL_FMT
(
int32_t
,
"i"
);
DECL_FMT
(
uint32_t
,
"I"
);
DECL_FMT
(
int64_t
,
"q"
);
DECL_FMT
(
uint64_t
,
"Q"
);
DECL_FMT
(
float
,
"f"
);
DECL_FMT
(
double
,
"d"
);
DECL_FMT
(
float
,
"f"
);
DECL_FMT
(
double
,
"d"
);
DECL_FMT
(
bool
,
"?"
);
DECL_FMT
(
std
::
complex
<
float
>
,
"Zf"
);
DECL_FMT
(
std
::
complex
<
double
>
,
"Zd"
);
#undef DECL_FMT
/// Information record describing a Python buffer object
...
...
@@ -126,8 +127,9 @@ struct type_info {
PyTypeObject
*
type
;
size_t
type_size
;
void
(
*
init_holder
)(
PyObject
*
);
std
::
function
<
buffer_info
*
(
PyObject
*
)
>
get_buffer
;
std
::
vector
<
PyObject
*
(
*
)(
PyObject
*
,
PyTypeObject
*
)
>
implicit_conversions
;
buffer_info
*
(
*
get_buffer
)(
PyObject
*
,
void
*
)
=
nullptr
;
void
*
get_buffer_data
=
nullptr
;
};
/// Internal data struture used to track registered instances and types
...
...
include/pybind/numpy.h
View file @
43398a85
...
...
@@ -17,8 +17,10 @@
NAMESPACE_BEGIN
(
pybind
)
template
<
typename
type
>
struct
npy_format_descriptor
{
};
class
array
:
public
buffer
{
p
rotected
:
p
ublic
:
struct
API
{
enum
Entries
{
API_PyArray_Type
=
2
,
...
...
@@ -26,10 +28,18 @@ protected:
API_PyArray_FromAny
=
69
,
API_PyArray_NewCopy
=
85
,
API_PyArray_NewFromDescr
=
94
,
API_NPY_C_CONTIGUOUS
=
0x0001
,
API_NPY_F_CONTIGUOUS
=
0x0002
,
API_NPY_NPY_ARRAY_FORCECAST
=
0x0010
,
API_NPY_ENSURE_ARRAY
=
0x0040
NPY_C_CONTIGUOUS
=
0x0001
,
NPY_F_CONTIGUOUS
=
0x0002
,
NPY_NPY_ARRAY_FORCECAST
=
0x0010
,
NPY_ENSURE_ARRAY
=
0x0040
,
NPY_BOOL
=
0
,
NPY_BYTE
,
NPY_UBYTE
,
NPY_SHORT
,
NPY_USHORT
,
NPY_INT
,
NPY_UINT
,
NPY_LONG
,
NPY_ULONG
,
NPY_LONGLONG
,
NPY_ULONGLONG
,
NPY_FLOAT
,
NPY_DOUBLE
,
NPY_LONGDOUBLE
,
NPY_CFLOAT
,
NPY_CDOUBLE
,
NPY_CLONGDOUBLE
};
static
API
lookup
()
{
...
...
@@ -59,13 +69,12 @@ protected:
PyTypeObject
*
PyArray_Type
;
PyObject
*
(
*
PyArray_FromAny
)
(
PyObject
*
,
PyObject
*
,
int
,
int
,
int
,
PyObject
*
);
};
public:
PYBIND_OBJECT_DEFAULT
(
array
,
buffer
,
lookup_api
().
PyArray_Check
)
template
<
typename
Type
>
array
(
size_t
size
,
const
Type
*
ptr
)
{
API
&
api
=
lookup_api
();
PyObject
*
descr
=
api
.
PyArray_DescrFromType
(
(
int
)
format_descriptor
<
Type
>::
value
()[
0
]);
PyObject
*
descr
=
api
.
PyArray_DescrFromType
(
npy_format_descriptor
<
Type
>::
value
);
if
(
descr
==
nullptr
)
throw
std
::
runtime_error
(
"NumPy: unsupported buffer format!"
);
Py_intptr_t
shape
=
(
Py_intptr_t
)
size
;
...
...
@@ -83,7 +92,12 @@ public:
API
&
api
=
lookup_api
();
if
(
info
.
format
.
size
()
!=
1
)
throw
std
::
runtime_error
(
"Unsupported buffer format!"
);
PyObject
*
descr
=
api
.
PyArray_DescrFromType
(
info
.
format
[
0
]);
int
fmt
=
(
int
)
info
.
format
[
0
];
if
(
info
.
format
==
"Zd"
)
fmt
=
API
::
NPY_CDOUBLE
;
else
if
(
info
.
format
==
"Zf"
)
fmt
=
API
::
NPY_CFLOAT
;
PyObject
*
descr
=
api
.
PyArray_DescrFromType
(
fmt
);
if
(
descr
==
nullptr
)
throw
std
::
runtime_error
(
"NumPy: unsupported buffer format '"
+
info
.
format
+
"'!"
);
PyObject
*
tmp
=
api
.
PyArray_NewFromDescr
(
...
...
@@ -109,12 +123,12 @@ public:
PYBIND_OBJECT_CVT
(
array_dtype
,
array
,
is_non_null
,
m_ptr
=
ensure
(
m_ptr
));
array_dtype
()
:
array
()
{
}
static
bool
is_non_null
(
PyObject
*
ptr
)
{
return
ptr
!=
nullptr
;
}
static
PyObject
*
ensure
(
PyObject
*
ptr
)
{
PyObject
*
ensure
(
PyObject
*
ptr
)
{
API
&
api
=
lookup_api
();
PyObject
*
descr
=
api
.
PyArray_DescrFromType
(
format_descriptor
<
T
>::
value
()[
0
]
);
PyObject
*
descr
=
api
.
PyArray_DescrFromType
(
npy_
format_descriptor
<
T
>::
value
);
return
api
.
PyArray_FromAny
(
ptr
,
descr
,
0
,
0
,
API
::
API_
NPY_C_CONTIGUOUS
|
API
::
API_
NPY_ENSURE_ARRAY
|
API
::
API_
NPY_NPY_ARRAY_FORCECAST
,
nullptr
);
API
::
NPY_C_CONTIGUOUS
|
API
::
NPY_ENSURE_ARRAY
|
API
::
NPY_NPY_ARRAY_FORCECAST
,
nullptr
);
}
};
...
...
@@ -125,8 +139,19 @@ PYBIND_TYPE_CASTER_PYTYPE(array_dtype<int16_t>) PYBIND_TYPE_CASTER_PYTYPE(array_
PYBIND_TYPE_CASTER_PYTYPE
(
array_dtype
<
int32_t
>
)
PYBIND_TYPE_CASTER_PYTYPE
(
array_dtype
<
uint32_t
>
)
PYBIND_TYPE_CASTER_PYTYPE
(
array_dtype
<
int64_t
>
)
PYBIND_TYPE_CASTER_PYTYPE
(
array_dtype
<
uint64_t
>
)
PYBIND_TYPE_CASTER_PYTYPE
(
array_dtype
<
float
>
)
PYBIND_TYPE_CASTER_PYTYPE
(
array_dtype
<
double
>
)
PYBIND_TYPE_CASTER_PYTYPE
(
array_dtype
<
std
::
complex
<
float
>>
)
PYBIND_TYPE_CASTER_PYTYPE
(
array_dtype
<
std
::
complex
<
double
>>
)
PYBIND_TYPE_CASTER_PYTYPE
(
array_dtype
<
bool
>
)
NAMESPACE_END
(
detail
)
#define DECL_FMT(t, n) template<> struct npy_format_descriptor<t> { enum { value = array::API::n }; }
DECL_FMT
(
int8_t
,
NPY_BYTE
);
DECL_FMT
(
uint8_t
,
NPY_UBYTE
);
DECL_FMT
(
int16_t
,
NPY_SHORT
);
DECL_FMT
(
uint16_t
,
NPY_USHORT
);
DECL_FMT
(
int32_t
,
NPY_INT
);
DECL_FMT
(
uint32_t
,
NPY_UINT
);
DECL_FMT
(
int64_t
,
NPY_LONGLONG
);
DECL_FMT
(
uint64_t
,
NPY_ULONGLONG
);
DECL_FMT
(
float
,
NPY_FLOAT
);
DECL_FMT
(
double
,
NPY_DOUBLE
);
DECL_FMT
(
bool
,
NPY_BOOL
);
DECL_FMT
(
std
::
complex
<
float
>
,
NPY_CFLOAT
);
DECL_FMT
(
std
::
complex
<
double
>
,
NPY_CDOUBLE
);
#undef DECL_FMT
template
<
typename
func_type
,
typename
return_type
,
typename
...
args_type
,
size_t
...
Index
>
std
::
function
<
object
(
array_dtype
<
args_type
>
...)
>
vectorize
(
func_type
&&
f
,
return_type
(
*
)
(
args_type
...),
...
...
include/pybind/pybind.h
View file @
43398a85
...
...
@@ -393,22 +393,27 @@ protected:
Py_TYPE
(
self
)
->
tp_free
((
PyObject
*
)
self
);
}
void
install_buffer_funcs
(
const
std
::
function
<
buffer_info
*
(
PyObject
*
)
>
&
func
)
{
void
install_buffer_funcs
(
buffer_info
*
(
*
get_buffer
)(
PyObject
*
,
void
*
),
void
*
get_buffer_data
)
{
PyHeapTypeObject
*
type
=
(
PyHeapTypeObject
*
)
m_ptr
;
type
->
ht_type
.
tp_as_buffer
=
&
type
->
as_buffer
;
type
->
as_buffer
.
bf_getbuffer
=
getbuffer
;
type
->
as_buffer
.
bf_releasebuffer
=
releasebuffer
;
((
detail
::
type_info
*
)
capsule
(
attr
(
"__pybind__"
)))
->
get_buffer
=
func
;
auto
info
=
((
detail
::
type_info
*
)
capsule
(
attr
(
"__pybind__"
)));
info
->
get_buffer
=
get_buffer
;
info
->
get_buffer_data
=
get_buffer_data
;
}
static
int
getbuffer
(
PyObject
*
obj
,
Py_buffer
*
view
,
int
flags
)
{
auto
const
&
info_func
=
((
detail
::
type_info
*
)
capsule
(
handle
(
obj
).
attr
(
"__pybind__"
)))
->
get_buffer
;
if
(
view
==
nullptr
||
obj
==
nullptr
||
!
info_func
)
{
auto
const
&
typeinfo
=
((
detail
::
type_info
*
)
capsule
(
handle
(
obj
).
attr
(
"__pybind__"
)));
if
(
view
==
nullptr
||
obj
==
nullptr
||
!
typeinfo
||
!
typeinfo
->
get_buffer
)
{
PyErr_SetString
(
PyExc_BufferError
,
"Internal error"
);
return
-
1
;
}
memset
(
view
,
0
,
sizeof
(
Py_buffer
));
buffer_info
*
info
=
info_func
(
obj
);
buffer_info
*
info
=
typeinfo
->
get_buffer
(
obj
,
typeinfo
->
get_buffer_data
);
view
->
obj
=
obj
;
view
->
ndim
=
1
;
view
->
internal
=
info
;
...
...
@@ -483,13 +488,16 @@ public:
return
*
this
;
}
class_
&
def_buffer
(
const
std
::
function
<
buffer_info
(
type
&
)
>
&
func
)
{
install_buffer_funcs
([
func
](
PyObject
*
obj
)
->
buffer_info
*
{
template
<
typename
Func
>
class_
&
def_buffer
(
Func
&&
func
)
{
struct
capture
{
Func
func
;
};
capture
*
ptr
=
new
capture
{
std
::
forward
<
Func
>
(
func
)
};
install_buffer_funcs
([](
PyObject
*
obj
,
void
*
ptr
)
->
buffer_info
*
{
detail
::
type_caster
<
type
>
caster
;
if
(
!
caster
.
load
(
obj
,
false
))
return
nullptr
;
return
new
buffer_info
(
func
(
caster
));
});
return
new
buffer_info
(
((
capture
*
)
ptr
)
->
func
(
caster
));
}
,
ptr
);
return
*
this
;
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment