"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "7747b588e25cb5eef4e86f13813c68e1f95849c8"
Unverified Commit dc3e0ca5 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Textual inversion] Relax loading textual inversion (#4903)

* [Textual inversion] Relax loading textual inversion

* up
parent 6c314ad0
...@@ -663,6 +663,8 @@ class TextualInversionLoaderMixin: ...@@ -663,6 +663,8 @@ class TextualInversionLoaderMixin:
self, self,
pretrained_model_name_or_path: Union[str, List[str], Dict[str, torch.Tensor], List[Dict[str, torch.Tensor]]], pretrained_model_name_or_path: Union[str, List[str], Dict[str, torch.Tensor], List[Dict[str, torch.Tensor]]],
token: Optional[Union[str, List[str]]] = None, token: Optional[Union[str, List[str]]] = None,
tokenizer: Optional[PreTrainedTokenizer] = None,
text_encoder: Optional[PreTrainedModel] = None,
**kwargs, **kwargs,
): ):
r""" r"""
...@@ -684,6 +686,11 @@ class TextualInversionLoaderMixin: ...@@ -684,6 +686,11 @@ class TextualInversionLoaderMixin:
token (`str` or `List[str]`, *optional*): token (`str` or `List[str]`, *optional*):
Override the token to use for the textual inversion weights. If `pretrained_model_name_or_path` is a Override the token to use for the textual inversion weights. If `pretrained_model_name_or_path` is a
list, then `token` must also be a list of equal length. list, then `token` must also be a list of equal length.
text_encoder ([`~transformers.CLIPTextModel`], *optional*):
Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
If not specified, function will take self.tokenizer.
tokenizer ([`~transformers.CLIPTokenizer`], *optional*):
A `CLIPTokenizer` to tokenize text. If not specified, function will take self.tokenizer.
weight_name (`str`, *optional*): weight_name (`str`, *optional*):
Name of a custom weight file. This should be used when: Name of a custom weight file. This should be used when:
...@@ -757,15 +764,18 @@ class TextualInversionLoaderMixin: ...@@ -757,15 +764,18 @@ class TextualInversionLoaderMixin:
``` ```
""" """
if not hasattr(self, "tokenizer") or not isinstance(self.tokenizer, PreTrainedTokenizer): tokenizer = tokenizer or getattr(self, "tokenizer", None)
text_encoder = text_encoder or getattr(self, "text_encoder", None)
if tokenizer is None:
raise ValueError( raise ValueError(
f"{self.__class__.__name__} requires `self.tokenizer` of type `PreTrainedTokenizer` for calling" f"{self.__class__.__name__} requires `self.tokenizer` or passing a `tokenizer` of type `PreTrainedTokenizer` for calling"
f" `{self.load_textual_inversion.__name__}`" f" `{self.load_textual_inversion.__name__}`"
) )
if not hasattr(self, "text_encoder") or not isinstance(self.text_encoder, PreTrainedModel): if text_encoder is None:
raise ValueError( raise ValueError(
f"{self.__class__.__name__} requires `self.text_encoder` of type `PreTrainedModel` for calling" f"{self.__class__.__name__} requires `self.text_encoder` or passing a `text_encoder` of type `PreTrainedModel` for calling"
f" `{self.load_textual_inversion.__name__}`" f" `{self.load_textual_inversion.__name__}`"
) )
...@@ -830,7 +840,7 @@ class TextualInversionLoaderMixin: ...@@ -830,7 +840,7 @@ class TextualInversionLoaderMixin:
token_ids_and_embeddings = [] token_ids_and_embeddings = []
for pretrained_model_name_or_path, token in zip(pretrained_model_name_or_paths, tokens): for pretrained_model_name_or_path, token in zip(pretrained_model_name_or_paths, tokens):
if not isinstance(pretrained_model_name_or_path, dict): if not isinstance(pretrained_model_name_or_path, (dict, torch.Tensor)):
# 1. Load textual inversion file # 1. Load textual inversion file
model_file = None model_file = None
# Let's first try to load .safetensors weights # Let's first try to load .safetensors weights
...@@ -897,10 +907,10 @@ class TextualInversionLoaderMixin: ...@@ -897,10 +907,10 @@ class TextualInversionLoaderMixin:
else: else:
token = loaded_token token = loaded_token
embedding = embedding.to(dtype=self.text_encoder.dtype, device=self.text_encoder.device) embedding = embedding.to(dtype=text_encoder.dtype, device=text_encoder.device)
# 3. Make sure we don't mess up the tokenizer or text encoder # 3. Make sure we don't mess up the tokenizer or text encoder
vocab = self.tokenizer.get_vocab() vocab = tokenizer.get_vocab()
if token in vocab: if token in vocab:
raise ValueError( raise ValueError(
f"Token {token} already in tokenizer vocabulary. Please choose a different token name or remove {token} and embedding from the tokenizer and text encoder." f"Token {token} already in tokenizer vocabulary. Please choose a different token name or remove {token} and embedding from the tokenizer and text encoder."
...@@ -908,7 +918,7 @@ class TextualInversionLoaderMixin: ...@@ -908,7 +918,7 @@ class TextualInversionLoaderMixin:
elif f"{token}_1" in vocab: elif f"{token}_1" in vocab:
multi_vector_tokens = [token] multi_vector_tokens = [token]
i = 1 i = 1
while f"{token}_{i}" in self.tokenizer.added_tokens_encoder: while f"{token}_{i}" in tokenizer.added_tokens_encoder:
multi_vector_tokens.append(f"{token}_{i}") multi_vector_tokens.append(f"{token}_{i}")
i += 1 i += 1
...@@ -926,16 +936,16 @@ class TextualInversionLoaderMixin: ...@@ -926,16 +936,16 @@ class TextualInversionLoaderMixin:
embeddings = [embedding[0]] if len(embedding.shape) > 1 else [embedding] embeddings = [embedding[0]] if len(embedding.shape) > 1 else [embedding]
# add tokens and get ids # add tokens and get ids
self.tokenizer.add_tokens(tokens) tokenizer.add_tokens(tokens)
token_ids = self.tokenizer.convert_tokens_to_ids(tokens) token_ids = tokenizer.convert_tokens_to_ids(tokens)
token_ids_and_embeddings += zip(token_ids, embeddings) token_ids_and_embeddings += zip(token_ids, embeddings)
logger.info(f"Loaded textual inversion embedding for {token}.") logger.info(f"Loaded textual inversion embedding for {token}.")
# resize token embeddings and set all new embeddings # resize token embeddings and set all new embeddings
self.text_encoder.resize_token_embeddings(len(self.tokenizer)) text_encoder.resize_token_embeddings(len(tokenizer))
for token_id, embedding in token_ids_and_embeddings: for token_id, embedding in token_ids_and_embeddings:
self.text_encoder.get_input_embeddings().weight.data[token_id] = embedding text_encoder.get_input_embeddings().weight.data[token_id] = embedding
# offload back # offload back
if is_model_cpu_offload: if is_model_cpu_offload:
......
...@@ -84,7 +84,7 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): ...@@ -84,7 +84,7 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
return noise_cfg return noise_cfg
class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoaderMixin): class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin):
r""" r"""
Pipeline for text-to-image generation using Stable Diffusion XL. Pipeline for text-to-image generation using Stable Diffusion XL.
......
...@@ -84,7 +84,9 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): ...@@ -84,7 +84,9 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
return noise_cfg return noise_cfg
class StableDiffusionXLImg2ImgPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoaderMixin): class StableDiffusionXLImg2ImgPipeline(
DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin
):
r""" r"""
Pipeline for text-to-image generation using Stable Diffusion XL. Pipeline for text-to-image generation using Stable Diffusion XL.
......
...@@ -230,7 +230,9 @@ def prepare_mask_and_masked_image(image, mask, height, width, return_image: bool ...@@ -230,7 +230,9 @@ def prepare_mask_and_masked_image(image, mask, height, width, return_image: bool
return mask, masked_image return mask, masked_image
class StableDiffusionXLInpaintPipeline(DiffusionPipeline, LoraLoaderMixin, FromSingleFileMixin): class StableDiffusionXLInpaintPipeline(
DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin
):
r""" r"""
Pipeline for text-to-image generation using Stable Diffusion XL. Pipeline for text-to-image generation using Stable Diffusion XL.
......
...@@ -62,7 +62,9 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): ...@@ -62,7 +62,9 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
return noise_cfg return noise_cfg
class StableDiffusionXLInstructPix2PixPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoaderMixin): class StableDiffusionXLInstructPix2PixPipeline(
DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin
):
r""" r"""
Pipeline for pixel-level image editing by following text instructions. Based on Stable Diffusion XL. Pipeline for pixel-level image editing by following text instructions. Based on Stable Diffusion XL.
......
...@@ -123,7 +123,9 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): ...@@ -123,7 +123,9 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
return noise_cfg return noise_cfg
class StableDiffusionXLAdapterPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoaderMixin): class StableDiffusionXLAdapterPipeline(
DiffusionPipeline, FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
):
r""" r"""
Pipeline for text-to-image generation using Stable Diffusion augmented with T2I-Adapter Pipeline for text-to-image generation using Stable Diffusion augmented with T2I-Adapter
https://arxiv.org/abs/2302.08453 https://arxiv.org/abs/2302.08453
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment