Unverified Commit d4f846fa authored by YiYi Xu's avatar YiYi Xu Committed by GitHub
Browse files

[WIP]Flax training script for controlnet (#2818)



* add train_controlnet_flax

---------
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 58fc8244
...@@ -267,3 +267,99 @@ image = pipe( ...@@ -267,3 +267,99 @@ image = pipe(
image.save("./output.png") image.save("./output.png")
``` ```
## Training with Flax/JAX
For faster training on TPUs and GPUs you can leverage the flax training example. Follow the instructions above to get the model and dataset before running the script.
### Running on Google Cloud TPU
See below for commands to set up a TPU VM(`--accelerator-type v4-8`). For more details about how to set up and use TPUs, refer to [Cloud docs for single VM setup](https://cloud.google.com/tpu/docs/run-calculation-jax).
First create a single TPUv4-8 VM and connect to it:
```
ZONE=us-central2-b
TPU_TYPE=v4-8
VM_NAME=hg_flax
gcloud alpha compute tpus tpu-vm create $VM_NAME \
--zone $ZONE \
--accelerator-type $TPU_TYPE \
--version tpu-vm-v4-base
gcloud alpha compute tpus tpu-vm ssh $VM_NAME --zone $ZONE -- \
```
When connected install JAX `0.4.5`:
```
pip install "jax[tpu]==0.4.5" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
```
To verify that JAX was correctly installed, you can run the following command:
```
import jax
jax.device_count()
```
This should display the number of TPU cores, which should be 4 on a TPUv4-8 VM.
Then install Diffusers and the library's training dependencies:
```bash
git clone https://github.com/huggingface/diffusers
cd diffusers
pip install .
```
Then cd in the example folder and run
```bash
pip install -U -r requirements_flax.txt
```
Now let's downloading two conditioning images that we will use to run validation during the training in order to track our progress
```
wget https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_1.png
wget https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_2.png
```
We encourage you to store or share your model with the community. To use huggingface hub, please login to your Hugging Face account, or ([create one](https://huggingface.co/docs/diffusers/main/en/training/hf.co/join) if you don’t have one already):
```
huggingface-cli login
```
Make sure you have the `MODEL_DIR`,`OUTPUT_DIR` and `HUB_MODEL_ID` environment variables set. The `OUTPUT_DIR` and `HUB_MODEL_ID` variables specify where to save the model to on the Hub:
```
export MODEL_DIR="runwayml/stable-diffusion-v1-5"
export OUTPUT_DIR="control_out"
export HUB_MODEL_ID="fill-circle-controlnet"
```
And finally start the training
```
python3 train_controlnet_flax.py \
--pretrained_model_name_or_path=$MODEL_DIR \
--output_dir=$OUTPUT_DIR \
--dataset_name=fusing/fill50k \
--resolution=512 \
--learning_rate=1e-5 \
--validation_image "./conditioning_image_1.png" "./conditioning_image_2.png" \
--validation_prompt "red circle with blue background" "cyan circle with brown floral background" \
--validation_steps=1000 \
--train_batch_size=2 \
--revision="non-ema" \
--from_pt \
--report_to="wandb" \
--max_train_steps=10000 \
--push_to_hub \
--hub_model_id=$HUB_MODEL_ID
```
Since we passed the `--push_to_hub` flag, it will automatically create a model repo under your huggingface account based on `$HUB_MODEL_ID`. By the end of training, the final checkpoint will be automatically stored on the hub. You can find an example model repo [here](https://huggingface.co/YiYiXu/fill-circle-controlnet).
#!/usr/bin/env python
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
import argparse
import logging
import math
import os
import random
from pathlib import Path
from typing import Optional
import jax
import jax.numpy as jnp
import numpy as np
import optax
import torch
import torch.utils.checkpoint
import transformers
from datasets import load_dataset
from flax import jax_utils
from flax.core.frozen_dict import unfreeze
from flax.training import train_state
from flax.training.common_utils import shard
from huggingface_hub import HfFolder, Repository, create_repo, whoami
from PIL import Image
from torchvision import transforms
from tqdm.auto import tqdm
from transformers import CLIPTokenizer, FlaxCLIPTextModel, set_seed
from diffusers import (
FlaxAutoencoderKL,
FlaxControlNetModel,
FlaxDDPMScheduler,
FlaxStableDiffusionControlNetPipeline,
FlaxUNet2DConditionModel,
)
from diffusers.utils import check_min_version, is_wandb_available
if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.15.0.dev0")
logger = logging.getLogger(__name__)
def image_grid(imgs, rows, cols):
assert len(imgs) == rows * cols
w, h = imgs[0].size
grid = Image.new("RGB", size=(cols * w, rows * h))
grid_w, grid_h = grid.size
for i, img in enumerate(imgs):
grid.paste(img, box=(i % cols * w, i // cols * h))
return grid
def log_validation(controlnet, controlnet_params, tokenizer, args, rng, weight_dtype):
logger.info("Running validation... ")
pipeline, params = FlaxStableDiffusionControlNetPipeline.from_pretrained(
args.pretrained_model_name_or_path,
tokenizer=tokenizer,
controlnet=controlnet,
safety_checker=None,
dtype=weight_dtype,
revision=args.revision,
from_pt=args.from_pt,
)
params = jax_utils.replicate(params)
params["controlnet"] = controlnet_params
num_samples = jax.device_count()
prng_seed = jax.random.split(rng, jax.device_count())
if len(args.validation_image) == len(args.validation_prompt):
validation_images = args.validation_image
validation_prompts = args.validation_prompt
elif len(args.validation_image) == 1:
validation_images = args.validation_image * len(args.validation_prompt)
validation_prompts = args.validation_prompt
elif len(args.validation_prompt) == 1:
validation_images = args.validation_image
validation_prompts = args.validation_prompt * len(args.validation_image)
else:
raise ValueError(
"number of `args.validation_image` and `args.validation_prompt` should be checked in `parse_args`"
)
image_logs = []
for validation_prompt, validation_image in zip(validation_prompts, validation_images):
prompts = num_samples * [validation_prompt]
prompt_ids = pipeline.prepare_text_inputs(prompts)
prompt_ids = shard(prompt_ids)
validation_image = Image.open(validation_image)
processed_image = pipeline.prepare_image_inputs(num_samples * [validation_image])
processed_image = shard(processed_image)
images = pipeline(
prompt_ids=prompt_ids,
image=processed_image,
params=params,
prng_seed=prng_seed,
num_inference_steps=50,
jit=True,
).images
images = images.reshape((images.shape[0] * images.shape[1],) + images.shape[-3:])
images = pipeline.numpy_to_pil(images)
image_logs.append(
{"validation_image": validation_image, "images": images, "validation_prompt": validation_prompt}
)
if args.report_to == "wandb":
formatted_images = []
for log in image_logs:
images = log["images"]
validation_prompt = log["validation_prompt"]
validation_image = log["validation_image"]
formatted_images.append(wandb.Image(validation_image, caption="Controlnet conditioning"))
for image in images:
image = wandb.Image(image, caption=validation_prompt)
formatted_images.append(image)
wandb.log({"validation": formatted_images})
else:
logger.warn(f"image logging not implemented for {args.report_to}")
return image_logs
def save_model_card(repo_name, image_logs=None, base_model=str, repo_folder=None):
img_str = ""
for i, log in enumerate(image_logs):
images = log["images"]
validation_prompt = log["validation_prompt"]
validation_image = log["validation_image"]
validation_image.save(os.path.join(repo_folder, "image_control.png"))
img_str += f"prompt: {validation_prompt}\n"
images = [validation_image] + images
image_grid(images, 1, len(images)).save(os.path.join(repo_folder, f"images_{i}.png"))
img_str += f"![images_{i})](./images_{i}.png)\n"
yaml = f"""
---
license: creativeml-openrail-m
base_model: {base_model}
tags:
- stable-diffusion
- stable-diffusion-diffusers
- text-to-image
- diffusers
- controlnet
inference: true
---
"""
model_card = f"""
# controlnet- {repo_name}
These are controlnet weights trained on {base_model} with new type of conditioning. You can find some example images in the following. \n
{img_str}
"""
with open(os.path.join(repo_folder, "README.md"), "w") as f:
f.write(yaml + model_card)
def parse_args():
parser = argparse.ArgumentParser(description="Simple example of a training script.")
parser.add_argument(
"--pretrained_model_name_or_path",
type=str,
required=True,
help="Path to pretrained model or model identifier from huggingface.co/models.",
)
parser.add_argument(
"--controlnet_model_name_or_path",
type=str,
default=None,
help="Path to pretrained controlnet model or model identifier from huggingface.co/models."
" If not specified controlnet weights are initialized from unet.",
)
parser.add_argument(
"--revision",
type=str,
default=None,
help="Revision of pretrained model identifier from huggingface.co/models.",
)
parser.add_argument(
"--from_pt",
action="store_true",
help="Load the pretrained model from a pytorch checkpoint.",
)
parser.add_argument(
"--tokenizer_name",
type=str,
default=None,
help="Pretrained tokenizer name or path if not the same as model_name",
)
parser.add_argument(
"--output_dir",
type=str,
default="controlnet-model",
help="The output directory where the model predictions and checkpoints will be written.",
)
parser.add_argument(
"--cache_dir",
type=str,
default=None,
help="The directory where the downloaded models and datasets will be stored.",
)
parser.add_argument("--seed", type=int, default=0, help="A seed for reproducible training.")
parser.add_argument(
"--resolution",
type=int,
default=512,
help=(
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
" resolution"
),
)
parser.add_argument(
"--train_batch_size", type=int, default=1, help="Batch size (per device) for the training dataloader."
)
parser.add_argument("--num_train_epochs", type=int, default=100)
parser.add_argument(
"--max_train_steps",
type=int,
default=None,
help="Total number of training steps to perform.",
)
parser.add_argument(
"--learning_rate",
type=float,
default=1e-4,
help="Initial learning rate (after the potential warmup period) to use.",
)
parser.add_argument(
"--scale_lr",
action="store_true",
help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
)
parser.add_argument(
"--lr_scheduler",
type=str,
default="constant",
help=(
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
' "constant", "constant_with_warmup"]'
),
)
parser.add_argument(
"--dataloader_num_workers",
type=int,
default=0,
help=(
"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
),
)
parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
parser.add_argument(
"--hub_model_id",
type=str,
default=None,
help="The name of the repository to keep in sync with the local `output_dir`.",
)
parser.add_argument(
"--logging_dir",
type=str,
default="logs",
help=(
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
),
)
parser.add_argument(
"--logging_steps",
type=int,
default=100,
help=("log training metric every X steps to `--report_t`"),
)
parser.add_argument(
"--report_to",
type=str,
default="tensorboard",
help=(
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
),
)
parser.add_argument(
"--mixed_precision",
type=str,
default="no",
choices=["no", "fp16", "bf16"],
help=(
"Whether to use mixed precision. Choose"
"between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
"and an Nvidia Ampere GPU."
),
)
parser.add_argument(
"--dataset_name",
type=str,
default=None,
help=(
"The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
" dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
" or to a folder containing files that 🤗 Datasets can understand."
),
)
parser.add_argument(
"--dataset_config_name",
type=str,
default=None,
help="The config of the Dataset, leave as None if there's only one config.",
)
parser.add_argument(
"--train_data_dir",
type=str,
default=None,
help=(
"A folder containing the training data. Folder contents must follow the structure described in"
" https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
" must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
),
)
parser.add_argument(
"--image_column", type=str, default="image", help="The column of the dataset containing the target image."
)
parser.add_argument(
"--conditioning_image_column",
type=str,
default="conditioning_image",
help="The column of the dataset containing the controlnet conditioning image.",
)
parser.add_argument(
"--caption_column",
type=str,
default="text",
help="The column of the dataset containing a caption or a list of captions.",
)
parser.add_argument(
"--max_train_samples",
type=int,
default=None,
help=(
"For debugging purposes or quicker training, truncate the number of training examples to this "
"value if set."
),
)
parser.add_argument(
"--proportion_empty_prompts",
type=float,
default=0,
help="Proportion of image prompts to be replaced with empty strings. Defaults to 0 (no prompt replacement).",
)
parser.add_argument(
"--validation_prompt",
type=str,
default=None,
nargs="+",
help=(
"A set of prompts evaluated every `--validation_steps` and logged to `--report_to`."
" Provide either a matching number of `--validation_image`s, a single `--validation_image`"
" to be used with all prompts, or a single prompt that will be used with all `--validation_image`s."
),
)
parser.add_argument(
"--validation_image",
type=str,
default=None,
nargs="+",
help=(
"A set of paths to the controlnet conditioning image be evaluated every `--validation_steps`"
" and logged to `--report_to`. Provide either a matching number of `--validation_prompt`s, a"
" a single `--validation_prompt` to be used with all `--validation_image`s, or a single"
" `--validation_image` that will be used with all `--validation_prompt`s."
),
)
parser.add_argument(
"--validation_steps",
type=int,
default=100,
help=(
"Run validation every X steps. Validation consists of running the prompt"
" `args.validation_prompt` and logging the images."
),
)
parser.add_argument(
"--tracker_project_name",
type=str,
default="train_controlnet_flax",
help=("The `project` argument passed to wandb"),
)
parser.add_argument(
"--gradient_accumulation_steps", type=int, default=1, help="Number of steps to accumulate gradients over"
)
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
args = parser.parse_args()
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
if env_local_rank != -1 and env_local_rank != args.local_rank:
args.local_rank = env_local_rank
# Sanity checks
if args.dataset_name is None and args.train_data_dir is None:
raise ValueError("Need either a dataset name or a training folder.")
if args.dataset_name is not None and args.train_data_dir is not None:
raise ValueError("Specify only one of `--dataset_name` or `--train_data_dir`")
if args.proportion_empty_prompts < 0 or args.proportion_empty_prompts > 1:
raise ValueError("`--proportion_empty_prompts` must be in the range [0, 1].")
if args.validation_prompt is not None and args.validation_image is None:
raise ValueError("`--validation_image` must be set if `--validation_prompt` is set")
if args.validation_prompt is None and args.validation_image is not None:
raise ValueError("`--validation_prompt` must be set if `--validation_image` is set")
if (
args.validation_image is not None
and args.validation_prompt is not None
and len(args.validation_image) != 1
and len(args.validation_prompt) != 1
and len(args.validation_image) != len(args.validation_prompt)
):
raise ValueError(
"Must provide either 1 `--validation_image`, 1 `--validation_prompt`,"
" or the same number of `--validation_prompt`s and `--validation_image`s"
)
return args
def make_train_dataset(args, tokenizer):
# Get the datasets: you can either provide your own training and evaluation files (see below)
# or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).
# In distributed training, the load_dataset function guarantees that only one local process can concurrently
# download the dataset.
if args.dataset_name is not None:
# Downloading and loading a dataset from the hub.
dataset = load_dataset(
args.dataset_name,
args.dataset_config_name,
cache_dir=args.cache_dir,
)
else:
data_files = {}
if args.train_data_dir is not None:
data_files["train"] = os.path.join(args.train_data_dir, "**")
dataset = load_dataset(
"imagefolder",
data_files=data_files,
cache_dir=args.cache_dir,
)
# See more about loading custom images at
# https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder
# Preprocessing the datasets.
# We need to tokenize inputs and targets.
column_names = dataset["train"].column_names
# 6. Get the column names for input/target.
if args.image_column is None:
image_column = column_names[0]
logger.info(f"image column defaulting to {image_column}")
else:
image_column = args.image_column
if image_column not in column_names:
raise ValueError(
f"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
)
if args.caption_column is None:
caption_column = column_names[1]
logger.info(f"caption column defaulting to {caption_column}")
else:
caption_column = args.caption_column
if caption_column not in column_names:
raise ValueError(
f"`--caption_column` value '{args.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
)
if args.conditioning_image_column is None:
conditioning_image_column = column_names[2]
logger.info(f"conditioning image column defaulting to {caption_column}")
else:
conditioning_image_column = args.conditioning_image_column
if conditioning_image_column not in column_names:
raise ValueError(
f"`--conditioning_image_column` value '{args.conditioning_image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
)
def tokenize_captions(examples, is_train=True):
captions = []
for caption in examples[caption_column]:
if random.random() < args.proportion_empty_prompts:
captions.append("")
elif isinstance(caption, str):
captions.append(caption)
elif isinstance(caption, (list, np.ndarray)):
# take a random caption if there are multiple
captions.append(random.choice(caption) if is_train else caption[0])
else:
raise ValueError(
f"Caption column `{caption_column}` should contain either strings or lists of strings."
)
inputs = tokenizer(
captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
)
return inputs.input_ids
image_transforms = transforms.Compose(
[
transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]
)
conditioning_image_transforms = transforms.Compose(
[
transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
transforms.ToTensor(),
]
)
def preprocess_train(examples):
images = [image.convert("RGB") for image in examples[image_column]]
images = [image_transforms(image) for image in images]
conditioning_images = [image.convert("RGB") for image in examples[conditioning_image_column]]
conditioning_images = [conditioning_image_transforms(image) for image in conditioning_images]
examples["pixel_values"] = images
examples["conditioning_pixel_values"] = conditioning_images
examples["input_ids"] = tokenize_captions(examples)
return examples
if jax.process_index() == 0:
if args.max_train_samples is not None:
dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
# Set the training transforms
train_dataset = dataset["train"].with_transform(preprocess_train)
return train_dataset
def collate_fn(examples):
pixel_values = torch.stack([example["pixel_values"] for example in examples])
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
conditioning_pixel_values = torch.stack([example["conditioning_pixel_values"] for example in examples])
conditioning_pixel_values = conditioning_pixel_values.to(memory_format=torch.contiguous_format).float()
input_ids = torch.stack([example["input_ids"] for example in examples])
batch = {
"pixel_values": pixel_values,
"conditioning_pixel_values": conditioning_pixel_values,
"input_ids": input_ids,
}
batch = {k: v.numpy() for k, v in batch.items()}
return batch
def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
if token is None:
token = HfFolder.get_token()
if organization is None:
username = whoami(token)["name"]
return f"{username}/{model_id}"
else:
return f"{organization}/{model_id}"
def get_params_to_save(params):
return jax.device_get(jax.tree_util.tree_map(lambda x: x[0], params))
def main():
args = parse_args()
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
# Setup logging, we only want one process per machine to log things on the screen.
logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR)
if jax.process_index() == 0:
transformers.utils.logging.set_verbosity_info()
else:
transformers.utils.logging.set_verbosity_error()
# wandb init
if jax.process_index() == 0 and args.report_to == "wandb":
wandb.init(
project=args.tracker_project_name,
job_type="train",
config=args,
)
if args.seed is not None:
set_seed(args.seed)
rng = jax.random.PRNGKey(0)
# Handle the repository creation
if jax.process_index() == 0:
if args.push_to_hub:
if args.hub_model_id is None:
repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
else:
repo_name = args.hub_model_id
repo_url = create_repo(repo_name, exist_ok=True, token=args.hub_token)
repo = Repository(args.output_dir, clone_from=repo_url, token=args.hub_token)
with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
if "step_*" not in gitignore:
gitignore.write("step_*\n")
if "epoch_*" not in gitignore:
gitignore.write("epoch_*\n")
elif args.output_dir is not None:
os.makedirs(args.output_dir, exist_ok=True)
# Load the tokenizer and add the placeholder token as a additional special token
if args.tokenizer_name:
tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name)
elif args.pretrained_model_name_or_path:
tokenizer = CLIPTokenizer.from_pretrained(
args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision
)
else:
raise NotImplementedError("No tokenizer specified!")
# Get the datasets: you can either provide your own training and evaluation files (see below)
train_dataset = make_train_dataset(args, tokenizer)
total_train_batch_size = args.train_batch_size * jax.local_device_count() * args.gradient_accumulation_steps
train_dataloader = torch.utils.data.DataLoader(
train_dataset,
shuffle=True,
collate_fn=collate_fn,
batch_size=total_train_batch_size,
num_workers=args.dataloader_num_workers,
drop_last=True,
)
weight_dtype = jnp.float32
if args.mixed_precision == "fp16":
weight_dtype = jnp.float16
elif args.mixed_precision == "bf16":
weight_dtype = jnp.bfloat16
# Load models and create wrapper for stable diffusion
text_encoder = FlaxCLIPTextModel.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="text_encoder",
dtype=weight_dtype,
revision=args.revision,
from_pt=args.from_pt,
)
vae, vae_params = FlaxAutoencoderKL.from_pretrained(
args.pretrained_model_name_or_path,
revision=args.revision,
subfolder="vae",
dtype=weight_dtype,
from_pt=args.from_pt,
)
unet, unet_params = FlaxUNet2DConditionModel.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="unet",
dtype=weight_dtype,
revision=args.revision,
from_pt=args.from_pt,
)
if args.controlnet_model_name_or_path:
logger.info("Loading existing controlnet weights")
controlnet, controlnet_params = FlaxControlNetModel.from_pretrained(
args.controlnet_model_name_or_path, from_pt=True, dtype=jnp.float32
)
else:
logger.info("Initializing controlnet weights from unet")
rng, rng_params = jax.random.split(rng)
controlnet = FlaxControlNetModel(
in_channels=unet.config.in_channels,
down_block_types=unet.config.down_block_types,
only_cross_attention=unet.config.only_cross_attention,
block_out_channels=unet.config.block_out_channels,
layers_per_block=unet.config.layers_per_block,
attention_head_dim=unet.config.attention_head_dim,
cross_attention_dim=unet.config.cross_attention_dim,
use_linear_projection=unet.config.use_linear_projection,
flip_sin_to_cos=unet.config.flip_sin_to_cos,
freq_shift=unet.config.freq_shift,
)
controlnet_params = controlnet.init_weights(rng=rng_params)
controlnet_params = unfreeze(controlnet_params)
for key in [
"conv_in",
"time_embedding",
"down_blocks_0",
"down_blocks_1",
"down_blocks_2",
"down_blocks_3",
"mid_block",
]:
controlnet_params[key] = unet_params[key]
# Optimization
if args.scale_lr:
args.learning_rate = args.learning_rate * total_train_batch_size
constant_scheduler = optax.constant_schedule(args.learning_rate)
adamw = optax.adamw(
learning_rate=constant_scheduler,
b1=args.adam_beta1,
b2=args.adam_beta2,
eps=args.adam_epsilon,
weight_decay=args.adam_weight_decay,
)
optimizer = optax.chain(
optax.clip_by_global_norm(args.max_grad_norm),
adamw,
)
state = train_state.TrainState.create(apply_fn=controlnet.__call__, params=controlnet_params, tx=optimizer)
noise_scheduler, noise_scheduler_state = FlaxDDPMScheduler.from_pretrained(
args.pretrained_model_name_or_path, subfolder="scheduler"
)
# Initialize our training
validation_rng, train_rngs = jax.random.split(rng)
train_rngs = jax.random.split(train_rngs, jax.local_device_count())
def train_step(state, unet_params, text_encoder_params, vae_params, batch, train_rng):
# reshape batch, add grad_step_dim if gradient_accumulation_steps > 1
if args.gradient_accumulation_steps > 1:
grad_steps = args.gradient_accumulation_steps
batch = jax.tree_map(lambda x: x.reshape((grad_steps, x.shape[0] // grad_steps) + x.shape[1:]), batch)
def compute_loss(params, minibatch, sample_rng):
# Convert images to latent space
vae_outputs = vae.apply(
{"params": vae_params}, minibatch["pixel_values"], deterministic=True, method=vae.encode
)
latents = vae_outputs.latent_dist.sample(sample_rng)
# (NHWC) -> (NCHW)
latents = jnp.transpose(latents, (0, 3, 1, 2))
latents = latents * vae.config.scaling_factor
# Sample noise that we'll add to the latents
noise_rng, timestep_rng = jax.random.split(sample_rng)
noise = jax.random.normal(noise_rng, latents.shape)
# Sample a random timestep for each image
bsz = latents.shape[0]
timesteps = jax.random.randint(
timestep_rng,
(bsz,),
0,
noise_scheduler.config.num_train_timesteps,
)
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_latents = noise_scheduler.add_noise(noise_scheduler_state, latents, noise, timesteps)
# Get the text embedding for conditioning
encoder_hidden_states = text_encoder(
minibatch["input_ids"],
params=text_encoder_params,
train=False,
)[0]
controlnet_cond = minibatch["conditioning_pixel_values"]
# Predict the noise residual and compute loss
down_block_res_samples, mid_block_res_sample = controlnet.apply(
{"params": params},
noisy_latents,
timesteps,
encoder_hidden_states,
controlnet_cond,
train=True,
return_dict=False,
)
model_pred = unet.apply(
{"params": unet_params},
noisy_latents,
timesteps,
encoder_hidden_states,
down_block_additional_residuals=down_block_res_samples,
mid_block_additional_residual=mid_block_res_sample,
).sample
# Get the target for loss depending on the prediction type
if noise_scheduler.config.prediction_type == "epsilon":
target = noise
elif noise_scheduler.config.prediction_type == "v_prediction":
target = noise_scheduler.get_velocity(noise_scheduler_state, latents, noise, timesteps)
else:
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
loss = (target - model_pred) ** 2
loss = loss.mean()
return loss
grad_fn = jax.value_and_grad(compute_loss)
# get a minibatch (one gradient accumulation slice)
def get_minibatch(batch, grad_idx):
return jax.tree_util.tree_map(
lambda x: jax.lax.dynamic_index_in_dim(x, grad_idx, keepdims=False),
batch,
)
def loss_and_grad(grad_idx, train_rng):
# create minibatch for the grad step
minibatch = get_minibatch(batch, grad_idx) if grad_idx is not None else batch
sample_rng, train_rng = jax.random.split(train_rng, 2)
loss, grad = grad_fn(state.params, minibatch, sample_rng)
return loss, grad, train_rng
if args.gradient_accumulation_steps == 1:
loss, grad, new_train_rng = loss_and_grad(None, train_rng)
else:
init_loss_grad_rng = (
0.0, # initial value for cumul_loss
jax.tree_map(jnp.zeros_like, state.params), # initial value for cumul_grad
train_rng, # initial value for train_rng
)
def cumul_grad_step(grad_idx, loss_grad_rng):
cumul_loss, cumul_grad, train_rng = loss_grad_rng
loss, grad, new_train_rng = loss_and_grad(grad_idx, train_rng)
cumul_loss, cumul_grad = jax.tree_map(jnp.add, (cumul_loss, cumul_grad), (loss, grad))
return cumul_loss, cumul_grad, new_train_rng
loss, grad, new_train_rng = jax.lax.fori_loop(
0,
args.gradient_accumulation_steps,
cumul_grad_step,
init_loss_grad_rng,
)
loss, grad = jax.tree_map(lambda x: x / args.gradient_accumulation_steps, (loss, grad))
grad = jax.lax.pmean(grad, "batch")
new_state = state.apply_gradients(grads=grad)
metrics = {"loss": loss}
metrics = jax.lax.pmean(metrics, axis_name="batch")
return new_state, metrics, new_train_rng
# Create parallel version of the train step
p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,))
# Replicate the train state on each device
state = jax_utils.replicate(state)
unet_params = jax_utils.replicate(unet_params)
text_encoder_params = jax_utils.replicate(text_encoder.params)
vae_params = jax_utils.replicate(vae_params)
# Train!
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
# Scheduler and math around the number of training steps.
if args.max_train_steps is None:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
logger.info("***** Running training *****")
logger.info(f" Num examples = {len(train_dataset)}")
logger.info(f" Num Epochs = {args.num_train_epochs}")
logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
logger.info(f" Total train batch size (w. parallel & distributed) = {total_train_batch_size}")
logger.info(f" Total optimization steps = {args.num_train_epochs * num_update_steps_per_epoch}")
if jax.process_index() == 0:
wandb.define_metric("*", step_metric="train/step")
wandb.config.update(
{
"num_train_examples": len(train_dataset),
"total_train_batch_size": total_train_batch_size,
"total_optimization_step": args.num_train_epochs * num_update_steps_per_epoch,
"num_devices": jax.device_count(),
}
)
global_step = 0
epochs = tqdm(
range(args.num_train_epochs),
desc="Epoch ... ",
position=0,
disable=jax.process_index() > 0,
)
for epoch in epochs:
# ======================== Training ================================
train_metrics = []
steps_per_epoch = len(train_dataset) // total_train_batch_size
train_step_progress_bar = tqdm(
total=steps_per_epoch,
desc="Training...",
position=1,
leave=False,
disable=jax.process_index() > 0,
)
# train
for batch in train_dataloader:
batch = shard(batch)
state, train_metric, train_rngs = p_train_step(
state, unet_params, text_encoder_params, vae_params, batch, train_rngs
)
train_metrics.append(train_metric)
train_step_progress_bar.update(1)
global_step += 1
if global_step >= args.max_train_steps:
break
if (
args.validation_prompt is not None
and global_step % args.validation_steps == 0
and jax.process_index() == 0
):
_ = log_validation(controlnet, state.params, tokenizer, args, validation_rng, weight_dtype)
if global_step % args.logging_steps == 0 and jax.process_index() == 0:
if args.report_to == "wandb":
wandb.log(
{
"train/step": global_step,
"train/epoch": epoch,
"train/loss": jax_utils.unreplicate(train_metric)["loss"],
}
)
train_metric = jax_utils.unreplicate(train_metric)
train_step_progress_bar.close()
epochs.write(f"Epoch... ({epoch + 1}/{args.num_train_epochs} | Loss: {train_metric['loss']})")
# Create the pipeline using using the trained modules and save it.
if jax.process_index() == 0:
image_logs = log_validation(controlnet, state.params, tokenizer, args, validation_rng, weight_dtype)
controlnet.save_pretrained(
args.output_dir,
params=get_params_to_save(state.params),
)
if args.push_to_hub:
save_model_card(
repo_name,
image_logs=image_logs,
base_model=args.pretrained_model_name_or_path,
repo_folder=args.output_dir,
)
repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True)
if __name__ == "__main__":
main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment