Unverified Commit 5d3e7bda authored by bonlime's avatar bonlime Committed by GitHub
Browse files

Fix bug in Textual Inversion Unloading (#9304)



* Update textual_inversion.py

* add unload test

* add comment

* fix style

---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
Co-authored-by: default avatarYour Name <you@example.com>
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>
parent 2541d141
...@@ -561,6 +561,8 @@ class TextualInversionLoaderMixin: ...@@ -561,6 +561,8 @@ class TextualInversionLoaderMixin:
tokenizer._added_tokens_encoder[token.content] = last_special_token_id + key_id tokenizer._added_tokens_encoder[token.content] = last_special_token_id + key_id
key_id += 1 key_id += 1
tokenizer._update_trie() tokenizer._update_trie()
# set correct total vocab size after removing tokens
tokenizer._update_total_vocab_size()
# Delete from text encoder # Delete from text encoder
text_embedding_dim = text_encoder.get_input_embeddings().embedding_dim text_embedding_dim = text_encoder.get_input_embeddings().embedding_dim
......
...@@ -947,6 +947,27 @@ class DownloadTests(unittest.TestCase): ...@@ -947,6 +947,27 @@ class DownloadTests(unittest.TestCase):
emb1[num_tokens + 1].sum().item() == emb2[num_tokens + 1].sum().item() == emb3[num_tokens + 1].sum().item() emb1[num_tokens + 1].sum().item() == emb2[num_tokens + 1].sum().item() == emb3[num_tokens + 1].sum().item()
) )
def test_textual_inversion_unload(self):
pipe1 = StableDiffusionPipeline.from_pretrained(
"hf-internal-testing/tiny-stable-diffusion-torch", safety_checker=None
)
pipe1 = pipe1.to(torch_device)
orig_tokenizer_size = len(pipe1.tokenizer)
orig_emb_size = len(pipe1.text_encoder.get_input_embeddings().weight)
token = "<*>"
ten = torch.ones((32,))
pipe1.load_textual_inversion(ten, token=token)
pipe1.unload_textual_inversion()
pipe1.load_textual_inversion(ten, token=token)
pipe1.unload_textual_inversion()
final_tokenizer_size = len(pipe1.tokenizer)
final_emb_size = len(pipe1.text_encoder.get_input_embeddings().weight)
# both should be restored to original size
assert final_tokenizer_size == orig_tokenizer_size
assert final_emb_size == orig_emb_size
def test_download_ignore_files(self): def test_download_ignore_files(self):
# Check https://huggingface.co/hf-internal-testing/tiny-stable-diffusion-pipe-ignore-files/blob/72f58636e5508a218c6b3f60550dc96445547817/model_index.json#L4 # Check https://huggingface.co/hf-internal-testing/tiny-stable-diffusion-pipe-ignore-files/blob/72f58636e5508a218c6b3f60550dc96445547817/model_index.json#L4
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment