Unverified Commit 6c47f6bf authored by Zhuohan Li's avatar Zhuohan Li Committed by GitHub
Browse files

[Core] Remove tokenizer group in vLLM (#24078)


Signed-off-by: default avatarZhuohan Li <zhuohan123@gmail.com>
parent c15309a7
......@@ -12,6 +12,7 @@ from typing import TYPE_CHECKING, Any, Optional, Union
import huggingface_hub
from transformers import (AutoTokenizer, PreTrainedTokenizer,
PreTrainedTokenizerFast)
from typing_extensions import assert_never
from vllm import envs
from vllm.logger import init_logger
......@@ -19,7 +20,6 @@ from vllm.transformers_utils.config import (
get_sentence_transformer_tokenizer_config)
from vllm.transformers_utils.tokenizers import MistralTokenizer
from vllm.transformers_utils.utils import check_gguf_file
from vllm.utils import make_async
if TYPE_CHECKING:
from vllm.config import ModelConfig
......@@ -274,20 +274,19 @@ def cached_tokenizer_from_config(
)
def get_lora_tokenizer(lora_request: LoRARequest, *args,
**kwargs) -> Optional[AnyTokenizer]:
if lora_request is None:
return None
try:
tokenizer = get_tokenizer(lora_request.lora_path, *args, **kwargs)
except Exception as e:
# No tokenizer was found in the LoRA folder,
# use base model tokenizer
logger.warning(
"No tokenizer found in %s, using base model tokenizer instead. "
"(Exception: %s)", lora_request.lora_path, e)
tokenizer = None
return tokenizer
def init_tokenizer_from_configs(model_config: ModelConfig):
runner_type = model_config.runner_type
if runner_type == "generate" or runner_type == "draft":
truncation_side = "left"
elif runner_type == "pooling":
truncation_side = "right"
else:
assert_never(runner_type)
get_lora_tokenizer_async = make_async(get_lora_tokenizer)
return get_tokenizer(
model_config.tokenizer,
tokenizer_mode=model_config.tokenizer_mode,
trust_remote_code=model_config.trust_remote_code,
revision=model_config.tokenizer_revision,
truncation_side=truncation_side,
)
......@@ -61,6 +61,11 @@ class TokenizerBase(ABC):
def max_token_id(self) -> int:
raise NotImplementedError()
@property
@abstractmethod
def truncation_side(self) -> str:
raise NotImplementedError()
def __len__(self) -> int:
return self.vocab_size
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
from typing_extensions import assert_never
from vllm.config import ModelConfig, SchedulerConfig
from vllm.config.lora import LoRAConfig
from vllm.lora.request import LoRARequest
from vllm.transformers_utils.tokenizer import (AnyTokenizer, encode_tokens,
get_lora_tokenizer,
get_lora_tokenizer_async,
get_tokenizer)
from vllm.utils import LRUCache
class TokenizerGroup:
"""A group of tokenizers that can be used for LoRA adapters."""
def __init__(self, tokenizer_id: str, enable_lora: bool, max_num_seqs: int,
max_input_length: Optional[int], **tokenizer_config):
self.tokenizer_id = tokenizer_id
self.tokenizer_config = tokenizer_config
self.enable_lora = enable_lora
self.max_input_length = max_input_length
self.truncation_side = tokenizer_config.get("truncation_side", "left")
self.tokenizer = get_tokenizer(self.tokenizer_id, **tokenizer_config)
max_loras = tokenizer_config.get("max_loras", 0)
self.lora_tokenizers = LRUCache[int, AnyTokenizer](
capacity=max(max_loras, max_num_seqs) if enable_lora else 0)
def get_max_input_len(self,
lora_request: Optional[LoRARequest] = None
) -> Optional[int]:
"""Get the maximum input length for the LoRA request."""
return self.max_input_length
def _raise_if_input_too_long(self,
encoded_tokens: list[int],
lora_request: Optional[LoRARequest] = None):
input_length = len(encoded_tokens)
if lora_request:
max_input_length = (lora_request.long_lora_max_len
or self.max_input_length)
else:
max_input_length = self.max_input_length
if max_input_length is not None and input_length > max_input_length:
raise ValueError("Input too long.", input_length, max_input_length)
def encode(self,
prompt: str,
max_length: Optional[int] = None,
truncation: Optional[bool] = None,
lora_request: Optional[LoRARequest] = None,
add_special_tokens: Optional[bool] = None) -> list[int]:
tokenizer = self.get_lora_tokenizer(lora_request)
ret = encode_tokens(tokenizer,
prompt,
max_length=max_length,
truncation=truncation,
add_special_tokens=add_special_tokens)
self._raise_if_input_too_long(ret, lora_request)
return ret
async def encode_async(
self,
prompt: str,
max_length: Optional[int] = None,
truncation: Optional[bool] = None,
lora_request: Optional[LoRARequest] = None,
add_special_tokens: Optional[bool] = None) -> list[int]:
tokenizer = await self.get_lora_tokenizer_async(lora_request)
ret = encode_tokens(tokenizer,
prompt,
max_length=max_length,
truncation=truncation,
add_special_tokens=add_special_tokens)
self._raise_if_input_too_long(ret, lora_request)
return ret
def get_lora_tokenizer(
self,
lora_request: Optional[LoRARequest] = None,
) -> AnyTokenizer:
if not lora_request or not self.enable_lora:
return self.tokenizer
if lora_request.lora_int_id not in self.lora_tokenizers:
tokenizer = (get_lora_tokenizer(
lora_request, **self.tokenizer_config) or self.tokenizer)
self.lora_tokenizers.put(lora_request.lora_int_id, tokenizer)
return tokenizer
else:
return self.lora_tokenizers[lora_request.lora_int_id]
async def get_lora_tokenizer_async(
self,
lora_request: Optional[LoRARequest] = None,
) -> AnyTokenizer:
if not lora_request or not self.enable_lora:
return self.tokenizer
if lora_request.lora_int_id not in self.lora_tokenizers:
tokenizer = (await get_lora_tokenizer_async(
lora_request, **self.tokenizer_config) or self.tokenizer)
self.lora_tokenizers.put(lora_request.lora_int_id, tokenizer)
return tokenizer
else:
return self.lora_tokenizers[lora_request.lora_int_id]
def init_tokenizer_from_configs(model_config: ModelConfig,
scheduler_config: SchedulerConfig,
lora_config: Optional[LoRAConfig]):
runner_type = model_config.runner_type
if runner_type == "generate" or runner_type == "draft":
truncation_side = "left"
elif runner_type == "pooling":
truncation_side = "right"
else:
assert_never(runner_type)
return TokenizerGroup(
tokenizer_id=model_config.tokenizer,
enable_lora=bool(lora_config),
max_num_seqs=scheduler_config.max_num_seqs,
max_loras=lora_config.max_loras if lora_config else 0,
max_input_length=None,
tokenizer_mode=model_config.tokenizer_mode,
trust_remote_code=model_config.trust_remote_code,
revision=model_config.tokenizer_revision,
truncation_side=truncation_side)
......@@ -327,6 +327,10 @@ class MistralTokenizer(TokenizerBase):
def max_token_id(self) -> int:
return self._max_token_id
@property
def truncation_side(self) -> str:
raise NotImplementedError()
def __len__(self) -> int:
return self.vocab_size
......
......@@ -29,8 +29,8 @@ from vllm.tasks import SupportedTask
from vllm.tracing import init_tracer
from vllm.transformers_utils.config import (
maybe_register_config_serialize_by_value)
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
from vllm.transformers_utils.tokenizer import (AnyTokenizer,
init_tokenizer_from_configs)
from vllm.usage.usage_lib import UsageContext
from vllm.utils import (Device, as_list, cancel_task_threadsafe, cdiv,
deprecate_kwargs)
......@@ -112,9 +112,7 @@ class AsyncLLM(EngineClient):
else:
# Tokenizer (+ ensure liveness if running in another process).
self.tokenizer = init_tokenizer_from_configs(
model_config=vllm_config.model_config,
scheduler_config=vllm_config.scheduler_config,
lora_config=vllm_config.lora_config)
model_config=vllm_config.model_config)
# Processor (converts Inputs --> EngineCoreRequests).
self.processor = Processor(
......@@ -596,15 +594,12 @@ class AsyncLLM(EngineClient):
async def get_input_preprocessor(self) -> InputPreprocessor:
return self.processor.input_preprocessor
async def get_tokenizer(
self,
lora_request: Optional[LoRARequest] = None,
) -> AnyTokenizer:
async def get_tokenizer(self) -> AnyTokenizer:
if self.tokenizer is None:
raise ValueError("Unable to get tokenizer because "
"skip_tokenizer_init is True")
return self.tokenizer.get_lora_tokenizer(lora_request)
return self.tokenizer
async def is_tracing_enabled(self) -> bool:
return self.observability_config.otlp_traces_endpoint is not None
......
......@@ -20,8 +20,8 @@ from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams
from vllm.tasks import SupportedTask
from vllm.tracing import init_tracer
from vllm.transformers_utils.tokenizer_group import (
TokenizerGroup, init_tokenizer_from_configs)
from vllm.transformers_utils.tokenizer import (AnyTokenizer,
init_tokenizer_from_configs)
from vllm.usage.usage_lib import UsageContext
from vllm.utils import Device
from vllm.v1.engine.core_client import EngineCoreClient
......@@ -89,9 +89,7 @@ class LLMEngine:
else:
# Tokenizer (+ ensure liveness if running in another process).
self.tokenizer = init_tokenizer_from_configs(
model_config=vllm_config.model_config,
scheduler_config=vllm_config.scheduler_config,
lora_config=vllm_config.lora_config)
model_config=vllm_config.model_config)
# Processor (convert Inputs --> EngineCoreRequests)
self.processor = Processor(vllm_config=vllm_config,
......@@ -297,7 +295,7 @@ class LLMEngine:
assert self.log_stats, "Stat logging disabled"
return get_metrics_snapshot()
def get_tokenizer_group(self) -> TokenizerGroup:
def get_tokenizer(self) -> AnyTokenizer:
if self.tokenizer is None:
raise ValueError("Unable to get tokenizer because "
"skip_tokenizer_init is True")
......
......@@ -14,7 +14,6 @@ from vllm.sampling_params import RequestOutputKind
from vllm.tracing import (SpanAttributes, SpanKind, Tracer,
extract_trace_context)
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest, FinishReason
from vllm.v1.engine.detokenizer import IncrementalDetokenizer
from vllm.v1.engine.logprobs import LogprobsProcessor
......@@ -290,7 +289,7 @@ class RequestState:
class OutputProcessor:
"""Process EngineCoreOutputs into RequestOutputs."""
def __init__(self, tokenizer: TokenizerGroup, log_stats: bool):
def __init__(self, tokenizer: AnyTokenizer, log_stats: bool):
self.log_stats = log_stats
self.tokenizer = tokenizer
self.request_states: dict[str, RequestState] = {}
......@@ -347,10 +346,7 @@ class OutputProcessor:
if request_id in self.request_states:
raise ValueError(f"Request id {request_id} already running.")
tokenizer = None if not self.tokenizer else \
self.tokenizer.get_lora_tokenizer(request.lora_request)
req_state = RequestState.from_new_request(tokenizer=tokenizer,
req_state = RequestState.from_new_request(tokenizer=self.tokenizer,
request=request,
prompt=prompt,
parent_req=parent_req,
......
......@@ -9,6 +9,7 @@ from vllm.config import VllmConfig
from vllm.inputs import ProcessorInputs, PromptType, SingletonInputs
from vllm.inputs.parse import split_enc_dec_inputs
from vllm.inputs.preprocess import InputPreprocessor
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.multimodal.cache import processor_cache_from_config
......@@ -17,7 +18,7 @@ from vllm.multimodal.processing import EncDecMultiModalProcessor
from vllm.multimodal.utils import argsort_mm_positions
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.structured_output.backend_guidance import (
validate_guidance_grammar)
......@@ -28,13 +29,15 @@ from vllm.v1.structured_output.backend_outlines import (
from vllm.v1.structured_output.backend_xgrammar import (
validate_xgrammar_grammar)
logger = init_logger(__name__)
class Processor:
def __init__(
self,
vllm_config: VllmConfig,
tokenizer: TokenizerGroup,
tokenizer: AnyTokenizer,
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
):
......@@ -90,7 +93,6 @@ class Processor:
def _validate_sampling_params(
self,
params: SamplingParams,
lora_request: Optional[LoRARequest],
) -> None:
self._validate_structured_output(params)
self._validate_logit_bias(params)
......@@ -103,8 +105,7 @@ class Processor:
# When skip_tokenizer_init=True, we can't validate token IDs
# Skip validation and let the model handle invalid tokens
return
tokenizer = self.tokenizer.get_lora_tokenizer(lora_request)
vocab_size = len(tokenizer)
vocab_size = len(self.tokenizer)
if not all(0 <= tid < vocab_size for tid in params.allowed_token_ids):
raise ValueError(
"allowed_token_ids contains out-of-vocab token id!")
......@@ -144,7 +145,6 @@ class Processor:
def _validate_params(
self,
params: Union[SamplingParams, PoolingParams],
lora_request: Optional[LoRARequest],
):
"""
Validate supported SamplingParam.
......@@ -155,14 +155,14 @@ class Processor:
return
self._validate_logprobs(params)
self._validate_sampling_params(params, lora_request)
self._validate_sampling_params(params)
self._validate_supported_sampling_params(params)
def _validate_multi_modal_uuids(self, prompt: PromptType) -> None:
"""
Validate that user-provided multi_modal_uuids align with
multi_modal_data in the incoming request prompt(s).
Only checks lengths; `None` entries are allowed and will be
Only checks lengths; `None` entries are allowed and will be
auto-hashed downstream.
"""
......@@ -202,10 +202,22 @@ class Processor:
_validate_single_prompt(prompt) # type: ignore[arg-type]
def _validate_lora(self, lora_request: Optional[LoRARequest]) -> None:
if lora_request is not None and not self.lora_config:
if lora_request is None:
return
# LoRA request passed in while LoRA is not enabled
if not self.lora_config:
raise ValueError(f"Got lora_request {lora_request} but LoRA is "
"not enabled!")
if self.tokenizer is not None:
logger.warning_once(
"vLLM has deprecated support for supporting different "
"tokenizers for different LoRAs. By default, vLLM uses base "
"model's tokenizer. If you are using a LoRA "
"with its own tokenizer, consider specifying `--tokenizer "
"[lora_path]` to use the LoRA tokenizer.")
def _validate_structured_output(self, params: SamplingParams) -> None:
if not params.guided_decoding or not self.decoding_config:
return
......@@ -326,7 +338,7 @@ class Processor:
# TODO(woosuk): Support pooling models.
self._validate_lora(lora_request)
self._validate_params(params, lora_request)
self._validate_params(params)
data_parallel_size = self.vllm_config.parallel_config.data_parallel_size
if data_parallel_rank is not None and not (0 <= data_parallel_rank <
......@@ -365,7 +377,6 @@ class Processor:
processed_inputs: ProcessorInputs = self.input_preprocessor.preprocess(
prompt,
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
mm_uuids=mm_uuids,
)
from vllm.platforms import current_platform
......@@ -375,9 +386,9 @@ class Processor:
processed_inputs=processed_inputs,
)
eos_token_id = self.input_preprocessor.get_eos_token_id(lora_request)
eos_token_id = self.input_preprocessor.get_eos_token_id()
self._validate_model_inputs(processed_inputs, lora_request)
self._validate_model_inputs(processed_inputs)
encoder_inputs, decoder_inputs = split_enc_dec_inputs(processed_inputs)
......@@ -394,8 +405,7 @@ class Processor:
sampling_params.update_from_generation_config(
self.generation_config_fields, eos_token_id)
if self.tokenizer is not None:
sampling_params.update_from_tokenizer(
self.tokenizer.get_lora_tokenizer(lora_request))
sampling_params.update_from_tokenizer(self.tokenizer)
else:
pooling_params = params.clone()
......@@ -436,24 +446,17 @@ class Processor:
trace_headers=trace_headers,
)
def _validate_model_inputs(self,
inputs: ProcessorInputs,
lora_request: Optional[LoRARequest] = None):
def _validate_model_inputs(self, inputs: ProcessorInputs):
encoder_inputs, decoder_inputs = split_enc_dec_inputs(inputs)
if encoder_inputs is not None:
self._validate_model_input(encoder_inputs,
lora_request,
prompt_type="encoder")
self._validate_model_input(encoder_inputs, prompt_type="encoder")
self._validate_model_input(decoder_inputs,
lora_request,
prompt_type="decoder")
self._validate_model_input(decoder_inputs, prompt_type="decoder")
def _validate_model_input(
self,
prompt_inputs: SingletonInputs,
lora_request: Optional[LoRARequest],
*,
prompt_type: Literal["encoder", "decoder"],
):
......@@ -469,7 +472,7 @@ class Processor:
if self.model_config.skip_tokenizer_init:
tokenizer = None
else:
tokenizer = self.tokenizer.get_lora_tokenizer(lora_request)
tokenizer = self.tokenizer
max_input_id = max(prompt_ids, default=0)
# NOTE: tokenizer.max_token_id is the tokenizer’s vocab size while
......
......@@ -9,7 +9,7 @@ from typing import TYPE_CHECKING, Optional
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.reasoning import ReasoningParserManager
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
from vllm.transformers_utils.tokenizer import init_tokenizer_from_configs
from vllm.utils import LazyLoader
from vllm.v1.structured_output.backend_guidance import GuidanceBackend
from vllm.v1.structured_output.backend_types import (StructuredOutputBackend,
......@@ -60,10 +60,7 @@ class StructuredOutputManager:
max_workers = max(1, (multiprocessing.cpu_count() + 1) // 2)
self.executor = ThreadPoolExecutor(max_workers=max_workers)
self.tokenizer = init_tokenizer_from_configs(
model_config=self.vllm_config.model_config,
scheduler_config=self.vllm_config.scheduler_config,
lora_config=self.vllm_config.lora_config,
).get_lora_tokenizer(None)
model_config=self.vllm_config.model_config)
reasoning_backend = \
self.vllm_config.decoding_config.reasoning_backend
if reasoning_backend:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment