".github/vscode:/vscode.git/clone" did not exist on "3a42ebbf5781d6c6408324edeac9d704ca41e6b6"
Unverified Commit 125b1199 authored by DarkSharpness's avatar DarkSharpness Committed by GitHub
Browse files

support parallel grammar preprocessing (#1996)


Co-authored-by: default avatarLianmin Zheng <lianminzheng@gmail.com>
parent eff468dd
......@@ -13,25 +13,11 @@ See the License for the specific language governing permissions and
limitations under the License.
"""
"""For constrained decoding."""
import json
from typing import Dict, Optional, Union
from pydantic import BaseModel
try:
from outlines.caching import cache as disk_cache
from outlines.caching import disable_cache
from outlines.fsm.guide import RegexGuide
from outlines.fsm.regex import FSMInfo, make_byte_level_fsm, make_deterministic_fsm
from outlines.models.transformers import TransformerTokenizer
except ImportError as e:
print(
f'\nError: {e}. Please install a new version of outlines by `pip install "outlines>=0.0.44"`\n'
)
raise
try:
from outlines.fsm.json_schema import build_regex_from_object
except ImportError:
......@@ -51,31 +37,6 @@ except ImportError:
return build_regex_from_schema(schema, whitespace_pattern)
try:
from xgrammar import (
GrammarMatcher,
GrammarMatcherInitContext,
GrammarMatcherInitContextCache,
)
except ImportError as e:
class Dummy:
pass
GrammarMatcher = Dummy
GrammarMatcherInitContext = Dummy
GrammarMatcherInitContextCache = Dummy
__all__ = [
"RegexGuide",
"FSMInfo",
"make_deterministic_fsm",
"build_regex_from_object",
"TransformerTokenizer",
"disk_cache",
"disable_cache",
"make_byte_level_fsm",
"GrammarMatcher",
"GrammarMatcherInitContext",
"GrammarMatcherInitContextCache",
]
......@@ -13,25 +13,47 @@ See the License for the specific language governing permissions and
limitations under the License.
"""
"""Base tool cache for constrained decoding tools."""
"""Base cache class for constrained decoding tools."""
import time
from dataclasses import dataclass
from threading import Event, Lock
from typing import Any, Dict, Tuple
@dataclass
class MapEntry:
event: Event
value: Any
def __iter__(self):
return iter((self.event, self.value))
class BaseToolCache:
def __init__(self, enable=True):
self.enable = enable
self.enable: bool = enable
self.cache: Dict[str, MapEntry] = {}
self.metrics: Dict[str, Any] = {}
self.lock_cache: Lock = Lock()
self.lock_metrics: Lock = Lock()
self.reset()
def reset(self):
self.cache = {}
self.metrics = {"total": 0, "hit": 0, "avg_init_time": 0}
with self.lock_cache:
self.cache = {}
with self.lock_metrics:
self.metrics = {"total": 0, "hit": 0, "avg_init_time": 0}
def query(self, key):
def _init_with_timer(key):
start = time.monotonic()
val = self.init_value(key)
init_time = time.monotonic() - start
def _init_with_timer(self, key) -> Tuple[Any, float]:
start = time.monotonic()
val = self.init_value(key)
init_time = time.monotonic() - start
return val, init_time
def update_time(self, init_time):
with self.lock_metrics:
curr_total = self.metrics["total"]
new_total = curr_total + 1
......@@ -39,27 +61,44 @@ class BaseToolCache:
self.metrics["avg_init_time"] = (init_time / new_total) + (
curr_total / new_total
) * self.metrics["avg_init_time"]
return val
if key in self.cache:
self.metrics["hit"] += 1
val = self.cache[key]
else:
# Cache miss or disabled.
val = _init_with_timer(key)
def query(self, key):
if not self.enable:
value, init_time = self._init_with_timer(key)
self.update_time(init_time)
return value
with self.lock_cache:
if key in self.cache:
entry = self.cache[key]
cache_hit = True
else:
entry = MapEntry(Event(), None)
self.cache[key] = entry
cache_hit = False
if self.enable:
with self.lock_metrics:
self.metrics["total"] += 1
self.cache[key] = val
return val
if cache_hit:
self.metrics["hit"] += 1
if cache_hit:
entry.event.wait()
else:
entry.value, init_time = self._init_with_timer(key)
self.update_time(init_time)
entry.event.set()
return entry.value
def init_value(self, key):
raise NotImplementedError()
def get_cache_hit_rate(self):
if self.metrics["total"] == 0:
return 0
return self.metrics["hit"] / self.metrics["total"]
with self.lock_metrics:
if self.metrics["total"] == 0:
return 0
return self.metrics["hit"] / self.metrics["total"]
def get_avg_init_time(self):
return self.metrics["avg_init_time"]
with self.lock_metrics:
return self.metrics["avg_init_time"]
......@@ -13,50 +13,44 @@ limitations under the License.
"""Cache for the compressed finite state machine."""
import logging
from typing import List, Optional, Tuple, Union
from concurrent.futures import Future, ThreadPoolExecutor
from typing import List, Tuple, Union
import torch
from sglang.srt.constrained import GrammarMatcher, RegexGuide
from sglang.srt.constrained.bnf_cache import BNFCache
from sglang.srt.constrained.fsm_cache import FSMCache
from sglang.srt.constrained.jump_forward import JumpForwardCache, JumpForwardMap
# from sglang.srt.managers.schedule_batch import Req
from sglang.srt.constrained.outlines_cache import OutlinesCache, RegexGuide
from sglang.srt.constrained.outlines_jump_forward import (
OutlinesJumpCache,
OutlinesJumpForwardMap,
)
from sglang.srt.constrained.xgrammar_cache import (
GrammarMatcher,
XGrammarBackend,
XGrammarJumpCache,
)
logger = logging.getLogger(__name__)
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
class XGrammarJump:
pass
class JumpHelper:
data: Union[List, str]
state: int
suffix_ids: List[int]
def __init__(
self, data: Union[List, str] = "", state: int = -1, suffix_ids=[]
) -> None:
self.data = data
self.state = state
self.suffix_ids = suffix_ids
self.data: Union[List, str] = data
self.state: int = state
self.suffix_ids: List[int] = suffix_ids
def can_jump(self):
return len(self.data) > 0
class Grammar:
grammar: Union[GrammarMatcher, Tuple[RegexGuide, int]]
jump_map: Union[XGrammarJump, JumpForwardMap, None]
def __init__(
self,
grammar: Union[GrammarMatcher, Tuple[RegexGuide, int]],
jump_map: Union[XGrammarJump, JumpForwardMap, None],
jump_map: Union[XGrammarJumpCache, OutlinesJumpForwardMap, None],
) -> None:
self.grammar = grammar
self.jump_map = jump_map
......@@ -69,10 +63,10 @@ class Grammar:
self.grammar = guide, guide.get_next_state(state, token)
def try_jump(self, tokenizer) -> JumpHelper:
if isinstance(self.jump_map, XGrammarJump):
if isinstance(self.jump_map, XGrammarJumpCache):
assert isinstance(self.grammar, GrammarMatcher)
return JumpHelper(self.grammar.find_jump_forward_string())
elif isinstance(self.jump_map, JumpForwardMap):
elif isinstance(self.jump_map, OutlinesJumpForwardMap):
assert isinstance(self.grammar, Tuple)
_, state = self.grammar
......@@ -103,7 +97,7 @@ class Grammar:
if isinstance(helper.data, str):
return helper.data, -1
else:
assert isinstance(self.jump_map, JumpForwardMap)
assert isinstance(self.jump_map, OutlinesJumpForwardMap)
return self.jump_map.jump_forward_symbol(helper.state)
def jump_and_retokenize(
......@@ -129,7 +123,7 @@ class Grammar:
def fill_vocab_mask(self, vocab_mask: torch.Tensor, vocab_size: int):
if isinstance(self.grammar, GrammarMatcher):
# Note that this bitmask is a bitset, not bool
bitmask = self.grammar.find_next_token_bitmask()
bitmask = self.grammar.get_next_token_bitmask()
# Mask the tokens that are not allowed
vocab_mask[
self.grammar.get_rejected_tokens_from_bitmask(bitmask, vocab_size)
......@@ -140,9 +134,7 @@ class Grammar:
vocab_mask[guide.get_next_instruction(state).tokens] = 0
class GrammarCache:
grammar_cache: Union[BNFCache, FSMCache]
jump_cache: Union[XGrammarJump, JumpForwardCache, None]
class GrammarBackend:
def __init__(
self,
......@@ -153,38 +145,38 @@ class GrammarCache:
backend=None,
allow_jump=False,
):
self.executor = ThreadPoolExecutor()
self.backend = backend
if backend == "xgrammar":
self.grammar_cache = BNFCache(
self.grammar_cache = XGrammarBackend(
tokenizer_path=tokenizer_path,
tokenizer_args_dict=tokenizer_args_dict,
skip_tokenizer_init=skip_tokenizer_init,
whitespace_patterns=whitespace_patterns,
)
self.jump_cache = XGrammarJump() if allow_jump else None
self.jump_cache = XGrammarJumpCache() if allow_jump else None
else:
assert backend == "outlines"
self.grammar_cache = FSMCache(
self.grammar_cache = OutlinesCache(
tokenizer_path=tokenizer_path,
tokenizer_args_dict=tokenizer_args_dict,
skip_tokenizer_init=skip_tokenizer_init,
constrained_json_whitespace_pattern=whitespace_patterns,
enable=True,
)
self.jump_cache = JumpForwardCache() if allow_jump else None
self.jump_cache = OutlinesJumpCache() if allow_jump else None
def query(self, key: Tuple[str, str], vocab_size: int) -> Grammar:
if isinstance(self.grammar_cache, BNFCache):
assert not isinstance(self.jump_cache, JumpForwardCache)
def _query(self, key: Tuple[str, str], vocab_size: int) -> Grammar:
if isinstance(self.grammar_cache, XGrammarBackend):
return Grammar(self.grammar_cache.query(key, vocab_size), self.jump_cache)
else:
jump_map = None
guide, regex = self.grammar_cache.query(key)
if isinstance(self.jump_cache, JumpForwardCache):
jump_map = self.jump_cache.query(regex)
jump_map = self.jump_cache.query(regex)
return Grammar((guide, 0), jump_map)
def query(self, key: Tuple[str, str], vocab_size: int) -> Future:
return self.executor.submit(self._query, key, vocab_size)
def reset(self):
if isinstance(self.grammar_cache, FSMCache):
self.grammar_cache.reset()
if isinstance(self.jump_cache, JumpForwardCache):
self.jump_cache.reset()
self.grammar_cache.reset()
self.jump_cache.reset()
......@@ -17,16 +17,17 @@ limitations under the License.
import logging
from interegular import InvalidSyntax, parse_pattern
from outlines.fsm.json_schema import build_regex_from_schema
from outlines.fsm.guide import RegexGuide
from outlines.models.transformers import TransformerTokenizer
from transformers import AutoTokenizer
from sglang.srt.constrained import RegexGuide, TransformerTokenizer
from sglang.srt.constrained import build_regex_from_object
from sglang.srt.constrained.base_tool_cache import BaseToolCache
logger = logging.getLogger(__name__)
class FSMCache(BaseToolCache):
class OutlinesCache(BaseToolCache):
def __init__(
self,
tokenizer_path,
......@@ -74,7 +75,7 @@ class FSMCache(BaseToolCache):
key_type, key_string = key
if key_type == "json":
try:
regex = build_regex_from_schema(
regex = build_regex_from_object(
key_string,
whitespace_pattern=self.constrained_json_whitespace_pattern,
)
......
......@@ -14,7 +14,7 @@ limitations under the License.
"""
"""
Faster constrained decoding.
Faster constrained decoding with jump forward decoding / compressed finite state machine.
Reference: https://lmsys.org/blog/2024-02-05-compressed-fsm/
"""
......@@ -23,15 +23,10 @@ import logging
from collections import defaultdict
import interegular
import outlines.caching
from interegular import InvalidSyntax
from outlines.caching import cache as disk_cache
from outlines.fsm.regex import FSMInfo, make_byte_level_fsm, make_deterministic_fsm
from sglang.srt.constrained import (
FSMInfo,
disk_cache,
make_byte_level_fsm,
make_deterministic_fsm,
)
from sglang.srt.constrained.base_tool_cache import BaseToolCache
IP_REGEX = r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)"
......@@ -47,7 +42,7 @@ class JumpEdge:
byte_next_state: int = None
class JumpForwardMap:
class OutlinesJumpForwardMap:
def __init__(self, regex_string):
@disk_cache()
def _init_state_to_jump_forward(regex_string):
......@@ -169,12 +164,12 @@ class JumpForwardMap:
)
class JumpForwardCache(BaseToolCache):
class OutlinesJumpCache(BaseToolCache):
def __init__(self):
super().__init__()
def init_value(self, regex):
forward_map = JumpForwardMap(regex)
forward_map = OutlinesJumpForwardMap(regex)
if forward_map.state_to_jump_forward:
return forward_map
else:
......@@ -182,7 +177,7 @@ class JumpForwardCache(BaseToolCache):
def test_main(regex_string):
jump_forward_map = JumpForwardMap(regex_string)
jump_forward_map = OutlinesJumpForwardMap(regex_string)
for state, e in jump_forward_map.state_to_jump_forward.items():
if e.symbol is not None:
jump_forward_str, next_state = jump_forward_map.jump_forward_symbol(state)
......
......@@ -17,18 +17,29 @@ from typing import Tuple
from transformers import AutoTokenizer
from sglang.srt.constrained import (
GrammarMatcher,
GrammarMatcherInitContext,
GrammarMatcherInitContextCache,
)
try:
from xgrammar import CachedGrammarCompiler, CompiledGrammar, GrammarMatcher
except ImportError as e:
class Dummy:
pass
GrammarMatcher = Dummy
CompiledGrammar = Dummy
CachedGrammarCompiler = Dummy
MAX_ROLLBACK_TOKENS = 10
class BNFCache:
grammar_cache: GrammarMatcherInitContextCache
class XGrammarJumpCache:
"""A dummy class."""
def reset(self):
pass
class XGrammarBackend:
def __init__(
self,
tokenizer_path,
......@@ -41,16 +52,16 @@ class BNFCache:
return
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, **tokenizer_args_dict)
self.grammar_cache = GrammarMatcherInitContextCache(
self.grammar_cache: CachedGrammarCompiler = CachedGrammarCompiler(
tokenizer_or_vocab=tokenizer
)
def get_context(self, key: Tuple[str, str]) -> GrammarMatcherInitContext:
def get_context(self, key: Tuple[str, str]) -> CompiledGrammar:
key_type, key_string = key
if key_type == "json":
return self.grammar_cache.get_init_context_for_json_schema(key_string)
return self.grammar_cache.get_compiled_grammar_for_json_schema(key_string)
elif key_type == "regex":
raise ValueError(f"regex hasn't been supported by xgrammar yet")
raise ValueError("regex hasn't been supported by xgrammar yet")
else:
raise ValueError(f"Invalid key_type: {key_type}")
......@@ -59,3 +70,6 @@ class BNFCache:
return GrammarMatcher(
ctx, max_rollback_tokens=MAX_ROLLBACK_TOKENS, mask_vocab_size=vocab_size
)
def reset(self):
self.grammar_cache.clear()
......@@ -29,7 +29,7 @@ import zmq
from sglang.global_config import global_config
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.constrained.grammar import GrammarCache
from sglang.srt.constrained.grammar import GrammarBackend
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.managers.io_struct import (
......@@ -234,11 +234,12 @@ class Scheduler:
self.chunked_prefill_size is not None and server_args.enable_mixed_chunk
)
# Init the FSM cache for constrained generation
# Init the grammar cache for constrained generation
self.grammar_cache = None
self.grammar_queue: List[Req] = []
if not server_args.skip_tokenizer_init:
self.grammar_cache = GrammarCache(
self.grammar_cache = GrammarBackend(
server_args.tokenizer_path,
{
"tokenizer_mode": server_args.tokenizer_mode,
......@@ -455,7 +456,7 @@ class Scheduler:
# By default, only return the logprobs for output tokens
req.logprob_start_len = len(recv_req.input_ids) - 1
# Init regex FSM or BNF
# Init grammar cache for this request
if (
req.sampling_params.json_schema is not None
or req.sampling_params.regex is not None
......@@ -488,7 +489,10 @@ class Scheduler:
self.max_req_len - len(req.origin_input_ids) - 1,
)
self.waiting_queue.append(req)
if req.grammar is not None:
self.grammar_queue.append(req)
else:
self.waiting_queue.append(req)
def handle_embedding_request(
self,
......@@ -634,6 +638,17 @@ class Scheduler:
return self.running_batch
def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
# Check if the grammar queue is ready
if self.grammar_queue:
new_grammar_queue = []
for req in self.grammar_queue:
if req.grammar.done():
req.grammar = req.grammar.result()
self.waiting_queue.append(req)
else:
new_grammar_queue.append(req)
self.grammar_queue = new_grammar_queue
# Handle the cases where prefill is not allowed
if (
self.batch_is_full or len(self.waiting_queue) == 0
......
......@@ -39,7 +39,6 @@ from vllm.model_executor.model_loader import get_model
from vllm.model_executor.models import ModelRegistry
from sglang.srt.configs.model_config import AttentionArch, ModelConfig
from sglang.srt.constrained import disable_cache
from sglang.srt.layers.attention.double_sparsity_backend import DoubleSparseAttnBackend
from sglang.srt.layers.attention.flashinfer_backend import FlashInferAttnBackend
from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
......@@ -129,6 +128,8 @@ class ModelRunner:
if server_args.show_time_cost:
enable_show_time_cost()
if server_args.disable_disk_cache:
from outlines.caching import disable_cache
disable_cache()
global_server_args_dict.update(
......
......@@ -100,8 +100,8 @@ class TestJSONConstrained(unittest.TestCase):
except (TypeError, json.decoder.JSONDecodeError):
print("JSONDecodeError", text)
raise
assert isinstance(js_obj["name"], str)
assert isinstance(js_obj["population"], int)
assert isinstance(js_obj["name"], str), f"{js_obj=}"
assert isinstance(js_obj["population"], int), f"{js_obj=}"
def test_mix_json_and_other(self):
json_schemas = [None, None, self.json_schema, self.json_schema] * 10
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment