Unverified Commit ea5ff3a1 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Refactor] Simplify BOS/EOS token handling (#34435)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 04ea31ba
......@@ -39,7 +39,6 @@ def test_min_tokens_with_stop(min_tokens: int, stop: str, truth: str):
mm_features=None,
sampling_params=params,
pooling_params=None,
eos_token_id=None,
arrival_time=0.0,
lora_request=None,
cache_salt=None,
......
......@@ -35,7 +35,6 @@ def _make_request(stop, include_stop_str_in_output: bool, min_tokens: int = 0):
mm_features=None,
sampling_params=params,
pooling_params=None,
eos_token_id=None,
arrival_time=0.0,
lora_request=None,
cache_salt=None,
......
......@@ -67,7 +67,6 @@ def _run_incremental_decode(
mm_features=None,
sampling_params=params,
pooling_params=None,
eos_token_id=None,
arrival_time=0.0,
lora_request=None,
cache_salt=None,
......
......@@ -1123,7 +1123,7 @@ rectangle
# Encode all content tokens at once
all_token_ids = step3p5_tokenizer.encode(model_output, add_special_tokens=False)
eos_token_id = getattr(step3p5_tokenizer, "eos_token_id", None)
eos_token_id = step3p5_tokenizer.eos_token_id
# Include EOS token in delta_token_ids if available
if eos_token_id is not None:
......
......@@ -84,13 +84,15 @@ def make_request(
)
mm_features.append(mm_feature)
sampling_params = SamplingParams(max_tokens=17)
sampling_params.update_from_generation_config({}, eos_token_id=100)
return Request(
request_id=request_id,
prompt_token_ids=prompt_token_ids,
mm_features=mm_features if mm_features else None,
sampling_params=SamplingParams(max_tokens=17),
sampling_params=sampling_params,
pooling_params=None,
eos_token_id=100,
lora_request=None,
cache_salt=cache_salt,
block_hasher=get_request_block_hasher(block_size, hash_fn),
......
......@@ -75,13 +75,15 @@ def make_request(
)
mm_features.append(mm_feature)
sampling_params = SamplingParams(max_tokens=17, prompt_logprobs=prompt_logprobs)
sampling_params.update_from_generation_config({}, eos_token_id=100)
return Request(
request_id=request_id,
prompt_token_ids=prompt_token_ids,
mm_features=mm_features if mm_features else None,
sampling_params=SamplingParams(max_tokens=17, prompt_logprobs=prompt_logprobs),
sampling_params=sampling_params,
pooling_params=None,
eos_token_id=100,
lora_request=lora_request,
cache_salt=cache_salt,
block_hasher=get_request_block_hasher(block_size, hash_fn),
......
......@@ -48,10 +48,9 @@ def _create_random_request(
request_id = uuid.uuid4().hex
sampling_params = SamplingParams(
ignore_eos=False,
max_tokens=max_tokens,
)
sampling_params = SamplingParams(ignore_eos=False, max_tokens=max_tokens)
sampling_params.update_from_generation_config({}, EOS_TOKEN_ID)
mm_features = []
for j, position in enumerate(mm_positions):
identifier = f"{request_id}_hash_{j}"
......@@ -79,7 +78,6 @@ def _create_random_request(
sampling_params=sampling_params,
pooling_params=None,
mm_features=mm_features if mm_features else None,
eos_token_id=EOS_TOKEN_ID,
arrival_time=arrival_time,
priority=priority,
block_hasher=block_hasher,
......
......@@ -469,8 +469,7 @@ def test_stop_via_update_from_output():
# Test case 4: Ignore EOS flag
scheduler = create_scheduler(num_speculative_tokens=2)
requests = create_requests(num_requests=1, max_tokens=10)
requests[0].sampling_params.ignore_eos = True
requests = create_requests(num_requests=1, max_tokens=10, ignore_eos=True)
requests[0].num_computed_tokens = requests[0].num_tokens
scheduler.requests[requests[0].request_id] = requests[0]
scheduler.running.append(requests[0])
......@@ -515,12 +514,12 @@ def test_check_stop_min_tokens():
max_tokens=20,
min_tokens=5,
)
sampling_params.update_from_generation_config({}, EOS_TOKEN_ID)
request = Request(
request_id="0",
prompt_token_ids=[0, 1, 2],
sampling_params=sampling_params,
pooling_params=None,
eos_token_id=EOS_TOKEN_ID,
)
# Simulate having generated 3 output tokens (less than min_tokens=5)
request.append_output_token_ids([10, 11, EOS_TOKEN_ID]) # EOS token present
......@@ -551,12 +550,12 @@ def test_check_stop_min_tokens():
max_tokens=20,
min_tokens=0,
)
sampling_params_no_min.update_from_generation_config({}, EOS_TOKEN_ID)
request_no_min = Request(
request_id="1",
prompt_token_ids=[0, 1, 2],
sampling_params=sampling_params_no_min,
pooling_params=None,
eos_token_id=EOS_TOKEN_ID,
)
request_no_min.append_output_token_ids([10, EOS_TOKEN_ID])
......@@ -571,12 +570,12 @@ def test_check_stop_min_tokens():
min_tokens=5,
stop_token_ids=[42],
)
sampling_params_stop.update_from_generation_config({}, EOS_TOKEN_ID)
request_stop = Request(
request_id="2",
prompt_token_ids=[0, 1, 2],
sampling_params=sampling_params_stop,
pooling_params=None,
eos_token_id=EOS_TOKEN_ID,
)
# Only 3 output tokens, less than min_tokens=5, but has stop token
request_stop.append_output_token_ids([10, 11, 42])
......@@ -1877,6 +1876,7 @@ def create_requests_with_priority(
stop_token_ids=stop_token_ids,
prompt_logprobs=prompt_logprobs,
)
sampling_params.update_from_generation_config({}, EOS_TOKEN_ID)
requests = []
if mm_hashes_list is not None:
......@@ -1938,7 +1938,6 @@ def create_requests_with_priority(
sampling_params=sampling_params,
pooling_params=None,
mm_features=mm_features if mm_features else None,
eos_token_id=EOS_TOKEN_ID,
arrival_time=arrival_times[i],
priority=priorities[i],
block_hasher=block_hasher,
......@@ -2429,13 +2428,13 @@ def test_schedule_skip_tokenizer_init_structured_output_request():
max_tokens=16,
structured_outputs=structured_outputs_params,
)
sampling_params.update_from_generation_config({}, EOS_TOKEN_ID)
request = Request(
request_id="0",
prompt_token_ids=[0, 1],
mm_features=None,
sampling_params=sampling_params,
pooling_params=None,
eos_token_id=EOS_TOKEN_ID,
)
scheduler.add_request(request)
output = scheduler.schedule()
......
......@@ -174,6 +174,7 @@ def create_requests(
num_tokens: int = 10,
mm_hashes_list: list[list[str]] | None = None,
mm_positions: list[list[PlaceholderRange]] | None = None,
ignore_eos: bool = False,
max_tokens: int = 16,
stop_token_ids: list[int] | None = None,
prompt_logprobs: int | None = None,
......@@ -188,11 +189,12 @@ def create_requests(
block_hasher = get_request_block_hasher(block_size, sha256)
sampling_params = SamplingParams(
ignore_eos=False,
ignore_eos=ignore_eos,
max_tokens=max_tokens,
stop_token_ids=stop_token_ids,
prompt_logprobs=prompt_logprobs,
)
sampling_params.update_from_generation_config({}, EOS_TOKEN_ID)
requests = []
if mm_hashes_list is not None:
......@@ -250,7 +252,6 @@ def create_requests(
sampling_params=sampling_params,
pooling_params=None,
mm_features=mm_features if mm_features else None,
eos_token_id=EOS_TOKEN_ID,
block_hasher=block_hasher,
)
requests.append(request)
......
......@@ -54,7 +54,6 @@ def make_request() -> EngineCoreRequest:
mm_features=None,
sampling_params=SamplingParams(),
pooling_params=None,
eos_token_id=None,
arrival_time=time.time(),
lora_request=None,
cache_salt=None,
......
......@@ -69,7 +69,6 @@ def make_request(
mm_features=None,
sampling_params=params,
pooling_params=None,
eos_token_id=None,
arrival_time=time.time(),
lora_request=None,
cache_salt=None,
......
......@@ -32,7 +32,6 @@ def test_fast_inc_detok_invalid_utf8_err_case():
mm_features=None,
sampling_params=params,
pooling_params=None,
eos_token_id=None,
arrival_time=0.0,
lora_request=None,
cache_salt=None,
......
......@@ -66,7 +66,6 @@ def test_incremental_detokenization(
external_req_id=f"request-{idx}",
prompt_token_ids=prompt_tokens,
mm_features=None,
eos_token_id=None,
arrival_time=0,
lora_request=None,
cache_salt=None,
......@@ -487,7 +486,6 @@ def test_logprobs_processor(
external_req_id=request_id_list[idx],
prompt_token_ids=prompt_tokens,
mm_features=None,
eos_token_id=None,
arrival_time=0,
lora_request=None,
cache_salt=None,
......@@ -663,6 +661,19 @@ def test_stop_token(
prompt_string = dummy_test_vectors.prompt_strings[0]
prompt_tokens = dummy_test_vectors.prompt_tokens[0]
sampling_params = SamplingParams(
skip_special_tokens=False,
spaces_between_special_tokens=False,
output_kind=RequestOutputKind.DELTA,
stop=[],
stop_token_ids=stop_token_ids,
include_stop_str_in_output=include_stop_str_in_output,
logprobs=num_sample_logprobs,
prompt_logprobs=None,
ignore_eos=ignore_eos,
)
sampling_params.update_from_generation_config({}, eos_token_id)
# Make request.
request_id = "request-0"
request = EngineCoreRequest(
......@@ -670,22 +681,11 @@ def test_stop_token(
external_req_id=request_id + "-ext",
prompt_token_ids=prompt_tokens,
mm_features=None,
eos_token_id=eos_token_id,
arrival_time=0,
lora_request=None,
cache_salt=None,
data_parallel_rank=None,
sampling_params=SamplingParams(
skip_special_tokens=False,
spaces_between_special_tokens=False,
output_kind=RequestOutputKind.DELTA,
stop=[],
stop_token_ids=stop_token_ids,
include_stop_str_in_output=include_stop_str_in_output,
logprobs=num_sample_logprobs,
prompt_logprobs=None,
ignore_eos=ignore_eos,
),
sampling_params=sampling_params,
pooling_params=None,
)
......@@ -693,9 +693,8 @@ def test_stop_token(
tokens_list=[generation_tokens],
generated_logprobs_raw=[generation_logprobs] if do_logprobs else None,
prompt_logprobs_raw=None,
eos_token_id=eos_token_id,
stop_token_ids=stop_token_ids,
ignore_eos=ignore_eos,
eos_token_id=sampling_params.eos_token_id,
stop_token_ids=sampling_params.stop_token_ids,
request_ids=[request.request_id],
)
......@@ -775,7 +774,6 @@ def test_stop_string(
external_req_id=request_id_list[idx],
prompt_token_ids=prompt_tokens,
mm_features=None,
eos_token_id=None,
arrival_time=0,
lora_request=None,
cache_salt=None,
......@@ -907,7 +905,6 @@ def test_iteration_stats(dummy_test_vectors):
external_req_id=f"request-{idx}-ext",
prompt_token_ids=prompt_tokens,
mm_features=None,
eos_token_id=None,
arrival_time=0,
lora_request=None,
cache_salt=None,
......@@ -994,7 +991,6 @@ def test_lora_request_tracking(log_stats: bool, dummy_test_vectors):
external_req_id=f"request-{idx}",
prompt_token_ids=prompt_tokens,
mm_features=None,
eos_token_id=None,
arrival_time=0,
lora_request=lora_assignments[idx],
cache_salt=None,
......@@ -1315,7 +1311,6 @@ def test_abort_requests(runner: str, abort_by: str, dummy_test_vectors):
external_req_id=f"external-{idx}",
prompt_token_ids=prompt_tokens,
mm_features=None,
eos_token_id=None,
arrival_time=0,
lora_request=None,
cache_salt=None,
......
......@@ -76,7 +76,6 @@ def make_request(sampling_params: SamplingParams) -> EngineCoreRequest:
mm_features=None,
sampling_params=sampling_params,
pooling_params=None,
eos_token_id=None,
arrival_time=0.0,
lora_request=None,
cache_salt=None,
......
......@@ -342,7 +342,6 @@ class MockEngineCore:
prompt_logprobs_raw: list[LogprobsTensors] | None = None,
eos_token_id: int | None = None,
stop_token_ids: list[int] | None = None,
ignore_eos: bool = False,
request_ids: list[str] | None = None,
) -> None:
self.num_requests = len(tokens_list)
......@@ -355,7 +354,6 @@ class MockEngineCore:
self.request_finished = [False for _ in range(self.num_requests)]
self.eos_token_id = eos_token_id
self.stop_token_ids = stop_token_ids
self.ignore_eos = ignore_eos
self.request_ids = (
request_ids
if request_ids is not None
......@@ -400,7 +398,7 @@ class MockEngineCore:
if token_idx == len(token_ids) - 1:
output.finish_reason = FinishReason.LENGTH
self.request_finished[req_idx] = True
if not self.ignore_eos and new_token_id == self.eos_token_id:
if new_token_id == self.eos_token_id:
output.finish_reason = FinishReason.STOP
self.request_finished[req_idx] = True
if new_token_id in (self.stop_token_ids or ()):
......
......@@ -93,12 +93,14 @@ class DecodeBenchTestRunner:
"""Create a new request with given token IDs."""
self.req_id += 1
sampling_params = SamplingParams(max_tokens=100)
sampling_params.update_from_generation_config({}, EOS_TOKEN_ID)
req = Request(
request_id=str(self.req_id),
prompt_token_ids=token_ids,
sampling_params=SamplingParams(max_tokens=100),
sampling_params=sampling_params,
pooling_params=None,
eos_token_id=EOS_TOKEN_ID,
block_hasher=self._block_hasher,
)
......
......@@ -142,12 +142,14 @@ def test_request_interface():
from vllm.sampling_params import SamplingParams
from vllm.v1.request import Request
sampling_params = SamplingParams(max_tokens=10)
sampling_params.update_from_generation_config({}, eos_token_id=100)
req = Request(
request_id="test_request",
prompt_token_ids=[1, 2, 3],
sampling_params=SamplingParams(max_tokens=10),
sampling_params=sampling_params,
pooling_params=None,
eos_token_id=100,
lora_request=None,
)
assumes(req, "mm_features", is_instance_of=(list, NoneType))
......
......@@ -226,12 +226,14 @@ class RequestRunner:
def new_request(self, token_ids: list[int]):
self.req_id += 1
sampling_params = SamplingParams(max_tokens=1000)
sampling_params.update_from_generation_config({}, EOS_TOKEN_ID)
req = Request(
request_id=str(self.req_id),
prompt_token_ids=token_ids,
sampling_params=SamplingParams(max_tokens=1000),
sampling_params=sampling_params,
pooling_params=None,
eos_token_id=EOS_TOKEN_ID,
block_hasher=self._block_hasher,
)
......
......@@ -212,6 +212,7 @@ def create_request(
max_tokens = 1 if do_remote_decode else max_tokens
sampling_params = SamplingParams(max_tokens=max_tokens)
sampling_params.update_from_generation_config({}, EOS_TOKEN_ID)
common_prefix = [1] * common_prefix_len if common_prefix_len > 0 else []
suffix = [i * request_id for i in range(num_tokens - common_prefix_len)]
......@@ -223,7 +224,6 @@ def create_request(
sampling_params=sampling_params,
pooling_params=None,
mm_features=None,
eos_token_id=EOS_TOKEN_ID,
block_hasher=get_request_block_hasher(block_size, hash_fn),
)
req.kv_transfer_params = kv_transfer_params
......
......@@ -43,7 +43,6 @@ class DummyRequest(Request):
stop_token_ids=[STOP_TOKEN], max_tokens=max_tokens
),
pooling_params=None,
eos_token_id=None,
mm_features=mm_features,
resumable=resumable,
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment