Unverified Commit 3fd03f1e authored by Lucas Kabela's avatar Lucas Kabela Committed by GitHub
Browse files

[BE] Rename `should_torch_compile_mm_vit` to `should_torch_compile_mm_encoder` (#36281)


Signed-off-by: default avatarLucas Kabela <lucaskabela@meta.com>
parent 10a5f4d5
...@@ -26,7 +26,7 @@ This feature is off by default, but can be enabled by setting `compile_mm_encode ...@@ -26,7 +26,7 @@ This feature is off by default, but can be enabled by setting `compile_mm_encode
To compile a multimodal component such as an encoder, we follow the same mechanism as the LLM text backbone, with a few additional scaffoldings: To compile a multimodal component such as an encoder, we follow the same mechanism as the LLM text backbone, with a few additional scaffoldings:
1. The `@support_torch_compile` decorator should include `enable_if=should_torch_compile_mm_vit`. This will gate the compilation behind our 1. The `@support_torch_compile` decorator should include `enable_if=should_torch_compile_mm_encoder`. This will gate the compilation behind our
`compile_mm_encoder` configuration `compile_mm_encoder` configuration
2. `with set_model_tag("<component_name>", is_encoder=True)` context manager should be used around the nn.Module's instantiation. Since torch.compile 2. `with set_model_tag("<component_name>", is_encoder=True)` context manager should be used around the nn.Module's instantiation. Since torch.compile
......
...@@ -47,6 +47,11 @@ IGNORE_COMPILE_KEY = "_ignore_compile_vllm" ...@@ -47,6 +47,11 @@ IGNORE_COMPILE_KEY = "_ignore_compile_vllm"
_T = TypeVar("_T", bound=nn.Module) _T = TypeVar("_T", bound=nn.Module)
def should_torch_compile_mm_encoder(vllm_config: VllmConfig) -> bool:
"""Callable to be passed to `@support_torch_compile`'s `enable_if` argument."""
return vllm_config.compilation_config.compile_mm_encoder
def ignore_torch_compile(cls: type[_T]) -> type[_T]: def ignore_torch_compile(cls: type[_T]) -> type[_T]:
""" """
A decorator to ignore support_torch_compile decorator A decorator to ignore support_torch_compile decorator
......
...@@ -10,7 +10,10 @@ from torch import nn ...@@ -10,7 +10,10 @@ from torch import nn
from torch.nn import functional as F from torch.nn import functional as F
from transformers import Siglip2VisionConfig from transformers import Siglip2VisionConfig
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import (
should_torch_compile_mm_encoder,
support_torch_compile,
)
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.attention import MMEncoderAttention from vllm.model_executor.layers.attention import MMEncoderAttention
...@@ -25,7 +28,6 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader ...@@ -25,7 +28,6 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from .vision import ( from .vision import (
is_vit_use_data_parallel, is_vit_use_data_parallel,
resolve_visual_encoder_outputs, resolve_visual_encoder_outputs,
should_torch_compile_mm_vit,
) )
...@@ -269,7 +271,7 @@ class Siglip2MLP(nn.Module): ...@@ -269,7 +271,7 @@ class Siglip2MLP(nn.Module):
@support_torch_compile( @support_torch_compile(
dynamic_arg_dims={"hidden_states": [0, 1], "cu_seqlens": 0}, dynamic_arg_dims={"hidden_states": [0, 1], "cu_seqlens": 0},
enable_if=should_torch_compile_mm_vit, enable_if=should_torch_compile_mm_encoder,
) )
class Siglip2EncoderLayer(nn.Module): class Siglip2EncoderLayer(nn.Module):
def __init__( def __init__(
......
...@@ -31,7 +31,10 @@ from transformers.models.llama4.image_processing_llama4_fast import ( ...@@ -31,7 +31,10 @@ from transformers.models.llama4.image_processing_llama4_fast import (
get_best_fit, get_best_fit,
) )
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import (
should_torch_compile_mm_encoder,
support_torch_compile,
)
from vllm.config import VllmConfig, set_current_vllm_config from vllm.config import VllmConfig, set_current_vllm_config
from vllm.config.multimodal import BaseDummyOptions from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
...@@ -49,7 +52,6 @@ from vllm.model_executor.layers.rotary_embedding import get_rope ...@@ -49,7 +52,6 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.model_loader.utils import initialize_model from vllm.model_executor.model_loader.utils import initialize_model
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.models.vision import should_torch_compile_mm_vit
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import ( from vllm.multimodal.inputs import (
MultiModalDataDict, MultiModalDataDict,
...@@ -454,7 +456,7 @@ class Llama4UnfoldConvolution(nn.Module): ...@@ -454,7 +456,7 @@ class Llama4UnfoldConvolution(nn.Module):
@support_torch_compile( @support_torch_compile(
dynamic_arg_dims={"images_flattened": 0}, enable_if=should_torch_compile_mm_vit dynamic_arg_dims={"images_flattened": 0}, enable_if=should_torch_compile_mm_encoder
) )
class Llama4VisionModel(nn.Module): class Llama4VisionModel(nn.Module):
def __init__( def __init__(
......
...@@ -42,7 +42,10 @@ from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import ( ...@@ -42,7 +42,10 @@ from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import (
Qwen2_5_VLVisionConfig, Qwen2_5_VLVisionConfig,
) )
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import (
should_torch_compile_mm_encoder,
support_torch_compile,
)
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed import parallel_state from vllm.distributed import parallel_state
from vllm.distributed import utils as dist_utils from vllm.distributed import utils as dist_utils
...@@ -65,7 +68,6 @@ from vllm.model_executor.layers.rotary_embedding.common import ( ...@@ -65,7 +68,6 @@ from vllm.model_executor.layers.rotary_embedding.common import (
) )
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.models.vision import should_torch_compile_mm_vit
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.evs import ( from vllm.multimodal.evs import (
compute_mrope_for_media, compute_mrope_for_media,
...@@ -424,7 +426,7 @@ class Qwen2_5_VisionAttention(nn.Module): ...@@ -424,7 +426,7 @@ class Qwen2_5_VisionAttention(nn.Module):
"rotary_pos_emb_cos": 0, "rotary_pos_emb_cos": 0,
"rotary_pos_emb_sin": 0, "rotary_pos_emb_sin": 0,
}, },
enable_if=should_torch_compile_mm_vit, enable_if=should_torch_compile_mm_encoder,
) )
class Qwen2_5_VisionBlock(nn.Module): class Qwen2_5_VisionBlock(nn.Module):
def __init__( def __init__(
...@@ -483,7 +485,7 @@ class Qwen2_5_VisionBlock(nn.Module): ...@@ -483,7 +485,7 @@ class Qwen2_5_VisionBlock(nn.Module):
dynamic_arg_dims={ dynamic_arg_dims={
"x": 0, "x": 0,
}, },
enable_if=should_torch_compile_mm_vit, enable_if=should_torch_compile_mm_encoder,
) )
class Qwen2_5_VisionPatchEmbed(nn.Module): class Qwen2_5_VisionPatchEmbed(nn.Module):
def __init__( def __init__(
...@@ -518,7 +520,7 @@ class Qwen2_5_VisionPatchEmbed(nn.Module): ...@@ -518,7 +520,7 @@ class Qwen2_5_VisionPatchEmbed(nn.Module):
dynamic_arg_dims={ dynamic_arg_dims={
"x": 0, "x": 0,
}, },
enable_if=should_torch_compile_mm_vit, enable_if=should_torch_compile_mm_encoder,
) )
class Qwen2_5_VisionPatchMerger(nn.Module): class Qwen2_5_VisionPatchMerger(nn.Module):
def __init__( def __init__(
......
...@@ -143,11 +143,6 @@ def is_vit_use_data_parallel(): ...@@ -143,11 +143,6 @@ def is_vit_use_data_parallel():
return mm_encoder_tp_mode == "data" return mm_encoder_tp_mode == "data"
def should_torch_compile_mm_vit(vllm_config: VllmConfig) -> bool:
"""Callable to be passed to `@support_torch_compile`'s `enable_if` argument."""
return vllm_config.compilation_config.compile_mm_encoder
VisionFeatureSelectStrategyStr = Literal["class", "default", "full"] VisionFeatureSelectStrategyStr = Literal["class", "default", "full"]
VisionFeatureSelectStrategy: TypeAlias = ( VisionFeatureSelectStrategy: TypeAlias = (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment