Unverified Commit 9c13f865 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[training] add an offload utility that can be used as a context manager. (#11775)



* add an offload utility that can be used as a context manager.

* update

---------
Co-authored-by: default avatarLinoy Tsaban <57615435+linoytsaban@users.noreply.github.com>
parent 5c520972
......@@ -13,6 +13,7 @@ on:
- "src/diffusers/loaders/peft.py"
- "tests/pipelines/test_pipelines_common.py"
- "tests/models/test_modeling_common.py"
- "examples/**/*.py"
workflow_dispatch:
concurrency:
......
......@@ -58,6 +58,7 @@ from diffusers.training_utils import (
compute_density_for_timestep_sampling,
compute_loss_weighting_for_sd3,
free_memory,
offload_models,
)
from diffusers.utils import (
check_min_version,
......@@ -1364,43 +1365,34 @@ def main(args):
# provided (i.e. the --instance_prompt is used for all images), we encode the instance prompt once to avoid
# the redundant encoding.
if not train_dataset.custom_instance_prompts:
if args.offload:
text_encoding_pipeline = text_encoding_pipeline.to(accelerator.device)
(
instance_prompt_hidden_states_t5,
instance_prompt_hidden_states_llama3,
instance_pooled_prompt_embeds,
_,
_,
_,
) = compute_text_embeddings(args.instance_prompt, text_encoding_pipeline)
if args.offload:
text_encoding_pipeline = text_encoding_pipeline.to("cpu")
with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload):
(
instance_prompt_hidden_states_t5,
instance_prompt_hidden_states_llama3,
instance_pooled_prompt_embeds,
_,
_,
_,
) = compute_text_embeddings(args.instance_prompt, text_encoding_pipeline)
# Handle class prompt for prior-preservation.
if args.with_prior_preservation:
if args.offload:
text_encoding_pipeline = text_encoding_pipeline.to(accelerator.device)
(class_prompt_hidden_states_t5, class_prompt_hidden_states_llama3, class_pooled_prompt_embeds, _, _, _) = (
compute_text_embeddings(args.class_prompt, text_encoding_pipeline)
)
if args.offload:
text_encoding_pipeline = text_encoding_pipeline.to("cpu")
with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload):
(class_prompt_hidden_states_t5, class_prompt_hidden_states_llama3, class_pooled_prompt_embeds, _, _, _) = (
compute_text_embeddings(args.class_prompt, text_encoding_pipeline)
)
validation_embeddings = {}
if args.validation_prompt is not None:
if args.offload:
text_encoding_pipeline = text_encoding_pipeline.to(accelerator.device)
(
validation_embeddings["prompt_embeds_t5"],
validation_embeddings["prompt_embeds_llama3"],
validation_embeddings["pooled_prompt_embeds"],
validation_embeddings["negative_prompt_embeds_t5"],
validation_embeddings["negative_prompt_embeds_llama3"],
validation_embeddings["negative_pooled_prompt_embeds"],
) = compute_text_embeddings(args.validation_prompt, text_encoding_pipeline)
if args.offload:
text_encoding_pipeline = text_encoding_pipeline.to("cpu")
with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload):
(
validation_embeddings["prompt_embeds_t5"],
validation_embeddings["prompt_embeds_llama3"],
validation_embeddings["pooled_prompt_embeds"],
validation_embeddings["negative_prompt_embeds_t5"],
validation_embeddings["negative_prompt_embeds_llama3"],
validation_embeddings["negative_pooled_prompt_embeds"],
) = compute_text_embeddings(args.validation_prompt, text_encoding_pipeline)
# If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),
# pack the statically computed variables appropriately here. This is so that we don't
......@@ -1581,12 +1573,10 @@ def main(args):
if args.cache_latents:
model_input = latents_cache[step].sample()
else:
if args.offload:
vae = vae.to(accelerator.device)
pixel_values = batch["pixel_values"].to(dtype=vae.dtype)
with offload_models(vae, device=accelerator.device, offload=args.offload):
pixel_values = batch["pixel_values"].to(dtype=vae.dtype)
model_input = vae.encode(pixel_values).latent_dist.sample()
if args.offload:
vae = vae.to("cpu")
model_input = (model_input - vae_config_shift_factor) * vae_config_scaling_factor
model_input = model_input.to(dtype=weight_dtype)
......
......@@ -5,12 +5,14 @@ import math
import random
import re
import warnings
from contextlib import contextmanager
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
import numpy as np
import torch
from .models import UNet2DConditionModel
from .pipelines import DiffusionPipeline
from .schedulers import SchedulerMixin
from .utils import (
convert_state_dict_to_diffusers,
......@@ -318,6 +320,39 @@ def free_memory():
torch.xpu.empty_cache()
@contextmanager
def offload_models(
*modules: Union[torch.nn.Module, DiffusionPipeline], device: Union[str, torch.device], offload: bool = True
):
"""
Context manager that, if offload=True, moves each module to `device` on enter, then moves it back to its original
device on exit.
Args:
device (`str` or `torch.Device`): Device to move the `modules` to.
offload (`bool`): Flag to enable offloading.
"""
if offload:
is_model = not any(isinstance(m, DiffusionPipeline) for m in modules)
# record where each module was
if is_model:
original_devices = [next(m.parameters()).device for m in modules]
else:
assert len(modules) == 1
original_devices = modules[0].device
# move to target device
for m in modules:
m.to(device)
try:
yield
finally:
if offload:
# move back to original devices
for m, orig_dev in zip(modules, original_devices):
m.to(orig_dev)
def parse_buckets_string(buckets_str):
"""Parses a string defining buckets into a list of (height, width) tuples."""
if not buckets_str:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment