ddp_zero1.py 5.04 KB
Newer Older
Tri Dao's avatar
Tri Dao committed
1
2
3
4
5
6
7
8
9
10
11
# Meant to work with Pytorch's ZeroRedundancyOptimizer

from typing import Any, Callable, Dict, List, Optional, Union
from pathlib import Path

import torch
from torch.optim.optimizer import Optimizer
from torch.distributed.optim import ZeroRedundancyOptimizer

from pytorch_lightning.strategies.ddp import DDPStrategy
from pytorch_lightning.core.optimizer import LightningOptimizer
12
13
14
15
16
17
18
try:  # pytorch_lightning <= 1.7
    from pytorch_lightning.utilities.types import _PATH
except ImportError:  # pytorch_lightning >= 1.8
    try:
        from lightning_lite.utilities.types import _PATH
    except ImportError:  # pytorch_lightning >= 1.9
        from lightning_fabric.utilities.types import _PATH
Tri Dao's avatar
Tri Dao committed
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106


# Copied from Pytorch's ZeroRedundancyOptimizer's state_dict method, but we only get
# the local state dict to avoid synchronization across GPUs.
# https://github.com/pytorch/pytorch/blob/0c7ca2d97ba5980a2af7dcd6b8106dc915e591cd/torch/distributed/optim/zero_redundancy_optimizer.py#L1131
def get_zero_optimizer_state_dict_local(optimizer, global_rank):
    optimizer._check_overlap_initialized()

    # Sync the exposed `param_groups` attributes to the local optimizer in
    # case they have been updated
    optimizer._sync_param_groups(optimizer.param_groups, optimizer.optim.param_groups)

    local_state_dict = optimizer.optim.state_dict()
    state_dict = super(ZeroRedundancyOptimizer, optimizer).state_dict()

    # Update the global optimizer state with local state information,
    # factoring in the translation from local to global indexing
    rank = global_rank
    # TODO: recursive copy to device
    local_param_groups = local_state_dict["param_groups"]
    global_param_groups = optimizer._partition_parameters()[rank]
    assert len(local_param_groups) == len(global_param_groups), \
        "Mismatch between number of local and global parameter groups"

    for local_param_group, global_param_group in zip(local_param_groups, global_param_groups):
        # `local_param_group` stores local indices, while
        # `global_param_group` stores the tensors directly
        local_param_indices = local_param_group["params"]
        global_params = global_param_group["params"]

        assert len(local_param_indices) == len(global_params), \
            "Mismatch between number of local and global parameters in parameter group"
        for local_param_index, global_param in zip(local_param_indices, global_params):
            # Update the global parameter state, if any
            if local_param_index in local_state_dict["state"]:
                global_param_index = optimizer._param_to_index[global_param]
                state_dict["state"][global_param_index] = local_state_dict["state"][local_param_index]

    # Sort the parameters in the state
    state_dict["state"] = dict(sorted(state_dict["state"].items()))
    return state_dict


class DDPStrategyZero1(DDPStrategy):
    """To use ZeroRedundancyOptimizer, we need to shard the optimizer states when
    saving/loading checkpoints.
    """

    strategy_name = "ddp_zero1"

    def optimizer_state(self, optimizer: Optimizer) -> Optional[dict]:
        if isinstance(optimizer, LightningOptimizer):
            optimizer = optimizer._optimizer
        if isinstance(optimizer, ZeroRedundancyOptimizer):
            return get_zero_optimizer_state_dict_local(optimizer, self.global_rank)
        else:
            return optimizer.state_dict()

    def save_checkpoint(
        self, checkpoint: Dict[str, Any], filepath: _PATH, storage_options: Optional[Any] = None
    ) -> None:
        """Save model/training states as a checkpoint file through state-dump and file-write.
        Args:
            checkpoint: dict containing model and trainer state
            filepath: write-target file's path
            storage_options: parameter for how to save to storage, passed to ``CheckpointIO`` plugin
        """
        filepath = Path(filepath)
        filepath.mkdir(parents=True, exist_ok=True)
        local_optimizer_states = checkpoint.pop('optimizer_states')
        if self.is_global_zero:
            self.checkpoint_io.save_checkpoint(checkpoint, filepath / 'model_states.pt',
                                               storage_options=storage_options)
        self.checkpoint_io.save_checkpoint(local_optimizer_states,
                                           filepath / f'{self.global_rank:03d}_optim_states.pt',
                                           storage_options=storage_options)

    def load_checkpoint(self, checkpoint_path: _PATH) -> Dict[str, Any]:
        torch.cuda.empty_cache()
        checkpoint_path = Path(checkpoint_path)
        if checkpoint_path.is_file():
            return super().load_checkpoint(self, str(checkpoint_path))
        else:
            assert checkpoint_path.is_dir()
            global_states = self.checkpoint_io.load_checkpoint(checkpoint_path / 'model_states.pt')
            local_optimizer_states = self.checkpoint_io.load_checkpoint(checkpoint_path / f'{self.global_rank:03d}_optim_states.pt')
            global_states['optimizer_states'] = local_optimizer_states
            return global_states