Unverified Commit 911f3ba6 authored by Yixin Dong's avatar Yixin Dong Committed by GitHub
Browse files

upgrade xgrammar to 0.1.19 (#6129)

parent f6f96b05
...@@ -42,7 +42,7 @@ runtime_common = [ ...@@ -42,7 +42,7 @@ runtime_common = [
"transformers==4.51.1", "transformers==4.51.1",
"uvicorn", "uvicorn",
"uvloop", "uvloop",
"xgrammar==0.1.17", "xgrammar==0.1.19",
"blobfile==3.0.0" "blobfile==3.0.0"
] ]
......
...@@ -18,6 +18,7 @@ import logging ...@@ -18,6 +18,7 @@ import logging
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
import torch import torch
import xgrammar
from xgrammar import ( from xgrammar import (
CompiledGrammar, CompiledGrammar,
GrammarCompiler, GrammarCompiler,
...@@ -58,17 +59,11 @@ class XGrammarGrammar(BaseGrammarObject): ...@@ -58,17 +59,11 @@ class XGrammarGrammar(BaseGrammarObject):
self.override_stop_tokens = override_stop_tokens self.override_stop_tokens = override_stop_tokens
self.finished = False self.finished = False
# Fix (from vLLM team): postpone the import of apply_token_bitmask_inplace_kernels to the from xgrammar.kernels.apply_token_bitmask_inplace_cpu import (
# class init site to avoid re-initializing CUDA in forked subprocess. apply_token_bitmask_inplace_cpu,
from xgrammar.kernels import apply_token_bitmask_inplace_kernels
self.use_token_bitmask_triton = get_bool_env_var(
"SGLANG_TOKEN_BITMASK_TRITON", "false"
)
self.apply_vocab_mask_cuda = apply_token_bitmask_inplace_kernels.get(
"cuda", None
) )
self.apply_vocab_mask_cpu = apply_token_bitmask_inplace_kernels.get("cpu", None)
self.apply_vocab_mask_cpu = apply_token_bitmask_inplace_cpu
def accept_token(self, token: int): def accept_token(self, token: int):
assert self.matcher.accept_token(token) assert self.matcher.accept_token(token)
...@@ -113,15 +108,12 @@ class XGrammarGrammar(BaseGrammarObject): ...@@ -113,15 +108,12 @@ class XGrammarGrammar(BaseGrammarObject):
return vocab_mask.to(device, non_blocking=True) return vocab_mask.to(device, non_blocking=True)
def apply_vocab_mask(self, logits: torch.Tensor, vocab_mask: torch.Tensor) -> None: def apply_vocab_mask(self, logits: torch.Tensor, vocab_mask: torch.Tensor) -> None:
if ( if logits.device.type == "cuda":
not self.use_token_bitmask_triton apply_token_bitmask_inplace_triton(logits, vocab_mask)
and logits.device.type == "cuda" elif logits.device.type == "cpu" and self.apply_vocab_mask_cpu:
and self.apply_vocab_mask_cuda self.apply_vocab_mask_cpu(logits, vocab_mask)
): else:
return self.apply_vocab_mask_cuda(logits, vocab_mask) raise RuntimeError(f"Unsupported device: {logits.device.type}")
if logits.device.type == "cpu" and self.apply_vocab_mask_cpu:
return self.apply_vocab_mask_cpu(logits, vocab_mask)
apply_token_bitmask_inplace_triton(logits, vocab_mask)
def copy(self): def copy(self):
matcher = GrammarMatcher( matcher = GrammarMatcher(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment