Unverified Commit 9c13f865 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[training] add an offload utility that can be used as a context manager. (#11775)



* add an offload utility that can be used as a context manager.

* update

---------
Co-authored-by: default avatarLinoy Tsaban <57615435+linoytsaban@users.noreply.github.com>
parent 5c520972
...@@ -13,6 +13,7 @@ on: ...@@ -13,6 +13,7 @@ on:
- "src/diffusers/loaders/peft.py" - "src/diffusers/loaders/peft.py"
- "tests/pipelines/test_pipelines_common.py" - "tests/pipelines/test_pipelines_common.py"
- "tests/models/test_modeling_common.py" - "tests/models/test_modeling_common.py"
- "examples/**/*.py"
workflow_dispatch: workflow_dispatch:
concurrency: concurrency:
......
...@@ -58,6 +58,7 @@ from diffusers.training_utils import ( ...@@ -58,6 +58,7 @@ from diffusers.training_utils import (
compute_density_for_timestep_sampling, compute_density_for_timestep_sampling,
compute_loss_weighting_for_sd3, compute_loss_weighting_for_sd3,
free_memory, free_memory,
offload_models,
) )
from diffusers.utils import ( from diffusers.utils import (
check_min_version, check_min_version,
...@@ -1364,43 +1365,34 @@ def main(args): ...@@ -1364,43 +1365,34 @@ def main(args):
# provided (i.e. the --instance_prompt is used for all images), we encode the instance prompt once to avoid # provided (i.e. the --instance_prompt is used for all images), we encode the instance prompt once to avoid
# the redundant encoding. # the redundant encoding.
if not train_dataset.custom_instance_prompts: if not train_dataset.custom_instance_prompts:
if args.offload: with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload):
text_encoding_pipeline = text_encoding_pipeline.to(accelerator.device) (
( instance_prompt_hidden_states_t5,
instance_prompt_hidden_states_t5, instance_prompt_hidden_states_llama3,
instance_prompt_hidden_states_llama3, instance_pooled_prompt_embeds,
instance_pooled_prompt_embeds, _,
_, _,
_, _,
_, ) = compute_text_embeddings(args.instance_prompt, text_encoding_pipeline)
) = compute_text_embeddings(args.instance_prompt, text_encoding_pipeline)
if args.offload:
text_encoding_pipeline = text_encoding_pipeline.to("cpu")
# Handle class prompt for prior-preservation. # Handle class prompt for prior-preservation.
if args.with_prior_preservation: if args.with_prior_preservation:
if args.offload: with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload):
text_encoding_pipeline = text_encoding_pipeline.to(accelerator.device) (class_prompt_hidden_states_t5, class_prompt_hidden_states_llama3, class_pooled_prompt_embeds, _, _, _) = (
(class_prompt_hidden_states_t5, class_prompt_hidden_states_llama3, class_pooled_prompt_embeds, _, _, _) = ( compute_text_embeddings(args.class_prompt, text_encoding_pipeline)
compute_text_embeddings(args.class_prompt, text_encoding_pipeline) )
)
if args.offload:
text_encoding_pipeline = text_encoding_pipeline.to("cpu")
validation_embeddings = {} validation_embeddings = {}
if args.validation_prompt is not None: if args.validation_prompt is not None:
if args.offload: with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload):
text_encoding_pipeline = text_encoding_pipeline.to(accelerator.device) (
( validation_embeddings["prompt_embeds_t5"],
validation_embeddings["prompt_embeds_t5"], validation_embeddings["prompt_embeds_llama3"],
validation_embeddings["prompt_embeds_llama3"], validation_embeddings["pooled_prompt_embeds"],
validation_embeddings["pooled_prompt_embeds"], validation_embeddings["negative_prompt_embeds_t5"],
validation_embeddings["negative_prompt_embeds_t5"], validation_embeddings["negative_prompt_embeds_llama3"],
validation_embeddings["negative_prompt_embeds_llama3"], validation_embeddings["negative_pooled_prompt_embeds"],
validation_embeddings["negative_pooled_prompt_embeds"], ) = compute_text_embeddings(args.validation_prompt, text_encoding_pipeline)
) = compute_text_embeddings(args.validation_prompt, text_encoding_pipeline)
if args.offload:
text_encoding_pipeline = text_encoding_pipeline.to("cpu")
# If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images), # If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),
# pack the statically computed variables appropriately here. This is so that we don't # pack the statically computed variables appropriately here. This is so that we don't
...@@ -1581,12 +1573,10 @@ def main(args): ...@@ -1581,12 +1573,10 @@ def main(args):
if args.cache_latents: if args.cache_latents:
model_input = latents_cache[step].sample() model_input = latents_cache[step].sample()
else: else:
if args.offload: with offload_models(vae, device=accelerator.device, offload=args.offload):
vae = vae.to(accelerator.device) pixel_values = batch["pixel_values"].to(dtype=vae.dtype)
pixel_values = batch["pixel_values"].to(dtype=vae.dtype)
model_input = vae.encode(pixel_values).latent_dist.sample() model_input = vae.encode(pixel_values).latent_dist.sample()
if args.offload:
vae = vae.to("cpu")
model_input = (model_input - vae_config_shift_factor) * vae_config_scaling_factor model_input = (model_input - vae_config_shift_factor) * vae_config_scaling_factor
model_input = model_input.to(dtype=weight_dtype) model_input = model_input.to(dtype=weight_dtype)
......
...@@ -5,12 +5,14 @@ import math ...@@ -5,12 +5,14 @@ import math
import random import random
import re import re
import warnings import warnings
from contextlib import contextmanager
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
import numpy as np import numpy as np
import torch import torch
from .models import UNet2DConditionModel from .models import UNet2DConditionModel
from .pipelines import DiffusionPipeline
from .schedulers import SchedulerMixin from .schedulers import SchedulerMixin
from .utils import ( from .utils import (
convert_state_dict_to_diffusers, convert_state_dict_to_diffusers,
...@@ -318,6 +320,39 @@ def free_memory(): ...@@ -318,6 +320,39 @@ def free_memory():
torch.xpu.empty_cache() torch.xpu.empty_cache()
@contextmanager
def offload_models(
*modules: Union[torch.nn.Module, DiffusionPipeline], device: Union[str, torch.device], offload: bool = True
):
"""
Context manager that, if offload=True, moves each module to `device` on enter, then moves it back to its original
device on exit.
Args:
device (`str` or `torch.Device`): Device to move the `modules` to.
offload (`bool`): Flag to enable offloading.
"""
if offload:
is_model = not any(isinstance(m, DiffusionPipeline) for m in modules)
# record where each module was
if is_model:
original_devices = [next(m.parameters()).device for m in modules]
else:
assert len(modules) == 1
original_devices = modules[0].device
# move to target device
for m in modules:
m.to(device)
try:
yield
finally:
if offload:
# move back to original devices
for m, orig_dev in zip(modules, original_devices):
m.to(orig_dev)
def parse_buckets_string(buckets_str): def parse_buckets_string(buckets_str):
"""Parses a string defining buckets into a list of (height, width) tuples.""" """Parses a string defining buckets into a list of (height, width) tuples."""
if not buckets_str: if not buckets_str:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment