Unverified Commit 62863bb1 authored by YiYi Xu's avatar YiYi Xu Committed by GitHub
Browse files

Revert "[LoRA] introduce LoraBaseMixin to promote reusability." (#8976)

Revert "[LoRA] introduce LoraBaseMixin to promote reusability. (#8774)"

This reverts commit 527430d0.
parent 1fd647f2
...@@ -11,7 +11,7 @@ from torch.nn.functional import grid_sample ...@@ -11,7 +11,7 @@ from torch.nn.functional import grid_sample
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from ...image_processor import VaeImageProcessor from ...image_processor import VaeImageProcessor
from ...loaders import StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
...@@ -281,9 +281,7 @@ def create_motion_field_and_warp_latents(motion_field_strength_x, motion_field_s ...@@ -281,9 +281,7 @@ def create_motion_field_and_warp_latents(motion_field_strength_x, motion_field_s
return warped_latents return warped_latents
class TextToVideoZeroPipeline( class TextToVideoZeroPipeline(DiffusionPipeline, StableDiffusionMixin, TextualInversionLoaderMixin, LoraLoaderMixin):
DiffusionPipeline, StableDiffusionMixin, TextualInversionLoaderMixin, StableDiffusionLoraLoaderMixin
):
r""" r"""
Pipeline for zero-shot text-to-video generation using Stable Diffusion. Pipeline for zero-shot text-to-video generation using Stable Diffusion.
...@@ -833,7 +831,7 @@ class TextToVideoZeroPipeline( ...@@ -833,7 +831,7 @@ class TextToVideoZeroPipeline(
""" """
# set lora scale so that monkey patched LoRA # set lora scale so that monkey patched LoRA
# function of text encoder can correctly access it # function of text encoder can correctly access it
if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): if lora_scale is not None and isinstance(self, LoraLoaderMixin):
self._lora_scale = lora_scale self._lora_scale = lora_scale
# dynamically adjust the LoRA scale # dynamically adjust the LoRA scale
...@@ -966,7 +964,7 @@ class TextToVideoZeroPipeline( ...@@ -966,7 +964,7 @@ class TextToVideoZeroPipeline(
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
if self.text_encoder is not None: if self.text_encoder is not None:
if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers # Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder, lora_scale) unscale_lora_layers(self.text_encoder, lora_scale)
......
...@@ -14,7 +14,7 @@ from transformers import ( ...@@ -14,7 +14,7 @@ from transformers import (
) )
from ...image_processor import VaeImageProcessor from ...image_processor import VaeImageProcessor
from ...loaders import StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL from ...models import AutoencoderKL
from ...models.lora import adjust_lora_scale_text_encoder from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
...@@ -422,7 +422,7 @@ class UniDiffuserPipeline(DiffusionPipeline): ...@@ -422,7 +422,7 @@ class UniDiffuserPipeline(DiffusionPipeline):
""" """
# set lora scale so that monkey patched LoRA # set lora scale so that monkey patched LoRA
# function of text encoder can correctly access it # function of text encoder can correctly access it
if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): if lora_scale is not None and isinstance(self, LoraLoaderMixin):
self._lora_scale = lora_scale self._lora_scale = lora_scale
# dynamically adjust the LoRA scale # dynamically adjust the LoRA scale
...@@ -555,7 +555,7 @@ class UniDiffuserPipeline(DiffusionPipeline): ...@@ -555,7 +555,7 @@ class UniDiffuserPipeline(DiffusionPipeline):
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
if self.text_encoder is not None: if self.text_encoder is not None:
if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers # Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder, lora_scale) unscale_lora_layers(self.text_encoder, lora_scale)
......
...@@ -20,7 +20,7 @@ import numpy as np ...@@ -20,7 +20,7 @@ import numpy as np
import torch import torch
from transformers import CLIPTextModel, CLIPTokenizer from transformers import CLIPTextModel, CLIPTokenizer
from ...loaders import StableDiffusionLoraLoaderMixin from ...loaders import LoraLoaderMixin
from ...schedulers import DDPMWuerstchenScheduler from ...schedulers import DDPMWuerstchenScheduler
from ...utils import BaseOutput, deprecate, logging, replace_example_docstring from ...utils import BaseOutput, deprecate, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor from ...utils.torch_utils import randn_tensor
...@@ -62,7 +62,7 @@ class WuerstchenPriorPipelineOutput(BaseOutput): ...@@ -62,7 +62,7 @@ class WuerstchenPriorPipelineOutput(BaseOutput):
image_embeddings: Union[torch.Tensor, np.ndarray] image_embeddings: Union[torch.Tensor, np.ndarray]
class WuerstchenPriorPipeline(DiffusionPipeline, StableDiffusionLoraLoaderMixin): class WuerstchenPriorPipeline(DiffusionPipeline, LoraLoaderMixin):
""" """
Pipeline for generating image prior for Wuerstchen. Pipeline for generating image prior for Wuerstchen.
...@@ -70,8 +70,8 @@ class WuerstchenPriorPipeline(DiffusionPipeline, StableDiffusionLoraLoaderMixin) ...@@ -70,8 +70,8 @@ class WuerstchenPriorPipeline(DiffusionPipeline, StableDiffusionLoraLoaderMixin)
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
The pipeline also inherits the following loading methods: The pipeline also inherits the following loading methods:
- [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights - [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
- [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights - [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
Args: Args:
prior ([`Prior`]): prior ([`Prior`]):
...@@ -95,7 +95,6 @@ class WuerstchenPriorPipeline(DiffusionPipeline, StableDiffusionLoraLoaderMixin) ...@@ -95,7 +95,6 @@ class WuerstchenPriorPipeline(DiffusionPipeline, StableDiffusionLoraLoaderMixin)
text_encoder_name = "text_encoder" text_encoder_name = "text_encoder"
model_cpu_offload_seq = "text_encoder->prior" model_cpu_offload_seq = "text_encoder->prior"
_callback_tensor_inputs = ["latents", "text_encoder_hidden_states", "negative_prompt_embeds"] _callback_tensor_inputs = ["latents", "text_encoder_hidden_states", "negative_prompt_embeds"]
_lora_loadable_modules = ["prior", "text_encoder"]
def __init__( def __init__(
self, self,
......
...@@ -12,55 +12,376 @@ ...@@ -12,55 +12,376 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os
import sys import sys
import tempfile
import unittest import unittest
import numpy as np
import torch
from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokenizer, T5EncoderModel
from diffusers import ( from diffusers import (
AutoencoderKL,
FlowMatchEulerDiscreteScheduler, FlowMatchEulerDiscreteScheduler,
SD3Transformer2DModel,
StableDiffusion3Pipeline, StableDiffusion3Pipeline,
) )
from diffusers.utils.testing_utils import is_peft_available, require_peft_backend, require_torch_gpu, torch_device from diffusers.utils.testing_utils import is_peft_available, require_peft_backend, require_torch_gpu, torch_device
if is_peft_available(): if is_peft_available():
pass from peft import LoraConfig
from peft.utils import get_peft_model_state_dict
sys.path.append(".") sys.path.append(".")
from utils import PeftLoraLoaderMixinTests # noqa: E402 from utils import check_if_lora_correctly_set # noqa: E402
@require_peft_backend @require_peft_backend
class SD3LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): class SD3LoRATests(unittest.TestCase):
pipeline_class = StableDiffusion3Pipeline pipeline_class = StableDiffusion3Pipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler()
scheduler_kwargs = {} def get_dummy_components(self):
transformer_kwargs = { torch.manual_seed(0)
"sample_size": 32, transformer = SD3Transformer2DModel(
"patch_size": 1, sample_size=32,
"in_channels": 4, patch_size=1,
"num_layers": 1, in_channels=4,
"attention_head_dim": 8, num_layers=1,
"num_attention_heads": 4, attention_head_dim=8,
"caption_projection_dim": 32, num_attention_heads=4,
"joint_attention_dim": 32, caption_projection_dim=32,
"pooled_projection_dim": 64, joint_attention_dim=32,
"out_channels": 4, pooled_projection_dim=64,
} out_channels=4,
vae_kwargs = { )
"sample_size": 32, clip_text_encoder_config = CLIPTextConfig(
"in_channels": 3, bos_token_id=0,
"out_channels": 3, eos_token_id=2,
"block_out_channels": (4,), hidden_size=32,
"layers_per_block": 1, intermediate_size=37,
"latent_channels": 4, layer_norm_eps=1e-05,
"norm_num_groups": 1, num_attention_heads=4,
"use_quant_conv": False, num_hidden_layers=5,
"use_post_quant_conv": False, pad_token_id=1,
"shift_factor": 0.0609, vocab_size=1000,
"scaling_factor": 1.5035, hidden_act="gelu",
} projection_dim=32,
has_three_text_encoders = True )
torch.manual_seed(0)
text_encoder = CLIPTextModelWithProjection(clip_text_encoder_config)
torch.manual_seed(0)
text_encoder_2 = CLIPTextModelWithProjection(clip_text_encoder_config)
text_encoder_3 = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
tokenizer_3 = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
torch.manual_seed(0)
vae = AutoencoderKL(
sample_size=32,
in_channels=3,
out_channels=3,
block_out_channels=(4,),
layers_per_block=1,
latent_channels=4,
norm_num_groups=1,
use_quant_conv=False,
use_post_quant_conv=False,
shift_factor=0.0609,
scaling_factor=1.5035,
)
scheduler = FlowMatchEulerDiscreteScheduler()
return {
"scheduler": scheduler,
"text_encoder": text_encoder,
"text_encoder_2": text_encoder_2,
"text_encoder_3": text_encoder_3,
"tokenizer": tokenizer,
"tokenizer_2": tokenizer_2,
"tokenizer_3": tokenizer_3,
"transformer": transformer,
"vae": vae,
}
def get_dummy_inputs(self, device, seed=0):
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device="cpu").manual_seed(seed)
inputs = {
"prompt": "A painting of a squirrel eating a burger",
"generator": generator,
"num_inference_steps": 2,
"guidance_scale": 5.0,
"output_type": "np",
}
return inputs
def get_lora_config_for_transformer(self):
lora_config = LoraConfig(
r=4,
lora_alpha=4,
target_modules=["to_q", "to_k", "to_v", "to_out.0"],
init_lora_weights=False,
use_dora=False,
)
return lora_config
def get_lora_config_for_text_encoders(self):
text_lora_config = LoraConfig(
r=4,
lora_alpha=4,
init_lora_weights="gaussian",
target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
)
return text_lora_config
def test_simple_inference_with_transformer_lora_save_load(self):
components = self.get_dummy_components()
transformer_config = self.get_lora_config_for_transformer()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
pipe.transformer.add_adapter(transformer_config)
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer")
inputs = self.get_dummy_inputs(torch_device)
images_lora = pipe(**inputs).images
with tempfile.TemporaryDirectory() as tmpdirname:
transformer_state_dict = get_peft_model_state_dict(pipe.transformer)
self.pipeline_class.save_lora_weights(
save_directory=tmpdirname,
transformer_lora_layers=transformer_state_dict,
)
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
pipe.unload_lora_weights()
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
inputs = self.get_dummy_inputs(torch_device)
images_lora_from_pretrained = pipe(**inputs).images
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer")
self.assertTrue(
np.allclose(images_lora, images_lora_from_pretrained, atol=1e-3, rtol=1e-3),
"Loading from saved checkpoints should give same results.",
)
def test_simple_inference_with_clip_encoders_lora_save_load(self):
components = self.get_dummy_components()
transformer_config = self.get_lora_config_for_transformer()
text_encoder_config = self.get_lora_config_for_text_encoders()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(torch_device)
pipe.transformer.add_adapter(transformer_config)
pipe.text_encoder.add_adapter(text_encoder_config)
pipe.text_encoder_2.add_adapter(text_encoder_config)
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer")
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder.")
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2.")
inputs = self.get_dummy_inputs(torch_device)
images_lora = pipe(**inputs).images
with tempfile.TemporaryDirectory() as tmpdirname:
transformer_state_dict = get_peft_model_state_dict(pipe.transformer)
text_encoder_one_state_dict = get_peft_model_state_dict(pipe.text_encoder)
text_encoder_two_state_dict = get_peft_model_state_dict(pipe.text_encoder_2)
self.pipeline_class.save_lora_weights(
save_directory=tmpdirname,
transformer_lora_layers=transformer_state_dict,
text_encoder_lora_layers=text_encoder_one_state_dict,
text_encoder_2_lora_layers=text_encoder_two_state_dict,
)
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
pipe.unload_lora_weights()
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
inputs = self.get_dummy_inputs(torch_device)
images_lora_from_pretrained = pipe(**inputs).images
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer")
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text_encoder_one")
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text_encoder_two")
self.assertTrue(
np.allclose(images_lora, images_lora_from_pretrained, atol=1e-3, rtol=1e-3),
"Loading from saved checkpoints should give same results.",
)
def test_simple_inference_with_transformer_lora_and_scale(self):
components = self.get_dummy_components()
transformer_lora_config = self.get_lora_config_for_transformer()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(torch_device)
output_no_lora = pipe(**inputs).images
pipe.transformer.add_adapter(transformer_lora_config)
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer")
inputs = self.get_dummy_inputs(torch_device)
output_lora = pipe(**inputs).images
self.assertTrue(
not np.allclose(output_lora, output_no_lora, atol=1e-3, rtol=1e-3), "Lora should change the output"
)
inputs = self.get_dummy_inputs(torch_device)
output_lora_scale = pipe(**inputs, joint_attention_kwargs={"scale": 0.5}).images
self.assertTrue(
not np.allclose(output_lora, output_lora_scale, atol=1e-3, rtol=1e-3),
"Lora + scale should change the output",
)
inputs = self.get_dummy_inputs(torch_device)
output_lora_0_scale = pipe(**inputs, joint_attention_kwargs={"scale": 0.0}).images
self.assertTrue(
np.allclose(output_no_lora, output_lora_0_scale, atol=1e-3, rtol=1e-3),
"Lora + 0 scale should lead to same result as no LoRA",
)
def test_simple_inference_with_clip_encoders_lora_and_scale(self):
components = self.get_dummy_components()
transformer_lora_config = self.get_lora_config_for_transformer()
text_encoder_config = self.get_lora_config_for_text_encoders()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(torch_device)
output_no_lora = pipe(**inputs).images
pipe.transformer.add_adapter(transformer_lora_config)
pipe.text_encoder.add_adapter(text_encoder_config)
pipe.text_encoder_2.add_adapter(text_encoder_config)
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer")
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text_encoder_one")
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text_encoder_two")
inputs = self.get_dummy_inputs(torch_device)
output_lora = pipe(**inputs).images
self.assertTrue(
not np.allclose(output_lora, output_no_lora, atol=1e-3, rtol=1e-3), "Lora should change the output"
)
inputs = self.get_dummy_inputs(torch_device)
output_lora_scale = pipe(**inputs, joint_attention_kwargs={"scale": 0.5}).images
self.assertTrue(
not np.allclose(output_lora, output_lora_scale, atol=1e-3, rtol=1e-3),
"Lora + scale should change the output",
)
inputs = self.get_dummy_inputs(torch_device)
output_lora_0_scale = pipe(**inputs, joint_attention_kwargs={"scale": 0.0}).images
self.assertTrue(
np.allclose(output_no_lora, output_lora_0_scale, atol=1e-3, rtol=1e-3),
"Lora + 0 scale should lead to same result as no LoRA",
)
def test_simple_inference_with_transformer_fused(self):
components = self.get_dummy_components()
transformer_lora_config = self.get_lora_config_for_transformer()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(torch_device)
output_no_lora = pipe(**inputs).images
pipe.transformer.add_adapter(transformer_lora_config)
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer")
pipe.fuse_lora()
# Fusing should still keep the LoRA layers
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer")
inputs = self.get_dummy_inputs(torch_device)
ouput_fused = pipe(**inputs).images
self.assertFalse(
np.allclose(ouput_fused, output_no_lora, atol=1e-3, rtol=1e-3), "Fused lora should change the output"
)
def test_simple_inference_with_transformer_fused_with_no_fusion(self):
components = self.get_dummy_components()
transformer_lora_config = self.get_lora_config_for_transformer()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(torch_device)
output_no_lora = pipe(**inputs).images
pipe.transformer.add_adapter(transformer_lora_config)
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer")
inputs = self.get_dummy_inputs(torch_device)
ouput_lora = pipe(**inputs).images
pipe.fuse_lora()
# Fusing should still keep the LoRA layers
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer")
inputs = self.get_dummy_inputs(torch_device)
ouput_fused = pipe(**inputs).images
self.assertFalse(
np.allclose(ouput_fused, output_no_lora, atol=1e-3, rtol=1e-3), "Fused lora should change the output"
)
self.assertTrue(
np.allclose(ouput_fused, ouput_lora, atol=1e-3, rtol=1e-3),
"Fused lora output should be changed when LoRA isn't fused but still effective.",
)
def test_simple_inference_with_transformer_fuse_unfuse(self):
components = self.get_dummy_components()
transformer_lora_config = self.get_lora_config_for_transformer()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(torch_device)
output_no_lora = pipe(**inputs).images
pipe.transformer.add_adapter(transformer_lora_config)
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer")
pipe.fuse_lora()
# Fusing should still keep the LoRA layers
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer")
inputs = self.get_dummy_inputs(torch_device)
ouput_fused = pipe(**inputs).images
self.assertFalse(
np.allclose(ouput_fused, output_no_lora, atol=1e-3, rtol=1e-3), "Fused lora should change the output"
)
pipe.unfuse_lora()
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer")
inputs = self.get_dummy_inputs(torch_device)
output_unfused_lora = pipe(**inputs).images
self.assertTrue(
np.allclose(ouput_fused, output_unfused_lora, atol=1e-3, rtol=1e-3), "Fused lora should change the output"
)
@require_torch_gpu @require_torch_gpu
def test_sd3_lora(self): def test_sd3_lora(self):
......
...@@ -19,14 +19,12 @@ from itertools import product ...@@ -19,14 +19,12 @@ from itertools import product
import numpy as np import numpy as np
import torch import torch
from transformers import AutoTokenizer, CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer, T5EncoderModel from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
from diffusers import ( from diffusers import (
AutoencoderKL, AutoencoderKL,
DDIMScheduler, DDIMScheduler,
FlowMatchEulerDiscreteScheduler,
LCMScheduler, LCMScheduler,
SD3Transformer2DModel,
UNet2DConditionModel, UNet2DConditionModel,
) )
from diffusers.utils.import_utils import is_peft_available from diffusers.utils.import_utils import is_peft_available
...@@ -73,47 +71,28 @@ class PeftLoraLoaderMixinTests: ...@@ -73,47 +71,28 @@ class PeftLoraLoaderMixinTests:
scheduler_cls = None scheduler_cls = None
scheduler_kwargs = None scheduler_kwargs = None
has_two_text_encoders = False has_two_text_encoders = False
has_three_text_encoders = False
unet_kwargs = None unet_kwargs = None
transformer_kwargs = None
vae_kwargs = None vae_kwargs = None
def get_dummy_components(self, scheduler_cls=None, use_dora=False): def get_dummy_components(self, scheduler_cls=None, use_dora=False):
if self.unet_kwargs and self.transformer_kwargs:
raise ValueError("Both `unet_kwargs` and `transformer_kwargs` cannot be specified.")
if self.has_two_text_encoders and self.has_three_text_encoders:
raise ValueError("Both `has_two_text_encoders` and `has_three_text_encoders` cannot be True.")
scheduler_cls = self.scheduler_cls if scheduler_cls is None else scheduler_cls scheduler_cls = self.scheduler_cls if scheduler_cls is None else scheduler_cls
rank = 4 rank = 4
torch.manual_seed(0) torch.manual_seed(0)
if self.unet_kwargs is not None: unet = UNet2DConditionModel(**self.unet_kwargs)
unet = UNet2DConditionModel(**self.unet_kwargs)
else:
transformer = SD3Transformer2DModel(**self.transformer_kwargs)
scheduler = scheduler_cls(**self.scheduler_kwargs) scheduler = scheduler_cls(**self.scheduler_kwargs)
torch.manual_seed(0) torch.manual_seed(0)
vae = AutoencoderKL(**self.vae_kwargs) vae = AutoencoderKL(**self.vae_kwargs)
if not self.has_three_text_encoders: text_encoder = CLIPTextModel.from_pretrained("peft-internal-testing/tiny-clip-text-2")
text_encoder = CLIPTextModel.from_pretrained("peft-internal-testing/tiny-clip-text-2") tokenizer = CLIPTokenizer.from_pretrained("peft-internal-testing/tiny-clip-text-2")
tokenizer = CLIPTokenizer.from_pretrained("peft-internal-testing/tiny-clip-text-2")
if self.has_two_text_encoders: if self.has_two_text_encoders:
text_encoder_2 = CLIPTextModelWithProjection.from_pretrained("peft-internal-testing/tiny-clip-text-2") text_encoder_2 = CLIPTextModelWithProjection.from_pretrained("peft-internal-testing/tiny-clip-text-2")
tokenizer_2 = CLIPTokenizer.from_pretrained("peft-internal-testing/tiny-clip-text-2") tokenizer_2 = CLIPTokenizer.from_pretrained("peft-internal-testing/tiny-clip-text-2")
if self.has_three_text_encoders:
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
tokenizer_3 = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
text_encoder = CLIPTextModelWithProjection.from_pretrained("hf-internal-testing/tiny-sd3-text_encoder")
text_encoder_2 = CLIPTextModelWithProjection.from_pretrained("hf-internal-testing/tiny-sd3-text_encoder-2")
text_encoder_3 = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
text_lora_config = LoraConfig( text_lora_config = LoraConfig(
r=rank, r=rank,
lora_alpha=rank, lora_alpha=rank,
...@@ -122,7 +101,7 @@ class PeftLoraLoaderMixinTests: ...@@ -122,7 +101,7 @@ class PeftLoraLoaderMixinTests:
use_dora=use_dora, use_dora=use_dora,
) )
denoiser_lora_config = LoraConfig( unet_lora_config = LoraConfig(
r=rank, r=rank,
lora_alpha=rank, lora_alpha=rank,
target_modules=["to_q", "to_k", "to_v", "to_out.0"], target_modules=["to_q", "to_k", "to_v", "to_out.0"],
...@@ -130,31 +109,18 @@ class PeftLoraLoaderMixinTests: ...@@ -130,31 +109,18 @@ class PeftLoraLoaderMixinTests:
use_dora=use_dora, use_dora=use_dora,
) )
if self.has_two_text_encoders or self.has_three_text_encoders: if self.has_two_text_encoders:
if self.unet_kwargs is not None: pipeline_components = {
pipeline_components = { "unet": unet,
"unet": unet, "scheduler": scheduler,
"scheduler": scheduler, "vae": vae,
"vae": vae, "text_encoder": text_encoder,
"text_encoder": text_encoder, "tokenizer": tokenizer,
"tokenizer": tokenizer, "text_encoder_2": text_encoder_2,
"text_encoder_2": text_encoder_2, "tokenizer_2": tokenizer_2,
"tokenizer_2": tokenizer_2, "image_encoder": None,
"image_encoder": None, "feature_extractor": None,
"feature_extractor": None, }
}
elif self.has_three_text_encoders and self.transformer_kwargs is not None:
pipeline_components = {
"transformer": transformer,
"scheduler": scheduler,
"vae": vae,
"text_encoder": text_encoder,
"tokenizer": tokenizer,
"text_encoder_2": text_encoder_2,
"tokenizer_2": tokenizer_2,
"text_encoder_3": text_encoder_3,
"tokenizer_3": tokenizer_3,
}
else: else:
pipeline_components = { pipeline_components = {
"unet": unet, "unet": unet,
...@@ -167,7 +133,7 @@ class PeftLoraLoaderMixinTests: ...@@ -167,7 +133,7 @@ class PeftLoraLoaderMixinTests:
"image_encoder": None, "image_encoder": None,
} }
return pipeline_components, text_lora_config, denoiser_lora_config return pipeline_components, text_lora_config, unet_lora_config
def get_dummy_inputs(self, with_generator=True): def get_dummy_inputs(self, with_generator=True):
batch_size = 1 batch_size = 1
...@@ -204,12 +170,7 @@ class PeftLoraLoaderMixinTests: ...@@ -204,12 +170,7 @@ class PeftLoraLoaderMixinTests:
""" """
Tests a simple inference and makes sure it works as expected Tests a simple inference and makes sure it works as expected
""" """
scheduler_classes = ( for scheduler_cls in [DDIMScheduler, LCMScheduler]:
[FlowMatchEulerDiscreteScheduler]
if self.has_three_text_encoders and self.transformer_kwargs
else [DDIMScheduler, LCMScheduler]
)
for scheduler_cls in scheduler_classes:
components, text_lora_config, _ = self.get_dummy_components(scheduler_cls) components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components) pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device) pipe = pipe.to(torch_device)
...@@ -217,20 +178,14 @@ class PeftLoraLoaderMixinTests: ...@@ -217,20 +178,14 @@ class PeftLoraLoaderMixinTests:
_, _, inputs = self.get_dummy_inputs() _, _, inputs = self.get_dummy_inputs()
output_no_lora = pipe(**inputs).images output_no_lora = pipe(**inputs).images
shape_to_be_checked = (1, 64, 64, 3) if self.unet_kwargs is not None else (1, 32, 32, 3) self.assertTrue(output_no_lora.shape == (1, 64, 64, 3))
self.assertTrue(output_no_lora.shape == shape_to_be_checked)
def test_simple_inference_with_text_lora(self): def test_simple_inference_with_text_lora(self):
""" """
Tests a simple inference with lora attached on the text encoder Tests a simple inference with lora attached on the text encoder
and makes sure it works as expected and makes sure it works as expected
""" """
scheduler_classes = ( for scheduler_cls in [DDIMScheduler, LCMScheduler]:
[FlowMatchEulerDiscreteScheduler]
if self.has_three_text_encoders and self.transformer_kwargs
else [DDIMScheduler, LCMScheduler]
)
for scheduler_cls in scheduler_classes:
components, text_lora_config, _ = self.get_dummy_components(scheduler_cls) components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components) pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device) pipe = pipe.to(torch_device)
...@@ -238,13 +193,12 @@ class PeftLoraLoaderMixinTests: ...@@ -238,13 +193,12 @@ class PeftLoraLoaderMixinTests:
_, _, inputs = self.get_dummy_inputs(with_generator=False) _, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
shape_to_be_checked = (1, 64, 64, 3) if self.unet_kwargs is not None else (1, 32, 32, 3) self.assertTrue(output_no_lora.shape == (1, 64, 64, 3))
self.assertTrue(output_no_lora.shape == shape_to_be_checked)
pipe.text_encoder.add_adapter(text_lora_config) pipe.text_encoder.add_adapter(text_lora_config)
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
if self.has_two_text_encoders or self.has_three_text_encoders: if self.has_two_text_encoders:
pipe.text_encoder_2.add_adapter(text_lora_config) pipe.text_encoder_2.add_adapter(text_lora_config)
self.assertTrue( self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
...@@ -260,12 +214,7 @@ class PeftLoraLoaderMixinTests: ...@@ -260,12 +214,7 @@ class PeftLoraLoaderMixinTests:
Tests a simple inference with lora attached on the text encoder + scale argument Tests a simple inference with lora attached on the text encoder + scale argument
and makes sure it works as expected and makes sure it works as expected
""" """
scheduler_classes = ( for scheduler_cls in [DDIMScheduler, LCMScheduler]:
[FlowMatchEulerDiscreteScheduler]
if self.has_three_text_encoders and self.transformer_kwargs
else [DDIMScheduler, LCMScheduler]
)
for scheduler_cls in scheduler_classes:
components, text_lora_config, _ = self.get_dummy_components(scheduler_cls) components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components) pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device) pipe = pipe.to(torch_device)
...@@ -273,13 +222,12 @@ class PeftLoraLoaderMixinTests: ...@@ -273,13 +222,12 @@ class PeftLoraLoaderMixinTests:
_, _, inputs = self.get_dummy_inputs(with_generator=False) _, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
shape_to_be_checked = (1, 64, 64, 3) if self.unet_kwargs is not None else (1, 32, 32, 3) self.assertTrue(output_no_lora.shape == (1, 64, 64, 3))
self.assertTrue(output_no_lora.shape == shape_to_be_checked)
pipe.text_encoder.add_adapter(text_lora_config) pipe.text_encoder.add_adapter(text_lora_config)
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
if self.has_two_text_encoders or self.has_three_text_encoders: if self.has_two_text_encoders:
pipe.text_encoder_2.add_adapter(text_lora_config) pipe.text_encoder_2.add_adapter(text_lora_config)
self.assertTrue( self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
...@@ -290,27 +238,17 @@ class PeftLoraLoaderMixinTests: ...@@ -290,27 +238,17 @@ class PeftLoraLoaderMixinTests:
not np.allclose(output_lora, output_no_lora, atol=1e-3, rtol=1e-3), "Lora should change the output" not np.allclose(output_lora, output_no_lora, atol=1e-3, rtol=1e-3), "Lora should change the output"
) )
if self.unet_kwargs is not None: output_lora_scale = pipe(
output_lora_scale = pipe( **inputs, generator=torch.manual_seed(0), cross_attention_kwargs={"scale": 0.5}
**inputs, generator=torch.manual_seed(0), cross_attention_kwargs={"scale": 0.5} ).images
).images
else:
output_lora_scale = pipe(
**inputs, generator=torch.manual_seed(0), joint_attention_kwargs={"scale": 0.5}
).images
self.assertTrue( self.assertTrue(
not np.allclose(output_lora, output_lora_scale, atol=1e-3, rtol=1e-3), not np.allclose(output_lora, output_lora_scale, atol=1e-3, rtol=1e-3),
"Lora + scale should change the output", "Lora + scale should change the output",
) )
if self.unet_kwargs is not None: output_lora_0_scale = pipe(
output_lora_0_scale = pipe( **inputs, generator=torch.manual_seed(0), cross_attention_kwargs={"scale": 0.0}
**inputs, generator=torch.manual_seed(0), cross_attention_kwargs={"scale": 0.0} ).images
).images
else:
output_lora_0_scale = pipe(
**inputs, generator=torch.manual_seed(0), joint_attention_kwargs={"scale": 0.0}
).images
self.assertTrue( self.assertTrue(
np.allclose(output_no_lora, output_lora_0_scale, atol=1e-3, rtol=1e-3), np.allclose(output_no_lora, output_lora_0_scale, atol=1e-3, rtol=1e-3),
"Lora + 0 scale should lead to same result as no LoRA", "Lora + 0 scale should lead to same result as no LoRA",
...@@ -321,12 +259,7 @@ class PeftLoraLoaderMixinTests: ...@@ -321,12 +259,7 @@ class PeftLoraLoaderMixinTests:
Tests a simple inference with lora attached into text encoder + fuses the lora weights into base model Tests a simple inference with lora attached into text encoder + fuses the lora weights into base model
and makes sure it works as expected and makes sure it works as expected
""" """
scheduler_classes = ( for scheduler_cls in [DDIMScheduler, LCMScheduler]:
[FlowMatchEulerDiscreteScheduler]
if self.has_three_text_encoders and self.transformer_kwargs
else [DDIMScheduler, LCMScheduler]
)
for scheduler_cls in scheduler_classes:
components, text_lora_config, _ = self.get_dummy_components(scheduler_cls) components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components) pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device) pipe = pipe.to(torch_device)
...@@ -334,13 +267,12 @@ class PeftLoraLoaderMixinTests: ...@@ -334,13 +267,12 @@ class PeftLoraLoaderMixinTests:
_, _, inputs = self.get_dummy_inputs(with_generator=False) _, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
shape_to_be_checked = (1, 64, 64, 3) if self.unet_kwargs is not None else (1, 32, 32, 3) self.assertTrue(output_no_lora.shape == (1, 64, 64, 3))
self.assertTrue(output_no_lora.shape == shape_to_be_checked)
pipe.text_encoder.add_adapter(text_lora_config) pipe.text_encoder.add_adapter(text_lora_config)
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
if self.has_two_text_encoders or self.has_three_text_encoders: if self.has_two_text_encoders:
pipe.text_encoder_2.add_adapter(text_lora_config) pipe.text_encoder_2.add_adapter(text_lora_config)
self.assertTrue( self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
...@@ -350,7 +282,7 @@ class PeftLoraLoaderMixinTests: ...@@ -350,7 +282,7 @@ class PeftLoraLoaderMixinTests:
# Fusing should still keep the LoRA layers # Fusing should still keep the LoRA layers
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
if self.has_two_text_encoders or self.has_three_text_encoders: if self.has_two_text_encoders:
self.assertTrue( self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
) )
...@@ -365,12 +297,7 @@ class PeftLoraLoaderMixinTests: ...@@ -365,12 +297,7 @@ class PeftLoraLoaderMixinTests:
Tests a simple inference with lora attached to text encoder, then unloads the lora weights Tests a simple inference with lora attached to text encoder, then unloads the lora weights
and makes sure it works as expected and makes sure it works as expected
""" """
scheduler_classes = ( for scheduler_cls in [DDIMScheduler, LCMScheduler]:
[FlowMatchEulerDiscreteScheduler]
if self.has_three_text_encoders and self.transformer_kwargs
else [DDIMScheduler, LCMScheduler]
)
for scheduler_cls in scheduler_classes:
components, text_lora_config, _ = self.get_dummy_components(scheduler_cls) components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components) pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device) pipe = pipe.to(torch_device)
...@@ -378,13 +305,12 @@ class PeftLoraLoaderMixinTests: ...@@ -378,13 +305,12 @@ class PeftLoraLoaderMixinTests:
_, _, inputs = self.get_dummy_inputs(with_generator=False) _, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
shape_to_be_checked = (1, 64, 64, 3) if self.unet_kwargs is not None else (1, 32, 32, 3) self.assertTrue(output_no_lora.shape == (1, 64, 64, 3))
self.assertTrue(output_no_lora.shape == shape_to_be_checked)
pipe.text_encoder.add_adapter(text_lora_config) pipe.text_encoder.add_adapter(text_lora_config)
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
if self.has_two_text_encoders or self.has_three_text_encoders: if self.has_two_text_encoders:
pipe.text_encoder_2.add_adapter(text_lora_config) pipe.text_encoder_2.add_adapter(text_lora_config)
self.assertTrue( self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
...@@ -396,7 +322,7 @@ class PeftLoraLoaderMixinTests: ...@@ -396,7 +322,7 @@ class PeftLoraLoaderMixinTests:
check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly unloaded in text encoder" check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly unloaded in text encoder"
) )
if self.has_two_text_encoders or self.has_three_text_encoders: if self.has_two_text_encoders:
self.assertFalse( self.assertFalse(
check_if_lora_correctly_set(pipe.text_encoder_2), check_if_lora_correctly_set(pipe.text_encoder_2),
"Lora not correctly unloaded in text encoder 2", "Lora not correctly unloaded in text encoder 2",
...@@ -412,12 +338,7 @@ class PeftLoraLoaderMixinTests: ...@@ -412,12 +338,7 @@ class PeftLoraLoaderMixinTests:
""" """
Tests a simple usecase where users could use saving utilities for LoRA. Tests a simple usecase where users could use saving utilities for LoRA.
""" """
scheduler_classes = ( for scheduler_cls in [DDIMScheduler, LCMScheduler]:
[FlowMatchEulerDiscreteScheduler]
if self.has_three_text_encoders and self.transformer_kwargs
else [DDIMScheduler, LCMScheduler]
)
for scheduler_cls in scheduler_classes:
components, text_lora_config, _ = self.get_dummy_components(scheduler_cls) components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components) pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device) pipe = pipe.to(torch_device)
...@@ -425,13 +346,12 @@ class PeftLoraLoaderMixinTests: ...@@ -425,13 +346,12 @@ class PeftLoraLoaderMixinTests:
_, _, inputs = self.get_dummy_inputs(with_generator=False) _, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
shape_to_be_checked = (1, 64, 64, 3) if self.unet_kwargs is not None else (1, 32, 32, 3) self.assertTrue(output_no_lora.shape == (1, 64, 64, 3))
self.assertTrue(output_no_lora.shape == shape_to_be_checked)
pipe.text_encoder.add_adapter(text_lora_config) pipe.text_encoder.add_adapter(text_lora_config)
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
if self.has_two_text_encoders or self.has_three_text_encoders: if self.has_two_text_encoders:
pipe.text_encoder_2.add_adapter(text_lora_config) pipe.text_encoder_2.add_adapter(text_lora_config)
self.assertTrue( self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
...@@ -441,7 +361,7 @@ class PeftLoraLoaderMixinTests: ...@@ -441,7 +361,7 @@ class PeftLoraLoaderMixinTests:
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
text_encoder_state_dict = get_peft_model_state_dict(pipe.text_encoder) text_encoder_state_dict = get_peft_model_state_dict(pipe.text_encoder)
if self.has_two_text_encoders or self.has_three_text_encoders: if self.has_two_text_encoders:
text_encoder_2_state_dict = get_peft_model_state_dict(pipe.text_encoder_2) text_encoder_2_state_dict = get_peft_model_state_dict(pipe.text_encoder_2)
self.pipeline_class.save_lora_weights( self.pipeline_class.save_lora_weights(
...@@ -465,7 +385,7 @@ class PeftLoraLoaderMixinTests: ...@@ -465,7 +385,7 @@ class PeftLoraLoaderMixinTests:
images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0)).images images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0)).images
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
if self.has_two_text_encoders or self.has_three_text_encoders: if self.has_two_text_encoders:
self.assertTrue( self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
) )
...@@ -481,14 +401,9 @@ class PeftLoraLoaderMixinTests: ...@@ -481,14 +401,9 @@ class PeftLoraLoaderMixinTests:
with different ranks and some adapters removed with different ranks and some adapters removed
and makes sure it works as expected and makes sure it works as expected
""" """
scheduler_classes = ( for scheduler_cls in [DDIMScheduler, LCMScheduler]:
[FlowMatchEulerDiscreteScheduler]
if self.has_three_text_encoders and self.transformer_kwargs
else [DDIMScheduler, LCMScheduler]
)
for scheduler_cls in scheduler_classes:
components, _, _ = self.get_dummy_components(scheduler_cls) components, _, _ = self.get_dummy_components(scheduler_cls)
# Verify `StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder` handles different ranks per module (PR#8324). # Verify `LoraLoaderMixin.load_lora_into_text_encoder` handles different ranks per module (PR#8324).
text_lora_config = LoraConfig( text_lora_config = LoraConfig(
r=4, r=4,
rank_pattern={"q_proj": 1, "k_proj": 2, "v_proj": 3}, rank_pattern={"q_proj": 1, "k_proj": 2, "v_proj": 3},
...@@ -503,8 +418,7 @@ class PeftLoraLoaderMixinTests: ...@@ -503,8 +418,7 @@ class PeftLoraLoaderMixinTests:
_, _, inputs = self.get_dummy_inputs(with_generator=False) _, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
shape_to_be_checked = (1, 64, 64, 3) if self.unet_kwargs is not None else (1, 32, 32, 3) self.assertTrue(output_no_lora.shape == (1, 64, 64, 3))
self.assertTrue(output_no_lora.shape == shape_to_be_checked)
pipe.text_encoder.add_adapter(text_lora_config) pipe.text_encoder.add_adapter(text_lora_config)
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
...@@ -516,7 +430,7 @@ class PeftLoraLoaderMixinTests: ...@@ -516,7 +430,7 @@ class PeftLoraLoaderMixinTests:
if "text_model.encoder.layers.4" not in module_name if "text_model.encoder.layers.4" not in module_name
} }
if self.has_two_text_encoders or self.has_three_text_encoders: if self.has_two_text_encoders:
pipe.text_encoder_2.add_adapter(text_lora_config) pipe.text_encoder_2.add_adapter(text_lora_config)
self.assertTrue( self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
...@@ -548,12 +462,7 @@ class PeftLoraLoaderMixinTests: ...@@ -548,12 +462,7 @@ class PeftLoraLoaderMixinTests:
""" """
Tests a simple usecase where users could use saving utilities for LoRA through save_pretrained Tests a simple usecase where users could use saving utilities for LoRA through save_pretrained
""" """
scheduler_classes = ( for scheduler_cls in [DDIMScheduler, LCMScheduler]:
[FlowMatchEulerDiscreteScheduler]
if self.has_three_text_encoders and self.transformer_kwargs
else [DDIMScheduler, LCMScheduler]
)
for scheduler_cls in scheduler_classes:
components, text_lora_config, _ = self.get_dummy_components(scheduler_cls) components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components) pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device) pipe = pipe.to(torch_device)
...@@ -561,13 +470,12 @@ class PeftLoraLoaderMixinTests: ...@@ -561,13 +470,12 @@ class PeftLoraLoaderMixinTests:
_, _, inputs = self.get_dummy_inputs(with_generator=False) _, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
shape_to_be_checked = (1, 64, 64, 3) if self.unet_kwargs is not None else (1, 32, 32, 3) self.assertTrue(output_no_lora.shape == (1, 64, 64, 3))
self.assertTrue(output_no_lora.shape == shape_to_be_checked)
pipe.text_encoder.add_adapter(text_lora_config) pipe.text_encoder.add_adapter(text_lora_config)
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
if self.has_two_text_encoders or self.has_three_text_encoders: if self.has_two_text_encoders:
pipe.text_encoder_2.add_adapter(text_lora_config) pipe.text_encoder_2.add_adapter(text_lora_config)
self.assertTrue( self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
...@@ -586,7 +494,7 @@ class PeftLoraLoaderMixinTests: ...@@ -586,7 +494,7 @@ class PeftLoraLoaderMixinTests:
"Lora not correctly set in text encoder", "Lora not correctly set in text encoder",
) )
if self.has_two_text_encoders or self.has_three_text_encoders: if self.has_two_text_encoders:
self.assertTrue( self.assertTrue(
check_if_lora_correctly_set(pipe_from_pretrained.text_encoder_2), check_if_lora_correctly_set(pipe_from_pretrained.text_encoder_2),
"Lora not correctly set in text encoder 2", "Lora not correctly set in text encoder 2",
...@@ -599,42 +507,27 @@ class PeftLoraLoaderMixinTests: ...@@ -599,42 +507,27 @@ class PeftLoraLoaderMixinTests:
"Loading from saved checkpoints should give same results.", "Loading from saved checkpoints should give same results.",
) )
def test_simple_inference_with_text_denoiser_lora_save_load(self): def test_simple_inference_with_text_unet_lora_save_load(self):
""" """
Tests a simple usecase where users could use saving utilities for LoRA for Unet + text encoder Tests a simple usecase where users could use saving utilities for LoRA for Unet + text encoder
""" """
scheduler_classes = ( for scheduler_cls in [DDIMScheduler, LCMScheduler]:
[FlowMatchEulerDiscreteScheduler] components, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
if self.has_three_text_encoders and self.transformer_kwargs
else [DDIMScheduler, LCMScheduler]
)
scheduler_classes = (
[FlowMatchEulerDiscreteScheduler]
if self.has_three_text_encoders and self.transformer_kwargs
else [DDIMScheduler, LCMScheduler]
)
for scheduler_cls in scheduler_classes:
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components) pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device) pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False) _, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
shape_to_be_checked = (1, 64, 64, 3) if self.unet_kwargs is not None else (1, 32, 32, 3) self.assertTrue(output_no_lora.shape == (1, 64, 64, 3))
self.assertTrue(output_no_lora.shape == shape_to_be_checked)
pipe.text_encoder.add_adapter(text_lora_config) pipe.text_encoder.add_adapter(text_lora_config)
if self.unet_kwargs is not None: pipe.unet.add_adapter(unet_lora_config)
pipe.unet.add_adapter(denoiser_lora_config)
else:
pipe.transformer.add_adapter(denoiser_lora_config)
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
denoiser_to_checked = pipe.unet if self.unet_kwargs is not None else pipe.transformer self.assertTrue(check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in Unet")
self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in Unet")
if self.has_two_text_encoders or self.has_three_text_encoders: if self.has_two_text_encoders:
pipe.text_encoder_2.add_adapter(text_lora_config) pipe.text_encoder_2.add_adapter(text_lora_config)
self.assertTrue( self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
...@@ -644,36 +537,22 @@ class PeftLoraLoaderMixinTests: ...@@ -644,36 +537,22 @@ class PeftLoraLoaderMixinTests:
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
text_encoder_state_dict = get_peft_model_state_dict(pipe.text_encoder) text_encoder_state_dict = get_peft_model_state_dict(pipe.text_encoder)
unet_state_dict = get_peft_model_state_dict(pipe.unet)
if self.unet_kwargs is not None: if self.has_two_text_encoders:
denoiser_state_dict = get_peft_model_state_dict(pipe.unet)
else:
denoiser_state_dict = get_peft_model_state_dict(pipe.transformer)
if self.has_two_text_encoders or self.has_three_text_encoders:
text_encoder_2_state_dict = get_peft_model_state_dict(pipe.text_encoder_2) text_encoder_2_state_dict = get_peft_model_state_dict(pipe.text_encoder_2)
if self.unet_kwargs is not None: self.pipeline_class.save_lora_weights(
self.pipeline_class.save_lora_weights( save_directory=tmpdirname,
save_directory=tmpdirname, text_encoder_lora_layers=text_encoder_state_dict,
text_encoder_lora_layers=text_encoder_state_dict, text_encoder_2_lora_layers=text_encoder_2_state_dict,
text_encoder_2_lora_layers=text_encoder_2_state_dict, unet_lora_layers=unet_state_dict,
unet_lora_layers=denoiser_state_dict, safe_serialization=False,
safe_serialization=False, )
)
else:
self.pipeline_class.save_lora_weights(
save_directory=tmpdirname,
text_encoder_lora_layers=text_encoder_state_dict,
text_encoder_2_lora_layers=text_encoder_2_state_dict,
transformer_lora_layers=denoiser_state_dict,
safe_serialization=False,
)
else: else:
self.pipeline_class.save_lora_weights( self.pipeline_class.save_lora_weights(
save_directory=tmpdirname, save_directory=tmpdirname,
text_encoder_lora_layers=text_encoder_state_dict, text_encoder_lora_layers=text_encoder_state_dict,
unet_lora_layers=denoiser_state_dict, unet_lora_layers=unet_state_dict,
safe_serialization=False, safe_serialization=False,
) )
...@@ -684,10 +563,9 @@ class PeftLoraLoaderMixinTests: ...@@ -684,10 +563,9 @@ class PeftLoraLoaderMixinTests:
images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0)).images images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0)).images
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
denoiser_to_checked = pipe.unet if self.unet_kwargs is not None else pipe.transformer self.assertTrue(check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in Unet")
self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser")
if self.has_two_text_encoders or self.has_three_text_encoders: if self.has_two_text_encoders:
self.assertTrue( self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
) )
...@@ -697,37 +575,27 @@ class PeftLoraLoaderMixinTests: ...@@ -697,37 +575,27 @@ class PeftLoraLoaderMixinTests:
"Loading from saved checkpoints should give same results.", "Loading from saved checkpoints should give same results.",
) )
def test_simple_inference_with_text_denoiser_lora_and_scale(self): def test_simple_inference_with_text_unet_lora_and_scale(self):
""" """
Tests a simple inference with lora attached on the text encoder + Unet + scale argument Tests a simple inference with lora attached on the text encoder + Unet + scale argument
and makes sure it works as expected and makes sure it works as expected
""" """
scheduler_classes = ( for scheduler_cls in [DDIMScheduler, LCMScheduler]:
[FlowMatchEulerDiscreteScheduler] components, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
if self.has_three_text_encoders and self.transformer_kwargs
else [DDIMScheduler, LCMScheduler]
)
for scheduler_cls in scheduler_classes:
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components) pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device) pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False) _, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
shape_to_be_checked = (1, 64, 64, 3) if self.unet_kwargs is not None else (1, 32, 32, 3) self.assertTrue(output_no_lora.shape == (1, 64, 64, 3))
self.assertTrue(output_no_lora.shape == shape_to_be_checked)
pipe.text_encoder.add_adapter(text_lora_config) pipe.text_encoder.add_adapter(text_lora_config)
if self.unet_kwargs is not None: pipe.unet.add_adapter(unet_lora_config)
pipe.unet.add_adapter(denoiser_lora_config)
else:
pipe.transformer.add_adapter(denoiser_lora_config)
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
denoiser_to_checked = pipe.unet if self.unet_kwargs is not None else pipe.transformer self.assertTrue(check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in Unet")
self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser")
if self.has_two_text_encoders or self.has_three_text_encoders: if self.has_two_text_encoders:
pipe.text_encoder_2.add_adapter(text_lora_config) pipe.text_encoder_2.add_adapter(text_lora_config)
self.assertTrue( self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
...@@ -738,27 +606,17 @@ class PeftLoraLoaderMixinTests: ...@@ -738,27 +606,17 @@ class PeftLoraLoaderMixinTests:
not np.allclose(output_lora, output_no_lora, atol=1e-3, rtol=1e-3), "Lora should change the output" not np.allclose(output_lora, output_no_lora, atol=1e-3, rtol=1e-3), "Lora should change the output"
) )
if self.unet_kwargs is not None: output_lora_scale = pipe(
output_lora_scale = pipe( **inputs, generator=torch.manual_seed(0), cross_attention_kwargs={"scale": 0.5}
**inputs, generator=torch.manual_seed(0), cross_attention_kwargs={"scale": 0.5} ).images
).images
else:
output_lora_scale = pipe(
**inputs, generator=torch.manual_seed(0), joint_attention_kwargs={"scale": 0.5}
).images
self.assertTrue( self.assertTrue(
not np.allclose(output_lora, output_lora_scale, atol=1e-3, rtol=1e-3), not np.allclose(output_lora, output_lora_scale, atol=1e-3, rtol=1e-3),
"Lora + scale should change the output", "Lora + scale should change the output",
) )
if self.unet_kwargs is not None: output_lora_0_scale = pipe(
output_lora_0_scale = pipe( **inputs, generator=torch.manual_seed(0), cross_attention_kwargs={"scale": 0.0}
**inputs, generator=torch.manual_seed(0), cross_attention_kwargs={"scale": 0.0} ).images
).images
else:
output_lora_0_scale = pipe(
**inputs, generator=torch.manual_seed(0), joint_attention_kwargs={"scale": 0.0}
).images
self.assertTrue( self.assertTrue(
np.allclose(output_no_lora, output_lora_0_scale, atol=1e-3, rtol=1e-3), np.allclose(output_no_lora, output_lora_0_scale, atol=1e-3, rtol=1e-3),
"Lora + 0 scale should lead to same result as no LoRA", "Lora + 0 scale should lead to same result as no LoRA",
...@@ -769,38 +627,28 @@ class PeftLoraLoaderMixinTests: ...@@ -769,38 +627,28 @@ class PeftLoraLoaderMixinTests:
"The scaling parameter has not been correctly restored!", "The scaling parameter has not been correctly restored!",
) )
def test_simple_inference_with_text_lora_denoiser_fused(self): def test_simple_inference_with_text_lora_unet_fused(self):
""" """
Tests a simple inference with lora attached into text encoder + fuses the lora weights into base model Tests a simple inference with lora attached into text encoder + fuses the lora weights into base model
and makes sure it works as expected - with unet and makes sure it works as expected - with unet
""" """
scheduler_classes = ( for scheduler_cls in [DDIMScheduler, LCMScheduler]:
[FlowMatchEulerDiscreteScheduler] components, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
if self.has_three_text_encoders and self.transformer_kwargs
else [DDIMScheduler, LCMScheduler]
)
for scheduler_cls in scheduler_classes:
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components) pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device) pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False) _, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
shape_to_be_checked = (1, 64, 64, 3) if self.unet_kwargs is not None else (1, 32, 32, 3) self.assertTrue(output_no_lora.shape == (1, 64, 64, 3))
self.assertTrue(output_no_lora.shape == shape_to_be_checked)
pipe.text_encoder.add_adapter(text_lora_config) pipe.text_encoder.add_adapter(text_lora_config)
if self.unet_kwargs is not None: pipe.unet.add_adapter(unet_lora_config)
pipe.unet.add_adapter(denoiser_lora_config)
else:
pipe.transformer.add_adapter(denoiser_lora_config)
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
denoiser_to_checked = pipe.unet if self.unet_kwargs is not None else pipe.transformer self.assertTrue(check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in Unet")
self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser")
if self.has_two_text_encoders or self.has_three_text_encoders: if self.has_two_text_encoders:
pipe.text_encoder_2.add_adapter(text_lora_config) pipe.text_encoder_2.add_adapter(text_lora_config)
self.assertTrue( self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
...@@ -809,10 +657,9 @@ class PeftLoraLoaderMixinTests: ...@@ -809,10 +657,9 @@ class PeftLoraLoaderMixinTests:
pipe.fuse_lora() pipe.fuse_lora()
# Fusing should still keep the LoRA layers # Fusing should still keep the LoRA layers
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
denoiser_to_checked = pipe.unet if self.unet_kwargs is not None else pipe.transformer self.assertTrue(check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in unet")
self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser")
if self.has_two_text_encoders or self.has_three_text_encoders: if self.has_two_text_encoders:
self.assertTrue( self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
) )
...@@ -822,37 +669,27 @@ class PeftLoraLoaderMixinTests: ...@@ -822,37 +669,27 @@ class PeftLoraLoaderMixinTests:
np.allclose(ouput_fused, output_no_lora, atol=1e-3, rtol=1e-3), "Fused lora should change the output" np.allclose(ouput_fused, output_no_lora, atol=1e-3, rtol=1e-3), "Fused lora should change the output"
) )
def test_simple_inference_with_text_denoiser_lora_unloaded(self): def test_simple_inference_with_text_unet_lora_unloaded(self):
""" """
Tests a simple inference with lora attached to text encoder and unet, then unloads the lora weights Tests a simple inference with lora attached to text encoder and unet, then unloads the lora weights
and makes sure it works as expected and makes sure it works as expected
""" """
scheduler_classes = ( for scheduler_cls in [DDIMScheduler, LCMScheduler]:
[FlowMatchEulerDiscreteScheduler] components, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
if self.has_three_text_encoders and self.transformer_kwargs
else [DDIMScheduler, LCMScheduler]
)
for scheduler_cls in scheduler_classes:
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components) pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device) pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False) _, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
shape_to_be_checked = (1, 64, 64, 3) if self.unet_kwargs is not None else (1, 32, 32, 3) self.assertTrue(output_no_lora.shape == (1, 64, 64, 3))
self.assertTrue(output_no_lora.shape == shape_to_be_checked)
pipe.text_encoder.add_adapter(text_lora_config) pipe.text_encoder.add_adapter(text_lora_config)
if self.unet_kwargs is not None: pipe.unet.add_adapter(unet_lora_config)
pipe.unet.add_adapter(denoiser_lora_config)
else:
pipe.transformer.add_adapter(denoiser_lora_config)
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
denoiser_to_checked = pipe.unet if self.unet_kwargs is not None else pipe.transformer self.assertTrue(check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in Unet")
self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser")
if self.has_two_text_encoders or self.has_three_text_encoders: if self.has_two_text_encoders:
pipe.text_encoder_2.add_adapter(text_lora_config) pipe.text_encoder_2.add_adapter(text_lora_config)
self.assertTrue( self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
...@@ -863,12 +700,9 @@ class PeftLoraLoaderMixinTests: ...@@ -863,12 +700,9 @@ class PeftLoraLoaderMixinTests:
self.assertFalse( self.assertFalse(
check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly unloaded in text encoder" check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly unloaded in text encoder"
) )
denoiser_to_checked = pipe.unet if self.unet_kwargs is not None else pipe.transformer self.assertFalse(check_if_lora_correctly_set(pipe.unet), "Lora not correctly unloaded in Unet")
self.assertFalse(
check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly unloaded in denoiser"
)
if self.has_two_text_encoders or self.has_three_text_encoders: if self.has_two_text_encoders:
self.assertFalse( self.assertFalse(
check_if_lora_correctly_set(pipe.text_encoder_2), check_if_lora_correctly_set(pipe.text_encoder_2),
"Lora not correctly unloaded in text encoder 2", "Lora not correctly unloaded in text encoder 2",
...@@ -880,34 +714,25 @@ class PeftLoraLoaderMixinTests: ...@@ -880,34 +714,25 @@ class PeftLoraLoaderMixinTests:
"Fused lora should change the output", "Fused lora should change the output",
) )
def test_simple_inference_with_text_denoiser_lora_unfused(self): def test_simple_inference_with_text_unet_lora_unfused(self):
""" """
Tests a simple inference with lora attached to text encoder and unet, then unloads the lora weights Tests a simple inference with lora attached to text encoder and unet, then unloads the lora weights
and makes sure it works as expected and makes sure it works as expected
""" """
scheduler_classes = ( for scheduler_cls in [DDIMScheduler, LCMScheduler]:
[FlowMatchEulerDiscreteScheduler] components, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
if self.has_three_text_encoders and self.transformer_kwargs
else [DDIMScheduler, LCMScheduler]
)
for scheduler_cls in scheduler_classes:
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components) pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device) pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False) _, _, inputs = self.get_dummy_inputs(with_generator=False)
pipe.text_encoder.add_adapter(text_lora_config) pipe.text_encoder.add_adapter(text_lora_config)
if self.unet_kwargs is not None: pipe.unet.add_adapter(unet_lora_config)
pipe.unet.add_adapter(denoiser_lora_config)
else:
pipe.transformer.add_adapter(denoiser_lora_config)
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
denoiser_to_checked = pipe.unet if self.unet_kwargs is not None else pipe.transformer self.assertTrue(check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in Unet")
self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser")
if self.has_two_text_encoders or self.has_three_text_encoders: if self.has_two_text_encoders:
pipe.text_encoder_2.add_adapter(text_lora_config) pipe.text_encoder_2.add_adapter(text_lora_config)
self.assertTrue( self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
...@@ -922,10 +747,9 @@ class PeftLoraLoaderMixinTests: ...@@ -922,10 +747,9 @@ class PeftLoraLoaderMixinTests:
output_unfused_lora = pipe(**inputs, generator=torch.manual_seed(0)).images output_unfused_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
# unloading should remove the LoRA layers # unloading should remove the LoRA layers
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Unfuse should still keep LoRA layers") self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Unfuse should still keep LoRA layers")
denoiser_to_checked = pipe.unet if self.unet_kwargs is not None else pipe.transformer self.assertTrue(check_if_lora_correctly_set(pipe.unet), "Unfuse should still keep LoRA layers")
self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Unfuse should still keep LoRA layers")
if self.has_two_text_encoders or self.has_three_text_encoders: if self.has_two_text_encoders:
self.assertTrue( self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder_2), "Unfuse should still keep LoRA layers" check_if_lora_correctly_set(pipe.text_encoder_2), "Unfuse should still keep LoRA layers"
) )
...@@ -936,18 +760,13 @@ class PeftLoraLoaderMixinTests: ...@@ -936,18 +760,13 @@ class PeftLoraLoaderMixinTests:
"Fused lora should change the output", "Fused lora should change the output",
) )
def test_simple_inference_with_text_denoiser_multi_adapter(self): def test_simple_inference_with_text_unet_multi_adapter(self):
""" """
Tests a simple inference with lora attached to text encoder and unet, attaches Tests a simple inference with lora attached to text encoder and unet, attaches
multiple adapters and set them multiple adapters and set them
""" """
scheduler_classes = ( for scheduler_cls in [DDIMScheduler, LCMScheduler]:
[FlowMatchEulerDiscreteScheduler] components, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
if self.has_three_text_encoders and self.transformer_kwargs
else [DDIMScheduler, LCMScheduler]
)
for scheduler_cls in scheduler_classes:
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components) pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device) pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
...@@ -958,20 +777,13 @@ class PeftLoraLoaderMixinTests: ...@@ -958,20 +777,13 @@ class PeftLoraLoaderMixinTests:
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
pipe.text_encoder.add_adapter(text_lora_config, "adapter-2") pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
if self.unet_kwargs is not None: pipe.unet.add_adapter(unet_lora_config, "adapter-1")
pipe.unet.add_adapter(denoiser_lora_config, "adapter-1") pipe.unet.add_adapter(unet_lora_config, "adapter-2")
else:
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
if self.unet_kwargs is not None:
pipe.unet.add_adapter(denoiser_lora_config, "adapter-2")
else:
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-2")
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
denoiser_to_checked = pipe.unet if self.unet_kwargs is not None else pipe.transformer self.assertTrue(check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in Unet")
self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser")
if self.has_two_text_encoders or self.has_three_text_encoders: if self.has_two_text_encoders:
pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1") pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2") pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2")
self.assertTrue( self.assertTrue(
...@@ -1014,21 +826,13 @@ class PeftLoraLoaderMixinTests: ...@@ -1014,21 +826,13 @@ class PeftLoraLoaderMixinTests:
"output with no lora and output with lora disabled should give same results", "output with no lora and output with lora disabled should give same results",
) )
def test_simple_inference_with_text_denoiser_block_scale(self): def test_simple_inference_with_text_unet_block_scale(self):
""" """
Tests a simple inference with lora attached to text encoder and unet, attaches Tests a simple inference with lora attached to text encoder and unet, attaches
one adapter and set differnt weights for different blocks (i.e. block lora) one adapter and set differnt weights for different blocks (i.e. block lora)
""" """
if self.pipeline_class.__name__ == "StableDiffusion3Pipeline": for scheduler_cls in [DDIMScheduler, LCMScheduler]:
return components, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
scheduler_classes = (
[FlowMatchEulerDiscreteScheduler]
if self.has_three_text_encoders and self.transformer_kwargs
else [DDIMScheduler, LCMScheduler]
)
for scheduler_cls in scheduler_classes:
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components) pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device) pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
...@@ -1037,16 +841,12 @@ class PeftLoraLoaderMixinTests: ...@@ -1037,16 +841,12 @@ class PeftLoraLoaderMixinTests:
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
if self.unet_kwargs is not None: pipe.unet.add_adapter(unet_lora_config, "adapter-1")
pipe.unet.add_adapter(denoiser_lora_config, "adapter-1")
else:
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
denoiser_to_checked = pipe.unet if self.unet_kwargs is not None else pipe.transformer self.assertTrue(check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in Unet")
self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser")
if self.has_two_text_encoders or self.has_three_text_encoders: if self.has_two_text_encoders:
pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1") pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
self.assertTrue( self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
...@@ -1081,21 +881,13 @@ class PeftLoraLoaderMixinTests: ...@@ -1081,21 +881,13 @@ class PeftLoraLoaderMixinTests:
"output with no lora and output with lora disabled should give same results", "output with no lora and output with lora disabled should give same results",
) )
def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self): def test_simple_inference_with_text_unet_multi_adapter_block_lora(self):
""" """
Tests a simple inference with lora attached to text encoder and unet, attaches Tests a simple inference with lora attached to text encoder and unet, attaches
multiple adapters and set differnt weights for different blocks (i.e. block lora) multiple adapters and set differnt weights for different blocks (i.e. block lora)
""" """
if self.pipeline_class.__name__ == "StableDiffusion3Pipeline": for scheduler_cls in [DDIMScheduler, LCMScheduler]:
return components, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
scheduler_classes = (
[FlowMatchEulerDiscreteScheduler]
if self.has_three_text_encoders and self.transformer_kwargs
else [DDIMScheduler, LCMScheduler]
)
for scheduler_cls in scheduler_classes:
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components) pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device) pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
...@@ -1106,20 +898,13 @@ class PeftLoraLoaderMixinTests: ...@@ -1106,20 +898,13 @@ class PeftLoraLoaderMixinTests:
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
pipe.text_encoder.add_adapter(text_lora_config, "adapter-2") pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
if self.unet_kwargs is not None: pipe.unet.add_adapter(unet_lora_config, "adapter-1")
pipe.unet.add_adapter(denoiser_lora_config, "adapter-1") pipe.unet.add_adapter(unet_lora_config, "adapter-2")
else:
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
if self.unet_kwargs is not None:
pipe.unet.add_adapter(denoiser_lora_config, "adapter-2")
else:
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-2")
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
denoiser_to_checked = pipe.unet if self.unet_kwargs is not None else pipe.transformer self.assertTrue(check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in Unet")
self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser")
if self.has_two_text_encoders or self.has_three_text_encoders: if self.has_two_text_encoders:
pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1") pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2") pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2")
self.assertTrue( self.assertTrue(
...@@ -1168,10 +953,8 @@ class PeftLoraLoaderMixinTests: ...@@ -1168,10 +953,8 @@ class PeftLoraLoaderMixinTests:
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
pipe.set_adapters(["adapter-1", "adapter-2"], [scales_1]) pipe.set_adapters(["adapter-1", "adapter-2"], [scales_1])
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self): def test_simple_inference_with_text_unet_block_scale_for_all_dict_options(self):
"""Tests that any valid combination of lora block scales can be used in pipe.set_adapter""" """Tests that any valid combination of lora block scales can be used in pipe.set_adapter"""
if self.pipeline_class.__name__ == "StableDiffusion3Pipeline":
return
def updown_options(blocks_with_tf, layers_per_block, value): def updown_options(blocks_with_tf, layers_per_block, value):
""" """
...@@ -1236,19 +1019,16 @@ class PeftLoraLoaderMixinTests: ...@@ -1236,19 +1019,16 @@ class PeftLoraLoaderMixinTests:
return opts return opts
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(self.scheduler_cls) components, text_lora_config, unet_lora_config = self.get_dummy_components(self.scheduler_cls)
pipe = self.pipeline_class(**components) pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device) pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False) _, _, inputs = self.get_dummy_inputs(with_generator=False)
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
if self.unet_kwargs is not None: pipe.unet.add_adapter(unet_lora_config, "adapter-1")
pipe.unet.add_adapter(denoiser_lora_config, "adapter-1")
else:
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
if self.has_two_text_encoders or self.has_three_text_encoders: if self.has_two_text_encoders:
pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1") pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
for scale_dict in all_possible_dict_opts(pipe.unet, value=1234): for scale_dict in all_possible_dict_opts(pipe.unet, value=1234):
...@@ -1258,18 +1038,13 @@ class PeftLoraLoaderMixinTests: ...@@ -1258,18 +1038,13 @@ class PeftLoraLoaderMixinTests:
pipe.set_adapters("adapter-1", scale_dict) # test will fail if this line throws an error pipe.set_adapters("adapter-1", scale_dict) # test will fail if this line throws an error
def test_simple_inference_with_text_denoiser_multi_adapter_delete_adapter(self): def test_simple_inference_with_text_unet_multi_adapter_delete_adapter(self):
""" """
Tests a simple inference with lora attached to text encoder and unet, attaches Tests a simple inference with lora attached to text encoder and unet, attaches
multiple adapters and set/delete them multiple adapters and set/delete them
""" """
scheduler_classes = ( for scheduler_cls in [DDIMScheduler, LCMScheduler]:
[FlowMatchEulerDiscreteScheduler] components, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
if self.has_three_text_encoders and self.transformer_kwargs
else [DDIMScheduler, LCMScheduler]
)
for scheduler_cls in scheduler_classes:
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components) pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device) pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
...@@ -1280,20 +1055,13 @@ class PeftLoraLoaderMixinTests: ...@@ -1280,20 +1055,13 @@ class PeftLoraLoaderMixinTests:
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
pipe.text_encoder.add_adapter(text_lora_config, "adapter-2") pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
if self.unet_kwargs is not None: pipe.unet.add_adapter(unet_lora_config, "adapter-1")
pipe.unet.add_adapter(denoiser_lora_config, "adapter-1") pipe.unet.add_adapter(unet_lora_config, "adapter-2")
else:
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
if self.unet_kwargs is not None:
pipe.unet.add_adapter(denoiser_lora_config, "adapter-2")
else:
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-2")
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
denoiser_to_checked = pipe.unet if self.unet_kwargs is not None else pipe.transformer self.assertTrue(check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in Unet")
self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser")
if self.has_two_text_encoders or self.has_three_text_encoders: if self.has_two_text_encoders:
pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1") pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2") pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2")
self.assertTrue( self.assertTrue(
...@@ -1345,14 +1113,8 @@ class PeftLoraLoaderMixinTests: ...@@ -1345,14 +1113,8 @@ class PeftLoraLoaderMixinTests:
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
pipe.text_encoder.add_adapter(text_lora_config, "adapter-2") pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
if self.unet_kwargs is not None: pipe.unet.add_adapter(unet_lora_config, "adapter-1")
pipe.unet.add_adapter(denoiser_lora_config, "adapter-1") pipe.unet.add_adapter(unet_lora_config, "adapter-2")
else:
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
if self.unet_kwargs is not None:
pipe.unet.add_adapter(denoiser_lora_config, "adapter-2")
else:
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-2")
pipe.set_adapters(["adapter-1", "adapter-2"]) pipe.set_adapters(["adapter-1", "adapter-2"])
pipe.delete_adapters(["adapter-1", "adapter-2"]) pipe.delete_adapters(["adapter-1", "adapter-2"])
...@@ -1364,18 +1126,13 @@ class PeftLoraLoaderMixinTests: ...@@ -1364,18 +1126,13 @@ class PeftLoraLoaderMixinTests:
"output with no lora and output with lora disabled should give same results", "output with no lora and output with lora disabled should give same results",
) )
def test_simple_inference_with_text_denoiser_multi_adapter_weighted(self): def test_simple_inference_with_text_unet_multi_adapter_weighted(self):
""" """
Tests a simple inference with lora attached to text encoder and unet, attaches Tests a simple inference with lora attached to text encoder and unet, attaches
multiple adapters and set them multiple adapters and set them
""" """
scheduler_classes = ( for scheduler_cls in [DDIMScheduler, LCMScheduler]:
[FlowMatchEulerDiscreteScheduler] components, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
if self.has_three_text_encoders and self.transformer_kwargs
else [DDIMScheduler, LCMScheduler]
)
for scheduler_cls in scheduler_classes:
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components) pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device) pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
...@@ -1386,20 +1143,13 @@ class PeftLoraLoaderMixinTests: ...@@ -1386,20 +1143,13 @@ class PeftLoraLoaderMixinTests:
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
pipe.text_encoder.add_adapter(text_lora_config, "adapter-2") pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
if self.unet_kwargs is not None: pipe.unet.add_adapter(unet_lora_config, "adapter-1")
pipe.unet.add_adapter(denoiser_lora_config, "adapter-1") pipe.unet.add_adapter(unet_lora_config, "adapter-2")
else:
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
if self.unet_kwargs is not None:
pipe.unet.add_adapter(denoiser_lora_config, "adapter-2")
else:
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-2")
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
denoiser_to_checked = pipe.unet if self.unet_kwargs is not None else pipe.transformer self.assertTrue(check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in Unet")
self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser")
if self.has_two_text_encoders or self.has_three_text_encoders: if self.has_two_text_encoders:
pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1") pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2") pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2")
self.assertTrue( self.assertTrue(
...@@ -1452,13 +1202,8 @@ class PeftLoraLoaderMixinTests: ...@@ -1452,13 +1202,8 @@ class PeftLoraLoaderMixinTests:
@skip_mps @skip_mps
def test_lora_fuse_nan(self): def test_lora_fuse_nan(self):
scheduler_classes = ( for scheduler_cls in [DDIMScheduler, LCMScheduler]:
[FlowMatchEulerDiscreteScheduler] components, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
if self.has_three_text_encoders and self.transformer_kwargs
else [DDIMScheduler, LCMScheduler]
)
for scheduler_cls in scheduler_classes:
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components) pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device) pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
...@@ -1466,23 +1211,16 @@ class PeftLoraLoaderMixinTests: ...@@ -1466,23 +1211,16 @@ class PeftLoraLoaderMixinTests:
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
if self.unet_kwargs is not None: pipe.unet.add_adapter(unet_lora_config, "adapter-1")
pipe.unet.add_adapter(denoiser_lora_config, "adapter-1")
else:
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
denoiser_to_checked = pipe.unet if self.unet_kwargs is not None else pipe.transformer self.assertTrue(check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in Unet")
self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser")
# corrupt one LoRA weight with `inf` values # corrupt one LoRA weight with `inf` values
with torch.no_grad(): with torch.no_grad():
if self.unet_kwargs: pipe.unet.mid_block.attentions[0].transformer_blocks[0].attn1.to_q.lora_A["adapter-1"].weight += float(
pipe.unet.mid_block.attentions[0].transformer_blocks[0].attn1.to_q.lora_A[ "inf"
"adapter-1" )
].weight += float("inf")
else:
pipe.transformer.transformer_blocks[0].attn.to_q.lora_A["adapter-1"].weight += float("inf")
# with `safe_fusing=True` we should see an Error # with `safe_fusing=True` we should see an Error
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
...@@ -1500,32 +1238,21 @@ class PeftLoraLoaderMixinTests: ...@@ -1500,32 +1238,21 @@ class PeftLoraLoaderMixinTests:
Tests a simple usecase where we attach multiple adapters and check if the results Tests a simple usecase where we attach multiple adapters and check if the results
are the expected results are the expected results
""" """
scheduler_classes = ( for scheduler_cls in [DDIMScheduler, LCMScheduler]:
[FlowMatchEulerDiscreteScheduler] components, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
if self.has_three_text_encoders and self.transformer_kwargs
else [DDIMScheduler, LCMScheduler]
)
for scheduler_cls in scheduler_classes:
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components) pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device) pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False) _, _, inputs = self.get_dummy_inputs(with_generator=False)
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
if self.unet_kwargs is not None: pipe.unet.add_adapter(unet_lora_config, "adapter-1")
pipe.unet.add_adapter(denoiser_lora_config, "adapter-1")
else:
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
adapter_names = pipe.get_active_adapters() adapter_names = pipe.get_active_adapters()
self.assertListEqual(adapter_names, ["adapter-1"]) self.assertListEqual(adapter_names, ["adapter-1"])
pipe.text_encoder.add_adapter(text_lora_config, "adapter-2") pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
if self.unet_kwargs is not None: pipe.unet.add_adapter(unet_lora_config, "adapter-2")
pipe.unet.add_adapter(denoiser_lora_config, "adapter-2")
else:
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-2")
adapter_names = pipe.get_active_adapters() adapter_names = pipe.get_active_adapters()
self.assertListEqual(adapter_names, ["adapter-2"]) self.assertListEqual(adapter_names, ["adapter-2"])
...@@ -1538,108 +1265,65 @@ class PeftLoraLoaderMixinTests: ...@@ -1538,108 +1265,65 @@ class PeftLoraLoaderMixinTests:
Tests a simple usecase where we attach multiple adapters and check if the results Tests a simple usecase where we attach multiple adapters and check if the results
are the expected results are the expected results
""" """
scheduler_classes = ( for scheduler_cls in [DDIMScheduler, LCMScheduler]:
[FlowMatchEulerDiscreteScheduler] components, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
if self.has_three_text_encoders and self.transformer_kwargs
else [DDIMScheduler, LCMScheduler]
)
for scheduler_cls in scheduler_classes:
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components) pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device) pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
if self.unet_kwargs is not None: pipe.unet.add_adapter(unet_lora_config, "adapter-1")
pipe.unet.add_adapter(denoiser_lora_config, "adapter-1")
else:
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
adapter_names = pipe.get_list_adapters() adapter_names = pipe.get_list_adapters()
dicts_to_be_checked = {"text_encoder": ["adapter-1"]} self.assertDictEqual(adapter_names, {"text_encoder": ["adapter-1"], "unet": ["adapter-1"]})
if self.unet_kwargs is not None:
dicts_to_be_checked.update({"unet": ["adapter-1"]})
else:
dicts_to_be_checked.update({"transformer": ["adapter-1"]})
self.assertDictEqual(adapter_names, dicts_to_be_checked)
pipe.text_encoder.add_adapter(text_lora_config, "adapter-2") pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
if self.unet_kwargs is not None: pipe.unet.add_adapter(unet_lora_config, "adapter-2")
pipe.unet.add_adapter(denoiser_lora_config, "adapter-2")
else:
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-2")
adapter_names = pipe.get_list_adapters() adapter_names = pipe.get_list_adapters()
dicts_to_be_checked = {"text_encoder": ["adapter-1", "adapter-2"]} self.assertDictEqual(
if self.unet_kwargs is not None: adapter_names, {"text_encoder": ["adapter-1", "adapter-2"], "unet": ["adapter-1", "adapter-2"]}
dicts_to_be_checked.update({"unet": ["adapter-1", "adapter-2"]}) )
else:
dicts_to_be_checked.update({"transformer": ["adapter-1", "adapter-2"]})
self.assertDictEqual(adapter_names, dicts_to_be_checked)
pipe.set_adapters(["adapter-1", "adapter-2"]) pipe.set_adapters(["adapter-1", "adapter-2"])
dicts_to_be_checked = {"text_encoder": ["adapter-1", "adapter-2"]}
if self.unet_kwargs is not None:
dicts_to_be_checked.update({"unet": ["adapter-1", "adapter-2"]})
else:
dicts_to_be_checked.update({"transformer": ["adapter-1", "adapter-2"]})
self.assertDictEqual( self.assertDictEqual(
pipe.get_list_adapters(), pipe.get_list_adapters(),
dicts_to_be_checked, {"unet": ["adapter-1", "adapter-2"], "text_encoder": ["adapter-1", "adapter-2"]},
) )
if self.unet_kwargs is not None: pipe.unet.add_adapter(unet_lora_config, "adapter-3")
pipe.unet.add_adapter(denoiser_lora_config, "adapter-3") self.assertDictEqual(
else: pipe.get_list_adapters(),
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-3") {"unet": ["adapter-1", "adapter-2", "adapter-3"], "text_encoder": ["adapter-1", "adapter-2"]},
)
dicts_to_be_checked = {"text_encoder": ["adapter-1", "adapter-2"]}
if self.unet_kwargs is not None:
dicts_to_be_checked.update({"unet": ["adapter-1", "adapter-2", "adapter-3"]})
else:
dicts_to_be_checked.update({"transformer": ["adapter-1", "adapter-2", "adapter-3"]})
self.assertDictEqual(pipe.get_list_adapters(), dicts_to_be_checked)
@require_peft_version_greater(peft_version="0.6.2") @require_peft_version_greater(peft_version="0.6.2")
def test_simple_inference_with_text_lora_denoiser_fused_multi(self): def test_simple_inference_with_text_lora_unet_fused_multi(self):
""" """
Tests a simple inference with lora attached into text encoder + fuses the lora weights into base model Tests a simple inference with lora attached into text encoder + fuses the lora weights into base model
and makes sure it works as expected - with unet and multi-adapter case and makes sure it works as expected - with unet and multi-adapter case
""" """
scheduler_classes = ( for scheduler_cls in [DDIMScheduler, LCMScheduler]:
[FlowMatchEulerDiscreteScheduler] components, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
if self.has_three_text_encoders and self.transformer_kwargs
else [DDIMScheduler, LCMScheduler]
)
for scheduler_cls in scheduler_classes:
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components) pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device) pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False) _, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
shape_to_be_checked = (1, 64, 64, 3) if self.unet_kwargs is not None else (1, 32, 32, 3) self.assertTrue(output_no_lora.shape == (1, 64, 64, 3))
self.assertTrue(output_no_lora.shape == shape_to_be_checked)
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
if self.unet_kwargs is not None: pipe.unet.add_adapter(unet_lora_config, "adapter-1")
pipe.unet.add_adapter(denoiser_lora_config, "adapter-1")
else:
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
# Attach a second adapter # Attach a second adapter
pipe.text_encoder.add_adapter(text_lora_config, "adapter-2") pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
if self.unet_kwargs is not None: pipe.unet.add_adapter(unet_lora_config, "adapter-2")
pipe.unet.add_adapter(denoiser_lora_config, "adapter-2")
else:
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-2")
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
denoiser_to_checked = pipe.unet if self.unet_kwargs is not None else pipe.transformer self.assertTrue(check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in Unet")
self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser")
if self.has_two_text_encoders or self.has_three_text_encoders: if self.has_two_text_encoders:
pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1") pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2") pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2")
self.assertTrue( self.assertTrue(
...@@ -1675,35 +1359,23 @@ class PeftLoraLoaderMixinTests: ...@@ -1675,35 +1359,23 @@ class PeftLoraLoaderMixinTests:
@require_peft_version_greater(peft_version="0.9.0") @require_peft_version_greater(peft_version="0.9.0")
def test_simple_inference_with_dora(self): def test_simple_inference_with_dora(self):
scheduler_classes = ( for scheduler_cls in [DDIMScheduler, LCMScheduler]:
[FlowMatchEulerDiscreteScheduler] components, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls, use_dora=True)
if self.has_three_text_encoders and self.transformer_kwargs
else [DDIMScheduler, LCMScheduler]
)
for scheduler_cls in scheduler_classes:
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(
scheduler_cls, use_dora=True
)
pipe = self.pipeline_class(**components) pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device) pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False) _, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_dora_lora = pipe(**inputs, generator=torch.manual_seed(0)).images output_no_dora_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
shape_to_be_checked = (1, 64, 64, 3) if self.unet_kwargs is not None else (1, 32, 32, 3) self.assertTrue(output_no_dora_lora.shape == (1, 64, 64, 3))
self.assertTrue(output_no_dora_lora.shape == shape_to_be_checked)
pipe.text_encoder.add_adapter(text_lora_config) pipe.text_encoder.add_adapter(text_lora_config)
if self.unet_kwargs is not None: pipe.unet.add_adapter(unet_lora_config)
pipe.unet.add_adapter(denoiser_lora_config)
else:
pipe.transformer.add_adapter(denoiser_lora_config)
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
denoiser_to_checked = pipe.unet if self.unet_kwargs is not None else pipe.transformer self.assertTrue(check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in Unet")
self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser")
if self.has_two_text_encoders or self.has_three_text_encoders: if self.has_two_text_encoders:
pipe.text_encoder_2.add_adapter(text_lora_config) pipe.text_encoder_2.add_adapter(text_lora_config)
self.assertTrue( self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
...@@ -1717,34 +1389,25 @@ class PeftLoraLoaderMixinTests: ...@@ -1717,34 +1389,25 @@ class PeftLoraLoaderMixinTests:
) )
@unittest.skip("This is failing for now - need to investigate") @unittest.skip("This is failing for now - need to investigate")
def test_simple_inference_with_text_denoiser_lora_unfused_torch_compile(self): def test_simple_inference_with_text_unet_lora_unfused_torch_compile(self):
""" """
Tests a simple inference with lora attached to text encoder and unet, then unloads the lora weights Tests a simple inference with lora attached to text encoder and unet, then unloads the lora weights
and makes sure it works as expected and makes sure it works as expected
""" """
scheduler_classes = ( for scheduler_cls in [DDIMScheduler, LCMScheduler]:
[FlowMatchEulerDiscreteScheduler] components, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
if self.has_three_text_encoders and self.transformer_kwargs
else [DDIMScheduler, LCMScheduler]
)
for scheduler_cls in scheduler_classes:
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components) pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device) pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False) _, _, inputs = self.get_dummy_inputs(with_generator=False)
pipe.text_encoder.add_adapter(text_lora_config) pipe.text_encoder.add_adapter(text_lora_config)
if self.unet_kwargs is not None: pipe.unet.add_adapter(unet_lora_config)
pipe.unet.add_adapter(denoiser_lora_config)
else:
pipe.transformer.add_adapter(denoiser_lora_config)
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
denoiser_to_checked = pipe.unet if self.unet_kwargs is not None else pipe.transformer self.assertTrue(check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in Unet")
self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser")
if self.has_two_text_encoders or self.has_three_text_encoders: if self.has_two_text_encoders:
pipe.text_encoder_2.add_adapter(text_lora_config) pipe.text_encoder_2.add_adapter(text_lora_config)
self.assertTrue( self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
...@@ -1753,27 +1416,19 @@ class PeftLoraLoaderMixinTests: ...@@ -1753,27 +1416,19 @@ class PeftLoraLoaderMixinTests:
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True) pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
pipe.text_encoder = torch.compile(pipe.text_encoder, mode="reduce-overhead", fullgraph=True) pipe.text_encoder = torch.compile(pipe.text_encoder, mode="reduce-overhead", fullgraph=True)
if self.has_two_text_encoders or self.has_three_text_encoders: if self.has_two_text_encoders:
pipe.text_encoder_2 = torch.compile(pipe.text_encoder_2, mode="reduce-overhead", fullgraph=True) pipe.text_encoder_2 = torch.compile(pipe.text_encoder_2, mode="reduce-overhead", fullgraph=True)
# Just makes sure it works.. # Just makes sure it works..
_ = pipe(**inputs, generator=torch.manual_seed(0)).images _ = pipe(**inputs, generator=torch.manual_seed(0)).images
def test_modify_padding_mode(self): def test_modify_padding_mode(self):
if self.pipeline_class.__name__ == "StableDiffusion3Pipeline":
return
def set_pad_mode(network, mode="circular"): def set_pad_mode(network, mode="circular"):
for _, module in network.named_modules(): for _, module in network.named_modules():
if isinstance(module, torch.nn.Conv2d): if isinstance(module, torch.nn.Conv2d):
module.padding_mode = mode module.padding_mode = mode
scheduler_classes = ( for scheduler_cls in [DDIMScheduler, LCMScheduler]:
[FlowMatchEulerDiscreteScheduler]
if self.has_three_text_encoders and self.transformer_kwargs
else [DDIMScheduler, LCMScheduler]
)
for scheduler_cls in scheduler_classes:
components, _, _ = self.get_dummy_components(scheduler_cls) components, _, _ = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components) pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device) pipe = pipe.to(torch_device)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment