test_device_function.py 2.99 KB
Newer Older
root's avatar
root committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import unittest
import numpy

import pytest

import cupy
from cupyx import jit
from cupy import testing


class TestDeviceFunction(unittest.TestCase):

    def test_device_function(self):
        @jit.rawkernel()
        def f(x, y, z):
            tid = jit.threadIdx.x
            z[tid] = g(x[tid], y[tid]) + x[tid] + y[tid]

        @jit.rawkernel(device=True)
        def g(x, y):
            x += 1
            y += 1
            return x + y

        x = testing.shaped_random((30,), dtype=numpy.int32, seed=0)
        y = testing.shaped_random((30,), dtype=numpy.int32, seed=1)
        z = testing.shaped_random((30,), dtype=numpy.int32, seed=2)
        f((1,), (30,), (x, y, z))
        assert (z == (x + y + 1) * 2).all()

    def test_device_function_duplicated_names(self):
        @jit.rawkernel()
        def f(x, y, z):
            tid = jit.threadIdx.x
            z[tid] = g(10)(x[tid], y[tid])
            z[tid] += g(20)(x[tid], y[tid])
            z[tid] += g(30)(x[tid], y[tid])

        def g(const):
            @jit.rawkernel(device=True)
            def f(x, y):
                return x + y + const
            return f

        x = testing.shaped_random((30,), dtype=numpy.int32, seed=0)
        y = testing.shaped_random((30,), dtype=numpy.int32, seed=1)
        z = testing.shaped_random((30,), dtype=numpy.int32, seed=2)
        f((1,), (30,), (x, y, z))
        assert (z == (x + y) * 3 + 60).all()

    def test_device_function_recursive(self):
        @jit.rawkernel()
        def f(x, y, z):
            tid = jit.threadIdx.x
            z[tid] = g(x[tid], y[tid])

        @jit.rawkernel(device=True)
        def g(x, y):
            return x + y + g(x, y)

        x = testing.shaped_random((30,), dtype=numpy.int32, seed=0)
        y = testing.shaped_random((30,), dtype=numpy.int32, seed=1)
        z = testing.shaped_random((30,), dtype=numpy.int32, seed=2)
        with pytest.raises(ValueError):
            f((1,), (30,), (x, y, z))

    def test_device_function_template_recursion(self):
        @jit.rawkernel()
        def f(x, y):
            tid = jit.threadIdx.x
            y[tid] = x[tid]
            jit.syncthreads()
            g(1)(y)

        def g(step):
            @jit.rawkernel(device=True)
            def f(x):
                if step < 256:
                    tid = jit.threadIdx.x
                    if tid % (step * 2) == 0:
                        x[tid] += x[tid + step]
                        jit.syncthreads()
                        g(step * 2)(x)
            return f

        x = testing.shaped_random((256,), dtype=numpy.int32, seed=0)
        y = testing.shaped_random((256,), dtype=numpy.int32, seed=1)
        f((1,), (256,), (x, y))
        assert y[0] == x.sum()

    def test_device_function_called_once(self):
        @jit.rawkernel(device=True)
        def g(x):
            x[0] += 1
            return 1

        @jit.rawkernel()
        def f(x):
            x[g(x)] += 1

        x = cupy.array([0, 0])
        f((1,), (1,), (x,))
        testing.assert_array_equal(x, [1, 1])