Unverified Commit 24947317 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[Examples] Add support for Min-SNR weighting strategy for better convergence (#2899)

* improve stable unclip doc.

* feat: support for applying min-snr weighting for faster convergence.

* add: support for validation logging with wandb

* make  not a required arg.

* fix: arg name.

* fix: cli args.

* fix: tracker config.

* fix: loss calculation.

* fix: validation logging.

* fix: unwrap call.

* fix: validation logging.

* fix: internval.

* fix: checkpointing push to hub.

* fix: https://github.com/huggingface/diffusers/commit/c8a2856c6d5e45577bf4c24dee06b1a4a2f5c050\#commitcomment-106913193



* fix: norm group test for UNet3D.

* address PR comments.

* remove unneeded code.

* add: entry in the readme and docs.

* Apply suggestions from code review
Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>

---------
Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>
parent 8826bae6
...@@ -155,6 +155,28 @@ python train_text_to_image_flax.py \ ...@@ -155,6 +155,28 @@ python train_text_to_image_flax.py \
</jax> </jax>
</frameworkcontent> </frameworkcontent>
## Training with Min-SNR weighting
We support training with the Min-SNR weighting strategy proposed in [Efficient Diffusion Training via Min-SNR Weighting Strategy](https://arxiv.org/abs/2303.09556) which helps to achieve faster convergence
by rebalancing the loss. In order to use it, one needs to set the `--snr_gamma` argument. The recommended
value when using it is 5.0.
You can find [this project on Weights and Biases](https://wandb.ai/sayakpaul/text2image-finetune-minsnr) that compares the loss surfaces of the following setups:
* Training without the Min-SNR weighting strategy
* Training with the Min-SNR weighting strategy (`snr_gamma` set to 5.0)
* Training with the Min-SNR weighting strategy (`snr_gamma` set to 1.0)
For our small Pokemons dataset, the effects of Min-SNR weighting strategy might not appear to be pronounced, but for larger datasets, we believe the effects will be more pronounced.
Also, note that in this example, we either predict `epsilon` (i.e., the noise) or the `v_prediction`. For both of these cases, the formulation of the Min-SNR weighting strategy that we have used holds.
<Tip warning={true}>
Training with Min-SNR weighting strategy is only supported in PyTorch.
</Tip>
## LoRA ## LoRA
You can also use Low-Rank Adaptation of Large Language Models (LoRA), a fine-tuning technique for accelerating training large models, for fine-tuning text-to-image models. For more details, take a look at the [LoRA training](lora#text-to-image) guide. You can also use Low-Rank Adaptation of Large Language Models (LoRA), a fine-tuning technique for accelerating training large models, for fine-tuning text-to-image models. For more details, take a look at the [LoRA training](lora#text-to-image) guide.
......
...@@ -111,6 +111,22 @@ image = pipe(prompt="yoda").images[0] ...@@ -111,6 +111,22 @@ image = pipe(prompt="yoda").images[0]
image.save("yoda-pokemon.png") image.save("yoda-pokemon.png")
``` ```
#### Training with Min-SNR weighting
We support training with the Min-SNR weighting strategy proposed in [Efficient Diffusion Training via Min-SNR Weighting Strategy](https://arxiv.org/abs/2303.09556) which helps to achieve faster convergence
by rebalancing the loss. In order to use it, one needs to set the `--snr_gamma` argument. The recommended
value when using it is 5.0.
You can find [this project on Weights and Biases](https://wandb.ai/sayakpaul/text2image-finetune-minsnr) that compares the loss surfaces of the following setups:
* Training without the Min-SNR weighting strategy
* Training with the Min-SNR weighting strategy (`snr_gamma` set to 5.0)
* Training with the Min-SNR weighting strategy (`snr_gamma` set to 1.0)
For our small Pokemons dataset, the effects of Min-SNR weighting strategy might not appear to be pronounced, but for larger datasets, we believe the effects will be more pronounced.
Also, note that in this example, we either predict `epsilon` (i.e., the noise) or the `v_prediction`. For both of these cases, the formulation of the Min-SNR weighting strategy that we have used holds.
## Training with LoRA ## Training with LoRA
Low-Rank Adaption of Large Language Models was first introduced by Microsoft in [LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685) by *Edward J. Hu, Yelong Shen, Phillip Wallis, Zeyuan Allen-Zhu, Yuanzhi Li, Shean Wang, Lu Wang, Weizhu Chen*. Low-Rank Adaption of Large Language Models was first introduced by Microsoft in [LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685) by *Edward J. Hu, Yelong Shen, Phillip Wallis, Zeyuan Allen-Zhu, Yuanzhi Li, Shean Wang, Lu Wang, Weizhu Chen*.
......
...@@ -41,15 +41,74 @@ import diffusers ...@@ -41,15 +41,74 @@ import diffusers
from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel
from diffusers.optimization import get_scheduler from diffusers.optimization import get_scheduler
from diffusers.training_utils import EMAModel from diffusers.training_utils import EMAModel
from diffusers.utils import check_min_version, deprecate from diffusers.utils import check_min_version, deprecate, is_wandb_available
from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.import_utils import is_xformers_available
if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.15.0.dev0") check_min_version("0.15.0.dev0")
logger = get_logger(__name__, log_level="INFO") logger = get_logger(__name__, log_level="INFO")
DATASET_NAME_MAPPING = {
"lambdalabs/pokemon-blip-captions": ("image", "text"),
}
def log_validation(vae, text_encoder, tokenizer, unet, args, accelerator, weight_dtype, epoch):
logger.info("Running validation... ")
pipeline = StableDiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path,
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=accelerator.unwrap_model(unet),
safety_checker=None,
revision=args.revision,
torch_dtype=weight_dtype,
)
pipeline = pipeline.to(accelerator.device)
pipeline.set_progress_bar_config(disable=True)
if args.enable_xformers_memory_efficient_attention:
pipeline.enable_xformers_memory_efficient_attention()
if args.seed is None:
generator = None
else:
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
images = []
for i in range(len(args.validation_prompts)):
with torch.autocast("cuda"):
image = pipeline(args.validation_prompts[i], num_inference_steps=20, generator=generator).images[0]
images.append(image)
for tracker in accelerator.trackers:
if tracker.name == "tensorboard":
np_images = np.stack([np.asarray(img) for img in images])
tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
elif tracker.name == "wandb":
tracker.log(
{
"validation": [
wandb.Image(image, caption=f"{i}: {args.validation_prompts[i]}")
for i, image in enumerate(images)
]
}
)
else:
logger.warn(f"image logging not implemented for {tracker.name}")
del pipeline
torch.cuda.empty_cache()
def parse_args(): def parse_args():
parser = argparse.ArgumentParser(description="Simple example of a training script.") parser = argparse.ArgumentParser(description="Simple example of a training script.")
...@@ -111,6 +170,13 @@ def parse_args(): ...@@ -111,6 +170,13 @@ def parse_args():
"value if set." "value if set."
), ),
) )
parser.add_argument(
"--validation_prompts",
type=str,
default=None,
nargs="+",
help=("A set of prompts evaluated every `--validation_epochs` and logged to `--report_to`."),
)
parser.add_argument( parser.add_argument(
"--output_dir", "--output_dir",
type=str, type=str,
...@@ -192,6 +258,13 @@ def parse_args(): ...@@ -192,6 +258,13 @@ def parse_args():
parser.add_argument( parser.add_argument(
"--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
) )
parser.add_argument(
"--snr_gamma",
type=float,
default=None,
help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. "
"More details here: https://arxiv.org/abs/2303.09556.",
)
parser.add_argument( parser.add_argument(
"--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
) )
...@@ -297,6 +370,21 @@ def parse_args(): ...@@ -297,6 +370,21 @@ def parse_args():
"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
) )
parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.") parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.")
parser.add_argument(
"--validation_epochs",
type=int,
default=5,
help="Run validation every X epochs.",
)
parser.add_argument(
"--tracker_project_name",
type=str,
default="text2image-fine-tune",
help=(
"The `project_name` argument passed to Accelerator.init_trackers for"
" more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator"
),
)
args = parser.parse_args() args = parser.parse_args()
env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
...@@ -314,11 +402,6 @@ def parse_args(): ...@@ -314,11 +402,6 @@ def parse_args():
return args return args
dataset_name_mapping = {
"lambdalabs/pokemon-blip-captions": ("image", "text"),
}
def main(): def main():
args = parse_args() args = parse_args()
...@@ -410,6 +493,30 @@ def main(): ...@@ -410,6 +493,30 @@ def main():
else: else:
raise ValueError("xformers is not available. Make sure it is installed correctly") raise ValueError("xformers is not available. Make sure it is installed correctly")
def compute_snr(timesteps):
"""
Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849
"""
alphas_cumprod = noise_scheduler.alphas_cumprod
sqrt_alphas_cumprod = alphas_cumprod**0.5
sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5
# Expand the tensors.
# Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026
sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape):
sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None]
alpha = sqrt_alphas_cumprod.expand(timesteps.shape)
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape):
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None]
sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape)
# Compute SNR.
snr = (alpha / sigma) ** 2
return snr
# `accelerate` 0.16.0 will have better support for customized saving # `accelerate` 0.16.0 will have better support for customized saving
if version.parse(accelerate.__version__) >= version.parse("0.16.0"): if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
...@@ -507,7 +614,7 @@ def main(): ...@@ -507,7 +614,7 @@ def main():
column_names = dataset["train"].column_names column_names = dataset["train"].column_names
# 6. Get the column names for input/target. # 6. Get the column names for input/target.
dataset_columns = dataset_name_mapping.get(args.dataset_name, None) dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None)
if args.image_column is None: if args.image_column is None:
image_column = dataset_columns[0] if dataset_columns is not None else column_names[0] image_column = dataset_columns[0] if dataset_columns is not None else column_names[0]
else: else:
...@@ -626,7 +733,9 @@ def main(): ...@@ -626,7 +733,9 @@ def main():
# We need to initialize the trackers we use, and also store our configuration. # We need to initialize the trackers we use, and also store our configuration.
# The trackers initializes automatically on the main process. # The trackers initializes automatically on the main process.
if accelerator.is_main_process: if accelerator.is_main_process:
accelerator.init_trackers("text2image-fine-tune", config=vars(args)) tracker_config = dict(vars(args))
tracker_config.pop("validation_prompts")
accelerator.init_trackers(args.tracker_project_name, tracker_config)
# Train! # Train!
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
...@@ -715,7 +824,23 @@ def main(): ...@@ -715,7 +824,23 @@ def main():
# Predict the noise residual and compute loss # Predict the noise residual and compute loss
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
if args.snr_gamma is None:
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
else:
# Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
# This is discussed in Section 4.2 of the same paper.
snr = compute_snr(timesteps)
mse_loss_weights = (
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
)
# We first calculate the original loss. Then we mean over the non-batch dimensions and
# rebalance the sample-wise losses with their respective loss weights.
# Finally, we take the mean of the rebalanced loss.
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
loss = loss.mean()
# Gather the losses across all processes for logging (if we use distributed training). # Gather the losses across all processes for logging (if we use distributed training).
avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean() avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
...@@ -750,6 +875,26 @@ def main(): ...@@ -750,6 +875,26 @@ def main():
if global_step >= args.max_train_steps: if global_step >= args.max_train_steps:
break break
if accelerator.is_main_process:
if args.validation_prompts is not None and epoch % args.validation_epochs == 0:
if args.use_ema:
# Store the UNet parameters temporarily and load the EMA parameters to perform inference.
ema_unet.store(unet.parameters())
ema_unet.copy_to(unet.parameters())
log_validation(
vae,
text_encoder,
tokenizer,
unet,
args,
accelerator,
weight_dtype,
global_step,
)
if args.use_ema:
# Switch back to the original UNet parameters.
ema_unet.restore(unet.parameters())
# Create the pipeline using the trained modules and save it. # Create the pipeline using the trained modules and save it.
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
if accelerator.is_main_process: if accelerator.is_main_process:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment