Unverified Commit 4f74a5e1 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[PEFT warnings] Only sure deprecation warnings in the future (#5240)

* [PEFT warnings] Only sure deprecation warnings in the future

* make style
parent bbe8d3ae
......@@ -27,6 +27,7 @@ from huggingface_hub import hf_hub_download, model_info
from packaging import version
from torch import nn
from . import __version__
from .models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
from .utils import (
DIFFUSERS_CACHE,
......@@ -1708,7 +1709,8 @@ class LoraLoaderMixin:
@classmethod
def _remove_text_encoder_monkey_patch_classmethod(cls, text_encoder):
deprecate("_remove_text_encoder_monkey_patch_classmethod", "0.23", LORA_DEPRECATION_MESSAGE)
if version.parse(__version__) > version.parse("0.23"):
deprecate("_remove_text_encoder_monkey_patch_classmethod", "0.25", LORA_DEPRECATION_MESSAGE)
for _, attn_module in text_encoder_attn_modules(text_encoder):
if isinstance(attn_module.q_proj, PatchedLoraProjection):
......@@ -1736,7 +1738,8 @@ class LoraLoaderMixin:
r"""
Monkey-patches the forward passes of attention modules of the text encoder.
"""
deprecate("_modify_text_encoder", "0.23", LORA_DEPRECATION_MESSAGE)
if version.parse(__version__) > version.parse("0.23"):
deprecate("_modify_text_encoder", "0.25", LORA_DEPRECATION_MESSAGE)
def create_patched_linear_lora(model, network_alpha, rank, dtype, lora_parameters):
linear_layer = model.regular_linear_layer if isinstance(model, PatchedLoraProjection) else model
......@@ -2123,7 +2126,8 @@ class LoraLoaderMixin:
module.merge()
else:
deprecate("fuse_text_encoder_lora", "0.23", LORA_DEPRECATION_MESSAGE)
if version.parse(__version__) > version.parse("0.23"):
deprecate("fuse_text_encoder_lora", "0.25", LORA_DEPRECATION_MESSAGE)
def fuse_text_encoder_lora(text_encoder, lora_scale=1.0):
for _, attn_module in text_encoder_attn_modules(text_encoder):
......@@ -2173,7 +2177,8 @@ class LoraLoaderMixin:
module.unmerge()
else:
deprecate("unfuse_text_encoder_lora", "0.23", LORA_DEPRECATION_MESSAGE)
if version.parse(__version__) > version.parse("0.23"):
deprecate("unfuse_text_encoder_lora", "0.25", LORA_DEPRECATION_MESSAGE)
def unfuse_text_encoder_lora(text_encoder):
for _, attn_module in text_encoder_attn_modules(text_encoder):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment