Unverified Commit e69a2651 authored by Walter Beller-Morales's avatar Walter Beller-Morales Committed by GitHub
Browse files

[Feat][Core] safely abort requests when FSM fails to advance (#38663)


Signed-off-by: default avatarwalterbm <walter.beller.morales@gmail.com>
parent fef56c18
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections import deque from collections import deque
from unittest.mock import Mock
import pytest import pytest
from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.async_scheduler import AsyncScheduler
from vllm.v1.core.sched.output import CachedRequestData, SchedulerOutput
from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.request import RequestStatus from vllm.v1.request import RequestStatus
from vllm.v1.utils import ConstantList from vllm.v1.utils import ConstantList
...@@ -247,3 +249,66 @@ def test_prefix_caching_for_multi_turn(): ...@@ -247,3 +249,66 @@ def test_prefix_caching_for_multi_turn():
# requests. # requests.
for req in next_turn_requests: for req in next_turn_requests:
assert req.num_cached_tokens == req.num_prompt_tokens // BLOCK_SIZE * BLOCK_SIZE assert req.num_cached_tokens == req.num_prompt_tokens // BLOCK_SIZE * BLOCK_SIZE
def test_abort_request_when_structured_output_fsm_cannot_advance():
scheduler = object.__new__(AsyncScheduler)
request = create_requests(num_requests=1, num_tokens=1)[0]
request.structured_output_request = Mock()
request.structured_output_request.grammar = Mock()
request.structured_output_request.grammar.accept_tokens.return_value = False
request.status = RequestStatus.RUNNING
request.num_computed_tokens = request.num_tokens
request.num_output_placeholders = 1
scheduler.perf_metrics = None
scheduler.connector = None
scheduler.structured_output_manager = Mock()
scheduler.structured_output_manager.should_advance.return_value = True
scheduler.requests = {request.request_id: request}
scheduler.running = [request]
scheduler.waiting = Mock()
scheduler.kv_cache_manager = Mock()
scheduler.kv_cache_manager.take_events.return_value = None
scheduler.kv_event_publisher = Mock()
scheduler.finished_req_ids = set()
scheduler.finished_req_ids_dict = None
scheduler.vllm_config = Mock()
scheduler.vllm_config.model_config.enable_return_routed_experts = False
scheduler.recompute_kv_load_failures = False
scheduler.make_stats = Mock(return_value=None)
scheduler.max_model_len = 128
def free_request(req, delay_free_blocks=False):
scheduler.finished_req_ids.add(req.request_id)
scheduler.requests.pop(req.request_id, None)
return None
scheduler._free_request = Mock(side_effect=free_request)
output = SchedulerOutput(
scheduled_new_reqs=[],
scheduled_cached_reqs=CachedRequestData.make_empty(),
num_scheduled_tokens={request.request_id: 1},
total_num_scheduled_tokens=1,
scheduled_encoder_inputs={},
scheduled_spec_decode_tokens={},
num_common_prefix_blocks=[],
finished_req_ids=set(),
free_encoder_mm_hashes=[],
)
model_runner_output = ModelRunnerOutput(
req_ids=[request.request_id],
req_id_to_index={request.request_id: 0},
sampled_token_ids=[[123]],
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
)
scheduler.update_from_output(output, model_runner_output)
assert request.resumable is False
assert request.status == RequestStatus.FINISHED_ERROR
assert request.request_id not in scheduler.requests
assert not scheduler.running
...@@ -26,6 +26,7 @@ from vllm.v1.core.encoder_cache_manager import EncoderCacheManager ...@@ -26,6 +26,7 @@ from vllm.v1.core.encoder_cache_manager import EncoderCacheManager
from vllm.v1.core.kv_cache_utils import get_request_block_hasher, init_none_hash from vllm.v1.core.kv_cache_utils import get_request_block_hasher, init_none_hash
from vllm.v1.core.sched.output import CachedRequestData, SchedulerOutput from vllm.v1.core.sched.output import CachedRequestData, SchedulerOutput
from vllm.v1.core.sched.scheduler import Scheduler from vllm.v1.core.sched.scheduler import Scheduler
from vllm.v1.engine import FinishReason
from vllm.v1.kv_cache_interface import ( from vllm.v1.kv_cache_interface import (
FullAttentionSpec, FullAttentionSpec,
KVCacheConfig, KVCacheConfig,
...@@ -2463,6 +2464,86 @@ def test_schedule_skip_tokenizer_init_structured_output_request(): ...@@ -2463,6 +2464,86 @@ def test_schedule_skip_tokenizer_init_structured_output_request():
assert len(scheduler.skipped_waiting) == 1 assert len(scheduler.skipped_waiting) == 1
def test_abort_request_when_structured_output_fsm_cannot_advance():
scheduler = object.__new__(Scheduler)
sampling_params = SamplingParams(ignore_eos=True, max_tokens=4)
sampling_params.update_from_generation_config({}, EOS_TOKEN_ID)
request = Request(
request_id="0",
prompt_token_ids=[0, 1],
mm_features=None,
sampling_params=sampling_params,
pooling_params=None,
)
request.structured_output_request = Mock()
request.structured_output_request.grammar = Mock()
request.structured_output_request.grammar.accept_tokens.return_value = False
request.status = RequestStatus.RUNNING
request.num_computed_tokens = request.num_tokens
scheduler.perf_metrics = None
scheduler.connector = None
scheduler.structured_output_manager = Mock()
scheduler.structured_output_manager.should_advance.return_value = True
scheduler.requests = {request.request_id: request}
scheduler.running = [request]
scheduler.waiting = Mock()
scheduler.kv_cache_manager = Mock()
scheduler.kv_cache_manager.take_events.return_value = None
scheduler.kv_event_publisher = Mock()
scheduler.finished_req_ids = set()
scheduler.finished_req_ids_dict = None
scheduler.vllm_config = Mock()
scheduler.vllm_config.model_config.enable_return_routed_experts = False
scheduler.recompute_kv_load_failures = False
scheduler.make_stats = Mock(return_value=None)
scheduler.max_model_len = 128
def free_request(req: Request, delay_free_blocks: bool = False):
scheduler.finished_req_ids.add(req.request_id)
scheduler.requests.pop(req.request_id, None)
return None
scheduler._free_request = Mock(side_effect=free_request)
output = SchedulerOutput(
scheduled_new_reqs=[],
scheduled_cached_reqs=CachedRequestData.make_empty(),
num_scheduled_tokens={request.request_id: 1},
total_num_scheduled_tokens=1,
scheduled_encoder_inputs={},
scheduled_spec_decode_tokens={},
num_common_prefix_blocks=[],
finished_req_ids=set(),
free_encoder_mm_hashes=[],
)
model_runner_output = ModelRunnerOutput(
req_ids=[request.request_id],
req_id_to_index={request.request_id: 0},
sampled_token_ids=[[123]],
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
)
engine_core_outputs = scheduler.update_from_output(output, model_runner_output)
request.structured_output_request.grammar.accept_tokens.assert_called_once_with(
request.request_id, [123]
)
assert request.resumable is False
assert request.status == RequestStatus.FINISHED_ERROR
assert request.request_id not in scheduler.requests
assert not scheduler.running
scheduler._free_request.assert_called_once_with(request)
assert len(engine_core_outputs[0].outputs) == 1
engine_core_output = engine_core_outputs[0].outputs[0]
assert engine_core_output.request_id == request.request_id
assert engine_core_output.new_token_ids == [123]
assert engine_core_output.finish_reason == FinishReason.ERROR
@pytest.mark.parametrize( @pytest.mark.parametrize(
"use_ec_connector, ec_role", [(False, None), (True, "ec_consumer")] "use_ec_connector, ec_role", [(False, None), (True, "ec_consumer")]
) )
......
...@@ -1406,6 +1406,23 @@ class Scheduler(SchedulerInterface): ...@@ -1406,6 +1406,23 @@ class Scheduler(SchedulerInterface):
request.status = RequestStatus.FINISHED_STOPPED request.status = RequestStatus.FINISHED_STOPPED
stopped = True stopped = True
if new_token_ids and self.structured_output_manager.should_advance(request):
struct_output_request = request.structured_output_request
assert struct_output_request is not None
assert struct_output_request.grammar is not None
if not struct_output_request.grammar.accept_tokens( # type: ignore[union-attr]
req_id, new_token_ids
):
logger.error(
"Unexpected: grammar rejected tokens %s for request %s. "
"Terminating request.",
new_token_ids,
req_id,
)
request.status = RequestStatus.FINISHED_ERROR
request.resumable = False
stopped = True
routed_experts = None routed_experts = None
finish_reason = None finish_reason = None
if stopped: if stopped:
...@@ -1431,18 +1448,6 @@ class Scheduler(SchedulerInterface): ...@@ -1431,18 +1448,6 @@ class Scheduler(SchedulerInterface):
): ):
new_logprobs = logprobs.slice_request(req_index, len(new_token_ids)) new_logprobs = logprobs.slice_request(req_index, len(new_token_ids))
if new_token_ids and self.structured_output_manager.should_advance(request):
struct_output_request = request.structured_output_request
assert struct_output_request is not None
assert struct_output_request.grammar is not None
ok = struct_output_request.grammar.accept_tokens(req_id, new_token_ids)
if not ok:
logger.warning(
"Unexpected: grammar rejected tokens %s for request %s.",
new_token_ids,
req_id,
)
if num_nans_in_logits is not None and req_id in num_nans_in_logits: if num_nans_in_logits is not None and req_id in num_nans_in_logits:
request.num_nans_in_logits = num_nans_in_logits[req_id] request.num_nans_in_logits = num_nans_in_logits[req_id]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment