handle.py 6.47 KB
Newer Older
1
2
3
4
import contextlib
import logging
import warnings

5
from . import utils
6
from .opt import OptimWrapper
Michael Carilli's avatar
Michael Carilli committed
7
from .scaler import LossScaler, iter_params
8
9
10
11
12
13
14
15
16
17
from ._amp_state import _amp_state
from ..fp16_utils import FP16_Optimizer


# There's no reason to expose the notion of a "handle". Everything can happen through amp.* calls.
@contextlib.contextmanager
def scale_loss(loss,
               optimizer,
               model=None,
               delay_unscale=False):
18
19
20
21
22
23
24
25
26
27
28
29
30
    """
    On context manager entrance, scale the loss in a way consistent with the current loss scale.
    Yield the loss

    On context manager exit (if ``delay_unscale=False``), unscale the gradients so that
    ``optimizer.step()`` can be called.

    .. note::
    If Amp is using explicit FP32 master params (which is the default for ``opt_level=O2``, and
    can also be manually enabled by supplying ``master_weights=True`` to ``amp.initialize``)
    any FP16 gradients are copied to FP32 master gradients before being unscaled.  ``optimizer.step()``
    will then apply the unscaled master gradients to the master params.
    """
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
    if not _amp_state.opt_properties.enabled:
        yield loss
        return

    if optimizer.loss_scaler is None:
        raise RuntimeError("optimizer passed to scale_loss does not have a loss_scaler.")

    loss_scale = optimizer.loss_scaler.loss_scale()

    if ((not _amp_state.opt_properties.master_weights)
        and (not optimizer.loss_scaler.dynamic)
        and loss_scale == 1.0):
        yield loss
        # Needing to drop the cache here as well is an ugly gotcha.
        # But for now I think it's necessary to short-circuit.
        # Probably ok to skip this if not delay_unscale
47
        if _amp_state.opt_properties.patch_torch_functions:
48
49
50
51
52
53
54
55
56
57
58
59
60
            _amp_state.handle._clear_cache()
        return

    yield loss*loss_scale

    # this isn't pretty but it unifies things.  Once I deprecate the old API entirely,
    # I will have freedom to clean this up.  Maybe instead of wrapping optimizers,
    # I can simply construct a set of attributes (e.g. master params) and assign them
    # directly to optimizer instances.
    if not delay_unscale:
        if isinstance(optimizer, FP16_Optimizer):
            optimizer.update_master_grads()
        else:
61
            optimizer.loss_scaler.clear_overflow_state()
62
63
64
65
            optimizer.loss_scaler.unscale(
                iter_params(optimizer.param_groups),
                iter_params(optimizer.param_groups),
                loss_scale)
66
            # For future fused optimizers that enable sync-free dynamic loss scaling,
67
            # should_skip will always be False.
68
69
70
71
72
            should_skip = optimizer.loss_scaler.update_scale()
            if should_skip:
                optimizer_step = optimizer.step
                def skip_step():
                    logger = logging.getLogger('apex.amp')
73
74
                    logger.warning("Gradient overflow.  Skipping step, reducing " +
                                   "loss scale to {}".format(optimizer.loss_scaler.loss_scale()))
75
76
77
78
                    optimizer.step = optimizer_step
                optimizer.step = skip_step

    # Probably ok to skip this if not delay_unscale
79
    if _amp_state.opt_properties.patch_torch_functions:
80
81
        _amp_state.handle._clear_cache()

82

83
84
85
86
87
88
89
90
91
# Free function version of AmpHandle.disable_casts, another step on the
# path to removing the concept of "AmpHandle"
@contextlib.contextmanager
def disable_casts():
    _amp_state.handle._is_active = False
    yield
    _amp_state.handle._is_active = True


92
class AmpHandle(object):
Michael Carilli's avatar
Michael Carilli committed
93
    def __init__(self, loss_scale="dynamic", enable_caching=True, verbose=False):
94
        self._enable_caching = enable_caching
95
        self._verbose = verbose
96
        self._cache = dict()
Michael Carilli's avatar
Michael Carilli committed
97
        self._default_scaler = LossScaler(loss_scale)
98
        self._is_active = True
99
        self._all_wrappers = []
100
101

    def is_active(self):
102
103
104
105
106
107
108
        return self._is_active

    @contextlib.contextmanager
    def _disable_casts(self):
        self._is_active = False
        yield
        self._is_active = True
109
110
111
112

    def wrap_optimizer(self, optimizer, num_loss=1):
        self._default_scaler = None
        return OptimWrapper(optimizer, self, num_loss)
113
114
115

    @contextlib.contextmanager
    def scale_loss(self, loss, optimizer):
116
117
118
119
120
121
122
123
124
125
126
127
        if not self.is_active():
            yield loss
            return

        if self._default_scaler is None:
            raise RuntimeError(
                'After calling `handle.wrap_optimizer()`, you must explicitly ' +
                'use `optimizer.scale_loss(loss)`.')

        # TODO: this code block is duplicated here and `opt.py`. Unify.
        loss_scale = self._default_scaler.loss_scale()
        yield loss * loss_scale
128

129
        self._default_scaler.clear_overflow_state()
130
        self._default_scaler.unscale(
Michael Carilli's avatar
Michael Carilli committed
131
132
133
            iter_params(optimizer.param_groups),
            iter_params(optimizer.param_groups),
            loss_scale)
134
        should_skip = self._default_scaler.update_scale()
135
        if should_skip:
136
137
            optimizer_step = optimizer.step
            def skip_step():
138
                logger = logging.getLogger('apex.amp')
139
                logger.warning('Gradient overflow, skipping update')
140
141
142
143
144
145
146
147
                optimizer.step = optimizer_step
            optimizer.step = skip_step

        self._clear_cache()

    def _clear_cache(self):
        self._cache.clear()

148
149
150
151
152
153
154
155
156
    # Experimental support for saving / restoring uncasted versions of functions
    def _save_func(self, mod, fn, func):
        self._all_wrappers.append((mod, fn, func))

    def _deactivate(self):
        for mod, fn, func in self._all_wrappers:
            utils.set_func(mod, fn, func)
        self._all_wrappers = []

157
158
159
160
161
162
163
    @property
    def has_cache(self):
        return self._enable_caching

    @property
    def cache(self):
        return self._cache
164
165
166
167
168

    def remove_cache(self, param):
        if self.has_cache and param in self.cache:
            del self.cache[param]

169
170
171
172
    @property
    def verbose(self):
        return self._verbose

173
174
175
176
class NoOpHandle(object):
    def is_active(self):
        return False

177
178
179
180
    @contextlib.contextmanager
    def _disable_casts(self):
        yield

181
182
183
184
185
186
187
188
189
190
    def wrap_optimizer(self, optimizer, num_loss=1):
        return OptimWrapper(optimizer, self, num_loss)

    @contextlib.contextmanager
    def scale_loss(self, loss, optimizer):
        yield loss

    @property
    def has_cache(self):
        return False
191
192
193
194

    @property
    def verbose(self):
        return False
195

196
197
198
    def _clear_cache(self):
        pass

199
200
    def _deactivate(self):
        pass