Unverified Commit 25279175 authored by Benjamin Bossan's avatar Benjamin Bossan Committed by GitHub
Browse files

FIX set_lora_device when target layers differ (#11844)



* FIX set_lora_device when target layers differ

Resolves #11833

Fixes a bug that occurs after calling set_lora_device when multiple LoRA
adapters are loaded that target different layers.

Note: Technically, the accompanying test does not require a GPU because
the bug is triggered even if the parameters are already on the
corresponding device, i.e. loading on CPU and then changing the device
to CPU is sufficient to cause the bug. However, this may be optimized
away in the future, so I decided to test with GPU.

* Update docstring to warn about device mismatch

* Extend docstring with an example

* Fix docstring

---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent e6639fef
...@@ -934,6 +934,27 @@ class LoraBaseMixin: ...@@ -934,6 +934,27 @@ class LoraBaseMixin:
Moves the LoRAs listed in `adapter_names` to a target device. Useful for offloading the LoRA to the CPU in case Moves the LoRAs listed in `adapter_names` to a target device. Useful for offloading the LoRA to the CPU in case
you want to load multiple adapters and free some GPU memory. you want to load multiple adapters and free some GPU memory.
After offloading the LoRA adapters to CPU, as long as the rest of the model is still on GPU, the LoRA adapters
can no longer be used for inference, as that would cause a device mismatch. Remember to set the device back to
GPU before using those LoRA adapters for inference.
```python
>>> pipe.load_lora_weights(path_1, adapter_name="adapter-1")
>>> pipe.load_lora_weights(path_2, adapter_name="adapter-2")
>>> pipe.set_adapters("adapter-1")
>>> image_1 = pipe(**kwargs)
>>> # switch to adapter-2, offload adapter-1
>>> pipeline.set_lora_device(adapter_names=["adapter-1"], device="cpu")
>>> pipeline.set_lora_device(adapter_names=["adapter-2"], device="cuda:0")
>>> pipe.set_adapters("adapter-2")
>>> image_2 = pipe(**kwargs)
>>> # switch back to adapter-1, offload adapter-2
>>> pipeline.set_lora_device(adapter_names=["adapter-2"], device="cpu")
>>> pipeline.set_lora_device(adapter_names=["adapter-1"], device="cuda:0")
>>> pipe.set_adapters("adapter-1")
>>> ...
```
Args: Args:
adapter_names (`List[str]`): adapter_names (`List[str]`):
List of adapters to send device to. List of adapters to send device to.
...@@ -949,6 +970,10 @@ class LoraBaseMixin: ...@@ -949,6 +970,10 @@ class LoraBaseMixin:
for module in model.modules(): for module in model.modules():
if isinstance(module, BaseTunerLayer): if isinstance(module, BaseTunerLayer):
for adapter_name in adapter_names: for adapter_name in adapter_names:
if adapter_name not in module.lora_A:
# it is sufficient to check lora_A
continue
module.lora_A[adapter_name].to(device) module.lora_A[adapter_name].to(device)
module.lora_B[adapter_name].to(device) module.lora_B[adapter_name].to(device)
# this is a param, not a module, so device placement is not in-place -> re-assign # this is a param, not a module, so device placement is not in-place -> re-assign
......
...@@ -120,7 +120,7 @@ class StableDiffusionLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase): ...@@ -120,7 +120,7 @@ class StableDiffusionLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase):
self.assertTrue( self.assertTrue(
check_if_lora_correctly_set(pipe.unet), check_if_lora_correctly_set(pipe.unet),
"Lora not correctly set in text encoder", "Lora not correctly set in unet",
) )
# We will offload the first adapter in CPU and check if the offloading # We will offload the first adapter in CPU and check if the offloading
...@@ -187,7 +187,7 @@ class StableDiffusionLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase): ...@@ -187,7 +187,7 @@ class StableDiffusionLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase):
self.assertTrue( self.assertTrue(
check_if_lora_correctly_set(pipe.unet), check_if_lora_correctly_set(pipe.unet),
"Lora not correctly set in text encoder", "Lora not correctly set in unet",
) )
for name, param in pipe.unet.named_parameters(): for name, param in pipe.unet.named_parameters():
...@@ -208,6 +208,53 @@ class StableDiffusionLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase): ...@@ -208,6 +208,53 @@ class StableDiffusionLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase):
if "lora_" in name: if "lora_" in name:
self.assertNotEqual(param.device, torch.device("cpu")) self.assertNotEqual(param.device, torch.device("cpu"))
@slow
@require_torch_accelerator
def test_integration_set_lora_device_different_target_layers(self):
# fixes a bug that occurred when calling set_lora_device with multiple adapters loaded that target different
# layers, see #11833
from peft import LoraConfig
path = "stable-diffusion-v1-5/stable-diffusion-v1-5"
pipe = StableDiffusionPipeline.from_pretrained(path, torch_dtype=torch.float16)
# configs partly target the same, partly different layers
config0 = LoraConfig(target_modules=["to_k", "to_v"])
config1 = LoraConfig(target_modules=["to_k", "to_q"])
pipe.unet.add_adapter(config0, adapter_name="adapter-0")
pipe.unet.add_adapter(config1, adapter_name="adapter-1")
pipe = pipe.to(torch_device)
self.assertTrue(
check_if_lora_correctly_set(pipe.unet),
"Lora not correctly set in unet",
)
# sanity check that the adapters don't target the same layers, otherwise the test passes even without the fix
modules_adapter_0 = {n for n, _ in pipe.unet.named_modules() if n.endswith(".adapter-0")}
modules_adapter_1 = {n for n, _ in pipe.unet.named_modules() if n.endswith(".adapter-1")}
self.assertNotEqual(modules_adapter_0, modules_adapter_1)
self.assertTrue(modules_adapter_0 - modules_adapter_1)
self.assertTrue(modules_adapter_1 - modules_adapter_0)
# setting both separately works
pipe.set_lora_device(["adapter-0"], "cpu")
pipe.set_lora_device(["adapter-1"], "cpu")
for name, module in pipe.unet.named_modules():
if "adapter-0" in name and not isinstance(module, (nn.Dropout, nn.Identity)):
self.assertTrue(module.weight.device == torch.device("cpu"))
elif "adapter-1" in name and not isinstance(module, (nn.Dropout, nn.Identity)):
self.assertTrue(module.weight.device == torch.device("cpu"))
# setting both at once also works
pipe.set_lora_device(["adapter-0", "adapter-1"], torch_device)
for name, module in pipe.unet.named_modules():
if "adapter-0" in name and not isinstance(module, (nn.Dropout, nn.Identity)):
self.assertTrue(module.weight.device != torch.device("cpu"))
elif "adapter-1" in name and not isinstance(module, (nn.Dropout, nn.Identity)):
self.assertTrue(module.weight.device != torch.device("cpu"))
@slow @slow
@nightly @nightly
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment