Unverified Commit 56539cdd authored by Lucas Wilkinson's avatar Lucas Wilkinson Committed by GitHub
Browse files

[Core] Refactor padding logic and pad for CUDA graphs before attention metadata building (#28579)

parent 430dd4d9
......@@ -84,12 +84,14 @@ See the following figures for a quick comparison between the previous and curren
```python
class BatchDescriptor(NamedTuple):
num_tokens: int
uniform_decode: bool = False
num_reqs: int
uniform: bool = False
has_lora: bool = False
```
where `num_tokens` can be the padded token length, and `uniform_decode` is determined by if `max_query_len` of a batch is equal to the desired `max_query_len` of a uniform_decode, and the num_scheduled_tokens is divisible by that desired `max_query_len`.
where `num_tokens` can be the padded token length, and `uniform` indicates if all the requests have the same query lengths. Many attention backends only support full cudagraphs when the batches are uniform; pure decode batches are uniform but may not be query length 1 (i.e. `num_tokens == num_reqs`), this occurs in the validation pass of spec-decode where "decode" batches will have a query length of `1+num_spec_tokens`.
The goal of this structure is to uniquely identify a (padded) batch with minimal possible items corresponding to a CUDA Graphs item. We are safe to exclude items like `uniform_query_len` because it is a constant at runtime for a certain setup currently. For example, it should be either `1` for a commonly pure decode or `1+num_spec_tokens` for a validation phase of speculative decode.
The goal of this structure is to uniquely identify a (padded) batch with minimal possible items corresponding to a CUDA Graphs item.
!!! note
The prototype of `BatchDescriptor` may be extended for more general situations in the future, e.g., include more items, like `uniform_query_len` to support multiple different uniform decode lengths settings (<https://github.com/vllm-project/vllm/pull/23679>), or other modifications needed to support CUDA Graphs for models whose inputs are not necessarily token length aware (for example, some multi-modal inputs).
......
......@@ -42,12 +42,24 @@ def _create_vllm_config(
mock_config.compilation_config = compilation_config
mock_config.scheduler_config = SchedulerConfig(max_num_seqs=max_num_seqs)
mock_config.parallel_config = ParallelConfig()
mock_config.speculative_config = None # No speculative decoding
if not lora_config:
mock_config.lora_config = None
# Mimic the behavior of VllmConfig.__post_init__()
if compilation_config.mode == CompilationMode.VLLM_COMPILE:
compilation_config.set_splitting_ops_for_v1()
# mimic VllmConfig.__post_init__
if compilation_config.cudagraph_capture_sizes:
compilation_config.max_cudagraph_capture_size = (
compilation_config.cudagraph_capture_sizes[-1]
)
compilation_config.post_init_cudagraph_sizes()
mock_config.pad_for_cudagraph = (
lambda batch_size: compilation_config.bs_to_padded_graph_size[batch_size]
)
return mock_config
......@@ -109,9 +121,11 @@ class TestCudagraphDispatcher:
# 1. non-uniform batch, size in cudagraph size list
desc_full_exact = BatchDescriptor(
num_tokens=8,
uniform_decode=False,
uniform=False,
)
rt_mode, key = dispatcher.dispatch(
num_tokens=8, uniform_decode=False, has_lora=False
)
rt_mode, key = dispatcher.dispatch(desc_full_exact)
if cudagraph_mode_str == "FULL":
assert rt_mode == CUDAGraphMode.FULL
assert key == desc_full_exact
......@@ -122,32 +136,37 @@ class TestCudagraphDispatcher:
assert rt_mode == CUDAGraphMode.NONE
# 2. uniform decode batch, size in cudagraph size list
desc_uniform_exact = BatchDescriptor(num_tokens=8, uniform_decode=True)
rt_mode, key = dispatcher.dispatch(desc_uniform_exact)
desc_uniform_exact = BatchDescriptor(num_tokens=8, num_reqs=8, uniform=True)
rt_mode, key = dispatcher.dispatch(
num_tokens=8, uniform_decode=True, has_lora=False
)
if cudagraph_mode_str == "FULL":
assert rt_mode == CUDAGraphMode.FULL
assert key == desc_uniform_exact.non_uniform
assert key == desc_uniform_exact.relax_for_mixed_batch_cudagraphs()
elif cudagraph_mode_str in ["FULL_DECODE_ONLY", "FULL_AND_PIECEWISE"]:
assert rt_mode == CUDAGraphMode.FULL
assert key == desc_uniform_exact
elif cudagraph_mode_str == "PIECEWISE":
assert rt_mode == CUDAGraphMode.PIECEWISE
assert key == desc_uniform_exact.non_uniform
assert key == desc_uniform_exact.relax_for_mixed_batch_cudagraphs()
else:
assert rt_mode == CUDAGraphMode.NONE
# 3. No key match
desc_no_match = BatchDescriptor(num_tokens=15, uniform_decode=False)
rt_mode, key = dispatcher.dispatch(desc_no_match)
rt_mode, key = dispatcher.dispatch(
num_tokens=15, uniform_decode=False, has_lora=False
)
assert rt_mode == CUDAGraphMode.NONE
assert key is None
assert key == BatchDescriptor(num_tokens=15)
# 4. Cascade attention should have a fall back mode
desc_full_exact = BatchDescriptor(num_tokens=8, uniform_decode=False)
rt_mode, key = dispatcher.dispatch(desc_full_exact, use_cascade_attn=True)
desc_full_exact = BatchDescriptor(num_tokens=8, uniform=False)
rt_mode, key = dispatcher.dispatch(
num_tokens=8, uniform_decode=False, has_lora=False, use_cascade_attn=True
)
if "PIECEWISE" in cudagraph_mode_str: # string contains check
assert rt_mode == CUDAGraphMode.PIECEWISE
assert key == desc_full_exact.non_uniform
assert key == desc_full_exact.relax_for_mixed_batch_cudagraphs()
else:
assert rt_mode == CUDAGraphMode.NONE
......
......@@ -35,23 +35,27 @@ class BatchDescriptor(NamedTuple):
"""
num_tokens: int
uniform_decode: bool = False
num_reqs: int | None = None
"""
False can also be used for an uniform decode batch to dispatch to the
cudagraph supporting non-uniform batches.
Number of requests in the batch. Can be None for PIECEWISE cudagraphs where
the cudagraphs can handle any number of requests.
"""
uniform: bool = False
"""
True if all the requests in the batch have the same number of tokens.
"""
has_lora: bool = False
"""
Whether this batch has active LoRA adapters.
"""
@property
def non_uniform(self) -> "BatchDescriptor":
def relax_for_mixed_batch_cudagraphs(self) -> "BatchDescriptor":
"""
Return a non-uniform version of current batch descriptor.
Return a relaxed version of current batch descriptor that is still compatible
with PIECEWISE cudagraphs (or mixed prefill-decode FA cudagraphs).
"""
return BatchDescriptor(
self.num_tokens, uniform_decode=False, has_lora=self.has_lora
self.num_tokens, num_reqs=None, uniform=False, has_lora=self.has_lora
)
......
......@@ -930,30 +930,11 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
if num_decodes > 0:
pure_decode = num_prefills == 0
# possible required padding for cudagraph replay
use_cudagraph = (
self.enable_cuda_graph
and pure_decode
and num_decode_tokens <= self._decode_cudagraph_max_bs
)
if use_cudagraph:
num_input_tokens = self.vllm_config.pad_for_cudagraph(
num_decode_tokens
)
# Carefully fulfill the padding region with reasonable value
# on cpu.
# Make sure paged_kv_indptr_cpu is not decreasing
self.paged_kv_indptr_cpu[
1 + num_decodes : 1 + num_input_tokens
].fill_(paged_kv_indptr_cpu[-1])
# Fill the remaining paged_kv_last_page_len_cpu with 1.
# This is because flashinfer treats 0 as a full page
# instead of empty.
self.paged_kv_last_page_len_cpu[num_decodes:num_input_tokens].fill_(
1
)
else:
num_input_tokens = num_decode_tokens
attn_metadata.decode_wrapper = self._get_decode_wrapper(
......
......@@ -107,6 +107,8 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
)
# -1 in case it's non-computed and causes later issues with indexing
block_idx_last_computed_token = block_idx_last_computed_token.clamp(min=0)
# -1 in the case we have a padded request (0 seq-len)
block_idx_last_scheduled_token = block_idx_last_scheduled_token.clamp(min=0)
return (
block_idx_last_computed_token,
......
......@@ -72,6 +72,7 @@ class CommonAttentionMetadata:
num_reqs: int
"""Number of requests"""
# TODO(lucas): rename to num_tokens since it may be padded and this is misleading
num_actual_tokens: int
"""Total number of tokens in batch"""
max_query_len: int
......@@ -857,7 +858,9 @@ def split_decodes_and_prefills(
if require_uniform:
is_prefill = query_lens != query_lens[0]
else:
is_prefill = query_lens > decode_threshold
# 0-query len indicates a padded request; leave this at the back
# of the batch with the prefills
is_prefill = (query_lens > decode_threshold) | (query_lens == 0)
if not torch.any(is_prefill):
return num_reqs, 0, num_tokens, 0
......
......@@ -4,6 +4,9 @@ from itertools import product
from vllm.config import CUDAGraphMode, VllmConfig
from vllm.forward_context import BatchDescriptor
from vllm.logger import init_logger
logger = init_logger(__name__)
class CudagraphDispatcher:
......@@ -28,7 +31,11 @@ class CudagraphDispatcher:
def __init__(self, vllm_config: VllmConfig):
self.vllm_config = vllm_config
self.compilation_config = vllm_config.compilation_config
self.cudagraph_mode = self.compilation_config.cudagraph_mode
self.uniform_decode_query_len = (
1
if not self.vllm_config.speculative_config
else 1 + self.vllm_config.speculative_config.num_speculative_tokens
)
# Dict to store valid cudagraph dispatching keys.
self.cudagraph_keys: dict[CUDAGraphMode, set[BatchDescriptor]] = {
......@@ -36,25 +43,42 @@ class CudagraphDispatcher:
CUDAGraphMode.FULL: set(),
}
not_use_piecewise_compilation = (
not self.cudagraph_mode.requires_piecewise_compilation()
)
assert (
not_use_piecewise_compilation
not self.compilation_config.cudagraph_mode.requires_piecewise_compilation()
or self.compilation_config.is_attention_compiled_piecewise()
), (
"Compilation mode should be CompilationMode.VLLM_COMPILE when "
"cudagraph_mode piecewise cudagraphs is used, "
"and attention should be in splitting_ops or "
"inductor splitting should be used. "
f"cudagraph_mode={self.cudagraph_mode}, "
f"cudagraph_mode={self.compilation_config.cudagraph_mode}, "
f"compilation_mode={self.compilation_config.mode}, "
f"splitting_ops={self.compilation_config.splitting_ops}"
)
self.keys_initialized = False
def _create_padded_batch_descriptor(
self, num_tokens: int, uniform_decode: bool, has_lora: bool
) -> BatchDescriptor:
max_num_seqs = self.vllm_config.scheduler_config.max_num_seqs
uniform_decode_query_len = self.uniform_decode_query_len
num_tokens_padded = self.vllm_config.pad_for_cudagraph(num_tokens)
if uniform_decode and self.cudagraph_mode.has_mode(CUDAGraphMode.FULL):
num_reqs = num_tokens_padded // uniform_decode_query_len
assert num_tokens_padded % uniform_decode_query_len == 0
else:
uniform_decode = False
num_reqs = min(num_tokens_padded, max_num_seqs)
return BatchDescriptor(
num_tokens=num_tokens_padded,
num_reqs=num_reqs,
uniform=uniform_decode,
has_lora=has_lora,
)
def add_cudagraph_key(
self, runtime_mode: CUDAGraphMode, batch_descriptor: BatchDescriptor
):
......@@ -66,7 +90,9 @@ class CudagraphDispatcher:
def initialize_cudagraph_keys(
self, cudagraph_mode: CUDAGraphMode, uniform_decode_query_len: int
):
# This should be called only after attention backend is initialized.
# This should be called only after attention backend is initialized. So we can
# get the correct cudagraph mode after backend support is resolved.
self.cudagraph_mode = cudagraph_mode
# LoRA activation cases to specialize the cuda graphs on
if self.vllm_config.lora_config:
......@@ -86,9 +112,9 @@ class CudagraphDispatcher:
):
self.add_cudagraph_key(
cudagraph_mode.mixed_mode(),
BatchDescriptor(
num_tokens=bs, uniform_decode=False, has_lora=has_lora
),
self._create_padded_batch_descriptor(
bs, False, has_lora
).relax_for_mixed_batch_cudagraphs(),
)
# if decode cudagraph mode is FULL, and we don't already have mixed
......@@ -109,40 +135,49 @@ class CudagraphDispatcher:
for bs, has_lora in product(cudagraph_capture_sizes_for_decode, lora_cases):
self.add_cudagraph_key(
CUDAGraphMode.FULL,
BatchDescriptor(
num_tokens=bs, uniform_decode=True, has_lora=has_lora
),
self._create_padded_batch_descriptor(bs, True, has_lora),
)
self.keys_initialized = True
def dispatch(
self, batch_descriptor: BatchDescriptor, use_cascade_attn: bool = False
) -> tuple[CUDAGraphMode, BatchDescriptor | None]:
self,
num_tokens: int,
uniform_decode: bool,
has_lora: bool,
use_cascade_attn: bool = False,
) -> tuple[CUDAGraphMode, BatchDescriptor]:
"""
Given conditions(e.g.,batch descriptor and if using cascade attention),
dispatch to a cudagraph runtime mode and the valid batch descriptor.
A new batch descriptor is returned as we might dispatch a uniform batch
to a graph that supports a more general batch (uniform to non-uniform).
"""
# if not initialized, just skip dispatching.
if not self.keys_initialized:
return CUDAGraphMode.NONE, None
if (
not self.keys_initialized
or self.cudagraph_mode == CUDAGraphMode.NONE
or num_tokens > self.compilation_config.max_cudagraph_capture_size
):
return CUDAGraphMode.NONE, BatchDescriptor(num_tokens)
batch_desc = self._create_padded_batch_descriptor(
num_tokens, uniform_decode, has_lora
)
relaxed_batch_desc = batch_desc.relax_for_mixed_batch_cudagraphs()
non_uniform_key = batch_descriptor.non_uniform
# if a batch use cascade attention, bypass checking full cudagraphs
if not use_cascade_attn:
# check if key exists for full cudagraph
if batch_descriptor in self.cudagraph_keys[CUDAGraphMode.FULL]:
return CUDAGraphMode.FULL, batch_descriptor
if batch_desc in self.cudagraph_keys[CUDAGraphMode.FULL]:
return CUDAGraphMode.FULL, batch_desc
# otherwise, check if non-uniform key exists
if non_uniform_key in self.cudagraph_keys[CUDAGraphMode.FULL]:
return CUDAGraphMode.FULL, non_uniform_key
# otherwise, check if the relaxed key exists
if relaxed_batch_desc in self.cudagraph_keys[CUDAGraphMode.FULL]:
return CUDAGraphMode.FULL, relaxed_batch_desc
# also check if non-uniform key exists for more "general"
# also check if the relaxed key exists for more "general"
# piecewise cudagraph
if non_uniform_key in self.cudagraph_keys[CUDAGraphMode.PIECEWISE]:
return CUDAGraphMode.PIECEWISE, non_uniform_key
if relaxed_batch_desc in self.cudagraph_keys[CUDAGraphMode.PIECEWISE]:
return CUDAGraphMode.PIECEWISE, relaxed_batch_desc
# finally, just return no cudagraphs
return CUDAGraphMode.NONE, None
# finally, just return no cudagraphs and a trivial batch descriptor
return CUDAGraphMode.NONE, BatchDescriptor(num_tokens)
......@@ -9,6 +9,7 @@ from vllm.config import ParallelConfig
from vllm.distributed.parallel_state import get_dp_group
from vllm.logger import init_logger
from vllm.v1.worker.ubatch_utils import (
UBatchSlice,
UBatchSlices,
check_ubatch_thresholds,
create_ubatch_slices,
......@@ -88,6 +89,17 @@ def _post_process_dp_padding(tensor: torch.Tensor, should_dp_pad: bool) -> torch
return num_tokens_across_dp.cpu()
# This just pads the second ubatch slice out to the total number of tokens
# (num_tokens + padding) since we do `create_ubatch_slices` before applying DP padding.
def _pad_out_ubatch_slice(ubatch_slices: UBatchSlices, num_total_tokens: int):
padded_second_ubatch_slice = slice(
ubatch_slices[1].token_slice.start, num_total_tokens
)
ubatch_slices[1] = UBatchSlice(
padded_second_ubatch_slice, padded_second_ubatch_slice
)
def _synchronize_dp_ranks(
num_tokens_unpadded: int,
num_tokens_padded: int,
......@@ -220,11 +232,14 @@ def coordinate_batch_across_dp(
# to the second ubatch in pad_out_ubatch_slice after attention
# metadata creation
assert num_tokens_after_padding is not None
token_split_point = int(num_tokens_after_padding[0].item()) // 2
num_tokens_padded = int(num_tokens_after_padding[0].item())
token_split_point = int(num_tokens_padded) // 2
assert num_scheduled_tokens_per_request is not None
ubatch_slices = create_ubatch_slices(
num_scheduled_tokens_per_request, token_split_point
)
ubatch_slices = _pad_out_ubatch_slice(ubatch_slices, num_tokens_padded)
assert sum(s.num_tokens for s in ubatch_slices) == num_tokens_padded
return (ubatch_slices, num_tokens_after_padding)
This diff is collapsed.
......@@ -8,12 +8,13 @@ from contextlib import AbstractContextManager, nullcontext
from types import NoneType
from typing import TYPE_CHECKING, Any, cast
import numpy as np
import torch
import torch.distributed
import torch.nn as nn
import vllm.envs as envs
from vllm.config import VllmConfig
from vllm.config import CUDAGraphMode, VllmConfig
from vllm.distributed import (
ensure_model_parallel_initialized,
init_distributed_environment,
......@@ -487,6 +488,7 @@ class Worker(WorkerBase):
hidden_states, last_hidden_states = self.model_runner._dummy_run(
num_tokens=max_num_reqs,
skip_eplb=True,
cudagraph_runtime_mode=CUDAGraphMode.NONE,
)
if self.model_runner.is_pooling_model:
self.model_runner._dummy_pooler_run(hidden_states)
......@@ -534,12 +536,39 @@ class Worker(WorkerBase):
intermediate_tensors = None
forward_pass = scheduler_output.total_num_scheduled_tokens > 0
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
num_input_tokens = self.model_runner._get_num_input_tokens(num_scheduled_tokens)
all_gather_tensors = {}
compilation_config = self.vllm_config.compilation_config
parallel_config = self.vllm_config.parallel_config
if (
parallel_config.pipeline_parallel_size > 1
and compilation_config.pass_config.enable_sequence_parallelism
and forward_pass
):
# currently only supported by V1 GPUModelRunner
assert isinstance(self.model_runner, GPUModelRunner)
num_scheduled_tokens_np = np.array(
list(scheduler_output.num_scheduled_tokens.values()),
dtype=np.int32,
)
# TODO(lucas): This is pretty gross; ideally we should only ever call
# `_determine_batch_execution_and_padding` once (will get called again
# in `execute_model`) but this requires a larger refactor of PP.
_, batch_desc, _, _ = (
self.model_runner._determine_batch_execution_and_padding(
num_tokens=num_scheduled_tokens,
num_reqs=len(num_scheduled_tokens_np),
num_scheduled_tokens_np=num_scheduled_tokens_np,
max_num_scheduled_tokens=num_scheduled_tokens_np.max(),
use_cascade_attn=False, # TODO(lucas): Handle cascade attention
)
)
all_gather_tensors = {
"residual": not is_residual_scattered_for_sp(
self.vllm_config, num_input_tokens
self.vllm_config, batch_desc.num_tokens
)
}
if forward_pass and not get_pp_group().is_first_rank:
tensor_dict = get_pp_group().recv_tensor_dict(
all_gather_group=get_tp_group(),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment