Unverified Commit c81a88b2 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[Core] LoRA improvements pt. 3 (#4842)



* throw warning when more than one lora is attempted to be fused.

* introduce support of lora scale during fusion.

* change test name

* changes

* change to _lora_scale

* lora_scale to call whenever applicable.

* debugging

* lora_scale additional.

* cross_attention_kwargs

* lora_scale -> scale.

* lora_scale fix

* lora_scale in patched projection.

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* styling.

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* remove unneeded prints.

* remove unneeded prints.

* assign cross_attention_kwargs.

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* clean up.

* refactor scale retrieval logic a bit.

* fix nonetypw

* fix: tests

* add more tests

* more fixes.

* figure out a way to pass lora_scale.

* Apply suggestions from code review
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* unify the retrieval logic of lora_scale.

* move adjust_lora_scale_text_encoder to lora.py.

* introduce dynamic adjustment lora scale support to sd

* fix up copies

* Empty-Commit

* add: test to check fusion equivalence on different scales.

* handle lora fusion warning.

* make lora smaller

* make lora smaller

* make lora smaller

---------
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 2c1677ee
...@@ -31,6 +31,7 @@ from ...models.attention_processor import ( ...@@ -31,6 +31,7 @@ from ...models.attention_processor import (
LoRAXFormersAttnProcessor, LoRAXFormersAttnProcessor,
XFormersAttnProcessor, XFormersAttnProcessor,
) )
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import ( from ...utils import (
PIL_INTERPOLATION, PIL_INTERPOLATION,
...@@ -312,6 +313,10 @@ class StableDiffusionXLAdapterPipeline(DiffusionPipeline, FromSingleFileMixin, L ...@@ -312,6 +313,10 @@ class StableDiffusionXLAdapterPipeline(DiffusionPipeline, FromSingleFileMixin, L
if lora_scale is not None and isinstance(self, LoraLoaderMixin): if lora_scale is not None and isinstance(self, LoraLoaderMixin):
self._lora_scale = lora_scale self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
if prompt is not None and isinstance(prompt, str): if prompt is not None and isinstance(prompt, str):
batch_size = 1 batch_size = 1
elif prompt is not None and isinstance(prompt, list): elif prompt is not None and isinstance(prompt, list):
......
...@@ -21,6 +21,7 @@ from transformers import CLIPTextModel, CLIPTokenizer ...@@ -21,6 +21,7 @@ from transformers import CLIPTextModel, CLIPTokenizer
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet3DConditionModel from ...models import AutoencoderKL, UNet3DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import ( from ...utils import (
deprecate, deprecate,
...@@ -246,6 +247,9 @@ class TextToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lora ...@@ -246,6 +247,9 @@ class TextToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lora
if lora_scale is not None and isinstance(self, LoraLoaderMixin): if lora_scale is not None and isinstance(self, LoraLoaderMixin):
self._lora_scale = lora_scale self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
if prompt is not None and isinstance(prompt, str): if prompt is not None and isinstance(prompt, str):
batch_size = 1 batch_size = 1
elif prompt is not None and isinstance(prompt, list): elif prompt is not None and isinstance(prompt, list):
......
...@@ -22,6 +22,7 @@ from transformers import CLIPTextModel, CLIPTokenizer ...@@ -22,6 +22,7 @@ from transformers import CLIPTextModel, CLIPTokenizer
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet3DConditionModel from ...models import AutoencoderKL, UNet3DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import ( from ...utils import (
deprecate, deprecate,
...@@ -308,6 +309,9 @@ class VideoToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lor ...@@ -308,6 +309,9 @@ class VideoToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lor
if lora_scale is not None and isinstance(self, LoraLoaderMixin): if lora_scale is not None and isinstance(self, LoraLoaderMixin):
self._lora_scale = lora_scale self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
if prompt is not None and isinstance(prompt, str): if prompt is not None and isinstance(prompt, str):
batch_size = 1 batch_size = 1
elif prompt is not None and isinstance(prompt, list): elif prompt is not None and isinstance(prompt, list):
......
...@@ -1143,6 +1143,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ...@@ -1143,6 +1143,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)} cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)}
# 3. down # 3. down
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
is_adapter = mid_block_additional_residual is None and down_block_additional_residuals is not None is_adapter = mid_block_additional_residual is None and down_block_additional_residuals is not None
...@@ -1165,7 +1166,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ...@@ -1165,7 +1166,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
**additional_residuals, **additional_residuals,
) )
else: else:
sample, res_samples = downsample_block(hidden_states=sample, temb=emb) sample, res_samples = downsample_block(hidden_states=sample, temb=emb, scale=lora_scale)
if is_adapter and len(down_block_additional_residuals) > 0: if is_adapter and len(down_block_additional_residuals) > 0:
sample += down_block_additional_residuals.pop(0) sample += down_block_additional_residuals.pop(0)
...@@ -1229,7 +1230,11 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ...@@ -1229,7 +1230,11 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
) )
else: else:
sample = upsample_block( sample = upsample_block(
hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size hidden_states=sample,
temb=emb,
res_hidden_states_tuple=res_samples,
upsample_size=upsample_size,
scale=lora_scale,
) )
# 6. post-process # 6. post-process
...@@ -1410,7 +1415,7 @@ class DownBlockFlat(nn.Module): ...@@ -1410,7 +1415,7 @@ class DownBlockFlat(nn.Module):
self.gradient_checkpointing = False self.gradient_checkpointing = False
def forward(self, hidden_states, temb=None): def forward(self, hidden_states, temb=None, scale: float = 1.0):
output_states = () output_states = ()
for resnet in self.resnets: for resnet in self.resnets:
...@@ -1431,13 +1436,13 @@ class DownBlockFlat(nn.Module): ...@@ -1431,13 +1436,13 @@ class DownBlockFlat(nn.Module):
create_custom_forward(resnet), hidden_states, temb create_custom_forward(resnet), hidden_states, temb
) )
else: else:
hidden_states = resnet(hidden_states, temb) hidden_states = resnet(hidden_states, temb, scale=scale)
output_states = output_states + (hidden_states,) output_states = output_states + (hidden_states,)
if self.downsamplers is not None: if self.downsamplers is not None:
for downsampler in self.downsamplers: for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states) hidden_states = downsampler(hidden_states, scale=scale)
output_states = output_states + (hidden_states,) output_states = output_states + (hidden_states,)
...@@ -1547,6 +1552,8 @@ class CrossAttnDownBlockFlat(nn.Module): ...@@ -1547,6 +1552,8 @@ class CrossAttnDownBlockFlat(nn.Module):
): ):
output_states = () output_states = ()
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
blocks = list(zip(self.resnets, self.attentions)) blocks = list(zip(self.resnets, self.attentions))
for i, (resnet, attn) in enumerate(blocks): for i, (resnet, attn) in enumerate(blocks):
...@@ -1577,7 +1584,7 @@ class CrossAttnDownBlockFlat(nn.Module): ...@@ -1577,7 +1584,7 @@ class CrossAttnDownBlockFlat(nn.Module):
return_dict=False, return_dict=False,
)[0] )[0]
else: else:
hidden_states = resnet(hidden_states, temb) hidden_states = resnet(hidden_states, temb, scale=lora_scale)
hidden_states = attn( hidden_states = attn(
hidden_states, hidden_states,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
...@@ -1595,7 +1602,7 @@ class CrossAttnDownBlockFlat(nn.Module): ...@@ -1595,7 +1602,7 @@ class CrossAttnDownBlockFlat(nn.Module):
if self.downsamplers is not None: if self.downsamplers is not None:
for downsampler in self.downsamplers: for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states) hidden_states = downsampler(hidden_states, scale=lora_scale)
output_states = output_states + (hidden_states,) output_states = output_states + (hidden_states,)
...@@ -1651,7 +1658,7 @@ class UpBlockFlat(nn.Module): ...@@ -1651,7 +1658,7 @@ class UpBlockFlat(nn.Module):
self.gradient_checkpointing = False self.gradient_checkpointing = False
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None): def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, scale: float = 1.0):
for resnet in self.resnets: for resnet in self.resnets:
# pop res hidden states # pop res hidden states
res_hidden_states = res_hidden_states_tuple[-1] res_hidden_states = res_hidden_states_tuple[-1]
...@@ -1675,11 +1682,11 @@ class UpBlockFlat(nn.Module): ...@@ -1675,11 +1682,11 @@ class UpBlockFlat(nn.Module):
create_custom_forward(resnet), hidden_states, temb create_custom_forward(resnet), hidden_states, temb
) )
else: else:
hidden_states = resnet(hidden_states, temb) hidden_states = resnet(hidden_states, temb, scale=scale)
if self.upsamplers is not None: if self.upsamplers is not None:
for upsampler in self.upsamplers: for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states, upsample_size) hidden_states = upsampler(hidden_states, upsample_size, scale=scale)
return hidden_states return hidden_states
...@@ -1782,6 +1789,8 @@ class CrossAttnUpBlockFlat(nn.Module): ...@@ -1782,6 +1789,8 @@ class CrossAttnUpBlockFlat(nn.Module):
attention_mask: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
): ):
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
for resnet, attn in zip(self.resnets, self.attentions): for resnet, attn in zip(self.resnets, self.attentions):
# pop res hidden states # pop res hidden states
res_hidden_states = res_hidden_states_tuple[-1] res_hidden_states = res_hidden_states_tuple[-1]
...@@ -1815,7 +1824,7 @@ class CrossAttnUpBlockFlat(nn.Module): ...@@ -1815,7 +1824,7 @@ class CrossAttnUpBlockFlat(nn.Module):
return_dict=False, return_dict=False,
)[0] )[0]
else: else:
hidden_states = resnet(hidden_states, temb) hidden_states = resnet(hidden_states, temb, scale=lora_scale)
hidden_states = attn( hidden_states = attn(
hidden_states, hidden_states,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
...@@ -1827,7 +1836,7 @@ class CrossAttnUpBlockFlat(nn.Module): ...@@ -1827,7 +1836,7 @@ class CrossAttnUpBlockFlat(nn.Module):
if self.upsamplers is not None: if self.upsamplers is not None:
for upsampler in self.upsamplers: for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states, upsample_size) hidden_states = upsampler(hidden_states, upsample_size, scale=lora_scale)
return hidden_states return hidden_states
...@@ -1932,7 +1941,8 @@ class UNetMidBlockFlatCrossAttn(nn.Module): ...@@ -1932,7 +1941,8 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
cross_attention_kwargs: Optional[Dict[str, Any]] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor: ) -> torch.FloatTensor:
hidden_states = self.resnets[0](hidden_states, temb) lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
hidden_states = self.resnets[0](hidden_states, temb, scale=lora_scale)
for attn, resnet in zip(self.attentions, self.resnets[1:]): for attn, resnet in zip(self.attentions, self.resnets[1:]):
if self.training and self.gradient_checkpointing: if self.training and self.gradient_checkpointing:
...@@ -1969,7 +1979,7 @@ class UNetMidBlockFlatCrossAttn(nn.Module): ...@@ -1969,7 +1979,7 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
encoder_attention_mask=encoder_attention_mask, encoder_attention_mask=encoder_attention_mask,
return_dict=False, return_dict=False,
)[0] )[0]
hidden_states = resnet(hidden_states, temb) hidden_states = resnet(hidden_states, temb, scale=lora_scale)
return hidden_states return hidden_states
...@@ -2070,6 +2080,7 @@ class UNetMidBlockFlatSimpleCrossAttn(nn.Module): ...@@ -2070,6 +2080,7 @@ class UNetMidBlockFlatSimpleCrossAttn(nn.Module):
encoder_attention_mask: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
): ):
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
lora_scale = cross_attention_kwargs.get("scale", 1.0)
if attention_mask is None: if attention_mask is None:
# if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask. # if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask.
...@@ -2082,7 +2093,7 @@ class UNetMidBlockFlatSimpleCrossAttn(nn.Module): ...@@ -2082,7 +2093,7 @@ class UNetMidBlockFlatSimpleCrossAttn(nn.Module):
# mask = attention_mask if encoder_hidden_states is None else encoder_attention_mask # mask = attention_mask if encoder_hidden_states is None else encoder_attention_mask
mask = attention_mask mask = attention_mask
hidden_states = self.resnets[0](hidden_states, temb) hidden_states = self.resnets[0](hidden_states, temb, scale=lora_scale)
for attn, resnet in zip(self.attentions, self.resnets[1:]): for attn, resnet in zip(self.attentions, self.resnets[1:]):
# attn # attn
hidden_states = attn( hidden_states = attn(
...@@ -2093,6 +2104,6 @@ class UNetMidBlockFlatSimpleCrossAttn(nn.Module): ...@@ -2093,6 +2104,6 @@ class UNetMidBlockFlatSimpleCrossAttn(nn.Module):
) )
# resnet # resnet
hidden_states = resnet(hidden_states, temb) hidden_states = resnet(hidden_states, temb, scale=lora_scale)
return hidden_states return hidden_states
...@@ -92,11 +92,11 @@ def create_text_encoder_lora_layers(text_encoder: nn.Module): ...@@ -92,11 +92,11 @@ def create_text_encoder_lora_layers(text_encoder: nn.Module):
return text_encoder_lora_layers return text_encoder_lora_layers
def set_lora_weights(lora_attn_parameters, randn_weight=False): def set_lora_weights(lora_attn_parameters, randn_weight=False, var=1.0):
with torch.no_grad(): with torch.no_grad():
for parameter in lora_attn_parameters: for parameter in lora_attn_parameters:
if randn_weight: if randn_weight:
parameter[:] = torch.randn_like(parameter) parameter[:] = torch.randn_like(parameter) * var
else: else:
torch.zero_(parameter) torch.zero_(parameter)
...@@ -719,7 +719,7 @@ class SDXLLoraLoaderMixinTests(unittest.TestCase): ...@@ -719,7 +719,7 @@ class SDXLLoraLoaderMixinTests(unittest.TestCase):
# default & unloaded LoRA weights should have identical state_dicts # default & unloaded LoRA weights should have identical state_dicts
assert text_encoder_1_sd_keys == text_encoder_1_sd_keys_3 assert text_encoder_1_sd_keys == text_encoder_1_sd_keys_3
# default & loaded LoRA weights should NOT have identical state_dicts # default & loaded LoRA weights should NOT have identical state_dicts
assert text_encoder_1_sd_keys != text_encoder_1_sd_keys_2 # assert text_encoder_1_sd_keys != text_encoder_1_sd_keys_2
# default & unloaded LoRA weights should have identical state_dicts # default & unloaded LoRA weights should have identical state_dicts
assert text_encoder_2_sd_keys == text_encoder_2_sd_keys_3 assert text_encoder_2_sd_keys == text_encoder_2_sd_keys_3
...@@ -863,6 +863,161 @@ class SDXLLoraLoaderMixinTests(unittest.TestCase): ...@@ -863,6 +863,161 @@ class SDXLLoraLoaderMixinTests(unittest.TestCase):
lora_image_slice, images_with_unloaded_lora_slice lora_image_slice, images_with_unloaded_lora_slice
), "`unload_lora_weights()` should have not effect on the semantics of the results as the LoRA parameters were fused." ), "`unload_lora_weights()` should have not effect on the semantics of the results as the LoRA parameters were fused."
def test_fuse_lora_with_different_scales(self):
pipeline_components, lora_components = self.get_dummy_components()
sd_pipe = StableDiffusionXLPipeline(**pipeline_components)
sd_pipe = sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None)
_, _, pipeline_inputs = self.get_dummy_inputs(with_generator=False)
_ = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images
# Emulate training.
set_lora_weights(lora_components["unet_lora_layers"].parameters(), randn_weight=True)
set_lora_weights(lora_components["text_encoder_one_lora_layers"].parameters(), randn_weight=True)
set_lora_weights(lora_components["text_encoder_two_lora_layers"].parameters(), randn_weight=True)
with tempfile.TemporaryDirectory() as tmpdirname:
StableDiffusionXLPipeline.save_lora_weights(
save_directory=tmpdirname,
unet_lora_layers=lora_components["unet_lora_layers"],
text_encoder_lora_layers=lora_components["text_encoder_one_lora_layers"],
text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_layers"],
safe_serialization=True,
)
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
sd_pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
sd_pipe.fuse_lora(lora_scale=1.0)
lora_images_scale_one = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images
lora_image_slice_scale_one = lora_images_scale_one[0, -3:, -3:, -1]
# Reverse LoRA fusion.
sd_pipe.unfuse_lora()
with tempfile.TemporaryDirectory() as tmpdirname:
StableDiffusionXLPipeline.save_lora_weights(
save_directory=tmpdirname,
unet_lora_layers=lora_components["unet_lora_layers"],
text_encoder_lora_layers=lora_components["text_encoder_one_lora_layers"],
text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_layers"],
safe_serialization=True,
)
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
sd_pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
sd_pipe.fuse_lora(lora_scale=0.5)
lora_images_scale_0_5 = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images
lora_image_slice_scale_0_5 = lora_images_scale_0_5[0, -3:, -3:, -1]
assert not np.allclose(
lora_image_slice_scale_one, lora_image_slice_scale_0_5, atol=1e-03
), "Different LoRA scales should influence the outputs accordingly."
def test_with_different_scales(self):
pipeline_components, lora_components = self.get_dummy_components()
sd_pipe = StableDiffusionXLPipeline(**pipeline_components)
sd_pipe = sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None)
_, _, pipeline_inputs = self.get_dummy_inputs(with_generator=False)
original_images = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images
original_imagee_slice = original_images[0, -3:, -3:, -1]
# Emulate training.
set_lora_weights(lora_components["unet_lora_layers"].parameters(), randn_weight=True)
set_lora_weights(lora_components["text_encoder_one_lora_layers"].parameters(), randn_weight=True)
set_lora_weights(lora_components["text_encoder_two_lora_layers"].parameters(), randn_weight=True)
with tempfile.TemporaryDirectory() as tmpdirname:
StableDiffusionXLPipeline.save_lora_weights(
save_directory=tmpdirname,
unet_lora_layers=lora_components["unet_lora_layers"],
text_encoder_lora_layers=lora_components["text_encoder_one_lora_layers"],
text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_layers"],
safe_serialization=True,
)
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
sd_pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
lora_images_scale_one = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images
lora_image_slice_scale_one = lora_images_scale_one[0, -3:, -3:, -1]
lora_images_scale_0_5 = sd_pipe(
**pipeline_inputs, generator=torch.manual_seed(0), cross_attention_kwargs={"scale": 0.5}
).images
lora_image_slice_scale_0_5 = lora_images_scale_0_5[0, -3:, -3:, -1]
lora_images_scale_0_0 = sd_pipe(
**pipeline_inputs, generator=torch.manual_seed(0), cross_attention_kwargs={"scale": 0.0}
).images
lora_image_slice_scale_0_0 = lora_images_scale_0_0[0, -3:, -3:, -1]
assert not np.allclose(
lora_image_slice_scale_one, lora_image_slice_scale_0_5, atol=1e-03
), "Different LoRA scales should influence the outputs accordingly."
assert np.allclose(
original_imagee_slice, lora_image_slice_scale_0_0, atol=1e-03
), "LoRA scale of 0.0 shouldn't be different from the results without LoRA."
def test_with_different_scales_fusion_equivalence(self):
pipeline_components, lora_components = self.get_dummy_components()
sd_pipe = StableDiffusionXLPipeline(**pipeline_components)
sd_pipe = sd_pipe.to(torch_device)
# sd_pipe.unet.set_default_attn_processor()
sd_pipe.set_progress_bar_config(disable=None)
_, _, pipeline_inputs = self.get_dummy_inputs(with_generator=False)
images = sd_pipe(
**pipeline_inputs,
generator=torch.manual_seed(0),
).images
images_slice = images[0, -3:, -3:, -1]
# Emulate training.
set_lora_weights(lora_components["unet_lora_layers"].parameters(), randn_weight=True, var=0.1)
set_lora_weights(lora_components["text_encoder_one_lora_layers"].parameters(), randn_weight=True, var=0.1)
set_lora_weights(lora_components["text_encoder_two_lora_layers"].parameters(), randn_weight=True, var=0.1)
with tempfile.TemporaryDirectory() as tmpdirname:
StableDiffusionXLPipeline.save_lora_weights(
save_directory=tmpdirname,
unet_lora_layers=lora_components["unet_lora_layers"],
text_encoder_lora_layers=lora_components["text_encoder_one_lora_layers"],
text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_layers"],
safe_serialization=True,
)
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
sd_pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
lora_images_scale_0_5 = sd_pipe(
**pipeline_inputs,
generator=torch.manual_seed(0),
cross_attention_kwargs={"scale": 0.5},
).images
lora_image_slice_scale_0_5 = lora_images_scale_0_5[0, -3:, -3:, -1]
sd_pipe.fuse_lora(lora_scale=0.5)
lora_images_scale_0_5_fusion = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images
lora_image_slice_scale_0_5_fusion = lora_images_scale_0_5_fusion[0, -3:, -3:, -1]
assert np.allclose(
lora_image_slice_scale_0_5, lora_image_slice_scale_0_5_fusion, atol=1e-03
), "Fusion shouldn't affect the results when calling the pipeline with a non-default LoRA scale."
sd_pipe.unfuse_lora()
images_unfused = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images
images_slice_unfused = images_unfused[0, -3:, -3:, -1]
assert np.allclose(images_slice, images_slice_unfused, atol=1e-03), "Unfused should match no LoRA"
assert not np.allclose(
images_slice, lora_image_slice_scale_0_5, atol=1e-03
), "0.5 scale and no scale shouldn't match"
@slow @slow
@require_torch_gpu @require_torch_gpu
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment