Commit 1401de15 authored by dongchy920's avatar dongchy920
Browse files

stylegan2_mmcv

parents
Pipeline #1274 canceled with stages
# Copyright (c) OpenMMLab. All rights reserved.
from .builder import build_dataloader, build_dataset
from .dataset_wrappers import RepeatDataset
from .grow_scale_image_dataset import GrowScaleImgDataset
from .paired_image_dataset import PairedImageDataset
from .pipelines import (Collect, Compose, Flip, ImageToTensor,
LoadImageFromFile, Normalize, Resize, ToTensor)
from .quick_test_dataset import QuickTestImageDataset
from .samplers import DistributedSampler
from .singan_dataset import SinGANDataset
from .unconditional_image_dataset import UnconditionalImageDataset
from .unpaired_image_dataset import UnpairedImageDataset
__all__ = [
'build_dataloader', 'build_dataset', 'LoadImageFromFile',
'DistributedSampler', 'UnconditionalImageDataset', 'Compose', 'ToTensor',
'ImageToTensor', 'Collect', 'Flip', 'Resize', 'RepeatDataset', 'Normalize',
'GrowScaleImgDataset', 'SinGANDataset', 'PairedImageDataset',
'UnpairedImageDataset', 'QuickTestImageDataset'
]
# Copyright (c) OpenMMLab. All rights reserved.
import platform
import random
import warnings
from copy import deepcopy
from functools import partial
import numpy as np
import torch
from mmcv.parallel import collate
from mmcv.runner import get_dist_info
from mmcv.utils import TORCH_VERSION, Registry, build_from_cfg, digit_version
from torch.utils.data import DataLoader
from .samplers import DistributedSampler
if platform.system() != 'Windows':
# https://github.com/pytorch/pytorch/issues/973
import resource
rlimit = resource.getrlimit(resource.RLIMIT_NOFILE)
base_soft_limit = rlimit[0]
hard_limit = rlimit[1]
soft_limit = min(max(4096, base_soft_limit), hard_limit)
resource.setrlimit(resource.RLIMIT_NOFILE, (soft_limit, hard_limit))
DATASETS = Registry('dataset')
PIPELINES = Registry('pipeline')
def build_dataset(cfg, default_args=None):
"""Build dataset.
Args:
cfg (dict): Config for the dataset.
default_args (dict | None, optional): Default arguments.
Defaults to None.
Returns:
Object: Dataset for sampling data batch.
"""
from .dataset_wrappers import RepeatDataset
if isinstance(cfg, (list, tuple)):
raise NotImplementedError('Currently, we do NOT support ConcatDataset')
# dataset = ConcatDataset(
# [build_dataset(c, default_args) for c in cfg])
if cfg['type'] == 'RepeatDataset':
dataset = RepeatDataset(
build_dataset(cfg['dataset'], default_args), cfg['times'])
# add support for using datasets from `MMClassification`
elif cfg['type'].startswith('mmcls.'):
try:
from mmcls.datasets import build_dataset as build_dataset_mmcls
except ImportError:
raise ImportError(
f'Please install mmcls to use {cfg["type"]} dataset.')
_cfg = deepcopy(cfg)
_cfg['type'] = _cfg['type'][6:]
dataset = build_dataset_mmcls(_cfg, default_args)
else:
dataset = build_from_cfg(cfg, DATASETS, default_args)
return dataset
def build_dataloader(dataset,
samples_per_gpu,
workers_per_gpu,
num_gpus=1,
dist=True,
shuffle=True,
seed=None,
persistent_workers=False,
**kwargs):
"""Build PyTorch DataLoader.
In distributed training, each GPU/process has a dataloader.
In non-distributed training, there is only one dataloader for all GPUs.
Args:
dataset (Dataset): A PyTorch dataset.
samples_per_gpu (int): Number of training samples on each GPU, i.e.,
batch size of each GPU.
workers_per_gpu (int): How many subprocesses to use for data loading
for each GPU.
num_gpus (int): Number of GPUs. Only used in non-distributed training.
dist (bool): Distributed training/test or not. Default: True.
shuffle (bool): Whether to shuffle the data at every epoch.
Default: True.
persistent_workers (bool, optional): If True, the data loader will
not shutdown the worker processes after a dataset has been
consumed once. This allows to maintain the workers Dataset
instances alive. The argument also has effect in PyTorch>=1.7.0.
Default: False.
kwargs: any keyword argument to be used to initialize DataLoader
Returns:
DataLoader: A PyTorch dataloader.
"""
rank, world_size = get_dist_info()
if dist:
sampler = DistributedSampler(
dataset,
world_size,
rank,
shuffle=shuffle,
samples_per_gpu=samples_per_gpu,
seed=seed)
shuffle = False
batch_size = samples_per_gpu
num_workers = workers_per_gpu
else:
sampler = None
batch_size = num_gpus * samples_per_gpu
num_workers = num_gpus * workers_per_gpu
init_fn = partial(
worker_init_fn, num_workers=num_workers, rank=rank,
seed=seed) if seed is not None else None
if (digit_version(TORCH_VERSION) >= digit_version('1.7.0')
and TORCH_VERSION != 'parrots'):
kwargs['persistent_workers'] = persistent_workers
elif persistent_workers is True:
warnings.warn('persistent_workers is invalid because your pytorch '
'version is lower than 1.7.0')
data_loader = DataLoader(
dataset,
batch_size=batch_size,
sampler=sampler,
num_workers=num_workers,
collate_fn=partial(collate, samples_per_gpu=samples_per_gpu),
shuffle=shuffle,
worker_init_fn=init_fn,
**kwargs)
return data_loader
def worker_init_fn(worker_id, num_workers, rank, seed):
# The seed of each worker equals to
# num_worker * rank + worker_id + user_seed
worker_seed = num_workers * rank + worker_id + seed
np.random.seed(worker_seed)
random.seed(worker_seed)
torch.manual_seed(worker_seed)
# Copyright (c) OpenMMLab. All rights reserved.
from .builder import DATASETS
@DATASETS.register_module()
class RepeatDataset:
"""A wrapper of repeated dataset.
The length of repeated dataset will be `times` larger than the original
dataset. This is useful when the data loading time is long but the dataset
is small. Using RepeatDataset can reduce the data loading time between
epochs.
Args:
dataset (:obj:`Dataset`): The dataset to be repeated.
times (int): Repeat times.
"""
def __init__(self, dataset, times):
self.dataset = dataset
self.times = times
self._ori_len = len(self.dataset)
def __getitem__(self, idx):
"""Get item at each call.
Args:
idx (int): Index for getting each item.
"""
return self.dataset[idx % self._ori_len]
def __len__(self):
"""Length of the dataset.
Returns:
int: Length of the dataset.
"""
return self.times * self._ori_len
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import mmcv
from torch.utils.data import Dataset
from .builder import DATASETS
from .pipelines import Compose
@DATASETS.register_module()
class GrowScaleImgDataset(Dataset):
"""Grow Scale Unconditional Image Dataset.
This dataset is similar with ``UnconditionalImageDataset``, but offer
more dynamic functionalities for the supporting complex algorithms, like
PGGAN.
Highlight functionalities:
#. Support growing scale dataset. The motivation is to decrease data
pre-processing load in CPU. In this dataset, you can provide
``imgs_roots`` like:
.. code-block:: python
{'64': 'path_to_64x64_imgs',
'512': 'path_to_512x512_imgs'}
Then, in training scales lower than 64x64, this dataset will set
``self.imgs_root`` as 'path_to_64x64_imgs';
#. Offer ``samples_per_gpu`` according to different scales. In this
dataset, ``self.samples_per_gpu`` will help runner to know the updated
batch size.
Basically, This dataset contains raw images for training unconditional
GANs. Given a root dir, we will recursively find all images in this root.
The transformation on data is defined by the pipeline.
Args:
imgs_root (str): Root path for unconditional images.
pipeline (list[dict | callable]): A sequence of data transforms.
len_per_stage (int, optional): The length of dataset for each scale.
This args change the length dataset by concatenating or extracting
subset. If given a value less than 0., the original length will be
kept. Defaults to 1e6.
gpu_samples_per_scale (dict | None, optional): Dict contains
``samples_per_gpu`` for each scale. For example, ``{'32': 4}`` will
set the scale of 32 with ``samples_per_gpu=4``, despite other scale
with ``samples_per_gpu=self.gpu_samples_base``.
gpu_samples_base (int, optional): Set default ``samples_per_gpu`` for
each scale. Defaults to 32.
test_mode (bool, optional): If True, the dataset will work in test
mode. Otherwise, in train mode. Default to False.
"""
_VALID_IMG_SUFFIX = ('.jpg', '.png', '.jpeg', '.JPEG')
def __init__(self,
imgs_roots,
pipeline,
len_per_stage=int(1e6),
gpu_samples_per_scale=None,
gpu_samples_base=32,
test_mode=False):
super().__init__()
assert isinstance(imgs_roots, dict)
self.imgs_roots = imgs_roots
self._img_scales = sorted([int(x) for x in imgs_roots.keys()])
self._curr_scale = self._img_scales[0]
self._actual_curr_scale = self._curr_scale
self.imgs_root = self.imgs_roots[str(self._curr_scale)]
self.pipeline = Compose(pipeline)
self.test_mode = test_mode
# len_per_stage = -1, keep the original length
self.len_per_stage = len_per_stage
self.curr_stage = 0
self.gpu_samples_per_scale = gpu_samples_per_scale
if self.gpu_samples_per_scale is not None:
assert isinstance(self.gpu_samples_per_scale, dict)
else:
self.gpu_samples_per_scale = dict()
self.gpu_samples_base = gpu_samples_base
self.load_annotations()
# print basic dataset information to check the validity
mmcv.print_log(repr(self), 'mmgen')
def load_annotations(self):
"""Load annotations."""
# recursively find all of the valid images from imgs_root
imgs_list = mmcv.scandir(
self.imgs_root, self._VALID_IMG_SUFFIX, recursive=True)
self.imgs_list = [osp.join(self.imgs_root, x) for x in imgs_list]
if self.len_per_stage > 0:
self.concat_imgs_list_to(self.len_per_stage)
self.samples_per_gpu = self.gpu_samples_per_scale.get(
str(self._actual_curr_scale), self.gpu_samples_base)
def update_annotations(self, curr_scale):
"""Update annotations.
Args:
curr_scale (int): Current image scale.
Returns:
bool: Whether to update.
"""
if curr_scale == self._actual_curr_scale:
return False
for scale in self._img_scales:
if curr_scale <= scale:
self._curr_scale = scale
break
if scale == self._img_scales[-1]:
assert RuntimeError(
f'Cannot find a suitable scale for {curr_scale}')
self._actual_curr_scale = curr_scale
self.imgs_root = self.imgs_roots[str(self._curr_scale)]
self.load_annotations()
# print basic dataset information to check the validity
mmcv.print_log('Update Dataset: ' + repr(self), 'mmgen')
return True
def concat_imgs_list_to(self, num):
"""Concat image list to specified length.
Args:
num (int): The length of the concatenated image list.
"""
if num <= len(self.imgs_list):
self.imgs_list = self.imgs_list[:num]
return
concat_factor = (num // len(self.imgs_list)) + 1
imgs = self.imgs_list * concat_factor
self.imgs_list = imgs[:num]
def prepare_train_data(self, idx):
"""Prepare training data.
Args:
idx (int): Index of current batch.
Returns:
dict: Prepared training data batch.
"""
results = dict(real_img_path=self.imgs_list[idx])
return self.pipeline(results)
def prepare_test_data(self, idx):
"""Prepare testing data.
Args:
idx (int): Index of current batch.
Returns:
dict: Prepared training data batch.
"""
results = dict(real_img_path=self.imgs_list[idx])
return self.pipeline(results)
def __len__(self):
return len(self.imgs_list)
def __getitem__(self, idx):
if not self.test_mode:
return self.prepare_train_data(idx)
return self.prepare_test_data(idx)
def __repr__(self):
dataset_name = self.__class__
imgs_root = self.imgs_root
num_imgs = len(self)
return (f'dataset_name: {dataset_name}, total {num_imgs} images in '
f'imgs_root: {imgs_root}')
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import os.path as osp
from pathlib import Path
from mmcv import scandir
from torch.utils.data import Dataset
from .builder import DATASETS
from .pipelines import Compose
IMG_EXTENSIONS = ('.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG', '.ppm',
'.PPM', '.bmp', '.BMP', '.tif', '.TIF', '.tiff', '.TIFF')
@DATASETS.register_module()
class PairedImageDataset(Dataset):
"""General paired image folder dataset for image generation.
It assumes that the training directory is '/path/to/data/train'.
During test time, the directory is '/path/to/data/test'. '/path/to/data'
can be initialized by args 'dataroot'. Each sample contains a pair of
images concatenated in the w dimension (A|B).
Args:
dataroot (str | :obj:`Path`): Path to the folder root of paired images.
pipeline (List[dict | callable]): A sequence of data transformations.
test_mode (bool): Store `True` when building test dataset.
Default: `False`.
testdir (str): Subfolder of dataroot which contain test images.
Default: 'test'.
"""
def __init__(self, dataroot, pipeline, test_mode=False, testdir='test'):
super().__init__()
phase = testdir if test_mode else 'train'
self.dataroot = osp.join(str(dataroot), phase)
self.data_infos = self.load_annotations()
self.test_mode = test_mode
self.pipeline = Compose(pipeline)
def load_annotations(self):
"""Load paired image paths.
Returns:
list[dict]: List that contains paired image paths.
"""
data_infos = []
pair_paths = sorted(self.scan_folder(self.dataroot))
for pair_path in pair_paths:
data_infos.append(dict(pair_path=pair_path))
return data_infos
@staticmethod
def scan_folder(path):
"""Obtain image path list (including sub-folders) from a given folder.
Args:
path (str | :obj:`Path`): Folder path.
Returns:
list[str]: Image list obtained from the given folder.
"""
if isinstance(path, (str, Path)):
path = str(path)
else:
raise TypeError("'path' must be a str or a Path object, "
f'but received {type(path)}.')
images = scandir(path, suffix=IMG_EXTENSIONS, recursive=True)
images = [osp.join(path, v) for v in images]
assert images, f'{path} has no valid image file.'
return images
def prepare_train_data(self, idx):
"""Prepare training data.
Args:
idx (int): Index of the training batch data.
Returns:
dict: Returned training batch.
"""
results = copy.deepcopy(self.data_infos[idx])
return self.pipeline(results)
def prepare_test_data(self, idx):
"""Prepare testing data.
Args:
idx (int): Index for getting each testing batch.
Returns:
Tensor: Returned testing batch.
"""
results = copy.deepcopy(self.data_infos[idx])
return self.pipeline(results)
def __len__(self):
"""Length of the dataset.
Returns:
int: Length of the dataset.
"""
return len(self.data_infos)
def __getitem__(self, idx):
"""Get item at each call.
Args:
idx (int): Index for getting each item.
"""
if not self.test_mode:
return self.prepare_train_data(idx)
return self.prepare_test_data(idx)
# Copyright (c) OpenMMLab. All rights reserved.
from .augmentation import (CenterCropLongEdge, Flip, NumpyPad,
RandomCropLongEdge, RandomImgNoise, Resize)
from .compose import Compose
from .crop import Crop, FixedCrop
from .formatting import Collect, ImageToTensor, ToTensor
from .loading import LoadImageFromFile
from .normalize import Normalize
__all__ = [
'LoadImageFromFile',
'Compose',
'ImageToTensor',
'Collect',
'ToTensor',
'Flip',
'Resize',
'RandomImgNoise',
'RandomCropLongEdge',
'CenterCropLongEdge',
'Normalize',
'NumpyPad',
'Crop',
'FixedCrop',
]
# Copyright (c) OpenMMLab. All rights reserved.
import mmcv
import numpy as np
from mmcls.datasets import PIPELINES as CLS_PIPELINE
from ..builder import PIPELINES
@PIPELINES.register_module()
class Flip:
"""Flip the input data with a probability.
Reverse the order of elements in the given data with a specific direction.
The shape of the data is preserved, but the elements are reordered.
Required keys are the keys in attributes "keys", added or modified keys are
"flip", "flip_direction" and the keys in attributes "keys".
It also supports flipping a list of images with the same flip.
Args:
keys (list[str]): The images to be flipped.
flip_ratio (float): The propability to flip the images.
direction (str): Flip images horizontally or vertically. Options are
"horizontal" | "vertical". Default: "horizontal".
"""
_directions = ['horizontal', 'vertical']
def __init__(self, keys, flip_ratio=0.5, direction='horizontal'):
if direction not in self._directions:
raise ValueError(f'Direction {direction} is not supported.'
f'Currently support ones are {self._directions}')
self.keys = keys
self.flip_ratio = flip_ratio
self.direction = direction
def __call__(self, results):
"""Call function.
Args:
results (dict): A dict containing the necessary information and
data for augmentation.
Returns:
dict: A dict containing the processed data and information.
"""
flip = np.random.random() < self.flip_ratio
if flip:
for key in self.keys:
if isinstance(results[key], list):
for v in results[key]:
mmcv.imflip_(v, self.direction)
else:
mmcv.imflip_(results[key], self.direction)
results['flip'] = flip
results['flip_direction'] = self.direction
return results
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += (f'(keys={self.keys}, flip_ratio={self.flip_ratio}, '
f'direction={self.direction})')
return repr_str
@PIPELINES.register_module()
class Resize:
"""Resize data to a specific size for training or resize the images to fit
the network input regulation for testing.
When used for resizing images to fit network input regulation, the case is
that a network may have several downsample and then upsample operation,
then the input height and width should be divisible by the downsample
factor of the network.
For example, the network would downsample the input for 5 times with
stride 2, then the downsample factor is 2^5 = 32 and the height
and width should be divisible by 32.
Required keys are the keys in attribute "keys", added or modified keys are
"keep_ratio", "scale_factor", "interpolation" and the
keys in attribute "keys".
All keys in "keys" should have the same shape. "test_trans" is used to
record the test transformation to align the input's shape.
Args:
keys (list[str]): The images to be resized.
scale (float | Tuple[int]): If scale is Tuple(int), target spatial
size (h, w). Otherwise, target spatial size is scaled by input
size. If any of scale is -1, we will rescale short edge.
Note that when it is used, `size_factor` and `max_size` are
useless. Default: None
keep_ratio (bool): If set to True, images will be resized without
changing the aspect ratio. Otherwise, it will resize images to a
given size. Default: False.
Note that it is used togher with `scale`.
size_factor (int): Let the output shape be a multiple of size_factor.
Default:None.
Note that when it is used, `scale` should be set to None and
`keep_ratio` should be set to False.
max_size (int): The maximum size of the longest side of the output.
Default:None.
Note that it is used togher with `size_factor`.
interpolation (str): Algorithm used for interpolation:
"nearest" | "bilinear" | "bicubic" | "area" | "lanczos".
Default: "bilinear".
backend (str | None): The image resize backend type. Options are `cv2`,
`pillow`, `None`. If backend is None, the global imread_backend
specified by ``mmcv.use_backend()`` will be used. Default: None.
"""
def __init__(self,
keys,
scale=None,
keep_ratio=False,
size_factor=None,
max_size=None,
interpolation='bilinear',
backend=None):
assert keys, 'Keys should not be empty.'
if size_factor:
assert scale is None, ('When size_factor is used, scale should ',
f'be None. But received {scale}.')
assert keep_ratio is False, ('When size_factor is used, '
'keep_ratio should be False.')
if max_size:
assert size_factor is not None, (
'When max_size is used, '
f'size_factor should also be set. But received {size_factor}.')
if isinstance(scale, float):
if scale <= 0:
raise ValueError(f'Invalid scale {scale}, must be positive.')
elif mmcv.is_tuple_of(scale, int):
max_long_edge = max(scale)
max_short_edge = min(scale)
if max_short_edge == -1:
# assign np.inf to long edge for rescaling short edge later.
scale = (np.inf, max_long_edge)
elif scale is not None:
raise TypeError(
f'Scale must be None, float or tuple of int, but got '
f'{type(scale)}.')
self.keys = keys
self.scale = scale
self.size_factor = size_factor
self.max_size = max_size
self.keep_ratio = keep_ratio
self.interpolation = interpolation
self.backend = backend
def _resize(self, img, scale):
"""Resize given image with corresponding scale.
Args:
img (np.array): Image to be resized.
scale (float | Tuple[int]): Scale used in resize process.
Returns:
tuple: Tuple contains resized image and scale factor in resize
process.
"""
if self.keep_ratio:
img, scale_factor = mmcv.imrescale(
img,
scale,
return_scale=True,
interpolation=self.interpolation,
backend=self.backend)
else:
img, w_scale, h_scale = mmcv.imresize(
img,
scale,
return_scale=True,
interpolation=self.interpolation,
backend=self.backend)
scale_factor = np.array((w_scale, h_scale), dtype=np.float32)
return img, scale_factor
def __call__(self, results):
"""Call function.
Args:
results (dict): A dict containing the necessary information and
data for augmentation.
Returns:
dict: A dict containing the processed data and information.
"""
if self.size_factor:
h, w = results[self.keys[0]].shape[:2]
new_h = h - (h % self.size_factor)
new_w = w - (w % self.size_factor)
if self.max_size:
new_h = min(self.max_size - (self.max_size % self.size_factor),
new_h)
new_w = min(self.max_size - (self.max_size % self.size_factor),
new_w)
scale = (new_w, new_h)
elif isinstance(self.scale, tuple) and (np.inf in self.scale):
# find inf in self.scale, calculate ``scale`` manually
h, w = results[self.keys[0]].shape[:2]
if h < w:
scale = (int(self.scale[-1] / h * w), self.scale[-1])
else:
scale = (self.scale[-1], int(self.scale[-1] / w * h))
else:
# direct use the given ones
scale = self.scale
# here we assume all images in self.keys have the same input size
for key in self.keys:
results[key], scale_factor = self._resize(results[key], scale)
if len(results[key].shape) == 2:
results[key] = np.expand_dims(results[key], axis=2)
results['scale_factor'] = scale_factor
results['keep_ratio'] = self.keep_ratio
results['interpolation'] = self.interpolation
return results
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += (
f'(keys={self.keys}, scale={self.scale}, '
f'keep_ratio={self.keep_ratio}, size_factor={self.size_factor}, '
f'max_size={self.max_size},interpolation={self.interpolation})')
return repr_str
@PIPELINES.register_module()
class NumpyPad:
"""Numpy Padding.
In this augmentation, numpy padding is adopted to customize padding
augmentation. Please carefully read the numpy manual in:
https://numpy.org/doc/stable/reference/generated/numpy.pad.html
If you just hope a single dimension to be padded, you must set ``padding``
like this:
::
padding = ((2, 2), (0, 0), (0, 0))
In this case, if you adopt an input with three dimension, only the first
diemansion will be padded.
Args:
keys (list[str]): The images to be resized.
padding (int | tuple(int)): Please refer to the args ``pad_width`` in
``numpy.pad``.
"""
def __init__(self, keys, padding, **kwargs):
self.keys = keys
self.padding = padding
self.kwargs = kwargs
def __call__(self, results):
"""Call function.
Args:
results (dict): A dict containing the necessary information and
data for augmentation.
Returns:
dict: A dict containing the processed data and information.
"""
for k in self.keys:
results[k] = np.pad(results[k], self.padding, **self.kwargs)
return results
def __repr__(self) -> str:
repr_str = self.__class__.__name__
repr_str += (
f'(keys={self.keys}, padding={self.padding}, kwargs={self.kwargs})'
)
return repr_str
@CLS_PIPELINE.register_module()
@PIPELINES.register_module()
class RandomImgNoise:
"""Add random noise with specific distribution and range to the input
image.
Args:
keys (list[str]): The images to be added random noise.
lower_bound (float, optional): The lower bound of the noise.
Default to ``0.``.
upper_bound (float, optional): The upper bound of the noise.
Default to ``1 / 128.``.
distribution (str, optional): The probability distribution of the
noise. Default to 'uniform'.
"""
def __init__(self,
keys,
lower_bound=0,
upper_bound=1 / 128.,
distribution='uniform'):
assert keys, 'Keys should not be empty.'
self.keys = keys
self.lower_bound = lower_bound
self.upper_bound = upper_bound
if distribution not in ['uniform', 'normal']:
raise KeyError('Only support \'uniform\' distribution and '
'\'normal\' distribution, receive '
f'{distribution}.')
self.distribution = distribution
def __call__(self, results):
"""Call function.
Args:
results (dict): A dict containing the necessary information and
data for augmentation.
Returns:
dict: A dict containing the processed data and information.
"""
if self.distribution == 'uniform':
dist_fn = np.random.rand
else: # self.distribution == 'normal
dist_fn = np.random.randn
for key in self.keys:
img_size = results[key].shape
noise = dist_fn(*img_size)
scale = noise.max() - noise.min()
noise = noise - noise.min()
noise = noise / scale * (self.upper_bound - self.lower_bound)
noise = noise + self.lower_bound
results[key] += noise
return results
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += (f'(keys={self.keys}, lower_bound={self.lower_bound}, '
f'upper_bound={self.upper_bound})')
return repr_str
@CLS_PIPELINE.register_module()
@PIPELINES.register_module()
class RandomCropLongEdge:
"""Random crop the given image by the long edge.
Args:
keys (list[str]): The images to be cropped.
"""
def __init__(self, keys):
assert keys, 'Keys should not be empty.'
self.keys = keys
def __call__(self, results):
"""Call function.
Args:
results (dict): A dict containing the necessary information and
data for augmentation.
Returns:
dict: A dict containing the processed data and information.
"""
for key in self.keys:
img = results[key]
img_height, img_width = img.shape[:2]
crop_size = min(img_height, img_width)
y1 = 0 if img_height == crop_size else \
np.random.randint(0, img_height - crop_size)
x1 = 0 if img_width == crop_size else \
np.random.randint(0, img_width - crop_size)
y2, x2 = y1 + crop_size - 1, x1 + crop_size - 1
img = mmcv.imcrop(img, bboxes=np.array([x1, y1, x2, y2]))
results[key] = img
return results
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += (f'(keys={self.keys})')
return repr_str
@CLS_PIPELINE.register_module()
@PIPELINES.register_module()
class CenterCropLongEdge:
"""Center crop the given image by the long edge.
Args:
keys (list[str]): The images to be cropped.
"""
def __init__(self, keys):
assert keys, 'Keys should not be empty.'
self.keys = keys
def __call__(self, results):
"""Call function.
Args:
results (dict): A dict containing the necessary information and
data for augmentation.
Returns:
dict: A dict containing the processed data and information.
"""
for key in self.keys:
img = results[key]
img_height, img_width = img.shape[:2]
crop_size = min(img_height, img_width)
y1 = 0 if img_height == crop_size else \
int(round(img_height - crop_size) / 2)
x1 = 0 if img_width == crop_size else \
int(round(img_width - crop_size) / 2)
y2 = y1 + crop_size - 1
x2 = x1 + crop_size - 1
img = mmcv.imcrop(img, bboxes=np.array([x1, y1, x2, y2]))
results[key] = img
return results
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += (f'(keys={self.keys})')
return repr_str
# Copyright (c) OpenMMLab. All rights reserved.
from collections.abc import Sequence
from copy import deepcopy
from mmcv.utils import build_from_cfg
from ..builder import PIPELINES
@PIPELINES.register_module()
class Compose:
"""Compose a data pipeline with a sequence of transforms.
Args:
transforms (list[dict | callable]):
Either config dicts of transforms or transform objects.
"""
def __init__(self, transforms):
assert isinstance(transforms, Sequence)
self.transforms = []
for transform in transforms:
if isinstance(transform, dict):
# add support for using pipelines from `MMClassification`
if transform['type'].startswith('mmcls.'):
try:
from mmcls.datasets import PIPELINES as MMCLSPIPELINE
except ImportError:
raise ImportError('Please install mmcls to use '
f'{transform["type"]} dataset.')
pipeline_source = MMCLSPIPELINE
# remove prefix
transform_cfg = deepcopy(transform)
transform_cfg['type'] = transform_cfg['type'][6:]
else:
pipeline_source = PIPELINES
transform_cfg = deepcopy(transform)
transform = build_from_cfg(transform_cfg, pipeline_source)
self.transforms.append(transform)
elif callable(transform):
self.transforms.append(transform)
else:
raise TypeError(f'transform must be callable or a dict, '
f'but got {type(transform)}')
def __call__(self, data):
"""Call function.
Args:
data (dict): A dict containing the necessary information and
data for augmentation.
Returns:
dict: A dict containing the processed data and information.
"""
for t in self.transforms:
data = t(data)
if data is None:
return None
return data
def __repr__(self):
format_string = self.__class__.__name__ + '('
for t in self.transforms:
format_string += '\n'
format_string += f' {t}'
format_string += '\n)'
return format_string
# Copyright (c) OpenMMLab. All rights reserved.
import mmcv
import numpy as np
from ..builder import PIPELINES
@PIPELINES.register_module()
class Crop:
"""Crop data to specific size for training.
Args:
keys (Sequence[str]): The images to be cropped.
crop_size (Tuple[int]): Target spatial size (h, w).
random_crop (bool): If set to True, it will random crop
image. Otherwise, it will work as center crop.
"""
def __init__(self, keys, crop_size, random_crop=True):
if not mmcv.is_tuple_of(crop_size, int):
raise TypeError(
'Elements of crop_size must be int and crop_size must be'
f' tuple, but got {type(crop_size[0])} in {type(crop_size)}')
self.keys = keys
self.crop_size = crop_size
self.random_crop = random_crop
def _crop(self, data):
if not isinstance(data, list):
data_list = [data]
else:
data_list = data
crop_bbox_list = []
data_list_ = []
for item in data_list:
data_h, data_w = item.shape[:2]
crop_h, crop_w = self.crop_size
crop_h = min(data_h, crop_h)
crop_w = min(data_w, crop_w)
if self.random_crop:
x_offset = np.random.randint(0, data_w - crop_w + 1)
y_offset = np.random.randint(0, data_h - crop_h + 1)
else:
x_offset = max(0, (data_w - crop_w)) // 2
y_offset = max(0, (data_h - crop_h)) // 2
crop_bbox = [x_offset, y_offset, crop_w, crop_h]
item_ = item[y_offset:y_offset + crop_h,
x_offset:x_offset + crop_w, ...]
crop_bbox_list.append(crop_bbox)
data_list_.append(item_)
if not isinstance(data, list):
return data_list_[0], crop_bbox_list[0]
return data_list_, crop_bbox_list
def __call__(self, results):
"""Call function.
Args:
results (dict): A dict containing the necessary information and
data for augmentation.
Returns:
dict: A dict containing the processed data and information.
"""
for k in self.keys:
data_, crop_bbox = self._crop(results[k])
results[k] = data_
results[k + '_crop_bbox'] = crop_bbox
results['crop_size'] = self.crop_size
return results
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += (f'keys={self.keys}, crop_size={self.crop_size}, '
f'random_crop={self.random_crop}')
return repr_str
@PIPELINES.register_module()
class FixedCrop:
"""Crop paired data (at a specific position) to specific size for training.
Args:
keys (Sequence[str]): The images to be cropped.
crop_size (Tuple[int]): Target spatial size (h, w).
crop_pos (Tuple[int]): Specific position (x, y). If set to None,
random initialize the position to crop paired data batch.
"""
def __init__(self, keys, crop_size, crop_pos=None):
if not mmcv.is_tuple_of(crop_size, int):
raise TypeError(
'Elements of crop_size must be int and crop_size must be'
f' tuple, but got {type(crop_size[0])} in {type(crop_size)}')
if not mmcv.is_tuple_of(crop_pos, int) and (crop_pos is not None):
raise TypeError(
'Elements of crop_pos must be int and crop_pos must be'
f' tuple or None, but got {type(crop_pos[0])} in '
f'{type(crop_pos)}')
self.keys = keys
self.crop_size = crop_size
self.crop_pos = crop_pos
def _crop(self, data, x_offset, y_offset, crop_w, crop_h):
crop_bbox = [x_offset, y_offset, crop_w, crop_h]
data_ = data[y_offset:y_offset + crop_h, x_offset:x_offset + crop_w,
...]
return data_, crop_bbox
def __call__(self, results):
"""Call function.
Args:
results (dict): A dict containing the necessary information and
data for augmentation.
Returns:
dict: A dict containing the processed data and information.
"""
data_h, data_w = results[self.keys[0]].shape[:2]
crop_h, crop_w = self.crop_size
crop_h = min(data_h, crop_h)
crop_w = min(data_w, crop_w)
if self.crop_pos is None:
x_offset = np.random.randint(0, data_w - crop_w + 1)
y_offset = np.random.randint(0, data_h - crop_h + 1)
else:
x_offset, y_offset = self.crop_pos
crop_w = min(data_w - x_offset, crop_w)
crop_h = min(data_h - y_offset, crop_h)
for k in self.keys:
# In fixed crop for paired images, sizes should be the same
if (results[k].shape[0] != data_h
or results[k].shape[1] != data_w):
raise ValueError(
'The sizes of paired images should be the same. Expected '
f'({data_h}, {data_w}), but got ({results[k].shape[0]}, '
f'{results[k].shape[1]}).')
data_, crop_bbox = self._crop(results[k], x_offset, y_offset,
crop_w, crop_h)
results[k] = data_
results[k + '_crop_bbox'] = crop_bbox
results['crop_size'] = self.crop_size
results['crop_pos'] = self.crop_pos
return results
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += (f'keys={self.keys}, crop_size={self.crop_size}, '
f'crop_pos={self.crop_pos}')
return repr_str
# Copyright (c) OpenMMLab. All rights reserved.
from collections.abc import Sequence
import mmcv
import numpy as np
import torch
from mmcv.parallel import DataContainer as DC
from ..builder import PIPELINES
def to_tensor(data):
"""Convert objects of various python types to :obj:`torch.Tensor`.
Supported types are: :class:`numpy.ndarray`, :class:`torch.Tensor`,
:class:`Sequence`, :class:`int` and :class:`float`.
"""
if isinstance(data, torch.Tensor):
return data
if isinstance(data, np.ndarray):
return torch.from_numpy(data)
if isinstance(data, Sequence) and not mmcv.is_str(data):
return torch.tensor(data)
if isinstance(data, int):
return torch.LongTensor([data])
if isinstance(data, float):
return torch.FloatTensor([data])
raise TypeError(f'type {type(data)} cannot be converted to tensor.')
@PIPELINES.register_module()
class ToTensor:
"""Convert some values in results dict to `torch.Tensor` type in data
loader pipeline.
Args:
keys (Sequence[str]): Required keys to be converted.
"""
def __init__(self, keys):
self.keys = keys
def __call__(self, results):
"""Call function.
Args:
results (dict): A dict containing the necessary information and
data for augmentation.
Returns:
dict: A dict containing the processed data and information.
"""
for key in self.keys:
results[key] = to_tensor(results[key])
return results
def __repr__(self):
return self.__class__.__name__ + f'(keys={self.keys})'
@PIPELINES.register_module()
class ImageToTensor:
"""Convert image type to `torch.Tensor` type.
Args:
keys (Sequence[str]): Required keys to be converted.
to_float32 (bool): Whether convert numpy image array to np.float32
before converted to tensor. Default: True.
"""
def __init__(self, keys, to_float32=True):
self.keys = keys
self.to_float32 = to_float32
def __call__(self, results):
"""Call function.
Args:
results (dict): A dict containing the necessary information and
data for augmentation.
Returns:
dict: A dict containing the processed data and information.
"""
for key in self.keys:
# deal with gray scale img: expand a color channel
if len(results[key].shape) == 2:
results[key] = results[key][..., None]
if self.to_float32 and not isinstance(results[key], np.float32):
results[key] = results[key].astype(np.float32)
results[key] = to_tensor(results[key].transpose(2, 0, 1))
return results
def __repr__(self):
return self.__class__.__name__ + (
f'(keys={self.keys}, to_float32={self.to_float32})')
@PIPELINES.register_module()
class Collect:
"""Collect data from the loader relevant to the specific task.
This is usually the last stage of the data loader pipeline. Typically keys
is set to some subset of "img", "gt_labels".
The "img_meta" item is always populated. The contents of the "meta"
dictionary depends on "meta_keys".
Args:
keys (Sequence[str]): Required keys to be collected.
meta_keys (Sequence[str]): Required keys to be collected to "meta".
Default: None.
"""
def __init__(self, keys, meta_keys=None):
self.keys = keys
self.meta_keys = meta_keys
def __call__(self, results):
"""Call function.
Args:
results (dict): A dict containing the necessary information and
data for augmentation.
Returns:
dict: A dict containing the processed data and information.
"""
data = {}
img_meta = {}
for key in self.meta_keys:
img_meta[key] = results[key]
data['meta'] = DC(img_meta, cpu_only=True)
for key in self.keys:
data[key] = results[key]
return data
def __repr__(self):
return self.__class__.__name__ + (
f'(keys={self.keys}, meta_keys={self.meta_keys})')
# Copyright (c) OpenMMLab. All rights reserved.
import mmcv
import numpy as np
from mmcv.fileio import FileClient
from ..builder import PIPELINES
@PIPELINES.register_module()
class LoadImageFromFile:
"""Load image from file.
Args:
io_backend (str): io backend where images are store. Default: 'disk'.
key (str): Keys in results to find corresponding path. Default: 'gt'.
flag (str): Loading flag for images. Default: 'color'.
channel_order (str): Order of channel, candidates are 'bgr' and 'rgb'.
Default: 'bgr'.
backend (str | None): The image decoding backend type. Options are
`cv2`, `pillow`, `turbojpeg`, `None`. If backend is None, the
global imread_backend specified by ``mmcv.use_backend()`` will be
used. Default: None.
save_original_img (bool): If True, maintain a copy of the image in
``results`` dict with name of ``f'ori_{key}'``. Default: False.
kwargs (dict): Args for file client.
"""
def __init__(self,
io_backend='disk',
key='gt',
flag='color',
channel_order='bgr',
backend=None,
save_original_img=False,
**kwargs):
self.io_backend = io_backend
self.key = key
self.flag = flag
self.save_original_img = save_original_img
self.channel_order = channel_order
self.backend = backend
self.kwargs = kwargs
self.file_client = None
def __call__(self, results):
"""Call function.
Args:
results (dict): A dict containing the necessary information and
data for augmentation.
Returns:
dict: A dict containing the processed data and information.
"""
if self.file_client is None:
self.file_client = FileClient(self.io_backend, **self.kwargs)
filepath = str(results[f'{self.key}_path'])
img_bytes = self.file_client.get(filepath)
img = mmcv.imfrombytes(
img_bytes,
flag=self.flag,
channel_order=self.channel_order,
backend=self.backend) # HWC
results[self.key] = img
results[f'{self.key}_path'] = filepath
results[f'{self.key}_ori_shape'] = img.shape
if self.save_original_img:
results[f'ori_{self.key}'] = img.copy()
return results
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += (
f'(io_backend={self.io_backend}, key={self.key}, '
f'flag={self.flag}, save_original_img={self.save_original_img})')
return repr_str
@PIPELINES.register_module()
class LoadPairedImageFromFile(LoadImageFromFile):
"""Load a pair of images from file.
Each sample contains a pair of images, which are concatenated in the w
dimension (a|b). This is a special loading class for generation paired
dataset. It loads a pair of images as the common loader does and crops
it into two images with the same shape in different domains.
Required key is "pair_path". Added or modified keys are "pair",
"pair_ori_shape", "ori_pair", "img_{domain_a}", "img_{domain_b}",
"img_{domain_a}_path", "img_{domain_b}_path", "img_{domain_a}_ori_shape",
"img_{domain_b}_ori_shape", "ori_img_{domain_a}" and
"ori_img_{domain_b}".
Args:
io_backend (str): io backend where images are store. Default: 'disk'.
key (str): Keys in results to find corresponding path. Default: 'gt'.
domain_a (str, optional): One of the paired image domain.
Defaults to None.
domain_b (str, optional): The other image domain.
Defaults to None.
flag (str): Loading flag for images. Default: 'color'.
channel_order (str): Order of channel, candidates are 'bgr' and 'rgb'.
Default: 'bgr'.
save_original_img (bool): If True, maintain a copy of the image in
`results` dict with name of `f'ori_{key}'`. Default: False.
kwargs (dict): Args for file client.
"""
def __init__(self,
io_backend='disk',
key='pair',
domain_a=None,
domain_b=None,
flag='color',
channel_order='bgr',
backend=None,
save_original_img=False,
**kwargs):
super().__init__(
io_backend,
key=key,
flag=flag,
channel_order=channel_order,
backend=backend,
save_original_img=save_original_img,
**kwargs)
assert isinstance(domain_a, str)
assert isinstance(domain_b, str)
self.domain_a = domain_a
self.domain_b = domain_b
def __call__(self, results):
"""Call function.
Args:
results (dict): A dict containing the necessary information and
data for augmentation.
Returns:
dict: A dict containing the processed data and information.
"""
if self.file_client is None:
self.file_client = FileClient(self.io_backend, **self.kwargs)
filepath = str(results[f'{self.key}_path'])
img_bytes = self.file_client.get(filepath)
img = mmcv.imfrombytes(img_bytes, flag=self.flag) # HWC, BGR
if img.ndim == 2:
img = np.expand_dims(img, axis=2)
results[self.key] = img
results[f'{self.key}_path'] = filepath
results[f'{self.key}_ori_shape'] = img.shape
if self.save_original_img:
results[f'ori_{self.key}'] = img.copy()
# crop pair into a and b
w = img.shape[1]
if w % 2 != 0:
raise ValueError(
f'The width of image pair must be even number, but got {w}.')
new_w = w // 2
img_a = img[:, :new_w, :]
img_b = img[:, new_w:, :]
results[f'img_{self.domain_a}'] = img_a
results[f'img_{self.domain_b}'] = img_b
results[f'img_{self.domain_a}_path'] = filepath
results[f'img_{self.domain_b}_path'] = filepath
results[f'img_{self.domain_a}_ori_shape'] = img_a.shape
results[f'img_{self.domain_b}_ori_shape'] = img_b.shape
if self.save_original_img:
results[f'ori_img_{self.domain_a}'] = img_a.copy()
results[f'ori_img_{self.domain_b}'] = img_b.copy()
return results
# Copyright (c) OpenMMLab. All rights reserved.
import mmcv
import numpy as np
from ..builder import PIPELINES
@PIPELINES.register_module()
class Normalize:
"""Normalize images with the given mean and std value.
Required keys are the keys in attribute "keys", added or modified keys are
the keys in attribute "keys" and these keys with postfix '_norm_cfg'.
It also supports normalizing a list of images.
Args:
keys (Sequence[str]): The images to be normalized.
mean (np.ndarray): Mean values of different channels.
std (np.ndarray): Std values of different channels.
to_rgb (bool): Whether to convert channels from BGR to RGB.
"""
def __init__(self, keys, mean, std, to_rgb=False):
self.keys = keys
self.mean = np.array(mean, dtype=np.float32)
self.std = np.array(std, dtype=np.float32)
self.to_rgb = to_rgb
def __call__(self, results):
"""Call function.
Args:
results (dict): A dict containing the necessary information and
data for augmentation.
Returns:
dict: A dict containing the processed data and information.
"""
for key in self.keys:
if isinstance(results[key], list):
results[key] = [
mmcv.imnormalize(v, self.mean, self.std, self.to_rgb)
for v in results[key]
]
else:
results[key] = mmcv.imnormalize(results[key], self.mean,
self.std, self.to_rgb)
results['img_norm_cfg'] = dict(
mean=self.mean, std=self.std, to_rgb=self.to_rgb)
return results
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += (f'(keys={self.keys}, mean={self.mean}, std={self.std}, '
f'to_rgb={self.to_rgb})')
return repr_str
@PIPELINES.register_module()
class RescaleToZeroOne:
"""Transform the images into a range between 0 and 1.
Required keys are the keys in attribute "keys", added or modified keys are
the keys in attribute "keys".
It also supports rescaling a list of images.
Args:
keys (Sequence[str]): The images to be transformed.
"""
def __init__(self, keys):
self.keys = keys
def __call__(self, results):
"""Call function.
Args:
results (dict): A dict containing the necessary information and
data for augmentation.
Returns:
dict: A dict containing the processed data and information.
"""
for key in self.keys:
if isinstance(results[key], list):
results[key] = [
v.astype(np.float32) / 255. for v in results[key]
]
else:
results[key] = results[key].astype(np.float32) / 255.
return results
def __repr__(self):
return self.__class__.__name__ + f'(keys={self.keys})'
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from torch.utils.data import Dataset
from .builder import DATASETS
@DATASETS.register_module()
class QuickTestImageDataset(Dataset):
"""Dataset for quickly testing the correctness.
Args:
size (tuple[int]): The size of the images. Defaults to `None`.
"""
def __init__(self, *args, size=None, **kwargs):
super().__init__()
self.size = size
self.img_tensor = torch.randn(3, self.size[0], self.size[1])
def __len__(self):
return 10000
def __getitem__(self, idx):
return dict(real_img=self.img_tensor)
# Copyright (c) OpenMMLab. All rights reserved.
from .distributed_sampler import DistributedSampler
__all__ = ['DistributedSampler']
# Copyright (c) OpenMMLab. All rights reserved.
from __future__ import division
import numpy as np
import torch
from torch.utils.data import DistributedSampler as _DistributedSampler
from mmgen.utils import sync_random_seed
class DistributedSampler(_DistributedSampler):
"""DistributedSampler inheriting from
`torch.utils.data.DistributedSampler`.
In pytorch of lower versions, there is no `shuffle` argument. This child
class will port one to DistributedSampler.
"""
def __init__(self,
dataset,
num_replicas=None,
rank=None,
shuffle=True,
samples_per_gpu=1,
seed=None):
super().__init__(dataset, num_replicas=num_replicas, rank=rank)
self.shuffle = shuffle
self.samples_per_gpu = samples_per_gpu
# fix the bug of the official implementation
self.num_samples_per_replica = int(
int(
np.ceil(
len(self.dataset) * 1.0 / self.num_replicas /
samples_per_gpu)))
self.num_samples = self.num_samples_per_replica * self.samples_per_gpu
self.total_size = self.num_samples * self.num_replicas
# to avoid padding bug when meeting too small dataset
if len(dataset) < self.num_replicas * samples_per_gpu:
raise ValueError(
'You may use too small dataset and our distributed '
'sampler cannot pad your dataset correctly. We highly '
'recommend you to use fewer GPUs to finish your work')
# In distributed sampling, different ranks should sample
# non-overlapped data in the dataset. Therefore, this function
# is used to make sure that each rank shuffles the data indices
# in the same order based on the same seed. Then different ranks
# could use different indices to select non-overlapped data from the
# same data list.
self.seed = sync_random_seed(seed)
def update_sampler(self, dataset, samples_per_gpu=None):
self.dataset = dataset
if samples_per_gpu is not None:
self.samples_per_gpu = samples_per_gpu
# fix the bug of the official implementation
self.num_samples_per_replica = int(
int(
np.ceil(
len(self.dataset) * 1.0 / self.num_replicas /
self.samples_per_gpu)))
self.num_samples = self.num_samples_per_replica * self.samples_per_gpu
self.total_size = self.num_samples * self.num_replicas
# to avoid padding bug when meeting too small dataset
if len(dataset) < self.num_replicas * self.samples_per_gpu:
raise ValueError(
'You may use too small dataset and our distributed '
'sampler cannot pad your dataset correctly. We highly '
'recommend you to use fewer GPUs to finish your work')
def __iter__(self):
# deterministically shuffle based on epoch
if self.shuffle:
g = torch.Generator()
# When :attr:`shuffle=True`, this ensures all replicas
# use a different random ordering for each epoch.
# Otherwise, the next iteration of this sampler will
# yield the same ordering.
g.manual_seed(self.seed + self.epoch)
indices = torch.randperm(len(self.dataset), generator=g).tolist()
else:
indices = torch.arange(len(self.dataset)).tolist()
# add extra samples to make it evenly divisible
indices += indices[:(self.total_size - len(indices))]
assert len(indices) == self.total_size
# subsample
indices = indices[self.rank:self.total_size:self.num_replicas]
assert len(indices) == self.num_samples
return iter(indices)
# Copyright (c) OpenMMLab. All rights reserved.
import mmcv
import numpy as np
import torch
from torch.utils.data import Dataset
from .builder import DATASETS
def create_real_pyramid(real, min_size, max_size, scale_factor_init):
"""Create image pyramid.
This function is modified from the official implementation:
https://github.com/tamarott/SinGAN/blob/master/SinGAN/functions.py#L221
In this implementation, we adopt the rescaling function from MMCV.
Args:
real (np.array): The real image array.
min_size (int): The minimum size for the image pyramid.
max_size (int): The maximum size for the image pyramid.
scale_factor_init (float): The initial scale factor.
"""
num_scales = int(
np.ceil(
np.log(np.power(min_size / min(real.shape[0], real.shape[1]), 1)) /
np.log(scale_factor_init))) + 1
scale2stop = int(
np.ceil(
np.log(
min([max_size, max([real.shape[0], real.shape[1]])]) /
max([real.shape[0], real.shape[1]])) /
np.log(scale_factor_init)))
stop_scale = num_scales - scale2stop
scale1 = min(max_size / max([real.shape[0], real.shape[1]]), 1)
real_max = mmcv.imrescale(real, scale1)
scale_factor = np.power(
min_size / (min(real_max.shape[0], real_max.shape[1])),
1 / (stop_scale))
scale2stop = int(
np.ceil(
np.log(
min([max_size, max([real.shape[0], real.shape[1]])]) /
max([real.shape[0], real.shape[1]])) /
np.log(scale_factor_init)))
stop_scale = num_scales - scale2stop
reals = []
for i in range(stop_scale + 1):
scale = np.power(scale_factor, stop_scale - i)
curr_real = mmcv.imrescale(real, scale)
reals.append(curr_real)
return reals, scale_factor, stop_scale
@DATASETS.register_module()
class SinGANDataset(Dataset):
"""SinGAN Dataset.
In this dataset, we create an image pyramid and save it in the cache.
Args:
img_path (str): Path to the single image file.
min_size (int): Min size of the image pyramid. Here, the number will be
set to the ``min(H, W)``.
max_size (int): Max size of the image pyramid. Here, the number will be
set to the ``max(H, W)``.
scale_factor_init (float): Rescale factor. Note that the actual factor
we use may be a little bit different from this value.
num_samples (int, optional): The number of samples (length) in this
dataset. Defaults to -1.
"""
def __init__(self,
img_path,
min_size,
max_size,
scale_factor_init,
num_samples=-1):
self.img_path = img_path
assert mmcv.is_filepath(self.img_path)
self.load_annotations(min_size, max_size, scale_factor_init)
self.num_samples = num_samples
def load_annotations(self, min_size, max_size, scale_factor_init):
"""Load annatations for SinGAN Dataset.
Args:
min_size (int): The minimum size for the image pyramid.
max_size (int): The maximum size for the image pyramid.
scale_factor_init (float): The initial scale factor.
"""
real = mmcv.imread(self.img_path)
self.reals, self.scale_factor, self.stop_scale = create_real_pyramid(
real, min_size, max_size, scale_factor_init)
self.data_dict = {}
for i, real in enumerate(self.reals):
self.data_dict[f'real_scale{i}'] = self._img2tensor(real)
self.data_dict['input_sample'] = torch.zeros_like(
self.data_dict['real_scale0'])
def _img2tensor(self, img):
img = torch.from_numpy(img).to(torch.float32).permute(2, 0,
1).contiguous()
img = (img / 255 - 0.5) * 2
return img
def __getitem__(self, index):
return self.data_dict
def __len__(self):
return int(1e6) if self.num_samples < 0 else self.num_samples
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import mmcv
from torch.utils.data import Dataset
from .builder import DATASETS
from .pipelines import Compose
@DATASETS.register_module()
class UnconditionalImageDataset(Dataset):
"""Unconditional Image Dataset.
This dataset contains raw images for training unconditional GANs. Given
a root dir, we will recursively find all images in this root. The
transformation on data is defined by the pipeline.
Args:
imgs_root (str): Root path for unconditional images.
pipeline (list[dict | callable]): A sequence of data transforms.
test_mode (bool, optional): If True, the dataset will work in test
mode. Otherwise, in train mode. Default to False.
"""
_VALID_IMG_SUFFIX = ('.jpg', '.png', '.jpeg', '.JPEG')
def __init__(self, imgs_root, pipeline, test_mode=False):
super().__init__()
self.imgs_root = imgs_root
self.pipeline = Compose(pipeline)
self.test_mode = test_mode
self.load_annotations()
# print basic dataset information to check the validity
mmcv.print_log(repr(self), 'mmgen')
def load_annotations(self):
"""Load annotations."""
# recursively find all of the valid images from imgs_root
imgs_list = mmcv.scandir(
self.imgs_root, self._VALID_IMG_SUFFIX, recursive=True)
self.imgs_list = [osp.join(self.imgs_root, x) for x in imgs_list]
def prepare_train_data(self, idx):
"""Prepare training data.
Args:
idx (int): Index of current batch.
Returns:
dict: Prepared training data batch.
"""
results = dict(real_img_path=self.imgs_list[idx])
return self.pipeline(results)
def prepare_test_data(self, idx):
"""Prepare testing data.
Args:
idx (int): Index of current batch.
Returns:
dict: Prepared training data batch.
"""
results = dict(real_img_path=self.imgs_list[idx])
return self.pipeline(results)
def __len__(self):
return len(self.imgs_list)
def __getitem__(self, idx):
if not self.test_mode:
return self.prepare_train_data(idx)
return self.prepare_test_data(idx)
def __repr__(self):
dataset_name = self.__class__
imgs_root = self.imgs_root
num_imgs = len(self)
return (f'dataset_name: {dataset_name}, total {num_imgs} images in '
f'imgs_root: {imgs_root}')
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
from pathlib import Path
import numpy as np
from mmcv import scandir
from torch.utils.data import Dataset
from .builder import DATASETS
from .pipelines import Compose
IMG_EXTENSIONS = ('.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG', '.ppm',
'.PPM', '.bmp', '.BMP', '.tif', '.TIF', '.tiff', '.TIFF')
@DATASETS.register_module()
class UnpairedImageDataset(Dataset):
"""General unpaired image folder dataset for image generation.
It assumes that the training directory of images from domain A is
'/path/to/data/trainA', and that from domain B is '/path/to/data/trainB',
respectively. '/path/to/data' can be initialized by args 'dataroot'.
During test time, the directory is '/path/to/data/testA' and
'/path/to/data/testB', respectively.
Args:
dataroot (str | :obj:`Path`): Path to the folder root of unpaired
images.
pipeline (List[dict | callable]): A sequence of data transformations.
test_mode (bool): Store `True` when building test dataset.
Default: `False`.
domain_a (str, optional): Domain of images in trainA / testA.
Defaults to None.
domain_b (str, optional): Domain of images in trainB / testB.
Defaults to None.
"""
def __init__(self,
dataroot,
pipeline,
test_mode=False,
domain_a=None,
domain_b=None):
super().__init__()
phase = 'test' if test_mode else 'train'
self.dataroot_a = osp.join(str(dataroot), phase + 'A')
self.dataroot_b = osp.join(str(dataroot), phase + 'B')
self.data_infos_a = self.load_annotations(self.dataroot_a)
self.data_infos_b = self.load_annotations(self.dataroot_b)
self.len_a = len(self.data_infos_a)
self.len_b = len(self.data_infos_b)
self.test_mode = test_mode
self.pipeline = Compose(pipeline)
assert isinstance(domain_a, str)
assert isinstance(domain_b, str)
self.domain_a = domain_a
self.domain_b = domain_b
def load_annotations(self, dataroot):
"""Load unpaired image paths of one domain.
Args:
dataroot (str): Path to the folder root for unpaired images of
one domain.
Returns:
list[dict]: List that contains unpaired image paths of one domain.
"""
data_infos = []
paths = sorted(self.scan_folder(dataroot))
for path in paths:
data_infos.append(dict(path=path))
return data_infos
def prepare_train_data(self, idx):
"""Prepare unpaired training data.
Args:
idx (int): Index of current batch.
Returns:
dict: Prepared training data batch.
"""
img_a_path = self.data_infos_a[idx % self.len_a]['path']
idx_b = np.random.randint(0, self.len_b)
img_b_path = self.data_infos_b[idx_b]['path']
results = dict()
results[f'img_{self.domain_a}_path'] = img_a_path
results[f'img_{self.domain_b}_path'] = img_b_path
return self.pipeline(results)
def prepare_test_data(self, idx):
"""Prepare unpaired test data.
Args:
idx (int): Index of current batch.
Returns:
list[dict]: Prepared test data batch.
"""
img_a_path = self.data_infos_a[idx % self.len_a]['path']
img_b_path = self.data_infos_b[idx % self.len_b]['path']
results = dict()
results[f'img_{self.domain_a}_path'] = img_a_path
results[f'img_{self.domain_b}_path'] = img_b_path
return self.pipeline(results)
def __len__(self):
return max(self.len_a, self.len_b)
@staticmethod
def scan_folder(path):
"""Obtain image path list (including sub-folders) from a given folder.
Args:
path (str | :obj:`Path`): Folder path.
Returns:
list[str]: Image list obtained from the given folder.
"""
if isinstance(path, (str, Path)):
path = str(path)
else:
raise TypeError("'path' must be a str or a Path object, "
f'but received {type(path)}.')
images = scandir(path, suffix=IMG_EXTENSIONS, recursive=True)
images = [osp.join(path, v) for v in images]
assert images, f'{path} has no valid image file.'
return images
def __getitem__(self, idx):
"""Get item at each call.
Args:
idx (int): Index for getting each item.
"""
if not self.test_mode:
return self.prepare_train_data(idx)
return self.prepare_test_data(idx)
# Copyright (c) OpenMMLab. All rights reserved.
from .architectures import * # noqa: F401, F403
from .builder import MODELS, MODULES, build_model, build_module
from .common import * # noqa: F401, F403
from .diffusions import * # noqa: F401, F403
from .gans import * # noqa: F401, F403
from .losses import * # noqa: F401, F403
from .misc import * # noqa: F401, F403
from .translation_models import * # noqa: F401, F403
__all__ = ['build_model', 'MODELS', 'build_module', 'MODULES']
# Copyright (c) OpenMMLab. All rights reserved.
from .arcface import IDLossModel
from .biggan import (BigGANDeepDiscriminator, BigGANDeepGenerator,
BigGANDiscriminator, BigGANGenerator, SNConvModule)
from .cyclegan import ResnetGenerator
from .dcgan import DCGANDiscriminator, DCGANGenerator
from .ddpm import DenoisingUnet
from .fid_inception import InceptionV3
from .lpips import PerceptualLoss
from .lsgan import LSGANDiscriminator, LSGANGenerator
from .pggan import (EqualizedLR, EqualizedLRConvDownModule,
EqualizedLRConvModule, EqualizedLRConvUpModule,
EqualizedLRLinearModule, MiniBatchStddevLayer,
PGGANDiscriminator, PGGANGenerator, PGGANNoiseTo2DFeat,
PixelNorm, equalized_lr)
from .pix2pix import PatchDiscriminator, generation_init_weights
from .positional_encoding import CatersianGrid, SinusoidalPositionalEmbedding
from .singan import SinGANMultiScaleDiscriminator, SinGANMultiScaleGenerator
from .sngan_proj import ProjDiscriminator, SNGANGenerator
from .stylegan import (MSStyleGAN2Discriminator, MSStyleGANv2Generator,
StyleGAN1Discriminator, StyleGAN2Discriminator,
StyleGANv1Generator, StyleGANv2Generator,
StyleGANv3Generator)
from .wgan_gp import WGANGPDiscriminator, WGANGPGenerator
__all__ = [
'DCGANGenerator', 'DCGANDiscriminator', 'EqualizedLR',
'EqualizedLRConvModule', 'equalized_lr', 'EqualizedLRLinearModule',
'EqualizedLRConvUpModule', 'EqualizedLRConvDownModule', 'PixelNorm',
'MiniBatchStddevLayer', 'PGGANNoiseTo2DFeat', 'PGGANGenerator',
'PGGANDiscriminator', 'InceptionV3', 'SinGANMultiScaleDiscriminator',
'SinGANMultiScaleGenerator', 'CatersianGrid',
'SinusoidalPositionalEmbedding', 'StyleGAN2Discriminator',
'StyleGANv2Generator', 'StyleGANv1Generator', 'StyleGAN1Discriminator',
'MSStyleGAN2Discriminator', 'MSStyleGANv2Generator',
'generation_init_weights', 'PatchDiscriminator', 'ResnetGenerator',
'PerceptualLoss', 'WGANGPDiscriminator', 'WGANGPGenerator',
'LSGANDiscriminator', 'LSGANGenerator', 'ProjDiscriminator',
'SNGANGenerator', 'BigGANGenerator', 'SNConvModule', 'BigGANDiscriminator',
'BigGANDeepGenerator', 'BigGANDeepDiscriminator', 'DenoisingUnet',
'StyleGANv3Generator', 'IDLossModel'
]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment