Unverified Commit 37158f20 authored by Chang Su's avatar Chang Su Committed by GitHub
Browse files

router: Support parallel sampling num > 1 in grpc_server and non-stream handling (#10929)

parent 3e95aa1a
...@@ -4,6 +4,7 @@ Mimics TokenizerManager's state management and ZMQ communication patterns. ...@@ -4,6 +4,7 @@ Mimics TokenizerManager's state management and ZMQ communication patterns.
""" """
import asyncio import asyncio
import copy
import dataclasses import dataclasses
import logging import logging
import os import os
...@@ -11,6 +12,7 @@ import signal ...@@ -11,6 +12,7 @@ import signal
import sys import sys
import threading import threading
import time import time
import uuid
from typing import Any, Dict, List, Optional, Union from typing import Any, Dict, List, Optional, Union
import grpc import grpc
...@@ -79,11 +81,9 @@ class GrpcReqState: ...@@ -79,11 +81,9 @@ class GrpcReqState:
last_completion_tokens: int = 1 last_completion_tokens: int = 1
# Streaming state # Streaming state
last_output_offset: int = 0
stream_finished: bool = False stream_finished: bool = False
# Output accumulation # Token accumulation (for non-streaming)
text: str = ""
output_ids: List[int] = dataclasses.field(default_factory=list) output_ids: List[int] = dataclasses.field(default_factory=list)
input_token_logprobs_val: List[float] = dataclasses.field(default_factory=list) input_token_logprobs_val: List[float] = dataclasses.field(default_factory=list)
input_token_logprobs_idx: List[int] = dataclasses.field(default_factory=list) input_token_logprobs_idx: List[int] = dataclasses.field(default_factory=list)
...@@ -139,8 +139,6 @@ class GrpcRequestManager: ...@@ -139,8 +139,6 @@ class GrpcRequestManager:
self.is_pause_cond = asyncio.Condition() self.is_pause_cond = asyncio.Condition()
# Metrics # Metrics
self.request_counter = 0
self.request_counter_lock = asyncio.Lock()
self.last_receive_tstamp = time.time() self.last_receive_tstamp = time.time()
# Crash dump for debugging # Crash dump for debugging
...@@ -158,22 +156,133 @@ class GrpcRequestManager: ...@@ -158,22 +156,133 @@ class GrpcRequestManager:
obj: TokenizedGenerateReqInput, obj: TokenizedGenerateReqInput,
request_id: Optional[str] = None, request_id: Optional[str] = None,
grpc_context: Optional[grpc.aio.ServicerContext] = None, grpc_context: Optional[grpc.aio.ServicerContext] = None,
) -> asyncio.Queue: ):
""" """
Submit a generation request to the scheduler. Submit a generation request to the scheduler with n>1 parallel sampling support.
Returns a queue for streaming outputs.
This method implements the same two-phase approach as tokenizer_manager.py:
1. Phase 1: Send prefix caching request (max_new_tokens=0)
2. Phase 2: Send n generation requests that reuse the cached prefix
Yields individual responses for streaming, or aggregated responses for non-streaming.
""" """
n = getattr(obj.sampling_params, "n", 1)
if n <= 1:
async for response in self._handle_single_request(
obj, request_id, grpc_context
):
yield response
return
# N>1 handling - two-phase approach
logger.debug(f"Multiple sampling request (n={n}), using two-phase approach")
# Generate base request ID if not provided
if request_id is None:
base_request_id = f"grpc-{uuid.uuid4().hex}"
else:
base_request_id = request_id
# Phase 1: Cache the common prefix
logger.debug(f"Phase 1: Caching prefix for request {base_request_id}")
prefix_obj = copy.copy(obj)
prefix_obj.sampling_params = copy.copy(obj.sampling_params)
prefix_obj.sampling_params.max_new_tokens = 0 # Prefill-only
prefix_obj.sampling_params.n = 1 # Don't replicate prefix request
# Send prefix caching request and consume response
async for _ in self._handle_single_request(
prefix_obj, f"{base_request_id}-prefix", grpc_context
):
# Consume prefix response (usually just one chunk with finish_reason)
pass
logger.debug(f"Phase 1 completed: Prefix cached for {base_request_id}")
# Phase 2: Generate n parallel requests
logger.debug(f"Phase 2: Generating {n} parallel requests")
generators = []
request_ids = []
for i in range(n):
# Create individual generation request
gen_obj = copy.copy(obj)
gen_obj.sampling_params = copy.copy(obj.sampling_params)
gen_obj.sampling_params.n = 1 # Each request generates 1 response
gen_request_id = f"{base_request_id}-{i}"
request_ids.append(gen_request_id)
# Start generation request
generators.append(
self._handle_single_request(gen_obj, gen_request_id, grpc_context)
)
# Handle response aggregation
is_stream = getattr(obj, "stream", False)
if not is_stream:
# Non-streaming: collect all responses and return as batch
logger.debug(f"Non-streaming mode: collecting {n} responses")
responses = []
for generator in generators:
async for response in generator:
responses.append(response)
yield responses # Return all responses as a batch
else:
# Streaming mode: multiplex responses with index for ordering
logger.debug(f"Streaming mode: multiplexing {n} streams")
rid_to_index = {rid: i for i, rid in enumerate(request_ids)}
# Create async tasks for all generators
task_map = {}
for generator in generators:
task = asyncio.create_task(generator.__anext__())
task_map[task] = generator
# Process responses as they arrive
while task_map:
done, _ = await asyncio.wait(
task_map.keys(), return_when=asyncio.FIRST_COMPLETED
)
for task in done:
generator = task_map.pop(task)
try:
response = await task
# Add index for client-side ordering
if isinstance(response, dict) and "meta_info" in response:
response_rid = response["meta_info"].get("id", "")
if response_rid in rid_to_index:
response["index"] = rid_to_index[response_rid]
yield response
# Create next task for this generator
next_task = asyncio.create_task(generator.__anext__())
task_map[next_task] = generator
except StopAsyncIteration:
# This generator is finished
pass
async def _handle_single_request(
self,
obj: TokenizedGenerateReqInput,
request_id: Optional[str] = None,
grpc_context: Optional[grpc.aio.ServicerContext] = None,
):
"""Handle a single request - core implementation without n>1 logic."""
# Generate request ID if not provided # Generate request ID if not provided
if request_id is None: if request_id is None:
async with self.request_counter_lock: request_id = f"grpc-{uuid.uuid4().hex}"
request_id = f"grpc-{self.request_counter}"
self.request_counter += 1
obj.rid = request_id obj.rid = request_id
# Create and register request state
# TODO: support log_request # TODO: support log_request
# Create request state
state = GrpcReqState( state = GrpcReqState(
request_id=request_id, request_id=request_id,
grpc_context=grpc_context, grpc_context=grpc_context,
...@@ -189,19 +298,51 @@ class GrpcRequestManager: ...@@ -189,19 +298,51 @@ class GrpcRequestManager:
state.session_id = obj.session_params.session_id state.session_id = obj.session_params.session_id
state.is_session_request = True state.is_session_request = True
# Register state
self.rid_to_state[request_id] = state self.rid_to_state[request_id] = state
self.record_request_for_crash_dump(obj) self.record_request_for_crash_dump(obj)
# Send to scheduler via ZMQ
try: try:
# Send to scheduler - let exceptions bubble up to grpc_server.py
await self._send_to_scheduler(obj) await self._send_to_scheduler(obj)
except Exception as e:
# Clean up on failure
del self.rid_to_state[request_id]
raise RuntimeError(f"Failed to send request to scheduler: {e}")
return state.out_queue is_stream = getattr(obj, "stream", False)
while True:
# Client cancelled - notify scheduler and exit
if grpc_context and grpc_context.cancelled():
await self.abort_request(request_id)
return
try:
response = await asyncio.wait_for(state.out_queue.get(), timeout=4)
if is_stream:
yield response
# Non-streaming: yield final response with accumulated tokens from state
if isinstance(response, dict) and response.get("finished", False):
if not is_stream:
final_response = response.copy()
final_response["token_ids"] = state.output_ids
yield final_response
break
except asyncio.TimeoutError:
# Timeout waiting for response - abort and cleanup
logger.warning(
f"Timeout waiting for response for request {request_id}"
)
await self.abort_request(request_id)
return
finally:
# Always clean up request state when exiting
self._cleanup_request_state(request_id)
def _cleanup_request_state(self, request_id: str):
"""Clean up local request state (does not notify scheduler)."""
if request_id in self.rid_to_state:
del self.rid_to_state[request_id]
async def embedding_request( async def embedding_request(
self, self,
...@@ -214,9 +355,7 @@ class GrpcRequestManager: ...@@ -214,9 +355,7 @@ class GrpcRequestManager:
""" """
# Generate request ID if not provided # Generate request ID if not provided
if request_id is None: if request_id is None:
async with self.request_counter_lock: request_id = f"grpc-embed-{uuid.uuid4().hex}"
request_id = f"grpc-embed-{self.request_counter}"
self.request_counter += 1
obj.rid = request_id obj.rid = request_id
...@@ -355,7 +494,6 @@ class GrpcRequestManager: ...@@ -355,7 +494,6 @@ class GrpcRequestManager:
# Extract output for this request # Extract output for this request
output_data = { output_data = {
"request_id": rid, "request_id": rid,
"text": batch_out.decoded_texts[i] if batch_out.decoded_texts else "",
"token_ids": batch_out.output_ids[i] if batch_out.output_ids else [], "token_ids": batch_out.output_ids[i] if batch_out.output_ids else [],
"finished": batch_out.finished_reasons[i] is not None, "finished": batch_out.finished_reasons[i] is not None,
"meta_info": { "meta_info": {
...@@ -367,6 +505,9 @@ class GrpcRequestManager: ...@@ -367,6 +505,9 @@ class GrpcRequestManager:
if batch_out.completion_tokens if batch_out.completion_tokens
else 0 else 0
), ),
"cached_tokens": (
batch_out.cached_tokens[i] if batch_out.cached_tokens else 0
),
"finish_reason": ( "finish_reason": (
str(batch_out.finished_reasons[i]) str(batch_out.finished_reasons[i])
if batch_out.finished_reasons[i] if batch_out.finished_reasons[i]
...@@ -389,15 +530,10 @@ class GrpcRequestManager: ...@@ -389,15 +530,10 @@ class GrpcRequestManager:
), ),
} }
# Update state # Update state for accumulation
if output_data["text"]:
state.text += output_data["text"][state.last_output_offset :]
state.last_output_offset = len(output_data["text"])
if output_data["token_ids"]: if output_data["token_ids"]:
state.output_ids.extend(output_data["token_ids"]) state.output_ids.extend(output_data["token_ids"])
# Send to output queue
await state.out_queue.put(output_data) await state.out_queue.put(output_data)
# Handle completion # Handle completion
......
...@@ -181,20 +181,34 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer) ...@@ -181,20 +181,34 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
# Convert gRPC request to internal format # Convert gRPC request to internal format
tokenized_req = self._convert_generate_request(request) tokenized_req = self._convert_generate_request(request)
# Submit to request manager # Submit to request manager (automatically handles n>1)
output_queue = await self.request_manager.generate_request( response_generator = self.request_manager.generate_request(
obj=tokenized_req, obj=tokenized_req,
request_id=request.request_id, request_id=request.request_id,
grpc_context=context, grpc_context=context,
) )
# Stream outputs async for output in response_generator:
while True: # Handle batch responses (for n>1 non-streaming)
try: if isinstance(output, list):
# Get output with timeout for batch_output in output:
output = await asyncio.wait_for(output_queue.get(), timeout=4) if "error" in batch_output:
yield sglang_scheduler_pb2.GenerateResponse(
# Check for errors request_id=request.request_id,
error=sglang_scheduler_pb2.GenerateError(
message=batch_output["error"],
http_status_code=(
"500" if "abort" not in batch_output else "499"
),
),
)
else:
# All non-error batch outputs are final responses
yield self._create_completion_response(
request.request_id, batch_output
)
else:
# Handle single response (for streaming or n=1 non-streaming)
if "error" in output: if "error" in output:
yield sglang_scheduler_pb2.GenerateResponse( yield sglang_scheduler_pb2.GenerateResponse(
request_id=request.request_id, request_id=request.request_id,
...@@ -205,27 +219,13 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer) ...@@ -205,27 +219,13 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
), ),
), ),
) )
break elif output.get("finished", False):
# Check if finished
if output.get("finished", False):
# Send completion
yield self._create_completion_response( yield self._create_completion_response(
request.request_id, output request.request_id, output
) )
break
else: else:
# Send chunk
yield self._create_chunk_response(request.request_id, output) yield self._create_chunk_response(request.request_id, output)
except asyncio.TimeoutError:
# Check if context is still active
if context.cancelled():
# Abort the request
await self.request_manager.abort_request(request.request_id)
break
continue
except Exception as e: except Exception as e:
logger.error(f"Generate failed: {e}\n{get_exception_traceback()}") logger.error(f"Generate failed: {e}\n{get_exception_traceback()}")
yield sglang_scheduler_pb2.GenerateResponse( yield sglang_scheduler_pb2.GenerateResponse(
...@@ -403,7 +403,7 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer) ...@@ -403,7 +403,7 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
return_logprob=grpc_req.return_logprob, return_logprob=grpc_req.return_logprob,
logprob_start_len=grpc_req.logprob_start_len or -1, logprob_start_len=grpc_req.logprob_start_len or -1,
top_logprobs_num=grpc_req.top_logprobs_num or 0, top_logprobs_num=grpc_req.top_logprobs_num or 0,
stream=True, # Always stream for gRPC stream=grpc_req.stream or False,
lora_path=grpc_req.lora_id if grpc_req.lora_id else None, lora_path=grpc_req.lora_id if grpc_req.lora_id else None,
token_ids_logprob=( token_ids_logprob=(
list(grpc_req.token_ids_logprob) if grpc_req.token_ids_logprob else None list(grpc_req.token_ids_logprob) if grpc_req.token_ids_logprob else None
...@@ -480,10 +480,10 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer) ...@@ -480,10 +480,10 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
return sglang_scheduler_pb2.GenerateResponse( return sglang_scheduler_pb2.GenerateResponse(
request_id=request_id, request_id=request_id,
chunk=sglang_scheduler_pb2.GenerateStreamChunk( chunk=sglang_scheduler_pb2.GenerateStreamChunk(
token_id=output["token_ids"][-1] if output.get("token_ids") else 0, token_ids=output.get("token_ids", []),
prompt_tokens=meta_info.get("prompt_tokens", 0), prompt_tokens=meta_info.get("prompt_tokens", 0),
completion_tokens=meta_info.get("completion_tokens", 0), completion_tokens=meta_info.get("completion_tokens", 0),
cached_tokens=0, cached_tokens=meta_info.get("cached_tokens", 0),
), ),
) )
......
...@@ -122,6 +122,9 @@ message GenerateRequest { ...@@ -122,6 +122,9 @@ message GenerateRequest {
// For load balancing // For load balancing
int32 dp_balance_id = 17; int32 dp_balance_id = 17;
// Whether client wants streaming response
bool stream = 18;
} }
message TokenizedInput { message TokenizedInput {
...@@ -163,8 +166,8 @@ message GenerateResponse { ...@@ -163,8 +166,8 @@ message GenerateResponse {
} }
message GenerateStreamChunk { message GenerateStreamChunk {
// Generated token // Generated tokens (incremental chunk)
int32 token_id = 1; repeated int32 token_ids = 1;
// Cumulative counts // Cumulative counts
int32 prompt_tokens = 2; int32 prompt_tokens = 2;
......
...@@ -29,7 +29,7 @@ from google.protobuf import timestamp_pb2 as google_dot_protobuf_dot_timestamp__ ...@@ -29,7 +29,7 @@ from google.protobuf import timestamp_pb2 as google_dot_protobuf_dot_timestamp__
from google.protobuf import struct_pb2 as google_dot_protobuf_dot_struct__pb2 from google.protobuf import struct_pb2 as google_dot_protobuf_dot_struct__pb2
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x16sglang_scheduler.proto\x12\x15sglang.grpc.scheduler\x1a\x1fgoogle/protobuf/timestamp.proto\x1a\x1cgoogle/protobuf/struct.proto\"\xc9\x05\n\x0eSamplingParams\x12\x13\n\x0btemperature\x18\x01 \x01(\x02\x12\r\n\x05top_p\x18\x02 \x01(\x02\x12\r\n\x05top_k\x18\x03 \x01(\x05\x12\r\n\x05min_p\x18\x04 \x01(\x02\x12\x19\n\x11\x66requency_penalty\x18\x05 \x01(\x02\x12\x18\n\x10presence_penalty\x18\x06 \x01(\x02\x12\x1a\n\x12repetition_penalty\x18\x07 \x01(\x02\x12\x16\n\x0emax_new_tokens\x18\x08 \x01(\x05\x12\x0c\n\x04stop\x18\t \x03(\t\x12\x16\n\x0estop_token_ids\x18\n \x03(\x05\x12\x1b\n\x13skip_special_tokens\x18\x0b \x01(\x08\x12%\n\x1dspaces_between_special_tokens\x18\x0c \x01(\x08\x12\x0f\n\x05regex\x18\r \x01(\tH\x00\x12\x15\n\x0bjson_schema\x18\x0e \x01(\tH\x00\x12\x16\n\x0c\x65\x62nf_grammar\x18\x0f \x01(\tH\x00\x12\x18\n\x0estructural_tag\x18\x10 \x01(\tH\x00\x12\x11\n\tlora_path\x18\x11 \x01(\t\x12\t\n\x01n\x18\x12 \x01(\x05\x12\x15\n\rtoken_healing\x18\x13 \x01(\x08\x12\x16\n\x0emin_new_tokens\x18\x14 \x01(\x05\x12\x12\n\nignore_eos\x18\x15 \x01(\x08\x12\x14\n\x0cno_stop_trim\x18\x16 \x01(\x08\x12\x17\n\x0fstream_interval\x18\x17 \x01(\x05\x12H\n\nlogit_bias\x18\x18 \x03(\x0b\x32\x34.sglang.grpc.scheduler.SamplingParams.LogitBiasEntry\x12.\n\rcustom_params\x18\x19 \x01(\x0b\x32\x17.google.protobuf.Struct\x1a\x30\n\x0eLogitBiasEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\x02:\x02\x38\x01\x42\x0c\n\nconstraint\"]\n\x13\x44isaggregatedParams\x12\x16\n\x0e\x62ootstrap_host\x18\x01 \x01(\t\x12\x16\n\x0e\x62ootstrap_port\x18\x02 \x01(\x05\x12\x16\n\x0e\x62ootstrap_room\x18\x03 \x01(\x05\"\xe9\x04\n\x0fGenerateRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x38\n\ttokenized\x18\x02 \x01(\x0b\x32%.sglang.grpc.scheduler.TokenizedInput\x12:\n\tmm_inputs\x18\x03 \x01(\x0b\x32\'.sglang.grpc.scheduler.MultimodalInputs\x12>\n\x0fsampling_params\x18\x04 \x01(\x0b\x32%.sglang.grpc.scheduler.SamplingParams\x12\x16\n\x0ereturn_logprob\x18\x05 \x01(\x08\x12\x19\n\x11logprob_start_len\x18\x06 \x01(\x05\x12\x18\n\x10top_logprobs_num\x18\x07 \x01(\x05\x12\x19\n\x11token_ids_logprob\x18\x08 \x03(\x05\x12\x1c\n\x14return_hidden_states\x18\t \x01(\x08\x12H\n\x14\x64isaggregated_params\x18\n \x01(\x0b\x32*.sglang.grpc.scheduler.DisaggregatedParams\x12\x1e\n\x16\x63ustom_logit_processor\x18\x0b \x01(\t\x12-\n\ttimestamp\x18\x0c \x01(\x0b\x32\x1a.google.protobuf.Timestamp\x12\x13\n\x0blog_metrics\x18\r \x01(\x08\x12\x14\n\x0cinput_embeds\x18\x0e \x03(\x02\x12\x0f\n\x07lora_id\x18\x0f \x01(\t\x12\x1a\n\x12\x64\x61ta_parallel_rank\x18\x10 \x01(\x05\x12\x15\n\rdp_balance_id\x18\x11 \x01(\x05\":\n\x0eTokenizedInput\x12\x15\n\roriginal_text\x18\x01 \x01(\t\x12\x11\n\tinput_ids\x18\x02 \x03(\x05\"\xd3\x01\n\x10MultimodalInputs\x12\x12\n\nimage_urls\x18\x01 \x03(\t\x12\x12\n\nvideo_urls\x18\x02 \x03(\t\x12\x12\n\naudio_urls\x18\x03 \x03(\t\x12\x33\n\x12processed_features\x18\x04 \x01(\x0b\x32\x17.google.protobuf.Struct\x12\x12\n\nimage_data\x18\x05 \x03(\x0c\x12\x12\n\nvideo_data\x18\x06 \x03(\x0c\x12\x12\n\naudio_data\x18\x07 \x03(\x0c\x12\x12\n\nmodalities\x18\x08 \x03(\t\"\xe3\x01\n\x10GenerateResponse\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12;\n\x05\x63hunk\x18\x02 \x01(\x0b\x32*.sglang.grpc.scheduler.GenerateStreamChunkH\x00\x12;\n\x08\x63omplete\x18\x03 \x01(\x0b\x32\'.sglang.grpc.scheduler.GenerateCompleteH\x00\x12\x35\n\x05\x65rror\x18\x04 \x01(\x0b\x32$.sglang.grpc.scheduler.GenerateErrorH\x00\x42\n\n\x08response\"\xba\x01\n\x13GenerateStreamChunk\x12\x10\n\x08token_id\x18\x01 \x01(\x05\x12\x15\n\rprompt_tokens\x18\x02 \x01(\x05\x12\x19\n\x11\x63ompletion_tokens\x18\x03 \x01(\x05\x12\x15\n\rcached_tokens\x18\x04 \x01(\x05\x12\x31\n\x08logprobs\x18\x05 \x01(\x0b\x32\x1f.sglang.grpc.scheduler.LogProbs\x12\x15\n\rhidden_states\x18\x06 \x03(\x02\"\x81\x03\n\x10GenerateComplete\x12\x12\n\noutput_ids\x18\x01 \x03(\x05\x12K\n\rfinish_reason\x18\x02 \x01(\x0e\x32\x34.sglang.grpc.scheduler.GenerateComplete.FinishReason\x12\x15\n\rprompt_tokens\x18\x03 \x01(\x05\x12\x19\n\x11\x63ompletion_tokens\x18\x04 \x01(\x05\x12\x15\n\rcached_tokens\x18\x05 \x01(\x05\x12\x35\n\x0c\x61ll_logprobs\x18\x06 \x03(\x0b\x32\x1f.sglang.grpc.scheduler.LogProbs\x12>\n\x11\x61ll_hidden_states\x18\x07 \x03(\x0b\x32#.sglang.grpc.scheduler.HiddenStates\"L\n\x0c\x46inishReason\x12\x08\n\x04STOP\x10\x00\x12\n\n\x06LENGTH\x10\x01\x12\r\n\tEOS_TOKEN\x10\x02\x12\x0c\n\x08STOP_STR\x10\x03\x12\t\n\x05\x41\x42ORT\x10\x04\"K\n\rGenerateError\x12\x0f\n\x07message\x18\x01 \x01(\t\x12\x18\n\x10http_status_code\x18\x02 \x01(\t\x12\x0f\n\x07\x64\x65tails\x18\x03 \x01(\t\"\x84\x01\n\x08LogProbs\x12\x16\n\x0etoken_logprobs\x18\x01 \x03(\x02\x12\x11\n\ttoken_ids\x18\x02 \x03(\x05\x12\x38\n\x0ctop_logprobs\x18\x03 \x03(\x0b\x32\".sglang.grpc.scheduler.TopLogProbs\x12\x13\n\x0btoken_texts\x18\x04 \x03(\t\"E\n\x0bTopLogProbs\x12\x0e\n\x06values\x18\x01 \x03(\x02\x12\x11\n\ttoken_ids\x18\x02 \x03(\x05\x12\x13\n\x0btoken_texts\x18\x03 \x03(\t\"?\n\x0cHiddenStates\x12\x0e\n\x06values\x18\x01 \x03(\x02\x12\r\n\x05layer\x18\x02 \x01(\x05\x12\x10\n\x08position\x18\x03 \x01(\x05\"\xca\x02\n\x0c\x45mbedRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x38\n\ttokenized\x18\x02 \x01(\x0b\x32%.sglang.grpc.scheduler.TokenizedInput\x12:\n\tmm_inputs\x18\x04 \x01(\x0b\x32\'.sglang.grpc.scheduler.MultimodalInputs\x12>\n\x0fsampling_params\x18\x05 \x01(\x0b\x32%.sglang.grpc.scheduler.SamplingParams\x12\x13\n\x0blog_metrics\x18\x06 \x01(\x08\x12\x16\n\x0etoken_type_ids\x18\x07 \x03(\x05\x12\x1a\n\x12\x64\x61ta_parallel_rank\x18\x08 \x01(\x05\x12\x18\n\x10is_cross_encoder\x18\t \x01(\x08\x12\r\n\x05texts\x18\n \x03(\t\"\x9d\x01\n\rEmbedResponse\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x38\n\x08\x63omplete\x18\x02 \x01(\x0b\x32$.sglang.grpc.scheduler.EmbedCompleteH\x00\x12\x32\n\x05\x65rror\x18\x03 \x01(\x0b\x32!.sglang.grpc.scheduler.EmbedErrorH\x00\x42\n\n\x08response\"\xa3\x01\n\rEmbedComplete\x12\x11\n\tembedding\x18\x01 \x03(\x02\x12\x15\n\rprompt_tokens\x18\x02 \x01(\x05\x12\x15\n\rcached_tokens\x18\x03 \x01(\x05\x12\x15\n\rembedding_dim\x18\x04 \x01(\x05\x12:\n\x10\x62\x61tch_embeddings\x18\x05 \x03(\x0b\x32 .sglang.grpc.scheduler.Embedding\"*\n\tEmbedding\x12\x0e\n\x06values\x18\x01 \x03(\x02\x12\r\n\x05index\x18\x02 \x01(\x05\"<\n\nEmbedError\x12\x0f\n\x07message\x18\x01 \x01(\t\x12\x0c\n\x04\x63ode\x18\x02 \x01(\t\x12\x0f\n\x07\x64\x65tails\x18\x03 \x01(\t\"N\n\x12HealthCheckRequest\x12\x38\n\ttokenized\x18\x01 \x01(\x0b\x32%.sglang.grpc.scheduler.TokenizedInput\"7\n\x13HealthCheckResponse\x12\x0f\n\x07healthy\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"2\n\x0c\x41\x62ortRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x0e\n\x06reason\x18\x02 \x01(\t\"1\n\rAbortResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"I\n\x0fLoadLoRARequest\x12\x12\n\nadapter_id\x18\x01 \x01(\t\x12\x14\n\x0c\x61\x64\x61pter_path\x18\x02 \x01(\t\x12\x0c\n\x04rank\x18\x03 \x01(\x05\"H\n\x10LoadLoRAResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x12\n\nadapter_id\x18\x02 \x01(\t\x12\x0f\n\x07message\x18\x03 \x01(\t\"\'\n\x11UnloadLoRARequest\x12\x12\n\nadapter_id\x18\x01 \x01(\t\"6\n\x12UnloadLoRAResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"w\n\x14UpdateWeightsRequest\x12\x13\n\tdisk_path\x18\x01 \x01(\tH\x00\x12\x15\n\x0btensor_data\x18\x02 \x01(\x0cH\x00\x12\x14\n\nremote_url\x18\x03 \x01(\tH\x00\x12\x13\n\x0bweight_name\x18\x04 \x01(\tB\x08\n\x06source\"9\n\x15UpdateWeightsResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"-\n\x17GetInternalStateRequest\x12\x12\n\nstate_keys\x18\x01 \x03(\t\"B\n\x18GetInternalStateResponse\x12&\n\x05state\x18\x01 \x01(\x0b\x32\x17.google.protobuf.Struct\"A\n\x17SetInternalStateRequest\x12&\n\x05state\x18\x01 \x01(\x0b\x32\x17.google.protobuf.Struct\"<\n\x18SetInternalStateResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t2\xfe\x02\n\x0fSglangScheduler\x12]\n\x08Generate\x12&.sglang.grpc.scheduler.GenerateRequest\x1a\'.sglang.grpc.scheduler.GenerateResponse0\x01\x12R\n\x05\x45mbed\x12#.sglang.grpc.scheduler.EmbedRequest\x1a$.sglang.grpc.scheduler.EmbedResponse\x12\x64\n\x0bHealthCheck\x12).sglang.grpc.scheduler.HealthCheckRequest\x1a*.sglang.grpc.scheduler.HealthCheckResponse\x12R\n\x05\x41\x62ort\x12#.sglang.grpc.scheduler.AbortRequest\x1a$.sglang.grpc.scheduler.AbortResponseb\x06proto3') DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x16sglang_scheduler.proto\x12\x15sglang.grpc.scheduler\x1a\x1fgoogle/protobuf/timestamp.proto\x1a\x1cgoogle/protobuf/struct.proto\"\xc9\x05\n\x0eSamplingParams\x12\x13\n\x0btemperature\x18\x01 \x01(\x02\x12\r\n\x05top_p\x18\x02 \x01(\x02\x12\r\n\x05top_k\x18\x03 \x01(\x05\x12\r\n\x05min_p\x18\x04 \x01(\x02\x12\x19\n\x11\x66requency_penalty\x18\x05 \x01(\x02\x12\x18\n\x10presence_penalty\x18\x06 \x01(\x02\x12\x1a\n\x12repetition_penalty\x18\x07 \x01(\x02\x12\x16\n\x0emax_new_tokens\x18\x08 \x01(\x05\x12\x0c\n\x04stop\x18\t \x03(\t\x12\x16\n\x0estop_token_ids\x18\n \x03(\x05\x12\x1b\n\x13skip_special_tokens\x18\x0b \x01(\x08\x12%\n\x1dspaces_between_special_tokens\x18\x0c \x01(\x08\x12\x0f\n\x05regex\x18\r \x01(\tH\x00\x12\x15\n\x0bjson_schema\x18\x0e \x01(\tH\x00\x12\x16\n\x0c\x65\x62nf_grammar\x18\x0f \x01(\tH\x00\x12\x18\n\x0estructural_tag\x18\x10 \x01(\tH\x00\x12\x11\n\tlora_path\x18\x11 \x01(\t\x12\t\n\x01n\x18\x12 \x01(\x05\x12\x15\n\rtoken_healing\x18\x13 \x01(\x08\x12\x16\n\x0emin_new_tokens\x18\x14 \x01(\x05\x12\x12\n\nignore_eos\x18\x15 \x01(\x08\x12\x14\n\x0cno_stop_trim\x18\x16 \x01(\x08\x12\x17\n\x0fstream_interval\x18\x17 \x01(\x05\x12H\n\nlogit_bias\x18\x18 \x03(\x0b\x32\x34.sglang.grpc.scheduler.SamplingParams.LogitBiasEntry\x12.\n\rcustom_params\x18\x19 \x01(\x0b\x32\x17.google.protobuf.Struct\x1a\x30\n\x0eLogitBiasEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\x02:\x02\x38\x01\x42\x0c\n\nconstraint\"]\n\x13\x44isaggregatedParams\x12\x16\n\x0e\x62ootstrap_host\x18\x01 \x01(\t\x12\x16\n\x0e\x62ootstrap_port\x18\x02 \x01(\x05\x12\x16\n\x0e\x62ootstrap_room\x18\x03 \x01(\x05\"\xf9\x04\n\x0fGenerateRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x38\n\ttokenized\x18\x02 \x01(\x0b\x32%.sglang.grpc.scheduler.TokenizedInput\x12:\n\tmm_inputs\x18\x03 \x01(\x0b\x32\'.sglang.grpc.scheduler.MultimodalInputs\x12>\n\x0fsampling_params\x18\x04 \x01(\x0b\x32%.sglang.grpc.scheduler.SamplingParams\x12\x16\n\x0ereturn_logprob\x18\x05 \x01(\x08\x12\x19\n\x11logprob_start_len\x18\x06 \x01(\x05\x12\x18\n\x10top_logprobs_num\x18\x07 \x01(\x05\x12\x19\n\x11token_ids_logprob\x18\x08 \x03(\x05\x12\x1c\n\x14return_hidden_states\x18\t \x01(\x08\x12H\n\x14\x64isaggregated_params\x18\n \x01(\x0b\x32*.sglang.grpc.scheduler.DisaggregatedParams\x12\x1e\n\x16\x63ustom_logit_processor\x18\x0b \x01(\t\x12-\n\ttimestamp\x18\x0c \x01(\x0b\x32\x1a.google.protobuf.Timestamp\x12\x13\n\x0blog_metrics\x18\r \x01(\x08\x12\x14\n\x0cinput_embeds\x18\x0e \x03(\x02\x12\x0f\n\x07lora_id\x18\x0f \x01(\t\x12\x1a\n\x12\x64\x61ta_parallel_rank\x18\x10 \x01(\x05\x12\x15\n\rdp_balance_id\x18\x11 \x01(\x05\x12\x0e\n\x06stream\x18\x12 \x01(\x08\":\n\x0eTokenizedInput\x12\x15\n\roriginal_text\x18\x01 \x01(\t\x12\x11\n\tinput_ids\x18\x02 \x03(\x05\"\xd3\x01\n\x10MultimodalInputs\x12\x12\n\nimage_urls\x18\x01 \x03(\t\x12\x12\n\nvideo_urls\x18\x02 \x03(\t\x12\x12\n\naudio_urls\x18\x03 \x03(\t\x12\x33\n\x12processed_features\x18\x04 \x01(\x0b\x32\x17.google.protobuf.Struct\x12\x12\n\nimage_data\x18\x05 \x03(\x0c\x12\x12\n\nvideo_data\x18\x06 \x03(\x0c\x12\x12\n\naudio_data\x18\x07 \x03(\x0c\x12\x12\n\nmodalities\x18\x08 \x03(\t\"\xe3\x01\n\x10GenerateResponse\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12;\n\x05\x63hunk\x18\x02 \x01(\x0b\x32*.sglang.grpc.scheduler.GenerateStreamChunkH\x00\x12;\n\x08\x63omplete\x18\x03 \x01(\x0b\x32\'.sglang.grpc.scheduler.GenerateCompleteH\x00\x12\x35\n\x05\x65rror\x18\x04 \x01(\x0b\x32$.sglang.grpc.scheduler.GenerateErrorH\x00\x42\n\n\x08response\"\xbb\x01\n\x13GenerateStreamChunk\x12\x11\n\ttoken_ids\x18\x01 \x03(\x05\x12\x15\n\rprompt_tokens\x18\x02 \x01(\x05\x12\x19\n\x11\x63ompletion_tokens\x18\x03 \x01(\x05\x12\x15\n\rcached_tokens\x18\x04 \x01(\x05\x12\x31\n\x08logprobs\x18\x05 \x01(\x0b\x32\x1f.sglang.grpc.scheduler.LogProbs\x12\x15\n\rhidden_states\x18\x06 \x03(\x02\"\x81\x03\n\x10GenerateComplete\x12\x12\n\noutput_ids\x18\x01 \x03(\x05\x12K\n\rfinish_reason\x18\x02 \x01(\x0e\x32\x34.sglang.grpc.scheduler.GenerateComplete.FinishReason\x12\x15\n\rprompt_tokens\x18\x03 \x01(\x05\x12\x19\n\x11\x63ompletion_tokens\x18\x04 \x01(\x05\x12\x15\n\rcached_tokens\x18\x05 \x01(\x05\x12\x35\n\x0c\x61ll_logprobs\x18\x06 \x03(\x0b\x32\x1f.sglang.grpc.scheduler.LogProbs\x12>\n\x11\x61ll_hidden_states\x18\x07 \x03(\x0b\x32#.sglang.grpc.scheduler.HiddenStates\"L\n\x0c\x46inishReason\x12\x08\n\x04STOP\x10\x00\x12\n\n\x06LENGTH\x10\x01\x12\r\n\tEOS_TOKEN\x10\x02\x12\x0c\n\x08STOP_STR\x10\x03\x12\t\n\x05\x41\x42ORT\x10\x04\"K\n\rGenerateError\x12\x0f\n\x07message\x18\x01 \x01(\t\x12\x18\n\x10http_status_code\x18\x02 \x01(\t\x12\x0f\n\x07\x64\x65tails\x18\x03 \x01(\t\"\x84\x01\n\x08LogProbs\x12\x16\n\x0etoken_logprobs\x18\x01 \x03(\x02\x12\x11\n\ttoken_ids\x18\x02 \x03(\x05\x12\x38\n\x0ctop_logprobs\x18\x03 \x03(\x0b\x32\".sglang.grpc.scheduler.TopLogProbs\x12\x13\n\x0btoken_texts\x18\x04 \x03(\t\"E\n\x0bTopLogProbs\x12\x0e\n\x06values\x18\x01 \x03(\x02\x12\x11\n\ttoken_ids\x18\x02 \x03(\x05\x12\x13\n\x0btoken_texts\x18\x03 \x03(\t\"?\n\x0cHiddenStates\x12\x0e\n\x06values\x18\x01 \x03(\x02\x12\r\n\x05layer\x18\x02 \x01(\x05\x12\x10\n\x08position\x18\x03 \x01(\x05\"\xca\x02\n\x0c\x45mbedRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x38\n\ttokenized\x18\x02 \x01(\x0b\x32%.sglang.grpc.scheduler.TokenizedInput\x12:\n\tmm_inputs\x18\x04 \x01(\x0b\x32\'.sglang.grpc.scheduler.MultimodalInputs\x12>\n\x0fsampling_params\x18\x05 \x01(\x0b\x32%.sglang.grpc.scheduler.SamplingParams\x12\x13\n\x0blog_metrics\x18\x06 \x01(\x08\x12\x16\n\x0etoken_type_ids\x18\x07 \x03(\x05\x12\x1a\n\x12\x64\x61ta_parallel_rank\x18\x08 \x01(\x05\x12\x18\n\x10is_cross_encoder\x18\t \x01(\x08\x12\r\n\x05texts\x18\n \x03(\t\"\x9d\x01\n\rEmbedResponse\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x38\n\x08\x63omplete\x18\x02 \x01(\x0b\x32$.sglang.grpc.scheduler.EmbedCompleteH\x00\x12\x32\n\x05\x65rror\x18\x03 \x01(\x0b\x32!.sglang.grpc.scheduler.EmbedErrorH\x00\x42\n\n\x08response\"\xa3\x01\n\rEmbedComplete\x12\x11\n\tembedding\x18\x01 \x03(\x02\x12\x15\n\rprompt_tokens\x18\x02 \x01(\x05\x12\x15\n\rcached_tokens\x18\x03 \x01(\x05\x12\x15\n\rembedding_dim\x18\x04 \x01(\x05\x12:\n\x10\x62\x61tch_embeddings\x18\x05 \x03(\x0b\x32 .sglang.grpc.scheduler.Embedding\"*\n\tEmbedding\x12\x0e\n\x06values\x18\x01 \x03(\x02\x12\r\n\x05index\x18\x02 \x01(\x05\"<\n\nEmbedError\x12\x0f\n\x07message\x18\x01 \x01(\t\x12\x0c\n\x04\x63ode\x18\x02 \x01(\t\x12\x0f\n\x07\x64\x65tails\x18\x03 \x01(\t\"N\n\x12HealthCheckRequest\x12\x38\n\ttokenized\x18\x01 \x01(\x0b\x32%.sglang.grpc.scheduler.TokenizedInput\"7\n\x13HealthCheckResponse\x12\x0f\n\x07healthy\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"2\n\x0c\x41\x62ortRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x0e\n\x06reason\x18\x02 \x01(\t\"1\n\rAbortResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"I\n\x0fLoadLoRARequest\x12\x12\n\nadapter_id\x18\x01 \x01(\t\x12\x14\n\x0c\x61\x64\x61pter_path\x18\x02 \x01(\t\x12\x0c\n\x04rank\x18\x03 \x01(\x05\"H\n\x10LoadLoRAResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x12\n\nadapter_id\x18\x02 \x01(\t\x12\x0f\n\x07message\x18\x03 \x01(\t\"\'\n\x11UnloadLoRARequest\x12\x12\n\nadapter_id\x18\x01 \x01(\t\"6\n\x12UnloadLoRAResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"w\n\x14UpdateWeightsRequest\x12\x13\n\tdisk_path\x18\x01 \x01(\tH\x00\x12\x15\n\x0btensor_data\x18\x02 \x01(\x0cH\x00\x12\x14\n\nremote_url\x18\x03 \x01(\tH\x00\x12\x13\n\x0bweight_name\x18\x04 \x01(\tB\x08\n\x06source\"9\n\x15UpdateWeightsResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"-\n\x17GetInternalStateRequest\x12\x12\n\nstate_keys\x18\x01 \x03(\t\"B\n\x18GetInternalStateResponse\x12&\n\x05state\x18\x01 \x01(\x0b\x32\x17.google.protobuf.Struct\"A\n\x17SetInternalStateRequest\x12&\n\x05state\x18\x01 \x01(\x0b\x32\x17.google.protobuf.Struct\"<\n\x18SetInternalStateResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t2\xfe\x02\n\x0fSglangScheduler\x12]\n\x08Generate\x12&.sglang.grpc.scheduler.GenerateRequest\x1a\'.sglang.grpc.scheduler.GenerateResponse0\x01\x12R\n\x05\x45mbed\x12#.sglang.grpc.scheduler.EmbedRequest\x1a$.sglang.grpc.scheduler.EmbedResponse\x12\x64\n\x0bHealthCheck\x12).sglang.grpc.scheduler.HealthCheckRequest\x1a*.sglang.grpc.scheduler.HealthCheckResponse\x12R\n\x05\x41\x62ort\x12#.sglang.grpc.scheduler.AbortRequest\x1a$.sglang.grpc.scheduler.AbortResponseb\x06proto3')
_globals = globals() _globals = globals()
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
...@@ -45,65 +45,65 @@ if not _descriptor._USE_C_DESCRIPTORS: ...@@ -45,65 +45,65 @@ if not _descriptor._USE_C_DESCRIPTORS:
_globals['_DISAGGREGATEDPARAMS']._serialized_start=828 _globals['_DISAGGREGATEDPARAMS']._serialized_start=828
_globals['_DISAGGREGATEDPARAMS']._serialized_end=921 _globals['_DISAGGREGATEDPARAMS']._serialized_end=921
_globals['_GENERATEREQUEST']._serialized_start=924 _globals['_GENERATEREQUEST']._serialized_start=924
_globals['_GENERATEREQUEST']._serialized_end=1541 _globals['_GENERATEREQUEST']._serialized_end=1557
_globals['_TOKENIZEDINPUT']._serialized_start=1543 _globals['_TOKENIZEDINPUT']._serialized_start=1559
_globals['_TOKENIZEDINPUT']._serialized_end=1601 _globals['_TOKENIZEDINPUT']._serialized_end=1617
_globals['_MULTIMODALINPUTS']._serialized_start=1604 _globals['_MULTIMODALINPUTS']._serialized_start=1620
_globals['_MULTIMODALINPUTS']._serialized_end=1815 _globals['_MULTIMODALINPUTS']._serialized_end=1831
_globals['_GENERATERESPONSE']._serialized_start=1818 _globals['_GENERATERESPONSE']._serialized_start=1834
_globals['_GENERATERESPONSE']._serialized_end=2045 _globals['_GENERATERESPONSE']._serialized_end=2061
_globals['_GENERATESTREAMCHUNK']._serialized_start=2048 _globals['_GENERATESTREAMCHUNK']._serialized_start=2064
_globals['_GENERATESTREAMCHUNK']._serialized_end=2234 _globals['_GENERATESTREAMCHUNK']._serialized_end=2251
_globals['_GENERATECOMPLETE']._serialized_start=2237 _globals['_GENERATECOMPLETE']._serialized_start=2254
_globals['_GENERATECOMPLETE']._serialized_end=2622 _globals['_GENERATECOMPLETE']._serialized_end=2639
_globals['_GENERATECOMPLETE_FINISHREASON']._serialized_start=2546 _globals['_GENERATECOMPLETE_FINISHREASON']._serialized_start=2563
_globals['_GENERATECOMPLETE_FINISHREASON']._serialized_end=2622 _globals['_GENERATECOMPLETE_FINISHREASON']._serialized_end=2639
_globals['_GENERATEERROR']._serialized_start=2624 _globals['_GENERATEERROR']._serialized_start=2641
_globals['_GENERATEERROR']._serialized_end=2699 _globals['_GENERATEERROR']._serialized_end=2716
_globals['_LOGPROBS']._serialized_start=2702 _globals['_LOGPROBS']._serialized_start=2719
_globals['_LOGPROBS']._serialized_end=2834 _globals['_LOGPROBS']._serialized_end=2851
_globals['_TOPLOGPROBS']._serialized_start=2836 _globals['_TOPLOGPROBS']._serialized_start=2853
_globals['_TOPLOGPROBS']._serialized_end=2905 _globals['_TOPLOGPROBS']._serialized_end=2922
_globals['_HIDDENSTATES']._serialized_start=2907 _globals['_HIDDENSTATES']._serialized_start=2924
_globals['_HIDDENSTATES']._serialized_end=2970 _globals['_HIDDENSTATES']._serialized_end=2987
_globals['_EMBEDREQUEST']._serialized_start=2973 _globals['_EMBEDREQUEST']._serialized_start=2990
_globals['_EMBEDREQUEST']._serialized_end=3303 _globals['_EMBEDREQUEST']._serialized_end=3320
_globals['_EMBEDRESPONSE']._serialized_start=3306 _globals['_EMBEDRESPONSE']._serialized_start=3323
_globals['_EMBEDRESPONSE']._serialized_end=3463 _globals['_EMBEDRESPONSE']._serialized_end=3480
_globals['_EMBEDCOMPLETE']._serialized_start=3466 _globals['_EMBEDCOMPLETE']._serialized_start=3483
_globals['_EMBEDCOMPLETE']._serialized_end=3629 _globals['_EMBEDCOMPLETE']._serialized_end=3646
_globals['_EMBEDDING']._serialized_start=3631 _globals['_EMBEDDING']._serialized_start=3648
_globals['_EMBEDDING']._serialized_end=3673 _globals['_EMBEDDING']._serialized_end=3690
_globals['_EMBEDERROR']._serialized_start=3675 _globals['_EMBEDERROR']._serialized_start=3692
_globals['_EMBEDERROR']._serialized_end=3735 _globals['_EMBEDERROR']._serialized_end=3752
_globals['_HEALTHCHECKREQUEST']._serialized_start=3737 _globals['_HEALTHCHECKREQUEST']._serialized_start=3754
_globals['_HEALTHCHECKREQUEST']._serialized_end=3815 _globals['_HEALTHCHECKREQUEST']._serialized_end=3832
_globals['_HEALTHCHECKRESPONSE']._serialized_start=3817 _globals['_HEALTHCHECKRESPONSE']._serialized_start=3834
_globals['_HEALTHCHECKRESPONSE']._serialized_end=3872 _globals['_HEALTHCHECKRESPONSE']._serialized_end=3889
_globals['_ABORTREQUEST']._serialized_start=3874 _globals['_ABORTREQUEST']._serialized_start=3891
_globals['_ABORTREQUEST']._serialized_end=3924 _globals['_ABORTREQUEST']._serialized_end=3941
_globals['_ABORTRESPONSE']._serialized_start=3926 _globals['_ABORTRESPONSE']._serialized_start=3943
_globals['_ABORTRESPONSE']._serialized_end=3975 _globals['_ABORTRESPONSE']._serialized_end=3992
_globals['_LOADLORAREQUEST']._serialized_start=3977 _globals['_LOADLORAREQUEST']._serialized_start=3994
_globals['_LOADLORAREQUEST']._serialized_end=4050 _globals['_LOADLORAREQUEST']._serialized_end=4067
_globals['_LOADLORARESPONSE']._serialized_start=4052 _globals['_LOADLORARESPONSE']._serialized_start=4069
_globals['_LOADLORARESPONSE']._serialized_end=4124 _globals['_LOADLORARESPONSE']._serialized_end=4141
_globals['_UNLOADLORAREQUEST']._serialized_start=4126 _globals['_UNLOADLORAREQUEST']._serialized_start=4143
_globals['_UNLOADLORAREQUEST']._serialized_end=4165 _globals['_UNLOADLORAREQUEST']._serialized_end=4182
_globals['_UNLOADLORARESPONSE']._serialized_start=4167 _globals['_UNLOADLORARESPONSE']._serialized_start=4184
_globals['_UNLOADLORARESPONSE']._serialized_end=4221 _globals['_UNLOADLORARESPONSE']._serialized_end=4238
_globals['_UPDATEWEIGHTSREQUEST']._serialized_start=4223 _globals['_UPDATEWEIGHTSREQUEST']._serialized_start=4240
_globals['_UPDATEWEIGHTSREQUEST']._serialized_end=4342 _globals['_UPDATEWEIGHTSREQUEST']._serialized_end=4359
_globals['_UPDATEWEIGHTSRESPONSE']._serialized_start=4344 _globals['_UPDATEWEIGHTSRESPONSE']._serialized_start=4361
_globals['_UPDATEWEIGHTSRESPONSE']._serialized_end=4401 _globals['_UPDATEWEIGHTSRESPONSE']._serialized_end=4418
_globals['_GETINTERNALSTATEREQUEST']._serialized_start=4403 _globals['_GETINTERNALSTATEREQUEST']._serialized_start=4420
_globals['_GETINTERNALSTATEREQUEST']._serialized_end=4448 _globals['_GETINTERNALSTATEREQUEST']._serialized_end=4465
_globals['_GETINTERNALSTATERESPONSE']._serialized_start=4450 _globals['_GETINTERNALSTATERESPONSE']._serialized_start=4467
_globals['_GETINTERNALSTATERESPONSE']._serialized_end=4516 _globals['_GETINTERNALSTATERESPONSE']._serialized_end=4533
_globals['_SETINTERNALSTATEREQUEST']._serialized_start=4518 _globals['_SETINTERNALSTATEREQUEST']._serialized_start=4535
_globals['_SETINTERNALSTATEREQUEST']._serialized_end=4583 _globals['_SETINTERNALSTATEREQUEST']._serialized_end=4600
_globals['_SETINTERNALSTATERESPONSE']._serialized_start=4585 _globals['_SETINTERNALSTATERESPONSE']._serialized_start=4602
_globals['_SETINTERNALSTATERESPONSE']._serialized_end=4645 _globals['_SETINTERNALSTATERESPONSE']._serialized_end=4662
_globals['_SGLANGSCHEDULER']._serialized_start=4648 _globals['_SGLANGSCHEDULER']._serialized_start=4665
_globals['_SGLANGSCHEDULER']._serialized_end=5030 _globals['_SGLANGSCHEDULER']._serialized_end=5047
# @@protoc_insertion_point(module_scope) # @@protoc_insertion_point(module_scope)
...@@ -83,7 +83,7 @@ class DisaggregatedParams(_message.Message): ...@@ -83,7 +83,7 @@ class DisaggregatedParams(_message.Message):
def __init__(self, bootstrap_host: _Optional[str] = ..., bootstrap_port: _Optional[int] = ..., bootstrap_room: _Optional[int] = ...) -> None: ... def __init__(self, bootstrap_host: _Optional[str] = ..., bootstrap_port: _Optional[int] = ..., bootstrap_room: _Optional[int] = ...) -> None: ...
class GenerateRequest(_message.Message): class GenerateRequest(_message.Message):
__slots__ = ("request_id", "tokenized", "mm_inputs", "sampling_params", "return_logprob", "logprob_start_len", "top_logprobs_num", "token_ids_logprob", "return_hidden_states", "disaggregated_params", "custom_logit_processor", "timestamp", "log_metrics", "input_embeds", "lora_id", "data_parallel_rank", "dp_balance_id") __slots__ = ("request_id", "tokenized", "mm_inputs", "sampling_params", "return_logprob", "logprob_start_len", "top_logprobs_num", "token_ids_logprob", "return_hidden_states", "disaggregated_params", "custom_logit_processor", "timestamp", "log_metrics", "input_embeds", "lora_id", "data_parallel_rank", "dp_balance_id", "stream")
REQUEST_ID_FIELD_NUMBER: _ClassVar[int] REQUEST_ID_FIELD_NUMBER: _ClassVar[int]
TOKENIZED_FIELD_NUMBER: _ClassVar[int] TOKENIZED_FIELD_NUMBER: _ClassVar[int]
MM_INPUTS_FIELD_NUMBER: _ClassVar[int] MM_INPUTS_FIELD_NUMBER: _ClassVar[int]
...@@ -101,6 +101,7 @@ class GenerateRequest(_message.Message): ...@@ -101,6 +101,7 @@ class GenerateRequest(_message.Message):
LORA_ID_FIELD_NUMBER: _ClassVar[int] LORA_ID_FIELD_NUMBER: _ClassVar[int]
DATA_PARALLEL_RANK_FIELD_NUMBER: _ClassVar[int] DATA_PARALLEL_RANK_FIELD_NUMBER: _ClassVar[int]
DP_BALANCE_ID_FIELD_NUMBER: _ClassVar[int] DP_BALANCE_ID_FIELD_NUMBER: _ClassVar[int]
STREAM_FIELD_NUMBER: _ClassVar[int]
request_id: str request_id: str
tokenized: TokenizedInput tokenized: TokenizedInput
mm_inputs: MultimodalInputs mm_inputs: MultimodalInputs
...@@ -118,7 +119,8 @@ class GenerateRequest(_message.Message): ...@@ -118,7 +119,8 @@ class GenerateRequest(_message.Message):
lora_id: str lora_id: str
data_parallel_rank: int data_parallel_rank: int
dp_balance_id: int dp_balance_id: int
def __init__(self, request_id: _Optional[str] = ..., tokenized: _Optional[_Union[TokenizedInput, _Mapping]] = ..., mm_inputs: _Optional[_Union[MultimodalInputs, _Mapping]] = ..., sampling_params: _Optional[_Union[SamplingParams, _Mapping]] = ..., return_logprob: bool = ..., logprob_start_len: _Optional[int] = ..., top_logprobs_num: _Optional[int] = ..., token_ids_logprob: _Optional[_Iterable[int]] = ..., return_hidden_states: bool = ..., disaggregated_params: _Optional[_Union[DisaggregatedParams, _Mapping]] = ..., custom_logit_processor: _Optional[str] = ..., timestamp: _Optional[_Union[datetime.datetime, _timestamp_pb2.Timestamp, _Mapping]] = ..., log_metrics: bool = ..., input_embeds: _Optional[_Iterable[float]] = ..., lora_id: _Optional[str] = ..., data_parallel_rank: _Optional[int] = ..., dp_balance_id: _Optional[int] = ...) -> None: ... stream: bool
def __init__(self, request_id: _Optional[str] = ..., tokenized: _Optional[_Union[TokenizedInput, _Mapping]] = ..., mm_inputs: _Optional[_Union[MultimodalInputs, _Mapping]] = ..., sampling_params: _Optional[_Union[SamplingParams, _Mapping]] = ..., return_logprob: bool = ..., logprob_start_len: _Optional[int] = ..., top_logprobs_num: _Optional[int] = ..., token_ids_logprob: _Optional[_Iterable[int]] = ..., return_hidden_states: bool = ..., disaggregated_params: _Optional[_Union[DisaggregatedParams, _Mapping]] = ..., custom_logit_processor: _Optional[str] = ..., timestamp: _Optional[_Union[datetime.datetime, _timestamp_pb2.Timestamp, _Mapping]] = ..., log_metrics: bool = ..., input_embeds: _Optional[_Iterable[float]] = ..., lora_id: _Optional[str] = ..., data_parallel_rank: _Optional[int] = ..., dp_balance_id: _Optional[int] = ..., stream: bool = ...) -> None: ...
class TokenizedInput(_message.Message): class TokenizedInput(_message.Message):
__slots__ = ("original_text", "input_ids") __slots__ = ("original_text", "input_ids")
...@@ -161,20 +163,20 @@ class GenerateResponse(_message.Message): ...@@ -161,20 +163,20 @@ class GenerateResponse(_message.Message):
def __init__(self, request_id: _Optional[str] = ..., chunk: _Optional[_Union[GenerateStreamChunk, _Mapping]] = ..., complete: _Optional[_Union[GenerateComplete, _Mapping]] = ..., error: _Optional[_Union[GenerateError, _Mapping]] = ...) -> None: ... def __init__(self, request_id: _Optional[str] = ..., chunk: _Optional[_Union[GenerateStreamChunk, _Mapping]] = ..., complete: _Optional[_Union[GenerateComplete, _Mapping]] = ..., error: _Optional[_Union[GenerateError, _Mapping]] = ...) -> None: ...
class GenerateStreamChunk(_message.Message): class GenerateStreamChunk(_message.Message):
__slots__ = ("token_id", "prompt_tokens", "completion_tokens", "cached_tokens", "logprobs", "hidden_states") __slots__ = ("token_ids", "prompt_tokens", "completion_tokens", "cached_tokens", "logprobs", "hidden_states")
TOKEN_ID_FIELD_NUMBER: _ClassVar[int] TOKEN_IDS_FIELD_NUMBER: _ClassVar[int]
PROMPT_TOKENS_FIELD_NUMBER: _ClassVar[int] PROMPT_TOKENS_FIELD_NUMBER: _ClassVar[int]
COMPLETION_TOKENS_FIELD_NUMBER: _ClassVar[int] COMPLETION_TOKENS_FIELD_NUMBER: _ClassVar[int]
CACHED_TOKENS_FIELD_NUMBER: _ClassVar[int] CACHED_TOKENS_FIELD_NUMBER: _ClassVar[int]
LOGPROBS_FIELD_NUMBER: _ClassVar[int] LOGPROBS_FIELD_NUMBER: _ClassVar[int]
HIDDEN_STATES_FIELD_NUMBER: _ClassVar[int] HIDDEN_STATES_FIELD_NUMBER: _ClassVar[int]
token_id: int token_ids: _containers.RepeatedScalarFieldContainer[int]
prompt_tokens: int prompt_tokens: int
completion_tokens: int completion_tokens: int
cached_tokens: int cached_tokens: int
logprobs: LogProbs logprobs: LogProbs
hidden_states: _containers.RepeatedScalarFieldContainer[float] hidden_states: _containers.RepeatedScalarFieldContainer[float]
def __init__(self, token_id: _Optional[int] = ..., prompt_tokens: _Optional[int] = ..., completion_tokens: _Optional[int] = ..., cached_tokens: _Optional[int] = ..., logprobs: _Optional[_Union[LogProbs, _Mapping]] = ..., hidden_states: _Optional[_Iterable[float]] = ...) -> None: ... def __init__(self, token_ids: _Optional[_Iterable[int]] = ..., prompt_tokens: _Optional[int] = ..., completion_tokens: _Optional[int] = ..., cached_tokens: _Optional[int] = ..., logprobs: _Optional[_Union[LogProbs, _Mapping]] = ..., hidden_states: _Optional[_Iterable[float]] = ...) -> None: ...
class GenerateComplete(_message.Message): class GenerateComplete(_message.Message):
__slots__ = ("output_ids", "finish_reason", "prompt_tokens", "completion_tokens", "cached_tokens", "all_logprobs", "all_hidden_states") __slots__ = ("output_ids", "finish_reason", "prompt_tokens", "completion_tokens", "cached_tokens", "all_logprobs", "all_hidden_states")
......
...@@ -103,6 +103,7 @@ impl SglangSchedulerClient { ...@@ -103,6 +103,7 @@ impl SglangSchedulerClient {
logprob_start_len: -1, logprob_start_len: -1,
top_logprobs_num: body.top_logprobs.unwrap_or(0) as i32, top_logprobs_num: body.top_logprobs.unwrap_or(0) as i32,
return_hidden_states: body.return_hidden_states, return_hidden_states: body.return_hidden_states,
stream: body.stream,
..Default::default() ..Default::default()
}; };
...@@ -367,14 +368,14 @@ mod tests { ...@@ -367,14 +368,14 @@ mod tests {
#[test] #[test]
fn test_generate_stream_chunk() { fn test_generate_stream_chunk() {
let chunk = proto::GenerateStreamChunk { let chunk = proto::GenerateStreamChunk {
token_id: 1234, token_ids: vec![1234, 5678],
prompt_tokens: 5, prompt_tokens: 5,
completion_tokens: 2, completion_tokens: 2,
cached_tokens: 3, cached_tokens: 3,
..Default::default() ..Default::default()
}; };
assert_eq!(chunk.token_id, 1234); assert_eq!(chunk.token_ids, vec![1234, 5678]);
assert_eq!(chunk.prompt_tokens, 5); assert_eq!(chunk.prompt_tokens, 5);
assert_eq!(chunk.completion_tokens, 2); assert_eq!(chunk.completion_tokens, 2);
assert_eq!(chunk.cached_tokens, 3); assert_eq!(chunk.cached_tokens, 3);
......
...@@ -122,6 +122,9 @@ message GenerateRequest { ...@@ -122,6 +122,9 @@ message GenerateRequest {
// For load balancing // For load balancing
int32 dp_balance_id = 17; int32 dp_balance_id = 17;
// Whether client wants streaming response
bool stream = 18;
} }
message TokenizedInput { message TokenizedInput {
...@@ -163,8 +166,8 @@ message GenerateResponse { ...@@ -163,8 +166,8 @@ message GenerateResponse {
} }
message GenerateStreamChunk { message GenerateStreamChunk {
// Generated token // Generated tokens (incremental chunk)
int32 token_id = 1; repeated int32 token_ids = 1;
// Cumulative counts // Cumulative counts
int32 prompt_tokens = 2; int32 prompt_tokens = 2;
......
...@@ -203,6 +203,7 @@ impl GrpcRouter { ...@@ -203,6 +203,7 @@ impl GrpcRouter {
debug!("Selected worker: {}", worker.url()); debug!("Selected worker: {}", worker.url());
// Step 2: Get gRPC client for worker (fail fast if can't connect) // Step 2: Get gRPC client for worker (fail fast if can't connect)
// TODO(CahterineSue): manage grpc connection in worker. (it should be simpler here)
let client = match self.get_or_create_grpc_client(worker.url()).await { let client = match self.get_or_create_grpc_client(worker.url()).await {
Ok(c) => c, Ok(c) => c,
Err(e) => { Err(e) => {
...@@ -249,7 +250,7 @@ impl GrpcRouter { ...@@ -249,7 +250,7 @@ impl GrpcRouter {
// Step 6: Build the base gRPC request // Step 6: Build the base gRPC request
let request_id = format!("chatcmpl-{}", Uuid::new_v4()); let request_id = format!("chatcmpl-{}", Uuid::new_v4());
let base_request = match client.build_generate_request( let request = match client.build_generate_request(
request_id, request_id,
body, body,
processed_messages.text.clone(), processed_messages.text.clone(),
...@@ -268,11 +269,11 @@ impl GrpcRouter { ...@@ -268,11 +269,11 @@ impl GrpcRouter {
} }
}; };
// Step 7: Handle streaming vs non-streaming
if body.stream { if body.stream {
self.handle_streaming_chat(client, base_request, body).await self.handle_streaming_chat(client, request, body).await
} else { } else {
self.handle_non_streaming_chat(client, base_request, body) self.handle_non_streaming_chat(client, request, body).await
.await
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment