nn.py 10.2 KB
Newer Older
chenzk's avatar
v1.0  
chenzk committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional, Tuple

import torch
from torch import nn

from nanotron import distributed as dist
from nanotron.distributed import get_global_rank
from nanotron.parallel.parameters import NanotronParameter
from nanotron.parallel.sharded_parameters import (
    SplitConfig,
    create_sharded_parameter_from_config,
    mark_all_parameters_in_module_as_sharded,
)
from nanotron.parallel.tensor_parallel.distributed_differentiable_primitives import (
    differentiable_all_gather,
    differentiable_all_reduce_sum,
    differentiable_identity,
    differentiable_reduce_scatter_sum,
)
from nanotron.parallel.tensor_parallel.enum import TensorParallelLinearMode
from nanotron.parallel.tensor_parallel.functional import (
    column_linear,
    row_linear,
)
from nanotron.parallel.tied_parameters import create_tied_parameter


class TensorParallelColumnLinear(nn.Linear):
    def __init__(
        self,
        in_features,
        out_features,
        pg: dist.ProcessGroup,
        mode: TensorParallelLinearMode,
        bias=True,
        device=None,
        dtype=None,
        async_communication: bool = False,
        contiguous_chunks: Optional[Tuple[int, ...]] = None,
        tp_recompute_allgather: bool = True,
    ):
        self.pg = pg
        self.world_size = pg.size()

        assert out_features % self.world_size == 0

        self.in_features = in_features
        self.out_features = out_features // self.world_size
        self.tp_recompute_allgather = tp_recompute_allgather

        super().__init__(
            in_features=self.in_features,
            out_features=self.out_features,
            bias=bias,
            device=device,
            dtype=dtype,
        )

        self.mode = mode
        self.async_communication = async_communication

        if contiguous_chunks is not None:
            assert (
                sum(contiguous_chunks) == out_features
            ), f"Sum of contiguous chunks ({sum(contiguous_chunks)}) must equal to out_features ({out_features})"
        split_config = SplitConfig(split_dim=0, contiguous_chunks=contiguous_chunks)

        mark_all_parameters_in_module_as_sharded(
            self,
            pg=self.pg,
            split_config=split_config,
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return column_linear(
            input=x,
            weight=self.weight,
            bias=self.bias,
            group=self.pg,
            tp_mode=self.mode,
            async_communication=self.async_communication,
            tp_recompute_allgather=self.tp_recompute_allgather,
        )

    def extra_repr(self) -> str:
        return f"tp_rank={dist.get_rank(self.pg)}, {super().extra_repr()}, unsharded_out_features={self.out_features * self.world_size}"


class TensorParallelRowLinear(nn.Linear):
    def __init__(
        self,
        in_features,
        out_features,
        pg: dist.ProcessGroup,
        mode: TensorParallelLinearMode,
        bias=True,
        device=None,
        dtype=None,
        async_communication: bool = False,
        contiguous_chunks: Optional[Tuple[int, ...]] = None,
    ):
        self.pg = pg
        self.world_size = pg.size()

        assert in_features % self.world_size == 0

        self.in_features = in_features // self.world_size
        self.out_features = out_features

        # No need to shard the bias term, only rank 0 would have it
        bias = dist.get_rank(self.pg) == 0 and bias

        super().__init__(
            in_features=self.in_features,
            out_features=self.out_features,
            bias=bias,
            device=device,
            dtype=dtype,
        )
        self.mode = mode
        self.async_communication = async_communication
        if self.mode is TensorParallelLinearMode.ALL_REDUCE and self.async_communication:
            raise ValueError("async_communication is not supported for ALL_REDUCE mode")

        if contiguous_chunks is not None:
            assert (
                sum(contiguous_chunks) == in_features
            ), f"Sum of contiguous chunks ({sum(contiguous_chunks)}) must equal to in_features ({in_features})"

        split_config = SplitConfig(split_dim=1, contiguous_chunks=contiguous_chunks)

        self._mark_all_parameters_in_module_as_sharded(split_config)

    def _mark_all_parameters_in_module_as_sharded(self, split_config: SplitConfig):
        for name, param in list(self.named_parameters()):
            if name == "bias":
                # `bias` only exists in rank 0 because it's not sharded
                new_param = NanotronParameter(tensor=param)
            else:
                new_param = create_sharded_parameter_from_config(
                    parameter=param,
                    pg=self.pg,
                    split_config=split_config,
                )
            setattr(self, name, new_param)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return row_linear(
            input=x,
            weight=self.weight,
            bias=self.bias,
            group=self.pg,
            tp_mode=self.mode,
            async_communication=self.async_communication,
        )

    def extra_repr(self) -> str:
        return f"tp_rank={dist.get_rank(self.pg)}, {super().extra_repr()}, unsharded_in_features={self.in_features * self.world_size}"


class TiedLinear(nn.Linear):
    def __init__(
        self,
        in_features,
        out_features,
        pg: dist.ProcessGroup,
        mode: TensorParallelLinearMode,
        bias=True,
        device=None,
        dtype=None,
    ):
        self.pg = pg
        self.world_size = pg.size()
        self.mode = mode

        super().__init__(
            in_features=in_features,
            out_features=out_features,
            bias=bias,
            device=device,
            dtype=dtype,
        )

        self._mark_all_parameters_in_module_as_tied()

    def _mark_all_parameters_in_module_as_tied(self):
        for name, param in list(self.named_parameters()):
            new_param = create_tied_parameter(
                parameter=param,
                name=name,
                global_ranks=tuple(sorted((get_global_rank(self.pg, i) for i in range(self.pg.size())))),
                reduce_op=None if self.mode is TensorParallelLinearMode.ALL_REDUCE else dist.ReduceOp.SUM,
                root_module=self,
            )
            setattr(self, name, new_param)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        y = super().forward(x)
        if self.mode is TensorParallelLinearMode.ALL_REDUCE:
            y = differentiable_identity(y, group=self.pg)
        elif self.mode is TensorParallelLinearMode.REDUCE_SCATTER:
            y = differentiable_all_gather(y, group=self.pg)
        else:
            raise ValueError(f"Got unexpected mode: {self.mode}.")

        return y


class TensorParallelEmbedding(nn.Embedding):
    def __init__(
        self,
        num_embeddings,
        embedding_dim,
        pg: dist.ProcessGroup,
        mode: TensorParallelLinearMode,
        padding_idx=None,
        max_norm=None,
        norm_type=2.0,
        scale_grad_by_freq=False,
        sparse=False,
        _weight=None,
        device=None,
        dtype=None,
        contiguous_chunks: Optional[Tuple[int, ...]] = None,
    ):
        self.pg = pg
        self.rank = dist.get_rank(self.pg)
        self.world_size = pg.size()

        self.original_num_embeddings = num_embeddings

        # TODO @thomasw21: Fix and remove that constraint. Typically there's no reason to have such a constraint.
        assert num_embeddings % self.world_size == 0
        block_size = num_embeddings // self.world_size
        # inputs in `[min_id, max_id[` are handled by `self` to get embeddings
        self.min_id = self.rank * block_size
        self.max_id = (self.rank + 1) * block_size

        super().__init__(
            block_size,
            embedding_dim,
            padding_idx=padding_idx,
            max_norm=max_norm,
            norm_type=norm_type,
            scale_grad_by_freq=scale_grad_by_freq,
            sparse=sparse,
            _weight=_weight,
            device=device,
            dtype=dtype,
        )

        self.mode = mode

        if contiguous_chunks is not None:
            assert (
                sum(contiguous_chunks) == num_embeddings
            ), f"Sum of contiguous chunks ({sum(contiguous_chunks)}) must equal to num_embeddings ({num_embeddings})"

        split_config = SplitConfig(split_dim=0, contiguous_chunks=contiguous_chunks)

        mark_all_parameters_in_module_as_sharded(self, pg=self.pg, split_config=split_config)

    def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
        if self.pg.size() > 1:
            # `0` if input is in the correct interval, else `1`
            input_mask = torch.logical_or(self.min_id > input_ids, input_ids >= self.max_id)
            # translate for [0, self.max_id - self.min_id[
            masked_input = input_ids.clone() - self.min_id
            # default all out of bounds values to `0`
            masked_input[input_mask] = 0
        else:
            masked_input = input_ids
        out = super().forward(masked_input)

        if self.pg.size() > 1:
            out = out * (~input_mask[..., None])

        if self.mode is TensorParallelLinearMode.ALL_REDUCE:
            out = differentiable_all_reduce_sum(out, group=self.pg)
        elif self.mode is TensorParallelLinearMode.REDUCE_SCATTER:
            out = differentiable_reduce_scatter_sum(out, group=self.pg)
        else:
            raise ValueError(f"Got unexpected mode: {self.mode}.")

        return out

    def extra_repr(self) -> str:
        return f"tp_rank={dist.get_rank(self.pg)}, {super().extra_repr()}, unsharded_num_embeddings={self.original_num_embeddings}"