"vscode:/vscode.git/clone" did not exist on "8679faa35d2ed872babd1f8d160f5f259e01d1aa"
Unverified Commit 8bb9a555 authored by Jacky's avatar Jacky Committed by GitHub
Browse files

feat: Enable vLLM abort while engine is generating the next token (#3102)


Signed-off-by: default avatarJacky <18255193+kthui@users.noreply.github.com>
parent ce36d9f4
......@@ -6,6 +6,7 @@ import logging
import os
import uuid
from abc import ABC, abstractmethod
from contextlib import asynccontextmanager
from copy import deepcopy
from typing import AsyncGenerator
......@@ -40,6 +41,36 @@ class BaseWorkerHandler(ABC):
async def generate(self, request, context) -> AsyncGenerator[dict, None]:
raise NotImplementedError
async def _monitor_abort(self, context, request_id, is_prefill):
"""Background task that monitors for context cancellation and aborts the request."""
try:
await context.async_killed_or_stopped()
# If we reach here, the context was stopped or killed
await self.engine_client.abort(request_id)
logger.debug(
f"Aborted {'Prefill ' if is_prefill else ''}Request ID: {request_id}"
)
except asyncio.CancelledError:
# Task was cancelled, normal cleanup if not aborted
pass
except Exception as e:
logger.error(f"Error in abort monitor for request {request_id}: {e}")
@asynccontextmanager
async def _abort_monitor(self, context, request_id, is_prefill=False):
"""Context manager that creates and automatically cleans up an abort monitoring task."""
task = asyncio.create_task(self._monitor_abort(context, request_id, is_prefill))
try:
yield task
finally:
# Cancel the abort monitoring task when exiting the context
if not task.done():
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
async def clear_kv_blocks(self, request=None):
try:
await self.engine_client.reset_prefix_cache()
......@@ -202,10 +233,8 @@ class DecodeWorkerHandler(BaseWorkerHandler):
)
except Exception as e:
# TODO: Cancellation does not propagate until the first token is received
if context.is_stopped() or context.is_killed():
logger.debug(f"Aborted Remote Prefill Request ID: {request_id}")
# TODO: Raise asyncio.CancelledError into bindings
return
raise e
......@@ -229,21 +258,17 @@ class DecodeWorkerHandler(BaseWorkerHandler):
"kv_transfer_params"
] = prefill_response.kv_transfer_params
try:
async for tok in self.generate_tokens(prompt, sampling_params, request_id):
if context.is_stopped() or context.is_killed():
await self.engine_client.abort(request_id)
logger.debug(f"Aborted Request ID: {request_id}")
# TODO: Raise asyncio.CancelledError into bindings
break
yield tok
except EngineDeadError as e:
logger.error(f"vLLM EngineDeadError: {e}")
logger.warning("Initiating Dynamo Runtime shutdown.")
self.runtime.shutdown()
os._exit(1)
async with self._abort_monitor(context, request_id):
try:
async for tok in self.generate_tokens(
prompt, sampling_params, request_id
):
yield tok
except EngineDeadError as e:
logger.error(f"vLLM EngineDeadError: {e}")
logger.warning("Initiating Dynamo Runtime shutdown.")
self.runtime.shutdown()
os._exit(1)
class PrefillWorkerHandler(BaseWorkerHandler):
......@@ -257,36 +282,31 @@ class PrefillWorkerHandler(BaseWorkerHandler):
prompt = TokensPrompt(prompt_token_ids=request["token_ids"])
sampling_params = msgspec.convert(request["sampling_params"], SamplingParams)
try:
gen = self.engine_client.generate(prompt, sampling_params, request_id)
except EngineDeadError as e:
logger.error(f"vLLM EngineDeadError: {e}")
logger.warning("Initiating Dynamo Runtime shutdown.")
self.runtime.shutdown()
os._exit(1)
# Generate only 1 token in prefill
try:
async for res in gen:
if context.is_stopped() or context.is_killed():
await self.engine_client.abort(request_id)
logger.debug(f"Aborted Prefill Request ID: {request_id}")
# TODO: Raise asyncio.CancelledError into bindings
break
logger.debug(f"kv transfer params: {res.kv_transfer_params}")
yield MyRequestOutput(
request_id=res.request_id,
prompt=res.prompt,
prompt_token_ids=res.prompt_token_ids,
prompt_logprobs=res.prompt_logprobs,
outputs=res.outputs,
finished=res.finished,
metrics=res.metrics,
kv_transfer_params=res.kv_transfer_params,
).model_dump_json()
except asyncio.CancelledError:
# raise the error because we cannot migrate prefill requests
raise GeneratorExit(
"Prefill engine was shut down during token generation"
) from None
async with self._abort_monitor(context, request_id, is_prefill=True):
try:
gen = self.engine_client.generate(prompt, sampling_params, request_id)
except EngineDeadError as e:
logger.error(f"vLLM EngineDeadError: {e}")
logger.warning("Initiating Dynamo Runtime shutdown.")
self.runtime.shutdown()
os._exit(1)
# Generate only 1 token in prefill
try:
async for res in gen:
logger.debug(f"kv transfer params: {res.kv_transfer_params}")
yield MyRequestOutput(
request_id=res.request_id,
prompt=res.prompt,
prompt_token_ids=res.prompt_token_ids,
prompt_logprobs=res.prompt_logprobs,
outputs=res.outputs,
finished=res.finished,
metrics=res.metrics,
kv_transfer_params=res.kv_transfer_params,
).model_dump_json()
except asyncio.CancelledError:
# raise the error because we cannot migrate prefill requests
raise GeneratorExit(
"Prefill engine was shut down during token generation"
) from None
......@@ -412,9 +412,7 @@ def test_request_cancellation_vllm(request, runtime_services, predownload_models
logger.info(
"Checking for cancellation messages in worker and frontend logs..."
)
# TODO: Need to wait for the next token to generate before seeing the
# cancellation on the logs. DIS-625
time.sleep(0.5)
time.sleep(0.05) # time for cancellation to propagate
frontend_log_offset, worker_log_offset = verify_request_cancelled(
frontend,
worker,
......@@ -467,9 +465,7 @@ def test_request_cancellation_vllm_decode(
logger.info(
"Checking for cancellation messages in decode and prefill worker and frontend logs..."
)
# TODO: Need to wait for the next token to generate before seeing the
# cancellation on the logs. DIS-625
time.sleep(0.5)
time.sleep(0.05) # time for cancellation to propagate
verify_request_cancelled(frontend, decode_worker, prefill_worker)
......@@ -507,10 +503,6 @@ def test_request_cancellation_vllm_prefill(
with decode_worker:
logger.info(f"Decode Worker PID: {decode_worker.get_pid()}")
# TODO: Why the model is not immediately available at the frontend after
# health check returns success.
time.sleep(2)
# Step 4: Test request cancellation for completion scenario only
logger.info(
"Testing completion request cancellation in prefill worker..."
......@@ -520,9 +512,7 @@ def test_request_cancellation_vllm_prefill(
logger.info(
"Checking for cancellation messages in decode and prefill worker and frontend logs..."
)
# TODO: Need to wait for prefill to generate first token before seeing
# the cancellation on the logs. DIS-625
time.sleep(3)
time.sleep(0.05) # time for cancellation to propagate
verify_request_cancelled(
frontend,
decode_worker,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment