Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
53ca1552
Unverified
Commit
53ca1552
authored
Sep 11, 2025
by
Chang Su
Committed by
GitHub
Sep 11, 2025
Browse files
Implement Standalone gRPC Server for SGLang Python Scheduler (#10283)
parent
a23bdeaf
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
2486 additions
and
285 deletions
+2486
-285
.pre-commit-config.yaml
.pre-commit-config.yaml
+8
-2
python/sglang/srt/entrypoints/grpc_request_manager.py
python/sglang/srt/entrypoints/grpc_request_manager.py
+580
-0
python/sglang/srt/entrypoints/grpc_server.py
python/sglang/srt/entrypoints/grpc_server.py
+680
-0
python/sglang/srt/grpc/__init__.py
python/sglang/srt/grpc/__init__.py
+1
-0
python/sglang/srt/grpc/sglang_scheduler.proto
python/sglang/srt/grpc/sglang_scheduler.proto
+389
-0
python/sglang/srt/grpc/sglang_scheduler_pb2.py
python/sglang/srt/grpc/sglang_scheduler_pb2.py
+106
-0
python/sglang/srt/grpc/sglang_scheduler_pb2.pyi
python/sglang/srt/grpc/sglang_scheduler_pb2.pyi
+427
-0
python/sglang/srt/grpc/sglang_scheduler_pb2_grpc.py
python/sglang/srt/grpc/sglang_scheduler_pb2_grpc.py
+236
-0
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+1
-0
sgl-router/src/grpc/client.rs
sgl-router/src/grpc/client.rs
+36
-109
sgl-router/src/proto/sglang_scheduler.proto
sgl-router/src/proto/sglang_scheduler.proto
+22
-174
No files found.
.pre-commit-config.yaml
View file @
53ca1552
...
...
@@ -22,17 +22,19 @@ repos:
rev
:
5.13.2
hooks
:
-
id
:
isort
exclude
:
'
^python/sglang/srt/grpc/.*_pb2\.py$|^python/sglang/srt/grpc/.*_pb2_grpc\.py$|^python/sglang/srt/grpc/.*_pb2\.pyi$|^python/sglang/srt/grpc/.*_pb2_grpc\.pyi$'
-
repo
:
https://github.com/astral-sh/ruff-pre-commit
rev
:
v0.11.7
hooks
:
-
id
:
ruff
args
:
[
--select=F401
,
--fixable=F401
]
files
:
^(benchmark/|docs/|examples/)
exclude
:
\.ipynb$
exclude
:
\.ipynb$
|^python/sglang/srt/grpc/.*_pb2\.py$|^python/sglang/srt/grpc/.*_pb2_grpc\.py$|^python/sglang/srt/grpc/.*_pb2\.pyi$|^python/sglang/srt/grpc/.*_pb2_grpc\.pyi$
-
repo
:
https://github.com/psf/black
rev
:
24.10.0
hooks
:
-
id
:
black-jupyter
exclude
:
'
^python/sglang/srt/grpc/.*_pb2\.py$|^python/sglang/srt/grpc/.*_pb2_grpc\.py$|^python/sglang/srt/grpc/.*_pb2\.pyi$|^python/sglang/srt/grpc/.*_pb2_grpc\.pyi$'
-
repo
:
https://github.com/codespell-project/codespell
rev
:
v2.4.1
hooks
:
...
...
@@ -42,7 +44,11 @@ repos:
exclude
:
|
(?x)^(
test/srt/test_reasoning_parser\.py|
docs/advanced_features/vlm_query\.ipynb
docs/advanced_features/vlm_query\.ipynb|
python/sglang/srt/grpc/.*_pb2\.py|
python/sglang/srt/grpc/.*_pb2_grpc\.py|
python/sglang/srt/grpc/.*_pb2\.pyi|
python/sglang/srt/grpc/.*_pb2_grpc\.pyi
)$
-
repo
:
https://github.com/pre-commit/mirrors-clang-format
rev
:
v18.1.8
...
...
python/sglang/srt/entrypoints/grpc_request_manager.py
0 → 100644
View file @
53ca1552
"""
gRPC Request Manager - Orchestrates request lifecycle without tokenization.
Mimics TokenizerManager's state management and ZMQ communication patterns.
"""
import
asyncio
import
dataclasses
import
logging
import
os
import
signal
import
sys
import
threading
import
time
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Union
import
grpc
import
zmq
import
zmq.asyncio
from
sglang.srt.managers.io_struct
import
(
AbortReq
,
BatchEmbeddingOut
,
BatchTokenIDOut
,
HealthCheckOutput
,
TokenizedEmbeddingReqInput
,
TokenizedGenerateReqInput
,
)
from
sglang.srt.server_args
import
PortArgs
,
ServerArgs
from
sglang.srt.utils
import
get_zmq_socket
,
kill_process_tree
from
sglang.utils
import
get_exception_traceback
logger
=
logging
.
getLogger
(
__name__
)
class
GrpcSignalHandler
:
"""Minimal signal handler for gRPC server - delegates real crash handling to scheduler."""
def
__init__
(
self
,
grpc_manager
):
self
.
grpc_manager
=
grpc_manager
def
sigterm_handler
(
self
,
signum
=
None
,
frame
=
None
):
"""Handle SIGTERM by gracefully shutting down gRPC server."""
logger
.
warning
(
f
"SIGTERM received.
{
signum
=
}
{
frame
=
}
. Shutting down gRPC server..."
)
self
.
grpc_manager
.
gracefully_exit
=
True
def
running_phase_sigquit_handler
(
self
,
signum
=
None
,
frame
=
None
):
"""Handle SIGQUIT from failed scheduler process."""
logger
.
error
(
"Received SIGQUIT from scheduler process. Scheduler failed, shutting down gRPC server."
)
logger
.
info
(
"Note: Crash dumps are handled by the scheduler process, not the gRPC server."
)
# Just exit cleanly - the scheduler handles crash dumps
kill_process_tree
(
os
.
getpid
(),
include_parent
=
True
)
@
dataclasses
.
dataclass
class
GrpcReqState
:
"""State tracking for a gRPC request."""
# Request identification
request_id
:
str
grpc_context
:
Optional
[
grpc
.
aio
.
ServicerContext
]
# Communication
out_queue
:
asyncio
.
Queue
finished
:
bool
event
:
asyncio
.
Event
obj
:
Union
[
TokenizedGenerateReqInput
,
TokenizedEmbeddingReqInput
]
# Metrics (same as TokenizerManager's ReqState)
created_time
:
float
finished_time
:
float
=
0.0
first_token_time
:
float
=
0.0
last_time
:
float
=
0.0
last_completion_tokens
:
int
=
1
# Streaming state
last_output_offset
:
int
=
0
stream_finished
:
bool
=
False
# Output accumulation
text
:
str
=
""
output_ids
:
List
[
int
]
=
dataclasses
.
field
(
default_factory
=
list
)
input_token_logprobs_val
:
List
[
float
]
=
dataclasses
.
field
(
default_factory
=
list
)
input_token_logprobs_idx
:
List
[
int
]
=
dataclasses
.
field
(
default_factory
=
list
)
output_token_logprobs_val
:
List
[
float
]
=
dataclasses
.
field
(
default_factory
=
list
)
output_token_logprobs_idx
:
List
[
int
]
=
dataclasses
.
field
(
default_factory
=
list
)
input_top_logprobs_val
:
List
[
List
[
float
]]
=
dataclasses
.
field
(
default_factory
=
list
)
input_top_logprobs_idx
:
List
[
List
[
int
]]
=
dataclasses
.
field
(
default_factory
=
list
)
output_top_logprobs_val
:
List
[
List
[
float
]]
=
dataclasses
.
field
(
default_factory
=
list
)
output_top_logprobs_idx
:
List
[
List
[
int
]]
=
dataclasses
.
field
(
default_factory
=
list
)
# Session state
session_id
:
Optional
[
str
]
=
None
is_session_request
:
bool
=
False
class
GrpcRequestManager
:
"""
Manages gRPC request lifecycle, mimicking TokenizerManager's orchestration
behaviors without tokenization.
"""
def
__init__
(
self
,
server_args
:
ServerArgs
,
port_args
:
PortArgs
,
):
"""Initialize the gRPC request manager."""
self
.
server_args
=
server_args
self
.
port_args
=
port_args
# ZMQ Communication Setup (same pattern as TokenizerManager)
context
=
zmq
.
asyncio
.
Context
(
2
)
# Socket for receiving outputs from scheduler
self
.
recv_from_scheduler
=
get_zmq_socket
(
context
,
zmq
.
PULL
,
port_args
.
detokenizer_ipc_name
,
bind
=
True
)
# Socket for sending requests to scheduler
self
.
send_to_scheduler
=
get_zmq_socket
(
context
,
zmq
.
PUSH
,
port_args
.
scheduler_input_ipc_name
,
bind
=
True
)
# State Management (from TokenizerManager)
self
.
rid_to_state
:
Dict
[
str
,
GrpcReqState
]
=
{}
self
.
asyncio_tasks
:
set
=
set
()
self
.
gracefully_exit
=
False
self
.
no_create_loop
=
False
self
.
event_loop
=
None
# Pause/Resume Control
self
.
is_pause
=
False
self
.
is_pause_cond
=
asyncio
.
Condition
()
# Metrics
self
.
request_counter
=
0
self
.
request_counter_lock
=
asyncio
.
Lock
()
self
.
last_receive_tstamp
=
time
.
time
()
# Crash dump for debugging
self
.
crash_dump_request_list
=
[]
self
.
crash_dump_performed
=
False
logger
.
info
(
f
"GrpcRequestManager initialized with ZMQ IPC: "
f
"recv=
{
port_args
.
detokenizer_ipc_name
}
, "
f
"send=
{
port_args
.
scheduler_input_ipc_name
}
"
)
async
def
generate_request
(
self
,
obj
:
TokenizedGenerateReqInput
,
request_id
:
Optional
[
str
]
=
None
,
grpc_context
:
Optional
[
grpc
.
aio
.
ServicerContext
]
=
None
,
)
->
asyncio
.
Queue
:
"""
Submit a generation request to the scheduler.
Returns a queue for streaming outputs.
"""
# Generate request ID if not provided
if
request_id
is
None
:
async
with
self
.
request_counter_lock
:
request_id
=
f
"grpc-
{
self
.
request_counter
}
"
self
.
request_counter
+=
1
obj
.
rid
=
request_id
# TODO: support log_request
# Create request state
state
=
GrpcReqState
(
request_id
=
request_id
,
grpc_context
=
grpc_context
,
out_queue
=
asyncio
.
Queue
(),
finished
=
False
,
event
=
asyncio
.
Event
(),
obj
=
obj
,
created_time
=
time
.
time
(),
)
# Track session if needed
if
hasattr
(
obj
,
"session_params"
)
and
obj
.
session_params
:
state
.
session_id
=
obj
.
session_params
.
session_id
state
.
is_session_request
=
True
# Register state
self
.
rid_to_state
[
request_id
]
=
state
self
.
record_request_for_crash_dump
(
obj
)
# Send to scheduler via ZMQ
try
:
await
self
.
_send_to_scheduler
(
obj
)
except
Exception
as
e
:
# Clean up on failure
del
self
.
rid_to_state
[
request_id
]
raise
RuntimeError
(
f
"Failed to send request to scheduler:
{
e
}
"
)
return
state
.
out_queue
async
def
embedding_request
(
self
,
obj
:
TokenizedEmbeddingReqInput
,
request_id
:
Optional
[
str
]
=
None
,
)
->
asyncio
.
Future
:
"""
Submit an embedding request to the scheduler.
Returns a future that will contain the embedding result.
"""
# Generate request ID if not provided
if
request_id
is
None
:
async
with
self
.
request_counter_lock
:
request_id
=
f
"grpc-embed-
{
self
.
request_counter
}
"
self
.
request_counter
+=
1
obj
.
rid
=
request_id
# Create request state
state
=
GrpcReqState
(
request_id
=
request_id
,
grpc_context
=
None
,
out_queue
=
asyncio
.
Queue
(),
finished
=
False
,
event
=
asyncio
.
Event
(),
obj
=
obj
,
created_time
=
time
.
time
(),
)
# Register state
self
.
rid_to_state
[
request_id
]
=
state
# Create future for result
future
=
asyncio
.
Future
()
# Send to scheduler
try
:
await
self
.
_send_to_scheduler
(
obj
)
except
Exception
as
e
:
del
self
.
rid_to_state
[
request_id
]
future
.
set_exception
(
e
)
return
future
# Wait for result in background
async
def
wait_for_result
():
try
:
# Wait for completion
await
state
.
event
.
wait
()
# Get result from queue
result
=
await
state
.
out_queue
.
get
()
future
.
set_result
(
result
)
except
Exception
as
e
:
future
.
set_exception
(
e
)
finally
:
# Clean up
if
request_id
in
self
.
rid_to_state
:
del
self
.
rid_to_state
[
request_id
]
asyncio
.
create_task
(
wait_for_result
())
return
future
async
def
abort_request
(
self
,
request_id
:
str
)
->
bool
:
"""Abort a running request."""
if
request_id
not
in
self
.
rid_to_state
:
return
False
# Send abort to scheduler
abort_req
=
AbortReq
(
rid
=
request_id
)
try
:
await
self
.
_send_to_scheduler
(
abort_req
)
except
Exception
as
e
:
logger
.
error
(
f
"Failed to send abort request:
{
e
}
"
)
return
False
# Mark as finished
state
=
self
.
rid_to_state
.
get
(
request_id
)
if
state
:
state
.
finished
=
True
state
.
stream_finished
=
True
state
.
event
.
set
()
# Send abort notification to output queue
await
state
.
out_queue
.
put
({
"error"
:
"Request aborted"
,
"abort"
:
True
})
return
True
async
def
pause_generation
(
self
):
"""Pause generation processing."""
async
with
self
.
is_pause_cond
:
self
.
is_pause
=
True
logger
.
info
(
"Generation paused"
)
async
def
resume_generation
(
self
):
"""Resume generation processing."""
async
with
self
.
is_pause_cond
:
self
.
is_pause
=
False
self
.
is_pause_cond
.
notify_all
()
logger
.
info
(
"Generation resumed"
)
async
def
handle_loop
(
self
):
"""
Main event loop - processes outputs from scheduler.
Mimics TokenizerManager's handle_loop.
"""
while
not
self
.
gracefully_exit
:
try
:
# Receive from scheduler
recv_obj
=
await
self
.
recv_from_scheduler
.
recv_pyobj
()
self
.
last_receive_tstamp
=
time
.
time
()
# Check for pause
async
with
self
.
is_pause_cond
:
while
self
.
is_pause
:
await
self
.
is_pause_cond
.
wait
()
# Handle different output types
if
isinstance
(
recv_obj
,
BatchTokenIDOut
):
await
self
.
_handle_batch_output
(
recv_obj
)
elif
isinstance
(
recv_obj
,
BatchEmbeddingOut
):
await
self
.
_handle_embedding_output
(
recv_obj
)
elif
isinstance
(
recv_obj
,
HealthCheckOutput
):
await
self
.
_handle_health_check_output
(
recv_obj
)
else
:
logger
.
warning
(
f
"Unknown output type:
{
type
(
recv_obj
)
}
"
)
except
zmq
.
error
.
Again
:
# Timeout, check if we should exit
if
self
.
gracefully_exit
:
break
continue
except
Exception
as
e
:
logger
.
error
(
f
"Handle loop error:
{
e
}
\n
{
get_exception_traceback
()
}
"
)
if
self
.
gracefully_exit
:
break
async
def
_handle_batch_output
(
self
,
batch_out
:
BatchTokenIDOut
):
"""Handle batch generation output from scheduler."""
# Process each request in the batch
for
i
,
rid
in
enumerate
(
batch_out
.
rids
):
if
rid
not
in
self
.
rid_to_state
:
continue
state
=
self
.
rid_to_state
[
rid
]
# Update metrics
now
=
time
.
time
()
if
state
.
first_token_time
==
0.0
:
state
.
first_token_time
=
now
state
.
last_time
=
now
# Extract output for this request
output_data
=
{
"request_id"
:
rid
,
"text"
:
batch_out
.
decoded_texts
[
i
]
if
batch_out
.
decoded_texts
else
""
,
"token_ids"
:
batch_out
.
output_ids
[
i
]
if
batch_out
.
output_ids
else
[],
"finished"
:
batch_out
.
finished_reasons
[
i
]
is
not
None
,
"meta_info"
:
{
"prompt_tokens"
:
(
batch_out
.
prompt_tokens
[
i
]
if
batch_out
.
prompt_tokens
else
0
),
"completion_tokens"
:
(
batch_out
.
completion_tokens
[
i
]
if
batch_out
.
completion_tokens
else
0
),
"finish_reason"
:
(
str
(
batch_out
.
finished_reasons
[
i
])
if
batch_out
.
finished_reasons
[
i
]
else
None
),
},
}
# Add logprobs if available
if
batch_out
.
output_token_logprobs_val
and
i
<
len
(
batch_out
.
output_token_logprobs_val
):
output_data
[
"logprobs"
]
=
{
"tokens"
:
batch_out
.
output_token_logprobs_val
[
i
],
"top_logprobs"
:
(
batch_out
.
output_top_logprobs_val
[
i
]
if
batch_out
.
output_top_logprobs_val
and
i
<
len
(
batch_out
.
output_top_logprobs_val
)
else
None
),
}
# Update state
if
output_data
[
"text"
]:
state
.
text
+=
output_data
[
"text"
][
state
.
last_output_offset
:]
state
.
last_output_offset
=
len
(
output_data
[
"text"
])
if
output_data
[
"token_ids"
]:
state
.
output_ids
.
extend
(
output_data
[
"token_ids"
])
# Send to output queue
await
state
.
out_queue
.
put
(
output_data
)
# Handle completion
if
output_data
[
"finished"
]:
state
.
finished
=
True
state
.
finished_time
=
now
state
.
stream_finished
=
True
state
.
event
.
set
()
# Remove from tracking after a delay
async
def
cleanup
():
await
asyncio
.
sleep
(
5.0
)
if
rid
in
self
.
rid_to_state
:
del
self
.
rid_to_state
[
rid
]
asyncio
.
create_task
(
cleanup
())
async
def
_handle_embedding_output
(
self
,
batch_out
:
BatchEmbeddingOut
):
"""Handle batch embedding output from scheduler."""
for
i
,
rid
in
enumerate
(
batch_out
.
rids
):
if
rid
not
in
self
.
rid_to_state
:
continue
state
=
self
.
rid_to_state
[
rid
]
# Create result
result
=
{
"request_id"
:
rid
,
"embedding"
:
batch_out
.
embeddings
[
i
],
"prompt_tokens"
:
(
batch_out
.
prompt_tokens
[
i
]
if
batch_out
.
prompt_tokens
else
0
),
"finish_reason"
:
(
batch_out
.
finish_reason
[
i
]
if
batch_out
.
finish_reason
else
None
),
}
# Send result
await
state
.
out_queue
.
put
(
result
)
# Mark as finished
state
.
finished
=
True
state
.
finished_time
=
time
.
time
()
state
.
event
.
set
()
async
def
_handle_health_check_output
(
self
,
health_out
:
HealthCheckOutput
):
"""Handle health check output from scheduler."""
rid
=
health_out
.
rid
if
rid
not
in
self
.
rid_to_state
:
logger
.
warning
(
f
"Health check output for unknown request:
{
rid
}
"
)
return
state
=
self
.
rid_to_state
[
rid
]
# Create health check result
result
=
{
"request_id"
:
rid
,
"healthy"
:
True
,
# If we got a response, scheduler is healthy
"output_text"
:
(
health_out
.
output_str
if
hasattr
(
health_out
,
"output_str"
)
else
""
),
"finish_reason"
:
(
health_out
.
finish_reason
if
hasattr
(
health_out
,
"finish_reason"
)
else
"stop"
),
}
# Send result
await
state
.
out_queue
.
put
(
result
)
# Mark as finished
state
.
finished
=
True
state
.
finished_time
=
time
.
time
()
state
.
event
.
set
()
async
def
_send_to_scheduler
(
self
,
obj
):
"""Send an object to the scheduler via ZMQ."""
try
:
self
.
send_to_scheduler
.
send_pyobj
(
obj
)
except
Exception
as
e
:
logger
.
error
(
f
"Failed to send to scheduler:
{
e
}
"
)
raise
def
record_request_for_crash_dump
(
self
,
obj
):
"""Record request for potential crash dump."""
if
len
(
self
.
crash_dump_request_list
)
<
100
:
self
.
crash_dump_request_list
.
append
(
{
"time"
:
time
.
time
(),
"request_id"
:
getattr
(
obj
,
"rid"
,
"unknown"
),
"type"
:
type
(
obj
).
__name__
,
}
)
async
def
shutdown
(
self
):
"""Gracefully shutdown the request manager."""
logger
.
info
(
"Shutting down GrpcRequestManager"
)
self
.
gracefully_exit
=
True
# Cancel all pending requests
for
rid
,
state
in
self
.
rid_to_state
.
items
():
if
not
state
.
finished
:
await
state
.
out_queue
.
put
(
{
"error"
:
"Server shutting down"
,
"shutdown"
:
True
}
)
state
.
finished
=
True
state
.
event
.
set
()
# Wait for tasks to complete
if
self
.
asyncio_tasks
:
await
asyncio
.
gather
(
*
list
(
self
.
asyncio_tasks
),
return_exceptions
=
True
)
# Close ZMQ sockets
self
.
recv_from_scheduler
.
close
()
self
.
send_to_scheduler
.
close
()
logger
.
info
(
"GrpcRequestManager shutdown complete"
)
def
get_server_info
(
self
)
->
Dict
[
str
,
Any
]:
"""Get server information for health checks."""
return
{
"active_requests"
:
len
(
self
.
rid_to_state
),
"paused"
:
self
.
is_pause
,
"last_receive_time"
:
self
.
last_receive_tstamp
,
}
def
auto_create_handle_loop
(
self
):
"""Automatically create and start the handle_loop task, matching TokenizerManager pattern."""
if
self
.
no_create_loop
:
return
self
.
no_create_loop
=
True
loop
=
asyncio
.
get_event_loop
()
self
.
asyncio_tasks
.
add
(
loop
.
create_task
(
print_exception_wrapper
(
self
.
handle_loop
))
)
self
.
event_loop
=
loop
# We cannot add signal handler when the grpc manager is not in
# the main thread due to the CPython limitation.
if
threading
.
current_thread
()
is
threading
.
main_thread
():
signal_handler
=
GrpcSignalHandler
(
self
)
loop
.
add_signal_handler
(
signal
.
SIGTERM
,
signal_handler
.
sigterm_handler
)
# Update the signal handler for the process. It overrides the sigquit handler in the launch phase.
loop
.
add_signal_handler
(
signal
.
SIGQUIT
,
signal_handler
.
running_phase_sigquit_handler
)
else
:
logger
.
warning
(
"Signal handler is not added because the grpc request manager is "
"not in the main thread. This disables graceful shutdown of the "
"grpc request manager when SIGTERM is received."
)
self
.
asyncio_tasks
.
add
(
loop
.
create_task
(
print_exception_wrapper
(
self
.
sigterm_watchdog
))
)
async
def
sigterm_watchdog
(
self
):
"""Watchdog to handle SIGTERM gracefully, matching TokenizerManager pattern."""
while
not
self
.
gracefully_exit
:
await
asyncio
.
sleep
(
1.0
)
async
def
print_exception_wrapper
(
func
):
"""
Sometimes an asyncio function does not print exception.
We do another wrapper to handle the exception.
"""
try
:
await
func
()
except
Exception
:
traceback
=
get_exception_traceback
()
logger
.
error
(
f
"GrpcRequestManager hit an exception:
{
traceback
}
"
)
if
hasattr
(
func
,
"__self__"
)
and
isinstance
(
func
.
__self__
,
GrpcRequestManager
):
func
.
__self__
.
dump_requests_before_crash
()
kill_process_tree
(
os
.
getpid
(),
include_parent
=
True
)
sys
.
exit
(
1
)
python/sglang/srt/entrypoints/grpc_server.py
0 → 100644
View file @
53ca1552
"""
Standalone gRPC Server for SGLang - Fully separated from HTTP server.
Uses GrpcRequestManager for orchestration without tokenization.
"""
import
argparse
import
asyncio
import
logging
import
multiprocessing
as
mp
import
os
import
signal
import
time
from
concurrent
import
futures
from
typing
import
AsyncIterator
,
Dict
,
Optional
,
Tuple
import
grpc
from
grpc_reflection.v1alpha
import
reflection
from
sglang.srt.entrypoints.grpc_request_manager
import
GrpcRequestManager
from
sglang.srt.grpc
import
sglang_scheduler_pb2
,
sglang_scheduler_pb2_grpc
from
sglang.srt.managers.data_parallel_controller
import
(
run_data_parallel_controller_process
,
)
from
sglang.srt.managers.io_struct
import
(
TokenizedEmbeddingReqInput
,
TokenizedGenerateReqInput
,
)
from
sglang.srt.managers.scheduler
import
run_scheduler_process
from
sglang.srt.sampling.sampling_params
import
SamplingParams
as
SGLSamplingParams
from
sglang.srt.server_args
import
PortArgs
,
ServerArgs
from
sglang.srt.torch_memory_saver_adapter
import
TorchMemorySaverAdapter
from
sglang.srt.utils
import
configure_logger
,
prepare_model_and_tokenizer
from
sglang.utils
import
get_exception_traceback
logger
=
logging
.
getLogger
(
__name__
)
HEALTH_CHECK_TIMEOUT
=
int
(
os
.
getenv
(
"SGLANG_HEALTH_CHECK_TIMEOUT"
,
20
))
def
_launch_scheduler_process_only
(
server_args
:
ServerArgs
,
port_args
:
Optional
[
PortArgs
]
=
None
,
)
->
Tuple
[
Dict
,
PortArgs
,
list
]:
"""
Launch only the scheduler process(es) without tokenizer/detokenizer.
Returns scheduler info, port args, and list of scheduler processes.
"""
# Configure global environment
configure_logger
(
server_args
)
server_args
.
check_server_args
()
# Allocate ports for inter-process communications
if
port_args
is
None
:
port_args
=
PortArgs
.
init_new
(
server_args
)
logger
.
info
(
f
"
{
server_args
=
}
"
)
# Prepare model and tokenizer paths
server_args
.
model_path
,
server_args
.
tokenizer_path
=
prepare_model_and_tokenizer
(
server_args
.
model_path
,
server_args
.
tokenizer_path
)
scheduler_procs
=
[]
if
server_args
.
dp_size
==
1
:
memory_saver_adapter
=
TorchMemorySaverAdapter
.
create
(
enable
=
server_args
.
enable_memory_saver
)
scheduler_pipe_readers
=
[]
nnodes_per_tp_group
=
max
(
server_args
.
nnodes
//
server_args
.
pp_size
,
1
)
tp_size_per_node
=
server_args
.
tp_size
//
nnodes_per_tp_group
tp_rank_range
=
range
(
tp_size_per_node
*
(
server_args
.
node_rank
%
nnodes_per_tp_group
),
tp_size_per_node
*
(
server_args
.
node_rank
%
nnodes_per_tp_group
+
1
),
)
pp_size_per_node
=
max
(
server_args
.
pp_size
//
server_args
.
nnodes
,
1
)
pp_rank_range
=
range
(
pp_size_per_node
*
(
server_args
.
node_rank
//
nnodes_per_tp_group
),
pp_size_per_node
*
(
server_args
.
node_rank
//
nnodes_per_tp_group
+
1
),
)
for
pp_rank
in
pp_rank_range
:
for
tp_rank
in
tp_rank_range
:
reader
,
writer
=
mp
.
Pipe
(
duplex
=
False
)
gpu_id
=
(
server_args
.
base_gpu_id
+
((
pp_rank
%
pp_size_per_node
)
*
tp_size_per_node
)
+
(
tp_rank
%
tp_size_per_node
)
*
server_args
.
gpu_id_step
)
moe_ep_rank
=
tp_rank
//
(
server_args
.
tp_size
//
server_args
.
ep_size
)
proc
=
mp
.
Process
(
target
=
run_scheduler_process
,
args
=
(
server_args
,
port_args
,
gpu_id
,
tp_rank
,
moe_ep_rank
,
pp_rank
,
None
,
writer
,
None
,
),
)
with
memory_saver_adapter
.
configure_subprocess
():
proc
.
start
()
scheduler_procs
.
append
(
proc
)
scheduler_pipe_readers
.
append
(
reader
)
else
:
# Launch the data parallel controller
reader
,
writer
=
mp
.
Pipe
(
duplex
=
False
)
scheduler_pipe_readers
=
[
reader
]
proc
=
mp
.
Process
(
target
=
run_data_parallel_controller_process
,
args
=
(
server_args
,
port_args
,
writer
),
)
proc
.
start
()
scheduler_procs
.
append
(
proc
)
# TODO(CatherineSue): handle cases for multi-node
# Wait for all scheduler processes to be ready
scheduler_infos
=
[]
for
i
,
reader
in
enumerate
(
scheduler_pipe_readers
):
try
:
data
=
reader
.
recv
()
except
EOFError
:
logger
.
error
(
f
"Rank
{
i
}
scheduler is dead. Please check if there are relevant logs."
)
scheduler_procs
[
i
].
join
()
logger
.
error
(
f
"Exit code:
{
scheduler_procs
[
i
].
exitcode
}
"
)
raise
RuntimeError
(
f
"Failed to initialize scheduler rank
{
i
}
"
)
if
data
.
get
(
"status"
)
!=
"ready"
:
raise
RuntimeError
(
f
"Scheduler rank
{
i
}
initialization failed:
{
data
.
get
(
'error'
,
'Unknown error'
)
}
"
)
scheduler_infos
.
append
(
data
)
logger
.
info
(
f
"All
{
len
(
scheduler_procs
)
}
scheduler process(es) initialized successfully"
)
# Return the first scheduler's info (they should all be the same)
return
scheduler_infos
[
0
],
port_args
,
scheduler_procs
class
SGLangSchedulerServicer
(
sglang_scheduler_pb2_grpc
.
SglangSchedulerServicer
):
"""
Standalone gRPC service implementation using GrpcRequestManager.
Fully separated from HTTP server with its own process and no shared globals.
"""
def
__init__
(
self
,
request_manager
:
GrpcRequestManager
,
server_args
:
ServerArgs
,
model_info
:
Dict
,
):
"""Initialize the standalone gRPC service."""
self
.
request_manager
=
request_manager
self
.
server_args
=
server_args
self
.
model_info
=
model_info
self
.
start_time
=
time
.
time
()
# Start the request manager's event loop using auto_create_handle_loop
self
.
request_manager
.
auto_create_handle_loop
()
logger
.
info
(
"Standalone gRPC scheduler service initialized"
)
async
def
Generate
(
self
,
request
:
sglang_scheduler_pb2
.
GenerateRequest
,
context
:
grpc
.
aio
.
ServicerContext
,
)
->
AsyncIterator
[
sglang_scheduler_pb2
.
GenerateResponse
]:
"""Handle generation requests with streaming responses."""
logger
.
info
(
f
"Generation request:
{
request
.
request_id
}
"
)
try
:
# Convert gRPC request to internal format
tokenized_req
=
self
.
_convert_generate_request
(
request
)
# Submit to request manager
output_queue
=
await
self
.
request_manager
.
generate_request
(
obj
=
tokenized_req
,
request_id
=
request
.
request_id
,
grpc_context
=
context
,
)
# Stream outputs
while
True
:
try
:
# Get output with timeout
output
=
await
asyncio
.
wait_for
(
output_queue
.
get
(),
timeout
=
4
)
# Check for errors
if
"error"
in
output
:
yield
sglang_scheduler_pb2
.
GenerateResponse
(
request_id
=
request
.
request_id
,
error
=
sglang_scheduler_pb2
.
GenerateError
(
message
=
output
[
"error"
],
http_status_code
=
(
"500"
if
"abort"
not
in
output
else
"499"
),
),
)
break
# Check if finished
if
output
.
get
(
"finished"
,
False
):
# Send completion
yield
self
.
_create_completion_response
(
request
.
request_id
,
output
)
break
else
:
# Send chunk
yield
self
.
_create_chunk_response
(
request
.
request_id
,
output
)
except
asyncio
.
TimeoutError
:
# Check if context is still active
if
context
.
cancelled
():
# Abort the request
await
self
.
request_manager
.
abort_request
(
request
.
request_id
)
break
continue
except
Exception
as
e
:
logger
.
error
(
f
"Generate failed:
{
e
}
\n
{
get_exception_traceback
()
}
"
)
yield
sglang_scheduler_pb2
.
GenerateResponse
(
request_id
=
request
.
request_id
,
error
=
sglang_scheduler_pb2
.
GenerateError
(
message
=
str
(
e
),
http_status_code
=
"500"
,
details
=
get_exception_traceback
(),
),
)
async
def
Embed
(
self
,
request
:
sglang_scheduler_pb2
.
EmbedRequest
,
context
:
grpc
.
aio
.
ServicerContext
,
)
->
sglang_scheduler_pb2
.
EmbedResponse
:
"""Handle embedding requests."""
logger
.
info
(
f
"Embedding request:
{
request
.
request_id
}
"
)
try
:
# Convert request
tokenized_req
=
self
.
_convert_embed_request
(
request
)
# Submit to request manager
future
=
await
self
.
request_manager
.
embedding_request
(
obj
=
tokenized_req
,
request_id
=
request
.
request_id
,
)
# Wait for result
result
=
await
future
# Create response
return
sglang_scheduler_pb2
.
EmbedResponse
(
request_id
=
request
.
request_id
,
complete
=
sglang_scheduler_pb2
.
EmbedComplete
(
embedding
=
result
[
"embedding"
],
prompt_tokens
=
result
.
get
(
"prompt_tokens"
,
0
),
cached_tokens
=
0
,
embedding_dim
=
len
(
result
[
"embedding"
]),
generation_time
=
time
.
time
()
-
self
.
start_time
,
),
)
except
Exception
as
e
:
logger
.
error
(
f
"Embed failed:
{
e
}
\n
{
get_exception_traceback
()
}
"
)
return
sglang_scheduler_pb2
.
EmbedResponse
(
request_id
=
request
.
request_id
,
error
=
sglang_scheduler_pb2
.
EmbedError
(
message
=
str
(
e
),
code
=
"INTERNAL_ERROR"
,
details
=
get_exception_traceback
(),
),
)
async
def
HealthCheck
(
self
,
request
:
sglang_scheduler_pb2
.
HealthCheckRequest
,
context
:
grpc
.
aio
.
ServicerContext
,
)
->
sglang_scheduler_pb2
.
HealthCheckResponse
:
"""Health check by generating from client input."""
try
:
# Check if request manager is shutting down
if
self
.
request_manager
.
gracefully_exit
:
return
sglang_scheduler_pb2
.
HealthCheckResponse
(
healthy
=
False
,
message
=
"Server shutting down"
)
# Extract tokenized input from request
if
not
request
.
HasField
(
"tokenized"
):
return
sglang_scheduler_pb2
.
HealthCheckResponse
(
healthy
=
False
,
message
=
"Tokenized input required for health check"
)
input_text
=
request
.
tokenized
.
original_text
input_ids
=
list
(
request
.
tokenized
.
input_ids
)
# Create health check request
rid
=
f
"HEALTH_CHECK_GRPC_
{
time
.
time
()
}
"
health_request
=
TokenizedGenerateReqInput
(
rid
=
rid
,
input_text
=
input_text
,
input_ids
=
input_ids
,
sampling_params
=
SGLSamplingParams
(
max_new_tokens
=
1
,
temperature
=
0.0
),
stream
=
False
,
mm_inputs
=
None
,
return_logprob
=
False
,
logprob_start_len
=-
1
,
top_logprobs_num
=
0
,
token_ids_logprob
=
None
,
)
logger
.
info
(
f
"Sending health check request to request manager..."
)
# Submit and wait for response
output_queue
=
await
self
.
request_manager
.
generate_request
(
health_request
,
request_id
=
rid
)
try
:
# Wait for response with configurable timeout
response
=
await
asyncio
.
wait_for
(
output_queue
.
get
(),
timeout
=
HEALTH_CHECK_TIMEOUT
)
# Clean up
if
rid
in
self
.
request_manager
.
rid_to_state
:
del
self
.
request_manager
.
rid_to_state
[
rid
]
return
sglang_scheduler_pb2
.
HealthCheckResponse
(
healthy
=
True
,
message
=
"Health check passed"
)
except
asyncio
.
TimeoutError
:
# Clean up on timeout
if
rid
in
self
.
request_manager
.
rid_to_state
:
del
self
.
request_manager
.
rid_to_state
[
rid
]
return
sglang_scheduler_pb2
.
HealthCheckResponse
(
healthy
=
False
,
message
=
"Health check timeout"
)
except
Exception
as
e
:
logger
.
error
(
f
"Health check failed:
{
e
}
"
)
return
sglang_scheduler_pb2
.
HealthCheckResponse
(
healthy
=
False
,
message
=
f
"Health check error:
{
str
(
e
)
}
"
)
async
def
Abort
(
self
,
request
:
sglang_scheduler_pb2
.
AbortRequest
,
context
:
grpc
.
aio
.
ServicerContext
,
)
->
sglang_scheduler_pb2
.
AbortResponse
:
"""Abort an ongoing request."""
logger
.
info
(
f
"Aborting request:
{
request
.
request_id
}
"
)
try
:
success
=
await
self
.
request_manager
.
abort_request
(
request
.
request_id
)
return
sglang_scheduler_pb2
.
AbortResponse
(
success
=
success
,
message
=
f
"Request
{
request
.
request_id
}
{
'aborted'
if
success
else
'not found'
}
"
,
)
except
Exception
as
e
:
logger
.
error
(
f
"Abort failed:
{
e
}
"
)
return
sglang_scheduler_pb2
.
AbortResponse
(
success
=
False
,
message
=
str
(
e
),
)
# Helper methods for request/response conversion
def
_convert_generate_request
(
self
,
grpc_req
:
sglang_scheduler_pb2
.
GenerateRequest
)
->
TokenizedGenerateReqInput
:
"""Convert gRPC GenerateRequest to internal format."""
# Extract tokenized input
if
not
grpc_req
.
HasField
(
"tokenized"
):
raise
ValueError
(
"Tokenized input must be provided"
)
input_text
=
grpc_req
.
tokenized
.
original_text
input_ids
=
list
(
grpc_req
.
tokenized
.
input_ids
)
# Convert sampling params
sampling_params
=
self
.
_convert_sampling_params
(
grpc_req
.
sampling_params
)
# Create request
return
TokenizedGenerateReqInput
(
rid
=
grpc_req
.
request_id
,
input_text
=
input_text
,
input_ids
=
input_ids
,
mm_inputs
=
None
,
# TODO: implement mm support
sampling_params
=
sampling_params
,
return_logprob
=
grpc_req
.
return_logprob
,
logprob_start_len
=
grpc_req
.
logprob_start_len
or
-
1
,
top_logprobs_num
=
grpc_req
.
top_logprobs_num
or
0
,
stream
=
True
,
# Always stream for gRPC
lora_path
=
grpc_req
.
lora_id
if
grpc_req
.
lora_id
else
None
,
token_ids_logprob
=
(
list
(
grpc_req
.
token_ids_logprob
)
if
grpc_req
.
token_ids_logprob
else
None
),
)
def
_convert_embed_request
(
self
,
grpc_req
:
sglang_scheduler_pb2
.
EmbedRequest
)
->
TokenizedEmbeddingReqInput
:
"""Convert gRPC EmbedRequest to internal format."""
# Extract tokenized input
if
not
grpc_req
.
HasField
(
"tokenized"
):
raise
ValueError
(
"Tokenized input must be provided"
)
input_text
=
grpc_req
.
tokenized
.
original_text
input_ids
=
list
(
grpc_req
.
tokenized
.
input_ids
)
return
TokenizedEmbeddingReqInput
(
rid
=
grpc_req
.
request_id
,
input_text
=
input_text
,
input_ids
=
input_ids
,
)
def
_convert_sampling_params
(
self
,
grpc_params
:
sglang_scheduler_pb2
.
SamplingParams
)
->
SGLSamplingParams
:
"""Convert gRPC SamplingParams to internal format."""
# Handle constraint types
regex
=
None
json_schema
=
None
ebnf_grammar
=
None
if
grpc_params
.
HasField
(
"regex"
):
regex
=
grpc_params
.
regex
elif
grpc_params
.
HasField
(
"json_schema"
):
json_schema
=
grpc_params
.
json_schema
elif
grpc_params
.
HasField
(
"ebnf_grammar"
):
ebnf_grammar
=
grpc_params
.
ebnf_grammar
return
SGLSamplingParams
(
temperature
=
grpc_params
.
temperature
or
1.0
,
top_p
=
grpc_params
.
top_p
or
1.0
,
top_k
=
grpc_params
.
top_k
or
-
1
,
min_p
=
grpc_params
.
min_p
or
0.0
,
frequency_penalty
=
grpc_params
.
frequency_penalty
or
0.0
,
presence_penalty
=
grpc_params
.
presence_penalty
or
0.0
,
repetition_penalty
=
grpc_params
.
repetition_penalty
or
1.0
,
max_new_tokens
=
grpc_params
.
max_new_tokens
or
128
,
min_new_tokens
=
grpc_params
.
min_new_tokens
or
0
,
stop
=
list
(
grpc_params
.
stop
)
if
grpc_params
.
stop
else
None
,
stop_token_ids
=
(
list
(
grpc_params
.
stop_token_ids
)
if
grpc_params
.
stop_token_ids
else
None
),
skip_special_tokens
=
grpc_params
.
skip_special_tokens
,
spaces_between_special_tokens
=
grpc_params
.
spaces_between_special_tokens
,
regex
=
regex
,
json_schema
=
json_schema
,
ebnf
=
ebnf_grammar
,
n
=
grpc_params
.
n
or
1
,
ignore_eos
=
grpc_params
.
ignore_eos
,
)
def
_create_chunk_response
(
self
,
request_id
:
str
,
output
:
Dict
)
->
sglang_scheduler_pb2
.
GenerateResponse
:
"""Create a streaming chunk response."""
return
sglang_scheduler_pb2
.
GenerateResponse
(
request_id
=
request_id
,
chunk
=
sglang_scheduler_pb2
.
GenerateStreamChunk
(
token_id
=
output
[
"token_ids"
][
-
1
]
if
output
.
get
(
"token_ids"
)
else
0
,
text
=
output
.
get
(
"text"
,
""
),
prompt_tokens
=
0
,
completion_tokens
=
len
(
output
.
get
(
"token_ids"
,
[])),
cached_tokens
=
0
,
generation_time
=
time
.
time
()
-
self
.
start_time
,
queue_time
=
0.0
,
),
)
def
_create_completion_response
(
self
,
request_id
:
str
,
output
:
Dict
)
->
sglang_scheduler_pb2
.
GenerateResponse
:
"""Create a completion response."""
# Determine finish reason
finish_reason
=
sglang_scheduler_pb2
.
GenerateComplete
.
STOP
meta_info
=
output
.
get
(
"meta_info"
,
{})
if
meta_info
.
get
(
"finish_reason"
)
==
"length"
:
finish_reason
=
sglang_scheduler_pb2
.
GenerateComplete
.
LENGTH
elif
meta_info
.
get
(
"finish_reason"
)
==
"eos_token"
:
finish_reason
=
sglang_scheduler_pb2
.
GenerateComplete
.
EOS_TOKEN
return
sglang_scheduler_pb2
.
GenerateResponse
(
request_id
=
request_id
,
complete
=
sglang_scheduler_pb2
.
GenerateComplete
(
output_ids
=
output
.
get
(
"token_ids"
,
[]),
output_text
=
output
.
get
(
"text"
,
""
),
finish_reason
=
finish_reason
,
),
)
async
def
shutdown
(
self
):
"""Shutdown the service."""
logger
.
info
(
"Shutting down gRPC service"
)
# Shutdown request manager (handles its own tasks)
await
self
.
request_manager
.
shutdown
()
async
def
serve_grpc
(
server_args
:
ServerArgs
,
model_info
:
Optional
[
Dict
]
=
None
,
):
"""Start the standalone gRPC server with integrated scheduler."""
# Launch only the scheduler process(es) (no tokenizer/detokenizer needed for gRPC)
logger
.
info
(
"Launching scheduler process(es)..."
)
scheduler_info
,
port_args
,
scheduler_procs
=
_launch_scheduler_process_only
(
server_args
=
server_args
,
)
# Update model info from scheduler info
if
model_info
is
None
:
model_info
=
{
"model_name"
:
server_args
.
model_path
,
"max_context_length"
:
scheduler_info
.
get
(
"max_total_num_tokens"
,
server_args
.
context_length
or
8192
),
"vocab_size"
:
scheduler_info
.
get
(
"vocab_size"
,
128256
),
"supports_vision"
:
scheduler_info
.
get
(
"supports_vision"
,
False
),
"model_type"
:
scheduler_info
.
get
(
"model_type"
,
"transformer"
),
"max_req_input_len"
:
scheduler_info
.
get
(
"max_req_input_len"
,
8192
),
"eos_token_ids"
:
scheduler_info
.
get
(
"eos_token_ids"
,
[]),
"pad_token_id"
:
scheduler_info
.
get
(
"pad_token_id"
,
0
),
"bos_token_id"
:
scheduler_info
.
get
(
"bos_token_id"
,
1
),
}
# Create request manager with the correct port args
request_manager
=
GrpcRequestManager
(
server_args
=
server_args
,
port_args
=
port_args
,
)
# Create gRPC server
server
=
grpc
.
aio
.
server
(
futures
.
ThreadPoolExecutor
(
max_workers
=
10
),
options
=
[
(
"grpc.max_send_message_length"
,
1024
*
1024
*
256
),
(
"grpc.max_receive_message_length"
,
1024
*
1024
*
256
),
],
)
# Add service
servicer
=
SGLangSchedulerServicer
(
request_manager
=
request_manager
,
server_args
=
server_args
,
model_info
=
model_info
,
)
sglang_scheduler_pb2_grpc
.
add_SglangSchedulerServicer_to_server
(
servicer
,
server
)
# Enable reflection
SERVICE_NAMES
=
(
sglang_scheduler_pb2
.
DESCRIPTOR
.
services_by_name
[
"SglangScheduler"
].
full_name
,
reflection
.
SERVICE_NAME
,
)
reflection
.
enable_server_reflection
(
SERVICE_NAMES
,
server
)
# Start server
listen_addr
=
f
"
{
server_args
.
host
}
:
{
server_args
.
port
}
"
server
.
add_insecure_port
(
listen_addr
)
logger
.
info
(
f
"Starting standalone gRPC server on
{
listen_addr
}
"
)
await
server
.
start
()
# Handle shutdown signals
loop
=
asyncio
.
get_running_loop
()
stop_event
=
asyncio
.
Event
()
def
signal_handler
():
logger
.
info
(
"Received shutdown signal"
)
stop_event
.
set
()
for
sig
in
(
signal
.
SIGTERM
,
signal
.
SIGINT
):
loop
.
add_signal_handler
(
sig
,
signal_handler
)
try
:
await
stop_event
.
wait
()
finally
:
logger
.
info
(
"Shutting down gRPC server"
)
await
servicer
.
shutdown
()
await
server
.
stop
(
5.0
)
# Terminate scheduler processes
for
i
,
proc
in
enumerate
(
scheduler_procs
):
if
proc
and
proc
.
is_alive
():
logger
.
info
(
f
"Terminating scheduler process
{
i
}
..."
)
proc
.
terminate
()
proc
.
join
(
timeout
=
5.0
)
if
proc
.
is_alive
():
logger
.
warning
(
f
"Force killing scheduler process
{
i
}
..."
)
proc
.
kill
()
proc
.
join
()
def
main
():
"""Main entry point for standalone gRPC server."""
# Fix CUDA multiprocessing issues - must be called before any CUDA operations
mp
.
set_start_method
(
"spawn"
,
force
=
True
)
parser
=
argparse
.
ArgumentParser
(
description
=
"SGLang Standalone gRPC Server"
)
# Server arguments
parser
.
add_argument
(
"--host"
,
type
=
str
,
default
=
"0.0.0.0"
,
help
=
"Host to bind to"
)
parser
.
add_argument
(
"--port"
,
type
=
int
,
default
=
30000
,
help
=
"gRPC server port"
)
# Model arguments
parser
.
add_argument
(
"--model-path"
,
type
=
str
,
required
=
True
,
help
=
"Model path"
)
parser
.
add_argument
(
"--tokenizer-path"
,
type
=
str
,
help
=
"Tokenizer path"
)
parser
.
add_argument
(
"--context-length"
,
type
=
int
,
help
=
"Context length"
)
parser
.
add_argument
(
"--tp-size"
,
type
=
int
,
default
=
1
,
help
=
"Tensor parallel size"
)
parser
.
add_argument
(
"--dp-size"
,
type
=
int
,
default
=
1
,
help
=
"Data parallel size"
)
# Runtime arguments
parser
.
add_argument
(
"--max-running-requests"
,
type
=
int
,
default
=
2048
,
help
=
"Max concurrent requests"
)
parser
.
add_argument
(
"--max-total-tokens"
,
type
=
int
,
default
=
1000000
,
help
=
"Max total tokens"
)
parser
.
add_argument
(
"--max-prefill-tokens"
,
type
=
int
,
default
=
16384
,
help
=
"Max prefill tokens"
)
parser
.
add_argument
(
"--attention-backend"
,
type
=
str
,
default
=
"flashinfer"
,
help
=
"Attention backend"
)
parser
.
add_argument
(
"--lora-paths"
,
type
=
str
,
help
=
"LoRA adapter paths"
)
# Logging
parser
.
add_argument
(
"--log-level"
,
type
=
str
,
default
=
"INFO"
,
help
=
"Logging level"
)
args
=
parser
.
parse_args
()
# Convert to ServerArgs with gRPC host and port
server_args
=
ServerArgs
(
model_path
=
args
.
model_path
,
tokenizer_path
=
args
.
tokenizer_path
or
args
.
model_path
,
context_length
=
args
.
context_length
,
tp_size
=
args
.
tp_size
,
dp_size
=
args
.
dp_size
,
max_running_requests
=
args
.
max_running_requests
,
max_total_tokens
=
args
.
max_total_tokens
,
max_prefill_tokens
=
args
.
max_prefill_tokens
,
attention_backend
=
args
.
attention_backend
,
lora_paths
=
args
.
lora_paths
.
split
(
","
)
if
args
.
lora_paths
else
None
,
log_level
=
args
.
log_level
,
# Override with gRPC server host and port
host
=
args
.
host
,
port
=
args
.
port
,
)
# Run server
asyncio
.
run
(
serve_grpc
(
server_args
=
server_args
,
)
)
if
__name__
==
"__main__"
:
main
()
python/sglang/srt/grpc/__init__.py
0 → 100644
View file @
53ca1552
# SGLang gRPC module
python/sglang/srt/grpc/sglang_scheduler.proto
0 → 100644
View file @
53ca1552
syntax
=
"proto3"
;
package
sglang
.
grpc.scheduler
;
import
"google/protobuf/timestamp.proto"
;
import
"google/protobuf/struct.proto"
;
// Service definition for SGLang scheduler communication
// This protocol bridges the Rust router and Python scheduler
service
SglangScheduler
{
// Submit a generation request (supports streaming)
rpc
Generate
(
GenerateRequest
)
returns
(
stream
GenerateResponse
);
// Submit an embedding request
rpc
Embed
(
EmbedRequest
)
returns
(
EmbedResponse
);
// Health check and metrics
rpc
HealthCheck
(
HealthCheckRequest
)
returns
(
HealthCheckResponse
);
// Abort a running request
rpc
Abort
(
AbortRequest
)
returns
(
AbortResponse
);
}
// =====================
// Common Types
// =====================
// Sampling parameters matching SGLang's SamplingParams
message
SamplingParams
{
float
temperature
=
1
;
float
top_p
=
2
;
int32
top_k
=
3
;
float
min_p
=
4
;
float
frequency_penalty
=
5
;
float
presence_penalty
=
6
;
float
repetition_penalty
=
7
;
int32
max_new_tokens
=
8
;
repeated
string
stop
=
9
;
repeated
int32
stop_token_ids
=
10
;
bool
skip_special_tokens
=
11
;
bool
spaces_between_special_tokens
=
12
;
// Structured generation
oneof
constraint
{
string
regex
=
13
;
string
json_schema
=
14
;
string
ebnf_grammar
=
15
;
}
// LoRA adapter
string
lora_path
=
16
;
// Speculative decoding
int32
n
=
17
;
// Number of samples
// Token healing
bool
token_healing
=
18
;
// Additional parameters
int32
min_new_tokens
=
19
;
bool
ignore_eos
=
20
;
bool
no_stop_trim
=
21
;
int32
stream_interval
=
22
;
map
<
string
,
float
>
logit_bias
=
23
;
string
structural_tag
=
24
;
// Custom parameters for extensibility
google.protobuf.Struct
custom_params
=
25
;
}
// Disaggregated serving parameters
message
DisaggregatedParams
{
string
bootstrap_host
=
1
;
int32
bootstrap_port
=
2
;
int32
bootstrap_room
=
3
;
}
// =====================
// Generate Request
// =====================
message
GenerateRequest
{
string
request_id
=
1
;
// Input must be tokenized (no raw text)
TokenizedInput
tokenized
=
2
;
// Multimodal inputs
MultimodalInputs
mm_inputs
=
3
;
// Generation parameters
SamplingParams
sampling_params
=
4
;
// Return options
bool
return_logprob
=
5
;
int32
logprob_start_len
=
6
;
int32
top_logprobs_num
=
7
;
repeated
int32
token_ids_logprob
=
8
;
bool
return_hidden_states
=
9
;
// For disaggregated serving
DisaggregatedParams
disaggregated_params
=
10
;
// Custom logit processor (serialized)
string
custom_logit_processor
=
11
;
// Request metadata
google.protobuf.Timestamp
timestamp
=
12
;
bool
log_metrics
=
13
;
// Input embeddings (alternative to text/tokens)
repeated
float
input_embeds
=
14
;
// LoRA adapter ID (if pre-loaded)
string
lora_id
=
15
;
// Data parallel routing
int32
data_parallel_rank
=
16
;
// For load balancing
int32
dp_balance_id
=
17
;
}
message
TokenizedInput
{
string
original_text
=
1
;
// For reference
repeated
int32
input_ids
=
2
;
}
message
MultimodalInputs
{
// Simplified multimodal handling - actual data processed by tokenizer
repeated
string
image_urls
=
1
;
repeated
string
video_urls
=
2
;
repeated
string
audio_urls
=
3
;
// Pre-processed multimodal features (if available)
google.protobuf.Struct
processed_features
=
4
;
// Raw data for direct processing
repeated
bytes
image_data
=
5
;
repeated
bytes
video_data
=
6
;
repeated
bytes
audio_data
=
7
;
// Modality metadata
repeated
string
modalities
=
8
;
}
// =====================
// Generate Response
// =====================
message
GenerateResponse
{
string
request_id
=
1
;
// Response type
oneof
response
{
GenerateStreamChunk
chunk
=
2
;
GenerateComplete
complete
=
3
;
GenerateError
error
=
4
;
}
}
message
GenerateStreamChunk
{
// Generated token
int32
token_id
=
1
;
string
text
=
2
;
// Cumulative counts
int32
prompt_tokens
=
3
;
int32
completion_tokens
=
4
;
int32
cached_tokens
=
5
;
// Logprobs (if requested)
LogProbs
logprobs
=
6
;
// Hidden states (if requested)
repeated
float
hidden_states
=
7
;
// Metadata
float
generation_time
=
8
;
// Time to generate this token
int32
queue_time
=
9
;
// Time spent in queue
}
message
GenerateComplete
{
// Final output
repeated
int32
output_ids
=
1
;
string
output_text
=
2
;
// Finish reason
enum
FinishReason
{
// The model generated a stop sequence.
STOP
=
0
;
// The model reached the maximum generation length.
LENGTH
=
1
;
// The model generated an end-of-sequence (EOS) token.
EOS_TOKEN
=
2
;
// The model generated a user-provided stop string.
STOP_STR
=
3
;
// The request was aborted by the user or system.
ABORT
=
4
;
}
FinishReason
finish_reason
=
3
;
// All logprobs if requested
repeated
LogProbs
all_logprobs
=
11
;
// All hidden states if requested
repeated
HiddenStates
all_hidden_states
=
12
;
}
message
GenerateError
{
string
message
=
1
;
string
http_status_code
=
2
;
string
details
=
3
;
}
message
LogProbs
{
repeated
float
token_logprobs
=
1
;
repeated
int32
token_ids
=
2
;
// Top logprobs at each position
repeated
TopLogProbs
top_logprobs
=
3
;
// Decoded text for tokens
repeated
string
token_texts
=
4
;
}
message
TopLogProbs
{
repeated
float
values
=
1
;
repeated
int32
token_ids
=
2
;
repeated
string
token_texts
=
3
;
}
message
HiddenStates
{
repeated
float
values
=
1
;
int32
layer
=
2
;
int32
position
=
3
;
}
// =====================
// Embedding Request
// =====================
message
EmbedRequest
{
string
request_id
=
1
;
// Input must be tokenized (no raw text)
TokenizedInput
tokenized
=
2
;
// Multimodal inputs
MultimodalInputs
mm_inputs
=
4
;
// Dummy sampling params for compatibility
// EmbedRequest doesn't use sampling_params
SamplingParams
sampling_params
=
5
;
bool
log_metrics
=
6
;
// Token type IDs for models that require them
repeated
int32
token_type_ids
=
7
;
// Data parallel routing
int32
data_parallel_rank
=
8
;
// For cross-encoder requests
bool
is_cross_encoder
=
9
;
repeated
string
texts
=
10
;
// For cross-encoder batch
}
message
EmbedResponse
{
string
request_id
=
1
;
oneof
response
{
EmbedComplete
complete
=
2
;
EmbedError
error
=
3
;
}
}
message
EmbedComplete
{
repeated
float
embedding
=
1
;
int32
prompt_tokens
=
2
;
int32
cached_tokens
=
3
;
// Additional metadata
int32
embedding_dim
=
4
;
float
generation_time
=
5
;
// For batch embeddings
repeated
Embedding
batch_embeddings
=
6
;
}
message
Embedding
{
repeated
float
values
=
1
;
int32
index
=
2
;
}
message
EmbedError
{
string
message
=
1
;
string
code
=
2
;
string
details
=
3
;
}
// =====================
// Management Operations
// =====================
message
HealthCheckRequest
{
// Input for health test generation (must be tokenized)
TokenizedInput
tokenized
=
1
;
}
message
HealthCheckResponse
{
bool
healthy
=
1
;
string
message
=
2
;
}
message
AbortRequest
{
string
request_id
=
1
;
string
reason
=
2
;
}
message
AbortResponse
{
bool
success
=
1
;
string
message
=
2
;
}
// =====================
// Additional Operations (Future)
// =====================
// Load LoRA adapter
message
LoadLoRARequest
{
string
adapter_id
=
1
;
string
adapter_path
=
2
;
int32
rank
=
3
;
}
message
LoadLoRAResponse
{
bool
success
=
1
;
string
adapter_id
=
2
;
string
message
=
3
;
}
// Unload LoRA adapter
message
UnloadLoRARequest
{
string
adapter_id
=
1
;
}
message
UnloadLoRAResponse
{
bool
success
=
1
;
string
message
=
2
;
}
// Update weights
message
UpdateWeightsRequest
{
oneof
source
{
string
disk_path
=
1
;
bytes
tensor_data
=
2
;
string
remote_url
=
3
;
}
string
weight_name
=
4
;
}
message
UpdateWeightsResponse
{
bool
success
=
1
;
string
message
=
2
;
}
// Get internal state for debugging
message
GetInternalStateRequest
{
repeated
string
state_keys
=
1
;
}
message
GetInternalStateResponse
{
google.protobuf.Struct
state
=
1
;
}
// Set internal state for testing
message
SetInternalStateRequest
{
google.protobuf.Struct
state
=
1
;
}
message
SetInternalStateResponse
{
bool
success
=
1
;
string
message
=
2
;
}
python/sglang/srt/grpc/sglang_scheduler_pb2.py
0 → 100644
View file @
53ca1552
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# NO CHECKED-IN PROTOBUF GENCODE
# source: sglang_scheduler.proto
# Protobuf Python Version: 6.31.1
"""Generated protocol buffer code."""
from
google.protobuf
import
descriptor
as
_descriptor
from
google.protobuf
import
descriptor_pool
as
_descriptor_pool
from
google.protobuf
import
runtime_version
as
_runtime_version
from
google.protobuf
import
symbol_database
as
_symbol_database
from
google.protobuf.internal
import
builder
as
_builder
_runtime_version
.
ValidateProtobufRuntimeVersion
(
_runtime_version
.
Domain
.
PUBLIC
,
6
,
31
,
1
,
''
,
'sglang_scheduler.proto'
)
# @@protoc_insertion_point(imports)
_sym_db
=
_symbol_database
.
Default
()
from
google.protobuf
import
timestamp_pb2
as
google_dot_protobuf_dot_timestamp__pb2
from
google.protobuf
import
struct_pb2
as
google_dot_protobuf_dot_struct__pb2
DESCRIPTOR
=
_descriptor_pool
.
Default
().
AddSerializedFile
(
b
'
\n\x16
sglang_scheduler.proto
\x12\x15
sglang.grpc.scheduler
\x1a\x1f
google/protobuf/timestamp.proto
\x1a\x1c
google/protobuf/struct.proto
\"\xc7\x05\n\x0e
SamplingParams
\x12\x13\n\x0b
temperature
\x18\x01
\x01
(
\x02\x12\r\n\x05
top_p
\x18\x02
\x01
(
\x02\x12\r\n\x05
top_k
\x18\x03
\x01
(
\x05\x12\r\n\x05
min_p
\x18\x04
\x01
(
\x02\x12\x19\n\x11\x66
requency_penalty
\x18\x05
\x01
(
\x02\x12\x18\n\x10
presence_penalty
\x18\x06
\x01
(
\x02\x12\x1a\n\x12
repetition_penalty
\x18\x07
\x01
(
\x02\x12\x16\n\x0e
max_new_tokens
\x18\x08
\x01
(
\x05\x12\x0c\n\x04
stop
\x18\t
\x03
(
\t\x12\x16\n\x0e
stop_token_ids
\x18\n
\x03
(
\x05\x12\x1b\n\x13
skip_special_tokens
\x18\x0b
\x01
(
\x08\x12
%
\n\x1d
spaces_between_special_tokens
\x18\x0c
\x01
(
\x08\x12\x0f\n\x05
regex
\x18\r
\x01
(
\t
H
\x00\x12\x15\n\x0b
json_schema
\x18\x0e
\x01
(
\t
H
\x00\x12\x16\n\x0c\x65\x62
nf_grammar
\x18\x0f
\x01
(
\t
H
\x00\x12\x11\n\t
lora_path
\x18\x10
\x01
(
\t\x12\t\n\x01
n
\x18\x11
\x01
(
\x05\x12\x15\n\r
token_healing
\x18\x12
\x01
(
\x08\x12\x16\n\x0e
min_new_tokens
\x18\x13
\x01
(
\x05\x12\x12\n\n
ignore_eos
\x18\x14
\x01
(
\x08\x12\x14\n\x0c
no_stop_trim
\x18\x15
\x01
(
\x08\x12\x17\n\x0f
stream_interval
\x18\x16
\x01
(
\x05\x12
H
\n\n
logit_bias
\x18\x17
\x03
(
\x0b\x32\x34
.sglang.grpc.scheduler.SamplingParams.LogitBiasEntry
\x12\x16\n\x0e
structural_tag
\x18\x18
\x01
(
\t\x12
.
\n\r
custom_params
\x18\x19
\x01
(
\x0b\x32\x17
.google.protobuf.Struct
\x1a\x30\n\x0e
LogitBiasEntry
\x12\x0b\n\x03
key
\x18\x01
\x01
(
\t\x12\r\n\x05
value
\x18\x02
\x01
(
\x02
:
\x02\x38\x01\x42\x0c\n\n
constraint
\"
]
\n\x13\x44
isaggregatedParams
\x12\x16\n\x0e\x62
ootstrap_host
\x18\x01
\x01
(
\t\x12\x16\n\x0e\x62
ootstrap_port
\x18\x02
\x01
(
\x05\x12\x16\n\x0e\x62
ootstrap_room
\x18\x03
\x01
(
\x05\"\xe9\x04\n\x0f
GenerateRequest
\x12\x12\n\n
request_id
\x18\x01
\x01
(
\t\x12\x38\n\t
tokenized
\x18\x02
\x01
(
\x0b\x32
%.sglang.grpc.scheduler.TokenizedInput
\x12
:
\n\t
mm_inputs
\x18\x03
\x01
(
\x0b\x32\'
.sglang.grpc.scheduler.MultimodalInputs
\x12
>
\n\x0f
sampling_params
\x18\x04
\x01
(
\x0b\x32
%.sglang.grpc.scheduler.SamplingParams
\x12\x16\n\x0e
return_logprob
\x18\x05
\x01
(
\x08\x12\x19\n\x11
logprob_start_len
\x18\x06
\x01
(
\x05\x12\x18\n\x10
top_logprobs_num
\x18\x07
\x01
(
\x05\x12\x19\n\x11
token_ids_logprob
\x18\x08
\x03
(
\x05\x12\x1c\n\x14
return_hidden_states
\x18\t
\x01
(
\x08\x12
H
\n\x14\x64
isaggregated_params
\x18\n
\x01
(
\x0b\x32
*.sglang.grpc.scheduler.DisaggregatedParams
\x12\x1e\n\x16\x63
ustom_logit_processor
\x18\x0b
\x01
(
\t\x12
-
\n\t
timestamp
\x18\x0c
\x01
(
\x0b\x32\x1a
.google.protobuf.Timestamp
\x12\x13\n\x0b
log_metrics
\x18\r
\x01
(
\x08\x12\x14\n\x0c
input_embeds
\x18\x0e
\x03
(
\x02\x12\x0f\n\x07
lora_id
\x18\x0f
\x01
(
\t\x12\x1a\n\x12\x64\x61
ta_parallel_rank
\x18\x10
\x01
(
\x05\x12\x15\n\r
dp_balance_id
\x18\x11
\x01
(
\x05\"
:
\n\x0e
TokenizedInput
\x12\x15\n\r
original_text
\x18\x01
\x01
(
\t\x12\x11\n\t
input_ids
\x18\x02
\x03
(
\x05\"\xd3\x01\n\x10
MultimodalInputs
\x12\x12\n\n
image_urls
\x18\x01
\x03
(
\t\x12\x12\n\n
video_urls
\x18\x02
\x03
(
\t\x12\x12\n\n
audio_urls
\x18\x03
\x03
(
\t\x12\x33\n\x12
processed_features
\x18\x04
\x01
(
\x0b\x32\x17
.google.protobuf.Struct
\x12\x12\n\n
image_data
\x18\x05
\x03
(
\x0c\x12\x12\n\n
video_data
\x18\x06
\x03
(
\x0c\x12\x12\n\n
audio_data
\x18\x07
\x03
(
\x0c\x12\x12\n\n
modalities
\x18\x08
\x03
(
\t\"\xe3\x01\n\x10
GenerateResponse
\x12\x12\n\n
request_id
\x18\x01
\x01
(
\t\x12
;
\n\x05\x63
hunk
\x18\x02
\x01
(
\x0b\x32
*.sglang.grpc.scheduler.GenerateStreamChunkH
\x00\x12
;
\n\x08\x63
omplete
\x18\x03
\x01
(
\x0b\x32\'
.sglang.grpc.scheduler.GenerateCompleteH
\x00\x12\x35\n\x05\x65
rror
\x18\x04
\x01
(
\x0b\x32
$.sglang.grpc.scheduler.GenerateErrorH
\x00\x42\n\n\x08
response
\"\xf5\x01\n\x13
GenerateStreamChunk
\x12\x10\n\x08
token_id
\x18\x01
\x01
(
\x05\x12\x0c\n\x04
text
\x18\x02
\x01
(
\t\x12\x15\n\r
prompt_tokens
\x18\x03
\x01
(
\x05\x12\x19\n\x11\x63
ompletion_tokens
\x18\x04
\x01
(
\x05\x12\x15\n\r
cached_tokens
\x18\x05
\x01
(
\x05\x12\x31\n\x08
logprobs
\x18\x06
\x01
(
\x0b\x32\x1f
.sglang.grpc.scheduler.LogProbs
\x12\x15\n\r
hidden_states
\x18\x07
\x03
(
\x02\x12\x17\n\x0f
generation_time
\x18\x08
\x01
(
\x02\x12\x12\n\n
queue_time
\x18\t
\x01
(
\x05\"\xcd\x02\n\x10
GenerateComplete
\x12\x12\n\n
output_ids
\x18\x01
\x03
(
\x05\x12\x13\n\x0b
output_text
\x18\x02
\x01
(
\t\x12
K
\n\r
finish_reason
\x18\x03
\x01
(
\x0e\x32\x34
.sglang.grpc.scheduler.GenerateComplete.FinishReason
\x12\x35\n\x0c\x61
ll_logprobs
\x18\x0b
\x03
(
\x0b\x32\x1f
.sglang.grpc.scheduler.LogProbs
\x12
>
\n\x11\x61
ll_hidden_states
\x18\x0c
\x03
(
\x0b\x32
#.sglang.grpc.scheduler.HiddenStates
\"
L
\n\x0c\x46
inishReason
\x12\x08\n\x04
STOP
\x10\x00\x12\n\n\x06
LENGTH
\x10\x01\x12\r\n\t
EOS_TOKEN
\x10\x02\x12\x0c\n\x08
STOP_STR
\x10\x03\x12\t\n\x05\x41\x42
ORT
\x10\x04\"
K
\n\r
GenerateError
\x12\x0f\n\x07
message
\x18\x01
\x01
(
\t\x12\x18\n\x10
http_status_code
\x18\x02
\x01
(
\t\x12\x0f\n\x07\x64\x65
tails
\x18\x03
\x01
(
\t\"\x84\x01\n\x08
LogProbs
\x12\x16\n\x0e
token_logprobs
\x18\x01
\x03
(
\x02\x12\x11\n\t
token_ids
\x18\x02
\x03
(
\x05\x12\x38\n\x0c
top_logprobs
\x18\x03
\x03
(
\x0b\x32\"
.sglang.grpc.scheduler.TopLogProbs
\x12\x13\n\x0b
token_texts
\x18\x04
\x03
(
\t\"
E
\n\x0b
TopLogProbs
\x12\x0e\n\x06
values
\x18\x01
\x03
(
\x02\x12\x11\n\t
token_ids
\x18\x02
\x03
(
\x05\x12\x13\n\x0b
token_texts
\x18\x03
\x03
(
\t\"
?
\n\x0c
HiddenStates
\x12\x0e\n\x06
values
\x18\x01
\x03
(
\x02\x12\r\n\x05
layer
\x18\x02
\x01
(
\x05\x12\x10\n\x08
position
\x18\x03
\x01
(
\x05\"\xca\x02\n\x0c\x45
mbedRequest
\x12\x12\n\n
request_id
\x18\x01
\x01
(
\t\x12\x38\n\t
tokenized
\x18\x02
\x01
(
\x0b\x32
%.sglang.grpc.scheduler.TokenizedInput
\x12
:
\n\t
mm_inputs
\x18\x04
\x01
(
\x0b\x32\'
.sglang.grpc.scheduler.MultimodalInputs
\x12
>
\n\x0f
sampling_params
\x18\x05
\x01
(
\x0b\x32
%.sglang.grpc.scheduler.SamplingParams
\x12\x13\n\x0b
log_metrics
\x18\x06
\x01
(
\x08\x12\x16\n\x0e
token_type_ids
\x18\x07
\x03
(
\x05\x12\x1a\n\x12\x64\x61
ta_parallel_rank
\x18\x08
\x01
(
\x05\x12\x18\n\x10
is_cross_encoder
\x18\t
\x01
(
\x08\x12\r\n\x05
texts
\x18\n
\x03
(
\t\"\x9d\x01\n\r
EmbedResponse
\x12\x12\n\n
request_id
\x18\x01
\x01
(
\t\x12\x38\n\x08\x63
omplete
\x18\x02
\x01
(
\x0b\x32
$.sglang.grpc.scheduler.EmbedCompleteH
\x00\x12\x32\n\x05\x65
rror
\x18\x03
\x01
(
\x0b\x32
!.sglang.grpc.scheduler.EmbedErrorH
\x00\x42\n\n\x08
response
\"\xbc\x01\n\r
EmbedComplete
\x12\x11\n\t
embedding
\x18\x01
\x03
(
\x02\x12\x15\n\r
prompt_tokens
\x18\x02
\x01
(
\x05\x12\x15\n\r
cached_tokens
\x18\x03
\x01
(
\x05\x12\x15\n\r
embedding_dim
\x18\x04
\x01
(
\x05\x12\x17\n\x0f
generation_time
\x18\x05
\x01
(
\x02\x12
:
\n\x10\x62\x61
tch_embeddings
\x18\x06
\x03
(
\x0b\x32
.sglang.grpc.scheduler.Embedding
\"
*
\n\t
Embedding
\x12\x0e\n\x06
values
\x18\x01
\x03
(
\x02\x12\r\n\x05
index
\x18\x02
\x01
(
\x05\"
<
\n\n
EmbedError
\x12\x0f\n\x07
message
\x18\x01
\x01
(
\t\x12\x0c\n\x04\x63
ode
\x18\x02
\x01
(
\t\x12\x0f\n\x07\x64\x65
tails
\x18\x03
\x01
(
\t\"
N
\n\x12
HealthCheckRequest
\x12\x38\n\t
tokenized
\x18\x01
\x01
(
\x0b\x32
%.sglang.grpc.scheduler.TokenizedInput
\"
7
\n\x13
HealthCheckResponse
\x12\x0f\n\x07
healthy
\x18\x01
\x01
(
\x08\x12\x0f\n\x07
message
\x18\x02
\x01
(
\t\"
2
\n\x0c\x41\x62
ortRequest
\x12\x12\n\n
request_id
\x18\x01
\x01
(
\t\x12\x0e\n\x06
reason
\x18\x02
\x01
(
\t\"
1
\n\r
AbortResponse
\x12\x0f\n\x07
success
\x18\x01
\x01
(
\x08\x12\x0f\n\x07
message
\x18\x02
\x01
(
\t\"
I
\n\x0f
LoadLoRARequest
\x12\x12\n\n
adapter_id
\x18\x01
\x01
(
\t\x12\x14\n\x0c\x61\x64\x61
pter_path
\x18\x02
\x01
(
\t\x12\x0c\n\x04
rank
\x18\x03
\x01
(
\x05\"
H
\n\x10
LoadLoRAResponse
\x12\x0f\n\x07
success
\x18\x01
\x01
(
\x08\x12\x12\n\n
adapter_id
\x18\x02
\x01
(
\t\x12\x0f\n\x07
message
\x18\x03
\x01
(
\t\"\'\n\x11
UnloadLoRARequest
\x12\x12\n\n
adapter_id
\x18\x01
\x01
(
\t\"
6
\n\x12
UnloadLoRAResponse
\x12\x0f\n\x07
success
\x18\x01
\x01
(
\x08\x12\x0f\n\x07
message
\x18\x02
\x01
(
\t\"
w
\n\x14
UpdateWeightsRequest
\x12\x13\n\t
disk_path
\x18\x01
\x01
(
\t
H
\x00\x12\x15\n\x0b
tensor_data
\x18\x02
\x01
(
\x0c
H
\x00\x12\x14\n\n
remote_url
\x18\x03
\x01
(
\t
H
\x00\x12\x13\n\x0b
weight_name
\x18\x04
\x01
(
\t
B
\x08\n\x06
source
\"
9
\n\x15
UpdateWeightsResponse
\x12\x0f\n\x07
success
\x18\x01
\x01
(
\x08\x12\x0f\n\x07
message
\x18\x02
\x01
(
\t\"
-
\n\x17
GetInternalStateRequest
\x12\x12\n\n
state_keys
\x18\x01
\x03
(
\t\"
B
\n\x18
GetInternalStateResponse
\x12
&
\n\x05
state
\x18\x01
\x01
(
\x0b\x32\x17
.google.protobuf.Struct
\"
A
\n\x17
SetInternalStateRequest
\x12
&
\n\x05
state
\x18\x01
\x01
(
\x0b\x32\x17
.google.protobuf.Struct
\"
<
\n\x18
SetInternalStateResponse
\x12\x0f\n\x07
success
\x18\x01
\x01
(
\x08\x12\x0f\n\x07
message
\x18\x02
\x01
(
\t
2
\xfe\x02\n\x0f
SglangScheduler
\x12
]
\n\x08
Generate
\x12
&.sglang.grpc.scheduler.GenerateRequest
\x1a\'
.sglang.grpc.scheduler.GenerateResponse0
\x01\x12
R
\n\x05\x45
mbed
\x12
#.sglang.grpc.scheduler.EmbedRequest
\x1a
$.sglang.grpc.scheduler.EmbedResponse
\x12\x64\n\x0b
HealthCheck
\x12
).sglang.grpc.scheduler.HealthCheckRequest
\x1a
*.sglang.grpc.scheduler.HealthCheckResponse
\x12
R
\n\x05\x41\x62
ort
\x12
#.sglang.grpc.scheduler.AbortRequest
\x1a
$.sglang.grpc.scheduler.AbortResponseb
\x06
proto3'
)
_globals
=
globals
()
_builder
.
BuildMessageAndEnumDescriptors
(
DESCRIPTOR
,
_globals
)
_builder
.
BuildTopDescriptorsAndMessages
(
DESCRIPTOR
,
'sglang_scheduler_pb2'
,
_globals
)
if
not
_descriptor
.
_USE_C_DESCRIPTORS
:
DESCRIPTOR
.
_loaded_options
=
None
_globals
[
'_SAMPLINGPARAMS_LOGITBIASENTRY'
].
_loaded_options
=
None
_globals
[
'_SAMPLINGPARAMS_LOGITBIASENTRY'
].
_serialized_options
=
b
'8
\001
'
_globals
[
'_SAMPLINGPARAMS'
].
_serialized_start
=
113
_globals
[
'_SAMPLINGPARAMS'
].
_serialized_end
=
824
_globals
[
'_SAMPLINGPARAMS_LOGITBIASENTRY'
].
_serialized_start
=
762
_globals
[
'_SAMPLINGPARAMS_LOGITBIASENTRY'
].
_serialized_end
=
810
_globals
[
'_DISAGGREGATEDPARAMS'
].
_serialized_start
=
826
_globals
[
'_DISAGGREGATEDPARAMS'
].
_serialized_end
=
919
_globals
[
'_GENERATEREQUEST'
].
_serialized_start
=
922
_globals
[
'_GENERATEREQUEST'
].
_serialized_end
=
1539
_globals
[
'_TOKENIZEDINPUT'
].
_serialized_start
=
1541
_globals
[
'_TOKENIZEDINPUT'
].
_serialized_end
=
1599
_globals
[
'_MULTIMODALINPUTS'
].
_serialized_start
=
1602
_globals
[
'_MULTIMODALINPUTS'
].
_serialized_end
=
1813
_globals
[
'_GENERATERESPONSE'
].
_serialized_start
=
1816
_globals
[
'_GENERATERESPONSE'
].
_serialized_end
=
2043
_globals
[
'_GENERATESTREAMCHUNK'
].
_serialized_start
=
2046
_globals
[
'_GENERATESTREAMCHUNK'
].
_serialized_end
=
2291
_globals
[
'_GENERATECOMPLETE'
].
_serialized_start
=
2294
_globals
[
'_GENERATECOMPLETE'
].
_serialized_end
=
2627
_globals
[
'_GENERATECOMPLETE_FINISHREASON'
].
_serialized_start
=
2551
_globals
[
'_GENERATECOMPLETE_FINISHREASON'
].
_serialized_end
=
2627
_globals
[
'_GENERATEERROR'
].
_serialized_start
=
2629
_globals
[
'_GENERATEERROR'
].
_serialized_end
=
2704
_globals
[
'_LOGPROBS'
].
_serialized_start
=
2707
_globals
[
'_LOGPROBS'
].
_serialized_end
=
2839
_globals
[
'_TOPLOGPROBS'
].
_serialized_start
=
2841
_globals
[
'_TOPLOGPROBS'
].
_serialized_end
=
2910
_globals
[
'_HIDDENSTATES'
].
_serialized_start
=
2912
_globals
[
'_HIDDENSTATES'
].
_serialized_end
=
2975
_globals
[
'_EMBEDREQUEST'
].
_serialized_start
=
2978
_globals
[
'_EMBEDREQUEST'
].
_serialized_end
=
3308
_globals
[
'_EMBEDRESPONSE'
].
_serialized_start
=
3311
_globals
[
'_EMBEDRESPONSE'
].
_serialized_end
=
3468
_globals
[
'_EMBEDCOMPLETE'
].
_serialized_start
=
3471
_globals
[
'_EMBEDCOMPLETE'
].
_serialized_end
=
3659
_globals
[
'_EMBEDDING'
].
_serialized_start
=
3661
_globals
[
'_EMBEDDING'
].
_serialized_end
=
3703
_globals
[
'_EMBEDERROR'
].
_serialized_start
=
3705
_globals
[
'_EMBEDERROR'
].
_serialized_end
=
3765
_globals
[
'_HEALTHCHECKREQUEST'
].
_serialized_start
=
3767
_globals
[
'_HEALTHCHECKREQUEST'
].
_serialized_end
=
3845
_globals
[
'_HEALTHCHECKRESPONSE'
].
_serialized_start
=
3847
_globals
[
'_HEALTHCHECKRESPONSE'
].
_serialized_end
=
3902
_globals
[
'_ABORTREQUEST'
].
_serialized_start
=
3904
_globals
[
'_ABORTREQUEST'
].
_serialized_end
=
3954
_globals
[
'_ABORTRESPONSE'
].
_serialized_start
=
3956
_globals
[
'_ABORTRESPONSE'
].
_serialized_end
=
4005
_globals
[
'_LOADLORAREQUEST'
].
_serialized_start
=
4007
_globals
[
'_LOADLORAREQUEST'
].
_serialized_end
=
4080
_globals
[
'_LOADLORARESPONSE'
].
_serialized_start
=
4082
_globals
[
'_LOADLORARESPONSE'
].
_serialized_end
=
4154
_globals
[
'_UNLOADLORAREQUEST'
].
_serialized_start
=
4156
_globals
[
'_UNLOADLORAREQUEST'
].
_serialized_end
=
4195
_globals
[
'_UNLOADLORARESPONSE'
].
_serialized_start
=
4197
_globals
[
'_UNLOADLORARESPONSE'
].
_serialized_end
=
4251
_globals
[
'_UPDATEWEIGHTSREQUEST'
].
_serialized_start
=
4253
_globals
[
'_UPDATEWEIGHTSREQUEST'
].
_serialized_end
=
4372
_globals
[
'_UPDATEWEIGHTSRESPONSE'
].
_serialized_start
=
4374
_globals
[
'_UPDATEWEIGHTSRESPONSE'
].
_serialized_end
=
4431
_globals
[
'_GETINTERNALSTATEREQUEST'
].
_serialized_start
=
4433
_globals
[
'_GETINTERNALSTATEREQUEST'
].
_serialized_end
=
4478
_globals
[
'_GETINTERNALSTATERESPONSE'
].
_serialized_start
=
4480
_globals
[
'_GETINTERNALSTATERESPONSE'
].
_serialized_end
=
4546
_globals
[
'_SETINTERNALSTATEREQUEST'
].
_serialized_start
=
4548
_globals
[
'_SETINTERNALSTATEREQUEST'
].
_serialized_end
=
4613
_globals
[
'_SETINTERNALSTATERESPONSE'
].
_serialized_start
=
4615
_globals
[
'_SETINTERNALSTATERESPONSE'
].
_serialized_end
=
4675
_globals
[
'_SGLANGSCHEDULER'
].
_serialized_start
=
4678
_globals
[
'_SGLANGSCHEDULER'
].
_serialized_end
=
5060
# @@protoc_insertion_point(module_scope)
python/sglang/srt/grpc/sglang_scheduler_pb2.pyi
0 → 100644
View file @
53ca1552
import datetime
from google.protobuf import timestamp_pb2 as _timestamp_pb2
from google.protobuf import struct_pb2 as _struct_pb2
from google.protobuf.internal import containers as _containers
from google.protobuf.internal import enum_type_wrapper as _enum_type_wrapper
from google.protobuf import descriptor as _descriptor
from google.protobuf import message as _message
from collections.abc import Iterable as _Iterable, Mapping as _Mapping
from typing import ClassVar as _ClassVar, Optional as _Optional, Union as _Union
DESCRIPTOR: _descriptor.FileDescriptor
class SamplingParams(_message.Message):
__slots__ = ("temperature", "top_p", "top_k", "min_p", "frequency_penalty", "presence_penalty", "repetition_penalty", "max_new_tokens", "stop", "stop_token_ids", "skip_special_tokens", "spaces_between_special_tokens", "regex", "json_schema", "ebnf_grammar", "lora_path", "n", "token_healing", "min_new_tokens", "ignore_eos", "no_stop_trim", "stream_interval", "logit_bias", "structural_tag", "custom_params")
class LogitBiasEntry(_message.Message):
__slots__ = ("key", "value")
KEY_FIELD_NUMBER: _ClassVar[int]
VALUE_FIELD_NUMBER: _ClassVar[int]
key: str
value: float
def __init__(self, key: _Optional[str] = ..., value: _Optional[float] = ...) -> None: ...
TEMPERATURE_FIELD_NUMBER: _ClassVar[int]
TOP_P_FIELD_NUMBER: _ClassVar[int]
TOP_K_FIELD_NUMBER: _ClassVar[int]
MIN_P_FIELD_NUMBER: _ClassVar[int]
FREQUENCY_PENALTY_FIELD_NUMBER: _ClassVar[int]
PRESENCE_PENALTY_FIELD_NUMBER: _ClassVar[int]
REPETITION_PENALTY_FIELD_NUMBER: _ClassVar[int]
MAX_NEW_TOKENS_FIELD_NUMBER: _ClassVar[int]
STOP_FIELD_NUMBER: _ClassVar[int]
STOP_TOKEN_IDS_FIELD_NUMBER: _ClassVar[int]
SKIP_SPECIAL_TOKENS_FIELD_NUMBER: _ClassVar[int]
SPACES_BETWEEN_SPECIAL_TOKENS_FIELD_NUMBER: _ClassVar[int]
REGEX_FIELD_NUMBER: _ClassVar[int]
JSON_SCHEMA_FIELD_NUMBER: _ClassVar[int]
EBNF_GRAMMAR_FIELD_NUMBER: _ClassVar[int]
LORA_PATH_FIELD_NUMBER: _ClassVar[int]
N_FIELD_NUMBER: _ClassVar[int]
TOKEN_HEALING_FIELD_NUMBER: _ClassVar[int]
MIN_NEW_TOKENS_FIELD_NUMBER: _ClassVar[int]
IGNORE_EOS_FIELD_NUMBER: _ClassVar[int]
NO_STOP_TRIM_FIELD_NUMBER: _ClassVar[int]
STREAM_INTERVAL_FIELD_NUMBER: _ClassVar[int]
LOGIT_BIAS_FIELD_NUMBER: _ClassVar[int]
STRUCTURAL_TAG_FIELD_NUMBER: _ClassVar[int]
CUSTOM_PARAMS_FIELD_NUMBER: _ClassVar[int]
temperature: float
top_p: float
top_k: int
min_p: float
frequency_penalty: float
presence_penalty: float
repetition_penalty: float
max_new_tokens: int
stop: _containers.RepeatedScalarFieldContainer[str]
stop_token_ids: _containers.RepeatedScalarFieldContainer[int]
skip_special_tokens: bool
spaces_between_special_tokens: bool
regex: str
json_schema: str
ebnf_grammar: str
lora_path: str
n: int
token_healing: bool
min_new_tokens: int
ignore_eos: bool
no_stop_trim: bool
stream_interval: int
logit_bias: _containers.ScalarMap[str, float]
structural_tag: str
custom_params: _struct_pb2.Struct
def __init__(self, temperature: _Optional[float] = ..., top_p: _Optional[float] = ..., top_k: _Optional[int] = ..., min_p: _Optional[float] = ..., frequency_penalty: _Optional[float] = ..., presence_penalty: _Optional[float] = ..., repetition_penalty: _Optional[float] = ..., max_new_tokens: _Optional[int] = ..., stop: _Optional[_Iterable[str]] = ..., stop_token_ids: _Optional[_Iterable[int]] = ..., skip_special_tokens: bool = ..., spaces_between_special_tokens: bool = ..., regex: _Optional[str] = ..., json_schema: _Optional[str] = ..., ebnf_grammar: _Optional[str] = ..., lora_path: _Optional[str] = ..., n: _Optional[int] = ..., token_healing: bool = ..., min_new_tokens: _Optional[int] = ..., ignore_eos: bool = ..., no_stop_trim: bool = ..., stream_interval: _Optional[int] = ..., logit_bias: _Optional[_Mapping[str, float]] = ..., structural_tag: _Optional[str] = ..., custom_params: _Optional[_Union[_struct_pb2.Struct, _Mapping]] = ...) -> None: ...
class DisaggregatedParams(_message.Message):
__slots__ = ("bootstrap_host", "bootstrap_port", "bootstrap_room")
BOOTSTRAP_HOST_FIELD_NUMBER: _ClassVar[int]
BOOTSTRAP_PORT_FIELD_NUMBER: _ClassVar[int]
BOOTSTRAP_ROOM_FIELD_NUMBER: _ClassVar[int]
bootstrap_host: str
bootstrap_port: int
bootstrap_room: int
def __init__(self, bootstrap_host: _Optional[str] = ..., bootstrap_port: _Optional[int] = ..., bootstrap_room: _Optional[int] = ...) -> None: ...
class GenerateRequest(_message.Message):
__slots__ = ("request_id", "tokenized", "mm_inputs", "sampling_params", "return_logprob", "logprob_start_len", "top_logprobs_num", "token_ids_logprob", "return_hidden_states", "disaggregated_params", "custom_logit_processor", "timestamp", "log_metrics", "input_embeds", "lora_id", "data_parallel_rank", "dp_balance_id")
REQUEST_ID_FIELD_NUMBER: _ClassVar[int]
TOKENIZED_FIELD_NUMBER: _ClassVar[int]
MM_INPUTS_FIELD_NUMBER: _ClassVar[int]
SAMPLING_PARAMS_FIELD_NUMBER: _ClassVar[int]
RETURN_LOGPROB_FIELD_NUMBER: _ClassVar[int]
LOGPROB_START_LEN_FIELD_NUMBER: _ClassVar[int]
TOP_LOGPROBS_NUM_FIELD_NUMBER: _ClassVar[int]
TOKEN_IDS_LOGPROB_FIELD_NUMBER: _ClassVar[int]
RETURN_HIDDEN_STATES_FIELD_NUMBER: _ClassVar[int]
DISAGGREGATED_PARAMS_FIELD_NUMBER: _ClassVar[int]
CUSTOM_LOGIT_PROCESSOR_FIELD_NUMBER: _ClassVar[int]
TIMESTAMP_FIELD_NUMBER: _ClassVar[int]
LOG_METRICS_FIELD_NUMBER: _ClassVar[int]
INPUT_EMBEDS_FIELD_NUMBER: _ClassVar[int]
LORA_ID_FIELD_NUMBER: _ClassVar[int]
DATA_PARALLEL_RANK_FIELD_NUMBER: _ClassVar[int]
DP_BALANCE_ID_FIELD_NUMBER: _ClassVar[int]
request_id: str
tokenized: TokenizedInput
mm_inputs: MultimodalInputs
sampling_params: SamplingParams
return_logprob: bool
logprob_start_len: int
top_logprobs_num: int
token_ids_logprob: _containers.RepeatedScalarFieldContainer[int]
return_hidden_states: bool
disaggregated_params: DisaggregatedParams
custom_logit_processor: str
timestamp: _timestamp_pb2.Timestamp
log_metrics: bool
input_embeds: _containers.RepeatedScalarFieldContainer[float]
lora_id: str
data_parallel_rank: int
dp_balance_id: int
def __init__(self, request_id: _Optional[str] = ..., tokenized: _Optional[_Union[TokenizedInput, _Mapping]] = ..., mm_inputs: _Optional[_Union[MultimodalInputs, _Mapping]] = ..., sampling_params: _Optional[_Union[SamplingParams, _Mapping]] = ..., return_logprob: bool = ..., logprob_start_len: _Optional[int] = ..., top_logprobs_num: _Optional[int] = ..., token_ids_logprob: _Optional[_Iterable[int]] = ..., return_hidden_states: bool = ..., disaggregated_params: _Optional[_Union[DisaggregatedParams, _Mapping]] = ..., custom_logit_processor: _Optional[str] = ..., timestamp: _Optional[_Union[datetime.datetime, _timestamp_pb2.Timestamp, _Mapping]] = ..., log_metrics: bool = ..., input_embeds: _Optional[_Iterable[float]] = ..., lora_id: _Optional[str] = ..., data_parallel_rank: _Optional[int] = ..., dp_balance_id: _Optional[int] = ...) -> None: ...
class TokenizedInput(_message.Message):
__slots__ = ("original_text", "input_ids")
ORIGINAL_TEXT_FIELD_NUMBER: _ClassVar[int]
INPUT_IDS_FIELD_NUMBER: _ClassVar[int]
original_text: str
input_ids: _containers.RepeatedScalarFieldContainer[int]
def __init__(self, original_text: _Optional[str] = ..., input_ids: _Optional[_Iterable[int]] = ...) -> None: ...
class MultimodalInputs(_message.Message):
__slots__ = ("image_urls", "video_urls", "audio_urls", "processed_features", "image_data", "video_data", "audio_data", "modalities")
IMAGE_URLS_FIELD_NUMBER: _ClassVar[int]
VIDEO_URLS_FIELD_NUMBER: _ClassVar[int]
AUDIO_URLS_FIELD_NUMBER: _ClassVar[int]
PROCESSED_FEATURES_FIELD_NUMBER: _ClassVar[int]
IMAGE_DATA_FIELD_NUMBER: _ClassVar[int]
VIDEO_DATA_FIELD_NUMBER: _ClassVar[int]
AUDIO_DATA_FIELD_NUMBER: _ClassVar[int]
MODALITIES_FIELD_NUMBER: _ClassVar[int]
image_urls: _containers.RepeatedScalarFieldContainer[str]
video_urls: _containers.RepeatedScalarFieldContainer[str]
audio_urls: _containers.RepeatedScalarFieldContainer[str]
processed_features: _struct_pb2.Struct
image_data: _containers.RepeatedScalarFieldContainer[bytes]
video_data: _containers.RepeatedScalarFieldContainer[bytes]
audio_data: _containers.RepeatedScalarFieldContainer[bytes]
modalities: _containers.RepeatedScalarFieldContainer[str]
def __init__(self, image_urls: _Optional[_Iterable[str]] = ..., video_urls: _Optional[_Iterable[str]] = ..., audio_urls: _Optional[_Iterable[str]] = ..., processed_features: _Optional[_Union[_struct_pb2.Struct, _Mapping]] = ..., image_data: _Optional[_Iterable[bytes]] = ..., video_data: _Optional[_Iterable[bytes]] = ..., audio_data: _Optional[_Iterable[bytes]] = ..., modalities: _Optional[_Iterable[str]] = ...) -> None: ...
class GenerateResponse(_message.Message):
__slots__ = ("request_id", "chunk", "complete", "error")
REQUEST_ID_FIELD_NUMBER: _ClassVar[int]
CHUNK_FIELD_NUMBER: _ClassVar[int]
COMPLETE_FIELD_NUMBER: _ClassVar[int]
ERROR_FIELD_NUMBER: _ClassVar[int]
request_id: str
chunk: GenerateStreamChunk
complete: GenerateComplete
error: GenerateError
def __init__(self, request_id: _Optional[str] = ..., chunk: _Optional[_Union[GenerateStreamChunk, _Mapping]] = ..., complete: _Optional[_Union[GenerateComplete, _Mapping]] = ..., error: _Optional[_Union[GenerateError, _Mapping]] = ...) -> None: ...
class GenerateStreamChunk(_message.Message):
__slots__ = ("token_id", "text", "prompt_tokens", "completion_tokens", "cached_tokens", "logprobs", "hidden_states", "generation_time", "queue_time")
TOKEN_ID_FIELD_NUMBER: _ClassVar[int]
TEXT_FIELD_NUMBER: _ClassVar[int]
PROMPT_TOKENS_FIELD_NUMBER: _ClassVar[int]
COMPLETION_TOKENS_FIELD_NUMBER: _ClassVar[int]
CACHED_TOKENS_FIELD_NUMBER: _ClassVar[int]
LOGPROBS_FIELD_NUMBER: _ClassVar[int]
HIDDEN_STATES_FIELD_NUMBER: _ClassVar[int]
GENERATION_TIME_FIELD_NUMBER: _ClassVar[int]
QUEUE_TIME_FIELD_NUMBER: _ClassVar[int]
token_id: int
text: str
prompt_tokens: int
completion_tokens: int
cached_tokens: int
logprobs: LogProbs
hidden_states: _containers.RepeatedScalarFieldContainer[float]
generation_time: float
queue_time: int
def __init__(self, token_id: _Optional[int] = ..., text: _Optional[str] = ..., prompt_tokens: _Optional[int] = ..., completion_tokens: _Optional[int] = ..., cached_tokens: _Optional[int] = ..., logprobs: _Optional[_Union[LogProbs, _Mapping]] = ..., hidden_states: _Optional[_Iterable[float]] = ..., generation_time: _Optional[float] = ..., queue_time: _Optional[int] = ...) -> None: ...
class GenerateComplete(_message.Message):
__slots__ = ("output_ids", "output_text", "finish_reason", "all_logprobs", "all_hidden_states")
class FinishReason(int, metaclass=_enum_type_wrapper.EnumTypeWrapper):
__slots__ = ()
STOP: _ClassVar[GenerateComplete.FinishReason]
LENGTH: _ClassVar[GenerateComplete.FinishReason]
EOS_TOKEN: _ClassVar[GenerateComplete.FinishReason]
STOP_STR: _ClassVar[GenerateComplete.FinishReason]
ABORT: _ClassVar[GenerateComplete.FinishReason]
STOP: GenerateComplete.FinishReason
LENGTH: GenerateComplete.FinishReason
EOS_TOKEN: GenerateComplete.FinishReason
STOP_STR: GenerateComplete.FinishReason
ABORT: GenerateComplete.FinishReason
OUTPUT_IDS_FIELD_NUMBER: _ClassVar[int]
OUTPUT_TEXT_FIELD_NUMBER: _ClassVar[int]
FINISH_REASON_FIELD_NUMBER: _ClassVar[int]
ALL_LOGPROBS_FIELD_NUMBER: _ClassVar[int]
ALL_HIDDEN_STATES_FIELD_NUMBER: _ClassVar[int]
output_ids: _containers.RepeatedScalarFieldContainer[int]
output_text: str
finish_reason: GenerateComplete.FinishReason
all_logprobs: _containers.RepeatedCompositeFieldContainer[LogProbs]
all_hidden_states: _containers.RepeatedCompositeFieldContainer[HiddenStates]
def __init__(self, output_ids: _Optional[_Iterable[int]] = ..., output_text: _Optional[str] = ..., finish_reason: _Optional[_Union[GenerateComplete.FinishReason, str]] = ..., all_logprobs: _Optional[_Iterable[_Union[LogProbs, _Mapping]]] = ..., all_hidden_states: _Optional[_Iterable[_Union[HiddenStates, _Mapping]]] = ...) -> None: ...
class GenerateError(_message.Message):
__slots__ = ("message", "http_status_code", "details")
MESSAGE_FIELD_NUMBER: _ClassVar[int]
HTTP_STATUS_CODE_FIELD_NUMBER: _ClassVar[int]
DETAILS_FIELD_NUMBER: _ClassVar[int]
message: str
http_status_code: str
details: str
def __init__(self, message: _Optional[str] = ..., http_status_code: _Optional[str] = ..., details: _Optional[str] = ...) -> None: ...
class LogProbs(_message.Message):
__slots__ = ("token_logprobs", "token_ids", "top_logprobs", "token_texts")
TOKEN_LOGPROBS_FIELD_NUMBER: _ClassVar[int]
TOKEN_IDS_FIELD_NUMBER: _ClassVar[int]
TOP_LOGPROBS_FIELD_NUMBER: _ClassVar[int]
TOKEN_TEXTS_FIELD_NUMBER: _ClassVar[int]
token_logprobs: _containers.RepeatedScalarFieldContainer[float]
token_ids: _containers.RepeatedScalarFieldContainer[int]
top_logprobs: _containers.RepeatedCompositeFieldContainer[TopLogProbs]
token_texts: _containers.RepeatedScalarFieldContainer[str]
def __init__(self, token_logprobs: _Optional[_Iterable[float]] = ..., token_ids: _Optional[_Iterable[int]] = ..., top_logprobs: _Optional[_Iterable[_Union[TopLogProbs, _Mapping]]] = ..., token_texts: _Optional[_Iterable[str]] = ...) -> None: ...
class TopLogProbs(_message.Message):
__slots__ = ("values", "token_ids", "token_texts")
VALUES_FIELD_NUMBER: _ClassVar[int]
TOKEN_IDS_FIELD_NUMBER: _ClassVar[int]
TOKEN_TEXTS_FIELD_NUMBER: _ClassVar[int]
values: _containers.RepeatedScalarFieldContainer[float]
token_ids: _containers.RepeatedScalarFieldContainer[int]
token_texts: _containers.RepeatedScalarFieldContainer[str]
def __init__(self, values: _Optional[_Iterable[float]] = ..., token_ids: _Optional[_Iterable[int]] = ..., token_texts: _Optional[_Iterable[str]] = ...) -> None: ...
class HiddenStates(_message.Message):
__slots__ = ("values", "layer", "position")
VALUES_FIELD_NUMBER: _ClassVar[int]
LAYER_FIELD_NUMBER: _ClassVar[int]
POSITION_FIELD_NUMBER: _ClassVar[int]
values: _containers.RepeatedScalarFieldContainer[float]
layer: int
position: int
def __init__(self, values: _Optional[_Iterable[float]] = ..., layer: _Optional[int] = ..., position: _Optional[int] = ...) -> None: ...
class EmbedRequest(_message.Message):
__slots__ = ("request_id", "tokenized", "mm_inputs", "sampling_params", "log_metrics", "token_type_ids", "data_parallel_rank", "is_cross_encoder", "texts")
REQUEST_ID_FIELD_NUMBER: _ClassVar[int]
TOKENIZED_FIELD_NUMBER: _ClassVar[int]
MM_INPUTS_FIELD_NUMBER: _ClassVar[int]
SAMPLING_PARAMS_FIELD_NUMBER: _ClassVar[int]
LOG_METRICS_FIELD_NUMBER: _ClassVar[int]
TOKEN_TYPE_IDS_FIELD_NUMBER: _ClassVar[int]
DATA_PARALLEL_RANK_FIELD_NUMBER: _ClassVar[int]
IS_CROSS_ENCODER_FIELD_NUMBER: _ClassVar[int]
TEXTS_FIELD_NUMBER: _ClassVar[int]
request_id: str
tokenized: TokenizedInput
mm_inputs: MultimodalInputs
sampling_params: SamplingParams
log_metrics: bool
token_type_ids: _containers.RepeatedScalarFieldContainer[int]
data_parallel_rank: int
is_cross_encoder: bool
texts: _containers.RepeatedScalarFieldContainer[str]
def __init__(self, request_id: _Optional[str] = ..., tokenized: _Optional[_Union[TokenizedInput, _Mapping]] = ..., mm_inputs: _Optional[_Union[MultimodalInputs, _Mapping]] = ..., sampling_params: _Optional[_Union[SamplingParams, _Mapping]] = ..., log_metrics: bool = ..., token_type_ids: _Optional[_Iterable[int]] = ..., data_parallel_rank: _Optional[int] = ..., is_cross_encoder: bool = ..., texts: _Optional[_Iterable[str]] = ...) -> None: ...
class EmbedResponse(_message.Message):
__slots__ = ("request_id", "complete", "error")
REQUEST_ID_FIELD_NUMBER: _ClassVar[int]
COMPLETE_FIELD_NUMBER: _ClassVar[int]
ERROR_FIELD_NUMBER: _ClassVar[int]
request_id: str
complete: EmbedComplete
error: EmbedError
def __init__(self, request_id: _Optional[str] = ..., complete: _Optional[_Union[EmbedComplete, _Mapping]] = ..., error: _Optional[_Union[EmbedError, _Mapping]] = ...) -> None: ...
class EmbedComplete(_message.Message):
__slots__ = ("embedding", "prompt_tokens", "cached_tokens", "embedding_dim", "generation_time", "batch_embeddings")
EMBEDDING_FIELD_NUMBER: _ClassVar[int]
PROMPT_TOKENS_FIELD_NUMBER: _ClassVar[int]
CACHED_TOKENS_FIELD_NUMBER: _ClassVar[int]
EMBEDDING_DIM_FIELD_NUMBER: _ClassVar[int]
GENERATION_TIME_FIELD_NUMBER: _ClassVar[int]
BATCH_EMBEDDINGS_FIELD_NUMBER: _ClassVar[int]
embedding: _containers.RepeatedScalarFieldContainer[float]
prompt_tokens: int
cached_tokens: int
embedding_dim: int
generation_time: float
batch_embeddings: _containers.RepeatedCompositeFieldContainer[Embedding]
def __init__(self, embedding: _Optional[_Iterable[float]] = ..., prompt_tokens: _Optional[int] = ..., cached_tokens: _Optional[int] = ..., embedding_dim: _Optional[int] = ..., generation_time: _Optional[float] = ..., batch_embeddings: _Optional[_Iterable[_Union[Embedding, _Mapping]]] = ...) -> None: ...
class Embedding(_message.Message):
__slots__ = ("values", "index")
VALUES_FIELD_NUMBER: _ClassVar[int]
INDEX_FIELD_NUMBER: _ClassVar[int]
values: _containers.RepeatedScalarFieldContainer[float]
index: int
def __init__(self, values: _Optional[_Iterable[float]] = ..., index: _Optional[int] = ...) -> None: ...
class EmbedError(_message.Message):
__slots__ = ("message", "code", "details")
MESSAGE_FIELD_NUMBER: _ClassVar[int]
CODE_FIELD_NUMBER: _ClassVar[int]
DETAILS_FIELD_NUMBER: _ClassVar[int]
message: str
code: str
details: str
def __init__(self, message: _Optional[str] = ..., code: _Optional[str] = ..., details: _Optional[str] = ...) -> None: ...
class HealthCheckRequest(_message.Message):
__slots__ = ("tokenized",)
TOKENIZED_FIELD_NUMBER: _ClassVar[int]
tokenized: TokenizedInput
def __init__(self, tokenized: _Optional[_Union[TokenizedInput, _Mapping]] = ...) -> None: ...
class HealthCheckResponse(_message.Message):
__slots__ = ("healthy", "message")
HEALTHY_FIELD_NUMBER: _ClassVar[int]
MESSAGE_FIELD_NUMBER: _ClassVar[int]
healthy: bool
message: str
def __init__(self, healthy: bool = ..., message: _Optional[str] = ...) -> None: ...
class AbortRequest(_message.Message):
__slots__ = ("request_id", "reason")
REQUEST_ID_FIELD_NUMBER: _ClassVar[int]
REASON_FIELD_NUMBER: _ClassVar[int]
request_id: str
reason: str
def __init__(self, request_id: _Optional[str] = ..., reason: _Optional[str] = ...) -> None: ...
class AbortResponse(_message.Message):
__slots__ = ("success", "message")
SUCCESS_FIELD_NUMBER: _ClassVar[int]
MESSAGE_FIELD_NUMBER: _ClassVar[int]
success: bool
message: str
def __init__(self, success: bool = ..., message: _Optional[str] = ...) -> None: ...
class LoadLoRARequest(_message.Message):
__slots__ = ("adapter_id", "adapter_path", "rank")
ADAPTER_ID_FIELD_NUMBER: _ClassVar[int]
ADAPTER_PATH_FIELD_NUMBER: _ClassVar[int]
RANK_FIELD_NUMBER: _ClassVar[int]
adapter_id: str
adapter_path: str
rank: int
def __init__(self, adapter_id: _Optional[str] = ..., adapter_path: _Optional[str] = ..., rank: _Optional[int] = ...) -> None: ...
class LoadLoRAResponse(_message.Message):
__slots__ = ("success", "adapter_id", "message")
SUCCESS_FIELD_NUMBER: _ClassVar[int]
ADAPTER_ID_FIELD_NUMBER: _ClassVar[int]
MESSAGE_FIELD_NUMBER: _ClassVar[int]
success: bool
adapter_id: str
message: str
def __init__(self, success: bool = ..., adapter_id: _Optional[str] = ..., message: _Optional[str] = ...) -> None: ...
class UnloadLoRARequest(_message.Message):
__slots__ = ("adapter_id",)
ADAPTER_ID_FIELD_NUMBER: _ClassVar[int]
adapter_id: str
def __init__(self, adapter_id: _Optional[str] = ...) -> None: ...
class UnloadLoRAResponse(_message.Message):
__slots__ = ("success", "message")
SUCCESS_FIELD_NUMBER: _ClassVar[int]
MESSAGE_FIELD_NUMBER: _ClassVar[int]
success: bool
message: str
def __init__(self, success: bool = ..., message: _Optional[str] = ...) -> None: ...
class UpdateWeightsRequest(_message.Message):
__slots__ = ("disk_path", "tensor_data", "remote_url", "weight_name")
DISK_PATH_FIELD_NUMBER: _ClassVar[int]
TENSOR_DATA_FIELD_NUMBER: _ClassVar[int]
REMOTE_URL_FIELD_NUMBER: _ClassVar[int]
WEIGHT_NAME_FIELD_NUMBER: _ClassVar[int]
disk_path: str
tensor_data: bytes
remote_url: str
weight_name: str
def __init__(self, disk_path: _Optional[str] = ..., tensor_data: _Optional[bytes] = ..., remote_url: _Optional[str] = ..., weight_name: _Optional[str] = ...) -> None: ...
class UpdateWeightsResponse(_message.Message):
__slots__ = ("success", "message")
SUCCESS_FIELD_NUMBER: _ClassVar[int]
MESSAGE_FIELD_NUMBER: _ClassVar[int]
success: bool
message: str
def __init__(self, success: bool = ..., message: _Optional[str] = ...) -> None: ...
class GetInternalStateRequest(_message.Message):
__slots__ = ("state_keys",)
STATE_KEYS_FIELD_NUMBER: _ClassVar[int]
state_keys: _containers.RepeatedScalarFieldContainer[str]
def __init__(self, state_keys: _Optional[_Iterable[str]] = ...) -> None: ...
class GetInternalStateResponse(_message.Message):
__slots__ = ("state",)
STATE_FIELD_NUMBER: _ClassVar[int]
state: _struct_pb2.Struct
def __init__(self, state: _Optional[_Union[_struct_pb2.Struct, _Mapping]] = ...) -> None: ...
class SetInternalStateRequest(_message.Message):
__slots__ = ("state",)
STATE_FIELD_NUMBER: _ClassVar[int]
state: _struct_pb2.Struct
def __init__(self, state: _Optional[_Union[_struct_pb2.Struct, _Mapping]] = ...) -> None: ...
class SetInternalStateResponse(_message.Message):
__slots__ = ("success", "message")
SUCCESS_FIELD_NUMBER: _ClassVar[int]
MESSAGE_FIELD_NUMBER: _ClassVar[int]
success: bool
message: str
def __init__(self, success: bool = ..., message: _Optional[str] = ...) -> None: ...
python/sglang/srt/grpc/sglang_scheduler_pb2_grpc.py
0 → 100644
View file @
53ca1552
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
"""Client and server classes corresponding to protobuf-defined services."""
import
grpc
import
warnings
from
.
import
sglang_scheduler_pb2
as
sglang__scheduler__pb2
GRPC_GENERATED_VERSION
=
'1.74.0'
GRPC_VERSION
=
grpc
.
__version__
_version_not_supported
=
False
try
:
from
grpc._utilities
import
first_version_is_lower
_version_not_supported
=
first_version_is_lower
(
GRPC_VERSION
,
GRPC_GENERATED_VERSION
)
except
ImportError
:
_version_not_supported
=
True
if
_version_not_supported
:
raise
RuntimeError
(
f
'The grpc package installed is at version
{
GRPC_VERSION
}
,'
+
f
' but the generated code in sglang_scheduler_pb2_grpc.py depends on'
+
f
' grpcio>=
{
GRPC_GENERATED_VERSION
}
.'
+
f
' Please upgrade your grpc module to grpcio>=
{
GRPC_GENERATED_VERSION
}
'
+
f
' or downgrade your generated code using grpcio-tools<=
{
GRPC_VERSION
}
.'
)
class
SglangSchedulerStub
(
object
):
"""Service definition for SGLang scheduler communication
This protocol bridges the Rust router and Python scheduler
"""
def
__init__
(
self
,
channel
):
"""Constructor.
Args:
channel: A grpc.Channel.
"""
self
.
Generate
=
channel
.
unary_stream
(
'/sglang.grpc.scheduler.SglangScheduler/Generate'
,
request_serializer
=
sglang__scheduler__pb2
.
GenerateRequest
.
SerializeToString
,
response_deserializer
=
sglang__scheduler__pb2
.
GenerateResponse
.
FromString
,
_registered_method
=
True
)
self
.
Embed
=
channel
.
unary_unary
(
'/sglang.grpc.scheduler.SglangScheduler/Embed'
,
request_serializer
=
sglang__scheduler__pb2
.
EmbedRequest
.
SerializeToString
,
response_deserializer
=
sglang__scheduler__pb2
.
EmbedResponse
.
FromString
,
_registered_method
=
True
)
self
.
HealthCheck
=
channel
.
unary_unary
(
'/sglang.grpc.scheduler.SglangScheduler/HealthCheck'
,
request_serializer
=
sglang__scheduler__pb2
.
HealthCheckRequest
.
SerializeToString
,
response_deserializer
=
sglang__scheduler__pb2
.
HealthCheckResponse
.
FromString
,
_registered_method
=
True
)
self
.
Abort
=
channel
.
unary_unary
(
'/sglang.grpc.scheduler.SglangScheduler/Abort'
,
request_serializer
=
sglang__scheduler__pb2
.
AbortRequest
.
SerializeToString
,
response_deserializer
=
sglang__scheduler__pb2
.
AbortResponse
.
FromString
,
_registered_method
=
True
)
class
SglangSchedulerServicer
(
object
):
"""Service definition for SGLang scheduler communication
This protocol bridges the Rust router and Python scheduler
"""
def
Generate
(
self
,
request
,
context
):
"""Submit a generation request (supports streaming)
"""
context
.
set_code
(
grpc
.
StatusCode
.
UNIMPLEMENTED
)
context
.
set_details
(
'Method not implemented!'
)
raise
NotImplementedError
(
'Method not implemented!'
)
def
Embed
(
self
,
request
,
context
):
"""Submit an embedding request
"""
context
.
set_code
(
grpc
.
StatusCode
.
UNIMPLEMENTED
)
context
.
set_details
(
'Method not implemented!'
)
raise
NotImplementedError
(
'Method not implemented!'
)
def
HealthCheck
(
self
,
request
,
context
):
"""Health check and metrics
"""
context
.
set_code
(
grpc
.
StatusCode
.
UNIMPLEMENTED
)
context
.
set_details
(
'Method not implemented!'
)
raise
NotImplementedError
(
'Method not implemented!'
)
def
Abort
(
self
,
request
,
context
):
"""Abort a running request
"""
context
.
set_code
(
grpc
.
StatusCode
.
UNIMPLEMENTED
)
context
.
set_details
(
'Method not implemented!'
)
raise
NotImplementedError
(
'Method not implemented!'
)
def
add_SglangSchedulerServicer_to_server
(
servicer
,
server
):
rpc_method_handlers
=
{
'Generate'
:
grpc
.
unary_stream_rpc_method_handler
(
servicer
.
Generate
,
request_deserializer
=
sglang__scheduler__pb2
.
GenerateRequest
.
FromString
,
response_serializer
=
sglang__scheduler__pb2
.
GenerateResponse
.
SerializeToString
,
),
'Embed'
:
grpc
.
unary_unary_rpc_method_handler
(
servicer
.
Embed
,
request_deserializer
=
sglang__scheduler__pb2
.
EmbedRequest
.
FromString
,
response_serializer
=
sglang__scheduler__pb2
.
EmbedResponse
.
SerializeToString
,
),
'HealthCheck'
:
grpc
.
unary_unary_rpc_method_handler
(
servicer
.
HealthCheck
,
request_deserializer
=
sglang__scheduler__pb2
.
HealthCheckRequest
.
FromString
,
response_serializer
=
sglang__scheduler__pb2
.
HealthCheckResponse
.
SerializeToString
,
),
'Abort'
:
grpc
.
unary_unary_rpc_method_handler
(
servicer
.
Abort
,
request_deserializer
=
sglang__scheduler__pb2
.
AbortRequest
.
FromString
,
response_serializer
=
sglang__scheduler__pb2
.
AbortResponse
.
SerializeToString
,
),
}
generic_handler
=
grpc
.
method_handlers_generic_handler
(
'sglang.grpc.scheduler.SglangScheduler'
,
rpc_method_handlers
)
server
.
add_generic_rpc_handlers
((
generic_handler
,))
server
.
add_registered_method_handlers
(
'sglang.grpc.scheduler.SglangScheduler'
,
rpc_method_handlers
)
# This class is part of an EXPERIMENTAL API.
class
SglangScheduler
(
object
):
"""Service definition for SGLang scheduler communication
This protocol bridges the Rust router and Python scheduler
"""
@
staticmethod
def
Generate
(
request
,
target
,
options
=
(),
channel_credentials
=
None
,
call_credentials
=
None
,
insecure
=
False
,
compression
=
None
,
wait_for_ready
=
None
,
timeout
=
None
,
metadata
=
None
):
return
grpc
.
experimental
.
unary_stream
(
request
,
target
,
'/sglang.grpc.scheduler.SglangScheduler/Generate'
,
sglang__scheduler__pb2
.
GenerateRequest
.
SerializeToString
,
sglang__scheduler__pb2
.
GenerateResponse
.
FromString
,
options
,
channel_credentials
,
insecure
,
call_credentials
,
compression
,
wait_for_ready
,
timeout
,
metadata
,
_registered_method
=
True
)
@
staticmethod
def
Embed
(
request
,
target
,
options
=
(),
channel_credentials
=
None
,
call_credentials
=
None
,
insecure
=
False
,
compression
=
None
,
wait_for_ready
=
None
,
timeout
=
None
,
metadata
=
None
):
return
grpc
.
experimental
.
unary_unary
(
request
,
target
,
'/sglang.grpc.scheduler.SglangScheduler/Embed'
,
sglang__scheduler__pb2
.
EmbedRequest
.
SerializeToString
,
sglang__scheduler__pb2
.
EmbedResponse
.
FromString
,
options
,
channel_credentials
,
insecure
,
call_credentials
,
compression
,
wait_for_ready
,
timeout
,
metadata
,
_registered_method
=
True
)
@
staticmethod
def
HealthCheck
(
request
,
target
,
options
=
(),
channel_credentials
=
None
,
call_credentials
=
None
,
insecure
=
False
,
compression
=
None
,
wait_for_ready
=
None
,
timeout
=
None
,
metadata
=
None
):
return
grpc
.
experimental
.
unary_unary
(
request
,
target
,
'/sglang.grpc.scheduler.SglangScheduler/HealthCheck'
,
sglang__scheduler__pb2
.
HealthCheckRequest
.
SerializeToString
,
sglang__scheduler__pb2
.
HealthCheckResponse
.
FromString
,
options
,
channel_credentials
,
insecure
,
call_credentials
,
compression
,
wait_for_ready
,
timeout
,
metadata
,
_registered_method
=
True
)
@
staticmethod
def
Abort
(
request
,
target
,
options
=
(),
channel_credentials
=
None
,
call_credentials
=
None
,
insecure
=
False
,
compression
=
None
,
wait_for_ready
=
None
,
timeout
=
None
,
metadata
=
None
):
return
grpc
.
experimental
.
unary_unary
(
request
,
target
,
'/sglang.grpc.scheduler.SglangScheduler/Abort'
,
sglang__scheduler__pb2
.
AbortRequest
.
SerializeToString
,
sglang__scheduler__pb2
.
AbortResponse
.
FromString
,
options
,
channel_credentials
,
insecure
,
call_credentials
,
compression
,
wait_for_ready
,
timeout
,
metadata
,
_registered_method
=
True
)
python/sglang/srt/server_args.py
View file @
53ca1552
...
...
@@ -2238,6 +2238,7 @@ class ServerArgs:
args
.
pp_size
=
args
.
pipeline_parallel_size
args
.
dp_size
=
args
.
data_parallel_size
args
.
ep_size
=
args
.
expert_parallel_size
attrs
=
[
attr
.
name
for
attr
in
dataclasses
.
fields
(
cls
)]
return
cls
(
**
{
attr
:
getattr
(
args
,
attr
)
for
attr
in
attrs
})
...
...
sgl-router/src/grpc/client.rs
View file @
53ca1552
...
...
@@ -37,21 +37,6 @@ impl SglangSchedulerClient {
Ok
(
Self
{
client
})
}
/// Initialize the connection
pub
async
fn
initialize
(
&
mut
self
,
client_id
:
String
,
)
->
Result
<
proto
::
InitializeResponse
,
Box
<
dyn
std
::
error
::
Error
>>
{
let
request
=
Request
::
new
(
proto
::
InitializeRequest
{
client_id
,
client_version
:
"0.1.0"
.to_string
(),
mode
:
proto
::
initialize_request
::
Mode
::
Regular
as
i32
,
});
let
response
=
self
.client
.initialize
(
request
)
.await
?
;
Ok
(
response
.into_inner
())
}
/// Submit a generation request (returns streaming response)
pub
async
fn
generate_stream
(
&
mut
self
,
...
...
@@ -68,7 +53,10 @@ impl SglangSchedulerClient {
)
->
Result
<
proto
::
HealthCheckResponse
,
Box
<
dyn
std
::
error
::
Error
>>
{
debug!
(
"Sending health check request"
);
let
request
=
Request
::
new
(
proto
::
HealthCheckRequest
{
include_detailed_metrics
:
false
,
tokenized
:
Some
(
proto
::
TokenizedInput
{
original_text
:
"Hello"
.to_string
(),
input_ids
:
vec!
[
9906
],
// Mock token ID for "Hello"
}),
});
let
response
=
self
.client
.health_check
(
request
)
.await
?
;
...
...
@@ -87,21 +75,6 @@ impl SglangSchedulerClient {
self
.client
.abort
(
request
)
.await
?
;
Ok
(())
}
/// Flush cache
pub
async
fn
flush_cache
(
&
mut
self
,
flush_all
:
bool
,
session_ids
:
&
[
String
],
)
->
Result
<
proto
::
FlushCacheResponse
,
Box
<
dyn
std
::
error
::
Error
>>
{
let
request
=
Request
::
new
(
proto
::
FlushCacheRequest
{
flush_all
,
session_ids
:
session_ids
.to_vec
(),
});
let
response
=
self
.client
.flush_cache
(
request
)
.await
?
;
Ok
(
response
.into_inner
())
}
}
#[cfg(test)]
...
...
@@ -111,14 +84,13 @@ mod tests {
#[test]
fn
test_proto_types_compilation
()
{
// Test that protobuf types can be constructed
let
init_req
=
proto
::
InitializeRequest
{
client_id
:
"test-client"
.to_string
(),
client_version
:
"0.1.0"
.to_string
(),
mode
:
0
,
let
health_req
=
proto
::
HealthCheckRequest
{
tokenized
:
Some
(
proto
::
TokenizedInput
{
original_text
:
"test"
.to_string
(),
input_ids
:
vec!
[
1296
],
}),
};
assert_eq!
(
init_req
.client_id
,
"test-client"
);
assert_eq!
(
init_req
.client_version
,
"0.1.0"
);
assert_eq!
(
init_req
.mode
,
0
);
assert
!
(
health_req
.tokenized
.is_some
());
}
#[test]
...
...
@@ -134,9 +106,10 @@ mod tests {
let
gen_req
=
proto
::
GenerateRequest
{
request_id
:
"test-req-123"
.to_string
(),
input
:
Some
(
proto
::
generate_request
::
Input
::
Text
(
"Hello world"
.to_string
(),
)),
tokenized
:
Some
(
proto
::
TokenizedInput
{
original_text
:
"Hello world"
.to_string
(),
input_ids
:
vec!
[
9906
,
1917
],
// Mock token IDs for "Hello world"
}),
sampling_params
:
Some
(
sampling_params
),
return_logprob
:
true
,
logprob_start_len
:
0
,
...
...
@@ -145,8 +118,8 @@ mod tests {
};
assert_eq!
(
gen_req
.request_id
,
"test-req-123"
);
if
let
Some
(
proto
::
generate_request
::
Input
::
Text
(
text
)
)
=
&
gen_req
.
input
{
assert_eq!
(
text
,
"Hello world"
);
if
let
Some
(
ref
tokenized
)
=
&
gen_req
.
tokenized
{
assert_eq!
(
tokenized
.original_
text
,
"Hello world"
);
}
assert
!
(
gen_req
.return_logprob
);
assert_eq!
(
gen_req
.top_logprobs_num
,
5
);
...
...
@@ -160,9 +133,12 @@ mod tests {
#[test]
fn
test_health_check_request
()
{
let
health_req
=
proto
::
HealthCheckRequest
{
include_detailed_metrics
:
true
,
tokenized
:
Some
(
proto
::
TokenizedInput
{
original_text
:
"test"
.to_string
(),
input_ids
:
vec!
[
1296
],
// Mock token ID for "test"
}),
};
assert
!
(
health_req
.
include_detailed_metrics
);
assert
!
(
health_req
.
tokenized
.is_some
()
);
}
#[test]
...
...
@@ -175,17 +151,6 @@ mod tests {
assert_eq!
(
abort_req
.reason
,
"User canceled"
);
}
#[test]
fn
test_flush_cache_request
()
{
let
flush_req
=
proto
::
FlushCacheRequest
{
flush_all
:
true
,
session_ids
:
vec!
[
"session1"
.to_string
(),
"session2"
.to_string
()],
};
assert
!
(
flush_req
.flush_all
);
assert_eq!
(
flush_req
.session_ids
.len
(),
2
);
assert_eq!
(
flush_req
.session_ids
[
0
],
"session1"
);
}
#[test]
fn
test_sampling_params_defaults
()
{
let
params
=
proto
::
SamplingParams
::
default
();
...
...
@@ -214,38 +179,29 @@ mod tests {
assert_eq!
(
mm_inputs
.modalities
[
0
],
"image"
);
}
#[test]
fn
test_session_params
()
{
let
session_params
=
proto
::
SessionParams
{
session_id
:
"sess-789"
.to_string
(),
request_id
:
"req-101"
.to_string
(),
offset
:
100
,
replace
:
true
,
drop_previous_output
:
false
,
};
assert_eq!
(
session_params
.session_id
,
"sess-789"
);
assert_eq!
(
session_params
.request_id
,
"req-101"
);
assert_eq!
(
session_params
.offset
,
100
);
assert
!
(
session_params
.replace
);
assert
!
(
!
session_params
.drop_previous_output
);
}
// TODO: SessionParams not in current proto - skip test
// #[test]
// fn test_session_params() { ... }
#[test]
fn
test_embed_request
()
{
let
embed_req
=
proto
::
EmbedRequest
{
request_id
:
"embed-req-202"
.to_string
(),
input
:
Some
(
proto
::
embed_request
::
Input
::
Text
(
"This is a test sentence for embedding"
.to_string
(),
)),
tokenized
:
Some
(
proto
::
TokenizedInput
{
original_text
:
"This is a test sentence for embedding"
.to_string
(),
input_ids
:
vec!
[
2028
,
374
,
264
,
1296
,
11914
,
369
,
28537
],
// Mock token IDs
}),
log_metrics
:
true
,
data_parallel_rank
:
0
,
..
Default
::
default
()
};
assert_eq!
(
embed_req
.request_id
,
"embed-req-202"
);
if
let
Some
(
proto
::
embed_request
::
Input
::
Text
(
text
))
=
&
embed_req
.input
{
assert_eq!
(
text
,
"This is a test sentence for embedding"
);
if
let
Some
(
ref
tokenized
)
=
&
embed_req
.tokenized
{
assert_eq!
(
tokenized
.original_text
,
"This is a test sentence for embedding"
);
}
assert
!
(
embed_req
.log_metrics
);
assert_eq!
(
embed_req
.data_parallel_rank
,
0
);
...
...
@@ -292,36 +248,7 @@ mod tests {
assert_eq!
(
chunk
.queue_time
,
10
);
}
#[test]
fn
test_model_info
()
{
let
model_info
=
proto
::
ModelInfo
{
model_name
:
"Meta-Llama-3-8B-Instruct"
.to_string
(),
max_context_length
:
8192
,
vocab_size
:
128256
,
supports_tool_calling
:
true
,
supports_vision
:
false
,
special_tokens
:
vec!
[
"<|begin_of_text|>"
.to_string
(),
"<|end_of_text|>"
.to_string
(),
],
model_type
:
"llama"
.to_string
(),
num_layers
:
32
,
hidden_size
:
4096
,
num_attention_heads
:
32
,
num_key_value_heads
:
8
,
tokenizer_type
:
"llama"
.to_string
(),
eos_token_ids
:
vec!
[
128001
,
128009
],
pad_token_id
:
128001
,
bos_token_id
:
128000
,
};
assert_eq!
(
model_info
.model_name
,
"Meta-Llama-3-8B-Instruct"
);
assert_eq!
(
model_info
.max_context_length
,
8192
);
assert_eq!
(
model_info
.vocab_size
,
128256
);
assert
!
(
model_info
.supports_tool_calling
);
assert
!
(
!
model_info
.supports_vision
);
assert_eq!
(
model_info
.special_tokens
.len
(),
2
);
assert_eq!
(
model_info
.num_layers
,
32
);
assert_eq!
(
model_info
.eos_token_ids
,
vec!
[
128001
,
128009
]);
}
// TODO: ModelInfo not in current proto - skip test
// #[test]
// fn test_model_info() { ... }
}
sgl-router/src/proto/sglang_scheduler.proto
View file @
53ca1552
...
...
@@ -8,9 +8,6 @@ import "google/protobuf/struct.proto";
// Service definition for SGLang scheduler communication
// This protocol bridges the Rust router and Python scheduler
service
SglangScheduler
{
// Initialize connection and get model info
rpc
Initialize
(
InitializeRequest
)
returns
(
InitializeResponse
);
// Submit a generation request (supports streaming)
rpc
Generate
(
GenerateRequest
)
returns
(
stream
GenerateResponse
);
...
...
@@ -23,8 +20,6 @@ service SglangScheduler {
// Abort a running request
rpc
Abort
(
AbortRequest
)
returns
(
AbortResponse
);
// Flush KV cache
rpc
FlushCache
(
FlushCacheRequest
)
returns
(
FlushCacheResponse
);
}
// =====================
...
...
@@ -75,14 +70,6 @@ message SamplingParams {
google.protobuf.Struct
custom_params
=
25
;
}
// Session parameters for continual prompting
message
SessionParams
{
string
session_id
=
1
;
string
request_id
=
2
;
int32
offset
=
3
;
bool
replace
=
4
;
bool
drop_previous_output
=
5
;
}
// Disaggregated serving parameters
message
DisaggregatedParams
{
...
...
@@ -91,87 +78,6 @@ message DisaggregatedParams {
int32
bootstrap_room
=
3
;
}
// =====================
// Initialize
// =====================
message
InitializeRequest
{
string
client_id
=
1
;
string
client_version
=
2
;
// Operating mode
enum
Mode
{
REGULAR
=
0
;
// Normal mode with local scheduler
PREFILL
=
1
;
// Prefill-only mode for disaggregated serving
DECODE
=
2
;
// Decode-only mode for disaggregated serving
}
Mode
mode
=
3
;
}
message
InitializeResponse
{
bool
success
=
1
;
string
scheduler_version
=
2
;
// Model information
ModelInfo
model_info
=
3
;
// Server capabilities
ServerCapabilities
capabilities
=
4
;
// Error message if success is false
string
error_message
=
5
;
}
message
ModelInfo
{
string
model_name
=
1
;
int32
max_context_length
=
2
;
int32
vocab_size
=
3
;
bool
supports_tool_calling
=
4
;
bool
supports_vision
=
5
;
repeated
string
special_tokens
=
6
;
// Additional model metadata
string
model_type
=
7
;
int32
num_layers
=
8
;
int32
hidden_size
=
9
;
int32
num_attention_heads
=
10
;
int32
num_key_value_heads
=
11
;
// Tokenizer info
string
tokenizer_type
=
12
;
repeated
int32
eos_token_ids
=
13
;
int32
pad_token_id
=
14
;
int32
bos_token_id
=
15
;
}
message
ServerCapabilities
{
bool
continuous_batching
=
1
;
bool
disaggregated_serving
=
2
;
bool
speculative_decoding
=
3
;
int32
max_batch_size
=
4
;
int32
max_num_batched_tokens
=
5
;
int32
max_prefill_tokens
=
6
;
string
attention_backend
=
7
;
// "flashinfer", "triton", "torch"
// Additional capabilities
bool
supports_lora
=
8
;
bool
supports_grammar
=
9
;
bool
supports_multimodal
=
10
;
repeated
string
supported_modalities
=
11
;
// ["image", "video", "audio"]
bool
supports_custom_logit_processor
=
12
;
bool
supports_session
=
13
;
// Hardware info
int32
num_gpus
=
14
;
string
gpu_type
=
15
;
int64
total_gpu_memory
=
16
;
// Parallelism info
int32
tensor_parallel_size
=
17
;
int32
pipeline_parallel_size
=
18
;
int32
data_parallel_size
=
19
;
}
// =====================
// Generate Request
// =====================
...
...
@@ -179,49 +85,43 @@ message ServerCapabilities {
message
GenerateRequest
{
string
request_id
=
1
;
// Input can be either text or tokenized
oneof
input
{
string
text
=
2
;
TokenizedInput
tokenized
=
3
;
}
// Input must be tokenized (no raw text)
TokenizedInput
tokenized
=
2
;
// Multimodal inputs
MultimodalInputs
mm_inputs
=
4
;
MultimodalInputs
mm_inputs
=
3
;
// Generation parameters
SamplingParams
sampling_params
=
5
;
SamplingParams
sampling_params
=
4
;
// Return options
bool
return_logprob
=
6
;
int32
logprob_start_len
=
7
;
int32
top_logprobs_num
=
8
;
repeated
int32
token_ids_logprob
=
9
;
bool
return_hidden_states
=
10
;
// Session management
SessionParams
session_params
=
11
;
bool
return_logprob
=
5
;
int32
logprob_start_len
=
6
;
int32
top_logprobs_num
=
7
;
repeated
int32
token_ids_logprob
=
8
;
bool
return_hidden_states
=
9
;
// For disaggregated serving
DisaggregatedParams
disaggregated_params
=
1
2
;
DisaggregatedParams
disaggregated_params
=
1
0
;
// Custom logit processor (serialized)
string
custom_logit_processor
=
1
3
;
string
custom_logit_processor
=
1
1
;
// Request metadata
google.protobuf.Timestamp
timestamp
=
1
4
;
bool
log_metrics
=
1
5
;
google.protobuf.Timestamp
timestamp
=
1
2
;
bool
log_metrics
=
1
3
;
// Input embeddings (alternative to text/tokens)
repeated
float
input_embeds
=
1
6
;
repeated
float
input_embeds
=
1
4
;
// LoRA adapter ID (if pre-loaded)
string
lora_id
=
1
7
;
string
lora_id
=
1
5
;
// Data parallel routing
int32
data_parallel_rank
=
1
8
;
int32
data_parallel_rank
=
1
6
;
// For load balancing
int32
dp_balance_id
=
1
9
;
int32
dp_balance_id
=
1
7
;
}
message
TokenizedInput
{
...
...
@@ -303,19 +203,6 @@ message GenerateComplete {
}
FinishReason
finish_reason
=
3
;
// Final counts
int32
prompt_tokens
=
4
;
int32
completion_tokens
=
5
;
int32
cached_tokens
=
6
;
// Performance metrics
float
total_generation_time
=
7
;
float
time_to_first_token
=
8
;
float
tokens_per_second
=
9
;
// Spec decode metrics
int32
spec_verify_count
=
10
;
// All logprobs if requested
repeated
LogProbs
all_logprobs
=
11
;
...
...
@@ -359,10 +246,8 @@ message HiddenStates {
message
EmbedRequest
{
string
request_id
=
1
;
oneof
input
{
string
text
=
2
;
TokenizedInput
tokenized
=
3
;
}
// Input must be tokenized (no raw text)
TokenizedInput
tokenized
=
2
;
// Multimodal inputs
MultimodalInputs
mm_inputs
=
4
;
...
...
@@ -422,39 +307,13 @@ message EmbedError {
// =====================
message
HealthCheckRequest
{
bool
include_detailed_metrics
=
1
;
// Input for health test generation (must be tokenized)
TokenizedInput
tokenized
=
1
;
}
message
HealthCheckResponse
{
bool
healthy
=
1
;
// Current load metrics
int32
num_requests_running
=
2
;
int32
num_requests_waiting
=
3
;
float
gpu_cache_usage
=
4
;
float
gpu_memory_usage
=
5
;
// KV cache metrics
int32
kv_cache_total_blocks
=
6
;
int32
kv_cache_used_blocks
=
7
;
float
kv_cache_hit_rate
=
8
;
// Additional metrics
int32
num_grammar_queue_requests
=
9
;
float
generation_throughput
=
10
;
// tokens/sec
float
average_queue_time
=
11
;
// seconds
float
average_generation_time
=
12
;
// seconds
// System metrics
float
cpu_usage
=
13
;
int64
memory_usage
=
14
;
// Disaggregation metrics
int32
num_prefill_requests
=
15
;
int32
num_decode_requests
=
16
;
// Detailed metrics (optional)
google.protobuf.Struct
detailed_metrics
=
17
;
string
message
=
2
;
}
message
AbortRequest
{
...
...
@@ -467,17 +326,6 @@ message AbortResponse {
string
message
=
2
;
}
message
FlushCacheRequest
{
bool
flush_all
=
1
;
repeated
string
session_ids
=
2
;
// Flush specific sessions
}
message
FlushCacheResponse
{
bool
success
=
1
;
int32
num_entries_flushed
=
2
;
int64
memory_freed
=
3
;
// bytes
string
message
=
4
;
}
// =====================
// Additional Operations (Future)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment