"...en/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "b65928b5561a1b4cfdaf82482f41fbf8b7a8b9bb"
Unverified Commit c583f3b4 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

Fuse loras (#4473)



* Fuse loras

* initial implementation.

* add slow test one.

* styling

* add: test for checking efficiency

* print

* position

* place model offload correctly

* style

* style.

* unfuse test.

* final checks

* remove warning test

* remove warnings altogether

* debugging

* tighten up tests.

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* denugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debuging

* debugging

* debugging

* debugging

* suit up the generator initialization a bit.

* remove print

* update assertion.

* debugging

* remove print.

* fix: assertions.

* style

* can generator be a problem?

* generator

* correct tests.

* support text encoder lora fusion.

* tighten up tests.

---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent 12358b98
...@@ -85,7 +85,49 @@ class PatchedLoraProjection(nn.Module): ...@@ -85,7 +85,49 @@ class PatchedLoraProjection(nn.Module):
self.lora_scale = lora_scale self.lora_scale = lora_scale
def _fuse_lora(self):
if self.lora_linear_layer is None:
return
dtype, device = self.regular_linear_layer.weight.data.dtype, self.regular_linear_layer.weight.data.device
logger.info(f"Fusing LoRA weights for {self.__class__}")
w_orig = self.regular_linear_layer.weight.data.float()
w_up = self.lora_linear_layer.up.weight.data.float()
w_down = self.lora_linear_layer.down.weight.data.float()
if self.lora_linear_layer.network_alpha is not None:
w_up = w_up * self.lora_linear_layer.network_alpha / self.lora_linear_layer.rank
fused_weight = w_orig + torch.bmm(w_up[None, :], w_down[None, :])[0]
self.regular_linear_layer.weight.data = fused_weight.to(device=device, dtype=dtype)
# we can drop the lora layer now
self.lora_linear_layer = None
# offload the up and down matrices to CPU to not blow the memory
self.w_up = w_up.cpu()
self.w_down = w_down.cpu()
def _unfuse_lora(self):
if not (hasattr(self, "w_up") and hasattr(self, "w_down")):
return
logger.info(f"Unfusing LoRA weights for {self.__class__}")
fused_weight = self.regular_linear_layer.weight.data
dtype, device = fused_weight.dtype, fused_weight.device
self.w_up = self.w_up.to(device=device, dtype=dtype)
self.w_down = self.w_down.to(device, dtype=dtype)
unfused_weight = fused_weight - torch.bmm(self.w_up[None, :], self.w_down[None, :])[0]
self.regular_linear_layer.weight.data = unfused_weight.to(device=device, dtype=dtype)
self.w_up = None
self.w_down = None
def forward(self, input): def forward(self, input):
if self.lora_linear_layer is None:
return self.regular_linear_layer(input)
return self.regular_linear_layer(input) + self.lora_scale * self.lora_linear_layer(input) return self.regular_linear_layer(input) + self.lora_scale * self.lora_linear_layer(input)
...@@ -525,6 +567,20 @@ class UNet2DConditionLoadersMixin: ...@@ -525,6 +567,20 @@ class UNet2DConditionLoadersMixin:
save_function(state_dict, os.path.join(save_directory, weight_name)) save_function(state_dict, os.path.join(save_directory, weight_name))
logger.info(f"Model weights saved in {os.path.join(save_directory, weight_name)}") logger.info(f"Model weights saved in {os.path.join(save_directory, weight_name)}")
def fuse_lora(self):
self.apply(self._fuse_lora_apply)
def _fuse_lora_apply(self, module):
if hasattr(module, "_fuse_lora"):
module._fuse_lora()
def unfuse_lora(self):
self.apply(self._unfuse_lora_apply)
def _unfuse_lora_apply(self, module):
if hasattr(module, "_unfuse_lora"):
module._unfuse_lora()
class TextualInversionLoaderMixin: class TextualInversionLoaderMixin:
r""" r"""
...@@ -1712,6 +1768,83 @@ class LoraLoaderMixin: ...@@ -1712,6 +1768,83 @@ class LoraLoaderMixin:
# Safe to call the following regardless of LoRA. # Safe to call the following regardless of LoRA.
self._remove_text_encoder_monkey_patch() self._remove_text_encoder_monkey_patch()
def fuse_lora(self, fuse_unet: bool = True, fuse_text_encoder: bool = True):
r"""
Fuses the LoRA parameters into the original parameters of the corresponding blocks.
<Tip warning={true}>
This is an experimental API.
</Tip>
Args:
fuse_unet (`bool`, defaults to `True`): Whether to fuse the UNet LoRA parameters.
fuse_text_encoder (`bool`, defaults to `True`):
Whether to fuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
LoRA parameters then it won't have any effect.
"""
if fuse_unet:
self.unet.fuse_lora()
def fuse_text_encoder_lora(text_encoder):
for _, attn_module in text_encoder_attn_modules(text_encoder):
if isinstance(attn_module.q_proj, PatchedLoraProjection):
attn_module.q_proj._fuse_lora()
attn_module.k_proj._fuse_lora()
attn_module.v_proj._fuse_lora()
attn_module.out_proj._fuse_lora()
for _, mlp_module in text_encoder_mlp_modules(text_encoder):
if isinstance(mlp_module.fc1, PatchedLoraProjection):
mlp_module.fc1._fuse_lora()
mlp_module.fc2._fuse_lora()
if fuse_text_encoder:
if hasattr(self, "text_encoder"):
fuse_text_encoder_lora(self.text_encoder)
if hasattr(self, "text_encoder_2"):
fuse_text_encoder_lora(self.text_encoder_2)
def unfuse_lora(self, unfuse_unet: bool = True, unfuse_text_encoder: bool = True):
r"""
Reverses the effect of
[`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraLoaderMixin.fuse_lora).
<Tip warning={true}>
This is an experimental API.
</Tip>
Args:
unfuse_unet (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
unfuse_text_encoder (`bool`, defaults to `True`):
Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
LoRA parameters then it won't have any effect.
"""
if unfuse_unet:
self.unet.unfuse_lora()
def unfuse_text_encoder_lora(text_encoder):
for _, attn_module in text_encoder_attn_modules(text_encoder):
if isinstance(attn_module.q_proj, PatchedLoraProjection):
attn_module.q_proj._unfuse_lora()
attn_module.k_proj._unfuse_lora()
attn_module.v_proj._unfuse_lora()
attn_module.out_proj._unfuse_lora()
for _, mlp_module in text_encoder_mlp_modules(text_encoder):
if isinstance(mlp_module.fc1, PatchedLoraProjection):
mlp_module.fc1._unfuse_lora()
mlp_module.fc2._unfuse_lora()
if unfuse_text_encoder:
if hasattr(self, "text_encoder"):
unfuse_text_encoder_lora(self.text_encoder)
if hasattr(self, "text_encoder_2"):
unfuse_text_encoder_lora(self.text_encoder_2)
class FromSingleFileMixin: class FromSingleFileMixin:
""" """
......
...@@ -14,9 +14,15 @@ ...@@ -14,9 +14,15 @@
from typing import Optional from typing import Optional
import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn from torch import nn
from ..utils import logging
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class LoRALinearLayer(nn.Module): class LoRALinearLayer(nn.Module):
def __init__(self, in_features, out_features, rank=4, network_alpha=None, device=None, dtype=None): def __init__(self, in_features, out_features, rank=4, network_alpha=None, device=None, dtype=None):
...@@ -91,6 +97,51 @@ class LoRACompatibleConv(nn.Conv2d): ...@@ -91,6 +97,51 @@ class LoRACompatibleConv(nn.Conv2d):
def set_lora_layer(self, lora_layer: Optional[LoRAConv2dLayer]): def set_lora_layer(self, lora_layer: Optional[LoRAConv2dLayer]):
self.lora_layer = lora_layer self.lora_layer = lora_layer
def _fuse_lora(self):
if self.lora_layer is None:
return
dtype, device = self.weight.data.dtype, self.weight.data.device
logger.info(f"Fusing LoRA weights for {self.__class__}")
w_orig = self.weight.data.float()
w_up = self.lora_layer.up.weight.data.float()
w_down = self.lora_layer.down.weight.data.float()
if self.lora_layer.network_alpha is not None:
w_up = w_up * self.lora_layer.network_alpha / self.lora_layer.rank
fusion = torch.mm(w_up.flatten(start_dim=1), w_down.flatten(start_dim=1))
fusion = fusion.reshape((w_orig.shape))
fused_weight = w_orig + fusion
self.weight.data = fused_weight.to(device=device, dtype=dtype)
# we can drop the lora layer now
self.lora_layer = None
# offload the up and down matrices to CPU to not blow the memory
self.w_up = w_up.cpu()
self.w_down = w_down.cpu()
def _unfuse_lora(self):
if not (hasattr(self, "w_up") and hasattr(self, "w_down")):
return
logger.info(f"Unfusing LoRA weights for {self.__class__}")
fused_weight = self.weight.data
dtype, device = fused_weight.data.dtype, fused_weight.data.device
self.w_up = self.w_up.to(device=device, dtype=dtype)
self.w_down = self.w_down.to(device, dtype=dtype)
fusion = torch.mm(self.w_up.flatten(start_dim=1), self.w_down.flatten(start_dim=1))
fusion = fusion.reshape((fused_weight.shape))
unfused_weight = fused_weight - fusion
self.weight.data = unfused_weight.to(device=device, dtype=dtype)
self.w_up = None
self.w_down = None
def forward(self, x): def forward(self, x):
if self.lora_layer is None: if self.lora_layer is None:
# make sure to the functional Conv2D function as otherwise torch.compile's graph will break # make sure to the functional Conv2D function as otherwise torch.compile's graph will break
...@@ -109,9 +160,49 @@ class LoRACompatibleLinear(nn.Linear): ...@@ -109,9 +160,49 @@ class LoRACompatibleLinear(nn.Linear):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.lora_layer = lora_layer self.lora_layer = lora_layer
def set_lora_layer(self, lora_layer: Optional[LoRAConv2dLayer]): def set_lora_layer(self, lora_layer: Optional[LoRALinearLayer]):
self.lora_layer = lora_layer self.lora_layer = lora_layer
def _fuse_lora(self):
if self.lora_layer is None:
return
dtype, device = self.weight.data.dtype, self.weight.data.device
logger.info(f"Fusing LoRA weights for {self.__class__}")
w_orig = self.weight.data.float()
w_up = self.lora_layer.up.weight.data.float()
w_down = self.lora_layer.down.weight.data.float()
if self.lora_layer.network_alpha is not None:
w_up = w_up * self.lora_layer.network_alpha / self.lora_layer.rank
fused_weight = w_orig + torch.bmm(w_up[None, :], w_down[None, :])[0]
self.weight.data = fused_weight.to(device=device, dtype=dtype)
# we can drop the lora layer now
self.lora_layer = None
# offload the up and down matrices to CPU to not blow the memory
self.w_up = w_up.cpu()
self.w_down = w_down.cpu()
def _unfuse_lora(self):
if not (hasattr(self, "w_up") and hasattr(self, "w_down")):
return
logger.info(f"Unfusing LoRA weights for {self.__class__}")
fused_weight = self.weight.data
dtype, device = fused_weight.dtype, fused_weight.device
self.w_up = self.w_up.to(device=device, dtype=dtype)
self.w_down = self.w_down.to(device, dtype=dtype)
unfused_weight = fused_weight - torch.bmm(self.w_up[None, :], self.w_down[None, :])[0]
self.weight.data = unfused_weight.to(device=device, dtype=dtype)
self.w_up = None
self.w_down = None
def forward(self, hidden_states, lora_scale: int = 1): def forward(self, hidden_states, lora_scale: int = 1):
if self.lora_layer is None: if self.lora_layer is None:
return super().forward(hidden_states) return super().forward(hidden_states)
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
import os import os
import tempfile import tempfile
import time
import unittest import unittest
import numpy as np import numpy as np
...@@ -692,6 +693,124 @@ class SDXLLoraLoaderMixinTests(unittest.TestCase): ...@@ -692,6 +693,124 @@ class SDXLLoraLoaderMixinTests(unittest.TestCase):
sd_pipe.unload_lora_weights() sd_pipe.unload_lora_weights()
def test_lora_fusion(self):
pipeline_components, lora_components = self.get_dummy_components()
sd_pipe = StableDiffusionXLPipeline(**pipeline_components)
sd_pipe = sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None)
_, _, pipeline_inputs = self.get_dummy_inputs(with_generator=False)
original_images = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images
orig_image_slice = original_images[0, -3:, -3:, -1]
# Emulate training.
set_lora_weights(lora_components["unet_lora_layers"].parameters(), randn_weight=True)
set_lora_weights(lora_components["text_encoder_one_lora_layers"].parameters(), randn_weight=True)
set_lora_weights(lora_components["text_encoder_two_lora_layers"].parameters(), randn_weight=True)
with tempfile.TemporaryDirectory() as tmpdirname:
StableDiffusionXLPipeline.save_lora_weights(
save_directory=tmpdirname,
unet_lora_layers=lora_components["unet_lora_layers"],
text_encoder_lora_layers=lora_components["text_encoder_one_lora_layers"],
text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_layers"],
safe_serialization=True,
)
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
sd_pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
sd_pipe.fuse_lora()
lora_images = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images
lora_image_slice = lora_images[0, -3:, -3:, -1]
self.assertFalse(np.allclose(orig_image_slice, lora_image_slice, atol=1e-3))
def test_unfuse_lora(self):
pipeline_components, lora_components = self.get_dummy_components()
sd_pipe = StableDiffusionXLPipeline(**pipeline_components)
sd_pipe = sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None)
_, _, pipeline_inputs = self.get_dummy_inputs(with_generator=False)
original_images = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images
orig_image_slice = original_images[0, -3:, -3:, -1]
# Emulate training.
set_lora_weights(lora_components["unet_lora_layers"].parameters(), randn_weight=True)
set_lora_weights(lora_components["text_encoder_one_lora_layers"].parameters(), randn_weight=True)
set_lora_weights(lora_components["text_encoder_two_lora_layers"].parameters(), randn_weight=True)
with tempfile.TemporaryDirectory() as tmpdirname:
StableDiffusionXLPipeline.save_lora_weights(
save_directory=tmpdirname,
unet_lora_layers=lora_components["unet_lora_layers"],
text_encoder_lora_layers=lora_components["text_encoder_one_lora_layers"],
text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_layers"],
safe_serialization=True,
)
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
sd_pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
sd_pipe.fuse_lora()
lora_images = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images
lora_image_slice = lora_images[0, -3:, -3:, -1]
# Reverse LoRA fusion.
sd_pipe.unfuse_lora()
original_images = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images
orig_image_slice_two = original_images[0, -3:, -3:, -1]
assert not np.allclose(
orig_image_slice, lora_image_slice
), "Fusion of LoRAs should lead to a different image slice."
assert not np.allclose(
orig_image_slice_two, lora_image_slice
), "Fusion of LoRAs should lead to a different image slice."
assert np.allclose(
orig_image_slice, orig_image_slice_two, atol=1e-3
), "Reversing LoRA fusion should lead to results similar to what was obtained with the pipeline without any LoRA parameters."
def test_lora_fusion_is_not_affected_by_unloading(self):
pipeline_components, lora_components = self.get_dummy_components()
sd_pipe = StableDiffusionXLPipeline(**pipeline_components)
sd_pipe = sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None)
_, _, pipeline_inputs = self.get_dummy_inputs(with_generator=False)
_ = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images
# Emulate training.
set_lora_weights(lora_components["unet_lora_layers"].parameters(), randn_weight=True)
set_lora_weights(lora_components["text_encoder_one_lora_layers"].parameters(), randn_weight=True)
set_lora_weights(lora_components["text_encoder_two_lora_layers"].parameters(), randn_weight=True)
with tempfile.TemporaryDirectory() as tmpdirname:
StableDiffusionXLPipeline.save_lora_weights(
save_directory=tmpdirname,
unet_lora_layers=lora_components["unet_lora_layers"],
text_encoder_lora_layers=lora_components["text_encoder_one_lora_layers"],
text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_layers"],
safe_serialization=True,
)
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
sd_pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
sd_pipe.fuse_lora()
lora_images = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images
lora_image_slice = lora_images[0, -3:, -3:, -1]
# Unload LoRA parameters.
sd_pipe.unload_lora_weights()
images_with_unloaded_lora = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images
images_with_unloaded_lora_slice = images_with_unloaded_lora[0, -3:, -3:, -1]
assert np.allclose(
lora_image_slice, images_with_unloaded_lora_slice
), "`unload_lora_weights()` should have not effect on the semantics of the results as the LoRA parameters were fused."
@slow @slow
@require_torch_gpu @require_torch_gpu
...@@ -877,10 +996,10 @@ class LoraIntegrationTests(unittest.TestCase): ...@@ -877,10 +996,10 @@ class LoraIntegrationTests(unittest.TestCase):
generator = torch.Generator().manual_seed(0) generator = torch.Generator().manual_seed(0)
pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-0.9") pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-0.9")
pipe.enable_model_cpu_offload()
lora_model_id = "hf-internal-testing/sdxl-0.9-daiton-lora" lora_model_id = "hf-internal-testing/sdxl-0.9-daiton-lora"
lora_filename = "daiton-xl-lora-test.safetensors" lora_filename = "daiton-xl-lora-test.safetensors"
pipe.load_lora_weights(lora_model_id, weight_name=lora_filename) pipe.load_lora_weights(lora_model_id, weight_name=lora_filename)
pipe.enable_model_cpu_offload()
images = pipe( images = pipe(
"masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2 "masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2
...@@ -895,10 +1014,10 @@ class LoraIntegrationTests(unittest.TestCase): ...@@ -895,10 +1014,10 @@ class LoraIntegrationTests(unittest.TestCase):
generator = torch.Generator().manual_seed(0) generator = torch.Generator().manual_seed(0)
pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-0.9") pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-0.9")
pipe.enable_model_cpu_offload()
lora_model_id = "hf-internal-testing/sdxl-0.9-costumes-lora" lora_model_id = "hf-internal-testing/sdxl-0.9-costumes-lora"
lora_filename = "saijo.safetensors" lora_filename = "saijo.safetensors"
pipe.load_lora_weights(lora_model_id, weight_name=lora_filename) pipe.load_lora_weights(lora_model_id, weight_name=lora_filename)
pipe.enable_model_cpu_offload()
images = pipe( images = pipe(
"masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2 "masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2
...@@ -913,10 +1032,10 @@ class LoraIntegrationTests(unittest.TestCase): ...@@ -913,10 +1032,10 @@ class LoraIntegrationTests(unittest.TestCase):
generator = torch.Generator().manual_seed(0) generator = torch.Generator().manual_seed(0)
pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-0.9") pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-0.9")
pipe.enable_model_cpu_offload()
lora_model_id = "hf-internal-testing/sdxl-0.9-kamepan-lora" lora_model_id = "hf-internal-testing/sdxl-0.9-kamepan-lora"
lora_filename = "kame_sdxl_v2-000020-16rank.safetensors" lora_filename = "kame_sdxl_v2-000020-16rank.safetensors"
pipe.load_lora_weights(lora_model_id, weight_name=lora_filename) pipe.load_lora_weights(lora_model_id, weight_name=lora_filename)
pipe.enable_model_cpu_offload()
images = pipe( images = pipe(
"masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2 "masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2
...@@ -931,19 +1050,127 @@ class LoraIntegrationTests(unittest.TestCase): ...@@ -931,19 +1050,127 @@ class LoraIntegrationTests(unittest.TestCase):
generator = torch.Generator().manual_seed(0) generator = torch.Generator().manual_seed(0)
pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0") pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0")
lora_model_id = "hf-internal-testing/sdxl-1.0-lora"
lora_filename = "sd_xl_offset_example-lora_1.0.safetensors"
pipe.load_lora_weights(lora_model_id, weight_name=lora_filename)
pipe.enable_model_cpu_offload() pipe.enable_model_cpu_offload()
images = pipe(
"masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2
).images
images = images[0, -3:, -3:, -1].flatten()
expected = np.array([0.4468, 0.4087, 0.4134, 0.366, 0.3202, 0.3505, 0.3786, 0.387, 0.3535])
self.assertTrue(np.allclose(images, expected, atol=1e-4))
def test_sdxl_1_0_lora_fusion(self):
generator = torch.Generator().manual_seed(0)
pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0")
lora_model_id = "hf-internal-testing/sdxl-1.0-lora" lora_model_id = "hf-internal-testing/sdxl-1.0-lora"
lora_filename = "sd_xl_offset_example-lora_1.0.safetensors" lora_filename = "sd_xl_offset_example-lora_1.0.safetensors"
pipe.load_lora_weights(lora_model_id, weight_name=lora_filename) pipe.load_lora_weights(lora_model_id, weight_name=lora_filename)
pipe.fuse_lora()
pipe.enable_model_cpu_offload()
images = pipe( images = pipe(
"masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2 "masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2
).images ).images
images = images[0, -3:, -3:, -1].flatten() images = images[0, -3:, -3:, -1].flatten()
# This way we also test equivalence between LoRA fusion and the non-fusion behaviour.
expected = np.array([0.4468, 0.4087, 0.4134, 0.366, 0.3202, 0.3505, 0.3786, 0.387, 0.3535]) expected = np.array([0.4468, 0.4087, 0.4134, 0.366, 0.3202, 0.3505, 0.3786, 0.387, 0.3535])
self.assertTrue(np.allclose(images, expected, atol=1e-3)) self.assertTrue(np.allclose(images, expected, atol=1e-4))
def test_sdxl_1_0_lora_unfusion(self):
generator = torch.Generator().manual_seed(0)
pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0")
lora_model_id = "hf-internal-testing/sdxl-1.0-lora"
lora_filename = "sd_xl_offset_example-lora_1.0.safetensors"
pipe.load_lora_weights(lora_model_id, weight_name=lora_filename)
pipe.fuse_lora()
pipe.enable_model_cpu_offload()
images = pipe(
"masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2
).images
images_with_fusion = images[0, -3:, -3:, -1].flatten()
pipe.unfuse_lora()
generator = torch.Generator().manual_seed(0)
images = pipe(
"masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2
).images
images_without_fusion = images[0, -3:, -3:, -1].flatten()
self.assertFalse(np.allclose(images_with_fusion, images_without_fusion, atol=1e-3))
def test_sdxl_1_0_lora_unfusion_effectivity(self):
pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0")
pipe.enable_model_cpu_offload()
generator = torch.Generator().manual_seed(0)
images = pipe(
"masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2
).images
original_image_slice = images[0, -3:, -3:, -1].flatten()
lora_model_id = "hf-internal-testing/sdxl-1.0-lora"
lora_filename = "sd_xl_offset_example-lora_1.0.safetensors"
pipe.load_lora_weights(lora_model_id, weight_name=lora_filename)
pipe.fuse_lora()
generator = torch.Generator().manual_seed(0)
_ = pipe(
"masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2
).images
pipe.unfuse_lora()
generator = torch.Generator().manual_seed(0)
images = pipe(
"masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2
).images
images_without_fusion_slice = images[0, -3:, -3:, -1].flatten()
self.assertTrue(np.allclose(original_image_slice, images_without_fusion_slice, atol=1e-3))
def test_sdxl_1_0_lora_fusion_efficiency(self):
generator = torch.Generator().manual_seed(0)
lora_model_id = "hf-internal-testing/sdxl-1.0-lora"
lora_filename = "sd_xl_offset_example-lora_1.0.safetensors"
pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0")
pipe.load_lora_weights(lora_model_id, weight_name=lora_filename)
pipe.enable_model_cpu_offload()
start_time = time.time()
for _ in range(3):
pipe(
"masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2
).images
end_time = time.time()
elapsed_time_non_fusion = end_time - start_time
del pipe
pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0")
pipe.load_lora_weights(lora_model_id, weight_name=lora_filename)
pipe.fuse_lora()
pipe.enable_model_cpu_offload()
start_time = time.time()
generator = torch.Generator().manual_seed(0)
for _ in range(3):
pipe(
"masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2
).images
end_time = time.time()
elapsed_time_fusion = end_time - start_time
self.assertTrue(elapsed_time_fusion < elapsed_time_non_fusion)
def test_sdxl_1_0_last_ben(self): def test_sdxl_1_0_last_ben(self):
generator = torch.Generator().manual_seed(0) generator = torch.Generator().manual_seed(0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment