Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
852a49c5
Commit
852a49c5
authored
Sep 30, 2025
by
maxiao
Browse files
adapt to dsv32 on dcu
parent
8f7453e3
Changes
159
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
512 additions
and
1028 deletions
+512
-1028
python/sglang/srt/entrypoints/grpc_request_manager.py
python/sglang/srt/entrypoints/grpc_request_manager.py
+41
-264
python/sglang/srt/entrypoints/grpc_server.py
python/sglang/srt/entrypoints/grpc_server.py
+46
-125
python/sglang/srt/entrypoints/openai/protocol.py
python/sglang/srt/entrypoints/openai/protocol.py
+5
-30
python/sglang/srt/entrypoints/openai/serving_base.py
python/sglang/srt/entrypoints/openai/serving_base.py
+6
-25
python/sglang/srt/entrypoints/openai/serving_chat.py
python/sglang/srt/entrypoints/openai/serving_chat.py
+49
-177
python/sglang/srt/entrypoints/openai/serving_completions.py
python/sglang/srt/entrypoints/openai/serving_completions.py
+3
-4
python/sglang/srt/entrypoints/openai/serving_responses.py
python/sglang/srt/entrypoints/openai/serving_responses.py
+0
-2
python/sglang/srt/eplb/expert_location.py
python/sglang/srt/eplb/expert_location.py
+5
-30
python/sglang/srt/function_call/function_call_parser.py
python/sglang/srt/function_call/function_call_parser.py
+2
-3
python/sglang/srt/function_call/glm4_moe_detector.py
python/sglang/srt/function_call/glm4_moe_detector.py
+3
-3
python/sglang/srt/function_call/json_array_parser.py
python/sglang/srt/function_call/json_array_parser.py
+0
-63
python/sglang/srt/function_call/utils.py
python/sglang/srt/function_call/utils.py
+5
-96
python/sglang/srt/grpc/sglang_scheduler.proto
python/sglang/srt/grpc/sglang_scheduler.proto
+51
-48
python/sglang/srt/grpc/sglang_scheduler_pb2.py
python/sglang/srt/grpc/sglang_scheduler_pb2.py
+68
-69
python/sglang/srt/grpc/sglang_scheduler_pb2.pyi
python/sglang/srt/grpc/sglang_scheduler_pb2.pyi
+50
-39
python/sglang/srt/grpc/sglang_scheduler_pb2_grpc.py
python/sglang/srt/grpc/sglang_scheduler_pb2_grpc.py
+0
-3
python/sglang/srt/hf_transformers_utils.py
python/sglang/srt/hf_transformers_utils.py
+42
-3
python/sglang/srt/layers/attention/aiter_backend.py
python/sglang/srt/layers/attention/aiter_backend.py
+1
-5
python/sglang/srt/layers/attention/ascend_backend.py
python/sglang/srt/layers/attention/ascend_backend.py
+99
-4
python/sglang/srt/layers/attention/attention_registry.py
python/sglang/srt/layers/attention/attention_registry.py
+36
-35
No files found.
python/sglang/srt/entrypoints/grpc_request_manager.py
View file @
852a49c5
...
@@ -4,7 +4,6 @@ Mimics TokenizerManager's state management and ZMQ communication patterns.
...
@@ -4,7 +4,6 @@ Mimics TokenizerManager's state management and ZMQ communication patterns.
"""
"""
import
asyncio
import
asyncio
import
copy
import
dataclasses
import
dataclasses
import
logging
import
logging
import
os
import
os
...
@@ -12,8 +11,7 @@ import signal
...
@@ -12,8 +11,7 @@ import signal
import
sys
import
sys
import
threading
import
threading
import
time
import
time
import
uuid
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Union
from
typing
import
Any
,
AsyncGenerator
,
Dict
,
List
,
Optional
,
Union
import
grpc
import
grpc
import
zmq
import
zmq
...
@@ -81,10 +79,11 @@ class GrpcReqState:
...
@@ -81,10 +79,11 @@ class GrpcReqState:
last_completion_tokens
:
int
=
1
last_completion_tokens
:
int
=
1
# Streaming state
# Streaming state
last_output_offset
:
int
=
0
stream_finished
:
bool
=
False
stream_finished
:
bool
=
False
input_logprobs_sent
:
bool
=
False
# Track if input logprobs were sent in streaming
# Token accumulation (for non-streaming)
# Output accumulation
text
:
str
=
""
output_ids
:
List
[
int
]
=
dataclasses
.
field
(
default_factory
=
list
)
output_ids
:
List
[
int
]
=
dataclasses
.
field
(
default_factory
=
list
)
input_token_logprobs_val
:
List
[
float
]
=
dataclasses
.
field
(
default_factory
=
list
)
input_token_logprobs_val
:
List
[
float
]
=
dataclasses
.
field
(
default_factory
=
list
)
input_token_logprobs_idx
:
List
[
int
]
=
dataclasses
.
field
(
default_factory
=
list
)
input_token_logprobs_idx
:
List
[
int
]
=
dataclasses
.
field
(
default_factory
=
list
)
...
@@ -140,6 +139,8 @@ class GrpcRequestManager:
...
@@ -140,6 +139,8 @@ class GrpcRequestManager:
self
.
is_pause_cond
=
asyncio
.
Condition
()
self
.
is_pause_cond
=
asyncio
.
Condition
()
# Metrics
# Metrics
self
.
request_counter
=
0
self
.
request_counter_lock
=
asyncio
.
Lock
()
self
.
last_receive_tstamp
=
time
.
time
()
self
.
last_receive_tstamp
=
time
.
time
()
# Crash dump for debugging
# Crash dump for debugging
...
@@ -157,133 +158,22 @@ class GrpcRequestManager:
...
@@ -157,133 +158,22 @@ class GrpcRequestManager:
obj
:
TokenizedGenerateReqInput
,
obj
:
TokenizedGenerateReqInput
,
request_id
:
Optional
[
str
]
=
None
,
request_id
:
Optional
[
str
]
=
None
,
grpc_context
:
Optional
[
grpc
.
aio
.
ServicerContext
]
=
None
,
grpc_context
:
Optional
[
grpc
.
aio
.
ServicerContext
]
=
None
,
)
->
A
sync
Generator
[
Union
[
Dict
,
List
[
Dict
]],
None
]
:
)
->
a
sync
io
.
Queue
:
"""
"""
Submit a generation request to the scheduler with n>1 parallel sampling support.
Submit a generation request to the scheduler.
Returns a queue for streaming outputs.
This method implements the same two-phase approach as tokenizer_manager.py:
1. Phase 1: Send prefix caching request (max_new_tokens=0)
2. Phase 2: Send n generation requests that reuse the cached prefix
Yields individual responses for streaming, or aggregated responses for non-streaming.
"""
"""
n
=
getattr
(
obj
.
sampling_params
,
"n"
,
1
)
if
n
<=
1
:
async
for
response
in
self
.
_handle_single_request
(
obj
,
request_id
,
grpc_context
):
yield
response
return
# N>1 handling - two-phase approach
logger
.
debug
(
f
"Multiple sampling request (n=
{
n
}
), using two-phase approach"
)
# Generate base request ID if not provided
if
request_id
is
None
:
base_request_id
=
f
"grpc-
{
uuid
.
uuid4
().
hex
}
"
else
:
base_request_id
=
request_id
# Phase 1: Cache the common prefix
logger
.
debug
(
f
"Phase 1: Caching prefix for request
{
base_request_id
}
"
)
prefix_obj
=
copy
.
copy
(
obj
)
prefix_obj
.
sampling_params
=
copy
.
copy
(
obj
.
sampling_params
)
prefix_obj
.
sampling_params
.
max_new_tokens
=
0
# Prefill-only
prefix_obj
.
sampling_params
.
n
=
1
# Don't replicate prefix request
# Send prefix caching request and consume response
async
for
_
in
self
.
_handle_single_request
(
prefix_obj
,
f
"
{
base_request_id
}
-prefix"
,
grpc_context
):
# Consume prefix response (usually just one chunk with finish_reason)
pass
logger
.
debug
(
f
"Phase 1 completed: Prefix cached for
{
base_request_id
}
"
)
# Phase 2: Generate n parallel requests
logger
.
debug
(
f
"Phase 2: Generating
{
n
}
parallel requests"
)
generators
=
[]
request_ids
=
[]
for
i
in
range
(
n
):
# Create individual generation request
gen_obj
=
copy
.
copy
(
obj
)
gen_obj
.
sampling_params
=
copy
.
copy
(
obj
.
sampling_params
)
gen_obj
.
sampling_params
.
n
=
1
# Each request generates 1 response
gen_request_id
=
f
"
{
base_request_id
}
-
{
i
}
"
request_ids
.
append
(
gen_request_id
)
# Start generation request
generators
.
append
(
self
.
_handle_single_request
(
gen_obj
,
gen_request_id
,
grpc_context
)
)
# Handle response aggregation
is_stream
=
getattr
(
obj
,
"stream"
,
False
)
if
not
is_stream
:
# Non-streaming: collect all responses and return as batch
logger
.
debug
(
f
"Non-streaming mode: collecting
{
n
}
responses"
)
responses
=
[]
for
generator
in
generators
:
async
for
response
in
generator
:
responses
.
append
(
response
)
yield
responses
# Return all responses as a batch
else
:
# Streaming mode: multiplex responses with index for ordering
logger
.
debug
(
f
"Streaming mode: multiplexing
{
n
}
streams"
)
rid_to_index
=
{
rid
:
i
for
i
,
rid
in
enumerate
(
request_ids
)}
# Create async tasks for all generators
task_map
=
{}
for
generator
in
generators
:
task
=
asyncio
.
create_task
(
generator
.
__anext__
())
task_map
[
task
]
=
generator
# Process responses as they arrive
while
task_map
:
done
,
_
=
await
asyncio
.
wait
(
task_map
.
keys
(),
return_when
=
asyncio
.
FIRST_COMPLETED
)
for
task
in
done
:
generator
=
task_map
.
pop
(
task
)
try
:
response
=
await
task
# Add index for client-side ordering
if
isinstance
(
response
,
dict
)
and
"meta_info"
in
response
:
response_rid
=
response
[
"meta_info"
].
get
(
"id"
,
""
)
if
response_rid
in
rid_to_index
:
response
[
"index"
]
=
rid_to_index
[
response_rid
]
yield
response
# Create next task for this generator
next_task
=
asyncio
.
create_task
(
generator
.
__anext__
())
task_map
[
next_task
]
=
generator
except
StopAsyncIteration
:
# This generator is finished
pass
async
def
_handle_single_request
(
self
,
obj
:
TokenizedGenerateReqInput
,
request_id
:
Optional
[
str
]
=
None
,
grpc_context
:
Optional
[
grpc
.
aio
.
ServicerContext
]
=
None
,
):
"""Handle a single request - core implementation without n>1 logic."""
# Generate request ID if not provided
# Generate request ID if not provided
if
request_id
is
None
:
if
request_id
is
None
:
request_id
=
f
"grpc-
{
uuid
.
uuid4
().
hex
}
"
async
with
self
.
request_counter_lock
:
request_id
=
f
"grpc-
{
self
.
request_counter
}
"
self
.
request_counter
+=
1
obj
.
rid
=
request_id
obj
.
rid
=
request_id
# Create and register request state
# TODO: support log_request
# TODO: support log_request
# Create request state
state
=
GrpcReqState
(
state
=
GrpcReqState
(
request_id
=
request_id
,
request_id
=
request_id
,
grpc_context
=
grpc_context
,
grpc_context
=
grpc_context
,
...
@@ -299,51 +189,19 @@ class GrpcRequestManager:
...
@@ -299,51 +189,19 @@ class GrpcRequestManager:
state
.
session_id
=
obj
.
session_params
.
session_id
state
.
session_id
=
obj
.
session_params
.
session_id
state
.
is_session_request
=
True
state
.
is_session_request
=
True
# Register state
self
.
rid_to_state
[
request_id
]
=
state
self
.
rid_to_state
[
request_id
]
=
state
self
.
record_request_for_crash_dump
(
obj
)
self
.
record_request_for_crash_dump
(
obj
)
# Send to scheduler via ZMQ
try
:
try
:
# Send to scheduler - let exceptions bubble up to grpc_server.py
await
self
.
_send_to_scheduler
(
obj
)
await
self
.
_send_to_scheduler
(
obj
)
except
Exception
as
e
:
is_stream
=
getattr
(
obj
,
"stream"
,
False
)
# Clean up on failure
while
True
:
# Client cancelled - notify scheduler and exit
if
grpc_context
and
grpc_context
.
cancelled
():
await
self
.
abort_request
(
request_id
)
return
try
:
response
=
await
asyncio
.
wait_for
(
state
.
out_queue
.
get
(),
timeout
=
4
)
if
is_stream
:
yield
response
# Non-streaming: yield final response with accumulated tokens from state
if
isinstance
(
response
,
dict
)
and
response
.
get
(
"finished"
,
False
):
if
not
is_stream
:
final_response
=
response
.
copy
()
final_response
[
"token_ids"
]
=
state
.
output_ids
yield
final_response
break
except
asyncio
.
TimeoutError
:
# Timeout waiting for response - abort and cleanup
logger
.
warning
(
f
"Timeout waiting for response for request
{
request_id
}
"
)
await
self
.
abort_request
(
request_id
)
return
finally
:
# Always clean up request state when exiting
self
.
_cleanup_request_state
(
request_id
)
def
_cleanup_request_state
(
self
,
request_id
:
str
):
"""Clean up local request state (does not notify scheduler)."""
if
request_id
in
self
.
rid_to_state
:
del
self
.
rid_to_state
[
request_id
]
del
self
.
rid_to_state
[
request_id
]
raise
RuntimeError
(
f
"Failed to send request to scheduler:
{
e
}
"
)
return
state
.
out_queue
async
def
embedding_request
(
async
def
embedding_request
(
self
,
self
,
...
@@ -356,7 +214,9 @@ class GrpcRequestManager:
...
@@ -356,7 +214,9 @@ class GrpcRequestManager:
"""
"""
# Generate request ID if not provided
# Generate request ID if not provided
if
request_id
is
None
:
if
request_id
is
None
:
request_id
=
f
"grpc-embed-
{
uuid
.
uuid4
().
hex
}
"
async
with
self
.
request_counter_lock
:
request_id
=
f
"grpc-embed-
{
self
.
request_counter
}
"
self
.
request_counter
+=
1
obj
.
rid
=
request_id
obj
.
rid
=
request_id
...
@@ -495,6 +355,7 @@ class GrpcRequestManager:
...
@@ -495,6 +355,7 @@ class GrpcRequestManager:
# Extract output for this request
# Extract output for this request
output_data
=
{
output_data
=
{
"request_id"
:
rid
,
"request_id"
:
rid
,
"text"
:
batch_out
.
decoded_texts
[
i
]
if
batch_out
.
decoded_texts
else
""
,
"token_ids"
:
batch_out
.
output_ids
[
i
]
if
batch_out
.
output_ids
else
[],
"token_ids"
:
batch_out
.
output_ids
[
i
]
if
batch_out
.
output_ids
else
[],
"finished"
:
batch_out
.
finished_reasons
[
i
]
is
not
None
,
"finished"
:
batch_out
.
finished_reasons
[
i
]
is
not
None
,
"meta_info"
:
{
"meta_info"
:
{
...
@@ -506,9 +367,6 @@ class GrpcRequestManager:
...
@@ -506,9 +367,6 @@ class GrpcRequestManager:
if
batch_out
.
completion_tokens
if
batch_out
.
completion_tokens
else
0
else
0
),
),
"cached_tokens"
:
(
batch_out
.
cached_tokens
[
i
]
if
batch_out
.
cached_tokens
else
0
),
"finish_reason"
:
(
"finish_reason"
:
(
str
(
batch_out
.
finished_reasons
[
i
])
str
(
batch_out
.
finished_reasons
[
i
])
if
batch_out
.
finished_reasons
[
i
]
if
batch_out
.
finished_reasons
[
i
]
...
@@ -517,110 +375,29 @@ class GrpcRequestManager:
...
@@ -517,110 +375,29 @@ class GrpcRequestManager:
},
},
}
}
# Accumulate input logprobs (only once, usually in first chunk)
# Add logprobs if available
if
batch_out
.
input_token_logprobs_val
and
i
<
len
(
batch_out
.
input_token_logprobs_val
):
if
not
state
.
input_token_logprobs_val
:
state
.
input_token_logprobs_val
.
extend
(
batch_out
.
input_token_logprobs_val
[
i
]
)
if
batch_out
.
input_token_logprobs_idx
and
i
<
len
(
batch_out
.
input_token_logprobs_idx
):
state
.
input_token_logprobs_idx
.
extend
(
batch_out
.
input_token_logprobs_idx
[
i
]
)
if
batch_out
.
input_top_logprobs_val
and
i
<
len
(
batch_out
.
input_top_logprobs_val
):
state
.
input_top_logprobs_val
.
extend
(
batch_out
.
input_top_logprobs_val
[
i
]
)
if
batch_out
.
input_top_logprobs_idx
and
i
<
len
(
batch_out
.
input_top_logprobs_idx
):
state
.
input_top_logprobs_idx
.
extend
(
batch_out
.
input_top_logprobs_idx
[
i
]
)
# Send input logprobs based on mode
if
state
.
input_token_logprobs_val
:
if
state
.
obj
.
stream
and
not
state
.
input_logprobs_sent
:
# Streaming: send input logprobs once in first chunk that has them
output_data
[
"input_logprobs"
]
=
{
"token_logprobs_val"
:
state
.
input_token_logprobs_val
,
"token_logprobs_idx"
:
state
.
input_token_logprobs_idx
,
"top_logprobs_val"
:
state
.
input_top_logprobs_val
,
"top_logprobs_idx"
:
state
.
input_top_logprobs_idx
,
}
state
.
input_logprobs_sent
=
True
elif
not
state
.
obj
.
stream
and
output_data
[
"finished"
]:
# Non-streaming: send input logprobs in final chunk
output_data
[
"input_logprobs"
]
=
{
"token_logprobs_val"
:
state
.
input_token_logprobs_val
,
"token_logprobs_idx"
:
state
.
input_token_logprobs_idx
,
"top_logprobs_val"
:
state
.
input_top_logprobs_val
,
"top_logprobs_idx"
:
state
.
input_top_logprobs_idx
,
}
# Add output logprobs if available (RAW - no detokenization!)
if
batch_out
.
output_token_logprobs_val
and
i
<
len
(
if
batch_out
.
output_token_logprobs_val
and
i
<
len
(
batch_out
.
output_token_logprobs_val
batch_out
.
output_token_logprobs_val
):
):
# Accumulate in state first
output_data
[
"logprobs"
]
=
{
state
.
output_token_logprobs_val
.
extend
(
"tokens"
:
batch_out
.
output_token_logprobs_val
[
i
],
batch_out
.
output_token_logprobs_val
[
i
]
"top_logprobs"
:
(
)
if
batch_out
.
output_token_logprobs_idx
and
i
<
len
(
batch_out
.
output_token_logprobs_idx
):
state
.
output_token_logprobs_idx
.
extend
(
batch_out
.
output_token_logprobs_idx
[
i
]
)
if
batch_out
.
output_top_logprobs_val
and
i
<
len
(
batch_out
.
output_top_logprobs_val
):
state
.
output_top_logprobs_val
.
extend
(
batch_out
.
output_top_logprobs_val
[
i
]
batch_out
.
output_top_logprobs_val
[
i
]
)
if
batch_out
.
output_top_logprobs_val
if
batch_out
.
output_top_logprobs_idx
and
i
<
len
(
and
i
<
len
(
batch_out
.
output_top_logprobs_val
)
batch_out
.
output_top_logprobs_idx
else
None
):
),
state
.
output_top_logprobs_idx
.
extend
(
}
batch_out
.
output_top_logprobs_idx
[
i
]
)
# Update state
if
output_data
[
"text"
]:
if
state
.
obj
.
stream
:
state
.
text
+=
output_data
[
"text"
][
state
.
last_output_offset
:]
# For streaming: send incremental logprobs (only new tokens in this chunk)
state
.
last_output_offset
=
len
(
output_data
[
"text"
])
# NOTE: this is different than TokenizerManager, which always accumulates
def
get_part
(
attr_name
):
source_list
=
getattr
(
batch_out
,
attr_name
,
None
)
return
(
source_list
[
i
]
if
source_list
and
i
<
len
(
source_list
)
else
[]
)
output_data
[
"output_logprobs"
]
=
{
"token_logprobs_val"
:
batch_out
.
output_token_logprobs_val
[
i
],
"token_logprobs_idx"
:
get_part
(
"output_token_logprobs_idx"
),
"top_logprobs_val"
:
get_part
(
"output_top_logprobs_val"
),
"top_logprobs_idx"
:
get_part
(
"output_top_logprobs_idx"
),
}
elif
output_data
[
"finished"
]:
# Non-streaming: send cumulative output logprobs in final chunk
output_data
[
"output_logprobs"
]
=
{
"token_logprobs_val"
:
state
.
output_token_logprobs_val
,
"token_logprobs_idx"
:
state
.
output_token_logprobs_idx
,
"top_logprobs_val"
:
state
.
output_top_logprobs_val
,
"top_logprobs_idx"
:
state
.
output_top_logprobs_idx
,
}
# Update state for accumulation
if
output_data
[
"token_ids"
]:
if
output_data
[
"token_ids"
]:
state
.
output_ids
.
extend
(
output_data
[
"token_ids"
])
state
.
output_ids
.
extend
(
output_data
[
"token_ids"
])
# Send to output queue
await
state
.
out_queue
.
put
(
output_data
)
await
state
.
out_queue
.
put
(
output_data
)
# Handle completion
# Handle completion
...
...
python/sglang/srt/entrypoints/grpc_server.py
View file @
852a49c5
...
@@ -181,34 +181,20 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
...
@@ -181,34 +181,20 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
# Convert gRPC request to internal format
# Convert gRPC request to internal format
tokenized_req
=
self
.
_convert_generate_request
(
request
)
tokenized_req
=
self
.
_convert_generate_request
(
request
)
# Submit to request manager
(automatically handles n>1)
# Submit to request manager
response_generator
=
self
.
request_manager
.
generate_request
(
output_queue
=
await
self
.
request_manager
.
generate_request
(
obj
=
tokenized_req
,
obj
=
tokenized_req
,
request_id
=
request
.
request_id
,
request_id
=
request
.
request_id
,
grpc_context
=
context
,
grpc_context
=
context
,
)
)
async
for
output
in
response_generator
:
# Stream outputs
# Handle batch responses (for n>1 non-streaming)
while
True
:
if
isinstance
(
output
,
list
):
try
:
for
batch_output
in
output
:
# Get output with timeout
if
"error"
in
batch_output
:
output
=
await
asyncio
.
wait_for
(
output_queue
.
get
(),
timeout
=
4
)
yield
sglang_scheduler_pb2
.
GenerateResponse
(
request_id
=
request
.
request_id
,
# Check for errors
error
=
sglang_scheduler_pb2
.
GenerateError
(
message
=
batch_output
[
"error"
],
http_status_code
=
(
"500"
if
"abort"
not
in
batch_output
else
"499"
),
),
)
else
:
# All non-error batch outputs are final responses
yield
self
.
_create_completion_response
(
request
.
request_id
,
batch_output
)
else
:
# Handle single response (for streaming or n=1 non-streaming)
if
"error"
in
output
:
if
"error"
in
output
:
yield
sglang_scheduler_pb2
.
GenerateResponse
(
yield
sglang_scheduler_pb2
.
GenerateResponse
(
request_id
=
request
.
request_id
,
request_id
=
request
.
request_id
,
...
@@ -219,13 +205,27 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
...
@@ -219,13 +205,27 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
),
),
),
),
)
)
elif
output
.
get
(
"finished"
,
False
):
break
# Check if finished
if
output
.
get
(
"finished"
,
False
):
# Send completion
yield
self
.
_create_completion_response
(
yield
self
.
_create_completion_response
(
request
.
request_id
,
output
request
.
request_id
,
output
)
)
break
else
:
else
:
# Send chunk
yield
self
.
_create_chunk_response
(
request
.
request_id
,
output
)
yield
self
.
_create_chunk_response
(
request
.
request_id
,
output
)
except
asyncio
.
TimeoutError
:
# Check if context is still active
if
context
.
cancelled
():
# Abort the request
await
self
.
request_manager
.
abort_request
(
request
.
request_id
)
break
continue
except
Exception
as
e
:
except
Exception
as
e
:
logger
.
error
(
f
"Generate failed:
{
e
}
\n
{
get_exception_traceback
()
}
"
)
logger
.
error
(
f
"Generate failed:
{
e
}
\n
{
get_exception_traceback
()
}
"
)
yield
sglang_scheduler_pb2
.
GenerateResponse
(
yield
sglang_scheduler_pb2
.
GenerateResponse
(
...
@@ -266,6 +266,7 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
...
@@ -266,6 +266,7 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
prompt_tokens
=
result
.
get
(
"prompt_tokens"
,
0
),
prompt_tokens
=
result
.
get
(
"prompt_tokens"
,
0
),
cached_tokens
=
0
,
cached_tokens
=
0
,
embedding_dim
=
len
(
result
[
"embedding"
]),
embedding_dim
=
len
(
result
[
"embedding"
]),
generation_time
=
time
.
time
()
-
self
.
start_time
,
),
),
)
)
...
@@ -321,14 +322,14 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
...
@@ -321,14 +322,14 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
logger
.
info
(
f
"Sending health check request to request manager..."
)
logger
.
info
(
f
"Sending health check request to request manager..."
)
# Submit and wait for response
# Submit and wait for response
output_
generator
=
self
.
request_manager
.
generate_request
(
output_
queue
=
await
self
.
request_manager
.
generate_request
(
health_request
,
request_id
=
rid
health_request
,
request_id
=
rid
)
)
try
:
try
:
#
Get first
response with timeout
#
Wait for
response with
configurable
timeout
response
=
await
asyncio
.
wait_for
(
response
=
await
asyncio
.
wait_for
(
output_
generator
.
__anext__
(),
timeout
=
HEALTH_CHECK_TIMEOUT
output_
queue
.
get
(),
timeout
=
HEALTH_CHECK_TIMEOUT
)
)
# Clean up
# Clean up
...
@@ -403,8 +404,8 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
...
@@ -403,8 +404,8 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
return_logprob
=
grpc_req
.
return_logprob
,
return_logprob
=
grpc_req
.
return_logprob
,
logprob_start_len
=
grpc_req
.
logprob_start_len
or
-
1
,
logprob_start_len
=
grpc_req
.
logprob_start_len
or
-
1
,
top_logprobs_num
=
grpc_req
.
top_logprobs_num
or
0
,
top_logprobs_num
=
grpc_req
.
top_logprobs_num
or
0
,
stream
=
grpc_req
.
stream
or
False
,
stream
=
True
,
# Always
stream
f
or
gRPC
lora_
id
=
grpc_req
.
lora_id
if
grpc_req
.
lora_id
else
None
,
lora_
path
=
grpc_req
.
lora_id
if
grpc_req
.
lora_id
else
None
,
token_ids_logprob
=
(
token_ids_logprob
=
(
list
(
grpc_req
.
token_ids_logprob
)
if
grpc_req
.
token_ids_logprob
else
None
list
(
grpc_req
.
token_ids_logprob
)
if
grpc_req
.
token_ids_logprob
else
None
),
),
...
@@ -437,7 +438,6 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
...
@@ -437,7 +438,6 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
regex
=
None
regex
=
None
json_schema
=
None
json_schema
=
None
ebnf_grammar
=
None
ebnf_grammar
=
None
structural_tag
=
None
if
grpc_params
.
HasField
(
"regex"
):
if
grpc_params
.
HasField
(
"regex"
):
regex
=
grpc_params
.
regex
regex
=
grpc_params
.
regex
...
@@ -445,8 +445,6 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
...
@@ -445,8 +445,6 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
json_schema
=
grpc_params
.
json_schema
json_schema
=
grpc_params
.
json_schema
elif
grpc_params
.
HasField
(
"ebnf_grammar"
):
elif
grpc_params
.
HasField
(
"ebnf_grammar"
):
ebnf_grammar
=
grpc_params
.
ebnf_grammar
ebnf_grammar
=
grpc_params
.
ebnf_grammar
elif
grpc_params
.
HasField
(
"structural_tag"
):
structural_tag
=
grpc_params
.
structural_tag
return
SGLSamplingParams
(
return
SGLSamplingParams
(
temperature
=
grpc_params
.
temperature
or
1.0
,
temperature
=
grpc_params
.
temperature
or
1.0
,
...
@@ -458,74 +456,33 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
...
@@ -458,74 +456,33 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
repetition_penalty
=
grpc_params
.
repetition_penalty
or
1.0
,
repetition_penalty
=
grpc_params
.
repetition_penalty
or
1.0
,
max_new_tokens
=
grpc_params
.
max_new_tokens
or
128
,
max_new_tokens
=
grpc_params
.
max_new_tokens
or
128
,
min_new_tokens
=
grpc_params
.
min_new_tokens
or
0
,
min_new_tokens
=
grpc_params
.
min_new_tokens
or
0
,
stop
=
list
(
grpc_params
.
stop
)
if
grpc_params
.
stop
else
[]
,
stop
=
list
(
grpc_params
.
stop
)
if
grpc_params
.
stop
else
None
,
stop_token_ids
=
(
stop_token_ids
=
(
list
(
grpc_params
.
stop_token_ids
)
if
grpc_params
.
stop_token_ids
else
[]
list
(
grpc_params
.
stop_token_ids
)
if
grpc_params
.
stop_token_ids
else
None
),
),
skip_special_tokens
=
grpc_params
.
skip_special_tokens
,
skip_special_tokens
=
grpc_params
.
skip_special_tokens
,
spaces_between_special_tokens
=
grpc_params
.
spaces_between_special_tokens
,
spaces_between_special_tokens
=
grpc_params
.
spaces_between_special_tokens
,
regex
=
regex
,
regex
=
regex
,
json_schema
=
json_schema
,
json_schema
=
json_schema
,
ebnf
=
ebnf_grammar
,
ebnf
=
ebnf_grammar
,
structural_tag
=
structural_tag
,
n
=
grpc_params
.
n
or
1
,
n
=
grpc_params
.
n
or
1
,
ignore_eos
=
grpc_params
.
ignore_eos
,
ignore_eos
=
grpc_params
.
ignore_eos
,
)
)
def
_convert_logprobs_to_proto
(
self
,
logprobs_data
:
Dict
)
->
Optional
[
sglang_scheduler_pb2
.
LogProbs
]:
"""Convert logprobs dict to proto LogProbs format (transport RAW data only)."""
if
not
logprobs_data
:
return
None
token_logprobs_val
=
logprobs_data
.
get
(
"token_logprobs_val"
,
[])
token_logprobs_idx
=
logprobs_data
.
get
(
"token_logprobs_idx"
,
[])
top_logprobs_val
=
logprobs_data
.
get
(
"top_logprobs_val"
,
[])
top_logprobs_idx
=
logprobs_data
.
get
(
"top_logprobs_idx"
,
[])
# Build TopLogProbs entries
top_logprobs_proto
=
[]
if
top_logprobs_val
and
top_logprobs_idx
:
for
val_list
,
idx_list
in
zip
(
top_logprobs_val
,
top_logprobs_idx
):
top_logprobs_proto
.
append
(
sglang_scheduler_pb2
.
TopLogProbs
(
values
=
val_list
,
token_ids
=
idx_list
,
)
)
return
sglang_scheduler_pb2
.
LogProbs
(
token_logprobs
=
token_logprobs_val
,
token_ids
=
token_logprobs_idx
,
top_logprobs
=
top_logprobs_proto
,
)
def
_create_chunk_response
(
def
_create_chunk_response
(
self
,
request_id
:
str
,
output
:
Dict
self
,
request_id
:
str
,
output
:
Dict
)
->
sglang_scheduler_pb2
.
GenerateResponse
:
)
->
sglang_scheduler_pb2
.
GenerateResponse
:
"""Create a streaming chunk response."""
"""Create a streaming chunk response."""
meta_info
=
output
.
get
(
"meta_info"
,
{})
# Convert output logprobs if present
output_logprobs_proto
=
self
.
_convert_logprobs_to_proto
(
output
.
get
(
"output_logprobs"
)
)
# Convert input logprobs if present (only in first chunk)
input_logprobs_proto
=
self
.
_convert_logprobs_to_proto
(
output
.
get
(
"input_logprobs"
)
)
return
sglang_scheduler_pb2
.
GenerateResponse
(
return
sglang_scheduler_pb2
.
GenerateResponse
(
request_id
=
request_id
,
request_id
=
request_id
,
chunk
=
sglang_scheduler_pb2
.
GenerateStreamChunk
(
chunk
=
sglang_scheduler_pb2
.
GenerateStreamChunk
(
token_ids
=
output
.
get
(
"token_ids"
,
[]),
token_id
=
output
[
"token_ids"
][
-
1
]
if
output
.
get
(
"token_ids"
)
else
0
,
prompt_tokens
=
meta_info
.
get
(
"prompt_tokens"
,
0
),
text
=
output
.
get
(
"text"
,
""
),
completion_tokens
=
meta_info
.
get
(
"completion_tokens"
,
0
),
prompt_tokens
=
0
,
cached_tokens
=
meta_info
.
get
(
"cached_tokens"
,
0
),
completion_tokens
=
len
(
output
.
get
(
"token_ids"
,
[])),
output_logprobs
=
output_logprobs_proto
,
cached_tokens
=
0
,
input_logprobs
=
input_logprobs_proto
,
generation_time
=
time
.
time
()
-
self
.
start_time
,
queue_time
=
0.0
,
),
),
)
)
...
@@ -534,56 +491,20 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
...
@@ -534,56 +491,20 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
)
->
sglang_scheduler_pb2
.
GenerateResponse
:
)
->
sglang_scheduler_pb2
.
GenerateResponse
:
"""Create a completion response."""
"""Create a completion response."""
# Extract meta info and finish reason details
# Determine finish reason
finish_reason
=
sglang_scheduler_pb2
.
GenerateComplete
.
STOP
meta_info
=
output
.
get
(
"meta_info"
,
{})
meta_info
=
output
.
get
(
"meta_info"
,
{})
finish_reason_data
=
meta_info
.
get
(
"finish_reason"
)
if
meta_info
.
get
(
"finish_reason"
)
==
"length"
:
finish_reason
=
sglang_scheduler_pb2
.
GenerateComplete
.
LENGTH
# Determine finish reason, default is stop
elif
meta_info
.
get
(
"finish_reason"
)
==
"eos_token"
:
finish_reason
=
"stop"
finish_reason
=
sglang_scheduler_pb2
.
GenerateComplete
.
EOS_TOKEN
if
finish_reason_data
:
if
isinstance
(
finish_reason_data
,
dict
):
finish_reason_type
=
finish_reason_data
.
get
(
"type"
)
else
:
# Handle legacy string format
finish_reason_type
=
finish_reason_data
if
finish_reason_type
==
"length"
:
finish_reason
=
"length"
elif
finish_reason_type
==
"abort"
:
finish_reason
=
"abort"
# Extract matched_stop information
matched_stop_kwargs
=
{}
if
isinstance
(
finish_reason_data
,
dict
)
and
"matched"
in
finish_reason_data
:
matched
=
finish_reason_data
[
"matched"
]
if
isinstance
(
matched
,
int
):
matched_stop_kwargs
[
"matched_token_id"
]
=
matched
elif
isinstance
(
matched
,
str
):
matched_stop_kwargs
[
"matched_stop_str"
]
=
matched
# Convert output logprobs if present
output_logprobs_proto
=
self
.
_convert_logprobs_to_proto
(
output
.
get
(
"output_logprobs"
)
)
# Convert input logprobs if present
input_logprobs_proto
=
self
.
_convert_logprobs_to_proto
(
output
.
get
(
"input_logprobs"
)
)
return
sglang_scheduler_pb2
.
GenerateResponse
(
return
sglang_scheduler_pb2
.
GenerateResponse
(
request_id
=
request_id
,
request_id
=
request_id
,
complete
=
sglang_scheduler_pb2
.
GenerateComplete
(
complete
=
sglang_scheduler_pb2
.
GenerateComplete
(
output_ids
=
output
.
get
(
"token_ids"
,
[]),
output_ids
=
output
.
get
(
"token_ids"
,
[]),
output_text
=
output
.
get
(
"text"
,
""
),
finish_reason
=
finish_reason
,
finish_reason
=
finish_reason
,
prompt_tokens
=
meta_info
.
get
(
"prompt_tokens"
,
0
),
completion_tokens
=
meta_info
.
get
(
"completion_tokens"
,
len
(
output
.
get
(
"token_ids"
,
[]))
),
cached_tokens
=
meta_info
.
get
(
"cached_tokens"
,
0
),
output_logprobs
=
output_logprobs_proto
,
input_logprobs
=
input_logprobs_proto
,
**
matched_stop_kwargs
,
),
),
)
)
...
...
python/sglang/srt/entrypoints/openai/protocol.py
View file @
852a49c5
...
@@ -16,7 +16,7 @@
...
@@ -16,7 +16,7 @@
import
time
import
time
import
uuid
import
uuid
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
Any
,
Dict
,
List
,
NamedTuple
,
Optional
,
TypeAlias
,
Union
from
typing
import
Any
,
Dict
,
List
,
Optional
,
TypeAlias
,
Union
from
openai.types.responses
import
(
from
openai.types.responses
import
(
ResponseFunctionToolCall
,
ResponseFunctionToolCall
,
...
@@ -228,15 +228,11 @@ class CompletionRequest(BaseModel):
...
@@ -228,15 +228,11 @@ class CompletionRequest(BaseModel):
# For request id
# For request id
rid
:
Optional
[
Union
[
List
[
str
],
str
]]
=
None
rid
:
Optional
[
Union
[
List
[
str
],
str
]]
=
None
# Extra key for classifying the request (e.g. cache_salt)
extra_key
:
Optional
[
Union
[
List
[
str
],
str
]]
=
None
# Cache salt for request caching
cache_salt
:
Optional
[
Union
[
List
[
str
],
str
]]
=
None
# Priority for the request
# Priority for the request
priority
:
Optional
[
int
]
=
None
priority
:
Optional
[
int
]
=
None
# For custom metric labels
# For custom
er
metric labels
custom_labels
:
Optional
[
Dict
[
str
,
str
]]
=
None
custom
er
_labels
:
Optional
[
Dict
[
str
,
str
]]
=
None
@
field_validator
(
"max_tokens"
)
@
field_validator
(
"max_tokens"
)
@
classmethod
@
classmethod
...
@@ -343,7 +339,7 @@ class FunctionResponse(BaseModel):
...
@@ -343,7 +339,7 @@ class FunctionResponse(BaseModel):
"""Function response."""
"""Function response."""
name
:
Optional
[
str
]
=
None
name
:
Optional
[
str
]
=
None
arguments
:
Optional
[
str
|
Dict
[
str
,
Any
]
]
=
None
arguments
:
Optional
[
str
]
=
None
class
ToolCall
(
BaseModel
):
class
ToolCall
(
BaseModel
):
...
@@ -392,7 +388,7 @@ class Function(BaseModel):
...
@@ -392,7 +388,7 @@ class Function(BaseModel):
"""Function descriptions."""
"""Function descriptions."""
description
:
Optional
[
str
]
=
Field
(
default
=
None
,
examples
=
[
None
])
description
:
Optional
[
str
]
=
Field
(
default
=
None
,
examples
=
[
None
])
name
:
str
name
:
Optional
[
str
]
=
None
parameters
:
Optional
[
object
]
=
None
parameters
:
Optional
[
object
]
=
None
strict
:
bool
=
False
strict
:
bool
=
False
...
@@ -549,10 +545,6 @@ class ChatCompletionRequest(BaseModel):
...
@@ -549,10 +545,6 @@ class ChatCompletionRequest(BaseModel):
# For request id
# For request id
rid
:
Optional
[
Union
[
List
[
str
],
str
]]
=
None
rid
:
Optional
[
Union
[
List
[
str
],
str
]]
=
None
# Extra key for classifying the request (e.g. cache_salt)
extra_key
:
Optional
[
Union
[
List
[
str
],
str
]]
=
None
# Cache salt for request caching
cache_salt
:
Optional
[
Union
[
List
[
str
],
str
]]
=
None
# Priority for the request
# Priority for the request
priority
:
Optional
[
int
]
=
None
priority
:
Optional
[
int
]
=
None
...
@@ -786,13 +778,6 @@ class ResponsesRequest(BaseModel):
...
@@ -786,13 +778,6 @@ class ResponsesRequest(BaseModel):
description
=
"The request_id related to this request. If the caller does not set it, a random uuid will be generated."
,
description
=
"The request_id related to this request. If the caller does not set it, a random uuid will be generated."
,
)
)
priority
:
int
=
Field
(
default
=
0
,
description
=
"Request priority"
)
priority
:
int
=
Field
(
default
=
0
,
description
=
"Request priority"
)
extra_key
:
Optional
[
str
]
=
Field
(
default
=
None
,
description
=
"Extra key for classifying the request (e.g. cache_salt)"
,
)
cache_salt
:
Optional
[
str
]
=
Field
(
default
=
None
,
description
=
"Cache salt for request caching"
)
# SGLang-specific sampling parameters
# SGLang-specific sampling parameters
frequency_penalty
:
float
=
0.0
frequency_penalty
:
float
=
0.0
...
@@ -943,16 +928,6 @@ class MessageProcessingResult:
...
@@ -943,16 +928,6 @@ class MessageProcessingResult:
tool_call_constraint
:
Optional
[
Any
]
=
None
tool_call_constraint
:
Optional
[
Any
]
=
None
class
ToolCallProcessingResult
(
NamedTuple
):
"""Result of processing tool calls in a response."""
tool_calls
:
Optional
[
List
[
Any
]
]
# List of ToolCall objects or None if parsing failed
remaining_text
:
str
# Text remaining after parsing tool calls
finish_reason
:
Dict
[
str
,
Any
]
# Updated finish reason dictionary
class
ResponseReasoningTextContent
(
BaseModel
):
class
ResponseReasoningTextContent
(
BaseModel
):
text
:
str
text
:
str
type
:
Literal
[
"reasoning_text"
]
=
"reasoning_text"
type
:
Literal
[
"reasoning_text"
]
=
"reasoning_text"
...
...
python/sglang/srt/entrypoints/openai/serving_base.py
View file @
852a49c5
...
@@ -27,10 +27,10 @@ class OpenAIServingBase(ABC):
...
@@ -27,10 +27,10 @@ class OpenAIServingBase(ABC):
self
.
tokenizer_manager
=
tokenizer_manager
self
.
tokenizer_manager
=
tokenizer_manager
self
.
allowed_custom_labels
=
(
self
.
allowed_custom_labels
=
(
set
(
set
(
self
.
tokenizer_manager
.
server_args
.
tokenizer_metrics_allowed_custom_labels
self
.
tokenizer_manager
.
server_args
.
tokenizer_metrics_allowed_custom
er
_labels
)
)
if
isinstance
(
self
.
tokenizer_manager
.
server_args
,
ServerArgs
)
if
isinstance
(
self
.
tokenizer_manager
.
server_args
,
ServerArgs
)
and
self
.
tokenizer_manager
.
server_args
.
tokenizer_metrics_allowed_custom_labels
and
self
.
tokenizer_manager
.
server_args
.
tokenizer_metrics_allowed_custom
er
_labels
else
None
else
None
)
)
...
@@ -62,12 +62,6 @@ class OpenAIServingBase(ABC):
...
@@ -62,12 +62,6 @@ class OpenAIServingBase(ABC):
return
self
.
create_error_response
(
return
self
.
create_error_response
(
message
=
e
.
detail
,
err_type
=
str
(
e
.
status_code
),
status_code
=
e
.
status_code
message
=
e
.
detail
,
err_type
=
str
(
e
.
status_code
),
status_code
=
e
.
status_code
)
)
except
ValueError
as
e
:
return
self
.
create_error_response
(
message
=
str
(
e
),
err_type
=
"BadRequest"
,
status_code
=
400
,
)
except
Exception
as
e
:
except
Exception
as
e
:
logger
.
exception
(
f
"Error in request:
{
e
}
"
)
logger
.
exception
(
f
"Error in request:
{
e
}
"
)
return
self
.
create_error_response
(
return
self
.
create_error_response
(
...
@@ -92,19 +86,6 @@ class OpenAIServingBase(ABC):
...
@@ -92,19 +86,6 @@ class OpenAIServingBase(ABC):
return
f
"
{
self
.
_request_id_prefix
()
}{
uuid
.
uuid4
().
hex
}
"
return
f
"
{
self
.
_request_id_prefix
()
}{
uuid
.
uuid4
().
hex
}
"
def
_compute_extra_key
(
self
,
request
:
OpenAIServingRequest
)
->
Optional
[
str
]:
"""Compute the final extra_key by concatenating cache_salt and extra_key if both are provided."""
parts
=
[]
for
key
in
[
"cache_salt"
,
"extra_key"
]:
value
=
getattr
(
request
,
key
,
None
)
if
value
:
if
not
isinstance
(
value
,
str
):
raise
TypeError
(
f
"Value of
{
key
}
must be a string, but got
{
type
(
value
).
__name__
}
"
)
parts
.
append
(
value
)
return
""
.
join
(
parts
)
if
parts
else
None
@
abstractmethod
@
abstractmethod
def
_convert_to_internal_request
(
def
_convert_to_internal_request
(
self
,
self
,
...
@@ -184,14 +165,14 @@ class OpenAIServingBase(ABC):
...
@@ -184,14 +165,14 @@ class OpenAIServingBase(ABC):
)
)
return
json
.
dumps
({
"error"
:
error
.
model_dump
()})
return
json
.
dumps
({
"error"
:
error
.
model_dump
()})
def
extract_custom_labels
(
self
,
raw_request
):
def
extract_custom
er
_labels
(
self
,
raw_request
):
if
(
if
(
not
self
.
allowed_custom_labels
not
self
.
allowed_custom_labels
or
not
self
.
tokenizer_manager
.
server_args
.
tokenizer_metrics_custom_labels_header
or
not
self
.
tokenizer_manager
.
server_args
.
tokenizer_metrics_custom_labels_header
):
):
return
None
return
None
custom_labels
=
None
custom
er
_labels
=
None
header
=
(
header
=
(
self
.
tokenizer_manager
.
server_args
.
tokenizer_metrics_custom_labels_header
self
.
tokenizer_manager
.
server_args
.
tokenizer_metrics_custom_labels_header
)
)
...
@@ -206,9 +187,9 @@ class OpenAIServingBase(ABC):
...
@@ -206,9 +187,9 @@ class OpenAIServingBase(ABC):
raw_labels
=
None
raw_labels
=
None
if
isinstance
(
raw_labels
,
dict
):
if
isinstance
(
raw_labels
,
dict
):
custom_labels
=
{
custom
er
_labels
=
{
label
:
value
label
:
value
for
label
,
value
in
raw_labels
.
items
()
for
label
,
value
in
raw_labels
.
items
()
if
label
in
self
.
allowed_custom_labels
if
label
in
self
.
allowed_custom_labels
}
}
return
custom_labels
return
custom
er
_labels
python/sglang/srt/entrypoints/openai/serving_chat.py
View file @
852a49c5
...
@@ -9,7 +9,6 @@ from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, List, Optional, Uni
...
@@ -9,7 +9,6 @@ from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, List, Optional, Uni
from
fastapi
import
Request
from
fastapi
import
Request
from
fastapi.responses
import
ORJSONResponse
,
StreamingResponse
from
fastapi.responses
import
ORJSONResponse
,
StreamingResponse
from
jsonschema
import
Draft202012Validator
,
SchemaError
from
sglang.srt.entrypoints.openai.protocol
import
(
from
sglang.srt.entrypoints.openai.protocol
import
(
ChatCompletionRequest
,
ChatCompletionRequest
,
...
@@ -26,8 +25,6 @@ from sglang.srt.entrypoints.openai.protocol import (
...
@@ -26,8 +25,6 @@ from sglang.srt.entrypoints.openai.protocol import (
LogProbs
,
LogProbs
,
MessageProcessingResult
,
MessageProcessingResult
,
ToolCall
,
ToolCall
,
ToolCallProcessingResult
,
ToolChoice
,
TopLogprob
,
TopLogprob
,
)
)
from
sglang.srt.entrypoints.openai.serving_base
import
OpenAIServingBase
from
sglang.srt.entrypoints.openai.serving_base
import
OpenAIServingBase
...
@@ -36,10 +33,7 @@ from sglang.srt.entrypoints.openai.utils import (
...
@@ -36,10 +33,7 @@ from sglang.srt.entrypoints.openai.utils import (
process_hidden_states_from_ret
,
process_hidden_states_from_ret
,
to_openai_style_logprobs
,
to_openai_style_logprobs
,
)
)
from
sglang.srt.function_call.core_types
import
ToolCallItem
from
sglang.srt.function_call.function_call_parser
import
FunctionCallParser
from
sglang.srt.function_call.function_call_parser
import
FunctionCallParser
from
sglang.srt.function_call.json_array_parser
import
JsonArrayParser
from
sglang.srt.function_call.utils
import
get_json_schema_constraint
from
sglang.srt.managers.io_struct
import
GenerateReqInput
from
sglang.srt.managers.io_struct
import
GenerateReqInput
from
sglang.srt.parser.conversation
import
generate_chat_conv
from
sglang.srt.parser.conversation
import
generate_chat_conv
from
sglang.srt.parser.jinja_template_utils
import
process_content_for_template_format
from
sglang.srt.parser.jinja_template_utils
import
process_content_for_template_format
...
@@ -64,7 +58,6 @@ class OpenAIServingChat(OpenAIServingBase):
...
@@ -64,7 +58,6 @@ class OpenAIServingChat(OpenAIServingBase):
super
().
__init__
(
tokenizer_manager
)
super
().
__init__
(
tokenizer_manager
)
self
.
template_manager
=
template_manager
self
.
template_manager
=
template_manager
self
.
tool_call_parser
=
self
.
tokenizer_manager
.
server_args
.
tool_call_parser
self
.
tool_call_parser
=
self
.
tokenizer_manager
.
server_args
.
tool_call_parser
self
.
reasoning_parser
=
self
.
tokenizer_manager
.
server_args
.
reasoning_parser
def
_request_id_prefix
(
self
)
->
str
:
def
_request_id_prefix
(
self
)
->
str
:
return
"chatcmpl-"
return
"chatcmpl-"
...
@@ -81,23 +74,6 @@ class OpenAIServingChat(OpenAIServingBase):
...
@@ -81,23 +74,6 @@ class OpenAIServingChat(OpenAIServingBase):
):
):
return
"Tools cannot be empty if tool choice is set to required."
return
"Tools cannot be empty if tool choice is set to required."
if
request
.
tool_choice
is
not
None
and
not
isinstance
(
request
.
tool_choice
,
str
):
if
not
request
.
tools
:
return
"Tools cannot be empty if tool choice is set to a specific tool."
tool_name
=
request
.
tool_choice
.
function
.
name
tool_exists
=
any
(
tool
.
function
.
name
==
tool_name
for
tool
in
request
.
tools
)
if
not
tool_exists
:
return
f
"Tool '
{
tool_name
}
' not found in tools list."
# Validate tool definitions
for
i
,
tool
in
enumerate
(
request
.
tools
or
[]):
if
tool
.
function
.
parameters
is
None
:
continue
try
:
Draft202012Validator
.
check_schema
(
tool
.
function
.
parameters
)
except
SchemaError
as
e
:
return
f
"Tool
{
i
}
function has invalid 'parameters' schema:
{
str
(
e
)
}
"
max_output_tokens
=
request
.
max_completion_tokens
or
request
.
max_tokens
max_output_tokens
=
request
.
max_completion_tokens
or
request
.
max_tokens
server_context_length
=
self
.
tokenizer_manager
.
server_args
.
context_length
server_context_length
=
self
.
tokenizer_manager
.
server_args
.
context_length
if
(
if
(
...
@@ -152,8 +128,8 @@ class OpenAIServingChat(OpenAIServingBase):
...
@@ -152,8 +128,8 @@ class OpenAIServingChat(OpenAIServingBase):
else
:
else
:
prompt_kwargs
=
{
"input_ids"
:
processed_messages
.
prompt_ids
}
prompt_kwargs
=
{
"input_ids"
:
processed_messages
.
prompt_ids
}
# Extract custom labels from raw request headers
# Extract custom
er
labels from raw request headers
custom_labels
=
self
.
extract_custom_labels
(
raw_request
)
custom
er
_labels
=
self
.
extract_custom
er
_labels
(
raw_request
)
adapted_request
=
GenerateReqInput
(
adapted_request
=
GenerateReqInput
(
**
prompt_kwargs
,
**
prompt_kwargs
,
...
@@ -173,9 +149,8 @@ class OpenAIServingChat(OpenAIServingBase):
...
@@ -173,9 +149,8 @@ class OpenAIServingChat(OpenAIServingBase):
bootstrap_room
=
request
.
bootstrap_room
,
bootstrap_room
=
request
.
bootstrap_room
,
return_hidden_states
=
request
.
return_hidden_states
,
return_hidden_states
=
request
.
return_hidden_states
,
rid
=
request
.
rid
,
rid
=
request
.
rid
,
extra_key
=
self
.
_compute_extra_key
(
request
),
priority
=
request
.
priority
,
priority
=
request
.
priority
,
custom_labels
=
custom_labels
,
custom
er
_labels
=
custom
er
_labels
,
)
)
return
adapted_request
,
request
return
adapted_request
,
request
...
@@ -213,14 +188,6 @@ class OpenAIServingChat(OpenAIServingBase):
...
@@ -213,14 +188,6 @@ class OpenAIServingChat(OpenAIServingBase):
tool_call_constraint
=
parser
.
get_structure_constraint
(
tool_call_constraint
=
parser
.
get_structure_constraint
(
request
.
tool_choice
request
.
tool_choice
)
)
# Handle JSON schema constraint directly for required or named tool choice
if
request
.
tool_choice
==
"required"
or
isinstance
(
request
.
tool_choice
,
ToolChoice
):
json_schema
=
get_json_schema_constraint
(
request
.
tools
,
request
.
tool_choice
)
tool_call_constraint
=
(
"json_schema"
,
json_schema
)
# Use chat template
# Use chat template
if
self
.
template_manager
.
chat_template_name
is
None
:
if
self
.
template_manager
.
chat_template_name
is
None
:
...
@@ -468,10 +435,6 @@ class OpenAIServingChat(OpenAIServingBase):
...
@@ -468,10 +435,6 @@ class OpenAIServingChat(OpenAIServingBase):
sampling_params
[
constraint_type
]
=
convert_json_schema_to_str
(
sampling_params
[
constraint_type
]
=
convert_json_schema_to_str
(
constraint_value
.
model_dump
(
by_alias
=
True
)
constraint_value
.
model_dump
(
by_alias
=
True
)
)
)
elif
constraint_type
==
"json_schema"
:
sampling_params
[
constraint_type
]
=
convert_json_schema_to_str
(
constraint_value
)
else
:
else
:
sampling_params
[
constraint_type
]
=
constraint_value
sampling_params
[
constraint_type
]
=
constraint_value
return
sampling_params
return
sampling_params
...
@@ -564,7 +527,10 @@ class OpenAIServingChat(OpenAIServingBase):
...
@@ -564,7 +527,10 @@ class OpenAIServingChat(OpenAIServingBase):
stream_buffers
[
index
]
=
stream_buffer
+
delta
stream_buffers
[
index
]
=
stream_buffer
+
delta
# Handle reasoning content
# Handle reasoning content
if
self
.
reasoning_parser
and
request
.
separate_reasoning
:
if
(
self
.
tokenizer_manager
.
server_args
.
reasoning_parser
and
request
.
separate_reasoning
):
reasoning_text
,
delta
=
self
.
_process_reasoning_stream
(
reasoning_text
,
delta
=
self
.
_process_reasoning_stream
(
index
,
delta
,
reasoning_parser_dict
,
content
,
request
index
,
delta
,
reasoning_parser_dict
,
content
,
request
)
)
...
@@ -754,7 +720,7 @@ class OpenAIServingChat(OpenAIServingBase):
...
@@ -754,7 +720,7 @@ class OpenAIServingChat(OpenAIServingBase):
# Handle reasoning content
# Handle reasoning content
reasoning_text
=
None
reasoning_text
=
None
reasoning_parser
=
self
.
reasoning_parser
reasoning_parser
=
self
.
tokenizer_manager
.
server_args
.
reasoning_parser
if
reasoning_parser
and
request
.
separate_reasoning
:
if
reasoning_parser
and
request
.
separate_reasoning
:
is_force_reasoning
=
(
is_force_reasoning
=
(
self
.
template_manager
.
force_reasoning
self
.
template_manager
.
force_reasoning
...
@@ -782,13 +748,8 @@ class OpenAIServingChat(OpenAIServingBase):
...
@@ -782,13 +748,8 @@ class OpenAIServingChat(OpenAIServingBase):
and
request
.
tools
and
request
.
tools
and
self
.
tool_call_parser
and
self
.
tool_call_parser
):
):
history_tool_calls_cnt
=
self
.
_get_history_tool_calls_cnt
(
request
)
tool_calls
,
text
,
finish_reason
=
self
.
_process_tool_calls
(
tool_calls
,
text
,
finish_reason
=
self
.
_process_tool_calls
(
text
,
text
,
request
.
tools
,
finish_reason
request
.
tools
,
finish_reason
,
request
.
tool_choice
,
history_tool_calls_cnt
,
)
)
choice_data
=
ChatCompletionResponseChoice
(
choice_data
=
ChatCompletionResponseChoice
(
...
@@ -878,76 +839,13 @@ class OpenAIServingChat(OpenAIServingBase):
...
@@ -878,76 +839,13 @@ class OpenAIServingChat(OpenAIServingBase):
token_logprobs
=
self
.
_process_logprobs_tokens
(
logprobs
,
use_token_index
=
True
)
token_logprobs
=
self
.
_process_logprobs_tokens
(
logprobs
,
use_token_index
=
True
)
return
ChoiceLogprobs
(
content
=
token_logprobs
)
return
ChoiceLogprobs
(
content
=
token_logprobs
)
def
_process_tool_call_id
(
self
,
call_item
:
ToolCallItem
,
history_tool_calls_cnt
:
int
,
)
->
str
:
"""Process for generating a new and unique `tool_call_id`"""
if
self
.
tool_call_parser
!=
"kimi_k2"
:
# A simple uuid is sufficient for all models except for Kimi-K2.
tool_call_id
=
f
"call_
{
uuid
.
uuid4
().
hex
[:
24
]
}
"
return
tool_call_id
else
:
# Align with Kimi-K2 format: functions.{name}:{index}
# Kimi-K2 allows multiple tool_calls in one message; SGLang sets call_item.tool_index to the *local* position inside that message.
# Therefore, the index must be corrected by using `history_tool_calls_cnt + call_item.tool_index` to ensure globally unique and properly ordered.
tool_call_id
=
f
"functions.
{
call_item
.
name
}
:
{
history_tool_calls_cnt
+
call_item
.
tool_index
}
"
logger
.
debug
(
f
"Process tool call idx, parser:
{
self
.
tool_call_parser
}
, tool_call_id:
{
tool_call_id
}
, history_cnt:
{
history_tool_calls_cnt
}
"
)
return
tool_call_id
def
_process_tool_calls
(
def
_process_tool_calls
(
self
,
self
,
text
:
str
,
text
:
str
,
tools
:
List
[
Any
],
tools
:
List
[
Any
],
finish_reason
:
Dict
[
str
,
Any
],
finish_reason
:
Dict
[
str
,
Any
],
tool_choice
:
Optional
[
Union
[
str
,
ToolChoice
]]
=
None
,
)
->
tuple
[
Optional
[
List
[
ToolCall
]],
str
,
Dict
[
str
,
Any
]]:
history_tool_calls_cnt
:
int
=
0
,
)
->
ToolCallProcessingResult
:
"""Process tool calls in the response"""
"""Process tool calls in the response"""
# Handle required or named tool choice
if
tool_choice
==
"required"
or
(
isinstance
(
tool_choice
,
ToolChoice
)
and
tool_choice
.
type
==
"function"
):
# Set finish reason to tool_calls since we're processing tool calls
if
finish_reason
[
"type"
]
==
"stop"
:
finish_reason
[
"type"
]
=
"tool_calls"
finish_reason
[
"matched"
]
=
None
try
:
# For required tool choice, we expect a JSON array of tool calls
tool_call_data
=
json
.
loads
(
text
)
tool_calls
=
[]
for
i
,
tool
in
enumerate
(
tool_call_data
):
# Create a ToolCallItem from the JSON data
call_info
=
ToolCallItem
(
tool_index
=
i
,
# Use the loop index as tool_index
name
=
tool
[
"name"
],
parameters
=
json
.
dumps
(
tool
[
"parameters"
],
ensure_ascii
=
False
),
)
tool_id
=
self
.
_process_tool_call_id
(
call_info
,
history_tool_calls_cnt
)
tool_calls
.
append
(
ToolCall
(
id
=
tool_id
,
index
=
i
,
function
=
FunctionResponse
(
name
=
tool
[
"name"
],
arguments
=
json
.
dumps
(
tool
[
"parameters"
],
ensure_ascii
=
False
),
),
)
)
return
ToolCallProcessingResult
(
tool_calls
,
""
,
finish_reason
)
except
json
.
JSONDecodeError
as
e
:
logger
.
error
(
f
"Tool call parsing error:
{
e
}
"
)
return
ToolCallProcessingResult
(
None
,
text
,
finish_reason
)
# Use parser since output is not constrained by JSON schema
parser
=
FunctionCallParser
(
tools
,
self
.
tool_call_parser
)
parser
=
FunctionCallParser
(
tools
,
self
.
tool_call_parser
)
if
parser
.
has_tool_call
(
text
):
if
parser
.
has_tool_call
(
text
):
if
finish_reason
[
"type"
]
==
"stop"
:
if
finish_reason
[
"type"
]
==
"stop"
:
...
@@ -957,9 +855,15 @@ class OpenAIServingChat(OpenAIServingBase):
...
@@ -957,9 +855,15 @@ class OpenAIServingChat(OpenAIServingBase):
text
,
call_info_list
=
parser
.
parse_non_stream
(
text
)
text
,
call_info_list
=
parser
.
parse_non_stream
(
text
)
tool_calls
=
[]
tool_calls
=
[]
for
call_info
in
call_info_list
:
for
call_info
in
call_info_list
:
tool_id
=
self
.
_process_tool_call_id
(
# For Kimi-K2, align tool_call_id with the model format: functions.{name}:{index}
call_info
,
history_tool_calls_cnt
if
(
)
self
.
tool_call_parser
==
"kimi_k2"
and
call_info
.
name
is
not
None
):
tool_id
=
f
"functions.
{
call_info
.
name
}
:
{
call_info
.
tool_index
}
"
else
:
tool_id
=
f
"call_
{
uuid
.
uuid4
().
hex
[:
24
]
}
"
tool_calls
.
append
(
tool_calls
.
append
(
ToolCall
(
ToolCall
(
id
=
tool_id
,
id
=
tool_id
,
...
@@ -969,13 +873,13 @@ class OpenAIServingChat(OpenAIServingBase):
...
@@ -969,13 +873,13 @@ class OpenAIServingChat(OpenAIServingBase):
),
),
)
)
)
)
return
ToolCallProcessingResult
(
tool_calls
,
text
,
finish_reason
)
return
tool_calls
,
text
,
finish_reason
except
Exception
as
e
:
except
Exception
as
e
:
logger
.
error
(
f
"Tool call parsing error:
{
e
}
"
)
logger
.
error
(
f
"Tool call parsing error:
{
e
}
"
)
# Return error but don't fail the whole request
# Return error but don't fail the whole request
return
ToolCallProcessingResult
(
None
,
text
,
finish_reason
)
return
None
,
text
,
finish_reason
return
ToolCallProcessingResult
(
None
,
text
,
finish_reason
)
return
None
,
text
,
finish_reason
def
_process_streaming_logprobs
(
def
_process_streaming_logprobs
(
self
,
content
:
Dict
[
str
,
Any
],
n_prev_token
:
int
self
,
content
:
Dict
[
str
,
Any
],
n_prev_token
:
int
...
@@ -1008,33 +912,13 @@ class OpenAIServingChat(OpenAIServingBase):
...
@@ -1008,33 +912,13 @@ class OpenAIServingChat(OpenAIServingBase):
or
self
.
_get_enable_thinking_from_request
(
request
)
or
self
.
_get_enable_thinking_from_request
(
request
)
)
)
reasoning_parser_dict
[
index
]
=
ReasoningParser
(
reasoning_parser_dict
[
index
]
=
ReasoningParser
(
self
.
reasoning_parser
,
self
.
tokenizer_manager
.
server_args
.
reasoning_parser
,
request
.
stream_reasoning
,
request
.
stream_reasoning
,
is_force_reasoning
,
is_force_reasoning
,
)
)
reasoning_parser
=
reasoning_parser_dict
[
index
]
reasoning_parser
=
reasoning_parser_dict
[
index
]
return
reasoning_parser
.
parse_stream_chunk
(
delta
)
return
reasoning_parser
.
parse_stream_chunk
(
delta
)
def
_get_history_tool_calls_cnt
(
self
,
request
:
ChatCompletionRequest
)
->
int
:
"""Counts the number of tool calls in the request's message history.
NOTE: This method is only useful for models that include self-increasing
history tool call idx in tool calls id, such as kimi-k2
Args:
request: The chat completion request object.
Returns:
The total number of tool calls in the history, or 0 if not applicable.
"""
messages
=
getattr
(
request
,
"messages"
,
[])
idx
=
0
for
msg
in
messages
:
if
msg
.
role
==
"assistant"
:
tool_calls
=
getattr
(
msg
,
"tool_calls"
,
None
)
idx
+=
len
(
list
(
tool_calls
))
if
tool_calls
is
not
None
else
0
# noqa
return
idx
def
_get_enable_thinking_from_request
(
self
,
request
:
ChatCompletionRequest
)
->
bool
:
def
_get_enable_thinking_from_request
(
self
,
request
:
ChatCompletionRequest
)
->
bool
:
"""Extracts the 'enable_thinking' flag from request chat_template_kwargs.
"""Extracts the 'enable_thinking' flag from request chat_template_kwargs.
...
@@ -1048,11 +932,11 @@ class OpenAIServingChat(OpenAIServingBase):
...
@@ -1048,11 +932,11 @@ class OpenAIServingChat(OpenAIServingBase):
"""
"""
if
hasattr
(
request
,
"chat_template_kwargs"
)
and
request
.
chat_template_kwargs
:
if
hasattr
(
request
,
"chat_template_kwargs"
)
and
request
.
chat_template_kwargs
:
# For Qwen3 models, `enable_thinking` is supported.
# For Qwen3 models, `enable_thinking` is supported.
if
self
.
reasoning_parser
in
[
"qwen3"
,
"glm45"
]
:
if
request
.
chat_template_kwargs
.
get
(
"enable_thinking"
)
is
not
None
:
return
request
.
chat_template_kwargs
.
get
(
"enable_thinking"
,
False
)
return
request
.
chat_template_kwargs
.
get
(
"enable_thinking"
)
# For DeepSeek-V3.1 models, `thinking` is supported.
# For DeepSeek-V3.1 models, `thinking` is supported.
elif
self
.
reasoning_parser
in
[
"deepseek-v3"
]
:
elif
request
.
chat_template_kwargs
.
get
(
"thinking"
)
is
not
None
:
return
request
.
chat_template_kwargs
.
get
(
"thinking"
,
False
)
return
request
.
chat_template_kwargs
.
get
(
"thinking"
)
else
:
else
:
return
False
return
False
return
False
return
False
...
@@ -1068,25 +952,13 @@ class OpenAIServingChat(OpenAIServingBase):
...
@@ -1068,25 +952,13 @@ class OpenAIServingChat(OpenAIServingBase):
):
):
"""Process tool calls in streaming response"""
"""Process tool calls in streaming response"""
if
index
not
in
parser_dict
:
if
index
not
in
parser_dict
:
# Use JSON detector directly for required or named tool choice
parser_dict
[
index
]
=
FunctionCallParser
(
if
request
.
tool_choice
==
"required"
or
isinstance
(
tools
=
request
.
tools
,
request
.
tool_choice
,
ToolChoice
tool_call_parser
=
self
.
tool_call_parser
,
):
)
parser_dict
[
index
]
=
JsonArrayParser
()
else
:
parser_dict
[
index
]
=
FunctionCallParser
(
tools
=
request
.
tools
,
tool_call_parser
=
self
.
tool_call_parser
,
)
parser
=
parser_dict
[
index
]
parser
=
parser_dict
[
index
]
# Handle both FunctionCallParser and JsonArrayParser
normal_text
,
calls
=
parser
.
parse_stream_chunk
(
delta
)
if
isinstance
(
parser
,
JsonArrayParser
):
result
=
parser
.
parse_streaming_increment
(
delta
,
request
.
tools
)
normal_text
,
calls
=
result
.
normal_text
,
result
.
calls
else
:
normal_text
,
calls
=
parser
.
parse_stream_chunk
(
delta
)
# Yield normal text
# Yield normal text
if
normal_text
:
if
normal_text
:
...
@@ -1104,7 +976,6 @@ class OpenAIServingChat(OpenAIServingBase):
...
@@ -1104,7 +976,6 @@ class OpenAIServingChat(OpenAIServingBase):
yield
f
"data:
{
chunk
.
model_dump_json
()
}
\n\n
"
yield
f
"data:
{
chunk
.
model_dump_json
()
}
\n\n
"
# Yield tool calls
# Yield tool calls
history_tool_calls_cnt
=
self
.
_get_history_tool_calls_cnt
(
request
)
for
call_item
in
calls
:
for
call_item
in
calls
:
# Mark that this choice has tool calls
# Mark that this choice has tool calls
has_tool_calls
[
index
]
=
True
has_tool_calls
[
index
]
=
True
...
@@ -1112,9 +983,11 @@ class OpenAIServingChat(OpenAIServingBase):
...
@@ -1112,9 +983,11 @@ class OpenAIServingChat(OpenAIServingBase):
# Tool call ID should be generated only once per tool call
# Tool call ID should be generated only once per tool call
if
call_item
.
name
:
if
call_item
.
name
:
# First chunk: include ID and function name
# First chunk: include ID and function name
tool_call_id
=
self
.
_process_tool_call_id
(
if
self
.
tool_call_parser
==
"kimi_k2"
:
call_item
,
history_tool_calls_cnt
# Align with Kimi-K2 format: functions.{name}:{index}
)
tool_call_id
=
f
"functions.
{
call_item
.
name
}
:
{
call_item
.
tool_index
}
"
else
:
tool_call_id
=
f
"call_
{
uuid
.
uuid4
().
hex
[:
24
]
}
"
function_name
=
call_item
.
name
function_name
=
call_item
.
name
else
:
else
:
# Subsequent chunks: null ID and name for argument deltas
# Subsequent chunks: null ID and name for argument deltas
...
@@ -1145,7 +1018,7 @@ class OpenAIServingChat(OpenAIServingBase):
...
@@ -1145,7 +1018,7 @@ class OpenAIServingChat(OpenAIServingBase):
def
_check_for_unstreamed_tool_args
(
def
_check_for_unstreamed_tool_args
(
self
,
self
,
parser
:
Union
[
FunctionCallParser
,
JsonArrayParser
],
parser
:
FunctionCallParser
,
content
:
Dict
[
str
,
Any
],
content
:
Dict
[
str
,
Any
],
request
:
ChatCompletionRequest
,
request
:
ChatCompletionRequest
,
index
:
int
,
index
:
int
,
...
@@ -1155,31 +1028,30 @@ class OpenAIServingChat(OpenAIServingBase):
...
@@ -1155,31 +1028,30 @@ class OpenAIServingChat(OpenAIServingBase):
when generation finishes. This ensures tool calls are properly completed
when generation finishes. This ensures tool calls are properly completed
even if the model generates the final arguments in the last chunk.
even if the model generates the final arguments in the last chunk.
"""
"""
# Get the detector - either from FunctionCallParser or directly if json detector
# Only check if we have tool calls and the parser has tracked data
detector
=
parser
.
detector
if
hasattr
(
parser
,
"detector"
)
else
parser
# Only check if we have tool calls and the detector has tracked data
if
(
if
(
not
hasattr
(
detector
,
"prev_tool_call_arr"
)
not
hasattr
(
parser
.
detector
,
"prev_tool_call_arr"
)
or
not
detector
.
prev_tool_call_arr
or
not
parser
.
detector
.
prev_tool_call_arr
):
):
return
None
return
None
if
(
if
(
not
hasattr
(
detector
,
"streamed_args_for_tool"
)
not
hasattr
(
parser
.
detector
,
"streamed_args_for_tool"
)
or
not
detector
.
streamed_args_for_tool
or
not
parser
.
detector
.
streamed_args_for_tool
):
):
return
None
return
None
# Get the last tool call that was being processed
# Get the last tool call that was being processed
tool_index
=
len
(
detector
.
prev_tool_call_arr
)
-
1
tool_index
=
len
(
parser
.
detector
.
prev_tool_call_arr
)
-
1
if
tool_index
<
0
or
tool_index
>=
len
(
detector
.
streamed_args_for_tool
):
if
tool_index
<
0
or
tool_index
>=
len
(
parser
.
detector
.
streamed_args_for_tool
):
return
None
return
None
# Get expected vs actual arguments
# Get expected vs actual arguments
expected_args
=
detector
.
prev_tool_call_arr
[
tool_index
].
get
(
"arguments"
,
{})
expected_args
=
parser
.
detector
.
prev_tool_call_arr
[
tool_index
].
get
(
"arguments"
,
{}
)
expected_call
=
json
.
dumps
(
expected_args
,
ensure_ascii
=
False
)
expected_call
=
json
.
dumps
(
expected_args
,
ensure_ascii
=
False
)
actual_call
=
detector
.
streamed_args_for_tool
[
tool_index
]
actual_call
=
parser
.
detector
.
streamed_args_for_tool
[
tool_index
]
# Check if there are remaining arguments to send
# Check if there are remaining arguments to send
remaining_call
=
(
remaining_call
=
(
...
...
python/sglang/srt/entrypoints/openai/serving_completions.py
View file @
852a49c5
...
@@ -90,8 +90,8 @@ class OpenAIServingCompletion(OpenAIServingBase):
...
@@ -90,8 +90,8 @@ class OpenAIServingCompletion(OpenAIServingBase):
else
:
else
:
prompt_kwargs
=
{
"input_ids"
:
prompt
}
prompt_kwargs
=
{
"input_ids"
:
prompt
}
# Extract custom labels from raw request headers
# Extract custom
er
labels from raw request headers
custom_labels
=
self
.
extract_custom_labels
(
raw_request
)
custom
er
_labels
=
self
.
extract_custom
er
_labels
(
raw_request
)
adapted_request
=
GenerateReqInput
(
adapted_request
=
GenerateReqInput
(
**
prompt_kwargs
,
**
prompt_kwargs
,
...
@@ -107,9 +107,8 @@ class OpenAIServingCompletion(OpenAIServingBase):
...
@@ -107,9 +107,8 @@ class OpenAIServingCompletion(OpenAIServingBase):
bootstrap_room
=
request
.
bootstrap_room
,
bootstrap_room
=
request
.
bootstrap_room
,
return_hidden_states
=
request
.
return_hidden_states
,
return_hidden_states
=
request
.
return_hidden_states
,
rid
=
request
.
rid
,
rid
=
request
.
rid
,
extra_key
=
self
.
_compute_extra_key
(
request
),
priority
=
request
.
priority
,
priority
=
request
.
priority
,
custom_labels
=
custom_labels
,
custom
er
_labels
=
custom
er
_labels
,
)
)
return
adapted_request
,
request
return
adapted_request
,
request
...
...
python/sglang/srt/entrypoints/openai/serving_responses.py
View file @
852a49c5
...
@@ -245,7 +245,6 @@ class OpenAIServingResponses(OpenAIServingChat):
...
@@ -245,7 +245,6 @@ class OpenAIServingResponses(OpenAIServingChat):
sampling_params
=
sampling_params
,
sampling_params
=
sampling_params
,
stream
=
request
.
stream
,
stream
=
request
.
stream
,
rid
=
request
.
request_id
,
rid
=
request
.
request_id
,
extra_key
=
self
.
_compute_extra_key
(
request
),
background
=
request
.
background
,
background
=
request
.
background
,
)
)
...
@@ -1251,7 +1250,6 @@ class OpenAIServingResponses(OpenAIServingChat):
...
@@ -1251,7 +1250,6 @@ class OpenAIServingResponses(OpenAIServingChat):
sampling_params
=
sampling_params
,
sampling_params
=
sampling_params
,
stream
=
adapted_request
.
stream
,
stream
=
adapted_request
.
stream
,
rid
=
request_id
,
rid
=
request_id
,
extra_key
=
adapted_request
.
extra_key
,
return_logprob
=
adapted_request
.
return_logprob
,
return_logprob
=
adapted_request
.
return_logprob
,
logprob_start_len
=
adapted_request
.
logprob_start_len
,
logprob_start_len
=
adapted_request
.
logprob_start_len
,
top_logprobs_num
=
adapted_request
.
top_logprobs_num
,
top_logprobs_num
=
adapted_request
.
top_logprobs_num
,
...
...
python/sglang/srt/eplb/expert_location.py
View file @
852a49c5
...
@@ -231,7 +231,6 @@ class ExpertLocationMetadata:
...
@@ -231,7 +231,6 @@ class ExpertLocationMetadata:
logical_to_all_physical_map_num_valid
=
logical_to_all_physical_map_num_valid
,
logical_to_all_physical_map_num_valid
=
logical_to_all_physical_map_num_valid
,
logical_to_rank_dispatch_physical_map
=
(
logical_to_rank_dispatch_physical_map
=
(
compute_logical_to_rank_dispatch_physical_map
(
compute_logical_to_rank_dispatch_physical_map
(
server_args
=
server_args
,
logical_to_all_physical_map
=
logical_to_all_physical_map
,
logical_to_all_physical_map
=
logical_to_all_physical_map
,
num_gpus
=
ep_size
,
num_gpus
=
ep_size
,
num_physical_experts
=
num_physical_experts
,
num_physical_experts
=
num_physical_experts
,
...
@@ -341,7 +340,6 @@ def _pad_nested_array(arr, pad_value):
...
@@ -341,7 +340,6 @@ def _pad_nested_array(arr, pad_value):
# TODO optimize performance (rewrite and/or run in separate process with overlap)
# TODO optimize performance (rewrite and/or run in separate process with overlap)
def
compute_logical_to_rank_dispatch_physical_map
(
def
compute_logical_to_rank_dispatch_physical_map
(
server_args
:
ServerArgs
,
logical_to_all_physical_map
:
torch
.
Tensor
,
logical_to_all_physical_map
:
torch
.
Tensor
,
num_gpus
:
int
,
num_gpus
:
int
,
num_physical_experts
:
int
,
num_physical_experts
:
int
,
...
@@ -350,9 +348,7 @@ def compute_logical_to_rank_dispatch_physical_map(
...
@@ -350,9 +348,7 @@ def compute_logical_to_rank_dispatch_physical_map(
):
):
r
=
random
.
Random
(
seed
)
r
=
random
.
Random
(
seed
)
num_local_gpu_physical_experts
=
num_physical_experts
//
num_gpus
num_local_physical_experts
=
num_physical_experts
//
num_gpus
num_gpus_per_node
=
server_args
.
ep_size
//
server_args
.
nnodes
num_local_node_physical_experts
=
num_local_gpu_physical_experts
*
num_gpus_per_node
num_layers
,
num_logical_experts
,
_
=
logical_to_all_physical_map
.
shape
num_layers
,
num_logical_experts
,
_
=
logical_to_all_physical_map
.
shape
dtype
=
logical_to_all_physical_map
.
dtype
dtype
=
logical_to_all_physical_map
.
dtype
...
@@ -376,28 +372,13 @@ def compute_logical_to_rank_dispatch_physical_map(
...
@@ -376,28 +372,13 @@ def compute_logical_to_rank_dispatch_physical_map(
physical_expert_id
physical_expert_id
for
physical_expert_id
in
candidate_physical_expert_ids
for
physical_expert_id
in
candidate_physical_expert_ids
if
_compute_gpu_id_of_physical_expert
(
if
_compute_gpu_id_of_physical_expert
(
physical_expert_id
,
num_local_
gpu_
physical_experts
physical_expert_id
,
num_local_physical_experts
)
)
==
gpu_id
==
gpu_id
]
]
if
len
(
same_gpu_physical_expert_ids
)
>
0
:
if
len
(
same_gpu_physical_expert_ids
)
>
0
:
# 1. Prefer same-GPU experts
output_partial
[
gpu_id
]
=
same_gpu_physical_expert_ids
[
0
]
output_partial
[
gpu_id
]
=
same_gpu_physical_expert_ids
[
0
]
else
:
# 2. Otherwise, prefer same-node experts
node_id
=
gpu_id
//
num_gpus_per_node
same_node_physical_expert_ids
=
[
physical_expert_id
for
physical_expert_id
in
candidate_physical_expert_ids
if
_compute_node_id_of_physical_expert
(
physical_expert_id
,
num_local_node_physical_experts
)
==
node_id
]
if
len
(
same_node_physical_expert_ids
)
>
0
:
output_partial
[
gpu_id
]
=
same_node_physical_expert_ids
[
0
]
# 3. Fill remaining slots with fair random choices
num_remain
=
torch
.
sum
(
output_partial
==
-
1
).
item
()
num_remain
=
torch
.
sum
(
output_partial
==
-
1
).
item
()
output_partial
[
output_partial
==
-
1
]
=
torch
.
tensor
(
output_partial
[
output_partial
==
-
1
]
=
torch
.
tensor
(
_fair_choices
(
candidate_physical_expert_ids
,
k
=
num_remain
,
r
=
r
),
_fair_choices
(
candidate_physical_expert_ids
,
k
=
num_remain
,
r
=
r
),
...
@@ -423,15 +404,9 @@ def _logical_to_all_physical_raw(
...
@@ -423,15 +404,9 @@ def _logical_to_all_physical_raw(
def
_compute_gpu_id_of_physical_expert
(
def
_compute_gpu_id_of_physical_expert
(
physical_expert_id
:
int
,
num_local_gpu_physical_experts
:
int
physical_expert_id
:
int
,
num_local_physical_experts
:
int
)
->
int
:
return
physical_expert_id
//
num_local_gpu_physical_experts
def
_compute_node_id_of_physical_expert
(
physical_expert_id
:
int
,
num_local_host_physical_experts
:
int
)
->
int
:
)
->
int
:
return
physical_expert_id
//
num_local_
host_
physical_experts
return
physical_expert_id
//
num_local_physical_experts
def
_fair_choices
(
arr
:
List
,
k
:
int
,
r
:
random
.
Random
)
->
List
:
def
_fair_choices
(
arr
:
List
,
k
:
int
,
r
:
random
.
Random
)
->
List
:
...
...
python/sglang/srt/function_call/function_call_parser.py
View file @
852a49c5
...
@@ -20,7 +20,6 @@ from sglang.srt.function_call.pythonic_detector import PythonicDetector
...
@@ -20,7 +20,6 @@ from sglang.srt.function_call.pythonic_detector import PythonicDetector
from
sglang.srt.function_call.qwen3_coder_detector
import
Qwen3CoderDetector
from
sglang.srt.function_call.qwen3_coder_detector
import
Qwen3CoderDetector
from
sglang.srt.function_call.qwen25_detector
import
Qwen25Detector
from
sglang.srt.function_call.qwen25_detector
import
Qwen25Detector
from
sglang.srt.function_call.step3_detector
import
Step3Detector
from
sglang.srt.function_call.step3_detector
import
Step3Detector
from
sglang.srt.function_call.utils
import
get_json_schema_constraint
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -179,8 +178,8 @@ class FunctionCallParser:
...
@@ -179,8 +178,8 @@ class FunctionCallParser:
strict_tag
=
self
.
get_structure_tag
()
strict_tag
=
self
.
get_structure_tag
()
return
(
"structural_tag"
,
strict_tag
)
return
(
"structural_tag"
,
strict_tag
)
elif
tool_choice
==
"required"
or
isinstance
(
tool_choice
,
ToolChoice
):
elif
tool_choice
==
"required"
or
isinstance
(
tool_choice
,
ToolChoice
):
json_schema
=
get_json_schema_constraint
(
self
.
tools
,
tool_choice
)
ebnf
=
self
.
get_ebnf
(
tool_choice
)
return
(
"
json_schema"
,
json_schema
)
return
(
"
ebnf"
,
ebnf
)
if
ebnf
is
not
None
else
None
def
get_ebnf
(
def
get_ebnf
(
self
,
tool_choice
:
Union
[
ToolChoice
,
Literal
[
"required"
]]
self
,
tool_choice
:
Union
[
ToolChoice
,
Literal
[
"required"
]]
...
...
python/sglang/srt/function_call/glm4_moe_detector.py
View file @
852a49c5
...
@@ -39,7 +39,7 @@ def parse_arguments(json_value):
...
@@ -39,7 +39,7 @@ def parse_arguments(json_value):
class
Glm4MoeDetector
(
BaseFormatDetector
):
class
Glm4MoeDetector
(
BaseFormatDetector
):
"""
"""
Detector for GLM-4.5
and GLM-4.6
models.
Detector for GLM-4.5 models.
Assumes function call format:
Assumes function call format:
<tool_call>get_weather
\n
<arg_key>city</arg_key>
\n
<arg_value>北京</arg_value>
\n
<arg_key>date</arg_key>
\n
<arg_value>2024-06-27</arg_value>
\n
</tool_call>
\n
<tool_call>get_weather
\n
<arg_key>city</arg_key>
\n
<arg_value>上海</arg_value>
\n
<arg_key>date</arg_key>
\n
<arg_value>2024-06-27</arg_value>
\n
</tool_call>
<tool_call>get_weather
\n
<arg_key>city</arg_key>
\n
<arg_value>北京</arg_value>
\n
<arg_key>date</arg_key>
\n
<arg_value>2024-06-27</arg_value>
\n
</tool_call>
\n
<tool_call>get_weather
\n
<arg_key>city</arg_key>
\n
<arg_value>上海</arg_value>
\n
<arg_key>date</arg_key>
\n
<arg_value>2024-06-27</arg_value>
\n
</tool_call>
"""
"""
...
@@ -53,7 +53,7 @@ class Glm4MoeDetector(BaseFormatDetector):
...
@@ -53,7 +53,7 @@ class Glm4MoeDetector(BaseFormatDetector):
self
.
func_arg_regex
=
r
"<arg_key>(.*?)</arg_key>\s*<arg_value>(.*?)</arg_value>"
self
.
func_arg_regex
=
r
"<arg_key>(.*?)</arg_key>\s*<arg_value>(.*?)</arg_value>"
def
has_tool_call
(
self
,
text
:
str
)
->
bool
:
def
has_tool_call
(
self
,
text
:
str
)
->
bool
:
"""Check if the text contains a glm-4.5
/ glm-4.6
format tool call."""
"""Check if the text contains a glm-4.5 format tool call."""
return
self
.
bot_token
in
text
return
self
.
bot_token
in
text
def
detect_and_parse
(
self
,
text
:
str
,
tools
:
List
[
Tool
])
->
StreamingParseResult
:
def
detect_and_parse
(
self
,
text
:
str
,
tools
:
List
[
Tool
])
->
StreamingParseResult
:
...
@@ -102,7 +102,7 @@ class Glm4MoeDetector(BaseFormatDetector):
...
@@ -102,7 +102,7 @@ class Glm4MoeDetector(BaseFormatDetector):
self
,
new_text
:
str
,
tools
:
List
[
Tool
]
self
,
new_text
:
str
,
tools
:
List
[
Tool
]
)
->
StreamingParseResult
:
)
->
StreamingParseResult
:
"""
"""
Streaming incremental parsing tool calls for GLM-4.5
and GLM-4.6
format.
Streaming incremental parsing tool calls for GLM-4.5 format.
"""
"""
self
.
_buffer
+=
new_text
self
.
_buffer
+=
new_text
current_text
=
self
.
_buffer
current_text
=
self
.
_buffer
...
...
python/sglang/srt/function_call/json_array_parser.py
deleted
100644 → 0
View file @
8f7453e3
import
json
import
re
from
typing
import
List
from
sglang.srt.entrypoints.openai.protocol
import
Tool
from
sglang.srt.function_call.base_format_detector
import
BaseFormatDetector
from
sglang.srt.function_call.core_types
import
StreamingParseResult
class
JsonArrayParser
(
BaseFormatDetector
):
"""
Parser for JSON array tool calls when JSON schema constraints are active.
This parser is used when tool_choice="required" or a specific tool is named,
bypassing model-specific parsers in favor of direct JSON array parsing.
"""
def
__init__
(
self
):
super
().
__init__
()
# Configure for JSON array parsing
self
.
bot_token
=
"["
self
.
eot_token
=
"]"
self
.
tool_call_separator
=
","
def
has_tool_call
(
self
,
text
:
str
)
->
bool
:
"""
Check if the given text contains a JSON tool call (array or single object).
"""
return
"["
in
text
or
"{"
in
text
def
detect_and_parse
(
self
,
text
:
str
,
tools
:
List
[
Tool
])
->
StreamingParseResult
:
"""
Parse JSON tool calls using the base class implementation.
"""
raise
NotImplementedError
(
"Detect and parse not supported for JSON schema constraints."
)
def
build_ebnf
(
self
,
tools
:
List
[
Tool
])
->
str
:
"""
Build an EBNF grammar for constrained generation.
This is not used for JSON schema constraints as they are handled
by the constraint backends directly.
"""
raise
NotImplementedError
(
"EBNF generation is not supported for JSON schema constraints."
)
def
parse_streaming_increment
(
self
,
new_text
:
str
,
tools
:
List
[
Tool
]
)
->
StreamingParseResult
:
"""
Streaming incremental parsing with tool validation.
"""
return
super
().
parse_streaming_increment
(
new_text
,
tools
)
def
structure_info
(
self
)
->
callable
:
"""
Return a function that creates StructureInfo for constrained generation.
This is not used for JSON schema constraints as they are handled
by the constraint backends directly.
"""
raise
NotImplementedError
(
"structure_info not used for JSON schema constraints"
)
python/sglang/srt/function_call/utils.py
View file @
852a49c5
import
json
import
json
from
json
import
JSONDecodeError
,
JSONDecoder
from
json
import
JSONDecodeError
,
JSONDecoder
from
json.decoder
import
WHITESPACE
from
typing
import
Any
,
Tuple
from
typing
import
Any
,
List
,
Literal
,
Optional
,
Tuple
,
Union
import
partial_json_parser
import
partial_json_parser
from
partial_json_parser.core.options
import
Allow
from
partial_json_parser.core.options
import
Allow
from
sglang.srt.entrypoints.openai.protocol
import
Tool
,
ToolChoice
def
_find_common_prefix
(
s1
:
str
,
s2
:
str
)
->
str
:
def
_find_common_prefix
(
s1
:
str
,
s2
:
str
)
->
str
:
prefix
=
""
prefix
=
""
...
@@ -40,12 +37,10 @@ def _partial_json_loads(input_str: str, flags: Allow) -> Tuple[Any, int]:
...
@@ -40,12 +37,10 @@ def _partial_json_loads(input_str: str, flags: Allow) -> Tuple[Any, int]:
"""
"""
try
:
try
:
return
(
partial_json_parser
.
loads
(
input_str
,
flags
),
len
(
input_str
))
return
(
partial_json_parser
.
loads
(
input_str
,
flags
),
len
(
input_str
))
except
(
JSONDecodeError
,
IndexError
)
as
e
:
except
JSONDecodeError
as
e
:
msg
=
getattr
(
e
,
"msg"
,
str
(
e
))
if
"Extra data"
in
e
.
msg
:
if
"Extra data"
in
msg
or
"pop from empty list"
in
msg
:
dec
=
JSONDecoder
()
start
=
WHITESPACE
.
match
(
input_str
,
0
).
end
()
return
dec
.
raw_decode
(
input_str
)
obj
,
end
=
JSONDecoder
().
raw_decode
(
input_str
,
start
)
return
obj
,
end
raise
raise
...
@@ -55,89 +50,3 @@ def _is_complete_json(input_str: str) -> bool:
...
@@ -55,89 +50,3 @@ def _is_complete_json(input_str: str) -> bool:
return
True
return
True
except
JSONDecodeError
:
except
JSONDecodeError
:
return
False
return
False
def
_get_tool_schema_defs
(
tools
:
List
[
Tool
])
->
dict
:
"""
Get consolidated $defs from all tools, validating for conflicts.
Args:
tools: List of tools to process
Returns:
Dictionary of consolidated $defs from all tools
Raises:
ValueError: If conflicting $defs are found
"""
all_defs
=
{}
for
tool
in
tools
:
if
tool
.
function
.
parameters
is
None
:
continue
defs
=
tool
.
function
.
parameters
.
get
(
"$defs"
,
{})
for
def_name
,
def_schema
in
defs
.
items
():
if
def_name
in
all_defs
and
all_defs
[
def_name
]
!=
def_schema
:
raise
ValueError
(
f
"Tool definition '
{
def_name
}
' has "
"multiple schemas, which is not "
"supported."
)
else
:
all_defs
[
def_name
]
=
def_schema
return
all_defs
def
_get_tool_schema
(
tool
:
Tool
)
->
dict
:
return
{
"properties"
:
{
"name"
:
{
"type"
:
"string"
,
"enum"
:
[
tool
.
function
.
name
]},
"parameters"
:
(
tool
.
function
.
parameters
if
tool
.
function
.
parameters
else
{
"type"
:
"object"
,
"properties"
:
{}}
),
},
"required"
:
[
"name"
,
"parameters"
],
}
def
get_json_schema_constraint
(
tools
:
List
[
Tool
],
tool_choice
:
Union
[
ToolChoice
,
Literal
[
"required"
]]
)
->
Optional
[
dict
]:
"""
Get the JSON schema constraint for the specified tool choice.
Args:
tool_choice: The tool choice specification
Returns:
JSON schema dict, or None if no valid tools found
"""
if
isinstance
(
tool_choice
,
ToolChoice
):
# For specific function choice, return the user's parameters schema directly
fn_name
=
tool_choice
.
function
.
name
for
tool
in
tools
:
if
tool
.
function
.
name
==
fn_name
:
return
{
"type"
:
"array"
,
"minItems"
:
1
,
"maxItems"
:
1
,
"items"
:
_get_tool_schema
(
tool
),
}
return
None
elif
tool_choice
==
"required"
:
json_schema
=
{
"type"
:
"array"
,
"minItems"
:
1
,
"items"
:
{
"type"
:
"object"
,
"anyOf"
:
[
_get_tool_schema
(
tool
)
for
tool
in
tools
],
},
}
json_schema_defs
=
_get_tool_schema_defs
(
tools
)
if
json_schema_defs
:
json_schema
[
"$defs"
]
=
json_schema_defs
return
json_schema
return
None
python/sglang/srt/grpc/sglang_scheduler.proto
View file @
852a49c5
...
@@ -36,9 +36,9 @@ message SamplingParams {
...
@@ -36,9 +36,9 @@ message SamplingParams {
float
presence_penalty
=
6
;
float
presence_penalty
=
6
;
float
repetition_penalty
=
7
;
float
repetition_penalty
=
7
;
optional
int32
max_new_tokens
=
8
;
int32
max_new_tokens
=
8
;
repeated
string
stop
=
9
;
repeated
string
stop
=
9
;
repeated
u
int32
stop_token_ids
=
10
;
repeated
int32
stop_token_ids
=
10
;
bool
skip_special_tokens
=
11
;
bool
skip_special_tokens
=
11
;
bool
spaces_between_special_tokens
=
12
;
bool
spaces_between_special_tokens
=
12
;
...
@@ -47,24 +47,24 @@ message SamplingParams {
...
@@ -47,24 +47,24 @@ message SamplingParams {
string
regex
=
13
;
string
regex
=
13
;
string
json_schema
=
14
;
string
json_schema
=
14
;
string
ebnf_grammar
=
15
;
string
ebnf_grammar
=
15
;
string
structural_tag
=
16
;
}
}
// LoRA adapter
// LoRA adapter
string
lora_path
=
1
7
;
string
lora_path
=
1
6
;
// Speculative decoding
// Speculative decoding
int32
n
=
1
8
;
// Number of samples
int32
n
=
1
7
;
// Number of samples
// Token healing
// Token healing
bool
token_healing
=
1
9
;
bool
token_healing
=
1
8
;
// Additional parameters
// Additional parameters
int32
min_new_tokens
=
20
;
int32
min_new_tokens
=
19
;
bool
ignore_eos
=
21
;
bool
ignore_eos
=
20
;
bool
no_stop_trim
=
22
;
bool
no_stop_trim
=
21
;
int32
stream_interval
=
23
;
int32
stream_interval
=
22
;
map
<
string
,
float
>
logit_bias
=
24
;
map
<
string
,
float
>
logit_bias
=
23
;
string
structural_tag
=
24
;
// Custom parameters for extensibility
// Custom parameters for extensibility
google.protobuf.Struct
custom_params
=
25
;
google.protobuf.Struct
custom_params
=
25
;
...
@@ -98,7 +98,7 @@ message GenerateRequest {
...
@@ -98,7 +98,7 @@ message GenerateRequest {
bool
return_logprob
=
5
;
bool
return_logprob
=
5
;
int32
logprob_start_len
=
6
;
int32
logprob_start_len
=
6
;
int32
top_logprobs_num
=
7
;
int32
top_logprobs_num
=
7
;
repeated
u
int32
token_ids_logprob
=
8
;
repeated
int32
token_ids_logprob
=
8
;
bool
return_hidden_states
=
9
;
bool
return_hidden_states
=
9
;
// For disaggregated serving
// For disaggregated serving
...
@@ -122,14 +122,11 @@ message GenerateRequest {
...
@@ -122,14 +122,11 @@ message GenerateRequest {
// For load balancing
// For load balancing
int32
dp_balance_id
=
17
;
int32
dp_balance_id
=
17
;
// Whether client wants streaming response
bool
stream
=
18
;
}
}
message
TokenizedInput
{
message
TokenizedInput
{
string
original_text
=
1
;
// For reference
string
original_text
=
1
;
// For reference
repeated
u
int32
input_ids
=
2
;
repeated
int32
input_ids
=
2
;
}
}
message
MultimodalInputs
{
message
MultimodalInputs
{
...
@@ -166,50 +163,51 @@ message GenerateResponse {
...
@@ -166,50 +163,51 @@ message GenerateResponse {
}
}
message
GenerateStreamChunk
{
message
GenerateStreamChunk
{
// Generated tokens (incremental chunk)
// Generated token
repeated
uint32
token_ids
=
1
;
int32
token_id
=
1
;
string
text
=
2
;
// Cumulative counts
// Cumulative counts
int32
prompt_tokens
=
2
;
int32
prompt_tokens
=
3
;
int32
completion_tokens
=
3
;
int32
completion_tokens
=
4
;
int32
cached_tokens
=
4
;
int32
cached_tokens
=
5
;
//
Output l
ogprobs (if requested)
- incremental for streaming
//
L
ogprobs (if requested)
LogProbs
output_
logprobs
=
5
;
LogProbs
logprobs
=
6
;
// Hidden states (if requested)
// Hidden states (if requested)
repeated
float
hidden_states
=
6
;
repeated
float
hidden_states
=
7
;
// Input logprobs (if requested) - only in first chunk
// Metadata
LogProbs
input_logprobs
=
7
;
float
generation_time
=
8
;
// Time to generate this token
int32
queue_time
=
9
;
// Time spent in queue
}
}
message
GenerateComplete
{
message
GenerateComplete
{
// Final output
// Final output
repeated
uint32
output_ids
=
1
;
repeated
int32
output_ids
=
1
;
string
output_text
=
2
;
// Finish reason as OpenAI-compatible string ("stop", "length", "abort")
string
finish_reason
=
2
;
// Finish reason
enum
FinishReason
{
// Token usage counts
// The model generated a stop sequence.
int32
prompt_tokens
=
3
;
STOP
=
0
;
int32
completion_tokens
=
4
;
// The model reached the maximum generation length.
int32
cached_tokens
=
5
;
LENGTH
=
1
;
// The model generated an end-of-sequence (EOS) token.
EOS_TOKEN
=
2
;
// The model generated a user-provided stop string.
STOP_STR
=
3
;
// The request was aborted by the user or system.
ABORT
=
4
;
}
FinishReason
finish_reason
=
3
;
//
Output
logprobs if requested
(cumulative)
//
All
logprobs if requested
LogProbs
output
_logprobs
=
6
;
repeated
LogProbs
all
_logprobs
=
11
;
// All hidden states if requested
// All hidden states if requested
repeated
HiddenStates
all_hidden_states
=
7
;
repeated
HiddenStates
all_hidden_states
=
12
;
// Matched stop information (for stop sequences)
oneof
matched_stop
{
uint32
matched_token_id
=
8
;
string
matched_stop_str
=
9
;
}
// Input logprobs if requested (for prompt tokens)
LogProbs
input_logprobs
=
10
;
}
}
message
GenerateError
{
message
GenerateError
{
...
@@ -224,11 +222,15 @@ message LogProbs {
...
@@ -224,11 +222,15 @@ message LogProbs {
// Top logprobs at each position
// Top logprobs at each position
repeated
TopLogProbs
top_logprobs
=
3
;
repeated
TopLogProbs
top_logprobs
=
3
;
// Decoded text for tokens
repeated
string
token_texts
=
4
;
}
}
message
TopLogProbs
{
message
TopLogProbs
{
repeated
float
values
=
1
;
repeated
float
values
=
1
;
repeated
int32
token_ids
=
2
;
repeated
int32
token_ids
=
2
;
repeated
string
token_texts
=
3
;
}
}
message
HiddenStates
{
message
HiddenStates
{
...
@@ -283,9 +285,10 @@ message EmbedComplete {
...
@@ -283,9 +285,10 @@ message EmbedComplete {
// Additional metadata
// Additional metadata
int32
embedding_dim
=
4
;
int32
embedding_dim
=
4
;
float
generation_time
=
5
;
// For batch embeddings
// For batch embeddings
repeated
Embedding
batch_embeddings
=
5
;
repeated
Embedding
batch_embeddings
=
6
;
}
}
message
Embedding
{
message
Embedding
{
...
...
python/sglang/srt/grpc/sglang_scheduler_pb2.py
View file @
852a49c5
This diff is collapsed.
Click to expand it.
python/sglang/srt/grpc/sglang_scheduler_pb2.pyi
View file @
852a49c5
...
@@ -3,6 +3,7 @@ import datetime
...
@@ -3,6 +3,7 @@ import datetime
from google.protobuf import timestamp_pb2 as _timestamp_pb2
from google.protobuf import timestamp_pb2 as _timestamp_pb2
from google.protobuf import struct_pb2 as _struct_pb2
from google.protobuf import struct_pb2 as _struct_pb2
from google.protobuf.internal import containers as _containers
from google.protobuf.internal import containers as _containers
from google.protobuf.internal import enum_type_wrapper as _enum_type_wrapper
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor as _descriptor
from google.protobuf import message as _message
from google.protobuf import message as _message
from collections.abc import Iterable as _Iterable, Mapping as _Mapping
from collections.abc import Iterable as _Iterable, Mapping as _Mapping
...
@@ -11,7 +12,7 @@ from typing import ClassVar as _ClassVar, Optional as _Optional, Union as _Union
...
@@ -11,7 +12,7 @@ from typing import ClassVar as _ClassVar, Optional as _Optional, Union as _Union
DESCRIPTOR: _descriptor.FileDescriptor
DESCRIPTOR: _descriptor.FileDescriptor
class SamplingParams(_message.Message):
class SamplingParams(_message.Message):
__slots__ = ("temperature", "top_p", "top_k", "min_p", "frequency_penalty", "presence_penalty", "repetition_penalty", "max_new_tokens", "stop", "stop_token_ids", "skip_special_tokens", "spaces_between_special_tokens", "regex", "json_schema", "ebnf_grammar",
"structural_tag",
"lora_path", "n", "token_healing", "min_new_tokens", "ignore_eos", "no_stop_trim", "stream_interval", "logit_bias", "custom_params")
__slots__ = ("temperature", "top_p", "top_k", "min_p", "frequency_penalty", "presence_penalty", "repetition_penalty", "max_new_tokens", "stop", "stop_token_ids", "skip_special_tokens", "spaces_between_special_tokens", "regex", "json_schema", "ebnf_grammar", "lora_path", "n", "token_healing", "min_new_tokens", "ignore_eos", "no_stop_trim", "stream_interval", "logit_bias",
"structural_tag",
"custom_params")
class LogitBiasEntry(_message.Message):
class LogitBiasEntry(_message.Message):
__slots__ = ("key", "value")
__slots__ = ("key", "value")
KEY_FIELD_NUMBER: _ClassVar[int]
KEY_FIELD_NUMBER: _ClassVar[int]
...
@@ -34,7 +35,6 @@ class SamplingParams(_message.Message):
...
@@ -34,7 +35,6 @@ class SamplingParams(_message.Message):
REGEX_FIELD_NUMBER: _ClassVar[int]
REGEX_FIELD_NUMBER: _ClassVar[int]
JSON_SCHEMA_FIELD_NUMBER: _ClassVar[int]
JSON_SCHEMA_FIELD_NUMBER: _ClassVar[int]
EBNF_GRAMMAR_FIELD_NUMBER: _ClassVar[int]
EBNF_GRAMMAR_FIELD_NUMBER: _ClassVar[int]
STRUCTURAL_TAG_FIELD_NUMBER: _ClassVar[int]
LORA_PATH_FIELD_NUMBER: _ClassVar[int]
LORA_PATH_FIELD_NUMBER: _ClassVar[int]
N_FIELD_NUMBER: _ClassVar[int]
N_FIELD_NUMBER: _ClassVar[int]
TOKEN_HEALING_FIELD_NUMBER: _ClassVar[int]
TOKEN_HEALING_FIELD_NUMBER: _ClassVar[int]
...
@@ -43,6 +43,7 @@ class SamplingParams(_message.Message):
...
@@ -43,6 +43,7 @@ class SamplingParams(_message.Message):
NO_STOP_TRIM_FIELD_NUMBER: _ClassVar[int]
NO_STOP_TRIM_FIELD_NUMBER: _ClassVar[int]
STREAM_INTERVAL_FIELD_NUMBER: _ClassVar[int]
STREAM_INTERVAL_FIELD_NUMBER: _ClassVar[int]
LOGIT_BIAS_FIELD_NUMBER: _ClassVar[int]
LOGIT_BIAS_FIELD_NUMBER: _ClassVar[int]
STRUCTURAL_TAG_FIELD_NUMBER: _ClassVar[int]
CUSTOM_PARAMS_FIELD_NUMBER: _ClassVar[int]
CUSTOM_PARAMS_FIELD_NUMBER: _ClassVar[int]
temperature: float
temperature: float
top_p: float
top_p: float
...
@@ -59,7 +60,6 @@ class SamplingParams(_message.Message):
...
@@ -59,7 +60,6 @@ class SamplingParams(_message.Message):
regex: str
regex: str
json_schema: str
json_schema: str
ebnf_grammar: str
ebnf_grammar: str
structural_tag: str
lora_path: str
lora_path: str
n: int
n: int
token_healing: bool
token_healing: bool
...
@@ -68,8 +68,9 @@ class SamplingParams(_message.Message):
...
@@ -68,8 +68,9 @@ class SamplingParams(_message.Message):
no_stop_trim: bool
no_stop_trim: bool
stream_interval: int
stream_interval: int
logit_bias: _containers.ScalarMap[str, float]
logit_bias: _containers.ScalarMap[str, float]
structural_tag: str
custom_params: _struct_pb2.Struct
custom_params: _struct_pb2.Struct
def __init__(self, temperature: _Optional[float] = ..., top_p: _Optional[float] = ..., top_k: _Optional[int] = ..., min_p: _Optional[float] = ..., frequency_penalty: _Optional[float] = ..., presence_penalty: _Optional[float] = ..., repetition_penalty: _Optional[float] = ..., max_new_tokens: _Optional[int] = ..., stop: _Optional[_Iterable[str]] = ..., stop_token_ids: _Optional[_Iterable[int]] = ..., skip_special_tokens: bool = ..., spaces_between_special_tokens: bool = ..., regex: _Optional[str] = ..., json_schema: _Optional[str] = ..., ebnf_grammar: _Optional[str] = ...,
structural_tag: _Optional[str] = ...,
lora_path: _Optional[str] = ..., n: _Optional[int] = ..., token_healing: bool = ..., min_new_tokens: _Optional[int] = ..., ignore_eos: bool = ..., no_stop_trim: bool = ..., stream_interval: _Optional[int] = ..., logit_bias: _Optional[_Mapping[str, float]] = ..., custom_params: _Optional[_Union[_struct_pb2.Struct, _Mapping]] = ...) -> None: ...
def __init__(self, temperature: _Optional[float] = ..., top_p: _Optional[float] = ..., top_k: _Optional[int] = ..., min_p: _Optional[float] = ..., frequency_penalty: _Optional[float] = ..., presence_penalty: _Optional[float] = ..., repetition_penalty: _Optional[float] = ..., max_new_tokens: _Optional[int] = ..., stop: _Optional[_Iterable[str]] = ..., stop_token_ids: _Optional[_Iterable[int]] = ..., skip_special_tokens: bool = ..., spaces_between_special_tokens: bool = ..., regex: _Optional[str] = ..., json_schema: _Optional[str] = ..., ebnf_grammar: _Optional[str] = ..., lora_path: _Optional[str] = ..., n: _Optional[int] = ..., token_healing: bool = ..., min_new_tokens: _Optional[int] = ..., ignore_eos: bool = ..., no_stop_trim: bool = ..., stream_interval: _Optional[int] = ..., logit_bias: _Optional[_Mapping[str, float]] = ...,
structural_tag: _Optional[str] = ...,
custom_params: _Optional[_Union[_struct_pb2.Struct, _Mapping]] = ...) -> None: ...
class DisaggregatedParams(_message.Message):
class DisaggregatedParams(_message.Message):
__slots__ = ("bootstrap_host", "bootstrap_port", "bootstrap_room")
__slots__ = ("bootstrap_host", "bootstrap_port", "bootstrap_room")
...
@@ -82,7 +83,7 @@ class DisaggregatedParams(_message.Message):
...
@@ -82,7 +83,7 @@ class DisaggregatedParams(_message.Message):
def __init__(self, bootstrap_host: _Optional[str] = ..., bootstrap_port: _Optional[int] = ..., bootstrap_room: _Optional[int] = ...) -> None: ...
def __init__(self, bootstrap_host: _Optional[str] = ..., bootstrap_port: _Optional[int] = ..., bootstrap_room: _Optional[int] = ...) -> None: ...
class GenerateRequest(_message.Message):
class GenerateRequest(_message.Message):
__slots__ = ("request_id", "tokenized", "mm_inputs", "sampling_params", "return_logprob", "logprob_start_len", "top_logprobs_num", "token_ids_logprob", "return_hidden_states", "disaggregated_params", "custom_logit_processor", "timestamp", "log_metrics", "input_embeds", "lora_id", "data_parallel_rank", "dp_balance_id"
, "stream"
)
__slots__ = ("request_id", "tokenized", "mm_inputs", "sampling_params", "return_logprob", "logprob_start_len", "top_logprobs_num", "token_ids_logprob", "return_hidden_states", "disaggregated_params", "custom_logit_processor", "timestamp", "log_metrics", "input_embeds", "lora_id", "data_parallel_rank", "dp_balance_id")
REQUEST_ID_FIELD_NUMBER: _ClassVar[int]
REQUEST_ID_FIELD_NUMBER: _ClassVar[int]
TOKENIZED_FIELD_NUMBER: _ClassVar[int]
TOKENIZED_FIELD_NUMBER: _ClassVar[int]
MM_INPUTS_FIELD_NUMBER: _ClassVar[int]
MM_INPUTS_FIELD_NUMBER: _ClassVar[int]
...
@@ -100,7 +101,6 @@ class GenerateRequest(_message.Message):
...
@@ -100,7 +101,6 @@ class GenerateRequest(_message.Message):
LORA_ID_FIELD_NUMBER: _ClassVar[int]
LORA_ID_FIELD_NUMBER: _ClassVar[int]
DATA_PARALLEL_RANK_FIELD_NUMBER: _ClassVar[int]
DATA_PARALLEL_RANK_FIELD_NUMBER: _ClassVar[int]
DP_BALANCE_ID_FIELD_NUMBER: _ClassVar[int]
DP_BALANCE_ID_FIELD_NUMBER: _ClassVar[int]
STREAM_FIELD_NUMBER: _ClassVar[int]
request_id: str
request_id: str
tokenized: TokenizedInput
tokenized: TokenizedInput
mm_inputs: MultimodalInputs
mm_inputs: MultimodalInputs
...
@@ -118,8 +118,7 @@ class GenerateRequest(_message.Message):
...
@@ -118,8 +118,7 @@ class GenerateRequest(_message.Message):
lora_id: str
lora_id: str
data_parallel_rank: int
data_parallel_rank: int
dp_balance_id: int
dp_balance_id: int
stream: bool
def __init__(self, request_id: _Optional[str] = ..., tokenized: _Optional[_Union[TokenizedInput, _Mapping]] = ..., mm_inputs: _Optional[_Union[MultimodalInputs, _Mapping]] = ..., sampling_params: _Optional[_Union[SamplingParams, _Mapping]] = ..., return_logprob: bool = ..., logprob_start_len: _Optional[int] = ..., top_logprobs_num: _Optional[int] = ..., token_ids_logprob: _Optional[_Iterable[int]] = ..., return_hidden_states: bool = ..., disaggregated_params: _Optional[_Union[DisaggregatedParams, _Mapping]] = ..., custom_logit_processor: _Optional[str] = ..., timestamp: _Optional[_Union[datetime.datetime, _timestamp_pb2.Timestamp, _Mapping]] = ..., log_metrics: bool = ..., input_embeds: _Optional[_Iterable[float]] = ..., lora_id: _Optional[str] = ..., data_parallel_rank: _Optional[int] = ..., dp_balance_id: _Optional[int] = ...) -> None: ...
def __init__(self, request_id: _Optional[str] = ..., tokenized: _Optional[_Union[TokenizedInput, _Mapping]] = ..., mm_inputs: _Optional[_Union[MultimodalInputs, _Mapping]] = ..., sampling_params: _Optional[_Union[SamplingParams, _Mapping]] = ..., return_logprob: bool = ..., logprob_start_len: _Optional[int] = ..., top_logprobs_num: _Optional[int] = ..., token_ids_logprob: _Optional[_Iterable[int]] = ..., return_hidden_states: bool = ..., disaggregated_params: _Optional[_Union[DisaggregatedParams, _Mapping]] = ..., custom_logit_processor: _Optional[str] = ..., timestamp: _Optional[_Union[datetime.datetime, _timestamp_pb2.Timestamp, _Mapping]] = ..., log_metrics: bool = ..., input_embeds: _Optional[_Iterable[float]] = ..., lora_id: _Optional[str] = ..., data_parallel_rank: _Optional[int] = ..., dp_balance_id: _Optional[int] = ..., stream: bool = ...) -> None: ...
class TokenizedInput(_message.Message):
class TokenizedInput(_message.Message):
__slots__ = ("original_text", "input_ids")
__slots__ = ("original_text", "input_ids")
...
@@ -162,46 +161,52 @@ class GenerateResponse(_message.Message):
...
@@ -162,46 +161,52 @@ class GenerateResponse(_message.Message):
def __init__(self, request_id: _Optional[str] = ..., chunk: _Optional[_Union[GenerateStreamChunk, _Mapping]] = ..., complete: _Optional[_Union[GenerateComplete, _Mapping]] = ..., error: _Optional[_Union[GenerateError, _Mapping]] = ...) -> None: ...
def __init__(self, request_id: _Optional[str] = ..., chunk: _Optional[_Union[GenerateStreamChunk, _Mapping]] = ..., complete: _Optional[_Union[GenerateComplete, _Mapping]] = ..., error: _Optional[_Union[GenerateError, _Mapping]] = ...) -> None: ...
class GenerateStreamChunk(_message.Message):
class GenerateStreamChunk(_message.Message):
__slots__ = ("token_ids", "prompt_tokens", "completion_tokens", "cached_tokens", "output_logprobs", "hidden_states", "input_logprobs")
__slots__ = ("token_id", "text", "prompt_tokens", "completion_tokens", "cached_tokens", "logprobs", "hidden_states", "generation_time", "queue_time")
TOKEN_IDS_FIELD_NUMBER: _ClassVar[int]
TOKEN_ID_FIELD_NUMBER: _ClassVar[int]
TEXT_FIELD_NUMBER: _ClassVar[int]
PROMPT_TOKENS_FIELD_NUMBER: _ClassVar[int]
PROMPT_TOKENS_FIELD_NUMBER: _ClassVar[int]
COMPLETION_TOKENS_FIELD_NUMBER: _ClassVar[int]
COMPLETION_TOKENS_FIELD_NUMBER: _ClassVar[int]
CACHED_TOKENS_FIELD_NUMBER: _ClassVar[int]
CACHED_TOKENS_FIELD_NUMBER: _ClassVar[int]
OUTPUT_
LOGPROBS_FIELD_NUMBER: _ClassVar[int]
LOGPROBS_FIELD_NUMBER: _ClassVar[int]
HIDDEN_STATES_FIELD_NUMBER: _ClassVar[int]
HIDDEN_STATES_FIELD_NUMBER: _ClassVar[int]
INPUT_LOGPROBS_FIELD_NUMBER: _ClassVar[int]
GENERATION_TIME_FIELD_NUMBER: _ClassVar[int]
token_ids: _containers.RepeatedScalarFieldContainer[int]
QUEUE_TIME_FIELD_NUMBER: _ClassVar[int]
token_id: int
text: str
prompt_tokens: int
prompt_tokens: int
completion_tokens: int
completion_tokens: int
cached_tokens: int
cached_tokens: int
output_
logprobs: LogProbs
logprobs: LogProbs
hidden_states: _containers.RepeatedScalarFieldContainer[float]
hidden_states: _containers.RepeatedScalarFieldContainer[float]
input_logprobs: LogProbs
generation_time: float
def __init__(self, token_ids: _Optional[_Iterable[int]] = ..., prompt_tokens: _Optional[int] = ..., completion_tokens: _Optional[int] = ..., cached_tokens: _Optional[int] = ..., output_logprobs: _Optional[_Union[LogProbs, _Mapping]] = ..., hidden_states: _Optional[_Iterable[float]] = ..., input_logprobs: _Optional[_Union[LogProbs, _Mapping]] = ...) -> None: ...
queue_time: int
def __init__(self, token_id: _Optional[int] = ..., text: _Optional[str] = ..., prompt_tokens: _Optional[int] = ..., completion_tokens: _Optional[int] = ..., cached_tokens: _Optional[int] = ..., logprobs: _Optional[_Union[LogProbs, _Mapping]] = ..., hidden_states: _Optional[_Iterable[float]] = ..., generation_time: _Optional[float] = ..., queue_time: _Optional[int] = ...) -> None: ...
class GenerateComplete(_message.Message):
class GenerateComplete(_message.Message):
__slots__ = ("output_ids", "finish_reason", "prompt_tokens", "completion_tokens", "cached_tokens", "output_logprobs", "all_hidden_states", "matched_token_id", "matched_stop_str", "input_logprobs")
__slots__ = ("output_ids", "output_text", "finish_reason", "all_logprobs", "all_hidden_states")
class FinishReason(int, metaclass=_enum_type_wrapper.EnumTypeWrapper):
__slots__ = ()
STOP: _ClassVar[GenerateComplete.FinishReason]
LENGTH: _ClassVar[GenerateComplete.FinishReason]
EOS_TOKEN: _ClassVar[GenerateComplete.FinishReason]
STOP_STR: _ClassVar[GenerateComplete.FinishReason]
ABORT: _ClassVar[GenerateComplete.FinishReason]
STOP: GenerateComplete.FinishReason
LENGTH: GenerateComplete.FinishReason
EOS_TOKEN: GenerateComplete.FinishReason
STOP_STR: GenerateComplete.FinishReason
ABORT: GenerateComplete.FinishReason
OUTPUT_IDS_FIELD_NUMBER: _ClassVar[int]
OUTPUT_IDS_FIELD_NUMBER: _ClassVar[int]
OUTPUT_TEXT_FIELD_NUMBER: _ClassVar[int]
FINISH_REASON_FIELD_NUMBER: _ClassVar[int]
FINISH_REASON_FIELD_NUMBER: _ClassVar[int]
PROMPT_TOKENS_FIELD_NUMBER: _ClassVar[int]
ALL_LOGPROBS_FIELD_NUMBER: _ClassVar[int]
COMPLETION_TOKENS_FIELD_NUMBER: _ClassVar[int]
CACHED_TOKENS_FIELD_NUMBER: _ClassVar[int]
OUTPUT_LOGPROBS_FIELD_NUMBER: _ClassVar[int]
ALL_HIDDEN_STATES_FIELD_NUMBER: _ClassVar[int]
ALL_HIDDEN_STATES_FIELD_NUMBER: _ClassVar[int]
MATCHED_TOKEN_ID_FIELD_NUMBER: _ClassVar[int]
MATCHED_STOP_STR_FIELD_NUMBER: _ClassVar[int]
INPUT_LOGPROBS_FIELD_NUMBER: _ClassVar[int]
output_ids: _containers.RepeatedScalarFieldContainer[int]
output_ids: _containers.RepeatedScalarFieldContainer[int]
finish_reason: str
output_text: str
prompt_tokens: int
finish_reason: GenerateComplete.FinishReason
completion_tokens: int
all_logprobs: _containers.RepeatedCompositeFieldContainer[LogProbs]
cached_tokens: int
output_logprobs: LogProbs
all_hidden_states: _containers.RepeatedCompositeFieldContainer[HiddenStates]
all_hidden_states: _containers.RepeatedCompositeFieldContainer[HiddenStates]
matched_token_id: int
def __init__(self, output_ids: _Optional[_Iterable[int]] = ..., output_text: _Optional[str] = ..., finish_reason: _Optional[_Union[GenerateComplete.FinishReason, str]] = ..., all_logprobs: _Optional[_Iterable[_Union[LogProbs, _Mapping]]] = ..., all_hidden_states: _Optional[_Iterable[_Union[HiddenStates, _Mapping]]] = ...) -> None: ...
matched_stop_str: str
input_logprobs: LogProbs
def __init__(self, output_ids: _Optional[_Iterable[int]] = ..., finish_reason: _Optional[str] = ..., prompt_tokens: _Optional[int] = ..., completion_tokens: _Optional[int] = ..., cached_tokens: _Optional[int] = ..., output_logprobs: _Optional[_Union[LogProbs, _Mapping]] = ..., all_hidden_states: _Optional[_Iterable[_Union[HiddenStates, _Mapping]]] = ..., matched_token_id: _Optional[int] = ..., matched_stop_str: _Optional[str] = ..., input_logprobs: _Optional[_Union[LogProbs, _Mapping]] = ...) -> None: ...
class GenerateError(_message.Message):
class GenerateError(_message.Message):
__slots__ = ("message", "http_status_code", "details")
__slots__ = ("message", "http_status_code", "details")
...
@@ -214,22 +219,26 @@ class GenerateError(_message.Message):
...
@@ -214,22 +219,26 @@ class GenerateError(_message.Message):
def __init__(self, message: _Optional[str] = ..., http_status_code: _Optional[str] = ..., details: _Optional[str] = ...) -> None: ...
def __init__(self, message: _Optional[str] = ..., http_status_code: _Optional[str] = ..., details: _Optional[str] = ...) -> None: ...
class LogProbs(_message.Message):
class LogProbs(_message.Message):
__slots__ = ("token_logprobs", "token_ids", "top_logprobs")
__slots__ = ("token_logprobs", "token_ids", "top_logprobs"
, "token_texts"
)
TOKEN_LOGPROBS_FIELD_NUMBER: _ClassVar[int]
TOKEN_LOGPROBS_FIELD_NUMBER: _ClassVar[int]
TOKEN_IDS_FIELD_NUMBER: _ClassVar[int]
TOKEN_IDS_FIELD_NUMBER: _ClassVar[int]
TOP_LOGPROBS_FIELD_NUMBER: _ClassVar[int]
TOP_LOGPROBS_FIELD_NUMBER: _ClassVar[int]
TOKEN_TEXTS_FIELD_NUMBER: _ClassVar[int]
token_logprobs: _containers.RepeatedScalarFieldContainer[float]
token_logprobs: _containers.RepeatedScalarFieldContainer[float]
token_ids: _containers.RepeatedScalarFieldContainer[int]
token_ids: _containers.RepeatedScalarFieldContainer[int]
top_logprobs: _containers.RepeatedCompositeFieldContainer[TopLogProbs]
top_logprobs: _containers.RepeatedCompositeFieldContainer[TopLogProbs]
def __init__(self, token_logprobs: _Optional[_Iterable[float]] = ..., token_ids: _Optional[_Iterable[int]] = ..., top_logprobs: _Optional[_Iterable[_Union[TopLogProbs, _Mapping]]] = ...) -> None: ...
token_texts: _containers.RepeatedScalarFieldContainer[str]
def __init__(self, token_logprobs: _Optional[_Iterable[float]] = ..., token_ids: _Optional[_Iterable[int]] = ..., top_logprobs: _Optional[_Iterable[_Union[TopLogProbs, _Mapping]]] = ..., token_texts: _Optional[_Iterable[str]] = ...) -> None: ...
class TopLogProbs(_message.Message):
class TopLogProbs(_message.Message):
__slots__ = ("values", "token_ids")
__slots__ = ("values", "token_ids"
, "token_texts"
)
VALUES_FIELD_NUMBER: _ClassVar[int]
VALUES_FIELD_NUMBER: _ClassVar[int]
TOKEN_IDS_FIELD_NUMBER: _ClassVar[int]
TOKEN_IDS_FIELD_NUMBER: _ClassVar[int]
TOKEN_TEXTS_FIELD_NUMBER: _ClassVar[int]
values: _containers.RepeatedScalarFieldContainer[float]
values: _containers.RepeatedScalarFieldContainer[float]
token_ids: _containers.RepeatedScalarFieldContainer[int]
token_ids: _containers.RepeatedScalarFieldContainer[int]
def __init__(self, values: _Optional[_Iterable[float]] = ..., token_ids: _Optional[_Iterable[int]] = ...) -> None: ...
token_texts: _containers.RepeatedScalarFieldContainer[str]
def __init__(self, values: _Optional[_Iterable[float]] = ..., token_ids: _Optional[_Iterable[int]] = ..., token_texts: _Optional[_Iterable[str]] = ...) -> None: ...
class HiddenStates(_message.Message):
class HiddenStates(_message.Message):
__slots__ = ("values", "layer", "position")
__slots__ = ("values", "layer", "position")
...
@@ -274,18 +283,20 @@ class EmbedResponse(_message.Message):
...
@@ -274,18 +283,20 @@ class EmbedResponse(_message.Message):
def __init__(self, request_id: _Optional[str] = ..., complete: _Optional[_Union[EmbedComplete, _Mapping]] = ..., error: _Optional[_Union[EmbedError, _Mapping]] = ...) -> None: ...
def __init__(self, request_id: _Optional[str] = ..., complete: _Optional[_Union[EmbedComplete, _Mapping]] = ..., error: _Optional[_Union[EmbedError, _Mapping]] = ...) -> None: ...
class EmbedComplete(_message.Message):
class EmbedComplete(_message.Message):
__slots__ = ("embedding", "prompt_tokens", "cached_tokens", "embedding_dim", "batch_embeddings")
__slots__ = ("embedding", "prompt_tokens", "cached_tokens", "embedding_dim",
"generation_time",
"batch_embeddings")
EMBEDDING_FIELD_NUMBER: _ClassVar[int]
EMBEDDING_FIELD_NUMBER: _ClassVar[int]
PROMPT_TOKENS_FIELD_NUMBER: _ClassVar[int]
PROMPT_TOKENS_FIELD_NUMBER: _ClassVar[int]
CACHED_TOKENS_FIELD_NUMBER: _ClassVar[int]
CACHED_TOKENS_FIELD_NUMBER: _ClassVar[int]
EMBEDDING_DIM_FIELD_NUMBER: _ClassVar[int]
EMBEDDING_DIM_FIELD_NUMBER: _ClassVar[int]
GENERATION_TIME_FIELD_NUMBER: _ClassVar[int]
BATCH_EMBEDDINGS_FIELD_NUMBER: _ClassVar[int]
BATCH_EMBEDDINGS_FIELD_NUMBER: _ClassVar[int]
embedding: _containers.RepeatedScalarFieldContainer[float]
embedding: _containers.RepeatedScalarFieldContainer[float]
prompt_tokens: int
prompt_tokens: int
cached_tokens: int
cached_tokens: int
embedding_dim: int
embedding_dim: int
generation_time: float
batch_embeddings: _containers.RepeatedCompositeFieldContainer[Embedding]
batch_embeddings: _containers.RepeatedCompositeFieldContainer[Embedding]
def __init__(self, embedding: _Optional[_Iterable[float]] = ..., prompt_tokens: _Optional[int] = ..., cached_tokens: _Optional[int] = ..., embedding_dim: _Optional[int] = ..., batch_embeddings: _Optional[_Iterable[_Union[Embedding, _Mapping]]] = ...) -> None: ...
def __init__(self, embedding: _Optional[_Iterable[float]] = ..., prompt_tokens: _Optional[int] = ..., cached_tokens: _Optional[int] = ..., embedding_dim: _Optional[int] = ...,
generation_time: _Optional[float] = ...,
batch_embeddings: _Optional[_Iterable[_Union[Embedding, _Mapping]]] = ...) -> None: ...
class Embedding(_message.Message):
class Embedding(_message.Message):
__slots__ = ("values", "index")
__slots__ = ("values", "index")
...
...
python/sglang/srt/grpc/sglang_scheduler_pb2_grpc.py
View file @
852a49c5
# This file is auto-generated. Do not edit manually.
# Regenerate with: python compile_proto.py
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
"""Client and server classes corresponding to protobuf-defined services."""
"""Client and server classes corresponding to protobuf-defined services."""
import
grpc
import
grpc
...
...
python/sglang/srt/hf_transformers_utils.py
View file @
852a49c5
...
@@ -119,6 +119,37 @@ def get_hf_text_config(config: PretrainedConfig):
...
@@ -119,6 +119,37 @@ def get_hf_text_config(config: PretrainedConfig):
return
config
return
config
def
_load_deepseek_v32_model
(
model_path
:
str
,
trust_remote_code
:
bool
=
False
,
revision
:
Optional
[
str
]
=
None
,
**
kwargs
,
):
# first get the local path
local_path
=
download_from_hf
(
model_path
)
# then load the config file in json
config_file
=
os
.
path
.
join
(
local_path
,
"config.json"
)
if
not
os
.
path
.
exists
(
config_file
):
raise
RuntimeError
(
f
"Can't find config file in
{
local_path
}
."
)
with
open
(
config_file
,
"r"
)
as
f
:
config_json
=
json
.
load
(
f
)
config_json
[
"architectures"
]
=
[
"DeepseekV3ForCausalLM"
]
config_json
[
"model_type"
]
=
"deepseek_v3"
tmp_path
=
os
.
path
.
join
(
local_path
,
"_tmp_config_folder"
)
os
.
makedirs
(
tmp_path
,
exist_ok
=
True
)
unique_path
=
os
.
path
.
join
(
tmp_path
,
f
"deepseek_v32_
{
os
.
getpid
()
}
"
)
with
open
(
unique_path
,
"w"
)
as
f
:
json
.
dump
(
config_json
,
f
)
return
AutoConfig
.
from_pretrained
(
unique_path
,
trust_remote_code
=
trust_remote_code
,
revision
=
revision
,
**
kwargs
)
@
lru_cache_frozenset
(
maxsize
=
32
)
@
lru_cache_frozenset
(
maxsize
=
32
)
def
get_config
(
def
get_config
(
model
:
str
,
model
:
str
,
...
@@ -140,9 +171,17 @@ def get_config(
...
@@ -140,9 +171,17 @@ def get_config(
client
.
pull_files
(
ignore_pattern
=
[
"*.pt"
,
"*.safetensors"
,
"*.bin"
])
client
.
pull_files
(
ignore_pattern
=
[
"*.pt"
,
"*.safetensors"
,
"*.bin"
])
model
=
client
.
get_local_dir
()
model
=
client
.
get_local_dir
()
config
=
AutoConfig
.
from_pretrained
(
try
:
model
,
trust_remote_code
=
trust_remote_code
,
revision
=
revision
,
**
kwargs
config
=
AutoConfig
.
from_pretrained
(
)
model
,
trust_remote_code
=
trust_remote_code
,
revision
=
revision
,
**
kwargs
)
except
ValueError
as
e
:
if
not
"deepseek_v32"
in
str
(
e
):
raise
e
config
=
_load_deepseek_v32_model
(
model
,
trust_remote_code
=
trust_remote_code
,
revision
=
revision
,
**
kwargs
)
if
(
if
(
config
.
architectures
is
not
None
config
.
architectures
is
not
None
and
config
.
architectures
[
0
]
==
"Phi4MMForCausalLM"
and
config
.
architectures
[
0
]
==
"Phi4MMForCausalLM"
...
...
python/sglang/srt/layers/attention/aiter_backend.py
View file @
852a49c5
...
@@ -619,11 +619,7 @@ class AiterAttnBackend(AttentionBackend):
...
@@ -619,11 +619,7 @@ class AiterAttnBackend(AttentionBackend):
assert
len
(
k
.
shape
)
==
3
assert
len
(
k
.
shape
)
==
3
assert
len
(
v
.
shape
)
==
3
assert
len
(
v
.
shape
)
==
3
if
(
if
forward_batch
.
forward_mode
.
is_extend
():
forward_batch
.
forward_mode
.
is_extend
()
and
not
forward_batch
.
forward_mode
.
is_target_verify
()
and
not
forward_batch
.
forward_mode
.
is_draft_extend
()
):
if
kv_indices
.
shape
[
0
]
==
0
:
if
kv_indices
.
shape
[
0
]
==
0
:
o
=
flash_attn_varlen_func
(
o
=
flash_attn_varlen_func
(
q
,
q
,
...
...
python/sglang/srt/layers/attention/ascend_backend.py
View file @
852a49c5
...
@@ -3,6 +3,7 @@ from __future__ import annotations
...
@@ -3,6 +3,7 @@ from __future__ import annotations
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
List
,
Optional
from
typing
import
TYPE_CHECKING
,
List
,
Optional
import
custom_ops
import
torch
import
torch
import
torch_npu
import
torch_npu
from
torch.nn.functional
import
scaled_dot_product_attention
from
torch.nn.functional
import
scaled_dot_product_attention
...
@@ -36,6 +37,8 @@ class ForwardMetadata:
...
@@ -36,6 +37,8 @@ class ForwardMetadata:
seq_lens_cpu_int
:
Optional
[
torch
.
Tensor
]
=
None
seq_lens_cpu_int
:
Optional
[
torch
.
Tensor
]
=
None
seq_lens_cpu_list
:
Optional
[
List
[
int
]]
=
None
seq_lens_cpu_list
:
Optional
[
List
[
int
]]
=
None
seq_lens_list_cumsum
:
Optional
[
List
[
int
]]
=
None
seq_lens_list_cumsum
:
Optional
[
List
[
int
]]
=
None
seq_lens
:
Optional
[
torch
.
Tensor
]
=
None
actual_seq_lengths_q
:
Optional
[
torch
.
Tensor
]
=
None
class
AscendAttnBackend
(
AttentionBackend
):
class
AscendAttnBackend
(
AttentionBackend
):
...
@@ -67,6 +70,9 @@ class AscendAttnBackend(AttentionBackend):
...
@@ -67,6 +70,9 @@ class AscendAttnBackend(AttentionBackend):
if
self
.
use_mla
:
if
self
.
use_mla
:
self
.
kv_lora_rank
=
model_runner
.
model_config
.
kv_lora_rank
self
.
kv_lora_rank
=
model_runner
.
model_config
.
kv_lora_rank
self
.
qk_rope_head_dim
=
model_runner
.
model_config
.
qk_rope_head_dim
self
.
qk_rope_head_dim
=
model_runner
.
model_config
.
qk_rope_head_dim
self
.
q_head_dim
=
(
self
.
qk_rope_head_dim
+
model_runner
.
model_config
.
qk_nope_head_dim
)
self
.
native_attn
=
TorchNativeAttnBackend
(
model_runner
)
self
.
native_attn
=
TorchNativeAttnBackend
(
model_runner
)
self
.
graph_metadata
=
{}
self
.
graph_metadata
=
{}
self
.
max_context_len
=
model_runner
.
model_config
.
context_len
self
.
max_context_len
=
model_runner
.
model_config
.
context_len
...
@@ -102,10 +108,6 @@ class AscendAttnBackend(AttentionBackend):
...
@@ -102,10 +108,6 @@ class AscendAttnBackend(AttentionBackend):
self
.
forward_metadata
.
seq_lens_cpu_int
=
forward_batch
.
seq_lens_cpu
.
int
()
self
.
forward_metadata
.
seq_lens_cpu_int
=
forward_batch
.
seq_lens_cpu
.
int
()
seq_lens_list_cumsum
=
np
.
cumsum
(
forward_batch
.
extend_seq_lens_cpu
)
seq_lens_list_cumsum
=
np
.
cumsum
(
forward_batch
.
extend_seq_lens_cpu
)
if
forward_batch
.
is_extend_in_batch
:
seq_lens_list_cumsum
[
-
1
]
=
(
(
seq_lens_list_cumsum
[
-
1
]
-
1
)
//
tp_size
+
1
)
*
tp_size
self
.
forward_metadata
.
seq_lens_list_cumsum
=
seq_lens_list_cumsum
self
.
forward_metadata
.
seq_lens_list_cumsum
=
seq_lens_list_cumsum
self
.
graph_mode
=
False
self
.
graph_mode
=
False
...
@@ -133,6 +135,10 @@ class AscendAttnBackend(AttentionBackend):
...
@@ -133,6 +135,10 @@ class AscendAttnBackend(AttentionBackend):
metadata
.
block_tables
=
self
.
graph_metadata
[
"block_tables"
][:
bs
,
:]
metadata
.
block_tables
=
self
.
graph_metadata
[
"block_tables"
][:
bs
,
:]
metadata
.
seq_lens_cpu_list
=
seq_lens
.
cpu
().
int
().
tolist
()
metadata
.
seq_lens_cpu_list
=
seq_lens
.
cpu
().
int
().
tolist
()
metadata
.
seq_lens
=
seq_lens
metadata
.
actual_seq_lengths_q
=
torch
.
tensor
(
[
1
+
i
*
1
for
i
in
range
(
bs
)],
dtype
=
torch
.
int32
,
device
=
seq_lens
.
device
)
self
.
graph_metadata
[
bs
]
=
metadata
self
.
graph_metadata
[
bs
]
=
metadata
self
.
forward_metadata
=
metadata
self
.
forward_metadata
=
metadata
...
@@ -161,6 +167,8 @@ class AscendAttnBackend(AttentionBackend):
...
@@ -161,6 +167,8 @@ class AscendAttnBackend(AttentionBackend):
metadata
.
block_tables
[:
bs
,
max_seq_pages
:].
fill_
(
0
)
metadata
.
block_tables
[:
bs
,
max_seq_pages
:].
fill_
(
0
)
metadata
.
block_tables
[
bs
:,
:].
fill_
(
0
)
metadata
.
block_tables
[
bs
:,
:].
fill_
(
0
)
metadata
.
seq_lens
[:
bs
].
copy_
(
seq_lens
[:
bs
])
self
.
forward_metadata
=
metadata
self
.
forward_metadata
=
metadata
self
.
graph_mode
=
True
self
.
graph_mode
=
True
...
@@ -168,6 +176,64 @@ class AscendAttnBackend(AttentionBackend):
...
@@ -168,6 +176,64 @@ class AscendAttnBackend(AttentionBackend):
def
get_cuda_graph_seq_len_fill_value
(
self
):
def
get_cuda_graph_seq_len_fill_value
(
self
):
return
0
return
0
def
forward_sparse
(
self
,
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
layer
:
RadixAttention
,
forward_batch
:
ForwardBatch
,
save_kv_cache
:
bool
=
True
,
# For multi_head latent attention
q_rope
:
Optional
[
torch
.
Tensor
]
=
None
,
k_rope
:
Optional
[
torch
.
Tensor
]
=
None
,
topk_indices
:
torch
.
Tensor
=
None
,
):
is_prefill
=
forward_batch
.
forward_mode
.
is_extend
()
if
save_kv_cache
:
k
=
k
.
view
(
-
1
,
layer
.
tp_k_head_num
,
self
.
kv_lora_rank
)
k_rope
=
k_rope
.
view
(
-
1
,
layer
.
tp_k_head_num
,
self
.
qk_rope_head_dim
)
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
,
forward_batch
.
out_cache_loc
,
k
,
k_rope
)
q_nope
,
q_pe
=
q
,
q_rope
k_nope
,
k_pe
=
forward_batch
.
token_to_kv_pool
.
get_kv_buffer
(
layer
.
layer_id
)
block_table
=
self
.
forward_metadata
.
block_tables
if
is_prefill
:
actual_seq_qlen
=
torch
.
cumsum
(
forward_batch
.
seq_lens
,
dim
=
0
)
else
:
if
self
.
forward_metadata
.
actual_seq_lengths_q
is
None
:
actual_seq_qlen
=
(
torch
.
arange
(
1
,
q
.
shape
[
0
]
+
1
).
to
(
q
.
device
).
to
(
torch
.
int32
)
)
else
:
actual_seq_qlen
=
self
.
forward_metadata
.
actual_seq_lengths_q
if
self
.
forward_metadata
.
seq_lens_cpu_int
is
None
:
actual_seq_lengths_kv
=
self
.
forward_metadata
.
seq_lens
else
:
actual_seq_lengths_kv
=
self
.
forward_metadata
.
seq_lens_cpu_int
attn_out
=
torch
.
ops
.
custom
.
npu_sparse_flash_attention
(
query
=
q_nope
,
key
=
k_nope
,
value
=
k_nope
,
query_rope
=
q_pe
,
key_rope
=
k_pe
,
sparse_indices
=
topk_indices
,
scale_value
=
layer
.
scaling
,
actual_seq_lengths_query
=
actual_seq_qlen
.
to
(
torch
.
int32
),
actual_seq_lengths_kv
=
actual_seq_lengths_kv
.
to
(
q
.
device
),
block_table
=
block_table
,
sparse_block_size
=
1
,
layout_query
=
"TND"
,
layout_kv
=
"PA_BSND"
,
sparse_mode
=
3
,
)
return
attn_out
def
forward_extend
(
def
forward_extend
(
self
,
self
,
q
,
q
,
...
@@ -176,7 +242,23 @@ class AscendAttnBackend(AttentionBackend):
...
@@ -176,7 +242,23 @@ class AscendAttnBackend(AttentionBackend):
layer
:
RadixAttention
,
layer
:
RadixAttention
,
forward_batch
:
ForwardBatch
,
forward_batch
:
ForwardBatch
,
save_kv_cache
:
bool
=
True
,
save_kv_cache
:
bool
=
True
,
# For multi_head latent attention
q_rope
:
Optional
[
torch
.
Tensor
]
=
None
,
k_rope
:
Optional
[
torch
.
Tensor
]
=
None
,
topk_indices
:
Optional
[
torch
.
Tensor
]
=
None
,
):
):
if
topk_indices
is
not
None
:
return
self
.
forward_sparse
(
q
,
k
,
v
,
layer
,
forward_batch
,
save_kv_cache
,
q_rope
,
k_rope
,
topk_indices
,
)
if
not
self
.
use_mla
:
if
not
self
.
use_mla
:
if
save_kv_cache
:
if
save_kv_cache
:
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
...
@@ -437,10 +519,23 @@ class AscendAttnBackend(AttentionBackend):
...
@@ -437,10 +519,23 @@ class AscendAttnBackend(AttentionBackend):
# For multi-head latent attention
# For multi-head latent attention
q_rope
:
Optional
[
torch
.
Tensor
]
=
None
,
q_rope
:
Optional
[
torch
.
Tensor
]
=
None
,
k_rope
:
Optional
[
torch
.
Tensor
]
=
None
,
k_rope
:
Optional
[
torch
.
Tensor
]
=
None
,
topk_indices
:
Optional
[
torch
.
Tensor
]
=
None
,
):
):
if
is_mla_preprocess_enabled
():
if
is_mla_preprocess_enabled
():
# MLAPO does saving kv_cache
# MLAPO does saving kv_cache
save_kv_cache
=
False
save_kv_cache
=
False
if
topk_indices
is
not
None
:
return
self
.
forward_sparse
(
q
,
k
,
v
,
layer
,
forward_batch
,
save_kv_cache
,
q_rope
,
k_rope
,
topk_indices
,
)
if
self
.
graph_mode
:
if
self
.
graph_mode
:
return
self
.
forward_decode_graph
(
return
self
.
forward_decode_graph
(
...
...
python/sglang/srt/layers/attention/attention_registry.py
View file @
852a49c5
import
logging
logger
=
logging
.
getLogger
(
__name__
)
ATTENTION_BACKENDS
=
{}
ATTENTION_BACKENDS
=
{}
...
@@ -66,6 +62,13 @@ def create_ascend_backend(runner):
...
@@ -66,6 +62,13 @@ def create_ascend_backend(runner):
return
AscendAttnBackend
(
runner
)
return
AscendAttnBackend
(
runner
)
@
register_attention_backend
(
"nsa"
)
def
create_nsa_backend
(
runner
):
from
sglang.srt.layers.attention.nsa_backend
import
NativeSparseAttnBackend
return
NativeSparseAttnBackend
(
runner
)
@
register_attention_backend
(
"triton"
)
@
register_attention_backend
(
"triton"
)
def
create_triton_backend
(
runner
):
def
create_triton_backend
(
runner
):
assert
not
runner
.
model_config
.
is_encoder_decoder
,
(
assert
not
runner
.
model_config
.
is_encoder_decoder
,
(
...
@@ -162,37 +165,35 @@ def create_dual_chunk_flash_attn_backend(runner):
...
@@ -162,37 +165,35 @@ def create_dual_chunk_flash_attn_backend(runner):
return
DualChunkFlashAttentionBackend
(
runner
)
return
DualChunkFlashAttentionBackend
(
runner
)
def
attn_backend_wrapper
(
runner
,
full_attn_backend
):
@
register_attention_backend
(
"hybrid_linear_attn"
)
"""
def
create_hybrid_linear_attn_backend
(
runner
):
Wrapper for special models like hybrid GDN, so we don't
assert
(
need to change the code of the original attention backend.
runner
.
is_hybrid_gdn
"""
),
"hybrid_linear_attn backend can only be used with hybrid GDN models."
assert
not
(
from
sglang.srt.layers.attention.hybrid_linear_attn_backend
import
(
runner
.
is_hybrid_gdn
and
runner
.
use_mla_backend
HybridLinearAttnBackend
,
),
"hybrid_gdn can only be used with non-MLA models."
MambaAttnBackend
,
)
# wrap for hybrid GDN models
from
sglang.srt.utils
import
is_blackwell
,
is_npu
if
runner
.
is_hybrid_gdn
:
from
sglang.srt.utils
import
is_blackwell
,
is_npu
if
is_blackwell
():
assert
(
runner
.
server_args
.
attention_backend
==
"triton"
),
"triton backend is the only supported backend on Blackwell GPUs for hybrid GDN models, use --attention-backend triton to specify the backend."
if
is_npu
():
assert
(
runner
.
server_args
.
attention_backend
==
"ascend"
),
"ascend backend is the only supported backend on NPU for hybrid GDN models, use --attention-backend ascend to specify the backend."
logger
.
info
(
f
"Using hybrid linear attention backend for hybrid GDN models."
)
from
sglang.srt.layers.attention.hybrid_linear_attn_backend
import
(
HybridLinearAttnBackend
,
MambaAttnBackend
,
)
linear_attn_backend
=
MambaAttnBackend
(
runner
)
if
is_npu
():
full_attn_layers
=
runner
.
model_config
.
hf_config
.
full_attention_layer_ids
from
sglang.srt.layers.attention.ascend_backend
import
AscendAttnBackend
return
HybridLinearAttnBackend
(
full_attn_backend
,
linear_attn_backend
,
full_attn_layers
full_attn_backend
=
AscendAttnBackend
(
runner
)
elif
is_blackwell
():
from
sglang.srt.layers.attention.triton_backend
import
TritonAttnBackend
full_attn_backend
=
TritonAttnBackend
(
runner
)
else
:
from
sglang.srt.layers.attention.flashattention_backend
import
(
FlashAttentionBackend
,
)
)
return
full_attn_backend
full_attn_backend
=
FlashAttentionBackend
(
runner
)
linear_attn_backend
=
MambaAttnBackend
(
runner
)
full_attn_layers
=
runner
.
model_config
.
hf_config
.
full_attention_layer_ids
return
HybridLinearAttnBackend
(
full_attn_backend
,
linear_attn_backend
,
full_attn_layers
)
Prev
1
2
3
4
5
6
…
8
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment