Unverified Commit 2523390c authored by Benjamin Bossan's avatar Benjamin Bossan Committed by GitHub
Browse files

FIX Setting device for DoRA parameters (#7655)

Fix a bug that causes the the call to set_lora_device to ignore the DoRA
parameters.
parent 279de3c3
...@@ -1267,6 +1267,10 @@ class LoraLoaderMixin: ...@@ -1267,6 +1267,10 @@ class LoraLoaderMixin:
for adapter_name in adapter_names: for adapter_name in adapter_names:
unet_module.lora_A[adapter_name].to(device) unet_module.lora_A[adapter_name].to(device)
unet_module.lora_B[adapter_name].to(device) unet_module.lora_B[adapter_name].to(device)
# this is a param, not a module, so device placement is not in-place -> re-assign
unet_module.lora_magnitude_vector[adapter_name] = unet_module.lora_magnitude_vector[
adapter_name
].to(device)
# Handle the text encoder # Handle the text encoder
modules_to_process = [] modules_to_process = []
...@@ -1283,6 +1287,10 @@ class LoraLoaderMixin: ...@@ -1283,6 +1287,10 @@ class LoraLoaderMixin:
for adapter_name in adapter_names: for adapter_name in adapter_names:
text_encoder_module.lora_A[adapter_name].to(device) text_encoder_module.lora_A[adapter_name].to(device)
text_encoder_module.lora_B[adapter_name].to(device) text_encoder_module.lora_B[adapter_name].to(device)
# this is a param, not a module, so device placement is not in-place -> re-assign
text_encoder_module.lora_magnitude_vector[
adapter_name
] = text_encoder_module.lora_magnitude_vector[adapter_name].to(device)
class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin): class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin):
......
...@@ -150,6 +150,54 @@ class StableDiffusionLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase): ...@@ -150,6 +150,54 @@ class StableDiffusionLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase):
if ("adapter-1" in n or "adapter-2" in n) and not isinstance(m, (nn.Dropout, nn.Identity)): if ("adapter-1" in n or "adapter-2" in n) and not isinstance(m, (nn.Dropout, nn.Identity)):
self.assertTrue(m.weight.device != torch.device("cpu")) self.assertTrue(m.weight.device != torch.device("cpu"))
@require_torch_gpu
def test_integration_move_lora_dora_cpu(self):
from peft import LoraConfig
path = "runwayml/stable-diffusion-v1-5"
unet_lora_config = LoraConfig(
init_lora_weights="gaussian",
target_modules=["to_k", "to_q", "to_v", "to_out.0"],
use_dora=True,
)
text_lora_config = LoraConfig(
init_lora_weights="gaussian",
target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
use_dora=True,
)
pipe = StableDiffusionPipeline.from_pretrained(path, torch_dtype=torch.float16)
pipe.unet.add_adapter(unet_lora_config, "adapter-1")
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder),
"Lora not correctly set in text encoder",
)
self.assertTrue(
check_if_lora_correctly_set(pipe.unet),
"Lora not correctly set in text encoder",
)
for name, param in pipe.unet.named_parameters():
if "lora_" in name:
self.assertEqual(param.device, torch.device("cpu"))
for name, param in pipe.text_encoder.named_parameters():
if "lora_" in name:
self.assertEqual(param.device, torch.device("cpu"))
pipe.set_lora_device(["adapter-1"], torch_device)
for name, param in pipe.unet.named_parameters():
if "lora_" in name:
self.assertNotEqual(param.device, torch.device("cpu"))
for name, param in pipe.text_encoder.named_parameters():
if "lora_" in name:
self.assertNotEqual(param.device, torch.device("cpu"))
@slow @slow
@require_torch_gpu @require_torch_gpu
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment