Unverified Commit 2bfa55f4 authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

[`core` / `PEFT` / `LoRA`] Integrate PEFT into Unet (#5151)



* v1

* add tests and fix previous failing tests

* fix CI

* add tests + v1 `PeftLayerScaler`

* style

* add scale retrieving mechanism system

* fix CI

* up

* up

* simple approach --> not same results for some reason

* fix issues

* fix copies

* remove unneeded method

* active adapters!

* fix merge conflicts

* up

* up

* kohya - test-1

* Apply suggestions from code review
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* fix scale

* fix copies

* add comment

* multi adapters

* fix tests

* oops

* v1 faster loading - in progress

* Revert "v1 faster loading - in progress"

This reverts commit ac925f81321e95fc8168184c3346bf3d75404d5a.

* kohya same generation

* fix some slow tests

* peft integration features for unet lora

1. Support for Multiple ranks/alphas
2. Support for Multiple active adapters
3. Support for enabling/disabling LoRAs

* fix `get_peft_kwargs`

* Update loaders.py

* add some tests

* add unfuse tests

* fix tests

* up

* add set adapter from sourab and tests

* fix multi adapter tests

* style & quality

* style

* remove comment

* fix `adapter_name` issues

* fix unet adapter name for sdxl

* fix enabling/disabling adapters

* fix fuse / unfuse unet

* nit

* fix

* up

* fix cpu offloading

* fix another slow test

* fix another offload test

* add more tests

* all slow tests pass

* style

* fix alpha pattern for unet and text encoder

* Update src/diffusers/loaders.py
Co-authored-by: default avatarBenjamin Bossan <BenjaminBossan@users.noreply.github.com>

* Update src/diffusers/models/attention.py
Co-authored-by: default avatarBenjamin Bossan <BenjaminBossan@users.noreply.github.com>

* up

* up

* clarify comment

* comments

* change comment order

* change comment order

* stylr & quality

* Update tests/lora/test_lora_layers_peft.py
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* fix bugs and add tests

* Update src/diffusers/models/modeling_utils.py
Co-authored-by: default avatarBenjamin Bossan <BenjaminBossan@users.noreply.github.com>

* Update src/diffusers/models/modeling_utils.py
Co-authored-by: default avatarBenjamin Bossan <BenjaminBossan@users.noreply.github.com>

* refactor

* suggestion

* add break statemebt

* add compile tests

* move slow tests to peft tests as I modified them

* quality

* refactor a bit

* style

* change import

* style

* fix CI

* refactor slow tests one last time

* style

* oops

* oops

* oops

* final tweak tests

* Apply suggestions from code review
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>

* Update src/diffusers/loaders.py
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>

* comments

* Apply suggestions from code review
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>

* remove comments

* more comments

* try

* revert

* add `safe_merge` tests

* add comment

* style, comments and run tests in fp16

* add warnings

* fix doc test

* replace with `adapter_weights`

* add `get_active_adapters()`

* expose `get_list_adapters` method

* better error message

* Apply suggestions from code review
Co-authored-by: default avatarSteven Liu <59462357+stevhliu@users.noreply.github.com>

* style

* trigger slow lora tests

* fix tests

* maybe fix last test

* revert

* Update src/diffusers/loaders.py
Co-authored-by: default avatarBenjamin Bossan <BenjaminBossan@users.noreply.github.com>

* Update src/diffusers/loaders.py
Co-authored-by: default avatarBenjamin Bossan <BenjaminBossan@users.noreply.github.com>

* Update src/diffusers/loaders.py
Co-authored-by: default avatarBenjamin Bossan <BenjaminBossan@users.noreply.github.com>

* Update src/diffusers/loaders.py
Co-authored-by: default avatarBenjamin Bossan <BenjaminBossan@users.noreply.github.com>

* Apply suggestions from code review
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>

* move `MIN_PEFT_VERSION`

* Apply suggestions from code review
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* let's not use class variable

* fix few nits

* change a bit offloading logic

* check earlier

* rm unneeded block

* break long line

* return empty list

* change logic a bit and address comments

* add typehint

* remove parenthesis

* fix

* revert to fp16 in tests

* add to gpu

* revert to old test

* style

* Update src/diffusers/loaders.py
Co-authored-by: default avatarBenjamin Bossan <BenjaminBossan@users.noreply.github.com>

* change indent

* Apply suggestions from code review

* Apply suggestions from code review

---------
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: default avatarSourab Mangrulkar <13534540+pacman100@users.noreply.github.com>
Co-authored-by: default avatarBenjamin Bossan <BenjaminBossan@users.noreply.github.com>
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
Co-authored-by: default avatarSteven Liu <59462357+stevhliu@users.noreply.github.com>
parent 9bc55e8b
...@@ -31,7 +31,14 @@ from ...models.attention_processor import ( ...@@ -31,7 +31,14 @@ from ...models.attention_processor import (
) )
from ...models.lora import adjust_lora_scale_text_encoder from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import PIL_INTERPOLATION, logging, replace_example_docstring, scale_lora_layers, unscale_lora_layers from ...utils import (
PIL_INTERPOLATION,
USE_PEFT_BACKEND,
logging,
replace_example_docstring,
scale_lora_layers,
unscale_lora_layers,
)
from ...utils.torch_utils import randn_tensor from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline from ..pipeline_utils import DiffusionPipeline
from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
...@@ -283,7 +290,7 @@ class StableDiffusionXLAdapterPipeline( ...@@ -283,7 +290,7 @@ class StableDiffusionXLAdapterPipeline(
self._lora_scale = lora_scale self._lora_scale = lora_scale
# dynamically adjust the LoRA scale # dynamically adjust the LoRA scale
if not self.use_peft_backend: if not USE_PEFT_BACKEND:
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
else: else:
...@@ -425,7 +432,7 @@ class StableDiffusionXLAdapterPipeline( ...@@ -425,7 +432,7 @@ class StableDiffusionXLAdapterPipeline(
bs_embed * num_images_per_prompt, -1 bs_embed * num_images_per_prompt, -1
) )
if isinstance(self, StableDiffusionXLLoraLoaderMixin) and self.use_peft_backend: if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers # Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder) unscale_lora_layers(self.text_encoder)
unscale_lora_layers(self.text_encoder_2) unscale_lora_layers(self.text_encoder_2)
......
...@@ -23,7 +23,14 @@ from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin ...@@ -23,7 +23,14 @@ from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet3DConditionModel from ...models import AutoencoderKL, UNet3DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import deprecate, logging, replace_example_docstring, scale_lora_layers, unscale_lora_layers from ...utils import (
USE_PEFT_BACKEND,
deprecate,
logging,
replace_example_docstring,
scale_lora_layers,
unscale_lora_layers,
)
from ...utils.torch_utils import randn_tensor from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline from ..pipeline_utils import DiffusionPipeline
from . import TextToVideoSDPipelineOutput from . import TextToVideoSDPipelineOutput
...@@ -224,7 +231,7 @@ class TextToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lora ...@@ -224,7 +231,7 @@ class TextToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lora
self._lora_scale = lora_scale self._lora_scale = lora_scale
# dynamically adjust the LoRA scale # dynamically adjust the LoRA scale
if not self.use_peft_backend: if not USE_PEFT_BACKEND:
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
else: else:
scale_lora_layers(self.text_encoder, lora_scale) scale_lora_layers(self.text_encoder, lora_scale)
...@@ -352,7 +359,7 @@ class TextToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lora ...@@ -352,7 +359,7 @@ class TextToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lora
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
if isinstance(self, LoraLoaderMixin) and self.use_peft_backend: if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers # Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder) unscale_lora_layers(self.text_encoder)
......
...@@ -24,7 +24,14 @@ from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin ...@@ -24,7 +24,14 @@ from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet3DConditionModel from ...models import AutoencoderKL, UNet3DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import deprecate, logging, replace_example_docstring, scale_lora_layers, unscale_lora_layers from ...utils import (
USE_PEFT_BACKEND,
deprecate,
logging,
replace_example_docstring,
scale_lora_layers,
unscale_lora_layers,
)
from ...utils.torch_utils import randn_tensor from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline from ..pipeline_utils import DiffusionPipeline
from . import TextToVideoSDPipelineOutput from . import TextToVideoSDPipelineOutput
...@@ -286,7 +293,7 @@ class VideoToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lor ...@@ -286,7 +293,7 @@ class VideoToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lor
self._lora_scale = lora_scale self._lora_scale = lora_scale
# dynamically adjust the LoRA scale # dynamically adjust the LoRA scale
if not self.use_peft_backend: if not USE_PEFT_BACKEND:
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
else: else:
scale_lora_layers(self.text_encoder, lora_scale) scale_lora_layers(self.text_encoder, lora_scale)
...@@ -414,7 +421,7 @@ class VideoToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lor ...@@ -414,7 +421,7 @@ class VideoToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lor
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
if isinstance(self, LoraLoaderMixin) and self.use_peft_backend: if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers # Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder) unscale_lora_layers(self.text_encoder)
......
...@@ -18,7 +18,7 @@ from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin ...@@ -18,7 +18,7 @@ from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL from ...models import AutoencoderKL
from ...models.lora import adjust_lora_scale_text_encoder from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import deprecate, logging, scale_lora_layers, unscale_lora_layers from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
from ...utils.outputs import BaseOutput from ...utils.outputs import BaseOutput
from ...utils.torch_utils import randn_tensor from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline from ..pipeline_utils import DiffusionPipeline
...@@ -426,7 +426,7 @@ class UniDiffuserPipeline(DiffusionPipeline): ...@@ -426,7 +426,7 @@ class UniDiffuserPipeline(DiffusionPipeline):
self._lora_scale = lora_scale self._lora_scale = lora_scale
# dynamically adjust the LoRA scale # dynamically adjust the LoRA scale
if not self.use_peft_backend: if not USE_PEFT_BACKEND:
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
else: else:
scale_lora_layers(self.text_encoder, lora_scale) scale_lora_layers(self.text_encoder, lora_scale)
...@@ -554,7 +554,7 @@ class UniDiffuserPipeline(DiffusionPipeline): ...@@ -554,7 +554,7 @@ class UniDiffuserPipeline(DiffusionPipeline):
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
if isinstance(self, LoraLoaderMixin) and self.use_peft_backend: if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers # Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder) unscale_lora_layers(self.text_encoder)
......
...@@ -31,7 +31,7 @@ from ...models.embeddings import ( ...@@ -31,7 +31,7 @@ from ...models.embeddings import (
) )
from ...models.transformer_2d import Transformer2DModel from ...models.transformer_2d import Transformer2DModel
from ...models.unet_2d_condition import UNet2DConditionOutput from ...models.unet_2d_condition import UNet2DConditionOutput
from ...utils import is_torch_version, logging from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import apply_freeu from ...utils.torch_utils import apply_freeu
...@@ -1211,6 +1211,9 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ...@@ -1211,6 +1211,9 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
# 3. down # 3. down
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
is_adapter = mid_block_additional_residual is None and down_block_additional_residuals is not None is_adapter = mid_block_additional_residual is None and down_block_additional_residuals is not None
...@@ -1310,6 +1313,10 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ...@@ -1310,6 +1313,10 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
sample = self.conv_act(sample) sample = self.conv_act(sample)
sample = self.conv_out(sample) sample = self.conv_out(sample)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self)
if not return_dict: if not return_dict:
return (sample,) return (sample,)
......
...@@ -26,9 +26,11 @@ from .constants import ( ...@@ -26,9 +26,11 @@ from .constants import (
FLAX_WEIGHTS_NAME, FLAX_WEIGHTS_NAME,
HF_MODULES_CACHE, HF_MODULES_CACHE,
HUGGINGFACE_CO_RESOLVE_ENDPOINT, HUGGINGFACE_CO_RESOLVE_ENDPOINT,
MIN_PEFT_VERSION,
ONNX_EXTERNAL_WEIGHTS_NAME, ONNX_EXTERNAL_WEIGHTS_NAME,
ONNX_WEIGHTS_NAME, ONNX_WEIGHTS_NAME,
SAFETENSORS_WEIGHTS_NAME, SAFETENSORS_WEIGHTS_NAME,
USE_PEFT_BACKEND,
WEIGHTS_NAME, WEIGHTS_NAME,
) )
from .deprecation_utils import deprecate from .deprecation_utils import deprecate
...@@ -86,6 +88,7 @@ from .loading_utils import load_image ...@@ -86,6 +88,7 @@ from .loading_utils import load_image
from .logging import get_logger from .logging import get_logger
from .outputs import BaseOutput from .outputs import BaseOutput
from .peft_utils import ( from .peft_utils import (
check_peft_version,
get_adapter_name, get_adapter_name,
get_peft_kwargs, get_peft_kwargs,
recurse_remove_peft_layers, recurse_remove_peft_layers,
...@@ -95,7 +98,11 @@ from .peft_utils import ( ...@@ -95,7 +98,11 @@ from .peft_utils import (
unscale_lora_layers, unscale_lora_layers,
) )
from .pil_utils import PIL_INTERPOLATION, make_image_grid, numpy_to_pil, pt_to_pil from .pil_utils import PIL_INTERPOLATION, make_image_grid, numpy_to_pil, pt_to_pil
from .state_dict_utils import convert_state_dict_to_diffusers, convert_state_dict_to_peft from .state_dict_utils import (
convert_state_dict_to_diffusers,
convert_state_dict_to_peft,
convert_unet_state_dict_to_peft,
)
logger = get_logger(__name__) logger = get_logger(__name__)
......
...@@ -11,13 +11,19 @@ ...@@ -11,13 +11,19 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import importlib
import os import os
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE, hf_cache_home from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE, hf_cache_home
from packaging import version
from .import_utils import is_peft_available, is_transformers_available
default_cache_path = HUGGINGFACE_HUB_CACHE default_cache_path = HUGGINGFACE_HUB_CACHE
MIN_PEFT_VERSION = "0.5.0"
CONFIG_NAME = "config.json" CONFIG_NAME = "config.json"
WEIGHTS_NAME = "diffusion_pytorch_model.bin" WEIGHTS_NAME = "diffusion_pytorch_model.bin"
...@@ -30,3 +36,16 @@ DIFFUSERS_CACHE = default_cache_path ...@@ -30,3 +36,16 @@ DIFFUSERS_CACHE = default_cache_path
DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules" DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules"
HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(hf_cache_home, "modules")) HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(hf_cache_home, "modules"))
DEPRECATED_REVISION_ARGS = ["fp16", "non-ema"] DEPRECATED_REVISION_ARGS = ["fp16", "non-ema"]
# Below should be `True` if the current version of `peft` and `transformers` are compatible with
# PEFT backend. Will automatically fall back to PEFT backend if the correct versions of the libraries are
# available.
# For PEFT it is has to be greater than 0.6.0 and for transformers it has to be greater than 4.33.1.
_required_peft_version = is_peft_available() and version.parse(
version.parse(importlib.metadata.version("peft")).base_version
) > version.parse(MIN_PEFT_VERSION)
_required_transformers_version = is_transformers_available() and version.parse(
version.parse(importlib.metadata.version("transformers")).base_version
) > version.parse("4.33")
USE_PEFT_BACKEND = _required_peft_version and _required_transformers_version
...@@ -15,8 +15,11 @@ ...@@ -15,8 +15,11 @@
PEFT utilities: Utilities related to peft library PEFT utilities: Utilities related to peft library
""" """
import collections import collections
import importlib
from .import_utils import is_torch_available from packaging import version
from .import_utils import is_peft_available, is_torch_available
def recurse_remove_peft_layers(model): def recurse_remove_peft_layers(model):
...@@ -53,7 +56,6 @@ def recurse_remove_peft_layers(model): ...@@ -53,7 +56,6 @@ def recurse_remove_peft_layers(model):
module.padding, module.padding,
module.dilation, module.dilation,
module.groups, module.groups,
module.bias,
).to(module.weight.device) ).to(module.weight.device)
new_module.weight = module.weight new_module.weight = module.weight
...@@ -106,10 +108,11 @@ def unscale_lora_layers(model): ...@@ -106,10 +108,11 @@ def unscale_lora_layers(model):
module.unscale_layer() module.unscale_layer()
def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict): def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, is_unet=True):
rank_pattern = {} rank_pattern = {}
alpha_pattern = {} alpha_pattern = {}
r = lora_alpha = list(rank_dict.values())[0] r = lora_alpha = list(rank_dict.values())[0]
if len(set(rank_dict.values())) > 1: if len(set(rank_dict.values())) > 1:
# get the rank occuring the most number of times # get the rank occuring the most number of times
r = collections.Counter(rank_dict.values()).most_common()[0][0] r = collections.Counter(rank_dict.values()).most_common()[0][0]
...@@ -118,13 +121,22 @@ def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict): ...@@ -118,13 +121,22 @@ def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict):
rank_pattern = dict(filter(lambda x: x[1] != r, rank_dict.items())) rank_pattern = dict(filter(lambda x: x[1] != r, rank_dict.items()))
rank_pattern = {k.split(".lora_B.")[0]: v for k, v in rank_pattern.items()} rank_pattern = {k.split(".lora_B.")[0]: v for k, v in rank_pattern.items()}
if network_alpha_dict is not None and len(set(network_alpha_dict.values())) > 1: if network_alpha_dict is not None:
# get the alpha occuring the most number of times if len(set(network_alpha_dict.values())) > 1:
lora_alpha = collections.Counter(network_alpha_dict.values()).most_common()[0][0] # get the alpha occuring the most number of times
lora_alpha = collections.Counter(network_alpha_dict.values()).most_common()[0][0]
# for modules with alpha different from the most occuring alpha, add it to the `alpha_pattern`
alpha_pattern = dict(filter(lambda x: x[1] != lora_alpha, network_alpha_dict.items())) # for modules with alpha different from the most occuring alpha, add it to the `alpha_pattern`
alpha_pattern = {".".join(k.split(".down.")[0].split(".")[:-1]): v for k, v in alpha_pattern.items()} alpha_pattern = dict(filter(lambda x: x[1] != lora_alpha, network_alpha_dict.items()))
if is_unet:
alpha_pattern = {
".".join(k.split(".lora_A.")[0].split(".")).replace(".alpha", ""): v
for k, v in alpha_pattern.items()
}
else:
alpha_pattern = {".".join(k.split(".down.")[0].split(".")[:-1]): v for k, v in alpha_pattern.items()}
else:
lora_alpha = set(network_alpha_dict.values()).pop()
# layer names without the Diffusers specific # layer names without the Diffusers specific
target_modules = list({name.split(".lora")[0] for name in peft_state_dict.keys()}) target_modules = list({name.split(".lora")[0] for name in peft_state_dict.keys()})
...@@ -155,9 +167,9 @@ def set_adapter_layers(model, enabled=True): ...@@ -155,9 +167,9 @@ def set_adapter_layers(model, enabled=True):
if isinstance(module, BaseTunerLayer): if isinstance(module, BaseTunerLayer):
# The recent version of PEFT needs to call `enable_adapters` instead # The recent version of PEFT needs to call `enable_adapters` instead
if hasattr(module, "enable_adapters"): if hasattr(module, "enable_adapters"):
module.enable_adapters(enabled=False) module.enable_adapters(enabled=enabled)
else: else:
module.disable_adapters = True module.disable_adapters = not enabled
def set_weights_and_activate_adapters(model, adapter_names, weights): def set_weights_and_activate_adapters(model, adapter_names, weights):
...@@ -182,3 +194,23 @@ def set_weights_and_activate_adapters(model, adapter_names, weights): ...@@ -182,3 +194,23 @@ def set_weights_and_activate_adapters(model, adapter_names, weights):
module.set_adapter(adapter_names) module.set_adapter(adapter_names)
else: else:
module.active_adapter = adapter_names module.active_adapter = adapter_names
def check_peft_version(min_version: str) -> None:
r"""
Checks if the version of PEFT is compatible.
Args:
version (`str`):
The version of PEFT to check against.
"""
if not is_peft_available():
raise ValueError("PEFT is not installed. Please install it with `pip install peft`")
is_peft_version_compatible = version.parse(importlib.metadata.version("peft")) > version.parse(min_version)
if not is_peft_version_compatible:
raise ValueError(
f"The version of PEFT you are using is not compatible, please use a version that is greater"
f" than {min_version}"
)
...@@ -28,6 +28,22 @@ class StateDictType(enum.Enum): ...@@ -28,6 +28,22 @@ class StateDictType(enum.Enum):
DIFFUSERS = "diffusers" DIFFUSERS = "diffusers"
# We need to define a proper mapping for Unet since it uses different output keys than text encoder
# e.g. to_q_lora -> q_proj / to_q
UNET_TO_DIFFUSERS = {
".to_out_lora.up": ".to_out.0.lora_B",
".to_out_lora.down": ".to_out.0.lora_A",
".to_q_lora.down": ".to_q.lora_A",
".to_q_lora.up": ".to_q.lora_B",
".to_k_lora.down": ".to_k.lora_A",
".to_k_lora.up": ".to_k.lora_B",
".to_v_lora.down": ".to_v.lora_A",
".to_v_lora.up": ".to_v.lora_B",
".lora.up": ".lora_B",
".lora.down": ".lora_A",
}
DIFFUSERS_TO_PEFT = { DIFFUSERS_TO_PEFT = {
".q_proj.lora_linear_layer.up": ".q_proj.lora_B", ".q_proj.lora_linear_layer.up": ".q_proj.lora_B",
".q_proj.lora_linear_layer.down": ".q_proj.lora_A", ".q_proj.lora_linear_layer.down": ".q_proj.lora_A",
...@@ -50,6 +66,8 @@ DIFFUSERS_OLD_TO_PEFT = { ...@@ -50,6 +66,8 @@ DIFFUSERS_OLD_TO_PEFT = {
".to_v_lora.down": ".v_proj.lora_A", ".to_v_lora.down": ".v_proj.lora_A",
".to_out_lora.up": ".out_proj.lora_B", ".to_out_lora.up": ".out_proj.lora_B",
".to_out_lora.down": ".out_proj.lora_A", ".to_out_lora.down": ".out_proj.lora_A",
".lora_linear_layer.up": ".lora_B",
".lora_linear_layer.down": ".lora_A",
} }
PEFT_TO_DIFFUSERS = { PEFT_TO_DIFFUSERS = {
...@@ -84,6 +102,10 @@ DIFFUSERS_STATE_DICT_MAPPINGS = { ...@@ -84,6 +102,10 @@ DIFFUSERS_STATE_DICT_MAPPINGS = {
StateDictType.PEFT: PEFT_TO_DIFFUSERS, StateDictType.PEFT: PEFT_TO_DIFFUSERS,
} }
KEYS_TO_ALWAYS_REPLACE = {
".processor.": ".",
}
def convert_state_dict(state_dict, mapping): def convert_state_dict(state_dict, mapping):
r""" r"""
...@@ -103,6 +125,12 @@ def convert_state_dict(state_dict, mapping): ...@@ -103,6 +125,12 @@ def convert_state_dict(state_dict, mapping):
""" """
converted_state_dict = {} converted_state_dict = {}
for k, v in state_dict.items(): for k, v in state_dict.items():
# First, filter out the keys that we always want to replace
for pattern in KEYS_TO_ALWAYS_REPLACE.keys():
if pattern in k:
new_pattern = KEYS_TO_ALWAYS_REPLACE[pattern]
k = k.replace(pattern, new_pattern)
for pattern in mapping.keys(): for pattern in mapping.keys():
if pattern in k: if pattern in k:
new_pattern = mapping[pattern] new_pattern = mapping[pattern]
...@@ -184,3 +212,11 @@ def convert_state_dict_to_diffusers(state_dict, original_type=None, **kwargs): ...@@ -184,3 +212,11 @@ def convert_state_dict_to_diffusers(state_dict, original_type=None, **kwargs):
mapping = DIFFUSERS_STATE_DICT_MAPPINGS[original_type] mapping = DIFFUSERS_STATE_DICT_MAPPINGS[original_type]
return convert_state_dict(state_dict, mapping) return convert_state_dict(state_dict, mapping)
def convert_unet_state_dict_to_peft(state_dict):
r"""
Converts a state dict from UNet format to diffusers format - i.e. by removing some keys
"""
mapping = UNET_TO_DIFFUSERS
return convert_state_dict(state_dict, mapping)
...@@ -673,6 +673,7 @@ class LoraLoaderMixinTests(unittest.TestCase): ...@@ -673,6 +673,7 @@ class LoraLoaderMixinTests(unittest.TestCase):
self.assertFalse(torch.allclose(torch.from_numpy(orig_image_slice), torch.from_numpy(lora_image_slice))) self.assertFalse(torch.allclose(torch.from_numpy(orig_image_slice), torch.from_numpy(lora_image_slice)))
@deprecate_after_peft_backend
class SDXInpaintLoraMixinTests(unittest.TestCase): class SDXInpaintLoraMixinTests(unittest.TestCase):
def get_dummy_inputs(self, device, seed=0, img_res=64, output_pil=True): def get_dummy_inputs(self, device, seed=0, img_res=64, output_pil=True):
# TODO: use tensor inputs instead of PIL, this is here just to leave the old expected_slices untouched # TODO: use tensor inputs instead of PIL, this is here just to leave the old expected_slices untouched
...@@ -1387,6 +1388,7 @@ class SDXLLoraLoaderMixinTests(unittest.TestCase): ...@@ -1387,6 +1388,7 @@ class SDXLLoraLoaderMixinTests(unittest.TestCase):
), "The pipeline was serialized with LoRA parameters fused inside of the respected modules. The loaded pipeline should yield proper outputs, henceforth." ), "The pipeline was serialized with LoRA parameters fused inside of the respected modules. The loaded pipeline should yield proper outputs, henceforth."
@deprecate_after_peft_backend
class UNet2DConditionLoRAModelTests(unittest.TestCase): class UNet2DConditionLoRAModelTests(unittest.TestCase):
model_class = UNet2DConditionModel model_class = UNet2DConditionModel
main_input_name = "sample" main_input_name = "sample"
...@@ -1635,6 +1637,7 @@ class UNet2DConditionLoRAModelTests(unittest.TestCase): ...@@ -1635,6 +1637,7 @@ class UNet2DConditionLoRAModelTests(unittest.TestCase):
assert max_diff_off_sample < expected_max_diff assert max_diff_off_sample < expected_max_diff
@deprecate_after_peft_backend
class UNet3DConditionModelTests(unittest.TestCase): class UNet3DConditionModelTests(unittest.TestCase):
model_class = UNet3DConditionModel model_class = UNet3DConditionModel
main_input_name = "sample" main_input_name = "sample"
...@@ -1877,6 +1880,7 @@ class UNet3DConditionModelTests(unittest.TestCase): ...@@ -1877,6 +1880,7 @@ class UNet3DConditionModelTests(unittest.TestCase):
@slow @slow
@deprecate_after_peft_backend
@require_torch_gpu @require_torch_gpu
class LoraIntegrationTests(unittest.TestCase): class LoraIntegrationTests(unittest.TestCase):
def test_dreambooth_old_format(self): def test_dreambooth_old_format(self):
......
...@@ -12,21 +12,27 @@ ...@@ -12,21 +12,27 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import copy
import os import os
import tempfile import tempfile
import time
import unittest import unittest
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from huggingface_hub.repocard import RepoCard
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
from diffusers import ( from diffusers import (
AutoencoderKL, AutoencoderKL,
ControlNetModel,
DDIMScheduler, DDIMScheduler,
DiffusionPipeline,
EulerDiscreteScheduler, EulerDiscreteScheduler,
StableDiffusionPipeline, StableDiffusionPipeline,
StableDiffusionXLControlNetPipeline,
StableDiffusionXLPipeline, StableDiffusionXLPipeline,
UNet2DConditionModel, UNet2DConditionModel,
) )
...@@ -35,9 +41,20 @@ from diffusers.models.attention_processor import ( ...@@ -35,9 +41,20 @@ from diffusers.models.attention_processor import (
LoRAAttnProcessor, LoRAAttnProcessor,
LoRAAttnProcessor2_0, LoRAAttnProcessor2_0,
) )
from diffusers.utils.import_utils import is_peft_available from diffusers.utils.import_utils import is_accelerate_available, is_peft_available
from diffusers.utils.testing_utils import floats_tensor, require_peft_backend, require_torch_gpu, slow from diffusers.utils.testing_utils import (
floats_tensor,
load_image,
nightly,
require_peft_backend,
require_torch_gpu,
slow,
torch_device,
)
if is_accelerate_available():
from accelerate.utils import release_memory
if is_peft_available(): if is_peft_available():
from peft import LoraConfig from peft import LoraConfig
...@@ -45,6 +62,18 @@ if is_peft_available(): ...@@ -45,6 +62,18 @@ if is_peft_available():
from peft.utils import get_peft_model_state_dict from peft.utils import get_peft_model_state_dict
def state_dicts_almost_equal(sd1, sd2):
sd1 = dict(sorted(sd1.items()))
sd2 = dict(sorted(sd2.items()))
models_are_equal = True
for ten1, ten2 in zip(sd1.values(), sd2.values()):
if (ten1 - ten2).abs().max() > 1e-3:
models_are_equal = False
return models_are_equal
def create_unet_lora_layers(unet: nn.Module): def create_unet_lora_layers(unet: nn.Module):
lora_attn_procs = {} lora_attn_procs = {}
for name in unet.attn_processors.keys(): for name in unet.attn_processors.keys():
...@@ -94,6 +123,10 @@ class PeftLoraLoaderMixinTests: ...@@ -94,6 +123,10 @@ class PeftLoraLoaderMixinTests:
r=4, lora_alpha=4, target_modules=["q_proj", "k_proj", "v_proj", "out_proj"], init_lora_weights=False r=4, lora_alpha=4, target_modules=["q_proj", "k_proj", "v_proj", "out_proj"], init_lora_weights=False
) )
unet_lora_config = LoraConfig(
r=4, lora_alpha=4, target_modules=["to_q", "to_k", "to_v", "to_out.0"], init_lora_weights=False
)
unet_lora_attn_procs, unet_lora_layers = create_unet_lora_layers(unet) unet_lora_attn_procs, unet_lora_layers = create_unet_lora_layers(unet)
if self.has_two_text_encoders: if self.has_two_text_encoders:
...@@ -120,7 +153,7 @@ class PeftLoraLoaderMixinTests: ...@@ -120,7 +153,7 @@ class PeftLoraLoaderMixinTests:
"unet_lora_layers": unet_lora_layers, "unet_lora_layers": unet_lora_layers,
"unet_lora_attn_procs": unet_lora_attn_procs, "unet_lora_attn_procs": unet_lora_attn_procs,
} }
return pipeline_components, lora_components, text_lora_config return pipeline_components, lora_components, text_lora_config, unet_lora_config
def get_dummy_inputs(self, with_generator=True): def get_dummy_inputs(self, with_generator=True):
batch_size = 1 batch_size = 1
...@@ -166,7 +199,7 @@ class PeftLoraLoaderMixinTests: ...@@ -166,7 +199,7 @@ class PeftLoraLoaderMixinTests:
""" """
Tests a simple inference and makes sure it works as expected Tests a simple inference and makes sure it works as expected
""" """
components, _, _ = self.get_dummy_components() components, _, _, _ = self.get_dummy_components()
pipe = self.pipeline_class(**components) pipe = self.pipeline_class(**components)
pipe = pipe.to(self.torch_device) pipe = pipe.to(self.torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
...@@ -180,7 +213,7 @@ class PeftLoraLoaderMixinTests: ...@@ -180,7 +213,7 @@ class PeftLoraLoaderMixinTests:
Tests a simple inference with lora attached on the text encoder Tests a simple inference with lora attached on the text encoder
and makes sure it works as expected and makes sure it works as expected
""" """
components, _, text_lora_config = self.get_dummy_components() components, _, text_lora_config, _ = self.get_dummy_components()
pipe = self.pipeline_class(**components) pipe = self.pipeline_class(**components)
pipe = pipe.to(self.torch_device) pipe = pipe.to(self.torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
...@@ -208,7 +241,7 @@ class PeftLoraLoaderMixinTests: ...@@ -208,7 +241,7 @@ class PeftLoraLoaderMixinTests:
Tests a simple inference with lora attached on the text encoder + scale argument Tests a simple inference with lora attached on the text encoder + scale argument
and makes sure it works as expected and makes sure it works as expected
""" """
components, _, text_lora_config = self.get_dummy_components() components, _, text_lora_config, _ = self.get_dummy_components()
pipe = self.pipeline_class(**components) pipe = self.pipeline_class(**components)
pipe = pipe.to(self.torch_device) pipe = pipe.to(self.torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
...@@ -252,7 +285,7 @@ class PeftLoraLoaderMixinTests: ...@@ -252,7 +285,7 @@ class PeftLoraLoaderMixinTests:
Tests a simple inference with lora attached into text encoder + fuses the lora weights into base model Tests a simple inference with lora attached into text encoder + fuses the lora weights into base model
and makes sure it works as expected and makes sure it works as expected
""" """
components, _, text_lora_config = self.get_dummy_components() components, _, text_lora_config, _ = self.get_dummy_components()
pipe = self.pipeline_class(**components) pipe = self.pipeline_class(**components)
pipe = pipe.to(self.torch_device) pipe = pipe.to(self.torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
...@@ -289,7 +322,7 @@ class PeftLoraLoaderMixinTests: ...@@ -289,7 +322,7 @@ class PeftLoraLoaderMixinTests:
Tests a simple inference with lora attached to text encoder, then unloads the lora weights Tests a simple inference with lora attached to text encoder, then unloads the lora weights
and makes sure it works as expected and makes sure it works as expected
""" """
components, _, text_lora_config = self.get_dummy_components() components, _, text_lora_config, _ = self.get_dummy_components()
pipe = self.pipeline_class(**components) pipe = self.pipeline_class(**components)
pipe = pipe.to(self.torch_device) pipe = pipe.to(self.torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
...@@ -327,7 +360,7 @@ class PeftLoraLoaderMixinTests: ...@@ -327,7 +360,7 @@ class PeftLoraLoaderMixinTests:
""" """
Tests a simple usecase where users could use saving utilities for LoRA. Tests a simple usecase where users could use saving utilities for LoRA.
""" """
components, _, text_lora_config = self.get_dummy_components() components, _, text_lora_config, _ = self.get_dummy_components()
pipe = self.pipeline_class(**components) pipe = self.pipeline_class(**components)
pipe = pipe.to(self.torch_device) pipe = pipe.to(self.torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
...@@ -387,7 +420,7 @@ class PeftLoraLoaderMixinTests: ...@@ -387,7 +420,7 @@ class PeftLoraLoaderMixinTests:
""" """
Tests a simple usecase where users could use saving utilities for LoRA through save_pretrained Tests a simple usecase where users could use saving utilities for LoRA through save_pretrained
""" """
components, _, text_lora_config = self.get_dummy_components() components, _, text_lora_config, _ = self.get_dummy_components()
pipe = self.pipeline_class(**components) pipe = self.pipeline_class(**components)
pipe = pipe.to(self.torch_device) pipe = pipe.to(self.torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
...@@ -431,108 +464,680 @@ class PeftLoraLoaderMixinTests: ...@@ -431,108 +464,680 @@ class PeftLoraLoaderMixinTests:
"Loading from saved checkpoints should give same results.", "Loading from saved checkpoints should give same results.",
) )
def test_simple_inference_with_text_unet_lora_save_load(self):
"""
Tests a simple usecase where users could use saving utilities for LoRA for Unet + text encoder
"""
components, _, text_lora_config, unet_lora_config = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(self.torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
class StableDiffusionLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase): output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
pipeline_class = StableDiffusionPipeline self.assertTrue(output_no_lora.shape == (1, 64, 64, 3))
scheduler_cls = DDIMScheduler
scheduler_kwargs = {
"beta_start": 0.00085,
"beta_end": 0.012,
"beta_schedule": "scaled_linear",
"clip_sample": False,
"set_alpha_to_one": False,
"steps_offset": 1,
}
unet_kwargs = {
"block_out_channels": (32, 64),
"layers_per_block": 2,
"sample_size": 32,
"in_channels": 4,
"out_channels": 4,
"down_block_types": ("DownBlock2D", "CrossAttnDownBlock2D"),
"up_block_types": ("CrossAttnUpBlock2D", "UpBlock2D"),
"cross_attention_dim": 32,
}
vae_kwargs = {
"block_out_channels": [32, 64],
"in_channels": 3,
"out_channels": 3,
"down_block_types": ["DownEncoderBlock2D", "DownEncoderBlock2D"],
"up_block_types": ["UpDecoderBlock2D", "UpDecoderBlock2D"],
"latent_channels": 4,
}
@slow pipe.text_encoder.add_adapter(text_lora_config)
@require_torch_gpu pipe.unet.add_adapter(unet_lora_config)
def test_integration_logits_with_scale(self):
path = "runwayml/stable-diffusion-v1-5"
lora_id = "takuma104/lora-test-text-encoder-lora-target"
pipe = StableDiffusionPipeline.from_pretrained(path, torch_dtype=torch.float32) self.assertTrue(self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
pipe.load_lora_weights(lora_id) self.assertTrue(self.check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in Unet")
pipe = pipe.to("cuda")
if self.has_two_text_encoders:
pipe.text_encoder_2.add_adapter(text_lora_config)
self.assertTrue(
self.check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
)
images_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
with tempfile.TemporaryDirectory() as tmpdirname:
text_encoder_state_dict = get_peft_model_state_dict(pipe.text_encoder)
unet_state_dict = get_peft_model_state_dict(pipe.unet)
if self.has_two_text_encoders:
text_encoder_2_state_dict = get_peft_model_state_dict(pipe.text_encoder_2)
self.pipeline_class.save_lora_weights(
save_directory=tmpdirname,
text_encoder_lora_layers=text_encoder_state_dict,
text_encoder_2_lora_layers=text_encoder_2_state_dict,
unet_lora_layers=unet_state_dict,
safe_serialization=False,
)
else:
self.pipeline_class.save_lora_weights(
save_directory=tmpdirname,
text_encoder_lora_layers=text_encoder_state_dict,
unet_lora_layers=unet_state_dict,
safe_serialization=False,
)
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))
pipe.unload_lora_weights()
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))
images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0)).images
self.assertTrue(self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
self.assertTrue(self.check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in Unet")
if self.has_two_text_encoders:
self.assertTrue(
self.check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
)
self.assertTrue( self.assertTrue(
self.check_if_lora_correctly_set(pipe.text_encoder), np.allclose(images_lora, images_lora_from_pretrained, atol=1e-3, rtol=1e-3),
"Lora not correctly set in text encoder 2", "Loading from saved checkpoints should give same results.",
) )
prompt = "a red sks dog" def test_simple_inference_with_text_unet_lora_and_scale(self):
"""
Tests a simple inference with lora attached on the text encoder + Unet + scale argument
and makes sure it works as expected
"""
components, _, text_lora_config, unet_lora_config = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(self.torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
images = pipe( output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
prompt=prompt, self.assertTrue(output_no_lora.shape == (1, 64, 64, 3))
num_inference_steps=15,
cross_attention_kwargs={"scale": 0.5}, pipe.text_encoder.add_adapter(text_lora_config)
generator=torch.manual_seed(0), pipe.unet.add_adapter(unet_lora_config)
output_type="np", self.assertTrue(self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
self.assertTrue(self.check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in Unet")
if self.has_two_text_encoders:
pipe.text_encoder_2.add_adapter(text_lora_config)
self.assertTrue(
self.check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
)
output_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
self.assertTrue(
not np.allclose(output_lora, output_no_lora, atol=1e-3, rtol=1e-3), "Lora should change the output"
)
output_lora_scale = pipe(
**inputs, generator=torch.manual_seed(0), cross_attention_kwargs={"scale": 0.5}
).images ).images
self.assertTrue(
not np.allclose(output_lora, output_lora_scale, atol=1e-3, rtol=1e-3),
"Lora + scale should change the output",
)
expected_slice_scale = np.array([0.307, 0.283, 0.310, 0.310, 0.300, 0.314, 0.336, 0.314, 0.321]) output_lora_0_scale = pipe(
**inputs, generator=torch.manual_seed(0), cross_attention_kwargs={"scale": 0.0}
).images
self.assertTrue(
np.allclose(output_no_lora, output_lora_0_scale, atol=1e-3, rtol=1e-3),
"Lora + 0 scale should lead to same result as no LoRA",
)
predicted_slice = images[0, -3:, -3:, -1].flatten() self.assertTrue(
pipe.text_encoder.text_model.encoder.layers[0].self_attn.q_proj.scaling["default"] == 1.0,
"The scaling parameter has not been correctly restored!",
)
self.assertTrue(np.allclose(expected_slice_scale, predicted_slice, atol=1e-3, rtol=1e-3)) def test_simple_inference_with_text_lora_unet_fused(self):
"""
Tests a simple inference with lora attached into text encoder + fuses the lora weights into base model
and makes sure it works as expected - with unet
"""
components, _, text_lora_config, unet_lora_config = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(self.torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
@slow output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
@require_torch_gpu self.assertTrue(output_no_lora.shape == (1, 64, 64, 3))
def test_integration_logits_no_scale(self):
path = "runwayml/stable-diffusion-v1-5"
lora_id = "takuma104/lora-test-text-encoder-lora-target"
pipe = StableDiffusionPipeline.from_pretrained(path, torch_dtype=torch.float32) pipe.text_encoder.add_adapter(text_lora_config)
pipe.load_lora_weights(lora_id) pipe.unet.add_adapter(unet_lora_config)
pipe = pipe.to("cuda")
self.assertTrue(self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
self.assertTrue(self.check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in Unet")
if self.has_two_text_encoders:
pipe.text_encoder_2.add_adapter(text_lora_config)
self.assertTrue(
self.check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
)
pipe.fuse_lora()
# Fusing should still keep the LoRA layers
self.assertTrue(self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
self.assertTrue(self.check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in unet")
if self.has_two_text_encoders:
self.assertTrue(
self.check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
)
ouput_fused = pipe(**inputs, generator=torch.manual_seed(0)).images
self.assertFalse(
np.allclose(ouput_fused, output_no_lora, atol=1e-3, rtol=1e-3), "Fused lora should change the output"
)
def test_simple_inference_with_text_unet_lora_unloaded(self):
"""
Tests a simple inference with lora attached to text encoder and unet, then unloads the lora weights
and makes sure it works as expected
"""
components, _, text_lora_config, unet_lora_config = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(self.torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
self.assertTrue(output_no_lora.shape == (1, 64, 64, 3))
pipe.text_encoder.add_adapter(text_lora_config)
pipe.unet.add_adapter(unet_lora_config)
self.assertTrue(self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
self.assertTrue(self.check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in Unet")
if self.has_two_text_encoders:
pipe.text_encoder_2.add_adapter(text_lora_config)
self.assertTrue(
self.check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
)
pipe.unload_lora_weights()
# unloading should remove the LoRA layers
self.assertFalse(
self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly unloaded in text encoder"
)
self.assertFalse(self.check_if_lora_correctly_set(pipe.unet), "Lora not correctly unloaded in Unet")
if self.has_two_text_encoders:
self.assertFalse(
self.check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly unloaded in text encoder 2"
)
ouput_unloaded = pipe(**inputs, generator=torch.manual_seed(0)).images
self.assertTrue( self.assertTrue(
self.check_if_lora_correctly_set(pipe.text_encoder), np.allclose(ouput_unloaded, output_no_lora, atol=1e-3, rtol=1e-3), "Fused lora should change the output"
"Lora not correctly set in text encoder",
) )
prompt = "a red sks dog" def test_simple_inference_with_text_unet_lora_unfused(self):
"""
Tests a simple inference with lora attached to text encoder and unet, then unloads the lora weights
and makes sure it works as expected
"""
components, _, text_lora_config, unet_lora_config = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(self.torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
images = pipe(prompt=prompt, num_inference_steps=30, generator=torch.manual_seed(0), output_type="np").images pipe.text_encoder.add_adapter(text_lora_config)
pipe.unet.add_adapter(unet_lora_config)
expected_slice_scale = np.array([0.074, 0.064, 0.073, 0.0842, 0.069, 0.0641, 0.0794, 0.076, 0.084]) self.assertTrue(self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
self.assertTrue(self.check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in Unet")
predicted_slice = images[0, -3:, -3:, -1].flatten() if self.has_two_text_encoders:
pipe.text_encoder_2.add_adapter(text_lora_config)
self.assertTrue(
self.check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
)
self.assertTrue(np.allclose(expected_slice_scale, predicted_slice, atol=1e-3, rtol=1e-3)) pipe.fuse_lora()
output_fused_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
class StableDiffusionXLLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase): pipe.unfuse_lora()
has_two_text_encoders = True
pipeline_class = StableDiffusionXLPipeline output_unfused_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
scheduler_cls = EulerDiscreteScheduler # unloading should remove the LoRA layers
scheduler_kwargs = { self.assertTrue(self.check_if_lora_correctly_set(pipe.text_encoder), "Unfuse should still keep LoRA layers")
"beta_start": 0.00085, self.assertTrue(self.check_if_lora_correctly_set(pipe.unet), "Unfuse should still keep LoRA layers")
"beta_end": 0.012,
"beta_schedule": "scaled_linear", if self.has_two_text_encoders:
"timestep_spacing": "leading", self.assertTrue(
"steps_offset": 1, self.check_if_lora_correctly_set(pipe.text_encoder_2), "Unfuse should still keep LoRA layers"
} )
unet_kwargs = {
"block_out_channels": (32, 64), # Fuse and unfuse should lead to the same results
"layers_per_block": 2, self.assertTrue(
np.allclose(output_fused_lora, output_unfused_lora, atol=1e-3, rtol=1e-3),
"Fused lora should change the output",
)
def test_simple_inference_with_text_unet_multi_adapter(self):
"""
Tests a simple inference with lora attached to text encoder and unet, attaches
multiple adapters and set them
"""
components, _, text_lora_config, unet_lora_config = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(self.torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
pipe.unet.add_adapter(unet_lora_config, "adapter-1")
pipe.unet.add_adapter(unet_lora_config, "adapter-2")
self.assertTrue(self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
self.assertTrue(self.check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in Unet")
if self.has_two_text_encoders:
pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2")
self.assertTrue(
self.check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
)
pipe.set_adapters("adapter-1")
output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0)).images
pipe.set_adapters("adapter-2")
output_adapter_2 = pipe(**inputs, generator=torch.manual_seed(0)).images
pipe.set_adapters(["adapter-1", "adapter-2"])
output_adapter_mixed = pipe(**inputs, generator=torch.manual_seed(0)).images
# Fuse and unfuse should lead to the same results
self.assertFalse(
np.allclose(output_adapter_1, output_adapter_2, atol=1e-3, rtol=1e-3),
"Adapter 1 and 2 should give different results",
)
self.assertFalse(
np.allclose(output_adapter_1, output_adapter_mixed, atol=1e-3, rtol=1e-3),
"Adapter 1 and mixed adapters should give different results",
)
self.assertFalse(
np.allclose(output_adapter_2, output_adapter_mixed, atol=1e-3, rtol=1e-3),
"Adapter 2 and mixed adapters should give different results",
)
pipe.disable_lora()
output_disabled = pipe(**inputs, generator=torch.manual_seed(0)).images
self.assertTrue(
np.allclose(output_no_lora, output_disabled, atol=1e-3, rtol=1e-3),
"output with no lora and output with lora disabled should give same results",
)
def test_lora_fuse_nan(self):
components, _, text_lora_config, unet_lora_config = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(self.torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
pipe.unet.add_adapter(unet_lora_config, "adapter-1")
self.assertTrue(self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
self.assertTrue(self.check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in Unet")
# corrupt one LoRA weight with `inf` values
with torch.no_grad():
pipe.unet.mid_block.attentions[0].transformer_blocks[0].attn1.to_q.lora_A["adapter-1"].weight += float(
"inf"
)
# with `safe_fusing=True` we should see an Error
with self.assertRaises(ValueError):
pipe.fuse_lora(safe_fusing=True)
# without we should not see an error, but every image will be black
pipe.fuse_lora(safe_fusing=False)
out = pipe("test", num_inference_steps=2, output_type="np").images
self.assertTrue(np.isnan(out).all())
def test_get_adapters(self):
"""
Tests a simple usecase where we attach multiple adapters and check if the results
are the expected results
"""
components, _, text_lora_config, unet_lora_config = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(self.torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
pipe.unet.add_adapter(unet_lora_config, "adapter-1")
adapter_names = pipe.get_active_adapters()
self.assertListEqual(adapter_names, ["adapter-1"])
pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
pipe.unet.add_adapter(unet_lora_config, "adapter-2")
adapter_names = pipe.get_active_adapters()
self.assertListEqual(adapter_names, ["adapter-2"])
pipe.set_adapters(["adapter-1", "adapter-2"])
self.assertListEqual(pipe.get_active_adapters(), ["adapter-1", "adapter-2"])
def test_get_list_adapters(self):
"""
Tests a simple usecase where we attach multiple adapters and check if the results
are the expected results
"""
components, _, text_lora_config, unet_lora_config = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(self.torch_device)
pipe.set_progress_bar_config(disable=None)
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
pipe.unet.add_adapter(unet_lora_config, "adapter-1")
adapter_names = pipe.get_list_adapters()
self.assertDictEqual(adapter_names, {"text_encoder": ["adapter-1"], "unet": ["adapter-1"]})
pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
pipe.unet.add_adapter(unet_lora_config, "adapter-2")
adapter_names = pipe.get_list_adapters()
self.assertDictEqual(
adapter_names, {"text_encoder": ["adapter-1", "adapter-2"], "unet": ["adapter-1", "adapter-2"]}
)
pipe.set_adapters(["adapter-1", "adapter-2"])
self.assertDictEqual(
pipe.get_list_adapters(), {"unet": ["adapter-1", "adapter-2"], "text_encoder": ["adapter-1", "adapter-2"]}
)
pipe.unet.add_adapter(unet_lora_config, "adapter-3")
self.assertDictEqual(
pipe.get_list_adapters(),
{"unet": ["adapter-1", "adapter-2", "adapter-3"], "text_encoder": ["adapter-1", "adapter-2"]},
)
@unittest.skip("This is failing for now - need to investigate")
def test_simple_inference_with_text_unet_lora_unfused_torch_compile(self):
"""
Tests a simple inference with lora attached to text encoder and unet, then unloads the lora weights
and makes sure it works as expected
"""
components, _, text_lora_config, unet_lora_config = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(self.torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
pipe.text_encoder.add_adapter(text_lora_config)
pipe.unet.add_adapter(unet_lora_config)
self.assertTrue(self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
self.assertTrue(self.check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in Unet")
if self.has_two_text_encoders:
pipe.text_encoder_2.add_adapter(text_lora_config)
self.assertTrue(
self.check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
)
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
pipe.text_encoder = torch.compile(pipe.text_encoder, mode="reduce-overhead", fullgraph=True)
if self.has_two_text_encoders:
pipe.text_encoder_2 = torch.compile(pipe.text_encoder_2, mode="reduce-overhead", fullgraph=True)
# Just makes sure it works..
_ = pipe(**inputs, generator=torch.manual_seed(0)).images
class StableDiffusionLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase):
pipeline_class = StableDiffusionPipeline
scheduler_cls = DDIMScheduler
scheduler_kwargs = {
"beta_start": 0.00085,
"beta_end": 0.012,
"beta_schedule": "scaled_linear",
"clip_sample": False,
"set_alpha_to_one": False,
"steps_offset": 1,
}
unet_kwargs = {
"block_out_channels": (32, 64),
"layers_per_block": 2,
"sample_size": 32,
"in_channels": 4,
"out_channels": 4,
"down_block_types": ("DownBlock2D", "CrossAttnDownBlock2D"),
"up_block_types": ("CrossAttnUpBlock2D", "UpBlock2D"),
"cross_attention_dim": 32,
}
vae_kwargs = {
"block_out_channels": [32, 64],
"in_channels": 3,
"out_channels": 3,
"down_block_types": ["DownEncoderBlock2D", "DownEncoderBlock2D"],
"up_block_types": ["UpDecoderBlock2D", "UpDecoderBlock2D"],
"latent_channels": 4,
}
@slow
@require_torch_gpu
def test_integration_move_lora_cpu(self):
path = "runwayml/stable-diffusion-v1-5"
lora_id = "takuma104/lora-test-text-encoder-lora-target"
pipe = StableDiffusionPipeline.from_pretrained(path, torch_dtype=torch.float16)
pipe.load_lora_weights(lora_id, adapter_name="adapter-1")
pipe.load_lora_weights(lora_id, adapter_name="adapter-2")
pipe = pipe.to("cuda")
self.assertTrue(
self.check_if_lora_correctly_set(pipe.text_encoder),
"Lora not correctly set in text encoder",
)
self.assertTrue(
self.check_if_lora_correctly_set(pipe.unet),
"Lora not correctly set in text encoder",
)
# We will offload the first adapter in CPU and check if the offloading
# has been performed correctly
pipe.set_lora_device(["adapter-1"], "cpu")
for name, module in pipe.unet.named_modules():
if "adapter-1" in name and not isinstance(module, (nn.Dropout, nn.Identity)):
self.assertTrue(module.weight.device == torch.device("cpu"))
elif "adapter-2" in name and not isinstance(module, (nn.Dropout, nn.Identity)):
self.assertTrue(module.weight.device != torch.device("cpu"))
for name, module in pipe.text_encoder.named_modules():
if "adapter-1" in name and not isinstance(module, (nn.Dropout, nn.Identity)):
self.assertTrue(module.weight.device == torch.device("cpu"))
elif "adapter-2" in name and not isinstance(module, (nn.Dropout, nn.Identity)):
self.assertTrue(module.weight.device != torch.device("cpu"))
pipe.set_lora_device(["adapter-1"], 0)
for n, m in pipe.unet.named_modules():
if "adapter-1" in n and not isinstance(m, (nn.Dropout, nn.Identity)):
self.assertTrue(m.weight.device != torch.device("cpu"))
for n, m in pipe.text_encoder.named_modules():
if "adapter-1" in n and not isinstance(m, (nn.Dropout, nn.Identity)):
self.assertTrue(m.weight.device != torch.device("cpu"))
pipe.set_lora_device(["adapter-1", "adapter-2"], "cuda")
for n, m in pipe.unet.named_modules():
if ("adapter-1" in n or "adapter-2" in n) and not isinstance(m, (nn.Dropout, nn.Identity)):
self.assertTrue(m.weight.device != torch.device("cpu"))
for n, m in pipe.text_encoder.named_modules():
if ("adapter-1" in n or "adapter-2" in n) and not isinstance(m, (nn.Dropout, nn.Identity)):
self.assertTrue(m.weight.device != torch.device("cpu"))
@slow
@require_torch_gpu
def test_integration_logits_with_scale(self):
path = "runwayml/stable-diffusion-v1-5"
lora_id = "takuma104/lora-test-text-encoder-lora-target"
pipe = StableDiffusionPipeline.from_pretrained(path, torch_dtype=torch.float32)
pipe.load_lora_weights(lora_id)
pipe = pipe.to("cuda")
self.assertTrue(
self.check_if_lora_correctly_set(pipe.text_encoder),
"Lora not correctly set in text encoder 2",
)
prompt = "a red sks dog"
images = pipe(
prompt=prompt,
num_inference_steps=15,
cross_attention_kwargs={"scale": 0.5},
generator=torch.manual_seed(0),
output_type="np",
).images
expected_slice_scale = np.array([0.307, 0.283, 0.310, 0.310, 0.300, 0.314, 0.336, 0.314, 0.321])
predicted_slice = images[0, -3:, -3:, -1].flatten()
self.assertTrue(np.allclose(expected_slice_scale, predicted_slice, atol=1e-3, rtol=1e-3))
@slow
@require_torch_gpu
def test_integration_logits_no_scale(self):
path = "runwayml/stable-diffusion-v1-5"
lora_id = "takuma104/lora-test-text-encoder-lora-target"
pipe = StableDiffusionPipeline.from_pretrained(path, torch_dtype=torch.float32)
pipe.load_lora_weights(lora_id)
pipe = pipe.to("cuda")
self.assertTrue(
self.check_if_lora_correctly_set(pipe.text_encoder),
"Lora not correctly set in text encoder",
)
prompt = "a red sks dog"
images = pipe(prompt=prompt, num_inference_steps=30, generator=torch.manual_seed(0), output_type="np").images
expected_slice_scale = np.array([0.074, 0.064, 0.073, 0.0842, 0.069, 0.0641, 0.0794, 0.076, 0.084])
predicted_slice = images[0, -3:, -3:, -1].flatten()
self.assertTrue(np.allclose(expected_slice_scale, predicted_slice, atol=1e-3, rtol=1e-3))
@nightly
@require_torch_gpu
def test_integration_logits_multi_adapter(self):
path = "stabilityai/stable-diffusion-xl-base-1.0"
lora_id = "CiroN2022/toy-face"
pipe = StableDiffusionXLPipeline.from_pretrained(path, torch_dtype=torch.float16)
pipe.load_lora_weights(lora_id, weight_name="toy_face_sdxl.safetensors", adapter_name="toy")
pipe = pipe.to("cuda")
self.assertTrue(
self.check_if_lora_correctly_set(pipe.unet),
"Lora not correctly set in Unet",
)
prompt = "toy_face of a hacker with a hoodie"
lora_scale = 0.9
images = pipe(
prompt=prompt,
num_inference_steps=30,
generator=torch.manual_seed(0),
cross_attention_kwargs={"scale": lora_scale},
output_type="np",
).images
expected_slice_scale = np.array([0.538, 0.539, 0.540, 0.540, 0.542, 0.539, 0.538, 0.541, 0.539])
predicted_slice = images[0, -3:, -3:, -1].flatten()
# import pdb; pdb.set_trace()
self.assertTrue(np.allclose(expected_slice_scale, predicted_slice, atol=1e-3, rtol=1e-3))
pipe.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
pipe.set_adapters("pixel")
prompt = "pixel art, a hacker with a hoodie, simple, flat colors"
images = pipe(
prompt,
num_inference_steps=30,
guidance_scale=7.5,
cross_attention_kwargs={"scale": lora_scale},
generator=torch.manual_seed(0),
output_type="np",
).images
predicted_slice = images[0, -3:, -3:, -1].flatten()
expected_slice_scale = np.array(
[0.61973065, 0.62018543, 0.62181497, 0.61933696, 0.6208608, 0.620576, 0.6200281, 0.62258327, 0.6259889]
)
self.assertTrue(np.allclose(expected_slice_scale, predicted_slice, atol=1e-3, rtol=1e-3))
# multi-adapter inference
pipe.set_adapters(["pixel", "toy"], adapter_weights=[0.5, 1.0])
images = pipe(
prompt,
num_inference_steps=30,
guidance_scale=7.5,
cross_attention_kwargs={"scale": 1.0},
generator=torch.manual_seed(0),
output_type="np",
).images
predicted_slice = images[0, -3:, -3:, -1].flatten()
expected_slice_scale = np.array([0.5977, 0.5985, 0.6039, 0.5976, 0.6025, 0.6036, 0.5946, 0.5979, 0.5998])
self.assertTrue(np.allclose(expected_slice_scale, predicted_slice, atol=1e-3, rtol=1e-3))
# Lora disabled
pipe.disable_lora()
images = pipe(
prompt,
num_inference_steps=30,
guidance_scale=7.5,
cross_attention_kwargs={"scale": lora_scale},
generator=torch.manual_seed(0),
output_type="np",
).images
predicted_slice = images[0, -3:, -3:, -1].flatten()
expected_slice_scale = np.array([0.54625, 0.5473, 0.5495, 0.5465, 0.5476, 0.5461, 0.5452, 0.5485, 0.5493])
self.assertTrue(np.allclose(expected_slice_scale, predicted_slice, atol=1e-3, rtol=1e-3))
class StableDiffusionXLLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase):
has_two_text_encoders = True
pipeline_class = StableDiffusionXLPipeline
scheduler_cls = EulerDiscreteScheduler
scheduler_kwargs = {
"beta_start": 0.00085,
"beta_end": 0.012,
"beta_schedule": "scaled_linear",
"timestep_spacing": "leading",
"steps_offset": 1,
}
unet_kwargs = {
"block_out_channels": (32, 64),
"layers_per_block": 2,
"sample_size": 32, "sample_size": 32,
"in_channels": 4, "in_channels": 4,
"out_channels": 4, "out_channels": 4,
...@@ -555,3 +1160,606 @@ class StableDiffusionXLLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase): ...@@ -555,3 +1160,606 @@ class StableDiffusionXLLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase):
"latent_channels": 4, "latent_channels": 4,
"sample_size": 128, "sample_size": 128,
} }
@slow
@require_torch_gpu
class LoraIntegrationTests(unittest.TestCase):
def tearDown(self):
import gc
gc.collect()
torch.cuda.empty_cache()
gc.collect()
def test_dreambooth_old_format(self):
generator = torch.Generator("cpu").manual_seed(0)
lora_model_id = "hf-internal-testing/lora_dreambooth_dog_example"
card = RepoCard.load(lora_model_id)
base_model_id = card.data.to_dict()["base_model"]
pipe = StableDiffusionPipeline.from_pretrained(base_model_id, safety_checker=None)
pipe = pipe.to(torch_device)
pipe.load_lora_weights(lora_model_id)
images = pipe(
"A photo of a sks dog floating in the river", output_type="np", generator=generator, num_inference_steps=2
).images
images = images[0, -3:, -3:, -1].flatten()
expected = np.array([0.7207, 0.6787, 0.6010, 0.7478, 0.6838, 0.6064, 0.6984, 0.6443, 0.5785])
self.assertTrue(np.allclose(images, expected, atol=1e-4))
release_memory(pipe)
def test_dreambooth_text_encoder_new_format(self):
generator = torch.Generator().manual_seed(0)
lora_model_id = "hf-internal-testing/lora-trained"
card = RepoCard.load(lora_model_id)
base_model_id = card.data.to_dict()["base_model"]
pipe = StableDiffusionPipeline.from_pretrained(base_model_id, safety_checker=None)
pipe = pipe.to(torch_device)
pipe.load_lora_weights(lora_model_id)
images = pipe("A photo of a sks dog", output_type="np", generator=generator, num_inference_steps=2).images
images = images[0, -3:, -3:, -1].flatten()
expected = np.array([0.6628, 0.6138, 0.5390, 0.6625, 0.6130, 0.5463, 0.6166, 0.5788, 0.5359])
self.assertTrue(np.allclose(images, expected, atol=1e-4))
release_memory(pipe)
def test_a1111(self):
generator = torch.Generator().manual_seed(0)
pipe = StableDiffusionPipeline.from_pretrained("hf-internal-testing/Counterfeit-V2.5", safety_checker=None).to(
torch_device
)
lora_model_id = "hf-internal-testing/civitai-light-shadow-lora"
lora_filename = "light_and_shadow.safetensors"
pipe.load_lora_weights(lora_model_id, weight_name=lora_filename)
images = pipe(
"masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2
).images
images = images[0, -3:, -3:, -1].flatten()
expected = np.array([0.3636, 0.3708, 0.3694, 0.3679, 0.3829, 0.3677, 0.3692, 0.3688, 0.3292])
self.assertTrue(np.allclose(images, expected, atol=1e-3))
release_memory(pipe)
def test_lycoris(self):
generator = torch.Generator().manual_seed(0)
pipe = StableDiffusionPipeline.from_pretrained(
"hf-internal-testing/Amixx", safety_checker=None, use_safetensors=True, variant="fp16"
).to(torch_device)
lora_model_id = "hf-internal-testing/edgLycorisMugler-light"
lora_filename = "edgLycorisMugler-light.safetensors"
pipe.load_lora_weights(lora_model_id, weight_name=lora_filename)
images = pipe(
"masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2
).images
images = images[0, -3:, -3:, -1].flatten()
expected = np.array([0.6463, 0.658, 0.599, 0.6542, 0.6512, 0.6213, 0.658, 0.6485, 0.6017])
self.assertTrue(np.allclose(images, expected, atol=1e-3))
release_memory(pipe)
def test_a1111_with_model_cpu_offload(self):
generator = torch.Generator().manual_seed(0)
pipe = StableDiffusionPipeline.from_pretrained("hf-internal-testing/Counterfeit-V2.5", safety_checker=None)
pipe.enable_model_cpu_offload()
lora_model_id = "hf-internal-testing/civitai-light-shadow-lora"
lora_filename = "light_and_shadow.safetensors"
pipe.load_lora_weights(lora_model_id, weight_name=lora_filename)
images = pipe(
"masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2
).images
images = images[0, -3:, -3:, -1].flatten()
expected = np.array([0.3636, 0.3708, 0.3694, 0.3679, 0.3829, 0.3677, 0.3692, 0.3688, 0.3292])
self.assertTrue(np.allclose(images, expected, atol=1e-3))
release_memory(pipe)
def test_a1111_with_sequential_cpu_offload(self):
generator = torch.Generator().manual_seed(0)
pipe = StableDiffusionPipeline.from_pretrained("hf-internal-testing/Counterfeit-V2.5", safety_checker=None)
pipe.enable_sequential_cpu_offload()
lora_model_id = "hf-internal-testing/civitai-light-shadow-lora"
lora_filename = "light_and_shadow.safetensors"
pipe.load_lora_weights(lora_model_id, weight_name=lora_filename)
images = pipe(
"masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2
).images
images = images[0, -3:, -3:, -1].flatten()
expected = np.array([0.3636, 0.3708, 0.3694, 0.3679, 0.3829, 0.3677, 0.3692, 0.3688, 0.3292])
self.assertTrue(np.allclose(images, expected, atol=1e-3))
release_memory(pipe)
def test_kohya_sd_v15_with_higher_dimensions(self):
generator = torch.Generator().manual_seed(0)
pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", safety_checker=None).to(
torch_device
)
lora_model_id = "hf-internal-testing/urushisato-lora"
lora_filename = "urushisato_v15.safetensors"
pipe.load_lora_weights(lora_model_id, weight_name=lora_filename)
images = pipe(
"masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2
).images
images = images[0, -3:, -3:, -1].flatten()
expected = np.array([0.7165, 0.6616, 0.5833, 0.7504, 0.6718, 0.587, 0.6871, 0.6361, 0.5694])
self.assertTrue(np.allclose(images, expected, atol=1e-3))
release_memory(pipe)
def test_vanilla_funetuning(self):
generator = torch.Generator().manual_seed(0)
lora_model_id = "hf-internal-testing/sd-model-finetuned-lora-t4"
card = RepoCard.load(lora_model_id)
base_model_id = card.data.to_dict()["base_model"]
pipe = StableDiffusionPipeline.from_pretrained(base_model_id, safety_checker=None)
pipe = pipe.to(torch_device)
pipe.load_lora_weights(lora_model_id)
images = pipe("A pokemon with blue eyes.", output_type="np", generator=generator, num_inference_steps=2).images
images = images[0, -3:, -3:, -1].flatten()
expected = np.array([0.7406, 0.699, 0.5963, 0.7493, 0.7045, 0.6096, 0.6886, 0.6388, 0.583])
self.assertTrue(np.allclose(images, expected, atol=1e-4))
release_memory(pipe)
def test_unload_kohya_lora(self):
generator = torch.manual_seed(0)
prompt = "masterpiece, best quality, mountain"
num_inference_steps = 2
pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", safety_checker=None).to(
torch_device
)
initial_images = pipe(
prompt, output_type="np", generator=generator, num_inference_steps=num_inference_steps
).images
initial_images = initial_images[0, -3:, -3:, -1].flatten()
lora_model_id = "hf-internal-testing/civitai-colored-icons-lora"
lora_filename = "Colored_Icons_by_vizsumit.safetensors"
pipe.load_lora_weights(lora_model_id, weight_name=lora_filename)
generator = torch.manual_seed(0)
lora_images = pipe(
prompt, output_type="np", generator=generator, num_inference_steps=num_inference_steps
).images
lora_images = lora_images[0, -3:, -3:, -1].flatten()
pipe.unload_lora_weights()
generator = torch.manual_seed(0)
unloaded_lora_images = pipe(
prompt, output_type="np", generator=generator, num_inference_steps=num_inference_steps
).images
unloaded_lora_images = unloaded_lora_images[0, -3:, -3:, -1].flatten()
self.assertFalse(np.allclose(initial_images, lora_images))
self.assertTrue(np.allclose(initial_images, unloaded_lora_images, atol=1e-3))
release_memory(pipe)
def test_load_unload_load_kohya_lora(self):
# This test ensures that a Kohya-style LoRA can be safely unloaded and then loaded
# without introducing any side-effects. Even though the test uses a Kohya-style
# LoRA, the underlying adapter handling mechanism is format-agnostic.
generator = torch.manual_seed(0)
prompt = "masterpiece, best quality, mountain"
num_inference_steps = 2
pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", safety_checker=None).to(
torch_device
)
initial_images = pipe(
prompt, output_type="np", generator=generator, num_inference_steps=num_inference_steps
).images
initial_images = initial_images[0, -3:, -3:, -1].flatten()
lora_model_id = "hf-internal-testing/civitai-colored-icons-lora"
lora_filename = "Colored_Icons_by_vizsumit.safetensors"
pipe.load_lora_weights(lora_model_id, weight_name=lora_filename)
generator = torch.manual_seed(0)
lora_images = pipe(
prompt, output_type="np", generator=generator, num_inference_steps=num_inference_steps
).images
lora_images = lora_images[0, -3:, -3:, -1].flatten()
pipe.unload_lora_weights()
generator = torch.manual_seed(0)
unloaded_lora_images = pipe(
prompt, output_type="np", generator=generator, num_inference_steps=num_inference_steps
).images
unloaded_lora_images = unloaded_lora_images[0, -3:, -3:, -1].flatten()
self.assertFalse(np.allclose(initial_images, lora_images))
self.assertTrue(np.allclose(initial_images, unloaded_lora_images, atol=1e-3))
# make sure we can load a LoRA again after unloading and they don't have
# any undesired effects.
pipe.load_lora_weights(lora_model_id, weight_name=lora_filename)
generator = torch.manual_seed(0)
lora_images_again = pipe(
prompt, output_type="np", generator=generator, num_inference_steps=num_inference_steps
).images
lora_images_again = lora_images_again[0, -3:, -3:, -1].flatten()
self.assertTrue(np.allclose(lora_images, lora_images_again, atol=1e-3))
release_memory(pipe)
@slow
@require_torch_gpu
class LoraSDXLIntegrationTests(unittest.TestCase):
def tearDown(self):
import gc
gc.collect()
torch.cuda.empty_cache()
gc.collect()
def test_sdxl_0_9_lora_one(self):
generator = torch.Generator().manual_seed(0)
pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-0.9")
lora_model_id = "hf-internal-testing/sdxl-0.9-daiton-lora"
lora_filename = "daiton-xl-lora-test.safetensors"
pipe.load_lora_weights(lora_model_id, weight_name=lora_filename)
pipe.enable_model_cpu_offload()
images = pipe(
"masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2
).images
images = images[0, -3:, -3:, -1].flatten()
expected = np.array([0.3838, 0.3482, 0.3588, 0.3162, 0.319, 0.3369, 0.338, 0.3366, 0.3213])
self.assertTrue(np.allclose(images, expected, atol=1e-3))
release_memory(pipe)
def test_sdxl_0_9_lora_two(self):
generator = torch.Generator().manual_seed(0)
pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-0.9")
lora_model_id = "hf-internal-testing/sdxl-0.9-costumes-lora"
lora_filename = "saijo.safetensors"
pipe.load_lora_weights(lora_model_id, weight_name=lora_filename)
pipe.enable_model_cpu_offload()
images = pipe(
"masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2
).images
images = images[0, -3:, -3:, -1].flatten()
expected = np.array([0.3137, 0.3269, 0.3355, 0.255, 0.2577, 0.2563, 0.2679, 0.2758, 0.2626])
self.assertTrue(np.allclose(images, expected, atol=1e-3))
release_memory(pipe)
def test_sdxl_0_9_lora_three(self):
generator = torch.Generator().manual_seed(0)
pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-0.9")
lora_model_id = "hf-internal-testing/sdxl-0.9-kamepan-lora"
lora_filename = "kame_sdxl_v2-000020-16rank.safetensors"
pipe.load_lora_weights(lora_model_id, weight_name=lora_filename)
pipe.enable_model_cpu_offload()
images = pipe(
"masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2
).images
images = images[0, -3:, -3:, -1].flatten()
expected = np.array([0.4015, 0.3761, 0.3616, 0.3745, 0.3462, 0.3337, 0.3564, 0.3649, 0.3468])
self.assertTrue(np.allclose(images, expected, atol=5e-3))
release_memory(pipe)
def test_sdxl_1_0_lora(self):
generator = torch.Generator().manual_seed(0)
pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0")
pipe.enable_model_cpu_offload()
lora_model_id = "hf-internal-testing/sdxl-1.0-lora"
lora_filename = "sd_xl_offset_example-lora_1.0.safetensors"
pipe.load_lora_weights(lora_model_id, weight_name=lora_filename)
images = pipe(
"masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2
).images
images = images[0, -3:, -3:, -1].flatten()
expected = np.array([0.4468, 0.4087, 0.4134, 0.366, 0.3202, 0.3505, 0.3786, 0.387, 0.3535])
self.assertTrue(np.allclose(images, expected, atol=1e-4))
release_memory(pipe)
def test_sdxl_1_0_lora_fusion(self):
generator = torch.Generator().manual_seed(0)
pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0")
lora_model_id = "hf-internal-testing/sdxl-1.0-lora"
lora_filename = "sd_xl_offset_example-lora_1.0.safetensors"
pipe.load_lora_weights(lora_model_id, weight_name=lora_filename)
pipe.fuse_lora()
# We need to unload the lora weights since in the previous API `fuse_lora` led to lora weights being
# silently deleted - otherwise this will CPU OOM
pipe.unload_lora_weights()
pipe.enable_model_cpu_offload()
images = pipe(
"masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2
).images
images = images[0, -3:, -3:, -1].flatten()
# This way we also test equivalence between LoRA fusion and the non-fusion behaviour.
expected = np.array([0.4468, 0.4087, 0.4134, 0.366, 0.3202, 0.3505, 0.3786, 0.387, 0.3535])
self.assertTrue(np.allclose(images, expected, atol=1e-4))
release_memory(pipe)
def test_sdxl_1_0_lora_unfusion(self):
generator = torch.Generator().manual_seed(0)
pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0")
lora_model_id = "hf-internal-testing/sdxl-1.0-lora"
lora_filename = "sd_xl_offset_example-lora_1.0.safetensors"
pipe.load_lora_weights(lora_model_id, weight_name=lora_filename)
pipe.fuse_lora()
pipe.enable_model_cpu_offload()
images = pipe(
"masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2
).images
images_with_fusion = images[0, -3:, -3:, -1].flatten()
pipe.unfuse_lora()
generator = torch.Generator().manual_seed(0)
images = pipe(
"masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2
).images
images_without_fusion = images[0, -3:, -3:, -1].flatten()
self.assertTrue(np.allclose(images_with_fusion, images_without_fusion, atol=1e-3))
release_memory(pipe)
def test_sdxl_1_0_lora_unfusion_effectivity(self):
pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0")
pipe.enable_model_cpu_offload()
generator = torch.Generator().manual_seed(0)
images = pipe(
"masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2
).images
original_image_slice = images[0, -3:, -3:, -1].flatten()
lora_model_id = "hf-internal-testing/sdxl-1.0-lora"
lora_filename = "sd_xl_offset_example-lora_1.0.safetensors"
pipe.load_lora_weights(lora_model_id, weight_name=lora_filename)
pipe.fuse_lora()
generator = torch.Generator().manual_seed(0)
_ = pipe(
"masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2
).images
pipe.unfuse_lora()
# We need to unload the lora weights - in the old API unfuse led to unloading the adapter weights
pipe.unload_lora_weights()
generator = torch.Generator().manual_seed(0)
images = pipe(
"masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2
).images
images_without_fusion_slice = images[0, -3:, -3:, -1].flatten()
self.assertTrue(np.allclose(original_image_slice, images_without_fusion_slice, atol=1e-3))
release_memory(pipe)
def test_sdxl_1_0_lora_fusion_efficiency(self):
generator = torch.Generator().manual_seed(0)
lora_model_id = "hf-internal-testing/sdxl-1.0-lora"
lora_filename = "sd_xl_offset_example-lora_1.0.safetensors"
pipe = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16
)
pipe.load_lora_weights(lora_model_id, weight_name=lora_filename, torch_dtype=torch.bfloat16)
pipe.enable_model_cpu_offload()
start_time = time.time()
for _ in range(3):
pipe(
"masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2
).images
end_time = time.time()
elapsed_time_non_fusion = end_time - start_time
del pipe
pipe = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16
)
pipe.load_lora_weights(lora_model_id, weight_name=lora_filename, torch_dtype=torch.bfloat16)
pipe.fuse_lora()
# We need to unload the lora weights since in the previous API `fuse_lora` led to lora weights being
# silently deleted - otherwise this will CPU OOM
pipe.unload_lora_weights()
pipe.enable_model_cpu_offload()
start_time = time.time()
generator = torch.Generator().manual_seed(0)
for _ in range(3):
pipe(
"masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2
).images
end_time = time.time()
elapsed_time_fusion = end_time - start_time
self.assertTrue(elapsed_time_fusion < elapsed_time_non_fusion)
release_memory(pipe)
def test_sdxl_1_0_last_ben(self):
generator = torch.Generator().manual_seed(0)
pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0")
pipe.enable_model_cpu_offload()
lora_model_id = "TheLastBen/Papercut_SDXL"
lora_filename = "papercut.safetensors"
pipe.load_lora_weights(lora_model_id, weight_name=lora_filename)
images = pipe("papercut.safetensors", output_type="np", generator=generator, num_inference_steps=2).images
images = images[0, -3:, -3:, -1].flatten()
expected = np.array([0.5244, 0.4347, 0.4312, 0.4246, 0.4398, 0.4409, 0.4884, 0.4938, 0.4094])
self.assertTrue(np.allclose(images, expected, atol=1e-3))
release_memory(pipe)
def test_sdxl_1_0_fuse_unfuse_all(self):
pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16)
text_encoder_1_sd = copy.deepcopy(pipe.text_encoder.state_dict())
text_encoder_2_sd = copy.deepcopy(pipe.text_encoder_2.state_dict())
unet_sd = copy.deepcopy(pipe.unet.state_dict())
pipe.load_lora_weights(
"davizca87/sun-flower", weight_name="snfw3rXL-000004.safetensors", torch_dtype=torch.float16
)
fused_te_state_dict = pipe.text_encoder.state_dict()
fused_te_2_state_dict = pipe.text_encoder_2.state_dict()
unet_state_dict = pipe.unet.state_dict()
for key, value in text_encoder_1_sd.items():
self.assertTrue(torch.allclose(fused_te_state_dict[key], value))
for key, value in text_encoder_2_sd.items():
self.assertTrue(torch.allclose(fused_te_2_state_dict[key], value))
for key, value in unet_state_dict.items():
self.assertTrue(torch.allclose(unet_state_dict[key], value))
pipe.fuse_lora()
pipe.unload_lora_weights()
assert not state_dicts_almost_equal(text_encoder_1_sd, pipe.text_encoder.state_dict())
assert not state_dicts_almost_equal(text_encoder_2_sd, pipe.text_encoder_2.state_dict())
assert not state_dicts_almost_equal(unet_sd, pipe.unet.state_dict())
release_memory(pipe)
del unet_sd, text_encoder_1_sd, text_encoder_2_sd
def test_sdxl_1_0_lora_with_sequential_cpu_offloading(self):
generator = torch.Generator().manual_seed(0)
pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0")
pipe.enable_sequential_cpu_offload()
lora_model_id = "hf-internal-testing/sdxl-1.0-lora"
lora_filename = "sd_xl_offset_example-lora_1.0.safetensors"
pipe.load_lora_weights(lora_model_id, weight_name=lora_filename)
images = pipe(
"masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2
).images
images = images[0, -3:, -3:, -1].flatten()
expected = np.array([0.4468, 0.4087, 0.4134, 0.366, 0.3202, 0.3505, 0.3786, 0.387, 0.3535])
self.assertTrue(np.allclose(images, expected, atol=1e-3))
release_memory(pipe)
def test_canny_lora(self):
controlnet = ControlNetModel.from_pretrained("diffusers/controlnet-canny-sdxl-1.0")
pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet
)
pipe.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors")
pipe.enable_sequential_cpu_offload()
generator = torch.Generator(device="cpu").manual_seed(0)
prompt = "corgi"
image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png"
)
images = pipe(prompt, image=image, generator=generator, output_type="np", num_inference_steps=3).images
assert images[0].shape == (768, 512, 3)
original_image = images[0, -3:, -3:, -1].flatten()
expected_image = np.array([0.4574, 0.4461, 0.4435, 0.4462, 0.4396, 0.439, 0.4474, 0.4486, 0.4333])
assert np.allclose(original_image, expected_image, atol=1e-04)
release_memory(pipe)
@nightly
def test_sequential_fuse_unfuse(self):
pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16)
# 1. round
pipe.load_lora_weights("Pclanglais/TintinIA", torch_dtype=torch.float16)
pipe.to("cuda")
pipe.fuse_lora()
generator = torch.Generator().manual_seed(0)
images = pipe(
"masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2
).images
image_slice = images[0, -3:, -3:, -1].flatten()
pipe.unfuse_lora()
# 2. round
pipe.load_lora_weights("ProomptEngineer/pe-balloon-diffusion-style", torch_dtype=torch.float16)
pipe.fuse_lora()
pipe.unfuse_lora()
# 3. round
pipe.load_lora_weights("ostris/crayon_style_lora_sdxl", torch_dtype=torch.float16)
pipe.fuse_lora()
pipe.unfuse_lora()
# 4. back to 1st round
pipe.load_lora_weights("Pclanglais/TintinIA", torch_dtype=torch.float16)
pipe.fuse_lora()
generator = torch.Generator().manual_seed(0)
images_2 = pipe(
"masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2
).images
image_slice_2 = images_2[0, -3:, -3:, -1].flatten()
self.assertTrue(np.allclose(image_slice, image_slice_2, atol=1e-3))
release_memory(pipe)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment