"...git@developer.sourcefind.cn:yaoyuping/nndetection.git" did not exist on "7246044d8824f7b3f6c243db054b61420212ad05"
test_shape_base.py 3.24 KB
Newer Older
root's avatar
root committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import unittest

import numpy
import pytest

import cupy
from cupy import testing


@testing.parameterize(*(testing.product({'axis': [0, 1, -1]})))
class TestApplyAlongAxis(unittest.TestCase):

    @testing.numpy_cupy_array_equal()
    def test_simple(self, xp):
        a = xp.ones((20, 10), 'd')
        return xp.apply_along_axis(len, self.axis, a)

    @testing.for_all_dtypes(no_bool=True)
    @testing.numpy_cupy_array_equal()
    def test_3d(self, xp, dtype):
        a = xp.arange(27, dtype=dtype).reshape((3, 3, 3))
        return xp.apply_along_axis(xp.sum, self.axis, a)

    @testing.numpy_cupy_array_equal()
    def test_0d_array(self, xp):

        def sum_to_0d(x):
            """ Sum x, returning a 0d array of the same class """
            assert x.ndim == 1
            return xp.squeeze(xp.sum(x, keepdims=True))

        a = xp.ones((6, 3))
        return xp.apply_along_axis(sum_to_0d, self.axis, a)

    @testing.numpy_cupy_array_equal()
    def test_axis_insertion_2d(self, xp):

        def f1to2(x):
            """produces an asymmetric non-square matrix from x"""
            assert x.ndim == 1
            return (x[::-1] * x[1:, None])

        # 2d insertion
        a2d = xp.arange(6 * 3).reshape((6, 3))
        return xp.apply_along_axis(f1to2, self.axis, a2d)

    @testing.numpy_cupy_array_equal()
    def test_axis_insertion_3d(self, xp):

        def f1to2(x):
            """produces an asymmetric non-square matrix from x"""
            assert x.ndim == 1
            return (x[::-1] * x[1:, None])

        # 3d insertion
        a3d = xp.arange(6 * 5 * 3).reshape((6, 5, 3))
        return xp.apply_along_axis(f1to2, self.axis, a3d)

    def test_empty1(self):
        # can't apply_along_axis when there's no chance to call the function
        def never_call(x):
            assert False  # should never be reached

        for xp in [numpy, cupy]:
            a = xp.empty((0, 0))
            with pytest.raises(ValueError):
                xp.apply_along_axis(never_call, self.axis, a)

    def test_empty2(self):
        # but it's sometimes ok with some non-zero dimensions
        def empty_to_1(x):
            assert len(x) == 0
            return 1

        for xp in [numpy, cupy]:
            shape = [10, 10]
            shape[self.axis] = 0
            shape = tuple(shape)
            a = xp.empty(shape)
            if self.axis == 0:
                other_axis = 1
            else:
                other_axis = 0
            with pytest.raises(ValueError):
                xp.apply_along_axis(empty_to_1, other_axis, a)

            # okay to call along the shape 0 axis
            testing.assert_array_equal(
                xp.apply_along_axis(empty_to_1, self.axis, a),
                xp.ones((10,))
            )

    @testing.numpy_cupy_array_equal()
    def test_tuple_outs(self, xp):
        def func(x):
            return x.sum(axis=-1), x.prod(axis=-1), x.max(axis=-1)

        a = testing.shaped_arange((2, 2, 2), xp, cupy.int64)
        return xp.apply_along_axis(func, 1, a)


@testing.with_requires('numpy>=1.16')
def test_apply_along_axis_invalid_axis():
    for xp in [numpy, cupy]:
        a = xp.ones((8, 4))
        for axis in [-3, 2]:
            with pytest.raises(numpy.AxisError):
                xp.apply_along_axis(xp.sum, axis, a)