Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
pybind11
Commits
91b3d681
Commit
91b3d681
authored
Aug 29, 2016
by
Ivan Smirnov
Browse files
Expose some dtype/array attributes via NumPy C API
parent
720136bf
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
246 additions
and
56 deletions
+246
-56
include/pybind11/eigen.h
include/pybind11/eigen.h
+26
-36
include/pybind11/numpy.h
include/pybind11/numpy.h
+131
-20
tests/CMakeLists.txt
tests/CMakeLists.txt
+1
-0
tests/test_numpy_array.cpp
tests/test_numpy_array.cpp
+45
-0
tests/test_numpy_array.py
tests/test_numpy_array.py
+43
-0
No files found.
include/pybind11/eigen.h
View file @
91b3d681
...
...
@@ -83,12 +83,11 @@ struct type_caster<Type, typename std::enable_if<is_eigen_dense<Type>::value &&
static
constexpr
bool
isVector
=
Type
::
IsVectorAtCompileTime
;
bool
load
(
handle
src
,
bool
)
{
array_t
<
Scalar
>
buf
fer
(
src
,
true
);
if
(
!
buf
fer
.
check
())
return
false
;
array_t
<
Scalar
>
buf
(
src
,
true
);
if
(
!
buf
.
check
())
return
false
;
auto
info
=
buffer
.
request
();
if
(
info
.
ndim
==
1
)
{
if
(
buf
.
ndim
()
==
1
)
{
typedef
Eigen
::
InnerStride
<>
Strides
;
if
(
!
isVector
&&
!
(
Type
::
RowsAtCompileTime
==
Eigen
::
Dynamic
&&
...
...
@@ -96,31 +95,32 @@ struct type_caster<Type, typename std::enable_if<is_eigen_dense<Type>::value &&
return
false
;
if
(
Type
::
SizeAtCompileTime
!=
Eigen
::
Dynamic
&&
info
.
shape
[
0
]
!=
(
size_t
)
Type
::
SizeAtCompileTime
)
buf
.
shape
(
0
)
!=
(
size_t
)
Type
::
SizeAtCompileTime
)
return
false
;
auto
strides
=
Strides
(
info
.
strides
[
0
]
/
sizeof
(
Scalar
));
Strides
::
Index
n_elts
=
(
Strides
::
Index
)
info
.
shape
[
0
];
Strides
::
Index
n_elts
=
(
Strides
::
Index
)
buf
.
shape
(
0
);
Strides
::
Index
unity
=
1
;
value
=
Eigen
::
Map
<
Type
,
0
,
Strides
>
(
(
Scalar
*
)
info
.
ptr
,
rowMajor
?
unity
:
n_elts
,
rowMajor
?
n_elts
:
unity
,
strides
);
}
else
if
(
info
.
ndim
==
2
)
{
buf
.
mutable_data
(),
rowMajor
?
unity
:
n_elts
,
rowMajor
?
n_elts
:
unity
,
Strides
(
buf
.
strides
(
0
)
/
sizeof
(
Scalar
))
);
}
else
if
(
buf
.
ndim
()
==
2
)
{
typedef
Eigen
::
Stride
<
Eigen
::
Dynamic
,
Eigen
::
Dynamic
>
Strides
;
if
((
Type
::
RowsAtCompileTime
!=
Eigen
::
Dynamic
&&
info
.
shape
[
0
]
!=
(
size_t
)
Type
::
RowsAtCompileTime
)
||
(
Type
::
ColsAtCompileTime
!=
Eigen
::
Dynamic
&&
info
.
shape
[
1
]
!=
(
size_t
)
Type
::
ColsAtCompileTime
))
if
((
Type
::
RowsAtCompileTime
!=
Eigen
::
Dynamic
&&
buf
.
shape
(
0
)
!=
(
size_t
)
Type
::
RowsAtCompileTime
)
||
(
Type
::
ColsAtCompileTime
!=
Eigen
::
Dynamic
&&
buf
.
shape
(
1
)
!=
(
size_t
)
Type
::
ColsAtCompileTime
))
return
false
;
auto
strides
=
Strides
(
info
.
strides
[
rowMajor
?
0
:
1
]
/
sizeof
(
Scalar
),
info
.
strides
[
rowMajor
?
1
:
0
]
/
sizeof
(
Scalar
));
value
=
Eigen
::
Map
<
Type
,
0
,
Strides
>
(
(
Scalar
*
)
info
.
ptr
,
typename
Strides
::
Index
(
info
.
shape
[
0
]),
typename
Strides
::
Index
(
info
.
shape
[
1
]),
strides
);
buf
.
mutable_data
(),
typename
Strides
::
Index
(
buf
.
shape
(
0
)),
typename
Strides
::
Index
(
buf
.
shape
(
1
)),
Strides
(
buf
.
strides
(
rowMajor
?
0
:
1
)
/
sizeof
(
Scalar
),
buf
.
strides
(
rowMajor
?
1
:
0
)
/
sizeof
(
Scalar
))
);
}
else
{
return
false
;
}
...
...
@@ -222,28 +222,18 @@ struct type_caster<Type, typename std::enable_if<is_eigen_sparse<Type>::value>::
}
}
auto
values
Array
=
array_t
<
Scalar
>
((
object
)
obj
.
attr
(
"data"
));
auto
innerIndices
Array
=
array_t
<
StorageIndex
>
((
object
)
obj
.
attr
(
"indices"
));
auto
outerIndices
Array
=
array_t
<
StorageIndex
>
((
object
)
obj
.
attr
(
"indptr"
));
auto
values
=
array_t
<
Scalar
>
((
object
)
obj
.
attr
(
"data"
));
auto
innerIndices
=
array_t
<
StorageIndex
>
((
object
)
obj
.
attr
(
"indices"
));
auto
outerIndices
=
array_t
<
StorageIndex
>
((
object
)
obj
.
attr
(
"indptr"
));
auto
shape
=
pybind11
::
tuple
((
pybind11
::
object
)
obj
.
attr
(
"shape"
));
auto
nnz
=
obj
.
attr
(
"nnz"
).
cast
<
Index
>
();
if
(
!
valuesArray
.
check
()
||
!
innerIndicesArray
.
check
()
||
!
outerIndicesArray
.
check
())
if
(
!
values
.
check
()
||
!
innerIndices
.
check
()
||
!
outerIndices
.
check
())
return
false
;
auto
outerIndices
=
outerIndicesArray
.
request
();
auto
innerIndices
=
innerIndicesArray
.
request
();
auto
values
=
valuesArray
.
request
();
value
=
Eigen
::
MappedSparseMatrix
<
Scalar
,
Type
::
Flags
,
StorageIndex
>
(
shape
[
0
].
cast
<
Index
>
(),
shape
[
1
].
cast
<
Index
>
(),
nnz
,
static_cast
<
StorageIndex
*>
(
outerIndices
.
ptr
),
static_cast
<
StorageIndex
*>
(
innerIndices
.
ptr
),
static_cast
<
Scalar
*>
(
values
.
ptr
)
);
shape
[
0
].
cast
<
Index
>
(),
shape
[
1
].
cast
<
Index
>
(),
nnz
,
outerIndices
.
mutable_data
(),
innerIndices
.
mutable_data
(),
values
.
mutable_data
());
return
true
;
}
...
...
include/pybind11/numpy.h
View file @
91b3d681
...
...
@@ -19,6 +19,7 @@
#include <sstream>
#include <string>
#include <initializer_list>
#include <functional>
#if defined(_MSC_VER)
#pragma warning(push)
...
...
@@ -30,12 +31,41 @@ namespace detail {
template
<
typename
type
,
typename
SFINAE
=
void
>
struct
npy_format_descriptor
{
};
template
<
typename
type
>
struct
is_pod_struct
;
struct
PyArrayDescr_Proxy
{
PyObject_HEAD
PyObject
*
typeobj
;
char
kind
;
char
type
;
char
byteorder
;
char
flags
;
int
type_num
;
int
elsize
;
int
alignment
;
char
*
subarray
;
PyObject
*
fields
;
PyObject
*
names
;
};
struct
PyArray_Proxy
{
PyObject_HEAD
char
*
data
;
int
nd
;
ssize_t
*
dimensions
;
ssize_t
*
strides
;
PyObject
*
base
;
PyObject
*
descr
;
int
flags
;
};
struct
npy_api
{
enum
constants
{
NPY_C_CONTIGUOUS_
=
0x0001
,
NPY_F_CONTIGUOUS_
=
0x0002
,
NPY_ARRAY_OWNDATA_
=
0x0004
,
NPY_ARRAY_FORCECAST_
=
0x0010
,
NPY_ENSURE_ARRAY_
=
0x0040
,
NPY_ARRAY_ALIGNED_
=
0x0100
,
NPY_ARRAY_WRITEABLE_
=
0x0400
,
NPY_BOOL_
=
0
,
NPY_BYTE_
,
NPY_UBYTE_
,
NPY_SHORT_
,
NPY_USHORT_
,
...
...
@@ -113,6 +143,11 @@ private:
};
}
#define PyArray_GET_(ptr, attr) (reinterpret_cast<::pybind11::detail::PyArray_Proxy*>(ptr)->attr)
#define PyArrayDescr_GET_(ptr, attr) (reinterpret_cast<::pybind11::detail::PyArrayDescr_Proxy*>(ptr)->attr)
#define PyArray_CHKFLAGS_(ptr, flag) \
(flag == (reinterpret_cast<::pybind11::detail::PyArray_Proxy*>(ptr)->flags & flag))
class
dtype
:
public
object
{
public:
PYBIND11_OBJECT_DEFAULT
(
dtype
,
object
,
detail
::
npy_api
::
get
().
PyArrayDescr_Check_
);
...
...
@@ -150,15 +185,15 @@ public:
}
size_t
itemsize
()
const
{
return
attr
(
"itemsize"
).
cast
<
size_t
>
(
);
return
(
size_t
)
PyArrayDescr_GET_
(
m_ptr
,
elsize
);
}
bool
has_fields
()
const
{
return
attr
(
"fields"
).
cast
<
object
>
().
ptr
()
!=
Py_None
;
return
PyArrayDescr_GET_
(
m_ptr
,
names
)
!=
nullptr
;
}
std
::
string
kind
()
const
{
return
(
std
::
string
)
attr
(
"kind"
).
cast
<
pybind11
::
str
>
(
);
char
kind
()
const
{
return
PyArrayDescr_GET_
(
m_ptr
,
kind
);
}
private:
...
...
@@ -171,20 +206,20 @@ private:
dtype
strip_padding
()
{
// Recursively strip all void fields with empty names that are generated for
// padding fields (as of NumPy v1.11).
auto
fields
=
attr
(
"fields"
).
cast
<
object
>
();
if
(
fields
.
ptr
()
==
Py_None
)
if
(
!
has_fields
())
return
*
this
;
struct
field_descr
{
PYBIND11_STR_TYPE
name
;
object
format
;
pybind11
::
int_
offset
;
};
std
::
vector
<
field_descr
>
field_descriptors
;
auto
fields
=
attr
(
"fields"
).
cast
<
object
>
();
auto
items
=
fields
.
attr
(
"items"
).
cast
<
object
>
();
for
(
auto
field
:
items
())
{
auto
spec
=
object
(
field
,
true
).
cast
<
tuple
>
();
auto
name
=
spec
[
0
].
cast
<
pybind11
::
str
>
();
auto
format
=
spec
[
1
].
cast
<
tuple
>
()[
0
].
cast
<
dtype
>
();
auto
offset
=
spec
[
1
].
cast
<
tuple
>
()[
1
].
cast
<
pybind11
::
int_
>
();
if
(
!
len
(
name
)
&&
format
.
kind
()
==
"V"
)
if
(
!
len
(
name
)
&&
format
.
kind
()
==
'V'
)
continue
;
field_descriptors
.
push_back
({(
PYBIND11_STR_TYPE
)
name
,
format
.
strip_padding
(),
offset
});
}
...
...
@@ -244,14 +279,83 @@ public:
template
<
typename
T
>
array
(
const
std
::
vector
<
size_t
>&
shape
,
T
*
ptr
)
:
array
(
shape
,
default_strides
(
shape
,
sizeof
(
T
)),
ptr
)
{
}
template
<
typename
T
>
array
(
size_t
size
,
T
*
ptr
)
:
array
(
std
::
vector
<
size_t
>
{
size
},
ptr
)
{
}
template
<
typename
T
>
array
(
size_t
count
,
T
*
ptr
)
:
array
(
std
::
vector
<
size_t
>
{
count
},
ptr
)
{
}
array
(
const
buffer_info
&
info
)
:
array
(
pybind11
::
dtype
(
info
),
info
.
shape
,
info
.
strides
,
info
.
ptr
)
{
}
pybind11
::
dtype
dtype
()
{
return
attr
(
"dtype"
).
cast
<
pybind11
::
dtype
>
();
/// Array descriptor (dtype)
pybind11
::
dtype
dtype
()
const
{
return
object
(
PyArray_GET_
(
m_ptr
,
descr
),
true
);
}
/// Total number of elements
size_t
size
()
const
{
return
std
::
accumulate
(
shape
(),
shape
()
+
ndim
(),
(
size_t
)
1
,
std
::
multiplies
<
size_t
>
());
}
/// Byte size of a single element
size_t
itemsize
()
const
{
return
(
size_t
)
PyArrayDescr_GET_
(
PyArray_GET_
(
m_ptr
,
descr
),
elsize
);
}
/// Total number of bytes
size_t
nbytes
()
const
{
return
size
()
*
itemsize
();
}
/// Number of dimensions
size_t
ndim
()
const
{
return
(
size_t
)
PyArray_GET_
(
m_ptr
,
nd
);
}
/// Dimensions of the array
const
size_t
*
shape
()
const
{
static_assert
(
sizeof
(
size_t
)
==
sizeof
(
Py_intptr_t
),
"size_t != Py_intptr_t"
);
return
reinterpret_cast
<
const
size_t
*>
(
PyArray_GET_
(
m_ptr
,
dimensions
));
}
/// Dimension along a given axis
size_t
shape
(
size_t
dim
)
const
{
if
(
dim
>=
ndim
())
pybind11_fail
(
"NumPy: attempted to index shape beyond ndim"
);
return
shape
()[
dim
];
}
/// Strides of the array
const
size_t
*
strides
()
const
{
static_assert
(
sizeof
(
size_t
)
==
sizeof
(
Py_intptr_t
),
"size_t != Py_intptr_t"
);
return
reinterpret_cast
<
const
size_t
*>
(
PyArray_GET_
(
m_ptr
,
strides
));
}
/// Stride along a given axis
size_t
strides
(
size_t
dim
)
const
{
if
(
dim
>=
ndim
())
pybind11_fail
(
"NumPy: attempted to index strides beyond ndim"
);
return
strides
()[
dim
];
}
/// If set, the array is writeable (otherwise the buffer is read-only)
bool
writeable
()
const
{
return
PyArray_CHKFLAGS_
(
m_ptr
,
detail
::
npy_api
::
NPY_ARRAY_WRITEABLE_
);
}
/// If set, the array owns the data (will be freed when the array is deleted)
bool
owndata
()
const
{
return
PyArray_CHKFLAGS_
(
m_ptr
,
detail
::
npy_api
::
NPY_ARRAY_OWNDATA_
);
}
/// Direct pointer to contained buffer
const
void
*
data
()
const
{
return
reinterpret_cast
<
const
void
*>
(
PyArray_GET_
(
m_ptr
,
data
));
}
/// Direct mutable pointer to contained buffer (checks writeable flag)
void
*
mutable_data
()
{
if
(
!
writeable
())
pybind11_fail
(
"NumPy: cannot get mutable data of a read-only array"
);
return
reinterpret_cast
<
void
*>
(
PyArray_GET_
(
m_ptr
,
data
));
}
protected:
...
...
@@ -284,8 +388,18 @@ public:
array_t
(
const
std
::
vector
<
size_t
>&
shape
,
T
*
ptr
=
nullptr
)
:
array
(
shape
,
ptr
)
{
}
array_t
(
size_t
size
,
T
*
ptr
=
nullptr
)
:
array
(
size
,
ptr
)
{
}
array_t
(
size_t
count
,
T
*
ptr
=
nullptr
)
:
array
(
count
,
ptr
)
{
}
const
T
*
data
()
const
{
return
reinterpret_cast
<
const
T
*>
(
PyArray_GET_
(
m_ptr
,
data
));
}
T
*
mutable_data
()
{
if
(
!
writeable
())
pybind11_fail
(
"NumPy: cannot get mutable data of a read-only array"
);
return
reinterpret_cast
<
T
*>
(
PyArray_GET_
(
m_ptr
,
data
));
}
static
bool
is_non_null
(
PyObject
*
ptr
)
{
return
ptr
!=
nullptr
;
}
...
...
@@ -678,16 +792,13 @@ struct vectorize_helper {
if
(
size
==
1
)
return
cast
(
f
(
*
((
Args
*
)
buffers
[
Index
].
ptr
)...));
array
result
(
buffer_info
(
nullptr
,
sizeof
(
Return
),
format_descriptor
<
Return
>::
format
(),
ndim
,
shape
,
strides
));
buffer_info
buf
=
result
.
request
();
Return
*
output
=
(
Return
*
)
buf
.
ptr
;
array_t
<
Return
>
result
(
shape
,
strides
);
auto
buf
=
result
.
request
();
auto
output
=
(
Return
*
)
buf
.
ptr
;
if
(
trivial_broadcast
)
{
/* Call the function */
for
(
size_t
i
=
0
;
i
<
size
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
size
;
++
i
)
{
output
[
i
]
=
f
((
buffers
[
Index
].
size
==
1
?
*
((
Args
*
)
buffers
[
Index
].
ptr
)
:
((
Args
*
)
buffers
[
Index
].
ptr
)[
i
])...);
...
...
tests/CMakeLists.txt
View file @
91b3d681
...
...
@@ -19,6 +19,7 @@ set(PYBIND11_TEST_FILES
test_kwargs_and_defaults.cpp
test_methods_and_attributes.cpp
test_modules.cpp
test_numpy_array.cpp
test_numpy_dtypes.cpp
test_numpy_vectorize.cpp
test_opaque_types.cpp
...
...
tests/test_numpy_array.cpp
0 → 100644
View file @
91b3d681
/*
tests/test_numpy_array.cpp -- test core array functionality
Copyright (c) 2016 Ivan Smirnov <i.s.smirnov@gmail.com>
All rights reserved. Use of this source code is governed by a
BSD-style license that can be found in the LICENSE file.
*/
#include "pybind11_tests.h"
#include <pybind11/numpy.h>
#include <pybind11/stl.h>
test_initializer
numpy_array
([](
py
::
module
&
m
)
{
m
.
def
(
"get_arr_ndim"
,
[](
const
py
::
array
&
arr
)
{
return
arr
.
ndim
();
});
m
.
def
(
"get_arr_shape"
,
[](
const
py
::
array
&
arr
)
{
return
std
::
vector
<
size_t
>
(
arr
.
shape
(),
arr
.
shape
()
+
arr
.
ndim
());
});
m
.
def
(
"get_arr_shape"
,
[](
const
py
::
array
&
arr
,
size_t
dim
)
{
return
arr
.
shape
(
dim
);
});
m
.
def
(
"get_arr_strides"
,
[](
const
py
::
array
&
arr
)
{
return
std
::
vector
<
size_t
>
(
arr
.
strides
(),
arr
.
strides
()
+
arr
.
ndim
());
});
m
.
def
(
"get_arr_strides"
,
[](
const
py
::
array
&
arr
,
size_t
dim
)
{
return
arr
.
strides
(
dim
);
});
m
.
def
(
"get_arr_writeable"
,
[](
const
py
::
array
&
arr
)
{
return
arr
.
writeable
();
});
m
.
def
(
"get_arr_size"
,
[](
const
py
::
array
&
arr
)
{
return
arr
.
size
();
});
m
.
def
(
"get_arr_itemsize"
,
[](
const
py
::
array
&
arr
)
{
return
arr
.
itemsize
();
});
m
.
def
(
"get_arr_nbytes"
,
[](
const
py
::
array
&
arr
)
{
return
arr
.
nbytes
();
});
m
.
def
(
"get_arr_owndata"
,
[](
const
py
::
array
&
arr
)
{
return
arr
.
owndata
();
});
});
tests/test_numpy_array.py
0 → 100644
View file @
91b3d681
import
pytest
with
pytest
.
suppress
(
ImportError
):
import
numpy
as
np
@
pytest
.
requires_numpy
def
test_array_attributes
():
from
pybind11_tests
import
(
get_arr_ndim
,
get_arr_shape
,
get_arr_strides
,
get_arr_writeable
,
get_arr_size
,
get_arr_itemsize
,
get_arr_nbytes
,
get_arr_owndata
)
a
=
np
.
array
(
0
,
'f8'
)
assert
get_arr_ndim
(
a
)
==
0
assert
get_arr_shape
(
a
)
==
[]
assert
get_arr_strides
(
a
)
==
[]
with
pytest
.
raises
(
RuntimeError
):
get_arr_shape
(
a
,
1
)
with
pytest
.
raises
(
RuntimeError
):
get_arr_strides
(
a
,
0
)
assert
get_arr_writeable
(
a
)
assert
get_arr_size
(
a
)
==
1
assert
get_arr_itemsize
(
a
)
==
8
assert
get_arr_nbytes
(
a
)
==
8
assert
get_arr_owndata
(
a
)
a
=
np
.
array
([[
1
,
2
,
3
],
[
4
,
5
,
6
]],
'u2'
).
view
()
a
.
flags
.
writeable
=
False
assert
get_arr_ndim
(
a
)
==
2
assert
get_arr_shape
(
a
)
==
[
2
,
3
]
assert
get_arr_shape
(
a
,
0
)
==
2
assert
get_arr_shape
(
a
,
1
)
==
3
assert
get_arr_strides
(
a
)
==
[
6
,
2
]
assert
get_arr_strides
(
a
,
0
)
==
6
assert
get_arr_strides
(
a
,
1
)
==
2
with
pytest
.
raises
(
RuntimeError
):
get_arr_shape
(
a
,
2
)
with
pytest
.
raises
(
RuntimeError
):
get_arr_strides
(
a
,
2
)
assert
not
get_arr_writeable
(
a
)
assert
get_arr_size
(
a
)
==
6
assert
get_arr_itemsize
(
a
)
==
2
assert
get_arr_nbytes
(
a
)
==
12
assert
not
get_arr_owndata
(
a
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment