test_random.py 431 Bytes
Newer Older
root's avatar
root committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
import unittest

from cupy import random
from cupy import testing


class TestResetSeed(unittest.TestCase):

    @testing.for_float_dtypes(no_float16=True)
    def test_reset_seed(self, dtype):
        rs = random.get_random_state()
        rs.seed(0)
        l1 = rs.rand(10, dtype=dtype)

        rs = random.get_random_state()
        rs.seed(0)
        l2 = rs.rand(10, dtype=dtype)

        testing.assert_array_equal(l1, l2)