Unverified Commit ea6d067a authored by Lucas Kabela's avatar Lucas Kabela Committed by GitHub
Browse files

[Misc][LLaMa4] Compile LLaMa Vision Encoder (#30709)


Signed-off-by: default avatarLucas Kabela <lucaskabela@meta.com>
parent abd92242
...@@ -71,3 +71,40 @@ def test_qwen2_5_vl_no_vit_compilation(vllm_runner, monkeypatch): ...@@ -71,3 +71,40 @@ def test_qwen2_5_vl_no_vit_compilation(vllm_runner, monkeypatch):
) as _, ) as _,
): ):
pass pass
# forked needed to workaround https://github.com/vllm-project/vllm/issues/21073
# Requires Cuda and 8 gpus as well
@pytest.mark.forked
@pytest.mark.skip(reason="Skipping due to CI resource constraints")
def test_mllama4_vit_compilation(vllm_runner, monkeypatch):
"""Test that Mllama4 vision submodules are compiled.
This test verifies that the 2 vision submodules (Llama4VisionEncoder,
Llama4VisionPixelShuffleMLP) are properly tagged
for compilation by checking that num_models_seen increases to 3.
However since we are using TP=8, we compilation_counter will not
work properly so we will just check the run succeeds rn
"""
# Disable multiprocessing so that the counter is in the same process
monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
with (
monkeypatch.context(),
# TODO: Since we require TP=8, this messes with the compilation
# counter. We should fix this in the future, but leave for now
# to make sure that compilation runs (no crash) with llama vision encoder
compilation_counter.expect(num_models_seen=0),
vllm_runner(
"meta-llama/Llama-4-Scout-17B-16E-Instruct",
max_model_len=512,
gpu_memory_utilization=0.8,
tensor_parallel_size=8,
compilation_config={
"mode": CompilationMode.VLLM_COMPILE,
"compile_mm_encoder": True,
},
),
):
pass
...@@ -430,8 +430,9 @@ class CompilationConfig: ...@@ -430,8 +430,9 @@ class CompilationConfig:
If empty list [], no ops are excluded (suitable for full cudagraphs).""" If empty list [], no ops are excluded (suitable for full cudagraphs)."""
compile_mm_encoder: bool = False compile_mm_encoder: bool = False
"""Whether or not to compile the multimodal encoder. """Whether or not to compile the multimodal encoder.
Currently, this only works for `Qwen2_5_vl` on selected platforms. Currently, this only works for `Qwen2_5_vl` and `mLLaMa4` models
Disabled by default until more models are supported/tested to work.""" on selected platforms. Disabled by default until more models
are supported/tested to work."""
# Inductor capture # Inductor capture
compile_sizes: list[int | str] | None = None compile_sizes: list[int | str] | None = None
......
...@@ -171,12 +171,12 @@ class MMEncoderAttention(CustomOp): ...@@ -171,12 +171,12 @@ class MMEncoderAttention(CustomOp):
q=query, q=query,
k=key, k=key,
v=value, v=value,
scale=self.scale,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
batch_size=bsz, batch_size=bsz,
is_rocm_aiter=(self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA), is_rocm_aiter=(self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA),
fa_version=self._fa_version, fa_version=self._fa_version,
scale=self.scale,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
) )
if is_reshaped: if is_reshaped:
output = output.reshape(bsz, q_len, -1) output = output.reshape(bsz, q_len, -1)
......
...@@ -60,14 +60,17 @@ class Llama4VisionRotaryEmbedding(RotaryEmbeddingBase): ...@@ -60,14 +60,17 @@ class Llama4VisionRotaryEmbedding(RotaryEmbeddingBase):
assert key is not None assert key is not None
# self.cos_sin_cache here is complex tensor so we cannot cast into # self.cos_sin_cache here is complex tensor so we cannot cast into
# query's dtype directly with self._match_cos_sin_cache_dtype # query's dtype directly with self._match_cos_sin_cache_dtype
self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to(query.device)
# NOTE: by not storing cos_sin_cache in self, we can avoid
# memory buffer update which is costly to runtime
cos_sin_cache: torch.Tensor = self.cos_sin_cache.to(query.device)
query_ = torch.view_as_complex(query.float().reshape(*query.shape[:-1], -1, 2)) query_ = torch.view_as_complex(query.float().reshape(*query.shape[:-1], -1, 2))
key_ = torch.view_as_complex(key.float().reshape(*key.shape[:-1], -1, 2)) key_ = torch.view_as_complex(key.float().reshape(*key.shape[:-1], -1, 2))
broadcast_shape = [ broadcast_shape = [
d if i == 1 or i == (query_.ndim - 1) else 1 d if i == 1 or i == (query_.ndim - 1) else 1
for i, d in enumerate(query_.shape) for i, d in enumerate(query_.shape)
] ]
freqs_ci = self.cos_sin_cache.view(*broadcast_shape) freqs_ci = cos_sin_cache.view(*broadcast_shape)
query_out = torch.view_as_real(query_ * freqs_ci).flatten(3) query_out = torch.view_as_real(query_ * freqs_ci).flatten(3)
key_out = torch.view_as_real(key_ * freqs_ci).flatten(3) key_out = torch.view_as_real(key_ * freqs_ci).flatten(3)
return query_out.type_as(query), key_out.type_as(key) return query_out.type_as(query), key_out.type_as(key)
......
...@@ -369,7 +369,11 @@ def llama_model_invariants( ...@@ -369,7 +369,11 @@ def llama_model_invariants(
torch._check(positions.size()[0] == input_ids.size()[0]) torch._check(positions.size()[0] == input_ids.size()[0])
@support_torch_compile(shape_invariants=llama_model_invariants) @support_torch_compile(
# TODO[#32068]: Investigate recompilation
# mark_unbacked_dims={"input_ids": 0},
shape_invariants=llama_model_invariants
)
class LlamaModel(nn.Module): class LlamaModel(nn.Module):
def __init__( def __init__(
self, self,
......
...@@ -31,9 +31,11 @@ from transformers.models.llama4.image_processing_llama4_fast import ( ...@@ -31,9 +31,11 @@ from transformers.models.llama4.image_processing_llama4_fast import (
get_best_fit, get_best_fit,
) )
from vllm.config import VllmConfig from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.config.multimodal import BaseDummyOptions from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.forward_context import set_forward_context
from vllm.model_executor.layers.attention.mm_encoder_attention import MMEncoderAttention from vllm.model_executor.layers.attention.mm_encoder_attention import MMEncoderAttention
from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.linear import ( from vllm.model_executor.layers.linear import (
...@@ -47,6 +49,7 @@ from vllm.model_executor.layers.rotary_embedding import get_rope ...@@ -47,6 +49,7 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.model_loader.utils import initialize_model from vllm.model_executor.model_loader.utils import initialize_model
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.models.vision import should_torch_compile_mm_vit
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import ( from vllm.multimodal.inputs import (
MultiModalDataDict, MultiModalDataDict,
...@@ -456,6 +459,9 @@ class Llama4UnfoldConvolution(nn.Module): ...@@ -456,6 +459,9 @@ class Llama4UnfoldConvolution(nn.Module):
return hidden_states return hidden_states
@support_torch_compile(
dynamic_arg_dims={"images_flattened": 0}, enable_if=should_torch_compile_mm_vit
)
class Llama4VisionModel(nn.Module): class Llama4VisionModel(nn.Module):
def __init__( def __init__(
self, self,
...@@ -497,6 +503,7 @@ class Llama4VisionModel(nn.Module): ...@@ -497,6 +503,7 @@ class Llama4VisionModel(nn.Module):
prefix=f"{prefix}.model", prefix=f"{prefix}.model",
use_data_parallel=use_data_parallel, use_data_parallel=use_data_parallel,
) )
self.vision_adapter = Llama4VisionPixelShuffleMLP( self.vision_adapter = Llama4VisionPixelShuffleMLP(
config, config,
quant_config, quant_config,
...@@ -762,18 +769,28 @@ class Llama4ForConditionalGeneration( ...@@ -762,18 +769,28 @@ class Llama4ForConditionalGeneration(
multimodal_config = vllm_config.model_config.multimodal_config multimodal_config = vllm_config.model_config.multimodal_config
self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data" self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
self.vllm_config = vllm_config
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
self.multimodal_config = multimodal_config self.multimodal_config = multimodal_config
if multimodal_config.get_limit_per_prompt("image"): if multimodal_config.get_limit_per_prompt("image"):
from vllm.compilation.backends import set_model_tag
with (
set_current_vllm_config(vllm_config),
set_model_tag("Llama4VisionModel", is_encoder=True),
):
self.vision_model = Llama4VisionModel( self.vision_model = Llama4VisionModel(
config.vision_config, config=config.vision_config,
None, quant_config=None,
prefix=maybe_prefix(prefix, "vision_model"), prefix=maybe_prefix(prefix, "vision_model"),
use_data_parallel=self.use_data_parallel, use_data_parallel=self.use_data_parallel,
) )
self.multi_modal_projector = Llama4MultiModalProjector( self.multi_modal_projector = Llama4MultiModalProjector(
self.config, None, prefix=maybe_prefix(prefix, "multi_modal_projector") config=self.config,
quant_config=None,
prefix=maybe_prefix(prefix, "multi_modal_projector"),
) )
else: else:
self.vision_model = None self.vision_model = None
...@@ -883,6 +900,9 @@ class Llama4ForConditionalGeneration( ...@@ -883,6 +900,9 @@ class Llama4ForConditionalGeneration(
if image_input is None: if image_input is None:
return [] return []
with (
set_forward_context(None, self.vllm_config),
):
return self._process_image_input(image_input) return self._process_image_input(image_input)
def forward( def forward(
......
...@@ -72,9 +72,9 @@ def flash_attn_maxseqlen_wrapper_fake( ...@@ -72,9 +72,9 @@ def flash_attn_maxseqlen_wrapper_fake(
batch_size: int, batch_size: int,
is_rocm_aiter: bool, is_rocm_aiter: bool,
fa_version: int | None, fa_version: int | None,
scale: float | None, scale: float | None = None,
cu_seqlens: torch.Tensor | None, cu_seqlens: torch.Tensor | None = None,
max_seqlen: torch.Tensor | None, max_seqlen: torch.Tensor | None = None,
) -> torch.Tensor: ) -> torch.Tensor:
return torch.empty_like(q) return torch.empty_like(q)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment