Unverified Commit ee20d1f8 authored by YiYi Xu's avatar YiYi Xu Committed by GitHub
Browse files

update flax controlnet training script (#2951)

* load_from_disk + checkpointing_steps

* apply feedback
parent 0d0fa2a3
...@@ -27,13 +27,13 @@ import optax ...@@ -27,13 +27,13 @@ import optax
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
import transformers import transformers
from datasets import load_dataset from datasets import load_dataset, load_from_disk
from flax import jax_utils from flax import jax_utils
from flax.core.frozen_dict import unfreeze from flax.core.frozen_dict import unfreeze
from flax.training import train_state from flax.training import train_state
from flax.training.common_utils import shard from flax.training.common_utils import shard
from huggingface_hub import create_repo, upload_folder from huggingface_hub import create_repo, upload_folder
from PIL import Image from PIL import Image, PngImagePlugin
from torch.utils.data import IterableDataset from torch.utils.data import IterableDataset
from torchvision import transforms from torchvision import transforms
from tqdm.auto import tqdm from tqdm.auto import tqdm
...@@ -49,6 +49,11 @@ from diffusers import ( ...@@ -49,6 +49,11 @@ from diffusers import (
from diffusers.utils import check_min_version, is_wandb_available from diffusers.utils import check_min_version, is_wandb_available
# To prevent an error that occurs when there are abnormally large compressed data chunk in the png image
# see more https://github.com/python-pillow/Pillow/issues/5610
LARGE_ENOUGH_NUMBER = 100
PngImagePlugin.MAX_TEXT_CHUNK = LARGE_ENOUGH_NUMBER * (1024**2)
if is_wandb_available(): if is_wandb_available():
import wandb import wandb
...@@ -246,6 +251,12 @@ def parse_args(): ...@@ -246,6 +251,12 @@ def parse_args():
default=None, default=None,
help="Total number of training steps to perform.", help="Total number of training steps to perform.",
) )
parser.add_argument(
"--checkpointing_steps",
type=int,
default=5000,
help=("Save a checkpoint of the training state every X updates."),
)
parser.add_argument( parser.add_argument(
"--learning_rate", "--learning_rate",
type=float, type=float,
...@@ -344,9 +355,17 @@ def parse_args(): ...@@ -344,9 +355,17 @@ def parse_args():
type=str, type=str,
default=None, default=None,
help=( help=(
"A folder containing the training data. Folder contents must follow the structure described in" "A folder containing the training dataset. By default it will use `load_dataset` method to load a custom dataset from the folder."
" https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file" "Folder must contain a dataset script as described here https://huggingface.co/docs/datasets/dataset_script) ."
" must exist to provide the captions for the images. Ignored if `dataset_name` is specified." "If `--load_from_disk` flag is passed, it will use `load_from_disk` method instead. Ignored if `dataset_name` is specified."
),
)
parser.add_argument(
"--load_from_disk",
action="store_true",
help=(
"If True, will load a dataset that was previously saved using `save_to_disk` from `--train_data_dir`"
"See more https://huggingface.co/docs/datasets/package_reference/main_classes#datasets.Dataset.load_from_disk"
), ),
) )
parser.add_argument( parser.add_argument(
...@@ -478,6 +497,11 @@ def make_train_dataset(args, tokenizer, batch_size=None): ...@@ -478,6 +497,11 @@ def make_train_dataset(args, tokenizer, batch_size=None):
) )
else: else:
if args.train_data_dir is not None: if args.train_data_dir is not None:
if args.load_from_disk:
dataset = load_from_disk(
args.train_data_dir,
)
else:
dataset = load_dataset( dataset = load_dataset(
args.train_data_dir, args.train_data_dir,
cache_dir=args.cache_dir, cache_dir=args.cache_dir,
...@@ -545,6 +569,7 @@ def make_train_dataset(args, tokenizer, batch_size=None): ...@@ -545,6 +569,7 @@ def make_train_dataset(args, tokenizer, batch_size=None):
image_transforms = transforms.Compose( image_transforms = transforms.Compose(
[ [
transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR), transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
transforms.CenterCrop(args.resolution),
transforms.ToTensor(), transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]), transforms.Normalize([0.5], [0.5]),
] ]
...@@ -553,6 +578,7 @@ def make_train_dataset(args, tokenizer, batch_size=None): ...@@ -553,6 +578,7 @@ def make_train_dataset(args, tokenizer, batch_size=None):
conditioning_image_transforms = transforms.Compose( conditioning_image_transforms = transforms.Compose(
[ [
transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR), transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
transforms.CenterCrop(args.resolution),
transforms.ToTensor(), transforms.ToTensor(),
] ]
) )
...@@ -981,6 +1007,11 @@ def main(): ...@@ -981,6 +1007,11 @@ def main():
"train/loss": jax_utils.unreplicate(train_metric)["loss"], "train/loss": jax_utils.unreplicate(train_metric)["loss"],
} }
) )
if global_step % args.checkpointing_steps == 0 and jax.process_index() == 0:
controlnet.save_pretrained(
f"{args.output_dir}/{global_step}",
params=get_params_to_save(state.params),
)
train_metric = jax_utils.unreplicate(train_metric) train_metric = jax_utils.unreplicate(train_metric)
train_step_progress_bar.close() train_step_progress_bar.close()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment