Unverified Commit 142f353e authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

Fix lora device test (#7738)



* fix lora device test

* fix more.

* fix more/

* quality

* empty

---------
Co-authored-by: default avatarDhruv Nair <dhruv.nair@gmail.com>
parent b833d0fc
...@@ -1268,6 +1268,7 @@ class LoraLoaderMixin: ...@@ -1268,6 +1268,7 @@ class LoraLoaderMixin:
unet_module.lora_A[adapter_name].to(device) unet_module.lora_A[adapter_name].to(device)
unet_module.lora_B[adapter_name].to(device) unet_module.lora_B[adapter_name].to(device)
# this is a param, not a module, so device placement is not in-place -> re-assign # this is a param, not a module, so device placement is not in-place -> re-assign
if hasattr(unet_module, "lora_magnitude_vector") and unet_module.lora_magnitude_vector is not None:
unet_module.lora_magnitude_vector[adapter_name] = unet_module.lora_magnitude_vector[ unet_module.lora_magnitude_vector[adapter_name] = unet_module.lora_magnitude_vector[
adapter_name adapter_name
].to(device) ].to(device)
...@@ -1288,6 +1289,10 @@ class LoraLoaderMixin: ...@@ -1288,6 +1289,10 @@ class LoraLoaderMixin:
text_encoder_module.lora_A[adapter_name].to(device) text_encoder_module.lora_A[adapter_name].to(device)
text_encoder_module.lora_B[adapter_name].to(device) text_encoder_module.lora_B[adapter_name].to(device)
# this is a param, not a module, so device placement is not in-place -> re-assign # this is a param, not a module, so device placement is not in-place -> re-assign
if (
hasattr(text_encoder, "lora_magnitude_vector")
and text_encoder_module.lora_magnitude_vector is not None
):
text_encoder_module.lora_magnitude_vector[ text_encoder_module.lora_magnitude_vector[
adapter_name adapter_name
] = text_encoder_module.lora_magnitude_vector[adapter_name].to(device) ] = text_encoder_module.lora_magnitude_vector[adapter_name].to(device)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment