"vscode:/vscode.git/clone" did not exist on "86c81b4e927d94ed2dba76fc04e2088c6931e6b5"
Unverified Commit 92542719 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[docs] minor cleanups in the lora docs. (#11770)



* minor cleanups in the lora docs.

* Apply suggestions from code review
Co-authored-by: default avatarSteven Liu <59462357+stevhliu@users.noreply.github.com>

* format docs

* fix copies

---------
Co-authored-by: default avatarSteven Liu <59462357+stevhliu@users.noreply.github.com>
parent 67603002
...@@ -37,6 +37,10 @@ To learn more about how to load LoRA weights, see the [LoRA](../../using-diffuse ...@@ -37,6 +37,10 @@ To learn more about how to load LoRA weights, see the [LoRA](../../using-diffuse
</Tip> </Tip>
## LoraBaseMixin
[[autodoc]] loaders.lora_base.LoraBaseMixin
## StableDiffusionLoraLoaderMixin ## StableDiffusionLoraLoaderMixin
[[autodoc]] loaders.lora_pipeline.StableDiffusionLoraLoaderMixin [[autodoc]] loaders.lora_pipeline.StableDiffusionLoraLoaderMixin
...@@ -96,10 +100,6 @@ To learn more about how to load LoRA weights, see the [LoRA](../../using-diffuse ...@@ -96,10 +100,6 @@ To learn more about how to load LoRA weights, see the [LoRA](../../using-diffuse
[[autodoc]] loaders.lora_pipeline.HiDreamImageLoraLoaderMixin [[autodoc]] loaders.lora_pipeline.HiDreamImageLoraLoaderMixin
## LoraBaseMixin
[[autodoc]] loaders.lora_base.LoraBaseMixin
## WanLoraLoaderMixin ## WanLoraLoaderMixin
[[autodoc]] loaders.lora_pipeline.WanLoraLoaderMixin [[autodoc]] loaders.lora_pipeline.WanLoraLoaderMixin
\ No newline at end of file
...@@ -424,6 +424,17 @@ def _load_lora_into_text_encoder( ...@@ -424,6 +424,17 @@ def _load_lora_into_text_encoder(
def _func_optionally_disable_offloading(_pipeline): def _func_optionally_disable_offloading(_pipeline):
"""
Optionally removes offloading in case the pipeline has been already sequentially offloaded to CPU.
Args:
_pipeline (`DiffusionPipeline`):
The pipeline to disable offloading for.
Returns:
tuple:
A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True.
"""
is_model_cpu_offload = False is_model_cpu_offload = False
is_sequential_cpu_offload = False is_sequential_cpu_offload = False
...@@ -453,6 +464,24 @@ class LoraBaseMixin: ...@@ -453,6 +464,24 @@ class LoraBaseMixin:
_lora_loadable_modules = [] _lora_loadable_modules = []
_merged_adapters = set() _merged_adapters = set()
@property
def lora_scale(self) -> float:
"""
Returns the lora scale which can be set at run time by the pipeline. # if `_lora_scale` has not been set,
return 1.
"""
return self._lora_scale if hasattr(self, "_lora_scale") else 1.0
@property
def num_fused_loras(self):
"""Returns the number of LoRAs that have been fused."""
return len(self._merged_adapters)
@property
def fused_loras(self):
"""Returns names of the LoRAs that have been fused."""
return self._merged_adapters
def load_lora_weights(self, **kwargs): def load_lora_weights(self, **kwargs):
raise NotImplementedError("`load_lora_weights()` is not implemented.") raise NotImplementedError("`load_lora_weights()` is not implemented.")
...@@ -464,33 +493,6 @@ class LoraBaseMixin: ...@@ -464,33 +493,6 @@ class LoraBaseMixin:
def lora_state_dict(cls, **kwargs): def lora_state_dict(cls, **kwargs):
raise NotImplementedError("`lora_state_dict()` is not implemented.") raise NotImplementedError("`lora_state_dict()` is not implemented.")
@classmethod
def _optionally_disable_offloading(cls, _pipeline):
"""
Optionally removes offloading in case the pipeline has been already sequentially offloaded to CPU.
Args:
_pipeline (`DiffusionPipeline`):
The pipeline to disable offloading for.
Returns:
tuple:
A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True.
"""
return _func_optionally_disable_offloading(_pipeline=_pipeline)
@classmethod
def _fetch_state_dict(cls, *args, **kwargs):
deprecation_message = f"Using the `_fetch_state_dict()` method from {cls} has been deprecated and will be removed in a future version. Please use `from diffusers.loaders.lora_base import _fetch_state_dict`."
deprecate("_fetch_state_dict", "0.35.0", deprecation_message)
return _fetch_state_dict(*args, **kwargs)
@classmethod
def _best_guess_weight_name(cls, *args, **kwargs):
deprecation_message = f"Using the `_best_guess_weight_name()` method from {cls} has been deprecated and will be removed in a future version. Please use `from diffusers.loaders.lora_base import _best_guess_weight_name`."
deprecate("_best_guess_weight_name", "0.35.0", deprecation_message)
return _best_guess_weight_name(*args, **kwargs)
def unload_lora_weights(self): def unload_lora_weights(self):
""" """
Unloads the LoRA parameters. Unloads the LoRA parameters.
...@@ -661,19 +663,37 @@ class LoraBaseMixin: ...@@ -661,19 +663,37 @@ class LoraBaseMixin:
self._merged_adapters = self._merged_adapters - {adapter} self._merged_adapters = self._merged_adapters - {adapter}
module.unmerge() module.unmerge()
@property
def num_fused_loras(self):
return len(self._merged_adapters)
@property
def fused_loras(self):
return self._merged_adapters
def set_adapters( def set_adapters(
self, self,
adapter_names: Union[List[str], str], adapter_names: Union[List[str], str],
adapter_weights: Optional[Union[float, Dict, List[float], List[Dict]]] = None, adapter_weights: Optional[Union[float, Dict, List[float], List[Dict]]] = None,
): ):
"""
Set the currently active adapters for use in the pipeline.
Args:
adapter_names (`List[str]` or `str`):
The names of the adapters to use.
adapter_weights (`Union[List[float], float]`, *optional*):
The adapter(s) weights to use with the UNet. If `None`, the weights are set to `1.0` for all the
adapters.
Example:
```py
from diffusers import AutoPipelineForText2Image
import torch
pipeline = AutoPipelineForText2Image.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
).to("cuda")
pipeline.load_lora_weights(
"jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_name="cinematic"
)
pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
pipeline.set_adapters(["cinematic", "pixel"], adapter_weights=[0.5, 0.5])
```
"""
if isinstance(adapter_weights, dict): if isinstance(adapter_weights, dict):
components_passed = set(adapter_weights.keys()) components_passed = set(adapter_weights.keys())
lora_components = set(self._lora_loadable_modules) lora_components = set(self._lora_loadable_modules)
...@@ -743,6 +763,24 @@ class LoraBaseMixin: ...@@ -743,6 +763,24 @@ class LoraBaseMixin:
set_adapters_for_text_encoder(adapter_names, model, _component_adapter_weights[component]) set_adapters_for_text_encoder(adapter_names, model, _component_adapter_weights[component])
def disable_lora(self): def disable_lora(self):
"""
Disables the active LoRA layers of the pipeline.
Example:
```py
from diffusers import AutoPipelineForText2Image
import torch
pipeline = AutoPipelineForText2Image.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
).to("cuda")
pipeline.load_lora_weights(
"jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_name="cinematic"
)
pipeline.disable_lora()
```
"""
if not USE_PEFT_BACKEND: if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.") raise ValueError("PEFT backend is required for this method.")
...@@ -755,6 +793,24 @@ class LoraBaseMixin: ...@@ -755,6 +793,24 @@ class LoraBaseMixin:
disable_lora_for_text_encoder(model) disable_lora_for_text_encoder(model)
def enable_lora(self): def enable_lora(self):
"""
Enables the active LoRA layers of the pipeline.
Example:
```py
from diffusers import AutoPipelineForText2Image
import torch
pipeline = AutoPipelineForText2Image.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
).to("cuda")
pipeline.load_lora_weights(
"jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_name="cinematic"
)
pipeline.enable_lora()
```
"""
if not USE_PEFT_BACKEND: if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.") raise ValueError("PEFT backend is required for this method.")
...@@ -768,10 +824,26 @@ class LoraBaseMixin: ...@@ -768,10 +824,26 @@ class LoraBaseMixin:
def delete_adapters(self, adapter_names: Union[List[str], str]): def delete_adapters(self, adapter_names: Union[List[str], str]):
""" """
Delete an adapter's LoRA layers from the pipeline.
Args: Args:
Deletes the LoRA layers of `adapter_name` for the unet and text-encoder(s).
adapter_names (`Union[List[str], str]`): adapter_names (`Union[List[str], str]`):
The names of the adapter to delete. Can be a single string or a list of strings The names of the adapters to delete.
Example:
```py
from diffusers import AutoPipelineForText2Image
import torch
pipeline = AutoPipelineForText2Image.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
).to("cuda")
pipeline.load_lora_weights(
"jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_names="cinematic"
)
pipeline.delete_adapters("cinematic")
```
""" """
if not USE_PEFT_BACKEND: if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.") raise ValueError("PEFT backend is required for this method.")
...@@ -872,6 +944,24 @@ class LoraBaseMixin: ...@@ -872,6 +944,24 @@ class LoraBaseMixin:
adapter_name adapter_name
].to(device) ].to(device)
def enable_lora_hotswap(self, **kwargs) -> None:
"""
Hotswap adapters without triggering recompilation of a model or if the ranks of the loaded adapters are
different.
Args:
target_rank (`int`):
The highest rank among all the adapters that will be loaded.
check_compiled (`str`, *optional*, defaults to `"error"`):
How to handle a model that is already compiled. The check can return the following messages:
- "error" (default): raise an error
- "warn": issue a warning
- "ignore": do nothing
"""
for key, component in self.components.items():
if hasattr(component, "enable_lora_hotswap") and (key in self._lora_loadable_modules):
component.enable_lora_hotswap(**kwargs)
@staticmethod @staticmethod
def pack_weights(layers, prefix): def pack_weights(layers, prefix):
layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers
...@@ -887,6 +977,7 @@ class LoraBaseMixin: ...@@ -887,6 +977,7 @@ class LoraBaseMixin:
safe_serialization: bool, safe_serialization: bool,
lora_adapter_metadata: Optional[dict] = None, lora_adapter_metadata: Optional[dict] = None,
): ):
"""Writes the state dict of the LoRA layers (optionally with metadata) to disk."""
if os.path.isfile(save_directory): if os.path.isfile(save_directory):
logger.error(f"Provided path ({save_directory}) should be a directory, not a file") logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
return return
...@@ -927,28 +1018,18 @@ class LoraBaseMixin: ...@@ -927,28 +1018,18 @@ class LoraBaseMixin:
save_function(state_dict, save_path) save_function(state_dict, save_path)
logger.info(f"Model weights saved in {save_path}") logger.info(f"Model weights saved in {save_path}")
@property @classmethod
def lora_scale(self) -> float: def _optionally_disable_offloading(cls, _pipeline):
# property function that returns the lora scale which can be set at run time by the pipeline. return _func_optionally_disable_offloading(_pipeline=_pipeline)
# if _lora_scale has not been set, return 1
return self._lora_scale if hasattr(self, "_lora_scale") else 1.0
def enable_lora_hotswap(self, **kwargs) -> None:
"""Enables the possibility to hotswap LoRA adapters.
Calling this method is only required when hotswapping adapters and if the model is compiled or if the ranks of @classmethod
the loaded adapters differ. def _fetch_state_dict(cls, *args, **kwargs):
deprecation_message = f"Using the `_fetch_state_dict()` method from {cls} has been deprecated and will be removed in a future version. Please use `from diffusers.loaders.lora_base import _fetch_state_dict`."
deprecate("_fetch_state_dict", "0.35.0", deprecation_message)
return _fetch_state_dict(*args, **kwargs)
Args: @classmethod
target_rank (`int`): def _best_guess_weight_name(cls, *args, **kwargs):
The highest rank among all the adapters that will be loaded. deprecation_message = f"Using the `_best_guess_weight_name()` method from {cls} has been deprecated and will be removed in a future version. Please use `from diffusers.loaders.lora_base import _best_guess_weight_name`."
check_compiled (`str`, *optional*, defaults to `"error"`): deprecate("_best_guess_weight_name", "0.35.0", deprecation_message)
How to handle the case when the model is already compiled, which should generally be avoided. The return _best_guess_weight_name(*args, **kwargs)
options are:
- "error" (default): raise an error
- "warn": issue a warning
- "ignore": do nothing
"""
for key, component in self.components.items():
if hasattr(component, "enable_lora_hotswap") and (key in self._lora_loadable_modules):
component.enable_lora_hotswap(**kwargs)
...@@ -85,17 +85,6 @@ class PeftAdapterMixin: ...@@ -85,17 +85,6 @@ class PeftAdapterMixin:
@classmethod @classmethod
# Copied from diffusers.loaders.lora_base.LoraBaseMixin._optionally_disable_offloading # Copied from diffusers.loaders.lora_base.LoraBaseMixin._optionally_disable_offloading
def _optionally_disable_offloading(cls, _pipeline): def _optionally_disable_offloading(cls, _pipeline):
"""
Optionally removes offloading in case the pipeline has been already sequentially offloaded to CPU.
Args:
_pipeline (`DiffusionPipeline`):
The pipeline to disable offloading for.
Returns:
tuple:
A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True.
"""
return _func_optionally_disable_offloading(_pipeline=_pipeline) return _func_optionally_disable_offloading(_pipeline=_pipeline)
def load_lora_adapter( def load_lora_adapter(
...@@ -444,7 +433,7 @@ class PeftAdapterMixin: ...@@ -444,7 +433,7 @@ class PeftAdapterMixin:
weights: Optional[Union[float, Dict, List[float], List[Dict], List[None]]] = None, weights: Optional[Union[float, Dict, List[float], List[Dict], List[None]]] = None,
): ):
""" """
Set the currently active adapters for use in the UNet. Set the currently active adapters for use in the diffusion network (e.g. unet, transformer, etc.).
Args: Args:
adapter_names (`List[str]` or `str`): adapter_names (`List[str]` or `str`):
...@@ -466,7 +455,7 @@ class PeftAdapterMixin: ...@@ -466,7 +455,7 @@ class PeftAdapterMixin:
"jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_name="cinematic" "jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_name="cinematic"
) )
pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel") pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
pipeline.set_adapters(["cinematic", "pixel"], adapter_weights=[0.5, 0.5]) pipeline.unet.set_adapters(["cinematic", "pixel"], adapter_weights=[0.5, 0.5])
``` ```
""" """
if not USE_PEFT_BACKEND: if not USE_PEFT_BACKEND:
...@@ -714,7 +703,7 @@ class PeftAdapterMixin: ...@@ -714,7 +703,7 @@ class PeftAdapterMixin:
pipeline.load_lora_weights( pipeline.load_lora_weights(
"jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_name="cinematic" "jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_name="cinematic"
) )
pipeline.disable_lora() pipeline.unet.disable_lora()
``` ```
""" """
if not USE_PEFT_BACKEND: if not USE_PEFT_BACKEND:
...@@ -737,7 +726,7 @@ class PeftAdapterMixin: ...@@ -737,7 +726,7 @@ class PeftAdapterMixin:
pipeline.load_lora_weights( pipeline.load_lora_weights(
"jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_name="cinematic" "jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_name="cinematic"
) )
pipeline.enable_lora() pipeline.unet.enable_lora()
``` ```
""" """
if not USE_PEFT_BACKEND: if not USE_PEFT_BACKEND:
...@@ -764,7 +753,7 @@ class PeftAdapterMixin: ...@@ -764,7 +753,7 @@ class PeftAdapterMixin:
pipeline.load_lora_weights( pipeline.load_lora_weights(
"jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_names="cinematic" "jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_names="cinematic"
) )
pipeline.delete_adapters("cinematic") pipeline.unet.delete_adapters("cinematic")
``` ```
""" """
if not USE_PEFT_BACKEND: if not USE_PEFT_BACKEND:
......
...@@ -394,17 +394,6 @@ class UNet2DConditionLoadersMixin: ...@@ -394,17 +394,6 @@ class UNet2DConditionLoadersMixin:
@classmethod @classmethod
# Copied from diffusers.loaders.lora_base.LoraBaseMixin._optionally_disable_offloading # Copied from diffusers.loaders.lora_base.LoraBaseMixin._optionally_disable_offloading
def _optionally_disable_offloading(cls, _pipeline): def _optionally_disable_offloading(cls, _pipeline):
"""
Optionally removes offloading in case the pipeline has been already sequentially offloaded to CPU.
Args:
_pipeline (`DiffusionPipeline`):
The pipeline to disable offloading for.
Returns:
tuple:
A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True.
"""
return _func_optionally_disable_offloading(_pipeline=_pipeline) return _func_optionally_disable_offloading(_pipeline=_pipeline)
def save_attn_procs( def save_attn_procs(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment