Unverified Commit 1d532f9d authored by Lucas Wilkinson's avatar Lucas Wilkinson Committed by GitHub
Browse files

[DP] Only use DP padding when cudagraphs are actually used (#34102)


Signed-off-by: default avatarLucas Wilkinson <lwilkins@redhat.com>
parent 234a65b7
......@@ -176,10 +176,14 @@ class TestCudagraphDispatcher:
assert rt_mode == CUDAGraphMode.NONE
assert key == BatchDescriptor(num_tokens=15)
# 4. disable_full should have a fall back mode (e.g., cascade attention)
# 4. invalid_modes={FULL} should have a fall back mode
# (e.g., cascade attention)
desc_full_exact = BatchDescriptor(num_tokens=8, uniform=False)
rt_mode, key = dispatcher.dispatch(
num_tokens=8, uniform_decode=False, has_lora=False, disable_full=True
num_tokens=8,
uniform_decode=False,
has_lora=False,
invalid_modes={CUDAGraphMode.FULL},
)
if "PIECEWISE" in cudagraph_mode_str: # string contains check
......@@ -188,6 +192,16 @@ class TestCudagraphDispatcher:
else:
assert rt_mode == CUDAGraphMode.NONE
# 5. valid_modes={NONE} always returns NONE even when keys exist
rt_mode, key = dispatcher.dispatch(
num_tokens=8,
uniform_decode=False,
has_lora=False,
valid_modes={CUDAGraphMode.NONE},
)
assert rt_mode == CUDAGraphMode.NONE
assert key == BatchDescriptor(num_tokens=8)
@pytest.mark.parametrize(
"cudagraph_mode_str,compilation_mode,expected_modes",
[
......
......@@ -87,8 +87,12 @@ class CUDAGraphMode(enum.Enum):
def separate_routine(self) -> bool:
return isinstance(self.value, tuple)
def valid_runtime_modes(self) -> bool:
return self in [CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL]
@classmethod
def valid_runtime_modes(cls) -> frozenset["CUDAGraphMode"]:
return frozenset({cls.NONE, cls.PIECEWISE, cls.FULL})
def is_valid_runtime_mode(self) -> bool:
return self in CUDAGraphMode.valid_runtime_modes()
def __str__(self) -> str:
return self.name
......
......@@ -241,7 +241,7 @@ class ForwardContext:
additional_kwargs: dict[str, Any] = field(default_factory=dict)
def __post_init__(self):
assert self.cudagraph_runtime_mode.valid_runtime_modes(), (
assert self.cudagraph_runtime_mode.is_valid_runtime_mode(), (
f"Invalid cudagraph runtime mode: {self.cudagraph_runtime_mode}"
)
......@@ -347,7 +347,6 @@ def set_forward_context(
num_tokens_unpadded=num_tokens,
parallel_config=vllm_config.parallel_config,
allow_microbatching=False,
allow_dp_padding=False,
)
assert num_tokens_across_dp is not None
dp_metadata = DPMetadata.make(
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Set as AbstractSet
from dataclasses import replace
from itertools import product
......@@ -232,8 +233,9 @@ class CudagraphDispatcher:
num_tokens: int,
uniform_decode: bool = False,
has_lora: bool = False,
disable_full: bool = False,
num_active_loras: int = 0,
valid_modes: AbstractSet[CUDAGraphMode] | None = None,
invalid_modes: AbstractSet[CUDAGraphMode] | None = None,
) -> tuple[CUDAGraphMode, BatchDescriptor]:
"""
Given conditions(e.g.,batch descriptor and if using piecewise only),
......@@ -246,15 +248,29 @@ class CudagraphDispatcher:
uniform_decode: Whether the batch is uniform decode (i.e. uniform and query
length is uniform_decode_query_len).
has_lora: Whether LoRA is active.
disable_full: If True, skip FULL cudagraph checks and
return PIECEWISE or NONE only. (can be used for features like
cascade attention that are not supported by full cudagraphs)
num_active_loras: Number of distinct active LoRA adapters.
valid_modes: Set of cudagraph modes that are allowed. None means
all modes are allowed.
invalid_modes: Set of cudagraph modes to exclude. Subtracted from
valid_modes to compute allowed modes. (e.g., {FULL} for
features like cascade attention not supported by full
cudagraphs). None means no modes are excluded.
"""
allowed_modes = valid_modes or CUDAGraphMode.valid_runtime_modes()
if invalid_modes:
allowed_modes -= invalid_modes
assert len(allowed_modes) >= 1, (
f"No allowed cudagraph modes: valid_modes={valid_modes}, "
f"invalid_modes={invalid_modes}"
)
if (
not self.keys_initialized
or self.cudagraph_mode == CUDAGraphMode.NONE
or num_tokens > self.compilation_config.max_cudagraph_capture_size
or allowed_modes <= {CUDAGraphMode.NONE}
):
return CUDAGraphMode.NONE, BatchDescriptor(num_tokens)
......@@ -281,24 +297,26 @@ class CudagraphDispatcher:
num_tokens, uniform_decode, has_lora, effective_num_active_loras
)
if CUDAGraphMode.FULL in allowed_modes:
# check if key exists for full cudagraph
# For pure FULL mode, keys are registered with uniform=False.
batch_desc_to_check = batch_desc
if self.cudagraph_mode == CUDAGraphMode.FULL:
batch_desc_to_check = replace(batch_desc, uniform=False)
if (
not disable_full
and batch_desc_to_check in self.cudagraph_keys[CUDAGraphMode.FULL]
):
if batch_desc_to_check in self.cudagraph_keys[CUDAGraphMode.FULL]:
return CUDAGraphMode.FULL, batch_desc_to_check
if CUDAGraphMode.PIECEWISE in allowed_modes:
# also check if the relaxed key exists for more "general"
# piecewise cudagraph
batch_desc_to_check = replace(batch_desc, num_reqs=None, uniform=False)
if batch_desc_to_check in self.cudagraph_keys[CUDAGraphMode.PIECEWISE]:
return CUDAGraphMode.PIECEWISE, batch_desc_to_check
# finally, just return no cudagraphs and a trivial batch descriptor
assert CUDAGraphMode.NONE in allowed_modes, (
f"No matching cudagraph found and NONE is not in "
f"allowed_modes={allowed_modes}"
)
return CUDAGraphMode.NONE, BatchDescriptor(num_tokens)
def get_capture_descs(self) -> list[tuple[CUDAGraphMode, list[BatchDescriptor]]]:
......
......@@ -448,17 +448,10 @@ class SpecDecodeBaseProposer:
assert draft_indexer_metadata is not None
per_layer_attn_metadata[layer_name] = draft_indexer_metadata
num_tokens_dp_padded, num_tokens_across_dp = self._pad_batch_across_dp(
num_tokens_unpadded=num_tokens, num_tokens_padded=num_tokens
cudagraph_runtime_mode, num_input_tokens, num_tokens_across_dp = (
self._determine_batch_execution_and_padding(num_tokens)
)
cudagraph_runtime_mode, batch_desc = self.cudagraph_dispatcher.dispatch(
num_tokens_dp_padded
)
num_input_tokens = batch_desc.num_tokens
if num_tokens_across_dp is not None:
num_tokens_across_dp[self.dp_rank] = num_input_tokens
if self.supports_mm_inputs:
mm_embeds, is_mm_embed = mm_embed_inputs or (None, None)
......@@ -549,17 +542,10 @@ class SpecDecodeBaseProposer:
# Generate the remaining draft tokens.
draft_token_ids_list = [draft_token_ids]
batch_size_dp_padded, batch_size_across_dp = self._pad_batch_across_dp(
num_tokens_unpadded=batch_size, num_tokens_padded=batch_size
cudagraph_runtime_mode, input_batch_size, batch_size_across_dp = (
self._determine_batch_execution_and_padding(batch_size)
)
cudagraph_runtime_mode, batch_desc = self.cudagraph_dispatcher.dispatch(
batch_size_dp_padded
)
input_batch_size = batch_desc.num_tokens
if batch_size_across_dp is not None:
batch_size_across_dp[self.dp_rank] = input_batch_size
common_attn_metadata.num_actual_tokens = batch_size
common_attn_metadata.max_query_len = 1
common_attn_metadata.query_start_loc = self.arange[: batch_size + 1]
......@@ -1568,19 +1554,11 @@ class SpecDecodeBaseProposer:
self.num_speculative_tokens if not is_graph_capturing else 1
):
if fwd_idx <= 1:
num_tokens_dp_padded, num_tokens_across_dp = self._pad_batch_across_dp(
num_tokens_unpadded=num_tokens, num_tokens_padded=num_tokens
cudagraph_runtime_mode, num_input_tokens, num_tokens_across_dp = (
self._determine_batch_execution_and_padding(
num_tokens, use_cudagraphs=use_cudagraphs
)
if use_cudagraphs:
cudagraph_runtime_mode, batch_desc = (
self.cudagraph_dispatcher.dispatch(num_tokens_dp_padded)
)
num_input_tokens = batch_desc.num_tokens
else:
cudagraph_runtime_mode = CUDAGraphMode.NONE
num_input_tokens = num_tokens_dp_padded
if num_tokens_across_dp is not None:
num_tokens_across_dp[self.dp_rank] = num_input_tokens
# Make sure to use EAGLE's own buffer during cudagraph capture.
if (
......@@ -1680,28 +1658,49 @@ class SpecDecodeBaseProposer:
== 1
), "All drafting layers should belong to the same kv cache group"
def _pad_batch_across_dp(
def _determine_batch_execution_and_padding(
self,
num_tokens_unpadded: int,
num_tokens_padded: int,
) -> tuple[int, torch.Tensor]:
num_tokens: int,
use_cudagraphs: bool = True,
) -> tuple[CUDAGraphMode, int, torch.Tensor | None]:
cudagraph_mode, batch_desc = self.cudagraph_dispatcher.dispatch(
num_tokens,
valid_modes=({CUDAGraphMode.NONE} if not use_cudagraphs else None),
)
num_tokens_padded = batch_desc.num_tokens
# Extra coordination when running data-parallel since we need to
# coordinate across ranks
# TODO(Flechman): support DBO ubatching
should_ubatch, num_toks_across_dp, _ = coordinate_batch_across_dp(
num_tokens_unpadded=num_tokens_unpadded,
should_ubatch, num_tokens_across_dp = False, None
if self.vllm_config.parallel_config.data_parallel_size > 1:
should_ubatch, num_tokens_across_dp, synced_cudagraph_mode = (
coordinate_batch_across_dp(
num_tokens_unpadded=num_tokens,
parallel_config=self.vllm_config.parallel_config,
allow_microbatching=False,
allow_dp_padding=self.cudagraph_dispatcher.cudagraph_mode
!= CUDAGraphMode.NONE,
num_tokens_padded=num_tokens_padded,
uniform_decode=None,
num_scheduled_tokens_per_request=None,
cudagraph_mode=cudagraph_mode.value,
)
)
assert not should_ubatch, "DBO ubatching not implemented for EAGLE"
num_tokens_dp_padded = num_tokens_padded
if num_toks_across_dp is not None:
num_tokens_dp_padded = int(num_toks_across_dp[self.dp_rank].item())
return num_tokens_dp_padded, num_toks_across_dp
# Extract DP-synced values
if num_tokens_across_dp is not None:
dp_rank = self.dp_rank
num_tokens_padded = int(num_tokens_across_dp[dp_rank].item())
# Re-dispatch with DP padding so we have the correct
# batch_descriptor
cudagraph_mode, batch_desc = self.cudagraph_dispatcher.dispatch(
num_tokens_padded,
valid_modes={CUDAGraphMode(synced_cudagraph_mode)},
)
# Assert to make sure the agreed upon token count is correct
# otherwise num_tokens_across_dp will no-longer be valid
assert batch_desc.num_tokens == num_tokens_padded
num_tokens_across_dp[dp_rank] = num_tokens_padded
return cudagraph_mode, num_tokens_padded, num_tokens_across_dp
class EagleProposer(SpecDecodeBaseProposer):
......
......@@ -37,7 +37,6 @@ def _get_device_and_group(parallel_config: ParallelConfig):
def _run_ar(
should_ubatch: bool,
should_dp_pad: bool,
orig_num_tokens_per_ubatch: int,
padded_num_tokens_per_ubatch: int,
cudagraph_mode: int,
......@@ -46,12 +45,11 @@ def _run_ar(
dp_size = parallel_config.data_parallel_size
dp_rank = parallel_config.data_parallel_rank
device, group = _get_device_and_group(parallel_config)
tensor = torch.zeros(5, dp_size, device=device, dtype=torch.int32)
tensor = torch.zeros(4, dp_size, device=device, dtype=torch.int32)
tensor[0][dp_rank] = orig_num_tokens_per_ubatch
tensor[1][dp_rank] = padded_num_tokens_per_ubatch
tensor[2][dp_rank] = 1 if should_ubatch else 0
tensor[3][dp_rank] = 1 if should_dp_pad else 0
tensor[4][dp_rank] = cudagraph_mode
tensor[3][dp_rank] = cudagraph_mode
dist.all_reduce(tensor, group=group)
return tensor
......@@ -97,14 +95,13 @@ def _post_process_cudagraph_mode(tensor: torch.Tensor) -> int:
If any rank has NONE (0), all ranks use NONE.
This ensures all ranks send consistent values (all padded or all unpadded).
"""
return int(tensor[4, :].min().item())
return int(tensor[3, :].min().item())
def _synchronize_dp_ranks(
num_tokens_unpadded: int,
num_tokens_padded: int,
should_attempt_ubatching: bool,
should_attempt_dp_padding: bool,
cudagraph_mode: int,
parallel_config: ParallelConfig,
) -> tuple[bool, torch.Tensor | None, int]:
......@@ -113,8 +110,8 @@ def _synchronize_dp_ranks(
run with microbatching or none of them do.
2. Determines the total number of tokens that each rank will run.
When running microbatched or if should_attempt_dp_padding is True, all
ranks will be padded out so that the run with the same number of tokens
When running microbatched or if cudagraph is enabled (synced across ranks),
all ranks will be padded out so that they run with the same number of tokens.
3. Synchronizes cudagraph_mode across ranks by taking the minimum.
......@@ -133,29 +130,26 @@ def _synchronize_dp_ranks(
# will run and if we are using ubatching or not.
tensor = _run_ar(
should_ubatch=should_attempt_ubatching,
should_dp_pad=should_attempt_dp_padding,
orig_num_tokens_per_ubatch=num_tokens_unpadded,
padded_num_tokens_per_ubatch=num_tokens_padded,
cudagraph_mode=cudagraph_mode,
parallel_config=parallel_config,
)
should_dp_pad = bool(torch.all(tensor[3] == 1).item())
# DP ranks should all have the same value for should_attempt_dp_padding.
assert should_attempt_dp_padding == should_dp_pad
# Synchronize cudagraph_mode across ranks first (take min).
# This is needed before DP padding decision since we use the synced
# cudagraph mode to determine whether DP padding is needed.
synced_cudagraph_mode = _post_process_cudagraph_mode(tensor)
# Check conditions for microbatching
should_ubatch = _post_process_ubatch(tensor, parallel_config.num_ubatches)
if should_ubatch and not should_dp_pad:
logger.debug_once(
"Microbatching has been triggered and requires DP padding. "
"Enabling DP padding even though it has been explicitly "
"disabled.",
scope="global",
)
should_dp_pad = True
# DP padding is needed when cudagraph is enabled (synced across ranks)
# or when ubatching/DBO is active (ubatching requires uniform batch
# sizes across DP ranks currently).
# Use the synced runtime cudagraph mode rather than the compilation config
# so we can avoid padding when cudagraph is not enabled for this step.
should_dp_pad = synced_cudagraph_mode != 0 or should_ubatch
# Pad all DP ranks up to the maximum token count across ranks if
# should_dp_pad is True
......@@ -164,16 +158,12 @@ def _synchronize_dp_ranks(
should_dp_pad,
)
# Synchronize cudagraph_mode across ranks (take min)
synced_cudagraph_mode = _post_process_cudagraph_mode(tensor)
return should_ubatch, num_tokens_after_padding, synced_cudagraph_mode
def coordinate_batch_across_dp(
num_tokens_unpadded: int,
allow_microbatching: bool,
allow_dp_padding: bool,
parallel_config: ParallelConfig,
num_tokens_padded: int | None = None,
uniform_decode: bool | None = None,
......@@ -187,7 +177,6 @@ def coordinate_batch_across_dp(
Args:
num_tokens_unpadded: Number of tokens without accounting for padding
allow_microbatching: If microbatching should be attempted
allow_dp_padding: If all DP ranks should be padded up to the same value
parallel_config: The parallel config
num_tokens_padded: Number of tokens including any non-DP padding (CUDA graphs,
TP, etc)
......@@ -195,15 +184,15 @@ def coordinate_batch_across_dp(
only contains single token decodes
num_scheduled_tokens_per_request: Only used if allow_microbatching is True. The
number of tokens per request.
cudagraph_mode: The cudagraph mode for this rank (0=NONE, 1=PIECEWISE, 2=FULL)
cudagraph_mode: The cudagraph mode for this rank (0=NONE, 1=PIECEWISE, 2=FULL).
DP padding is enabled when synced cudagraph mode across ranks is not NONE.
Returns: tuple[
ubatch_slices: if this is set then all DP ranks have agreed to
microbatch
num_tokens_after_padding: A tensor containing the total number of
tokens per-microbatch for each DP rank including padding. Will be
padded up to the max value across all DP ranks when allow_dp_padding
is True.
padded up to the max value across all DP ranks when cudagraph is enabled.
synced_cudagraph_mode: The synchronized cudagraph mode (min across ranks)
]
......@@ -231,7 +220,6 @@ def coordinate_batch_across_dp(
num_tokens_unpadded,
num_tokens_padded,
should_attempt_ubatching,
allow_dp_padding,
cudagraph_mode,
parallel_config,
)
......
......@@ -2300,7 +2300,7 @@ class GPUModelRunner(
)
# Dispatch for the decoder portion of the model.
_, batch_desc = self.cudagraph_dispatcher.dispatch(
num_logits, disable_full=True
num_logits, invalid_modes={CUDAGraphMode.FULL}
)
num_logits_padded = batch_desc.num_tokens
logits_indices_padded = self.kv_sharing_fast_prefill_logits_indices[
......@@ -3174,20 +3174,19 @@ class GPUModelRunner(
has_lora = num_active_loras > 0 if force_has_lora is None else force_has_lora
num_tokens_padded = self._pad_for_sequence_parallelism(num_tokens)
dispatch_cudagraph = (
lambda num_tokens, disable_full: self.cudagraph_dispatcher.dispatch(
def dispatch_cudagraph(num_tokens, disable_full=False, valid_modes=None):
return self.cudagraph_dispatcher.dispatch(
num_tokens=num_tokens,
has_lora=has_lora,
uniform_decode=uniform_decode,
disable_full=disable_full,
num_active_loras=num_active_loras,
)
if not force_eager
else (CUDAGraphMode.NONE, BatchDescriptor(num_tokens_padded))
valid_modes={CUDAGraphMode.NONE} if force_eager else valid_modes,
invalid_modes={CUDAGraphMode.FULL} if disable_full else None,
)
cudagraph_mode, batch_descriptor = dispatch_cudagraph(
num_tokens_padded, use_cascade_attn or has_encoder_output
num_tokens_padded, disable_full=use_cascade_attn or has_encoder_output
)
num_tokens_padded = batch_descriptor.num_tokens
if self.compilation_config.pass_config.enable_sp:
......@@ -3204,20 +3203,11 @@ class GPUModelRunner(
# across ranks
should_ubatch, num_tokens_across_dp = False, None
if self.vllm_config.parallel_config.data_parallel_size > 1:
# Disable DP padding when running eager to avoid excessive padding when
# running prefills. This lets us set cudagraph_mode="NONE" on the prefiller
# in a P/D setup and still use CUDA graphs (enabled by this padding) on the
# decoder.
allow_dp_padding = (
self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
)
should_ubatch, num_tokens_across_dp, synced_cudagraph_mode = (
coordinate_batch_across_dp(
num_tokens_unpadded=num_tokens,
parallel_config=self.parallel_config,
allow_microbatching=allow_microbatching,
allow_dp_padding=allow_dp_padding,
num_tokens_padded=num_tokens_padded,
uniform_decode=uniform_decode,
num_scheduled_tokens_per_request=num_scheduled_tokens_np,
......@@ -3232,7 +3222,7 @@ class GPUModelRunner(
# Re-dispatch with DP padding so we have the correct batch_descriptor
cudagraph_mode, batch_descriptor = dispatch_cudagraph(
num_tokens_padded,
disable_full=synced_cudagraph_mode <= CUDAGraphMode.PIECEWISE.value,
valid_modes={CUDAGraphMode(synced_cudagraph_mode)},
)
# Assert to make sure the agreed upon token count is correct otherwise
# num_tokens_across_dp will no-longer be valid
......@@ -4724,7 +4714,7 @@ class GPUModelRunner(
assert (
cudagraph_runtime_mode is None
or cudagraph_runtime_mode.valid_runtime_modes()
or cudagraph_runtime_mode.is_valid_runtime_mode()
)
# If cudagraph_mode.decode_mode() == FULL and
......@@ -5336,7 +5326,7 @@ class GPUModelRunner(
):
assert (
cudagraph_runtime_mode != CUDAGraphMode.NONE
and cudagraph_runtime_mode.valid_runtime_modes()
and cudagraph_runtime_mode.is_valid_runtime_mode()
), f"Invalid cudagraph runtime mode: {cudagraph_runtime_mode}"
if not batch_descriptors:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment