test_overlap.py 3.75 KB
Newer Older
dugupeiwen's avatar
dugupeiwen committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import numpy as np

from numba import jit
from numba.core import types
from numba.tests.support import TestCase, tag
import unittest


# Array overlaps involving a displacement

def array_overlap1(src, dest, k=1):
    assert src.shape == dest.shape
    dest[k:] = src[:-k]

def array_overlap2(src, dest, k=1):
    assert src.shape == dest.shape
    dest[:-k] = src[k:]

def array_overlap3(src, dest, k=1):
    assert src.shape == dest.shape
    dest[:,:-k] = src[:,k:]

def array_overlap4(src, dest, k=1):
    assert src.shape == dest.shape
    dest[:,k:] = src[:,:-k]

def array_overlap5(src, dest, k=1):
    assert src.shape == dest.shape
    dest[...,:-k] = src[...,k:]

def array_overlap6(src, dest, k=1):
    assert src.shape == dest.shape
    dest[...,k:] = src[...,:-k]

# Array overlaps involving an in-place reversal

def array_overlap11(src, dest):
    assert src.shape == dest.shape
    dest[::-1] = src

def array_overlap12(src, dest):
    assert src.shape == dest.shape
    dest[:] = src[::-1]

def array_overlap13(src, dest):
    assert src.shape == dest.shape
    dest[:,::-1] = src

def array_overlap14(src, dest):
    assert src.shape == dest.shape
    dest[:] = src[:,::-1]

def array_overlap15(src, dest):
    assert src.shape == dest.shape
    dest[...,::-1] = src

def array_overlap16(src, dest):
    assert src.shape == dest.shape
    dest[:] = src[...,::-1]


class TestArrayOverlap(TestCase):

    def check_overlap(self, pyfunc, min_ndim, have_k_argument=False):
        N = 4

        def vary_layouts(orig):
            yield orig.copy(order='C')
            yield orig.copy(order='F')
            a = orig[::-1].copy()[::-1]
            assert not a.flags.c_contiguous and not a.flags.f_contiguous
            yield a

        def check(pyfunc, cfunc, pydest, cdest, kwargs):
            pyfunc(pydest, pydest, **kwargs)
            cfunc(cdest, cdest, **kwargs)
            self.assertPreciseEqual(pydest, cdest)

        cfunc = jit(nopython=True)(pyfunc)
        # Check for up to 3d arrays
        for ndim in range(min_ndim, 4):
            shape = (N,) * ndim
            orig = np.arange(0, N**ndim).reshape(shape)
            # Note we cannot copy a 'A' layout array exactly (bitwise),
            # so instead we call vary_layouts() twice
            for pydest, cdest in zip(vary_layouts(orig), vary_layouts(orig)):
                if have_k_argument:
                    for k in range(1, N):
                        check(pyfunc, cfunc, pydest, cdest, dict(k=k))
                else:
                    check(pyfunc, cfunc, pydest, cdest, {})

    def check_overlap_with_k(self, pyfunc, min_ndim):
        self.check_overlap(pyfunc, min_ndim=min_ndim, have_k_argument=True)

    def test_overlap1(self):
        self.check_overlap_with_k(array_overlap1, min_ndim=1)

    def test_overlap2(self):
        self.check_overlap_with_k(array_overlap2, min_ndim=1)

    def test_overlap3(self):
        self.check_overlap_with_k(array_overlap3, min_ndim=2)

    def test_overlap4(self):
        self.check_overlap_with_k(array_overlap4, min_ndim=2)

    def test_overlap5(self):
        self.check_overlap_with_k(array_overlap5, min_ndim=1)

    def test_overlap6(self):
        self.check_overlap_with_k(array_overlap6, min_ndim=1)

    def test_overlap11(self):
        self.check_overlap(array_overlap11, min_ndim=1)

    def test_overlap12(self):
        self.check_overlap(array_overlap12, min_ndim=1)

    def test_overlap13(self):
        self.check_overlap(array_overlap13, min_ndim=2)

    def test_overlap14(self):
        self.check_overlap(array_overlap14, min_ndim=2)

    def test_overlap15(self):
        self.check_overlap(array_overlap15, min_ndim=1)

    def test_overlap16(self):
        self.check_overlap(array_overlap16, min_ndim=1)


if __name__ == '__main__':
    unittest.main()