Unverified Commit 45fa3e44 authored by Zaida Zhou's avatar Zaida Zhou Committed by GitHub
Browse files

Add pyupgrade pre-commit hook (#1937)

* add pyupgrade

* add options for pyupgrade

* minor refinement
parent c561264d
......@@ -105,7 +105,7 @@ def test_initilization_info_logger():
class CheckLoggerModel(BaseModule):
def __init__(self, init_cfg=None):
super(CheckLoggerModel, self).__init__(init_cfg)
super().__init__(init_cfg)
self.conv1 = nn.Conv2d(1, 1, 1, 1)
self.conv2 = OverloadInitConv(1, 1, 1, 1)
self.conv3 = nn.Conv2d(1, 1, 1, 1)
......@@ -151,7 +151,7 @@ def test_initilization_info_logger():
class OverloadInitConvFc(nn.Conv2d, BaseModule):
def __init__(self, *args, **kwargs):
super(OverloadInitConvFc, self).__init__(*args, **kwargs)
super().__init__(*args, **kwargs)
self.conv1 = nn.Linear(1, 1)
def init_weights(self):
......@@ -162,7 +162,7 @@ def test_initilization_info_logger():
class CheckLoggerModel(BaseModule):
def __init__(self, init_cfg=None):
super(CheckLoggerModel, self).__init__(init_cfg)
super().__init__(init_cfg)
self.conv1 = nn.Conv2d(1, 1, 1, 1)
self.conv2 = OverloadInitConvFc(1, 1, 1, 1)
self.conv3 = nn.Conv2d(1, 1, 1, 1)
......@@ -171,7 +171,7 @@ def test_initilization_info_logger():
class TopLevelModule(BaseModule):
def __init__(self, init_cfg=None, checklog_init_cfg=None):
super(TopLevelModule, self).__init__(init_cfg)
super().__init__(init_cfg)
self.module1 = CheckLoggerModel(checklog_init_cfg)
self.module2 = OverloadInitConvFc(1, 1, 1, 1)
......
......@@ -22,7 +22,7 @@ sys.modules['petrel_client.client'] = MagicMock()
@MODULE_WRAPPERS.register_module()
class DDPWrapper(object):
class DDPWrapper:
def __init__(self, module):
self.module = module
......@@ -44,7 +44,7 @@ class Model(nn.Module):
self.conv = nn.Conv2d(3, 3, 1)
class Mockpavimodel(object):
class Mockpavimodel:
def __init__(self, name='fakename'):
self.name = name
......@@ -59,18 +59,18 @@ def assert_tensor_equal(tensor_a, tensor_b):
def test_get_state_dict():
if torch.__version__ == 'parrots':
state_dict_keys = set([
state_dict_keys = {
'block.conv.weight', 'block.conv.bias', 'block.norm.weight',
'block.norm.bias', 'block.norm.running_mean',
'block.norm.running_var', 'conv.weight', 'conv.bias'
])
}
else:
state_dict_keys = set([
state_dict_keys = {
'block.conv.weight', 'block.conv.bias', 'block.norm.weight',
'block.norm.bias', 'block.norm.running_mean',
'block.norm.running_var', 'block.norm.num_batches_tracked',
'conv.weight', 'conv.bias'
])
}
model = Model()
state_dict = get_state_dict(model)
......
......@@ -70,7 +70,7 @@ def test_auto_fp16():
with pytest.raises(TypeError):
# ExampleObject is not a subclass of nn.Module
class ExampleObject(object):
class ExampleObject:
@auto_fp16()
def __call__(self, x):
......@@ -196,7 +196,7 @@ def test_force_fp32():
with pytest.raises(TypeError):
# ExampleObject is not a subclass of nn.Module
class ExampleObject(object):
class ExampleObject:
@force_fp32()
def __call__(self, x):
......
......@@ -35,12 +35,12 @@ def test_construct():
cfg = Config(cfg_dict, filename=item)
assert isinstance(cfg, Config)
assert isinstance(cfg.filename, str) and cfg.filename == str(item)
assert cfg.text == open(item, 'r').read()
assert cfg.text == open(item).read()
assert cfg.dump() == cfg.pretty_text
with tempfile.TemporaryDirectory() as temp_config_dir:
dump_file = osp.join(temp_config_dir, 'a.py')
cfg.dump(dump_file)
assert cfg.dump() == open(dump_file, 'r').read()
assert cfg.dump() == open(dump_file).read()
assert Config.fromfile(dump_file)
# test b.json
......@@ -48,12 +48,12 @@ def test_construct():
cfg = Config(cfg_dict, filename=cfg_file)
assert isinstance(cfg, Config)
assert cfg.filename == cfg_file
assert cfg.text == open(cfg_file, 'r').read()
assert cfg.text == open(cfg_file).read()
assert cfg.dump() == json.dumps(cfg_dict)
with tempfile.TemporaryDirectory() as temp_config_dir:
dump_file = osp.join(temp_config_dir, 'b.json')
cfg.dump(dump_file)
assert cfg.dump() == open(dump_file, 'r').read()
assert cfg.dump() == open(dump_file).read()
assert Config.fromfile(dump_file)
# test c.yaml
......@@ -61,12 +61,12 @@ def test_construct():
cfg = Config(cfg_dict, filename=cfg_file)
assert isinstance(cfg, Config)
assert cfg.filename == cfg_file
assert cfg.text == open(cfg_file, 'r').read()
assert cfg.text == open(cfg_file).read()
assert cfg.dump() == yaml.dump(cfg_dict)
with tempfile.TemporaryDirectory() as temp_config_dir:
dump_file = osp.join(temp_config_dir, 'c.yaml')
cfg.dump(dump_file)
assert cfg.dump() == open(dump_file, 'r').read()
assert cfg.dump() == open(dump_file).read()
assert Config.fromfile(dump_file)
# test h.py
......@@ -82,12 +82,12 @@ def test_construct():
cfg = Config(cfg_dict, filename=cfg_file)
assert isinstance(cfg, Config)
assert cfg.filename == cfg_file
assert cfg.text == open(cfg_file, 'r').read()
assert cfg.text == open(cfg_file).read()
assert cfg.dump() == cfg.pretty_text
with tempfile.TemporaryDirectory() as temp_config_dir:
dump_file = osp.join(temp_config_dir, 'h.py')
cfg.dump(dump_file)
assert cfg.dump() == open(dump_file, 'r').read()
assert cfg.dump() == open(dump_file).read()
assert Config.fromfile(dump_file)
assert Config.fromfile(dump_file)['item1'] == cfg_dict['item1']
assert Config.fromfile(dump_file)['item2'] == cfg_dict['item2']
......@@ -109,12 +109,12 @@ def test_construct():
cfg = Config(cfg_dict, filename=cfg_file)
assert isinstance(cfg, Config)
assert cfg.filename == cfg_file
assert cfg.text == open(cfg_file, 'r').read()
assert cfg.text == open(cfg_file).read()
assert cfg.dump() == yaml.dump(cfg_dict)
with tempfile.TemporaryDirectory() as temp_config_dir:
dump_file = osp.join(temp_config_dir, 'p.yaml')
cfg.dump(dump_file)
assert cfg.dump() == open(dump_file, 'r').read()
assert cfg.dump() == open(dump_file).read()
assert Config.fromfile(dump_file)
assert Config.fromfile(dump_file)['item1'] == cfg_dict['item1']
......@@ -128,12 +128,12 @@ def test_construct():
cfg = Config(cfg_dict, filename=cfg_file)
assert isinstance(cfg, Config)
assert cfg.filename == cfg_file
assert cfg.text == open(cfg_file, 'r').read()
assert cfg.text == open(cfg_file).read()
assert cfg.dump() == json.dumps(cfg_dict)
with tempfile.TemporaryDirectory() as temp_config_dir:
dump_file = osp.join(temp_config_dir, 'o.json')
cfg.dump(dump_file)
assert cfg.dump() == open(dump_file, 'r').read()
assert cfg.dump() == open(dump_file).read()
assert Config.fromfile(dump_file)
assert Config.fromfile(dump_file)['item1'] == cfg_dict['item1']
......@@ -152,7 +152,7 @@ def test_fromfile():
assert isinstance(cfg, Config)
assert isinstance(cfg.filename, str) and cfg.filename == str(item)
assert cfg.text == osp.abspath(osp.expanduser(item)) + '\n' + \
open(item, 'r').read()
open(item).read()
# test custom_imports for Config.fromfile
cfg_file = osp.join(data_path, 'config', 'q.py')
......@@ -182,7 +182,7 @@ def test_fromstring():
out_cfg = Config.fromstring(in_cfg.pretty_text, '.py')
assert in_cfg._cfg_dict == out_cfg._cfg_dict
cfg_str = open(cfg_file, 'r').read()
cfg_str = open(cfg_file).read()
out_cfg = Config.fromstring(cfg_str, file_format)
assert in_cfg._cfg_dict == out_cfg._cfg_dict
......@@ -193,7 +193,7 @@ def test_fromstring():
Config.fromstring(in_cfg.pretty_text, '.json')
# test file format error
cfg_str = open(cfg_file, 'r').read()
cfg_str = open(cfg_file).read()
with pytest.raises(Exception):
Config.fromstring(cfg_str, '.py')
......@@ -205,9 +205,9 @@ def test_merge_from_base():
assert cfg.filename == cfg_file
base_cfg_file = osp.join(data_path, 'config/base.py')
merge_text = osp.abspath(osp.expanduser(base_cfg_file)) + '\n' + \
open(base_cfg_file, 'r').read()
open(base_cfg_file).read()
merge_text += '\n' + osp.abspath(osp.expanduser(cfg_file)) + '\n' + \
open(cfg_file, 'r').read()
open(cfg_file).read()
assert cfg.text == merge_text
assert cfg.item1 == [2, 3]
assert cfg.item2.a == 1
......
......@@ -101,7 +101,7 @@ def test_print_log_logger(caplog):
logger = get_logger('abc', log_file=f.name)
print_log('welcome', logger=logger)
assert caplog.record_tuples[-1] == ('abc', logging.INFO, 'welcome')
with open(f.name, 'r') as fin:
with open(f.name) as fin:
log_text = fin.read()
regex_time = r'\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2},\d{3}'
match = re.fullmatch(regex_time + r' - abc - INFO - welcome\n',
......
......@@ -10,7 +10,7 @@ skip_no_parrots = pytest.mark.skipif(
TORCH_VERSION != 'parrots', reason='test case under parrots environment')
class TestJit(object):
class TestJit:
def test_add_dict(self):
......@@ -255,7 +255,7 @@ class TestJit(object):
def test_instance_method(self):
class T(object):
class T:
def __init__(self, shape):
self._c = torch.rand(shape)
......
......@@ -30,12 +30,14 @@ def test_scandir():
filenames = ['a.bin', '1.txt', '2.txt', '1.json', '2.json', '3.TXT']
assert set(mmcv.scandir(folder)) == set(filenames)
assert set(mmcv.scandir(Path(folder))) == set(filenames)
assert set(mmcv.scandir(folder, '.txt')) == set(
[filename for filename in filenames if filename.endswith('.txt')])
assert set(mmcv.scandir(folder, ('.json', '.txt'))) == set([
filename for filename in filenames
if filename.endswith(('.txt', '.json'))
])
assert set(mmcv.scandir(folder, '.txt')) == {
filename
for filename in filenames if filename.endswith('.txt')
}
assert set(mmcv.scandir(folder, ('.json', '.txt'))) == {
filename
for filename in filenames if filename.endswith(('.txt', '.json'))
}
assert set(mmcv.scandir(folder, '.png')) == set()
# path of sep is `\\` in windows but `/` in linux, so osp.join should be
......@@ -46,27 +48,33 @@ def test_scandir():
osp.join('sub', '1.txt'), '.file'
]
# .file starts with '.' and is a file so it will not be scanned
assert set(mmcv.scandir(folder, recursive=True)) == set(
[filename for filename in filenames_recursive if filename != '.file'])
assert set(mmcv.scandir(Path(folder), recursive=True)) == set(
[filename for filename in filenames_recursive if filename != '.file'])
assert set(mmcv.scandir(folder, '.txt', recursive=True)) == set([
filename for filename in filenames_recursive
if filename.endswith('.txt')
])
assert set(mmcv.scandir(folder, recursive=True)) == {
filename
for filename in filenames_recursive if filename != '.file'
}
assert set(mmcv.scandir(Path(folder), recursive=True)) == {
filename
for filename in filenames_recursive if filename != '.file'
}
assert set(mmcv.scandir(folder, '.txt', recursive=True)) == {
filename
for filename in filenames_recursive if filename.endswith('.txt')
}
assert set(
mmcv.scandir(folder, '.TXT', recursive=True,
case_sensitive=False)) == set([
filename for filename in filenames_recursive
case_sensitive=False)) == {
filename
for filename in filenames_recursive
if filename.endswith(('.txt', '.TXT'))
])
}
assert set(
mmcv.scandir(
folder, ('.TXT', '.JSON'), recursive=True,
case_sensitive=False)) == set([
filename for filename in filenames_recursive
case_sensitive=False)) == {
filename
for filename in filenames_recursive
if filename.endswith(('.txt', '.json', '.TXT'))
])
}
with pytest.raises(TypeError):
list(mmcv.scandir(123))
with pytest.raises(TypeError):
......
# Copyright (c) OpenMMLab. All rights reserved.
import os
import time
from io import StringIO
from unittest.mock import patch
try:
from unittest.mock import patch
except ImportError:
from mock import patch
try:
from StringIO import StringIO
except ImportError:
from io import StringIO
import mmcv # isort:skip
import mmcv
def reset_string_io(io):
......
......@@ -70,7 +70,7 @@ def test_assert_dict_contains_subset():
def test_assert_attrs_equal():
class TestExample(object):
class TestExample:
a, b, c = 1, ('wvi', 3), [4.5, 3.14]
def test_func(self):
......@@ -104,7 +104,7 @@ def test_assert_attrs_equal():
if torch is not None:
class TestExample(object):
class TestExample:
a, b = torch.tensor([1]), torch.tensor([4, 5])
# case 5
......
......@@ -41,7 +41,7 @@ def test_parse_version_info():
def _mock_cmd_success(cmd):
return '3b46d33e90c397869ad5103075838fdfc9812aa0'.encode('ascii')
return b'3b46d33e90c397869ad5103075838fdfc9812aa0'
def _mock_cmd_fail(cmd):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment