Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
852a49c5
Commit
852a49c5
authored
Sep 30, 2025
by
maxiao
Browse files
adapt to dsv32 on dcu
parent
8f7453e3
Changes
159
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
512 additions
and
1028 deletions
+512
-1028
python/sglang/srt/entrypoints/grpc_request_manager.py
python/sglang/srt/entrypoints/grpc_request_manager.py
+41
-264
python/sglang/srt/entrypoints/grpc_server.py
python/sglang/srt/entrypoints/grpc_server.py
+46
-125
python/sglang/srt/entrypoints/openai/protocol.py
python/sglang/srt/entrypoints/openai/protocol.py
+5
-30
python/sglang/srt/entrypoints/openai/serving_base.py
python/sglang/srt/entrypoints/openai/serving_base.py
+6
-25
python/sglang/srt/entrypoints/openai/serving_chat.py
python/sglang/srt/entrypoints/openai/serving_chat.py
+49
-177
python/sglang/srt/entrypoints/openai/serving_completions.py
python/sglang/srt/entrypoints/openai/serving_completions.py
+3
-4
python/sglang/srt/entrypoints/openai/serving_responses.py
python/sglang/srt/entrypoints/openai/serving_responses.py
+0
-2
python/sglang/srt/eplb/expert_location.py
python/sglang/srt/eplb/expert_location.py
+5
-30
python/sglang/srt/function_call/function_call_parser.py
python/sglang/srt/function_call/function_call_parser.py
+2
-3
python/sglang/srt/function_call/glm4_moe_detector.py
python/sglang/srt/function_call/glm4_moe_detector.py
+3
-3
python/sglang/srt/function_call/json_array_parser.py
python/sglang/srt/function_call/json_array_parser.py
+0
-63
python/sglang/srt/function_call/utils.py
python/sglang/srt/function_call/utils.py
+5
-96
python/sglang/srt/grpc/sglang_scheduler.proto
python/sglang/srt/grpc/sglang_scheduler.proto
+51
-48
python/sglang/srt/grpc/sglang_scheduler_pb2.py
python/sglang/srt/grpc/sglang_scheduler_pb2.py
+68
-69
python/sglang/srt/grpc/sglang_scheduler_pb2.pyi
python/sglang/srt/grpc/sglang_scheduler_pb2.pyi
+50
-39
python/sglang/srt/grpc/sglang_scheduler_pb2_grpc.py
python/sglang/srt/grpc/sglang_scheduler_pb2_grpc.py
+0
-3
python/sglang/srt/hf_transformers_utils.py
python/sglang/srt/hf_transformers_utils.py
+42
-3
python/sglang/srt/layers/attention/aiter_backend.py
python/sglang/srt/layers/attention/aiter_backend.py
+1
-5
python/sglang/srt/layers/attention/ascend_backend.py
python/sglang/srt/layers/attention/ascend_backend.py
+99
-4
python/sglang/srt/layers/attention/attention_registry.py
python/sglang/srt/layers/attention/attention_registry.py
+36
-35
No files found.
python/sglang/srt/entrypoints/grpc_request_manager.py
View file @
852a49c5
...
...
@@ -4,7 +4,6 @@ Mimics TokenizerManager's state management and ZMQ communication patterns.
"""
import
asyncio
import
copy
import
dataclasses
import
logging
import
os
...
...
@@ -12,8 +11,7 @@ import signal
import
sys
import
threading
import
time
import
uuid
from
typing
import
Any
,
AsyncGenerator
,
Dict
,
List
,
Optional
,
Union
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Union
import
grpc
import
zmq
...
...
@@ -81,10 +79,11 @@ class GrpcReqState:
last_completion_tokens
:
int
=
1
# Streaming state
last_output_offset
:
int
=
0
stream_finished
:
bool
=
False
input_logprobs_sent
:
bool
=
False
# Track if input logprobs were sent in streaming
# Token accumulation (for non-streaming)
# Output accumulation
text
:
str
=
""
output_ids
:
List
[
int
]
=
dataclasses
.
field
(
default_factory
=
list
)
input_token_logprobs_val
:
List
[
float
]
=
dataclasses
.
field
(
default_factory
=
list
)
input_token_logprobs_idx
:
List
[
int
]
=
dataclasses
.
field
(
default_factory
=
list
)
...
...
@@ -140,6 +139,8 @@ class GrpcRequestManager:
self
.
is_pause_cond
=
asyncio
.
Condition
()
# Metrics
self
.
request_counter
=
0
self
.
request_counter_lock
=
asyncio
.
Lock
()
self
.
last_receive_tstamp
=
time
.
time
()
# Crash dump for debugging
...
...
@@ -157,133 +158,22 @@ class GrpcRequestManager:
obj
:
TokenizedGenerateReqInput
,
request_id
:
Optional
[
str
]
=
None
,
grpc_context
:
Optional
[
grpc
.
aio
.
ServicerContext
]
=
None
,
)
->
A
sync
Generator
[
Union
[
Dict
,
List
[
Dict
]],
None
]
:
)
->
a
sync
io
.
Queue
:
"""
Submit a generation request to the scheduler with n>1 parallel sampling support.
This method implements the same two-phase approach as tokenizer_manager.py:
1. Phase 1: Send prefix caching request (max_new_tokens=0)
2. Phase 2: Send n generation requests that reuse the cached prefix
Yields individual responses for streaming, or aggregated responses for non-streaming.
Submit a generation request to the scheduler.
Returns a queue for streaming outputs.
"""
n
=
getattr
(
obj
.
sampling_params
,
"n"
,
1
)
if
n
<=
1
:
async
for
response
in
self
.
_handle_single_request
(
obj
,
request_id
,
grpc_context
):
yield
response
return
# N>1 handling - two-phase approach
logger
.
debug
(
f
"Multiple sampling request (n=
{
n
}
), using two-phase approach"
)
# Generate base request ID if not provided
if
request_id
is
None
:
base_request_id
=
f
"grpc-
{
uuid
.
uuid4
().
hex
}
"
else
:
base_request_id
=
request_id
# Phase 1: Cache the common prefix
logger
.
debug
(
f
"Phase 1: Caching prefix for request
{
base_request_id
}
"
)
prefix_obj
=
copy
.
copy
(
obj
)
prefix_obj
.
sampling_params
=
copy
.
copy
(
obj
.
sampling_params
)
prefix_obj
.
sampling_params
.
max_new_tokens
=
0
# Prefill-only
prefix_obj
.
sampling_params
.
n
=
1
# Don't replicate prefix request
# Send prefix caching request and consume response
async
for
_
in
self
.
_handle_single_request
(
prefix_obj
,
f
"
{
base_request_id
}
-prefix"
,
grpc_context
):
# Consume prefix response (usually just one chunk with finish_reason)
pass
logger
.
debug
(
f
"Phase 1 completed: Prefix cached for
{
base_request_id
}
"
)
# Phase 2: Generate n parallel requests
logger
.
debug
(
f
"Phase 2: Generating
{
n
}
parallel requests"
)
generators
=
[]
request_ids
=
[]
for
i
in
range
(
n
):
# Create individual generation request
gen_obj
=
copy
.
copy
(
obj
)
gen_obj
.
sampling_params
=
copy
.
copy
(
obj
.
sampling_params
)
gen_obj
.
sampling_params
.
n
=
1
# Each request generates 1 response
gen_request_id
=
f
"
{
base_request_id
}
-
{
i
}
"
request_ids
.
append
(
gen_request_id
)
# Start generation request
generators
.
append
(
self
.
_handle_single_request
(
gen_obj
,
gen_request_id
,
grpc_context
)
)
# Handle response aggregation
is_stream
=
getattr
(
obj
,
"stream"
,
False
)
if
not
is_stream
:
# Non-streaming: collect all responses and return as batch
logger
.
debug
(
f
"Non-streaming mode: collecting
{
n
}
responses"
)
responses
=
[]
for
generator
in
generators
:
async
for
response
in
generator
:
responses
.
append
(
response
)
yield
responses
# Return all responses as a batch
else
:
# Streaming mode: multiplex responses with index for ordering
logger
.
debug
(
f
"Streaming mode: multiplexing
{
n
}
streams"
)
rid_to_index
=
{
rid
:
i
for
i
,
rid
in
enumerate
(
request_ids
)}
# Create async tasks for all generators
task_map
=
{}
for
generator
in
generators
:
task
=
asyncio
.
create_task
(
generator
.
__anext__
())
task_map
[
task
]
=
generator
# Process responses as they arrive
while
task_map
:
done
,
_
=
await
asyncio
.
wait
(
task_map
.
keys
(),
return_when
=
asyncio
.
FIRST_COMPLETED
)
for
task
in
done
:
generator
=
task_map
.
pop
(
task
)
try
:
response
=
await
task
# Add index for client-side ordering
if
isinstance
(
response
,
dict
)
and
"meta_info"
in
response
:
response_rid
=
response
[
"meta_info"
].
get
(
"id"
,
""
)
if
response_rid
in
rid_to_index
:
response
[
"index"
]
=
rid_to_index
[
response_rid
]
yield
response
# Create next task for this generator
next_task
=
asyncio
.
create_task
(
generator
.
__anext__
())
task_map
[
next_task
]
=
generator
except
StopAsyncIteration
:
# This generator is finished
pass
async
def
_handle_single_request
(
self
,
obj
:
TokenizedGenerateReqInput
,
request_id
:
Optional
[
str
]
=
None
,
grpc_context
:
Optional
[
grpc
.
aio
.
ServicerContext
]
=
None
,
):
"""Handle a single request - core implementation without n>1 logic."""
# Generate request ID if not provided
if
request_id
is
None
:
request_id
=
f
"grpc-
{
uuid
.
uuid4
().
hex
}
"
async
with
self
.
request_counter_lock
:
request_id
=
f
"grpc-
{
self
.
request_counter
}
"
self
.
request_counter
+=
1
obj
.
rid
=
request_id
# Create and register request state
# TODO: support log_request
# Create request state
state
=
GrpcReqState
(
request_id
=
request_id
,
grpc_context
=
grpc_context
,
...
...
@@ -299,51 +189,19 @@ class GrpcRequestManager:
state
.
session_id
=
obj
.
session_params
.
session_id
state
.
is_session_request
=
True
# Register state
self
.
rid_to_state
[
request_id
]
=
state
self
.
record_request_for_crash_dump
(
obj
)
# Send to scheduler via ZMQ
try
:
# Send to scheduler - let exceptions bubble up to grpc_server.py
await
self
.
_send_to_scheduler
(
obj
)
is_stream
=
getattr
(
obj
,
"stream"
,
False
)
while
True
:
# Client cancelled - notify scheduler and exit
if
grpc_context
and
grpc_context
.
cancelled
():
await
self
.
abort_request
(
request_id
)
return
try
:
response
=
await
asyncio
.
wait_for
(
state
.
out_queue
.
get
(),
timeout
=
4
)
if
is_stream
:
yield
response
# Non-streaming: yield final response with accumulated tokens from state
if
isinstance
(
response
,
dict
)
and
response
.
get
(
"finished"
,
False
):
if
not
is_stream
:
final_response
=
response
.
copy
()
final_response
[
"token_ids"
]
=
state
.
output_ids
yield
final_response
break
except
asyncio
.
TimeoutError
:
# Timeout waiting for response - abort and cleanup
logger
.
warning
(
f
"Timeout waiting for response for request
{
request_id
}
"
)
await
self
.
abort_request
(
request_id
)
return
finally
:
# Always clean up request state when exiting
self
.
_cleanup_request_state
(
request_id
)
def
_cleanup_request_state
(
self
,
request_id
:
str
):
"""Clean up local request state (does not notify scheduler)."""
if
request_id
in
self
.
rid_to_state
:
except
Exception
as
e
:
# Clean up on failure
del
self
.
rid_to_state
[
request_id
]
raise
RuntimeError
(
f
"Failed to send request to scheduler:
{
e
}
"
)
return
state
.
out_queue
async
def
embedding_request
(
self
,
...
...
@@ -356,7 +214,9 @@ class GrpcRequestManager:
"""
# Generate request ID if not provided
if
request_id
is
None
:
request_id
=
f
"grpc-embed-
{
uuid
.
uuid4
().
hex
}
"
async
with
self
.
request_counter_lock
:
request_id
=
f
"grpc-embed-
{
self
.
request_counter
}
"
self
.
request_counter
+=
1
obj
.
rid
=
request_id
...
...
@@ -495,6 +355,7 @@ class GrpcRequestManager:
# Extract output for this request
output_data
=
{
"request_id"
:
rid
,
"text"
:
batch_out
.
decoded_texts
[
i
]
if
batch_out
.
decoded_texts
else
""
,
"token_ids"
:
batch_out
.
output_ids
[
i
]
if
batch_out
.
output_ids
else
[],
"finished"
:
batch_out
.
finished_reasons
[
i
]
is
not
None
,
"meta_info"
:
{
...
...
@@ -506,9 +367,6 @@ class GrpcRequestManager:
if
batch_out
.
completion_tokens
else
0
),
"cached_tokens"
:
(
batch_out
.
cached_tokens
[
i
]
if
batch_out
.
cached_tokens
else
0
),
"finish_reason"
:
(
str
(
batch_out
.
finished_reasons
[
i
])
if
batch_out
.
finished_reasons
[
i
]
...
...
@@ -517,110 +375,29 @@ class GrpcRequestManager:
},
}
# Accumulate input logprobs (only once, usually in first chunk)
if
batch_out
.
input_token_logprobs_val
and
i
<
len
(
batch_out
.
input_token_logprobs_val
):
if
not
state
.
input_token_logprobs_val
:
state
.
input_token_logprobs_val
.
extend
(
batch_out
.
input_token_logprobs_val
[
i
]
)
if
batch_out
.
input_token_logprobs_idx
and
i
<
len
(
batch_out
.
input_token_logprobs_idx
):
state
.
input_token_logprobs_idx
.
extend
(
batch_out
.
input_token_logprobs_idx
[
i
]
)
if
batch_out
.
input_top_logprobs_val
and
i
<
len
(
batch_out
.
input_top_logprobs_val
):
state
.
input_top_logprobs_val
.
extend
(
batch_out
.
input_top_logprobs_val
[
i
]
)
if
batch_out
.
input_top_logprobs_idx
and
i
<
len
(
batch_out
.
input_top_logprobs_idx
):
state
.
input_top_logprobs_idx
.
extend
(
batch_out
.
input_top_logprobs_idx
[
i
]
)
# Send input logprobs based on mode
if
state
.
input_token_logprobs_val
:
if
state
.
obj
.
stream
and
not
state
.
input_logprobs_sent
:
# Streaming: send input logprobs once in first chunk that has them
output_data
[
"input_logprobs"
]
=
{
"token_logprobs_val"
:
state
.
input_token_logprobs_val
,
"token_logprobs_idx"
:
state
.
input_token_logprobs_idx
,
"top_logprobs_val"
:
state
.
input_top_logprobs_val
,
"top_logprobs_idx"
:
state
.
input_top_logprobs_idx
,
}
state
.
input_logprobs_sent
=
True
elif
not
state
.
obj
.
stream
and
output_data
[
"finished"
]:
# Non-streaming: send input logprobs in final chunk
output_data
[
"input_logprobs"
]
=
{
"token_logprobs_val"
:
state
.
input_token_logprobs_val
,
"token_logprobs_idx"
:
state
.
input_token_logprobs_idx
,
"top_logprobs_val"
:
state
.
input_top_logprobs_val
,
"top_logprobs_idx"
:
state
.
input_top_logprobs_idx
,
}
# Add output logprobs if available (RAW - no detokenization!)
# Add logprobs if available
if
batch_out
.
output_token_logprobs_val
and
i
<
len
(
batch_out
.
output_token_logprobs_val
):
# Accumulate in state first
state
.
output_token_logprobs_val
.
extend
(
batch_out
.
output_token_logprobs_val
[
i
]
)
if
batch_out
.
output_token_logprobs_idx
and
i
<
len
(
batch_out
.
output_token_logprobs_idx
):
state
.
output_token_logprobs_idx
.
extend
(
batch_out
.
output_token_logprobs_idx
[
i
]
)
if
batch_out
.
output_top_logprobs_val
and
i
<
len
(
batch_out
.
output_top_logprobs_val
):
state
.
output_top_logprobs_val
.
extend
(
output_data
[
"logprobs"
]
=
{
"tokens"
:
batch_out
.
output_token_logprobs_val
[
i
],
"top_logprobs"
:
(
batch_out
.
output_top_logprobs_val
[
i
]
)
if
batch_out
.
output_top_logprobs_idx
and
i
<
len
(
batch_out
.
output_top_logprobs_idx
):
state
.
output_top_logprobs_idx
.
extend
(
batch_out
.
output_top_logprobs_idx
[
i
]
)
if
state
.
obj
.
stream
:
# For streaming: send incremental logprobs (only new tokens in this chunk)
# NOTE: this is different than TokenizerManager, which always accumulates
def
get_part
(
attr_name
):
source_list
=
getattr
(
batch_out
,
attr_name
,
None
)
return
(
source_list
[
i
]
if
source_list
and
i
<
len
(
source_list
)
else
[]
)
output_data
[
"output_logprobs"
]
=
{
"token_logprobs_val"
:
batch_out
.
output_token_logprobs_val
[
i
],
"token_logprobs_idx"
:
get_part
(
"output_token_logprobs_idx"
),
"top_logprobs_val"
:
get_part
(
"output_top_logprobs_val"
),
"top_logprobs_idx"
:
get_part
(
"output_top_logprobs_idx"
),
}
elif
output_data
[
"finished"
]:
# Non-streaming: send cumulative output logprobs in final chunk
output_data
[
"output_logprobs"
]
=
{
"token_logprobs_val"
:
state
.
output_token_logprobs_val
,
"token_logprobs_idx"
:
state
.
output_token_logprobs_idx
,
"top_logprobs_val"
:
state
.
output_top_logprobs_val
,
"top_logprobs_idx"
:
state
.
output_top_logprobs_idx
,
}
# Update state for accumulation
if
batch_out
.
output_top_logprobs_val
and
i
<
len
(
batch_out
.
output_top_logprobs_val
)
else
None
),
}
# Update state
if
output_data
[
"text"
]:
state
.
text
+=
output_data
[
"text"
][
state
.
last_output_offset
:]
state
.
last_output_offset
=
len
(
output_data
[
"text"
])
if
output_data
[
"token_ids"
]:
state
.
output_ids
.
extend
(
output_data
[
"token_ids"
])
# Send to output queue
await
state
.
out_queue
.
put
(
output_data
)
# Handle completion
...
...
python/sglang/srt/entrypoints/grpc_server.py
View file @
852a49c5
...
...
@@ -181,34 +181,20 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
# Convert gRPC request to internal format
tokenized_req
=
self
.
_convert_generate_request
(
request
)
# Submit to request manager
(automatically handles n>1)
response_generator
=
self
.
request_manager
.
generate_request
(
# Submit to request manager
output_queue
=
await
self
.
request_manager
.
generate_request
(
obj
=
tokenized_req
,
request_id
=
request
.
request_id
,
grpc_context
=
context
,
)
async
for
output
in
response_generator
:
# Handle batch responses (for n>1 non-streaming)
if
isinstance
(
output
,
list
):
for
batch_output
in
output
:
if
"error"
in
batch_output
:
yield
sglang_scheduler_pb2
.
GenerateResponse
(
request_id
=
request
.
request_id
,
error
=
sglang_scheduler_pb2
.
GenerateError
(
message
=
batch_output
[
"error"
],
http_status_code
=
(
"500"
if
"abort"
not
in
batch_output
else
"499"
),
),
)
else
:
# All non-error batch outputs are final responses
yield
self
.
_create_completion_response
(
request
.
request_id
,
batch_output
)
else
:
# Handle single response (for streaming or n=1 non-streaming)
# Stream outputs
while
True
:
try
:
# Get output with timeout
output
=
await
asyncio
.
wait_for
(
output_queue
.
get
(),
timeout
=
4
)
# Check for errors
if
"error"
in
output
:
yield
sglang_scheduler_pb2
.
GenerateResponse
(
request_id
=
request
.
request_id
,
...
...
@@ -219,13 +205,27 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
),
),
)
elif
output
.
get
(
"finished"
,
False
):
break
# Check if finished
if
output
.
get
(
"finished"
,
False
):
# Send completion
yield
self
.
_create_completion_response
(
request
.
request_id
,
output
)
break
else
:
# Send chunk
yield
self
.
_create_chunk_response
(
request
.
request_id
,
output
)
except
asyncio
.
TimeoutError
:
# Check if context is still active
if
context
.
cancelled
():
# Abort the request
await
self
.
request_manager
.
abort_request
(
request
.
request_id
)
break
continue
except
Exception
as
e
:
logger
.
error
(
f
"Generate failed:
{
e
}
\n
{
get_exception_traceback
()
}
"
)
yield
sglang_scheduler_pb2
.
GenerateResponse
(
...
...
@@ -266,6 +266,7 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
prompt_tokens
=
result
.
get
(
"prompt_tokens"
,
0
),
cached_tokens
=
0
,
embedding_dim
=
len
(
result
[
"embedding"
]),
generation_time
=
time
.
time
()
-
self
.
start_time
,
),
)
...
...
@@ -321,14 +322,14 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
logger
.
info
(
f
"Sending health check request to request manager..."
)
# Submit and wait for response
output_
generator
=
self
.
request_manager
.
generate_request
(
output_
queue
=
await
self
.
request_manager
.
generate_request
(
health_request
,
request_id
=
rid
)
try
:
#
Get first
response with timeout
#
Wait for
response with
configurable
timeout
response
=
await
asyncio
.
wait_for
(
output_
generator
.
__anext__
(),
timeout
=
HEALTH_CHECK_TIMEOUT
output_
queue
.
get
(),
timeout
=
HEALTH_CHECK_TIMEOUT
)
# Clean up
...
...
@@ -403,8 +404,8 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
return_logprob
=
grpc_req
.
return_logprob
,
logprob_start_len
=
grpc_req
.
logprob_start_len
or
-
1
,
top_logprobs_num
=
grpc_req
.
top_logprobs_num
or
0
,
stream
=
grpc_req
.
stream
or
False
,
lora_
id
=
grpc_req
.
lora_id
if
grpc_req
.
lora_id
else
None
,
stream
=
True
,
# Always
stream
f
or
gRPC
lora_
path
=
grpc_req
.
lora_id
if
grpc_req
.
lora_id
else
None
,
token_ids_logprob
=
(
list
(
grpc_req
.
token_ids_logprob
)
if
grpc_req
.
token_ids_logprob
else
None
),
...
...
@@ -437,7 +438,6 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
regex
=
None
json_schema
=
None
ebnf_grammar
=
None
structural_tag
=
None
if
grpc_params
.
HasField
(
"regex"
):
regex
=
grpc_params
.
regex
...
...
@@ -445,8 +445,6 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
json_schema
=
grpc_params
.
json_schema
elif
grpc_params
.
HasField
(
"ebnf_grammar"
):
ebnf_grammar
=
grpc_params
.
ebnf_grammar
elif
grpc_params
.
HasField
(
"structural_tag"
):
structural_tag
=
grpc_params
.
structural_tag
return
SGLSamplingParams
(
temperature
=
grpc_params
.
temperature
or
1.0
,
...
...
@@ -458,74 +456,33 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
repetition_penalty
=
grpc_params
.
repetition_penalty
or
1.0
,
max_new_tokens
=
grpc_params
.
max_new_tokens
or
128
,
min_new_tokens
=
grpc_params
.
min_new_tokens
or
0
,
stop
=
list
(
grpc_params
.
stop
)
if
grpc_params
.
stop
else
[]
,
stop
=
list
(
grpc_params
.
stop
)
if
grpc_params
.
stop
else
None
,
stop_token_ids
=
(
list
(
grpc_params
.
stop_token_ids
)
if
grpc_params
.
stop_token_ids
else
[]
list
(
grpc_params
.
stop_token_ids
)
if
grpc_params
.
stop_token_ids
else
None
),
skip_special_tokens
=
grpc_params
.
skip_special_tokens
,
spaces_between_special_tokens
=
grpc_params
.
spaces_between_special_tokens
,
regex
=
regex
,
json_schema
=
json_schema
,
ebnf
=
ebnf_grammar
,
structural_tag
=
structural_tag
,
n
=
grpc_params
.
n
or
1
,
ignore_eos
=
grpc_params
.
ignore_eos
,
)
def
_convert_logprobs_to_proto
(
self
,
logprobs_data
:
Dict
)
->
Optional
[
sglang_scheduler_pb2
.
LogProbs
]:
"""Convert logprobs dict to proto LogProbs format (transport RAW data only)."""
if
not
logprobs_data
:
return
None
token_logprobs_val
=
logprobs_data
.
get
(
"token_logprobs_val"
,
[])
token_logprobs_idx
=
logprobs_data
.
get
(
"token_logprobs_idx"
,
[])
top_logprobs_val
=
logprobs_data
.
get
(
"top_logprobs_val"
,
[])
top_logprobs_idx
=
logprobs_data
.
get
(
"top_logprobs_idx"
,
[])
# Build TopLogProbs entries
top_logprobs_proto
=
[]
if
top_logprobs_val
and
top_logprobs_idx
:
for
val_list
,
idx_list
in
zip
(
top_logprobs_val
,
top_logprobs_idx
):
top_logprobs_proto
.
append
(
sglang_scheduler_pb2
.
TopLogProbs
(
values
=
val_list
,
token_ids
=
idx_list
,
)
)
return
sglang_scheduler_pb2
.
LogProbs
(
token_logprobs
=
token_logprobs_val
,
token_ids
=
token_logprobs_idx
,
top_logprobs
=
top_logprobs_proto
,
)
def
_create_chunk_response
(
self
,
request_id
:
str
,
output
:
Dict
)
->
sglang_scheduler_pb2
.
GenerateResponse
:
"""Create a streaming chunk response."""
meta_info
=
output
.
get
(
"meta_info"
,
{})
# Convert output logprobs if present
output_logprobs_proto
=
self
.
_convert_logprobs_to_proto
(
output
.
get
(
"output_logprobs"
)
)
# Convert input logprobs if present (only in first chunk)
input_logprobs_proto
=
self
.
_convert_logprobs_to_proto
(
output
.
get
(
"input_logprobs"
)
)
return
sglang_scheduler_pb2
.
GenerateResponse
(
request_id
=
request_id
,
chunk
=
sglang_scheduler_pb2
.
GenerateStreamChunk
(
token_ids
=
output
.
get
(
"token_ids"
,
[]),
prompt_tokens
=
meta_info
.
get
(
"prompt_tokens"
,
0
),
completion_tokens
=
meta_info
.
get
(
"completion_tokens"
,
0
),
cached_tokens
=
meta_info
.
get
(
"cached_tokens"
,
0
),
output_logprobs
=
output_logprobs_proto
,
input_logprobs
=
input_logprobs_proto
,
token_id
=
output
[
"token_ids"
][
-
1
]
if
output
.
get
(
"token_ids"
)
else
0
,
text
=
output
.
get
(
"text"
,
""
),
prompt_tokens
=
0
,
completion_tokens
=
len
(
output
.
get
(
"token_ids"
,
[])),
cached_tokens
=
0
,
generation_time
=
time
.
time
()
-
self
.
start_time
,
queue_time
=
0.0
,
),
)
...
...
@@ -534,56 +491,20 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
)
->
sglang_scheduler_pb2
.
GenerateResponse
:
"""Create a completion response."""
# Extract meta info and finish reason details
# Determine finish reason
finish_reason
=
sglang_scheduler_pb2
.
GenerateComplete
.
STOP
meta_info
=
output
.
get
(
"meta_info"
,
{})
finish_reason_data
=
meta_info
.
get
(
"finish_reason"
)
# Determine finish reason, default is stop
finish_reason
=
"stop"
if
finish_reason_data
:
if
isinstance
(
finish_reason_data
,
dict
):
finish_reason_type
=
finish_reason_data
.
get
(
"type"
)
else
:
# Handle legacy string format
finish_reason_type
=
finish_reason_data
if
finish_reason_type
==
"length"
:
finish_reason
=
"length"
elif
finish_reason_type
==
"abort"
:
finish_reason
=
"abort"
# Extract matched_stop information
matched_stop_kwargs
=
{}
if
isinstance
(
finish_reason_data
,
dict
)
and
"matched"
in
finish_reason_data
:
matched
=
finish_reason_data
[
"matched"
]
if
isinstance
(
matched
,
int
):
matched_stop_kwargs
[
"matched_token_id"
]
=
matched
elif
isinstance
(
matched
,
str
):
matched_stop_kwargs
[
"matched_stop_str"
]
=
matched
# Convert output logprobs if present
output_logprobs_proto
=
self
.
_convert_logprobs_to_proto
(
output
.
get
(
"output_logprobs"
)
)
# Convert input logprobs if present
input_logprobs_proto
=
self
.
_convert_logprobs_to_proto
(
output
.
get
(
"input_logprobs"
)
)
if
meta_info
.
get
(
"finish_reason"
)
==
"length"
:
finish_reason
=
sglang_scheduler_pb2
.
GenerateComplete
.
LENGTH
elif
meta_info
.
get
(
"finish_reason"
)
==
"eos_token"
:
finish_reason
=
sglang_scheduler_pb2
.
GenerateComplete
.
EOS_TOKEN
return
sglang_scheduler_pb2
.
GenerateResponse
(
request_id
=
request_id
,
complete
=
sglang_scheduler_pb2
.
GenerateComplete
(
output_ids
=
output
.
get
(
"token_ids"
,
[]),
output_text
=
output
.
get
(
"text"
,
""
),
finish_reason
=
finish_reason
,
prompt_tokens
=
meta_info
.
get
(
"prompt_tokens"
,
0
),
completion_tokens
=
meta_info
.
get
(
"completion_tokens"
,
len
(
output
.
get
(
"token_ids"
,
[]))
),
cached_tokens
=
meta_info
.
get
(
"cached_tokens"
,
0
),
output_logprobs
=
output_logprobs_proto
,
input_logprobs
=
input_logprobs_proto
,
**
matched_stop_kwargs
,
),
)
...
...
python/sglang/srt/entrypoints/openai/protocol.py
View file @
852a49c5
...
...
@@ -16,7 +16,7 @@
import
time
import
uuid
from
dataclasses
import
dataclass
from
typing
import
Any
,
Dict
,
List
,
NamedTuple
,
Optional
,
TypeAlias
,
Union
from
typing
import
Any
,
Dict
,
List
,
Optional
,
TypeAlias
,
Union
from
openai.types.responses
import
(
ResponseFunctionToolCall
,
...
...
@@ -228,15 +228,11 @@ class CompletionRequest(BaseModel):
# For request id
rid
:
Optional
[
Union
[
List
[
str
],
str
]]
=
None
# Extra key for classifying the request (e.g. cache_salt)
extra_key
:
Optional
[
Union
[
List
[
str
],
str
]]
=
None
# Cache salt for request caching
cache_salt
:
Optional
[
Union
[
List
[
str
],
str
]]
=
None
# Priority for the request
priority
:
Optional
[
int
]
=
None
# For custom metric labels
custom_labels
:
Optional
[
Dict
[
str
,
str
]]
=
None
# For custom
er
metric labels
custom
er
_labels
:
Optional
[
Dict
[
str
,
str
]]
=
None
@
field_validator
(
"max_tokens"
)
@
classmethod
...
...
@@ -343,7 +339,7 @@ class FunctionResponse(BaseModel):
"""Function response."""
name
:
Optional
[
str
]
=
None
arguments
:
Optional
[
str
|
Dict
[
str
,
Any
]
]
=
None
arguments
:
Optional
[
str
]
=
None
class
ToolCall
(
BaseModel
):
...
...
@@ -392,7 +388,7 @@ class Function(BaseModel):
"""Function descriptions."""
description
:
Optional
[
str
]
=
Field
(
default
=
None
,
examples
=
[
None
])
name
:
str
name
:
Optional
[
str
]
=
None
parameters
:
Optional
[
object
]
=
None
strict
:
bool
=
False
...
...
@@ -549,10 +545,6 @@ class ChatCompletionRequest(BaseModel):
# For request id
rid
:
Optional
[
Union
[
List
[
str
],
str
]]
=
None
# Extra key for classifying the request (e.g. cache_salt)
extra_key
:
Optional
[
Union
[
List
[
str
],
str
]]
=
None
# Cache salt for request caching
cache_salt
:
Optional
[
Union
[
List
[
str
],
str
]]
=
None
# Priority for the request
priority
:
Optional
[
int
]
=
None
...
...
@@ -786,13 +778,6 @@ class ResponsesRequest(BaseModel):
description
=
"The request_id related to this request. If the caller does not set it, a random uuid will be generated."
,
)
priority
:
int
=
Field
(
default
=
0
,
description
=
"Request priority"
)
extra_key
:
Optional
[
str
]
=
Field
(
default
=
None
,
description
=
"Extra key for classifying the request (e.g. cache_salt)"
,
)
cache_salt
:
Optional
[
str
]
=
Field
(
default
=
None
,
description
=
"Cache salt for request caching"
)
# SGLang-specific sampling parameters
frequency_penalty
:
float
=
0.0
...
...
@@ -943,16 +928,6 @@ class MessageProcessingResult:
tool_call_constraint
:
Optional
[
Any
]
=
None
class
ToolCallProcessingResult
(
NamedTuple
):
"""Result of processing tool calls in a response."""
tool_calls
:
Optional
[
List
[
Any
]
]
# List of ToolCall objects or None if parsing failed
remaining_text
:
str
# Text remaining after parsing tool calls
finish_reason
:
Dict
[
str
,
Any
]
# Updated finish reason dictionary
class
ResponseReasoningTextContent
(
BaseModel
):
text
:
str
type
:
Literal
[
"reasoning_text"
]
=
"reasoning_text"
...
...
python/sglang/srt/entrypoints/openai/serving_base.py
View file @
852a49c5
...
...
@@ -27,10 +27,10 @@ class OpenAIServingBase(ABC):
self
.
tokenizer_manager
=
tokenizer_manager
self
.
allowed_custom_labels
=
(
set
(
self
.
tokenizer_manager
.
server_args
.
tokenizer_metrics_allowed_custom_labels
self
.
tokenizer_manager
.
server_args
.
tokenizer_metrics_allowed_custom
er
_labels
)
if
isinstance
(
self
.
tokenizer_manager
.
server_args
,
ServerArgs
)
and
self
.
tokenizer_manager
.
server_args
.
tokenizer_metrics_allowed_custom_labels
and
self
.
tokenizer_manager
.
server_args
.
tokenizer_metrics_allowed_custom
er
_labels
else
None
)
...
...
@@ -62,12 +62,6 @@ class OpenAIServingBase(ABC):
return
self
.
create_error_response
(
message
=
e
.
detail
,
err_type
=
str
(
e
.
status_code
),
status_code
=
e
.
status_code
)
except
ValueError
as
e
:
return
self
.
create_error_response
(
message
=
str
(
e
),
err_type
=
"BadRequest"
,
status_code
=
400
,
)
except
Exception
as
e
:
logger
.
exception
(
f
"Error in request:
{
e
}
"
)
return
self
.
create_error_response
(
...
...
@@ -92,19 +86,6 @@ class OpenAIServingBase(ABC):
return
f
"
{
self
.
_request_id_prefix
()
}{
uuid
.
uuid4
().
hex
}
"
def
_compute_extra_key
(
self
,
request
:
OpenAIServingRequest
)
->
Optional
[
str
]:
"""Compute the final extra_key by concatenating cache_salt and extra_key if both are provided."""
parts
=
[]
for
key
in
[
"cache_salt"
,
"extra_key"
]:
value
=
getattr
(
request
,
key
,
None
)
if
value
:
if
not
isinstance
(
value
,
str
):
raise
TypeError
(
f
"Value of
{
key
}
must be a string, but got
{
type
(
value
).
__name__
}
"
)
parts
.
append
(
value
)
return
""
.
join
(
parts
)
if
parts
else
None
@
abstractmethod
def
_convert_to_internal_request
(
self
,
...
...
@@ -184,14 +165,14 @@ class OpenAIServingBase(ABC):
)
return
json
.
dumps
({
"error"
:
error
.
model_dump
()})
def
extract_custom_labels
(
self
,
raw_request
):
def
extract_custom
er
_labels
(
self
,
raw_request
):
if
(
not
self
.
allowed_custom_labels
or
not
self
.
tokenizer_manager
.
server_args
.
tokenizer_metrics_custom_labels_header
):
return
None
custom_labels
=
None
custom
er
_labels
=
None
header
=
(
self
.
tokenizer_manager
.
server_args
.
tokenizer_metrics_custom_labels_header
)
...
...
@@ -206,9 +187,9 @@ class OpenAIServingBase(ABC):
raw_labels
=
None
if
isinstance
(
raw_labels
,
dict
):
custom_labels
=
{
custom
er
_labels
=
{
label
:
value
for
label
,
value
in
raw_labels
.
items
()
if
label
in
self
.
allowed_custom_labels
}
return
custom_labels
return
custom
er
_labels
python/sglang/srt/entrypoints/openai/serving_chat.py
View file @
852a49c5
...
...
@@ -9,7 +9,6 @@ from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, List, Optional, Uni
from
fastapi
import
Request
from
fastapi.responses
import
ORJSONResponse
,
StreamingResponse
from
jsonschema
import
Draft202012Validator
,
SchemaError
from
sglang.srt.entrypoints.openai.protocol
import
(
ChatCompletionRequest
,
...
...
@@ -26,8 +25,6 @@ from sglang.srt.entrypoints.openai.protocol import (
LogProbs
,
MessageProcessingResult
,
ToolCall
,
ToolCallProcessingResult
,
ToolChoice
,
TopLogprob
,
)
from
sglang.srt.entrypoints.openai.serving_base
import
OpenAIServingBase
...
...
@@ -36,10 +33,7 @@ from sglang.srt.entrypoints.openai.utils import (
process_hidden_states_from_ret
,
to_openai_style_logprobs
,
)
from
sglang.srt.function_call.core_types
import
ToolCallItem
from
sglang.srt.function_call.function_call_parser
import
FunctionCallParser
from
sglang.srt.function_call.json_array_parser
import
JsonArrayParser
from
sglang.srt.function_call.utils
import
get_json_schema_constraint
from
sglang.srt.managers.io_struct
import
GenerateReqInput
from
sglang.srt.parser.conversation
import
generate_chat_conv
from
sglang.srt.parser.jinja_template_utils
import
process_content_for_template_format
...
...
@@ -64,7 +58,6 @@ class OpenAIServingChat(OpenAIServingBase):
super
().
__init__
(
tokenizer_manager
)
self
.
template_manager
=
template_manager
self
.
tool_call_parser
=
self
.
tokenizer_manager
.
server_args
.
tool_call_parser
self
.
reasoning_parser
=
self
.
tokenizer_manager
.
server_args
.
reasoning_parser
def
_request_id_prefix
(
self
)
->
str
:
return
"chatcmpl-"
...
...
@@ -81,23 +74,6 @@ class OpenAIServingChat(OpenAIServingBase):
):
return
"Tools cannot be empty if tool choice is set to required."
if
request
.
tool_choice
is
not
None
and
not
isinstance
(
request
.
tool_choice
,
str
):
if
not
request
.
tools
:
return
"Tools cannot be empty if tool choice is set to a specific tool."
tool_name
=
request
.
tool_choice
.
function
.
name
tool_exists
=
any
(
tool
.
function
.
name
==
tool_name
for
tool
in
request
.
tools
)
if
not
tool_exists
:
return
f
"Tool '
{
tool_name
}
' not found in tools list."
# Validate tool definitions
for
i
,
tool
in
enumerate
(
request
.
tools
or
[]):
if
tool
.
function
.
parameters
is
None
:
continue
try
:
Draft202012Validator
.
check_schema
(
tool
.
function
.
parameters
)
except
SchemaError
as
e
:
return
f
"Tool
{
i
}
function has invalid 'parameters' schema:
{
str
(
e
)
}
"
max_output_tokens
=
request
.
max_completion_tokens
or
request
.
max_tokens
server_context_length
=
self
.
tokenizer_manager
.
server_args
.
context_length
if
(
...
...
@@ -152,8 +128,8 @@ class OpenAIServingChat(OpenAIServingBase):
else
:
prompt_kwargs
=
{
"input_ids"
:
processed_messages
.
prompt_ids
}
# Extract custom labels from raw request headers
custom_labels
=
self
.
extract_custom_labels
(
raw_request
)
# Extract custom
er
labels from raw request headers
custom
er
_labels
=
self
.
extract_custom
er
_labels
(
raw_request
)
adapted_request
=
GenerateReqInput
(
**
prompt_kwargs
,
...
...
@@ -173,9 +149,8 @@ class OpenAIServingChat(OpenAIServingBase):
bootstrap_room
=
request
.
bootstrap_room
,
return_hidden_states
=
request
.
return_hidden_states
,
rid
=
request
.
rid
,
extra_key
=
self
.
_compute_extra_key
(
request
),
priority
=
request
.
priority
,
custom_labels
=
custom_labels
,
custom
er
_labels
=
custom
er
_labels
,
)
return
adapted_request
,
request
...
...
@@ -213,14 +188,6 @@ class OpenAIServingChat(OpenAIServingBase):
tool_call_constraint
=
parser
.
get_structure_constraint
(
request
.
tool_choice
)
# Handle JSON schema constraint directly for required or named tool choice
if
request
.
tool_choice
==
"required"
or
isinstance
(
request
.
tool_choice
,
ToolChoice
):
json_schema
=
get_json_schema_constraint
(
request
.
tools
,
request
.
tool_choice
)
tool_call_constraint
=
(
"json_schema"
,
json_schema
)
# Use chat template
if
self
.
template_manager
.
chat_template_name
is
None
:
...
...
@@ -468,10 +435,6 @@ class OpenAIServingChat(OpenAIServingBase):
sampling_params
[
constraint_type
]
=
convert_json_schema_to_str
(
constraint_value
.
model_dump
(
by_alias
=
True
)
)
elif
constraint_type
==
"json_schema"
:
sampling_params
[
constraint_type
]
=
convert_json_schema_to_str
(
constraint_value
)
else
:
sampling_params
[
constraint_type
]
=
constraint_value
return
sampling_params
...
...
@@ -564,7 +527,10 @@ class OpenAIServingChat(OpenAIServingBase):
stream_buffers
[
index
]
=
stream_buffer
+
delta
# Handle reasoning content
if
self
.
reasoning_parser
and
request
.
separate_reasoning
:
if
(
self
.
tokenizer_manager
.
server_args
.
reasoning_parser
and
request
.
separate_reasoning
):
reasoning_text
,
delta
=
self
.
_process_reasoning_stream
(
index
,
delta
,
reasoning_parser_dict
,
content
,
request
)
...
...
@@ -754,7 +720,7 @@ class OpenAIServingChat(OpenAIServingBase):
# Handle reasoning content
reasoning_text
=
None
reasoning_parser
=
self
.
reasoning_parser
reasoning_parser
=
self
.
tokenizer_manager
.
server_args
.
reasoning_parser
if
reasoning_parser
and
request
.
separate_reasoning
:
is_force_reasoning
=
(
self
.
template_manager
.
force_reasoning
...
...
@@ -782,13 +748,8 @@ class OpenAIServingChat(OpenAIServingBase):
and
request
.
tools
and
self
.
tool_call_parser
):
history_tool_calls_cnt
=
self
.
_get_history_tool_calls_cnt
(
request
)
tool_calls
,
text
,
finish_reason
=
self
.
_process_tool_calls
(
text
,
request
.
tools
,
finish_reason
,
request
.
tool_choice
,
history_tool_calls_cnt
,
text
,
request
.
tools
,
finish_reason
)
choice_data
=
ChatCompletionResponseChoice
(
...
...
@@ -878,76 +839,13 @@ class OpenAIServingChat(OpenAIServingBase):
token_logprobs
=
self
.
_process_logprobs_tokens
(
logprobs
,
use_token_index
=
True
)
return
ChoiceLogprobs
(
content
=
token_logprobs
)
def
_process_tool_call_id
(
self
,
call_item
:
ToolCallItem
,
history_tool_calls_cnt
:
int
,
)
->
str
:
"""Process for generating a new and unique `tool_call_id`"""
if
self
.
tool_call_parser
!=
"kimi_k2"
:
# A simple uuid is sufficient for all models except for Kimi-K2.
tool_call_id
=
f
"call_
{
uuid
.
uuid4
().
hex
[:
24
]
}
"
return
tool_call_id
else
:
# Align with Kimi-K2 format: functions.{name}:{index}
# Kimi-K2 allows multiple tool_calls in one message; SGLang sets call_item.tool_index to the *local* position inside that message.
# Therefore, the index must be corrected by using `history_tool_calls_cnt + call_item.tool_index` to ensure globally unique and properly ordered.
tool_call_id
=
f
"functions.
{
call_item
.
name
}
:
{
history_tool_calls_cnt
+
call_item
.
tool_index
}
"
logger
.
debug
(
f
"Process tool call idx, parser:
{
self
.
tool_call_parser
}
, tool_call_id:
{
tool_call_id
}
, history_cnt:
{
history_tool_calls_cnt
}
"
)
return
tool_call_id
def
_process_tool_calls
(
self
,
text
:
str
,
tools
:
List
[
Any
],
finish_reason
:
Dict
[
str
,
Any
],
tool_choice
:
Optional
[
Union
[
str
,
ToolChoice
]]
=
None
,
history_tool_calls_cnt
:
int
=
0
,
)
->
ToolCallProcessingResult
:
)
->
tuple
[
Optional
[
List
[
ToolCall
]],
str
,
Dict
[
str
,
Any
]]:
"""Process tool calls in the response"""
# Handle required or named tool choice
if
tool_choice
==
"required"
or
(
isinstance
(
tool_choice
,
ToolChoice
)
and
tool_choice
.
type
==
"function"
):
# Set finish reason to tool_calls since we're processing tool calls
if
finish_reason
[
"type"
]
==
"stop"
:
finish_reason
[
"type"
]
=
"tool_calls"
finish_reason
[
"matched"
]
=
None
try
:
# For required tool choice, we expect a JSON array of tool calls
tool_call_data
=
json
.
loads
(
text
)
tool_calls
=
[]
for
i
,
tool
in
enumerate
(
tool_call_data
):
# Create a ToolCallItem from the JSON data
call_info
=
ToolCallItem
(
tool_index
=
i
,
# Use the loop index as tool_index
name
=
tool
[
"name"
],
parameters
=
json
.
dumps
(
tool
[
"parameters"
],
ensure_ascii
=
False
),
)
tool_id
=
self
.
_process_tool_call_id
(
call_info
,
history_tool_calls_cnt
)
tool_calls
.
append
(
ToolCall
(
id
=
tool_id
,
index
=
i
,
function
=
FunctionResponse
(
name
=
tool
[
"name"
],
arguments
=
json
.
dumps
(
tool
[
"parameters"
],
ensure_ascii
=
False
),
),
)
)
return
ToolCallProcessingResult
(
tool_calls
,
""
,
finish_reason
)
except
json
.
JSONDecodeError
as
e
:
logger
.
error
(
f
"Tool call parsing error:
{
e
}
"
)
return
ToolCallProcessingResult
(
None
,
text
,
finish_reason
)
# Use parser since output is not constrained by JSON schema
parser
=
FunctionCallParser
(
tools
,
self
.
tool_call_parser
)
if
parser
.
has_tool_call
(
text
):
if
finish_reason
[
"type"
]
==
"stop"
:
...
...
@@ -957,9 +855,15 @@ class OpenAIServingChat(OpenAIServingBase):
text
,
call_info_list
=
parser
.
parse_non_stream
(
text
)
tool_calls
=
[]
for
call_info
in
call_info_list
:
tool_id
=
self
.
_process_tool_call_id
(
call_info
,
history_tool_calls_cnt
)
# For Kimi-K2, align tool_call_id with the model format: functions.{name}:{index}
if
(
self
.
tool_call_parser
==
"kimi_k2"
and
call_info
.
name
is
not
None
):
tool_id
=
f
"functions.
{
call_info
.
name
}
:
{
call_info
.
tool_index
}
"
else
:
tool_id
=
f
"call_
{
uuid
.
uuid4
().
hex
[:
24
]
}
"
tool_calls
.
append
(
ToolCall
(
id
=
tool_id
,
...
...
@@ -969,13 +873,13 @@ class OpenAIServingChat(OpenAIServingBase):
),
)
)
return
ToolCallProcessingResult
(
tool_calls
,
text
,
finish_reason
)
return
tool_calls
,
text
,
finish_reason
except
Exception
as
e
:
logger
.
error
(
f
"Tool call parsing error:
{
e
}
"
)
# Return error but don't fail the whole request
return
ToolCallProcessingResult
(
None
,
text
,
finish_reason
)
return
None
,
text
,
finish_reason
return
ToolCallProcessingResult
(
None
,
text
,
finish_reason
)
return
None
,
text
,
finish_reason
def
_process_streaming_logprobs
(
self
,
content
:
Dict
[
str
,
Any
],
n_prev_token
:
int
...
...
@@ -1008,33 +912,13 @@ class OpenAIServingChat(OpenAIServingBase):
or
self
.
_get_enable_thinking_from_request
(
request
)
)
reasoning_parser_dict
[
index
]
=
ReasoningParser
(
self
.
reasoning_parser
,
self
.
tokenizer_manager
.
server_args
.
reasoning_parser
,
request
.
stream_reasoning
,
is_force_reasoning
,
)
reasoning_parser
=
reasoning_parser_dict
[
index
]
return
reasoning_parser
.
parse_stream_chunk
(
delta
)
def
_get_history_tool_calls_cnt
(
self
,
request
:
ChatCompletionRequest
)
->
int
:
"""Counts the number of tool calls in the request's message history.
NOTE: This method is only useful for models that include self-increasing
history tool call idx in tool calls id, such as kimi-k2
Args:
request: The chat completion request object.
Returns:
The total number of tool calls in the history, or 0 if not applicable.
"""
messages
=
getattr
(
request
,
"messages"
,
[])
idx
=
0
for
msg
in
messages
:
if
msg
.
role
==
"assistant"
:
tool_calls
=
getattr
(
msg
,
"tool_calls"
,
None
)
idx
+=
len
(
list
(
tool_calls
))
if
tool_calls
is
not
None
else
0
# noqa
return
idx
def
_get_enable_thinking_from_request
(
self
,
request
:
ChatCompletionRequest
)
->
bool
:
"""Extracts the 'enable_thinking' flag from request chat_template_kwargs.
...
...
@@ -1048,11 +932,11 @@ class OpenAIServingChat(OpenAIServingBase):
"""
if
hasattr
(
request
,
"chat_template_kwargs"
)
and
request
.
chat_template_kwargs
:
# For Qwen3 models, `enable_thinking` is supported.
if
self
.
reasoning_parser
in
[
"qwen3"
,
"glm45"
]
:
return
request
.
chat_template_kwargs
.
get
(
"enable_thinking"
,
False
)
if
request
.
chat_template_kwargs
.
get
(
"enable_thinking"
)
is
not
None
:
return
request
.
chat_template_kwargs
.
get
(
"enable_thinking"
)
# For DeepSeek-V3.1 models, `thinking` is supported.
elif
self
.
reasoning_parser
in
[
"deepseek-v3"
]
:
return
request
.
chat_template_kwargs
.
get
(
"thinking"
,
False
)
elif
request
.
chat_template_kwargs
.
get
(
"thinking"
)
is
not
None
:
return
request
.
chat_template_kwargs
.
get
(
"thinking"
)
else
:
return
False
return
False
...
...
@@ -1068,25 +952,13 @@ class OpenAIServingChat(OpenAIServingBase):
):
"""Process tool calls in streaming response"""
if
index
not
in
parser_dict
:
# Use JSON detector directly for required or named tool choice
if
request
.
tool_choice
==
"required"
or
isinstance
(
request
.
tool_choice
,
ToolChoice
):
parser_dict
[
index
]
=
JsonArrayParser
()
else
:
parser_dict
[
index
]
=
FunctionCallParser
(
tools
=
request
.
tools
,
tool_call_parser
=
self
.
tool_call_parser
,
)
parser_dict
[
index
]
=
FunctionCallParser
(
tools
=
request
.
tools
,
tool_call_parser
=
self
.
tool_call_parser
,
)
parser
=
parser_dict
[
index
]
# Handle both FunctionCallParser and JsonArrayParser
if
isinstance
(
parser
,
JsonArrayParser
):
result
=
parser
.
parse_streaming_increment
(
delta
,
request
.
tools
)
normal_text
,
calls
=
result
.
normal_text
,
result
.
calls
else
:
normal_text
,
calls
=
parser
.
parse_stream_chunk
(
delta
)
normal_text
,
calls
=
parser
.
parse_stream_chunk
(
delta
)
# Yield normal text
if
normal_text
:
...
...
@@ -1104,7 +976,6 @@ class OpenAIServingChat(OpenAIServingBase):
yield
f
"data:
{
chunk
.
model_dump_json
()
}
\n\n
"
# Yield tool calls
history_tool_calls_cnt
=
self
.
_get_history_tool_calls_cnt
(
request
)
for
call_item
in
calls
:
# Mark that this choice has tool calls
has_tool_calls
[
index
]
=
True
...
...
@@ -1112,9 +983,11 @@ class OpenAIServingChat(OpenAIServingBase):
# Tool call ID should be generated only once per tool call
if
call_item
.
name
:
# First chunk: include ID and function name
tool_call_id
=
self
.
_process_tool_call_id
(
call_item
,
history_tool_calls_cnt
)
if
self
.
tool_call_parser
==
"kimi_k2"
:
# Align with Kimi-K2 format: functions.{name}:{index}
tool_call_id
=
f
"functions.
{
call_item
.
name
}
:
{
call_item
.
tool_index
}
"
else
:
tool_call_id
=
f
"call_
{
uuid
.
uuid4
().
hex
[:
24
]
}
"
function_name
=
call_item
.
name
else
:
# Subsequent chunks: null ID and name for argument deltas
...
...
@@ -1145,7 +1018,7 @@ class OpenAIServingChat(OpenAIServingBase):
def
_check_for_unstreamed_tool_args
(
self
,
parser
:
Union
[
FunctionCallParser
,
JsonArrayParser
],
parser
:
FunctionCallParser
,
content
:
Dict
[
str
,
Any
],
request
:
ChatCompletionRequest
,
index
:
int
,
...
...
@@ -1155,31 +1028,30 @@ class OpenAIServingChat(OpenAIServingBase):
when generation finishes. This ensures tool calls are properly completed
even if the model generates the final arguments in the last chunk.
"""
# Get the detector - either from FunctionCallParser or directly if json detector
detector
=
parser
.
detector
if
hasattr
(
parser
,
"detector"
)
else
parser
# Only check if we have tool calls and the detector has tracked data
# Only check if we have tool calls and the parser has tracked data
if
(
not
hasattr
(
detector
,
"prev_tool_call_arr"
)
or
not
detector
.
prev_tool_call_arr
not
hasattr
(
parser
.
detector
,
"prev_tool_call_arr"
)
or
not
parser
.
detector
.
prev_tool_call_arr
):
return
None
if
(
not
hasattr
(
detector
,
"streamed_args_for_tool"
)
or
not
detector
.
streamed_args_for_tool
not
hasattr
(
parser
.
detector
,
"streamed_args_for_tool"
)
or
not
parser
.
detector
.
streamed_args_for_tool
):
return
None
# Get the last tool call that was being processed
tool_index
=
len
(
detector
.
prev_tool_call_arr
)
-
1
if
tool_index
<
0
or
tool_index
>=
len
(
detector
.
streamed_args_for_tool
):
tool_index
=
len
(
parser
.
detector
.
prev_tool_call_arr
)
-
1
if
tool_index
<
0
or
tool_index
>=
len
(
parser
.
detector
.
streamed_args_for_tool
):
return
None
# Get expected vs actual arguments
expected_args
=
detector
.
prev_tool_call_arr
[
tool_index
].
get
(
"arguments"
,
{})
expected_args
=
parser
.
detector
.
prev_tool_call_arr
[
tool_index
].
get
(
"arguments"
,
{}
)
expected_call
=
json
.
dumps
(
expected_args
,
ensure_ascii
=
False
)
actual_call
=
detector
.
streamed_args_for_tool
[
tool_index
]
actual_call
=
parser
.
detector
.
streamed_args_for_tool
[
tool_index
]
# Check if there are remaining arguments to send
remaining_call
=
(
...
...
python/sglang/srt/entrypoints/openai/serving_completions.py
View file @
852a49c5
...
...
@@ -90,8 +90,8 @@ class OpenAIServingCompletion(OpenAIServingBase):
else
:
prompt_kwargs
=
{
"input_ids"
:
prompt
}
# Extract custom labels from raw request headers
custom_labels
=
self
.
extract_custom_labels
(
raw_request
)
# Extract custom
er
labels from raw request headers
custom
er
_labels
=
self
.
extract_custom
er
_labels
(
raw_request
)
adapted_request
=
GenerateReqInput
(
**
prompt_kwargs
,
...
...
@@ -107,9 +107,8 @@ class OpenAIServingCompletion(OpenAIServingBase):
bootstrap_room
=
request
.
bootstrap_room
,
return_hidden_states
=
request
.
return_hidden_states
,
rid
=
request
.
rid
,
extra_key
=
self
.
_compute_extra_key
(
request
),
priority
=
request
.
priority
,
custom_labels
=
custom_labels
,
custom
er
_labels
=
custom
er
_labels
,
)
return
adapted_request
,
request
...
...
python/sglang/srt/entrypoints/openai/serving_responses.py
View file @
852a49c5
...
...
@@ -245,7 +245,6 @@ class OpenAIServingResponses(OpenAIServingChat):
sampling_params
=
sampling_params
,
stream
=
request
.
stream
,
rid
=
request
.
request_id
,
extra_key
=
self
.
_compute_extra_key
(
request
),
background
=
request
.
background
,
)
...
...
@@ -1251,7 +1250,6 @@ class OpenAIServingResponses(OpenAIServingChat):
sampling_params
=
sampling_params
,
stream
=
adapted_request
.
stream
,
rid
=
request_id
,
extra_key
=
adapted_request
.
extra_key
,
return_logprob
=
adapted_request
.
return_logprob
,
logprob_start_len
=
adapted_request
.
logprob_start_len
,
top_logprobs_num
=
adapted_request
.
top_logprobs_num
,
...
...
python/sglang/srt/eplb/expert_location.py
View file @
852a49c5
...
...
@@ -231,7 +231,6 @@ class ExpertLocationMetadata:
logical_to_all_physical_map_num_valid
=
logical_to_all_physical_map_num_valid
,
logical_to_rank_dispatch_physical_map
=
(
compute_logical_to_rank_dispatch_physical_map
(
server_args
=
server_args
,
logical_to_all_physical_map
=
logical_to_all_physical_map
,
num_gpus
=
ep_size
,
num_physical_experts
=
num_physical_experts
,
...
...
@@ -341,7 +340,6 @@ def _pad_nested_array(arr, pad_value):
# TODO optimize performance (rewrite and/or run in separate process with overlap)
def
compute_logical_to_rank_dispatch_physical_map
(
server_args
:
ServerArgs
,
logical_to_all_physical_map
:
torch
.
Tensor
,
num_gpus
:
int
,
num_physical_experts
:
int
,
...
...
@@ -350,9 +348,7 @@ def compute_logical_to_rank_dispatch_physical_map(
):
r
=
random
.
Random
(
seed
)
num_local_gpu_physical_experts
=
num_physical_experts
//
num_gpus
num_gpus_per_node
=
server_args
.
ep_size
//
server_args
.
nnodes
num_local_node_physical_experts
=
num_local_gpu_physical_experts
*
num_gpus_per_node
num_local_physical_experts
=
num_physical_experts
//
num_gpus
num_layers
,
num_logical_experts
,
_
=
logical_to_all_physical_map
.
shape
dtype
=
logical_to_all_physical_map
.
dtype
...
...
@@ -376,28 +372,13 @@ def compute_logical_to_rank_dispatch_physical_map(
physical_expert_id
for
physical_expert_id
in
candidate_physical_expert_ids
if
_compute_gpu_id_of_physical_expert
(
physical_expert_id
,
num_local_
gpu_
physical_experts
physical_expert_id
,
num_local_physical_experts
)
==
gpu_id
]
if
len
(
same_gpu_physical_expert_ids
)
>
0
:
# 1. Prefer same-GPU experts
output_partial
[
gpu_id
]
=
same_gpu_physical_expert_ids
[
0
]
else
:
# 2. Otherwise, prefer same-node experts
node_id
=
gpu_id
//
num_gpus_per_node
same_node_physical_expert_ids
=
[
physical_expert_id
for
physical_expert_id
in
candidate_physical_expert_ids
if
_compute_node_id_of_physical_expert
(
physical_expert_id
,
num_local_node_physical_experts
)
==
node_id
]
if
len
(
same_node_physical_expert_ids
)
>
0
:
output_partial
[
gpu_id
]
=
same_node_physical_expert_ids
[
0
]
# 3. Fill remaining slots with fair random choices
num_remain
=
torch
.
sum
(
output_partial
==
-
1
).
item
()
output_partial
[
output_partial
==
-
1
]
=
torch
.
tensor
(
_fair_choices
(
candidate_physical_expert_ids
,
k
=
num_remain
,
r
=
r
),
...
...
@@ -423,15 +404,9 @@ def _logical_to_all_physical_raw(
def
_compute_gpu_id_of_physical_expert
(
physical_expert_id
:
int
,
num_local_gpu_physical_experts
:
int
)
->
int
:
return
physical_expert_id
//
num_local_gpu_physical_experts
def
_compute_node_id_of_physical_expert
(
physical_expert_id
:
int
,
num_local_host_physical_experts
:
int
physical_expert_id
:
int
,
num_local_physical_experts
:
int
)
->
int
:
return
physical_expert_id
//
num_local_
host_
physical_experts
return
physical_expert_id
//
num_local_physical_experts
def
_fair_choices
(
arr
:
List
,
k
:
int
,
r
:
random
.
Random
)
->
List
:
...
...
python/sglang/srt/function_call/function_call_parser.py
View file @
852a49c5
...
...
@@ -20,7 +20,6 @@ from sglang.srt.function_call.pythonic_detector import PythonicDetector
from
sglang.srt.function_call.qwen3_coder_detector
import
Qwen3CoderDetector
from
sglang.srt.function_call.qwen25_detector
import
Qwen25Detector
from
sglang.srt.function_call.step3_detector
import
Step3Detector
from
sglang.srt.function_call.utils
import
get_json_schema_constraint
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -179,8 +178,8 @@ class FunctionCallParser:
strict_tag
=
self
.
get_structure_tag
()
return
(
"structural_tag"
,
strict_tag
)
elif
tool_choice
==
"required"
or
isinstance
(
tool_choice
,
ToolChoice
):
json_schema
=
get_json_schema_constraint
(
self
.
tools
,
tool_choice
)
return
(
"
json_schema"
,
json_schema
)
ebnf
=
self
.
get_ebnf
(
tool_choice
)
return
(
"
ebnf"
,
ebnf
)
if
ebnf
is
not
None
else
None
def
get_ebnf
(
self
,
tool_choice
:
Union
[
ToolChoice
,
Literal
[
"required"
]]
...
...
python/sglang/srt/function_call/glm4_moe_detector.py
View file @
852a49c5
...
...
@@ -39,7 +39,7 @@ def parse_arguments(json_value):
class
Glm4MoeDetector
(
BaseFormatDetector
):
"""
Detector for GLM-4.5
and GLM-4.6
models.
Detector for GLM-4.5 models.
Assumes function call format:
<tool_call>get_weather
\n
<arg_key>city</arg_key>
\n
<arg_value>北京</arg_value>
\n
<arg_key>date</arg_key>
\n
<arg_value>2024-06-27</arg_value>
\n
</tool_call>
\n
<tool_call>get_weather
\n
<arg_key>city</arg_key>
\n
<arg_value>上海</arg_value>
\n
<arg_key>date</arg_key>
\n
<arg_value>2024-06-27</arg_value>
\n
</tool_call>
"""
...
...
@@ -53,7 +53,7 @@ class Glm4MoeDetector(BaseFormatDetector):
self
.
func_arg_regex
=
r
"<arg_key>(.*?)</arg_key>\s*<arg_value>(.*?)</arg_value>"
def
has_tool_call
(
self
,
text
:
str
)
->
bool
:
"""Check if the text contains a glm-4.5
/ glm-4.6
format tool call."""
"""Check if the text contains a glm-4.5 format tool call."""
return
self
.
bot_token
in
text
def
detect_and_parse
(
self
,
text
:
str
,
tools
:
List
[
Tool
])
->
StreamingParseResult
:
...
...
@@ -102,7 +102,7 @@ class Glm4MoeDetector(BaseFormatDetector):
self
,
new_text
:
str
,
tools
:
List
[
Tool
]
)
->
StreamingParseResult
:
"""
Streaming incremental parsing tool calls for GLM-4.5
and GLM-4.6
format.
Streaming incremental parsing tool calls for GLM-4.5 format.
"""
self
.
_buffer
+=
new_text
current_text
=
self
.
_buffer
...
...
python/sglang/srt/function_call/json_array_parser.py
deleted
100644 → 0
View file @
8f7453e3
import
json
import
re
from
typing
import
List
from
sglang.srt.entrypoints.openai.protocol
import
Tool
from
sglang.srt.function_call.base_format_detector
import
BaseFormatDetector
from
sglang.srt.function_call.core_types
import
StreamingParseResult
class
JsonArrayParser
(
BaseFormatDetector
):
"""
Parser for JSON array tool calls when JSON schema constraints are active.
This parser is used when tool_choice="required" or a specific tool is named,
bypassing model-specific parsers in favor of direct JSON array parsing.
"""
def
__init__
(
self
):
super
().
__init__
()
# Configure for JSON array parsing
self
.
bot_token
=
"["
self
.
eot_token
=
"]"
self
.
tool_call_separator
=
","
def
has_tool_call
(
self
,
text
:
str
)
->
bool
:
"""
Check if the given text contains a JSON tool call (array or single object).
"""
return
"["
in
text
or
"{"
in
text
def
detect_and_parse
(
self
,
text
:
str
,
tools
:
List
[
Tool
])
->
StreamingParseResult
:
"""
Parse JSON tool calls using the base class implementation.
"""
raise
NotImplementedError
(
"Detect and parse not supported for JSON schema constraints."
)
def
build_ebnf
(
self
,
tools
:
List
[
Tool
])
->
str
:
"""
Build an EBNF grammar for constrained generation.
This is not used for JSON schema constraints as they are handled
by the constraint backends directly.
"""
raise
NotImplementedError
(
"EBNF generation is not supported for JSON schema constraints."
)
def
parse_streaming_increment
(
self
,
new_text
:
str
,
tools
:
List
[
Tool
]
)
->
StreamingParseResult
:
"""
Streaming incremental parsing with tool validation.
"""
return
super
().
parse_streaming_increment
(
new_text
,
tools
)
def
structure_info
(
self
)
->
callable
:
"""
Return a function that creates StructureInfo for constrained generation.
This is not used for JSON schema constraints as they are handled
by the constraint backends directly.
"""
raise
NotImplementedError
(
"structure_info not used for JSON schema constraints"
)
python/sglang/srt/function_call/utils.py
View file @
852a49c5
import
json
from
json
import
JSONDecodeError
,
JSONDecoder
from
json.decoder
import
WHITESPACE
from
typing
import
Any
,
List
,
Literal
,
Optional
,
Tuple
,
Union
from
typing
import
Any
,
Tuple
import
partial_json_parser
from
partial_json_parser.core.options
import
Allow
from
sglang.srt.entrypoints.openai.protocol
import
Tool
,
ToolChoice
def
_find_common_prefix
(
s1
:
str
,
s2
:
str
)
->
str
:
prefix
=
""
...
...
@@ -40,12 +37,10 @@ def _partial_json_loads(input_str: str, flags: Allow) -> Tuple[Any, int]:
"""
try
:
return
(
partial_json_parser
.
loads
(
input_str
,
flags
),
len
(
input_str
))
except
(
JSONDecodeError
,
IndexError
)
as
e
:
msg
=
getattr
(
e
,
"msg"
,
str
(
e
))
if
"Extra data"
in
msg
or
"pop from empty list"
in
msg
:
start
=
WHITESPACE
.
match
(
input_str
,
0
).
end
()
obj
,
end
=
JSONDecoder
().
raw_decode
(
input_str
,
start
)
return
obj
,
end
except
JSONDecodeError
as
e
:
if
"Extra data"
in
e
.
msg
:
dec
=
JSONDecoder
()
return
dec
.
raw_decode
(
input_str
)
raise
...
...
@@ -55,89 +50,3 @@ def _is_complete_json(input_str: str) -> bool:
return
True
except
JSONDecodeError
:
return
False
def
_get_tool_schema_defs
(
tools
:
List
[
Tool
])
->
dict
:
"""
Get consolidated $defs from all tools, validating for conflicts.
Args:
tools: List of tools to process
Returns:
Dictionary of consolidated $defs from all tools
Raises:
ValueError: If conflicting $defs are found
"""
all_defs
=
{}
for
tool
in
tools
:
if
tool
.
function
.
parameters
is
None
:
continue
defs
=
tool
.
function
.
parameters
.
get
(
"$defs"
,
{})
for
def_name
,
def_schema
in
defs
.
items
():
if
def_name
in
all_defs
and
all_defs
[
def_name
]
!=
def_schema
:
raise
ValueError
(
f
"Tool definition '
{
def_name
}
' has "
"multiple schemas, which is not "
"supported."
)
else
:
all_defs
[
def_name
]
=
def_schema
return
all_defs
def
_get_tool_schema
(
tool
:
Tool
)
->
dict
:
return
{
"properties"
:
{
"name"
:
{
"type"
:
"string"
,
"enum"
:
[
tool
.
function
.
name
]},
"parameters"
:
(
tool
.
function
.
parameters
if
tool
.
function
.
parameters
else
{
"type"
:
"object"
,
"properties"
:
{}}
),
},
"required"
:
[
"name"
,
"parameters"
],
}
def
get_json_schema_constraint
(
tools
:
List
[
Tool
],
tool_choice
:
Union
[
ToolChoice
,
Literal
[
"required"
]]
)
->
Optional
[
dict
]:
"""
Get the JSON schema constraint for the specified tool choice.
Args:
tool_choice: The tool choice specification
Returns:
JSON schema dict, or None if no valid tools found
"""
if
isinstance
(
tool_choice
,
ToolChoice
):
# For specific function choice, return the user's parameters schema directly
fn_name
=
tool_choice
.
function
.
name
for
tool
in
tools
:
if
tool
.
function
.
name
==
fn_name
:
return
{
"type"
:
"array"
,
"minItems"
:
1
,
"maxItems"
:
1
,
"items"
:
_get_tool_schema
(
tool
),
}
return
None
elif
tool_choice
==
"required"
:
json_schema
=
{
"type"
:
"array"
,
"minItems"
:
1
,
"items"
:
{
"type"
:
"object"
,
"anyOf"
:
[
_get_tool_schema
(
tool
)
for
tool
in
tools
],
},
}
json_schema_defs
=
_get_tool_schema_defs
(
tools
)
if
json_schema_defs
:
json_schema
[
"$defs"
]
=
json_schema_defs
return
json_schema
return
None
python/sglang/srt/grpc/sglang_scheduler.proto
View file @
852a49c5
...
...
@@ -36,9 +36,9 @@ message SamplingParams {
float
presence_penalty
=
6
;
float
repetition_penalty
=
7
;
optional
int32
max_new_tokens
=
8
;
int32
max_new_tokens
=
8
;
repeated
string
stop
=
9
;
repeated
u
int32
stop_token_ids
=
10
;
repeated
int32
stop_token_ids
=
10
;
bool
skip_special_tokens
=
11
;
bool
spaces_between_special_tokens
=
12
;
...
...
@@ -47,24 +47,24 @@ message SamplingParams {
string
regex
=
13
;
string
json_schema
=
14
;
string
ebnf_grammar
=
15
;
string
structural_tag
=
16
;
}
// LoRA adapter
string
lora_path
=
1
7
;
string
lora_path
=
1
6
;
// Speculative decoding
int32
n
=
1
8
;
// Number of samples
int32
n
=
1
7
;
// Number of samples
// Token healing
bool
token_healing
=
1
9
;
bool
token_healing
=
1
8
;
// Additional parameters
int32
min_new_tokens
=
20
;
bool
ignore_eos
=
21
;
bool
no_stop_trim
=
22
;
int32
stream_interval
=
23
;
map
<
string
,
float
>
logit_bias
=
24
;
int32
min_new_tokens
=
19
;
bool
ignore_eos
=
20
;
bool
no_stop_trim
=
21
;
int32
stream_interval
=
22
;
map
<
string
,
float
>
logit_bias
=
23
;
string
structural_tag
=
24
;
// Custom parameters for extensibility
google.protobuf.Struct
custom_params
=
25
;
...
...
@@ -98,7 +98,7 @@ message GenerateRequest {
bool
return_logprob
=
5
;
int32
logprob_start_len
=
6
;
int32
top_logprobs_num
=
7
;
repeated
u
int32
token_ids_logprob
=
8
;
repeated
int32
token_ids_logprob
=
8
;
bool
return_hidden_states
=
9
;
// For disaggregated serving
...
...
@@ -122,14 +122,11 @@ message GenerateRequest {
// For load balancing
int32
dp_balance_id
=
17
;
// Whether client wants streaming response
bool
stream
=
18
;
}
message
TokenizedInput
{
string
original_text
=
1
;
// For reference
repeated
u
int32
input_ids
=
2
;
repeated
int32
input_ids
=
2
;
}
message
MultimodalInputs
{
...
...
@@ -166,50 +163,51 @@ message GenerateResponse {
}
message
GenerateStreamChunk
{
// Generated tokens (incremental chunk)
repeated
uint32
token_ids
=
1
;
// Generated token
int32
token_id
=
1
;
string
text
=
2
;
// Cumulative counts
int32
prompt_tokens
=
2
;
int32
completion_tokens
=
3
;
int32
cached_tokens
=
4
;
int32
prompt_tokens
=
3
;
int32
completion_tokens
=
4
;
int32
cached_tokens
=
5
;
//
Output l
ogprobs (if requested)
- incremental for streaming
LogProbs
output_
logprobs
=
5
;
//
L
ogprobs (if requested)
LogProbs
logprobs
=
6
;
// Hidden states (if requested)
repeated
float
hidden_states
=
6
;
repeated
float
hidden_states
=
7
;
// Input logprobs (if requested) - only in first chunk
LogProbs
input_logprobs
=
7
;
// Metadata
float
generation_time
=
8
;
// Time to generate this token
int32
queue_time
=
9
;
// Time spent in queue
}
message
GenerateComplete
{
// Final output
repeated
uint32
output_ids
=
1
;
// Finish reason as OpenAI-compatible string ("stop", "length", "abort")
string
finish_reason
=
2
;
// Token usage counts
int32
prompt_tokens
=
3
;
int32
completion_tokens
=
4
;
int32
cached_tokens
=
5
;
repeated
int32
output_ids
=
1
;
string
output_text
=
2
;
// Finish reason
enum
FinishReason
{
// The model generated a stop sequence.
STOP
=
0
;
// The model reached the maximum generation length.
LENGTH
=
1
;
// The model generated an end-of-sequence (EOS) token.
EOS_TOKEN
=
2
;
// The model generated a user-provided stop string.
STOP_STR
=
3
;
// The request was aborted by the user or system.
ABORT
=
4
;
}
FinishReason
finish_reason
=
3
;
//
Output
logprobs if requested
(cumulative)
LogProbs
output
_logprobs
=
6
;
//
All
logprobs if requested
repeated
LogProbs
all
_logprobs
=
11
;
// All hidden states if requested
repeated
HiddenStates
all_hidden_states
=
7
;
// Matched stop information (for stop sequences)
oneof
matched_stop
{
uint32
matched_token_id
=
8
;
string
matched_stop_str
=
9
;
}
// Input logprobs if requested (for prompt tokens)
LogProbs
input_logprobs
=
10
;
repeated
HiddenStates
all_hidden_states
=
12
;
}
message
GenerateError
{
...
...
@@ -224,11 +222,15 @@ message LogProbs {
// Top logprobs at each position
repeated
TopLogProbs
top_logprobs
=
3
;
// Decoded text for tokens
repeated
string
token_texts
=
4
;
}
message
TopLogProbs
{
repeated
float
values
=
1
;
repeated
int32
token_ids
=
2
;
repeated
string
token_texts
=
3
;
}
message
HiddenStates
{
...
...
@@ -283,9 +285,10 @@ message EmbedComplete {
// Additional metadata
int32
embedding_dim
=
4
;
float
generation_time
=
5
;
// For batch embeddings
repeated
Embedding
batch_embeddings
=
5
;
repeated
Embedding
batch_embeddings
=
6
;
}
message
Embedding
{
...
...
python/sglang/srt/grpc/sglang_scheduler_pb2.py
View file @
852a49c5
This diff is collapsed.
Click to expand it.
python/sglang/srt/grpc/sglang_scheduler_pb2.pyi
View file @
852a49c5
...
...
@@ -3,6 +3,7 @@ import datetime
from google.protobuf import timestamp_pb2 as _timestamp_pb2
from google.protobuf import struct_pb2 as _struct_pb2
from google.protobuf.internal import containers as _containers
from google.protobuf.internal import enum_type_wrapper as _enum_type_wrapper
from google.protobuf import descriptor as _descriptor
from google.protobuf import message as _message
from collections.abc import Iterable as _Iterable, Mapping as _Mapping
...
...
@@ -11,7 +12,7 @@ from typing import ClassVar as _ClassVar, Optional as _Optional, Union as _Union
DESCRIPTOR: _descriptor.FileDescriptor
class SamplingParams(_message.Message):
__slots__ = ("temperature", "top_p", "top_k", "min_p", "frequency_penalty", "presence_penalty", "repetition_penalty", "max_new_tokens", "stop", "stop_token_ids", "skip_special_tokens", "spaces_between_special_tokens", "regex", "json_schema", "ebnf_grammar",
"structural_tag",
"lora_path", "n", "token_healing", "min_new_tokens", "ignore_eos", "no_stop_trim", "stream_interval", "logit_bias", "custom_params")
__slots__ = ("temperature", "top_p", "top_k", "min_p", "frequency_penalty", "presence_penalty", "repetition_penalty", "max_new_tokens", "stop", "stop_token_ids", "skip_special_tokens", "spaces_between_special_tokens", "regex", "json_schema", "ebnf_grammar", "lora_path", "n", "token_healing", "min_new_tokens", "ignore_eos", "no_stop_trim", "stream_interval", "logit_bias",
"structural_tag",
"custom_params")
class LogitBiasEntry(_message.Message):
__slots__ = ("key", "value")
KEY_FIELD_NUMBER: _ClassVar[int]
...
...
@@ -34,7 +35,6 @@ class SamplingParams(_message.Message):
REGEX_FIELD_NUMBER: _ClassVar[int]
JSON_SCHEMA_FIELD_NUMBER: _ClassVar[int]
EBNF_GRAMMAR_FIELD_NUMBER: _ClassVar[int]
STRUCTURAL_TAG_FIELD_NUMBER: _ClassVar[int]
LORA_PATH_FIELD_NUMBER: _ClassVar[int]
N_FIELD_NUMBER: _ClassVar[int]
TOKEN_HEALING_FIELD_NUMBER: _ClassVar[int]
...
...
@@ -43,6 +43,7 @@ class SamplingParams(_message.Message):
NO_STOP_TRIM_FIELD_NUMBER: _ClassVar[int]
STREAM_INTERVAL_FIELD_NUMBER: _ClassVar[int]
LOGIT_BIAS_FIELD_NUMBER: _ClassVar[int]
STRUCTURAL_TAG_FIELD_NUMBER: _ClassVar[int]
CUSTOM_PARAMS_FIELD_NUMBER: _ClassVar[int]
temperature: float
top_p: float
...
...
@@ -59,7 +60,6 @@ class SamplingParams(_message.Message):
regex: str
json_schema: str
ebnf_grammar: str
structural_tag: str
lora_path: str
n: int
token_healing: bool
...
...
@@ -68,8 +68,9 @@ class SamplingParams(_message.Message):
no_stop_trim: bool
stream_interval: int
logit_bias: _containers.ScalarMap[str, float]
structural_tag: str
custom_params: _struct_pb2.Struct
def __init__(self, temperature: _Optional[float] = ..., top_p: _Optional[float] = ..., top_k: _Optional[int] = ..., min_p: _Optional[float] = ..., frequency_penalty: _Optional[float] = ..., presence_penalty: _Optional[float] = ..., repetition_penalty: _Optional[float] = ..., max_new_tokens: _Optional[int] = ..., stop: _Optional[_Iterable[str]] = ..., stop_token_ids: _Optional[_Iterable[int]] = ..., skip_special_tokens: bool = ..., spaces_between_special_tokens: bool = ..., regex: _Optional[str] = ..., json_schema: _Optional[str] = ..., ebnf_grammar: _Optional[str] = ...,
structural_tag: _Optional[str] = ...,
lora_path: _Optional[str] = ..., n: _Optional[int] = ..., token_healing: bool = ..., min_new_tokens: _Optional[int] = ..., ignore_eos: bool = ..., no_stop_trim: bool = ..., stream_interval: _Optional[int] = ..., logit_bias: _Optional[_Mapping[str, float]] = ..., custom_params: _Optional[_Union[_struct_pb2.Struct, _Mapping]] = ...) -> None: ...
def __init__(self, temperature: _Optional[float] = ..., top_p: _Optional[float] = ..., top_k: _Optional[int] = ..., min_p: _Optional[float] = ..., frequency_penalty: _Optional[float] = ..., presence_penalty: _Optional[float] = ..., repetition_penalty: _Optional[float] = ..., max_new_tokens: _Optional[int] = ..., stop: _Optional[_Iterable[str]] = ..., stop_token_ids: _Optional[_Iterable[int]] = ..., skip_special_tokens: bool = ..., spaces_between_special_tokens: bool = ..., regex: _Optional[str] = ..., json_schema: _Optional[str] = ..., ebnf_grammar: _Optional[str] = ..., lora_path: _Optional[str] = ..., n: _Optional[int] = ..., token_healing: bool = ..., min_new_tokens: _Optional[int] = ..., ignore_eos: bool = ..., no_stop_trim: bool = ..., stream_interval: _Optional[int] = ..., logit_bias: _Optional[_Mapping[str, float]] = ...,
structural_tag: _Optional[str] = ...,
custom_params: _Optional[_Union[_struct_pb2.Struct, _Mapping]] = ...) -> None: ...
class DisaggregatedParams(_message.Message):
__slots__ = ("bootstrap_host", "bootstrap_port", "bootstrap_room")
...
...
@@ -82,7 +83,7 @@ class DisaggregatedParams(_message.Message):
def __init__(self, bootstrap_host: _Optional[str] = ..., bootstrap_port: _Optional[int] = ..., bootstrap_room: _Optional[int] = ...) -> None: ...
class GenerateRequest(_message.Message):
__slots__ = ("request_id", "tokenized", "mm_inputs", "sampling_params", "return_logprob", "logprob_start_len", "top_logprobs_num", "token_ids_logprob", "return_hidden_states", "disaggregated_params", "custom_logit_processor", "timestamp", "log_metrics", "input_embeds", "lora_id", "data_parallel_rank", "dp_balance_id"
, "stream"
)
__slots__ = ("request_id", "tokenized", "mm_inputs", "sampling_params", "return_logprob", "logprob_start_len", "top_logprobs_num", "token_ids_logprob", "return_hidden_states", "disaggregated_params", "custom_logit_processor", "timestamp", "log_metrics", "input_embeds", "lora_id", "data_parallel_rank", "dp_balance_id")
REQUEST_ID_FIELD_NUMBER: _ClassVar[int]
TOKENIZED_FIELD_NUMBER: _ClassVar[int]
MM_INPUTS_FIELD_NUMBER: _ClassVar[int]
...
...
@@ -100,7 +101,6 @@ class GenerateRequest(_message.Message):
LORA_ID_FIELD_NUMBER: _ClassVar[int]
DATA_PARALLEL_RANK_FIELD_NUMBER: _ClassVar[int]
DP_BALANCE_ID_FIELD_NUMBER: _ClassVar[int]
STREAM_FIELD_NUMBER: _ClassVar[int]
request_id: str
tokenized: TokenizedInput
mm_inputs: MultimodalInputs
...
...
@@ -118,8 +118,7 @@ class GenerateRequest(_message.Message):
lora_id: str
data_parallel_rank: int
dp_balance_id: int
stream: bool
def __init__(self, request_id: _Optional[str] = ..., tokenized: _Optional[_Union[TokenizedInput, _Mapping]] = ..., mm_inputs: _Optional[_Union[MultimodalInputs, _Mapping]] = ..., sampling_params: _Optional[_Union[SamplingParams, _Mapping]] = ..., return_logprob: bool = ..., logprob_start_len: _Optional[int] = ..., top_logprobs_num: _Optional[int] = ..., token_ids_logprob: _Optional[_Iterable[int]] = ..., return_hidden_states: bool = ..., disaggregated_params: _Optional[_Union[DisaggregatedParams, _Mapping]] = ..., custom_logit_processor: _Optional[str] = ..., timestamp: _Optional[_Union[datetime.datetime, _timestamp_pb2.Timestamp, _Mapping]] = ..., log_metrics: bool = ..., input_embeds: _Optional[_Iterable[float]] = ..., lora_id: _Optional[str] = ..., data_parallel_rank: _Optional[int] = ..., dp_balance_id: _Optional[int] = ..., stream: bool = ...) -> None: ...
def __init__(self, request_id: _Optional[str] = ..., tokenized: _Optional[_Union[TokenizedInput, _Mapping]] = ..., mm_inputs: _Optional[_Union[MultimodalInputs, _Mapping]] = ..., sampling_params: _Optional[_Union[SamplingParams, _Mapping]] = ..., return_logprob: bool = ..., logprob_start_len: _Optional[int] = ..., top_logprobs_num: _Optional[int] = ..., token_ids_logprob: _Optional[_Iterable[int]] = ..., return_hidden_states: bool = ..., disaggregated_params: _Optional[_Union[DisaggregatedParams, _Mapping]] = ..., custom_logit_processor: _Optional[str] = ..., timestamp: _Optional[_Union[datetime.datetime, _timestamp_pb2.Timestamp, _Mapping]] = ..., log_metrics: bool = ..., input_embeds: _Optional[_Iterable[float]] = ..., lora_id: _Optional[str] = ..., data_parallel_rank: _Optional[int] = ..., dp_balance_id: _Optional[int] = ...) -> None: ...
class TokenizedInput(_message.Message):
__slots__ = ("original_text", "input_ids")
...
...
@@ -162,46 +161,52 @@ class GenerateResponse(_message.Message):
def __init__(self, request_id: _Optional[str] = ..., chunk: _Optional[_Union[GenerateStreamChunk, _Mapping]] = ..., complete: _Optional[_Union[GenerateComplete, _Mapping]] = ..., error: _Optional[_Union[GenerateError, _Mapping]] = ...) -> None: ...
class GenerateStreamChunk(_message.Message):
__slots__ = ("token_ids", "prompt_tokens", "completion_tokens", "cached_tokens", "output_logprobs", "hidden_states", "input_logprobs")
TOKEN_IDS_FIELD_NUMBER: _ClassVar[int]
__slots__ = ("token_id", "text", "prompt_tokens", "completion_tokens", "cached_tokens", "logprobs", "hidden_states", "generation_time", "queue_time")
TOKEN_ID_FIELD_NUMBER: _ClassVar[int]
TEXT_FIELD_NUMBER: _ClassVar[int]
PROMPT_TOKENS_FIELD_NUMBER: _ClassVar[int]
COMPLETION_TOKENS_FIELD_NUMBER: _ClassVar[int]
CACHED_TOKENS_FIELD_NUMBER: _ClassVar[int]
OUTPUT_
LOGPROBS_FIELD_NUMBER: _ClassVar[int]
LOGPROBS_FIELD_NUMBER: _ClassVar[int]
HIDDEN_STATES_FIELD_NUMBER: _ClassVar[int]
INPUT_LOGPROBS_FIELD_NUMBER: _ClassVar[int]
token_ids: _containers.RepeatedScalarFieldContainer[int]
GENERATION_TIME_FIELD_NUMBER: _ClassVar[int]
QUEUE_TIME_FIELD_NUMBER: _ClassVar[int]
token_id: int
text: str
prompt_tokens: int
completion_tokens: int
cached_tokens: int
output_
logprobs: LogProbs
logprobs: LogProbs
hidden_states: _containers.RepeatedScalarFieldContainer[float]
input_logprobs: LogProbs
def __init__(self, token_ids: _Optional[_Iterable[int]] = ..., prompt_tokens: _Optional[int] = ..., completion_tokens: _Optional[int] = ..., cached_tokens: _Optional[int] = ..., output_logprobs: _Optional[_Union[LogProbs, _Mapping]] = ..., hidden_states: _Optional[_Iterable[float]] = ..., input_logprobs: _Optional[_Union[LogProbs, _Mapping]] = ...) -> None: ...
generation_time: float
queue_time: int
def __init__(self, token_id: _Optional[int] = ..., text: _Optional[str] = ..., prompt_tokens: _Optional[int] = ..., completion_tokens: _Optional[int] = ..., cached_tokens: _Optional[int] = ..., logprobs: _Optional[_Union[LogProbs, _Mapping]] = ..., hidden_states: _Optional[_Iterable[float]] = ..., generation_time: _Optional[float] = ..., queue_time: _Optional[int] = ...) -> None: ...
class GenerateComplete(_message.Message):
__slots__ = ("output_ids", "finish_reason", "prompt_tokens", "completion_tokens", "cached_tokens", "output_logprobs", "all_hidden_states", "matched_token_id", "matched_stop_str", "input_logprobs")
__slots__ = ("output_ids", "output_text", "finish_reason", "all_logprobs", "all_hidden_states")
class FinishReason(int, metaclass=_enum_type_wrapper.EnumTypeWrapper):
__slots__ = ()
STOP: _ClassVar[GenerateComplete.FinishReason]
LENGTH: _ClassVar[GenerateComplete.FinishReason]
EOS_TOKEN: _ClassVar[GenerateComplete.FinishReason]
STOP_STR: _ClassVar[GenerateComplete.FinishReason]
ABORT: _ClassVar[GenerateComplete.FinishReason]
STOP: GenerateComplete.FinishReason
LENGTH: GenerateComplete.FinishReason
EOS_TOKEN: GenerateComplete.FinishReason
STOP_STR: GenerateComplete.FinishReason
ABORT: GenerateComplete.FinishReason
OUTPUT_IDS_FIELD_NUMBER: _ClassVar[int]
OUTPUT_TEXT_FIELD_NUMBER: _ClassVar[int]
FINISH_REASON_FIELD_NUMBER: _ClassVar[int]
PROMPT_TOKENS_FIELD_NUMBER: _ClassVar[int]
COMPLETION_TOKENS_FIELD_NUMBER: _ClassVar[int]
CACHED_TOKENS_FIELD_NUMBER: _ClassVar[int]
OUTPUT_LOGPROBS_FIELD_NUMBER: _ClassVar[int]
ALL_LOGPROBS_FIELD_NUMBER: _ClassVar[int]
ALL_HIDDEN_STATES_FIELD_NUMBER: _ClassVar[int]
MATCHED_TOKEN_ID_FIELD_NUMBER: _ClassVar[int]
MATCHED_STOP_STR_FIELD_NUMBER: _ClassVar[int]
INPUT_LOGPROBS_FIELD_NUMBER: _ClassVar[int]
output_ids: _containers.RepeatedScalarFieldContainer[int]
finish_reason: str
prompt_tokens: int
completion_tokens: int
cached_tokens: int
output_logprobs: LogProbs
output_text: str
finish_reason: GenerateComplete.FinishReason
all_logprobs: _containers.RepeatedCompositeFieldContainer[LogProbs]
all_hidden_states: _containers.RepeatedCompositeFieldContainer[HiddenStates]
matched_token_id: int
matched_stop_str: str
input_logprobs: LogProbs
def __init__(self, output_ids: _Optional[_Iterable[int]] = ..., finish_reason: _Optional[str] = ..., prompt_tokens: _Optional[int] = ..., completion_tokens: _Optional[int] = ..., cached_tokens: _Optional[int] = ..., output_logprobs: _Optional[_Union[LogProbs, _Mapping]] = ..., all_hidden_states: _Optional[_Iterable[_Union[HiddenStates, _Mapping]]] = ..., matched_token_id: _Optional[int] = ..., matched_stop_str: _Optional[str] = ..., input_logprobs: _Optional[_Union[LogProbs, _Mapping]] = ...) -> None: ...
def __init__(self, output_ids: _Optional[_Iterable[int]] = ..., output_text: _Optional[str] = ..., finish_reason: _Optional[_Union[GenerateComplete.FinishReason, str]] = ..., all_logprobs: _Optional[_Iterable[_Union[LogProbs, _Mapping]]] = ..., all_hidden_states: _Optional[_Iterable[_Union[HiddenStates, _Mapping]]] = ...) -> None: ...
class GenerateError(_message.Message):
__slots__ = ("message", "http_status_code", "details")
...
...
@@ -214,22 +219,26 @@ class GenerateError(_message.Message):
def __init__(self, message: _Optional[str] = ..., http_status_code: _Optional[str] = ..., details: _Optional[str] = ...) -> None: ...
class LogProbs(_message.Message):
__slots__ = ("token_logprobs", "token_ids", "top_logprobs")
__slots__ = ("token_logprobs", "token_ids", "top_logprobs"
, "token_texts"
)
TOKEN_LOGPROBS_FIELD_NUMBER: _ClassVar[int]
TOKEN_IDS_FIELD_NUMBER: _ClassVar[int]
TOP_LOGPROBS_FIELD_NUMBER: _ClassVar[int]
TOKEN_TEXTS_FIELD_NUMBER: _ClassVar[int]
token_logprobs: _containers.RepeatedScalarFieldContainer[float]
token_ids: _containers.RepeatedScalarFieldContainer[int]
top_logprobs: _containers.RepeatedCompositeFieldContainer[TopLogProbs]
def __init__(self, token_logprobs: _Optional[_Iterable[float]] = ..., token_ids: _Optional[_Iterable[int]] = ..., top_logprobs: _Optional[_Iterable[_Union[TopLogProbs, _Mapping]]] = ...) -> None: ...
token_texts: _containers.RepeatedScalarFieldContainer[str]
def __init__(self, token_logprobs: _Optional[_Iterable[float]] = ..., token_ids: _Optional[_Iterable[int]] = ..., top_logprobs: _Optional[_Iterable[_Union[TopLogProbs, _Mapping]]] = ..., token_texts: _Optional[_Iterable[str]] = ...) -> None: ...
class TopLogProbs(_message.Message):
__slots__ = ("values", "token_ids")
__slots__ = ("values", "token_ids"
, "token_texts"
)
VALUES_FIELD_NUMBER: _ClassVar[int]
TOKEN_IDS_FIELD_NUMBER: _ClassVar[int]
TOKEN_TEXTS_FIELD_NUMBER: _ClassVar[int]
values: _containers.RepeatedScalarFieldContainer[float]
token_ids: _containers.RepeatedScalarFieldContainer[int]
def __init__(self, values: _Optional[_Iterable[float]] = ..., token_ids: _Optional[_Iterable[int]] = ...) -> None: ...
token_texts: _containers.RepeatedScalarFieldContainer[str]
def __init__(self, values: _Optional[_Iterable[float]] = ..., token_ids: _Optional[_Iterable[int]] = ..., token_texts: _Optional[_Iterable[str]] = ...) -> None: ...
class HiddenStates(_message.Message):
__slots__ = ("values", "layer", "position")
...
...
@@ -274,18 +283,20 @@ class EmbedResponse(_message.Message):
def __init__(self, request_id: _Optional[str] = ..., complete: _Optional[_Union[EmbedComplete, _Mapping]] = ..., error: _Optional[_Union[EmbedError, _Mapping]] = ...) -> None: ...
class EmbedComplete(_message.Message):
__slots__ = ("embedding", "prompt_tokens", "cached_tokens", "embedding_dim", "batch_embeddings")
__slots__ = ("embedding", "prompt_tokens", "cached_tokens", "embedding_dim",
"generation_time",
"batch_embeddings")
EMBEDDING_FIELD_NUMBER: _ClassVar[int]
PROMPT_TOKENS_FIELD_NUMBER: _ClassVar[int]
CACHED_TOKENS_FIELD_NUMBER: _ClassVar[int]
EMBEDDING_DIM_FIELD_NUMBER: _ClassVar[int]
GENERATION_TIME_FIELD_NUMBER: _ClassVar[int]
BATCH_EMBEDDINGS_FIELD_NUMBER: _ClassVar[int]
embedding: _containers.RepeatedScalarFieldContainer[float]
prompt_tokens: int
cached_tokens: int
embedding_dim: int
generation_time: float
batch_embeddings: _containers.RepeatedCompositeFieldContainer[Embedding]
def __init__(self, embedding: _Optional[_Iterable[float]] = ..., prompt_tokens: _Optional[int] = ..., cached_tokens: _Optional[int] = ..., embedding_dim: _Optional[int] = ..., batch_embeddings: _Optional[_Iterable[_Union[Embedding, _Mapping]]] = ...) -> None: ...
def __init__(self, embedding: _Optional[_Iterable[float]] = ..., prompt_tokens: _Optional[int] = ..., cached_tokens: _Optional[int] = ..., embedding_dim: _Optional[int] = ...,
generation_time: _Optional[float] = ...,
batch_embeddings: _Optional[_Iterable[_Union[Embedding, _Mapping]]] = ...) -> None: ...
class Embedding(_message.Message):
__slots__ = ("values", "index")
...
...
python/sglang/srt/grpc/sglang_scheduler_pb2_grpc.py
View file @
852a49c5
# This file is auto-generated. Do not edit manually.
# Regenerate with: python compile_proto.py
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
"""Client and server classes corresponding to protobuf-defined services."""
import
grpc
...
...
python/sglang/srt/hf_transformers_utils.py
View file @
852a49c5
...
...
@@ -119,6 +119,37 @@ def get_hf_text_config(config: PretrainedConfig):
return
config
def
_load_deepseek_v32_model
(
model_path
:
str
,
trust_remote_code
:
bool
=
False
,
revision
:
Optional
[
str
]
=
None
,
**
kwargs
,
):
# first get the local path
local_path
=
download_from_hf
(
model_path
)
# then load the config file in json
config_file
=
os
.
path
.
join
(
local_path
,
"config.json"
)
if
not
os
.
path
.
exists
(
config_file
):
raise
RuntimeError
(
f
"Can't find config file in
{
local_path
}
."
)
with
open
(
config_file
,
"r"
)
as
f
:
config_json
=
json
.
load
(
f
)
config_json
[
"architectures"
]
=
[
"DeepseekV3ForCausalLM"
]
config_json
[
"model_type"
]
=
"deepseek_v3"
tmp_path
=
os
.
path
.
join
(
local_path
,
"_tmp_config_folder"
)
os
.
makedirs
(
tmp_path
,
exist_ok
=
True
)
unique_path
=
os
.
path
.
join
(
tmp_path
,
f
"deepseek_v32_
{
os
.
getpid
()
}
"
)
with
open
(
unique_path
,
"w"
)
as
f
:
json
.
dump
(
config_json
,
f
)
return
AutoConfig
.
from_pretrained
(
unique_path
,
trust_remote_code
=
trust_remote_code
,
revision
=
revision
,
**
kwargs
)
@
lru_cache_frozenset
(
maxsize
=
32
)
def
get_config
(
model
:
str
,
...
...
@@ -140,9 +171,17 @@ def get_config(
client
.
pull_files
(
ignore_pattern
=
[
"*.pt"
,
"*.safetensors"
,
"*.bin"
])
model
=
client
.
get_local_dir
()
config
=
AutoConfig
.
from_pretrained
(
model
,
trust_remote_code
=
trust_remote_code
,
revision
=
revision
,
**
kwargs
)
try
:
config
=
AutoConfig
.
from_pretrained
(
model
,
trust_remote_code
=
trust_remote_code
,
revision
=
revision
,
**
kwargs
)
except
ValueError
as
e
:
if
not
"deepseek_v32"
in
str
(
e
):
raise
e
config
=
_load_deepseek_v32_model
(
model
,
trust_remote_code
=
trust_remote_code
,
revision
=
revision
,
**
kwargs
)
if
(
config
.
architectures
is
not
None
and
config
.
architectures
[
0
]
==
"Phi4MMForCausalLM"
...
...
python/sglang/srt/layers/attention/aiter_backend.py
View file @
852a49c5
...
...
@@ -619,11 +619,7 @@ class AiterAttnBackend(AttentionBackend):
assert
len
(
k
.
shape
)
==
3
assert
len
(
v
.
shape
)
==
3
if
(
forward_batch
.
forward_mode
.
is_extend
()
and
not
forward_batch
.
forward_mode
.
is_target_verify
()
and
not
forward_batch
.
forward_mode
.
is_draft_extend
()
):
if
forward_batch
.
forward_mode
.
is_extend
():
if
kv_indices
.
shape
[
0
]
==
0
:
o
=
flash_attn_varlen_func
(
q
,
...
...
python/sglang/srt/layers/attention/ascend_backend.py
View file @
852a49c5
...
...
@@ -3,6 +3,7 @@ from __future__ import annotations
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
List
,
Optional
import
custom_ops
import
torch
import
torch_npu
from
torch.nn.functional
import
scaled_dot_product_attention
...
...
@@ -36,6 +37,8 @@ class ForwardMetadata:
seq_lens_cpu_int
:
Optional
[
torch
.
Tensor
]
=
None
seq_lens_cpu_list
:
Optional
[
List
[
int
]]
=
None
seq_lens_list_cumsum
:
Optional
[
List
[
int
]]
=
None
seq_lens
:
Optional
[
torch
.
Tensor
]
=
None
actual_seq_lengths_q
:
Optional
[
torch
.
Tensor
]
=
None
class
AscendAttnBackend
(
AttentionBackend
):
...
...
@@ -67,6 +70,9 @@ class AscendAttnBackend(AttentionBackend):
if
self
.
use_mla
:
self
.
kv_lora_rank
=
model_runner
.
model_config
.
kv_lora_rank
self
.
qk_rope_head_dim
=
model_runner
.
model_config
.
qk_rope_head_dim
self
.
q_head_dim
=
(
self
.
qk_rope_head_dim
+
model_runner
.
model_config
.
qk_nope_head_dim
)
self
.
native_attn
=
TorchNativeAttnBackend
(
model_runner
)
self
.
graph_metadata
=
{}
self
.
max_context_len
=
model_runner
.
model_config
.
context_len
...
...
@@ -102,10 +108,6 @@ class AscendAttnBackend(AttentionBackend):
self
.
forward_metadata
.
seq_lens_cpu_int
=
forward_batch
.
seq_lens_cpu
.
int
()
seq_lens_list_cumsum
=
np
.
cumsum
(
forward_batch
.
extend_seq_lens_cpu
)
if
forward_batch
.
is_extend_in_batch
:
seq_lens_list_cumsum
[
-
1
]
=
(
(
seq_lens_list_cumsum
[
-
1
]
-
1
)
//
tp_size
+
1
)
*
tp_size
self
.
forward_metadata
.
seq_lens_list_cumsum
=
seq_lens_list_cumsum
self
.
graph_mode
=
False
...
...
@@ -133,6 +135,10 @@ class AscendAttnBackend(AttentionBackend):
metadata
.
block_tables
=
self
.
graph_metadata
[
"block_tables"
][:
bs
,
:]
metadata
.
seq_lens_cpu_list
=
seq_lens
.
cpu
().
int
().
tolist
()
metadata
.
seq_lens
=
seq_lens
metadata
.
actual_seq_lengths_q
=
torch
.
tensor
(
[
1
+
i
*
1
for
i
in
range
(
bs
)],
dtype
=
torch
.
int32
,
device
=
seq_lens
.
device
)
self
.
graph_metadata
[
bs
]
=
metadata
self
.
forward_metadata
=
metadata
...
...
@@ -161,6 +167,8 @@ class AscendAttnBackend(AttentionBackend):
metadata
.
block_tables
[:
bs
,
max_seq_pages
:].
fill_
(
0
)
metadata
.
block_tables
[
bs
:,
:].
fill_
(
0
)
metadata
.
seq_lens
[:
bs
].
copy_
(
seq_lens
[:
bs
])
self
.
forward_metadata
=
metadata
self
.
graph_mode
=
True
...
...
@@ -168,6 +176,64 @@ class AscendAttnBackend(AttentionBackend):
def
get_cuda_graph_seq_len_fill_value
(
self
):
return
0
def
forward_sparse
(
self
,
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
layer
:
RadixAttention
,
forward_batch
:
ForwardBatch
,
save_kv_cache
:
bool
=
True
,
# For multi_head latent attention
q_rope
:
Optional
[
torch
.
Tensor
]
=
None
,
k_rope
:
Optional
[
torch
.
Tensor
]
=
None
,
topk_indices
:
torch
.
Tensor
=
None
,
):
is_prefill
=
forward_batch
.
forward_mode
.
is_extend
()
if
save_kv_cache
:
k
=
k
.
view
(
-
1
,
layer
.
tp_k_head_num
,
self
.
kv_lora_rank
)
k_rope
=
k_rope
.
view
(
-
1
,
layer
.
tp_k_head_num
,
self
.
qk_rope_head_dim
)
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
,
forward_batch
.
out_cache_loc
,
k
,
k_rope
)
q_nope
,
q_pe
=
q
,
q_rope
k_nope
,
k_pe
=
forward_batch
.
token_to_kv_pool
.
get_kv_buffer
(
layer
.
layer_id
)
block_table
=
self
.
forward_metadata
.
block_tables
if
is_prefill
:
actual_seq_qlen
=
torch
.
cumsum
(
forward_batch
.
seq_lens
,
dim
=
0
)
else
:
if
self
.
forward_metadata
.
actual_seq_lengths_q
is
None
:
actual_seq_qlen
=
(
torch
.
arange
(
1
,
q
.
shape
[
0
]
+
1
).
to
(
q
.
device
).
to
(
torch
.
int32
)
)
else
:
actual_seq_qlen
=
self
.
forward_metadata
.
actual_seq_lengths_q
if
self
.
forward_metadata
.
seq_lens_cpu_int
is
None
:
actual_seq_lengths_kv
=
self
.
forward_metadata
.
seq_lens
else
:
actual_seq_lengths_kv
=
self
.
forward_metadata
.
seq_lens_cpu_int
attn_out
=
torch
.
ops
.
custom
.
npu_sparse_flash_attention
(
query
=
q_nope
,
key
=
k_nope
,
value
=
k_nope
,
query_rope
=
q_pe
,
key_rope
=
k_pe
,
sparse_indices
=
topk_indices
,
scale_value
=
layer
.
scaling
,
actual_seq_lengths_query
=
actual_seq_qlen
.
to
(
torch
.
int32
),
actual_seq_lengths_kv
=
actual_seq_lengths_kv
.
to
(
q
.
device
),
block_table
=
block_table
,
sparse_block_size
=
1
,
layout_query
=
"TND"
,
layout_kv
=
"PA_BSND"
,
sparse_mode
=
3
,
)
return
attn_out
def
forward_extend
(
self
,
q
,
...
...
@@ -176,7 +242,23 @@ class AscendAttnBackend(AttentionBackend):
layer
:
RadixAttention
,
forward_batch
:
ForwardBatch
,
save_kv_cache
:
bool
=
True
,
# For multi_head latent attention
q_rope
:
Optional
[
torch
.
Tensor
]
=
None
,
k_rope
:
Optional
[
torch
.
Tensor
]
=
None
,
topk_indices
:
Optional
[
torch
.
Tensor
]
=
None
,
):
if
topk_indices
is
not
None
:
return
self
.
forward_sparse
(
q
,
k
,
v
,
layer
,
forward_batch
,
save_kv_cache
,
q_rope
,
k_rope
,
topk_indices
,
)
if
not
self
.
use_mla
:
if
save_kv_cache
:
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
...
...
@@ -437,10 +519,23 @@ class AscendAttnBackend(AttentionBackend):
# For multi-head latent attention
q_rope
:
Optional
[
torch
.
Tensor
]
=
None
,
k_rope
:
Optional
[
torch
.
Tensor
]
=
None
,
topk_indices
:
Optional
[
torch
.
Tensor
]
=
None
,
):
if
is_mla_preprocess_enabled
():
# MLAPO does saving kv_cache
save_kv_cache
=
False
if
topk_indices
is
not
None
:
return
self
.
forward_sparse
(
q
,
k
,
v
,
layer
,
forward_batch
,
save_kv_cache
,
q_rope
,
k_rope
,
topk_indices
,
)
if
self
.
graph_mode
:
return
self
.
forward_decode_graph
(
...
...
python/sglang/srt/layers/attention/attention_registry.py
View file @
852a49c5
import
logging
logger
=
logging
.
getLogger
(
__name__
)
ATTENTION_BACKENDS
=
{}
...
...
@@ -66,6 +62,13 @@ def create_ascend_backend(runner):
return
AscendAttnBackend
(
runner
)
@
register_attention_backend
(
"nsa"
)
def
create_nsa_backend
(
runner
):
from
sglang.srt.layers.attention.nsa_backend
import
NativeSparseAttnBackend
return
NativeSparseAttnBackend
(
runner
)
@
register_attention_backend
(
"triton"
)
def
create_triton_backend
(
runner
):
assert
not
runner
.
model_config
.
is_encoder_decoder
,
(
...
...
@@ -162,37 +165,35 @@ def create_dual_chunk_flash_attn_backend(runner):
return
DualChunkFlashAttentionBackend
(
runner
)
def
attn_backend_wrapper
(
runner
,
full_attn_backend
):
"""
Wrapper for special models like hybrid GDN, so we don't
need to change the code of the original attention backend.
"""
assert
not
(
runner
.
is_hybrid_gdn
and
runner
.
use_mla_backend
),
"hybrid_gdn can only be used with non-MLA models."
# wrap for hybrid GDN models
if
runner
.
is_hybrid_gdn
:
from
sglang.srt.utils
import
is_blackwell
,
is_npu
if
is_blackwell
():
assert
(
runner
.
server_args
.
attention_backend
==
"triton"
),
"triton backend is the only supported backend on Blackwell GPUs for hybrid GDN models, use --attention-backend triton to specify the backend."
if
is_npu
():
assert
(
runner
.
server_args
.
attention_backend
==
"ascend"
),
"ascend backend is the only supported backend on NPU for hybrid GDN models, use --attention-backend ascend to specify the backend."
logger
.
info
(
f
"Using hybrid linear attention backend for hybrid GDN models."
)
from
sglang.srt.layers.attention.hybrid_linear_attn_backend
import
(
HybridLinearAttnBackend
,
MambaAttnBackend
,
)
@
register_attention_backend
(
"hybrid_linear_attn"
)
def
create_hybrid_linear_attn_backend
(
runner
):
assert
(
runner
.
is_hybrid_gdn
),
"hybrid_linear_attn backend can only be used with hybrid GDN models."
from
sglang.srt.layers.attention.hybrid_linear_attn_backend
import
(
HybridLinearAttnBackend
,
MambaAttnBackend
,
)
from
sglang.srt.utils
import
is_blackwell
,
is_npu
linear_attn_backend
=
MambaAttnBackend
(
runner
)
full_attn_layers
=
runner
.
model_config
.
hf_config
.
full_attention_layer_ids
return
HybridLinearAttnBackend
(
full_attn_backend
,
linear_attn_backend
,
full_attn_layers
if
is_npu
():
from
sglang.srt.layers.attention.ascend_backend
import
AscendAttnBackend
full_attn_backend
=
AscendAttnBackend
(
runner
)
elif
is_blackwell
():
from
sglang.srt.layers.attention.triton_backend
import
TritonAttnBackend
full_attn_backend
=
TritonAttnBackend
(
runner
)
else
:
from
sglang.srt.layers.attention.flashattention_backend
import
(
FlashAttentionBackend
,
)
return
full_attn_backend
full_attn_backend
=
FlashAttentionBackend
(
runner
)
linear_attn_backend
=
MambaAttnBackend
(
runner
)
full_attn_layers
=
runner
.
model_config
.
hf_config
.
full_attention_layer_ids
return
HybridLinearAttnBackend
(
full_attn_backend
,
linear_attn_backend
,
full_attn_layers
)
Prev
1
2
3
4
5
6
…
8
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment