Unverified Commit 6c47f6bf authored by Zhuohan Li's avatar Zhuohan Li Committed by GitHub
Browse files

[Core] Remove tokenizer group in vLLM (#24078)


Signed-off-by: default avatarZhuohan Li <zhuohan123@gmail.com>
parent c15309a7
...@@ -43,7 +43,7 @@ def _ref_convert_id_to_token( ...@@ -43,7 +43,7 @@ def _ref_convert_id_to_token(
[RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY]) [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY])
def test_incremental_detokenization(request_output_kind: RequestOutputKind, def test_incremental_detokenization(request_output_kind: RequestOutputKind,
dummy_test_vectors): dummy_test_vectors):
output_processor = OutputProcessor(dummy_test_vectors.tokenizer_group, output_processor = OutputProcessor(dummy_test_vectors.tokenizer,
log_stats=False) log_stats=False)
engine_core = MockEngineCore( engine_core = MockEngineCore(
tokens_list=dummy_test_vectors.generation_tokens) tokens_list=dummy_test_vectors.generation_tokens)
...@@ -382,7 +382,7 @@ def test_logprobs_processor(request_output_kind: RequestOutputKind, ...@@ -382,7 +382,7 @@ def test_logprobs_processor(request_output_kind: RequestOutputKind,
num_sample_logprobs: Optional[int], num_sample_logprobs: Optional[int],
num_prompt_logprobs: Optional[int], num_prompt_logprobs: Optional[int],
dummy_test_vectors): dummy_test_vectors):
output_processor = OutputProcessor(dummy_test_vectors.tokenizer_group, output_processor = OutputProcessor(dummy_test_vectors.tokenizer,
log_stats=False) log_stats=False)
engine_core = MockEngineCore( engine_core = MockEngineCore(
tokens_list=dummy_test_vectors.generation_tokens, tokens_list=dummy_test_vectors.generation_tokens,
...@@ -535,7 +535,7 @@ def test_stop_token(include_stop_str_in_output: bool, ...@@ -535,7 +535,7 @@ def test_stop_token(include_stop_str_in_output: bool,
) # '<|end_of_text|>' ) # '<|end_of_text|>'
stop_token_ids = [128009] if not is_eos_test else None # '<|eot_id|>' stop_token_ids = [128009] if not is_eos_test else None # '<|eot_id|>'
output_processor = OutputProcessor(dummy_test_vectors.tokenizer_group, output_processor = OutputProcessor(dummy_test_vectors.tokenizer,
log_stats=False) log_stats=False)
# Dummy engine core outputs, with control tokens suffixed to test stops # Dummy engine core outputs, with control tokens suffixed to test stops
suffix_token = ([eos_token_id] if is_eos_test else stop_token_ids) suffix_token = ([eos_token_id] if is_eos_test else stop_token_ids)
...@@ -642,7 +642,7 @@ def test_stop_token(include_stop_str_in_output: bool, ...@@ -642,7 +642,7 @@ def test_stop_token(include_stop_str_in_output: bool,
[None, NUM_SAMPLE_LOGPROBS_UNDER_TEST]) [None, NUM_SAMPLE_LOGPROBS_UNDER_TEST])
def test_stop_string(include_stop_str_in_output: bool, def test_stop_string(include_stop_str_in_output: bool,
num_sample_logprobs: Optional[int], dummy_test_vectors): num_sample_logprobs: Optional[int], dummy_test_vectors):
output_processor = OutputProcessor(dummy_test_vectors.tokenizer_group, output_processor = OutputProcessor(dummy_test_vectors.tokenizer,
log_stats=False) log_stats=False)
engine_core = MockEngineCore( engine_core = MockEngineCore(
tokens_list=dummy_test_vectors.generation_tokens, tokens_list=dummy_test_vectors.generation_tokens,
...@@ -763,7 +763,7 @@ def test_stop_string(include_stop_str_in_output: bool, ...@@ -763,7 +763,7 @@ def test_stop_string(include_stop_str_in_output: bool,
def test_iteration_stats(dummy_test_vectors): def test_iteration_stats(dummy_test_vectors):
output_processor = OutputProcessor(dummy_test_vectors.tokenizer_group, output_processor = OutputProcessor(dummy_test_vectors.tokenizer,
log_stats=True) log_stats=True)
engine_core = MockEngineCore(dummy_test_vectors.generation_tokens) engine_core = MockEngineCore(dummy_test_vectors.generation_tokens)
engine_core_timestamp = time.monotonic() engine_core_timestamp = time.monotonic()
......
...@@ -9,7 +9,6 @@ import torch ...@@ -9,7 +9,6 @@ import torch
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
from vllm.engine.arg_utils import EngineArgs from vllm.engine.arg_utils import EngineArgs
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
from vllm.v1.engine import EngineCoreOutput, FinishReason from vllm.v1.engine import EngineCoreOutput, FinishReason
from vllm.v1.outputs import LogprobsLists, LogprobsTensors from vllm.v1.outputs import LogprobsLists, LogprobsTensors
...@@ -296,7 +295,6 @@ def generate_dummy_prompt_logprobs_tensors( ...@@ -296,7 +295,6 @@ def generate_dummy_prompt_logprobs_tensors(
class DummyOutputProcessorTestVectors: class DummyOutputProcessorTestVectors:
"""Dummy test vectors for output processor tests""" """Dummy test vectors for output processor tests"""
tokenizer: GeneralTokenizerType tokenizer: GeneralTokenizerType
tokenizer_group: TokenizerGroup
vllm_config: EngineArgs vllm_config: EngineArgs
full_tokens: list[list[int]] # Prompt + generated tokens full_tokens: list[list[int]] # Prompt + generated tokens
prompt_tokens: list[list[int]] prompt_tokens: list[list[int]]
......
...@@ -582,7 +582,7 @@ def test_structured_output_with_reasoning_matrices( ...@@ -582,7 +582,7 @@ def test_structured_output_with_reasoning_matrices(
reasoning_parser=reasoning_parser, reasoning_parser=reasoning_parser,
speculative_config=speculative_config, speculative_config=speculative_config,
) )
tokenizer = llm.get_tokenizer(None) tokenizer = llm.get_tokenizer()
reasoner = ReasoningParserManager.get_reasoning_parser(reasoning_parser)( reasoner = ReasoningParserManager.get_reasoning_parser(reasoning_parser)(
tokenizer=tokenizer) tokenizer=tokenizer)
......
...@@ -37,7 +37,7 @@ from vllm.lora.request import LoRARequest ...@@ -37,7 +37,7 @@ from vllm.lora.request import LoRARequest
from vllm.lora.utils import get_adapter_absolute_path from vllm.lora.utils import get_adapter_absolute_path
from vllm.multimodal import MultiModalDataDict from vllm.multimodal import MultiModalDataDict
from vllm.multimodal.image import convert_image_mode from vllm.multimodal.image import convert_image_mode
from vllm.transformers_utils.tokenizer import AnyTokenizer, get_lora_tokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import PlaceholderModule from vllm.utils import PlaceholderModule
try: try:
...@@ -155,34 +155,26 @@ class BenchmarkDataset(ABC): ...@@ -155,34 +155,26 @@ class BenchmarkDataset(ABC):
def get_random_lora_request( def get_random_lora_request(
self, self,
tokenizer: PreTrainedTokenizerBase,
max_loras: Optional[int] = None, max_loras: Optional[int] = None,
lora_path: Optional[str] = None, lora_path: Optional[str] = None,
) -> tuple[Optional[LoRARequest], AnyTokenizer]: ) -> Optional[LoRARequest]:
""" """
Optionally select a random LoRA request and return its associated Optionally select a random LoRA request.
tokenizer.
This method is used when LoRA parameters are provided. It randomly This method is used when LoRA parameters are provided. It randomly
selects a LoRA based on max_loras and retrieves a cached tokenizer for selects a LoRA based on max_loras.
that LoRA if available. Otherwise, it returns the base tokenizer.
Args: Args:
tokenizer (PreTrainedTokenizerBase): The base tokenizer to use if no
LoRA is selected.
max_loras (Optional[int]): The maximum number of LoRAs available. max_loras (Optional[int]): The maximum number of LoRAs available.
If `None`, LoRA is not used. If `None`, LoRA is not used.
lora_path (Optional[str]): Path to the LoRA parameters on disk. lora_path (Optional[str]): Path to the LoRA parameters on disk.
If `None`, LoRA is not used. If `None`, LoRA is not used.
Returns: Returns:
A tuple with the following elements: A new [LoRARequest][] (or `None` if not applicable).
- A new [LoRARequest][] (or `None` if not applicable).
- The tokenizer associated with the LoRA request
(or the base tokenizer).
""" """
if max_loras is None or lora_path is None: if max_loras is None or lora_path is None:
return None, tokenizer return None
# Generate a random LoRA ID in the range [1, max_loras]. # Generate a random LoRA ID in the range [1, max_loras].
lora_id = random.randint(1, max_loras) lora_id = random.randint(1, max_loras)
...@@ -191,11 +183,7 @@ class BenchmarkDataset(ABC): ...@@ -191,11 +183,7 @@ class BenchmarkDataset(ABC):
lora_int_id=lora_id, lora_int_id=lora_id,
lora_path=lora_path_on_disk(lora_path), lora_path=lora_path_on_disk(lora_path),
) )
if lora_id not in lora_tokenizer_cache: return lora_request
lora_tokenizer_cache[lora_id] = get_lora_tokenizer(lora_request)
# Return lora_request and the cached tokenizer if available; otherwise,
# return the base tokenizer
return lora_request, lora_tokenizer_cache[lora_id] or tokenizer
@abstractmethod @abstractmethod
def sample(self, tokenizer: PreTrainedTokenizerBase, def sample(self, tokenizer: PreTrainedTokenizerBase,
...@@ -982,8 +970,8 @@ class ShareGPTDataset(BenchmarkDataset): ...@@ -982,8 +970,8 @@ class ShareGPTDataset(BenchmarkDataset):
entry["conversations"][1]["value"], entry["conversations"][1]["value"],
) )
lora_request, tokenizer = self.get_random_lora_request( lora_request = self.get_random_lora_request(
tokenizer=tokenizer, max_loras=max_loras, lora_path=lora_path) max_loras=max_loras, lora_path=lora_path)
prompt_ids = tokenizer(prompt).input_ids prompt_ids = tokenizer(prompt).input_ids
completion_ids = tokenizer(completion).input_ids completion_ids = tokenizer(completion).input_ids
prompt_len = len(prompt_ids) prompt_len = len(prompt_ids)
...@@ -1882,8 +1870,8 @@ class BurstGPTDataset(BenchmarkDataset): ...@@ -1882,8 +1870,8 @@ class BurstGPTDataset(BenchmarkDataset):
for i in range(num_requests): for i in range(num_requests):
input_len = int(data[i][2]) input_len = int(data[i][2])
output_len = int(data[i][3]) output_len = int(data[i][3])
lora_req, tokenizer = self.get_random_lora_request( lora_req = self.get_random_lora_request(
tokenizer=tokenizer, max_loras=max_loras, lora_path=lora_path) max_loras=max_loras, lora_path=lora_path)
vocab_size = tokenizer.vocab_size vocab_size = tokenizer.vocab_size
# Generate a synthetic prompt: a list of token IDs computed as (i + # Generate a synthetic prompt: a list of token IDs computed as (i +
# j) modulo vocab_size. # j) modulo vocab_size.
...@@ -2376,7 +2364,6 @@ class AIMODataset(HuggingFaceDataset): ...@@ -2376,7 +2364,6 @@ class AIMODataset(HuggingFaceDataset):
expected_output_len=output_len, expected_output_len=output_len,
multi_modal_data=None, multi_modal_data=None,
request_id=request_id_prefix + str(ind), request_id=request_id_prefix + str(ind),
)) ))
ind += 1 ind += 1
self.maybe_oversample_requests(sampled_requests, num_requests, self.maybe_oversample_requests(sampled_requests, num_requests,
......
...@@ -390,11 +390,8 @@ class _AsyncLLMEngine(LLMEngine): ...@@ -390,11 +390,8 @@ class _AsyncLLMEngine(LLMEngine):
"""Stop the remote worker execution loop.""" """Stop the remote worker execution loop."""
await self.model_executor.stop_remote_worker_execution_loop_async() await self.model_executor.stop_remote_worker_execution_loop_async()
async def get_tokenizer_async(self, async def get_tokenizer_async(self) -> AnyTokenizer:
lora_request: Optional[LoRARequest] = None return self.get_tokenizer()
) -> AnyTokenizer:
return await (
self.get_tokenizer_group().get_lora_tokenizer_async(lora_request))
async def add_request_async( async def add_request_async(
self, self,
...@@ -435,7 +432,6 @@ class _AsyncLLMEngine(LLMEngine): ...@@ -435,7 +432,6 @@ class _AsyncLLMEngine(LLMEngine):
processed_inputs = await self.input_preprocessor.preprocess_async( processed_inputs = await self.input_preprocessor.preprocess_async(
prompt, prompt,
lora_request=lora_request,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
) )
...@@ -614,11 +610,8 @@ class AsyncLLMEngine(EngineClient): ...@@ -614,11 +610,8 @@ class AsyncLLMEngine(EngineClient):
async def get_input_preprocessor(self) -> InputPreprocessor: async def get_input_preprocessor(self) -> InputPreprocessor:
return self.engine.input_preprocessor return self.engine.input_preprocessor
async def get_tokenizer( async def get_tokenizer(self) -> AnyTokenizer:
self, return self.engine.get_tokenizer()
lora_request: Optional[LoRARequest] = None,
) -> AnyTokenizer:
return await self.engine.get_tokenizer_async(lora_request)
def start_background_loop(self) -> None: def start_background_loop(self) -> None:
"""Start the background loop.""" """Start the background loop."""
......
...@@ -49,9 +49,8 @@ from vllm.sequence import (ExecuteModelRequest, ParallelSampleSequenceGroup, ...@@ -49,9 +49,8 @@ from vllm.sequence import (ExecuteModelRequest, ParallelSampleSequenceGroup,
from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context, from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context,
init_tracer) init_tracer)
from vllm.transformers_utils.detokenizer import Detokenizer from vllm.transformers_utils.detokenizer import Detokenizer
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer import (AnyTokenizer,
from vllm.transformers_utils.tokenizer_group import ( init_tokenizer_from_configs)
TokenizerGroup, init_tokenizer_from_configs)
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled, from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
usage_message) usage_message)
from vllm.utils import Counter, Device, resolve_obj_by_qualname, weak_bind from vllm.utils import Counter, Device, resolve_obj_by_qualname, weak_bind
...@@ -186,7 +185,7 @@ class LLMEngine: ...@@ -186,7 +185,7 @@ class LLMEngine:
return outputs_ return outputs_
tokenizer: Optional[TokenizerGroup] tokenizer: Optional[AnyTokenizer]
def __init__( def __init__(
self, self,
...@@ -233,18 +232,9 @@ class LLMEngine: ...@@ -233,18 +232,9 @@ class LLMEngine:
if self.model_config.skip_tokenizer_init: if self.model_config.skip_tokenizer_init:
self.tokenizer = None self.tokenizer = None
self.detokenizer = None self.detokenizer = None
tokenizer_group = None
else: else:
self.tokenizer = self._init_tokenizer() self.tokenizer = self._init_tokenizer()
self.detokenizer = Detokenizer(self.tokenizer) self.detokenizer = Detokenizer(self.tokenizer)
tokenizer_group = self.get_tokenizer_group()
# Ensure that the function doesn't contain a reference to self,
# to avoid engine GC issues
def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer:
assert tokenizer_group, ("tokenizer_group cannot be None, "
"make sure skip_tokenizer_init is False")
return tokenizer_group.get_lora_tokenizer(sequence.lora_request)
self.seq_counter = Counter() self.seq_counter = Counter()
self.generation_config_fields = ( self.generation_config_fields = (
...@@ -389,10 +379,8 @@ class LLMEngine: ...@@ -389,10 +379,8 @@ class LLMEngine:
self.detokenizer, self.detokenizer,
self.scheduler, self.scheduler,
self.seq_counter, self.seq_counter,
get_tokenizer_for_seq,
stop_checker=StopChecker( stop_checker=StopChecker(
self.scheduler_config.max_model_len, self.scheduler_config.max_model_len,
get_tokenizer_for_seq,
self.reasoner if self.decoding_config.reasoning_backend self.reasoner if self.decoding_config.reasoning_backend
and self.tokenizer else None, and self.tokenizer else None,
), ),
...@@ -521,24 +509,15 @@ class LLMEngine: ...@@ -521,24 +509,15 @@ class LLMEngine:
if model_executor := getattr(self, "model_executor", None): if model_executor := getattr(self, "model_executor", None):
model_executor.shutdown() model_executor.shutdown()
def get_tokenizer_group(self) -> TokenizerGroup: def get_tokenizer(self) -> AnyTokenizer:
if self.tokenizer is None: if self.tokenizer is None:
raise ValueError("Unable to get tokenizer because " raise ValueError("Unable to get tokenizer because "
"skip_tokenizer_init is True") "skip_tokenizer_init is True")
return self.tokenizer return self.tokenizer
def get_tokenizer( def _init_tokenizer(self) -> AnyTokenizer:
self, return init_tokenizer_from_configs(model_config=self.model_config)
lora_request: Optional[LoRARequest] = None,
) -> AnyTokenizer:
return self.get_tokenizer_group().get_lora_tokenizer(lora_request)
def _init_tokenizer(self) -> TokenizerGroup:
return init_tokenizer_from_configs(
model_config=self.model_config,
scheduler_config=self.scheduler_config,
lora_config=self.lora_config)
def _verify_args(self) -> None: def _verify_args(self) -> None:
self.model_config.verify_with_parallel_config(self.parallel_config) self.model_config.verify_with_parallel_config(self.parallel_config)
...@@ -574,11 +553,11 @@ class LLMEngine: ...@@ -574,11 +553,11 @@ class LLMEngine:
) )
return None return None
self._validate_model_inputs(processed_inputs, lora_request) self._validate_model_inputs(processed_inputs)
# Create the sequences. # Create the sequences.
block_size = self.cache_config.block_size block_size = self.cache_config.block_size
seq_id = next(self.seq_counter) seq_id = next(self.seq_counter)
eos_token_id = self.input_preprocessor.get_eos_token_id(lora_request) eos_token_id = self.input_preprocessor.get_eos_token_id()
encoder_inputs, decoder_inputs = split_enc_dec_inputs(processed_inputs) encoder_inputs, decoder_inputs = split_enc_dec_inputs(processed_inputs)
...@@ -700,7 +679,6 @@ class LLMEngine: ...@@ -700,7 +679,6 @@ class LLMEngine:
processed_inputs = self.input_preprocessor.preprocess( processed_inputs = self.input_preprocessor.preprocess(
prompt, prompt,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
) )
self._add_processed_request( self._add_processed_request(
...@@ -1739,29 +1717,22 @@ class LLMEngine: ...@@ -1739,29 +1717,22 @@ class LLMEngine:
SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_EXECUTE, SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_EXECUTE,
metrics.model_execute_time) metrics.model_execute_time)
def _validate_model_inputs(self, inputs: ProcessorInputs, def _validate_model_inputs(self, inputs: ProcessorInputs):
lora_request: Optional[LoRARequest]):
encoder_inputs, decoder_inputs = split_enc_dec_inputs(inputs) encoder_inputs, decoder_inputs = split_enc_dec_inputs(inputs)
if encoder_inputs is not None: if encoder_inputs is not None:
self._validate_model_input(encoder_inputs, self._validate_model_input(encoder_inputs, prompt_type="encoder")
lora_request,
prompt_type="encoder")
self._validate_model_input(decoder_inputs, self._validate_model_input(decoder_inputs, prompt_type="decoder")
lora_request,
prompt_type="decoder")
def _validate_model_input( def _validate_model_input(
self, self,
prompt_inputs: SingletonInputs, prompt_inputs: SingletonInputs,
lora_request: Optional[LoRARequest],
*, *,
prompt_type: Literal["encoder", "decoder"], prompt_type: Literal["encoder", "decoder"],
): ):
model_config = self.model_config model_config = self.model_config
tokenizer = (None if self.tokenizer is None else tokenizer = self.tokenizer
self.tokenizer.get_lora_tokenizer(lora_request))
prompt_ids = prompt_inputs.get("prompt_token_ids", []) prompt_ids = prompt_inputs.get("prompt_token_ids", [])
if not prompt_ids: if not prompt_ids:
...@@ -1822,7 +1793,7 @@ class LLMEngine: ...@@ -1822,7 +1793,7 @@ class LLMEngine:
logits_processors = [] logits_processors = []
if (sampling_params.logit_bias or sampling_params.allowed_token_ids): if (sampling_params.logit_bias or sampling_params.allowed_token_ids):
tokenizer = self.get_tokenizer(lora_request=lora_request) tokenizer = self.get_tokenizer()
processors = get_openai_logits_processors( processors = get_openai_logits_processors(
logit_bias=sampling_params.logit_bias, logit_bias=sampling_params.logit_bias,
...@@ -1835,7 +1806,7 @@ class LLMEngine: ...@@ -1835,7 +1806,7 @@ class LLMEngine:
sampling_params.allowed_token_ids = None sampling_params.allowed_token_ids = None
if len(sampling_params.bad_words) > 0: if len(sampling_params.bad_words) > 0:
tokenizer = self.get_tokenizer(lora_request) tokenizer = self.get_tokenizer()
processors = get_bad_words_logits_processors( processors = get_bad_words_logits_processors(
bad_words=sampling_params.bad_words, tokenizer=tokenizer) bad_words=sampling_params.bad_words, tokenizer=tokenizer)
logits_processors.extend(processors) logits_processors.extend(processors)
......
...@@ -2,14 +2,13 @@ ...@@ -2,14 +2,13 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Callable, List from typing import List
from vllm.config import SchedulerConfig from vllm.config import SchedulerConfig
from vllm.core.scheduler import Scheduler from vllm.core.scheduler import Scheduler
from vllm.engine.output_processor.stop_checker import StopChecker from vllm.engine.output_processor.stop_checker import StopChecker
from vllm.sequence import Sequence, SequenceGroup, SequenceGroupOutput from vllm.sequence import SequenceGroup, SequenceGroupOutput
from vllm.transformers_utils.detokenizer import Detokenizer from vllm.transformers_utils.detokenizer import Detokenizer
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import Counter from vllm.utils import Counter
...@@ -31,7 +30,6 @@ class SequenceGroupOutputProcessor(ABC): ...@@ -31,7 +30,6 @@ class SequenceGroupOutputProcessor(ABC):
detokenizer: Detokenizer, detokenizer: Detokenizer,
scheduler: List[Scheduler], scheduler: List[Scheduler],
seq_counter: Counter, seq_counter: Counter,
get_tokenizer_for_seq: Callable[[Sequence], AnyTokenizer],
stop_checker: "StopChecker", stop_checker: "StopChecker",
): ):
"""Create an output processor. """Create an output processor.
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Callable, List, Optional, Tuple from typing import List, Optional, Tuple
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.reasoning import ReasoningParser from vllm.reasoning import ReasoningParser
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.sequence import Sequence, SequenceStatus from vllm.sequence import Sequence, SequenceStatus
from vllm.transformers_utils.tokenizer import AnyTokenizer
class StopChecker: class StopChecker:
...@@ -20,12 +19,10 @@ class StopChecker: ...@@ -20,12 +19,10 @@ class StopChecker:
def __init__( def __init__(
self, self,
max_model_len: int, max_model_len: int,
get_tokenizer_for_seq: Callable[[Sequence], AnyTokenizer],
reasoner: Optional[ReasoningParser] = None, reasoner: Optional[ReasoningParser] = None,
): ):
# Do not use it directly, but use `self._get_max_model_len`. # Do not use it directly, but use `self._get_max_model_len`.
self._max_model_len = max_model_len self._max_model_len = max_model_len
self.get_tokenizer_for_seq = get_tokenizer_for_seq
self.reasoner = reasoner self.reasoner = reasoner
def _get_max_model_len(self, lora_req: Optional[LoRARequest]): def _get_max_model_len(self, lora_req: Optional[LoRARequest]):
......
...@@ -76,8 +76,7 @@ class EngineClient(ABC): ...@@ -76,8 +76,7 @@ class EngineClient(ABC):
include_stop_str_in_output = params.include_stop_str_in_output include_stop_str_in_output = params.include_stop_str_in_output
preprocessor = await self.get_input_preprocessor() preprocessor = await self.get_input_preprocessor()
tokenizer_group = preprocessor.get_tokenizer_group() tokenizer = preprocessor.get_tokenizer()
tokenizer = await tokenizer_group.get_lora_tokenizer_async()
eos_token_id = tokenizer.eos_token_id eos_token_id = tokenizer.eos_token_id
if is_explicit_encoder_decoder_prompt(prompt): if is_explicit_encoder_decoder_prompt(prompt):
...@@ -260,11 +259,8 @@ class EngineClient(ABC): ...@@ -260,11 +259,8 @@ class EngineClient(ABC):
... ...
@abstractmethod @abstractmethod
async def get_tokenizer( async def get_tokenizer(self) -> AnyTokenizer:
self, """Get the tokenizer"""
lora_request: Optional[LoRARequest] = None,
) -> AnyTokenizer:
"""Get the appropriate tokenizer for the request"""
... ...
async def get_io_processor(self) -> IOProcessor: async def get_io_processor(self) -> IOProcessor:
......
...@@ -301,23 +301,17 @@ class LLM: ...@@ -301,23 +301,17 @@ class LLM:
self.io_processor = get_io_processor(self.llm_engine.vllm_config, self.io_processor = get_io_processor(self.llm_engine.vllm_config,
io_processor_plugin) io_processor_plugin)
def get_tokenizer( def get_tokenizer(self) -> AnyTokenizer:
self, return self.llm_engine.get_tokenizer()
lora_request: Optional[LoRARequest] = None,
) -> AnyTokenizer:
return self.llm_engine.get_tokenizer_group().get_lora_tokenizer(
lora_request)
def set_tokenizer(self, tokenizer: AnyTokenizer) -> None: def set_tokenizer(self, tokenizer: AnyTokenizer) -> None:
tokenizer_group = self.llm_engine.get_tokenizer_group()
# While CachedTokenizer is dynamic, have no choice but # While CachedTokenizer is dynamic, have no choice but
# compare class name. Misjudgment will arise from # compare class name. Misjudgment will arise from
# user-defined tokenizer started with 'Cached' # user-defined tokenizer started with 'Cached'
if tokenizer.__class__.__name__.startswith("Cached"): if tokenizer.__class__.__name__.startswith("Cached"):
tokenizer_group.tokenizer = tokenizer self.llm_engine.tokenizer = tokenizer
else: else:
tokenizer_group.tokenizer = get_cached_tokenizer(tokenizer) self.llm_engine.tokenizer = get_cached_tokenizer(tokenizer)
def get_default_sampling_params(self) -> SamplingParams: def get_default_sampling_params(self) -> SamplingParams:
if self.default_sampling_params is None: if self.default_sampling_params is None:
...@@ -707,7 +701,6 @@ class LLM: ...@@ -707,7 +701,6 @@ class LLM:
self, self,
messages: Union[list[ChatCompletionMessageParam], messages: Union[list[ChatCompletionMessageParam],
list[list[ChatCompletionMessageParam]]], list[list[ChatCompletionMessageParam]]],
lora_request: Optional[LoRARequest] = None,
chat_template: Optional[str] = None, chat_template: Optional[str] = None,
chat_template_content_format: ChatTemplateContentFormatOption = "auto", chat_template_content_format: ChatTemplateContentFormatOption = "auto",
add_generation_prompt: bool = True, add_generation_prompt: bool = True,
...@@ -739,7 +732,7 @@ class LLM: ...@@ -739,7 +732,7 @@ class LLM:
cast(list[ChatCompletionMessageParam], messages) cast(list[ChatCompletionMessageParam], messages)
] ]
tokenizer = self.get_tokenizer(lora_request) tokenizer = self.get_tokenizer()
model_config = self.llm_engine.get_model_config() model_config = self.llm_engine.get_model_config()
resolved_content_format = resolve_chat_template_content_format( resolved_content_format = resolve_chat_template_content_format(
chat_template, chat_template,
...@@ -872,7 +865,6 @@ class LLM: ...@@ -872,7 +865,6 @@ class LLM:
prompts = self.preprocess_chat( prompts = self.preprocess_chat(
messages=messages, messages=messages,
lora_request=lora_request,
chat_template=chat_template, chat_template=chat_template,
chat_template_content_format=chat_template_content_format, chat_template_content_format=chat_template_content_format,
add_generation_prompt=add_generation_prompt, add_generation_prompt=add_generation_prompt,
......
...@@ -188,7 +188,7 @@ class OpenAIServingChat(OpenAIServing): ...@@ -188,7 +188,7 @@ class OpenAIServingChat(OpenAIServing):
model_name = self.models.model_name(lora_request) model_name = self.models.model_name(lora_request)
tokenizer = await self.engine_client.get_tokenizer(lora_request) tokenizer = await self.engine_client.get_tokenizer()
tool_parser = self.tool_parser tool_parser = self.tool_parser
......
...@@ -50,10 +50,7 @@ class ClassificationMixin(OpenAIServing): ...@@ -50,10 +50,7 @@ class ClassificationMixin(OpenAIServing):
return None return None
try: try:
ctx.lora_request = self._maybe_get_adapters(ctx.request) ctx.tokenizer = await self.engine_client.get_tokenizer()
ctx.tokenizer = await self.engine_client.get_tokenizer(
ctx.lora_request)
renderer = self._get_renderer(ctx.tokenizer) renderer = self._get_renderer(ctx.tokenizer)
ctx.engine_prompts = await renderer.render_prompt( ctx.engine_prompts = await renderer.render_prompt(
......
...@@ -127,8 +127,7 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -127,8 +127,7 @@ class OpenAIServingCompletion(OpenAIServing):
if self.model_config.skip_tokenizer_init: if self.model_config.skip_tokenizer_init:
tokenizer = None tokenizer = None
else: else:
tokenizer = await self.engine_client.get_tokenizer(lora_request tokenizer = await self.engine_client.get_tokenizer()
)
renderer = self._get_renderer(tokenizer) renderer = self._get_renderer(tokenizer)
engine_prompts = await renderer.render_prompt_and_embeds( engine_prompts = await renderer.render_prompt_and_embeds(
......
...@@ -76,8 +76,7 @@ class EmbeddingMixin(OpenAIServing): ...@@ -76,8 +76,7 @@ class EmbeddingMixin(OpenAIServing):
try: try:
ctx.lora_request = self._maybe_get_adapters(ctx.request) ctx.lora_request = self._maybe_get_adapters(ctx.request)
tokenizer = await self.engine_client.get_tokenizer(ctx.lora_request tokenizer = await self.engine_client.get_tokenizer()
)
renderer = self._get_renderer(tokenizer) renderer = self._get_renderer(tokenizer)
if isinstance(ctx.request, EmbeddingChatRequest): if isinstance(ctx.request, EmbeddingChatRequest):
......
...@@ -103,8 +103,7 @@ class OpenAIServingPooling(OpenAIServing): ...@@ -103,8 +103,7 @@ class OpenAIServingPooling(OpenAIServing):
if self.model_config.skip_tokenizer_init: if self.model_config.skip_tokenizer_init:
tokenizer = None tokenizer = None
else: else:
tokenizer = await self.engine_client.get_tokenizer(lora_request tokenizer = await self.engine_client.get_tokenizer()
)
renderer = self._get_renderer(tokenizer) renderer = self._get_renderer(tokenizer)
if getattr(request, "dimensions", None) is not None: if getattr(request, "dimensions", None) is not None:
......
...@@ -240,7 +240,7 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -240,7 +240,7 @@ class OpenAIServingResponses(OpenAIServing):
try: try:
lora_request = self._maybe_get_adapters(request) lora_request = self._maybe_get_adapters(request)
model_name = self.models.model_name(lora_request) model_name = self.models.model_name(lora_request)
tokenizer = await self.engine_client.get_tokenizer(lora_request) tokenizer = await self.engine_client.get_tokenizer()
if self.use_harmony: if self.use_harmony:
messages, request_prompts, engine_prompts = ( messages, request_prompts, engine_prompts = (
......
...@@ -269,7 +269,7 @@ class ServingScores(OpenAIServing): ...@@ -269,7 +269,7 @@ class ServingScores(OpenAIServing):
) -> Union[list[PoolingRequestOutput], ErrorResponse]: ) -> Union[list[PoolingRequestOutput], ErrorResponse]:
lora_request = self._maybe_get_adapters(request) lora_request = self._maybe_get_adapters(request)
tokenizer = await self.engine_client.get_tokenizer(lora_request) tokenizer = await self.engine_client.get_tokenizer()
truncate_prompt_tokens = getattr(request, "truncate_prompt_tokens", truncate_prompt_tokens = getattr(request, "truncate_prompt_tokens",
None) None)
......
...@@ -65,7 +65,7 @@ class OpenAIServingTokenization(OpenAIServing): ...@@ -65,7 +65,7 @@ class OpenAIServingTokenization(OpenAIServing):
try: try:
lora_request = self._maybe_get_adapters(request) lora_request = self._maybe_get_adapters(request)
tokenizer = await self.engine_client.get_tokenizer(lora_request) tokenizer = await self.engine_client.get_tokenizer()
renderer = self._get_renderer(tokenizer) renderer = self._get_renderer(tokenizer)
if isinstance(request, TokenizeChatRequest): if isinstance(request, TokenizeChatRequest):
...@@ -130,7 +130,7 @@ class OpenAIServingTokenization(OpenAIServing): ...@@ -130,7 +130,7 @@ class OpenAIServingTokenization(OpenAIServing):
lora_request = self._maybe_get_adapters(request) lora_request = self._maybe_get_adapters(request)
tokenizer = await self.engine_client.get_tokenizer(lora_request) tokenizer = await self.engine_client.get_tokenizer()
self._log_inputs(request_id, self._log_inputs(request_id,
request.tokens, request.tokens,
......
...@@ -9,13 +9,11 @@ from typing_extensions import assert_never ...@@ -9,13 +9,11 @@ from typing_extensions import assert_never
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.multimodal.cache import BaseMultiModalProcessorCache from vllm.multimodal.cache import BaseMultiModalProcessorCache
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalEncDecInputs, from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalEncDecInputs,
MultiModalInputs, MultiModalUUIDDict) MultiModalInputs, MultiModalUUIDDict)
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
from .data import (DecoderOnlyInputs, EmbedsInputs, EmbedsPrompt, from .data import (DecoderOnlyInputs, EmbedsInputs, EmbedsPrompt,
EncoderDecoderInputs, ProcessorInputs, PromptType, EncoderDecoderInputs, ProcessorInputs, PromptType,
...@@ -31,7 +29,7 @@ class InputPreprocessor: ...@@ -31,7 +29,7 @@ class InputPreprocessor:
def __init__( def __init__(
self, self,
model_config: ModelConfig, model_config: ModelConfig,
tokenizer: Optional[TokenizerGroup], tokenizer: Optional[AnyTokenizer],
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
mm_processor_cache: Optional[BaseMultiModalProcessorCache] = None, mm_processor_cache: Optional[BaseMultiModalProcessorCache] = None,
) -> None: ) -> None:
...@@ -42,32 +40,28 @@ class InputPreprocessor: ...@@ -42,32 +40,28 @@ class InputPreprocessor:
self.mm_registry = mm_registry self.mm_registry = mm_registry
self.mm_processor_cache = mm_processor_cache self.mm_processor_cache = mm_processor_cache
def get_tokenizer_group(self) -> TokenizerGroup: def get_tokenizer(self) -> AnyTokenizer:
if self.tokenizer is None: if self.tokenizer is None:
raise ValueError("You cannot pass text prompts when " raise ValueError("You cannot pass text prompts when "
"`skip_tokenizer_init` is True") "`skip_tokenizer_init` is True")
return self.tokenizer return self.tokenizer
def get_bos_token_id(self, def get_bos_token_id(self) -> Optional[int]:
lora_request: Optional[LoRARequest] = None
) -> Optional[int]:
if self.tokenizer is None: if self.tokenizer is None:
logger.warning("Using None for BOS token id because tokenizer " logger.warning("Using None for BOS token id because tokenizer "
"is not initialized") "is not initialized")
return None return None
return self.tokenizer.get_lora_tokenizer(lora_request).bos_token_id return self.tokenizer.bos_token_id
def get_eos_token_id(self, def get_eos_token_id(self) -> Optional[int]:
lora_request: Optional[LoRARequest] = None
) -> Optional[int]:
if self.tokenizer is None: if self.tokenizer is None:
logger.warning("Using None for EOS token id because tokenizer " logger.warning("Using None for EOS token id because tokenizer "
"is not initialized") "is not initialized")
return None return None
return self.tokenizer.get_lora_tokenizer(lora_request).eos_token_id return self.tokenizer.eos_token_id
def get_decoder_start_token_id(self) -> Optional[int]: def get_decoder_start_token_id(self) -> Optional[int]:
""" """
...@@ -190,14 +184,13 @@ class InputPreprocessor: ...@@ -190,14 +184,13 @@ class InputPreprocessor:
def _tokenize_prompt( def _tokenize_prompt(
self, self,
prompt: str, prompt: str,
lora_request: Optional[LoRARequest],
tokenization_kwargs: Optional[dict[str, Any]] = None, tokenization_kwargs: Optional[dict[str, Any]] = None,
) -> list[int]: ) -> list[int]:
""" """
Apply the model's tokenizer to a text prompt, returning the Apply the model's tokenizer to a text prompt, returning the
corresponding token IDs. corresponding token IDs.
""" """
tokenizer = self.get_tokenizer_group() tokenizer = self.get_tokenizer()
tokenization_kwargs = self._get_tokenization_kw(tokenization_kwargs) tokenization_kwargs = self._get_tokenization_kw(tokenization_kwargs)
encoder_config = self.model_config.encoder_config encoder_config = self.model_config.encoder_config
...@@ -205,50 +198,39 @@ class InputPreprocessor: ...@@ -205,50 +198,39 @@ class InputPreprocessor:
if encoder_config and encoder_config.get("do_lower_case", False): if encoder_config and encoder_config.get("do_lower_case", False):
prompt = prompt.lower() prompt = prompt.lower()
return tokenizer.encode(prompt=prompt, return tokenizer.encode(prompt, **tokenization_kwargs)
lora_request=lora_request,
**tokenization_kwargs)
async def _tokenize_prompt_async( async def _tokenize_prompt_async(
self, self,
prompt: str, prompt: str,
lora_request: Optional[LoRARequest],
tokenization_kwargs: Optional[dict[str, Any]] = None, tokenization_kwargs: Optional[dict[str, Any]] = None,
) -> list[int]: ) -> list[int]:
""" """
Async version of Async version of
[`_tokenize_prompt`][vllm.inputs.preprocess.InputPreprocessor._tokenize_prompt]. [`_tokenize_prompt`][vllm.inputs.preprocess.InputPreprocessor._tokenize_prompt].
""" """
tokenizer = self.get_tokenizer_group() tokenizer = self.get_tokenizer()
tokenization_kwargs = self._get_tokenization_kw(tokenization_kwargs) tokenization_kwargs = self._get_tokenization_kw(tokenization_kwargs)
return await tokenizer.encode_async(prompt=prompt, return tokenizer.encode(prompt, **tokenization_kwargs)
lora_request=lora_request,
**tokenization_kwargs)
def _get_mm_tokenizer( def _get_mm_tokenizer(self) -> AnyTokenizer:
self,
lora_request: Optional[LoRARequest],
) -> AnyTokenizer:
# PrithviGeoSpatialMAE needs to be initialized without a tokenizer # PrithviGeoSpatialMAE needs to be initialized without a tokenizer
# while using also multi-modal input # while using also multi-modal input
if not self.tokenizer: if not self.tokenizer:
return cast(AnyTokenizer, object()) # Dummy return cast(AnyTokenizer, object()) # Dummy
tokenizer_group = self.get_tokenizer_group() tokenizer = self.get_tokenizer()
return tokenizer_group.get_lora_tokenizer(lora_request) return tokenizer
async def _get_mm_tokenizer_async( async def _get_mm_tokenizer_async(self) -> AnyTokenizer:
self,
lora_request: Optional[LoRARequest],
) -> AnyTokenizer:
# PrithviGeoSpatialMAE needs to be initialized without a tokenizer # PrithviGeoSpatialMAE needs to be initialized without a tokenizer
# while using also multi-modal input # while using also multi-modal input
if not self.tokenizer: if not self.tokenizer:
return cast(AnyTokenizer, object()) # Dummy return cast(AnyTokenizer, object()) # Dummy
tokenizer_group = self.get_tokenizer_group() tokenizer = self.get_tokenizer()
return await tokenizer_group.get_lora_tokenizer_async(lora_request) return tokenizer
def _process_multimodal( def _process_multimodal(
self, self,
...@@ -256,7 +238,6 @@ class InputPreprocessor: ...@@ -256,7 +238,6 @@ class InputPreprocessor:
mm_data: MultiModalDataDict, mm_data: MultiModalDataDict,
mm_processor_kwargs: Optional[Mapping[str, object]], mm_processor_kwargs: Optional[Mapping[str, object]],
tokenization_kwargs: Optional[dict[str, Any]] = None, tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
*, *,
mm_uuids: Optional[MultiModalUUIDDict] = None, mm_uuids: Optional[MultiModalUUIDDict] = None,
) -> MultiModalInputs: ) -> MultiModalInputs:
...@@ -264,7 +245,7 @@ class InputPreprocessor: ...@@ -264,7 +245,7 @@ class InputPreprocessor:
Apply the model's multi-modal processor to a multi-modal prompt, Apply the model's multi-modal processor to a multi-modal prompt,
returning the corresponding token IDs and metadata. returning the corresponding token IDs and metadata.
""" """
tokenizer = self._get_mm_tokenizer(lora_request) tokenizer = self._get_mm_tokenizer()
mm_processor = self.mm_registry.create_processor( mm_processor = self.mm_registry.create_processor(
self.model_config, self.model_config,
...@@ -299,7 +280,6 @@ class InputPreprocessor: ...@@ -299,7 +280,6 @@ class InputPreprocessor:
mm_data: MultiModalDataDict, mm_data: MultiModalDataDict,
mm_processor_kwargs: Optional[Mapping[str, object]], mm_processor_kwargs: Optional[Mapping[str, object]],
tokenization_kwargs: Optional[dict[str, Any]] = None, tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
*, *,
mm_uuids: Optional[MultiModalUUIDDict] = None, mm_uuids: Optional[MultiModalUUIDDict] = None,
) -> MultiModalInputs: ) -> MultiModalInputs:
...@@ -307,7 +287,7 @@ class InputPreprocessor: ...@@ -307,7 +287,7 @@ class InputPreprocessor:
Async version of Async version of
[`_process_multimodal`][vllm.inputs.preprocess.InputPreprocessor._process_multimodal]. [`_process_multimodal`][vllm.inputs.preprocess.InputPreprocessor._process_multimodal].
""" """
tokenizer = await self._get_mm_tokenizer_async(lora_request) tokenizer = await self._get_mm_tokenizer_async()
mm_processor = self.mm_registry.create_processor( mm_processor = self.mm_registry.create_processor(
self.model_config, self.model_config,
...@@ -386,7 +366,6 @@ class InputPreprocessor: ...@@ -386,7 +366,6 @@ class InputPreprocessor:
self, self,
parsed_content: TokensPrompt, parsed_content: TokensPrompt,
tokenization_kwargs: Optional[dict[str, Any]] = None, tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
*, *,
mm_uuids: Optional[MultiModalUUIDDict] = None, mm_uuids: Optional[MultiModalUUIDDict] = None,
) -> Union[TokenInputs, MultiModalInputs]: ) -> Union[TokenInputs, MultiModalInputs]:
...@@ -400,7 +379,6 @@ class InputPreprocessor: ...@@ -400,7 +379,6 @@ class InputPreprocessor:
multi_modal_data, multi_modal_data,
parsed_content.get("mm_processor_kwargs"), parsed_content.get("mm_processor_kwargs"),
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
mm_uuids=mm_uuids, mm_uuids=mm_uuids,
) )
else: else:
...@@ -415,7 +393,6 @@ class InputPreprocessor: ...@@ -415,7 +393,6 @@ class InputPreprocessor:
self, self,
parsed_content: TokensPrompt, parsed_content: TokensPrompt,
tokenization_kwargs: Optional[dict[str, Any]] = None, tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
*, *,
mm_uuids: Optional[MultiModalUUIDDict] = None, mm_uuids: Optional[MultiModalUUIDDict] = None,
) -> Union[TokenInputs, MultiModalInputs]: ) -> Union[TokenInputs, MultiModalInputs]:
...@@ -429,7 +406,6 @@ class InputPreprocessor: ...@@ -429,7 +406,6 @@ class InputPreprocessor:
multi_modal_data, multi_modal_data,
parsed_content.get("mm_processor_kwargs"), parsed_content.get("mm_processor_kwargs"),
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
mm_uuids=mm_uuids, mm_uuids=mm_uuids,
) )
else: else:
...@@ -444,7 +420,6 @@ class InputPreprocessor: ...@@ -444,7 +420,6 @@ class InputPreprocessor:
self, self,
parsed_content: TextPrompt, parsed_content: TextPrompt,
tokenization_kwargs: Optional[dict[str, Any]] = None, tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
*, *,
mm_uuids: Optional[MultiModalUUIDDict] = None, mm_uuids: Optional[MultiModalUUIDDict] = None,
) -> Union[TokenInputs, MultiModalInputs]: ) -> Union[TokenInputs, MultiModalInputs]:
...@@ -457,13 +432,11 @@ class InputPreprocessor: ...@@ -457,13 +432,11 @@ class InputPreprocessor:
multi_modal_data, multi_modal_data,
parsed_content.get("mm_processor_kwargs"), parsed_content.get("mm_processor_kwargs"),
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
mm_uuids=mm_uuids, mm_uuids=mm_uuids,
) )
else: else:
prompt_token_ids = self._tokenize_prompt( prompt_token_ids = self._tokenize_prompt(
prompt_text, prompt_text,
lora_request=lora_request,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
) )
inputs = token_inputs( inputs = token_inputs(
...@@ -480,7 +453,6 @@ class InputPreprocessor: ...@@ -480,7 +453,6 @@ class InputPreprocessor:
self, self,
parsed_content: TextPrompt, parsed_content: TextPrompt,
tokenization_kwargs: Optional[dict[str, Any]] = None, tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
*, *,
mm_uuids: Optional[MultiModalUUIDDict] = None, mm_uuids: Optional[MultiModalUUIDDict] = None,
) -> Union[TokenInputs, MultiModalInputs]: ) -> Union[TokenInputs, MultiModalInputs]:
...@@ -493,13 +465,11 @@ class InputPreprocessor: ...@@ -493,13 +465,11 @@ class InputPreprocessor:
multi_modal_data, multi_modal_data,
parsed_content.get("mm_processor_kwargs"), parsed_content.get("mm_processor_kwargs"),
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
mm_uuids=mm_uuids, mm_uuids=mm_uuids,
) )
else: else:
prompt_token_ids = await self._tokenize_prompt_async( prompt_token_ids = await self._tokenize_prompt_async(
prompt_text, prompt_text,
lora_request=lora_request,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
) )
inputs = token_inputs( inputs = token_inputs(
...@@ -516,7 +486,6 @@ class InputPreprocessor: ...@@ -516,7 +486,6 @@ class InputPreprocessor:
self, self,
prompt: SingletonPrompt, prompt: SingletonPrompt,
tokenization_kwargs: Optional[dict[str, Any]] = None, tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
*, *,
mm_uuids: Optional[MultiModalUUIDDict] = None, mm_uuids: Optional[MultiModalUUIDDict] = None,
) -> SingletonInputs: ) -> SingletonInputs:
...@@ -526,7 +495,6 @@ class InputPreprocessor: ...@@ -526,7 +495,6 @@ class InputPreprocessor:
Arguments: Arguments:
* prompt: single encoder or decoder input prompt * prompt: single encoder or decoder input prompt
* lora_request: this is only valid for decoder prompts
Returns: Returns:
...@@ -539,21 +507,18 @@ class InputPreprocessor: ...@@ -539,21 +507,18 @@ class InputPreprocessor:
if parsed["type"] == "tokens": if parsed["type"] == "tokens":
return self._process_tokens( return self._process_tokens(
parsed["content"], parsed["content"],
lora_request=lora_request,
mm_uuids=mm_uuids, mm_uuids=mm_uuids,
) )
if parsed["type"] == "text": if parsed["type"] == "text":
return self._process_text( return self._process_text(
parsed["content"], parsed["content"],
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
mm_uuids=mm_uuids, mm_uuids=mm_uuids,
) )
if parsed["type"] == "str": if parsed["type"] == "str":
return self._process_text( return self._process_text(
TextPrompt(prompt=parsed["content"]), TextPrompt(prompt=parsed["content"]),
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
mm_uuids=mm_uuids, mm_uuids=mm_uuids,
) )
...@@ -563,7 +528,6 @@ class InputPreprocessor: ...@@ -563,7 +528,6 @@ class InputPreprocessor:
self, self,
prompt: SingletonPrompt, prompt: SingletonPrompt,
tokenization_kwargs: Optional[dict[str, Any]] = None, tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
*, *,
mm_uuids: Optional[MultiModalUUIDDict] = None, mm_uuids: Optional[MultiModalUUIDDict] = None,
) -> SingletonInputs: ) -> SingletonInputs:
...@@ -578,21 +542,18 @@ class InputPreprocessor: ...@@ -578,21 +542,18 @@ class InputPreprocessor:
if parsed["type"] == "tokens": if parsed["type"] == "tokens":
return await self._process_tokens_async( return await self._process_tokens_async(
parsed["content"], parsed["content"],
lora_request=lora_request,
mm_uuids=mm_uuids, mm_uuids=mm_uuids,
) )
if parsed["type"] == "text": if parsed["type"] == "text":
return await self._process_text_async( return await self._process_text_async(
parsed["content"], parsed["content"],
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
mm_uuids=mm_uuids, mm_uuids=mm_uuids,
) )
if parsed["type"] == "str": if parsed["type"] == "str":
return await self._process_text_async( return await self._process_text_async(
TextPrompt(prompt=parsed["content"]), TextPrompt(prompt=parsed["content"]),
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
mm_uuids=mm_uuids, mm_uuids=mm_uuids,
) )
...@@ -844,7 +805,6 @@ class InputPreprocessor: ...@@ -844,7 +805,6 @@ class InputPreprocessor:
self, self,
prompt: SingletonPrompt, prompt: SingletonPrompt,
tokenization_kwargs: Optional[dict[str, Any]] = None, tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
*, *,
mm_uuids: Optional[MultiModalUUIDDict] = None, mm_uuids: Optional[MultiModalUUIDDict] = None,
) -> DecoderOnlyInputs: ) -> DecoderOnlyInputs:
...@@ -856,7 +816,6 @@ class InputPreprocessor: ...@@ -856,7 +816,6 @@ class InputPreprocessor:
Arguments: Arguments:
* prompt: input prompt * prompt: input prompt
* lora_request
Returns: Returns:
...@@ -866,7 +825,6 @@ class InputPreprocessor: ...@@ -866,7 +825,6 @@ class InputPreprocessor:
prompt_comps = self._prompt_to_llm_inputs( prompt_comps = self._prompt_to_llm_inputs(
prompt, prompt,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
mm_uuids=mm_uuids, mm_uuids=mm_uuids,
) )
...@@ -876,7 +834,6 @@ class InputPreprocessor: ...@@ -876,7 +834,6 @@ class InputPreprocessor:
self, self,
prompt: SingletonPrompt, prompt: SingletonPrompt,
tokenization_kwargs: Optional[dict[str, Any]] = None, tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
*, *,
mm_uuids: Optional[MultiModalUUIDDict] = None, mm_uuids: Optional[MultiModalUUIDDict] = None,
) -> DecoderOnlyInputs: ) -> DecoderOnlyInputs:
...@@ -887,7 +844,6 @@ class InputPreprocessor: ...@@ -887,7 +844,6 @@ class InputPreprocessor:
prompt_comps = await self._prompt_to_llm_inputs_async( prompt_comps = await self._prompt_to_llm_inputs_async(
prompt, prompt,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
mm_uuids=mm_uuids, mm_uuids=mm_uuids,
) )
...@@ -897,7 +853,6 @@ class InputPreprocessor: ...@@ -897,7 +853,6 @@ class InputPreprocessor:
self, self,
prompt: PromptType, prompt: PromptType,
tokenization_kwargs: Optional[dict[str, Any]] = None, tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
*, *,
mm_uuids: Optional[MultiModalUUIDDict] = None, mm_uuids: Optional[MultiModalUUIDDict] = None,
) -> ProcessorInputs: ) -> ProcessorInputs:
...@@ -919,7 +874,6 @@ class InputPreprocessor: ...@@ -919,7 +874,6 @@ class InputPreprocessor:
return self._process_decoder_only_prompt( return self._process_decoder_only_prompt(
prompt, prompt,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
mm_uuids=mm_uuids, mm_uuids=mm_uuids,
) )
...@@ -927,7 +881,6 @@ class InputPreprocessor: ...@@ -927,7 +881,6 @@ class InputPreprocessor:
self, self,
prompt: PromptType, prompt: PromptType,
tokenization_kwargs: Optional[dict[str, Any]] = None, tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
*, *,
mm_uuids: Optional[MultiModalUUIDDict] = None, mm_uuids: Optional[MultiModalUUIDDict] = None,
) -> ProcessorInputs: ) -> ProcessorInputs:
...@@ -952,7 +905,6 @@ class InputPreprocessor: ...@@ -952,7 +905,6 @@ class InputPreprocessor:
return await self._process_decoder_only_prompt_async( return await self._process_decoder_only_prompt_async(
prompt, prompt,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
mm_uuids=mm_uuids, mm_uuids=mm_uuids,
) )
......
...@@ -10,18 +10,13 @@ from vllm.sequence import (VLLM_INVALID_TOKEN_ID, SamplingParams, Sequence, ...@@ -10,18 +10,13 @@ from vllm.sequence import (VLLM_INVALID_TOKEN_ID, SamplingParams, Sequence,
from .detokenizer_utils import (convert_prompt_ids_to_tokens, from .detokenizer_utils import (convert_prompt_ids_to_tokens,
detokenize_incrementally) detokenize_incrementally)
from .tokenizer import AnyTokenizer from .tokenizer import AnyTokenizer
from .tokenizer_group import TokenizerGroup
class Detokenizer: class Detokenizer:
"""Provides methods to decode the output of a model into text.""" """Provides methods to decode the output of a model into text."""
def __init__(self, tokenizer_group: TokenizerGroup): def __init__(self, tokenizer: AnyTokenizer):
self.tokenizer_group = tokenizer_group self.tokenizer = tokenizer
def get_tokenizer_for_seq(self, sequence: Sequence) -> AnyTokenizer:
"""Returns the HF tokenizer to use for a given sequence."""
return self.tokenizer_group.get_lora_tokenizer(sequence.lora_request)
def decode_prompt_logprobs_inplace(self, seq_group: SequenceGroup, def decode_prompt_logprobs_inplace(self, seq_group: SequenceGroup,
prompt_logprobs: list[Optional[dict[ prompt_logprobs: list[Optional[dict[
...@@ -46,7 +41,6 @@ class Detokenizer: ...@@ -46,7 +41,6 @@ class Detokenizer:
# Only prompt, without the generated token. # Only prompt, without the generated token.
all_token_ids = seq.get_token_ids() all_token_ids = seq.get_token_ids()
prompt_token_ids = all_token_ids[:-1] prompt_token_ids = all_token_ids[:-1]
tokenizer = self.get_tokenizer_for_seq(seq)
prefix_offset = 0 prefix_offset = 0
read_offset = 0 read_offset = 0
next_iter_prefix_offset = 0 next_iter_prefix_offset = 0
...@@ -70,7 +64,7 @@ class Detokenizer: ...@@ -70,7 +64,7 @@ class Detokenizer:
prompt_token_ids[:token_position] + [token_id]) prompt_token_ids[:token_position] + [token_id])
(new_tokens, new_text, new_prefix_offset, (new_tokens, new_text, new_prefix_offset,
new_read_offset) = detokenize_incrementally( new_read_offset) = detokenize_incrementally(
tokenizer=tokenizer, tokenizer=self.tokenizer,
all_input_ids=prompt_token_ids_with_token, all_input_ids=prompt_token_ids_with_token,
prev_tokens=prev_tokens, prev_tokens=prev_tokens,
prefix_offset=prefix_offset, prefix_offset=prefix_offset,
...@@ -111,7 +105,6 @@ class Detokenizer: ...@@ -111,7 +105,6 @@ class Detokenizer:
""" """
all_input_ids = seq.get_token_ids() all_input_ids = seq.get_token_ids()
token_id_generated_this_iteration = all_input_ids[-1] token_id_generated_this_iteration = all_input_ids[-1]
tokenizer = self.get_tokenizer_for_seq(seq)
# Convert prompt token IDs to tokens if necessary. # Convert prompt token IDs to tokens if necessary.
# Do it here so that we don't have to repeat this # Do it here so that we don't have to repeat this
...@@ -119,14 +112,14 @@ class Detokenizer: ...@@ -119,14 +112,14 @@ class Detokenizer:
if seq.tokens is None: if seq.tokens is None:
(seq.tokens, seq.prefix_offset, (seq.tokens, seq.prefix_offset,
seq.read_offset) = convert_prompt_ids_to_tokens( seq.read_offset) = convert_prompt_ids_to_tokens(
tokenizer=tokenizer, tokenizer=self.tokenizer,
prompt_ids=all_input_ids[:-1], prompt_ids=all_input_ids[:-1],
skip_special_tokens=prms.skip_special_tokens, skip_special_tokens=prms.skip_special_tokens,
) )
(new_tokens, new_decoded_token_text, prefix_offset, (new_tokens, new_decoded_token_text, prefix_offset,
read_offset) = detokenize_incrementally( read_offset) = detokenize_incrementally(
tokenizer=tokenizer, tokenizer=self.tokenizer,
all_input_ids=all_input_ids, all_input_ids=all_input_ids,
prev_tokens=seq.tokens, prev_tokens=seq.tokens,
prefix_offset=seq.prefix_offset, prefix_offset=seq.prefix_offset,
...@@ -150,7 +143,7 @@ class Detokenizer: ...@@ -150,7 +143,7 @@ class Detokenizer:
and token_id != VLLM_INVALID_TOKEN_ID): and token_id != VLLM_INVALID_TOKEN_ID):
all_input_ids_with_logprob = previous_tokens + [token_id] all_input_ids_with_logprob = previous_tokens + [token_id]
(_, new_text, _, _) = detokenize_incrementally( (_, new_text, _, _) = detokenize_incrementally(
tokenizer=tokenizer, tokenizer=self.tokenizer,
all_input_ids=all_input_ids_with_logprob, all_input_ids=all_input_ids_with_logprob,
prev_tokens=seq.tokens, prev_tokens=seq.tokens,
prefix_offset=seq.prefix_offset, prefix_offset=seq.prefix_offset,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment