Unverified Commit c30e91db authored by Cao Yuhang's avatar Cao Yuhang Committed by GitHub
Browse files

share torch version (#343)

parent b87e774f
...@@ -3,6 +3,7 @@ import torch ...@@ -3,6 +3,7 @@ import torch
from torch.nn.parallel.distributed import (DistributedDataParallel, from torch.nn.parallel.distributed import (DistributedDataParallel,
_find_tensors) _find_tensors)
from mmcv.utils import TORCH_VERSION
from .scatter_gather import scatter_kwargs from .scatter_gather import scatter_kwargs
...@@ -47,7 +48,7 @@ class MMDistributedDataParallel(DistributedDataParallel): ...@@ -47,7 +48,7 @@ class MMDistributedDataParallel(DistributedDataParallel):
else: else:
self.reducer.prepare_for_backward([]) self.reducer.prepare_for_backward([])
else: else:
if torch.__version__ > '1.2': if TORCH_VERSION > '1.2':
self.require_forward_param_sync = False self.require_forward_param_sync = False
return output return output
...@@ -79,6 +80,6 @@ class MMDistributedDataParallel(DistributedDataParallel): ...@@ -79,6 +80,6 @@ class MMDistributedDataParallel(DistributedDataParallel):
else: else:
self.reducer.prepare_for_backward([]) self.reducer.prepare_for_backward([])
else: else:
if torch.__version__ > '1.2': if TORCH_VERSION > '1.2':
self.require_forward_param_sync = False self.require_forward_param_sync = False
return output return output
...@@ -5,6 +5,7 @@ import torch.nn as nn ...@@ -5,6 +5,7 @@ import torch.nn as nn
from torch._utils import (_flatten_dense_tensors, _take_tensors, from torch._utils import (_flatten_dense_tensors, _take_tensors,
_unflatten_dense_tensors) _unflatten_dense_tensors)
from mmcv.utils import TORCH_VERSION
from .scatter_gather import scatter_kwargs from .scatter_gather import scatter_kwargs
...@@ -37,7 +38,7 @@ class MMDistributedDataParallel(nn.Module): ...@@ -37,7 +38,7 @@ class MMDistributedDataParallel(nn.Module):
self._dist_broadcast_coalesced(module_states, self._dist_broadcast_coalesced(module_states,
self.broadcast_bucket_size) self.broadcast_bucket_size)
if self.broadcast_buffers: if self.broadcast_buffers:
if torch.__version__ < '1.0': if TORCH_VERSION < '1.0':
buffers = [b.data for b in self.module._all_buffers()] buffers = [b.data for b in self.module._all_buffers()]
else: else:
buffers = [b.data for b in self.module.buffers()] buffers = [b.data for b in self.module.buffers()]
......
...@@ -7,6 +7,8 @@ import torch ...@@ -7,6 +7,8 @@ import torch
import torch.distributed as dist import torch.distributed as dist
import torch.multiprocessing as mp import torch.multiprocessing as mp
from mmcv.utils import TORCH_VERSION
def init_dist(launcher, backend='nccl', **kwargs): def init_dist(launcher, backend='nccl', **kwargs):
if mp.get_start_method(allow_none=True) is None: if mp.get_start_method(allow_none=True) is None:
...@@ -49,7 +51,7 @@ def _init_dist_slurm(backend, port=29500): ...@@ -49,7 +51,7 @@ def _init_dist_slurm(backend, port=29500):
def get_dist_info(): def get_dist_info():
if torch.__version__ < '1.0': if TORCH_VERSION < '1.0':
initialized = dist._initialized initialized = dist._initialized
else: else:
if dist.is_available(): if dist.is_available():
......
# Copyright (c) Open-MMLab. All rights reserved. # Copyright (c) Open-MMLab. All rights reserved.
import os.path as osp import os.path as osp
import torch from mmcv.utils import TORCH_VERSION
from ...dist_utils import master_only from ...dist_utils import master_only
from ..hook import HOOKS from ..hook import HOOKS
from .base import LoggerHook from .base import LoggerHook
...@@ -22,7 +21,7 @@ class TensorboardLoggerHook(LoggerHook): ...@@ -22,7 +21,7 @@ class TensorboardLoggerHook(LoggerHook):
@master_only @master_only
def before_run(self, runner): def before_run(self, runner):
if torch.__version__ < '1.1' or torch.__version__ == 'parrots': if TORCH_VERSION < '1.1' or TORCH_VERSION == 'parrots':
try: try:
from tensorboardX import SummaryWriter from tensorboardX import SummaryWriter
except ImportError: except ImportError:
......
# Copyright (c) Open-MMLab. All rights reserved. # Copyright (c) Open-MMLab. All rights reserved.
from .config import Config, ConfigDict, DictAction from .config import Config, ConfigDict, DictAction
from .env import TORCH_VERSION
from .logging import get_logger, print_log from .logging import get_logger, print_log
from .misc import (check_prerequisites, concat_list, is_list_of, is_seq_of, from .misc import (check_prerequisites, concat_list, is_list_of, is_seq_of,
is_str, is_tuple_of, iter_cast, list_cast, is_str, is_tuple_of, iter_cast, list_cast,
...@@ -29,5 +30,6 @@ __all__ = [ ...@@ -29,5 +30,6 @@ __all__ = [
'CUDA_HOME', 'SyncBatchNorm', '_AdaptiveAvgPoolNd', '_AdaptiveMaxPoolNd', 'CUDA_HOME', 'SyncBatchNorm', '_AdaptiveAvgPoolNd', '_AdaptiveMaxPoolNd',
'_AvgPoolNd', '_BatchNorm', '_ConvNd', '_ConvTransposeMixin', '_AvgPoolNd', '_BatchNorm', '_ConvNd', '_ConvTransposeMixin',
'_InstanceNorm', '_MaxPoolNd', 'get_build_config', 'BuildExtension', '_InstanceNorm', '_MaxPoolNd', 'get_build_config', 'BuildExtension',
'CppExtension', 'CUDAExtension', 'DataLoader', 'PoolDataLoader' 'CppExtension', 'CUDAExtension', 'DataLoader', 'PoolDataLoader',
'TORCH_VERSION'
] ]
# This file holding some environment constant for sharing by other files
import torch
TORCH_VERSION = torch.__version__
...@@ -2,9 +2,11 @@ from functools import partial ...@@ -2,9 +2,11 @@ from functools import partial
import torch import torch
from .env import TORCH_VERSION
def _get_cuda_home(): def _get_cuda_home():
if torch.__version__ == 'parrots': if TORCH_VERSION == 'parrots':
from parrots.utils.build_extension import CUDA_HOME from parrots.utils.build_extension import CUDA_HOME
else: else:
from torch.utils.cpp_extension import CUDA_HOME from torch.utils.cpp_extension import CUDA_HOME
...@@ -12,7 +14,7 @@ def _get_cuda_home(): ...@@ -12,7 +14,7 @@ def _get_cuda_home():
def get_build_config(): def get_build_config():
if torch.__version__ == 'parrots': if TORCH_VERSION == 'parrots':
from parrots.config import get_build_info from parrots.config import get_build_info
return get_build_info() return get_build_info()
else: else:
...@@ -20,7 +22,7 @@ def get_build_config(): ...@@ -20,7 +22,7 @@ def get_build_config():
def _get_conv(): def _get_conv():
if torch.__version__ == 'parrots': if TORCH_VERSION == 'parrots':
from parrots.nn.modules.conv import _ConvNd, _ConvTransposeMixin from parrots.nn.modules.conv import _ConvNd, _ConvTransposeMixin
else: else:
from torch.nn.modules.conv import _ConvNd, _ConvTransposeMixin from torch.nn.modules.conv import _ConvNd, _ConvTransposeMixin
...@@ -28,7 +30,7 @@ def _get_conv(): ...@@ -28,7 +30,7 @@ def _get_conv():
def _get_dataloader(): def _get_dataloader():
if torch.__version__ == 'parrots': if TORCH_VERSION == 'parrots':
from torch.utils.data import DataLoader, PoolDataLoader from torch.utils.data import DataLoader, PoolDataLoader
else: else:
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
...@@ -37,7 +39,7 @@ def _get_dataloader(): ...@@ -37,7 +39,7 @@ def _get_dataloader():
def _get_extension(): def _get_extension():
if torch.__version__ == 'parrots': if TORCH_VERSION == 'parrots':
from parrots.utils.build_extension import BuildExtension, Extension from parrots.utils.build_extension import BuildExtension, Extension
CppExtension = partial(Extension, cuda=False) CppExtension = partial(Extension, cuda=False)
CUDAExtension = partial(Extension, cuda=True) CUDAExtension = partial(Extension, cuda=True)
...@@ -48,7 +50,7 @@ def _get_extension(): ...@@ -48,7 +50,7 @@ def _get_extension():
def _get_pool(): def _get_pool():
if torch.__version__ == 'parrots': if TORCH_VERSION == 'parrots':
from parrots.nn.modules.pool import (_AdaptiveAvgPoolNd, from parrots.nn.modules.pool import (_AdaptiveAvgPoolNd,
_AdaptiveMaxPoolNd, _AvgPoolNd, _AdaptiveMaxPoolNd, _AvgPoolNd,
_MaxPoolNd) _MaxPoolNd)
...@@ -60,7 +62,7 @@ def _get_pool(): ...@@ -60,7 +62,7 @@ def _get_pool():
def _get_norm(): def _get_norm():
if torch.__version__ == 'parrots': if TORCH_VERSION == 'parrots':
from parrots.nn.modules.batchnorm import _BatchNorm, _InstanceNorm from parrots.nn.modules.batchnorm import _BatchNorm, _InstanceNorm
SyncBatchNorm_ = torch.nn.SyncBatchNorm2d SyncBatchNorm_ = torch.nn.SyncBatchNorm2d
else: else:
...@@ -81,11 +83,11 @@ _AdaptiveAvgPoolNd, _AdaptiveMaxPoolNd, _AvgPoolNd, _MaxPoolNd = _get_pool() ...@@ -81,11 +83,11 @@ _AdaptiveAvgPoolNd, _AdaptiveMaxPoolNd, _AvgPoolNd, _MaxPoolNd = _get_pool()
class SyncBatchNorm(SyncBatchNorm_): class SyncBatchNorm(SyncBatchNorm_):
def _specify_ddp_gpu_num(self, gpu_size): def _specify_ddp_gpu_num(self, gpu_size):
if torch.__version__ != 'parrots': if TORCH_VERSION != 'parrots':
super()._specify_ddp_gpu_num(gpu_size) super()._specify_ddp_gpu_num(gpu_size)
def _check_input_dim(self, input): def _check_input_dim(self, input):
if torch.__version__ == 'parrots': if TORCH_VERSION == 'parrots':
if input.dim() < 2: if input.dim() < 2:
raise ValueError( raise ValueError(
f'expected at least 2D input (got {input.dim()}D input)') f'expected at least 2D input (got {input.dim()}D input)')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment