Unverified Commit 3132aac0 authored by Jee Jee Li's avatar Jee Jee Li Committed by GitHub
Browse files

[Bugfix] Fix Idefics3 bug (#10778)


Signed-off-by: default avatarJee Jee Li <pandaleefree@gmail.com>
parent c82b432d
...@@ -267,54 +267,56 @@ def input_processor_for_idefics3(ctx: InputContext, ...@@ -267,54 +267,56 @@ def input_processor_for_idefics3(ctx: InputContext,
n_images_in_text = [] n_images_in_text = []
text = inputs.get("prompt") text = inputs.get("prompt")
if text is not None: if text is None:
if isinstance(text, str): prompt_token_ids = inputs.get("prompt_token_ids", [])
text = [text] assert prompt_token_ids
elif not isinstance(text, list) and not isinstance(text[0], str): text = tokenizer.decode(prompt_token_ids)
raise ValueError("Invalid input text. Please provide a string, "
"or a list of strings") if isinstance(text, str):
text = [text]
fake_image_token = processor.fake_image_token.content elif not isinstance(text, list) and not isinstance(text[0], str):
image_token = processor.image_token.content raise ValueError("Invalid input text. Please provide a string, "
global_img_token = processor.global_image_tag "or a list of strings")
prompt_strings = [] fake_image_token = processor.fake_image_token.content
for sample, sample_rows, sample_cols in zip(text, image_rows, image_token = processor.image_token.content
image_cols): global_img_token = processor.global_image_tag
n_images_in_text.append(sample.count(image_token))
prompt_strings = []
# Replace the image token with fake tokens around the expanded for sample, sample_rows, sample_cols in zip(text, image_rows, image_cols):
# image token sequence of length `image_seq_len` n_images_in_text.append(sample.count(image_token))
image_prompt_strings = []
for n_rows, n_cols in zip(sample_rows, sample_cols): # Replace the image token with fake tokens around the expanded
image_prompt_string = _get_image_prompt_string( # image token sequence of length `image_seq_len`
n_rows, image_prompt_strings = []
n_cols, for n_rows, n_cols in zip(sample_rows, sample_cols):
processor.image_seq_len, image_prompt_string = _get_image_prompt_string(
image_token=image_token, n_rows,
fake_token_around_image=fake_image_token, n_cols,
global_img_token=global_img_token, processor.image_seq_len,
) image_token=image_token,
image_prompt_strings.append(image_prompt_string) fake_token_around_image=fake_image_token,
global_img_token=global_img_token,
split_sample = sample.split(image_token) )
if len(split_sample) == 0: image_prompt_strings.append(image_prompt_string)
raise ValueError(
"The image token should be present in the text.")
# Place in the image prompt strings where the image tokens are split_sample = sample.split(image_token)
sample = split_sample[0] if len(split_sample) == 0:
for i, image_prompt_string in enumerate(image_prompt_strings): raise ValueError("The image token should be present in the text.")
sample += image_prompt_string + split_sample[i + 1]
prompt_strings.append(sample)
prompt_token_ids = tokenizer(text=prompt_strings[0]).input_ids # Place in the image prompt strings where the image tokens are
sample = split_sample[0]
for i, image_prompt_string in enumerate(image_prompt_strings):
sample += image_prompt_string + split_sample[i + 1]
prompt_strings.append(sample)
return token_inputs( prompt_token_ids = tokenizer(text=prompt_strings[0]).input_ids
prompt_token_ids=prompt_token_ids,
prompt=prompt_strings[0], return token_inputs(
multi_modal_data=multi_modal_data, prompt_token_ids=prompt_token_ids,
) prompt=prompt_strings[0],
multi_modal_data=multi_modal_data,
)
def _get_max_num_image_patch(image_processor: Idefics3ImageProcessor) -> int: def _get_max_num_image_patch(image_processor: Idefics3ImageProcessor) -> int:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment