utils.py 1.99 KB
Newer Older
chenzk's avatar
v1.0.8  
chenzk committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import functools
import operator
import os

import torch
from torch import nn

from nanotron import distributed as dist
from nanotron.parallel import ParallelContext
from nanotron.parallel.tied_parameters import get_tied_id_to_param
from nanotron.utils import Singleton


class MemoryBuffer(metaclass=Singleton):
    """
    Global memory buffer to store intermediate activations that need not to be cached for the backward pass.
    """

    def __init__(self):
        self.buffer = {}

    def get(self, name: str, shape: tuple[int], dtype: torch.dtype = torch.bfloat16) -> torch.Tensor:
        required_numel = functools.reduce(operator.mul, shape, 1)
        if (name, dtype) not in self.buffer or self.buffer[name, dtype].numel() < required_numel:
            self.buffer[name, dtype] = torch.empty(
                required_numel, dtype=dtype, device=torch.cuda.current_device(), requires_grad=False
            )
        return self.buffer[name, dtype][:required_numel].view(shape)


def assert_cuda_max_connections_set_to_1(func):
    flag_is_set_to_1 = None

    @functools.wraps(func)
    def wrapper(*args, **kwargs):
        nonlocal flag_is_set_to_1
        if flag_is_set_to_1 is None:
            assert os.environ.get("CUDA_DEVICE_MAX_CONNECTIONS") == "1"
            flag_is_set_to_1 = True
        return func(*args, **kwargs)

    return wrapper


def initial_sync(model: nn.Module, parallel_context: ParallelContext):
    # Synchronize across dp: basic assumption
    sorted_name_params = sorted(model.named_parameters(), key=lambda x: x[0])
    for name, param in sorted_name_params:
        dist.all_reduce(param, op=dist.ReduceOp.AVG, group=parallel_context.dp_pg)

    # Synchronize across tied weights: basic assumption
    for (_, group_ranks), param in sorted(
        get_tied_id_to_param(parameters=model.parameters(), root_module=model).items(), key=lambda x: x[0]
    ):
        group = parallel_context.world_ranks_to_pg[group_ranks]
        dist.all_reduce(param, op=dist.ReduceOp.AVG, group=group)