Unverified Commit b7e8d7d7 authored by Cao Yuhang's avatar Cao Yuhang Committed by GitHub
Browse files

Use f-string (#245)

* use f-string

* delete python3.5 sup

* minor fix

* fix supported python version

* fix format

* fix yapf

* remove redundant space

* fix typo
parent 1e8a2121
...@@ -19,7 +19,6 @@ env: ...@@ -19,7 +19,6 @@ env:
- COLUMNS=80 - COLUMNS=80
python: python:
- "3.5"
- "3.6" - "3.6"
- "3.7" - "3.7"
- "3.8" - "3.8"
......
...@@ -30,7 +30,7 @@ It provides the following functionalities. ...@@ -30,7 +30,7 @@ It provides the following functionalities.
See the `documentation <http://mmcv.readthedocs.io/en/latest>`_ for more features and usage. See the `documentation <http://mmcv.readthedocs.io/en/latest>`_ for more features and usage.
Note: MMCV requires Python 3.5+. Note: MMCV requires Python 3.6+.
Installation Installation
......
...@@ -17,11 +17,10 @@ def quantize(arr, min_val, max_val, levels, dtype=np.int64): ...@@ -17,11 +17,10 @@ def quantize(arr, min_val, max_val, levels, dtype=np.int64):
""" """
if not (isinstance(levels, int) and levels > 1): if not (isinstance(levels, int) and levels > 1):
raise ValueError( raise ValueError(
'levels must be a positive integer, but got {}'.format(levels)) f'levels must be a positive integer, but got {levels}')
if min_val >= max_val: if min_val >= max_val:
raise ValueError( raise ValueError(
'min_val ({}) must be smaller than max_val ({})'.format( f'min_val ({min_val}) must be smaller than max_val ({max_val})')
min_val, max_val))
arr = np.clip(arr, min_val, max_val) - min_val arr = np.clip(arr, min_val, max_val) - min_val
quantized_arr = np.minimum( quantized_arr = np.minimum(
...@@ -45,11 +44,10 @@ def dequantize(arr, min_val, max_val, levels, dtype=np.float64): ...@@ -45,11 +44,10 @@ def dequantize(arr, min_val, max_val, levels, dtype=np.float64):
""" """
if not (isinstance(levels, int) and levels > 1): if not (isinstance(levels, int) and levels > 1):
raise ValueError( raise ValueError(
'levels must be a positive integer, but got {}'.format(levels)) f'levels must be a positive integer, but got {levels}')
if min_val >= max_val: if min_val >= max_val:
raise ValueError( raise ValueError(
'min_val ({}) must be smaller than max_val ({})'.format( f'min_val ({min_val}) must be smaller than max_val ({max_val})')
min_val, max_val))
dequantized_arr = (arr + 0.5).astype(dtype) * (max_val - dequantized_arr = (arr + 0.5).astype(dtype) * (max_val -
min_val) / levels + min_val min_val) / levels + min_val
......
...@@ -220,7 +220,7 @@ class ResNet(nn.Module): ...@@ -220,7 +220,7 @@ class ResNet(nn.Module):
with_cp=False): with_cp=False):
super(ResNet, self).__init__() super(ResNet, self).__init__()
if depth not in self.arch_settings: if depth not in self.arch_settings:
raise KeyError('invalid depth {} for resnet'.format(depth)) raise KeyError(f'invalid depth {depth} for resnet')
assert num_stages >= 1 and num_stages <= 4 assert num_stages >= 1 and num_stages <= 4
block, stage_blocks = self.arch_settings[depth] block, stage_blocks = self.arch_settings[depth]
stage_blocks = stage_blocks[:num_stages] stage_blocks = stage_blocks[:num_stages]
...@@ -256,7 +256,7 @@ class ResNet(nn.Module): ...@@ -256,7 +256,7 @@ class ResNet(nn.Module):
style=self.style, style=self.style,
with_cp=with_cp) with_cp=with_cp)
self.inplanes = planes * block.expansion self.inplanes = planes * block.expansion
layer_name = 'layer{}'.format(i + 1) layer_name = f'layer{i + 1}'
self.add_module(layer_name, res_layer) self.add_module(layer_name, res_layer)
self.res_layers.append(layer_name) self.res_layers.append(layer_name)
...@@ -309,7 +309,7 @@ class ResNet(nn.Module): ...@@ -309,7 +309,7 @@ class ResNet(nn.Module):
self.bn1.weight.requires_grad = False self.bn1.weight.requires_grad = False
self.bn1.bias.requires_grad = False self.bn1.bias.requires_grad = False
for i in range(1, self.frozen_stages + 1): for i in range(1, self.frozen_stages + 1):
mod = getattr(self, 'layer{}'.format(i)) mod = getattr(self, f'layer{i}')
mod.eval() mod.eval()
for param in mod.parameters(): for param in mod.parameters():
param.requires_grad = False param.requires_grad = False
...@@ -73,7 +73,7 @@ class VGG(nn.Module): ...@@ -73,7 +73,7 @@ class VGG(nn.Module):
with_last_pool=True): with_last_pool=True):
super(VGG, self).__init__() super(VGG, self).__init__()
if depth not in self.arch_settings: if depth not in self.arch_settings:
raise KeyError('invalid depth {} for vgg'.format(depth)) raise KeyError(f'invalid depth {depth} for vgg')
assert num_stages >= 1 and num_stages <= 5 assert num_stages >= 1 and num_stages <= 5
stage_blocks = self.arch_settings[depth] stage_blocks = self.arch_settings[depth]
self.stage_blocks = stage_blocks[:num_stages] self.stage_blocks = stage_blocks[:num_stages]
......
...@@ -169,20 +169,19 @@ class FileClient(object): ...@@ -169,20 +169,19 @@ class FileClient(object):
def __init__(self, backend='disk', **kwargs): def __init__(self, backend='disk', **kwargs):
if backend not in self._backends: if backend not in self._backends:
raise ValueError( raise ValueError(
'Backend {} is not supported. Currently supported ones are {}'. f'Backend {backend} is not supported. Currently supported ones'
format(backend, list(self._backends.keys()))) f' are {list(self._backends.keys())}')
self.backend = backend self.backend = backend
self.client = self._backends[backend](**kwargs) self.client = self._backends[backend](**kwargs)
@classmethod @classmethod
def register_backend(cls, name, backend): def register_backend(cls, name, backend):
if not inspect.isclass(backend): if not inspect.isclass(backend):
raise TypeError('backend should be a class but got {}'.format( raise TypeError(
type(backend))) f'backend should be a class but got {type(backend)}')
if not issubclass(backend, BaseStorageBackend): if not issubclass(backend, BaseStorageBackend):
raise TypeError( raise TypeError(
'backend {} is not a subclass of BaseStorageBackend'.format( f'backend {backend} is not a subclass of BaseStorageBackend')
backend))
cls._backends[name] = backend cls._backends[name] = backend
......
...@@ -34,7 +34,7 @@ def load(file, file_format=None, **kwargs): ...@@ -34,7 +34,7 @@ def load(file, file_format=None, **kwargs):
if file_format is None and is_str(file): if file_format is None and is_str(file):
file_format = file.split('.')[-1] file_format = file.split('.')[-1]
if file_format not in file_handlers: if file_format not in file_handlers:
raise TypeError('Unsupported format: {}'.format(file_format)) raise TypeError(f'Unsupported format: {file_format}')
handler = file_handlers[file_format] handler = file_handlers[file_format]
if is_str(file): if is_str(file):
...@@ -71,7 +71,7 @@ def dump(obj, file=None, file_format=None, **kwargs): ...@@ -71,7 +71,7 @@ def dump(obj, file=None, file_format=None, **kwargs):
raise ValueError( raise ValueError(
'file_format must be specified since file is None') 'file_format must be specified since file is None')
if file_format not in file_handlers: if file_format not in file_handlers:
raise TypeError('Unsupported format: {}'.format(file_format)) raise TypeError(f'Unsupported format: {file_format}')
handler = file_handlers[file_format] handler = file_handlers[file_format]
if file is None: if file is None:
...@@ -94,8 +94,7 @@ def _register_handler(handler, file_formats): ...@@ -94,8 +94,7 @@ def _register_handler(handler, file_formats):
""" """
if not isinstance(handler, BaseFileHandler): if not isinstance(handler, BaseFileHandler):
raise TypeError( raise TypeError(
'handler must be a child of BaseFileHandler, not {}'.format( f'handler must be a child of BaseFileHandler, not {type(handler)}')
type(handler)))
if isinstance(file_formats, str): if isinstance(file_formats, str):
file_formats = [file_formats] file_formats = [file_formats]
if not is_list_of(file_formats, str): if not is_list_of(file_formats, str):
......
...@@ -13,7 +13,7 @@ def imconvert(img, src, dst): ...@@ -13,7 +13,7 @@ def imconvert(img, src, dst):
Returns: Returns:
ndarray: The converted image. ndarray: The converted image.
""" """
code = getattr(cv2, 'COLOR_{}2{}'.format(src.upper(), dst.upper())) code = getattr(cv2, f'COLOR_{src.upper()}2{dst.upper()}')
out_img = cv2.cvtColor(img, code) out_img = cv2.cvtColor(img, code)
return out_img return out_img
...@@ -82,20 +82,21 @@ def gray2rgb(img): ...@@ -82,20 +82,21 @@ def gray2rgb(img):
def convert_color_factory(src, dst): def convert_color_factory(src, dst):
code = getattr(cv2, 'COLOR_{}2{}'.format(src.upper(), dst.upper())) code = getattr(cv2, f'COLOR_{src.upper()}2{dst.upper()}')
def convert_color(img): def convert_color(img):
out_img = cv2.cvtColor(img, code) out_img = cv2.cvtColor(img, code)
return out_img return out_img
convert_color.__doc__ = """Convert a {0} image to {1} image. convert_color.__doc__ = f"""Convert a {src.upper()} image to {dst.upper()}
image.
Args: Args:
img (ndarray or str): The input image. img (ndarray or str): The input image.
Returns: Returns:
ndarray: The converted {1} image. ndarray: The converted {dst.upper()} image.
""".format(src.upper(), dst.upper()) """
return convert_color return convert_color
......
...@@ -91,8 +91,7 @@ def rescale_size(old_size, scale, return_scale=False): ...@@ -91,8 +91,7 @@ def rescale_size(old_size, scale, return_scale=False):
w, h = old_size w, h = old_size
if isinstance(scale, (float, int)): if isinstance(scale, (float, int)):
if scale <= 0: if scale <= 0:
raise ValueError( raise ValueError(f'Invalid scale {scale}, must be positive.')
'Invalid scale {}, must be positive.'.format(scale))
scale_factor = scale scale_factor = scale
elif isinstance(scale, tuple): elif isinstance(scale, tuple):
max_long_edge = max(scale) max_long_edge = max(scale)
...@@ -101,8 +100,7 @@ def rescale_size(old_size, scale, return_scale=False): ...@@ -101,8 +100,7 @@ def rescale_size(old_size, scale, return_scale=False):
max_short_edge / min(h, w)) max_short_edge / min(h, w))
else: else:
raise TypeError( raise TypeError(
'Scale must be a number or tuple of int, but got {}'.format( f'Scale must be a number or tuple of int, but got {type(scale)}')
type(scale)))
new_size = _scale_size((w, h), scale_factor) new_size = _scale_size((w, h), scale_factor)
......
...@@ -80,7 +80,7 @@ def imread(img_or_path, flag='color', channel_order='bgr'): ...@@ -80,7 +80,7 @@ def imread(img_or_path, flag='color', channel_order='bgr'):
return img_or_path return img_or_path
elif is_str(img_or_path): elif is_str(img_or_path):
check_file_exist(img_or_path, check_file_exist(img_or_path,
'img file does not exist: {}'.format(img_or_path)) f'img file does not exist: {img_or_path}')
if imread_backend == 'turbojpeg': if imread_backend == 'turbojpeg':
with open(img_or_path, 'rb') as in_file: with open(img_or_path, 'rb') as in_file:
img = jpeg.decode(in_file.read(), img = jpeg.decode(in_file.read(),
......
...@@ -24,7 +24,7 @@ def scatter(input, devices, streams=None): ...@@ -24,7 +24,7 @@ def scatter(input, devices, streams=None):
output = output.cuda(devices[0], non_blocking=True) output = output.cuda(devices[0], non_blocking=True)
return output return output
else: else:
raise Exception('Unknown type {}.'.format(type(input))) raise Exception(f'Unknown type {type(input)}.')
def synchronize_stream(output, devices, streams): def synchronize_stream(output, devices, streams):
...@@ -41,7 +41,7 @@ def synchronize_stream(output, devices, streams): ...@@ -41,7 +41,7 @@ def synchronize_stream(output, devices, streams):
main_stream.wait_stream(streams[0]) main_stream.wait_stream(streams[0])
output.record_stream(main_stream) output.record_stream(main_stream)
else: else:
raise Exception('Unknown type {}.'.format(type(output))) raise Exception(f'Unknown type {type(output)}.')
def get_input_device(input): def get_input_device(input):
...@@ -54,7 +54,7 @@ def get_input_device(input): ...@@ -54,7 +54,7 @@ def get_input_device(input):
elif isinstance(input, torch.Tensor): elif isinstance(input, torch.Tensor):
return input.get_device() if input.is_cuda else -1 return input.get_device() if input.is_cuda else -1
else: else:
raise Exception('Unknown type {}.'.format(type(input))) raise Exception(f'Unknown type {type(input)}.')
class Scatter(object): class Scatter(object):
......
...@@ -21,7 +21,7 @@ def collate(batch, samples_per_gpu=1): ...@@ -21,7 +21,7 @@ def collate(batch, samples_per_gpu=1):
""" """
if not isinstance(batch, collections.Sequence): if not isinstance(batch, collections.Sequence):
raise TypeError('{} is not supported.'.format(batch.dtype)) raise TypeError(f'{batch.dtype} is not supported.')
if isinstance(batch[0], DataContainer): if isinstance(batch[0], DataContainer):
assert len(batch) % samples_per_gpu == 0 assert len(batch) % samples_per_gpu == 0
......
...@@ -9,8 +9,9 @@ def assert_tensor_type(func): ...@@ -9,8 +9,9 @@ def assert_tensor_type(func):
@functools.wraps(func) @functools.wraps(func)
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
if not isinstance(args[0].data, torch.Tensor): if not isinstance(args[0].data, torch.Tensor):
raise AttributeError('{} has no attribute {} for type {}'.format( raise AttributeError(
args[0].__class__.__name__, func.__name__, args[0].datatype)) f'{args[0].__class__.__name__} has no attribute '
f'{func.__name__} for type {args[0].datatype}')
return func(*args, **kwargs) return func(*args, **kwargs)
return wrapper return wrapper
...@@ -47,7 +48,7 @@ class DataContainer(object): ...@@ -47,7 +48,7 @@ class DataContainer(object):
self._pad_dims = pad_dims self._pad_dims = pad_dims
def __repr__(self): def __repr__(self):
return '{}({})'.format(self.__class__.__name__, repr(self.data)) return f'{self.__class__.__name__}({repr(self.data)})'
@property @property
def data(self): def data(self):
......
...@@ -86,11 +86,11 @@ def load_state_dict(module, state_dict, strict=False, logger=None): ...@@ -86,11 +86,11 @@ def load_state_dict(module, state_dict, strict=False, logger=None):
] ]
if unexpected_keys: if unexpected_keys:
err_msg.append('unexpected key in source state_dict: {}\n'.format( err_msg.append('unexpected key in source '
', '.join(unexpected_keys))) f'state_dict: {", ".join(unexpected_keys)}\n')
if missing_keys: if missing_keys:
err_msg.append('missing keys in source state_dict: {}\n'.format( err_msg.append(
', '.join(missing_keys))) f'missing keys in source state_dict: {", ".join(missing_keys)}\n')
rank, _ = get_dist_info() rank, _ = get_dist_info()
if len(err_msg) > 0 and rank == 0: if len(err_msg) > 0 and rank == 0:
...@@ -124,7 +124,7 @@ def get_torchvision_models(): ...@@ -124,7 +124,7 @@ def get_torchvision_models():
for _, name, ispkg in pkgutil.walk_packages(torchvision.models.__path__): for _, name, ispkg in pkgutil.walk_packages(torchvision.models.__path__):
if ispkg: if ispkg:
continue continue
_zoo = import_module('torchvision.models.{}'.format(name)) _zoo = import_module(f'torchvision.models.{name}')
if hasattr(_zoo, 'model_urls'): if hasattr(_zoo, 'model_urls'):
_urls = getattr(_zoo, 'model_urls') _urls = getattr(_zoo, 'model_urls')
model_urls.update(_urls) model_urls.update(_urls)
...@@ -160,7 +160,7 @@ def _load_checkpoint(filename, map_location=None): ...@@ -160,7 +160,7 @@ def _load_checkpoint(filename, map_location=None):
checkpoint = load_url_dist(filename) checkpoint = load_url_dist(filename)
else: else:
if not osp.isfile(filename): if not osp.isfile(filename):
raise IOError('{} is not a checkpoint file'.format(filename)) raise IOError(f'{filename} is not a checkpoint file')
checkpoint = torch.load(filename, map_location=map_location) checkpoint = torch.load(filename, map_location=map_location)
return checkpoint return checkpoint
...@@ -191,7 +191,7 @@ def load_checkpoint(model, ...@@ -191,7 +191,7 @@ def load_checkpoint(model,
state_dict = checkpoint['state_dict'] state_dict = checkpoint['state_dict']
else: else:
raise RuntimeError( raise RuntimeError(
'No state_dict found in checkpoint file {}'.format(filename)) f'No state_dict found in checkpoint file {filename}')
# strip prefix of state_dict # strip prefix of state_dict
if list(state_dict.keys())[0].startswith('module.'): if list(state_dict.keys())[0].startswith('module.'):
state_dict = {k[7:]: v for k, v in checkpoint['state_dict'].items()} state_dict = {k[7:]: v for k, v in checkpoint['state_dict'].items()}
...@@ -233,8 +233,7 @@ def save_checkpoint(model, filename, optimizer=None, meta=None): ...@@ -233,8 +233,7 @@ def save_checkpoint(model, filename, optimizer=None, meta=None):
if meta is None: if meta is None:
meta = {} meta = {}
elif not isinstance(meta, dict): elif not isinstance(meta, dict):
raise TypeError('meta must be a dict or None, but got {}'.format( raise TypeError(f'meta must be a dict or None, but got {type(meta)}')
type(meta)))
meta.update(mmcv_version=mmcv.__version__, time=time.asctime()) meta.update(mmcv_version=mmcv.__version__, time=time.asctime())
mmcv.mkdir_or_exist(osp.dirname(filename)) mmcv.mkdir_or_exist(osp.dirname(filename))
......
...@@ -18,7 +18,7 @@ def init_dist(launcher, backend='nccl', **kwargs): ...@@ -18,7 +18,7 @@ def init_dist(launcher, backend='nccl', **kwargs):
elif launcher == 'slurm': elif launcher == 'slurm':
_init_dist_slurm(backend, **kwargs) _init_dist_slurm(backend, **kwargs)
else: else:
raise ValueError('Invalid launcher type: {}'.format(launcher)) raise ValueError(f'Invalid launcher type: {launcher}')
def _init_dist_pytorch(backend, **kwargs): def _init_dist_pytorch(backend, **kwargs):
...@@ -40,7 +40,7 @@ def _init_dist_slurm(backend, port=29500, **kwargs): ...@@ -40,7 +40,7 @@ def _init_dist_slurm(backend, port=29500, **kwargs):
num_gpus = torch.cuda.device_count() num_gpus = torch.cuda.device_count()
torch.cuda.set_device(proc_id % num_gpus) torch.cuda.set_device(proc_id % num_gpus)
addr = subprocess.getoutput( addr = subprocess.getoutput(
'scontrol show hostname {} | head -n1'.format(node_list)) f'scontrol show hostname {node_list} | head -n1')
os.environ['MASTER_PORT'] = str(port) os.environ['MASTER_PORT'] = str(port)
os.environ['MASTER_ADDR'] = addr os.environ['MASTER_ADDR'] = addr
os.environ['WORLD_SIZE'] = str(ntasks) os.environ['WORLD_SIZE'] = str(ntasks)
......
...@@ -67,7 +67,7 @@ class MlflowLoggerHook(LoggerHook): ...@@ -67,7 +67,7 @@ class MlflowLoggerHook(LoggerHook):
for var, val in runner.log_buffer.output.items(): for var, val in runner.log_buffer.output.items():
if var in ['time', 'data_time']: if var in ['time', 'data_time']:
continue continue
tag = '{}/{}'.format(var, runner.mode) tag = f'{var}/{runner.mode}'
if isinstance(val, numbers.Number): if isinstance(val, numbers.Number):
metrics[tag] = val metrics[tag] = val
metrics['learning_rate'] = runner.current_lr()[0] metrics['learning_rate'] = runner.current_lr()[0]
......
...@@ -45,7 +45,7 @@ class TensorboardLoggerHook(LoggerHook): ...@@ -45,7 +45,7 @@ class TensorboardLoggerHook(LoggerHook):
for var in runner.log_buffer.output: for var in runner.log_buffer.output:
if var in ['time', 'data_time']: if var in ['time', 'data_time']:
continue continue
tag = '{}/{}'.format(var, runner.mode) tag = f'{var}/{runner.mode}'
record = runner.log_buffer.output[var] record = runner.log_buffer.output[var]
if isinstance(record, str): if isinstance(record, str):
self.writer.add_text(tag, record, runner.iter) self.writer.add_text(tag, record, runner.iter)
......
...@@ -22,7 +22,7 @@ class TextLoggerHook(LoggerHook): ...@@ -22,7 +22,7 @@ class TextLoggerHook(LoggerHook):
super(TextLoggerHook, self).before_run(runner) super(TextLoggerHook, self).before_run(runner)
self.start_iter = runner.iter self.start_iter = runner.iter
self.json_log_path = osp.join(runner.work_dir, self.json_log_path = osp.join(runner.work_dir,
'{}.log.json'.format(runner.timestamp)) f'{runner.timestamp}.log.json')
if runner.meta is not None: if runner.meta is not None:
self._dump_log(runner.meta, runner) self._dump_log(runner.meta, runner)
...@@ -37,25 +37,24 @@ class TextLoggerHook(LoggerHook): ...@@ -37,25 +37,24 @@ class TextLoggerHook(LoggerHook):
def _log_info(self, log_dict, runner): def _log_info(self, log_dict, runner):
if runner.mode == 'train': if runner.mode == 'train':
log_str = 'Epoch [{}][{}/{}]\tlr: {:.5f}, '.format( log_str = f'Epoch [{log_dict["epoch"]}]' \
log_dict['epoch'], log_dict['iter'], len(runner.data_loader), f'[{log_dict["iter"]}/{len(runner.data_loader)}]\t' \
log_dict['lr']) f'lr: {log_dict["lr"]:.5f}, '
if 'time' in log_dict.keys(): if 'time' in log_dict.keys():
self.time_sec_tot += (log_dict['time'] * self.interval) self.time_sec_tot += (log_dict['time'] * self.interval)
time_sec_avg = self.time_sec_tot / ( time_sec_avg = self.time_sec_tot / (
runner.iter - self.start_iter + 1) runner.iter - self.start_iter + 1)
eta_sec = time_sec_avg * (runner.max_iters - runner.iter - 1) eta_sec = time_sec_avg * (runner.max_iters - runner.iter - 1)
eta_str = str(datetime.timedelta(seconds=int(eta_sec))) eta_str = str(datetime.timedelta(seconds=int(eta_sec)))
log_str += 'eta: {}, '.format(eta_str) log_str += f'eta: {eta_str}, '
log_str += ('time: {:.3f}, data_time: {:.3f}, '.format( log_str += f'time: {log_dict["time"]:.3f}, ' \
log_dict['time'], log_dict['data_time'])) f'data_time: {log_dict["data_time"]:.3f}, '
# statistic memory # statistic memory
if torch.cuda.is_available(): if torch.cuda.is_available():
log_str += 'memory: {}, '.format(log_dict['memory']) log_str += f'memory: {log_dict["memory"]}, '
else: else:
log_str = 'Epoch({}) [{}][{}]\t'.format(log_dict['mode'], log_str = 'Epoch({log_dict["mode"]}) ' \
log_dict['epoch'] - 1, f'[{log_dict["epoch"] - 1}][{log_dict["iter"]}]\t'
log_dict['iter'])
log_items = [] log_items = []
for name, val in log_dict.items(): for name, val in log_dict.items():
# TODO: resolve this hack # TODO: resolve this hack
...@@ -66,8 +65,8 @@ class TextLoggerHook(LoggerHook): ...@@ -66,8 +65,8 @@ class TextLoggerHook(LoggerHook):
]: ]:
continue continue
if isinstance(val, float): if isinstance(val, float):
val = '{:.4f}'.format(val) val = f'{val:.4f}'
log_items.append('{}: {}'.format(name, val)) log_items.append(f'{name}: {val}')
log_str += ', '.join(log_items) log_str += ', '.join(log_items)
runner.logger.info(log_str) runner.logger.info(log_str)
......
...@@ -42,7 +42,7 @@ class WandbLoggerHook(LoggerHook): ...@@ -42,7 +42,7 @@ class WandbLoggerHook(LoggerHook):
for var, val in runner.log_buffer.output.items(): for var, val in runner.log_buffer.output.items():
if var in ['time', 'data_time']: if var in ['time', 'data_time']:
continue continue
tag = '{}/{}'.format(var, runner.mode) tag = f'{var}/{runner.mode}'
if isinstance(val, numbers.Number): if isinstance(val, numbers.Number):
metrics[tag] = val metrics[tag] = val
metrics['learning_rate'] = runner.current_lr()[0] metrics['learning_rate'] = runner.current_lr()[0]
......
...@@ -32,8 +32,8 @@ class LrUpdaterHook(Hook): ...@@ -32,8 +32,8 @@ class LrUpdaterHook(Hook):
if warmup is not None: if warmup is not None:
if warmup not in ['constant', 'linear', 'exp']: if warmup not in ['constant', 'linear', 'exp']:
raise ValueError( raise ValueError(
'"{}" is not a supported type for warming up, valid types' f'"{warmup}" is not a supported type for warming up, valid'
' are "constant" and "linear"'.format(warmup)) ' types are "constant" and "linear"')
if warmup is not None: if warmup is not None:
assert warmup_iters > 0, \ assert warmup_iters > 0, \
'"warmup_iters" must be a positive integer' '"warmup_iters" must be a positive integer'
...@@ -254,7 +254,7 @@ class CyclicLrUpdaterHook(LrUpdaterHook): ...@@ -254,7 +254,7 @@ class CyclicLrUpdaterHook(LrUpdaterHook):
if len(target_ratio) == 1 else target_ratio if len(target_ratio) == 1 else target_ratio
else: else:
raise ValueError('target_ratio should be either float ' raise ValueError('target_ratio should be either float '
'or tuple, got {}'.format(type(target_ratio))) f'or tuple, got {type(target_ratio)}')
assert len(target_ratio) == 2, \ assert len(target_ratio) == 2, \
'"target_ratio" must be list or tuple of two floats' '"target_ratio" must be list or tuple of two floats'
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment