Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
pybind11
Commits
43398a85
Commit
43398a85
authored
Jul 28, 2015
by
Wenzel Jakob
Browse files
complex number support
parent
d4258baf
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
88 additions
and
28 deletions
+88
-28
example/example10.cpp
example/example10.cpp
+6
-1
example/example10.py
example/example10.py
+3
-0
include/pybind/cast.h
include/pybind/cast.h
+17
-0
include/pybind/common.h
include/pybind/common.h
+6
-4
include/pybind/numpy.h
include/pybind/numpy.h
+39
-14
include/pybind/pybind.h
include/pybind/pybind.h
+17
-9
No files found.
example/example10.cpp
View file @
43398a85
...
@@ -15,14 +15,19 @@ double my_func(int x, float y, double z) {
...
@@ -15,14 +15,19 @@ double my_func(int x, float y, double z) {
return
x
*
y
*
z
;
return
x
*
y
*
z
;
}
}
std
::
complex
<
double
>
my_func3
(
std
::
complex
<
double
>
c
)
{
return
c
*
std
::
complex
<
double
>
(
2.
f
);
}
void
init_ex10
(
py
::
module
&
m
)
{
void
init_ex10
(
py
::
module
&
m
)
{
// Vectorize all arguments (though non-vector arguments are also allowed)
// Vectorize all arguments (though non-vector arguments are also allowed)
m
.
def
(
"vectorized_func"
,
py
::
vectorize
(
my_func
));
m
.
def
(
"vectorized_func"
,
py
::
vectorize
(
my_func
));
// Vectorize a lambda function with a capture object (e.g. to exclude some arguments from the vectorization)
// Vectorize a lambda function with a capture object (e.g. to exclude some arguments from the vectorization)
m
.
def
(
"vectorized_func2"
,
m
.
def
(
"vectorized_func2"
,
[](
py
::
array_dtype
<
int
>
x
,
py
::
array_dtype
<
float
>
y
,
float
z
)
{
[](
py
::
array_dtype
<
int
>
x
,
py
::
array_dtype
<
float
>
y
,
float
z
)
{
return
py
::
vectorize
([
z
](
int
x
,
float
y
)
{
return
my_func
(
x
,
y
,
z
);
})(
x
,
y
);
return
py
::
vectorize
([
z
](
int
x
,
float
y
)
{
return
my_func
(
x
,
y
,
z
);
})(
x
,
y
);
}
}
);
);
// Vectorize all arguments (complex numbers)
m
.
def
(
"vectorized_func3"
,
py
::
vectorize
(
my_func3
));
}
}
example/example10.py
View file @
43398a85
...
@@ -7,6 +7,9 @@ import numpy as np
...
@@ -7,6 +7,9 @@ import numpy as np
from
example
import
vectorized_func
from
example
import
vectorized_func
from
example
import
vectorized_func2
from
example
import
vectorized_func2
from
example
import
vectorized_func3
print
(
vectorized_func3
(
np
.
array
(
3
+
7j
)))
for
f
in
[
vectorized_func
,
vectorized_func2
]:
for
f
in
[
vectorized_func
,
vectorized_func2
]:
print
(
f
(
1
,
2
,
3
))
print
(
f
(
1
,
2
,
3
))
...
...
include/pybind/cast.h
View file @
43398a85
...
@@ -192,6 +192,23 @@ public:
...
@@ -192,6 +192,23 @@ public:
PYBIND_TYPE_CASTER
(
bool
,
"bool"
);
PYBIND_TYPE_CASTER
(
bool
,
"bool"
);
};
};
template
<
typename
T
>
class
type_caster
<
std
::
complex
<
T
>>
{
public:
bool
load
(
PyObject
*
src
,
bool
)
{
Py_complex
result
=
PyComplex_AsCComplex
(
src
);
if
(
result
.
real
==
-
1.0
&&
PyErr_Occurred
())
{
PyErr_Clear
();
return
false
;
}
value
=
std
::
complex
<
T
>
((
T
)
result
.
real
,
(
T
)
result
.
imag
);
return
true
;
}
static
PyObject
*
cast
(
const
std
::
complex
<
T
>
&
src
,
return_value_policy
/* policy */
,
PyObject
*
/* parent */
)
{
return
PyComplex_FromDoubles
((
double
)
src
.
real
(),
(
double
)
src
.
imag
());
}
PYBIND_TYPE_CASTER
(
std
::
complex
<
T
>
,
"complex"
);
};
template
<
>
class
type_caster
<
std
::
string
>
{
template
<
>
class
type_caster
<
std
::
string
>
{
public:
public:
bool
load
(
PyObject
*
src
,
bool
)
{
bool
load
(
PyObject
*
src
,
bool
)
{
...
...
include/pybind/common.h
View file @
43398a85
...
@@ -33,7 +33,7 @@
...
@@ -33,7 +33,7 @@
#include <unordered_map>
#include <unordered_map>
#include <iostream>
#include <iostream>
#include <memory>
#include <memory>
#include <
functional
>
#include <
complex
>
/// Include Python header, disable linking to pythonX_d.lib on Windows in debug mode
/// Include Python header, disable linking to pythonX_d.lib on Windows in debug mode
#if defined(_MSC_VER)
#if defined(_MSC_VER)
...
@@ -82,7 +82,8 @@ template <typename type> struct format_descriptor { };
...
@@ -82,7 +82,8 @@ template <typename type> struct format_descriptor { };
#define DECL_FMT(t, n) template<> struct format_descriptor<t> { static std::string value() { return n; }; };
#define DECL_FMT(t, n) template<> struct format_descriptor<t> { static std::string value() { return n; }; };
DECL_FMT
(
int8_t
,
"b"
);
DECL_FMT
(
uint8_t
,
"B"
);
DECL_FMT
(
int16_t
,
"h"
);
DECL_FMT
(
uint16_t
,
"H"
);
DECL_FMT
(
int8_t
,
"b"
);
DECL_FMT
(
uint8_t
,
"B"
);
DECL_FMT
(
int16_t
,
"h"
);
DECL_FMT
(
uint16_t
,
"H"
);
DECL_FMT
(
int32_t
,
"i"
);
DECL_FMT
(
uint32_t
,
"I"
);
DECL_FMT
(
int64_t
,
"q"
);
DECL_FMT
(
uint64_t
,
"Q"
);
DECL_FMT
(
int32_t
,
"i"
);
DECL_FMT
(
uint32_t
,
"I"
);
DECL_FMT
(
int64_t
,
"q"
);
DECL_FMT
(
uint64_t
,
"Q"
);
DECL_FMT
(
float
,
"f"
);
DECL_FMT
(
double
,
"d"
);
DECL_FMT
(
float
,
"f"
);
DECL_FMT
(
double
,
"d"
);
DECL_FMT
(
bool
,
"?"
);
DECL_FMT
(
std
::
complex
<
float
>
,
"Zf"
);
DECL_FMT
(
std
::
complex
<
double
>
,
"Zd"
);
#undef DECL_FMT
#undef DECL_FMT
/// Information record describing a Python buffer object
/// Information record describing a Python buffer object
...
@@ -126,11 +127,12 @@ struct type_info {
...
@@ -126,11 +127,12 @@ struct type_info {
PyTypeObject
*
type
;
PyTypeObject
*
type
;
size_t
type_size
;
size_t
type_size
;
void
(
*
init_holder
)(
PyObject
*
);
void
(
*
init_holder
)(
PyObject
*
);
std
::
function
<
buffer_info
*
(
PyObject
*
)
>
get_buffer
;
std
::
vector
<
PyObject
*
(
*
)(
PyObject
*
,
PyTypeObject
*
)
>
implicit_conversions
;
std
::
vector
<
PyObject
*
(
*
)(
PyObject
*
,
PyTypeObject
*
)
>
implicit_conversions
;
buffer_info
*
(
*
get_buffer
)(
PyObject
*
,
void
*
)
=
nullptr
;
void
*
get_buffer_data
=
nullptr
;
};
};
/// Internal data struture used to track registered instances and types
/// Internal data struture used to track registered instances and types
struct
internals
{
struct
internals
{
std
::
unordered_map
<
std
::
string
,
type_info
>
registered_types
;
std
::
unordered_map
<
std
::
string
,
type_info
>
registered_types
;
std
::
unordered_map
<
void
*
,
PyObject
*>
registered_instances
;
std
::
unordered_map
<
void
*
,
PyObject
*>
registered_instances
;
...
...
include/pybind/numpy.h
View file @
43398a85
...
@@ -17,8 +17,10 @@
...
@@ -17,8 +17,10 @@
NAMESPACE_BEGIN
(
pybind
)
NAMESPACE_BEGIN
(
pybind
)
template
<
typename
type
>
struct
npy_format_descriptor
{
};
class
array
:
public
buffer
{
class
array
:
public
buffer
{
p
rotected
:
p
ublic
:
struct
API
{
struct
API
{
enum
Entries
{
enum
Entries
{
API_PyArray_Type
=
2
,
API_PyArray_Type
=
2
,
...
@@ -26,10 +28,18 @@ protected:
...
@@ -26,10 +28,18 @@ protected:
API_PyArray_FromAny
=
69
,
API_PyArray_FromAny
=
69
,
API_PyArray_NewCopy
=
85
,
API_PyArray_NewCopy
=
85
,
API_PyArray_NewFromDescr
=
94
,
API_PyArray_NewFromDescr
=
94
,
API_NPY_C_CONTIGUOUS
=
0x0001
,
NPY_C_CONTIGUOUS
=
0x0001
,
API_NPY_F_CONTIGUOUS
=
0x0002
,
NPY_F_CONTIGUOUS
=
0x0002
,
API_NPY_NPY_ARRAY_FORCECAST
=
0x0010
,
NPY_NPY_ARRAY_FORCECAST
=
0x0010
,
API_NPY_ENSURE_ARRAY
=
0x0040
NPY_ENSURE_ARRAY
=
0x0040
,
NPY_BOOL
=
0
,
NPY_BYTE
,
NPY_UBYTE
,
NPY_SHORT
,
NPY_USHORT
,
NPY_INT
,
NPY_UINT
,
NPY_LONG
,
NPY_ULONG
,
NPY_LONGLONG
,
NPY_ULONGLONG
,
NPY_FLOAT
,
NPY_DOUBLE
,
NPY_LONGDOUBLE
,
NPY_CFLOAT
,
NPY_CDOUBLE
,
NPY_CLONGDOUBLE
};
};
static
API
lookup
()
{
static
API
lookup
()
{
...
@@ -59,13 +69,12 @@ protected:
...
@@ -59,13 +69,12 @@ protected:
PyTypeObject
*
PyArray_Type
;
PyTypeObject
*
PyArray_Type
;
PyObject
*
(
*
PyArray_FromAny
)
(
PyObject
*
,
PyObject
*
,
int
,
int
,
int
,
PyObject
*
);
PyObject
*
(
*
PyArray_FromAny
)
(
PyObject
*
,
PyObject
*
,
int
,
int
,
int
,
PyObject
*
);
};
};
public:
PYBIND_OBJECT_DEFAULT
(
array
,
buffer
,
lookup_api
().
PyArray_Check
)
PYBIND_OBJECT_DEFAULT
(
array
,
buffer
,
lookup_api
().
PyArray_Check
)
template
<
typename
Type
>
array
(
size_t
size
,
const
Type
*
ptr
)
{
template
<
typename
Type
>
array
(
size_t
size
,
const
Type
*
ptr
)
{
API
&
api
=
lookup_api
();
API
&
api
=
lookup_api
();
PyObject
*
descr
=
api
.
PyArray_DescrFromType
(
PyObject
*
descr
=
api
.
PyArray_DescrFromType
(
npy_format_descriptor
<
Type
>::
value
);
(
int
)
format_descriptor
<
Type
>::
value
()[
0
]);
if
(
descr
==
nullptr
)
if
(
descr
==
nullptr
)
throw
std
::
runtime_error
(
"NumPy: unsupported buffer format!"
);
throw
std
::
runtime_error
(
"NumPy: unsupported buffer format!"
);
Py_intptr_t
shape
=
(
Py_intptr_t
)
size
;
Py_intptr_t
shape
=
(
Py_intptr_t
)
size
;
...
@@ -83,7 +92,12 @@ public:
...
@@ -83,7 +92,12 @@ public:
API
&
api
=
lookup_api
();
API
&
api
=
lookup_api
();
if
(
info
.
format
.
size
()
!=
1
)
if
(
info
.
format
.
size
()
!=
1
)
throw
std
::
runtime_error
(
"Unsupported buffer format!"
);
throw
std
::
runtime_error
(
"Unsupported buffer format!"
);
PyObject
*
descr
=
api
.
PyArray_DescrFromType
(
info
.
format
[
0
]);
int
fmt
=
(
int
)
info
.
format
[
0
];
if
(
info
.
format
==
"Zd"
)
fmt
=
API
::
NPY_CDOUBLE
;
else
if
(
info
.
format
==
"Zf"
)
fmt
=
API
::
NPY_CFLOAT
;
PyObject
*
descr
=
api
.
PyArray_DescrFromType
(
fmt
);
if
(
descr
==
nullptr
)
if
(
descr
==
nullptr
)
throw
std
::
runtime_error
(
"NumPy: unsupported buffer format '"
+
info
.
format
+
"'!"
);
throw
std
::
runtime_error
(
"NumPy: unsupported buffer format '"
+
info
.
format
+
"'!"
);
PyObject
*
tmp
=
api
.
PyArray_NewFromDescr
(
PyObject
*
tmp
=
api
.
PyArray_NewFromDescr
(
...
@@ -109,12 +123,12 @@ public:
...
@@ -109,12 +123,12 @@ public:
PYBIND_OBJECT_CVT
(
array_dtype
,
array
,
is_non_null
,
m_ptr
=
ensure
(
m_ptr
));
PYBIND_OBJECT_CVT
(
array_dtype
,
array
,
is_non_null
,
m_ptr
=
ensure
(
m_ptr
));
array_dtype
()
:
array
()
{
}
array_dtype
()
:
array
()
{
}
static
bool
is_non_null
(
PyObject
*
ptr
)
{
return
ptr
!=
nullptr
;
}
static
bool
is_non_null
(
PyObject
*
ptr
)
{
return
ptr
!=
nullptr
;
}
static
PyObject
*
ensure
(
PyObject
*
ptr
)
{
PyObject
*
ensure
(
PyObject
*
ptr
)
{
API
&
api
=
lookup_api
();
API
&
api
=
lookup_api
();
PyObject
*
descr
=
api
.
PyArray_DescrFromType
(
format_descriptor
<
T
>::
value
()[
0
]
);
PyObject
*
descr
=
api
.
PyArray_DescrFromType
(
npy_
format_descriptor
<
T
>::
value
);
return
api
.
PyArray_FromAny
(
ptr
,
descr
,
0
,
0
,
return
api
.
PyArray_FromAny
(
ptr
,
descr
,
0
,
0
,
API
::
API_
NPY_C_CONTIGUOUS
|
API
::
API_
NPY_ENSURE_ARRAY
|
API
::
NPY_C_CONTIGUOUS
|
API
::
NPY_ENSURE_ARRAY
|
API
::
API_
NPY_NPY_ARRAY_FORCECAST
,
nullptr
);
API
::
NPY_NPY_ARRAY_FORCECAST
,
nullptr
);
}
}
};
};
...
@@ -125,8 +139,19 @@ PYBIND_TYPE_CASTER_PYTYPE(array_dtype<int16_t>) PYBIND_TYPE_CASTER_PYTYPE(array_
...
@@ -125,8 +139,19 @@ PYBIND_TYPE_CASTER_PYTYPE(array_dtype<int16_t>) PYBIND_TYPE_CASTER_PYTYPE(array_
PYBIND_TYPE_CASTER_PYTYPE
(
array_dtype
<
int32_t
>
)
PYBIND_TYPE_CASTER_PYTYPE
(
array_dtype
<
uint32_t
>
)
PYBIND_TYPE_CASTER_PYTYPE
(
array_dtype
<
int32_t
>
)
PYBIND_TYPE_CASTER_PYTYPE
(
array_dtype
<
uint32_t
>
)
PYBIND_TYPE_CASTER_PYTYPE
(
array_dtype
<
int64_t
>
)
PYBIND_TYPE_CASTER_PYTYPE
(
array_dtype
<
uint64_t
>
)
PYBIND_TYPE_CASTER_PYTYPE
(
array_dtype
<
int64_t
>
)
PYBIND_TYPE_CASTER_PYTYPE
(
array_dtype
<
uint64_t
>
)
PYBIND_TYPE_CASTER_PYTYPE
(
array_dtype
<
float
>
)
PYBIND_TYPE_CASTER_PYTYPE
(
array_dtype
<
double
>
)
PYBIND_TYPE_CASTER_PYTYPE
(
array_dtype
<
float
>
)
PYBIND_TYPE_CASTER_PYTYPE
(
array_dtype
<
double
>
)
PYBIND_TYPE_CASTER_PYTYPE
(
array_dtype
<
std
::
complex
<
float
>>
)
PYBIND_TYPE_CASTER_PYTYPE
(
array_dtype
<
std
::
complex
<
double
>>
)
PYBIND_TYPE_CASTER_PYTYPE
(
array_dtype
<
bool
>
)
NAMESPACE_END
(
detail
)
NAMESPACE_END
(
detail
)
#define DECL_FMT(t, n) template<> struct npy_format_descriptor<t> { enum { value = array::API::n }; }
DECL_FMT
(
int8_t
,
NPY_BYTE
);
DECL_FMT
(
uint8_t
,
NPY_UBYTE
);
DECL_FMT
(
int16_t
,
NPY_SHORT
);
DECL_FMT
(
uint16_t
,
NPY_USHORT
);
DECL_FMT
(
int32_t
,
NPY_INT
);
DECL_FMT
(
uint32_t
,
NPY_UINT
);
DECL_FMT
(
int64_t
,
NPY_LONGLONG
);
DECL_FMT
(
uint64_t
,
NPY_ULONGLONG
);
DECL_FMT
(
float
,
NPY_FLOAT
);
DECL_FMT
(
double
,
NPY_DOUBLE
);
DECL_FMT
(
bool
,
NPY_BOOL
);
DECL_FMT
(
std
::
complex
<
float
>
,
NPY_CFLOAT
);
DECL_FMT
(
std
::
complex
<
double
>
,
NPY_CDOUBLE
);
#undef DECL_FMT
template
<
typename
func_type
,
typename
return_type
,
typename
...
args_type
,
size_t
...
Index
>
template
<
typename
func_type
,
typename
return_type
,
typename
...
args_type
,
size_t
...
Index
>
std
::
function
<
object
(
array_dtype
<
args_type
>
...)
>
std
::
function
<
object
(
array_dtype
<
args_type
>
...)
>
vectorize
(
func_type
&&
f
,
return_type
(
*
)
(
args_type
...),
vectorize
(
func_type
&&
f
,
return_type
(
*
)
(
args_type
...),
...
@@ -171,7 +196,7 @@ template <typename func_type, typename return_type, typename... args_type, size_
...
@@ -171,7 +196,7 @@ template <typename func_type, typename return_type, typename... args_type, size_
return
cast
(
result
[
0
]);
return
cast
(
result
[
0
]);
/* Return the result */
/* Return the result */
return
array
(
buffer_info
(
result
.
data
(),
sizeof
(
return_type
),
return
array
(
buffer_info
(
result
.
data
(),
sizeof
(
return_type
),
format_descriptor
<
return_type
>::
value
(),
format_descriptor
<
return_type
>::
value
(),
ndim
,
shape
,
strides
));
ndim
,
shape
,
strides
));
};
};
...
...
include/pybind/pybind.h
View file @
43398a85
...
@@ -393,22 +393,27 @@ protected:
...
@@ -393,22 +393,27 @@ protected:
Py_TYPE
(
self
)
->
tp_free
((
PyObject
*
)
self
);
Py_TYPE
(
self
)
->
tp_free
((
PyObject
*
)
self
);
}
}
void
install_buffer_funcs
(
const
std
::
function
<
buffer_info
*
(
PyObject
*
)
>
&
func
)
{
void
install_buffer_funcs
(
buffer_info
*
(
*
get_buffer
)(
PyObject
*
,
void
*
),
void
*
get_buffer_data
)
{
PyHeapTypeObject
*
type
=
(
PyHeapTypeObject
*
)
m_ptr
;
PyHeapTypeObject
*
type
=
(
PyHeapTypeObject
*
)
m_ptr
;
type
->
ht_type
.
tp_as_buffer
=
&
type
->
as_buffer
;
type
->
ht_type
.
tp_as_buffer
=
&
type
->
as_buffer
;
type
->
as_buffer
.
bf_getbuffer
=
getbuffer
;
type
->
as_buffer
.
bf_getbuffer
=
getbuffer
;
type
->
as_buffer
.
bf_releasebuffer
=
releasebuffer
;
type
->
as_buffer
.
bf_releasebuffer
=
releasebuffer
;
((
detail
::
type_info
*
)
capsule
(
attr
(
"__pybind__"
)))
->
get_buffer
=
func
;
auto
info
=
((
detail
::
type_info
*
)
capsule
(
attr
(
"__pybind__"
)));
info
->
get_buffer
=
get_buffer
;
info
->
get_buffer_data
=
get_buffer_data
;
}
}
static
int
getbuffer
(
PyObject
*
obj
,
Py_buffer
*
view
,
int
flags
)
{
static
int
getbuffer
(
PyObject
*
obj
,
Py_buffer
*
view
,
int
flags
)
{
auto
const
&
info_func
=
((
detail
::
type_info
*
)
capsule
(
handle
(
obj
).
attr
(
"__pybind__"
)))
->
get_buffer
;
auto
const
&
typeinfo
=
((
detail
::
type_info
*
)
capsule
(
handle
(
obj
).
attr
(
"__pybind__"
)));
if
(
view
==
nullptr
||
obj
==
nullptr
||
!
info_func
)
{
if
(
view
==
nullptr
||
obj
==
nullptr
||
!
typeinfo
||
!
typeinfo
->
get_buffer
)
{
PyErr_SetString
(
PyExc_BufferError
,
"Internal error"
);
PyErr_SetString
(
PyExc_BufferError
,
"Internal error"
);
return
-
1
;
return
-
1
;
}
}
memset
(
view
,
0
,
sizeof
(
Py_buffer
));
memset
(
view
,
0
,
sizeof
(
Py_buffer
));
buffer_info
*
info
=
info_func
(
obj
);
buffer_info
*
info
=
typeinfo
->
get_buffer
(
obj
,
typeinfo
->
get_buffer_data
);
view
->
obj
=
obj
;
view
->
obj
=
obj
;
view
->
ndim
=
1
;
view
->
ndim
=
1
;
view
->
internal
=
info
;
view
->
internal
=
info
;
...
@@ -483,13 +488,16 @@ public:
...
@@ -483,13 +488,16 @@ public:
return
*
this
;
return
*
this
;
}
}
class_
&
def_buffer
(
const
std
::
function
<
buffer_info
(
type
&
)
>
&
func
)
{
template
<
typename
Func
>
install_buffer_funcs
([
func
](
PyObject
*
obj
)
->
buffer_info
*
{
class_
&
def_buffer
(
Func
&&
func
)
{
struct
capture
{
Func
func
;
};
capture
*
ptr
=
new
capture
{
std
::
forward
<
Func
>
(
func
)
};
install_buffer_funcs
([](
PyObject
*
obj
,
void
*
ptr
)
->
buffer_info
*
{
detail
::
type_caster
<
type
>
caster
;
detail
::
type_caster
<
type
>
caster
;
if
(
!
caster
.
load
(
obj
,
false
))
if
(
!
caster
.
load
(
obj
,
false
))
return
nullptr
;
return
nullptr
;
return
new
buffer_info
(
func
(
caster
));
return
new
buffer_info
(
((
capture
*
)
ptr
)
->
func
(
caster
));
});
}
,
ptr
);
return
*
this
;
return
*
this
;
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment