Unverified Commit 8bb9a555 authored by Jacky's avatar Jacky Committed by GitHub
Browse files

feat: Enable vLLM abort while engine is generating the next token (#3102)


Signed-off-by: default avatarJacky <18255193+kthui@users.noreply.github.com>
parent ce36d9f4
...@@ -6,6 +6,7 @@ import logging ...@@ -6,6 +6,7 @@ import logging
import os import os
import uuid import uuid
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from contextlib import asynccontextmanager
from copy import deepcopy from copy import deepcopy
from typing import AsyncGenerator from typing import AsyncGenerator
...@@ -40,6 +41,36 @@ class BaseWorkerHandler(ABC): ...@@ -40,6 +41,36 @@ class BaseWorkerHandler(ABC):
async def generate(self, request, context) -> AsyncGenerator[dict, None]: async def generate(self, request, context) -> AsyncGenerator[dict, None]:
raise NotImplementedError raise NotImplementedError
async def _monitor_abort(self, context, request_id, is_prefill):
"""Background task that monitors for context cancellation and aborts the request."""
try:
await context.async_killed_or_stopped()
# If we reach here, the context was stopped or killed
await self.engine_client.abort(request_id)
logger.debug(
f"Aborted {'Prefill ' if is_prefill else ''}Request ID: {request_id}"
)
except asyncio.CancelledError:
# Task was cancelled, normal cleanup if not aborted
pass
except Exception as e:
logger.error(f"Error in abort monitor for request {request_id}: {e}")
@asynccontextmanager
async def _abort_monitor(self, context, request_id, is_prefill=False):
"""Context manager that creates and automatically cleans up an abort monitoring task."""
task = asyncio.create_task(self._monitor_abort(context, request_id, is_prefill))
try:
yield task
finally:
# Cancel the abort monitoring task when exiting the context
if not task.done():
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
async def clear_kv_blocks(self, request=None): async def clear_kv_blocks(self, request=None):
try: try:
await self.engine_client.reset_prefix_cache() await self.engine_client.reset_prefix_cache()
...@@ -202,10 +233,8 @@ class DecodeWorkerHandler(BaseWorkerHandler): ...@@ -202,10 +233,8 @@ class DecodeWorkerHandler(BaseWorkerHandler):
) )
except Exception as e: except Exception as e:
# TODO: Cancellation does not propagate until the first token is received
if context.is_stopped() or context.is_killed(): if context.is_stopped() or context.is_killed():
logger.debug(f"Aborted Remote Prefill Request ID: {request_id}") logger.debug(f"Aborted Remote Prefill Request ID: {request_id}")
# TODO: Raise asyncio.CancelledError into bindings
return return
raise e raise e
...@@ -229,21 +258,17 @@ class DecodeWorkerHandler(BaseWorkerHandler): ...@@ -229,21 +258,17 @@ class DecodeWorkerHandler(BaseWorkerHandler):
"kv_transfer_params" "kv_transfer_params"
] = prefill_response.kv_transfer_params ] = prefill_response.kv_transfer_params
try: async with self._abort_monitor(context, request_id):
async for tok in self.generate_tokens(prompt, sampling_params, request_id): try:
if context.is_stopped() or context.is_killed(): async for tok in self.generate_tokens(
await self.engine_client.abort(request_id) prompt, sampling_params, request_id
logger.debug(f"Aborted Request ID: {request_id}") ):
# TODO: Raise asyncio.CancelledError into bindings yield tok
break except EngineDeadError as e:
logger.error(f"vLLM EngineDeadError: {e}")
yield tok logger.warning("Initiating Dynamo Runtime shutdown.")
self.runtime.shutdown()
except EngineDeadError as e: os._exit(1)
logger.error(f"vLLM EngineDeadError: {e}")
logger.warning("Initiating Dynamo Runtime shutdown.")
self.runtime.shutdown()
os._exit(1)
class PrefillWorkerHandler(BaseWorkerHandler): class PrefillWorkerHandler(BaseWorkerHandler):
...@@ -257,36 +282,31 @@ class PrefillWorkerHandler(BaseWorkerHandler): ...@@ -257,36 +282,31 @@ class PrefillWorkerHandler(BaseWorkerHandler):
prompt = TokensPrompt(prompt_token_ids=request["token_ids"]) prompt = TokensPrompt(prompt_token_ids=request["token_ids"])
sampling_params = msgspec.convert(request["sampling_params"], SamplingParams) sampling_params = msgspec.convert(request["sampling_params"], SamplingParams)
try: async with self._abort_monitor(context, request_id, is_prefill=True):
gen = self.engine_client.generate(prompt, sampling_params, request_id) try:
except EngineDeadError as e: gen = self.engine_client.generate(prompt, sampling_params, request_id)
logger.error(f"vLLM EngineDeadError: {e}") except EngineDeadError as e:
logger.warning("Initiating Dynamo Runtime shutdown.") logger.error(f"vLLM EngineDeadError: {e}")
self.runtime.shutdown() logger.warning("Initiating Dynamo Runtime shutdown.")
os._exit(1) self.runtime.shutdown()
os._exit(1)
# Generate only 1 token in prefill
try: # Generate only 1 token in prefill
async for res in gen: try:
if context.is_stopped() or context.is_killed(): async for res in gen:
await self.engine_client.abort(request_id) logger.debug(f"kv transfer params: {res.kv_transfer_params}")
logger.debug(f"Aborted Prefill Request ID: {request_id}") yield MyRequestOutput(
# TODO: Raise asyncio.CancelledError into bindings request_id=res.request_id,
break prompt=res.prompt,
prompt_token_ids=res.prompt_token_ids,
logger.debug(f"kv transfer params: {res.kv_transfer_params}") prompt_logprobs=res.prompt_logprobs,
yield MyRequestOutput( outputs=res.outputs,
request_id=res.request_id, finished=res.finished,
prompt=res.prompt, metrics=res.metrics,
prompt_token_ids=res.prompt_token_ids, kv_transfer_params=res.kv_transfer_params,
prompt_logprobs=res.prompt_logprobs, ).model_dump_json()
outputs=res.outputs, except asyncio.CancelledError:
finished=res.finished, # raise the error because we cannot migrate prefill requests
metrics=res.metrics, raise GeneratorExit(
kv_transfer_params=res.kv_transfer_params, "Prefill engine was shut down during token generation"
).model_dump_json() ) from None
except asyncio.CancelledError:
# raise the error because we cannot migrate prefill requests
raise GeneratorExit(
"Prefill engine was shut down during token generation"
) from None
...@@ -412,9 +412,7 @@ def test_request_cancellation_vllm(request, runtime_services, predownload_models ...@@ -412,9 +412,7 @@ def test_request_cancellation_vllm(request, runtime_services, predownload_models
logger.info( logger.info(
"Checking for cancellation messages in worker and frontend logs..." "Checking for cancellation messages in worker and frontend logs..."
) )
# TODO: Need to wait for the next token to generate before seeing the time.sleep(0.05) # time for cancellation to propagate
# cancellation on the logs. DIS-625
time.sleep(0.5)
frontend_log_offset, worker_log_offset = verify_request_cancelled( frontend_log_offset, worker_log_offset = verify_request_cancelled(
frontend, frontend,
worker, worker,
...@@ -467,9 +465,7 @@ def test_request_cancellation_vllm_decode( ...@@ -467,9 +465,7 @@ def test_request_cancellation_vllm_decode(
logger.info( logger.info(
"Checking for cancellation messages in decode and prefill worker and frontend logs..." "Checking for cancellation messages in decode and prefill worker and frontend logs..."
) )
# TODO: Need to wait for the next token to generate before seeing the time.sleep(0.05) # time for cancellation to propagate
# cancellation on the logs. DIS-625
time.sleep(0.5)
verify_request_cancelled(frontend, decode_worker, prefill_worker) verify_request_cancelled(frontend, decode_worker, prefill_worker)
...@@ -507,10 +503,6 @@ def test_request_cancellation_vllm_prefill( ...@@ -507,10 +503,6 @@ def test_request_cancellation_vllm_prefill(
with decode_worker: with decode_worker:
logger.info(f"Decode Worker PID: {decode_worker.get_pid()}") logger.info(f"Decode Worker PID: {decode_worker.get_pid()}")
# TODO: Why the model is not immediately available at the frontend after
# health check returns success.
time.sleep(2)
# Step 4: Test request cancellation for completion scenario only # Step 4: Test request cancellation for completion scenario only
logger.info( logger.info(
"Testing completion request cancellation in prefill worker..." "Testing completion request cancellation in prefill worker..."
...@@ -520,9 +512,7 @@ def test_request_cancellation_vllm_prefill( ...@@ -520,9 +512,7 @@ def test_request_cancellation_vllm_prefill(
logger.info( logger.info(
"Checking for cancellation messages in decode and prefill worker and frontend logs..." "Checking for cancellation messages in decode and prefill worker and frontend logs..."
) )
# TODO: Need to wait for prefill to generate first token before seeing time.sleep(0.05) # time for cancellation to propagate
# the cancellation on the logs. DIS-625
time.sleep(3)
verify_request_cancelled( verify_request_cancelled(
frontend, frontend,
decode_worker, decode_worker,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment