oss.py 9.75 KB
Newer Older
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
1
2
3
4
5
6
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

import copy
7
from itertools import chain
8
9
import logging
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Type
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
10

11
import torch
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
12
13
14
import torch.distributed as dist
from torch.optim import SGD, Optimizer

15
16
from .utils import broadcast_object, recursive_copy_to_device

17
if TYPE_CHECKING:  # pragma: no cover
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
18
19
20
21
22
23
24
    from torch.optim.optimizer import _params_t
else:
    _params_t = Any


class OSS(Optimizer):
    """Wraps an arbitrary :class:`optim.Optimizer <torch.optim.Optimizer>`
25
    optimizer and shards its state as described by ZeRO_.
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
26
27
28
29
30
    ::
        opt = OSS(params, optim=torch.optim.Adam, lr=0.01)

    .. _ZeRO: https://arxiv.org/abs/1910.02054

31
32
33
    We use a greedy algorithm to pack a number of parameters
    at each rank. Each parameter belongs to a single rank and
    is not divided among rank.
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
34

35
36
37
    After each rank completed their parameter update, they broadcast
    the new version of the parameters to all other ranks to synchronize
    the parameters for next round forward/backward computation.
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52

    Args:
        params (list of tensors):
            parameters to be optimized
    Keyword Args:
        optim (torch.nn.Optimizer):
            optimizer to shard (default: SGD)
        group (group):
            torch.distributed group (default: group.WORLD)
    """

    optim: Optimizer
    in_super_constructor: bool

    def __init__(self, params: _params_t, optim: Type[Optimizer] = SGD, group: Any = dist.group.WORLD, **defaults: Any):
53
        # Hold all the model params in the root .param_groups
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
54
55
56
57
        self.in_super_constructor = True
        super().__init__(params, defaults)
        self.in_super_constructor = False

58
        # Build the wrapped optimizer, responsible for a shard of the params
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
59
60
61
62
63
        self.group = group
        self.rank = dist.get_rank(group)
        param_groups = self.partition_parameters()
        self.optim = optim(param_groups[self.rank], **defaults)

64
65
66
67
68
69
        # Optional consolidated optimizer state
        self._all_states: List[Dict[str, Any]] = []

        # Current device is set by the parameters allocated to this rank
        self._device = self.partition_parameters()[self.rank][0]["params"][0].device

Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
    def partition_parameters(self) -> List[List[dict]]:
        """Partitions parameters across distributed ranks.

        Returns a list of param_groups (which is a list of dict) where each
        element of the list contains the param_groups for a rank. Element 0
        corresponds to rank 0, etc. We need all the ranks for the broadcast
        inside step().
        """
        world_size = dist.get_world_size(self.group)
        param_groups: List[List] = [list() for _ in range(world_size)]
        sizes = [0] * world_size
        for param_group in self.param_groups:
            param_lists: List[List] = [list() for _ in range(world_size)]
            for param in param_group["params"]:
                # Add this param to rank with smallest size.
                rank = sizes.index(min(sizes))
                param_lists[rank].append(param)
                sizes[rank] += param.numel()
            for rank, params in enumerate(param_lists):
89
90
91
                param_group_rank = copy.copy(param_group)
                param_group_rank["params"] = params
                param_groups[rank].append(param_group_rank)
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
92
93
        return param_groups

94
95
96
    # NOTE(msb) We add a kwargs in order to support Optimizer sub-classes that support extra kwargs.
    # For example, the apex library contains fused optimizers with a step that supports extra kwargs.
    def step(self, closure: Optional[Callable[[], float]] = None, **kwargs: Any) -> Optional[float]:
97
98
99
        # Sync lr in case its been update by an LRScheduler.
        self._sync_lr()

100
        # Run the optimizer step on this shard only
101
        loss = self.optim.step(closure=closure, **kwargs)  # type: ignore
102
103

        # Sync all the states
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
104
105
106
        for rank, param_groups in enumerate(self.partition_parameters()):
            for param_group in param_groups:
                for param in param_group["params"]:
107
                    dist.broadcast(tensor=param, src=rank, group=self.group)
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
108
109
        return loss

110
    def local_state_dict(self) -> dict:
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
111
112
113
        """ Gets this rank's state_dict. """
        return self.optim.state_dict()

114
115
116
117
118
    def consolidate_state_dict(self, recipient_rank: int = 0) -> None:
        """ Update the consolidated state_dict list, one per rank.

        This needs to be called on all replicas """

119
120
121
        # Sync lr in case its been update by an LRScheduler.
        self._sync_lr()

122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
        if self.rank == recipient_rank:
            # Pull the sharded state from all the other replicas
            # Store all the states in order, rank by rank
            logging.debug("Pulling the sharded SGD state from all replicas")
            self._all_states = self._collect_sharded_states()
        else:
            # Acknowledge broadcasts, and send this rank's shard when needed
            self._broadcast_state_dict()

    def state_dict(self) -> Dict[str, Any]:
        """
        Return the last known global optimizer state, which consist of a list of the shards.

        NOTE: This is limited to the replica which was responsible for the consolidation.
        The state may also not be up to date, depending on when `consolidate_state_dict` was last called.
        """

        assert (
            len(self._all_states) > 0
        ), "The optimizer state is not materialized, please call consolidate_state_dict on every replica beforehand"

143
        return {"state": self._all_states}
144
145

    def load_local_state_dict(self, state_dict: dict) -> None:
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
146
        """ Loads this rank's state_dict. """
147

148
        self.optim.load_state_dict(state_dict)
149

150
151
152
153
154
155
156
157
158
159
160
161
162
        # Workaround PyTorch bug that casts state (https://github.com/pytorch/pytorch/issues/43706)
        # Copied from https://github.com/pytorch/fairseq/blob/v0.9.0/fairseq/optim/fp16_optimizer.py#L251-L268
        groups = self.optim.param_groups
        saved_groups = state_dict["param_groups"]
        id_map = {
            old_id: p
            for old_id, p in zip(chain(*(g["params"] for g in saved_groups)), chain(*(g["params"] for g in groups)))
        }
        for k, v in state_dict["state"].items():
            if k in id_map:
                param = id_map[k]
                self.optim.state[param] = recursive_copy_to_device(v, non_blocking=True, device=param.device)

163
        # Restore the global param_groups (the params themselves are already correct)
164
        for global_group, local_group in zip(self.param_groups, groups):
165
166
167
            for k, v in local_group.items():
                if k != "params":
                    global_group[k] = v
168

169
170
171
172
173
    def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
        """ Restore the global parameter groups as well as the shard """
        # Dispatch this rank's state dictionary to the wrapped shard optimizer
        self.load_local_state_dict(state_dict["state"][self.rank])

Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
174
175
176
177
178
179
    def add_param_group(self, param_group: dict) -> None:
        super().add_param_group(param_group)
        if not self.in_super_constructor:
            param_groups = self.partition_parameters()[self.rank]
            if len(param_groups) == len(self.optim.param_groups) + 1:
                self.optim.add_param_group(param_groups[-1])
180

181
182
183
184
185
    def _sync_lr(self) -> None:
        """Sync learning rate (needed to support LRScheduler)."""
        for global_group, local_group in zip(self.param_groups, self.optim.param_groups):
            local_group["lr"] = global_group["lr"]

186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
    def _collect_sharded_states(self) -> List[Dict[str, Any]]:
        """
        Collect all the state shards, in CPU memory.
        """
        empty_buffer = torch.tensor([0], dtype=torch.uint8, device=self._device)
        all_states: List[Dict[str, Any]] = []

        for rank in range(dist.get_world_size(group=self.group)):
            if rank == self.rank:
                logging.debug("Saving self state")
                all_states.append(
                    recursive_copy_to_device(self.local_state_dict(), non_blocking=True, device=torch.device("cpu"))
                )

                # Sync with other replicas
                broadcast_object(empty_buffer, src_rank=rank, group=self.group, dist_device=self._device)
            else:
                # Fetch the optim state from the other replicas
                logging.debug("Receiving state from rank %s ", rank)
                replica_state = broadcast_object(
                    empty_buffer, src_rank=rank, group=self.group, dist_device=self._device
                )

                all_states.append(
                    recursive_copy_to_device(replica_state, non_blocking=True, device=torch.device("cpu"))
                )

                logging.debug("State from rank %s received", rank)

        return all_states

    def _broadcast_state_dict(self) -> None:
        """
        Broadcast this rank's state shard, discard others
        """
        empty_buffer = torch.tensor([0], dtype=torch.uint8, device=self._device)

        for rank in range(dist.get_world_size(group=self.group)):
            if rank == self.rank:
                # Send the state to the reference replica
                logging.debug(
                    "Sending the sharded SGD state to the reference replica from rank %s", rank,
                )
                broadcast_object(self.local_state_dict(), src_rank=rank, group=self.group, dist_device=self._device)
            else:
                # Discard this tensor/rank, broadcast necessary for syncing
                logging.debug("Discarding broadcast from rank %s", rank)
                broadcast_object(empty_buffer, src_rank=rank, group=self.group, dist_device=self._device)