"vllm/vscode:/vscode.git/clone" did not exist on "1b7c791d60629453030de1600e756a8ba555455e"
Unverified Commit 34a98427 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Misc] Refactor tokenizer interface (#29693)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent f223ed41
......@@ -27,8 +27,8 @@ ALLOWED_FILES = {
"vllm/distributed/device_communicators/shm_broadcast.py",
"vllm/distributed/device_communicators/shm_object_storage.py",
"vllm/utils/hashing.py",
"tests/tokenizers_/test_cached_tokenizer.py",
"tests/utils_/test_hashing.py",
"tests/tokenization/test_cached_tokenizer.py",
"benchmarks/kernels/graph_machete_bench.py",
"benchmarks/kernels/benchmark_lora.py",
"benchmarks/kernels/benchmark_machete.py",
......
......@@ -35,6 +35,7 @@ FILES = [
"vllm/multimodal",
"vllm/platforms",
"vllm/plugins",
"vllm/tokenizers",
"vllm/transformers_utils",
"vllm/triton_utils",
"vllm/usage",
......
......@@ -39,7 +39,7 @@ from vllm.lora.request import LoRARequest
from vllm.lora.utils import get_adapter_absolute_path
from vllm.multimodal import MultiModalDataDict
from vllm.multimodal.image import convert_image_mode
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.tokenizers import TokenizerLike
from vllm.utils.import_utils import PlaceholderModule
try:
......@@ -293,7 +293,7 @@ def lora_path_on_disk(lora_path: str) -> str:
# Global cache for LoRA tokenizers.
lora_tokenizer_cache: dict[int, AnyTokenizer] = {}
lora_tokenizer_cache: dict[int, TokenizerLike] = {}
def process_image(image: Any) -> Mapping[str, Any]:
......
......@@ -13,7 +13,7 @@ from vllm.plugins.io_processors import IOProcessor
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams
from vllm.tasks import SupportedTask
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.tokenizers import TokenizerLike
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.input_processor import InputProcessor
......@@ -85,7 +85,7 @@ class EngineClient(ABC):
...
@abstractmethod
async def get_tokenizer(self) -> AnyTokenizer:
async def get_tokenizer(self) -> TokenizerLike:
"""Get the tokenizer"""
...
......
......@@ -49,9 +49,9 @@ from vllm.logger import init_logger
from vllm.model_executor.models import SupportsMultiModal
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict, MultiModalUUIDDict
from vllm.multimodal.utils import MEDIA_CONNECTOR_REGISTRY, MediaConnector
from vllm.tokenizers import MistralTokenizer, TokenizerLike
from vllm.transformers_utils.chat_templates import get_chat_template_fallback_path
from vllm.transformers_utils.processor import cached_get_processor
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
from vllm.utils import random_uuid
from vllm.utils.func_utils import supports_kw
......@@ -536,7 +536,7 @@ def resolve_hf_chat_template(
def _resolve_chat_template_content_format(
chat_template: str | None,
tools: list[dict[str, Any]] | None,
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike,
*,
model_config: ModelConfig,
) -> _ChatTemplateContentFormat:
......@@ -593,7 +593,7 @@ def resolve_chat_template_content_format(
chat_template: str | None,
tools: list[dict[str, Any]] | None,
given_format: ChatTemplateContentFormatOption,
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike,
*,
model_config: ModelConfig,
) -> _ChatTemplateContentFormat:
......@@ -627,7 +627,7 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
maximum per prompt.
"""
def __init__(self, model_config: ModelConfig, tokenizer: AnyTokenizer):
def __init__(self, model_config: ModelConfig, tokenizer: TokenizerLike):
super().__init__()
self._model_config = model_config
......@@ -1592,7 +1592,7 @@ def _postprocess_messages(messages: list[ConversationMessage]) -> None:
def parse_chat_messages(
messages: list[ChatCompletionMessageParam],
model_config: ModelConfig,
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike,
content_format: _ChatTemplateContentFormat,
) -> tuple[
list[ConversationMessage],
......@@ -1624,7 +1624,7 @@ def parse_chat_messages(
def parse_chat_messages_futures(
messages: list[ChatCompletionMessageParam],
model_config: ModelConfig,
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike,
content_format: _ChatTemplateContentFormat,
) -> tuple[
list[ConversationMessage],
......
......@@ -71,11 +71,8 @@ from vllm.platforms import current_platform
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import BeamSearchParams, RequestOutputKind, SamplingParams
from vllm.tasks import PoolingTask
from vllm.transformers_utils.tokenizer import (
AnyTokenizer,
MistralTokenizer,
get_cached_tokenizer,
)
from vllm.tokenizers import MistralTokenizer, TokenizerLike
from vllm.transformers_utils.tokenizer import get_cached_tokenizer
from vllm.usage.usage_lib import UsageContext
from vllm.utils.collection_utils import as_iter, is_list_of
from vllm.utils.counter import Counter
......@@ -350,11 +347,11 @@ class LLM:
self.input_processor = self.llm_engine.input_processor
self.io_processor = self.llm_engine.io_processor
def get_tokenizer(self) -> AnyTokenizer:
def get_tokenizer(self) -> TokenizerLike:
return self.llm_engine.get_tokenizer()
@deprecated("`set_tokenizer` is deprecated and will be removed in v0.13.")
def set_tokenizer(self, tokenizer: AnyTokenizer) -> None:
def set_tokenizer(self, tokenizer: TokenizerLike) -> None:
# While CachedTokenizer is dynamic, have no choice but
# compare class name. Misjudgment will arise from
# user-defined tokenizer started with 'Cached'
......@@ -1244,7 +1241,7 @@ class LLM:
def _embedding_score(
self,
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike,
text_1: list[str | TextPrompt | TokensPrompt],
text_2: list[str | TextPrompt | TokensPrompt],
truncate_prompt_tokens: int | None = None,
......@@ -1276,7 +1273,7 @@ class LLM:
def _cross_encoding_score(
self,
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike,
data_1: list[str] | list[ScoreContentPartParam],
data_2: list[str] | list[ScoreContentPartParam],
truncate_prompt_tokens: int | None = None,
......
......@@ -62,8 +62,9 @@ from vllm.logger import init_logger
from vllm.logprobs import Logprob
from vllm.outputs import CompletionOutput, RequestOutput
from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
from vllm.transformers_utils.tokenizers import (
from vllm.tokenizers import TokenizerLike
from vllm.tokenizers.mistral import (
MistralTokenizer,
maybe_serialize_tool_calls,
truncate_tool_call_ids,
validate_request_params,
......@@ -530,7 +531,7 @@ class OpenAIServingChat(OpenAIServing):
request_id: str,
model_name: str,
conversation: list[ConversationMessage],
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike,
request_metadata: RequestResponseMetadata,
) -> AsyncGenerator[str, None]:
created_time = int(time.time())
......@@ -1296,7 +1297,7 @@ class OpenAIServingChat(OpenAIServing):
request_id: str,
model_name: str,
conversation: list[ConversationMessage],
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike,
request_metadata: RequestResponseMetadata,
) -> ErrorResponse | ChatCompletionResponse:
created_time = int(time.time())
......@@ -1624,7 +1625,7 @@ class OpenAIServingChat(OpenAIServing):
self,
logprobs: dict[int, Logprob],
top_logprobs: int | None,
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike,
should_return_as_token_id: bool,
) -> list[ChatCompletionLogProb]:
return [
......@@ -1648,7 +1649,7 @@ class OpenAIServingChat(OpenAIServing):
self,
token_ids: GenericSequence[int],
top_logprobs: GenericSequence[dict[int, Logprob] | None],
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike,
num_output_top_logprobs: int | None = None,
return_as_token_id: bool | None = None,
) -> ChatCompletionLogProbs:
......
......@@ -221,7 +221,7 @@ class ServingClassification(ClassificationMixin):
def _create_pooling_params(
self,
ctx: ClassificationServeContext,
ctx: ServeContext[ClassificationRequest],
) -> PoolingParams | ErrorResponse:
pooling_params = super()._create_pooling_params(ctx)
if isinstance(pooling_params, ErrorResponse):
......
......@@ -33,7 +33,7 @@ from vllm.logger import init_logger
from vllm.logprobs import Logprob
from vllm.outputs import RequestOutput
from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.tokenizers import TokenizerLike
from vllm.utils.async_utils import merge_async_iterators
from vllm.utils.collection_utils import as_list
from vllm.v1.sample.logits_processor import validate_logits_processors_parameters
......@@ -326,7 +326,7 @@ class OpenAIServingCompletion(OpenAIServing):
created_time: int,
model_name: str,
num_prompts: int,
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike | None,
request_metadata: RequestResponseMetadata,
) -> AsyncGenerator[str, None]:
num_choices = 1 if request.n is None else request.n
......@@ -511,7 +511,7 @@ class OpenAIServingCompletion(OpenAIServing):
request_id: str,
created_time: int,
model_name: str,
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike | None,
request_metadata: RequestResponseMetadata,
) -> CompletionResponse:
choices: list[CompletionResponseChoice] = []
......@@ -622,7 +622,7 @@ class OpenAIServingCompletion(OpenAIServing):
token_ids: GenericSequence[int],
top_logprobs: GenericSequence[dict[int, Logprob] | None],
num_output_top_logprobs: int,
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike | None,
initial_text_offset: int = 0,
return_as_token_id: bool | None = None,
) -> CompletionLogProbs:
......@@ -642,9 +642,15 @@ class OpenAIServingCompletion(OpenAIServing):
for i, token_id in enumerate(token_ids):
step_top_logprobs = top_logprobs[i]
if step_top_logprobs is None:
token = tokenizer.decode(token_id)
if should_return_as_token_id:
token = f"token_id:{token_id}"
else:
if tokenizer is None:
raise ValueError(
"Unable to get tokenizer because `skip_tokenizer_init=True`"
)
token = tokenizer.decode(token_id)
out_tokens.append(token)
out_token_logprobs.append(None)
......
......@@ -7,13 +7,14 @@ import time
import traceback
from collections.abc import AsyncGenerator, Callable, Iterable, Mapping, Sequence
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from http import HTTPStatus
from typing import Any, ClassVar, Generic, TypeAlias, TypeVar
import numpy as np
import torch
from fastapi import Request
from pydantic import BaseModel, ConfigDict, Field, TypeAdapter
from pydantic import ConfigDict, TypeAdapter
from starlette.datastructures import Headers
from typing_extensions import TypeIs
......@@ -96,12 +97,12 @@ from vllm.outputs import CompletionOutput, PoolingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams
from vllm.reasoning import ReasoningParser, ReasoningParserManager
from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.tokenizers import MistralTokenizer, TokenizerLike
from vllm.tracing import (
contains_trace_headers,
extract_trace_headers,
log_tracing_disabled_warning,
)
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
from vllm.utils import random_uuid
from vllm.utils.async_utils import (
AsyncMicrobatchTokenizer,
......@@ -184,19 +185,19 @@ def is_embeds_prompt(prompt: RequestPrompt) -> TypeIs[EmbedsPrompt]:
RequestT = TypeVar("RequestT", bound=AnyRequest)
class RequestProcessingMixin(BaseModel):
@dataclass(kw_only=True)
class RequestProcessingMixin:
"""
Mixin for request processing,
handling prompt preparation and engine input.
"""
request_prompts: Sequence[RequestPrompt] | None = []
engine_prompts: list[EngineTokensPrompt] | None = []
model_config = ConfigDict(arbitrary_types_allowed=True)
request_prompts: Sequence[RequestPrompt] | None = field(default_factory=list)
engine_prompts: list[EngineTokensPrompt] | None = field(default_factory=list)
class ResponseGenerationMixin(BaseModel):
@dataclass(kw_only=True)
class ResponseGenerationMixin:
"""
Mixin for response generation,
managing result generators and final batch results.
......@@ -205,54 +206,38 @@ class ResponseGenerationMixin(BaseModel):
result_generator: (
AsyncGenerator[tuple[int, RequestOutput | PoolingRequestOutput], None] | None
) = None
final_res_batch: list[RequestOutput | PoolingRequestOutput] = Field(
final_res_batch: list[RequestOutput | PoolingRequestOutput] = field(
default_factory=list
)
model_config = ConfigDict(arbitrary_types_allowed=True)
class ServeContext(
RequestProcessingMixin,
ResponseGenerationMixin,
BaseModel,
Generic[RequestT],
):
@dataclass(kw_only=True)
class ServeContext(RequestProcessingMixin, ResponseGenerationMixin, Generic[RequestT]):
# Shared across all requests
request: RequestT
raw_request: Request | None = None
model_name: str
request_id: str
created_time: int = Field(default_factory=lambda: int(time.time()))
created_time: int = field(default_factory=lambda: int(time.time()))
lora_request: LoRARequest | None = None
# Shared across most requests
tokenizer: AnyTokenizer | None = None
# `protected_namespaces` resolves Pydantic v2's warning
# on conflict with protected namespace "model_"
model_config = ConfigDict(
protected_namespaces=(),
arbitrary_types_allowed=True,
)
tokenizer: TokenizerLike | None = None
ClassificationServeContext = ServeContext[ClassificationRequest]
@dataclass(kw_only=True)
class ClassificationServeContext(ServeContext[ClassificationRequest]):
pass
@dataclass(kw_only=True)
class EmbeddingServeContext(ServeContext[EmbeddingRequest]):
chat_template: str | None = None
chat_template_content_format: ChatTemplateContentFormatOption
# Used to resolve the Pydantic error related to
# forward reference of MultiModalDataDict in TokensPrompt
RequestProcessingMixin.model_rebuild()
ServeContext.model_rebuild()
ClassificationServeContext.model_rebuild()
EmbeddingServeContext.model_rebuild()
class OpenAIServing:
request_id_prefix: ClassVar[str] = """
A short string prepended to every request’s ID (e.g. "embd", "classify")
......@@ -281,7 +266,7 @@ class OpenAIServing:
apply_mistral_chat_template, executor=self._tokenizer_executor
)
self._async_tokenizer_pool: dict[AnyTokenizer, AsyncMicrobatchTokenizer] = {}
self._async_tokenizer_pool: dict[TokenizerLike, AsyncMicrobatchTokenizer] = {}
self.log_error_stack = log_error_stack
self.input_processor = self.models.input_processor
......@@ -291,7 +276,7 @@ class OpenAIServing:
def _get_tool_parser(
self, tool_parser_name: str | None = None, enable_auto_tools: bool = False
) -> Callable[[AnyTokenizer], ToolParser] | None:
) -> Callable[[TokenizerLike], ToolParser] | None:
"""Get the tool parser based on the name."""
parser = None
if not enable_auto_tools or tool_parser_name is None:
......@@ -317,7 +302,7 @@ class OpenAIServing:
def _get_reasoning_parser(
self,
reasoning_parser_name: str,
) -> Callable[[AnyTokenizer], ReasoningParser] | None:
) -> Callable[[TokenizerLike], ReasoningParser] | None:
"""Get the reasoning parser based on the name."""
parser = None
if not reasoning_parser_name:
......@@ -547,7 +532,7 @@ class OpenAIServing:
prompt_logprobs=None,
)
def _get_renderer(self, tokenizer: AnyTokenizer | None) -> BaseRenderer:
def _get_renderer(self, tokenizer: TokenizerLike | None) -> BaseRenderer:
"""
Get a Renderer instance with the provided tokenizer.
Uses shared async tokenizer pool for efficiency.
......@@ -877,7 +862,7 @@ class OpenAIServing:
self,
request: AnyRequest,
prompt: str,
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike,
add_special_tokens: bool,
) -> TextTokensPrompt:
async_tokenizer = self._get_async_tokenizer(tokenizer)
......@@ -919,7 +904,7 @@ class OpenAIServing:
self,
request: AnyRequest,
prompt_ids: list[int],
tokenizer: AnyTokenizer | None,
tokenizer: TokenizerLike | None,
) -> TextTokensPrompt:
truncate_prompt_tokens = getattr(request, "truncate_prompt_tokens", None)
......@@ -1015,7 +1000,7 @@ class OpenAIServing:
async def _tokenize_prompt_input_async(
self,
request: AnyRequest,
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike,
prompt_input: str | list[int],
add_special_tokens: bool = True,
) -> TextTokensPrompt:
......@@ -1034,7 +1019,7 @@ class OpenAIServing:
async def _tokenize_prompt_inputs_async(
self,
request: AnyRequest,
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike,
prompt_inputs: Iterable[str | list[int]],
add_special_tokens: bool = True,
) -> AsyncGenerator[TextTokensPrompt, None]:
......@@ -1079,7 +1064,7 @@ class OpenAIServing:
async def _preprocess_chat(
self,
request: ChatLikeRequest | ResponsesRequest,
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike | None,
messages: list[ChatCompletionMessageParam],
chat_template: str | None,
chat_template_content_format: ChatTemplateContentFormatOption,
......@@ -1088,13 +1073,18 @@ class OpenAIServing:
tool_dicts: list[dict[str, Any]] | None = None,
documents: list[dict[str, str]] | None = None,
chat_template_kwargs: dict[str, Any] | None = None,
tool_parser: Callable[[AnyTokenizer], ToolParser] | None = None,
tool_parser: Callable[[TokenizerLike], ToolParser] | None = None,
add_special_tokens: bool = False,
) -> tuple[
list[ConversationMessage],
Sequence[RequestPrompt],
list[EngineTokensPrompt],
]:
if tokenizer is None:
raise ValueError(
"Unable to get tokenizer because `skip_tokenizer_init=True`"
)
model_config = self.model_config
resolved_content_format = resolve_chat_template_content_format(
......@@ -1370,9 +1360,9 @@ class OpenAIServing:
@staticmethod
def _parse_tool_calls_from_content(
request: ResponsesRequest | ChatCompletionRequest,
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike,
enable_auto_tools: bool,
tool_parser_cls: Callable[[AnyTokenizer], ToolParser] | None,
tool_parser_cls: Callable[[TokenizerLike], ToolParser] | None,
content: str | None = None,
) -> tuple[list[FunctionCall] | None, str | None]:
function_calls = list[FunctionCall]()
......@@ -1442,7 +1432,7 @@ class OpenAIServing:
def _get_decoded_token(
logprob: Logprob,
token_id: int,
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike | None,
return_as_token_id: bool = False,
) -> str:
if return_as_token_id:
......@@ -1450,6 +1440,12 @@ class OpenAIServing:
if logprob.decoded_token is not None:
return logprob.decoded_token
if tokenizer is None:
raise ValueError(
"Unable to get tokenizer because `skip_tokenizer_init=True`"
)
return tokenizer.decode(token_id)
def _is_model_supported(self, model_name: str | None) -> bool:
......
......@@ -105,7 +105,7 @@ from vllm.logprobs import Logprob as SampleLogprob
from vllm.logprobs import SampleLogprobs
from vllm.outputs import CompletionOutput
from vllm.sampling_params import SamplingParams, StructuredOutputsParams
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.tokenizers import TokenizerLike
from vllm.utils import random_uuid
logger = init_logger(__name__)
......@@ -492,7 +492,7 @@ class OpenAIServingResponses(OpenAIServing):
self,
request: ResponsesRequest,
prev_response: ResponsesResponse | None,
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike,
):
if request.tools is None or (
request.tool_choice == "none" and self.exclude_tools_when_tool_choice_none
......@@ -563,7 +563,7 @@ class OpenAIServingResponses(OpenAIServing):
result_generator: AsyncIterator[ConversationContext],
context: ConversationContext,
model_name: str,
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike,
request_metadata: RequestResponseMetadata,
created_time: int | None = None,
) -> ErrorResponse | ResponsesResponse:
......@@ -675,7 +675,7 @@ class OpenAIServingResponses(OpenAIServing):
self,
logprobs: dict[int, SampleLogprob],
top_logprobs: int,
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike,
) -> list[LogprobTopLogprob]:
"""Returns the top-k logprobs from the logprobs dictionary."""
out = []
......@@ -700,7 +700,7 @@ class OpenAIServingResponses(OpenAIServing):
self,
token_ids: Sequence[int],
logprobs: SampleLogprobs | None,
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike,
top_logprobs: int | None = None,
) -> list[Logprob]:
assert logprobs is not None, "logprobs must be provided"
......@@ -736,7 +736,7 @@ class OpenAIServingResponses(OpenAIServing):
self,
token_ids: Sequence[int],
logprobs: SampleLogprobs | None,
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike,
top_logprobs: int | None = None,
) -> list[response_text_delta_event.Logprob]:
lgs = self._create_response_logprobs(
......@@ -763,7 +763,7 @@ class OpenAIServingResponses(OpenAIServing):
self,
request: ResponsesRequest,
final_output: CompletionOutput,
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike,
) -> list[ResponseOutputItem]:
if self.reasoning_parser:
try:
......@@ -1135,7 +1135,7 @@ class OpenAIServingResponses(OpenAIServing):
result_generator: AsyncIterator[ConversationContext | None],
context: ConversationContext,
model_name: str,
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike,
request_metadata: RequestResponseMetadata,
created_time: int,
_increment_sequence_number_and_return: Callable[
......@@ -1438,7 +1438,7 @@ class OpenAIServingResponses(OpenAIServing):
result_generator: AsyncIterator[ConversationContext | None],
context: ConversationContext,
model_name: str,
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike,
request_metadata: RequestResponseMetadata,
created_time: int,
_increment_sequence_number_and_return: Callable[
......@@ -1891,7 +1891,7 @@ class OpenAIServingResponses(OpenAIServing):
result_generator: AsyncIterator[ConversationContext | None],
context: ConversationContext,
model_name: str,
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike,
request_metadata: RequestResponseMetadata,
created_time: int | None = None,
) -> AsyncGenerator[StreamingResponsesResponse, None]:
......
......@@ -36,7 +36,7 @@ from vllm.inputs.data import TokensPrompt
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.outputs import PoolingRequestOutput, ScoringRequestOutput
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
from vllm.tokenizers import MistralTokenizer, TokenizerLike
from vllm.utils.async_utils import make_async, merge_async_iterators
logger = init_logger(__name__)
......@@ -60,7 +60,7 @@ class ServingScores(OpenAIServing):
async def _embedding_score(
self,
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike,
texts_1: list[str],
texts_2: list[str],
request: RerankRequest | ScoreRequest,
......@@ -153,7 +153,7 @@ class ServingScores(OpenAIServing):
def _preprocess_score(
self,
request: RerankRequest | ScoreRequest,
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike,
tokenization_kwargs: dict[str, Any],
data_1: str | ScoreContentPartParam,
data_2: str | ScoreContentPartParam,
......@@ -175,7 +175,7 @@ class ServingScores(OpenAIServing):
async def _cross_encoding_score(
self,
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike,
data_1: list[str] | list[ScoreContentPartParam],
data_2: list[str] | list[ScoreContentPartParam],
request: RerankRequest | ScoreRequest,
......
......@@ -22,7 +22,7 @@ from vllm.entrypoints.openai.serving_engine import OpenAIServing
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.entrypoints.renderer import RenderConfig
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.tokenizers import TokenizerLike
logger = init_logger(__name__)
......@@ -170,7 +170,7 @@ class OpenAIServingTokenization(OpenAIServing):
@dataclass
class TokenizerInfo:
tokenizer: AnyTokenizer
tokenizer: TokenizerLike
chat_template: str | None
def to_dict(self) -> dict[str, Any]:
......
......@@ -22,7 +22,7 @@ from vllm.logger import init_logger
from vllm.sampling_params import (
StructuredOutputsParams,
)
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.tokenizers import TokenizerLike
from vllm.utils.collection_utils import is_list_of
from vllm.utils.import_utils import import_from_path
......@@ -36,7 +36,7 @@ class ToolParser:
derived classes.
"""
def __init__(self, tokenizer: AnyTokenizer):
def __init__(self, tokenizer: TokenizerLike):
self.prev_tool_call_arr: list[dict] = []
# the index of the tool call that is currently being parsed
self.current_tool_id: int = -1
......
......@@ -19,13 +19,13 @@ from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
ToolParser,
)
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.tokenizers import TokenizerLike
logger = init_logger(__name__)
class DeepSeekV31ToolParser(ToolParser):
def __init__(self, tokenizer: AnyTokenizer):
def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer)
self.current_tool_name_sent: bool = False
......
......@@ -19,13 +19,13 @@ from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
ToolParser,
)
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.tokenizers import TokenizerLike
logger = init_logger(__name__)
class DeepSeekV3ToolParser(ToolParser):
def __init__(self, tokenizer: AnyTokenizer):
def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer)
self.current_tool_name_sent: bool = False
......
......@@ -19,13 +19,13 @@ from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
ToolParser,
)
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.tokenizers import TokenizerLike
logger = init_logger(__name__)
class Ernie45ToolParser(ToolParser):
def __init__(self, tokenizer: AnyTokenizer):
def __init__(self, tokenizer: TokenizerLike):
"""
Ernie thinking model format:
abc\n</think>\n\n\n<tool_call>\ndef\n</tool_call>\n
......
......@@ -22,13 +22,13 @@ from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
ToolParser,
)
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.tokenizers import TokenizerLike
logger = init_logger(__name__)
class Glm4MoeModelToolParser(ToolParser):
def __init__(self, tokenizer: AnyTokenizer):
def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer)
self.current_tool_name_sent = False
self.prev_tool_call_arr: list[dict] = []
......
......@@ -29,7 +29,7 @@ from vllm.entrypoints.openai.tool_parsers.utils import (
partial_json_loads,
)
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.tokenizers import TokenizerLike
logger = init_logger(__name__)
......@@ -44,7 +44,7 @@ class Granite20bFCToolParser(ToolParser):
are all set
"""
def __init__(self, tokenizer: AnyTokenizer):
def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer)
self.bot_token = "<function_call>"
......
......@@ -27,7 +27,7 @@ from vllm.entrypoints.openai.tool_parsers.utils import (
partial_json_loads,
)
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.tokenizers import TokenizerLike
logger = init_logger(__name__)
......@@ -42,7 +42,7 @@ class GraniteToolParser(ToolParser):
are all set
"""
def __init__(self, tokenizer: AnyTokenizer):
def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer)
# for granite 3.0, the token `<|tool_call|>`
self.bot_token = "<|tool_call|>"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment