Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
f85e479e
Unverified
Commit
f85e479e
authored
Mar 23, 2026
by
Baorun (Lauren) Mu
Committed by
GitHub
Mar 23, 2026
Browse files
[Feature] ViT Full CUDA Graph (#35963)
Signed-off-by:
Baorun Mu
<
bmu@nvidia.com
>
parent
1f0d2106
Changes
7
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
1584 additions
and
31 deletions
+1584
-31
tests/v1/cudagraph/test_encoder_cudagraph.py
tests/v1/cudagraph/test_encoder_cudagraph.py
+451
-0
vllm/config/compilation.py
vllm/config/compilation.py
+32
-0
vllm/model_executor/models/interfaces.py
vllm/model_executor/models/interfaces.py
+141
-0
vllm/model_executor/models/qwen3_vl.py
vllm/model_executor/models/qwen3_vl.py
+270
-30
vllm/v1/worker/gpu/mm/encoder_cudagraph.py
vllm/v1/worker/gpu/mm/encoder_cudagraph.py
+576
-0
vllm/v1/worker/gpu/mm/encoder_cudagraph_defs.py
vllm/v1/worker/gpu/mm/encoder_cudagraph_defs.py
+66
-0
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+48
-1
No files found.
tests/v1/cudagraph/test_encoder_cudagraph.py
0 → 100644
View file @
f85e479e
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Unit tests for EncoderCudaGraphManager.
Test organization:
No GPU required:
- TestFindBudgetGraph — greedy budget selection logic
- TestGetCumulativeStats — hit/miss rate statistics
GPU required:
- TestEncoderCudaGraphCaptureReplay — capture, replay, fallback, counters, chunking
"""
from
typing
import
Any
import
pytest
import
torch
from
vllm.platforms
import
current_platform
from
vllm.v1.worker.gpu.mm.encoder_cudagraph
import
(
EncoderCudaGraphManager
,
)
from
vllm.v1.worker.gpu.mm.encoder_cudagraph_defs
import
(
EncoderCudaGraphCaptureInputs
,
EncoderCudaGraphConfig
,
EncoderCudaGraphReplayBuffers
,
)
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def
_make_manager_with_budgets
(
budgets
:
list
[
int
])
->
EncoderCudaGraphManager
:
"""Create a minimal EncoderCudaGraphManager with only token_budgets set.
Skips the parts of __init__ that require a real VllmConfig / model
by patching the attributes directly after construction.
"""
mgr
=
object
.
__new__
(
EncoderCudaGraphManager
)
mgr
.
token_budgets
=
sorted
(
budgets
)
mgr
.
max_batch_size
=
16
mgr
.
use_dp
=
False
mgr
.
budget_graphs
=
{}
mgr
.
graph_hits
=
0
mgr
.
graph_misses
=
0
mgr
.
log_stats_interval
=
100
return
mgr
# ---------------------------------------------------------------------------
# _generate_budgets
# ---------------------------------------------------------------------------
class
TestGenerateBudgets
:
"""Auto-generate power-of-2 budgets from min to max."""
def
test_exact_powers_of_2
(
self
):
result
=
EncoderCudaGraphManager
.
_generate_budgets
(
64
,
1024
)
assert
result
==
[
64
,
128
,
256
,
512
,
1024
]
def
test_max_not_power_of_2
(
self
):
result
=
EncoderCudaGraphManager
.
_generate_budgets
(
64
,
800
)
assert
result
==
[
64
,
128
,
256
,
512
,
800
]
def
test_min_equals_max
(
self
):
result
=
EncoderCudaGraphManager
.
_generate_budgets
(
64
,
64
)
assert
result
==
[
64
]
def
test_large_range
(
self
):
result
=
EncoderCudaGraphManager
.
_generate_budgets
(
64
,
8192
)
assert
result
==
[
64
,
128
,
256
,
512
,
1024
,
2048
,
4096
,
8192
]
# ---------------------------------------------------------------------------
# _find_smallest_fitting_budget_given_tokens
# ---------------------------------------------------------------------------
class
TestFindBudgetGraph
:
"""Budget greedy selection: smallest budget >= total_tokens."""
@
pytest
.
mark
.
parametrize
(
"total_tokens,budgets,expected"
,
[
# Exact match
(
2048
,
[
2048
,
4096
,
8192
],
2048
),
# Below smallest budget — picks smallest
(
100
,
[
2048
,
4096
,
8192
],
2048
),
# Zero tokens — picks smallest
(
0
,
[
2048
,
4096
,
8192
],
2048
),
# Between budgets — picks next one up
(
2049
,
[
2048
,
4096
,
8192
],
4096
),
(
4097
,
[
2048
,
4096
,
8192
],
8192
),
# Exceeds all budgets — returns None (eager fallback)
(
9000
,
[
2048
,
4096
,
8192
],
None
),
# Single budget, fits
(
1000
,
[
2048
],
2048
),
# Single budget, does not fit
(
3000
,
[
2048
],
None
),
],
)
def
test_find_budget
(
self
,
total_tokens
,
budgets
,
expected
):
mgr
=
_make_manager_with_budgets
(
budgets
)
result
=
mgr
.
_find_smallest_fitting_budget_given_tokens
(
total_tokens
)
assert
result
==
expected
def
test_budgets_are_sorted
(
self
):
"""Manager always sorts budgets ascending at init."""
mgr
=
_make_manager_with_budgets
([
8192
,
2048
,
4096
])
assert
mgr
.
token_budgets
==
[
2048
,
4096
,
8192
]
# Budget selection still works correctly after sorting
assert
mgr
.
_find_smallest_fitting_budget_given_tokens
(
3000
)
==
4096
# ---------------------------------------------------------------------------
# get_cumulative_stats
# ---------------------------------------------------------------------------
class
TestGetCumulativeStats
:
"""Statistics tracking and reporting."""
def
test_initial_stats_are_zero
(
self
):
mgr
=
_make_manager_with_budgets
([
2048
])
stats
=
mgr
.
get_cumulative_stats
()
assert
stats
[
"graph_hits"
]
==
0
assert
stats
[
"graph_misses"
]
==
0
assert
stats
[
"hit_rate"
]
==
0.0
def
test_hit_rate_calculation
(
self
):
mgr
=
_make_manager_with_budgets
([
2048
])
mgr
.
graph_hits
=
75
mgr
.
graph_misses
=
25
stats
=
mgr
.
get_cumulative_stats
()
assert
stats
[
"graph_hits"
]
==
75
assert
stats
[
"graph_misses"
]
==
25
assert
stats
[
"hit_rate"
]
==
pytest
.
approx
(
0.75
)
def
test_all_hits
(
self
):
mgr
=
_make_manager_with_budgets
([
2048
])
mgr
.
graph_hits
=
100
mgr
.
graph_misses
=
0
assert
mgr
.
get_cumulative_stats
()[
"hit_rate"
]
==
pytest
.
approx
(
1.0
)
def
test_all_misses
(
self
):
mgr
=
_make_manager_with_budgets
([
2048
])
mgr
.
graph_hits
=
0
mgr
.
graph_misses
=
50
assert
mgr
.
get_cumulative_stats
()[
"hit_rate"
]
==
pytest
.
approx
(
0.0
)
def
test_stats_report_budget_info
(
self
):
budgets
=
[
2048
,
4096
,
8192
]
mgr
=
_make_manager_with_budgets
(
budgets
)
stats
=
mgr
.
get_cumulative_stats
()
assert
stats
[
"num_budgets"
]
==
0
# no graphs captured yet
assert
stats
[
"token_budgets"
]
==
budgets
# ---------------------------------------------------------------------------
# GPU fixtures and helpers
# ---------------------------------------------------------------------------
# Mock encoder parameters (kept small for fast capture)
_SPATIAL_MERGE
=
2
_HIDDEN
=
32
_PATCH_SIZE
=
4
# H/W per patch in grid_thw units
_TEMPORAL_PATCH
=
1
_IN_CHANNELS
=
3
# flattened_patch_size = in_channels * temporal_patch * patch_size^2
_FLAT
=
_IN_CHANNELS
*
_TEMPORAL_PATCH
*
_PATCH_SIZE
*
_PATCH_SIZE
# 48
# Test budgets: small to keep capture fast
_BUDGETS
=
[
16
,
64
]
_MAX_BATCH
=
4
def
_count_input_patches
(
grid_thw_list
:
list
[
list
[
int
]])
->
int
:
return
sum
(
t
*
h
*
w
for
t
,
h
,
w
in
grid_thw_list
)
def
_count_output_tokens
(
grid_thw_list
:
list
[
list
[
int
]],
spatial_merge_size
:
int
)
->
int
:
m
=
spatial_merge_size
return
sum
(
t
*
(
h
//
m
)
*
(
w
//
m
)
for
t
,
h
,
w
in
grid_thw_list
)
class
SimpleMockViTModel
(
torch
.
nn
.
Module
):
"""Minimal ViT model for CUDA graph tests.
Implements the SupportsEncoderCudaGraph protocol by providing
all required methods. The forward pass projects patches and
simulates spatial merge by averaging groups of m^2 patches.
"""
supports_encoder_cudagraph
=
True
def
__init__
(
self
):
super
().
__init__
()
self
.
proj
=
torch
.
nn
.
Linear
(
_FLAT
,
_HIDDEN
)
self
.
spatial_merge_size
=
_SPATIAL_MERGE
self
.
out_hidden_size
=
_HIDDEN
def
get_encoder_cudagraph_config
(
self
)
->
EncoderCudaGraphConfig
:
return
EncoderCudaGraphConfig
(
modalities
=
[
"image"
],
input_key
=
"pixel_values"
,
buffer_keys
=
[
"dummy_buf"
],
out_hidden_size
=
_HIDDEN
,
)
def
get_encoder_cudagraph_budget_range
(
self
,
vllm_config
,
)
->
tuple
[
int
,
int
]:
# For tests: min=4, max=128 (small values for fast capture)
return
(
4
,
128
)
def
get_encoder_cudagraph_num_items
(
self
,
mm_kwargs
:
dict
[
str
,
Any
],
)
->
int
:
return
len
(
mm_kwargs
[
"image_grid_thw"
])
def
get_encoder_cudagraph_per_item_output_tokens
(
self
,
mm_kwargs
:
dict
[
str
,
Any
],
)
->
list
[
int
]:
m
=
_SPATIAL_MERGE
return
[
t
*
(
h
//
m
)
*
(
w
//
m
)
for
t
,
h
,
w
in
mm_kwargs
[
"image_grid_thw"
]]
def
get_encoder_cudagraph_per_item_input_sizes
(
self
,
mm_kwargs
:
dict
[
str
,
Any
],
)
->
list
[
int
]:
return
[
t
*
h
*
w
for
t
,
h
,
w
in
mm_kwargs
[
"image_grid_thw"
]]
def
select_encoder_cudagraph_items
(
self
,
mm_kwargs
:
dict
[
str
,
Any
],
indices
:
list
[
int
],
)
->
dict
[
str
,
Any
]:
grid_thw
=
mm_kwargs
[
"image_grid_thw"
]
pixel_values
=
mm_kwargs
[
"pixel_values"
]
if
len
(
indices
)
==
0
:
return
{
"pixel_values"
:
pixel_values
[:
0
],
"image_grid_thw"
:
[],
}
patches_per_item
=
[
t
*
h
*
w
for
t
,
h
,
w
in
grid_thw
]
cum_patches
=
[
0
]
for
p
in
patches_per_item
:
cum_patches
.
append
(
cum_patches
[
-
1
]
+
p
)
selected_pv
=
torch
.
cat
(
[
pixel_values
[
cum_patches
[
i
]
:
cum_patches
[
i
+
1
]]
for
i
in
indices
]
)
selected_grid
=
[
grid_thw
[
i
]
for
i
in
indices
]
return
{
"pixel_values"
:
selected_pv
,
"image_grid_thw"
:
selected_grid
,
}
def
prepare_encoder_cudagraph_capture_inputs
(
self
,
token_budget
:
int
,
max_batch_size
:
int
,
device
:
torch
.
device
,
dtype
:
torch
.
dtype
,
)
->
EncoderCudaGraphCaptureInputs
:
per_image_output
=
token_budget
//
max_batch_size
grid_config
=
[
[
1
,
_SPATIAL_MERGE
,
per_image_output
*
_SPATIAL_MERGE
]
for
_
in
range
(
max_batch_size
)
]
total_patches
=
_count_input_patches
(
grid_config
)
dummy_pixel_values
=
torch
.
randn
(
total_patches
,
_FLAT
,
device
=
device
,
dtype
=
dtype
)
n_out
=
_count_output_tokens
(
grid_config
,
_SPATIAL_MERGE
)
dummy_buf
=
torch
.
zeros
(
n_out
,
_HIDDEN
,
device
=
device
,
dtype
=
dtype
)
return
EncoderCudaGraphCaptureInputs
(
mm_kwargs
=
{
"pixel_values"
:
dummy_pixel_values
,
"image_grid_thw"
:
grid_config
,
},
buffers
=
{
"dummy_buf"
:
dummy_buf
},
)
def
prepare_encoder_cudagraph_replay_buffers
(
self
,
mm_kwargs
:
dict
[
str
,
Any
],
max_batch_size
:
int
,
)
->
EncoderCudaGraphReplayBuffers
:
grid_thw
=
mm_kwargs
[
"image_grid_thw"
]
n_out
=
_count_output_tokens
(
grid_thw
,
_SPATIAL_MERGE
)
p
=
next
(
self
.
parameters
())
dummy_buf
=
torch
.
zeros
(
n_out
,
_HIDDEN
,
device
=
p
.
device
,
dtype
=
p
.
dtype
)
return
EncoderCudaGraphReplayBuffers
(
buffers
=
{
"dummy_buf"
:
dummy_buf
})
def
encoder_cudagraph_forward
(
self
,
mm_kwargs
:
dict
[
str
,
Any
],
buffers
:
dict
[
str
,
torch
.
Tensor
],
)
->
torch
.
Tensor
:
return
self
.
_forward
(
mm_kwargs
[
"pixel_values"
])
def
encoder_eager_forward
(
self
,
mm_kwargs
:
dict
[
str
,
Any
],
)
->
torch
.
Tensor
:
return
self
.
_forward
(
mm_kwargs
[
"pixel_values"
])
def
_forward
(
self
,
pixel_values
:
torch
.
Tensor
)
->
torch
.
Tensor
:
m2
=
_SPATIAL_MERGE
**
2
out
=
self
.
proj
(
pixel_values
)
n_out
=
out
.
shape
[
0
]
//
m2
return
out
[:
n_out
*
m2
].
view
(
n_out
,
m2
,
_HIDDEN
).
mean
(
dim
=
1
)
def
_make_manager_for_gpu
(
model
:
SimpleMockViTModel
,
token_budgets
:
list
[
int
],
max_batch_size
:
int
,
device
:
torch
.
device
,
dtype
:
torch
.
dtype
,
)
->
EncoderCudaGraphManager
:
"""Create EncoderCudaGraphManager bypassing VllmConfig for GPU tests."""
mgr
=
object
.
__new__
(
EncoderCudaGraphManager
)
mgr
.
token_budgets
=
sorted
(
token_budgets
)
mgr
.
max_batch_size
=
max_batch_size
mgr
.
use_dp
=
False
mgr
.
budget_graphs
=
{}
mgr
.
graph_hits
=
0
mgr
.
graph_misses
=
0
mgr
.
log_stats_interval
=
100
mgr
.
model
=
model
mgr
.
config
=
model
.
get_encoder_cudagraph_config
()
mgr
.
device
=
device
mgr
.
dtype
=
dtype
return
mgr
def
_make_pixel_values
(
grid_thw_list
:
list
[
list
[
int
]],
device
:
torch
.
device
,
dtype
:
torch
.
dtype
,
)
->
torch
.
Tensor
:
"""Random pixel_values matching the total input patch count."""
n
=
_count_input_patches
(
grid_thw_list
)
return
torch
.
randn
(
n
,
_FLAT
,
device
=
device
,
dtype
=
dtype
)
def
_make_mm_kwargs
(
grid_thw_list
:
list
[
list
[
int
]],
device
:
torch
.
device
,
dtype
:
torch
.
dtype
,
)
->
dict
[
str
,
Any
]:
"""Create mm_kwargs for testing."""
return
{
"pixel_values"
:
_make_pixel_values
(
grid_thw_list
,
device
,
dtype
),
"image_grid_thw"
:
grid_thw_list
,
}
# ---------------------------------------------------------------------------
# GPU tests — capture, replay, fallback, counters, chunking
# ---------------------------------------------------------------------------
@
pytest
.
mark
.
skipif
(
not
current_platform
.
is_cuda
(),
reason
=
"Skip if not cuda"
)
class
TestEncoderCudaGraphCaptureReplay
:
def
setup_method
(
self
):
self
.
device
=
torch
.
device
(
"cuda:0"
)
self
.
dtype
=
torch
.
float16
self
.
model
=
SimpleMockViTModel
().
to
(
self
.
device
).
half
()
self
.
mgr
=
_make_manager_for_gpu
(
self
.
model
,
_BUDGETS
,
_MAX_BATCH
,
self
.
device
,
self
.
dtype
)
self
.
mgr
.
capture
()
# --- capture ---
def
test_capture_creates_one_graph_per_budget
(
self
):
assert
len
(
self
.
mgr
.
budget_graphs
)
==
len
(
_BUDGETS
)
assert
set
(
self
.
mgr
.
budget_graphs
.
keys
())
==
set
(
_BUDGETS
)
# --- output shape ---
def
test_execute_returns_one_tensor_per_image
(
self
):
grid_thw
=
[[
1
,
4
,
4
],
[
1
,
4
,
4
]]
mm_kwargs
=
_make_mm_kwargs
(
grid_thw
,
self
.
device
,
self
.
dtype
)
result
=
self
.
mgr
.
execute
(
mm_kwargs
)
assert
result
is
not
None
assert
len
(
result
)
==
2
def
test_execute_output_tokens_per_image
(
self
):
# [1,4,4] → 1*(4//2)*(4//2) = 4 tokens; [1,8,8] → 16 tokens
grid_thw
=
[[
1
,
4
,
4
],
[
1
,
8
,
8
]]
mm_kwargs
=
_make_mm_kwargs
(
grid_thw
,
self
.
device
,
self
.
dtype
)
result
=
self
.
mgr
.
execute
(
mm_kwargs
)
assert
result
is
not
None
assert
result
[
0
].
shape
==
(
4
,
_HIDDEN
)
assert
result
[
1
].
shape
==
(
16
,
_HIDDEN
)
# --- budget fallback ---
def
test_eager_fallback_when_tokens_exceed_all_budgets
(
self
):
# [1,18,18] → 1*(18//2)*(18//2) = 81 tokens > max budget 64.
# Greedy packing handles the fallback internally: the oversized image
# gets an eager forward pass and is returned as part of the output list
# (execute() no longer returns None for individual image misses).
grid_thw
=
[[
1
,
18
,
18
]]
mm_kwargs
=
_make_mm_kwargs
(
grid_thw
,
self
.
device
,
self
.
dtype
)
result
=
self
.
mgr
.
execute
(
mm_kwargs
)
assert
result
is
not
None
assert
len
(
result
)
==
1
# Eager output: SimpleMockViTModel produces n_out = 81 tokens
assert
result
[
0
].
shape
==
(
81
,
_HIDDEN
)
assert
self
.
mgr
.
graph_misses
==
1
# --- counters ---
def
test_hit_counter_increments_by_num_images
(
self
):
grid_thw
=
[[
1
,
4
,
4
],
[
1
,
4
,
4
]]
mm_kwargs
=
_make_mm_kwargs
(
grid_thw
,
self
.
device
,
self
.
dtype
)
self
.
mgr
.
execute
(
mm_kwargs
)
assert
self
.
mgr
.
graph_hits
==
2
def
test_miss_counter_increments_by_num_images
(
self
):
grid_thw
=
[[
1
,
18
,
18
]]
# 81 tokens > 64
mm_kwargs
=
_make_mm_kwargs
(
grid_thw
,
self
.
device
,
self
.
dtype
)
self
.
mgr
.
execute
(
mm_kwargs
)
assert
self
.
mgr
.
graph_misses
==
1
# --- chunking ---
def
test_chunking_when_images_exceed_max_batch
(
self
):
# 8 images > max_batch_size=4 → 2 chunks of 4
# each chunk: 4 * 4 = 16 tokens → fits budget 16
n_images
=
_MAX_BATCH
*
2
grid_thw
=
[[
1
,
4
,
4
]]
*
n_images
mm_kwargs
=
_make_mm_kwargs
(
grid_thw
,
self
.
device
,
self
.
dtype
)
result
=
self
.
mgr
.
execute
(
mm_kwargs
)
assert
result
is
not
None
assert
len
(
result
)
==
n_images
for
out
in
result
:
assert
out
.
shape
==
(
4
,
_HIDDEN
)
vllm/config/compilation.py
View file @
f85e479e
...
...
@@ -489,6 +489,28 @@ class CompilationConfig:
on selected platforms. Disabled by default until more models
are supported/tested to work."""
# Vision encoder CUDA graph
cudagraph_mm_encoder
:
bool
=
False
"""Enable CUDA graph capture for multimodal encoder (ViT).
When enabled, captures full encoder forward as CUDA graph
for each token budget level."""
encoder_cudagraph_token_budgets
:
list
[
int
]
=
field
(
default_factory
=
list
)
"""Token budget levels for encoder CUDA graph capture.
Each budget defines a fixed token capacity. At runtime, images are greedy-packed
into the smallest fitting budget and the corresponding CUDA graph is replayed.
If empty (default), auto-inferred from model architecture as power-of-2
levels from the model's estimated min budget to max budget.
User-provided values override auto-inference.
Example: [2048, 4096, 8192, 13824]"""
encoder_cudagraph_max_images_per_batch
:
int
=
0
"""Maximum number of images per batch for encoder CUDA graph capture.
Determines the fixed batch size used during graph capture.
If 0 (default), auto-inferred as max_budget // min_budget from the
model's budget range. User-provided positive value overrides
auto-inference."""
# Inductor capture
compile_sizes
:
list
[
int
|
str
]
|
None
=
None
"""Sizes to compile for inductor. In addition
...
...
@@ -906,6 +928,16 @@ class CompilationConfig:
f
"Invalid backend for piecewise compilation:
{
self
.
backend
}
"
)
# Validate encoder CUDA graph configuration
if
(
self
.
cudagraph_mm_encoder
and
self
.
encoder_cudagraph_max_images_per_batch
<
0
):
raise
ValueError
(
"encoder_cudagraph_max_images_per_batch must be "
"non-negative (0 = auto-infer)"
)
if
self
.
backend
==
""
:
self
.
backend
=
current_platform
.
get_compile_backend
()
...
...
vllm/model_executor/models/interfaces.py
View file @
f85e479e
...
...
@@ -13,6 +13,7 @@ from collections.abc import (
from
contextlib
import
ExitStack
,
contextmanager
,
nullcontext
from
typing
import
(
TYPE_CHECKING
,
Any
,
ClassVar
,
Literal
,
Protocol
,
...
...
@@ -46,6 +47,11 @@ if TYPE_CHECKING:
from
vllm.multimodal.inputs
import
MultiModalFeatureSpec
from
vllm.multimodal.registry
import
_ProcessorFactories
from
vllm.sequence
import
IntermediateTensors
from
vllm.v1.worker.gpu.mm.encoder_cudagraph_defs
import
(
EncoderCudaGraphCaptureInputs
,
EncoderCudaGraphConfig
,
EncoderCudaGraphReplayBuffers
,
)
else
:
VllmConfig
=
object
WeightsMapper
=
object
...
...
@@ -1494,3 +1500,138 @@ def supports_xdrope(
model
:
type
[
object
]
|
object
,
)
->
TypeIs
[
type
[
SupportsXDRoPE
]]
|
TypeIs
[
SupportsXDRoPE
]:
return
isinstance
(
model
,
SupportsXDRoPE
)
@
runtime_checkable
class
SupportsEncoderCudaGraph
(
Protocol
):
"""Interface for models whose vision encoder supports CUDA graph
capture/replay.
Models implement these methods to provide the
:class:`EncoderCudaGraphManager` with all model-specific logic
(input handling, metadata computation, forward pass) without the
manager needing to know model internals.
"""
supports_encoder_cudagraph
:
ClassVar
[
Literal
[
True
]]
=
True
def
get_encoder_cudagraph_config
(
self
)
->
"EncoderCudaGraphConfig"
:
...
def
get_encoder_cudagraph_budget_range
(
self
,
vllm_config
:
"VllmConfig"
,
)
->
tuple
[
int
,
int
]:
"""Return (min_token_budget, max_token_budget) for auto-inference.
- min_token_budget: estimated smallest possible encoder input
(e.g. 64 for a 224x224 image)
- max_token_budget: estimated largest budget worth capturing
(e.g. max_num_batched_tokens)
Used when ``encoder_cudagraph_token_budgets`` and/or
``encoder_cudagraph_max_images_per_batch`` are not explicitly
specified by the user.
"""
...
def
get_encoder_cudagraph_num_items
(
self
,
mm_kwargs
:
dict
[
str
,
Any
],
)
->
int
:
"""Return the number of items (e.g. images) in the batch."""
...
def
get_encoder_cudagraph_per_item_output_tokens
(
self
,
mm_kwargs
:
dict
[
str
,
Any
],
)
->
list
[
int
]:
"""Return output token count for each item.
Used for greedy packing and DP load balancing.
"""
...
def
get_encoder_cudagraph_per_item_input_sizes
(
self
,
mm_kwargs
:
dict
[
str
,
Any
],
)
->
list
[
int
]:
"""Return input size (e.g. patch count) for each item.
Used for input tensor slicing offsets.
"""
...
def
select_encoder_cudagraph_items
(
self
,
mm_kwargs
:
dict
[
str
,
Any
],
indices
:
list
[
int
],
)
->
dict
[
str
,
Any
]:
"""Select a subset of items and return mm_kwargs for the sub-batch.
Called by the manager during greedy packing and DP sharding to
extract inputs for a specific set of items (e.g. images at
indices [0, 3, 5]). The implementation is model-specific
because input formats differ:
- Qwen-family: slice concatenated pixel_values by cumulative
patch offsets, subset grid_thw by indices.
- Batched models (CLIP): index pixel_values along dim 0.
"""
...
def
prepare_encoder_cudagraph_capture_inputs
(
self
,
token_budget
:
int
,
max_batch_size
:
int
,
device
:
torch
.
device
,
dtype
:
torch
.
dtype
,
)
->
"EncoderCudaGraphCaptureInputs"
:
"""Create dummy inputs and buffers for CUDA graph capture."""
...
def
prepare_encoder_cudagraph_replay_buffers
(
self
,
mm_kwargs
:
dict
[
str
,
Any
],
max_batch_size
:
int
,
)
->
"EncoderCudaGraphReplayBuffers"
:
"""Compute buffer values from actual batch inputs for replay."""
...
def
encoder_cudagraph_forward
(
self
,
mm_kwargs
:
dict
[
str
,
Any
],
buffers
:
dict
[
str
,
torch
.
Tensor
],
)
->
torch
.
Tensor
:
"""Run the encoder forward pass with precomputed buffers.
Used during both CUDA graph capture and replay.
"""
...
def
encoder_eager_forward
(
self
,
mm_kwargs
:
dict
[
str
,
Any
],
)
->
torch
.
Tensor
:
"""Run the encoder forward pass without precomputed buffers.
Used as eager fallback when inputs exceed all budgets.
"""
...
@
overload
def
supports_encoder_cudagraph
(
model
:
type
[
object
],
)
->
TypeIs
[
type
[
SupportsEncoderCudaGraph
]]:
...
@
overload
def
supports_encoder_cudagraph
(
model
:
object
,
)
->
TypeIs
[
SupportsEncoderCudaGraph
]:
...
def
supports_encoder_cudagraph
(
model
:
type
[
object
]
|
object
,
)
->
TypeIs
[
type
[
SupportsEncoderCudaGraph
]]
|
TypeIs
[
SupportsEncoderCudaGraph
]:
return
isinstance
(
model
,
SupportsEncoderCudaGraph
)
vllm/model_executor/models/qwen3_vl.py
View file @
f85e479e
...
...
@@ -103,6 +103,7 @@ from .interfaces import (
MultiModalEmbeddings
,
SupportsEagle
,
SupportsEagle3
,
SupportsEncoderCudaGraph
,
SupportsLoRA
,
SupportsMRoPE
,
SupportsMultiModal
,
...
...
@@ -528,54 +529,120 @@ class Qwen3_VisionTransformer(nn.Module):
return
torch
.
cat
(
outputs
,
dim
=
0
)
def
prepare_encoder_metadata
(
self
,
grid_thw_list
:
list
[
list
[
int
]],
*
,
max_batch_size
:
int
|
None
=
None
,
max_seqlen_override
:
int
|
None
=
None
,
device
:
torch
.
device
|
None
=
None
,
)
->
dict
[
str
,
torch
.
Tensor
|
None
]:
"""Compute encoder metadata from grid_thw_list.
Shared by the eager forward path, CUDA graph capture, and
CUDA graph replay to avoid duplicated implementation.
Args:
grid_thw_list: Grid configurations as list of [t, h, w].
max_batch_size: If set, pad cu_seqlens to this size
(needed for CUDA graph capture/replay).
max_seqlen_override: If set, use this value for max_seqlen
instead of computing from cu_seqlens (needed for CUDA
graph capture to cover worst-case replay scenarios).
device: Device to place tensors on. Defaults to self.device.
"""
if
device
is
None
:
device
=
self
.
device
metadata
:
dict
[
str
,
torch
.
Tensor
|
None
]
=
{}
# Positional embeddings
metadata
[
"pos_embeds"
]
=
self
.
fast_pos_embed_interpolate
(
grid_thw_list
)
rotary_cos
,
rotary_sin
=
self
.
rot_pos_emb
(
grid_thw_list
)
metadata
[
"rotary_pos_emb_cos"
]
=
rotary_cos
metadata
[
"rotary_pos_emb_sin"
]
=
rotary_sin
# cu_seqlens from grid_thw
grid_thw_np
=
np
.
array
(
grid_thw_list
,
dtype
=
np
.
int32
)
patches_per_frame
=
grid_thw_np
[:,
1
]
*
grid_thw_np
[:,
2
]
cu_seqlens
=
np
.
repeat
(
patches_per_frame
,
grid_thw_np
[:,
0
]).
cumsum
(
dtype
=
np
.
int32
)
cu_seqlens
=
np
.
concatenate
([
np
.
zeros
(
1
,
dtype
=
np
.
int32
),
cu_seqlens
])
# Pad cu_seqlens if max_batch_size specified
if
max_batch_size
is
not
None
:
num_seqs
=
len
(
cu_seqlens
)
-
1
if
num_seqs
<
max_batch_size
:
cu_seqlens
=
np
.
concatenate
(
[
cu_seqlens
,
np
.
full
(
max_batch_size
-
num_seqs
,
cu_seqlens
[
-
1
],
dtype
=
np
.
int32
,
),
]
)
# sequence_lengths (backend-specific)
metadata
[
"sequence_lengths"
]
=
MMEncoderAttention
.
maybe_compute_seq_lens
(
self
.
attn_backend
,
cu_seqlens
,
device
)
# max_seqlen
if
max_seqlen_override
is
not
None
:
max_seqlen_val
=
max_seqlen_override
else
:
max_seqlen_val
=
MMEncoderAttention
.
compute_max_seqlen
(
self
.
attn_backend
,
cu_seqlens
)
# Keep max_seqlen on CPU: attention wrappers call .item() on it,
# and having it on GPU would capture a wasteful D2H copy in CUDA
# graphs without changing behavior (the scalar is baked at capture).
metadata
[
"max_seqlen"
]
=
torch
.
tensor
(
max_seqlen_val
,
dtype
=
torch
.
int32
)
# Recompute cu_seqlens (backend-specific transformation)
metadata
[
"cu_seqlens"
]
=
MMEncoderAttention
.
maybe_recompute_cu_seqlens
(
self
.
attn_backend
,
cu_seqlens
,
self
.
hidden_size
,
self
.
tp_size
,
device
,
)
return
metadata
def
forward
(
self
,
x
:
torch
.
Tensor
,
grid_thw
:
torch
.
Tensor
|
list
[
list
[
int
]],
*
,
encoder_metadata
:
dict
[
str
,
torch
.
Tensor
]
|
None
=
None
,
)
->
torch
.
Tensor
:
hidden_states
=
x
.
to
(
device
=
self
.
device
,
dtype
=
self
.
dtype
,
non_blocking
=
True
)
hidden_states
=
self
.
patch_embed
(
hidden_states
)
if
encoder_metadata
is
None
:
if
isinstance
(
grid_thw
,
list
):
grid_thw_list
=
grid_thw
grid_thw
=
np
.
array
(
grid_thw
,
dtype
=
np
.
int32
)
else
:
grid_thw_list
=
grid_thw
.
tolist
()
grid_thw
=
grid_thw
.
numpy
(
)
encoder_metadata
=
self
.
prepare_encoder_metadata
(
grid_thw
_list
)
pos_embeds
=
self
.
fast_pos_embed_interpolate
(
grid_thw_list
)
pos_embeds
=
encoder_metadata
[
"pos_embeds"
]
hidden_states
=
hidden_states
+
pos_embeds
rotary_pos_emb_cos
,
rotary_pos_emb_sin
=
self
.
rot_pos_emb
(
grid_thw_list
)
cu_seqlens
=
np
.
repeat
(
grid_thw
[:,
1
]
*
grid_thw
[:,
2
],
grid_thw
[:,
0
]).
cumsum
(
axis
=
0
,
dtype
=
np
.
int32
)
cu_seqlens
=
np
.
concatenate
([
np
.
zeros
(
1
,
dtype
=
np
.
int32
),
cu_seqlens
])
sequence_lengths
=
MMEncoderAttention
.
maybe_compute_seq_lens
(
self
.
attn_backend
,
cu_seqlens
,
self
.
device
)
max_seqlen
=
torch
.
tensor
(
MMEncoderAttention
.
compute_max_seqlen
(
self
.
attn_backend
,
cu_seqlens
),
dtype
=
torch
.
int32
,
)
cu_seqlens
=
MMEncoderAttention
.
maybe_recompute_cu_seqlens
(
self
.
attn_backend
,
cu_seqlens
,
self
.
hidden_size
,
self
.
tp_size
,
self
.
device
,
)
hidden_states
=
hidden_states
.
unsqueeze
(
1
)
deepstack_feature_lists
=
[]
for
layer_num
,
blk
in
enumerate
(
self
.
blocks
):
hidden_states
=
blk
(
hidden_states
,
cu_seqlens
=
cu_seqlens
,
rotary_pos_emb_cos
=
rotary_pos_emb_cos
,
rotary_pos_emb_sin
=
rotary_pos_emb_sin
,
max_seqlen
=
max_seqlen
,
sequence_lengths
=
sequence_lengths
,
cu_seqlens
=
encoder_metadata
[
"
cu_seqlens
"
]
,
rotary_pos_emb_cos
=
encoder_metadata
[
"
rotary_pos_emb_cos
"
]
,
rotary_pos_emb_sin
=
encoder_metadata
[
"
rotary_pos_emb_sin
"
]
,
max_seqlen
=
encoder_metadata
[
"
max_seqlen
"
]
,
sequence_lengths
=
encoder_metadata
.
get
(
"
sequence_lengths
"
)
,
)
if
layer_num
in
self
.
deepstack_visual_indexes
:
deepstack_merger_idx
=
self
.
deepstack_visual_indexes
.
index
(
layer_num
)
...
...
@@ -1358,6 +1425,7 @@ class Qwen3LLMForCausalLM(Qwen3ForCausalLM):
class
Qwen3VLForConditionalGeneration
(
nn
.
Module
,
SupportsMultiModal
,
SupportsEncoderCudaGraph
,
SupportsLoRA
,
SupportsPP
,
SupportsMRoPE
,
...
...
@@ -1507,6 +1575,178 @@ class Qwen3VLForConditionalGeneration(
for
idx
in
range
(
self
.
deepstack_num_level
):
self
.
deepstack_input_embeds
[
idx
][:
num_tokens
].
zero_
()
# -- SupportsEncoderCudaGraph protocol methods --
def
get_encoder_cudagraph_config
(
self
):
from
vllm.v1.worker.gpu.mm.encoder_cudagraph_defs
import
(
EncoderCudaGraphConfig
,
)
return
EncoderCudaGraphConfig
(
modalities
=
[
"image"
],
input_key
=
"pixel_values"
,
buffer_keys
=
[
"pos_embeds"
,
"rotary_pos_emb_cos"
,
"rotary_pos_emb_sin"
,
"cu_seqlens"
,
"max_seqlen"
,
"sequence_lengths"
,
],
out_hidden_size
=
self
.
visual
.
out_hidden_size
,
)
def
get_encoder_cudagraph_budget_range
(
self
,
vllm_config
,
)
->
tuple
[
int
,
int
]:
# Min: estimated smallest possible encoder input.
# 224x224 image → 16x16 patches, spatial_merge_size=2 → 8x8 = 64 tokens
min_budget
=
64
# Max: capped by max_num_batched_tokens
max_budget
=
vllm_config
.
scheduler_config
.
max_num_batched_tokens
return
(
min_budget
,
max_budget
)
def
get_encoder_cudagraph_num_items
(
self
,
mm_kwargs
:
dict
[
str
,
Any
],
)
->
int
:
return
len
(
mm_kwargs
[
"image_grid_thw"
])
def
get_encoder_cudagraph_per_item_output_tokens
(
self
,
mm_kwargs
:
dict
[
str
,
Any
],
)
->
list
[
int
]:
m
=
self
.
visual
.
spatial_merge_size
return
[
t
*
(
h
//
m
)
*
(
w
//
m
)
for
t
,
h
,
w
in
mm_kwargs
[
"image_grid_thw"
]]
def
get_encoder_cudagraph_per_item_input_sizes
(
self
,
mm_kwargs
:
dict
[
str
,
Any
],
)
->
list
[
int
]:
return
[
t
*
h
*
w
for
t
,
h
,
w
in
mm_kwargs
[
"image_grid_thw"
]]
def
select_encoder_cudagraph_items
(
self
,
mm_kwargs
:
dict
[
str
,
Any
],
indices
:
list
[
int
],
)
->
dict
[
str
,
Any
]:
grid_thw
=
mm_kwargs
[
"image_grid_thw"
]
pixel_values
=
mm_kwargs
[
"pixel_values"
]
if
len
(
indices
)
==
0
:
return
{
"pixel_values"
:
pixel_values
[:
0
],
"image_grid_thw"
:
[],
}
# Compute cumulative patch offsets for slicing pixel_values
patches_per_item
=
[
t
*
h
*
w
for
t
,
h
,
w
in
grid_thw
]
cum_patches
=
[
0
]
for
p
in
patches_per_item
:
cum_patches
.
append
(
cum_patches
[
-
1
]
+
p
)
selected_pv
=
torch
.
cat
(
[
pixel_values
[
cum_patches
[
i
]
:
cum_patches
[
i
+
1
]]
for
i
in
indices
]
)
selected_grid
=
[
grid_thw
[
i
]
for
i
in
indices
]
return
{
"pixel_values"
:
selected_pv
,
"image_grid_thw"
:
selected_grid
,
}
def
prepare_encoder_cudagraph_capture_inputs
(
self
,
token_budget
:
int
,
max_batch_size
:
int
,
device
:
torch
.
device
,
dtype
:
torch
.
dtype
,
):
from
vllm.v1.worker.gpu.mm.encoder_cudagraph_defs
import
(
EncoderCudaGraphCaptureInputs
,
)
spatial_merge_size
=
self
.
visual
.
spatial_merge_size
per_image_output
=
token_budget
//
max_batch_size
# Synthetic rectangular grid: [1, merge, per_image_output * merge]
# produces exactly per_image_output tokens per image.
grid_config
=
[
[
1
,
spatial_merge_size
,
per_image_output
*
spatial_merge_size
]
for
_
in
range
(
max_batch_size
)
]
# Create dummy pixel_values
patch_embed
=
self
.
visual
.
patch_embed
in_channels
=
patch_embed
.
proj
.
in_channels
patch_size
=
patch_embed
.
patch_size
temporal_patch_size
=
patch_embed
.
temporal_patch_size
total_patches
=
sum
(
t
*
h
*
w
for
t
,
h
,
w
in
grid_config
)
flattened_patch_size
=
(
in_channels
*
temporal_patch_size
*
patch_size
*
patch_size
)
dummy_pixel_values
=
torch
.
randn
(
total_patches
,
flattened_patch_size
,
device
=
device
,
dtype
=
dtype
)
# Override max_seqlen with a safe upper bound for capture.
# max_seqlen.item() gets baked into the CUDA graph (not replayed),
# so the capture value must cover any replay scenario.
# Worst case: 1 image consuming the full budget ->
# seq_len = token_budget * spatial_merge_size^2.
buffers
=
self
.
visual
.
prepare_encoder_metadata
(
grid_config
,
max_batch_size
=
max_batch_size
,
max_seqlen_override
=
token_budget
*
(
spatial_merge_size
**
2
),
device
=
device
,
)
mm_kwargs
=
{
"pixel_values"
:
dummy_pixel_values
,
"image_grid_thw"
:
grid_config
,
}
return
EncoderCudaGraphCaptureInputs
(
mm_kwargs
=
mm_kwargs
,
buffers
=
buffers
,
)
def
prepare_encoder_cudagraph_replay_buffers
(
self
,
mm_kwargs
:
dict
[
str
,
Any
],
max_batch_size
:
int
,
):
from
vllm.v1.worker.gpu.mm.encoder_cudagraph_defs
import
(
EncoderCudaGraphReplayBuffers
,
)
grid_thw_list
=
mm_kwargs
[
"image_grid_thw"
]
buffers
=
self
.
visual
.
prepare_encoder_metadata
(
grid_thw_list
,
max_batch_size
=
max_batch_size
,
)
return
EncoderCudaGraphReplayBuffers
(
buffers
=
buffers
)
def
encoder_cudagraph_forward
(
self
,
mm_kwargs
:
dict
[
str
,
Any
],
buffers
:
dict
[
str
,
torch
.
Tensor
],
)
->
torch
.
Tensor
:
pixel_values
=
mm_kwargs
[
"pixel_values"
]
grid_thw
=
mm_kwargs
[
"image_grid_thw"
]
return
self
.
visual
(
pixel_values
,
grid_thw
,
encoder_metadata
=
buffers
)
def
encoder_eager_forward
(
self
,
mm_kwargs
:
dict
[
str
,
Any
],
)
->
torch
.
Tensor
:
pixel_values
=
mm_kwargs
[
"pixel_values"
]
grid_thw
=
mm_kwargs
[
"image_grid_thw"
]
return
self
.
visual
(
pixel_values
,
grid_thw
)
def
_parse_and_validate_image_input
(
self
,
**
kwargs
:
object
)
->
Qwen2_5_VLImageInputs
|
None
:
...
...
vllm/v1/worker/gpu/mm/encoder_cudagraph.py
0 → 100644
View file @
f85e479e
This diff is collapsed.
Click to expand it.
vllm/v1/worker/gpu/mm/encoder_cudagraph_defs.py
0 → 100644
View file @
f85e479e
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Data transfer objects for encoder CUDA graph management."""
from
dataclasses
import
dataclass
from
typing
import
Any
import
torch
@
dataclass
class
EncoderCudaGraphConfig
:
"""Configuration for encoder CUDA graph management.
Provided by the model at init time via
``get_encoder_cudagraph_config()``. Values are fixed for the
lifetime of the manager.
"""
modalities
:
list
[
str
]
"""Supported modalities (e.g. ["image"])."""
input_key
:
str
"""Key in mm_kwargs for the input tensor (e.g. "pixel_values")."""
buffer_keys
:
list
[
str
]
"""Keys for the tensor buffers recorded into the CUDA graph.
Before replay the manager zeros then slice-copies new data
into these buffers."""
out_hidden_size
:
int
"""Output hidden dim of the vision encoder.
Used for DP gather buffer allocation."""
@
dataclass
class
EncoderCudaGraphCaptureInputs
:
"""Everything needed for one CUDA graph capture.
Returned by ``prepare_encoder_cudagraph_capture_inputs()``.
"""
mm_kwargs
:
dict
[
str
,
Any
]
"""Dummy forward inputs (model-specific keys).
For Qwen3-VL this contains pixel_values and grid_thw."""
buffers
:
dict
[
str
,
torch
.
Tensor
]
"""Precomputed tensor buffers that will be recorded into the
CUDA graph. The manager stores references to these exact
tensor objects and copies new data into them before each
``graph.replay()`` call (buffer identity invariant)."""
@
dataclass
class
EncoderCudaGraphReplayBuffers
:
"""New buffer values for graph replay, computed by the model from
actual batch inputs.
Returned by ``prepare_encoder_cudagraph_replay_buffers()``.
Keys match ``EncoderCudaGraphConfig.buffer_keys``.
"""
buffers
:
dict
[
str
,
torch
.
Tensor
|
None
]
"""Data to copy into the captured buffers before replay.
``None`` values leave the corresponding captured buffer
unchanged."""
vllm/v1/worker/gpu_model_runner.py
View file @
f85e479e
...
...
@@ -207,6 +207,7 @@ from .utils import (
if
TYPE_CHECKING
:
from
vllm.v1.core.sched.output
import
GrammarOutput
,
SchedulerOutput
from
vllm.v1.spec_decode.ngram_proposer
import
NgramProposer
from
vllm.v1.worker.gpu.mm.encoder_cudagraph
import
EncoderCudaGraphManager
logger
=
init_logger
(
__name__
)
...
...
@@ -499,6 +500,9 @@ class GPUModelRunner(
self
.
encoder_cache
:
dict
[
str
,
torch
.
Tensor
]
=
{}
self
.
late_interaction_runner
=
LateInteractionRunner
()
# Encoder CUDA graph manager (initialized after model load if enabled)
self
.
encoder_cudagraph_manager
:
EncoderCudaGraphManager
|
None
=
None
self
.
use_aux_hidden_state_outputs
=
False
# Set up speculative decoding.
# NOTE(Jiayi): currently we put the entire draft model on
...
...
@@ -2664,6 +2668,18 @@ class GPUModelRunner(
with
self
.
timed_encoder_operation
(
should_time
,
mm_lora_refs
,
current_item_idx
,
num_items
):
cudagraph_output
=
None
if
(
self
.
encoder_cudagraph_manager
is
not
None
and
self
.
encoder_cudagraph_manager
.
supports_modality
(
modality
)
):
cudagraph_output
=
self
.
encoder_cudagraph_manager
.
execute
(
mm_kwargs_batch
,
)
if
cudagraph_output
is
not
None
:
batch_outputs
=
cudagraph_output
else
:
batch_outputs
=
model
.
embed_multimodal
(
**
mm_kwargs_batch
)
sanity_check_mm_encoder_outputs
(
batch_outputs
,
expected_num_items
=
num_items
)
...
...
@@ -5715,6 +5731,33 @@ class GPUModelRunner(
)
return
0
# Initialize encoder CUDA graph manager if enabled.
# Use get_model() to unwrap CUDAGraphWrapper/UBatchWrapper,
# because @runtime_checkable Protocol isinstance() checks do not
# work through __getattr__ forwarding.
if
(
self
.
compilation_config
.
cudagraph_mm_encoder
and
self
.
supports_mm_inputs
and
self
.
encoder_cudagraph_manager
is
None
):
from
vllm.model_executor.models.interfaces
import
(
SupportsEncoderCudaGraph
,
supports_encoder_cudagraph
,
)
from
vllm.v1.worker.gpu.mm.encoder_cudagraph
import
(
EncoderCudaGraphManager
,
)
raw_model
=
self
.
get_model
()
if
supports_encoder_cudagraph
(
raw_model
):
self
.
encoder_cudagraph_manager
=
EncoderCudaGraphManager
(
vllm_config
=
self
.
vllm_config
,
device
=
self
.
device
,
dtype
=
self
.
dtype
,
model
=
cast
(
SupportsEncoderCudaGraph
,
raw_model
),
)
logger
.
info
(
"Initialized EncoderCudaGraphManager for vision encoder"
)
compilation_counter
.
num_gpu_runner_capture_triggers
+=
1
start_time
=
time
.
perf_counter
()
...
...
@@ -5738,6 +5781,10 @@ class GPUModelRunner(
)
torch
.
accelerator
.
synchronize
()
# Capture encoder CUDA graphs if enabled
if
self
.
encoder_cudagraph_manager
is
not
None
:
self
.
encoder_cudagraph_manager
.
capture
()
torch
.
accelerator
.
synchronize
()
end_free_gpu_memory
=
torch
.
cuda
.
mem_get_info
()[
0
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment