Unverified Commit 37158f20 authored by Chang Su's avatar Chang Su Committed by GitHub
Browse files

router: Support parallel sampling num > 1 in grpc_server and non-stream handling (#10929)

parent 3e95aa1a
...@@ -4,6 +4,7 @@ Mimics TokenizerManager's state management and ZMQ communication patterns. ...@@ -4,6 +4,7 @@ Mimics TokenizerManager's state management and ZMQ communication patterns.
""" """
import asyncio import asyncio
import copy
import dataclasses import dataclasses
import logging import logging
import os import os
...@@ -11,6 +12,7 @@ import signal ...@@ -11,6 +12,7 @@ import signal
import sys import sys
import threading import threading
import time import time
import uuid
from typing import Any, Dict, List, Optional, Union from typing import Any, Dict, List, Optional, Union
import grpc import grpc
...@@ -79,11 +81,9 @@ class GrpcReqState: ...@@ -79,11 +81,9 @@ class GrpcReqState:
last_completion_tokens: int = 1 last_completion_tokens: int = 1
# Streaming state # Streaming state
last_output_offset: int = 0
stream_finished: bool = False stream_finished: bool = False
# Output accumulation # Token accumulation (for non-streaming)
text: str = ""
output_ids: List[int] = dataclasses.field(default_factory=list) output_ids: List[int] = dataclasses.field(default_factory=list)
input_token_logprobs_val: List[float] = dataclasses.field(default_factory=list) input_token_logprobs_val: List[float] = dataclasses.field(default_factory=list)
input_token_logprobs_idx: List[int] = dataclasses.field(default_factory=list) input_token_logprobs_idx: List[int] = dataclasses.field(default_factory=list)
...@@ -139,8 +139,6 @@ class GrpcRequestManager: ...@@ -139,8 +139,6 @@ class GrpcRequestManager:
self.is_pause_cond = asyncio.Condition() self.is_pause_cond = asyncio.Condition()
# Metrics # Metrics
self.request_counter = 0
self.request_counter_lock = asyncio.Lock()
self.last_receive_tstamp = time.time() self.last_receive_tstamp = time.time()
# Crash dump for debugging # Crash dump for debugging
...@@ -158,22 +156,133 @@ class GrpcRequestManager: ...@@ -158,22 +156,133 @@ class GrpcRequestManager:
obj: TokenizedGenerateReqInput, obj: TokenizedGenerateReqInput,
request_id: Optional[str] = None, request_id: Optional[str] = None,
grpc_context: Optional[grpc.aio.ServicerContext] = None, grpc_context: Optional[grpc.aio.ServicerContext] = None,
) -> asyncio.Queue: ):
""" """
Submit a generation request to the scheduler. Submit a generation request to the scheduler with n>1 parallel sampling support.
Returns a queue for streaming outputs.
This method implements the same two-phase approach as tokenizer_manager.py:
1. Phase 1: Send prefix caching request (max_new_tokens=0)
2. Phase 2: Send n generation requests that reuse the cached prefix
Yields individual responses for streaming, or aggregated responses for non-streaming.
""" """
n = getattr(obj.sampling_params, "n", 1)
if n <= 1:
async for response in self._handle_single_request(
obj, request_id, grpc_context
):
yield response
return
# N>1 handling - two-phase approach
logger.debug(f"Multiple sampling request (n={n}), using two-phase approach")
# Generate base request ID if not provided
if request_id is None:
base_request_id = f"grpc-{uuid.uuid4().hex}"
else:
base_request_id = request_id
# Phase 1: Cache the common prefix
logger.debug(f"Phase 1: Caching prefix for request {base_request_id}")
prefix_obj = copy.copy(obj)
prefix_obj.sampling_params = copy.copy(obj.sampling_params)
prefix_obj.sampling_params.max_new_tokens = 0 # Prefill-only
prefix_obj.sampling_params.n = 1 # Don't replicate prefix request
# Send prefix caching request and consume response
async for _ in self._handle_single_request(
prefix_obj, f"{base_request_id}-prefix", grpc_context
):
# Consume prefix response (usually just one chunk with finish_reason)
pass
logger.debug(f"Phase 1 completed: Prefix cached for {base_request_id}")
# Phase 2: Generate n parallel requests
logger.debug(f"Phase 2: Generating {n} parallel requests")
generators = []
request_ids = []
for i in range(n):
# Create individual generation request
gen_obj = copy.copy(obj)
gen_obj.sampling_params = copy.copy(obj.sampling_params)
gen_obj.sampling_params.n = 1 # Each request generates 1 response
gen_request_id = f"{base_request_id}-{i}"
request_ids.append(gen_request_id)
# Start generation request
generators.append(
self._handle_single_request(gen_obj, gen_request_id, grpc_context)
)
# Handle response aggregation
is_stream = getattr(obj, "stream", False)
if not is_stream:
# Non-streaming: collect all responses and return as batch
logger.debug(f"Non-streaming mode: collecting {n} responses")
responses = []
for generator in generators:
async for response in generator:
responses.append(response)
yield responses # Return all responses as a batch
else:
# Streaming mode: multiplex responses with index for ordering
logger.debug(f"Streaming mode: multiplexing {n} streams")
rid_to_index = {rid: i for i, rid in enumerate(request_ids)}
# Create async tasks for all generators
task_map = {}
for generator in generators:
task = asyncio.create_task(generator.__anext__())
task_map[task] = generator
# Process responses as they arrive
while task_map:
done, _ = await asyncio.wait(
task_map.keys(), return_when=asyncio.FIRST_COMPLETED
)
for task in done:
generator = task_map.pop(task)
try:
response = await task
# Add index for client-side ordering
if isinstance(response, dict) and "meta_info" in response:
response_rid = response["meta_info"].get("id", "")
if response_rid in rid_to_index:
response["index"] = rid_to_index[response_rid]
yield response
# Create next task for this generator
next_task = asyncio.create_task(generator.__anext__())
task_map[next_task] = generator
except StopAsyncIteration:
# This generator is finished
pass
async def _handle_single_request(
self,
obj: TokenizedGenerateReqInput,
request_id: Optional[str] = None,
grpc_context: Optional[grpc.aio.ServicerContext] = None,
):
"""Handle a single request - core implementation without n>1 logic."""
# Generate request ID if not provided # Generate request ID if not provided
if request_id is None: if request_id is None:
async with self.request_counter_lock: request_id = f"grpc-{uuid.uuid4().hex}"
request_id = f"grpc-{self.request_counter}"
self.request_counter += 1
obj.rid = request_id obj.rid = request_id
# Create and register request state
# TODO: support log_request # TODO: support log_request
# Create request state
state = GrpcReqState( state = GrpcReqState(
request_id=request_id, request_id=request_id,
grpc_context=grpc_context, grpc_context=grpc_context,
...@@ -189,19 +298,51 @@ class GrpcRequestManager: ...@@ -189,19 +298,51 @@ class GrpcRequestManager:
state.session_id = obj.session_params.session_id state.session_id = obj.session_params.session_id
state.is_session_request = True state.is_session_request = True
# Register state
self.rid_to_state[request_id] = state self.rid_to_state[request_id] = state
self.record_request_for_crash_dump(obj) self.record_request_for_crash_dump(obj)
# Send to scheduler via ZMQ
try: try:
# Send to scheduler - let exceptions bubble up to grpc_server.py
await self._send_to_scheduler(obj) await self._send_to_scheduler(obj)
except Exception as e:
# Clean up on failure
del self.rid_to_state[request_id]
raise RuntimeError(f"Failed to send request to scheduler: {e}")
return state.out_queue is_stream = getattr(obj, "stream", False)
while True:
# Client cancelled - notify scheduler and exit
if grpc_context and grpc_context.cancelled():
await self.abort_request(request_id)
return
try:
response = await asyncio.wait_for(state.out_queue.get(), timeout=4)
if is_stream:
yield response
# Non-streaming: yield final response with accumulated tokens from state
if isinstance(response, dict) and response.get("finished", False):
if not is_stream:
final_response = response.copy()
final_response["token_ids"] = state.output_ids
yield final_response
break
except asyncio.TimeoutError:
# Timeout waiting for response - abort and cleanup
logger.warning(
f"Timeout waiting for response for request {request_id}"
)
await self.abort_request(request_id)
return
finally:
# Always clean up request state when exiting
self._cleanup_request_state(request_id)
def _cleanup_request_state(self, request_id: str):
"""Clean up local request state (does not notify scheduler)."""
if request_id in self.rid_to_state:
del self.rid_to_state[request_id]
async def embedding_request( async def embedding_request(
self, self,
...@@ -214,9 +355,7 @@ class GrpcRequestManager: ...@@ -214,9 +355,7 @@ class GrpcRequestManager:
""" """
# Generate request ID if not provided # Generate request ID if not provided
if request_id is None: if request_id is None:
async with self.request_counter_lock: request_id = f"grpc-embed-{uuid.uuid4().hex}"
request_id = f"grpc-embed-{self.request_counter}"
self.request_counter += 1
obj.rid = request_id obj.rid = request_id
...@@ -355,7 +494,6 @@ class GrpcRequestManager: ...@@ -355,7 +494,6 @@ class GrpcRequestManager:
# Extract output for this request # Extract output for this request
output_data = { output_data = {
"request_id": rid, "request_id": rid,
"text": batch_out.decoded_texts[i] if batch_out.decoded_texts else "",
"token_ids": batch_out.output_ids[i] if batch_out.output_ids else [], "token_ids": batch_out.output_ids[i] if batch_out.output_ids else [],
"finished": batch_out.finished_reasons[i] is not None, "finished": batch_out.finished_reasons[i] is not None,
"meta_info": { "meta_info": {
...@@ -367,6 +505,9 @@ class GrpcRequestManager: ...@@ -367,6 +505,9 @@ class GrpcRequestManager:
if batch_out.completion_tokens if batch_out.completion_tokens
else 0 else 0
), ),
"cached_tokens": (
batch_out.cached_tokens[i] if batch_out.cached_tokens else 0
),
"finish_reason": ( "finish_reason": (
str(batch_out.finished_reasons[i]) str(batch_out.finished_reasons[i])
if batch_out.finished_reasons[i] if batch_out.finished_reasons[i]
...@@ -389,15 +530,10 @@ class GrpcRequestManager: ...@@ -389,15 +530,10 @@ class GrpcRequestManager:
), ),
} }
# Update state # Update state for accumulation
if output_data["text"]:
state.text += output_data["text"][state.last_output_offset :]
state.last_output_offset = len(output_data["text"])
if output_data["token_ids"]: if output_data["token_ids"]:
state.output_ids.extend(output_data["token_ids"]) state.output_ids.extend(output_data["token_ids"])
# Send to output queue
await state.out_queue.put(output_data) await state.out_queue.put(output_data)
# Handle completion # Handle completion
......
...@@ -181,20 +181,34 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer) ...@@ -181,20 +181,34 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
# Convert gRPC request to internal format # Convert gRPC request to internal format
tokenized_req = self._convert_generate_request(request) tokenized_req = self._convert_generate_request(request)
# Submit to request manager # Submit to request manager (automatically handles n>1)
output_queue = await self.request_manager.generate_request( response_generator = self.request_manager.generate_request(
obj=tokenized_req, obj=tokenized_req,
request_id=request.request_id, request_id=request.request_id,
grpc_context=context, grpc_context=context,
) )
# Stream outputs async for output in response_generator:
while True: # Handle batch responses (for n>1 non-streaming)
try: if isinstance(output, list):
# Get output with timeout for batch_output in output:
output = await asyncio.wait_for(output_queue.get(), timeout=4) if "error" in batch_output:
yield sglang_scheduler_pb2.GenerateResponse(
# Check for errors request_id=request.request_id,
error=sglang_scheduler_pb2.GenerateError(
message=batch_output["error"],
http_status_code=(
"500" if "abort" not in batch_output else "499"
),
),
)
else:
# All non-error batch outputs are final responses
yield self._create_completion_response(
request.request_id, batch_output
)
else:
# Handle single response (for streaming or n=1 non-streaming)
if "error" in output: if "error" in output:
yield sglang_scheduler_pb2.GenerateResponse( yield sglang_scheduler_pb2.GenerateResponse(
request_id=request.request_id, request_id=request.request_id,
...@@ -205,27 +219,13 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer) ...@@ -205,27 +219,13 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
), ),
), ),
) )
break elif output.get("finished", False):
# Check if finished
if output.get("finished", False):
# Send completion
yield self._create_completion_response( yield self._create_completion_response(
request.request_id, output request.request_id, output
) )
break
else: else:
# Send chunk
yield self._create_chunk_response(request.request_id, output) yield self._create_chunk_response(request.request_id, output)
except asyncio.TimeoutError:
# Check if context is still active
if context.cancelled():
# Abort the request
await self.request_manager.abort_request(request.request_id)
break
continue
except Exception as e: except Exception as e:
logger.error(f"Generate failed: {e}\n{get_exception_traceback()}") logger.error(f"Generate failed: {e}\n{get_exception_traceback()}")
yield sglang_scheduler_pb2.GenerateResponse( yield sglang_scheduler_pb2.GenerateResponse(
...@@ -403,7 +403,7 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer) ...@@ -403,7 +403,7 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
return_logprob=grpc_req.return_logprob, return_logprob=grpc_req.return_logprob,
logprob_start_len=grpc_req.logprob_start_len or -1, logprob_start_len=grpc_req.logprob_start_len or -1,
top_logprobs_num=grpc_req.top_logprobs_num or 0, top_logprobs_num=grpc_req.top_logprobs_num or 0,
stream=True, # Always stream for gRPC stream=grpc_req.stream or False,
lora_path=grpc_req.lora_id if grpc_req.lora_id else None, lora_path=grpc_req.lora_id if grpc_req.lora_id else None,
token_ids_logprob=( token_ids_logprob=(
list(grpc_req.token_ids_logprob) if grpc_req.token_ids_logprob else None list(grpc_req.token_ids_logprob) if grpc_req.token_ids_logprob else None
...@@ -480,10 +480,10 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer) ...@@ -480,10 +480,10 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
return sglang_scheduler_pb2.GenerateResponse( return sglang_scheduler_pb2.GenerateResponse(
request_id=request_id, request_id=request_id,
chunk=sglang_scheduler_pb2.GenerateStreamChunk( chunk=sglang_scheduler_pb2.GenerateStreamChunk(
token_id=output["token_ids"][-1] if output.get("token_ids") else 0, token_ids=output.get("token_ids", []),
prompt_tokens=meta_info.get("prompt_tokens", 0), prompt_tokens=meta_info.get("prompt_tokens", 0),
completion_tokens=meta_info.get("completion_tokens", 0), completion_tokens=meta_info.get("completion_tokens", 0),
cached_tokens=0, cached_tokens=meta_info.get("cached_tokens", 0),
), ),
) )
......
...@@ -122,6 +122,9 @@ message GenerateRequest { ...@@ -122,6 +122,9 @@ message GenerateRequest {
// For load balancing // For load balancing
int32 dp_balance_id = 17; int32 dp_balance_id = 17;
// Whether client wants streaming response
bool stream = 18;
} }
message TokenizedInput { message TokenizedInput {
...@@ -163,8 +166,8 @@ message GenerateResponse { ...@@ -163,8 +166,8 @@ message GenerateResponse {
} }
message GenerateStreamChunk { message GenerateStreamChunk {
// Generated token // Generated tokens (incremental chunk)
int32 token_id = 1; repeated int32 token_ids = 1;
// Cumulative counts // Cumulative counts
int32 prompt_tokens = 2; int32 prompt_tokens = 2;
......
...@@ -83,7 +83,7 @@ class DisaggregatedParams(_message.Message): ...@@ -83,7 +83,7 @@ class DisaggregatedParams(_message.Message):
def __init__(self, bootstrap_host: _Optional[str] = ..., bootstrap_port: _Optional[int] = ..., bootstrap_room: _Optional[int] = ...) -> None: ... def __init__(self, bootstrap_host: _Optional[str] = ..., bootstrap_port: _Optional[int] = ..., bootstrap_room: _Optional[int] = ...) -> None: ...
class GenerateRequest(_message.Message): class GenerateRequest(_message.Message):
__slots__ = ("request_id", "tokenized", "mm_inputs", "sampling_params", "return_logprob", "logprob_start_len", "top_logprobs_num", "token_ids_logprob", "return_hidden_states", "disaggregated_params", "custom_logit_processor", "timestamp", "log_metrics", "input_embeds", "lora_id", "data_parallel_rank", "dp_balance_id") __slots__ = ("request_id", "tokenized", "mm_inputs", "sampling_params", "return_logprob", "logprob_start_len", "top_logprobs_num", "token_ids_logprob", "return_hidden_states", "disaggregated_params", "custom_logit_processor", "timestamp", "log_metrics", "input_embeds", "lora_id", "data_parallel_rank", "dp_balance_id", "stream")
REQUEST_ID_FIELD_NUMBER: _ClassVar[int] REQUEST_ID_FIELD_NUMBER: _ClassVar[int]
TOKENIZED_FIELD_NUMBER: _ClassVar[int] TOKENIZED_FIELD_NUMBER: _ClassVar[int]
MM_INPUTS_FIELD_NUMBER: _ClassVar[int] MM_INPUTS_FIELD_NUMBER: _ClassVar[int]
...@@ -101,6 +101,7 @@ class GenerateRequest(_message.Message): ...@@ -101,6 +101,7 @@ class GenerateRequest(_message.Message):
LORA_ID_FIELD_NUMBER: _ClassVar[int] LORA_ID_FIELD_NUMBER: _ClassVar[int]
DATA_PARALLEL_RANK_FIELD_NUMBER: _ClassVar[int] DATA_PARALLEL_RANK_FIELD_NUMBER: _ClassVar[int]
DP_BALANCE_ID_FIELD_NUMBER: _ClassVar[int] DP_BALANCE_ID_FIELD_NUMBER: _ClassVar[int]
STREAM_FIELD_NUMBER: _ClassVar[int]
request_id: str request_id: str
tokenized: TokenizedInput tokenized: TokenizedInput
mm_inputs: MultimodalInputs mm_inputs: MultimodalInputs
...@@ -118,7 +119,8 @@ class GenerateRequest(_message.Message): ...@@ -118,7 +119,8 @@ class GenerateRequest(_message.Message):
lora_id: str lora_id: str
data_parallel_rank: int data_parallel_rank: int
dp_balance_id: int dp_balance_id: int
def __init__(self, request_id: _Optional[str] = ..., tokenized: _Optional[_Union[TokenizedInput, _Mapping]] = ..., mm_inputs: _Optional[_Union[MultimodalInputs, _Mapping]] = ..., sampling_params: _Optional[_Union[SamplingParams, _Mapping]] = ..., return_logprob: bool = ..., logprob_start_len: _Optional[int] = ..., top_logprobs_num: _Optional[int] = ..., token_ids_logprob: _Optional[_Iterable[int]] = ..., return_hidden_states: bool = ..., disaggregated_params: _Optional[_Union[DisaggregatedParams, _Mapping]] = ..., custom_logit_processor: _Optional[str] = ..., timestamp: _Optional[_Union[datetime.datetime, _timestamp_pb2.Timestamp, _Mapping]] = ..., log_metrics: bool = ..., input_embeds: _Optional[_Iterable[float]] = ..., lora_id: _Optional[str] = ..., data_parallel_rank: _Optional[int] = ..., dp_balance_id: _Optional[int] = ...) -> None: ... stream: bool
def __init__(self, request_id: _Optional[str] = ..., tokenized: _Optional[_Union[TokenizedInput, _Mapping]] = ..., mm_inputs: _Optional[_Union[MultimodalInputs, _Mapping]] = ..., sampling_params: _Optional[_Union[SamplingParams, _Mapping]] = ..., return_logprob: bool = ..., logprob_start_len: _Optional[int] = ..., top_logprobs_num: _Optional[int] = ..., token_ids_logprob: _Optional[_Iterable[int]] = ..., return_hidden_states: bool = ..., disaggregated_params: _Optional[_Union[DisaggregatedParams, _Mapping]] = ..., custom_logit_processor: _Optional[str] = ..., timestamp: _Optional[_Union[datetime.datetime, _timestamp_pb2.Timestamp, _Mapping]] = ..., log_metrics: bool = ..., input_embeds: _Optional[_Iterable[float]] = ..., lora_id: _Optional[str] = ..., data_parallel_rank: _Optional[int] = ..., dp_balance_id: _Optional[int] = ..., stream: bool = ...) -> None: ...
class TokenizedInput(_message.Message): class TokenizedInput(_message.Message):
__slots__ = ("original_text", "input_ids") __slots__ = ("original_text", "input_ids")
...@@ -161,20 +163,20 @@ class GenerateResponse(_message.Message): ...@@ -161,20 +163,20 @@ class GenerateResponse(_message.Message):
def __init__(self, request_id: _Optional[str] = ..., chunk: _Optional[_Union[GenerateStreamChunk, _Mapping]] = ..., complete: _Optional[_Union[GenerateComplete, _Mapping]] = ..., error: _Optional[_Union[GenerateError, _Mapping]] = ...) -> None: ... def __init__(self, request_id: _Optional[str] = ..., chunk: _Optional[_Union[GenerateStreamChunk, _Mapping]] = ..., complete: _Optional[_Union[GenerateComplete, _Mapping]] = ..., error: _Optional[_Union[GenerateError, _Mapping]] = ...) -> None: ...
class GenerateStreamChunk(_message.Message): class GenerateStreamChunk(_message.Message):
__slots__ = ("token_id", "prompt_tokens", "completion_tokens", "cached_tokens", "logprobs", "hidden_states") __slots__ = ("token_ids", "prompt_tokens", "completion_tokens", "cached_tokens", "logprobs", "hidden_states")
TOKEN_ID_FIELD_NUMBER: _ClassVar[int] TOKEN_IDS_FIELD_NUMBER: _ClassVar[int]
PROMPT_TOKENS_FIELD_NUMBER: _ClassVar[int] PROMPT_TOKENS_FIELD_NUMBER: _ClassVar[int]
COMPLETION_TOKENS_FIELD_NUMBER: _ClassVar[int] COMPLETION_TOKENS_FIELD_NUMBER: _ClassVar[int]
CACHED_TOKENS_FIELD_NUMBER: _ClassVar[int] CACHED_TOKENS_FIELD_NUMBER: _ClassVar[int]
LOGPROBS_FIELD_NUMBER: _ClassVar[int] LOGPROBS_FIELD_NUMBER: _ClassVar[int]
HIDDEN_STATES_FIELD_NUMBER: _ClassVar[int] HIDDEN_STATES_FIELD_NUMBER: _ClassVar[int]
token_id: int token_ids: _containers.RepeatedScalarFieldContainer[int]
prompt_tokens: int prompt_tokens: int
completion_tokens: int completion_tokens: int
cached_tokens: int cached_tokens: int
logprobs: LogProbs logprobs: LogProbs
hidden_states: _containers.RepeatedScalarFieldContainer[float] hidden_states: _containers.RepeatedScalarFieldContainer[float]
def __init__(self, token_id: _Optional[int] = ..., prompt_tokens: _Optional[int] = ..., completion_tokens: _Optional[int] = ..., cached_tokens: _Optional[int] = ..., logprobs: _Optional[_Union[LogProbs, _Mapping]] = ..., hidden_states: _Optional[_Iterable[float]] = ...) -> None: ... def __init__(self, token_ids: _Optional[_Iterable[int]] = ..., prompt_tokens: _Optional[int] = ..., completion_tokens: _Optional[int] = ..., cached_tokens: _Optional[int] = ..., logprobs: _Optional[_Union[LogProbs, _Mapping]] = ..., hidden_states: _Optional[_Iterable[float]] = ...) -> None: ...
class GenerateComplete(_message.Message): class GenerateComplete(_message.Message):
__slots__ = ("output_ids", "finish_reason", "prompt_tokens", "completion_tokens", "cached_tokens", "all_logprobs", "all_hidden_states") __slots__ = ("output_ids", "finish_reason", "prompt_tokens", "completion_tokens", "cached_tokens", "all_logprobs", "all_hidden_states")
......
...@@ -103,6 +103,7 @@ impl SglangSchedulerClient { ...@@ -103,6 +103,7 @@ impl SglangSchedulerClient {
logprob_start_len: -1, logprob_start_len: -1,
top_logprobs_num: body.top_logprobs.unwrap_or(0) as i32, top_logprobs_num: body.top_logprobs.unwrap_or(0) as i32,
return_hidden_states: body.return_hidden_states, return_hidden_states: body.return_hidden_states,
stream: body.stream,
..Default::default() ..Default::default()
}; };
...@@ -367,14 +368,14 @@ mod tests { ...@@ -367,14 +368,14 @@ mod tests {
#[test] #[test]
fn test_generate_stream_chunk() { fn test_generate_stream_chunk() {
let chunk = proto::GenerateStreamChunk { let chunk = proto::GenerateStreamChunk {
token_id: 1234, token_ids: vec![1234, 5678],
prompt_tokens: 5, prompt_tokens: 5,
completion_tokens: 2, completion_tokens: 2,
cached_tokens: 3, cached_tokens: 3,
..Default::default() ..Default::default()
}; };
assert_eq!(chunk.token_id, 1234); assert_eq!(chunk.token_ids, vec![1234, 5678]);
assert_eq!(chunk.prompt_tokens, 5); assert_eq!(chunk.prompt_tokens, 5);
assert_eq!(chunk.completion_tokens, 2); assert_eq!(chunk.completion_tokens, 2);
assert_eq!(chunk.cached_tokens, 3); assert_eq!(chunk.cached_tokens, 3);
......
...@@ -122,6 +122,9 @@ message GenerateRequest { ...@@ -122,6 +122,9 @@ message GenerateRequest {
// For load balancing // For load balancing
int32 dp_balance_id = 17; int32 dp_balance_id = 17;
// Whether client wants streaming response
bool stream = 18;
} }
message TokenizedInput { message TokenizedInput {
...@@ -163,8 +166,8 @@ message GenerateResponse { ...@@ -163,8 +166,8 @@ message GenerateResponse {
} }
message GenerateStreamChunk { message GenerateStreamChunk {
// Generated token // Generated tokens (incremental chunk)
int32 token_id = 1; repeated int32 token_ids = 1;
// Cumulative counts // Cumulative counts
int32 prompt_tokens = 2; int32 prompt_tokens = 2;
......
...@@ -203,6 +203,7 @@ impl GrpcRouter { ...@@ -203,6 +203,7 @@ impl GrpcRouter {
debug!("Selected worker: {}", worker.url()); debug!("Selected worker: {}", worker.url());
// Step 2: Get gRPC client for worker (fail fast if can't connect) // Step 2: Get gRPC client for worker (fail fast if can't connect)
// TODO(CahterineSue): manage grpc connection in worker. (it should be simpler here)
let client = match self.get_or_create_grpc_client(worker.url()).await { let client = match self.get_or_create_grpc_client(worker.url()).await {
Ok(c) => c, Ok(c) => c,
Err(e) => { Err(e) => {
...@@ -249,7 +250,7 @@ impl GrpcRouter { ...@@ -249,7 +250,7 @@ impl GrpcRouter {
// Step 6: Build the base gRPC request // Step 6: Build the base gRPC request
let request_id = format!("chatcmpl-{}", Uuid::new_v4()); let request_id = format!("chatcmpl-{}", Uuid::new_v4());
let base_request = match client.build_generate_request( let request = match client.build_generate_request(
request_id, request_id,
body, body,
processed_messages.text.clone(), processed_messages.text.clone(),
...@@ -268,11 +269,11 @@ impl GrpcRouter { ...@@ -268,11 +269,11 @@ impl GrpcRouter {
} }
}; };
// Step 7: Handle streaming vs non-streaming
if body.stream { if body.stream {
self.handle_streaming_chat(client, base_request, body).await self.handle_streaming_chat(client, request, body).await
} else { } else {
self.handle_non_streaming_chat(client, base_request, body) self.handle_non_streaming_chat(client, request, body).await
.await
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment