Unverified Commit 34a98427 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Misc] Refactor tokenizer interface (#29693)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent f223ed41
...@@ -27,8 +27,8 @@ ALLOWED_FILES = { ...@@ -27,8 +27,8 @@ ALLOWED_FILES = {
"vllm/distributed/device_communicators/shm_broadcast.py", "vllm/distributed/device_communicators/shm_broadcast.py",
"vllm/distributed/device_communicators/shm_object_storage.py", "vllm/distributed/device_communicators/shm_object_storage.py",
"vllm/utils/hashing.py", "vllm/utils/hashing.py",
"tests/tokenizers_/test_cached_tokenizer.py",
"tests/utils_/test_hashing.py", "tests/utils_/test_hashing.py",
"tests/tokenization/test_cached_tokenizer.py",
"benchmarks/kernels/graph_machete_bench.py", "benchmarks/kernels/graph_machete_bench.py",
"benchmarks/kernels/benchmark_lora.py", "benchmarks/kernels/benchmark_lora.py",
"benchmarks/kernels/benchmark_machete.py", "benchmarks/kernels/benchmark_machete.py",
......
...@@ -35,6 +35,7 @@ FILES = [ ...@@ -35,6 +35,7 @@ FILES = [
"vllm/multimodal", "vllm/multimodal",
"vllm/platforms", "vllm/platforms",
"vllm/plugins", "vllm/plugins",
"vllm/tokenizers",
"vllm/transformers_utils", "vllm/transformers_utils",
"vllm/triton_utils", "vllm/triton_utils",
"vllm/usage", "vllm/usage",
......
...@@ -39,7 +39,7 @@ from vllm.lora.request import LoRARequest ...@@ -39,7 +39,7 @@ from vllm.lora.request import LoRARequest
from vllm.lora.utils import get_adapter_absolute_path from vllm.lora.utils import get_adapter_absolute_path
from vllm.multimodal import MultiModalDataDict from vllm.multimodal import MultiModalDataDict
from vllm.multimodal.image import convert_image_mode from vllm.multimodal.image import convert_image_mode
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.tokenizers import TokenizerLike
from vllm.utils.import_utils import PlaceholderModule from vllm.utils.import_utils import PlaceholderModule
try: try:
...@@ -293,7 +293,7 @@ def lora_path_on_disk(lora_path: str) -> str: ...@@ -293,7 +293,7 @@ def lora_path_on_disk(lora_path: str) -> str:
# Global cache for LoRA tokenizers. # Global cache for LoRA tokenizers.
lora_tokenizer_cache: dict[int, AnyTokenizer] = {} lora_tokenizer_cache: dict[int, TokenizerLike] = {}
def process_image(image: Any) -> Mapping[str, Any]: def process_image(image: Any) -> Mapping[str, Any]:
......
...@@ -13,7 +13,7 @@ from vllm.plugins.io_processors import IOProcessor ...@@ -13,7 +13,7 @@ from vllm.plugins.io_processors import IOProcessor
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.tasks import SupportedTask from vllm.tasks import SupportedTask
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.tokenizers import TokenizerLike
from vllm.v1.engine import EngineCoreRequest from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.input_processor import InputProcessor from vllm.v1.engine.input_processor import InputProcessor
...@@ -85,7 +85,7 @@ class EngineClient(ABC): ...@@ -85,7 +85,7 @@ class EngineClient(ABC):
... ...
@abstractmethod @abstractmethod
async def get_tokenizer(self) -> AnyTokenizer: async def get_tokenizer(self) -> TokenizerLike:
"""Get the tokenizer""" """Get the tokenizer"""
... ...
......
...@@ -49,9 +49,9 @@ from vllm.logger import init_logger ...@@ -49,9 +49,9 @@ from vllm.logger import init_logger
from vllm.model_executor.models import SupportsMultiModal from vllm.model_executor.models import SupportsMultiModal
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict, MultiModalUUIDDict from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict, MultiModalUUIDDict
from vllm.multimodal.utils import MEDIA_CONNECTOR_REGISTRY, MediaConnector from vllm.multimodal.utils import MEDIA_CONNECTOR_REGISTRY, MediaConnector
from vllm.tokenizers import MistralTokenizer, TokenizerLike
from vllm.transformers_utils.chat_templates import get_chat_template_fallback_path from vllm.transformers_utils.chat_templates import get_chat_template_fallback_path
from vllm.transformers_utils.processor import cached_get_processor from vllm.transformers_utils.processor import cached_get_processor
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
from vllm.utils import random_uuid from vllm.utils import random_uuid
from vllm.utils.func_utils import supports_kw from vllm.utils.func_utils import supports_kw
...@@ -536,7 +536,7 @@ def resolve_hf_chat_template( ...@@ -536,7 +536,7 @@ def resolve_hf_chat_template(
def _resolve_chat_template_content_format( def _resolve_chat_template_content_format(
chat_template: str | None, chat_template: str | None,
tools: list[dict[str, Any]] | None, tools: list[dict[str, Any]] | None,
tokenizer: AnyTokenizer, tokenizer: TokenizerLike,
*, *,
model_config: ModelConfig, model_config: ModelConfig,
) -> _ChatTemplateContentFormat: ) -> _ChatTemplateContentFormat:
...@@ -593,7 +593,7 @@ def resolve_chat_template_content_format( ...@@ -593,7 +593,7 @@ def resolve_chat_template_content_format(
chat_template: str | None, chat_template: str | None,
tools: list[dict[str, Any]] | None, tools: list[dict[str, Any]] | None,
given_format: ChatTemplateContentFormatOption, given_format: ChatTemplateContentFormatOption,
tokenizer: AnyTokenizer, tokenizer: TokenizerLike,
*, *,
model_config: ModelConfig, model_config: ModelConfig,
) -> _ChatTemplateContentFormat: ) -> _ChatTemplateContentFormat:
...@@ -627,7 +627,7 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]): ...@@ -627,7 +627,7 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
maximum per prompt. maximum per prompt.
""" """
def __init__(self, model_config: ModelConfig, tokenizer: AnyTokenizer): def __init__(self, model_config: ModelConfig, tokenizer: TokenizerLike):
super().__init__() super().__init__()
self._model_config = model_config self._model_config = model_config
...@@ -1592,7 +1592,7 @@ def _postprocess_messages(messages: list[ConversationMessage]) -> None: ...@@ -1592,7 +1592,7 @@ def _postprocess_messages(messages: list[ConversationMessage]) -> None:
def parse_chat_messages( def parse_chat_messages(
messages: list[ChatCompletionMessageParam], messages: list[ChatCompletionMessageParam],
model_config: ModelConfig, model_config: ModelConfig,
tokenizer: AnyTokenizer, tokenizer: TokenizerLike,
content_format: _ChatTemplateContentFormat, content_format: _ChatTemplateContentFormat,
) -> tuple[ ) -> tuple[
list[ConversationMessage], list[ConversationMessage],
...@@ -1624,7 +1624,7 @@ def parse_chat_messages( ...@@ -1624,7 +1624,7 @@ def parse_chat_messages(
def parse_chat_messages_futures( def parse_chat_messages_futures(
messages: list[ChatCompletionMessageParam], messages: list[ChatCompletionMessageParam],
model_config: ModelConfig, model_config: ModelConfig,
tokenizer: AnyTokenizer, tokenizer: TokenizerLike,
content_format: _ChatTemplateContentFormat, content_format: _ChatTemplateContentFormat,
) -> tuple[ ) -> tuple[
list[ConversationMessage], list[ConversationMessage],
......
...@@ -71,11 +71,8 @@ from vllm.platforms import current_platform ...@@ -71,11 +71,8 @@ from vllm.platforms import current_platform
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.sampling_params import BeamSearchParams, RequestOutputKind, SamplingParams from vllm.sampling_params import BeamSearchParams, RequestOutputKind, SamplingParams
from vllm.tasks import PoolingTask from vllm.tasks import PoolingTask
from vllm.transformers_utils.tokenizer import ( from vllm.tokenizers import MistralTokenizer, TokenizerLike
AnyTokenizer, from vllm.transformers_utils.tokenizer import get_cached_tokenizer
MistralTokenizer,
get_cached_tokenizer,
)
from vllm.usage.usage_lib import UsageContext from vllm.usage.usage_lib import UsageContext
from vllm.utils.collection_utils import as_iter, is_list_of from vllm.utils.collection_utils import as_iter, is_list_of
from vllm.utils.counter import Counter from vllm.utils.counter import Counter
...@@ -350,11 +347,11 @@ class LLM: ...@@ -350,11 +347,11 @@ class LLM:
self.input_processor = self.llm_engine.input_processor self.input_processor = self.llm_engine.input_processor
self.io_processor = self.llm_engine.io_processor self.io_processor = self.llm_engine.io_processor
def get_tokenizer(self) -> AnyTokenizer: def get_tokenizer(self) -> TokenizerLike:
return self.llm_engine.get_tokenizer() return self.llm_engine.get_tokenizer()
@deprecated("`set_tokenizer` is deprecated and will be removed in v0.13.") @deprecated("`set_tokenizer` is deprecated and will be removed in v0.13.")
def set_tokenizer(self, tokenizer: AnyTokenizer) -> None: def set_tokenizer(self, tokenizer: TokenizerLike) -> None:
# While CachedTokenizer is dynamic, have no choice but # While CachedTokenizer is dynamic, have no choice but
# compare class name. Misjudgment will arise from # compare class name. Misjudgment will arise from
# user-defined tokenizer started with 'Cached' # user-defined tokenizer started with 'Cached'
...@@ -1244,7 +1241,7 @@ class LLM: ...@@ -1244,7 +1241,7 @@ class LLM:
def _embedding_score( def _embedding_score(
self, self,
tokenizer: AnyTokenizer, tokenizer: TokenizerLike,
text_1: list[str | TextPrompt | TokensPrompt], text_1: list[str | TextPrompt | TokensPrompt],
text_2: list[str | TextPrompt | TokensPrompt], text_2: list[str | TextPrompt | TokensPrompt],
truncate_prompt_tokens: int | None = None, truncate_prompt_tokens: int | None = None,
...@@ -1276,7 +1273,7 @@ class LLM: ...@@ -1276,7 +1273,7 @@ class LLM:
def _cross_encoding_score( def _cross_encoding_score(
self, self,
tokenizer: AnyTokenizer, tokenizer: TokenizerLike,
data_1: list[str] | list[ScoreContentPartParam], data_1: list[str] | list[ScoreContentPartParam],
data_2: list[str] | list[ScoreContentPartParam], data_2: list[str] | list[ScoreContentPartParam],
truncate_prompt_tokens: int | None = None, truncate_prompt_tokens: int | None = None,
......
...@@ -62,8 +62,9 @@ from vllm.logger import init_logger ...@@ -62,8 +62,9 @@ from vllm.logger import init_logger
from vllm.logprobs import Logprob from vllm.logprobs import Logprob
from vllm.outputs import CompletionOutput, RequestOutput from vllm.outputs import CompletionOutput, RequestOutput
from vllm.sampling_params import BeamSearchParams, SamplingParams from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer from vllm.tokenizers import TokenizerLike
from vllm.transformers_utils.tokenizers import ( from vllm.tokenizers.mistral import (
MistralTokenizer,
maybe_serialize_tool_calls, maybe_serialize_tool_calls,
truncate_tool_call_ids, truncate_tool_call_ids,
validate_request_params, validate_request_params,
...@@ -530,7 +531,7 @@ class OpenAIServingChat(OpenAIServing): ...@@ -530,7 +531,7 @@ class OpenAIServingChat(OpenAIServing):
request_id: str, request_id: str,
model_name: str, model_name: str,
conversation: list[ConversationMessage], conversation: list[ConversationMessage],
tokenizer: AnyTokenizer, tokenizer: TokenizerLike,
request_metadata: RequestResponseMetadata, request_metadata: RequestResponseMetadata,
) -> AsyncGenerator[str, None]: ) -> AsyncGenerator[str, None]:
created_time = int(time.time()) created_time = int(time.time())
...@@ -1296,7 +1297,7 @@ class OpenAIServingChat(OpenAIServing): ...@@ -1296,7 +1297,7 @@ class OpenAIServingChat(OpenAIServing):
request_id: str, request_id: str,
model_name: str, model_name: str,
conversation: list[ConversationMessage], conversation: list[ConversationMessage],
tokenizer: AnyTokenizer, tokenizer: TokenizerLike,
request_metadata: RequestResponseMetadata, request_metadata: RequestResponseMetadata,
) -> ErrorResponse | ChatCompletionResponse: ) -> ErrorResponse | ChatCompletionResponse:
created_time = int(time.time()) created_time = int(time.time())
...@@ -1624,7 +1625,7 @@ class OpenAIServingChat(OpenAIServing): ...@@ -1624,7 +1625,7 @@ class OpenAIServingChat(OpenAIServing):
self, self,
logprobs: dict[int, Logprob], logprobs: dict[int, Logprob],
top_logprobs: int | None, top_logprobs: int | None,
tokenizer: AnyTokenizer, tokenizer: TokenizerLike,
should_return_as_token_id: bool, should_return_as_token_id: bool,
) -> list[ChatCompletionLogProb]: ) -> list[ChatCompletionLogProb]:
return [ return [
...@@ -1648,7 +1649,7 @@ class OpenAIServingChat(OpenAIServing): ...@@ -1648,7 +1649,7 @@ class OpenAIServingChat(OpenAIServing):
self, self,
token_ids: GenericSequence[int], token_ids: GenericSequence[int],
top_logprobs: GenericSequence[dict[int, Logprob] | None], top_logprobs: GenericSequence[dict[int, Logprob] | None],
tokenizer: AnyTokenizer, tokenizer: TokenizerLike,
num_output_top_logprobs: int | None = None, num_output_top_logprobs: int | None = None,
return_as_token_id: bool | None = None, return_as_token_id: bool | None = None,
) -> ChatCompletionLogProbs: ) -> ChatCompletionLogProbs:
......
...@@ -221,7 +221,7 @@ class ServingClassification(ClassificationMixin): ...@@ -221,7 +221,7 @@ class ServingClassification(ClassificationMixin):
def _create_pooling_params( def _create_pooling_params(
self, self,
ctx: ClassificationServeContext, ctx: ServeContext[ClassificationRequest],
) -> PoolingParams | ErrorResponse: ) -> PoolingParams | ErrorResponse:
pooling_params = super()._create_pooling_params(ctx) pooling_params = super()._create_pooling_params(ctx)
if isinstance(pooling_params, ErrorResponse): if isinstance(pooling_params, ErrorResponse):
......
...@@ -33,7 +33,7 @@ from vllm.logger import init_logger ...@@ -33,7 +33,7 @@ from vllm.logger import init_logger
from vllm.logprobs import Logprob from vllm.logprobs import Logprob
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
from vllm.sampling_params import BeamSearchParams, SamplingParams from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.tokenizers import TokenizerLike
from vllm.utils.async_utils import merge_async_iterators from vllm.utils.async_utils import merge_async_iterators
from vllm.utils.collection_utils import as_list from vllm.utils.collection_utils import as_list
from vllm.v1.sample.logits_processor import validate_logits_processors_parameters from vllm.v1.sample.logits_processor import validate_logits_processors_parameters
...@@ -326,7 +326,7 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -326,7 +326,7 @@ class OpenAIServingCompletion(OpenAIServing):
created_time: int, created_time: int,
model_name: str, model_name: str,
num_prompts: int, num_prompts: int,
tokenizer: AnyTokenizer, tokenizer: TokenizerLike | None,
request_metadata: RequestResponseMetadata, request_metadata: RequestResponseMetadata,
) -> AsyncGenerator[str, None]: ) -> AsyncGenerator[str, None]:
num_choices = 1 if request.n is None else request.n num_choices = 1 if request.n is None else request.n
...@@ -511,7 +511,7 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -511,7 +511,7 @@ class OpenAIServingCompletion(OpenAIServing):
request_id: str, request_id: str,
created_time: int, created_time: int,
model_name: str, model_name: str,
tokenizer: AnyTokenizer, tokenizer: TokenizerLike | None,
request_metadata: RequestResponseMetadata, request_metadata: RequestResponseMetadata,
) -> CompletionResponse: ) -> CompletionResponse:
choices: list[CompletionResponseChoice] = [] choices: list[CompletionResponseChoice] = []
...@@ -622,7 +622,7 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -622,7 +622,7 @@ class OpenAIServingCompletion(OpenAIServing):
token_ids: GenericSequence[int], token_ids: GenericSequence[int],
top_logprobs: GenericSequence[dict[int, Logprob] | None], top_logprobs: GenericSequence[dict[int, Logprob] | None],
num_output_top_logprobs: int, num_output_top_logprobs: int,
tokenizer: AnyTokenizer, tokenizer: TokenizerLike | None,
initial_text_offset: int = 0, initial_text_offset: int = 0,
return_as_token_id: bool | None = None, return_as_token_id: bool | None = None,
) -> CompletionLogProbs: ) -> CompletionLogProbs:
...@@ -642,9 +642,15 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -642,9 +642,15 @@ class OpenAIServingCompletion(OpenAIServing):
for i, token_id in enumerate(token_ids): for i, token_id in enumerate(token_ids):
step_top_logprobs = top_logprobs[i] step_top_logprobs = top_logprobs[i]
if step_top_logprobs is None: if step_top_logprobs is None:
token = tokenizer.decode(token_id)
if should_return_as_token_id: if should_return_as_token_id:
token = f"token_id:{token_id}" token = f"token_id:{token_id}"
else:
if tokenizer is None:
raise ValueError(
"Unable to get tokenizer because `skip_tokenizer_init=True`"
)
token = tokenizer.decode(token_id)
out_tokens.append(token) out_tokens.append(token)
out_token_logprobs.append(None) out_token_logprobs.append(None)
......
...@@ -7,13 +7,14 @@ import time ...@@ -7,13 +7,14 @@ import time
import traceback import traceback
from collections.abc import AsyncGenerator, Callable, Iterable, Mapping, Sequence from collections.abc import AsyncGenerator, Callable, Iterable, Mapping, Sequence
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from http import HTTPStatus from http import HTTPStatus
from typing import Any, ClassVar, Generic, TypeAlias, TypeVar from typing import Any, ClassVar, Generic, TypeAlias, TypeVar
import numpy as np import numpy as np
import torch import torch
from fastapi import Request from fastapi import Request
from pydantic import BaseModel, ConfigDict, Field, TypeAdapter from pydantic import ConfigDict, TypeAdapter
from starlette.datastructures import Headers from starlette.datastructures import Headers
from typing_extensions import TypeIs from typing_extensions import TypeIs
...@@ -96,12 +97,12 @@ from vllm.outputs import CompletionOutput, PoolingRequestOutput, RequestOutput ...@@ -96,12 +97,12 @@ from vllm.outputs import CompletionOutput, PoolingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.reasoning import ReasoningParser, ReasoningParserManager from vllm.reasoning import ReasoningParser, ReasoningParserManager
from vllm.sampling_params import BeamSearchParams, SamplingParams from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.tokenizers import MistralTokenizer, TokenizerLike
from vllm.tracing import ( from vllm.tracing import (
contains_trace_headers, contains_trace_headers,
extract_trace_headers, extract_trace_headers,
log_tracing_disabled_warning, log_tracing_disabled_warning,
) )
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
from vllm.utils import random_uuid from vllm.utils import random_uuid
from vllm.utils.async_utils import ( from vllm.utils.async_utils import (
AsyncMicrobatchTokenizer, AsyncMicrobatchTokenizer,
...@@ -184,19 +185,19 @@ def is_embeds_prompt(prompt: RequestPrompt) -> TypeIs[EmbedsPrompt]: ...@@ -184,19 +185,19 @@ def is_embeds_prompt(prompt: RequestPrompt) -> TypeIs[EmbedsPrompt]:
RequestT = TypeVar("RequestT", bound=AnyRequest) RequestT = TypeVar("RequestT", bound=AnyRequest)
class RequestProcessingMixin(BaseModel): @dataclass(kw_only=True)
class RequestProcessingMixin:
""" """
Mixin for request processing, Mixin for request processing,
handling prompt preparation and engine input. handling prompt preparation and engine input.
""" """
request_prompts: Sequence[RequestPrompt] | None = [] request_prompts: Sequence[RequestPrompt] | None = field(default_factory=list)
engine_prompts: list[EngineTokensPrompt] | None = [] engine_prompts: list[EngineTokensPrompt] | None = field(default_factory=list)
model_config = ConfigDict(arbitrary_types_allowed=True)
class ResponseGenerationMixin(BaseModel): @dataclass(kw_only=True)
class ResponseGenerationMixin:
""" """
Mixin for response generation, Mixin for response generation,
managing result generators and final batch results. managing result generators and final batch results.
...@@ -205,54 +206,38 @@ class ResponseGenerationMixin(BaseModel): ...@@ -205,54 +206,38 @@ class ResponseGenerationMixin(BaseModel):
result_generator: ( result_generator: (
AsyncGenerator[tuple[int, RequestOutput | PoolingRequestOutput], None] | None AsyncGenerator[tuple[int, RequestOutput | PoolingRequestOutput], None] | None
) = None ) = None
final_res_batch: list[RequestOutput | PoolingRequestOutput] = Field( final_res_batch: list[RequestOutput | PoolingRequestOutput] = field(
default_factory=list default_factory=list
) )
model_config = ConfigDict(arbitrary_types_allowed=True) model_config = ConfigDict(arbitrary_types_allowed=True)
class ServeContext( @dataclass(kw_only=True)
RequestProcessingMixin, class ServeContext(RequestProcessingMixin, ResponseGenerationMixin, Generic[RequestT]):
ResponseGenerationMixin,
BaseModel,
Generic[RequestT],
):
# Shared across all requests # Shared across all requests
request: RequestT request: RequestT
raw_request: Request | None = None raw_request: Request | None = None
model_name: str model_name: str
request_id: str request_id: str
created_time: int = Field(default_factory=lambda: int(time.time())) created_time: int = field(default_factory=lambda: int(time.time()))
lora_request: LoRARequest | None = None lora_request: LoRARequest | None = None
# Shared across most requests # Shared across most requests
tokenizer: AnyTokenizer | None = None tokenizer: TokenizerLike | None = None
# `protected_namespaces` resolves Pydantic v2's warning
# on conflict with protected namespace "model_"
model_config = ConfigDict(
protected_namespaces=(),
arbitrary_types_allowed=True,
)
ClassificationServeContext = ServeContext[ClassificationRequest] @dataclass(kw_only=True)
class ClassificationServeContext(ServeContext[ClassificationRequest]):
pass
@dataclass(kw_only=True)
class EmbeddingServeContext(ServeContext[EmbeddingRequest]): class EmbeddingServeContext(ServeContext[EmbeddingRequest]):
chat_template: str | None = None chat_template: str | None = None
chat_template_content_format: ChatTemplateContentFormatOption chat_template_content_format: ChatTemplateContentFormatOption
# Used to resolve the Pydantic error related to
# forward reference of MultiModalDataDict in TokensPrompt
RequestProcessingMixin.model_rebuild()
ServeContext.model_rebuild()
ClassificationServeContext.model_rebuild()
EmbeddingServeContext.model_rebuild()
class OpenAIServing: class OpenAIServing:
request_id_prefix: ClassVar[str] = """ request_id_prefix: ClassVar[str] = """
A short string prepended to every request’s ID (e.g. "embd", "classify") A short string prepended to every request’s ID (e.g. "embd", "classify")
...@@ -281,7 +266,7 @@ class OpenAIServing: ...@@ -281,7 +266,7 @@ class OpenAIServing:
apply_mistral_chat_template, executor=self._tokenizer_executor apply_mistral_chat_template, executor=self._tokenizer_executor
) )
self._async_tokenizer_pool: dict[AnyTokenizer, AsyncMicrobatchTokenizer] = {} self._async_tokenizer_pool: dict[TokenizerLike, AsyncMicrobatchTokenizer] = {}
self.log_error_stack = log_error_stack self.log_error_stack = log_error_stack
self.input_processor = self.models.input_processor self.input_processor = self.models.input_processor
...@@ -291,7 +276,7 @@ class OpenAIServing: ...@@ -291,7 +276,7 @@ class OpenAIServing:
def _get_tool_parser( def _get_tool_parser(
self, tool_parser_name: str | None = None, enable_auto_tools: bool = False self, tool_parser_name: str | None = None, enable_auto_tools: bool = False
) -> Callable[[AnyTokenizer], ToolParser] | None: ) -> Callable[[TokenizerLike], ToolParser] | None:
"""Get the tool parser based on the name.""" """Get the tool parser based on the name."""
parser = None parser = None
if not enable_auto_tools or tool_parser_name is None: if not enable_auto_tools or tool_parser_name is None:
...@@ -317,7 +302,7 @@ class OpenAIServing: ...@@ -317,7 +302,7 @@ class OpenAIServing:
def _get_reasoning_parser( def _get_reasoning_parser(
self, self,
reasoning_parser_name: str, reasoning_parser_name: str,
) -> Callable[[AnyTokenizer], ReasoningParser] | None: ) -> Callable[[TokenizerLike], ReasoningParser] | None:
"""Get the reasoning parser based on the name.""" """Get the reasoning parser based on the name."""
parser = None parser = None
if not reasoning_parser_name: if not reasoning_parser_name:
...@@ -547,7 +532,7 @@ class OpenAIServing: ...@@ -547,7 +532,7 @@ class OpenAIServing:
prompt_logprobs=None, prompt_logprobs=None,
) )
def _get_renderer(self, tokenizer: AnyTokenizer | None) -> BaseRenderer: def _get_renderer(self, tokenizer: TokenizerLike | None) -> BaseRenderer:
""" """
Get a Renderer instance with the provided tokenizer. Get a Renderer instance with the provided tokenizer.
Uses shared async tokenizer pool for efficiency. Uses shared async tokenizer pool for efficiency.
...@@ -877,7 +862,7 @@ class OpenAIServing: ...@@ -877,7 +862,7 @@ class OpenAIServing:
self, self,
request: AnyRequest, request: AnyRequest,
prompt: str, prompt: str,
tokenizer: AnyTokenizer, tokenizer: TokenizerLike,
add_special_tokens: bool, add_special_tokens: bool,
) -> TextTokensPrompt: ) -> TextTokensPrompt:
async_tokenizer = self._get_async_tokenizer(tokenizer) async_tokenizer = self._get_async_tokenizer(tokenizer)
...@@ -919,7 +904,7 @@ class OpenAIServing: ...@@ -919,7 +904,7 @@ class OpenAIServing:
self, self,
request: AnyRequest, request: AnyRequest,
prompt_ids: list[int], prompt_ids: list[int],
tokenizer: AnyTokenizer | None, tokenizer: TokenizerLike | None,
) -> TextTokensPrompt: ) -> TextTokensPrompt:
truncate_prompt_tokens = getattr(request, "truncate_prompt_tokens", None) truncate_prompt_tokens = getattr(request, "truncate_prompt_tokens", None)
...@@ -1015,7 +1000,7 @@ class OpenAIServing: ...@@ -1015,7 +1000,7 @@ class OpenAIServing:
async def _tokenize_prompt_input_async( async def _tokenize_prompt_input_async(
self, self,
request: AnyRequest, request: AnyRequest,
tokenizer: AnyTokenizer, tokenizer: TokenizerLike,
prompt_input: str | list[int], prompt_input: str | list[int],
add_special_tokens: bool = True, add_special_tokens: bool = True,
) -> TextTokensPrompt: ) -> TextTokensPrompt:
...@@ -1034,7 +1019,7 @@ class OpenAIServing: ...@@ -1034,7 +1019,7 @@ class OpenAIServing:
async def _tokenize_prompt_inputs_async( async def _tokenize_prompt_inputs_async(
self, self,
request: AnyRequest, request: AnyRequest,
tokenizer: AnyTokenizer, tokenizer: TokenizerLike,
prompt_inputs: Iterable[str | list[int]], prompt_inputs: Iterable[str | list[int]],
add_special_tokens: bool = True, add_special_tokens: bool = True,
) -> AsyncGenerator[TextTokensPrompt, None]: ) -> AsyncGenerator[TextTokensPrompt, None]:
...@@ -1079,7 +1064,7 @@ class OpenAIServing: ...@@ -1079,7 +1064,7 @@ class OpenAIServing:
async def _preprocess_chat( async def _preprocess_chat(
self, self,
request: ChatLikeRequest | ResponsesRequest, request: ChatLikeRequest | ResponsesRequest,
tokenizer: AnyTokenizer, tokenizer: TokenizerLike | None,
messages: list[ChatCompletionMessageParam], messages: list[ChatCompletionMessageParam],
chat_template: str | None, chat_template: str | None,
chat_template_content_format: ChatTemplateContentFormatOption, chat_template_content_format: ChatTemplateContentFormatOption,
...@@ -1088,13 +1073,18 @@ class OpenAIServing: ...@@ -1088,13 +1073,18 @@ class OpenAIServing:
tool_dicts: list[dict[str, Any]] | None = None, tool_dicts: list[dict[str, Any]] | None = None,
documents: list[dict[str, str]] | None = None, documents: list[dict[str, str]] | None = None,
chat_template_kwargs: dict[str, Any] | None = None, chat_template_kwargs: dict[str, Any] | None = None,
tool_parser: Callable[[AnyTokenizer], ToolParser] | None = None, tool_parser: Callable[[TokenizerLike], ToolParser] | None = None,
add_special_tokens: bool = False, add_special_tokens: bool = False,
) -> tuple[ ) -> tuple[
list[ConversationMessage], list[ConversationMessage],
Sequence[RequestPrompt], Sequence[RequestPrompt],
list[EngineTokensPrompt], list[EngineTokensPrompt],
]: ]:
if tokenizer is None:
raise ValueError(
"Unable to get tokenizer because `skip_tokenizer_init=True`"
)
model_config = self.model_config model_config = self.model_config
resolved_content_format = resolve_chat_template_content_format( resolved_content_format = resolve_chat_template_content_format(
...@@ -1370,9 +1360,9 @@ class OpenAIServing: ...@@ -1370,9 +1360,9 @@ class OpenAIServing:
@staticmethod @staticmethod
def _parse_tool_calls_from_content( def _parse_tool_calls_from_content(
request: ResponsesRequest | ChatCompletionRequest, request: ResponsesRequest | ChatCompletionRequest,
tokenizer: AnyTokenizer, tokenizer: TokenizerLike,
enable_auto_tools: bool, enable_auto_tools: bool,
tool_parser_cls: Callable[[AnyTokenizer], ToolParser] | None, tool_parser_cls: Callable[[TokenizerLike], ToolParser] | None,
content: str | None = None, content: str | None = None,
) -> tuple[list[FunctionCall] | None, str | None]: ) -> tuple[list[FunctionCall] | None, str | None]:
function_calls = list[FunctionCall]() function_calls = list[FunctionCall]()
...@@ -1442,7 +1432,7 @@ class OpenAIServing: ...@@ -1442,7 +1432,7 @@ class OpenAIServing:
def _get_decoded_token( def _get_decoded_token(
logprob: Logprob, logprob: Logprob,
token_id: int, token_id: int,
tokenizer: AnyTokenizer, tokenizer: TokenizerLike | None,
return_as_token_id: bool = False, return_as_token_id: bool = False,
) -> str: ) -> str:
if return_as_token_id: if return_as_token_id:
...@@ -1450,6 +1440,12 @@ class OpenAIServing: ...@@ -1450,6 +1440,12 @@ class OpenAIServing:
if logprob.decoded_token is not None: if logprob.decoded_token is not None:
return logprob.decoded_token return logprob.decoded_token
if tokenizer is None:
raise ValueError(
"Unable to get tokenizer because `skip_tokenizer_init=True`"
)
return tokenizer.decode(token_id) return tokenizer.decode(token_id)
def _is_model_supported(self, model_name: str | None) -> bool: def _is_model_supported(self, model_name: str | None) -> bool:
......
...@@ -105,7 +105,7 @@ from vllm.logprobs import Logprob as SampleLogprob ...@@ -105,7 +105,7 @@ from vllm.logprobs import Logprob as SampleLogprob
from vllm.logprobs import SampleLogprobs from vllm.logprobs import SampleLogprobs
from vllm.outputs import CompletionOutput from vllm.outputs import CompletionOutput
from vllm.sampling_params import SamplingParams, StructuredOutputsParams from vllm.sampling_params import SamplingParams, StructuredOutputsParams
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.tokenizers import TokenizerLike
from vllm.utils import random_uuid from vllm.utils import random_uuid
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -492,7 +492,7 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -492,7 +492,7 @@ class OpenAIServingResponses(OpenAIServing):
self, self,
request: ResponsesRequest, request: ResponsesRequest,
prev_response: ResponsesResponse | None, prev_response: ResponsesResponse | None,
tokenizer: AnyTokenizer, tokenizer: TokenizerLike,
): ):
if request.tools is None or ( if request.tools is None or (
request.tool_choice == "none" and self.exclude_tools_when_tool_choice_none request.tool_choice == "none" and self.exclude_tools_when_tool_choice_none
...@@ -563,7 +563,7 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -563,7 +563,7 @@ class OpenAIServingResponses(OpenAIServing):
result_generator: AsyncIterator[ConversationContext], result_generator: AsyncIterator[ConversationContext],
context: ConversationContext, context: ConversationContext,
model_name: str, model_name: str,
tokenizer: AnyTokenizer, tokenizer: TokenizerLike,
request_metadata: RequestResponseMetadata, request_metadata: RequestResponseMetadata,
created_time: int | None = None, created_time: int | None = None,
) -> ErrorResponse | ResponsesResponse: ) -> ErrorResponse | ResponsesResponse:
...@@ -675,7 +675,7 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -675,7 +675,7 @@ class OpenAIServingResponses(OpenAIServing):
self, self,
logprobs: dict[int, SampleLogprob], logprobs: dict[int, SampleLogprob],
top_logprobs: int, top_logprobs: int,
tokenizer: AnyTokenizer, tokenizer: TokenizerLike,
) -> list[LogprobTopLogprob]: ) -> list[LogprobTopLogprob]:
"""Returns the top-k logprobs from the logprobs dictionary.""" """Returns the top-k logprobs from the logprobs dictionary."""
out = [] out = []
...@@ -700,7 +700,7 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -700,7 +700,7 @@ class OpenAIServingResponses(OpenAIServing):
self, self,
token_ids: Sequence[int], token_ids: Sequence[int],
logprobs: SampleLogprobs | None, logprobs: SampleLogprobs | None,
tokenizer: AnyTokenizer, tokenizer: TokenizerLike,
top_logprobs: int | None = None, top_logprobs: int | None = None,
) -> list[Logprob]: ) -> list[Logprob]:
assert logprobs is not None, "logprobs must be provided" assert logprobs is not None, "logprobs must be provided"
...@@ -736,7 +736,7 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -736,7 +736,7 @@ class OpenAIServingResponses(OpenAIServing):
self, self,
token_ids: Sequence[int], token_ids: Sequence[int],
logprobs: SampleLogprobs | None, logprobs: SampleLogprobs | None,
tokenizer: AnyTokenizer, tokenizer: TokenizerLike,
top_logprobs: int | None = None, top_logprobs: int | None = None,
) -> list[response_text_delta_event.Logprob]: ) -> list[response_text_delta_event.Logprob]:
lgs = self._create_response_logprobs( lgs = self._create_response_logprobs(
...@@ -763,7 +763,7 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -763,7 +763,7 @@ class OpenAIServingResponses(OpenAIServing):
self, self,
request: ResponsesRequest, request: ResponsesRequest,
final_output: CompletionOutput, final_output: CompletionOutput,
tokenizer: AnyTokenizer, tokenizer: TokenizerLike,
) -> list[ResponseOutputItem]: ) -> list[ResponseOutputItem]:
if self.reasoning_parser: if self.reasoning_parser:
try: try:
...@@ -1135,7 +1135,7 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -1135,7 +1135,7 @@ class OpenAIServingResponses(OpenAIServing):
result_generator: AsyncIterator[ConversationContext | None], result_generator: AsyncIterator[ConversationContext | None],
context: ConversationContext, context: ConversationContext,
model_name: str, model_name: str,
tokenizer: AnyTokenizer, tokenizer: TokenizerLike,
request_metadata: RequestResponseMetadata, request_metadata: RequestResponseMetadata,
created_time: int, created_time: int,
_increment_sequence_number_and_return: Callable[ _increment_sequence_number_and_return: Callable[
...@@ -1438,7 +1438,7 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -1438,7 +1438,7 @@ class OpenAIServingResponses(OpenAIServing):
result_generator: AsyncIterator[ConversationContext | None], result_generator: AsyncIterator[ConversationContext | None],
context: ConversationContext, context: ConversationContext,
model_name: str, model_name: str,
tokenizer: AnyTokenizer, tokenizer: TokenizerLike,
request_metadata: RequestResponseMetadata, request_metadata: RequestResponseMetadata,
created_time: int, created_time: int,
_increment_sequence_number_and_return: Callable[ _increment_sequence_number_and_return: Callable[
...@@ -1891,7 +1891,7 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -1891,7 +1891,7 @@ class OpenAIServingResponses(OpenAIServing):
result_generator: AsyncIterator[ConversationContext | None], result_generator: AsyncIterator[ConversationContext | None],
context: ConversationContext, context: ConversationContext,
model_name: str, model_name: str,
tokenizer: AnyTokenizer, tokenizer: TokenizerLike,
request_metadata: RequestResponseMetadata, request_metadata: RequestResponseMetadata,
created_time: int | None = None, created_time: int | None = None,
) -> AsyncGenerator[StreamingResponsesResponse, None]: ) -> AsyncGenerator[StreamingResponsesResponse, None]:
......
...@@ -36,7 +36,7 @@ from vllm.inputs.data import TokensPrompt ...@@ -36,7 +36,7 @@ from vllm.inputs.data import TokensPrompt
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.outputs import PoolingRequestOutput, ScoringRequestOutput from vllm.outputs import PoolingRequestOutput, ScoringRequestOutput
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer from vllm.tokenizers import MistralTokenizer, TokenizerLike
from vllm.utils.async_utils import make_async, merge_async_iterators from vllm.utils.async_utils import make_async, merge_async_iterators
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -60,7 +60,7 @@ class ServingScores(OpenAIServing): ...@@ -60,7 +60,7 @@ class ServingScores(OpenAIServing):
async def _embedding_score( async def _embedding_score(
self, self,
tokenizer: AnyTokenizer, tokenizer: TokenizerLike,
texts_1: list[str], texts_1: list[str],
texts_2: list[str], texts_2: list[str],
request: RerankRequest | ScoreRequest, request: RerankRequest | ScoreRequest,
...@@ -153,7 +153,7 @@ class ServingScores(OpenAIServing): ...@@ -153,7 +153,7 @@ class ServingScores(OpenAIServing):
def _preprocess_score( def _preprocess_score(
self, self,
request: RerankRequest | ScoreRequest, request: RerankRequest | ScoreRequest,
tokenizer: AnyTokenizer, tokenizer: TokenizerLike,
tokenization_kwargs: dict[str, Any], tokenization_kwargs: dict[str, Any],
data_1: str | ScoreContentPartParam, data_1: str | ScoreContentPartParam,
data_2: str | ScoreContentPartParam, data_2: str | ScoreContentPartParam,
...@@ -175,7 +175,7 @@ class ServingScores(OpenAIServing): ...@@ -175,7 +175,7 @@ class ServingScores(OpenAIServing):
async def _cross_encoding_score( async def _cross_encoding_score(
self, self,
tokenizer: AnyTokenizer, tokenizer: TokenizerLike,
data_1: list[str] | list[ScoreContentPartParam], data_1: list[str] | list[ScoreContentPartParam],
data_2: list[str] | list[ScoreContentPartParam], data_2: list[str] | list[ScoreContentPartParam],
request: RerankRequest | ScoreRequest, request: RerankRequest | ScoreRequest,
......
...@@ -22,7 +22,7 @@ from vllm.entrypoints.openai.serving_engine import OpenAIServing ...@@ -22,7 +22,7 @@ from vllm.entrypoints.openai.serving_engine import OpenAIServing
from vllm.entrypoints.openai.serving_models import OpenAIServingModels from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.entrypoints.renderer import RenderConfig from vllm.entrypoints.renderer import RenderConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.tokenizers import TokenizerLike
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -170,7 +170,7 @@ class OpenAIServingTokenization(OpenAIServing): ...@@ -170,7 +170,7 @@ class OpenAIServingTokenization(OpenAIServing):
@dataclass @dataclass
class TokenizerInfo: class TokenizerInfo:
tokenizer: AnyTokenizer tokenizer: TokenizerLike
chat_template: str | None chat_template: str | None
def to_dict(self) -> dict[str, Any]: def to_dict(self) -> dict[str, Any]:
......
...@@ -22,7 +22,7 @@ from vllm.logger import init_logger ...@@ -22,7 +22,7 @@ from vllm.logger import init_logger
from vllm.sampling_params import ( from vllm.sampling_params import (
StructuredOutputsParams, StructuredOutputsParams,
) )
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.tokenizers import TokenizerLike
from vllm.utils.collection_utils import is_list_of from vllm.utils.collection_utils import is_list_of
from vllm.utils.import_utils import import_from_path from vllm.utils.import_utils import import_from_path
...@@ -36,7 +36,7 @@ class ToolParser: ...@@ -36,7 +36,7 @@ class ToolParser:
derived classes. derived classes.
""" """
def __init__(self, tokenizer: AnyTokenizer): def __init__(self, tokenizer: TokenizerLike):
self.prev_tool_call_arr: list[dict] = [] self.prev_tool_call_arr: list[dict] = []
# the index of the tool call that is currently being parsed # the index of the tool call that is currently being parsed
self.current_tool_id: int = -1 self.current_tool_id: int = -1
......
...@@ -19,13 +19,13 @@ from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( ...@@ -19,13 +19,13 @@ from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
ToolParser, ToolParser,
) )
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.tokenizers import TokenizerLike
logger = init_logger(__name__) logger = init_logger(__name__)
class DeepSeekV31ToolParser(ToolParser): class DeepSeekV31ToolParser(ToolParser):
def __init__(self, tokenizer: AnyTokenizer): def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer) super().__init__(tokenizer)
self.current_tool_name_sent: bool = False self.current_tool_name_sent: bool = False
......
...@@ -19,13 +19,13 @@ from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( ...@@ -19,13 +19,13 @@ from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
ToolParser, ToolParser,
) )
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.tokenizers import TokenizerLike
logger = init_logger(__name__) logger = init_logger(__name__)
class DeepSeekV3ToolParser(ToolParser): class DeepSeekV3ToolParser(ToolParser):
def __init__(self, tokenizer: AnyTokenizer): def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer) super().__init__(tokenizer)
self.current_tool_name_sent: bool = False self.current_tool_name_sent: bool = False
......
...@@ -19,13 +19,13 @@ from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( ...@@ -19,13 +19,13 @@ from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
ToolParser, ToolParser,
) )
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.tokenizers import TokenizerLike
logger = init_logger(__name__) logger = init_logger(__name__)
class Ernie45ToolParser(ToolParser): class Ernie45ToolParser(ToolParser):
def __init__(self, tokenizer: AnyTokenizer): def __init__(self, tokenizer: TokenizerLike):
""" """
Ernie thinking model format: Ernie thinking model format:
abc\n</think>\n\n\n<tool_call>\ndef\n</tool_call>\n abc\n</think>\n\n\n<tool_call>\ndef\n</tool_call>\n
......
...@@ -22,13 +22,13 @@ from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( ...@@ -22,13 +22,13 @@ from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
ToolParser, ToolParser,
) )
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.tokenizers import TokenizerLike
logger = init_logger(__name__) logger = init_logger(__name__)
class Glm4MoeModelToolParser(ToolParser): class Glm4MoeModelToolParser(ToolParser):
def __init__(self, tokenizer: AnyTokenizer): def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer) super().__init__(tokenizer)
self.current_tool_name_sent = False self.current_tool_name_sent = False
self.prev_tool_call_arr: list[dict] = [] self.prev_tool_call_arr: list[dict] = []
......
...@@ -29,7 +29,7 @@ from vllm.entrypoints.openai.tool_parsers.utils import ( ...@@ -29,7 +29,7 @@ from vllm.entrypoints.openai.tool_parsers.utils import (
partial_json_loads, partial_json_loads,
) )
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.tokenizers import TokenizerLike
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -44,7 +44,7 @@ class Granite20bFCToolParser(ToolParser): ...@@ -44,7 +44,7 @@ class Granite20bFCToolParser(ToolParser):
are all set are all set
""" """
def __init__(self, tokenizer: AnyTokenizer): def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer) super().__init__(tokenizer)
self.bot_token = "<function_call>" self.bot_token = "<function_call>"
......
...@@ -27,7 +27,7 @@ from vllm.entrypoints.openai.tool_parsers.utils import ( ...@@ -27,7 +27,7 @@ from vllm.entrypoints.openai.tool_parsers.utils import (
partial_json_loads, partial_json_loads,
) )
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.tokenizers import TokenizerLike
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -42,7 +42,7 @@ class GraniteToolParser(ToolParser): ...@@ -42,7 +42,7 @@ class GraniteToolParser(ToolParser):
are all set are all set
""" """
def __init__(self, tokenizer: AnyTokenizer): def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer) super().__init__(tokenizer)
# for granite 3.0, the token `<|tool_call|>` # for granite 3.0, the token `<|tool_call|>`
self.bot_token = "<|tool_call|>" self.bot_token = "<|tool_call|>"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment