Unverified Commit 668e34c6 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[LoRA SD3] add support for lora fusion in sd3 (#8616)

* add support for lora fusion in sd3

* add test to ensure fused lora and effective lora produce same outpouts
parent 25d7bb3e
...@@ -1728,3 +1728,78 @@ class SD3LoraLoaderMixin: ...@@ -1728,3 +1728,78 @@ class SD3LoraLoaderMixin:
remove_hook_from_module(component, recurse=is_sequential_cpu_offload) remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
return (is_model_cpu_offload, is_sequential_cpu_offload) return (is_model_cpu_offload, is_sequential_cpu_offload)
def fuse_lora(
self,
fuse_transformer: bool = True,
lora_scale: float = 1.0,
safe_fusing: bool = False,
adapter_names: Optional[List[str]] = None,
):
r"""
Fuses the LoRA parameters into the original parameters of the corresponding blocks.
<Tip warning={true}>
This is an experimental API.
</Tip>
Args:
fuse_transformer (`bool`, defaults to `True`): Whether to fuse the transformer LoRA parameters.
lora_scale (`float`, defaults to 1.0):
Controls how much to influence the outputs with the LoRA parameters.
safe_fusing (`bool`, defaults to `False`):
Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
adapter_names (`List[str]`, *optional*):
Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
Example:
```py
from diffusers import DiffusionPipeline
import torch
pipeline = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-3-medium-diffusers", torch_dtype=torch.float16
).to("cuda")
pipeline.load_lora_weights(
"nerijs/pixel-art-medium-128-v0.1",
weight_name="pixel-art-medium-128-v0.1.safetensors",
adapter_name="pixel",
)
pipeline.fuse_lora(lora_scale=0.7)
```
"""
if fuse_transformer:
self.num_fused_loras += 1
if fuse_transformer:
transformer = (
getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer
)
transformer.fuse_lora(lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names)
def unfuse_lora(self, unfuse_transformer: bool = True):
r"""
Reverses the effect of
[`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraLoaderMixin.fuse_lora).
<Tip warning={true}>
This is an experimental API.
</Tip>
Args:
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the transformer LoRA parameters.
"""
from peft.tuners.tuners_utils import BaseTunerLayer
transformer = getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer
if unfuse_transformer:
for module in transformer.modules():
if isinstance(module, BaseTunerLayer):
module.unmerge()
self.num_fused_loras -= 1
...@@ -13,6 +13,8 @@ ...@@ -13,6 +13,8 @@
# limitations under the License. # limitations under the License.
import inspect
from functools import partial
from typing import Any, Dict, List, Optional, Union from typing import Any, Dict, List, Optional, Union
import torch import torch
...@@ -239,6 +241,47 @@ class SD3Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi ...@@ -239,6 +241,47 @@ class SD3Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi
if hasattr(module, "gradient_checkpointing"): if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value module.gradient_checkpointing = value
def fuse_lora(self, lora_scale=1.0, safe_fusing=False, adapter_names=None):
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for `fuse_lora()`.")
self.lora_scale = lora_scale
self._safe_fusing = safe_fusing
self.apply(partial(self._fuse_lora_apply, adapter_names=adapter_names))
def _fuse_lora_apply(self, module, adapter_names=None):
from peft.tuners.tuners_utils import BaseTunerLayer
merge_kwargs = {"safe_merge": self._safe_fusing}
if isinstance(module, BaseTunerLayer):
if self.lora_scale != 1.0:
module.scale_layer(self.lora_scale)
# For BC with prevous PEFT versions, we need to check the signature
# of the `merge` method to see if it supports the `adapter_names` argument.
supported_merge_kwargs = list(inspect.signature(module.merge).parameters)
if "adapter_names" in supported_merge_kwargs:
merge_kwargs["adapter_names"] = adapter_names
elif "adapter_names" not in supported_merge_kwargs and adapter_names is not None:
raise ValueError(
"The `adapter_names` argument is not supported with your PEFT version. Please upgrade"
" to the latest version of PEFT. `pip install -U peft`"
)
module.merge(**merge_kwargs)
def unfuse_lora(self):
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for `unfuse_lora()`.")
self.apply(self._unfuse_lora_apply)
def _unfuse_lora_apply(self, module):
from peft.tuners.tuners_utils import BaseTunerLayer
if isinstance(module, BaseTunerLayer):
module.unmerge()
def forward( def forward(
self, self,
hidden_states: torch.FloatTensor, hidden_states: torch.FloatTensor,
......
...@@ -205,3 +205,85 @@ class SD3LoRATests(unittest.TestCase): ...@@ -205,3 +205,85 @@ class SD3LoRATests(unittest.TestCase):
np.allclose(output_no_lora, output_lora_0_scale, atol=1e-3, rtol=1e-3), np.allclose(output_no_lora, output_lora_0_scale, atol=1e-3, rtol=1e-3),
"Lora + 0 scale should lead to same result as no LoRA", "Lora + 0 scale should lead to same result as no LoRA",
) )
def test_simple_inference_with_transformer_fused(self):
components = self.get_dummy_components()
transformer_lora_config = self.get_lora_config_for_transformer()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(torch_device)
output_no_lora = pipe(**inputs).images
pipe.transformer.add_adapter(transformer_lora_config)
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer")
pipe.fuse_lora()
# Fusing should still keep the LoRA layers
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer")
inputs = self.get_dummy_inputs(torch_device)
ouput_fused = pipe(**inputs).images
self.assertFalse(
np.allclose(ouput_fused, output_no_lora, atol=1e-3, rtol=1e-3), "Fused lora should change the output"
)
def test_simple_inference_with_transformer_fused_with_no_fusion(self):
components = self.get_dummy_components()
transformer_lora_config = self.get_lora_config_for_transformer()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(torch_device)
output_no_lora = pipe(**inputs).images
pipe.transformer.add_adapter(transformer_lora_config)
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer")
inputs = self.get_dummy_inputs(torch_device)
ouput_lora = pipe(**inputs).images
pipe.fuse_lora()
# Fusing should still keep the LoRA layers
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer")
inputs = self.get_dummy_inputs(torch_device)
ouput_fused = pipe(**inputs).images
self.assertFalse(
np.allclose(ouput_fused, output_no_lora, atol=1e-3, rtol=1e-3), "Fused lora should change the output"
)
self.assertTrue(
np.allclose(ouput_fused, ouput_lora, atol=1e-3, rtol=1e-3),
"Fused lora output should be changed when LoRA isn't fused but still effective.",
)
def test_simple_inference_with_transformer_fuse_unfuse(self):
components = self.get_dummy_components()
transformer_lora_config = self.get_lora_config_for_transformer()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(torch_device)
output_no_lora = pipe(**inputs).images
pipe.transformer.add_adapter(transformer_lora_config)
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer")
pipe.fuse_lora()
# Fusing should still keep the LoRA layers
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer")
inputs = self.get_dummy_inputs(torch_device)
ouput_fused = pipe(**inputs).images
self.assertFalse(
np.allclose(ouput_fused, output_no_lora, atol=1e-3, rtol=1e-3), "Fused lora should change the output"
)
pipe.unfuse_lora()
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer")
inputs = self.get_dummy_inputs(torch_device)
output_unfused_lora = pipe(**inputs).images
self.assertTrue(
np.allclose(ouput_fused, output_unfused_lora, atol=1e-3, rtol=1e-3), "Fused lora should change the output"
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment