Unverified Commit 125b1199 authored by DarkSharpness's avatar DarkSharpness Committed by GitHub
Browse files

support parallel grammar preprocessing (#1996)


Co-authored-by: default avatarLianmin Zheng <lianminzheng@gmail.com>
parent eff468dd
...@@ -13,25 +13,11 @@ See the License for the specific language governing permissions and ...@@ -13,25 +13,11 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
""" """
"""For constrained decoding."""
import json import json
from typing import Dict, Optional, Union from typing import Dict, Optional, Union
from pydantic import BaseModel from pydantic import BaseModel
try:
from outlines.caching import cache as disk_cache
from outlines.caching import disable_cache
from outlines.fsm.guide import RegexGuide
from outlines.fsm.regex import FSMInfo, make_byte_level_fsm, make_deterministic_fsm
from outlines.models.transformers import TransformerTokenizer
except ImportError as e:
print(
f'\nError: {e}. Please install a new version of outlines by `pip install "outlines>=0.0.44"`\n'
)
raise
try: try:
from outlines.fsm.json_schema import build_regex_from_object from outlines.fsm.json_schema import build_regex_from_object
except ImportError: except ImportError:
...@@ -51,31 +37,6 @@ except ImportError: ...@@ -51,31 +37,6 @@ except ImportError:
return build_regex_from_schema(schema, whitespace_pattern) return build_regex_from_schema(schema, whitespace_pattern)
try:
from xgrammar import (
GrammarMatcher,
GrammarMatcherInitContext,
GrammarMatcherInitContextCache,
)
except ImportError as e:
class Dummy:
pass
GrammarMatcher = Dummy
GrammarMatcherInitContext = Dummy
GrammarMatcherInitContextCache = Dummy
__all__ = [ __all__ = [
"RegexGuide",
"FSMInfo",
"make_deterministic_fsm",
"build_regex_from_object", "build_regex_from_object",
"TransformerTokenizer",
"disk_cache",
"disable_cache",
"make_byte_level_fsm",
"GrammarMatcher",
"GrammarMatcherInitContext",
"GrammarMatcherInitContextCache",
] ]
...@@ -13,25 +13,47 @@ See the License for the specific language governing permissions and ...@@ -13,25 +13,47 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
""" """
"""Base tool cache for constrained decoding tools.""" """Base cache class for constrained decoding tools."""
import time import time
from dataclasses import dataclass
from threading import Event, Lock
from typing import Any, Dict, Tuple
@dataclass
class MapEntry:
event: Event
value: Any
def __iter__(self):
return iter((self.event, self.value))
class BaseToolCache: class BaseToolCache:
def __init__(self, enable=True): def __init__(self, enable=True):
self.enable = enable self.enable: bool = enable
self.cache: Dict[str, MapEntry] = {}
self.metrics: Dict[str, Any] = {}
self.lock_cache: Lock = Lock()
self.lock_metrics: Lock = Lock()
self.reset() self.reset()
def reset(self): def reset(self):
with self.lock_cache:
self.cache = {} self.cache = {}
with self.lock_metrics:
self.metrics = {"total": 0, "hit": 0, "avg_init_time": 0} self.metrics = {"total": 0, "hit": 0, "avg_init_time": 0}
def query(self, key): def _init_with_timer(self, key) -> Tuple[Any, float]:
def _init_with_timer(key):
start = time.monotonic() start = time.monotonic()
val = self.init_value(key) val = self.init_value(key)
init_time = time.monotonic() - start init_time = time.monotonic() - start
return val, init_time
def update_time(self, init_time):
with self.lock_metrics:
curr_total = self.metrics["total"] curr_total = self.metrics["total"]
new_total = curr_total + 1 new_total = curr_total + 1
...@@ -39,27 +61,44 @@ class BaseToolCache: ...@@ -39,27 +61,44 @@ class BaseToolCache:
self.metrics["avg_init_time"] = (init_time / new_total) + ( self.metrics["avg_init_time"] = (init_time / new_total) + (
curr_total / new_total curr_total / new_total
) * self.metrics["avg_init_time"] ) * self.metrics["avg_init_time"]
return val
def query(self, key):
if not self.enable:
value, init_time = self._init_with_timer(key)
self.update_time(init_time)
return value
with self.lock_cache:
if key in self.cache: if key in self.cache:
self.metrics["hit"] += 1 entry = self.cache[key]
val = self.cache[key] cache_hit = True
else: else:
# Cache miss or disabled. entry = MapEntry(Event(), None)
val = _init_with_timer(key) self.cache[key] = entry
cache_hit = False
if self.enable: with self.lock_metrics:
self.metrics["total"] += 1 self.metrics["total"] += 1
self.cache[key] = val if cache_hit:
return val self.metrics["hit"] += 1
if cache_hit:
entry.event.wait()
else:
entry.value, init_time = self._init_with_timer(key)
self.update_time(init_time)
entry.event.set()
return entry.value
def init_value(self, key): def init_value(self, key):
raise NotImplementedError() raise NotImplementedError()
def get_cache_hit_rate(self): def get_cache_hit_rate(self):
with self.lock_metrics:
if self.metrics["total"] == 0: if self.metrics["total"] == 0:
return 0 return 0
return self.metrics["hit"] / self.metrics["total"] return self.metrics["hit"] / self.metrics["total"]
def get_avg_init_time(self): def get_avg_init_time(self):
with self.lock_metrics:
return self.metrics["avg_init_time"] return self.metrics["avg_init_time"]
...@@ -13,50 +13,44 @@ limitations under the License. ...@@ -13,50 +13,44 @@ limitations under the License.
"""Cache for the compressed finite state machine.""" """Cache for the compressed finite state machine."""
import logging import logging
from typing import List, Optional, Tuple, Union from concurrent.futures import Future, ThreadPoolExecutor
from typing import List, Tuple, Union
import torch import torch
from sglang.srt.constrained import GrammarMatcher, RegexGuide from sglang.srt.constrained.outlines_cache import OutlinesCache, RegexGuide
from sglang.srt.constrained.bnf_cache import BNFCache from sglang.srt.constrained.outlines_jump_forward import (
from sglang.srt.constrained.fsm_cache import FSMCache OutlinesJumpCache,
from sglang.srt.constrained.jump_forward import JumpForwardCache, JumpForwardMap OutlinesJumpForwardMap,
)
# from sglang.srt.managers.schedule_batch import Req from sglang.srt.constrained.xgrammar_cache import (
GrammarMatcher,
XGrammarBackend,
XGrammarJumpCache,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
class XGrammarJump:
pass
class JumpHelper: class JumpHelper:
data: Union[List, str]
state: int
suffix_ids: List[int]
def __init__( def __init__(
self, data: Union[List, str] = "", state: int = -1, suffix_ids=[] self, data: Union[List, str] = "", state: int = -1, suffix_ids=[]
) -> None: ) -> None:
self.data = data self.data: Union[List, str] = data
self.state = state self.state: int = state
self.suffix_ids = suffix_ids self.suffix_ids: List[int] = suffix_ids
def can_jump(self): def can_jump(self):
return len(self.data) > 0 return len(self.data) > 0
class Grammar: class Grammar:
grammar: Union[GrammarMatcher, Tuple[RegexGuide, int]]
jump_map: Union[XGrammarJump, JumpForwardMap, None]
def __init__( def __init__(
self, self,
grammar: Union[GrammarMatcher, Tuple[RegexGuide, int]], grammar: Union[GrammarMatcher, Tuple[RegexGuide, int]],
jump_map: Union[XGrammarJump, JumpForwardMap, None], jump_map: Union[XGrammarJumpCache, OutlinesJumpForwardMap, None],
) -> None: ) -> None:
self.grammar = grammar self.grammar = grammar
self.jump_map = jump_map self.jump_map = jump_map
...@@ -69,10 +63,10 @@ class Grammar: ...@@ -69,10 +63,10 @@ class Grammar:
self.grammar = guide, guide.get_next_state(state, token) self.grammar = guide, guide.get_next_state(state, token)
def try_jump(self, tokenizer) -> JumpHelper: def try_jump(self, tokenizer) -> JumpHelper:
if isinstance(self.jump_map, XGrammarJump): if isinstance(self.jump_map, XGrammarJumpCache):
assert isinstance(self.grammar, GrammarMatcher) assert isinstance(self.grammar, GrammarMatcher)
return JumpHelper(self.grammar.find_jump_forward_string()) return JumpHelper(self.grammar.find_jump_forward_string())
elif isinstance(self.jump_map, JumpForwardMap): elif isinstance(self.jump_map, OutlinesJumpForwardMap):
assert isinstance(self.grammar, Tuple) assert isinstance(self.grammar, Tuple)
_, state = self.grammar _, state = self.grammar
...@@ -103,7 +97,7 @@ class Grammar: ...@@ -103,7 +97,7 @@ class Grammar:
if isinstance(helper.data, str): if isinstance(helper.data, str):
return helper.data, -1 return helper.data, -1
else: else:
assert isinstance(self.jump_map, JumpForwardMap) assert isinstance(self.jump_map, OutlinesJumpForwardMap)
return self.jump_map.jump_forward_symbol(helper.state) return self.jump_map.jump_forward_symbol(helper.state)
def jump_and_retokenize( def jump_and_retokenize(
...@@ -129,7 +123,7 @@ class Grammar: ...@@ -129,7 +123,7 @@ class Grammar:
def fill_vocab_mask(self, vocab_mask: torch.Tensor, vocab_size: int): def fill_vocab_mask(self, vocab_mask: torch.Tensor, vocab_size: int):
if isinstance(self.grammar, GrammarMatcher): if isinstance(self.grammar, GrammarMatcher):
# Note that this bitmask is a bitset, not bool # Note that this bitmask is a bitset, not bool
bitmask = self.grammar.find_next_token_bitmask() bitmask = self.grammar.get_next_token_bitmask()
# Mask the tokens that are not allowed # Mask the tokens that are not allowed
vocab_mask[ vocab_mask[
self.grammar.get_rejected_tokens_from_bitmask(bitmask, vocab_size) self.grammar.get_rejected_tokens_from_bitmask(bitmask, vocab_size)
...@@ -140,9 +134,7 @@ class Grammar: ...@@ -140,9 +134,7 @@ class Grammar:
vocab_mask[guide.get_next_instruction(state).tokens] = 0 vocab_mask[guide.get_next_instruction(state).tokens] = 0
class GrammarCache: class GrammarBackend:
grammar_cache: Union[BNFCache, FSMCache]
jump_cache: Union[XGrammarJump, JumpForwardCache, None]
def __init__( def __init__(
self, self,
...@@ -153,38 +145,38 @@ class GrammarCache: ...@@ -153,38 +145,38 @@ class GrammarCache:
backend=None, backend=None,
allow_jump=False, allow_jump=False,
): ):
self.executor = ThreadPoolExecutor()
self.backend = backend
if backend == "xgrammar": if backend == "xgrammar":
self.grammar_cache = BNFCache( self.grammar_cache = XGrammarBackend(
tokenizer_path=tokenizer_path, tokenizer_path=tokenizer_path,
tokenizer_args_dict=tokenizer_args_dict, tokenizer_args_dict=tokenizer_args_dict,
skip_tokenizer_init=skip_tokenizer_init, skip_tokenizer_init=skip_tokenizer_init,
whitespace_patterns=whitespace_patterns, whitespace_patterns=whitespace_patterns,
) )
self.jump_cache = XGrammarJump() if allow_jump else None self.jump_cache = XGrammarJumpCache() if allow_jump else None
else: else:
assert backend == "outlines" assert backend == "outlines"
self.grammar_cache = FSMCache( self.grammar_cache = OutlinesCache(
tokenizer_path=tokenizer_path, tokenizer_path=tokenizer_path,
tokenizer_args_dict=tokenizer_args_dict, tokenizer_args_dict=tokenizer_args_dict,
skip_tokenizer_init=skip_tokenizer_init, skip_tokenizer_init=skip_tokenizer_init,
constrained_json_whitespace_pattern=whitespace_patterns, constrained_json_whitespace_pattern=whitespace_patterns,
enable=True,
) )
self.jump_cache = JumpForwardCache() if allow_jump else None self.jump_cache = OutlinesJumpCache() if allow_jump else None
def query(self, key: Tuple[str, str], vocab_size: int) -> Grammar: def _query(self, key: Tuple[str, str], vocab_size: int) -> Grammar:
if isinstance(self.grammar_cache, BNFCache): if isinstance(self.grammar_cache, XGrammarBackend):
assert not isinstance(self.jump_cache, JumpForwardCache)
return Grammar(self.grammar_cache.query(key, vocab_size), self.jump_cache) return Grammar(self.grammar_cache.query(key, vocab_size), self.jump_cache)
else: else:
jump_map = None
guide, regex = self.grammar_cache.query(key) guide, regex = self.grammar_cache.query(key)
if isinstance(self.jump_cache, JumpForwardCache):
jump_map = self.jump_cache.query(regex) jump_map = self.jump_cache.query(regex)
return Grammar((guide, 0), jump_map) return Grammar((guide, 0), jump_map)
def query(self, key: Tuple[str, str], vocab_size: int) -> Future:
return self.executor.submit(self._query, key, vocab_size)
def reset(self): def reset(self):
if isinstance(self.grammar_cache, FSMCache):
self.grammar_cache.reset() self.grammar_cache.reset()
if isinstance(self.jump_cache, JumpForwardCache):
self.jump_cache.reset() self.jump_cache.reset()
...@@ -17,16 +17,17 @@ limitations under the License. ...@@ -17,16 +17,17 @@ limitations under the License.
import logging import logging
from interegular import InvalidSyntax, parse_pattern from interegular import InvalidSyntax, parse_pattern
from outlines.fsm.json_schema import build_regex_from_schema from outlines.fsm.guide import RegexGuide
from outlines.models.transformers import TransformerTokenizer
from transformers import AutoTokenizer from transformers import AutoTokenizer
from sglang.srt.constrained import RegexGuide, TransformerTokenizer from sglang.srt.constrained import build_regex_from_object
from sglang.srt.constrained.base_tool_cache import BaseToolCache from sglang.srt.constrained.base_tool_cache import BaseToolCache
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class FSMCache(BaseToolCache): class OutlinesCache(BaseToolCache):
def __init__( def __init__(
self, self,
tokenizer_path, tokenizer_path,
...@@ -74,7 +75,7 @@ class FSMCache(BaseToolCache): ...@@ -74,7 +75,7 @@ class FSMCache(BaseToolCache):
key_type, key_string = key key_type, key_string = key
if key_type == "json": if key_type == "json":
try: try:
regex = build_regex_from_schema( regex = build_regex_from_object(
key_string, key_string,
whitespace_pattern=self.constrained_json_whitespace_pattern, whitespace_pattern=self.constrained_json_whitespace_pattern,
) )
......
...@@ -14,7 +14,7 @@ limitations under the License. ...@@ -14,7 +14,7 @@ limitations under the License.
""" """
""" """
Faster constrained decoding. Faster constrained decoding with jump forward decoding / compressed finite state machine.
Reference: https://lmsys.org/blog/2024-02-05-compressed-fsm/ Reference: https://lmsys.org/blog/2024-02-05-compressed-fsm/
""" """
...@@ -23,15 +23,10 @@ import logging ...@@ -23,15 +23,10 @@ import logging
from collections import defaultdict from collections import defaultdict
import interegular import interegular
import outlines.caching
from interegular import InvalidSyntax from interegular import InvalidSyntax
from outlines.caching import cache as disk_cache
from outlines.fsm.regex import FSMInfo, make_byte_level_fsm, make_deterministic_fsm
from sglang.srt.constrained import (
FSMInfo,
disk_cache,
make_byte_level_fsm,
make_deterministic_fsm,
)
from sglang.srt.constrained.base_tool_cache import BaseToolCache from sglang.srt.constrained.base_tool_cache import BaseToolCache
IP_REGEX = r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)" IP_REGEX = r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)"
...@@ -47,7 +42,7 @@ class JumpEdge: ...@@ -47,7 +42,7 @@ class JumpEdge:
byte_next_state: int = None byte_next_state: int = None
class JumpForwardMap: class OutlinesJumpForwardMap:
def __init__(self, regex_string): def __init__(self, regex_string):
@disk_cache() @disk_cache()
def _init_state_to_jump_forward(regex_string): def _init_state_to_jump_forward(regex_string):
...@@ -169,12 +164,12 @@ class JumpForwardMap: ...@@ -169,12 +164,12 @@ class JumpForwardMap:
) )
class JumpForwardCache(BaseToolCache): class OutlinesJumpCache(BaseToolCache):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
def init_value(self, regex): def init_value(self, regex):
forward_map = JumpForwardMap(regex) forward_map = OutlinesJumpForwardMap(regex)
if forward_map.state_to_jump_forward: if forward_map.state_to_jump_forward:
return forward_map return forward_map
else: else:
...@@ -182,7 +177,7 @@ class JumpForwardCache(BaseToolCache): ...@@ -182,7 +177,7 @@ class JumpForwardCache(BaseToolCache):
def test_main(regex_string): def test_main(regex_string):
jump_forward_map = JumpForwardMap(regex_string) jump_forward_map = OutlinesJumpForwardMap(regex_string)
for state, e in jump_forward_map.state_to_jump_forward.items(): for state, e in jump_forward_map.state_to_jump_forward.items():
if e.symbol is not None: if e.symbol is not None:
jump_forward_str, next_state = jump_forward_map.jump_forward_symbol(state) jump_forward_str, next_state = jump_forward_map.jump_forward_symbol(state)
......
...@@ -17,18 +17,29 @@ from typing import Tuple ...@@ -17,18 +17,29 @@ from typing import Tuple
from transformers import AutoTokenizer from transformers import AutoTokenizer
from sglang.srt.constrained import ( try:
GrammarMatcher, from xgrammar import CachedGrammarCompiler, CompiledGrammar, GrammarMatcher
GrammarMatcherInitContext, except ImportError as e:
GrammarMatcherInitContextCache,
) class Dummy:
pass
GrammarMatcher = Dummy
CompiledGrammar = Dummy
CachedGrammarCompiler = Dummy
MAX_ROLLBACK_TOKENS = 10 MAX_ROLLBACK_TOKENS = 10
class BNFCache: class XGrammarJumpCache:
grammar_cache: GrammarMatcherInitContextCache """A dummy class."""
def reset(self):
pass
class XGrammarBackend:
def __init__( def __init__(
self, self,
tokenizer_path, tokenizer_path,
...@@ -41,16 +52,16 @@ class BNFCache: ...@@ -41,16 +52,16 @@ class BNFCache:
return return
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, **tokenizer_args_dict) tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, **tokenizer_args_dict)
self.grammar_cache = GrammarMatcherInitContextCache( self.grammar_cache: CachedGrammarCompiler = CachedGrammarCompiler(
tokenizer_or_vocab=tokenizer tokenizer_or_vocab=tokenizer
) )
def get_context(self, key: Tuple[str, str]) -> GrammarMatcherInitContext: def get_context(self, key: Tuple[str, str]) -> CompiledGrammar:
key_type, key_string = key key_type, key_string = key
if key_type == "json": if key_type == "json":
return self.grammar_cache.get_init_context_for_json_schema(key_string) return self.grammar_cache.get_compiled_grammar_for_json_schema(key_string)
elif key_type == "regex": elif key_type == "regex":
raise ValueError(f"regex hasn't been supported by xgrammar yet") raise ValueError("regex hasn't been supported by xgrammar yet")
else: else:
raise ValueError(f"Invalid key_type: {key_type}") raise ValueError(f"Invalid key_type: {key_type}")
...@@ -59,3 +70,6 @@ class BNFCache: ...@@ -59,3 +70,6 @@ class BNFCache:
return GrammarMatcher( return GrammarMatcher(
ctx, max_rollback_tokens=MAX_ROLLBACK_TOKENS, mask_vocab_size=vocab_size ctx, max_rollback_tokens=MAX_ROLLBACK_TOKENS, mask_vocab_size=vocab_size
) )
def reset(self):
self.grammar_cache.clear()
...@@ -29,7 +29,7 @@ import zmq ...@@ -29,7 +29,7 @@ import zmq
from sglang.global_config import global_config from sglang.global_config import global_config
from sglang.srt.configs.model_config import ModelConfig from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.constrained.grammar import GrammarCache from sglang.srt.constrained.grammar import GrammarBackend
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.managers.io_struct import ( from sglang.srt.managers.io_struct import (
...@@ -234,11 +234,12 @@ class Scheduler: ...@@ -234,11 +234,12 @@ class Scheduler:
self.chunked_prefill_size is not None and server_args.enable_mixed_chunk self.chunked_prefill_size is not None and server_args.enable_mixed_chunk
) )
# Init the FSM cache for constrained generation # Init the grammar cache for constrained generation
self.grammar_cache = None self.grammar_cache = None
self.grammar_queue: List[Req] = []
if not server_args.skip_tokenizer_init: if not server_args.skip_tokenizer_init:
self.grammar_cache = GrammarCache( self.grammar_cache = GrammarBackend(
server_args.tokenizer_path, server_args.tokenizer_path,
{ {
"tokenizer_mode": server_args.tokenizer_mode, "tokenizer_mode": server_args.tokenizer_mode,
...@@ -455,7 +456,7 @@ class Scheduler: ...@@ -455,7 +456,7 @@ class Scheduler:
# By default, only return the logprobs for output tokens # By default, only return the logprobs for output tokens
req.logprob_start_len = len(recv_req.input_ids) - 1 req.logprob_start_len = len(recv_req.input_ids) - 1
# Init regex FSM or BNF # Init grammar cache for this request
if ( if (
req.sampling_params.json_schema is not None req.sampling_params.json_schema is not None
or req.sampling_params.regex is not None or req.sampling_params.regex is not None
...@@ -488,6 +489,9 @@ class Scheduler: ...@@ -488,6 +489,9 @@ class Scheduler:
self.max_req_len - len(req.origin_input_ids) - 1, self.max_req_len - len(req.origin_input_ids) - 1,
) )
if req.grammar is not None:
self.grammar_queue.append(req)
else:
self.waiting_queue.append(req) self.waiting_queue.append(req)
def handle_embedding_request( def handle_embedding_request(
...@@ -634,6 +638,17 @@ class Scheduler: ...@@ -634,6 +638,17 @@ class Scheduler:
return self.running_batch return self.running_batch
def get_new_batch_prefill(self) -> Optional[ScheduleBatch]: def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
# Check if the grammar queue is ready
if self.grammar_queue:
new_grammar_queue = []
for req in self.grammar_queue:
if req.grammar.done():
req.grammar = req.grammar.result()
self.waiting_queue.append(req)
else:
new_grammar_queue.append(req)
self.grammar_queue = new_grammar_queue
# Handle the cases where prefill is not allowed # Handle the cases where prefill is not allowed
if ( if (
self.batch_is_full or len(self.waiting_queue) == 0 self.batch_is_full or len(self.waiting_queue) == 0
......
...@@ -39,7 +39,6 @@ from vllm.model_executor.model_loader import get_model ...@@ -39,7 +39,6 @@ from vllm.model_executor.model_loader import get_model
from vllm.model_executor.models import ModelRegistry from vllm.model_executor.models import ModelRegistry
from sglang.srt.configs.model_config import AttentionArch, ModelConfig from sglang.srt.configs.model_config import AttentionArch, ModelConfig
from sglang.srt.constrained import disable_cache
from sglang.srt.layers.attention.double_sparsity_backend import DoubleSparseAttnBackend from sglang.srt.layers.attention.double_sparsity_backend import DoubleSparseAttnBackend
from sglang.srt.layers.attention.flashinfer_backend import FlashInferAttnBackend from sglang.srt.layers.attention.flashinfer_backend import FlashInferAttnBackend
from sglang.srt.layers.attention.triton_backend import TritonAttnBackend from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
...@@ -129,6 +128,8 @@ class ModelRunner: ...@@ -129,6 +128,8 @@ class ModelRunner:
if server_args.show_time_cost: if server_args.show_time_cost:
enable_show_time_cost() enable_show_time_cost()
if server_args.disable_disk_cache: if server_args.disable_disk_cache:
from outlines.caching import disable_cache
disable_cache() disable_cache()
global_server_args_dict.update( global_server_args_dict.update(
......
...@@ -100,8 +100,8 @@ class TestJSONConstrained(unittest.TestCase): ...@@ -100,8 +100,8 @@ class TestJSONConstrained(unittest.TestCase):
except (TypeError, json.decoder.JSONDecodeError): except (TypeError, json.decoder.JSONDecodeError):
print("JSONDecodeError", text) print("JSONDecodeError", text)
raise raise
assert isinstance(js_obj["name"], str) assert isinstance(js_obj["name"], str), f"{js_obj=}"
assert isinstance(js_obj["population"], int) assert isinstance(js_obj["population"], int), f"{js_obj=}"
def test_mix_json_and_other(self): def test_mix_json_and_other(self):
json_schemas = [None, None, self.json_schema, self.json_schema] * 10 json_schemas = [None, None, self.json_schema, self.json_schema] * 10
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment