Unverified Commit 02247d9c authored by Sourab Mangrulkar's avatar Sourab Mangrulkar Committed by GitHub
Browse files

PEFT Integration for Text Encoder to handle multiple alphas/ranks,...


PEFT Integration for Text Encoder to handle multiple alphas/ranks, disable/enable adapters and support for multiple adapters (#5147)

* more fixes

* up

* up

* style

* add in setup

* oops

* more changes

* v1 rzfactor CI

* Apply suggestions from code review
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* few todos

* protect torch import

* style

* fix fuse text encoder

* Update src/diffusers/loaders.py
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>

* replace with `recurse_replace_peft_layers`

* keep old modules for BC

* adjustments on `adjust_lora_scale_text_encoder`

* nit

* move tests

* add conversion utils

* remove unneeded methods

* use class method instead

* oops

* use `base_version`

* fix examples

* fix CI

* fix weird error with python 3.8

* fix

* better fix

* style

* Apply suggestions from code review
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* Apply suggestions from code review
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* add comment

* Apply suggestions from code review
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>

* conv2d support for recurse remove

* added docstrings

* more docstring

* add deprecate

* revert

* try to fix merge conflicts

* peft integration features for text encoder

1. support multiple rank/alpha values
2. support multiple active adapters
3. support disabling and enabling adapters

* fix bug

* fix code quality

* Apply suggestions from code review
Co-authored-by: default avatarYounes Belkada <49240599+younesbelkada@users.noreply.github.com>

* fix bugs

* Apply suggestions from code review
Co-authored-by: default avatarYounes Belkada <49240599+younesbelkada@users.noreply.github.com>

* address comments
Co-Authored-By: default avatarBenjamin Bossan <BenjaminBossan@users.noreply.github.com>
Co-Authored-By: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* fix code quality

* address comments

* address comments

* Apply suggestions from code review

* find and replace

---------
Co-authored-by: default avataryounesbelkada <younesbelkada@gmail.com>
Co-authored-by: default avatarYounes Belkada <49240599+younesbelkada@users.noreply.github.com>
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
Co-authored-by: default avatarBenjamin Bossan <BenjaminBossan@users.noreply.github.com>
parent 940f9410
...@@ -35,18 +35,23 @@ from .utils import ( ...@@ -35,18 +35,23 @@ from .utils import (
convert_state_dict_to_diffusers, convert_state_dict_to_diffusers,
convert_state_dict_to_peft, convert_state_dict_to_peft,
deprecate, deprecate,
get_adapter_name,
get_peft_kwargs,
is_accelerate_available, is_accelerate_available,
is_omegaconf_available, is_omegaconf_available,
is_peft_available, is_peft_available,
is_transformers_available, is_transformers_available,
logging, logging,
recurse_remove_peft_layers, recurse_remove_peft_layers,
scale_lora_layers,
set_adapter_layers,
set_weights_and_activate_adapters,
) )
from .utils.import_utils import BACKENDS_MAPPING from .utils.import_utils import BACKENDS_MAPPING
if is_transformers_available(): if is_transformers_available():
from transformers import CLIPTextModel, CLIPTextModelWithProjection from transformers import CLIPTextModel, CLIPTextModelWithProjection, PreTrainedModel
if is_accelerate_available(): if is_accelerate_available():
from accelerate import init_empty_weights from accelerate import init_empty_weights
...@@ -1100,7 +1105,9 @@ class LoraLoaderMixin: ...@@ -1100,7 +1105,9 @@ class LoraLoaderMixin:
num_fused_loras = 0 num_fused_loras = 0
use_peft_backend = USE_PEFT_BACKEND use_peft_backend = USE_PEFT_BACKEND
def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs): def load_lora_weights(
self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
):
""" """
Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.unet` and Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.unet` and
`self.text_encoder`. `self.text_encoder`.
...@@ -1120,6 +1127,9 @@ class LoraLoaderMixin: ...@@ -1120,6 +1127,9 @@ class LoraLoaderMixin:
See [`~loaders.LoraLoaderMixin.lora_state_dict`]. See [`~loaders.LoraLoaderMixin.lora_state_dict`].
kwargs (`dict`, *optional*): kwargs (`dict`, *optional*):
See [`~loaders.LoraLoaderMixin.lora_state_dict`]. See [`~loaders.LoraLoaderMixin.lora_state_dict`].
adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded.
""" """
# First, ensure that the checkpoint is a compatible one and can be successfully loaded. # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
state_dict, network_alphas = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) state_dict, network_alphas = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
...@@ -1143,6 +1153,7 @@ class LoraLoaderMixin: ...@@ -1143,6 +1153,7 @@ class LoraLoaderMixin:
text_encoder=self.text_encoder, text_encoder=self.text_encoder,
lora_scale=self.lora_scale, lora_scale=self.lora_scale,
low_cpu_mem_usage=low_cpu_mem_usage, low_cpu_mem_usage=low_cpu_mem_usage,
adapter_name=adapter_name,
_pipeline=self, _pipeline=self,
) )
...@@ -1500,6 +1511,7 @@ class LoraLoaderMixin: ...@@ -1500,6 +1511,7 @@ class LoraLoaderMixin:
prefix=None, prefix=None,
lora_scale=1.0, lora_scale=1.0,
low_cpu_mem_usage=None, low_cpu_mem_usage=None,
adapter_name=None,
_pipeline=None, _pipeline=None,
): ):
""" """
...@@ -1523,6 +1535,9 @@ class LoraLoaderMixin: ...@@ -1523,6 +1535,9 @@ class LoraLoaderMixin:
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model. tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
argument to `True` will raise an error. argument to `True` will raise an error.
adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded.
""" """
low_cpu_mem_usage = low_cpu_mem_usage if low_cpu_mem_usage is not None else _LOW_CPU_MEM_USAGE_DEFAULT low_cpu_mem_usage = low_cpu_mem_usage if low_cpu_mem_usage is not None else _LOW_CPU_MEM_USAGE_DEFAULT
...@@ -1584,19 +1599,22 @@ class LoraLoaderMixin: ...@@ -1584,19 +1599,22 @@ class LoraLoaderMixin:
if cls.use_peft_backend: if cls.use_peft_backend:
from peft import LoraConfig from peft import LoraConfig
lora_rank = list(rank.values())[0] lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict)
# By definition, the scale should be alpha divided by rank.
# https://github.com/huggingface/peft/blob/ba0477f2985b1ba311b83459d29895c809404e99/src/peft/tuners/lora/layer.py#L71
alpha = lora_scale * lora_rank
target_modules = ["q_proj", "k_proj", "v_proj", "out_proj"] lora_config = LoraConfig(**lora_config_kwargs)
if patch_mlp:
target_modules += ["fc1", "fc2"]
# TODO: support multi alpha / rank: https://github.com/huggingface/peft/pull/873 # adapter_name
lora_config = LoraConfig(r=lora_rank, target_modules=target_modules, lora_alpha=alpha) if adapter_name is None:
adapter_name = get_adapter_name(text_encoder)
text_encoder.load_adapter(adapter_state_dict=text_encoder_lora_state_dict, peft_config=lora_config) # inject LoRA layers and load the state dict
text_encoder.load_adapter(
adapter_name=adapter_name,
adapter_state_dict=text_encoder_lora_state_dict,
peft_config=lora_config,
)
# scale LoRA layers with `lora_scale`
scale_lora_layers(text_encoder, weight=lora_scale)
is_model_cpu_offload = False is_model_cpu_offload = False
is_sequential_cpu_offload = False is_sequential_cpu_offload = False
...@@ -2178,6 +2196,81 @@ class LoraLoaderMixin: ...@@ -2178,6 +2196,81 @@ class LoraLoaderMixin:
self.num_fused_loras -= 1 self.num_fused_loras -= 1
def set_adapter_for_text_encoder(
self,
adapter_names: Union[List[str], str],
text_encoder: Optional[PreTrainedModel] = None,
text_encoder_weights: List[float] = None,
):
"""
Sets the adapter layers for the text encoder.
Args:
adapter_names (`List[str]` or `str`):
The names of the adapters to use.
text_encoder (`torch.nn.Module`, *optional*):
The text encoder module to set the adapter layers for. If `None`, it will try to get the `text_encoder`
attribute.
text_encoder_weights (`List[float]`, *optional*):
The weights to use for the text encoder. If `None`, the weights are set to `1.0` for all the adapters.
"""
if not self.use_peft_backend:
raise ValueError("PEFT backend is required for this method.")
def process_weights(adapter_names, weights):
if weights is None:
weights = [1.0] * len(adapter_names)
elif isinstance(weights, float):
weights = [weights]
if len(adapter_names) != len(weights):
raise ValueError(
f"Length of adapter names {len(adapter_names)} is not equal to the length of the weights {len(weights)}"
)
return weights
adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names
text_encoder_weights = process_weights(adapter_names, text_encoder_weights)
text_encoder = text_encoder or getattr(self, "text_encoder", None)
if text_encoder is None:
raise ValueError(
"The pipeline does not have a default `pipe.text_encoder` class. Please make sure to pass a `text_encoder` instead."
)
set_weights_and_activate_adapters(text_encoder, adapter_names, text_encoder_weights)
def disable_lora_for_text_encoder(self, text_encoder: Optional[PreTrainedModel] = None):
"""
Disables the LoRA layers for the text encoder.
Args:
text_encoder (`torch.nn.Module`, *optional*):
The text encoder module to disable the LoRA layers for. If `None`, it will try to get the
`text_encoder` attribute.
"""
if not self.use_peft_backend:
raise ValueError("PEFT backend is required for this method.")
text_encoder = text_encoder or getattr(self, "text_encoder", None)
if text_encoder is None:
raise ValueError("Text Encoder not found.")
set_adapter_layers(text_encoder, enabled=False)
def enable_lora_for_text_encoder(self, text_encoder: Optional[PreTrainedModel] = None):
"""
Enables the LoRA layers for the text encoder.
Args:
text_encoder (`torch.nn.Module`, *optional*):
The text encoder module to enable the LoRA layers for. If `None`, it will try to get the `text_encoder`
attribute.
"""
if not self.use_peft_backend:
raise ValueError("PEFT backend is required for this method.")
text_encoder = text_encoder or getattr(self, "text_encoder", None)
if text_encoder is None:
raise ValueError("Text Encoder not found.")
set_adapter_layers(self.text_encoder, enabled=True)
class FromSingleFileMixin: class FromSingleFileMixin:
""" """
......
...@@ -19,7 +19,7 @@ import torch.nn.functional as F ...@@ -19,7 +19,7 @@ import torch.nn.functional as F
from torch import nn from torch import nn
from ..loaders import PatchedLoraProjection, text_encoder_attn_modules, text_encoder_mlp_modules from ..loaders import PatchedLoraProjection, text_encoder_attn_modules, text_encoder_mlp_modules
from ..utils import logging from ..utils import logging, scale_lora_layers
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
...@@ -27,11 +27,7 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name ...@@ -27,11 +27,7 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def adjust_lora_scale_text_encoder(text_encoder, lora_scale: float = 1.0, use_peft_backend: bool = False): def adjust_lora_scale_text_encoder(text_encoder, lora_scale: float = 1.0, use_peft_backend: bool = False):
if use_peft_backend: if use_peft_backend:
from peft.tuners.lora import LoraLayer scale_lora_layers(text_encoder, weight=lora_scale)
for module in text_encoder.modules():
if isinstance(module, LoraLayer):
module.scaling[module.active_adapter] = lora_scale
else: else:
for _, attn_module in text_encoder_attn_modules(text_encoder): for _, attn_module in text_encoder_attn_modules(text_encoder):
if isinstance(attn_module.q_proj, PatchedLoraProjection): if isinstance(attn_module.q_proj, PatchedLoraProjection):
......
...@@ -84,7 +84,14 @@ from .import_utils import ( ...@@ -84,7 +84,14 @@ from .import_utils import (
from .loading_utils import load_image from .loading_utils import load_image
from .logging import get_logger from .logging import get_logger
from .outputs import BaseOutput from .outputs import BaseOutput
from .peft_utils import recurse_remove_peft_layers from .peft_utils import (
get_adapter_name,
get_peft_kwargs,
recurse_remove_peft_layers,
scale_lora_layers,
set_adapter_layers,
set_weights_and_activate_adapters,
)
from .pil_utils import PIL_INTERPOLATION, make_image_grid, numpy_to_pil, pt_to_pil from .pil_utils import PIL_INTERPOLATION, make_image_grid, numpy_to_pil, pt_to_pil
from .state_dict_utils import convert_state_dict_to_diffusers, convert_state_dict_to_peft from .state_dict_utils import convert_state_dict_to_diffusers, convert_state_dict_to_peft
......
...@@ -14,6 +14,8 @@ ...@@ -14,6 +14,8 @@
""" """
PEFT utilities: Utilities related to peft library PEFT utilities: Utilities related to peft library
""" """
import collections
from .import_utils import is_torch_available from .import_utils import is_torch_available
...@@ -68,3 +70,98 @@ def recurse_remove_peft_layers(model): ...@@ -68,3 +70,98 @@ def recurse_remove_peft_layers(model):
torch.cuda.empty_cache() torch.cuda.empty_cache()
return model return model
def scale_lora_layers(model, weight):
"""
Adjust the weightage given to the LoRA layers of the model.
Args:
model (`torch.nn.Module`):
The model to scale.
weight (`float`):
The weight to be given to the LoRA layers.
"""
from peft.tuners.tuners_utils import BaseTunerLayer
for module in model.modules():
if isinstance(module, BaseTunerLayer):
module.scale_layer(weight)
def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict):
rank_pattern = {}
alpha_pattern = {}
r = lora_alpha = list(rank_dict.values())[0]
if len(set(rank_dict.values())) > 1:
# get the rank occuring the most number of times
r = collections.Counter(rank_dict.values()).most_common()[0][0]
# for modules with rank different from the most occuring rank, add it to the `rank_pattern`
rank_pattern = dict(filter(lambda x: x[1] != r, rank_dict.items()))
rank_pattern = {k.split(".lora_B.")[0]: v for k, v in rank_pattern.items()}
if network_alpha_dict is not None and len(set(network_alpha_dict.values())) > 1:
# get the alpha occuring the most number of times
lora_alpha = collections.Counter(network_alpha_dict.values()).most_common()[0][0]
# for modules with alpha different from the most occuring alpha, add it to the `alpha_pattern`
alpha_pattern = dict(filter(lambda x: x[1] != lora_alpha, network_alpha_dict.items()))
alpha_pattern = {".".join(k.split(".down.")[0].split(".")[:-1]): v for k, v in alpha_pattern.items()}
# layer names without the Diffusers specific
target_modules = list({name.split(".lora")[0] for name in peft_state_dict.keys()})
lora_config_kwargs = {
"r": r,
"lora_alpha": lora_alpha,
"rank_pattern": rank_pattern,
"alpha_pattern": alpha_pattern,
"target_modules": target_modules,
}
return lora_config_kwargs
def get_adapter_name(model):
from peft.tuners.tuners_utils import BaseTunerLayer
for module in model.modules():
if isinstance(module, BaseTunerLayer):
return f"default_{len(module.r)}"
return "default_0"
def set_adapter_layers(model, enabled=True):
from peft.tuners.tuners_utils import BaseTunerLayer
for module in model.modules():
if isinstance(module, BaseTunerLayer):
# The recent version of PEFT needs to call `enable_adapters` instead
if hasattr(module, "enable_adapters"):
module.enable_adapters(enabled=False)
else:
module.disable_adapters = True
def set_weights_and_activate_adapters(model, adapter_names, weights):
from peft.tuners.tuners_utils import BaseTunerLayer
# iterate over each adapter, make it active and set the corresponding scaling weight
for adapter_name, weight in zip(adapter_names, weights):
for module in model.modules():
if isinstance(module, BaseTunerLayer):
# For backward compatbility with previous PEFT versions
if hasattr(module, "set_adapter"):
module.set_adapter(adapter_name)
else:
module.active_adapter = adapter_name
module.scale_layer(weight)
# set multiple active adapters
for module in model.modules():
if isinstance(module, BaseTunerLayer):
# For backward compatbility with previous PEFT versions
if hasattr(module, "set_adapter"):
module.set_adapter(adapter_names)
else:
module.active_adapter = adapter_names
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment