Unverified Commit 3a414595 authored by Lucas Wilkinson's avatar Lucas Wilkinson Committed by GitHub
Browse files

[cudagraphs] Refactor cudagraph capture loop (#32946)


Signed-off-by: default avatarLucas Wilkinson <lwilkins@redhat.com>
parent 8518b304
...@@ -173,6 +173,68 @@ class TestCudagraphDispatcher: ...@@ -173,6 +173,68 @@ class TestCudagraphDispatcher:
else: else:
assert rt_mode == CUDAGraphMode.NONE assert rt_mode == CUDAGraphMode.NONE
@pytest.mark.parametrize(
"cudagraph_mode_str,compilation_mode,expected_modes",
[
# FULL mode: only FULL keys, no PIECEWISE
("FULL", CompilationMode.NONE, [CUDAGraphMode.FULL]),
# PIECEWISE mode: only PIECEWISE keys
("PIECEWISE", CompilationMode.VLLM_COMPILE, [CUDAGraphMode.PIECEWISE]),
# FULL_DECODE_ONLY: only FULL keys for uniform decode
("FULL_DECODE_ONLY", CompilationMode.NONE, [CUDAGraphMode.FULL]),
# NONE mode: no keys
("NONE", CompilationMode.NONE, []),
],
)
def test_get_capture_descs(
self, cudagraph_mode_str, compilation_mode, expected_modes
):
"""Test get_capture_descs returns correctly grouped and ordered descs."""
comp_config = CompilationConfig(
cudagraph_mode=cudagraph_mode_str,
mode=compilation_mode,
cudagraph_capture_sizes=[1, 4, 8, 16],
)
config = _create_vllm_config(comp_config, max_num_seqs=16)
dispatcher = CudagraphDispatcher(config)
dispatcher.initialize_cudagraph_keys(
cudagraph_mode=comp_config.cudagraph_mode, uniform_decode_query_len=1
)
capture_descs = dispatcher.get_capture_descs()
# Verify we get the expected modes
actual_modes = [mode for mode, _ in capture_descs]
assert actual_modes == expected_modes
# Verify each group is sorted largest-first
for mode, descs in capture_descs:
assert len(descs) > 0, "Each group should have at least one descriptor"
num_tokens_list = [d.num_tokens for d in descs]
assert num_tokens_list == sorted(num_tokens_list, reverse=True), (
f"Descriptors for {mode} should be sorted largest-first"
)
# All descriptors in a group should have same uniform value
uniform_values = [d.uniform for d in descs]
assert len(set(uniform_values)) == 1, (
"All descriptors in a group should have the same uniform value"
)
def test_get_capture_descs_empty_when_not_initialized(self):
"""Test that get_capture_descs returns empty list when keys not initialized."""
comp_config = CompilationConfig(
cudagraph_mode="FULL",
mode=CompilationMode.NONE,
cudagraph_capture_sizes=[1, 8],
)
config = _create_vllm_config(comp_config, max_num_seqs=8)
dispatcher = CudagraphDispatcher(config)
# Don't initialize keys
assert dispatcher.get_capture_descs() == []
@pytest.mark.skipif(not current_platform.is_cuda(), reason="Skip if not cuda") @pytest.mark.skipif(not current_platform.is_cuda(), reason="Skip if not cuda")
class TestCUDAGraphWrapper: class TestCUDAGraphWrapper:
......
...@@ -231,3 +231,26 @@ class CudagraphDispatcher: ...@@ -231,3 +231,26 @@ class CudagraphDispatcher:
# finally, just return no cudagraphs and a trivial batch descriptor # finally, just return no cudagraphs and a trivial batch descriptor
return CUDAGraphMode.NONE, BatchDescriptor(num_tokens) return CUDAGraphMode.NONE, BatchDescriptor(num_tokens)
def get_capture_descs(self) -> list[tuple[CUDAGraphMode, list[BatchDescriptor]]]:
"""
Returns capture descriptors for cudagraph capturing.
Returns:
List of (runtime_mode, batch_descriptors) tuples, ordered PIECEWISE
first then FULL. Batch descriptors are sorted largest-first for
memory efficiency.
"""
if not self.keys_initialized or self.cudagraph_mode == CUDAGraphMode.NONE:
return []
result = []
# Return in order: PIECEWISE first, then FULL
for mode in [CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL]:
descs = list(self.cudagraph_keys[mode])
if descs:
# Sort by num_tokens descending (largest first)
descs.sort(key=lambda d: d.num_tokens, reverse=True)
result.append((mode, descs))
return result
...@@ -10,7 +10,6 @@ from collections.abc import Iterator, Sequence ...@@ -10,7 +10,6 @@ from collections.abc import Iterator, Sequence
from contextlib import contextmanager from contextlib import contextmanager
from copy import copy, deepcopy from copy import copy, deepcopy
from functools import reduce from functools import reduce
from itertools import product
from typing import TYPE_CHECKING, Any, NamedTuple, TypeAlias, cast from typing import TYPE_CHECKING, Any, NamedTuple, TypeAlias, cast
import numpy as np import numpy as np
...@@ -4839,50 +4838,14 @@ class GPUModelRunner( ...@@ -4839,50 +4838,14 @@ class GPUModelRunner(
set_cudagraph_capturing_enabled(True) set_cudagraph_capturing_enabled(True)
with freeze_gc(), graph_capture(device=self.device): with freeze_gc(), graph_capture(device=self.device):
start_free_gpu_memory = torch.cuda.mem_get_info()[0] start_free_gpu_memory = torch.cuda.mem_get_info()[0]
cudagraph_mode = self.compilation_config.cudagraph_mode
assert cudagraph_mode is not None
if self.lora_config: for (
if self.compilation_config.cudagraph_specialize_lora: runtime_mode,
lora_cases = [True, False] batch_descs,
else: ) in self.cudagraph_dispatcher.get_capture_descs():
lora_cases = [True]
else:
lora_cases = [False]
if cudagraph_mode.mixed_mode() != CUDAGraphMode.NONE:
cudagraph_runtime_mode = cudagraph_mode.mixed_mode()
# make sure we capture the largest batch size first
compilation_cases = list(
product(reversed(self.cudagraph_batch_sizes), lora_cases)
)
self._capture_cudagraphs(
compilation_cases,
cudagraph_runtime_mode=cudagraph_runtime_mode,
uniform_decode=False,
)
# Capture full cudagraph for uniform decode batches if we
# don't already have full mixed prefill-decode cudagraphs.
if (
cudagraph_mode.decode_mode() == CUDAGraphMode.FULL
and cudagraph_mode.separate_routine()
):
max_num_tokens = (
self.scheduler_config.max_num_seqs * self.uniform_decode_query_len
)
decode_cudagraph_batch_sizes = [
x
for x in self.cudagraph_batch_sizes
if max_num_tokens >= x >= self.uniform_decode_query_len
]
compilation_cases_decode = list(
product(reversed(decode_cudagraph_batch_sizes), lora_cases)
)
self._capture_cudagraphs( self._capture_cudagraphs(
compilation_cases=compilation_cases_decode, batch_descriptors=batch_descs,
cudagraph_runtime_mode=CUDAGraphMode.FULL, cudagraph_runtime_mode=runtime_mode,
uniform_decode=True,
) )
torch.cuda.synchronize() torch.cuda.synchronize()
...@@ -4913,19 +4876,32 @@ class GPUModelRunner( ...@@ -4913,19 +4876,32 @@ class GPUModelRunner(
def _capture_cudagraphs( def _capture_cudagraphs(
self, self,
compilation_cases: list[tuple[int, bool]], batch_descriptors: list[BatchDescriptor],
cudagraph_runtime_mode: CUDAGraphMode, cudagraph_runtime_mode: CUDAGraphMode,
uniform_decode: bool,
): ):
assert ( assert (
cudagraph_runtime_mode != CUDAGraphMode.NONE cudagraph_runtime_mode != CUDAGraphMode.NONE
and cudagraph_runtime_mode.valid_runtime_modes() and cudagraph_runtime_mode.valid_runtime_modes()
), f"Invalid cudagraph runtime mode: {cudagraph_runtime_mode}" ), f"Invalid cudagraph runtime mode: {cudagraph_runtime_mode}"
if not batch_descriptors:
return
uniform_decode = batch_descriptors[0].uniform
force_attention = cudagraph_runtime_mode == CUDAGraphMode.FULL
dummy_run = functools.partial(
self._dummy_run,
uniform_decode=uniform_decode,
skip_eplb=True,
remove_lora=False,
force_attention=force_attention,
)
# Only rank 0 should print progress bar during capture # Only rank 0 should print progress bar during capture
if is_global_first_rank(): if is_global_first_rank():
compilation_cases = tqdm( batch_descriptors = tqdm(
compilation_cases, batch_descriptors,
disable=not self.load_config.use_tqdm_on_load, disable=not self.load_config.use_tqdm_on_load,
desc="Capturing CUDA graphs ({}, {})".format( desc="Capturing CUDA graphs ({}, {})".format(
"decode" if uniform_decode else "mixed prefill-decode", "decode" if uniform_decode else "mixed prefill-decode",
...@@ -4934,7 +4910,10 @@ class GPUModelRunner( ...@@ -4934,7 +4910,10 @@ class GPUModelRunner(
) )
# We skip EPLB here since we don't want to record dummy metrics # We skip EPLB here since we don't want to record dummy metrics
for num_tokens, activate_lora in compilation_cases: for batch_desc in batch_descriptors:
num_tokens = batch_desc.num_tokens
activate_lora = batch_desc.has_lora
# We currently only capture ubatched graphs when its a FULL # We currently only capture ubatched graphs when its a FULL
# cudagraph, a uniform decode batch, and the number of tokens # cudagraph, a uniform decode batch, and the number of tokens
# is above the threshold. Otherwise we just capture a non-ubatched # is above the threshold. Otherwise we just capture a non-ubatched
...@@ -4952,28 +4931,22 @@ class GPUModelRunner( ...@@ -4952,28 +4931,22 @@ class GPUModelRunner(
for _ in range(self.compilation_config.cudagraph_num_of_warmups): for _ in range(self.compilation_config.cudagraph_num_of_warmups):
# Use CUDAGraphRuntimeStyle.NONE (default) for warmup. # Use CUDAGraphRuntimeStyle.NONE (default) for warmup.
# But be careful, warm up with `NONE`is orthogonal to # But be careful, warm up with `NONE` is orthogonal to
# if we want to warm up attention or not. This is # if we want to warm up attention or not. This is
# different from the case where `FULL` implies capture # different from the case where `FULL` implies capture
# attention while `PIECEWISE` implies no attention. # attention while `PIECEWISE` implies no attention.
force_attention = cudagraph_runtime_mode == CUDAGraphMode.FULL dummy_run(
self._dummy_run(
num_tokens, num_tokens,
cudagraph_runtime_mode=CUDAGraphMode.NONE, cudagraph_runtime_mode=CUDAGraphMode.NONE,
force_attention=force_attention,
uniform_decode=uniform_decode,
allow_microbatching=allow_microbatching, allow_microbatching=allow_microbatching,
skip_eplb=True,
remove_lora=False,
activate_lora=activate_lora, activate_lora=activate_lora,
) )
self._dummy_run(
# Capture run
dummy_run(
num_tokens, num_tokens,
cudagraph_runtime_mode=cudagraph_runtime_mode, cudagraph_runtime_mode=cudagraph_runtime_mode,
uniform_decode=uniform_decode,
allow_microbatching=allow_microbatching, allow_microbatching=allow_microbatching,
skip_eplb=True,
remove_lora=False,
activate_lora=activate_lora, activate_lora=activate_lora,
is_graph_capturing=True, is_graph_capturing=True,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment