Unverified Commit ea5ff3a1 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Refactor] Simplify BOS/EOS token handling (#34435)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 04ea31ba
......@@ -83,6 +83,7 @@ def test_grammar_bitmask_with_specdec():
),
)
sampling_params.structured_outputs._backend = "guidance"
sampling_params.update_from_generation_config({}, tokenizer.eos_token_id)
my_req_id = f"my_req_id_{i}"
request = Request(
......@@ -90,7 +91,6 @@ def test_grammar_bitmask_with_specdec():
prompt_token_ids=prompt[:i],
sampling_params=sampling_params,
pooling_params=None,
eos_token_id=tokenizer.eos_token_id,
)
structured_output_manager.grammar_init(request)
......@@ -147,13 +147,13 @@ def test_grammar_init_async_and_sync(async_grammar):
),
)
sampling_params.structured_outputs._backend = "guidance"
sampling_params.update_from_generation_config({}, tokenizer.eos_token_id)
request = Request(
"test_request",
prompt_token_ids=prompt,
sampling_params=sampling_params,
pooling_params=None,
eos_token_id=tokenizer.eos_token_id,
)
structured_output_manager.grammar_init(request)
......
......@@ -77,24 +77,6 @@ class InputPreprocessor:
def get_tokenizer(self) -> TokenizerLike:
return self.renderer.get_tokenizer()
def get_bos_token_id(self) -> int | None:
if self.tokenizer is None:
logger.warning_once(
"Using None for BOS token id because tokenizer is not initialized"
)
return None
return self.tokenizer.bos_token_id
def get_eos_token_id(self) -> int | None:
if self.tokenizer is None:
logger.warning_once(
"Using None for EOS token id because tokenizer is not initialized"
)
return None
return self.tokenizer.eos_token_id
def get_decoder_start_token_id(self) -> int:
"""
Obtain the decoder start token id employed by an encoder/decoder
......@@ -106,11 +88,10 @@ class InputPreprocessor:
if dec_start_token_id is None:
logger.warning_once(
"Falling back on <BOS> for decoder start token "
"id because decoder start token id is not "
"available."
"Falling back on <BOS> for decoder start token id "
"because decoder start token id is not available."
)
dec_start_token_id = self.get_bos_token_id()
dec_start_token_id = self.renderer.get_bos_token_id()
if dec_start_token_id is None:
raise RuntimeError("Cannot find decoder start token id or <BOS>")
......
......@@ -6,6 +6,7 @@ from collections.abc import Sequence
from typing import TYPE_CHECKING, Any, overload
from vllm.inputs import EmbedsPrompt, TextPrompt, TokensPrompt
from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike
from vllm.utils.async_utils import AsyncMicrobatchTokenizer
......@@ -26,6 +27,8 @@ if TYPE_CHECKING:
ConversationMessage,
)
logger = init_logger(__name__)
class BaseRenderer(ABC):
@classmethod
......@@ -63,6 +66,24 @@ class BaseRenderer(ABC):
return self._async_tokenizer
def get_bos_token_id(self) -> int | None:
if self.tokenizer is None:
logger.warning_once(
"Using None for BOS token id because tokenizer is not initialized"
)
return None
return self.tokenizer.bos_token_id
def get_eos_token_id(self) -> int | None:
if self.tokenizer is None:
logger.warning_once(
"Using None for EOS token id because tokenizer is not initialized"
)
return None
return self.tokenizer.eos_token_id
# Step 1: Convert raw inputs to prompts
def render_prompt(
self,
......
......@@ -223,6 +223,7 @@ class SamplingParams(
# The below fields are not supposed to be used as an input.
# They are set in post_init.
output_text_buffer_length: int = 0
_eos_token_id: int | None = None
_all_stop_token_ids: set[int] = msgspec.field(default_factory=set)
# Fields used to construct logits processors
......@@ -477,24 +478,26 @@ class SamplingParams(
def update_from_generation_config(
self,
generation_config: dict[str, Any],
model_eos_token_id: int | None = None,
eos_token_id: int | None = None,
) -> None:
"""Update if there are non-default values from generation_config"""
if not self.ignore_eos:
self._eos_token_id = eos_token_id
if model_eos_token_id is not None:
if eos_token_id is not None:
# Add the eos token id into the sampling_params to support
# min_tokens processing.
self._all_stop_token_ids.add(model_eos_token_id)
self._all_stop_token_ids.add(eos_token_id)
# Update eos_token_id for generation
if (eos_ids := generation_config.get("eos_token_id")) is not None:
# it can be either int or list of int
eos_ids = {eos_ids} if isinstance(eos_ids, int) else set(eos_ids)
if model_eos_token_id is not None:
if eos_token_id is not None:
# We don't need to include the primary eos_token_id in
# stop_token_ids since it's handled separately for stopping
# purposes.
eos_ids.discard(model_eos_token_id)
eos_ids.discard(eos_token_id)
if eos_ids:
self._all_stop_token_ids.update(eos_ids)
if not self.ignore_eos:
......@@ -550,6 +553,10 @@ class SamplingParams(
return SamplingType.RANDOM_SEED
return SamplingType.RANDOM
@property
def eos_token_id(self) -> int | None:
return self._eos_token_id
@property
def all_stop_token_ids(self) -> set[int]:
return self._all_stop_token_ids
......
......@@ -47,7 +47,7 @@ def check_stop(request: Request, max_model_len: int) -> bool:
return False
last_token_id = request.output_token_ids[-1]
if not sampling_params.ignore_eos and last_token_id == request.eos_token_id:
if last_token_id == sampling_params.eos_token_id:
request.status = RequestStatus.FINISHED_STOPPED
return True
......
......@@ -9,6 +9,7 @@ from typing import Any, Literal
import msgspec
import numpy as np
import torch
from typing_extensions import deprecated
from vllm.lora.request import LoRARequest
from vllm.multimodal.inputs import MultiModalFeatureSpec
......@@ -63,7 +64,6 @@ class EngineCoreRequest(
mm_features: list[MultiModalFeatureSpec] | None
sampling_params: SamplingParams | None
pooling_params: PoolingParams | None
eos_token_id: int | None
arrival_time: float
lora_request: LoRARequest | None
cache_salt: str | None
......@@ -99,6 +99,17 @@ class EngineCoreRequest(
assert self.pooling_params is not None
return self.pooling_params
@property
@deprecated(
"EngineCoreRequest.eos_token_id will be removed in v0.18. "
"Please use EngineCoreRequest.sampling_params.eos_token_id instead."
)
def eos_token_id(self) -> int | None:
if self.sampling_params is None:
return None
return self.sampling_params.eos_token_id
class EngineCoreEventType(enum.IntEnum):
"""The type of engine core request event."""
......
......@@ -376,8 +376,6 @@ class InputProcessor:
processed_inputs=processed_inputs,
)
eos_token_id = self.input_preprocessor.get_eos_token_id()
encoder_inputs, decoder_inputs = split_enc_dec_inputs(processed_inputs)
self._validate_model_inputs(encoder_inputs, decoder_inputs)
......@@ -403,7 +401,7 @@ class InputProcessor:
sampling_params.update_from_generation_config(
self.generation_config_fields,
None if self.tokenizer is None else self.tokenizer.eos_token_id,
self.renderer.get_eos_token_id(),
)
if self.tokenizer is not None:
sampling_params.update_from_tokenizer(self.tokenizer)
......@@ -446,7 +444,6 @@ class InputProcessor:
mm_features=mm_features,
sampling_params=sampling_params,
pooling_params=pooling_params,
eos_token_id=eos_token_id,
arrival_time=arrival_time,
lora_request=lora_request,
cache_salt=decoder_inputs.get("cache_salt"),
......
......@@ -9,6 +9,7 @@ from dataclasses import dataclass
from typing import TYPE_CHECKING, Any
import torch
from typing_extensions import deprecated
from vllm.multimodal.inputs import MultiModalFeatureSpec
from vllm.pooling_params import PoolingParams
......@@ -62,7 +63,6 @@ class Request:
prompt_token_ids: list[int] | None,
sampling_params: SamplingParams | None,
pooling_params: PoolingParams | None,
eos_token_id: int | None,
client_index: int = 0,
arrival_time: float | None = None,
prompt_embeds: torch.Tensor | None = None,
......@@ -80,8 +80,6 @@ class Request:
self.priority = priority
self.sampling_params = sampling_params
self.pooling_params = pooling_params
# Because of LoRA, the eos token id can be different for each request.
self.eos_token_id = eos_token_id
self.lora_request = lora_request
self.structured_output_request = StructuredOutputRequest.from_sampling_params(
sampling_params
......@@ -176,6 +174,17 @@ class Request:
# None entry in the queue means finished.
self.streaming_queue: deque[StreamingUpdate | None] | None = None
@property
@deprecated(
"Request.eos_token_id will be removed in v0.18. "
"Please use Request.sampling_params.eos_token_id instead."
)
def eos_token_id(self) -> int | None:
if self.sampling_params is None:
return None
return self.sampling_params.eos_token_id
@classmethod
def from_engine_core_request(
cls,
......@@ -190,7 +199,6 @@ class Request:
mm_features=request.mm_features,
sampling_params=request.sampling_params,
pooling_params=request.pooling_params,
eos_token_id=request.eos_token_id,
arrival_time=request.arrival_time,
lora_request=request.lora_request,
cache_salt=request.cache_salt,
......
......@@ -185,14 +185,13 @@ re_llama_byte_token = re.compile(r"^<0x[0-9A-F]{2}>$")
re_replacement_seq = re.compile(r"^.{0,6}�+.{0,6}$")
def _reduced_vocabulary(
tokenizer: TokenizerLike, eos_token_id: int
) -> dict[bytes, list[int]]:
def _reduced_vocabulary(tokenizer: TokenizerLike) -> dict[bytes, list[int]]:
"""Create a map from vocabulary tokens to lists of equivalent token ids.
Returns:
A Dict of token string -> equivalent token ids
"""
eos_token_id = tokenizer.eos_token_id
unicode_to_bytes = {
v: k for k, v in convert_slow_tokenizer.bytes_to_unicode().items()
......@@ -260,30 +259,13 @@ def get_outlines_vocabulary(tokenizer: TokenizerLike) -> oc.Vocabulary:
if hasattr(tokenizer, "_outlines_vocabulary"):
return tokenizer._outlines_vocabulary # type: ignore
try:
if hasattr(tokenizer, "eos_token_id") and tokenizer.eos_token_id is not None:
eos_token_id = tokenizer.eos_token_id
else:
raise ValueError(
"Error during structured outputs setup for outlines: Tokenizer "
f"({type(tokenizer)}) has no `eos_token_id` property, but "
"`eos_token_id` is required for structured outputs to work properly."
)
reduced_vocab = _reduced_vocabulary(
tokenizer,
eos_token_id, # type: ignore
)
vocabulary = OutlinesVocabulary(oc.Vocabulary(eos_token_id, reduced_vocab))
tokenizer._outlines_vocabulary = vocabulary # type: ignore
reduced_vocab = _reduced_vocabulary(tokenizer)
vocabulary = OutlinesVocabulary(
oc.Vocabulary(tokenizer.eos_token_id, reduced_vocab)
)
tokenizer._outlines_vocabulary = vocabulary # type: ignore
return vocabulary
except AttributeError as e:
raise ValueError(
"Cannot get the vocabulary of the tokenizer "
f"({type(tokenizer)}). The tokenizer should have a "
"get_vocab method."
) from e
return vocabulary
def grammar_is_likely_lark(grammar_str: str) -> bool:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment