Unverified Commit fe6b19c3 authored by wang.yuqi's avatar wang.yuqi Committed by GitHub
Browse files

[Bugfix] Properly abort pooling request. (#25734)


Signed-off-by: default avatarwang.yuqi <noooop@126.com>
Co-authored-by: default avatarCyrus Leung <tlleungac@connect.ust.hk>
parent 2827b3f4
...@@ -12,6 +12,7 @@ from tests.v1.engine.utils import (NUM_PROMPT_LOGPROBS_UNDER_TEST, ...@@ -12,6 +12,7 @@ from tests.v1.engine.utils import (NUM_PROMPT_LOGPROBS_UNDER_TEST,
STOP_STRINGS, STOP_STRINGS,
DummyOutputProcessorTestVectors, DummyOutputProcessorTestVectors,
MockEngineCore) MockEngineCore)
from vllm import PoolingParams
from vllm.logprobs import PromptLogprobs, SampleLogprobs from vllm.logprobs import PromptLogprobs, SampleLogprobs
from vllm.outputs import CompletionOutput, RequestOutput from vllm.outputs import CompletionOutput, RequestOutput
from vllm.sampling_params import RequestOutputKind, SamplingParams from vllm.sampling_params import RequestOutputKind, SamplingParams
...@@ -998,3 +999,35 @@ async def test_cumulative_output_collector_n(): ...@@ -998,3 +999,35 @@ async def test_cumulative_output_collector_n():
third = [k for k in result.outputs if k.index == 2] third = [k for k in result.outputs if k.index == 2]
assert len(third) == 1 assert len(third) == 1
assert third[0].text == "c" assert third[0].text == "c"
@pytest.mark.parametrize("runner", ["generate", "pooling"])
def test_abort_requests(runner: str, dummy_test_vectors):
output_processor = OutputProcessor(dummy_test_vectors.tokenizer,
log_stats=True)
requests = [
EngineCoreRequest(
request_id=f"request-{idx}",
prompt_token_ids=prompt_tokens,
mm_features=None,
eos_token_id=None,
arrival_time=0,
lora_request=None,
cache_salt=None,
data_parallel_rank=None,
sampling_params=SamplingParams() if runner == "generate" else None,
pooling_params=PoolingParams(
task="embed") if runner == "pooling" else None,
) for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens)
]
for request in requests:
if runner == "generate":
output_kind = request.sampling_params.output_kind
else:
output_kind = request.pooling_params.output_kind
queue = RequestOutputCollector(output_kind=output_kind)
output_processor.add_request(request, None, queue=queue)
for request in requests:
output_processor.abort_requests([request.request_id])
...@@ -335,7 +335,14 @@ class OutputProcessor: ...@@ -335,7 +335,14 @@ class OutputProcessor:
# Produce final abort output. # Produce final abort output.
if req_state.queue is not None and ( if req_state.queue is not None and (
request_output := req_state.make_request_output( request_output := req_state.make_request_output(
[], None, FinishReason.ABORT, None, None)): new_token_ids=[],
# Set pooling_output is not None to
# correctly enter the abort pooling branch
pooling_output=torch.randn(0, device="cpu")
if req_state.detokenizer is None else None,
finish_reason=FinishReason.ABORT,
stop_reason=None,
kv_transfer_params=None)):
req_state.queue.put(request_output) req_state.queue.put(request_output)
elif parent := self.parent_requests.get(request_id): elif parent := self.parent_requests.get(request_id):
# Abort children prior to removing the parent. # Abort children prior to removing the parent.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment