Unverified Commit d06e0694 authored by Andreas Steiner's avatar Andreas Steiner Committed by GitHub
Browse files

Adds profiling flags, computes train metrics average. (#3053)

* WIP controlnet training

- bugfix --streaming
- bugfix running report_to!='wandb'
- adds memory profile before validation

* Adds final logging statement.

* Sets train epochs to 11.

Looking at a longer ~16ep run, we see only good validation images
after ~11ep:

https://wandb.ai/andsteing/controlnet_fill50k/runs/3j2hx6n8



* Removes --logging_dir (it's not used).

* Adds --profile flags.

* Updates --output_dir=runs/fill-circle-{timestamp}.

* Compute mean of `train_metrics`.

Previously `train_metrics[-1]` was logged, resulting in very bumpy train
metrics.

* Improves logging a bit.

- adds l2_grads gradient norm logging
- adds steps_per_sec
- sets walltime as x coordinate of train/step
- logs controlnet_params config

* Adds --ccache (doesn't really help though).

* minor fix in controlnet flax example (#2986)

* fix the error when push_to_hub but not log validation

* contronet_from_pt & controlnet_revision

* add intermediate checkpointing to the guide

* Bugfix --profile_steps

* Sets `RACKER_PROJECT_NAME='controlnet_fill50k'`.

* Logs fractional epoch.

* Adds relative `walltime` metric.

* Adds `StepTraceAnnotation` and uses `global_step` insetad of `step`.

* Applied `black`.

* Streamlines commands in README a bit.

* Removes `--ccache`.

This makes only a very small difference (~1 min) with this model size, so removing
the option introduced in cdb3cc.

* Re-ran `black`.

* Update examples/controlnet/README.md
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>

* Converts spaces to tab.

* Removes repeated args.

* Skips first step (compilation) in profiling

* Updates README with profiling instructions.

* Unifies tabs/spaces in README.

* Re-ran style & quality.

---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent 0a73b4d3
...@@ -284,9 +284,9 @@ TPU_TYPE=v4-8 ...@@ -284,9 +284,9 @@ TPU_TYPE=v4-8
VM_NAME=hg_flax VM_NAME=hg_flax
gcloud alpha compute tpus tpu-vm create $VM_NAME \ gcloud alpha compute tpus tpu-vm create $VM_NAME \
--zone $ZONE \ --zone $ZONE \
--accelerator-type $TPU_TYPE \ --accelerator-type $TPU_TYPE \
--version tpu-vm-v4-base --version tpu-vm-v4-base
gcloud alpha compute tpus tpu-vm ssh $VM_NAME --zone $ZONE -- \ gcloud alpha compute tpus tpu-vm ssh $VM_NAME --zone $ZONE -- \
``` ```
...@@ -326,6 +326,7 @@ If you want to use Weights and Biases logging, you should also install `wandb` n ...@@ -326,6 +326,7 @@ If you want to use Weights and Biases logging, you should also install `wandb` n
pip install wandb pip install wandb
``` ```
Now let's downloading two conditioning images that we will use to run validation during the training in order to track our progress Now let's downloading two conditioning images that we will use to run validation during the training in order to track our progress
``` ```
...@@ -343,8 +344,8 @@ Make sure you have the `MODEL_DIR`,`OUTPUT_DIR` and `HUB_MODEL_ID` environment v ...@@ -343,8 +344,8 @@ Make sure you have the `MODEL_DIR`,`OUTPUT_DIR` and `HUB_MODEL_ID` environment v
```bash ```bash
export MODEL_DIR="runwayml/stable-diffusion-v1-5" export MODEL_DIR="runwayml/stable-diffusion-v1-5"
export OUTPUT_DIR="control_out" export OUTPUT_DIR="runs/fill-circle-{timestamp}"
export HUB_MODEL_ID="fill-circle-controlnet" export HUB_MODEL_ID="controlnet-fill-circle"
``` ```
And finally start the training And finally start the training
...@@ -363,32 +364,36 @@ python3 train_controlnet_flax.py \ ...@@ -363,32 +364,36 @@ python3 train_controlnet_flax.py \
--revision="non-ema" \ --revision="non-ema" \
--from_pt \ --from_pt \
--report_to="wandb" \ --report_to="wandb" \
--max_train_steps=10000 \ --tracker_project_name=$HUB_MODEL_ID \
--num_train_epochs=11 \
--push_to_hub \ --push_to_hub \
--hub_model_id=$HUB_MODEL_ID --hub_model_id=$HUB_MODEL_ID
``` ```
Since we passed the `--push_to_hub` flag, it will automatically create a model repo under your huggingface account based on `$HUB_MODEL_ID`. By the end of training, the final checkpoint will be automatically stored on the hub. You can find an example model repo [here](https://huggingface.co/YiYiXu/fill-circle-controlnet). Since we passed the `--push_to_hub` flag, it will automatically create a model repo under your huggingface account based on `$HUB_MODEL_ID`. By the end of training, the final checkpoint will be automatically stored on the hub. You can find an example model repo [here](https://huggingface.co/YiYiXu/fill-circle-controlnet).
Our training script also provides limited support for streaming large datasets from the Hugging Face Hub. In order to enable streaming, one must also set `--max_train_samples`. Here is an example command: Our training script also provides limited support for streaming large datasets from the Hugging Face Hub. In order to enable streaming, one must also set `--max_train_samples`. Here is an example command (from [this blog article](https://huggingface.co/blog/train-your-controlnet)):
```bash ```bash
export MODEL_DIR="runwayml/stable-diffusion-v1-5"
export OUTPUT_DIR="runs/uncanny-faces-{timestamp}"
export HUB_MODEL_ID="controlnet-uncanny-faces"
python3 train_controlnet_flax.py \ python3 train_controlnet_flax.py \
--pretrained_model_name_or_path=$MODEL_DIR \ --pretrained_model_name_or_path=$MODEL_DIR \
--output_dir=$OUTPUT_DIR \ --output_dir=$OUTPUT_DIR \
--dataset_name=multimodalart/facesyntheticsspigacaptioned \ --dataset_name=multimodalart/facesyntheticsspigacaptioned \
--streaming \ --streaming \
--conditioning_image_column=spiga_seg \ --conditioning_image_column=spiga_seg \
--image_column=image \ --image_column=image \
--caption_column=image_caption \ --caption_column=image_caption \
--resolution=512 \ --resolution=512 \
--max_train_samples 50 \ --max_train_samples 100000 \
--max_train_steps 5 \ --learning_rate=1e-5 \
--learning_rate=1e-5 \ --train_batch_size=1 \
--validation_steps=2 \ --revision="flax" \
--train_batch_size=1 \ --report_to="wandb" \
--revision="flax" \ --tracker_project_name=$HUB_MODEL_ID
--report_to="wandb"
``` ```
Note, however, that the performance of the TPUs might get bottlenecked as streaming with `datasets` is not optimized for images. For ensuring maximum throughput, we encourage you to explore the following options: Note, however, that the performance of the TPUs might get bottlenecked as streaming with `datasets` is not optimized for images. For ensuring maximum throughput, we encourage you to explore the following options:
...@@ -400,16 +405,35 @@ Note, however, that the performance of the TPUs might get bottlenecked as stream ...@@ -400,16 +405,35 @@ Note, however, that the performance of the TPUs might get bottlenecked as stream
When work with a larger dataset, you may need to run training process for a long time and it’s useful to save regular checkpoints during the process. You can use the following argument to enable intermediate checkpointing: When work with a larger dataset, you may need to run training process for a long time and it’s useful to save regular checkpoints during the process. You can use the following argument to enable intermediate checkpointing:
```bash ```bash
--checkpointing_steps=500 --checkpointing_steps=500
``` ```
This will save the trained model in subfolders of your output_dir. Subfolder names is the number of steps performed so far; for example: a checkpoint saved after 500 training steps would be saved in a subfolder named 500 This will save the trained model in subfolders of your output_dir. Subfolder names is the number of steps performed so far; for example: a checkpoint saved after 500 training steps would be saved in a subfolder named 500
You can then start your training from this saved checkpoint with You can then start your training from this saved checkpoint with
```bash ```bash
--controlnet_model_name_or_path="./control_out/500" --controlnet_model_name_or_path="./control_out/500"
``` ```
We support training with the Min-SNR weighting strategy proposed in [Efficient Diffusion Training via Min-SNR Weighting Strategy](https://arxiv.org/abs/2303.09556) which helps to achieve faster convergence by rebalancing the loss. To use it, one needs to set the `--snr_gamma` argument. The recommended value when using it is `5.0`. We support training with the Min-SNR weighting strategy proposed in [Efficient Diffusion Training via Min-SNR Weighting Strategy](https://arxiv.org/abs/2303.09556) which helps to achieve faster convergence by rebalancing the loss. To use it, one needs to set the `--snr_gamma` argument. The recommended value when using it is `5.0`.
We also support gradient accumulation - it is a technique that lets you use a bigger batch size than your machine would normally be able to fit into memory. You can use `gradient_accumulation_steps` argument to set gradient accumulation steps. The ControlNet author recommends using gradient accumulation to achieve better convergence. Read more [here](https://github.com/lllyasviel/ControlNet/blob/main/docs/train.md#more-consideration-sudden-converge-phenomenon-and-gradient-accumulation). We also support gradient accumulation - it is a technique that lets you use a bigger batch size than your machine would normally be able to fit into memory. You can use `gradient_accumulation_steps` argument to set gradient accumulation steps. The ControlNet author recommends using gradient accumulation to achieve better convergence. Read more [here](https://github.com/lllyasviel/ControlNet/blob/main/docs/train.md#more-consideration-sudden-converge-phenomenon-and-gradient-accumulation).
\ No newline at end of file
You can **profile your code** with:
```bash
--profile_steps==5
```
Refer to the [JAX documentation on profiling](https://jax.readthedocs.io/en/latest/profiling.html). To inspect the profile trace, you'll have to install and start Tensorboard with the profile plugin:
```bash
pip install tensorflow tensorboard-plugin-profile
tensorboard --logdir runs/fill-circle-100steps-20230411_165612/
```
The profile can then be inspected at http://localhost:6006/#profile
Sometimes you'll get version conflicts (error messages like `Duplicate plugins for name projector`), which means that you have to uninstall and reinstall all versions of Tensorflow/Tensorboard (e.g. with `pip uninstall tensorflow tf-nightly tensorboard tb-nightly tensorboard-plugin-profile && pip install tf-nightly tbp-nightly tensorboard-plugin-profile`).
Note that the debugging functionality of the Tensorboard `profile` plugin is still under active development. Not all views are fully functional, and for example the `trace_viewer` cuts off events after 1M (which can result in all your device traces getting lost if you for example profile the compilation step by accident).
...@@ -18,6 +18,7 @@ import logging ...@@ -18,6 +18,7 @@ import logging
import math import math
import os import os
import random import random
import time
from pathlib import Path from pathlib import Path
import jax import jax
...@@ -220,6 +221,28 @@ def parse_args(): ...@@ -220,6 +221,28 @@ def parse_args():
default=None, default=None,
help="Revision of controlnet model identifier from huggingface.co/models.", help="Revision of controlnet model identifier from huggingface.co/models.",
) )
parser.add_argument(
"--profile_steps",
type=int,
default=0,
help="How many training steps to profile in the beginning.",
)
parser.add_argument(
"--profile_validation",
action="store_true",
help="Whether to profile the (last) validation.",
)
parser.add_argument(
"--profile_memory",
action="store_true",
help="Whether to dump an initial (before training loop) and a final (at program end) memory profile.",
)
parser.add_argument(
"--ccache",
type=str,
default=None,
help="Enables compilation cache.",
)
parser.add_argument( parser.add_argument(
"--controlnet_from_pt", "--controlnet_from_pt",
action="store_true", action="store_true",
...@@ -234,8 +257,9 @@ def parse_args(): ...@@ -234,8 +257,9 @@ def parse_args():
parser.add_argument( parser.add_argument(
"--output_dir", "--output_dir",
type=str, type=str,
default="controlnet-model", default="runs/{timestamp}",
help="The output directory where the model predictions and checkpoints will be written.", help="The output directory where the model predictions and checkpoints will be written. "
"Can contain placeholders: {timestamp}.",
) )
parser.add_argument( parser.add_argument(
"--cache_dir", "--cache_dir",
...@@ -317,15 +341,6 @@ def parse_args(): ...@@ -317,15 +341,6 @@ def parse_args():
default=None, default=None,
help="The name of the repository to keep in sync with the local `output_dir`.", help="The name of the repository to keep in sync with the local `output_dir`.",
) )
parser.add_argument(
"--logging_dir",
type=str,
default="logs",
help=(
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
),
)
parser.add_argument( parser.add_argument(
"--logging_steps", "--logging_steps",
type=int, type=int,
...@@ -459,6 +474,8 @@ def parse_args(): ...@@ -459,6 +474,8 @@ def parse_args():
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
args = parser.parse_args() args = parser.parse_args()
args.output_dir = args.output_dir.replace("{timestamp}", time.strftime("%Y%m%d_%H%M%S"))
env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
if env_local_rank != -1 and env_local_rank != args.local_rank: if env_local_rank != -1 and env_local_rank != args.local_rank:
args.local_rank = env_local_rank args.local_rank = env_local_rank
...@@ -952,6 +969,11 @@ def main(): ...@@ -952,6 +969,11 @@ def main():
metrics = {"loss": loss} metrics = {"loss": loss}
metrics = jax.lax.pmean(metrics, axis_name="batch") metrics = jax.lax.pmean(metrics, axis_name="batch")
def l2(xs):
return jnp.sqrt(sum([jnp.vdot(x, x) for x in jax.tree_util.tree_leaves(xs)]))
metrics["l2_grads"] = l2(jax.tree_util.tree_leaves(grad))
return new_state, metrics, new_train_rng return new_state, metrics, new_train_rng
# Create parallel version of the train step # Create parallel version of the train step
...@@ -983,32 +1005,38 @@ def main(): ...@@ -983,32 +1005,38 @@ def main():
logger.info(f" Total train batch size (w. parallel & distributed) = {total_train_batch_size}") logger.info(f" Total train batch size (w. parallel & distributed) = {total_train_batch_size}")
logger.info(f" Total optimization steps = {args.num_train_epochs * num_update_steps_per_epoch}") logger.info(f" Total optimization steps = {args.num_train_epochs * num_update_steps_per_epoch}")
if jax.process_index() == 0: if jax.process_index() == 0 and args.report_to == "wandb":
wandb.define_metric("*", step_metric="train/step") wandb.define_metric("*", step_metric="train/step")
wandb.define_metric("train/step", step_metric="walltime")
wandb.config.update( wandb.config.update(
{ {
"num_train_examples": args.max_train_samples if args.streaming else len(train_dataset), "num_train_examples": args.max_train_samples if args.streaming else len(train_dataset),
"total_train_batch_size": total_train_batch_size, "total_train_batch_size": total_train_batch_size,
"total_optimization_step": args.num_train_epochs * num_update_steps_per_epoch, "total_optimization_step": args.num_train_epochs * num_update_steps_per_epoch,
"num_devices": jax.device_count(), "num_devices": jax.device_count(),
"controlnet_params": sum(np.prod(x.shape) for x in jax.tree_util.tree_leaves(state.params)),
} }
) )
global_step = 0 global_step = step0 = 0
epochs = tqdm( epochs = tqdm(
range(args.num_train_epochs), range(args.num_train_epochs),
desc="Epoch ... ", desc="Epoch ... ",
position=0, position=0,
disable=jax.process_index() > 0, disable=jax.process_index() > 0,
) )
if args.profile_memory:
jax.profiler.save_device_memory_profile(os.path.join(args.output_dir, "memory_initial.prof"))
t00 = t0 = time.monotonic()
for epoch in epochs: for epoch in epochs:
# ======================== Training ================================ # ======================== Training ================================
train_metrics = [] train_metrics = []
train_metric = None
steps_per_epoch = ( steps_per_epoch = (
args.max_train_samples // total_train_batch_size args.max_train_samples // total_train_batch_size
if args.streaming if args.streaming or args.max_train_samples
else len(train_dataset) // total_train_batch_size else len(train_dataset) // total_train_batch_size
) )
train_step_progress_bar = tqdm( train_step_progress_bar = tqdm(
...@@ -1020,10 +1048,18 @@ def main(): ...@@ -1020,10 +1048,18 @@ def main():
) )
# train # train
for batch in train_dataloader: for batch in train_dataloader:
if args.profile_steps and global_step == 1:
train_metric["loss"].block_until_ready()
jax.profiler.start_trace(args.output_dir)
if args.profile_steps and global_step == 1 + args.profile_steps:
train_metric["loss"].block_until_ready()
jax.profiler.stop_trace()
batch = shard(batch) batch = shard(batch)
state, train_metric, train_rngs = p_train_step( with jax.profiler.StepTraceAnnotation("train", step_num=global_step):
state, unet_params, text_encoder_params, vae_params, batch, train_rngs state, train_metric, train_rngs = p_train_step(
) state, unet_params, text_encoder_params, vae_params, batch, train_rngs
)
train_metrics.append(train_metric) train_metrics.append(train_metric)
train_step_progress_bar.update(1) train_step_progress_bar.update(1)
...@@ -1041,13 +1077,19 @@ def main(): ...@@ -1041,13 +1077,19 @@ def main():
if global_step % args.logging_steps == 0 and jax.process_index() == 0: if global_step % args.logging_steps == 0 and jax.process_index() == 0:
if args.report_to == "wandb": if args.report_to == "wandb":
train_metrics = jax_utils.unreplicate(train_metrics)
train_metrics = jax.tree_util.tree_map(lambda *m: jnp.array(m).mean(), *train_metrics)
wandb.log( wandb.log(
{ {
"walltime": time.monotonic() - t00,
"train/step": global_step, "train/step": global_step,
"train/epoch": epoch, "train/epoch": global_step / dataset_length,
"train/loss": jax_utils.unreplicate(train_metric)["loss"], "train/steps_per_sec": (global_step - step0) / (time.monotonic() - t0),
**{f"train/{k}": v for k, v in train_metrics.items()},
} }
) )
t0, step0 = time.monotonic(), global_step
train_metrics = []
if global_step % args.checkpointing_steps == 0 and jax.process_index() == 0: if global_step % args.checkpointing_steps == 0 and jax.process_index() == 0:
controlnet.save_pretrained( controlnet.save_pretrained(
f"{args.output_dir}/{global_step}", f"{args.output_dir}/{global_step}",
...@@ -1058,10 +1100,14 @@ def main(): ...@@ -1058,10 +1100,14 @@ def main():
train_step_progress_bar.close() train_step_progress_bar.close()
epochs.write(f"Epoch... ({epoch + 1}/{args.num_train_epochs} | Loss: {train_metric['loss']})") epochs.write(f"Epoch... ({epoch + 1}/{args.num_train_epochs} | Loss: {train_metric['loss']})")
# Create the pipeline using using the trained modules and save it. # Final validation & store model.
if jax.process_index() == 0: if jax.process_index() == 0:
if args.validation_prompt is not None: if args.validation_prompt is not None:
if args.profile_validation:
jax.profiler.start_trace(args.output_dir)
image_logs = log_validation(controlnet, state.params, tokenizer, args, validation_rng, weight_dtype) image_logs = log_validation(controlnet, state.params, tokenizer, args, validation_rng, weight_dtype)
if args.profile_validation:
jax.profiler.stop_trace()
else: else:
image_logs = None image_logs = None
...@@ -1084,6 +1130,10 @@ def main(): ...@@ -1084,6 +1130,10 @@ def main():
ignore_patterns=["step_*", "epoch_*"], ignore_patterns=["step_*", "epoch_*"],
) )
if args.profile_memory:
jax.profiler.save_device_memory_profile(os.path.join(args.output_dir, "memory_final.prof"))
logger.info("Finished training.")
if __name__ == "__main__": if __name__ == "__main__":
main() main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment