Commit 5c4839a7 authored by Zhen Li's avatar Zhen Li
Browse files

Fix crash on replacing the trigger word to prompt_text_only

parent 1af5d25e
......@@ -344,15 +344,11 @@ class PhotoMakerStableDiffusionXLPipeline(StableDiffusionXLPipeline):
)
# 4. Encode input prompt without the trigger word for delayed conditioning
tokens = self.tokenizer.encode(prompt, return_tensors="pt")
trigger_word_token = self.tokenizer.encode(self.trigger_word, return_tensors="pt")[0, 0]
trigger_word_index = (tokens[0] == trigger_word_token).nonzero()
if trigger_word_index.numel() > 0:
tokens = torch.cat((tokens[:, :trigger_word_index], tokens[:, trigger_word_index + 1:]), dim=1)
prompt_text_only = self.tokenizer.decode(tokens[0].tolist())
else:
prompt_text_only = prompt
# encode, remove trigger word token, then decode
tokens_text_only = self.tokenizer.encode(prompt, add_special_tokens=False)
trigger_word_token = self.tokenizer.convert_tokens_to_ids(self.trigger_word)
tokens_text_only.remove(trigger_word_token)
prompt_text_only = self.tokenizer.decode(tokens_text_only, add_special_tokens=False)
(
prompt_embeds_text_only,
negative_prompt_embeds,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment