Unverified Commit 0d196f9f authored by pdoane's avatar pdoane Committed by GitHub
Browse files

Fix issue in maybe_convert_prompt (#3188)

When the token used for textual inversion does not have any special symbols (e.g. it is not surrounded by <>), the tokenizer does not properly split the replacement tokens.  Adding a space for the padding tokens fixes this.
parent 131312ca
...@@ -410,7 +410,7 @@ class TextualInversionLoaderMixin: ...@@ -410,7 +410,7 @@ class TextualInversionLoaderMixin:
replacement = token replacement = token
i = 1 i = 1
while f"{token}_{i}" in tokenizer.added_tokens_encoder: while f"{token}_{i}" in tokenizer.added_tokens_encoder:
replacement += f"{token}_{i}" replacement += f" {token}_{i}"
i += 1 i += 1
prompt = prompt.replace(token, replacement) prompt = prompt.replace(token, replacement)
......
...@@ -541,7 +541,7 @@ class DownloadTests(unittest.TestCase): ...@@ -541,7 +541,7 @@ class DownloadTests(unittest.TestCase):
assert pipe.text_encoder.get_input_embeddings().weight[-3].sum().item() == 96 assert pipe.text_encoder.get_input_embeddings().weight[-3].sum().item() == 96
assert pipe.text_encoder.get_input_embeddings().weight[-2].sum().item() == 128 assert pipe.text_encoder.get_input_embeddings().weight[-2].sum().item() == 128
assert pipe.text_encoder.get_input_embeddings().weight[-1].sum().item() == 160 assert pipe.text_encoder.get_input_embeddings().weight[-1].sum().item() == 160
assert pipe._maybe_convert_prompt("<***>", pipe.tokenizer) == "<***><***>_1<***>_2" assert pipe._maybe_convert_prompt("<***>", pipe.tokenizer) == "<***> <***>_1 <***>_2"
prompt = "hey <***>" prompt = "hey <***>"
out = pipe(prompt, num_inference_steps=1, output_type="numpy").images out = pipe(prompt, num_inference_steps=1, output_type="numpy").images
...@@ -569,7 +569,7 @@ class DownloadTests(unittest.TestCase): ...@@ -569,7 +569,7 @@ class DownloadTests(unittest.TestCase):
assert pipe.text_encoder.get_input_embeddings().weight[-3].sum().item() == 96 assert pipe.text_encoder.get_input_embeddings().weight[-3].sum().item() == 96
assert pipe.text_encoder.get_input_embeddings().weight[-2].sum().item() == 128 assert pipe.text_encoder.get_input_embeddings().weight[-2].sum().item() == 128
assert pipe.text_encoder.get_input_embeddings().weight[-1].sum().item() == 160 assert pipe.text_encoder.get_input_embeddings().weight[-1].sum().item() == 160
assert pipe._maybe_convert_prompt("<****>", pipe.tokenizer) == "<****><****>_1<****>_2" assert pipe._maybe_convert_prompt("<****>", pipe.tokenizer) == "<****> <****>_1 <****>_2"
prompt = "hey <****>" prompt = "hey <****>"
out = pipe(prompt, num_inference_steps=1, output_type="numpy").images out = pipe(prompt, num_inference_steps=1, output_type="numpy").images
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment