Unverified Commit bca832c7 authored by JieXin Liang's avatar JieXin Liang Committed by GitHub
Browse files

[Fix] fix outlines and xgrammar (#4947)

parent d9dd5298
...@@ -19,10 +19,13 @@ Reference: https://lmsys.org/blog/2024-02-05-compressed-fsm/ ...@@ -19,10 +19,13 @@ Reference: https://lmsys.org/blog/2024-02-05-compressed-fsm/
import dataclasses import dataclasses
import logging import logging
from collections import defaultdict from collections import defaultdict
from typing import Optional
import interegular import interegular
from interegular import InvalidSyntax from interegular import InvalidSyntax
from outlines.caching import cache as disk_cache from outlines.caching import cache
from sglang.srt.utils import get_bool_env_var
try: try:
# outlines >= 0.1.0 # outlines >= 0.1.0
...@@ -34,6 +37,9 @@ except ImportError: ...@@ -34,6 +37,9 @@ except ImportError:
IP_REGEX = r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)" IP_REGEX = r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)"
# Env var was set in sglang.srt.server_args.ServerArgs.__post__init__
DISABLE_DISK_CACHE = get_bool_env_var("SGLANG_DISABLE_OUTLINES_DISK_CACHE", "true")
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -45,6 +51,13 @@ class JumpEdge: ...@@ -45,6 +51,13 @@ class JumpEdge:
byte_next_state: int = None byte_next_state: int = None
def disk_cache(expire: Optional[float] = None, typed=False, ignore=()):
if not DISABLE_DISK_CACHE:
return cache(expire, typed, ignore)
else:
return lambda fn: None
@disk_cache() @disk_cache()
def init_state_to_jump_forward(regex_string): def init_state_to_jump_forward(regex_string):
try: try:
......
# Adapt from
# https://github.com/mlc-ai/xgrammar/blob/v0.1.17/python/xgrammar/kernels/apply_token_bitmask_inplace_triton.py
from typing import List, Optional, Union
import torch
import triton
import triton.language as tl
from sglang.srt.utils import get_device_core_count
@triton.jit
def apply_token_bitmask_inplace_kernel(
logits_ptr,
bitmask_ptr,
indices_ptr,
num_rows,
vocab_size,
logits_strides,
bitmask_strides,
NUM_SMS: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
"""Apply a bitmask to logits in-place using Triton. The bitmask is a 01 bitwise compressed tensor,
where 0 means the token is masked and 1 means the token is not masked. After applying the bitmask,
the masked logits will be set to -inf.
Parameters
----------
logits_ptr : tl.tensor
Pointer to the logits tensor to apply the bitmask to.
bitmask_ptr : tl.tensor
Pointer to the bitmask tensor to apply.
indices_ptr : Optional[tl.tensor]
Optional pointer to indices tensor specifying which rows to apply the mask to.
num_rows : int
Number of rows to process. If indices_ptr is provided, this is the number of unique indices.
vocab_size : int
Size of the vocabulary dimension. If the logits does not have a vocab padding, this is the
same as the logits's second dimension. Otherwise, this is the actual size of the vocabulary.
logits_strides : int
Stride between rows in the logits tensor.
bitmask_strides : int
Stride between rows in the bitmask tensor.
NUM_SMS : int
Number of streaming multiprocessors to use.
BLOCK_SIZE : int
Size of processing blocks.
"""
pid = tl.program_id(0)
num_blocks = tl.cdiv(vocab_size, BLOCK_SIZE)
for work_id in tl.range(pid, num_rows * num_blocks, NUM_SMS):
row_id = work_id // num_blocks
block_offset = (work_id % num_blocks) * BLOCK_SIZE
batch_id = row_id if indices_ptr is None else tl.load(indices_ptr + row_id)
offsets = block_offset + tl.arange(0, BLOCK_SIZE)
bitmask_offsets = block_offset // 32 + tl.arange(0, BLOCK_SIZE // 32)
vocab_mask = offsets < vocab_size
packed_bitmask_mask = bitmask_offsets < bitmask_strides
packed_bitmask = tl.load(
bitmask_ptr + batch_id * bitmask_strides + bitmask_offsets,
packed_bitmask_mask,
)
bitmask = ((packed_bitmask[:, None] >> (tl.arange(0, 32)[None, :])) & 1) == 0
bitmask = bitmask.reshape(BLOCK_SIZE)
tl.store(
logits_ptr + batch_id * logits_strides + offsets,
-float("inf"),
vocab_mask & bitmask,
)
def apply_token_bitmask_inplace_triton(
logits: torch.Tensor,
bitmask: torch.Tensor,
indices: Optional[Union[List[int], torch.Tensor]] = None,
):
NUM_SMS = get_device_core_count()
BLOCK_SIZE = 4096
BITS_PER_BLOCK = 32
# Check input dtype
assert bitmask.dtype == torch.int32, "bitmask must be of type int32"
# Check input tensor shapes.
logits_shape = logits.shape
bitmask_shape = bitmask.shape
if logits.ndim == 1:
logits_shape = (1, logits_shape[0])
if bitmask.ndim == 1:
bitmask_shape = (1, bitmask_shape[0])
required_bitmask_width = (logits_shape[1] + BITS_PER_BLOCK - 1) // BITS_PER_BLOCK
assert required_bitmask_width >= bitmask_shape[1], (
f"Bitmask width too large: allow at most {required_bitmask_width} int32s for "
f"logits' width {logits_shape[1]}, but got {bitmask_shape[1]}"
)
vocab_size = min(logits_shape[1], bitmask_shape[1] * BITS_PER_BLOCK)
num_rows = None
if isinstance(indices, list) or isinstance(indices, torch.Tensor):
indices = torch.tensor(indices, dtype=torch.int32, device=logits.device)
num_rows = indices.shape[0]
else:
assert (
logits_shape[0] == bitmask_shape[0]
), f"batch size mismatch: logits {logits_shape[0]} vs bitmask {bitmask_shape[0]}"
num_rows = logits_shape[0]
if NUM_SMS > 0:
grid = (NUM_SMS,)
else:
num_blocks = triton.cdiv(vocab_size, BLOCK_SIZE)
grid = (num_rows * num_blocks,)
NUM_SMS = triton.next_power_of_2(grid[0])
apply_token_bitmask_inplace_kernel[grid](
logits,
bitmask,
indices,
num_rows,
vocab_size,
logits_shape[1],
bitmask_shape[1],
NUM_SMS,
BLOCK_SIZE,
num_warps=BLOCK_SIZE // 32 // (16 // logits.element_size()),
num_stages=3,
)
...@@ -25,13 +25,16 @@ from xgrammar import ( ...@@ -25,13 +25,16 @@ from xgrammar import (
StructuralTagItem, StructuralTagItem,
TokenizerInfo, TokenizerInfo,
allocate_token_bitmask, allocate_token_bitmask,
apply_token_bitmask_inplace,
) )
from sglang.srt.constrained.base_grammar_backend import ( from sglang.srt.constrained.base_grammar_backend import (
BaseGrammarBackend, BaseGrammarBackend,
BaseGrammarObject, BaseGrammarObject,
) )
from sglang.srt.constrained.triton_ops.bitmask_ops import (
apply_token_bitmask_inplace_triton,
)
from sglang.srt.utils import get_bool_env_var
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -55,6 +58,18 @@ class XGrammarGrammar(BaseGrammarObject): ...@@ -55,6 +58,18 @@ class XGrammarGrammar(BaseGrammarObject):
self.override_stop_tokens = override_stop_tokens self.override_stop_tokens = override_stop_tokens
self.finished = False self.finished = False
# Fix (from vLLM team): postpone the import of apply_token_bitmask_inplace_kernels to the
# class init site to avoid re-initializing CUDA in forked subprocess.
from xgrammar.kernels import apply_token_bitmask_inplace_kernels
self.use_token_bitmask_triton = get_bool_env_var(
"SGLANG_TOKEN_BITMASK_TRITON", "false"
)
self.apply_vocab_mask_cuda = apply_token_bitmask_inplace_kernels.get(
"cuda", None
)
self.apply_vocab_mask_cpu = apply_token_bitmask_inplace_kernels.get("cpu", None)
def accept_token(self, token: int): def accept_token(self, token: int):
assert self.matcher.accept_token(token) assert self.matcher.accept_token(token)
...@@ -97,9 +112,16 @@ class XGrammarGrammar(BaseGrammarObject): ...@@ -97,9 +112,16 @@ class XGrammarGrammar(BaseGrammarObject):
def move_vocab_mask(vocab_mask: torch.Tensor, device) -> torch.Tensor: def move_vocab_mask(vocab_mask: torch.Tensor, device) -> torch.Tensor:
return vocab_mask.to(device, non_blocking=True) return vocab_mask.to(device, non_blocking=True)
@staticmethod def apply_vocab_mask(self, logits: torch.Tensor, vocab_mask: torch.Tensor) -> None:
def apply_vocab_mask(logits: torch.Tensor, vocab_mask: torch.Tensor) -> None: if (
apply_token_bitmask_inplace(logits, vocab_mask) not self.use_token_bitmask_triton
and logits.device.type == "cuda"
and self.apply_vocab_mask_cuda
):
return self.apply_vocab_mask_cuda(logits, vocab_mask)
if logits.device.type == "cpu" and self.apply_vocab_mask_cpu:
return self.apply_vocab_mask_cpu(logits, vocab_mask)
apply_token_bitmask_inplace_triton(logits, vocab_mask)
def copy(self): def copy(self):
matcher = GrammarMatcher( matcher = GrammarMatcher(
......
...@@ -137,11 +137,6 @@ class ModelRunner: ...@@ -137,11 +137,6 @@ class ModelRunner:
if server_args.show_time_cost: if server_args.show_time_cost:
enable_show_time_cost() enable_show_time_cost()
if server_args.disable_outlines_disk_cache:
from outlines.caching import disable_cache
disable_cache()
# Global vars # Global vars
global_server_args_dict.update( global_server_args_dict.update(
{ {
......
...@@ -392,6 +392,10 @@ class ServerArgs: ...@@ -392,6 +392,10 @@ class ServerArgs:
os.environ["SGLANG_ENABLE_TORCH_COMPILE"] = ( os.environ["SGLANG_ENABLE_TORCH_COMPILE"] = (
"1" if self.enable_torch_compile else "0" "1" if self.enable_torch_compile else "0"
) )
# Set env var before grammar backends init
os.environ["SGLANG_DISABLE_OUTLINES_DISK_CACHE"] = (
"1" if self.disable_outlines_disk_cache else "0"
)
@staticmethod @staticmethod
def add_cli_args(parser: argparse.ArgumentParser): def add_cli_args(parser: argparse.ArgumentParser):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment