Unverified Commit 799f5b4e authored by Greg Hunkins's avatar Greg Hunkins Committed by GitHub
Browse files

[Feat] Enable State Dict For Textual Inversion Loader (#3439)



* enable state dict for textual inversion loader

* Empty-Commit | restart CI

* Empty-Commit | restart CI

* Empty-Commit | restart CI

* Empty-Commit | restart CI

* add tests

* fix tests

* fix tests

* fix tests

---------
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 07ef4855
...@@ -470,7 +470,7 @@ class TextualInversionLoaderMixin: ...@@ -470,7 +470,7 @@ class TextualInversionLoaderMixin:
def load_textual_inversion( def load_textual_inversion(
self, self,
pretrained_model_name_or_path: Union[str, List[str]], pretrained_model_name_or_path: Union[str, List[str], Dict[str, torch.Tensor], List[Dict[str, torch.Tensor]]],
token: Optional[Union[str, List[str]]] = None, token: Optional[Union[str, List[str]]] = None,
**kwargs, **kwargs,
): ):
...@@ -485,7 +485,7 @@ class TextualInversionLoaderMixin: ...@@ -485,7 +485,7 @@ class TextualInversionLoaderMixin:
</Tip> </Tip>
Parameters: Parameters:
pretrained_model_name_or_path (`str` or `os.PathLike` or `List[str or os.PathLike]`): pretrained_model_name_or_path (`str` or `os.PathLike` or `List[str or os.PathLike]` or `Dict` or `List[Dict]`):
Can be either: Can be either:
- A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
...@@ -494,6 +494,8 @@ class TextualInversionLoaderMixin: ...@@ -494,6 +494,8 @@ class TextualInversionLoaderMixin:
- A path to a *directory* containing textual inversion weights, e.g. - A path to a *directory* containing textual inversion weights, e.g.
`./my_text_inversion_directory/`. `./my_text_inversion_directory/`.
- A path to a *file* containing textual inversion weights, e.g. `./my_text_inversions.pt`. - A path to a *file* containing textual inversion weights, e.g. `./my_text_inversions.pt`.
- A [torch state
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
Or a list of those elements. Or a list of those elements.
token (`str` or `List[str]`, *optional*): token (`str` or `List[str]`, *optional*):
...@@ -618,7 +620,7 @@ class TextualInversionLoaderMixin: ...@@ -618,7 +620,7 @@ class TextualInversionLoaderMixin:
"framework": "pytorch", "framework": "pytorch",
} }
if isinstance(pretrained_model_name_or_path, str): if not isinstance(pretrained_model_name_or_path, list):
pretrained_model_name_or_paths = [pretrained_model_name_or_path] pretrained_model_name_or_paths = [pretrained_model_name_or_path]
else: else:
pretrained_model_name_or_paths = pretrained_model_name_or_path pretrained_model_name_or_paths = pretrained_model_name_or_path
...@@ -643,6 +645,7 @@ class TextualInversionLoaderMixin: ...@@ -643,6 +645,7 @@ class TextualInversionLoaderMixin:
token_ids_and_embeddings = [] token_ids_and_embeddings = []
for pretrained_model_name_or_path, token in zip(pretrained_model_name_or_paths, tokens): for pretrained_model_name_or_path, token in zip(pretrained_model_name_or_paths, tokens):
if not isinstance(pretrained_model_name_or_path, dict):
# 1. Load textual inversion file # 1. Load textual inversion file
model_file = None model_file = None
# Let's first try to load .safetensors weights # Let's first try to load .safetensors weights
...@@ -685,6 +688,8 @@ class TextualInversionLoaderMixin: ...@@ -685,6 +688,8 @@ class TextualInversionLoaderMixin:
user_agent=user_agent, user_agent=user_agent,
) )
state_dict = torch.load(model_file, map_location="cpu") state_dict = torch.load(model_file, map_location="cpu")
else:
state_dict = pretrained_model_name_or_path
# 2. Load token and embedding correcly from file # 2. Load token and embedding correcly from file
loaded_token = None loaded_token = None
......
...@@ -663,6 +663,65 @@ class DownloadTests(unittest.TestCase): ...@@ -663,6 +663,65 @@ class DownloadTests(unittest.TestCase):
out = pipe(prompt, num_inference_steps=1, output_type="numpy").images out = pipe(prompt, num_inference_steps=1, output_type="numpy").images
assert out.shape == (1, 128, 128, 3) assert out.shape == (1, 128, 128, 3)
# single token state dict load
ten = {"<x>": torch.ones((32,))}
pipe.load_textual_inversion(ten)
token = pipe.tokenizer.convert_tokens_to_ids("<x>")
assert token == num_tokens + 10, "Added token must be at spot `num_tokens`"
assert pipe.text_encoder.get_input_embeddings().weight[-1].sum().item() == 32
assert pipe._maybe_convert_prompt("<x>", pipe.tokenizer) == "<x>"
prompt = "hey <x>"
out = pipe(prompt, num_inference_steps=1, output_type="numpy").images
assert out.shape == (1, 128, 128, 3)
# multi embedding state dict load
ten1 = {"<xxxxx>": torch.ones((32,))}
ten2 = {"<xxxxxx>": 2 * torch.ones((1, 32))}
pipe.load_textual_inversion([ten1, ten2])
token = pipe.tokenizer.convert_tokens_to_ids("<xxxxx>")
assert token == num_tokens + 11, "Added token must be at spot `num_tokens`"
assert pipe.text_encoder.get_input_embeddings().weight[-2].sum().item() == 32
assert pipe._maybe_convert_prompt("<xxxxx>", pipe.tokenizer) == "<xxxxx>"
token = pipe.tokenizer.convert_tokens_to_ids("<xxxxxx>")
assert token == num_tokens + 12, "Added token must be at spot `num_tokens`"
assert pipe.text_encoder.get_input_embeddings().weight[-1].sum().item() == 64
assert pipe._maybe_convert_prompt("<xxxxxx>", pipe.tokenizer) == "<xxxxxx>"
prompt = "hey <xxxxx> <xxxxxx>"
out = pipe(prompt, num_inference_steps=1, output_type="numpy").images
assert out.shape == (1, 128, 128, 3)
# auto1111 multi-token state dict load
ten = {
"string_to_param": {
"*": torch.cat([3 * torch.ones((1, 32)), 4 * torch.ones((1, 32)), 5 * torch.ones((1, 32))])
},
"name": "<xxxx>",
}
pipe.load_textual_inversion(ten)
token = pipe.tokenizer.convert_tokens_to_ids("<xxxx>")
token_1 = pipe.tokenizer.convert_tokens_to_ids("<xxxx>_1")
token_2 = pipe.tokenizer.convert_tokens_to_ids("<xxxx>_2")
assert token == num_tokens + 13, "Added token must be at spot `num_tokens`"
assert token_1 == num_tokens + 14, "Added token must be at spot `num_tokens`"
assert token_2 == num_tokens + 15, "Added token must be at spot `num_tokens`"
assert pipe.text_encoder.get_input_embeddings().weight[-3].sum().item() == 96
assert pipe.text_encoder.get_input_embeddings().weight[-2].sum().item() == 128
assert pipe.text_encoder.get_input_embeddings().weight[-1].sum().item() == 160
assert pipe._maybe_convert_prompt("<xxxx>", pipe.tokenizer) == "<xxxx> <xxxx>_1 <xxxx>_2"
prompt = "hey <xxxx>"
out = pipe(prompt, num_inference_steps=1, output_type="numpy").images
assert out.shape == (1, 128, 128, 3)
def test_download_ignore_files(self): def test_download_ignore_files(self):
# Check https://huggingface.co/hf-internal-testing/tiny-stable-diffusion-pipe-ignore-files/blob/72f58636e5508a218c6b3f60550dc96445547817/model_index.json#L4 # Check https://huggingface.co/hf-internal-testing/tiny-stable-diffusion-pipe-ignore-files/blob/72f58636e5508a218c6b3f60550dc96445547817/model_index.json#L4
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment