Unverified Commit c81a88b2 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[Core] LoRA improvements pt. 3 (#4842)



* throw warning when more than one lora is attempted to be fused.

* introduce support of lora scale during fusion.

* change test name

* changes

* change to _lora_scale

* lora_scale to call whenever applicable.

* debugging

* lora_scale additional.

* cross_attention_kwargs

* lora_scale -> scale.

* lora_scale fix

* lora_scale in patched projection.

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* styling.

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* remove unneeded prints.

* remove unneeded prints.

* assign cross_attention_kwargs.

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* clean up.

* refactor scale retrieval logic a bit.

* fix nonetypw

* fix: tests

* add more tests

* more fixes.

* figure out a way to pass lora_scale.

* Apply suggestions from code review
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* unify the retrieval logic of lora_scale.

* move adjust_lora_scale_text_encoder to lora.py.

* introduce dynamic adjustment lora scale support to sd

* fix up copies

* Empty-Commit

* add: test to check fusion equivalence on different scales.

* handle lora fusion warning.

* make lora smaller

* make lora smaller

* make lora smaller

---------
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 2c1677ee
......@@ -31,6 +31,7 @@ from ...models.attention_processor import (
LoRAXFormersAttnProcessor,
XFormersAttnProcessor,
)
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import (
PIL_INTERPOLATION,
......@@ -312,6 +313,10 @@ class StableDiffusionXLAdapterPipeline(DiffusionPipeline, FromSingleFileMixin, L
if lora_scale is not None and isinstance(self, LoraLoaderMixin):
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
......
......@@ -21,6 +21,7 @@ from transformers import CLIPTextModel, CLIPTokenizer
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet3DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import (
deprecate,
......@@ -246,6 +247,9 @@ class TextToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lora
if lora_scale is not None and isinstance(self, LoraLoaderMixin):
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
......
......@@ -22,6 +22,7 @@ from transformers import CLIPTextModel, CLIPTokenizer
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet3DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import (
deprecate,
......@@ -308,6 +309,9 @@ class VideoToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lor
if lora_scale is not None and isinstance(self, LoraLoaderMixin):
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
......
......@@ -1143,6 +1143,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)}
# 3. down
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
is_adapter = mid_block_additional_residual is None and down_block_additional_residuals is not None
......@@ -1165,7 +1166,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
**additional_residuals,
)
else:
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
sample, res_samples = downsample_block(hidden_states=sample, temb=emb, scale=lora_scale)
if is_adapter and len(down_block_additional_residuals) > 0:
sample += down_block_additional_residuals.pop(0)
......@@ -1229,7 +1230,11 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
)
else:
sample = upsample_block(
hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
hidden_states=sample,
temb=emb,
res_hidden_states_tuple=res_samples,
upsample_size=upsample_size,
scale=lora_scale,
)
# 6. post-process
......@@ -1410,7 +1415,7 @@ class DownBlockFlat(nn.Module):
self.gradient_checkpointing = False
def forward(self, hidden_states, temb=None):
def forward(self, hidden_states, temb=None, scale: float = 1.0):
output_states = ()
for resnet in self.resnets:
......@@ -1431,13 +1436,13 @@ class DownBlockFlat(nn.Module):
create_custom_forward(resnet), hidden_states, temb
)
else:
hidden_states = resnet(hidden_states, temb)
hidden_states = resnet(hidden_states, temb, scale=scale)
output_states = output_states + (hidden_states,)
if self.downsamplers is not None:
for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states)
hidden_states = downsampler(hidden_states, scale=scale)
output_states = output_states + (hidden_states,)
......@@ -1547,6 +1552,8 @@ class CrossAttnDownBlockFlat(nn.Module):
):
output_states = ()
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
blocks = list(zip(self.resnets, self.attentions))
for i, (resnet, attn) in enumerate(blocks):
......@@ -1577,7 +1584,7 @@ class CrossAttnDownBlockFlat(nn.Module):
return_dict=False,
)[0]
else:
hidden_states = resnet(hidden_states, temb)
hidden_states = resnet(hidden_states, temb, scale=lora_scale)
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
......@@ -1595,7 +1602,7 @@ class CrossAttnDownBlockFlat(nn.Module):
if self.downsamplers is not None:
for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states)
hidden_states = downsampler(hidden_states, scale=lora_scale)
output_states = output_states + (hidden_states,)
......@@ -1651,7 +1658,7 @@ class UpBlockFlat(nn.Module):
self.gradient_checkpointing = False
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None):
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, scale: float = 1.0):
for resnet in self.resnets:
# pop res hidden states
res_hidden_states = res_hidden_states_tuple[-1]
......@@ -1675,11 +1682,11 @@ class UpBlockFlat(nn.Module):
create_custom_forward(resnet), hidden_states, temb
)
else:
hidden_states = resnet(hidden_states, temb)
hidden_states = resnet(hidden_states, temb, scale=scale)
if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states, upsample_size)
hidden_states = upsampler(hidden_states, upsample_size, scale=scale)
return hidden_states
......@@ -1782,6 +1789,8 @@ class CrossAttnUpBlockFlat(nn.Module):
attention_mask: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
):
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
for resnet, attn in zip(self.resnets, self.attentions):
# pop res hidden states
res_hidden_states = res_hidden_states_tuple[-1]
......@@ -1815,7 +1824,7 @@ class CrossAttnUpBlockFlat(nn.Module):
return_dict=False,
)[0]
else:
hidden_states = resnet(hidden_states, temb)
hidden_states = resnet(hidden_states, temb, scale=lora_scale)
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
......@@ -1827,7 +1836,7 @@ class CrossAttnUpBlockFlat(nn.Module):
if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states, upsample_size)
hidden_states = upsampler(hidden_states, upsample_size, scale=lora_scale)
return hidden_states
......@@ -1932,7 +1941,8 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor:
hidden_states = self.resnets[0](hidden_states, temb)
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
hidden_states = self.resnets[0](hidden_states, temb, scale=lora_scale)
for attn, resnet in zip(self.attentions, self.resnets[1:]):
if self.training and self.gradient_checkpointing:
......@@ -1969,7 +1979,7 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]
hidden_states = resnet(hidden_states, temb)
hidden_states = resnet(hidden_states, temb, scale=lora_scale)
return hidden_states
......@@ -2070,6 +2080,7 @@ class UNetMidBlockFlatSimpleCrossAttn(nn.Module):
encoder_attention_mask: Optional[torch.FloatTensor] = None,
):
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
lora_scale = cross_attention_kwargs.get("scale", 1.0)
if attention_mask is None:
# if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask.
......@@ -2082,7 +2093,7 @@ class UNetMidBlockFlatSimpleCrossAttn(nn.Module):
# mask = attention_mask if encoder_hidden_states is None else encoder_attention_mask
mask = attention_mask
hidden_states = self.resnets[0](hidden_states, temb)
hidden_states = self.resnets[0](hidden_states, temb, scale=lora_scale)
for attn, resnet in zip(self.attentions, self.resnets[1:]):
# attn
hidden_states = attn(
......@@ -2093,6 +2104,6 @@ class UNetMidBlockFlatSimpleCrossAttn(nn.Module):
)
# resnet
hidden_states = resnet(hidden_states, temb)
hidden_states = resnet(hidden_states, temb, scale=lora_scale)
return hidden_states
......@@ -92,11 +92,11 @@ def create_text_encoder_lora_layers(text_encoder: nn.Module):
return text_encoder_lora_layers
def set_lora_weights(lora_attn_parameters, randn_weight=False):
def set_lora_weights(lora_attn_parameters, randn_weight=False, var=1.0):
with torch.no_grad():
for parameter in lora_attn_parameters:
if randn_weight:
parameter[:] = torch.randn_like(parameter)
parameter[:] = torch.randn_like(parameter) * var
else:
torch.zero_(parameter)
......@@ -719,7 +719,7 @@ class SDXLLoraLoaderMixinTests(unittest.TestCase):
# default & unloaded LoRA weights should have identical state_dicts
assert text_encoder_1_sd_keys == text_encoder_1_sd_keys_3
# default & loaded LoRA weights should NOT have identical state_dicts
assert text_encoder_1_sd_keys != text_encoder_1_sd_keys_2 #
assert text_encoder_1_sd_keys != text_encoder_1_sd_keys_2
# default & unloaded LoRA weights should have identical state_dicts
assert text_encoder_2_sd_keys == text_encoder_2_sd_keys_3
......@@ -863,6 +863,161 @@ class SDXLLoraLoaderMixinTests(unittest.TestCase):
lora_image_slice, images_with_unloaded_lora_slice
), "`unload_lora_weights()` should have not effect on the semantics of the results as the LoRA parameters were fused."
def test_fuse_lora_with_different_scales(self):
pipeline_components, lora_components = self.get_dummy_components()
sd_pipe = StableDiffusionXLPipeline(**pipeline_components)
sd_pipe = sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None)
_, _, pipeline_inputs = self.get_dummy_inputs(with_generator=False)
_ = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images
# Emulate training.
set_lora_weights(lora_components["unet_lora_layers"].parameters(), randn_weight=True)
set_lora_weights(lora_components["text_encoder_one_lora_layers"].parameters(), randn_weight=True)
set_lora_weights(lora_components["text_encoder_two_lora_layers"].parameters(), randn_weight=True)
with tempfile.TemporaryDirectory() as tmpdirname:
StableDiffusionXLPipeline.save_lora_weights(
save_directory=tmpdirname,
unet_lora_layers=lora_components["unet_lora_layers"],
text_encoder_lora_layers=lora_components["text_encoder_one_lora_layers"],
text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_layers"],
safe_serialization=True,
)
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
sd_pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
sd_pipe.fuse_lora(lora_scale=1.0)
lora_images_scale_one = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images
lora_image_slice_scale_one = lora_images_scale_one[0, -3:, -3:, -1]
# Reverse LoRA fusion.
sd_pipe.unfuse_lora()
with tempfile.TemporaryDirectory() as tmpdirname:
StableDiffusionXLPipeline.save_lora_weights(
save_directory=tmpdirname,
unet_lora_layers=lora_components["unet_lora_layers"],
text_encoder_lora_layers=lora_components["text_encoder_one_lora_layers"],
text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_layers"],
safe_serialization=True,
)
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
sd_pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
sd_pipe.fuse_lora(lora_scale=0.5)
lora_images_scale_0_5 = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images
lora_image_slice_scale_0_5 = lora_images_scale_0_5[0, -3:, -3:, -1]
assert not np.allclose(
lora_image_slice_scale_one, lora_image_slice_scale_0_5, atol=1e-03
), "Different LoRA scales should influence the outputs accordingly."
def test_with_different_scales(self):
pipeline_components, lora_components = self.get_dummy_components()
sd_pipe = StableDiffusionXLPipeline(**pipeline_components)
sd_pipe = sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None)
_, _, pipeline_inputs = self.get_dummy_inputs(with_generator=False)
original_images = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images
original_imagee_slice = original_images[0, -3:, -3:, -1]
# Emulate training.
set_lora_weights(lora_components["unet_lora_layers"].parameters(), randn_weight=True)
set_lora_weights(lora_components["text_encoder_one_lora_layers"].parameters(), randn_weight=True)
set_lora_weights(lora_components["text_encoder_two_lora_layers"].parameters(), randn_weight=True)
with tempfile.TemporaryDirectory() as tmpdirname:
StableDiffusionXLPipeline.save_lora_weights(
save_directory=tmpdirname,
unet_lora_layers=lora_components["unet_lora_layers"],
text_encoder_lora_layers=lora_components["text_encoder_one_lora_layers"],
text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_layers"],
safe_serialization=True,
)
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
sd_pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
lora_images_scale_one = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images
lora_image_slice_scale_one = lora_images_scale_one[0, -3:, -3:, -1]
lora_images_scale_0_5 = sd_pipe(
**pipeline_inputs, generator=torch.manual_seed(0), cross_attention_kwargs={"scale": 0.5}
).images
lora_image_slice_scale_0_5 = lora_images_scale_0_5[0, -3:, -3:, -1]
lora_images_scale_0_0 = sd_pipe(
**pipeline_inputs, generator=torch.manual_seed(0), cross_attention_kwargs={"scale": 0.0}
).images
lora_image_slice_scale_0_0 = lora_images_scale_0_0[0, -3:, -3:, -1]
assert not np.allclose(
lora_image_slice_scale_one, lora_image_slice_scale_0_5, atol=1e-03
), "Different LoRA scales should influence the outputs accordingly."
assert np.allclose(
original_imagee_slice, lora_image_slice_scale_0_0, atol=1e-03
), "LoRA scale of 0.0 shouldn't be different from the results without LoRA."
def test_with_different_scales_fusion_equivalence(self):
pipeline_components, lora_components = self.get_dummy_components()
sd_pipe = StableDiffusionXLPipeline(**pipeline_components)
sd_pipe = sd_pipe.to(torch_device)
# sd_pipe.unet.set_default_attn_processor()
sd_pipe.set_progress_bar_config(disable=None)
_, _, pipeline_inputs = self.get_dummy_inputs(with_generator=False)
images = sd_pipe(
**pipeline_inputs,
generator=torch.manual_seed(0),
).images
images_slice = images[0, -3:, -3:, -1]
# Emulate training.
set_lora_weights(lora_components["unet_lora_layers"].parameters(), randn_weight=True, var=0.1)
set_lora_weights(lora_components["text_encoder_one_lora_layers"].parameters(), randn_weight=True, var=0.1)
set_lora_weights(lora_components["text_encoder_two_lora_layers"].parameters(), randn_weight=True, var=0.1)
with tempfile.TemporaryDirectory() as tmpdirname:
StableDiffusionXLPipeline.save_lora_weights(
save_directory=tmpdirname,
unet_lora_layers=lora_components["unet_lora_layers"],
text_encoder_lora_layers=lora_components["text_encoder_one_lora_layers"],
text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_layers"],
safe_serialization=True,
)
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
sd_pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
lora_images_scale_0_5 = sd_pipe(
**pipeline_inputs,
generator=torch.manual_seed(0),
cross_attention_kwargs={"scale": 0.5},
).images
lora_image_slice_scale_0_5 = lora_images_scale_0_5[0, -3:, -3:, -1]
sd_pipe.fuse_lora(lora_scale=0.5)
lora_images_scale_0_5_fusion = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images
lora_image_slice_scale_0_5_fusion = lora_images_scale_0_5_fusion[0, -3:, -3:, -1]
assert np.allclose(
lora_image_slice_scale_0_5, lora_image_slice_scale_0_5_fusion, atol=1e-03
), "Fusion shouldn't affect the results when calling the pipeline with a non-default LoRA scale."
sd_pipe.unfuse_lora()
images_unfused = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images
images_slice_unfused = images_unfused[0, -3:, -3:, -1]
assert np.allclose(images_slice, images_slice_unfused, atol=1e-03), "Unfused should match no LoRA"
assert not np.allclose(
images_slice, lora_image_slice_scale_0_5, atol=1e-03
), "0.5 scale and no scale shouldn't match"
@slow
@require_torch_gpu
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment