Unverified Commit 712d0f88 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Refactor] Dynamic `target` and `content` for prompt updates (#23411)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 49ab23b3
...@@ -17,13 +17,11 @@ from vllm.multimodal.processing import (PlaceholderFeaturesInfo, ...@@ -17,13 +17,11 @@ from vllm.multimodal.processing import (PlaceholderFeaturesInfo,
PromptReplacement, apply_text_matches, PromptReplacement, apply_text_matches,
apply_token_matches, apply_token_matches,
find_mm_placeholders, find_mm_placeholders,
find_text_matches, find_token_matches,
iter_token_matches, iter_token_matches,
replace_token_matches) replace_token_matches)
# yapf: enable # yapf: enable
from vllm.multimodal.profiling import MultiModalProfiler from vllm.multimodal.profiling import MultiModalProfiler
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import full_groupby
from .utils import random_image from .utils import random_image
...@@ -75,12 +73,15 @@ from .utils import random_image ...@@ -75,12 +73,15 @@ from .utils import random_image
), ),
], ],
) )
@pytest.mark.parametrize("start_idx", [0, 4, 8])
# yapf: enable # yapf: enable
def test_iter_token_matches(token_ids, match_ids, expected): def test_iter_token_matches(token_ids, match_ids, expected, start_idx):
result = list(iter_token_matches(token_ids, match_ids)) result = list(iter_token_matches(token_ids, match_ids,
start_idx=start_idx))
# Manually constructed results # Manually constructed results
assert [item._asdict() for item in result] == expected assert [item._asdict() for item in result
] == [item for item in expected if item["start_idx"] >= start_idx]
# Invariants # Invariants
match_lens = [end - start for start, end in result] match_lens = [end - start for start, end in result]
...@@ -241,21 +242,23 @@ def test_find_token_matches( ...@@ -241,21 +242,23 @@ def test_find_token_matches(
# Should not be used since there is nothing to convert to token IDs # Should not be used since there is nothing to convert to token IDs
mock_tokenizer = cast(AnyTokenizer, object()) mock_tokenizer = cast(AnyTokenizer, object())
prompt_updates = [ prompt_updates = {
update_type(key, target, []).bind(mock_tokenizer) key: update_type(key, target, []).resolve(mock_tokenizer, 0)
for key, target in target_by_key.items() for key, target in target_by_key.items()
] }
result = find_token_matches(prompt, prompt_updates) result = {
key: list(update.iter_token_matches(prompt, mock_tokenizer))
for key, update in prompt_updates.items()
}
# Only displayed on error # Only displayed on error
print("result:", result) print("result:", result)
# Manually constructed results # Manually constructed results
result_groups = dict(full_groupby(result, key=lambda x: x.modality))
assert { assert {
key: [ key: [
dict(start_idx=item.start_idx, end_idx=item.end_idx) dict(start_idx=item.start_idx, end_idx=item.end_idx)
for item in result_groups.get(key, []) for item in result.get(key, [])
] ]
for key in expected_by_key for key in expected_by_key
} == expected_by_key } == expected_by_key
...@@ -388,21 +391,23 @@ def test_find_text_matches( ...@@ -388,21 +391,23 @@ def test_find_text_matches(
# Should not be used since there is nothing to convert to text # Should not be used since there is nothing to convert to text
mock_tokenizer = cast(AnyTokenizer, object()) mock_tokenizer = cast(AnyTokenizer, object())
prompt_updates = [ prompt_updates = {
update_type(key, target, []).bind(mock_tokenizer) key: update_type(key, target, []).resolve(mock_tokenizer, 0)
for key, target in target_by_key.items() for key, target in target_by_key.items()
] }
result = find_text_matches(prompt, prompt_updates) result = {
key: list(update.iter_text_matches(prompt, mock_tokenizer))
for key, update in prompt_updates.items()
}
# Only displayed on error # Only displayed on error
print("result:", result) print("result:", result)
# Manually constructed results # Manually constructed results
result_groups = dict(full_groupby(result, key=lambda x: x.modality))
assert { assert {
key: [ key: [
dict(start_idx=item.start_idx, end_idx=item.end_idx) dict(start_idx=item.start_idx, end_idx=item.end_idx)
for item in result_groups.get(key, []) for item in result.get(key, [])
] ]
for key in expected_by_key for key in expected_by_key
} == expected_by_key } == expected_by_key
...@@ -552,39 +557,37 @@ def test_find_update_text( ...@@ -552,39 +557,37 @@ def test_find_update_text(
update_type, update_type,
expected_by_mm_count, expected_by_mm_count,
) in expected_by_update_type_mm_count.items(): ) in expected_by_update_type_mm_count.items():
mm_prompt_updates = {
key:
[update_type(key, target, repl_by_key[key]).bind(mock_tokenizer)]
for key, target in target_by_key.items()
}
mm_matches = {
key: find_text_matches(prompt, updates)
for key, updates in mm_prompt_updates.items()
}
for mm_count, expected in expected_by_mm_count.items(): for mm_count, expected in expected_by_mm_count.items():
result = apply_text_matches( mm_prompt_updates = {
key: [[
update_type(key, target,
repl_by_key[key]).resolve(mock_tokenizer, i)
] for i in range(mm_count)]
for key, target in target_by_key.items()
}
new_prompt, result = apply_text_matches(
prompt, prompt,
mm_matches, mm_prompt_updates,
{key: mm_count mock_tokenizer,
for key in repl_by_key},
) )
# Only displayed on error # Only displayed on error
print("update_type:", update_type) print("update_type:", update_type)
print("mm_count:", mm_count) print("mm_count:", mm_count)
print("mm_matches:", mm_matches) print("mm_prompt_updates:", mm_prompt_updates)
print("new_prompt:", new_prompt)
print("result:", result) print("result:", result)
# Manually constructed results # Manually constructed results
assert result == expected assert new_prompt == expected
# yapf: disable # yapf: disable
@pytest.mark.parametrize( @pytest.mark.parametrize(
("prompt", "target_by_key", "repl_by_key", "expected_by_update_type_mm_count"), # noqa: E501 ("prompt", "target_by_key", "repl_by_key", "expected_by_update_type_mm_count"), # noqa: E501
[ [
# Tokenized test cases of `test_find_replace_text` # Tokenized test cases of `test_find_update_text`
# using the vocab of llava-hf/llava-v1.6-mistral-7b-hf # using the vocab of llava-hf/llava-v1.6-mistral-7b-hf
( (
[1, 9833, 28747, 32000, 9833, 28747, 32000, 32000, 918], [1, 9833, 28747, 32000, 9833, 28747, 32000, 32000, 918],
...@@ -726,32 +729,30 @@ def test_find_update_tokens( ...@@ -726,32 +729,30 @@ def test_find_update_tokens(
update_type, update_type,
expected_by_mm_count, expected_by_mm_count,
) in expected_by_update_type_mm_count.items(): ) in expected_by_update_type_mm_count.items():
mm_prompt_updates = {
key:
[update_type(key, target, repl_by_key[key]).bind(mock_tokenizer)]
for key, target in target_by_key.items()
}
mm_matches = {
key: find_token_matches(prompt, updates)
for key, updates in mm_prompt_updates.items()
}
for mm_count, expected in expected_by_mm_count.items(): for mm_count, expected in expected_by_mm_count.items():
result = apply_token_matches( mm_prompt_updates = {
key: [[
update_type(key, target,
repl_by_key[key]).resolve(mock_tokenizer, i)
] for i in range(mm_count)]
for key, target in target_by_key.items()
}
new_prompt, result = apply_token_matches(
prompt, prompt,
mm_matches, mm_prompt_updates,
{key: mm_count mock_tokenizer,
for key in repl_by_key},
) )
# Only displayed on error # Only displayed on error
print("update_type:", update_type) print("update_type:", update_type)
print("mm_count:", mm_count) print("mm_count:", mm_count)
print("mm_matches:", mm_matches) print("mm_prompt_updates:", mm_prompt_updates)
print("new_prompt:", new_prompt)
print("result:", result) print("result:", result)
# Manually constructed results # Manually constructed results
assert result == expected assert new_prompt == expected
# yapf: disable # yapf: disable
...@@ -878,17 +879,12 @@ def test_find_mm_placeholders( ...@@ -878,17 +879,12 @@ def test_find_mm_placeholders(
mock_tokenizer = cast(AnyTokenizer, object()) mock_tokenizer = cast(AnyTokenizer, object())
mm_prompt_updates = { mm_prompt_updates = {
key: [update_type(key, [], repl).bind(mock_tokenizer)] key: [[update_type(key, [], repl).resolve(mock_tokenizer, i)]
for i in range(3)]
for key, repl in repl_by_key.items() for key, repl in repl_by_key.items()
} }
result = find_mm_placeholders( result = find_mm_placeholders(prompt, mm_prompt_updates)
mm_prompt_updates,
prompt,
# Effectively match all occurrences in the prompt
{key: 3
for key in repl_by_key},
)
# Only displayed on error # Only displayed on error
print("result:", result) print("result:", result)
......
...@@ -22,10 +22,12 @@ from vllm.multimodal.parse import (ImageProcessorItems, ImageSize, ...@@ -22,10 +22,12 @@ from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
MultiModalDataItems) MultiModalDataItems)
# yapf: disable # yapf: disable
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, BoundPromptUpdate, BaseProcessingInfo,
MultiModalPromptUpdates,
MultiModalPromptUpdatesApplyResult,
PlaceholderFeaturesInfo, PlaceholderFeaturesInfo,
PromptReplacement, PromptTargetMatch, PromptReplacement, PromptUpdate,
PromptUpdate, PromptUpdateDetails, PromptUpdateDetails,
find_mm_placeholders, find_mm_placeholders,
replace_token_matches) replace_token_matches)
# yapf: enable # yapf: enable
...@@ -337,14 +339,10 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]): ...@@ -337,14 +339,10 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
def _apply_token_matches( def _apply_token_matches(
self, self,
prompt: list[int], prompt: list[int],
mm_matches: Mapping[str, Sequence[PromptTargetMatch]], mm_prompt_updates: MultiModalPromptUpdates,
mm_item_counts: Mapping[str, int], ) -> tuple[list[int], MultiModalPromptUpdatesApplyResult]:
) -> list[int]: token_ids, res = super()._apply_token_matches(prompt,
token_ids = super()._apply_token_matches( mm_prompt_updates)
prompt,
mm_matches,
mm_item_counts,
)
# "\n\n\n" and "\n\n\n\n" are single tokens # "\n\n\n" and "\n\n\n\n" are single tokens
# Since our replacement can insert "\n\n" next to "\n" # Since our replacement can insert "\n\n" next to "\n"
...@@ -373,13 +371,12 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]): ...@@ -373,13 +371,12 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
[newline_4], [newline_4],
) )
return token_ids return token_ids, res
def _find_mm_placeholders( def _find_mm_placeholders(
self, self,
mm_prompt_updates: Mapping[str, Sequence[BoundPromptUpdate]],
new_token_ids: list[int], new_token_ids: list[int],
mm_item_counts: Mapping[str, int], mm_prompt_updates: MultiModalPromptUpdates,
) -> Mapping[str, list[PlaceholderFeaturesInfo]]: ) -> Mapping[str, list[PlaceholderFeaturesInfo]]:
# We need to detect "\n\n" inside "\n\n\n" and "\n\n\n\n" # We need to detect "\n\n" inside "\n\n\n" and "\n\n\n\n"
tokenizer = self.info.get_tokenizer() tokenizer = self.info.get_tokenizer()
...@@ -404,8 +401,7 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]): ...@@ -404,8 +401,7 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
repl_token_ids.extend(repl_toks) repl_token_ids.extend(repl_toks)
repl_orig_idxs.extend(orig_idx for _ in range(len(repl_toks))) repl_orig_idxs.extend(orig_idx for _ in range(len(repl_toks)))
repls = find_mm_placeholders(mm_prompt_updates, repl_token_ids, repls = find_mm_placeholders(repl_token_ids, mm_prompt_updates)
mm_item_counts)
return { return {
modality: [ modality: [
......
...@@ -29,10 +29,12 @@ from vllm.multimodal.parse import (ImageProcessorItems, MultiModalDataItems, ...@@ -29,10 +29,12 @@ from vllm.multimodal.parse import (ImageProcessorItems, MultiModalDataItems,
MultiModalDataParser) MultiModalDataParser)
# yapf: disable # yapf: disable
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, BoundPromptUpdate, BaseProcessingInfo,
MultiModalPromptUpdates,
MultiModalPromptUpdatesApplyResult,
PlaceholderFeaturesInfo, PlaceholderFeaturesInfo,
PromptReplacement, PromptTargetMatch, PromptReplacement, PromptUpdate,
PromptUpdate, PromptUpdateDetails, PromptUpdateDetails,
find_mm_placeholders, find_mm_placeholders,
replace_token_matches) replace_token_matches)
# yapf: enable # yapf: enable
...@@ -254,14 +256,10 @@ class Gemma3nMultiModalProcessor(BaseMultiModalProcessor[Gemma3nProcessingInfo] ...@@ -254,14 +256,10 @@ class Gemma3nMultiModalProcessor(BaseMultiModalProcessor[Gemma3nProcessingInfo]
def _apply_token_matches( def _apply_token_matches(
self, self,
prompt: list[int], prompt: list[int],
mm_matches: Mapping[str, Sequence[PromptTargetMatch]], mm_prompt_updates: MultiModalPromptUpdates,
mm_item_counts: Mapping[str, int], ) -> tuple[list[int], MultiModalPromptUpdatesApplyResult]:
) -> list[int]: token_ids, res = super()._apply_token_matches(prompt,
token_ids = super()._apply_token_matches( mm_prompt_updates)
prompt,
mm_matches,
mm_item_counts,
)
# "\n\n\n" and "\n\n\n\n" are single tokens # "\n\n\n" and "\n\n\n\n" are single tokens
# Since our replacement can insert "\n\n" next to "\n" # Since our replacement can insert "\n\n" next to "\n"
...@@ -290,13 +288,12 @@ class Gemma3nMultiModalProcessor(BaseMultiModalProcessor[Gemma3nProcessingInfo] ...@@ -290,13 +288,12 @@ class Gemma3nMultiModalProcessor(BaseMultiModalProcessor[Gemma3nProcessingInfo]
[newline_4], [newline_4],
) )
return token_ids return token_ids, res
def _find_mm_placeholders( def _find_mm_placeholders(
self, self,
mm_prompt_updates: Mapping[str, Sequence[BoundPromptUpdate]],
new_token_ids: list[int], new_token_ids: list[int],
mm_item_counts: Mapping[str, int], mm_prompt_updates: MultiModalPromptUpdates,
) -> Mapping[str, list[PlaceholderFeaturesInfo]]: ) -> Mapping[str, list[PlaceholderFeaturesInfo]]:
# We need to detect "\n\n" inside "\n\n\n" and "\n\n\n\n" # We need to detect "\n\n" inside "\n\n\n" and "\n\n\n\n"
tokenizer = self.info.get_tokenizer() tokenizer = self.info.get_tokenizer()
...@@ -321,8 +318,7 @@ class Gemma3nMultiModalProcessor(BaseMultiModalProcessor[Gemma3nProcessingInfo] ...@@ -321,8 +318,7 @@ class Gemma3nMultiModalProcessor(BaseMultiModalProcessor[Gemma3nProcessingInfo]
repl_token_ids.extend(repl_toks) repl_token_ids.extend(repl_toks)
repl_orig_idxs.extend(orig_idx for _ in range(len(repl_toks))) repl_orig_idxs.extend(orig_idx for _ in range(len(repl_toks)))
repls = find_mm_placeholders(mm_prompt_updates, repl_token_ids, repls = find_mm_placeholders(repl_token_ids, mm_prompt_updates)
mm_item_counts)
return { return {
modality: [ modality: [
......
...@@ -828,26 +828,19 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor): ...@@ -828,26 +828,19 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor):
target=[image_token_id] * num_image_tokens, target=[image_token_id] * num_image_tokens,
replacement=get_replacement_mantis, replacement=get_replacement_mantis,
) )
]) ], mm_item_counts)
prompt_ids, prompt, _ = self._apply_prompt_updates( prompt_ids, prompt, _ = self._apply_prompt_updates(
result["prompt_token_ids"], result["prompt_token_ids"],
mantis_mm_repls, mantis_mm_repls,
mm_item_counts,
) )
unbound_orig_repls = self._get_prompt_updates( orig_repls = self._get_mm_prompt_updates(
mm_items, mm_items,
hf_processor_mm_kwargs, hf_processor_mm_kwargs,
mm_kwargs, mm_kwargs,
) )
orig_repls = self._bind_and_group_updates(unbound_orig_repls) mm_placeholders = self._find_mm_placeholders(prompt_ids, orig_repls)
mm_placeholders = self._find_mm_placeholders(
orig_repls,
prompt_ids,
mm_item_counts,
)
self._validate_mm_placeholders(mm_placeholders, mm_item_counts) self._validate_mm_placeholders(mm_placeholders, mm_item_counts)
mm_placeholder_ranges = { mm_placeholder_ranges = {
......
...@@ -38,7 +38,8 @@ from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, ...@@ -38,7 +38,8 @@ from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
# yapf conflicts with isort for this block # yapf conflicts with isort for this block
# yapf: disable # yapf: disable
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, BoundPromptUpdate, BaseProcessingInfo,
MultiModalPromptUpdates,
PlaceholderFeaturesInfo, PlaceholderFeaturesInfo,
PromptReplacement, PromptUpdate) PromptReplacement, PromptUpdate)
# yapf: enable # yapf: enable
...@@ -431,24 +432,21 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor[Phi3VProcessingInfo]): ...@@ -431,24 +432,21 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor[Phi3VProcessingInfo]):
return [_IMAGE_TOKEN_ID] * num_image_tokens return [_IMAGE_TOKEN_ID] * num_image_tokens
num_images = mm_items.get_count("image", strict=False)
return [ return [
PromptReplacement( PromptReplacement(
modality="image", modality="image",
target=image_token, target=image_tokens.__getitem__,
replacement=get_replacement_phi3v, replacement=get_replacement_phi3v,
) for image_token in image_tokens[:num_images] )
] ]
def _apply_prompt_updates( def _apply_prompt_updates(
self, self,
token_ids: list[int], token_ids: list[int],
mm_prompt_updates: Mapping[str, Sequence[BoundPromptUpdate]], mm_prompt_updates: MultiModalPromptUpdates,
mm_item_counts: Mapping[str, int],
) -> tuple[list[int], str, Mapping[str, list[PlaceholderFeaturesInfo]]]: ) -> tuple[list[int], str, Mapping[str, list[PlaceholderFeaturesInfo]]]:
# align to hf behavior when there are images # align to hf behavior when there are images
if len(mm_item_counts): if len(mm_prompt_updates):
tokenizer = self.info.get_tokenizer() tokenizer = self.info.get_tokenizer()
# to decode token_ids to the original text, we need to # to decode token_ids to the original text, we need to
# 1. remove the first bos token # 1. remove the first bos token
...@@ -484,7 +482,6 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor[Phi3VProcessingInfo]): ...@@ -484,7 +482,6 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor[Phi3VProcessingInfo]):
token_ids, text, placeholders = super()._apply_prompt_updates( token_ids, text, placeholders = super()._apply_prompt_updates(
token_ids=token_ids, token_ids=token_ids,
mm_prompt_updates=mm_prompt_updates, mm_prompt_updates=mm_prompt_updates,
mm_item_counts=mm_item_counts,
) )
# Keep the behavior in line with HF processor # Keep the behavior in line with HF processor
......
...@@ -1032,8 +1032,8 @@ class Phi4MMMultiModalProcessor(BaseMultiModalProcessor[Phi4MMProcessingInfo]): ...@@ -1032,8 +1032,8 @@ class Phi4MMMultiModalProcessor(BaseMultiModalProcessor[Phi4MMProcessingInfo]):
out_mm_kwargs: MultiModalKwargsItems, out_mm_kwargs: MultiModalKwargsItems,
) -> Sequence[PromptUpdate]: ) -> Sequence[PromptUpdate]:
tokenizer = self.info.get_tokenizer() tokenizer = self.info.get_tokenizer()
image_token_id = tokenizer.vocab[tokenizer.image_token] image_token_id: int = tokenizer.vocab[tokenizer.image_token]
audio_token_id = tokenizer.vocab[tokenizer.audio_token] audio_token_id: int = tokenizer.vocab[tokenizer.audio_token]
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
audio_processor = self.info.get_feature_extractor( audio_processor = self.info.get_feature_extractor(
...@@ -1053,9 +1053,7 @@ class Phi4MMMultiModalProcessor(BaseMultiModalProcessor[Phi4MMProcessingInfo]): ...@@ -1053,9 +1053,7 @@ class Phi4MMMultiModalProcessor(BaseMultiModalProcessor[Phi4MMProcessingInfo]):
processor=hf_processor, processor=hf_processor,
) )
image_tokens = [image_token_id] * num_image_tokens return [image_token_id] * num_image_tokens
return image_tokens
def get_audio_replacement_phi4mm(item_idx: int): def get_audio_replacement_phi4mm(item_idx: int):
audios = mm_items.get_items("audio", AudioProcessorItems) audios = mm_items.get_items("audio", AudioProcessorItems)
...@@ -1066,9 +1064,7 @@ class Phi4MMMultiModalProcessor(BaseMultiModalProcessor[Phi4MMProcessingInfo]): ...@@ -1066,9 +1064,7 @@ class Phi4MMMultiModalProcessor(BaseMultiModalProcessor[Phi4MMProcessingInfo]):
audio_embed_size = self.info._compute_audio_embed_size( audio_embed_size = self.info._compute_audio_embed_size(
audio_frames) audio_frames)
audio_tokens = [audio_token_id] * audio_embed_size return [audio_token_id] * audio_embed_size
return audio_tokens
return [ return [
PromptReplacement( PromptReplacement(
......
...@@ -824,9 +824,7 @@ class Phi4MMMultiModalProcessor(BaseMultiModalProcessor[Phi4MMProcessingInfo]): ...@@ -824,9 +824,7 @@ class Phi4MMMultiModalProcessor(BaseMultiModalProcessor[Phi4MMProcessingInfo]):
processor=hf_processor, processor=hf_processor,
) )
image_tokens = [_IMAGE_PLACEHOLDER_TOKEN_ID] * num_image_tokens return [_IMAGE_PLACEHOLDER_TOKEN_ID] * num_image_tokens
return image_tokens
def get_audio_replacement_phi4mm(item_idx: int): def get_audio_replacement_phi4mm(item_idx: int):
audios = mm_items.get_items("audio", AudioProcessorItems) audios = mm_items.get_items("audio", AudioProcessorItems)
...@@ -837,28 +835,20 @@ class Phi4MMMultiModalProcessor(BaseMultiModalProcessor[Phi4MMProcessingInfo]): ...@@ -837,28 +835,20 @@ class Phi4MMMultiModalProcessor(BaseMultiModalProcessor[Phi4MMProcessingInfo]):
audio_embed_size = self.info._compute_audio_embed_size( audio_embed_size = self.info._compute_audio_embed_size(
audio_frames) audio_frames)
audio_tokens = [_AUDIO_PLACEHOLDER_TOKEN_ID] * audio_embed_size return [_AUDIO_PLACEHOLDER_TOKEN_ID] * audio_embed_size
return audio_tokens
num_images = mm_items.get_count("image", strict=False) return [
num_audios = mm_items.get_count("audio", strict=False)
image_repl = [
PromptReplacement( PromptReplacement(
modality="image", modality="image",
target=image_token, target=image_tokens.__getitem__,
replacement=get_image_replacement_phi4mm, replacement=get_image_replacement_phi4mm,
) for image_token in image_tokens[:num_images] ),
]
audio_repl = [
PromptReplacement( PromptReplacement(
modality="audio", modality="audio",
target=audio_token, target=audio_tokens.__getitem__,
replacement=get_audio_replacement_phi4mm, replacement=get_audio_replacement_phi4mm,
) for audio_token in audio_tokens[:num_audios] ),
] ]
return image_repl + audio_repl
@MULTIMODAL_REGISTRY.register_processor( @MULTIMODAL_REGISTRY.register_processor(
......
...@@ -309,9 +309,8 @@ class Qwen2_5OmniThinkerMultiModalProcessor( ...@@ -309,9 +309,8 @@ class Qwen2_5OmniThinkerMultiModalProcessor(
if is_update_applied: if is_update_applied:
mm_placeholders = self._find_mm_placeholders( mm_placeholders = self._find_mm_placeholders(
mm_prompt_updates,
prompt_ids, prompt_ids,
mm_item_counts, mm_prompt_updates,
) )
self._validate_mm_placeholders( self._validate_mm_placeholders(
mm_placeholders, mm_placeholders,
...@@ -328,7 +327,6 @@ class Qwen2_5OmniThinkerMultiModalProcessor( ...@@ -328,7 +327,6 @@ class Qwen2_5OmniThinkerMultiModalProcessor(
) = self._apply_prompt_updates( ) = self._apply_prompt_updates(
prompt_ids, prompt_ids,
mm_prompt_updates, mm_prompt_updates,
mm_item_counts,
) )
self._validate_mm_placeholders( self._validate_mm_placeholders(
mm_placeholders, mm_placeholders,
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment