Unverified Commit 6c47f6bf authored by Zhuohan Li's avatar Zhuohan Li Committed by GitHub
Browse files

[Core] Remove tokenizer group in vLLM (#24078)


Signed-off-by: default avatarZhuohan Li <zhuohan123@gmail.com>
parent c15309a7
...@@ -12,6 +12,7 @@ from typing import TYPE_CHECKING, Any, Optional, Union ...@@ -12,6 +12,7 @@ from typing import TYPE_CHECKING, Any, Optional, Union
import huggingface_hub import huggingface_hub
from transformers import (AutoTokenizer, PreTrainedTokenizer, from transformers import (AutoTokenizer, PreTrainedTokenizer,
PreTrainedTokenizerFast) PreTrainedTokenizerFast)
from typing_extensions import assert_never
from vllm import envs from vllm import envs
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -19,7 +20,6 @@ from vllm.transformers_utils.config import ( ...@@ -19,7 +20,6 @@ from vllm.transformers_utils.config import (
get_sentence_transformer_tokenizer_config) get_sentence_transformer_tokenizer_config)
from vllm.transformers_utils.tokenizers import MistralTokenizer from vllm.transformers_utils.tokenizers import MistralTokenizer
from vllm.transformers_utils.utils import check_gguf_file from vllm.transformers_utils.utils import check_gguf_file
from vllm.utils import make_async
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.config import ModelConfig from vllm.config import ModelConfig
...@@ -274,20 +274,19 @@ def cached_tokenizer_from_config( ...@@ -274,20 +274,19 @@ def cached_tokenizer_from_config(
) )
def get_lora_tokenizer(lora_request: LoRARequest, *args, def init_tokenizer_from_configs(model_config: ModelConfig):
**kwargs) -> Optional[AnyTokenizer]: runner_type = model_config.runner_type
if lora_request is None: if runner_type == "generate" or runner_type == "draft":
return None truncation_side = "left"
try: elif runner_type == "pooling":
tokenizer = get_tokenizer(lora_request.lora_path, *args, **kwargs) truncation_side = "right"
except Exception as e: else:
# No tokenizer was found in the LoRA folder, assert_never(runner_type)
# use base model tokenizer
logger.warning(
"No tokenizer found in %s, using base model tokenizer instead. "
"(Exception: %s)", lora_request.lora_path, e)
tokenizer = None
return tokenizer
get_lora_tokenizer_async = make_async(get_lora_tokenizer) return get_tokenizer(
model_config.tokenizer,
tokenizer_mode=model_config.tokenizer_mode,
trust_remote_code=model_config.trust_remote_code,
revision=model_config.tokenizer_revision,
truncation_side=truncation_side,
)
...@@ -61,6 +61,11 @@ class TokenizerBase(ABC): ...@@ -61,6 +61,11 @@ class TokenizerBase(ABC):
def max_token_id(self) -> int: def max_token_id(self) -> int:
raise NotImplementedError() raise NotImplementedError()
@property
@abstractmethod
def truncation_side(self) -> str:
raise NotImplementedError()
def __len__(self) -> int: def __len__(self) -> int:
return self.vocab_size return self.vocab_size
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
from typing_extensions import assert_never
from vllm.config import ModelConfig, SchedulerConfig
from vllm.config.lora import LoRAConfig
from vllm.lora.request import LoRARequest
from vllm.transformers_utils.tokenizer import (AnyTokenizer, encode_tokens,
get_lora_tokenizer,
get_lora_tokenizer_async,
get_tokenizer)
from vllm.utils import LRUCache
class TokenizerGroup:
"""A group of tokenizers that can be used for LoRA adapters."""
def __init__(self, tokenizer_id: str, enable_lora: bool, max_num_seqs: int,
max_input_length: Optional[int], **tokenizer_config):
self.tokenizer_id = tokenizer_id
self.tokenizer_config = tokenizer_config
self.enable_lora = enable_lora
self.max_input_length = max_input_length
self.truncation_side = tokenizer_config.get("truncation_side", "left")
self.tokenizer = get_tokenizer(self.tokenizer_id, **tokenizer_config)
max_loras = tokenizer_config.get("max_loras", 0)
self.lora_tokenizers = LRUCache[int, AnyTokenizer](
capacity=max(max_loras, max_num_seqs) if enable_lora else 0)
def get_max_input_len(self,
lora_request: Optional[LoRARequest] = None
) -> Optional[int]:
"""Get the maximum input length for the LoRA request."""
return self.max_input_length
def _raise_if_input_too_long(self,
encoded_tokens: list[int],
lora_request: Optional[LoRARequest] = None):
input_length = len(encoded_tokens)
if lora_request:
max_input_length = (lora_request.long_lora_max_len
or self.max_input_length)
else:
max_input_length = self.max_input_length
if max_input_length is not None and input_length > max_input_length:
raise ValueError("Input too long.", input_length, max_input_length)
def encode(self,
prompt: str,
max_length: Optional[int] = None,
truncation: Optional[bool] = None,
lora_request: Optional[LoRARequest] = None,
add_special_tokens: Optional[bool] = None) -> list[int]:
tokenizer = self.get_lora_tokenizer(lora_request)
ret = encode_tokens(tokenizer,
prompt,
max_length=max_length,
truncation=truncation,
add_special_tokens=add_special_tokens)
self._raise_if_input_too_long(ret, lora_request)
return ret
async def encode_async(
self,
prompt: str,
max_length: Optional[int] = None,
truncation: Optional[bool] = None,
lora_request: Optional[LoRARequest] = None,
add_special_tokens: Optional[bool] = None) -> list[int]:
tokenizer = await self.get_lora_tokenizer_async(lora_request)
ret = encode_tokens(tokenizer,
prompt,
max_length=max_length,
truncation=truncation,
add_special_tokens=add_special_tokens)
self._raise_if_input_too_long(ret, lora_request)
return ret
def get_lora_tokenizer(
self,
lora_request: Optional[LoRARequest] = None,
) -> AnyTokenizer:
if not lora_request or not self.enable_lora:
return self.tokenizer
if lora_request.lora_int_id not in self.lora_tokenizers:
tokenizer = (get_lora_tokenizer(
lora_request, **self.tokenizer_config) or self.tokenizer)
self.lora_tokenizers.put(lora_request.lora_int_id, tokenizer)
return tokenizer
else:
return self.lora_tokenizers[lora_request.lora_int_id]
async def get_lora_tokenizer_async(
self,
lora_request: Optional[LoRARequest] = None,
) -> AnyTokenizer:
if not lora_request or not self.enable_lora:
return self.tokenizer
if lora_request.lora_int_id not in self.lora_tokenizers:
tokenizer = (await get_lora_tokenizer_async(
lora_request, **self.tokenizer_config) or self.tokenizer)
self.lora_tokenizers.put(lora_request.lora_int_id, tokenizer)
return tokenizer
else:
return self.lora_tokenizers[lora_request.lora_int_id]
def init_tokenizer_from_configs(model_config: ModelConfig,
scheduler_config: SchedulerConfig,
lora_config: Optional[LoRAConfig]):
runner_type = model_config.runner_type
if runner_type == "generate" or runner_type == "draft":
truncation_side = "left"
elif runner_type == "pooling":
truncation_side = "right"
else:
assert_never(runner_type)
return TokenizerGroup(
tokenizer_id=model_config.tokenizer,
enable_lora=bool(lora_config),
max_num_seqs=scheduler_config.max_num_seqs,
max_loras=lora_config.max_loras if lora_config else 0,
max_input_length=None,
tokenizer_mode=model_config.tokenizer_mode,
trust_remote_code=model_config.trust_remote_code,
revision=model_config.tokenizer_revision,
truncation_side=truncation_side)
...@@ -327,6 +327,10 @@ class MistralTokenizer(TokenizerBase): ...@@ -327,6 +327,10 @@ class MistralTokenizer(TokenizerBase):
def max_token_id(self) -> int: def max_token_id(self) -> int:
return self._max_token_id return self._max_token_id
@property
def truncation_side(self) -> str:
raise NotImplementedError()
def __len__(self) -> int: def __len__(self) -> int:
return self.vocab_size return self.vocab_size
......
...@@ -29,8 +29,8 @@ from vllm.tasks import SupportedTask ...@@ -29,8 +29,8 @@ from vllm.tasks import SupportedTask
from vllm.tracing import init_tracer from vllm.tracing import init_tracer
from vllm.transformers_utils.config import ( from vllm.transformers_utils.config import (
maybe_register_config_serialize_by_value) maybe_register_config_serialize_by_value)
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer import (AnyTokenizer,
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs init_tokenizer_from_configs)
from vllm.usage.usage_lib import UsageContext from vllm.usage.usage_lib import UsageContext
from vllm.utils import (Device, as_list, cancel_task_threadsafe, cdiv, from vllm.utils import (Device, as_list, cancel_task_threadsafe, cdiv,
deprecate_kwargs) deprecate_kwargs)
...@@ -112,9 +112,7 @@ class AsyncLLM(EngineClient): ...@@ -112,9 +112,7 @@ class AsyncLLM(EngineClient):
else: else:
# Tokenizer (+ ensure liveness if running in another process). # Tokenizer (+ ensure liveness if running in another process).
self.tokenizer = init_tokenizer_from_configs( self.tokenizer = init_tokenizer_from_configs(
model_config=vllm_config.model_config, model_config=vllm_config.model_config)
scheduler_config=vllm_config.scheduler_config,
lora_config=vllm_config.lora_config)
# Processor (converts Inputs --> EngineCoreRequests). # Processor (converts Inputs --> EngineCoreRequests).
self.processor = Processor( self.processor = Processor(
...@@ -596,15 +594,12 @@ class AsyncLLM(EngineClient): ...@@ -596,15 +594,12 @@ class AsyncLLM(EngineClient):
async def get_input_preprocessor(self) -> InputPreprocessor: async def get_input_preprocessor(self) -> InputPreprocessor:
return self.processor.input_preprocessor return self.processor.input_preprocessor
async def get_tokenizer( async def get_tokenizer(self) -> AnyTokenizer:
self,
lora_request: Optional[LoRARequest] = None,
) -> AnyTokenizer:
if self.tokenizer is None: if self.tokenizer is None:
raise ValueError("Unable to get tokenizer because " raise ValueError("Unable to get tokenizer because "
"skip_tokenizer_init is True") "skip_tokenizer_init is True")
return self.tokenizer.get_lora_tokenizer(lora_request) return self.tokenizer
async def is_tracing_enabled(self) -> bool: async def is_tracing_enabled(self) -> bool:
return self.observability_config.otlp_traces_endpoint is not None return self.observability_config.otlp_traces_endpoint is not None
......
...@@ -20,8 +20,8 @@ from vllm.pooling_params import PoolingParams ...@@ -20,8 +20,8 @@ from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.tasks import SupportedTask from vllm.tasks import SupportedTask
from vllm.tracing import init_tracer from vllm.tracing import init_tracer
from vllm.transformers_utils.tokenizer_group import ( from vllm.transformers_utils.tokenizer import (AnyTokenizer,
TokenizerGroup, init_tokenizer_from_configs) init_tokenizer_from_configs)
from vllm.usage.usage_lib import UsageContext from vllm.usage.usage_lib import UsageContext
from vllm.utils import Device from vllm.utils import Device
from vllm.v1.engine.core_client import EngineCoreClient from vllm.v1.engine.core_client import EngineCoreClient
...@@ -89,9 +89,7 @@ class LLMEngine: ...@@ -89,9 +89,7 @@ class LLMEngine:
else: else:
# Tokenizer (+ ensure liveness if running in another process). # Tokenizer (+ ensure liveness if running in another process).
self.tokenizer = init_tokenizer_from_configs( self.tokenizer = init_tokenizer_from_configs(
model_config=vllm_config.model_config, model_config=vllm_config.model_config)
scheduler_config=vllm_config.scheduler_config,
lora_config=vllm_config.lora_config)
# Processor (convert Inputs --> EngineCoreRequests) # Processor (convert Inputs --> EngineCoreRequests)
self.processor = Processor(vllm_config=vllm_config, self.processor = Processor(vllm_config=vllm_config,
...@@ -297,7 +295,7 @@ class LLMEngine: ...@@ -297,7 +295,7 @@ class LLMEngine:
assert self.log_stats, "Stat logging disabled" assert self.log_stats, "Stat logging disabled"
return get_metrics_snapshot() return get_metrics_snapshot()
def get_tokenizer_group(self) -> TokenizerGroup: def get_tokenizer(self) -> AnyTokenizer:
if self.tokenizer is None: if self.tokenizer is None:
raise ValueError("Unable to get tokenizer because " raise ValueError("Unable to get tokenizer because "
"skip_tokenizer_init is True") "skip_tokenizer_init is True")
......
...@@ -14,7 +14,6 @@ from vllm.sampling_params import RequestOutputKind ...@@ -14,7 +14,6 @@ from vllm.sampling_params import RequestOutputKind
from vllm.tracing import (SpanAttributes, SpanKind, Tracer, from vllm.tracing import (SpanAttributes, SpanKind, Tracer,
extract_trace_context) extract_trace_context)
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest, FinishReason from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest, FinishReason
from vllm.v1.engine.detokenizer import IncrementalDetokenizer from vllm.v1.engine.detokenizer import IncrementalDetokenizer
from vllm.v1.engine.logprobs import LogprobsProcessor from vllm.v1.engine.logprobs import LogprobsProcessor
...@@ -290,7 +289,7 @@ class RequestState: ...@@ -290,7 +289,7 @@ class RequestState:
class OutputProcessor: class OutputProcessor:
"""Process EngineCoreOutputs into RequestOutputs.""" """Process EngineCoreOutputs into RequestOutputs."""
def __init__(self, tokenizer: TokenizerGroup, log_stats: bool): def __init__(self, tokenizer: AnyTokenizer, log_stats: bool):
self.log_stats = log_stats self.log_stats = log_stats
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.request_states: dict[str, RequestState] = {} self.request_states: dict[str, RequestState] = {}
...@@ -347,10 +346,7 @@ class OutputProcessor: ...@@ -347,10 +346,7 @@ class OutputProcessor:
if request_id in self.request_states: if request_id in self.request_states:
raise ValueError(f"Request id {request_id} already running.") raise ValueError(f"Request id {request_id} already running.")
tokenizer = None if not self.tokenizer else \ req_state = RequestState.from_new_request(tokenizer=self.tokenizer,
self.tokenizer.get_lora_tokenizer(request.lora_request)
req_state = RequestState.from_new_request(tokenizer=tokenizer,
request=request, request=request,
prompt=prompt, prompt=prompt,
parent_req=parent_req, parent_req=parent_req,
......
...@@ -9,6 +9,7 @@ from vllm.config import VllmConfig ...@@ -9,6 +9,7 @@ from vllm.config import VllmConfig
from vllm.inputs import ProcessorInputs, PromptType, SingletonInputs from vllm.inputs import ProcessorInputs, PromptType, SingletonInputs
from vllm.inputs.parse import split_enc_dec_inputs from vllm.inputs.parse import split_enc_dec_inputs
from vllm.inputs.preprocess import InputPreprocessor from vllm.inputs.preprocess import InputPreprocessor
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.multimodal.cache import processor_cache_from_config from vllm.multimodal.cache import processor_cache_from_config
...@@ -17,7 +18,7 @@ from vllm.multimodal.processing import EncDecMultiModalProcessor ...@@ -17,7 +18,7 @@ from vllm.multimodal.processing import EncDecMultiModalProcessor
from vllm.multimodal.utils import argsort_mm_positions from vllm.multimodal.utils import argsort_mm_positions
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer_group import TokenizerGroup from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.v1.engine import EngineCoreRequest from vllm.v1.engine import EngineCoreRequest
from vllm.v1.structured_output.backend_guidance import ( from vllm.v1.structured_output.backend_guidance import (
validate_guidance_grammar) validate_guidance_grammar)
...@@ -28,13 +29,15 @@ from vllm.v1.structured_output.backend_outlines import ( ...@@ -28,13 +29,15 @@ from vllm.v1.structured_output.backend_outlines import (
from vllm.v1.structured_output.backend_xgrammar import ( from vllm.v1.structured_output.backend_xgrammar import (
validate_xgrammar_grammar) validate_xgrammar_grammar)
logger = init_logger(__name__)
class Processor: class Processor:
def __init__( def __init__(
self, self,
vllm_config: VllmConfig, vllm_config: VllmConfig,
tokenizer: TokenizerGroup, tokenizer: AnyTokenizer,
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
): ):
...@@ -90,7 +93,6 @@ class Processor: ...@@ -90,7 +93,6 @@ class Processor:
def _validate_sampling_params( def _validate_sampling_params(
self, self,
params: SamplingParams, params: SamplingParams,
lora_request: Optional[LoRARequest],
) -> None: ) -> None:
self._validate_structured_output(params) self._validate_structured_output(params)
self._validate_logit_bias(params) self._validate_logit_bias(params)
...@@ -103,8 +105,7 @@ class Processor: ...@@ -103,8 +105,7 @@ class Processor:
# When skip_tokenizer_init=True, we can't validate token IDs # When skip_tokenizer_init=True, we can't validate token IDs
# Skip validation and let the model handle invalid tokens # Skip validation and let the model handle invalid tokens
return return
tokenizer = self.tokenizer.get_lora_tokenizer(lora_request) vocab_size = len(self.tokenizer)
vocab_size = len(tokenizer)
if not all(0 <= tid < vocab_size for tid in params.allowed_token_ids): if not all(0 <= tid < vocab_size for tid in params.allowed_token_ids):
raise ValueError( raise ValueError(
"allowed_token_ids contains out-of-vocab token id!") "allowed_token_ids contains out-of-vocab token id!")
...@@ -144,7 +145,6 @@ class Processor: ...@@ -144,7 +145,6 @@ class Processor:
def _validate_params( def _validate_params(
self, self,
params: Union[SamplingParams, PoolingParams], params: Union[SamplingParams, PoolingParams],
lora_request: Optional[LoRARequest],
): ):
""" """
Validate supported SamplingParam. Validate supported SamplingParam.
...@@ -155,14 +155,14 @@ class Processor: ...@@ -155,14 +155,14 @@ class Processor:
return return
self._validate_logprobs(params) self._validate_logprobs(params)
self._validate_sampling_params(params, lora_request) self._validate_sampling_params(params)
self._validate_supported_sampling_params(params) self._validate_supported_sampling_params(params)
def _validate_multi_modal_uuids(self, prompt: PromptType) -> None: def _validate_multi_modal_uuids(self, prompt: PromptType) -> None:
""" """
Validate that user-provided multi_modal_uuids align with Validate that user-provided multi_modal_uuids align with
multi_modal_data in the incoming request prompt(s). multi_modal_data in the incoming request prompt(s).
Only checks lengths; `None` entries are allowed and will be Only checks lengths; `None` entries are allowed and will be
auto-hashed downstream. auto-hashed downstream.
""" """
...@@ -202,10 +202,22 @@ class Processor: ...@@ -202,10 +202,22 @@ class Processor:
_validate_single_prompt(prompt) # type: ignore[arg-type] _validate_single_prompt(prompt) # type: ignore[arg-type]
def _validate_lora(self, lora_request: Optional[LoRARequest]) -> None: def _validate_lora(self, lora_request: Optional[LoRARequest]) -> None:
if lora_request is not None and not self.lora_config: if lora_request is None:
return
# LoRA request passed in while LoRA is not enabled
if not self.lora_config:
raise ValueError(f"Got lora_request {lora_request} but LoRA is " raise ValueError(f"Got lora_request {lora_request} but LoRA is "
"not enabled!") "not enabled!")
if self.tokenizer is not None:
logger.warning_once(
"vLLM has deprecated support for supporting different "
"tokenizers for different LoRAs. By default, vLLM uses base "
"model's tokenizer. If you are using a LoRA "
"with its own tokenizer, consider specifying `--tokenizer "
"[lora_path]` to use the LoRA tokenizer.")
def _validate_structured_output(self, params: SamplingParams) -> None: def _validate_structured_output(self, params: SamplingParams) -> None:
if not params.guided_decoding or not self.decoding_config: if not params.guided_decoding or not self.decoding_config:
return return
...@@ -326,7 +338,7 @@ class Processor: ...@@ -326,7 +338,7 @@ class Processor:
# TODO(woosuk): Support pooling models. # TODO(woosuk): Support pooling models.
self._validate_lora(lora_request) self._validate_lora(lora_request)
self._validate_params(params, lora_request) self._validate_params(params)
data_parallel_size = self.vllm_config.parallel_config.data_parallel_size data_parallel_size = self.vllm_config.parallel_config.data_parallel_size
if data_parallel_rank is not None and not (0 <= data_parallel_rank < if data_parallel_rank is not None and not (0 <= data_parallel_rank <
...@@ -365,7 +377,6 @@ class Processor: ...@@ -365,7 +377,6 @@ class Processor:
processed_inputs: ProcessorInputs = self.input_preprocessor.preprocess( processed_inputs: ProcessorInputs = self.input_preprocessor.preprocess(
prompt, prompt,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
mm_uuids=mm_uuids, mm_uuids=mm_uuids,
) )
from vllm.platforms import current_platform from vllm.platforms import current_platform
...@@ -375,9 +386,9 @@ class Processor: ...@@ -375,9 +386,9 @@ class Processor:
processed_inputs=processed_inputs, processed_inputs=processed_inputs,
) )
eos_token_id = self.input_preprocessor.get_eos_token_id(lora_request) eos_token_id = self.input_preprocessor.get_eos_token_id()
self._validate_model_inputs(processed_inputs, lora_request) self._validate_model_inputs(processed_inputs)
encoder_inputs, decoder_inputs = split_enc_dec_inputs(processed_inputs) encoder_inputs, decoder_inputs = split_enc_dec_inputs(processed_inputs)
...@@ -394,8 +405,7 @@ class Processor: ...@@ -394,8 +405,7 @@ class Processor:
sampling_params.update_from_generation_config( sampling_params.update_from_generation_config(
self.generation_config_fields, eos_token_id) self.generation_config_fields, eos_token_id)
if self.tokenizer is not None: if self.tokenizer is not None:
sampling_params.update_from_tokenizer( sampling_params.update_from_tokenizer(self.tokenizer)
self.tokenizer.get_lora_tokenizer(lora_request))
else: else:
pooling_params = params.clone() pooling_params = params.clone()
...@@ -436,24 +446,17 @@ class Processor: ...@@ -436,24 +446,17 @@ class Processor:
trace_headers=trace_headers, trace_headers=trace_headers,
) )
def _validate_model_inputs(self, def _validate_model_inputs(self, inputs: ProcessorInputs):
inputs: ProcessorInputs,
lora_request: Optional[LoRARequest] = None):
encoder_inputs, decoder_inputs = split_enc_dec_inputs(inputs) encoder_inputs, decoder_inputs = split_enc_dec_inputs(inputs)
if encoder_inputs is not None: if encoder_inputs is not None:
self._validate_model_input(encoder_inputs, self._validate_model_input(encoder_inputs, prompt_type="encoder")
lora_request,
prompt_type="encoder")
self._validate_model_input(decoder_inputs, self._validate_model_input(decoder_inputs, prompt_type="decoder")
lora_request,
prompt_type="decoder")
def _validate_model_input( def _validate_model_input(
self, self,
prompt_inputs: SingletonInputs, prompt_inputs: SingletonInputs,
lora_request: Optional[LoRARequest],
*, *,
prompt_type: Literal["encoder", "decoder"], prompt_type: Literal["encoder", "decoder"],
): ):
...@@ -469,7 +472,7 @@ class Processor: ...@@ -469,7 +472,7 @@ class Processor:
if self.model_config.skip_tokenizer_init: if self.model_config.skip_tokenizer_init:
tokenizer = None tokenizer = None
else: else:
tokenizer = self.tokenizer.get_lora_tokenizer(lora_request) tokenizer = self.tokenizer
max_input_id = max(prompt_ids, default=0) max_input_id = max(prompt_ids, default=0)
# NOTE: tokenizer.max_token_id is the tokenizer’s vocab size while # NOTE: tokenizer.max_token_id is the tokenizer’s vocab size while
......
...@@ -9,7 +9,7 @@ from typing import TYPE_CHECKING, Optional ...@@ -9,7 +9,7 @@ from typing import TYPE_CHECKING, Optional
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.reasoning import ReasoningParserManager from vllm.reasoning import ReasoningParserManager
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs from vllm.transformers_utils.tokenizer import init_tokenizer_from_configs
from vllm.utils import LazyLoader from vllm.utils import LazyLoader
from vllm.v1.structured_output.backend_guidance import GuidanceBackend from vllm.v1.structured_output.backend_guidance import GuidanceBackend
from vllm.v1.structured_output.backend_types import (StructuredOutputBackend, from vllm.v1.structured_output.backend_types import (StructuredOutputBackend,
...@@ -60,10 +60,7 @@ class StructuredOutputManager: ...@@ -60,10 +60,7 @@ class StructuredOutputManager:
max_workers = max(1, (multiprocessing.cpu_count() + 1) // 2) max_workers = max(1, (multiprocessing.cpu_count() + 1) // 2)
self.executor = ThreadPoolExecutor(max_workers=max_workers) self.executor = ThreadPoolExecutor(max_workers=max_workers)
self.tokenizer = init_tokenizer_from_configs( self.tokenizer = init_tokenizer_from_configs(
model_config=self.vllm_config.model_config, model_config=self.vllm_config.model_config)
scheduler_config=self.vllm_config.scheduler_config,
lora_config=self.vllm_config.lora_config,
).get_lora_tokenizer(None)
reasoning_backend = \ reasoning_backend = \
self.vllm_config.decoding_config.reasoning_backend self.vllm_config.decoding_config.reasoning_backend
if reasoning_backend: if reasoning_backend:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment