test_rsqrt.py 454 Bytes
Newer Older
root's avatar
root committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
import unittest

import numpy

import cupy
from cupy import testing
import cupyx


class TestRsqrt(unittest.TestCase):

    @testing.for_all_dtypes(no_complex=True)
    def test_rsqrt(self, dtype):
        # Adding 1.0 to avoid division by zero.
        a = testing.shaped_arange((2, 3), numpy, dtype) + 1.0
        out = cupyx.rsqrt(cupy.array(a))
        # numpy.sqrt is broken in numpy<1.11.2
        testing.assert_allclose(out, 1.0 / numpy.sqrt(a))