"deploy/vscode:/vscode.git/clone" did not exist on "61f16716bc16fe15d692968335a40f23457da9da"
Unverified Commit 8bb9a555 authored by Jacky's avatar Jacky Committed by GitHub
Browse files

feat: Enable vLLM abort while engine is generating the next token (#3102)


Signed-off-by: default avatarJacky <18255193+kthui@users.noreply.github.com>
parent ce36d9f4
...@@ -6,6 +6,7 @@ import logging ...@@ -6,6 +6,7 @@ import logging
import os import os
import uuid import uuid
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from contextlib import asynccontextmanager
from copy import deepcopy from copy import deepcopy
from typing import AsyncGenerator from typing import AsyncGenerator
...@@ -40,6 +41,36 @@ class BaseWorkerHandler(ABC): ...@@ -40,6 +41,36 @@ class BaseWorkerHandler(ABC):
async def generate(self, request, context) -> AsyncGenerator[dict, None]: async def generate(self, request, context) -> AsyncGenerator[dict, None]:
raise NotImplementedError raise NotImplementedError
async def _monitor_abort(self, context, request_id, is_prefill):
"""Background task that monitors for context cancellation and aborts the request."""
try:
await context.async_killed_or_stopped()
# If we reach here, the context was stopped or killed
await self.engine_client.abort(request_id)
logger.debug(
f"Aborted {'Prefill ' if is_prefill else ''}Request ID: {request_id}"
)
except asyncio.CancelledError:
# Task was cancelled, normal cleanup if not aborted
pass
except Exception as e:
logger.error(f"Error in abort monitor for request {request_id}: {e}")
@asynccontextmanager
async def _abort_monitor(self, context, request_id, is_prefill=False):
"""Context manager that creates and automatically cleans up an abort monitoring task."""
task = asyncio.create_task(self._monitor_abort(context, request_id, is_prefill))
try:
yield task
finally:
# Cancel the abort monitoring task when exiting the context
if not task.done():
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
async def clear_kv_blocks(self, request=None): async def clear_kv_blocks(self, request=None):
try: try:
await self.engine_client.reset_prefix_cache() await self.engine_client.reset_prefix_cache()
...@@ -202,10 +233,8 @@ class DecodeWorkerHandler(BaseWorkerHandler): ...@@ -202,10 +233,8 @@ class DecodeWorkerHandler(BaseWorkerHandler):
) )
except Exception as e: except Exception as e:
# TODO: Cancellation does not propagate until the first token is received
if context.is_stopped() or context.is_killed(): if context.is_stopped() or context.is_killed():
logger.debug(f"Aborted Remote Prefill Request ID: {request_id}") logger.debug(f"Aborted Remote Prefill Request ID: {request_id}")
# TODO: Raise asyncio.CancelledError into bindings
return return
raise e raise e
...@@ -229,16 +258,12 @@ class DecodeWorkerHandler(BaseWorkerHandler): ...@@ -229,16 +258,12 @@ class DecodeWorkerHandler(BaseWorkerHandler):
"kv_transfer_params" "kv_transfer_params"
] = prefill_response.kv_transfer_params ] = prefill_response.kv_transfer_params
async with self._abort_monitor(context, request_id):
try: try:
async for tok in self.generate_tokens(prompt, sampling_params, request_id): async for tok in self.generate_tokens(
if context.is_stopped() or context.is_killed(): prompt, sampling_params, request_id
await self.engine_client.abort(request_id) ):
logger.debug(f"Aborted Request ID: {request_id}")
# TODO: Raise asyncio.CancelledError into bindings
break
yield tok yield tok
except EngineDeadError as e: except EngineDeadError as e:
logger.error(f"vLLM EngineDeadError: {e}") logger.error(f"vLLM EngineDeadError: {e}")
logger.warning("Initiating Dynamo Runtime shutdown.") logger.warning("Initiating Dynamo Runtime shutdown.")
...@@ -257,6 +282,7 @@ class PrefillWorkerHandler(BaseWorkerHandler): ...@@ -257,6 +282,7 @@ class PrefillWorkerHandler(BaseWorkerHandler):
prompt = TokensPrompt(prompt_token_ids=request["token_ids"]) prompt = TokensPrompt(prompt_token_ids=request["token_ids"])
sampling_params = msgspec.convert(request["sampling_params"], SamplingParams) sampling_params = msgspec.convert(request["sampling_params"], SamplingParams)
async with self._abort_monitor(context, request_id, is_prefill=True):
try: try:
gen = self.engine_client.generate(prompt, sampling_params, request_id) gen = self.engine_client.generate(prompt, sampling_params, request_id)
except EngineDeadError as e: except EngineDeadError as e:
...@@ -268,12 +294,6 @@ class PrefillWorkerHandler(BaseWorkerHandler): ...@@ -268,12 +294,6 @@ class PrefillWorkerHandler(BaseWorkerHandler):
# Generate only 1 token in prefill # Generate only 1 token in prefill
try: try:
async for res in gen: async for res in gen:
if context.is_stopped() or context.is_killed():
await self.engine_client.abort(request_id)
logger.debug(f"Aborted Prefill Request ID: {request_id}")
# TODO: Raise asyncio.CancelledError into bindings
break
logger.debug(f"kv transfer params: {res.kv_transfer_params}") logger.debug(f"kv transfer params: {res.kv_transfer_params}")
yield MyRequestOutput( yield MyRequestOutput(
request_id=res.request_id, request_id=res.request_id,
......
...@@ -412,9 +412,7 @@ def test_request_cancellation_vllm(request, runtime_services, predownload_models ...@@ -412,9 +412,7 @@ def test_request_cancellation_vllm(request, runtime_services, predownload_models
logger.info( logger.info(
"Checking for cancellation messages in worker and frontend logs..." "Checking for cancellation messages in worker and frontend logs..."
) )
# TODO: Need to wait for the next token to generate before seeing the time.sleep(0.05) # time for cancellation to propagate
# cancellation on the logs. DIS-625
time.sleep(0.5)
frontend_log_offset, worker_log_offset = verify_request_cancelled( frontend_log_offset, worker_log_offset = verify_request_cancelled(
frontend, frontend,
worker, worker,
...@@ -467,9 +465,7 @@ def test_request_cancellation_vllm_decode( ...@@ -467,9 +465,7 @@ def test_request_cancellation_vllm_decode(
logger.info( logger.info(
"Checking for cancellation messages in decode and prefill worker and frontend logs..." "Checking for cancellation messages in decode and prefill worker and frontend logs..."
) )
# TODO: Need to wait for the next token to generate before seeing the time.sleep(0.05) # time for cancellation to propagate
# cancellation on the logs. DIS-625
time.sleep(0.5)
verify_request_cancelled(frontend, decode_worker, prefill_worker) verify_request_cancelled(frontend, decode_worker, prefill_worker)
...@@ -507,10 +503,6 @@ def test_request_cancellation_vllm_prefill( ...@@ -507,10 +503,6 @@ def test_request_cancellation_vllm_prefill(
with decode_worker: with decode_worker:
logger.info(f"Decode Worker PID: {decode_worker.get_pid()}") logger.info(f"Decode Worker PID: {decode_worker.get_pid()}")
# TODO: Why the model is not immediately available at the frontend after
# health check returns success.
time.sleep(2)
# Step 4: Test request cancellation for completion scenario only # Step 4: Test request cancellation for completion scenario only
logger.info( logger.info(
"Testing completion request cancellation in prefill worker..." "Testing completion request cancellation in prefill worker..."
...@@ -520,9 +512,7 @@ def test_request_cancellation_vllm_prefill( ...@@ -520,9 +512,7 @@ def test_request_cancellation_vllm_prefill(
logger.info( logger.info(
"Checking for cancellation messages in decode and prefill worker and frontend logs..." "Checking for cancellation messages in decode and prefill worker and frontend logs..."
) )
# TODO: Need to wait for prefill to generate first token before seeing time.sleep(0.05) # time for cancellation to propagate
# the cancellation on the logs. DIS-625
time.sleep(3)
verify_request_cancelled( verify_request_cancelled(
frontend, frontend,
decode_worker, decode_worker,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment