Unverified Commit 799f5b4e authored by Greg Hunkins's avatar Greg Hunkins Committed by GitHub
Browse files

[Feat] Enable State Dict For Textual Inversion Loader (#3439)



* enable state dict for textual inversion loader

* Empty-Commit | restart CI

* Empty-Commit | restart CI

* Empty-Commit | restart CI

* Empty-Commit | restart CI

* add tests

* fix tests

* fix tests

* fix tests

---------
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 07ef4855
...@@ -470,7 +470,7 @@ class TextualInversionLoaderMixin: ...@@ -470,7 +470,7 @@ class TextualInversionLoaderMixin:
def load_textual_inversion( def load_textual_inversion(
self, self,
pretrained_model_name_or_path: Union[str, List[str]], pretrained_model_name_or_path: Union[str, List[str], Dict[str, torch.Tensor], List[Dict[str, torch.Tensor]]],
token: Optional[Union[str, List[str]]] = None, token: Optional[Union[str, List[str]]] = None,
**kwargs, **kwargs,
): ):
...@@ -485,7 +485,7 @@ class TextualInversionLoaderMixin: ...@@ -485,7 +485,7 @@ class TextualInversionLoaderMixin:
</Tip> </Tip>
Parameters: Parameters:
pretrained_model_name_or_path (`str` or `os.PathLike` or `List[str or os.PathLike]`): pretrained_model_name_or_path (`str` or `os.PathLike` or `List[str or os.PathLike]` or `Dict` or `List[Dict]`):
Can be either: Can be either:
- A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
...@@ -494,6 +494,8 @@ class TextualInversionLoaderMixin: ...@@ -494,6 +494,8 @@ class TextualInversionLoaderMixin:
- A path to a *directory* containing textual inversion weights, e.g. - A path to a *directory* containing textual inversion weights, e.g.
`./my_text_inversion_directory/`. `./my_text_inversion_directory/`.
- A path to a *file* containing textual inversion weights, e.g. `./my_text_inversions.pt`. - A path to a *file* containing textual inversion weights, e.g. `./my_text_inversions.pt`.
- A [torch state
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
Or a list of those elements. Or a list of those elements.
token (`str` or `List[str]`, *optional*): token (`str` or `List[str]`, *optional*):
...@@ -618,7 +620,7 @@ class TextualInversionLoaderMixin: ...@@ -618,7 +620,7 @@ class TextualInversionLoaderMixin:
"framework": "pytorch", "framework": "pytorch",
} }
if isinstance(pretrained_model_name_or_path, str): if not isinstance(pretrained_model_name_or_path, list):
pretrained_model_name_or_paths = [pretrained_model_name_or_path] pretrained_model_name_or_paths = [pretrained_model_name_or_path]
else: else:
pretrained_model_name_or_paths = pretrained_model_name_or_path pretrained_model_name_or_paths = pretrained_model_name_or_path
...@@ -643,16 +645,38 @@ class TextualInversionLoaderMixin: ...@@ -643,16 +645,38 @@ class TextualInversionLoaderMixin:
token_ids_and_embeddings = [] token_ids_and_embeddings = []
for pretrained_model_name_or_path, token in zip(pretrained_model_name_or_paths, tokens): for pretrained_model_name_or_path, token in zip(pretrained_model_name_or_paths, tokens):
# 1. Load textual inversion file if not isinstance(pretrained_model_name_or_path, dict):
model_file = None # 1. Load textual inversion file
# Let's first try to load .safetensors weights model_file = None
if (use_safetensors and weight_name is None) or ( # Let's first try to load .safetensors weights
weight_name is not None and weight_name.endswith(".safetensors") if (use_safetensors and weight_name is None) or (
): weight_name is not None and weight_name.endswith(".safetensors")
try: ):
try:
model_file = _get_model_file(
pretrained_model_name_or_path,
weights_name=weight_name or TEXT_INVERSION_NAME_SAFE,
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
)
state_dict = safetensors.torch.load_file(model_file, device="cpu")
except Exception as e:
if not allow_pickle:
raise e
model_file = None
if model_file is None:
model_file = _get_model_file( model_file = _get_model_file(
pretrained_model_name_or_path, pretrained_model_name_or_path,
weights_name=weight_name or TEXT_INVERSION_NAME_SAFE, weights_name=weight_name or TEXT_INVERSION_NAME,
cache_dir=cache_dir, cache_dir=cache_dir,
force_download=force_download, force_download=force_download,
resume_download=resume_download, resume_download=resume_download,
...@@ -663,28 +687,9 @@ class TextualInversionLoaderMixin: ...@@ -663,28 +687,9 @@ class TextualInversionLoaderMixin:
subfolder=subfolder, subfolder=subfolder,
user_agent=user_agent, user_agent=user_agent,
) )
state_dict = safetensors.torch.load_file(model_file, device="cpu") state_dict = torch.load(model_file, map_location="cpu")
except Exception as e: else:
if not allow_pickle: state_dict = pretrained_model_name_or_path
raise e
model_file = None
if model_file is None:
model_file = _get_model_file(
pretrained_model_name_or_path,
weights_name=weight_name or TEXT_INVERSION_NAME,
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
)
state_dict = torch.load(model_file, map_location="cpu")
# 2. Load token and embedding correcly from file # 2. Load token and embedding correcly from file
loaded_token = None loaded_token = None
......
...@@ -663,6 +663,65 @@ class DownloadTests(unittest.TestCase): ...@@ -663,6 +663,65 @@ class DownloadTests(unittest.TestCase):
out = pipe(prompt, num_inference_steps=1, output_type="numpy").images out = pipe(prompt, num_inference_steps=1, output_type="numpy").images
assert out.shape == (1, 128, 128, 3) assert out.shape == (1, 128, 128, 3)
# single token state dict load
ten = {"<x>": torch.ones((32,))}
pipe.load_textual_inversion(ten)
token = pipe.tokenizer.convert_tokens_to_ids("<x>")
assert token == num_tokens + 10, "Added token must be at spot `num_tokens`"
assert pipe.text_encoder.get_input_embeddings().weight[-1].sum().item() == 32
assert pipe._maybe_convert_prompt("<x>", pipe.tokenizer) == "<x>"
prompt = "hey <x>"
out = pipe(prompt, num_inference_steps=1, output_type="numpy").images
assert out.shape == (1, 128, 128, 3)
# multi embedding state dict load
ten1 = {"<xxxxx>": torch.ones((32,))}
ten2 = {"<xxxxxx>": 2 * torch.ones((1, 32))}
pipe.load_textual_inversion([ten1, ten2])
token = pipe.tokenizer.convert_tokens_to_ids("<xxxxx>")
assert token == num_tokens + 11, "Added token must be at spot `num_tokens`"
assert pipe.text_encoder.get_input_embeddings().weight[-2].sum().item() == 32
assert pipe._maybe_convert_prompt("<xxxxx>", pipe.tokenizer) == "<xxxxx>"
token = pipe.tokenizer.convert_tokens_to_ids("<xxxxxx>")
assert token == num_tokens + 12, "Added token must be at spot `num_tokens`"
assert pipe.text_encoder.get_input_embeddings().weight[-1].sum().item() == 64
assert pipe._maybe_convert_prompt("<xxxxxx>", pipe.tokenizer) == "<xxxxxx>"
prompt = "hey <xxxxx> <xxxxxx>"
out = pipe(prompt, num_inference_steps=1, output_type="numpy").images
assert out.shape == (1, 128, 128, 3)
# auto1111 multi-token state dict load
ten = {
"string_to_param": {
"*": torch.cat([3 * torch.ones((1, 32)), 4 * torch.ones((1, 32)), 5 * torch.ones((1, 32))])
},
"name": "<xxxx>",
}
pipe.load_textual_inversion(ten)
token = pipe.tokenizer.convert_tokens_to_ids("<xxxx>")
token_1 = pipe.tokenizer.convert_tokens_to_ids("<xxxx>_1")
token_2 = pipe.tokenizer.convert_tokens_to_ids("<xxxx>_2")
assert token == num_tokens + 13, "Added token must be at spot `num_tokens`"
assert token_1 == num_tokens + 14, "Added token must be at spot `num_tokens`"
assert token_2 == num_tokens + 15, "Added token must be at spot `num_tokens`"
assert pipe.text_encoder.get_input_embeddings().weight[-3].sum().item() == 96
assert pipe.text_encoder.get_input_embeddings().weight[-2].sum().item() == 128
assert pipe.text_encoder.get_input_embeddings().weight[-1].sum().item() == 160
assert pipe._maybe_convert_prompt("<xxxx>", pipe.tokenizer) == "<xxxx> <xxxx>_1 <xxxx>_2"
prompt = "hey <xxxx>"
out = pipe(prompt, num_inference_steps=1, output_type="numpy").images
assert out.shape == (1, 128, 128, 3)
def test_download_ignore_files(self): def test_download_ignore_files(self):
# Check https://huggingface.co/hf-internal-testing/tiny-stable-diffusion-pipe-ignore-files/blob/72f58636e5508a218c6b3f60550dc96445547817/model_index.json#L4 # Check https://huggingface.co/hf-internal-testing/tiny-stable-diffusion-pipe-ignore-files/blob/72f58636e5508a218c6b3f60550dc96445547817/model_index.json#L4
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment