Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
pybind11
Commits
aca6bcae
Commit
aca6bcae
authored
Sep 08, 2016
by
Ivan Smirnov
Browse files
Add tests for array data access /index methods
parent
f2a0ad58
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
216 additions
and
60 deletions
+216
-60
tests/test_numpy_array.cpp
tests/test_numpy_array.cpp
+79
-30
tests/test_numpy_array.py
tests/test_numpy_array.py
+137
-30
No files found.
tests/test_numpy_array.cpp
View file @
aca6bcae
...
@@ -8,38 +8,87 @@
...
@@ -8,38 +8,87 @@
*/
*/
#include "pybind11_tests.h"
#include "pybind11_tests.h"
#include <pybind11/numpy.h>
#include <pybind11/numpy.h>
#include <pybind11/stl.h>
#include <pybind11/stl.h>
#include <cstdint>
#include <vector>
using
arr
=
py
::
array
;
using
arr_t
=
py
::
array_t
<
uint16_t
,
0
>
;
template
<
typename
...
Ix
>
arr
data
(
const
arr
&
a
,
Ix
&&
...
index
)
{
return
arr
(
a
.
nbytes
()
-
a
.
offset_at
(
index
...),
(
const
uint8_t
*
)
a
.
data
(
index
...));
}
template
<
typename
...
Ix
>
arr
data_t
(
const
arr_t
&
a
,
Ix
&&
...
index
)
{
return
arr
(
a
.
size
()
-
a
.
index_at
(
index
...),
a
.
data
(
index
...));
}
arr
&
mutate_data
(
arr
&
a
)
{
auto
ptr
=
(
uint8_t
*
)
a
.
mutable_data
();
for
(
size_t
i
=
0
;
i
<
a
.
nbytes
();
i
++
)
ptr
[
i
]
=
(
uint8_t
)
(
ptr
[
i
]
*
2
);
return
a
;
}
arr_t
&
mutate_data_t
(
arr_t
&
a
)
{
auto
ptr
=
a
.
mutable_data
();
for
(
size_t
i
=
0
;
i
<
a
.
size
();
i
++
)
ptr
[
i
]
++
;
return
a
;
}
template
<
typename
...
Ix
>
arr
&
mutate_data
(
arr
&
a
,
Ix
&&
...
index
)
{
auto
ptr
=
(
uint8_t
*
)
a
.
mutable_data
(
index
...);
for
(
size_t
i
=
0
;
i
<
a
.
nbytes
()
-
a
.
offset_at
(
index
...);
i
++
)
ptr
[
i
]
=
(
uint8_t
)
(
ptr
[
i
]
*
2
);
return
a
;
}
template
<
typename
...
Ix
>
arr_t
&
mutate_data_t
(
arr_t
&
a
,
Ix
&&
...
index
)
{
auto
ptr
=
a
.
mutable_data
(
index
...);
for
(
size_t
i
=
0
;
i
<
a
.
size
()
-
a
.
index_at
(
index
...);
i
++
)
ptr
[
i
]
++
;
return
a
;
}
template
<
typename
...
Ix
>
size_t
index_at
(
const
arr
&
a
,
Ix
&&
...
idx
)
{
return
a
.
index_at
(
idx
...);
}
template
<
typename
...
Ix
>
size_t
index_at_t
(
const
arr_t
&
a
,
Ix
&&
...
idx
)
{
return
a
.
index_at
(
idx
...);
}
template
<
typename
...
Ix
>
size_t
offset_at
(
const
arr
&
a
,
Ix
&&
...
idx
)
{
return
a
.
offset_at
(
idx
...);
}
template
<
typename
...
Ix
>
size_t
offset_at_t
(
const
arr_t
&
a
,
Ix
&&
...
idx
)
{
return
a
.
offset_at
(
idx
...);
}
template
<
typename
...
Ix
>
size_t
at_t
(
const
arr_t
&
a
,
Ix
&&
...
idx
)
{
return
a
.
at
(
idx
...);
}
template
<
typename
...
Ix
>
arr_t
&
mutate_at_t
(
arr_t
&
a
,
Ix
&&
...
idx
)
{
a
.
mutable_at
(
idx
...)
++
;
return
a
;
}
#define def_index_fn(name, type) \
sm.def(#name, [](type a) { return name(a); }); \
sm.def(#name, [](type a, int i) { return name(a, i); }); \
sm.def(#name, [](type a, int i, int j) { return name(a, i, j); }); \
sm.def(#name, [](type a, int i, int j, int k) { return name(a, i, j, k); });
test_initializer
numpy_array
([](
py
::
module
&
m
)
{
test_initializer
numpy_array
([](
py
::
module
&
m
)
{
m
.
def
(
"get_arr_ndim"
,
[](
const
py
::
array
&
arr
)
{
auto
sm
=
m
.
def_submodule
(
"array"
);
return
arr
.
ndim
();
});
sm
.
def
(
"ndim"
,
[](
const
arr
&
a
)
{
return
a
.
ndim
();
});
m
.
def
(
"get_arr_shape"
,
[](
const
py
::
array
&
arr
)
{
sm
.
def
(
"shape"
,
[](
const
arr
&
a
)
{
return
arr
(
a
.
ndim
(),
a
.
shape
());
});
return
std
::
vector
<
size_t
>
(
arr
.
shape
(),
arr
.
shape
()
+
arr
.
ndim
());
sm
.
def
(
"shape"
,
[](
const
arr
&
a
,
size_t
dim
)
{
return
a
.
shape
(
dim
);
});
});
sm
.
def
(
"strides"
,
[](
const
arr
&
a
)
{
return
arr
(
a
.
ndim
(),
a
.
strides
());
});
m
.
def
(
"get_arr_shape"
,
[](
const
py
::
array
&
arr
,
size_t
dim
)
{
sm
.
def
(
"strides"
,
[](
const
arr
&
a
,
size_t
dim
)
{
return
a
.
strides
(
dim
);
});
return
arr
.
shape
(
dim
);
sm
.
def
(
"writeable"
,
[](
const
arr
&
a
)
{
return
a
.
writeable
();
});
});
sm
.
def
(
"size"
,
[](
const
arr
&
a
)
{
return
a
.
size
();
});
m
.
def
(
"get_arr_strides"
,
[](
const
py
::
array
&
arr
)
{
sm
.
def
(
"itemsize"
,
[](
const
arr
&
a
)
{
return
a
.
itemsize
();
});
return
std
::
vector
<
size_t
>
(
arr
.
strides
(),
arr
.
strides
()
+
arr
.
ndim
());
sm
.
def
(
"nbytes"
,
[](
const
arr
&
a
)
{
return
a
.
nbytes
();
});
});
sm
.
def
(
"owndata"
,
[](
const
arr
&
a
)
{
return
a
.
owndata
();
});
m
.
def
(
"get_arr_strides"
,
[](
const
py
::
array
&
arr
,
size_t
dim
)
{
return
arr
.
strides
(
dim
);
def_index_fn
(
data
,
const
arr
&
);
});
def_index_fn
(
data_t
,
const
arr_t
&
);
m
.
def
(
"get_arr_writeable"
,
[](
const
py
::
array
&
arr
)
{
def_index_fn
(
index_at
,
const
arr
&
);
return
arr
.
writeable
();
def_index_fn
(
index_at_t
,
const
arr_t
&
);
});
def_index_fn
(
offset_at
,
const
arr
&
);
m
.
def
(
"get_arr_size"
,
[](
const
py
::
array
&
arr
)
{
def_index_fn
(
offset_at_t
,
const
arr_t
&
);
return
arr
.
size
();
def_index_fn
(
mutate_data
,
arr
&
);
});
def_index_fn
(
mutate_data_t
,
arr_t
&
);
m
.
def
(
"get_arr_itemsize"
,
[](
const
py
::
array
&
arr
)
{
def_index_fn
(
at_t
,
const
arr_t
&
);
return
arr
.
itemsize
();
def_index_fn
(
mutate_at_t
,
arr_t
&
);
});
m
.
def
(
"get_arr_nbytes"
,
[](
const
py
::
array
&
arr
)
{
return
arr
.
nbytes
();
});
m
.
def
(
"get_arr_owndata"
,
[](
const
py
::
array
&
arr
)
{
return
arr
.
owndata
();
});
});
});
tests/test_numpy_array.py
View file @
aca6bcae
...
@@ -4,40 +4,147 @@ with pytest.suppress(ImportError):
...
@@ -4,40 +4,147 @@ with pytest.suppress(ImportError):
import
numpy
as
np
import
numpy
as
np
@
pytest
.
fixture
(
scope
=
'function'
)
def
arr
():
return
np
.
array
([[
1
,
2
,
3
],
[
4
,
5
,
6
]],
'<u2'
)
@
pytest
.
requires_numpy
@
pytest
.
requires_numpy
def
test_array_attributes
():
def
test_array_attributes
():
from
pybind11_tests
import
(
get_arr_ndim
,
get_arr_shape
,
get_arr_strides
,
get_arr_writeable
,
from
pybind11_tests.array
import
(
get_arr_size
,
get_arr_itemsize
,
get_arr_nbytes
,
get_arr_owndata
)
ndim
,
shape
,
strides
,
writeable
,
size
,
itemsize
,
nbytes
,
owndata
)
a
=
np
.
array
(
0
,
'f8'
)
a
=
np
.
array
(
0
,
'f8'
)
assert
get_arr_ndim
(
a
)
==
0
assert
ndim
(
a
)
==
0
assert
get_arr_shape
(
a
)
==
[]
assert
all
(
shape
(
a
)
==
[])
assert
get_arr_strides
(
a
)
==
[]
assert
all
(
strides
(
a
)
==
[])
with
pytest
.
raises
(
RuntimeError
):
with
pytest
.
raises
(
IndexError
)
as
excinfo
:
get_arr_shape
(
a
,
1
)
shape
(
a
,
0
)
with
pytest
.
raises
(
RuntimeError
):
assert
str
(
excinfo
.
value
)
==
'invalid axis: 0 (ndim = 0)'
get_arr_strides
(
a
,
0
)
with
pytest
.
raises
(
IndexError
)
as
excinfo
:
assert
get_arr_writeable
(
a
)
strides
(
a
,
0
)
assert
get_arr_size
(
a
)
==
1
assert
str
(
excinfo
.
value
)
==
'invalid axis: 0 (ndim = 0)'
assert
get_arr_itemsize
(
a
)
==
8
assert
writeable
(
a
)
assert
get_arr_nbytes
(
a
)
==
8
assert
size
(
a
)
==
1
assert
get_arr_owndata
(
a
)
assert
itemsize
(
a
)
==
8
assert
nbytes
(
a
)
==
8
assert
owndata
(
a
)
a
=
np
.
array
([[
1
,
2
,
3
],
[
4
,
5
,
6
]],
'u2'
).
view
()
a
=
np
.
array
([[
1
,
2
,
3
],
[
4
,
5
,
6
]],
'u2'
).
view
()
a
.
flags
.
writeable
=
False
a
.
flags
.
writeable
=
False
assert
get_arr_ndim
(
a
)
==
2
assert
ndim
(
a
)
==
2
assert
get_arr_shape
(
a
)
==
[
2
,
3
]
assert
all
(
shape
(
a
)
==
[
2
,
3
])
assert
get_arr_shape
(
a
,
0
)
==
2
assert
shape
(
a
,
0
)
==
2
assert
get_arr_shape
(
a
,
1
)
==
3
assert
shape
(
a
,
1
)
==
3
assert
get_arr_strides
(
a
)
==
[
6
,
2
]
assert
all
(
strides
(
a
)
==
[
6
,
2
])
assert
get_arr_strides
(
a
,
0
)
==
6
assert
strides
(
a
,
0
)
==
6
assert
get_arr_strides
(
a
,
1
)
==
2
assert
strides
(
a
,
1
)
==
2
with
pytest
.
raises
(
RuntimeError
):
with
pytest
.
raises
(
IndexError
)
as
excinfo
:
get_arr_shape
(
a
,
2
)
shape
(
a
,
2
)
with
pytest
.
raises
(
RuntimeError
):
assert
str
(
excinfo
.
value
)
==
'invalid axis: 2 (ndim = 2)'
get_arr_strides
(
a
,
2
)
with
pytest
.
raises
(
IndexError
)
as
excinfo
:
assert
not
get_arr_writeable
(
a
)
strides
(
a
,
2
)
assert
get_arr_size
(
a
)
==
6
assert
str
(
excinfo
.
value
)
==
'invalid axis: 2 (ndim = 2)'
assert
get_arr_itemsize
(
a
)
==
2
assert
not
writeable
(
a
)
assert
get_arr_nbytes
(
a
)
==
12
assert
size
(
a
)
==
6
assert
not
get_arr_owndata
(
a
)
assert
itemsize
(
a
)
==
2
assert
nbytes
(
a
)
==
12
assert
not
owndata
(
a
)
@
pytest
.
requires_numpy
@
pytest
.
mark
.
parametrize
(
'args, ret'
,
[([],
0
),
([
0
],
0
),
([
1
],
3
),
([
0
,
1
],
1
),
([
1
,
2
],
5
)])
def
test_index_offset
(
arr
,
args
,
ret
):
from
pybind11_tests.array
import
index_at
,
index_at_t
,
offset_at
,
offset_at_t
assert
index_at
(
arr
,
*
args
)
==
ret
assert
index_at_t
(
arr
,
*
args
)
==
ret
assert
offset_at
(
arr
,
*
args
)
==
ret
*
arr
.
dtype
.
itemsize
assert
offset_at_t
(
arr
,
*
args
)
==
ret
*
arr
.
dtype
.
itemsize
@
pytest
.
requires_numpy
def
test_dim_check_fail
(
arr
):
from
pybind11_tests.array
import
(
index_at
,
index_at_t
,
offset_at
,
offset_at_t
,
data
,
data_t
,
mutate_data
,
mutate_data_t
)
for
func
in
(
index_at
,
index_at_t
,
offset_at
,
offset_at_t
,
data
,
data_t
,
mutate_data
,
mutate_data_t
):
with
pytest
.
raises
(
IndexError
)
as
excinfo
:
func
(
arr
,
1
,
2
,
3
)
assert
str
(
excinfo
.
value
)
==
'too many indices for an array: 3 (ndim = 2)'
@
pytest
.
requires_numpy
@
pytest
.
mark
.
parametrize
(
'args, ret'
,
[([],
[
1
,
2
,
3
,
4
,
5
,
6
]),
([
1
],
[
4
,
5
,
6
]),
([
0
,
1
],
[
2
,
3
,
4
,
5
,
6
]),
([
1
,
2
],
[
6
])])
def
test_data
(
arr
,
args
,
ret
):
from
pybind11_tests.array
import
data
,
data_t
assert
all
(
data_t
(
arr
,
*
args
)
==
ret
)
assert
all
(
data
(
arr
,
*
args
)[::
2
]
==
ret
)
assert
all
(
data
(
arr
,
*
args
)[
1
::
2
]
==
0
)
@
pytest
.
requires_numpy
def
test_mutate_readonly
(
arr
):
from
pybind11_tests.array
import
mutate_data
,
mutate_data_t
,
mutate_at_t
arr
.
flags
.
writeable
=
False
for
func
,
args
in
(
mutate_data
,
()),
(
mutate_data_t
,
()),
(
mutate_at_t
,
(
0
,
0
)):
with
pytest
.
raises
(
RuntimeError
)
as
excinfo
:
func
(
arr
,
*
args
)
assert
str
(
excinfo
.
value
)
==
'array is not writeable'
@
pytest
.
requires_numpy
@
pytest
.
mark
.
parametrize
(
'dim'
,
[
0
,
1
,
3
])
def
test_at_fail
(
arr
,
dim
):
from
pybind11_tests.array
import
at_t
,
mutate_at_t
for
func
in
at_t
,
mutate_at_t
:
with
pytest
.
raises
(
IndexError
)
as
excinfo
:
func
(
arr
,
*
([
0
]
*
dim
))
assert
str
(
excinfo
.
value
)
==
'index dimension mismatch: {} (ndim = 2)'
.
format
(
dim
)
@
pytest
.
requires_numpy
def
test_at
(
arr
):
from
pybind11_tests.array
import
at_t
,
mutate_at_t
assert
at_t
(
arr
,
0
,
2
)
==
3
assert
at_t
(
arr
,
1
,
0
)
==
4
assert
all
(
mutate_at_t
(
arr
,
0
,
2
).
ravel
()
==
[
1
,
2
,
4
,
4
,
5
,
6
])
assert
all
(
mutate_at_t
(
arr
,
1
,
0
).
ravel
()
==
[
1
,
2
,
4
,
5
,
5
,
6
])
@
pytest
.
requires_numpy
def
test_mutate_data
(
arr
):
from
pybind11_tests.array
import
mutate_data
,
mutate_data_t
assert
all
(
mutate_data
(
arr
).
ravel
()
==
[
2
,
4
,
6
,
8
,
10
,
12
])
assert
all
(
mutate_data
(
arr
).
ravel
()
==
[
4
,
8
,
12
,
16
,
20
,
24
])
assert
all
(
mutate_data
(
arr
,
1
).
ravel
()
==
[
4
,
8
,
12
,
32
,
40
,
48
])
assert
all
(
mutate_data
(
arr
,
0
,
1
).
ravel
()
==
[
4
,
16
,
24
,
64
,
80
,
96
])
assert
all
(
mutate_data
(
arr
,
1
,
2
).
ravel
()
==
[
4
,
16
,
24
,
64
,
80
,
192
])
assert
all
(
mutate_data_t
(
arr
).
ravel
()
==
[
5
,
17
,
25
,
65
,
81
,
193
])
assert
all
(
mutate_data_t
(
arr
).
ravel
()
==
[
6
,
18
,
26
,
66
,
82
,
194
])
assert
all
(
mutate_data_t
(
arr
,
1
).
ravel
()
==
[
6
,
18
,
26
,
67
,
83
,
195
])
assert
all
(
mutate_data_t
(
arr
,
0
,
1
).
ravel
()
==
[
6
,
19
,
27
,
68
,
84
,
196
])
assert
all
(
mutate_data_t
(
arr
,
1
,
2
).
ravel
()
==
[
6
,
19
,
27
,
68
,
84
,
197
])
@
pytest
.
requires_numpy
def
test_bounds_check
(
arr
):
from
pybind11_tests.array
import
(
index_at
,
index_at_t
,
data
,
data_t
,
mutate_data
,
mutate_data_t
,
at_t
,
mutate_at_t
)
funcs
=
(
index_at
,
index_at_t
,
data
,
data_t
,
mutate_data
,
mutate_data_t
,
at_t
,
mutate_at_t
)
for
func
in
funcs
:
with
pytest
.
raises
(
IndexError
)
as
excinfo
:
index_at
(
arr
,
2
,
0
)
assert
str
(
excinfo
.
value
)
==
'index 2 is out of bounds for axis 0 with size 2'
with
pytest
.
raises
(
IndexError
)
as
excinfo
:
index_at
(
arr
,
0
,
4
)
assert
str
(
excinfo
.
value
)
==
'index 4 is out of bounds for axis 1 with size 3'
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment