Unverified Commit 31058cda authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[LoRA] allow loras to be loaded with low_cpu_mem_usage. (#9510)

* allow loras to be loaded with low_cpu_mem_usage.

* add flux support but note https://github.com/huggingface/diffusers/pull/9510\#issuecomment-2378316687



* low_cpu_mem_usage.

* fix-copies

* fix-copies again

* tests

* _LOW_CPU_MEM_USAGE_DEFAULT_LORA

* _peft_version default.

* version checks.

* version check.

* version check.

* version check.

* require peft 0.13.1.

* explicitly specify low_cpu_mem_usage=False.

* docs.

* transformers version 4.45.2.

* update

* fix

* empty

* better name initialize_dummy_state_dict.

* doc todos.

* Apply suggestions from code review
Co-authored-by: default avatarSteven Liu <59462357+stevhliu@users.noreply.github.com>

* style

* fix-copies

---------
Co-authored-by: default avatarSteven Liu <59462357+stevhliu@users.noreply.github.com>
parent ec9e5264
...@@ -75,6 +75,12 @@ image ...@@ -75,6 +75,12 @@ image
![pixel-art](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/peft_integration/diffusers_peft_lora_inference_12_1.png) ![pixel-art](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/peft_integration/diffusers_peft_lora_inference_12_1.png)
<Tip>
By default, if the most up-to-date versions of PEFT and Transformers are detected, `low_cpu_mem_usage` is set to `True` to speed up the loading time of LoRA checkpoints.
</Tip>
## Merge adapters ## Merge adapters
You can also merge different adapter checkpoints for inference to blend their styles together. You can also merge different adapter checkpoints for inference to blend their styles together.
......
This diff is collapsed.
...@@ -115,6 +115,9 @@ class UNet2DConditionLoadersMixin: ...@@ -115,6 +115,9 @@ class UNet2DConditionLoadersMixin:
`default_{i}` where i is the total number of adapters being loaded. `default_{i}` where i is the total number of adapters being loaded.
weight_name (`str`, *optional*, defaults to None): weight_name (`str`, *optional*, defaults to None):
Name of the serialized state dict file. Name of the serialized state dict file.
low_cpu_mem_usage (`bool`, *optional*):
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
weights.
Example: Example:
...@@ -142,8 +145,14 @@ class UNet2DConditionLoadersMixin: ...@@ -142,8 +145,14 @@ class UNet2DConditionLoadersMixin:
adapter_name = kwargs.pop("adapter_name", None) adapter_name = kwargs.pop("adapter_name", None)
_pipeline = kwargs.pop("_pipeline", None) _pipeline = kwargs.pop("_pipeline", None)
network_alphas = kwargs.pop("network_alphas", None) network_alphas = kwargs.pop("network_alphas", None)
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", False)
allow_pickle = False allow_pickle = False
if low_cpu_mem_usage and is_peft_version("<=", "0.13.0"):
raise ValueError(
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
)
if use_safetensors is None: if use_safetensors is None:
use_safetensors = True use_safetensors = True
allow_pickle = True allow_pickle = True
...@@ -209,6 +218,7 @@ class UNet2DConditionLoadersMixin: ...@@ -209,6 +218,7 @@ class UNet2DConditionLoadersMixin:
network_alphas=network_alphas, network_alphas=network_alphas,
adapter_name=adapter_name, adapter_name=adapter_name,
_pipeline=_pipeline, _pipeline=_pipeline,
low_cpu_mem_usage=low_cpu_mem_usage,
) )
else: else:
raise ValueError( raise ValueError(
...@@ -268,7 +278,9 @@ class UNet2DConditionLoadersMixin: ...@@ -268,7 +278,9 @@ class UNet2DConditionLoadersMixin:
return attn_processors return attn_processors
def _process_lora(self, state_dict, unet_identifier_key, network_alphas, adapter_name, _pipeline): def _process_lora(
self, state_dict, unet_identifier_key, network_alphas, adapter_name, _pipeline, low_cpu_mem_usage
):
# This method does the following things: # This method does the following things:
# 1. Filters the `state_dict` with keys matching `unet_identifier_key` when using the non-legacy # 1. Filters the `state_dict` with keys matching `unet_identifier_key` when using the non-legacy
# format. For legacy format no filtering is applied. # format. For legacy format no filtering is applied.
...@@ -335,9 +347,12 @@ class UNet2DConditionLoadersMixin: ...@@ -335,9 +347,12 @@ class UNet2DConditionLoadersMixin:
# In case the pipeline has been already offloaded to CPU - temporarily remove the hooks # In case the pipeline has been already offloaded to CPU - temporarily remove the hooks
# otherwise loading LoRA weights will lead to an error # otherwise loading LoRA weights will lead to an error
is_model_cpu_offload, is_sequential_cpu_offload = self._optionally_disable_offloading(_pipeline) is_model_cpu_offload, is_sequential_cpu_offload = self._optionally_disable_offloading(_pipeline)
peft_kwargs = {}
if is_peft_version(">=", "0.13.1"):
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
inject_adapter_in_model(lora_config, self, adapter_name=adapter_name) inject_adapter_in_model(lora_config, self, adapter_name=adapter_name, **peft_kwargs)
incompatible_keys = set_peft_model_state_dict(self, state_dict, adapter_name) incompatible_keys = set_peft_model_state_dict(self, state_dict, adapter_name, **peft_kwargs)
if incompatible_keys is not None: if incompatible_keys is not None:
# check only for unexpected keys # check only for unexpected keys
......
...@@ -388,6 +388,24 @@ def require_peft_version_greater(peft_version): ...@@ -388,6 +388,24 @@ def require_peft_version_greater(peft_version):
return decorator return decorator
def require_transformers_version_greater(transformers_version):
"""
Decorator marking a test that requires transformers with a specific version, this would require some specific
versions of PEFT and transformers.
"""
def decorator(test_case):
correct_transformers_version = is_transformers_available() and version.parse(
version.parse(importlib.metadata.version("transformers")).base_version
) > version.parse(transformers_version)
return unittest.skipUnless(
correct_transformers_version,
f"test requires transformers with the version greater than {transformers_version}",
)(test_case)
return decorator
def require_accelerate_version_greater(accelerate_version): def require_accelerate_version_greater(accelerate_version):
def decorator(test_case): def decorator(test_case):
correct_accelerate_version = is_peft_available() and version.parse( correct_accelerate_version = is_peft_available() and version.parse(
......
...@@ -32,13 +32,14 @@ from diffusers.utils.testing_utils import ( ...@@ -32,13 +32,14 @@ from diffusers.utils.testing_utils import (
floats_tensor, floats_tensor,
require_peft_backend, require_peft_backend,
require_peft_version_greater, require_peft_version_greater,
require_transformers_version_greater,
skip_mps, skip_mps,
torch_device, torch_device,
) )
if is_peft_available(): if is_peft_available():
from peft import LoraConfig from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
from peft.tuners.tuners_utils import BaseTunerLayer from peft.tuners.tuners_utils import BaseTunerLayer
from peft.utils import get_peft_model_state_dict from peft.utils import get_peft_model_state_dict
...@@ -65,6 +66,12 @@ def check_if_lora_correctly_set(model) -> bool: ...@@ -65,6 +66,12 @@ def check_if_lora_correctly_set(model) -> bool:
return False return False
def initialize_dummy_state_dict(state_dict):
if not all(v.device.type == "meta" for _, v in state_dict.items()):
raise ValueError("`state_dict` has non-meta values.")
return {k: torch.randn(v.shape, device=torch_device, dtype=v.dtype) for k, v in state_dict.items()}
@require_peft_backend @require_peft_backend
class PeftLoraLoaderMixinTests: class PeftLoraLoaderMixinTests:
pipeline_class = None pipeline_class = None
...@@ -272,6 +279,136 @@ class PeftLoraLoaderMixinTests: ...@@ -272,6 +279,136 @@ class PeftLoraLoaderMixinTests:
not np.allclose(output_lora, output_no_lora, atol=1e-3, rtol=1e-3), "Lora should change the output" not np.allclose(output_lora, output_no_lora, atol=1e-3, rtol=1e-3), "Lora should change the output"
) )
@require_peft_version_greater("0.13.1")
def test_low_cpu_mem_usage_with_injection(self):
"""Tests if we can inject LoRA state dict with low_cpu_mem_usage."""
for scheduler_cls in self.scheduler_classes:
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
inject_adapter_in_model(text_lora_config, pipe.text_encoder, low_cpu_mem_usage=True)
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder."
)
self.assertTrue(
"meta" in {p.device.type for p in pipe.text_encoder.parameters()},
"The LoRA params should be on 'meta' device.",
)
te_state_dict = initialize_dummy_state_dict(get_peft_model_state_dict(pipe.text_encoder))
set_peft_model_state_dict(pipe.text_encoder, te_state_dict, low_cpu_mem_usage=True)
self.assertTrue(
"meta" not in {p.device.type for p in pipe.text_encoder.parameters()},
"No param should be on 'meta' device.",
)
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
inject_adapter_in_model(denoiser_lora_config, denoiser, low_cpu_mem_usage=True)
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
self.assertTrue(
"meta" in {p.device.type for p in denoiser.parameters()}, "The LoRA params should be on 'meta' device."
)
denoiser_state_dict = initialize_dummy_state_dict(get_peft_model_state_dict(denoiser))
set_peft_model_state_dict(denoiser, denoiser_state_dict, low_cpu_mem_usage=True)
self.assertTrue(
"meta" not in {p.device.type for p in denoiser.parameters()}, "No param should be on 'meta' device."
)
if self.has_two_text_encoders or self.has_three_text_encoders:
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
inject_adapter_in_model(text_lora_config, pipe.text_encoder_2, low_cpu_mem_usage=True)
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
)
self.assertTrue(
"meta" in {p.device.type for p in pipe.text_encoder_2.parameters()},
"The LoRA params should be on 'meta' device.",
)
te2_state_dict = initialize_dummy_state_dict(get_peft_model_state_dict(pipe.text_encoder_2))
set_peft_model_state_dict(pipe.text_encoder_2, te2_state_dict, low_cpu_mem_usage=True)
self.assertTrue(
"meta" not in {p.device.type for p in pipe.text_encoder_2.parameters()},
"No param should be on 'meta' device.",
)
_, _, inputs = self.get_dummy_inputs()
output_lora = pipe(**inputs)[0]
self.assertTrue(output_lora.shape == self.output_shape)
@require_peft_version_greater("0.13.1")
@require_transformers_version_greater("4.45.1")
def test_low_cpu_mem_usage_with_loading(self):
"""Tests if we can load LoRA state dict with low_cpu_mem_usage."""
for scheduler_cls in self.scheduler_classes:
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape)
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder.add_adapter(text_lora_config)
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
)
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
denoiser.add_adapter(denoiser_lora_config)
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
if self.has_two_text_encoders or self.has_three_text_encoders:
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder_2.add_adapter(text_lora_config)
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
)
images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
with tempfile.TemporaryDirectory() as tmpdirname:
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
self.pipeline_class.save_lora_weights(
save_directory=tmpdirname, safe_serialization=False, **lora_state_dicts
)
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))
pipe.unload_lora_weights()
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin"), low_cpu_mem_usage=False)
for module_name, module in modules_to_save.items():
self.assertTrue(check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}")
images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(
np.allclose(images_lora, images_lora_from_pretrained, atol=1e-3, rtol=1e-3),
"Loading from saved checkpoints should give same results.",
)
# Now, check for `low_cpu_mem_usage.`
pipe.unload_lora_weights()
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin"), low_cpu_mem_usage=True)
for module_name, module in modules_to_save.items():
self.assertTrue(check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}")
images_lora_from_pretrained_low_cpu = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(
np.allclose(
images_lora_from_pretrained_low_cpu, images_lora_from_pretrained, atol=1e-3, rtol=1e-3
),
"Loading from saved checkpoints with `low_cpu_mem_usage` should give same results.",
)
def test_simple_inference_with_text_lora_and_scale(self): def test_simple_inference_with_text_lora_and_scale(self):
""" """
Tests a simple inference with lora attached on the text encoder + scale argument Tests a simple inference with lora attached on the text encoder + scale argument
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment