test_piecewise.py 4.32 KB
Newer Older
root's avatar
root committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import unittest

import pytest

import numpy
import cupy
from cupy import testing


class TestPiecewise(unittest.TestCase):

    @testing.for_all_dtypes()
    @testing.numpy_cupy_array_equal()
    def test_piecewise(self, xp, dtype):
        x = xp.linspace(2.5, 12.5, 6, dtype=dtype)
        condlist = [x < 0, x >= 0, x < 5, x >= 1.5]
        funclist = xp.array([-1, 1, 2, 5])
        return xp.piecewise(x, condlist, funclist)

    @testing.for_all_dtypes()
    @testing.numpy_cupy_array_equal()
    def test_piecewise_scalar_input(self, xp, dtype):
        x = dtype(2)
        condlist = [x < 0, x >= 0]
        funclist = [1, 10]
        return xp.piecewise(x, condlist, funclist)

    @testing.for_all_dtypes()
    @testing.numpy_cupy_array_equal()
    def test_piecewise_scalar_condition(self, xp, dtype):
        x = testing.shaped_random(shape=(2, 3, 5), xp=xp, dtype=dtype)
        condlist = True
        funclist = xp.array([-10, 10])
        return xp.piecewise(x, condlist, funclist)

    @testing.for_signed_dtypes()
    @testing.numpy_cupy_array_equal()
    def test_piecewise_otherwise_condition1(self, xp, dtype):
        x = xp.linspace(-2, 20, 12, dtype=dtype)
        condlist = [x > 15, x <= 5, x == 0, x == 10]
        funclist = xp.array([-1, 0, 2, 3, -5])
        return xp.piecewise(x, condlist, funclist)

    @testing.for_all_dtypes()
    @testing.numpy_cupy_array_equal()
    def test_piecewise_otherwise_condition2(self, xp, dtype):
        x = xp.array([-10, 20, 30, 40]).astype(dtype)
        condlist = [
            xp.array([True, False, False, True]),
            xp.array([True, False, False, True]),
        ]
        funclist = xp.array([-1, 1, 2])
        return xp.piecewise(x, condlist, funclist)

    @testing.for_all_dtypes()
    @testing.numpy_cupy_array_equal()
    def test_piecewise_zero_dim_input(self, xp, dtype):
        x = testing.shaped_random(shape=(), xp=xp, dtype=dtype)
        condlist = [x < 0, x > 0]
        funclist = [10, 1, 2]
        return xp.piecewise(x, condlist, funclist)

    @testing.for_all_dtypes()
    @testing.numpy_cupy_array_equal()
    def test_piecewise_ndim_input(self, xp, dtype):
        x = testing.shaped_random(shape=(2, 3, 5), xp=xp, dtype=dtype)
        condlist = [x < 0, x > 0]
        funclist = [10, 1, 2]
        return xp.piecewise(x, condlist, funclist)

    @testing.for_all_dtypes()
    @testing.numpy_cupy_array_equal()
    def test_piecewise_zero_dim_condlist(self, xp, dtype):
        x = testing.shaped_random(shape=(), xp=xp, dtype=dtype)
        condlist = [testing.shaped_random(shape=(), xp=xp, dtype=bool)]
        funclist = [1, 2]
        return xp.piecewise(x, condlist, funclist)

    @testing.for_all_dtypes()
    @testing.numpy_cupy_array_equal()
    def test_piecewise_ndarray_condlist_funclist(self, xp, dtype):
        x = xp.linspace(1, 20, 12, dtype=dtype)
        condlist = xp.array([x > 15, x <= 5, x == 0, x == 10])
        funclist = xp.array([-1, 0, 2, 3, -5]).astype(dtype)
        return xp.piecewise(x, condlist, funclist)

    @testing.for_all_dtypes_combination(
        names=['dtype1', 'dtype2'], no_complex=True)
    @testing.numpy_cupy_array_equal()
    def test_piecewise_diff_types_funclist(self, xp, dtype1, dtype2):
        x = xp.linspace(1, 20, 12, dtype=dtype1)
        condlist = [x > 15, x <= 5, x == 0, x == 10]
        funclist = xp.array([1, 0, 2, 3, 5], dtype=dtype2)
        return xp.piecewise(x, condlist, funclist)

    @testing.for_all_dtypes()
    def test_mismatched_lengths(self, dtype):
        funclist = [-1, 0, 2, 4, 5]
        for xp in (numpy, cupy):
            x = xp.linspace(-2, 4, 6, dtype=dtype)
            condlist = [x < 0, x >= 0]
            with pytest.raises(ValueError):
                xp.piecewise(x, condlist, funclist)

    @testing.for_all_dtypes()
    def test_callable_funclist(self, dtype):
        x = cupy.linspace(-2, 4, 6, dtype=dtype)
        condlist = [x < 0, x > 0]
        funclist = [lambda x: -x, lambda x: x]
        with pytest.raises(NotImplementedError):
            cupy.piecewise(x, condlist, funclist)

    @testing.for_all_dtypes()
    def test_mixed_funclist(self, dtype):
        x = cupy.linspace(-2, 2, 6, dtype=dtype)
        condlist = [x < 0, x == 0, x > 0]
        funclist = [-10, lambda x: -x, 10, lambda x: x]
        with pytest.raises(NotImplementedError):
            cupy.piecewise(x, condlist, funclist)