"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "4f8853e48184b5610b08b5fe8545b16a693066e1"
Unverified Commit 142f353e authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

Fix lora device test (#7738)



* fix lora device test

* fix more.

* fix more/

* quality

* empty

---------
Co-authored-by: default avatarDhruv Nair <dhruv.nair@gmail.com>
parent b833d0fc
...@@ -1268,9 +1268,10 @@ class LoraLoaderMixin: ...@@ -1268,9 +1268,10 @@ class LoraLoaderMixin:
unet_module.lora_A[adapter_name].to(device) unet_module.lora_A[adapter_name].to(device)
unet_module.lora_B[adapter_name].to(device) unet_module.lora_B[adapter_name].to(device)
# this is a param, not a module, so device placement is not in-place -> re-assign # this is a param, not a module, so device placement is not in-place -> re-assign
unet_module.lora_magnitude_vector[adapter_name] = unet_module.lora_magnitude_vector[ if hasattr(unet_module, "lora_magnitude_vector") and unet_module.lora_magnitude_vector is not None:
adapter_name unet_module.lora_magnitude_vector[adapter_name] = unet_module.lora_magnitude_vector[
].to(device) adapter_name
].to(device)
# Handle the text encoder # Handle the text encoder
modules_to_process = [] modules_to_process = []
...@@ -1288,9 +1289,13 @@ class LoraLoaderMixin: ...@@ -1288,9 +1289,13 @@ class LoraLoaderMixin:
text_encoder_module.lora_A[adapter_name].to(device) text_encoder_module.lora_A[adapter_name].to(device)
text_encoder_module.lora_B[adapter_name].to(device) text_encoder_module.lora_B[adapter_name].to(device)
# this is a param, not a module, so device placement is not in-place -> re-assign # this is a param, not a module, so device placement is not in-place -> re-assign
text_encoder_module.lora_magnitude_vector[ if (
adapter_name hasattr(text_encoder, "lora_magnitude_vector")
] = text_encoder_module.lora_magnitude_vector[adapter_name].to(device) and text_encoder_module.lora_magnitude_vector is not None
):
text_encoder_module.lora_magnitude_vector[
adapter_name
] = text_encoder_module.lora_magnitude_vector[adapter_name].to(device)
class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin): class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment