Unverified Commit 45fa3e44 authored by Zaida Zhou's avatar Zaida Zhou Committed by GitHub
Browse files

Add pyupgrade pre-commit hook (#1937)

* add pyupgrade

* add options for pyupgrade

* minor refinement
parent c561264d
......@@ -27,8 +27,7 @@ class SegmindLoggerHook(LoggerHook):
ignore_last=True,
reset_flag=False,
by_epoch=True):
super(SegmindLoggerHook, self).__init__(interval, ignore_last,
reset_flag, by_epoch)
super().__init__(interval, ignore_last, reset_flag, by_epoch)
self.import_segmind()
def import_segmind(self):
......
......@@ -28,13 +28,12 @@ class TensorboardLoggerHook(LoggerHook):
ignore_last=True,
reset_flag=False,
by_epoch=True):
super(TensorboardLoggerHook, self).__init__(interval, ignore_last,
reset_flag, by_epoch)
super().__init__(interval, ignore_last, reset_flag, by_epoch)
self.log_dir = log_dir
@master_only
def before_run(self, runner):
super(TensorboardLoggerHook, self).before_run(runner)
super().before_run(runner)
if (TORCH_VERSION == 'parrots'
or digit_version(TORCH_VERSION) < digit_version('1.1')):
try:
......
......@@ -62,8 +62,7 @@ class TextLoggerHook(LoggerHook):
out_suffix=('.log.json', '.log', '.py'),
keep_local=True,
file_client_args=None):
super(TextLoggerHook, self).__init__(interval, ignore_last, reset_flag,
by_epoch)
super().__init__(interval, ignore_last, reset_flag, by_epoch)
self.by_epoch = by_epoch
self.time_sec_tot = 0
self.interval_exp_name = interval_exp_name
......@@ -87,7 +86,7 @@ class TextLoggerHook(LoggerHook):
self.out_dir)
def before_run(self, runner):
super(TextLoggerHook, self).before_run(runner)
super().before_run(runner)
if self.out_dir is not None:
self.file_client = FileClient.infer_client(self.file_client_args,
......@@ -97,8 +96,8 @@ class TextLoggerHook(LoggerHook):
basename = osp.basename(runner.work_dir.rstrip(osp.sep))
self.out_dir = self.file_client.join_path(self.out_dir, basename)
runner.logger.info(
(f'Text logs will be saved to {self.out_dir} by '
f'{self.file_client.name} after the training process.'))
f'Text logs will be saved to {self.out_dir} by '
f'{self.file_client.name} after the training process.')
self.start_iter = runner.iter
self.json_log_path = osp.join(runner.work_dir,
......@@ -242,15 +241,15 @@ class TextLoggerHook(LoggerHook):
local_filepath = osp.join(runner.work_dir, filename)
out_filepath = self.file_client.join_path(
self.out_dir, filename)
with open(local_filepath, 'r') as f:
with open(local_filepath) as f:
self.file_client.put_text(f.read(), out_filepath)
runner.logger.info(
(f'The file {local_filepath} has been uploaded to '
f'{out_filepath}.'))
f'The file {local_filepath} has been uploaded to '
f'{out_filepath}.')
if not self.keep_local:
os.remove(local_filepath)
runner.logger.info(
(f'{local_filepath} was removed due to the '
'`self.keep_local=False`'))
f'{local_filepath} was removed due to the '
'`self.keep_local=False`')
......@@ -57,8 +57,7 @@ class WandbLoggerHook(LoggerHook):
with_step=True,
log_artifact=True,
out_suffix=('.log.json', '.log', '.py')):
super(WandbLoggerHook, self).__init__(interval, ignore_last,
reset_flag, by_epoch)
super().__init__(interval, ignore_last, reset_flag, by_epoch)
self.import_wandb()
self.init_kwargs = init_kwargs
self.commit = commit
......@@ -76,7 +75,7 @@ class WandbLoggerHook(LoggerHook):
@master_only
def before_run(self, runner):
super(WandbLoggerHook, self).before_run(runner)
super().before_run(runner)
if self.wandb is None:
self.import_wandb()
if self.init_kwargs:
......
......@@ -157,7 +157,7 @@ class LrUpdaterHook(Hook):
class FixedLrUpdaterHook(LrUpdaterHook):
def __init__(self, **kwargs):
super(FixedLrUpdaterHook, self).__init__(**kwargs)
super().__init__(**kwargs)
def get_lr(self, runner, base_lr):
return base_lr
......@@ -188,7 +188,7 @@ class StepLrUpdaterHook(LrUpdaterHook):
self.step = step
self.gamma = gamma
self.min_lr = min_lr
super(StepLrUpdaterHook, self).__init__(**kwargs)
super().__init__(**kwargs)
def get_lr(self, runner, base_lr):
progress = runner.epoch if self.by_epoch else runner.iter
......@@ -215,7 +215,7 @@ class ExpLrUpdaterHook(LrUpdaterHook):
def __init__(self, gamma, **kwargs):
self.gamma = gamma
super(ExpLrUpdaterHook, self).__init__(**kwargs)
super().__init__(**kwargs)
def get_lr(self, runner, base_lr):
progress = runner.epoch if self.by_epoch else runner.iter
......@@ -228,7 +228,7 @@ class PolyLrUpdaterHook(LrUpdaterHook):
def __init__(self, power=1., min_lr=0., **kwargs):
self.power = power
self.min_lr = min_lr
super(PolyLrUpdaterHook, self).__init__(**kwargs)
super().__init__(**kwargs)
def get_lr(self, runner, base_lr):
if self.by_epoch:
......@@ -247,7 +247,7 @@ class InvLrUpdaterHook(LrUpdaterHook):
def __init__(self, gamma, power=1., **kwargs):
self.gamma = gamma
self.power = power
super(InvLrUpdaterHook, self).__init__(**kwargs)
super().__init__(**kwargs)
def get_lr(self, runner, base_lr):
progress = runner.epoch if self.by_epoch else runner.iter
......@@ -269,7 +269,7 @@ class CosineAnnealingLrUpdaterHook(LrUpdaterHook):
assert (min_lr is None) ^ (min_lr_ratio is None)
self.min_lr = min_lr
self.min_lr_ratio = min_lr_ratio
super(CosineAnnealingLrUpdaterHook, self).__init__(**kwargs)
super().__init__(**kwargs)
def get_lr(self, runner, base_lr):
if self.by_epoch:
......@@ -317,7 +317,7 @@ class FlatCosineAnnealingLrUpdaterHook(LrUpdaterHook):
self.start_percent = start_percent
self.min_lr = min_lr
self.min_lr_ratio = min_lr_ratio
super(FlatCosineAnnealingLrUpdaterHook, self).__init__(**kwargs)
super().__init__(**kwargs)
def get_lr(self, runner, base_lr):
if self.by_epoch:
......@@ -367,7 +367,7 @@ class CosineRestartLrUpdaterHook(LrUpdaterHook):
self.restart_weights = restart_weights
assert (len(self.periods) == len(self.restart_weights)
), 'periods and restart_weights should have the same length.'
super(CosineRestartLrUpdaterHook, self).__init__(**kwargs)
super().__init__(**kwargs)
self.cumulative_periods = [
sum(self.periods[0:i + 1]) for i in range(0, len(self.periods))
......@@ -484,10 +484,10 @@ class CyclicLrUpdaterHook(LrUpdaterHook):
assert not by_epoch, \
'currently only support "by_epoch" = False'
super(CyclicLrUpdaterHook, self).__init__(by_epoch, **kwargs)
super().__init__(by_epoch, **kwargs)
def before_run(self, runner):
super(CyclicLrUpdaterHook, self).before_run(runner)
super().before_run(runner)
# initiate lr_phases
# total lr_phases are separated as up and down
self.max_iter_per_phase = runner.max_iters // self.cyclic_times
......@@ -598,7 +598,7 @@ class OneCycleLrUpdaterHook(LrUpdaterHook):
self.final_div_factor = final_div_factor
self.three_phase = three_phase
self.lr_phases = [] # init lr_phases
super(OneCycleLrUpdaterHook, self).__init__(**kwargs)
super().__init__(**kwargs)
def before_run(self, runner):
if hasattr(self, 'total_steps'):
......@@ -668,7 +668,7 @@ class LinearAnnealingLrUpdaterHook(LrUpdaterHook):
assert (min_lr is None) ^ (min_lr_ratio is None)
self.min_lr = min_lr
self.min_lr_ratio = min_lr_ratio
super(LinearAnnealingLrUpdaterHook, self).__init__(**kwargs)
super().__init__(**kwargs)
def get_lr(self, runner, base_lr):
if self.by_epoch:
......
......@@ -176,7 +176,7 @@ class StepMomentumUpdaterHook(MomentumUpdaterHook):
self.step = step
self.gamma = gamma
self.min_momentum = min_momentum
super(StepMomentumUpdaterHook, self).__init__(**kwargs)
super().__init__(**kwargs)
def get_momentum(self, runner, base_momentum):
progress = runner.epoch if self.by_epoch else runner.iter
......@@ -214,7 +214,7 @@ class CosineAnnealingMomentumUpdaterHook(MomentumUpdaterHook):
assert (min_momentum is None) ^ (min_momentum_ratio is None)
self.min_momentum = min_momentum
self.min_momentum_ratio = min_momentum_ratio
super(CosineAnnealingMomentumUpdaterHook, self).__init__(**kwargs)
super().__init__(**kwargs)
def get_momentum(self, runner, base_momentum):
if self.by_epoch:
......@@ -247,7 +247,7 @@ class LinearAnnealingMomentumUpdaterHook(MomentumUpdaterHook):
assert (min_momentum is None) ^ (min_momentum_ratio is None)
self.min_momentum = min_momentum
self.min_momentum_ratio = min_momentum_ratio
super(LinearAnnealingMomentumUpdaterHook, self).__init__(**kwargs)
super().__init__(**kwargs)
def get_momentum(self, runner, base_momentum):
if self.by_epoch:
......@@ -328,10 +328,10 @@ class CyclicMomentumUpdaterHook(MomentumUpdaterHook):
# currently only support by_epoch=False
assert not by_epoch, \
'currently only support "by_epoch" = False'
super(CyclicMomentumUpdaterHook, self).__init__(by_epoch, **kwargs)
super().__init__(by_epoch, **kwargs)
def before_run(self, runner):
super(CyclicMomentumUpdaterHook, self).before_run(runner)
super().before_run(runner)
# initiate momentum_phases
# total momentum_phases are separated as up and down
max_iter_per_phase = runner.max_iters // self.cyclic_times
......@@ -439,7 +439,7 @@ class OneCycleMomentumUpdaterHook(MomentumUpdaterHook):
self.anneal_func = annealing_linear
self.three_phase = three_phase
self.momentum_phases = [] # init momentum_phases
super(OneCycleMomentumUpdaterHook, self).__init__(**kwargs)
super().__init__(**kwargs)
def before_run(self, runner):
if isinstance(runner.optimizer, dict):
......
......@@ -110,7 +110,7 @@ class GradientCumulativeOptimizerHook(OptimizerHook):
"""
def __init__(self, cumulative_iters=1, **kwargs):
super(GradientCumulativeOptimizerHook, self).__init__(**kwargs)
super().__init__(**kwargs)
assert isinstance(cumulative_iters, int) and cumulative_iters > 0, \
f'cumulative_iters only accepts positive int, but got ' \
......@@ -297,8 +297,7 @@ if (TORCH_VERSION != 'parrots'
"""
def __init__(self, *args, **kwargs):
super(GradientCumulativeFp16OptimizerHook,
self).__init__(*args, **kwargs)
super().__init__(*args, **kwargs)
def after_train_iter(self, runner):
if not self.initialized:
......@@ -490,8 +489,7 @@ else:
iters gradient cumulating."""
def __init__(self, *args, **kwargs):
super(GradientCumulativeFp16OptimizerHook,
self).__init__(*args, **kwargs)
super().__init__(*args, **kwargs)
def after_train_iter(self, runner):
if not self.initialized:
......
......@@ -263,7 +263,7 @@ class IterBasedRunner(BaseRunner):
if log_config is not None:
for info in log_config['hooks']:
info.setdefault('by_epoch', False)
super(IterBasedRunner, self).register_training_hooks(
super().register_training_hooks(
lr_config=lr_config,
momentum_config=momentum_config,
optimizer_config=optimizer_config,
......
......@@ -54,7 +54,7 @@ def onnx2trt(onnx_model: Union[str, onnx.ModelProto],
msg += reset_style
warnings.warn(msg)
device = torch.device('cuda:{}'.format(device_id))
device = torch.device(f'cuda:{device_id}')
# create builder and network
logger = trt.Logger(log_level)
builder = trt.Builder(logger)
......@@ -209,7 +209,7 @@ class TRTWrapper(torch.nn.Module):
msg += reset_style
warnings.warn(msg)
super(TRTWrapper, self).__init__()
super().__init__()
self.engine = engine
if isinstance(self.engine, str):
self.engine = load_trt_engine(engine)
......
......@@ -39,7 +39,7 @@ class ConfigDict(Dict):
def __getattr__(self, name):
try:
value = super(ConfigDict, self).__getattr__(name)
value = super().__getattr__(name)
except KeyError:
ex = AttributeError(f"'{self.__class__.__name__}' object has no "
f"attribute '{name}'")
......@@ -96,7 +96,7 @@ class Config:
@staticmethod
def _validate_py_syntax(filename):
with open(filename, 'r', encoding='utf-8') as f:
with open(filename, encoding='utf-8') as f:
# Setting encoding explicitly to resolve coding issue on windows
content = f.read()
try:
......@@ -116,7 +116,7 @@ class Config:
fileBasename=file_basename,
fileBasenameNoExtension=file_basename_no_extension,
fileExtname=file_extname)
with open(filename, 'r', encoding='utf-8') as f:
with open(filename, encoding='utf-8') as f:
# Setting encoding explicitly to resolve coding issue on windows
config_file = f.read()
for key, value in support_templates.items():
......@@ -130,7 +130,7 @@ class Config:
def _pre_substitute_base_vars(filename, temp_config_name):
"""Substitute base variable placehoders to string, so that parsing
would work."""
with open(filename, 'r', encoding='utf-8') as f:
with open(filename, encoding='utf-8') as f:
# Setting encoding explicitly to resolve coding issue on windows
config_file = f.read()
base_var_dict = {}
......@@ -183,7 +183,7 @@ class Config:
check_file_exist(filename)
fileExtname = osp.splitext(filename)[1]
if fileExtname not in ['.py', '.json', '.yaml', '.yml']:
raise IOError('Only py/yml/yaml/json type are supported now!')
raise OSError('Only py/yml/yaml/json type are supported now!')
with tempfile.TemporaryDirectory() as temp_config_dir:
temp_config_file = tempfile.NamedTemporaryFile(
......@@ -236,7 +236,7 @@ class Config:
warnings.warn(warning_msg, DeprecationWarning)
cfg_text = filename + '\n'
with open(filename, 'r', encoding='utf-8') as f:
with open(filename, encoding='utf-8') as f:
# Setting encoding explicitly to resolve coding issue on windows
cfg_text += f.read()
......@@ -356,7 +356,7 @@ class Config:
:obj:`Config`: Config obj.
"""
if file_format not in ['.py', '.json', '.yaml', '.yml']:
raise IOError('Only py/yml/yaml/json type are supported now!')
raise OSError('Only py/yml/yaml/json type are supported now!')
if file_format != '.py' and 'dict(' in cfg_str:
# check if users specify a wrong suffix for python
warnings.warn(
......@@ -396,16 +396,16 @@ class Config:
if isinstance(filename, Path):
filename = str(filename)
super(Config, self).__setattr__('_cfg_dict', ConfigDict(cfg_dict))
super(Config, self).__setattr__('_filename', filename)
super().__setattr__('_cfg_dict', ConfigDict(cfg_dict))
super().__setattr__('_filename', filename)
if cfg_text:
text = cfg_text
elif filename:
with open(filename, 'r') as f:
with open(filename) as f:
text = f.read()
else:
text = ''
super(Config, self).__setattr__('_text', text)
super().__setattr__('_text', text)
@property
def filename(self):
......@@ -556,9 +556,9 @@ class Config:
def __setstate__(self, state):
_cfg_dict, _filename, _text = state
super(Config, self).__setattr__('_cfg_dict', _cfg_dict)
super(Config, self).__setattr__('_filename', _filename)
super(Config, self).__setattr__('_text', _text)
super().__setattr__('_cfg_dict', _cfg_dict)
super().__setattr__('_filename', _filename)
super().__setattr__('_text', _text)
def dump(self, file=None):
"""Dumps config into a file or returns a string representation of the
......@@ -584,7 +584,7 @@ class Config:
will be dumped. Defaults to None.
"""
import mmcv
cfg_dict = super(Config, self).__getattribute__('_cfg_dict').to_dict()
cfg_dict = super().__getattribute__('_cfg_dict').to_dict()
if file is None:
if self.filename is None or self.filename.endswith('.py'):
return self.pretty_text
......@@ -638,8 +638,8 @@ class Config:
subkey = key_list[-1]
d[subkey] = v
cfg_dict = super(Config, self).__getattribute__('_cfg_dict')
super(Config, self).__setattr__(
cfg_dict = super().__getattribute__('_cfg_dict')
super().__setattr__(
'_cfg_dict',
Config._merge_a_into_b(
option_cfg_dict, cfg_dict, allow_list_keys=allow_list_keys))
......
......@@ -6,7 +6,7 @@ class TimerError(Exception):
def __init__(self, message):
self.message = message
super(TimerError, self).__init__(message)
super().__init__(message)
class Timer:
......
......@@ -40,10 +40,10 @@ def flowread(flow_or_path: Union[np.ndarray, str],
try:
header = f.read(4).decode('utf-8')
except Exception:
raise IOError(f'Invalid flow file: {flow_or_path}')
raise OSError(f'Invalid flow file: {flow_or_path}')
else:
if header != 'PIEH':
raise IOError(f'Invalid flow file: {flow_or_path}, '
raise OSError(f'Invalid flow file: {flow_or_path}, '
'header does not contain PIEH')
w = np.fromfile(f, np.int32, 1).squeeze()
......@@ -53,7 +53,7 @@ def flowread(flow_or_path: Union[np.ndarray, str],
assert concat_axis in [0, 1]
cat_flow = imread(flow_or_path, flag='unchanged')
if cat_flow.ndim != 2:
raise IOError(
raise OSError(
f'{flow_or_path} is not a valid quantized flow file, '
f'its dimension is {cat_flow.ndim}.')
assert cat_flow.shape[concat_axis] % 2 == 0
......@@ -86,7 +86,7 @@ def flowwrite(flow: np.ndarray,
"""
if not quantize:
with open(filename, 'wb') as f:
f.write('PIEH'.encode('utf-8'))
f.write(b'PIEH')
np.array([flow.shape[1], flow.shape[0]], dtype=np.int32).tofile(f)
flow = flow.astype(np.float32)
flow.tofile(f)
......@@ -146,7 +146,7 @@ def dequantize_flow(dx: np.ndarray,
assert dx.shape == dy.shape
assert dx.ndim == 2 or (dx.ndim == 3 and dx.shape[-1] == 1)
dx, dy = [dequantize(d, -max_val, max_val, 255) for d in [dx, dy]]
dx, dy = (dequantize(d, -max_val, max_val, 255) for d in [dx, dy])
if denorm:
dx *= dx.shape[1]
......
# Copyright (c) OpenMMLab. All rights reserved.
from __future__ import division
from typing import Optional, Union
import numpy as np
......
......@@ -39,7 +39,7 @@ def choose_requirement(primary, secondary):
def get_version():
version_file = 'mmcv/version.py'
with open(version_file, 'r', encoding='utf-8') as f:
with open(version_file, encoding='utf-8') as f:
exec(compile(f.read(), version_file, 'exec'))
return locals()['__version__']
......@@ -94,12 +94,11 @@ def parse_requirements(fname='requirements/runtime.txt', with_version=True):
yield info
def parse_require_file(fpath):
with open(fpath, 'r') as f:
with open(fpath) as f:
for line in f.readlines():
line = line.strip()
if line and not line.startswith('#'):
for info in parse_line(line):
yield info
yield from parse_line(line)
def gen_packages_items():
if exists(require_fpath):
......
# Copyright (c) OpenMMLab. All rights reserved.
from __future__ import division
import numpy as np
import pytest
......
......@@ -23,7 +23,7 @@ class ExampleConv(nn.Module):
groups=1,
bias=True,
norm_cfg=None):
super(ExampleConv, self).__init__()
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
......
......@@ -202,21 +202,22 @@ class TestFileClient:
# test `list_dir_or_file`
with build_temporary_directory() as tmp_dir:
# 1. list directories and files
assert set(disk_backend.list_dir_or_file(tmp_dir)) == set(
['dir1', 'dir2', 'text1.txt', 'text2.txt'])
assert set(disk_backend.list_dir_or_file(tmp_dir)) == {
'dir1', 'dir2', 'text1.txt', 'text2.txt'
}
# 2. list directories and files recursively
assert set(disk_backend.list_dir_or_file(
tmp_dir, recursive=True)) == set([
tmp_dir, recursive=True)) == {
'dir1',
osp.join('dir1', 'text3.txt'), 'dir2',
osp.join('dir2', 'dir3'),
osp.join('dir2', 'dir3', 'text4.txt'),
osp.join('dir2', 'img.jpg'), 'text1.txt', 'text2.txt'
])
}
# 3. only list directories
assert set(
disk_backend.list_dir_or_file(
tmp_dir, list_file=False)) == set(['dir1', 'dir2'])
tmp_dir, list_file=False)) == {'dir1', 'dir2'}
with pytest.raises(
TypeError,
match='`suffix` should be None when `list_dir` is True'):
......@@ -227,30 +228,30 @@ class TestFileClient:
# 4. only list directories recursively
assert set(
disk_backend.list_dir_or_file(
tmp_dir, list_file=False, recursive=True)) == set(
['dir1', 'dir2',
osp.join('dir2', 'dir3')])
tmp_dir, list_file=False, recursive=True)) == {
'dir1', 'dir2',
osp.join('dir2', 'dir3')
}
# 5. only list files
assert set(disk_backend.list_dir_or_file(
tmp_dir, list_dir=False)) == set(['text1.txt', 'text2.txt'])
tmp_dir, list_dir=False)) == {'text1.txt', 'text2.txt'}
# 6. only list files recursively
assert set(
disk_backend.list_dir_or_file(
tmp_dir, list_dir=False, recursive=True)) == set([
tmp_dir, list_dir=False, recursive=True)) == {
osp.join('dir1', 'text3.txt'),
osp.join('dir2', 'dir3', 'text4.txt'),
osp.join('dir2', 'img.jpg'), 'text1.txt', 'text2.txt'
])
}
# 7. only list files ending with suffix
assert set(
disk_backend.list_dir_or_file(
tmp_dir, list_dir=False,
suffix='.txt')) == set(['text1.txt', 'text2.txt'])
suffix='.txt')) == {'text1.txt', 'text2.txt'}
assert set(
disk_backend.list_dir_or_file(
tmp_dir, list_dir=False,
suffix=('.txt',
'.jpg'))) == set(['text1.txt', 'text2.txt'])
suffix=('.txt', '.jpg'))) == {'text1.txt', 'text2.txt'}
with pytest.raises(
TypeError,
match='`suffix` must be a string or tuple of strings'):
......@@ -260,22 +261,22 @@ class TestFileClient:
assert set(
disk_backend.list_dir_or_file(
tmp_dir, list_dir=False, suffix='.txt',
recursive=True)) == set([
recursive=True)) == {
osp.join('dir1', 'text3.txt'),
osp.join('dir2', 'dir3', 'text4.txt'), 'text1.txt',
'text2.txt'
])
}
# 7. only list files ending with suffix
assert set(
disk_backend.list_dir_or_file(
tmp_dir,
list_dir=False,
suffix=('.txt', '.jpg'),
recursive=True)) == set([
recursive=True)) == {
osp.join('dir1', 'text3.txt'),
osp.join('dir2', 'dir3', 'text4.txt'),
osp.join('dir2', 'img.jpg'), 'text1.txt', 'text2.txt'
])
}
@patch('ceph.S3Client', MockS3Client)
def test_ceph_backend(self):
......@@ -463,21 +464,21 @@ class TestFileClient:
with build_temporary_directory() as tmp_dir:
# 1. list directories and files
assert set(petrel_backend.list_dir_or_file(tmp_dir)) == set(
['dir1', 'dir2', 'text1.txt', 'text2.txt'])
assert set(petrel_backend.list_dir_or_file(tmp_dir)) == {
'dir1', 'dir2', 'text1.txt', 'text2.txt'
}
# 2. list directories and files recursively
assert set(
petrel_backend.list_dir_or_file(
tmp_dir, recursive=True)) == set([
'dir1', '/'.join(('dir1', 'text3.txt')), 'dir2',
'/'.join(('dir2', 'dir3')), '/'.join(
petrel_backend.list_dir_or_file(tmp_dir, recursive=True)) == {
'dir1', '/'.join(('dir1', 'text3.txt')), 'dir2', '/'.join(
('dir2', 'dir3')), '/'.join(
('dir2', 'dir3', 'text4.txt')), '/'.join(
('dir2', 'img.jpg')), 'text1.txt', 'text2.txt'
])
}
# 3. only list directories
assert set(
petrel_backend.list_dir_or_file(
tmp_dir, list_file=False)) == set(['dir1', 'dir2'])
tmp_dir, list_file=False)) == {'dir1', 'dir2'}
with pytest.raises(
TypeError,
match=('`list_dir` should be False when `suffix` is not '
......@@ -489,31 +490,30 @@ class TestFileClient:
# 4. only list directories recursively
assert set(
petrel_backend.list_dir_or_file(
tmp_dir, list_file=False, recursive=True)) == set(
['dir1', 'dir2', '/'.join(('dir2', 'dir3'))])
tmp_dir, list_file=False, recursive=True)) == {
'dir1', 'dir2', '/'.join(('dir2', 'dir3'))
}
# 5. only list files
assert set(
petrel_backend.list_dir_or_file(tmp_dir,
list_dir=False)) == set(
['text1.txt', 'text2.txt'])
petrel_backend.list_dir_or_file(
tmp_dir, list_dir=False)) == {'text1.txt', 'text2.txt'}
# 6. only list files recursively
assert set(
petrel_backend.list_dir_or_file(
tmp_dir, list_dir=False, recursive=True)) == set([
tmp_dir, list_dir=False, recursive=True)) == {
'/'.join(('dir1', 'text3.txt')), '/'.join(
('dir2', 'dir3', 'text4.txt')), '/'.join(
('dir2', 'img.jpg')), 'text1.txt', 'text2.txt'
])
}
# 7. only list files ending with suffix
assert set(
petrel_backend.list_dir_or_file(
tmp_dir, list_dir=False,
suffix='.txt')) == set(['text1.txt', 'text2.txt'])
suffix='.txt')) == {'text1.txt', 'text2.txt'}
assert set(
petrel_backend.list_dir_or_file(
tmp_dir, list_dir=False,
suffix=('.txt',
'.jpg'))) == set(['text1.txt', 'text2.txt'])
suffix=('.txt', '.jpg'))) == {'text1.txt', 'text2.txt'}
with pytest.raises(
TypeError,
match='`suffix` must be a string or tuple of strings'):
......@@ -523,22 +523,22 @@ class TestFileClient:
assert set(
petrel_backend.list_dir_or_file(
tmp_dir, list_dir=False, suffix='.txt',
recursive=True)) == set([
recursive=True)) == {
'/'.join(('dir1', 'text3.txt')), '/'.join(
('dir2', 'dir3', 'text4.txt')), 'text1.txt',
'text2.txt'
])
}
# 7. only list files ending with suffix
assert set(
petrel_backend.list_dir_or_file(
tmp_dir,
list_dir=False,
suffix=('.txt', '.jpg'),
recursive=True)) == set([
recursive=True)) == {
'/'.join(('dir1', 'text3.txt')), '/'.join(
('dir2', 'dir3', 'text4.txt')), '/'.join(
('dir2', 'img.jpg')), 'text1.txt', 'text2.txt'
])
}
@patch('mc.MemcachedClient.GetInstance', MockMemcachedClient)
@patch('mc.pyvector', MagicMock)
......
......@@ -128,7 +128,7 @@ def test_register_handler():
assert content == '1.jpg\n2.jpg\n3.jpg\n4.jpg\n5.jpg'
tmp_filename = osp.join(tempfile.gettempdir(), 'mmcv_test.txt2')
mmcv.dump(content, tmp_filename)
with open(tmp_filename, 'r') as f:
with open(tmp_filename) as f:
written = f.read()
os.remove(tmp_filename)
assert written == '\n' + content
......
......@@ -6,7 +6,7 @@ import torch
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE
class TestBBox(object):
class TestBBox:
def _test_bbox_overlaps(self, device='cpu', dtype=torch.float):
from mmcv.ops import bbox_overlaps
......
......@@ -4,7 +4,7 @@ import torch
import torch.nn.functional as F
class TestBilinearGridSample(object):
class TestBilinearGridSample:
def _test_bilinear_grid_sample(self,
dtype=torch.float,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment