Unverified Commit ea5ff3a1 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Refactor] Simplify BOS/EOS token handling (#34435)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 04ea31ba
...@@ -39,7 +39,6 @@ def test_min_tokens_with_stop(min_tokens: int, stop: str, truth: str): ...@@ -39,7 +39,6 @@ def test_min_tokens_with_stop(min_tokens: int, stop: str, truth: str):
mm_features=None, mm_features=None,
sampling_params=params, sampling_params=params,
pooling_params=None, pooling_params=None,
eos_token_id=None,
arrival_time=0.0, arrival_time=0.0,
lora_request=None, lora_request=None,
cache_salt=None, cache_salt=None,
......
...@@ -35,7 +35,6 @@ def _make_request(stop, include_stop_str_in_output: bool, min_tokens: int = 0): ...@@ -35,7 +35,6 @@ def _make_request(stop, include_stop_str_in_output: bool, min_tokens: int = 0):
mm_features=None, mm_features=None,
sampling_params=params, sampling_params=params,
pooling_params=None, pooling_params=None,
eos_token_id=None,
arrival_time=0.0, arrival_time=0.0,
lora_request=None, lora_request=None,
cache_salt=None, cache_salt=None,
......
...@@ -67,7 +67,6 @@ def _run_incremental_decode( ...@@ -67,7 +67,6 @@ def _run_incremental_decode(
mm_features=None, mm_features=None,
sampling_params=params, sampling_params=params,
pooling_params=None, pooling_params=None,
eos_token_id=None,
arrival_time=0.0, arrival_time=0.0,
lora_request=None, lora_request=None,
cache_salt=None, cache_salt=None,
......
...@@ -1123,7 +1123,7 @@ rectangle ...@@ -1123,7 +1123,7 @@ rectangle
# Encode all content tokens at once # Encode all content tokens at once
all_token_ids = step3p5_tokenizer.encode(model_output, add_special_tokens=False) all_token_ids = step3p5_tokenizer.encode(model_output, add_special_tokens=False)
eos_token_id = getattr(step3p5_tokenizer, "eos_token_id", None) eos_token_id = step3p5_tokenizer.eos_token_id
# Include EOS token in delta_token_ids if available # Include EOS token in delta_token_ids if available
if eos_token_id is not None: if eos_token_id is not None:
......
...@@ -84,13 +84,15 @@ def make_request( ...@@ -84,13 +84,15 @@ def make_request(
) )
mm_features.append(mm_feature) mm_features.append(mm_feature)
sampling_params = SamplingParams(max_tokens=17)
sampling_params.update_from_generation_config({}, eos_token_id=100)
return Request( return Request(
request_id=request_id, request_id=request_id,
prompt_token_ids=prompt_token_ids, prompt_token_ids=prompt_token_ids,
mm_features=mm_features if mm_features else None, mm_features=mm_features if mm_features else None,
sampling_params=SamplingParams(max_tokens=17), sampling_params=sampling_params,
pooling_params=None, pooling_params=None,
eos_token_id=100,
lora_request=None, lora_request=None,
cache_salt=cache_salt, cache_salt=cache_salt,
block_hasher=get_request_block_hasher(block_size, hash_fn), block_hasher=get_request_block_hasher(block_size, hash_fn),
......
...@@ -75,13 +75,15 @@ def make_request( ...@@ -75,13 +75,15 @@ def make_request(
) )
mm_features.append(mm_feature) mm_features.append(mm_feature)
sampling_params = SamplingParams(max_tokens=17, prompt_logprobs=prompt_logprobs)
sampling_params.update_from_generation_config({}, eos_token_id=100)
return Request( return Request(
request_id=request_id, request_id=request_id,
prompt_token_ids=prompt_token_ids, prompt_token_ids=prompt_token_ids,
mm_features=mm_features if mm_features else None, mm_features=mm_features if mm_features else None,
sampling_params=SamplingParams(max_tokens=17, prompt_logprobs=prompt_logprobs), sampling_params=sampling_params,
pooling_params=None, pooling_params=None,
eos_token_id=100,
lora_request=lora_request, lora_request=lora_request,
cache_salt=cache_salt, cache_salt=cache_salt,
block_hasher=get_request_block_hasher(block_size, hash_fn), block_hasher=get_request_block_hasher(block_size, hash_fn),
......
...@@ -48,10 +48,9 @@ def _create_random_request( ...@@ -48,10 +48,9 @@ def _create_random_request(
request_id = uuid.uuid4().hex request_id = uuid.uuid4().hex
sampling_params = SamplingParams( sampling_params = SamplingParams(ignore_eos=False, max_tokens=max_tokens)
ignore_eos=False, sampling_params.update_from_generation_config({}, EOS_TOKEN_ID)
max_tokens=max_tokens,
)
mm_features = [] mm_features = []
for j, position in enumerate(mm_positions): for j, position in enumerate(mm_positions):
identifier = f"{request_id}_hash_{j}" identifier = f"{request_id}_hash_{j}"
...@@ -79,7 +78,6 @@ def _create_random_request( ...@@ -79,7 +78,6 @@ def _create_random_request(
sampling_params=sampling_params, sampling_params=sampling_params,
pooling_params=None, pooling_params=None,
mm_features=mm_features if mm_features else None, mm_features=mm_features if mm_features else None,
eos_token_id=EOS_TOKEN_ID,
arrival_time=arrival_time, arrival_time=arrival_time,
priority=priority, priority=priority,
block_hasher=block_hasher, block_hasher=block_hasher,
......
...@@ -469,8 +469,7 @@ def test_stop_via_update_from_output(): ...@@ -469,8 +469,7 @@ def test_stop_via_update_from_output():
# Test case 4: Ignore EOS flag # Test case 4: Ignore EOS flag
scheduler = create_scheduler(num_speculative_tokens=2) scheduler = create_scheduler(num_speculative_tokens=2)
requests = create_requests(num_requests=1, max_tokens=10) requests = create_requests(num_requests=1, max_tokens=10, ignore_eos=True)
requests[0].sampling_params.ignore_eos = True
requests[0].num_computed_tokens = requests[0].num_tokens requests[0].num_computed_tokens = requests[0].num_tokens
scheduler.requests[requests[0].request_id] = requests[0] scheduler.requests[requests[0].request_id] = requests[0]
scheduler.running.append(requests[0]) scheduler.running.append(requests[0])
...@@ -515,12 +514,12 @@ def test_check_stop_min_tokens(): ...@@ -515,12 +514,12 @@ def test_check_stop_min_tokens():
max_tokens=20, max_tokens=20,
min_tokens=5, min_tokens=5,
) )
sampling_params.update_from_generation_config({}, EOS_TOKEN_ID)
request = Request( request = Request(
request_id="0", request_id="0",
prompt_token_ids=[0, 1, 2], prompt_token_ids=[0, 1, 2],
sampling_params=sampling_params, sampling_params=sampling_params,
pooling_params=None, pooling_params=None,
eos_token_id=EOS_TOKEN_ID,
) )
# Simulate having generated 3 output tokens (less than min_tokens=5) # Simulate having generated 3 output tokens (less than min_tokens=5)
request.append_output_token_ids([10, 11, EOS_TOKEN_ID]) # EOS token present request.append_output_token_ids([10, 11, EOS_TOKEN_ID]) # EOS token present
...@@ -551,12 +550,12 @@ def test_check_stop_min_tokens(): ...@@ -551,12 +550,12 @@ def test_check_stop_min_tokens():
max_tokens=20, max_tokens=20,
min_tokens=0, min_tokens=0,
) )
sampling_params_no_min.update_from_generation_config({}, EOS_TOKEN_ID)
request_no_min = Request( request_no_min = Request(
request_id="1", request_id="1",
prompt_token_ids=[0, 1, 2], prompt_token_ids=[0, 1, 2],
sampling_params=sampling_params_no_min, sampling_params=sampling_params_no_min,
pooling_params=None, pooling_params=None,
eos_token_id=EOS_TOKEN_ID,
) )
request_no_min.append_output_token_ids([10, EOS_TOKEN_ID]) request_no_min.append_output_token_ids([10, EOS_TOKEN_ID])
...@@ -571,12 +570,12 @@ def test_check_stop_min_tokens(): ...@@ -571,12 +570,12 @@ def test_check_stop_min_tokens():
min_tokens=5, min_tokens=5,
stop_token_ids=[42], stop_token_ids=[42],
) )
sampling_params_stop.update_from_generation_config({}, EOS_TOKEN_ID)
request_stop = Request( request_stop = Request(
request_id="2", request_id="2",
prompt_token_ids=[0, 1, 2], prompt_token_ids=[0, 1, 2],
sampling_params=sampling_params_stop, sampling_params=sampling_params_stop,
pooling_params=None, pooling_params=None,
eos_token_id=EOS_TOKEN_ID,
) )
# Only 3 output tokens, less than min_tokens=5, but has stop token # Only 3 output tokens, less than min_tokens=5, but has stop token
request_stop.append_output_token_ids([10, 11, 42]) request_stop.append_output_token_ids([10, 11, 42])
...@@ -1877,6 +1876,7 @@ def create_requests_with_priority( ...@@ -1877,6 +1876,7 @@ def create_requests_with_priority(
stop_token_ids=stop_token_ids, stop_token_ids=stop_token_ids,
prompt_logprobs=prompt_logprobs, prompt_logprobs=prompt_logprobs,
) )
sampling_params.update_from_generation_config({}, EOS_TOKEN_ID)
requests = [] requests = []
if mm_hashes_list is not None: if mm_hashes_list is not None:
...@@ -1938,7 +1938,6 @@ def create_requests_with_priority( ...@@ -1938,7 +1938,6 @@ def create_requests_with_priority(
sampling_params=sampling_params, sampling_params=sampling_params,
pooling_params=None, pooling_params=None,
mm_features=mm_features if mm_features else None, mm_features=mm_features if mm_features else None,
eos_token_id=EOS_TOKEN_ID,
arrival_time=arrival_times[i], arrival_time=arrival_times[i],
priority=priorities[i], priority=priorities[i],
block_hasher=block_hasher, block_hasher=block_hasher,
...@@ -2429,13 +2428,13 @@ def test_schedule_skip_tokenizer_init_structured_output_request(): ...@@ -2429,13 +2428,13 @@ def test_schedule_skip_tokenizer_init_structured_output_request():
max_tokens=16, max_tokens=16,
structured_outputs=structured_outputs_params, structured_outputs=structured_outputs_params,
) )
sampling_params.update_from_generation_config({}, EOS_TOKEN_ID)
request = Request( request = Request(
request_id="0", request_id="0",
prompt_token_ids=[0, 1], prompt_token_ids=[0, 1],
mm_features=None, mm_features=None,
sampling_params=sampling_params, sampling_params=sampling_params,
pooling_params=None, pooling_params=None,
eos_token_id=EOS_TOKEN_ID,
) )
scheduler.add_request(request) scheduler.add_request(request)
output = scheduler.schedule() output = scheduler.schedule()
......
...@@ -174,6 +174,7 @@ def create_requests( ...@@ -174,6 +174,7 @@ def create_requests(
num_tokens: int = 10, num_tokens: int = 10,
mm_hashes_list: list[list[str]] | None = None, mm_hashes_list: list[list[str]] | None = None,
mm_positions: list[list[PlaceholderRange]] | None = None, mm_positions: list[list[PlaceholderRange]] | None = None,
ignore_eos: bool = False,
max_tokens: int = 16, max_tokens: int = 16,
stop_token_ids: list[int] | None = None, stop_token_ids: list[int] | None = None,
prompt_logprobs: int | None = None, prompt_logprobs: int | None = None,
...@@ -188,11 +189,12 @@ def create_requests( ...@@ -188,11 +189,12 @@ def create_requests(
block_hasher = get_request_block_hasher(block_size, sha256) block_hasher = get_request_block_hasher(block_size, sha256)
sampling_params = SamplingParams( sampling_params = SamplingParams(
ignore_eos=False, ignore_eos=ignore_eos,
max_tokens=max_tokens, max_tokens=max_tokens,
stop_token_ids=stop_token_ids, stop_token_ids=stop_token_ids,
prompt_logprobs=prompt_logprobs, prompt_logprobs=prompt_logprobs,
) )
sampling_params.update_from_generation_config({}, EOS_TOKEN_ID)
requests = [] requests = []
if mm_hashes_list is not None: if mm_hashes_list is not None:
...@@ -250,7 +252,6 @@ def create_requests( ...@@ -250,7 +252,6 @@ def create_requests(
sampling_params=sampling_params, sampling_params=sampling_params,
pooling_params=None, pooling_params=None,
mm_features=mm_features if mm_features else None, mm_features=mm_features if mm_features else None,
eos_token_id=EOS_TOKEN_ID,
block_hasher=block_hasher, block_hasher=block_hasher,
) )
requests.append(request) requests.append(request)
......
...@@ -54,7 +54,6 @@ def make_request() -> EngineCoreRequest: ...@@ -54,7 +54,6 @@ def make_request() -> EngineCoreRequest:
mm_features=None, mm_features=None,
sampling_params=SamplingParams(), sampling_params=SamplingParams(),
pooling_params=None, pooling_params=None,
eos_token_id=None,
arrival_time=time.time(), arrival_time=time.time(),
lora_request=None, lora_request=None,
cache_salt=None, cache_salt=None,
......
...@@ -69,7 +69,6 @@ def make_request( ...@@ -69,7 +69,6 @@ def make_request(
mm_features=None, mm_features=None,
sampling_params=params, sampling_params=params,
pooling_params=None, pooling_params=None,
eos_token_id=None,
arrival_time=time.time(), arrival_time=time.time(),
lora_request=None, lora_request=None,
cache_salt=None, cache_salt=None,
......
...@@ -32,7 +32,6 @@ def test_fast_inc_detok_invalid_utf8_err_case(): ...@@ -32,7 +32,6 @@ def test_fast_inc_detok_invalid_utf8_err_case():
mm_features=None, mm_features=None,
sampling_params=params, sampling_params=params,
pooling_params=None, pooling_params=None,
eos_token_id=None,
arrival_time=0.0, arrival_time=0.0,
lora_request=None, lora_request=None,
cache_salt=None, cache_salt=None,
......
...@@ -66,7 +66,6 @@ def test_incremental_detokenization( ...@@ -66,7 +66,6 @@ def test_incremental_detokenization(
external_req_id=f"request-{idx}", external_req_id=f"request-{idx}",
prompt_token_ids=prompt_tokens, prompt_token_ids=prompt_tokens,
mm_features=None, mm_features=None,
eos_token_id=None,
arrival_time=0, arrival_time=0,
lora_request=None, lora_request=None,
cache_salt=None, cache_salt=None,
...@@ -487,7 +486,6 @@ def test_logprobs_processor( ...@@ -487,7 +486,6 @@ def test_logprobs_processor(
external_req_id=request_id_list[idx], external_req_id=request_id_list[idx],
prompt_token_ids=prompt_tokens, prompt_token_ids=prompt_tokens,
mm_features=None, mm_features=None,
eos_token_id=None,
arrival_time=0, arrival_time=0,
lora_request=None, lora_request=None,
cache_salt=None, cache_salt=None,
...@@ -663,6 +661,19 @@ def test_stop_token( ...@@ -663,6 +661,19 @@ def test_stop_token(
prompt_string = dummy_test_vectors.prompt_strings[0] prompt_string = dummy_test_vectors.prompt_strings[0]
prompt_tokens = dummy_test_vectors.prompt_tokens[0] prompt_tokens = dummy_test_vectors.prompt_tokens[0]
sampling_params = SamplingParams(
skip_special_tokens=False,
spaces_between_special_tokens=False,
output_kind=RequestOutputKind.DELTA,
stop=[],
stop_token_ids=stop_token_ids,
include_stop_str_in_output=include_stop_str_in_output,
logprobs=num_sample_logprobs,
prompt_logprobs=None,
ignore_eos=ignore_eos,
)
sampling_params.update_from_generation_config({}, eos_token_id)
# Make request. # Make request.
request_id = "request-0" request_id = "request-0"
request = EngineCoreRequest( request = EngineCoreRequest(
...@@ -670,22 +681,11 @@ def test_stop_token( ...@@ -670,22 +681,11 @@ def test_stop_token(
external_req_id=request_id + "-ext", external_req_id=request_id + "-ext",
prompt_token_ids=prompt_tokens, prompt_token_ids=prompt_tokens,
mm_features=None, mm_features=None,
eos_token_id=eos_token_id,
arrival_time=0, arrival_time=0,
lora_request=None, lora_request=None,
cache_salt=None, cache_salt=None,
data_parallel_rank=None, data_parallel_rank=None,
sampling_params=SamplingParams( sampling_params=sampling_params,
skip_special_tokens=False,
spaces_between_special_tokens=False,
output_kind=RequestOutputKind.DELTA,
stop=[],
stop_token_ids=stop_token_ids,
include_stop_str_in_output=include_stop_str_in_output,
logprobs=num_sample_logprobs,
prompt_logprobs=None,
ignore_eos=ignore_eos,
),
pooling_params=None, pooling_params=None,
) )
...@@ -693,9 +693,8 @@ def test_stop_token( ...@@ -693,9 +693,8 @@ def test_stop_token(
tokens_list=[generation_tokens], tokens_list=[generation_tokens],
generated_logprobs_raw=[generation_logprobs] if do_logprobs else None, generated_logprobs_raw=[generation_logprobs] if do_logprobs else None,
prompt_logprobs_raw=None, prompt_logprobs_raw=None,
eos_token_id=eos_token_id, eos_token_id=sampling_params.eos_token_id,
stop_token_ids=stop_token_ids, stop_token_ids=sampling_params.stop_token_ids,
ignore_eos=ignore_eos,
request_ids=[request.request_id], request_ids=[request.request_id],
) )
...@@ -775,7 +774,6 @@ def test_stop_string( ...@@ -775,7 +774,6 @@ def test_stop_string(
external_req_id=request_id_list[idx], external_req_id=request_id_list[idx],
prompt_token_ids=prompt_tokens, prompt_token_ids=prompt_tokens,
mm_features=None, mm_features=None,
eos_token_id=None,
arrival_time=0, arrival_time=0,
lora_request=None, lora_request=None,
cache_salt=None, cache_salt=None,
...@@ -907,7 +905,6 @@ def test_iteration_stats(dummy_test_vectors): ...@@ -907,7 +905,6 @@ def test_iteration_stats(dummy_test_vectors):
external_req_id=f"request-{idx}-ext", external_req_id=f"request-{idx}-ext",
prompt_token_ids=prompt_tokens, prompt_token_ids=prompt_tokens,
mm_features=None, mm_features=None,
eos_token_id=None,
arrival_time=0, arrival_time=0,
lora_request=None, lora_request=None,
cache_salt=None, cache_salt=None,
...@@ -994,7 +991,6 @@ def test_lora_request_tracking(log_stats: bool, dummy_test_vectors): ...@@ -994,7 +991,6 @@ def test_lora_request_tracking(log_stats: bool, dummy_test_vectors):
external_req_id=f"request-{idx}", external_req_id=f"request-{idx}",
prompt_token_ids=prompt_tokens, prompt_token_ids=prompt_tokens,
mm_features=None, mm_features=None,
eos_token_id=None,
arrival_time=0, arrival_time=0,
lora_request=lora_assignments[idx], lora_request=lora_assignments[idx],
cache_salt=None, cache_salt=None,
...@@ -1315,7 +1311,6 @@ def test_abort_requests(runner: str, abort_by: str, dummy_test_vectors): ...@@ -1315,7 +1311,6 @@ def test_abort_requests(runner: str, abort_by: str, dummy_test_vectors):
external_req_id=f"external-{idx}", external_req_id=f"external-{idx}",
prompt_token_ids=prompt_tokens, prompt_token_ids=prompt_tokens,
mm_features=None, mm_features=None,
eos_token_id=None,
arrival_time=0, arrival_time=0,
lora_request=None, lora_request=None,
cache_salt=None, cache_salt=None,
......
...@@ -76,7 +76,6 @@ def make_request(sampling_params: SamplingParams) -> EngineCoreRequest: ...@@ -76,7 +76,6 @@ def make_request(sampling_params: SamplingParams) -> EngineCoreRequest:
mm_features=None, mm_features=None,
sampling_params=sampling_params, sampling_params=sampling_params,
pooling_params=None, pooling_params=None,
eos_token_id=None,
arrival_time=0.0, arrival_time=0.0,
lora_request=None, lora_request=None,
cache_salt=None, cache_salt=None,
......
...@@ -342,7 +342,6 @@ class MockEngineCore: ...@@ -342,7 +342,6 @@ class MockEngineCore:
prompt_logprobs_raw: list[LogprobsTensors] | None = None, prompt_logprobs_raw: list[LogprobsTensors] | None = None,
eos_token_id: int | None = None, eos_token_id: int | None = None,
stop_token_ids: list[int] | None = None, stop_token_ids: list[int] | None = None,
ignore_eos: bool = False,
request_ids: list[str] | None = None, request_ids: list[str] | None = None,
) -> None: ) -> None:
self.num_requests = len(tokens_list) self.num_requests = len(tokens_list)
...@@ -355,7 +354,6 @@ class MockEngineCore: ...@@ -355,7 +354,6 @@ class MockEngineCore:
self.request_finished = [False for _ in range(self.num_requests)] self.request_finished = [False for _ in range(self.num_requests)]
self.eos_token_id = eos_token_id self.eos_token_id = eos_token_id
self.stop_token_ids = stop_token_ids self.stop_token_ids = stop_token_ids
self.ignore_eos = ignore_eos
self.request_ids = ( self.request_ids = (
request_ids request_ids
if request_ids is not None if request_ids is not None
...@@ -400,7 +398,7 @@ class MockEngineCore: ...@@ -400,7 +398,7 @@ class MockEngineCore:
if token_idx == len(token_ids) - 1: if token_idx == len(token_ids) - 1:
output.finish_reason = FinishReason.LENGTH output.finish_reason = FinishReason.LENGTH
self.request_finished[req_idx] = True self.request_finished[req_idx] = True
if not self.ignore_eos and new_token_id == self.eos_token_id: if new_token_id == self.eos_token_id:
output.finish_reason = FinishReason.STOP output.finish_reason = FinishReason.STOP
self.request_finished[req_idx] = True self.request_finished[req_idx] = True
if new_token_id in (self.stop_token_ids or ()): if new_token_id in (self.stop_token_ids or ()):
......
...@@ -93,12 +93,14 @@ class DecodeBenchTestRunner: ...@@ -93,12 +93,14 @@ class DecodeBenchTestRunner:
"""Create a new request with given token IDs.""" """Create a new request with given token IDs."""
self.req_id += 1 self.req_id += 1
sampling_params = SamplingParams(max_tokens=100)
sampling_params.update_from_generation_config({}, EOS_TOKEN_ID)
req = Request( req = Request(
request_id=str(self.req_id), request_id=str(self.req_id),
prompt_token_ids=token_ids, prompt_token_ids=token_ids,
sampling_params=SamplingParams(max_tokens=100), sampling_params=sampling_params,
pooling_params=None, pooling_params=None,
eos_token_id=EOS_TOKEN_ID,
block_hasher=self._block_hasher, block_hasher=self._block_hasher,
) )
......
...@@ -142,12 +142,14 @@ def test_request_interface(): ...@@ -142,12 +142,14 @@ def test_request_interface():
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.v1.request import Request from vllm.v1.request import Request
sampling_params = SamplingParams(max_tokens=10)
sampling_params.update_from_generation_config({}, eos_token_id=100)
req = Request( req = Request(
request_id="test_request", request_id="test_request",
prompt_token_ids=[1, 2, 3], prompt_token_ids=[1, 2, 3],
sampling_params=SamplingParams(max_tokens=10), sampling_params=sampling_params,
pooling_params=None, pooling_params=None,
eos_token_id=100,
lora_request=None, lora_request=None,
) )
assumes(req, "mm_features", is_instance_of=(list, NoneType)) assumes(req, "mm_features", is_instance_of=(list, NoneType))
......
...@@ -226,12 +226,14 @@ class RequestRunner: ...@@ -226,12 +226,14 @@ class RequestRunner:
def new_request(self, token_ids: list[int]): def new_request(self, token_ids: list[int]):
self.req_id += 1 self.req_id += 1
sampling_params = SamplingParams(max_tokens=1000)
sampling_params.update_from_generation_config({}, EOS_TOKEN_ID)
req = Request( req = Request(
request_id=str(self.req_id), request_id=str(self.req_id),
prompt_token_ids=token_ids, prompt_token_ids=token_ids,
sampling_params=SamplingParams(max_tokens=1000), sampling_params=sampling_params,
pooling_params=None, pooling_params=None,
eos_token_id=EOS_TOKEN_ID,
block_hasher=self._block_hasher, block_hasher=self._block_hasher,
) )
......
...@@ -212,6 +212,7 @@ def create_request( ...@@ -212,6 +212,7 @@ def create_request(
max_tokens = 1 if do_remote_decode else max_tokens max_tokens = 1 if do_remote_decode else max_tokens
sampling_params = SamplingParams(max_tokens=max_tokens) sampling_params = SamplingParams(max_tokens=max_tokens)
sampling_params.update_from_generation_config({}, EOS_TOKEN_ID)
common_prefix = [1] * common_prefix_len if common_prefix_len > 0 else [] common_prefix = [1] * common_prefix_len if common_prefix_len > 0 else []
suffix = [i * request_id for i in range(num_tokens - common_prefix_len)] suffix = [i * request_id for i in range(num_tokens - common_prefix_len)]
...@@ -223,7 +224,6 @@ def create_request( ...@@ -223,7 +224,6 @@ def create_request(
sampling_params=sampling_params, sampling_params=sampling_params,
pooling_params=None, pooling_params=None,
mm_features=None, mm_features=None,
eos_token_id=EOS_TOKEN_ID,
block_hasher=get_request_block_hasher(block_size, hash_fn), block_hasher=get_request_block_hasher(block_size, hash_fn),
) )
req.kv_transfer_params = kv_transfer_params req.kv_transfer_params = kv_transfer_params
......
...@@ -43,7 +43,6 @@ class DummyRequest(Request): ...@@ -43,7 +43,6 @@ class DummyRequest(Request):
stop_token_ids=[STOP_TOKEN], max_tokens=max_tokens stop_token_ids=[STOP_TOKEN], max_tokens=max_tokens
), ),
pooling_params=None, pooling_params=None,
eos_token_id=None,
mm_features=mm_features, mm_features=mm_features,
resumable=resumable, resumable=resumable,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment