Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
01bdbf7f
Unverified
Commit
01bdbf7f
authored
May 11, 2025
by
Lianmin Zheng
Committed by
GitHub
May 11, 2025
Browse files
Improve structured outputs: fix race condition, server crash, metrics and style (#6188)
parent
94d42b67
Changes
13
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
568 additions
and
258 deletions
+568
-258
docs/backend/structured_outputs_for_reasoning_models.ipynb
docs/backend/structured_outputs_for_reasoning_models.ipynb
+14
-12
python/sglang/srt/constrained/base_grammar_backend.py
python/sglang/srt/constrained/base_grammar_backend.py
+49
-72
python/sglang/srt/constrained/llguidance_backend.py
python/sglang/srt/constrained/llguidance_backend.py
+25
-21
python/sglang/srt/constrained/outlines_backend.py
python/sglang/srt/constrained/outlines_backend.py
+27
-26
python/sglang/srt/constrained/reasoner_grammar_backend.py
python/sglang/srt/constrained/reasoner_grammar_backend.py
+22
-33
python/sglang/srt/constrained/xgrammar_backend.py
python/sglang/srt/constrained/xgrammar_backend.py
+69
-43
python/sglang/srt/layers/sampler.py
python/sglang/srt/layers/sampler.py
+0
-4
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+1
-0
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+33
-8
python/sglang/srt/managers/tokenizer_manager.py
python/sglang/srt/managers/tokenizer_manager.py
+7
-0
python/sglang/srt/metrics/collector.py
python/sglang/srt/metrics/collector.py
+312
-37
test/srt/test_json_constrained.py
test/srt/test_json_constrained.py
+8
-2
test/srt/test_metrics.py
test/srt/test_metrics.py
+1
-0
No files found.
docs/backend/structured_outputs_for_reasoning_models.ipynb
View file @
01bdbf7f
...
@@ -94,8 +94,8 @@
...
@@ -94,8 +94,8 @@
" model=\"deepseek-ai/DeepSeek-R1-Distill-Qwen-7B\",\n",
" model=\"deepseek-ai/DeepSeek-R1-Distill-Qwen-7B\",\n",
" messages=[\n",
" messages=[\n",
" {\n",
" {\n",
" \"role\": \"
user
\",\n",
" \"role\": \"
assistant
\",\n",
" \"content\": \"
Please generate the inform
ation of the capital of France in the JSON format.\",\n",
" \"content\": \"
Give me the information and popul
ation of the capital of France in the JSON format.\",\n",
" },\n",
" },\n",
" ],\n",
" ],\n",
" temperature=0,\n",
" temperature=0,\n",
...
@@ -145,8 +145,8 @@
...
@@ -145,8 +145,8 @@
" model=\"deepseek-ai/DeepSeek-R1-Distill-Qwen-7B\",\n",
" model=\"deepseek-ai/DeepSeek-R1-Distill-Qwen-7B\",\n",
" messages=[\n",
" messages=[\n",
" {\n",
" {\n",
" \"role\": \"
user
\",\n",
" \"role\": \"
assistant
\",\n",
" \"content\": \"Give me the information of the capital of France in the JSON format.\",\n",
" \"content\": \"Give me the information
and population
of the capital of France in the JSON format.\",\n",
" },\n",
" },\n",
" ],\n",
" ],\n",
" temperature=0,\n",
" temperature=0,\n",
...
@@ -188,8 +188,8 @@
...
@@ -188,8 +188,8 @@
" messages=[\n",
" messages=[\n",
" {\"role\": \"system\", \"content\": \"You are a helpful geography bot.\"},\n",
" {\"role\": \"system\", \"content\": \"You are a helpful geography bot.\"},\n",
" {\n",
" {\n",
" \"role\": \"
user
\",\n",
" \"role\": \"
assistant
\",\n",
" \"content\": \"Give me the information of the capital of France.\",\n",
" \"content\": \"Give me the information
and population
of the capital of France
in the JSON format
.\",\n",
" },\n",
" },\n",
" ],\n",
" ],\n",
" temperature=0,\n",
" temperature=0,\n",
...
@@ -218,7 +218,7 @@
...
@@ -218,7 +218,7 @@
"response = client.chat.completions.create(\n",
"response = client.chat.completions.create(\n",
" model=\"deepseek-ai/DeepSeek-R1-Distill-Qwen-7B\",\n",
" model=\"deepseek-ai/DeepSeek-R1-Distill-Qwen-7B\",\n",
" messages=[\n",
" messages=[\n",
" {\"role\": \"
user
\", \"content\": \"What is the capital of France?\"},\n",
" {\"role\": \"
assistant
\", \"content\": \"What is the capital of France?\"},\n",
" ],\n",
" ],\n",
" temperature=0,\n",
" temperature=0,\n",
" max_tokens=2048,\n",
" max_tokens=2048,\n",
...
@@ -323,7 +323,7 @@
...
@@ -323,7 +323,7 @@
"You are a helpful assistant.\"\"\",\n",
"You are a helpful assistant.\"\"\",\n",
" },\n",
" },\n",
" {\n",
" {\n",
" \"role\": \"
user
\",\n",
" \"role\": \"
assistant
\",\n",
" \"content\": \"You are in New York. Please get the current date and time, and the weather.\",\n",
" \"content\": \"You are in New York. Please get the current date and time, and the weather.\",\n",
" },\n",
" },\n",
" ]\n",
" ]\n",
...
@@ -400,9 +400,9 @@
...
@@ -400,9 +400,9 @@
"\n",
"\n",
"messages = [\n",
"messages = [\n",
" {\n",
" {\n",
" \"role\": \"
user
\",\n",
" \"role\": \"
assistant
\",\n",
" \"content\": \"
Here is
the information of the capital of France in the JSON format.\
\n\
",\n",
" \"content\": \"
Give me
the information
and population
of the capital of France in the JSON format.\",\n",
" }\n",
" }
,
\n",
"]\n",
"]\n",
"text = tokenizer.apply_chat_template(\n",
"text = tokenizer.apply_chat_template(\n",
" messages, tokenize=False, add_generation_prompt=True\n",
" messages, tokenize=False, add_generation_prompt=True\n",
...
@@ -452,7 +452,9 @@
...
@@ -452,7 +452,9 @@
")\n",
")\n",
"\n",
"\n",
"# JSON\n",
"# JSON\n",
"text = tokenizer.apply_chat_template(text, tokenize=False, add_generation_prompt=True)\n",
"text = tokenizer.apply_chat_template(\n",
" messages, tokenize=False, add_generation_prompt=True\n",
")\n",
"response = requests.post(\n",
"response = requests.post(\n",
" f\"http://localhost:{port}/generate\",\n",
" f\"http://localhost:{port}/generate\",\n",
" json={\n",
" json={\n",
...
...
python/sglang/srt/constrained/base_grammar_backend.py
View file @
01bdbf7f
...
@@ -14,10 +14,9 @@
...
@@ -14,10 +14,9 @@
"""The baseclass of a backend for grammar-guided constrained decoding."""
"""The baseclass of a backend for grammar-guided constrained decoding."""
import
logging
import
logging
from
abc
import
ABC
,
abstractmethod
from
concurrent.futures
import
ThreadPoolExecutor
from
concurrent.futures
import
Future
,
ThreadPoolExecutor
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
threading
import
Event
,
Lock
from
threading
import
Event
from
typing
import
Dict
,
List
,
Optional
,
Tuple
from
typing
import
Dict
,
List
,
Optional
,
Tuple
import
torch
import
torch
...
@@ -27,11 +26,36 @@ from sglang.srt.server_args import ServerArgs
...
@@ -27,11 +26,36 @@ from sglang.srt.server_args import ServerArgs
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
class
BaseGrammarObject
(
ABC
)
:
class
BaseGrammarObject
:
def
__init__
(
self
):
def
__init__
(
self
):
self
.
_finished
=
False
self
.
_finished
=
False
def
accept_token
(
self
,
token
:
int
)
->
None
:
"""
Accept a token in the grammar.
"""
raise
NotImplementedError
()
def
allocate_vocab_mask
(
self
,
vocab_size
:
int
,
batch_size
:
int
,
device
)
->
torch
.
Tensor
:
raise
NotImplementedError
()
def
fill_vocab_mask
(
self
,
vocab_mask
:
torch
.
Tensor
,
idx
:
int
)
->
None
:
raise
NotImplementedError
()
@
staticmethod
def
move_vocab_mask
(
vocab_mask
:
torch
.
Tensor
,
device
)
->
torch
.
Tensor
:
raise
NotImplementedError
()
@
staticmethod
def
apply_vocab_mask
(
logits
:
torch
.
Tensor
,
vocab_mask
:
torch
.
Tensor
)
->
None
:
raise
NotImplementedError
()
def
copy
(
self
)
->
"BaseGrammarObject"
:
raise
NotImplementedError
()
@
property
@
property
def
finished
(
self
):
def
finished
(
self
):
return
self
.
_finished
return
self
.
_finished
...
@@ -40,7 +64,6 @@ class BaseGrammarObject(ABC):
...
@@ -40,7 +64,6 @@ class BaseGrammarObject(ABC):
def
finished
(
self
,
finished
):
def
finished
(
self
,
finished
):
self
.
_finished
=
finished
self
.
_finished
=
finished
@
abstractmethod
def
try_jump_forward
(
self
,
tokenizer
)
->
Optional
[
Tuple
[
List
[
int
],
str
]]:
def
try_jump_forward
(
self
,
tokenizer
)
->
Optional
[
Tuple
[
List
[
int
],
str
]]:
"""
"""
Try to jump forward in the grammar.
Try to jump forward in the grammar.
...
@@ -49,9 +72,8 @@ class BaseGrammarObject(ABC):
...
@@ -49,9 +72,8 @@ class BaseGrammarObject(ABC):
A jump forward helper which may be used in `jump_forward_str_state`.
A jump forward helper which may be used in `jump_forward_str_state`.
None if the jump forward is not possible.
None if the jump forward is not possible.
"""
"""
raise
NotImplementedError
raise
NotImplementedError
()
@
abstractmethod
def
jump_forward_str_state
(
self
,
helper
:
Tuple
[
List
[
int
],
str
])
->
Tuple
[
str
,
int
]:
def
jump_forward_str_state
(
self
,
helper
:
Tuple
[
List
[
int
],
str
])
->
Tuple
[
str
,
int
]:
"""
"""
Jump forward for the grammar.
Jump forward for the grammar.
...
@@ -60,47 +82,15 @@ class BaseGrammarObject(ABC):
...
@@ -60,47 +82,15 @@ class BaseGrammarObject(ABC):
A tuple of the jump forward string and the next state of the grammar
A tuple of the jump forward string and the next state of the grammar
(which can be used in `jump_and_retokenize` if needed).
(which can be used in `jump_and_retokenize` if needed).
"""
"""
raise
NotImplementedError
raise
NotImplementedError
()
@
abstractmethod
def
jump_and_retokenize
(
def
jump_and_retokenize
(
self
,
old_output_ids
:
List
[
int
],
new_output_ids
:
List
[
int
],
next_state
:
int
self
,
old_output_ids
:
List
[
int
],
new_output_ids
:
List
[
int
],
next_state
:
int
)
->
None
:
)
->
None
:
"""
"""
Jump forward occurs, and update the grammar state if needed.
Jump forward occurs, and update the grammar state if needed.
"""
"""
raise
NotImplementedError
raise
NotImplementedError
()
@
abstractmethod
def
accept_token
(
self
,
token
:
int
)
->
None
:
"""
Accept a token in the grammar.
"""
raise
NotImplementedError
@
abstractmethod
def
allocate_vocab_mask
(
self
,
vocab_size
:
int
,
batch_size
:
int
,
device
)
->
torch
.
Tensor
:
raise
NotImplementedError
@
abstractmethod
def
fill_vocab_mask
(
self
,
vocab_mask
:
torch
.
Tensor
,
idx
:
int
)
->
None
:
raise
NotImplementedError
@
staticmethod
@
abstractmethod
def
move_vocab_mask
(
vocab_mask
:
torch
.
Tensor
,
device
)
->
torch
.
Tensor
:
raise
NotImplementedError
@
staticmethod
@
abstractmethod
def
apply_vocab_mask
(
logits
:
torch
.
Tensor
,
vocab_mask
:
torch
.
Tensor
)
->
None
:
raise
NotImplementedError
@
abstractmethod
def
copy
(
self
)
->
"BaseGrammarObject"
:
raise
NotImplementedError
@
dataclass
@
dataclass
...
@@ -113,10 +103,9 @@ class BaseGrammarBackend:
...
@@ -113,10 +103,9 @@ class BaseGrammarBackend:
def
__init__
(
self
):
def
__init__
(
self
):
self
.
executor
=
ThreadPoolExecutor
()
self
.
executor
=
ThreadPoolExecutor
()
self
.
cache
:
Dict
[
Tuple
[
str
,
str
],
CacheEntry
]
=
{}
self
.
cache
:
Dict
[
Tuple
[
str
,
str
],
CacheEntry
]
=
{}
self
.
cache_lock
=
Lock
()
def
_not_supported
(
self
,
key_type
:
str
,
key_string
:
str
)
->
None
:
def
_not_supported
(
self
,
key_type
:
str
,
key_string
:
str
)
->
None
:
logger
.
warning
(
f
"Skip unsupported
{
key_type
}
:
{
key_type
}
=
{
key_string
}
"
)
logger
.
warning
(
f
"Skip unsupported
{
key_type
=
}
,
{
key_string
=
}
"
)
def
dispatch_fallback
(
def
dispatch_fallback
(
self
,
key_type
:
str
,
key_string
:
str
self
,
key_type
:
str
,
key_string
:
str
...
@@ -148,40 +137,25 @@ class BaseGrammarBackend:
...
@@ -148,40 +137,25 @@ class BaseGrammarBackend:
return
self
.
dispatch_ebnf
(
key_string
)
return
self
.
dispatch_ebnf
(
key_string
)
elif
key_type
==
"structural_tag"
:
elif
key_type
==
"structural_tag"
:
return
self
.
dispatch_structural_tag
(
key_string
)
return
self
.
dispatch_structural_tag
(
key_string
)
elif
key_type
==
"structural_pattern"
:
return
self
.
dispatch_structural_pattern
(
key_string
)
else
:
else
:
return
self
.
dispatch_fallback
(
key_type
,
key_string
)
return
self
.
dispatch_fallback
(
key_type
,
key_string
)
def
_init_value
(
self
,
key
:
Tuple
[
str
,
str
])
->
Optional
[
BaseGrammarObject
]:
def
get_cached_or_future_value
(
with
self
.
cache_lock
:
self
,
key
:
Tuple
[
str
,
str
]
if
key
in
self
.
cache
:
)
->
Optional
[
BaseGrammarObject
]:
cache_hit
=
True
value
=
self
.
cache
.
get
(
key
)
entry
=
self
.
cache
[
key
]
if
value
:
else
:
return
value
.
copy
(),
True
cache_hit
=
False
value
=
self
.
executor
.
submit
(
self
.
_init_value_dispatch
,
key
)
entry
=
CacheEntry
(
None
,
Event
())
return
value
,
False
self
.
cache
[
key
]
=
entry
if
cache_hit
:
entry
.
event
.
wait
()
else
:
entry
.
value
=
self
.
_init_value_dispatch
(
key
)
entry
.
event
.
set
()
return
entry
.
value
.
copy
()
if
entry
.
value
else
None
def
get_cached_value
(
self
,
key
:
Tuple
[
str
,
str
])
->
Optional
[
BaseGrammarObject
]:
with
self
.
cache_lock
:
entry
=
self
.
cache
.
get
(
key
)
if
not
entry
or
not
entry
.
event
.
is_set
():
return
None
val
=
self
.
cache
[
key
].
value
return
val
.
copy
()
if
val
else
None
def
g
et_
future_valu
e
(
self
,
key
:
Tuple
[
str
,
str
]
)
->
Future
:
def
s
et_
cach
e
(
self
,
key
:
Tuple
[
str
,
str
]
,
value
:
BaseGrammarObject
)
:
return
self
.
executor
.
submit
(
self
.
_init_value
,
key
)
self
.
cache
[
key
]
=
value
def
reset
(
self
):
def
reset
(
self
):
with
self
.
cache_lock
:
self
.
cache
.
clear
()
self
.
cache
.
clear
()
def
create_grammar_backend
(
def
create_grammar_backend
(
...
@@ -211,9 +185,12 @@ def create_grammar_backend(
...
@@ -211,9 +185,12 @@ def create_grammar_backend(
raise
ValueError
(
f
"Invalid grammar backend:
{
server_args
.
grammar_backend
}
"
)
raise
ValueError
(
f
"Invalid grammar backend:
{
server_args
.
grammar_backend
}
"
)
if
server_args
.
reasoning_parser
and
hasattr
(
tokenizer
,
"think_end_id"
):
if
server_args
.
reasoning_parser
and
hasattr
(
tokenizer
,
"think_end_id"
):
from
.reasoner_grammar_backend
import
ReasonerGrammarBackend
from
sglang.srt.constrained.reasoner_grammar_backend
import
(
ReasonerGrammarBackend
,
)
grammar_backend
=
ReasonerGrammarBackend
(
grammar_backend
=
ReasonerGrammarBackend
(
grammar_backend
,
tokenizer
.
think_end_id
grammar_backend
,
tokenizer
.
think_end_id
)
)
return
grammar_backend
return
grammar_backend
python/sglang/srt/constrained/llguidance_backend.py
View file @
01bdbf7f
...
@@ -50,21 +50,6 @@ class GuidanceGrammar(BaseGrammarObject):
...
@@ -50,21 +50,6 @@ class GuidanceGrammar(BaseGrammarObject):
self
.
finished
=
False
self
.
finished
=
False
self
.
bitmask
=
None
self
.
bitmask
=
None
def
try_jump_forward
(
self
,
tokenizer
)
->
Optional
[
Tuple
[
List
[
int
],
str
]]:
ff_tokens
=
self
.
ll_matcher
.
compute_ff_tokens
()
if
ff_tokens
:
return
ff_tokens
,
""
else
:
return
None
def
jump_forward_str_state
(
self
,
helper
:
Tuple
[
List
[
int
],
str
])
->
Tuple
[
str
,
int
]:
return
""
,
-
1
def
jump_and_retokenize
(
self
,
old_output_ids
:
List
[
int
],
new_output_ids
:
List
[
int
],
next_state
:
int
):
pass
def
accept_token
(
self
,
token
:
int
):
def
accept_token
(
self
,
token
:
int
):
if
not
self
.
ll_matcher
.
consume_token
(
token
):
if
not
self
.
ll_matcher
.
consume_token
(
token
):
logger
.
warning
(
f
"matcher error:
{
self
.
ll_matcher
.
get_error
()
}
"
)
logger
.
warning
(
f
"matcher error:
{
self
.
ll_matcher
.
get_error
()
}
"
)
...
@@ -104,6 +89,21 @@ class GuidanceGrammar(BaseGrammarObject):
...
@@ -104,6 +89,21 @@ class GuidanceGrammar(BaseGrammarObject):
serialized_grammar
=
self
.
serialized_grammar
,
serialized_grammar
=
self
.
serialized_grammar
,
)
)
def
try_jump_forward
(
self
,
tokenizer
)
->
Optional
[
Tuple
[
List
[
int
],
str
]]:
ff_tokens
=
self
.
ll_matcher
.
compute_ff_tokens
()
if
ff_tokens
:
return
ff_tokens
,
""
else
:
return
None
def
jump_forward_str_state
(
self
,
helper
:
Tuple
[
List
[
int
],
str
])
->
Tuple
[
str
,
int
]:
return
""
,
-
1
def
jump_and_retokenize
(
self
,
old_output_ids
:
List
[
int
],
new_output_ids
:
List
[
int
],
next_state
:
int
):
pass
class
GuidanceBackend
(
BaseGrammarBackend
):
class
GuidanceBackend
(
BaseGrammarBackend
):
...
@@ -130,12 +130,16 @@ class GuidanceBackend(BaseGrammarBackend):
...
@@ -130,12 +130,16 @@ class GuidanceBackend(BaseGrammarBackend):
return
None
return
None
def
dispatch_json
(
self
,
key_string
:
str
)
->
Optional
[
GuidanceGrammar
]:
def
dispatch_json
(
self
,
key_string
:
str
)
->
Optional
[
GuidanceGrammar
]:
serialized_grammar
=
LLMatcher
.
grammar_from_json_schema
(
try
:
key_string
,
serialized_grammar
=
LLMatcher
.
grammar_from_json_schema
(
defaults
=
{
key_string
,
"whitespace_pattern"
:
self
.
whitespace_pattern
,
defaults
=
{
},
"whitespace_pattern"
:
self
.
whitespace_pattern
,
)
},
)
except
Exception
as
e
:
logger
.
warning
(
f
"Skip invalid grammar:
{
key_string
=
}
,
{
e
=
}
"
)
return
None
return
self
.
_from_serialized
(
serialized_grammar
)
return
self
.
_from_serialized
(
serialized_grammar
)
def
dispatch_regex
(
self
,
key_string
:
str
)
->
Optional
[
GuidanceGrammar
]:
def
dispatch_regex
(
self
,
key_string
:
str
)
->
Optional
[
GuidanceGrammar
]:
...
...
python/sglang/srt/constrained/outlines_backend.py
View file @
01bdbf7f
...
@@ -53,6 +53,30 @@ class OutlinesGrammar(BaseGrammarObject):
...
@@ -53,6 +53,30 @@ class OutlinesGrammar(BaseGrammarObject):
def
accept_token
(
self
,
token
:
int
):
def
accept_token
(
self
,
token
:
int
):
self
.
state
=
self
.
guide
.
get_next_state
(
self
.
state
,
token
)
self
.
state
=
self
.
guide
.
get_next_state
(
self
.
state
,
token
)
def
allocate_vocab_mask
(
self
,
vocab_size
:
int
,
batch_size
:
int
,
device
)
->
torch
.
Tensor
:
return
torch
.
zeros
(
batch_size
,
vocab_size
,
dtype
=
torch
.
bool
,
device
=
device
)
@
staticmethod
def
move_vocab_mask
(
vocab_mask
:
torch
.
Tensor
,
device
)
->
torch
.
Tensor
:
return
vocab_mask
def
fill_vocab_mask
(
self
,
vocab_mask
:
torch
.
Tensor
,
idx
:
int
)
->
None
:
tokens
=
torch
.
tensor
(
self
.
guide
.
get_next_instruction
(
self
.
state
).
tokens
,
dtype
=
torch
.
int64
).
to
(
vocab_mask
.
device
,
non_blocking
=
True
)
vocab_mask
=
vocab_mask
[
idx
]
vocab_mask
.
fill_
(
1
)
vocab_mask
.
scatter_
(
0
,
tokens
,
torch
.
zeros_like
(
tokens
,
dtype
=
torch
.
bool
))
@
staticmethod
def
apply_vocab_mask
(
logits
:
torch
.
Tensor
,
vocab_mask
:
torch
.
Tensor
):
logits
.
masked_fill_
(
vocab_mask
,
float
(
"-inf"
))
def
copy
(
self
):
return
OutlinesGrammar
(
self
.
guide
,
self
.
jump_forward_map
)
def
try_jump_forward
(
self
,
tokenizer
)
->
Optional
[
Tuple
]:
def
try_jump_forward
(
self
,
tokenizer
)
->
Optional
[
Tuple
]:
if
not
self
.
jump_forward_map
:
if
not
self
.
jump_forward_map
:
return
None
return
None
...
@@ -86,30 +110,6 @@ class OutlinesGrammar(BaseGrammarObject):
...
@@ -86,30 +110,6 @@ class OutlinesGrammar(BaseGrammarObject):
):
):
self
.
state
=
next_state
self
.
state
=
next_state
def
allocate_vocab_mask
(
self
,
vocab_size
:
int
,
batch_size
:
int
,
device
)
->
torch
.
Tensor
:
return
torch
.
zeros
(
batch_size
,
vocab_size
,
dtype
=
torch
.
bool
,
device
=
device
)
@
staticmethod
def
move_vocab_mask
(
vocab_mask
:
torch
.
Tensor
,
device
)
->
torch
.
Tensor
:
return
vocab_mask
def
fill_vocab_mask
(
self
,
vocab_mask
:
torch
.
Tensor
,
idx
:
int
)
->
None
:
tokens
=
torch
.
tensor
(
self
.
guide
.
get_next_instruction
(
self
.
state
).
tokens
,
dtype
=
torch
.
int64
).
to
(
vocab_mask
.
device
,
non_blocking
=
True
)
vocab_mask
=
vocab_mask
[
idx
]
vocab_mask
.
fill_
(
1
)
vocab_mask
.
scatter_
(
0
,
tokens
,
torch
.
zeros_like
(
tokens
,
dtype
=
torch
.
bool
))
@
staticmethod
def
apply_vocab_mask
(
logits
:
torch
.
Tensor
,
vocab_mask
:
torch
.
Tensor
):
logits
.
masked_fill_
(
vocab_mask
,
float
(
"-inf"
))
def
copy
(
self
):
return
OutlinesGrammar
(
self
.
guide
,
self
.
jump_forward_map
)
class
OutlinesGrammarBackend
(
BaseGrammarBackend
):
class
OutlinesGrammarBackend
(
BaseGrammarBackend
):
def
__init__
(
def
__init__
(
...
@@ -169,8 +169,9 @@ class OutlinesGrammarBackend(BaseGrammarBackend):
...
@@ -169,8 +169,9 @@ class OutlinesGrammarBackend(BaseGrammarBackend):
key_string
,
key_string
,
whitespace_pattern
=
self
.
whitespace_pattern
,
whitespace_pattern
=
self
.
whitespace_pattern
,
)
)
except
(
NotImplementedError
,
json
.
decoder
.
JSONDecodeError
)
as
e
:
except
(
NotImplementedError
,
json
.
decoder
.
JSONDecodeError
,
ValueError
)
as
e
:
logger
.
warning
(
f
"Skip invalid json_schema: json_schema=
{
key_string
}
,
{
e
=
}
"
)
logger
.
warning
(
f
"Skip invalid json_schema:
{
key_string
=
}
,
{
e
=
}
"
)
return
None
return
self
.
_compile_regex
(
regex
)
return
self
.
_compile_regex
(
regex
)
def
dispatch_regex
(
self
,
key_string
:
str
):
def
dispatch_regex
(
self
,
key_string
:
str
):
...
...
python/sglang/srt/constrained/reasoner_grammar_backend.py
View file @
01bdbf7f
...
@@ -13,7 +13,6 @@
...
@@ -13,7 +13,6 @@
# ==============================================================================
# ==============================================================================
"""The baseclass of a backend for reasoner grammar-guided constrained decoding."""
"""The baseclass of a backend for reasoner grammar-guided constrained decoding."""
from
concurrent.futures
import
Future
from
typing
import
List
,
Optional
,
Tuple
from
typing
import
List
,
Optional
,
Tuple
import
torch
import
torch
...
@@ -28,13 +27,12 @@ class ReasonerGrammarObject(BaseGrammarObject):
...
@@ -28,13 +27,12 @@ class ReasonerGrammarObject(BaseGrammarObject):
self
.
think_end_id
=
think_end_id
self
.
think_end_id
=
think_end_id
self
.
is_in_reasoning
=
True
self
.
is_in_reasoning
=
True
@
property
def
accept_token
(
self
,
token
:
int
):
def
finished
(
self
)
:
if
token
==
self
.
think_end_id
:
return
self
.
grammar
.
finished
self
.
is_in_reasoning
=
False
@
finished
.
setter
if
not
self
.
is_in_reasoning
and
token
!=
self
.
think_end_id
:
def
finished
(
self
,
finished
):
self
.
grammar
.
accept_token
(
token
)
self
.
grammar
.
finished
=
finished
def
allocate_vocab_mask
(
def
allocate_vocab_mask
(
self
,
vocab_size
:
int
,
batch_size
:
int
,
device
self
,
vocab_size
:
int
,
batch_size
:
int
,
device
...
@@ -52,12 +50,16 @@ class ReasonerGrammarObject(BaseGrammarObject):
...
@@ -52,12 +50,16 @@ class ReasonerGrammarObject(BaseGrammarObject):
def
apply_vocab_mask
(
self
):
def
apply_vocab_mask
(
self
):
return
self
.
grammar
.
apply_vocab_mask
return
self
.
grammar
.
apply_vocab_mask
def
accept_token
(
self
,
token
:
int
):
def
copy
(
self
)
->
BaseGrammarObject
:
if
token
==
self
.
think_end_id
:
return
ReasonerGrammarObject
(
self
.
grammar
.
copy
(),
self
.
think_end_id
)
self
.
is_in_reasoning
=
False
if
not
self
.
is_in_reasoning
and
token
!=
self
.
think_end_id
:
@
property
self
.
grammar
.
accept_token
(
token
)
def
finished
(
self
):
return
self
.
grammar
.
finished
@
finished
.
setter
def
finished
(
self
,
finished
):
self
.
grammar
.
finished
=
finished
def
try_jump_forward
(
self
,
tokenizer
):
def
try_jump_forward
(
self
,
tokenizer
):
return
self
.
grammar
.
try_jump_forward
(
tokenizer
)
return
self
.
grammar
.
try_jump_forward
(
tokenizer
)
...
@@ -72,30 +74,17 @@ class ReasonerGrammarObject(BaseGrammarObject):
...
@@ -72,30 +74,17 @@ class ReasonerGrammarObject(BaseGrammarObject):
old_output_ids
,
new_output_ids
,
next_state
old_output_ids
,
new_output_ids
,
next_state
)
)
def
copy
(
self
)
->
BaseGrammarObject
:
return
ReasonerGrammarObject
(
self
.
grammar
.
copy
(),
self
.
think_end_id
)
class
ReasonerGrammarBackend
(
BaseGrammarBackend
):
class
ReasonerGrammarBackend
(
BaseGrammarBackend
):
def
__init__
(
self
,
grammar_backend
:
BaseGrammarBackend
,
think_end_id
):
def
__init__
(
self
,
grammar_backend
:
BaseGrammarBackend
,
think_end_id
):
super
().
__init__
()
self
.
grammar_backend
=
grammar_backend
self
.
grammar_backend
=
grammar_backend
self
.
think_end_id
=
think_end_id
self
.
think_end_id
=
think_end_id
def
get_cached_value
(
self
,
key
:
Tuple
[
str
,
str
])
->
Optional
[
ReasonerGrammarObject
]:
def
_init_value_dispatch
(
grammar
=
self
.
grammar_backend
.
get_cached_value
(
key
)
self
,
key
:
Tuple
[
str
,
str
]
return
ReasonerGrammarObject
(
grammar
,
self
.
think_end_id
)
if
grammar
else
None
)
->
Optional
[
ReasonerGrammarObject
]:
ret
=
self
.
grammar_backend
.
_init_value_dispatch
(
key
)
def
get_future_value
(
self
,
key
:
Tuple
[
str
,
str
])
->
Future
:
if
ret
is
None
:
grammar
=
Future
()
return
None
return
ReasonerGrammarObject
(
ret
,
self
.
think_end_id
)
def
callback
(
f
:
Future
):
if
result
:
=
f
.
result
():
grammar
.
set_result
(
ReasonerGrammarObject
(
result
,
self
.
think_end_id
))
else
:
grammar
.
set_result
(
None
)
self
.
grammar_backend
.
get_future_value
(
key
).
add_done_callback
(
callback
)
return
grammar
def
reset
(
self
):
self
.
grammar_backend
.
reset
()
python/sglang/srt/constrained/xgrammar_backend.py
View file @
01bdbf7f
...
@@ -18,7 +18,6 @@ import logging
...
@@ -18,7 +18,6 @@ import logging
from
typing
import
List
,
Optional
,
Tuple
,
Union
from
typing
import
List
,
Optional
,
Tuple
,
Union
import
torch
import
torch
import
xgrammar
from
xgrammar
import
(
from
xgrammar
import
(
CompiledGrammar
,
CompiledGrammar
,
GrammarCompiler
,
GrammarCompiler
,
...
@@ -35,7 +34,6 @@ from sglang.srt.constrained.base_grammar_backend import (
...
@@ -35,7 +34,6 @@ from sglang.srt.constrained.base_grammar_backend import (
from
sglang.srt.constrained.triton_ops.bitmask_ops
import
(
from
sglang.srt.constrained.triton_ops.bitmask_ops
import
(
apply_token_bitmask_inplace_triton
,
apply_token_bitmask_inplace_triton
,
)
)
from
sglang.srt.utils
import
get_bool_env_var
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -51,49 +49,35 @@ class XGrammarGrammar(BaseGrammarObject):
...
@@ -51,49 +49,35 @@ class XGrammarGrammar(BaseGrammarObject):
vocab_size
:
int
,
vocab_size
:
int
,
ctx
:
CompiledGrammar
,
ctx
:
CompiledGrammar
,
override_stop_tokens
:
Optional
[
Union
[
List
[
int
],
int
]],
override_stop_tokens
:
Optional
[
Union
[
List
[
int
],
int
]],
key_string
:
Optional
[
str
]
=
None
,
# TODO (sk): for debugging, remove later
)
->
None
:
)
->
None
:
super
().
__init__
()
self
.
matcher
=
matcher
self
.
matcher
=
matcher
self
.
vocab_size
=
vocab_size
self
.
vocab_size
=
vocab_size
self
.
ctx
=
ctx
self
.
ctx
=
ctx
self
.
override_stop_tokens
=
override_stop_tokens
self
.
override_stop_tokens
=
override_stop_tokens
self
.
finished
=
False
self
.
finished
=
False
self
.
accepted_tokens
=
[]
from
xgrammar.kernels.apply_token_bitmask_inplace_cpu
import
(
self
.
key_string
=
key_string
apply_token_bitmask_inplace_cpu
,
)
self
.
apply_vocab_mask_cpu
=
apply_token_bitmask_inplace_cpu
def
accept_token
(
self
,
token
:
int
):
def
accept_token
(
self
,
token
:
int
):
assert
self
.
matcher
.
accept_token
(
token
)
if
not
self
.
is_terminated
():
accepted
=
self
.
matcher
.
accept_token
(
token
)
def
try_jump_forward
(
self
,
tokenizer
)
->
Optional
[
Tuple
[
List
[
int
],
str
]]:
if
not
accepted
:
s
=
self
.
matcher
.
find_jump_forward_string
()
# log for debugging
if
s
:
raise
ValueError
(
return
[],
s
f
"Tokens not accepted:
{
token
}
\n
"
return
None
f
"Accepted tokens:
{
self
.
accepted_tokens
}
\n
"
f
"Key string:
{
self
.
key_string
}
"
def
jump_forward_str_state
(
self
,
helper
:
Tuple
[
List
[
int
],
str
])
->
Tuple
[
str
,
int
]:
)
_
,
data
=
helper
return
data
,
-
1
def
jump_and_retokenize
(
self
,
old_output_ids
:
List
[
int
],
new_output_ids
:
List
[
int
],
next_state
:
int
):
k
=
0
for
i
,
old_id
in
enumerate
(
old_output_ids
):
if
old_id
==
new_output_ids
[
i
]:
k
=
i
+
1
else
:
else
:
break
self
.
accepted_tokens
.
append
(
token
)
#
rollback
to the last token that is the same
def
rollback
(
self
,
k
:
int
):
if
k
<
len
(
old_output_ids
):
self
.
matcher
.
rollback
(
k
)
self
.
matcher
.
rollback
(
len
(
old_output_ids
)
-
k
)
self
.
accepted_tokens
=
self
.
accepted_tokens
[:
-
k
]
for
i
in
range
(
k
,
len
(
new_output_ids
)
):
def
is_terminated
(
self
):
assert
self
.
matcher
.
accept_token
(
new_output_ids
[
i
]
)
return
self
.
matcher
.
is_terminated
(
)
def
allocate_vocab_mask
(
def
allocate_vocab_mask
(
self
,
vocab_size
:
int
,
batch_size
:
int
,
device
self
,
vocab_size
:
int
,
batch_size
:
int
,
device
...
@@ -122,9 +106,43 @@ class XGrammarGrammar(BaseGrammarObject):
...
@@ -122,9 +106,43 @@ class XGrammarGrammar(BaseGrammarObject):
override_stop_tokens
=
self
.
override_stop_tokens
,
override_stop_tokens
=
self
.
override_stop_tokens
,
)
)
return
XGrammarGrammar
(
return
XGrammarGrammar
(
matcher
,
self
.
vocab_size
,
self
.
ctx
,
self
.
override_stop_tokens
matcher
,
self
.
vocab_size
,
self
.
ctx
,
self
.
override_stop_tokens
,
self
.
key_string
,
)
)
def
try_jump_forward
(
self
,
tokenizer
)
->
Optional
[
Tuple
[
List
[
int
],
str
]]:
s
=
self
.
matcher
.
find_jump_forward_string
()
if
s
:
return
[],
s
return
None
def
jump_forward_str_state
(
self
,
helper
:
Tuple
[
List
[
int
],
str
])
->
Tuple
[
str
,
int
]:
_
,
data
=
helper
return
data
,
-
1
def
jump_and_retokenize
(
self
,
old_output_ids
:
List
[
int
],
new_output_ids
:
List
[
int
],
next_state
:
int
):
k
=
0
for
i
,
old_id
in
enumerate
(
old_output_ids
):
if
old_id
==
new_output_ids
[
i
]:
k
=
i
+
1
else
:
break
# rollback to the last token that is the same
if
k
<
len
(
old_output_ids
):
self
.
matcher
.
rollback
(
len
(
old_output_ids
)
-
k
)
for
i
in
range
(
k
,
len
(
new_output_ids
)):
assert
self
.
matcher
.
accept_token
(
new_output_ids
[
i
])
def
__repr__
(
self
):
return
f
"XGrammarGrammar(
{
self
.
key_string
=
}
,
{
self
.
accepted_tokens
=
}
)"
class
XGrammarGrammarBackend
(
BaseGrammarBackend
):
class
XGrammarGrammarBackend
(
BaseGrammarBackend
):
def
__init__
(
def
__init__
(
...
@@ -143,9 +161,15 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
...
@@ -143,9 +161,15 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
self
.
vocab_size
=
vocab_size
self
.
vocab_size
=
vocab_size
self
.
override_stop_tokens
=
override_stop_tokens
self
.
override_stop_tokens
=
override_stop_tokens
def
_from_context
(
self
,
ctx
:
CompiledGrammar
)
->
XGrammarGrammar
:
def
_from_context
(
self
,
ctx
:
CompiledGrammar
,
key_string
:
str
)
->
XGrammarGrammar
:
matcher
=
GrammarMatcher
(
ctx
,
max_rollback_tokens
=
MAX_ROLLBACK_TOKENS
)
matcher
=
GrammarMatcher
(
return
XGrammarGrammar
(
matcher
,
self
.
vocab_size
,
ctx
,
self
.
override_stop_tokens
)
ctx
,
max_rollback_tokens
=
MAX_ROLLBACK_TOKENS
,
override_stop_tokens
=
self
.
override_stop_tokens
,
)
return
XGrammarGrammar
(
matcher
,
self
.
vocab_size
,
ctx
,
self
.
override_stop_tokens
,
key_string
)
def
dispatch_json
(
self
,
key_string
:
str
)
->
Optional
[
XGrammarGrammar
]:
def
dispatch_json
(
self
,
key_string
:
str
)
->
Optional
[
XGrammarGrammar
]:
try
:
try
:
...
@@ -157,7 +181,7 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
...
@@ -157,7 +181,7 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
except
RuntimeError
as
e
:
except
RuntimeError
as
e
:
logging
.
warning
(
f
"Skip invalid json_schema: json_schema=
{
key_string
}
,
{
e
=
}
"
)
logging
.
warning
(
f
"Skip invalid json_schema: json_schema=
{
key_string
}
,
{
e
=
}
"
)
return
None
return
None
return
self
.
_from_context
(
ctx
)
return
self
.
_from_context
(
ctx
,
key_string
)
def
dispatch_ebnf
(
self
,
key_string
:
str
)
->
Optional
[
XGrammarGrammar
]:
def
dispatch_ebnf
(
self
,
key_string
:
str
)
->
Optional
[
XGrammarGrammar
]:
try
:
try
:
...
@@ -165,7 +189,7 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
...
@@ -165,7 +189,7 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
except
RuntimeError
as
e
:
except
RuntimeError
as
e
:
logging
.
warning
(
f
"Skip invalid ebnf: ebnf=
{
key_string
}
,
{
e
=
}
"
)
logging
.
warning
(
f
"Skip invalid ebnf: ebnf=
{
key_string
}
,
{
e
=
}
"
)
return
None
return
None
return
self
.
_from_context
(
ctx
)
return
self
.
_from_context
(
ctx
,
key_string
)
def
dispatch_regex
(
self
,
key_string
:
str
)
->
Optional
[
XGrammarGrammar
]:
def
dispatch_regex
(
self
,
key_string
:
str
)
->
Optional
[
XGrammarGrammar
]:
try
:
try
:
...
@@ -173,7 +197,7 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
...
@@ -173,7 +197,7 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
except
RuntimeError
as
e
:
except
RuntimeError
as
e
:
logging
.
warning
(
f
"Skip invalid regex: regex=
{
key_string
}
,
{
e
=
}
"
)
logging
.
warning
(
f
"Skip invalid regex: regex=
{
key_string
}
,
{
e
=
}
"
)
return
None
return
None
return
self
.
_from_context
(
ctx
)
return
self
.
_from_context
(
ctx
,
key_string
)
def
dispatch_structural_tag
(
self
,
key_string
:
str
)
->
Optional
[
XGrammarGrammar
]:
def
dispatch_structural_tag
(
self
,
key_string
:
str
)
->
Optional
[
XGrammarGrammar
]:
try
:
try
:
...
@@ -190,9 +214,11 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
...
@@ -190,9 +214,11 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
tags
,
structural_tag
[
"triggers"
]
tags
,
structural_tag
[
"triggers"
]
)
)
except
RuntimeError
as
e
:
except
RuntimeError
as
e
:
logging
.
warning
(
f
"Skip invalid regex: regex=
{
key_string
}
,
{
e
=
}
"
)
logging
.
warning
(
f
"Skip invalid structural_tag: structural_tag=
{
key_string
}
,
{
e
=
}
"
)
return
None
return
None
return
self
.
_from_context
(
ctx
)
return
self
.
_from_context
(
ctx
,
key_string
)
def
reset
(
self
):
def
reset
(
self
):
if
self
.
grammar_compiler
:
if
self
.
grammar_compiler
:
...
...
python/sglang/srt/layers/sampler.py
View file @
01bdbf7f
...
@@ -239,10 +239,6 @@ def top_p_normalize_probs_torch(
...
@@ -239,10 +239,6 @@ def top_p_normalize_probs_torch(
def
get_top_logprobs
(
logprobs
:
torch
.
Tensor
,
top_logprobs_nums
:
List
[
int
]):
def
get_top_logprobs
(
logprobs
:
torch
.
Tensor
,
top_logprobs_nums
:
List
[
int
]):
assert
len
(
top_logprobs_nums
)
==
logprobs
.
shape
[
0
],
(
len
(
top_logprobs_nums
),
logprobs
.
shape
[
0
],
)
max_k
=
max
(
top_logprobs_nums
)
max_k
=
max
(
top_logprobs_nums
)
ret
=
logprobs
.
topk
(
max_k
,
dim
=
1
)
ret
=
logprobs
.
topk
(
max_k
,
dim
=
1
)
values
=
ret
.
values
.
tolist
()
values
=
ret
.
values
.
tolist
()
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
01bdbf7f
...
@@ -533,6 +533,7 @@ class Req:
...
@@ -533,6 +533,7 @@ class Req:
# Constrained decoding
# Constrained decoding
self
.
grammar
:
Optional
[
BaseGrammarObject
]
=
None
self
.
grammar
:
Optional
[
BaseGrammarObject
]
=
None
self
.
grammar_wait_ct
=
0
# The number of cached tokens that were already cached in the KV cache
# The number of cached tokens that were already cached in the KV cache
self
.
cached_tokens
=
0
self
.
cached_tokens
=
0
...
...
python/sglang/srt/managers/scheduler.py
View file @
01bdbf7f
...
@@ -149,6 +149,7 @@ logger = logging.getLogger(__name__)
...
@@ -149,6 +149,7 @@ logger = logging.getLogger(__name__)
# Test retract decode for debugging purposes
# Test retract decode for debugging purposes
TEST_RETRACT
=
get_bool_env_var
(
"SGLANG_TEST_RETRACT"
)
TEST_RETRACT
=
get_bool_env_var
(
"SGLANG_TEST_RETRACT"
)
RECORD_STEP_TIME
=
get_bool_env_var
(
"SGLANG_RECORD_STEP_TIME"
)
RECORD_STEP_TIME
=
get_bool_env_var
(
"SGLANG_RECORD_STEP_TIME"
)
GRAMMAR_TIMEOUT
=
float
(
os
.
environ
.
get
(
"SGLANG_GRAMMAR_TIMEOUT"
,
300
))
@
dataclass
@
dataclass
...
@@ -1024,9 +1025,11 @@ class Scheduler(
...
@@ -1024,9 +1025,11 @@ class Scheduler(
elif
req
.
sampling_params
.
structural_tag
:
elif
req
.
sampling_params
.
structural_tag
:
key
=
(
"structural_tag"
,
req
.
sampling_params
.
structural_tag
)
key
=
(
"structural_tag"
,
req
.
sampling_params
.
structural_tag
)
req
.
grammar
=
self
.
grammar_backend
.
get_cached_value
(
key
)
value
,
cache_hit
=
self
.
grammar_backend
.
get_cached_or_future_value
(
key
)
if
not
req
.
grammar
:
req
.
grammar
=
value
req
.
grammar
=
self
.
grammar_backend
.
get_future_value
(
key
)
if
not
cache_hit
:
req
.
grammar_key
=
key
add_to_grammar_queue
=
True
add_to_grammar_queue
=
True
if
add_to_grammar_queue
:
if
add_to_grammar_queue
:
...
@@ -1208,6 +1211,7 @@ class Scheduler(
...
@@ -1208,6 +1211,7 @@ class Scheduler(
self
.
stats
.
cache_hit_rate
=
0.0
self
.
stats
.
cache_hit_rate
=
0.0
self
.
stats
.
gen_throughput
=
self
.
last_gen_throughput
self
.
stats
.
gen_throughput
=
self
.
last_gen_throughput
self
.
stats
.
num_queue_reqs
=
len
(
self
.
waiting_queue
)
self
.
stats
.
num_queue_reqs
=
len
(
self
.
waiting_queue
)
self
.
stats
.
num_grammar_queue_reqs
=
len
(
self
.
grammar_queue
)
self
.
stats
.
spec_accept_length
=
spec_accept_length
self
.
stats
.
spec_accept_length
=
spec_accept_length
self
.
metrics_collector
.
log_stats
(
self
.
stats
)
self
.
metrics_collector
.
log_stats
(
self
.
stats
)
...
@@ -1255,6 +1259,7 @@ class Scheduler(
...
@@ -1255,6 +1259,7 @@ class Scheduler(
self
.
stats
.
token_usage
=
num_used
/
self
.
max_total_num_tokens
self
.
stats
.
token_usage
=
num_used
/
self
.
max_total_num_tokens
self
.
stats
.
gen_throughput
=
0
self
.
stats
.
gen_throughput
=
0
self
.
stats
.
num_queue_reqs
=
len
(
self
.
waiting_queue
)
self
.
stats
.
num_queue_reqs
=
len
(
self
.
waiting_queue
)
self
.
stats
.
num_grammar_queue_reqs
=
len
(
self
.
grammar_queue
)
self
.
metrics_collector
.
log_stats
(
self
.
stats
)
self
.
metrics_collector
.
log_stats
(
self
.
stats
)
def
get_next_batch_to_run
(
self
)
->
Optional
[
ScheduleBatch
]:
def
get_next_batch_to_run
(
self
)
->
Optional
[
ScheduleBatch
]:
...
@@ -1715,11 +1720,17 @@ class Scheduler(
...
@@ -1715,11 +1720,17 @@ class Scheduler(
"""Move requests whose grammar objects are ready from grammar_queue to waiting_queue."""
"""Move requests whose grammar objects are ready from grammar_queue to waiting_queue."""
num_ready_reqs
=
0
num_ready_reqs
=
0
num_abort_reqs
=
0
for
req
in
self
.
grammar_queue
:
for
req
in
self
.
grammar_queue
:
try
:
try
:
req
.
grammar
=
req
.
grammar
.
result
(
timeout
=
0.05
)
req
.
grammar
=
req
.
grammar
.
result
(
timeout
=
0.03
)
if
req
.
grammar
:
self
.
grammar_backend
.
set_cache
(
req
.
grammar_key
,
req
.
grammar
.
copy
())
num_ready_reqs
+=
1
num_ready_reqs
+=
1
except
futures
.
_base
.
TimeoutError
:
except
futures
.
_base
.
TimeoutError
:
req
.
grammar_wait_ct
+=
1
if
req
.
grammar_wait_ct
>
GRAMMAR_TIMEOUT
/
0.03
:
num_abort_reqs
=
1
break
break
if
self
.
server_args
.
enable_dp_attention
:
if
self
.
server_args
.
enable_dp_attention
:
...
@@ -1731,14 +1742,28 @@ class Scheduler(
...
@@ -1731,14 +1742,28 @@ class Scheduler(
if
tp_size
>
1
:
if
tp_size
>
1
:
# Sync across TP ranks to make sure they have the same number of ready requests
# Sync across TP ranks to make sure they have the same number of ready requests
tensor
=
torch
.
tensor
(
num_ready_reqs
,
dtype
=
torch
.
int32
)
tensor
=
torch
.
tensor
(
[
num_ready_reqs
,
num_abort_reqs
],
dtype
=
torch
.
int32
)
torch
.
distributed
.
all_reduce
(
torch
.
distributed
.
all_reduce
(
tensor
,
op
=
torch
.
distributed
.
ReduceOp
.
MAX
,
group
=
tp_group
tensor
,
op
=
torch
.
distributed
.
ReduceOp
.
MAX
,
group
=
tp_group
)
)
num_ready_reqs_max
=
tensor
.
item
()
num_ready_reqs_max
,
num_abort_reqs_max
=
tensor
.
tolist
()
for
i
in
range
(
num_ready_reqs
,
num_ready_reqs_max
):
for
i
in
range
(
num_ready_reqs
,
num_ready_reqs_max
):
self
.
grammar_queue
[
i
].
grammar
=
self
.
grammar_queue
[
i
].
grammar
.
result
()
req
=
self
.
grammar_queue
[
i
]
num_ready_reqs
=
num_ready_reqs_max
req
.
grammar
=
req
.
grammar
.
result
()
if
req
.
grammar
:
self
.
grammar_backend
.
set_cache
(
req
.
grammar_key
,
req
.
grammar
.
copy
())
for
i
in
range
(
num_ready_reqs
,
num_ready_reqs
+
num_abort_reqs_max
):
req
=
self
.
grammar_queue
[
i
]
req
.
grammar
.
cancel
()
req
.
grammar
=
None
error_msg
=
f
"Grammar preprocessing timed out for
{
req
.
grammar_key
=
}
"
logger
.
error
(
error_msg
)
req
.
finished_reason
=
FINISH_ABORT
(
error_msg
,
HTTPStatus
.
BAD_REQUEST
,
"BadRequestError"
)
num_ready_reqs
=
num_ready_reqs_max
+
num_abort_reqs_max
self
.
_extend_requests_to_queue
(
self
.
grammar_queue
[:
num_ready_reqs
])
self
.
_extend_requests_to_queue
(
self
.
grammar_queue
[:
num_ready_reqs
])
self
.
grammar_queue
=
self
.
grammar_queue
[
num_ready_reqs
:]
self
.
grammar_queue
=
self
.
grammar_queue
[
num_ready_reqs
:]
...
...
python/sglang/srt/managers/tokenizer_manager.py
View file @
01bdbf7f
...
@@ -1230,11 +1230,18 @@ class TokenizerManager:
...
@@ -1230,11 +1230,18 @@ class TokenizerManager:
state
.
last_completion_tokens
=
completion_tokens
state
.
last_completion_tokens
=
completion_tokens
if
state
.
finished
:
if
state
.
finished
:
has_grammar
=
(
state
.
obj
.
sampling_params
.
get
(
"json_schema"
,
None
)
or
state
.
obj
.
sampling_params
.
get
(
"regex"
,
None
)
or
state
.
obj
.
sampling_params
.
get
(
"ebnf"
,
None
)
or
state
.
obj
.
sampling_params
.
get
(
"structural_tag"
,
None
)
)
self
.
metrics_collector
.
observe_one_finished_request
(
self
.
metrics_collector
.
observe_one_finished_request
(
recv_obj
.
prompt_tokens
[
i
],
recv_obj
.
prompt_tokens
[
i
],
completion_tokens
,
completion_tokens
,
recv_obj
.
cached_tokens
[
i
],
recv_obj
.
cached_tokens
[
i
],
state
.
finished_time
-
state
.
created_time
,
state
.
finished_time
-
state
.
created_time
,
has_grammar
,
)
)
def
dump_requests
(
self
,
state
:
ReqState
,
out_dict
:
dict
):
def
dump_requests
(
self
,
state
:
ReqState
,
out_dict
:
dict
):
...
...
python/sglang/srt/metrics/collector.py
View file @
01bdbf7f
...
@@ -15,7 +15,119 @@
...
@@ -15,7 +15,119 @@
import
time
import
time
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
Dict
,
Union
from
enum
import
Enum
from
typing
import
Dict
,
List
,
Optional
,
Union
from
sglang.srt.utils
import
get_bool_env_var
SGLANG_TEST_REQUEST_TIME_STATS
=
get_bool_env_var
(
"SGLANG_TEST_REQUEST_TIME_STATS"
)
@
dataclass
class
TimeStats
:
"""
Store the timestamps for each stage of a request.
Unified: wait_queue -> forward -> completion
Prefill: bootstrap_queue -> wait_queue -> forward -> transfer_queue -> completion
Decode: prealloc_queue -> transfer_queue -> wait_queue -> forward -> completion
"""
lb_entry_time
:
float
=
0.0
wait_queue_entry_time
:
float
=
0.0
forward_entry_time
:
float
=
0.0
completion_time
:
float
=
0.0
prefill_bootstrap_queue_entry_time
:
float
=
0.0
prefill_transfer_queue_entry_time
:
float
=
0.0
decode_prealloc_queue_entry_time
:
float
=
0.0
decode_transfer_queue_entry_time
:
float
=
0.0
class
RequestType
(
Enum
):
UNIFIED
=
"unified"
PREFILL
=
"prefill"
DECODE
=
"decode"
INVALID
=
"invalid"
def
__str__
(
self
)
->
str
:
# if unified
_type
=
self
.
get_type
()
if
_type
==
self
.
RequestType
.
UNIFIED
:
queue_duration
=
self
.
forward_entry_time
-
self
.
wait_queue_entry_time
forward_duration
=
self
.
completion_time
-
self
.
forward_entry_time
if
SGLANG_TEST_REQUEST_TIME_STATS
:
assert
(
queue_duration
>=
0
and
forward_duration
>=
0
),
f
"queue_duration=
{
queue_duration
}
< 0 or forward_duration=
{
forward_duration
}
< 0"
return
f
"queue_duration=
{
self
.
format_duration
(
queue_duration
)
}
, forward_duration=
{
self
.
format_duration
(
forward_duration
)
}
, start_time=
{
self
.
wait_queue_entry_time
}
"
elif
_type
==
self
.
RequestType
.
PREFILL
:
bootstrap_duration
=
(
self
.
wait_queue_entry_time
-
self
.
prefill_bootstrap_queue_entry_time
)
queue_duration
=
self
.
forward_entry_time
-
self
.
wait_queue_entry_time
forward_duration
=
self
.
completion_time
-
self
.
forward_entry_time
if
SGLANG_TEST_REQUEST_TIME_STATS
:
assert
(
bootstrap_duration
>=
0
and
queue_duration
>=
0
and
forward_duration
>=
0
),
f
"bootstrap_duration=
{
bootstrap_duration
}
< 0 or queue_duration=
{
queue_duration
}
< 0 or forward_duration=
{
forward_duration
}
< 0"
return
f
"bootstrap_duration=
{
self
.
format_duration
(
bootstrap_duration
)
}
, queue_duration=
{
self
.
format_duration
(
queue_duration
)
}
, forward_duration=
{
self
.
format_duration
(
forward_duration
)
}
, start_time=
{
self
.
prefill_bootstrap_queue_entry_time
}
"
# if decode
elif
_type
==
self
.
RequestType
.
DECODE
:
prealloc_duration
=
(
self
.
decode_transfer_queue_entry_time
-
self
.
decode_prealloc_queue_entry_time
)
transfer_duration
=
(
self
.
wait_queue_entry_time
-
self
.
decode_transfer_queue_entry_time
)
queue_duration
=
self
.
forward_entry_time
-
self
.
wait_queue_entry_time
forward_duration
=
self
.
completion_time
-
self
.
forward_entry_time
if
SGLANG_TEST_REQUEST_TIME_STATS
:
assert
(
prealloc_duration
>=
0
and
transfer_duration
>=
0
and
queue_duration
>=
0
and
forward_duration
>=
0
),
f
"prealloc_duration=
{
prealloc_duration
}
< 0 or transfer_duration=
{
transfer_duration
}
< 0 or queue_duration=
{
queue_duration
}
< 0 or forward_duration=
{
forward_duration
}
< 0"
return
f
"prealloc_duration=
{
self
.
format_duration
(
prealloc_duration
)
}
, transfer_duration=
{
self
.
format_duration
(
transfer_duration
)
}
, queue_duration=
{
self
.
format_duration
(
queue_duration
)
}
, forward_duration=
{
self
.
format_duration
(
forward_duration
)
}
, start_time=
{
self
.
decode_prealloc_queue_entry_time
}
"
else
:
return
"Invalid Time Stats"
def
format_duration
(
self
,
duration
:
float
)
->
str
:
return
f
"
{
duration
*
1e3
:.
2
f
}
ms"
def
get_type
(
self
)
->
RequestType
:
"""Determine the type of request based on timestamp values."""
if
(
self
.
prefill_bootstrap_queue_entry_time
==
0.0
and
self
.
prefill_transfer_queue_entry_time
==
0.0
and
self
.
decode_prealloc_queue_entry_time
==
0.0
and
self
.
decode_transfer_queue_entry_time
==
0.0
):
return
self
.
RequestType
.
UNIFIED
elif
(
self
.
prefill_bootstrap_queue_entry_time
>
0.0
and
self
.
prefill_transfer_queue_entry_time
>
0.0
):
return
self
.
RequestType
.
PREFILL
elif
(
self
.
decode_prealloc_queue_entry_time
>
0.0
and
self
.
decode_transfer_queue_entry_time
>
0.0
and
self
.
wait_queue_entry_time
>
0.0
):
return
self
.
RequestType
.
DECODE
else
:
return
self
.
RequestType
.
INVALID
@
dataclass
@
dataclass
...
@@ -26,15 +138,20 @@ class SchedulerStats:
...
@@ -26,15 +138,20 @@ class SchedulerStats:
gen_throughput
:
float
=
0.0
gen_throughput
:
float
=
0.0
num_queue_reqs
:
int
=
0
num_queue_reqs
:
int
=
0
cache_hit_rate
:
float
=
0.0
cache_hit_rate
:
float
=
0.0
num_grammar_queue_reqs
:
int
=
0
spec_accept_length
:
float
=
0.0
spec_accept_length
:
float
=
0.0
avg_request_queue_latency
:
float
=
0.0
avg_request_queue_latency
:
float
=
0.0
num_prefill_prealloc_queue_reqs
:
int
=
0
num_prefill_infight_queue_reqs
:
int
=
0
num_decode_prealloc_queue_reqs
:
int
=
0
num_decode_transfer_queue_reqs
:
int
=
0
class
SchedulerMetricsCollector
:
class
SchedulerMetricsCollector
:
def
__init__
(
self
,
labels
:
Dict
[
str
,
str
])
->
None
:
def
__init__
(
self
,
labels
:
Dict
[
str
,
str
])
->
None
:
# We need to import prometheus_client after setting the env variable `PROMETHEUS_MULTIPROC_DIR`
# We need to import prometheus_client after setting the env variable `PROMETHEUS_MULTIPROC_DIR`
from
prometheus_client
import
Gauge
,
Histogram
from
prometheus_client
import
Counter
,
Gauge
self
.
labels
=
labels
self
.
labels
=
labels
self
.
last_log_time
=
time
.
time
()
self
.
last_log_time
=
time
.
time
()
...
@@ -74,6 +191,13 @@ class SchedulerMetricsCollector:
...
@@ -74,6 +191,13 @@ class SchedulerMetricsCollector:
multiprocess_mode
=
"mostrecent"
,
multiprocess_mode
=
"mostrecent"
,
)
)
self
.
num_grammar_queue_reqs
=
Gauge
(
name
=
"sglang:num_grammar_queue_reqs"
,
documentation
=
"The number of requests in the grammar waiting queue."
,
labelnames
=
labels
.
keys
(),
multiprocess_mode
=
"mostrecent"
,
)
self
.
cache_hit_rate
=
Gauge
(
self
.
cache_hit_rate
=
Gauge
(
name
=
"sglang:cache_hit_rate"
,
name
=
"sglang:cache_hit_rate"
,
documentation
=
"The prefix cache hit rate."
,
documentation
=
"The prefix cache hit rate."
,
...
@@ -95,28 +219,98 @@ class SchedulerMetricsCollector:
...
@@ -95,28 +219,98 @@ class SchedulerMetricsCollector:
multiprocess_mode
=
"mostrecent"
,
multiprocess_mode
=
"mostrecent"
,
)
)
# Disaggregation queue metrics
self
.
num_prefill_prealloc_queue_reqs
=
Gauge
(
name
=
"sglang:num_prefill_prealloc_queue_reqs"
,
documentation
=
"The number of requests in the prefill prealloc queue."
,
labelnames
=
labels
.
keys
(),
multiprocess_mode
=
"mostrecent"
,
)
self
.
num_prefill_infight_queue_reqs
=
Gauge
(
name
=
"sglang:num_prefill_infight_queue_reqs"
,
documentation
=
"The number of requests in the prefill infight queue."
,
labelnames
=
labels
.
keys
(),
multiprocess_mode
=
"mostrecent"
,
)
self
.
num_decode_prealloc_queue_reqs
=
Gauge
(
name
=
"sglang:num_decode_prealloc_queue_reqs"
,
documentation
=
"The number of requests in the decode prealloc queue."
,
labelnames
=
labels
.
keys
(),
multiprocess_mode
=
"mostrecent"
,
)
self
.
num_decode_transfer_queue_reqs
=
Gauge
(
name
=
"sglang:num_decode_transfer_queue_reqs"
,
documentation
=
"The number of requests in the decode transfer queue."
,
labelnames
=
labels
.
keys
(),
multiprocess_mode
=
"mostrecent"
,
)
self
.
num_bootstrap_failed_reqs
=
Counter
(
name
=
"sglang:num_bootstrap_failed_reqs"
,
documentation
=
"The number of bootstrap failed requests."
,
labelnames
=
labels
.
keys
(),
)
self
.
num_transfer_failed_reqs
=
Counter
(
name
=
"sglang:num_transfer_failed_reqs"
,
documentation
=
"The number of transfer failed requests."
,
labelnames
=
labels
.
keys
(),
)
def
_log_gauge
(
self
,
gauge
,
data
:
Union
[
int
,
float
])
->
None
:
def
_log_gauge
(
self
,
gauge
,
data
:
Union
[
int
,
float
])
->
None
:
# Convenience function for logging to gauge.
# Convenience function for logging to gauge.
gauge
.
labels
(
**
self
.
labels
).
set
(
data
)
gauge
.
labels
(
**
self
.
labels
).
set
(
data
)
def
increment_bootstrap_failed_reqs
(
self
)
->
None
:
self
.
num_bootstrap_failed_reqs
.
labels
(
**
self
.
labels
).
inc
(
1
)
def
increment_transfer_failed_reqs
(
self
)
->
None
:
self
.
num_transfer_failed_reqs
.
labels
(
**
self
.
labels
).
inc
(
1
)
def
log_stats
(
self
,
stats
:
SchedulerStats
)
->
None
:
def
log_stats
(
self
,
stats
:
SchedulerStats
)
->
None
:
self
.
_log_gauge
(
self
.
num_running_reqs
,
stats
.
num_running_reqs
)
self
.
_log_gauge
(
self
.
num_running_reqs
,
stats
.
num_running_reqs
)
self
.
_log_gauge
(
self
.
num_used_tokens
,
stats
.
num_used_tokens
)
self
.
_log_gauge
(
self
.
num_used_tokens
,
stats
.
num_used_tokens
)
self
.
_log_gauge
(
self
.
token_usage
,
stats
.
token_usage
)
self
.
_log_gauge
(
self
.
token_usage
,
stats
.
token_usage
)
self
.
_log_gauge
(
self
.
gen_throughput
,
stats
.
gen_throughput
)
self
.
_log_gauge
(
self
.
gen_throughput
,
stats
.
gen_throughput
)
self
.
_log_gauge
(
self
.
num_queue_reqs
,
stats
.
num_queue_reqs
)
self
.
_log_gauge
(
self
.
num_queue_reqs
,
stats
.
num_queue_reqs
)
self
.
_log_gauge
(
self
.
num_grammar_queue_reqs
,
stats
.
num_grammar_queue_reqs
)
self
.
_log_gauge
(
self
.
cache_hit_rate
,
stats
.
cache_hit_rate
)
self
.
_log_gauge
(
self
.
cache_hit_rate
,
stats
.
cache_hit_rate
)
self
.
_log_gauge
(
self
.
spec_accept_length
,
stats
.
spec_accept_length
)
self
.
_log_gauge
(
self
.
spec_accept_length
,
stats
.
spec_accept_length
)
self
.
_log_gauge
(
self
.
avg_request_queue_latency
,
stats
.
avg_request_queue_latency
)
# Disaggregation metrics
self
.
_log_gauge
(
self
.
num_prefill_prealloc_queue_reqs
,
stats
.
num_prefill_prealloc_queue_reqs
)
self
.
_log_gauge
(
self
.
num_prefill_infight_queue_reqs
,
stats
.
num_prefill_infight_queue_reqs
)
self
.
_log_gauge
(
self
.
num_decode_prealloc_queue_reqs
,
stats
.
num_decode_prealloc_queue_reqs
)
self
.
_log_gauge
(
self
.
num_decode_transfer_queue_reqs
,
stats
.
num_decode_transfer_queue_reqs
)
self
.
last_log_time
=
time
.
time
()
self
.
last_log_time
=
time
.
time
()
class
TokenizerMetricsCollector
:
class
TokenizerMetricsCollector
:
def
__init__
(
self
,
labels
:
Dict
[
str
,
str
])
->
None
:
def
__init__
(
self
,
labels
:
Dict
[
str
,
str
],
bucket_time_to_first_token
:
Optional
[
List
[
float
]]
=
None
,
bucket_inter_token_latency
:
Optional
[
List
[
float
]]
=
None
,
bucket_e2e_request_latency
:
Optional
[
List
[
float
]]
=
None
,
collect_tokens_histogram
:
bool
=
False
,
)
->
None
:
# We need to import prometheus_client after setting the env variable `PROMETHEUS_MULTIPROC_DIR`
# We need to import prometheus_client after setting the env variable `PROMETHEUS_MULTIPROC_DIR`
from
prometheus_client
import
Counter
,
Histogram
from
prometheus_client
import
Counter
,
Histogram
self
.
labels
=
labels
self
.
labels
=
labels
self
.
collect_tokens_histogram
=
collect_tokens_histogram
self
.
prompt_tokens_total
=
Counter
(
self
.
prompt_tokens_total
=
Counter
(
name
=
"sglang:prompt_tokens_total"
,
name
=
"sglang:prompt_tokens_total"
,
...
@@ -130,6 +324,66 @@ class TokenizerMetricsCollector:
...
@@ -130,6 +324,66 @@ class TokenizerMetricsCollector:
labelnames
=
labels
.
keys
(),
labelnames
=
labels
.
keys
(),
)
)
if
collect_tokens_histogram
:
bucket_prompt_tokens
=
[
100
,
300
,
500
,
700
,
1000
,
1500
,
2000
,
3000
,
4000
,
5000
,
6000
,
7000
,
8000
,
9000
,
10000
,
12000
,
15000
,
20000
,
22000
,
25000
,
30000
,
35000
,
40000
,
]
self
.
prompt_tokens_histogram
=
Histogram
(
name
=
"sglang:prompt_tokens_histogram"
,
documentation
=
"Histogram of prompt token length."
,
labelnames
=
labels
.
keys
(),
buckets
=
bucket_prompt_tokens
,
)
bucket_generation_tokens
=
[
100
,
300
,
500
,
1000
,
1200
,
1500
,
1700
,
2000
,
2500
,
3000
,
3500
,
4000
,
4500
,
5000
,
6000
,
7000
,
8000
,
9000
,
10000
,
]
self
.
generation_tokens_histogram
=
Histogram
(
name
=
"sglang:generation_tokens_histogram"
,
documentation
=
"Histogram of generation token length."
,
labelnames
=
labels
.
keys
(),
buckets
=
bucket_generation_tokens
,
)
self
.
cached_tokens_total
=
Counter
(
self
.
cached_tokens_total
=
Counter
(
name
=
"sglang:cached_tokens_total"
,
name
=
"sglang:cached_tokens_total"
,
documentation
=
"Number of cached prompt tokens."
,
documentation
=
"Number of cached prompt tokens."
,
...
@@ -142,11 +396,14 @@ class TokenizerMetricsCollector:
...
@@ -142,11 +396,14 @@ class TokenizerMetricsCollector:
labelnames
=
labels
.
keys
(),
labelnames
=
labels
.
keys
(),
)
)
self
.
histogram_time_to_fir
st_to
ken
=
Histogram
(
self
.
num_so_reque
st
s
_to
tal
=
Counter
(
name
=
"sglang:
time_to_fir
st_to
ken_seconds
"
,
name
=
"sglang:
num_so_reque
st
s
_to
tal
"
,
documentation
=
"
Histogram of time to first token in seconds
."
,
documentation
=
"
Number of structured output requests processed
."
,
labelnames
=
labels
.
keys
(),
labelnames
=
labels
.
keys
(),
buckets
=
[
)
if
bucket_time_to_first_token
is
None
:
bucket_time_to_first_token
=
[
0.1
,
0.1
,
0.2
,
0.2
,
0.4
,
0.4
,
...
@@ -165,14 +422,33 @@ class TokenizerMetricsCollector:
...
@@ -165,14 +422,33 @@ class TokenizerMetricsCollector:
100
,
100
,
200
,
200
,
400
,
400
,
],
]
)
self
.
histogram_inter_token_latency_seconds
=
Histogram
(
if
bucket_e2e_request_latency
is
None
:
name
=
"sglang:inter_token_latency_seconds"
,
bucket_e2e_request_latency
=
[
documentation
=
"Histogram of inter-token latency in seconds."
,
0.1
,
labelnames
=
labels
.
keys
(),
0.2
,
buckets
=
[
0.4
,
0.6
,
0.8
,
1
,
2
,
4
,
6
,
8
,
10
,
20
,
40
,
60
,
80
,
100
,
200
,
400
,
800
,
]
if
bucket_inter_token_latency
is
None
:
bucket_inter_token_latency
=
[
0.002
,
0.002
,
0.004
,
0.004
,
0.006
,
0.006
,
...
@@ -196,34 +472,27 @@ class TokenizerMetricsCollector:
...
@@ -196,34 +472,27 @@ class TokenizerMetricsCollector:
4.000
,
4.000
,
6.000
,
6.000
,
8.000
,
8.000
,
],
]
self
.
histogram_time_to_first_token
=
Histogram
(
name
=
"sglang:time_to_first_token_seconds"
,
documentation
=
"Histogram of time to first token in seconds."
,
labelnames
=
labels
.
keys
(),
buckets
=
bucket_time_to_first_token
,
)
self
.
histogram_inter_token_latency_seconds
=
Histogram
(
name
=
"sglang:inter_token_latency_seconds"
,
documentation
=
"Histogram of inter-token latency in seconds."
,
labelnames
=
labels
.
keys
(),
buckets
=
bucket_inter_token_latency
,
)
)
self
.
histogram_e2e_request_latency
=
Histogram
(
self
.
histogram_e2e_request_latency
=
Histogram
(
name
=
"sglang:e2e_request_latency_seconds"
,
name
=
"sglang:e2e_request_latency_seconds"
,
documentation
=
"Histogram of End-to-end request latency in seconds"
,
documentation
=
"Histogram of End-to-end request latency in seconds"
,
labelnames
=
labels
.
keys
(),
labelnames
=
labels
.
keys
(),
buckets
=
[
buckets
=
bucket_e2e_request_latency
,
0.1
,
0.2
,
0.4
,
0.6
,
0.8
,
1
,
2
,
4
,
6
,
8
,
10
,
20
,
40
,
60
,
80
,
100
,
200
,
400
,
800
,
],
)
)
def
_log_histogram
(
self
,
histogram
,
data
:
Union
[
int
,
float
])
->
None
:
def
_log_histogram
(
self
,
histogram
,
data
:
Union
[
int
,
float
])
->
None
:
...
@@ -235,13 +504,19 @@ class TokenizerMetricsCollector:
...
@@ -235,13 +504,19 @@ class TokenizerMetricsCollector:
generation_tokens
:
int
,
generation_tokens
:
int
,
cached_tokens
:
int
,
cached_tokens
:
int
,
e2e_latency
:
float
,
e2e_latency
:
float
,
has_grammar
:
bool
,
):
):
self
.
prompt_tokens_total
.
labels
(
**
self
.
labels
).
inc
(
prompt_tokens
)
self
.
prompt_tokens_total
.
labels
(
**
self
.
labels
).
inc
(
prompt_tokens
)
self
.
generation_tokens_total
.
labels
(
**
self
.
labels
).
inc
(
generation_tokens
)
self
.
generation_tokens_total
.
labels
(
**
self
.
labels
).
inc
(
generation_tokens
)
if
cached_tokens
>
0
:
if
cached_tokens
>
0
:
self
.
cached_tokens_total
.
labels
(
**
self
.
labels
).
inc
(
cached_tokens
)
self
.
cached_tokens_total
.
labels
(
**
self
.
labels
).
inc
(
cached_tokens
)
self
.
num_requests_total
.
labels
(
**
self
.
labels
).
inc
(
1
)
self
.
num_requests_total
.
labels
(
**
self
.
labels
).
inc
(
1
)
if
has_grammar
:
self
.
num_so_requests_total
.
labels
(
**
self
.
labels
).
inc
(
1
)
self
.
_log_histogram
(
self
.
histogram_e2e_request_latency
,
e2e_latency
)
self
.
_log_histogram
(
self
.
histogram_e2e_request_latency
,
e2e_latency
)
if
self
.
collect_tokens_histogram
:
self
.
_log_histogram
(
self
.
prompt_tokens_histogram
,
prompt_tokens
)
self
.
_log_histogram
(
self
.
generation_tokens_histogram
,
generation_tokens
)
def
observe_time_to_first_token
(
self
,
value
:
float
):
def
observe_time_to_first_token
(
self
,
value
:
float
):
self
.
histogram_time_to_first_token
.
labels
(
**
self
.
labels
).
observe
(
value
)
self
.
histogram_time_to_first_token
.
labels
(
**
self
.
labels
).
observe
(
value
)
...
...
test/srt/test_json_constrained.py
View file @
01bdbf7f
...
@@ -82,7 +82,7 @@ class TestJSONConstrainedOutlinesBackend(CustomTestCase):
...
@@ -82,7 +82,7 @@ class TestJSONConstrainedOutlinesBackend(CustomTestCase):
print
(
json
.
dumps
(
ret
))
print
(
json
.
dumps
(
ret
))
print
(
"="
*
100
)
print
(
"="
*
100
)
if
not
json_schema
:
if
not
json_schema
or
json_schema
==
"INVALID"
:
return
return
# Make sure the json output is valid
# Make sure the json output is valid
...
@@ -97,6 +97,9 @@ class TestJSONConstrainedOutlinesBackend(CustomTestCase):
...
@@ -97,6 +97,9 @@ class TestJSONConstrainedOutlinesBackend(CustomTestCase):
def
test_json_generate
(
self
):
def
test_json_generate
(
self
):
self
.
run_decode
(
json_schema
=
self
.
json_schema
)
self
.
run_decode
(
json_schema
=
self
.
json_schema
)
def
test_json_invalid
(
self
):
self
.
run_decode
(
json_schema
=
"INVALID"
)
def
test_json_openai
(
self
):
def
test_json_openai
(
self
):
client
=
openai
.
Client
(
api_key
=
"EMPTY"
,
base_url
=
f
"
{
self
.
base_url
}
/v1"
)
client
=
openai
.
Client
(
api_key
=
"EMPTY"
,
base_url
=
f
"
{
self
.
base_url
}
/v1"
)
...
@@ -104,7 +107,10 @@ class TestJSONConstrainedOutlinesBackend(CustomTestCase):
...
@@ -104,7 +107,10 @@ class TestJSONConstrainedOutlinesBackend(CustomTestCase):
model
=
self
.
model
,
model
=
self
.
model
,
messages
=
[
messages
=
[
{
"role"
:
"system"
,
"content"
:
"You are a helpful AI assistant"
},
{
"role"
:
"system"
,
"content"
:
"You are a helpful AI assistant"
},
{
"role"
:
"user"
,
"content"
:
"Introduce the capital of France."
},
{
"role"
:
"user"
,
"content"
:
"Introduce the capital of France. Return in a JSON format."
,
},
],
],
temperature
=
0
,
temperature
=
0
,
max_tokens
=
128
,
max_tokens
=
128
,
...
...
test/srt/test_metrics.py
View file @
01bdbf7f
...
@@ -56,6 +56,7 @@ class TestEnableMetrics(CustomTestCase):
...
@@ -56,6 +56,7 @@ class TestEnableMetrics(CustomTestCase):
"sglang:token_usage"
,
"sglang:token_usage"
,
"sglang:gen_throughput"
,
"sglang:gen_throughput"
,
"sglang:num_queue_reqs"
,
"sglang:num_queue_reqs"
,
"sglang:num_grammar_queue_reqs"
,
"sglang:cache_hit_rate"
,
"sglang:cache_hit_rate"
,
"sglang:spec_accept_length"
,
"sglang:spec_accept_length"
,
"sglang:prompt_tokens_total"
,
"sglang:prompt_tokens_total"
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment