Unverified Commit 911f3ba6 authored by Yixin Dong's avatar Yixin Dong Committed by GitHub
Browse files

upgrade xgrammar to 0.1.19 (#6129)

parent f6f96b05
......@@ -42,7 +42,7 @@ runtime_common = [
"transformers==4.51.1",
"uvicorn",
"uvloop",
"xgrammar==0.1.17",
"xgrammar==0.1.19",
"blobfile==3.0.0"
]
......
......@@ -18,6 +18,7 @@ import logging
from typing import List, Optional, Tuple, Union
import torch
import xgrammar
from xgrammar import (
CompiledGrammar,
GrammarCompiler,
......@@ -58,17 +59,11 @@ class XGrammarGrammar(BaseGrammarObject):
self.override_stop_tokens = override_stop_tokens
self.finished = False
# Fix (from vLLM team): postpone the import of apply_token_bitmask_inplace_kernels to the
# class init site to avoid re-initializing CUDA in forked subprocess.
from xgrammar.kernels import apply_token_bitmask_inplace_kernels
self.use_token_bitmask_triton = get_bool_env_var(
"SGLANG_TOKEN_BITMASK_TRITON", "false"
)
self.apply_vocab_mask_cuda = apply_token_bitmask_inplace_kernels.get(
"cuda", None
from xgrammar.kernels.apply_token_bitmask_inplace_cpu import (
apply_token_bitmask_inplace_cpu,
)
self.apply_vocab_mask_cpu = apply_token_bitmask_inplace_kernels.get("cpu", None)
self.apply_vocab_mask_cpu = apply_token_bitmask_inplace_cpu
def accept_token(self, token: int):
assert self.matcher.accept_token(token)
......@@ -113,15 +108,12 @@ class XGrammarGrammar(BaseGrammarObject):
return vocab_mask.to(device, non_blocking=True)
def apply_vocab_mask(self, logits: torch.Tensor, vocab_mask: torch.Tensor) -> None:
if (
not self.use_token_bitmask_triton
and logits.device.type == "cuda"
and self.apply_vocab_mask_cuda
):
return self.apply_vocab_mask_cuda(logits, vocab_mask)
if logits.device.type == "cpu" and self.apply_vocab_mask_cpu:
return self.apply_vocab_mask_cpu(logits, vocab_mask)
apply_token_bitmask_inplace_triton(logits, vocab_mask)
if logits.device.type == "cuda":
apply_token_bitmask_inplace_triton(logits, vocab_mask)
elif logits.device.type == "cpu" and self.apply_vocab_mask_cpu:
self.apply_vocab_mask_cpu(logits, vocab_mask)
else:
raise RuntimeError(f"Unsupported device: {logits.device.type}")
def copy(self):
matcher = GrammarMatcher(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment