Unverified Commit 7e6544c7 authored by Benjamin Chislett's avatar Benjamin Chislett Committed by GitHub
Browse files

[Perf] Parallelize fill_bitmask to accelerate high-throughput guided decoding (#21862)


Signed-off-by: default avatarBenjamin Chislett <benjamin.chislett@centml.ai>
parent 8e6c7e87
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
from __future__ import annotations from __future__ import annotations
import multiprocessing import multiprocessing
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import Future, ThreadPoolExecutor
from typing import TYPE_CHECKING, Optional from typing import TYPE_CHECKING, Optional
from vllm.config import VllmConfig from vllm.config import VllmConfig
...@@ -40,6 +40,17 @@ class StructuredOutputManager: ...@@ -40,6 +40,17 @@ class StructuredOutputManager:
self._grammar_bitmask: Optional[torch.Tensor] = None self._grammar_bitmask: Optional[torch.Tensor] = None
self._full_mask = torch.tensor(-1, dtype=torch.int32) self._full_mask = torch.tensor(-1, dtype=torch.int32)
max_batch_size = self.vllm_config.scheduler_config.max_num_seqs
self.fill_bitmask_parallel_threshold = 128
if self.fill_bitmask_parallel_threshold < max_batch_size:
self.fill_bitmask_parallel_batch_size = 16
# Use:
# - at least 1 CPU
# - at most half the number of CPUs or 8, whichever is less
max_workers = max(1, min(multiprocessing.cpu_count() // 2, 8))
self.executor_for_fillmask = ThreadPoolExecutor(
max_workers=max_workers)
if not self.vllm_config.model_config.skip_tokenizer_init: if not self.vllm_config.model_config.skip_tokenizer_init:
# The default max_workers if not specified is the number of # The default max_workers if not specified is the number of
# CPUs * 5, which is way too high since these tasks are CPU-bound, # CPUs * 5, which is way too high since these tasks are CPU-bound,
...@@ -120,6 +131,26 @@ class StructuredOutputManager: ...@@ -120,6 +131,26 @@ class StructuredOutputManager:
assert self.backend is not None assert self.backend is not None
return self.backend.compile_grammar(request_type, grammar_spec) return self.backend.compile_grammar(request_type, grammar_spec)
def _fill_bitmasks(
self,
batch: list[tuple[StructuredOutputGrammar, int, bool]],
) -> None:
assert self._grammar_bitmask is not None
for grammar, index, apply_bitmask in batch:
if apply_bitmask and not grammar.is_terminated():
grammar.fill_bitmask(self._grammar_bitmask, index)
else:
# Note that for thinking support, we will need to
# reset the relevant part of the bitmask for consequent
# requests here.
self._grammar_bitmask[index].fill_(self._full_mask)
def _async_submit_fill_bitmask(
self,
batch: list[tuple[StructuredOutputGrammar, int, bool]],
) -> Future:
return self.executor_for_fillmask.submit(self._fill_bitmasks, batch)
def grammar_bitmask( def grammar_bitmask(
self, self,
requests: dict[str, Request], requests: dict[str, Request],
...@@ -146,7 +177,6 @@ class StructuredOutputManager: ...@@ -146,7 +177,6 @@ class StructuredOutputManager:
self.backend.allocate_token_bitmask( self.backend.allocate_token_bitmask(
max_batch_size * (1 + max_num_spec_tokens)) max_batch_size * (1 + max_num_spec_tokens))
bitmask_tensor = self._grammar_bitmask
# Generate a batched bitmask for all structured output requests. # Generate a batched bitmask for all structured output requests.
# When speculative decoding is enabled, we need to include multiple # When speculative decoding is enabled, we need to include multiple
# masks for each request, one for each possible bonus token position. # masks for each request, one for each possible bonus token position.
...@@ -155,14 +185,35 @@ class StructuredOutputManager: ...@@ -155,14 +185,35 @@ class StructuredOutputManager:
ordered_seq = sorted(structured_output_request_ids.items(), ordered_seq = sorted(structured_output_request_ids.items(),
key=lambda x: x[1]) key=lambda x: x[1])
# Note that for thinking support, we will need to # Optimized parallel filling of bitmasks for
# reset the relevant part of the bitmask for consequent # non-spec, large-batch-size cases
# request here. if len(ordered_seq) > self.fill_bitmask_parallel_threshold and \
bitmask_tensor[:(len(ordered_seq) * (1 + max_num_spec_tokens))].fill_( max_num_spec_tokens == 0:
self._full_mask) promises = []
batch = []
for req_id, _ in ordered_seq:
request = requests[req_id]
structured_output_request = request.structured_output_request
if TYPE_CHECKING:
assert structured_output_request is not None
assert structured_output_request.grammar is not None
apply_bitmask = self.should_fill_bitmask(request)
batch.append((structured_output_request.grammar,
cumulative_index, apply_bitmask))
if len(batch) == self.fill_bitmask_parallel_batch_size:
promises.append(self._async_submit_fill_bitmask(batch))
batch = []
cumulative_index += 1
if batch:
promises.append(self._async_submit_fill_bitmask(batch))
# NOTE: This outer loop can likely be parallelized to improve # Wait for all bitmask filling tasks to complete.
# performance of bitmask generation for large batches. for promise in promises:
promise.result()
else:
# Fallback to serial filling of bitmasks for small-batch-size cases
for req_id, _ in ordered_seq: for req_id, _ in ordered_seq:
request = requests[req_id] request = requests[req_id]
structured_output_request = request.structured_output_request structured_output_request = request.structured_output_request
...@@ -170,32 +221,25 @@ class StructuredOutputManager: ...@@ -170,32 +221,25 @@ class StructuredOutputManager:
if TYPE_CHECKING: if TYPE_CHECKING:
assert structured_output_request is not None assert structured_output_request is not None
assert structured_output_request.grammar is not None assert structured_output_request.grammar is not None
apply_bitmask: bool = True apply_bitmask = self.should_fill_bitmask(request)
if self.reasoner is not None:
if structured_output_request.reasoning_ended is None:
structured_output_request.reasoning_ended = \
self.reasoner.is_reasoning_end(request.prompt_token_ids)
apply_bitmask = structured_output_request.reasoning_ended
state_advancements = 0 state_advancements = 0
req_tokens = scheduled_spec_decode_tokens.get(req_id, []) + [None] req_tokens = scheduled_spec_decode_tokens.get(req_id, [])
for i, token in enumerate(req_tokens): for i, token in enumerate(req_tokens + [None]):
if apply_bitmask and not \ self._fill_bitmasks([(structured_output_request.grammar,
structured_output_request.grammar.is_terminated(): cumulative_index, apply_bitmask)])
structured_output_request.grammar.fill_bitmask(
bitmask_tensor, cumulative_index) if apply_bitmask and token is not None and \
if token is not None: not structured_output_request.grammar.is_terminated():
# In order to generate the correct bitmask for each
# position in the speculative sequence, we advance
# the FSM state for each speculative token and rollback
# to restore the previous state when we are finished.
assert structured_output_request.grammar.accept_tokens( assert structured_output_request.grammar.accept_tokens(
req_id, [token]) req_id, [token])
state_advancements += 1 state_advancements += 1
cumulative_index += 1 cumulative_index += 1
if state_advancements > 0: if state_advancements > 0:
structured_output_request.grammar.rollback(state_advancements) structured_output_request.grammar.rollback(
state_advancements)
bitmask_tensor = self._grammar_bitmask
if cumulative_index < bitmask_tensor.shape[0]: if cumulative_index < bitmask_tensor.shape[0]:
bitmask_tensor = bitmask_tensor[:cumulative_index] bitmask_tensor = bitmask_tensor[:cumulative_index]
...@@ -204,6 +248,15 @@ class StructuredOutputManager: ...@@ -204,6 +248,15 @@ class StructuredOutputManager:
# and deserialization when sending this to the GPU workers. # and deserialization when sending this to the GPU workers.
return bitmask_tensor.numpy() return bitmask_tensor.numpy()
def should_fill_bitmask(self, request: Request) -> bool:
if self.reasoner is not None:
assert request.structured_output_request is not None
if request.structured_output_request.reasoning_ended is None:
request.structured_output_request.reasoning_ended = \
self.reasoner.is_reasoning_end(request.prompt_token_ids)
return request.structured_output_request.reasoning_ended
return True
def should_advance(self, request: Request) -> bool: def should_advance(self, request: Request) -> bool:
if not request.use_structured_output: if not request.use_structured_output:
return False return False
......
...@@ -148,6 +148,7 @@ class XgrammarGrammar(StructuredOutputGrammar): ...@@ -148,6 +148,7 @@ class XgrammarGrammar(StructuredOutputGrammar):
repr=False, repr=False,
hash=False, hash=False,
init=False) init=False)
_is_terminated: bool = field(default=False, repr=False, hash=False)
def accept_tokens(self, request_id: str, tokens: list[int]) -> bool: def accept_tokens(self, request_id: str, tokens: list[int]) -> bool:
"""Accepts a list of tokens and advances the FSM. """Accepts a list of tokens and advances the FSM.
...@@ -155,6 +156,8 @@ class XgrammarGrammar(StructuredOutputGrammar): ...@@ -155,6 +156,8 @@ class XgrammarGrammar(StructuredOutputGrammar):
Returns True if the FSM was advanced successfully. Returns True if the FSM was advanced successfully.
Returns False if the FSM failed to advance. Returns False if the FSM failed to advance.
""" """
if self._is_terminated:
return False
for token in tokens: for token in tokens:
if not self.matcher.accept_token(token): if not self.matcher.accept_token(token):
logger.error( logger.error(
...@@ -162,6 +165,7 @@ class XgrammarGrammar(StructuredOutputGrammar): ...@@ -162,6 +165,7 @@ class XgrammarGrammar(StructuredOutputGrammar):
"for tokens %s. Please file an issue.", request_id, token) "for tokens %s. Please file an issue.", request_id, token)
return False return False
self.num_processed_tokens += 1 self.num_processed_tokens += 1
self._is_terminated = self.matcher.is_terminated()
return True return True
def validate_tokens(self, tokens: list[int]) -> list[int]: def validate_tokens(self, tokens: list[int]) -> list[int]:
...@@ -184,12 +188,13 @@ class XgrammarGrammar(StructuredOutputGrammar): ...@@ -184,12 +188,13 @@ class XgrammarGrammar(StructuredOutputGrammar):
def rollback(self, num_tokens: int) -> None: def rollback(self, num_tokens: int) -> None:
self.matcher.rollback(num_tokens) self.matcher.rollback(num_tokens)
self.num_processed_tokens -= num_tokens self.num_processed_tokens -= num_tokens
self._is_terminated = self.matcher.is_terminated()
def fill_bitmask(self, bitmask: torch.Tensor, idx: int) -> None: def fill_bitmask(self, bitmask: torch.Tensor, idx: int) -> None:
self.matcher.fill_next_token_bitmask(bitmask, idx) self.matcher.fill_next_token_bitmask(bitmask, idx)
def is_terminated(self) -> bool: def is_terminated(self) -> bool:
return self.matcher.is_terminated() return self._is_terminated
def reset(self): def reset(self):
self.num_processed_tokens = 0 self.num_processed_tokens = 0
......
...@@ -1324,9 +1324,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -1324,9 +1324,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
cumulative_index += 1 + num_spec_tokens cumulative_index += 1 + num_spec_tokens
grammar_bitmask = sorted_bitmask grammar_bitmask = sorted_bitmask
# If the grammar bitmask and the logits have the same shape
# we don't need to pass indices to the kernel,
# since the bitmask is already aligned with the logits.
skip_out_indices = grammar_bitmask.shape[0] == logits.shape[0]
# Serialization of np.ndarray is much more efficient than a tensor, # Serialization of np.ndarray is much more efficient than a tensor,
# so we receive it in that format. # so we receive it in that format.
grammar_bitmask = torch.from_numpy(grammar_bitmask) grammar_bitmask = torch.from_numpy(grammar_bitmask).contiguous()
# Force use of the torch.compile implementation from xgrammar to work # Force use of the torch.compile implementation from xgrammar to work
# around issues with the Triton kernel in concurrent structured output # around issues with the Triton kernel in concurrent structured output
...@@ -1334,7 +1339,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -1334,7 +1339,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
xgr_torch_compile.apply_token_bitmask_inplace_torch_compile( xgr_torch_compile.apply_token_bitmask_inplace_torch_compile(
logits, logits,
grammar_bitmask.to(self.device, non_blocking=True), grammar_bitmask.to(self.device, non_blocking=True),
indices=out_indices, indices=out_indices if not skip_out_indices else None,
) )
def sync_and_slice_intermediate_tensors( def sync_and_slice_intermediate_tensors(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment