test_random.py 1.67 KB
Newer Older
Shanmugam Ramasamy's avatar
Shanmugam Ramasamy committed
1
2
3
import pytest
import torch

liangjing's avatar
liangjing committed
4
5
6
7
8
9
10
11
12
from megatron.core.tensor_parallel.random import (
    CudaRNGStatesTracker,
    checkpoint,
    get_cuda_rng_tracker,
    model_parallel_cuda_manual_seed,
)
from tests.unit_tests.test_utilities import Utils


Shanmugam Ramasamy's avatar
Shanmugam Ramasamy committed
13
14
def test_cuda_rng_states_tracker():
    rng_tracker = CudaRNGStatesTracker()
liangjing's avatar
liangjing committed
15
16
    rng_tracker.set_states({"state1": 1234})
    assert rng_tracker.get_states()["state1"] == 1234
Shanmugam Ramasamy's avatar
Shanmugam Ramasamy committed
17
    rng_tracker.reset()
liangjing's avatar
liangjing committed
18
    assert rng_tracker.get_states() == {}
Shanmugam Ramasamy's avatar
Shanmugam Ramasamy committed
19
    seed = 1111
liangjing's avatar
liangjing committed
20
    rng_tracker.add("state2", seed)
Shanmugam Ramasamy's avatar
Shanmugam Ramasamy committed
21
    with pytest.raises(Exception):
liangjing's avatar
liangjing committed
22
        assert rng_tracker.add("state3", seed)
Shanmugam Ramasamy's avatar
Shanmugam Ramasamy committed
23
    with pytest.raises(Exception):
liangjing's avatar
liangjing committed
24
25
        assert rng_tracker.add("state2", 111)
    assert rng_tracker.get_states()['state2'] is not None
Shanmugam Ramasamy's avatar
Shanmugam Ramasamy committed
26
    with pytest.raises(Exception):
liangjing's avatar
liangjing committed
27
28
        assert ()

Shanmugam Ramasamy's avatar
Shanmugam Ramasamy committed
29
30
31
32
33
    rng_tracker.fork("state2")
    torch.cuda.manual_seed(seed)
    rng_state = torch.cuda.get_rng_state()
    assert torch.equal(rng_tracker.get_states()['state2'], rng_state)

liangjing's avatar
liangjing committed
34

Shanmugam Ramasamy's avatar
Shanmugam Ramasamy committed
35
def test_model_parallel_cuda_manual_seed():
liangjing's avatar
liangjing committed
36
    Utils.initialize_model_parallel(4, 2)
Shanmugam Ramasamy's avatar
Shanmugam Ramasamy committed
37
    model_parallel_cuda_manual_seed(0)
liangjing's avatar
liangjing committed
38
39
    rng_tracker = get_cuda_rng_tracker()
    assert rng_tracker.get_states()['model-parallel-rng'] is not None
Shanmugam Ramasamy's avatar
Shanmugam Ramasamy committed
40
41
    Utils.destroy_model_parallel()

liangjing's avatar
liangjing committed
42

Shanmugam Ramasamy's avatar
Shanmugam Ramasamy committed
43
44
def test_checkpoint():
    def test_forward(*input):
liangjing's avatar
liangjing committed
45
46
47
48
49
        return input[0] + input[1]

    assert torch.equal(
        torch.ones(16) * 3, checkpoint(test_forward, None, torch.ones(16), torch.ones(16) * 2)
    )
Shanmugam Ramasamy's avatar
Shanmugam Ramasamy committed
50
    Utils.initialize_model_parallel()
liangjing's avatar
liangjing committed
51
52
53
54
    input1 = torch.ones((4, 4))
    checkpoint(test_forward, True, input1, torch.ones((4, 4)) * 2)
    assert torch.equal(torch.ones(input1.numel()).cuda(), input1)
    Utils.destroy_model_parallel()