Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
stylegan2_mmcv
Commits
1401de15
Commit
1401de15
authored
Jun 28, 2024
by
dongchy920
Browse files
stylegan2_mmcv
parents
Pipeline
#1274
canceled with stages
Changes
463
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2128 additions
and
0 deletions
+2128
-0
build/lib/mmgen/datasets/__init__.py
build/lib/mmgen/datasets/__init__.py
+20
-0
build/lib/mmgen/datasets/builder.py
build/lib/mmgen/datasets/builder.py
+146
-0
build/lib/mmgen/datasets/dataset_wrappers.py
build/lib/mmgen/datasets/dataset_wrappers.py
+39
-0
build/lib/mmgen/datasets/grow_scale_image_dataset.py
build/lib/mmgen/datasets/grow_scale_image_dataset.py
+179
-0
build/lib/mmgen/datasets/paired_image_dataset.py
build/lib/mmgen/datasets/paired_image_dataset.py
+118
-0
build/lib/mmgen/datasets/pipelines/__init__.py
build/lib/mmgen/datasets/pipelines/__init__.py
+25
-0
build/lib/mmgen/datasets/pipelines/augmentation.py
build/lib/mmgen/datasets/pipelines/augmentation.py
+438
-0
build/lib/mmgen/datasets/pipelines/compose.py
build/lib/mmgen/datasets/pipelines/compose.py
+68
-0
build/lib/mmgen/datasets/pipelines/crop.py
build/lib/mmgen/datasets/pipelines/crop.py
+161
-0
build/lib/mmgen/datasets/pipelines/formatting.py
build/lib/mmgen/datasets/pipelines/formatting.py
+141
-0
build/lib/mmgen/datasets/pipelines/loading.py
build/lib/mmgen/datasets/pipelines/loading.py
+177
-0
build/lib/mmgen/datasets/pipelines/normalize.py
build/lib/mmgen/datasets/pipelines/normalize.py
+94
-0
build/lib/mmgen/datasets/quick_test_dataset.py
build/lib/mmgen/datasets/quick_test_dataset.py
+25
-0
build/lib/mmgen/datasets/samplers/__init__.py
build/lib/mmgen/datasets/samplers/__init__.py
+4
-0
build/lib/mmgen/datasets/samplers/distributed_sampler.py
build/lib/mmgen/datasets/samplers/distributed_sampler.py
+94
-0
build/lib/mmgen/datasets/singan_dataset.py
build/lib/mmgen/datasets/singan_dataset.py
+121
-0
build/lib/mmgen/datasets/unconditional_image_dataset.py
build/lib/mmgen/datasets/unconditional_image_dataset.py
+83
-0
build/lib/mmgen/datasets/unpaired_image_dataset.py
build/lib/mmgen/datasets/unpaired_image_dataset.py
+142
-0
build/lib/mmgen/models/__init__.py
build/lib/mmgen/models/__init__.py
+11
-0
build/lib/mmgen/models/architectures/__init__.py
build/lib/mmgen/models/architectures/__init__.py
+42
-0
No files found.
Too many changes to show.
To preserve performance only
463 of 463+
files are displayed.
Plain diff
Email patch
build/lib/mmgen/datasets/__init__.py
0 → 100644
View file @
1401de15
# Copyright (c) OpenMMLab. All rights reserved.
from
.builder
import
build_dataloader
,
build_dataset
from
.dataset_wrappers
import
RepeatDataset
from
.grow_scale_image_dataset
import
GrowScaleImgDataset
from
.paired_image_dataset
import
PairedImageDataset
from
.pipelines
import
(
Collect
,
Compose
,
Flip
,
ImageToTensor
,
LoadImageFromFile
,
Normalize
,
Resize
,
ToTensor
)
from
.quick_test_dataset
import
QuickTestImageDataset
from
.samplers
import
DistributedSampler
from
.singan_dataset
import
SinGANDataset
from
.unconditional_image_dataset
import
UnconditionalImageDataset
from
.unpaired_image_dataset
import
UnpairedImageDataset
__all__
=
[
'build_dataloader'
,
'build_dataset'
,
'LoadImageFromFile'
,
'DistributedSampler'
,
'UnconditionalImageDataset'
,
'Compose'
,
'ToTensor'
,
'ImageToTensor'
,
'Collect'
,
'Flip'
,
'Resize'
,
'RepeatDataset'
,
'Normalize'
,
'GrowScaleImgDataset'
,
'SinGANDataset'
,
'PairedImageDataset'
,
'UnpairedImageDataset'
,
'QuickTestImageDataset'
]
build/lib/mmgen/datasets/builder.py
0 → 100644
View file @
1401de15
# Copyright (c) OpenMMLab. All rights reserved.
import
platform
import
random
import
warnings
from
copy
import
deepcopy
from
functools
import
partial
import
numpy
as
np
import
torch
from
mmcv.parallel
import
collate
from
mmcv.runner
import
get_dist_info
from
mmcv.utils
import
TORCH_VERSION
,
Registry
,
build_from_cfg
,
digit_version
from
torch.utils.data
import
DataLoader
from
.samplers
import
DistributedSampler
if
platform
.
system
()
!=
'Windows'
:
# https://github.com/pytorch/pytorch/issues/973
import
resource
rlimit
=
resource
.
getrlimit
(
resource
.
RLIMIT_NOFILE
)
base_soft_limit
=
rlimit
[
0
]
hard_limit
=
rlimit
[
1
]
soft_limit
=
min
(
max
(
4096
,
base_soft_limit
),
hard_limit
)
resource
.
setrlimit
(
resource
.
RLIMIT_NOFILE
,
(
soft_limit
,
hard_limit
))
DATASETS
=
Registry
(
'dataset'
)
PIPELINES
=
Registry
(
'pipeline'
)
def
build_dataset
(
cfg
,
default_args
=
None
):
"""Build dataset.
Args:
cfg (dict): Config for the dataset.
default_args (dict | None, optional): Default arguments.
Defaults to None.
Returns:
Object: Dataset for sampling data batch.
"""
from
.dataset_wrappers
import
RepeatDataset
if
isinstance
(
cfg
,
(
list
,
tuple
)):
raise
NotImplementedError
(
'Currently, we do NOT support ConcatDataset'
)
# dataset = ConcatDataset(
# [build_dataset(c, default_args) for c in cfg])
if
cfg
[
'type'
]
==
'RepeatDataset'
:
dataset
=
RepeatDataset
(
build_dataset
(
cfg
[
'dataset'
],
default_args
),
cfg
[
'times'
])
# add support for using datasets from `MMClassification`
elif
cfg
[
'type'
].
startswith
(
'mmcls.'
):
try
:
from
mmcls.datasets
import
build_dataset
as
build_dataset_mmcls
except
ImportError
:
raise
ImportError
(
f
'Please install mmcls to use
{
cfg
[
"type"
]
}
dataset.'
)
_cfg
=
deepcopy
(
cfg
)
_cfg
[
'type'
]
=
_cfg
[
'type'
][
6
:]
dataset
=
build_dataset_mmcls
(
_cfg
,
default_args
)
else
:
dataset
=
build_from_cfg
(
cfg
,
DATASETS
,
default_args
)
return
dataset
def
build_dataloader
(
dataset
,
samples_per_gpu
,
workers_per_gpu
,
num_gpus
=
1
,
dist
=
True
,
shuffle
=
True
,
seed
=
None
,
persistent_workers
=
False
,
**
kwargs
):
"""Build PyTorch DataLoader.
In distributed training, each GPU/process has a dataloader.
In non-distributed training, there is only one dataloader for all GPUs.
Args:
dataset (Dataset): A PyTorch dataset.
samples_per_gpu (int): Number of training samples on each GPU, i.e.,
batch size of each GPU.
workers_per_gpu (int): How many subprocesses to use for data loading
for each GPU.
num_gpus (int): Number of GPUs. Only used in non-distributed training.
dist (bool): Distributed training/test or not. Default: True.
shuffle (bool): Whether to shuffle the data at every epoch.
Default: True.
persistent_workers (bool, optional): If True, the data loader will
not shutdown the worker processes after a dataset has been
consumed once. This allows to maintain the workers Dataset
instances alive. The argument also has effect in PyTorch>=1.7.0.
Default: False.
kwargs: any keyword argument to be used to initialize DataLoader
Returns:
DataLoader: A PyTorch dataloader.
"""
rank
,
world_size
=
get_dist_info
()
if
dist
:
sampler
=
DistributedSampler
(
dataset
,
world_size
,
rank
,
shuffle
=
shuffle
,
samples_per_gpu
=
samples_per_gpu
,
seed
=
seed
)
shuffle
=
False
batch_size
=
samples_per_gpu
num_workers
=
workers_per_gpu
else
:
sampler
=
None
batch_size
=
num_gpus
*
samples_per_gpu
num_workers
=
num_gpus
*
workers_per_gpu
init_fn
=
partial
(
worker_init_fn
,
num_workers
=
num_workers
,
rank
=
rank
,
seed
=
seed
)
if
seed
is
not
None
else
None
if
(
digit_version
(
TORCH_VERSION
)
>=
digit_version
(
'1.7.0'
)
and
TORCH_VERSION
!=
'parrots'
):
kwargs
[
'persistent_workers'
]
=
persistent_workers
elif
persistent_workers
is
True
:
warnings
.
warn
(
'persistent_workers is invalid because your pytorch '
'version is lower than 1.7.0'
)
data_loader
=
DataLoader
(
dataset
,
batch_size
=
batch_size
,
sampler
=
sampler
,
num_workers
=
num_workers
,
collate_fn
=
partial
(
collate
,
samples_per_gpu
=
samples_per_gpu
),
shuffle
=
shuffle
,
worker_init_fn
=
init_fn
,
**
kwargs
)
return
data_loader
def
worker_init_fn
(
worker_id
,
num_workers
,
rank
,
seed
):
# The seed of each worker equals to
# num_worker * rank + worker_id + user_seed
worker_seed
=
num_workers
*
rank
+
worker_id
+
seed
np
.
random
.
seed
(
worker_seed
)
random
.
seed
(
worker_seed
)
torch
.
manual_seed
(
worker_seed
)
build/lib/mmgen/datasets/dataset_wrappers.py
0 → 100644
View file @
1401de15
# Copyright (c) OpenMMLab. All rights reserved.
from
.builder
import
DATASETS
@
DATASETS
.
register_module
()
class
RepeatDataset
:
"""A wrapper of repeated dataset.
The length of repeated dataset will be `times` larger than the original
dataset. This is useful when the data loading time is long but the dataset
is small. Using RepeatDataset can reduce the data loading time between
epochs.
Args:
dataset (:obj:`Dataset`): The dataset to be repeated.
times (int): Repeat times.
"""
def
__init__
(
self
,
dataset
,
times
):
self
.
dataset
=
dataset
self
.
times
=
times
self
.
_ori_len
=
len
(
self
.
dataset
)
def
__getitem__
(
self
,
idx
):
"""Get item at each call.
Args:
idx (int): Index for getting each item.
"""
return
self
.
dataset
[
idx
%
self
.
_ori_len
]
def
__len__
(
self
):
"""Length of the dataset.
Returns:
int: Length of the dataset.
"""
return
self
.
times
*
self
.
_ori_len
build/lib/mmgen/datasets/grow_scale_image_dataset.py
0 → 100644
View file @
1401de15
# Copyright (c) OpenMMLab. All rights reserved.
import
os.path
as
osp
import
mmcv
from
torch.utils.data
import
Dataset
from
.builder
import
DATASETS
from
.pipelines
import
Compose
@
DATASETS
.
register_module
()
class
GrowScaleImgDataset
(
Dataset
):
"""Grow Scale Unconditional Image Dataset.
This dataset is similar with ``UnconditionalImageDataset``, but offer
more dynamic functionalities for the supporting complex algorithms, like
PGGAN.
Highlight functionalities:
#. Support growing scale dataset. The motivation is to decrease data
pre-processing load in CPU. In this dataset, you can provide
``imgs_roots`` like:
.. code-block:: python
{'64': 'path_to_64x64_imgs',
'512': 'path_to_512x512_imgs'}
Then, in training scales lower than 64x64, this dataset will set
``self.imgs_root`` as 'path_to_64x64_imgs';
#. Offer ``samples_per_gpu`` according to different scales. In this
dataset, ``self.samples_per_gpu`` will help runner to know the updated
batch size.
Basically, This dataset contains raw images for training unconditional
GANs. Given a root dir, we will recursively find all images in this root.
The transformation on data is defined by the pipeline.
Args:
imgs_root (str): Root path for unconditional images.
pipeline (list[dict | callable]): A sequence of data transforms.
len_per_stage (int, optional): The length of dataset for each scale.
This args change the length dataset by concatenating or extracting
subset. If given a value less than 0., the original length will be
kept. Defaults to 1e6.
gpu_samples_per_scale (dict | None, optional): Dict contains
``samples_per_gpu`` for each scale. For example, ``{'32': 4}`` will
set the scale of 32 with ``samples_per_gpu=4``, despite other scale
with ``samples_per_gpu=self.gpu_samples_base``.
gpu_samples_base (int, optional): Set default ``samples_per_gpu`` for
each scale. Defaults to 32.
test_mode (bool, optional): If True, the dataset will work in test
mode. Otherwise, in train mode. Default to False.
"""
_VALID_IMG_SUFFIX
=
(
'.jpg'
,
'.png'
,
'.jpeg'
,
'.JPEG'
)
def
__init__
(
self
,
imgs_roots
,
pipeline
,
len_per_stage
=
int
(
1e6
),
gpu_samples_per_scale
=
None
,
gpu_samples_base
=
32
,
test_mode
=
False
):
super
().
__init__
()
assert
isinstance
(
imgs_roots
,
dict
)
self
.
imgs_roots
=
imgs_roots
self
.
_img_scales
=
sorted
([
int
(
x
)
for
x
in
imgs_roots
.
keys
()])
self
.
_curr_scale
=
self
.
_img_scales
[
0
]
self
.
_actual_curr_scale
=
self
.
_curr_scale
self
.
imgs_root
=
self
.
imgs_roots
[
str
(
self
.
_curr_scale
)]
self
.
pipeline
=
Compose
(
pipeline
)
self
.
test_mode
=
test_mode
# len_per_stage = -1, keep the original length
self
.
len_per_stage
=
len_per_stage
self
.
curr_stage
=
0
self
.
gpu_samples_per_scale
=
gpu_samples_per_scale
if
self
.
gpu_samples_per_scale
is
not
None
:
assert
isinstance
(
self
.
gpu_samples_per_scale
,
dict
)
else
:
self
.
gpu_samples_per_scale
=
dict
()
self
.
gpu_samples_base
=
gpu_samples_base
self
.
load_annotations
()
# print basic dataset information to check the validity
mmcv
.
print_log
(
repr
(
self
),
'mmgen'
)
def
load_annotations
(
self
):
"""Load annotations."""
# recursively find all of the valid images from imgs_root
imgs_list
=
mmcv
.
scandir
(
self
.
imgs_root
,
self
.
_VALID_IMG_SUFFIX
,
recursive
=
True
)
self
.
imgs_list
=
[
osp
.
join
(
self
.
imgs_root
,
x
)
for
x
in
imgs_list
]
if
self
.
len_per_stage
>
0
:
self
.
concat_imgs_list_to
(
self
.
len_per_stage
)
self
.
samples_per_gpu
=
self
.
gpu_samples_per_scale
.
get
(
str
(
self
.
_actual_curr_scale
),
self
.
gpu_samples_base
)
def
update_annotations
(
self
,
curr_scale
):
"""Update annotations.
Args:
curr_scale (int): Current image scale.
Returns:
bool: Whether to update.
"""
if
curr_scale
==
self
.
_actual_curr_scale
:
return
False
for
scale
in
self
.
_img_scales
:
if
curr_scale
<=
scale
:
self
.
_curr_scale
=
scale
break
if
scale
==
self
.
_img_scales
[
-
1
]:
assert
RuntimeError
(
f
'Cannot find a suitable scale for
{
curr_scale
}
'
)
self
.
_actual_curr_scale
=
curr_scale
self
.
imgs_root
=
self
.
imgs_roots
[
str
(
self
.
_curr_scale
)]
self
.
load_annotations
()
# print basic dataset information to check the validity
mmcv
.
print_log
(
'Update Dataset: '
+
repr
(
self
),
'mmgen'
)
return
True
def
concat_imgs_list_to
(
self
,
num
):
"""Concat image list to specified length.
Args:
num (int): The length of the concatenated image list.
"""
if
num
<=
len
(
self
.
imgs_list
):
self
.
imgs_list
=
self
.
imgs_list
[:
num
]
return
concat_factor
=
(
num
//
len
(
self
.
imgs_list
))
+
1
imgs
=
self
.
imgs_list
*
concat_factor
self
.
imgs_list
=
imgs
[:
num
]
def
prepare_train_data
(
self
,
idx
):
"""Prepare training data.
Args:
idx (int): Index of current batch.
Returns:
dict: Prepared training data batch.
"""
results
=
dict
(
real_img_path
=
self
.
imgs_list
[
idx
])
return
self
.
pipeline
(
results
)
def
prepare_test_data
(
self
,
idx
):
"""Prepare testing data.
Args:
idx (int): Index of current batch.
Returns:
dict: Prepared training data batch.
"""
results
=
dict
(
real_img_path
=
self
.
imgs_list
[
idx
])
return
self
.
pipeline
(
results
)
def
__len__
(
self
):
return
len
(
self
.
imgs_list
)
def
__getitem__
(
self
,
idx
):
if
not
self
.
test_mode
:
return
self
.
prepare_train_data
(
idx
)
return
self
.
prepare_test_data
(
idx
)
def
__repr__
(
self
):
dataset_name
=
self
.
__class__
imgs_root
=
self
.
imgs_root
num_imgs
=
len
(
self
)
return
(
f
'dataset_name:
{
dataset_name
}
, total
{
num_imgs
}
images in '
f
'imgs_root:
{
imgs_root
}
'
)
build/lib/mmgen/datasets/paired_image_dataset.py
0 → 100644
View file @
1401de15
# Copyright (c) OpenMMLab. All rights reserved.
import
copy
import
os.path
as
osp
from
pathlib
import
Path
from
mmcv
import
scandir
from
torch.utils.data
import
Dataset
from
.builder
import
DATASETS
from
.pipelines
import
Compose
IMG_EXTENSIONS
=
(
'.jpg'
,
'.JPG'
,
'.jpeg'
,
'.JPEG'
,
'.png'
,
'.PNG'
,
'.ppm'
,
'.PPM'
,
'.bmp'
,
'.BMP'
,
'.tif'
,
'.TIF'
,
'.tiff'
,
'.TIFF'
)
@
DATASETS
.
register_module
()
class
PairedImageDataset
(
Dataset
):
"""General paired image folder dataset for image generation.
It assumes that the training directory is '/path/to/data/train'.
During test time, the directory is '/path/to/data/test'. '/path/to/data'
can be initialized by args 'dataroot'. Each sample contains a pair of
images concatenated in the w dimension (A|B).
Args:
dataroot (str | :obj:`Path`): Path to the folder root of paired images.
pipeline (List[dict | callable]): A sequence of data transformations.
test_mode (bool): Store `True` when building test dataset.
Default: `False`.
testdir (str): Subfolder of dataroot which contain test images.
Default: 'test'.
"""
def
__init__
(
self
,
dataroot
,
pipeline
,
test_mode
=
False
,
testdir
=
'test'
):
super
().
__init__
()
phase
=
testdir
if
test_mode
else
'train'
self
.
dataroot
=
osp
.
join
(
str
(
dataroot
),
phase
)
self
.
data_infos
=
self
.
load_annotations
()
self
.
test_mode
=
test_mode
self
.
pipeline
=
Compose
(
pipeline
)
def
load_annotations
(
self
):
"""Load paired image paths.
Returns:
list[dict]: List that contains paired image paths.
"""
data_infos
=
[]
pair_paths
=
sorted
(
self
.
scan_folder
(
self
.
dataroot
))
for
pair_path
in
pair_paths
:
data_infos
.
append
(
dict
(
pair_path
=
pair_path
))
return
data_infos
@
staticmethod
def
scan_folder
(
path
):
"""Obtain image path list (including sub-folders) from a given folder.
Args:
path (str | :obj:`Path`): Folder path.
Returns:
list[str]: Image list obtained from the given folder.
"""
if
isinstance
(
path
,
(
str
,
Path
)):
path
=
str
(
path
)
else
:
raise
TypeError
(
"'path' must be a str or a Path object, "
f
'but received
{
type
(
path
)
}
.'
)
images
=
scandir
(
path
,
suffix
=
IMG_EXTENSIONS
,
recursive
=
True
)
images
=
[
osp
.
join
(
path
,
v
)
for
v
in
images
]
assert
images
,
f
'
{
path
}
has no valid image file.'
return
images
def
prepare_train_data
(
self
,
idx
):
"""Prepare training data.
Args:
idx (int): Index of the training batch data.
Returns:
dict: Returned training batch.
"""
results
=
copy
.
deepcopy
(
self
.
data_infos
[
idx
])
return
self
.
pipeline
(
results
)
def
prepare_test_data
(
self
,
idx
):
"""Prepare testing data.
Args:
idx (int): Index for getting each testing batch.
Returns:
Tensor: Returned testing batch.
"""
results
=
copy
.
deepcopy
(
self
.
data_infos
[
idx
])
return
self
.
pipeline
(
results
)
def
__len__
(
self
):
"""Length of the dataset.
Returns:
int: Length of the dataset.
"""
return
len
(
self
.
data_infos
)
def
__getitem__
(
self
,
idx
):
"""Get item at each call.
Args:
idx (int): Index for getting each item.
"""
if
not
self
.
test_mode
:
return
self
.
prepare_train_data
(
idx
)
return
self
.
prepare_test_data
(
idx
)
build/lib/mmgen/datasets/pipelines/__init__.py
0 → 100644
View file @
1401de15
# Copyright (c) OpenMMLab. All rights reserved.
from
.augmentation
import
(
CenterCropLongEdge
,
Flip
,
NumpyPad
,
RandomCropLongEdge
,
RandomImgNoise
,
Resize
)
from
.compose
import
Compose
from
.crop
import
Crop
,
FixedCrop
from
.formatting
import
Collect
,
ImageToTensor
,
ToTensor
from
.loading
import
LoadImageFromFile
from
.normalize
import
Normalize
__all__
=
[
'LoadImageFromFile'
,
'Compose'
,
'ImageToTensor'
,
'Collect'
,
'ToTensor'
,
'Flip'
,
'Resize'
,
'RandomImgNoise'
,
'RandomCropLongEdge'
,
'CenterCropLongEdge'
,
'Normalize'
,
'NumpyPad'
,
'Crop'
,
'FixedCrop'
,
]
build/lib/mmgen/datasets/pipelines/augmentation.py
0 → 100644
View file @
1401de15
# Copyright (c) OpenMMLab. All rights reserved.
import
mmcv
import
numpy
as
np
from
mmcls.datasets
import
PIPELINES
as
CLS_PIPELINE
from
..builder
import
PIPELINES
@
PIPELINES
.
register_module
()
class
Flip
:
"""Flip the input data with a probability.
Reverse the order of elements in the given data with a specific direction.
The shape of the data is preserved, but the elements are reordered.
Required keys are the keys in attributes "keys", added or modified keys are
"flip", "flip_direction" and the keys in attributes "keys".
It also supports flipping a list of images with the same flip.
Args:
keys (list[str]): The images to be flipped.
flip_ratio (float): The propability to flip the images.
direction (str): Flip images horizontally or vertically. Options are
"horizontal" | "vertical". Default: "horizontal".
"""
_directions
=
[
'horizontal'
,
'vertical'
]
def
__init__
(
self
,
keys
,
flip_ratio
=
0.5
,
direction
=
'horizontal'
):
if
direction
not
in
self
.
_directions
:
raise
ValueError
(
f
'Direction
{
direction
}
is not supported.'
f
'Currently support ones are
{
self
.
_directions
}
'
)
self
.
keys
=
keys
self
.
flip_ratio
=
flip_ratio
self
.
direction
=
direction
def
__call__
(
self
,
results
):
"""Call function.
Args:
results (dict): A dict containing the necessary information and
data for augmentation.
Returns:
dict: A dict containing the processed data and information.
"""
flip
=
np
.
random
.
random
()
<
self
.
flip_ratio
if
flip
:
for
key
in
self
.
keys
:
if
isinstance
(
results
[
key
],
list
):
for
v
in
results
[
key
]:
mmcv
.
imflip_
(
v
,
self
.
direction
)
else
:
mmcv
.
imflip_
(
results
[
key
],
self
.
direction
)
results
[
'flip'
]
=
flip
results
[
'flip_direction'
]
=
self
.
direction
return
results
def
__repr__
(
self
):
repr_str
=
self
.
__class__
.
__name__
repr_str
+=
(
f
'(keys=
{
self
.
keys
}
, flip_ratio=
{
self
.
flip_ratio
}
, '
f
'direction=
{
self
.
direction
}
)'
)
return
repr_str
@
PIPELINES
.
register_module
()
class
Resize
:
"""Resize data to a specific size for training or resize the images to fit
the network input regulation for testing.
When used for resizing images to fit network input regulation, the case is
that a network may have several downsample and then upsample operation,
then the input height and width should be divisible by the downsample
factor of the network.
For example, the network would downsample the input for 5 times with
stride 2, then the downsample factor is 2^5 = 32 and the height
and width should be divisible by 32.
Required keys are the keys in attribute "keys", added or modified keys are
"keep_ratio", "scale_factor", "interpolation" and the
keys in attribute "keys".
All keys in "keys" should have the same shape. "test_trans" is used to
record the test transformation to align the input's shape.
Args:
keys (list[str]): The images to be resized.
scale (float | Tuple[int]): If scale is Tuple(int), target spatial
size (h, w). Otherwise, target spatial size is scaled by input
size. If any of scale is -1, we will rescale short edge.
Note that when it is used, `size_factor` and `max_size` are
useless. Default: None
keep_ratio (bool): If set to True, images will be resized without
changing the aspect ratio. Otherwise, it will resize images to a
given size. Default: False.
Note that it is used togher with `scale`.
size_factor (int): Let the output shape be a multiple of size_factor.
Default:None.
Note that when it is used, `scale` should be set to None and
`keep_ratio` should be set to False.
max_size (int): The maximum size of the longest side of the output.
Default:None.
Note that it is used togher with `size_factor`.
interpolation (str): Algorithm used for interpolation:
"nearest" | "bilinear" | "bicubic" | "area" | "lanczos".
Default: "bilinear".
backend (str | None): The image resize backend type. Options are `cv2`,
`pillow`, `None`. If backend is None, the global imread_backend
specified by ``mmcv.use_backend()`` will be used. Default: None.
"""
def
__init__
(
self
,
keys
,
scale
=
None
,
keep_ratio
=
False
,
size_factor
=
None
,
max_size
=
None
,
interpolation
=
'bilinear'
,
backend
=
None
):
assert
keys
,
'Keys should not be empty.'
if
size_factor
:
assert
scale
is
None
,
(
'When size_factor is used, scale should '
,
f
'be None. But received
{
scale
}
.'
)
assert
keep_ratio
is
False
,
(
'When size_factor is used, '
'keep_ratio should be False.'
)
if
max_size
:
assert
size_factor
is
not
None
,
(
'When max_size is used, '
f
'size_factor should also be set. But received
{
size_factor
}
.'
)
if
isinstance
(
scale
,
float
):
if
scale
<=
0
:
raise
ValueError
(
f
'Invalid scale
{
scale
}
, must be positive.'
)
elif
mmcv
.
is_tuple_of
(
scale
,
int
):
max_long_edge
=
max
(
scale
)
max_short_edge
=
min
(
scale
)
if
max_short_edge
==
-
1
:
# assign np.inf to long edge for rescaling short edge later.
scale
=
(
np
.
inf
,
max_long_edge
)
elif
scale
is
not
None
:
raise
TypeError
(
f
'Scale must be None, float or tuple of int, but got '
f
'
{
type
(
scale
)
}
.'
)
self
.
keys
=
keys
self
.
scale
=
scale
self
.
size_factor
=
size_factor
self
.
max_size
=
max_size
self
.
keep_ratio
=
keep_ratio
self
.
interpolation
=
interpolation
self
.
backend
=
backend
def
_resize
(
self
,
img
,
scale
):
"""Resize given image with corresponding scale.
Args:
img (np.array): Image to be resized.
scale (float | Tuple[int]): Scale used in resize process.
Returns:
tuple: Tuple contains resized image and scale factor in resize
process.
"""
if
self
.
keep_ratio
:
img
,
scale_factor
=
mmcv
.
imrescale
(
img
,
scale
,
return_scale
=
True
,
interpolation
=
self
.
interpolation
,
backend
=
self
.
backend
)
else
:
img
,
w_scale
,
h_scale
=
mmcv
.
imresize
(
img
,
scale
,
return_scale
=
True
,
interpolation
=
self
.
interpolation
,
backend
=
self
.
backend
)
scale_factor
=
np
.
array
((
w_scale
,
h_scale
),
dtype
=
np
.
float32
)
return
img
,
scale_factor
def
__call__
(
self
,
results
):
"""Call function.
Args:
results (dict): A dict containing the necessary information and
data for augmentation.
Returns:
dict: A dict containing the processed data and information.
"""
if
self
.
size_factor
:
h
,
w
=
results
[
self
.
keys
[
0
]].
shape
[:
2
]
new_h
=
h
-
(
h
%
self
.
size_factor
)
new_w
=
w
-
(
w
%
self
.
size_factor
)
if
self
.
max_size
:
new_h
=
min
(
self
.
max_size
-
(
self
.
max_size
%
self
.
size_factor
),
new_h
)
new_w
=
min
(
self
.
max_size
-
(
self
.
max_size
%
self
.
size_factor
),
new_w
)
scale
=
(
new_w
,
new_h
)
elif
isinstance
(
self
.
scale
,
tuple
)
and
(
np
.
inf
in
self
.
scale
):
# find inf in self.scale, calculate ``scale`` manually
h
,
w
=
results
[
self
.
keys
[
0
]].
shape
[:
2
]
if
h
<
w
:
scale
=
(
int
(
self
.
scale
[
-
1
]
/
h
*
w
),
self
.
scale
[
-
1
])
else
:
scale
=
(
self
.
scale
[
-
1
],
int
(
self
.
scale
[
-
1
]
/
w
*
h
))
else
:
# direct use the given ones
scale
=
self
.
scale
# here we assume all images in self.keys have the same input size
for
key
in
self
.
keys
:
results
[
key
],
scale_factor
=
self
.
_resize
(
results
[
key
],
scale
)
if
len
(
results
[
key
].
shape
)
==
2
:
results
[
key
]
=
np
.
expand_dims
(
results
[
key
],
axis
=
2
)
results
[
'scale_factor'
]
=
scale_factor
results
[
'keep_ratio'
]
=
self
.
keep_ratio
results
[
'interpolation'
]
=
self
.
interpolation
return
results
def
__repr__
(
self
):
repr_str
=
self
.
__class__
.
__name__
repr_str
+=
(
f
'(keys=
{
self
.
keys
}
, scale=
{
self
.
scale
}
, '
f
'keep_ratio=
{
self
.
keep_ratio
}
, size_factor=
{
self
.
size_factor
}
, '
f
'max_size=
{
self
.
max_size
}
,interpolation=
{
self
.
interpolation
}
)'
)
return
repr_str
@
PIPELINES
.
register_module
()
class
NumpyPad
:
"""Numpy Padding.
In this augmentation, numpy padding is adopted to customize padding
augmentation. Please carefully read the numpy manual in:
https://numpy.org/doc/stable/reference/generated/numpy.pad.html
If you just hope a single dimension to be padded, you must set ``padding``
like this:
::
padding = ((2, 2), (0, 0), (0, 0))
In this case, if you adopt an input with three dimension, only the first
diemansion will be padded.
Args:
keys (list[str]): The images to be resized.
padding (int | tuple(int)): Please refer to the args ``pad_width`` in
``numpy.pad``.
"""
def
__init__
(
self
,
keys
,
padding
,
**
kwargs
):
self
.
keys
=
keys
self
.
padding
=
padding
self
.
kwargs
=
kwargs
def
__call__
(
self
,
results
):
"""Call function.
Args:
results (dict): A dict containing the necessary information and
data for augmentation.
Returns:
dict: A dict containing the processed data and information.
"""
for
k
in
self
.
keys
:
results
[
k
]
=
np
.
pad
(
results
[
k
],
self
.
padding
,
**
self
.
kwargs
)
return
results
def
__repr__
(
self
)
->
str
:
repr_str
=
self
.
__class__
.
__name__
repr_str
+=
(
f
'(keys=
{
self
.
keys
}
, padding=
{
self
.
padding
}
, kwargs=
{
self
.
kwargs
}
)'
)
return
repr_str
@
CLS_PIPELINE
.
register_module
()
@
PIPELINES
.
register_module
()
class
RandomImgNoise
:
"""Add random noise with specific distribution and range to the input
image.
Args:
keys (list[str]): The images to be added random noise.
lower_bound (float, optional): The lower bound of the noise.
Default to ``0.``.
upper_bound (float, optional): The upper bound of the noise.
Default to ``1 / 128.``.
distribution (str, optional): The probability distribution of the
noise. Default to 'uniform'.
"""
def
__init__
(
self
,
keys
,
lower_bound
=
0
,
upper_bound
=
1
/
128.
,
distribution
=
'uniform'
):
assert
keys
,
'Keys should not be empty.'
self
.
keys
=
keys
self
.
lower_bound
=
lower_bound
self
.
upper_bound
=
upper_bound
if
distribution
not
in
[
'uniform'
,
'normal'
]:
raise
KeyError
(
'Only support
\'
uniform
\'
distribution and '
'
\'
normal
\'
distribution, receive '
f
'
{
distribution
}
.'
)
self
.
distribution
=
distribution
def
__call__
(
self
,
results
):
"""Call function.
Args:
results (dict): A dict containing the necessary information and
data for augmentation.
Returns:
dict: A dict containing the processed data and information.
"""
if
self
.
distribution
==
'uniform'
:
dist_fn
=
np
.
random
.
rand
else
:
# self.distribution == 'normal
dist_fn
=
np
.
random
.
randn
for
key
in
self
.
keys
:
img_size
=
results
[
key
].
shape
noise
=
dist_fn
(
*
img_size
)
scale
=
noise
.
max
()
-
noise
.
min
()
noise
=
noise
-
noise
.
min
()
noise
=
noise
/
scale
*
(
self
.
upper_bound
-
self
.
lower_bound
)
noise
=
noise
+
self
.
lower_bound
results
[
key
]
+=
noise
return
results
def
__repr__
(
self
):
repr_str
=
self
.
__class__
.
__name__
repr_str
+=
(
f
'(keys=
{
self
.
keys
}
, lower_bound=
{
self
.
lower_bound
}
, '
f
'upper_bound=
{
self
.
upper_bound
}
)'
)
return
repr_str
@
CLS_PIPELINE
.
register_module
()
@
PIPELINES
.
register_module
()
class
RandomCropLongEdge
:
"""Random crop the given image by the long edge.
Args:
keys (list[str]): The images to be cropped.
"""
def
__init__
(
self
,
keys
):
assert
keys
,
'Keys should not be empty.'
self
.
keys
=
keys
def
__call__
(
self
,
results
):
"""Call function.
Args:
results (dict): A dict containing the necessary information and
data for augmentation.
Returns:
dict: A dict containing the processed data and information.
"""
for
key
in
self
.
keys
:
img
=
results
[
key
]
img_height
,
img_width
=
img
.
shape
[:
2
]
crop_size
=
min
(
img_height
,
img_width
)
y1
=
0
if
img_height
==
crop_size
else
\
np
.
random
.
randint
(
0
,
img_height
-
crop_size
)
x1
=
0
if
img_width
==
crop_size
else
\
np
.
random
.
randint
(
0
,
img_width
-
crop_size
)
y2
,
x2
=
y1
+
crop_size
-
1
,
x1
+
crop_size
-
1
img
=
mmcv
.
imcrop
(
img
,
bboxes
=
np
.
array
([
x1
,
y1
,
x2
,
y2
]))
results
[
key
]
=
img
return
results
def
__repr__
(
self
):
repr_str
=
self
.
__class__
.
__name__
repr_str
+=
(
f
'(keys=
{
self
.
keys
}
)'
)
return
repr_str
@
CLS_PIPELINE
.
register_module
()
@
PIPELINES
.
register_module
()
class
CenterCropLongEdge
:
"""Center crop the given image by the long edge.
Args:
keys (list[str]): The images to be cropped.
"""
def
__init__
(
self
,
keys
):
assert
keys
,
'Keys should not be empty.'
self
.
keys
=
keys
def
__call__
(
self
,
results
):
"""Call function.
Args:
results (dict): A dict containing the necessary information and
data for augmentation.
Returns:
dict: A dict containing the processed data and information.
"""
for
key
in
self
.
keys
:
img
=
results
[
key
]
img_height
,
img_width
=
img
.
shape
[:
2
]
crop_size
=
min
(
img_height
,
img_width
)
y1
=
0
if
img_height
==
crop_size
else
\
int
(
round
(
img_height
-
crop_size
)
/
2
)
x1
=
0
if
img_width
==
crop_size
else
\
int
(
round
(
img_width
-
crop_size
)
/
2
)
y2
=
y1
+
crop_size
-
1
x2
=
x1
+
crop_size
-
1
img
=
mmcv
.
imcrop
(
img
,
bboxes
=
np
.
array
([
x1
,
y1
,
x2
,
y2
]))
results
[
key
]
=
img
return
results
def
__repr__
(
self
):
repr_str
=
self
.
__class__
.
__name__
repr_str
+=
(
f
'(keys=
{
self
.
keys
}
)'
)
return
repr_str
build/lib/mmgen/datasets/pipelines/compose.py
0 → 100644
View file @
1401de15
# Copyright (c) OpenMMLab. All rights reserved.
from
collections.abc
import
Sequence
from
copy
import
deepcopy
from
mmcv.utils
import
build_from_cfg
from
..builder
import
PIPELINES
@
PIPELINES
.
register_module
()
class
Compose
:
"""Compose a data pipeline with a sequence of transforms.
Args:
transforms (list[dict | callable]):
Either config dicts of transforms or transform objects.
"""
def
__init__
(
self
,
transforms
):
assert
isinstance
(
transforms
,
Sequence
)
self
.
transforms
=
[]
for
transform
in
transforms
:
if
isinstance
(
transform
,
dict
):
# add support for using pipelines from `MMClassification`
if
transform
[
'type'
].
startswith
(
'mmcls.'
):
try
:
from
mmcls.datasets
import
PIPELINES
as
MMCLSPIPELINE
except
ImportError
:
raise
ImportError
(
'Please install mmcls to use '
f
'
{
transform
[
"type"
]
}
dataset.'
)
pipeline_source
=
MMCLSPIPELINE
# remove prefix
transform_cfg
=
deepcopy
(
transform
)
transform_cfg
[
'type'
]
=
transform_cfg
[
'type'
][
6
:]
else
:
pipeline_source
=
PIPELINES
transform_cfg
=
deepcopy
(
transform
)
transform
=
build_from_cfg
(
transform_cfg
,
pipeline_source
)
self
.
transforms
.
append
(
transform
)
elif
callable
(
transform
):
self
.
transforms
.
append
(
transform
)
else
:
raise
TypeError
(
f
'transform must be callable or a dict, '
f
'but got
{
type
(
transform
)
}
'
)
def
__call__
(
self
,
data
):
"""Call function.
Args:
data (dict): A dict containing the necessary information and
data for augmentation.
Returns:
dict: A dict containing the processed data and information.
"""
for
t
in
self
.
transforms
:
data
=
t
(
data
)
if
data
is
None
:
return
None
return
data
def
__repr__
(
self
):
format_string
=
self
.
__class__
.
__name__
+
'('
for
t
in
self
.
transforms
:
format_string
+=
'
\n
'
format_string
+=
f
'
{
t
}
'
format_string
+=
'
\n
)'
return
format_string
build/lib/mmgen/datasets/pipelines/crop.py
0 → 100644
View file @
1401de15
# Copyright (c) OpenMMLab. All rights reserved.
import
mmcv
import
numpy
as
np
from
..builder
import
PIPELINES
@
PIPELINES
.
register_module
()
class
Crop
:
"""Crop data to specific size for training.
Args:
keys (Sequence[str]): The images to be cropped.
crop_size (Tuple[int]): Target spatial size (h, w).
random_crop (bool): If set to True, it will random crop
image. Otherwise, it will work as center crop.
"""
def
__init__
(
self
,
keys
,
crop_size
,
random_crop
=
True
):
if
not
mmcv
.
is_tuple_of
(
crop_size
,
int
):
raise
TypeError
(
'Elements of crop_size must be int and crop_size must be'
f
' tuple, but got
{
type
(
crop_size
[
0
])
}
in
{
type
(
crop_size
)
}
'
)
self
.
keys
=
keys
self
.
crop_size
=
crop_size
self
.
random_crop
=
random_crop
def
_crop
(
self
,
data
):
if
not
isinstance
(
data
,
list
):
data_list
=
[
data
]
else
:
data_list
=
data
crop_bbox_list
=
[]
data_list_
=
[]
for
item
in
data_list
:
data_h
,
data_w
=
item
.
shape
[:
2
]
crop_h
,
crop_w
=
self
.
crop_size
crop_h
=
min
(
data_h
,
crop_h
)
crop_w
=
min
(
data_w
,
crop_w
)
if
self
.
random_crop
:
x_offset
=
np
.
random
.
randint
(
0
,
data_w
-
crop_w
+
1
)
y_offset
=
np
.
random
.
randint
(
0
,
data_h
-
crop_h
+
1
)
else
:
x_offset
=
max
(
0
,
(
data_w
-
crop_w
))
//
2
y_offset
=
max
(
0
,
(
data_h
-
crop_h
))
//
2
crop_bbox
=
[
x_offset
,
y_offset
,
crop_w
,
crop_h
]
item_
=
item
[
y_offset
:
y_offset
+
crop_h
,
x_offset
:
x_offset
+
crop_w
,
...]
crop_bbox_list
.
append
(
crop_bbox
)
data_list_
.
append
(
item_
)
if
not
isinstance
(
data
,
list
):
return
data_list_
[
0
],
crop_bbox_list
[
0
]
return
data_list_
,
crop_bbox_list
def
__call__
(
self
,
results
):
"""Call function.
Args:
results (dict): A dict containing the necessary information and
data for augmentation.
Returns:
dict: A dict containing the processed data and information.
"""
for
k
in
self
.
keys
:
data_
,
crop_bbox
=
self
.
_crop
(
results
[
k
])
results
[
k
]
=
data_
results
[
k
+
'_crop_bbox'
]
=
crop_bbox
results
[
'crop_size'
]
=
self
.
crop_size
return
results
def
__repr__
(
self
):
repr_str
=
self
.
__class__
.
__name__
repr_str
+=
(
f
'keys=
{
self
.
keys
}
, crop_size=
{
self
.
crop_size
}
, '
f
'random_crop=
{
self
.
random_crop
}
'
)
return
repr_str
@
PIPELINES
.
register_module
()
class
FixedCrop
:
"""Crop paired data (at a specific position) to specific size for training.
Args:
keys (Sequence[str]): The images to be cropped.
crop_size (Tuple[int]): Target spatial size (h, w).
crop_pos (Tuple[int]): Specific position (x, y). If set to None,
random initialize the position to crop paired data batch.
"""
def
__init__
(
self
,
keys
,
crop_size
,
crop_pos
=
None
):
if
not
mmcv
.
is_tuple_of
(
crop_size
,
int
):
raise
TypeError
(
'Elements of crop_size must be int and crop_size must be'
f
' tuple, but got
{
type
(
crop_size
[
0
])
}
in
{
type
(
crop_size
)
}
'
)
if
not
mmcv
.
is_tuple_of
(
crop_pos
,
int
)
and
(
crop_pos
is
not
None
):
raise
TypeError
(
'Elements of crop_pos must be int and crop_pos must be'
f
' tuple or None, but got
{
type
(
crop_pos
[
0
])
}
in '
f
'
{
type
(
crop_pos
)
}
'
)
self
.
keys
=
keys
self
.
crop_size
=
crop_size
self
.
crop_pos
=
crop_pos
def
_crop
(
self
,
data
,
x_offset
,
y_offset
,
crop_w
,
crop_h
):
crop_bbox
=
[
x_offset
,
y_offset
,
crop_w
,
crop_h
]
data_
=
data
[
y_offset
:
y_offset
+
crop_h
,
x_offset
:
x_offset
+
crop_w
,
...]
return
data_
,
crop_bbox
def
__call__
(
self
,
results
):
"""Call function.
Args:
results (dict): A dict containing the necessary information and
data for augmentation.
Returns:
dict: A dict containing the processed data and information.
"""
data_h
,
data_w
=
results
[
self
.
keys
[
0
]].
shape
[:
2
]
crop_h
,
crop_w
=
self
.
crop_size
crop_h
=
min
(
data_h
,
crop_h
)
crop_w
=
min
(
data_w
,
crop_w
)
if
self
.
crop_pos
is
None
:
x_offset
=
np
.
random
.
randint
(
0
,
data_w
-
crop_w
+
1
)
y_offset
=
np
.
random
.
randint
(
0
,
data_h
-
crop_h
+
1
)
else
:
x_offset
,
y_offset
=
self
.
crop_pos
crop_w
=
min
(
data_w
-
x_offset
,
crop_w
)
crop_h
=
min
(
data_h
-
y_offset
,
crop_h
)
for
k
in
self
.
keys
:
# In fixed crop for paired images, sizes should be the same
if
(
results
[
k
].
shape
[
0
]
!=
data_h
or
results
[
k
].
shape
[
1
]
!=
data_w
):
raise
ValueError
(
'The sizes of paired images should be the same. Expected '
f
'(
{
data_h
}
,
{
data_w
}
), but got (
{
results
[
k
].
shape
[
0
]
}
, '
f
'
{
results
[
k
].
shape
[
1
]
}
).'
)
data_
,
crop_bbox
=
self
.
_crop
(
results
[
k
],
x_offset
,
y_offset
,
crop_w
,
crop_h
)
results
[
k
]
=
data_
results
[
k
+
'_crop_bbox'
]
=
crop_bbox
results
[
'crop_size'
]
=
self
.
crop_size
results
[
'crop_pos'
]
=
self
.
crop_pos
return
results
def
__repr__
(
self
):
repr_str
=
self
.
__class__
.
__name__
repr_str
+=
(
f
'keys=
{
self
.
keys
}
, crop_size=
{
self
.
crop_size
}
, '
f
'crop_pos=
{
self
.
crop_pos
}
'
)
return
repr_str
build/lib/mmgen/datasets/pipelines/formatting.py
0 → 100644
View file @
1401de15
# Copyright (c) OpenMMLab. All rights reserved.
from
collections.abc
import
Sequence
import
mmcv
import
numpy
as
np
import
torch
from
mmcv.parallel
import
DataContainer
as
DC
from
..builder
import
PIPELINES
def
to_tensor
(
data
):
"""Convert objects of various python types to :obj:`torch.Tensor`.
Supported types are: :class:`numpy.ndarray`, :class:`torch.Tensor`,
:class:`Sequence`, :class:`int` and :class:`float`.
"""
if
isinstance
(
data
,
torch
.
Tensor
):
return
data
if
isinstance
(
data
,
np
.
ndarray
):
return
torch
.
from_numpy
(
data
)
if
isinstance
(
data
,
Sequence
)
and
not
mmcv
.
is_str
(
data
):
return
torch
.
tensor
(
data
)
if
isinstance
(
data
,
int
):
return
torch
.
LongTensor
([
data
])
if
isinstance
(
data
,
float
):
return
torch
.
FloatTensor
([
data
])
raise
TypeError
(
f
'type
{
type
(
data
)
}
cannot be converted to tensor.'
)
@
PIPELINES
.
register_module
()
class
ToTensor
:
"""Convert some values in results dict to `torch.Tensor` type in data
loader pipeline.
Args:
keys (Sequence[str]): Required keys to be converted.
"""
def
__init__
(
self
,
keys
):
self
.
keys
=
keys
def
__call__
(
self
,
results
):
"""Call function.
Args:
results (dict): A dict containing the necessary information and
data for augmentation.
Returns:
dict: A dict containing the processed data and information.
"""
for
key
in
self
.
keys
:
results
[
key
]
=
to_tensor
(
results
[
key
])
return
results
def
__repr__
(
self
):
return
self
.
__class__
.
__name__
+
f
'(keys=
{
self
.
keys
}
)'
@
PIPELINES
.
register_module
()
class
ImageToTensor
:
"""Convert image type to `torch.Tensor` type.
Args:
keys (Sequence[str]): Required keys to be converted.
to_float32 (bool): Whether convert numpy image array to np.float32
before converted to tensor. Default: True.
"""
def
__init__
(
self
,
keys
,
to_float32
=
True
):
self
.
keys
=
keys
self
.
to_float32
=
to_float32
def
__call__
(
self
,
results
):
"""Call function.
Args:
results (dict): A dict containing the necessary information and
data for augmentation.
Returns:
dict: A dict containing the processed data and information.
"""
for
key
in
self
.
keys
:
# deal with gray scale img: expand a color channel
if
len
(
results
[
key
].
shape
)
==
2
:
results
[
key
]
=
results
[
key
][...,
None
]
if
self
.
to_float32
and
not
isinstance
(
results
[
key
],
np
.
float32
):
results
[
key
]
=
results
[
key
].
astype
(
np
.
float32
)
results
[
key
]
=
to_tensor
(
results
[
key
].
transpose
(
2
,
0
,
1
))
return
results
def
__repr__
(
self
):
return
self
.
__class__
.
__name__
+
(
f
'(keys=
{
self
.
keys
}
, to_float32=
{
self
.
to_float32
}
)'
)
@
PIPELINES
.
register_module
()
class
Collect
:
"""Collect data from the loader relevant to the specific task.
This is usually the last stage of the data loader pipeline. Typically keys
is set to some subset of "img", "gt_labels".
The "img_meta" item is always populated. The contents of the "meta"
dictionary depends on "meta_keys".
Args:
keys (Sequence[str]): Required keys to be collected.
meta_keys (Sequence[str]): Required keys to be collected to "meta".
Default: None.
"""
def
__init__
(
self
,
keys
,
meta_keys
=
None
):
self
.
keys
=
keys
self
.
meta_keys
=
meta_keys
def
__call__
(
self
,
results
):
"""Call function.
Args:
results (dict): A dict containing the necessary information and
data for augmentation.
Returns:
dict: A dict containing the processed data and information.
"""
data
=
{}
img_meta
=
{}
for
key
in
self
.
meta_keys
:
img_meta
[
key
]
=
results
[
key
]
data
[
'meta'
]
=
DC
(
img_meta
,
cpu_only
=
True
)
for
key
in
self
.
keys
:
data
[
key
]
=
results
[
key
]
return
data
def
__repr__
(
self
):
return
self
.
__class__
.
__name__
+
(
f
'(keys=
{
self
.
keys
}
, meta_keys=
{
self
.
meta_keys
}
)'
)
build/lib/mmgen/datasets/pipelines/loading.py
0 → 100644
View file @
1401de15
# Copyright (c) OpenMMLab. All rights reserved.
import
mmcv
import
numpy
as
np
from
mmcv.fileio
import
FileClient
from
..builder
import
PIPELINES
@
PIPELINES
.
register_module
()
class
LoadImageFromFile
:
"""Load image from file.
Args:
io_backend (str): io backend where images are store. Default: 'disk'.
key (str): Keys in results to find corresponding path. Default: 'gt'.
flag (str): Loading flag for images. Default: 'color'.
channel_order (str): Order of channel, candidates are 'bgr' and 'rgb'.
Default: 'bgr'.
backend (str | None): The image decoding backend type. Options are
`cv2`, `pillow`, `turbojpeg`, `None`. If backend is None, the
global imread_backend specified by ``mmcv.use_backend()`` will be
used. Default: None.
save_original_img (bool): If True, maintain a copy of the image in
``results`` dict with name of ``f'ori_{key}'``. Default: False.
kwargs (dict): Args for file client.
"""
def
__init__
(
self
,
io_backend
=
'disk'
,
key
=
'gt'
,
flag
=
'color'
,
channel_order
=
'bgr'
,
backend
=
None
,
save_original_img
=
False
,
**
kwargs
):
self
.
io_backend
=
io_backend
self
.
key
=
key
self
.
flag
=
flag
self
.
save_original_img
=
save_original_img
self
.
channel_order
=
channel_order
self
.
backend
=
backend
self
.
kwargs
=
kwargs
self
.
file_client
=
None
def
__call__
(
self
,
results
):
"""Call function.
Args:
results (dict): A dict containing the necessary information and
data for augmentation.
Returns:
dict: A dict containing the processed data and information.
"""
if
self
.
file_client
is
None
:
self
.
file_client
=
FileClient
(
self
.
io_backend
,
**
self
.
kwargs
)
filepath
=
str
(
results
[
f
'
{
self
.
key
}
_path'
])
img_bytes
=
self
.
file_client
.
get
(
filepath
)
img
=
mmcv
.
imfrombytes
(
img_bytes
,
flag
=
self
.
flag
,
channel_order
=
self
.
channel_order
,
backend
=
self
.
backend
)
# HWC
results
[
self
.
key
]
=
img
results
[
f
'
{
self
.
key
}
_path'
]
=
filepath
results
[
f
'
{
self
.
key
}
_ori_shape'
]
=
img
.
shape
if
self
.
save_original_img
:
results
[
f
'ori_
{
self
.
key
}
'
]
=
img
.
copy
()
return
results
def
__repr__
(
self
):
repr_str
=
self
.
__class__
.
__name__
repr_str
+=
(
f
'(io_backend=
{
self
.
io_backend
}
, key=
{
self
.
key
}
, '
f
'flag=
{
self
.
flag
}
, save_original_img=
{
self
.
save_original_img
}
)'
)
return
repr_str
@
PIPELINES
.
register_module
()
class
LoadPairedImageFromFile
(
LoadImageFromFile
):
"""Load a pair of images from file.
Each sample contains a pair of images, which are concatenated in the w
dimension (a|b). This is a special loading class for generation paired
dataset. It loads a pair of images as the common loader does and crops
it into two images with the same shape in different domains.
Required key is "pair_path". Added or modified keys are "pair",
"pair_ori_shape", "ori_pair", "img_{domain_a}", "img_{domain_b}",
"img_{domain_a}_path", "img_{domain_b}_path", "img_{domain_a}_ori_shape",
"img_{domain_b}_ori_shape", "ori_img_{domain_a}" and
"ori_img_{domain_b}".
Args:
io_backend (str): io backend where images are store. Default: 'disk'.
key (str): Keys in results to find corresponding path. Default: 'gt'.
domain_a (str, optional): One of the paired image domain.
Defaults to None.
domain_b (str, optional): The other image domain.
Defaults to None.
flag (str): Loading flag for images. Default: 'color'.
channel_order (str): Order of channel, candidates are 'bgr' and 'rgb'.
Default: 'bgr'.
save_original_img (bool): If True, maintain a copy of the image in
`results` dict with name of `f'ori_{key}'`. Default: False.
kwargs (dict): Args for file client.
"""
def
__init__
(
self
,
io_backend
=
'disk'
,
key
=
'pair'
,
domain_a
=
None
,
domain_b
=
None
,
flag
=
'color'
,
channel_order
=
'bgr'
,
backend
=
None
,
save_original_img
=
False
,
**
kwargs
):
super
().
__init__
(
io_backend
,
key
=
key
,
flag
=
flag
,
channel_order
=
channel_order
,
backend
=
backend
,
save_original_img
=
save_original_img
,
**
kwargs
)
assert
isinstance
(
domain_a
,
str
)
assert
isinstance
(
domain_b
,
str
)
self
.
domain_a
=
domain_a
self
.
domain_b
=
domain_b
def
__call__
(
self
,
results
):
"""Call function.
Args:
results (dict): A dict containing the necessary information and
data for augmentation.
Returns:
dict: A dict containing the processed data and information.
"""
if
self
.
file_client
is
None
:
self
.
file_client
=
FileClient
(
self
.
io_backend
,
**
self
.
kwargs
)
filepath
=
str
(
results
[
f
'
{
self
.
key
}
_path'
])
img_bytes
=
self
.
file_client
.
get
(
filepath
)
img
=
mmcv
.
imfrombytes
(
img_bytes
,
flag
=
self
.
flag
)
# HWC, BGR
if
img
.
ndim
==
2
:
img
=
np
.
expand_dims
(
img
,
axis
=
2
)
results
[
self
.
key
]
=
img
results
[
f
'
{
self
.
key
}
_path'
]
=
filepath
results
[
f
'
{
self
.
key
}
_ori_shape'
]
=
img
.
shape
if
self
.
save_original_img
:
results
[
f
'ori_
{
self
.
key
}
'
]
=
img
.
copy
()
# crop pair into a and b
w
=
img
.
shape
[
1
]
if
w
%
2
!=
0
:
raise
ValueError
(
f
'The width of image pair must be even number, but got
{
w
}
.'
)
new_w
=
w
//
2
img_a
=
img
[:,
:
new_w
,
:]
img_b
=
img
[:,
new_w
:,
:]
results
[
f
'img_
{
self
.
domain_a
}
'
]
=
img_a
results
[
f
'img_
{
self
.
domain_b
}
'
]
=
img_b
results
[
f
'img_
{
self
.
domain_a
}
_path'
]
=
filepath
results
[
f
'img_
{
self
.
domain_b
}
_path'
]
=
filepath
results
[
f
'img_
{
self
.
domain_a
}
_ori_shape'
]
=
img_a
.
shape
results
[
f
'img_
{
self
.
domain_b
}
_ori_shape'
]
=
img_b
.
shape
if
self
.
save_original_img
:
results
[
f
'ori_img_
{
self
.
domain_a
}
'
]
=
img_a
.
copy
()
results
[
f
'ori_img_
{
self
.
domain_b
}
'
]
=
img_b
.
copy
()
return
results
build/lib/mmgen/datasets/pipelines/normalize.py
0 → 100644
View file @
1401de15
# Copyright (c) OpenMMLab. All rights reserved.
import
mmcv
import
numpy
as
np
from
..builder
import
PIPELINES
@
PIPELINES
.
register_module
()
class
Normalize
:
"""Normalize images with the given mean and std value.
Required keys are the keys in attribute "keys", added or modified keys are
the keys in attribute "keys" and these keys with postfix '_norm_cfg'.
It also supports normalizing a list of images.
Args:
keys (Sequence[str]): The images to be normalized.
mean (np.ndarray): Mean values of different channels.
std (np.ndarray): Std values of different channels.
to_rgb (bool): Whether to convert channels from BGR to RGB.
"""
def
__init__
(
self
,
keys
,
mean
,
std
,
to_rgb
=
False
):
self
.
keys
=
keys
self
.
mean
=
np
.
array
(
mean
,
dtype
=
np
.
float32
)
self
.
std
=
np
.
array
(
std
,
dtype
=
np
.
float32
)
self
.
to_rgb
=
to_rgb
def
__call__
(
self
,
results
):
"""Call function.
Args:
results (dict): A dict containing the necessary information and
data for augmentation.
Returns:
dict: A dict containing the processed data and information.
"""
for
key
in
self
.
keys
:
if
isinstance
(
results
[
key
],
list
):
results
[
key
]
=
[
mmcv
.
imnormalize
(
v
,
self
.
mean
,
self
.
std
,
self
.
to_rgb
)
for
v
in
results
[
key
]
]
else
:
results
[
key
]
=
mmcv
.
imnormalize
(
results
[
key
],
self
.
mean
,
self
.
std
,
self
.
to_rgb
)
results
[
'img_norm_cfg'
]
=
dict
(
mean
=
self
.
mean
,
std
=
self
.
std
,
to_rgb
=
self
.
to_rgb
)
return
results
def
__repr__
(
self
):
repr_str
=
self
.
__class__
.
__name__
repr_str
+=
(
f
'(keys=
{
self
.
keys
}
, mean=
{
self
.
mean
}
, std=
{
self
.
std
}
, '
f
'to_rgb=
{
self
.
to_rgb
}
)'
)
return
repr_str
@
PIPELINES
.
register_module
()
class
RescaleToZeroOne
:
"""Transform the images into a range between 0 and 1.
Required keys are the keys in attribute "keys", added or modified keys are
the keys in attribute "keys".
It also supports rescaling a list of images.
Args:
keys (Sequence[str]): The images to be transformed.
"""
def
__init__
(
self
,
keys
):
self
.
keys
=
keys
def
__call__
(
self
,
results
):
"""Call function.
Args:
results (dict): A dict containing the necessary information and
data for augmentation.
Returns:
dict: A dict containing the processed data and information.
"""
for
key
in
self
.
keys
:
if
isinstance
(
results
[
key
],
list
):
results
[
key
]
=
[
v
.
astype
(
np
.
float32
)
/
255.
for
v
in
results
[
key
]
]
else
:
results
[
key
]
=
results
[
key
].
astype
(
np
.
float32
)
/
255.
return
results
def
__repr__
(
self
):
return
self
.
__class__
.
__name__
+
f
'(keys=
{
self
.
keys
}
)'
build/lib/mmgen/datasets/quick_test_dataset.py
0 → 100644
View file @
1401de15
# Copyright (c) OpenMMLab. All rights reserved.
import
torch
from
torch.utils.data
import
Dataset
from
.builder
import
DATASETS
@
DATASETS
.
register_module
()
class
QuickTestImageDataset
(
Dataset
):
"""Dataset for quickly testing the correctness.
Args:
size (tuple[int]): The size of the images. Defaults to `None`.
"""
def
__init__
(
self
,
*
args
,
size
=
None
,
**
kwargs
):
super
().
__init__
()
self
.
size
=
size
self
.
img_tensor
=
torch
.
randn
(
3
,
self
.
size
[
0
],
self
.
size
[
1
])
def
__len__
(
self
):
return
10000
def
__getitem__
(
self
,
idx
):
return
dict
(
real_img
=
self
.
img_tensor
)
build/lib/mmgen/datasets/samplers/__init__.py
0 → 100644
View file @
1401de15
# Copyright (c) OpenMMLab. All rights reserved.
from
.distributed_sampler
import
DistributedSampler
__all__
=
[
'DistributedSampler'
]
build/lib/mmgen/datasets/samplers/distributed_sampler.py
0 → 100644
View file @
1401de15
# Copyright (c) OpenMMLab. All rights reserved.
from
__future__
import
division
import
numpy
as
np
import
torch
from
torch.utils.data
import
DistributedSampler
as
_DistributedSampler
from
mmgen.utils
import
sync_random_seed
class
DistributedSampler
(
_DistributedSampler
):
"""DistributedSampler inheriting from
`torch.utils.data.DistributedSampler`.
In pytorch of lower versions, there is no `shuffle` argument. This child
class will port one to DistributedSampler.
"""
def
__init__
(
self
,
dataset
,
num_replicas
=
None
,
rank
=
None
,
shuffle
=
True
,
samples_per_gpu
=
1
,
seed
=
None
):
super
().
__init__
(
dataset
,
num_replicas
=
num_replicas
,
rank
=
rank
)
self
.
shuffle
=
shuffle
self
.
samples_per_gpu
=
samples_per_gpu
# fix the bug of the official implementation
self
.
num_samples_per_replica
=
int
(
int
(
np
.
ceil
(
len
(
self
.
dataset
)
*
1.0
/
self
.
num_replicas
/
samples_per_gpu
)))
self
.
num_samples
=
self
.
num_samples_per_replica
*
self
.
samples_per_gpu
self
.
total_size
=
self
.
num_samples
*
self
.
num_replicas
# to avoid padding bug when meeting too small dataset
if
len
(
dataset
)
<
self
.
num_replicas
*
samples_per_gpu
:
raise
ValueError
(
'You may use too small dataset and our distributed '
'sampler cannot pad your dataset correctly. We highly '
'recommend you to use fewer GPUs to finish your work'
)
# In distributed sampling, different ranks should sample
# non-overlapped data in the dataset. Therefore, this function
# is used to make sure that each rank shuffles the data indices
# in the same order based on the same seed. Then different ranks
# could use different indices to select non-overlapped data from the
# same data list.
self
.
seed
=
sync_random_seed
(
seed
)
def
update_sampler
(
self
,
dataset
,
samples_per_gpu
=
None
):
self
.
dataset
=
dataset
if
samples_per_gpu
is
not
None
:
self
.
samples_per_gpu
=
samples_per_gpu
# fix the bug of the official implementation
self
.
num_samples_per_replica
=
int
(
int
(
np
.
ceil
(
len
(
self
.
dataset
)
*
1.0
/
self
.
num_replicas
/
self
.
samples_per_gpu
)))
self
.
num_samples
=
self
.
num_samples_per_replica
*
self
.
samples_per_gpu
self
.
total_size
=
self
.
num_samples
*
self
.
num_replicas
# to avoid padding bug when meeting too small dataset
if
len
(
dataset
)
<
self
.
num_replicas
*
self
.
samples_per_gpu
:
raise
ValueError
(
'You may use too small dataset and our distributed '
'sampler cannot pad your dataset correctly. We highly '
'recommend you to use fewer GPUs to finish your work'
)
def
__iter__
(
self
):
# deterministically shuffle based on epoch
if
self
.
shuffle
:
g
=
torch
.
Generator
()
# When :attr:`shuffle=True`, this ensures all replicas
# use a different random ordering for each epoch.
# Otherwise, the next iteration of this sampler will
# yield the same ordering.
g
.
manual_seed
(
self
.
seed
+
self
.
epoch
)
indices
=
torch
.
randperm
(
len
(
self
.
dataset
),
generator
=
g
).
tolist
()
else
:
indices
=
torch
.
arange
(
len
(
self
.
dataset
)).
tolist
()
# add extra samples to make it evenly divisible
indices
+=
indices
[:(
self
.
total_size
-
len
(
indices
))]
assert
len
(
indices
)
==
self
.
total_size
# subsample
indices
=
indices
[
self
.
rank
:
self
.
total_size
:
self
.
num_replicas
]
assert
len
(
indices
)
==
self
.
num_samples
return
iter
(
indices
)
build/lib/mmgen/datasets/singan_dataset.py
0 → 100644
View file @
1401de15
# Copyright (c) OpenMMLab. All rights reserved.
import
mmcv
import
numpy
as
np
import
torch
from
torch.utils.data
import
Dataset
from
.builder
import
DATASETS
def
create_real_pyramid
(
real
,
min_size
,
max_size
,
scale_factor_init
):
"""Create image pyramid.
This function is modified from the official implementation:
https://github.com/tamarott/SinGAN/blob/master/SinGAN/functions.py#L221
In this implementation, we adopt the rescaling function from MMCV.
Args:
real (np.array): The real image array.
min_size (int): The minimum size for the image pyramid.
max_size (int): The maximum size for the image pyramid.
scale_factor_init (float): The initial scale factor.
"""
num_scales
=
int
(
np
.
ceil
(
np
.
log
(
np
.
power
(
min_size
/
min
(
real
.
shape
[
0
],
real
.
shape
[
1
]),
1
))
/
np
.
log
(
scale_factor_init
)))
+
1
scale2stop
=
int
(
np
.
ceil
(
np
.
log
(
min
([
max_size
,
max
([
real
.
shape
[
0
],
real
.
shape
[
1
]])])
/
max
([
real
.
shape
[
0
],
real
.
shape
[
1
]]))
/
np
.
log
(
scale_factor_init
)))
stop_scale
=
num_scales
-
scale2stop
scale1
=
min
(
max_size
/
max
([
real
.
shape
[
0
],
real
.
shape
[
1
]]),
1
)
real_max
=
mmcv
.
imrescale
(
real
,
scale1
)
scale_factor
=
np
.
power
(
min_size
/
(
min
(
real_max
.
shape
[
0
],
real_max
.
shape
[
1
])),
1
/
(
stop_scale
))
scale2stop
=
int
(
np
.
ceil
(
np
.
log
(
min
([
max_size
,
max
([
real
.
shape
[
0
],
real
.
shape
[
1
]])])
/
max
([
real
.
shape
[
0
],
real
.
shape
[
1
]]))
/
np
.
log
(
scale_factor_init
)))
stop_scale
=
num_scales
-
scale2stop
reals
=
[]
for
i
in
range
(
stop_scale
+
1
):
scale
=
np
.
power
(
scale_factor
,
stop_scale
-
i
)
curr_real
=
mmcv
.
imrescale
(
real
,
scale
)
reals
.
append
(
curr_real
)
return
reals
,
scale_factor
,
stop_scale
@
DATASETS
.
register_module
()
class
SinGANDataset
(
Dataset
):
"""SinGAN Dataset.
In this dataset, we create an image pyramid and save it in the cache.
Args:
img_path (str): Path to the single image file.
min_size (int): Min size of the image pyramid. Here, the number will be
set to the ``min(H, W)``.
max_size (int): Max size of the image pyramid. Here, the number will be
set to the ``max(H, W)``.
scale_factor_init (float): Rescale factor. Note that the actual factor
we use may be a little bit different from this value.
num_samples (int, optional): The number of samples (length) in this
dataset. Defaults to -1.
"""
def
__init__
(
self
,
img_path
,
min_size
,
max_size
,
scale_factor_init
,
num_samples
=-
1
):
self
.
img_path
=
img_path
assert
mmcv
.
is_filepath
(
self
.
img_path
)
self
.
load_annotations
(
min_size
,
max_size
,
scale_factor_init
)
self
.
num_samples
=
num_samples
def
load_annotations
(
self
,
min_size
,
max_size
,
scale_factor_init
):
"""Load annatations for SinGAN Dataset.
Args:
min_size (int): The minimum size for the image pyramid.
max_size (int): The maximum size for the image pyramid.
scale_factor_init (float): The initial scale factor.
"""
real
=
mmcv
.
imread
(
self
.
img_path
)
self
.
reals
,
self
.
scale_factor
,
self
.
stop_scale
=
create_real_pyramid
(
real
,
min_size
,
max_size
,
scale_factor_init
)
self
.
data_dict
=
{}
for
i
,
real
in
enumerate
(
self
.
reals
):
self
.
data_dict
[
f
'real_scale
{
i
}
'
]
=
self
.
_img2tensor
(
real
)
self
.
data_dict
[
'input_sample'
]
=
torch
.
zeros_like
(
self
.
data_dict
[
'real_scale0'
])
def
_img2tensor
(
self
,
img
):
img
=
torch
.
from_numpy
(
img
).
to
(
torch
.
float32
).
permute
(
2
,
0
,
1
).
contiguous
()
img
=
(
img
/
255
-
0.5
)
*
2
return
img
def
__getitem__
(
self
,
index
):
return
self
.
data_dict
def
__len__
(
self
):
return
int
(
1e6
)
if
self
.
num_samples
<
0
else
self
.
num_samples
build/lib/mmgen/datasets/unconditional_image_dataset.py
0 → 100644
View file @
1401de15
# Copyright (c) OpenMMLab. All rights reserved.
import
os.path
as
osp
import
mmcv
from
torch.utils.data
import
Dataset
from
.builder
import
DATASETS
from
.pipelines
import
Compose
@
DATASETS
.
register_module
()
class
UnconditionalImageDataset
(
Dataset
):
"""Unconditional Image Dataset.
This dataset contains raw images for training unconditional GANs. Given
a root dir, we will recursively find all images in this root. The
transformation on data is defined by the pipeline.
Args:
imgs_root (str): Root path for unconditional images.
pipeline (list[dict | callable]): A sequence of data transforms.
test_mode (bool, optional): If True, the dataset will work in test
mode. Otherwise, in train mode. Default to False.
"""
_VALID_IMG_SUFFIX
=
(
'.jpg'
,
'.png'
,
'.jpeg'
,
'.JPEG'
)
def
__init__
(
self
,
imgs_root
,
pipeline
,
test_mode
=
False
):
super
().
__init__
()
self
.
imgs_root
=
imgs_root
self
.
pipeline
=
Compose
(
pipeline
)
self
.
test_mode
=
test_mode
self
.
load_annotations
()
# print basic dataset information to check the validity
mmcv
.
print_log
(
repr
(
self
),
'mmgen'
)
def
load_annotations
(
self
):
"""Load annotations."""
# recursively find all of the valid images from imgs_root
imgs_list
=
mmcv
.
scandir
(
self
.
imgs_root
,
self
.
_VALID_IMG_SUFFIX
,
recursive
=
True
)
self
.
imgs_list
=
[
osp
.
join
(
self
.
imgs_root
,
x
)
for
x
in
imgs_list
]
def
prepare_train_data
(
self
,
idx
):
"""Prepare training data.
Args:
idx (int): Index of current batch.
Returns:
dict: Prepared training data batch.
"""
results
=
dict
(
real_img_path
=
self
.
imgs_list
[
idx
])
return
self
.
pipeline
(
results
)
def
prepare_test_data
(
self
,
idx
):
"""Prepare testing data.
Args:
idx (int): Index of current batch.
Returns:
dict: Prepared training data batch.
"""
results
=
dict
(
real_img_path
=
self
.
imgs_list
[
idx
])
return
self
.
pipeline
(
results
)
def
__len__
(
self
):
return
len
(
self
.
imgs_list
)
def
__getitem__
(
self
,
idx
):
if
not
self
.
test_mode
:
return
self
.
prepare_train_data
(
idx
)
return
self
.
prepare_test_data
(
idx
)
def
__repr__
(
self
):
dataset_name
=
self
.
__class__
imgs_root
=
self
.
imgs_root
num_imgs
=
len
(
self
)
return
(
f
'dataset_name:
{
dataset_name
}
, total
{
num_imgs
}
images in '
f
'imgs_root:
{
imgs_root
}
'
)
build/lib/mmgen/datasets/unpaired_image_dataset.py
0 → 100644
View file @
1401de15
# Copyright (c) OpenMMLab. All rights reserved.
import
os.path
as
osp
from
pathlib
import
Path
import
numpy
as
np
from
mmcv
import
scandir
from
torch.utils.data
import
Dataset
from
.builder
import
DATASETS
from
.pipelines
import
Compose
IMG_EXTENSIONS
=
(
'.jpg'
,
'.JPG'
,
'.jpeg'
,
'.JPEG'
,
'.png'
,
'.PNG'
,
'.ppm'
,
'.PPM'
,
'.bmp'
,
'.BMP'
,
'.tif'
,
'.TIF'
,
'.tiff'
,
'.TIFF'
)
@
DATASETS
.
register_module
()
class
UnpairedImageDataset
(
Dataset
):
"""General unpaired image folder dataset for image generation.
It assumes that the training directory of images from domain A is
'/path/to/data/trainA', and that from domain B is '/path/to/data/trainB',
respectively. '/path/to/data' can be initialized by args 'dataroot'.
During test time, the directory is '/path/to/data/testA' and
'/path/to/data/testB', respectively.
Args:
dataroot (str | :obj:`Path`): Path to the folder root of unpaired
images.
pipeline (List[dict | callable]): A sequence of data transformations.
test_mode (bool): Store `True` when building test dataset.
Default: `False`.
domain_a (str, optional): Domain of images in trainA / testA.
Defaults to None.
domain_b (str, optional): Domain of images in trainB / testB.
Defaults to None.
"""
def
__init__
(
self
,
dataroot
,
pipeline
,
test_mode
=
False
,
domain_a
=
None
,
domain_b
=
None
):
super
().
__init__
()
phase
=
'test'
if
test_mode
else
'train'
self
.
dataroot_a
=
osp
.
join
(
str
(
dataroot
),
phase
+
'A'
)
self
.
dataroot_b
=
osp
.
join
(
str
(
dataroot
),
phase
+
'B'
)
self
.
data_infos_a
=
self
.
load_annotations
(
self
.
dataroot_a
)
self
.
data_infos_b
=
self
.
load_annotations
(
self
.
dataroot_b
)
self
.
len_a
=
len
(
self
.
data_infos_a
)
self
.
len_b
=
len
(
self
.
data_infos_b
)
self
.
test_mode
=
test_mode
self
.
pipeline
=
Compose
(
pipeline
)
assert
isinstance
(
domain_a
,
str
)
assert
isinstance
(
domain_b
,
str
)
self
.
domain_a
=
domain_a
self
.
domain_b
=
domain_b
def
load_annotations
(
self
,
dataroot
):
"""Load unpaired image paths of one domain.
Args:
dataroot (str): Path to the folder root for unpaired images of
one domain.
Returns:
list[dict]: List that contains unpaired image paths of one domain.
"""
data_infos
=
[]
paths
=
sorted
(
self
.
scan_folder
(
dataroot
))
for
path
in
paths
:
data_infos
.
append
(
dict
(
path
=
path
))
return
data_infos
def
prepare_train_data
(
self
,
idx
):
"""Prepare unpaired training data.
Args:
idx (int): Index of current batch.
Returns:
dict: Prepared training data batch.
"""
img_a_path
=
self
.
data_infos_a
[
idx
%
self
.
len_a
][
'path'
]
idx_b
=
np
.
random
.
randint
(
0
,
self
.
len_b
)
img_b_path
=
self
.
data_infos_b
[
idx_b
][
'path'
]
results
=
dict
()
results
[
f
'img_
{
self
.
domain_a
}
_path'
]
=
img_a_path
results
[
f
'img_
{
self
.
domain_b
}
_path'
]
=
img_b_path
return
self
.
pipeline
(
results
)
def
prepare_test_data
(
self
,
idx
):
"""Prepare unpaired test data.
Args:
idx (int): Index of current batch.
Returns:
list[dict]: Prepared test data batch.
"""
img_a_path
=
self
.
data_infos_a
[
idx
%
self
.
len_a
][
'path'
]
img_b_path
=
self
.
data_infos_b
[
idx
%
self
.
len_b
][
'path'
]
results
=
dict
()
results
[
f
'img_
{
self
.
domain_a
}
_path'
]
=
img_a_path
results
[
f
'img_
{
self
.
domain_b
}
_path'
]
=
img_b_path
return
self
.
pipeline
(
results
)
def
__len__
(
self
):
return
max
(
self
.
len_a
,
self
.
len_b
)
@
staticmethod
def
scan_folder
(
path
):
"""Obtain image path list (including sub-folders) from a given folder.
Args:
path (str | :obj:`Path`): Folder path.
Returns:
list[str]: Image list obtained from the given folder.
"""
if
isinstance
(
path
,
(
str
,
Path
)):
path
=
str
(
path
)
else
:
raise
TypeError
(
"'path' must be a str or a Path object, "
f
'but received
{
type
(
path
)
}
.'
)
images
=
scandir
(
path
,
suffix
=
IMG_EXTENSIONS
,
recursive
=
True
)
images
=
[
osp
.
join
(
path
,
v
)
for
v
in
images
]
assert
images
,
f
'
{
path
}
has no valid image file.'
return
images
def
__getitem__
(
self
,
idx
):
"""Get item at each call.
Args:
idx (int): Index for getting each item.
"""
if
not
self
.
test_mode
:
return
self
.
prepare_train_data
(
idx
)
return
self
.
prepare_test_data
(
idx
)
build/lib/mmgen/models/__init__.py
0 → 100644
View file @
1401de15
# Copyright (c) OpenMMLab. All rights reserved.
from
.architectures
import
*
# noqa: F401, F403
from
.builder
import
MODELS
,
MODULES
,
build_model
,
build_module
from
.common
import
*
# noqa: F401, F403
from
.diffusions
import
*
# noqa: F401, F403
from
.gans
import
*
# noqa: F401, F403
from
.losses
import
*
# noqa: F401, F403
from
.misc
import
*
# noqa: F401, F403
from
.translation_models
import
*
# noqa: F401, F403
__all__
=
[
'build_model'
,
'MODELS'
,
'build_module'
,
'MODULES'
]
build/lib/mmgen/models/architectures/__init__.py
0 → 100644
View file @
1401de15
# Copyright (c) OpenMMLab. All rights reserved.
from
.arcface
import
IDLossModel
from
.biggan
import
(
BigGANDeepDiscriminator
,
BigGANDeepGenerator
,
BigGANDiscriminator
,
BigGANGenerator
,
SNConvModule
)
from
.cyclegan
import
ResnetGenerator
from
.dcgan
import
DCGANDiscriminator
,
DCGANGenerator
from
.ddpm
import
DenoisingUnet
from
.fid_inception
import
InceptionV3
from
.lpips
import
PerceptualLoss
from
.lsgan
import
LSGANDiscriminator
,
LSGANGenerator
from
.pggan
import
(
EqualizedLR
,
EqualizedLRConvDownModule
,
EqualizedLRConvModule
,
EqualizedLRConvUpModule
,
EqualizedLRLinearModule
,
MiniBatchStddevLayer
,
PGGANDiscriminator
,
PGGANGenerator
,
PGGANNoiseTo2DFeat
,
PixelNorm
,
equalized_lr
)
from
.pix2pix
import
PatchDiscriminator
,
generation_init_weights
from
.positional_encoding
import
CatersianGrid
,
SinusoidalPositionalEmbedding
from
.singan
import
SinGANMultiScaleDiscriminator
,
SinGANMultiScaleGenerator
from
.sngan_proj
import
ProjDiscriminator
,
SNGANGenerator
from
.stylegan
import
(
MSStyleGAN2Discriminator
,
MSStyleGANv2Generator
,
StyleGAN1Discriminator
,
StyleGAN2Discriminator
,
StyleGANv1Generator
,
StyleGANv2Generator
,
StyleGANv3Generator
)
from
.wgan_gp
import
WGANGPDiscriminator
,
WGANGPGenerator
__all__
=
[
'DCGANGenerator'
,
'DCGANDiscriminator'
,
'EqualizedLR'
,
'EqualizedLRConvModule'
,
'equalized_lr'
,
'EqualizedLRLinearModule'
,
'EqualizedLRConvUpModule'
,
'EqualizedLRConvDownModule'
,
'PixelNorm'
,
'MiniBatchStddevLayer'
,
'PGGANNoiseTo2DFeat'
,
'PGGANGenerator'
,
'PGGANDiscriminator'
,
'InceptionV3'
,
'SinGANMultiScaleDiscriminator'
,
'SinGANMultiScaleGenerator'
,
'CatersianGrid'
,
'SinusoidalPositionalEmbedding'
,
'StyleGAN2Discriminator'
,
'StyleGANv2Generator'
,
'StyleGANv1Generator'
,
'StyleGAN1Discriminator'
,
'MSStyleGAN2Discriminator'
,
'MSStyleGANv2Generator'
,
'generation_init_weights'
,
'PatchDiscriminator'
,
'ResnetGenerator'
,
'PerceptualLoss'
,
'WGANGPDiscriminator'
,
'WGANGPGenerator'
,
'LSGANDiscriminator'
,
'LSGANGenerator'
,
'ProjDiscriminator'
,
'SNGANGenerator'
,
'BigGANGenerator'
,
'SNConvModule'
,
'BigGANDiscriminator'
,
'BigGANDeepGenerator'
,
'BigGANDeepDiscriminator'
,
'DenoisingUnet'
,
'StyleGANv3Generator'
,
'IDLossModel'
]
Prev
1
…
8
9
10
11
12
13
14
15
16
…
24
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment