"...text-generation-inference.git" did not exist on "444400b45760ab54faac0689f2f815c1fb425a8f"
Unverified Commit a6c7b5b6 authored by Lachlan Nicholson's avatar Lachlan Nicholson Committed by GitHub
Browse files

Iterate over unique tokens to avoid duplicate replacements for multivector embeddings (#3588)

* iterate over unique tokens to avoid duplicate replacements

* added test for multiple references to multi embedding

* adhere to black formatting

* reorder test post-rebase
parent 8e552bb4
...@@ -462,7 +462,8 @@ class TextualInversionLoaderMixin: ...@@ -462,7 +462,8 @@ class TextualInversionLoaderMixin:
`str`: The converted prompt `str`: The converted prompt
""" """
tokens = tokenizer.tokenize(prompt) tokens = tokenizer.tokenize(prompt)
for token in tokens: unique_tokens = set(tokens)
for token in unique_tokens:
if token in tokenizer.added_tokens_encoder: if token in tokenizer.added_tokens_encoder:
replacement = token replacement = token
i = 1 i = 1
......
...@@ -722,6 +722,18 @@ class DownloadTests(unittest.TestCase): ...@@ -722,6 +722,18 @@ class DownloadTests(unittest.TestCase):
out = pipe(prompt, num_inference_steps=1, output_type="numpy").images out = pipe(prompt, num_inference_steps=1, output_type="numpy").images
assert out.shape == (1, 128, 128, 3) assert out.shape == (1, 128, 128, 3)
# multiple references to multi embedding
ten = {"<cat>": torch.ones(3, 32)}
pipe.load_textual_inversion(ten)
assert (
pipe._maybe_convert_prompt("<cat> <cat>", pipe.tokenizer) == "<cat> <cat>_1 <cat>_2 <cat> <cat>_1 <cat>_2"
)
prompt = "hey <cat> <cat>"
out = pipe(prompt, num_inference_steps=1, output_type="numpy").images
assert out.shape == (1, 128, 128, 3)
def test_download_ignore_files(self): def test_download_ignore_files(self):
# Check https://huggingface.co/hf-internal-testing/tiny-stable-diffusion-pipe-ignore-files/blob/72f58636e5508a218c6b3f60550dc96445547817/model_index.json#L4 # Check https://huggingface.co/hf-internal-testing/tiny-stable-diffusion-pipe-ignore-files/blob/72f58636e5508a218c6b3f60550dc96445547817/model_index.json#L4
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment