Unverified Commit fe57be78 authored by Shanshan Shen's avatar Shanshan Shen Committed by GitHub
Browse files

[MM][CG] Support `--enable-vit-cuda-graph` option for VLM examples (#40580)


Signed-off-by: default avatarshen-shanshan <467638484@qq.com>
parent 8317cedc
...@@ -2463,6 +2463,12 @@ MODELS_NEED_VIDEO_METADATA = [ ...@@ -2463,6 +2463,12 @@ MODELS_NEED_VIDEO_METADATA = [
] ]
MODELS_SUPPORT_VIT_CUDA_GRAPH = [
"qwen3_vl",
"qwen3_vl_moe",
]
def get_multi_modal_input(args): def get_multi_modal_input(args):
""" """
return { return {
...@@ -2575,6 +2581,29 @@ def apply_image_repeat( ...@@ -2575,6 +2581,29 @@ def apply_image_repeat(
return inputs, inputs_with_empty_media return inputs, inputs_with_empty_media
def maybe_add_vit_cuda_graph_compilation_config(args, engine_args):
model = args.model_type
modality = args.modality
enable_vit_cuda_graph = args.enable_vit_cuda_graph
if enable_vit_cuda_graph and model in MODELS_SUPPORT_VIT_CUDA_GRAPH:
if modality == "image" or modality == "video":
vision_items_per_batch = 1
elif modality == "image+video":
vision_items_per_batch = 2
else:
raise ValueError(
f"modality={modality} is not supported for vit cuda graph."
)
engine_args.compilation_config = {
"cudagraph_mm_encoder": True,
"encoder_cudagraph_max_vision_items_per_batch": vision_items_per_batch,
}
return engine_args
@contextmanager @contextmanager
def time_counter(enable: bool): def time_counter(enable: bool):
if enable: if enable:
...@@ -2625,33 +2654,28 @@ def parse_args(): ...@@ -2625,33 +2654,28 @@ def parse_args():
default=0, default=0,
help="Set the seed when initializing `vllm.LLM`.", help="Set the seed when initializing `vllm.LLM`.",
) )
parser.add_argument( parser.add_argument(
"--image-repeat-prob", "--image-repeat-prob",
type=float, type=float,
default=None, default=None,
help="Simulates the hit-ratio for multi-modal preprocessor cache (if enabled)", help="Simulates the hit-ratio for multi-modal preprocessor cache (if enabled)",
) )
parser.add_argument( parser.add_argument(
"--disable-mm-processor-cache", "--disable-mm-processor-cache",
action="store_true", action="store_true",
help="If True, disables caching of multi-modal processor.", help="If True, disables caching of multi-modal processor.",
) )
parser.add_argument( parser.add_argument(
"--time-generate", "--time-generate",
action="store_true", action="store_true",
help="If True, then print the total generate() call time", help="If True, then print the total generate() call time",
) )
parser.add_argument( parser.add_argument(
"--use-different-prompt-per-request", "--use-different-prompt-per-request",
action="store_true", action="store_true",
help="If True, then use different prompt (with the same multi-modal " help="If True, then use different prompt (with the same multi-modal "
"data) for each request.", "data) for each request.",
) )
parser.add_argument( parser.add_argument(
"--verify-mm-cache-hit-with-uuids", "--verify-mm-cache-hit-with-uuids",
action="store_true", action="store_true",
...@@ -2665,6 +2689,11 @@ def parse_args(): ...@@ -2665,6 +2689,11 @@ def parse_args():
default=None, default=None,
help="Tensor parallel size to override the model's default setting. ", help="Tensor parallel size to override the model's default setting. ",
) )
parser.add_argument(
"--enable-vit-cuda-graph",
action="store_true",
help="If True, will enable vit cuda graph capture and replay for the model.",
)
return parser.parse_args() return parser.parse_args()
...@@ -2698,6 +2727,7 @@ def main(args): ...@@ -2698,6 +2727,7 @@ def main(args):
engine_args.mm_processor_cache_gb = mm_processor_cache_gb engine_args.mm_processor_cache_gb = mm_processor_cache_gb
if args.tensor_parallel_size is not None: if args.tensor_parallel_size is not None:
engine_args.tensor_parallel_size = args.tensor_parallel_size engine_args.tensor_parallel_size = args.tensor_parallel_size
engine_args = maybe_add_vit_cuda_graph_compilation_config(args, engine_args)
llm = LLM.from_engine_args(engine_args) llm = LLM.from_engine_args(engine_args)
# Don't want to check the flag multiple times, so just hijack `prompts`. # Don't want to check the flag multiple times, so just hijack `prompts`.
......
...@@ -1802,7 +1802,11 @@ class Qwen3VLForConditionalGeneration( ...@@ -1802,7 +1802,11 @@ class Qwen3VLForConditionalGeneration(
# spatial_merge_size=2 → 8x8 = 64 tokens # spatial_merge_size=2 → 8x8 = 64 tokens
min_budget = 64 min_budget = 64
# Max: capped by max_num_batched_tokens # Max: capped by max_num_batched_tokens
max_budget = vllm_config.scheduler_config.max_num_batched_tokens # TODO(shen-shanshan): the max_budget auto-infer needs to be optimized later.
max_budget = min(
vllm_config.scheduler_config.max_num_batched_tokens,
self.model_config.max_model_len,
)
return (min_budget, max_budget) return (min_budget, max_budget)
def _get_pixel_values_by_modality( def _get_pixel_values_by_modality(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment