random.py 4.59 KB
Newer Older
chenzk's avatar
v1.0.8  
chenzk committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import contextlib
import random
from dataclasses import dataclass
from typing import MutableMapping, Optional, Tuple

import numpy as np
import torch

from nanotron import distributed as dist
from nanotron.distributed import ProcessGroup


@dataclass
class RandomState:
    random: Tuple[int, Tuple[int, ...], None]
    numpy: Tuple[str, np.ndarray, int, int, float]
    torch_cpu: torch.Tensor
    torch_cuda: Optional[torch.Tensor]

    def __eq__(self, other):
        return (
            isinstance(other, RandomState)
            and all(v1 == v2 for v1, v2 in zip(self.random, other.random))
            and all(
                np.array_equal(v1, v2) if isinstance(v1, np.ndarray) else v1 == v2
                for v1, v2 in zip(self.numpy, other.numpy)
            )
            and torch.equal(self.torch_cpu, other.torch_cpu)
            and (
                other.torch_cuda is None if self.torch_cuda is None else torch.equal(self.torch_cuda, other.torch_cuda)
            )
        )


class RandomStates(MutableMapping[str, RandomState]):
    def __init__(self, dict: dict):
        for key, value in dict.items():
            self.check_type(key, value)
        # TODO @thomasw21: We make a copy for safety measure.
        self._dict = dict.copy()

    @staticmethod
    def check_type(key, value):
        if not isinstance(key, str):
            raise ValueError(f"Expected key to be of type str. Got {type(key)}")
        if not isinstance(value, RandomState):
            raise ValueError(f"Expected value to be of type `nanotron.dataclass.RandomState`. Got {type(value)}")

    def __getitem__(self, item):
        return self._dict[item]

    def __iter__(self):
        return self._dict.__iter__()

    def __len__(self):
        return len(self._dict)

    def __delitem__(self, key):
        raise ValueError("Can't delete a random states key")

    def __setitem__(self, key, value):
        if key not in self._dict:
            raise ValueError("Can't add a new random states after initialisation")
        self.check_type(key, value)
        return self._dict.__setitem__(key, value)

    def __eq__(self, other):
        if not isinstance(other, RandomStates):
            return False

        return self._dict == other._dict


def set_random_seed(seed: int):
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)


def set_random_state(random_state: RandomState):
    random.setstate(random_state.random)
    np.random.set_state(random_state.numpy)
    torch.set_rng_state(random_state.torch_cpu)
    if torch.cuda.is_available():
        torch.cuda.set_rng_state(random_state.torch_cuda, "cuda")
    else:
        assert random_state.torch_cuda is None


def get_current_random_state():
    """Returns a snapshot of current random state"""
    return RandomState(
        random=random.getstate(),
        numpy=np.random.get_state(),
        torch_cpu=torch.random.get_rng_state(),
        torch_cuda=torch.cuda.get_rng_state("cuda") if torch.cuda.is_available() else None,
    )


@contextlib.contextmanager
def branch_random_state(random_states: RandomStates, key: str, enabled: bool):
    """
    Context manager handling random state:
     - upon entering: Stores current random state and set new random state defined by key.
     - upon exiting: updates key in `random_states` to the new current random state, and set back the old one.
    """
    if not enabled:
        yield
        return

    old_random_state = get_current_random_state()

    # Get the new state associated to the key
    new_random_state = random_states[key]
    set_random_state(new_random_state)

    try:
        yield
    finally:
        # Update state from parallel_context with the newest state
        new_random_state = get_current_random_state()
        random_states[key] = new_random_state

        # Set the old state back
        set_random_state(old_random_state)


def get_synced_random_state(
    random_state: RandomState,
    pg: ProcessGroup,
):
    # We use rank 0 as a reference and broadcast random states from that rank to all the other ranks within a group in order to sync them
    reference_rank = 0
    if dist.get_rank(pg) == reference_rank:
        random_states = [random_state]
    else:
        random_states = [None]

    # TODO @thomasw21: broadcast tensor using `broadcast` in order not to use pickle
    dist.broadcast_object_list(
        random_states, src=dist.get_global_rank(pg, reference_rank), group=pg, device=torch.device("cuda")
    )

    new_random_state = random_states[0]
    assert new_random_state is not None
    return new_random_state