lars.py 4.39 KB
Newer Older
maming's avatar
maming committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
from keras import backend as K
from keras.optimizers import Optimizer


class LARS(Optimizer):
    """Layer-wise Adaptive Rate Scaling for large batch training.
    Introduced by "Large Batch Training of Convolutional Networks" by Y. You,
    I. Gitman, and B. Ginsburg. (https://arxiv.org/abs/1708.03888)
    Implements the LARS learning rate scheme presented in the paper above. This
    optimizer is useful when scaling the batch size to up to 32K without
    significant performance degradation. It is recommended to use the optimizer
    in conjunction with:
        - Gradual learning rate warm-up
        - Linear learning rate scaling
        - Poly rule learning rate decay
    Note, LARS scaling is currently only enabled for dense tensors.

    Args:
        lr: A `Tensor` or floating point value. The base learning rate.
        momentum: A floating point value. Momentum hyperparameter.
        weight_decay: A floating point value. Weight decay hyperparameter.
        eeta: LARS coefficient as used in the paper. Dfault set to LARS
            coefficient from the paper. (eeta / weight_decay) determines the
            highest scaling factor in LARS.
        epsilon: Optional epsilon parameter to be set in models that have very
            small gradients. Default set to 0.0.
        nesterov: when set to True, nesterov momentum will be enabled
    """

    def __init__(self,
                 lr,
                 momentum=0.9,
                 weight_decay=0.0001,
                 eeta=0.001,
                 epsilon=0.0,
                 nesterov=False,
                 **kwargs):

        if momentum < 0.0:
            raise ValueError("momentum should be positive: %s" % momentum)
        if weight_decay < 0.0:
            raise ValueError("weight_decay is not positive: %s" % weight_decay)
        super(LARS, self).__init__(**kwargs)
        with K.name_scope(self.__class__.__name__):
            self.iterations = K.variable(0, dtype='int64', name='iterations')
            self.lr = K.variable(lr, name='lr')
            self.momentum = K.variable(momentum, name='momentum')
            self.weight_decay = K.variable(weight_decay, name='weight_decay')
            self.eeta = K.variable(eeta, name='eeta')
        self.epsilon = epsilon
        self.nesterov = nesterov

    def get_updates(self, loss, params):
        grads = self.get_gradients(loss, params)
        weights = self.get_weights()
        self.updates = [K.update_add(self.iterations, 1)]
        scaled_lr = self.lr
        w_norm = K.sqrt(K.sum([K.sum(K.square(weight))
                               for weight in weights]))
        g_norm = K.sqrt(K.sum([K.sum(K.square(grad))
                               for grad in grads]))
        scaled_lr = K.switch(K.greater(w_norm * g_norm, K.zeros([1])),
                             K.expand_dims((self.eeta * w_norm /
                                            (g_norm + self.weight_decay * w_norm +
                                             self.epsilon)) * self.lr),
                             K.ones([1]) * self.lr)
        if K.backend() == 'theano':
            scaled_lr = scaled_lr[0]  # otherwise theano raise broadcasting error
        # momentum
        moments = [K.zeros(K.int_shape(param), dtype=K.dtype(param))
                   for param in params]
        self.weights = [self.iterations] + moments
        for param, grad, moment in zip(params, grads, moments):
            v0 = (moment * self.momentum)
            v1 = scaled_lr * grad  # velocity
            veloc = v0 - v1
            self.updates.append(K.update(moment, veloc))

            if self.nesterov:
                new_param = param + (veloc * self.momentum) - v1
            else:
                new_param = param + veloc

            # Apply constraints.
            if getattr(param, 'constraint', None) is not None:
                new_param = param.constraint(new_param)

            self.updates.append(K.update(param, new_param))
        return self.updates

    def get_config(self):
        config = {'lr': float(K.get_value(self.lr)),
                  'momentum': float(K.get_value(self.momentum)),
                  'weight_decay': float(K.get_value(self.weight_decay)),
                  'epsilon': self.epsilon,
                  'eeta': float(K.get_value(self.eeta)),
                  'nesterov': self.nesterov}
        base_config = super(LARS, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))