Unverified Commit c57bb199 authored by Russell Bryant's avatar Russell Bryant Committed by GitHub
Browse files

[V1] Resolve failed concurrent structured output requests (#19565)


Signed-off-by: default avatarRussell Bryant <rbryant@redhat.com>
parent dba68f91
...@@ -66,11 +66,15 @@ from .utils import (gather_mm_placeholders, initialize_kv_cache_for_kv_sharing, ...@@ -66,11 +66,15 @@ from .utils import (gather_mm_placeholders, initialize_kv_cache_for_kv_sharing,
if TYPE_CHECKING: if TYPE_CHECKING:
import xgrammar as xgr import xgrammar as xgr
import xgrammar.kernels.apply_token_bitmask_inplace_torch_compile as xgr_torch_compile # noqa: E501
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.output import SchedulerOutput
else: else:
xgr = LazyLoader("xgr", globals(), "xgrammar") xgr = LazyLoader("xgr", globals(), "xgrammar")
xgr_torch_compile = LazyLoader(
"xgr_torch_compile", globals(),
"xgrammar.kernels.apply_token_bitmask_inplace_torch_compile")
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -1103,7 +1107,10 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1103,7 +1107,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# so we receive it in that format. # so we receive it in that format.
grammar_bitmask = torch.from_numpy(grammar_bitmask) grammar_bitmask = torch.from_numpy(grammar_bitmask)
xgr.apply_token_bitmask_inplace( # Force use of the torch.compile implementation from xgrammar to work
# around issues with the Triton kernel in concurrent structured output
# scenarios. See PR #19565 and issues #19493, #18376 for details.
xgr_torch_compile.apply_token_bitmask_inplace_torch_compile(
logits, logits,
grammar_bitmask.to(self.device, non_blocking=True), grammar_bitmask.to(self.device, non_blocking=True),
indices=out_indices, indices=out_indices,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment