Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
01bdbf7f
Unverified
Commit
01bdbf7f
authored
May 11, 2025
by
Lianmin Zheng
Committed by
GitHub
May 11, 2025
Browse files
Improve structured outputs: fix race condition, server crash, metrics and style (#6188)
parent
94d42b67
Changes
13
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
568 additions
and
258 deletions
+568
-258
docs/backend/structured_outputs_for_reasoning_models.ipynb
docs/backend/structured_outputs_for_reasoning_models.ipynb
+14
-12
python/sglang/srt/constrained/base_grammar_backend.py
python/sglang/srt/constrained/base_grammar_backend.py
+49
-72
python/sglang/srt/constrained/llguidance_backend.py
python/sglang/srt/constrained/llguidance_backend.py
+25
-21
python/sglang/srt/constrained/outlines_backend.py
python/sglang/srt/constrained/outlines_backend.py
+27
-26
python/sglang/srt/constrained/reasoner_grammar_backend.py
python/sglang/srt/constrained/reasoner_grammar_backend.py
+22
-33
python/sglang/srt/constrained/xgrammar_backend.py
python/sglang/srt/constrained/xgrammar_backend.py
+69
-43
python/sglang/srt/layers/sampler.py
python/sglang/srt/layers/sampler.py
+0
-4
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+1
-0
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+33
-8
python/sglang/srt/managers/tokenizer_manager.py
python/sglang/srt/managers/tokenizer_manager.py
+7
-0
python/sglang/srt/metrics/collector.py
python/sglang/srt/metrics/collector.py
+312
-37
test/srt/test_json_constrained.py
test/srt/test_json_constrained.py
+8
-2
test/srt/test_metrics.py
test/srt/test_metrics.py
+1
-0
No files found.
docs/backend/structured_outputs_for_reasoning_models.ipynb
View file @
01bdbf7f
...
...
@@ -94,8 +94,8 @@
" model=\"deepseek-ai/DeepSeek-R1-Distill-Qwen-7B\",\n",
" messages=[\n",
" {\n",
" \"role\": \"
user
\",\n",
" \"content\": \"
Please generate the inform
ation of the capital of France in the JSON format.\",\n",
" \"role\": \"
assistant
\",\n",
" \"content\": \"
Give me the information and popul
ation of the capital of France in the JSON format.\",\n",
" },\n",
" ],\n",
" temperature=0,\n",
...
...
@@ -145,8 +145,8 @@
" model=\"deepseek-ai/DeepSeek-R1-Distill-Qwen-7B\",\n",
" messages=[\n",
" {\n",
" \"role\": \"
user
\",\n",
" \"content\": \"Give me the information of the capital of France in the JSON format.\",\n",
" \"role\": \"
assistant
\",\n",
" \"content\": \"Give me the information
and population
of the capital of France in the JSON format.\",\n",
" },\n",
" ],\n",
" temperature=0,\n",
...
...
@@ -188,8 +188,8 @@
" messages=[\n",
" {\"role\": \"system\", \"content\": \"You are a helpful geography bot.\"},\n",
" {\n",
" \"role\": \"
user
\",\n",
" \"content\": \"Give me the information of the capital of France.\",\n",
" \"role\": \"
assistant
\",\n",
" \"content\": \"Give me the information
and population
of the capital of France
in the JSON format
.\",\n",
" },\n",
" ],\n",
" temperature=0,\n",
...
...
@@ -218,7 +218,7 @@
"response = client.chat.completions.create(\n",
" model=\"deepseek-ai/DeepSeek-R1-Distill-Qwen-7B\",\n",
" messages=[\n",
" {\"role\": \"
user
\", \"content\": \"What is the capital of France?\"},\n",
" {\"role\": \"
assistant
\", \"content\": \"What is the capital of France?\"},\n",
" ],\n",
" temperature=0,\n",
" max_tokens=2048,\n",
...
...
@@ -323,7 +323,7 @@
"You are a helpful assistant.\"\"\",\n",
" },\n",
" {\n",
" \"role\": \"
user
\",\n",
" \"role\": \"
assistant
\",\n",
" \"content\": \"You are in New York. Please get the current date and time, and the weather.\",\n",
" },\n",
" ]\n",
...
...
@@ -400,9 +400,9 @@
"\n",
"messages = [\n",
" {\n",
" \"role\": \"
user
\",\n",
" \"content\": \"
Here is
the information of the capital of France in the JSON format.\
\n\
",\n",
" }\n",
" \"role\": \"
assistant
\",\n",
" \"content\": \"
Give me
the information
and population
of the capital of France in the JSON format.\",\n",
" }
,
\n",
"]\n",
"text = tokenizer.apply_chat_template(\n",
" messages, tokenize=False, add_generation_prompt=True\n",
...
...
@@ -452,7 +452,9 @@
")\n",
"\n",
"# JSON\n",
"text = tokenizer.apply_chat_template(text, tokenize=False, add_generation_prompt=True)\n",
"text = tokenizer.apply_chat_template(\n",
" messages, tokenize=False, add_generation_prompt=True\n",
")\n",
"response = requests.post(\n",
" f\"http://localhost:{port}/generate\",\n",
" json={\n",
...
...
python/sglang/srt/constrained/base_grammar_backend.py
View file @
01bdbf7f
...
...
@@ -14,10 +14,9 @@
"""The baseclass of a backend for grammar-guided constrained decoding."""
import
logging
from
abc
import
ABC
,
abstractmethod
from
concurrent.futures
import
Future
,
ThreadPoolExecutor
from
concurrent.futures
import
ThreadPoolExecutor
from
dataclasses
import
dataclass
from
threading
import
Event
,
Lock
from
threading
import
Event
from
typing
import
Dict
,
List
,
Optional
,
Tuple
import
torch
...
...
@@ -27,11 +26,36 @@ from sglang.srt.server_args import ServerArgs
logger
=
logging
.
getLogger
(
__name__
)
class
BaseGrammarObject
(
ABC
)
:
class
BaseGrammarObject
:
def
__init__
(
self
):
self
.
_finished
=
False
def
accept_token
(
self
,
token
:
int
)
->
None
:
"""
Accept a token in the grammar.
"""
raise
NotImplementedError
()
def
allocate_vocab_mask
(
self
,
vocab_size
:
int
,
batch_size
:
int
,
device
)
->
torch
.
Tensor
:
raise
NotImplementedError
()
def
fill_vocab_mask
(
self
,
vocab_mask
:
torch
.
Tensor
,
idx
:
int
)
->
None
:
raise
NotImplementedError
()
@
staticmethod
def
move_vocab_mask
(
vocab_mask
:
torch
.
Tensor
,
device
)
->
torch
.
Tensor
:
raise
NotImplementedError
()
@
staticmethod
def
apply_vocab_mask
(
logits
:
torch
.
Tensor
,
vocab_mask
:
torch
.
Tensor
)
->
None
:
raise
NotImplementedError
()
def
copy
(
self
)
->
"BaseGrammarObject"
:
raise
NotImplementedError
()
@
property
def
finished
(
self
):
return
self
.
_finished
...
...
@@ -40,7 +64,6 @@ class BaseGrammarObject(ABC):
def
finished
(
self
,
finished
):
self
.
_finished
=
finished
@
abstractmethod
def
try_jump_forward
(
self
,
tokenizer
)
->
Optional
[
Tuple
[
List
[
int
],
str
]]:
"""
Try to jump forward in the grammar.
...
...
@@ -49,9 +72,8 @@ class BaseGrammarObject(ABC):
A jump forward helper which may be used in `jump_forward_str_state`.
None if the jump forward is not possible.
"""
raise
NotImplementedError
raise
NotImplementedError
()
@
abstractmethod
def
jump_forward_str_state
(
self
,
helper
:
Tuple
[
List
[
int
],
str
])
->
Tuple
[
str
,
int
]:
"""
Jump forward for the grammar.
...
...
@@ -60,47 +82,15 @@ class BaseGrammarObject(ABC):
A tuple of the jump forward string and the next state of the grammar
(which can be used in `jump_and_retokenize` if needed).
"""
raise
NotImplementedError
raise
NotImplementedError
()
@
abstractmethod
def
jump_and_retokenize
(
self
,
old_output_ids
:
List
[
int
],
new_output_ids
:
List
[
int
],
next_state
:
int
)
->
None
:
"""
Jump forward occurs, and update the grammar state if needed.
"""
raise
NotImplementedError
@
abstractmethod
def
accept_token
(
self
,
token
:
int
)
->
None
:
"""
Accept a token in the grammar.
"""
raise
NotImplementedError
@
abstractmethod
def
allocate_vocab_mask
(
self
,
vocab_size
:
int
,
batch_size
:
int
,
device
)
->
torch
.
Tensor
:
raise
NotImplementedError
@
abstractmethod
def
fill_vocab_mask
(
self
,
vocab_mask
:
torch
.
Tensor
,
idx
:
int
)
->
None
:
raise
NotImplementedError
@
staticmethod
@
abstractmethod
def
move_vocab_mask
(
vocab_mask
:
torch
.
Tensor
,
device
)
->
torch
.
Tensor
:
raise
NotImplementedError
@
staticmethod
@
abstractmethod
def
apply_vocab_mask
(
logits
:
torch
.
Tensor
,
vocab_mask
:
torch
.
Tensor
)
->
None
:
raise
NotImplementedError
@
abstractmethod
def
copy
(
self
)
->
"BaseGrammarObject"
:
raise
NotImplementedError
raise
NotImplementedError
()
@
dataclass
...
...
@@ -113,10 +103,9 @@ class BaseGrammarBackend:
def
__init__
(
self
):
self
.
executor
=
ThreadPoolExecutor
()
self
.
cache
:
Dict
[
Tuple
[
str
,
str
],
CacheEntry
]
=
{}
self
.
cache_lock
=
Lock
()
def
_not_supported
(
self
,
key_type
:
str
,
key_string
:
str
)
->
None
:
logger
.
warning
(
f
"Skip unsupported
{
key_type
}
:
{
key_type
}
=
{
key_string
}
"
)
logger
.
warning
(
f
"Skip unsupported
{
key_type
=
}
,
{
key_string
=
}
"
)
def
dispatch_fallback
(
self
,
key_type
:
str
,
key_string
:
str
...
...
@@ -148,40 +137,25 @@ class BaseGrammarBackend:
return
self
.
dispatch_ebnf
(
key_string
)
elif
key_type
==
"structural_tag"
:
return
self
.
dispatch_structural_tag
(
key_string
)
elif
key_type
==
"structural_pattern"
:
return
self
.
dispatch_structural_pattern
(
key_string
)
else
:
return
self
.
dispatch_fallback
(
key_type
,
key_string
)
def
_init_value
(
self
,
key
:
Tuple
[
str
,
str
])
->
Optional
[
BaseGrammarObject
]:
with
self
.
cache_lock
:
if
key
in
self
.
cache
:
cache_hit
=
True
entry
=
self
.
cache
[
key
]
else
:
cache_hit
=
False
entry
=
CacheEntry
(
None
,
Event
())
self
.
cache
[
key
]
=
entry
if
cache_hit
:
entry
.
event
.
wait
()
else
:
entry
.
value
=
self
.
_init_value_dispatch
(
key
)
entry
.
event
.
set
()
return
entry
.
value
.
copy
()
if
entry
.
value
else
None
def
get_cached_value
(
self
,
key
:
Tuple
[
str
,
str
])
->
Optional
[
BaseGrammarObject
]:
with
self
.
cache_lock
:
entry
=
self
.
cache
.
get
(
key
)
if
not
entry
or
not
entry
.
event
.
is_set
():
return
None
val
=
self
.
cache
[
key
].
value
return
val
.
copy
()
if
val
else
None
def
get_cached_or_future_value
(
self
,
key
:
Tuple
[
str
,
str
]
)
->
Optional
[
BaseGrammarObject
]:
value
=
self
.
cache
.
get
(
key
)
if
value
:
return
value
.
copy
(),
True
value
=
self
.
executor
.
submit
(
self
.
_init_value_dispatch
,
key
)
return
value
,
False
def
g
et_
future_valu
e
(
self
,
key
:
Tuple
[
str
,
str
]
)
->
Future
:
return
self
.
executor
.
submit
(
self
.
_init_value
,
key
)
def
s
et_
cach
e
(
self
,
key
:
Tuple
[
str
,
str
]
,
value
:
BaseGrammarObject
)
:
self
.
cache
[
key
]
=
value
def
reset
(
self
):
with
self
.
cache_lock
:
self
.
cache
.
clear
()
self
.
cache
.
clear
()
def
create_grammar_backend
(
...
...
@@ -211,9 +185,12 @@ def create_grammar_backend(
raise
ValueError
(
f
"Invalid grammar backend:
{
server_args
.
grammar_backend
}
"
)
if
server_args
.
reasoning_parser
and
hasattr
(
tokenizer
,
"think_end_id"
):
from
.reasoner_grammar_backend
import
ReasonerGrammarBackend
from
sglang.srt.constrained.reasoner_grammar_backend
import
(
ReasonerGrammarBackend
,
)
grammar_backend
=
ReasonerGrammarBackend
(
grammar_backend
,
tokenizer
.
think_end_id
)
return
grammar_backend
python/sglang/srt/constrained/llguidance_backend.py
View file @
01bdbf7f
...
...
@@ -50,21 +50,6 @@ class GuidanceGrammar(BaseGrammarObject):
self
.
finished
=
False
self
.
bitmask
=
None
def
try_jump_forward
(
self
,
tokenizer
)
->
Optional
[
Tuple
[
List
[
int
],
str
]]:
ff_tokens
=
self
.
ll_matcher
.
compute_ff_tokens
()
if
ff_tokens
:
return
ff_tokens
,
""
else
:
return
None
def
jump_forward_str_state
(
self
,
helper
:
Tuple
[
List
[
int
],
str
])
->
Tuple
[
str
,
int
]:
return
""
,
-
1
def
jump_and_retokenize
(
self
,
old_output_ids
:
List
[
int
],
new_output_ids
:
List
[
int
],
next_state
:
int
):
pass
def
accept_token
(
self
,
token
:
int
):
if
not
self
.
ll_matcher
.
consume_token
(
token
):
logger
.
warning
(
f
"matcher error:
{
self
.
ll_matcher
.
get_error
()
}
"
)
...
...
@@ -104,6 +89,21 @@ class GuidanceGrammar(BaseGrammarObject):
serialized_grammar
=
self
.
serialized_grammar
,
)
def
try_jump_forward
(
self
,
tokenizer
)
->
Optional
[
Tuple
[
List
[
int
],
str
]]:
ff_tokens
=
self
.
ll_matcher
.
compute_ff_tokens
()
if
ff_tokens
:
return
ff_tokens
,
""
else
:
return
None
def
jump_forward_str_state
(
self
,
helper
:
Tuple
[
List
[
int
],
str
])
->
Tuple
[
str
,
int
]:
return
""
,
-
1
def
jump_and_retokenize
(
self
,
old_output_ids
:
List
[
int
],
new_output_ids
:
List
[
int
],
next_state
:
int
):
pass
class
GuidanceBackend
(
BaseGrammarBackend
):
...
...
@@ -130,12 +130,16 @@ class GuidanceBackend(BaseGrammarBackend):
return
None
def
dispatch_json
(
self
,
key_string
:
str
)
->
Optional
[
GuidanceGrammar
]:
serialized_grammar
=
LLMatcher
.
grammar_from_json_schema
(
key_string
,
defaults
=
{
"whitespace_pattern"
:
self
.
whitespace_pattern
,
},
)
try
:
serialized_grammar
=
LLMatcher
.
grammar_from_json_schema
(
key_string
,
defaults
=
{
"whitespace_pattern"
:
self
.
whitespace_pattern
,
},
)
except
Exception
as
e
:
logger
.
warning
(
f
"Skip invalid grammar:
{
key_string
=
}
,
{
e
=
}
"
)
return
None
return
self
.
_from_serialized
(
serialized_grammar
)
def
dispatch_regex
(
self
,
key_string
:
str
)
->
Optional
[
GuidanceGrammar
]:
...
...
python/sglang/srt/constrained/outlines_backend.py
View file @
01bdbf7f
...
...
@@ -53,6 +53,30 @@ class OutlinesGrammar(BaseGrammarObject):
def
accept_token
(
self
,
token
:
int
):
self
.
state
=
self
.
guide
.
get_next_state
(
self
.
state
,
token
)
def
allocate_vocab_mask
(
self
,
vocab_size
:
int
,
batch_size
:
int
,
device
)
->
torch
.
Tensor
:
return
torch
.
zeros
(
batch_size
,
vocab_size
,
dtype
=
torch
.
bool
,
device
=
device
)
@
staticmethod
def
move_vocab_mask
(
vocab_mask
:
torch
.
Tensor
,
device
)
->
torch
.
Tensor
:
return
vocab_mask
def
fill_vocab_mask
(
self
,
vocab_mask
:
torch
.
Tensor
,
idx
:
int
)
->
None
:
tokens
=
torch
.
tensor
(
self
.
guide
.
get_next_instruction
(
self
.
state
).
tokens
,
dtype
=
torch
.
int64
).
to
(
vocab_mask
.
device
,
non_blocking
=
True
)
vocab_mask
=
vocab_mask
[
idx
]
vocab_mask
.
fill_
(
1
)
vocab_mask
.
scatter_
(
0
,
tokens
,
torch
.
zeros_like
(
tokens
,
dtype
=
torch
.
bool
))
@
staticmethod
def
apply_vocab_mask
(
logits
:
torch
.
Tensor
,
vocab_mask
:
torch
.
Tensor
):
logits
.
masked_fill_
(
vocab_mask
,
float
(
"-inf"
))
def
copy
(
self
):
return
OutlinesGrammar
(
self
.
guide
,
self
.
jump_forward_map
)
def
try_jump_forward
(
self
,
tokenizer
)
->
Optional
[
Tuple
]:
if
not
self
.
jump_forward_map
:
return
None
...
...
@@ -86,30 +110,6 @@ class OutlinesGrammar(BaseGrammarObject):
):
self
.
state
=
next_state
def
allocate_vocab_mask
(
self
,
vocab_size
:
int
,
batch_size
:
int
,
device
)
->
torch
.
Tensor
:
return
torch
.
zeros
(
batch_size
,
vocab_size
,
dtype
=
torch
.
bool
,
device
=
device
)
@
staticmethod
def
move_vocab_mask
(
vocab_mask
:
torch
.
Tensor
,
device
)
->
torch
.
Tensor
:
return
vocab_mask
def
fill_vocab_mask
(
self
,
vocab_mask
:
torch
.
Tensor
,
idx
:
int
)
->
None
:
tokens
=
torch
.
tensor
(
self
.
guide
.
get_next_instruction
(
self
.
state
).
tokens
,
dtype
=
torch
.
int64
).
to
(
vocab_mask
.
device
,
non_blocking
=
True
)
vocab_mask
=
vocab_mask
[
idx
]
vocab_mask
.
fill_
(
1
)
vocab_mask
.
scatter_
(
0
,
tokens
,
torch
.
zeros_like
(
tokens
,
dtype
=
torch
.
bool
))
@
staticmethod
def
apply_vocab_mask
(
logits
:
torch
.
Tensor
,
vocab_mask
:
torch
.
Tensor
):
logits
.
masked_fill_
(
vocab_mask
,
float
(
"-inf"
))
def
copy
(
self
):
return
OutlinesGrammar
(
self
.
guide
,
self
.
jump_forward_map
)
class
OutlinesGrammarBackend
(
BaseGrammarBackend
):
def
__init__
(
...
...
@@ -169,8 +169,9 @@ class OutlinesGrammarBackend(BaseGrammarBackend):
key_string
,
whitespace_pattern
=
self
.
whitespace_pattern
,
)
except
(
NotImplementedError
,
json
.
decoder
.
JSONDecodeError
)
as
e
:
logger
.
warning
(
f
"Skip invalid json_schema: json_schema=
{
key_string
}
,
{
e
=
}
"
)
except
(
NotImplementedError
,
json
.
decoder
.
JSONDecodeError
,
ValueError
)
as
e
:
logger
.
warning
(
f
"Skip invalid json_schema:
{
key_string
=
}
,
{
e
=
}
"
)
return
None
return
self
.
_compile_regex
(
regex
)
def
dispatch_regex
(
self
,
key_string
:
str
):
...
...
python/sglang/srt/constrained/reasoner_grammar_backend.py
View file @
01bdbf7f
...
...
@@ -13,7 +13,6 @@
# ==============================================================================
"""The baseclass of a backend for reasoner grammar-guided constrained decoding."""
from
concurrent.futures
import
Future
from
typing
import
List
,
Optional
,
Tuple
import
torch
...
...
@@ -28,13 +27,12 @@ class ReasonerGrammarObject(BaseGrammarObject):
self
.
think_end_id
=
think_end_id
self
.
is_in_reasoning
=
True
@
property
def
finished
(
self
)
:
return
self
.
grammar
.
finished
def
accept_token
(
self
,
token
:
int
):
if
token
==
self
.
think_end_id
:
self
.
is_in_reasoning
=
False
@
finished
.
setter
def
finished
(
self
,
finished
):
self
.
grammar
.
finished
=
finished
if
not
self
.
is_in_reasoning
and
token
!=
self
.
think_end_id
:
self
.
grammar
.
accept_token
(
token
)
def
allocate_vocab_mask
(
self
,
vocab_size
:
int
,
batch_size
:
int
,
device
...
...
@@ -52,12 +50,16 @@ class ReasonerGrammarObject(BaseGrammarObject):
def
apply_vocab_mask
(
self
):
return
self
.
grammar
.
apply_vocab_mask
def
accept_token
(
self
,
token
:
int
):
if
token
==
self
.
think_end_id
:
self
.
is_in_reasoning
=
False
def
copy
(
self
)
->
BaseGrammarObject
:
return
ReasonerGrammarObject
(
self
.
grammar
.
copy
(),
self
.
think_end_id
)
if
not
self
.
is_in_reasoning
and
token
!=
self
.
think_end_id
:
self
.
grammar
.
accept_token
(
token
)
@
property
def
finished
(
self
):
return
self
.
grammar
.
finished
@
finished
.
setter
def
finished
(
self
,
finished
):
self
.
grammar
.
finished
=
finished
def
try_jump_forward
(
self
,
tokenizer
):
return
self
.
grammar
.
try_jump_forward
(
tokenizer
)
...
...
@@ -72,30 +74,17 @@ class ReasonerGrammarObject(BaseGrammarObject):
old_output_ids
,
new_output_ids
,
next_state
)
def
copy
(
self
)
->
BaseGrammarObject
:
return
ReasonerGrammarObject
(
self
.
grammar
.
copy
(),
self
.
think_end_id
)
class
ReasonerGrammarBackend
(
BaseGrammarBackend
):
def
__init__
(
self
,
grammar_backend
:
BaseGrammarBackend
,
think_end_id
):
super
().
__init__
()
self
.
grammar_backend
=
grammar_backend
self
.
think_end_id
=
think_end_id
def
get_cached_value
(
self
,
key
:
Tuple
[
str
,
str
])
->
Optional
[
ReasonerGrammarObject
]:
grammar
=
self
.
grammar_backend
.
get_cached_value
(
key
)
return
ReasonerGrammarObject
(
grammar
,
self
.
think_end_id
)
if
grammar
else
None
def
get_future_value
(
self
,
key
:
Tuple
[
str
,
str
])
->
Future
:
grammar
=
Future
()
def
callback
(
f
:
Future
):
if
result
:
=
f
.
result
():
grammar
.
set_result
(
ReasonerGrammarObject
(
result
,
self
.
think_end_id
))
else
:
grammar
.
set_result
(
None
)
self
.
grammar_backend
.
get_future_value
(
key
).
add_done_callback
(
callback
)
return
grammar
def
reset
(
self
):
self
.
grammar_backend
.
reset
()
def
_init_value_dispatch
(
self
,
key
:
Tuple
[
str
,
str
]
)
->
Optional
[
ReasonerGrammarObject
]:
ret
=
self
.
grammar_backend
.
_init_value_dispatch
(
key
)
if
ret
is
None
:
return
None
return
ReasonerGrammarObject
(
ret
,
self
.
think_end_id
)
python/sglang/srt/constrained/xgrammar_backend.py
View file @
01bdbf7f
...
...
@@ -18,7 +18,6 @@ import logging
from
typing
import
List
,
Optional
,
Tuple
,
Union
import
torch
import
xgrammar
from
xgrammar
import
(
CompiledGrammar
,
GrammarCompiler
,
...
...
@@ -35,7 +34,6 @@ from sglang.srt.constrained.base_grammar_backend import (
from
sglang.srt.constrained.triton_ops.bitmask_ops
import
(
apply_token_bitmask_inplace_triton
,
)
from
sglang.srt.utils
import
get_bool_env_var
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -51,49 +49,35 @@ class XGrammarGrammar(BaseGrammarObject):
vocab_size
:
int
,
ctx
:
CompiledGrammar
,
override_stop_tokens
:
Optional
[
Union
[
List
[
int
],
int
]],
key_string
:
Optional
[
str
]
=
None
,
# TODO (sk): for debugging, remove later
)
->
None
:
super
().
__init__
()
self
.
matcher
=
matcher
self
.
vocab_size
=
vocab_size
self
.
ctx
=
ctx
self
.
override_stop_tokens
=
override_stop_tokens
self
.
finished
=
False
from
xgrammar.kernels.apply_token_bitmask_inplace_cpu
import
(
apply_token_bitmask_inplace_cpu
,
)
self
.
apply_vocab_mask_cpu
=
apply_token_bitmask_inplace_cpu
self
.
accepted_tokens
=
[]
self
.
key_string
=
key_string
def
accept_token
(
self
,
token
:
int
):
assert
self
.
matcher
.
accept_token
(
token
)
def
try_jump_forward
(
self
,
tokenizer
)
->
Optional
[
Tuple
[
List
[
int
],
str
]]:
s
=
self
.
matcher
.
find_jump_forward_string
()
if
s
:
return
[],
s
return
None
def
jump_forward_str_state
(
self
,
helper
:
Tuple
[
List
[
int
],
str
])
->
Tuple
[
str
,
int
]:
_
,
data
=
helper
return
data
,
-
1
def
jump_and_retokenize
(
self
,
old_output_ids
:
List
[
int
],
new_output_ids
:
List
[
int
],
next_state
:
int
):
k
=
0
for
i
,
old_id
in
enumerate
(
old_output_ids
):
if
old_id
==
new_output_ids
[
i
]:
k
=
i
+
1
if
not
self
.
is_terminated
():
accepted
=
self
.
matcher
.
accept_token
(
token
)
if
not
accepted
:
# log for debugging
raise
ValueError
(
f
"Tokens not accepted:
{
token
}
\n
"
f
"Accepted tokens:
{
self
.
accepted_tokens
}
\n
"
f
"Key string:
{
self
.
key_string
}
"
)
else
:
break
self
.
accepted_tokens
.
append
(
token
)
#
rollback
to the last token that is the same
if
k
<
len
(
old_output_ids
):
self
.
matcher
.
rollback
(
len
(
old_output_ids
)
-
k
)
def
rollback
(
self
,
k
:
int
):
self
.
matcher
.
rollback
(
k
)
self
.
accepted_tokens
=
self
.
accepted_tokens
[:
-
k
]
for
i
in
range
(
k
,
len
(
new_output_ids
)
):
assert
self
.
matcher
.
accept_token
(
new_output_ids
[
i
]
)
def
is_terminated
(
self
):
return
self
.
matcher
.
is_terminated
(
)
def
allocate_vocab_mask
(
self
,
vocab_size
:
int
,
batch_size
:
int
,
device
...
...
@@ -122,9 +106,43 @@ class XGrammarGrammar(BaseGrammarObject):
override_stop_tokens
=
self
.
override_stop_tokens
,
)
return
XGrammarGrammar
(
matcher
,
self
.
vocab_size
,
self
.
ctx
,
self
.
override_stop_tokens
matcher
,
self
.
vocab_size
,
self
.
ctx
,
self
.
override_stop_tokens
,
self
.
key_string
,
)
def
try_jump_forward
(
self
,
tokenizer
)
->
Optional
[
Tuple
[
List
[
int
],
str
]]:
s
=
self
.
matcher
.
find_jump_forward_string
()
if
s
:
return
[],
s
return
None
def
jump_forward_str_state
(
self
,
helper
:
Tuple
[
List
[
int
],
str
])
->
Tuple
[
str
,
int
]:
_
,
data
=
helper
return
data
,
-
1
def
jump_and_retokenize
(
self
,
old_output_ids
:
List
[
int
],
new_output_ids
:
List
[
int
],
next_state
:
int
):
k
=
0
for
i
,
old_id
in
enumerate
(
old_output_ids
):
if
old_id
==
new_output_ids
[
i
]:
k
=
i
+
1
else
:
break
# rollback to the last token that is the same
if
k
<
len
(
old_output_ids
):
self
.
matcher
.
rollback
(
len
(
old_output_ids
)
-
k
)
for
i
in
range
(
k
,
len
(
new_output_ids
)):
assert
self
.
matcher
.
accept_token
(
new_output_ids
[
i
])
def
__repr__
(
self
):
return
f
"XGrammarGrammar(
{
self
.
key_string
=
}
,
{
self
.
accepted_tokens
=
}
)"
class
XGrammarGrammarBackend
(
BaseGrammarBackend
):
def
__init__
(
...
...
@@ -143,9 +161,15 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
self
.
vocab_size
=
vocab_size
self
.
override_stop_tokens
=
override_stop_tokens
def
_from_context
(
self
,
ctx
:
CompiledGrammar
)
->
XGrammarGrammar
:
matcher
=
GrammarMatcher
(
ctx
,
max_rollback_tokens
=
MAX_ROLLBACK_TOKENS
)
return
XGrammarGrammar
(
matcher
,
self
.
vocab_size
,
ctx
,
self
.
override_stop_tokens
)
def
_from_context
(
self
,
ctx
:
CompiledGrammar
,
key_string
:
str
)
->
XGrammarGrammar
:
matcher
=
GrammarMatcher
(
ctx
,
max_rollback_tokens
=
MAX_ROLLBACK_TOKENS
,
override_stop_tokens
=
self
.
override_stop_tokens
,
)
return
XGrammarGrammar
(
matcher
,
self
.
vocab_size
,
ctx
,
self
.
override_stop_tokens
,
key_string
)
def
dispatch_json
(
self
,
key_string
:
str
)
->
Optional
[
XGrammarGrammar
]:
try
:
...
...
@@ -157,7 +181,7 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
except
RuntimeError
as
e
:
logging
.
warning
(
f
"Skip invalid json_schema: json_schema=
{
key_string
}
,
{
e
=
}
"
)
return
None
return
self
.
_from_context
(
ctx
)
return
self
.
_from_context
(
ctx
,
key_string
)
def
dispatch_ebnf
(
self
,
key_string
:
str
)
->
Optional
[
XGrammarGrammar
]:
try
:
...
...
@@ -165,7 +189,7 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
except
RuntimeError
as
e
:
logging
.
warning
(
f
"Skip invalid ebnf: ebnf=
{
key_string
}
,
{
e
=
}
"
)
return
None
return
self
.
_from_context
(
ctx
)
return
self
.
_from_context
(
ctx
,
key_string
)
def
dispatch_regex
(
self
,
key_string
:
str
)
->
Optional
[
XGrammarGrammar
]:
try
:
...
...
@@ -173,7 +197,7 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
except
RuntimeError
as
e
:
logging
.
warning
(
f
"Skip invalid regex: regex=
{
key_string
}
,
{
e
=
}
"
)
return
None
return
self
.
_from_context
(
ctx
)
return
self
.
_from_context
(
ctx
,
key_string
)
def
dispatch_structural_tag
(
self
,
key_string
:
str
)
->
Optional
[
XGrammarGrammar
]:
try
:
...
...
@@ -190,9 +214,11 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
tags
,
structural_tag
[
"triggers"
]
)
except
RuntimeError
as
e
:
logging
.
warning
(
f
"Skip invalid regex: regex=
{
key_string
}
,
{
e
=
}
"
)
logging
.
warning
(
f
"Skip invalid structural_tag: structural_tag=
{
key_string
}
,
{
e
=
}
"
)
return
None
return
self
.
_from_context
(
ctx
)
return
self
.
_from_context
(
ctx
,
key_string
)
def
reset
(
self
):
if
self
.
grammar_compiler
:
...
...
python/sglang/srt/layers/sampler.py
View file @
01bdbf7f
...
...
@@ -239,10 +239,6 @@ def top_p_normalize_probs_torch(
def
get_top_logprobs
(
logprobs
:
torch
.
Tensor
,
top_logprobs_nums
:
List
[
int
]):
assert
len
(
top_logprobs_nums
)
==
logprobs
.
shape
[
0
],
(
len
(
top_logprobs_nums
),
logprobs
.
shape
[
0
],
)
max_k
=
max
(
top_logprobs_nums
)
ret
=
logprobs
.
topk
(
max_k
,
dim
=
1
)
values
=
ret
.
values
.
tolist
()
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
01bdbf7f
...
...
@@ -533,6 +533,7 @@ class Req:
# Constrained decoding
self
.
grammar
:
Optional
[
BaseGrammarObject
]
=
None
self
.
grammar_wait_ct
=
0
# The number of cached tokens that were already cached in the KV cache
self
.
cached_tokens
=
0
...
...
python/sglang/srt/managers/scheduler.py
View file @
01bdbf7f
...
...
@@ -149,6 +149,7 @@ logger = logging.getLogger(__name__)
# Test retract decode for debugging purposes
TEST_RETRACT
=
get_bool_env_var
(
"SGLANG_TEST_RETRACT"
)
RECORD_STEP_TIME
=
get_bool_env_var
(
"SGLANG_RECORD_STEP_TIME"
)
GRAMMAR_TIMEOUT
=
float
(
os
.
environ
.
get
(
"SGLANG_GRAMMAR_TIMEOUT"
,
300
))
@
dataclass
...
...
@@ -1024,9 +1025,11 @@ class Scheduler(
elif
req
.
sampling_params
.
structural_tag
:
key
=
(
"structural_tag"
,
req
.
sampling_params
.
structural_tag
)
req
.
grammar
=
self
.
grammar_backend
.
get_cached_value
(
key
)
if
not
req
.
grammar
:
req
.
grammar
=
self
.
grammar_backend
.
get_future_value
(
key
)
value
,
cache_hit
=
self
.
grammar_backend
.
get_cached_or_future_value
(
key
)
req
.
grammar
=
value
if
not
cache_hit
:
req
.
grammar_key
=
key
add_to_grammar_queue
=
True
if
add_to_grammar_queue
:
...
...
@@ -1208,6 +1211,7 @@ class Scheduler(
self
.
stats
.
cache_hit_rate
=
0.0
self
.
stats
.
gen_throughput
=
self
.
last_gen_throughput
self
.
stats
.
num_queue_reqs
=
len
(
self
.
waiting_queue
)
self
.
stats
.
num_grammar_queue_reqs
=
len
(
self
.
grammar_queue
)
self
.
stats
.
spec_accept_length
=
spec_accept_length
self
.
metrics_collector
.
log_stats
(
self
.
stats
)
...
...
@@ -1255,6 +1259,7 @@ class Scheduler(
self
.
stats
.
token_usage
=
num_used
/
self
.
max_total_num_tokens
self
.
stats
.
gen_throughput
=
0
self
.
stats
.
num_queue_reqs
=
len
(
self
.
waiting_queue
)
self
.
stats
.
num_grammar_queue_reqs
=
len
(
self
.
grammar_queue
)
self
.
metrics_collector
.
log_stats
(
self
.
stats
)
def
get_next_batch_to_run
(
self
)
->
Optional
[
ScheduleBatch
]:
...
...
@@ -1715,11 +1720,17 @@ class Scheduler(
"""Move requests whose grammar objects are ready from grammar_queue to waiting_queue."""
num_ready_reqs
=
0
num_abort_reqs
=
0
for
req
in
self
.
grammar_queue
:
try
:
req
.
grammar
=
req
.
grammar
.
result
(
timeout
=
0.05
)
req
.
grammar
=
req
.
grammar
.
result
(
timeout
=
0.03
)
if
req
.
grammar
:
self
.
grammar_backend
.
set_cache
(
req
.
grammar_key
,
req
.
grammar
.
copy
())
num_ready_reqs
+=
1
except
futures
.
_base
.
TimeoutError
:
req
.
grammar_wait_ct
+=
1
if
req
.
grammar_wait_ct
>
GRAMMAR_TIMEOUT
/
0.03
:
num_abort_reqs
=
1
break
if
self
.
server_args
.
enable_dp_attention
:
...
...
@@ -1731,14 +1742,28 @@ class Scheduler(
if
tp_size
>
1
:
# Sync across TP ranks to make sure they have the same number of ready requests
tensor
=
torch
.
tensor
(
num_ready_reqs
,
dtype
=
torch
.
int32
)
tensor
=
torch
.
tensor
(
[
num_ready_reqs
,
num_abort_reqs
],
dtype
=
torch
.
int32
)
torch
.
distributed
.
all_reduce
(
tensor
,
op
=
torch
.
distributed
.
ReduceOp
.
MAX
,
group
=
tp_group
)
num_ready_reqs_max
=
tensor
.
item
()
num_ready_reqs_max
,
num_abort_reqs_max
=
tensor
.
tolist
()
for
i
in
range
(
num_ready_reqs
,
num_ready_reqs_max
):
self
.
grammar_queue
[
i
].
grammar
=
self
.
grammar_queue
[
i
].
grammar
.
result
()
num_ready_reqs
=
num_ready_reqs_max
req
=
self
.
grammar_queue
[
i
]
req
.
grammar
=
req
.
grammar
.
result
()
if
req
.
grammar
:
self
.
grammar_backend
.
set_cache
(
req
.
grammar_key
,
req
.
grammar
.
copy
())
for
i
in
range
(
num_ready_reqs
,
num_ready_reqs
+
num_abort_reqs_max
):
req
=
self
.
grammar_queue
[
i
]
req
.
grammar
.
cancel
()
req
.
grammar
=
None
error_msg
=
f
"Grammar preprocessing timed out for
{
req
.
grammar_key
=
}
"
logger
.
error
(
error_msg
)
req
.
finished_reason
=
FINISH_ABORT
(
error_msg
,
HTTPStatus
.
BAD_REQUEST
,
"BadRequestError"
)
num_ready_reqs
=
num_ready_reqs_max
+
num_abort_reqs_max
self
.
_extend_requests_to_queue
(
self
.
grammar_queue
[:
num_ready_reqs
])
self
.
grammar_queue
=
self
.
grammar_queue
[
num_ready_reqs
:]
...
...
python/sglang/srt/managers/tokenizer_manager.py
View file @
01bdbf7f
...
...
@@ -1230,11 +1230,18 @@ class TokenizerManager:
state
.
last_completion_tokens
=
completion_tokens
if
state
.
finished
:
has_grammar
=
(
state
.
obj
.
sampling_params
.
get
(
"json_schema"
,
None
)
or
state
.
obj
.
sampling_params
.
get
(
"regex"
,
None
)
or
state
.
obj
.
sampling_params
.
get
(
"ebnf"
,
None
)
or
state
.
obj
.
sampling_params
.
get
(
"structural_tag"
,
None
)
)
self
.
metrics_collector
.
observe_one_finished_request
(
recv_obj
.
prompt_tokens
[
i
],
completion_tokens
,
recv_obj
.
cached_tokens
[
i
],
state
.
finished_time
-
state
.
created_time
,
has_grammar
,
)
def
dump_requests
(
self
,
state
:
ReqState
,
out_dict
:
dict
):
...
...
python/sglang/srt/metrics/collector.py
View file @
01bdbf7f
...
...
@@ -15,7 +15,119 @@
import
time
from
dataclasses
import
dataclass
from
typing
import
Dict
,
Union
from
enum
import
Enum
from
typing
import
Dict
,
List
,
Optional
,
Union
from
sglang.srt.utils
import
get_bool_env_var
SGLANG_TEST_REQUEST_TIME_STATS
=
get_bool_env_var
(
"SGLANG_TEST_REQUEST_TIME_STATS"
)
@
dataclass
class
TimeStats
:
"""
Store the timestamps for each stage of a request.
Unified: wait_queue -> forward -> completion
Prefill: bootstrap_queue -> wait_queue -> forward -> transfer_queue -> completion
Decode: prealloc_queue -> transfer_queue -> wait_queue -> forward -> completion
"""
lb_entry_time
:
float
=
0.0
wait_queue_entry_time
:
float
=
0.0
forward_entry_time
:
float
=
0.0
completion_time
:
float
=
0.0
prefill_bootstrap_queue_entry_time
:
float
=
0.0
prefill_transfer_queue_entry_time
:
float
=
0.0
decode_prealloc_queue_entry_time
:
float
=
0.0
decode_transfer_queue_entry_time
:
float
=
0.0
class
RequestType
(
Enum
):
UNIFIED
=
"unified"
PREFILL
=
"prefill"
DECODE
=
"decode"
INVALID
=
"invalid"
def
__str__
(
self
)
->
str
:
# if unified
_type
=
self
.
get_type
()
if
_type
==
self
.
RequestType
.
UNIFIED
:
queue_duration
=
self
.
forward_entry_time
-
self
.
wait_queue_entry_time
forward_duration
=
self
.
completion_time
-
self
.
forward_entry_time
if
SGLANG_TEST_REQUEST_TIME_STATS
:
assert
(
queue_duration
>=
0
and
forward_duration
>=
0
),
f
"queue_duration=
{
queue_duration
}
< 0 or forward_duration=
{
forward_duration
}
< 0"
return
f
"queue_duration=
{
self
.
format_duration
(
queue_duration
)
}
, forward_duration=
{
self
.
format_duration
(
forward_duration
)
}
, start_time=
{
self
.
wait_queue_entry_time
}
"
elif
_type
==
self
.
RequestType
.
PREFILL
:
bootstrap_duration
=
(
self
.
wait_queue_entry_time
-
self
.
prefill_bootstrap_queue_entry_time
)
queue_duration
=
self
.
forward_entry_time
-
self
.
wait_queue_entry_time
forward_duration
=
self
.
completion_time
-
self
.
forward_entry_time
if
SGLANG_TEST_REQUEST_TIME_STATS
:
assert
(
bootstrap_duration
>=
0
and
queue_duration
>=
0
and
forward_duration
>=
0
),
f
"bootstrap_duration=
{
bootstrap_duration
}
< 0 or queue_duration=
{
queue_duration
}
< 0 or forward_duration=
{
forward_duration
}
< 0"
return
f
"bootstrap_duration=
{
self
.
format_duration
(
bootstrap_duration
)
}
, queue_duration=
{
self
.
format_duration
(
queue_duration
)
}
, forward_duration=
{
self
.
format_duration
(
forward_duration
)
}
, start_time=
{
self
.
prefill_bootstrap_queue_entry_time
}
"
# if decode
elif
_type
==
self
.
RequestType
.
DECODE
:
prealloc_duration
=
(
self
.
decode_transfer_queue_entry_time
-
self
.
decode_prealloc_queue_entry_time
)
transfer_duration
=
(
self
.
wait_queue_entry_time
-
self
.
decode_transfer_queue_entry_time
)
queue_duration
=
self
.
forward_entry_time
-
self
.
wait_queue_entry_time
forward_duration
=
self
.
completion_time
-
self
.
forward_entry_time
if
SGLANG_TEST_REQUEST_TIME_STATS
:
assert
(
prealloc_duration
>=
0
and
transfer_duration
>=
0
and
queue_duration
>=
0
and
forward_duration
>=
0
),
f
"prealloc_duration=
{
prealloc_duration
}
< 0 or transfer_duration=
{
transfer_duration
}
< 0 or queue_duration=
{
queue_duration
}
< 0 or forward_duration=
{
forward_duration
}
< 0"
return
f
"prealloc_duration=
{
self
.
format_duration
(
prealloc_duration
)
}
, transfer_duration=
{
self
.
format_duration
(
transfer_duration
)
}
, queue_duration=
{
self
.
format_duration
(
queue_duration
)
}
, forward_duration=
{
self
.
format_duration
(
forward_duration
)
}
, start_time=
{
self
.
decode_prealloc_queue_entry_time
}
"
else
:
return
"Invalid Time Stats"
def
format_duration
(
self
,
duration
:
float
)
->
str
:
return
f
"
{
duration
*
1e3
:.
2
f
}
ms"
def
get_type
(
self
)
->
RequestType
:
"""Determine the type of request based on timestamp values."""
if
(
self
.
prefill_bootstrap_queue_entry_time
==
0.0
and
self
.
prefill_transfer_queue_entry_time
==
0.0
and
self
.
decode_prealloc_queue_entry_time
==
0.0
and
self
.
decode_transfer_queue_entry_time
==
0.0
):
return
self
.
RequestType
.
UNIFIED
elif
(
self
.
prefill_bootstrap_queue_entry_time
>
0.0
and
self
.
prefill_transfer_queue_entry_time
>
0.0
):
return
self
.
RequestType
.
PREFILL
elif
(
self
.
decode_prealloc_queue_entry_time
>
0.0
and
self
.
decode_transfer_queue_entry_time
>
0.0
and
self
.
wait_queue_entry_time
>
0.0
):
return
self
.
RequestType
.
DECODE
else
:
return
self
.
RequestType
.
INVALID
@
dataclass
...
...
@@ -26,15 +138,20 @@ class SchedulerStats:
gen_throughput
:
float
=
0.0
num_queue_reqs
:
int
=
0
cache_hit_rate
:
float
=
0.0
num_grammar_queue_reqs
:
int
=
0
spec_accept_length
:
float
=
0.0
avg_request_queue_latency
:
float
=
0.0
num_prefill_prealloc_queue_reqs
:
int
=
0
num_prefill_infight_queue_reqs
:
int
=
0
num_decode_prealloc_queue_reqs
:
int
=
0
num_decode_transfer_queue_reqs
:
int
=
0
class
SchedulerMetricsCollector
:
def
__init__
(
self
,
labels
:
Dict
[
str
,
str
])
->
None
:
# We need to import prometheus_client after setting the env variable `PROMETHEUS_MULTIPROC_DIR`
from
prometheus_client
import
Gauge
,
Histogram
from
prometheus_client
import
Counter
,
Gauge
self
.
labels
=
labels
self
.
last_log_time
=
time
.
time
()
...
...
@@ -74,6 +191,13 @@ class SchedulerMetricsCollector:
multiprocess_mode
=
"mostrecent"
,
)
self
.
num_grammar_queue_reqs
=
Gauge
(
name
=
"sglang:num_grammar_queue_reqs"
,
documentation
=
"The number of requests in the grammar waiting queue."
,
labelnames
=
labels
.
keys
(),
multiprocess_mode
=
"mostrecent"
,
)
self
.
cache_hit_rate
=
Gauge
(
name
=
"sglang:cache_hit_rate"
,
documentation
=
"The prefix cache hit rate."
,
...
...
@@ -95,28 +219,98 @@ class SchedulerMetricsCollector:
multiprocess_mode
=
"mostrecent"
,
)
# Disaggregation queue metrics
self
.
num_prefill_prealloc_queue_reqs
=
Gauge
(
name
=
"sglang:num_prefill_prealloc_queue_reqs"
,
documentation
=
"The number of requests in the prefill prealloc queue."
,
labelnames
=
labels
.
keys
(),
multiprocess_mode
=
"mostrecent"
,
)
self
.
num_prefill_infight_queue_reqs
=
Gauge
(
name
=
"sglang:num_prefill_infight_queue_reqs"
,
documentation
=
"The number of requests in the prefill infight queue."
,
labelnames
=
labels
.
keys
(),
multiprocess_mode
=
"mostrecent"
,
)
self
.
num_decode_prealloc_queue_reqs
=
Gauge
(
name
=
"sglang:num_decode_prealloc_queue_reqs"
,
documentation
=
"The number of requests in the decode prealloc queue."
,
labelnames
=
labels
.
keys
(),
multiprocess_mode
=
"mostrecent"
,
)
self
.
num_decode_transfer_queue_reqs
=
Gauge
(
name
=
"sglang:num_decode_transfer_queue_reqs"
,
documentation
=
"The number of requests in the decode transfer queue."
,
labelnames
=
labels
.
keys
(),
multiprocess_mode
=
"mostrecent"
,
)
self
.
num_bootstrap_failed_reqs
=
Counter
(
name
=
"sglang:num_bootstrap_failed_reqs"
,
documentation
=
"The number of bootstrap failed requests."
,
labelnames
=
labels
.
keys
(),
)
self
.
num_transfer_failed_reqs
=
Counter
(
name
=
"sglang:num_transfer_failed_reqs"
,
documentation
=
"The number of transfer failed requests."
,
labelnames
=
labels
.
keys
(),
)
def
_log_gauge
(
self
,
gauge
,
data
:
Union
[
int
,
float
])
->
None
:
# Convenience function for logging to gauge.
gauge
.
labels
(
**
self
.
labels
).
set
(
data
)
def
increment_bootstrap_failed_reqs
(
self
)
->
None
:
self
.
num_bootstrap_failed_reqs
.
labels
(
**
self
.
labels
).
inc
(
1
)
def
increment_transfer_failed_reqs
(
self
)
->
None
:
self
.
num_transfer_failed_reqs
.
labels
(
**
self
.
labels
).
inc
(
1
)
def
log_stats
(
self
,
stats
:
SchedulerStats
)
->
None
:
self
.
_log_gauge
(
self
.
num_running_reqs
,
stats
.
num_running_reqs
)
self
.
_log_gauge
(
self
.
num_used_tokens
,
stats
.
num_used_tokens
)
self
.
_log_gauge
(
self
.
token_usage
,
stats
.
token_usage
)
self
.
_log_gauge
(
self
.
gen_throughput
,
stats
.
gen_throughput
)
self
.
_log_gauge
(
self
.
num_queue_reqs
,
stats
.
num_queue_reqs
)
self
.
_log_gauge
(
self
.
num_grammar_queue_reqs
,
stats
.
num_grammar_queue_reqs
)
self
.
_log_gauge
(
self
.
cache_hit_rate
,
stats
.
cache_hit_rate
)
self
.
_log_gauge
(
self
.
spec_accept_length
,
stats
.
spec_accept_length
)
self
.
_log_gauge
(
self
.
avg_request_queue_latency
,
stats
.
avg_request_queue_latency
)
# Disaggregation metrics
self
.
_log_gauge
(
self
.
num_prefill_prealloc_queue_reqs
,
stats
.
num_prefill_prealloc_queue_reqs
)
self
.
_log_gauge
(
self
.
num_prefill_infight_queue_reqs
,
stats
.
num_prefill_infight_queue_reqs
)
self
.
_log_gauge
(
self
.
num_decode_prealloc_queue_reqs
,
stats
.
num_decode_prealloc_queue_reqs
)
self
.
_log_gauge
(
self
.
num_decode_transfer_queue_reqs
,
stats
.
num_decode_transfer_queue_reqs
)
self
.
last_log_time
=
time
.
time
()
class
TokenizerMetricsCollector
:
def
__init__
(
self
,
labels
:
Dict
[
str
,
str
])
->
None
:
def
__init__
(
self
,
labels
:
Dict
[
str
,
str
],
bucket_time_to_first_token
:
Optional
[
List
[
float
]]
=
None
,
bucket_inter_token_latency
:
Optional
[
List
[
float
]]
=
None
,
bucket_e2e_request_latency
:
Optional
[
List
[
float
]]
=
None
,
collect_tokens_histogram
:
bool
=
False
,
)
->
None
:
# We need to import prometheus_client after setting the env variable `PROMETHEUS_MULTIPROC_DIR`
from
prometheus_client
import
Counter
,
Histogram
self
.
labels
=
labels
self
.
collect_tokens_histogram
=
collect_tokens_histogram
self
.
prompt_tokens_total
=
Counter
(
name
=
"sglang:prompt_tokens_total"
,
...
...
@@ -130,6 +324,66 @@ class TokenizerMetricsCollector:
labelnames
=
labels
.
keys
(),
)
if
collect_tokens_histogram
:
bucket_prompt_tokens
=
[
100
,
300
,
500
,
700
,
1000
,
1500
,
2000
,
3000
,
4000
,
5000
,
6000
,
7000
,
8000
,
9000
,
10000
,
12000
,
15000
,
20000
,
22000
,
25000
,
30000
,
35000
,
40000
,
]
self
.
prompt_tokens_histogram
=
Histogram
(
name
=
"sglang:prompt_tokens_histogram"
,
documentation
=
"Histogram of prompt token length."
,
labelnames
=
labels
.
keys
(),
buckets
=
bucket_prompt_tokens
,
)
bucket_generation_tokens
=
[
100
,
300
,
500
,
1000
,
1200
,
1500
,
1700
,
2000
,
2500
,
3000
,
3500
,
4000
,
4500
,
5000
,
6000
,
7000
,
8000
,
9000
,
10000
,
]
self
.
generation_tokens_histogram
=
Histogram
(
name
=
"sglang:generation_tokens_histogram"
,
documentation
=
"Histogram of generation token length."
,
labelnames
=
labels
.
keys
(),
buckets
=
bucket_generation_tokens
,
)
self
.
cached_tokens_total
=
Counter
(
name
=
"sglang:cached_tokens_total"
,
documentation
=
"Number of cached prompt tokens."
,
...
...
@@ -142,11 +396,14 @@ class TokenizerMetricsCollector:
labelnames
=
labels
.
keys
(),
)
self
.
histogram_time_to_fir
st_to
ken
=
Histogram
(
name
=
"sglang:
time_to_fir
st_to
ken_seconds
"
,
documentation
=
"
Histogram of time to first token in seconds
."
,
self
.
num_so_reque
st
s
_to
tal
=
Counter
(
name
=
"sglang:
num_so_reque
st
s
_to
tal
"
,
documentation
=
"
Number of structured output requests processed
."
,
labelnames
=
labels
.
keys
(),
buckets
=
[
)
if
bucket_time_to_first_token
is
None
:
bucket_time_to_first_token
=
[
0.1
,
0.2
,
0.4
,
...
...
@@ -165,14 +422,33 @@ class TokenizerMetricsCollector:
100
,
200
,
400
,
],
)
]
self
.
histogram_inter_token_latency_seconds
=
Histogram
(
name
=
"sglang:inter_token_latency_seconds"
,
documentation
=
"Histogram of inter-token latency in seconds."
,
labelnames
=
labels
.
keys
(),
buckets
=
[
if
bucket_e2e_request_latency
is
None
:
bucket_e2e_request_latency
=
[
0.1
,
0.2
,
0.4
,
0.6
,
0.8
,
1
,
2
,
4
,
6
,
8
,
10
,
20
,
40
,
60
,
80
,
100
,
200
,
400
,
800
,
]
if
bucket_inter_token_latency
is
None
:
bucket_inter_token_latency
=
[
0.002
,
0.004
,
0.006
,
...
...
@@ -196,34 +472,27 @@ class TokenizerMetricsCollector:
4.000
,
6.000
,
8.000
,
],
]
self
.
histogram_time_to_first_token
=
Histogram
(
name
=
"sglang:time_to_first_token_seconds"
,
documentation
=
"Histogram of time to first token in seconds."
,
labelnames
=
labels
.
keys
(),
buckets
=
bucket_time_to_first_token
,
)
self
.
histogram_inter_token_latency_seconds
=
Histogram
(
name
=
"sglang:inter_token_latency_seconds"
,
documentation
=
"Histogram of inter-token latency in seconds."
,
labelnames
=
labels
.
keys
(),
buckets
=
bucket_inter_token_latency
,
)
self
.
histogram_e2e_request_latency
=
Histogram
(
name
=
"sglang:e2e_request_latency_seconds"
,
documentation
=
"Histogram of End-to-end request latency in seconds"
,
labelnames
=
labels
.
keys
(),
buckets
=
[
0.1
,
0.2
,
0.4
,
0.6
,
0.8
,
1
,
2
,
4
,
6
,
8
,
10
,
20
,
40
,
60
,
80
,
100
,
200
,
400
,
800
,
],
buckets
=
bucket_e2e_request_latency
,
)
def
_log_histogram
(
self
,
histogram
,
data
:
Union
[
int
,
float
])
->
None
:
...
...
@@ -235,13 +504,19 @@ class TokenizerMetricsCollector:
generation_tokens
:
int
,
cached_tokens
:
int
,
e2e_latency
:
float
,
has_grammar
:
bool
,
):
self
.
prompt_tokens_total
.
labels
(
**
self
.
labels
).
inc
(
prompt_tokens
)
self
.
generation_tokens_total
.
labels
(
**
self
.
labels
).
inc
(
generation_tokens
)
if
cached_tokens
>
0
:
self
.
cached_tokens_total
.
labels
(
**
self
.
labels
).
inc
(
cached_tokens
)
self
.
num_requests_total
.
labels
(
**
self
.
labels
).
inc
(
1
)
if
has_grammar
:
self
.
num_so_requests_total
.
labels
(
**
self
.
labels
).
inc
(
1
)
self
.
_log_histogram
(
self
.
histogram_e2e_request_latency
,
e2e_latency
)
if
self
.
collect_tokens_histogram
:
self
.
_log_histogram
(
self
.
prompt_tokens_histogram
,
prompt_tokens
)
self
.
_log_histogram
(
self
.
generation_tokens_histogram
,
generation_tokens
)
def
observe_time_to_first_token
(
self
,
value
:
float
):
self
.
histogram_time_to_first_token
.
labels
(
**
self
.
labels
).
observe
(
value
)
...
...
test/srt/test_json_constrained.py
View file @
01bdbf7f
...
...
@@ -82,7 +82,7 @@ class TestJSONConstrainedOutlinesBackend(CustomTestCase):
print
(
json
.
dumps
(
ret
))
print
(
"="
*
100
)
if
not
json_schema
:
if
not
json_schema
or
json_schema
==
"INVALID"
:
return
# Make sure the json output is valid
...
...
@@ -97,6 +97,9 @@ class TestJSONConstrainedOutlinesBackend(CustomTestCase):
def
test_json_generate
(
self
):
self
.
run_decode
(
json_schema
=
self
.
json_schema
)
def
test_json_invalid
(
self
):
self
.
run_decode
(
json_schema
=
"INVALID"
)
def
test_json_openai
(
self
):
client
=
openai
.
Client
(
api_key
=
"EMPTY"
,
base_url
=
f
"
{
self
.
base_url
}
/v1"
)
...
...
@@ -104,7 +107,10 @@ class TestJSONConstrainedOutlinesBackend(CustomTestCase):
model
=
self
.
model
,
messages
=
[
{
"role"
:
"system"
,
"content"
:
"You are a helpful AI assistant"
},
{
"role"
:
"user"
,
"content"
:
"Introduce the capital of France."
},
{
"role"
:
"user"
,
"content"
:
"Introduce the capital of France. Return in a JSON format."
,
},
],
temperature
=
0
,
max_tokens
=
128
,
...
...
test/srt/test_metrics.py
View file @
01bdbf7f
...
...
@@ -56,6 +56,7 @@ class TestEnableMetrics(CustomTestCase):
"sglang:token_usage"
,
"sglang:gen_throughput"
,
"sglang:num_queue_reqs"
,
"sglang:num_grammar_queue_reqs"
,
"sglang:cache_hit_rate"
,
"sglang:spec_accept_length"
,
"sglang:prompt_tokens_total"
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment