"benchmark/vscode:/vscode.git/clone" did not exist on "a049864270aa44226ccc33f3b5ef929222307e9c"
Unverified Commit 76ec3d1f authored by Aryan's avatar Aryan Committed by GitHub
Browse files

Support dynamically loading/unloading loras with group offloading (#11804)

* update

* add test

* address review comments

* update

* fixes

* change decorator order to fix tests

* try fix

* fight tests
parent cdaf84a7
This diff is collapsed.
...@@ -25,6 +25,7 @@ import torch.nn as nn ...@@ -25,6 +25,7 @@ import torch.nn as nn
from huggingface_hub import model_info from huggingface_hub import model_info
from huggingface_hub.constants import HF_HUB_OFFLINE from huggingface_hub.constants import HF_HUB_OFFLINE
from ..hooks.group_offloading import _is_group_offload_enabled, _maybe_remove_and_reapply_group_offloading
from ..models.modeling_utils import ModelMixin, load_state_dict from ..models.modeling_utils import ModelMixin, load_state_dict
from ..utils import ( from ..utils import (
USE_PEFT_BACKEND, USE_PEFT_BACKEND,
...@@ -391,7 +392,9 @@ def _load_lora_into_text_encoder( ...@@ -391,7 +392,9 @@ def _load_lora_into_text_encoder(
adapter_name = get_adapter_name(text_encoder) adapter_name = get_adapter_name(text_encoder)
# <Unsafe code # <Unsafe code
is_model_cpu_offload, is_sequential_cpu_offload = _func_optionally_disable_offloading(_pipeline) is_model_cpu_offload, is_sequential_cpu_offload, is_group_offload = _func_optionally_disable_offloading(
_pipeline
)
# inject LoRA layers and load the state dict # inject LoRA layers and load the state dict
# in transformers we automatically check whether the adapter name is already in use or not # in transformers we automatically check whether the adapter name is already in use or not
text_encoder.load_adapter( text_encoder.load_adapter(
...@@ -410,6 +413,10 @@ def _load_lora_into_text_encoder( ...@@ -410,6 +413,10 @@ def _load_lora_into_text_encoder(
_pipeline.enable_model_cpu_offload() _pipeline.enable_model_cpu_offload()
elif is_sequential_cpu_offload: elif is_sequential_cpu_offload:
_pipeline.enable_sequential_cpu_offload() _pipeline.enable_sequential_cpu_offload()
elif is_group_offload:
for component in _pipeline.components.values():
if isinstance(component, torch.nn.Module):
_maybe_remove_and_reapply_group_offloading(component)
# Unsafe code /> # Unsafe code />
if prefix is not None and not state_dict: if prefix is not None and not state_dict:
...@@ -433,30 +440,36 @@ def _func_optionally_disable_offloading(_pipeline): ...@@ -433,30 +440,36 @@ def _func_optionally_disable_offloading(_pipeline):
Returns: Returns:
tuple: tuple:
A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True. A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` or `is_group_offload` is True.
""" """
is_model_cpu_offload = False is_model_cpu_offload = False
is_sequential_cpu_offload = False is_sequential_cpu_offload = False
is_group_offload = False
if _pipeline is not None and _pipeline.hf_device_map is None: if _pipeline is not None and _pipeline.hf_device_map is None:
for _, component in _pipeline.components.items(): for _, component in _pipeline.components.items():
if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"): if not isinstance(component, nn.Module):
if not is_model_cpu_offload: continue
is_model_cpu_offload = isinstance(component._hf_hook, CpuOffload) is_group_offload = is_group_offload or _is_group_offload_enabled(component)
if not is_sequential_cpu_offload: if not hasattr(component, "_hf_hook"):
is_sequential_cpu_offload = ( continue
is_model_cpu_offload = is_model_cpu_offload or isinstance(component._hf_hook, CpuOffload)
is_sequential_cpu_offload = is_sequential_cpu_offload or (
isinstance(component._hf_hook, AlignDevicesHook) isinstance(component._hf_hook, AlignDevicesHook)
or hasattr(component._hf_hook, "hooks") or hasattr(component._hf_hook, "hooks")
and isinstance(component._hf_hook.hooks[0], AlignDevicesHook) and isinstance(component._hf_hook.hooks[0], AlignDevicesHook)
) )
if is_sequential_cpu_offload or is_model_cpu_offload:
logger.info( logger.info(
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again." "Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
) )
if is_sequential_cpu_offload or is_model_cpu_offload: for _, component in _pipeline.components.items():
if not isinstance(component, nn.Module) or not hasattr(component, "_hf_hook"):
continue
remove_hook_from_module(component, recurse=is_sequential_cpu_offload) remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
return (is_model_cpu_offload, is_sequential_cpu_offload) return (is_model_cpu_offload, is_sequential_cpu_offload, is_group_offload)
class LoraBaseMixin: class LoraBaseMixin:
......
...@@ -22,6 +22,7 @@ from typing import Dict, List, Literal, Optional, Union ...@@ -22,6 +22,7 @@ from typing import Dict, List, Literal, Optional, Union
import safetensors import safetensors
import torch import torch
from ..hooks.group_offloading import _maybe_remove_and_reapply_group_offloading
from ..utils import ( from ..utils import (
MIN_PEFT_VERSION, MIN_PEFT_VERSION,
USE_PEFT_BACKEND, USE_PEFT_BACKEND,
...@@ -256,7 +257,9 @@ class PeftAdapterMixin: ...@@ -256,7 +257,9 @@ class PeftAdapterMixin:
# In case the pipeline has been already offloaded to CPU - temporarily remove the hooks # In case the pipeline has been already offloaded to CPU - temporarily remove the hooks
# otherwise loading LoRA weights will lead to an error. # otherwise loading LoRA weights will lead to an error.
is_model_cpu_offload, is_sequential_cpu_offload = self._optionally_disable_offloading(_pipeline) is_model_cpu_offload, is_sequential_cpu_offload, is_group_offload = self._optionally_disable_offloading(
_pipeline
)
peft_kwargs = {} peft_kwargs = {}
if is_peft_version(">=", "0.13.1"): if is_peft_version(">=", "0.13.1"):
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
...@@ -347,6 +350,10 @@ class PeftAdapterMixin: ...@@ -347,6 +350,10 @@ class PeftAdapterMixin:
_pipeline.enable_model_cpu_offload() _pipeline.enable_model_cpu_offload()
elif is_sequential_cpu_offload: elif is_sequential_cpu_offload:
_pipeline.enable_sequential_cpu_offload() _pipeline.enable_sequential_cpu_offload()
elif is_group_offload:
for component in _pipeline.components.values():
if isinstance(component, torch.nn.Module):
_maybe_remove_and_reapply_group_offloading(component)
# Unsafe code /> # Unsafe code />
if prefix is not None and not state_dict: if prefix is not None and not state_dict:
...@@ -687,6 +694,8 @@ class PeftAdapterMixin: ...@@ -687,6 +694,8 @@ class PeftAdapterMixin:
if hasattr(self, "peft_config"): if hasattr(self, "peft_config"):
del self.peft_config del self.peft_config
_maybe_remove_and_reapply_group_offloading(self)
def disable_lora(self): def disable_lora(self):
""" """
Disables the active LoRA layers of the underlying model. Disables the active LoRA layers of the underlying model.
......
...@@ -22,6 +22,7 @@ import torch ...@@ -22,6 +22,7 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from huggingface_hub.utils import validate_hf_hub_args from huggingface_hub.utils import validate_hf_hub_args
from ..hooks.group_offloading import _maybe_remove_and_reapply_group_offloading
from ..models.embeddings import ( from ..models.embeddings import (
ImageProjection, ImageProjection,
IPAdapterFaceIDImageProjection, IPAdapterFaceIDImageProjection,
...@@ -203,6 +204,7 @@ class UNet2DConditionLoadersMixin: ...@@ -203,6 +204,7 @@ class UNet2DConditionLoadersMixin:
is_lora = all(("lora" in k or k.endswith(".alpha")) for k in state_dict.keys()) is_lora = all(("lora" in k or k.endswith(".alpha")) for k in state_dict.keys())
is_model_cpu_offload = False is_model_cpu_offload = False
is_sequential_cpu_offload = False is_sequential_cpu_offload = False
is_group_offload = False
if is_lora: if is_lora:
deprecation_message = "Using the `load_attn_procs()` method has been deprecated and will be removed in a future version. Please use `load_lora_adapter()`." deprecation_message = "Using the `load_attn_procs()` method has been deprecated and will be removed in a future version. Please use `load_lora_adapter()`."
...@@ -211,7 +213,7 @@ class UNet2DConditionLoadersMixin: ...@@ -211,7 +213,7 @@ class UNet2DConditionLoadersMixin:
if is_custom_diffusion: if is_custom_diffusion:
attn_processors = self._process_custom_diffusion(state_dict=state_dict) attn_processors = self._process_custom_diffusion(state_dict=state_dict)
elif is_lora: elif is_lora:
is_model_cpu_offload, is_sequential_cpu_offload = self._process_lora( is_model_cpu_offload, is_sequential_cpu_offload, is_group_offload = self._process_lora(
state_dict=state_dict, state_dict=state_dict,
unet_identifier_key=self.unet_name, unet_identifier_key=self.unet_name,
network_alphas=network_alphas, network_alphas=network_alphas,
...@@ -230,7 +232,9 @@ class UNet2DConditionLoadersMixin: ...@@ -230,7 +232,9 @@ class UNet2DConditionLoadersMixin:
# For LoRA, the UNet is already offloaded at this stage as it is handled inside `_process_lora`. # For LoRA, the UNet is already offloaded at this stage as it is handled inside `_process_lora`.
if is_custom_diffusion and _pipeline is not None: if is_custom_diffusion and _pipeline is not None:
is_model_cpu_offload, is_sequential_cpu_offload = self._optionally_disable_offloading(_pipeline=_pipeline) is_model_cpu_offload, is_sequential_cpu_offload, is_group_offload = self._optionally_disable_offloading(
_pipeline=_pipeline
)
# only custom diffusion needs to set attn processors # only custom diffusion needs to set attn processors
self.set_attn_processor(attn_processors) self.set_attn_processor(attn_processors)
...@@ -241,6 +245,10 @@ class UNet2DConditionLoadersMixin: ...@@ -241,6 +245,10 @@ class UNet2DConditionLoadersMixin:
_pipeline.enable_model_cpu_offload() _pipeline.enable_model_cpu_offload()
elif is_sequential_cpu_offload: elif is_sequential_cpu_offload:
_pipeline.enable_sequential_cpu_offload() _pipeline.enable_sequential_cpu_offload()
elif is_group_offload:
for component in _pipeline.components.values():
if isinstance(component, torch.nn.Module):
_maybe_remove_and_reapply_group_offloading(component)
# Unsafe code /> # Unsafe code />
def _process_custom_diffusion(self, state_dict): def _process_custom_diffusion(self, state_dict):
...@@ -307,6 +315,7 @@ class UNet2DConditionLoadersMixin: ...@@ -307,6 +315,7 @@ class UNet2DConditionLoadersMixin:
is_model_cpu_offload = False is_model_cpu_offload = False
is_sequential_cpu_offload = False is_sequential_cpu_offload = False
is_group_offload = False
state_dict_to_be_used = unet_state_dict if len(unet_state_dict) > 0 else state_dict state_dict_to_be_used = unet_state_dict if len(unet_state_dict) > 0 else state_dict
if len(state_dict_to_be_used) > 0: if len(state_dict_to_be_used) > 0:
...@@ -356,7 +365,9 @@ class UNet2DConditionLoadersMixin: ...@@ -356,7 +365,9 @@ class UNet2DConditionLoadersMixin:
# In case the pipeline has been already offloaded to CPU - temporarily remove the hooks # In case the pipeline has been already offloaded to CPU - temporarily remove the hooks
# otherwise loading LoRA weights will lead to an error # otherwise loading LoRA weights will lead to an error
is_model_cpu_offload, is_sequential_cpu_offload = self._optionally_disable_offloading(_pipeline) is_model_cpu_offload, is_sequential_cpu_offload, is_group_offload = self._optionally_disable_offloading(
_pipeline
)
peft_kwargs = {} peft_kwargs = {}
if is_peft_version(">=", "0.13.1"): if is_peft_version(">=", "0.13.1"):
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
...@@ -389,7 +400,7 @@ class UNet2DConditionLoadersMixin: ...@@ -389,7 +400,7 @@ class UNet2DConditionLoadersMixin:
if warn_msg: if warn_msg:
logger.warning(warn_msg) logger.warning(warn_msg)
return is_model_cpu_offload, is_sequential_cpu_offload return is_model_cpu_offload, is_sequential_cpu_offload, is_group_offload
@classmethod @classmethod
# Copied from diffusers.loaders.lora_base.LoraBaseMixin._optionally_disable_offloading # Copied from diffusers.loaders.lora_base.LoraBaseMixin._optionally_disable_offloading
......
...@@ -16,6 +16,7 @@ import sys ...@@ -16,6 +16,7 @@ import sys
import unittest import unittest
import torch import torch
from parameterized import parameterized
from transformers import AutoTokenizer, T5EncoderModel from transformers import AutoTokenizer, T5EncoderModel
from diffusers import ( from diffusers import (
...@@ -28,6 +29,7 @@ from diffusers import ( ...@@ -28,6 +29,7 @@ from diffusers import (
from diffusers.utils.testing_utils import ( from diffusers.utils.testing_utils import (
floats_tensor, floats_tensor,
require_peft_backend, require_peft_backend,
require_torch_accelerator,
) )
...@@ -127,6 +129,13 @@ class CogVideoXLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): ...@@ -127,6 +129,13 @@ class CogVideoXLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
def test_lora_scale_kwargs_match_fusion(self): def test_lora_scale_kwargs_match_fusion(self):
super().test_lora_scale_kwargs_match_fusion(expected_atol=9e-3, expected_rtol=9e-3) super().test_lora_scale_kwargs_match_fusion(expected_atol=9e-3, expected_rtol=9e-3)
@parameterized.expand([("block_level", True), ("leaf_level", False)])
@require_torch_accelerator
def test_group_offloading_inference_denoiser(self, offload_type, use_stream):
# TODO: We don't run the (leaf_level, True) test here that is enabled for other models.
# The reason for this can be found here: https://github.com/huggingface/diffusers/pull/11804#issuecomment-3013325338
super()._test_group_offloading_inference_denoiser(offload_type, use_stream)
@unittest.skip("Not supported in CogVideoX.") @unittest.skip("Not supported in CogVideoX.")
def test_simple_inference_with_text_denoiser_block_scale(self): def test_simple_inference_with_text_denoiser_block_scale(self):
pass pass
......
...@@ -18,10 +18,17 @@ import unittest ...@@ -18,10 +18,17 @@ import unittest
import numpy as np import numpy as np
import torch import torch
from parameterized import parameterized
from transformers import AutoTokenizer, GlmModel from transformers import AutoTokenizer, GlmModel
from diffusers import AutoencoderKL, CogView4Pipeline, CogView4Transformer2DModel, FlowMatchEulerDiscreteScheduler from diffusers import AutoencoderKL, CogView4Pipeline, CogView4Transformer2DModel, FlowMatchEulerDiscreteScheduler
from diffusers.utils.testing_utils import floats_tensor, require_peft_backend, skip_mps, torch_device from diffusers.utils.testing_utils import (
floats_tensor,
require_peft_backend,
require_torch_accelerator,
skip_mps,
torch_device,
)
sys.path.append(".") sys.path.append(".")
...@@ -141,6 +148,13 @@ class CogView4LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): ...@@ -141,6 +148,13 @@ class CogView4LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
"Loading from saved checkpoints should give same results.", "Loading from saved checkpoints should give same results.",
) )
@parameterized.expand([("block_level", True), ("leaf_level", False)])
@require_torch_accelerator
def test_group_offloading_inference_denoiser(self, offload_type, use_stream):
# TODO: We don't run the (leaf_level, True) test here that is enabled for other models.
# The reason for this can be found here: https://github.com/huggingface/diffusers/pull/11804#issuecomment-3013325338
super()._test_group_offloading_inference_denoiser(offload_type, use_stream)
@unittest.skip("Not supported in CogView4.") @unittest.skip("Not supported in CogView4.")
def test_simple_inference_with_text_denoiser_block_scale(self): def test_simple_inference_with_text_denoiser_block_scale(self):
pass pass
......
...@@ -39,6 +39,7 @@ from diffusers.utils.testing_utils import ( ...@@ -39,6 +39,7 @@ from diffusers.utils.testing_utils import (
is_torch_version, is_torch_version,
require_peft_backend, require_peft_backend,
require_peft_version_greater, require_peft_version_greater,
require_torch_accelerator,
require_transformers_version_greater, require_transformers_version_greater,
skip_mps, skip_mps,
torch_device, torch_device,
...@@ -2355,3 +2356,73 @@ class PeftLoraLoaderMixinTests: ...@@ -2355,3 +2356,73 @@ class PeftLoraLoaderMixinTests:
pipe.load_lora_weights(tmpdirname) pipe.load_lora_weights(tmpdirname)
output_lora_loaded = pipe(**inputs, generator=torch.manual_seed(0))[0] output_lora_loaded = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(np.allclose(output_adapter_1, output_lora_loaded, atol=1e-3, rtol=1e-3)) self.assertTrue(np.allclose(output_adapter_1, output_lora_loaded, atol=1e-3, rtol=1e-3))
def _test_group_offloading_inference_denoiser(self, offload_type, use_stream):
from diffusers.hooks.group_offloading import _get_top_level_group_offload_hook
onload_device = torch_device
offload_device = torch.device("cpu")
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(self.scheduler_classes[0])
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
denoiser.add_adapter(denoiser_lora_config)
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
with tempfile.TemporaryDirectory() as tmpdirname:
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
self.pipeline_class.save_lora_weights(
save_directory=tmpdirname, safe_serialization=True, **lora_state_dicts
)
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
components, _, _ = self.get_dummy_components(self.scheduler_classes[0])
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
check_if_lora_correctly_set(denoiser)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
# Test group offloading with load_lora_weights
denoiser.enable_group_offload(
onload_device=onload_device,
offload_device=offload_device,
offload_type=offload_type,
num_blocks_per_group=1,
use_stream=use_stream,
)
group_offload_hook_1 = _get_top_level_group_offload_hook(denoiser)
self.assertTrue(group_offload_hook_1 is not None)
output_1 = pipe(**inputs, generator=torch.manual_seed(0))[0]
# Test group offloading after removing the lora
pipe.unload_lora_weights()
group_offload_hook_2 = _get_top_level_group_offload_hook(denoiser)
self.assertTrue(group_offload_hook_2 is not None)
output_2 = pipe(**inputs, generator=torch.manual_seed(0))[0] # noqa: F841
# Add the lora again and check if group offloading works
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
check_if_lora_correctly_set(denoiser)
group_offload_hook_3 = _get_top_level_group_offload_hook(denoiser)
self.assertTrue(group_offload_hook_3 is not None)
output_3 = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(np.allclose(output_1, output_3, atol=1e-3, rtol=1e-3))
@parameterized.expand([("block_level", True), ("leaf_level", False), ("leaf_level", True)])
@require_torch_accelerator
def test_group_offloading_inference_denoiser(self, offload_type, use_stream):
for cls in inspect.getmro(self.__class__):
if "test_group_offloading_inference_denoiser" in cls.__dict__ and cls is not PeftLoraLoaderMixinTests:
# Skip this test if it is overwritten by child class. We need to do this because parameterized
# materializes the test methods on invocation which cannot be overridden.
return
self._test_group_offloading_inference_denoiser(offload_type, use_stream)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment