Unverified Commit b1dad2e9 authored by wfng92's avatar wfng92 Committed by GitHub
Browse files

Make center crop and random flip as args for unconditional image generation (#2259)

* Add center crop and horizontal flip to args

* Update command to use center crop and random flip

* Add center crop and horizontal flip to args

* Update command to use center crop and random flip
parent cd524755
...@@ -36,7 +36,7 @@ The command to train a DDPM UNet model on the Oxford Flowers dataset with onnxru ...@@ -36,7 +36,7 @@ The command to train a DDPM UNet model on the Oxford Flowers dataset with onnxru
```bash ```bash
accelerate launch train_unconditional_ort.py \ accelerate launch train_unconditional_ort.py \
--dataset_name="huggan/flowers-102-categories" \ --dataset_name="huggan/flowers-102-categories" \
--resolution=64 \ --resolution=64 --center_crop --random_flip \
--output_dir="ddpm-ema-flowers-64" \ --output_dir="ddpm-ema-flowers-64" \
--use_ema \ --use_ema \
--train_batch_size=16 \ --train_batch_size=16 \
......
...@@ -20,15 +20,7 @@ from diffusers.training_utils import EMAModel ...@@ -20,15 +20,7 @@ from diffusers.training_utils import EMAModel
from diffusers.utils import check_min_version from diffusers.utils import check_min_version
from huggingface_hub import HfFolder, Repository, create_repo, whoami from huggingface_hub import HfFolder, Repository, create_repo, whoami
from onnxruntime.training.ortmodule import ORTModule from onnxruntime.training.ortmodule import ORTModule
from torchvision.transforms import ( from torchvision import transforms
CenterCrop,
Compose,
InterpolationMode,
Normalize,
RandomHorizontalFlip,
Resize,
ToTensor,
)
from tqdm.auto import tqdm from tqdm.auto import tqdm
...@@ -105,6 +97,21 @@ def parse_args(): ...@@ -105,6 +97,21 @@ def parse_args():
" resolution" " resolution"
), ),
) )
parser.add_argument(
"--center_crop",
default=False,
action="store_true",
help=(
"Whether to center crop the input images to the resolution. If not set, the images will be randomly"
" cropped. The images will be resized to the resolution first before cropping."
),
)
parser.add_argument(
"--random_flip",
default=False,
action="store_true",
help="whether to randomly flip images horizontally",
)
parser.add_argument( parser.add_argument(
"--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader." "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
) )
...@@ -369,13 +376,13 @@ def main(args): ...@@ -369,13 +376,13 @@ def main(args):
# https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder
# Preprocessing the datasets and DataLoaders creation. # Preprocessing the datasets and DataLoaders creation.
augmentations = Compose( augmentations = transforms.Compose(
[ [
Resize(args.resolution, interpolation=InterpolationMode.BILINEAR), transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
CenterCrop(args.resolution), transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),
RandomHorizontalFlip(), transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x),
ToTensor(), transforms.ToTensor(),
Normalize([0.5], [0.5]), transforms.Normalize([0.5], [0.5]),
] ]
) )
......
...@@ -34,7 +34,7 @@ The command to train a DDPM UNet model on the Oxford Flowers dataset: ...@@ -34,7 +34,7 @@ The command to train a DDPM UNet model on the Oxford Flowers dataset:
```bash ```bash
accelerate launch train_unconditional.py \ accelerate launch train_unconditional.py \
--dataset_name="huggan/flowers-102-categories" \ --dataset_name="huggan/flowers-102-categories" \
--resolution=64 \ --resolution=64 --center_crop --random_flip \
--output_dir="ddpm-ema-flowers-64" \ --output_dir="ddpm-ema-flowers-64" \
--train_batch_size=16 \ --train_batch_size=16 \
--num_epochs=100 \ --num_epochs=100 \
...@@ -59,7 +59,7 @@ The command to train a DDPM UNet model on the Pokemon dataset: ...@@ -59,7 +59,7 @@ The command to train a DDPM UNet model on the Pokemon dataset:
```bash ```bash
accelerate launch train_unconditional.py \ accelerate launch train_unconditional.py \
--dataset_name="huggan/pokemon" \ --dataset_name="huggan/pokemon" \
--resolution=64 \ --resolution=64 --center_crop --random_flip \
--output_dir="ddpm-ema-pokemon-64" \ --output_dir="ddpm-ema-pokemon-64" \
--train_batch_size=16 \ --train_batch_size=16 \
--num_epochs=100 \ --num_epochs=100 \
......
...@@ -19,15 +19,7 @@ from diffusers.optimization import get_scheduler ...@@ -19,15 +19,7 @@ from diffusers.optimization import get_scheduler
from diffusers.training_utils import EMAModel from diffusers.training_utils import EMAModel
from diffusers.utils import check_min_version from diffusers.utils import check_min_version
from huggingface_hub import HfFolder, Repository, create_repo, whoami from huggingface_hub import HfFolder, Repository, create_repo, whoami
from torchvision.transforms import ( from torchvision import transforms
CenterCrop,
Compose,
InterpolationMode,
Normalize,
RandomHorizontalFlip,
Resize,
ToTensor,
)
from tqdm.auto import tqdm from tqdm.auto import tqdm
...@@ -105,6 +97,21 @@ def parse_args(): ...@@ -105,6 +97,21 @@ def parse_args():
" resolution" " resolution"
), ),
) )
parser.add_argument(
"--center_crop",
default=False,
action="store_true",
help=(
"Whether to center crop the input images to the resolution. If not set, the images will be randomly"
" cropped. The images will be resized to the resolution first before cropping."
),
)
parser.add_argument(
"--random_flip",
default=False,
action="store_true",
help="whether to randomly flip images horizontally",
)
parser.add_argument( parser.add_argument(
"--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader." "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
) )
...@@ -369,13 +376,13 @@ def main(args): ...@@ -369,13 +376,13 @@ def main(args):
# https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder
# Preprocessing the datasets and DataLoaders creation. # Preprocessing the datasets and DataLoaders creation.
augmentations = Compose( augmentations = transforms.Compose(
[ [
Resize(args.resolution, interpolation=InterpolationMode.BILINEAR), transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
CenterCrop(args.resolution), transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),
RandomHorizontalFlip(), transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x),
ToTensor(), transforms.ToTensor(),
Normalize([0.5], [0.5]), transforms.Normalize([0.5], [0.5]),
] ]
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment