"...text-generation-inference.git" did not exist on "efb73fcb598fbb93c6cae7d6667a58b373b0de96"
Unverified Commit 107e0216 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[LoRA tests] fix stuff related to assertions arising from the recent changes. (#6448)

* debug

* debug test_with_different_scales_fusion_equivalence

* use the right method.

* place it right.

* let's see.

* let's see again

* alright then.

* add a comment.
parent 6dbef45e
...@@ -317,9 +317,9 @@ class LoraLoaderMixinTests(unittest.TestCase): ...@@ -317,9 +317,9 @@ class LoraLoaderMixinTests(unittest.TestCase):
text_encoder_lora_params = LoraLoaderMixin._modify_text_encoder( text_encoder_lora_params = LoraLoaderMixin._modify_text_encoder(
text_encoder, dtype=torch.float32, rank=self.lora_rank text_encoder, dtype=torch.float32, rank=self.lora_rank
) )
text_encoder_lora_params = set_lora_weights( text_encoder_lora_params = text_encoder_lora_state_dict(text_encoder)
text_encoder_lora_state_dict(text_encoder), randn_weight=True, var=0.1 # We call this to ensure that the effects of the in-place `_modify_text_encoder` have been erased.
) LoraLoaderMixin._remove_text_encoder_monkey_patch_classmethod(text_encoder)
pipeline_components = { pipeline_components = {
"unet": unet, "unet": unet,
...@@ -937,18 +937,17 @@ class SDXLLoraLoaderMixinTests(unittest.TestCase): ...@@ -937,18 +937,17 @@ class SDXLLoraLoaderMixinTests(unittest.TestCase):
_, unet_lora_params = create_unet_lora_layers(unet, rank=self.lora_rank) _, unet_lora_params = create_unet_lora_layers(unet, rank=self.lora_rank)
if modify_text_encoder: if modify_text_encoder:
text_encoder_lora_params = StableDiffusionXLLoraLoaderMixin._modify_text_encoder( _ = StableDiffusionXLLoraLoaderMixin._modify_text_encoder(
text_encoder, dtype=torch.float32, rank=self.lora_rank text_encoder, dtype=torch.float32, rank=self.lora_rank
) )
text_encoder_lora_params = set_lora_weights( text_encoder_lora_params = text_encoder_lora_state_dict(text_encoder)
text_encoder_lora_state_dict(text_encoder), randn_weight=True, var=0.1 StableDiffusionXLLoraLoaderMixin._remove_text_encoder_monkey_patch_classmethod(text_encoder)
)
text_encoder_two_lora_params = StableDiffusionXLLoraLoaderMixin._modify_text_encoder( _ = StableDiffusionXLLoraLoaderMixin._modify_text_encoder(
text_encoder_2, dtype=torch.float32, rank=self.lora_rank text_encoder_2, dtype=torch.float32, rank=self.lora_rank
) )
text_encoder_two_lora_params = set_lora_weights( text_encoder_two_lora_params = text_encoder_lora_state_dict(text_encoder_2)
text_encoder_lora_state_dict(text_encoder_2), randn_weight=True, var=0.1 StableDiffusionXLLoraLoaderMixin._remove_text_encoder_monkey_patch_classmethod(text_encoder_2)
)
else: else:
text_encoder_lora_params = None text_encoder_lora_params = None
text_encoder_two_lora_params = None text_encoder_two_lora_params = None
...@@ -1446,7 +1445,7 @@ class SDXLLoraLoaderMixinTests(unittest.TestCase): ...@@ -1446,7 +1445,7 @@ class SDXLLoraLoaderMixinTests(unittest.TestCase):
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
sd_pipe.save_pretrained(tmpdirname) sd_pipe.save_pretrained(tmpdirname)
sd_pipe_loaded = StableDiffusionXLPipeline.from_pretrained(tmpdirname) sd_pipe_loaded = StableDiffusionXLPipeline.from_pretrained(tmpdirname).to(torch_device)
loaded_lora_images = sd_pipe_loaded(**pipeline_inputs, generator=torch.manual_seed(0)).images loaded_lora_images = sd_pipe_loaded(**pipeline_inputs, generator=torch.manual_seed(0)).images
loaded_lora_image_slice = loaded_lora_images[0, -3:, -3:, -1] loaded_lora_image_slice = loaded_lora_images[0, -3:, -3:, -1]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment