"vscode:/vscode.git/clone" did not exist on "51ceef84b46693590766a6a9bb97e44484ddb675"
test_strided_tricks.py 1.49 KB
Newer Older
root's avatar
root committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import unittest

import numpy

import cupy
from cupy import testing
from cupy.lib import stride_tricks


class TestAsStrided(unittest.TestCase):
    def test_as_strided(self):
        a = cupy.array([1, 2, 3, 4])
        a_view = stride_tricks.as_strided(
            a, shape=(2,), strides=(2 * a.itemsize,))
        expected = cupy.array([1, 3])
        testing.assert_array_equal(a_view, expected)

        a = cupy.array([1, 2, 3, 4])
        a_view = stride_tricks.as_strided(
            a, shape=(3, 4), strides=(0, 1 * a.itemsize))
        expected = cupy.array([[1, 2, 3, 4], [1, 2, 3, 4], [1, 2, 3, 4]])
        testing.assert_array_equal(a_view, expected)

    @testing.numpy_cupy_array_equal()
    def test_rolling_window(self, xp):
        a = testing.shaped_arange((3, 4), xp)
        a_rolling = rolling_window(a, 2, 0)

        return a_rolling


def rolling_window(a, window, axis=-1):
    """
    Make an ndarray with a rolling window along axis.
    This function is taken from https://github.com/numpy/numpy/pull/31
    but slightly modified to accept axis option.
    """
    a = numpy.swapaxes(a, axis, -1)
    shape = a.shape[:-1] + (a.shape[-1] - window + 1, window)
    strides = a.strides + (a.strides[-1],)
    if isinstance(a, numpy.ndarray):
        rolling = numpy.lib.stride_tricks.as_strided(
            a, shape=shape, strides=strides)
    elif isinstance(a, cupy.ndarray):
        rolling = stride_tricks.as_strided(a, shape=shape, strides=strides)
    return rolling.swapaxes(-2, axis)