"tests/vscode:/vscode.git/clone" did not exist on "d9fc8cd9da4a69cb4171efb7cb5a46308680c83c"
Unverified Commit 3fd03f1e authored by Lucas Kabela's avatar Lucas Kabela Committed by GitHub
Browse files

[BE] Rename `should_torch_compile_mm_vit` to `should_torch_compile_mm_encoder` (#36281)


Signed-off-by: default avatarLucas Kabela <lucaskabela@meta.com>
parent 10a5f4d5
......@@ -26,7 +26,7 @@ This feature is off by default, but can be enabled by setting `compile_mm_encode
To compile a multimodal component such as an encoder, we follow the same mechanism as the LLM text backbone, with a few additional scaffoldings:
1. The `@support_torch_compile` decorator should include `enable_if=should_torch_compile_mm_vit`. This will gate the compilation behind our
1. The `@support_torch_compile` decorator should include `enable_if=should_torch_compile_mm_encoder`. This will gate the compilation behind our
`compile_mm_encoder` configuration
2. `with set_model_tag("<component_name>", is_encoder=True)` context manager should be used around the nn.Module's instantiation. Since torch.compile
......
......@@ -47,6 +47,11 @@ IGNORE_COMPILE_KEY = "_ignore_compile_vllm"
_T = TypeVar("_T", bound=nn.Module)
def should_torch_compile_mm_encoder(vllm_config: VllmConfig) -> bool:
"""Callable to be passed to `@support_torch_compile`'s `enable_if` argument."""
return vllm_config.compilation_config.compile_mm_encoder
def ignore_torch_compile(cls: type[_T]) -> type[_T]:
"""
A decorator to ignore support_torch_compile decorator
......
......@@ -10,7 +10,10 @@ from torch import nn
from torch.nn import functional as F
from transformers import Siglip2VisionConfig
from vllm.compilation.decorators import support_torch_compile
from vllm.compilation.decorators import (
should_torch_compile_mm_encoder,
support_torch_compile,
)
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.attention import MMEncoderAttention
......@@ -25,7 +28,6 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from .vision import (
is_vit_use_data_parallel,
resolve_visual_encoder_outputs,
should_torch_compile_mm_vit,
)
......@@ -269,7 +271,7 @@ class Siglip2MLP(nn.Module):
@support_torch_compile(
dynamic_arg_dims={"hidden_states": [0, 1], "cu_seqlens": 0},
enable_if=should_torch_compile_mm_vit,
enable_if=should_torch_compile_mm_encoder,
)
class Siglip2EncoderLayer(nn.Module):
def __init__(
......
......@@ -31,7 +31,10 @@ from transformers.models.llama4.image_processing_llama4_fast import (
get_best_fit,
)
from vllm.compilation.decorators import support_torch_compile
from vllm.compilation.decorators import (
should_torch_compile_mm_encoder,
support_torch_compile,
)
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import get_tensor_model_parallel_world_size
......@@ -49,7 +52,6 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.model_loader.utils import initialize_model
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.models.vision import should_torch_compile_mm_vit
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (
MultiModalDataDict,
......@@ -454,7 +456,7 @@ class Llama4UnfoldConvolution(nn.Module):
@support_torch_compile(
dynamic_arg_dims={"images_flattened": 0}, enable_if=should_torch_compile_mm_vit
dynamic_arg_dims={"images_flattened": 0}, enable_if=should_torch_compile_mm_encoder
)
class Llama4VisionModel(nn.Module):
def __init__(
......
......@@ -42,7 +42,10 @@ from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import (
Qwen2_5_VLVisionConfig,
)
from vllm.compilation.decorators import support_torch_compile
from vllm.compilation.decorators import (
should_torch_compile_mm_encoder,
support_torch_compile,
)
from vllm.config import VllmConfig
from vllm.distributed import parallel_state
from vllm.distributed import utils as dist_utils
......@@ -65,7 +68,6 @@ from vllm.model_executor.layers.rotary_embedding.common import (
)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.models.vision import should_torch_compile_mm_vit
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.evs import (
compute_mrope_for_media,
......@@ -424,7 +426,7 @@ class Qwen2_5_VisionAttention(nn.Module):
"rotary_pos_emb_cos": 0,
"rotary_pos_emb_sin": 0,
},
enable_if=should_torch_compile_mm_vit,
enable_if=should_torch_compile_mm_encoder,
)
class Qwen2_5_VisionBlock(nn.Module):
def __init__(
......@@ -483,7 +485,7 @@ class Qwen2_5_VisionBlock(nn.Module):
dynamic_arg_dims={
"x": 0,
},
enable_if=should_torch_compile_mm_vit,
enable_if=should_torch_compile_mm_encoder,
)
class Qwen2_5_VisionPatchEmbed(nn.Module):
def __init__(
......@@ -518,7 +520,7 @@ class Qwen2_5_VisionPatchEmbed(nn.Module):
dynamic_arg_dims={
"x": 0,
},
enable_if=should_torch_compile_mm_vit,
enable_if=should_torch_compile_mm_encoder,
)
class Qwen2_5_VisionPatchMerger(nn.Module):
def __init__(
......
......@@ -143,11 +143,6 @@ def is_vit_use_data_parallel():
return mm_encoder_tp_mode == "data"
def should_torch_compile_mm_vit(vllm_config: VllmConfig) -> bool:
"""Callable to be passed to `@support_torch_compile`'s `enable_if` argument."""
return vllm_config.compilation_config.compile_mm_encoder
VisionFeatureSelectStrategyStr = Literal["class", "default", "full"]
VisionFeatureSelectStrategy: TypeAlias = (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment