Commit b7536f78 authored by limm's avatar limm
Browse files

add a to another part of mmgeneration code

parent 57e0e891
Pipeline #2777 canceled with stages
# Copyright (c) OpenMMLab. All rights reserved.
from .builder import build_optimizers
__all__ = ['build_optimizers']
# Copyright (c) OpenMMLab. All rights reserved.
from mmcv.runner import build_optimizer
def build_optimizers(model, cfgs):
"""Build multiple optimizers from configs.
If `cfgs` contains several dicts for optimizers, then a dict for each
constructed optimizers will be returned.
If `cfgs` only contains one optimizer config, the constructed optimizer
itself will be returned.
For example,
1) Multiple optimizer configs:
.. code-block:: python
optimizer_cfg = dict(
model1=dict(type='SGD', lr=lr),
model2=dict(type='SGD', lr=lr))
The return dict is
``dict('model1': torch.optim.Optimizer, 'model2': torch.optim.Optimizer)``
2) Single optimizer config:
.. code-block:: python
optimizer_cfg = dict(type='SGD', lr=lr)
The return is ``torch.optim.Optimizer``.
Args:
model (:obj:`nn.Module`): The model with parameters to be optimized.
cfgs (dict): The config dict of the optimizer.
Returns:
dict[:obj:`torch.optim.Optimizer`] | :obj:`torch.optim.Optimizer`:
The initialized optimizers.
"""
optimizers = {}
if hasattr(model, 'module'):
model = model.module
# determine whether 'cfgs' has several dicts for optimizers
is_dict_of_dict = True
for key, cfg in cfgs.items():
if not isinstance(cfg, dict):
is_dict_of_dict = False
if is_dict_of_dict:
for key, cfg in cfgs.items():
cfg_ = cfg.copy()
module = getattr(model, key)
optimizers[key] = build_optimizer(module, cfg_)
return optimizers
return build_optimizer(model, cfgs)
# Copyright (c) OpenMMLab. All rights reserved.
from mmcv.utils import Registry, build_from_cfg
METRICS = Registry('metric')
def build(cfg, registry, default_args=None):
"""Build a module.
Args:
cfg (dict, list[dict]): The config of modules, is is either a dict
or a list of configs.
registry (:obj:`Registry`): A registry the module belongs to.
default_args (dict, optional): Default arguments to build the module.
Defaults to None.
Returns:
nn.Module: A built nn module.
"""
if isinstance(cfg, list):
modules = [
build_from_cfg(cfg_, registry, default_args) for cfg_ in cfg
]
return modules
return build_from_cfg(cfg, registry, default_args)
def build_metric(cfg):
"""Build a metric calculator."""
return build(cfg, METRICS)
# Copyright (c) OpenMMLab. All rights reserved.
from .dynamic_iterbased_runner import DynamicIterBasedRunner
__all__ = ['DynamicIterBasedRunner']
# Copyright (c) OpenMMLab. All rights reserved.
try:
from apex import amp
except ImportError:
amp = None
def apex_amp_initialize(models, optimizers, init_args=None, mode='gan'):
"""Initialize apex.amp for mixed-precision training.
Args:
models (nn.Module | list[Module]): Modules to be wrapped with apex.amp.
optimizer (:obj:`Optimizer`, optional): Optimizer to be saved.
init_args (dict | None, optional): Config for amp initialization.
Defaults to None.
mode (str, optional): The moded used to initialize the apex.map.
Different modes lead to different wrapping mode for models and
optimizers. Defaults to 'gan'.
Returns:
Module, :obj:`Optimizer`: Wrapped module and optimizer.
"""
init_args = init_args or dict()
if mode == 'gan':
_optmizers = [optimizers['generator'], optimizers['discriminator']]
models, _optmizers = amp.initialize(models, _optmizers, **init_args)
optimizers['generator'] = _optmizers[0]
optimizers['discriminator'] = _optmizers[1]
return models, optimizers
else:
raise NotImplementedError(
f'Cannot initialize apex.amp with mode {mode}')
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import time
from tempfile import TemporaryDirectory
import mmcv
import torch
from mmcv.parallel import is_module_wrapper
from mmcv.runner.checkpoint import get_state_dict, weights_to_cpu
from torch.optim import Optimizer
def save_checkpoint(model,
filename,
optimizer=None,
loss_scaler=None,
save_apex_amp=False,
meta=None):
"""Save checkpoint to file.
The checkpoint will have 3 or more fields: ``meta``, ``state_dict`` and
``optimizer``. By default ``meta`` will contain version and time info.
In mixed-precision training, ``loss_scaler`` or ``amp.state_dict`` will be
saved in checkpoint.
Args:
model (Module): Module whose params are to be saved.
filename (str): Checkpoint filename.
optimizer (:obj:`Optimizer`, optional): Optimizer to be saved.
loss_scaler (Object, optional): Loss scaler used for FP16 training.
save_apex_amp (bool, optional): Whether to save apex.amp
``state_dict``.
meta (dict, optional): Metadata to be saved in checkpoint.
"""
if meta is None:
meta = {}
elif not isinstance(meta, dict):
raise TypeError(f'meta must be a dict or None, but got {type(meta)}')
meta.update(mmcv_version=mmcv.__version__, time=time.asctime())
if is_module_wrapper(model):
model = model.module
if hasattr(model, 'CLASSES') and model.CLASSES is not None:
# save class name to the meta
meta.update(CLASSES=model.CLASSES)
checkpoint = {
'meta': meta,
'state_dict': weights_to_cpu(get_state_dict(model))
}
# save optimizer state dict in the checkpoint
if isinstance(optimizer, Optimizer):
checkpoint['optimizer'] = optimizer.state_dict()
elif isinstance(optimizer, dict):
checkpoint['optimizer'] = {}
for name, optim in optimizer.items():
checkpoint['optimizer'][name] = optim.state_dict()
# save loss scaler for mixed-precision (FP16) training
if loss_scaler is not None:
checkpoint['loss_scaler'] = loss_scaler.state_dict()
# save state_dict from apex.amp
if save_apex_amp:
from apex import amp
checkpoint['amp'] = amp.state_dict()
if filename.startswith('pavi://'):
try:
from pavi import modelcloud
from pavi.exception import NodeNotFoundError
except ImportError:
raise ImportError(
'Please install pavi to load checkpoint from modelcloud.')
model_path = filename[7:]
root = modelcloud.Folder()
model_dir, model_name = osp.split(model_path)
try:
model = modelcloud.get(model_dir)
except NodeNotFoundError:
model = root.create_training_model(model_dir)
with TemporaryDirectory() as tmp_dir:
checkpoint_file = osp.join(tmp_dir, model_name)
with open(checkpoint_file, 'wb') as f:
torch.save(checkpoint, f)
f.flush()
model.create_file(checkpoint_file, name=model_name)
else:
mmcv.mkdir_or_exist(osp.dirname(filename))
# immediately flush buffer
with open(filename, 'wb') as f:
torch.save(checkpoint, f)
f.flush()
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import platform
import shutil
import time
import warnings
from functools import partial
import mmcv
import torch
import torch.distributed as dist
from mmcv.parallel import collate, is_module_wrapper
from mmcv.runner import HOOKS, RUNNERS, IterBasedRunner, get_host_info
from torch.optim import Optimizer
from torch.utils.data import DataLoader
from .checkpoint import save_checkpoint
try:
# If PyTorch version >= 1.6.0, torch.cuda.amp.GradScaler would be imported
# and used; otherwise, auto fp16 will adopt mmcv's implementation.
from torch.cuda.amp import GradScaler
except ImportError:
pass
class IterLoader:
"""Iteration based dataloader.
This wrapper for dataloader is to matching the iter-based training
proceduer.
Args:
dataloader (object): Dataloader in PyTorch.
runner (object): ``mmcv.Runner``
"""
def __init__(self, dataloader, runner):
self._dataloader = dataloader
self.runner = runner
self.iter_loader = iter(self._dataloader)
self._epoch = 0
@property
def epoch(self):
"""The number of current epoch.
Returns:
int: Epoch number.
"""
return self._epoch
def update_dataloader(self, curr_scale):
"""Update dataloader.
Update the dataloader according to the `curr_scale`. This functionality
is very helpful in training progressive growing GANs in which the
dataloader should be updated according to the scale of the models in
training.
Args:
curr_scale (int): The scale in current stage.
"""
# update dataset, sampler, and samples per gpu in dataloader
if hasattr(self._dataloader.dataset, 'update_annotations'):
update_flag = self._dataloader.dataset.update_annotations(
curr_scale)
else:
update_flag = False
if update_flag:
# the sampler should be updated with the modified dataset
assert hasattr(self._dataloader.sampler, 'update_sampler')
samples_per_gpu = None if not hasattr(
self._dataloader.dataset, 'samples_per_gpu'
) else self._dataloader.dataset.samples_per_gpu
self._dataloader.sampler.update_sampler(self._dataloader.dataset,
samples_per_gpu)
# update samples per gpu
if samples_per_gpu is not None:
if dist.is_initialized():
# samples = samples_per_gpu
# self._dataloader.collate_fn = partial(
# collate, samples_per_gpu=samples)
self._dataloader = DataLoader(
self._dataloader.dataset,
batch_size=samples_per_gpu,
sampler=self._dataloader.sampler,
num_workers=self._dataloader.num_workers,
collate_fn=partial(
collate, samples_per_gpu=samples_per_gpu),
shuffle=False,
worker_init_fn=self._dataloader.worker_init_fn)
self.iter_loader = iter(self._dataloader)
else:
raise NotImplementedError(
'Currently, we only support dynamic batch size in'
' ddp, because the number of gpus in DataParallel '
'cannot be obtained easily.')
def __next__(self):
try:
data = next(self.iter_loader)
except StopIteration:
self._epoch += 1
if hasattr(self._dataloader.sampler, 'set_epoch'):
self._dataloader.sampler.set_epoch(self._epoch)
self.iter_loader = iter(self._dataloader)
data = next(self.iter_loader)
return data
def __len__(self):
return len(self._dataloader)
@RUNNERS.register_module()
class DynamicIterBasedRunner(IterBasedRunner):
"""Dynamic Iterbased Runner.
In this Dynamic Iterbased Runner, we will pass the ``reducer`` to the
``train_step`` so that the models can be trained with dynamic architecture.
More details and clarification can be found in this [tutorial](docs/en/tutorials/ddp_train_gans.md). # noqa
Args:
is_dynamic_ddp (bool, optional): Whether to adopt the dynamic ddp.
Defaults to False.
pass_training_status (bool, optional): Whether to pass the training
status. Defaults to False.
fp16_loss_scaler (dict | None, optional): Config for fp16 GradScaler
from ``torch.cuda.amp``. Defaults to None.
use_apex_amp (bool, optional): Whether to use apex.amp to start mixed
precision training. Defaults to False.
"""
def __init__(self,
*args,
is_dynamic_ddp=False,
pass_training_status=False,
fp16_loss_scaler=None,
use_apex_amp=False,
**kwargs):
super().__init__(*args, **kwargs)
if is_module_wrapper(self.model):
_model = self.model.module
else:
_model = self.model
self.is_dynamic_ddp = is_dynamic_ddp
self.pass_training_status = pass_training_status
# add a flag for checking if `self.optimizer` comes from `_model`
self.optimizer_from_model = False
# add support for optimizer is None.
# sanity check for whether `_model` contains self-defined optimizer
if hasattr(_model, 'optimizer'):
assert self.optimizer is None, (
'Runner and model cannot contain optimizer at the same time.')
self.optimizer_from_model = True
self.optimizer = _model.optimizer
# add fp16 grad scaler, using pytorch official GradScaler
self.with_fp16_grad_scaler = False
if fp16_loss_scaler is not None:
self.loss_scaler = GradScaler(**fp16_loss_scaler)
self.with_fp16_grad_scaler = True
mmcv.print_log('Use FP16 grad scaler in Training', 'mmgen')
# flag to use amp in apex (NVIDIA)
self.use_apex_amp = use_apex_amp
def call_hook(self, fn_name):
"""Call all hooks.
Args:
fn_name (str): The function name in each hook to be called, such as
"before_train_epoch".
"""
for hook in self._hooks:
if hasattr(hook, fn_name):
getattr(hook, fn_name)(self)
def train(self, data_loader, **kwargs):
if is_module_wrapper(self.model):
_model = self.model.module
else:
_model = self.model
self.model.train()
self.mode = 'train'
# check if self.optimizer from model and track it
if self.optimizer_from_model:
self.optimizer = _model.optimizer
self.data_loader = data_loader
self._epoch = data_loader.epoch
self.call_hook('before_fetch_train_data')
data_batch = next(self.data_loader)
self.call_hook('before_train_iter')
# prepare input args for train_step
# running status
if self.pass_training_status:
running_status = dict(iteration=self.iter, epoch=self.epoch)
kwargs['running_status'] = running_status
# ddp reducer for tracking dynamic computational graph
if self.is_dynamic_ddp:
kwargs.update(dict(ddp_reducer=self.model.reducer))
if self.with_fp16_grad_scaler:
kwargs.update(dict(loss_scaler=self.loss_scaler))
if self.use_apex_amp:
kwargs.update(dict(use_apex_amp=True))
outputs = self.model.train_step(data_batch, self.optimizer, **kwargs)
# the loss scaler should be updated after ``train_step``
if self.with_fp16_grad_scaler:
self.loss_scaler.update()
# further check for the cases where the optimizer is built in
# `train_step`.
if self.optimizer is None:
if hasattr(_model, 'optimizer'):
self.optimizer_from_model = True
self.optimizer = _model.optimizer
# check if self.optimizer from model and track it
if self.optimizer_from_model:
self.optimizer = _model.optimizer
if not isinstance(outputs, dict):
raise TypeError('model.train_step() must return a dict')
if 'log_vars' in outputs:
self.log_buffer.update(outputs['log_vars'], outputs['num_samples'])
self.outputs = outputs
self.call_hook('after_train_iter')
self._inner_iter += 1
self._iter += 1
def run(self, data_loaders, workflow, max_iters=None, **kwargs):
"""Start running.
Args:
data_loaders (list[:obj:`DataLoader`]): Dataloaders for training
and validation.
workflow (list[tuple]): A list of (phase, iters) to specify the
running order and iterations. E.g, [('train', 10000),
('val', 1000)] means running 10000 iterations for training and
1000 iterations for validation, iteratively.
"""
assert isinstance(data_loaders, list)
assert mmcv.is_list_of(workflow, tuple)
assert len(data_loaders) == len(workflow)
if max_iters is not None:
warnings.warn(
'setting max_iters in run is deprecated, '
'please set max_iters in runner_config', DeprecationWarning)
self._max_iters = max_iters
assert self._max_iters is not None, (
'max_iters must be specified during instantiation')
work_dir = self.work_dir if self.work_dir is not None else 'NONE'
self.logger.info('Start running, host: %s, work_dir: %s',
get_host_info(), work_dir)
self.logger.info('workflow: %s, max: %d iters', workflow,
self._max_iters)
self.call_hook('before_run')
iter_loaders = [IterLoader(x, self) for x in data_loaders]
self.call_hook('before_epoch')
while self.iter < self._max_iters:
for i, flow in enumerate(workflow):
self._inner_iter = 0
mode, iters = flow
if not isinstance(mode, str) or not hasattr(self, mode):
raise ValueError(
'runner has no method named "{}" to run a workflow'.
format(mode))
iter_runner = getattr(self, mode)
for _ in range(iters):
if mode == 'train' and self.iter >= self._max_iters:
break
iter_runner(iter_loaders[i], **kwargs)
time.sleep(1) # wait for some hooks like loggers to finish
self.call_hook('after_epoch')
self.call_hook('after_run')
def resume(self,
checkpoint,
resume_optimizer=True,
resume_loss_scaler=True,
map_location='default'):
"""Resume model from checkpoint.
Args:
checkpoint (str): Checkpoint to resume from.
resume_optimizer (bool, optional): Whether resume the optimizer(s)
if the checkpoint file includes optimizer(s). Default to True.
resume_loss_scaler (bool, optional): Whether to resume the loss
scaler (GradScaler) from ``torch.cuda.amp``. Defaults to True.
map_location (str, optional): Same as :func:`torch.load`.
Default to 'default'.
"""
if map_location == 'default':
device_id = torch.cuda.current_device()
checkpoint = self.load_checkpoint(
checkpoint,
map_location=lambda storage, loc: storage.cuda(device_id))
else:
checkpoint = self.load_checkpoint(
checkpoint, map_location=map_location)
self._epoch = checkpoint['meta']['epoch']
self._iter = checkpoint['meta']['iter']
self._inner_iter = checkpoint['meta']['iter']
if 'optimizer' in checkpoint and resume_optimizer:
if isinstance(self.optimizer, Optimizer):
self.optimizer.load_state_dict(checkpoint['optimizer'])
elif isinstance(self.optimizer, dict):
for k in self.optimizer.keys():
self.optimizer[k].load_state_dict(
checkpoint['optimizer'][k])
else:
raise TypeError(
'Optimizer should be dict or torch.optim.Optimizer '
f'but got {type(self.optimizer)}')
if 'loss_scaler' in checkpoint and resume_loss_scaler:
self.loss_scaler.load_state_dict(checkpoint['loss_scaler'])
if self.use_apex_amp:
from apex import amp
amp.load_state_dict(checkpoint['amp'])
self.logger.info(f'resumed from epoch: {self.epoch}, iter {self.iter}')
def save_checkpoint(self,
out_dir,
filename_tmpl='iter_{}.pth',
meta=None,
save_optimizer=True,
create_symlink=True):
"""Save checkpoint to file.
Args:
out_dir (str): Directory to save checkpoint files.
filename_tmpl (str, optional): Checkpoint file template.
Defaults to 'iter_{}.pth'.
meta (dict, optional): Metadata to be saved in checkpoint.
Defaults to None.
save_optimizer (bool, optional): Whether save optimizer.
Defaults to True.
create_symlink (bool, optional): Whether create symlink to the
latest checkpoint file. Defaults to True.
"""
if meta is None:
meta = dict(iter=self.iter + 1, epoch=self.epoch + 1)
elif isinstance(meta, dict):
meta.update(iter=self.iter + 1, epoch=self.epoch + 1)
else:
raise TypeError(
f'meta should be a dict or None, but got {type(meta)}')
if self.meta is not None:
meta.update(self.meta)
filename = filename_tmpl.format(self.iter + 1)
filepath = osp.join(out_dir, filename)
optimizer = self.optimizer if save_optimizer else None
_loss_scaler = self.loss_scaler if self.with_fp16_grad_scaler else None
save_checkpoint(
self.model,
filepath,
optimizer=optimizer,
loss_scaler=_loss_scaler,
save_apex_amp=self.use_apex_amp,
meta=meta)
# in some environments, `os.symlink` is not supported, you may need to
# set `create_symlink` to False
if create_symlink:
dst_file = osp.join(out_dir, 'latest.pth')
if platform.system() != 'Windows':
mmcv.symlink(filename, dst_file)
else:
shutil.copy(filepath, dst_file)
def register_lr_hook(self, lr_config):
if lr_config is None:
return
if isinstance(lr_config, dict):
assert 'policy' in lr_config
policy_type = lr_config.pop('policy')
# If the type of policy is all in lower case, e.g., 'cyclic',
# then its first letter will be capitalized, e.g., to be 'Cyclic'.
# This is for the convenient usage of Lr updater.
# Since this is not applicable for `
# CosineAnnealingLrUpdater`,
# the string will not be changed if it contains capital letters.
if policy_type == policy_type.lower():
policy_type = policy_type.title()
hook_type = policy_type + 'LrUpdaterHook'
lr_config['type'] = hook_type
hook = mmcv.build_from_cfg(lr_config, HOOKS)
else:
hook = lr_config
self.register_hook(hook)
# Copyright (c) OpenMMLab. All rights reserved.
import functools
from collections import abc
from inspect import getfullargspec
import numpy as np
import torch
import torch.nn as nn
from mmcv.utils import TORCH_VERSION
try:
# If PyTorch version >= 1.6.0, torch.cuda.amp.autocast would be imported
# and used; otherwise, auto fp16 will adopt mmcv's implementation.
from torch.cuda.amp import autocast
except ImportError:
pass
def nan_to_num(x, nan=0.0, posinf=None, neginf=None, *, out=None):
r"""Replaces :literal:`NaN`, positive infinity, and negative infinity
values in :attr:`input` with the values specified by :attr:`nan`,
:attr:`posinf`, and :attr:`neginf`, respectively. By default,
:literal:`NaN`s are replaced with zero, positive infinity is replaced with
the greatest finite value representable by :attr:`input`'s dtype, and
negative infinity is replaced with the least finite value representable by
:attr:`input`'s dtype.
.. note::
This function is provided in ``PyTorch>=1.8.0``. Here is a
reimplementation to avoid attribute error in lower PyTorch version.
Args:
x (Tensor): Input tensor.
nan (Number, optional): the value to replace :literal:`NaN`\s with.
Default is zero.
posinf (Number, optional): if a Number, the value to replace positive
infinity values with. If None, positive infinity values are
replaced with the greatest finite value representable by
:attr:`input`'s dtype. Default is None.
neginf (Number, optional): if a Number, the value to replace negative
infinity values with. If None, negative infinity values are
replaced with the lowest finite value representable by
:attr:`input`'s dtype. Default is None.
Returns:
Tensor: Output tensor.
"""
try:
return torch.nan_to_num(
x, nan=nan, posinf=posinf, neginf=neginf, out=out)
except AttributeError:
if not isinstance(x, torch.Tensor):
raise TypeError(
f'argument input (position 1) must be Tensor, not {type(x)}')
if posinf is None:
posinf = torch.finfo(x.dtype).max
if neginf is None:
neginf = torch.finfo(x.dtype).min
assert nan == 0
# a better choice is to use nansum, but this function is not supported
# in PyTorch 1.5
# x.unsqueeze(0).nansum(0)
x[torch.isnan(x)] = 0.
return torch.clamp(x, min=neginf, max=posinf, out=out)
def cast_tensor_type(inputs, src_type, dst_type):
"""Recursively convert Tensor in inputs from src_type to dst_type.
Args:
inputs: Inputs that to be casted.
src_type (torch.dtype): Source type..
dst_type (torch.dtype): Destination type.
Returns:
The same type with inputs, but all contained Tensors have been cast.
"""
if isinstance(inputs, torch.Tensor):
return inputs.to(dst_type)
if isinstance(inputs, nn.Module):
return inputs
elif isinstance(inputs, str):
return inputs
elif isinstance(inputs, np.ndarray):
return inputs
elif isinstance(inputs, abc.Mapping):
return type(inputs)({
k: cast_tensor_type(v, src_type, dst_type)
for k, v in inputs.items()
})
elif isinstance(inputs, abc.Iterable):
return type(inputs)(
cast_tensor_type(item, src_type, dst_type) for item in inputs)
else:
return inputs
def auto_fp16(apply_to=None, out_fp32=False):
"""Decorator to enable fp16 training automatically.
This decorator is useful when you write custom modules and want to support
mixed precision training. If inputs arguments are fp32 tensors, they will
be converted to fp16 automatically. Arguments other than fp32 tensors are
ignored. If you are using PyTorch >= 1.6, torch.cuda.amp is used as the
backend, otherwise, original mmcv implementation will be adopted.
Args:
apply_to (Iterable, optional): The argument names to be converted.
`None` indicates all arguments.
out_fp32 (bool): Whether to convert the output back to fp32.
Example:
>>> import torch.nn as nn
>>> class MyModule1(nn.Module):
>>>
>>> # Convert x and y to fp16
>>> @auto_fp16()
>>> def forward(self, x, y):
>>> pass
>>> import torch.nn as nn
>>> class MyModule2(nn.Module):
>>>
>>> # convert pred to fp16
>>> @auto_fp16(apply_to=('pred', ))
>>> def do_something(self, pred, others):
>>> pass
"""
def auto_fp16_wrapper(old_func):
@functools.wraps(old_func)
def new_func(*args, **kwargs):
# check if the module has set the attribute `fp16_enabled`, if not,
# just fallback to the original method.
if not isinstance(args[0], torch.nn.Module):
raise TypeError('@auto_fp16 can only be used to decorate the '
'method of nn.Module')
if not (hasattr(args[0], 'fp16_enabled') and args[0].fp16_enabled):
return old_func(*args, **kwargs)
# define output type by class itself
if hasattr(args[0], 'out_fp32') and args[0].out_fp32:
_out_fp32 = True
else:
_out_fp32 = False
# get the arg spec of the decorated method
args_info = getfullargspec(old_func)
# get the argument names to be casted
# Here, we change the default behaviour with Yu Xiong's
# implementation
args_to_cast = [] if apply_to is None else apply_to
# convert the args that need to be processed
new_args = []
# NOTE: default args are not taken into consideration
if args:
arg_names = args_info.args[:len(args)]
for i, arg_name in enumerate(arg_names):
if arg_name in args_to_cast:
new_args.append(
cast_tensor_type(args[i], torch.float, torch.half))
else:
new_args.append(args[i])
# convert the kwargs that need to be processed
new_kwargs = {}
if kwargs:
for arg_name, arg_value in kwargs.items():
if arg_name in args_to_cast:
new_kwargs[arg_name] = cast_tensor_type(
arg_value, torch.float, torch.half)
else:
new_kwargs[arg_name] = arg_value
# apply converted arguments to the decorated method
if TORCH_VERSION != 'parrots' and TORCH_VERSION >= '1.6.0':
output = autocast(enabled=True)(old_func)(*new_args,
**new_kwargs)
else:
# output = old_func(*new_args, **new_kwargs)
raise RuntimeError('Please use PyTorch >= 1.6.0')
# cast the results back to fp32 if necessary
if out_fp32 or _out_fp32:
output = cast_tensor_type(output, torch.half, torch.float)
return output
return new_func
return auto_fp16_wrapper
# Copyright (c) OpenMMLab. All rights reserved.
from .lr_updater import LinearLrUpdaterHook
__all__ = ['LinearLrUpdaterHook']
# Copyright (c) OpenMMLab. All rights reserved.
from mmcv.runner import HOOKS, LrUpdaterHook
@HOOKS.register_module()
class LinearLrUpdaterHook(LrUpdaterHook):
"""Linear learning rate scheduler for image generation.
In the beginning, the learning rate is 'base_lr' defined in mmcv.
We give a target learning rate 'target_lr' and a start point 'start'
(iteration / epoch). Before 'start', we fix learning rate as 'base_lr';
After 'start', we linearly update learning rate to 'target_lr'.
Args:
target_lr (float): The target learning rate. Default: 0.
start (int): The start point (iteration / epoch, specified by args
'by_epoch' in its parent class in mmcv) to update learning rate.
Default: 0.
interval (int): The interval to update the learning rate. Default: 1.
"""
def __init__(self, target_lr=0, start=0, interval=1, **kwargs):
super().__init__(**kwargs)
self.target_lr = target_lr
self.start = start
self.interval = interval
def get_lr(self, runner, base_lr):
"""Calculates the learning rate.
Args:
runner (object): The passed runner.
base_lr (float): Base learning rate.
Returns:
float: Current learning rate.
"""
if self.by_epoch:
progress = runner.epoch
max_progress = runner.max_epochs
else:
progress = runner.iter
max_progress = runner.max_iters
assert max_progress >= self.start
if max_progress == self.start:
return base_lr
# Before 'start', fix lr; After 'start', linearly update lr.
factor = (max(0, progress - self.start) // self.interval) / (
(max_progress - self.start) // self.interval)
return base_lr + (self.target_lr - base_lr) * factor
# Copyright (c) OpenMMLab. All rights reserved.
from .builder import build_dataloader, build_dataset
from .dataset_wrappers import RepeatDataset
from .grow_scale_image_dataset import GrowScaleImgDataset
from .paired_image_dataset import PairedImageDataset
from .pipelines import (Collect, Compose, Flip, ImageToTensor,
LoadImageFromFile, Normalize, Resize, ToTensor)
from .quick_test_dataset import QuickTestImageDataset
from .samplers import DistributedSampler
from .singan_dataset import SinGANDataset
from .unconditional_image_dataset import UnconditionalImageDataset
from .unpaired_image_dataset import UnpairedImageDataset
__all__ = [
'build_dataloader', 'build_dataset', 'LoadImageFromFile',
'DistributedSampler', 'UnconditionalImageDataset', 'Compose', 'ToTensor',
'ImageToTensor', 'Collect', 'Flip', 'Resize', 'RepeatDataset', 'Normalize',
'GrowScaleImgDataset', 'SinGANDataset', 'PairedImageDataset',
'UnpairedImageDataset', 'QuickTestImageDataset'
]
# Copyright (c) OpenMMLab. All rights reserved.
import platform
import random
import warnings
from copy import deepcopy
from functools import partial
import numpy as np
import torch
from mmcv.parallel import collate
from mmcv.runner import get_dist_info
from mmcv.utils import TORCH_VERSION, Registry, build_from_cfg, digit_version
from torch.utils.data import DataLoader
from .samplers import DistributedSampler
if platform.system() != 'Windows':
# https://github.com/pytorch/pytorch/issues/973
import resource
rlimit = resource.getrlimit(resource.RLIMIT_NOFILE)
base_soft_limit = rlimit[0]
hard_limit = rlimit[1]
soft_limit = min(max(4096, base_soft_limit), hard_limit)
resource.setrlimit(resource.RLIMIT_NOFILE, (soft_limit, hard_limit))
DATASETS = Registry('dataset')
PIPELINES = Registry('pipeline')
def build_dataset(cfg, default_args=None):
"""Build dataset.
Args:
cfg (dict): Config for the dataset.
default_args (dict | None, optional): Default arguments.
Defaults to None.
Returns:
Object: Dataset for sampling data batch.
"""
from .dataset_wrappers import RepeatDataset
if isinstance(cfg, (list, tuple)):
raise NotImplementedError('Currently, we do NOT support ConcatDataset')
# dataset = ConcatDataset(
# [build_dataset(c, default_args) for c in cfg])
if cfg['type'] == 'RepeatDataset':
dataset = RepeatDataset(
build_dataset(cfg['dataset'], default_args), cfg['times'])
# add support for using datasets from `MMClassification`
elif cfg['type'].startswith('mmcls.'):
try:
from mmcls.datasets import build_dataset as build_dataset_mmcls
except ImportError:
raise ImportError(
f'Please install mmcls to use {cfg["type"]} dataset.')
_cfg = deepcopy(cfg)
_cfg['type'] = _cfg['type'][6:]
dataset = build_dataset_mmcls(_cfg, default_args)
else:
dataset = build_from_cfg(cfg, DATASETS, default_args)
return dataset
def build_dataloader(dataset,
samples_per_gpu,
workers_per_gpu,
num_gpus=1,
dist=True,
shuffle=True,
seed=None,
persistent_workers=False,
**kwargs):
"""Build PyTorch DataLoader.
In distributed training, each GPU/process has a dataloader.
In non-distributed training, there is only one dataloader for all GPUs.
Args:
dataset (Dataset): A PyTorch dataset.
samples_per_gpu (int): Number of training samples on each GPU, i.e.,
batch size of each GPU.
workers_per_gpu (int): How many subprocesses to use for data loading
for each GPU.
num_gpus (int): Number of GPUs. Only used in non-distributed training.
dist (bool): Distributed training/test or not. Default: True.
shuffle (bool): Whether to shuffle the data at every epoch.
Default: True.
persistent_workers (bool, optional): If True, the data loader will
not shutdown the worker processes after a dataset has been
consumed once. This allows to maintain the workers Dataset
instances alive. The argument also has effect in PyTorch>=1.7.0.
Default: False.
kwargs: any keyword argument to be used to initialize DataLoader
Returns:
DataLoader: A PyTorch dataloader.
"""
rank, world_size = get_dist_info()
if dist:
sampler = DistributedSampler(
dataset,
world_size,
rank,
shuffle=shuffle,
samples_per_gpu=samples_per_gpu,
seed=seed)
shuffle = False
batch_size = samples_per_gpu
num_workers = workers_per_gpu
else:
sampler = None
batch_size = num_gpus * samples_per_gpu
num_workers = num_gpus * workers_per_gpu
init_fn = partial(
worker_init_fn, num_workers=num_workers, rank=rank,
seed=seed) if seed is not None else None
if (digit_version(TORCH_VERSION) >= digit_version('1.7.0')
and TORCH_VERSION != 'parrots'):
kwargs['persistent_workers'] = persistent_workers
elif persistent_workers is True:
warnings.warn('persistent_workers is invalid because your pytorch '
'version is lower than 1.7.0')
data_loader = DataLoader(
dataset,
batch_size=batch_size,
sampler=sampler,
num_workers=num_workers,
collate_fn=partial(collate, samples_per_gpu=samples_per_gpu),
shuffle=shuffle,
worker_init_fn=init_fn,
**kwargs)
return data_loader
def worker_init_fn(worker_id, num_workers, rank, seed):
# The seed of each worker equals to
# num_worker * rank + worker_id + user_seed
worker_seed = num_workers * rank + worker_id + seed
np.random.seed(worker_seed)
random.seed(worker_seed)
torch.manual_seed(worker_seed)
# Copyright (c) OpenMMLab. All rights reserved.
from .builder import DATASETS
@DATASETS.register_module()
class RepeatDataset:
"""A wrapper of repeated dataset.
The length of repeated dataset will be `times` larger than the original
dataset. This is useful when the data loading time is long but the dataset
is small. Using RepeatDataset can reduce the data loading time between
epochs.
Args:
dataset (:obj:`Dataset`): The dataset to be repeated.
times (int): Repeat times.
"""
def __init__(self, dataset, times):
self.dataset = dataset
self.times = times
self._ori_len = len(self.dataset)
def __getitem__(self, idx):
"""Get item at each call.
Args:
idx (int): Index for getting each item.
"""
return self.dataset[idx % self._ori_len]
def __len__(self):
"""Length of the dataset.
Returns:
int: Length of the dataset.
"""
return self.times * self._ori_len
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import mmcv
from torch.utils.data import Dataset
from .builder import DATASETS
from .pipelines import Compose
@DATASETS.register_module()
class GrowScaleImgDataset(Dataset):
"""Grow Scale Unconditional Image Dataset.
This dataset is similar with ``UnconditionalImageDataset``, but offer
more dynamic functionalities for the supporting complex algorithms, like
PGGAN.
Highlight functionalities:
#. Support growing scale dataset. The motivation is to decrease data
pre-processing load in CPU. In this dataset, you can provide
``imgs_roots`` like:
.. code-block:: python
{'64': 'path_to_64x64_imgs',
'512': 'path_to_512x512_imgs'}
Then, in training scales lower than 64x64, this dataset will set
``self.imgs_root`` as 'path_to_64x64_imgs';
#. Offer ``samples_per_gpu`` according to different scales. In this
dataset, ``self.samples_per_gpu`` will help runner to know the updated
batch size.
Basically, This dataset contains raw images for training unconditional
GANs. Given a root dir, we will recursively find all images in this root.
The transformation on data is defined by the pipeline.
Args:
imgs_root (str): Root path for unconditional images.
pipeline (list[dict | callable]): A sequence of data transforms.
len_per_stage (int, optional): The length of dataset for each scale.
This args change the length dataset by concatenating or extracting
subset. If given a value less than 0., the original length will be
kept. Defaults to 1e6.
gpu_samples_per_scale (dict | None, optional): Dict contains
``samples_per_gpu`` for each scale. For example, ``{'32': 4}`` will
set the scale of 32 with ``samples_per_gpu=4``, despite other scale
with ``samples_per_gpu=self.gpu_samples_base``.
gpu_samples_base (int, optional): Set default ``samples_per_gpu`` for
each scale. Defaults to 32.
test_mode (bool, optional): If True, the dataset will work in test
mode. Otherwise, in train mode. Default to False.
"""
_VALID_IMG_SUFFIX = ('.jpg', '.png', '.jpeg', '.JPEG')
def __init__(self,
imgs_roots,
pipeline,
len_per_stage=int(1e6),
gpu_samples_per_scale=None,
gpu_samples_base=32,
test_mode=False):
super().__init__()
assert isinstance(imgs_roots, dict)
self.imgs_roots = imgs_roots
self._img_scales = sorted([int(x) for x in imgs_roots.keys()])
self._curr_scale = self._img_scales[0]
self._actual_curr_scale = self._curr_scale
self.imgs_root = self.imgs_roots[str(self._curr_scale)]
self.pipeline = Compose(pipeline)
self.test_mode = test_mode
# len_per_stage = -1, keep the original length
self.len_per_stage = len_per_stage
self.curr_stage = 0
self.gpu_samples_per_scale = gpu_samples_per_scale
if self.gpu_samples_per_scale is not None:
assert isinstance(self.gpu_samples_per_scale, dict)
else:
self.gpu_samples_per_scale = dict()
self.gpu_samples_base = gpu_samples_base
self.load_annotations()
# print basic dataset information to check the validity
mmcv.print_log(repr(self), 'mmgen')
def load_annotations(self):
"""Load annotations."""
# recursively find all of the valid images from imgs_root
imgs_list = mmcv.scandir(
self.imgs_root, self._VALID_IMG_SUFFIX, recursive=True)
self.imgs_list = [osp.join(self.imgs_root, x) for x in imgs_list]
if self.len_per_stage > 0:
self.concat_imgs_list_to(self.len_per_stage)
self.samples_per_gpu = self.gpu_samples_per_scale.get(
str(self._actual_curr_scale), self.gpu_samples_base)
def update_annotations(self, curr_scale):
"""Update annotations.
Args:
curr_scale (int): Current image scale.
Returns:
bool: Whether to update.
"""
if curr_scale == self._actual_curr_scale:
return False
for scale in self._img_scales:
if curr_scale <= scale:
self._curr_scale = scale
break
if scale == self._img_scales[-1]:
assert RuntimeError(
f'Cannot find a suitable scale for {curr_scale}')
self._actual_curr_scale = curr_scale
self.imgs_root = self.imgs_roots[str(self._curr_scale)]
self.load_annotations()
# print basic dataset information to check the validity
mmcv.print_log('Update Dataset: ' + repr(self), 'mmgen')
return True
def concat_imgs_list_to(self, num):
"""Concat image list to specified length.
Args:
num (int): The length of the concatenated image list.
"""
if num <= len(self.imgs_list):
self.imgs_list = self.imgs_list[:num]
return
concat_factor = (num // len(self.imgs_list)) + 1
imgs = self.imgs_list * concat_factor
self.imgs_list = imgs[:num]
def prepare_train_data(self, idx):
"""Prepare training data.
Args:
idx (int): Index of current batch.
Returns:
dict: Prepared training data batch.
"""
results = dict(real_img_path=self.imgs_list[idx])
return self.pipeline(results)
def prepare_test_data(self, idx):
"""Prepare testing data.
Args:
idx (int): Index of current batch.
Returns:
dict: Prepared training data batch.
"""
results = dict(real_img_path=self.imgs_list[idx])
return self.pipeline(results)
def __len__(self):
return len(self.imgs_list)
def __getitem__(self, idx):
if not self.test_mode:
return self.prepare_train_data(idx)
return self.prepare_test_data(idx)
def __repr__(self):
dataset_name = self.__class__
imgs_root = self.imgs_root
num_imgs = len(self)
return (f'dataset_name: {dataset_name}, total {num_imgs} images in '
f'imgs_root: {imgs_root}')
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import os.path as osp
from pathlib import Path
from mmcv import scandir
from torch.utils.data import Dataset
from .builder import DATASETS
from .pipelines import Compose
IMG_EXTENSIONS = ('.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG', '.ppm',
'.PPM', '.bmp', '.BMP', '.tif', '.TIF', '.tiff', '.TIFF')
@DATASETS.register_module()
class PairedImageDataset(Dataset):
"""General paired image folder dataset for image generation.
It assumes that the training directory is '/path/to/data/train'.
During test time, the directory is '/path/to/data/test'. '/path/to/data'
can be initialized by args 'dataroot'. Each sample contains a pair of
images concatenated in the w dimension (A|B).
Args:
dataroot (str | :obj:`Path`): Path to the folder root of paired images.
pipeline (List[dict | callable]): A sequence of data transformations.
test_mode (bool): Store `True` when building test dataset.
Default: `False`.
testdir (str): Subfolder of dataroot which contain test images.
Default: 'test'.
"""
def __init__(self, dataroot, pipeline, test_mode=False, testdir='test'):
super().__init__()
phase = testdir if test_mode else 'train'
self.dataroot = osp.join(str(dataroot), phase)
self.data_infos = self.load_annotations()
self.test_mode = test_mode
self.pipeline = Compose(pipeline)
def load_annotations(self):
"""Load paired image paths.
Returns:
list[dict]: List that contains paired image paths.
"""
data_infos = []
pair_paths = sorted(self.scan_folder(self.dataroot))
for pair_path in pair_paths:
data_infos.append(dict(pair_path=pair_path))
return data_infos
@staticmethod
def scan_folder(path):
"""Obtain image path list (including sub-folders) from a given folder.
Args:
path (str | :obj:`Path`): Folder path.
Returns:
list[str]: Image list obtained from the given folder.
"""
if isinstance(path, (str, Path)):
path = str(path)
else:
raise TypeError("'path' must be a str or a Path object, "
f'but received {type(path)}.')
images = scandir(path, suffix=IMG_EXTENSIONS, recursive=True)
images = [osp.join(path, v) for v in images]
assert images, f'{path} has no valid image file.'
return images
def prepare_train_data(self, idx):
"""Prepare training data.
Args:
idx (int): Index of the training batch data.
Returns:
dict: Returned training batch.
"""
results = copy.deepcopy(self.data_infos[idx])
return self.pipeline(results)
def prepare_test_data(self, idx):
"""Prepare testing data.
Args:
idx (int): Index for getting each testing batch.
Returns:
Tensor: Returned testing batch.
"""
results = copy.deepcopy(self.data_infos[idx])
return self.pipeline(results)
def __len__(self):
"""Length of the dataset.
Returns:
int: Length of the dataset.
"""
return len(self.data_infos)
def __getitem__(self, idx):
"""Get item at each call.
Args:
idx (int): Index for getting each item.
"""
if not self.test_mode:
return self.prepare_train_data(idx)
return self.prepare_test_data(idx)
# Copyright (c) OpenMMLab. All rights reserved.
from .augmentation import (CenterCropLongEdge, Flip, NumpyPad,
RandomCropLongEdge, RandomImgNoise, Resize)
from .compose import Compose
from .crop import Crop, FixedCrop
from .formatting import Collect, ImageToTensor, ToTensor
from .loading import LoadImageFromFile
from .normalize import Normalize
__all__ = [
'LoadImageFromFile',
'Compose',
'ImageToTensor',
'Collect',
'ToTensor',
'Flip',
'Resize',
'RandomImgNoise',
'RandomCropLongEdge',
'CenterCropLongEdge',
'Normalize',
'NumpyPad',
'Crop',
'FixedCrop',
]
# Copyright (c) OpenMMLab. All rights reserved.
import mmcv
import numpy as np
from mmcls.datasets import PIPELINES as CLS_PIPELINE
from ..builder import PIPELINES
@PIPELINES.register_module()
class Flip:
"""Flip the input data with a probability.
Reverse the order of elements in the given data with a specific direction.
The shape of the data is preserved, but the elements are reordered.
Required keys are the keys in attributes "keys", added or modified keys are
"flip", "flip_direction" and the keys in attributes "keys".
It also supports flipping a list of images with the same flip.
Args:
keys (list[str]): The images to be flipped.
flip_ratio (float): The propability to flip the images.
direction (str): Flip images horizontally or vertically. Options are
"horizontal" | "vertical". Default: "horizontal".
"""
_directions = ['horizontal', 'vertical']
def __init__(self, keys, flip_ratio=0.5, direction='horizontal'):
if direction not in self._directions:
raise ValueError(f'Direction {direction} is not supported.'
f'Currently support ones are {self._directions}')
self.keys = keys
self.flip_ratio = flip_ratio
self.direction = direction
def __call__(self, results):
"""Call function.
Args:
results (dict): A dict containing the necessary information and
data for augmentation.
Returns:
dict: A dict containing the processed data and information.
"""
flip = np.random.random() < self.flip_ratio
if flip:
for key in self.keys:
if isinstance(results[key], list):
for v in results[key]:
mmcv.imflip_(v, self.direction)
else:
mmcv.imflip_(results[key], self.direction)
results['flip'] = flip
results['flip_direction'] = self.direction
return results
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += (f'(keys={self.keys}, flip_ratio={self.flip_ratio}, '
f'direction={self.direction})')
return repr_str
@PIPELINES.register_module()
class Resize:
"""Resize data to a specific size for training or resize the images to fit
the network input regulation for testing.
When used for resizing images to fit network input regulation, the case is
that a network may have several downsample and then upsample operation,
then the input height and width should be divisible by the downsample
factor of the network.
For example, the network would downsample the input for 5 times with
stride 2, then the downsample factor is 2^5 = 32 and the height
and width should be divisible by 32.
Required keys are the keys in attribute "keys", added or modified keys are
"keep_ratio", "scale_factor", "interpolation" and the
keys in attribute "keys".
All keys in "keys" should have the same shape. "test_trans" is used to
record the test transformation to align the input's shape.
Args:
keys (list[str]): The images to be resized.
scale (float | Tuple[int]): If scale is Tuple(int), target spatial
size (h, w). Otherwise, target spatial size is scaled by input
size. If any of scale is -1, we will rescale short edge.
Note that when it is used, `size_factor` and `max_size` are
useless. Default: None
keep_ratio (bool): If set to True, images will be resized without
changing the aspect ratio. Otherwise, it will resize images to a
given size. Default: False.
Note that it is used togher with `scale`.
size_factor (int): Let the output shape be a multiple of size_factor.
Default:None.
Note that when it is used, `scale` should be set to None and
`keep_ratio` should be set to False.
max_size (int): The maximum size of the longest side of the output.
Default:None.
Note that it is used togher with `size_factor`.
interpolation (str): Algorithm used for interpolation:
"nearest" | "bilinear" | "bicubic" | "area" | "lanczos".
Default: "bilinear".
backend (str | None): The image resize backend type. Options are `cv2`,
`pillow`, `None`. If backend is None, the global imread_backend
specified by ``mmcv.use_backend()`` will be used. Default: None.
"""
def __init__(self,
keys,
scale=None,
keep_ratio=False,
size_factor=None,
max_size=None,
interpolation='bilinear',
backend=None):
assert keys, 'Keys should not be empty.'
if size_factor:
assert scale is None, ('When size_factor is used, scale should ',
f'be None. But received {scale}.')
assert keep_ratio is False, ('When size_factor is used, '
'keep_ratio should be False.')
if max_size:
assert size_factor is not None, (
'When max_size is used, '
f'size_factor should also be set. But received {size_factor}.')
if isinstance(scale, float):
if scale <= 0:
raise ValueError(f'Invalid scale {scale}, must be positive.')
elif mmcv.is_tuple_of(scale, int):
max_long_edge = max(scale)
max_short_edge = min(scale)
if max_short_edge == -1:
# assign np.inf to long edge for rescaling short edge later.
scale = (np.inf, max_long_edge)
elif scale is not None:
raise TypeError(
f'Scale must be None, float or tuple of int, but got '
f'{type(scale)}.')
self.keys = keys
self.scale = scale
self.size_factor = size_factor
self.max_size = max_size
self.keep_ratio = keep_ratio
self.interpolation = interpolation
self.backend = backend
def _resize(self, img, scale):
"""Resize given image with corresponding scale.
Args:
img (np.array): Image to be resized.
scale (float | Tuple[int]): Scale used in resize process.
Returns:
tuple: Tuple contains resized image and scale factor in resize
process.
"""
if self.keep_ratio:
img, scale_factor = mmcv.imrescale(
img,
scale,
return_scale=True,
interpolation=self.interpolation,
backend=self.backend)
else:
img, w_scale, h_scale = mmcv.imresize(
img,
scale,
return_scale=True,
interpolation=self.interpolation,
backend=self.backend)
scale_factor = np.array((w_scale, h_scale), dtype=np.float32)
return img, scale_factor
def __call__(self, results):
"""Call function.
Args:
results (dict): A dict containing the necessary information and
data for augmentation.
Returns:
dict: A dict containing the processed data and information.
"""
if self.size_factor:
h, w = results[self.keys[0]].shape[:2]
new_h = h - (h % self.size_factor)
new_w = w - (w % self.size_factor)
if self.max_size:
new_h = min(self.max_size - (self.max_size % self.size_factor),
new_h)
new_w = min(self.max_size - (self.max_size % self.size_factor),
new_w)
scale = (new_w, new_h)
elif isinstance(self.scale, tuple) and (np.inf in self.scale):
# find inf in self.scale, calculate ``scale`` manually
h, w = results[self.keys[0]].shape[:2]
if h < w:
scale = (int(self.scale[-1] / h * w), self.scale[-1])
else:
scale = (self.scale[-1], int(self.scale[-1] / w * h))
else:
# direct use the given ones
scale = self.scale
# here we assume all images in self.keys have the same input size
for key in self.keys:
results[key], scale_factor = self._resize(results[key], scale)
if len(results[key].shape) == 2:
results[key] = np.expand_dims(results[key], axis=2)
results['scale_factor'] = scale_factor
results['keep_ratio'] = self.keep_ratio
results['interpolation'] = self.interpolation
return results
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += (
f'(keys={self.keys}, scale={self.scale}, '
f'keep_ratio={self.keep_ratio}, size_factor={self.size_factor}, '
f'max_size={self.max_size},interpolation={self.interpolation})')
return repr_str
@PIPELINES.register_module()
class NumpyPad:
"""Numpy Padding.
In this augmentation, numpy padding is adopted to customize padding
augmentation. Please carefully read the numpy manual in:
https://numpy.org/doc/stable/reference/generated/numpy.pad.html
If you just hope a single dimension to be padded, you must set ``padding``
like this:
::
padding = ((2, 2), (0, 0), (0, 0))
In this case, if you adopt an input with three dimension, only the first
diemansion will be padded.
Args:
keys (list[str]): The images to be resized.
padding (int | tuple(int)): Please refer to the args ``pad_width`` in
``numpy.pad``.
"""
def __init__(self, keys, padding, **kwargs):
self.keys = keys
self.padding = padding
self.kwargs = kwargs
def __call__(self, results):
"""Call function.
Args:
results (dict): A dict containing the necessary information and
data for augmentation.
Returns:
dict: A dict containing the processed data and information.
"""
for k in self.keys:
results[k] = np.pad(results[k], self.padding, **self.kwargs)
return results
def __repr__(self) -> str:
repr_str = self.__class__.__name__
repr_str += (
f'(keys={self.keys}, padding={self.padding}, kwargs={self.kwargs})'
)
return repr_str
@CLS_PIPELINE.register_module()
@PIPELINES.register_module()
class RandomImgNoise:
"""Add random noise with specific distribution and range to the input
image.
Args:
keys (list[str]): The images to be added random noise.
lower_bound (float, optional): The lower bound of the noise.
Default to ``0.``.
upper_bound (float, optional): The upper bound of the noise.
Default to ``1 / 128.``.
distribution (str, optional): The probability distribution of the
noise. Default to 'uniform'.
"""
def __init__(self,
keys,
lower_bound=0,
upper_bound=1 / 128.,
distribution='uniform'):
assert keys, 'Keys should not be empty.'
self.keys = keys
self.lower_bound = lower_bound
self.upper_bound = upper_bound
if distribution not in ['uniform', 'normal']:
raise KeyError('Only support \'uniform\' distribution and '
'\'normal\' distribution, receive '
f'{distribution}.')
self.distribution = distribution
def __call__(self, results):
"""Call function.
Args:
results (dict): A dict containing the necessary information and
data for augmentation.
Returns:
dict: A dict containing the processed data and information.
"""
if self.distribution == 'uniform':
dist_fn = np.random.rand
else: # self.distribution == 'normal
dist_fn = np.random.randn
for key in self.keys:
img_size = results[key].shape
noise = dist_fn(*img_size)
scale = noise.max() - noise.min()
noise = noise - noise.min()
noise = noise / scale * (self.upper_bound - self.lower_bound)
noise = noise + self.lower_bound
results[key] += noise
return results
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += (f'(keys={self.keys}, lower_bound={self.lower_bound}, '
f'upper_bound={self.upper_bound})')
return repr_str
@CLS_PIPELINE.register_module()
@PIPELINES.register_module()
class RandomCropLongEdge:
"""Random crop the given image by the long edge.
Args:
keys (list[str]): The images to be cropped.
"""
def __init__(self, keys):
assert keys, 'Keys should not be empty.'
self.keys = keys
def __call__(self, results):
"""Call function.
Args:
results (dict): A dict containing the necessary information and
data for augmentation.
Returns:
dict: A dict containing the processed data and information.
"""
for key in self.keys:
img = results[key]
img_height, img_width = img.shape[:2]
crop_size = min(img_height, img_width)
y1 = 0 if img_height == crop_size else \
np.random.randint(0, img_height - crop_size)
x1 = 0 if img_width == crop_size else \
np.random.randint(0, img_width - crop_size)
y2, x2 = y1 + crop_size - 1, x1 + crop_size - 1
img = mmcv.imcrop(img, bboxes=np.array([x1, y1, x2, y2]))
results[key] = img
return results
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += (f'(keys={self.keys})')
return repr_str
@CLS_PIPELINE.register_module()
@PIPELINES.register_module()
class CenterCropLongEdge:
"""Center crop the given image by the long edge.
Args:
keys (list[str]): The images to be cropped.
"""
def __init__(self, keys):
assert keys, 'Keys should not be empty.'
self.keys = keys
def __call__(self, results):
"""Call function.
Args:
results (dict): A dict containing the necessary information and
data for augmentation.
Returns:
dict: A dict containing the processed data and information.
"""
for key in self.keys:
img = results[key]
img_height, img_width = img.shape[:2]
crop_size = min(img_height, img_width)
y1 = 0 if img_height == crop_size else \
int(round(img_height - crop_size) / 2)
x1 = 0 if img_width == crop_size else \
int(round(img_width - crop_size) / 2)
y2 = y1 + crop_size - 1
x2 = x1 + crop_size - 1
img = mmcv.imcrop(img, bboxes=np.array([x1, y1, x2, y2]))
results[key] = img
return results
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += (f'(keys={self.keys})')
return repr_str
# Copyright (c) OpenMMLab. All rights reserved.
from collections.abc import Sequence
from copy import deepcopy
from mmcv.utils import build_from_cfg
from ..builder import PIPELINES
@PIPELINES.register_module()
class Compose:
"""Compose a data pipeline with a sequence of transforms.
Args:
transforms (list[dict | callable]):
Either config dicts of transforms or transform objects.
"""
def __init__(self, transforms):
assert isinstance(transforms, Sequence)
self.transforms = []
for transform in transforms:
if isinstance(transform, dict):
# add support for using pipelines from `MMClassification`
if transform['type'].startswith('mmcls.'):
try:
from mmcls.datasets import PIPELINES as MMCLSPIPELINE
except ImportError:
raise ImportError('Please install mmcls to use '
f'{transform["type"]} dataset.')
pipeline_source = MMCLSPIPELINE
# remove prefix
transform_cfg = deepcopy(transform)
transform_cfg['type'] = transform_cfg['type'][6:]
else:
pipeline_source = PIPELINES
transform_cfg = deepcopy(transform)
transform = build_from_cfg(transform_cfg, pipeline_source)
self.transforms.append(transform)
elif callable(transform):
self.transforms.append(transform)
else:
raise TypeError(f'transform must be callable or a dict, '
f'but got {type(transform)}')
def __call__(self, data):
"""Call function.
Args:
data (dict): A dict containing the necessary information and
data for augmentation.
Returns:
dict: A dict containing the processed data and information.
"""
for t in self.transforms:
data = t(data)
if data is None:
return None
return data
def __repr__(self):
format_string = self.__class__.__name__ + '('
for t in self.transforms:
format_string += '\n'
format_string += f' {t}'
format_string += '\n)'
return format_string
# Copyright (c) OpenMMLab. All rights reserved.
import mmcv
import numpy as np
from ..builder import PIPELINES
@PIPELINES.register_module()
class Crop:
"""Crop data to specific size for training.
Args:
keys (Sequence[str]): The images to be cropped.
crop_size (Tuple[int]): Target spatial size (h, w).
random_crop (bool): If set to True, it will random crop
image. Otherwise, it will work as center crop.
"""
def __init__(self, keys, crop_size, random_crop=True):
if not mmcv.is_tuple_of(crop_size, int):
raise TypeError(
'Elements of crop_size must be int and crop_size must be'
f' tuple, but got {type(crop_size[0])} in {type(crop_size)}')
self.keys = keys
self.crop_size = crop_size
self.random_crop = random_crop
def _crop(self, data):
if not isinstance(data, list):
data_list = [data]
else:
data_list = data
crop_bbox_list = []
data_list_ = []
for item in data_list:
data_h, data_w = item.shape[:2]
crop_h, crop_w = self.crop_size
crop_h = min(data_h, crop_h)
crop_w = min(data_w, crop_w)
if self.random_crop:
x_offset = np.random.randint(0, data_w - crop_w + 1)
y_offset = np.random.randint(0, data_h - crop_h + 1)
else:
x_offset = max(0, (data_w - crop_w)) // 2
y_offset = max(0, (data_h - crop_h)) // 2
crop_bbox = [x_offset, y_offset, crop_w, crop_h]
item_ = item[y_offset:y_offset + crop_h,
x_offset:x_offset + crop_w, ...]
crop_bbox_list.append(crop_bbox)
data_list_.append(item_)
if not isinstance(data, list):
return data_list_[0], crop_bbox_list[0]
return data_list_, crop_bbox_list
def __call__(self, results):
"""Call function.
Args:
results (dict): A dict containing the necessary information and
data for augmentation.
Returns:
dict: A dict containing the processed data and information.
"""
for k in self.keys:
data_, crop_bbox = self._crop(results[k])
results[k] = data_
results[k + '_crop_bbox'] = crop_bbox
results['crop_size'] = self.crop_size
return results
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += (f'keys={self.keys}, crop_size={self.crop_size}, '
f'random_crop={self.random_crop}')
return repr_str
@PIPELINES.register_module()
class FixedCrop:
"""Crop paired data (at a specific position) to specific size for training.
Args:
keys (Sequence[str]): The images to be cropped.
crop_size (Tuple[int]): Target spatial size (h, w).
crop_pos (Tuple[int]): Specific position (x, y). If set to None,
random initialize the position to crop paired data batch.
"""
def __init__(self, keys, crop_size, crop_pos=None):
if not mmcv.is_tuple_of(crop_size, int):
raise TypeError(
'Elements of crop_size must be int and crop_size must be'
f' tuple, but got {type(crop_size[0])} in {type(crop_size)}')
if not mmcv.is_tuple_of(crop_pos, int) and (crop_pos is not None):
raise TypeError(
'Elements of crop_pos must be int and crop_pos must be'
f' tuple or None, but got {type(crop_pos[0])} in '
f'{type(crop_pos)}')
self.keys = keys
self.crop_size = crop_size
self.crop_pos = crop_pos
def _crop(self, data, x_offset, y_offset, crop_w, crop_h):
crop_bbox = [x_offset, y_offset, crop_w, crop_h]
data_ = data[y_offset:y_offset + crop_h, x_offset:x_offset + crop_w,
...]
return data_, crop_bbox
def __call__(self, results):
"""Call function.
Args:
results (dict): A dict containing the necessary information and
data for augmentation.
Returns:
dict: A dict containing the processed data and information.
"""
data_h, data_w = results[self.keys[0]].shape[:2]
crop_h, crop_w = self.crop_size
crop_h = min(data_h, crop_h)
crop_w = min(data_w, crop_w)
if self.crop_pos is None:
x_offset = np.random.randint(0, data_w - crop_w + 1)
y_offset = np.random.randint(0, data_h - crop_h + 1)
else:
x_offset, y_offset = self.crop_pos
crop_w = min(data_w - x_offset, crop_w)
crop_h = min(data_h - y_offset, crop_h)
for k in self.keys:
# In fixed crop for paired images, sizes should be the same
if (results[k].shape[0] != data_h
or results[k].shape[1] != data_w):
raise ValueError(
'The sizes of paired images should be the same. Expected '
f'({data_h}, {data_w}), but got ({results[k].shape[0]}, '
f'{results[k].shape[1]}).')
data_, crop_bbox = self._crop(results[k], x_offset, y_offset,
crop_w, crop_h)
results[k] = data_
results[k + '_crop_bbox'] = crop_bbox
results['crop_size'] = self.crop_size
results['crop_pos'] = self.crop_pos
return results
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += (f'keys={self.keys}, crop_size={self.crop_size}, '
f'crop_pos={self.crop_pos}')
return repr_str
# Copyright (c) OpenMMLab. All rights reserved.
from collections.abc import Sequence
import mmcv
import numpy as np
import torch
from mmcv.parallel import DataContainer as DC
from ..builder import PIPELINES
def to_tensor(data):
"""Convert objects of various python types to :obj:`torch.Tensor`.
Supported types are: :class:`numpy.ndarray`, :class:`torch.Tensor`,
:class:`Sequence`, :class:`int` and :class:`float`.
"""
if isinstance(data, torch.Tensor):
return data
if isinstance(data, np.ndarray):
return torch.from_numpy(data)
if isinstance(data, Sequence) and not mmcv.is_str(data):
return torch.tensor(data)
if isinstance(data, int):
return torch.LongTensor([data])
if isinstance(data, float):
return torch.FloatTensor([data])
raise TypeError(f'type {type(data)} cannot be converted to tensor.')
@PIPELINES.register_module()
class ToTensor:
"""Convert some values in results dict to `torch.Tensor` type in data
loader pipeline.
Args:
keys (Sequence[str]): Required keys to be converted.
"""
def __init__(self, keys):
self.keys = keys
def __call__(self, results):
"""Call function.
Args:
results (dict): A dict containing the necessary information and
data for augmentation.
Returns:
dict: A dict containing the processed data and information.
"""
for key in self.keys:
results[key] = to_tensor(results[key])
return results
def __repr__(self):
return self.__class__.__name__ + f'(keys={self.keys})'
@PIPELINES.register_module()
class ImageToTensor:
"""Convert image type to `torch.Tensor` type.
Args:
keys (Sequence[str]): Required keys to be converted.
to_float32 (bool): Whether convert numpy image array to np.float32
before converted to tensor. Default: True.
"""
def __init__(self, keys, to_float32=True):
self.keys = keys
self.to_float32 = to_float32
def __call__(self, results):
"""Call function.
Args:
results (dict): A dict containing the necessary information and
data for augmentation.
Returns:
dict: A dict containing the processed data and information.
"""
for key in self.keys:
# deal with gray scale img: expand a color channel
if len(results[key].shape) == 2:
results[key] = results[key][..., None]
if self.to_float32 and not isinstance(results[key], np.float32):
results[key] = results[key].astype(np.float32)
results[key] = to_tensor(results[key].transpose(2, 0, 1))
return results
def __repr__(self):
return self.__class__.__name__ + (
f'(keys={self.keys}, to_float32={self.to_float32})')
@PIPELINES.register_module()
class Collect:
"""Collect data from the loader relevant to the specific task.
This is usually the last stage of the data loader pipeline. Typically keys
is set to some subset of "img", "gt_labels".
The "img_meta" item is always populated. The contents of the "meta"
dictionary depends on "meta_keys".
Args:
keys (Sequence[str]): Required keys to be collected.
meta_keys (Sequence[str]): Required keys to be collected to "meta".
Default: None.
"""
def __init__(self, keys, meta_keys=None):
self.keys = keys
self.meta_keys = meta_keys
def __call__(self, results):
"""Call function.
Args:
results (dict): A dict containing the necessary information and
data for augmentation.
Returns:
dict: A dict containing the processed data and information.
"""
data = {}
img_meta = {}
for key in self.meta_keys:
img_meta[key] = results[key]
data['meta'] = DC(img_meta, cpu_only=True)
for key in self.keys:
data[key] = results[key]
return data
def __repr__(self):
return self.__class__.__name__ + (
f'(keys={self.keys}, meta_keys={self.meta_keys})')
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment