sharded_ddp.py 8.76 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

"""
A distributed data parallel class that works with OSS optimizer.

Adopted from LegacyDistributedDataParallel module from fairseq.
"""

from contextlib import contextmanager
import copy
14
from typing import Any, Dict, Generator, List, Optional, Type, cast
15
16

import torch
17
from torch import Tensor, nn
18
import torch.distributed as dist
19
from torch.nn import Parameter
20

21
from fairscale.optim import OSS
22
23


24
class ShardedDataParallel(nn.Module):
25
26
27
    """Implements distributed data parallel training with optimizer state sharding.

    A simplified version of :class:`torch.nn.parallel.DistributedDataParallel`.
28
    This version uses a c10d process group for communication and optionally
29
30
31
32
    broadcast buffers.

    Args:
        module (~torch.nn.Module): module to be parallelized
33
34
        optimizer (~torch.optim.Optimizer): optimizer to be used for training
        optimizer_params(Dict): extra parameters for the optimizer
35
        world_size (int): number of parallel workers
36
37
        broadcast_buffers (bool): flag that enables syncing (broadcasting) buffers of
        the module at beginning of the forward function. (default: ``True``)
38
39
40
41
42
43
44
45
46
        process_group (optional): the c10d process group to be used for
            distributed gradient reduction. If None, the default WORLD process group
            will be used.
        buffer_size (int, optional): number of elements to buffer before
            performing reduce (default: 256M). Used to reduce multiple small
            params to avoid communication overhead.
    """

    def __init__(
47
48
49
50
51
        self,
        module: nn.Module,
        optimizer: Type[torch.optim.Optimizer],
        optimizer_params: Dict[str, Any],
        world_size: int,
52
        broadcast_buffers: bool,
53
54
        process_group: Any = None,
        buffer_size: int = 2 ** 28,
55
56
57
58
59
60
61
    ):
        super().__init__()

        self.module = module
        self.world_size = world_size
        self.process_group = process_group if process_group is not None else dist.group.WORLD
        self.rank = dist.get_rank(self.process_group)
62
63
        self.broadcast_buffers = broadcast_buffers
        self.authoritative_rank = 0
64
65
66
67
68
69
70
71
72
73
74
75

        # Never use a bigger buffer than the number of model params
        self.buffer_size = min(buffer_size, sum(p.numel() for p in self.module.parameters()))
        self.buffer: Optional[Tensor] = None

        # Flag used to make sure we only reduce gradients one time in the execution engine
        self.need_reduction = False

        # We can also forcibly accumulate grads locally and only do the
        # gradients-reduce at some later time
        self.accumulate_grads = False

76
77
        # Build the sharded optimizer
        self.sharded_optimizer = OSS(self.module.parameters(), optim=optimizer, group=process_group, **optimizer_params)
78

79
        # Sanity checks
80
81
82
        assert len(self.sharded_optimizer.param_to_rank) == len(
            list(self.module.parameters())
        ), "number of params do not match"
83
        for param in self.module.parameters():
84
            assert param in self.sharded_optimizer.param_to_rank, f"{param} not in the optimizer"
85
86
87
88
89

    def __getstate__(self) -> Dict:
        attrs = copy.copy(self.__dict__)
        return attrs

90
91
92
93
94
    @property
    def optimizer(self) -> torch.optim.Optimizer:
        return self.sharded_optimizer

    def train(self, mode: bool = True) -> "ShardedDataParallel":
Min Xu's avatar
Min Xu committed
95
96
97
98
99
100
101
102
        pre_mode = self.module.training
        self.module.train(mode)
        if self.module.training:
            assert not self.need_reduction or pre_mode, "incorrect state transition"
        else:
            assert not self.need_reduction, "try to enter eval with grads unreduced"
        return self

103
104
105
106
107
108
109
110
111
    @contextmanager
    def no_sync(self) -> Generator:
        """A context manager to disable gradient synchronization."""
        old_accumulate_grads = self.accumulate_grads
        self.accumulate_grads = True
        yield
        self.accumulate_grads = old_accumulate_grads

    def forward(self, *inputs: Any, **kwargs: Any) -> Tensor:
Min Xu's avatar
Min Xu committed
112
113
114
115
116
        if self.module.training:
            if self.need_reduction:
                raise RuntimeError("OssDdp requires explicit reduction, must call OssDdp.reduce")
            if not self.accumulate_grads:
                self.need_reduction = True
117
118
119
            if self.broadcast_buffers and len(list(self.module.buffers())) > 0:
                self._sync_buffers()

120
121
122
123
124
125
126
        return self.module(*inputs, **kwargs)

    def reduce(self) -> None:
        """
        This function must be called explicitly after backward to reduce
        gradients. There is no automatic hook like c10d.
        """
Min Xu's avatar
Min Xu committed
127
        assert self.module.training, "Cannot call reduce in eval"
128

129
130
131
        def reduce_grads(params: List[Parameter], params_rank: int) -> None:
            """ Helper to reduce a list of params that should fit in the buffer.
            NOTE: All param gradients are assumed to exist"""
132
            assert self.buffer is not None
133
134

            # Fill in the packed IO buffer
135
136
137
138
139
            buffer: Tensor = cast(Tensor, self.buffer)
            if len(params) > 1:
                offset = 0
                for p in params:
                    sz = p.numel()
140
                    buffer[offset : offset + sz].copy_(p.grad.data.view(-1))  # type: ignore
141
142
143
                    offset += sz
            else:
                # we only have a single grad to reduce
144
                buffer = params[0].grad.data  # type: ignore
145

146
147
148
            # Reduce
            buffer.div_(self.world_size)  # type: ignore
            dist.reduce(tensor=buffer, dst=params_rank, group=self.process_group)  # type: ignore
149

150
            # Copy reduced grads back into their original place, or free corresponding memory
151
152
153
154
            if params_rank == self.rank:
                offset = 0
                for p in params:
                    sz = p.numel()
155
                    p.grad.data.copy_(buffer[offset : offset + sz].view_as(p))  # type: ignore
156
157
158
                    offset += sz
            else:
                for p in params:
159
                    p.grad = None
160
161
162
163
164
165
166
167
168
169

        def reduction_fn() -> None:
            # This function only needs to be called once
            if not self.need_reduction or self.accumulate_grads:
                return
            self.need_reduction = False

            if self.buffer is None:
                self.buffer = next(self.module.parameters()).new(self.buffer_size)  # type: ignore

170
            for params in self.sharded_optimizer.per_device_params:
171
172
173
174
175
176
                # Reduce the gradients in buckets
                offset = 0
                buffered_params: List[Parameter] = []
                param_rank: Optional[int] = None
                for param in params:
                    last_param_rank: Optional[int] = param_rank
177
                    param_rank = self.sharded_optimizer.param_to_rank[param]
178
179
                    if not param.requires_grad:
                        continue
180

181
182
183
184
185
186
187
188
                    if param.grad is None:
                        param.grad = torch.zeros_like(param)
                    if param.grad.requires_grad:
                        raise RuntimeError("DistributedDataParallel only works with gradients that don't require grad")
                    sz = param.numel()
                    if sz > self.buffer.numel():
                        # reduce big params directly
                        assert param_rank is not None
189
                        reduce_grads([param], cast(int, param_rank))
190
191
192
193
194
195
196
                    else:
                        # smaller params are packed together from the same device
                        # and same rank.
                        if offset + sz > self.buffer.numel() or (
                            last_param_rank is not None and last_param_rank != param_rank
                        ):
                            assert last_param_rank is not None
197
                            reduce_grads(buffered_params, cast(int, last_param_rank))
198
199
                            offset = 0
                            buffered_params.clear()
200
                        buffered_params.append(cast(Parameter, param))
201
202
203
204
                        offset += sz

                if len(buffered_params) > 0:
                    assert param_rank is not None
205
                    reduce_grads(buffered_params, cast(int, param_rank))
206
207

        reduction_fn()
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222

    def _sync_buffers(self) -> None:
        """
        Sync all the param buffers in between ranks.
        TODO: Could be worth bucketing ?
        """
        _ = list(
            map(
                lambda x: x.wait(),
                map(
                    lambda x: dist.broadcast(x, self.authoritative_rank, self.process_group, async_op=True),
                    self.module.buffers(),
                ),
            )
        )