Unverified Commit fe3398fa authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Chore] Enable passing `tokenizer=None` into MM processor (#29724)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent ad7f714d
...@@ -3,7 +3,6 @@ ...@@ -3,7 +3,6 @@
import time import time
from contextlib import nullcontext from contextlib import nullcontext
from typing import cast
import numpy as np import numpy as np
import pytest import pytest
...@@ -24,7 +23,6 @@ from vllm.multimodal.processing import ( ...@@ -24,7 +23,6 @@ from vllm.multimodal.processing import (
replace_token_matches, replace_token_matches,
) )
from vllm.multimodal.profiling import MultiModalProfiler from vllm.multimodal.profiling import MultiModalProfiler
from vllm.tokenizers import TokenizerLike
from .utils import random_image from .utils import random_image
...@@ -238,15 +236,12 @@ def test_find_token_matches( ...@@ -238,15 +236,12 @@ def test_find_token_matches(
expected_by_key, expected_by_key,
update_type, update_type,
): ):
# Should not be used since there is nothing to convert to token IDs
mock_tokenizer = cast(TokenizerLike, object())
prompt_updates = { prompt_updates = {
key: update_type(key, target, []).resolve(0) key: update_type(key, target, []).resolve(0)
for key, target in target_by_key.items() for key, target in target_by_key.items()
} }
result = { result = {
key: list(update.iter_token_matches(prompt, mock_tokenizer)) key: list(update.iter_token_matches(prompt, tokenizer=None))
for key, update in prompt_updates.items() for key, update in prompt_updates.items()
} }
...@@ -385,15 +380,12 @@ def test_find_text_matches( ...@@ -385,15 +380,12 @@ def test_find_text_matches(
expected_by_key, expected_by_key,
update_type, update_type,
): ):
# Should not be used since there is nothing to convert to text
mock_tokenizer = cast(TokenizerLike, object())
prompt_updates = { prompt_updates = {
key: update_type(key, target, []).resolve(0) key: update_type(key, target, []).resolve(0)
for key, target in target_by_key.items() for key, target in target_by_key.items()
} }
result = { result = {
key: list(update.iter_text_matches(prompt, mock_tokenizer)) key: list(update.iter_text_matches(prompt, tokenizer=None))
for key, update in prompt_updates.items() for key, update in prompt_updates.items()
} }
...@@ -545,9 +537,6 @@ def test_find_update_text( ...@@ -545,9 +537,6 @@ def test_find_update_text(
repl_by_key, repl_by_key,
expected_by_update_type_mm_count, expected_by_update_type_mm_count,
): ):
# Should not be used since there is nothing to convert to text
mock_tokenizer = cast(TokenizerLike, object())
for ( for (
update_type, update_type,
expected_by_mm_count, expected_by_mm_count,
...@@ -564,7 +553,7 @@ def test_find_update_text( ...@@ -564,7 +553,7 @@ def test_find_update_text(
new_prompt, result = apply_text_matches( new_prompt, result = apply_text_matches(
prompt, prompt,
mm_prompt_updates, mm_prompt_updates,
mock_tokenizer, tokenizer=None,
) )
# Only displayed on error # Only displayed on error
...@@ -750,9 +739,6 @@ def test_find_update_tokens( ...@@ -750,9 +739,6 @@ def test_find_update_tokens(
repl_by_key, repl_by_key,
expected_by_update_type_mm_count, expected_by_update_type_mm_count,
): ):
# Should not be used since there is nothing to convert to tokens
mock_tokenizer = cast(TokenizerLike, object())
for ( for (
update_type, update_type,
expected_by_mm_count, expected_by_mm_count,
...@@ -769,7 +755,7 @@ def test_find_update_tokens( ...@@ -769,7 +755,7 @@ def test_find_update_tokens(
new_prompt, result = apply_token_matches( new_prompt, result = apply_token_matches(
prompt, prompt,
mm_prompt_updates, mm_prompt_updates,
mock_tokenizer, tokenizer=None,
) )
# Only displayed on error # Only displayed on error
...@@ -900,15 +886,12 @@ def test_find_mm_placeholders( ...@@ -900,15 +886,12 @@ def test_find_mm_placeholders(
expected, expected,
update_type, update_type,
): ):
# Should not be used since there is nothing to convert to tokens
mock_tokenizer = cast(TokenizerLike, object())
mm_prompt_updates = { mm_prompt_updates = {
key: [[update_type(key, [], repl).resolve(i)] for i in range(3)] key: [[update_type(key, [], repl).resolve(i)] for i in range(3)]
for key, repl in repl_by_key.items() for key, repl in repl_by_key.items()
} }
result = find_mm_placeholders(prompt, mm_prompt_updates, mock_tokenizer) result = find_mm_placeholders(prompt, mm_prompt_updates, tokenizer=None)
# Only displayed on error # Only displayed on error
print("result:", result) print("result:", result)
...@@ -1029,12 +1012,9 @@ def test_hf_processor_init_kwargs( ...@@ -1029,12 +1012,9 @@ def test_hf_processor_init_kwargs(
inference_kwargs, inference_kwargs,
expected_kwargs, expected_kwargs,
): ):
# Should not be used since there is nothing to convert to tokens
mock_tokenizer = cast(TokenizerLike, object())
ctx = InputProcessingContext( ctx = InputProcessingContext(
model_config=ModelConfig(model_id, mm_processor_kwargs=config_kwargs), model_config=ModelConfig(model_id, mm_processor_kwargs=config_kwargs),
tokenizer=mock_tokenizer, tokenizer=None,
) )
processor = ctx.get_hf_processor( processor = ctx.get_hf_processor(
...@@ -1065,12 +1045,9 @@ def test_hf_processor_call_kwargs( ...@@ -1065,12 +1045,9 @@ def test_hf_processor_call_kwargs(
inference_kwargs, inference_kwargs,
expected_kwargs, expected_kwargs,
): ):
# Should not be used since there is nothing to convert to tokens
mock_tokenizer = cast(TokenizerLike, object())
ctx = InputProcessingContext( ctx = InputProcessingContext(
model_config=ModelConfig(model_id, mm_processor_kwargs=config_kwargs), model_config=ModelConfig(model_id, mm_processor_kwargs=config_kwargs),
tokenizer=mock_tokenizer, tokenizer=None,
) )
processor = ctx.get_hf_processor(DummyProcessor) # type: ignore[arg-type] processor = ctx.get_hf_processor(DummyProcessor) # type: ignore[arg-type]
...@@ -1089,8 +1066,6 @@ def test_apply_matches_no_match_exits_quickly(): ...@@ -1089,8 +1066,6 @@ def test_apply_matches_no_match_exits_quickly():
With the fix, it should exit immediately when no match is found. With the fix, it should exit immediately when no match is found.
""" """
mock_tokenizer = cast(TokenizerLike, object())
# Create a long prompt with no placeholder # Create a long prompt with no placeholder
long_prompt = "x" * 10000 long_prompt = "x" * 10000
...@@ -1103,7 +1078,7 @@ def test_apply_matches_no_match_exits_quickly(): ...@@ -1103,7 +1078,7 @@ def test_apply_matches_no_match_exits_quickly():
result, _ = _apply_matches( result, _ = _apply_matches(
long_prompt, long_prompt,
mm_prompt_updates, mm_prompt_updates,
mock_tokenizer, tokenizer=None,
) )
elapsed = time.perf_counter() - start elapsed = time.perf_counter() - start
......
...@@ -337,7 +337,7 @@ class OpenAIServing: ...@@ -337,7 +337,7 @@ class OpenAIServing:
tokenizer = input_processor.tokenizer tokenizer = input_processor.tokenizer
if tokenizer is None: if tokenizer is None:
raise ValueError( raise ValueError(
"You cannot use beam search when `skip_tokenizer_init` is True" "You cannot use beam search when `skip_tokenizer_init=True`"
) )
eos_token_id: int = tokenizer.eos_token_id # type: ignore eos_token_id: int = tokenizer.eos_token_id # type: ignore
......
...@@ -62,7 +62,7 @@ class InputPreprocessor: ...@@ -62,7 +62,7 @@ class InputPreprocessor:
def get_tokenizer(self) -> TokenizerLike: def get_tokenizer(self) -> TokenizerLike:
if self.tokenizer is None: if self.tokenizer is None:
raise ValueError( raise ValueError(
"You cannot pass text prompts when `skip_tokenizer_init` is True" "You cannot pass text prompts when `skip_tokenizer_init=True`"
) )
return self.tokenizer return self.tokenizer
...@@ -228,22 +228,11 @@ class InputPreprocessor: ...@@ -228,22 +228,11 @@ class InputPreprocessor:
return tokenizer.encode(prompt, **tokenization_kwargs) return tokenizer.encode(prompt, **tokenization_kwargs)
def _get_mm_tokenizer(self) -> TokenizerLike:
# PrithviGeoSpatialMAE needs to be initialized without a tokenizer
# while using also multi-modal input
if not self.tokenizer:
return cast(TokenizerLike, object()) # Dummy
tokenizer = self.get_tokenizer()
return tokenizer
def _get_mm_processor(self) -> BaseMultiModalProcessor: def _get_mm_processor(self) -> BaseMultiModalProcessor:
if not hasattr(self, "_mm_processor"): if not hasattr(self, "_mm_processor"):
tokenizer = self._get_mm_tokenizer()
self._mm_processor = self.mm_registry.create_processor( self._mm_processor = self.mm_registry.create_processor(
self.model_config, self.model_config,
tokenizer=tokenizer, tokenizer=self.tokenizer,
cache=self.mm_processor_cache, cache=self.mm_processor_cache,
) )
......
...@@ -866,12 +866,6 @@ class Glm4vVisionTransformer(nn.Module): ...@@ -866,12 +866,6 @@ class Glm4vVisionTransformer(nn.Module):
class Glm4vProcessingInfo(BaseProcessingInfo): class Glm4vProcessingInfo(BaseProcessingInfo):
def get_hf_config(self):
return self.ctx.get_hf_config()
def get_tokenizer(self):
return self.ctx.tokenizer
def get_supported_mm_limits(self) -> Mapping[str, int | None]: def get_supported_mm_limits(self) -> Mapping[str, int | None]:
return {"image": None, "video": 1} return {"image": None, "video": 1}
......
...@@ -615,9 +615,6 @@ class Qwen3VLProcessingInfo(Qwen2VLProcessingInfo): ...@@ -615,9 +615,6 @@ class Qwen3VLProcessingInfo(Qwen2VLProcessingInfo):
**kwargs, **kwargs,
) )
def get_tokenizer(self):
return self.ctx.tokenizer
def get_image_processor(self, **kwargs: object) -> Qwen2VLImageProcessorFast: def get_image_processor(self, **kwargs: object) -> Qwen2VLImageProcessorFast:
return self.get_hf_processor(**kwargs).image_processor return self.get_hf_processor(**kwargs).image_processor
......
...@@ -555,7 +555,7 @@ class QwenVLProcessor: ...@@ -555,7 +555,7 @@ class QwenVLProcessor:
class QwenVLProcessingInfo(BaseProcessingInfo): class QwenVLProcessingInfo(BaseProcessingInfo):
def get_tokenizer(self) -> PreTrainedTokenizer: def get_tokenizer(self) -> PreTrainedTokenizer:
tokenizer = self.ctx.tokenizer tokenizer = self.ctx.get_tokenizer()
assert isinstance(tokenizer, PreTrainedTokenizer) assert isinstance(tokenizer, PreTrainedTokenizer)
return _get_tokenizer_without_image_pad(tokenizer) return _get_tokenizer_without_image_pad(tokenizer)
......
...@@ -97,15 +97,37 @@ def _cached_decode( ...@@ -97,15 +97,37 @@ def _cached_decode(
) )
def _seq2text(tokenizer: TokenizerLike, seq: PromptSeq) -> str: def _seq2text(
tokenizer: TokenizerLike | None,
seq: PromptSeq,
*,
use_cache: bool = True,
) -> str:
if isinstance(seq, str): if isinstance(seq, str):
return seq return seq
if tokenizer is None:
raise ValueError("You cannot decode tokens when `skip_tokenizer_init=True`")
if not use_cache:
return decode_tokens(tokenizer, seq)
return _cached_decode(tokenizer, tuple(seq)) return _cached_decode(tokenizer, tuple(seq))
def _seq2tokens(tokenizer: TokenizerLike, seq: PromptSeq) -> list[int]: def _seq2tokens(
tokenizer: TokenizerLike | None,
seq: PromptSeq,
*,
use_cache: bool = True,
) -> list[int]:
if isinstance(seq, str): if isinstance(seq, str):
if tokenizer is None:
raise ValueError("You cannot encode text when `skip_tokenizer_init=True`")
if not use_cache:
return encode_tokens(tokenizer, seq, add_special_tokens=False)
return _cached_encode(tokenizer, seq, add_special_tokens=False) return _cached_encode(tokenizer, seq, add_special_tokens=False)
return seq return seq
...@@ -114,7 +136,7 @@ def _seq2tokens(tokenizer: TokenizerLike, seq: PromptSeq) -> list[int]: ...@@ -114,7 +136,7 @@ def _seq2tokens(tokenizer: TokenizerLike, seq: PromptSeq) -> list[int]:
class _GetMatchIndex(Protocol): class _GetMatchIndex(Protocol):
def __call__( def __call__(
self, self,
tokenizer: TokenizerLike, tokenizer: TokenizerLike | None,
prompt: PromptSeq, prompt: PromptSeq,
start_idx: int = 0, start_idx: int = 0,
) -> int | None: ... ) -> int | None: ...
...@@ -144,7 +166,7 @@ class PromptIndexTargets: ...@@ -144,7 +166,7 @@ class PromptIndexTargets:
""" """
def get_match_index( def get_match_index(
tokenizer: TokenizerLike, tokenizer: TokenizerLike | None,
prompt: PromptSeq, prompt: PromptSeq,
start_idx: int = 0, start_idx: int = 0,
) -> int | None: ) -> int | None:
...@@ -154,13 +176,11 @@ class PromptIndexTargets: ...@@ -154,13 +176,11 @@ class PromptIndexTargets:
prefix = seq prefix = seq
if isinstance(prompt, str): if isinstance(prompt, str):
if not isinstance(prefix, str):
# Make both `str` # Make both `str`
prefix = decode_tokens(tokenizer, prefix) prefix = _seq2text(tokenizer, prefix, use_cache=False)
else: else:
if isinstance(prefix, str):
# Make both `list[int]` # Make both `list[int]`
prefix = encode_tokens(tokenizer, prefix, add_special_tokens=False) prefix = _seq2tokens(tokenizer, prefix, use_cache=False)
match_idx = len(prefix) match_idx = len(prefix)
return match_idx if prompt[:match_idx] == prefix else None return match_idx if prompt[:match_idx] == prefix else None
...@@ -200,7 +220,7 @@ class PromptUpdateDetails(Generic[_S]): ...@@ -200,7 +220,7 @@ class PromptUpdateDetails(Generic[_S]):
full: _S full: _S
"""The full content.""" """The full content."""
is_embed: Callable[[TokenizerLike, PromptSeq], torch.Tensor] | None = None is_embed: Callable[[TokenizerLike | None, PromptSeq], torch.Tensor] | None = None
""" """
Given [`full`][vllm.multimodal.processing.PromptUpdateDetails.full], Given [`full`][vllm.multimodal.processing.PromptUpdateDetails.full],
return a boolean mask of shape `(len(full),)` indicating which positions return a boolean mask of shape `(len(full),)` indicating which positions
...@@ -221,8 +241,8 @@ class PromptUpdateDetails(Generic[_S]): ...@@ -221,8 +241,8 @@ class PromptUpdateDetails(Generic[_S]):
seq: _S, seq: _S,
embed_text: str, embed_text: str,
) -> "PromptUpdateDetails[_S]": ) -> "PromptUpdateDetails[_S]":
def is_embed(tokenizer: TokenizerLike, full: PromptSeq) -> torch.Tensor: def is_embed(tokenizer: TokenizerLike | None, full: PromptSeq) -> torch.Tensor:
embed_token_ids = encode_tokens(tokenizer, embed_text) embed_token_ids = _seq2tokens(tokenizer, embed_text, use_cache=False)
token_ids = _seq2tokens(tokenizer, full) token_ids = _seq2tokens(tokenizer, full)
return torch.isin( return torch.isin(
...@@ -237,7 +257,7 @@ class PromptUpdateDetails(Generic[_S]): ...@@ -237,7 +257,7 @@ class PromptUpdateDetails(Generic[_S]):
seq: _S, seq: _S,
embed_token_id: int, embed_token_id: int,
) -> "PromptUpdateDetails[_S]": ) -> "PromptUpdateDetails[_S]":
def is_embed(tokenizer: TokenizerLike, full: PromptSeq) -> torch.Tensor: def is_embed(tokenizer: TokenizerLike | None, full: PromptSeq) -> torch.Tensor:
token_ids = _seq2tokens(tokenizer, full) token_ids = _seq2tokens(tokenizer, full)
return torch.tensor(token_ids) == embed_token_id return torch.tensor(token_ids) == embed_token_id
...@@ -523,7 +543,7 @@ class ResolvedPromptUpdate: ...@@ -523,7 +543,7 @@ class ResolvedPromptUpdate:
def iter_token_matches( def iter_token_matches(
self, self,
prompt: list[int], prompt: list[int],
tokenizer: TokenizerLike, tokenizer: TokenizerLike | None,
*, *,
start_idx: int = 0, start_idx: int = 0,
) -> Generator[PromptTargetMatch]: ) -> Generator[PromptTargetMatch]:
...@@ -545,7 +565,7 @@ class ResolvedPromptUpdate: ...@@ -545,7 +565,7 @@ class ResolvedPromptUpdate:
def iter_text_matches( def iter_text_matches(
self, self,
prompt: str, prompt: str,
tokenizer: TokenizerLike, tokenizer: TokenizerLike | None,
*, *,
start_idx: int = 0, start_idx: int = 0,
) -> Generator[PromptTargetMatch]: ) -> Generator[PromptTargetMatch]:
...@@ -567,7 +587,7 @@ class ResolvedPromptUpdate: ...@@ -567,7 +587,7 @@ class ResolvedPromptUpdate:
def iter_matches( def iter_matches(
self, self,
prompt: list[int] | str, prompt: list[int] | str,
tokenizer: TokenizerLike, tokenizer: TokenizerLike | None,
*, *,
start_idx: int = 0, start_idx: int = 0,
) -> Generator[PromptTargetMatch]: ) -> Generator[PromptTargetMatch]:
...@@ -676,7 +696,7 @@ _MatchToApply = tuple[tuple[str, int], tuple[PromptTargetMatch, int]] ...@@ -676,7 +696,7 @@ _MatchToApply = tuple[tuple[str, int], tuple[PromptTargetMatch, int]]
def _find_matches( def _find_matches(
prompt: _S, prompt: _S,
mm_prompt_updates: "MultiModalPromptUpdates", mm_prompt_updates: "MultiModalPromptUpdates",
tokenizer: TokenizerLike, tokenizer: TokenizerLike | None,
*, *,
prev_end_idx: int = 0, prev_end_idx: int = 0,
current_result: "MultiModalPromptUpdatesApplyResult", current_result: "MultiModalPromptUpdatesApplyResult",
...@@ -741,7 +761,7 @@ def _all_items_found( ...@@ -741,7 +761,7 @@ def _all_items_found(
def _apply_matches( def _apply_matches(
prompt: _S, prompt: _S,
mm_prompt_updates: "MultiModalPromptUpdates", mm_prompt_updates: "MultiModalPromptUpdates",
tokenizer: TokenizerLike, tokenizer: TokenizerLike | None,
) -> tuple[list[_S], "MultiModalPromptUpdatesApplyResult"]: ) -> tuple[list[_S], "MultiModalPromptUpdatesApplyResult"]:
mm_item_counts = {m: len(items) for m, items in mm_prompt_updates.items()} mm_item_counts = {m: len(items) for m, items in mm_prompt_updates.items()}
...@@ -807,7 +827,7 @@ def _apply_matches( ...@@ -807,7 +827,7 @@ def _apply_matches(
def apply_token_matches( def apply_token_matches(
prompt: list[int], prompt: list[int],
mm_prompt_updates: "MultiModalPromptUpdates", mm_prompt_updates: "MultiModalPromptUpdates",
tokenizer: TokenizerLike, tokenizer: TokenizerLike | None,
) -> tuple[list[int], "MultiModalPromptUpdatesApplyResult"]: ) -> tuple[list[int], "MultiModalPromptUpdatesApplyResult"]:
""" """
Apply the updates in `mm_prompt_updates` to `prompt`. Apply the updates in `mm_prompt_updates` to `prompt`.
...@@ -824,7 +844,7 @@ def apply_token_matches( ...@@ -824,7 +844,7 @@ def apply_token_matches(
def apply_text_matches( def apply_text_matches(
prompt: str, prompt: str,
mm_prompt_updates: "MultiModalPromptUpdates", mm_prompt_updates: "MultiModalPromptUpdates",
tokenizer: TokenizerLike, tokenizer: TokenizerLike | None,
) -> tuple[str, "MultiModalPromptUpdatesApplyResult"]: ) -> tuple[str, "MultiModalPromptUpdatesApplyResult"]:
""" """
Apply the updates in `mm_prompt_updates` to `prompt`. Apply the updates in `mm_prompt_updates` to `prompt`.
...@@ -841,7 +861,7 @@ def apply_text_matches( ...@@ -841,7 +861,7 @@ def apply_text_matches(
def _iter_placeholders( def _iter_placeholders(
prompt: list[int], prompt: list[int],
mm_prompt_updates: "MultiModalPromptUpdates", mm_prompt_updates: "MultiModalPromptUpdates",
tokenizer: TokenizerLike, tokenizer: TokenizerLike | None,
) -> Iterable[PlaceholderFeaturesInfo]: ) -> Iterable[PlaceholderFeaturesInfo]:
""" """
Yield each set of placeholder tokens found in `prompt`. Yield each set of placeholder tokens found in `prompt`.
...@@ -910,7 +930,7 @@ def _iter_placeholders( ...@@ -910,7 +930,7 @@ def _iter_placeholders(
def find_mm_placeholders( def find_mm_placeholders(
prompt: list[int], prompt: list[int],
mm_prompt_updates: "MultiModalPromptUpdates", mm_prompt_updates: "MultiModalPromptUpdates",
tokenizer: TokenizerLike, tokenizer: TokenizerLike | None,
) -> Mapping[str, list[PlaceholderFeaturesInfo]]: ) -> Mapping[str, list[PlaceholderFeaturesInfo]]:
it = _iter_placeholders(prompt, mm_prompt_updates, tokenizer) it = _iter_placeholders(prompt, mm_prompt_updates, tokenizer)
return dict(full_groupby_modality(it)) return dict(full_groupby_modality(it))
...@@ -931,9 +951,17 @@ class InputProcessingContext: ...@@ -931,9 +951,17 @@ class InputProcessingContext:
model_config: ModelConfig model_config: ModelConfig
"""The configuration of the model.""" """The configuration of the model."""
tokenizer: TokenizerLike tokenizer: TokenizerLike | None
"""The tokenizer used to tokenize the inputs.""" """The tokenizer used to tokenize the inputs."""
def get_tokenizer(self) -> TokenizerLike:
if self.tokenizer is None:
raise ValueError(
"You cannot pass text prompts when `skip_tokenizer_init=True`"
)
return self.tokenizer
@overload @overload
def get_hf_config(self, /) -> PretrainedConfig: ... def get_hf_config(self, /) -> PretrainedConfig: ...
...@@ -1148,7 +1176,7 @@ class BaseProcessingInfo: ...@@ -1148,7 +1176,7 @@ class BaseProcessingInfo:
return self.ctx.model_config.model return self.ctx.model_config.model
def get_tokenizer(self) -> TokenizerLike: def get_tokenizer(self) -> TokenizerLike:
return self.ctx.tokenizer return self.ctx.get_tokenizer()
def get_hf_config(self) -> PretrainedConfig: def get_hf_config(self) -> PretrainedConfig:
return self.ctx.get_hf_config() return self.ctx.get_hf_config()
...@@ -1960,15 +1988,11 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -1960,15 +1988,11 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
for update_idxs in match_result.values() for update_idxs in match_result.values()
): ):
new_text, match_result = self._apply_text_matches( new_text, match_result = self._apply_text_matches(
decode_tokens(tokenizer, token_ids), _seq2text(tokenizer, token_ids, use_cache=False),
mm_prompt_updates, mm_prompt_updates,
) )
new_token_ids = encode_tokens( new_token_ids = _seq2tokens(tokenizer, new_text, use_cache=False)
tokenizer,
new_text,
add_special_tokens=False,
)
matched_updates = defaultdict[str, list[Sequence[ResolvedPromptUpdate]]](list) matched_updates = defaultdict[str, list[Sequence[ResolvedPromptUpdate]]](list)
for modality, update_idxs in match_result.items(): for modality, update_idxs in match_result.items():
......
...@@ -234,9 +234,7 @@ class MultiModalRegistry: ...@@ -234,9 +234,7 @@ class MultiModalRegistry:
model_config: "ModelConfig", model_config: "ModelConfig",
tokenizer: TokenizerLike | None = None, tokenizer: TokenizerLike | None = None,
) -> InputProcessingContext: ) -> InputProcessingContext:
if model_config.skip_tokenizer_init: if tokenizer is None and not model_config.skip_tokenizer_init:
tokenizer = cast(TokenizerLike, object())
elif tokenizer is None:
tokenizer = cached_tokenizer_from_config(model_config) tokenizer = cached_tokenizer_from_config(model_config)
return InputProcessingContext(model_config, tokenizer) return InputProcessingContext(model_config, tokenizer)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment