Unverified Commit 80118853 authored by Shanshan Shen's avatar Shanshan Shen Committed by GitHub
Browse files

[MM][Perf][CG] Support ViT full CUDA graph for Qwen3-VL video inference (#38061)


Signed-off-by: default avatarshen-shanshan <467638484@qq.com>
Signed-off-by: default avatarShanshan Shen <87969357+shen-shanshan@users.noreply.github.com>
Co-authored-by: default avatarRoger Wang <hey@rogerw.io>
parent c0ecaed9
......@@ -28,6 +28,7 @@ Multiple CUDA Graphs are pre-captured at different **token budget** levels (e.g.
class BudgetGraphMetadata:
token_budget: int
max_batch_size: int
max_frames_per_batch: int
graph: torch.cuda.CUDAGraph
input_buffer: torch.Tensor # e.g. pixel_values
metadata_buffers: dict[str, torch.Tensor] # e.g. embeddings, seq metadata
......@@ -51,6 +52,15 @@ For each graph replay:
When `mm_encoder_tp_mode="data"`, the manager distributes images across TP ranks using load-balanced assignment via `get_load_balance_assignment`, executes locally on each rank, then gathers results back in the original order via `tensor_model_parallel_all_gather`.
### Video inference support (experimental)
Following <https://github.com/vllm-project/vllm/pull/35963> (ViT full CUDA graph support for image inference), <https://github.com/vllm-project/vllm/pull/38061> extends the encoder CUDA graph framework to support video inference for Qwen3-VL. Previously, the CUDA graph capture/replay path only handled image inputs (`pixel_values` + `image_grid_thw`). Video inputs use different keys (`pixel_values_videos` + `video_grid_thw`) and require larger `cu_seqlens` buffers because each video item contributes multiple frames (`T` attention sequences). This PR generalizes the protocol and manager to handle both modalities through a single shared graph manager.
!!! note
Video CUDA graphs are automatically disabled when EVS (Efficient Video Sampling) pruning is enabled, since EVS makes the token count data-dependent and incompatible with CUDA graph capture.
Currently, we only support image-only or video-only inputs when enabling CUDA graph, mixed inputs (image + video) are not supported yet (we will work on it in the near future). Thus, it's recommended to turn off the image modality by `--limit-mm-per-prompt '{"image": 0}'` for video-only inputs.
## Model integration via `SupportsEncoderCudaGraph`
Models opt-in to encoder CUDA Graphs by implementing the [SupportsEncoderCudaGraph][vllm.model_executor.models.interfaces.SupportsEncoderCudaGraph] protocol. This protocol encapsulates all model-specific logic so that the manager remains model-agnostic. The protocol defines the following methods:
......@@ -65,12 +75,17 @@ Models opt-in to encoder CUDA Graphs by implementing the [SupportsEncoderCudaGra
* `prepare_encoder_cudagraph_replay_buffers(...)` — computes new buffer values from actual batch inputs before replay.
* `encoder_cudagraph_forward(...)` — forward pass using precomputed buffers (called during capture and replay).
* `encoder_eager_forward(...)` — fallback eager forward when no graph fits.
Currently supported: **Qwen3-VL** (see `vllm/model_executor/models/qwen3_vl.py`).
* `get_input_modality(...)` - return the modality of the inputs.
!!! note
The `SupportsEncoderCudaGraph` protocol is designed to be model-agnostic. New vision encoder models can opt-in by implementing the protocol methods without modifying the manager.
**Supported models:**
| Architecture | Models | CG for Image | CG for Video |
| ------------ | ------ | ------------ | ------------ |
| `Qwen3VLForConditionalGeneration` | `Qwen3-VL` | ✅︎ | ✅︎ |
!!! note
Encoder CUDA Graphs have currently been tested with `--mm-encoder-attn-backend=FLASH_ATTN` and `--mm-encoder-attn-backend=FLASHINFER` on Blackwell GPUs.
......@@ -80,10 +95,13 @@ Three fields in `CompilationConfig` control encoder CUDA Graphs:
* `cudagraph_mm_encoder` (`bool`, default `False`) — enable CUDA Graph capture for multimodal encoder. When enabled, captures the full encoder forward as a CUDA Graph for each token budget level.
* `encoder_cudagraph_token_budgets` (`list[int]`, default `[]`) — token budget levels for capture. If empty (default), auto-inferred from model architecture as power-of-2 levels. User-provided values override auto-inference.
* `encoder_cudagraph_max_images_per_batch` (`int`, default `0`) — maximum number of images per batch during capture. If 0 (default), auto-inferred as `max_budget // min_budget`.
* `encoder_cudagraph_max_vision_items_per_batch` (`int`, default `0`) — maximum number of images/videos per batch during capture. If 0 (default), auto-inferred as `max_budget // min_budget`.
* `encoder_cudagraph_max_frames_per_batch` (`int`, default `0`) — maximum number of video frames per batch during capture. If 0 (default), auto-inferred as `encoder_cudagraph_max_vision_items_per_batch * 2` (to be optimized).
## Usage guide
### Image inference
Enable encoder CUDA Graphs via `compilation_config`:
```bash
......@@ -95,7 +113,7 @@ With explicit budgets:
```bash
vllm serve Qwen/Qwen3-VL-32B \
--compilation-config '{"cudagraph_mm_encoder": true, "encoder_cudagraph_token_budgets": [2048, 4096, 8192, 13824], "encoder_cudagraph_max_images_per_batch": 8}'
--compilation-config '{"cudagraph_mm_encoder": true, "encoder_cudagraph_token_budgets": [2048, 4096, 8192, 13824], "encoder_cudagraph_max_vision_items_per_batch": 8}'
```
Python example:
......@@ -107,7 +125,7 @@ compilation_config = {
"cudagraph_mm_encoder": True,
# Optional: override auto-inferred budgets
# "encoder_cudagraph_token_budgets": [2048, 4096, 8192, 13824],
# "encoder_cudagraph_max_images_per_batch": 8,
# "encoder_cudagraph_max_vision_items_per_batch": 8,
}
model = vllm.LLM(
......@@ -118,6 +136,44 @@ model = vllm.LLM(
The manager tracks hit/miss statistics and logs them periodically. A "hit" means an image was processed via CUDA Graph replay; a "miss" means eager fallback (image exceeded all budgets).
### Video inference
Enable encoder CUDA Graphs via `compilation_config`:
```bash
vllm serve Qwen/Qwen3-VL-32B \
--limit-mm-per-prompt '{"image": 0}' \
--compilation-config '{"cudagraph_mm_encoder": true}'
```
With explicit budgets:
```bash
vllm serve Qwen/Qwen3-VL-32B \
--limit-mm-per-prompt '{"image": 0}' \
--compilation-config '{"cudagraph_mm_encoder": true, "encoder_cudagraph_token_budgets": [2048, 4096, 8192, 13824], "encoder_cudagraph_max_vision_items_per_batch": 8, "encoder_cudagraph_max_frames_per_batch": 64}'
```
Python example:
```python
import vllm
compilation_config = {
"cudagraph_mm_encoder": True,
# Optional: override auto-inferred budgets
# "encoder_cudagraph_token_budgets": [2048, 4096, 8192, 13824],
# "encoder_cudagraph_max_vision_items_per_batch": 8,
# "encoder_cudagraph_max_frames_per_batch": 64,
}
model = vllm.LLM(
model="Qwen/Qwen3-VL-32B",
limit_mm_per_prompt='{"image": 0}',
compilation_config=compilation_config,
)
```
## About the Performance
The following benchmarks were run on Blackwell GPUs (GB200) using `vllm bench mm-processor`. See [#35963](https://github.com/vllm-project/vllm/pull/35963) for full details.
......@@ -140,7 +196,7 @@ vllm bench mm-processor \
--num-prompts 3000 --num-warmups 300 \
--max-model-len 32768 --seed 42 \
--mm-encoder-attn-backend FLASH_ATTN \
--compilation-config '{"cudagraph_mm_encoder": true, "encoder_cudagraph_token_budgets": [512, 1024, 1536, 2048, 2560, 3072, 3584, 4096, 4864], "encoder_cudagraph_max_images_per_batch": 8}'
--compilation-config '{"cudagraph_mm_encoder": true, "encoder_cudagraph_token_budgets": [512, 1024, 1536, 2048, 2560, 3072, 3584, 4096, 4864], "encoder_cudagraph_max_vision_items_per_batch": 8}'
```
### Multi-GPU (4x GB200, TP=4, DP=4)
......@@ -165,5 +221,8 @@ vllm bench mm-processor \
--max-model-len 8192 --seed 42 \
--mm-encoder-attn-backend FLASHINFER \
--tensor-parallel-size 4 --mm-encoder-tp-mode data \
--compilation-config '{"cudagraph_mm_encoder": true, "encoder_cudagraph_token_budgets": [512, 1024, 1536, 2048, 2560, 3072, 3584, 4096, 4864], "encoder_cudagraph_max_images_per_batch": 8}'
--compilation-config '{"cudagraph_mm_encoder": true, "encoder_cudagraph_token_budgets": [512, 1024, 1536, 2048, 2560, 3072, 3584, 4096, 4864], "encoder_cudagraph_max_vision_items_per_batch": 8}'
```
!!! note
Find more details about benchmarks on GPUs (A100) for video inference at [#38061](https://github.com/vllm-project/vllm/pull/38061).
......@@ -6,8 +6,10 @@ Test organization:
No GPU required:
- TestFindBudgetGraph — greedy budget selection logic
- TestGetCumulativeStats — hit/miss rate statistics
- TestGetInputModality — modality routing from mm_kwargs keys
GPU required:
- TestEncoderCudaGraphCaptureReplay — capture, replay, fallback, counters, chunking
- TestEncoderCudaGraphVideoReplay — video modality capture, replay
"""
from typing import Any
......@@ -205,11 +207,19 @@ class SimpleMockViTModel(torch.nn.Module):
def get_encoder_cudagraph_config(self) -> EncoderCudaGraphConfig:
return EncoderCudaGraphConfig(
modalities=["image"],
input_key="pixel_values",
input_key_by_modality={
"image": "pixel_values",
},
buffer_keys=["dummy_buf"],
out_hidden_size=_HIDDEN,
)
def get_input_modality(
self,
mm_kwargs: dict[str, Any],
) -> str:
return "image"
def get_encoder_cudagraph_budget_range(
self,
vllm_config,
......@@ -268,6 +278,7 @@ class SimpleMockViTModel(torch.nn.Module):
self,
token_budget: int,
max_batch_size: int,
max_frames_per_batch: int,
device: torch.device,
dtype: torch.dtype,
) -> EncoderCudaGraphCaptureInputs:
......@@ -294,6 +305,7 @@ class SimpleMockViTModel(torch.nn.Module):
self,
mm_kwargs: dict[str, Any],
max_batch_size: int,
max_frames_per_batch: int,
) -> EncoderCudaGraphReplayBuffers:
grid_thw = mm_kwargs["image_grid_thw"]
n_out = _count_output_tokens(grid_thw, _SPATIAL_MERGE)
......@@ -327,11 +339,16 @@ def _make_manager_for_gpu(
max_batch_size: int,
device: torch.device,
dtype: torch.dtype,
*,
max_frames_per_batch: int | None = None,
) -> EncoderCudaGraphManager:
"""Create EncoderCudaGraphManager bypassing VllmConfig for GPU tests."""
mgr = object.__new__(EncoderCudaGraphManager)
mgr.token_budgets = sorted(token_budgets)
mgr.max_batch_size = max_batch_size
mgr.max_frames_per_batch = (
max_frames_per_batch if max_frames_per_batch is not None else max_batch_size * 2
)
mgr.use_dp = False
mgr.budget_graphs = {}
mgr.graph_hits = 0
......@@ -366,6 +383,18 @@ def _make_mm_kwargs(
}
def _make_video_mm_kwargs(
grid_thw_list: list[list[int]],
device: torch.device,
dtype: torch.dtype,
) -> dict[str, Any]:
"""Create video mm_kwargs (pixel_values_videos / video_grid_thw) for testing."""
return {
"pixel_values_videos": _make_pixel_values(grid_thw_list, device, dtype),
"video_grid_thw": grid_thw_list,
}
# ---------------------------------------------------------------------------
# GPU tests — capture, replay, fallback, counters, chunking
# ---------------------------------------------------------------------------
......@@ -449,3 +478,285 @@ class TestEncoderCudaGraphCaptureReplay:
assert len(result) == n_images
for out in result:
assert out.shape == (4, _HIDDEN)
# ---------------------------------------------------------------------------
# SimpleMockViTVideoModel — extends SimpleMockViTModel with video support
# ---------------------------------------------------------------------------
class SimpleMockViTVideoModel(SimpleMockViTModel):
"""ViT mock that supports both image and video modalities.
Reuses SimpleMockViTModel's NN weights and _forward() logic.
Only the protocol methods that are key-dependent are overridden.
"""
def get_encoder_cudagraph_config(self) -> EncoderCudaGraphConfig:
return EncoderCudaGraphConfig(
modalities=["image", "video"],
input_key_by_modality={
"image": "pixel_values",
"video": "pixel_values_videos",
},
buffer_keys=["dummy_buf"],
out_hidden_size=_HIDDEN,
)
def get_input_modality(self, mm_kwargs: dict[str, Any]) -> str:
return "video" if "video_grid_thw" in mm_kwargs else "image"
# ------------------------------------------------------------------
# Private helpers — route to the correct mm_kwargs keys
# ------------------------------------------------------------------
def _get_grid_thw(self, mm_kwargs: dict[str, Any]) -> list[list[int]]:
key = (
"video_grid_thw"
if self.get_input_modality(mm_kwargs) == "video"
else "image_grid_thw"
)
return mm_kwargs[key]
def _get_pixel_values(self, mm_kwargs: dict[str, Any]) -> torch.Tensor:
key = (
"pixel_values_videos"
if self.get_input_modality(mm_kwargs) == "video"
else "pixel_values"
)
return mm_kwargs[key]
# ------------------------------------------------------------------
# Protocol overrides that depend on modality keys
# ------------------------------------------------------------------
def get_encoder_cudagraph_num_items(self, mm_kwargs: dict[str, Any]) -> int:
return len(self._get_grid_thw(mm_kwargs))
def get_encoder_cudagraph_per_item_output_tokens(
self, mm_kwargs: dict[str, Any]
) -> list[int]:
m = _SPATIAL_MERGE
return [t * (h // m) * (w // m) for t, h, w in self._get_grid_thw(mm_kwargs)]
def get_encoder_cudagraph_per_item_input_sizes(
self, mm_kwargs: dict[str, Any]
) -> list[int]:
return [t * h * w for t, h, w in self._get_grid_thw(mm_kwargs)]
def select_encoder_cudagraph_items(
self, mm_kwargs: dict[str, Any], indices: list[int]
) -> dict[str, Any]:
modality = self.get_input_modality(mm_kwargs)
pv_key = "pixel_values_videos" if modality == "video" else "pixel_values"
grid_key = "video_grid_thw" if modality == "video" else "image_grid_thw"
grid_thw = self._get_grid_thw(mm_kwargs)
pixel_values = self._get_pixel_values(mm_kwargs)
if len(indices) == 0:
return {pv_key: pixel_values[:0], grid_key: []}
patches_per_item = [t * h * w for t, h, w in grid_thw]
cum_patches = [0]
for p in patches_per_item:
cum_patches.append(cum_patches[-1] + p)
selected_pv = torch.cat(
[pixel_values[cum_patches[i] : cum_patches[i + 1]] for i in indices]
)
return {pv_key: selected_pv, grid_key: [grid_thw[i] for i in indices]}
def prepare_encoder_cudagraph_capture_inputs(
self,
token_budget: int,
max_batch_size: int,
max_frames_per_batch: int,
device: torch.device,
dtype: torch.dtype,
) -> EncoderCudaGraphCaptureInputs:
per_item_output = token_budget // max_batch_size
frames_per_item = max_frames_per_batch // max_batch_size
if frames_per_item > 1:
# Video-format capture: size cu_seqlens for T frames per item.
tokens_per_frame = (
per_item_output + frames_per_item - 1
) // frames_per_item
grid_config = [
[frames_per_item, _SPATIAL_MERGE, tokens_per_frame * _SPATIAL_MERGE]
for _ in range(max_batch_size)
]
else:
grid_config = [
[1, _SPATIAL_MERGE, per_item_output * _SPATIAL_MERGE]
for _ in range(max_batch_size)
]
total_patches = _count_input_patches(grid_config)
# Use pixel_values (image key) for capture — same patch shape as video.
dummy_pixel_values = torch.randn(
total_patches, _FLAT, device=device, dtype=dtype
)
n_out = _count_output_tokens(grid_config, _SPATIAL_MERGE)
dummy_buf = torch.zeros(n_out, _HIDDEN, device=device, dtype=dtype)
return EncoderCudaGraphCaptureInputs(
mm_kwargs={
"pixel_values": dummy_pixel_values,
"image_grid_thw": grid_config,
},
buffers={"dummy_buf": dummy_buf},
)
def prepare_encoder_cudagraph_replay_buffers(
self,
mm_kwargs: dict[str, Any],
max_batch_size: int,
max_frames_per_batch: int,
) -> EncoderCudaGraphReplayBuffers:
n_out = _count_output_tokens(self._get_grid_thw(mm_kwargs), _SPATIAL_MERGE)
p = next(self.parameters())
dummy_buf = torch.zeros(n_out, _HIDDEN, device=p.device, dtype=p.dtype)
return EncoderCudaGraphReplayBuffers(buffers={"dummy_buf": dummy_buf})
def encoder_cudagraph_forward(
self, mm_kwargs: dict[str, Any], buffers: dict[str, torch.Tensor]
) -> torch.Tensor:
return self._forward(self._get_pixel_values(mm_kwargs))
def encoder_eager_forward(self, mm_kwargs: dict[str, Any]) -> torch.Tensor:
return self._forward(self._get_pixel_values(mm_kwargs))
# ---------------------------------------------------------------------------
# No-GPU tests — get_input_modality routing
# ---------------------------------------------------------------------------
class TestGetInputModality:
"""get_input_modality returns correct modality based on mm_kwargs keys."""
def test_image_only_model_always_returns_image(self):
model = SimpleMockViTModel()
mm_kwargs = {
"pixel_values": torch.zeros(1, _FLAT),
"image_grid_thw": [[1, 4, 4]],
}
assert model.get_input_modality(mm_kwargs) == "image"
def test_video_model_returns_image_for_image_kwargs(self):
model = SimpleMockViTVideoModel()
mm_kwargs = {
"pixel_values": torch.zeros(1, _FLAT),
"image_grid_thw": [[1, 4, 4]],
}
assert model.get_input_modality(mm_kwargs) == "image"
def test_video_model_returns_video_for_video_kwargs(self):
model = SimpleMockViTVideoModel()
mm_kwargs = {
"pixel_values_videos": torch.zeros(8, _FLAT),
"video_grid_thw": [[2, 4, 4]],
}
assert model.get_input_modality(mm_kwargs) == "video"
def test_video_model_config_has_both_modalities(self):
model = SimpleMockViTVideoModel()
cfg = model.get_encoder_cudagraph_config()
assert "image" in cfg.modalities
assert "video" in cfg.modalities
assert cfg.input_key_by_modality["image"] == "pixel_values"
assert cfg.input_key_by_modality["video"] == "pixel_values_videos"
# ---------------------------------------------------------------------------
# GPU tests — video capture, replay, fallback, and mixed image+video
# ---------------------------------------------------------------------------
_VIDEO_MAX_BATCH = 4
_VIDEO_MAX_FRAMES = 8 # 2 frames per item at max_batch_size=4
@pytest.mark.skipif(not current_platform.is_cuda(), reason="Skip if not cuda")
class TestEncoderCudaGraphVideoReplay:
def setup_method(self):
self.device = torch.device("cuda:0")
self.dtype = torch.float16
self.model = SimpleMockViTVideoModel().to(self.device).half()
self.mgr = _make_manager_for_gpu(
self.model,
_BUDGETS,
_VIDEO_MAX_BATCH,
self.device,
self.dtype,
max_frames_per_batch=_VIDEO_MAX_FRAMES,
)
self.mgr.capture()
# --- capture ---
def test_capture_creates_one_graph_per_budget(self):
assert len(self.mgr.budget_graphs) == len(_BUDGETS)
assert set(self.mgr.budget_graphs.keys()) == set(_BUDGETS)
# --- output shape ---
def test_video_execute_returns_one_tensor_per_video(self):
# T=2, 4x4 → 2*(4//2)*(4//2) = 8 tokens per video
grid_thw = [[2, 4, 4], [2, 4, 4]]
mm_kwargs = _make_video_mm_kwargs(grid_thw, self.device, self.dtype)
result = self.mgr.execute(mm_kwargs)
assert result is not None
assert len(result) == 2
def test_video_output_tokens_per_item(self):
# T=2,4x4 → 8 tokens; T=1,4x4 → 4 tokens
grid_thw = [[2, 4, 4], [1, 4, 4]]
mm_kwargs = _make_video_mm_kwargs(grid_thw, self.device, self.dtype)
result = self.mgr.execute(mm_kwargs)
assert result is not None
assert result[0].shape == (8, _HIDDEN)
assert result[1].shape == (4, _HIDDEN)
# --- budget fallback ---
def test_video_eager_fallback_when_tokens_exceed_all_budgets(self):
# T=2, 18x18 → 2*(18//2)*(18//2) = 162 tokens > max budget 64
grid_thw = [[2, 18, 18]]
mm_kwargs = _make_video_mm_kwargs(grid_thw, self.device, self.dtype)
result = self.mgr.execute(mm_kwargs)
assert result is not None
assert len(result) == 1
assert result[0].shape == (162, _HIDDEN)
assert self.mgr.graph_misses == 1
# --- counters ---
def test_video_hit_counter_increments_by_num_videos(self):
grid_thw = [[2, 4, 4], [1, 4, 4]]
mm_kwargs = _make_video_mm_kwargs(grid_thw, self.device, self.dtype)
self.mgr.execute(mm_kwargs)
assert self.mgr.graph_hits == 2
def test_video_miss_counter_increments_for_oversized_video(self):
grid_thw = [[2, 18, 18]] # 162 tokens > 64
mm_kwargs = _make_video_mm_kwargs(grid_thw, self.device, self.dtype)
self.mgr.execute(mm_kwargs)
assert self.mgr.graph_misses == 1
# --- image and video sharing the same manager ---
def test_image_and_video_share_manager(self):
"""Image and video inputs can both be executed through the same manager."""
img_grid = [[1, 4, 4], [1, 4, 4]]
img_result = self.mgr.execute(
_make_mm_kwargs(img_grid, self.device, self.dtype)
)
vid_grid = [[2, 4, 4]]
vid_result = self.mgr.execute(
_make_video_mm_kwargs(vid_grid, self.device, self.dtype)
)
assert len(img_result) == 2
assert len(vid_result) == 1
assert img_result[0].shape == (4, _HIDDEN)
assert vid_result[0].shape == (8, _HIDDEN)
......@@ -519,13 +519,21 @@ class CompilationConfig:
User-provided values override auto-inference.
Example: [2048, 4096, 8192, 13824]"""
encoder_cudagraph_max_images_per_batch: int = 0
"""Maximum number of images per batch for encoder CUDA graph capture.
encoder_cudagraph_max_vision_items_per_batch: int = 0
"""Maximum number of images/videos per batch for encoder CUDA graph capture.
Determines the fixed batch size used during graph capture.
If 0 (default), auto-inferred as max_budget // min_budget from the
model's budget range. User-provided positive value overrides
auto-inference."""
encoder_cudagraph_max_frames_per_batch: int = 0
"""Maximum total video frames per batch for encoder CUDA graph capture.
Controls the cu_seqlens buffer size (one entry per attention sequence,
i.e. one per video frame). If 0 (default), auto-inferred per budget
level as token_budget (tight bound: packing guarantees
sum(T_i) <= token_budget). Positive value overrides auto-inference
and applies to all budget levels."""
# Inductor capture
compile_sizes: list[int | str] | None = None
"""Sizes to compile for inductor. In addition
......@@ -964,10 +972,18 @@ class CompilationConfig:
# Validate encoder CUDA graph configuration
if (
self.cudagraph_mm_encoder
and self.encoder_cudagraph_max_images_per_batch < 0
and self.encoder_cudagraph_max_vision_items_per_batch < 0
):
raise ValueError(
"encoder_cudagraph_max_vision_items_per_batch must be "
"non-negative (0 = auto-infer)"
)
if (
self.cudagraph_mm_encoder
and self.encoder_cudagraph_max_frames_per_batch < 0
):
raise ValueError(
"encoder_cudagraph_max_images_per_batch must be "
"encoder_cudagraph_max_frames_per_batch must be "
"non-negative (0 = auto-infer)"
)
......
......@@ -1524,6 +1524,13 @@ class SupportsEncoderCudaGraph(Protocol):
def get_encoder_cudagraph_config(self) -> "EncoderCudaGraphConfig": ...
def get_input_modality(
self,
mm_kwargs: dict[str, Any],
) -> str:
"""Return the modality of the inputs."""
...
def get_encoder_cudagraph_budget_range(
self,
vllm_config: "VllmConfig",
......@@ -1536,7 +1543,7 @@ class SupportsEncoderCudaGraph(Protocol):
(e.g. max_num_batched_tokens)
Used when ``encoder_cudagraph_token_budgets`` and/or
``encoder_cudagraph_max_images_per_batch`` are not explicitly
``encoder_cudagraph_max_vision_items_per_batch`` are not explicitly
specified by the user.
"""
...
......@@ -1590,6 +1597,7 @@ class SupportsEncoderCudaGraph(Protocol):
self,
token_budget: int,
max_batch_size: int,
max_frames_per_batch: int,
device: torch.device,
dtype: torch.dtype,
) -> "EncoderCudaGraphCaptureInputs":
......@@ -1600,6 +1608,7 @@ class SupportsEncoderCudaGraph(Protocol):
self,
mm_kwargs: dict[str, Any],
max_batch_size: int,
max_frames_per_batch: int,
) -> "EncoderCudaGraphReplayBuffers":
"""Compute buffer values from actual batch inputs for replay."""
...
......
......@@ -99,6 +99,7 @@ from vllm.tokenizers.registry import cached_tokenizer_from_config
from vllm.triton_utils import HAS_TRITON, tl, triton
from vllm.utils.collection_utils import is_list_of
from vllm.utils.math_utils import round_up
from vllm.v1.worker.encoder_cudagraph_defs import EncoderCudaGraphReplayBuffers
from .interfaces import (
MultiModalEmbeddings,
......@@ -689,6 +690,7 @@ class Qwen3_VisionTransformer(nn.Module):
grid_thw_list: list[list[int]],
*,
max_batch_size: int | None = None,
max_frames_per_batch: int | None = None,
max_seqlen_override: int | None = None,
device: torch.device | None = None,
) -> dict[str, torch.Tensor | None]:
......@@ -701,6 +703,10 @@ class Qwen3_VisionTransformer(nn.Module):
grid_thw_list: Grid configurations as list of [t, h, w].
max_batch_size: If set, pad cu_seqlens to this size
(needed for CUDA graph capture/replay).
max_frames_per_batch: If set, overrides max_batch_size for
cu_seqlens padding. For video inputs each item contributes
T attention sequences (frames); this sizes the buffer to
the total frame budget so video replays never overflow.
max_seqlen_override: If set, use this value for max_seqlen
instead of computing from cu_seqlens (needed for CUDA
graph capture to cover worst-case replay scenarios).
......@@ -725,15 +731,21 @@ class Qwen3_VisionTransformer(nn.Module):
)
cu_seqlens = np.concatenate([np.zeros(1, dtype=np.int32), cu_seqlens])
# Pad cu_seqlens if max_batch_size specified
if max_batch_size is not None:
# Pad cu_seqlens to the required number of sequences.
# For videos each item contributes T frames = T attention sequences,
# so the total can exceed max_batch_size. max_frames_per_batch
# overrides the pad target when set.
pad_to = (
max_frames_per_batch if max_frames_per_batch is not None else max_batch_size
)
if pad_to is not None:
num_seqs = len(cu_seqlens) - 1
if num_seqs < max_batch_size:
if num_seqs < pad_to:
cu_seqlens = np.concatenate(
[
cu_seqlens,
np.full(
max_batch_size - num_seqs,
pad_to - num_seqs,
cu_seqlens[-1],
dtype=np.int32,
),
......@@ -1737,9 +1749,21 @@ class Qwen3VLForConditionalGeneration(
EncoderCudaGraphConfig,
)
modalities = ["image"]
# NOTE: When EVS (Efficient Video Sampling) pruning is enabled, the number
# of tokens becomes data-dependent (i.e., the retained tokens are
# dynamically selected based on inter-frame differences) and therefore
# cannot be captured by CUDA Graphs. As a result, video CUDA Graphs are
# only enabled when EVS is disabled.
if not self.is_multimodal_pruning_enabled:
modalities.append("video")
return EncoderCudaGraphConfig(
modalities=["image"],
input_key="pixel_values",
modalities=modalities,
input_key_by_modality={
"image": "pixel_values",
"video": "pixel_values_videos",
},
buffer_keys=[
"pos_embeds",
"rotary_pos_emb_cos",
......@@ -1751,49 +1775,86 @@ class Qwen3VLForConditionalGeneration(
out_hidden_size=self.visual.out_hidden_size,
)
def get_input_modality(
self,
mm_kwargs: dict[str, Any],
) -> str:
if "image_grid_thw" in mm_kwargs:
return "image"
return "video"
def get_encoder_cudagraph_budget_range(
self,
vllm_config,
) -> tuple[int, int]:
# Min: estimated smallest possible encoder input.
# 224x224 image → 16x16 patches, spatial_merge_size=2 → 8x8 = 64 tokens
# 224x224 image → 16x16 patches (patch_size=14)
# spatial_merge_size=2 → 8x8 = 64 tokens
min_budget = 64
# Max: capped by max_num_batched_tokens
max_budget = vllm_config.scheduler_config.max_num_batched_tokens
return (min_budget, max_budget)
def _get_pixel_values_by_modality(
self,
mm_kwargs: dict[str, Any],
) -> torch.Tensor:
if self.get_input_modality(mm_kwargs) == "image":
pixel_values = mm_kwargs["pixel_values"]
else:
pixel_values = mm_kwargs["pixel_values_videos"]
return pixel_values
def _get_grid_thw_by_modality(
self,
mm_kwargs: dict[str, Any],
) -> list[tuple[int, int, int]]:
grid_thw_key = f"{self.get_input_modality(mm_kwargs)}_grid_thw"
grid_thw = mm_kwargs[grid_thw_key]
if not isinstance(grid_thw, list):
grid_thw = grid_thw.tolist()
return grid_thw
def get_encoder_cudagraph_num_items(
self,
mm_kwargs: dict[str, Any],
) -> int:
return len(mm_kwargs["image_grid_thw"])
return len(self._get_grid_thw_by_modality(mm_kwargs))
def get_encoder_cudagraph_per_item_output_tokens(
self,
mm_kwargs: dict[str, Any],
) -> list[int]:
m = self.visual.spatial_merge_size
return [t * (h // m) * (w // m) for t, h, w in mm_kwargs["image_grid_thw"]]
grid_thw = self._get_grid_thw_by_modality(mm_kwargs)
return [t * (h // m) * (w // m) for t, h, w in grid_thw]
def get_encoder_cudagraph_per_item_input_sizes(
self,
mm_kwargs: dict[str, Any],
) -> list[int]:
return [t * h * w for t, h, w in mm_kwargs["image_grid_thw"]]
grid_thw = self._get_grid_thw_by_modality(mm_kwargs)
return [t * h * w for t, h, w in grid_thw]
def select_encoder_cudagraph_items(
self,
mm_kwargs: dict[str, Any],
indices: list[int],
) -> dict[str, Any]:
grid_thw = mm_kwargs["image_grid_thw"]
pixel_values = mm_kwargs["pixel_values"]
grid_thw = self._get_grid_thw_by_modality(mm_kwargs)
pixel_values = self._get_pixel_values_by_modality(mm_kwargs)
if len(indices) == 0:
if self.get_input_modality(mm_kwargs) == "image":
return {
"pixel_values": pixel_values[:0],
"image_grid_thw": [],
}
else:
return {
"pixel_values_videos": pixel_values[:0],
"video_grid_thw": [],
}
# Compute cumulative patch offsets for slicing pixel_values
patches_per_item = [t * h * w for t, h, w in grid_thw]
......@@ -1806,15 +1867,22 @@ class Qwen3VLForConditionalGeneration(
)
selected_grid = [grid_thw[i] for i in indices]
if self.get_input_modality(mm_kwargs) == "image":
return {
"pixel_values": selected_pv,
"image_grid_thw": selected_grid,
}
else:
return {
"pixel_values_videos": selected_pv,
"video_grid_thw": selected_grid,
}
def prepare_encoder_cudagraph_capture_inputs(
self,
token_budget: int,
max_batch_size: int,
max_frames_per_batch: int,
device: torch.device,
dtype: torch.dtype,
):
......@@ -1823,12 +1891,33 @@ class Qwen3VLForConditionalGeneration(
)
spatial_merge_size = self.visual.spatial_merge_size
per_image_output = token_budget // max_batch_size
# Synthetic rectangular grid: [1, merge, per_image_output * merge]
# produces exactly per_image_output tokens per image.
per_mm_item_output = token_budget // max_batch_size
frames_per_item = max_frames_per_batch // max_batch_size
if frames_per_item > 1:
# Build the capture grid using a video-format layout so that
# cu_seqlens is sized for video replays from the start.
# cu_seqlens has one entry per attention sequence (one per frame),
# so using T > 1 per item makes the buffer large enough without
# relying solely on padding.
# Ceiling ensures frames_per_item * tokens_per_frame >= per_mm_item_output
# so the pixel_values buffer covers any valid single-item replay.
tokens_per_frame = (
per_mm_item_output + frames_per_item - 1
) // frames_per_item
# Video-format grid_config (T=frames_per_item).
grid_config = [
[
frames_per_item,
spatial_merge_size,
tokens_per_frame * spatial_merge_size,
]
for _ in range(max_batch_size)
]
else:
# Image-format grid_config (T=1).
grid_config = [
[1, spatial_merge_size, per_image_output * spatial_merge_size]
[1, spatial_merge_size, per_mm_item_output * spatial_merge_size]
for _ in range(max_batch_size)
]
......@@ -1848,15 +1937,18 @@ class Qwen3VLForConditionalGeneration(
# Override max_seqlen with a safe upper bound for capture.
# max_seqlen.item() gets baked into the CUDA graph (not replayed),
# so the capture value must cover any replay scenario.
# Worst case: 1 image consuming the full budget ->
# Worst case: 1 item consuming the full budget ->
# seq_len = token_budget * spatial_merge_size^2.
buffers = self.visual.prepare_encoder_metadata(
grid_config,
max_batch_size=max_batch_size,
max_frames_per_batch=max_frames_per_batch,
max_seqlen_override=token_budget * (spatial_merge_size**2),
device=device,
)
# Just use image-modality dummy input_buffer for capturing, since it's also
# compatible for video inputs (has the same shape: [num_patches, C*T*P*P]).
mm_kwargs = {
"pixel_values": dummy_pixel_values,
"image_grid_thw": grid_config,
......@@ -1871,17 +1963,21 @@ class Qwen3VLForConditionalGeneration(
self,
mm_kwargs: dict[str, Any],
max_batch_size: int,
max_frames_per_batch: int,
):
from vllm.v1.worker.encoder_cudagraph_defs import (
EncoderCudaGraphReplayBuffers,
)
grid_thw_list = mm_kwargs["image_grid_thw"]
modality = self.get_input_modality(mm_kwargs)
grid_thw_list = self._get_grid_thw_by_modality(mm_kwargs)
if modality == "image":
buffers = self.visual.prepare_encoder_metadata(
grid_thw_list,
max_batch_size=max_batch_size,
)
else:
buffers = self.visual.prepare_encoder_metadata(
grid_thw_list,
max_frames_per_batch=max_frames_per_batch,
)
return EncoderCudaGraphReplayBuffers(buffers=buffers)
......@@ -1890,16 +1986,16 @@ class Qwen3VLForConditionalGeneration(
mm_kwargs: dict[str, Any],
buffers: dict[str, torch.Tensor],
) -> torch.Tensor:
pixel_values = mm_kwargs["pixel_values"]
grid_thw = mm_kwargs["image_grid_thw"]
pixel_values = self._get_pixel_values_by_modality(mm_kwargs)
grid_thw = self._get_grid_thw_by_modality(mm_kwargs)
return self.visual(pixel_values, grid_thw, encoder_metadata=buffers)
def encoder_eager_forward(
self,
mm_kwargs: dict[str, Any],
) -> torch.Tensor:
pixel_values = mm_kwargs["pixel_values"]
grid_thw = mm_kwargs["image_grid_thw"]
pixel_values = self._get_pixel_values_by_modality(mm_kwargs)
grid_thw = self._get_grid_thw_by_modality(mm_kwargs)
return self.visual(pixel_values, grid_thw)
def _parse_and_validate_image_input(
......
......@@ -36,6 +36,7 @@ class BudgetGraphMetadata:
token_budget: int
max_batch_size: int # Max number of images/videos per batch
max_frames_per_batch: int # Max total frames per batch (for video)
graph: torch.cuda.CUDAGraph
# The input tensor updated before replay (e.g. pixel_values)
input_buffer: torch.Tensor
......@@ -66,12 +67,13 @@ class EncoderCudaGraphManager:
comp_config = vllm_config.compilation_config
user_budgets = comp_config.encoder_cudagraph_token_budgets
user_max_images = comp_config.encoder_cudagraph_max_images_per_batch
user_max_mm_items = comp_config.encoder_cudagraph_max_vision_items_per_batch
user_max_frames = comp_config.encoder_cudagraph_max_frames_per_batch
if user_budgets and user_max_images > 0:
if user_budgets and user_max_mm_items > 0:
# Fully user-specified
self.token_budgets = sorted(user_budgets)
self.max_batch_size = user_max_images
self.max_batch_size = user_max_mm_items
else:
# Auto-infer missing values from model
min_budget, max_budget = model.get_encoder_cudagraph_budget_range(
......@@ -83,9 +85,15 @@ class EncoderCudaGraphManager:
else self._generate_budgets(min_budget, max_budget)
)
self.max_batch_size = (
user_max_images if user_max_images > 0 else max_budget // min_budget
user_max_mm_items if user_max_mm_items > 0 else max_budget // min_budget
)
if user_max_frames > 0:
self.max_frames_per_batch = user_max_frames
else:
# TODO(shen-shanshan): optimize this auto-infer for max_frames_per_batch.
self.max_frames_per_batch = self.max_batch_size * 2
mm_config = vllm_config.model_config.multimodal_config
self.use_dp = (
mm_config is not None
......@@ -100,9 +108,10 @@ class EncoderCudaGraphManager:
logger.info(
"EncoderCudaGraphManager initialized with "
"budgets=%s, max_batch_size=%d, use_dp=%s",
"budgets=%s, max_batch_size=%d, max_frames_per_batch=%s, use_dp=%s",
self.token_budgets,
self.max_batch_size,
self.max_frames_per_batch if self.max_frames_per_batch > 0 else "auto",
self.use_dp,
)
......@@ -136,13 +145,19 @@ class EncoderCudaGraphManager:
def _capture_budget_graph(self, token_budget: int):
"""Capture CUDA graph for a single token budget."""
logger.debug(
"Capturing encoder cudagraph for budget=%d, max_batch_size=%d",
"Capturing encoder cudagraph for budget=%d, max_batch_size=%d, "
"max_frames_per_batch=%d",
token_budget,
self.max_batch_size,
self.max_frames_per_batch,
)
capture_inputs = self.model.prepare_encoder_cudagraph_capture_inputs(
token_budget, self.max_batch_size, self.device, self.dtype
token_budget,
self.max_batch_size,
self.max_frames_per_batch,
self.device,
self.dtype,
)
mm_kwargs = capture_inputs.mm_kwargs
......@@ -157,10 +172,14 @@ class EncoderCudaGraphManager:
output = self.model.encoder_cudagraph_forward(mm_kwargs, buffers)
output_buffer.copy_(output)
input_key = self.config.input_key
# Since the image and video modalities share the same per-patch shape,
# so we can use the image dummy inputs to capture CUDA graph for both
# image and video.
input_key = self.config.input_key_by_modality["image"]
self.budget_graphs[token_budget] = BudgetGraphMetadata(
token_budget=token_budget,
max_batch_size=self.max_batch_size,
max_frames_per_batch=self.max_frames_per_batch,
graph=graph,
input_buffer=mm_kwargs[input_key],
metadata_buffers=buffers,
......@@ -230,10 +249,11 @@ class EncoderCudaGraphManager:
# Copy the input tensor. Buffers are sized for the full budget;
# actual inputs may be smaller. Zero then slice-copy so padded
# positions are invisible to attention (cu_seqlens masks them out).
input_key = self.config.input_key
input_key = self.config.input_key_by_modality[
self.model.get_input_modality(mm_kwargs)
]
src = mm_kwargs[input_key]
n = src.shape[0]
graph_meta.input_buffer.zero_()
graph_meta.input_buffer[:n].copy_(src)
# Copy metadata buffers using keys from config.buffer_keys.
......@@ -362,7 +382,9 @@ class EncoderCudaGraphManager:
(token_budget - batch_out_tokens) / token_budget * 100,
)
replay = self.model.prepare_encoder_cudagraph_replay_buffers(
batch_mm_kwargs, self.max_batch_size
batch_mm_kwargs,
self.max_batch_size,
self.max_frames_per_batch,
)
# graph_hits counted inside _run_budget_graph after replay.
......
......@@ -20,8 +20,10 @@ class EncoderCudaGraphConfig:
modalities: list[str]
"""Supported modalities (e.g. ["image"])."""
input_key: str
"""Key in mm_kwargs for the input tensor (e.g. "pixel_values")."""
input_key_by_modality: dict[str, str]
"""Per-modality input tensor key mapping, e.g.
{"image": "pixel_values", "video": "pixel_values_videos"}.
"""
buffer_keys: list[str]
"""Keys for the tensor buffers recorded into the CUDA graph.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment