Commit 57e0e891 authored by limm's avatar limm
Browse files

add part mmgeneration code

parent 04e07f48
cff-version: 1.2.0
message: "If you use this software, please cite it as below."
authors:
- name: "MMGeneration Contributors"
title: "OpenMMLab's next-generation toolbox for generative models"
date-released: 2020-07-10
url: "https://github.com/open-mmlab/mmgeneration"
license: Apache-2.0
Copyright (c) OpenMMLab. All rights reserved.
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright 2018-2019 Open-MMLab.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
# Licenses for special operations
In this file, we list the operations with other licenses instead of Apache 2.0. Users should be careful about adopting these operations in any commercial matters.
| Operation | Files | License |
| :------------------: | :-----------------------------------------------------------------------------------------------------------------------------------: | :------------: |
| conv2d_gradfix | [mmgen/ops/conv2d_gradfix.py](https://github.com/open-mmlab/mmgeneration/blob/master/mmgen/ops/conv2d_gradfix.py) | NVIDIA License |
| compute_pr_distances | [mmgen/core/evaluation/metric_utils.py](https://github.com/open-mmlab/mmgeneration/blob/master/mmgen/core/evaluation/metric_utils.py) | NVIDIA License |
# mmgeneration <div align="center">
<img src="https://user-images.githubusercontent.com/12726765/114528756-de55af80-9c7b-11eb-94d7-d3224ada1585.png" width="400"/>
<div>&nbsp;</div>
<div align="center">
<b><font size="5">OpenMMLab website</font></b>
<sup>
<a href="https://openmmlab.com">
<i><font size="4">HOT</font></i>
</a>
</sup>
&nbsp;&nbsp;&nbsp;&nbsp;
<b><font size="5">OpenMMLab platform</font></b>
<sup>
<a href="https://platform.openmmlab.com">
<i><font size="4">TRY IT OUT</font></i>
</a>
</sup>
</div>
<div>&nbsp;</div>
</div>
[![PyPI](https://img.shields.io/pypi/v/mmgen)](https://pypi.org/project/mmgen)
[![docs](https://img.shields.io/badge/docs-latest-blue)](https://mmgeneration.readthedocs.io/en/latest/)
[![badge](https://github.com/open-mmlab/mmgeneration/workflows/build/badge.svg)](https://github.com/open-mmlab/mmgeneration/actions)
[![codecov](https://codecov.io/gh/open-mmlab/mmgeneration/branch/master/graph/badge.svg)](https://codecov.io/gh/open-mmlab/mmgeneration)
[![license](https://img.shields.io/github/license/open-mmlab/mmgeneration.svg)](https://github.com/open-mmlab/mmgeneration/blob/master/LICENSE)
[![open issues](https://isitmaintained.com/badge/open/open-mmlab/mmgeneration.svg)](https://github.com/open-mmlab/mmgeneration/issues)
[![issue resolution](https://isitmaintained.com/badge/resolution/open-mmlab/mmgeneration.svg)](https://github.com/open-mmlab/mmgeneration/issues)
[📘Documentation](https://mmgeneration.readthedocs.io/en/latest/) |
[🛠️Installation](https://mmgeneration.readthedocs.io/en/latest/get_started.html#installation) |
[👀Model Zoo](https://mmgeneration.readthedocs.io/en/latest/modelzoo_statistics.html) |
[🆕Update News](https://github.com/open-mmlab/mmgeneration/blob/master/docs/en/changelog.md) |
[🚀Ongoing Projects](https://github.com/open-mmlab/mmgeneration/projects) |
[🤔Reporting Issues](https://github.com/open-mmlab/mmgeneration/issues)
English | [简体中文](README_zh-CN.md)
## What's New
MMGeneration has been merged in [MMEditing](https://github.com/open-mmlab/mmediting/tree/1.x). And we have supported new generation tasks and models. We highlight the following new features:
- 🌟 Text2Image
-[GLIDE](https://github.com/open-mmlab/mmediting/tree/1.x/projects/glide/configs/README.md)
-[Disco-Diffusion](https://github.com/open-mmlab/mmediting/tree/1.x/configs/disco_diffusion/README.md)
-[Stable-Diffusion](https://github.com/open-mmlab/mmediting/tree/1.x/configs/stable_diffusion/README.md)
- 🌟 3D-aware Generation
-[EG3D](https://github.com/open-mmlab/mmediting/tree/1.x/configs/eg3d/README.md)
## Introduction
MMGeneration is a powerful toolkit for generative models, especially for GANs now. It is based on PyTorch and [MMCV](https://github.com/open-mmlab/mmcv). The master branch works with **PyTorch 1.5+**.
<div align="center">
<img src="https://user-images.githubusercontent.com/12726765/114534478-9a65a900-9c81-11eb-8087-de8b6816eed8.png" width="800"/>
</div>
## Major Features
- **High-quality Training Performance:** We currently support training on Unconditional GANs, Internal GANs, and Image Translation Models. Support for conditional models will come soon.
- **Powerful Application Toolkit:** A plentiful toolkit containing multiple applications in GANs is provided to users. GAN interpolation, GAN projection, and GAN manipulations are integrated into our framework. It's time to play with your GANs! ([Tutorial for applications](docs/en/tutorials/applications.md))
- **Efficient Distributed Training for Generative Models:** For the highly dynamic training in generative models, we adopt a new way to train dynamic models with `MMDDP`. ([Tutorial for DDP](docs/en/tutorials/ddp_train_gans.md))
- **New Modular Design for Flexible Combination:** A new design for complex loss modules is proposed for customizing the links between modules, which can achieve flexible combination among different modules. ([Tutorial for new modular design](docs/en/tutorials/customize_losses.md))
<table>
<thead>
<tr>
<td>
<div align="center">
<b> Training Visualization</b>
<br/>
<img src="https://user-images.githubusercontent.com/12726765/114509105-b6f4e780-9c67-11eb-8644-110b3cb01314.gif" width="200"/>
</div></td>
<td>
<div align="center">
<b> GAN Interpolation</b>
<br/>
<img src="https://user-images.githubusercontent.com/12726765/114679300-9fd4f900-9d3e-11eb-8f37-c36a018c02f7.gif" width="200"/>
</div></td>
<td>
<div align="center">
<b> GAN Projector</b>
<br/>
<img src="https://user-images.githubusercontent.com/12726765/114524392-c11ee200-9c77-11eb-8b6d-37bc637f5626.gif" width="200"/>
</div></td>
<td>
<div align="center">
<b> GAN Manipulation</b>
<br/>
<img src="https://user-images.githubusercontent.com/12726765/114523716-20302700-9c77-11eb-804e-327ae1ca0c5b.gif" width="200"/>
</div></td>
</tr>
</thead>
</table>
## Highlight
- **Positional Encoding as Spatial Inductive Bias in GANs (CVPR2021)** has been released in `MMGeneration`. [\[Config\]](configs/positional_encoding_in_gans/README.md), [\[Project Page\]](https://nbei.github.io/gan-pos-encoding.html)
- Conditional GANs have been supported in our toolkit. More methods and pre-trained weights will come soon.
- Mixed-precision training (FP16) for StyleGAN2 has been supported. Please check [the comparison](configs/styleganv2/README.md) between different implementations.
## Changelog
v0.7.3 was released on 14/04/2023. Please refer to [changelog.md](docs/en/changelog.md) for details and release history.
## Installation
MMGeneration depends on [PyTorch](https://pytorch.org/) and [MMCV](https://github.com/open-mmlab/mmcv).
Below are quick steps for installation.
**Step 1.**
Install PyTorch following [official instructions](https://pytorch.org/get-started/locally/), e.g.
```python
pip3 install torch torchvision
```
**Step 2.**
Install MMCV with [MIM](https://github.com/open-mmlab/mim).
```
pip3 install openmim
mim install mmcv-full
```
**Step 3.**
Install MMGeneration from source.
```
git clone https://github.com/open-mmlab/mmgeneration.git
cd mmgeneration
pip3 install -e .
```
Please refer to [get_started.md](docs/en/get_started.md) for more detailed instruction.
## Getting Started
Please see [get_started.md](docs/en/get_started.md) for the basic usage of MMGeneration. [docs/en/quick_run.md](docs/en/quick_run.md) can offer full guidance for quick run. For other details and tutorials, please go to our [documentation](https://mmgeneration.readthedocs.io/).
## ModelZoo
These methods have been carefully studied and supported in our frameworks:
<details open>
<summary>Unconditional GANs (click to collapse)</summary>
-[DCGAN](configs/dcgan/README.md) (ICLR'2016)
-[WGAN-GP](configs/wgan-gp/README.md) (NIPS'2017)
-[LSGAN](configs/lsgan/README.md) (ICCV'2017)
-[GGAN](configs/ggan/README.md) (arXiv'2017)
-[PGGAN](configs/pggan/README.md) (ICLR'2018)
-[StyleGANV1](configs/styleganv1/README.md) (CVPR'2019)
-[StyleGANV2](configs/styleganv2/README.md) (CVPR'2020)
-[StyleGANV3](configs/styleganv3/README.md) (NeurIPS'2021)
-[Positional Encoding in GANs](configs/positional_encoding_in_gans/README.md) (CVPR'2021)
</details>
<details open>
<summary>Conditional GANs (click to collapse)</summary>
-[SNGAN](configs/sngan_proj/README.md) (ICLR'2018)
-[Projection GAN](configs/sngan_proj/README.md) (ICLR'2018)
-[SAGAN](configs/sagan/README.md) (ICML'2019)
-[BIGGAN/BIGGAN-DEEP](configs/biggan/README.md) (ICLR'2019)
</details>
<details open>
<summary>Tricks for GANs (click to collapse)</summary>
-[ADA](configs/ada/README.md) (NeurIPS'2020)
</details>
<details open>
<summary>Image2Image Translation (click to collapse)</summary>
-[Pix2Pix](configs/pix2pix/README.md) (CVPR'2017)
-[CycleGAN](configs/cyclegan/README.md) (ICCV'2017)
</details>
<details open>
<summary>Internal Learning (click to collapse)</summary>
-[SinGAN](configs/singan/README.md) (ICCV'2019)
</details>
<details open>
<summary>Denoising Diffusion Probabilistic Models (click to collapse)</summary>
-[Improved DDPM](configs/improved_ddpm/README.md) (arXiv'2021)
</details>
## Related-Applications
-[MMGEN-FaceStylor](https://github.com/open-mmlab/MMGEN-FaceStylor)
## Contributing
We appreciate all contributions to improve MMGeneration. Please refer to [CONTRIBUTING.md](https://github.com/open-mmlab/mmcv/blob/master/CONTRIBUTING.md) in MMCV for more details about the contributing guideline.
## Citation
If you find this project useful in your research, please consider cite:
```BibTeX
@misc{2021mmgeneration,
title={{MMGeneration}: OpenMMLab Generative Model Toolbox and Benchmark},
author={MMGeneration Contributors},
howpublished = {\url{https://github.com/open-mmlab/mmgeneration}},
year={2021}
}
```
## License
This project is released under the [Apache 2.0 license](LICENSE). Some operations in `MMGeneration` are with other licenses instead of Apache2.0. Please refer to [LICENSES.md](LICENSES.md) for the careful check, if you are using our code for commercial matters.
## Projects in OpenMMLab
- [MMCV](https://github.com/open-mmlab/mmcv): OpenMMLab foundational library for computer vision.
- [MIM](https://github.com/open-mmlab/mim): MIM installs OpenMMLab packages.
- [MMClassification](https://github.com/open-mmlab/mmclassification): OpenMMLab image classification toolbox and benchmark.
- [MMDetection](https://github.com/open-mmlab/mmdetection): OpenMMLab detection toolbox and benchmark.
- [MMDetection3D](https://github.com/open-mmlab/mmdetection3d): OpenMMLab's next-generation platform for general 3D object detection.
- [MMRotate](https://github.com/open-mmlab/mmrotate): OpenMMLab rotated object detection toolbox and benchmark.
- [MMSegmentation](https://github.com/open-mmlab/mmsegmentation): OpenMMLab semantic segmentation toolbox and benchmark.
- [MMOCR](https://github.com/open-mmlab/mmocr): OpenMMLab text detection, recognition, and understanding toolbox.
- [MMPose](https://github.com/open-mmlab/mmpose): OpenMMLab pose estimation toolbox and benchmark.
- [MMHuman3D](https://github.com/open-mmlab/mmhuman3d): OpenMMLab 3D human parametric model toolbox and benchmark.
- [MMSelfSup](https://github.com/open-mmlab/mmselfsup): OpenMMLab self-supervised learning toolbox and benchmark.
- [MMRazor](https://github.com/open-mmlab/mmrazor): OpenMMLab model compression toolbox and benchmark.
- [MMFewShot](https://github.com/open-mmlab/mmfewshot): OpenMMLab fewshot learning toolbox and benchmark.
- [MMAction2](https://github.com/open-mmlab/mmaction2): OpenMMLab's next-generation action understanding toolbox and benchmark.
- [MMTracking](https://github.com/open-mmlab/mmtracking): OpenMMLab video perception toolbox and benchmark.
- [MMFlow](https://github.com/open-mmlab/mmflow): OpenMMLab optical flow toolbox and benchmark.
- [MMEditing](https://github.com/open-mmlab/mmediting): OpenMMLab image and video editing toolbox.
- [MMGeneration](https://github.com/open-mmlab/mmgeneration): OpenMMLab image and video generative models toolbox.
- [MMDeploy](https://github.com/open-mmlab/mmdeploy): OpenMMLab model deployment framework.
import argparse
import os
import sys
import mmcv
import torch
import torch.nn as nn
from mmcv import Config, DictAction
from mmcv.runner import load_checkpoint
from torchvision.utils import save_image
# yapf: disable
sys.path.append(os.path.abspath(os.path.join(__file__, '../..'))) # isort:skip # noqa
from mmgen.apis import set_random_seed # isort:skip # noqa
from mmgen.core.evaluation import slerp # isort:skip # noqa
from mmgen.models import build_model # isort:skip # noqa
from mmgen.models.architectures import BigGANDeepGenerator, BigGANGenerator # isort:skip # noqa
from mmgen.models.architectures.common import get_module_device # isort:skip # noqa
# yapf: enable
_default_embedding_name = dict(
BigGANGenerator='shared_embedding',
BigGANDeepGenerator='shared_embedding',
SNGANGenerator='NULL',
SAGANGenerator='NULL')
def parse_args():
parser = argparse.ArgumentParser(
description='Sampling from latents\' interpolation')
parser.add_argument('config', help='evaluation config file path')
parser.add_argument('checkpoint', help='checkpoint file')
parser.add_argument(
'--use-cpu',
action='store_true',
help='whether to use cpu device for sampling')
parser.add_argument(
'--embedding-name',
type=str,
default=None,
help='name of conditional model\'s embedding layer')
parser.add_argument(
'--fix-z',
action='store_true',
help='whether to fix the noise for conditional model')
parser.add_argument(
'--fix-y',
action='store_true',
help='whether to fix the label for conditional model')
parser.add_argument('--seed', type=int, default=2021, help='random seed')
parser.add_argument(
'--deterministic',
action='store_true',
help='whether to set deterministic options for CUDNN backend.')
parser.add_argument(
'--samples-path', type=str, help='path to store images.')
parser.add_argument(
'--sample-model',
type=str,
default='ema',
help='use which mode (ema/orig) in sampling.')
parser.add_argument(
'--show-mode',
choices=['group', 'sequence'],
default='sequence',
help='mode to show interpolation result.')
parser.add_argument(
'--interp-mode',
choices=['lerp', 'slerp'],
default='lerp',
help='mode to sample from endpoints\'s interpolation.')
parser.add_argument(
'--endpoint', type=int, default=2, help='The number of endpoints.')
parser.add_argument(
'--batch-size',
type=int,
default=2,
help='batch size used in generator sampling.')
parser.add_argument(
'--interval',
type=int,
default=10,
help='The number of intervals between two endpoints.')
parser.add_argument(
'--sample-cfg',
nargs='+',
action=DictAction,
help='Other customized kwargs for sampling function')
args = parser.parse_args()
return args
@torch.no_grad()
def batch_inference(generator,
noise,
embedding=None,
num_batches=-1,
max_batch_size=16,
dict_key=None,
**kwargs):
"""Inference function to get a batch of desired data from output dictionary
of generator.
Args:
generator (nn.Module): Generator of a conditional model.
noise (Tensor | list[torch.tensor] | None): A batch of noise
Tensor.
embedding (Tensor, optional): Embedding tensor of label for
conditional models. Defaults to None.
num_batches (int, optional): The number of batchs for
inference. Defaults to -1.
max_batch_size (int, optional): The number of batch size for
inference. Defaults to 16.
dict_key (str, optional): key used to get results from output
dictionary of generator. Defaults to None.
Returns:
torch.Tensor: Tensor of output image, noise batch or label
batch.
"""
# split noise into groups
if noise is not None:
if isinstance(noise, torch.Tensor):
num_batches = noise.shape[0]
noise_group = torch.split(noise, max_batch_size, 0)
else:
num_batches = noise[0].shape[0]
noise_group = torch.split(noise[0], max_batch_size, 0)
noise_group = [[noise_tensor] for noise_tensor in noise_group]
else:
noise_group = [None] * (
num_batches // max_batch_size +
(1 if num_batches % max_batch_size > 0 else 0))
# split embedding into groups
if embedding is not None:
assert isinstance(embedding, torch.Tensor)
num_batches = embedding.shape[0]
embedding_group = torch.split(embedding, max_batch_size, 0)
else:
embedding_group = [None] * (
num_batches // max_batch_size +
(1 if num_batches % max_batch_size > 0 else 0))
# split batchsize into groups
batchsize_group = [max_batch_size] * (num_batches // max_batch_size)
if num_batches % max_batch_size > 0:
batchsize_group += [num_batches % max_batch_size]
device = get_module_device(generator)
outputs = []
for _noise, _embedding, _num_batches in zip(noise_group, embedding_group,
batchsize_group):
if isinstance(_noise, torch.Tensor):
_noise = _noise.to(device)
if isinstance(_noise, list):
_noise = [ele.to(device) for ele in _noise]
if _embedding is not None:
_embedding = _embedding.to(device)
output = generator(
_noise, label=_embedding, num_batches=_num_batches, **kwargs)
output = output[dict_key] if dict_key else output
if isinstance(output, list):
output = output[0]
# once obtaining sampled results, we immediately put them into cpu
# to save cuda memory
outputs.append(output.to('cpu'))
outputs = torch.cat(outputs, dim=0)
return outputs
@torch.no_grad()
def sample_from_path(generator,
latent_a,
latent_b,
label_a,
label_b,
intervals,
embedding_name=None,
interp_mode='lerp',
**kwargs):
interp_alphas = torch.linspace(0, 1, intervals)
interp_samples = []
device = get_module_device(generator)
if embedding_name is None:
generator_name = generator.__class__.__name__
assert generator_name in _default_embedding_name
embedding_name = _default_embedding_name[generator_name]
embedding_fn = getattr(generator, embedding_name, nn.Identity())
embedding_a = embedding_fn(label_a.to(device))
embedding_b = embedding_fn(label_b.to(device))
for alpha in interp_alphas:
# calculate latent interpolation
if interp_mode == 'lerp':
latent_interp = torch.lerp(latent_a, latent_b, alpha)
else:
assert latent_a.ndim == latent_b.ndim == 2
latent_interp = slerp(latent_a, latent_b, alpha)
# calculate embedding interpolation
embedding_interp = embedding_a + (
embedding_b - embedding_a) * alpha.to(embedding_a.dtype)
if isinstance(generator, (BigGANDeepGenerator, BigGANGenerator)):
kwargs.update(dict(use_outside_embedding=True))
sample = batch_inference(generator, latent_interp, embedding_interp,
**kwargs)
interp_samples.append(sample)
return interp_samples
def main():
args = parse_args()
cfg = Config.fromfile(args.config)
# set cudnn_benchmark
if cfg.get('cudnn_benchmark', False):
torch.backends.cudnn.benchmark = True
# set random seeds
if args.seed is not None:
print('set random seed to', args.seed)
set_random_seed(args.seed, deterministic=args.deterministic)
# build the model and load checkpoint
model = build_model(
cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg)
_ = load_checkpoint(model, args.checkpoint, map_location='cpu')
# sanity check for models without ema
if not model.use_ema:
args.sample_model = 'orig'
if args.sample_model == 'ema':
generator = model.generator_ema
else:
generator = model.generator
mmcv.print_log(f'Sampling model: {args.sample_model}', 'mmgen')
mmcv.print_log(f'Show mode: {args.show_mode}', 'mmgen')
mmcv.print_log(f'Samples path: {args.samples_path}', 'mmgen')
generator.eval()
if not args.use_cpu:
generator = generator.cuda()
if args.show_mode == 'sequence':
assert args.endpoint >= 2
else:
assert args.endpoint >= 2 and args.endpoint % 2 == 0
kwargs = dict(max_batch_size=args.batch_size)
if args.sample_cfg is None:
args.sample_cfg = dict()
kwargs.update(args.sample_cfg)
# get noises corresponding to each endpoint
noise_batch = batch_inference(
generator,
None,
num_batches=args.endpoint,
dict_key='noise_batch',
return_noise=True,
**kwargs)
# get labels corresponding to each endpoint
label_batch = batch_inference(
generator,
None,
num_batches=args.endpoint,
dict_key='label',
return_noise=True,
**kwargs)
# set label fixed
if args.fix_y:
label_batch = label_batch[0] * torch.ones_like(label_batch)
# set noise fixed
if args.fix_z:
noise_batch = torch.cat(
[noise_batch[0:1, ]] * noise_batch.shape[0], dim=0)
if args.show_mode == 'sequence':
results = sample_from_path(generator, noise_batch[:-1, ],
noise_batch[1:, ], label_batch[:-1, ],
label_batch[1:, ], args.interval,
args.embedding_name, args.interp_mode,
**kwargs)
else:
results = sample_from_path(generator, noise_batch[::2, ],
noise_batch[1::2, ], label_batch[:-1, ],
label_batch[1:, ], args.interval,
args.embedding_name, args.interp_mode,
**kwargs)
# reorder results
results = torch.stack(results).permute(1, 0, 2, 3, 4)
_, _, ch, h, w = results.shape
results = results.reshape(-1, ch, h, w)
# rescale value range to [0, 1]
results = ((results + 1) / 2)
results = results[:, [2, 1, 0], ...]
results = results.clamp_(0, 1)
# save image
mmcv.mkdir_or_exist(args.samples_path)
if args.show_mode == 'sequence':
for i in range(results.shape[0]):
image = results[i:i + 1]
save_image(
image,
os.path.join(args.samples_path, '{:0>5d}'.format(i) + '.png'))
else:
save_image(
results,
os.path.join(args.samples_path, 'group.png'),
nrow=args.interval)
if __name__ == '__main__':
main()
import argparse
import os
import sys
import imageio
import mmcv
import numpy as np
import torch
from mmcv import Config, DictAction
from mmcv.runner import load_checkpoint
from torchvision.utils import save_image
# yapf: disable
sys.path.append(os.path.abspath(os.path.join(__file__, '../..'))) # isort:skip # noqa
from mmgen.apis import set_random_seed # isort:skip # noqa
from mmgen.core.evaluation import slerp # isort:skip # noqa
from mmgen.models import build_model # isort:skip # noqa
# yapf: enable
def parse_args():
parser = argparse.ArgumentParser(
description='Sampling from latents\' interpolation')
parser.add_argument('config', help='evaluation config file path')
parser.add_argument('checkpoint', help='checkpoint file')
parser.add_argument(
'--use-cpu',
action='store_true',
help='whether to use cpu device for sampling')
parser.add_argument(
'--export-video',
action='store_true',
help='If true, export video rather than images')
parser.add_argument('--seed', type=int, default=2021, help='random seed')
parser.add_argument(
'--deterministic',
action='store_true',
help='whether to set deterministic options for CUDNN backend.')
parser.add_argument(
'--samples-path', type=str, help='path to store images.')
parser.add_argument(
'--sample-model',
type=str,
default='ema',
help='use which mode (ema/orig) in sampling.')
parser.add_argument(
'--show-mode',
choices=['group', 'sequence'],
default='sequence',
help='mode to show interpolation result.')
parser.add_argument(
'--interp-mode',
choices=['lerp', 'slerp'],
default='lerp',
help='mode to sample from endpoints\'s interpolation.')
parser.add_argument(
'--proj-latent',
type=str,
default=None,
help='Projection image files produced by stylegan_projector.py. If this \
argument is given, then the projected latent will be used as the input\
noise.')
parser.add_argument(
'--endpoint', type=int, default=2, help='The number of endpoints.')
parser.add_argument(
'--batch-size',
type=int,
default=2,
help='batch size used in generator sampling.')
parser.add_argument(
'--interval',
type=int,
default=10,
help='The number of intervals between two endpoints.')
parser.add_argument(
'--space',
choices=['z', 'w'],
default='w',
help='Interpolation space.')
parser.add_argument(
'--sample-cfg',
nargs='+',
action=DictAction,
help='Other customized kwargs for sampling function')
args = parser.parse_args()
return args
@torch.no_grad()
def batch_inference(generator,
noise,
num_batches=-1,
max_batch_size=16,
dict_key=None,
**kwargs):
"""Inference function to get a batch of desired data from output dictionary
of generator.
Args:
generator (nn.Module): Generator of a conditional model.
noise (Tensor | list[torch.tensor] | None): A batch of noise
Tensor.
num_batches (int, optional): The number of batchs for
inference. Defaults to -1.
max_batch_size (int, optional): The number of batch size for
inference. Defaults to 16.
dict_key (str, optional): key used to get results from output
dictionary of generator. Defaults to None.
Returns:
torch.Tensor: Tensor of output image, noise batch or label
batch.
"""
# split noise into groups
if noise is not None:
if isinstance(noise, torch.Tensor):
num_batches = noise.shape[0]
noise_group = torch.split(noise, max_batch_size, 0)
else:
num_batches = noise[0].shape[0]
noise_group = torch.split(noise[0], max_batch_size, 0)
noise_group = [[noise_tensor] for noise_tensor in noise_group]
else:
noise_group = [None] * (
num_batches // max_batch_size +
(1 if num_batches % max_batch_size > 0 else 0))
# split batchsize into groups
batchsize_group = [max_batch_size] * (num_batches // max_batch_size)
if num_batches % max_batch_size > 0:
batchsize_group += [num_batches % max_batch_size]
outputs = []
for _noise, _num_batches in zip(noise_group, batchsize_group):
if isinstance(_noise, torch.Tensor):
_noise = _noise.cuda()
if isinstance(_noise, list):
_noise = [ele.cuda() for ele in _noise]
output = generator(_noise, num_batches=_num_batches, **kwargs)
output = output[dict_key] if dict_key else output
if isinstance(output, list):
output = output[0]
# once obtaining sampled results, we immediately put them into cpu
# to save cuda memory
outputs.append(output.to('cpu'))
outputs = torch.cat(outputs, dim=0)
return outputs
def layout_grid(video_out,
all_img,
grid_w=1,
grid_h=1,
float_to_uint8=True,
chw_to_hwc=True,
to_numpy=True):
r"""Arrange images into video frames.
Ref: https://github.com/NVlabs/stylegan3/blob/a5a69f58294509598714d1e88c9646c3d7c6ec94/gen_video.py#L28 # noqa
Args:
video_out (Writer): Video writer.
all_img (torch.Tensor): All images to be displayed in video.
grid_w (int, optional): Column number in a frame. Defaults to 1.
grid_h (int, optional): Row number in a frame. Defaults to 1.
float_to_uint8 (bool, optional): Change torch value from `float` to `uint8`. Defaults to True.
chw_to_hwc (bool, optional): Change channel order from `chw` to `hwc`. Defaults to True.
to_numpy (bool, optional): Change image format from `torch.Tensor` to `np.array`. Defaults to True.
Returns:
Writer: Video writer.
"""
batch_size, channels, img_h, img_w = all_img.shape
assert batch_size % (grid_w * grid_h) == 0
images_per_frame = grid_w * grid_h
n_frames = batch_size // images_per_frame
all_img = all_img.reshape(images_per_frame, n_frames, channels, img_h,
img_w).permute(1, 0, 2, 3, 4).reshape(
n_frames, images_per_frame, channels, img_h,
img_w)
for i in range(0, n_frames):
img = all_img[i]
if float_to_uint8:
img = (img * 255.).clamp(0, 255).to(torch.uint8)
img = img.reshape(grid_h, grid_w, channels, img_h, img_w)
img = img.permute(2, 0, 3, 1, 4)
img = img.reshape(channels, grid_h * img_h, grid_w * img_w)
if chw_to_hwc:
img = img.permute(1, 2, 0)
if to_numpy:
img = img.cpu().numpy()
video_out.append_data(img)
return video_out
def crack_integer(integer):
"""Cracking an integer into the product of two nearest integers.
Args:
integer (int): An positive integer.
Returns:
tuple: Two integers.
"""
start = int(np.sqrt(integer))
factor = integer / start
while int(factor) != factor:
start += 1
factor = integer / start
return int(factor), start
@torch.no_grad()
def sample_from_path(generator,
latent_a,
latent_b,
intervals,
interp_mode='lerp',
space='z',
**kwargs):
interp_alphas = np.linspace(0, 1, intervals)
interp_samples = []
for alpha in interp_alphas:
if interp_mode == 'lerp':
latent_interp = torch.lerp(latent_a, latent_b, alpha)
else:
assert latent_a.ndim == latent_b.ndim == 2
latent_interp = slerp(latent_a, latent_b, alpha)
if space == 'w':
latent_interp = [latent_interp]
sample = batch_inference(generator, latent_interp, **kwargs)
interp_samples.append(sample)
return interp_samples
def main():
args = parse_args()
cfg = Config.fromfile(args.config)
# set cudnn_benchmark
if cfg.get('cudnn_benchmark', False):
torch.backends.cudnn.benchmark = True
# set random seeds
if args.seed is not None:
print('set random seed to', args.seed)
set_random_seed(args.seed, deterministic=args.deterministic)
# build the model and load checkpoint
model = build_model(
cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg)
_ = load_checkpoint(model, args.checkpoint, map_location='cpu')
# sanity check for models without ema
if not model.use_ema:
args.sample_model = 'orig'
if args.sample_model == 'ema':
generator = model.generator_ema
else:
generator = model.generator
mmcv.print_log(f'Sampling model: {args.sample_model}', 'mmgen')
mmcv.print_log(f'Show mode: {args.show_mode}', 'mmgen')
mmcv.print_log(f'Samples path: {args.samples_path}', 'mmgen')
generator.eval()
if not args.use_cpu:
generator = generator.cuda()
# if given proj_latent, reset args.endpoint
if args.proj_latent is not None:
mmcv.print_log(f'Load projected latent: {args.proj_latent}', 'mmgen')
proj_file = torch.load(args.proj_latent)
proj_n = len(proj_file)
setattr(args, 'endpoint', proj_n)
assert args.space == 'w', 'Projected latent are w or w-plus latent.'
noise_batch = []
for img_path in proj_file:
noise_batch.append(proj_file[img_path]['latent'].unsqueeze(0))
noise_batch = torch.cat(noise_batch, dim=0).cuda()
if args.use_cpu:
noise_batch = noise_batch.to('cpu')
if args.show_mode == 'sequence':
assert args.endpoint >= 2
else:
assert args.endpoint >= 2 and args.endpoint % 2 == 0,\
'''We need paired images in group mode,
so keep endpoint an even number'''
kwargs = dict(max_batch_size=args.batch_size)
if args.sample_cfg is None:
args.sample_cfg = dict()
kwargs.update(args.sample_cfg)
# remind users to fixed injected noise
if kwargs.get('randomize_noise', 'True'):
mmcv.print_log(
'''Hint: For Style-Based GAN, you can add
`--sample-cfg randomize_noise=False` to fix injected noises''',
'mmgen')
# get noises corresponding to each endpoint
if not args.proj_latent:
noise_batch = batch_inference(
generator,
None,
num_batches=args.endpoint,
dict_key='noise_batch' if args.space == 'z' else 'latent',
return_noise=True,
**kwargs)
if args.space == 'w':
kwargs['truncation_latent'] = generator.get_mean_latent()
kwargs['input_is_latent'] = True
if args.show_mode == 'sequence':
results = sample_from_path(generator, noise_batch[:-1, ],
noise_batch[1:, ], args.interval,
args.interp_mode, args.space, **kwargs)
else:
results = sample_from_path(generator, noise_batch[::2, ],
noise_batch[1::2, ], args.interval,
args.interp_mode, args.space, **kwargs)
# reorder results
results = torch.stack(results).permute(1, 0, 2, 3, 4)
_, _, ch, h, w = results.shape
results = results.reshape(-1, ch, h, w)
# rescale value range to [0, 1]
results = ((results + 1) / 2)
results = results[:, [2, 1, 0], ...]
results = results.clamp_(0, 1)
# save image
mmcv.mkdir_or_exist(args.samples_path)
if args.show_mode == 'sequence':
if args.export_video:
# render video.
video_out = imageio.get_writer(
os.path.join(args.samples_path, 'lerp.mp4'),
mode='I',
fps=60,
codec='libx264',
bitrate='12M')
video_out = layout_grid(video_out, results)
video_out.close()
else:
for i in range(results.shape[0]):
image = results[i:i + 1]
save_image(
image,
os.path.join(args.samples_path,
'{:0>5d}'.format(i) + '.png'))
else:
if args.export_video:
# render video.
video_out = imageio.get_writer(
os.path.join(args.samples_path, 'lerp.mp4'),
mode='I',
fps=60,
codec='libx264',
bitrate='12M')
n_pair = args.endpoint // 2
grid_w, grid_h = crack_integer(n_pair)
video_out = layout_grid(
video_out, results, grid_h=grid_h, grid_w=grid_w)
video_out.close()
else:
save_image(
results,
os.path.join(args.samples_path, 'group.png'),
nrow=args.interval)
if __name__ == '__main__':
main()
"""Modified SeFa (closed-form factorization)
This gan editing method is modified according to Sefa. More details can be
found in Positional Encoding as Spatial Inductive Bias in GANs, CVPR2021.
The major modifications are:
- Calculate eigen vectors on the matrix with all style modulation weights in
styleconvs;
- Allow to adopt unsymetric degree to be more robust to different samples.
"""
import argparse
import os
import sys
import mmcv
import numpy as np
import torch
from mmcv import DictAction
from mmcv.runner import load_checkpoint
from torchvision import utils
# yapf: disable
sys.path.append(os.path.abspath(os.path.join(__file__, '../..'))) # isort:skip # noqa
from mmgen.apis import set_random_seed # isort:skip # noqa
from mmgen.models import build_model # isort:skip # noqa
# yapf: enable
def calc_eigens(args, state_dict):
# get all of the style modulation weights except for weights in `to_rgb`
modulated = {
k: v
for k, v in state_dict.items()
if 'style_modulation' in k and 'to_rgb' not in k and 'weight' in k
}
weight_mat = []
for _, v in modulated.items():
weight_mat.append(v)
W = torch.cat(weight_mat, dim=0)
eigen_vector = torch.svd(W).V
# save eigen vector
output_path = os.path.splitext(args.ckpt)[0] + '_eigen-vec-mod.pth'
torch.save({'ckpt': args.ckpt, 'eigen_vector': eigen_vector}, output_path)
return eigen_vector
if __name__ == '__main__':
# set device
if torch.cuda.is_available():
device = 'cuda'
else:
device = 'cpu'
# set grad enabled = False
torch.set_grad_enabled(False)
parser = argparse.ArgumentParser(
description='Apply modified closed form factorization')
# sefa args
parser.add_argument(
'-i', '--index', type=int, default=0, help='index of eigenvector')
parser.add_argument(
'-d',
'--degree',
type=float,
nargs='+',
default=[2.],
help='scalar factors for moving latent vectors along eigenvector',
)
parser.add_argument(
'--degree-step',
type=float,
default=0.25,
help='The step of changing degrees')
parser.add_argument('-l', '--layer-num', nargs='+', type=int, default=None)
parser.add_argument(
'--eigen-vector',
type=str,
default=None,
help='Path to the eigen vectors')
# gan args
parser.add_argument(
'--randomize-noise',
action='store_true',
help='whether to use random noise in the middle layers')
parser.add_argument('--ckpt', type=str, help='Path to the checkpoint')
parser.add_argument('--config', type=str, help='Path to model config')
parser.add_argument('--truncation', type=float, default=1)
parser.add_argument('--truncation-mean', type=int, default=4096)
parser.add_argument('--noise-channels', type=int, default=512)
parser.add_argument('--input-scale', type=int, default=4)
parser.add_argument(
'--sample-cfg',
nargs='+',
action=DictAction,
help='Other customized kwargs for sampling function')
# system args
parser.add_argument('--num-samples', type=int, default=2)
parser.add_argument('--sample-path', type=str, default=None)
parser.add_argument('--random-seed', type=int, default=2020)
args = parser.parse_args()
set_random_seed(args.random_seed)
cfg = mmcv.Config.fromfile(args.config)
if cfg.get('cudnn_benchmark', False):
torch.backends.cudnn.benchmark = True
mmcv.print_log('Building models and loading checkpoints', 'mmgen')
# build model
model = build_model(
cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg)
model.eval()
load_checkpoint(model, args.ckpt, map_location='cpu')
# get generator
if model.use_ema:
generator = model.generator_ema
else:
generator = model.generator
generator = generator.to(device)
generator.eval()
mmcv.print_log('Calculating or loading eigen vectors', 'mmgen')
# load/calculate eigen vector for current weights
if args.eigen_vector is None:
eigen_vector = calc_eigens(args, generator.state_dict())
else:
eigen_vector = torch.load(args.eigen_vector)['eigen_vector']
eigen_vector = eigen_vector.to(device)
if args.truncation < 1:
# TODO: get mean latent
mean_latent = generator.get_mean_latent(args.truncation_mean)
else:
mean_latent = None
noise = torch.randn((args.num_samples, args.noise_channels), device=device)
latent = generator.style_mapping(noise)
# kwargs for different gan models
kwargs = dict()
# mspie-stylegan2
if args.input_scale > 0:
kwargs['chosen_scale'] = args.input_scale
if args.sample_cfg is None:
args.sample_cfg = dict()
mmcv.print_log('Sampling images with modified SeFa', 'mmgen')
sample = generator([latent], input_is_latent=True, **args.sample_cfg)
# the first line is the original samples
img_list = [sample]
if len(args.degree) == 1:
factor_list = np.arange(-args.degree[0], args.degree[0] + 0.001,
args.degree_step)
else:
factor_list = np.arange(args.degree[0], args.degree[1] + 0.001,
args.degree_step)
for fac in factor_list:
direction = fac * eigen_vector[:, args.index].unsqueeze(0)
if args.layer_num is None:
latent_input = [latent + direction]
else:
latent_all = latent.unsqueeze(1).repeat(1, generator.num_latents,
1)
for l_num in args.layer_num:
latent_all[:, l_num] = latent + direction
latent_input = [latent_all]
sample = generator(
latent_input, input_is_latent=True, **args.sample_cfg)
img_list.append(sample)
mmcv.mkdir_or_exist(args.sample_path)
if args.layer_num is None:
filename = (
f'{args.sample_path}/entangle-i{args.index}-d{args.degree}'
f'-t{args.degree_step}_{str(args.random_seed).zfill(6)}.png')
else:
filename = (f'{args.sample_path}/entangle-i{args.index}-d{args.degree}'
f'-t{args.degree_step}-l{args.layer_num}'
f'_{str(args.random_seed).zfill(6)}.png')
img = torch.cat(img_list, dim=0)[:, [2, 1, 0]]
utils.save_image(
img,
filename,
nrow=args.num_samples,
padding=0,
normalize=True,
range=(-1, 1))
mmcv.print_log(f'Save images to {filename}', 'mmgen')
import argparse
import math
import os
try:
import clip
except ImportError:
raise 'To use styleclip, openai clip need to be installed first'
import mmcv
import torch
import torchvision
from mmcv import Config, DictAction
from torch import optim
from tqdm import tqdm
from mmgen.apis import init_model
from mmgen.models.losses import CLIPLoss, FaceIdLoss
from mmgen.apis import set_random_seed # isort:skip # noqa
def get_lr(t, initial_lr, rampdown=0.25, rampup=0.05):
lr_ramp = min(1, (1 - t) / rampdown)
lr_ramp = 0.5 - 0.5 * math.cos(lr_ramp * math.pi)
lr_ramp = lr_ramp * min(1, t / rampup)
return initial_lr * lr_ramp
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('config', help='model config file path')
parser.add_argument('checkpoint', help='checkpoint file')
parser.add_argument('--seed', type=int, default=2021, help='random seed')
parser.add_argument(
'--deterministic',
action='store_true',
help='whether to set deterministic options for CUDNN backend.')
parser.add_argument(
'--use-cpu',
action='store_true',
help='whether to use cpu device for sampling')
parser.add_argument(
'--description',
type=str,
default='a person with purple hair',
help='the text that guides the editing/generation')
parser.add_argument('--lr', type=float, default=0.1)
parser.add_argument(
'--mode',
type=str,
default='generate',
choices=['edit', 'generate'],
help='choose between edit an image an generate a free one')
parser.add_argument(
'--l2-lambda',
type=float,
default=0.008,
help='weight of the latent distance, used for editing only')
parser.add_argument(
'--id-lambda',
type=float,
default=0.000,
help='weight of id loss, used for editing only')
parser.add_argument(
'--proj-latent',
type=str,
default=None,
help='Projection image files produced by stylegan_projector.py. If this \
argument is given, then the projected latent will be used as the init\
latent.')
parser.add_argument(
'--truncation',
type=float,
default=0.7,
help='used only for the initial latent vector, and only when a latent '
'code path is not provided')
parser.add_argument(
'--step', type=int, default=2000, help='Optimization iterations')
parser.add_argument(
'--save-interval',
type=int,
default=20,
help='if > 0 then saves intermidate results during the optimization')
parser.add_argument(
'--results-dir', type=str, default='work_dirs/styleclip/')
parser.add_argument(
'--sample-cfg',
nargs='+',
action=DictAction,
help='Other customized kwargs for sampling function')
args = parser.parse_args()
return args
def main():
args = parse_args()
# set cudnn_benchmark
cfg = Config.fromfile(args.config)
if cfg.get('cudnn_benchmark', False):
torch.backends.cudnn.benchmark = True
# set random seeds
if args.seed is not None:
print('set random seed to', args.seed)
set_random_seed(args.seed, deterministic=args.deterministic)
os.makedirs(args.results_dir, exist_ok=True)
text_inputs = torch.cat([clip.tokenize(args.description)]).cuda()
model = init_model(args.config, args.checkpoint, device='cpu')
g_ema = model.generator_ema
g_ema.eval()
if not args.use_cpu:
g_ema = g_ema.cuda()
mean_latent = g_ema.get_mean_latent()
# if given proj_latent
if args.proj_latent is not None:
mmcv.print_log(f'Load projected latent: {args.proj_latent}', 'mmgen')
proj_file = torch.load(args.proj_latent)
proj_n = len(proj_file)
assert proj_n == 1
noise_batch = []
for img_path in proj_file:
noise_batch.append(proj_file[img_path]['latent'].unsqueeze(0))
latent_code_init = torch.cat(noise_batch, dim=0).cuda()
elif args.mode == 'edit':
latent_code_init_not_trunc = torch.randn(1, 512).cuda()
with torch.no_grad():
results = g_ema([latent_code_init_not_trunc],
return_latents=True,
truncation=args.truncation,
truncation_latent=mean_latent)
latent_code_init = results['latent']
else:
latent_code_init = mean_latent.detach().clone().repeat(1, 18, 1)
with torch.no_grad():
img_orig = g_ema([latent_code_init],
input_is_latent=True,
randomize_noise=False)
latent = latent_code_init.detach().clone()
latent.requires_grad = True
clip_loss = CLIPLoss(clip_model=dict(in_size=g_ema.out_size))
id_loss = FaceIdLoss(
facenet=dict(type='ArcFace', ir_se50_weights=None, device='cuda'))
optimizer = optim.Adam([latent], lr=args.lr)
pbar = tqdm(range(args.step))
mmcv.print_log(f'Description: {args.description}')
for i in pbar:
t = i / args.step
lr = get_lr(t, args.lr)
optimizer.param_groups[0]['lr'] = lr
img_gen = g_ema([latent], input_is_latent=True, randomize_noise=False)
img_gen = img_gen[:, [2, 1, 0], ...]
# clip loss
c_loss = clip_loss(image=img_gen, text=text_inputs)
if args.id_lambda > 0:
i_loss = id_loss(pred=img_gen, gt=img_orig)[0]
else:
i_loss = 0
if args.mode == 'edit':
l2_loss = ((latent_code_init - latent)**2).sum()
loss = c_loss + args.l2_lambda * l2_loss + args.id_lambda * i_loss
else:
loss = c_loss
optimizer.zero_grad()
loss.backward()
optimizer.step()
pbar.set_description((f'loss: {loss.item():.4f};'))
if args.save_interval > 0 and (i % args.save_interval == 0):
with torch.no_grad():
img_gen = g_ema([latent],
input_is_latent=True,
randomize_noise=False)
img_gen = img_gen[:, [2, 1, 0], ...]
torchvision.utils.save_image(
img_gen,
os.path.join(args.results_dir, f'{str(i).zfill(5)}.png'),
normalize=True,
range=(-1, 1))
if args.mode == 'edit':
img_orig = img_orig[:, [2, 1, 0], ...]
final_result = torch.cat([img_orig, img_gen])
else:
final_result = img_gen
torchvision.utils.save_image(
final_result.detach().cpu(),
os.path.join(args.results_dir, 'final_result.png'),
normalize=True,
scale_each=True,
range=(-1, 1))
if __name__ == '__main__':
main()
r"""
This app is used to invert the styleGAN series synthesis network. We find
the matching latent vector w for given images so that we can manipulate
images in the latent feature space.
Ref: https://github.com/rosinality/stylegan2-pytorch/blob/master/projector.py # noqa
"""
import argparse
import os
import sys
from collections import OrderedDict
import mmcv
import numpy as np
import torch
import torch.nn.functional as F
from mmcv import Config
from mmcv.runner import load_checkpoint
from PIL import Image
from torch import optim
from torchvision import transforms
from torchvision.utils import save_image
from tqdm import tqdm
# yapf: disable
sys.path.append(os.path.abspath(os.path.join(__file__, '../..'))) # isort:skip # noqa
from mmgen.apis import set_random_seed # isort:skip # noqa
from mmgen.models import build_model # isort:skip # noqa
from mmgen.models.architectures.lpips import PerceptualLoss # isort:skip # noqa
# yapf: enable
def parse_args():
parser = argparse.ArgumentParser(
description='Image projector to the StyleGAN-based generator latent \
spaces')
parser.add_argument('config', help='evaluation config file path')
parser.add_argument('checkpoint', help='checkpoint file')
parser.add_argument(
'files',
metavar='FILES',
nargs='+',
help='path to image files to be projected')
parser.add_argument(
'--results-path', type=str, help='path to store projection results.')
parser.add_argument(
'--use-cpu',
action='store_true',
help='whether to use cpu device for sampling')
parser.add_argument('--seed', type=int, default=2021, help='random seed')
parser.add_argument(
'--deterministic',
action='store_true',
help='whether to set deterministic options for CUDNN backend.')
parser.add_argument(
'--sample-model',
type=str,
default='ema',
help='use which mode (ema/orig) in sampling.')
parser.add_argument(
'--lr-rampup',
type=float,
default=0.05,
help='proportion of the learning rate warmup iters in the total iters')
parser.add_argument(
'--lr-rampdown',
type=float,
default=0.25,
help='proportion of the learning rate decay iters in the total iters')
parser.add_argument(
'--lr', type=float, default=0.1, help='maximum learning rate')
parser.add_argument(
'--noise',
type=float,
default=0.05,
help='strength of the noise level')
parser.add_argument(
'--noise-ramp',
type=float,
default=0.75,
help='proportion of the noise level decay iters in the total iters',
)
parser.add_argument(
'--total-iters', type=int, default=1000, help='optimize iterations')
parser.add_argument(
'--noise-regularize',
type=float,
default=1e5,
help='weight of the noise regularization',
)
parser.add_argument(
'--mse', type=float, default=0, help='weight of the mse loss')
parser.add_argument(
'--n-mean-latent',
type=int,
default=10000,
help='sampling times to obtain the mean latent')
parser.add_argument(
'--w-plus',
action='store_true',
help='allow to use distinct latent codes to each layers',
)
args = parser.parse_args()
return args
def noise_regularize(noises):
loss = 0
for noise in noises:
size = noise.shape[2]
while True:
loss = (
loss +
(noise * torch.roll(noise, shifts=1, dims=3)).mean().pow(2) +
(noise * torch.roll(noise, shifts=1, dims=2)).mean().pow(2))
if size <= 8:
break
noise = noise.reshape([-1, 1, size // 2, 2, size // 2, 2])
noise = noise.mean([3, 5])
size //= 2
return loss
def noise_normalize_(noises):
for noise in noises:
mean = noise.mean()
std = noise.std()
noise.data.add_(-mean).div_(std)
def get_lr(t, initial_lr, rampdown=0.25, rampup=0.05):
lr_ramp = min(1, (1 - t) / rampdown)
lr_ramp = 0.5 - 0.5 * np.cos(lr_ramp * np.pi)
lr_ramp = lr_ramp * min(1, t / rampup)
return initial_lr * lr_ramp
def latent_noise(latent, strength):
noise = torch.randn_like(latent) * strength
return latent + noise
def main():
args = parse_args()
cfg = Config.fromfile(args.config)
# set cudnn_benchmark
if cfg.get('cudnn_benchmark', False):
torch.backends.cudnn.benchmark = True
# set random seeds
if args.seed is not None:
print('set random seed to', args.seed)
set_random_seed(args.seed, deterministic=args.deterministic)
# build the model and load checkpoint
model = build_model(
cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg)
_ = load_checkpoint(model, args.checkpoint, map_location='cpu')
# sanity check for models without ema
if not model.use_ema:
args.sample_model = 'orig'
if args.sample_model == 'ema':
generator = model.generator_ema
else:
generator = model.generator
mmcv.print_log(f'Sampling model: {args.sample_model}', 'mmgen')
generator.eval()
device = 'cpu'
if not args.use_cpu:
generator = generator.cuda()
device = 'cuda'
img_size = min(generator.out_size, 256)
transform = transforms.Compose([
transforms.Resize(img_size),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
])
# read images
imgs = []
for imgfile in args.files:
img = Image.open(imgfile).convert('RGB')
img = transform(img)
img = img[[2, 1, 0], ...]
imgs.append(img)
imgs = torch.stack(imgs, 0).to(device)
# get mean and standard deviation of style latents
with torch.no_grad():
noise_sample = torch.randn(
args.n_mean_latent, generator.style_channels, device=device)
latent_out = generator.style_mapping(noise_sample)
latent_mean = latent_out.mean(0)
latent_std = ((latent_out - latent_mean).pow(2).sum() /
args.n_mean_latent)**0.5
latent_in = latent_mean.detach().clone().unsqueeze(0).repeat(
imgs.shape[0], 1)
if args.w_plus:
latent_in = latent_in.unsqueeze(1).repeat(1, generator.num_latents, 1)
latent_in.requires_grad = True
# define lpips loss
percept = PerceptualLoss(use_gpu=device.startswith('cuda'))
# initialize layer noises
noises_single = generator.make_injected_noise()
noises = []
for noise in noises_single:
noises.append(noise.repeat(imgs.shape[0], 1, 1, 1).normal_())
for noise in noises:
noise.requires_grad = True
optimizer = optim.Adam([latent_in] + noises, lr=args.lr)
pbar = tqdm(range(args.total_iters))
# run optimization
for i in pbar:
t = i / args.total_iters
lr = get_lr(t, args.lr, args.lr_rampdown, args.lr_rampup)
optimizer.param_groups[0]['lr'] = lr
noise_strength = latent_std * args.noise * max(
0, 1 - t / args.noise_ramp)**2
latent_n = latent_noise(latent_in, noise_strength.item())
img_gen = generator([latent_n],
input_is_latent=True,
injected_noise=noises)
batch, channel, height, width = img_gen.shape
if height > 256:
factor = height // 256
img_gen = img_gen.reshape(batch, channel, height // factor, factor,
width // factor, factor)
img_gen = img_gen.mean([3, 5])
p_loss = percept(img_gen, imgs).sum()
n_loss = noise_regularize(noises)
mse_loss = F.mse_loss(img_gen, imgs)
loss = p_loss + args.noise_regularize * n_loss + args.mse * mse_loss
optimizer.zero_grad()
loss.backward()
optimizer.step()
noise_normalize_(noises)
pbar.set_description(
f' perceptual: {p_loss.item():.4f}, noise regularize:'
f'{n_loss.item():.4f}, mse: {mse_loss.item():.4f}, lr: {lr:.4f}')
results = generator([latent_in.detach().clone()],
input_is_latent=True,
injected_noise=noises)
# rescale value range to [0, 1]
results = ((results + 1) / 2)
results = results[:, [2, 1, 0], ...]
results = results.clamp_(0, 1)
mmcv.mkdir_or_exist(args.results_path)
# save projection results
result_file = OrderedDict()
for i, input_name in enumerate(args.files):
noise_single = []
for noise in noises:
noise_single.append(noise[i:i + 1])
result_file[input_name] = {
'img': img_gen[i],
'latent': latent_in[i],
'injected_noise': noise_single,
}
img_name = os.path.splitext(
os.path.basename(input_name))[0] + '-project.png'
save_image(results[i], os.path.join(args.results_path, img_name))
torch.save(result_file, os.path.join(args.results_path,
'project_result.pt'))
if __name__ == '__main__':
main()
dataset_type = 'UnconditionalImageDataset'
# To be noted that, `Resize` operation with `pillow` backend and
# `bicubic` interpolation is the must for correct IS evaluation
val_pipeline = [
dict(
type='LoadImageFromFile',
key='real_img',
io_backend='disk',
),
dict(
type='Resize',
keys=['real_img'],
scale=(299, 299),
backend='pillow',
interpolation='bicubic'),
dict(
type='Normalize',
keys=['real_img'],
mean=[127.5] * 3,
std=[127.5] * 3,
to_rgb=True),
dict(type='ImageToTensor', keys=['real_img']),
dict(type='Collect', keys=['real_img'], meta_keys=['real_img_path'])
]
data = dict(
samples_per_gpu=None,
workers_per_gpu=4,
val=dict(type=dataset_type, imgs_root=None, pipeline=val_pipeline))
dataset_type = 'mmcls.CIFAR10'
# different from mmcls, we adopt the setting used in BigGAN
# Note that the pipelines below are from MMClassification. Importantly, the
# `to_rgb` is set to `True` to convert image to BGR orders. The default order
# in Cifar10 is RGB. Thus, we have to convert it to BGR.
img_norm_cfg = dict(
mean=[127.5, 127.5, 127.5], std=[127.5, 127.5, 127.5], to_rgb=True)
train_pipeline = [
dict(type='RandomCrop', size=32, padding=4),
dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='ToTensor', keys=['gt_label']),
dict(type='Collect', keys=['img', 'gt_label'])
]
test_pipeline = [
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img'])
]
# Different from the classification task, the val/test split also use the
# training part, which is the same to StyleGAN-ADA.
data = dict(
samples_per_gpu=None,
workers_per_gpu=4,
train=dict(
type=dataset_type, data_prefix='data/cifar10',
pipeline=train_pipeline),
val=dict(
type=dataset_type, data_prefix='data/cifar10', pipeline=test_pipeline),
test=dict(
type=dataset_type, data_prefix='data/cifar10', pipeline=test_pipeline))
dataset_type = 'mmcls.CIFAR10'
# This config is set for extract inception state of CIFAR dataset.
# Different from mmcls, we adopt the setting used in BigGAN.
# Note that the pipelines below are from MMClassification.
# The default order in Cifar10 is RGB. Thus, we set `to_rgb` as `False`.
img_norm_cfg = dict(
mean=[127.5, 127.5, 127.5], std=[127.5, 127.5, 127.5], to_rgb=False)
train_pipeline = [
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='ToTensor', keys=['gt_label']),
dict(type='Collect', keys=['img', 'gt_label'])
]
test_pipeline = [
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img'])
]
# Different from the classification task, the val/test split also use the
# training part, which is the same to StyleGAN-ADA.
data = dict(
samples_per_gpu=None,
workers_per_gpu=4,
train=dict(
type=dataset_type, data_prefix='data/cifar10',
pipeline=train_pipeline),
val=dict(
type=dataset_type, data_prefix='data/cifar10', pipeline=test_pipeline),
test=dict(
type=dataset_type, data_prefix='data/cifar10', pipeline=test_pipeline))
dataset_type = 'mmcls.CIFAR10'
# different from mmcls, we adopt the setting used in BigGAN
# Note that the pipelines below are from MMClassification. Importantly, the
# `to_rgb` is set to `True` to convert image to BGR orders. The default order
# in Cifar10 is RGB. Thus, we have to convert it to BGR.
# Cifar dataset w/o augmentations. Remove `RandomFlip` and `RandomCrop`
# augmentations.
img_norm_cfg = dict(
mean=[127.5, 127.5, 127.5], std=[127.5, 127.5, 127.5], to_rgb=True)
train_pipeline = [
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='ToTensor', keys=['gt_label']),
dict(type='Collect', keys=['img', 'gt_label'])
]
test_pipeline = [
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img'])
]
# Different from the classification task, the val/test split also use the
# training part, which is the same to StyleGAN-ADA.
data = dict(
samples_per_gpu=None,
workers_per_gpu=4,
train=dict(
type='RepeatDataset',
times=500,
dataset=dict(
type=dataset_type,
data_prefix='data/cifar10',
pipeline=train_pipeline)),
val=dict(
type=dataset_type, data_prefix='data/cifar10', pipeline=test_pipeline),
test=dict(
type=dataset_type, data_prefix='data/cifar10', pipeline=test_pipeline))
dataset_type = 'mmcls.CIFAR10'
# different from mmcls, we adopt the setting used in BigGAN
# Note that the pipelines below are from MMClassification. Importantly, the
# `to_rgb` is set to `True` to convert image to BGR orders. The default order
# in Cifar10 is RGB. Thus, we have to convert it to BGR.
img_norm_cfg = dict(
mean=[127.5, 127.5, 127.5], std=[127.5, 127.5, 127.5], to_rgb=True)
train_pipeline = [
dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='ToTensor', keys=['gt_label']),
dict(type='Collect', keys=['img', 'gt_label'])
]
test_pipeline = [
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img'])
]
# Different from the classification task, the val/test split also use the
# training part, which is the same to StyleGAN-ADA.
data = dict(
samples_per_gpu=None,
workers_per_gpu=4,
train=dict(
type=dataset_type, data_prefix='data/cifar10',
pipeline=train_pipeline),
val=dict(
type=dataset_type, data_prefix='data/cifar10', pipeline=test_pipeline),
test=dict(
type=dataset_type, data_prefix='data/cifar10', pipeline=test_pipeline))
dataset_type = 'mmcls.CIFAR10'
# cifar dataset without augmentation
# different from mmcls, we adopt the setting used in BigGAN
# Note that the pipelines below are from MMClassification. Importantly, the
# `to_rgb` is set to `True` to convert image to BGR orders. The default order
# in Cifar10 is RGB. Thus, we have to convert it to BGR.
# Follow the pipeline in
# https://github.com/pfnet-research/sngan_projection/blob/master/datasets/cifar10.py
# Only `RandomImageNoise` augmentation is adopted.
img_norm_cfg = dict(
mean=[127.5, 127.5, 127.5], std=[127.5, 127.5, 127.5], to_rgb=True)
train_pipeline = [
dict(type='Normalize', **img_norm_cfg),
dict(type='RandomImgNoise', keys=['img']),
dict(type='ImageToTensor', keys=['img']),
dict(type='ToTensor', keys=['gt_label']),
dict(type='Collect', keys=['img', 'gt_label'])
]
test_pipeline = [
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img'])
]
# Different from the classification task, the val/test split also use the
# training part, which is the same to StyleGAN-ADA.
data = dict(
samples_per_gpu=None,
workers_per_gpu=4,
train=dict(
type=dataset_type, data_prefix='data/cifar10',
pipeline=train_pipeline),
val=dict(
type=dataset_type, data_prefix='data/cifar10', pipeline=test_pipeline),
test=dict(
type=dataset_type, data_prefix='data/cifar10', pipeline=test_pipeline))
dataset_type = 'mmcls.CIFAR10'
# different from mmcls, we adopt the setting used in BigGAN
# Note that the pipelines below are from MMClassification
img_norm_cfg = dict(
mean=[127.5, 127.5, 127.5], std=[127.5, 127.5, 127.5], to_rgb=False)
train_pipeline = [
dict(type='RandomCrop', size=32, padding=4),
dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='ToTensor', keys=['gt_label']),
dict(type='Collect', keys=['img', 'gt_label'])
]
test_pipeline = [
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img'])
]
# Different from the classification task, the val/test split also use the
# training part, which is the same to StyleGAN-ADA.
data = dict(
samples_per_gpu=None,
workers_per_gpu=4,
train=dict(
type=dataset_type, data_prefix='data/cifar10',
pipeline=train_pipeline),
val=dict(
type=dataset_type, data_prefix='data/cifar10', pipeline=test_pipeline),
test=dict(
type=dataset_type, data_prefix='data/cifar10', pipeline=test_pipeline))
dataset_type = 'UnconditionalImageDataset'
train_pipeline = [
dict(
type='LoadImageFromFile',
key='real_img',
io_backend='disk',
),
dict(type='Flip', keys=['real_img'], direction='horizontal'),
dict(
type='Normalize',
keys=['real_img'],
mean=[127.5] * 3,
std=[127.5] * 3,
to_rgb=False),
dict(type='ImageToTensor', keys=['real_img']),
dict(type='Collect', keys=['real_img'], meta_keys=['real_img_path'])
]
val_pipeline = [
dict(
type='LoadImageFromFile',
key='real_img',
io_backend='disk',
),
dict(
type='Normalize',
keys=['real_img'],
mean=[127.5] * 3,
std=[127.5] * 3,
to_rgb=True),
dict(type='ImageToTensor', keys=['real_img']),
dict(type='Collect', keys=['real_img'], meta_keys=['real_img_path'])
]
# `samples_per_gpu` and `imgs_root` need to be set.
data = dict(
samples_per_gpu=None,
workers_per_gpu=4,
train=dict(
type='RepeatDataset',
times=100,
dataset=dict(
type=dataset_type, imgs_root=None, pipeline=train_pipeline)),
val=dict(type=dataset_type, imgs_root=None, pipeline=val_pipeline))
dataset_type = 'GrowScaleImgDataset'
train_pipeline = [
dict(
type='LoadImageFromFile',
key='real_img',
io_backend='disk',
),
dict(type='Resize', keys=['real_img'], scale=(128, 128)),
dict(type='Flip', keys=['real_img'], direction='horizontal'),
dict(
type='Normalize',
keys=['real_img'],
mean=[127.5] * 3,
std=[127.5] * 3,
to_rgb=False),
dict(type='ImageToTensor', keys=['real_img']),
dict(type='Collect', keys=['real_img'], meta_keys=['real_img_path'])
]
# `samples_per_gpu` and `imgs_root` need to be set.
data = dict(
# samples per gpu should be the same as the first scale, e.g. '4': 64
# in this case
samples_per_gpu=None,
workers_per_gpu=4,
train=dict(
type=dataset_type,
# just an example
imgs_roots={'128': './data/lsun/bedroom_train'},
pipeline=train_pipeline,
gpu_samples_base=4,
# note that this should be changed with total gpu number
gpu_samples_per_scale={
'4': 64,
'8': 32,
'16': 16,
'32': 8,
'64': 4
},
len_per_stage=-1))
dataset_type = 'GrowScaleImgDataset'
train_pipeline = [
dict(
type='LoadImageFromFile',
key='real_img',
io_backend='disk',
),
dict(type='Flip', keys=['real_img'], direction='horizontal'),
dict(
type='Normalize',
keys=['real_img'],
mean=[127.5] * 3,
std=[127.5] * 3,
to_rgb=False),
dict(type='ImageToTensor', keys=['real_img']),
dict(type='Collect', keys=['real_img'], meta_keys=['real_img_path'])
]
# `samples_per_gpu` and `imgs_root` need to be set.
data = dict(
# samples per gpu should be the same as the first scale, e.g. '4': 64
# in this case
samples_per_gpu=None,
workers_per_gpu=4,
train=dict(
type=dataset_type,
# just an example
imgs_roots={
'64': './data/celebahq/imgs_64',
'256': './data/celebahq/imgs_256',
'512': './data/celebahq/imgs_512',
'1024': './data/celebahq/imgs_1024'
},
pipeline=train_pipeline,
gpu_samples_base=4,
# note that this should be changed with total gpu number
gpu_samples_per_scale={
'4': 64,
'8': 32,
'16': 16,
'32': 8,
'64': 4
},
len_per_stage=300000))
dataset_type = 'GrowScaleImgDataset'
train_pipeline = [
dict(
type='LoadImageFromFile',
key='real_img',
io_backend='disk',
),
dict(type='Flip', keys=['real_img'], direction='horizontal'),
dict(
type='Normalize',
keys=['real_img'],
mean=[127.5, 127.5, 127.5],
std=[127.5, 127.5, 127.5],
to_rgb=False),
dict(type='ImageToTensor', keys=['real_img']),
dict(type='Collect', keys=['real_img'], meta_keys=['real_img_path'])
]
data = dict(
samples_per_gpu=64,
workers_per_gpu=4,
train=dict(
type='GrowScaleImgDataset',
imgs_roots=dict({
'1024': './data/ffhq/images',
'256': './data/ffhq/ffhq_imgs/ffhq_256',
'64': './data/ffhq/ffhq_imgs/ffhq_64'
}),
pipeline=train_pipeline,
gpu_samples_base=4,
gpu_samples_per_scale={
'4': 64,
'8': 32,
'16': 16,
'32': 8,
'64': 4,
'128': 4,
'256': 4,
'512': 4,
'1024': 4
},
len_per_stage=300000))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment