Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
pybind11
Commits
00488a3e
Commit
00488a3e
authored
Oct 13, 2016
by
Wenzel Jakob
Committed by
GitHub
Oct 13, 2016
Browse files
Merge pull request #440 from wjakob/master
Permit creation of NumPy arrays with a "base" object that owns the data
parents
43f6aa68
fac7c094
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
186 additions
and
29 deletions
+186
-29
include/pybind11/common.h
include/pybind11/common.h
+1
-0
include/pybind11/numpy.h
include/pybind11/numpy.h
+63
-24
include/pybind11/pybind11.h
include/pybind11/pybind11.h
+17
-5
tests/test_numpy_array.cpp
tests/test_numpy_array.cpp
+25
-0
tests/test_numpy_array.py
tests/test_numpy_array.py
+80
-0
No files found.
include/pybind11/common.h
View file @
00488a3e
...
@@ -455,6 +455,7 @@ PYBIND11_RUNTIME_EXCEPTION(stop_iteration, PyExc_StopIteration)
...
@@ -455,6 +455,7 @@ PYBIND11_RUNTIME_EXCEPTION(stop_iteration, PyExc_StopIteration)
PYBIND11_RUNTIME_EXCEPTION
(
index_error
,
PyExc_IndexError
)
PYBIND11_RUNTIME_EXCEPTION
(
index_error
,
PyExc_IndexError
)
PYBIND11_RUNTIME_EXCEPTION
(
key_error
,
PyExc_KeyError
)
PYBIND11_RUNTIME_EXCEPTION
(
key_error
,
PyExc_KeyError
)
PYBIND11_RUNTIME_EXCEPTION
(
value_error
,
PyExc_ValueError
)
PYBIND11_RUNTIME_EXCEPTION
(
value_error
,
PyExc_ValueError
)
PYBIND11_RUNTIME_EXCEPTION
(
import_error
,
PyExc_ImportError
)
PYBIND11_RUNTIME_EXCEPTION
(
type_error
,
PyExc_TypeError
)
PYBIND11_RUNTIME_EXCEPTION
(
type_error
,
PyExc_TypeError
)
PYBIND11_RUNTIME_EXCEPTION
(
cast_error
,
PyExc_RuntimeError
)
/// Thrown when pybind11::cast or handle::call fail due to a type casting error
PYBIND11_RUNTIME_EXCEPTION
(
cast_error
,
PyExc_RuntimeError
)
/// Thrown when pybind11::cast or handle::call fail due to a type casting error
PYBIND11_RUNTIME_EXCEPTION
(
reference_cast_error
,
PyExc_RuntimeError
)
/// Used internally
PYBIND11_RUNTIME_EXCEPTION
(
reference_cast_error
,
PyExc_RuntimeError
)
/// Used internally
...
...
include/pybind11/numpy.h
View file @
00488a3e
...
@@ -22,8 +22,8 @@
...
@@ -22,8 +22,8 @@
#include <functional>
#include <functional>
#if defined(_MSC_VER)
#if defined(_MSC_VER)
#pragma warning(push)
#
pragma warning(push)
#pragma warning(disable: 4127) // warning C4127: Conditional expression is constant
#
pragma warning(disable: 4127) // warning C4127: Conditional expression is constant
#endif
#endif
/* This will be true on all flat address space platforms and allows us to reduce the
/* This will be true on all flat address space platforms and allows us to reduce the
...
@@ -156,8 +156,10 @@ NAMESPACE_END(detail)
...
@@ -156,8 +156,10 @@ NAMESPACE_END(detail)
(reinterpret_cast<::pybind11::detail::PyArray_Proxy*>(ptr)->attr)
(reinterpret_cast<::pybind11::detail::PyArray_Proxy*>(ptr)->attr)
#define PyArrayDescr_GET_(ptr, attr) \
#define PyArrayDescr_GET_(ptr, attr) \
(reinterpret_cast<::pybind11::detail::PyArrayDescr_Proxy*>(ptr)->attr)
(reinterpret_cast<::pybind11::detail::PyArrayDescr_Proxy*>(ptr)->attr)
#define PyArray_FLAGS_(ptr) \
PyArray_GET_(ptr, flags)
#define PyArray_CHKFLAGS_(ptr, flag) \
#define PyArray_CHKFLAGS_(ptr, flag) \
(flag == (
reinterpret_cast<::pybind11::detail::PyArray_Proxy*>(ptr)->flags
& flag))
(flag == (
PyArray_FLAGS_(ptr)
& flag))
class
dtype
:
public
object
{
class
dtype
:
public
object
{
public:
public:
...
@@ -258,38 +260,62 @@ public:
...
@@ -258,38 +260,62 @@ public:
forcecast
=
detail
::
npy_api
::
NPY_ARRAY_FORCECAST_
forcecast
=
detail
::
npy_api
::
NPY_ARRAY_FORCECAST_
};
};
array
(
const
pybind11
::
dtype
&
dt
,
const
std
::
vector
<
size_t
>&
shape
,
array
(
const
pybind11
::
dtype
&
dt
,
const
std
::
vector
<
size_t
>
&
shape
,
const
std
::
vector
<
size_t
>&
strides
,
const
void
*
ptr
=
nullptr
)
{
const
std
::
vector
<
size_t
>
&
strides
,
const
void
*
ptr
=
nullptr
,
handle
base
=
handle
())
{
auto
&
api
=
detail
::
npy_api
::
get
();
auto
&
api
=
detail
::
npy_api
::
get
();
auto
ndim
=
shape
.
size
();
auto
ndim
=
shape
.
size
();
if
(
shape
.
size
()
!=
strides
.
size
())
if
(
shape
.
size
()
!=
strides
.
size
())
pybind11_fail
(
"NumPy: shape ndim doesn't match strides ndim"
);
pybind11_fail
(
"NumPy: shape ndim doesn't match strides ndim"
);
auto
descr
=
dt
;
auto
descr
=
dt
;
int
flags
=
0
;
if
(
base
&&
ptr
)
{
array
base_array
(
base
,
true
);
if
(
base_array
.
check
())
/* Copy flags from base (except baseship bit) */
flags
=
base_array
.
flags
()
&
~
detail
::
npy_api
::
NPY_ARRAY_OWNDATA_
;
else
/* Writable by default, easy to downgrade later on if needed */
flags
=
detail
::
npy_api
::
NPY_ARRAY_WRITEABLE_
;
}
object
tmp
(
api
.
PyArray_NewFromDescr_
(
object
tmp
(
api
.
PyArray_NewFromDescr_
(
api
.
PyArray_Type_
,
descr
.
release
().
ptr
(),
(
int
)
ndim
,
(
Py_intptr_t
*
)
shape
.
data
(),
api
.
PyArray_Type_
,
descr
.
release
().
ptr
(),
(
int
)
ndim
,
(
Py_intptr_t
*
)
shape
.
data
(),
(
Py_intptr_t
*
)
strides
.
data
(),
const_cast
<
void
*>
(
ptr
),
0
,
nullptr
),
false
);
(
Py_intptr_t
*
)
strides
.
data
(),
const_cast
<
void
*>
(
ptr
),
flags
,
nullptr
),
false
);
if
(
!
tmp
)
if
(
!
tmp
)
pybind11_fail
(
"NumPy: unable to create array!"
);
pybind11_fail
(
"NumPy: unable to create array!"
);
if
(
ptr
)
if
(
ptr
)
{
tmp
=
object
(
api
.
PyArray_NewCopy_
(
tmp
.
ptr
(),
-
1
/* any order */
),
false
);
if
(
base
)
{
PyArray_GET_
(
tmp
.
ptr
(),
base
)
=
base
.
inc_ref
().
ptr
();
}
else
{
tmp
=
object
(
api
.
PyArray_NewCopy_
(
tmp
.
ptr
(),
-
1
/* any order */
),
false
);
}
}
m_ptr
=
tmp
.
release
().
ptr
();
m_ptr
=
tmp
.
release
().
ptr
();
}
}
array
(
const
pybind11
::
dtype
&
dt
,
const
std
::
vector
<
size_t
>&
shape
,
const
void
*
ptr
=
nullptr
)
array
(
const
pybind11
::
dtype
&
dt
,
const
std
::
vector
<
size_t
>
&
shape
,
:
array
(
dt
,
shape
,
default_strides
(
shape
,
dt
.
itemsize
()),
ptr
)
{
}
const
void
*
ptr
=
nullptr
,
handle
base
=
handle
())
:
array
(
dt
,
shape
,
default_strides
(
shape
,
dt
.
itemsize
()),
ptr
,
base
)
{
}
array
(
const
pybind11
::
dtype
&
dt
,
size_t
count
,
const
void
*
ptr
=
nullptr
)
array
(
const
pybind11
::
dtype
&
dt
,
size_t
count
,
const
void
*
ptr
=
nullptr
,
:
array
(
dt
,
std
::
vector
<
size_t
>
{
count
},
ptr
)
{
}
handle
base
=
handle
())
:
array
(
dt
,
std
::
vector
<
size_t
>
{
count
},
ptr
,
base
)
{
}
template
<
typename
T
>
array
(
const
std
::
vector
<
size_t
>&
shape
,
template
<
typename
T
>
array
(
const
std
::
vector
<
size_t
>&
shape
,
const
std
::
vector
<
size_t
>&
strides
,
const
T
*
ptr
)
const
std
::
vector
<
size_t
>&
strides
,
:
array
(
pybind11
::
dtype
::
of
<
T
>
(),
shape
,
strides
,
(
void
*
)
ptr
)
{
}
const
T
*
ptr
,
handle
base
=
handle
())
:
array
(
pybind11
::
dtype
::
of
<
T
>
(),
shape
,
strides
,
(
void
*
)
ptr
,
base
)
{
}
template
<
typename
T
>
array
(
const
std
::
vector
<
size_t
>&
shape
,
const
T
*
ptr
)
template
<
typename
T
>
:
array
(
shape
,
default_strides
(
shape
,
sizeof
(
T
)),
ptr
)
{
}
array
(
const
std
::
vector
<
size_t
>
&
shape
,
const
T
*
ptr
,
handle
base
=
handle
())
:
array
(
shape
,
default_strides
(
shape
,
sizeof
(
T
)),
ptr
,
base
)
{
}
template
<
typename
T
>
array
(
size_t
count
,
const
T
*
ptr
)
template
<
typename
T
>
:
array
(
std
::
vector
<
size_t
>
{
count
},
ptr
)
{
}
array
(
size_t
count
,
const
T
*
ptr
,
handle
base
=
handle
())
:
array
(
std
::
vector
<
size_t
>
{
count
},
ptr
,
base
)
{
}
array
(
const
buffer_info
&
info
)
array
(
const
buffer_info
&
info
)
:
array
(
pybind11
::
dtype
(
info
),
info
.
shape
,
info
.
strides
,
info
.
ptr
)
{
}
:
array
(
pybind11
::
dtype
(
info
),
info
.
shape
,
info
.
strides
,
info
.
ptr
)
{
}
...
@@ -319,6 +345,11 @@ public:
...
@@ -319,6 +345,11 @@ public:
return
(
size_t
)
PyArray_GET_
(
m_ptr
,
nd
);
return
(
size_t
)
PyArray_GET_
(
m_ptr
,
nd
);
}
}
/// Base object
object
base
()
const
{
return
object
(
PyArray_GET_
(
m_ptr
,
base
),
true
);
}
/// Dimensions of the array
/// Dimensions of the array
const
size_t
*
shape
()
const
{
const
size_t
*
shape
()
const
{
return
reinterpret_cast
<
const
size_t
*>
(
PyArray_GET_
(
m_ptr
,
dimensions
));
return
reinterpret_cast
<
const
size_t
*>
(
PyArray_GET_
(
m_ptr
,
dimensions
));
...
@@ -343,6 +374,11 @@ public:
...
@@ -343,6 +374,11 @@ public:
return
strides
()[
dim
];
return
strides
()[
dim
];
}
}
/// Return the NumPy array flags
int
flags
()
const
{
return
PyArray_FLAGS_
(
m_ptr
);
}
/// If set, the array is writeable (otherwise the buffer is read-only)
/// If set, the array is writeable (otherwise the buffer is read-only)
bool
writeable
()
const
{
bool
writeable
()
const
{
return
PyArray_CHKFLAGS_
(
m_ptr
,
detail
::
npy_api
::
NPY_ARRAY_WRITEABLE_
);
return
PyArray_CHKFLAGS_
(
m_ptr
,
detail
::
npy_api
::
NPY_ARRAY_WRITEABLE_
);
...
@@ -436,14 +472,17 @@ public:
...
@@ -436,14 +472,17 @@ public:
array_t
(
const
buffer_info
&
info
)
:
array
(
info
)
{
}
array_t
(
const
buffer_info
&
info
)
:
array
(
info
)
{
}
array_t
(
const
std
::
vector
<
size_t
>&
shape
,
const
std
::
vector
<
size_t
>&
strides
,
const
T
*
ptr
=
nullptr
)
array_t
(
const
std
::
vector
<
size_t
>
&
shape
,
:
array
(
shape
,
strides
,
ptr
)
{
}
const
std
::
vector
<
size_t
>
&
strides
,
const
T
*
ptr
=
nullptr
,
handle
base
=
handle
())
:
array
(
shape
,
strides
,
ptr
,
base
)
{
}
array_t
(
const
std
::
vector
<
size_t
>&
shape
,
const
T
*
ptr
=
nullptr
)
array_t
(
const
std
::
vector
<
size_t
>
&
shape
,
const
T
*
ptr
=
nullptr
,
:
array
(
shape
,
ptr
)
{
}
handle
base
=
handle
())
:
array
(
shape
,
ptr
,
base
)
{
}
array_t
(
size_t
count
,
const
T
*
ptr
=
nullptr
)
array_t
(
size_t
count
,
const
T
*
ptr
=
nullptr
,
handle
base
=
handle
()
)
:
array
(
count
,
ptr
)
{
}
:
array
(
count
,
ptr
,
base
)
{
}
constexpr
size_t
itemsize
()
const
{
constexpr
size_t
itemsize
()
const
{
return
sizeof
(
T
);
return
sizeof
(
T
);
...
...
include/pybind11/pybind11.h
View file @
00488a3e
...
@@ -567,7 +567,7 @@ public:
...
@@ -567,7 +567,7 @@ public:
static
module
import
(
const
char
*
name
)
{
static
module
import
(
const
char
*
name
)
{
PyObject
*
obj
=
PyImport_ImportModule
(
name
);
PyObject
*
obj
=
PyImport_ImportModule
(
name
);
if
(
!
obj
)
if
(
!
obj
)
pybind11_fail
(
"Module
\"
"
+
std
::
string
(
name
)
+
"
\"
not found!"
);
throw
import_error
(
"Module
\"
"
+
std
::
string
(
name
)
+
"
\"
not found!"
);
return
module
(
obj
,
false
);
return
module
(
obj
,
false
);
}
}
};
};
...
@@ -1344,15 +1344,27 @@ PYBIND11_NOINLINE inline void print(tuple args, dict kwargs) {
...
@@ -1344,15 +1344,27 @@ PYBIND11_NOINLINE inline void print(tuple args, dict kwargs) {
auto
sep
=
kwargs
.
contains
(
"sep"
)
?
kwargs
[
"sep"
]
:
cast
(
" "
);
auto
sep
=
kwargs
.
contains
(
"sep"
)
?
kwargs
[
"sep"
]
:
cast
(
" "
);
auto
line
=
sep
.
attr
(
"join"
)(
strings
);
auto
line
=
sep
.
attr
(
"join"
)(
strings
);
auto
file
=
kwargs
.
contains
(
"file"
)
?
kwargs
[
"file"
].
cast
<
object
>
()
object
file
;
:
module
::
import
(
"sys"
).
attr
(
"stdout"
);
if
(
kwargs
.
contains
(
"file"
))
{
file
=
kwargs
[
"file"
].
cast
<
object
>
();
}
else
{
try
{
file
=
module
::
import
(
"sys"
).
attr
(
"stdout"
);
}
catch
(
const
import_error
&
)
{
/* If print() is called from code that is executed as
part of garbage collection during interpreter shutdown,
importing 'sys' can fail. Give up rather than crashing the
interpreter in this case. */
return
;
}
}
auto
write
=
file
.
attr
(
"write"
);
auto
write
=
file
.
attr
(
"write"
);
write
(
line
);
write
(
line
);
write
(
kwargs
.
contains
(
"end"
)
?
kwargs
[
"end"
]
:
cast
(
"
\n
"
));
write
(
kwargs
.
contains
(
"end"
)
?
kwargs
[
"end"
]
:
cast
(
"
\n
"
));
if
(
kwargs
.
contains
(
"flush"
)
&&
kwargs
[
"flush"
].
cast
<
bool
>
())
{
if
(
kwargs
.
contains
(
"flush"
)
&&
kwargs
[
"flush"
].
cast
<
bool
>
())
file
.
attr
(
"flush"
)();
file
.
attr
(
"flush"
)();
}
}
}
NAMESPACE_END
(
detail
)
NAMESPACE_END
(
detail
)
...
...
tests/test_numpy_array.cpp
View file @
00488a3e
...
@@ -99,4 +99,29 @@ test_initializer numpy_array([](py::module &m) {
...
@@ -99,4 +99,29 @@ test_initializer numpy_array([](py::module &m) {
sm
.
def
(
"make_c_array"
,
[]
{
sm
.
def
(
"make_c_array"
,
[]
{
return
py
::
array_t
<
float
>
({
2
,
2
},
{
8
,
4
});
return
py
::
array_t
<
float
>
({
2
,
2
},
{
8
,
4
});
});
});
sm
.
def
(
"wrap"
,
[](
py
::
array
a
)
{
return
py
::
array
(
a
.
dtype
(),
std
::
vector
<
size_t
>
(
a
.
shape
(),
a
.
shape
()
+
a
.
ndim
()),
std
::
vector
<
size_t
>
(
a
.
strides
(),
a
.
strides
()
+
a
.
ndim
()),
a
.
data
(),
a
);
});
struct
ArrayClass
{
int
data
[
2
]
=
{
1
,
2
};
ArrayClass
()
{
py
::
print
(
"ArrayClass()"
);
}
~
ArrayClass
()
{
py
::
print
(
"~ArrayClass()"
);
}
};
py
::
class_
<
ArrayClass
>
(
sm
,
"ArrayClass"
)
.
def
(
py
::
init
<>
())
.
def
(
"numpy_view"
,
[](
py
::
object
&
obj
)
{
py
::
print
(
"ArrayClass::numpy_view()"
);
ArrayClass
&
a
=
obj
.
cast
<
ArrayClass
&>
();
return
py
::
array_t
<
int
>
({
2
},
{
4
},
a
.
data
,
obj
);
}
);
});
});
tests/test_numpy_array.py
View file @
00488a3e
import
pytest
import
pytest
import
gc
with
pytest
.
suppress
(
ImportError
):
with
pytest
.
suppress
(
ImportError
):
import
numpy
as
np
import
numpy
as
np
...
@@ -149,6 +150,7 @@ def test_bounds_check(arr):
...
@@ -149,6 +150,7 @@ def test_bounds_check(arr):
index_at
(
arr
,
0
,
4
)
index_at
(
arr
,
0
,
4
)
assert
str
(
excinfo
.
value
)
==
'index 4 is out of bounds for axis 1 with size 3'
assert
str
(
excinfo
.
value
)
==
'index 4 is out of bounds for axis 1 with size 3'
@
pytest
.
requires_numpy
@
pytest
.
requires_numpy
def
test_make_c_f_array
():
def
test_make_c_f_array
():
from
pybind11_tests.array
import
(
from
pybind11_tests.array
import
(
...
@@ -158,3 +160,81 @@ def test_make_c_f_array():
...
@@ -158,3 +160,81 @@ def test_make_c_f_array():
assert
not
make_c_array
().
flags
.
f_contiguous
assert
not
make_c_array
().
flags
.
f_contiguous
assert
make_f_array
().
flags
.
f_contiguous
assert
make_f_array
().
flags
.
f_contiguous
assert
not
make_f_array
().
flags
.
c_contiguous
assert
not
make_f_array
().
flags
.
c_contiguous
@
pytest
.
requires_numpy
def
test_wrap
():
from
pybind11_tests.array
import
wrap
def
assert_references
(
A
,
B
):
assert
A
is
not
B
assert
A
.
__array_interface__
[
'data'
][
0
]
==
\
B
.
__array_interface__
[
'data'
][
0
]
assert
A
.
shape
==
B
.
shape
assert
A
.
strides
==
B
.
strides
assert
A
.
flags
.
c_contiguous
==
B
.
flags
.
c_contiguous
assert
A
.
flags
.
f_contiguous
==
B
.
flags
.
f_contiguous
assert
A
.
flags
.
writeable
==
B
.
flags
.
writeable
assert
A
.
flags
.
aligned
==
B
.
flags
.
aligned
assert
A
.
flags
.
updateifcopy
==
B
.
flags
.
updateifcopy
assert
np
.
all
(
A
==
B
)
assert
not
B
.
flags
.
owndata
assert
B
.
base
is
A
if
A
.
flags
.
writeable
and
A
.
ndim
==
2
:
A
[
0
,
0
]
=
1234
assert
B
[
0
,
0
]
==
1234
A1
=
np
.
array
([
1
,
2
],
dtype
=
np
.
int16
)
assert
A1
.
flags
.
owndata
and
A1
.
base
is
None
A2
=
wrap
(
A1
)
assert_references
(
A1
,
A2
)
A1
=
np
.
array
([[
1
,
2
],
[
3
,
4
]],
dtype
=
np
.
float32
,
order
=
'F'
)
assert
A1
.
flags
.
owndata
and
A1
.
base
is
None
A2
=
wrap
(
A1
)
assert_references
(
A1
,
A2
)
A1
=
np
.
array
([[
1
,
2
],
[
3
,
4
]],
dtype
=
np
.
float32
,
order
=
'C'
)
A1
.
flags
.
writeable
=
False
A2
=
wrap
(
A1
)
assert_references
(
A1
,
A2
)
A1
=
np
.
random
.
random
((
4
,
4
,
4
))
A2
=
wrap
(
A1
)
assert_references
(
A1
,
A2
)
A1
=
A1
.
transpose
()
A2
=
wrap
(
A1
)
assert_references
(
A1
,
A2
)
A1
=
A1
.
diagonal
()
A2
=
wrap
(
A1
)
assert_references
(
A1
,
A2
)
@
pytest
.
requires_numpy
def
test_numpy_view
(
capture
):
from
pybind11_tests.array
import
ArrayClass
with
capture
:
ac
=
ArrayClass
()
ac_view_1
=
ac
.
numpy_view
()
ac_view_2
=
ac
.
numpy_view
()
assert
np
.
all
(
ac_view_1
==
np
.
array
([
1
,
2
],
dtype
=
np
.
int32
))
del
ac
gc
.
collect
()
assert
capture
==
"""
ArrayClass()
ArrayClass::numpy_view()
ArrayClass::numpy_view()
"""
ac_view_1
[
0
]
=
4
ac_view_1
[
1
]
=
3
assert
ac_view_2
[
0
]
==
4
assert
ac_view_2
[
1
]
==
3
with
capture
:
del
ac_view_1
del
ac_view_2
gc
.
collect
()
assert
capture
==
"""
~ArrayClass()
"""
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment