test_curand.py 1.43 KB
Newer Older
root's avatar
root committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import pickle
import unittest

import numpy

import cupy
from cupy.cuda import curand


class TestGenerateNormal(unittest.TestCase):

    def setUp(self):
        self.generator = curand.createGenerator(
            curand.CURAND_RNG_PSEUDO_DEFAULT)

    def test_invalid_argument_normal_float(self):
        out = cupy.empty((1,), dtype=numpy.float32)
        with self.assertRaises(ValueError):
            curand.generateNormal(
                self.generator, out.data.ptr, 1, 0.0, 1.0)

    def test_invalid_argument_normal_double(self):
        out = cupy.empty((1,), dtype=numpy.float64)
        with self.assertRaises(ValueError):
            curand.generateNormalDouble(
                self.generator, out.data.ptr, 1, 0.0, 1.0)

    def test_invalid_argument_log_normal_float(self):
        out = cupy.empty((1,), dtype=numpy.float32)
        with self.assertRaises(ValueError):
            curand.generateLogNormal(
                self.generator, out.data.ptr, 1, 0.0, 1.0)

    def test_invalid_argument_log_normal_double(self):
        out = cupy.empty((1,), dtype=numpy.float64)
        with self.assertRaises(ValueError):
            curand.generateLogNormalDouble(
                self.generator, out.data.ptr, 1, 0.0, 1.0)


class TestExceptionPicklable(unittest.TestCase):

    def test(self):
        e1 = curand.CURANDError(100)
        e2 = pickle.loads(pickle.dumps(e1))
        assert e1.args == e2.args
        assert str(e1) == str(e2)