Commit ddfb38ef authored by Jerry XU's avatar Jerry XU Committed by Kai Chen
Browse files

add pytorch 1.1.0 SyncBN support (#577)

* add pytorch 1.1.0 SyncBN support

* change BatchNorm2d to _BatchNorm and call freeze after train

* add freeze back to init function

* fixed indentation typo in adding freeze

* use SyncBN protect member func to set ddp_gpu_num

* Update README.md

update pytorch version to 1.1
parent c52cdd62
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
## Introduction ## Introduction
The master branch works with **PyTorch 1.0** or higher. If you would like to use PyTorch 0.4.1, The master branch works with **PyTorch 1.1** or higher. If you would like to use PyTorch 0.4.1,
please checkout to the [pytorch-0.4.1](https://github.com/open-mmlab/mmdetection/tree/pytorch-0.4.1) branch. please checkout to the [pytorch-0.4.1](https://github.com/open-mmlab/mmdetection/tree/pytorch-0.4.1) branch.
mmdetection is an open source object detection toolbox based on PyTorch. It is mmdetection is an open source object detection toolbox based on PyTorch. It is
......
...@@ -2,6 +2,7 @@ import logging ...@@ -2,6 +2,7 @@ import logging
import torch.nn as nn import torch.nn as nn
import torch.utils.checkpoint as cp import torch.utils.checkpoint as cp
from torch.nn.modules.batchnorm import _BatchNorm
from mmcv.cnn import constant_init, kaiming_init from mmcv.cnn import constant_init, kaiming_init
from mmcv.runner import load_checkpoint from mmcv.runner import load_checkpoint
...@@ -437,7 +438,7 @@ class ResNet(nn.Module): ...@@ -437,7 +438,7 @@ class ResNet(nn.Module):
for m in self.modules(): for m in self.modules():
if isinstance(m, nn.Conv2d): if isinstance(m, nn.Conv2d):
kaiming_init(m) kaiming_init(m)
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
constant_init(m, 1) constant_init(m, 1)
if self.dcn is not None: if self.dcn is not None:
...@@ -470,8 +471,9 @@ class ResNet(nn.Module): ...@@ -470,8 +471,9 @@ class ResNet(nn.Module):
def train(self, mode=True): def train(self, mode=True):
super(ResNet, self).train(mode) super(ResNet, self).train(mode)
self._freeze_stages()
if mode and self.norm_eval: if mode and self.norm_eval:
for m in self.modules(): for m in self.modules():
# trick: eval have effect on BatchNorm only # trick: eval have effect on BatchNorm only
if isinstance(m, nn.BatchNorm2d): if isinstance(m, _BatchNorm):
m.eval() m.eval()
...@@ -4,7 +4,7 @@ import torch.nn as nn ...@@ -4,7 +4,7 @@ import torch.nn as nn
norm_cfg = { norm_cfg = {
# format: layer_type: (abbreviation, module) # format: layer_type: (abbreviation, module)
'BN': ('bn', nn.BatchNorm2d), 'BN': ('bn', nn.BatchNorm2d),
'SyncBN': ('bn', None), 'SyncBN': ('bn', nn.SyncBatchNorm),
'GN': ('gn', nn.GroupNorm), 'GN': ('gn', nn.GroupNorm),
# and potentially 'SN' # and potentially 'SN'
} }
...@@ -44,6 +44,8 @@ def build_norm_layer(cfg, num_features, postfix=''): ...@@ -44,6 +44,8 @@ def build_norm_layer(cfg, num_features, postfix=''):
cfg_.setdefault('eps', 1e-5) cfg_.setdefault('eps', 1e-5)
if layer_type != 'GN': if layer_type != 'GN':
layer = norm_layer(num_features, **cfg_) layer = norm_layer(num_features, **cfg_)
if layer_type == 'SyncBN':
layer._specify_ddp_gpu_num(1)
else: else:
assert 'num_groups' in cfg_ assert 'num_groups' in cfg_
layer = norm_layer(num_channels=num_features, **cfg_) layer = norm_layer(num_channels=num_features, **cfg_)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment