Unverified Commit ea5ff3a1 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Refactor] Simplify BOS/EOS token handling (#34435)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 04ea31ba
...@@ -83,6 +83,7 @@ def test_grammar_bitmask_with_specdec(): ...@@ -83,6 +83,7 @@ def test_grammar_bitmask_with_specdec():
), ),
) )
sampling_params.structured_outputs._backend = "guidance" sampling_params.structured_outputs._backend = "guidance"
sampling_params.update_from_generation_config({}, tokenizer.eos_token_id)
my_req_id = f"my_req_id_{i}" my_req_id = f"my_req_id_{i}"
request = Request( request = Request(
...@@ -90,7 +91,6 @@ def test_grammar_bitmask_with_specdec(): ...@@ -90,7 +91,6 @@ def test_grammar_bitmask_with_specdec():
prompt_token_ids=prompt[:i], prompt_token_ids=prompt[:i],
sampling_params=sampling_params, sampling_params=sampling_params,
pooling_params=None, pooling_params=None,
eos_token_id=tokenizer.eos_token_id,
) )
structured_output_manager.grammar_init(request) structured_output_manager.grammar_init(request)
...@@ -147,13 +147,13 @@ def test_grammar_init_async_and_sync(async_grammar): ...@@ -147,13 +147,13 @@ def test_grammar_init_async_and_sync(async_grammar):
), ),
) )
sampling_params.structured_outputs._backend = "guidance" sampling_params.structured_outputs._backend = "guidance"
sampling_params.update_from_generation_config({}, tokenizer.eos_token_id)
request = Request( request = Request(
"test_request", "test_request",
prompt_token_ids=prompt, prompt_token_ids=prompt,
sampling_params=sampling_params, sampling_params=sampling_params,
pooling_params=None, pooling_params=None,
eos_token_id=tokenizer.eos_token_id,
) )
structured_output_manager.grammar_init(request) structured_output_manager.grammar_init(request)
......
...@@ -77,24 +77,6 @@ class InputPreprocessor: ...@@ -77,24 +77,6 @@ class InputPreprocessor:
def get_tokenizer(self) -> TokenizerLike: def get_tokenizer(self) -> TokenizerLike:
return self.renderer.get_tokenizer() return self.renderer.get_tokenizer()
def get_bos_token_id(self) -> int | None:
if self.tokenizer is None:
logger.warning_once(
"Using None for BOS token id because tokenizer is not initialized"
)
return None
return self.tokenizer.bos_token_id
def get_eos_token_id(self) -> int | None:
if self.tokenizer is None:
logger.warning_once(
"Using None for EOS token id because tokenizer is not initialized"
)
return None
return self.tokenizer.eos_token_id
def get_decoder_start_token_id(self) -> int: def get_decoder_start_token_id(self) -> int:
""" """
Obtain the decoder start token id employed by an encoder/decoder Obtain the decoder start token id employed by an encoder/decoder
...@@ -106,11 +88,10 @@ class InputPreprocessor: ...@@ -106,11 +88,10 @@ class InputPreprocessor:
if dec_start_token_id is None: if dec_start_token_id is None:
logger.warning_once( logger.warning_once(
"Falling back on <BOS> for decoder start token " "Falling back on <BOS> for decoder start token id "
"id because decoder start token id is not " "because decoder start token id is not available."
"available."
) )
dec_start_token_id = self.get_bos_token_id() dec_start_token_id = self.renderer.get_bos_token_id()
if dec_start_token_id is None: if dec_start_token_id is None:
raise RuntimeError("Cannot find decoder start token id or <BOS>") raise RuntimeError("Cannot find decoder start token id or <BOS>")
......
...@@ -6,6 +6,7 @@ from collections.abc import Sequence ...@@ -6,6 +6,7 @@ from collections.abc import Sequence
from typing import TYPE_CHECKING, Any, overload from typing import TYPE_CHECKING, Any, overload
from vllm.inputs import EmbedsPrompt, TextPrompt, TokensPrompt from vllm.inputs import EmbedsPrompt, TextPrompt, TokensPrompt
from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike from vllm.tokenizers import TokenizerLike
from vllm.utils.async_utils import AsyncMicrobatchTokenizer from vllm.utils.async_utils import AsyncMicrobatchTokenizer
...@@ -26,6 +27,8 @@ if TYPE_CHECKING: ...@@ -26,6 +27,8 @@ if TYPE_CHECKING:
ConversationMessage, ConversationMessage,
) )
logger = init_logger(__name__)
class BaseRenderer(ABC): class BaseRenderer(ABC):
@classmethod @classmethod
...@@ -63,6 +66,24 @@ class BaseRenderer(ABC): ...@@ -63,6 +66,24 @@ class BaseRenderer(ABC):
return self._async_tokenizer return self._async_tokenizer
def get_bos_token_id(self) -> int | None:
if self.tokenizer is None:
logger.warning_once(
"Using None for BOS token id because tokenizer is not initialized"
)
return None
return self.tokenizer.bos_token_id
def get_eos_token_id(self) -> int | None:
if self.tokenizer is None:
logger.warning_once(
"Using None for EOS token id because tokenizer is not initialized"
)
return None
return self.tokenizer.eos_token_id
# Step 1: Convert raw inputs to prompts # Step 1: Convert raw inputs to prompts
def render_prompt( def render_prompt(
self, self,
......
...@@ -223,6 +223,7 @@ class SamplingParams( ...@@ -223,6 +223,7 @@ class SamplingParams(
# The below fields are not supposed to be used as an input. # The below fields are not supposed to be used as an input.
# They are set in post_init. # They are set in post_init.
output_text_buffer_length: int = 0 output_text_buffer_length: int = 0
_eos_token_id: int | None = None
_all_stop_token_ids: set[int] = msgspec.field(default_factory=set) _all_stop_token_ids: set[int] = msgspec.field(default_factory=set)
# Fields used to construct logits processors # Fields used to construct logits processors
...@@ -477,24 +478,26 @@ class SamplingParams( ...@@ -477,24 +478,26 @@ class SamplingParams(
def update_from_generation_config( def update_from_generation_config(
self, self,
generation_config: dict[str, Any], generation_config: dict[str, Any],
model_eos_token_id: int | None = None, eos_token_id: int | None = None,
) -> None: ) -> None:
"""Update if there are non-default values from generation_config""" """Update if there are non-default values from generation_config"""
if not self.ignore_eos:
self._eos_token_id = eos_token_id
if model_eos_token_id is not None: if eos_token_id is not None:
# Add the eos token id into the sampling_params to support # Add the eos token id into the sampling_params to support
# min_tokens processing. # min_tokens processing.
self._all_stop_token_ids.add(model_eos_token_id) self._all_stop_token_ids.add(eos_token_id)
# Update eos_token_id for generation # Update eos_token_id for generation
if (eos_ids := generation_config.get("eos_token_id")) is not None: if (eos_ids := generation_config.get("eos_token_id")) is not None:
# it can be either int or list of int # it can be either int or list of int
eos_ids = {eos_ids} if isinstance(eos_ids, int) else set(eos_ids) eos_ids = {eos_ids} if isinstance(eos_ids, int) else set(eos_ids)
if model_eos_token_id is not None: if eos_token_id is not None:
# We don't need to include the primary eos_token_id in # We don't need to include the primary eos_token_id in
# stop_token_ids since it's handled separately for stopping # stop_token_ids since it's handled separately for stopping
# purposes. # purposes.
eos_ids.discard(model_eos_token_id) eos_ids.discard(eos_token_id)
if eos_ids: if eos_ids:
self._all_stop_token_ids.update(eos_ids) self._all_stop_token_ids.update(eos_ids)
if not self.ignore_eos: if not self.ignore_eos:
...@@ -550,6 +553,10 @@ class SamplingParams( ...@@ -550,6 +553,10 @@ class SamplingParams(
return SamplingType.RANDOM_SEED return SamplingType.RANDOM_SEED
return SamplingType.RANDOM return SamplingType.RANDOM
@property
def eos_token_id(self) -> int | None:
return self._eos_token_id
@property @property
def all_stop_token_ids(self) -> set[int]: def all_stop_token_ids(self) -> set[int]:
return self._all_stop_token_ids return self._all_stop_token_ids
......
...@@ -47,7 +47,7 @@ def check_stop(request: Request, max_model_len: int) -> bool: ...@@ -47,7 +47,7 @@ def check_stop(request: Request, max_model_len: int) -> bool:
return False return False
last_token_id = request.output_token_ids[-1] last_token_id = request.output_token_ids[-1]
if not sampling_params.ignore_eos and last_token_id == request.eos_token_id: if last_token_id == sampling_params.eos_token_id:
request.status = RequestStatus.FINISHED_STOPPED request.status = RequestStatus.FINISHED_STOPPED
return True return True
......
...@@ -9,6 +9,7 @@ from typing import Any, Literal ...@@ -9,6 +9,7 @@ from typing import Any, Literal
import msgspec import msgspec
import numpy as np import numpy as np
import torch import torch
from typing_extensions import deprecated
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.multimodal.inputs import MultiModalFeatureSpec from vllm.multimodal.inputs import MultiModalFeatureSpec
...@@ -63,7 +64,6 @@ class EngineCoreRequest( ...@@ -63,7 +64,6 @@ class EngineCoreRequest(
mm_features: list[MultiModalFeatureSpec] | None mm_features: list[MultiModalFeatureSpec] | None
sampling_params: SamplingParams | None sampling_params: SamplingParams | None
pooling_params: PoolingParams | None pooling_params: PoolingParams | None
eos_token_id: int | None
arrival_time: float arrival_time: float
lora_request: LoRARequest | None lora_request: LoRARequest | None
cache_salt: str | None cache_salt: str | None
...@@ -99,6 +99,17 @@ class EngineCoreRequest( ...@@ -99,6 +99,17 @@ class EngineCoreRequest(
assert self.pooling_params is not None assert self.pooling_params is not None
return self.pooling_params return self.pooling_params
@property
@deprecated(
"EngineCoreRequest.eos_token_id will be removed in v0.18. "
"Please use EngineCoreRequest.sampling_params.eos_token_id instead."
)
def eos_token_id(self) -> int | None:
if self.sampling_params is None:
return None
return self.sampling_params.eos_token_id
class EngineCoreEventType(enum.IntEnum): class EngineCoreEventType(enum.IntEnum):
"""The type of engine core request event.""" """The type of engine core request event."""
......
...@@ -376,8 +376,6 @@ class InputProcessor: ...@@ -376,8 +376,6 @@ class InputProcessor:
processed_inputs=processed_inputs, processed_inputs=processed_inputs,
) )
eos_token_id = self.input_preprocessor.get_eos_token_id()
encoder_inputs, decoder_inputs = split_enc_dec_inputs(processed_inputs) encoder_inputs, decoder_inputs = split_enc_dec_inputs(processed_inputs)
self._validate_model_inputs(encoder_inputs, decoder_inputs) self._validate_model_inputs(encoder_inputs, decoder_inputs)
...@@ -403,7 +401,7 @@ class InputProcessor: ...@@ -403,7 +401,7 @@ class InputProcessor:
sampling_params.update_from_generation_config( sampling_params.update_from_generation_config(
self.generation_config_fields, self.generation_config_fields,
None if self.tokenizer is None else self.tokenizer.eos_token_id, self.renderer.get_eos_token_id(),
) )
if self.tokenizer is not None: if self.tokenizer is not None:
sampling_params.update_from_tokenizer(self.tokenizer) sampling_params.update_from_tokenizer(self.tokenizer)
...@@ -446,7 +444,6 @@ class InputProcessor: ...@@ -446,7 +444,6 @@ class InputProcessor:
mm_features=mm_features, mm_features=mm_features,
sampling_params=sampling_params, sampling_params=sampling_params,
pooling_params=pooling_params, pooling_params=pooling_params,
eos_token_id=eos_token_id,
arrival_time=arrival_time, arrival_time=arrival_time,
lora_request=lora_request, lora_request=lora_request,
cache_salt=decoder_inputs.get("cache_salt"), cache_salt=decoder_inputs.get("cache_salt"),
......
...@@ -9,6 +9,7 @@ from dataclasses import dataclass ...@@ -9,6 +9,7 @@ from dataclasses import dataclass
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any
import torch import torch
from typing_extensions import deprecated
from vllm.multimodal.inputs import MultiModalFeatureSpec from vllm.multimodal.inputs import MultiModalFeatureSpec
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
...@@ -62,7 +63,6 @@ class Request: ...@@ -62,7 +63,6 @@ class Request:
prompt_token_ids: list[int] | None, prompt_token_ids: list[int] | None,
sampling_params: SamplingParams | None, sampling_params: SamplingParams | None,
pooling_params: PoolingParams | None, pooling_params: PoolingParams | None,
eos_token_id: int | None,
client_index: int = 0, client_index: int = 0,
arrival_time: float | None = None, arrival_time: float | None = None,
prompt_embeds: torch.Tensor | None = None, prompt_embeds: torch.Tensor | None = None,
...@@ -80,8 +80,6 @@ class Request: ...@@ -80,8 +80,6 @@ class Request:
self.priority = priority self.priority = priority
self.sampling_params = sampling_params self.sampling_params = sampling_params
self.pooling_params = pooling_params self.pooling_params = pooling_params
# Because of LoRA, the eos token id can be different for each request.
self.eos_token_id = eos_token_id
self.lora_request = lora_request self.lora_request = lora_request
self.structured_output_request = StructuredOutputRequest.from_sampling_params( self.structured_output_request = StructuredOutputRequest.from_sampling_params(
sampling_params sampling_params
...@@ -176,6 +174,17 @@ class Request: ...@@ -176,6 +174,17 @@ class Request:
# None entry in the queue means finished. # None entry in the queue means finished.
self.streaming_queue: deque[StreamingUpdate | None] | None = None self.streaming_queue: deque[StreamingUpdate | None] | None = None
@property
@deprecated(
"Request.eos_token_id will be removed in v0.18. "
"Please use Request.sampling_params.eos_token_id instead."
)
def eos_token_id(self) -> int | None:
if self.sampling_params is None:
return None
return self.sampling_params.eos_token_id
@classmethod @classmethod
def from_engine_core_request( def from_engine_core_request(
cls, cls,
...@@ -190,7 +199,6 @@ class Request: ...@@ -190,7 +199,6 @@ class Request:
mm_features=request.mm_features, mm_features=request.mm_features,
sampling_params=request.sampling_params, sampling_params=request.sampling_params,
pooling_params=request.pooling_params, pooling_params=request.pooling_params,
eos_token_id=request.eos_token_id,
arrival_time=request.arrival_time, arrival_time=request.arrival_time,
lora_request=request.lora_request, lora_request=request.lora_request,
cache_salt=request.cache_salt, cache_salt=request.cache_salt,
......
...@@ -185,14 +185,13 @@ re_llama_byte_token = re.compile(r"^<0x[0-9A-F]{2}>$") ...@@ -185,14 +185,13 @@ re_llama_byte_token = re.compile(r"^<0x[0-9A-F]{2}>$")
re_replacement_seq = re.compile(r"^.{0,6}�+.{0,6}$") re_replacement_seq = re.compile(r"^.{0,6}�+.{0,6}$")
def _reduced_vocabulary( def _reduced_vocabulary(tokenizer: TokenizerLike) -> dict[bytes, list[int]]:
tokenizer: TokenizerLike, eos_token_id: int
) -> dict[bytes, list[int]]:
"""Create a map from vocabulary tokens to lists of equivalent token ids. """Create a map from vocabulary tokens to lists of equivalent token ids.
Returns: Returns:
A Dict of token string -> equivalent token ids A Dict of token string -> equivalent token ids
""" """
eos_token_id = tokenizer.eos_token_id
unicode_to_bytes = { unicode_to_bytes = {
v: k for k, v in convert_slow_tokenizer.bytes_to_unicode().items() v: k for k, v in convert_slow_tokenizer.bytes_to_unicode().items()
...@@ -260,30 +259,13 @@ def get_outlines_vocabulary(tokenizer: TokenizerLike) -> oc.Vocabulary: ...@@ -260,30 +259,13 @@ def get_outlines_vocabulary(tokenizer: TokenizerLike) -> oc.Vocabulary:
if hasattr(tokenizer, "_outlines_vocabulary"): if hasattr(tokenizer, "_outlines_vocabulary"):
return tokenizer._outlines_vocabulary # type: ignore return tokenizer._outlines_vocabulary # type: ignore
try: reduced_vocab = _reduced_vocabulary(tokenizer)
if hasattr(tokenizer, "eos_token_id") and tokenizer.eos_token_id is not None: vocabulary = OutlinesVocabulary(
eos_token_id = tokenizer.eos_token_id oc.Vocabulary(tokenizer.eos_token_id, reduced_vocab)
else:
raise ValueError(
"Error during structured outputs setup for outlines: Tokenizer "
f"({type(tokenizer)}) has no `eos_token_id` property, but "
"`eos_token_id` is required for structured outputs to work properly."
) )
reduced_vocab = _reduced_vocabulary(
tokenizer,
eos_token_id, # type: ignore
)
vocabulary = OutlinesVocabulary(oc.Vocabulary(eos_token_id, reduced_vocab))
tokenizer._outlines_vocabulary = vocabulary # type: ignore tokenizer._outlines_vocabulary = vocabulary # type: ignore
return vocabulary return vocabulary
except AttributeError as e:
raise ValueError(
"Cannot get the vocabulary of the tokenizer "
f"({type(tokenizer)}). The tokenizer should have a "
"get_vocab method."
) from e
def grammar_is_likely_lark(grammar_str: str) -> bool: def grammar_is_likely_lark(grammar_str: str) -> bool:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment