"vscode:/vscode.git/clone" did not exist on "2d56f106a60c7b698705494e7539f8a7e4c85dd9"
Unverified Commit 5dba1c00 authored by wang shiguang's avatar wang shiguang Committed by GitHub
Browse files

Parrots wapper (#258)



* add parrots wapper

* Update parrots_wrapper.py

* Update parrots_wrapper.py

* fix commits

* fix

* fix lint and adapt torch.__version__

* mirror fix
Co-authored-by: default avatarwangshiguang <wangshiguang@sensetime.com>
parent 44e6b93d
......@@ -22,7 +22,13 @@ class TensorboardLoggerHook(LoggerHook):
@master_only
def before_run(self, runner):
if torch.__version__ >= '1.1' and '.' in torch.__version__:
if torch.__version__ < '1.1' or torch.__version__ == 'parrots':
try:
from tensorboardX import SummaryWriter
except ImportError:
raise ImportError('Please install tensorboardX to use '
'TensorboardLoggerHook.')
else:
try:
from torch.utils.tensorboard import SummaryWriter
except ImportError:
......@@ -30,12 +36,7 @@ class TensorboardLoggerHook(LoggerHook):
'Please run "pip install future tensorboard" to install '
'the dependencies to use torch.utils.tensorboard '
'(applicable to PyTorch 1.1 or higher)')
else:
try:
from tensorboardX import SummaryWriter
except ImportError:
raise ImportError('Please install tensorboardX to use '
'TensorboardLoggerHook.')
if self.log_dir is None:
self.log_dir = osp.join(runner.work_dir, 'tf_logs')
self.writer = SummaryWriter(self.log_dir)
......
......@@ -5,6 +5,10 @@ from .misc import (check_prerequisites, concat_list, is_list_of, is_seq_of,
is_str, is_tuple_of, iter_cast, list_cast,
requires_executable, requires_package, slice_list,
tuple_cast)
from .parrots_wrapper import (CUDA_HOME, SyncBatchNorm, _AdaptiveAvgPoolNd,
_AdaptiveMaxPoolNd, _AvgPoolNd, _BatchNorm,
_ConvNd, _ConvTransposeMixin, _InstanceNorm,
_MaxPoolNd, get_build_config)
from .path import (check_file_exist, fopen, is_filepath, mkdir_or_exist,
scandir, symlink)
from .progressbar import (ProgressBar, track_iter_progress,
......@@ -19,5 +23,8 @@ __all__ = [
'requires_package', 'requires_executable', 'is_filepath', 'fopen',
'check_file_exist', 'mkdir_or_exist', 'symlink', 'scandir', 'ProgressBar',
'track_progress', 'track_iter_progress', 'track_parallel_progress',
'Registry', 'build_from_cfg', 'Timer', 'TimerError', 'check_time'
'Registry', 'build_from_cfg', 'Timer', 'TimerError', 'check_time',
'CUDA_HOME', 'SyncBatchNorm', '_AdaptiveAvgPoolNd', '_AdaptiveMaxPoolNd',
'_AvgPoolNd', '_BatchNorm', '_ConvNd', '_ConvTransposeMixin',
'_InstanceNorm', '_MaxPoolNd', 'get_build_config'
]
import torch
def _get_cuda_home():
if torch.__version__ == 'parrots':
from parrots.utils.build_extension import CUDA_HOME
else:
from torch.utils.cpp_extension import CUDA_HOME
return CUDA_HOME
def get_build_config():
if torch.__version__ == 'parrots':
from parrots.config import get_build_info
return get_build_info()
else:
return torch.__config__.show()
def _get_conv():
if torch.__version__ == 'parrots':
from parrots.nn.modules.conv import _ConvNd, _ConvTransposeMixin
else:
from torch.nn.modules.conv import _ConvNd, _ConvTransposeMixin
return _ConvNd, _ConvTransposeMixin
def _get_pool():
if torch.__version__ == 'parrots':
from parrots.nn.modules.pool import (_AdaptiveAvgPoolNd,
_AdaptiveMaxPoolNd, _AvgPoolNd,
_MaxPoolNd)
else:
from torch.nn.modules.pooling import (_AdaptiveAvgPoolNd,
_AdaptiveMaxPoolNd, _AvgPoolNd,
_MaxPoolNd)
return _AdaptiveAvgPoolNd, _AdaptiveMaxPoolNd, _AvgPoolNd, _MaxPoolNd
def _get_norm():
if torch.__version__ == 'parrots':
from parrots.nn.modules.batchnorm import _BatchNorm, _InstanceNorm
SyncBatchNorm_ = torch.nn.SyncBatchNorm2d
else:
from torch.nn.modules.instancenorm import _InstanceNorm
from torch.nn.modules.batchnorm import _BatchNorm
SyncBatchNorm_ = torch.nn.SyncBatchNorm
return _BatchNorm, _InstanceNorm, SyncBatchNorm_
CUDA_HOME = _get_cuda_home()
_ConvNd, _ConvTransposeMixin = _get_conv()
_BatchNorm, _InstanceNorm, SyncBatchNorm_ = _get_norm()
_AdaptiveAvgPoolNd, _AdaptiveMaxPoolNd, _AvgPoolNd, _MaxPoolNd = _get_pool()
class SyncBatchNorm(SyncBatchNorm_):
def _specify_ddp_gpu_num(self, gpu_size):
if torch.__version__ != 'parrots':
super()._specify_ddp_gpu_num(gpu_size)
def _check_input_dim(self, input):
if torch.__version__ == 'parrots':
if input.dim() < 2:
raise ValueError(
f'expected at least 2D input (got {input.dim()}D input)')
else:
super()._check_input_dim(input)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment