NoBiasDecay.py 1.19 KB
Newer Older
Sugon_ldc's avatar
Sugon_ldc committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
import torch.nn as nn

def noBiasDecay(model, lr, weight_decay):
    '''
    no bias decay : only apply weight decay to the weights in convolution and fully-connected layers
    In paper [Bag of Tricks for Image Classification with Convolutional Neural Networks](https://arxiv.org/abs/1812.01187)
    Ref: https://github.com/weiaicunzai/Bag_of_Tricks_for_Image_Classification_with_Convolutional_Neural_Networks/blob/master/utils.py
    '''
    decay, bias_no_decay, weight_no_decay = [], [], []
    for m in model.modules():
        if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
            decay.append(m.weight)
            if m.bias is not None:
                bias_no_decay.append(m.bias)    
        else: 
            if hasattr(m, 'weight'):
                weight_no_decay.append(m.weight)
            if hasattr(m, 'bias'):
                bias_no_decay.append(m.bias)
        
    assert len(list(model.parameters())) == len(decay) + len(bias_no_decay) + len(weight_no_decay)
    
    # bias using 2*lr
    return [{'params': bias_no_decay, 'lr': 2*lr, 'weight_decay': 0.0}, {'params': weight_no_decay, 'lr': lr, 'weight_decay': 0.0}, {'params': decay, 'lr': lr, 'weight_decay': weight_decay}]