Unverified Commit 142f353e authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

Fix lora device test (#7738)



* fix lora device test

* fix more.

* fix more/

* quality

* empty

---------
Co-authored-by: default avatarDhruv Nair <dhruv.nair@gmail.com>
parent b833d0fc
......@@ -1268,9 +1268,10 @@ class LoraLoaderMixin:
unet_module.lora_A[adapter_name].to(device)
unet_module.lora_B[adapter_name].to(device)
# this is a param, not a module, so device placement is not in-place -> re-assign
unet_module.lora_magnitude_vector[adapter_name] = unet_module.lora_magnitude_vector[
adapter_name
].to(device)
if hasattr(unet_module, "lora_magnitude_vector") and unet_module.lora_magnitude_vector is not None:
unet_module.lora_magnitude_vector[adapter_name] = unet_module.lora_magnitude_vector[
adapter_name
].to(device)
# Handle the text encoder
modules_to_process = []
......@@ -1288,9 +1289,13 @@ class LoraLoaderMixin:
text_encoder_module.lora_A[adapter_name].to(device)
text_encoder_module.lora_B[adapter_name].to(device)
# this is a param, not a module, so device placement is not in-place -> re-assign
text_encoder_module.lora_magnitude_vector[
adapter_name
] = text_encoder_module.lora_magnitude_vector[adapter_name].to(device)
if (
hasattr(text_encoder, "lora_magnitude_vector")
and text_encoder_module.lora_magnitude_vector is not None
):
text_encoder_module.lora_magnitude_vector[
adapter_name
] = text_encoder_module.lora_magnitude_vector[adapter_name].to(device)
class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment