Commit ddfb38ef authored by Jerry XU's avatar Jerry XU Committed by Kai Chen
Browse files

add pytorch 1.1.0 SyncBN support (#577)

* add pytorch 1.1.0 SyncBN support

* change BatchNorm2d to _BatchNorm and call freeze after train

* add freeze back to init function

* fixed indentation typo in adding freeze

* use SyncBN protect member func to set ddp_gpu_num

* Update README.md

update pytorch version to 1.1
parent c52cdd62
......@@ -3,7 +3,7 @@
## Introduction
The master branch works with **PyTorch 1.0** or higher. If you would like to use PyTorch 0.4.1,
The master branch works with **PyTorch 1.1** or higher. If you would like to use PyTorch 0.4.1,
please checkout to the [pytorch-0.4.1](https://github.com/open-mmlab/mmdetection/tree/pytorch-0.4.1) branch.
mmdetection is an open source object detection toolbox based on PyTorch. It is
......
......@@ -2,6 +2,7 @@ import logging
import torch.nn as nn
import torch.utils.checkpoint as cp
from torch.nn.modules.batchnorm import _BatchNorm
from mmcv.cnn import constant_init, kaiming_init
from mmcv.runner import load_checkpoint
......@@ -437,7 +438,7 @@ class ResNet(nn.Module):
for m in self.modules():
if isinstance(m, nn.Conv2d):
kaiming_init(m)
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
constant_init(m, 1)
if self.dcn is not None:
......@@ -470,8 +471,9 @@ class ResNet(nn.Module):
def train(self, mode=True):
super(ResNet, self).train(mode)
self._freeze_stages()
if mode and self.norm_eval:
for m in self.modules():
# trick: eval have effect on BatchNorm only
if isinstance(m, nn.BatchNorm2d):
if isinstance(m, _BatchNorm):
m.eval()
......@@ -4,7 +4,7 @@ import torch.nn as nn
norm_cfg = {
# format: layer_type: (abbreviation, module)
'BN': ('bn', nn.BatchNorm2d),
'SyncBN': ('bn', None),
'SyncBN': ('bn', nn.SyncBatchNorm),
'GN': ('gn', nn.GroupNorm),
# and potentially 'SN'
}
......@@ -44,6 +44,8 @@ def build_norm_layer(cfg, num_features, postfix=''):
cfg_.setdefault('eps', 1e-5)
if layer_type != 'GN':
layer = norm_layer(num_features, **cfg_)
if layer_type == 'SyncBN':
layer._specify_ddp_gpu_num(1)
else:
assert 'num_groups' in cfg_
layer = norm_layer(num_channels=num_features, **cfg_)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment