Unverified Commit 37158f20 authored by Chang Su's avatar Chang Su Committed by GitHub
Browse files

router: Support parallel sampling num > 1 in grpc_server and non-stream handling (#10929)

parent 3e95aa1a
......@@ -4,6 +4,7 @@ Mimics TokenizerManager's state management and ZMQ communication patterns.
"""
import asyncio
import copy
import dataclasses
import logging
import os
......@@ -11,6 +12,7 @@ import signal
import sys
import threading
import time
import uuid
from typing import Any, Dict, List, Optional, Union
import grpc
......@@ -79,11 +81,9 @@ class GrpcReqState:
last_completion_tokens: int = 1
# Streaming state
last_output_offset: int = 0
stream_finished: bool = False
# Output accumulation
text: str = ""
# Token accumulation (for non-streaming)
output_ids: List[int] = dataclasses.field(default_factory=list)
input_token_logprobs_val: List[float] = dataclasses.field(default_factory=list)
input_token_logprobs_idx: List[int] = dataclasses.field(default_factory=list)
......@@ -139,8 +139,6 @@ class GrpcRequestManager:
self.is_pause_cond = asyncio.Condition()
# Metrics
self.request_counter = 0
self.request_counter_lock = asyncio.Lock()
self.last_receive_tstamp = time.time()
# Crash dump for debugging
......@@ -158,22 +156,133 @@ class GrpcRequestManager:
obj: TokenizedGenerateReqInput,
request_id: Optional[str] = None,
grpc_context: Optional[grpc.aio.ServicerContext] = None,
) -> asyncio.Queue:
):
"""
Submit a generation request to the scheduler.
Returns a queue for streaming outputs.
Submit a generation request to the scheduler with n>1 parallel sampling support.
This method implements the same two-phase approach as tokenizer_manager.py:
1. Phase 1: Send prefix caching request (max_new_tokens=0)
2. Phase 2: Send n generation requests that reuse the cached prefix
Yields individual responses for streaming, or aggregated responses for non-streaming.
"""
n = getattr(obj.sampling_params, "n", 1)
if n <= 1:
async for response in self._handle_single_request(
obj, request_id, grpc_context
):
yield response
return
# N>1 handling - two-phase approach
logger.debug(f"Multiple sampling request (n={n}), using two-phase approach")
# Generate base request ID if not provided
if request_id is None:
base_request_id = f"grpc-{uuid.uuid4().hex}"
else:
base_request_id = request_id
# Phase 1: Cache the common prefix
logger.debug(f"Phase 1: Caching prefix for request {base_request_id}")
prefix_obj = copy.copy(obj)
prefix_obj.sampling_params = copy.copy(obj.sampling_params)
prefix_obj.sampling_params.max_new_tokens = 0 # Prefill-only
prefix_obj.sampling_params.n = 1 # Don't replicate prefix request
# Send prefix caching request and consume response
async for _ in self._handle_single_request(
prefix_obj, f"{base_request_id}-prefix", grpc_context
):
# Consume prefix response (usually just one chunk with finish_reason)
pass
logger.debug(f"Phase 1 completed: Prefix cached for {base_request_id}")
# Phase 2: Generate n parallel requests
logger.debug(f"Phase 2: Generating {n} parallel requests")
generators = []
request_ids = []
for i in range(n):
# Create individual generation request
gen_obj = copy.copy(obj)
gen_obj.sampling_params = copy.copy(obj.sampling_params)
gen_obj.sampling_params.n = 1 # Each request generates 1 response
gen_request_id = f"{base_request_id}-{i}"
request_ids.append(gen_request_id)
# Start generation request
generators.append(
self._handle_single_request(gen_obj, gen_request_id, grpc_context)
)
# Handle response aggregation
is_stream = getattr(obj, "stream", False)
if not is_stream:
# Non-streaming: collect all responses and return as batch
logger.debug(f"Non-streaming mode: collecting {n} responses")
responses = []
for generator in generators:
async for response in generator:
responses.append(response)
yield responses # Return all responses as a batch
else:
# Streaming mode: multiplex responses with index for ordering
logger.debug(f"Streaming mode: multiplexing {n} streams")
rid_to_index = {rid: i for i, rid in enumerate(request_ids)}
# Create async tasks for all generators
task_map = {}
for generator in generators:
task = asyncio.create_task(generator.__anext__())
task_map[task] = generator
# Process responses as they arrive
while task_map:
done, _ = await asyncio.wait(
task_map.keys(), return_when=asyncio.FIRST_COMPLETED
)
for task in done:
generator = task_map.pop(task)
try:
response = await task
# Add index for client-side ordering
if isinstance(response, dict) and "meta_info" in response:
response_rid = response["meta_info"].get("id", "")
if response_rid in rid_to_index:
response["index"] = rid_to_index[response_rid]
yield response
# Create next task for this generator
next_task = asyncio.create_task(generator.__anext__())
task_map[next_task] = generator
except StopAsyncIteration:
# This generator is finished
pass
async def _handle_single_request(
self,
obj: TokenizedGenerateReqInput,
request_id: Optional[str] = None,
grpc_context: Optional[grpc.aio.ServicerContext] = None,
):
"""Handle a single request - core implementation without n>1 logic."""
# Generate request ID if not provided
if request_id is None:
async with self.request_counter_lock:
request_id = f"grpc-{self.request_counter}"
self.request_counter += 1
request_id = f"grpc-{uuid.uuid4().hex}"
obj.rid = request_id
# Create and register request state
# TODO: support log_request
# Create request state
state = GrpcReqState(
request_id=request_id,
grpc_context=grpc_context,
......@@ -189,19 +298,51 @@ class GrpcRequestManager:
state.session_id = obj.session_params.session_id
state.is_session_request = True
# Register state
self.rid_to_state[request_id] = state
self.record_request_for_crash_dump(obj)
# Send to scheduler via ZMQ
try:
# Send to scheduler - let exceptions bubble up to grpc_server.py
await self._send_to_scheduler(obj)
except Exception as e:
# Clean up on failure
del self.rid_to_state[request_id]
raise RuntimeError(f"Failed to send request to scheduler: {e}")
return state.out_queue
is_stream = getattr(obj, "stream", False)
while True:
# Client cancelled - notify scheduler and exit
if grpc_context and grpc_context.cancelled():
await self.abort_request(request_id)
return
try:
response = await asyncio.wait_for(state.out_queue.get(), timeout=4)
if is_stream:
yield response
# Non-streaming: yield final response with accumulated tokens from state
if isinstance(response, dict) and response.get("finished", False):
if not is_stream:
final_response = response.copy()
final_response["token_ids"] = state.output_ids
yield final_response
break
except asyncio.TimeoutError:
# Timeout waiting for response - abort and cleanup
logger.warning(
f"Timeout waiting for response for request {request_id}"
)
await self.abort_request(request_id)
return
finally:
# Always clean up request state when exiting
self._cleanup_request_state(request_id)
def _cleanup_request_state(self, request_id: str):
"""Clean up local request state (does not notify scheduler)."""
if request_id in self.rid_to_state:
del self.rid_to_state[request_id]
async def embedding_request(
self,
......@@ -214,9 +355,7 @@ class GrpcRequestManager:
"""
# Generate request ID if not provided
if request_id is None:
async with self.request_counter_lock:
request_id = f"grpc-embed-{self.request_counter}"
self.request_counter += 1
request_id = f"grpc-embed-{uuid.uuid4().hex}"
obj.rid = request_id
......@@ -355,7 +494,6 @@ class GrpcRequestManager:
# Extract output for this request
output_data = {
"request_id": rid,
"text": batch_out.decoded_texts[i] if batch_out.decoded_texts else "",
"token_ids": batch_out.output_ids[i] if batch_out.output_ids else [],
"finished": batch_out.finished_reasons[i] is not None,
"meta_info": {
......@@ -367,6 +505,9 @@ class GrpcRequestManager:
if batch_out.completion_tokens
else 0
),
"cached_tokens": (
batch_out.cached_tokens[i] if batch_out.cached_tokens else 0
),
"finish_reason": (
str(batch_out.finished_reasons[i])
if batch_out.finished_reasons[i]
......@@ -389,15 +530,10 @@ class GrpcRequestManager:
),
}
# Update state
if output_data["text"]:
state.text += output_data["text"][state.last_output_offset :]
state.last_output_offset = len(output_data["text"])
# Update state for accumulation
if output_data["token_ids"]:
state.output_ids.extend(output_data["token_ids"])
# Send to output queue
await state.out_queue.put(output_data)
# Handle completion
......
......@@ -181,20 +181,34 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
# Convert gRPC request to internal format
tokenized_req = self._convert_generate_request(request)
# Submit to request manager
output_queue = await self.request_manager.generate_request(
# Submit to request manager (automatically handles n>1)
response_generator = self.request_manager.generate_request(
obj=tokenized_req,
request_id=request.request_id,
grpc_context=context,
)
# Stream outputs
while True:
try:
# Get output with timeout
output = await asyncio.wait_for(output_queue.get(), timeout=4)
# Check for errors
async for output in response_generator:
# Handle batch responses (for n>1 non-streaming)
if isinstance(output, list):
for batch_output in output:
if "error" in batch_output:
yield sglang_scheduler_pb2.GenerateResponse(
request_id=request.request_id,
error=sglang_scheduler_pb2.GenerateError(
message=batch_output["error"],
http_status_code=(
"500" if "abort" not in batch_output else "499"
),
),
)
else:
# All non-error batch outputs are final responses
yield self._create_completion_response(
request.request_id, batch_output
)
else:
# Handle single response (for streaming or n=1 non-streaming)
if "error" in output:
yield sglang_scheduler_pb2.GenerateResponse(
request_id=request.request_id,
......@@ -205,27 +219,13 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
),
),
)
break
# Check if finished
if output.get("finished", False):
# Send completion
elif output.get("finished", False):
yield self._create_completion_response(
request.request_id, output
)
break
else:
# Send chunk
yield self._create_chunk_response(request.request_id, output)
except asyncio.TimeoutError:
# Check if context is still active
if context.cancelled():
# Abort the request
await self.request_manager.abort_request(request.request_id)
break
continue
except Exception as e:
logger.error(f"Generate failed: {e}\n{get_exception_traceback()}")
yield sglang_scheduler_pb2.GenerateResponse(
......@@ -403,7 +403,7 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
return_logprob=grpc_req.return_logprob,
logprob_start_len=grpc_req.logprob_start_len or -1,
top_logprobs_num=grpc_req.top_logprobs_num or 0,
stream=True, # Always stream for gRPC
stream=grpc_req.stream or False,
lora_path=grpc_req.lora_id if grpc_req.lora_id else None,
token_ids_logprob=(
list(grpc_req.token_ids_logprob) if grpc_req.token_ids_logprob else None
......@@ -480,10 +480,10 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
return sglang_scheduler_pb2.GenerateResponse(
request_id=request_id,
chunk=sglang_scheduler_pb2.GenerateStreamChunk(
token_id=output["token_ids"][-1] if output.get("token_ids") else 0,
token_ids=output.get("token_ids", []),
prompt_tokens=meta_info.get("prompt_tokens", 0),
completion_tokens=meta_info.get("completion_tokens", 0),
cached_tokens=0,
cached_tokens=meta_info.get("cached_tokens", 0),
),
)
......
......@@ -122,6 +122,9 @@ message GenerateRequest {
// For load balancing
int32 dp_balance_id = 17;
// Whether client wants streaming response
bool stream = 18;
}
message TokenizedInput {
......@@ -163,8 +166,8 @@ message GenerateResponse {
}
message GenerateStreamChunk {
// Generated token
int32 token_id = 1;
// Generated tokens (incremental chunk)
repeated int32 token_ids = 1;
// Cumulative counts
int32 prompt_tokens = 2;
......
......@@ -83,7 +83,7 @@ class DisaggregatedParams(_message.Message):
def __init__(self, bootstrap_host: _Optional[str] = ..., bootstrap_port: _Optional[int] = ..., bootstrap_room: _Optional[int] = ...) -> None: ...
class GenerateRequest(_message.Message):
__slots__ = ("request_id", "tokenized", "mm_inputs", "sampling_params", "return_logprob", "logprob_start_len", "top_logprobs_num", "token_ids_logprob", "return_hidden_states", "disaggregated_params", "custom_logit_processor", "timestamp", "log_metrics", "input_embeds", "lora_id", "data_parallel_rank", "dp_balance_id")
__slots__ = ("request_id", "tokenized", "mm_inputs", "sampling_params", "return_logprob", "logprob_start_len", "top_logprobs_num", "token_ids_logprob", "return_hidden_states", "disaggregated_params", "custom_logit_processor", "timestamp", "log_metrics", "input_embeds", "lora_id", "data_parallel_rank", "dp_balance_id", "stream")
REQUEST_ID_FIELD_NUMBER: _ClassVar[int]
TOKENIZED_FIELD_NUMBER: _ClassVar[int]
MM_INPUTS_FIELD_NUMBER: _ClassVar[int]
......@@ -101,6 +101,7 @@ class GenerateRequest(_message.Message):
LORA_ID_FIELD_NUMBER: _ClassVar[int]
DATA_PARALLEL_RANK_FIELD_NUMBER: _ClassVar[int]
DP_BALANCE_ID_FIELD_NUMBER: _ClassVar[int]
STREAM_FIELD_NUMBER: _ClassVar[int]
request_id: str
tokenized: TokenizedInput
mm_inputs: MultimodalInputs
......@@ -118,7 +119,8 @@ class GenerateRequest(_message.Message):
lora_id: str
data_parallel_rank: int
dp_balance_id: int
def __init__(self, request_id: _Optional[str] = ..., tokenized: _Optional[_Union[TokenizedInput, _Mapping]] = ..., mm_inputs: _Optional[_Union[MultimodalInputs, _Mapping]] = ..., sampling_params: _Optional[_Union[SamplingParams, _Mapping]] = ..., return_logprob: bool = ..., logprob_start_len: _Optional[int] = ..., top_logprobs_num: _Optional[int] = ..., token_ids_logprob: _Optional[_Iterable[int]] = ..., return_hidden_states: bool = ..., disaggregated_params: _Optional[_Union[DisaggregatedParams, _Mapping]] = ..., custom_logit_processor: _Optional[str] = ..., timestamp: _Optional[_Union[datetime.datetime, _timestamp_pb2.Timestamp, _Mapping]] = ..., log_metrics: bool = ..., input_embeds: _Optional[_Iterable[float]] = ..., lora_id: _Optional[str] = ..., data_parallel_rank: _Optional[int] = ..., dp_balance_id: _Optional[int] = ...) -> None: ...
stream: bool
def __init__(self, request_id: _Optional[str] = ..., tokenized: _Optional[_Union[TokenizedInput, _Mapping]] = ..., mm_inputs: _Optional[_Union[MultimodalInputs, _Mapping]] = ..., sampling_params: _Optional[_Union[SamplingParams, _Mapping]] = ..., return_logprob: bool = ..., logprob_start_len: _Optional[int] = ..., top_logprobs_num: _Optional[int] = ..., token_ids_logprob: _Optional[_Iterable[int]] = ..., return_hidden_states: bool = ..., disaggregated_params: _Optional[_Union[DisaggregatedParams, _Mapping]] = ..., custom_logit_processor: _Optional[str] = ..., timestamp: _Optional[_Union[datetime.datetime, _timestamp_pb2.Timestamp, _Mapping]] = ..., log_metrics: bool = ..., input_embeds: _Optional[_Iterable[float]] = ..., lora_id: _Optional[str] = ..., data_parallel_rank: _Optional[int] = ..., dp_balance_id: _Optional[int] = ..., stream: bool = ...) -> None: ...
class TokenizedInput(_message.Message):
__slots__ = ("original_text", "input_ids")
......@@ -161,20 +163,20 @@ class GenerateResponse(_message.Message):
def __init__(self, request_id: _Optional[str] = ..., chunk: _Optional[_Union[GenerateStreamChunk, _Mapping]] = ..., complete: _Optional[_Union[GenerateComplete, _Mapping]] = ..., error: _Optional[_Union[GenerateError, _Mapping]] = ...) -> None: ...
class GenerateStreamChunk(_message.Message):
__slots__ = ("token_id", "prompt_tokens", "completion_tokens", "cached_tokens", "logprobs", "hidden_states")
TOKEN_ID_FIELD_NUMBER: _ClassVar[int]
__slots__ = ("token_ids", "prompt_tokens", "completion_tokens", "cached_tokens", "logprobs", "hidden_states")
TOKEN_IDS_FIELD_NUMBER: _ClassVar[int]
PROMPT_TOKENS_FIELD_NUMBER: _ClassVar[int]
COMPLETION_TOKENS_FIELD_NUMBER: _ClassVar[int]
CACHED_TOKENS_FIELD_NUMBER: _ClassVar[int]
LOGPROBS_FIELD_NUMBER: _ClassVar[int]
HIDDEN_STATES_FIELD_NUMBER: _ClassVar[int]
token_id: int
token_ids: _containers.RepeatedScalarFieldContainer[int]
prompt_tokens: int
completion_tokens: int
cached_tokens: int
logprobs: LogProbs
hidden_states: _containers.RepeatedScalarFieldContainer[float]
def __init__(self, token_id: _Optional[int] = ..., prompt_tokens: _Optional[int] = ..., completion_tokens: _Optional[int] = ..., cached_tokens: _Optional[int] = ..., logprobs: _Optional[_Union[LogProbs, _Mapping]] = ..., hidden_states: _Optional[_Iterable[float]] = ...) -> None: ...
def __init__(self, token_ids: _Optional[_Iterable[int]] = ..., prompt_tokens: _Optional[int] = ..., completion_tokens: _Optional[int] = ..., cached_tokens: _Optional[int] = ..., logprobs: _Optional[_Union[LogProbs, _Mapping]] = ..., hidden_states: _Optional[_Iterable[float]] = ...) -> None: ...
class GenerateComplete(_message.Message):
__slots__ = ("output_ids", "finish_reason", "prompt_tokens", "completion_tokens", "cached_tokens", "all_logprobs", "all_hidden_states")
......
......@@ -103,6 +103,7 @@ impl SglangSchedulerClient {
logprob_start_len: -1,
top_logprobs_num: body.top_logprobs.unwrap_or(0) as i32,
return_hidden_states: body.return_hidden_states,
stream: body.stream,
..Default::default()
};
......@@ -367,14 +368,14 @@ mod tests {
#[test]
fn test_generate_stream_chunk() {
let chunk = proto::GenerateStreamChunk {
token_id: 1234,
token_ids: vec![1234, 5678],
prompt_tokens: 5,
completion_tokens: 2,
cached_tokens: 3,
..Default::default()
};
assert_eq!(chunk.token_id, 1234);
assert_eq!(chunk.token_ids, vec![1234, 5678]);
assert_eq!(chunk.prompt_tokens, 5);
assert_eq!(chunk.completion_tokens, 2);
assert_eq!(chunk.cached_tokens, 3);
......
......@@ -122,6 +122,9 @@ message GenerateRequest {
// For load balancing
int32 dp_balance_id = 17;
// Whether client wants streaming response
bool stream = 18;
}
message TokenizedInput {
......@@ -163,8 +166,8 @@ message GenerateResponse {
}
message GenerateStreamChunk {
// Generated token
int32 token_id = 1;
// Generated tokens (incremental chunk)
repeated int32 token_ids = 1;
// Cumulative counts
int32 prompt_tokens = 2;
......
......@@ -203,6 +203,7 @@ impl GrpcRouter {
debug!("Selected worker: {}", worker.url());
// Step 2: Get gRPC client for worker (fail fast if can't connect)
// TODO(CahterineSue): manage grpc connection in worker. (it should be simpler here)
let client = match self.get_or_create_grpc_client(worker.url()).await {
Ok(c) => c,
Err(e) => {
......@@ -249,7 +250,7 @@ impl GrpcRouter {
// Step 6: Build the base gRPC request
let request_id = format!("chatcmpl-{}", Uuid::new_v4());
let base_request = match client.build_generate_request(
let request = match client.build_generate_request(
request_id,
body,
processed_messages.text.clone(),
......@@ -268,11 +269,11 @@ impl GrpcRouter {
}
};
// Step 7: Handle streaming vs non-streaming
if body.stream {
self.handle_streaming_chat(client, base_request, body).await
self.handle_streaming_chat(client, request, body).await
} else {
self.handle_non_streaming_chat(client, base_request, body)
.await
self.handle_non_streaming_chat(client, request, body).await
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment