"vscode:/vscode.git/clone" did not exist on "b41ce1e0905e53f29d4f61d43782b5e1219b0ddf"
Unverified Commit a6c7b5b6 authored by Lachlan Nicholson's avatar Lachlan Nicholson Committed by GitHub
Browse files

Iterate over unique tokens to avoid duplicate replacements for multivector embeddings (#3588)

* iterate over unique tokens to avoid duplicate replacements

* added test for multiple references to multi embedding

* adhere to black formatting

* reorder test post-rebase
parent 8e552bb4
......@@ -462,7 +462,8 @@ class TextualInversionLoaderMixin:
`str`: The converted prompt
"""
tokens = tokenizer.tokenize(prompt)
for token in tokens:
unique_tokens = set(tokens)
for token in unique_tokens:
if token in tokenizer.added_tokens_encoder:
replacement = token
i = 1
......
......@@ -722,6 +722,18 @@ class DownloadTests(unittest.TestCase):
out = pipe(prompt, num_inference_steps=1, output_type="numpy").images
assert out.shape == (1, 128, 128, 3)
# multiple references to multi embedding
ten = {"<cat>": torch.ones(3, 32)}
pipe.load_textual_inversion(ten)
assert (
pipe._maybe_convert_prompt("<cat> <cat>", pipe.tokenizer) == "<cat> <cat>_1 <cat>_2 <cat> <cat>_1 <cat>_2"
)
prompt = "hey <cat> <cat>"
out = pipe(prompt, num_inference_steps=1, output_type="numpy").images
assert out.shape == (1, 128, 128, 3)
def test_download_ignore_files(self):
# Check https://huggingface.co/hf-internal-testing/tiny-stable-diffusion-pipe-ignore-files/blob/72f58636e5508a218c6b3f60550dc96445547817/model_index.json#L4
with tempfile.TemporaryDirectory() as tmpdirname:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment