Unverified Commit 31058cda authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[LoRA] allow loras to be loaded with low_cpu_mem_usage. (#9510)

* allow loras to be loaded with low_cpu_mem_usage.

* add flux support but note https://github.com/huggingface/diffusers/pull/9510\#issuecomment-2378316687



* low_cpu_mem_usage.

* fix-copies

* fix-copies again

* tests

* _LOW_CPU_MEM_USAGE_DEFAULT_LORA

* _peft_version default.

* version checks.

* version check.

* version check.

* version check.

* require peft 0.13.1.

* explicitly specify low_cpu_mem_usage=False.

* docs.

* transformers version 4.45.2.

* update

* fix

* empty

* better name initialize_dummy_state_dict.

* doc todos.

* Apply suggestions from code review
Co-authored-by: default avatarSteven Liu <59462357+stevhliu@users.noreply.github.com>

* style

* fix-copies

---------
Co-authored-by: default avatarSteven Liu <59462357+stevhliu@users.noreply.github.com>
parent ec9e5264
......@@ -75,6 +75,12 @@ image
![pixel-art](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/peft_integration/diffusers_peft_lora_inference_12_1.png)
<Tip>
By default, if the most up-to-date versions of PEFT and Transformers are detected, `low_cpu_mem_usage` is set to `True` to speed up the loading time of LoRA checkpoints.
</Tip>
## Merge adapters
You can also merge different adapter checkpoints for inference to blend their styles together.
......
This diff is collapsed.
......@@ -115,6 +115,9 @@ class UNet2DConditionLoadersMixin:
`default_{i}` where i is the total number of adapters being loaded.
weight_name (`str`, *optional*, defaults to None):
Name of the serialized state dict file.
low_cpu_mem_usage (`bool`, *optional*):
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
weights.
Example:
......@@ -142,8 +145,14 @@ class UNet2DConditionLoadersMixin:
adapter_name = kwargs.pop("adapter_name", None)
_pipeline = kwargs.pop("_pipeline", None)
network_alphas = kwargs.pop("network_alphas", None)
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", False)
allow_pickle = False
if low_cpu_mem_usage and is_peft_version("<=", "0.13.0"):
raise ValueError(
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
)
if use_safetensors is None:
use_safetensors = True
allow_pickle = True
......@@ -209,6 +218,7 @@ class UNet2DConditionLoadersMixin:
network_alphas=network_alphas,
adapter_name=adapter_name,
_pipeline=_pipeline,
low_cpu_mem_usage=low_cpu_mem_usage,
)
else:
raise ValueError(
......@@ -268,7 +278,9 @@ class UNet2DConditionLoadersMixin:
return attn_processors
def _process_lora(self, state_dict, unet_identifier_key, network_alphas, adapter_name, _pipeline):
def _process_lora(
self, state_dict, unet_identifier_key, network_alphas, adapter_name, _pipeline, low_cpu_mem_usage
):
# This method does the following things:
# 1. Filters the `state_dict` with keys matching `unet_identifier_key` when using the non-legacy
# format. For legacy format no filtering is applied.
......@@ -335,9 +347,12 @@ class UNet2DConditionLoadersMixin:
# In case the pipeline has been already offloaded to CPU - temporarily remove the hooks
# otherwise loading LoRA weights will lead to an error
is_model_cpu_offload, is_sequential_cpu_offload = self._optionally_disable_offloading(_pipeline)
peft_kwargs = {}
if is_peft_version(">=", "0.13.1"):
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
inject_adapter_in_model(lora_config, self, adapter_name=adapter_name)
incompatible_keys = set_peft_model_state_dict(self, state_dict, adapter_name)
inject_adapter_in_model(lora_config, self, adapter_name=adapter_name, **peft_kwargs)
incompatible_keys = set_peft_model_state_dict(self, state_dict, adapter_name, **peft_kwargs)
if incompatible_keys is not None:
# check only for unexpected keys
......
......@@ -388,6 +388,24 @@ def require_peft_version_greater(peft_version):
return decorator
def require_transformers_version_greater(transformers_version):
"""
Decorator marking a test that requires transformers with a specific version, this would require some specific
versions of PEFT and transformers.
"""
def decorator(test_case):
correct_transformers_version = is_transformers_available() and version.parse(
version.parse(importlib.metadata.version("transformers")).base_version
) > version.parse(transformers_version)
return unittest.skipUnless(
correct_transformers_version,
f"test requires transformers with the version greater than {transformers_version}",
)(test_case)
return decorator
def require_accelerate_version_greater(accelerate_version):
def decorator(test_case):
correct_accelerate_version = is_peft_available() and version.parse(
......
......@@ -32,13 +32,14 @@ from diffusers.utils.testing_utils import (
floats_tensor,
require_peft_backend,
require_peft_version_greater,
require_transformers_version_greater,
skip_mps,
torch_device,
)
if is_peft_available():
from peft import LoraConfig
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
from peft.tuners.tuners_utils import BaseTunerLayer
from peft.utils import get_peft_model_state_dict
......@@ -65,6 +66,12 @@ def check_if_lora_correctly_set(model) -> bool:
return False
def initialize_dummy_state_dict(state_dict):
if not all(v.device.type == "meta" for _, v in state_dict.items()):
raise ValueError("`state_dict` has non-meta values.")
return {k: torch.randn(v.shape, device=torch_device, dtype=v.dtype) for k, v in state_dict.items()}
@require_peft_backend
class PeftLoraLoaderMixinTests:
pipeline_class = None
......@@ -272,6 +279,136 @@ class PeftLoraLoaderMixinTests:
not np.allclose(output_lora, output_no_lora, atol=1e-3, rtol=1e-3), "Lora should change the output"
)
@require_peft_version_greater("0.13.1")
def test_low_cpu_mem_usage_with_injection(self):
"""Tests if we can inject LoRA state dict with low_cpu_mem_usage."""
for scheduler_cls in self.scheduler_classes:
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
inject_adapter_in_model(text_lora_config, pipe.text_encoder, low_cpu_mem_usage=True)
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder."
)
self.assertTrue(
"meta" in {p.device.type for p in pipe.text_encoder.parameters()},
"The LoRA params should be on 'meta' device.",
)
te_state_dict = initialize_dummy_state_dict(get_peft_model_state_dict(pipe.text_encoder))
set_peft_model_state_dict(pipe.text_encoder, te_state_dict, low_cpu_mem_usage=True)
self.assertTrue(
"meta" not in {p.device.type for p in pipe.text_encoder.parameters()},
"No param should be on 'meta' device.",
)
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
inject_adapter_in_model(denoiser_lora_config, denoiser, low_cpu_mem_usage=True)
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
self.assertTrue(
"meta" in {p.device.type for p in denoiser.parameters()}, "The LoRA params should be on 'meta' device."
)
denoiser_state_dict = initialize_dummy_state_dict(get_peft_model_state_dict(denoiser))
set_peft_model_state_dict(denoiser, denoiser_state_dict, low_cpu_mem_usage=True)
self.assertTrue(
"meta" not in {p.device.type for p in denoiser.parameters()}, "No param should be on 'meta' device."
)
if self.has_two_text_encoders or self.has_three_text_encoders:
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
inject_adapter_in_model(text_lora_config, pipe.text_encoder_2, low_cpu_mem_usage=True)
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
)
self.assertTrue(
"meta" in {p.device.type for p in pipe.text_encoder_2.parameters()},
"The LoRA params should be on 'meta' device.",
)
te2_state_dict = initialize_dummy_state_dict(get_peft_model_state_dict(pipe.text_encoder_2))
set_peft_model_state_dict(pipe.text_encoder_2, te2_state_dict, low_cpu_mem_usage=True)
self.assertTrue(
"meta" not in {p.device.type for p in pipe.text_encoder_2.parameters()},
"No param should be on 'meta' device.",
)
_, _, inputs = self.get_dummy_inputs()
output_lora = pipe(**inputs)[0]
self.assertTrue(output_lora.shape == self.output_shape)
@require_peft_version_greater("0.13.1")
@require_transformers_version_greater("4.45.1")
def test_low_cpu_mem_usage_with_loading(self):
"""Tests if we can load LoRA state dict with low_cpu_mem_usage."""
for scheduler_cls in self.scheduler_classes:
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape)
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder.add_adapter(text_lora_config)
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
)
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
denoiser.add_adapter(denoiser_lora_config)
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
if self.has_two_text_encoders or self.has_three_text_encoders:
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder_2.add_adapter(text_lora_config)
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
)
images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
with tempfile.TemporaryDirectory() as tmpdirname:
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
self.pipeline_class.save_lora_weights(
save_directory=tmpdirname, safe_serialization=False, **lora_state_dicts
)
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))
pipe.unload_lora_weights()
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin"), low_cpu_mem_usage=False)
for module_name, module in modules_to_save.items():
self.assertTrue(check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}")
images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(
np.allclose(images_lora, images_lora_from_pretrained, atol=1e-3, rtol=1e-3),
"Loading from saved checkpoints should give same results.",
)
# Now, check for `low_cpu_mem_usage.`
pipe.unload_lora_weights()
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin"), low_cpu_mem_usage=True)
for module_name, module in modules_to_save.items():
self.assertTrue(check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}")
images_lora_from_pretrained_low_cpu = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(
np.allclose(
images_lora_from_pretrained_low_cpu, images_lora_from_pretrained, atol=1e-3, rtol=1e-3
),
"Loading from saved checkpoints with `low_cpu_mem_usage` should give same results.",
)
def test_simple_inference_with_text_lora_and_scale(self):
"""
Tests a simple inference with lora attached on the text encoder + scale argument
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment