Unverified Commit 9f04477f authored by lizz's avatar lizz Committed by GitHub
Browse files

Harmless changes (#340)



* harmless changes
Signed-off-by: default avatarlizz <lizz@sensetime.com>

* Update conv_module.py

* No bracket for simple class
Signed-off-by: default avatarlizz <lizz@sensetime.com>
parent e8fa66aa
from torch import nn as nn
from torch import nn
from .registry import CONV_LAYERS
......
......@@ -89,7 +89,7 @@ class ConvModule(nn.Module):
self.with_activation = act_cfg is not None
# if the conv layer is before a norm layer, bias is unnecessary.
if bias == 'auto':
bias = False if self.with_norm else True
bias = not self.with_norm
self.with_bias = bias
if self.with_norm and self.with_bias:
......
......@@ -140,7 +140,4 @@ def is_norm(layer, exclude=None):
return False
all_norm_bases = (_BatchNorm, _InstanceNorm, nn.GroupNorm, nn.LayerNorm)
if isinstance(layer, all_norm_bases):
return True
else:
return False
return isinstance(layer, all_norm_bases)
......@@ -193,7 +193,7 @@ class HardDiskBackend(BaseStorageBackend):
return value_buf
class FileClient(object):
class FileClient:
"""A general file client to access files in different backend.
The client loads a file or text in a specified backend from its path
......
......@@ -328,8 +328,8 @@ def impad(img, shape, pad_val=0):
if len(shape) < len(img.shape):
shape = shape + (img.shape[-1], )
assert len(shape) == len(img.shape)
for i in range(len(shape)):
assert shape[i] >= img.shape[i]
for s, img_s in zip(shape, img.shape):
assert s >= img_s
pad = np.empty(shape, dtype=img.dtype)
pad[...] = pad_val
pad[:img.shape[0], :img.shape[1], ...] = img
......
......@@ -57,7 +57,7 @@ def get_input_device(input):
raise Exception(f'Unknown type {type(input)}.')
class Scatter(object):
class Scatter:
@staticmethod
def forward(target_gpus, input):
......
......@@ -17,7 +17,7 @@ def assert_tensor_type(func):
return wrapper
class DataContainer(object):
class DataContainer:
"""A container for any type of objects.
Typically tensors will be stacked in the collate function and sliced along
......
......@@ -19,7 +19,4 @@ def is_parallel_module(module):
"""
parallels = (DataParallel, DistributedDataParallel,
MMDistributedDataParallel)
if isinstance(module, parallels):
return True
else:
return False
return isinstance(module, parallels)
......@@ -4,7 +4,7 @@ from mmcv.utils import Registry
HOOKS = Registry('hook')
class Hook(object):
class Hook:
def before_run(self, runner):
pass
......
......@@ -4,7 +4,7 @@ from collections import OrderedDict
import numpy as np
class LogBuffer(object):
class LogBuffer:
def __init__(self):
self.val_history = OrderedDict()
......
......@@ -8,7 +8,7 @@ from .builder import OPTIMIZER_BUILDERS, OPTIMIZERS
@OPTIMIZER_BUILDERS.register_module()
class DefaultOptimizerConstructor(object):
class DefaultOptimizerConstructor:
"""Default constructor for optimizers.
By default each parameter share the same optimizer settings, and we
......
......@@ -55,7 +55,7 @@ def add_args(parser, cfg, prefix=''):
return parser
class Config(object):
class Config:
"""A facility for config and config files.
It supports common file formats as configs: python/json/yaml. The interface
......
......@@ -7,10 +7,7 @@ from .misc import is_str
def is_filepath(x):
if is_str(x) or isinstance(x, Path):
return True
else:
return False
return is_str(x) or isinstance(x, Path)
def fopen(filepath, *args, **kwargs):
......@@ -18,6 +15,7 @@ def fopen(filepath, *args, **kwargs):
return open(filepath, *args, **kwargs)
elif isinstance(filepath, Path):
return filepath.open(*args, **kwargs)
raise ValueError('`filepath` should be a string or a Path')
def check_file_exist(filename, msg_tmpl='file "{}" does not exist'):
......
......@@ -7,7 +7,7 @@ from shutil import get_terminal_size
from .timer import Timer
class ProgressBar(object):
class ProgressBar:
"""A progress bar which can print the progress"""
def __init__(self, task_num=0, bar_width=50, start=True, file=sys.stdout):
......
......@@ -5,7 +5,7 @@ from functools import partial
from .misc import is_str
class Registry(object):
class Registry:
"""A registry to map strings to classes.
Args:
......
......@@ -9,7 +9,7 @@ class TimerError(Exception):
super(TimerError, self).__init__(message)
class Timer(object):
class Timer:
"""A flexible Timer class.
:Example:
......
......@@ -11,7 +11,7 @@ from mmcv.utils import (check_file_exist, mkdir_or_exist, scandir,
track_progress)
class Cache(object):
class Cache:
def __init__(self, capacity):
self._cache = OrderedDict()
......@@ -39,7 +39,7 @@ class Cache(object):
return val
class VideoReader(object):
class VideoReader:
"""Video class with similar usage to a list object.
This video warpper class provides convenient apis to access frames.
......
......@@ -73,7 +73,7 @@ def resize_video(in_file,
"""
if size is None and ratio is None:
raise ValueError('expected size or ratio must be specified')
elif size is not None and ratio is not None:
if size is not None and ratio is not None:
raise ValueError('size and ratio cannot be specified at the same time')
options = {'log_level': log_level}
if size:
......
......@@ -37,10 +37,10 @@ def color_val(color):
elif isinstance(color, tuple):
assert len(color) == 3
for channel in color:
assert channel >= 0 and channel <= 255
assert 0 <= channel <= 255
return color
elif isinstance(color, int):
assert color >= 0 and color <= 255
assert 0 <= color <= 255
return color, color, color
elif isinstance(color, np.ndarray):
assert color.ndim == 1 and color.size == 3
......
......@@ -13,7 +13,7 @@ sys.modules['petrel_client.client'] = MagicMock()
sys.modules['mc'] = MagicMock()
class MockS3Client(object):
class MockS3Client:
def __init__(self, enable_mc=True):
self.enable_mc = enable_mc
......@@ -24,7 +24,7 @@ class MockS3Client(object):
return content
class MockMemcachedClient(object):
class MockMemcachedClient:
def __init__(self, server_list_cfg, client_cfg):
pass
......@@ -34,7 +34,7 @@ class MockMemcachedClient(object):
buffer.content = f.read()
class TestFileClient(object):
class TestFileClient:
@classmethod
def setup_class(cls):
......@@ -190,7 +190,7 @@ class TestFileClient(object):
# name must be a string
with pytest.raises(TypeError):
class TestClass1(object):
class TestClass1:
pass
FileClient.register_backend(1, TestClass1)
......@@ -202,7 +202,7 @@ class TestFileClient(object):
# module must be a subclass of BaseStorageBackend
with pytest.raises(TypeError):
class TestClass1(object):
class TestClass1:
pass
FileClient.register_backend('TestClass1', TestClass1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment