Unverified Commit b1dad2e9 authored by wfng92's avatar wfng92 Committed by GitHub
Browse files

Make center crop and random flip as args for unconditional image generation (#2259)

* Add center crop and horizontal flip to args

* Update command to use center crop and random flip

* Add center crop and horizontal flip to args

* Update command to use center crop and random flip
parent cd524755
...@@ -36,7 +36,7 @@ The command to train a DDPM UNet model on the Oxford Flowers dataset with onnxru ...@@ -36,7 +36,7 @@ The command to train a DDPM UNet model on the Oxford Flowers dataset with onnxru
```bash ```bash
accelerate launch train_unconditional_ort.py \ accelerate launch train_unconditional_ort.py \
--dataset_name="huggan/flowers-102-categories" \ --dataset_name="huggan/flowers-102-categories" \
--resolution=64 \ --resolution=64 --center_crop --random_flip \
--output_dir="ddpm-ema-flowers-64" \ --output_dir="ddpm-ema-flowers-64" \
--use_ema \ --use_ema \
--train_batch_size=16 \ --train_batch_size=16 \
...@@ -47,4 +47,4 @@ accelerate launch train_unconditional_ort.py \ ...@@ -47,4 +47,4 @@ accelerate launch train_unconditional_ort.py \
--mixed_precision=fp16 --mixed_precision=fp16
``` ```
Please contact Prathik Rao (prathikr), Sunghoon Choi (hanbitmyths), Ashwini Khade (askhade), or Peng Wang (pengwa) on github with any questions. Please contact Prathik Rao (prathikr), Sunghoon Choi (hanbitmyths), Ashwini Khade (askhade), or Peng Wang (pengwa) on github with any questions.
\ No newline at end of file
...@@ -20,15 +20,7 @@ from diffusers.training_utils import EMAModel ...@@ -20,15 +20,7 @@ from diffusers.training_utils import EMAModel
from diffusers.utils import check_min_version from diffusers.utils import check_min_version
from huggingface_hub import HfFolder, Repository, create_repo, whoami from huggingface_hub import HfFolder, Repository, create_repo, whoami
from onnxruntime.training.ortmodule import ORTModule from onnxruntime.training.ortmodule import ORTModule
from torchvision.transforms import ( from torchvision import transforms
CenterCrop,
Compose,
InterpolationMode,
Normalize,
RandomHorizontalFlip,
Resize,
ToTensor,
)
from tqdm.auto import tqdm from tqdm.auto import tqdm
...@@ -105,6 +97,21 @@ def parse_args(): ...@@ -105,6 +97,21 @@ def parse_args():
" resolution" " resolution"
), ),
) )
parser.add_argument(
"--center_crop",
default=False,
action="store_true",
help=(
"Whether to center crop the input images to the resolution. If not set, the images will be randomly"
" cropped. The images will be resized to the resolution first before cropping."
),
)
parser.add_argument(
"--random_flip",
default=False,
action="store_true",
help="whether to randomly flip images horizontally",
)
parser.add_argument( parser.add_argument(
"--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader." "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
) )
...@@ -369,13 +376,13 @@ def main(args): ...@@ -369,13 +376,13 @@ def main(args):
# https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder
# Preprocessing the datasets and DataLoaders creation. # Preprocessing the datasets and DataLoaders creation.
augmentations = Compose( augmentations = transforms.Compose(
[ [
Resize(args.resolution, interpolation=InterpolationMode.BILINEAR), transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
CenterCrop(args.resolution), transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),
RandomHorizontalFlip(), transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x),
ToTensor(), transforms.ToTensor(),
Normalize([0.5], [0.5]), transforms.Normalize([0.5], [0.5]),
] ]
) )
......
...@@ -34,7 +34,7 @@ The command to train a DDPM UNet model on the Oxford Flowers dataset: ...@@ -34,7 +34,7 @@ The command to train a DDPM UNet model on the Oxford Flowers dataset:
```bash ```bash
accelerate launch train_unconditional.py \ accelerate launch train_unconditional.py \
--dataset_name="huggan/flowers-102-categories" \ --dataset_name="huggan/flowers-102-categories" \
--resolution=64 \ --resolution=64 --center_crop --random_flip \
--output_dir="ddpm-ema-flowers-64" \ --output_dir="ddpm-ema-flowers-64" \
--train_batch_size=16 \ --train_batch_size=16 \
--num_epochs=100 \ --num_epochs=100 \
...@@ -59,7 +59,7 @@ The command to train a DDPM UNet model on the Pokemon dataset: ...@@ -59,7 +59,7 @@ The command to train a DDPM UNet model on the Pokemon dataset:
```bash ```bash
accelerate launch train_unconditional.py \ accelerate launch train_unconditional.py \
--dataset_name="huggan/pokemon" \ --dataset_name="huggan/pokemon" \
--resolution=64 \ --resolution=64 --center_crop --random_flip \
--output_dir="ddpm-ema-pokemon-64" \ --output_dir="ddpm-ema-pokemon-64" \
--train_batch_size=16 \ --train_batch_size=16 \
--num_epochs=100 \ --num_epochs=100 \
...@@ -139,4 +139,4 @@ dataset.push_to_hub("name_of_your_dataset", private=True) ...@@ -139,4 +139,4 @@ dataset.push_to_hub("name_of_your_dataset", private=True)
and that's it! You can now train your model by simply setting the `--dataset_name` argument to the name of your dataset on the hub. and that's it! You can now train your model by simply setting the `--dataset_name` argument to the name of your dataset on the hub.
More on this can also be found in [this blog post](https://huggingface.co/blog/image-search-datasets). More on this can also be found in [this blog post](https://huggingface.co/blog/image-search-datasets).
\ No newline at end of file
...@@ -19,15 +19,7 @@ from diffusers.optimization import get_scheduler ...@@ -19,15 +19,7 @@ from diffusers.optimization import get_scheduler
from diffusers.training_utils import EMAModel from diffusers.training_utils import EMAModel
from diffusers.utils import check_min_version from diffusers.utils import check_min_version
from huggingface_hub import HfFolder, Repository, create_repo, whoami from huggingface_hub import HfFolder, Repository, create_repo, whoami
from torchvision.transforms import ( from torchvision import transforms
CenterCrop,
Compose,
InterpolationMode,
Normalize,
RandomHorizontalFlip,
Resize,
ToTensor,
)
from tqdm.auto import tqdm from tqdm.auto import tqdm
...@@ -105,6 +97,21 @@ def parse_args(): ...@@ -105,6 +97,21 @@ def parse_args():
" resolution" " resolution"
), ),
) )
parser.add_argument(
"--center_crop",
default=False,
action="store_true",
help=(
"Whether to center crop the input images to the resolution. If not set, the images will be randomly"
" cropped. The images will be resized to the resolution first before cropping."
),
)
parser.add_argument(
"--random_flip",
default=False,
action="store_true",
help="whether to randomly flip images horizontally",
)
parser.add_argument( parser.add_argument(
"--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader." "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
) )
...@@ -369,13 +376,13 @@ def main(args): ...@@ -369,13 +376,13 @@ def main(args):
# https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder
# Preprocessing the datasets and DataLoaders creation. # Preprocessing the datasets and DataLoaders creation.
augmentations = Compose( augmentations = transforms.Compose(
[ [
Resize(args.resolution, interpolation=InterpolationMode.BILINEAR), transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
CenterCrop(args.resolution), transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),
RandomHorizontalFlip(), transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x),
ToTensor(), transforms.ToTensor(),
Normalize([0.5], [0.5]), transforms.Normalize([0.5], [0.5]),
] ]
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment