handle.py 2.55 KB
Newer Older
1
2
3
4
import contextlib
import logging
import warnings

5
6
from .opt import OptimWrapper
from .scaler import LossScaler
7
8

class AmpHandle(object):
9
    def __init__(self, enable_caching=True, verbose=False):
10
        self._enable_caching = enable_caching
11
        self._verbose = verbose
12
        self._cache = dict()
13
14
15
16
17
18
19
20
        self._default_scaler = LossScaler()

    def is_active(self):
        return True

    def wrap_optimizer(self, optimizer, num_loss=1):
        self._default_scaler = None
        return OptimWrapper(optimizer, self, num_loss)
21
22
23

    @contextlib.contextmanager
    def scale_loss(self, loss, optimizer):
24
25
26
27
28
29
30
31
32
33
        if not self.is_active():
            yield loss
            return

        if self._default_scaler is None:
            raise RuntimeError(
                'After calling `handle.wrap_optimizer()`, you must explicitly ' +
                'use `optimizer.scale_loss(loss)`.')

        # TODO: this code block is duplicated here and `opt.py`. Unify.
34
35
36
37
38
39
40
        loss_backward = loss.backward
        def warning_wrapper():
            warnings.warn("You called .backward() on the unscaled loss "
                          "inside a scale_loss block. This is almost "
                          "certainly an error.", stacklevel=2)
            loss_backward()
        loss.backward = warning_wrapper
41
42
        loss_scale = self._default_scaler.loss_scale()
        yield loss * loss_scale
43
44
        loss.backward = loss_backward

45
46
47
        should_skip = self._default_scaler.unscale_and_update(
            optimizer.param_groups, loss_scale)
        if should_skip:
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
            optimizer_step = optimizer.step
            def skip_step():
                logging.info('Gradient overflow, skipping update')
                optimizer.step = optimizer_step
            optimizer.step = skip_step

        self._clear_cache()

    def _clear_cache(self):
        self._cache.clear()

    @property
    def has_cache(self):
        return self._enable_caching

    @property
    def cache(self):
        return self._cache
66
67
68
69
70

    def remove_cache(self, param):
        if self.has_cache and param in self.cache:
            del self.cache[param]

71
72
73
74
    @property
    def verbose(self):
        return self._verbose

75
76
77
78
79
80
81
82
83
84
85
86
87
88
class NoOpHandle(object):
    def is_active(self):
        return False

    def wrap_optimizer(self, optimizer, num_loss=1):
        return OptimWrapper(optimizer, self, num_loss)

    @contextlib.contextmanager
    def scale_loss(self, loss, optimizer):
        yield loss

    @property
    def has_cache(self):
        return False
89
90
91
92

    @property
    def verbose(self):
        return False