Unverified Commit a937e1b5 authored by Pi Esposito's avatar Pi Esposito Committed by GitHub
Browse files

add load textual inversion embeddings to stable diffusion (#2009)



* add load textual inversion embeddings draft

* fix quality

* fix typo

* make fix copies

* move to textual inversion mixin

* make it accept from sd-concept library

* accept list of paths to embeddings

* fix styling of stable diffusion pipeline

* add dummy TextualInversionMixin

* add docstring to textualinversionmixin

* add load textual inversion embeddings draft

* fix quality

* fix typo

* make fix copies

* move to textual inversion mixin

* make it accept from sd-concept library

* accept list of paths to embeddings

* fix styling of stable diffusion pipeline

* add dummy TextualInversionMixin

* add docstring to textualinversionmixin

* add case for parsing embedding from auto1111 UI format
Co-authored-by: default avatarEvan Jones <evan.a.jones3@gmail.com>
Co-authored-by: default avatarAna Tamais <aninhamoraestamais@gmail.com>

* fix style after rebase

* move textual inversion mixin to loaders

* move mixin inheritance to DiffusionPipeline from StableDiffusionPipeline)

* update dummy class name

* addressed allo comments

* fix old dangling import

* fix style

* proposal

* remove bogus

* Apply suggestions from code review
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
Co-authored-by: default avatarWill Berman <wlbberman@gmail.com>

* finish

* make style

* up

* fix code quality

* fix code quality - again

* fix code quality - 3

* fix alt diffusion code quality

* fix model editing pipeline

* Apply suggestions from code review
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>

* Finish

---------
Co-authored-by: default avatarEvan Jones <evan.a.jones3@gmail.com>
Co-authored-by: default avatarAna Tamais <aninhamoraestamais@gmail.com>
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
Co-authored-by: default avatarWill Berman <wlbberman@gmail.com>
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>
parent 1d033a95
...@@ -109,6 +109,7 @@ try: ...@@ -109,6 +109,7 @@ try:
except OptionalDependencyNotAvailable: except OptionalDependencyNotAvailable:
from .utils.dummy_torch_and_transformers_objects import * # noqa F403 from .utils.dummy_torch_and_transformers_objects import * # noqa F403
else: else:
from .loaders import TextualInversionLoaderMixin
from .pipelines import ( from .pipelines import (
AltDiffusionImg2ImgPipeline, AltDiffusionImg2ImgPipeline,
AltDiffusionPipeline, AltDiffusionPipeline,
......
...@@ -13,18 +13,28 @@ ...@@ -13,18 +13,28 @@
# limitations under the License. # limitations under the License.
import os import os
from collections import defaultdict from collections import defaultdict
from typing import Callable, Dict, Union from typing import Callable, Dict, List, Optional, Union
import torch import torch
from .models.attention_processor import LoRAAttnProcessor from .models.attention_processor import LoRAAttnProcessor
from .models.modeling_utils import _get_model_file from .utils import (
from .utils import DIFFUSERS_CACHE, HF_HUB_OFFLINE, deprecate, is_safetensors_available, logging DIFFUSERS_CACHE,
HF_HUB_OFFLINE,
_get_model_file,
deprecate,
is_safetensors_available,
is_transformers_available,
logging,
)
if is_safetensors_available(): if is_safetensors_available():
import safetensors import safetensors
if is_transformers_available():
from transformers import PreTrainedModel, PreTrainedTokenizer
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
...@@ -32,6 +42,9 @@ logger = logging.get_logger(__name__) ...@@ -32,6 +42,9 @@ logger = logging.get_logger(__name__)
LORA_WEIGHT_NAME = "pytorch_lora_weights.bin" LORA_WEIGHT_NAME = "pytorch_lora_weights.bin"
LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors" LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors"
TEXT_INVERSION_NAME = "learned_embeds.bin"
TEXT_INVERSION_NAME_SAFE = "learned_embeds.safetensors"
class AttnProcsLayers(torch.nn.Module): class AttnProcsLayers(torch.nn.Module):
def __init__(self, state_dict: Dict[str, torch.Tensor]): def __init__(self, state_dict: Dict[str, torch.Tensor]):
...@@ -123,13 +136,6 @@ class UNet2DConditionLoadersMixin: ...@@ -123,13 +136,6 @@ class UNet2DConditionLoadersMixin:
It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated
models](https://huggingface.co/docs/hub/models-gated#gated-models). models](https://huggingface.co/docs/hub/models-gated#gated-models).
</Tip>
<Tip>
Activate the special ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use
this method in a firewalled environment.
</Tip> </Tip>
""" """
...@@ -292,5 +298,272 @@ class UNet2DConditionLoadersMixin: ...@@ -292,5 +298,272 @@ class UNet2DConditionLoadersMixin:
# Save the model # Save the model
save_function(state_dict, os.path.join(save_directory, weight_name)) save_function(state_dict, os.path.join(save_directory, weight_name))
logger.info(f"Model weights saved in {os.path.join(save_directory, weight_name)}") logger.info(f"Model weights saved in {os.path.join(save_directory, weight_name)}")
class TextualInversionLoaderMixin:
r"""
Mixin class for loading textual inversion tokens and embeddings to the tokenizer and text encoder.
"""
def maybe_convert_prompt(self, prompt: Union[str, List[str]], tokenizer: PreTrainedTokenizer):
r"""
Maybe convert a prompt into a "multi vector"-compatible prompt. If the prompt includes a token that corresponds
to a multi-vector textual inversion embedding, this function will process the prompt so that the special token
is replaced with multiple special tokens each corresponding to one of the vectors. If the prompt has no textual
inversion token or a textual inversion token that is a single vector, the input prompt is simply returned.
Parameters:
prompt (`str` or list of `str`):
The prompt or prompts to guide the image generation.
tokenizer (`PreTrainedTokenizer`):
The tokenizer responsible for encoding the prompt into input tokens.
Returns:
`str` or list of `str`: The converted prompt
"""
if not isinstance(prompt, List):
prompts = [prompt]
else:
prompts = prompt
prompts = [self._maybe_convert_prompt(p, tokenizer) for p in prompts]
if not isinstance(prompt, List):
return prompts[0]
return prompts
def _maybe_convert_prompt(self, prompt: str, tokenizer: PreTrainedTokenizer):
r"""
Maybe convert a prompt into a "multi vector"-compatible prompt. If the prompt includes a token that corresponds
to a multi-vector textual inversion embedding, this function will process the prompt so that the special token
is replaced with multiple special tokens each corresponding to one of the vectors. If the prompt has no textual
inversion token or a textual inversion token that is a single vector, the input prompt is simply returned.
Parameters:
prompt (`str`):
The prompt to guide the image generation.
tokenizer (`PreTrainedTokenizer`):
The tokenizer responsible for encoding the prompt into input tokens.
Returns:
`str`: The converted prompt
"""
tokens = tokenizer.tokenize(prompt)
for token in tokens:
if token in tokenizer.added_tokens_encoder:
replacement = token
i = 1
while f"{token}_{i}" in tokenizer.added_tokens_encoder:
replacement += f"{token}_{i}"
i += 1
prompt = prompt.replace(token, replacement)
return prompt
def load_textual_inversion(
self, pretrained_model_name_or_path: Union[str, Dict[str, torch.Tensor]], token: Optional[str] = None, **kwargs
):
r"""
Load textual inversion embeddings into the text encoder of stable diffusion pipelines. Both `diffusers` and
`Automatic1111` formats are supported.
<Tip warning={true}>
This function is experimental and might change in the future.
</Tip>
Parameters:
pretrained_model_name_or_path (`str` or `os.PathLike`):
Can be either:
- A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
Valid model ids should have an organization name, like
`"sd-concepts-library/low-poly-hd-logos-icons"`.
- A path to a *directory* containing textual inversion weights, e.g.
`./my_text_inversion_directory/`.
weight_name (`str`, *optional*):
Name of a custom weight file. This should be used in two cases:
- The saved textual inversion file is in `diffusers` format, but was saved under a specific weight
name, such as `text_inv.bin`.
- The saved textual inversion file is in the "Automatic1111" form.
cache_dir (`Union[str, os.PathLike]`, *optional*):
Path to a directory in which a downloaded pretrained model configuration should be cached if the
standard cache should not be used.
force_download (`bool`, *optional*, defaults to `False`):
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
cached versions if they exist.
resume_download (`bool`, *optional*, defaults to `False`):
Whether or not to delete incompletely received files. Will attempt to resume the download if such a
file exists.
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
local_files_only(`bool`, *optional*, defaults to `False`):
Whether or not to only look at local files (i.e., do not try to download the model).
use_auth_token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
when running `diffusers-cli login` (stored in `~/.huggingface`).
revision (`str`, *optional*, defaults to `"main"`):
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
identifier allowed by git.
subfolder (`str`, *optional*, defaults to `""`):
In case the relevant files are located inside a subfolder of the model repo (either remote in
huggingface.co or downloaded locally), you can specify the folder name here.
mirror (`str`, *optional*):
Mirror source to accelerate downloads in China. If you are from China and have an accessibility
problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety.
Please refer to the mirror site for more information.
<Tip>
It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated
models](https://huggingface.co/docs/hub/models-gated#gated-models).
</Tip>
"""
if not hasattr(self, "tokenizer") or not isinstance(self.tokenizer, PreTrainedTokenizer):
raise ValueError(
f"{self.__class__.__name__} requires `self.tokenizer` of type `PreTrainedTokenizer` for calling"
f" `{self.load_textual_inversion.__name__}`"
)
if not hasattr(self, "text_encoder") or not isinstance(self.text_encoder, PreTrainedModel):
raise ValueError(
f"{self.__class__.__name__} requires `self.text_encoder` of type `PreTrainedModel` for calling"
f" `{self.load_textual_inversion.__name__}`"
)
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
force_download = kwargs.pop("force_download", False)
resume_download = kwargs.pop("resume_download", False)
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
use_auth_token = kwargs.pop("use_auth_token", None)
revision = kwargs.pop("revision", None)
subfolder = kwargs.pop("subfolder", None)
weight_name = kwargs.pop("weight_name", None)
use_safetensors = kwargs.pop("use_safetensors", None)
if use_safetensors and not is_safetensors_available():
raise ValueError(
"`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetenstors"
)
allow_pickle = False
if use_safetensors is None:
use_safetensors = is_safetensors_available()
allow_pickle = True
user_agent = {
"file_type": "text_inversion",
"framework": "pytorch",
}
# 1. Load textual inversion file
model_file = None
# Let's first try to load .safetensors weights
if (use_safetensors and weight_name is None) or (
weight_name is not None and weight_name.endswith(".safetensors")
):
try:
model_file = _get_model_file(
pretrained_model_name_or_path,
weights_name=weight_name or TEXT_INVERSION_NAME_SAFE,
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
)
state_dict = safetensors.torch.load_file(model_file, device="cpu")
except Exception as e:
if not allow_pickle:
raise e
model_file = None
if model_file is None:
model_file = _get_model_file(
pretrained_model_name_or_path,
weights_name=weight_name or TEXT_INVERSION_NAME,
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
)
state_dict = torch.load(model_file, map_location="cpu")
# 2. Load token and embedding correcly from file
if isinstance(state_dict, torch.Tensor):
if token is None:
raise ValueError(
"You are trying to load a textual inversion embedding that has been saved as a PyTorch tensor. Make sure to pass the name of the corresponding token in this case: `token=...`."
)
embedding = state_dict
elif len(state_dict) == 1:
# diffusers
loaded_token, embedding = next(iter(state_dict.items()))
elif "string_to_param" in state_dict:
# A1111
loaded_token = state_dict["name"]
embedding = state_dict["string_to_param"]["*"]
if token is not None and loaded_token != token:
logger.warn(f"The loaded token: {loaded_token} is overwritten by the passed token {token}.")
else:
token = loaded_token
embedding = embedding.to(dtype=self.text_encoder.dtype, device=self.text_encoder.device)
# 3. Make sure we don't mess up the tokenizer or text encoder
vocab = self.tokenizer.get_vocab()
if token in vocab:
raise ValueError(
f"Token {token} already in tokenizer vocabulary. Please choose a different token name or remove {token} and embedding from the tokenizer and text encoder."
)
elif f"{token}_1" in vocab:
multi_vector_tokens = [token]
i = 1
while f"{token}_{i}" in self.tokenizer.added_tokens_encoder:
multi_vector_tokens.append(f"{token}_{i}")
i += 1
raise ValueError(
f"Multi-vector Token {multi_vector_tokens} already in tokenizer vocabulary. Please choose a different token name or remove the {multi_vector_tokens} and embedding from the tokenizer and text encoder."
)
is_multi_vector = len(embedding.shape) > 1 and embedding.shape[0] > 1
if is_multi_vector:
tokens = [token] + [f"{token}_{i}" for i in range(1, embedding.shape[0])]
embeddings = [e for e in embedding] # noqa: C416
else:
tokens = [token]
embeddings = [embedding] if len(embedding.shape) > 1 else [embedding[0]]
# add tokens and get ids
self.tokenizer.add_tokens(tokens)
token_ids = self.tokenizer.convert_tokens_to_ids(tokens)
# resize token embeddings and set new embeddings
self.text_encoder.resize_token_embeddings(len(self.tokenizer))
for token_id, embedding in zip(token_ids, embeddings):
self.text_encoder.get_input_embeddings().weight.data[token_id] = embedding
logger.info("Loaded textual inversion embedding for {token}.")
...@@ -16,27 +16,22 @@ ...@@ -16,27 +16,22 @@
import inspect import inspect
import os import os
import warnings
from functools import partial from functools import partial
from typing import Callable, List, Optional, Tuple, Union from typing import Callable, List, Optional, Tuple, Union
import torch import torch
from huggingface_hub import hf_hub_download
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
from packaging import version
from requests import HTTPError
from torch import Tensor, device from torch import Tensor, device
from .. import __version__ from .. import __version__
from ..utils import ( from ..utils import (
CONFIG_NAME, CONFIG_NAME,
DEPRECATED_REVISION_ARGS,
DIFFUSERS_CACHE, DIFFUSERS_CACHE,
FLAX_WEIGHTS_NAME, FLAX_WEIGHTS_NAME,
HF_HUB_OFFLINE, HF_HUB_OFFLINE,
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
SAFETENSORS_WEIGHTS_NAME, SAFETENSORS_WEIGHTS_NAME,
WEIGHTS_NAME, WEIGHTS_NAME,
_add_variant,
_get_model_file,
is_accelerate_available, is_accelerate_available,
is_safetensors_available, is_safetensors_available,
is_torch_version, is_torch_version,
...@@ -144,15 +139,6 @@ def _load_state_dict_into_model(model_to_load, state_dict): ...@@ -144,15 +139,6 @@ def _load_state_dict_into_model(model_to_load, state_dict):
return error_msgs return error_msgs
def _add_variant(weights_name: str, variant: Optional[str] = None) -> str:
if variant is not None:
splits = weights_name.split(".")
splits = splits[:-1] + [variant] + splits[-1:]
weights_name = ".".join(splits)
return weights_name
class ModelMixin(torch.nn.Module): class ModelMixin(torch.nn.Module):
r""" r"""
Base class for all models. Base class for all models.
...@@ -789,121 +775,3 @@ class ModelMixin(torch.nn.Module): ...@@ -789,121 +775,3 @@ class ModelMixin(torch.nn.Module):
return sum(p.numel() for p in non_embedding_parameters if p.requires_grad or not only_trainable) return sum(p.numel() for p in non_embedding_parameters if p.requires_grad or not only_trainable)
else: else:
return sum(p.numel() for p in self.parameters() if p.requires_grad or not only_trainable) return sum(p.numel() for p in self.parameters() if p.requires_grad or not only_trainable)
def _get_model_file(
pretrained_model_name_or_path,
*,
weights_name,
subfolder,
cache_dir,
force_download,
proxies,
resume_download,
local_files_only,
use_auth_token,
user_agent,
revision,
commit_hash=None,
):
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
if os.path.isfile(pretrained_model_name_or_path):
return pretrained_model_name_or_path
elif os.path.isdir(pretrained_model_name_or_path):
if os.path.isfile(os.path.join(pretrained_model_name_or_path, weights_name)):
# Load from a PyTorch checkpoint
model_file = os.path.join(pretrained_model_name_or_path, weights_name)
return model_file
elif subfolder is not None and os.path.isfile(
os.path.join(pretrained_model_name_or_path, subfolder, weights_name)
):
model_file = os.path.join(pretrained_model_name_or_path, subfolder, weights_name)
return model_file
else:
raise EnvironmentError(
f"Error no file named {weights_name} found in directory {pretrained_model_name_or_path}."
)
else:
# 1. First check if deprecated way of loading from branches is used
if (
revision in DEPRECATED_REVISION_ARGS
and (weights_name == WEIGHTS_NAME or weights_name == SAFETENSORS_WEIGHTS_NAME)
and version.parse(version.parse(__version__).base_version) >= version.parse("0.17.0")
):
try:
model_file = hf_hub_download(
pretrained_model_name_or_path,
filename=_add_variant(weights_name, revision),
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
user_agent=user_agent,
subfolder=subfolder,
revision=revision or commit_hash,
)
warnings.warn(
f"Loading the variant {revision} from {pretrained_model_name_or_path} via `revision='{revision}'` is deprecated. Loading instead from `revision='main'` with `variant={revision}`. Loading model variants via `revision='{revision}'` will be removed in diffusers v1. Please use `variant='{revision}'` instead.",
FutureWarning,
)
return model_file
except: # noqa: E722
warnings.warn(
f"You are loading the variant {revision} from {pretrained_model_name_or_path} via `revision='{revision}'`. This behavior is deprecated and will be removed in diffusers v1. One should use `variant='{revision}'` instead. However, it appears that {pretrained_model_name_or_path} currently does not have a {_add_variant(weights_name, revision)} file in the 'main' branch of {pretrained_model_name_or_path}. \n The Diffusers team and community would be very grateful if you could open an issue: https://github.com/huggingface/diffusers/issues/new with the title '{pretrained_model_name_or_path} is missing {_add_variant(weights_name, revision)}' so that the correct variant file can be added.",
FutureWarning,
)
try:
# 2. Load model file as usual
model_file = hf_hub_download(
pretrained_model_name_or_path,
filename=weights_name,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
user_agent=user_agent,
subfolder=subfolder,
revision=revision or commit_hash,
)
return model_file
except RepositoryNotFoundError:
raise EnvironmentError(
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier "
"listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a "
"token having permission to this repo with `use_auth_token` or log in with `huggingface-cli "
"login`."
)
except RevisionNotFoundError:
raise EnvironmentError(
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for "
"this model name. Check the model page at "
f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
)
except EntryNotFoundError:
raise EnvironmentError(
f"{pretrained_model_name_or_path} does not appear to have a file named {weights_name}."
)
except HTTPError as err:
raise EnvironmentError(
f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n{err}"
)
except ValueError:
raise EnvironmentError(
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it"
f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a"
f" directory containing a file named {weights_name} or"
" \nCheckout your internet connection or see how to run the library in"
" offline mode at 'https://huggingface.co/docs/diffusers/installation#offline-mode'."
)
except EnvironmentError:
raise EnvironmentError(
f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it from "
"'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
f"containing a file named {weights_name}"
)
...@@ -22,6 +22,7 @@ from transformers import CLIPImageProcessor, XLMRobertaTokenizer ...@@ -22,6 +22,7 @@ from transformers import CLIPImageProcessor, XLMRobertaTokenizer
from diffusers.utils import is_accelerate_available, is_accelerate_version from diffusers.utils import is_accelerate_available, is_accelerate_version
from ...configuration_utils import FrozenDict from ...configuration_utils import FrozenDict
from ...loaders import TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import deprecate, logging, randn_tensor, replace_example_docstring from ...utils import deprecate, logging, randn_tensor, replace_example_docstring
...@@ -49,7 +50,7 @@ EXAMPLE_DOC_STRING = """ ...@@ -49,7 +50,7 @@ EXAMPLE_DOC_STRING = """
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline with Stable->Alt, CLIPTextModel->RobertaSeriesModelWithTransformation, CLIPTokenizer->XLMRobertaTokenizer, AltDiffusionSafetyChecker->StableDiffusionSafetyChecker # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline with Stable->Alt, CLIPTextModel->RobertaSeriesModelWithTransformation, CLIPTokenizer->XLMRobertaTokenizer, AltDiffusionSafetyChecker->StableDiffusionSafetyChecker
class AltDiffusionPipeline(DiffusionPipeline): class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
r""" r"""
Pipeline for text-to-image generation using Alt Diffusion. Pipeline for text-to-image generation using Alt Diffusion.
...@@ -312,6 +313,10 @@ class AltDiffusionPipeline(DiffusionPipeline): ...@@ -312,6 +313,10 @@ class AltDiffusionPipeline(DiffusionPipeline):
batch_size = prompt_embeds.shape[0] batch_size = prompt_embeds.shape[0]
if prompt_embeds is None: if prompt_embeds is None:
# textual inversion: procecss multi-vector tokens if necessary
if isinstance(self, TextualInversionLoaderMixin):
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
text_inputs = self.tokenizer( text_inputs = self.tokenizer(
prompt, prompt,
padding="max_length", padding="max_length",
...@@ -372,6 +377,10 @@ class AltDiffusionPipeline(DiffusionPipeline): ...@@ -372,6 +377,10 @@ class AltDiffusionPipeline(DiffusionPipeline):
else: else:
uncond_tokens = negative_prompt uncond_tokens = negative_prompt
# textual inversion: procecss multi-vector tokens if necessary
if isinstance(self, TextualInversionLoaderMixin):
uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
max_length = prompt_embeds.shape[1] max_length = prompt_embeds.shape[1]
uncond_input = self.tokenizer( uncond_input = self.tokenizer(
uncond_tokens, uncond_tokens,
......
...@@ -25,6 +25,7 @@ from diffusers.utils import is_accelerate_available, is_accelerate_version ...@@ -25,6 +25,7 @@ from diffusers.utils import is_accelerate_available, is_accelerate_version
from ...configuration_utils import FrozenDict from ...configuration_utils import FrozenDict
from ...image_processor import VaeImageProcessor from ...image_processor import VaeImageProcessor
from ...loaders import TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import PIL_INTERPOLATION, deprecate, logging, randn_tensor, replace_example_docstring from ...utils import PIL_INTERPOLATION, deprecate, logging, randn_tensor, replace_example_docstring
...@@ -88,7 +89,7 @@ def preprocess(image): ...@@ -88,7 +89,7 @@ def preprocess(image):
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline with Stable->Alt, CLIPTextModel->RobertaSeriesModelWithTransformation, CLIPTokenizer->XLMRobertaTokenizer, AltDiffusionSafetyChecker->StableDiffusionSafetyChecker # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline with Stable->Alt, CLIPTextModel->RobertaSeriesModelWithTransformation, CLIPTokenizer->XLMRobertaTokenizer, AltDiffusionSafetyChecker->StableDiffusionSafetyChecker
class AltDiffusionImg2ImgPipeline(DiffusionPipeline): class AltDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
r""" r"""
Pipeline for text-guided image to image generation using Alt Diffusion. Pipeline for text-guided image to image generation using Alt Diffusion.
...@@ -322,6 +323,10 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -322,6 +323,10 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
batch_size = prompt_embeds.shape[0] batch_size = prompt_embeds.shape[0]
if prompt_embeds is None: if prompt_embeds is None:
# textual inversion: procecss multi-vector tokens if necessary
if isinstance(self, TextualInversionLoaderMixin):
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
text_inputs = self.tokenizer( text_inputs = self.tokenizer(
prompt, prompt,
padding="max_length", padding="max_length",
...@@ -382,6 +387,10 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -382,6 +387,10 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
else: else:
uncond_tokens = negative_prompt uncond_tokens = negative_prompt
# textual inversion: procecss multi-vector tokens if necessary
if isinstance(self, TextualInversionLoaderMixin):
uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
max_length = prompt_embeds.shape[1] max_length = prompt_embeds.shape[1]
uncond_input = self.tokenizer( uncond_input = self.tokenizer(
uncond_tokens, uncond_tokens,
......
...@@ -24,6 +24,7 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer ...@@ -24,6 +24,7 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from diffusers.utils import is_accelerate_available, is_accelerate_version from diffusers.utils import is_accelerate_available, is_accelerate_version
from ...configuration_utils import FrozenDict from ...configuration_utils import FrozenDict
from ...loaders import TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import DDIMScheduler from ...schedulers import DDIMScheduler
from ...utils import PIL_INTERPOLATION, deprecate, logging, randn_tensor from ...utils import PIL_INTERPOLATION, deprecate, logging, randn_tensor
...@@ -118,7 +119,7 @@ def compute_noise(scheduler, prev_latents, latents, timestep, noise_pred, eta): ...@@ -118,7 +119,7 @@ def compute_noise(scheduler, prev_latents, latents, timestep, noise_pred, eta):
return noise return noise
class CycleDiffusionPipeline(DiffusionPipeline): class CycleDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
r""" r"""
Pipeline for text-guided image to image generation using Stable Diffusion. Pipeline for text-guided image to image generation using Stable Diffusion.
...@@ -338,6 +339,10 @@ class CycleDiffusionPipeline(DiffusionPipeline): ...@@ -338,6 +339,10 @@ class CycleDiffusionPipeline(DiffusionPipeline):
batch_size = prompt_embeds.shape[0] batch_size = prompt_embeds.shape[0]
if prompt_embeds is None: if prompt_embeds is None:
# textual inversion: procecss multi-vector tokens if necessary
if isinstance(self, TextualInversionLoaderMixin):
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
text_inputs = self.tokenizer( text_inputs = self.tokenizer(
prompt, prompt,
padding="max_length", padding="max_length",
...@@ -398,6 +403,10 @@ class CycleDiffusionPipeline(DiffusionPipeline): ...@@ -398,6 +403,10 @@ class CycleDiffusionPipeline(DiffusionPipeline):
else: else:
uncond_tokens = negative_prompt uncond_tokens = negative_prompt
# textual inversion: procecss multi-vector tokens if necessary
if isinstance(self, TextualInversionLoaderMixin):
uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
max_length = prompt_embeds.shape[1] max_length = prompt_embeds.shape[1]
uncond_input = self.tokenizer( uncond_input = self.tokenizer(
uncond_tokens, uncond_tokens,
......
...@@ -20,6 +20,7 @@ from packaging import version ...@@ -20,6 +20,7 @@ from packaging import version
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from ...configuration_utils import FrozenDict from ...configuration_utils import FrozenDict
from ...loaders import TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import ( from ...utils import (
...@@ -52,7 +53,7 @@ EXAMPLE_DOC_STRING = """ ...@@ -52,7 +53,7 @@ EXAMPLE_DOC_STRING = """
""" """
class StableDiffusionPipeline(DiffusionPipeline): class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
r""" r"""
Pipeline for text-to-image generation using Stable Diffusion. Pipeline for text-to-image generation using Stable Diffusion.
...@@ -315,6 +316,10 @@ class StableDiffusionPipeline(DiffusionPipeline): ...@@ -315,6 +316,10 @@ class StableDiffusionPipeline(DiffusionPipeline):
batch_size = prompt_embeds.shape[0] batch_size = prompt_embeds.shape[0]
if prompt_embeds is None: if prompt_embeds is None:
# textual inversion: procecss multi-vector tokens if necessary
if isinstance(self, TextualInversionLoaderMixin):
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
text_inputs = self.tokenizer( text_inputs = self.tokenizer(
prompt, prompt,
padding="max_length", padding="max_length",
...@@ -375,6 +380,10 @@ class StableDiffusionPipeline(DiffusionPipeline): ...@@ -375,6 +380,10 @@ class StableDiffusionPipeline(DiffusionPipeline):
else: else:
uncond_tokens = negative_prompt uncond_tokens = negative_prompt
# textual inversion: procecss multi-vector tokens if necessary
if isinstance(self, TextualInversionLoaderMixin):
uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
max_length = prompt_embeds.shape[1] max_length = prompt_embeds.shape[1]
uncond_input = self.tokenizer( uncond_input = self.tokenizer(
uncond_tokens, uncond_tokens,
......
...@@ -21,6 +21,7 @@ import torch ...@@ -21,6 +21,7 @@ import torch
from torch.nn import functional as F from torch.nn import functional as F
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from ...loaders import TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...models.attention_processor import Attention from ...models.attention_processor import Attention
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
...@@ -159,7 +160,7 @@ class AttendExciteAttnProcessor: ...@@ -159,7 +160,7 @@ class AttendExciteAttnProcessor:
return hidden_states return hidden_states
class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline): class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline, TextualInversionLoaderMixin):
r""" r"""
Pipeline for text-to-image generation using Stable Diffusion and Attend and Excite. Pipeline for text-to-image generation using Stable Diffusion and Attend and Excite.
...@@ -335,6 +336,10 @@ class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline): ...@@ -335,6 +336,10 @@ class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline):
batch_size = prompt_embeds.shape[0] batch_size = prompt_embeds.shape[0]
if prompt_embeds is None: if prompt_embeds is None:
# textual inversion: procecss multi-vector tokens if necessary
if isinstance(self, TextualInversionLoaderMixin):
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
text_inputs = self.tokenizer( text_inputs = self.tokenizer(
prompt, prompt,
padding="max_length", padding="max_length",
...@@ -395,6 +400,10 @@ class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline): ...@@ -395,6 +400,10 @@ class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline):
else: else:
uncond_tokens = negative_prompt uncond_tokens = negative_prompt
# textual inversion: procecss multi-vector tokens if necessary
if isinstance(self, TextualInversionLoaderMixin):
uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
max_length = prompt_embeds.shape[1] max_length = prompt_embeds.shape[1]
uncond_input = self.tokenizer( uncond_input = self.tokenizer(
uncond_tokens, uncond_tokens,
......
...@@ -23,6 +23,7 @@ import torch ...@@ -23,6 +23,7 @@ import torch
from torch import nn from torch import nn
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from ...loaders import TextualInversionLoaderMixin
from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
from ...models.controlnet import ControlNetOutput from ...models.controlnet import ControlNetOutput
from ...models.modeling_utils import ModelMixin from ...models.modeling_utils import ModelMixin
...@@ -146,7 +147,7 @@ class MultiControlNetModel(ModelMixin): ...@@ -146,7 +147,7 @@ class MultiControlNetModel(ModelMixin):
return down_block_res_samples, mid_block_res_sample return down_block_res_samples, mid_block_res_sample
class StableDiffusionControlNetPipeline(DiffusionPipeline): class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
r""" r"""
Pipeline for text-to-image generation using Stable Diffusion with ControlNet guidance. Pipeline for text-to-image generation using Stable Diffusion with ControlNet guidance.
...@@ -354,6 +355,10 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline): ...@@ -354,6 +355,10 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline):
batch_size = prompt_embeds.shape[0] batch_size = prompt_embeds.shape[0]
if prompt_embeds is None: if prompt_embeds is None:
# textual inversion: procecss multi-vector tokens if necessary
if isinstance(self, TextualInversionLoaderMixin):
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
text_inputs = self.tokenizer( text_inputs = self.tokenizer(
prompt, prompt,
padding="max_length", padding="max_length",
...@@ -414,6 +419,10 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline): ...@@ -414,6 +419,10 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline):
else: else:
uncond_tokens = negative_prompt uncond_tokens = negative_prompt
# textual inversion: procecss multi-vector tokens if necessary
if isinstance(self, TextualInversionLoaderMixin):
uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
max_length = prompt_embeds.shape[1] max_length = prompt_embeds.shape[1]
uncond_input = self.tokenizer( uncond_input = self.tokenizer(
uncond_tokens, uncond_tokens,
......
...@@ -23,6 +23,7 @@ from packaging import version ...@@ -23,6 +23,7 @@ from packaging import version
from transformers import CLIPTextModel, CLIPTokenizer, DPTFeatureExtractor, DPTForDepthEstimation from transformers import CLIPTextModel, CLIPTokenizer, DPTFeatureExtractor, DPTForDepthEstimation
from ...configuration_utils import FrozenDict from ...configuration_utils import FrozenDict
from ...loaders import TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import PIL_INTERPOLATION, deprecate, is_accelerate_available, logging, randn_tensor from ...utils import PIL_INTERPOLATION, deprecate, is_accelerate_available, logging, randn_tensor
...@@ -54,7 +55,7 @@ def preprocess(image): ...@@ -54,7 +55,7 @@ def preprocess(image):
return image return image
class StableDiffusionDepth2ImgPipeline(DiffusionPipeline): class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
r""" r"""
Pipeline for text-guided image to image generation using Stable Diffusion. Pipeline for text-guided image to image generation using Stable Diffusion.
...@@ -200,6 +201,10 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline): ...@@ -200,6 +201,10 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline):
batch_size = prompt_embeds.shape[0] batch_size = prompt_embeds.shape[0]
if prompt_embeds is None: if prompt_embeds is None:
# textual inversion: procecss multi-vector tokens if necessary
if isinstance(self, TextualInversionLoaderMixin):
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
text_inputs = self.tokenizer( text_inputs = self.tokenizer(
prompt, prompt,
padding="max_length", padding="max_length",
...@@ -260,6 +265,10 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline): ...@@ -260,6 +265,10 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline):
else: else:
uncond_tokens = negative_prompt uncond_tokens = negative_prompt
# textual inversion: procecss multi-vector tokens if necessary
if isinstance(self, TextualInversionLoaderMixin):
uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
max_length = prompt_embeds.shape[1] max_length = prompt_embeds.shape[1]
uncond_input = self.tokenizer( uncond_input = self.tokenizer(
uncond_tokens, uncond_tokens,
......
...@@ -23,6 +23,7 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer ...@@ -23,6 +23,7 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from ...configuration_utils import FrozenDict from ...configuration_utils import FrozenDict
from ...image_processor import VaeImageProcessor from ...image_processor import VaeImageProcessor
from ...loaders import TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import ( from ...utils import (
...@@ -91,7 +92,7 @@ def preprocess(image): ...@@ -91,7 +92,7 @@ def preprocess(image):
return image return image
class StableDiffusionImg2ImgPipeline(DiffusionPipeline): class StableDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
r""" r"""
Pipeline for text-guided image to image generation using Stable Diffusion. Pipeline for text-guided image to image generation using Stable Diffusion.
...@@ -329,6 +330,10 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -329,6 +330,10 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
batch_size = prompt_embeds.shape[0] batch_size = prompt_embeds.shape[0]
if prompt_embeds is None: if prompt_embeds is None:
# textual inversion: procecss multi-vector tokens if necessary
if isinstance(self, TextualInversionLoaderMixin):
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
text_inputs = self.tokenizer( text_inputs = self.tokenizer(
prompt, prompt,
padding="max_length", padding="max_length",
...@@ -389,6 +394,10 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -389,6 +394,10 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
else: else:
uncond_tokens = negative_prompt uncond_tokens = negative_prompt
# textual inversion: procecss multi-vector tokens if necessary
if isinstance(self, TextualInversionLoaderMixin):
uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
max_length = prompt_embeds.shape[1] max_length = prompt_embeds.shape[1]
uncond_input = self.tokenizer( uncond_input = self.tokenizer(
uncond_tokens, uncond_tokens,
......
...@@ -22,6 +22,7 @@ from packaging import version ...@@ -22,6 +22,7 @@ from packaging import version
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from ...configuration_utils import FrozenDict from ...configuration_utils import FrozenDict
from ...loaders import TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import deprecate, is_accelerate_available, is_accelerate_version, logging, randn_tensor from ...utils import deprecate, is_accelerate_available, is_accelerate_version, logging, randn_tensor
...@@ -137,7 +138,7 @@ def prepare_mask_and_masked_image(image, mask): ...@@ -137,7 +138,7 @@ def prepare_mask_and_masked_image(image, mask):
return mask, masked_image return mask, masked_image
class StableDiffusionInpaintPipeline(DiffusionPipeline): class StableDiffusionInpaintPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
r""" r"""
Pipeline for text-guided image inpainting using Stable Diffusion. *This is an experimental feature*. Pipeline for text-guided image inpainting using Stable Diffusion. *This is an experimental feature*.
...@@ -381,6 +382,10 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): ...@@ -381,6 +382,10 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
batch_size = prompt_embeds.shape[0] batch_size = prompt_embeds.shape[0]
if prompt_embeds is None: if prompt_embeds is None:
# textual inversion: procecss multi-vector tokens if necessary
if isinstance(self, TextualInversionLoaderMixin):
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
text_inputs = self.tokenizer( text_inputs = self.tokenizer(
prompt, prompt,
padding="max_length", padding="max_length",
...@@ -441,6 +446,10 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): ...@@ -441,6 +446,10 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
else: else:
uncond_tokens = negative_prompt uncond_tokens = negative_prompt
# textual inversion: procecss multi-vector tokens if necessary
if isinstance(self, TextualInversionLoaderMixin):
uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
max_length = prompt_embeds.shape[1] max_length = prompt_embeds.shape[1]
uncond_input = self.tokenizer( uncond_input = self.tokenizer(
uncond_tokens, uncond_tokens,
......
...@@ -22,6 +22,7 @@ from packaging import version ...@@ -22,6 +22,7 @@ from packaging import version
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from ...configuration_utils import FrozenDict from ...configuration_utils import FrozenDict
from ...loaders import TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import ( from ...utils import (
...@@ -81,7 +82,7 @@ def preprocess_mask(mask, scale_factor=8): ...@@ -81,7 +82,7 @@ def preprocess_mask(mask, scale_factor=8):
return mask return mask
class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline): class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline, TextualInversionLoaderMixin):
r""" r"""
Pipeline for text-guided image inpainting using Stable Diffusion. *This is an experimental feature*. Pipeline for text-guided image inpainting using Stable Diffusion. *This is an experimental feature*.
...@@ -317,6 +318,10 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline): ...@@ -317,6 +318,10 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
batch_size = prompt_embeds.shape[0] batch_size = prompt_embeds.shape[0]
if prompt_embeds is None: if prompt_embeds is None:
# textual inversion: procecss multi-vector tokens if necessary
if isinstance(self, TextualInversionLoaderMixin):
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
text_inputs = self.tokenizer( text_inputs = self.tokenizer(
prompt, prompt,
padding="max_length", padding="max_length",
...@@ -377,6 +382,10 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline): ...@@ -377,6 +382,10 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
else: else:
uncond_tokens = negative_prompt uncond_tokens = negative_prompt
# textual inversion: procecss multi-vector tokens if necessary
if isinstance(self, TextualInversionLoaderMixin):
uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
max_length = prompt_embeds.shape[1] max_length = prompt_embeds.shape[1]
uncond_input = self.tokenizer( uncond_input = self.tokenizer(
uncond_tokens, uncond_tokens,
......
...@@ -20,6 +20,7 @@ import PIL ...@@ -20,6 +20,7 @@ import PIL
import torch import torch
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from ...loaders import TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import ( from ...utils import (
...@@ -60,7 +61,7 @@ def preprocess(image): ...@@ -60,7 +61,7 @@ def preprocess(image):
return image return image
class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline): class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
r""" r"""
Pipeline for pixel-level image editing by following text instructions. Based on Stable Diffusion. Pipeline for pixel-level image editing by following text instructions. Based on Stable Diffusion.
...@@ -511,6 +512,10 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline): ...@@ -511,6 +512,10 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline):
batch_size = prompt_embeds.shape[0] batch_size = prompt_embeds.shape[0]
if prompt_embeds is None: if prompt_embeds is None:
# textual inversion: procecss multi-vector tokens if necessary
if isinstance(self, TextualInversionLoaderMixin):
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
text_inputs = self.tokenizer( text_inputs = self.tokenizer(
prompt, prompt,
padding="max_length", padding="max_length",
...@@ -571,6 +576,10 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline): ...@@ -571,6 +576,10 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline):
else: else:
uncond_tokens = negative_prompt uncond_tokens = negative_prompt
# textual inversion: procecss multi-vector tokens if necessary
if isinstance(self, TextualInversionLoaderMixin):
uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
max_length = prompt_embeds.shape[1] max_length = prompt_embeds.shape[1]
uncond_input = self.tokenizer( uncond_input = self.tokenizer(
uncond_tokens, uncond_tokens,
......
...@@ -18,6 +18,7 @@ from typing import Callable, List, Optional, Union ...@@ -18,6 +18,7 @@ from typing import Callable, List, Optional, Union
import torch import torch
from k_diffusion.external import CompVisDenoiser, CompVisVDenoiser from k_diffusion.external import CompVisDenoiser, CompVisVDenoiser
from ...loaders import TextualInversionLoaderMixin
from ...pipelines import DiffusionPipeline from ...pipelines import DiffusionPipeline
from ...schedulers import LMSDiscreteScheduler from ...schedulers import LMSDiscreteScheduler
from ...utils import is_accelerate_available, is_accelerate_version, logging, randn_tensor from ...utils import is_accelerate_available, is_accelerate_version, logging, randn_tensor
...@@ -41,7 +42,7 @@ class ModelWrapper: ...@@ -41,7 +42,7 @@ class ModelWrapper:
return self.model(*args, encoder_hidden_states=encoder_hidden_states, **kwargs).sample return self.model(*args, encoder_hidden_states=encoder_hidden_states, **kwargs).sample
class StableDiffusionKDiffusionPipeline(DiffusionPipeline): class StableDiffusionKDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
r""" r"""
Pipeline for text-to-image generation using Stable Diffusion. Pipeline for text-to-image generation using Stable Diffusion.
...@@ -238,6 +239,10 @@ class StableDiffusionKDiffusionPipeline(DiffusionPipeline): ...@@ -238,6 +239,10 @@ class StableDiffusionKDiffusionPipeline(DiffusionPipeline):
batch_size = prompt_embeds.shape[0] batch_size = prompt_embeds.shape[0]
if prompt_embeds is None: if prompt_embeds is None:
# textual inversion: procecss multi-vector tokens if necessary
if isinstance(self, TextualInversionLoaderMixin):
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
text_inputs = self.tokenizer( text_inputs = self.tokenizer(
prompt, prompt,
padding="max_length", padding="max_length",
...@@ -298,6 +303,10 @@ class StableDiffusionKDiffusionPipeline(DiffusionPipeline): ...@@ -298,6 +303,10 @@ class StableDiffusionKDiffusionPipeline(DiffusionPipeline):
else: else:
uncond_tokens = negative_prompt uncond_tokens = negative_prompt
# textual inversion: procecss multi-vector tokens if necessary
if isinstance(self, TextualInversionLoaderMixin):
uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
max_length = prompt_embeds.shape[1] max_length = prompt_embeds.shape[1]
uncond_input = self.tokenizer( uncond_input = self.tokenizer(
uncond_tokens, uncond_tokens,
......
...@@ -18,6 +18,7 @@ from typing import Any, Callable, Dict, List, Optional, Union ...@@ -18,6 +18,7 @@ from typing import Any, Callable, Dict, List, Optional, Union
import torch import torch
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from ...loaders import TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import PNDMScheduler from ...schedulers import PNDMScheduler
from ...schedulers.scheduling_utils import SchedulerMixin from ...schedulers.scheduling_utils import SchedulerMixin
...@@ -52,7 +53,7 @@ EXAMPLE_DOC_STRING = """ ...@@ -52,7 +53,7 @@ EXAMPLE_DOC_STRING = """
""" """
class StableDiffusionModelEditingPipeline(DiffusionPipeline): class StableDiffusionModelEditingPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
r""" r"""
Pipeline for text-to-image model editing using "Editing Implicit Assumptions in Text-to-Image Diffusion Models". Pipeline for text-to-image model editing using "Editing Implicit Assumptions in Text-to-Image Diffusion Models".
...@@ -266,6 +267,10 @@ class StableDiffusionModelEditingPipeline(DiffusionPipeline): ...@@ -266,6 +267,10 @@ class StableDiffusionModelEditingPipeline(DiffusionPipeline):
batch_size = prompt_embeds.shape[0] batch_size = prompt_embeds.shape[0]
if prompt_embeds is None: if prompt_embeds is None:
# textual inversion: procecss multi-vector tokens if necessary
if isinstance(self, TextualInversionLoaderMixin):
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
text_inputs = self.tokenizer( text_inputs = self.tokenizer(
prompt, prompt,
padding="max_length", padding="max_length",
...@@ -326,6 +331,10 @@ class StableDiffusionModelEditingPipeline(DiffusionPipeline): ...@@ -326,6 +331,10 @@ class StableDiffusionModelEditingPipeline(DiffusionPipeline):
else: else:
uncond_tokens = negative_prompt uncond_tokens = negative_prompt
# textual inversion: procecss multi-vector tokens if necessary
if isinstance(self, TextualInversionLoaderMixin):
uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
max_length = prompt_embeds.shape[1] max_length = prompt_embeds.shape[1]
uncond_input = self.tokenizer( uncond_input = self.tokenizer(
uncond_tokens, uncond_tokens,
......
...@@ -17,6 +17,7 @@ from typing import Any, Callable, Dict, List, Optional, Union ...@@ -17,6 +17,7 @@ from typing import Any, Callable, Dict, List, Optional, Union
import torch import torch
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from ...loaders import TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import DDIMScheduler, PNDMScheduler from ...schedulers import DDIMScheduler, PNDMScheduler
from ...utils import is_accelerate_available, is_accelerate_version, logging, randn_tensor, replace_example_docstring from ...utils import is_accelerate_available, is_accelerate_version, logging, randn_tensor, replace_example_docstring
...@@ -47,7 +48,7 @@ EXAMPLE_DOC_STRING = """ ...@@ -47,7 +48,7 @@ EXAMPLE_DOC_STRING = """
""" """
class StableDiffusionPanoramaPipeline(DiffusionPipeline): class StableDiffusionPanoramaPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
r""" r"""
Pipeline for text-to-image generation using "MultiDiffusion: Fusing Diffusion Paths for Controlled Image Pipeline for text-to-image generation using "MultiDiffusion: Fusing Diffusion Paths for Controlled Image
Generation". Generation".
...@@ -230,6 +231,10 @@ class StableDiffusionPanoramaPipeline(DiffusionPipeline): ...@@ -230,6 +231,10 @@ class StableDiffusionPanoramaPipeline(DiffusionPipeline):
batch_size = prompt_embeds.shape[0] batch_size = prompt_embeds.shape[0]
if prompt_embeds is None: if prompt_embeds is None:
# textual inversion: procecss multi-vector tokens if necessary
if isinstance(self, TextualInversionLoaderMixin):
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
text_inputs = self.tokenizer( text_inputs = self.tokenizer(
prompt, prompt,
padding="max_length", padding="max_length",
...@@ -290,6 +295,10 @@ class StableDiffusionPanoramaPipeline(DiffusionPipeline): ...@@ -290,6 +295,10 @@ class StableDiffusionPanoramaPipeline(DiffusionPipeline):
else: else:
uncond_tokens = negative_prompt uncond_tokens = negative_prompt
# textual inversion: procecss multi-vector tokens if necessary
if isinstance(self, TextualInversionLoaderMixin):
uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
max_length = prompt_embeds.shape[1] max_length = prompt_embeds.shape[1]
uncond_input = self.tokenizer( uncond_input = self.tokenizer(
uncond_tokens, uncond_tokens,
......
...@@ -28,6 +28,7 @@ from transformers import ( ...@@ -28,6 +28,7 @@ from transformers import (
CLIPTokenizer, CLIPTokenizer,
) )
from ...loaders import TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...models.attention_processor import Attention from ...models.attention_processor import Attention
from ...schedulers import DDIMScheduler, DDPMScheduler, EulerAncestralDiscreteScheduler, LMSDiscreteScheduler from ...schedulers import DDIMScheduler, DDPMScheduler, EulerAncestralDiscreteScheduler, LMSDiscreteScheduler
...@@ -50,7 +51,7 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name ...@@ -50,7 +51,7 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@dataclass @dataclass
class Pix2PixInversionPipelineOutput(BaseOutput): class Pix2PixInversionPipelineOutput(BaseOutput, TextualInversionLoaderMixin):
""" """
Output class for Stable Diffusion pipelines. Output class for Stable Diffusion pipelines.
...@@ -470,6 +471,10 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline): ...@@ -470,6 +471,10 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline):
batch_size = prompt_embeds.shape[0] batch_size = prompt_embeds.shape[0]
if prompt_embeds is None: if prompt_embeds is None:
# textual inversion: procecss multi-vector tokens if necessary
if isinstance(self, TextualInversionLoaderMixin):
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
text_inputs = self.tokenizer( text_inputs = self.tokenizer(
prompt, prompt,
padding="max_length", padding="max_length",
...@@ -530,6 +535,10 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline): ...@@ -530,6 +535,10 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline):
else: else:
uncond_tokens = negative_prompt uncond_tokens = negative_prompt
# textual inversion: procecss multi-vector tokens if necessary
if isinstance(self, TextualInversionLoaderMixin):
uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
max_length = prompt_embeds.shape[1] max_length = prompt_embeds.shape[1]
uncond_input = self.tokenizer( uncond_input = self.tokenizer(
uncond_tokens, uncond_tokens,
......
...@@ -19,6 +19,7 @@ import torch ...@@ -19,6 +19,7 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from ...loaders import TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import is_accelerate_available, is_accelerate_version, logging, randn_tensor, replace_example_docstring from ...utils import is_accelerate_available, is_accelerate_version, logging, randn_tensor, replace_example_docstring
...@@ -87,7 +88,7 @@ class CrossAttnStoreProcessor: ...@@ -87,7 +88,7 @@ class CrossAttnStoreProcessor:
# Modified to get self-attention guidance scale in this paper (https://arxiv.org/pdf/2210.00939.pdf) as an input # Modified to get self-attention guidance scale in this paper (https://arxiv.org/pdf/2210.00939.pdf) as an input
class StableDiffusionSAGPipeline(DiffusionPipeline): class StableDiffusionSAGPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
r""" r"""
Pipeline for text-to-image generation using Stable Diffusion. Pipeline for text-to-image generation using Stable Diffusion.
...@@ -247,6 +248,10 @@ class StableDiffusionSAGPipeline(DiffusionPipeline): ...@@ -247,6 +248,10 @@ class StableDiffusionSAGPipeline(DiffusionPipeline):
batch_size = prompt_embeds.shape[0] batch_size = prompt_embeds.shape[0]
if prompt_embeds is None: if prompt_embeds is None:
# textual inversion: procecss multi-vector tokens if necessary
if isinstance(self, TextualInversionLoaderMixin):
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
text_inputs = self.tokenizer( text_inputs = self.tokenizer(
prompt, prompt,
padding="max_length", padding="max_length",
...@@ -307,6 +312,10 @@ class StableDiffusionSAGPipeline(DiffusionPipeline): ...@@ -307,6 +312,10 @@ class StableDiffusionSAGPipeline(DiffusionPipeline):
else: else:
uncond_tokens = negative_prompt uncond_tokens = negative_prompt
# textual inversion: procecss multi-vector tokens if necessary
if isinstance(self, TextualInversionLoaderMixin):
uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
max_length = prompt_embeds.shape[1] max_length = prompt_embeds.shape[1]
uncond_input = self.tokenizer( uncond_input = self.tokenizer(
uncond_tokens, uncond_tokens,
......
...@@ -20,6 +20,7 @@ import PIL ...@@ -20,6 +20,7 @@ import PIL
import torch import torch
from transformers import CLIPTextModel, CLIPTokenizer from transformers import CLIPTextModel, CLIPTokenizer
from ...loaders import TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import DDPMScheduler, KarrasDiffusionSchedulers from ...schedulers import DDPMScheduler, KarrasDiffusionSchedulers
from ...utils import deprecate, is_accelerate_available, logging, randn_tensor from ...utils import deprecate, is_accelerate_available, logging, randn_tensor
...@@ -50,7 +51,7 @@ def preprocess(image): ...@@ -50,7 +51,7 @@ def preprocess(image):
return image return image
class StableDiffusionUpscalePipeline(DiffusionPipeline): class StableDiffusionUpscalePipeline(DiffusionPipeline, TextualInversionLoaderMixin):
r""" r"""
Pipeline for text-guided image super-resolution using Stable Diffusion 2. Pipeline for text-guided image super-resolution using Stable Diffusion 2.
...@@ -194,6 +195,10 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline): ...@@ -194,6 +195,10 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline):
batch_size = prompt_embeds.shape[0] batch_size = prompt_embeds.shape[0]
if prompt_embeds is None: if prompt_embeds is None:
# textual inversion: procecss multi-vector tokens if necessary
if isinstance(self, TextualInversionLoaderMixin):
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
text_inputs = self.tokenizer( text_inputs = self.tokenizer(
prompt, prompt,
padding="max_length", padding="max_length",
...@@ -254,6 +259,10 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline): ...@@ -254,6 +259,10 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline):
else: else:
uncond_tokens = negative_prompt uncond_tokens = negative_prompt
# textual inversion: procecss multi-vector tokens if necessary
if isinstance(self, TextualInversionLoaderMixin):
uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
max_length = prompt_embeds.shape[1] max_length = prompt_embeds.shape[1]
uncond_input = self.tokenizer( uncond_input = self.tokenizer(
uncond_tokens, uncond_tokens,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment