test_numpy_array.py 17.9 KB
Newer Older
1
# -*- coding: utf-8 -*-
2
3
import pytest

4
5
6
import env  # noqa: F401

from pybind11_tests import numpy_array as m
7

8
np = pytest.importorskip("numpy")
9
10


11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
def test_dtypes():
    # See issue #1328.
    # - Platform-dependent sizes.
    for size_check in m.get_platform_dtype_size_checks():
        print(size_check)
        assert size_check.size_cpp == size_check.size_numpy, size_check
    # - Concrete sizes.
    for check in m.get_concrete_dtype_checks():
        print(check)
        assert check.numpy == check.pybind11, check
        if check.numpy.num != check.pybind11.num:
            print("NOTE: typenum mismatch for {}: {} != {}".format(
                check, check.numpy.num, check.pybind11.num))


26
27
@pytest.fixture(scope='function')
def arr():
28
    return np.array([[1, 2, 3], [4, 5, 6]], '=u2')
29
30


31
32
def test_array_attributes():
    a = np.array(0, 'f8')
33
34
35
    assert m.ndim(a) == 0
    assert all(m.shape(a) == [])
    assert all(m.strides(a) == [])
36
    with pytest.raises(IndexError) as excinfo:
37
        m.shape(a, 0)
38
39
    assert str(excinfo.value) == 'invalid axis: 0 (ndim = 0)'
    with pytest.raises(IndexError) as excinfo:
40
        m.strides(a, 0)
41
    assert str(excinfo.value) == 'invalid axis: 0 (ndim = 0)'
42
43
44
45
46
    assert m.writeable(a)
    assert m.size(a) == 1
    assert m.itemsize(a) == 8
    assert m.nbytes(a) == 8
    assert m.owndata(a)
47
48
49

    a = np.array([[1, 2, 3], [4, 5, 6]], 'u2').view()
    a.flags.writeable = False
50
51
52
53
54
55
56
    assert m.ndim(a) == 2
    assert all(m.shape(a) == [2, 3])
    assert m.shape(a, 0) == 2
    assert m.shape(a, 1) == 3
    assert all(m.strides(a) == [6, 2])
    assert m.strides(a, 0) == 6
    assert m.strides(a, 1) == 2
57
    with pytest.raises(IndexError) as excinfo:
58
        m.shape(a, 2)
59
60
    assert str(excinfo.value) == 'invalid axis: 2 (ndim = 2)'
    with pytest.raises(IndexError) as excinfo:
61
        m.strides(a, 2)
62
    assert str(excinfo.value) == 'invalid axis: 2 (ndim = 2)'
63
64
65
66
67
    assert not m.writeable(a)
    assert m.size(a) == 6
    assert m.itemsize(a) == 2
    assert m.nbytes(a) == 12
    assert not m.owndata(a)
68
69
70
71


@pytest.mark.parametrize('args, ret', [([], 0), ([0], 0), ([1], 3), ([0, 1], 1), ([1, 2], 5)])
def test_index_offset(arr, args, ret):
72
73
74
75
    assert m.index_at(arr, *args) == ret
    assert m.index_at_t(arr, *args) == ret
    assert m.offset_at(arr, *args) == ret * arr.dtype.itemsize
    assert m.offset_at_t(arr, *args) == ret * arr.dtype.itemsize
76
77
78


def test_dim_check_fail(arr):
79
80
    for func in (m.index_at, m.index_at_t, m.offset_at, m.offset_at_t, m.data, m.data_t,
                 m.mutate_data, m.mutate_data_t):
81
82
83
84
85
86
87
88
89
90
91
        with pytest.raises(IndexError) as excinfo:
            func(arr, 1, 2, 3)
        assert str(excinfo.value) == 'too many indices for an array: 3 (ndim = 2)'


@pytest.mark.parametrize('args, ret',
                         [([], [1, 2, 3, 4, 5, 6]),
                          ([1], [4, 5, 6]),
                          ([0, 1], [2, 3, 4, 5, 6]),
                          ([1, 2], [6])])
def test_data(arr, args, ret):
92
    from sys import byteorder
93
94
95
    assert all(m.data_t(arr, *args) == ret)
    assert all(m.data(arr, *args)[(0 if byteorder == 'little' else 1)::2] == ret)
    assert all(m.data(arr, *args)[(1 if byteorder == 'little' else 0)::2] == 0)
96
97
98
99


@pytest.mark.parametrize('dim', [0, 1, 3])
def test_at_fail(arr, dim):
100
    for func in m.at_t, m.mutate_at_t:
101
102
103
104
105
106
        with pytest.raises(IndexError) as excinfo:
            func(arr, *([0] * dim))
        assert str(excinfo.value) == 'index dimension mismatch: {} (ndim = 2)'.format(dim)


def test_at(arr):
107
108
109
110
111
    assert m.at_t(arr, 0, 2) == 3
    assert m.at_t(arr, 1, 0) == 4

    assert all(m.mutate_at_t(arr, 0, 2).ravel() == [1, 2, 4, 4, 5, 6])
    assert all(m.mutate_at_t(arr, 1, 0).ravel() == [1, 2, 4, 5, 5, 6])
112
113


114
115
116
117
118
119
def test_mutate_readonly(arr):
    arr.flags.writeable = False
    for func, args in (m.mutate_data, ()), (m.mutate_data_t, ()), (m.mutate_at_t, (0, 0)):
        with pytest.raises(ValueError) as excinfo:
            func(arr, *args)
        assert str(excinfo.value) == 'array is not writeable'
120
121
122


def test_mutate_data(arr):
123
124
125
126
127
    assert all(m.mutate_data(arr).ravel() == [2, 4, 6, 8, 10, 12])
    assert all(m.mutate_data(arr).ravel() == [4, 8, 12, 16, 20, 24])
    assert all(m.mutate_data(arr, 1).ravel() == [4, 8, 12, 32, 40, 48])
    assert all(m.mutate_data(arr, 0, 1).ravel() == [4, 16, 24, 64, 80, 96])
    assert all(m.mutate_data(arr, 1, 2).ravel() == [4, 16, 24, 64, 80, 192])
128

129
130
131
132
133
    assert all(m.mutate_data_t(arr).ravel() == [5, 17, 25, 65, 81, 193])
    assert all(m.mutate_data_t(arr).ravel() == [6, 18, 26, 66, 82, 194])
    assert all(m.mutate_data_t(arr, 1).ravel() == [6, 18, 26, 67, 83, 195])
    assert all(m.mutate_data_t(arr, 0, 1).ravel() == [6, 19, 27, 68, 84, 196])
    assert all(m.mutate_data_t(arr, 1, 2).ravel() == [6, 19, 27, 68, 84, 197])
134
135
136


def test_bounds_check(arr):
137
138
    for func in (m.index_at, m.index_at_t, m.data, m.data_t,
                 m.mutate_data, m.mutate_data_t, m.at_t, m.mutate_at_t):
139
        with pytest.raises(IndexError) as excinfo:
140
            func(arr, 2, 0)
141
142
        assert str(excinfo.value) == 'index 2 is out of bounds for axis 0 with size 2'
        with pytest.raises(IndexError) as excinfo:
143
            func(arr, 0, 4)
144
        assert str(excinfo.value) == 'index 4 is out of bounds for axis 1 with size 3'
145

146

147
def test_make_c_f_array():
148
149
150
151
    assert m.make_c_array().flags.c_contiguous
    assert not m.make_c_array().flags.f_contiguous
    assert m.make_f_array().flags.f_contiguous
    assert not m.make_f_array().flags.c_contiguous
152
153


154
155
156
def test_make_empty_shaped_array():
    m.make_empty_shaped_array()

157
158
159
160
161
    # empty shape means numpy scalar, PEP 3118
    assert m.scalar_int().ndim == 0
    assert m.scalar_int().shape == ()
    assert m.scalar_int() == 42

162

163
def test_wrap():
Jason Rhinelander's avatar
Jason Rhinelander committed
164
    def assert_references(a, b, base=None):
165
        from distutils.version import LooseVersion
Jason Rhinelander's avatar
Jason Rhinelander committed
166
167
        if base is None:
            base = a
168
169
170
171
172
173
174
175
        assert a is not b
        assert a.__array_interface__['data'][0] == b.__array_interface__['data'][0]
        assert a.shape == b.shape
        assert a.strides == b.strides
        assert a.flags.c_contiguous == b.flags.c_contiguous
        assert a.flags.f_contiguous == b.flags.f_contiguous
        assert a.flags.writeable == b.flags.writeable
        assert a.flags.aligned == b.flags.aligned
176
177
178
179
        if LooseVersion(np.__version__) >= LooseVersion("1.14.0"):
            assert a.flags.writebackifcopy == b.flags.writebackifcopy
        else:
            assert a.flags.updateifcopy == b.flags.updateifcopy
180
181
        assert np.all(a == b)
        assert not b.flags.owndata
Jason Rhinelander's avatar
Jason Rhinelander committed
182
        assert b.base is base
183
184
185
186
187
188
        if a.flags.writeable and a.ndim == 2:
            a[0, 0] = 1234
            assert b[0, 0] == 1234

    a1 = np.array([1, 2], dtype=np.int16)
    assert a1.flags.owndata and a1.base is None
189
    a2 = m.wrap(a1)
190
191
192
193
    assert_references(a1, a2)

    a1 = np.array([[1, 2], [3, 4]], dtype=np.float32, order='F')
    assert a1.flags.owndata and a1.base is None
194
    a2 = m.wrap(a1)
195
196
197
198
    assert_references(a1, a2)

    a1 = np.array([[1, 2], [3, 4]], dtype=np.float32, order='C')
    a1.flags.writeable = False
199
    a2 = m.wrap(a1)
200
201
202
    assert_references(a1, a2)

    a1 = np.random.random((4, 4, 4))
203
    a2 = m.wrap(a1)
204
205
    assert_references(a1, a2)

Jason Rhinelander's avatar
Jason Rhinelander committed
206
    a1t = a1.transpose()
207
    a2 = m.wrap(a1t)
Jason Rhinelander's avatar
Jason Rhinelander committed
208
    assert_references(a1t, a2, a1)
209

Jason Rhinelander's avatar
Jason Rhinelander committed
210
    a1d = a1.diagonal()
211
    a2 = m.wrap(a1d)
Jason Rhinelander's avatar
Jason Rhinelander committed
212
    assert_references(a1d, a2, a1)
213

214
    a1m = a1[::-1, ::-1, ::-1]
215
    a2 = m.wrap(a1m)
216
217
    assert_references(a1m, a2, a1)

218
219
220

def test_numpy_view(capture):
    with capture:
221
        ac = m.ArrayClass()
222
223
224
225
        ac_view_1 = ac.numpy_view()
        ac_view_2 = ac.numpy_view()
        assert np.all(ac_view_1 == np.array([1, 2], dtype=np.int32))
        del ac
Wenzel Jakob's avatar
Wenzel Jakob committed
226
        pytest.gc_collect()
227
228
229
230
231
232
233
234
235
236
237
238
    assert capture == """
        ArrayClass()
        ArrayClass::numpy_view()
        ArrayClass::numpy_view()
    """
    ac_view_1[0] = 4
    ac_view_1[1] = 3
    assert ac_view_2[0] == 4
    assert ac_view_2[1] == 3
    with capture:
        del ac_view_1
        del ac_view_2
Wenzel Jakob's avatar
Wenzel Jakob committed
239
240
        pytest.gc_collect()
        pytest.gc_collect()
241
242
243
    assert capture == """
        ~ArrayClass()
    """
244
245
246


def test_cast_numpy_int64_to_uint64():
247
248
    m.function_taking_uint64(123)
    m.function_taking_uint64(np.uint64(123))
249
250
251


def test_isinstance():
252
253
    assert m.isinstance_untyped(np.array([1, 2, 3]), "not an array")
    assert m.isinstance_typed(np.array([1.0, 2.0, 3.0]))
