Unverified Commit e0cd8489 authored by Hongkuan Zhou's avatar Hongkuan Zhou Committed by GitHub
Browse files

feat: multi-thread (via asyncio.task) in processor (#904)

parent 191748e0
......@@ -13,10 +13,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import logging
import uuid
from enum import Enum
from typing import AsyncIterator, Tuple, Union
from typing import Any, AsyncIterator, Dict, List, Tuple, Union
from components.kv_router import Router
from components.worker import VllmWorker
......@@ -68,6 +69,12 @@ class Processor(ProcessMixIn):
self.tokenizer, self.model_config
)
self.min_workers = 1
self.request_queue: asyncio.Queue[Dict[str, Any]] = asyncio.Queue()
self.request_futures: Dict[str, asyncio.Future] = {}
self.num_worker_tasks = (
self.engine_args.router_num_threads
) # Number of worker tasks to process the queue
self.worker_tasks: List[asyncio.Task] = []
print(f"Processor init: {self.engine_args.router}")
def _create_tokenizer(self, engine_args: AsyncEngineArgs) -> AnyTokenizer:
......@@ -117,6 +124,68 @@ class Processor(ProcessMixIn):
{"router": self.engine_args.router},
)
# Start multiple worker tasks to process the queue
self._start_worker_tasks()
def _start_worker_tasks(self):
"""Start multiple worker tasks to process the queue concurrently"""
# Clear any existing worker tasks
for task in self.worker_tasks:
if not task.done():
task.cancel()
self.worker_tasks = []
# Create new worker tasks
for i in range(self.num_worker_tasks):
task = asyncio.create_task(self._process_queue(worker_id=i))
self.worker_tasks.append(task)
logger.info(f"Started {self.num_worker_tasks} queue worker tasks")
async def _process_queue(self, worker_id: int):
"""Background task to process the request queue"""
logger.info(f"Queue worker {worker_id} started")
while True:
try:
# Get the next request from the queue
request_data = await self.request_queue.get()
# Process the request
try:
await self._process_request(request_data)
except Exception as e:
logger.error(f"Worker {worker_id}: Error processing request: {e}")
finally:
# Mark the task as done
self.request_queue.task_done()
except asyncio.CancelledError:
logger.info(f"Queue worker {worker_id} was cancelled")
break
except Exception as e:
logger.error(
f"Worker {worker_id}: Unexpected error in queue processing: {e}"
)
# Sleep briefly to avoid tight error loops
await asyncio.sleep(0.1)
async def _get_kv_load(self):
metrics = await self.metrics_aggregator.get_metrics()
kv_load = {}
for endpoint in metrics.endpoints:
worker_id = endpoint.worker_id
kv_load[worker_id] = getattr(endpoint, "gpu_cache_usage_perc", 0.0)
return kv_load
async def _get_pending_requests(self):
metrics = await self.metrics_aggregator.get_metrics()
pending_requests = {}
for endpoint in metrics.endpoints:
worker_id = endpoint.worker_id
pending_requests[worker_id] = getattr(endpoint, "num_requests_waiting", 0)
return pending_requests
async def _generate(
self,
raw_request: Union[CompletionRequest, ChatCompletionRequest],
......@@ -124,51 +193,108 @@ class Processor(ProcessMixIn):
):
request_id = str(uuid.uuid4())
logger.debug(f"Got raw request: {raw_request}")
(
request,
conversation,
prompt,
engine_prompt,
sampling_params,
) = await self._parse_raw_request(raw_request)
router_mode = (await self.etcd_kv_cache.get("router")).decode()
prefix_hit_rate = 0.0
# Create a future for this request
future: asyncio.Future[AsyncIterator[Any]] = asyncio.Future()
self.request_futures[request_id] = future
if self.use_router:
router_generator = await self.router_client.generate(
Tokens(tokens=engine_prompt["prompt_token_ids"]).model_dump_json()
)
decision = await router_generator.__anext__()
worker_id, prefix_hit_rate = decision.data()
prefix_hit_rate = float(prefix_hit_rate)
# Create request object once with default prefix_hit_rate
request_obj = vLLMGenerateRequest(
engine_prompt=engine_prompt,
sampling_params=sampling_params,
request_id=request_id,
prefix_hit_rate=prefix_hit_rate,
).model_dump_json()
# Enqueue the request with minimal processing
await self.request_queue.put(
{
"request_id": request_id,
"raw_request": raw_request,
"request_type": request_type,
}
)
if self.use_router:
if worker_id == "":
engine_generator = await self.worker_client.generate(request_obj)
else:
engine_generator = await self.worker_client.direct(
request_obj, int(worker_id)
try:
# Wait for the future to complete and yield the results
generator = await future
async for response in generator:
yield response
finally:
# Clean up the future when done
if request_id in self.request_futures:
del self.request_futures[request_id]
async def _process_request(self, request_data: Dict[str, Any]):
"""Process a single request from the queue"""
request_id = request_data["request_id"]
raw_request = request_data["raw_request"]
request_type = request_data["request_type"]
try:
# Parse the raw request here instead of in _generate
(
request,
conversation,
prompt,
engine_prompt,
sampling_params,
) = await self._parse_raw_request(raw_request)
# Create an async generator function to process this request
async def process_and_stream():
# TODO: queue request at processor when engines are full
router_mode = (await self.etcd_kv_cache.get("router")).decode()
self.use_router = router_mode in (RouterType.KV, RouterType.KV_LOAD)
prefix_hit_rate = 0.0 # Default value
if self.use_router:
router_generator = await self.router_client.generate(
Tokens(
tokens=engine_prompt["prompt_token_ids"]
).model_dump_json()
)
decision = await router_generator.__anext__()
worker_id, prefix_hit_rate = decision.data()
prefix_hit_rate = float(prefix_hit_rate)
# Create request object once with default prefix_hit_rate
request_obj = vLLMGenerateRequest(
engine_prompt=engine_prompt,
sampling_params=sampling_params,
request_id=request_id,
prefix_hit_rate=prefix_hit_rate,
).model_dump_json()
if self.use_router:
if worker_id == "":
engine_generator = await self.worker_client.generate(
request_obj
)
else:
engine_generator = await self.worker_client.direct(
request_obj, int(worker_id)
)
elif router_mode == RouterType.RANDOM:
engine_generator = await self.worker_client.generate(request_obj)
elif router_mode == RouterType.ROUND_ROBIN:
engine_generator = await self.worker_client.round_robin(request_obj)
output_generator = self._generate_responses(
engine_generator, request_type
)
elif router_mode == RouterType.RANDOM:
engine_generator = await self.worker_client.generate(request_obj)
elif router_mode == RouterType.ROUND_ROBIN:
engine_generator = await self.worker_client.round_robin(request_obj)
output = self._generate_responses(engine_generator, request_type)
# Stream responses directly to the caller
async for response in await self._stream_response(
request, output_generator, request_id, conversation
):
yield response
async for response in await self._stream_response(
request, output, request_id, conversation
):
yield response
# Set the future result to our async generator
if request_id in self.request_futures:
self.request_futures[request_id].set_result(process_and_stream())
except Exception as e:
logger.error(f"Error processing request {request_id}: {e}")
# Set exception on the future if it still exists
if (
request_id in self.request_futures
and not self.request_futures[request_id].done()
):
self.request_futures[request_id].set_exception(e)
async def _generate_responses(
self, engine_generator: AsyncIterator[RequestOutput], request_type: RequestType
......
......@@ -24,14 +24,13 @@ Frontend:
Processor:
router: round-robin
router-num-threads: 4
common-configs: [model, block-size, max-model-len]
VllmWorker:
enforce-eager: true
max-num-batched-tokens: 16384
enable-prefix-caching: true
router: random
tensor-parallel-size: 1
ServiceArgs:
workers: 1
resources:
......
......@@ -43,6 +43,12 @@ def parse_vllm_args(service_name, prefix) -> AsyncEngineArgs:
default=RouterType.RANDOM,
help="Router type to use for scheduling requests to workers",
)
parser.add_argument(
"--router-num-threads",
type=int,
default=4,
help="Number of threads to use for the router to process the requests",
)
parser.add_argument(
"--remote-prefill", action="store_true", help="Enable remote prefill"
)
......@@ -67,6 +73,7 @@ def parse_vllm_args(service_name, prefix) -> AsyncEngineArgs:
args = parser.parse_args(vllm_args)
engine_args = AsyncEngineArgs.from_cli_args(args)
engine_args.router = args.router
engine_args.router_num_threads = args.router_num_threads
engine_args.remote_prefill = args.remote_prefill
engine_args.conditional_disagg = args.conditional_disagg
engine_args.max_local_prefill_length = args.max_local_prefill_length
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment