Unverified Commit bcd6f3f9 authored by Yasyf Mohamedali's avatar Yasyf Mohamedali Committed by GitHub
Browse files

Various Fixes for Flax Dreambooth (#1782)



* Various Fixes for Flax Dreambooth

- Correctly update the progress bar every epoch
- Allow specifying a pretrained VAE
- Allow specifying a revision to pretrained models
- Cache compiled models between invocations (speeds up TPU execution a lot!)
- Save intermediate checkpoints by specifying `save_steps`

* Don't die when save_steps is not set.

* Address CR

* Address comments

* Apply suggestions from code review
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>
Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>
parent 19a0ce4a
......@@ -28,6 +28,7 @@ from flax import jax_utils
from flax.training import train_state
from flax.training.common_utils import shard
from huggingface_hub import HfFolder, Repository, whoami
from jax.experimental.compilation_cache import compilation_cache as cc
from PIL import Image
from torchvision import transforms
from tqdm.auto import tqdm
......@@ -37,6 +38,9 @@ from transformers import CLIPFeatureExtractor, CLIPTokenizer, FlaxCLIPTextModel,
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.10.0.dev0")
# Cache compiled models across invocations of this script.
cc.initialize_cache(os.path.expanduser("~/.cache/jax/compilation_cache"))
logger = logging.getLogger(__name__)
......@@ -49,6 +53,19 @@ def parse_args():
required=True,
help="Path to pretrained model or model identifier from huggingface.co/models.",
)
parser.add_argument(
"--pretrained_vae_name_or_path",
type=str,
default=None,
help="Path to pretrained vae or vae identifier from huggingface.co/models.",
)
parser.add_argument(
"--revision",
type=str,
default=None,
required=False,
help="Revision of pretrained model identifier from huggingface.co/models.",
)
parser.add_argument(
"--tokenizer_name",
type=str,
......@@ -103,6 +120,7 @@ def parse_args():
default="text-inversion-model",
help="The output directory where the model predictions and checkpoints will be written.",
)
parser.add_argument("--save_steps", type=int, default=None, help="Save a checkpoint every X steps.")
parser.add_argument("--seed", type=int, default=0, help="A seed for reproducible training.")
parser.add_argument(
"--resolution",
......@@ -332,7 +350,7 @@ def main():
if cur_class_images < args.num_class_images:
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path, safety_checker=None
args.pretrained_model_name_or_path, safety_checker=None, revision=args.revision
)
pipeline.set_progress_bar_config(disable=True)
......@@ -383,7 +401,11 @@ def main():
if args.tokenizer_name:
tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name)
elif args.pretrained_model_name_or_path:
tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer")
tokenizer = CLIPTokenizer.from_pretrained(
args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision
)
else:
raise NotImplementedError("No tokenizer specified!")
train_dataset = DreamBoothDataset(
instance_data_root=args.instance_data_dir,
......@@ -437,15 +459,23 @@ def main():
elif args.mixed_precision == "bf16":
weight_dtype = jnp.bfloat16
if args.pretrained_vae_name_or_path:
# TODO(patil-suraj): Upload flax weights for the VAE
vae_arg, vae_kwargs = (args.pretrained_vae_name_or_path, {"from_pt": True})
else:
vae_arg, vae_kwargs = (args.pretrained_model_name_or_path, {"subfolder": "vae", "revision": args.revision})
# Load models and create wrapper for stable diffusion
text_encoder = FlaxCLIPTextModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="text_encoder", dtype=weight_dtype
args.pretrained_model_name_or_path, subfolder="text_encoder", dtype=weight_dtype, revision=args.revision
)
vae, vae_params = FlaxAutoencoderKL.from_pretrained(
args.pretrained_model_name_or_path, subfolder="vae", dtype=weight_dtype
vae_arg,
dtype=weight_dtype,
**vae_kwargs,
)
unet, unet_params = FlaxUNet2DConditionModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="unet", dtype=weight_dtype
args.pretrained_model_name_or_path, subfolder="unet", dtype=weight_dtype, revision=args.revision
)
# Optimization
......@@ -597,6 +627,39 @@ def main():
logger.info(f" Total train batch size (w. parallel & distributed) = {total_train_batch_size}")
logger.info(f" Total optimization steps = {args.max_train_steps}")
def checkpoint(step=None):
# Create the pipeline using the trained modules and save it.
scheduler, _ = FlaxPNDMScheduler.from_pretrained(
"CompVis/stable-diffusion-v1-4", subfolder="scheduler"
)
safety_checker = FlaxStableDiffusionSafetyChecker.from_pretrained(
"CompVis/stable-diffusion-safety-checker", from_pt=True
)
pipeline = FlaxStableDiffusionPipeline(
text_encoder=text_encoder,
vae=vae,
unet=unet,
tokenizer=tokenizer,
scheduler=scheduler,
safety_checker=safety_checker,
feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"),
)
outdir = os.path.join(args.output_dir, str(step)) if step else args.output_dir
pipeline.save_pretrained(
outdir,
params={
"text_encoder": get_params_to_save(text_encoder_state.params),
"vae": get_params_to_save(vae_params),
"unet": get_params_to_save(unet_state.params),
"safety_checker": safety_checker.params,
},
)
if args.push_to_hub:
message = f"checkpoint-{step}" if step is not None else "End of training"
repo.push_to_hub(commit_message=message, blocking=False, auto_lfs_prune=True)
global_step = 0
epochs = tqdm(range(args.num_train_epochs), desc="Epoch ... ", position=0)
......@@ -615,9 +678,11 @@ def main():
)
train_metrics.append(train_metric)
train_step_progress_bar.update(1)
train_step_progress_bar.update(jax.local_device_count())
global_step += 1
if jax.process_index() == 0 and args.save_steps and global_step % args.save_steps == 0:
checkpoint(global_step)
if global_step >= args.max_train_steps:
break
......@@ -626,36 +691,8 @@ def main():
train_step_progress_bar.close()
epochs.write(f"Epoch... ({epoch + 1}/{args.num_train_epochs} | Loss: {train_metric['loss']})")
# Create the pipeline using using the trained modules and save it.
if jax.process_index() == 0:
scheduler = FlaxPNDMScheduler(
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True
)
safety_checker = FlaxStableDiffusionSafetyChecker.from_pretrained(
"CompVis/stable-diffusion-safety-checker", from_pt=True
)
pipeline = FlaxStableDiffusionPipeline(
text_encoder=text_encoder,
vae=vae,
unet=unet,
tokenizer=tokenizer,
scheduler=scheduler,
safety_checker=safety_checker,
feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"),
)
pipeline.save_pretrained(
args.output_dir,
params={
"text_encoder": get_params_to_save(text_encoder_state.params),
"vae": get_params_to_save(vae_params),
"unet": get_params_to_save(unet_state.params),
"safety_checker": safety_checker.params,
},
)
if args.push_to_hub:
repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True)
checkpoint()
if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment