Commit 05f0839a authored by dengjb's avatar dengjb
Browse files

update

parents
datasets/
checkpoints/
results/
build/
dist/
*.png
torch.egg-info/
*/**/__pycache__
torch/version.py
torch/csrc/generic/TensorMethods.cpp
torch/lib/*.so*
torch/lib/*.dylib*
torch/lib/*.h
torch/lib/build
torch/lib/tmp_install
torch/lib/include
torch/lib/torch_shm_manager
torch/csrc/cudnn/cuDNN.cpp
torch/csrc/nn/THNN.cwrap
torch/csrc/nn/THNN.cpp
torch/csrc/nn/THCUNN.cwrap
torch/csrc/nn/THCUNN.cpp
torch/csrc/nn/THNN_generic.cwrap
torch/csrc/nn/THNN_generic.cpp
torch/csrc/nn/THNN_generic.h
docs/src/**/*
test/data/legacy_modules.t7
test/data/gpu_tensors.pt
test/htmlcov
test/.coverage
*/*.pyc
*/**/*.pyc
*/**/**/*.pyc
*/**/**/**/*.pyc
*/**/**/**/**/*.pyc
*/*.so*
*/**/*.so*
*/**/*.dylib*
test/data/legacy_serialized.pt
*~
Copyright (c) 2017, Jun-Yan Zhu and Taesung Park
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
* Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
--------------------------- LICENSE FOR pix2pix --------------------------------
BSD License
For pix2pix software
Copyright (c) 2016, Phillip Isola and Jun-Yan Zhu
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
* Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
----------------------------- LICENSE FOR DCGAN --------------------------------
BSD License
For dcgan.torch software
Copyright (c) 2015, Facebook, Inc. All rights reserved.
Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
Neither the name Facebook nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
# CycleGAN
## 论文
[CycleGAN](https://arxiv.org/pdf/1703.10593.pdf)
## 模型简介
CycleGAN是一种基于循环一致性的生成对抗网络,主要用于非配对数据的跨域图像转换,核心应用包括‌网络结构创新‌、‌超分辨率重建‌和‌风格迁移‌。
### CycleGAN网络结构
**‌双生成器架构**‌:包含两个生成器(G和F)及两个判别器(D_X和D_Y),通过循环一致性损失(Cycle-Consistency Loss)实现无监督训练。‌‌
**‌注意力机制**‌:部分改进版本引入残差网络和注意力模块,增强低光照图像等特殊场景的细节保留能力。‌‌
**‌性能对比**‌:在Enlighten和LOL数据集上,改进后的CycleGAN相比传统GAN模型(如UNIT)PSNR提升约3-5 dB
<img src="https://junyanz.github.io/CycleGAN/images/teaser_high_res.jpg" width="1000px"/>
## 环境依赖
| 软件 | 版本 |
| :------: | :------: |
| DTK | 25.04.2 |
| python | 3.10.12 |
| transformers | 4.57.1 |
| vllm | 0.11.0+das.opt1.alpha.8e22ded.dtk25042 |
| torch | 2.5.1+das.opt1.dtk25042 |
| triton | 3.1+das.opt1.3c5d12d.dtk25041 |
| flash_attn | 2.6.1+das.opt1.dtk2504 |
| flash_mla | 1.0.0+das.opt1.dtk25042 |
当前仅支持镜像:
- 挂载地址`-v`根据实际模型情况修改
```bash
docker run -it --shm-size 60g --network=host --name cyclegan --privileged --device=/dev/kfd --device=/dev/dri --device=/dev/mkfd --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined -u root -v /opt/hyhal/:/opt/hyhal/:ro -v /path/your_code_path/:/path/your_code_path/ image.sourcefind.cn:5000/dcu/admin/base/pytorch:2.4.1-ubuntu22.04-dtk25.04.1-py3.11 bash
```
更多镜像可前往[光源](https://sourcefind.cn/#/service-list)下载使用。
关于本项目DCU显卡所需的特殊深度学习库可从[光合](https://developer.sourcefind.cn/tool/)开发者社区下载安装。
## 数据集
使用以下脚本下载数据集。许多数据集是由其他研究人员收集的。如果您使用这些数据,请引用他们的论文。
```bash
bash ./datasets/download_dataset.sh dataset_name
```
- `facades`:来自 [CMP Facades 数据集](http://cmp.felk.cvut.cz/~tylecr1/facade/) 的 400 张图像。[[引用](datasets/bibtex/facades.tex)]
- `cityscapes`:来自 [Cityscapes 训练集](https://www.cityscapes-dataset.com/) 的 2975 张图像。[[引用](datasets/bibtex/cityscapes.tex)]。注意:由于许可问题,我们不在我们的仓库中托管该数据集。请直接从 Cityscapes 网页下载数据集。有关详细信息,请参阅 [./datasets/prepare_cityscapes_dataset.py](file:///Users/dlyrm/work_space/model_zoo/cyclegan_pytorch/datasets/prepare_cityscapes_dataset.py)
- `maps`:从 Google Maps 抓取的 1096 张训练图像。
- `horse2zebra`:使用关键词 `wild horse``zebra`[ImageNet](http://www.image-net.org/) 下载的 939 张马的图像和 1177 张斑马的图像。
- `apple2orange`:使用关键词 `apple``navel orange`[ImageNet](http://www.image-net.org/) 下载的 996 张苹果图像和 1020 张橙子图像。
- `summer2winter_yosemite`:使用 Flickr API 下载的 1273 张夏季约塞米蒂图像和 854 张冬季约塞米蒂图像。详情请见我们的论文。
- `monet2photo``vangogh2photo``ukiyoe2photo``cezanne2photo`:艺术图像从 [Wikiart](https://www.wikiart.org/) 下载。真实照片使用 *landscape**landscapephotography* 标签组合从 Flickr 下载。每类的训练集大小为:莫奈:1074,塞尚:584,梵高:401,浮世绘:1433,照片:6853。
- `iphone2dslr_flower`:两类图像均从 Flickr 下载。每类的训练集大小为 iPhone:1813,DSLR:3316。详情请见我们的论文。
## 训练
### 应用预训练模型
- 下载测试图片(由 [Alexei Efros](https://www.flickr.com/photos/aaefros) 拍摄):
```
bash ./datasets/download_dataset.sh ae_photos
```
- 下载预训练模型 `style_cezanne`(对于 CPU 模型,使用 `style_cezanne_cpu`):
```
bash ./pretrained_models/download_model.sh style_cezanne
```
- 现在,让我们生成保罗·塞尚风格的图像:
```
DATA_ROOT=./datasets/ae_photos name=style_cezanne_pretrained model=one_direction_test phase=test loadSize=256 fineSize=256 resize_or_crop="scale_width" th test.lua
```
测试结果将保存到 `./results/style_cezanne_pretrained/latest_test/index.html`
更多预训练模型请参见 [模型库](#model-zoo)
[./examples/test_vangogh_style_on_ae_photos.sh](file:///Users/dlyrm/work_space/model_zoo/cyclegan_pytorch/examples/test_vangogh_style_on_ae_photos.sh) 是一个示例脚本,它下载预训练的梵高风格网络并在 Efros 的照片上运行。
### 训练
- 下载数据集(例如来自 ImageNet 的斑马和马的图像):
```bash
bash ./datasets/download_dataset.sh horse2zebra
```
- 训练模型:
```bash
DATA_ROOT=./datasets/horse2zebra name=horse2zebra_model th train.lua
```
- (仅限 CPU)不使用 GPU 或 CUDNN 的相同训练命令。设置环境变量 ```gpu=0 cudnn=0``` 强制仅使用 CPU:
```bash
DATA_ROOT=./datasets/horse2zebra name=horse2zebra_model gpu=0 cudnn=0 th train.lua
```
- (可选)启动显示服务器以在模型训练时查看结果。(详见 [显示界面](#display-ui)):
```bash
th -ldisplay.start 8000 0.0.0.0
```
## 推理
### 测试
- 最后,测试模型:
```bash
DATA_ROOT=./datasets/horse2zebra name=horse2zebra_model phase=test th test.lua
```
测试结果将保存到这里的HTML文件中:`./results/horse2zebra_model/latest_test/index.html`
这段命令会使用训练好的 `horse2zebra_model` 模型对 `./datasets/horse2zebra` 数据集中的测试图像进行转换测试,并将生成的结果以网页形式保存在指定路径下。
### 精度
DCU与GPU精度一致,推理框架:vllm。
## 预训练权重
下载方式
```
bash ./pretrained_models/download_model.sh model_name
```
| 模型名称 | 权重大小 | DCU型号 | 最低卡数需求 |下载地址|
|:-----:|:----------:|:----------:|:---------------------:|:----------:|
| latest_net_G | - | K100AI | 1 | - |
## 源码仓库及问题反馈
- https://developer.sourcefind.cn/codes/modelzoo/cyclegan_pytorch
## 参考资料
- https://github.com/junyanz/CycleGAN
<img src='imgs/horse2zebra.gif' align="right" width=384>
<br><br><br>
# CycleGAN
### [PyTorch](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix) | [project page](https://junyanz.github.io/CycleGAN/) | [paper](https://arxiv.org/pdf/1703.10593.pdf)
Torch implementation for learning an image-to-image translation (i.e. [pix2pix](https://github.com/phillipi/pix2pix)) **without** input-output pairs, for example:
**New**: Please check out [contrastive-unpaired-translation](https://github.com/taesungp/contrastive-unpaired-translation) (CUT), our new unpaired image-to-image translation model that enables fast and memory-efficient training.
<img src="https://junyanz.github.io/CycleGAN/images/teaser_high_res.jpg" width="1000px"/>
[Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks](https://junyanz.github.io/CycleGAN/)
[Jun-Yan Zhu](https://people.eecs.berkeley.edu/~junyanz/)\*, [Taesung Park](https://taesung.me/)\*, [Phillip Isola](http://web.mit.edu/phillipi/), [Alexei A. Efros](https://people.eecs.berkeley.edu/~efros/)
Berkeley AI Research Lab, UC Berkeley
In ICCV 2017. (* equal contributions)
This package includes CycleGAN, [pix2pix](https://github.com/phillipi/pix2pix), as well as other methods like [BiGAN](https://arxiv.org/abs/1605.09782)/[ALI](https://ishmaelbelghazi.github.io/ALI/) and Apple's paper [S+U learning](https://arxiv.org/pdf/1612.07828.pdf).
The code was written by [Jun-Yan Zhu](https://github.com/junyanz) and [Taesung Park](https://github.com/taesung).
**Update**: Please check out [PyTorch](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix) implementation for CycleGAN and pix2pix.
The PyTorch version is under active development and can produce results comparable or better than this Torch version.
## Other implementations:
<p><a href="https://github.com/leehomyc/cyclegan-1"> [Tensorflow]</a> (by Harry Yang),
<a href="https://github.com/architrathore/CycleGAN/">[Tensorflow]</a> (by Archit Rathore),
<a href="https://github.com/vanhuyz/CycleGAN-TensorFlow">[Tensorflow]</a> (by Van Huy),
<a href="https://github.com/XHUJOY/CycleGAN-tensorflow">[Tensorflow]</a> (by Xiaowei Hu),
<a href="https://github.com/LynnHo/CycleGAN-Tensorflow-Simple"> [Tensorflow-simple]</a> (by Zhenliang He),
<a href="https://github.com/luoxier/CycleGAN_Tensorlayer"> [TensorLayer]</a> (by luoxier),
<a href="https://github.com/Aixile/chainer-cyclegan">[Chainer]</a> (by Yanghua Jin),
<a href="https://github.com/yunjey/mnist-svhn-transfer">[Minimal PyTorch]</a> (by yunjey),
<a href="https://github.com/Ldpe2G/DeepLearningForFun/tree/master/Mxnet-Scala/CycleGAN">[Mxnet]</a> (by Ldpe2G),
<a href="https://github.com/tjwei/GANotebooks">[lasagne/Keras]</a> (by tjwei),
<a href="https://github.com/simontomaskarlsson/CycleGAN-Keras">[Keras]</a> (by Simon Karlsson)</p>
</ul>
## Applications
### Monet Paintings to Photos
<img src="https://junyanz.github.io/CycleGAN/images/painting2photo.jpg" width="1000px"/>
### Collection Style Transfer
<img src="https://junyanz.github.io/CycleGAN/images/photo2painting.jpg" width="1000px"/>
### Object Transfiguration
<img src="https://junyanz.github.io/CycleGAN/images/objects.jpg" width="1000px"/>
### Season Transfer
<img src="https://junyanz.github.io/CycleGAN/images/season.jpg" width="1000px"/>
### Photo Enhancement: Narrow depth of field
<img src="https://junyanz.github.io/CycleGAN/images/photo_enhancement.jpg" width="1000px"/>
## Prerequisites
- Linux or OSX
- NVIDIA GPU + CUDA CuDNN (CPU mode and CUDA without CuDNN may work with minimal modification, but untested)
- For MAC users, you need the Linux/GNU commands `gfind` and `gwc`, which can be installed with `brew install findutils coreutils`.
## Getting Started
### Installation
- Install torch and dependencies from https://github.com/torch/distro
- Install torch packages `nngraph`, `class`, `display`
```bash
luarocks install nngraph
luarocks install class
luarocks install https://raw.githubusercontent.com/szym/display/master/display-scm-0.rockspec
```
- Clone this repo:
```bash
git clone https://github.com/junyanz/CycleGAN
cd CycleGAN
```
### Apply a Pre-trained Model
- Download the test photos (taken by [Alexei Efros](https://www.flickr.com/photos/aaefros)):
```
bash ./datasets/download_dataset.sh ae_photos
```
- Download the pre-trained model `style_cezanne` (For CPU model, use `style_cezanne_cpu`):
```
bash ./pretrained_models/download_model.sh style_cezanne
```
- Now, let's generate Paul Cézanne style images:
```
DATA_ROOT=./datasets/ae_photos name=style_cezanne_pretrained model=one_direction_test phase=test loadSize=256 fineSize=256 resize_or_crop="scale_width" th test.lua
```
The test results will be saved to `./results/style_cezanne_pretrained/latest_test/index.html`.
Please refer to [Model Zoo](#model-zoo) for more pre-trained models.
`./examples/test_vangogh_style_on_ae_photos.sh` is an example script that downloads the pretrained Van Gogh style network and runs it on Efros's photos.
### Train
- Download a dataset (e.g. zebra and horse images from ImageNet):
```bash
bash ./datasets/download_dataset.sh horse2zebra
```
- Train a model:
```bash
DATA_ROOT=./datasets/horse2zebra name=horse2zebra_model th train.lua
```
- (CPU only) The same training command without using a GPU or CUDNN. Setting the environment variables ```gpu=0 cudnn=0``` forces CPU only
```bash
DATA_ROOT=./datasets/horse2zebra name=horse2zebra_model gpu=0 cudnn=0 th train.lua
```
- (Optionally) start the display server to view results as the model trains. (See [Display UI](#display-ui) for more details):
```bash
th -ldisplay.start 8000 0.0.0.0
```
### Test
- Finally, test the model:
```bash
DATA_ROOT=./datasets/horse2zebra name=horse2zebra_model phase=test th test.lua
```
The test results will be saved to an HTML file here: `./results/horse2zebra_model/latest_test/index.html`.
## Model Zoo
Download the pre-trained models with the following script. The model will be saved to `./checkpoints/model_name/latest_net_G.t7`.
```bash
bash ./pretrained_models/download_model.sh model_name
```
- `orange2apple` (orange -> apple) and `apple2orange`: trained on ImageNet categories `apple` and `orange`.
- `horse2zebra` (horse -> zebra) and `zebra2horse` (zebra -> horse): trained on ImageNet categories `horse` and `zebra`.
- `style_monet` (landscape photo -> Monet painting style), `style_vangogh` (landscape photo -> Van Gogh painting style), `style_ukiyoe` (landscape photo -> Ukiyo-e painting style), `style_cezanne` (landscape photo -> Cezanne painting style): trained on paintings and Flickr landscape photos.
- `monet2photo` (Monet paintings -> real landscape): trained on paintings and Flickr landscape photographs.
- `cityscapes_photo2label` (street scene -> label) and `cityscapes_label2photo` (label -> street scene): trained on the Cityscapes dataset.
- `map2sat` (map -> aerial photo) and `sat2map` (aerial photo -> map): trained on Google maps.
- `iphone2dslr_flower` (iPhone photos of flowers -> DSLR photos of flowers): trained on Flickr photos.
CPU models can be downloaded using:
```bash
bash pretrained_models/download_model.sh <name>_cpu
```
, where `<name>` can be `horse2zebra`, `style_monet`, etc. You just need to append `_cpu` to the target model.
## Training and Test Details
To train a model,
```bash
DATA_ROOT=/path/to/data/ name=expt_name th train.lua
```
Models are saved to `./checkpoints/expt_name` (can be changed by passing `checkpoint_dir=your_dir` in train.lua).
See `opt_train` in `options.lua` for additional training options.
To test the model,
```bash
DATA_ROOT=/path/to/data/ name=expt_name phase=test th test.lua
```
This will run the model named `expt_name` in both directions on all images in `/path/to/data/testA` and `/path/to/data/testB`.
A webpage with result images will be saved to `./results/expt_name` (can be changed by passing `results_dir=your_dir` in test.lua).
See `opt_test` in `options.lua` for additional test options. Please use `model=one_direction_test` if you only would like to generate outputs of the trained network in only one direction, and specify `which_direction=AtoB` or `which_direction=BtoA` to set the direction.
There are other options that can be used. For example, you can specify `resize_or_crop=crop` option to avoid resizing the image to squares. This is indeed how we trained GTA2Cityscapes model in the projet [webpage](https://junyanz.github.io/CycleGAN/) and [Cycada](https://arxiv.org/pdf/1711.03213.pdf) model. We prepared the images at 1024px resolution, and used `resize_or_crop=crop fineSize=360` to work with the cropped images of size 360x360. We also used `lambda_identity=1.0`.
## Datasets
Download the datasets using the following script. Many of the datasets were collected by other researchers. Please cite their papers if you use the data.
```bash
bash ./datasets/download_dataset.sh dataset_name
```
- `facades`: 400 images from the [CMP Facades dataset](http://cmp.felk.cvut.cz/~tylecr1/facade/). [[Citation](datasets/bibtex/facades.tex)]
- `cityscapes`: 2975 images from the [Cityscapes training set](https://www.cityscapes-dataset.com/). [[Citation](datasets/bibtex/cityscapes.tex)]. Note: Due to license issue, we do not host the dataset on our repo. Please download the dataset directly from the Cityscapes webpage. Please refer to `./datasets/prepare_cityscapes_dataset.py` for more detail.
- `maps`: 1096 training images scraped from Google Maps.
- `horse2zebra`: 939 horse images and 1177 zebra images downloaded from [ImageNet](http://www.image-net.org/) using the keywords `wild horse` and `zebra`
- `apple2orange`: 996 apple images and 1020 orange images downloaded from [ImageNet](http://www.image-net.org/) using the keywords `apple` and `navel orange`.
- `summer2winter_yosemite`: 1273 summer Yosemite images and 854 winter Yosemite images were downloaded using Flickr API. See more details in our paper.
- `monet2photo`, `vangogh2photo`, `ukiyoe2photo`, `cezanne2photo`: The art images were downloaded from [Wikiart](https://www.wikiart.org/). The real photos are downloaded from Flickr using the combination of the tags *landscape* and *landscapephotography*. The training set size of each class is Monet:1074, Cezanne:584, Van Gogh:401, Ukiyo-e:1433, Photographs:6853.
- `iphone2dslr_flower`: both classes of images were downloaded from Flickr. The training set size of each class is iPhone:1813, DSLR:3316. See more details in our paper.
## Display UI
Optionally, for displaying images during training and test, use the [display package](https://github.com/szym/display).
- Install it with: `luarocks install https://raw.githubusercontent.com/szym/display/master/display-scm-0.rockspec`
- Then start the server with: `th -ldisplay.start`
- Open this URL in your browser: [http://localhost:8000](http://localhost:8000)
By default, the server listens on localhost. Pass `0.0.0.0` to allow external connections on any interface:
```bash
th -ldisplay.start 8000 0.0.0.0
```
Then open `http://(hostname):(port)/` in your browser to load the remote desktop.
## Setup Training and Test data
To train CycleGAN model on your own datasets, you need to create a data folder with two subdirectories `trainA` and `trainB` that contain images from domain A and B. You can test your model on your training set by setting ``phase='train'`` in `test.lua`. You can also create subdirectories `testA` and `testB` if you have test data.
You should **not** expect our method to work on just any random combination of input and output datasets (e.g. `cats<->keyboards`). From our experiments, we find it works better if two datasets share similar visual content. For example, `landscape painting<->landscape photographs` works much better than `portrait painting <-> landscape photographs`. `zebras<->horses` achieves compelling results while `cats<->dogs` completely fails. See the following section for more discussion.
## Failure cases
<img align="left" style="padding:10px" src="https://junyanz.github.io/CycleGAN/images/failure_putin.jpg" width=320>
Our model does not work well when the test image is rather different from the images on which the model is trained, as is the case in the figure to the left (we trained on horses and zebras without riders, but test here one a horse with a rider). See additional typical failure cases [here](https://junyanz.github.io/CycleGAN/images/failures.jpg). On translation tasks that involve color and texture changes, like many of those reported above, the method often succeeds. We have also explored tasks that require geometric changes, with little success. For example, on the task of `dog<->cat` transfiguration, the learned translation degenerates into making minimal changes to the input. We also observe a lingering gap between the results achievable with paired training data and those achieved by our unpaired method. In some cases, this gap may be very hard -- or even impossible,-- to close: for example, our method sometimes permutes the labels for tree and building in the output of the cityscapes photos->labels task.
## Citation
If you use this code for your research, please cite our [paper](https://junyanz.github.io/CycleGAN/):
```
@inproceedings{CycleGAN2017,
title={Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networkss},
author={Zhu, Jun-Yan and Park, Taesung and Isola, Phillip and Efros, Alexei A},
booktitle={Computer Vision (ICCV), 2017 IEEE International Conference on},
year={2017}
}
```
## Related Projects:
**[contrastive-unpaired-translation](https://github.com/taesungp/contrastive-unpaired-translation) (CUT)**<br>
**[pix2pix-Torch](https://github.com/phillipi/pix2pix) | [pix2pixHD](https://github.com/NVIDIA/pix2pixHD) |
[BicycleGAN](https://github.com/junyanz/BicycleGAN) | [vid2vid](https://tcwang0509.github.io/vid2vid/) | [SPADE/GauGAN](https://github.com/NVlabs/SPADE)**<br>
**[iGAN](https://github.com/junyanz/iGAN) | [GAN Dissection](https://github.com/CSAILVision/GANDissect) | [GAN Paint](http://ganpaint.io/)**
## Cat Paper Collection
If you love cats, and love reading cool graphics, vision, and ML papers, please check out the Cat Paper [Collection](https://github.com/junyanz/CatPapers).
## Acknowledgments
Code borrows from [pix2pix](https://github.com/phillipi/pix2pix) and [DCGAN](https://github.com/soumith/dcgan.torch). The data loader is modified from [DCGAN](https://github.com/soumith/dcgan.torch) and [Context-Encoder](https://github.com/pathak22/context-encoder). The generative network is adopted from [neural-style](https://github.com/jcjohnson/neural-style) with [Instance Normalization](https://github.com/DmitryUlyanov/texture_nets/blob/master/InstanceNormalization.lua).
--------------------------------------------------------------------------------
-- Subclass of BaseDataLoader that provides data from two datasets.
-- The samples from the datasets are aligned
-- The datasets are of the same size
--------------------------------------------------------------------------------
require 'data.base_data_loader'
local class = require 'class'
data_util = paths.dofile('data_util.lua')
AlignedDataLoader = class('AlignedDataLoader', 'BaseDataLoader')
function AlignedDataLoader:__init(conf)
BaseDataLoader.__init(self, conf)
conf = conf or {}
end
function AlignedDataLoader:name()
return 'AlignedDataLoader'
end
function AlignedDataLoader:Initialize(opt)
opt.align_data = 1
self.idx_A = {1, opt.input_nc}
self.idx_B = {opt.input_nc+1, opt.input_nc+opt.output_nc}
local nc = 3--opt.input_nc + opt.output_nc
self.data = data_util.load_dataset('', opt, nc)
end
-- actually fetches the data
-- |return|: a table of two tables, each corresponding to
-- the batch for dataset A and dataset B
function AlignedDataLoader:LoadBatchForAllDatasets()
local batch_data, path = self.data:getBatch()
local batchA = batch_data[{ {}, self.idx_A, {}, {} }]
local batchB = batch_data[{ {}, self.idx_B, {}, {} }]
return batchA, batchB, path, path
end
-- returns the size of each dataset
function AlignedDataLoader:size(dataset)
return self.data:size()
end
--------------------------------------------------------------------------------
-- Base Class for Providing Data
--------------------------------------------------------------------------------
local class = require 'class'
require 'torch'
BaseDataLoader = class('BaseDataLoader')
function BaseDataLoader:__init(conf)
conf = conf or {}
self.data_tm = torch.Timer()
end
function BaseDataLoader:name()
return 'BaseDataLoader'
end
function BaseDataLoader:Initialize(opt)
end
-- actually fetches the data
-- |return|: a table of two tables, each corresponding to
-- the batch for dataset A and dataset B
function BaseDataLoader:LoadBatchForAllDatasets()
return {},{},{},{}
end
-- returns the next batch
-- a wrapper of getBatch(), which is meant to be overriden by subclasses
-- |return|: a table of two tables, each corresponding to
-- the batch for dataset A and dataset B
function BaseDataLoader:GetNextBatch()
self.data_tm:reset()
self.data_tm:resume()
local dataA, dataB, pathA, pathB = self:LoadBatchForAllDatasets()
self.data_tm:stop()
return dataA, dataB, pathA, pathB
end
function BaseDataLoader:time_elapsed_to_fetch_data()
return self.data_tm:time().real
end
-- returns the size of each dataset
function BaseDataLoader:size(dataset)
return 0
end
--[[
This data loader is a modified version of the one from dcgan.torch
(see https://github.com/soumith/dcgan.torch/blob/master/data/data.lua).
Copyright (c) 2016, Deepak Pathak [See LICENSE file for details]
]]--
local Threads = require 'threads'
Threads.serialization('threads.sharedserialize')
local data = {}
local result = {}
local unpack = unpack and unpack or table.unpack
function data.new(n, opt_)
opt_ = opt_ or {}
local self = {}
for k,v in pairs(data) do
self[k] = v
end
local donkey_file = 'donkey_folder.lua'
-- print('n..' .. n)
if n > 0 then
local options = opt_
self.threads = Threads(n,
function() require 'torch' end,
function(idx)
opt = options
tid = idx
local seed = (opt.manualSeed and opt.manualSeed or 0) + idx
torch.manualSeed(seed)
torch.setnumthreads(1)
print(string.format('Starting donkey with id: %d seed: %d', tid, seed))
assert(options, 'options not found')
assert(opt, 'opt not given')
print(opt)
paths.dofile(donkey_file)
end
)
else
if donkey_file then paths.dofile(donkey_file) end
-- print('empty threads')
self.threads = {}
function self.threads:addjob(f1, f2) f2(f1()) end
function self.threads:dojob() end
function self.threads:synchronize() end
end
local nSamples = 0
self.threads:addjob(function() return trainLoader:size() end,
function(c) nSamples = c end)
self.threads:synchronize()
self._size = nSamples
for i = 1, n do
self.threads:addjob(self._getFromThreads,
self._pushResult)
end
-- print(self.threads)
return self
end
function data._getFromThreads()
assert(opt.batchSize, 'opt.batchSize not found')
return trainLoader:sample(opt.batchSize)
end
function data._pushResult(...)
local res = {...}
if res == nil then
self.threads:synchronize()
end
result[1] = res
end
function data:getBatch()
-- queue another job
self.threads:addjob(self._getFromThreads, self._pushResult)
self.threads:dojob()
local res = result[1]
img_data = res[1]
img_paths = res[3]
result[1] = nil
if torch.type(img_data) == 'table' then
img_data = unpack(img_data)
end
return img_data, img_paths
end
function data:size()
return self._size
end
return data
local data_util = {}
require 'torch'
-- options = require '../options.lua'
-- load dataset from the file system
-- |name|: name of the dataset. It's currently either 'A' or 'B'
function data_util.load_dataset(name, opt, nc)
local tensortype = torch.getdefaulttensortype()
torch.setdefaulttensortype('torch.FloatTensor')
local new_opt = options.clone(opt)
new_opt.manualSeed = torch.random(1, 10000) -- fix seed
new_opt.nc = nc
torch.manualSeed(new_opt.manualSeed)
local data_loader = paths.dofile('../data/data.lua')
new_opt.phase = new_opt.phase .. name
local data = data_loader.new(new_opt.nThreads, new_opt)
print("Dataset Size " .. name .. ": ", data:size())
torch.setdefaulttensortype(tensortype)
return data
end
return data_util
--[[
Copyright (c) 2015-present, Facebook, Inc.
All rights reserved.
This source code is licensed under the BSD-style license found in the
LICENSE file in the root directory of this source tree. An additional grant
of patent rights can be found in the PATENTS file in the same directory.
]]--
require 'torch'
torch.setdefaulttensortype('torch.FloatTensor')
local ffi = require 'ffi'
local class = require('pl.class')
local dir = require 'pl.dir'
local tablex = require 'pl.tablex'
local argcheck = require 'argcheck'
require 'sys'
require 'xlua'
require 'image'
local dataset = torch.class('dataLoader')
local initcheck = argcheck{
pack=true,
help=[[
A dataset class for images in a flat folder structure (folder-name is class-name).
Optimized for extremely large datasets (upwards of 14 million images).
Tested only on Linux (as it uses command-line linux utilities to scale up)
]],
{check=function(paths)
local out = true;
for k,v in ipairs(paths) do
if type(v) ~= 'string' then
print('paths can only be of string input');
out = false
end
end
return out
end,
name="paths",
type="table",
help="Multiple paths of directories with images"},
{name="sampleSize",
type="table",
help="a consistent sample size to resize the images"},
{name="split",
type="number",
help="Percentage of split to go to Training"
},
{name="serial_batches",
type="number",
help="if randomly sample training images"},
{name="samplingMode",
type="string",
help="Sampling mode: random | balanced ",
default = "balanced"},
{name="verbose",
type="boolean",
help="Verbose mode during initialization",
default = false},
{name="loadSize",
type="table",
help="a size to load the images to, initially",
opt = true},
{name="forceClasses",
type="table",
help="If you want this loader to map certain classes to certain indices, "
.. "pass a classes table that has {classname : classindex} pairs."
.. " For example: {3 : 'dog', 5 : 'cat'}"
.. "This function is very useful when you want two loaders to have the same "
.. "class indices (trainLoader/testLoader for example)",
opt = true},
{name="sampleHookTrain",
type="function",
help="applied to sample during training(ex: for lighting jitter). "
.. "It takes the image path as input",
opt = true},
{name="sampleHookTest",
type="function",
help="applied to sample during testing",
opt = true},
}
function dataset:__init(...)
-- argcheck
local args = initcheck(...)
print(args)
for k,v in pairs(args) do self[k] = v end
if not self.loadSize then self.loadSize = self.sampleSize; end
if not self.sampleHookTrain then self.sampleHookTrain = self.defaultSampleHook end
if not self.sampleHookTest then self.sampleHookTest = self.defaultSampleHook end
self.image_count = 1
-- print('image_count_init', self.image_count)
-- find class names
self.classes = {}
local classPaths = {}
if self.forceClasses then
for k,v in pairs(self.forceClasses) do
self.classes[k] = v
classPaths[k] = {}
end
end
local function tableFind(t, o) for k,v in pairs(t) do if v == o then return k end end end
-- loop over each paths folder, get list of unique class names,
-- also store the directory paths per class
-- for each class,
for k,path in ipairs(self.paths) do
-- print('path', path)
local dirs = {} -- hack
dirs[1] = path
-- local dirs = dir.getdirectories(path);
for k,dirpath in ipairs(dirs) do
local class = paths.basename(dirpath)
local idx = tableFind(self.classes, class)
-- print(class)
-- print(idx)
if not idx then
table.insert(self.classes, class)
idx = #self.classes
classPaths[idx] = {}
end
if not tableFind(classPaths[idx], dirpath) then
table.insert(classPaths[idx], dirpath);
end
end
end
self.classIndices = {}
for k,v in ipairs(self.classes) do
self.classIndices[v] = k
end
-- define command-line tools, try your best to maintain OSX compatibility
local wc = 'wc'
local cut = 'cut'
local find = 'find -H' -- if folder name is symlink, do find inside it after dereferencing
if ffi.os == 'OSX' then
wc = 'gwc'
cut = 'gcut'
find = 'gfind'
end
----------------------------------------------------------------------
-- Options for the GNU find command
local extensionList = {'jpg', 'png','JPG','PNG','JPEG', 'ppm', 'PPM', 'bmp', 'BMP'}
local findOptions = ' -iname "*.' .. extensionList[1] .. '"'
for i=2,#extensionList do
findOptions = findOptions .. ' -o -iname "*.' .. extensionList[i] .. '"'
end
-- find the image path names
self.imagePath = torch.CharTensor() -- path to each image in dataset
self.imageClass = torch.LongTensor() -- class index of each image (class index in self.classes)
self.classList = {} -- index of imageList to each image of a particular class
self.classListSample = self.classList -- the main list used when sampling data
print('running "find" on each class directory, and concatenate all'
.. ' those filenames into a single file containing all image paths for a given class')
-- so, generates one file per class
local classFindFiles = {}
for i=1,#self.classes do
classFindFiles[i] = os.tmpname()
end
local combinedFindList = os.tmpname();
local tmpfile = os.tmpname()
local tmphandle = assert(io.open(tmpfile, 'w'))
-- iterate over classes
for i, class in ipairs(self.classes) do
-- iterate over classPaths
for j,path in ipairs(classPaths[i]) do
local command = find .. ' "' .. path .. '" ' .. findOptions
.. ' >>"' .. classFindFiles[i] .. '" \n'
tmphandle:write(command)
end
end
io.close(tmphandle)
os.execute('bash ' .. tmpfile)
os.execute('rm -f ' .. tmpfile)
print('now combine all the files to a single large file')
local tmpfile = os.tmpname()
local tmphandle = assert(io.open(tmpfile, 'w'))
-- concat all finds to a single large file in the order of self.classes
for i=1,#self.classes do
local command = 'cat "' .. classFindFiles[i] .. '" >>' .. combinedFindList .. ' \n'
tmphandle:write(command)
end
io.close(tmphandle)
os.execute('bash ' .. tmpfile)
os.execute('rm -f ' .. tmpfile)
--==========================================================================
print('load the large concatenated list of sample paths to self.imagePath')
local cmd = wc .. " -L '"
.. combinedFindList .. "' |"
.. cut .. " -f1 -d' '"
print('cmd..' .. cmd)
local maxPathLength = tonumber(sys.fexecute(wc .. " -L '"
.. combinedFindList .. "' |"
.. cut .. " -f1 -d' '")) + 1
local length = tonumber(sys.fexecute(wc .. " -l '"
.. combinedFindList .. "' |"
.. cut .. " -f1 -d' '"))
assert(length > 0, "Could not find any image file in the given input paths")
assert(maxPathLength > 0, "paths of files are length 0?")
self.imagePath:resize(length, maxPathLength):fill(0)
local s_data = self.imagePath:data()
local count = 0
for line in io.lines(combinedFindList) do
ffi.copy(s_data, line)
s_data = s_data + maxPathLength
if self.verbose and count % 10000 == 0 then
xlua.progress(count, length)
end;
count = count + 1
end
self.numSamples = self.imagePath:size(1)
if self.verbose then print(self.numSamples .. ' samples found.') end
--==========================================================================
print('Updating classList and imageClass appropriately')
self.imageClass:resize(self.numSamples)
local runningIndex = 0
for i=1,#self.classes do
if self.verbose then xlua.progress(i, #(self.classes)) end
local length = tonumber(sys.fexecute(wc .. " -l '"
.. classFindFiles[i] .. "' |"
.. cut .. " -f1 -d' '"))
if length == 0 then
error('Class has zero samples')
else
self.classList[i] = torch.linspace(runningIndex + 1, runningIndex + length, length):long()
self.imageClass[{{runningIndex + 1, runningIndex + length}}]:fill(i)
end
runningIndex = runningIndex + length
end
--==========================================================================
-- clean up temporary files
print('Cleaning up temporary files')
local tmpfilelistall = ''
for i=1,#(classFindFiles) do
tmpfilelistall = tmpfilelistall .. ' "' .. classFindFiles[i] .. '"'
if i % 1000 == 0 then
os.execute('rm -f ' .. tmpfilelistall)
tmpfilelistall = ''
end
end
os.execute('rm -f ' .. tmpfilelistall)
os.execute('rm -f "' .. combinedFindList .. '"')
--==========================================================================
if self.split == 100 then
self.testIndicesSize = 0
else
print('Splitting training and test sets to a ratio of '
.. self.split .. '/' .. (100-self.split))
self.classListTrain = {}
self.classListTest = {}
self.classListSample = self.classListTrain
local totalTestSamples = 0
-- split the classList into classListTrain and classListTest
for i=1,#self.classes do
local list = self.classList[i]
local count = self.classList[i]:size(1)
local splitidx = math.floor((count * self.split / 100) + 0.5) -- +round
local perm = torch.randperm(count)
self.classListTrain[i] = torch.LongTensor(splitidx)
for j=1,splitidx do
self.classListTrain[i][j] = list[perm[j]]
end
if splitidx == count then -- all samples were allocated to train set
self.classListTest[i] = torch.LongTensor()
else
self.classListTest[i] = torch.LongTensor(count-splitidx)
totalTestSamples = totalTestSamples + self.classListTest[i]:size(1)
local idx = 1
for j=splitidx+1,count do
self.classListTest[i][idx] = list[perm[j]]
idx = idx + 1
end
end
end
-- Now combine classListTest into a single tensor
self.testIndices = torch.LongTensor(totalTestSamples)
self.testIndicesSize = totalTestSamples
local tdata = self.testIndices:data()
local tidx = 0
for i=1,#self.classes do
local list = self.classListTest[i]
if list:dim() ~= 0 then
local ldata = list:data()
for j=0,list:size(1)-1 do
tdata[tidx] = ldata[j]
tidx = tidx + 1
end
end
end
end
end
-- size(), size(class)
function dataset:size(class, list)
list = list or self.classList
if not class then
return self.numSamples
elseif type(class) == 'string' then
return list[self.classIndices[class]]:size(1)
elseif type(class) == 'number' then
return list[class]:size(1)
end
end
-- getByClass
function dataset:getByClass(class)
local index = 0
if self.serial_batches == 1 then
index = math.fmod(self.image_count-1, self.classListSample[class]:nElement())+1
self.image_count = self.image_count +1
else
index = math.ceil(torch.uniform() * self.classListSample[class]:nElement())
end
local imgpath = ffi.string(torch.data(self.imagePath[self.classListSample[class][index]]))
return self:sampleHookTrain(imgpath), imgpath
end
-- converts a table of samples (and corresponding labels) to a clean tensor
local function tableToOutput(self, dataTable, scalarTable)
local data, scalarLabels, labels
if opt.resize_or_crop == 'crop' or opt.resize_or_crop == 'scale_width' or opt.resize_or_crop == 'scale_height' then
assert(#scalarTable == 1)
data = torch.Tensor(1,
dataTable[1]:size(1), dataTable[1]:size(2), dataTable[1]:size(3))
data[1]:copy(dataTable[1])
scalarLabels = torch.LongTensor(#scalarTable):fill(-1111)
else
local quantity = #scalarTable
data = torch.Tensor(quantity,
self.sampleSize[1], self.sampleSize[2], self.sampleSize[3])
scalarLabels = torch.LongTensor(quantity):fill(-1111)
for i=1,#dataTable do
data[i]:copy(dataTable[i])
scalarLabels[i] = scalarTable[i]
end
end
return data, scalarLabels
end
-- sampler, samples from the training set.
function dataset:sample(quantity)
assert(quantity)
local dataTable = {}
local scalarTable = {}
local samplePaths = {}
for i=1,quantity do
local class = torch.random(1, #self.classes)
local out, imgpath = self:getByClass(class)
table.insert(dataTable, out)
table.insert(scalarTable, class)
samplePaths[i] = imgpath
end
local data, scalarLabels = tableToOutput(self, dataTable, scalarTable)
return data, scalarLabels, samplePaths-- filePaths
end
function dataset:get(i1, i2)
local indices = torch.range(i1, i2);
local quantity = i2 - i1 + 1;
assert(quantity > 0)
-- now that indices has been initialized, get the samples
local dataTable = {}
local scalarTable = {}
for i=1,quantity do
-- load the sample
local imgpath = ffi.string(torch.data(self.imagePath[indices[i]]))
local out = self:sampleHookTest(imgpath)
table.insert(dataTable, out)
table.insert(scalarTable, self.imageClass[indices[i]])
end
local data, scalarLabels = tableToOutput(self, dataTable, scalarTable)
return data, scalarLabels
end
return dataset
--[[
This data loader is a modified version of the one from dcgan.torch
(see https://github.com/soumith/dcgan.torch/blob/master/data/donkey_folder.lua).
Copyright (c) 2016, Deepak Pathak [See LICENSE file for details]
Copyright (c) 2015-present, Facebook, Inc.
All rights reserved.
This source code is licensed under the BSD-style license found in the
LICENSE file in the root directory of this source tree. An additional grant
of patent rights can be found in the PATENTS file in the same directory.
]]--
require 'image'
paths.dofile('dataset.lua')
-- This file contains the data-loading logic and details.
-- It is run by each data-loader thread.
------------------------------------------
-------- COMMON CACHES and PATHS
-- Check for existence of opt.data
if opt.DATA_ROOT then
opt.data = paths.concat(opt.DATA_ROOT, opt.phase)
else
print(os.getenv('DATA_ROOT'))
opt.data = paths.concat(os.getenv('DATA_ROOT'), opt.phase)
end
if not paths.dirp(opt.data) then
error('Did not find directory: ' .. opt.data)
end
-- a cache file of the training metadata (if doesnt exist, will be created)
local cache_prefix = opt.data:gsub('/', '_')
os.execute(('mkdir -p %s'):format(opt.cache_dir))
local trainCache = paths.concat(opt.cache_dir, cache_prefix .. '_trainCache.t7')
--------------------------------------------------------------------------------------------
local input_nc = opt.nc -- input channels
local loadSize = {input_nc, opt.loadSize}
local sampleSize = {input_nc, opt.fineSize}
local function loadImage(path)
local input = image.load(path, 3, 'float')
local h = input:size(2)
local w = input:size(3)
local imA = image.crop(input, 0, 0, w/2, h)
imA = image.scale(imA, loadSize[2], loadSize[2])
local imB = image.crop(input, w/2, 0, w, h)
imB = image.scale(imB, loadSize[2], loadSize[2])
local perm = torch.LongTensor{3, 2, 1}
imA = imA:index(1, perm)
imA = imA:mul(2):add(-1)
imB = imB:index(1, perm)
imB = imB:mul(2):add(-1)
assert(imA:max()<=1,"A: badly scaled inputs")
assert(imA:min()>=-1,"A: badly scaled inputs")
assert(imB:max()<=1,"B: badly scaled inputs")
assert(imB:min()>=-1,"B: badly scaled inputs")
local oW = sampleSize[2]
local oH = sampleSize[2]
local iH = imA:size(2)
local iW = imA:size(3)
if iH~=oH then
h1 = math.ceil(torch.uniform(1e-2, iH-oH))
end
if iW~=oW then
w1 = math.ceil(torch.uniform(1e-2, iW-oW))
end
if iH ~= oH or iW ~= oW then
imA = image.crop(imA, w1, h1, w1 + oW, h1 + oH)
imB = image.crop(imB, w1, h1, w1 + oW, h1 + oH)
end
if opt.flip == 1 and torch.uniform() > 0.5 then
imA = image.hflip(imA)
imB = image.hflip(imB)
end
local concatenated = torch.cat(imA,imB,1)
return concatenated
end
local function loadSingleImage(path)
local im = image.load(path, input_nc, 'float')
if opt.resize_or_crop == 'resize_and_crop' then
im = image.scale(im, loadSize[2], loadSize[2])
end
if input_nc == 3 then
local perm = torch.LongTensor{3, 2, 1}
im = im:index(1, perm)--:mul(256.0): brg, rgb
im = im:mul(2):add(-1)
end
assert(im:max()<=1,"A: badly scaled inputs")
assert(im:min()>=-1,"A: badly scaled inputs")
local oW = sampleSize[2]
local oH = sampleSize[2]
local iH = im:size(2)
local iW = im:size(3)
if (opt.resize_or_crop == 'resize_and_crop' ) then
local h1, w1 = 0, 0
if iH~=oH then
h1 = math.ceil(torch.uniform(1e-2, iH-oH))
end
if iW~=oW then
w1 = math.ceil(torch.uniform(1e-2, iW-oW))
end
if iH ~= oH or iW ~= oW then
im = image.crop(im, w1, h1, w1 + oW, h1 + oH)
end
elseif (opt.resize_or_crop == 'combined') then
local sH = math.min(math.ceil(oH * torch.uniform(1+1e-2, 2.0-1e-2)), iH-1e-2)
local sW = math.min(math.ceil(oW * torch.uniform(1+1e-2, 2.0-1e-2)), iW-1e-2)
local h1 = math.ceil(torch.uniform(1e-2, iH-sH))
local w1 = math.ceil(torch.uniform(1e-2, iW-sW))
im = image.crop(im, w1, h1, w1 + sW, h1 + sH)
im = image.scale(im, oW, oH)
elseif (opt.resize_or_crop == 'crop') then
local w = math.min(math.min(oH, iH),iW)
w = math.floor(w/4)*4
local x = math.floor(torch.uniform(0, iW - w))
local y = math.floor(torch.uniform(0, iH - w))
im = image.crop(im, x, y, x+w, y+w)
elseif (opt.resize_or_crop == 'scale_width') then
w = oW
h = torch.floor(iH * oW/iW)
im = image.scale(im, w, h)
elseif (opt.resize_or_crop == 'scale_height') then
h = oH
w = torch.floor(iW * oH / iH)
im = image.scale(im, w, h)
end
if opt.flip == 1 and torch.uniform() > 0.5 then
im = image.hflip(im)
end
return im
end
-- channel-wise mean and std. Calculate or load them from disk later in the script.
local mean,std
--------------------------------------------------------------------------------
-- Hooks that are used for each image that is loaded
-- function to load the image, jitter it appropriately (random crops etc.)
local trainHook_singleimage = function(self, path)
collectgarbage()
-- print('load single image')
local im = loadSingleImage(path)
return im
end
-- function that loads images that have juxtaposition
-- of two images from two domains
local trainHook_doubleimage = function(self, path)
-- print('load double image')
collectgarbage()
local im = loadImage(path)
return im
end
if opt.align_data > 0 then
sample_nc = input_nc*2
trainHook = trainHook_doubleimage
else
sample_nc = input_nc
trainHook = trainHook_singleimage
end
trainLoader = dataLoader{
paths = {opt.data},
loadSize = {input_nc, loadSize[2], loadSize[2]},
sampleSize = {sample_nc, sampleSize[2], sampleSize[2]},
split = 100,
serial_batches = opt.serial_batches,
verbose = true
}
trainLoader.sampleHookTrain = trainHook
collectgarbage()
-- do some sanity checks on trainLoader
do
local class = trainLoader.imageClass
local nClasses = #trainLoader.classes
assert(class:max() <= nClasses, "class logic has error")
assert(class:min() >= 1, "class logic has error")
end
--------------------------------------------------------------------------------
-- Subclass of BaseDataLoader that provides data from two datasets.
-- The samples from the datasets are not aligned.
-- The datasets can have different sizes
--------------------------------------------------------------------------------
require 'data.base_data_loader'
local class = require 'class'
data_util = paths.dofile('data_util.lua')
UnalignedDataLoader = class('UnalignedDataLoader', 'BaseDataLoader')
function UnalignedDataLoader:__init(conf)
BaseDataLoader.__init(self, conf)
conf = conf or {}
end
function UnalignedDataLoader:name()
return 'UnalignedDataLoader'
end
function UnalignedDataLoader:Initialize(opt)
opt.align_data = 0
self.dataA = data_util.load_dataset('A', opt, opt.input_nc)
self.dataB = data_util.load_dataset('B', opt, opt.output_nc)
end
-- actually fetches the data
-- |return|: a table of two tables, each corresponding to
-- the batch for dataset A and dataset B
function UnalignedDataLoader:LoadBatchForAllDatasets()
local batchA, pathA = self.dataA:getBatch()
local batchB, pathB = self.dataB:getBatch()
return batchA, batchB, pathA, pathB
end
-- returns the size of each dataset
function UnalignedDataLoader:size(dataset)
if dataset == 'A' then
return self.dataA:size()
end
if dataset == 'B' then
return self.dataB:size()
end
return math.max(self.dataA:size(), self.dataB:size())
-- return the size of the largest dataset by default
end
#!/bin/sh
## This script download the dataset and pre-trained network,
## and generates style transferred images.
# Download the dataset. The downloaded dataset is stored in ./datasets/${DATASET_NAME}
DATASET_NAME='ae_photos'
bash ./datasets/download_dataset.sh $DATASET_NAME
# Download the pre-trained model. The downloaded model is stored in ./models/${MODEL_NAME}_pretrained/latest_net_G.t7
MODEL_NAME='style_vangogh'
bash ./pretrained_models/download_model.sh $MODEL_NAME
# Run style transfer using the downloaded dataset and model
DATA_ROOT=./datasets/$DATASET_NAME name=${MODEL_NAME}_pretrained model=one_direction_test phase=test how_many='all' loadSize=256 fineSize=256 resize_or_crop='scale_width' th test.lua
if [ $? == 0 ]; then
echo "The result can be viewed at ./results/${MODEL_NAME}_pretrained/latest_test/index.html"
fi
DB_NAME='maps'
GPU_ID=1
DISPLAY_ID=1
NET_G=resnet_6blocks
NET_D=basic
MODEL=cycle_gan
SAVE_EPOCH=5
ALIGN_DATA=0
LAMBDA=10
NF=64
EXPR_NAME=${DB_NAME}_${MODEL}_${LAMBDA}
CHECKPOINT_DIR=./checkpoints/
LOG_FILE=${CHECKPOINT_DIR}${EXPR_NAME}/log.txt
mkdir -p ${CHECKPOINT_DIR}${EXPR_NAME}
DATA_ROOT=./datasets/$DB_NAME align_data=$ALIGN_DATA use_lsgan=1 \
which_direction='AtoB' display_plot=$PLOT pool_size=50 niter=100 niter_decay=100 \
which_model_netG=$NET_G which_model_netD=$NET_D model=$MODEL lr=0.0002 print_freq=200 lambda_A=$LAMBDA lambda_B=$LAMBDA \
loadSize=143 fineSize=128 gpu=$GPU_ID display_winsize=128 \
name=$EXPR_NAME flip=1 save_epoch_freq=$SAVE_EPOCH \
continue_train=0 display_id=$DISPLAY_ID \
checkpoints_dir=$CHECKPOINT_DIR\
th train.lua | tee -a $LOG_FILE
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment