Unverified Commit bca832c7 authored by JieXin Liang's avatar JieXin Liang Committed by GitHub
Browse files

[Fix] fix outlines and xgrammar (#4947)

parent d9dd5298
......@@ -19,10 +19,13 @@ Reference: https://lmsys.org/blog/2024-02-05-compressed-fsm/
import dataclasses
import logging
from collections import defaultdict
from typing import Optional
import interegular
from interegular import InvalidSyntax
from outlines.caching import cache as disk_cache
from outlines.caching import cache
from sglang.srt.utils import get_bool_env_var
try:
# outlines >= 0.1.0
......@@ -34,6 +37,9 @@ except ImportError:
IP_REGEX = r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)"
# Env var was set in sglang.srt.server_args.ServerArgs.__post__init__
DISABLE_DISK_CACHE = get_bool_env_var("SGLANG_DISABLE_OUTLINES_DISK_CACHE", "true")
logger = logging.getLogger(__name__)
......@@ -45,6 +51,13 @@ class JumpEdge:
byte_next_state: int = None
def disk_cache(expire: Optional[float] = None, typed=False, ignore=()):
if not DISABLE_DISK_CACHE:
return cache(expire, typed, ignore)
else:
return lambda fn: None
@disk_cache()
def init_state_to_jump_forward(regex_string):
try:
......
# Adapt from
# https://github.com/mlc-ai/xgrammar/blob/v0.1.17/python/xgrammar/kernels/apply_token_bitmask_inplace_triton.py
from typing import List, Optional, Union
import torch
import triton
import triton.language as tl
from sglang.srt.utils import get_device_core_count
@triton.jit
def apply_token_bitmask_inplace_kernel(
logits_ptr,
bitmask_ptr,
indices_ptr,
num_rows,
vocab_size,
logits_strides,
bitmask_strides,
NUM_SMS: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
"""Apply a bitmask to logits in-place using Triton. The bitmask is a 01 bitwise compressed tensor,
where 0 means the token is masked and 1 means the token is not masked. After applying the bitmask,
the masked logits will be set to -inf.
Parameters
----------
logits_ptr : tl.tensor
Pointer to the logits tensor to apply the bitmask to.
bitmask_ptr : tl.tensor
Pointer to the bitmask tensor to apply.
indices_ptr : Optional[tl.tensor]
Optional pointer to indices tensor specifying which rows to apply the mask to.
num_rows : int
Number of rows to process. If indices_ptr is provided, this is the number of unique indices.
vocab_size : int
Size of the vocabulary dimension. If the logits does not have a vocab padding, this is the
same as the logits's second dimension. Otherwise, this is the actual size of the vocabulary.
logits_strides : int
Stride between rows in the logits tensor.
bitmask_strides : int
Stride between rows in the bitmask tensor.
NUM_SMS : int
Number of streaming multiprocessors to use.
BLOCK_SIZE : int
Size of processing blocks.
"""
pid = tl.program_id(0)
num_blocks = tl.cdiv(vocab_size, BLOCK_SIZE)
for work_id in tl.range(pid, num_rows * num_blocks, NUM_SMS):
row_id = work_id // num_blocks
block_offset = (work_id % num_blocks) * BLOCK_SIZE
batch_id = row_id if indices_ptr is None else tl.load(indices_ptr + row_id)
offsets = block_offset + tl.arange(0, BLOCK_SIZE)
bitmask_offsets = block_offset // 32 + tl.arange(0, BLOCK_SIZE // 32)
vocab_mask = offsets < vocab_size
packed_bitmask_mask = bitmask_offsets < bitmask_strides
packed_bitmask = tl.load(
bitmask_ptr + batch_id * bitmask_strides + bitmask_offsets,
packed_bitmask_mask,
)
bitmask = ((packed_bitmask[:, None] >> (tl.arange(0, 32)[None, :])) & 1) == 0
bitmask = bitmask.reshape(BLOCK_SIZE)
tl.store(
logits_ptr + batch_id * logits_strides + offsets,
-float("inf"),
vocab_mask & bitmask,
)
def apply_token_bitmask_inplace_triton(
logits: torch.Tensor,
bitmask: torch.Tensor,
indices: Optional[Union[List[int], torch.Tensor]] = None,
):
NUM_SMS = get_device_core_count()
BLOCK_SIZE = 4096
BITS_PER_BLOCK = 32
# Check input dtype
assert bitmask.dtype == torch.int32, "bitmask must be of type int32"
# Check input tensor shapes.
logits_shape = logits.shape
bitmask_shape = bitmask.shape
if logits.ndim == 1:
logits_shape = (1, logits_shape[0])
if bitmask.ndim == 1:
bitmask_shape = (1, bitmask_shape[0])
required_bitmask_width = (logits_shape[1] + BITS_PER_BLOCK - 1) // BITS_PER_BLOCK
assert required_bitmask_width >= bitmask_shape[1], (
f"Bitmask width too large: allow at most {required_bitmask_width} int32s for "
f"logits' width {logits_shape[1]}, but got {bitmask_shape[1]}"
)
vocab_size = min(logits_shape[1], bitmask_shape[1] * BITS_PER_BLOCK)
num_rows = None
if isinstance(indices, list) or isinstance(indices, torch.Tensor):
indices = torch.tensor(indices, dtype=torch.int32, device=logits.device)
num_rows = indices.shape[0]
else:
assert (
logits_shape[0] == bitmask_shape[0]
), f"batch size mismatch: logits {logits_shape[0]} vs bitmask {bitmask_shape[0]}"
num_rows = logits_shape[0]
if NUM_SMS > 0:
grid = (NUM_SMS,)
else:
num_blocks = triton.cdiv(vocab_size, BLOCK_SIZE)
grid = (num_rows * num_blocks,)
NUM_SMS = triton.next_power_of_2(grid[0])
apply_token_bitmask_inplace_kernel[grid](
logits,
bitmask,
indices,
num_rows,
vocab_size,
logits_shape[1],
bitmask_shape[1],
NUM_SMS,
BLOCK_SIZE,
num_warps=BLOCK_SIZE // 32 // (16 // logits.element_size()),
num_stages=3,
)
......@@ -25,13 +25,16 @@ from xgrammar import (
StructuralTagItem,
TokenizerInfo,
allocate_token_bitmask,
apply_token_bitmask_inplace,
)
from sglang.srt.constrained.base_grammar_backend import (
BaseGrammarBackend,
BaseGrammarObject,
)
from sglang.srt.constrained.triton_ops.bitmask_ops import (
apply_token_bitmask_inplace_triton,
)
from sglang.srt.utils import get_bool_env_var
logger = logging.getLogger(__name__)
......@@ -55,6 +58,18 @@ class XGrammarGrammar(BaseGrammarObject):
self.override_stop_tokens = override_stop_tokens
self.finished = False
# Fix (from vLLM team): postpone the import of apply_token_bitmask_inplace_kernels to the
# class init site to avoid re-initializing CUDA in forked subprocess.
from xgrammar.kernels import apply_token_bitmask_inplace_kernels
self.use_token_bitmask_triton = get_bool_env_var(
"SGLANG_TOKEN_BITMASK_TRITON", "false"
)
self.apply_vocab_mask_cuda = apply_token_bitmask_inplace_kernels.get(
"cuda", None
)
self.apply_vocab_mask_cpu = apply_token_bitmask_inplace_kernels.get("cpu", None)
def accept_token(self, token: int):
assert self.matcher.accept_token(token)
......@@ -97,9 +112,16 @@ class XGrammarGrammar(BaseGrammarObject):
def move_vocab_mask(vocab_mask: torch.Tensor, device) -> torch.Tensor:
return vocab_mask.to(device, non_blocking=True)
@staticmethod
def apply_vocab_mask(logits: torch.Tensor, vocab_mask: torch.Tensor) -> None:
apply_token_bitmask_inplace(logits, vocab_mask)
def apply_vocab_mask(self, logits: torch.Tensor, vocab_mask: torch.Tensor) -> None:
if (
not self.use_token_bitmask_triton
and logits.device.type == "cuda"
and self.apply_vocab_mask_cuda
):
return self.apply_vocab_mask_cuda(logits, vocab_mask)
if logits.device.type == "cpu" and self.apply_vocab_mask_cpu:
return self.apply_vocab_mask_cpu(logits, vocab_mask)
apply_token_bitmask_inplace_triton(logits, vocab_mask)
def copy(self):
matcher = GrammarMatcher(
......
......@@ -137,11 +137,6 @@ class ModelRunner:
if server_args.show_time_cost:
enable_show_time_cost()
if server_args.disable_outlines_disk_cache:
from outlines.caching import disable_cache
disable_cache()
# Global vars
global_server_args_dict.update(
{
......
......@@ -392,6 +392,10 @@ class ServerArgs:
os.environ["SGLANG_ENABLE_TORCH_COMPILE"] = (
"1" if self.enable_torch_compile else "0"
)
# Set env var before grammar backends init
os.environ["SGLANG_DISABLE_OUTLINES_DISK_CACHE"] = (
"1" if self.disable_outlines_disk_cache else "0"
)
@staticmethod
def add_cli_args(parser: argparse.ArgumentParser):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment