Unverified Commit d06e0694 authored by Andreas Steiner's avatar Andreas Steiner Committed by GitHub
Browse files

Adds profiling flags, computes train metrics average. (#3053)

* WIP controlnet training

- bugfix --streaming
- bugfix running report_to!='wandb'
- adds memory profile before validation

* Adds final logging statement.

* Sets train epochs to 11.

Looking at a longer ~16ep run, we see only good validation images
after ~11ep:

https://wandb.ai/andsteing/controlnet_fill50k/runs/3j2hx6n8



* Removes --logging_dir (it's not used).

* Adds --profile flags.

* Updates --output_dir=runs/fill-circle-{timestamp}.

* Compute mean of `train_metrics`.

Previously `train_metrics[-1]` was logged, resulting in very bumpy train
metrics.

* Improves logging a bit.

- adds l2_grads gradient norm logging
- adds steps_per_sec
- sets walltime as x coordinate of train/step
- logs controlnet_params config

* Adds --ccache (doesn't really help though).

* minor fix in controlnet flax example (#2986)

* fix the error when push_to_hub but not log validation

* contronet_from_pt & controlnet_revision

* add intermediate checkpointing to the guide

* Bugfix --profile_steps

* Sets `RACKER_PROJECT_NAME='controlnet_fill50k'`.

* Logs fractional epoch.

* Adds relative `walltime` metric.

* Adds `StepTraceAnnotation` and uses `global_step` insetad of `step`.

* Applied `black`.

* Streamlines commands in README a bit.

* Removes `--ccache`.

This makes only a very small difference (~1 min) with this model size, so removing
the option introduced in cdb3cc.

* Re-ran `black`.

* Update examples/controlnet/README.md
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>

* Converts spaces to tab.

* Removes repeated args.

* Skips first step (compilation) in profiling

* Updates README with profiling instructions.

* Unifies tabs/spaces in README.

* Re-ran style & quality.

---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent 0a73b4d3
......@@ -284,9 +284,9 @@ TPU_TYPE=v4-8
VM_NAME=hg_flax
gcloud alpha compute tpus tpu-vm create $VM_NAME \
--zone $ZONE \
--accelerator-type $TPU_TYPE \
--version tpu-vm-v4-base
--zone $ZONE \
--accelerator-type $TPU_TYPE \
--version tpu-vm-v4-base
gcloud alpha compute tpus tpu-vm ssh $VM_NAME --zone $ZONE -- \
```
......@@ -326,6 +326,7 @@ If you want to use Weights and Biases logging, you should also install `wandb` n
pip install wandb
```
Now let's downloading two conditioning images that we will use to run validation during the training in order to track our progress
```
......@@ -343,8 +344,8 @@ Make sure you have the `MODEL_DIR`,`OUTPUT_DIR` and `HUB_MODEL_ID` environment v
```bash
export MODEL_DIR="runwayml/stable-diffusion-v1-5"
export OUTPUT_DIR="control_out"
export HUB_MODEL_ID="fill-circle-controlnet"
export OUTPUT_DIR="runs/fill-circle-{timestamp}"
export HUB_MODEL_ID="controlnet-fill-circle"
```
And finally start the training
......@@ -363,32 +364,36 @@ python3 train_controlnet_flax.py \
--revision="non-ema" \
--from_pt \
--report_to="wandb" \
--max_train_steps=10000 \
--tracker_project_name=$HUB_MODEL_ID \
--num_train_epochs=11 \
--push_to_hub \
--hub_model_id=$HUB_MODEL_ID
```
Since we passed the `--push_to_hub` flag, it will automatically create a model repo under your huggingface account based on `$HUB_MODEL_ID`. By the end of training, the final checkpoint will be automatically stored on the hub. You can find an example model repo [here](https://huggingface.co/YiYiXu/fill-circle-controlnet).
Our training script also provides limited support for streaming large datasets from the Hugging Face Hub. In order to enable streaming, one must also set `--max_train_samples`. Here is an example command:
Our training script also provides limited support for streaming large datasets from the Hugging Face Hub. In order to enable streaming, one must also set `--max_train_samples`. Here is an example command (from [this blog article](https://huggingface.co/blog/train-your-controlnet)):
```bash
export MODEL_DIR="runwayml/stable-diffusion-v1-5"
export OUTPUT_DIR="runs/uncanny-faces-{timestamp}"
export HUB_MODEL_ID="controlnet-uncanny-faces"
python3 train_controlnet_flax.py \
--pretrained_model_name_or_path=$MODEL_DIR \
--output_dir=$OUTPUT_DIR \
--dataset_name=multimodalart/facesyntheticsspigacaptioned \
--streaming \
--conditioning_image_column=spiga_seg \
--image_column=image \
--caption_column=image_caption \
--resolution=512 \
--max_train_samples 50 \
--max_train_steps 5 \
--learning_rate=1e-5 \
--validation_steps=2 \
--train_batch_size=1 \
--revision="flax" \
--report_to="wandb"
--pretrained_model_name_or_path=$MODEL_DIR \
--output_dir=$OUTPUT_DIR \
--dataset_name=multimodalart/facesyntheticsspigacaptioned \
--streaming \
--conditioning_image_column=spiga_seg \
--image_column=image \
--caption_column=image_caption \
--resolution=512 \
--max_train_samples 100000 \
--learning_rate=1e-5 \
--train_batch_size=1 \
--revision="flax" \
--report_to="wandb" \
--tracker_project_name=$HUB_MODEL_ID
```
Note, however, that the performance of the TPUs might get bottlenecked as streaming with `datasets` is not optimized for images. For ensuring maximum throughput, we encourage you to explore the following options:
......@@ -400,16 +405,35 @@ Note, however, that the performance of the TPUs might get bottlenecked as stream
When work with a larger dataset, you may need to run training process for a long time and it’s useful to save regular checkpoints during the process. You can use the following argument to enable intermediate checkpointing:
```bash
--checkpointing_steps=500
--checkpointing_steps=500
```
This will save the trained model in subfolders of your output_dir. Subfolder names is the number of steps performed so far; for example: a checkpoint saved after 500 training steps would be saved in a subfolder named 500
You can then start your training from this saved checkpoint with
```bash
--controlnet_model_name_or_path="./control_out/500"
--controlnet_model_name_or_path="./control_out/500"
```
We support training with the Min-SNR weighting strategy proposed in [Efficient Diffusion Training via Min-SNR Weighting Strategy](https://arxiv.org/abs/2303.09556) which helps to achieve faster convergence by rebalancing the loss. To use it, one needs to set the `--snr_gamma` argument. The recommended value when using it is `5.0`.
We also support gradient accumulation - it is a technique that lets you use a bigger batch size than your machine would normally be able to fit into memory. You can use `gradient_accumulation_steps` argument to set gradient accumulation steps. The ControlNet author recommends using gradient accumulation to achieve better convergence. Read more [here](https://github.com/lllyasviel/ControlNet/blob/main/docs/train.md#more-consideration-sudden-converge-phenomenon-and-gradient-accumulation).
\ No newline at end of file
We also support gradient accumulation - it is a technique that lets you use a bigger batch size than your machine would normally be able to fit into memory. You can use `gradient_accumulation_steps` argument to set gradient accumulation steps. The ControlNet author recommends using gradient accumulation to achieve better convergence. Read more [here](https://github.com/lllyasviel/ControlNet/blob/main/docs/train.md#more-consideration-sudden-converge-phenomenon-and-gradient-accumulation).
You can **profile your code** with:
```bash
--profile_steps==5
```
Refer to the [JAX documentation on profiling](https://jax.readthedocs.io/en/latest/profiling.html). To inspect the profile trace, you'll have to install and start Tensorboard with the profile plugin:
```bash
pip install tensorflow tensorboard-plugin-profile
tensorboard --logdir runs/fill-circle-100steps-20230411_165612/
```
The profile can then be inspected at http://localhost:6006/#profile
Sometimes you'll get version conflicts (error messages like `Duplicate plugins for name projector`), which means that you have to uninstall and reinstall all versions of Tensorflow/Tensorboard (e.g. with `pip uninstall tensorflow tf-nightly tensorboard tb-nightly tensorboard-plugin-profile && pip install tf-nightly tbp-nightly tensorboard-plugin-profile`).
Note that the debugging functionality of the Tensorboard `profile` plugin is still under active development. Not all views are fully functional, and for example the `trace_viewer` cuts off events after 1M (which can result in all your device traces getting lost if you for example profile the compilation step by accident).
......@@ -18,6 +18,7 @@ import logging
import math
import os
import random
import time
from pathlib import Path
import jax
......@@ -220,6 +221,28 @@ def parse_args():
default=None,
help="Revision of controlnet model identifier from huggingface.co/models.",
)
parser.add_argument(
"--profile_steps",
type=int,
default=0,
help="How many training steps to profile in the beginning.",
)
parser.add_argument(
"--profile_validation",
action="store_true",
help="Whether to profile the (last) validation.",
)
parser.add_argument(
"--profile_memory",
action="store_true",
help="Whether to dump an initial (before training loop) and a final (at program end) memory profile.",
)
parser.add_argument(
"--ccache",
type=str,
default=None,
help="Enables compilation cache.",
)
parser.add_argument(
"--controlnet_from_pt",
action="store_true",
......@@ -234,8 +257,9 @@ def parse_args():
parser.add_argument(
"--output_dir",
type=str,
default="controlnet-model",
help="The output directory where the model predictions and checkpoints will be written.",
default="runs/{timestamp}",
help="The output directory where the model predictions and checkpoints will be written. "
"Can contain placeholders: {timestamp}.",
)
parser.add_argument(
"--cache_dir",
......@@ -317,15 +341,6 @@ def parse_args():
default=None,
help="The name of the repository to keep in sync with the local `output_dir`.",
)
parser.add_argument(
"--logging_dir",
type=str,
default="logs",
help=(
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
),
)
parser.add_argument(
"--logging_steps",
type=int,
......@@ -459,6 +474,8 @@ def parse_args():
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
args = parser.parse_args()
args.output_dir = args.output_dir.replace("{timestamp}", time.strftime("%Y%m%d_%H%M%S"))
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
if env_local_rank != -1 and env_local_rank != args.local_rank:
args.local_rank = env_local_rank
......@@ -952,6 +969,11 @@ def main():
metrics = {"loss": loss}
metrics = jax.lax.pmean(metrics, axis_name="batch")
def l2(xs):
return jnp.sqrt(sum([jnp.vdot(x, x) for x in jax.tree_util.tree_leaves(xs)]))
metrics["l2_grads"] = l2(jax.tree_util.tree_leaves(grad))
return new_state, metrics, new_train_rng
# Create parallel version of the train step
......@@ -983,32 +1005,38 @@ def main():
logger.info(f" Total train batch size (w. parallel & distributed) = {total_train_batch_size}")
logger.info(f" Total optimization steps = {args.num_train_epochs * num_update_steps_per_epoch}")
if jax.process_index() == 0:
if jax.process_index() == 0 and args.report_to == "wandb":
wandb.define_metric("*", step_metric="train/step")
wandb.define_metric("train/step", step_metric="walltime")
wandb.config.update(
{
"num_train_examples": args.max_train_samples if args.streaming else len(train_dataset),
"total_train_batch_size": total_train_batch_size,
"total_optimization_step": args.num_train_epochs * num_update_steps_per_epoch,
"num_devices": jax.device_count(),
"controlnet_params": sum(np.prod(x.shape) for x in jax.tree_util.tree_leaves(state.params)),
}
)
global_step = 0
global_step = step0 = 0
epochs = tqdm(
range(args.num_train_epochs),
desc="Epoch ... ",
position=0,
disable=jax.process_index() > 0,
)
if args.profile_memory:
jax.profiler.save_device_memory_profile(os.path.join(args.output_dir, "memory_initial.prof"))
t00 = t0 = time.monotonic()
for epoch in epochs:
# ======================== Training ================================
train_metrics = []
train_metric = None
steps_per_epoch = (
args.max_train_samples // total_train_batch_size
if args.streaming
if args.streaming or args.max_train_samples
else len(train_dataset) // total_train_batch_size
)
train_step_progress_bar = tqdm(
......@@ -1020,10 +1048,18 @@ def main():
)
# train
for batch in train_dataloader:
if args.profile_steps and global_step == 1:
train_metric["loss"].block_until_ready()
jax.profiler.start_trace(args.output_dir)
if args.profile_steps and global_step == 1 + args.profile_steps:
train_metric["loss"].block_until_ready()
jax.profiler.stop_trace()
batch = shard(batch)
state, train_metric, train_rngs = p_train_step(
state, unet_params, text_encoder_params, vae_params, batch, train_rngs
)
with jax.profiler.StepTraceAnnotation("train", step_num=global_step):
state, train_metric, train_rngs = p_train_step(
state, unet_params, text_encoder_params, vae_params, batch, train_rngs
)
train_metrics.append(train_metric)
train_step_progress_bar.update(1)
......@@ -1041,13 +1077,19 @@ def main():
if global_step % args.logging_steps == 0 and jax.process_index() == 0:
if args.report_to == "wandb":
train_metrics = jax_utils.unreplicate(train_metrics)
train_metrics = jax.tree_util.tree_map(lambda *m: jnp.array(m).mean(), *train_metrics)
wandb.log(
{
"walltime": time.monotonic() - t00,
"train/step": global_step,
"train/epoch": epoch,
"train/loss": jax_utils.unreplicate(train_metric)["loss"],
"train/epoch": global_step / dataset_length,
"train/steps_per_sec": (global_step - step0) / (time.monotonic() - t0),
**{f"train/{k}": v for k, v in train_metrics.items()},
}
)
t0, step0 = time.monotonic(), global_step
train_metrics = []
if global_step % args.checkpointing_steps == 0 and jax.process_index() == 0:
controlnet.save_pretrained(
f"{args.output_dir}/{global_step}",
......@@ -1058,10 +1100,14 @@ def main():
train_step_progress_bar.close()
epochs.write(f"Epoch... ({epoch + 1}/{args.num_train_epochs} | Loss: {train_metric['loss']})")
# Create the pipeline using using the trained modules and save it.
# Final validation & store model.
if jax.process_index() == 0:
if args.validation_prompt is not None:
if args.profile_validation:
jax.profiler.start_trace(args.output_dir)
image_logs = log_validation(controlnet, state.params, tokenizer, args, validation_rng, weight_dtype)
if args.profile_validation:
jax.profiler.stop_trace()
else:
image_logs = None
......@@ -1084,6 +1130,10 @@ def main():
ignore_patterns=["step_*", "epoch_*"],
)
if args.profile_memory:
jax.profiler.save_device_memory_profile(os.path.join(args.output_dir, "memory_final.prof"))
logger.info("Finished training.")
if __name__ == "__main__":
main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment