Unverified Commit 1af5d25e authored by bytesurfer3's avatar bytesurfer3 Committed by GitHub
Browse files

Update pipeline.py- issue #63 (#65)

The modification replaces this method with a token-based approach using a tokenizer. This tokenizes the entire prompt, identifies the token corresponding to the trigger word, removes it, and then decodes the modified tokens back into text.
parent 8d37e305
......@@ -344,7 +344,15 @@ class PhotoMakerStableDiffusionXLPipeline(StableDiffusionXLPipeline):
)
# 4. Encode input prompt without the trigger word for delayed conditioning
prompt_text_only = prompt.replace(" "+self.trigger_word, "") # sensitive to white space
tokens = self.tokenizer.encode(prompt, return_tensors="pt")
trigger_word_token = self.tokenizer.encode(self.trigger_word, return_tensors="pt")[0, 0]
trigger_word_index = (tokens[0] == trigger_word_token).nonzero()
if trigger_word_index.numel() > 0:
tokens = torch.cat((tokens[:, :trigger_word_index], tokens[:, trigger_word_index + 1:]), dim=1)
prompt_text_only = self.tokenizer.decode(tokens[0].tolist())
else:
prompt_text_only = prompt
(
prompt_embeds_text_only,
negative_prompt_embeds,
......@@ -490,4 +498,4 @@ class PhotoMakerStableDiffusionXLPipeline(StableDiffusionXLPipeline):
if not return_dict:
return (image,)
return StableDiffusionXLPipelineOutput(images=image)
\ No newline at end of file
return StableDiffusionXLPipelineOutput(images=image)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment