Unverified Commit 6c47f6bf authored by Zhuohan Li's avatar Zhuohan Li Committed by GitHub
Browse files

[Core] Remove tokenizer group in vLLM (#24078)


Signed-off-by: default avatarZhuohan Li <zhuohan123@gmail.com>
parent c15309a7
......@@ -43,7 +43,7 @@ def _ref_convert_id_to_token(
[RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY])
def test_incremental_detokenization(request_output_kind: RequestOutputKind,
dummy_test_vectors):
output_processor = OutputProcessor(dummy_test_vectors.tokenizer_group,
output_processor = OutputProcessor(dummy_test_vectors.tokenizer,
log_stats=False)
engine_core = MockEngineCore(
tokens_list=dummy_test_vectors.generation_tokens)
......@@ -382,7 +382,7 @@ def test_logprobs_processor(request_output_kind: RequestOutputKind,
num_sample_logprobs: Optional[int],
num_prompt_logprobs: Optional[int],
dummy_test_vectors):
output_processor = OutputProcessor(dummy_test_vectors.tokenizer_group,
output_processor = OutputProcessor(dummy_test_vectors.tokenizer,
log_stats=False)
engine_core = MockEngineCore(
tokens_list=dummy_test_vectors.generation_tokens,
......@@ -535,7 +535,7 @@ def test_stop_token(include_stop_str_in_output: bool,
) # '<|end_of_text|>'
stop_token_ids = [128009] if not is_eos_test else None # '<|eot_id|>'
output_processor = OutputProcessor(dummy_test_vectors.tokenizer_group,
output_processor = OutputProcessor(dummy_test_vectors.tokenizer,
log_stats=False)
# Dummy engine core outputs, with control tokens suffixed to test stops
suffix_token = ([eos_token_id] if is_eos_test else stop_token_ids)
......@@ -642,7 +642,7 @@ def test_stop_token(include_stop_str_in_output: bool,
[None, NUM_SAMPLE_LOGPROBS_UNDER_TEST])
def test_stop_string(include_stop_str_in_output: bool,
num_sample_logprobs: Optional[int], dummy_test_vectors):
output_processor = OutputProcessor(dummy_test_vectors.tokenizer_group,
output_processor = OutputProcessor(dummy_test_vectors.tokenizer,
log_stats=False)
engine_core = MockEngineCore(
tokens_list=dummy_test_vectors.generation_tokens,
......@@ -763,7 +763,7 @@ def test_stop_string(include_stop_str_in_output: bool,
def test_iteration_stats(dummy_test_vectors):
output_processor = OutputProcessor(dummy_test_vectors.tokenizer_group,
output_processor = OutputProcessor(dummy_test_vectors.tokenizer,
log_stats=True)
engine_core = MockEngineCore(dummy_test_vectors.generation_tokens)
engine_core_timestamp = time.monotonic()
......
......@@ -9,7 +9,6 @@ import torch
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
from vllm.engine.arg_utils import EngineArgs
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
from vllm.v1.engine import EngineCoreOutput, FinishReason
from vllm.v1.outputs import LogprobsLists, LogprobsTensors
......@@ -39,7 +38,7 @@ def _create_random_top_logprob_test_vector(
upper: float,
) -> torch.Tensor:
"""Create a random vector of top logprob float values.
Use to create fake sample logprobs for testing.
Note that a real production scenario would require
......@@ -63,7 +62,7 @@ def _create_random_top_logprob_test_matrix(
upper: float,
) -> torch.Tensor:
"""Create a random matrix of top logprob float values.
Use to create fake prompt logprobs for testing.
Note that a real production scenario would require
......@@ -296,7 +295,6 @@ def generate_dummy_prompt_logprobs_tensors(
class DummyOutputProcessorTestVectors:
"""Dummy test vectors for output processor tests"""
tokenizer: GeneralTokenizerType
tokenizer_group: TokenizerGroup
vllm_config: EngineArgs
full_tokens: list[list[int]] # Prompt + generated tokens
prompt_tokens: list[list[int]]
......
......@@ -582,7 +582,7 @@ def test_structured_output_with_reasoning_matrices(
reasoning_parser=reasoning_parser,
speculative_config=speculative_config,
)
tokenizer = llm.get_tokenizer(None)
tokenizer = llm.get_tokenizer()
reasoner = ReasoningParserManager.get_reasoning_parser(reasoning_parser)(
tokenizer=tokenizer)
......
This diff is collapsed.
......@@ -390,11 +390,8 @@ class _AsyncLLMEngine(LLMEngine):
"""Stop the remote worker execution loop."""
await self.model_executor.stop_remote_worker_execution_loop_async()
async def get_tokenizer_async(self,
lora_request: Optional[LoRARequest] = None
) -> AnyTokenizer:
return await (
self.get_tokenizer_group().get_lora_tokenizer_async(lora_request))
async def get_tokenizer_async(self) -> AnyTokenizer:
return self.get_tokenizer()
async def add_request_async(
self,
......@@ -435,7 +432,6 @@ class _AsyncLLMEngine(LLMEngine):
processed_inputs = await self.input_preprocessor.preprocess_async(
prompt,
lora_request=lora_request,
tokenization_kwargs=tokenization_kwargs,
)
......@@ -614,11 +610,8 @@ class AsyncLLMEngine(EngineClient):
async def get_input_preprocessor(self) -> InputPreprocessor:
return self.engine.input_preprocessor
async def get_tokenizer(
self,
lora_request: Optional[LoRARequest] = None,
) -> AnyTokenizer:
return await self.engine.get_tokenizer_async(lora_request)
async def get_tokenizer(self) -> AnyTokenizer:
return self.engine.get_tokenizer()
def start_background_loop(self) -> None:
"""Start the background loop."""
......
......@@ -49,9 +49,8 @@ from vllm.sequence import (ExecuteModelRequest, ParallelSampleSequenceGroup,
from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context,
init_tracer)
from vllm.transformers_utils.detokenizer import Detokenizer
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.transformers_utils.tokenizer_group import (
TokenizerGroup, init_tokenizer_from_configs)
from vllm.transformers_utils.tokenizer import (AnyTokenizer,
init_tokenizer_from_configs)
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
usage_message)
from vllm.utils import Counter, Device, resolve_obj_by_qualname, weak_bind
......@@ -186,7 +185,7 @@ class LLMEngine:
return outputs_
tokenizer: Optional[TokenizerGroup]
tokenizer: Optional[AnyTokenizer]
def __init__(
self,
......@@ -233,18 +232,9 @@ class LLMEngine:
if self.model_config.skip_tokenizer_init:
self.tokenizer = None
self.detokenizer = None
tokenizer_group = None
else:
self.tokenizer = self._init_tokenizer()
self.detokenizer = Detokenizer(self.tokenizer)
tokenizer_group = self.get_tokenizer_group()
# Ensure that the function doesn't contain a reference to self,
# to avoid engine GC issues
def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer:
assert tokenizer_group, ("tokenizer_group cannot be None, "
"make sure skip_tokenizer_init is False")
return tokenizer_group.get_lora_tokenizer(sequence.lora_request)
self.seq_counter = Counter()
self.generation_config_fields = (
......@@ -389,10 +379,8 @@ class LLMEngine:
self.detokenizer,
self.scheduler,
self.seq_counter,
get_tokenizer_for_seq,
stop_checker=StopChecker(
self.scheduler_config.max_model_len,
get_tokenizer_for_seq,
self.reasoner if self.decoding_config.reasoning_backend
and self.tokenizer else None,
),
......@@ -521,24 +509,15 @@ class LLMEngine:
if model_executor := getattr(self, "model_executor", None):
model_executor.shutdown()
def get_tokenizer_group(self) -> TokenizerGroup:
def get_tokenizer(self) -> AnyTokenizer:
if self.tokenizer is None:
raise ValueError("Unable to get tokenizer because "
"skip_tokenizer_init is True")
return self.tokenizer
def get_tokenizer(
self,
lora_request: Optional[LoRARequest] = None,
) -> AnyTokenizer:
return self.get_tokenizer_group().get_lora_tokenizer(lora_request)
def _init_tokenizer(self) -> TokenizerGroup:
return init_tokenizer_from_configs(
model_config=self.model_config,
scheduler_config=self.scheduler_config,
lora_config=self.lora_config)
def _init_tokenizer(self) -> AnyTokenizer:
return init_tokenizer_from_configs(model_config=self.model_config)
def _verify_args(self) -> None:
self.model_config.verify_with_parallel_config(self.parallel_config)
......@@ -574,11 +553,11 @@ class LLMEngine:
)
return None
self._validate_model_inputs(processed_inputs, lora_request)
self._validate_model_inputs(processed_inputs)
# Create the sequences.
block_size = self.cache_config.block_size
seq_id = next(self.seq_counter)
eos_token_id = self.input_preprocessor.get_eos_token_id(lora_request)
eos_token_id = self.input_preprocessor.get_eos_token_id()
encoder_inputs, decoder_inputs = split_enc_dec_inputs(processed_inputs)
......@@ -700,7 +679,6 @@ class LLMEngine:
processed_inputs = self.input_preprocessor.preprocess(
prompt,
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
)
self._add_processed_request(
......@@ -1739,29 +1717,22 @@ class LLMEngine:
SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_EXECUTE,
metrics.model_execute_time)
def _validate_model_inputs(self, inputs: ProcessorInputs,
lora_request: Optional[LoRARequest]):
def _validate_model_inputs(self, inputs: ProcessorInputs):
encoder_inputs, decoder_inputs = split_enc_dec_inputs(inputs)
if encoder_inputs is not None:
self._validate_model_input(encoder_inputs,
lora_request,
prompt_type="encoder")
self._validate_model_input(encoder_inputs, prompt_type="encoder")
self._validate_model_input(decoder_inputs,
lora_request,
prompt_type="decoder")
self._validate_model_input(decoder_inputs, prompt_type="decoder")
def _validate_model_input(
self,
prompt_inputs: SingletonInputs,
lora_request: Optional[LoRARequest],
*,
prompt_type: Literal["encoder", "decoder"],
):
model_config = self.model_config
tokenizer = (None if self.tokenizer is None else
self.tokenizer.get_lora_tokenizer(lora_request))
tokenizer = self.tokenizer
prompt_ids = prompt_inputs.get("prompt_token_ids", [])
if not prompt_ids:
......@@ -1822,7 +1793,7 @@ class LLMEngine:
logits_processors = []
if (sampling_params.logit_bias or sampling_params.allowed_token_ids):
tokenizer = self.get_tokenizer(lora_request=lora_request)
tokenizer = self.get_tokenizer()
processors = get_openai_logits_processors(
logit_bias=sampling_params.logit_bias,
......@@ -1835,7 +1806,7 @@ class LLMEngine:
sampling_params.allowed_token_ids = None
if len(sampling_params.bad_words) > 0:
tokenizer = self.get_tokenizer(lora_request)
tokenizer = self.get_tokenizer()
processors = get_bad_words_logits_processors(
bad_words=sampling_params.bad_words, tokenizer=tokenizer)
logits_processors.extend(processors)
......
......@@ -2,14 +2,13 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from typing import Callable, List
from typing import List
from vllm.config import SchedulerConfig
from vllm.core.scheduler import Scheduler
from vllm.engine.output_processor.stop_checker import StopChecker
from vllm.sequence import Sequence, SequenceGroup, SequenceGroupOutput
from vllm.sequence import SequenceGroup, SequenceGroupOutput
from vllm.transformers_utils.detokenizer import Detokenizer
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import Counter
......@@ -31,7 +30,6 @@ class SequenceGroupOutputProcessor(ABC):
detokenizer: Detokenizer,
scheduler: List[Scheduler],
seq_counter: Counter,
get_tokenizer_for_seq: Callable[[Sequence], AnyTokenizer],
stop_checker: "StopChecker",
):
"""Create an output processor.
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Callable, List, Optional, Tuple
from typing import List, Optional, Tuple
from vllm.lora.request import LoRARequest
from vllm.reasoning import ReasoningParser
from vllm.sampling_params import SamplingParams
from vllm.sequence import Sequence, SequenceStatus
from vllm.transformers_utils.tokenizer import AnyTokenizer
class StopChecker:
......@@ -20,12 +19,10 @@ class StopChecker:
def __init__(
self,
max_model_len: int,
get_tokenizer_for_seq: Callable[[Sequence], AnyTokenizer],
reasoner: Optional[ReasoningParser] = None,
):
# Do not use it directly, but use `self._get_max_model_len`.
self._max_model_len = max_model_len
self.get_tokenizer_for_seq = get_tokenizer_for_seq
self.reasoner = reasoner
def _get_max_model_len(self, lora_req: Optional[LoRARequest]):
......
......@@ -76,8 +76,7 @@ class EngineClient(ABC):
include_stop_str_in_output = params.include_stop_str_in_output
preprocessor = await self.get_input_preprocessor()
tokenizer_group = preprocessor.get_tokenizer_group()
tokenizer = await tokenizer_group.get_lora_tokenizer_async()
tokenizer = preprocessor.get_tokenizer()
eos_token_id = tokenizer.eos_token_id
if is_explicit_encoder_decoder_prompt(prompt):
......@@ -260,11 +259,8 @@ class EngineClient(ABC):
...
@abstractmethod
async def get_tokenizer(
self,
lora_request: Optional[LoRARequest] = None,
) -> AnyTokenizer:
"""Get the appropriate tokenizer for the request"""
async def get_tokenizer(self) -> AnyTokenizer:
"""Get the tokenizer"""
...
async def get_io_processor(self) -> IOProcessor:
......
......@@ -301,23 +301,17 @@ class LLM:
self.io_processor = get_io_processor(self.llm_engine.vllm_config,
io_processor_plugin)
def get_tokenizer(
self,
lora_request: Optional[LoRARequest] = None,
) -> AnyTokenizer:
return self.llm_engine.get_tokenizer_group().get_lora_tokenizer(
lora_request)
def get_tokenizer(self) -> AnyTokenizer:
return self.llm_engine.get_tokenizer()
def set_tokenizer(self, tokenizer: AnyTokenizer) -> None:
tokenizer_group = self.llm_engine.get_tokenizer_group()
# While CachedTokenizer is dynamic, have no choice but
# compare class name. Misjudgment will arise from
# user-defined tokenizer started with 'Cached'
if tokenizer.__class__.__name__.startswith("Cached"):
tokenizer_group.tokenizer = tokenizer
self.llm_engine.tokenizer = tokenizer
else:
tokenizer_group.tokenizer = get_cached_tokenizer(tokenizer)
self.llm_engine.tokenizer = get_cached_tokenizer(tokenizer)
def get_default_sampling_params(self) -> SamplingParams:
if self.default_sampling_params is None:
......@@ -707,7 +701,6 @@ class LLM:
self,
messages: Union[list[ChatCompletionMessageParam],
list[list[ChatCompletionMessageParam]]],
lora_request: Optional[LoRARequest] = None,
chat_template: Optional[str] = None,
chat_template_content_format: ChatTemplateContentFormatOption = "auto",
add_generation_prompt: bool = True,
......@@ -739,7 +732,7 @@ class LLM:
cast(list[ChatCompletionMessageParam], messages)
]
tokenizer = self.get_tokenizer(lora_request)
tokenizer = self.get_tokenizer()
model_config = self.llm_engine.get_model_config()
resolved_content_format = resolve_chat_template_content_format(
chat_template,
......@@ -872,7 +865,6 @@ class LLM:
prompts = self.preprocess_chat(
messages=messages,
lora_request=lora_request,
chat_template=chat_template,
chat_template_content_format=chat_template_content_format,
add_generation_prompt=add_generation_prompt,
......@@ -1519,7 +1511,7 @@ class LLM:
):
"""
Validate that if any multi-modal data is skipped (i.e. None),
then its corresponding UUID must be set.
then its corresponding UUID must be set.
"""
if multi_modal_data is None:
return
......
......@@ -188,7 +188,7 @@ class OpenAIServingChat(OpenAIServing):
model_name = self.models.model_name(lora_request)
tokenizer = await self.engine_client.get_tokenizer(lora_request)
tokenizer = await self.engine_client.get_tokenizer()
tool_parser = self.tool_parser
......
......@@ -50,10 +50,7 @@ class ClassificationMixin(OpenAIServing):
return None
try:
ctx.lora_request = self._maybe_get_adapters(ctx.request)
ctx.tokenizer = await self.engine_client.get_tokenizer(
ctx.lora_request)
ctx.tokenizer = await self.engine_client.get_tokenizer()
renderer = self._get_renderer(ctx.tokenizer)
ctx.engine_prompts = await renderer.render_prompt(
......
......@@ -127,8 +127,7 @@ class OpenAIServingCompletion(OpenAIServing):
if self.model_config.skip_tokenizer_init:
tokenizer = None
else:
tokenizer = await self.engine_client.get_tokenizer(lora_request
)
tokenizer = await self.engine_client.get_tokenizer()
renderer = self._get_renderer(tokenizer)
engine_prompts = await renderer.render_prompt_and_embeds(
......
......@@ -76,8 +76,7 @@ class EmbeddingMixin(OpenAIServing):
try:
ctx.lora_request = self._maybe_get_adapters(ctx.request)
tokenizer = await self.engine_client.get_tokenizer(ctx.lora_request
)
tokenizer = await self.engine_client.get_tokenizer()
renderer = self._get_renderer(tokenizer)
if isinstance(ctx.request, EmbeddingChatRequest):
......@@ -394,8 +393,8 @@ class EmbeddingMixin(OpenAIServing):
) -> Optional[ErrorResponse]:
"""Collect and aggregate batch results
with support for chunked processing.
For chunked requests, performs online aggregation to
For chunked requests, performs online aggregation to
minimize memory usage.
For regular requests, collects results normally.
"""
......
......@@ -103,8 +103,7 @@ class OpenAIServingPooling(OpenAIServing):
if self.model_config.skip_tokenizer_init:
tokenizer = None
else:
tokenizer = await self.engine_client.get_tokenizer(lora_request
)
tokenizer = await self.engine_client.get_tokenizer()
renderer = self._get_renderer(tokenizer)
if getattr(request, "dimensions", None) is not None:
......
......@@ -240,7 +240,7 @@ class OpenAIServingResponses(OpenAIServing):
try:
lora_request = self._maybe_get_adapters(request)
model_name = self.models.model_name(lora_request)
tokenizer = await self.engine_client.get_tokenizer(lora_request)
tokenizer = await self.engine_client.get_tokenizer()
if self.use_harmony:
messages, request_prompts, engine_prompts = (
......
......@@ -269,7 +269,7 @@ class ServingScores(OpenAIServing):
) -> Union[list[PoolingRequestOutput], ErrorResponse]:
lora_request = self._maybe_get_adapters(request)
tokenizer = await self.engine_client.get_tokenizer(lora_request)
tokenizer = await self.engine_client.get_tokenizer()
truncate_prompt_tokens = getattr(request, "truncate_prompt_tokens",
None)
......
......@@ -65,7 +65,7 @@ class OpenAIServingTokenization(OpenAIServing):
try:
lora_request = self._maybe_get_adapters(request)
tokenizer = await self.engine_client.get_tokenizer(lora_request)
tokenizer = await self.engine_client.get_tokenizer()
renderer = self._get_renderer(tokenizer)
if isinstance(request, TokenizeChatRequest):
......@@ -130,7 +130,7 @@ class OpenAIServingTokenization(OpenAIServing):
lora_request = self._maybe_get_adapters(request)
tokenizer = await self.engine_client.get_tokenizer(lora_request)
tokenizer = await self.engine_client.get_tokenizer()
self._log_inputs(request_id,
request.tokens,
......
......@@ -9,13 +9,11 @@ from typing_extensions import assert_never
from vllm.config import ModelConfig
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.multimodal.cache import BaseMultiModalProcessorCache
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalEncDecInputs,
MultiModalInputs, MultiModalUUIDDict)
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
from .data import (DecoderOnlyInputs, EmbedsInputs, EmbedsPrompt,
EncoderDecoderInputs, ProcessorInputs, PromptType,
......@@ -31,7 +29,7 @@ class InputPreprocessor:
def __init__(
self,
model_config: ModelConfig,
tokenizer: Optional[TokenizerGroup],
tokenizer: Optional[AnyTokenizer],
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
mm_processor_cache: Optional[BaseMultiModalProcessorCache] = None,
) -> None:
......@@ -42,32 +40,28 @@ class InputPreprocessor:
self.mm_registry = mm_registry
self.mm_processor_cache = mm_processor_cache
def get_tokenizer_group(self) -> TokenizerGroup:
def get_tokenizer(self) -> AnyTokenizer:
if self.tokenizer is None:
raise ValueError("You cannot pass text prompts when "
"`skip_tokenizer_init` is True")
return self.tokenizer
def get_bos_token_id(self,
lora_request: Optional[LoRARequest] = None
) -> Optional[int]:
def get_bos_token_id(self) -> Optional[int]:
if self.tokenizer is None:
logger.warning("Using None for BOS token id because tokenizer "
"is not initialized")
return None
return self.tokenizer.get_lora_tokenizer(lora_request).bos_token_id
return self.tokenizer.bos_token_id
def get_eos_token_id(self,
lora_request: Optional[LoRARequest] = None
) -> Optional[int]:
def get_eos_token_id(self) -> Optional[int]:
if self.tokenizer is None:
logger.warning("Using None for EOS token id because tokenizer "
"is not initialized")
return None
return self.tokenizer.get_lora_tokenizer(lora_request).eos_token_id
return self.tokenizer.eos_token_id
def get_decoder_start_token_id(self) -> Optional[int]:
"""
......@@ -190,14 +184,13 @@ class InputPreprocessor:
def _tokenize_prompt(
self,
prompt: str,
lora_request: Optional[LoRARequest],
tokenization_kwargs: Optional[dict[str, Any]] = None,
) -> list[int]:
"""
Apply the model's tokenizer to a text prompt, returning the
corresponding token IDs.
"""
tokenizer = self.get_tokenizer_group()
tokenizer = self.get_tokenizer()
tokenization_kwargs = self._get_tokenization_kw(tokenization_kwargs)
encoder_config = self.model_config.encoder_config
......@@ -205,50 +198,39 @@ class InputPreprocessor:
if encoder_config and encoder_config.get("do_lower_case", False):
prompt = prompt.lower()
return tokenizer.encode(prompt=prompt,
lora_request=lora_request,
**tokenization_kwargs)
return tokenizer.encode(prompt, **tokenization_kwargs)
async def _tokenize_prompt_async(
self,
prompt: str,
lora_request: Optional[LoRARequest],
tokenization_kwargs: Optional[dict[str, Any]] = None,
) -> list[int]:
"""
Async version of
[`_tokenize_prompt`][vllm.inputs.preprocess.InputPreprocessor._tokenize_prompt].
"""
tokenizer = self.get_tokenizer_group()
tokenizer = self.get_tokenizer()
tokenization_kwargs = self._get_tokenization_kw(tokenization_kwargs)
return await tokenizer.encode_async(prompt=prompt,
lora_request=lora_request,
**tokenization_kwargs)
return tokenizer.encode(prompt, **tokenization_kwargs)
def _get_mm_tokenizer(
self,
lora_request: Optional[LoRARequest],
) -> AnyTokenizer:
def _get_mm_tokenizer(self) -> AnyTokenizer:
# PrithviGeoSpatialMAE needs to be initialized without a tokenizer
# while using also multi-modal input
if not self.tokenizer:
return cast(AnyTokenizer, object()) # Dummy
tokenizer_group = self.get_tokenizer_group()
return tokenizer_group.get_lora_tokenizer(lora_request)
tokenizer = self.get_tokenizer()
return tokenizer
async def _get_mm_tokenizer_async(
self,
lora_request: Optional[LoRARequest],
) -> AnyTokenizer:
async def _get_mm_tokenizer_async(self) -> AnyTokenizer:
# PrithviGeoSpatialMAE needs to be initialized without a tokenizer
# while using also multi-modal input
if not self.tokenizer:
return cast(AnyTokenizer, object()) # Dummy
tokenizer_group = self.get_tokenizer_group()
return await tokenizer_group.get_lora_tokenizer_async(lora_request)
tokenizer = self.get_tokenizer()
return tokenizer
def _process_multimodal(
self,
......@@ -256,7 +238,6 @@ class InputPreprocessor:
mm_data: MultiModalDataDict,
mm_processor_kwargs: Optional[Mapping[str, object]],
tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
*,
mm_uuids: Optional[MultiModalUUIDDict] = None,
) -> MultiModalInputs:
......@@ -264,7 +245,7 @@ class InputPreprocessor:
Apply the model's multi-modal processor to a multi-modal prompt,
returning the corresponding token IDs and metadata.
"""
tokenizer = self._get_mm_tokenizer(lora_request)
tokenizer = self._get_mm_tokenizer()
mm_processor = self.mm_registry.create_processor(
self.model_config,
......@@ -299,7 +280,6 @@ class InputPreprocessor:
mm_data: MultiModalDataDict,
mm_processor_kwargs: Optional[Mapping[str, object]],
tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
*,
mm_uuids: Optional[MultiModalUUIDDict] = None,
) -> MultiModalInputs:
......@@ -307,7 +287,7 @@ class InputPreprocessor:
Async version of
[`_process_multimodal`][vllm.inputs.preprocess.InputPreprocessor._process_multimodal].
"""
tokenizer = await self._get_mm_tokenizer_async(lora_request)
tokenizer = await self._get_mm_tokenizer_async()
mm_processor = self.mm_registry.create_processor(
self.model_config,
......@@ -386,7 +366,6 @@ class InputPreprocessor:
self,
parsed_content: TokensPrompt,
tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
*,
mm_uuids: Optional[MultiModalUUIDDict] = None,
) -> Union[TokenInputs, MultiModalInputs]:
......@@ -400,7 +379,6 @@ class InputPreprocessor:
multi_modal_data,
parsed_content.get("mm_processor_kwargs"),
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
mm_uuids=mm_uuids,
)
else:
......@@ -415,7 +393,6 @@ class InputPreprocessor:
self,
parsed_content: TokensPrompt,
tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
*,
mm_uuids: Optional[MultiModalUUIDDict] = None,
) -> Union[TokenInputs, MultiModalInputs]:
......@@ -429,7 +406,6 @@ class InputPreprocessor:
multi_modal_data,
parsed_content.get("mm_processor_kwargs"),
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
mm_uuids=mm_uuids,
)
else:
......@@ -444,7 +420,6 @@ class InputPreprocessor:
self,
parsed_content: TextPrompt,
tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
*,
mm_uuids: Optional[MultiModalUUIDDict] = None,
) -> Union[TokenInputs, MultiModalInputs]:
......@@ -457,13 +432,11 @@ class InputPreprocessor:
multi_modal_data,
parsed_content.get("mm_processor_kwargs"),
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
mm_uuids=mm_uuids,
)
else:
prompt_token_ids = self._tokenize_prompt(
prompt_text,
lora_request=lora_request,
tokenization_kwargs=tokenization_kwargs,
)
inputs = token_inputs(
......@@ -480,7 +453,6 @@ class InputPreprocessor:
self,
parsed_content: TextPrompt,
tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
*,
mm_uuids: Optional[MultiModalUUIDDict] = None,
) -> Union[TokenInputs, MultiModalInputs]:
......@@ -493,13 +465,11 @@ class InputPreprocessor:
multi_modal_data,
parsed_content.get("mm_processor_kwargs"),
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
mm_uuids=mm_uuids,
)
else:
prompt_token_ids = await self._tokenize_prompt_async(
prompt_text,
lora_request=lora_request,
tokenization_kwargs=tokenization_kwargs,
)
inputs = token_inputs(
......@@ -516,7 +486,6 @@ class InputPreprocessor:
self,
prompt: SingletonPrompt,
tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
*,
mm_uuids: Optional[MultiModalUUIDDict] = None,
) -> SingletonInputs:
......@@ -526,7 +495,6 @@ class InputPreprocessor:
Arguments:
* prompt: single encoder or decoder input prompt
* lora_request: this is only valid for decoder prompts
Returns:
......@@ -539,21 +507,18 @@ class InputPreprocessor:
if parsed["type"] == "tokens":
return self._process_tokens(
parsed["content"],
lora_request=lora_request,
mm_uuids=mm_uuids,
)
if parsed["type"] == "text":
return self._process_text(
parsed["content"],
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
mm_uuids=mm_uuids,
)
if parsed["type"] == "str":
return self._process_text(
TextPrompt(prompt=parsed["content"]),
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
mm_uuids=mm_uuids,
)
......@@ -563,7 +528,6 @@ class InputPreprocessor:
self,
prompt: SingletonPrompt,
tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
*,
mm_uuids: Optional[MultiModalUUIDDict] = None,
) -> SingletonInputs:
......@@ -578,21 +542,18 @@ class InputPreprocessor:
if parsed["type"] == "tokens":
return await self._process_tokens_async(
parsed["content"],
lora_request=lora_request,
mm_uuids=mm_uuids,
)
if parsed["type"] == "text":
return await self._process_text_async(
parsed["content"],
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
mm_uuids=mm_uuids,
)
if parsed["type"] == "str":
return await self._process_text_async(
TextPrompt(prompt=parsed["content"]),
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
mm_uuids=mm_uuids,
)
......@@ -844,7 +805,6 @@ class InputPreprocessor:
self,
prompt: SingletonPrompt,
tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
*,
mm_uuids: Optional[MultiModalUUIDDict] = None,
) -> DecoderOnlyInputs:
......@@ -856,7 +816,6 @@ class InputPreprocessor:
Arguments:
* prompt: input prompt
* lora_request
Returns:
......@@ -866,7 +825,6 @@ class InputPreprocessor:
prompt_comps = self._prompt_to_llm_inputs(
prompt,
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
mm_uuids=mm_uuids,
)
......@@ -876,7 +834,6 @@ class InputPreprocessor:
self,
prompt: SingletonPrompt,
tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
*,
mm_uuids: Optional[MultiModalUUIDDict] = None,
) -> DecoderOnlyInputs:
......@@ -887,7 +844,6 @@ class InputPreprocessor:
prompt_comps = await self._prompt_to_llm_inputs_async(
prompt,
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
mm_uuids=mm_uuids,
)
......@@ -897,7 +853,6 @@ class InputPreprocessor:
self,
prompt: PromptType,
tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
*,
mm_uuids: Optional[MultiModalUUIDDict] = None,
) -> ProcessorInputs:
......@@ -919,7 +874,6 @@ class InputPreprocessor:
return self._process_decoder_only_prompt(
prompt,
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
mm_uuids=mm_uuids,
)
......@@ -927,7 +881,6 @@ class InputPreprocessor:
self,
prompt: PromptType,
tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
*,
mm_uuids: Optional[MultiModalUUIDDict] = None,
) -> ProcessorInputs:
......@@ -952,7 +905,6 @@ class InputPreprocessor:
return await self._process_decoder_only_prompt_async(
prompt,
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
mm_uuids=mm_uuids,
)
......
......@@ -10,18 +10,13 @@ from vllm.sequence import (VLLM_INVALID_TOKEN_ID, SamplingParams, Sequence,
from .detokenizer_utils import (convert_prompt_ids_to_tokens,
detokenize_incrementally)
from .tokenizer import AnyTokenizer
from .tokenizer_group import TokenizerGroup
class Detokenizer:
"""Provides methods to decode the output of a model into text."""
def __init__(self, tokenizer_group: TokenizerGroup):
self.tokenizer_group = tokenizer_group
def get_tokenizer_for_seq(self, sequence: Sequence) -> AnyTokenizer:
"""Returns the HF tokenizer to use for a given sequence."""
return self.tokenizer_group.get_lora_tokenizer(sequence.lora_request)
def __init__(self, tokenizer: AnyTokenizer):
self.tokenizer = tokenizer
def decode_prompt_logprobs_inplace(self, seq_group: SequenceGroup,
prompt_logprobs: list[Optional[dict[
......@@ -32,9 +27,9 @@ class Detokenizer:
Args:
seq_group: The sequence group to decode.
prompt_logprobs: The logprobs to decode.
position_offset: Offset of the first index of the logprobs
position_offset: Offset of the first index of the logprobs
relative to the start of the sequence (for chunked prefill).
Returns:
The prompt logprobs with the decoded tokens.
"""
......@@ -46,7 +41,6 @@ class Detokenizer:
# Only prompt, without the generated token.
all_token_ids = seq.get_token_ids()
prompt_token_ids = all_token_ids[:-1]
tokenizer = self.get_tokenizer_for_seq(seq)
prefix_offset = 0
read_offset = 0
next_iter_prefix_offset = 0
......@@ -70,7 +64,7 @@ class Detokenizer:
prompt_token_ids[:token_position] + [token_id])
(new_tokens, new_text, new_prefix_offset,
new_read_offset) = detokenize_incrementally(
tokenizer=tokenizer,
tokenizer=self.tokenizer,
all_input_ids=prompt_token_ids_with_token,
prev_tokens=prev_tokens,
prefix_offset=prefix_offset,
......@@ -111,7 +105,6 @@ class Detokenizer:
"""
all_input_ids = seq.get_token_ids()
token_id_generated_this_iteration = all_input_ids[-1]
tokenizer = self.get_tokenizer_for_seq(seq)
# Convert prompt token IDs to tokens if necessary.
# Do it here so that we don't have to repeat this
......@@ -119,14 +112,14 @@ class Detokenizer:
if seq.tokens is None:
(seq.tokens, seq.prefix_offset,
seq.read_offset) = convert_prompt_ids_to_tokens(
tokenizer=tokenizer,
tokenizer=self.tokenizer,
prompt_ids=all_input_ids[:-1],
skip_special_tokens=prms.skip_special_tokens,
)
(new_tokens, new_decoded_token_text, prefix_offset,
read_offset) = detokenize_incrementally(
tokenizer=tokenizer,
tokenizer=self.tokenizer,
all_input_ids=all_input_ids,
prev_tokens=seq.tokens,
prefix_offset=seq.prefix_offset,
......@@ -150,7 +143,7 @@ class Detokenizer:
and token_id != VLLM_INVALID_TOKEN_ID):
all_input_ids_with_logprob = previous_tokens + [token_id]
(_, new_text, _, _) = detokenize_incrementally(
tokenizer=tokenizer,
tokenizer=self.tokenizer,
all_input_ids=all_input_ids_with_logprob,
prev_tokens=seq.tokens,
prefix_offset=seq.prefix_offset,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment