"vllm/model_executor/models/colqwen3.py" did not exist on "de42abb366032519bca073e057331ead6270e09f"
Unverified Commit 6c47f6bf authored by Zhuohan Li's avatar Zhuohan Li Committed by GitHub
Browse files

[Core] Remove tokenizer group in vLLM (#24078)


Signed-off-by: default avatarZhuohan Li <zhuohan123@gmail.com>
parent c15309a7
......@@ -43,7 +43,7 @@ def _ref_convert_id_to_token(
[RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY])
def test_incremental_detokenization(request_output_kind: RequestOutputKind,
dummy_test_vectors):
output_processor = OutputProcessor(dummy_test_vectors.tokenizer_group,
output_processor = OutputProcessor(dummy_test_vectors.tokenizer,
log_stats=False)
engine_core = MockEngineCore(
tokens_list=dummy_test_vectors.generation_tokens)
......@@ -382,7 +382,7 @@ def test_logprobs_processor(request_output_kind: RequestOutputKind,
num_sample_logprobs: Optional[int],
num_prompt_logprobs: Optional[int],
dummy_test_vectors):
output_processor = OutputProcessor(dummy_test_vectors.tokenizer_group,
output_processor = OutputProcessor(dummy_test_vectors.tokenizer,
log_stats=False)
engine_core = MockEngineCore(
tokens_list=dummy_test_vectors.generation_tokens,
......@@ -535,7 +535,7 @@ def test_stop_token(include_stop_str_in_output: bool,
) # '<|end_of_text|>'
stop_token_ids = [128009] if not is_eos_test else None # '<|eot_id|>'
output_processor = OutputProcessor(dummy_test_vectors.tokenizer_group,
output_processor = OutputProcessor(dummy_test_vectors.tokenizer,
log_stats=False)
# Dummy engine core outputs, with control tokens suffixed to test stops
suffix_token = ([eos_token_id] if is_eos_test else stop_token_ids)
......@@ -642,7 +642,7 @@ def test_stop_token(include_stop_str_in_output: bool,
[None, NUM_SAMPLE_LOGPROBS_UNDER_TEST])
def test_stop_string(include_stop_str_in_output: bool,
num_sample_logprobs: Optional[int], dummy_test_vectors):
output_processor = OutputProcessor(dummy_test_vectors.tokenizer_group,
output_processor = OutputProcessor(dummy_test_vectors.tokenizer,
log_stats=False)
engine_core = MockEngineCore(
tokens_list=dummy_test_vectors.generation_tokens,
......@@ -763,7 +763,7 @@ def test_stop_string(include_stop_str_in_output: bool,
def test_iteration_stats(dummy_test_vectors):
output_processor = OutputProcessor(dummy_test_vectors.tokenizer_group,
output_processor = OutputProcessor(dummy_test_vectors.tokenizer,
log_stats=True)
engine_core = MockEngineCore(dummy_test_vectors.generation_tokens)
engine_core_timestamp = time.monotonic()
......
......@@ -9,7 +9,6 @@ import torch
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
from vllm.engine.arg_utils import EngineArgs
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
from vllm.v1.engine import EngineCoreOutput, FinishReason
from vllm.v1.outputs import LogprobsLists, LogprobsTensors
......@@ -296,7 +295,6 @@ def generate_dummy_prompt_logprobs_tensors(
class DummyOutputProcessorTestVectors:
"""Dummy test vectors for output processor tests"""
tokenizer: GeneralTokenizerType
tokenizer_group: TokenizerGroup
vllm_config: EngineArgs
full_tokens: list[list[int]] # Prompt + generated tokens
prompt_tokens: list[list[int]]
......
......@@ -582,7 +582,7 @@ def test_structured_output_with_reasoning_matrices(
reasoning_parser=reasoning_parser,
speculative_config=speculative_config,
)
tokenizer = llm.get_tokenizer(None)
tokenizer = llm.get_tokenizer()
reasoner = ReasoningParserManager.get_reasoning_parser(reasoning_parser)(
tokenizer=tokenizer)
......
......@@ -37,7 +37,7 @@ from vllm.lora.request import LoRARequest
from vllm.lora.utils import get_adapter_absolute_path
from vllm.multimodal import MultiModalDataDict
from vllm.multimodal.image import convert_image_mode
from vllm.transformers_utils.tokenizer import AnyTokenizer, get_lora_tokenizer
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import PlaceholderModule
try:
......@@ -155,34 +155,26 @@ class BenchmarkDataset(ABC):
def get_random_lora_request(
self,
tokenizer: PreTrainedTokenizerBase,
max_loras: Optional[int] = None,
lora_path: Optional[str] = None,
) -> tuple[Optional[LoRARequest], AnyTokenizer]:
) -> Optional[LoRARequest]:
"""
Optionally select a random LoRA request and return its associated
tokenizer.
Optionally select a random LoRA request.
This method is used when LoRA parameters are provided. It randomly
selects a LoRA based on max_loras and retrieves a cached tokenizer for
that LoRA if available. Otherwise, it returns the base tokenizer.
selects a LoRA based on max_loras.
Args:
tokenizer (PreTrainedTokenizerBase): The base tokenizer to use if no
LoRA is selected.
max_loras (Optional[int]): The maximum number of LoRAs available.
If `None`, LoRA is not used.
lora_path (Optional[str]): Path to the LoRA parameters on disk.
If `None`, LoRA is not used.
Returns:
A tuple with the following elements:
- A new [LoRARequest][] (or `None` if not applicable).
- The tokenizer associated with the LoRA request
(or the base tokenizer).
A new [LoRARequest][] (or `None` if not applicable).
"""
if max_loras is None or lora_path is None:
return None, tokenizer
return None
# Generate a random LoRA ID in the range [1, max_loras].
lora_id = random.randint(1, max_loras)
......@@ -191,11 +183,7 @@ class BenchmarkDataset(ABC):
lora_int_id=lora_id,
lora_path=lora_path_on_disk(lora_path),
)
if lora_id not in lora_tokenizer_cache:
lora_tokenizer_cache[lora_id] = get_lora_tokenizer(lora_request)
# Return lora_request and the cached tokenizer if available; otherwise,
# return the base tokenizer
return lora_request, lora_tokenizer_cache[lora_id] or tokenizer
return lora_request
@abstractmethod
def sample(self, tokenizer: PreTrainedTokenizerBase,
......@@ -982,8 +970,8 @@ class ShareGPTDataset(BenchmarkDataset):
entry["conversations"][1]["value"],
)
lora_request, tokenizer = self.get_random_lora_request(
tokenizer=tokenizer, max_loras=max_loras, lora_path=lora_path)
lora_request = self.get_random_lora_request(
max_loras=max_loras, lora_path=lora_path)
prompt_ids = tokenizer(prompt).input_ids
completion_ids = tokenizer(completion).input_ids
prompt_len = len(prompt_ids)
......@@ -1882,8 +1870,8 @@ class BurstGPTDataset(BenchmarkDataset):
for i in range(num_requests):
input_len = int(data[i][2])
output_len = int(data[i][3])
lora_req, tokenizer = self.get_random_lora_request(
tokenizer=tokenizer, max_loras=max_loras, lora_path=lora_path)
lora_req = self.get_random_lora_request(
max_loras=max_loras, lora_path=lora_path)
vocab_size = tokenizer.vocab_size
# Generate a synthetic prompt: a list of token IDs computed as (i +
# j) modulo vocab_size.
......@@ -2376,7 +2364,6 @@ class AIMODataset(HuggingFaceDataset):
expected_output_len=output_len,
multi_modal_data=None,
request_id=request_id_prefix + str(ind),
))
ind += 1
self.maybe_oversample_requests(sampled_requests, num_requests,
......
......@@ -390,11 +390,8 @@ class _AsyncLLMEngine(LLMEngine):
"""Stop the remote worker execution loop."""
await self.model_executor.stop_remote_worker_execution_loop_async()
async def get_tokenizer_async(self,
lora_request: Optional[LoRARequest] = None
) -> AnyTokenizer:
return await (
self.get_tokenizer_group().get_lora_tokenizer_async(lora_request))
async def get_tokenizer_async(self) -> AnyTokenizer:
return self.get_tokenizer()
async def add_request_async(
self,
......@@ -435,7 +432,6 @@ class _AsyncLLMEngine(LLMEngine):
processed_inputs = await self.input_preprocessor.preprocess_async(
prompt,
lora_request=lora_request,
tokenization_kwargs=tokenization_kwargs,
)
......@@ -614,11 +610,8 @@ class AsyncLLMEngine(EngineClient):
async def get_input_preprocessor(self) -> InputPreprocessor:
return self.engine.input_preprocessor
async def get_tokenizer(
self,
lora_request: Optional[LoRARequest] = None,
) -> AnyTokenizer:
return await self.engine.get_tokenizer_async(lora_request)
async def get_tokenizer(self) -> AnyTokenizer:
return self.engine.get_tokenizer()
def start_background_loop(self) -> None:
"""Start the background loop."""
......
......@@ -49,9 +49,8 @@ from vllm.sequence import (ExecuteModelRequest, ParallelSampleSequenceGroup,
from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context,
init_tracer)
from vllm.transformers_utils.detokenizer import Detokenizer
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.transformers_utils.tokenizer_group import (
TokenizerGroup, init_tokenizer_from_configs)
from vllm.transformers_utils.tokenizer import (AnyTokenizer,
init_tokenizer_from_configs)
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
usage_message)
from vllm.utils import Counter, Device, resolve_obj_by_qualname, weak_bind
......@@ -186,7 +185,7 @@ class LLMEngine:
return outputs_
tokenizer: Optional[TokenizerGroup]
tokenizer: Optional[AnyTokenizer]
def __init__(
self,
......@@ -233,18 +232,9 @@ class LLMEngine:
if self.model_config.skip_tokenizer_init:
self.tokenizer = None
self.detokenizer = None
tokenizer_group = None
else:
self.tokenizer = self._init_tokenizer()
self.detokenizer = Detokenizer(self.tokenizer)
tokenizer_group = self.get_tokenizer_group()
# Ensure that the function doesn't contain a reference to self,
# to avoid engine GC issues
def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer:
assert tokenizer_group, ("tokenizer_group cannot be None, "
"make sure skip_tokenizer_init is False")
return tokenizer_group.get_lora_tokenizer(sequence.lora_request)
self.seq_counter = Counter()
self.generation_config_fields = (
......@@ -389,10 +379,8 @@ class LLMEngine:
self.detokenizer,
self.scheduler,
self.seq_counter,
get_tokenizer_for_seq,
stop_checker=StopChecker(
self.scheduler_config.max_model_len,
get_tokenizer_for_seq,
self.reasoner if self.decoding_config.reasoning_backend
and self.tokenizer else None,
),
......@@ -521,24 +509,15 @@ class LLMEngine:
if model_executor := getattr(self, "model_executor", None):
model_executor.shutdown()
def get_tokenizer_group(self) -> TokenizerGroup:
def get_tokenizer(self) -> AnyTokenizer:
if self.tokenizer is None:
raise ValueError("Unable to get tokenizer because "
"skip_tokenizer_init is True")
return self.tokenizer
def get_tokenizer(
self,
lora_request: Optional[LoRARequest] = None,
) -> AnyTokenizer:
return self.get_tokenizer_group().get_lora_tokenizer(lora_request)
def _init_tokenizer(self) -> TokenizerGroup:
return init_tokenizer_from_configs(
model_config=self.model_config,
scheduler_config=self.scheduler_config,
lora_config=self.lora_config)
def _init_tokenizer(self) -> AnyTokenizer:
return init_tokenizer_from_configs(model_config=self.model_config)
def _verify_args(self) -> None:
self.model_config.verify_with_parallel_config(self.parallel_config)
......@@ -574,11 +553,11 @@ class LLMEngine:
)
return None
self._validate_model_inputs(processed_inputs, lora_request)
self._validate_model_inputs(processed_inputs)
# Create the sequences.
block_size = self.cache_config.block_size
seq_id = next(self.seq_counter)
eos_token_id = self.input_preprocessor.get_eos_token_id(lora_request)
eos_token_id = self.input_preprocessor.get_eos_token_id()
encoder_inputs, decoder_inputs = split_enc_dec_inputs(processed_inputs)
......@@ -700,7 +679,6 @@ class LLMEngine:
processed_inputs = self.input_preprocessor.preprocess(
prompt,
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
)
self._add_processed_request(
......@@ -1739,29 +1717,22 @@ class LLMEngine:
SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_EXECUTE,
metrics.model_execute_time)
def _validate_model_inputs(self, inputs: ProcessorInputs,
lora_request: Optional[LoRARequest]):
def _validate_model_inputs(self, inputs: ProcessorInputs):
encoder_inputs, decoder_inputs = split_enc_dec_inputs(inputs)
if encoder_inputs is not None:
self._validate_model_input(encoder_inputs,
lora_request,
prompt_type="encoder")
self._validate_model_input(encoder_inputs, prompt_type="encoder")
self._validate_model_input(decoder_inputs,
lora_request,
prompt_type="decoder")
self._validate_model_input(decoder_inputs, prompt_type="decoder")
def _validate_model_input(
self,
prompt_inputs: SingletonInputs,
lora_request: Optional[LoRARequest],
*,
prompt_type: Literal["encoder", "decoder"],
):
model_config = self.model_config
tokenizer = (None if self.tokenizer is None else
self.tokenizer.get_lora_tokenizer(lora_request))
tokenizer = self.tokenizer
prompt_ids = prompt_inputs.get("prompt_token_ids", [])
if not prompt_ids:
......@@ -1822,7 +1793,7 @@ class LLMEngine:
logits_processors = []
if (sampling_params.logit_bias or sampling_params.allowed_token_ids):
tokenizer = self.get_tokenizer(lora_request=lora_request)
tokenizer = self.get_tokenizer()
processors = get_openai_logits_processors(
logit_bias=sampling_params.logit_bias,
......@@ -1835,7 +1806,7 @@ class LLMEngine:
sampling_params.allowed_token_ids = None
if len(sampling_params.bad_words) > 0:
tokenizer = self.get_tokenizer(lora_request)
tokenizer = self.get_tokenizer()
processors = get_bad_words_logits_processors(
bad_words=sampling_params.bad_words, tokenizer=tokenizer)
logits_processors.extend(processors)
......
......@@ -2,14 +2,13 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from typing import Callable, List
from typing import List
from vllm.config import SchedulerConfig
from vllm.core.scheduler import Scheduler
from vllm.engine.output_processor.stop_checker import StopChecker
from vllm.sequence import Sequence, SequenceGroup, SequenceGroupOutput
from vllm.sequence import SequenceGroup, SequenceGroupOutput
from vllm.transformers_utils.detokenizer import Detokenizer
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import Counter
......@@ -31,7 +30,6 @@ class SequenceGroupOutputProcessor(ABC):
detokenizer: Detokenizer,
scheduler: List[Scheduler],
seq_counter: Counter,
get_tokenizer_for_seq: Callable[[Sequence], AnyTokenizer],
stop_checker: "StopChecker",
):
"""Create an output processor.
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Callable, List, Optional, Tuple
from typing import List, Optional, Tuple
from vllm.lora.request import LoRARequest
from vllm.reasoning import ReasoningParser
from vllm.sampling_params import SamplingParams
from vllm.sequence import Sequence, SequenceStatus
from vllm.transformers_utils.tokenizer import AnyTokenizer
class StopChecker:
......@@ -20,12 +19,10 @@ class StopChecker:
def __init__(
self,
max_model_len: int,
get_tokenizer_for_seq: Callable[[Sequence], AnyTokenizer],
reasoner: Optional[ReasoningParser] = None,
):
# Do not use it directly, but use `self._get_max_model_len`.
self._max_model_len = max_model_len
self.get_tokenizer_for_seq = get_tokenizer_for_seq
self.reasoner = reasoner
def _get_max_model_len(self, lora_req: Optional[LoRARequest]):
......
......@@ -76,8 +76,7 @@ class EngineClient(ABC):
include_stop_str_in_output = params.include_stop_str_in_output
preprocessor = await self.get_input_preprocessor()
tokenizer_group = preprocessor.get_tokenizer_group()
tokenizer = await tokenizer_group.get_lora_tokenizer_async()
tokenizer = preprocessor.get_tokenizer()
eos_token_id = tokenizer.eos_token_id
if is_explicit_encoder_decoder_prompt(prompt):
......@@ -260,11 +259,8 @@ class EngineClient(ABC):
...
@abstractmethod
async def get_tokenizer(
self,
lora_request: Optional[LoRARequest] = None,
) -> AnyTokenizer:
"""Get the appropriate tokenizer for the request"""
async def get_tokenizer(self) -> AnyTokenizer:
"""Get the tokenizer"""
...
async def get_io_processor(self) -> IOProcessor:
......
......@@ -301,23 +301,17 @@ class LLM:
self.io_processor = get_io_processor(self.llm_engine.vllm_config,
io_processor_plugin)
def get_tokenizer(
self,
lora_request: Optional[LoRARequest] = None,
) -> AnyTokenizer:
return self.llm_engine.get_tokenizer_group().get_lora_tokenizer(
lora_request)
def get_tokenizer(self) -> AnyTokenizer:
return self.llm_engine.get_tokenizer()
def set_tokenizer(self, tokenizer: AnyTokenizer) -> None:
tokenizer_group = self.llm_engine.get_tokenizer_group()
# While CachedTokenizer is dynamic, have no choice but
# compare class name. Misjudgment will arise from
# user-defined tokenizer started with 'Cached'
if tokenizer.__class__.__name__.startswith("Cached"):
tokenizer_group.tokenizer = tokenizer
self.llm_engine.tokenizer = tokenizer
else:
tokenizer_group.tokenizer = get_cached_tokenizer(tokenizer)
self.llm_engine.tokenizer = get_cached_tokenizer(tokenizer)
def get_default_sampling_params(self) -> SamplingParams:
if self.default_sampling_params is None:
......@@ -707,7 +701,6 @@ class LLM:
self,
messages: Union[list[ChatCompletionMessageParam],
list[list[ChatCompletionMessageParam]]],
lora_request: Optional[LoRARequest] = None,
chat_template: Optional[str] = None,
chat_template_content_format: ChatTemplateContentFormatOption = "auto",
add_generation_prompt: bool = True,
......@@ -739,7 +732,7 @@ class LLM:
cast(list[ChatCompletionMessageParam], messages)
]
tokenizer = self.get_tokenizer(lora_request)
tokenizer = self.get_tokenizer()
model_config = self.llm_engine.get_model_config()
resolved_content_format = resolve_chat_template_content_format(
chat_template,
......@@ -872,7 +865,6 @@ class LLM:
prompts = self.preprocess_chat(
messages=messages,
lora_request=lora_request,
chat_template=chat_template,
chat_template_content_format=chat_template_content_format,
add_generation_prompt=add_generation_prompt,
......
......@@ -188,7 +188,7 @@ class OpenAIServingChat(OpenAIServing):
model_name = self.models.model_name(lora_request)
tokenizer = await self.engine_client.get_tokenizer(lora_request)
tokenizer = await self.engine_client.get_tokenizer()
tool_parser = self.tool_parser
......
......@@ -50,10 +50,7 @@ class ClassificationMixin(OpenAIServing):
return None
try:
ctx.lora_request = self._maybe_get_adapters(ctx.request)
ctx.tokenizer = await self.engine_client.get_tokenizer(
ctx.lora_request)
ctx.tokenizer = await self.engine_client.get_tokenizer()
renderer = self._get_renderer(ctx.tokenizer)
ctx.engine_prompts = await renderer.render_prompt(
......
......@@ -127,8 +127,7 @@ class OpenAIServingCompletion(OpenAIServing):
if self.model_config.skip_tokenizer_init:
tokenizer = None
else:
tokenizer = await self.engine_client.get_tokenizer(lora_request
)
tokenizer = await self.engine_client.get_tokenizer()
renderer = self._get_renderer(tokenizer)
engine_prompts = await renderer.render_prompt_and_embeds(
......
......@@ -76,8 +76,7 @@ class EmbeddingMixin(OpenAIServing):
try:
ctx.lora_request = self._maybe_get_adapters(ctx.request)
tokenizer = await self.engine_client.get_tokenizer(ctx.lora_request
)
tokenizer = await self.engine_client.get_tokenizer()
renderer = self._get_renderer(tokenizer)
if isinstance(ctx.request, EmbeddingChatRequest):
......
......@@ -103,8 +103,7 @@ class OpenAIServingPooling(OpenAIServing):
if self.model_config.skip_tokenizer_init:
tokenizer = None
else:
tokenizer = await self.engine_client.get_tokenizer(lora_request
)
tokenizer = await self.engine_client.get_tokenizer()
renderer = self._get_renderer(tokenizer)
if getattr(request, "dimensions", None) is not None:
......
......@@ -240,7 +240,7 @@ class OpenAIServingResponses(OpenAIServing):
try:
lora_request = self._maybe_get_adapters(request)
model_name = self.models.model_name(lora_request)
tokenizer = await self.engine_client.get_tokenizer(lora_request)
tokenizer = await self.engine_client.get_tokenizer()
if self.use_harmony:
messages, request_prompts, engine_prompts = (
......
......@@ -269,7 +269,7 @@ class ServingScores(OpenAIServing):
) -> Union[list[PoolingRequestOutput], ErrorResponse]:
lora_request = self._maybe_get_adapters(request)
tokenizer = await self.engine_client.get_tokenizer(lora_request)
tokenizer = await self.engine_client.get_tokenizer()
truncate_prompt_tokens = getattr(request, "truncate_prompt_tokens",
None)
......
......@@ -65,7 +65,7 @@ class OpenAIServingTokenization(OpenAIServing):
try:
lora_request = self._maybe_get_adapters(request)
tokenizer = await self.engine_client.get_tokenizer(lora_request)
tokenizer = await self.engine_client.get_tokenizer()
renderer = self._get_renderer(tokenizer)
if isinstance(request, TokenizeChatRequest):
......@@ -130,7 +130,7 @@ class OpenAIServingTokenization(OpenAIServing):
lora_request = self._maybe_get_adapters(request)
tokenizer = await self.engine_client.get_tokenizer(lora_request)
tokenizer = await self.engine_client.get_tokenizer()
self._log_inputs(request_id,
request.tokens,
......
......@@ -9,13 +9,11 @@ from typing_extensions import assert_never
from vllm.config import ModelConfig
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.multimodal.cache import BaseMultiModalProcessorCache
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalEncDecInputs,
MultiModalInputs, MultiModalUUIDDict)
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
from .data import (DecoderOnlyInputs, EmbedsInputs, EmbedsPrompt,
EncoderDecoderInputs, ProcessorInputs, PromptType,
......@@ -31,7 +29,7 @@ class InputPreprocessor:
def __init__(
self,
model_config: ModelConfig,
tokenizer: Optional[TokenizerGroup],
tokenizer: Optional[AnyTokenizer],
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
mm_processor_cache: Optional[BaseMultiModalProcessorCache] = None,
) -> None:
......@@ -42,32 +40,28 @@ class InputPreprocessor:
self.mm_registry = mm_registry
self.mm_processor_cache = mm_processor_cache
def get_tokenizer_group(self) -> TokenizerGroup:
def get_tokenizer(self) -> AnyTokenizer:
if self.tokenizer is None:
raise ValueError("You cannot pass text prompts when "
"`skip_tokenizer_init` is True")
return self.tokenizer
def get_bos_token_id(self,
lora_request: Optional[LoRARequest] = None
) -> Optional[int]:
def get_bos_token_id(self) -> Optional[int]:
if self.tokenizer is None:
logger.warning("Using None for BOS token id because tokenizer "
"is not initialized")
return None
return self.tokenizer.get_lora_tokenizer(lora_request).bos_token_id
return self.tokenizer.bos_token_id
def get_eos_token_id(self,
lora_request: Optional[LoRARequest] = None
) -> Optional[int]:
def get_eos_token_id(self) -> Optional[int]:
if self.tokenizer is None:
logger.warning("Using None for EOS token id because tokenizer "
"is not initialized")
return None
return self.tokenizer.get_lora_tokenizer(lora_request).eos_token_id
return self.tokenizer.eos_token_id
def get_decoder_start_token_id(self) -> Optional[int]:
"""
......@@ -190,14 +184,13 @@ class InputPreprocessor:
def _tokenize_prompt(
self,
prompt: str,
lora_request: Optional[LoRARequest],
tokenization_kwargs: Optional[dict[str, Any]] = None,
) -> list[int]:
"""
Apply the model's tokenizer to a text prompt, returning the
corresponding token IDs.
"""
tokenizer = self.get_tokenizer_group()
tokenizer = self.get_tokenizer()
tokenization_kwargs = self._get_tokenization_kw(tokenization_kwargs)
encoder_config = self.model_config.encoder_config
......@@ -205,50 +198,39 @@ class InputPreprocessor:
if encoder_config and encoder_config.get("do_lower_case", False):
prompt = prompt.lower()
return tokenizer.encode(prompt=prompt,
lora_request=lora_request,
**tokenization_kwargs)
return tokenizer.encode(prompt, **tokenization_kwargs)
async def _tokenize_prompt_async(
self,
prompt: str,
lora_request: Optional[LoRARequest],
tokenization_kwargs: Optional[dict[str, Any]] = None,
) -> list[int]:
"""
Async version of
[`_tokenize_prompt`][vllm.inputs.preprocess.InputPreprocessor._tokenize_prompt].
"""
tokenizer = self.get_tokenizer_group()
tokenizer = self.get_tokenizer()
tokenization_kwargs = self._get_tokenization_kw(tokenization_kwargs)
return await tokenizer.encode_async(prompt=prompt,
lora_request=lora_request,
**tokenization_kwargs)
return tokenizer.encode(prompt, **tokenization_kwargs)
def _get_mm_tokenizer(
self,
lora_request: Optional[LoRARequest],
) -> AnyTokenizer:
def _get_mm_tokenizer(self) -> AnyTokenizer:
# PrithviGeoSpatialMAE needs to be initialized without a tokenizer
# while using also multi-modal input
if not self.tokenizer:
return cast(AnyTokenizer, object()) # Dummy
tokenizer_group = self.get_tokenizer_group()
return tokenizer_group.get_lora_tokenizer(lora_request)
tokenizer = self.get_tokenizer()
return tokenizer
async def _get_mm_tokenizer_async(
self,
lora_request: Optional[LoRARequest],
) -> AnyTokenizer:
async def _get_mm_tokenizer_async(self) -> AnyTokenizer:
# PrithviGeoSpatialMAE needs to be initialized without a tokenizer
# while using also multi-modal input
if not self.tokenizer:
return cast(AnyTokenizer, object()) # Dummy
tokenizer_group = self.get_tokenizer_group()
return await tokenizer_group.get_lora_tokenizer_async(lora_request)
tokenizer = self.get_tokenizer()
return tokenizer
def _process_multimodal(
self,
......@@ -256,7 +238,6 @@ class InputPreprocessor:
mm_data: MultiModalDataDict,
mm_processor_kwargs: Optional[Mapping[str, object]],
tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
*,
mm_uuids: Optional[MultiModalUUIDDict] = None,
) -> MultiModalInputs:
......@@ -264,7 +245,7 @@ class InputPreprocessor:
Apply the model's multi-modal processor to a multi-modal prompt,
returning the corresponding token IDs and metadata.
"""
tokenizer = self._get_mm_tokenizer(lora_request)
tokenizer = self._get_mm_tokenizer()
mm_processor = self.mm_registry.create_processor(
self.model_config,
......@@ -299,7 +280,6 @@ class InputPreprocessor:
mm_data: MultiModalDataDict,
mm_processor_kwargs: Optional[Mapping[str, object]],
tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
*,
mm_uuids: Optional[MultiModalUUIDDict] = None,
) -> MultiModalInputs:
......@@ -307,7 +287,7 @@ class InputPreprocessor:
Async version of
[`_process_multimodal`][vllm.inputs.preprocess.InputPreprocessor._process_multimodal].
"""
tokenizer = await self._get_mm_tokenizer_async(lora_request)
tokenizer = await self._get_mm_tokenizer_async()
mm_processor = self.mm_registry.create_processor(
self.model_config,
......@@ -386,7 +366,6 @@ class InputPreprocessor:
self,
parsed_content: TokensPrompt,
tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
*,
mm_uuids: Optional[MultiModalUUIDDict] = None,
) -> Union[TokenInputs, MultiModalInputs]:
......@@ -400,7 +379,6 @@ class InputPreprocessor:
multi_modal_data,
parsed_content.get("mm_processor_kwargs"),
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
mm_uuids=mm_uuids,
)
else:
......@@ -415,7 +393,6 @@ class InputPreprocessor:
self,
parsed_content: TokensPrompt,
tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
*,
mm_uuids: Optional[MultiModalUUIDDict] = None,
) -> Union[TokenInputs, MultiModalInputs]:
......@@ -429,7 +406,6 @@ class InputPreprocessor:
multi_modal_data,
parsed_content.get("mm_processor_kwargs"),
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
mm_uuids=mm_uuids,
)
else:
......@@ -444,7 +420,6 @@ class InputPreprocessor:
self,
parsed_content: TextPrompt,
tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
*,
mm_uuids: Optional[MultiModalUUIDDict] = None,
) -> Union[TokenInputs, MultiModalInputs]:
......@@ -457,13 +432,11 @@ class InputPreprocessor:
multi_modal_data,
parsed_content.get("mm_processor_kwargs"),
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
mm_uuids=mm_uuids,
)
else:
prompt_token_ids = self._tokenize_prompt(
prompt_text,
lora_request=lora_request,
tokenization_kwargs=tokenization_kwargs,
)
inputs = token_inputs(
......@@ -480,7 +453,6 @@ class InputPreprocessor:
self,
parsed_content: TextPrompt,
tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
*,
mm_uuids: Optional[MultiModalUUIDDict] = None,
) -> Union[TokenInputs, MultiModalInputs]:
......@@ -493,13 +465,11 @@ class InputPreprocessor:
multi_modal_data,
parsed_content.get("mm_processor_kwargs"),
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
mm_uuids=mm_uuids,
)
else:
prompt_token_ids = await self._tokenize_prompt_async(
prompt_text,
lora_request=lora_request,
tokenization_kwargs=tokenization_kwargs,
)
inputs = token_inputs(
......@@ -516,7 +486,6 @@ class InputPreprocessor:
self,
prompt: SingletonPrompt,
tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
*,
mm_uuids: Optional[MultiModalUUIDDict] = None,
) -> SingletonInputs:
......@@ -526,7 +495,6 @@ class InputPreprocessor:
Arguments:
* prompt: single encoder or decoder input prompt
* lora_request: this is only valid for decoder prompts
Returns:
......@@ -539,21 +507,18 @@ class InputPreprocessor:
if parsed["type"] == "tokens":
return self._process_tokens(
parsed["content"],
lora_request=lora_request,
mm_uuids=mm_uuids,
)
if parsed["type"] == "text":
return self._process_text(
parsed["content"],
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
mm_uuids=mm_uuids,
)
if parsed["type"] == "str":
return self._process_text(
TextPrompt(prompt=parsed["content"]),
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
mm_uuids=mm_uuids,
)
......@@ -563,7 +528,6 @@ class InputPreprocessor:
self,
prompt: SingletonPrompt,
tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
*,
mm_uuids: Optional[MultiModalUUIDDict] = None,
) -> SingletonInputs:
......@@ -578,21 +542,18 @@ class InputPreprocessor:
if parsed["type"] == "tokens":
return await self._process_tokens_async(
parsed["content"],
lora_request=lora_request,
mm_uuids=mm_uuids,
)
if parsed["type"] == "text":
return await self._process_text_async(
parsed["content"],
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
mm_uuids=mm_uuids,
)
if parsed["type"] == "str":
return await self._process_text_async(
TextPrompt(prompt=parsed["content"]),
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
mm_uuids=mm_uuids,
)
......@@ -844,7 +805,6 @@ class InputPreprocessor:
self,
prompt: SingletonPrompt,
tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
*,
mm_uuids: Optional[MultiModalUUIDDict] = None,
) -> DecoderOnlyInputs:
......@@ -856,7 +816,6 @@ class InputPreprocessor:
Arguments:
* prompt: input prompt
* lora_request
Returns:
......@@ -866,7 +825,6 @@ class InputPreprocessor:
prompt_comps = self._prompt_to_llm_inputs(
prompt,
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
mm_uuids=mm_uuids,
)
......@@ -876,7 +834,6 @@ class InputPreprocessor:
self,
prompt: SingletonPrompt,
tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
*,
mm_uuids: Optional[MultiModalUUIDDict] = None,
) -> DecoderOnlyInputs:
......@@ -887,7 +844,6 @@ class InputPreprocessor:
prompt_comps = await self._prompt_to_llm_inputs_async(
prompt,
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
mm_uuids=mm_uuids,
)
......@@ -897,7 +853,6 @@ class InputPreprocessor:
self,
prompt: PromptType,
tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
*,
mm_uuids: Optional[MultiModalUUIDDict] = None,
) -> ProcessorInputs:
......@@ -919,7 +874,6 @@ class InputPreprocessor:
return self._process_decoder_only_prompt(
prompt,
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
mm_uuids=mm_uuids,
)
......@@ -927,7 +881,6 @@ class InputPreprocessor:
self,
prompt: PromptType,
tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
*,
mm_uuids: Optional[MultiModalUUIDDict] = None,
) -> ProcessorInputs:
......@@ -952,7 +905,6 @@ class InputPreprocessor:
return await self._process_decoder_only_prompt_async(
prompt,
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
mm_uuids=mm_uuids,
)
......
......@@ -10,18 +10,13 @@ from vllm.sequence import (VLLM_INVALID_TOKEN_ID, SamplingParams, Sequence,
from .detokenizer_utils import (convert_prompt_ids_to_tokens,
detokenize_incrementally)
from .tokenizer import AnyTokenizer
from .tokenizer_group import TokenizerGroup
class Detokenizer:
"""Provides methods to decode the output of a model into text."""
def __init__(self, tokenizer_group: TokenizerGroup):
self.tokenizer_group = tokenizer_group
def get_tokenizer_for_seq(self, sequence: Sequence) -> AnyTokenizer:
"""Returns the HF tokenizer to use for a given sequence."""
return self.tokenizer_group.get_lora_tokenizer(sequence.lora_request)
def __init__(self, tokenizer: AnyTokenizer):
self.tokenizer = tokenizer
def decode_prompt_logprobs_inplace(self, seq_group: SequenceGroup,
prompt_logprobs: list[Optional[dict[
......@@ -46,7 +41,6 @@ class Detokenizer:
# Only prompt, without the generated token.
all_token_ids = seq.get_token_ids()
prompt_token_ids = all_token_ids[:-1]
tokenizer = self.get_tokenizer_for_seq(seq)
prefix_offset = 0
read_offset = 0
next_iter_prefix_offset = 0
......@@ -70,7 +64,7 @@ class Detokenizer:
prompt_token_ids[:token_position] + [token_id])
(new_tokens, new_text, new_prefix_offset,
new_read_offset) = detokenize_incrementally(
tokenizer=tokenizer,
tokenizer=self.tokenizer,
all_input_ids=prompt_token_ids_with_token,
prev_tokens=prev_tokens,
prefix_offset=prefix_offset,
......@@ -111,7 +105,6 @@ class Detokenizer:
"""
all_input_ids = seq.get_token_ids()
token_id_generated_this_iteration = all_input_ids[-1]
tokenizer = self.get_tokenizer_for_seq(seq)
# Convert prompt token IDs to tokens if necessary.
# Do it here so that we don't have to repeat this
......@@ -119,14 +112,14 @@ class Detokenizer:
if seq.tokens is None:
(seq.tokens, seq.prefix_offset,
seq.read_offset) = convert_prompt_ids_to_tokens(
tokenizer=tokenizer,
tokenizer=self.tokenizer,
prompt_ids=all_input_ids[:-1],
skip_special_tokens=prms.skip_special_tokens,
)
(new_tokens, new_decoded_token_text, prefix_offset,
read_offset) = detokenize_incrementally(
tokenizer=tokenizer,
tokenizer=self.tokenizer,
all_input_ids=all_input_ids,
prev_tokens=seq.tokens,
prefix_offset=seq.prefix_offset,
......@@ -150,7 +143,7 @@ class Detokenizer:
and token_id != VLLM_INVALID_TOKEN_ID):
all_input_ids_with_logprob = previous_tokens + [token_id]
(_, new_text, _, _) = detokenize_incrementally(
tokenizer=tokenizer,
tokenizer=self.tokenizer,
all_input_ids=all_input_ids_with_logprob,
prev_tokens=seq.tokens,
prefix_offset=seq.prefix_offset,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment