zero.py 22.5 KB
Newer Older
chenzk's avatar
v1.0.5  
chenzk committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
import itertools
import re
from collections import defaultdict
from pathlib import Path
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union

import numpy as np
import torch.optim
from functorch.dim import tree_map
from torch import nn
from tqdm import tqdm

from nanotron import distributed as dist
from nanotron import logging
from nanotron.distributed import ProcessGroup
from nanotron.logging import human_format, log_rank, warn_once
from nanotron.optim.base import BaseOptimizer
from nanotron.optim.inherit_from_other_optimizer import InheritFromOtherOptimizer
from nanotron.parallel import ParallelContext
from nanotron.parallel.parameters import NanotronParameter

logger = logging.get_logger(__name__)


class ZeroDistributedOptimizer(InheritFromOtherOptimizer):
    """Optimizer that handles partitioning of optimizer's states across DP ranks. See ZeRO Stage 1 in the paper https://arxiv.org/abs/1910.02054v3 for more details."""

    def __init__(
        self,
        named_params_or_groups: Iterable[Union[Tuple[str, NanotronParameter], Dict[str, Any]]],
        optimizer_builder: Callable[[Iterable[Dict[str, Any]]], BaseOptimizer],
        dp_pg: ProcessGroup,
    ):
        named_params_or_groups = list(named_params_or_groups)
        if len(named_params_or_groups) == 0 or isinstance(named_params_or_groups[0], dict):
            # case where named_params_or_groups is Iterable[Dict[str, Any]]
            for d in named_params_or_groups:
                assert (
                    "named_params" in d
                ), f"param_groups must contain a 'named_params' key, got a dict with keys {d.keys()}"

            # keep only named_params_or_groups that require grads
            named_params_or_groups = [
                {
                    "named_params": [
                        (name, param) for name, param in named_param_group["named_params"] if param.requires_grad
                    ],
                    **{k: v for k, v in named_param_group.items() if k != "named_params"},
                }
                for named_param_group in named_params_or_groups
            ]

            self.zero_named_param_groups = named_params_or_groups
        else:
            # case where named_params_or_groups is Iterable[Tuple[str, NanotronParameter]]
            # keep only named_params_or_groups that require grads
            named_params_or_groups = [(name, param) for name, param in named_params_or_groups if param.requires_grad]
            self.zero_named_param_groups = [{"named_params": named_params_or_groups}]

        self.dp_pg = dp_pg  # DP process group

        # partition model's params across DP ranks.
        # `self.param_name_to_dp_rank_offsets` sets mapping between each param inside self.named_params and its rank
        # NOTE: some param_groups may have no params in the current rank. we still keep them in self.optimizer.param_groups
        self.param_name_to_dp_rank_offsets = self._partition_parameters()

        current_dp_rank = dist.get_rank(self.dp_pg)
        param_groups_in_rank = [
            {
                "named_params": [
                    (
                        name,
                        get_sliced_tensor(
                            param=param,
                            start_offset=self.param_name_to_dp_rank_offsets[name][current_dp_rank][0],
                            end_offset=self.param_name_to_dp_rank_offsets[name][current_dp_rank][1],
                        ),
                    )
                    for name, param in param_group["named_params"]
                    if current_dp_rank in self.param_name_to_dp_rank_offsets[name]
                ],
                **{k: v for k, v in param_group.items() if k != "named_params"},
            }
            for param_group in self.zero_named_param_groups
        ]

        # initialize rank's optimizer which is responsible for updating the rank's parameters
        # NOTE: In case of ZeRO, `self.id_to_name` stores only names of parameters that are going to be updated by this DP rank's optimizer.
        # NOTE: In case of ZeRO, `self.optimizer` will only get the parameters that are going to be updated by this DP's optimizer. Which
        # means that `self.optimizer.param_groups` is only a subset of `self.param_groups`.
        optimizer = optimizer_builder(param_groups_in_rank)
        super().__init__(optimizer=optimizer, id_to_name=optimizer.id_to_name)

    @torch.no_grad()
    def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]:
        """Performs a single optimization step (parameter update)."""
        # TODO: @nouamanetazi: handle syncing param groups attrs (e.g. if we update lr)

        loss = super().step(closure=closure)

        # calculate param_size (model) + param_size (grads) + 2*param_size/DP_if_zero1 (optim_states)
        expected_allocated = sum(
            param.numel() * param.element_size() * 2 + param.numel() * param.element_size() * 2 / self.dp_pg.size()
            for named_param_group in self.zero_named_param_groups
            for _, param in named_param_group["named_params"]
        )

        log_rank(
            f"[After optim states allocation] Memory usage: {torch.cuda.memory_allocated() / 1024**2:.2f}MB "
            f"(Expected 2*param_size + 2*param_size/DP_if_zero1={expected_allocated / 1024**2:.2f}MB). "
            f"Peak reserved memory: {torch.cuda.max_memory_reserved() / 1024**2:.2f}MB",
            logger=logger,
            level=logging.DEBUG,
            group=self.dp_pg,
            rank=0,
        )

        # All gather updated params
        self._all_gather_params()
        return loss

    def zero_grad(self):
        """Copied from `torch.optim.optimizer.zero_grad` with the only change of using `self.param_groups` instead of `self.optimizer.param_groups`
        because we want to zero out the gradients of all model params (not just the params in the current rank)"""
        super().zero_grad()

        # TODO @thomasw21: This is a call to torch internal API, we need to fix this
        foreach = False  # self.optimizer.defaults.get("foreach", False)

        # TODO @thomasw21: This is a call to torch internal API, we need to fix this
        # if not hasattr(self.optimizer, "_zero_grad_profile_name"):
        #     self.optimizer._hook_for_profile()

        if foreach:
            per_device_and_dtype_grads = defaultdict(lambda: defaultdict(list))

        # TODO @thomasw21: This is a call to torch internal API, we need to fix this
        # with torch.autograd.profiler.record_function(self.optimizer._zero_grad_profile_name):

        # zero out the gradients of all model params (not just the params in the current rank)
        for named_param_group in self.zero_named_param_groups:
            for _, p in named_param_group["named_params"]:
                if p.grad is not None:
                    p.grad = None
        if foreach:
            for _, per_dtype_grads in per_device_and_dtype_grads.items():
                for grads in per_dtype_grads.values():
                    torch._foreach_zero_(grads)

    def _partition_parameters(self) -> Dict[str, Dict[int, Tuple[int, int]]]:
        named_params = [
            (name, param)
            for named_param_group in self.zero_named_param_groups
            for name, param in named_param_group["named_params"]
            if param.requires_grad
        ]

        # maps each model's param to the optimizer's dp rank that is responsible for updating it

        # We assume that parameters can be sharded across DP, ie we can "split" a parameter in different DP. This does break some optimizers, like Adafactor and such.
        # `param_name_to_dp_rank_offsets[name]` is a `Dict[int, Tuple[int, int]]` keys are dp_rank, and `Tuple[int, int]` are the offsets of the param belonging to this DP
        param_name_to_dp_rank_offsets = {}

        # NOTE: save the original shapes before flattening the params
        # so that later on, we can reshape the params to their original shapes
        # for topology-agnostic optimizer states loading
        self._orig_param_shapes = {}
        for name, param in named_params:
            self._orig_param_shapes[name] = param.shape

        for name, param in named_params:
            # We assume parameter to be contiguous in order to have an easy way of sharding it.
            assert param.is_contiguous(), f"Parameter {name} is not contiguous"

            numel = param.numel()
            padded_numel_per_dp = (numel - 1) // self.dp_pg.size() + 1
            sizes = np.full(shape=(self.dp_pg.size()), fill_value=padded_numel_per_dp)
            remainder = padded_numel_per_dp * self.dp_pg.size() - numel
            # Last `remainder` indices has one less element
            if remainder > 0:
                # It's weird that `size[-0:]` returns the entire list instead of nothing
                sizes[-remainder:] -= 1
            end_offsets = np.cumsum(sizes)
            assert len(end_offsets) == self.dp_pg.size()
            assert end_offsets[-1] == numel, f"Somehow {end_offsets[-1]} != {numel}"
            # We want start indices,
            start_offsets = np.concatenate([[0], end_offsets[:-1]])

            param_name_to_dp_rank_offsets[name] = {
                dp_rank: (start_offsets[dp_rank], end_offsets[dp_rank])
                for dp_rank in range(self.dp_pg.size())
                if start_offsets[dp_rank] < end_offsets[dp_rank]  # Only if the slice is not empty.
            }

        log_rank("[ZeRO sharding] Size of optimizer params per rank:", logger=logger, level=logging.INFO, rank=0)
        all_numel = sum(
            param_name_to_dp_rank_offsets[name][dp_rank][1] - param_name_to_dp_rank_offsets[name][dp_rank][0]
            for name, param in named_params
            for dp_rank in range(self.dp_pg.size())
            if dp_rank in param_name_to_dp_rank_offsets[name]
        )
        for dp_rank in range(self.dp_pg.size()):
            acc_numel = sum(
                value[dp_rank][1] - value[dp_rank][0]
                for value in param_name_to_dp_rank_offsets.values()
                if dp_rank in value
            )
            log_rank(
                f"[ZeRO sharding] DP Rank {dp_rank} has {human_format(acc_numel)} out of {human_format(all_numel)} ({0 if all_numel == 0 else acc_numel / all_numel * 100:.2f}%) params' optimizer states",
                logger=logger,
                level=logging.INFO,
                rank=0,
            )

        return param_name_to_dp_rank_offsets

    def _all_gather_params(self):
        """All gather updated params"""
        all_named_tensors_to_gather = [
            (name, param.view(-1))
            for named_param_groups in self.zero_named_param_groups
            for name, param in named_param_groups["named_params"]
        ]

        if len(all_named_tensors_to_gather) == 0:
            # No need to broadcast if there's nothing
            return

        if self.dp_pg.size() == 1:
            # They should already be updated
            return

        current_dp_rank = dist.get_rank(self.dp_pg)
        dist.all_gather_coalesced(
            output_tensor_lists=[
                [
                    tensor[slice(*self.param_name_to_dp_rank_offsets[name][dp_rank])]
                    if dp_rank in self.param_name_to_dp_rank_offsets[name]
                    else torch.empty(0, dtype=tensor.dtype, device=tensor.device)
                    for dp_rank in range(self.dp_pg.size())
                ]
                for name, tensor in all_named_tensors_to_gather
            ],
            input_tensor_list=[
                tensor[slice(*self.param_name_to_dp_rank_offsets[name][current_dp_rank])]
                if current_dp_rank in self.param_name_to_dp_rank_offsets[name]
                else torch.empty(0, dtype=tensor.dtype, device=tensor.device)
                for name, tensor in all_named_tensors_to_gather
            ],
            group=self.dp_pg,
        )


# Helpers


class SlicedFlatTensor(torch.Tensor):
    """Subclass of `torch.Tensor` that unable to define `grad` getter on a slice of a flattened tensor."""

    # Based on torch/testing/_internal/logging_tensor.py
    # https://github.com/pytorch/pytorch/issues/102337#issuecomment-1579673041
    __torch_function__ = torch._C._disabled_torch_function_impl

    @staticmethod
    def get_sliced_flat_tensor(data, start_offset, end_offset):
        with torch.no_grad():
            return data.view(-1)[start_offset:end_offset]

    @staticmethod
    def __new__(cls, data, start_offset, end_offset):
        sliced_tensor = cls.get_sliced_flat_tensor(data=data, start_offset=start_offset, end_offset=end_offset)

        result = torch.Tensor._make_wrapper_subclass(  # type: ignore[attr-defined]
            cls,
            sliced_tensor.size(),
            strides=sliced_tensor.stride(),
            storage_offset=sliced_tensor.storage_offset(),
            # TODO: clone storage aliasing
            dtype=sliced_tensor.dtype,
            layout=sliced_tensor.layout,
            device=sliced_tensor.device,
            requires_grad=sliced_tensor.requires_grad,
        )
        return result

    def __init__(self, data, start_offset, end_offset):
        super().__init__()
        # TODO @thomasw21: Make is so that you can never update this value
        self.sliced_flat_tensor = self.get_sliced_flat_tensor(
            data=data, start_offset=start_offset, end_offset=end_offset
        )
        self.orig_data = data
        self.start_offset = start_offset
        self.end_offset = end_offset

    @classmethod
    def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
        def unwrap(e):
            return e.sliced_flat_tensor if isinstance(e, cls) else e

        def never_wrap(e):
            # Never re-wrap
            return e

        return tree_map(never_wrap, func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs)))

    def _get_grad(self):
        if self.orig_data.grad is None:
            return None
        with torch.no_grad():
            return self.orig_data.grad.view(-1)[self.start_offset : self.end_offset]

    def _set_grad(self, grad):
        if grad is not None:
            orig_grad = self._get_grad()
            if orig_grad is None:
                raise NotImplementedError(
                    "Trying to set gradient on a sliced tensor when the original tensor hasn't allocated the buffer for the gradient"
                )
            orig_grad.copy_(grad)
            return
        # TODO @thomasw21: This is unfortunately necessary since we might pass `SliceTensor` to the optimizer.
        warn_once(
            logger=logger,
            msg="You're setting a `SlicedTensor` gradient to None. We're going to assume you meant to set the original tensor gradient to None.",
            rank=0,
        )
        self.orig_data.grad = None

    def _del_grad(self):
        raise NotImplementedError

    # TODO @thomasw21: Figure out why this function doesn't get inherited. https://github.com/pytorch/pytorch/issues/102337#issuecomment-1634363356
    def data_ptr(self):
        return self.sliced_flat_tensor.data_ptr()

    grad = property(_get_grad, _set_grad, _del_grad)


def get_sliced_tensor(param: NanotronParameter, start_offset: int, end_offset: int):
    # This allows us to create a leaf tensor, despite sharing the underlying storage
    result = SlicedFlatTensor(data=param, start_offset=start_offset, end_offset=end_offset)
    return result


def find_optim_index_from_param_name(
    param_name: str,
    # NOTE: (pp_rank, dp_rank, tp_rank) or (pp_rank, tp_rank)
    ckp_sharded_optim_states: Union[Tuple[Tuple[int, int, int], torch.Tensor], Tuple[Tuple[int, int], torch.Tensor]],
    is_zero1: bool,
    pp_rank=0,
) -> int:
    param_name = param_name.replace("module.", "")
    # NOTE: since all shards have the same optim state names
    # so we take the first shard (except optionally the pp dimension)
    if is_zero1 is True:
        # NOTE: (pp_rank, dp_rank, tp_rank)
        OPTIM_STATE_INDEX_TO_PARAM_NAME = ckp_sharded_optim_states[(pp_rank, 0, 0)]["names"]
    else:
        # NOTE: (pp_rank, tp_rank)
        OPTIM_STATE_INDEX_TO_PARAM_NAME = ckp_sharded_optim_states[(pp_rank, 0)]["names"]

    return next((k for k, v in OPTIM_STATE_INDEX_TO_PARAM_NAME.items() if v == param_name), None)


def extract_parallel_ranks_from_shard_path(
    shard_path: Path, is_zero1: bool
) -> Union[Tuple[int, int, int], Tuple[int, int]]:
    """Extract parallel ranks from shard path

    For example, if the shard path is:
    + For ZeRO-1: /path/to/optimizer_pp-0-of-1_dp-0-of-2_tp-0-of-1.pt
    then the function will return (0, 0, 0) (pp_rank, dp_rank, tp_rank)

    For ZeRO-0: /path/to/optimizer_pp-0-of-1_tp-0-of-1.pt
    then the function will return (0, 0) (pp_rank, tp_rank)
    """
    if is_zero1 is True:
        # TODO(xrsrke): use the same pattern as weight checkpoints
        # in weight checkpoints, we do pp-rank-.... but here we only do pp-...
        # TODO(xrsrke): don't hardcode this
        pattern = r"optimizer_pp-(\d+)-of-\d+_dp-(\d+)-of-\d+_tp-(\d+)-of-\d+\.pt"
        match = re.search(pattern, str(shard_path))
        pp_rank, dp_rank, tp_rank = match.groups()
        return int(pp_rank), int(dp_rank), int(tp_rank)
    else:
        # NOTE: this is zero0 checkpoint
        pattern = r"pp-(\d+)-of-\d+_tp-(\d+)-of-\d+"
        match = re.search(pattern, str(shard_path))
        pp_rank, tp_rank = match.groups()
        return int(pp_rank), int(tp_rank)


def merge_dp_shard_in_zero1_optimizer(
    model: nn.Module,
    optimizer_config,
    shard_paths: List[Path],
    parallel_context: ParallelContext,
    map_location: Optional[str] = None,
) -> Dict[Tuple[int, int], Dict[str, torch.Tensor]]:  # (pp_rank, tp_rank): param_name -> optim_state
    assert (
        optimizer_config["configs"]["param_name_to_dp_rank_offsets"] is not None
    ), "param_name_to_dp_rank_offsets is required"

    checkpoint_pp_size = optimizer_config["parallelism"]["pp_size"]
    checkpoint_tp_size = optimizer_config["parallelism"]["tp_size"]

    ckp_sharded_optim_states = {}
    for shard_path in shard_paths:
        pp_rank, dp_rank, tp_rank = extract_parallel_ranks_from_shard_path(shard_path, is_zero1=True)
        ckp_sharded_optim_states[(pp_rank, dp_rank, tp_rank)] = torch.load(shard_path, map_location=map_location)

    param_name_to_dp_rank_offsets = optimizer_config["configs"]["param_name_to_dp_rank_offsets"]
    optimizer_state_names = ckp_sharded_optim_states[(0, 0, 0)]["state"][0].keys()

    def get_numel_of_unsharded_dp_param(param_name):
        dp_offsets = param_name_to_dp_rank_offsets[param_name]
        return max(int(value) for values in dp_offsets.values() for value in values)

    def assign_shard_to_buffer(buffer, offset, value):
        offset_start, offset_end = map(int, offset)
        buffer[offset_start:offset_end] = value

    param_names = sorted(model.state_dict().keys(), key=lambda x: x)
    ckp_merged_dp_shards_optim_states = {}
    for pp_rank, tp_rank in tqdm(
        list(itertools.product(range(int(checkpoint_pp_size)), range(int(checkpoint_tp_size)))),
        disable=dist.get_rank(parallel_context.world_pg) != 0,
        desc="Merging ZeRO-1's shards across data parallel dimension",
    ):
        # NOTE: filter only the shards that correspond to the current pp_rank and tp_rank
        filtered_ckp_sharded_optim_states = {}
        for (pp, dp, tp), ckp_optim_state in ckp_sharded_optim_states.items():
            if pp == pp_rank and tp == tp_rank:
                filtered_ckp_sharded_optim_states[dp] = ckp_optim_state

        # NOTE: now merge the shards across data parallel dimension
        # for each parameter, we need to merge all shards across data parallel dimension
        merged_dp_shards_optim_states = {}

        merged_dp_shards_optim_states["state"] = {}

        for param_name in param_names:
            unshard_dp_size = get_numel_of_unsharded_dp_param(param_name)
            optim_state_index = find_optim_index_from_param_name(
                param_name=param_name,
                ckp_sharded_optim_states=ckp_sharded_optim_states,
                is_zero1=True,
            )
            merged_dp_shards_optim_states["state"][optim_state_index] = {}
            for state_name in optimizer_state_names:
                unsharded_dp_buffer = torch.zeros(unshard_dp_size, device="cuda")
                # NOTE: now merge all the params across data parallel dimension
                for dp_rank, ckp_optim_state in filtered_ckp_sharded_optim_states.items():
                    # NOTE: extract the optimizer state of the current parameter
                    ckp_optim_state = ckp_optim_state["state"][optim_state_index]
                    ckp_offset = param_name_to_dp_rank_offsets[param_name][str(dp_rank)]
                    assign_shard_to_buffer(unsharded_dp_buffer, ckp_offset, ckp_optim_state[state_name])

                # NOTE: in optimizer states, the "state" use an index to represent the parameter
                # not the parameter name
                merged_dp_shards_optim_states["state"][optim_state_index][state_name] = unsharded_dp_buffer
                # NOTE: each dp shard has the same step
                merged_dp_shards_optim_states["state"][optim_state_index]["step"] = ckp_optim_state["step"]

        ckp_merged_dp_shards_optim_states[(pp_rank, tp_rank)] = merged_dp_shards_optim_states
        # NOTE: each dp shard has the same names, and param_groups since it's the same tp shard
        # the 0 in (pp_rank, 0, tp_rank) is the dp_rank
        ckp_merged_dp_shards_optim_states[(pp_rank, tp_rank)]["names"] = ckp_sharded_optim_states[
            (pp_rank, 0, tp_rank)
        ]["names"]
        ckp_merged_dp_shards_optim_states[(pp_rank, tp_rank)]["param_groups"] = ckp_sharded_optim_states[
            (pp_rank, 0, tp_rank)
        ]["param_groups"]

    assert len(ckp_merged_dp_shards_optim_states) == int(checkpoint_pp_size) * int(
        checkpoint_tp_size
    ), f"Expect {int(checkpoint_pp_size) * int(checkpoint_tp_size)} merged dp shards, got {len(ckp_merged_dp_shards_optim_states)}"

    # NOTE: sanity check, make sure each merged checkpoint
    # has the same dict key as the original checkpoint
    for (pp_rank, tp_rank), ckp_optim_state in ckp_merged_dp_shards_optim_states.items():
        # NOTE: we remove the gradient_accumulator key from sanity check
        # because we don't merge gradient_accumulator states
        missing_keys = set(ckp_optim_state.keys()) - set(ckp_sharded_optim_states[(pp_rank, 0, tp_rank)].keys())
        assert (
            len(missing_keys - {"gradient_accumulator"}) == 0
        ), "Expected the merged dp shards to have the same keys as the original dp shards, but merged dp shard misses: {}".format(
            missing_keys
        )

    return ckp_merged_dp_shards_optim_states