Unverified Commit c81a88b2 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[Core] LoRA improvements pt. 3 (#4842)



* throw warning when more than one lora is attempted to be fused.

* introduce support of lora scale during fusion.

* change test name

* changes

* change to _lora_scale

* lora_scale to call whenever applicable.

* debugging

* lora_scale additional.

* cross_attention_kwargs

* lora_scale -> scale.

* lora_scale fix

* lora_scale in patched projection.

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* styling.

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* remove unneeded prints.

* remove unneeded prints.

* assign cross_attention_kwargs.

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* clean up.

* refactor scale retrieval logic a bit.

* fix nonetypw

* fix: tests

* add more tests

* more fixes.

* figure out a way to pass lora_scale.

* Apply suggestions from code review
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* unify the retrieval logic of lora_scale.

* move adjust_lora_scale_text_encoder to lora.py.

* introduce dynamic adjustment lora scale support to sd

* fix up copies

* Empty-Commit

* add: test to check fusion equivalence on different scales.

* handle lora fusion warning.

* make lora smaller

* make lora smaller

* make lora smaller

---------
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 2c1677ee
...@@ -26,6 +26,7 @@ from ...configuration_utils import FrozenDict ...@@ -26,6 +26,7 @@ from ...configuration_utils import FrozenDict
from ...image_processor import VaeImageProcessor from ...image_processor import VaeImageProcessor
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import DDIMInverseScheduler, KarrasDiffusionSchedulers from ...schedulers import DDIMInverseScheduler, KarrasDiffusionSchedulers
from ...utils import ( from ...utils import (
PIL_INTERPOLATION, PIL_INTERPOLATION,
...@@ -502,6 +503,9 @@ class StableDiffusionDiffEditPipeline(DiffusionPipeline, TextualInversionLoaderM ...@@ -502,6 +503,9 @@ class StableDiffusionDiffEditPipeline(DiffusionPipeline, TextualInversionLoaderM
if lora_scale is not None and isinstance(self, LoraLoaderMixin): if lora_scale is not None and isinstance(self, LoraLoaderMixin):
self._lora_scale = lora_scale self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
if prompt is not None and isinstance(prompt, str): if prompt is not None and isinstance(prompt, str):
batch_size = 1 batch_size = 1
elif prompt is not None and isinstance(prompt, list): elif prompt is not None and isinstance(prompt, list):
......
...@@ -24,6 +24,7 @@ from ...image_processor import VaeImageProcessor ...@@ -24,6 +24,7 @@ from ...image_processor import VaeImageProcessor
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...models.attention import GatedSelfAttentionDense from ...models.attention import GatedSelfAttentionDense
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import ( from ...utils import (
deprecate, deprecate,
...@@ -298,6 +299,9 @@ class StableDiffusionGLIGENPipeline(DiffusionPipeline): ...@@ -298,6 +299,9 @@ class StableDiffusionGLIGENPipeline(DiffusionPipeline):
if lora_scale is not None and isinstance(self, LoraLoaderMixin): if lora_scale is not None and isinstance(self, LoraLoaderMixin):
self._lora_scale = lora_scale self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
if prompt is not None and isinstance(prompt, str): if prompt is not None and isinstance(prompt, str):
batch_size = 1 batch_size = 1
elif prompt is not None and isinstance(prompt, list): elif prompt is not None and isinstance(prompt, list):
......
...@@ -30,6 +30,7 @@ from ...image_processor import VaeImageProcessor ...@@ -30,6 +30,7 @@ from ...image_processor import VaeImageProcessor
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...models.attention import GatedSelfAttentionDense from ...models.attention import GatedSelfAttentionDense
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import ( from ...utils import (
is_accelerate_available, is_accelerate_available,
...@@ -331,6 +332,9 @@ class StableDiffusionGLIGENTextImagePipeline(DiffusionPipeline): ...@@ -331,6 +332,9 @@ class StableDiffusionGLIGENTextImagePipeline(DiffusionPipeline):
if lora_scale is not None and isinstance(self, LoraLoaderMixin): if lora_scale is not None and isinstance(self, LoraLoaderMixin):
self._lora_scale = lora_scale self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
if prompt is not None and isinstance(prompt, str): if prompt is not None and isinstance(prompt, str):
batch_size = 1 batch_size = 1
elif prompt is not None and isinstance(prompt, list): elif prompt is not None and isinstance(prompt, list):
......
...@@ -25,6 +25,7 @@ from ...configuration_utils import FrozenDict ...@@ -25,6 +25,7 @@ from ...configuration_utils import FrozenDict
from ...image_processor import PipelineImageInput, VaeImageProcessor from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import ( from ...utils import (
PIL_INTERPOLATION, PIL_INTERPOLATION,
...@@ -323,6 +324,9 @@ class StableDiffusionImg2ImgPipeline( ...@@ -323,6 +324,9 @@ class StableDiffusionImg2ImgPipeline(
if lora_scale is not None and isinstance(self, LoraLoaderMixin): if lora_scale is not None and isinstance(self, LoraLoaderMixin):
self._lora_scale = lora_scale self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
if prompt is not None and isinstance(prompt, str): if prompt is not None and isinstance(prompt, str):
batch_size = 1 batch_size = 1
elif prompt is not None and isinstance(prompt, list): elif prompt is not None and isinstance(prompt, list):
......
...@@ -25,6 +25,7 @@ from ...configuration_utils import FrozenDict ...@@ -25,6 +25,7 @@ from ...configuration_utils import FrozenDict
from ...image_processor import PipelineImageInput, VaeImageProcessor from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AsymmetricAutoencoderKL, AutoencoderKL, UNet2DConditionModel from ...models import AsymmetricAutoencoderKL, AutoencoderKL, UNet2DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import deprecate, is_accelerate_available, is_accelerate_version, logging, randn_tensor from ...utils import deprecate, is_accelerate_available, is_accelerate_version, logging, randn_tensor
from ..pipeline_utils import DiffusionPipeline from ..pipeline_utils import DiffusionPipeline
...@@ -393,6 +394,9 @@ class StableDiffusionInpaintPipeline( ...@@ -393,6 +394,9 @@ class StableDiffusionInpaintPipeline(
if lora_scale is not None and isinstance(self, LoraLoaderMixin): if lora_scale is not None and isinstance(self, LoraLoaderMixin):
self._lora_scale = lora_scale self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
if prompt is not None and isinstance(prompt, str): if prompt is not None and isinstance(prompt, str):
batch_size = 1 batch_size = 1
elif prompt is not None and isinstance(prompt, list): elif prompt is not None and isinstance(prompt, list):
......
...@@ -25,6 +25,7 @@ from ...configuration_utils import FrozenDict ...@@ -25,6 +25,7 @@ from ...configuration_utils import FrozenDict
from ...image_processor import VaeImageProcessor from ...image_processor import VaeImageProcessor
from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import ( from ...utils import (
PIL_INTERPOLATION, PIL_INTERPOLATION,
...@@ -322,6 +323,9 @@ class StableDiffusionInpaintPipelineLegacy( ...@@ -322,6 +323,9 @@ class StableDiffusionInpaintPipelineLegacy(
if lora_scale is not None and isinstance(self, LoraLoaderMixin): if lora_scale is not None and isinstance(self, LoraLoaderMixin):
self._lora_scale = lora_scale self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
if prompt is not None and isinstance(prompt, str): if prompt is not None and isinstance(prompt, str):
batch_size = 1 batch_size = 1
elif prompt is not None and isinstance(prompt, list): elif prompt is not None and isinstance(prompt, list):
......
...@@ -22,6 +22,7 @@ from k_diffusion.sampling import BrownianTreeNoiseSampler, get_sigmas_karras ...@@ -22,6 +22,7 @@ from k_diffusion.sampling import BrownianTreeNoiseSampler, get_sigmas_karras
from ...image_processor import VaeImageProcessor from ...image_processor import VaeImageProcessor
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import LMSDiscreteScheduler from ...schedulers import LMSDiscreteScheduler
from ...utils import deprecate, is_accelerate_available, is_accelerate_version, logging, randn_tensor from ...utils import deprecate, is_accelerate_available, is_accelerate_version, logging, randn_tensor
from ..pipeline_utils import DiffusionPipeline from ..pipeline_utils import DiffusionPipeline
...@@ -229,6 +230,9 @@ class StableDiffusionKDiffusionPipeline(DiffusionPipeline, TextualInversionLoade ...@@ -229,6 +230,9 @@ class StableDiffusionKDiffusionPipeline(DiffusionPipeline, TextualInversionLoade
if lora_scale is not None and isinstance(self, LoraLoaderMixin): if lora_scale is not None and isinstance(self, LoraLoaderMixin):
self._lora_scale = lora_scale self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
if prompt is not None and isinstance(prompt, str): if prompt is not None and isinstance(prompt, str):
batch_size = 1 batch_size = 1
elif prompt is not None and isinstance(prompt, list): elif prompt is not None and isinstance(prompt, list):
......
...@@ -24,6 +24,7 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer ...@@ -24,6 +24,7 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from ...image_processor import VaeImageProcessorLDM3D from ...image_processor import VaeImageProcessorLDM3D
from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import ( from ...utils import (
BaseOutput, BaseOutput,
...@@ -293,6 +294,9 @@ class StableDiffusionLDM3DPipeline( ...@@ -293,6 +294,9 @@ class StableDiffusionLDM3DPipeline(
if lora_scale is not None and isinstance(self, LoraLoaderMixin): if lora_scale is not None and isinstance(self, LoraLoaderMixin):
self._lora_scale = lora_scale self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
if prompt is not None and isinstance(prompt, str): if prompt is not None and isinstance(prompt, str):
batch_size = 1 batch_size = 1
elif prompt is not None and isinstance(prompt, list): elif prompt is not None and isinstance(prompt, list):
......
...@@ -21,6 +21,7 @@ from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer ...@@ -21,6 +21,7 @@ from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from ...image_processor import VaeImageProcessor from ...image_processor import VaeImageProcessor
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import PNDMScheduler from ...schedulers import PNDMScheduler
from ...schedulers.scheduling_utils import SchedulerMixin from ...schedulers.scheduling_utils import SchedulerMixin
from ...utils import deprecate, logging, randn_tensor from ...utils import deprecate, logging, randn_tensor
...@@ -233,6 +234,9 @@ class StableDiffusionModelEditingPipeline(DiffusionPipeline, TextualInversionLoa ...@@ -233,6 +234,9 @@ class StableDiffusionModelEditingPipeline(DiffusionPipeline, TextualInversionLoa
if lora_scale is not None and isinstance(self, LoraLoaderMixin): if lora_scale is not None and isinstance(self, LoraLoaderMixin):
self._lora_scale = lora_scale self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
if prompt is not None and isinstance(prompt, str): if prompt is not None and isinstance(prompt, str):
batch_size = 1 batch_size = 1
elif prompt is not None and isinstance(prompt, list): elif prompt is not None and isinstance(prompt, list):
......
...@@ -21,6 +21,7 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer ...@@ -21,6 +21,7 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from ...image_processor import VaeImageProcessor from ...image_processor import VaeImageProcessor
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import DDIMScheduler from ...schedulers import DDIMScheduler
from ...utils import deprecate, logging, randn_tensor, replace_example_docstring from ...utils import deprecate, logging, randn_tensor, replace_example_docstring
from ..pipeline_utils import DiffusionPipeline from ..pipeline_utils import DiffusionPipeline
...@@ -210,6 +211,9 @@ class StableDiffusionPanoramaPipeline(DiffusionPipeline, TextualInversionLoaderM ...@@ -210,6 +211,9 @@ class StableDiffusionPanoramaPipeline(DiffusionPipeline, TextualInversionLoaderM
if lora_scale is not None and isinstance(self, LoraLoaderMixin): if lora_scale is not None and isinstance(self, LoraLoaderMixin):
self._lora_scale = lora_scale self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
if prompt is not None and isinstance(prompt, str): if prompt is not None and isinstance(prompt, str):
batch_size = 1 batch_size = 1
elif prompt is not None and isinstance(prompt, list): elif prompt is not None and isinstance(prompt, list):
......
...@@ -21,6 +21,7 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer ...@@ -21,6 +21,7 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from ...image_processor import VaeImageProcessor from ...image_processor import VaeImageProcessor
from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import ( from ...utils import (
deprecate, deprecate,
...@@ -277,6 +278,9 @@ class StableDiffusionParadigmsPipeline( ...@@ -277,6 +278,9 @@ class StableDiffusionParadigmsPipeline(
if lora_scale is not None and isinstance(self, LoraLoaderMixin): if lora_scale is not None and isinstance(self, LoraLoaderMixin):
self._lora_scale = lora_scale self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
if prompt is not None and isinstance(prompt, str): if prompt is not None and isinstance(prompt, str):
batch_size = 1 batch_size = 1
elif prompt is not None and isinstance(prompt, list): elif prompt is not None and isinstance(prompt, list):
......
...@@ -32,6 +32,7 @@ from ...image_processor import PipelineImageInput, VaeImageProcessor ...@@ -32,6 +32,7 @@ from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...models.attention_processor import Attention from ...models.attention_processor import Attention
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import DDIMScheduler, DDPMScheduler, EulerAncestralDiscreteScheduler, LMSDiscreteScheduler from ...schedulers import DDIMScheduler, DDPMScheduler, EulerAncestralDiscreteScheduler, LMSDiscreteScheduler
from ...schedulers.scheduling_ddim_inverse import DDIMInverseScheduler from ...schedulers.scheduling_ddim_inverse import DDIMInverseScheduler
from ...utils import ( from ...utils import (
...@@ -462,6 +463,9 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline): ...@@ -462,6 +463,9 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline):
if lora_scale is not None and isinstance(self, LoraLoaderMixin): if lora_scale is not None and isinstance(self, LoraLoaderMixin):
self._lora_scale = lora_scale self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
if prompt is not None and isinstance(prompt, str): if prompt is not None and isinstance(prompt, str):
batch_size = 1 batch_size = 1
elif prompt is not None and isinstance(prompt, list): elif prompt is not None and isinstance(prompt, list):
......
...@@ -22,6 +22,7 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer ...@@ -22,6 +22,7 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from ...image_processor import VaeImageProcessor from ...image_processor import VaeImageProcessor
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import deprecate, logging, randn_tensor, replace_example_docstring from ...utils import deprecate, logging, randn_tensor, replace_example_docstring
from ..pipeline_utils import DiffusionPipeline from ..pipeline_utils import DiffusionPipeline
...@@ -233,6 +234,9 @@ class StableDiffusionSAGPipeline(DiffusionPipeline, TextualInversionLoaderMixin) ...@@ -233,6 +234,9 @@ class StableDiffusionSAGPipeline(DiffusionPipeline, TextualInversionLoaderMixin)
if lora_scale is not None and isinstance(self, LoraLoaderMixin): if lora_scale is not None and isinstance(self, LoraLoaderMixin):
self._lora_scale = lora_scale self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
if prompt is not None and isinstance(prompt, str): if prompt is not None and isinstance(prompt, str):
batch_size = 1 batch_size = 1
elif prompt is not None and isinstance(prompt, list): elif prompt is not None and isinstance(prompt, list):
......
...@@ -30,6 +30,7 @@ from ...models.attention_processor import ( ...@@ -30,6 +30,7 @@ from ...models.attention_processor import (
LoRAXFormersAttnProcessor, LoRAXFormersAttnProcessor,
XFormersAttnProcessor, XFormersAttnProcessor,
) )
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import DDPMScheduler, KarrasDiffusionSchedulers from ...schedulers import DDPMScheduler, KarrasDiffusionSchedulers
from ...utils import deprecate, is_accelerate_available, is_accelerate_version, logging, randn_tensor from ...utils import deprecate, is_accelerate_available, is_accelerate_version, logging, randn_tensor
from ..pipeline_utils import DiffusionPipeline from ..pipeline_utils import DiffusionPipeline
...@@ -255,6 +256,9 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline, TextualInversionLoaderMi ...@@ -255,6 +256,9 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline, TextualInversionLoaderMi
if lora_scale is not None and isinstance(self, LoraLoaderMixin): if lora_scale is not None and isinstance(self, LoraLoaderMixin):
self._lora_scale = lora_scale self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
if prompt is not None and isinstance(prompt, str): if prompt is not None and isinstance(prompt, str):
batch_size = 1 batch_size = 1
elif prompt is not None and isinstance(prompt, list): elif prompt is not None and isinstance(prompt, list):
......
...@@ -23,6 +23,7 @@ from ...image_processor import VaeImageProcessor ...@@ -23,6 +23,7 @@ from ...image_processor import VaeImageProcessor
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, PriorTransformer, UNet2DConditionModel from ...models import AutoencoderKL, PriorTransformer, UNet2DConditionModel
from ...models.embeddings import get_timestep_embedding from ...models.embeddings import get_timestep_embedding
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import ( from ...utils import (
deprecate, deprecate,
...@@ -364,6 +365,9 @@ class StableUnCLIPPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL ...@@ -364,6 +365,9 @@ class StableUnCLIPPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL
if lora_scale is not None and isinstance(self, LoraLoaderMixin): if lora_scale is not None and isinstance(self, LoraLoaderMixin):
self._lora_scale = lora_scale self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
if prompt is not None and isinstance(prompt, str): if prompt is not None and isinstance(prompt, str):
batch_size = 1 batch_size = 1
elif prompt is not None and isinstance(prompt, list): elif prompt is not None and isinstance(prompt, list):
......
...@@ -25,6 +25,7 @@ from ...image_processor import VaeImageProcessor ...@@ -25,6 +25,7 @@ from ...image_processor import VaeImageProcessor
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...models.embeddings import get_timestep_embedding from ...models.embeddings import get_timestep_embedding
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import deprecate, is_accelerate_version, logging, randn_tensor, replace_example_docstring from ...utils import deprecate, is_accelerate_version, logging, randn_tensor, replace_example_docstring
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
...@@ -313,6 +314,9 @@ class StableUnCLIPImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin ...@@ -313,6 +314,9 @@ class StableUnCLIPImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin
if lora_scale is not None and isinstance(self, LoraLoaderMixin): if lora_scale is not None and isinstance(self, LoraLoaderMixin):
self._lora_scale = lora_scale self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
if prompt is not None and isinstance(prompt, str): if prompt is not None and isinstance(prompt, str):
batch_size = 1 batch_size = 1
elif prompt is not None and isinstance(prompt, list): elif prompt is not None and isinstance(prompt, list):
......
...@@ -20,7 +20,11 @@ import torch ...@@ -20,7 +20,11 @@ import torch
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
from ...image_processor import VaeImageProcessor from ...image_processor import VaeImageProcessor
from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...loaders import (
FromSingleFileMixin,
LoraLoaderMixin,
TextualInversionLoaderMixin,
)
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...models.attention_processor import ( from ...models.attention_processor import (
AttnProcessor2_0, AttnProcessor2_0,
...@@ -28,6 +32,7 @@ from ...models.attention_processor import ( ...@@ -28,6 +32,7 @@ from ...models.attention_processor import (
LoRAXFormersAttnProcessor, LoRAXFormersAttnProcessor,
XFormersAttnProcessor, XFormersAttnProcessor,
) )
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import ( from ...utils import (
is_accelerate_available, is_accelerate_available,
...@@ -284,6 +289,10 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoad ...@@ -284,6 +289,10 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoad
if lora_scale is not None and isinstance(self, LoraLoaderMixin): if lora_scale is not None and isinstance(self, LoraLoaderMixin):
self._lora_scale = lora_scale self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
if prompt is not None and isinstance(prompt, str): if prompt is not None and isinstance(prompt, str):
batch_size = 1 batch_size = 1
elif prompt is not None and isinstance(prompt, list): elif prompt is not None and isinstance(prompt, list):
......
...@@ -29,6 +29,7 @@ from ...models.attention_processor import ( ...@@ -29,6 +29,7 @@ from ...models.attention_processor import (
LoRAXFormersAttnProcessor, LoRAXFormersAttnProcessor,
XFormersAttnProcessor, XFormersAttnProcessor,
) )
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import ( from ...utils import (
is_accelerate_available, is_accelerate_available,
...@@ -294,6 +295,10 @@ class StableDiffusionXLImg2ImgPipeline(DiffusionPipeline, FromSingleFileMixin, L ...@@ -294,6 +295,10 @@ class StableDiffusionXLImg2ImgPipeline(DiffusionPipeline, FromSingleFileMixin, L
if lora_scale is not None and isinstance(self, LoraLoaderMixin): if lora_scale is not None and isinstance(self, LoraLoaderMixin):
self._lora_scale = lora_scale self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
if prompt is not None and isinstance(prompt, str): if prompt is not None and isinstance(prompt, str):
batch_size = 1 batch_size = 1
elif prompt is not None and isinstance(prompt, list): elif prompt is not None and isinstance(prompt, list):
......
...@@ -30,6 +30,7 @@ from ...models.attention_processor import ( ...@@ -30,6 +30,7 @@ from ...models.attention_processor import (
LoRAXFormersAttnProcessor, LoRAXFormersAttnProcessor,
XFormersAttnProcessor, XFormersAttnProcessor,
) )
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import ( from ...utils import (
deprecate, deprecate,
...@@ -444,6 +445,10 @@ class StableDiffusionXLInpaintPipeline(DiffusionPipeline, LoraLoaderMixin, FromS ...@@ -444,6 +445,10 @@ class StableDiffusionXLInpaintPipeline(DiffusionPipeline, LoraLoaderMixin, FromS
if lora_scale is not None and isinstance(self, LoraLoaderMixin): if lora_scale is not None and isinstance(self, LoraLoaderMixin):
self._lora_scale = lora_scale self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
if prompt is not None and isinstance(prompt, str): if prompt is not None and isinstance(prompt, str):
batch_size = 1 batch_size = 1
elif prompt is not None and isinstance(prompt, list): elif prompt is not None and isinstance(prompt, list):
......
...@@ -24,6 +24,7 @@ from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer ...@@ -24,6 +24,7 @@ from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from ...image_processor import VaeImageProcessor from ...image_processor import VaeImageProcessor
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, MultiAdapter, T2IAdapter, UNet2DConditionModel from ...models import AutoencoderKL, MultiAdapter, T2IAdapter, UNet2DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import ( from ...utils import (
PIL_INTERPOLATION, PIL_INTERPOLATION,
...@@ -318,6 +319,9 @@ class StableDiffusionAdapterPipeline(DiffusionPipeline): ...@@ -318,6 +319,9 @@ class StableDiffusionAdapterPipeline(DiffusionPipeline):
if lora_scale is not None and isinstance(self, LoraLoaderMixin): if lora_scale is not None and isinstance(self, LoraLoaderMixin):
self._lora_scale = lora_scale self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
if prompt is not None and isinstance(prompt, str): if prompt is not None and isinstance(prompt, str):
batch_size = 1 batch_size = 1
elif prompt is not None and isinstance(prompt, list): elif prompt is not None and isinstance(prompt, list):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment