Unverified Commit d45cbe70 authored by Travis Johnson's avatar Travis Johnson Committed by GitHub
Browse files

[Bugfix] Check that number of images matches number of <|image|> tokens with mllama (#11939)


Signed-off-by: default avatarTravis Johnson <tsjohnso@us.ibm.com>
parent 8a579408
...@@ -123,6 +123,13 @@ def input_processor_for_mllama( ...@@ -123,6 +123,13 @@ def input_processor_for_mllama(
assert is_list_of(image_data, Image.Image) assert is_list_of(image_data, Image.Image)
num_image_tokens = dec_inputs['prompt_token_ids'].count(
MLLAMA_IMAGE_TOKEN_ID)
if num_image_tokens != len(image_data):
raise ValueError(
f"The number of image tokens ({num_image_tokens}) must be"
f" the same as the number of images ({len(image_data)})")
# Since only the last group of consecutive images # Since only the last group of consecutive images
# are attended by the decoded tokens, we only need to # are attended by the decoded tokens, we only need to
# get the number of tiles for those images. # get the number of tiles for those images.
...@@ -1493,6 +1500,8 @@ def convert_sparse_cross_attention_mask_to_dense( ...@@ -1493,6 +1500,8 @@ def convert_sparse_cross_attention_mask_to_dense(
dense_mask[seq_start + start:seq_start + end, dense_mask[seq_start + start:seq_start + end,
tile_start:tile_start + tile] = 1 tile_start:tile_start + tile] = 1
tile_start += tile tile_start += tile
assert ts != -1
assert td != 0
tile_range_for_decode.append((ts, ts + td)) tile_range_for_decode.append((ts, ts + td))
seq_start += length seq_start += length
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment