handle.py 3.17 KB
Newer Older
1
2
3
4
import contextlib
import logging
import warnings

5
from . import utils
6
7
from .opt import OptimWrapper
from .scaler import LossScaler
8
9

class AmpHandle(object):
10
    def __init__(self, enable_caching=True, verbose=False):
11
        self._enable_caching = enable_caching
12
        self._verbose = verbose
13
        self._cache = dict()
14
        self._default_scaler = LossScaler()
15
        self._is_active = True
16
        self._all_wrappers = []
17
18

    def is_active(self):
19
20
21
22
23
24
25
        return self._is_active

    @contextlib.contextmanager
    def _disable_casts(self):
        self._is_active = False
        yield
        self._is_active = True
26
27
28
29

    def wrap_optimizer(self, optimizer, num_loss=1):
        self._default_scaler = None
        return OptimWrapper(optimizer, self, num_loss)
30
31
32

    @contextlib.contextmanager
    def scale_loss(self, loss, optimizer):
33
34
35
36
37
38
39
40
41
42
        if not self.is_active():
            yield loss
            return

        if self._default_scaler is None:
            raise RuntimeError(
                'After calling `handle.wrap_optimizer()`, you must explicitly ' +
                'use `optimizer.scale_loss(loss)`.')

        # TODO: this code block is duplicated here and `opt.py`. Unify.
43
44
45
46
47
48
49
        loss_backward = loss.backward
        def warning_wrapper():
            warnings.warn("You called .backward() on the unscaled loss "
                          "inside a scale_loss block. This is almost "
                          "certainly an error.", stacklevel=2)
            loss_backward()
        loss.backward = warning_wrapper
50
51
        loss_scale = self._default_scaler.loss_scale()
        yield loss * loss_scale
52
53
        loss.backward = loss_backward

54
55
56
        should_skip = self._default_scaler.unscale_and_update(
            optimizer.param_groups, loss_scale)
        if should_skip:
57
58
59
60
61
62
63
64
65
66
67
            optimizer_step = optimizer.step
            def skip_step():
                logging.info('Gradient overflow, skipping update')
                optimizer.step = optimizer_step
            optimizer.step = skip_step

        self._clear_cache()

    def _clear_cache(self):
        self._cache.clear()

68
69
70
71
72
73
74
75
76
    # Experimental support for saving / restoring uncasted versions of functions
    def _save_func(self, mod, fn, func):
        self._all_wrappers.append((mod, fn, func))

    def _deactivate(self):
        for mod, fn, func in self._all_wrappers:
            utils.set_func(mod, fn, func)
        self._all_wrappers = []

77
78
79
80
81
82
83
    @property
    def has_cache(self):
        return self._enable_caching

    @property
    def cache(self):
        return self._cache
84
85
86
87
88

    def remove_cache(self, param):
        if self.has_cache and param in self.cache:
            del self.cache[param]

89
90
91
92
    @property
    def verbose(self):
        return self._verbose

93
94
95
96
class NoOpHandle(object):
    def is_active(self):
        return False

97
98
99
100
    @contextlib.contextmanager
    def _disable_casts(self):
        yield

101
102
103
104
105
106
107
108
109
110
    def wrap_optimizer(self, optimizer, num_loss=1):
        return OptimWrapper(optimizer, self, num_loss)

    @contextlib.contextmanager
    def scale_loss(self, loss, optimizer):
        yield loss

    @property
    def has_cache(self):
        return False
111
112
113
114

    @property
    def verbose(self):
        return False