Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
ea6d067a
Unverified
Commit
ea6d067a
authored
Jan 09, 2026
by
Lucas Kabela
Committed by
GitHub
Jan 09, 2026
Browse files
[Misc][LLaMa4] Compile LLaMa Vision Encoder (#30709)
Signed-off-by:
Lucas Kabela
<
lucaskabela@meta.com
>
parent
abd92242
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
85 additions
and
20 deletions
+85
-20
tests/compile/fullgraph/test_multimodal_compile.py
tests/compile/fullgraph/test_multimodal_compile.py
+37
-0
vllm/config/compilation.py
vllm/config/compilation.py
+3
-2
vllm/model_executor/layers/attention/mm_encoder_attention.py
vllm/model_executor/layers/attention/mm_encoder_attention.py
+3
-3
vllm/model_executor/layers/rotary_embedding/llama4_vision_rope.py
...el_executor/layers/rotary_embedding/llama4_vision_rope.py
+5
-2
vllm/model_executor/models/llama.py
vllm/model_executor/models/llama.py
+5
-1
vllm/model_executor/models/mllama4.py
vllm/model_executor/models/mllama4.py
+29
-9
vllm/v1/attention/ops/vit_attn_wrappers.py
vllm/v1/attention/ops/vit_attn_wrappers.py
+3
-3
No files found.
tests/compile/fullgraph/test_multimodal_compile.py
View file @
ea6d067a
...
@@ -71,3 +71,40 @@ def test_qwen2_5_vl_no_vit_compilation(vllm_runner, monkeypatch):
...
@@ -71,3 +71,40 @@ def test_qwen2_5_vl_no_vit_compilation(vllm_runner, monkeypatch):
)
as
_
,
)
as
_
,
):
):
pass
pass
# forked needed to workaround https://github.com/vllm-project/vllm/issues/21073
# Requires Cuda and 8 gpus as well
@
pytest
.
mark
.
forked
@
pytest
.
mark
.
skip
(
reason
=
"Skipping due to CI resource constraints"
)
def
test_mllama4_vit_compilation
(
vllm_runner
,
monkeypatch
):
"""Test that Mllama4 vision submodules are compiled.
This test verifies that the 2 vision submodules (Llama4VisionEncoder,
Llama4VisionPixelShuffleMLP) are properly tagged
for compilation by checking that num_models_seen increases to 3.
However since we are using TP=8, we compilation_counter will not
work properly so we will just check the run succeeds rn
"""
# Disable multiprocessing so that the counter is in the same process
monkeypatch
.
setenv
(
"VLLM_ENABLE_V1_MULTIPROCESSING"
,
"0"
)
with
(
monkeypatch
.
context
(),
# TODO: Since we require TP=8, this messes with the compilation
# counter. We should fix this in the future, but leave for now
# to make sure that compilation runs (no crash) with llama vision encoder
compilation_counter
.
expect
(
num_models_seen
=
0
),
vllm_runner
(
"meta-llama/Llama-4-Scout-17B-16E-Instruct"
,
max_model_len
=
512
,
gpu_memory_utilization
=
0.8
,
tensor_parallel_size
=
8
,
compilation_config
=
{
"mode"
:
CompilationMode
.
VLLM_COMPILE
,
"compile_mm_encoder"
:
True
,
},
),
):
pass
vllm/config/compilation.py
View file @
ea6d067a
...
@@ -430,8 +430,9 @@ class CompilationConfig:
...
@@ -430,8 +430,9 @@ class CompilationConfig:
If empty list [], no ops are excluded (suitable for full cudagraphs)."""
If empty list [], no ops are excluded (suitable for full cudagraphs)."""
compile_mm_encoder
:
bool
=
False
compile_mm_encoder
:
bool
=
False
"""Whether or not to compile the multimodal encoder.
"""Whether or not to compile the multimodal encoder.
Currently, this only works for `Qwen2_5_vl` on selected platforms.
Currently, this only works for `Qwen2_5_vl` and `mLLaMa4` models
Disabled by default until more models are supported/tested to work."""
on selected platforms. Disabled by default until more models
are supported/tested to work."""
# Inductor capture
# Inductor capture
compile_sizes
:
list
[
int
|
str
]
|
None
=
None
compile_sizes
:
list
[
int
|
str
]
|
None
=
None
...
...
vllm/model_executor/layers/attention/mm_encoder_attention.py
View file @
ea6d067a
...
@@ -171,12 +171,12 @@ class MMEncoderAttention(CustomOp):
...
@@ -171,12 +171,12 @@ class MMEncoderAttention(CustomOp):
q
=
query
,
q
=
query
,
k
=
key
,
k
=
key
,
v
=
value
,
v
=
value
,
scale
=
self
.
scale
,
cu_seqlens
=
cu_seqlens
,
max_seqlen
=
max_seqlen
,
batch_size
=
bsz
,
batch_size
=
bsz
,
is_rocm_aiter
=
(
self
.
attn_backend
==
AttentionBackendEnum
.
ROCM_AITER_FA
),
is_rocm_aiter
=
(
self
.
attn_backend
==
AttentionBackendEnum
.
ROCM_AITER_FA
),
fa_version
=
self
.
_fa_version
,
fa_version
=
self
.
_fa_version
,
scale
=
self
.
scale
,
cu_seqlens
=
cu_seqlens
,
max_seqlen
=
max_seqlen
,
)
)
if
is_reshaped
:
if
is_reshaped
:
output
=
output
.
reshape
(
bsz
,
q_len
,
-
1
)
output
=
output
.
reshape
(
bsz
,
q_len
,
-
1
)
...
...
vllm/model_executor/layers/rotary_embedding/llama4_vision_rope.py
View file @
ea6d067a
...
@@ -60,14 +60,17 @@ class Llama4VisionRotaryEmbedding(RotaryEmbeddingBase):
...
@@ -60,14 +60,17 @@ class Llama4VisionRotaryEmbedding(RotaryEmbeddingBase):
assert
key
is
not
None
assert
key
is
not
None
# self.cos_sin_cache here is complex tensor so we cannot cast into
# self.cos_sin_cache here is complex tensor so we cannot cast into
# query's dtype directly with self._match_cos_sin_cache_dtype
# query's dtype directly with self._match_cos_sin_cache_dtype
self
.
cos_sin_cache
:
torch
.
Tensor
=
self
.
cos_sin_cache
.
to
(
query
.
device
)
# NOTE: by not storing cos_sin_cache in self, we can avoid
# memory buffer update which is costly to runtime
cos_sin_cache
:
torch
.
Tensor
=
self
.
cos_sin_cache
.
to
(
query
.
device
)
query_
=
torch
.
view_as_complex
(
query
.
float
().
reshape
(
*
query
.
shape
[:
-
1
],
-
1
,
2
))
query_
=
torch
.
view_as_complex
(
query
.
float
().
reshape
(
*
query
.
shape
[:
-
1
],
-
1
,
2
))
key_
=
torch
.
view_as_complex
(
key
.
float
().
reshape
(
*
key
.
shape
[:
-
1
],
-
1
,
2
))
key_
=
torch
.
view_as_complex
(
key
.
float
().
reshape
(
*
key
.
shape
[:
-
1
],
-
1
,
2
))
broadcast_shape
=
[
broadcast_shape
=
[
d
if
i
==
1
or
i
==
(
query_
.
ndim
-
1
)
else
1
d
if
i
==
1
or
i
==
(
query_
.
ndim
-
1
)
else
1
for
i
,
d
in
enumerate
(
query_
.
shape
)
for
i
,
d
in
enumerate
(
query_
.
shape
)
]
]
freqs_ci
=
self
.
cos_sin_cache
.
view
(
*
broadcast_shape
)
freqs_ci
=
cos_sin_cache
.
view
(
*
broadcast_shape
)
query_out
=
torch
.
view_as_real
(
query_
*
freqs_ci
).
flatten
(
3
)
query_out
=
torch
.
view_as_real
(
query_
*
freqs_ci
).
flatten
(
3
)
key_out
=
torch
.
view_as_real
(
key_
*
freqs_ci
).
flatten
(
3
)
key_out
=
torch
.
view_as_real
(
key_
*
freqs_ci
).
flatten
(
3
)
return
query_out
.
type_as
(
query
),
key_out
.
type_as
(
key
)
return
query_out
.
type_as
(
query
),
key_out
.
type_as
(
key
)
...
...
vllm/model_executor/models/llama.py
View file @
ea6d067a
...
@@ -369,7 +369,11 @@ def llama_model_invariants(
...
@@ -369,7 +369,11 @@ def llama_model_invariants(
torch
.
_check
(
positions
.
size
()[
0
]
==
input_ids
.
size
()[
0
])
torch
.
_check
(
positions
.
size
()[
0
]
==
input_ids
.
size
()[
0
])
@
support_torch_compile
(
shape_invariants
=
llama_model_invariants
)
@
support_torch_compile
(
# TODO[#32068]: Investigate recompilation
# mark_unbacked_dims={"input_ids": 0},
shape_invariants
=
llama_model_invariants
)
class
LlamaModel
(
nn
.
Module
):
class
LlamaModel
(
nn
.
Module
):
def
__init__
(
def
__init__
(
self
,
self
,
...
...
vllm/model_executor/models/mllama4.py
View file @
ea6d067a
...
@@ -31,9 +31,11 @@ from transformers.models.llama4.image_processing_llama4_fast import (
...
@@ -31,9 +31,11 @@ from transformers.models.llama4.image_processing_llama4_fast import (
get_best_fit
,
get_best_fit
,
)
)
from
vllm.config
import
VllmConfig
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
VllmConfig
,
set_current_vllm_config
from
vllm.config.multimodal
import
BaseDummyOptions
from
vllm.config.multimodal
import
BaseDummyOptions
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.forward_context
import
set_forward_context
from
vllm.model_executor.layers.attention.mm_encoder_attention
import
MMEncoderAttention
from
vllm.model_executor.layers.attention.mm_encoder_attention
import
MMEncoderAttention
from
vllm.model_executor.layers.fused_moe
import
FusedMoE
from
vllm.model_executor.layers.fused_moe
import
FusedMoE
from
vllm.model_executor.layers.linear
import
(
from
vllm.model_executor.layers.linear
import
(
...
@@ -47,6 +49,7 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
...
@@ -47,6 +49,7 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
from
vllm.model_executor.model_loader.utils
import
initialize_model
from
vllm.model_executor.model_loader.utils
import
initialize_model
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.models.module_mapping
import
MultiModelKeys
from
vllm.model_executor.models.module_mapping
import
MultiModelKeys
from
vllm.model_executor.models.vision
import
should_torch_compile_mm_vit
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal.inputs
import
(
from
vllm.multimodal.inputs
import
(
MultiModalDataDict
,
MultiModalDataDict
,
...
@@ -456,6 +459,9 @@ class Llama4UnfoldConvolution(nn.Module):
...
@@ -456,6 +459,9 @@ class Llama4UnfoldConvolution(nn.Module):
return
hidden_states
return
hidden_states
@
support_torch_compile
(
dynamic_arg_dims
=
{
"images_flattened"
:
0
},
enable_if
=
should_torch_compile_mm_vit
)
class
Llama4VisionModel
(
nn
.
Module
):
class
Llama4VisionModel
(
nn
.
Module
):
def
__init__
(
def
__init__
(
self
,
self
,
...
@@ -497,6 +503,7 @@ class Llama4VisionModel(nn.Module):
...
@@ -497,6 +503,7 @@ class Llama4VisionModel(nn.Module):
prefix
=
f
"
{
prefix
}
.model"
,
prefix
=
f
"
{
prefix
}
.model"
,
use_data_parallel
=
use_data_parallel
,
use_data_parallel
=
use_data_parallel
,
)
)
self
.
vision_adapter
=
Llama4VisionPixelShuffleMLP
(
self
.
vision_adapter
=
Llama4VisionPixelShuffleMLP
(
config
,
config
,
quant_config
,
quant_config
,
...
@@ -762,18 +769,28 @@ class Llama4ForConditionalGeneration(
...
@@ -762,18 +769,28 @@ class Llama4ForConditionalGeneration(
multimodal_config
=
vllm_config
.
model_config
.
multimodal_config
multimodal_config
=
vllm_config
.
model_config
.
multimodal_config
self
.
use_data_parallel
=
multimodal_config
.
mm_encoder_tp_mode
==
"data"
self
.
use_data_parallel
=
multimodal_config
.
mm_encoder_tp_mode
==
"data"
self
.
vllm_config
=
vllm_config
self
.
config
=
config
self
.
config
=
config
self
.
quant_config
=
quant_config
self
.
quant_config
=
quant_config
self
.
multimodal_config
=
multimodal_config
self
.
multimodal_config
=
multimodal_config
if
multimodal_config
.
get_limit_per_prompt
(
"image"
):
if
multimodal_config
.
get_limit_per_prompt
(
"image"
):
self
.
vision_model
=
Llama4VisionModel
(
from
vllm.compilation.backends
import
set_model_tag
config
.
vision_config
,
None
,
with
(
prefix
=
maybe_prefix
(
prefix
,
"vision_model"
),
set_current_vllm_config
(
vllm_config
),
use_data_parallel
=
self
.
use_data_parallel
,
set_model_tag
(
"Llama4VisionModel"
,
is_encoder
=
True
),
)
):
self
.
vision_model
=
Llama4VisionModel
(
config
=
config
.
vision_config
,
quant_config
=
None
,
prefix
=
maybe_prefix
(
prefix
,
"vision_model"
),
use_data_parallel
=
self
.
use_data_parallel
,
)
self
.
multi_modal_projector
=
Llama4MultiModalProjector
(
self
.
multi_modal_projector
=
Llama4MultiModalProjector
(
self
.
config
,
None
,
prefix
=
maybe_prefix
(
prefix
,
"multi_modal_projector"
)
config
=
self
.
config
,
quant_config
=
None
,
prefix
=
maybe_prefix
(
prefix
,
"multi_modal_projector"
),
)
)
else
:
else
:
self
.
vision_model
=
None
self
.
vision_model
=
None
...
@@ -883,7 +900,10 @@ class Llama4ForConditionalGeneration(
...
@@ -883,7 +900,10 @@ class Llama4ForConditionalGeneration(
if
image_input
is
None
:
if
image_input
is
None
:
return
[]
return
[]
return
self
.
_process_image_input
(
image_input
)
with
(
set_forward_context
(
None
,
self
.
vllm_config
),
):
return
self
.
_process_image_input
(
image_input
)
def
forward
(
def
forward
(
self
,
self
,
...
...
vllm/v1/attention/ops/vit_attn_wrappers.py
View file @
ea6d067a
...
@@ -72,9 +72,9 @@ def flash_attn_maxseqlen_wrapper_fake(
...
@@ -72,9 +72,9 @@ def flash_attn_maxseqlen_wrapper_fake(
batch_size
:
int
,
batch_size
:
int
,
is_rocm_aiter
:
bool
,
is_rocm_aiter
:
bool
,
fa_version
:
int
|
None
,
fa_version
:
int
|
None
,
scale
:
float
|
None
,
scale
:
float
|
None
=
None
,
cu_seqlens
:
torch
.
Tensor
|
None
,
cu_seqlens
:
torch
.
Tensor
|
None
=
None
,
max_seqlen
:
torch
.
Tensor
|
None
,
max_seqlen
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
return
torch
.
empty_like
(
q
)
return
torch
.
empty_like
(
q
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment