Unverified Commit 45fa3e44 authored by Zaida Zhou's avatar Zaida Zhou Committed by GitHub
Browse files

Add pyupgrade pre-commit hook (#1937)

* add pyupgrade

* add options for pyupgrade

* minor refinement
parent c561264d
...@@ -296,7 +296,7 @@ class SimpleRoIAlign(nn.Module): ...@@ -296,7 +296,7 @@ class SimpleRoIAlign(nn.Module):
If True, align the results more perfectly. If True, align the results more perfectly.
""" """
super(SimpleRoIAlign, self).__init__() super().__init__()
self.output_size = _pair(output_size) self.output_size = _pair(output_size)
self.spatial_scale = float(spatial_scale) self.spatial_scale = float(spatial_scale)
# to be consistent with other RoI ops # to be consistent with other RoI ops
......
...@@ -72,7 +72,7 @@ psa_mask = PSAMaskFunction.apply ...@@ -72,7 +72,7 @@ psa_mask = PSAMaskFunction.apply
class PSAMask(nn.Module): class PSAMask(nn.Module):
def __init__(self, psa_type, mask_size=None): def __init__(self, psa_type, mask_size=None):
super(PSAMask, self).__init__() super().__init__()
assert psa_type in ['collect', 'distribute'] assert psa_type in ['collect', 'distribute']
if psa_type == 'collect': if psa_type == 'collect':
psa_type_enum = 0 psa_type_enum = 0
......
...@@ -116,7 +116,7 @@ class RiRoIAlignRotated(nn.Module): ...@@ -116,7 +116,7 @@ class RiRoIAlignRotated(nn.Module):
num_samples=0, num_samples=0,
num_orientations=8, num_orientations=8,
clockwise=False): clockwise=False):
super(RiRoIAlignRotated, self).__init__() super().__init__()
self.out_size = out_size self.out_size = out_size
self.spatial_scale = float(spatial_scale) self.spatial_scale = float(spatial_scale)
......
...@@ -181,7 +181,7 @@ class RoIAlign(nn.Module): ...@@ -181,7 +181,7 @@ class RoIAlign(nn.Module):
pool_mode='avg', pool_mode='avg',
aligned=True, aligned=True,
use_torchvision=False): use_torchvision=False):
super(RoIAlign, self).__init__() super().__init__()
self.output_size = _pair(output_size) self.output_size = _pair(output_size)
self.spatial_scale = float(spatial_scale) self.spatial_scale = float(spatial_scale)
......
...@@ -156,7 +156,7 @@ class RoIAlignRotated(nn.Module): ...@@ -156,7 +156,7 @@ class RoIAlignRotated(nn.Module):
sampling_ratio=0, sampling_ratio=0,
aligned=True, aligned=True,
clockwise=False): clockwise=False):
super(RoIAlignRotated, self).__init__() super().__init__()
self.output_size = _pair(output_size) self.output_size = _pair(output_size)
self.spatial_scale = float(spatial_scale) self.spatial_scale = float(spatial_scale)
......
...@@ -71,7 +71,7 @@ roi_pool = RoIPoolFunction.apply ...@@ -71,7 +71,7 @@ roi_pool = RoIPoolFunction.apply
class RoIPool(nn.Module): class RoIPool(nn.Module):
def __init__(self, output_size, spatial_scale=1.0): def __init__(self, output_size, spatial_scale=1.0):
super(RoIPool, self).__init__() super().__init__()
self.output_size = _pair(output_size) self.output_size = _pair(output_size)
self.spatial_scale = float(spatial_scale) self.spatial_scale = float(spatial_scale)
......
...@@ -64,7 +64,7 @@ class SparseConvolution(SparseModule): ...@@ -64,7 +64,7 @@ class SparseConvolution(SparseModule):
inverse=False, inverse=False,
indice_key=None, indice_key=None,
fused_bn=False): fused_bn=False):
super(SparseConvolution, self).__init__() super().__init__()
assert groups == 1 assert groups == 1
if not isinstance(kernel_size, (list, tuple)): if not isinstance(kernel_size, (list, tuple)):
kernel_size = [kernel_size] * ndim kernel_size = [kernel_size] * ndim
...@@ -217,7 +217,7 @@ class SparseConv2d(SparseConvolution): ...@@ -217,7 +217,7 @@ class SparseConv2d(SparseConvolution):
groups=1, groups=1,
bias=True, bias=True,
indice_key=None): indice_key=None):
super(SparseConv2d, self).__init__( super().__init__(
2, 2,
in_channels, in_channels,
out_channels, out_channels,
...@@ -243,7 +243,7 @@ class SparseConv3d(SparseConvolution): ...@@ -243,7 +243,7 @@ class SparseConv3d(SparseConvolution):
groups=1, groups=1,
bias=True, bias=True,
indice_key=None): indice_key=None):
super(SparseConv3d, self).__init__( super().__init__(
3, 3,
in_channels, in_channels,
out_channels, out_channels,
...@@ -269,7 +269,7 @@ class SparseConv4d(SparseConvolution): ...@@ -269,7 +269,7 @@ class SparseConv4d(SparseConvolution):
groups=1, groups=1,
bias=True, bias=True,
indice_key=None): indice_key=None):
super(SparseConv4d, self).__init__( super().__init__(
4, 4,
in_channels, in_channels,
out_channels, out_channels,
...@@ -295,7 +295,7 @@ class SparseConvTranspose2d(SparseConvolution): ...@@ -295,7 +295,7 @@ class SparseConvTranspose2d(SparseConvolution):
groups=1, groups=1,
bias=True, bias=True,
indice_key=None): indice_key=None):
super(SparseConvTranspose2d, self).__init__( super().__init__(
2, 2,
in_channels, in_channels,
out_channels, out_channels,
...@@ -322,7 +322,7 @@ class SparseConvTranspose3d(SparseConvolution): ...@@ -322,7 +322,7 @@ class SparseConvTranspose3d(SparseConvolution):
groups=1, groups=1,
bias=True, bias=True,
indice_key=None): indice_key=None):
super(SparseConvTranspose3d, self).__init__( super().__init__(
3, 3,
in_channels, in_channels,
out_channels, out_channels,
...@@ -345,7 +345,7 @@ class SparseInverseConv2d(SparseConvolution): ...@@ -345,7 +345,7 @@ class SparseInverseConv2d(SparseConvolution):
kernel_size, kernel_size,
indice_key=None, indice_key=None,
bias=True): bias=True):
super(SparseInverseConv2d, self).__init__( super().__init__(
2, 2,
in_channels, in_channels,
out_channels, out_channels,
...@@ -364,7 +364,7 @@ class SparseInverseConv3d(SparseConvolution): ...@@ -364,7 +364,7 @@ class SparseInverseConv3d(SparseConvolution):
kernel_size, kernel_size,
indice_key=None, indice_key=None,
bias=True): bias=True):
super(SparseInverseConv3d, self).__init__( super().__init__(
3, 3,
in_channels, in_channels,
out_channels, out_channels,
...@@ -387,7 +387,7 @@ class SubMConv2d(SparseConvolution): ...@@ -387,7 +387,7 @@ class SubMConv2d(SparseConvolution):
groups=1, groups=1,
bias=True, bias=True,
indice_key=None): indice_key=None):
super(SubMConv2d, self).__init__( super().__init__(
2, 2,
in_channels, in_channels,
out_channels, out_channels,
...@@ -414,7 +414,7 @@ class SubMConv3d(SparseConvolution): ...@@ -414,7 +414,7 @@ class SubMConv3d(SparseConvolution):
groups=1, groups=1,
bias=True, bias=True,
indice_key=None): indice_key=None):
super(SubMConv3d, self).__init__( super().__init__(
3, 3,
in_channels, in_channels,
out_channels, out_channels,
...@@ -441,7 +441,7 @@ class SubMConv4d(SparseConvolution): ...@@ -441,7 +441,7 @@ class SubMConv4d(SparseConvolution):
groups=1, groups=1,
bias=True, bias=True,
indice_key=None): indice_key=None):
super(SubMConv4d, self).__init__( super().__init__(
4, 4,
in_channels, in_channels,
out_channels, out_channels,
......
...@@ -86,7 +86,7 @@ class SparseSequential(SparseModule): ...@@ -86,7 +86,7 @@ class SparseSequential(SparseModule):
""" """
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super(SparseSequential, self).__init__() super().__init__()
if len(args) == 1 and isinstance(args[0], OrderedDict): if len(args) == 1 and isinstance(args[0], OrderedDict):
for key, module in args[0].items(): for key, module in args[0].items():
self.add_module(key, module) self.add_module(key, module)
...@@ -103,7 +103,7 @@ class SparseSequential(SparseModule): ...@@ -103,7 +103,7 @@ class SparseSequential(SparseModule):
def __getitem__(self, idx): def __getitem__(self, idx):
if not (-len(self) <= idx < len(self)): if not (-len(self) <= idx < len(self)):
raise IndexError('index {} is out of range'.format(idx)) raise IndexError(f'index {idx} is out of range')
if idx < 0: if idx < 0:
idx += len(self) idx += len(self)
it = iter(self._modules.values()) it = iter(self._modules.values())
......
...@@ -29,7 +29,7 @@ class SparseMaxPool(SparseModule): ...@@ -29,7 +29,7 @@ class SparseMaxPool(SparseModule):
padding=0, padding=0,
dilation=1, dilation=1,
subm=False): subm=False):
super(SparseMaxPool, self).__init__() super().__init__()
if not isinstance(kernel_size, (list, tuple)): if not isinstance(kernel_size, (list, tuple)):
kernel_size = [kernel_size] * ndim kernel_size = [kernel_size] * ndim
if not isinstance(stride, (list, tuple)): if not isinstance(stride, (list, tuple)):
...@@ -77,12 +77,10 @@ class SparseMaxPool(SparseModule): ...@@ -77,12 +77,10 @@ class SparseMaxPool(SparseModule):
class SparseMaxPool2d(SparseMaxPool): class SparseMaxPool2d(SparseMaxPool):
def __init__(self, kernel_size, stride=1, padding=0, dilation=1): def __init__(self, kernel_size, stride=1, padding=0, dilation=1):
super(SparseMaxPool2d, self).__init__(2, kernel_size, stride, padding, super().__init__(2, kernel_size, stride, padding, dilation)
dilation)
class SparseMaxPool3d(SparseMaxPool): class SparseMaxPool3d(SparseMaxPool):
def __init__(self, kernel_size, stride=1, padding=0, dilation=1): def __init__(self, kernel_size, stride=1, padding=0, dilation=1):
super(SparseMaxPool3d, self).__init__(3, kernel_size, stride, padding, super().__init__(3, kernel_size, stride, padding, dilation)
dilation)
...@@ -18,7 +18,7 @@ def scatter_nd(indices, updates, shape): ...@@ -18,7 +18,7 @@ def scatter_nd(indices, updates, shape):
return ret return ret
class SparseConvTensor(object): class SparseConvTensor:
def __init__(self, def __init__(self,
features, features,
......
...@@ -198,7 +198,7 @@ class SyncBatchNorm(Module): ...@@ -198,7 +198,7 @@ class SyncBatchNorm(Module):
track_running_stats=True, track_running_stats=True,
group=None, group=None,
stats_mode='default'): stats_mode='default'):
super(SyncBatchNorm, self).__init__() super().__init__()
self.num_features = num_features self.num_features = num_features
self.eps = eps self.eps = eps
self.momentum = momentum self.momentum = momentum
......
...@@ -32,7 +32,7 @@ class MMDataParallel(DataParallel): ...@@ -32,7 +32,7 @@ class MMDataParallel(DataParallel):
""" """
def __init__(self, *args, dim=0, **kwargs): def __init__(self, *args, dim=0, **kwargs):
super(MMDataParallel, self).__init__(*args, dim=dim, **kwargs) super().__init__(*args, dim=dim, **kwargs)
self.dim = dim self.dim = dim
def forward(self, *inputs, **kwargs): def forward(self, *inputs, **kwargs):
......
...@@ -18,7 +18,7 @@ class MMDistributedDataParallel(nn.Module): ...@@ -18,7 +18,7 @@ class MMDistributedDataParallel(nn.Module):
dim=0, dim=0,
broadcast_buffers=True, broadcast_buffers=True,
bucket_cap_mb=25): bucket_cap_mb=25):
super(MMDistributedDataParallel, self).__init__() super().__init__()
self.module = module self.module = module
self.dim = dim self.dim = dim
self.broadcast_buffers = broadcast_buffers self.broadcast_buffers = broadcast_buffers
......
...@@ -35,7 +35,7 @@ class BaseModule(nn.Module, metaclass=ABCMeta): ...@@ -35,7 +35,7 @@ class BaseModule(nn.Module, metaclass=ABCMeta):
# NOTE init_cfg can be defined in different levels, but init_cfg # NOTE init_cfg can be defined in different levels, but init_cfg
# in low levels has a higher priority. # in low levels has a higher priority.
super(BaseModule, self).__init__() super().__init__()
# define default value of init_cfg instead of hard code # define default value of init_cfg instead of hard code
# in init_weights() function # in init_weights() function
self._is_init = False self._is_init = False
......
...@@ -83,8 +83,8 @@ class CheckpointHook(Hook): ...@@ -83,8 +83,8 @@ class CheckpointHook(Hook):
basename = osp.basename(runner.work_dir.rstrip(osp.sep)) basename = osp.basename(runner.work_dir.rstrip(osp.sep))
self.out_dir = self.file_client.join_path(self.out_dir, basename) self.out_dir = self.file_client.join_path(self.out_dir, basename)
runner.logger.info((f'Checkpoints will be saved to {self.out_dir} by ' runner.logger.info(f'Checkpoints will be saved to {self.out_dir} by '
f'{self.file_client.name}.')) f'{self.file_client.name}.')
# disable the create_symlink option because some file backends do not # disable the create_symlink option because some file backends do not
# allow to create a symlink # allow to create a symlink
...@@ -93,9 +93,9 @@ class CheckpointHook(Hook): ...@@ -93,9 +93,9 @@ class CheckpointHook(Hook):
'create_symlink'] and not self.file_client.allow_symlink: 'create_symlink'] and not self.file_client.allow_symlink:
self.args['create_symlink'] = False self.args['create_symlink'] = False
warnings.warn( warnings.warn(
('create_symlink is set as True by the user but is changed' 'create_symlink is set as True by the user but is changed'
'to be False because creating symbolic link is not ' 'to be False because creating symbolic link is not '
f'allowed in {self.file_client.name}')) f'allowed in {self.file_client.name}')
else: else:
self.args['create_symlink'] = self.file_client.allow_symlink self.args['create_symlink'] = self.file_client.allow_symlink
......
...@@ -214,8 +214,8 @@ class EvalHook(Hook): ...@@ -214,8 +214,8 @@ class EvalHook(Hook):
basename = osp.basename(runner.work_dir.rstrip(osp.sep)) basename = osp.basename(runner.work_dir.rstrip(osp.sep))
self.out_dir = self.file_client.join_path(self.out_dir, basename) self.out_dir = self.file_client.join_path(self.out_dir, basename)
runner.logger.info( runner.logger.info(
(f'The best checkpoint will be saved to {self.out_dir} by ' f'The best checkpoint will be saved to {self.out_dir} by '
f'{self.file_client.name}')) f'{self.file_client.name}')
if self.save_best is not None: if self.save_best is not None:
if runner.meta is None: if runner.meta is None:
...@@ -335,8 +335,8 @@ class EvalHook(Hook): ...@@ -335,8 +335,8 @@ class EvalHook(Hook):
self.best_ckpt_path): self.best_ckpt_path):
self.file_client.remove(self.best_ckpt_path) self.file_client.remove(self.best_ckpt_path)
runner.logger.info( runner.logger.info(
(f'The previous best checkpoint {self.best_ckpt_path} was ' f'The previous best checkpoint {self.best_ckpt_path} was '
'removed')) 'removed')
best_ckpt_name = f'best_{self.key_indicator}_{current}.pth' best_ckpt_name = f'best_{self.key_indicator}_{current}.pth'
self.best_ckpt_path = self.file_client.join_path( self.best_ckpt_path = self.file_client.join_path(
......
...@@ -34,8 +34,7 @@ class ClearMLLoggerHook(LoggerHook): ...@@ -34,8 +34,7 @@ class ClearMLLoggerHook(LoggerHook):
ignore_last=True, ignore_last=True,
reset_flag=False, reset_flag=False,
by_epoch=True): by_epoch=True):
super(ClearMLLoggerHook, self).__init__(interval, ignore_last, super().__init__(interval, ignore_last, reset_flag, by_epoch)
reset_flag, by_epoch)
self.import_clearml() self.import_clearml()
self.init_kwargs = init_kwargs self.init_kwargs = init_kwargs
...@@ -49,7 +48,7 @@ class ClearMLLoggerHook(LoggerHook): ...@@ -49,7 +48,7 @@ class ClearMLLoggerHook(LoggerHook):
@master_only @master_only
def before_run(self, runner): def before_run(self, runner):
super(ClearMLLoggerHook, self).before_run(runner) super().before_run(runner)
task_kwargs = self.init_kwargs if self.init_kwargs else {} task_kwargs = self.init_kwargs if self.init_kwargs else {}
self.task = self.clearml.Task.init(**task_kwargs) self.task = self.clearml.Task.init(**task_kwargs)
self.task_logger = self.task.get_logger() self.task_logger = self.task.get_logger()
......
...@@ -40,8 +40,7 @@ class MlflowLoggerHook(LoggerHook): ...@@ -40,8 +40,7 @@ class MlflowLoggerHook(LoggerHook):
ignore_last=True, ignore_last=True,
reset_flag=False, reset_flag=False,
by_epoch=True): by_epoch=True):
super(MlflowLoggerHook, self).__init__(interval, ignore_last, super().__init__(interval, ignore_last, reset_flag, by_epoch)
reset_flag, by_epoch)
self.import_mlflow() self.import_mlflow()
self.exp_name = exp_name self.exp_name = exp_name
self.tags = tags self.tags = tags
...@@ -59,7 +58,7 @@ class MlflowLoggerHook(LoggerHook): ...@@ -59,7 +58,7 @@ class MlflowLoggerHook(LoggerHook):
@master_only @master_only
def before_run(self, runner): def before_run(self, runner):
super(MlflowLoggerHook, self).before_run(runner) super().before_run(runner)
if self.exp_name is not None: if self.exp_name is not None:
self.mlflow.set_experiment(self.exp_name) self.mlflow.set_experiment(self.exp_name)
if self.tags is not None: if self.tags is not None:
......
...@@ -49,8 +49,7 @@ class NeptuneLoggerHook(LoggerHook): ...@@ -49,8 +49,7 @@ class NeptuneLoggerHook(LoggerHook):
with_step=True, with_step=True,
by_epoch=True): by_epoch=True):
super(NeptuneLoggerHook, self).__init__(interval, ignore_last, super().__init__(interval, ignore_last, reset_flag, by_epoch)
reset_flag, by_epoch)
self.import_neptune() self.import_neptune()
self.init_kwargs = init_kwargs self.init_kwargs = init_kwargs
self.with_step = with_step self.with_step = with_step
......
...@@ -40,8 +40,7 @@ class PaviLoggerHook(LoggerHook): ...@@ -40,8 +40,7 @@ class PaviLoggerHook(LoggerHook):
reset_flag=False, reset_flag=False,
by_epoch=True, by_epoch=True,
img_key='img_info'): img_key='img_info'):
super(PaviLoggerHook, self).__init__(interval, ignore_last, reset_flag, super().__init__(interval, ignore_last, reset_flag, by_epoch)
by_epoch)
self.init_kwargs = init_kwargs self.init_kwargs = init_kwargs
self.add_graph = add_graph self.add_graph = add_graph
self.add_last_ckpt = add_last_ckpt self.add_last_ckpt = add_last_ckpt
...@@ -49,7 +48,7 @@ class PaviLoggerHook(LoggerHook): ...@@ -49,7 +48,7 @@ class PaviLoggerHook(LoggerHook):
@master_only @master_only
def before_run(self, runner): def before_run(self, runner):
super(PaviLoggerHook, self).before_run(runner) super().before_run(runner)
try: try:
from pavi import SummaryWriter from pavi import SummaryWriter
except ImportError: except ImportError:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment