Unverified Commit 3d8b3d7c authored by pdoane's avatar pdoane Committed by GitHub
Browse files

Batched load of textual inversions (#3277)



* Batched load of textual inversions

- Only call resize_token_embeddings once per batch as it is the most expensive operation
- Allow pretrained_model_name_or_path and token to be an optional list
- Remove Dict from type annotation pretrained_model_name_or_path as it was not supported in this function
- Add comment that single files (e.g. .pt/.safetensors) are supported
- Add comment for token parameter
- Convert token override log message from warning to info

* Update src/diffusers/loaders.py

Check for duplicate tokens
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* Update condition for None tokens

---------
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 0ffac979
...@@ -436,7 +436,10 @@ class TextualInversionLoaderMixin: ...@@ -436,7 +436,10 @@ class TextualInversionLoaderMixin:
return prompt return prompt
def load_textual_inversion( def load_textual_inversion(
self, pretrained_model_name_or_path: Union[str, Dict[str, torch.Tensor]], token: Optional[str] = None, **kwargs self,
pretrained_model_name_or_path: Union[str, List[str]],
token: Optional[Union[str, List[str]]] = None,
**kwargs,
): ):
r""" r"""
Load textual inversion embeddings into the text encoder of stable diffusion pipelines. Both `diffusers` and Load textual inversion embeddings into the text encoder of stable diffusion pipelines. Both `diffusers` and
...@@ -449,7 +452,7 @@ class TextualInversionLoaderMixin: ...@@ -449,7 +452,7 @@ class TextualInversionLoaderMixin:
</Tip> </Tip>
Parameters: Parameters:
pretrained_model_name_or_path (`str` or `os.PathLike`): pretrained_model_name_or_path (`str` or `os.PathLike` or `List[str or os.PathLike]`):
Can be either: Can be either:
- A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
...@@ -457,6 +460,12 @@ class TextualInversionLoaderMixin: ...@@ -457,6 +460,12 @@ class TextualInversionLoaderMixin:
`"sd-concepts-library/low-poly-hd-logos-icons"`. `"sd-concepts-library/low-poly-hd-logos-icons"`.
- A path to a *directory* containing textual inversion weights, e.g. - A path to a *directory* containing textual inversion weights, e.g.
`./my_text_inversion_directory/`. `./my_text_inversion_directory/`.
- A path to a *file* containing textual inversion weights, e.g. `./my_text_inversions.pt`.
Or a list of those elements.
token (`str` or `List[str]`, *optional*):
Override the token to use for the textual inversion weights. If `pretrained_model_name_or_path` is a
list, then `token` must also be a list of equal length.
weight_name (`str`, *optional*): weight_name (`str`, *optional*):
Name of a custom weight file. This should be used in two cases: Name of a custom weight file. This should be used in two cases:
...@@ -576,16 +585,62 @@ class TextualInversionLoaderMixin: ...@@ -576,16 +585,62 @@ class TextualInversionLoaderMixin:
"framework": "pytorch", "framework": "pytorch",
} }
# 1. Load textual inversion file if isinstance(pretrained_model_name_or_path, str):
model_file = None pretrained_model_name_or_paths = [pretrained_model_name_or_path]
# Let's first try to load .safetensors weights else:
if (use_safetensors and weight_name is None) or ( pretrained_model_name_or_paths = pretrained_model_name_or_path
weight_name is not None and weight_name.endswith(".safetensors")
): if isinstance(token, str):
try: tokens = [token]
elif token is None:
tokens = [None] * len(pretrained_model_name_or_paths)
else:
tokens = token
if len(pretrained_model_name_or_paths) != len(tokens):
raise ValueError(
f"You have passed a list of models of length {len(pretrained_model_name_or_paths)}, and list of tokens of length {len(tokens)}"
f"Make sure both lists have the same length."
)
valid_tokens = [t for t in tokens if t is not None]
if len(set(valid_tokens)) < len(valid_tokens):
raise ValueError(f"You have passed a list of tokens that contains duplicates: {tokens}")
token_ids_and_embeddings = []
for pretrained_model_name_or_path, token in zip(pretrained_model_name_or_paths, tokens):
# 1. Load textual inversion file
model_file = None
# Let's first try to load .safetensors weights
if (use_safetensors and weight_name is None) or (
weight_name is not None and weight_name.endswith(".safetensors")
):
try:
model_file = _get_model_file(
pretrained_model_name_or_path,
weights_name=weight_name or TEXT_INVERSION_NAME_SAFE,
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
)
state_dict = safetensors.torch.load_file(model_file, device="cpu")
except Exception as e:
if not allow_pickle:
raise e
model_file = None
if model_file is None:
model_file = _get_model_file( model_file = _get_model_file(
pretrained_model_name_or_path, pretrained_model_name_or_path,
weights_name=weight_name or TEXT_INVERSION_NAME_SAFE, weights_name=weight_name or TEXT_INVERSION_NAME,
cache_dir=cache_dir, cache_dir=cache_dir,
force_download=force_download, force_download=force_download,
resume_download=resume_download, resume_download=resume_download,
...@@ -596,88 +651,68 @@ class TextualInversionLoaderMixin: ...@@ -596,88 +651,68 @@ class TextualInversionLoaderMixin:
subfolder=subfolder, subfolder=subfolder,
user_agent=user_agent, user_agent=user_agent,
) )
state_dict = safetensors.torch.load_file(model_file, device="cpu") state_dict = torch.load(model_file, map_location="cpu")
except Exception as e:
if not allow_pickle:
raise e
model_file = None # 2. Load token and embedding correcly from file
if isinstance(state_dict, torch.Tensor):
if token is None:
raise ValueError(
"You are trying to load a textual inversion embedding that has been saved as a PyTorch tensor. Make sure to pass the name of the corresponding token in this case: `token=...`."
)
embedding = state_dict
elif len(state_dict) == 1:
# diffusers
loaded_token, embedding = next(iter(state_dict.items()))
elif "string_to_param" in state_dict:
# A1111
loaded_token = state_dict["name"]
embedding = state_dict["string_to_param"]["*"]
if token is not None and loaded_token != token:
logger.info(f"The loaded token: {loaded_token} is overwritten by the passed token {token}.")
else:
token = loaded_token
if model_file is None: embedding = embedding.to(dtype=self.text_encoder.dtype, device=self.text_encoder.device)
model_file = _get_model_file(
pretrained_model_name_or_path,
weights_name=weight_name or TEXT_INVERSION_NAME,
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
)
state_dict = torch.load(model_file, map_location="cpu")
# 2. Load token and embedding correcly from file # 3. Make sure we don't mess up the tokenizer or text encoder
if isinstance(state_dict, torch.Tensor): vocab = self.tokenizer.get_vocab()
if token is None: if token in vocab:
raise ValueError( raise ValueError(
"You are trying to load a textual inversion embedding that has been saved as a PyTorch tensor. Make sure to pass the name of the corresponding token in this case: `token=...`." f"Token {token} already in tokenizer vocabulary. Please choose a different token name or remove {token} and embedding from the tokenizer and text encoder."
) )
embedding = state_dict elif f"{token}_1" in vocab:
elif len(state_dict) == 1: multi_vector_tokens = [token]
# diffusers i = 1
loaded_token, embedding = next(iter(state_dict.items())) while f"{token}_{i}" in self.tokenizer.added_tokens_encoder:
elif "string_to_param" in state_dict: multi_vector_tokens.append(f"{token}_{i}")
# A1111 i += 1
loaded_token = state_dict["name"]
embedding = state_dict["string_to_param"]["*"]
if token is not None and loaded_token != token:
logger.warn(f"The loaded token: {loaded_token} is overwritten by the passed token {token}.")
else:
token = loaded_token
embedding = embedding.to(dtype=self.text_encoder.dtype, device=self.text_encoder.device)
# 3. Make sure we don't mess up the tokenizer or text encoder raise ValueError(
vocab = self.tokenizer.get_vocab() f"Multi-vector Token {multi_vector_tokens} already in tokenizer vocabulary. Please choose a different token name or remove the {multi_vector_tokens} and embedding from the tokenizer and text encoder."
if token in vocab: )
raise ValueError(
f"Token {token} already in tokenizer vocabulary. Please choose a different token name or remove {token} and embedding from the tokenizer and text encoder."
)
elif f"{token}_1" in vocab:
multi_vector_tokens = [token]
i = 1
while f"{token}_{i}" in self.tokenizer.added_tokens_encoder:
multi_vector_tokens.append(f"{token}_{i}")
i += 1
raise ValueError( is_multi_vector = len(embedding.shape) > 1 and embedding.shape[0] > 1
f"Multi-vector Token {multi_vector_tokens} already in tokenizer vocabulary. Please choose a different token name or remove the {multi_vector_tokens} and embedding from the tokenizer and text encoder."
)
is_multi_vector = len(embedding.shape) > 1 and embedding.shape[0] > 1 if is_multi_vector:
tokens = [token] + [f"{token}_{i}" for i in range(1, embedding.shape[0])]
embeddings = [e for e in embedding] # noqa: C416
else:
tokens = [token]
embeddings = [embedding[0]] if len(embedding.shape) > 1 else [embedding]
if is_multi_vector: # add tokens and get ids
tokens = [token] + [f"{token}_{i}" for i in range(1, embedding.shape[0])] self.tokenizer.add_tokens(tokens)
embeddings = [e for e in embedding] # noqa: C416 token_ids = self.tokenizer.convert_tokens_to_ids(tokens)
else: token_ids_and_embeddings += zip(token_ids, embeddings)
tokens = [token]
embeddings = [embedding[0]] if len(embedding.shape) > 1 else [embedding]
# add tokens and get ids logger.info(f"Loaded textual inversion embedding for {token}.")
self.tokenizer.add_tokens(tokens)
token_ids = self.tokenizer.convert_tokens_to_ids(tokens)
# resize token embeddings and set new embeddings # resize token embeddings and set all new embeddings
self.text_encoder.resize_token_embeddings(len(self.tokenizer)) self.text_encoder.resize_token_embeddings(len(self.tokenizer))
for token_id, embedding in zip(token_ids, embeddings): for token_id, embedding in token_ids_and_embeddings:
self.text_encoder.get_input_embeddings().weight.data[token_id] = embedding self.text_encoder.get_input_embeddings().weight.data[token_id] = embedding
logger.info(f"Loaded textual inversion embedding for {token}.")
class LoraLoaderMixin: class LoraLoaderMixin:
r""" r"""
......
...@@ -575,6 +575,31 @@ class DownloadTests(unittest.TestCase): ...@@ -575,6 +575,31 @@ class DownloadTests(unittest.TestCase):
out = pipe(prompt, num_inference_steps=1, output_type="numpy").images out = pipe(prompt, num_inference_steps=1, output_type="numpy").images
assert out.shape == (1, 128, 128, 3) assert out.shape == (1, 128, 128, 3)
# multi embedding load
with tempfile.TemporaryDirectory() as tmpdirname1:
with tempfile.TemporaryDirectory() as tmpdirname2:
ten = {"<*****>": torch.ones((32,))}
torch.save(ten, os.path.join(tmpdirname1, "learned_embeds.bin"))
ten = {"<******>": 2 * torch.ones((1, 32))}
torch.save(ten, os.path.join(tmpdirname2, "learned_embeds.bin"))
pipe.load_textual_inversion([tmpdirname1, tmpdirname2])
token = pipe.tokenizer.convert_tokens_to_ids("<*****>")
assert token == num_tokens + 8, "Added token must be at spot `num_tokens`"
assert pipe.text_encoder.get_input_embeddings().weight[-2].sum().item() == 32
assert pipe._maybe_convert_prompt("<*****>", pipe.tokenizer) == "<*****>"
token = pipe.tokenizer.convert_tokens_to_ids("<******>")
assert token == num_tokens + 9, "Added token must be at spot `num_tokens`"
assert pipe.text_encoder.get_input_embeddings().weight[-1].sum().item() == 64
assert pipe._maybe_convert_prompt("<******>", pipe.tokenizer) == "<******>"
prompt = "hey <*****> <******>"
out = pipe(prompt, num_inference_steps=1, output_type="numpy").images
assert out.shape == (1, 128, 128, 3)
def test_download_ignore_files(self): def test_download_ignore_files(self):
# Check https://huggingface.co/hf-internal-testing/tiny-stable-diffusion-pipe-ignore-files/blob/72f58636e5508a218c6b3f60550dc96445547817/model_index.json#L4 # Check https://huggingface.co/hf-internal-testing/tiny-stable-diffusion-pipe-ignore-files/blob/72f58636e5508a218c6b3f60550dc96445547817/model_index.json#L4
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment