parameters.py 8.11 KB
Newer Older
chenzk's avatar
v1.0.8  
chenzk committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
import dataclasses
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple

import torch
from torch import nn

from nanotron import distributed as dist
from nanotron import logging

if TYPE_CHECKING:
    from nanotron.models import NanotronModel

logger = logging.get_logger(__name__)


@dataclasses.dataclass
class SlicesPair:
    local_slices: Tuple[slice, ...]
    global_slices: Tuple[slice, ...]

    @staticmethod
    def slice_to_str(s: slice):
        # e.g. slice(0, 10, 2) -> "0,10,2"
        # e.g. slice(None, None, None) -> "None,None,None"
        return ",".join(str(x) if x is not None else "None" for x in (s.start, s.stop, s.step))

    @staticmethod
    def str_to_slice(s: str):
        return slice(*(int(x) if x != "None" else None for x in s.split(",")))

    def __str__(self):
        # e.g. local_slices (slice(0, 10, 2), slice(None, None, None)) -> "0,10,2|None,None,None"
        local_slices_str = "|".join(map(self.slice_to_str, self.local_slices))
        # e.g. global_slices (slice(0, 20, 4), slice(None, None, None)) -> "0,20,4|None,None,None"
        global_slices_str = "|".join(map(self.slice_to_str, self.global_slices))
        # e.g. "0,10,2|None,None,None#0,20,4|None,None,None"
        return f"{local_slices_str}#{global_slices_str}"

    @classmethod
    def from_str(cls, string: str):
        local_slices_str, global_slices_str = string.split("#")
        local_slices = tuple(map(cls.str_to_slice, local_slices_str.split("|")))
        global_slices = tuple(map(cls.str_to_slice, global_slices_str.split("|")))
        return cls(local_slices, global_slices)

    @classmethod
    def tuple_to_str(cls, pairs):
        # e.g. 2 SlicesPair, 1st SlicesPair local_slices "0,10,2|None,None,None" and global_slices "0,10,2|None,None,None"
        # 2nd SlicesPair local_slices "0,20,4|None,None,None" and global_slices "0,40,8|None,None,None"
        # -> "0,10,2|None,None,None#0,10,2|None,None,None;0,20,4|None,None,None#0,40,8|None,None,None"
        return ";".join(map(str, pairs))

    @classmethod
    def tuple_from_str(cls, string: str):
        return tuple(map(cls.from_str, string.split(";")))


@dataclasses.dataclass
class TiedInfo:
    name: str
    # name must be defined starting from `root_module` (e.g. root_module.dense0.dense1.weight)
    root_module: nn.Module
    global_ranks: Tuple[int, ...]
    # None signifies that we do not reduce
    reduce_op: Optional[dist.ReduceOp]

    def get_full_name_from_model(self, model: nn.Module) -> str:
        module_id_to_prefix = {id(module): f"{module_name}." for module_name, module in model.named_modules()}
        # Fix the root_model
        module_id_to_prefix[id(model)] = ""
        return self.get_full_name_from_module_id_to_prefix(module_id_to_prefix)

    def get_full_name_from_module_id_to_prefix(self, module_id_to_prefix: Dict[int, str]) -> str:
        return f"{module_id_to_prefix[id(self.root_module)]}{self.name}"  # this assumes root_module is part of module_id_to_prefix


@dataclasses.dataclass
class ShardedInfo:
    global_ranks: Tuple[int, ...]
    # Info of to what slice of the unsharded tensor (global_slices) the current sharded tensor corresponds (local_slices)
    local_global_slices_pairs: Tuple[SlicesPair, ...]
    # The shape of the unsharded tensor
    unsharded_shape: Tuple[int, ...]

    def is_tp_sharded(self, parallel_context) -> bool:
        return set(dist.get_global_ranks(parallel_context.tp_pg)).issubset(set(self.global_ranks))

    def is_expert_sharded(self, parallel_context) -> bool:
        return set(dist.get_global_ranks(parallel_context.expert_pg)).issubset(set(self.global_ranks))

    def is_dp_sharded(self, parallel_context):
        return set(dist.get_global_ranks(parallel_context.dp_pg)).issubset(set(self.global_ranks))


class NanotronParameter(nn.Parameter):
    """Base class for all parameters in Nanotronmodels

    A NanotronParameter can have specific properties:
     - sharded: the parameter is considered to be `sharded` across multiple devices
     - tied: the parameter is considered to be `tied` with other parameters. We sum gradients over those.

    .. note::
        Notes about tied weights:
        - Tied weights means weights that need to be synced only within the same DP rank, regardless if they are part of TP strategy or just shared weights between two layers.
        - Syncing tied weights usually require to sum gradients.
        - Some weights are synced without needing to reduce grads over ranks. They can be in the same device (ex: enc/dec embeds in the same PP stage) or they can be duplicated across TP and duplicate the workload across TP ranks (ex: LN using traditional TP)
        - Even if some weights don't need their grads to be reduced, it's still useful for them to be marked as tied. For example, current serialization format requires to mark them correctly.
    """

    NANOTRON_PARAMETER_METADATA_ATTRIBUTE_NAME = "__nanotron_metadata__"
    NANOTRON_PARAMETER_METADATA_TIED_KEY = "tied"
    NANOTRON_PARAMETER_METADATA_SHARDED_KEY = "sharded"

    def __new__(cls, tensor: torch.Tensor, requires_grad: bool = True):
        param = nn.Parameter.__new__(cls, data=tensor.data.detach(), requires_grad=requires_grad)

        if isinstance(tensor, NanotronParameter):
            # Check that we don't inherit a weird class
            # We copy in order not to make in-place operation
            assert type(tensor) == NanotronParameter
            setattr(
                param,
                cls.NANOTRON_PARAMETER_METADATA_ATTRIBUTE_NAME,
                getattr(tensor, cls.NANOTRON_PARAMETER_METADATA_ATTRIBUTE_NAME).copy(),
            )
        else:
            setattr(param, cls.NANOTRON_PARAMETER_METADATA_ATTRIBUTE_NAME, {})

        return param

    def _set_metadata(self, key: str, value: Any):
        metadata = getattr(self, self.NANOTRON_PARAMETER_METADATA_ATTRIBUTE_NAME)

        if key in metadata:
            raise ValueError(
                f"We shouldn't override previous metadata. Key to be overridden: {key}, current metadata: {metadata}"
            )
        else:
            metadata[key] = value

    def mark_as_tied(
        self,
        name: str,
        global_ranks: Tuple[int, ...],
        reduce_op: Optional[dist.ReduceOp],
        root_module: "NanotronModel",
    ):
        self._set_metadata(
            self.NANOTRON_PARAMETER_METADATA_TIED_KEY,
            TiedInfo(name=name, global_ranks=global_ranks, reduce_op=reduce_op, root_module=root_module),
        )

    def get_tied_info(self) -> TiedInfo:
        return getattr(self, self.NANOTRON_PARAMETER_METADATA_ATTRIBUTE_NAME)[
            self.NANOTRON_PARAMETER_METADATA_TIED_KEY
        ]

    @property
    def is_tied(self) -> bool:
        return self.NANOTRON_PARAMETER_METADATA_TIED_KEY in getattr(
            self, self.NANOTRON_PARAMETER_METADATA_ATTRIBUTE_NAME
        )

    def mark_as_sharded(
        self,
        global_ranks: Tuple[int, ...],
        local_global_slices_pairs: Tuple[SlicesPair, ...],
        unsharded_shape: Tuple[int, ...],
    ):
        self._set_metadata(
            self.NANOTRON_PARAMETER_METADATA_SHARDED_KEY,
            ShardedInfo(
                global_ranks=global_ranks,
                local_global_slices_pairs=local_global_slices_pairs,
                unsharded_shape=unsharded_shape,
            ),
        )

    def get_sharded_info(self) -> ShardedInfo:
        return getattr(self, self.NANOTRON_PARAMETER_METADATA_ATTRIBUTE_NAME)[
            self.NANOTRON_PARAMETER_METADATA_SHARDED_KEY
        ]

    @property
    def is_sharded(self) -> bool:
        return self.NANOTRON_PARAMETER_METADATA_SHARDED_KEY in getattr(
            self, self.NANOTRON_PARAMETER_METADATA_ATTRIBUTE_NAME
        )


def sanity_check(root_module: nn.Module):
    """Makes sure that the module is in Nanotronformat

    Format:
     - all parameters are `NanotronParameter`, this allows us to add metadata to a parameter.
    """
    for name, param in root_module.named_parameters():
        if not isinstance(param, NanotronParameter):
            raise ValueError(
                f"Nanotronrequires model to be in Nanotronformat, ie all parameters are required to be a NanotronParameter. {name} isn't."
            )