Unverified Commit 01bdbf7f authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Improve structured outputs: fix race condition, server crash, metrics and style (#6188)

parent 94d42b67
...@@ -94,8 +94,8 @@ ...@@ -94,8 +94,8 @@
" model=\"deepseek-ai/DeepSeek-R1-Distill-Qwen-7B\",\n", " model=\"deepseek-ai/DeepSeek-R1-Distill-Qwen-7B\",\n",
" messages=[\n", " messages=[\n",
" {\n", " {\n",
" \"role\": \"user\",\n", " \"role\": \"assistant\",\n",
" \"content\": \"Please generate the information of the capital of France in the JSON format.\",\n", " \"content\": \"Give me the information and population of the capital of France in the JSON format.\",\n",
" },\n", " },\n",
" ],\n", " ],\n",
" temperature=0,\n", " temperature=0,\n",
...@@ -145,8 +145,8 @@ ...@@ -145,8 +145,8 @@
" model=\"deepseek-ai/DeepSeek-R1-Distill-Qwen-7B\",\n", " model=\"deepseek-ai/DeepSeek-R1-Distill-Qwen-7B\",\n",
" messages=[\n", " messages=[\n",
" {\n", " {\n",
" \"role\": \"user\",\n", " \"role\": \"assistant\",\n",
" \"content\": \"Give me the information of the capital of France in the JSON format.\",\n", " \"content\": \"Give me the information and population of the capital of France in the JSON format.\",\n",
" },\n", " },\n",
" ],\n", " ],\n",
" temperature=0,\n", " temperature=0,\n",
...@@ -188,8 +188,8 @@ ...@@ -188,8 +188,8 @@
" messages=[\n", " messages=[\n",
" {\"role\": \"system\", \"content\": \"You are a helpful geography bot.\"},\n", " {\"role\": \"system\", \"content\": \"You are a helpful geography bot.\"},\n",
" {\n", " {\n",
" \"role\": \"user\",\n", " \"role\": \"assistant\",\n",
" \"content\": \"Give me the information of the capital of France.\",\n", " \"content\": \"Give me the information and population of the capital of France in the JSON format.\",\n",
" },\n", " },\n",
" ],\n", " ],\n",
" temperature=0,\n", " temperature=0,\n",
...@@ -218,7 +218,7 @@ ...@@ -218,7 +218,7 @@
"response = client.chat.completions.create(\n", "response = client.chat.completions.create(\n",
" model=\"deepseek-ai/DeepSeek-R1-Distill-Qwen-7B\",\n", " model=\"deepseek-ai/DeepSeek-R1-Distill-Qwen-7B\",\n",
" messages=[\n", " messages=[\n",
" {\"role\": \"user\", \"content\": \"What is the capital of France?\"},\n", " {\"role\": \"assistant\", \"content\": \"What is the capital of France?\"},\n",
" ],\n", " ],\n",
" temperature=0,\n", " temperature=0,\n",
" max_tokens=2048,\n", " max_tokens=2048,\n",
...@@ -323,7 +323,7 @@ ...@@ -323,7 +323,7 @@
"You are a helpful assistant.\"\"\",\n", "You are a helpful assistant.\"\"\",\n",
" },\n", " },\n",
" {\n", " {\n",
" \"role\": \"user\",\n", " \"role\": \"assistant\",\n",
" \"content\": \"You are in New York. Please get the current date and time, and the weather.\",\n", " \"content\": \"You are in New York. Please get the current date and time, and the weather.\",\n",
" },\n", " },\n",
" ]\n", " ]\n",
...@@ -400,9 +400,9 @@ ...@@ -400,9 +400,9 @@
"\n", "\n",
"messages = [\n", "messages = [\n",
" {\n", " {\n",
" \"role\": \"user\",\n", " \"role\": \"assistant\",\n",
" \"content\": \"Here is the information of the capital of France in the JSON format.\\n\",\n", " \"content\": \"Give me the information and population of the capital of France in the JSON format.\",\n",
" }\n", " },\n",
"]\n", "]\n",
"text = tokenizer.apply_chat_template(\n", "text = tokenizer.apply_chat_template(\n",
" messages, tokenize=False, add_generation_prompt=True\n", " messages, tokenize=False, add_generation_prompt=True\n",
...@@ -452,7 +452,9 @@ ...@@ -452,7 +452,9 @@
")\n", ")\n",
"\n", "\n",
"# JSON\n", "# JSON\n",
"text = tokenizer.apply_chat_template(text, tokenize=False, add_generation_prompt=True)\n", "text = tokenizer.apply_chat_template(\n",
" messages, tokenize=False, add_generation_prompt=True\n",
")\n",
"response = requests.post(\n", "response = requests.post(\n",
" f\"http://localhost:{port}/generate\",\n", " f\"http://localhost:{port}/generate\",\n",
" json={\n", " json={\n",
......
...@@ -14,10 +14,9 @@ ...@@ -14,10 +14,9 @@
"""The baseclass of a backend for grammar-guided constrained decoding.""" """The baseclass of a backend for grammar-guided constrained decoding."""
import logging import logging
from abc import ABC, abstractmethod from concurrent.futures import ThreadPoolExecutor
from concurrent.futures import Future, ThreadPoolExecutor
from dataclasses import dataclass from dataclasses import dataclass
from threading import Event, Lock from threading import Event
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
import torch import torch
...@@ -27,11 +26,36 @@ from sglang.srt.server_args import ServerArgs ...@@ -27,11 +26,36 @@ from sglang.srt.server_args import ServerArgs
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class BaseGrammarObject(ABC): class BaseGrammarObject:
def __init__(self): def __init__(self):
self._finished = False self._finished = False
def accept_token(self, token: int) -> None:
"""
Accept a token in the grammar.
"""
raise NotImplementedError()
def allocate_vocab_mask(
self, vocab_size: int, batch_size: int, device
) -> torch.Tensor:
raise NotImplementedError()
def fill_vocab_mask(self, vocab_mask: torch.Tensor, idx: int) -> None:
raise NotImplementedError()
@staticmethod
def move_vocab_mask(vocab_mask: torch.Tensor, device) -> torch.Tensor:
raise NotImplementedError()
@staticmethod
def apply_vocab_mask(logits: torch.Tensor, vocab_mask: torch.Tensor) -> None:
raise NotImplementedError()
def copy(self) -> "BaseGrammarObject":
raise NotImplementedError()
@property @property
def finished(self): def finished(self):
return self._finished return self._finished
...@@ -40,7 +64,6 @@ class BaseGrammarObject(ABC): ...@@ -40,7 +64,6 @@ class BaseGrammarObject(ABC):
def finished(self, finished): def finished(self, finished):
self._finished = finished self._finished = finished
@abstractmethod
def try_jump_forward(self, tokenizer) -> Optional[Tuple[List[int], str]]: def try_jump_forward(self, tokenizer) -> Optional[Tuple[List[int], str]]:
""" """
Try to jump forward in the grammar. Try to jump forward in the grammar.
...@@ -49,9 +72,8 @@ class BaseGrammarObject(ABC): ...@@ -49,9 +72,8 @@ class BaseGrammarObject(ABC):
A jump forward helper which may be used in `jump_forward_str_state`. A jump forward helper which may be used in `jump_forward_str_state`.
None if the jump forward is not possible. None if the jump forward is not possible.
""" """
raise NotImplementedError raise NotImplementedError()
@abstractmethod
def jump_forward_str_state(self, helper: Tuple[List[int], str]) -> Tuple[str, int]: def jump_forward_str_state(self, helper: Tuple[List[int], str]) -> Tuple[str, int]:
""" """
Jump forward for the grammar. Jump forward for the grammar.
...@@ -60,47 +82,15 @@ class BaseGrammarObject(ABC): ...@@ -60,47 +82,15 @@ class BaseGrammarObject(ABC):
A tuple of the jump forward string and the next state of the grammar A tuple of the jump forward string and the next state of the grammar
(which can be used in `jump_and_retokenize` if needed). (which can be used in `jump_and_retokenize` if needed).
""" """
raise NotImplementedError raise NotImplementedError()
@abstractmethod
def jump_and_retokenize( def jump_and_retokenize(
self, old_output_ids: List[int], new_output_ids: List[int], next_state: int self, old_output_ids: List[int], new_output_ids: List[int], next_state: int
) -> None: ) -> None:
""" """
Jump forward occurs, and update the grammar state if needed. Jump forward occurs, and update the grammar state if needed.
""" """
raise NotImplementedError raise NotImplementedError()
@abstractmethod
def accept_token(self, token: int) -> None:
"""
Accept a token in the grammar.
"""
raise NotImplementedError
@abstractmethod
def allocate_vocab_mask(
self, vocab_size: int, batch_size: int, device
) -> torch.Tensor:
raise NotImplementedError
@abstractmethod
def fill_vocab_mask(self, vocab_mask: torch.Tensor, idx: int) -> None:
raise NotImplementedError
@staticmethod
@abstractmethod
def move_vocab_mask(vocab_mask: torch.Tensor, device) -> torch.Tensor:
raise NotImplementedError
@staticmethod
@abstractmethod
def apply_vocab_mask(logits: torch.Tensor, vocab_mask: torch.Tensor) -> None:
raise NotImplementedError
@abstractmethod
def copy(self) -> "BaseGrammarObject":
raise NotImplementedError
@dataclass @dataclass
...@@ -113,10 +103,9 @@ class BaseGrammarBackend: ...@@ -113,10 +103,9 @@ class BaseGrammarBackend:
def __init__(self): def __init__(self):
self.executor = ThreadPoolExecutor() self.executor = ThreadPoolExecutor()
self.cache: Dict[Tuple[str, str], CacheEntry] = {} self.cache: Dict[Tuple[str, str], CacheEntry] = {}
self.cache_lock = Lock()
def _not_supported(self, key_type: str, key_string: str) -> None: def _not_supported(self, key_type: str, key_string: str) -> None:
logger.warning(f"Skip unsupported {key_type}: {key_type}={key_string}") logger.warning(f"Skip unsupported {key_type=}, {key_string=}")
def dispatch_fallback( def dispatch_fallback(
self, key_type: str, key_string: str self, key_type: str, key_string: str
...@@ -148,40 +137,25 @@ class BaseGrammarBackend: ...@@ -148,40 +137,25 @@ class BaseGrammarBackend:
return self.dispatch_ebnf(key_string) return self.dispatch_ebnf(key_string)
elif key_type == "structural_tag": elif key_type == "structural_tag":
return self.dispatch_structural_tag(key_string) return self.dispatch_structural_tag(key_string)
elif key_type == "structural_pattern":
return self.dispatch_structural_pattern(key_string)
else: else:
return self.dispatch_fallback(key_type, key_string) return self.dispatch_fallback(key_type, key_string)
def _init_value(self, key: Tuple[str, str]) -> Optional[BaseGrammarObject]: def get_cached_or_future_value(
with self.cache_lock: self, key: Tuple[str, str]
if key in self.cache: ) -> Optional[BaseGrammarObject]:
cache_hit = True value = self.cache.get(key)
entry = self.cache[key] if value:
else: return value.copy(), True
cache_hit = False value = self.executor.submit(self._init_value_dispatch, key)
entry = CacheEntry(None, Event()) return value, False
self.cache[key] = entry
if cache_hit:
entry.event.wait()
else:
entry.value = self._init_value_dispatch(key)
entry.event.set()
return entry.value.copy() if entry.value else None
def get_cached_value(self, key: Tuple[str, str]) -> Optional[BaseGrammarObject]:
with self.cache_lock:
entry = self.cache.get(key)
if not entry or not entry.event.is_set():
return None
val = self.cache[key].value
return val.copy() if val else None
def get_future_value(self, key: Tuple[str, str]) -> Future: def set_cache(self, key: Tuple[str, str], value: BaseGrammarObject):
return self.executor.submit(self._init_value, key) self.cache[key] = value
def reset(self): def reset(self):
with self.cache_lock: self.cache.clear()
self.cache.clear()
def create_grammar_backend( def create_grammar_backend(
...@@ -211,9 +185,12 @@ def create_grammar_backend( ...@@ -211,9 +185,12 @@ def create_grammar_backend(
raise ValueError(f"Invalid grammar backend: {server_args.grammar_backend}") raise ValueError(f"Invalid grammar backend: {server_args.grammar_backend}")
if server_args.reasoning_parser and hasattr(tokenizer, "think_end_id"): if server_args.reasoning_parser and hasattr(tokenizer, "think_end_id"):
from .reasoner_grammar_backend import ReasonerGrammarBackend from sglang.srt.constrained.reasoner_grammar_backend import (
ReasonerGrammarBackend,
)
grammar_backend = ReasonerGrammarBackend( grammar_backend = ReasonerGrammarBackend(
grammar_backend, tokenizer.think_end_id grammar_backend, tokenizer.think_end_id
) )
return grammar_backend return grammar_backend
...@@ -50,21 +50,6 @@ class GuidanceGrammar(BaseGrammarObject): ...@@ -50,21 +50,6 @@ class GuidanceGrammar(BaseGrammarObject):
self.finished = False self.finished = False
self.bitmask = None self.bitmask = None
def try_jump_forward(self, tokenizer) -> Optional[Tuple[List[int], str]]:
ff_tokens = self.ll_matcher.compute_ff_tokens()
if ff_tokens:
return ff_tokens, ""
else:
return None
def jump_forward_str_state(self, helper: Tuple[List[int], str]) -> Tuple[str, int]:
return "", -1
def jump_and_retokenize(
self, old_output_ids: List[int], new_output_ids: List[int], next_state: int
):
pass
def accept_token(self, token: int): def accept_token(self, token: int):
if not self.ll_matcher.consume_token(token): if not self.ll_matcher.consume_token(token):
logger.warning(f"matcher error: {self.ll_matcher.get_error()}") logger.warning(f"matcher error: {self.ll_matcher.get_error()}")
...@@ -104,6 +89,21 @@ class GuidanceGrammar(BaseGrammarObject): ...@@ -104,6 +89,21 @@ class GuidanceGrammar(BaseGrammarObject):
serialized_grammar=self.serialized_grammar, serialized_grammar=self.serialized_grammar,
) )
def try_jump_forward(self, tokenizer) -> Optional[Tuple[List[int], str]]:
ff_tokens = self.ll_matcher.compute_ff_tokens()
if ff_tokens:
return ff_tokens, ""
else:
return None
def jump_forward_str_state(self, helper: Tuple[List[int], str]) -> Tuple[str, int]:
return "", -1
def jump_and_retokenize(
self, old_output_ids: List[int], new_output_ids: List[int], next_state: int
):
pass
class GuidanceBackend(BaseGrammarBackend): class GuidanceBackend(BaseGrammarBackend):
...@@ -130,12 +130,16 @@ class GuidanceBackend(BaseGrammarBackend): ...@@ -130,12 +130,16 @@ class GuidanceBackend(BaseGrammarBackend):
return None return None
def dispatch_json(self, key_string: str) -> Optional[GuidanceGrammar]: def dispatch_json(self, key_string: str) -> Optional[GuidanceGrammar]:
serialized_grammar = LLMatcher.grammar_from_json_schema( try:
key_string, serialized_grammar = LLMatcher.grammar_from_json_schema(
defaults={ key_string,
"whitespace_pattern": self.whitespace_pattern, defaults={
}, "whitespace_pattern": self.whitespace_pattern,
) },
)
except Exception as e:
logger.warning(f"Skip invalid grammar: {key_string=}, {e=}")
return None
return self._from_serialized(serialized_grammar) return self._from_serialized(serialized_grammar)
def dispatch_regex(self, key_string: str) -> Optional[GuidanceGrammar]: def dispatch_regex(self, key_string: str) -> Optional[GuidanceGrammar]:
......
...@@ -53,6 +53,30 @@ class OutlinesGrammar(BaseGrammarObject): ...@@ -53,6 +53,30 @@ class OutlinesGrammar(BaseGrammarObject):
def accept_token(self, token: int): def accept_token(self, token: int):
self.state = self.guide.get_next_state(self.state, token) self.state = self.guide.get_next_state(self.state, token)
def allocate_vocab_mask(
self, vocab_size: int, batch_size: int, device
) -> torch.Tensor:
return torch.zeros(batch_size, vocab_size, dtype=torch.bool, device=device)
@staticmethod
def move_vocab_mask(vocab_mask: torch.Tensor, device) -> torch.Tensor:
return vocab_mask
def fill_vocab_mask(self, vocab_mask: torch.Tensor, idx: int) -> None:
tokens = torch.tensor(
self.guide.get_next_instruction(self.state).tokens, dtype=torch.int64
).to(vocab_mask.device, non_blocking=True)
vocab_mask = vocab_mask[idx]
vocab_mask.fill_(1)
vocab_mask.scatter_(0, tokens, torch.zeros_like(tokens, dtype=torch.bool))
@staticmethod
def apply_vocab_mask(logits: torch.Tensor, vocab_mask: torch.Tensor):
logits.masked_fill_(vocab_mask, float("-inf"))
def copy(self):
return OutlinesGrammar(self.guide, self.jump_forward_map)
def try_jump_forward(self, tokenizer) -> Optional[Tuple]: def try_jump_forward(self, tokenizer) -> Optional[Tuple]:
if not self.jump_forward_map: if not self.jump_forward_map:
return None return None
...@@ -86,30 +110,6 @@ class OutlinesGrammar(BaseGrammarObject): ...@@ -86,30 +110,6 @@ class OutlinesGrammar(BaseGrammarObject):
): ):
self.state = next_state self.state = next_state
def allocate_vocab_mask(
self, vocab_size: int, batch_size: int, device
) -> torch.Tensor:
return torch.zeros(batch_size, vocab_size, dtype=torch.bool, device=device)
@staticmethod
def move_vocab_mask(vocab_mask: torch.Tensor, device) -> torch.Tensor:
return vocab_mask
def fill_vocab_mask(self, vocab_mask: torch.Tensor, idx: int) -> None:
tokens = torch.tensor(
self.guide.get_next_instruction(self.state).tokens, dtype=torch.int64
).to(vocab_mask.device, non_blocking=True)
vocab_mask = vocab_mask[idx]
vocab_mask.fill_(1)
vocab_mask.scatter_(0, tokens, torch.zeros_like(tokens, dtype=torch.bool))
@staticmethod
def apply_vocab_mask(logits: torch.Tensor, vocab_mask: torch.Tensor):
logits.masked_fill_(vocab_mask, float("-inf"))
def copy(self):
return OutlinesGrammar(self.guide, self.jump_forward_map)
class OutlinesGrammarBackend(BaseGrammarBackend): class OutlinesGrammarBackend(BaseGrammarBackend):
def __init__( def __init__(
...@@ -169,8 +169,9 @@ class OutlinesGrammarBackend(BaseGrammarBackend): ...@@ -169,8 +169,9 @@ class OutlinesGrammarBackend(BaseGrammarBackend):
key_string, key_string,
whitespace_pattern=self.whitespace_pattern, whitespace_pattern=self.whitespace_pattern,
) )
except (NotImplementedError, json.decoder.JSONDecodeError) as e: except (NotImplementedError, json.decoder.JSONDecodeError, ValueError) as e:
logger.warning(f"Skip invalid json_schema: json_schema={key_string}, {e=}") logger.warning(f"Skip invalid json_schema: {key_string=}, {e=}")
return None
return self._compile_regex(regex) return self._compile_regex(regex)
def dispatch_regex(self, key_string: str): def dispatch_regex(self, key_string: str):
......
...@@ -13,7 +13,6 @@ ...@@ -13,7 +13,6 @@
# ============================================================================== # ==============================================================================
"""The baseclass of a backend for reasoner grammar-guided constrained decoding.""" """The baseclass of a backend for reasoner grammar-guided constrained decoding."""
from concurrent.futures import Future
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
import torch import torch
...@@ -28,13 +27,12 @@ class ReasonerGrammarObject(BaseGrammarObject): ...@@ -28,13 +27,12 @@ class ReasonerGrammarObject(BaseGrammarObject):
self.think_end_id = think_end_id self.think_end_id = think_end_id
self.is_in_reasoning = True self.is_in_reasoning = True
@property def accept_token(self, token: int):
def finished(self): if token == self.think_end_id:
return self.grammar.finished self.is_in_reasoning = False
@finished.setter if not self.is_in_reasoning and token != self.think_end_id:
def finished(self, finished): self.grammar.accept_token(token)
self.grammar.finished = finished
def allocate_vocab_mask( def allocate_vocab_mask(
self, vocab_size: int, batch_size: int, device self, vocab_size: int, batch_size: int, device
...@@ -52,12 +50,16 @@ class ReasonerGrammarObject(BaseGrammarObject): ...@@ -52,12 +50,16 @@ class ReasonerGrammarObject(BaseGrammarObject):
def apply_vocab_mask(self): def apply_vocab_mask(self):
return self.grammar.apply_vocab_mask return self.grammar.apply_vocab_mask
def accept_token(self, token: int): def copy(self) -> BaseGrammarObject:
if token == self.think_end_id: return ReasonerGrammarObject(self.grammar.copy(), self.think_end_id)
self.is_in_reasoning = False
if not self.is_in_reasoning and token != self.think_end_id: @property
self.grammar.accept_token(token) def finished(self):
return self.grammar.finished
@finished.setter
def finished(self, finished):
self.grammar.finished = finished
def try_jump_forward(self, tokenizer): def try_jump_forward(self, tokenizer):
return self.grammar.try_jump_forward(tokenizer) return self.grammar.try_jump_forward(tokenizer)
...@@ -72,30 +74,17 @@ class ReasonerGrammarObject(BaseGrammarObject): ...@@ -72,30 +74,17 @@ class ReasonerGrammarObject(BaseGrammarObject):
old_output_ids, new_output_ids, next_state old_output_ids, new_output_ids, next_state
) )
def copy(self) -> BaseGrammarObject:
return ReasonerGrammarObject(self.grammar.copy(), self.think_end_id)
class ReasonerGrammarBackend(BaseGrammarBackend): class ReasonerGrammarBackend(BaseGrammarBackend):
def __init__(self, grammar_backend: BaseGrammarBackend, think_end_id): def __init__(self, grammar_backend: BaseGrammarBackend, think_end_id):
super().__init__()
self.grammar_backend = grammar_backend self.grammar_backend = grammar_backend
self.think_end_id = think_end_id self.think_end_id = think_end_id
def get_cached_value(self, key: Tuple[str, str]) -> Optional[ReasonerGrammarObject]: def _init_value_dispatch(
grammar = self.grammar_backend.get_cached_value(key) self, key: Tuple[str, str]
return ReasonerGrammarObject(grammar, self.think_end_id) if grammar else None ) -> Optional[ReasonerGrammarObject]:
ret = self.grammar_backend._init_value_dispatch(key)
def get_future_value(self, key: Tuple[str, str]) -> Future: if ret is None:
grammar = Future() return None
return ReasonerGrammarObject(ret, self.think_end_id)
def callback(f: Future):
if result := f.result():
grammar.set_result(ReasonerGrammarObject(result, self.think_end_id))
else:
grammar.set_result(None)
self.grammar_backend.get_future_value(key).add_done_callback(callback)
return grammar
def reset(self):
self.grammar_backend.reset()
...@@ -18,7 +18,6 @@ import logging ...@@ -18,7 +18,6 @@ import logging
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
import torch import torch
import xgrammar
from xgrammar import ( from xgrammar import (
CompiledGrammar, CompiledGrammar,
GrammarCompiler, GrammarCompiler,
...@@ -35,7 +34,6 @@ from sglang.srt.constrained.base_grammar_backend import ( ...@@ -35,7 +34,6 @@ from sglang.srt.constrained.base_grammar_backend import (
from sglang.srt.constrained.triton_ops.bitmask_ops import ( from sglang.srt.constrained.triton_ops.bitmask_ops import (
apply_token_bitmask_inplace_triton, apply_token_bitmask_inplace_triton,
) )
from sglang.srt.utils import get_bool_env_var
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -51,49 +49,35 @@ class XGrammarGrammar(BaseGrammarObject): ...@@ -51,49 +49,35 @@ class XGrammarGrammar(BaseGrammarObject):
vocab_size: int, vocab_size: int,
ctx: CompiledGrammar, ctx: CompiledGrammar,
override_stop_tokens: Optional[Union[List[int], int]], override_stop_tokens: Optional[Union[List[int], int]],
key_string: Optional[str] = None, # TODO (sk): for debugging, remove later
) -> None: ) -> None:
super().__init__()
self.matcher = matcher self.matcher = matcher
self.vocab_size = vocab_size self.vocab_size = vocab_size
self.ctx = ctx self.ctx = ctx
self.override_stop_tokens = override_stop_tokens self.override_stop_tokens = override_stop_tokens
self.finished = False self.finished = False
self.accepted_tokens = []
from xgrammar.kernels.apply_token_bitmask_inplace_cpu import ( self.key_string = key_string
apply_token_bitmask_inplace_cpu,
)
self.apply_vocab_mask_cpu = apply_token_bitmask_inplace_cpu
def accept_token(self, token: int): def accept_token(self, token: int):
assert self.matcher.accept_token(token) if not self.is_terminated():
accepted = self.matcher.accept_token(token)
def try_jump_forward(self, tokenizer) -> Optional[Tuple[List[int], str]]: if not accepted:
s = self.matcher.find_jump_forward_string() # log for debugging
if s: raise ValueError(
return [], s f"Tokens not accepted: {token}\n"
return None f"Accepted tokens: {self.accepted_tokens}\n"
f"Key string: {self.key_string}"
def jump_forward_str_state(self, helper: Tuple[List[int], str]) -> Tuple[str, int]: )
_, data = helper
return data, -1
def jump_and_retokenize(
self, old_output_ids: List[int], new_output_ids: List[int], next_state: int
):
k = 0
for i, old_id in enumerate(old_output_ids):
if old_id == new_output_ids[i]:
k = i + 1
else: else:
break self.accepted_tokens.append(token)
# rollback to the last token that is the same def rollback(self, k: int):
if k < len(old_output_ids): self.matcher.rollback(k)
self.matcher.rollback(len(old_output_ids) - k) self.accepted_tokens = self.accepted_tokens[:-k]
for i in range(k, len(new_output_ids)): def is_terminated(self):
assert self.matcher.accept_token(new_output_ids[i]) return self.matcher.is_terminated()
def allocate_vocab_mask( def allocate_vocab_mask(
self, vocab_size: int, batch_size: int, device self, vocab_size: int, batch_size: int, device
...@@ -122,9 +106,43 @@ class XGrammarGrammar(BaseGrammarObject): ...@@ -122,9 +106,43 @@ class XGrammarGrammar(BaseGrammarObject):
override_stop_tokens=self.override_stop_tokens, override_stop_tokens=self.override_stop_tokens,
) )
return XGrammarGrammar( return XGrammarGrammar(
matcher, self.vocab_size, self.ctx, self.override_stop_tokens matcher,
self.vocab_size,
self.ctx,
self.override_stop_tokens,
self.key_string,
) )
def try_jump_forward(self, tokenizer) -> Optional[Tuple[List[int], str]]:
s = self.matcher.find_jump_forward_string()
if s:
return [], s
return None
def jump_forward_str_state(self, helper: Tuple[List[int], str]) -> Tuple[str, int]:
_, data = helper
return data, -1
def jump_and_retokenize(
self, old_output_ids: List[int], new_output_ids: List[int], next_state: int
):
k = 0
for i, old_id in enumerate(old_output_ids):
if old_id == new_output_ids[i]:
k = i + 1
else:
break
# rollback to the last token that is the same
if k < len(old_output_ids):
self.matcher.rollback(len(old_output_ids) - k)
for i in range(k, len(new_output_ids)):
assert self.matcher.accept_token(new_output_ids[i])
def __repr__(self):
return f"XGrammarGrammar({self.key_string=}, {self.accepted_tokens=})"
class XGrammarGrammarBackend(BaseGrammarBackend): class XGrammarGrammarBackend(BaseGrammarBackend):
def __init__( def __init__(
...@@ -143,9 +161,15 @@ class XGrammarGrammarBackend(BaseGrammarBackend): ...@@ -143,9 +161,15 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
self.vocab_size = vocab_size self.vocab_size = vocab_size
self.override_stop_tokens = override_stop_tokens self.override_stop_tokens = override_stop_tokens
def _from_context(self, ctx: CompiledGrammar) -> XGrammarGrammar: def _from_context(self, ctx: CompiledGrammar, key_string: str) -> XGrammarGrammar:
matcher = GrammarMatcher(ctx, max_rollback_tokens=MAX_ROLLBACK_TOKENS) matcher = GrammarMatcher(
return XGrammarGrammar(matcher, self.vocab_size, ctx, self.override_stop_tokens) ctx,
max_rollback_tokens=MAX_ROLLBACK_TOKENS,
override_stop_tokens=self.override_stop_tokens,
)
return XGrammarGrammar(
matcher, self.vocab_size, ctx, self.override_stop_tokens, key_string
)
def dispatch_json(self, key_string: str) -> Optional[XGrammarGrammar]: def dispatch_json(self, key_string: str) -> Optional[XGrammarGrammar]:
try: try:
...@@ -157,7 +181,7 @@ class XGrammarGrammarBackend(BaseGrammarBackend): ...@@ -157,7 +181,7 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
except RuntimeError as e: except RuntimeError as e:
logging.warning(f"Skip invalid json_schema: json_schema={key_string}, {e=}") logging.warning(f"Skip invalid json_schema: json_schema={key_string}, {e=}")
return None return None
return self._from_context(ctx) return self._from_context(ctx, key_string)
def dispatch_ebnf(self, key_string: str) -> Optional[XGrammarGrammar]: def dispatch_ebnf(self, key_string: str) -> Optional[XGrammarGrammar]:
try: try:
...@@ -165,7 +189,7 @@ class XGrammarGrammarBackend(BaseGrammarBackend): ...@@ -165,7 +189,7 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
except RuntimeError as e: except RuntimeError as e:
logging.warning(f"Skip invalid ebnf: ebnf={key_string}, {e=}") logging.warning(f"Skip invalid ebnf: ebnf={key_string}, {e=}")
return None return None
return self._from_context(ctx) return self._from_context(ctx, key_string)
def dispatch_regex(self, key_string: str) -> Optional[XGrammarGrammar]: def dispatch_regex(self, key_string: str) -> Optional[XGrammarGrammar]:
try: try:
...@@ -173,7 +197,7 @@ class XGrammarGrammarBackend(BaseGrammarBackend): ...@@ -173,7 +197,7 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
except RuntimeError as e: except RuntimeError as e:
logging.warning(f"Skip invalid regex: regex={key_string}, {e=}") logging.warning(f"Skip invalid regex: regex={key_string}, {e=}")
return None return None
return self._from_context(ctx) return self._from_context(ctx, key_string)
def dispatch_structural_tag(self, key_string: str) -> Optional[XGrammarGrammar]: def dispatch_structural_tag(self, key_string: str) -> Optional[XGrammarGrammar]:
try: try:
...@@ -190,9 +214,11 @@ class XGrammarGrammarBackend(BaseGrammarBackend): ...@@ -190,9 +214,11 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
tags, structural_tag["triggers"] tags, structural_tag["triggers"]
) )
except RuntimeError as e: except RuntimeError as e:
logging.warning(f"Skip invalid regex: regex={key_string}, {e=}") logging.warning(
f"Skip invalid structural_tag: structural_tag={key_string}, {e=}"
)
return None return None
return self._from_context(ctx) return self._from_context(ctx, key_string)
def reset(self): def reset(self):
if self.grammar_compiler: if self.grammar_compiler:
......
...@@ -239,10 +239,6 @@ def top_p_normalize_probs_torch( ...@@ -239,10 +239,6 @@ def top_p_normalize_probs_torch(
def get_top_logprobs(logprobs: torch.Tensor, top_logprobs_nums: List[int]): def get_top_logprobs(logprobs: torch.Tensor, top_logprobs_nums: List[int]):
assert len(top_logprobs_nums) == logprobs.shape[0], (
len(top_logprobs_nums),
logprobs.shape[0],
)
max_k = max(top_logprobs_nums) max_k = max(top_logprobs_nums)
ret = logprobs.topk(max_k, dim=1) ret = logprobs.topk(max_k, dim=1)
values = ret.values.tolist() values = ret.values.tolist()
......
...@@ -533,6 +533,7 @@ class Req: ...@@ -533,6 +533,7 @@ class Req:
# Constrained decoding # Constrained decoding
self.grammar: Optional[BaseGrammarObject] = None self.grammar: Optional[BaseGrammarObject] = None
self.grammar_wait_ct = 0
# The number of cached tokens that were already cached in the KV cache # The number of cached tokens that were already cached in the KV cache
self.cached_tokens = 0 self.cached_tokens = 0
......
...@@ -149,6 +149,7 @@ logger = logging.getLogger(__name__) ...@@ -149,6 +149,7 @@ logger = logging.getLogger(__name__)
# Test retract decode for debugging purposes # Test retract decode for debugging purposes
TEST_RETRACT = get_bool_env_var("SGLANG_TEST_RETRACT") TEST_RETRACT = get_bool_env_var("SGLANG_TEST_RETRACT")
RECORD_STEP_TIME = get_bool_env_var("SGLANG_RECORD_STEP_TIME") RECORD_STEP_TIME = get_bool_env_var("SGLANG_RECORD_STEP_TIME")
GRAMMAR_TIMEOUT = float(os.environ.get("SGLANG_GRAMMAR_TIMEOUT", 300))
@dataclass @dataclass
...@@ -1024,9 +1025,11 @@ class Scheduler( ...@@ -1024,9 +1025,11 @@ class Scheduler(
elif req.sampling_params.structural_tag: elif req.sampling_params.structural_tag:
key = ("structural_tag", req.sampling_params.structural_tag) key = ("structural_tag", req.sampling_params.structural_tag)
req.grammar = self.grammar_backend.get_cached_value(key) value, cache_hit = self.grammar_backend.get_cached_or_future_value(key)
if not req.grammar: req.grammar = value
req.grammar = self.grammar_backend.get_future_value(key)
if not cache_hit:
req.grammar_key = key
add_to_grammar_queue = True add_to_grammar_queue = True
if add_to_grammar_queue: if add_to_grammar_queue:
...@@ -1208,6 +1211,7 @@ class Scheduler( ...@@ -1208,6 +1211,7 @@ class Scheduler(
self.stats.cache_hit_rate = 0.0 self.stats.cache_hit_rate = 0.0
self.stats.gen_throughput = self.last_gen_throughput self.stats.gen_throughput = self.last_gen_throughput
self.stats.num_queue_reqs = len(self.waiting_queue) self.stats.num_queue_reqs = len(self.waiting_queue)
self.stats.num_grammar_queue_reqs = len(self.grammar_queue)
self.stats.spec_accept_length = spec_accept_length self.stats.spec_accept_length = spec_accept_length
self.metrics_collector.log_stats(self.stats) self.metrics_collector.log_stats(self.stats)
...@@ -1255,6 +1259,7 @@ class Scheduler( ...@@ -1255,6 +1259,7 @@ class Scheduler(
self.stats.token_usage = num_used / self.max_total_num_tokens self.stats.token_usage = num_used / self.max_total_num_tokens
self.stats.gen_throughput = 0 self.stats.gen_throughput = 0
self.stats.num_queue_reqs = len(self.waiting_queue) self.stats.num_queue_reqs = len(self.waiting_queue)
self.stats.num_grammar_queue_reqs = len(self.grammar_queue)
self.metrics_collector.log_stats(self.stats) self.metrics_collector.log_stats(self.stats)
def get_next_batch_to_run(self) -> Optional[ScheduleBatch]: def get_next_batch_to_run(self) -> Optional[ScheduleBatch]:
...@@ -1715,11 +1720,17 @@ class Scheduler( ...@@ -1715,11 +1720,17 @@ class Scheduler(
"""Move requests whose grammar objects are ready from grammar_queue to waiting_queue.""" """Move requests whose grammar objects are ready from grammar_queue to waiting_queue."""
num_ready_reqs = 0 num_ready_reqs = 0
num_abort_reqs = 0
for req in self.grammar_queue: for req in self.grammar_queue:
try: try:
req.grammar = req.grammar.result(timeout=0.05) req.grammar = req.grammar.result(timeout=0.03)
if req.grammar:
self.grammar_backend.set_cache(req.grammar_key, req.grammar.copy())
num_ready_reqs += 1 num_ready_reqs += 1
except futures._base.TimeoutError: except futures._base.TimeoutError:
req.grammar_wait_ct += 1
if req.grammar_wait_ct > GRAMMAR_TIMEOUT / 0.03:
num_abort_reqs = 1
break break
if self.server_args.enable_dp_attention: if self.server_args.enable_dp_attention:
...@@ -1731,14 +1742,28 @@ class Scheduler( ...@@ -1731,14 +1742,28 @@ class Scheduler(
if tp_size > 1: if tp_size > 1:
# Sync across TP ranks to make sure they have the same number of ready requests # Sync across TP ranks to make sure they have the same number of ready requests
tensor = torch.tensor(num_ready_reqs, dtype=torch.int32) tensor = torch.tensor([num_ready_reqs, num_abort_reqs], dtype=torch.int32)
torch.distributed.all_reduce( torch.distributed.all_reduce(
tensor, op=torch.distributed.ReduceOp.MAX, group=tp_group tensor, op=torch.distributed.ReduceOp.MAX, group=tp_group
) )
num_ready_reqs_max = tensor.item() num_ready_reqs_max, num_abort_reqs_max = tensor.tolist()
for i in range(num_ready_reqs, num_ready_reqs_max): for i in range(num_ready_reqs, num_ready_reqs_max):
self.grammar_queue[i].grammar = self.grammar_queue[i].grammar.result() req = self.grammar_queue[i]
num_ready_reqs = num_ready_reqs_max req.grammar = req.grammar.result()
if req.grammar:
self.grammar_backend.set_cache(req.grammar_key, req.grammar.copy())
for i in range(num_ready_reqs, num_ready_reqs + num_abort_reqs_max):
req = self.grammar_queue[i]
req.grammar.cancel()
req.grammar = None
error_msg = f"Grammar preprocessing timed out for {req.grammar_key=}"
logger.error(error_msg)
req.finished_reason = FINISH_ABORT(
error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError"
)
num_ready_reqs = num_ready_reqs_max + num_abort_reqs_max
self._extend_requests_to_queue(self.grammar_queue[:num_ready_reqs]) self._extend_requests_to_queue(self.grammar_queue[:num_ready_reqs])
self.grammar_queue = self.grammar_queue[num_ready_reqs:] self.grammar_queue = self.grammar_queue[num_ready_reqs:]
......
...@@ -1230,11 +1230,18 @@ class TokenizerManager: ...@@ -1230,11 +1230,18 @@ class TokenizerManager:
state.last_completion_tokens = completion_tokens state.last_completion_tokens = completion_tokens
if state.finished: if state.finished:
has_grammar = (
state.obj.sampling_params.get("json_schema", None)
or state.obj.sampling_params.get("regex", None)
or state.obj.sampling_params.get("ebnf", None)
or state.obj.sampling_params.get("structural_tag", None)
)
self.metrics_collector.observe_one_finished_request( self.metrics_collector.observe_one_finished_request(
recv_obj.prompt_tokens[i], recv_obj.prompt_tokens[i],
completion_tokens, completion_tokens,
recv_obj.cached_tokens[i], recv_obj.cached_tokens[i],
state.finished_time - state.created_time, state.finished_time - state.created_time,
has_grammar,
) )
def dump_requests(self, state: ReqState, out_dict: dict): def dump_requests(self, state: ReqState, out_dict: dict):
......
...@@ -15,7 +15,119 @@ ...@@ -15,7 +15,119 @@
import time import time
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, Union from enum import Enum
from typing import Dict, List, Optional, Union
from sglang.srt.utils import get_bool_env_var
SGLANG_TEST_REQUEST_TIME_STATS = get_bool_env_var("SGLANG_TEST_REQUEST_TIME_STATS")
@dataclass
class TimeStats:
"""
Store the timestamps for each stage of a request.
Unified: wait_queue -> forward -> completion
Prefill: bootstrap_queue -> wait_queue -> forward -> transfer_queue -> completion
Decode: prealloc_queue -> transfer_queue -> wait_queue -> forward -> completion
"""
lb_entry_time: float = 0.0
wait_queue_entry_time: float = 0.0
forward_entry_time: float = 0.0
completion_time: float = 0.0
prefill_bootstrap_queue_entry_time: float = 0.0
prefill_transfer_queue_entry_time: float = 0.0
decode_prealloc_queue_entry_time: float = 0.0
decode_transfer_queue_entry_time: float = 0.0
class RequestType(Enum):
UNIFIED = "unified"
PREFILL = "prefill"
DECODE = "decode"
INVALID = "invalid"
def __str__(self) -> str:
# if unified
_type = self.get_type()
if _type == self.RequestType.UNIFIED:
queue_duration = self.forward_entry_time - self.wait_queue_entry_time
forward_duration = self.completion_time - self.forward_entry_time
if SGLANG_TEST_REQUEST_TIME_STATS:
assert (
queue_duration >= 0 and forward_duration >= 0
), f"queue_duration={queue_duration} < 0 or forward_duration={forward_duration} < 0"
return f"queue_duration={self.format_duration(queue_duration)}, forward_duration={self.format_duration(forward_duration)}, start_time={self.wait_queue_entry_time}"
elif _type == self.RequestType.PREFILL:
bootstrap_duration = (
self.wait_queue_entry_time - self.prefill_bootstrap_queue_entry_time
)
queue_duration = self.forward_entry_time - self.wait_queue_entry_time
forward_duration = self.completion_time - self.forward_entry_time
if SGLANG_TEST_REQUEST_TIME_STATS:
assert (
bootstrap_duration >= 0
and queue_duration >= 0
and forward_duration >= 0
), f"bootstrap_duration={bootstrap_duration} < 0 or queue_duration={queue_duration} < 0 or forward_duration={forward_duration} < 0"
return f"bootstrap_duration={self.format_duration(bootstrap_duration)}, queue_duration={self.format_duration(queue_duration)}, forward_duration={self.format_duration(forward_duration)}, start_time={self.prefill_bootstrap_queue_entry_time}"
# if decode
elif _type == self.RequestType.DECODE:
prealloc_duration = (
self.decode_transfer_queue_entry_time
- self.decode_prealloc_queue_entry_time
)
transfer_duration = (
self.wait_queue_entry_time - self.decode_transfer_queue_entry_time
)
queue_duration = self.forward_entry_time - self.wait_queue_entry_time
forward_duration = self.completion_time - self.forward_entry_time
if SGLANG_TEST_REQUEST_TIME_STATS:
assert (
prealloc_duration >= 0
and transfer_duration >= 0
and queue_duration >= 0
and forward_duration >= 0
), f"prealloc_duration={prealloc_duration} < 0 or transfer_duration={transfer_duration} < 0 or queue_duration={queue_duration} < 0 or forward_duration={forward_duration} < 0"
return f"prealloc_duration={self.format_duration(prealloc_duration)}, transfer_duration={self.format_duration(transfer_duration)}, queue_duration={self.format_duration(queue_duration)}, forward_duration={self.format_duration(forward_duration)}, start_time={self.decode_prealloc_queue_entry_time}"
else:
return "Invalid Time Stats"
def format_duration(self, duration: float) -> str:
return f"{duration * 1e3:.2f}ms"
def get_type(self) -> RequestType:
"""Determine the type of request based on timestamp values."""
if (
self.prefill_bootstrap_queue_entry_time == 0.0
and self.prefill_transfer_queue_entry_time == 0.0
and self.decode_prealloc_queue_entry_time == 0.0
and self.decode_transfer_queue_entry_time == 0.0
):
return self.RequestType.UNIFIED
elif (
self.prefill_bootstrap_queue_entry_time > 0.0
and self.prefill_transfer_queue_entry_time > 0.0
):
return self.RequestType.PREFILL
elif (
self.decode_prealloc_queue_entry_time > 0.0
and self.decode_transfer_queue_entry_time > 0.0
and self.wait_queue_entry_time > 0.0
):
return self.RequestType.DECODE
else:
return self.RequestType.INVALID
@dataclass @dataclass
...@@ -26,15 +138,20 @@ class SchedulerStats: ...@@ -26,15 +138,20 @@ class SchedulerStats:
gen_throughput: float = 0.0 gen_throughput: float = 0.0
num_queue_reqs: int = 0 num_queue_reqs: int = 0
cache_hit_rate: float = 0.0 cache_hit_rate: float = 0.0
num_grammar_queue_reqs: int = 0
spec_accept_length: float = 0.0 spec_accept_length: float = 0.0
avg_request_queue_latency: float = 0.0 avg_request_queue_latency: float = 0.0
num_prefill_prealloc_queue_reqs: int = 0
num_prefill_infight_queue_reqs: int = 0
num_decode_prealloc_queue_reqs: int = 0
num_decode_transfer_queue_reqs: int = 0
class SchedulerMetricsCollector: class SchedulerMetricsCollector:
def __init__(self, labels: Dict[str, str]) -> None: def __init__(self, labels: Dict[str, str]) -> None:
# We need to import prometheus_client after setting the env variable `PROMETHEUS_MULTIPROC_DIR` # We need to import prometheus_client after setting the env variable `PROMETHEUS_MULTIPROC_DIR`
from prometheus_client import Gauge, Histogram from prometheus_client import Counter, Gauge
self.labels = labels self.labels = labels
self.last_log_time = time.time() self.last_log_time = time.time()
...@@ -74,6 +191,13 @@ class SchedulerMetricsCollector: ...@@ -74,6 +191,13 @@ class SchedulerMetricsCollector:
multiprocess_mode="mostrecent", multiprocess_mode="mostrecent",
) )
self.num_grammar_queue_reqs = Gauge(
name="sglang:num_grammar_queue_reqs",
documentation="The number of requests in the grammar waiting queue.",
labelnames=labels.keys(),
multiprocess_mode="mostrecent",
)
self.cache_hit_rate = Gauge( self.cache_hit_rate = Gauge(
name="sglang:cache_hit_rate", name="sglang:cache_hit_rate",
documentation="The prefix cache hit rate.", documentation="The prefix cache hit rate.",
...@@ -95,28 +219,98 @@ class SchedulerMetricsCollector: ...@@ -95,28 +219,98 @@ class SchedulerMetricsCollector:
multiprocess_mode="mostrecent", multiprocess_mode="mostrecent",
) )
# Disaggregation queue metrics
self.num_prefill_prealloc_queue_reqs = Gauge(
name="sglang:num_prefill_prealloc_queue_reqs",
documentation="The number of requests in the prefill prealloc queue.",
labelnames=labels.keys(),
multiprocess_mode="mostrecent",
)
self.num_prefill_infight_queue_reqs = Gauge(
name="sglang:num_prefill_infight_queue_reqs",
documentation="The number of requests in the prefill infight queue.",
labelnames=labels.keys(),
multiprocess_mode="mostrecent",
)
self.num_decode_prealloc_queue_reqs = Gauge(
name="sglang:num_decode_prealloc_queue_reqs",
documentation="The number of requests in the decode prealloc queue.",
labelnames=labels.keys(),
multiprocess_mode="mostrecent",
)
self.num_decode_transfer_queue_reqs = Gauge(
name="sglang:num_decode_transfer_queue_reqs",
documentation="The number of requests in the decode transfer queue.",
labelnames=labels.keys(),
multiprocess_mode="mostrecent",
)
self.num_bootstrap_failed_reqs = Counter(
name="sglang:num_bootstrap_failed_reqs",
documentation="The number of bootstrap failed requests.",
labelnames=labels.keys(),
)
self.num_transfer_failed_reqs = Counter(
name="sglang:num_transfer_failed_reqs",
documentation="The number of transfer failed requests.",
labelnames=labels.keys(),
)
def _log_gauge(self, gauge, data: Union[int, float]) -> None: def _log_gauge(self, gauge, data: Union[int, float]) -> None:
# Convenience function for logging to gauge. # Convenience function for logging to gauge.
gauge.labels(**self.labels).set(data) gauge.labels(**self.labels).set(data)
def increment_bootstrap_failed_reqs(self) -> None:
self.num_bootstrap_failed_reqs.labels(**self.labels).inc(1)
def increment_transfer_failed_reqs(self) -> None:
self.num_transfer_failed_reqs.labels(**self.labels).inc(1)
def log_stats(self, stats: SchedulerStats) -> None: def log_stats(self, stats: SchedulerStats) -> None:
self._log_gauge(self.num_running_reqs, stats.num_running_reqs) self._log_gauge(self.num_running_reqs, stats.num_running_reqs)
self._log_gauge(self.num_used_tokens, stats.num_used_tokens) self._log_gauge(self.num_used_tokens, stats.num_used_tokens)
self._log_gauge(self.token_usage, stats.token_usage) self._log_gauge(self.token_usage, stats.token_usage)
self._log_gauge(self.gen_throughput, stats.gen_throughput) self._log_gauge(self.gen_throughput, stats.gen_throughput)
self._log_gauge(self.num_queue_reqs, stats.num_queue_reqs) self._log_gauge(self.num_queue_reqs, stats.num_queue_reqs)
self._log_gauge(self.num_grammar_queue_reqs, stats.num_grammar_queue_reqs)
self._log_gauge(self.cache_hit_rate, stats.cache_hit_rate) self._log_gauge(self.cache_hit_rate, stats.cache_hit_rate)
self._log_gauge(self.spec_accept_length, stats.spec_accept_length) self._log_gauge(self.spec_accept_length, stats.spec_accept_length)
self._log_gauge(self.avg_request_queue_latency, stats.avg_request_queue_latency)
# Disaggregation metrics
self._log_gauge(
self.num_prefill_prealloc_queue_reqs, stats.num_prefill_prealloc_queue_reqs
)
self._log_gauge(
self.num_prefill_infight_queue_reqs, stats.num_prefill_infight_queue_reqs
)
self._log_gauge(
self.num_decode_prealloc_queue_reqs, stats.num_decode_prealloc_queue_reqs
)
self._log_gauge(
self.num_decode_transfer_queue_reqs, stats.num_decode_transfer_queue_reqs
)
self.last_log_time = time.time() self.last_log_time = time.time()
class TokenizerMetricsCollector: class TokenizerMetricsCollector:
def __init__(self, labels: Dict[str, str]) -> None: def __init__(
self,
labels: Dict[str, str],
bucket_time_to_first_token: Optional[List[float]] = None,
bucket_inter_token_latency: Optional[List[float]] = None,
bucket_e2e_request_latency: Optional[List[float]] = None,
collect_tokens_histogram: bool = False,
) -> None:
# We need to import prometheus_client after setting the env variable `PROMETHEUS_MULTIPROC_DIR` # We need to import prometheus_client after setting the env variable `PROMETHEUS_MULTIPROC_DIR`
from prometheus_client import Counter, Histogram from prometheus_client import Counter, Histogram
self.labels = labels self.labels = labels
self.collect_tokens_histogram = collect_tokens_histogram
self.prompt_tokens_total = Counter( self.prompt_tokens_total = Counter(
name="sglang:prompt_tokens_total", name="sglang:prompt_tokens_total",
...@@ -130,6 +324,66 @@ class TokenizerMetricsCollector: ...@@ -130,6 +324,66 @@ class TokenizerMetricsCollector:
labelnames=labels.keys(), labelnames=labels.keys(),
) )
if collect_tokens_histogram:
bucket_prompt_tokens = [
100,
300,
500,
700,
1000,
1500,
2000,
3000,
4000,
5000,
6000,
7000,
8000,
9000,
10000,
12000,
15000,
20000,
22000,
25000,
30000,
35000,
40000,
]
self.prompt_tokens_histogram = Histogram(
name="sglang:prompt_tokens_histogram",
documentation="Histogram of prompt token length.",
labelnames=labels.keys(),
buckets=bucket_prompt_tokens,
)
bucket_generation_tokens = [
100,
300,
500,
1000,
1200,
1500,
1700,
2000,
2500,
3000,
3500,
4000,
4500,
5000,
6000,
7000,
8000,
9000,
10000,
]
self.generation_tokens_histogram = Histogram(
name="sglang:generation_tokens_histogram",
documentation="Histogram of generation token length.",
labelnames=labels.keys(),
buckets=bucket_generation_tokens,
)
self.cached_tokens_total = Counter( self.cached_tokens_total = Counter(
name="sglang:cached_tokens_total", name="sglang:cached_tokens_total",
documentation="Number of cached prompt tokens.", documentation="Number of cached prompt tokens.",
...@@ -142,11 +396,14 @@ class TokenizerMetricsCollector: ...@@ -142,11 +396,14 @@ class TokenizerMetricsCollector:
labelnames=labels.keys(), labelnames=labels.keys(),
) )
self.histogram_time_to_first_token = Histogram( self.num_so_requests_total = Counter(
name="sglang:time_to_first_token_seconds", name="sglang:num_so_requests_total",
documentation="Histogram of time to first token in seconds.", documentation="Number of structured output requests processed.",
labelnames=labels.keys(), labelnames=labels.keys(),
buckets=[ )
if bucket_time_to_first_token is None:
bucket_time_to_first_token = [
0.1, 0.1,
0.2, 0.2,
0.4, 0.4,
...@@ -165,14 +422,33 @@ class TokenizerMetricsCollector: ...@@ -165,14 +422,33 @@ class TokenizerMetricsCollector:
100, 100,
200, 200,
400, 400,
], ]
)
self.histogram_inter_token_latency_seconds = Histogram( if bucket_e2e_request_latency is None:
name="sglang:inter_token_latency_seconds", bucket_e2e_request_latency = [
documentation="Histogram of inter-token latency in seconds.", 0.1,
labelnames=labels.keys(), 0.2,
buckets=[ 0.4,
0.6,
0.8,
1,
2,
4,
6,
8,
10,
20,
40,
60,
80,
100,
200,
400,
800,
]
if bucket_inter_token_latency is None:
bucket_inter_token_latency = [
0.002, 0.002,
0.004, 0.004,
0.006, 0.006,
...@@ -196,34 +472,27 @@ class TokenizerMetricsCollector: ...@@ -196,34 +472,27 @@ class TokenizerMetricsCollector:
4.000, 4.000,
6.000, 6.000,
8.000, 8.000,
], ]
self.histogram_time_to_first_token = Histogram(
name="sglang:time_to_first_token_seconds",
documentation="Histogram of time to first token in seconds.",
labelnames=labels.keys(),
buckets=bucket_time_to_first_token,
)
self.histogram_inter_token_latency_seconds = Histogram(
name="sglang:inter_token_latency_seconds",
documentation="Histogram of inter-token latency in seconds.",
labelnames=labels.keys(),
buckets=bucket_inter_token_latency,
) )
self.histogram_e2e_request_latency = Histogram( self.histogram_e2e_request_latency = Histogram(
name="sglang:e2e_request_latency_seconds", name="sglang:e2e_request_latency_seconds",
documentation="Histogram of End-to-end request latency in seconds", documentation="Histogram of End-to-end request latency in seconds",
labelnames=labels.keys(), labelnames=labels.keys(),
buckets=[ buckets=bucket_e2e_request_latency,
0.1,
0.2,
0.4,
0.6,
0.8,
1,
2,
4,
6,
8,
10,
20,
40,
60,
80,
100,
200,
400,
800,
],
) )
def _log_histogram(self, histogram, data: Union[int, float]) -> None: def _log_histogram(self, histogram, data: Union[int, float]) -> None:
...@@ -235,13 +504,19 @@ class TokenizerMetricsCollector: ...@@ -235,13 +504,19 @@ class TokenizerMetricsCollector:
generation_tokens: int, generation_tokens: int,
cached_tokens: int, cached_tokens: int,
e2e_latency: float, e2e_latency: float,
has_grammar: bool,
): ):
self.prompt_tokens_total.labels(**self.labels).inc(prompt_tokens) self.prompt_tokens_total.labels(**self.labels).inc(prompt_tokens)
self.generation_tokens_total.labels(**self.labels).inc(generation_tokens) self.generation_tokens_total.labels(**self.labels).inc(generation_tokens)
if cached_tokens > 0: if cached_tokens > 0:
self.cached_tokens_total.labels(**self.labels).inc(cached_tokens) self.cached_tokens_total.labels(**self.labels).inc(cached_tokens)
self.num_requests_total.labels(**self.labels).inc(1) self.num_requests_total.labels(**self.labels).inc(1)
if has_grammar:
self.num_so_requests_total.labels(**self.labels).inc(1)
self._log_histogram(self.histogram_e2e_request_latency, e2e_latency) self._log_histogram(self.histogram_e2e_request_latency, e2e_latency)
if self.collect_tokens_histogram:
self._log_histogram(self.prompt_tokens_histogram, prompt_tokens)
self._log_histogram(self.generation_tokens_histogram, generation_tokens)
def observe_time_to_first_token(self, value: float): def observe_time_to_first_token(self, value: float):
self.histogram_time_to_first_token.labels(**self.labels).observe(value) self.histogram_time_to_first_token.labels(**self.labels).observe(value)
......
...@@ -82,7 +82,7 @@ class TestJSONConstrainedOutlinesBackend(CustomTestCase): ...@@ -82,7 +82,7 @@ class TestJSONConstrainedOutlinesBackend(CustomTestCase):
print(json.dumps(ret)) print(json.dumps(ret))
print("=" * 100) print("=" * 100)
if not json_schema: if not json_schema or json_schema == "INVALID":
return return
# Make sure the json output is valid # Make sure the json output is valid
...@@ -97,6 +97,9 @@ class TestJSONConstrainedOutlinesBackend(CustomTestCase): ...@@ -97,6 +97,9 @@ class TestJSONConstrainedOutlinesBackend(CustomTestCase):
def test_json_generate(self): def test_json_generate(self):
self.run_decode(json_schema=self.json_schema) self.run_decode(json_schema=self.json_schema)
def test_json_invalid(self):
self.run_decode(json_schema="INVALID")
def test_json_openai(self): def test_json_openai(self):
client = openai.Client(api_key="EMPTY", base_url=f"{self.base_url}/v1") client = openai.Client(api_key="EMPTY", base_url=f"{self.base_url}/v1")
...@@ -104,7 +107,10 @@ class TestJSONConstrainedOutlinesBackend(CustomTestCase): ...@@ -104,7 +107,10 @@ class TestJSONConstrainedOutlinesBackend(CustomTestCase):
model=self.model, model=self.model,
messages=[ messages=[
{"role": "system", "content": "You are a helpful AI assistant"}, {"role": "system", "content": "You are a helpful AI assistant"},
{"role": "user", "content": "Introduce the capital of France."}, {
"role": "user",
"content": "Introduce the capital of France. Return in a JSON format.",
},
], ],
temperature=0, temperature=0,
max_tokens=128, max_tokens=128,
......
...@@ -56,6 +56,7 @@ class TestEnableMetrics(CustomTestCase): ...@@ -56,6 +56,7 @@ class TestEnableMetrics(CustomTestCase):
"sglang:token_usage", "sglang:token_usage",
"sglang:gen_throughput", "sglang:gen_throughput",
"sglang:num_queue_reqs", "sglang:num_queue_reqs",
"sglang:num_grammar_queue_reqs",
"sglang:cache_hit_rate", "sglang:cache_hit_rate",
"sglang:spec_accept_length", "sglang:spec_accept_length",
"sglang:prompt_tokens_total", "sglang:prompt_tokens_total",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment