Unverified Commit 74fd735e authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

Add draft for lora text encoder scale (#3626)



* Add draft for lora text encoder scale

* Improve naming

* fix: training dreambooth lora script.

* Apply suggestions from code review

* Update examples/dreambooth/train_dreambooth_lora.py

* Apply suggestions from code review

* Apply suggestions from code review

* add lora mixin when fit

* add lora mixin when fit

* add lora mixin when fit

* fix more

* fix more

---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent 2de9e2df
......@@ -20,7 +20,7 @@ import torch
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from ...image_processor import VaeImageProcessor
from ...loaders import TextualInversionLoaderMixin
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import PNDMScheduler
from ...schedulers.scheduling_utils import SchedulerMixin
......@@ -55,7 +55,7 @@ EXAMPLE_DOC_STRING = """
"""
class StableDiffusionModelEditingPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
class StableDiffusionModelEditingPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin):
r"""
Pipeline for text-to-image model editing using "Editing Implicit Assumptions in Text-to-Image Diffusion Models".
......@@ -237,6 +237,7 @@ class StableDiffusionModelEditingPipeline(DiffusionPipeline, TextualInversionLoa
negative_prompt=None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
lora_scale: Optional[float] = None,
):
r"""
Encodes the prompt into text encoder hidden states.
......@@ -261,7 +262,14 @@ class StableDiffusionModelEditingPipeline(DiffusionPipeline, TextualInversionLoa
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
lora_scale (`float`, *optional*):
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
"""
# set lora scale so that monkey patched LoRA
# function of text encoder can correctly access it
if lora_scale is not None and isinstance(self, LoraLoaderMixin):
self._lora_scale = lora_scale
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
......@@ -719,6 +727,9 @@ class StableDiffusionModelEditingPipeline(DiffusionPipeline, TextualInversionLoa
do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt
text_encoder_lora_scale = (
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
)
prompt_embeds = self._encode_prompt(
prompt,
device,
......@@ -727,6 +738,7 @@ class StableDiffusionModelEditingPipeline(DiffusionPipeline, TextualInversionLoa
negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
lora_scale=text_encoder_lora_scale,
)
# 4. Prepare timesteps
......
......@@ -20,7 +20,7 @@ import torch
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from ...image_processor import VaeImageProcessor
from ...loaders import TextualInversionLoaderMixin
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import DDIMScheduler
from ...utils import is_accelerate_available, is_accelerate_version, logging, randn_tensor, replace_example_docstring
......@@ -51,7 +51,7 @@ EXAMPLE_DOC_STRING = """
"""
class StableDiffusionPanoramaPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
class StableDiffusionPanoramaPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin):
r"""
Pipeline for text-to-image generation using "MultiDiffusion: Fusing Diffusion Paths for Controlled Image
Generation".
......@@ -199,6 +199,7 @@ class StableDiffusionPanoramaPipeline(DiffusionPipeline, TextualInversionLoaderM
negative_prompt=None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
lora_scale: Optional[float] = None,
):
r"""
Encodes the prompt into text encoder hidden states.
......@@ -223,7 +224,14 @@ class StableDiffusionPanoramaPipeline(DiffusionPipeline, TextualInversionLoaderM
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
lora_scale (`float`, *optional*):
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
"""
# set lora scale so that monkey patched LoRA
# function of text encoder can correctly access it
if lora_scale is not None and isinstance(self, LoraLoaderMixin):
self._lora_scale = lora_scale
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
......@@ -586,6 +594,9 @@ class StableDiffusionPanoramaPipeline(DiffusionPipeline, TextualInversionLoaderM
do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt
text_encoder_lora_scale = (
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
)
prompt_embeds = self._encode_prompt(
prompt,
device,
......@@ -594,6 +605,7 @@ class StableDiffusionPanoramaPipeline(DiffusionPipeline, TextualInversionLoaderM
negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
lora_scale=text_encoder_lora_scale,
)
# 4. Prepare timesteps
......
......@@ -30,7 +30,7 @@ from transformers import (
)
from ...image_processor import VaeImageProcessor
from ...loaders import TextualInversionLoaderMixin
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel
from ...models.attention_processor import Attention
from ...schedulers import DDIMScheduler, DDPMScheduler, EulerAncestralDiscreteScheduler, LMSDiscreteScheduler
......@@ -447,6 +447,7 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline):
negative_prompt=None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
lora_scale: Optional[float] = None,
):
r"""
Encodes the prompt into text encoder hidden states.
......@@ -471,7 +472,14 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
lora_scale (`float`, *optional*):
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
"""
# set lora scale so that monkey patched LoRA
# function of text encoder can correctly access it
if lora_scale is not None and isinstance(self, LoraLoaderMixin):
self._lora_scale = lora_scale
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
......
......@@ -21,7 +21,7 @@ import torch.nn.functional as F
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from ...image_processor import VaeImageProcessor
from ...loaders import TextualInversionLoaderMixin
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import is_accelerate_available, is_accelerate_version, logging, randn_tensor, replace_example_docstring
......@@ -218,6 +218,7 @@ class StableDiffusionSAGPipeline(DiffusionPipeline, TextualInversionLoaderMixin)
negative_prompt=None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
lora_scale: Optional[float] = None,
):
r"""
Encodes the prompt into text encoder hidden states.
......@@ -242,7 +243,14 @@ class StableDiffusionSAGPipeline(DiffusionPipeline, TextualInversionLoaderMixin)
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
lora_scale (`float`, *optional*):
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
"""
# set lora scale so that monkey patched LoRA
# function of text encoder can correctly access it
if lora_scale is not None and isinstance(self, LoraLoaderMixin):
self._lora_scale = lora_scale
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
......
......@@ -14,7 +14,7 @@
import inspect
import warnings
from typing import Any, Callable, List, Optional, Union
from typing import Any, Callable, Dict, List, Optional, Union
import numpy as np
import PIL
......@@ -22,7 +22,7 @@ import torch
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from ...image_processor import VaeImageProcessor
from ...loaders import TextualInversionLoaderMixin
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel
from ...models.attention_processor import AttnProcessor2_0, LoRAXFormersAttnProcessor, XFormersAttnProcessor
from ...schedulers import DDPMScheduler, KarrasDiffusionSchedulers
......@@ -60,7 +60,7 @@ def preprocess(image):
return image
class StableDiffusionUpscalePipeline(DiffusionPipeline, TextualInversionLoaderMixin):
class StableDiffusionUpscalePipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin):
r"""
Pipeline for text-guided image super-resolution using Stable Diffusion 2.
......@@ -224,6 +224,7 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline, TextualInversionLoaderMi
negative_prompt=None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
lora_scale: Optional[float] = None,
):
r"""
Encodes the prompt into text encoder hidden states.
......@@ -248,7 +249,14 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline, TextualInversionLoaderMi
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
lora_scale (`float`, *optional*):
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
"""
# set lora scale so that monkey patched LoRA
# function of text encoder can correctly access it
if lora_scale is not None and isinstance(self, LoraLoaderMixin):
self._lora_scale = lora_scale
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
......@@ -514,6 +522,7 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline, TextualInversionLoaderMi
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: int = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
):
r"""
Function invoked when calling the pipeline for generation.
......@@ -568,6 +577,10 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline, TextualInversionLoaderMi
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.
cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
Examples:
```py
......@@ -632,6 +645,9 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline, TextualInversionLoaderMi
do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt
text_encoder_lora_scale = (
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
)
prompt_embeds = self._encode_prompt(
prompt,
device,
......@@ -640,6 +656,7 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline, TextualInversionLoaderMi
negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
lora_scale=text_encoder_lora_scale,
)
# 4. Preprocess image
......@@ -703,6 +720,7 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline, TextualInversionLoaderMi
latent_model_input,
t,
encoder_hidden_states=prompt_embeds,
cross_attention_kwargs=cross_attention_kwargs,
class_labels=noise_level,
return_dict=False,
)[0]
......
......@@ -21,7 +21,7 @@ from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokeniz
from transformers.models.clip.modeling_clip import CLIPTextModelOutput
from ...image_processor import VaeImageProcessor
from ...loaders import TextualInversionLoaderMixin
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, PriorTransformer, UNet2DConditionModel
from ...models.embeddings import get_timestep_embedding
from ...schedulers import KarrasDiffusionSchedulers
......@@ -50,7 +50,7 @@ EXAMPLE_DOC_STRING = """
"""
class StableUnCLIPPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
class StableUnCLIPPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin):
"""
Pipeline for text-to-image generation using stable unCLIP.
......@@ -338,6 +338,7 @@ class StableUnCLIPPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
negative_prompt=None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
lora_scale: Optional[float] = None,
):
r"""
Encodes the prompt into text encoder hidden states.
......@@ -362,7 +363,14 @@ class StableUnCLIPPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
lora_scale (`float`, *optional*):
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
"""
# set lora scale so that monkey patched LoRA
# function of text encoder can correctly access it
if lora_scale is not None and isinstance(self, LoraLoaderMixin):
self._lora_scale = lora_scale
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
......@@ -856,6 +864,9 @@ class StableUnCLIPPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
do_classifier_free_guidance = guidance_scale > 1.0
# 8. Encode input prompt
text_encoder_lora_scale = (
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
)
prompt_embeds = self._encode_prompt(
prompt=prompt,
device=device,
......@@ -864,6 +875,7 @@ class StableUnCLIPPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
negative_prompt=negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
lora_scale=text_encoder_lora_scale,
)
# 9. Prepare image embeddings
......
......@@ -23,7 +23,7 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPV
from diffusers.utils.import_utils import is_accelerate_available
from ...image_processor import VaeImageProcessor
from ...loaders import TextualInversionLoaderMixin
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel
from ...models.embeddings import get_timestep_embedding
from ...schedulers import KarrasDiffusionSchedulers
......@@ -63,7 +63,7 @@ EXAMPLE_DOC_STRING = """
"""
class StableUnCLIPImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
class StableUnCLIPImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin):
"""
Pipeline for text-guided image to image generation using stable unCLIP.
......@@ -238,6 +238,7 @@ class StableUnCLIPImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin
negative_prompt=None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
lora_scale: Optional[float] = None,
):
r"""
Encodes the prompt into text encoder hidden states.
......@@ -262,7 +263,14 @@ class StableUnCLIPImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
lora_scale (`float`, *optional*):
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
"""
# set lora scale so that monkey patched LoRA
# function of text encoder can correctly access it
if lora_scale is not None and isinstance(self, LoraLoaderMixin):
self._lora_scale = lora_scale
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
......@@ -752,6 +760,9 @@ class StableUnCLIPImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin
do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt
text_encoder_lora_scale = (
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
)
prompt_embeds = self._encode_prompt(
prompt=prompt,
device=device,
......@@ -760,6 +771,7 @@ class StableUnCLIPImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin
negative_prompt=negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
lora_scale=text_encoder_lora_scale,
)
# 4. Encoder input image
......
......@@ -19,7 +19,7 @@ import numpy as np
import torch
from transformers import CLIPTextModel, CLIPTokenizer
from ...loaders import TextualInversionLoaderMixin
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet3DConditionModel
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import (
......@@ -73,7 +73,7 @@ def tensor2vid(video: torch.Tensor, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) -
return images
class TextToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
class TextToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin):
r"""
Pipeline for text-to-video generation.
......@@ -224,6 +224,7 @@ class TextToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
negative_prompt=None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
lora_scale: Optional[float] = None,
):
r"""
Encodes the prompt into text encoder hidden states.
......@@ -248,7 +249,14 @@ class TextToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
lora_scale (`float`, *optional*):
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
"""
# set lora scale so that monkey patched LoRA
# function of text encoder can correctly access it
if lora_scale is not None and isinstance(self, LoraLoaderMixin):
self._lora_scale = lora_scale
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
......@@ -591,6 +599,9 @@ class TextToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt
text_encoder_lora_scale = (
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
)
prompt_embeds = self._encode_prompt(
prompt,
device,
......@@ -599,6 +610,7 @@ class TextToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
lora_scale=text_encoder_lora_scale,
)
# 4. Prepare timesteps
......
......@@ -173,6 +173,17 @@ class LoraLoaderMixinTests(unittest.TestCase):
return noise, input_ids, pipeline_inputs
# copied from: https://colab.research.google.com/gist/sayakpaul/df2ef6e1ae6d8c10a49d859883b10860/scratchpad.ipynb
def get_dummy_tokens(self):
max_seq_length = 77
inputs = torch.randint(2, 56, size=(1, max_seq_length), generator=torch.manual_seed(0))
prepared_inputs = {}
prepared_inputs["input_ids"] = inputs
return prepared_inputs
def create_lora_weight_file(self, tmpdirname):
_, lora_components = self.get_dummy_components()
LoraLoaderMixin.save_lora_weights(
......@@ -188,7 +199,7 @@ class LoraLoaderMixinTests(unittest.TestCase):
sd_pipe = sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None)
noise, input_ids, pipeline_inputs = self.get_dummy_inputs()
_, _, pipeline_inputs = self.get_dummy_inputs()
original_images = sd_pipe(**pipeline_inputs).images
orig_image_slice = original_images[0, -3:, -3:, -1]
......@@ -214,7 +225,7 @@ class LoraLoaderMixinTests(unittest.TestCase):
sd_pipe = sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None)
noise, input_ids, pipeline_inputs = self.get_dummy_inputs()
_, _, pipeline_inputs = self.get_dummy_inputs()
original_images = sd_pipe(**pipeline_inputs).images
orig_image_slice = original_images[0, -3:, -3:, -1]
......@@ -242,7 +253,7 @@ class LoraLoaderMixinTests(unittest.TestCase):
sd_pipe = sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None)
noise, input_ids, pipeline_inputs = self.get_dummy_inputs()
_, _, pipeline_inputs = self.get_dummy_inputs()
original_images = sd_pipe(**pipeline_inputs).images
orig_image_slice = original_images[0, -3:, -3:, -1]
......@@ -260,16 +271,6 @@ class LoraLoaderMixinTests(unittest.TestCase):
# Outputs shouldn't match.
self.assertFalse(torch.allclose(torch.from_numpy(orig_image_slice), torch.from_numpy(lora_image_slice)))
# copied from: https://colab.research.google.com/gist/sayakpaul/df2ef6e1ae6d8c10a49d859883b10860/scratchpad.ipynb
def get_dummy_tokens(self):
max_seq_length = 77
inputs = torch.randint(2, 56, size=(1, max_seq_length), generator=torch.manual_seed(0))
prepared_inputs = {}
prepared_inputs["input_ids"] = inputs
return prepared_inputs
def test_text_encoder_lora_monkey_patch(self):
pipeline_components, _ = self.get_dummy_components()
pipe = StableDiffusionPipeline(**pipeline_components)
......@@ -358,6 +359,34 @@ class LoraLoaderMixinTests(unittest.TestCase):
outputs_without_lora, outputs_without_lora_removed
), "remove lora monkey patch should restore the original outputs"
def test_text_encoder_lora_scale(self):
pipeline_components, lora_components = self.get_dummy_components()
sd_pipe = StableDiffusionPipeline(**pipeline_components)
sd_pipe = sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None)
_, _, pipeline_inputs = self.get_dummy_inputs()
with tempfile.TemporaryDirectory() as tmpdirname:
LoraLoaderMixin.save_lora_weights(
save_directory=tmpdirname,
unet_lora_layers=lora_components["unet_lora_layers"],
text_encoder_lora_layers=lora_components["text_encoder_lora_layers"],
)
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))
sd_pipe.load_lora_weights(tmpdirname)
lora_images = sd_pipe(**pipeline_inputs).images
lora_image_slice = lora_images[0, -3:, -3:, -1]
lora_images_with_scale = sd_pipe(**pipeline_inputs, cross_attention_kwargs={"scale": 0.5}).images
lora_image_with_scale_slice = lora_images_with_scale[0, -3:, -3:, -1]
# Outputs shouldn't match.
self.assertFalse(
torch.allclose(torch.from_numpy(lora_image_slice), torch.from_numpy(lora_image_with_scale_slice))
)
def test_lora_unet_attn_processors(self):
with tempfile.TemporaryDirectory() as tmpdirname:
self.create_lora_weight_file(tmpdirname)
......@@ -416,7 +445,7 @@ class LoraLoaderMixinTests(unittest.TestCase):
sd_pipe = sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None)
noise, input_ids, pipeline_inputs = self.get_dummy_inputs()
_, _, pipeline_inputs = self.get_dummy_inputs()
# enable XFormers
sd_pipe.enable_xformers_memory_efficient_attention()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment