Unverified Commit 45572c24 authored by YiYi Xu's avatar YiYi Xu Committed by GitHub
Browse files

fix the get_indices function (#2418)


Co-authored-by: default avataryiyixuxu <yixu310@gmail,com>
parent 5f65ef4d
...@@ -47,6 +47,7 @@ EXAMPLE_DOC_STRING = """ ...@@ -47,6 +47,7 @@ EXAMPLE_DOC_STRING = """
>>> # use get_indices function to find out indices of the tokens you want to alter >>> # use get_indices function to find out indices of the tokens you want to alter
>>> pipe.get_indices(prompt) >>> pipe.get_indices(prompt)
{0: '<|startoftext|>', 1: 'a</w>', 2: 'cat</w>', 3: 'and</w>', 4: 'a</w>', 5: 'frog</w>', 6: '<|endoftext|>'}
>>> token_indices = [2, 5] >>> token_indices = [2, 5]
>>> seed = 6141 >>> seed = 6141
...@@ -662,7 +663,7 @@ class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline): ...@@ -662,7 +663,7 @@ class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline):
def get_indices(self, prompt: str) -> Dict[str, int]: def get_indices(self, prompt: str) -> Dict[str, int]:
"""Utility function to list the indices of the tokens you wish to alte""" """Utility function to list the indices of the tokens you wish to alte"""
ids = self.tokenizer(prompt).input_ids ids = self.tokenizer(prompt).input_ids
indices = {tok: i for tok, i in zip(self.tokenizer.convert_ids_to_tokens(ids), range(len(ids)))} indices = {i: tok for tok, i in zip(self.tokenizer.convert_ids_to_tokens(ids), range(len(ids)))}
return indices return indices
@torch.no_grad() @torch.no_grad()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment