"multi_node/run-7b-sft-lora.sh" did not exist on "829d3ffa9ed1784fd830e79cbe842418899dc1f5"
test_complex_kernel.py 497 Bytes
Newer Older
dugupeiwen's avatar
dugupeiwen committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
import numpy as np
from numba import cuda
from numba.cuda.testing import unittest, CUDATestCase


class TestCudaComplex(CUDATestCase):
    def test_cuda_complex_arg(self):
        @cuda.jit('void(complex128[:], complex128)')
        def foo(a, b):
            i = cuda.grid(1)
            a[i] += b

        a = np.arange(5, dtype=np.complex128)
        a0 = a.copy()
        foo[1, a.shape](a, 2j)
        self.assertTrue(np.allclose(a, a0 + 2j))


if __name__ == '__main__':
    unittest.main()