Unverified Commit d82b0323 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[Examples] Add streaming support to the ControlNet training example in JAX (#2859)



* improve stable unclip doc.

* feat: add streaming support to controlnet flax training script.

* fix: CLI arg.

* fix: torch dataloader shuffle setting.

* fix: dataset length.

* fix: wandb config.

* fix: steps_per_epoch in the training loop.

* add: entry about streaming in the readme

* get column names from iterable dataset + fix final logging

---------
Co-authored-by: default avataryiyixuxu <yixu310@gmail.com>
parent 40a7b862
...@@ -335,7 +335,7 @@ huggingface-cli login ...@@ -335,7 +335,7 @@ huggingface-cli login
Make sure you have the `MODEL_DIR`,`OUTPUT_DIR` and `HUB_MODEL_ID` environment variables set. The `OUTPUT_DIR` and `HUB_MODEL_ID` variables specify where to save the model to on the Hub: Make sure you have the `MODEL_DIR`,`OUTPUT_DIR` and `HUB_MODEL_ID` environment variables set. The `OUTPUT_DIR` and `HUB_MODEL_ID` variables specify where to save the model to on the Hub:
``` ```bash
export MODEL_DIR="runwayml/stable-diffusion-v1-5" export MODEL_DIR="runwayml/stable-diffusion-v1-5"
export OUTPUT_DIR="control_out" export OUTPUT_DIR="control_out"
export HUB_MODEL_ID="fill-circle-controlnet" export HUB_MODEL_ID="fill-circle-controlnet"
...@@ -343,7 +343,7 @@ export HUB_MODEL_ID="fill-circle-controlnet" ...@@ -343,7 +343,7 @@ export HUB_MODEL_ID="fill-circle-controlnet"
And finally start the training And finally start the training
``` ```bash
python3 train_controlnet_flax.py \ python3 train_controlnet_flax.py \
--pretrained_model_name_or_path=$MODEL_DIR \ --pretrained_model_name_or_path=$MODEL_DIR \
--output_dir=$OUTPUT_DIR \ --output_dir=$OUTPUT_DIR \
...@@ -363,3 +363,30 @@ python3 train_controlnet_flax.py \ ...@@ -363,3 +363,30 @@ python3 train_controlnet_flax.py \
``` ```
Since we passed the `--push_to_hub` flag, it will automatically create a model repo under your huggingface account based on `$HUB_MODEL_ID`. By the end of training, the final checkpoint will be automatically stored on the hub. You can find an example model repo [here](https://huggingface.co/YiYiXu/fill-circle-controlnet). Since we passed the `--push_to_hub` flag, it will automatically create a model repo under your huggingface account based on `$HUB_MODEL_ID`. By the end of training, the final checkpoint will be automatically stored on the hub. You can find an example model repo [here](https://huggingface.co/YiYiXu/fill-circle-controlnet).
Our training script also provides limited support for streaming large datasets from the Hugging Face Hub. In order to enable streaming, one must also set `--max_train_samples`. Here is an example command:
```bash
python3 train_controlnet_flax.py \
--pretrained_model_name_or_path=$MODEL_DIR \
--output_dir=$OUTPUT_DIR \
--dataset_name=multimodalart/facesyntheticsspigacaptioned \
--streaming \
--conditioning_image_column=spiga_seg \
--image_column=image \
--caption_column=image_caption \
--resolution=512 \
--max_train_samples 50 \
--max_train_steps 5 \
--learning_rate=1e-5 \
--validation_steps=2 \
--train_batch_size=1 \
--revision="flax" \
--report_to="wandb"
```
Note, however, that the performance of the TPUs might get bottlenecked as streaming with `datasets` is not optimized for images. For ensuring maximum throughput, we encourage you to explore the following options:
* [Webdataset](https://webdataset.github.io/webdataset/)
* [TorchData](https://github.com/pytorch/data)
* [TensorFlow Datasets](https://www.tensorflow.org/datasets/tfless_tfds)
\ No newline at end of file
...@@ -35,6 +35,7 @@ from flax.training import train_state ...@@ -35,6 +35,7 @@ from flax.training import train_state
from flax.training.common_utils import shard from flax.training.common_utils import shard
from huggingface_hub import HfFolder, Repository, create_repo, whoami from huggingface_hub import HfFolder, Repository, create_repo, whoami
from PIL import Image from PIL import Image
from torch.utils.data import IterableDataset
from torchvision import transforms from torchvision import transforms
from tqdm.auto import tqdm from tqdm.auto import tqdm
from transformers import CLIPTokenizer, FlaxCLIPTextModel, set_seed from transformers import CLIPTokenizer, FlaxCLIPTextModel, set_seed
...@@ -206,7 +207,7 @@ def parse_args(): ...@@ -206,7 +207,7 @@ def parse_args():
parser.add_argument( parser.add_argument(
"--from_pt", "--from_pt",
action="store_true", action="store_true",
help="Load the pretrained model from a pytorch checkpoint.", help="Load the pretrained model from a PyTorch checkpoint.",
) )
parser.add_argument( parser.add_argument(
"--tokenizer_name", "--tokenizer_name",
...@@ -332,6 +333,7 @@ def parse_args(): ...@@ -332,6 +333,7 @@ def parse_args():
" or to a folder containing files that 🤗 Datasets can understand." " or to a folder containing files that 🤗 Datasets can understand."
), ),
) )
parser.add_argument("--streaming", action="store_true", help="To stream a large dataset from Hub.")
parser.add_argument( parser.add_argument(
"--dataset_config_name", "--dataset_config_name",
type=str, type=str,
...@@ -369,7 +371,7 @@ def parse_args(): ...@@ -369,7 +371,7 @@ def parse_args():
default=None, default=None,
help=( help=(
"For debugging purposes or quicker training, truncate the number of training examples to this " "For debugging purposes or quicker training, truncate the number of training examples to this "
"value if set." "value if set. Needed if `streaming` is set to True."
), ),
) )
parser.add_argument( parser.add_argument(
...@@ -453,10 +455,15 @@ def parse_args(): ...@@ -453,10 +455,15 @@ def parse_args():
" or the same number of `--validation_prompt`s and `--validation_image`s" " or the same number of `--validation_prompt`s and `--validation_image`s"
) )
# This idea comes from
# https://github.com/borisdayma/dalle-mini/blob/d2be512d4a6a9cda2d63ba04afc33038f98f705f/src/dalle_mini/data.py#L370
if args.streaming and args.max_train_samples is None:
raise ValueError("You must specify `max_train_samples` when using dataset streaming.")
return args return args
def make_train_dataset(args, tokenizer): def make_train_dataset(args, tokenizer, batch_size=None):
# Get the datasets: you can either provide your own training and evaluation files (see below) # Get the datasets: you can either provide your own training and evaluation files (see below)
# or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub). # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).
...@@ -468,6 +475,7 @@ def make_train_dataset(args, tokenizer): ...@@ -468,6 +475,7 @@ def make_train_dataset(args, tokenizer):
args.dataset_name, args.dataset_name,
args.dataset_config_name, args.dataset_config_name,
cache_dir=args.cache_dir, cache_dir=args.cache_dir,
streaming=args.streaming,
) )
else: else:
data_files = {} data_files = {}
...@@ -483,7 +491,10 @@ def make_train_dataset(args, tokenizer): ...@@ -483,7 +491,10 @@ def make_train_dataset(args, tokenizer):
# Preprocessing the datasets. # Preprocessing the datasets.
# We need to tokenize inputs and targets. # We need to tokenize inputs and targets.
column_names = dataset["train"].column_names if isinstance(dataset["train"], IterableDataset):
column_names = next(iter(dataset["train"])).keys()
else:
column_names = dataset["train"].column_names
# 6. Get the column names for input/target. # 6. Get the column names for input/target.
if args.image_column is None: if args.image_column is None:
...@@ -565,9 +576,20 @@ def make_train_dataset(args, tokenizer): ...@@ -565,9 +576,20 @@ def make_train_dataset(args, tokenizer):
if jax.process_index() == 0: if jax.process_index() == 0:
if args.max_train_samples is not None: if args.max_train_samples is not None:
dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples)) if args.streaming:
dataset["train"] = dataset["train"].shuffle(seed=args.seed).take(args.max_train_samples)
else:
dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
# Set the training transforms # Set the training transforms
train_dataset = dataset["train"].with_transform(preprocess_train) if args.streaming:
train_dataset = dataset["train"].map(
preprocess_train,
batched=True,
batch_size=batch_size,
remove_columns=list(dataset["train"].features.keys()),
)
else:
train_dataset = dataset["train"].with_transform(preprocess_train)
return train_dataset return train_dataset
...@@ -661,12 +683,12 @@ def main(): ...@@ -661,12 +683,12 @@ def main():
raise NotImplementedError("No tokenizer specified!") raise NotImplementedError("No tokenizer specified!")
# Get the datasets: you can either provide your own training and evaluation files (see below) # Get the datasets: you can either provide your own training and evaluation files (see below)
train_dataset = make_train_dataset(args, tokenizer)
total_train_batch_size = args.train_batch_size * jax.local_device_count() * args.gradient_accumulation_steps total_train_batch_size = args.train_batch_size * jax.local_device_count() * args.gradient_accumulation_steps
train_dataset = make_train_dataset(args, tokenizer, batch_size=total_train_batch_size)
train_dataloader = torch.utils.data.DataLoader( train_dataloader = torch.utils.data.DataLoader(
train_dataset, train_dataset,
shuffle=True, shuffle=not args.streaming,
collate_fn=collate_fn, collate_fn=collate_fn,
batch_size=total_train_batch_size, batch_size=total_train_batch_size,
num_workers=args.dataloader_num_workers, num_workers=args.dataloader_num_workers,
...@@ -897,7 +919,11 @@ def main(): ...@@ -897,7 +919,11 @@ def main():
vae_params = jax_utils.replicate(vae_params) vae_params = jax_utils.replicate(vae_params)
# Train! # Train!
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) if args.streaming:
dataset_length = args.max_train_samples
else:
dataset_length = len(train_dataloader)
num_update_steps_per_epoch = math.ceil(dataset_length / args.gradient_accumulation_steps)
# Scheduler and math around the number of training steps. # Scheduler and math around the number of training steps.
if args.max_train_steps is None: if args.max_train_steps is None:
...@@ -906,7 +932,7 @@ def main(): ...@@ -906,7 +932,7 @@ def main():
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
logger.info("***** Running training *****") logger.info("***** Running training *****")
logger.info(f" Num examples = {len(train_dataset)}") logger.info(f" Num examples = {args.max_train_samples if args.streaming else len(train_dataset)}")
logger.info(f" Num Epochs = {args.num_train_epochs}") logger.info(f" Num Epochs = {args.num_train_epochs}")
logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
logger.info(f" Total train batch size (w. parallel & distributed) = {total_train_batch_size}") logger.info(f" Total train batch size (w. parallel & distributed) = {total_train_batch_size}")
...@@ -916,7 +942,7 @@ def main(): ...@@ -916,7 +942,7 @@ def main():
wandb.define_metric("*", step_metric="train/step") wandb.define_metric("*", step_metric="train/step")
wandb.config.update( wandb.config.update(
{ {
"num_train_examples": len(train_dataset), "num_train_examples": args.max_train_samples if args.streaming else len(train_dataset),
"total_train_batch_size": total_train_batch_size, "total_train_batch_size": total_train_batch_size,
"total_optimization_step": args.num_train_epochs * num_update_steps_per_epoch, "total_optimization_step": args.num_train_epochs * num_update_steps_per_epoch,
"num_devices": jax.device_count(), "num_devices": jax.device_count(),
...@@ -935,7 +961,11 @@ def main(): ...@@ -935,7 +961,11 @@ def main():
train_metrics = [] train_metrics = []
steps_per_epoch = len(train_dataset) // total_train_batch_size steps_per_epoch = (
args.max_train_samples // total_train_batch_size
if args.streaming
else len(train_dataset) // total_train_batch_size
)
train_step_progress_bar = tqdm( train_step_progress_bar = tqdm(
total=steps_per_epoch, total=steps_per_epoch,
desc="Training...", desc="Training...",
...@@ -980,7 +1010,8 @@ def main(): ...@@ -980,7 +1010,8 @@ def main():
# Create the pipeline using using the trained modules and save it. # Create the pipeline using using the trained modules and save it.
if jax.process_index() == 0: if jax.process_index() == 0:
image_logs = log_validation(controlnet, state.params, tokenizer, args, validation_rng, weight_dtype) if args.validation_prompt is not None:
image_logs = log_validation(controlnet, state.params, tokenizer, args, validation_rng, weight_dtype)
controlnet.save_pretrained( controlnet.save_pretrained(
args.output_dir, args.output_dir,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment