Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
zhaoyu6
sglang
Commits
01bdbf7f
"sgl-kernel/vscode:/vscode.git/clone" did not exist on "5121af4627eff2d1d07809e968d7c1dc5dc80540"
Unverified
Commit
01bdbf7f
authored
May 11, 2025
by
Lianmin Zheng
Committed by
GitHub
May 11, 2025
Browse files
Improve structured outputs: fix race condition, server crash, metrics and style (#6188)
parent
94d42b67
Changes
13
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
568 additions
and
258 deletions
+568
-258
docs/backend/structured_outputs_for_reasoning_models.ipynb
docs/backend/structured_outputs_for_reasoning_models.ipynb
+14
-12
python/sglang/srt/constrained/base_grammar_backend.py
python/sglang/srt/constrained/base_grammar_backend.py
+49
-72
python/sglang/srt/constrained/llguidance_backend.py
python/sglang/srt/constrained/llguidance_backend.py
+25
-21
python/sglang/srt/constrained/outlines_backend.py
python/sglang/srt/constrained/outlines_backend.py
+27
-26
python/sglang/srt/constrained/reasoner_grammar_backend.py
python/sglang/srt/constrained/reasoner_grammar_backend.py
+22
-33
python/sglang/srt/constrained/xgrammar_backend.py
python/sglang/srt/constrained/xgrammar_backend.py
+69
-43
python/sglang/srt/layers/sampler.py
python/sglang/srt/layers/sampler.py
+0
-4
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+1
-0
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+33
-8
python/sglang/srt/managers/tokenizer_manager.py
python/sglang/srt/managers/tokenizer_manager.py
+7
-0
python/sglang/srt/metrics/collector.py
python/sglang/srt/metrics/collector.py
+312
-37
test/srt/test_json_constrained.py
test/srt/test_json_constrained.py
+8
-2
test/srt/test_metrics.py
test/srt/test_metrics.py
+1
-0
No files found.
docs/backend/structured_outputs_for_reasoning_models.ipynb
View file @
01bdbf7f
...
@@ -94,8 +94,8 @@
...
@@ -94,8 +94,8 @@
" model=\"deepseek-ai/DeepSeek-R1-Distill-Qwen-7B\",\n",
" model=\"deepseek-ai/DeepSeek-R1-Distill-Qwen-7B\",\n",
" messages=[\n",
" messages=[\n",
" {\n",
" {\n",
" \"role\": \"
user
\",\n",
" \"role\": \"
assistant
\",\n",
" \"content\": \"
Please generate the inform
ation of the capital of France in the JSON format.\",\n",
" \"content\": \"
Give me the information and popul
ation of the capital of France in the JSON format.\",\n",
" },\n",
" },\n",
" ],\n",
" ],\n",
" temperature=0,\n",
" temperature=0,\n",
...
@@ -145,8 +145,8 @@
...
@@ -145,8 +145,8 @@
" model=\"deepseek-ai/DeepSeek-R1-Distill-Qwen-7B\",\n",
" model=\"deepseek-ai/DeepSeek-R1-Distill-Qwen-7B\",\n",
" messages=[\n",
" messages=[\n",
" {\n",
" {\n",
" \"role\": \"
user
\",\n",
" \"role\": \"
assistant
\",\n",
" \"content\": \"Give me the information of the capital of France in the JSON format.\",\n",
" \"content\": \"Give me the information
and population
of the capital of France in the JSON format.\",\n",
" },\n",
" },\n",
" ],\n",
" ],\n",
" temperature=0,\n",
" temperature=0,\n",
...
@@ -188,8 +188,8 @@
...
@@ -188,8 +188,8 @@
" messages=[\n",
" messages=[\n",
" {\"role\": \"system\", \"content\": \"You are a helpful geography bot.\"},\n",
" {\"role\": \"system\", \"content\": \"You are a helpful geography bot.\"},\n",
" {\n",
" {\n",
" \"role\": \"
user
\",\n",
" \"role\": \"
assistant
\",\n",
" \"content\": \"Give me the information of the capital of France.\",\n",
" \"content\": \"Give me the information
and population
of the capital of France
in the JSON format
.\",\n",
" },\n",
" },\n",
" ],\n",
" ],\n",
" temperature=0,\n",
" temperature=0,\n",
...
@@ -218,7 +218,7 @@
...
@@ -218,7 +218,7 @@
"response = client.chat.completions.create(\n",
"response = client.chat.completions.create(\n",
" model=\"deepseek-ai/DeepSeek-R1-Distill-Qwen-7B\",\n",
" model=\"deepseek-ai/DeepSeek-R1-Distill-Qwen-7B\",\n",
" messages=[\n",
" messages=[\n",
" {\"role\": \"
user
\", \"content\": \"What is the capital of France?\"},\n",
" {\"role\": \"
assistant
\", \"content\": \"What is the capital of France?\"},\n",
" ],\n",
" ],\n",
" temperature=0,\n",
" temperature=0,\n",
" max_tokens=2048,\n",
" max_tokens=2048,\n",
...
@@ -323,7 +323,7 @@
...
@@ -323,7 +323,7 @@
"You are a helpful assistant.\"\"\",\n",
"You are a helpful assistant.\"\"\",\n",
" },\n",
" },\n",
" {\n",
" {\n",
" \"role\": \"
user
\",\n",
" \"role\": \"
assistant
\",\n",
" \"content\": \"You are in New York. Please get the current date and time, and the weather.\",\n",
" \"content\": \"You are in New York. Please get the current date and time, and the weather.\",\n",
" },\n",
" },\n",
" ]\n",
" ]\n",
...
@@ -400,9 +400,9 @@
...
@@ -400,9 +400,9 @@
"\n",
"\n",
"messages = [\n",
"messages = [\n",
" {\n",
" {\n",
" \"role\": \"
user
\",\n",
" \"role\": \"
assistant
\",\n",
" \"content\": \"
Here is
the information of the capital of France in the JSON format.\
\n\
",\n",
" \"content\": \"
Give me
the information
and population
of the capital of France in the JSON format.\",\n",
" }\n",
" }
,
\n",
"]\n",
"]\n",
"text = tokenizer.apply_chat_template(\n",
"text = tokenizer.apply_chat_template(\n",
" messages, tokenize=False, add_generation_prompt=True\n",
" messages, tokenize=False, add_generation_prompt=True\n",
...
@@ -452,7 +452,9 @@
...
@@ -452,7 +452,9 @@
")\n",
")\n",
"\n",
"\n",
"# JSON\n",
"# JSON\n",
"text = tokenizer.apply_chat_template(text, tokenize=False, add_generation_prompt=True)\n",
"text = tokenizer.apply_chat_template(\n",
" messages, tokenize=False, add_generation_prompt=True\n",
")\n",
"response = requests.post(\n",
"response = requests.post(\n",
" f\"http://localhost:{port}/generate\",\n",
" f\"http://localhost:{port}/generate\",\n",
" json={\n",
" json={\n",
...
...
python/sglang/srt/constrained/base_grammar_backend.py
View file @
01bdbf7f
...
@@ -14,10 +14,9 @@
...
@@ -14,10 +14,9 @@
"""The baseclass of a backend for grammar-guided constrained decoding."""
"""The baseclass of a backend for grammar-guided constrained decoding."""
import
logging
import
logging
from
abc
import
ABC
,
abstractmethod
from
concurrent.futures
import
ThreadPoolExecutor
from
concurrent.futures
import
Future
,
ThreadPoolExecutor
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
threading
import
Event
,
Lock
from
threading
import
Event
from
typing
import
Dict
,
List
,
Optional
,
Tuple
from
typing
import
Dict
,
List
,
Optional
,
Tuple
import
torch
import
torch
...
@@ -27,11 +26,36 @@ from sglang.srt.server_args import ServerArgs
...
@@ -27,11 +26,36 @@ from sglang.srt.server_args import ServerArgs
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
class
BaseGrammarObject
(
ABC
)
:
class
BaseGrammarObject
:
def
__init__
(
self
):
def
__init__
(
self
):
self
.
_finished
=
False
self
.
_finished
=
False
def
accept_token
(
self
,
token
:
int
)
->
None
:
"""
Accept a token in the grammar.
"""
raise
NotImplementedError
()
def
allocate_vocab_mask
(
self
,
vocab_size
:
int
,
batch_size
:
int
,
device
)
->
torch
.
Tensor
:
raise
NotImplementedError
()
def
fill_vocab_mask
(
self
,
vocab_mask
:
torch
.
Tensor
,
idx
:
int
)
->
None
:
raise
NotImplementedError
()
@
staticmethod
def
move_vocab_mask
(
vocab_mask
:
torch
.
Tensor
,
device
)
->
torch
.
Tensor
:
raise
NotImplementedError
()
@
staticmethod
def
apply_vocab_mask
(
logits
:
torch
.
Tensor
,
vocab_mask
:
torch
.
Tensor
)
->
None
:
raise
NotImplementedError
()
def
copy
(
self
)
->
"BaseGrammarObject"
:
raise
NotImplementedError
()
@
property
@
property
def
finished
(
self
):
def
finished
(
self
):
return
self
.
_finished
return
self
.
_finished
...
@@ -40,7 +64,6 @@ class BaseGrammarObject(ABC):
...
@@ -40,7 +64,6 @@ class BaseGrammarObject(ABC):
def
finished
(
self
,
finished
):
def
finished
(
self
,
finished
):
self
.
_finished
=
finished
self
.
_finished
=
finished
@
abstractmethod
def
try_jump_forward
(
self
,
tokenizer
)
->
Optional
[
Tuple
[
List
[
int
],
str
]]:
def
try_jump_forward
(
self
,
tokenizer
)
->
Optional
[
Tuple
[
List
[
int
],
str
]]:
"""
"""
Try to jump forward in the grammar.
Try to jump forward in the grammar.
...
@@ -49,9 +72,8 @@ class BaseGrammarObject(ABC):
...
@@ -49,9 +72,8 @@ class BaseGrammarObject(ABC):
A jump forward helper which may be used in `jump_forward_str_state`.
A jump forward helper which may be used in `jump_forward_str_state`.
None if the jump forward is not possible.
None if the jump forward is not possible.
"""
"""
raise
NotImplementedError
raise
NotImplementedError
()
@
abstractmethod
def
jump_forward_str_state
(
self
,
helper
:
Tuple
[
List
[
int
],
str
])
->
Tuple
[
str
,
int
]:
def
jump_forward_str_state
(
self
,
helper
:
Tuple
[
List
[
int
],
str
])
->
Tuple
[
str
,
int
]:
"""
"""
Jump forward for the grammar.
Jump forward for the grammar.
...
@@ -60,47 +82,15 @@ class BaseGrammarObject(ABC):
...
@@ -60,47 +82,15 @@ class BaseGrammarObject(ABC):
A tuple of the jump forward string and the next state of the grammar
A tuple of the jump forward string and the next state of the grammar
(which can be used in `jump_and_retokenize` if needed).
(which can be used in `jump_and_retokenize` if needed).
"""
"""
raise
NotImplementedError
raise
NotImplementedError
()
@
abstractmethod
def
jump_and_retokenize
(
def
jump_and_retokenize
(
self
,
old_output_ids
:
List
[
int
],
new_output_ids
:
List
[
int
],
next_state
:
int
self
,
old_output_ids
:
List
[
int
],
new_output_ids
:
List
[
int
],
next_state
:
int
)
->
None
:
)
->
None
:
"""
"""
Jump forward occurs, and update the grammar state if needed.
Jump forward occurs, and update the grammar state if needed.
"""
"""
raise
NotImplementedError
raise
NotImplementedError
()
@
abstractmethod
def
accept_token
(
self
,
token
:
int
)
->
None
:
"""
Accept a token in the grammar.
"""
raise
NotImplementedError
@
abstractmethod
def
allocate_vocab_mask
(
self
,
vocab_size
:
int
,
batch_size
:
int
,
device
)
->
torch
.
Tensor
:
raise
NotImplementedError
@
abstractmethod
def
fill_vocab_mask
(
self
,
vocab_mask
:
torch
.
Tensor
,
idx
:
int
)
->
None
:
raise
NotImplementedError
@
staticmethod
@
abstractmethod
def
move_vocab_mask
(
vocab_mask
:
torch
.
Tensor
,
device
)
->
torch
.
Tensor
:
raise
NotImplementedError
@
staticmethod
@
abstractmethod
def
apply_vocab_mask
(
logits
:
torch
.
Tensor
,
vocab_mask
:
torch
.
Tensor
)
->
None
:
raise
NotImplementedError
@
abstractmethod
def
copy
(
self
)
->
"BaseGrammarObject"
:
raise
NotImplementedError
@
dataclass
@
dataclass
...
@@ -113,10 +103,9 @@ class BaseGrammarBackend:
...
@@ -113,10 +103,9 @@ class BaseGrammarBackend:
def
__init__
(
self
):
def
__init__
(
self
):
self
.
executor
=
ThreadPoolExecutor
()
self
.
executor
=
ThreadPoolExecutor
()
self
.
cache
:
Dict
[
Tuple
[
str
,
str
],
CacheEntry
]
=
{}
self
.
cache
:
Dict
[
Tuple
[
str
,
str
],
CacheEntry
]
=
{}
self
.
cache_lock
=
Lock
()
def
_not_supported
(
self
,
key_type
:
str
,
key_string
:
str
)
->
None
:
def
_not_supported
(
self
,
key_type
:
str
,
key_string
:
str
)
->
None
:
logger
.
warning
(
f
"Skip unsupported
{
key_type
}
:
{
key_type
}
=
{
key_string
}
"
)
logger
.
warning
(
f
"Skip unsupported
{
key_type
=
}
,
{
key_string
=
}
"
)
def
dispatch_fallback
(
def
dispatch_fallback
(
self
,
key_type
:
str
,
key_string
:
str
self
,
key_type
:
str
,
key_string
:
str
...
@@ -148,40 +137,25 @@ class BaseGrammarBackend:
...
@@ -148,40 +137,25 @@ class BaseGrammarBackend:
return
self
.
dispatch_ebnf
(
key_string
)
return
self
.
dispatch_ebnf
(
key_string
)
elif
key_type
==
"structural_tag"
:
elif
key_type
==
"structural_tag"
:
return
self
.
dispatch_structural_tag
(
key_string
)
return
self
.
dispatch_structural_tag
(
key_string
)
elif
key_type
==
"structural_pattern"
:
return
self
.
dispatch_structural_pattern
(
key_string
)
else
:
else
:
return
self
.
dispatch_fallback
(
key_type
,
key_string
)
return
self
.
dispatch_fallback
(
key_type
,
key_string
)
def
_init_value
(
self
,
key
:
Tuple
[
str
,
str
])
->
Optional
[
BaseGrammarObject
]:
def
get_cached_or_future_value
(
with
self
.
cache_lock
:
self
,
key
:
Tuple
[
str
,
str
]
if
key
in
self
.
cache
:
)
->
Optional
[
BaseGrammarObject
]:
cache_hit
=
True
value
=
self
.
cache
.
get
(
key
)
entry
=
self
.
cache
[
key
]
if
value
:
else
:
return
value
.
copy
(),
True
cache_hit
=
False
value
=
self
.
executor
.
submit
(
self
.
_init_value_dispatch
,
key
)
entry
=
CacheEntry
(
None
,
Event
())
return
value
,
False
self
.
cache
[
key
]
=
entry
if
cache_hit
:
entry
.
event
.
wait
()
else
:
entry
.
value
=
self
.
_init_value_dispatch
(
key
)
entry
.
event
.
set
()
return
entry
.
value
.
copy
()
if
entry
.
value
else
None
def
get_cached_value
(
self
,
key
:
Tuple
[
str
,
str
])
->
Optional
[
BaseGrammarObject
]:
with
self
.
cache_lock
:
entry
=
self
.
cache
.
get
(
key
)
if
not
entry
or
not
entry
.
event
.
is_set
():
return
None
val
=
self
.
cache
[
key
].
value
return
val
.
copy
()
if
val
else
None
def
g
et_
future_valu
e
(
self
,
key
:
Tuple
[
str
,
str
]
)
->
Future
:
def
s
et_
cach
e
(
self
,
key
:
Tuple
[
str
,
str
]
,
value
:
BaseGrammarObject
)
:
return
self
.
executor
.
submit
(
self
.
_init_value
,
key
)
self
.
cache
[
key
]
=
value
def
reset
(
self
):
def
reset
(
self
):
with
self
.
cache_lock
:
self
.
cache
.
clear
()
self
.
cache
.
clear
()
def
create_grammar_backend
(
def
create_grammar_backend
(
...
@@ -211,9 +185,12 @@ def create_grammar_backend(
...
@@ -211,9 +185,12 @@ def create_grammar_backend(
raise
ValueError
(
f
"Invalid grammar backend:
{
server_args
.
grammar_backend
}
"
)
raise
ValueError
(
f
"Invalid grammar backend:
{
server_args
.
grammar_backend
}
"
)
if
server_args
.
reasoning_parser
and
hasattr
(
tokenizer
,
"think_end_id"
):
if
server_args
.
reasoning_parser
and
hasattr
(
tokenizer
,
"think_end_id"
):
from
.reasoner_grammar_backend
import
ReasonerGrammarBackend
from
sglang.srt.constrained.reasoner_grammar_backend
import
(
ReasonerGrammarBackend
,
)
grammar_backend
=
ReasonerGrammarBackend
(
grammar_backend
=
ReasonerGrammarBackend
(
grammar_backend
,
tokenizer
.
think_end_id
grammar_backend
,
tokenizer
.
think_end_id
)
)
return
grammar_backend
return
grammar_backend
python/sglang/srt/constrained/llguidance_backend.py
View file @
01bdbf7f
...
@@ -50,21 +50,6 @@ class GuidanceGrammar(BaseGrammarObject):
...
@@ -50,21 +50,6 @@ class GuidanceGrammar(BaseGrammarObject):
self
.
finished
=
False
self
.
finished
=
False
self
.
bitmask
=
None
self
.
bitmask
=
None
def
try_jump_forward
(
self
,
tokenizer
)
->
Optional
[
Tuple
[
List
[
int
],
str
]]:
ff_tokens
=
self
.
ll_matcher
.
compute_ff_tokens
()
if
ff_tokens
:
return
ff_tokens
,
""
else
:
return
None
def
jump_forward_str_state
(
self
,
helper
:
Tuple
[
List
[
int
],
str
])
->
Tuple
[
str
,
int
]:
return
""
,
-
1
def
jump_and_retokenize
(
self
,
old_output_ids
:
List
[
int
],
new_output_ids
:
List
[
int
],
next_state
:
int
):
pass
def
accept_token
(
self
,
token
:
int
):
def
accept_token
(
self
,
token
:
int
):
if
not
self
.
ll_matcher
.
consume_token
(
token
):
if
not
self
.
ll_matcher
.
consume_token
(
token
):
logger
.
warning
(
f
"matcher error:
{
self
.
ll_matcher
.
get_error
()
}
"
)
logger
.
warning
(
f
"matcher error:
{
self
.
ll_matcher
.
get_error
()
}
"
)
...
@@ -104,6 +89,21 @@ class GuidanceGrammar(BaseGrammarObject):
...
@@ -104,6 +89,21 @@ class GuidanceGrammar(BaseGrammarObject):
serialized_grammar
=
self
.
serialized_grammar
,
serialized_grammar
=
self
.
serialized_grammar
,
)
)
def
try_jump_forward
(
self
,
tokenizer
)
->
Optional
[
Tuple
[
List
[
int
],
str
]]:
ff_tokens
=
self
.
ll_matcher
.
compute_ff_tokens
()
if
ff_tokens
:
return
ff_tokens
,
""
else
:
return
None
def
jump_forward_str_state
(
self
,
helper
:
Tuple
[
List
[
int
],
str
])
->
Tuple
[
str
,
int
]:
return
""
,
-
1
def
jump_and_retokenize
(
self
,
old_output_ids
:
List
[
int
],
new_output_ids
:
List
[
int
],
next_state
:
int
):
pass
class
GuidanceBackend
(
BaseGrammarBackend
):
class
GuidanceBackend
(
BaseGrammarBackend
):
...
@@ -130,12 +130,16 @@ class GuidanceBackend(BaseGrammarBackend):
...
@@ -130,12 +130,16 @@ class GuidanceBackend(BaseGrammarBackend):
return
None
return
None
def
dispatch_json
(
self
,
key_string
:
str
)
->
Optional
[
GuidanceGrammar
]:
def
dispatch_json
(
self
,
key_string
:
str
)
->
Optional
[
GuidanceGrammar
]:
serialized_grammar
=
LLMatcher
.
grammar_from_json_schema
(
try
:
key_string
,
serialized_grammar
=
LLMatcher
.
grammar_from_json_schema
(
defaults
=
{
key_string
,
"whitespace_pattern"
:
self
.
whitespace_pattern
,
defaults
=
{
},
"whitespace_pattern"
:
self
.
whitespace_pattern
,
)
},
)
except
Exception
as
e
:
logger
.
warning
(
f
"Skip invalid grammar:
{
key_string
=
}
,
{
e
=
}
"
)
return
None
return
self
.
_from_serialized
(
serialized_grammar
)
return
self
.
_from_serialized
(
serialized_grammar
)
def
dispatch_regex
(
self
,
key_string
:
str
)
->
Optional
[
GuidanceGrammar
]:
def
dispatch_regex
(
self
,
key_string
:
str
)
->
Optional
[
GuidanceGrammar
]:
...
...
python/sglang/srt/constrained/outlines_backend.py
View file @
01bdbf7f
...
@@ -53,6 +53,30 @@ class OutlinesGrammar(BaseGrammarObject):
...
@@ -53,6 +53,30 @@ class OutlinesGrammar(BaseGrammarObject):
def
accept_token
(
self
,
token
:
int
):
def
accept_token
(
self
,
token
:
int
):
self
.
state
=
self
.
guide
.
get_next_state
(
self
.
state
,
token
)
self
.
state
=
self
.
guide
.
get_next_state
(
self
.
state
,
token
)
def
allocate_vocab_mask
(
self
,
vocab_size
:
int
,
batch_size
:
int
,
device
)
->
torch
.
Tensor
:
return
torch
.
zeros
(
batch_size
,
vocab_size
,
dtype
=
torch
.
bool
,
device
=
device
)
@
staticmethod
def
move_vocab_mask
(
vocab_mask
:
torch
.
Tensor
,
device
)
->
torch
.
Tensor
:
return
vocab_mask
def
fill_vocab_mask
(
self
,
vocab_mask
:
torch
.
Tensor
,
idx
:
int
)
->
None
:
tokens
=
torch
.
tensor
(
self
.
guide
.
get_next_instruction
(
self
.
state
).
tokens
,
dtype
=
torch
.
int64
).
to
(
vocab_mask
.
device
,
non_blocking
=
True
)
vocab_mask
=
vocab_mask
[
idx
]
vocab_mask
.
fill_
(
1
)
vocab_mask
.
scatter_
(
0
,
tokens
,
torch
.
zeros_like
(
tokens
,
dtype
=
torch
.
bool
))
@
staticmethod
def
apply_vocab_mask
(
logits
:
torch
.
Tensor
,
vocab_mask
:
torch
.
Tensor
):
logits
.
masked_fill_
(
vocab_mask
,
float
(
"-inf"
))
def
copy
(
self
):
return
OutlinesGrammar
(
self
.
guide
,
self
.
jump_forward_map
)
def
try_jump_forward
(
self
,
tokenizer
)
->
Optional
[
Tuple
]:
def
try_jump_forward
(
self
,
tokenizer
)
->
Optional
[
Tuple
]:
if
not
self
.
jump_forward_map
:
if
not
self
.
jump_forward_map
:
return
None
return
None
...
@@ -86,30 +110,6 @@ class OutlinesGrammar(BaseGrammarObject):
...
@@ -86,30 +110,6 @@ class OutlinesGrammar(BaseGrammarObject):
):
):
self
.
state
=
next_state
self
.
state
=
next_state
def
allocate_vocab_mask
(
self
,
vocab_size
:
int
,
batch_size
:
int
,
device
)
->
torch
.
Tensor
:
return
torch
.
zeros
(
batch_size
,
vocab_size
,
dtype
=
torch
.
bool
,
device
=
device
)
@
staticmethod
def
move_vocab_mask
(
vocab_mask
:
torch
.
Tensor
,
device
)
->
torch
.
Tensor
:
return
vocab_mask
def
fill_vocab_mask
(
self
,
vocab_mask
:
torch
.
Tensor
,
idx
:
int
)
->
None
:
tokens
=
torch
.
tensor
(
self
.
guide
.
get_next_instruction
(
self
.
state
).
tokens
,
dtype
=
torch
.
int64
).
to
(
vocab_mask
.
device
,
non_blocking
=
True
)
vocab_mask
=
vocab_mask
[
idx
]
vocab_mask
.
fill_
(
1
)
vocab_mask
.
scatter_
(
0
,
tokens
,
torch
.
zeros_like
(
tokens
,
dtype
=
torch
.
bool
))
@
staticmethod
def
apply_vocab_mask
(
logits
:
torch
.
Tensor
,
vocab_mask
:
torch
.
Tensor
):
logits
.
masked_fill_
(
vocab_mask
,
float
(
"-inf"
))
def
copy
(
self
):
return
OutlinesGrammar
(
self
.
guide
,
self
.
jump_forward_map
)
class
OutlinesGrammarBackend
(
BaseGrammarBackend
):
class
OutlinesGrammarBackend
(
BaseGrammarBackend
):
def
__init__
(
def
__init__
(
...
@@ -169,8 +169,9 @@ class OutlinesGrammarBackend(BaseGrammarBackend):
...
@@ -169,8 +169,9 @@ class OutlinesGrammarBackend(BaseGrammarBackend):
key_string
,
key_string
,
whitespace_pattern
=
self
.
whitespace_pattern
,
whitespace_pattern
=
self
.
whitespace_pattern
,
)
)
except
(
NotImplementedError
,
json
.
decoder
.
JSONDecodeError
)
as
e
:
except
(
NotImplementedError
,
json
.
decoder
.
JSONDecodeError
,
ValueError
)
as
e
:
logger
.
warning
(
f
"Skip invalid json_schema: json_schema=
{
key_string
}
,
{
e
=
}
"
)
logger
.
warning
(
f
"Skip invalid json_schema:
{
key_string
=
}
,
{
e
=
}
"
)
return
None
return
self
.
_compile_regex
(
regex
)
return
self
.
_compile_regex
(
regex
)
def
dispatch_regex
(
self
,
key_string
:
str
):
def
dispatch_regex
(
self
,
key_string
:
str
):
...
...
python/sglang/srt/constrained/reasoner_grammar_backend.py
View file @
01bdbf7f
...
@@ -13,7 +13,6 @@
...
@@ -13,7 +13,6 @@
# ==============================================================================
# ==============================================================================
"""The baseclass of a backend for reasoner grammar-guided constrained decoding."""
"""The baseclass of a backend for reasoner grammar-guided constrained decoding."""
from
concurrent.futures
import
Future
from
typing
import
List
,
Optional
,
Tuple
from
typing
import
List
,
Optional
,
Tuple
import
torch
import
torch
...
@@ -28,13 +27,12 @@ class ReasonerGrammarObject(BaseGrammarObject):
...
@@ -28,13 +27,12 @@ class ReasonerGrammarObject(BaseGrammarObject):
self
.
think_end_id
=
think_end_id
self
.
think_end_id
=
think_end_id
self
.
is_in_reasoning
=
True
self
.
is_in_reasoning
=
True
@
property
def
accept_token
(
self
,
token
:
int
):
def
finished
(
self
)
:
if
token
==
self
.
think_end_id
:
return
self
.
grammar
.
finished
self
.
is_in_reasoning
=
False
@
finished
.
setter
if
not
self
.
is_in_reasoning
and
token
!=
self
.
think_end_id
:
def
finished
(
self
,
finished
):
self
.
grammar
.
accept_token
(
token
)
self
.
grammar
.
finished
=
finished
def
allocate_vocab_mask
(
def
allocate_vocab_mask
(
self
,
vocab_size
:
int
,
batch_size
:
int
,
device
self
,
vocab_size
:
int
,
batch_size
:
int
,
device
...
@@ -52,12 +50,16 @@ class ReasonerGrammarObject(BaseGrammarObject):
...
@@ -52,12 +50,16 @@ class ReasonerGrammarObject(BaseGrammarObject):
def
apply_vocab_mask
(
self
):
def
apply_vocab_mask
(
self
):
return
self
.
grammar
.
apply_vocab_mask
return
self
.
grammar
.
apply_vocab_mask
def
accept_token
(
self
,
token
:
int
):
def
copy
(
self
)
->
BaseGrammarObject
:
if
token
==
self
.
think_end_id
:
return
ReasonerGrammarObject
(
self
.
grammar
.
copy
(),
self
.
think_end_id
)
self
.
is_in_reasoning
=
False
if
not
self
.
is_in_reasoning
and
token
!=
self
.
think_end_id
:
@
property
self
.
grammar
.
accept_token
(
token
)
def
finished
(
self
):
return
self
.
grammar
.
finished
@
finished
.
setter
def
finished
(
self
,
finished
):
self
.
grammar
.
finished
=
finished
def
try_jump_forward
(
self
,
tokenizer
):
def
try_jump_forward
(
self
,
tokenizer
):
return
self
.
grammar
.
try_jump_forward
(
tokenizer
)
return
self
.
grammar
.
try_jump_forward
(
tokenizer
)
...
@@ -72,30 +74,17 @@ class ReasonerGrammarObject(BaseGrammarObject):
...
@@ -72,30 +74,17 @@ class ReasonerGrammarObject(BaseGrammarObject):
old_output_ids
,
new_output_ids
,
next_state
old_output_ids
,
new_output_ids
,
next_state
)
)
def
copy
(
self
)
->
BaseGrammarObject
:
return
ReasonerGrammarObject
(
self
.
grammar
.
copy
(),
self
.
think_end_id
)
class
ReasonerGrammarBackend
(
BaseGrammarBackend
):
class
ReasonerGrammarBackend
(
BaseGrammarBackend
):
def
__init__
(
self
,
grammar_backend
:
BaseGrammarBackend
,
think_end_id
):
def
__init__
(
self
,
grammar_backend
:
BaseGrammarBackend
,
think_end_id
):
super
().
__init__
()
self
.
grammar_backend
=
grammar_backend
self
.
grammar_backend
=
grammar_backend
self
.
think_end_id
=
think_end_id
self
.
think_end_id
=
think_end_id
def
get_cached_value
(
self
,
key
:
Tuple
[
str
,
str
])
->
Optional
[
ReasonerGrammarObject
]:
def
_init_value_dispatch
(
grammar
=
self
.
grammar_backend
.
get_cached_value
(
key
)
self
,
key
:
Tuple
[
str
,
str
]
return
ReasonerGrammarObject
(
grammar
,
self
.
think_end_id
)
if
grammar
else
None
)
->
Optional
[
ReasonerGrammarObject
]:
ret
=
self
.
grammar_backend
.
_init_value_dispatch
(
key
)
def
get_future_value
(
self
,
key
:
Tuple
[
str
,
str
])
->
Future
:
if
ret
is
None
:
grammar
=
Future
()
return
None
return
ReasonerGrammarObject
(
ret
,
self
.
think_end_id
)
def
callback
(
f
:
Future
):
if
result
:
=
f
.
result
():
grammar
.
set_result
(
ReasonerGrammarObject
(
result
,
self
.
think_end_id
))
else
:
grammar
.
set_result
(
None
)
self
.
grammar_backend
.
get_future_value
(
key
).
add_done_callback
(
callback
)
return
grammar
def
reset
(
self
):
self
.
grammar_backend
.
reset
()
python/sglang/srt/constrained/xgrammar_backend.py
View file @
01bdbf7f
...
@@ -18,7 +18,6 @@ import logging
...
@@ -18,7 +18,6 @@ import logging
from
typing
import
List
,
Optional
,
Tuple
,
Union
from
typing
import
List
,
Optional
,
Tuple
,
Union
import
torch
import
torch
import
xgrammar
from
xgrammar
import
(
from
xgrammar
import
(
CompiledGrammar
,
CompiledGrammar
,
GrammarCompiler
,
GrammarCompiler
,
...
@@ -35,7 +34,6 @@ from sglang.srt.constrained.base_grammar_backend import (
...
@@ -35,7 +34,6 @@ from sglang.srt.constrained.base_grammar_backend import (
from
sglang.srt.constrained.triton_ops.bitmask_ops
import
(
from
sglang.srt.constrained.triton_ops.bitmask_ops
import
(
apply_token_bitmask_inplace_triton
,
apply_token_bitmask_inplace_triton
,
)
)
from
sglang.srt.utils
import
get_bool_env_var
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -51,49 +49,35 @@ class XGrammarGrammar(BaseGrammarObject):
...
@@ -51,49 +49,35 @@ class XGrammarGrammar(BaseGrammarObject):
vocab_size
:
int
,
vocab_size
:
int
,
ctx
:
CompiledGrammar
,
ctx
:
CompiledGrammar
,
override_stop_tokens
:
Optional
[
Union
[
List
[
int
],
int
]],
override_stop_tokens
:
Optional
[
Union
[
List
[
int
],
int
]],
key_string
:
Optional
[
str
]
=
None
,
# TODO (sk): for debugging, remove later
)
->
None
:
)
->
None
:
super
().
__init__
()
self
.
matcher
=
matcher
self
.
matcher
=
matcher
self
.
vocab_size
=
vocab_size
self
.
vocab_size
=
vocab_size
self
.
ctx
=
ctx
self
.
ctx
=
ctx
self
.
override_stop_tokens
=
override_stop_tokens
self
.
override_stop_tokens
=
override_stop_tokens
self
.
finished
=
False
self
.
finished
=
False
self
.
accepted_tokens
=
[]
from
xgrammar.kernels.apply_token_bitmask_inplace_cpu
import
(
self
.
key_string
=
key_string
apply_token_bitmask_inplace_cpu
,
)
self
.
apply_vocab_mask_cpu
=
apply_token_bitmask_inplace_cpu
def
accept_token
(
self
,
token
:
int
):
def
accept_token
(
self
,
token
:
int
):
assert
self
.
matcher
.
accept_token
(
token
)
if
not
self
.
is_terminated
():
accepted
=
self
.
matcher
.
accept_token
(
token
)
def
try_jump_forward
(
self
,
tokenizer
)
->
Optional
[
Tuple
[
List
[
int
],
str
]]:
if
not
accepted
:
s
=
self
.
matcher
.
find_jump_forward_string
()
# log for debugging
if
s
:
raise
ValueError
(
return
[],
s
f
"Tokens not accepted:
{
token
}
\n
"
return
None
f
"Accepted tokens:
{
self
.
accepted_tokens
}
\n
"
f
"Key string:
{
self
.
key_string
}
"
def
jump_forward_str_state
(
self
,
helper
:
Tuple
[
List
[
int
],
str
])
->
Tuple
[
str
,
int
]:
)
_
,
data
=
helper
return
data
,
-
1
def
jump_and_retokenize
(
self
,
old_output_ids
:
List
[
int
],
new_output_ids
:
List
[
int
],
next_state
:
int
):
k
=
0
for
i
,
old_id
in
enumerate
(
old_output_ids
):
if
old_id
==
new_output_ids
[
i
]:
k
=
i
+
1
else
:
else
:
break
self
.
accepted_tokens
.
append
(
token
)
#
rollback
to the last token that is the same
def
rollback
(
self
,
k
:
int
):
if
k
<
len
(
old_output_ids
):
self
.
matcher
.
rollback
(
k
)
self
.
matcher
.
rollback
(
len
(
old_output_ids
)
-
k
)
self
.
accepted_tokens
=
self
.
accepted_tokens
[:
-
k
]
for
i
in
range
(
k
,
len
(
new_output_ids
)
):
def
is_terminated
(
self
):
assert
self
.
matcher
.
accept_token
(
new_output_ids
[
i
]
)
return
self
.
matcher
.
is_terminated
(
)
def
allocate_vocab_mask
(
def
allocate_vocab_mask
(
self
,
vocab_size
:
int
,
batch_size
:
int
,
device
self
,
vocab_size
:
int
,
batch_size
:
int
,
device
...
@@ -122,9 +106,43 @@ class XGrammarGrammar(BaseGrammarObject):
...
@@ -122,9 +106,43 @@ class XGrammarGrammar(BaseGrammarObject):
override_stop_tokens
=
self
.
override_stop_tokens
,
override_stop_tokens
=
self
.
override_stop_tokens
,
)
)
return
XGrammarGrammar
(
return
XGrammarGrammar
(
matcher
,
self
.
vocab_size
,
self
.
ctx
,
self
.
override_stop_tokens
matcher
,
self
.
vocab_size
,
self
.
ctx
,
self
.
override_stop_tokens
,
self
.
key_string
,
)
)
def
try_jump_forward
(
self
,
tokenizer
)
->
Optional
[
Tuple
[
List
[
int
],
str
]]:
s
=
self
.
matcher
.
find_jump_forward_string
()
if
s
:
return
[],
s
return
None
def
jump_forward_str_state
(
self
,
helper
:
Tuple
[
List
[
int
],
str
])
->
Tuple
[
str
,
int
]:
_
,
data
=
helper
return
data
,
-
1
def
jump_and_retokenize
(
self
,
old_output_ids
:
List
[
int
],
new_output_ids
:
List
[
int
],
next_state
:
int
):
k
=
0
for
i
,
old_id
in
enumerate
(
old_output_ids
):
if
old_id
==
new_output_ids
[
i
]:
k
=
i
+
1
else
:
break
# rollback to the last token that is the same
if
k
<
len
(
old_output_ids
):
self
.
matcher
.
rollback
(
len
(
old_output_ids
)
-
k
)
for
i
in
range
(
k
,
len
(
new_output_ids
)):
assert
self
.
matcher
.
accept_token
(
new_output_ids
[
i
])
def
__repr__
(
self
):
return
f
"XGrammarGrammar(
{
self
.
key_string
=
}
,
{
self
.
accepted_tokens
=
}
)"
class
XGrammarGrammarBackend
(
BaseGrammarBackend
):
class
XGrammarGrammarBackend
(
BaseGrammarBackend
):
def
__init__
(
def
__init__
(
...
@@ -143,9 +161,15 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
...
@@ -143,9 +161,15 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
self
.
vocab_size
=
vocab_size
self
.
vocab_size
=
vocab_size
self
.
override_stop_tokens
=
override_stop_tokens
self
.
override_stop_tokens
=
override_stop_tokens
def
_from_context
(
self
,
ctx
:
CompiledGrammar
)
->
XGrammarGrammar
:
def
_from_context
(
self
,
ctx
:
CompiledGrammar
,
key_string
:
str
)
->
XGrammarGrammar
:
matcher
=
GrammarMatcher
(
ctx
,
max_rollback_tokens
=
MAX_ROLLBACK_TOKENS
)
matcher
=
GrammarMatcher
(
return
XGrammarGrammar
(
matcher
,
self
.
vocab_size
,
ctx
,
self
.
override_stop_tokens
)
ctx
,
max_rollback_tokens
=
MAX_ROLLBACK_TOKENS
,
override_stop_tokens
=
self
.
override_stop_tokens
,
)
return
XGrammarGrammar
(
matcher
,
self
.
vocab_size
,
ctx
,
self
.
override_stop_tokens
,
key_string
)
def
dispatch_json
(
self
,
key_string
:
str
)
->
Optional
[
XGrammarGrammar
]:
def
dispatch_json
(
self
,
key_string
:
str
)
->
Optional
[
XGrammarGrammar
]:
try
:
try
:
...
@@ -157,7 +181,7 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
...
@@ -157,7 +181,7 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
except
RuntimeError
as
e
:
except
RuntimeError
as
e
:
logging
.
warning
(
f
"Skip invalid json_schema: json_schema=
{
key_string
}
,
{
e
=
}
"
)
logging
.
warning
(
f
"Skip invalid json_schema: json_schema=
{
key_string
}
,
{
e
=
}
"
)
return
None
return
None
return
self
.
_from_context
(
ctx
)
return
self
.
_from_context
(
ctx
,
key_string
)
def
dispatch_ebnf
(
self
,
key_string
:
str
)
->
Optional
[
XGrammarGrammar
]:
def
dispatch_ebnf
(
self
,
key_string
:
str
)
->
Optional
[
XGrammarGrammar
]:
try
:
try
:
...
@@ -165,7 +189,7 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
...
@@ -165,7 +189,7 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
except
RuntimeError
as
e
:
except
RuntimeError
as
e
:
logging
.
warning
(
f
"Skip invalid ebnf: ebnf=
{
key_string
}
,
{
e
=
}
"
)
logging
.
warning
(
f
"Skip invalid ebnf: ebnf=
{
key_string
}
,
{
e
=
}
"
)
return
None
return
None
return
self
.
_from_context
(
ctx
)
return
self
.
_from_context
(
ctx
,
key_string
)
def
dispatch_regex
(
self
,
key_string
:
str
)
->
Optional
[
XGrammarGrammar
]:
def
dispatch_regex
(
self
,
key_string
:
str
)
->
Optional
[
XGrammarGrammar
]:
try
:
try
:
...
@@ -173,7 +197,7 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
...
@@ -173,7 +197,7 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
except
RuntimeError
as
e
:
except
RuntimeError
as
e
:
logging
.
warning
(
f
"Skip invalid regex: regex=
{
key_string
}
,
{
e
=
}
"
)
logging
.
warning
(
f
"Skip invalid regex: regex=
{
key_string
}
,
{
e
=
}
"
)
return
None
return
None
return
self
.
_from_context
(
ctx
)
return
self
.
_from_context
(
ctx
,
key_string
)
def
dispatch_structural_tag
(
self
,
key_string
:
str
)
->
Optional
[
XGrammarGrammar
]:
def
dispatch_structural_tag
(
self
,
key_string
:
str
)
->
Optional
[
XGrammarGrammar
]:
try
:
try
:
...
@@ -190,9 +214,11 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
...
@@ -190,9 +214,11 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
tags
,
structural_tag
[
"triggers"
]
tags
,
structural_tag
[
"triggers"
]
)
)
except
RuntimeError
as
e
:
except
RuntimeError
as
e
:
logging
.
warning
(
f
"Skip invalid regex: regex=
{
key_string
}
,
{
e
=
}
"
)
logging
.
warning
(
f
"Skip invalid structural_tag: structural_tag=
{
key_string
}
,
{
e
=
}
"
)
return
None
return
None
return
self
.
_from_context
(
ctx
)
return
self
.
_from_context
(
ctx
,
key_string
)
def
reset
(
self
):
def
reset
(
self
):
if
self
.
grammar_compiler
:
if
self
.
grammar_compiler
:
...
...
python/sglang/srt/layers/sampler.py
View file @
01bdbf7f
...
@@ -239,10 +239,6 @@ def top_p_normalize_probs_torch(
...
@@ -239,10 +239,6 @@ def top_p_normalize_probs_torch(
def
get_top_logprobs
(
logprobs
:
torch
.
Tensor
,
top_logprobs_nums
:
List
[
int
]):
def
get_top_logprobs
(
logprobs
:
torch
.
Tensor
,
top_logprobs_nums
:
List
[
int
]):
assert
len
(
top_logprobs_nums
)
==
logprobs
.
shape
[
0
],
(
len
(
top_logprobs_nums
),
logprobs
.
shape
[
0
],
)
max_k
=
max
(
top_logprobs_nums
)
max_k
=
max
(
top_logprobs_nums
)
ret
=
logprobs
.
topk
(
max_k
,
dim
=
1
)
ret
=
logprobs
.
topk
(
max_k
,
dim
=
1
)
values
=
ret
.
values
.
tolist
()
values
=
ret
.
values
.
tolist
()
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
01bdbf7f
...
@@ -533,6 +533,7 @@ class Req:
...
@@ -533,6 +533,7 @@ class Req:
# Constrained decoding
# Constrained decoding
self
.
grammar
:
Optional
[
BaseGrammarObject
]
=
None
self
.
grammar
:
Optional
[
BaseGrammarObject
]
=
None
self
.
grammar_wait_ct
=
0
# The number of cached tokens that were already cached in the KV cache
# The number of cached tokens that were already cached in the KV cache
self
.
cached_tokens
=
0
self
.
cached_tokens
=
0
...
...
python/sglang/srt/managers/scheduler.py
View file @
01bdbf7f
...
@@ -149,6 +149,7 @@ logger = logging.getLogger(__name__)
...
@@ -149,6 +149,7 @@ logger = logging.getLogger(__name__)
# Test retract decode for debugging purposes
# Test retract decode for debugging purposes
TEST_RETRACT
=
get_bool_env_var
(
"SGLANG_TEST_RETRACT"
)
TEST_RETRACT
=
get_bool_env_var
(
"SGLANG_TEST_RETRACT"
)
RECORD_STEP_TIME
=
get_bool_env_var
(
"SGLANG_RECORD_STEP_TIME"
)
RECORD_STEP_TIME
=
get_bool_env_var
(
"SGLANG_RECORD_STEP_TIME"
)
GRAMMAR_TIMEOUT
=
float
(
os
.
environ
.
get
(
"SGLANG_GRAMMAR_TIMEOUT"
,
300
))
@
dataclass
@
dataclass
...
@@ -1024,9 +1025,11 @@ class Scheduler(
...
@@ -1024,9 +1025,11 @@ class Scheduler(
elif
req
.
sampling_params
.
structural_tag
:
elif
req
.
sampling_params
.
structural_tag
:
key
=
(
"structural_tag"
,
req
.
sampling_params
.
structural_tag
)
key
=
(
"structural_tag"
,
req
.
sampling_params
.
structural_tag
)
req
.
grammar
=
self
.
grammar_backend
.
get_cached_value
(
key
)
value
,
cache_hit
=
self
.
grammar_backend
.
get_cached_or_future_value
(
key
)
if
not
req
.
grammar
:
req
.
grammar
=
value
req
.
grammar
=
self
.
grammar_backend
.
get_future_value
(
key
)
if
not
cache_hit
:
req
.
grammar_key
=
key
add_to_grammar_queue
=
True
add_to_grammar_queue
=
True
if
add_to_grammar_queue
:
if
add_to_grammar_queue
:
...
@@ -1208,6 +1211,7 @@ class Scheduler(
...
@@ -1208,6 +1211,7 @@ class Scheduler(
self
.
stats
.
cache_hit_rate
=
0.0
self
.
stats
.
cache_hit_rate
=
0.0
self
.
stats
.
gen_throughput
=
self
.
last_gen_throughput
self
.
stats
.
gen_throughput
=
self
.
last_gen_throughput
self
.
stats
.
num_queue_reqs
=
len
(
self
.
waiting_queue
)
self
.
stats
.
num_queue_reqs
=
len
(
self
.
waiting_queue
)
self
.
stats
.
num_grammar_queue_reqs
=
len
(
self
.
grammar_queue
)
self
.
stats
.
spec_accept_length
=
spec_accept_length
self
.
stats
.
spec_accept_length
=
spec_accept_length
self
.
metrics_collector
.
log_stats
(
self
.
stats
)
self
.
metrics_collector
.
log_stats
(
self
.
stats
)
...
@@ -1255,6 +1259,7 @@ class Scheduler(
...
@@ -1255,6 +1259,7 @@ class Scheduler(
self
.
stats
.
token_usage
=
num_used
/
self
.
max_total_num_tokens
self
.
stats
.
token_usage
=
num_used
/
self
.
max_total_num_tokens
self
.
stats
.
gen_throughput
=
0
self
.
stats
.
gen_throughput
=
0
self
.
stats
.
num_queue_reqs
=
len
(
self
.
waiting_queue
)
self
.
stats
.
num_queue_reqs
=
len
(
self
.
waiting_queue
)
self
.
stats
.
num_grammar_queue_reqs
=
len
(
self
.
grammar_queue
)
self
.
metrics_collector
.
log_stats
(
self
.
stats
)
self
.
metrics_collector
.
log_stats
(
self
.
stats
)
def
get_next_batch_to_run
(
self
)
->
Optional
[
ScheduleBatch
]:
def
get_next_batch_to_run
(
self
)
->
Optional
[
ScheduleBatch
]:
...
@@ -1715,11 +1720,17 @@ class Scheduler(
...
@@ -1715,11 +1720,17 @@ class Scheduler(
"""Move requests whose grammar objects are ready from grammar_queue to waiting_queue."""
"""Move requests whose grammar objects are ready from grammar_queue to waiting_queue."""
num_ready_reqs
=
0
num_ready_reqs
=
0
num_abort_reqs
=
0
for
req
in
self
.
grammar_queue
:
for
req
in
self
.
grammar_queue
:
try
:
try
:
req
.
grammar
=
req
.
grammar
.
result
(
timeout
=
0.05
)
req
.
grammar
=
req
.
grammar
.
result
(
timeout
=
0.03
)
if
req
.
grammar
:
self
.
grammar_backend
.
set_cache
(
req
.
grammar_key
,
req
.
grammar
.
copy
())
num_ready_reqs
+=
1
num_ready_reqs
+=
1
except
futures
.
_base
.
TimeoutError
:
except
futures
.
_base
.
TimeoutError
:
req
.
grammar_wait_ct
+=
1
if
req
.
grammar_wait_ct
>
GRAMMAR_TIMEOUT
/
0.03
:
num_abort_reqs
=
1
break
break
if
self
.
server_args
.
enable_dp_attention
:
if
self
.
server_args
.
enable_dp_attention
:
...
@@ -1731,14 +1742,28 @@ class Scheduler(
...
@@ -1731,14 +1742,28 @@ class Scheduler(
if
tp_size
>
1
:
if
tp_size
>
1
:
# Sync across TP ranks to make sure they have the same number of ready requests
# Sync across TP ranks to make sure they have the same number of ready requests
tensor
=
torch
.
tensor
(
num_ready_reqs
,
dtype
=
torch
.
int32
)
tensor
=
torch
.
tensor
(
[
num_ready_reqs
,
num_abort_reqs
],
dtype
=
torch
.
int32
)
torch
.
distributed
.
all_reduce
(
torch
.
distributed
.
all_reduce
(
tensor
,
op
=
torch
.
distributed
.
ReduceOp
.
MAX
,
group
=
tp_group
tensor
,
op
=
torch
.
distributed
.
ReduceOp
.
MAX
,
group
=
tp_group
)
)
num_ready_reqs_max
=
tensor
.
item
()
num_ready_reqs_max
,
num_abort_reqs_max
=
tensor
.
tolist
()
for
i
in
range
(
num_ready_reqs
,
num_ready_reqs_max
):
for
i
in
range
(
num_ready_reqs
,
num_ready_reqs_max
):
self
.
grammar_queue
[
i
].
grammar
=
self
.
grammar_queue
[
i
].
grammar
.
result
()
req
=
self
.
grammar_queue
[
i
]
num_ready_reqs
=
num_ready_reqs_max
req
.
grammar
=
req
.
grammar
.
result
()
if
req
.
grammar
:
self
.
grammar_backend
.
set_cache
(
req
.
grammar_key
,
req
.
grammar
.
copy
())
for
i
in
range
(
num_ready_reqs
,
num_ready_reqs
+
num_abort_reqs_max
):
req
=
self
.
grammar_queue
[
i
]
req
.
grammar
.
cancel
()
req
.
grammar
=
None
error_msg
=
f
"Grammar preprocessing timed out for
{
req
.
grammar_key
=
}
"
logger
.
error
(
error_msg
)
req
.
finished_reason
=
FINISH_ABORT
(
error_msg
,
HTTPStatus
.
BAD_REQUEST
,
"BadRequestError"
)
num_ready_reqs
=
num_ready_reqs_max
+
num_abort_reqs_max
self
.
_extend_requests_to_queue
(
self
.
grammar_queue
[:
num_ready_reqs
])
self
.
_extend_requests_to_queue
(
self
.
grammar_queue
[:
num_ready_reqs
])
self
.
grammar_queue
=
self
.
grammar_queue
[
num_ready_reqs
:]
self
.
grammar_queue
=
self
.
grammar_queue
[
num_ready_reqs
:]
...
...
python/sglang/srt/managers/tokenizer_manager.py
View file @
01bdbf7f
...
@@ -1230,11 +1230,18 @@ class TokenizerManager:
...
@@ -1230,11 +1230,18 @@ class TokenizerManager:
state
.
last_completion_tokens
=
completion_tokens
state
.
last_completion_tokens
=
completion_tokens
if
state
.
finished
:
if
state
.
finished
:
has_grammar
=
(
state
.
obj
.
sampling_params
.
get
(
"json_schema"
,
None
)
or
state
.
obj
.
sampling_params
.
get
(
"regex"
,
None
)
or
state
.
obj
.
sampling_params
.
get
(
"ebnf"
,
None
)
or
state
.
obj
.
sampling_params
.
get
(
"structural_tag"
,
None
)
)
self
.
metrics_collector
.
observe_one_finished_request
(
self
.
metrics_collector
.
observe_one_finished_request
(
recv_obj
.
prompt_tokens
[
i
],
recv_obj
.
prompt_tokens
[
i
],
completion_tokens
,
completion_tokens
,
recv_obj
.
cached_tokens
[
i
],
recv_obj
.
cached_tokens
[
i
],
state
.
finished_time
-
state
.
created_time
,
state
.
finished_time
-
state
.
created_time
,
has_grammar
,
)
)
def
dump_requests
(
self
,
state
:
ReqState
,
out_dict
:
dict
):
def
dump_requests
(
self
,
state
:
ReqState
,
out_dict
:
dict
):
...
...
python/sglang/srt/metrics/collector.py
View file @
01bdbf7f
...
@@ -15,7 +15,119 @@
...
@@ -15,7 +15,119 @@
import
time
import
time
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
Dict
,
Union
from
enum
import
Enum
from
typing
import
Dict
,
List
,
Optional
,
Union
from
sglang.srt.utils
import
get_bool_env_var
SGLANG_TEST_REQUEST_TIME_STATS
=
get_bool_env_var
(
"SGLANG_TEST_REQUEST_TIME_STATS"
)
@
dataclass
class
TimeStats
:
"""
Store the timestamps for each stage of a request.
Unified: wait_queue -> forward -> completion
Prefill: bootstrap_queue -> wait_queue -> forward -> transfer_queue -> completion
Decode: prealloc_queue -> transfer_queue -> wait_queue -> forward -> completion
"""
lb_entry_time
:
float
=
0.0
wait_queue_entry_time
:
float
=
0.0
forward_entry_time
:
float
=
0.0
completion_time
:
float
=
0.0
prefill_bootstrap_queue_entry_time
:
float
=
0.0
prefill_transfer_queue_entry_time
:
float
=
0.0
decode_prealloc_queue_entry_time
:
float
=
0.0
decode_transfer_queue_entry_time
:
float
=
0.0
class
RequestType
(
Enum
):
UNIFIED
=
"unified"
PREFILL
=
"prefill"
DECODE
=
"decode"
INVALID
=
"invalid"
def
__str__
(
self
)
->
str
:
# if unified
_type
=
self
.
get_type
()
if
_type
==
self
.
RequestType
.
UNIFIED
:
queue_duration
=
self
.
forward_entry_time
-
self
.
wait_queue_entry_time
forward_duration
=
self
.
completion_time
-
self
.
forward_entry_time
if
SGLANG_TEST_REQUEST_TIME_STATS
:
assert
(
queue_duration
>=
0
and
forward_duration
>=
0
),
f
"queue_duration=
{
queue_duration
}
< 0 or forward_duration=
{
forward_duration
}
< 0"
return
f
"queue_duration=
{
self
.
format_duration
(
queue_duration
)
}
, forward_duration=
{
self
.
format_duration
(
forward_duration
)
}
, start_time=
{
self
.
wait_queue_entry_time
}
"
elif
_type
==
self
.
RequestType
.
PREFILL
:
bootstrap_duration
=
(
self
.
wait_queue_entry_time
-
self
.
prefill_bootstrap_queue_entry_time
)
queue_duration
=
self
.
forward_entry_time
-
self
.
wait_queue_entry_time
forward_duration
=
self
.
completion_time
-
self
.
forward_entry_time
if
SGLANG_TEST_REQUEST_TIME_STATS
:
assert
(
bootstrap_duration
>=
0
and
queue_duration
>=
0
and
forward_duration
>=
0
),
f
"bootstrap_duration=
{
bootstrap_duration
}
< 0 or queue_duration=
{
queue_duration
}
< 0 or forward_duration=
{
forward_duration
}
< 0"
return
f
"bootstrap_duration=
{
self
.
format_duration
(
bootstrap_duration
)
}
, queue_duration=
{
self
.
format_duration
(
queue_duration
)
}
, forward_duration=
{
self
.
format_duration
(
forward_duration
)
}
, start_time=
{
self
.
prefill_bootstrap_queue_entry_time
}
"
# if decode
elif
_type
==
self
.
RequestType
.
DECODE
:
prealloc_duration
=
(
self
.
decode_transfer_queue_entry_time
-
self
.
decode_prealloc_queue_entry_time
)
transfer_duration
=
(
self
.
wait_queue_entry_time
-
self
.
decode_transfer_queue_entry_time
)
queue_duration
=
self
.
forward_entry_time
-
self
.
wait_queue_entry_time
forward_duration
=
self
.
completion_time
-
self
.
forward_entry_time
if
SGLANG_TEST_REQUEST_TIME_STATS
:
assert
(
prealloc_duration
>=
0
and
transfer_duration
>=
0
and
queue_duration
>=
0
and
forward_duration
>=
0
),
f
"prealloc_duration=
{
prealloc_duration
}
< 0 or transfer_duration=
{
transfer_duration
}
< 0 or queue_duration=
{
queue_duration
}
< 0 or forward_duration=
{
forward_duration
}
< 0"
return
f
"prealloc_duration=
{
self
.
format_duration
(
prealloc_duration
)
}
, transfer_duration=
{
self
.
format_duration
(
transfer_duration
)
}
, queue_duration=
{
self
.
format_duration
(
queue_duration
)
}
, forward_duration=
{
self
.
format_duration
(
forward_duration
)
}
, start_time=
{
self
.
decode_prealloc_queue_entry_time
}
"
else
:
return
"Invalid Time Stats"
def
format_duration
(
self
,
duration
:
float
)
->
str
:
return
f
"
{
duration
*
1e3
:.
2
f
}
ms"
def
get_type
(
self
)
->
RequestType
:
"""Determine the type of request based on timestamp values."""
if
(
self
.
prefill_bootstrap_queue_entry_time
==
0.0
and
self
.
prefill_transfer_queue_entry_time
==
0.0
and
self
.
decode_prealloc_queue_entry_time
==
0.0
and
self
.
decode_transfer_queue_entry_time
==
0.0
):
return
self
.
RequestType
.
UNIFIED
elif
(
self
.
prefill_bootstrap_queue_entry_time
>
0.0
and
self
.
prefill_transfer_queue_entry_time
>
0.0
):
return
self
.
RequestType
.
PREFILL
elif
(
self
.
decode_prealloc_queue_entry_time
>
0.0
and
self
.
decode_transfer_queue_entry_time
>
0.0
and
self
.
wait_queue_entry_time
>
0.0
):
return
self
.
RequestType
.
DECODE
else
:
return
self
.
RequestType
.
INVALID
@
dataclass
@
dataclass
...
@@ -26,15 +138,20 @@ class SchedulerStats:
...
@@ -26,15 +138,20 @@ class SchedulerStats:
gen_throughput
:
float
=
0.0
gen_throughput
:
float
=
0.0
num_queue_reqs
:
int
=
0
num_queue_reqs
:
int
=
0
cache_hit_rate
:
float
=
0.0
cache_hit_rate
:
float
=
0.0
num_grammar_queue_reqs
:
int
=
0
spec_accept_length
:
float
=
0.0
spec_accept_length
:
float
=
0.0
avg_request_queue_latency
:
float
=
0.0
avg_request_queue_latency
:
float
=
0.0
num_prefill_prealloc_queue_reqs
:
int
=
0
num_prefill_infight_queue_reqs
:
int
=
0
num_decode_prealloc_queue_reqs
:
int
=
0
num_decode_transfer_queue_reqs
:
int
=
0
class
SchedulerMetricsCollector
:
class
SchedulerMetricsCollector
:
def
__init__
(
self
,
labels
:
Dict
[
str
,
str
])
->
None
:
def
__init__
(
self
,
labels
:
Dict
[
str
,
str
])
->
None
:
# We need to import prometheus_client after setting the env variable `PROMETHEUS_MULTIPROC_DIR`
# We need to import prometheus_client after setting the env variable `PROMETHEUS_MULTIPROC_DIR`
from
prometheus_client
import
Gauge
,
Histogram
from
prometheus_client
import
Counter
,
Gauge
self
.
labels
=
labels
self
.
labels
=
labels
self
.
last_log_time
=
time
.
time
()
self
.
last_log_time
=
time
.
time
()
...
@@ -74,6 +191,13 @@ class SchedulerMetricsCollector:
...
@@ -74,6 +191,13 @@ class SchedulerMetricsCollector:
multiprocess_mode
=
"mostrecent"
,
multiprocess_mode
=
"mostrecent"
,
)
)
self
.
num_grammar_queue_reqs
=
Gauge
(
name
=
"sglang:num_grammar_queue_reqs"
,
documentation
=
"The number of requests in the grammar waiting queue."
,
labelnames
=
labels
.
keys
(),
multiprocess_mode
=
"mostrecent"
,
)
self
.
cache_hit_rate
=
Gauge
(
self
.
cache_hit_rate
=
Gauge
(
name
=
"sglang:cache_hit_rate"
,
name
=
"sglang:cache_hit_rate"
,
documentation
=
"The prefix cache hit rate."
,
documentation
=
"The prefix cache hit rate."
,
...
@@ -95,28 +219,98 @@ class SchedulerMetricsCollector:
...
@@ -95,28 +219,98 @@ class SchedulerMetricsCollector:
multiprocess_mode
=
"mostrecent"
,
multiprocess_mode
=
"mostrecent"
,
)
)
# Disaggregation queue metrics
self
.
num_prefill_prealloc_queue_reqs
=
Gauge
(
name
=
"sglang:num_prefill_prealloc_queue_reqs"
,
documentation
=
"The number of requests in the prefill prealloc queue."
,
labelnames
=
labels
.
keys
(),
multiprocess_mode
=
"mostrecent"
,
)
self
.
num_prefill_infight_queue_reqs
=
Gauge
(
name
=
"sglang:num_prefill_infight_queue_reqs"
,
documentation
=
"The number of requests in the prefill infight queue."
,
labelnames
=
labels
.
keys
(),
multiprocess_mode
=
"mostrecent"
,
)
self
.
num_decode_prealloc_queue_reqs
=
Gauge
(
name
=
"sglang:num_decode_prealloc_queue_reqs"
,
documentation
=
"The number of requests in the decode prealloc queue."
,
labelnames
=
labels
.
keys
(),
multiprocess_mode
=
"mostrecent"
,
)
self
.
num_decode_transfer_queue_reqs
=
Gauge
(
name
=
"sglang:num_decode_transfer_queue_reqs"
,
documentation
=
"The number of requests in the decode transfer queue."
,
labelnames
=
labels
.
keys
(),
multiprocess_mode
=
"mostrecent"
,
)
self
.
num_bootstrap_failed_reqs
=
Counter
(
name
=
"sglang:num_bootstrap_failed_reqs"
,
documentation
=
"The number of bootstrap failed requests."
,
labelnames
=
labels
.
keys
(),
)
self
.
num_transfer_failed_reqs
=
Counter
(
name
=
"sglang:num_transfer_failed_reqs"
,
documentation
=
"The number of transfer failed requests."
,
labelnames
=
labels
.
keys
(),
)
def
_log_gauge
(
self
,
gauge
,
data
:
Union
[
int
,
float
])
->
None
:
def
_log_gauge
(
self
,
gauge
,
data
:
Union
[
int
,
float
])
->
None
:
# Convenience function for logging to gauge.
# Convenience function for logging to gauge.
gauge
.
labels
(
**
self
.
labels
).
set
(
data
)
gauge
.
labels
(
**
self
.
labels
).
set
(
data
)
def
increment_bootstrap_failed_reqs
(
self
)
->
None
:
self
.
num_bootstrap_failed_reqs
.
labels
(
**
self
.
labels
).
inc
(
1
)
def
increment_transfer_failed_reqs
(
self
)
->
None
:
self
.
num_transfer_failed_reqs
.
labels
(
**
self
.
labels
).
inc
(
1
)
def
log_stats
(
self
,
stats
:
SchedulerStats
)
->
None
:
def
log_stats
(
self
,
stats
:
SchedulerStats
)
->
None
:
self
.
_log_gauge
(
self
.
num_running_reqs
,
stats
.
num_running_reqs
)
self
.
_log_gauge
(
self
.
num_running_reqs
,
stats
.
num_running_reqs
)
self
.
_log_gauge
(
self
.
num_used_tokens
,
stats
.
num_used_tokens
)
self
.
_log_gauge
(
self
.
num_used_tokens
,
stats
.
num_used_tokens
)
self
.
_log_gauge
(
self
.
token_usage
,
stats
.
token_usage
)
self
.
_log_gauge
(
self
.
token_usage
,
stats
.
token_usage
)
self
.
_log_gauge
(
self
.
gen_throughput
,
stats
.
gen_throughput
)
self
.
_log_gauge
(
self
.
gen_throughput
,
stats
.
gen_throughput
)
self
.
_log_gauge
(
self
.
num_queue_reqs
,
stats
.
num_queue_reqs
)
self
.
_log_gauge
(
self
.
num_queue_reqs
,
stats
.
num_queue_reqs
)
self
.
_log_gauge
(
self
.
num_grammar_queue_reqs
,
stats
.
num_grammar_queue_reqs
)
self
.
_log_gauge
(
self
.
cache_hit_rate
,
stats
.
cache_hit_rate
)
self
.
_log_gauge
(
self
.
cache_hit_rate
,
stats
.
cache_hit_rate
)
self
.
_log_gauge
(
self
.
spec_accept_length
,
stats
.
spec_accept_length
)
self
.
_log_gauge
(
self
.
spec_accept_length
,
stats
.
spec_accept_length
)
self
.
_log_gauge
(
self
.
avg_request_queue_latency
,
stats
.
avg_request_queue_latency
)
# Disaggregation metrics
self
.
_log_gauge
(
self
.
num_prefill_prealloc_queue_reqs
,
stats
.
num_prefill_prealloc_queue_reqs
)
self
.
_log_gauge
(
self
.
num_prefill_infight_queue_reqs
,
stats
.
num_prefill_infight_queue_reqs
)
self
.
_log_gauge
(
self
.
num_decode_prealloc_queue_reqs
,
stats
.
num_decode_prealloc_queue_reqs
)
self
.
_log_gauge
(
self
.
num_decode_transfer_queue_reqs
,
stats
.
num_decode_transfer_queue_reqs
)
self
.
last_log_time
=
time
.
time
()
self
.
last_log_time
=
time
.
time
()
class
TokenizerMetricsCollector
:
class
TokenizerMetricsCollector
:
def
__init__
(
self
,
labels
:
Dict
[
str
,
str
])
->
None
:
def
__init__
(
self
,
labels
:
Dict
[
str
,
str
],
bucket_time_to_first_token
:
Optional
[
List
[
float
]]
=
None
,
bucket_inter_token_latency
:
Optional
[
List
[
float
]]
=
None
,
bucket_e2e_request_latency
:
Optional
[
List
[
float
]]
=
None
,
collect_tokens_histogram
:
bool
=
False
,
)
->
None
:
# We need to import prometheus_client after setting the env variable `PROMETHEUS_MULTIPROC_DIR`
# We need to import prometheus_client after setting the env variable `PROMETHEUS_MULTIPROC_DIR`
from
prometheus_client
import
Counter
,
Histogram
from
prometheus_client
import
Counter
,
Histogram
self
.
labels
=
labels
self
.
labels
=
labels
self
.
collect_tokens_histogram
=
collect_tokens_histogram
self
.
prompt_tokens_total
=
Counter
(
self
.
prompt_tokens_total
=
Counter
(
name
=
"sglang:prompt_tokens_total"
,
name
=
"sglang:prompt_tokens_total"
,
...
@@ -130,6 +324,66 @@ class TokenizerMetricsCollector:
...
@@ -130,6 +324,66 @@ class TokenizerMetricsCollector:
labelnames
=
labels
.
keys
(),
labelnames
=
labels
.
keys
(),
)
)
if
collect_tokens_histogram
:
bucket_prompt_tokens
=
[
100
,
300
,
500
,
700
,
1000
,
1500
,
2000
,
3000
,
4000
,
5000
,
6000
,
7000
,
8000
,
9000
,
10000
,
12000
,
15000
,
20000
,
22000
,
25000
,
30000
,
35000
,
40000
,
]
self
.
prompt_tokens_histogram
=
Histogram
(
name
=
"sglang:prompt_tokens_histogram"
,
documentation
=
"Histogram of prompt token length."
,
labelnames
=
labels
.
keys
(),
buckets
=
bucket_prompt_tokens
,
)
bucket_generation_tokens
=
[
100
,
300
,
500
,
1000
,
1200
,
1500
,
1700
,
2000
,
2500
,
3000
,
3500
,
4000
,
4500
,
5000
,
6000
,
7000
,
8000
,
9000
,
10000
,
]
self
.
generation_tokens_histogram
=
Histogram
(
name
=
"sglang:generation_tokens_histogram"
,
documentation
=
"Histogram of generation token length."
,
labelnames
=
labels
.
keys
(),
buckets
=
bucket_generation_tokens
,
)
self
.
cached_tokens_total
=
Counter
(
self
.
cached_tokens_total
=
Counter
(
name
=
"sglang:cached_tokens_total"
,
name
=
"sglang:cached_tokens_total"
,
documentation
=
"Number of cached prompt tokens."
,
documentation
=
"Number of cached prompt tokens."
,
...
@@ -142,11 +396,14 @@ class TokenizerMetricsCollector:
...
@@ -142,11 +396,14 @@ class TokenizerMetricsCollector:
labelnames
=
labels
.
keys
(),
labelnames
=
labels
.
keys
(),
)
)
self
.
histogram_time_to_fir
st_to
ken
=
Histogram
(
self
.
num_so_reque
st
s
_to
tal
=
Counter
(
name
=
"sglang:
time_to_fir
st_to
ken_seconds
"
,
name
=
"sglang:
num_so_reque
st
s
_to
tal
"
,
documentation
=
"
Histogram of time to first token in seconds
."
,
documentation
=
"
Number of structured output requests processed
."
,
labelnames
=
labels
.
keys
(),
labelnames
=
labels
.
keys
(),
buckets
=
[
)
if
bucket_time_to_first_token
is
None
:
bucket_time_to_first_token
=
[
0.1
,
0.1
,
0.2
,
0.2
,
0.4
,
0.4
,
...
@@ -165,14 +422,33 @@ class TokenizerMetricsCollector:
...
@@ -165,14 +422,33 @@ class TokenizerMetricsCollector:
100
,
100
,
200
,
200
,
400
,
400
,
],
]
)
self
.
histogram_inter_token_latency_seconds
=
Histogram
(
if
bucket_e2e_request_latency
is
None
:
name
=
"sglang:inter_token_latency_seconds"
,
bucket_e2e_request_latency
=
[
documentation
=
"Histogram of inter-token latency in seconds."
,
0.1
,
labelnames
=
labels
.
keys
(),
0.2
,
buckets
=
[
0.4
,
0.6
,
0.8
,
1
,
2
,
4
,
6
,
8
,
10
,
20
,
40
,
60
,
80
,
100
,
200
,
400
,
800
,
]
if
bucket_inter_token_latency
is
None
:
bucket_inter_token_latency
=
[
0.002
,
0.002
,
0.004
,
0.004
,
0.006
,
0.006
,
...
@@ -196,34 +472,27 @@ class TokenizerMetricsCollector:
...
@@ -196,34 +472,27 @@ class TokenizerMetricsCollector:
4.000
,
4.000
,
6.000
,
6.000
,
8.000
,
8.000
,
],
]
self
.
histogram_time_to_first_token
=
Histogram
(
name
=
"sglang:time_to_first_token_seconds"
,
documentation
=
"Histogram of time to first token in seconds."
,
labelnames
=
labels
.
keys
(),
buckets
=
bucket_time_to_first_token
,
)
self
.
histogram_inter_token_latency_seconds
=
Histogram
(
name
=
"sglang:inter_token_latency_seconds"
,
documentation
=
"Histogram of inter-token latency in seconds."
,
labelnames
=
labels
.
keys
(),
buckets
=
bucket_inter_token_latency
,
)
)
self
.
histogram_e2e_request_latency
=
Histogram
(
self
.
histogram_e2e_request_latency
=
Histogram
(
name
=
"sglang:e2e_request_latency_seconds"
,
name
=
"sglang:e2e_request_latency_seconds"
,
documentation
=
"Histogram of End-to-end request latency in seconds"
,
documentation
=
"Histogram of End-to-end request latency in seconds"
,
labelnames
=
labels
.
keys
(),
labelnames
=
labels
.
keys
(),
buckets
=
[
buckets
=
bucket_e2e_request_latency
,
0.1
,
0.2
,
0.4
,
0.6
,
0.8
,
1
,
2
,
4
,
6
,
8
,
10
,
20
,
40
,
60
,
80
,
100
,
200
,
400
,
800
,
],
)
)
def
_log_histogram
(
self
,
histogram
,
data
:
Union
[
int
,
float
])
->
None
:
def
_log_histogram
(
self
,
histogram
,
data
:
Union
[
int
,
float
])
->
None
:
...
@@ -235,13 +504,19 @@ class TokenizerMetricsCollector:
...
@@ -235,13 +504,19 @@ class TokenizerMetricsCollector:
generation_tokens
:
int
,
generation_tokens
:
int
,
cached_tokens
:
int
,
cached_tokens
:
int
,
e2e_latency
:
float
,
e2e_latency
:
float
,
has_grammar
:
bool
,
):
):
self
.
prompt_tokens_total
.
labels
(
**
self
.
labels
).
inc
(
prompt_tokens
)
self
.
prompt_tokens_total
.
labels
(
**
self
.
labels
).
inc
(
prompt_tokens
)
self
.
generation_tokens_total
.
labels
(
**
self
.
labels
).
inc
(
generation_tokens
)
self
.
generation_tokens_total
.
labels
(
**
self
.
labels
).
inc
(
generation_tokens
)
if
cached_tokens
>
0
:
if
cached_tokens
>
0
:
self
.
cached_tokens_total
.
labels
(
**
self
.
labels
).
inc
(
cached_tokens
)
self
.
cached_tokens_total
.
labels
(
**
self
.
labels
).
inc
(
cached_tokens
)
self
.
num_requests_total
.
labels
(
**
self
.
labels
).
inc
(
1
)
self
.
num_requests_total
.
labels
(
**
self
.
labels
).
inc
(
1
)
if
has_grammar
:
self
.
num_so_requests_total
.
labels
(
**
self
.
labels
).
inc
(
1
)
self
.
_log_histogram
(
self
.
histogram_e2e_request_latency
,
e2e_latency
)
self
.
_log_histogram
(
self
.
histogram_e2e_request_latency
,
e2e_latency
)
if
self
.
collect_tokens_histogram
:
self
.
_log_histogram
(
self
.
prompt_tokens_histogram
,
prompt_tokens
)
self
.
_log_histogram
(
self
.
generation_tokens_histogram
,
generation_tokens
)
def
observe_time_to_first_token
(
self
,
value
:
float
):
def
observe_time_to_first_token
(
self
,
value
:
float
):
self
.
histogram_time_to_first_token
.
labels
(
**
self
.
labels
).
observe
(
value
)
self
.
histogram_time_to_first_token
.
labels
(
**
self
.
labels
).
observe
(
value
)
...
...
test/srt/test_json_constrained.py
View file @
01bdbf7f
...
@@ -82,7 +82,7 @@ class TestJSONConstrainedOutlinesBackend(CustomTestCase):
...
@@ -82,7 +82,7 @@ class TestJSONConstrainedOutlinesBackend(CustomTestCase):
print
(
json
.
dumps
(
ret
))
print
(
json
.
dumps
(
ret
))
print
(
"="
*
100
)
print
(
"="
*
100
)
if
not
json_schema
:
if
not
json_schema
or
json_schema
==
"INVALID"
:
return
return
# Make sure the json output is valid
# Make sure the json output is valid
...
@@ -97,6 +97,9 @@ class TestJSONConstrainedOutlinesBackend(CustomTestCase):
...
@@ -97,6 +97,9 @@ class TestJSONConstrainedOutlinesBackend(CustomTestCase):
def
test_json_generate
(
self
):
def
test_json_generate
(
self
):
self
.
run_decode
(
json_schema
=
self
.
json_schema
)
self
.
run_decode
(
json_schema
=
self
.
json_schema
)
def
test_json_invalid
(
self
):
self
.
run_decode
(
json_schema
=
"INVALID"
)
def
test_json_openai
(
self
):
def
test_json_openai
(
self
):
client
=
openai
.
Client
(
api_key
=
"EMPTY"
,
base_url
=
f
"
{
self
.
base_url
}
/v1"
)
client
=
openai
.
Client
(
api_key
=
"EMPTY"
,
base_url
=
f
"
{
self
.
base_url
}
/v1"
)
...
@@ -104,7 +107,10 @@ class TestJSONConstrainedOutlinesBackend(CustomTestCase):
...
@@ -104,7 +107,10 @@ class TestJSONConstrainedOutlinesBackend(CustomTestCase):
model
=
self
.
model
,
model
=
self
.
model
,
messages
=
[
messages
=
[
{
"role"
:
"system"
,
"content"
:
"You are a helpful AI assistant"
},
{
"role"
:
"system"
,
"content"
:
"You are a helpful AI assistant"
},
{
"role"
:
"user"
,
"content"
:
"Introduce the capital of France."
},
{
"role"
:
"user"
,
"content"
:
"Introduce the capital of France. Return in a JSON format."
,
},
],
],
temperature
=
0
,
temperature
=
0
,
max_tokens
=
128
,
max_tokens
=
128
,
...
...
test/srt/test_metrics.py
View file @
01bdbf7f
...
@@ -56,6 +56,7 @@ class TestEnableMetrics(CustomTestCase):
...
@@ -56,6 +56,7 @@ class TestEnableMetrics(CustomTestCase):
"sglang:token_usage"
,
"sglang:token_usage"
,
"sglang:gen_throughput"
,
"sglang:gen_throughput"
,
"sglang:num_queue_reqs"
,
"sglang:num_queue_reqs"
,
"sglang:num_grammar_queue_reqs"
,
"sglang:cache_hit_rate"
,
"sglang:cache_hit_rate"
,
"sglang:spec_accept_length"
,
"sglang:spec_accept_length"
,
"sglang:prompt_tokens_total"
,
"sglang:prompt_tokens_total"
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment