Unverified Commit a0542c19 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[LoRA] Remove legacy LoRA code and related adjustments (#8316)

* remove legacy code from load_attn_procs.

* finish first draft

* fix more.

* fix more

* add test

* add serialization support.

* fix-copies

* require peft backend for lora tests

* style

* fix test

* fix loading.

* empty

* address benjamin's feedback.
parent a8ad6664
...@@ -111,3 +111,21 @@ jobs: ...@@ -111,3 +111,21 @@ jobs:
-s -v \ -s -v \
--make-reports=tests_${{ matrix.config.report }} \ --make-reports=tests_${{ matrix.config.report }} \
tests/lora/ tests/lora/
python -m pytest -n 4 --max-worker-restart=0 --dist=loadfile \
-s -v \
--make-reports=tests_models_lora_${{ matrix.config.report }} \
tests/models/ -k "lora"
- name: Failure short reports
if: ${{ failure() }}
run: |
cat reports/tests_${{ matrix.config.report }}_failures_short.txt
cat reports/tests_models_lora_${{ matrix.config.report }}_failures_short.txt
- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v2
with:
name: pr_${{ matrix.config.report }}_test_reports
path: reports
\ No newline at end of file
...@@ -189,12 +189,17 @@ jobs: ...@@ -189,12 +189,17 @@ jobs:
-s -v -k "not Flax and not Onnx and not PEFTLoRALoading" \ -s -v -k "not Flax and not Onnx and not PEFTLoRALoading" \
--make-reports=tests_peft_cuda \ --make-reports=tests_peft_cuda \
tests/lora/ tests/lora/
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
-s -v -k "lora and not Flax and not Onnx and not PEFTLoRALoading" \
--make-reports=tests_peft_cuda_models_lora \
tests/models/
- name: Failure short reports - name: Failure short reports
if: ${{ failure() }} if: ${{ failure() }}
run: | run: |
cat reports/tests_peft_cuda_stats.txt cat reports/tests_peft_cuda_stats.txt
cat reports/tests_peft_cuda_failures_short.txt cat reports/tests_peft_cuda_failures_short.txt
cat reports/tests_peft_cuda_models_lora_failures_short.txt
- name: Test suite reports artifacts - name: Test suite reports artifacts
if: ${{ always() }} if: ${{ always() }}
......
...@@ -22,17 +22,14 @@ import torch ...@@ -22,17 +22,14 @@ import torch
from huggingface_hub import model_info from huggingface_hub import model_info
from huggingface_hub.constants import HF_HUB_OFFLINE from huggingface_hub.constants import HF_HUB_OFFLINE
from huggingface_hub.utils import validate_hf_hub_args from huggingface_hub.utils import validate_hf_hub_args
from packaging import version
from torch import nn from torch import nn
from .. import __version__ from ..models.modeling_utils import load_state_dict
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_state_dict
from ..utils import ( from ..utils import (
USE_PEFT_BACKEND, USE_PEFT_BACKEND,
_get_model_file, _get_model_file,
convert_state_dict_to_diffusers, convert_state_dict_to_diffusers,
convert_state_dict_to_peft, convert_state_dict_to_peft,
convert_unet_state_dict_to_peft,
delete_adapter_layers, delete_adapter_layers,
get_adapter_name, get_adapter_name,
get_peft_kwargs, get_peft_kwargs,
...@@ -119,13 +116,10 @@ class LoraLoaderMixin: ...@@ -119,13 +116,10 @@ class LoraLoaderMixin:
if not is_correct_format: if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint.") raise ValueError("Invalid LoRA checkpoint.")
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
self.load_lora_into_unet( self.load_lora_into_unet(
state_dict, state_dict,
network_alphas=network_alphas, network_alphas=network_alphas,
unet=getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet, unet=getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet,
low_cpu_mem_usage=low_cpu_mem_usage,
adapter_name=adapter_name, adapter_name=adapter_name,
_pipeline=self, _pipeline=self,
) )
...@@ -136,7 +130,6 @@ class LoraLoaderMixin: ...@@ -136,7 +130,6 @@ class LoraLoaderMixin:
if not hasattr(self, "text_encoder") if not hasattr(self, "text_encoder")
else self.text_encoder, else self.text_encoder,
lora_scale=self.lora_scale, lora_scale=self.lora_scale,
low_cpu_mem_usage=low_cpu_mem_usage,
adapter_name=adapter_name, adapter_name=adapter_name,
_pipeline=self, _pipeline=self,
) )
...@@ -193,16 +186,8 @@ class LoraLoaderMixin: ...@@ -193,16 +186,8 @@ class LoraLoaderMixin:
allowed by Git. allowed by Git.
subfolder (`str`, *optional*, defaults to `""`): subfolder (`str`, *optional*, defaults to `""`):
The subfolder location of a model file within a larger model repository on the Hub or locally. The subfolder location of a model file within a larger model repository on the Hub or locally.
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`): weight_name (`str`, *optional*, defaults to None):
Speed up model loading only loading the pretrained weights and not initializing the weights. This also Name of the serialized state dict file.
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
argument to `True` will raise an error.
mirror (`str`, *optional*):
Mirror source to resolve accessibility issues if you're downloading a model in China. We do not
guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
information.
""" """
# Load the main state dict first which has the LoRA layers for either of # Load the main state dict first which has the LoRA layers for either of
# UNet and text encoder or both. # UNet and text encoder or both.
...@@ -383,9 +368,7 @@ class LoraLoaderMixin: ...@@ -383,9 +368,7 @@ class LoraLoaderMixin:
return (is_model_cpu_offload, is_sequential_cpu_offload) return (is_model_cpu_offload, is_sequential_cpu_offload)
@classmethod @classmethod
def load_lora_into_unet( def load_lora_into_unet(cls, state_dict, network_alphas, unet, adapter_name=None, _pipeline=None):
cls, state_dict, network_alphas, unet, low_cpu_mem_usage=None, adapter_name=None, _pipeline=None
):
""" """
This will load the LoRA layers specified in `state_dict` into `unet`. This will load the LoRA layers specified in `state_dict` into `unet`.
...@@ -395,14 +378,11 @@ class LoraLoaderMixin: ...@@ -395,14 +378,11 @@ class LoraLoaderMixin:
into the unet or prefixed with an additional `unet` which can be used to distinguish between text into the unet or prefixed with an additional `unet` which can be used to distinguish between text
encoder lora layers. encoder lora layers.
network_alphas (`Dict[str, float]`): network_alphas (`Dict[str, float]`):
See `LoRALinearLayer` for more details. The value of the network alpha used for stable learning and preventing underflow. This value has the
same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
unet (`UNet2DConditionModel`): unet (`UNet2DConditionModel`):
The UNet model to load the LoRA layers into. The UNet model to load the LoRA layers into.
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
Speed up model loading only loading the pretrained weights and not initializing the weights. This also
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
argument to `True` will raise an error.
adapter_name (`str`, *optional*): adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded. `default_{i}` where i is the total number of adapters being loaded.
...@@ -410,94 +390,18 @@ class LoraLoaderMixin: ...@@ -410,94 +390,18 @@ class LoraLoaderMixin:
if not USE_PEFT_BACKEND: if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.") raise ValueError("PEFT backend is required for this method.")
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
low_cpu_mem_usage = low_cpu_mem_usage if low_cpu_mem_usage is not None else _LOW_CPU_MEM_USAGE_DEFAULT
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918), # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
# then the `state_dict` keys should have `cls.unet_name` and/or `cls.text_encoder_name` as # then the `state_dict` keys should have `cls.unet_name` and/or `cls.text_encoder_name` as
# their prefixes. # their prefixes.
keys = list(state_dict.keys()) keys = list(state_dict.keys())
only_text_encoder = all(key.startswith(cls.text_encoder_name) for key in keys)
if all(key.startswith(cls.unet_name) or key.startswith(cls.text_encoder_name) for key in keys): if any(key.startswith(cls.unet_name) for key in keys) and not only_text_encoder:
# Load the layers corresponding to UNet. # Load the layers corresponding to UNet.
logger.info(f"Loading {cls.unet_name}.") logger.info(f"Loading {cls.unet_name}.")
unet.load_attn_procs(
unet_keys = [k for k in keys if k.startswith(cls.unet_name)] state_dict, network_alphas=network_alphas, adapter_name=adapter_name, _pipeline=_pipeline
state_dict = {k.replace(f"{cls.unet_name}.", ""): v for k, v in state_dict.items() if k in unet_keys} )
if network_alphas is not None:
alpha_keys = [k for k in network_alphas.keys() if k.startswith(cls.unet_name)]
network_alphas = {
k.replace(f"{cls.unet_name}.", ""): v for k, v in network_alphas.items() if k in alpha_keys
}
else:
# Otherwise, we're dealing with the old format. This means the `state_dict` should only
# contain the module names of the `unet` as its keys WITHOUT any prefix.
if not USE_PEFT_BACKEND:
warn_message = "You have saved the LoRA weights using the old format. To convert the old LoRA weights to the new format, you can first load them in a dictionary and then create a new dictionary like the following: `new_state_dict = {f'unet.{module_name}': params for module_name, params in old_state_dict.items()}`."
logger.warning(warn_message)
if len(state_dict.keys()) > 0:
if adapter_name in getattr(unet, "peft_config", {}):
raise ValueError(
f"Adapter name {adapter_name} already in use in the Unet - please select a new adapter name."
)
state_dict = convert_unet_state_dict_to_peft(state_dict)
if network_alphas is not None:
# The alphas state dict have the same structure as Unet, thus we convert it to peft format using
# `convert_unet_state_dict_to_peft` method.
network_alphas = convert_unet_state_dict_to_peft(network_alphas)
rank = {}
for key, val in state_dict.items():
if "lora_B" in key:
rank[key] = val.shape[1]
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, is_unet=True)
if "use_dora" in lora_config_kwargs:
if lora_config_kwargs["use_dora"]:
if is_peft_version("<", "0.9.0"):
raise ValueError(
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
)
else:
if is_peft_version("<", "0.9.0"):
lora_config_kwargs.pop("use_dora")
lora_config = LoraConfig(**lora_config_kwargs)
# adapter_name
if adapter_name is None:
adapter_name = get_adapter_name(unet)
# In case the pipeline has been already offloaded to CPU - temporarily remove the hooks
# otherwise loading LoRA weights will lead to an error
is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline)
inject_adapter_in_model(lora_config, unet, adapter_name=adapter_name)
incompatible_keys = set_peft_model_state_dict(unet, state_dict, adapter_name)
if incompatible_keys is not None:
# check only for unexpected keys
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
if unexpected_keys:
logger.warning(
f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
f" {unexpected_keys}. "
)
# Offload back.
if is_model_cpu_offload:
_pipeline.enable_model_cpu_offload()
elif is_sequential_cpu_offload:
_pipeline.enable_sequential_cpu_offload()
# Unsafe code />
unet.load_attn_procs(
state_dict, network_alphas=network_alphas, low_cpu_mem_usage=low_cpu_mem_usage, _pipeline=_pipeline
)
@classmethod @classmethod
def load_lora_into_text_encoder( def load_lora_into_text_encoder(
...@@ -507,7 +411,6 @@ class LoraLoaderMixin: ...@@ -507,7 +411,6 @@ class LoraLoaderMixin:
text_encoder, text_encoder,
prefix=None, prefix=None,
lora_scale=1.0, lora_scale=1.0,
low_cpu_mem_usage=None,
adapter_name=None, adapter_name=None,
_pipeline=None, _pipeline=None,
): ):
...@@ -527,11 +430,6 @@ class LoraLoaderMixin: ...@@ -527,11 +430,6 @@ class LoraLoaderMixin:
lora_scale (`float`): lora_scale (`float`):
How much to scale the output of the lora linear layer before it is added with the output of the regular How much to scale the output of the lora linear layer before it is added with the output of the regular
lora layer. lora layer.
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
Speed up model loading only loading the pretrained weights and not initializing the weights. This also
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
argument to `True` will raise an error.
adapter_name (`str`, *optional*): adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded. `default_{i}` where i is the total number of adapters being loaded.
...@@ -541,8 +439,6 @@ class LoraLoaderMixin: ...@@ -541,8 +439,6 @@ class LoraLoaderMixin:
from peft import LoraConfig from peft import LoraConfig
low_cpu_mem_usage = low_cpu_mem_usage if low_cpu_mem_usage is not None else _LOW_CPU_MEM_USAGE_DEFAULT
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918), # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
# then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as # then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as
# their prefixes. # their prefixes.
...@@ -625,9 +521,7 @@ class LoraLoaderMixin: ...@@ -625,9 +521,7 @@ class LoraLoaderMixin:
# Unsafe code /> # Unsafe code />
@classmethod @classmethod
def load_lora_into_transformer( def load_lora_into_transformer(cls, state_dict, network_alphas, transformer, adapter_name=None, _pipeline=None):
cls, state_dict, network_alphas, transformer, low_cpu_mem_usage=None, adapter_name=None, _pipeline=None
):
""" """
This will load the LoRA layers specified in `state_dict` into `transformer`. This will load the LoRA layers specified in `state_dict` into `transformer`.
...@@ -640,19 +534,12 @@ class LoraLoaderMixin: ...@@ -640,19 +534,12 @@ class LoraLoaderMixin:
See `LoRALinearLayer` for more details. See `LoRALinearLayer` for more details.
unet (`UNet2DConditionModel`): unet (`UNet2DConditionModel`):
The UNet model to load the LoRA layers into. The UNet model to load the LoRA layers into.
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
Speed up model loading only loading the pretrained weights and not initializing the weights. This also
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
argument to `True` will raise an error.
adapter_name (`str`, *optional*): adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded. `default_{i}` where i is the total number of adapters being loaded.
""" """
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
low_cpu_mem_usage = low_cpu_mem_usage if low_cpu_mem_usage is not None else _LOW_CPU_MEM_USAGE_DEFAULT
keys = list(state_dict.keys()) keys = list(state_dict.keys())
transformer_keys = [k for k in keys if k.startswith(cls.transformer_name)] transformer_keys = [k for k in keys if k.startswith(cls.transformer_name)]
...@@ -846,22 +733,11 @@ class LoraLoaderMixin: ...@@ -846,22 +733,11 @@ class LoraLoaderMixin:
>>> ... >>> ...
``` ```
""" """
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
if not USE_PEFT_BACKEND: if not USE_PEFT_BACKEND:
if version.parse(__version__) > version.parse("0.23"): raise ValueError("PEFT backend is required for this method.")
logger.warning(
"You are using `unload_lora_weights` to disable and unload lora weights. If you want to iteratively enable and disable adapter weights,"
"you can use `pipe.enable_lora()` or `pipe.disable_lora()`. After installing the latest version of PEFT."
)
for _, module in unet.named_modules(): unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
if hasattr(module, "set_lora_layer"): unet.unload_lora()
module.set_lora_layer(None)
else:
recurse_remove_peft_layers(unet)
if hasattr(unet, "peft_config"):
del unet.peft_config
# Safe to call the following regardless of LoRA. # Safe to call the following regardless of LoRA.
self._remove_text_encoder_monkey_patch() self._remove_text_encoder_monkey_patch()
......
This diff is collapsed.
...@@ -903,17 +903,6 @@ class UNet2DConditionModel( ...@@ -903,17 +903,6 @@ class UNet2DConditionModel(
if self.original_attn_processors is not None: if self.original_attn_processors is not None:
self.set_attn_processor(self.original_attn_processors) self.set_attn_processor(self.original_attn_processors)
def unload_lora(self):
"""Unloads LoRA weights."""
deprecate(
"unload_lora",
"0.28.0",
"Calling `unload_lora()` is deprecated and will be removed in a future version. Please install `peft` and then call `disable_adapters().",
)
for module in self.modules():
if hasattr(module, "set_lora_layer"):
module.set_lora_layer(None)
def get_time_embed( def get_time_embed(
self, sample: torch.Tensor, timestep: Union[torch.Tensor, float, int] self, sample: torch.Tensor, timestep: Union[torch.Tensor, float, int]
) -> Optional[torch.Tensor]: ) -> Optional[torch.Tensor]:
......
...@@ -22,7 +22,7 @@ import torch.utils.checkpoint ...@@ -22,7 +22,7 @@ import torch.utils.checkpoint
from ...configuration_utils import ConfigMixin, register_to_config from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import UNet2DConditionLoadersMixin from ...loaders import UNet2DConditionLoadersMixin
from ...utils import BaseOutput, deprecate, logging from ...utils import BaseOutput, logging
from ..activations import get_activation from ..activations import get_activation
from ..attention_processor import ( from ..attention_processor import (
ADDED_KV_ATTENTION_PROCESSORS, ADDED_KV_ATTENTION_PROCESSORS,
...@@ -546,18 +546,6 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -546,18 +546,6 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
if self.original_attn_processors is not None: if self.original_attn_processors is not None:
self.set_attn_processor(self.original_attn_processors) self.set_attn_processor(self.original_attn_processors)
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unload_lora
def unload_lora(self):
"""Unloads LoRA weights."""
deprecate(
"unload_lora",
"0.28.0",
"Calling `unload_lora()` is deprecated and will be removed in a future version. Please install `peft` and then call `disable_adapters().",
)
for module in self.modules():
if hasattr(module, "set_lora_layer"):
module.set_lora_layer(None)
def forward( def forward(
self, self,
sample: torch.Tensor, sample: torch.Tensor,
......
...@@ -37,7 +37,9 @@ from diffusers.utils.testing_utils import ( ...@@ -37,7 +37,9 @@ from diffusers.utils.testing_utils import (
backend_empty_cache, backend_empty_cache,
enable_full_determinism, enable_full_determinism,
floats_tensor, floats_tensor,
is_peft_available,
load_hf_numpy, load_hf_numpy,
require_peft_backend,
require_torch_accelerator, require_torch_accelerator,
require_torch_accelerator_with_fp16, require_torch_accelerator_with_fp16,
require_torch_accelerator_with_training, require_torch_accelerator_with_training,
...@@ -51,11 +53,38 @@ from diffusers.utils.testing_utils import ( ...@@ -51,11 +53,38 @@ from diffusers.utils.testing_utils import (
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
if is_peft_available():
from peft import LoraConfig
from peft.tuners.tuners_utils import BaseTunerLayer
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
enable_full_determinism() enable_full_determinism()
def get_unet_lora_config():
rank = 4
unet_lora_config = LoraConfig(
r=rank,
lora_alpha=rank,
target_modules=["to_q", "to_k", "to_v", "to_out.0"],
init_lora_weights=False,
use_dora=False,
)
return unet_lora_config
def check_if_lora_correctly_set(model) -> bool:
"""
Checks if the LoRA layers are correctly set with peft
"""
for module in model.modules():
if isinstance(module, BaseTunerLayer):
return True
return False
def create_ip_adapter_state_dict(model): def create_ip_adapter_state_dict(model):
# "ip_adapter" (cross-attention weights) # "ip_adapter" (cross-attention weights)
ip_cross_attn_state_dict = {} ip_cross_attn_state_dict = {}
...@@ -1005,6 +1034,65 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test ...@@ -1005,6 +1034,65 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
assert sample2.allclose(sample5, atol=1e-4, rtol=1e-4) assert sample2.allclose(sample5, atol=1e-4, rtol=1e-4)
assert sample2.allclose(sample6, atol=1e-4, rtol=1e-4) assert sample2.allclose(sample6, atol=1e-4, rtol=1e-4)
@require_peft_backend
def test_lora(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
model.to(torch_device)
# forward pass without LoRA
with torch.no_grad():
non_lora_sample = model(**inputs_dict).sample
unet_lora_config = get_unet_lora_config()
model.add_adapter(unet_lora_config)
assert check_if_lora_correctly_set(model), "Lora not correctly set in UNet."
# forward pass with LoRA
with torch.no_grad():
lora_sample = model(**inputs_dict).sample
assert not torch.allclose(
non_lora_sample, lora_sample, atol=1e-4, rtol=1e-4
), "LoRA injected UNet should produce different results."
@require_peft_backend
def test_lora_serialization(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
model.to(torch_device)
# forward pass without LoRA
with torch.no_grad():
non_lora_sample = model(**inputs_dict).sample
unet_lora_config = get_unet_lora_config()
model.add_adapter(unet_lora_config)
assert check_if_lora_correctly_set(model), "Lora not correctly set in UNet."
# forward pass with LoRA
with torch.no_grad():
lora_sample_1 = model(**inputs_dict).sample
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_attn_procs(tmpdirname)
model.unload_lora()
model.load_attn_procs(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
assert check_if_lora_correctly_set(model), "Lora not correctly set in UNet."
with torch.no_grad():
lora_sample_2 = model(**inputs_dict).sample
assert not torch.allclose(
non_lora_sample, lora_sample_1, atol=1e-4, rtol=1e-4
), "LoRA injected UNet should produce different results."
assert torch.allclose(
lora_sample_1, lora_sample_2, atol=1e-4, rtol=1e-4
), "Loading from a saved checkpoint should produce identical results."
@slow @slow
class UNet2DConditionModelIntegrationTests(unittest.TestCase): class UNet2DConditionModelIntegrationTests(unittest.TestCase):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment