Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
pybind11
Commits
f2a0ad58
Commit
f2a0ad58
authored
Sep 08, 2016
by
Ivan Smirnov
Browse files
array: add direct data access and indexing methods
parent
91b3d681
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
102 additions
and
33 deletions
+102
-33
include/pybind11/numpy.h
include/pybind11/numpy.h
+102
-33
No files found.
include/pybind11/numpy.h
View file @
f2a0ad58
...
...
@@ -26,8 +26,14 @@
#pragma warning(disable: 4127) // warning C4127: Conditional expression is constant
#endif
/* This will be true on all flat address space platforms and allows us to reduce the
whole npy_intp / size_t / Py_intptr_t business down to just size_t for all size
and dimension types (e.g. shape, strides, indexing), instead of inflicting this
upon the library user. */
static_assert
(
sizeof
(
size_t
)
==
sizeof
(
Py_intptr_t
),
"size_t != Py_intptr_t"
);
NAMESPACE_BEGIN
(
pybind11
)
namespace
detail
{
NAMESPACE_BEGIN
(
detail
)
template
<
typename
type
,
typename
SFINAE
=
void
>
struct
npy_format_descriptor
{
};
template
<
typename
type
>
struct
is_pod_struct
;
...
...
@@ -141,10 +147,12 @@ private:
return
api
;
}
};
}
NAMESPACE_END
(
detail
)
#define PyArray_GET_(ptr, attr) (reinterpret_cast<::pybind11::detail::PyArray_Proxy*>(ptr)->attr)
#define PyArrayDescr_GET_(ptr, attr) (reinterpret_cast<::pybind11::detail::PyArrayDescr_Proxy*>(ptr)->attr)
#define PyArray_GET_(ptr, attr) \
(reinterpret_cast<::pybind11::detail::PyArray_Proxy*>(ptr)->attr)
#define PyArrayDescr_GET_(ptr, attr) \
(reinterpret_cast<::pybind11::detail::PyArrayDescr_Proxy*>(ptr)->attr)
#define PyArray_CHKFLAGS_(ptr, flag) \
(flag == (reinterpret_cast<::pybind11::detail::PyArray_Proxy*>(ptr)->flags & flag))
...
...
@@ -250,7 +258,7 @@ public:
};
array
(
const
pybind11
::
dtype
&
dt
,
const
std
::
vector
<
size_t
>&
shape
,
const
std
::
vector
<
size_t
>&
strides
,
void
*
ptr
=
nullptr
)
{
const
std
::
vector
<
size_t
>&
strides
,
const
void
*
ptr
=
nullptr
)
{
auto
&
api
=
detail
::
npy_api
::
get
();
auto
ndim
=
shape
.
size
();
if
(
shape
.
size
()
!=
strides
.
size
())
...
...
@@ -258,7 +266,7 @@ public:
auto
descr
=
dt
;
object
tmp
(
api
.
PyArray_NewFromDescr_
(
api
.
PyArray_Type_
,
descr
.
release
().
ptr
(),
(
int
)
ndim
,
(
Py_intptr_t
*
)
shape
.
data
(),
(
Py_intptr_t
*
)
strides
.
data
(),
ptr
,
0
,
nullptr
),
false
);
(
Py_intptr_t
*
)
strides
.
data
(),
const_cast
<
void
*>
(
ptr
)
,
0
,
nullptr
),
false
);
if
(
!
tmp
)
pybind11_fail
(
"NumPy: unable to create array!"
);
if
(
ptr
)
...
...
@@ -266,20 +274,20 @@ public:
m_ptr
=
tmp
.
release
().
ptr
();
}
array
(
const
pybind11
::
dtype
&
dt
,
const
std
::
vector
<
size_t
>&
shape
,
void
*
ptr
=
nullptr
)
array
(
const
pybind11
::
dtype
&
dt
,
const
std
::
vector
<
size_t
>&
shape
,
const
void
*
ptr
=
nullptr
)
:
array
(
dt
,
shape
,
default_strides
(
shape
,
dt
.
itemsize
()),
ptr
)
{
}
array
(
const
pybind11
::
dtype
&
dt
,
size_t
count
,
void
*
ptr
=
nullptr
)
array
(
const
pybind11
::
dtype
&
dt
,
size_t
count
,
const
void
*
ptr
=
nullptr
)
:
array
(
dt
,
std
::
vector
<
size_t
>
{
count
},
ptr
)
{
}
template
<
typename
T
>
array
(
const
std
::
vector
<
size_t
>&
shape
,
const
std
::
vector
<
size_t
>&
strides
,
T
*
ptr
)
const
std
::
vector
<
size_t
>&
strides
,
const
T
*
ptr
)
:
array
(
pybind11
::
dtype
::
of
<
T
>
(),
shape
,
strides
,
(
void
*
)
ptr
)
{
}
template
<
typename
T
>
array
(
const
std
::
vector
<
size_t
>&
shape
,
T
*
ptr
)
template
<
typename
T
>
array
(
const
std
::
vector
<
size_t
>&
shape
,
const
T
*
ptr
)
:
array
(
shape
,
default_strides
(
shape
,
sizeof
(
T
)),
ptr
)
{
}
template
<
typename
T
>
array
(
size_t
count
,
T
*
ptr
)
template
<
typename
T
>
array
(
size_t
count
,
const
T
*
ptr
)
:
array
(
std
::
vector
<
size_t
>
{
count
},
ptr
)
{
}
array
(
const
buffer_info
&
info
)
...
...
@@ -312,27 +320,25 @@ public:
/// Dimensions of the array
const
size_t
*
shape
()
const
{
static_assert
(
sizeof
(
size_t
)
==
sizeof
(
Py_intptr_t
),
"size_t != Py_intptr_t"
);
return
reinterpret_cast
<
const
size_t
*>
(
PyArray_GET_
(
m_ptr
,
dimensions
));
}
/// Dimension along a given axis
size_t
shape
(
size_t
dim
)
const
{
if
(
dim
>=
ndim
())
pybind11_fail
(
"NumPy: attempted to index shape beyond ndim
"
);
fail_dim_check
(
dim
,
"invalid axis
"
);
return
shape
()[
dim
];
}
/// Strides of the array
const
size_t
*
strides
()
const
{
static_assert
(
sizeof
(
size_t
)
==
sizeof
(
Py_intptr_t
),
"size_t != Py_intptr_t"
);
return
reinterpret_cast
<
const
size_t
*>
(
PyArray_GET_
(
m_ptr
,
strides
));
}
/// Stride along a given axis
size_t
strides
(
size_t
dim
)
const
{
if
(
dim
>=
ndim
())
pybind11_fail
(
"NumPy: attempted to index strides beyond ndim
"
);
fail_dim_check
(
dim
,
"invalid axis
"
);
return
strides
()[
dim
];
}
...
...
@@ -346,20 +352,61 @@ public:
return
PyArray_CHKFLAGS_
(
m_ptr
,
detail
::
npy_api
::
NPY_ARRAY_OWNDATA_
);
}
/// Direct pointer to contained buffer
const
void
*
data
()
const
{
return
reinterpret_cast
<
const
void
*>
(
PyArray_GET_
(
m_ptr
,
data
));
/// Pointer to the contained data. If index is not provided, points to the
/// beginning of the buffer. May throw if the index would lead to out of bounds access.
template
<
typename
...
Ix
>
const
void
*
data
(
Ix
&&
...
index
)
const
{
return
static_cast
<
const
void
*>
(
PyArray_GET_
(
m_ptr
,
data
)
+
offset_at
(
index
...));
}
/// Direct mutable pointer to contained buffer (checks writeable flag)
void
*
mutable_data
()
{
if
(
!
writeable
())
pybind11_fail
(
"NumPy: cannot get mutable data of a read-only array"
);
return
reinterpret_cast
<
void
*>
(
PyArray_GET_
(
m_ptr
,
data
));
/// Mutable pointer to the contained data. If index is not provided, points to the
/// beginning of the buffer. May throw if the index would lead to out of bounds access.
/// May throw if the array is not writeable.
template
<
typename
...
Ix
>
void
*
mutable_data
(
Ix
&&
...
index
)
{
check_writeable
();
return
static_cast
<
void
*>
(
PyArray_GET_
(
m_ptr
,
data
)
+
offset_at
(
index
...));
}
/// Byte offset from beginning of the array to a given index (full or partial).
/// May throw if the index would lead to out of bounds access.
template
<
typename
...
Ix
>
size_t
offset_at
(
Ix
&&
...
index
)
const
{
if
(
sizeof
...(
index
)
>
ndim
())
fail_dim_check
(
sizeof
...(
index
),
"too many indices for an array"
);
return
get_byte_offset
(
index
...);
}
size_t
offset_at
()
const
{
return
0
;
}
/// Item count from beginning of the array to a given index (full or partial).
/// May throw if the index would lead to out of bounds access.
template
<
typename
...
Ix
>
size_t
index_at
(
Ix
&&
...
index
)
const
{
return
offset_at
(
index
...)
/
itemsize
();
}
protected:
template
<
typename
T
,
typename
SFINAE
>
friend
struct
detail
::
npy_format_descriptor
;
template
<
typename
,
typename
>
friend
struct
detail
::
npy_format_descriptor
;
void
fail_dim_check
(
size_t
dim
,
const
std
::
string
&
msg
)
const
{
throw
index_error
(
msg
+
": "
+
std
::
to_string
(
dim
)
+
" (ndim = "
+
std
::
to_string
(
ndim
())
+
")"
);
}
template
<
typename
...
Ix
>
size_t
get_byte_offset
(
Ix
&&
...
index
)
const
{
const
size_t
idx
[]
=
{
(
size_t
)
index
...
};
if
(
!
std
::
equal
(
idx
+
0
,
idx
+
sizeof
...(
index
),
shape
(),
std
::
less
<
size_t
>
{}))
{
auto
mismatch
=
std
::
mismatch
(
idx
+
0
,
idx
+
sizeof
...(
index
),
shape
(),
std
::
less
<
size_t
>
{});
throw
index_error
(
std
::
string
(
"index "
)
+
std
::
to_string
(
*
mismatch
.
first
)
+
" is out of bounds for axis "
+
std
::
to_string
(
mismatch
.
first
-
idx
)
+
" with size "
+
std
::
to_string
(
*
mismatch
.
second
));
}
return
std
::
inner_product
(
idx
+
0
,
idx
+
sizeof
...(
index
),
strides
(),
(
size_t
)
0
);
}
size_t
get_byte_offset
()
const
{
return
0
;
}
void
check_writeable
()
const
{
if
(
!
writeable
())
throw
std
::
runtime_error
(
"array is not writeable"
);
}
static
std
::
vector
<
size_t
>
default_strides
(
const
std
::
vector
<
size_t
>&
shape
,
size_t
itemsize
)
{
auto
ndim
=
shape
.
size
();
...
...
@@ -382,23 +429,45 @@ public:
array_t
(
const
buffer_info
&
info
)
:
array
(
info
)
{
}
array_t
(
const
std
::
vector
<
size_t
>&
shape
,
const
std
::
vector
<
size_t
>&
strides
,
T
*
ptr
=
nullptr
)
array_t
(
const
std
::
vector
<
size_t
>&
shape
,
const
std
::
vector
<
size_t
>&
strides
,
const
T
*
ptr
=
nullptr
)
:
array
(
shape
,
strides
,
ptr
)
{
}
array_t
(
const
std
::
vector
<
size_t
>&
shape
,
T
*
ptr
=
nullptr
)
array_t
(
const
std
::
vector
<
size_t
>&
shape
,
const
T
*
ptr
=
nullptr
)
:
array
(
shape
,
ptr
)
{
}
array_t
(
size_t
count
,
T
*
ptr
=
nullptr
)
array_t
(
size_t
count
,
const
T
*
ptr
=
nullptr
)
:
array
(
count
,
ptr
)
{
}
const
T
*
data
()
const
{
return
reinterpret_cast
<
const
T
*>
(
PyArray_GET_
(
m_ptr
,
data
)
);
const
expr
size_t
itemsize
()
const
{
return
sizeof
(
T
);
}
T
*
mutable_data
()
{
if
(
!
writeable
())
pybind11_fail
(
"NumPy: cannot get mutable data of a read-only array"
);
return
reinterpret_cast
<
T
*>
(
PyArray_GET_
(
m_ptr
,
data
));
template
<
typename
...
Ix
>
size_t
index_at
(
Ix
&
...
index
)
const
{
return
offset_at
(
index
...)
/
itemsize
();
}
template
<
typename
...
Ix
>
const
T
*
data
(
Ix
&&
...
index
)
const
{
return
static_cast
<
const
T
*>
(
array
::
data
(
index
...));
}
template
<
typename
...
Ix
>
T
*
mutable_data
(
Ix
&&
...
index
)
{
return
static_cast
<
T
*>
(
array
::
mutable_data
(
index
...));
}
// Reference to element at a given index
template
<
typename
...
Ix
>
const
T
&
at
(
Ix
&&
...
index
)
const
{
if
(
sizeof
...(
index
)
!=
ndim
())
fail_dim_check
(
sizeof
...(
index
),
"index dimension mismatch"
);
// not using offset_at() / index_at() here so as to avoid another dimension check
return
*
(
static_cast
<
const
T
*>
(
array
::
data
())
+
get_byte_offset
(
index
...)
/
itemsize
());
}
// Mutable reference to element at a given index
template
<
typename
...
Ix
>
T
&
mutable_at
(
Ix
&&
...
index
)
{
if
(
sizeof
...(
index
)
!=
ndim
())
fail_dim_check
(
sizeof
...(
index
),
"index dimension mismatch"
);
// not using offset_at() / index_at() here so as to avoid another dimension check
return
*
(
static_cast
<
T
*>
(
array
::
mutable_data
())
+
get_byte_offset
(
index
...)
/
itemsize
());
}
static
bool
is_non_null
(
PyObject
*
ptr
)
{
return
ptr
!=
nullptr
;
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment