Unverified Commit 22613408 authored by Lucas Wilkinson's avatar Lucas Wilkinson Committed by GitHub
Browse files

[Misc] Remove pad_for_cudagraphs from config (#30143)


Signed-off-by: default avatarLucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: default avatarMatthew Bonanni <mbonanni@redhat.com>
Co-authored-by: default avatarMatthew Bonanni <mbonanni@redhat.com>
parent 86c69dc5
...@@ -2,14 +2,20 @@ ...@@ -2,14 +2,20 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import copy import copy
from contextlib import nullcontext from contextlib import nullcontext
from unittest.mock import patch from unittest.mock import MagicMock, patch
import pytest import pytest
from pydantic import ValidationError from pydantic import ValidationError
from vllm.compilation.counter import compilation_counter from vllm.compilation.counter import compilation_counter
from vllm.compilation.fix_functionalization import FixFunctionalizationPass from vllm.compilation.fix_functionalization import FixFunctionalizationPass
from vllm.config import CompilationConfig, CUDAGraphMode, ParallelConfig, VllmConfig from vllm.config import (
CompilationConfig,
CUDAGraphMode,
ParallelConfig,
SchedulerConfig,
VllmConfig,
)
from vllm.config.compilation import CompilationMode, PassConfig from vllm.config.compilation import CompilationMode, PassConfig
from vllm.engine.arg_utils import EngineArgs from vllm.engine.arg_utils import EngineArgs
from vllm.platforms import current_platform from vllm.platforms import current_platform
...@@ -17,6 +23,7 @@ from vllm.utils.torch_utils import ( ...@@ -17,6 +23,7 @@ from vllm.utils.torch_utils import (
_is_torch_equal_or_newer, _is_torch_equal_or_newer,
is_torch_equal, is_torch_equal,
) )
from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher
# This import automatically registers `torch.ops.silly.attention` # This import automatically registers `torch.ops.silly.attention`
from . import silly_attention # noqa: F401 from . import silly_attention # noqa: F401
...@@ -472,6 +479,19 @@ def test_cached_compilation_config(default_vllm_config): ...@@ -472,6 +479,19 @@ def test_cached_compilation_config(default_vllm_config):
assert "torch.ops._C.static_scaled_fp8_quant.default(" in code assert "torch.ops._C.static_scaled_fp8_quant.default(" in code
def _create_vllm_config_for_validation(
compilation_config: CompilationConfig,
) -> MagicMock:
"""Helper to create a mock VllmConfig for padding validation testing."""
mock_config = MagicMock(spec=VllmConfig)
mock_config.compilation_config = compilation_config
mock_config.scheduler_config = SchedulerConfig.default_factory(max_num_seqs=8)
mock_config.parallel_config = ParallelConfig()
mock_config.speculative_config = None
mock_config.lora_config = None
return mock_config
def test_compile_sizes_padding_validation(): def test_compile_sizes_padding_validation():
"""Test that compile_sizes with values that would be padded raises an error.""" """Test that compile_sizes with values that would be padded raises an error."""
# cudagraph_capture_sizes=[1, 2, 4, 8] means: # cudagraph_capture_sizes=[1, 2, 4, 8] means:
...@@ -488,29 +508,39 @@ def test_compile_sizes_padding_validation(): ...@@ -488,29 +508,39 @@ def test_compile_sizes_padding_validation():
cudagraph_capture_sizes=[1, 2, 4, 8], cudagraph_capture_sizes=[1, 2, 4, 8],
max_cudagraph_capture_size=8, max_cudagraph_capture_size=8,
compile_sizes=[3], compile_sizes=[3],
cudagraph_mode=CUDAGraphMode.FULL,
) )
config.post_init_cudagraph_sizes() config.post_init_cudagraph_sizes()
dispatcher = CudagraphDispatcher(_create_vllm_config_for_validation(config))
dispatcher.initialize_cudagraph_keys(CUDAGraphMode.FULL)
with pytest.raises(ValueError, match="would be padded to"): with pytest.raises(ValueError, match="would be padded to"):
config = CompilationConfig( config = CompilationConfig(
cudagraph_capture_sizes=[1, 2, 4, 8], cudagraph_capture_sizes=[1, 2, 4, 8],
max_cudagraph_capture_size=8, max_cudagraph_capture_size=8,
compile_sizes=[5], compile_sizes=[5],
cudagraph_mode=CUDAGraphMode.FULL,
) )
config.post_init_cudagraph_sizes() config.post_init_cudagraph_sizes()
dispatcher = CudagraphDispatcher(_create_vllm_config_for_validation(config))
dispatcher.initialize_cudagraph_keys(CUDAGraphMode.FULL)
config = CompilationConfig( config = CompilationConfig(
cudagraph_capture_sizes=[1, 2, 4, 8], cudagraph_capture_sizes=[1, 2, 4, 8],
max_cudagraph_capture_size=8, max_cudagraph_capture_size=8,
compile_sizes=[1, 2, 4, 8], compile_sizes=[1, 2, 4, 8],
cudagraph_mode=CUDAGraphMode.FULL,
) )
config.post_init_cudagraph_sizes() config.post_init_cudagraph_sizes()
assert sorted(config.compile_sizes) == [1, 2, 4, 8] assert sorted(config.compile_sizes) == [1, 2, 4, 8]
dispatcher = CudagraphDispatcher(_create_vllm_config_for_validation(config))
dispatcher.initialize_cudagraph_keys(CUDAGraphMode.FULL) # Should not raise
config = CompilationConfig( config = CompilationConfig(
cudagraph_capture_sizes=[1, 2, 4, 8], cudagraph_capture_sizes=[1, 2, 4, 8],
max_cudagraph_capture_size=8, max_cudagraph_capture_size=8,
compile_sizes=["cudagraph_capture_sizes"], compile_sizes=["cudagraph_capture_sizes"],
cudagraph_mode=CUDAGraphMode.FULL,
) )
config.post_init_cudagraph_sizes() config.post_init_cudagraph_sizes()
assert sorted(config.compile_sizes) == [1, 2, 4, 8] assert sorted(config.compile_sizes) == [1, 2, 4, 8]
...@@ -535,3 +565,5 @@ def test_compile_sizes_padding_validation(): ...@@ -535,3 +565,5 @@ def test_compile_sizes_padding_validation():
) )
config.post_init_cudagraph_sizes() config.post_init_cudagraph_sizes()
assert sorted(config.compile_sizes) == [3, 5, 7] assert sorted(config.compile_sizes) == [3, 5, 7]
dispatcher = CudagraphDispatcher(_create_vllm_config_for_validation(config))
dispatcher.initialize_cudagraph_keys(CUDAGraphMode.NONE) # Should not raise
...@@ -9,6 +9,7 @@ from tests.models.registry import HF_EXAMPLE_MODELS ...@@ -9,6 +9,7 @@ from tests.models.registry import HF_EXAMPLE_MODELS
from tests.utils import multi_gpu_test from tests.utils import multi_gpu_test
from vllm.engine.arg_utils import EngineArgs from vllm.engine.arg_utils import EngineArgs
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher
from ...utils import check_logprobs_close, check_outputs_equal from ...utils import check_logprobs_close, check_outputs_equal
...@@ -172,7 +173,14 @@ def test_mamba_cache_cg_padding( ...@@ -172,7 +173,14 @@ def test_mamba_cache_cg_padding(
tensor dimensions aren't compatible. tensor dimensions aren't compatible.
""" """
vllm_config = EngineArgs(model=model, trust_remote_code=True).create_engine_config() vllm_config = EngineArgs(model=model, trust_remote_code=True).create_engine_config()
while len(example_prompts) == vllm_config.pad_for_cudagraph(len(example_prompts)): cudagraph_dispatcher = CudagraphDispatcher(vllm_config)
cudagraph_dispatcher.initialize_cudagraph_keys(
vllm_config.compilation_config.cudagraph_mode
)
while (
len(example_prompts)
== cudagraph_dispatcher.dispatch(len(example_prompts))[1].num_tokens
):
example_prompts.append(example_prompts[0]) example_prompts.append(example_prompts[0])
try: try:
......
...@@ -61,9 +61,6 @@ def _create_vllm_config( ...@@ -61,9 +61,6 @@ def _create_vllm_config(
) )
compilation_config.post_init_cudagraph_sizes() compilation_config.post_init_cudagraph_sizes()
mock_config.pad_for_cudagraph = (
lambda batch_size: compilation_config.bs_to_padded_graph_size[batch_size]
)
return mock_config return mock_config
...@@ -169,6 +166,7 @@ class TestCudagraphDispatcher: ...@@ -169,6 +166,7 @@ class TestCudagraphDispatcher:
rt_mode, key = dispatcher.dispatch( rt_mode, key = dispatcher.dispatch(
num_tokens=8, uniform_decode=False, has_lora=False, disable_full=True num_tokens=8, uniform_decode=False, has_lora=False, disable_full=True
) )
if "PIECEWISE" in cudagraph_mode_str: # string contains check if "PIECEWISE" in cudagraph_mode_str: # string contains check
assert rt_mode == CUDAGraphMode.PIECEWISE assert rt_mode == CUDAGraphMode.PIECEWISE
assert key == desc_full_exact.relax_for_mixed_batch_cudagraphs() assert key == desc_full_exact.relax_for_mixed_batch_cudagraphs()
...@@ -360,7 +358,7 @@ class TestCudagraphIntegration: ...@@ -360,7 +358,7 @@ class TestCudagraphIntegration:
): ):
full_wrapper(input_1) full_wrapper(input_1)
rt_mode, key = self.dispatcher.dispatch(desc_1) rt_mode, key = self.dispatcher.dispatch(num_tokens=desc_1.num_tokens)
# 1. Capture first shape # 1. Capture first shape
action = self._run_and_monitor_call(full_wrapper, input_1, rt_mode, key) action = self._run_and_monitor_call(full_wrapper, input_1, rt_mode, key)
assert action == "capture_global" assert action == "capture_global"
...@@ -369,7 +367,7 @@ class TestCudagraphIntegration: ...@@ -369,7 +367,7 @@ class TestCudagraphIntegration:
action = self._run_and_monitor_call(full_wrapper, input_1, rt_mode, key) action = self._run_and_monitor_call(full_wrapper, input_1, rt_mode, key)
assert action == "replay" assert action == "replay"
rt_mode, key = self.dispatcher.dispatch(desc_2) rt_mode, key = self.dispatcher.dispatch(num_tokens=desc_2.num_tokens)
# 3. Capture second shape # 3. Capture second shape
action = self._run_and_monitor_call(full_wrapper, input_2, rt_mode, key) action = self._run_and_monitor_call(full_wrapper, input_2, rt_mode, key)
assert action == "capture_global" assert action == "capture_global"
...@@ -381,7 +379,7 @@ class TestCudagraphIntegration: ...@@ -381,7 +379,7 @@ class TestCudagraphIntegration:
assert action == "replay" assert action == "replay"
# 5. Bypass if no key match # 5. Bypass if no key match
rt_mode, key = self.dispatcher.dispatch(desc_3_unseen) rt_mode, key = self.dispatcher.dispatch(num_tokens=desc_3_unseen.num_tokens)
assert rt_mode == CUDAGraphMode.NONE assert rt_mode == CUDAGraphMode.NONE
action = self._run_and_monitor_call(full_wrapper, input_3, rt_mode, key) action = self._run_and_monitor_call(full_wrapper, input_3, rt_mode, key)
assert action == "bypass" assert action == "bypass"
......
...@@ -15,7 +15,7 @@ from torch._inductor.runtime.triton_heuristics import CachingAutotuner ...@@ -15,7 +15,7 @@ from torch._inductor.runtime.triton_heuristics import CachingAutotuner
from vllm.compilation.backends import VllmBackend from vllm.compilation.backends import VllmBackend
from vllm.compilation.monitor import end_monitoring_torch_compile from vllm.compilation.monitor import end_monitoring_torch_compile
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.config.compilation import Range from vllm.config.utils import Range
from vllm.logger import init_logger from vllm.logger import init_logger
logger = init_logger(__name__) logger = init_logger(__name__)
......
...@@ -11,7 +11,7 @@ import torch.fx as fx ...@@ -11,7 +11,7 @@ import torch.fx as fx
from torch._inductor.pattern_matcher import PatternMatcherPass from torch._inductor.pattern_matcher import PatternMatcherPass
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.config.compilation import Range from vllm.config.utils import Range
from vllm.distributed import get_tp_group, tensor_model_parallel_all_reduce from vllm.distributed import get_tp_group, tensor_model_parallel_all_reduce
from vllm.distributed.parallel_state import get_tensor_model_parallel_world_size from vllm.distributed.parallel_state import get_tensor_model_parallel_world_size
from vllm.logger import init_logger from vllm.logger import init_logger
......
...@@ -581,15 +581,6 @@ class CompilationConfig: ...@@ -581,15 +581,6 @@ class CompilationConfig:
local_cache_dir: str = field(default=None, init=False) # type: ignore local_cache_dir: str = field(default=None, init=False) # type: ignore
"""local cache dir for each rank""" """local cache dir for each rank"""
bs_to_padded_graph_size: list[int] = field(
default=None, # type: ignore
init=False,
)
"""optimization:
Intuitively, bs_to_padded_graph_size should be dict[int, int].
since we know all keys are in a range [0, max_cudagraph_capture_size],
we can optimize it to list[int] for better lookup performance."""
# keep track of enabled and disabled custom ops # keep track of enabled and disabled custom ops
enabled_custom_ops: Counter[str] = field(default_factory=Counter, init=False) enabled_custom_ops: Counter[str] = field(default_factory=Counter, init=False)
"""custom ops that are enabled""" """custom ops that are enabled"""
...@@ -639,7 +630,6 @@ class CompilationConfig: ...@@ -639,7 +630,6 @@ class CompilationConfig:
"debug_dump_path", "debug_dump_path",
"cache_dir", "cache_dir",
"local_cache_dir", "local_cache_dir",
"bs_to_padded_graph_size",
"traced_files", "traced_files",
"compilation_time", "compilation_time",
"static_forward_context", "static_forward_context",
...@@ -661,7 +651,6 @@ class CompilationConfig: ...@@ -661,7 +651,6 @@ class CompilationConfig:
"enabled_custom_ops": True, "enabled_custom_ops": True,
"disabled_custom_ops": True, "disabled_custom_ops": True,
"compilation_time": True, "compilation_time": True,
"bs_to_padded_graph_size": True,
"traced_files": True, "traced_files": True,
"inductor_compile_config": { "inductor_compile_config": {
"post_grad_custom_post_pass": True, "post_grad_custom_post_pass": True,
...@@ -882,7 +871,6 @@ class CompilationConfig: ...@@ -882,7 +871,6 @@ class CompilationConfig:
"""To complete the initialization after cudagraph related """To complete the initialization after cudagraph related
configs are set. This includes: configs are set. This includes:
- initialize compile_sizes - initialize compile_sizes
- pre-compute the mapping bs_to_padded_graph_size
""" """
computed_compile_sizes = [] computed_compile_sizes = []
...@@ -906,23 +894,6 @@ class CompilationConfig: ...@@ -906,23 +894,6 @@ class CompilationConfig:
if self.cudagraph_capture_sizes: if self.cudagraph_capture_sizes:
assert self.cudagraph_capture_sizes[-1] == self.max_cudagraph_capture_size assert self.cudagraph_capture_sizes[-1] == self.max_cudagraph_capture_size
# May get recomputed in the model runner if adjustment is needed for spec-decode
self.compute_bs_to_padded_graph_size()
# Validate that compile_sizes won't be changed by padding.
# Only validate when cudagraphs are actually being used.
if self.compile_sizes and self.cudagraph_mode != CUDAGraphMode.NONE:
for size in self.compile_sizes:
if size <= self.max_cudagraph_capture_size:
padded = self.bs_to_padded_graph_size[size]
if padded != size:
raise ValueError(
f"compile_sizes contains {size} which would be "
f"padded to {padded}. All compile_sizes must be "
"values that won't be changed by cudagraph padding. "
"Use values from cudagraph_capture_sizes."
)
def set_splitting_ops_for_v1( def set_splitting_ops_for_v1(
self, all2all_backend: str, data_parallel_size: int = 1 self, all2all_backend: str, data_parallel_size: int = 1
): ):
...@@ -1134,24 +1105,6 @@ class CompilationConfig: ...@@ -1134,24 +1105,6 @@ class CompilationConfig:
self.max_cudagraph_capture_size = rounded_sizes[-1] self.max_cudagraph_capture_size = rounded_sizes[-1]
self.cudagraph_capture_sizes = rounded_sizes self.cudagraph_capture_sizes = rounded_sizes
# Recompute after adjusting the cudagraph sizes
self.compute_bs_to_padded_graph_size()
def compute_bs_to_padded_graph_size(self):
# pre-compute the mapping from batch size to padded graph size
self.bs_to_padded_graph_size = [
0 for i in range(self.max_cudagraph_capture_size + 1)
]
for end, start in zip(
self.cudagraph_capture_sizes + [self.max_cudagraph_capture_size + 1],
[0] + self.cudagraph_capture_sizes,
):
for bs in range(start, end):
if bs == start:
self.bs_to_padded_graph_size[bs] = start
else:
self.bs_to_padded_graph_size[bs] = end
def get_compile_ranges(self) -> list[Range]: def get_compile_ranges(self) -> list[Range]:
"""Get the compile ranges for the compilation config.""" """Get the compile ranges for the compilation config."""
if self.compile_ranges_split_points is None: if self.compile_ranges_split_points is None:
......
...@@ -57,13 +57,47 @@ class CudagraphDispatcher: ...@@ -57,13 +57,47 @@ class CudagraphDispatcher:
) )
self.keys_initialized = False self.keys_initialized = False
# Default cudagraph_mode to NONE until initialize_cudagraph_keys is called
self.cudagraph_mode = CUDAGraphMode.NONE
def _compute_bs_to_padded_graph_size(self) -> None:
"""Pre-compute the mapping from batch size to padded graph size."""
max_size = self.compilation_config.max_cudagraph_capture_size
capture_sizes = self.compilation_config.cudagraph_capture_sizes
self._bs_to_padded_graph_size: list[int] = [0] * (max_size + 1)
for end, start in zip(
capture_sizes + [max_size + 1],
[0] + capture_sizes,
):
for bs in range(start, end):
if bs == start:
self._bs_to_padded_graph_size[bs] = start
else:
self._bs_to_padded_graph_size[bs] = end
# Validate that compile_sizes won't be changed by padding.
# Only validate when cudagraphs are actually being used.
if (
self.compilation_config.compile_sizes
and self.cudagraph_mode != CUDAGraphMode.NONE
):
for size in self.compilation_config.compile_sizes:
if size <= self.compilation_config.max_cudagraph_capture_size:
padded = self._bs_to_padded_graph_size[size]
if padded != size:
raise ValueError(
f"compile_sizes contains {size} which would be "
f"padded to {padded}. All compile_sizes must be "
"values that won't be changed by cudagraph padding. "
"Use values from cudagraph_capture_sizes."
)
def _create_padded_batch_descriptor( def _create_padded_batch_descriptor(
self, num_tokens: int, uniform_decode: bool, has_lora: bool self, num_tokens: int, uniform_decode: bool, has_lora: bool
) -> BatchDescriptor: ) -> BatchDescriptor:
max_num_seqs = self.vllm_config.scheduler_config.max_num_seqs max_num_seqs = self.vllm_config.scheduler_config.max_num_seqs
uniform_decode_query_len = self.uniform_decode_query_len uniform_decode_query_len = self.uniform_decode_query_len
num_tokens_padded = self.vllm_config.pad_for_cudagraph(num_tokens) num_tokens_padded = self._bs_to_padded_graph_size[num_tokens]
if uniform_decode and self.cudagraph_mode.has_mode(CUDAGraphMode.FULL): if uniform_decode and self.cudagraph_mode.has_mode(CUDAGraphMode.FULL):
num_reqs = num_tokens_padded // uniform_decode_query_len num_reqs = num_tokens_padded // uniform_decode_query_len
...@@ -88,12 +122,19 @@ class CudagraphDispatcher: ...@@ -88,12 +122,19 @@ class CudagraphDispatcher:
self.cudagraph_keys[runtime_mode].add(batch_descriptor) self.cudagraph_keys[runtime_mode].add(batch_descriptor)
def initialize_cudagraph_keys( def initialize_cudagraph_keys(
self, cudagraph_mode: CUDAGraphMode, uniform_decode_query_len: int self, cudagraph_mode: CUDAGraphMode, uniform_decode_query_len: int = 1
): ):
# This should be called only after attention backend is initialized. So we can # This should be called only after attention backend is initialized. So we can
# get the correct cudagraph mode after backend support is resolved. # get the correct cudagraph mode after backend support is resolved.
self.cudagraph_mode = cudagraph_mode self.cudagraph_mode = cudagraph_mode
# Early exit if cudagraphs are disabled
if cudagraph_mode == CUDAGraphMode.NONE:
self.keys_initialized = True
return
self._compute_bs_to_padded_graph_size()
# LoRA activation cases to specialize the cuda graphs on # LoRA activation cases to specialize the cuda graphs on
if self.vllm_config.lora_config: if self.vllm_config.lora_config:
if self.compilation_config.cudagraph_specialize_lora: if self.compilation_config.cudagraph_specialize_lora:
...@@ -143,15 +184,24 @@ class CudagraphDispatcher: ...@@ -143,15 +184,24 @@ class CudagraphDispatcher:
def dispatch( def dispatch(
self, self,
num_tokens: int, num_tokens: int,
uniform_decode: bool, uniform_decode: bool = False,
has_lora: bool, has_lora: bool = False,
disable_full: bool = False, disable_full: bool = False,
) -> tuple[CUDAGraphMode, BatchDescriptor]: ) -> tuple[CUDAGraphMode, BatchDescriptor]:
""" """
Given conditions(e.g.,batch descriptor and if using cascade attention), Given conditions(e.g.,batch descriptor and if using piecewise only),
dispatch to a cudagraph runtime mode and the valid batch descriptor. dispatch to a cudagraph runtime mode and the valid batch descriptor.
A new batch descriptor is returned as we might dispatch a uniform batch A new batch descriptor is returned as we might dispatch a uniform batch
to a graph that supports a more general batch (uniform to non-uniform). to a graph that supports a more general batch (uniform to non-uniform).
Args:
num_tokens: Number of tokens in the batch.
uniform_decode: Whether the batch is uniform decode (i.e. uniform and query
length is uniform_decode_query_len).
has_lora: Whether LoRA is active.
disable_full: If True, skip FULL cudagraph checks and
return PIECEWISE or NONE only. (can be used for features like
cascade attention that are not supported by full cudagraphs)
""" """
if ( if (
not self.keys_initialized not self.keys_initialized
......
...@@ -9,7 +9,6 @@ import torch ...@@ -9,7 +9,6 @@ import torch
import torch.nn as nn import torch.nn as nn
from vllm.config import ( from vllm.config import (
CompilationMode,
CUDAGraphMode, CUDAGraphMode,
VllmConfig, VllmConfig,
get_layers_from_vllm_config, get_layers_from_vllm_config,
...@@ -36,6 +35,7 @@ from vllm.v1.attention.backends.tree_attn import ( ...@@ -36,6 +35,7 @@ from vllm.v1.attention.backends.tree_attn import (
TreeAttentionMetadataBuilder, TreeAttentionMetadataBuilder,
) )
from vllm.v1.attention.backends.triton_attn import TritonAttentionMetadata from vllm.v1.attention.backends.triton_attn import TritonAttentionMetadata
from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher
from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.sampler import _SAMPLING_EPS from vllm.v1.sample.sampler import _SAMPLING_EPS
...@@ -100,24 +100,13 @@ class SpecDecodeBaseProposer: ...@@ -100,24 +100,13 @@ class SpecDecodeBaseProposer:
self._get_eagle3_use_aux_hidden_state_from_config() self._get_eagle3_use_aux_hidden_state_from_config()
) )
self.use_cuda_graph = False
self.compilation_config = self.vllm_config.compilation_config self.compilation_config = self.vllm_config.compilation_config
if self.compilation_config.mode == CompilationMode.VLLM_COMPILE:
cudagraph_mode = self.compilation_config.cudagraph_mode # Cudagraph dispatcher for PIECEWISE-only dispatching in eagle.
if cudagraph_mode != CUDAGraphMode.NONE and not cudagraph_mode.has_mode( # Keys are initialized later via initialize_cudagraph_keys() called from
CUDAGraphMode.PIECEWISE # gpu_model_runner._check_and_update_cudagraph_mode after
): # adjust_cudagraph_sizes_for_spec_decode is called.
logger.warning( self.cudagraph_dispatcher = CudagraphDispatcher(self.vllm_config)
"Currently the eagle proposer only supports cudagraph_mode "
"PIECEWISE, if you want the drafter to use cuda graphs, "
"please set compilation_config.cudagraph_mode to PIECEWISE "
"or FULL_AND_PIECEWISE"
)
self.use_cuda_graph = (
cudagraph_mode.has_mode(CUDAGraphMode.PIECEWISE)
and not self.speculative_config.enforce_eager
)
# persistent buffers for cuda graph # persistent buffers for cuda graph
self.input_ids = torch.zeros( self.input_ids = torch.zeros(
...@@ -234,6 +223,23 @@ class SpecDecodeBaseProposer: ...@@ -234,6 +223,23 @@ class SpecDecodeBaseProposer:
else: else:
self.positions[:num_tokens] = positions self.positions[:num_tokens] = positions
def initialize_cudagraph_keys(self, cudagraph_mode: CUDAGraphMode) -> None:
"""Initialize cudagraph dispatcher keys for eagle.
Eagle only supports PIECEWISE cudagraphs (via mixed_mode).
This should be called after adjust_cudagraph_sizes_for_spec_decode.
"""
if (
not self.speculative_config.enforce_eager
and cudagraph_mode.mixed_mode()
in [CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL]
):
eagle_cudagraph_mode = CUDAGraphMode.PIECEWISE
else:
eagle_cudagraph_mode = CUDAGraphMode.NONE
self.cudagraph_dispatcher.initialize_cudagraph_keys(eagle_cudagraph_mode)
def propose( def propose(
self, self,
# [num_tokens] # [num_tokens]
...@@ -304,16 +310,10 @@ class SpecDecodeBaseProposer: ...@@ -304,16 +310,10 @@ class SpecDecodeBaseProposer:
num_tokens_unpadded=num_tokens, num_tokens_padded=num_tokens num_tokens_unpadded=num_tokens, num_tokens_padded=num_tokens
) )
cudagraph_runtime_mode = CUDAGraphMode.NONE cudagraph_runtime_mode, batch_desc = self.cudagraph_dispatcher.dispatch(
if ( num_tokens_dp_padded
self.use_cuda_graph )
and num_tokens_dp_padded num_input_tokens = batch_desc.num_tokens
<= self.compilation_config.max_cudagraph_capture_size
):
num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens_dp_padded)
cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
else:
num_input_tokens = num_tokens_dp_padded
if num_tokens_across_dp is not None: if num_tokens_across_dp is not None:
num_tokens_across_dp[self.dp_rank] = num_input_tokens num_tokens_across_dp[self.dp_rank] = num_input_tokens
...@@ -412,16 +412,10 @@ class SpecDecodeBaseProposer: ...@@ -412,16 +412,10 @@ class SpecDecodeBaseProposer:
num_tokens_unpadded=batch_size, num_tokens_padded=batch_size num_tokens_unpadded=batch_size, num_tokens_padded=batch_size
) )
if ( cudagraph_runtime_mode, batch_desc = self.cudagraph_dispatcher.dispatch(
self.use_cuda_graph batch_size_dp_padded
and batch_size_dp_padded )
<= self.compilation_config.max_cudagraph_capture_size input_batch_size = batch_desc.num_tokens
):
input_batch_size = self.vllm_config.pad_for_cudagraph(batch_size_dp_padded)
cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
else:
input_batch_size = batch_size_dp_padded
cudagraph_runtime_mode = CUDAGraphMode.NONE
if batch_size_across_dp is not None: if batch_size_across_dp is not None:
batch_size_across_dp[self.dp_rank] = input_batch_size batch_size_across_dp[self.dp_rank] = input_batch_size
...@@ -870,15 +864,10 @@ class SpecDecodeBaseProposer: ...@@ -870,15 +864,10 @@ class SpecDecodeBaseProposer:
self.positions[:num_tokens] = tree_positions.view(-1) self.positions[:num_tokens] = tree_positions.view(-1)
self.hidden_states[:num_tokens] = tree_hidden_states.view(num_tokens, -1) self.hidden_states[:num_tokens] = tree_hidden_states.view(num_tokens, -1)
if ( cudagraph_runtime_mode, batch_desc = self.cudagraph_dispatcher.dispatch(
self.use_cuda_graph num_tokens
and num_tokens <= self.compilation_config.max_cudagraph_capture_size )
): num_input_tokens = batch_desc.num_tokens
num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens)
cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
else:
num_input_tokens = num_tokens
cudagraph_runtime_mode = CUDAGraphMode.NONE
# Run the model. # Run the model.
with set_forward_context( with set_forward_context(
per_layer_attn_metadata, per_layer_attn_metadata,
...@@ -1216,9 +1205,6 @@ class SpecDecodeBaseProposer: ...@@ -1216,9 +1205,6 @@ class SpecDecodeBaseProposer:
use_cudagraphs: bool = True, use_cudagraphs: bool = True,
is_graph_capturing: bool = False, is_graph_capturing: bool = False,
) -> None: ) -> None:
# Determine if CUDA graphs should be used for this run.
cudagraphs_enabled = use_cudagraphs and self.use_cuda_graph
# FIXME: when using tree-based specdec, adjust number of forward-passes # FIXME: when using tree-based specdec, adjust number of forward-passes
# according to the depth of the tree. # according to the depth of the tree.
for fwd_idx in range( for fwd_idx in range(
...@@ -1228,16 +1214,10 @@ class SpecDecodeBaseProposer: ...@@ -1228,16 +1214,10 @@ class SpecDecodeBaseProposer:
num_tokens_dp_padded, num_tokens_across_dp = self._pad_batch_across_dp( num_tokens_dp_padded, num_tokens_across_dp = self._pad_batch_across_dp(
num_tokens_unpadded=num_tokens, num_tokens_padded=num_tokens num_tokens_unpadded=num_tokens, num_tokens_padded=num_tokens
) )
if ( cudagraph_runtime_mode, batch_desc = self.cudagraph_dispatcher.dispatch(
cudagraphs_enabled num_tokens_dp_padded
and num_tokens_dp_padded )
<= self.compilation_config.max_cudagraph_capture_size num_input_tokens = batch_desc.num_tokens
):
num_input_tokens = self.vllm_config.pad_for_cudagraph(
num_tokens_dp_padded
)
else:
num_input_tokens = num_tokens_dp_padded
if num_tokens_across_dp is not None: if num_tokens_across_dp is not None:
num_tokens_across_dp[self.dp_rank] = num_input_tokens num_tokens_across_dp[self.dp_rank] = num_input_tokens
...@@ -1246,9 +1226,7 @@ class SpecDecodeBaseProposer: ...@@ -1246,9 +1226,7 @@ class SpecDecodeBaseProposer:
self.vllm_config, self.vllm_config,
num_tokens=num_input_tokens, num_tokens=num_input_tokens,
num_tokens_across_dp=num_tokens_across_dp, num_tokens_across_dp=num_tokens_across_dp,
cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE cudagraph_runtime_mode=cudagraph_runtime_mode,
if cudagraphs_enabled
else CUDAGraphMode.NONE,
): ):
if self.supports_mm_inputs: if self.supports_mm_inputs:
input_ids = None input_ids = None
...@@ -1340,7 +1318,8 @@ class SpecDecodeBaseProposer: ...@@ -1340,7 +1318,8 @@ class SpecDecodeBaseProposer:
num_tokens_unpadded=num_tokens_unpadded, num_tokens_unpadded=num_tokens_unpadded,
parallel_config=self.vllm_config.parallel_config, parallel_config=self.vllm_config.parallel_config,
allow_microbatching=False, allow_microbatching=False,
allow_dp_padding=self.use_cuda_graph, allow_dp_padding=self.cudagraph_dispatcher.cudagraph_mode
!= CUDAGraphMode.NONE,
num_tokens_padded=num_tokens_padded, num_tokens_padded=num_tokens_padded,
uniform_decode=None, uniform_decode=None,
num_scheduled_tokens_per_request=None, num_scheduled_tokens_per_request=None,
......
...@@ -2139,15 +2139,11 @@ class GPUModelRunner( ...@@ -2139,15 +2139,11 @@ class GPUModelRunner(
self.kv_sharing_fast_prefill_logits_indices[num_logits:].fill_( self.kv_sharing_fast_prefill_logits_indices[num_logits:].fill_(
logits_indices[-1].item() logits_indices[-1].item()
) )
if ( # Dispatch for the decoder portion of the model.
self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE _, batch_desc = self.cudagraph_dispatcher.dispatch(
and num_logits <= self.cudagraph_batch_sizes[-1] num_logits, disable_full=True
): )
# Use piecewise CUDA graphs. num_logits_padded = batch_desc.num_tokens
# Add padding to the batch size.
num_logits_padded = self.vllm_config.pad_for_cudagraph(num_logits)
else:
num_logits_padded = num_logits
logits_indices_padded = self.kv_sharing_fast_prefill_logits_indices[ logits_indices_padded = self.kv_sharing_fast_prefill_logits_indices[
:num_logits_padded :num_logits_padded
] ]
...@@ -5212,6 +5208,11 @@ class GPUModelRunner( ...@@ -5212,6 +5208,11 @@ class GPUModelRunner(
cudagraph_mode, self.uniform_decode_query_len cudagraph_mode, self.uniform_decode_query_len
) )
# Initialize eagle's cudagraph dispatcher if using eagle spec decode.
if self.speculative_config and self.speculative_config.use_eagle():
assert isinstance(self.drafter, EagleProposer)
self.drafter.initialize_cudagraph_keys(cudagraph_mode)
def calculate_reorder_batch_threshold(self) -> None: def calculate_reorder_batch_threshold(self) -> None:
""" """
Choose the minimum reorder batch threshold from all attention groups. Choose the minimum reorder batch threshold from all attention groups.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment