Unverified Commit b1dad2e9 authored by wfng92's avatar wfng92 Committed by GitHub
Browse files

Make center crop and random flip as args for unconditional image generation (#2259)

* Add center crop and horizontal flip to args

* Update command to use center crop and random flip

* Add center crop and horizontal flip to args

* Update command to use center crop and random flip
parent cd524755
......@@ -36,7 +36,7 @@ The command to train a DDPM UNet model on the Oxford Flowers dataset with onnxru
```bash
accelerate launch train_unconditional_ort.py \
--dataset_name="huggan/flowers-102-categories" \
--resolution=64 \
--resolution=64 --center_crop --random_flip \
--output_dir="ddpm-ema-flowers-64" \
--use_ema \
--train_batch_size=16 \
......
......@@ -20,15 +20,7 @@ from diffusers.training_utils import EMAModel
from diffusers.utils import check_min_version
from huggingface_hub import HfFolder, Repository, create_repo, whoami
from onnxruntime.training.ortmodule import ORTModule
from torchvision.transforms import (
CenterCrop,
Compose,
InterpolationMode,
Normalize,
RandomHorizontalFlip,
Resize,
ToTensor,
)
from torchvision import transforms
from tqdm.auto import tqdm
......@@ -105,6 +97,21 @@ def parse_args():
" resolution"
),
)
parser.add_argument(
"--center_crop",
default=False,
action="store_true",
help=(
"Whether to center crop the input images to the resolution. If not set, the images will be randomly"
" cropped. The images will be resized to the resolution first before cropping."
),
)
parser.add_argument(
"--random_flip",
default=False,
action="store_true",
help="whether to randomly flip images horizontally",
)
parser.add_argument(
"--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
)
......@@ -369,13 +376,13 @@ def main(args):
# https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder
# Preprocessing the datasets and DataLoaders creation.
augmentations = Compose(
augmentations = transforms.Compose(
[
Resize(args.resolution, interpolation=InterpolationMode.BILINEAR),
CenterCrop(args.resolution),
RandomHorizontalFlip(),
ToTensor(),
Normalize([0.5], [0.5]),
transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),
transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]
)
......
......@@ -34,7 +34,7 @@ The command to train a DDPM UNet model on the Oxford Flowers dataset:
```bash
accelerate launch train_unconditional.py \
--dataset_name="huggan/flowers-102-categories" \
--resolution=64 \
--resolution=64 --center_crop --random_flip \
--output_dir="ddpm-ema-flowers-64" \
--train_batch_size=16 \
--num_epochs=100 \
......@@ -59,7 +59,7 @@ The command to train a DDPM UNet model on the Pokemon dataset:
```bash
accelerate launch train_unconditional.py \
--dataset_name="huggan/pokemon" \
--resolution=64 \
--resolution=64 --center_crop --random_flip \
--output_dir="ddpm-ema-pokemon-64" \
--train_batch_size=16 \
--num_epochs=100 \
......
......@@ -19,15 +19,7 @@ from diffusers.optimization import get_scheduler
from diffusers.training_utils import EMAModel
from diffusers.utils import check_min_version
from huggingface_hub import HfFolder, Repository, create_repo, whoami
from torchvision.transforms import (
CenterCrop,
Compose,
InterpolationMode,
Normalize,
RandomHorizontalFlip,
Resize,
ToTensor,
)
from torchvision import transforms
from tqdm.auto import tqdm
......@@ -105,6 +97,21 @@ def parse_args():
" resolution"
),
)
parser.add_argument(
"--center_crop",
default=False,
action="store_true",
help=(
"Whether to center crop the input images to the resolution. If not set, the images will be randomly"
" cropped. The images will be resized to the resolution first before cropping."
),
)
parser.add_argument(
"--random_flip",
default=False,
action="store_true",
help="whether to randomly flip images horizontally",
)
parser.add_argument(
"--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
)
......@@ -369,13 +376,13 @@ def main(args):
# https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder
# Preprocessing the datasets and DataLoaders creation.
augmentations = Compose(
augmentations = transforms.Compose(
[
Resize(args.resolution, interpolation=InterpolationMode.BILINEAR),
CenterCrop(args.resolution),
RandomHorizontalFlip(),
ToTensor(),
Normalize([0.5], [0.5]),
transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),
transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment