Unverified Commit 3b079ec3 authored by Benjamin Bossan's avatar Benjamin Bossan Committed by GitHub
Browse files

ENH: Improve speed of function expanding LoRA scales (#11834)

* ENH Improve speed of expanding LoRA scales

Resolves #11816

The following call proved to be a bottleneck when setting a lot of LoRA
adapters in diffusers:

https://github.com/huggingface/diffusers/blob/cdaf84a708eadf17d731657f4be3fa39d09a12c0/src/diffusers/loaders/peft.py#L482

This is because we would repeatedly call unet.state_dict(), even though
in the standard case, it is not necessary:

https://github.com/huggingface/diffusers/blob/cdaf84a708eadf17d731657f4be3fa39d09a12c0/src/diffusers/loaders/unet_loader_utils.py#L55



This PR fixes this by deferring this call, so that it is only run when
it's necessary, not earlier.

* Small fix

---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent bc34fa83
......@@ -14,6 +14,8 @@
import copy
from typing import TYPE_CHECKING, Dict, List, Union
from torch import nn
from ..utils import logging
......@@ -52,7 +54,7 @@ def _maybe_expand_lora_scales(
weight_for_adapter,
blocks_with_transformer,
transformer_per_block,
unet.state_dict(),
model=unet,
default_scale=default_scale,
)
for weight_for_adapter in weight_scales
......@@ -65,7 +67,7 @@ def _maybe_expand_lora_scales_for_one_adapter(
scales: Union[float, Dict],
blocks_with_transformer: Dict[str, int],
transformer_per_block: Dict[str, int],
state_dict: None,
model: nn.Module,
default_scale: float = 1.0,
):
"""
......@@ -154,6 +156,7 @@ def _maybe_expand_lora_scales_for_one_adapter(
del scales[updown]
state_dict = model.state_dict()
for layer in scales.keys():
if not any(_translate_into_actual_layer_name(layer) in module for module in state_dict.keys()):
raise ValueError(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment