norm.py 278 Bytes
Newer Older
Ponku's avatar
Ponku committed
1
2
3
4
5
6
7
8
9
10
11
12
13
import torch


def freeze_batch_norm(model):
    for m in model.modules():
        if isinstance(m, torch.nn.BatchNorm2d):
            m.eval()


def unfreeze_batch_norm(model):
    for m in model.modules():
        if isinstance(m, torch.nn.BatchNorm2d):
            m.train()