254
255
256


def test_constructors():
257
    defaults = m.default_constructors()
258
259
260
261
262
263
    for a in defaults.values():
        assert a.size == 0
    assert defaults["array"].dtype == np.array([]).dtype
    assert defaults["array_t<int32>"].dtype == np.int32
    assert defaults["array_t<double>"].dtype == np.float64

264
    results = m.converting_constructors([1, 2, 3])
265
266
267
268
269
    for a in results.values():
        np.testing.assert_array_equal(a, [1, 2, 3])
    assert results["array"].dtype == np.int_
    assert results["array_t<int32>"].dtype == np.int32
    assert results["array_t<double>"].dtype == np.float64
270
271


272
273
def test_overload_resolution(msg):
    # Exact overload matches:
274
275
276
277
278
279
280
    assert m.overloaded(np.array([1], dtype='float64')) == 'double'
    assert m.overloaded(np.array([1], dtype='float32')) == 'float'
    assert m.overloaded(np.array([1], dtype='ushort')) == 'unsigned short'
    assert m.overloaded(np.array([1], dtype='intc')) == 'int'
    assert m.overloaded(np.array([1], dtype='longlong')) == 'long long'
    assert m.overloaded(np.array([1], dtype='complex')) == 'double complex'
    assert m.overloaded(np.array([1], dtype='csingle')) == 'float complex'
281
282

    # No exact match, should call first convertible version:
283
    assert m.overloaded(np.array([1], dtype='uint8')) == 'double'
284

285
    with pytest.raises(TypeError) as excinfo:
286
        m.overloaded("not an array")
287
288
    assert msg(excinfo.value) == """
        overloaded(): incompatible function arguments. The following argument types are supported:
289
290
291
292
293
294
295
            1. (arg0: numpy.ndarray[numpy.float64]) -> str
            2. (arg0: numpy.ndarray[numpy.float32]) -> str
            3. (arg0: numpy.ndarray[numpy.int32]) -> str
            4. (arg0: numpy.ndarray[numpy.uint16]) -> str
            5. (arg0: numpy.ndarray[numpy.int64]) -> str
            6. (arg0: numpy.ndarray[numpy.complex128]) -> str
            7. (arg0: numpy.ndarray[numpy.complex64]) -> str
296
297
298
299

        Invoked with: 'not an array'
    """

300
301
302
303
304
    assert m.overloaded2(np.array([1], dtype='float64')) == 'double'
    assert m.overloaded2(np.array([1], dtype='float32')) == 'float'
    assert m.overloaded2(np.array([1], dtype='complex64')) == 'float complex'
    assert m.overloaded2(np.array([1], dtype='complex128')) == 'double complex'
    assert m.overloaded2(np.array([1], dtype='float32')) == 'float'
305

306
307
    assert m.overloaded3(np.array([1], dtype='float64')) == 'double'
    assert m.overloaded3(np.array([1], dtype='intc')) == 'int'
308
309
    expected_exc = """
        overloaded3(): incompatible function arguments. The following argument types are supported:
310
311
            1. (arg0: numpy.ndarray[numpy.int32]) -> str
            2. (arg0: numpy.ndarray[numpy.float64]) -> str
312

313
        Invoked with: """
314
315

    with pytest.raises(TypeError) as excinfo:
316
        m.overloaded3(np.array([1], dtype='uintc'))
317
    assert msg(excinfo.value) == expected_exc + repr(np.array([1], dtype='uint32'))
318
    with pytest.raises(TypeError) as excinfo:
319
        m.overloaded3(np.array([1], dtype='float32'))
320
    assert msg(excinfo.value) == expected_exc + repr(np.array([1.], dtype='float32'))
321
    with pytest.raises(TypeError) as excinfo:
322
        m.overloaded3(np.array([1], dtype='complex'))
323
    assert msg(excinfo.value) == expected_exc + repr(np.array([1. + 0.j]))
324
325

    # Exact matches:
326
327
    assert m.overloaded4(np.array([1], dtype='double')) == 'double'
    assert m.overloaded4(np.array([1], dtype='longlong')) == 'long long'
328
329
330
    # Non-exact matches requiring conversion.  Since float to integer isn't a
    # save conversion, it should go to the double overload, but short can go to
    # either (and so should end up on the first-registered, the long long).
331
332
    assert m.overloaded4(np.array([1], dtype='float32')) == 'double'
    assert m.overloaded4(np.array([1], dtype='short')) == 'long long'
333

334
335
336
    assert m.overloaded5(np.array([1], dtype='double')) == 'double'
    assert m.overloaded5(np.array([1], dtype='uintc')) == 'unsigned int'
    assert m.overloaded5(np.array([1], dtype='float32')) == 'unsigned int'
337
338


339
340
def test_greedy_string_overload():
    """Tests fix for #685 - ndarray shouldn't go to std::string overload"""
341

342
343
344
    assert m.issue685("abc") == "string"
    assert m.issue685(np.array([97, 98, 99], dtype='b')) == "array"
    assert m.issue685(123) == "other"
345
346


347
def test_array_unchecked_fixed_dims(msg):
348
    z1 = np.array([[1, 2], [3, 4]], dtype='float64')
349
    m.proxy_add2(z1, 10)
350
351
352
    assert np.all(z1 == [[11, 12], [13, 14]])

    with pytest.raises(ValueError) as excinfo:
353
        m.proxy_add2(np.array([1., 2, 3]), 5.0)
354
355
356
    assert msg(excinfo.value) == "array has incorrect number of dimensions: 1; expected 2"

    expect_c = np.ndarray(shape=(3, 3, 3), buffer=np.array(range(3, 30)), dtype='int')
357
    assert np.all(m.proxy_init3(3.0) == expect_c)
358
    expect_f = np.transpose(expect_c)
359
    assert np.all(m.proxy_init3F(3.0) == expect_f)
360

361
362
    assert m.proxy_squared_L2_norm(np.array(range(6))) == 55
    assert m.proxy_squared_L2_norm(np.array(range(6), dtype="float64")) == 55
363

364
365
    assert m.proxy_auxiliaries2(z1) == [11, 11, True, 2, 8, 2, 2, 4, 32]
    assert m.proxy_auxiliaries2(z1) == m.array_auxiliaries2(z1)
366

367
368
369
    assert m.proxy_auxiliaries1_const_ref(z1[0, :])
    assert m.proxy_auxiliaries2_const_ref(z1)

370
371
372

def test_array_unchecked_dyn_dims(msg):
    z1 = np.array([[1, 2], [3, 4]], dtype='float64')
373
    m.proxy_add2_dyn(z1, 10)
374
375
376
    assert np.all(z1 == [[11, 12], [13, 14]])

    expect_c = np.ndarray(shape=(3, 3, 3), buffer=np.array(range(3, 30)), dtype='int')
377
    assert np.all(m.proxy_init3_dyn(3.0) == expect_c)
378

379
380
    assert m.proxy_auxiliaries2_dyn(z1) == [11, 11, True, 2, 8, 2, 2, 4, 32]
    assert m.proxy_auxiliaries2_dyn(z1) == m.array_auxiliaries2(z1)
381
382
383
384


def test_array_failure():
    with pytest.raises(ValueError) as excinfo:
385
        m.array_fail_test()
386
387
388
    assert str(excinfo.value) == 'cannot create a pybind11::array from a nullptr'

    with pytest.raises(ValueError) as excinfo:
389
        m.array_t_fail_test()
390
    assert str(excinfo.value) == 'cannot create a pybind11::array_t from a nullptr'
uentity's avatar
uentity committed
391

392
    with pytest.raises(ValueError) as excinfo:
393
        m.array_fail_test_negative_size()
394
395
    assert str(excinfo.value) == 'negative dimensions are not allowed'

uentity's avatar
uentity committed
396

397
398
399
400
401
402
def test_initializer_list():
    assert m.array_initializer_list1().shape == (1,)
    assert m.array_initializer_list2().shape == (1, 2)
    assert m.array_initializer_list3().shape == (1, 2, 3)
    assert m.array_initializer_list4().shape == (1, 2, 3, 4)

uentity's avatar
uentity committed
403

404
def test_array_resize(msg):
uentity's avatar
uentity committed
405
    a = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9], dtype='float64')
406
    m.array_reshape2(a)
uentity's avatar
uentity committed
407
408
409
410
    assert(a.size == 9)
    assert(np.all(a == [[1, 2, 3], [4, 5, 6], [7, 8, 9]]))

    # total size change should succced with refcheck off
411
    m.array_resize3(a, 4, False)
uentity's avatar
uentity committed
412
413
414
    assert(a.size == 64)
    # ... and fail with refcheck on
    try:
415
        m.array_resize3(a, 3, True)
uentity's avatar
uentity committed
416
417
418
419
420
    except ValueError as e:
        assert(str(e).startswith("cannot resize an array"))
    # transposed array doesn't own data
    b = a.transpose()
    try:
421
        m.array_resize3(b, 3, False)
uentity's avatar
uentity committed
422
423
424
    except ValueError as e:
        assert(str(e).startswith("cannot resize this array: it does not own its data"))
    # ... but reshape should be fine
425
    m.array_reshape2(b)
uentity's avatar
uentity committed
426
427
428
    assert(b.shape == (8, 8))


429
@pytest.mark.xfail("env.PYPY")
uentity's avatar
uentity committed
430
def test_array_create_and_resize(msg):
431
    a = m.create_and_resize(2)
uentity's avatar
uentity committed
432
433
    assert(a.size == 4)
    assert(np.all(a == 42.))
434
435
436
437
438


def test_index_using_ellipsis():
    a = m.index_using_ellipsis(np.zeros((5, 6, 7)))
    assert a.shape == (6,)
439
440


441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
@pytest.mark.parametrize("forcecast", [False, True])
@pytest.mark.parametrize("contiguity", [None, 'C', 'F'])
@pytest.mark.parametrize("noconvert", [False, True])
@pytest.mark.filterwarnings(
    "ignore:Casting complex values to real discards the imaginary part:numpy.ComplexWarning"
)
def test_argument_conversions(forcecast, contiguity, noconvert):
    function_name = "accept_double"
    if contiguity == 'C':
        function_name += "_c_style"
    elif contiguity == 'F':
        function_name += "_f_style"
    if forcecast:
        function_name += "_forcecast"
    if noconvert:
        function_name += "_noconvert"
    function = getattr(m, function_name)

    for dtype in [np.dtype('float32'), np.dtype('float64'), np.dtype('complex128')]:
        for order in ['C', 'F']:
            for shape in [(2, 2), (1, 3, 1, 1), (1, 1, 1), (0,)]:
                if not noconvert:
                    # If noconvert is not passed, only complex128 needs to be truncated and
                    # "cannot be safely obtained". So without `forcecast`, the argument shouldn't
                    # be accepted.
                    should_raise = dtype.name == 'complex128' and not forcecast
                else:
                    # If noconvert is passed, only float64 and the matching order is accepted.
                    # If at most one dimension has a size greater than 1, the array is also
                    # trivially contiguous.
                    trivially_contiguous = sum(1 for d in shape if d > 1) <= 1
                    should_raise = (
                        dtype.name != 'float64' or
                        (contiguity is not None and
                         contiguity != order and
                         not trivially_contiguous)
                    )

                array = np.zeros(shape, dtype=dtype, order=order)
                if not should_raise:
                    function(array)
                else:
                    with pytest.raises(TypeError, match="incompatible function arguments"):
                        function(array)


487
@pytest.mark.xfail("env.PYPY")
488
489
490
491
492
493
494
495
def test_dtype_refcount_leak():
    from sys import getrefcount
    dtype = np.dtype(np.float_)
    a = np.array([1], dtype=dtype)
    before = getrefcount(dtype)
    m.ndim(a)
    after = getrefcount(dtype)
    assert after == before