"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "21bbc633c4d7b9bb7f74caf4b248c6a4079a85c6"
Unverified Commit f0216b77 authored by Hezi Zisman's avatar Hezi Zisman Committed by GitHub
Browse files

allow explicit tokenizer & text_encoder in unload_textual_inversion (#6977)



* allow passing tokenizer & text_encoder to unload_textual_inversion


---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>
Co-authored-by: default avatarFabio Rigano <fabio2rigano@gmail.com>
Co-authored-by: default avatarLinoy Tsaban <57615435+linoytsaban@users.noreply.github.com>
parent d5f444de
...@@ -457,6 +457,8 @@ class TextualInversionLoaderMixin: ...@@ -457,6 +457,8 @@ class TextualInversionLoaderMixin:
def unload_textual_inversion( def unload_textual_inversion(
self, self,
tokens: Optional[Union[str, List[str]]] = None, tokens: Optional[Union[str, List[str]]] = None,
tokenizer: Optional["PreTrainedTokenizer"] = None,
text_encoder: Optional["PreTrainedModel"] = None,
): ):
r""" r"""
Unload Textual Inversion embeddings from the text encoder of [`StableDiffusionPipeline`] Unload Textual Inversion embeddings from the text encoder of [`StableDiffusionPipeline`]
...@@ -481,11 +483,28 @@ class TextualInversionLoaderMixin: ...@@ -481,11 +483,28 @@ class TextualInversionLoaderMixin:
# Remove just one token # Remove just one token
pipeline.unload_textual_inversion("<moe-bius>") pipeline.unload_textual_inversion("<moe-bius>")
# Example 3: unload from SDXL
pipeline = AutoPipelineForText2Image.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0")
embedding_path = hf_hub_download(repo_id="linoyts/web_y2k", filename="web_y2k_emb.safetensors", repo_type="model")
# load embeddings to the text encoders
state_dict = load_file(embedding_path)
# load embeddings of text_encoder 1 (CLIP ViT-L/14)
pipeline.load_textual_inversion(state_dict["clip_l"], token=["<s0>", "<s1>"], text_encoder=pipeline.text_encoder, tokenizer=pipeline.tokenizer)
# load embeddings of text_encoder 2 (CLIP ViT-G/14)
pipeline.load_textual_inversion(state_dict["clip_g"], token=["<s0>", "<s1>"], text_encoder=pipeline.text_encoder_2, tokenizer=pipeline.tokenizer_2)
# Unload explicitly from both text encoders abd tokenizers
pipeline.unload_textual_inversion(tokens=["<s0>", "<s1>"], text_encoder=pipeline.text_encoder, tokenizer=pipeline.tokenizer)
pipeline.unload_textual_inversion(tokens=["<s0>", "<s1>"], text_encoder=pipeline.text_encoder_2, tokenizer=pipeline.tokenizer_2)
``` ```
""" """
tokenizer = getattr(self, "tokenizer", None) tokenizer = tokenizer or getattr(self, "tokenizer", None)
text_encoder = getattr(self, "text_encoder", None) text_encoder = text_encoder or getattr(self, "text_encoder", None)
# Get textual inversion tokens and ids # Get textual inversion tokens and ids
token_ids = [] token_ids = []
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment