Unverified Commit e0cd8489 authored by Hongkuan Zhou's avatar Hongkuan Zhou Committed by GitHub
Browse files

feat: multi-thread (via asyncio.task) in processor (#904)

parent 191748e0
...@@ -13,10 +13,11 @@ ...@@ -13,10 +13,11 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import asyncio
import logging import logging
import uuid import uuid
from enum import Enum from enum import Enum
from typing import AsyncIterator, Tuple, Union from typing import Any, AsyncIterator, Dict, List, Tuple, Union
from components.kv_router import Router from components.kv_router import Router
from components.worker import VllmWorker from components.worker import VllmWorker
...@@ -68,6 +69,12 @@ class Processor(ProcessMixIn): ...@@ -68,6 +69,12 @@ class Processor(ProcessMixIn):
self.tokenizer, self.model_config self.tokenizer, self.model_config
) )
self.min_workers = 1 self.min_workers = 1
self.request_queue: asyncio.Queue[Dict[str, Any]] = asyncio.Queue()
self.request_futures: Dict[str, asyncio.Future] = {}
self.num_worker_tasks = (
self.engine_args.router_num_threads
) # Number of worker tasks to process the queue
self.worker_tasks: List[asyncio.Task] = []
print(f"Processor init: {self.engine_args.router}") print(f"Processor init: {self.engine_args.router}")
def _create_tokenizer(self, engine_args: AsyncEngineArgs) -> AnyTokenizer: def _create_tokenizer(self, engine_args: AsyncEngineArgs) -> AnyTokenizer:
...@@ -117,6 +124,68 @@ class Processor(ProcessMixIn): ...@@ -117,6 +124,68 @@ class Processor(ProcessMixIn):
{"router": self.engine_args.router}, {"router": self.engine_args.router},
) )
# Start multiple worker tasks to process the queue
self._start_worker_tasks()
def _start_worker_tasks(self):
"""Start multiple worker tasks to process the queue concurrently"""
# Clear any existing worker tasks
for task in self.worker_tasks:
if not task.done():
task.cancel()
self.worker_tasks = []
# Create new worker tasks
for i in range(self.num_worker_tasks):
task = asyncio.create_task(self._process_queue(worker_id=i))
self.worker_tasks.append(task)
logger.info(f"Started {self.num_worker_tasks} queue worker tasks")
async def _process_queue(self, worker_id: int):
"""Background task to process the request queue"""
logger.info(f"Queue worker {worker_id} started")
while True:
try:
# Get the next request from the queue
request_data = await self.request_queue.get()
# Process the request
try:
await self._process_request(request_data)
except Exception as e:
logger.error(f"Worker {worker_id}: Error processing request: {e}")
finally:
# Mark the task as done
self.request_queue.task_done()
except asyncio.CancelledError:
logger.info(f"Queue worker {worker_id} was cancelled")
break
except Exception as e:
logger.error(
f"Worker {worker_id}: Unexpected error in queue processing: {e}"
)
# Sleep briefly to avoid tight error loops
await asyncio.sleep(0.1)
async def _get_kv_load(self):
metrics = await self.metrics_aggregator.get_metrics()
kv_load = {}
for endpoint in metrics.endpoints:
worker_id = endpoint.worker_id
kv_load[worker_id] = getattr(endpoint, "gpu_cache_usage_perc", 0.0)
return kv_load
async def _get_pending_requests(self):
metrics = await self.metrics_aggregator.get_metrics()
pending_requests = {}
for endpoint in metrics.endpoints:
worker_id = endpoint.worker_id
pending_requests[worker_id] = getattr(endpoint, "num_requests_waiting", 0)
return pending_requests
async def _generate( async def _generate(
self, self,
raw_request: Union[CompletionRequest, ChatCompletionRequest], raw_request: Union[CompletionRequest, ChatCompletionRequest],
...@@ -124,6 +193,38 @@ class Processor(ProcessMixIn): ...@@ -124,6 +193,38 @@ class Processor(ProcessMixIn):
): ):
request_id = str(uuid.uuid4()) request_id = str(uuid.uuid4())
logger.debug(f"Got raw request: {raw_request}") logger.debug(f"Got raw request: {raw_request}")
# Create a future for this request
future: asyncio.Future[AsyncIterator[Any]] = asyncio.Future()
self.request_futures[request_id] = future
# Enqueue the request with minimal processing
await self.request_queue.put(
{
"request_id": request_id,
"raw_request": raw_request,
"request_type": request_type,
}
)
try:
# Wait for the future to complete and yield the results
generator = await future
async for response in generator:
yield response
finally:
# Clean up the future when done
if request_id in self.request_futures:
del self.request_futures[request_id]
async def _process_request(self, request_data: Dict[str, Any]):
"""Process a single request from the queue"""
request_id = request_data["request_id"]
raw_request = request_data["raw_request"]
request_type = request_data["request_type"]
try:
# Parse the raw request here instead of in _generate
( (
request, request,
conversation, conversation,
...@@ -132,12 +233,19 @@ class Processor(ProcessMixIn): ...@@ -132,12 +233,19 @@ class Processor(ProcessMixIn):
sampling_params, sampling_params,
) = await self._parse_raw_request(raw_request) ) = await self._parse_raw_request(raw_request)
# Create an async generator function to process this request
async def process_and_stream():
# TODO: queue request at processor when engines are full
router_mode = (await self.etcd_kv_cache.get("router")).decode() router_mode = (await self.etcd_kv_cache.get("router")).decode()
prefix_hit_rate = 0.0
self.use_router = router_mode in (RouterType.KV, RouterType.KV_LOAD)
prefix_hit_rate = 0.0 # Default value
if self.use_router: if self.use_router:
router_generator = await self.router_client.generate( router_generator = await self.router_client.generate(
Tokens(tokens=engine_prompt["prompt_token_ids"]).model_dump_json() Tokens(
tokens=engine_prompt["prompt_token_ids"]
).model_dump_json()
) )
decision = await router_generator.__anext__() decision = await router_generator.__anext__()
worker_id, prefix_hit_rate = decision.data() worker_id, prefix_hit_rate = decision.data()
...@@ -153,7 +261,9 @@ class Processor(ProcessMixIn): ...@@ -153,7 +261,9 @@ class Processor(ProcessMixIn):
if self.use_router: if self.use_router:
if worker_id == "": if worker_id == "":
engine_generator = await self.worker_client.generate(request_obj) engine_generator = await self.worker_client.generate(
request_obj
)
else: else:
engine_generator = await self.worker_client.direct( engine_generator = await self.worker_client.direct(
request_obj, int(worker_id) request_obj, int(worker_id)
...@@ -163,13 +273,29 @@ class Processor(ProcessMixIn): ...@@ -163,13 +273,29 @@ class Processor(ProcessMixIn):
elif router_mode == RouterType.ROUND_ROBIN: elif router_mode == RouterType.ROUND_ROBIN:
engine_generator = await self.worker_client.round_robin(request_obj) engine_generator = await self.worker_client.round_robin(request_obj)
output = self._generate_responses(engine_generator, request_type) output_generator = self._generate_responses(
engine_generator, request_type
)
# Stream responses directly to the caller
async for response in await self._stream_response( async for response in await self._stream_response(
request, output, request_id, conversation request, output_generator, request_id, conversation
): ):
yield response yield response
# Set the future result to our async generator
if request_id in self.request_futures:
self.request_futures[request_id].set_result(process_and_stream())
except Exception as e:
logger.error(f"Error processing request {request_id}: {e}")
# Set exception on the future if it still exists
if (
request_id in self.request_futures
and not self.request_futures[request_id].done()
):
self.request_futures[request_id].set_exception(e)
async def _generate_responses( async def _generate_responses(
self, engine_generator: AsyncIterator[RequestOutput], request_type: RequestType self, engine_generator: AsyncIterator[RequestOutput], request_type: RequestType
) -> AsyncIterator[Union[RequestOutput, Tuple[int, RequestOutput]]]: ) -> AsyncIterator[Union[RequestOutput, Tuple[int, RequestOutput]]]:
......
...@@ -24,14 +24,13 @@ Frontend: ...@@ -24,14 +24,13 @@ Frontend:
Processor: Processor:
router: round-robin router: round-robin
router-num-threads: 4
common-configs: [model, block-size, max-model-len] common-configs: [model, block-size, max-model-len]
VllmWorker: VllmWorker:
enforce-eager: true enforce-eager: true
max-num-batched-tokens: 16384 max-num-batched-tokens: 16384
enable-prefix-caching: true enable-prefix-caching: true
router: random
tensor-parallel-size: 1
ServiceArgs: ServiceArgs:
workers: 1 workers: 1
resources: resources:
......
...@@ -43,6 +43,12 @@ def parse_vllm_args(service_name, prefix) -> AsyncEngineArgs: ...@@ -43,6 +43,12 @@ def parse_vllm_args(service_name, prefix) -> AsyncEngineArgs:
default=RouterType.RANDOM, default=RouterType.RANDOM,
help="Router type to use for scheduling requests to workers", help="Router type to use for scheduling requests to workers",
) )
parser.add_argument(
"--router-num-threads",
type=int,
default=4,
help="Number of threads to use for the router to process the requests",
)
parser.add_argument( parser.add_argument(
"--remote-prefill", action="store_true", help="Enable remote prefill" "--remote-prefill", action="store_true", help="Enable remote prefill"
) )
...@@ -67,6 +73,7 @@ def parse_vllm_args(service_name, prefix) -> AsyncEngineArgs: ...@@ -67,6 +73,7 @@ def parse_vllm_args(service_name, prefix) -> AsyncEngineArgs:
args = parser.parse_args(vllm_args) args = parser.parse_args(vllm_args)
engine_args = AsyncEngineArgs.from_cli_args(args) engine_args = AsyncEngineArgs.from_cli_args(args)
engine_args.router = args.router engine_args.router = args.router
engine_args.router_num_threads = args.router_num_threads
engine_args.remote_prefill = args.remote_prefill engine_args.remote_prefill = args.remote_prefill
engine_args.conditional_disagg = args.conditional_disagg engine_args.conditional_disagg = args.conditional_disagg
engine_args.max_local_prefill_length = args.max_local_prefill_length engine_args.max_local_prefill_length = args.max_local_prefill_length
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment