Unverified Commit 3d8b3d7c authored by pdoane's avatar pdoane Committed by GitHub
Browse files

Batched load of textual inversions (#3277)



* Batched load of textual inversions

- Only call resize_token_embeddings once per batch as it is the most expensive operation
- Allow pretrained_model_name_or_path and token to be an optional list
- Remove Dict from type annotation pretrained_model_name_or_path as it was not supported in this function
- Add comment that single files (e.g. .pt/.safetensors) are supported
- Add comment for token parameter
- Convert token override log message from warning to info

* Update src/diffusers/loaders.py

Check for duplicate tokens
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* Update condition for None tokens

---------
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 0ffac979
...@@ -436,7 +436,10 @@ class TextualInversionLoaderMixin: ...@@ -436,7 +436,10 @@ class TextualInversionLoaderMixin:
return prompt return prompt
def load_textual_inversion( def load_textual_inversion(
self, pretrained_model_name_or_path: Union[str, Dict[str, torch.Tensor]], token: Optional[str] = None, **kwargs self,
pretrained_model_name_or_path: Union[str, List[str]],
token: Optional[Union[str, List[str]]] = None,
**kwargs,
): ):
r""" r"""
Load textual inversion embeddings into the text encoder of stable diffusion pipelines. Both `diffusers` and Load textual inversion embeddings into the text encoder of stable diffusion pipelines. Both `diffusers` and
...@@ -449,7 +452,7 @@ class TextualInversionLoaderMixin: ...@@ -449,7 +452,7 @@ class TextualInversionLoaderMixin:
</Tip> </Tip>
Parameters: Parameters:
pretrained_model_name_or_path (`str` or `os.PathLike`): pretrained_model_name_or_path (`str` or `os.PathLike` or `List[str or os.PathLike]`):
Can be either: Can be either:
- A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
...@@ -457,6 +460,12 @@ class TextualInversionLoaderMixin: ...@@ -457,6 +460,12 @@ class TextualInversionLoaderMixin:
`"sd-concepts-library/low-poly-hd-logos-icons"`. `"sd-concepts-library/low-poly-hd-logos-icons"`.
- A path to a *directory* containing textual inversion weights, e.g. - A path to a *directory* containing textual inversion weights, e.g.
`./my_text_inversion_directory/`. `./my_text_inversion_directory/`.
- A path to a *file* containing textual inversion weights, e.g. `./my_text_inversions.pt`.
Or a list of those elements.
token (`str` or `List[str]`, *optional*):
Override the token to use for the textual inversion weights. If `pretrained_model_name_or_path` is a
list, then `token` must also be a list of equal length.
weight_name (`str`, *optional*): weight_name (`str`, *optional*):
Name of a custom weight file. This should be used in two cases: Name of a custom weight file. This should be used in two cases:
...@@ -576,6 +585,31 @@ class TextualInversionLoaderMixin: ...@@ -576,6 +585,31 @@ class TextualInversionLoaderMixin:
"framework": "pytorch", "framework": "pytorch",
} }
if isinstance(pretrained_model_name_or_path, str):
pretrained_model_name_or_paths = [pretrained_model_name_or_path]
else:
pretrained_model_name_or_paths = pretrained_model_name_or_path
if isinstance(token, str):
tokens = [token]
elif token is None:
tokens = [None] * len(pretrained_model_name_or_paths)
else:
tokens = token
if len(pretrained_model_name_or_paths) != len(tokens):
raise ValueError(
f"You have passed a list of models of length {len(pretrained_model_name_or_paths)}, and list of tokens of length {len(tokens)}"
f"Make sure both lists have the same length."
)
valid_tokens = [t for t in tokens if t is not None]
if len(set(valid_tokens)) < len(valid_tokens):
raise ValueError(f"You have passed a list of tokens that contains duplicates: {tokens}")
token_ids_and_embeddings = []
for pretrained_model_name_or_path, token in zip(pretrained_model_name_or_paths, tokens):
# 1. Load textual inversion file # 1. Load textual inversion file
model_file = None model_file = None
# Let's first try to load .safetensors weights # Let's first try to load .safetensors weights
...@@ -635,7 +669,7 @@ class TextualInversionLoaderMixin: ...@@ -635,7 +669,7 @@ class TextualInversionLoaderMixin:
embedding = state_dict["string_to_param"]["*"] embedding = state_dict["string_to_param"]["*"]
if token is not None and loaded_token != token: if token is not None and loaded_token != token:
logger.warn(f"The loaded token: {loaded_token} is overwritten by the passed token {token}.") logger.info(f"The loaded token: {loaded_token} is overwritten by the passed token {token}.")
else: else:
token = loaded_token token = loaded_token
...@@ -670,14 +704,15 @@ class TextualInversionLoaderMixin: ...@@ -670,14 +704,15 @@ class TextualInversionLoaderMixin:
# add tokens and get ids # add tokens and get ids
self.tokenizer.add_tokens(tokens) self.tokenizer.add_tokens(tokens)
token_ids = self.tokenizer.convert_tokens_to_ids(tokens) token_ids = self.tokenizer.convert_tokens_to_ids(tokens)
token_ids_and_embeddings += zip(token_ids, embeddings)
# resize token embeddings and set new embeddings logger.info(f"Loaded textual inversion embedding for {token}.")
# resize token embeddings and set all new embeddings
self.text_encoder.resize_token_embeddings(len(self.tokenizer)) self.text_encoder.resize_token_embeddings(len(self.tokenizer))
for token_id, embedding in zip(token_ids, embeddings): for token_id, embedding in token_ids_and_embeddings:
self.text_encoder.get_input_embeddings().weight.data[token_id] = embedding self.text_encoder.get_input_embeddings().weight.data[token_id] = embedding
logger.info(f"Loaded textual inversion embedding for {token}.")
class LoraLoaderMixin: class LoraLoaderMixin:
r""" r"""
......
...@@ -575,6 +575,31 @@ class DownloadTests(unittest.TestCase): ...@@ -575,6 +575,31 @@ class DownloadTests(unittest.TestCase):
out = pipe(prompt, num_inference_steps=1, output_type="numpy").images out = pipe(prompt, num_inference_steps=1, output_type="numpy").images
assert out.shape == (1, 128, 128, 3) assert out.shape == (1, 128, 128, 3)
# multi embedding load
with tempfile.TemporaryDirectory() as tmpdirname1:
with tempfile.TemporaryDirectory() as tmpdirname2:
ten = {"<*****>": torch.ones((32,))}
torch.save(ten, os.path.join(tmpdirname1, "learned_embeds.bin"))
ten = {"<******>": 2 * torch.ones((1, 32))}
torch.save(ten, os.path.join(tmpdirname2, "learned_embeds.bin"))
pipe.load_textual_inversion([tmpdirname1, tmpdirname2])
token = pipe.tokenizer.convert_tokens_to_ids("<*****>")
assert token == num_tokens + 8, "Added token must be at spot `num_tokens`"
assert pipe.text_encoder.get_input_embeddings().weight[-2].sum().item() == 32
assert pipe._maybe_convert_prompt("<*****>", pipe.tokenizer) == "<*****>"
token = pipe.tokenizer.convert_tokens_to_ids("<******>")
assert token == num_tokens + 9, "Added token must be at spot `num_tokens`"
assert pipe.text_encoder.get_input_embeddings().weight[-1].sum().item() == 64
assert pipe._maybe_convert_prompt("<******>", pipe.tokenizer) == "<******>"
prompt = "hey <*****> <******>"
out = pipe(prompt, num_inference_steps=1, output_type="numpy").images
assert out.shape == (1, 128, 128, 3)
def test_download_ignore_files(self): def test_download_ignore_files(self):
# Check https://huggingface.co/hf-internal-testing/tiny-stable-diffusion-pipe-ignore-files/blob/72f58636e5508a218c6b3f60550dc96445547817/model_index.json#L4 # Check https://huggingface.co/hf-internal-testing/tiny-stable-diffusion-pipe-ignore-files/blob/72f58636e5508a218c6b3f60550dc96445547817/model_index.json#L4
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment