Unverified Commit 797ef57e authored by Jerry Jiarui XU's avatar Jerry Jiarui XU Committed by GitHub
Browse files

[Fix] Fix SyncBN build in PyTorch 1.9 (#1138)

* [Fix] Fix SyncBN build in PyTorch 1.9

* fixed parrots SyncBN
parent db097bd1
......@@ -106,7 +106,7 @@ def build_norm_layer(cfg, num_features, postfix=''):
cfg_.setdefault('eps', 1e-5)
if layer_type != 'GN':
layer = norm_layer(num_features, **cfg_)
if layer_type == 'SyncBN':
if layer_type == 'SyncBN' and hasattr(layer, '_specify_ddp_gpu_num'):
layer._specify_ddp_gpu_num(1)
else:
assert 'num_groups' in cfg_
......
......@@ -82,10 +82,6 @@ _AdaptiveAvgPoolNd, _AdaptiveMaxPoolNd, _AvgPoolNd, _MaxPoolNd = _get_pool()
class SyncBatchNorm(SyncBatchNorm_):
def _specify_ddp_gpu_num(self, gpu_size):
if TORCH_VERSION != 'parrots':
super()._specify_ddp_gpu_num(gpu_size)
def _check_input_dim(self, input):
if TORCH_VERSION == 'parrots':
if input.dim() < 2:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment