adascale.py 27.5 KB
Newer Older
Min Xu's avatar
Min Xu committed
1
2
3
4
5
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

6
# Copyright 2020 Petuum, Inc. All Rights Reserved.
7
#
8
9
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
10
#
11
12
# 1. Redistributions of source code must retain the above copyright notice,
#    this list of conditions and the following disclaimer.
13
#
14
15
16
# 2. Redistributions in binary form must reproduce the above copyright notice,
#    this list of conditions and the following disclaimer in the documentation
#    and/or other materials provided with the distribution.
17
#
18
19
20
# 3. Neither the name of Petuum, Inc.  nor the names of its contributors may be
#    used to endorse or promote products derived from this software without
#    specific prior written permission.
21
#
22
23
24
25
26
27
28
29
30
31
32
33
34
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.

import functools
Min Xu's avatar
Min Xu committed
35
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type
36
37

import numpy as np
38
import torch
39
from torch.autograd import Variable
40
import torch.distributed as dist
Min Xu's avatar
Min Xu committed
41
42
43
44
45
46
from torch.optim import SGD, Optimizer

if TYPE_CHECKING:  # pragma: no cover
    from torch.optim.optimizer import _params_t
else:
    _params_t = Any
47
48


49
class AdaScale(Optimizer):
50
51
52
53
54
    """
    Implements the AdaScale_ algorithm for scaling the learning rate for
    distributed and large batch size training. Can be used in combination with
    ``torch.nn.parallel.DistributedDataParallel`` and ``torch.optim.SGD``.

Min Xu's avatar
Min Xu committed
55
56
    .. _AdaScale: https://proceedings.icml.cc/static/paper_files/icml/2020/4682-Supplemental.pdf

57
58
59
60
    This class subclasses `Optimizer` so that `torch.optim.lr_scheduler` can
    work with it. In other words, AdaScale is intended to be a complete wrapper of an
    torch Optimizer.

61
    Note that, AdaScale does *not* help increase per-GPU batch size.
62

63
64
65
66
67
68
69
70
71
72
73
74
75
76
    There are several ways to integrate AdaScale with your training loop.
    We show two examples below.

    Example 1: using PyTorch's `lr_scheduler` classes.

    .. code-block:: python

        optim = AdaScale(SGD(model.parameters(), lr=0.001))
        model = DistributedDataParallel(model)
        scheduler = LambdaLR(optim, lr_lambda=...)

        last_epoch = 0
        done = False
        step = 0
77
        while not done:
78
79
80
81
82
83
84
85
86
87
88
            for batch in dataset:
                optim.zero_grad()
                logits = model()
                loss = criterion(logits, ...)
                loss.backward()
                step += optim.gain()
                optim.step()
                epoch = step // len(dataset)
                if epoch > last_epoch:
                    scheduler.step()
                    last_epoch = epoch
89
                if epoch >= MAX_EPOCHS:
90
91
92
                    done = True

    Example 2: using a custom `update_lr()` function that update the learning
93
    rate based on the current step count per epoch.
94

95
96
    .. code-block:: python

97
        optim = AdaScale(SGD(model.parameters(), lr=0.001))
98
99
        model = DistributedDataParallel(model)

100
101
        step = 0
        while step < max_steps:
102
103
            for batch in ...:
                optim.zero_grad()
104
105
                logits = model()
                loss = criterion()
106
                loss.backward()
107
108
109
                step += optim.gain()
                optim.step()
                update_lr(step)
110

Min Xu's avatar
Min Xu committed
111
112
113
114
    Args:
        optimizer (torch.optim.Optimizer):
            Optimizer to apply AdaScale to.
        world_size (int):
115
116
            Number of world_size for distributed training.
            If None, defaults to ``dist.get_world_size()``.
Min Xu's avatar
Min Xu committed
117
        scale (float):
118
119
            Scaling factor of the batch size from scale equals 1, e.g. using a 10x
            larger batch size (summed across all ranks with gradient accumulation)
120
121
            means a scale of 10.
            If None, defaults to ``world_size * num_gradients_to_accumulate``.
122
        smoothing (float):
123
124
            Smoothing factor for moving average.
            If None, it defaults to ``max(1 - (world_size * num_gradients_to_accumulate)/1000, 0)``.
125
126
127
            Note, for very high scale training, higher smoothing value might be needed,
            esp at the begining of the training. Therefore, if your scale is close to or larger
            than 1000, try experimenting with smoothing value > 0 if the final accuracy is poor.
128
        num_gradients_to_accumulate (int):
129
130
131
132
            Number of passes that we accumulate gradients locally
            between each optimizer step. This can be changed during
            training as long as the train loop changes gradient accumulation
            accordingly.
133
            Default to 1, which does not accumulate gradients.
134
135
136
137
138
        debias_ewma (bool):
            (experimental) Use debias exponential moving average
            for smoothing and mu and sigma variables. False will
            use the method in the paper's Appendix B.3.
            Default: True, which is what have been validated so far.
139
    """
140

Min Xu's avatar
Min Xu committed
141
142
143
144
145
    def __init__(
        self,
        optimizer: torch.optim.Optimizer,
        world_size: Optional[int] = None,
        scale: Optional[float] = None,
146
        smoothing: float = None,
147
        num_gradients_to_accumulate: int = 1,
148
        debias_ewma: bool = True,
Min Xu's avatar
Min Xu committed
149
    ):
150
151
152
153
        # Init hook_handles list, otherwise, a partial init'ed object may fail in ``__del__``.
        self._hook_handles: List[Any] = []

        # Init other fields.
154
        self._optimizer = optimizer
Min Xu's avatar
Min Xu committed
155
        self._local_grad_sqr: Optional[torch.Tensor] = None
156
157
158
159
        self._world_size: int = (
            world_size if world_size is not None else dist.get_world_size() if dist.is_initialized() else 1
        )
        self._num_backward_calls = 0
160
        self._last_final_backward_call = 0
161
        self._num_grads_to_accum = num_gradients_to_accumulate
162
        self._debias_ewma = debias_ewma
163

164
165
166
        # Proxy the param_groups so that `torch.optim.lr_scheduler` can work.
        self.param_groups = self._optimizer.param_groups

167
168
        self.set_num_gradients_to_accumulate(num_gradients_to_accumulate, update_smoothing=True)

169
170
171
172
173
        # The previous function call sets smoothing to its default value.
        # Override that here if smoothing was passed as an argument.
        if smoothing is not None:
            self._smoothing = smoothing

174
175
176
        if self._world_size * self._num_grads_to_accum <= 1:
            # gain will be NaN since we will be dividing by zero in paper's B.3 where (S-1) == 0.
            raise RuntimeError("AdaScale does not support a single worker without grad accumulation.")
177

178
        # Per-param-group sqr & var states (sigma^2 & mu^2 in the paper).
179
180
181
182
183
184
185
        self._optimizer.state.setdefault(
            "adascale",
            {
                "grad_sqr_avg": np.ones(len(optimizer.param_groups)),
                "grad_var_avg": np.zeros(len(optimizer.param_groups)),
            },
        )
186

187
        self._scale = 1.0  # Assign to inform mypy about the typing of this variable.
188
        self.set_scale(self._world_size * self._num_grads_to_accum if scale is None else scale)
189

190
        # Safer to register hooks after all init actions are done.
191
192
193
        self._hook()

    def _hook(self) -> None:
194
        """Internal function to register the gradient hooks.
195

196
197
198
        Note, don't assume every parameter will generate a gradient (i.e. triggering the hook)
        in every backward pass, which is the reason that we have ``find_unused_params`` flag
        in the DDP class in ``torch.nn.parallel``.
199
200
        """
        assert self._hook_handles == [], "Must run unhook first"
201
202
        for idx, param_group in enumerate(self._optimizer.param_groups):
            for param in param_group["params"]:
203
204
205
206
                h = param.register_hook(functools.partial(self._backward_hook, idx))
                self._hook_handles.append(h)

    def __del__(self) -> None:
207
        """Unhook in case caller forgets to call unhook.
208

209
210
211
212
        This however may not "work" since there would be circular reference
        between the hook objects and this objects. In that case, neither will
        get GC'ed. Calling unhook explicitly if you really want to delete
        AdaScale from memory.
213
214
215
216
        """
        self.unhook()

    def unhook(self) -> None:
217
        """Unregister hook handles.
218

219
220
221
        This is public because caller may need to call this to ensure all GPU
        memory are released. Otherwise, the hook may prevent parameters from being
        released from the GPU memory pool.
222

223
        Internally, we use this to support ``add_param_group()`` API.
224
225
226
227
        """
        for h in self._hook_handles:
            h.remove()
        self._hook_handles = []
228
229

    @property
230
    def _state(self) -> Dict[str, np.ndarray]:
231
232
233
        """
        Return the states of AdaScale.
        """
234
235
236
        return self._optimizer.state["adascale"]

    @property
Min Xu's avatar
Min Xu committed
237
    def scale(self) -> float:
238
239
        """
        The scaling factor of the current batch size, relative to the baseline
240
241
242
243
244
245
        batch size, which could be a DDP training. For example, if the
        baseline batch size is 32 on 2 GPUs, but using a scaled-up batch size
        of 80 on 4 GPUs, then then the scaling factor is 80 * 4 / 32 / 2 = 5.

        This is exposed API mainly for logging purpose. Note, this is different
        from ``self.gain()``.
246
247
248
249

        Returns:
            (float):
                The current scaling factor.
250
251
252
        """
        return self._scale

253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
    @property
    def smoothing(self) -> float:
        """
        The smoothing constant used in exponentially-weighted moving average
        tracking the gradient norm mean and variance within AdaScale.

        This is exposed API since the value is computed and caller may
        want to obtain this value and log it.

        Returns:
            (float):
                The current smoothing value.
        """
        return self._smoothing

    def set_scale(self, scale: float, update_estimate: bool = True) -> None:
269
270
271
272
273
        """
        Set the scaling factor of the current batch size. It is up to the
        application to invoke this function to make sure that AdaScale's
        scaling factor matches the actual batch size used during training.

Min Xu's avatar
Min Xu committed
274
275
276
        Args:
            scale (float):
                New scaling factor to be applied to AdaScale.
277
278
279
            update_estimate (bool):
                Whether to update the scale-depenent estimate of gradient
                variance; this is highly recommended. (default: True)
280
        """
281
282
283
284
285
286
287
288
289
290
        assert self._local_grad_sqr is None, "Don't change scale in backward phase"
        assert scale >= 1, "Scale must be at least 1"
        if update_estimate and hasattr(self, "_scale"):
            assert self._scale >= 1, "bug: old scale isn't valid"
            # Rescale grad_var_avg to account for the change in scale
            if self._debias_ewma and "grad_var_avg_biased" in self._state:
                self._state["grad_var_avg_biased"] *= self._scale / scale
            elif "grad_var_avg_total" in self._state:  # _debias_ewma==False
                self._state["grad_var_avg_total"] *= self._scale / scale
            self._state["grad_var_avg"] *= self._scale / scale
291
292
        self._scale = scale

293
    def _grad_sqr_avg(self, pg_idx: Optional[int] = None) -> float:
294
        """
295
296
        Current estimate of the squared l2-norm of the true gradient
        (sigma squared in the AdaScale paper).
297

298
299
300
301
302
        Args:
            pg_idx (Optional[int]):
                Optional index for a parameter group.

        Returns:
Min Xu's avatar
Min Xu committed
303
304
            (float):
                Estimate of squared l2-norm.
305
        """
306
        if pg_idx is not None:
307
            return self._state["grad_sqr_avg"][pg_idx]
308
        else:
309
            return float(np.sum(self._state["grad_sqr_avg"]))
310

311
    def _grad_var_avg(self, pg_idx: Optional[int] = None) -> float:
312
313
314
315
        """
        Current estimate of the trace of the covariance of the true gradient
        (mu squared in the AdaScale paper).

316
317
318
319
320
        Args:
            pg_idx (Optional[int]):
                Optional index for a parameter group.

        Returns:
Min Xu's avatar
Min Xu committed
321
322
            (float):
                Estimate of trace of the covariance.
323
        """
324
        if pg_idx is not None:
325
            return self._state["grad_var_avg"][pg_idx]
326
        else:
327
            return float(np.sum(self._state["grad_var_avg"]))
328

329
    def gain(self, pg_idx: Optional[int] = None) -> float:
330
        """
331
        Current estimate of the AdaScale gain ratio (r_t in the paper).
332

Min Xu's avatar
Min Xu committed
333
        Args:
334
335
            pg_idx (int):
                Optional index of a parameter group.
336
                Default None: returns "averaged" gain for all groups.
337

338
339
        Returns:
            (float):
Min Xu's avatar
Min Xu committed
340
                Estimate of gain ratio.
341
        """
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
        var = self._grad_var_avg(pg_idx)
        sqr = self._grad_sqr_avg(pg_idx)
        gain = (var + sqr) / (var / self.scale + sqr)
        return gain

    def _update_avg(self, name: str, value: np.ndarray, factor: float) -> None:
        if self._debias_ewma:
            # This function computes and stores the moving average of a vector
            # using a smoothing factor.
            biased = self._state.get(name + "_biased", np.zeros(value.shape[0]))
            unbias = self._state.get(name + "_unbias", np.zeros(value.shape[0]))
            biased = factor * biased + (1.0 - factor) * value
            unbias = factor * unbias + (1.0 - factor)
            self._state[name + "_biased"] = biased
            self._state[name + "_unbias"] = unbias
            self._state[name] = biased / unbias
        else:
            # Moving average procedure described in Appendix B.3
            # For iterations t < 1 / (1 - smoothing) define grad_var_avg
            # and grad_sqr_avg as mean of the past samples. After that
            # start using running average.
            #
            # Note: we only keep a single _count for all parameter groups.
            #       Ideally, it should be a vector and in case a PG is added
            #       after some iterations are done. But, then the if condition
            #       below will need to be a np.where. I leave this corner
            #       case to a future exercise.
369
370
            count = self._state.get(name + "_count", np.zeros(1))
            count[0] += 1
371
372
373
374
375
376
377
378
379
380
381
            self._state[name + "_count"] = count
            if count < 1 / (1 - self._smoothing):
                total = self._state.get(name + "_total", None)
                if total is None:
                    total = value
                else:
                    total += value
                self._state[name + "_total"] = total
                self._state[name] = total / count
            else:
                self._state[name] = factor * self._state[name] + (1.0 - factor) * value
382

383
    def _backward_hook(self, pg_idx: int, grad: torch.Tensor) -> None:
384
385
        # This method should be invoked once for each parameter during the
        # backward pass, before gradients are synchronized between world_size.
386
387

        # Store the local gradient square sums in a vector.
388
389
        # This vector is also used for error checking. Whenever it is not None,
        # it means that we are in backward pass.
390
        if self._local_grad_sqr is None:
391
            self._local_grad_sqr = torch.zeros(
392
393
394
                len(self._optimizer.param_groups),
                device=grad.device,
                requires_grad=False,
395
            )
396
397
398
399
400
        self._local_grad_sqr[pg_idx] += grad.pow(2).sum()

        # Now, ensure we queue a callback at the end of the callback queue.
        # This will fire after all gradient callbacks are done (esp. those
        # queued by DDP.
401
402
403
        self._final_callback_queued = False
        Variable._execution_engine.queue_callback(self._queue_callback)

Min Xu's avatar
Min Xu committed
404
    def _queue_callback(self) -> None:
405
406
407
408
409
410
411
412
413
414
415
416
        # This method should be invoked after the entire backward pass. We want
        # to make sure self._final_callback is invoked once, only after all
        # gradients have been synchronized between each worker. However, the
        # synchronization code in DistributedDataParallel is also done in a
        # callback, which might not yet be executed. Therefore, we enqueue
        # self._final_callback from this method, which should ensure it is
        # invoked after the gradient synchronization callback.
        if self._final_callback_queued:
            return
        self._final_callback_queued = True
        Variable._execution_engine.queue_callback(self._final_callback)

Min Xu's avatar
Min Xu committed
417
    def _final_callback(self) -> None:
418
        # This method should be invoked once for each backward pass, after
419
420
421
        # gradients have been synchronized between each worker, unless we
        # are in gradient accumulation mode, where grads are not all_reduced
        # between the GPUs.
422
        self._final_callback_queued = False
Min Xu's avatar
Min Xu committed
423
        assert isinstance(self._local_grad_sqr, torch.Tensor)
424

425
        # Keep track of number of backward calls for gradient accumulation.
426
427
        # TODO (min): this may not work with activation checkpointing when
        #             multiple backward calls happen in a big backward.
428
        self._num_backward_calls += 1
429

430
431
432
433
434
        # TODO (min, mike): We need to have a way to check that training loop & DDP
        #                   is doing the right thing where the gradient is reduced
        #                   in this backward pass.
        #                   Longer term, we may compute the gain and then inform
        #                   the training loop when it is a good time to step().
435
436
437
438
439
440
441
        assert (
            self._num_backward_calls - self._last_final_backward_call
        ) <= self._num_grads_to_accum, (
            f"bug: {self._num_backward_calls} - {self._last_final_backward_call} should <= {self._num_grads_to_accum}"
        )
        if (self._num_backward_calls - self._last_final_backward_call) % self._num_grads_to_accum != 0:
            assert self._local_grad_sqr is not None, "We should still be in backward phase"
442
443
444
445
446
447
448
449
450
451
452
            return

        # Since self._local_grad_sqr is FP32, sum shouldn't overflow.
        # This vector has length of # of param_groups, so it is small, but we
        # use async to hide the all_reduce latency, esp when # of nodes is large.
        work = None
        if self._world_size > 1:
            work = dist.all_reduce(self._local_grad_sqr, async_op=True)  # SUM

        # Compute the sums of squares for reduced gradients.
        # Divide by _num_grads_to_accum since the gradients are accumulated.
453
        total_grad_sqr = np.array(
454
            [sum(param.grad.pow(2).sum().item() for param in group["params"]) for group in self._optimizer.param_groups]
455
        )
456
457
458
459
        # Divide by (_num_grads_to_accum ** 2) to account for gradient
        # accumulation.
        if self._num_grads_to_accum > 1:
            # np array doesn't support /=.
460
            total_grad_sqr = total_grad_sqr / (self._num_grads_to_accum**2)
461
462
463
464
465
466
467

        # Wait for all_reduce to be done and move it to cpu & np.
        if work:
            work.wait()
        local_grad_sqr = self._local_grad_sqr.cpu().numpy()

        # See appendix B.3 of the paper.
468
        # Modified to handle cases where scale != world_size
469
        #
470
471
472
        # local_grad_sqr is \sum_{i=1}^{c N} \norm{g_t_i}^2
        # where N is world size and c is num_grads_to_accum
        # total_grad_sqr is \norm{\bar{g}_t}^2
473
        S = self._scale
474
475
        cN = self._world_size * self._num_grads_to_accum
        grad_var = local_grad_sqr * (S / cN) / (cN - 1) - total_grad_sqr * S / (cN - 1)
476
        grad_sqr = total_grad_sqr - grad_var / S
477
        grad_var = np.maximum(grad_var, 1e-6)
478
        grad_sqr = np.maximum(grad_sqr, 0.0)
479
480
481
482
483
        self._update_avg("grad_sqr_avg", grad_sqr, self.smoothing)
        self._update_avg("grad_var_avg", grad_var, self.smoothing)
        self._last_final_backward_call = self._num_backward_calls
        # Indicating backward is done.
        self._local_grad_sqr = None
484

Min Xu's avatar
Min Xu committed
485
    def step(self, *args: Any, **kwargs: Any) -> Optional[float]:
486
487
488
489
        """
        Run one optimizer step using Adascale. Essentially just invokes
        ``optimizer.step(*args, **kwargs)`` with a scaled learning rate.

490
491
492
493
494
495
496
        .. note::

            It is possible that this function becames a performance
            bottleneck if you have frequent updates. To avoid that,
            making bigger steps and reducing update frequency is generally
            better for performance.

Min Xu's avatar
Min Xu committed
497
        Args:
498
            args (Any):
Min Xu's avatar
Min Xu committed
499
                Positional arguments passed to ``optimizer.step``.
500
            kwargs (Any):
Min Xu's avatar
Min Xu committed
501
                Keyword arguments passed to ``optimizer.step``.
502

Min Xu's avatar
Min Xu committed
503
504
        Returns:
            (Tensor):
505
                The loss tensor if a closure if used to re-evaluate the model.
506
        """
507
        assert self._local_grad_sqr is None, "Don't step without finishing backward phase"
508
509
        # Set original LR and set new LR.
        original_lr = []
510
        for idx, param_group in enumerate(self._optimizer.param_groups):
511
512
            original_lr.append(param_group["lr"])
            param_group["lr"] = self.gain(pg_idx=idx) * param_group["lr"]
513

514
515
        # Step it.
        res = self._optimizer.step(*args, **kwargs)
516

517
518
519
        # Restore the original LR.
        for lr, param_group in zip(original_lr, self._optimizer.param_groups):
            param_group["lr"] = lr
520

521
        return res
522

523
    def add_param_group(self, pg: Dict) -> None:
524
        """Support adding parameter groups
525

526
        We need to re-size some of the state and re-register the backward hooks.
527
528
529
530
531
532
533
534
535
        """
        assert self._local_grad_sqr is None, "Can't add parameter group during backward"
        self._optimizer.add_param_group(pg)
        # Update the hooks.
        self.unhook()
        self._hook()
        # Extend the states.
        for name in self._state.keys():
            assert name.startswith("grad_sqr_avg") or name.startswith("grad_var_avg"), name
536
537
538
            if name.endswith("_count"):
                # This is the "_count" variable, should be a 1D int.
                assert self._state[name].shape == (1,), self._state[name].shape
539
540
541
                continue
            # must be a np array, extend it with the right value and check the shape.
            val = 1 if name == "grad_sqr_avg" else 0
542
            self._state[name] = np.append(self._state[name], val)  # type: ignore
543
544
            assert self._state[name].shape == (len(self._optimizer.param_groups),)

545
    def zero_grad(self) -> None:
546
        """Proxy function to optimizer, because some training loops need this."""
547
        assert self._local_grad_sqr is None, "Don't zero_grad in backward"
548
549
550
        return self._optimizer.zero_grad()

    def state_dict(self) -> Dict:
551
        """Proxy function to optimizer, checkpointing needs this.
552

553
        .. note::
554

555
556
            Do NOT checkpoint in the middle of gradient accumulation since
            associated AdaScale internal states are not saved in the checkpoint.
557
        """
558
        assert self._local_grad_sqr is None, "Don't checkpoint in backward"
559
560
561
        return self._optimizer.state_dict()

    def load_state_dict(self, data: Dict) -> None:
562
        """Proxy function to optimizer, checkpointing needs this.
563

564
        .. note::
565

566
567
            Do NOT checkpoint in the middle of gradient accumulation since
            associated AdaScale internal states are not saved in the checkpoint.
568
        """
569
        assert self._local_grad_sqr is None, "Don't load checkpoint in backward"
570
        return self._optimizer.load_state_dict(data)
571

572
573
574
575
576
    def set_num_gradients_to_accumulate(
        self,
        num_gradients_to_accumulate: int,
        update_smoothing: bool = True,
    ) -> None:
577
578
579
580
581
582
583
        """Set the number of gradients to accumulate to a new value.

        This is experimental. This could be called while training so that
        we can gradually increasing the steps between updates. Almost always,
        `set_scale` needs to be called to update the scale as well.

        TODO (min): need a way of determine how much to increase the step size?
584

585
        TODO (min): have both `set_scale` and `set_num_gradients_to_accumulate`
586
587
588
        is hard to use and easy to make mistake. I think it is better
        to specific a specify a `base_scale`. But more discussion is
        needed here.
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607

        Args:
            num_gradients_to_accumulate (int):
                Number of gradients to accumulate (calls to backward) between
                each optimizer step
            update_smoothing (bool):
                Whether to update smoothing factor or not. Default: True.
        """
        assert self._local_grad_sqr is None, "Don't change num_grad_to_accum in backward"
        assert num_gradients_to_accumulate >= 1, f"Invalid value {num_gradients_to_accumulate}"
        self._num_grads_to_accum = num_gradients_to_accumulate
        if update_smoothing:
            # Set smoothing based on effective world_size rather than scale here,
            # since world_size determines the number of samples being averaged over
            # at every update.
            #
            # When effective world size is large enough, smoothing is probably
            # not needed, so the smoothing factor is 0.
            self._smoothing = max(1 - self._world_size * self._num_grads_to_accum / 1000, 0)
Min Xu's avatar
Min Xu committed
608

Min Xu's avatar
Min Xu committed
609
610
611
612
613
614
615
    def __getattr__(self, name: str) -> Any:
        """Forward missing attributes to wrapped optimizer."""
        try:
            return super().__getattr__(name)  # defer to Optimizer logic
        except AttributeError:
            return getattr(self._optimizer, name)  # fallback to wrapped optim

Min Xu's avatar
Min Xu committed
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642

class AdaScaleWrapper(AdaScale):
    """
    A thin wrapper for AdaScale so that the constructor resembles a
    standard optimizer. This allows it to work with other Optimizer
    Wrappers, like `OSS`.

    .. warn::
        OSS(AdaScaleWrapper) (i.e. OSS wrapping AdaScale) resulting in each
        rank's AdaScale operates on different set of parameters. They
        will get different gain values and it is unclear how to adjust
        effective step size in that case. We have not validated effectiveness
        or benefit in this case.

        OTOH, AdaScale(OSS) (i.e. AdaScale wrapping OSS) is recommended
        and is numerically identical to AdaScale without OSS. Since
        AdaScale doesn't incur per-parameter state, the memory benefit
        of OSS is still the same.

    Args:
        params (list of tensors):
            parameters to be optimized
        optim (class subtyping torch.optim.Optimizer):
            a optimizer class to be wrapped.
        additional_optim_args (argument dict):
            keyward arguments to the `optim` class above.

643
    The rest params are in-sync with the `AdaScale` class above.
Min Xu's avatar
Min Xu committed
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
    """

    def __init__(
        self,
        params: _params_t,
        world_size: Optional[int] = None,
        scale: Optional[float] = None,
        smoothing: float = None,
        num_gradients_to_accumulate: int = 1,
        debias_ewma: bool = True,
        optim_cls: Type[Optimizer] = SGD,
        **additional_optim_args: Any,
    ):
        optim_obj = optim_cls(params, **additional_optim_args)
        super().__init__(optim_obj, world_size, scale, smoothing, num_gradients_to_accumulate, debias_ewma)