Unverified Commit 45fa3e44 authored by Zaida Zhou's avatar Zaida Zhou Committed by GitHub
Browse files

Add pyupgrade pre-commit hook (#1937)

* add pyupgrade

* add options for pyupgrade

* minor refinement
parent c561264d
......@@ -296,7 +296,7 @@ class SimpleRoIAlign(nn.Module):
If True, align the results more perfectly.
"""
super(SimpleRoIAlign, self).__init__()
super().__init__()
self.output_size = _pair(output_size)
self.spatial_scale = float(spatial_scale)
# to be consistent with other RoI ops
......
......@@ -72,7 +72,7 @@ psa_mask = PSAMaskFunction.apply
class PSAMask(nn.Module):
def __init__(self, psa_type, mask_size=None):
super(PSAMask, self).__init__()
super().__init__()
assert psa_type in ['collect', 'distribute']
if psa_type == 'collect':
psa_type_enum = 0
......
......@@ -116,7 +116,7 @@ class RiRoIAlignRotated(nn.Module):
num_samples=0,
num_orientations=8,
clockwise=False):
super(RiRoIAlignRotated, self).__init__()
super().__init__()
self.out_size = out_size
self.spatial_scale = float(spatial_scale)
......
......@@ -181,7 +181,7 @@ class RoIAlign(nn.Module):
pool_mode='avg',
aligned=True,
use_torchvision=False):
super(RoIAlign, self).__init__()
super().__init__()
self.output_size = _pair(output_size)
self.spatial_scale = float(spatial_scale)
......
......@@ -156,7 +156,7 @@ class RoIAlignRotated(nn.Module):
sampling_ratio=0,
aligned=True,
clockwise=False):
super(RoIAlignRotated, self).__init__()
super().__init__()
self.output_size = _pair(output_size)
self.spatial_scale = float(spatial_scale)
......
......@@ -71,7 +71,7 @@ roi_pool = RoIPoolFunction.apply
class RoIPool(nn.Module):
def __init__(self, output_size, spatial_scale=1.0):
super(RoIPool, self).__init__()
super().__init__()
self.output_size = _pair(output_size)
self.spatial_scale = float(spatial_scale)
......
......@@ -64,7 +64,7 @@ class SparseConvolution(SparseModule):
inverse=False,
indice_key=None,
fused_bn=False):
super(SparseConvolution, self).__init__()
super().__init__()
assert groups == 1
if not isinstance(kernel_size, (list, tuple)):
kernel_size = [kernel_size] * ndim
......@@ -217,7 +217,7 @@ class SparseConv2d(SparseConvolution):
groups=1,
bias=True,
indice_key=None):
super(SparseConv2d, self).__init__(
super().__init__(
2,
in_channels,
out_channels,
......@@ -243,7 +243,7 @@ class SparseConv3d(SparseConvolution):
groups=1,
bias=True,
indice_key=None):
super(SparseConv3d, self).__init__(
super().__init__(
3,
in_channels,
out_channels,
......@@ -269,7 +269,7 @@ class SparseConv4d(SparseConvolution):
groups=1,
bias=True,
indice_key=None):
super(SparseConv4d, self).__init__(
super().__init__(
4,
in_channels,
out_channels,
......@@ -295,7 +295,7 @@ class SparseConvTranspose2d(SparseConvolution):
groups=1,
bias=True,
indice_key=None):
super(SparseConvTranspose2d, self).__init__(
super().__init__(
2,
in_channels,
out_channels,
......@@ -322,7 +322,7 @@ class SparseConvTranspose3d(SparseConvolution):
groups=1,
bias=True,
indice_key=None):
super(SparseConvTranspose3d, self).__init__(
super().__init__(
3,
in_channels,
out_channels,
......@@ -345,7 +345,7 @@ class SparseInverseConv2d(SparseConvolution):
kernel_size,
indice_key=None,
bias=True):
super(SparseInverseConv2d, self).__init__(
super().__init__(
2,
in_channels,
out_channels,
......@@ -364,7 +364,7 @@ class SparseInverseConv3d(SparseConvolution):
kernel_size,
indice_key=None,
bias=True):
super(SparseInverseConv3d, self).__init__(
super().__init__(
3,
in_channels,
out_channels,
......@@ -387,7 +387,7 @@ class SubMConv2d(SparseConvolution):
groups=1,
bias=True,
indice_key=None):
super(SubMConv2d, self).__init__(
super().__init__(
2,
in_channels,
out_channels,
......@@ -414,7 +414,7 @@ class SubMConv3d(SparseConvolution):
groups=1,
bias=True,
indice_key=None):
super(SubMConv3d, self).__init__(
super().__init__(
3,
in_channels,
out_channels,
......@@ -441,7 +441,7 @@ class SubMConv4d(SparseConvolution):
groups=1,
bias=True,
indice_key=None):
super(SubMConv4d, self).__init__(
super().__init__(
4,
in_channels,
out_channels,
......
......@@ -86,7 +86,7 @@ class SparseSequential(SparseModule):
"""
def __init__(self, *args, **kwargs):
super(SparseSequential, self).__init__()
super().__init__()
if len(args) == 1 and isinstance(args[0], OrderedDict):
for key, module in args[0].items():
self.add_module(key, module)
......@@ -103,7 +103,7 @@ class SparseSequential(SparseModule):
def __getitem__(self, idx):
if not (-len(self) <= idx < len(self)):
raise IndexError('index {} is out of range'.format(idx))
raise IndexError(f'index {idx} is out of range')
if idx < 0:
idx += len(self)
it = iter(self._modules.values())
......
......@@ -29,7 +29,7 @@ class SparseMaxPool(SparseModule):
padding=0,
dilation=1,
subm=False):
super(SparseMaxPool, self).__init__()
super().__init__()
if not isinstance(kernel_size, (list, tuple)):
kernel_size = [kernel_size] * ndim
if not isinstance(stride, (list, tuple)):
......@@ -77,12 +77,10 @@ class SparseMaxPool(SparseModule):
class SparseMaxPool2d(SparseMaxPool):
def __init__(self, kernel_size, stride=1, padding=0, dilation=1):
super(SparseMaxPool2d, self).__init__(2, kernel_size, stride, padding,
dilation)
super().__init__(2, kernel_size, stride, padding, dilation)
class SparseMaxPool3d(SparseMaxPool):
def __init__(self, kernel_size, stride=1, padding=0, dilation=1):
super(SparseMaxPool3d, self).__init__(3, kernel_size, stride, padding,
dilation)
super().__init__(3, kernel_size, stride, padding, dilation)
......@@ -18,7 +18,7 @@ def scatter_nd(indices, updates, shape):
return ret
class SparseConvTensor(object):
class SparseConvTensor:
def __init__(self,
features,
......
......@@ -198,7 +198,7 @@ class SyncBatchNorm(Module):
track_running_stats=True,
group=None,
stats_mode='default'):
super(SyncBatchNorm, self).__init__()
super().__init__()
self.num_features = num_features
self.eps = eps
self.momentum = momentum
......
......@@ -32,7 +32,7 @@ class MMDataParallel(DataParallel):
"""
def __init__(self, *args, dim=0, **kwargs):
super(MMDataParallel, self).__init__(*args, dim=dim, **kwargs)
super().__init__(*args, dim=dim, **kwargs)
self.dim = dim
def forward(self, *inputs, **kwargs):
......
......@@ -18,7 +18,7 @@ class MMDistributedDataParallel(nn.Module):
dim=0,
broadcast_buffers=True,
bucket_cap_mb=25):
super(MMDistributedDataParallel, self).__init__()
super().__init__()
self.module = module
self.dim = dim
self.broadcast_buffers = broadcast_buffers
......
......@@ -35,7 +35,7 @@ class BaseModule(nn.Module, metaclass=ABCMeta):
# NOTE init_cfg can be defined in different levels, but init_cfg
# in low levels has a higher priority.
super(BaseModule, self).__init__()
super().__init__()
# define default value of init_cfg instead of hard code
# in init_weights() function
self._is_init = False
......
......@@ -83,8 +83,8 @@ class CheckpointHook(Hook):
basename = osp.basename(runner.work_dir.rstrip(osp.sep))
self.out_dir = self.file_client.join_path(self.out_dir, basename)
runner.logger.info((f'Checkpoints will be saved to {self.out_dir} by '
f'{self.file_client.name}.'))
runner.logger.info(f'Checkpoints will be saved to {self.out_dir} by '
f'{self.file_client.name}.')
# disable the create_symlink option because some file backends do not
# allow to create a symlink
......@@ -93,9 +93,9 @@ class CheckpointHook(Hook):
'create_symlink'] and not self.file_client.allow_symlink:
self.args['create_symlink'] = False
warnings.warn(
('create_symlink is set as True by the user but is changed'
'to be False because creating symbolic link is not '
f'allowed in {self.file_client.name}'))
'create_symlink is set as True by the user but is changed'
'to be False because creating symbolic link is not '
f'allowed in {self.file_client.name}')
else:
self.args['create_symlink'] = self.file_client.allow_symlink
......
......@@ -214,8 +214,8 @@ class EvalHook(Hook):
basename = osp.basename(runner.work_dir.rstrip(osp.sep))
self.out_dir = self.file_client.join_path(self.out_dir, basename)
runner.logger.info(
(f'The best checkpoint will be saved to {self.out_dir} by '
f'{self.file_client.name}'))
f'The best checkpoint will be saved to {self.out_dir} by '
f'{self.file_client.name}')
if self.save_best is not None:
if runner.meta is None:
......@@ -335,8 +335,8 @@ class EvalHook(Hook):
self.best_ckpt_path):
self.file_client.remove(self.best_ckpt_path)
runner.logger.info(
(f'The previous best checkpoint {self.best_ckpt_path} was '
'removed'))
f'The previous best checkpoint {self.best_ckpt_path} was '
'removed')
best_ckpt_name = f'best_{self.key_indicator}_{current}.pth'
self.best_ckpt_path = self.file_client.join_path(
......
......@@ -34,8 +34,7 @@ class ClearMLLoggerHook(LoggerHook):
ignore_last=True,
reset_flag=False,
by_epoch=True):
super(ClearMLLoggerHook, self).__init__(interval, ignore_last,
reset_flag, by_epoch)
super().__init__(interval, ignore_last, reset_flag, by_epoch)
self.import_clearml()
self.init_kwargs = init_kwargs
......@@ -49,7 +48,7 @@ class ClearMLLoggerHook(LoggerHook):
@master_only
def before_run(self, runner):
super(ClearMLLoggerHook, self).before_run(runner)
super().before_run(runner)
task_kwargs = self.init_kwargs if self.init_kwargs else {}
self.task = self.clearml.Task.init(**task_kwargs)
self.task_logger = self.task.get_logger()
......
......@@ -40,8 +40,7 @@ class MlflowLoggerHook(LoggerHook):
ignore_last=True,
reset_flag=False,
by_epoch=True):
super(MlflowLoggerHook, self).__init__(interval, ignore_last,
reset_flag, by_epoch)
super().__init__(interval, ignore_last, reset_flag, by_epoch)
self.import_mlflow()
self.exp_name = exp_name
self.tags = tags
......@@ -59,7 +58,7 @@ class MlflowLoggerHook(LoggerHook):
@master_only
def before_run(self, runner):
super(MlflowLoggerHook, self).before_run(runner)
super().before_run(runner)
if self.exp_name is not None:
self.mlflow.set_experiment(self.exp_name)
if self.tags is not None:
......
......@@ -49,8 +49,7 @@ class NeptuneLoggerHook(LoggerHook):
with_step=True,
by_epoch=True):
super(NeptuneLoggerHook, self).__init__(interval, ignore_last,
reset_flag, by_epoch)
super().__init__(interval, ignore_last, reset_flag, by_epoch)
self.import_neptune()
self.init_kwargs = init_kwargs
self.with_step = with_step
......
......@@ -40,8 +40,7 @@ class PaviLoggerHook(LoggerHook):
reset_flag=False,
by_epoch=True,
img_key='img_info'):
super(PaviLoggerHook, self).__init__(interval, ignore_last, reset_flag,
by_epoch)
super().__init__(interval, ignore_last, reset_flag, by_epoch)
self.init_kwargs = init_kwargs
self.add_graph = add_graph
self.add_last_ckpt = add_last_ckpt
......@@ -49,7 +48,7 @@ class PaviLoggerHook(LoggerHook):
@master_only
def before_run(self, runner):
super(PaviLoggerHook, self).before_run(runner)
super().before_run(runner)
try:
from pavi import SummaryWriter
except ImportError:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment