Commit 57e0e891 authored by limm's avatar limm
Browse files

add part mmgeneration code

parent 04e07f48
_base_ = ['./stylegan3_base.py']
synthesis_cfg = {
'type': 'SynthesisNetwork',
'channel_base': 65536,
'channel_max': 1024,
'magnitude_ema_beta': 0.999,
'conv_kernel': 1,
'use_radial_filters': True
}
r1_gamma = 32.8
d_reg_interval = 16
model = dict(
type='StaticUnconditionalGAN',
generator=dict(
out_size=1024,
img_channels=3,
synthesis_cfg=synthesis_cfg,
rgb2bgr=True),
discriminator=dict(type='StyleGAN2Discriminator', in_size=1024))
_base_ = ['./stylegan3_base.py']
synthesis_cfg = {
'type': 'SynthesisNetwork',
'channel_base': 32768,
'channel_max': 1024,
'magnitude_ema_beta': 0.999,
'conv_kernel': 1,
'use_radial_filters': True
}
model = dict(
type='StaticUnconditionalGAN',
generator=dict(
out_size=256,
img_channels=3,
rgb2bgr=True,
synthesis_cfg=synthesis_cfg),
discriminator=dict(in_size=256, channel_multiplier=1))
_base_ = [
'../_base_/models/stylegan/stylegan3_base.py',
'../_base_/datasets/ffhq_flip.py', '../_base_/default_runtime.py'
]
synthesis_cfg = {
'type': 'SynthesisNetwork',
'channel_base': 32768,
'channel_max': 512,
'magnitude_ema_beta': 0.999
}
r1_gamma = 6.6 # set by user
d_reg_interval = 16
load_from = 'https://download.openmmlab.com/mmgen/stylegan3/stylegan3_t_ffhq_1024_b4x8_cvt_official_rgb_20220329_235113-db6c6580.pth' # noqa
# ada settings
aug_kwargs = {
'xflip': 1,
'rotate90': 1,
'xint': 1,
'scale': 1,
'rotate': 1,
'aniso': 1,
'xfrac': 1,
'brightness': 1,
'contrast': 1,
'lumaflip': 1,
'hue': 1,
'saturation': 1
}
model = dict(
type='StaticUnconditionalGAN',
generator=dict(
out_size=1024,
img_channels=3,
rgb2bgr=True,
synthesis_cfg=synthesis_cfg),
discriminator=dict(
type='ADAStyleGAN2Discriminator',
in_size=1024,
input_bgr2rgb=True,
data_aug=dict(type='ADAAug', aug_pipeline=aug_kwargs, ada_kimg=100)),
gan_loss=dict(type='GANLoss', gan_type='wgan-logistic-ns'),
disc_auxiliary_loss=dict(loss_weight=r1_gamma / 2.0 * d_reg_interval))
imgs_root = 'data/metfaces/images/'
data = dict(
samples_per_gpu=4,
train=dict(dataset=dict(imgs_root=imgs_root)),
val=dict(imgs_root=imgs_root))
ema_half_life = 10. # G_smoothing_kimg
ema_kimg = 10
ema_nimg = ema_kimg * 1000
ema_beta = 0.5**(32 / max(ema_nimg, 1e-8))
custom_hooks = [
dict(
type='VisualizeUnconditionalSamples',
output_dir='training_samples',
interval=5000),
dict(
type='ExponentialMovingAverageHook',
module_keys=('generator_ema', ),
interp_mode='lerp',
interp_cfg=dict(momentum=ema_beta),
interval=1,
start_iter=0,
priority='VERY_HIGH')
]
inception_pkl = 'work_dirs/inception_pkl/metface_1024x1024_noflip.pkl'
metrics = dict(
fid50k=dict(
type='FID',
num_images=50000,
inception_pkl=inception_pkl,
inception_args=dict(type='StyleGAN'),
bgr2rgb=True))
evaluation = dict(
type='GenerativeEvalHook',
interval=dict(milestones=[80000], interval=[10000, 5000]),
metrics=dict(
type='FID',
num_images=50000,
inception_pkl=inception_pkl,
inception_args=dict(type='StyleGAN'),
bgr2rgb=True),
sample_kwargs=dict(sample_model='ema'))
lr_config = None
total_iters = 160000
_base_ = ['./stylegan3_base.py']
synthesis_cfg = {
'type': 'SynthesisNetwork',
'channel_base': 32768,
'channel_max': 512,
'magnitude_ema_beta': 0.999
}
model = dict(
type='StaticUnconditionalGAN',
generator=dict(
out_size=512,
img_channels=3,
rgb2bgr=True,
synthesis_cfg=synthesis_cfg),
discriminator=dict(in_size=512))
_base_ = ['./stylegan3_base.py']
synthesis_cfg = {
'type': 'SynthesisNetwork',
'channel_base': 32768,
'channel_max': 512,
'magnitude_ema_beta': 0.999
}
model = dict(
type='StaticUnconditionalGAN',
generator=dict(
out_size=1024,
img_channels=3,
synthesis_cfg=synthesis_cfg,
rgb2bgr=True),
discriminator=dict(in_size=1024))
_base_ = ['./stylegan3_base.py']
synthesis_cfg = {
'type': 'SynthesisNetwork',
'channel_base': 16384,
'channel_max': 512,
'magnitude_ema_beta': 0.999
}
model = dict(
type='StaticUnconditionalGAN',
generator=dict(
out_size=256,
img_channels=3,
rgb2bgr=True,
synthesis_cfg=synthesis_cfg),
discriminator=dict(in_size=256, channel_multiplier=1))
_base_ = [
'../_base_/models/stylegan/stylegan3_base.py',
'../_base_/datasets/unconditional_imgs_flip_lanczos_resize_256x256.py',
'../_base_/default_runtime.py'
]
synthesis_cfg = {
'type': 'SynthesisNetwork',
'channel_base': 16384,
'channel_max': 512,
'magnitude_ema_beta': 0.999
}
r1_gamma = 2. # set by user
d_reg_interval = 16
model = dict(
type='StaticUnconditionalGAN',
generator=dict(out_size=256, img_channels=3, synthesis_cfg=synthesis_cfg),
discriminator=dict(in_size=256, channel_multiplier=1),
gan_loss=dict(type='GANLoss', gan_type='wgan-logistic-ns'),
disc_auxiliary_loss=dict(loss_weight=r1_gamma / 2.0 * d_reg_interval))
imgs_root = 'data/ffhq/images'
data = dict(
samples_per_gpu=4,
train=dict(dataset=dict(imgs_root=imgs_root)),
val=dict(imgs_root=imgs_root))
ema_half_life = 10. # G_smoothing_kimg
custom_hooks = [
dict(
type='VisualizeUnconditionalSamples',
output_dir='training_samples',
interval=5000),
dict(
type='ExponentialMovingAverageHook',
module_keys=('generator_ema', ),
interp_mode='lerp',
interval=1,
start_iter=0,
momentum_policy='rampup',
momentum_cfg=dict(
ema_kimg=10, ema_rampup=0.05, batch_size=32, eps=1e-8),
priority='VERY_HIGH')
]
inception_pkl = 'work_dirs/inception_pkl/ffhq-lanczos-256x256.pkl'
metrics = dict(
fid50k=dict(
type='FID',
num_images=50000,
inception_pkl=inception_pkl,
inception_args=dict(type='StyleGAN'),
bgr2rgb=True))
inception_path = None
evaluation = dict(
type='GenerativeEvalHook',
interval=10000,
metrics=dict(
type='FID',
num_images=50000,
inception_pkl=inception_pkl,
inception_args=dict(type='StyleGAN', inception_path=inception_path),
bgr2rgb=True),
sample_kwargs=dict(sample_model='ema'))
checkpoint_config = dict(interval=10000, by_epoch=False, max_keep_ckpts=30)
lr_config = None
total_iters = 800002
_base_ = [
'../_base_/models/stylegan/stylegan3_base.py',
'../_base_/datasets/ffhq_flip.py', '../_base_/default_runtime.py'
]
batch_size = 32
magnitude_ema_beta = 0.5**(batch_size / (20 * 1e3))
synthesis_cfg = {
'type': 'SynthesisNetwork',
'channel_base': 32768,
'channel_max': 512,
'magnitude_ema_beta': 0.999
}
r1_gamma = 32.8
d_reg_interval = 16
model = dict(
type='StaticUnconditionalGAN',
generator=dict(out_size=1024, img_channels=3, synthesis_cfg=synthesis_cfg),
discriminator=dict(in_size=1024),
gan_loss=dict(type='GANLoss', gan_type='wgan-logistic-ns'),
disc_auxiliary_loss=dict(loss_weight=r1_gamma / 2.0 * d_reg_interval))
imgs_root = None # set by user
data = dict(
samples_per_gpu=4,
train=dict(dataset=dict(imgs_root=imgs_root)),
val=dict(imgs_root=imgs_root))
ema_half_life = 10. # G_smoothing_kimg
custom_hooks = [
dict(
type='VisualizeUnconditionalSamples',
output_dir='training_samples',
interval=5000),
dict(
type='ExponentialMovingAverageHook',
module_keys=('generator_ema', ),
interp_mode='lerp',
interval=1,
start_iter=0,
momentum_policy='rampup',
momentum_cfg=dict(
ema_kimg=10, ema_rampup=0.05, batch_size=batch_size, eps=1e-8),
priority='VERY_HIGH')
]
inception_pkl = 'work_dirs/inception_pkl/ffhq_noflip_1024x1024.pkl'
metrics = dict(
fid50k=dict(
type='FID',
num_images=50000,
inception_pkl=inception_pkl,
inception_args=dict(type='StyleGAN'),
bgr2rgb=True))
evaluation = dict(
type='GenerativeEvalHook',
interval=10000,
metrics=dict(
type='FID',
num_images=50000,
inception_pkl=inception_pkl,
inception_args=dict(type='StyleGAN'),
bgr2rgb=True),
sample_kwargs=dict(sample_model='ema'))
checkpoint_config = dict(interval=10000, by_epoch=False, max_keep_ckpts=30)
lr_config = None
total_iters = 800002
# WGAN-GP
> [Improved Training of Wasserstein GANs](https://arxiv.org/abs/1704.00028)
<!-- [ALGORITHM] -->
## Abstract
<!-- [ABSTRACT] -->
Generative Adversarial Networks (GANs) are powerful generative models, but suffer from training instability. The recently proposed Wasserstein GAN (WGAN) makes progress toward stable training of GANs, but sometimes can still generate only low-quality samples or fail to converge. We find that these problems are often due to the use of weight clipping in WGAN to enforce a Lipschitz constraint on the critic, which can lead to undesired behavior. We propose an alternative to clipping weights: penalize the norm of gradient of the critic with respect to its input. Our proposed method performs better than standard WGAN and enables stable training of a wide variety of GAN architectures with almost no hyperparameter tuning, including 101-layer ResNets and language models over discrete data. We also achieve high quality generations on CIFAR-10 and LSUN bedrooms.
<!-- [IMAGE] -->
<div align=center>
<img src="https://user-images.githubusercontent.com/28132635/143154792-de359728-101b-4ad1-90c0-ef3c1572d184.png"/>
</div>
## Results and models
<div align="center">
<b> WGAN-GP 128, CelebA-Cropped</b>
<br/>
<img src="https://user-images.githubusercontent.com/12726765/113997469-c00e3f00-988a-11eb-81dc-19b05698b74b.png" width="800"/>
</div>
| Models | Dataset | Details | SWD | MS-SSIM | Config | Download |
| :---------: | :------------: | :----------------: | :---------------------------: | :-----: | :---------------------------------------------------------: | :------------------------------------------------------------: |
| WGAN-GP 128 | CelebA-Cropped | GN | 5.87, 9.76, 9.43, 18.84/10.97 | 0.2601 | [config](https://github.com/open-mmlab/mmgeneration/tree/master/configs/wgan-gp/wgangp_GN_celeba-cropped_128_b64x1_160kiter.py) | [model](https://download.openmmlab.com/mmgen/wgangp/wgangp_GN_celeba-cropped_128_b64x1_160k_20210408_170611-f8a99336.pth) |
| WGAN-GP 128 | LSUN-Bedroom | GN, GP-lambda = 50 | 11.7, 7.87, 9.82, 25.36/13.69 | 0.059 | [config](https://github.com/open-mmlab/mmgeneration/tree/master/configs/wgan-gp/wgangp_GN_GP-50_lsun-bedroom_128_b64x1_160kiter.py) | [model](https://download.openmmlab.com/mmgen/wgangp/wgangp_GN_GP-50_lsun-bedroom_128_b64x1_130k_20210408_170509-56f2a37c.pth) |
## Citation
```latex
@article{gulrajani2017improved,
title={Improved Training of Wasserstein GANs},
author={Gulrajani, Ishaan and Ahmed, Faruk and Arjovsky, Martin and Dumoulin, Vincent and Courville, Aaron},
journal={arXiv preprint arXiv:1704.00028},
year={2017},
url={https://arxiv.org/abs/1704.00028},
}
```
Collections:
- Metadata:
Architecture:
- WGAN-GP
Name: WGAN-GP
Paper:
- https://arxiv.org/abs/1704.00028
README: configs/wgan-gp/README.md
Models:
- Config: https://github.com/open-mmlab/mmgeneration/tree/master/configs/wgan-gp/wgangp_GN_celeba-cropped_128_b64x1_160kiter.py
In Collection: WGAN-GP
Metadata:
Training Data: CELEBA
Name: wgangp_GN_celeba-cropped_128_b64x1_160kiter
Results:
- Dataset: CELEBA
Metrics:
Details: GN
MS-SSIM: 0.2601
SWD: 5.87, 9.76, 9.43, 18.84/10.97
Task: Unconditional GANs
Weights: https://download.openmmlab.com/mmgen/wgangp/wgangp_GN_celeba-cropped_128_b64x1_160k_20210408_170611-f8a99336.pth
- Config: https://github.com/open-mmlab/mmgeneration/tree/master/configs/wgan-gp/wgangp_GN_GP-50_lsun-bedroom_128_b64x1_160kiter.py
In Collection: WGAN-GP
Metadata:
Training Data: LSUN
Name: wgangp_GN_GP-50_lsun-bedroom_128_b64x1_160kiter
Results:
- Dataset: LSUN
Metrics:
Details: GN, GP-lambda = 50
MS-SSIM: 0.059
SWD: 11.7, 7.87, 9.82, 25.36/13.69
Task: Unconditional GANs
Weights: https://download.openmmlab.com/mmgen/wgangp/wgangp_GN_GP-50_lsun-bedroom_128_b64x1_130k_20210408_170509-56f2a37c.pth
_base_ = ['./wgangp_GN_celeba-cropped_128_b64x1_160kiter.py']
model = dict(disc_auxiliary_loss=[
dict(
type='GradientPenaltyLoss',
loss_weight=50,
norm_mode='HWC',
data_info=dict(
discriminator='disc', real_data='real_imgs',
fake_data='fake_imgs'))
])
data = dict(
samples_per_gpu=64, train=dict(imgs_root='./data/lsun/bedroom_train'))
_base_ = [
'../_base_/datasets/unconditional_imgs_128x128.py',
'../_base_/models/wgangp/wgangp_base.py'
]
data = dict(
samples_per_gpu=64,
train=dict(imgs_root='./data/celeba-cropped/cropped_images_aligned_png/'))
checkpoint_config = dict(interval=10000, by_epoch=False)
log_config = dict(interval=100, hooks=[dict(type='TextLoggerHook')])
custom_hooks = [
dict(
type='VisualizeUnconditionalSamples',
output_dir='training_samples',
interval=1000)
]
lr_config = None
total_iters = 160000
metrics = dict(
ms_ssim10k=dict(type='MS_SSIM', num_images=10000),
swd16k=dict(type='SWD', num_images=16384, image_shape=(3, 128, 128)))
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os
import sys
import mmcv
from mmcv import DictAction
from torchvision import utils
# yapf: disable
sys.path.append(os.path.abspath(os.path.join(__file__, '../..'))) # isort:skip # noqa
from mmgen.apis import init_model, sample_conditional_model # isort:skip # noqa
# yapf: enable
def parse_args():
parser = argparse.ArgumentParser(description='Generation demo')
parser.add_argument('config', help='test config file path')
parser.add_argument('checkpoint', help='checkpoint file')
parser.add_argument(
'--save-path',
type=str,
default='./work_dirs/demos/conditional_samples.png',
help='path to save unconditional samples')
parser.add_argument(
'--device', type=str, default='cuda:0', help='CUDA device id')
# args for inference/sampling
parser.add_argument(
'--num-batches', type=int, default=4, help='Batch size in inference')
parser.add_argument(
'--samples-per-classes',
type=int,
default=5,
help=('This argument work together with `label`, and decide the '
'number of samples to generate for each class in the given '
'`label`. If `label` is not given, samples-per-classes would '
'be regard as the total number of the images to sample.'))
parser.add_argument(
'--label',
type=int,
nargs='+',
help=('Labels want to sample. If not defined, '
'random sampling would be applied.'))
parser.add_argument(
'--sample-all-classes',
action='store_true',
help='Whether sample all classes of the dataset.')
parser.add_argument(
'--sample-model',
type=str,
default='ema',
help='Which model to use for sampling')
parser.add_argument(
'--sample-cfg',
nargs='+',
action=DictAction,
help='Other customized kwargs for sampling function')
# args for image grid
parser.add_argument(
'--padding', type=int, default=0, help='Padding in the image grid.')
parser.add_argument(
'--nrow',
type=int,
default=6,
help=('Number of images displayed in each row of the grid. '
'This argument would work only when label is not given.'))
args = parser.parse_args()
return args
def main():
args = parse_args()
model = init_model(
args.config, checkpoint=args.checkpoint, device=args.device)
if args.sample_cfg is None:
args.sample_cfg = dict()
if args.label is None and not args.sample_all_classes:
label = None
num_samples, nrow = args.samples_per_classes, args.nrow
mmcv.print_log(
'`label` is not passed, code would randomly sample '
f'`samples-per-classes` (={num_samples}) images.', 'mmgen')
else:
if args.sample_all_classes:
mmcv.print_log(
'`sample_all_classes` is set as True, `num-samples`, `label`, '
'and `nrows` would be ignored.', 'mmgen')
# get num_classes
if hasattr(model, 'num_classes') and model.num_classes is not None:
num_classes = model.num_classes
else:
raise AttributeError(
'Cannot get attribute `num_classes` from '
f'{type(model)}. Please check your config.', 'mmgen')
# build label list
meta_labels = [idx for idx in range(num_classes)]
else:
# get unique label
meta_labels = list(set(args.label))
meta_labels.sort()
# generate label to sample
label = []
for idx in meta_labels:
label += [idx] * args.samples_per_classes
num_samples = len(label)
nrow = args.samples_per_classes
mmcv.print_log(
'Set `nrows` as number of samples for each class '
f'(={args.samples_per_classes}).', 'mmgen')
results = sample_conditional_model(model, num_samples, args.num_batches,
args.sample_model, label,
**args.sample_cfg)
results = (results[:, [2, 1, 0]] + 1.) / 2.
# save images
mmcv.mkdir_or_exist(os.path.dirname(args.save_path))
utils.save_image(results, args.save_path, nrow=nrow, padding=args.padding)
if __name__ == '__main__':
main()
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os
import os.path as osp
import sys
import mmcv
import numpy as np
import torch
from mmcv import DictAction
from torchvision import utils
# yapf: disable
sys.path.append(os.path.abspath(os.path.join(__file__, '../..'))) # isort:skip # noqa
from mmgen.apis import init_model, sample_ddpm_model # isort:skip # noqa
# yapf: enable
def parse_args():
parser = argparse.ArgumentParser(description='DDPM demo')
parser.add_argument('config', help='test config file path')
parser.add_argument('checkpoint', help='checkpoint file')
parser.add_argument(
'--save-path',
type=str,
default='./work_dirs/demos/ddpm_samples.png',
help='path to save unconditional samples')
parser.add_argument(
'--device', type=str, default='cuda:0', help='CUDA device id')
# args for inference/sampling
parser.add_argument(
'--num-batches', type=int, default=4, help='Batch size in inference')
parser.add_argument(
'--num-samples',
type=int,
default=12,
help='The total number of samples')
parser.add_argument(
'--sample-model',
type=str,
default='ema',
help='Which model to use for sampling')
parser.add_argument(
'--sample-cfg',
nargs='+',
action=DictAction,
help='Other customized kwargs for sampling function')
parser.add_argument(
'--same-noise',
action='store_true',
help='whether use same noise as input (x_T)')
parser.add_argument(
'--n-skip',
type=int,
default=25,
help=('Skip how many steps before selecting one to visualize. This is '
'helpful with denoising timestep is too much. Only work with '
'`save-path` is end with \'.gif\'.'))
# args for image grid
parser.add_argument(
'--padding', type=int, default=0, help='Padding in the image grid.')
parser.add_argument(
'--nrow',
type=int,
default=2,
help=('Number of images displayed in each row of the grid. '
'This argument would work only when label is not given.'))
# args for image channel order
parser.add_argument(
'--is-rgb',
action='store_true',
help=('If true, color channels will not be permuted, This option is '
'useful when inference model trained with rgb images.'))
args = parser.parse_args()
return args
def create_gif(results, gif_name, fps=60, n_skip=1):
"""Create gif through imageio.
Args:
frames (torch.Tensor): Image frames, shape like [bz, 3, H, W].
gif_name (str): Saved gif name.
fps (int, optional): Frames per second of the generated gif.
Defaults to 60.
n_skip (int, optional): Skip how many steps before selecting one to
visualize. Defaults to 1.
"""
try:
import imageio
except ImportError:
raise RuntimeError('imageio is not installed,'
'Please use “pip install imageio” to install')
frames_list = []
for frame in results[::n_skip]:
frames_list.append(
(frame.permute(1, 2, 0).cpu().numpy() * 255.).astype(np.uint8))
# ensure the final denoising results in frames_list
if not (len(results) % n_skip == 0):
frames_list.append((results[-1].permute(1, 2, 0).cpu().numpy() *
255.).astype(np.uint8))
imageio.mimsave(gif_name, frames_list, 'GIF', fps=fps)
def main():
args = parse_args()
model = init_model(
args.config, checkpoint=args.checkpoint, device=args.device)
if args.sample_cfg is None:
args.sample_cfg = dict()
suffix = osp.splitext(args.save_path)[-1]
if suffix == '.gif':
args.sample_cfg['save_intermedia'] = True
results = sample_ddpm_model(model, args.num_samples, args.num_batches,
args.sample_model, args.same_noise,
**args.sample_cfg)
# save images
mmcv.mkdir_or_exist(os.path.dirname(args.save_path))
if suffix == '.gif':
# concentrate all output of each timestep
results_timestep_list = []
for t in results.keys():
# make grid
results_timestep = utils.make_grid(
results[t], nrow=args.nrow, padding=args.padding)
# unsqueeze at 0, because make grid output is size like [H', W', 3]
results_timestep_list.append(results_timestep[None, ...])
# Concatenates to [n_timesteps, H', W', 3]
results_timestep = torch.cat(results_timestep_list, dim=0)
if not args.is_rgb:
results_timestep = results_timestep[:, [2, 1, 0]]
results_timestep = (results_timestep + 1.) / 2.
create_gif(results_timestep, args.save_path, n_skip=args.n_skip)
else:
if not args.is_rgb:
results = results[:, [2, 1, 0]]
results = (results + 1.) / 2.
utils.save_image(
results, args.save_path, nrow=args.nrow, padding=args.padding)
if __name__ == '__main__':
main()
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os
import sys
import mmcv
from mmcv import DictAction
from torchvision import utils
# yapf: disable
sys.path.append(os.path.abspath(os.path.join(__file__, '../..'))) # isort:skip # noqa
from mmgen.apis import init_model, sample_img2img_model # isort:skip # noqa
# yapf: enable
def parse_args():
parser = argparse.ArgumentParser(description='Translation demo')
parser.add_argument('config', help='test config file path')
parser.add_argument('checkpoint', help='checkpoint file')
parser.add_argument('image_path', help='Image file path')
parser.add_argument(
'--target-domain', type=str, default=None, help='Desired image domain')
parser.add_argument(
'--save-path',
type=str,
default='./work_dirs/demos/translation_sample.png',
help='path to save translation sample')
parser.add_argument(
'--device', type=str, default='cuda:0', help='CUDA device id')
# args for inference/sampling
parser.add_argument(
'--sample-cfg',
nargs='+',
action=DictAction,
help='Other customized kwargs for sampling function')
args = parser.parse_args()
return args
def main():
args = parse_args()
model = init_model(
args.config, checkpoint=args.checkpoint, device=args.device)
if args.sample_cfg is None:
args.sample_cfg = dict()
results = sample_img2img_model(model, args.image_path, args.target_domain,
**args.sample_cfg)
results = (results[:, [2, 1, 0]] + 1.) / 2.
# save images
mmcv.mkdir_or_exist(os.path.dirname(args.save_path))
utils.save_image(results, args.save_path)
if __name__ == '__main__':
main()
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os
import sys
import mmcv
from mmcv import DictAction
from torchvision import utils
# yapf: disable
sys.path.append(os.path.abspath(os.path.join(__file__, '../..'))) # isort:skip # noqa
from mmgen.apis import init_model, sample_unconditional_model # isort:skip # noqa
# yapf: enable
def parse_args():
parser = argparse.ArgumentParser(description='Generation demo')
parser.add_argument('config', help='test config file path')
parser.add_argument('checkpoint', help='checkpoint file')
parser.add_argument(
'--save-path',
type=str,
default='./work_dirs/demos/unconditional_samples.png',
help='path to save unconditional samples')
parser.add_argument(
'--device', type=str, default='cuda:0', help='CUDA device id')
# args for inference/sampling
parser.add_argument(
'--num-batches', type=int, default=4, help='Batch size in inference')
parser.add_argument(
'--num-samples',
type=int,
default=12,
help='The total number of samples')
parser.add_argument(
'--sample-model',
type=str,
default='ema',
help='Which model to use for sampling')
parser.add_argument(
'--sample-cfg',
nargs='+',
action=DictAction,
help='Other customized kwargs for sampling function')
# args for image grid
parser.add_argument(
'--padding', type=int, default=0, help='Padding in the image grid.')
parser.add_argument(
'--nrow',
type=int,
default=6,
help='Number of images displayed in each row of the grid')
args = parser.parse_args()
return args
def main():
args = parse_args()
model = init_model(
args.config, checkpoint=args.checkpoint, device=args.device)
if args.sample_cfg is None:
args.sample_cfg = dict()
results = sample_unconditional_model(model, args.num_samples,
args.num_batches, args.sample_model,
**args.sample_cfg)
results = (results[:, [2, 1, 0]] + 1.) / 2.
# save images
mmcv.mkdir_or_exist(os.path.dirname(args.save_path))
utils.save_image(
results, args.save_path, nrow=args.nrow, padding=args.padding)
if __name__ == '__main__':
main()
ARG PYTORCH="1.8.0"
ARG CUDA="11.1"
ARG CUDNN="8"
FROM pytorch/pytorch:${PYTORCH}-cuda${CUDA}-cudnn${CUDNN}-devel
ENV TORCH_CUDA_ARCH_LIST="6.0 6.1 7.0+PTX"
ENV TORCH_NVCC_FLAGS="-Xfatbin -compress-all"
ENV CMAKE_PREFIX_PATH="$(dirname $(which conda))/../"
RUN apt-get update && apt-get install -y ffmpeg libsm6 libxext6 git ninja-build libglib2.0-0 libsm6 libxrender-dev libxext6 \
&& apt-get clean \
&& rm -rf /var/lib/apt/lists/*
# Install MMCV
RUN pip install mmcv-full==1.3.16 -f https://download.openmmlab.com/mmcv/dist/cu111/torch1.8.0/index.html
# Install MMGeneration
RUN conda clean --all
RUN git clone https://github.com/open-mmlab/mmgeneration.git /mmgen
WORKDIR /mmgen
ENV FORCE_CUDA="1"
RUN pip install -r requirements.txt
RUN pip install --no-cache-dir -e .
# Minimal makefile for Sphinx documentation
#
# You can set these variables from the command line, and also
# from the environment for the first two.
SPHINXOPTS ?=
SPHINXBUILD ?= sphinx-build
SOURCEDIR = .
BUILDDIR = _build
# Put it first so that "make" without argument is like "make help".
help:
@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
.PHONY: help Makefile
# Catch-all target: route all unknown targets to Sphinx using the new
# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
%: Makefile
@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
.header-logo {
background-image: url("https://user-images.githubusercontent.com/12726765/114528756-de55af80-9c7b-11eb-94d7-d3224ada1585.png");
background-size: 173px 40px;
height: 40px;
width: 173px;
}
mmgen.apis
--------------
.. automodule:: mmgen.apis
:members:
mmgen.core
--------------
evaluation
^^^^^^^^^^
.. automodule:: mmgen.core.evaluation
:members:
hooks
^^^^^^^^^^
.. automodule:: mmgen.core.hooks
:members:
optimizer
^^^^^^^^^^
.. automodule:: mmgen.core.optimizer
:members:
runners
^^^^^^^^^^
.. automodule:: mmgen.core.runners
:members:
scheduler
^^^^^^^^^^
.. automodule:: mmgen.core.scheduler
:members:
mmgen.datasets
--------------
datasets
^^^^^^^^^^
.. automodule:: mmgen.datasets
:members:
pipelines
^^^^^^^^^^
.. automodule:: mmgen.datasets.pipelines
:members:
mmgen.models
--------------
architectures
^^^^^^^^^^
.. automodule:: mmgen.models.architectures
:members:
common
^^^^^^^^^^
.. automodule:: mmgen.models.common
:members:
gans
^^^^^^^^^^^^
.. automodule:: mmgen.models.gans
:members:
losses
^^^^^^^^^^^^
.. automodule:: mmgen.models.losses
:members:
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment