Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
c5752323
Unverified
Commit
c5752323
authored
Apr 05, 2025
by
Lu Fang
Committed by
GitHub
Apr 05, 2025
Browse files
[Model] Support Llama4 in vLLM (#16104)
parent
63375f0c
Changes
35
Hide whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
1864 additions
and
91 deletions
+1864
-91
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
...quantization/compressed_tensors/compressed_tensors_moe.py
+24
-14
vllm/model_executor/layers/quantization/experts_int8.py
vllm/model_executor/layers/quantization/experts_int8.py
+15
-12
vllm/model_executor/layers/quantization/fp8.py
vllm/model_executor/layers/quantization/fp8.py
+2
-0
vllm/model_executor/layers/quantization/gguf.py
vllm/model_executor/layers/quantization/gguf.py
+6
-0
vllm/model_executor/layers/quantization/gptq_marlin.py
vllm/model_executor/layers/quantization/gptq_marlin.py
+5
-0
vllm/model_executor/layers/quantization/moe_wna16.py
vllm/model_executor/layers/quantization/moe_wna16.py
+18
-15
vllm/model_executor/layers/quantization/quark/quark_moe.py
vllm/model_executor/layers/quantization/quark/quark_moe.py
+17
-13
vllm/model_executor/layers/rotary_embedding.py
vllm/model_executor/layers/rotary_embedding.py
+68
-0
vllm/model_executor/models/llama.py
vllm/model_executor/models/llama.py
+17
-5
vllm/model_executor/models/llama4.py
vllm/model_executor/models/llama4.py
+530
-0
vllm/model_executor/models/mllama4.py
vllm/model_executor/models/mllama4.py
+886
-0
vllm/model_executor/models/registry.py
vllm/model_executor/models/registry.py
+2
-0
vllm/v1/attention/backends/flash_attn.py
vllm/v1/attention/backends/flash_attn.py
+236
-14
vllm/v1/attention/backends/triton_attn.py
vllm/v1/attention/backends/triton_attn.py
+37
-18
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+1
-0
No files found.
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
View file @
c5752323
...
@@ -224,6 +224,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
...
@@ -224,6 +224,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
custom_routing_function
:
Optional
[
Callable
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
scoring_func
:
str
=
"softmax"
,
scoring_func
:
str
=
"softmax"
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
apply_router_weight_on_input
:
bool
=
False
,
activation
:
str
=
"silu"
,
activation
:
str
=
"silu"
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
from
vllm.model_executor.layers.fused_moe
import
fused_experts
from
vllm.model_executor.layers.fused_moe
import
fused_experts
...
@@ -240,20 +241,22 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
...
@@ -240,20 +241,22 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
scoring_func
=
scoring_func
,
scoring_func
=
scoring_func
,
e_score_correction_bias
=
e_score_correction_bias
)
e_score_correction_bias
=
e_score_correction_bias
)
return
fused_experts
(
x
,
return
fused_experts
(
layer
.
w13_weight
,
x
,
layer
.
w2_weight
,
layer
.
w13_weight
,
topk_weights
=
topk_weights
,
layer
.
w2_weight
,
topk_ids
=
topk_ids
,
topk_weights
=
topk_weights
,
inplace
=
True
,
topk_ids
=
topk_ids
,
activation
=
activation
,
inplace
=
True
,
use_fp8_w8a8
=
True
,
activation
=
activation
,
global_num_experts
=
global_num_experts
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
expert_map
=
expert_map
,
use_fp8_w8a8
=
True
,
w1_scale
=
layer
.
w13_weight_scale
,
global_num_experts
=
global_num_experts
,
w2_scale
=
layer
.
w2_weight_scale
,
expert_map
=
expert_map
,
a1_scale
=
layer
.
w13_input_scale
,
w1_scale
=
layer
.
w13_weight_scale
,
a2_scale
=
layer
.
w2_input_scale
)
w2_scale
=
layer
.
w2_weight_scale
,
a1_scale
=
layer
.
w13_input_scale
,
a2_scale
=
layer
.
w2_input_scale
)
class
CompressedTensorsW8A8Fp8MoECutlassMethod
(
CompressedTensorsMoEMethod
):
class
CompressedTensorsW8A8Fp8MoECutlassMethod
(
CompressedTensorsMoEMethod
):
...
@@ -438,6 +441,7 @@ class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod):
...
@@ -438,6 +441,7 @@ class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod):
custom_routing_function
:
Optional
[
Callable
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
scoring_func
:
str
=
"softmax"
,
scoring_func
:
str
=
"softmax"
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
apply_router_weight_on_input
:
bool
=
False
,
activation
:
str
=
"silu"
,
activation
:
str
=
"silu"
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
...
@@ -474,6 +478,7 @@ class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod):
...
@@ -474,6 +478,7 @@ class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod):
a1_scale
=
layer
.
w13_input_scale
,
a1_scale
=
layer
.
w13_input_scale
,
a2_scale
=
layer
.
w2_input_scale
,
a2_scale
=
layer
.
w2_input_scale
,
out_dtype
=
x
.
dtype
,
out_dtype
=
x
.
dtype
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
)
)
...
@@ -778,6 +783,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
...
@@ -778,6 +783,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
custom_routing_function
:
Optional
[
Callable
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
scoring_func
:
str
=
"softmax"
,
scoring_func
:
str
=
"softmax"
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
apply_router_weight_on_input
:
bool
=
False
,
activation
:
str
=
"silu"
,
activation
:
str
=
"silu"
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
assert
activation
==
"silu"
,
"Only SiLU activation is supported."
assert
activation
==
"silu"
,
"Only SiLU activation is supported."
...
@@ -785,6 +791,10 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
...
@@ -785,6 +791,10 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
raise
NotImplementedError
(
raise
NotImplementedError
(
"Expert Parallelism is not supported for "
"Expert Parallelism is not supported for "
"fused Marlin MoE method."
)
"fused Marlin MoE method."
)
if
apply_router_weight_on_input
:
raise
NotImplementedError
(
"Apply router weight on input is not supported for "
"fused Marlin MoE method."
)
topk_weights
,
topk_ids
=
FusedMoE
.
select_experts
(
topk_weights
,
topk_ids
=
FusedMoE
.
select_experts
(
hidden_states
=
x
,
hidden_states
=
x
,
...
...
vllm/model_executor/layers/quantization/experts_int8.py
View file @
c5752323
...
@@ -113,6 +113,7 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase):
...
@@ -113,6 +113,7 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase):
custom_routing_function
:
Optional
[
Callable
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
scoring_func
:
str
=
"softmax"
,
scoring_func
:
str
=
"softmax"
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
apply_router_weight_on_input
:
bool
=
False
,
activation
:
str
=
"silu"
,
activation
:
str
=
"silu"
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
from
vllm.model_executor.layers.fused_moe
import
fused_experts
from
vllm.model_executor.layers.fused_moe
import
fused_experts
...
@@ -129,18 +130,20 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase):
...
@@ -129,18 +130,20 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase):
scoring_func
=
scoring_func
,
scoring_func
=
scoring_func
,
e_score_correction_bias
=
e_score_correction_bias
)
e_score_correction_bias
=
e_score_correction_bias
)
return
fused_experts
(
x
,
return
fused_experts
(
layer
.
w13_weight
,
x
,
layer
.
w2_weight
,
layer
.
w13_weight
,
topk_weights
=
topk_weights
,
layer
.
w2_weight
,
topk_ids
=
topk_ids
,
topk_weights
=
topk_weights
,
inplace
=
True
,
topk_ids
=
topk_ids
,
activation
=
activation
,
inplace
=
True
,
use_int8_w8a16
=
True
,
activation
=
activation
,
global_num_experts
=
global_num_experts
,
use_int8_w8a16
=
True
,
expert_map
=
expert_map
,
global_num_experts
=
global_num_experts
,
w1_scale
=
layer
.
w13_scale
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
w2_scale
=
layer
.
w2_scale
)
expert_map
=
expert_map
,
w1_scale
=
layer
.
w13_scale
,
w2_scale
=
layer
.
w2_scale
)
@
staticmethod
@
staticmethod
def
quantizing_weight_loader
(
layer
,
weight_loader
):
def
quantizing_weight_loader
(
layer
,
weight_loader
):
...
...
vllm/model_executor/layers/quantization/fp8.py
View file @
c5752323
...
@@ -773,6 +773,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
...
@@ -773,6 +773,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
custom_routing_function
:
Optional
[
Callable
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
scoring_func
:
str
=
"softmax"
,
scoring_func
:
str
=
"softmax"
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
apply_router_weight_on_input
:
bool
=
False
,
activation
:
str
=
"silu"
,
activation
:
str
=
"silu"
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
from
vllm.model_executor.layers.fused_moe
import
fused_experts
from
vllm.model_executor.layers.fused_moe
import
fused_experts
...
@@ -800,6 +801,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
...
@@ -800,6 +801,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
activation
=
activation
,
activation
=
activation
,
use_fp8_w8a8
=
True
,
use_fp8_w8a8
=
True
,
global_num_experts
=
global_num_experts
,
global_num_experts
=
global_num_experts
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
expert_map
=
expert_map
,
expert_map
=
expert_map
,
w1_scale
=
(
layer
.
w13_weight_scale_inv
w1_scale
=
(
layer
.
w13_weight_scale_inv
if
self
.
block_quant
else
layer
.
w13_weight_scale
),
if
self
.
block_quant
else
layer
.
w13_weight_scale
),
...
...
vllm/model_executor/layers/quantization/gguf.py
View file @
c5752323
...
@@ -338,9 +338,15 @@ class GGUFMoEMethod(FusedMoEMethodBase):
...
@@ -338,9 +338,15 @@ class GGUFMoEMethod(FusedMoEMethodBase):
custom_routing_function
:
Optional
[
Callable
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
scoring_func
:
str
=
"softmax"
,
scoring_func
:
str
=
"softmax"
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
apply_router_weight_on_input
:
bool
=
False
,
activation
:
str
=
"silu"
,
activation
:
str
=
"silu"
,
):
):
assert
activation
==
"silu"
,
"Only SiLU activation is supported."
assert
activation
==
"silu"
,
"Only SiLU activation is supported."
if
apply_router_weight_on_input
:
raise
NotImplementedError
(
"Apply router weight on input is not supported for"
"fused GGUF MoE method."
)
topk_weights
,
topk_ids
=
FusedMoE
.
select_experts
(
topk_weights
,
topk_ids
=
FusedMoE
.
select_experts
(
hidden_states
=
x
,
hidden_states
=
x
,
router_logits
=
router_logits
,
router_logits
=
router_logits
,
...
...
vllm/model_executor/layers/quantization/gptq_marlin.py
View file @
c5752323
...
@@ -592,9 +592,14 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
...
@@ -592,9 +592,14 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
custom_routing_function
:
Optional
[
Callable
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
scoring_func
:
str
=
"softmax"
,
scoring_func
:
str
=
"softmax"
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
apply_router_weight_on_input
:
bool
=
False
,
activation
:
str
=
"silu"
,
activation
:
str
=
"silu"
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
assert
activation
==
"silu"
,
"Only SiLU activation is supported."
assert
activation
==
"silu"
,
"Only SiLU activation is supported."
if
apply_router_weight_on_input
is
not
None
:
raise
NotImplementedError
(
"Apply router weight on input is not supported for"
"fused Marlin MoE method."
)
# The input must currently be float16
# The input must currently be float16
orig_dtype
=
x
.
dtype
orig_dtype
=
x
.
dtype
...
...
vllm/model_executor/layers/quantization/moe_wna16.py
View file @
c5752323
...
@@ -293,6 +293,7 @@ class MoeWNA16Method(FusedMoEMethodBase):
...
@@ -293,6 +293,7 @@ class MoeWNA16Method(FusedMoEMethodBase):
custom_routing_function
:
Optional
[
Callable
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
scoring_func
:
str
=
"softmax"
,
scoring_func
:
str
=
"softmax"
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
apply_router_weight_on_input
:
bool
=
False
,
activation
:
str
=
"silu"
,
activation
:
str
=
"silu"
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
from
vllm.model_executor.layers.fused_moe
import
fused_experts
from
vllm.model_executor.layers.fused_moe
import
fused_experts
...
@@ -312,21 +313,23 @@ class MoeWNA16Method(FusedMoEMethodBase):
...
@@ -312,21 +313,23 @@ class MoeWNA16Method(FusedMoEMethodBase):
weight_bits
=
self
.
quant_config
.
weight_bits
weight_bits
=
self
.
quant_config
.
weight_bits
has_zp
=
self
.
quant_config
.
has_zp
has_zp
=
self
.
quant_config
.
has_zp
return
fused_experts
(
x
,
return
fused_experts
(
layer
.
w13_qweight
,
x
,
layer
.
w2_qweight
,
layer
.
w13_qweight
,
topk_weights
=
topk_weights
,
layer
.
w2_qweight
,
topk_ids
=
topk_ids
,
topk_weights
=
topk_weights
,
inplace
=
True
,
topk_ids
=
topk_ids
,
use_int4_w4a16
=
weight_bits
==
4
,
inplace
=
True
,
use_int8_w8a16
=
weight_bits
==
8
,
use_int4_w4a16
=
weight_bits
==
4
,
global_num_experts
=
global_num_experts
,
use_int8_w8a16
=
weight_bits
==
8
,
expert_map
=
expert_map
,
global_num_experts
=
global_num_experts
,
w1_scale
=
layer
.
w13_scales
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
w2_scale
=
layer
.
w2_scales
,
expert_map
=
expert_map
,
w1_zp
=
layer
.
w13_qzeros
if
has_zp
else
None
,
w1_scale
=
layer
.
w13_scales
,
w2_zp
=
layer
.
w2_qzeros
if
has_zp
else
None
,
w2_scale
=
layer
.
w2_scales
,
block_shape
=
[
0
,
layer
.
group_size
])
w1_zp
=
layer
.
w13_qzeros
if
has_zp
else
None
,
w2_zp
=
layer
.
w2_qzeros
if
has_zp
else
None
,
block_shape
=
[
0
,
layer
.
group_size
])
@
staticmethod
@
staticmethod
def
get_weight_loader
(
layer
,
weight_loader
):
def
get_weight_loader
(
layer
,
weight_loader
):
...
...
vllm/model_executor/layers/quantization/quark/quark_moe.py
View file @
c5752323
...
@@ -202,6 +202,8 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
...
@@ -202,6 +202,8 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
custom_routing_function
:
Optional
[
Callable
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
scoring_func
:
str
=
"softmax"
,
scoring_func
:
str
=
"softmax"
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
apply_router_weight_on_input
:
bool
=
False
,
activation
:
str
=
"silu"
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
from
vllm.model_executor.layers.fused_moe
import
fused_experts
from
vllm.model_executor.layers.fused_moe
import
fused_experts
...
@@ -217,16 +219,18 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
...
@@ -217,16 +219,18 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
scoring_func
=
scoring_func
,
scoring_func
=
scoring_func
,
e_score_correction_bias
=
e_score_correction_bias
)
e_score_correction_bias
=
e_score_correction_bias
)
return
fused_experts
(
x
,
return
fused_experts
(
layer
.
w13_weight
,
x
,
layer
.
w2_weight
,
layer
.
w13_weight
,
topk_weights
=
topk_weights
,
layer
.
w2_weight
,
topk_ids
=
topk_ids
,
topk_weights
=
topk_weights
,
inplace
=
True
,
topk_ids
=
topk_ids
,
use_fp8_w8a8
=
True
,
inplace
=
True
,
global_num_experts
=
global_num_experts
,
use_fp8_w8a8
=
True
,
expert_map
=
expert_map
,
global_num_experts
=
global_num_experts
,
w1_scale
=
layer
.
w13_weight_scale
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
w2_scale
=
layer
.
w2_weight_scale
,
expert_map
=
expert_map
,
a1_scale
=
layer
.
w13_input_scale
,
w1_scale
=
layer
.
w13_weight_scale
,
a2_scale
=
layer
.
w2_input_scale
)
w2_scale
=
layer
.
w2_weight_scale
,
a1_scale
=
layer
.
w13_input_scale
,
a2_scale
=
layer
.
w2_input_scale
)
vllm/model_executor/layers/rotary_embedding.py
View file @
c5752323
...
@@ -851,6 +851,70 @@ class Llama3RotaryEmbedding(RotaryEmbedding):
...
@@ -851,6 +851,70 @@ class Llama3RotaryEmbedding(RotaryEmbedding):
return
new_freqs
return
new_freqs
class
Llama4VisionRotaryEmbedding
(
RotaryEmbedding
):
def
__init__
(
self
,
head_size
:
int
,
rotary_dim
:
int
,
max_position_embeddings
:
int
,
base
:
int
,
is_neox_style
:
bool
,
dtype
:
torch
.
dtype
,
):
super
().
__init__
(
head_size
,
rotary_dim
,
max_position_embeddings
,
base
,
is_neox_style
,
dtype
)
def
_compute_inv_freq
(
self
,
base
:
Union
[
int
,
float
])
->
torch
.
Tensor
:
inv_freqs
=
super
().
_compute_inv_freq
(
base
)
inv_freqs
=
inv_freqs
[:(
self
.
rotary_dim
//
2
)]
return
inv_freqs
def
_compute_cos_sin_cache
(
self
)
->
torch
.
Tensor
:
inv_freq
=
self
.
_compute_inv_freq
(
self
.
base
)
# self.max_position_embeddings here is number of image patches
# i.e. (image_size // patch_size) ** 2
num_patches
=
self
.
max_position_embeddings
img_idx
=
torch
.
arange
(
num_patches
,
dtype
=
torch
.
int32
)
\
.
reshape
(
num_patches
,
1
)
img_idx
=
torch
.
cat
([
img_idx
,
img_idx
[:
1
]],
dim
=
0
)
img_idx
[
-
1
,
-
1
]
=
-
2
# set to ID_CLS_TOKEN
num_patches_single_dim
=
int
(
math
.
sqrt
(
num_patches
))
frequencies_x
=
img_idx
%
num_patches_single_dim
frequencies_y
=
img_idx
//
num_patches_single_dim
freqs_x
=
((
frequencies_x
+
1
)[...,
None
]
*
inv_freq
[
None
,
None
,
:]).
repeat_interleave
(
2
,
dim
=-
1
)
freqs_y
=
((
frequencies_y
+
1
)[...,
None
]
*
inv_freq
[
None
,
None
,
:]).
repeat_interleave
(
2
,
dim
=-
1
)
freqs
=
torch
.
cat
([
freqs_x
,
freqs_y
],
dim
=-
1
).
float
().
contiguous
()[...,
::
2
]
freqs
=
freqs
.
masked_fill
(
img_idx
.
reshape
(
-
1
,
1
,
1
)
<
0
,
0
)
cache
=
torch
.
view_as_complex
(
torch
.
stack
([
torch
.
cos
(
freqs
),
torch
.
sin
(
freqs
)],
dim
=-
1
))
return
cache
def
forward
(
self
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
self
.
cos_sin_cache
:
torch
.
Tensor
=
self
.
cos_sin_cache
.
to
(
query
.
device
)
query_
=
torch
.
view_as_complex
(
query
.
float
().
reshape
(
*
query
.
shape
[:
-
1
],
-
1
,
2
))
key_
=
torch
.
view_as_complex
(
key
.
float
().
reshape
(
*
key
.
shape
[:
-
1
],
-
1
,
2
))
broadcast_shape
=
[
d
if
i
==
1
or
i
==
(
query_
.
ndim
-
1
)
else
1
for
i
,
d
in
enumerate
(
query_
.
shape
)
]
freqs_ci
=
self
.
cos_sin_cache
.
view
(
*
broadcast_shape
)
query_out
=
torch
.
view_as_real
(
query_
*
freqs_ci
).
flatten
(
3
)
key_out
=
torch
.
view_as_real
(
key_
*
freqs_ci
).
flatten
(
3
)
return
query_out
.
type_as
(
query
),
key_out
.
type_as
(
key
)
class
MRotaryEmbedding
(
RotaryEmbedding
):
class
MRotaryEmbedding
(
RotaryEmbedding
):
"""Rotary Embedding with Multimodal Sections."""
"""Rotary Embedding with Multimodal Sections."""
...
@@ -1130,6 +1194,10 @@ def get_rope(
...
@@ -1130,6 +1194,10 @@ def get_rope(
scaling_factor
,
low_freq_factor
,
scaling_factor
,
low_freq_factor
,
high_freq_factor
,
high_freq_factor
,
original_max_position
)
original_max_position
)
elif
scaling_type
==
"mllama4"
:
rotary_emb
=
Llama4VisionRotaryEmbedding
(
head_size
,
rotary_dim
,
max_position
,
base
,
is_neox_style
,
dtype
)
elif
scaling_type
==
"default"
:
elif
scaling_type
==
"default"
:
if
"mrope_section"
in
rope_scaling
:
if
"mrope_section"
in
rope_scaling
:
rotary_emb
=
MRotaryEmbedding
(
rotary_emb
=
MRotaryEmbedding
(
...
...
vllm/model_executor/models/llama.py
View file @
c5752323
...
@@ -65,6 +65,7 @@ class LlamaMLP(nn.Module):
...
@@ -65,6 +65,7 @@ class LlamaMLP(nn.Module):
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
bias
:
bool
=
False
,
bias
:
bool
=
False
,
prefix
:
str
=
""
,
prefix
:
str
=
""
,
reduce_results
:
bool
=
True
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
gate_up_proj
=
MergedColumnParallelLinear
(
self
.
gate_up_proj
=
MergedColumnParallelLinear
(
...
@@ -79,6 +80,7 @@ class LlamaMLP(nn.Module):
...
@@ -79,6 +80,7 @@ class LlamaMLP(nn.Module):
output_size
=
hidden_size
,
output_size
=
hidden_size
,
bias
=
bias
,
bias
=
bias
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
reduce_results
=
reduce_results
,
prefix
=
f
"
{
prefix
}
.down_proj"
,
prefix
=
f
"
{
prefix
}
.down_proj"
,
)
)
if
hidden_act
!=
"silu"
:
if
hidden_act
!=
"silu"
:
...
@@ -466,10 +468,14 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
...
@@ -466,10 +468,14 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
"ffn_norm"
:
"post_attention_layernorm"
,
"ffn_norm"
:
"post_attention_layernorm"
,
"tok_embeddings"
:
"model.embed_tokens"
,
"tok_embeddings"
:
"model.embed_tokens"
,
"output"
:
"lm_head"
,
"output"
:
"lm_head"
,
"norm"
:
"model.norm"
"norm"
:
"model.norm"
,
}
}
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
,
layer_type
:
Type
[
LlamaDecoderLayer
]
=
LlamaDecoderLayer
):
super
().
__init__
()
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
config
=
vllm_config
.
model_config
.
hf_config
quant_config
=
vllm_config
.
quant_config
quant_config
=
vllm_config
.
quant_config
...
@@ -478,7 +484,8 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
...
@@ -478,7 +484,8 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self
.
lora_config
=
lora_config
self
.
lora_config
=
lora_config
self
.
model
=
self
.
_init_model
(
vllm_config
=
vllm_config
,
self
.
model
=
self
.
_init_model
(
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"model"
))
prefix
=
maybe_prefix
(
prefix
,
"model"
),
layer_type
=
layer_type
)
if
get_pp_group
().
is_last_rank
:
if
get_pp_group
().
is_last_rank
:
self
.
unpadded_vocab_size
=
config
.
vocab_size
self
.
unpadded_vocab_size
=
config
.
vocab_size
...
@@ -513,8 +520,13 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
...
@@ -513,8 +520,13 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self
.
make_empty_intermediate_tensors
=
(
self
.
make_empty_intermediate_tensors
=
(
self
.
model
.
make_empty_intermediate_tensors
)
self
.
model
.
make_empty_intermediate_tensors
)
def
_init_model
(
self
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
def
_init_model
(
self
,
return
LlamaModel
(
vllm_config
=
vllm_config
,
prefix
=
prefix
)
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
,
layer_type
:
Type
[
LlamaDecoderLayer
]
=
LlamaDecoderLayer
):
return
LlamaModel
(
vllm_config
=
vllm_config
,
prefix
=
prefix
,
layer_type
=
layer_type
)
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
model
.
get_input_embeddings
(
input_ids
)
return
self
.
model
.
get_input_embeddings
(
input_ids
)
...
...
vllm/model_executor/models/llama4.py
0 → 100644
View file @
c5752323
# SPDX-License-Identifier: Apache-2.0
#
# Copyright 2025 the LLAMA4, Meta Inc., vLLM, and HuggingFace Inc. team.
# All rights reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only LLaMA model compatible with HuggingFace weights."""
from
typing
import
Any
,
Dict
,
Iterable
,
List
,
Optional
,
Set
,
Tuple
,
Type
import
torch
from
torch
import
nn
from
transformers
import
Llama4TextConfig
from
vllm.attention
import
Attention
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.distributed
import
(
get_tensor_model_parallel_world_size
,
tensor_model_parallel_all_reduce
)
from
vllm.model_executor.layers.fused_moe
import
FusedMoE
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
QKVParallelLinear
,
ReplicatedLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
.llama
import
LlamaDecoderLayer
,
LlamaForCausalLM
,
LlamaMLP
,
LlamaModel
from
.utils
import
(
AutoWeightsLoader
,
extract_layer_index
,
is_pp_missing_parameter
)
class
Llama4MoE
(
nn
.
Module
):
@
staticmethod
def
custom_routing_function
(
hidden_states
:
torch
.
Tensor
,
gating_output
:
torch
.
Tensor
,
topk
:
int
,
renormalize
:
bool
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
router_scores
,
router_indices
=
torch
.
topk
(
gating_output
,
topk
,
dim
=-
1
)
router_scores
=
torch
.
sigmoid
(
router_scores
.
float
()).
to
(
hidden_states
.
dtype
)
return
(
router_scores
,
router_indices
.
to
(
torch
.
int32
))
def
__init__
(
self
,
config
:
Llama4TextConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
):
super
().
__init__
()
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
top_k
=
config
.
num_experts_per_tok
intermediate_size_moe
=
config
.
intermediate_size
self
.
router
=
ReplicatedLinear
(
config
.
hidden_size
,
config
.
num_local_experts
,
bias
=
False
,
quant_config
=
None
,
prefix
=
f
"
{
prefix
}
.router"
)
self
.
experts
=
FusedMoE
(
num_experts
=
config
.
num_local_experts
,
top_k
=
config
.
num_experts_per_tok
,
hidden_size
=
config
.
hidden_size
,
custom_routing_function
=
Llama4MoE
.
custom_routing_function
,
intermediate_size
=
intermediate_size_moe
,
apply_router_weight_on_input
=
True
,
reduce_results
=
False
,
renormalize
=
False
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.experts"
)
self
.
shared_expert
=
LlamaMLP
(
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
intermediate_size_moe
,
hidden_act
=
"silu"
,
quant_config
=
quant_config
,
bias
=
False
,
prefix
=
f
"
{
prefix
}
.shared_expert"
,
reduce_results
=
False
,
# We need to do scatter before reduce
)
def
forward
(
self
,
hidden_states
):
router_logits
,
_
=
self
.
router
(
hidden_states
)
shared_out
=
self
.
shared_expert
(
hidden_states
)
routed_out
=
self
.
experts
(
hidden_states
=
hidden_states
,
router_logits
=
router_logits
,
)
experts_out
=
routed_out
+
shared_out
if
self
.
tp_size
>
1
:
experts_out
=
tensor_model_parallel_all_reduce
(
experts_out
)
return
experts_out
class
Llama4Attention
(
nn
.
Module
):
def
__init__
(
self
,
config
:
Llama4TextConfig
,
hidden_size
:
int
,
num_heads
:
int
,
num_kv_heads
:
int
,
rope_theta
:
float
=
10000
,
rope_scaling
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
max_position_embeddings
:
int
=
8192
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
bias
:
bool
=
False
,
bias_o_proj
:
bool
=
False
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
prefix
:
str
=
""
)
->
None
:
super
().
__init__
()
self
.
layer_idx
=
extract_layer_index
(
prefix
)
self
.
hidden_size
=
hidden_size
self
.
no_rope_layers
=
config
.
no_rope_layers
self
.
nope
=
self
.
no_rope_layers
[
self
.
layer_idx
]
==
0
self
.
use_qk_norm
=
config
.
use_qk_norm
and
not
self
.
nope
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
total_num_heads
=
num_heads
assert
self
.
total_num_heads
%
tp_size
==
0
self
.
num_heads
=
self
.
total_num_heads
//
tp_size
self
.
total_num_kv_heads
=
num_kv_heads
if
self
.
total_num_kv_heads
>=
tp_size
:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert
self
.
total_num_kv_heads
%
tp_size
==
0
else
:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert
tp_size
%
self
.
total_num_kv_heads
==
0
self
.
num_kv_heads
=
max
(
1
,
self
.
total_num_kv_heads
//
tp_size
)
self
.
head_dim
=
config
.
head_dim
self
.
q_size
=
self
.
num_heads
*
self
.
head_dim
self
.
kv_size
=
self
.
num_kv_heads
*
self
.
head_dim
self
.
scaling
=
self
.
head_dim
**-
0.5
# TODO: attn_temperature_tuning should be a bool in huggingface
self
.
attn_temperature_tuning
=
self
.
nope
and
\
config
.
attn_temperature_tuning
>
0
self
.
floor_scale
=
getattr
(
config
,
"floor_scale"
,
8192.0
)
self
.
attn_scale
=
getattr
(
config
,
"attn_scale"
,
0.1
)
self
.
rope_theta
=
rope_theta
self
.
max_position_embeddings
=
max_position_embeddings
self
.
n_rep
=
self
.
num_heads
//
self
.
num_kv_heads
self
.
q_norm
=
RMSNorm
(
hidden_size
=
self
.
q_size
,
eps
=
config
.
rms_norm_eps
,
has_weight
=
False
,
dtype
=
torch
.
float32
,
)
if
self
.
use_qk_norm
else
None
self
.
k_norm
=
RMSNorm
(
hidden_size
=
self
.
kv_size
,
eps
=
config
.
rms_norm_eps
,
has_weight
=
False
,
dtype
=
torch
.
float32
,
)
if
self
.
use_qk_norm
else
None
self
.
qkv_proj
=
QKVParallelLinear
(
hidden_size
=
hidden_size
,
head_size
=
self
.
head_dim
,
total_num_heads
=
self
.
total_num_heads
,
total_num_kv_heads
=
self
.
total_num_kv_heads
,
bias
=
bias
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.qkv_proj"
,
)
self
.
o_proj
=
RowParallelLinear
(
input_size
=
self
.
total_num_heads
*
self
.
head_dim
,
output_size
=
hidden_size
,
bias
=
bias_o_proj
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.o_proj"
,
)
is_neox_style
=
True
is_gguf
=
quant_config
and
quant_config
.
get_name
()
==
"gguf"
if
is_gguf
and
config
.
model_type
==
"llama"
:
is_neox_style
=
False
self
.
rotary_emb
=
get_rope
(
self
.
head_dim
,
rotary_dim
=
self
.
head_dim
,
max_position
=
max_position_embeddings
,
base
=
int
(
rope_theta
),
rope_scaling
=
rope_scaling
if
rope_scaling
!=
"default"
else
None
,
is_neox_style
=
is_neox_style
,
)
if
not
self
.
nope
else
None
self
.
attn
=
Attention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
scaling
,
num_kv_heads
=
self
.
num_kv_heads
,
cache_config
=
cache_config
,
quant_config
=
quant_config
,
per_layer_sliding_window
=
None
,
use_irope
=
not
self
.
nope
,
prefix
=
f
"
{
prefix
}
.attn"
,
)
def
_get_attn_scale
(
self
,
positions
:
torch
.
Tensor
)
->
torch
.
Tensor
:
floor
=
torch
.
floor
((
positions
+
1.0
)
/
self
.
floor_scale
)
attn_scale
=
torch
.
log
(
floor
+
1.0
)
*
self
.
attn_scale
+
1.0
return
attn_scale
.
unsqueeze
(
-
1
)
def
forward
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
if
self
.
rotary_emb
is
not
None
:
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
if
self
.
q_norm
is
not
None
:
q
=
self
.
q_norm
(
q
.
float
()).
to
(
q
.
dtype
)
if
self
.
k_norm
is
not
None
:
k
=
self
.
k_norm
(
k
.
float
()).
to
(
k
.
dtype
)
# We are applying temperature tuning (https://arxiv.org/abs/2501.19399)
# to NoPE layers, where the inference-time temperature tuning function
# is customized to not affect short context
# while working at very long context
# https://arxiv.org/abs/2501.19399
#
# We should apply temperature tuning between (after) rotary / QK norm
# and (before) attention.
if
self
.
attn_temperature_tuning
and
self
.
nope
:
attn_scale
=
self
.
_get_attn_scale
(
positions
)
q
=
(
q
*
attn_scale
).
to
(
q
.
dtype
)
attn_output
=
self
.
attn
(
q
,
k
,
v
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
class
Llama4DecoderLayer
(
LlamaDecoderLayer
):
def
__init__
(
self
,
config
:
Llama4TextConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
)
->
None
:
self
.
layer_idx
=
extract_layer_index
(
prefix
)
nn
.
Module
.
__init__
(
self
)
self
.
hidden_size
=
config
.
hidden_size
rope_theta
=
config
.
rope_theta
rope_scaling
=
config
.
rope_scaling
max_position_embeddings
=
config
.
max_position_embeddings
self
.
self_attn
=
Llama4Attention
(
config
=
config
,
hidden_size
=
self
.
hidden_size
,
num_heads
=
config
.
num_attention_heads
,
num_kv_heads
=
config
.
num_key_value_heads
,
rope_theta
=
rope_theta
,
rope_scaling
=
rope_scaling
,
max_position_embeddings
=
max_position_embeddings
,
quant_config
=
quant_config
,
bias
=
False
,
bias_o_proj
=
False
,
cache_config
=
cache_config
,
prefix
=
f
"
{
prefix
}
.self_attn"
,
)
is_moe_layer
=
(
self
.
layer_idx
+
1
)
%
config
.
interleave_moe_layer_step
==
0
if
is_moe_layer
:
self
.
feed_forward
=
Llama4MoE
(
config
=
config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.feed_forward"
,
)
else
:
self
.
feed_forward
=
LlamaMLP
(
hidden_size
=
self
.
hidden_size
,
intermediate_size
=
config
.
intermediate_size_mlp
,
hidden_act
=
"silu"
,
quant_config
=
quant_config
,
bias
=
False
,
prefix
=
f
"
{
prefix
}
.feed_forward"
,
)
self
.
input_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
post_attention_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
def
forward
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
residual
:
Optional
[
torch
.
Tensor
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# Self Attention
if
residual
is
None
:
residual
=
hidden_states
hidden_states
=
self
.
input_layernorm
(
hidden_states
)
else
:
hidden_states
,
residual
=
self
.
input_layernorm
(
hidden_states
,
residual
)
hidden_states
=
self
.
self_attn
(
positions
=
positions
,
hidden_states
=
hidden_states
)
# Fully Connected
hidden_states
,
residual
=
self
.
post_attention_layernorm
(
hidden_states
,
residual
)
hidden_states
=
self
.
feed_forward
(
hidden_states
)
return
hidden_states
,
residual
@
support_torch_compile
class
Llama4Model
(
LlamaModel
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
,
layer_type
:
Type
[
Llama4DecoderLayer
]
=
Llama4DecoderLayer
):
self
.
num_experts
=
vllm_config
.
model_config
.
hf_config
.
num_local_experts
super
().
__init__
(
vllm_config
=
vllm_config
,
prefix
=
prefix
,
layer_type
=
layer_type
)
def
load_moe_expert_weights
(
self
,
name
:
str
,
loaded_weight
:
torch
.
Tensor
,
params_dict
:
Dict
[
str
,
nn
.
Parameter
],
loaded_params
:
Set
[
str
],
expert_params_mapping
:
List
[
Tuple
[
str
,
str
,
int
,
str
]],
fused
:
bool
=
True
,
)
->
bool
:
expert_param_loaded
=
False
if
"experts.gate_up_proj"
in
name
:
loaded_weight
=
loaded_weight
.
chunk
(
2
,
dim
=-
1
)
for
(
param_name
,
weight_name
,
expert_id
,
shard_id
)
in
expert_params_mapping
:
new_loaded_weight
=
loaded_weight
if
fused
:
e_str
,
_
,
proj_str
,
_
=
weight_name
.
split
(
'.'
)
weight_name
=
f
"
{
e_str
}
.
{
proj_str
}
"
param_name
=
f
"
{
param_name
}
weight"
if
weight_name
not
in
name
:
continue
full_param_name
=
name
.
replace
(
weight_name
,
param_name
)
# Skip layers on other devices.
if
is_pp_missing_parameter
(
name
,
self
):
continue
if
((
name
.
endswith
(
".bias"
)
or
name
.
endswith
(
"_bias"
))
and
name
not
in
params_dict
):
continue
param
=
params_dict
[
full_param_name
]
weight_loader
=
param
.
weight_loader
if
fused
:
if
"w13"
in
full_param_name
:
shard_idx
=
0
if
shard_id
==
"w1"
else
1
new_loaded_weight
=
new_loaded_weight
[
shard_idx
]
new_loaded_weight
=
new_loaded_weight
.
transpose
(
-
1
,
-
2
)
layer_idx
=
extract_layer_index
(
name
)
# EP mapping
expert_map
=
self
.
layers
[
layer_idx
].
feed_forward
.
experts
.
expert_map
if
expert_map
is
not
None
:
local_expert_indices
=
(
expert_map
!=
-
1
)
\
.
nonzero
()
\
.
flatten
()
\
.
to
(
new_loaded_weight
.
device
)
new_loaded_weight
=
new_loaded_weight
[
local_expert_indices
]
expert_id
=
local_expert_indices
[
0
].
item
()
else
:
# TODO: add EP support for non fused weights
pass
weight_loader
(
param
,
new_loaded_weight
,
full_param_name
,
shard_id
=
shard_id
,
expert_id
=
expert_id
)
loaded_params
.
add
(
full_param_name
)
expert_param_loaded
=
True
return
expert_param_loaded
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]])
->
Set
[
str
]:
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
".qkv_proj"
,
".q_proj"
,
"q"
),
(
".qkv_proj"
,
".k_proj"
,
"k"
),
(
".qkv_proj"
,
".v_proj"
,
"v"
),
(
".gate_up_proj"
,
".gate_proj"
,
0
),
(
".gate_up_proj"
,
".up_proj"
,
1
),
]
fused_experts_params
=
False
expert_params_mapping
=
FusedMoE
.
make_expert_params_mapping
(
ckpt_gate_proj_name
=
"gate_proj"
,
ckpt_down_proj_name
=
"down_proj"
,
ckpt_up_proj_name
=
"up_proj"
,
num_experts
=
self
.
num_experts
)
expert_params_mapping_fused
=
FusedMoE
.
make_expert_params_mapping
(
ckpt_gate_proj_name
=
"gate_up_proj"
,
ckpt_down_proj_name
=
"down_proj"
,
ckpt_up_proj_name
=
"gate_up_proj"
,
num_experts
=
1
)
params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
:
Set
[
str
]
=
set
()
for
name
,
loaded_weight
in
weights
:
if
"experts.gate_up_proj"
in
name
or
"experts.down_proj"
in
name
:
fused_experts_params
=
True
expert_params_mapping
=
expert_params_mapping_fused
if
(
self
.
quant_config
is
not
None
and
(
scale_name
:
=
self
.
quant_config
.
get_cache_scale
(
name
))):
# Loading kv cache quantization scales
param
=
params_dict
[
scale_name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
loaded_weight
=
(
loaded_weight
if
loaded_weight
.
dim
()
==
0
else
loaded_weight
[
0
])
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
scale_name
)
continue
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
if
weight_name
not
in
name
or
"experts"
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
loaded_params
.
add
(
name
)
break
else
:
moe_loaded
=
self
.
load_moe_expert_weights
(
name
,
loaded_weight
,
params_dict
,
loaded_params
,
expert_params_mapping
,
fused
=
fused_experts_params
)
if
not
moe_loaded
:
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
name
)
return
loaded_params
class
Llama4ForCausalLM
(
LlamaForCausalLM
):
packed_modules_mapping
=
{
"qkv_proj"
:
[
"q_proj"
,
"k_proj"
,
"v_proj"
],
"gate_up_proj"
:
[
"gate_proj"
,
"up_proj"
],
}
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
# Update temperature tuning config from generation config
gen_config
=
vllm_config
.
model_config
.
try_get_generation_config
()
gen_config
.
update
(
vllm_config
.
model_config
.
override_generation_config
)
vllm_config
.
model_config
.
hf_config
.
attn_temperature_tuning
\
=
gen_config
.
get
(
"attn_temperature_tuning"
,
False
)
LlamaForCausalLM
.
__init__
(
self
,
vllm_config
=
vllm_config
,
prefix
=
prefix
,
layer_type
=
Llama4DecoderLayer
)
def
_init_model
(
self
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
,
layer_type
:
Type
[
Llama4DecoderLayer
]
=
Llama4DecoderLayer
):
return
Llama4Model
(
vllm_config
=
vllm_config
,
prefix
=
prefix
,
layer_type
=
layer_type
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]])
->
Set
[
str
]:
loader
=
AutoWeightsLoader
(
self
,
skip_prefixes
=
([
"lm_head."
]
if
self
.
config
.
tie_word_embeddings
else
None
),
)
weights
=
[
self
.
permute_qk_weight_for_rotary
(
name
,
loaded_weight
)
for
name
,
loaded_weight
in
weights
]
return
loader
.
load_weights
(
weights
)
def
permute_qk_weight_for_rotary
(
self
,
name
:
str
,
loaded_weight
:
torch
.
Tensor
,
)
->
Tuple
[
str
,
torch
.
Tensor
]:
def
permute
(
w
:
torch
.
Tensor
,
n_heads
:
int
):
attn_in
=
self
.
config
.
head_dim
*
n_heads
attn_out
=
self
.
config
.
hidden_size
return
w
.
view
(
n_heads
,
attn_in
//
n_heads
//
2
,
2
,
attn_out
).
transpose
(
1
,
2
).
reshape
(
attn_in
,
attn_out
)
modules
=
name
.
split
(
"."
)
# rotary embeds should be sliced
if
(
"wk"
in
modules
or
"k_proj"
in
modules
)
\
and
modules
[
-
1
]
==
"weight"
:
loaded_weight
=
permute
(
loaded_weight
,
self
.
config
.
num_key_value_heads
)
elif
(
"wq"
in
modules
or
"q_proj"
in
modules
)
\
and
modules
[
-
1
]
==
"weight"
:
loaded_weight
=
permute
(
loaded_weight
,
self
.
config
.
num_attention_heads
)
return
name
,
loaded_weight
vllm/model_executor/models/mllama4.py
0 → 100644
View file @
c5752323
# SPDX-License-Identifier: Apache-2.0
#
# Copyright 2025 the LLAMA4, Meta Inc., vLLM, and HuggingFace Inc. team.
# All rights reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
math
from
collections.abc
import
Iterable
,
Mapping
from
itertools
import
tee
from
typing
import
List
,
Literal
,
Optional
,
Set
,
Tuple
,
TypedDict
,
Union
import
torch
from
torch
import
nn
from
transformers
import
BatchFeature
,
Llama4Config
,
Llama4VisionConfig
from
transformers.image_utils
import
SizeDict
from
transformers.modeling_outputs
import
BaseModelOutput
from
transformers.models.llama4
import
Llama4Processor
from
transformers.models.llama4.image_processing_llama4_fast
import
(
find_supported_resolutions
,
get_best_fit
)
from
vllm.attention.layer
import
MultiHeadAttention
from
vllm.config
import
VllmConfig
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.inputs
import
InputProcessingContext
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal.inputs
import
(
MultiModalFieldConfig
,
MultiModalKwargs
,
NestedTensors
)
from
vllm.multimodal.parse
import
(
ImageProcessorItems
,
ImageSize
,
MultiModalDataItems
)
from
vllm.multimodal.processing
import
(
BaseMultiModalProcessor
,
BaseProcessingInfo
,
PromptReplacement
,
PromptUpdate
)
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
,
ProcessorInputs
from
vllm.sequence
import
IntermediateTensors
from
vllm.transformers_utils.tokenizer
import
cached_tokenizer_from_config
from
.interfaces
import
MultiModalEmbeddings
,
SupportsMultiModal
from
.utils
import
(
AutoWeightsLoader
,
flatten_bn
,
init_vllm_registered_model
,
maybe_prefix
,
merge_multimodal_embeddings
)
from
.vision
import
scatter_patch_features
,
select_patch_features
logger
=
init_logger
(
__name__
)
class
Llama4ImagePatchInputs
(
TypedDict
):
type
:
Literal
[
"pixel_values"
]
flat_data
:
torch
.
Tensor
"""
Shape:
`(batch_size * num_chunks, num_channels, image size, image size)`
"""
patches_per_image
:
torch
.
Tensor
"""
The number of total patches for each image in the batch.
This is used to split the embeddings which has the first two dimensions
flattened just like `flat_data`.
"""
embed_is_patch
:
Union
[
torch
.
Tensor
,
list
[
torch
.
Tensor
]]
"""
A boolean mask indicating which image embeddings correspond
to patch tokens.
"""
aspect_ratios
:
Union
[
torch
.
Tensor
,
list
[
torch
.
Tensor
]]
"""
A list of aspect ratios corresponding to the number of tiles
in each dimension that each image in the batch corresponds to.
Shape:
`(batch_size, ratio)` where ratio is a pair `(ratio_h, ratio_w)`
"""
class
Llama4VisionMLP
(
nn
.
Module
):
def
__init__
(
self
,
input_size
:
int
,
intermediate_size
:
int
,
output_size
:
int
,
bias
:
bool
,
output_activation
:
bool
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
):
super
().
__init__
()
self
.
fc1
=
ColumnParallelLinear
(
input_size
=
input_size
,
output_size
=
intermediate_size
,
bias
=
bias
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.fc1"
,
)
self
.
fc2
=
RowParallelLinear
(
input_size
=
intermediate_size
,
output_size
=
output_size
,
bias
=
bias
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.fc2"
,
)
self
.
activation_fn
=
nn
.
GELU
()
self
.
output_activation
=
output_activation
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
hidden_states
,
_
=
self
.
fc1
(
hidden_states
)
hidden_states
=
self
.
activation_fn
(
hidden_states
)
hidden_states
,
_
=
self
.
fc2
(
hidden_states
)
if
self
.
output_activation
:
return
self
.
activation_fn
(
hidden_states
)
return
hidden_states
class
Llama4MultiModalProjector
(
nn
.
Module
):
def
__init__
(
self
,
config
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
):
super
().
__init__
()
self
.
linear_1
=
ColumnParallelLinear
(
input_size
=
config
.
vision_config
.
vision_output_dim
,
output_size
=
config
.
text_config
.
hidden_size
,
bias
=
False
,
quant_config
=
quant_config
,
gather_output
=
True
,
prefix
=
f
"
{
prefix
}
.linear_1"
,
)
def
forward
(
self
,
image_features
):
hidden_states
,
_
=
self
.
linear_1
(
image_features
)
return
hidden_states
def
pixel_shuffle
(
input_tensor
,
shuffle_ratio
):
# input_tensor: [batch_size, num_patches, channels]
batch_size
,
num_patches
,
channels
=
input_tensor
.
shape
patch_size
=
int
(
math
.
sqrt
(
num_patches
))
input_tensor
=
input_tensor
.
view
(
batch_size
,
patch_size
,
patch_size
,
-
1
)
batch_size
,
height
,
width
,
channels
=
input_tensor
.
size
()
reshaped_tensor
=
input_tensor
.
view
(
batch_size
,
height
,
int
(
width
*
shuffle_ratio
),
int
(
channels
/
shuffle_ratio
))
reshaped_tensor
=
reshaped_tensor
.
permute
(
0
,
2
,
1
,
3
).
contiguous
()
reshaped_tensor
=
reshaped_tensor
.
view
(
batch_size
,
int
(
height
*
shuffle_ratio
),
int
(
width
*
shuffle_ratio
),
int
(
channels
/
(
shuffle_ratio
**
2
)))
reshaped_tensor
=
reshaped_tensor
.
permute
(
0
,
2
,
1
,
3
).
contiguous
()
output_tensor
=
reshaped_tensor
.
view
(
batch_size
,
-
1
,
reshaped_tensor
.
shape
[
-
1
])
return
output_tensor
class
Llama4VisionPixelShuffleMLP
(
nn
.
Module
):
def
__init__
(
self
,
config
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
):
super
().
__init__
()
self
.
pixel_shuffle_ratio
=
config
.
pixel_shuffle_ratio
self
.
inner_dim
=
int
(
config
.
projector_input_dim
//
(
self
.
pixel_shuffle_ratio
**
2
))
self
.
output_dim
=
config
.
projector_output_dim
self
.
mlp
=
Llama4VisionMLP
(
input_size
=
config
.
intermediate_size
,
intermediate_size
=
config
.
projector_input_dim
,
output_size
=
config
.
projector_output_dim
,
bias
=
config
.
multi_modal_projector_bias
,
output_activation
=
True
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.mlp"
)
def
forward
(
self
,
encoded_patches
:
torch
.
Tensor
)
->
torch
.
Tensor
:
encoded_patches
=
pixel_shuffle
(
encoded_patches
,
self
.
pixel_shuffle_ratio
)
return
self
.
mlp
(
encoded_patches
)
class
Llama4VisionAttention
(
nn
.
Module
):
def
__init__
(
self
,
config
:
Llama4VisionConfig
,
quant_config
:
Optional
[
QuantizationConfig
],
prefix
:
str
=
""
,
):
super
().
__init__
()
self
.
config
=
config
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
embed_dim
=
config
.
hidden_size
self
.
num_heads
=
config
.
num_attention_heads
self
.
head_dim
=
config
.
hidden_size
//
self
.
num_heads
assert
self
.
num_heads
%
self
.
tp_size
==
0
self
.
num_local_heads
=
self
.
num_heads
//
self
.
tp_size
self
.
q_size
=
self
.
num_local_heads
*
self
.
head_dim
self
.
kv_size
=
self
.
num_local_heads
*
self
.
head_dim
self
.
attention_dropout
=
config
.
attention_dropout
self
.
scaling
=
self
.
head_dim
**-
0.5
self
.
attn
=
MultiHeadAttention
(
self
.
num_local_heads
,
self
.
head_dim
,
self
.
scaling
)
self
.
qkv_proj
=
QKVParallelLinear
(
self
.
embed_dim
,
self
.
head_dim
,
self
.
num_heads
,
bias
=
True
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.qkv_proj"
,
)
self
.
o_proj
=
RowParallelLinear
(
self
.
num_heads
*
self
.
head_dim
,
self
.
embed_dim
,
bias
=
True
,
input_is_parallel
=
True
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.o_proj"
,
)
self
.
rotary_emb
=
get_rope
(
head_size
=
self
.
head_dim
,
rotary_dim
=
config
.
hidden_size
//
config
.
num_attention_heads
//
2
,
# number of image patches
max_position
=
(
config
.
image_size
//
config
.
patch_size
)
**
2
,
base
=
config
.
rope_theta
,
rope_scaling
=
{
"rope_type"
:
"mllama4"
},
is_neox_style
=
False
,
dtype
=
torch
.
complex64
,
# important
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
input_shape
=
hidden_states
.
shape
[:
-
1
]
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
=
q
.
view
(
q
.
shape
[
0
],
q
.
shape
[
1
],
self
.
num_local_heads
,
self
.
head_dim
)
k
=
k
.
view
(
k
.
shape
[
0
],
k
.
shape
[
1
],
self
.
num_local_heads
,
self
.
head_dim
)
q
,
k
=
self
.
rotary_emb
(
q
,
k
)
q
=
q
.
view
(
q
.
shape
[
0
],
q
.
shape
[
1
],
-
1
)
k
=
k
.
view
(
k
.
shape
[
0
],
k
.
shape
[
1
],
-
1
)
attn_output
=
self
.
attn
(
q
,
k
,
v
)
attn_output
=
attn_output
.
reshape
(
*
input_shape
,
-
1
).
contiguous
()
attn_output
,
_
=
self
.
o_proj
(
attn_output
)
return
attn_output
class
Llama4VisionEncoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
config
:
Llama4VisionConfig
,
quant_config
:
Optional
[
QuantizationConfig
],
prefix
:
str
=
""
,
):
super
().
__init__
()
self
.
hidden_size
=
config
.
hidden_size
self
.
num_attention_heads
=
config
.
num_attention_heads
self
.
intermediate_size
=
config
.
intermediate_size
self
.
self_attn
=
Llama4VisionAttention
(
config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.self_attn"
)
self
.
mlp
=
Llama4VisionMLP
(
input_size
=
config
.
hidden_size
,
intermediate_size
=
config
.
intermediate_size
,
output_size
=
config
.
hidden_size
,
bias
=
True
,
output_activation
=
False
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.mlp"
)
self
.
input_layernorm
=
nn
.
LayerNorm
(
config
.
hidden_size
)
self
.
post_attention_layernorm
=
nn
.
LayerNorm
(
config
.
hidden_size
)
def
forward
(
self
,
hidden_state
:
torch
.
Tensor
,
):
# Self Attention
residual
=
hidden_state
hidden_state
=
self
.
input_layernorm
(
hidden_state
)
hidden_state
=
self
.
self_attn
(
hidden_state
)
hidden_state
=
residual
+
hidden_state
# Feed forward
residual
=
hidden_state
hidden_state
=
self
.
post_attention_layernorm
(
hidden_state
)
hidden_state
=
self
.
mlp
(
hidden_state
)
hidden_state
=
residual
+
hidden_state
outputs
=
(
hidden_state
,
)
return
outputs
class
Llama4VisionEncoder
(
nn
.
Module
):
def
__init__
(
self
,
config
:
Llama4VisionConfig
,
quant_config
:
Optional
[
QuantizationConfig
],
prefix
:
str
=
""
,
):
super
().
__init__
()
self
.
config
=
config
self
.
layers
=
nn
.
ModuleList
([
Llama4VisionEncoderLayer
(
config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.layers.
{
layer_idx
}
"
,
)
for
layer_idx
in
range
(
config
.
num_hidden_layers
)
])
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
)
->
BaseModelOutput
:
r
"""
Args:
inputs_embeds (`torch.FloatTensor` of shape
`(batch_size, sequence_length, hidden_size)`):
Optionally, instead of passing `input_ids` you can choose to
directly pass an embedded representation. This is useful if you
want more control over how to convert `input_ids` indices into
associated vectors than the model's internal embedding
lookup matrix.
"""
for
encoder_layer
in
self
.
layers
:
layer_outputs
=
encoder_layer
(
hidden_states
)
hidden_states
=
layer_outputs
[
0
]
return
BaseModelOutput
(
last_hidden_state
=
hidden_states
,
)
class
Llama4UnfoldConvolution
(
nn
.
Module
):
def
__init__
(
self
,
config
:
Llama4VisionConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
):
super
().
__init__
()
kernel_size
=
config
.
patch_size
if
isinstance
(
kernel_size
,
int
):
kernel_size
=
(
kernel_size
,
kernel_size
)
self
.
unfold
=
torch
.
nn
.
Unfold
(
kernel_size
=
kernel_size
,
stride
=
config
.
patch_size
)
self
.
linear
=
ColumnParallelLinear
(
config
.
num_channels
*
kernel_size
[
0
]
*
kernel_size
[
1
],
config
.
hidden_size
,
bias
=
False
,
quant_config
=
quant_config
,
gather_output
=
True
,
prefix
=
f
"
{
prefix
}
.linear"
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
hidden_states
=
self
.
unfold
(
hidden_states
)
hidden_states
=
hidden_states
.
permute
(
0
,
2
,
1
)
hidden_states
,
_
=
self
.
linear
(
hidden_states
)
return
hidden_states
class
Llama4VisionModel
(
nn
.
Module
):
def
__init__
(
self
,
config
:
Llama4VisionConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
):
super
().
__init__
()
self
.
config
=
config
self
.
image_size
=
config
.
image_size
self
.
patch_size
=
config
.
patch_size
self
.
hidden_size
=
config
.
hidden_size
self
.
num_channels
=
config
.
num_channels
self
.
num_patches
=
(
self
.
image_size
//
self
.
patch_size
)
**
2
+
1
self
.
scale
=
config
.
hidden_size
**-
0.5
self
.
patch_embedding
=
Llama4UnfoldConvolution
(
config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.patch_embedding"
)
self
.
class_embedding
=
nn
.
Parameter
(
self
.
scale
*
torch
.
randn
(
self
.
hidden_size
))
self
.
positional_embedding_vlm
=
nn
.
Parameter
(
self
.
scale
*
torch
.
randn
(
self
.
num_patches
,
self
.
hidden_size
))
# layer norms
self
.
layernorm_pre
=
nn
.
LayerNorm
(
self
.
hidden_size
,
eps
=
1e-5
)
self
.
layernorm_post
=
nn
.
LayerNorm
(
self
.
hidden_size
,
eps
=
1e-5
)
# encoders
self
.
model
=
Llama4VisionEncoder
(
config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.model"
)
self
.
vision_adapter
=
Llama4VisionPixelShuffleMLP
(
config
,
quant_config
,
prefix
=
f
"
{
prefix
}
.vision_adapter"
)
def
forward
(
self
,
images_flattened
:
torch
.
Tensor
,
)
->
BaseModelOutput
:
# Patch embedding
hidden_state
=
self
.
patch_embedding
(
images_flattened
)
num_tiles
,
num_patches
,
hidden_dim
=
hidden_state
.
shape
# Add cls token
class_embedding
=
self
.
class_embedding
.
expand
(
hidden_state
.
shape
[
0
],
1
,
hidden_state
.
shape
[
-
1
])
hidden_state
=
torch
.
cat
([
hidden_state
,
class_embedding
],
dim
=
1
)
num_patches
+=
1
# Position embeddings
hidden_state
=
hidden_state
.
reshape
(
num_tiles
,
1
,
num_patches
,
hidden_dim
,
)
positional_embedding
=
self
.
positional_embedding_vlm
.
to
(
dtype
=
hidden_state
.
dtype
,
device
=
hidden_state
.
device
)
hidden_state
=
hidden_state
+
positional_embedding
hidden_state
=
self
.
layernorm_pre
(
hidden_state
)
hidden_state
=
hidden_state
.
view
(
num_tiles
,
-
1
,
hidden_dim
)
# Apply encoder
output
=
self
.
model
(
hidden_state
)
hidden_state
=
output
.
last_hidden_state
hidden_state
=
self
.
layernorm_post
(
hidden_state
)
# Remove CLS token output
hidden_state
=
hidden_state
[:,
:
-
1
,
:]
# now, we use Llama4VisionPixelShuffle + mlp to project embeddings
hidden_state
=
self
.
vision_adapter
(
hidden_state
)
return
BaseModelOutput
(
last_hidden_state
=
hidden_state
,
attentions
=
None
,
)
class
Mllama4ProcessingInfo
(
BaseProcessingInfo
):
def
__init__
(
self
,
ctx
:
InputProcessingContext
)
->
None
:
super
().
__init__
(
ctx
)
def
get_hf_config
(
self
)
->
Llama4Config
:
return
self
.
ctx
.
get_hf_config
(
Llama4Config
)
def
get_hf_processor
(
self
,
**
kwargs
:
object
)
->
Llama4Processor
:
return
self
.
ctx
.
get_hf_processor
(
Llama4Processor
,
use_fast
=
True
,
**
kwargs
)
def
get_supported_mm_limits
(
self
)
->
Mapping
[
str
,
Optional
[
int
]]:
return
{
"image"
:
10
}
@
staticmethod
def
get_patch_per_chunk
(
vision_config
:
Llama4VisionConfig
)
->
int
:
image_size
=
vision_config
.
image_size
patch_size
=
vision_config
.
patch_size
assert
(
image_size
%
patch_size
==
0
),
f
"chunk size
{
image_size
}
should be multiple of "
f
"patch_size
{
patch_size
}
"
ds_ratio
=
int
(
round
(
1.0
/
(
vision_config
.
pixel_shuffle_ratio
**
2
)))
return
(
image_size
//
patch_size
)
**
2
//
ds_ratio
def
get_max_num_tiles
(
self
)
->
int
:
image_processor
=
self
.
get_hf_processor
().
image_processor
return
image_processor
.
max_patches
def
get_mm_max_tokens_per_item
(
self
,
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
int
],
)
->
Mapping
[
str
,
int
]:
vision_config
=
self
.
get_hf_config
().
vision_config
# image_start + local tiles * (patches + 1 x separator) +
# 1 global tile * (image x 1 + patches) + image_end
token_per_chunk
=
self
.
get_patch_per_chunk
(
vision_config
)
+
1
mm_max_tokens
=
(
self
.
get_max_num_tiles
()
+
1
)
*
token_per_chunk
+
2
return
{
"image"
:
mm_max_tokens
}
def
get_image_size_with_most_features
(
self
)
->
ImageSize
:
vision_config
=
self
.
get_hf_config
().
vision_config
image_size
=
vision_config
.
image_size
# Result in the max possible feature size (h:w = 16:1)
return
ImageSize
(
height
=
self
.
get_max_num_tiles
()
*
image_size
,
width
=
image_size
)
class
Mllama4MultiModalProcessor
(
BaseMultiModalProcessor
[
Mllama4ProcessingInfo
]
):
def
_call_hf_processor
(
self
,
prompt
:
str
,
mm_data
:
Mapping
[
str
,
object
],
mm_kwargs
:
Mapping
[
str
,
object
],
)
->
BatchFeature
:
tokenizer
=
self
.
info
.
get_tokenizer
()
if
mm_data
is
None
:
return
tokenizer
(
prompt
,
add_special_tokens
=
False
)
# exclude bos
processed_outputs
=
super
().
_call_hf_processor
(
prompt
=
prompt
,
mm_data
=
mm_data
,
mm_kwargs
=
mm_kwargs
,
)
processor
=
self
.
info
.
get_hf_processor
(
**
mm_kwargs
)
image_processor
=
processor
.
image_processor
vision_config
=
self
.
info
.
get_hf_config
().
vision_config
if
processed_outputs
.
get
(
"pixel_values"
)
is
not
None
:
assert
"images"
in
mm_data
,
\
"images expected to be in mm_data when pixel_values is present"
images
=
mm_data
[
"images"
]
parsed_images
=
(
self
.
_get_data_parser
().
parse_mm_data
({
"image"
:
images
}).
get_items
(
"image"
,
ImageProcessorItems
))
tile_size
=
vision_config
.
image_size
possible_resolutions
=
find_supported_resolutions
(
max_num_chunks
=
self
.
info
.
get_max_num_tiles
(),
patch_size
=
SizeDict
(
height
=
tile_size
,
width
=
tile_size
),
)
best_fit_sizes
=
[
get_best_fit
(
(
image
.
size
[
1
],
image
.
size
[
0
]),
torch
.
tensor
(
possible_resolutions
),
resize_to_max_canvas
=
image_processor
.
resize_to_max_canvas
)
for
image
in
parsed_images
]
# TODO tile height/width do not necessarily need to match
aspect_ratios
=
[(
image_size
[
0
]
//
tile_size
,
image_size
[
1
]
//
tile_size
)
for
image_size
in
best_fit_sizes
]
patches_per_image
=
[
1
if
r_h
*
r_w
==
1
else
1
+
r_h
*
r_w
for
(
r_h
,
r_w
)
in
aspect_ratios
]
# embed_is_patch should have one feature per image-related token:
# <|image_start|>, <|tile_*_separator|>, <|image|>, <|image_end|>
# -> False
# <|patch|> -> True
# embed_is_patch has no entries corresponding to non-image-related
# tokens.
patch_id
=
tokenizer
.
get_vocab
()[
processor
.
img_patch_token
]
num_patches_per_chunk
=
self
.
info
.
get_patch_per_chunk
(
vision_config
)
expanded_image_tokens_list
=
[
processor
.
_prompt_split_image
(
aspect_ratio
,
num_patches_per_chunk
)
for
aspect_ratio
in
aspect_ratios
]
expanded_image_token_ids
=
[
tokenizer
.
encode
(
image_tokens
,
add_special_tokens
=
False
)
for
image_tokens
in
expanded_image_tokens_list
]
embed_is_patch
=
[
torch
.
tensor
(
tokens
)
==
patch_id
for
tokens
in
expanded_image_token_ids
]
processed_outputs
[
"aspect_ratios"
]
=
aspect_ratios
processed_outputs
[
"patches_per_image"
]
=
torch
.
tensor
(
patches_per_image
)
processed_outputs
[
"embed_is_patch"
]
=
embed_is_patch
return
processed_outputs
def
_get_mm_fields_config
(
self
,
hf_inputs
:
BatchFeature
,
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
],
)
->
Mapping
[
str
,
MultiModalFieldConfig
]:
patches_per_image
=
hf_inputs
.
get
(
"patches_per_image"
,
torch
.
empty
(
0
))
return
dict
(
pixel_values
=
MultiModalFieldConfig
.
flat_from_sizes
(
"image"
,
patches_per_image
),
patches_per_image
=
MultiModalFieldConfig
.
batched
(
"image"
),
aspect_ratios
=
MultiModalFieldConfig
.
batched
(
"image"
),
embed_is_patch
=
MultiModalFieldConfig
.
batched
(
"image"
),
)
def
_get_prompt_updates
(
self
,
mm_items
:
MultiModalDataItems
,
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
],
out_mm_kwargs
:
MultiModalKwargs
,
)
->
List
[
PromptUpdate
]:
assert
(
mm_items
.
get_count
(
"image"
,
strict
=
False
)
==
0
or
"aspect_ratios"
in
out_mm_kwargs
),
"Transformers expect to include aspect_ratios in out_mm_kwargs"
config
=
self
.
info
.
get_hf_config
()
vision_config
=
config
.
vision_config
num_patches_per_chunk
=
self
.
info
.
get_patch_per_chunk
(
vision_config
)
hf_processor
=
self
.
info
.
get_hf_processor
(
**
hf_processor_mm_kwargs
)
image_token
=
hf_processor
.
image_token
def
get_replacement
(
item_idx
:
int
):
aspect_ratio
=
out_mm_kwargs
[
"aspect_ratios"
][
item_idx
]
return
hf_processor
.
_prompt_split_image
(
aspect_ratio
=
aspect_ratio
,
num_patches_per_chunk
=
num_patches_per_chunk
)
return
[
PromptReplacement
(
modality
=
"image"
,
target
=
image_token
,
replacement
=
get_replacement
,
)
]
class
Mllama4DummyInputsBuilder
(
BaseDummyInputsBuilder
[
Mllama4ProcessingInfo
]):
def
get_dummy_processor_inputs
(
self
,
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
int
],
)
->
ProcessorInputs
:
num_images
=
mm_counts
.
get
(
"image"
,
0
)
(
target_width
,
target_height
)
=
self
.
info
.
get_image_size_with_most_features
()
mm_data
=
{
"image"
:
self
.
_get_dummy_images
(
width
=
target_width
,
height
=
target_height
,
num_images
=
num_images
)
}
image_token
=
self
.
info
.
get_hf_processor
().
fake_image_token
return
ProcessorInputs
(
prompt_text
=
image_token
*
num_images
,
mm_data
=
mm_data
,
)
@
MULTIMODAL_REGISTRY
.
register_processor
(
Mllama4MultiModalProcessor
,
info
=
Mllama4ProcessingInfo
,
dummy_inputs
=
Mllama4DummyInputsBuilder
,
)
class
Llama4ForConditionalGeneration
(
nn
.
Module
,
SupportsMultiModal
):
packed_modules_mapping
=
{
"qkv_proj"
:
[
"q_proj"
,
"k_proj"
,
"v_proj"
],
}
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
quant_config
=
vllm_config
.
quant_config
multimodal_config
=
vllm_config
.
model_config
.
multimodal_config
self
.
config
=
config
self
.
quant_config
=
quant_config
self
.
multimodal_config
=
multimodal_config
self
.
vision_model
=
Llama4VisionModel
(
config
.
vision_config
,
None
,
prefix
=
maybe_prefix
(
prefix
,
"vision_model"
))
self
.
multi_modal_projector
=
Llama4MultiModalProjector
(
self
.
config
,
None
,
prefix
=
maybe_prefix
(
prefix
,
"multi_modal_projector"
))
self
.
language_model
=
init_vllm_registered_model
(
vllm_config
=
vllm_config
,
hf_config
=
config
.
text_config
,
architectures
=
[
"Llama4ForCausalLM"
],
prefix
=
maybe_prefix
(
prefix
,
"language_model"
))
self
.
tokenizer
=
cached_tokenizer_from_config
(
vllm_config
.
model_config
)
def
_parse_and_validate_image_input
(
self
,
**
kwargs
:
object
)
->
Optional
[
Llama4ImagePatchInputs
]:
# num_images, 1, num_chunks, channel, image_size, image_size
pixel_values
=
kwargs
.
pop
(
"pixel_values"
,
None
)
if
pixel_values
is
None
:
return
None
# num_images x num_chunks, channel, image_size, image_size
# TODO: confirm handling for variable lengths
flat_pixel_values
=
flatten_bn
(
pixel_values
,
concat
=
True
)
patches_per_image
=
flatten_bn
(
kwargs
.
pop
(
"patches_per_image"
))
embed_is_patch
=
kwargs
.
pop
(
"embed_is_patch"
,
None
)
if
not
isinstance
(
embed_is_patch
,
(
torch
.
Tensor
,
list
)):
raise
ValueError
(
"Incorrect type of embed_is_patch. "
f
"Got type:
{
type
(
embed_is_patch
)
}
"
)
aspect_ratios
=
kwargs
.
pop
(
"aspect_ratios"
,
None
)
if
not
isinstance
(
aspect_ratios
,
(
torch
.
Tensor
,
list
)):
raise
ValueError
(
"Incorrect type of aspect_ratios. "
f
"Got type:
{
type
(
aspect_ratios
)
}
"
)
return
Llama4ImagePatchInputs
(
type
=
"pixel_values"
,
flat_data
=
flat_pixel_values
,
patches_per_image
=
patches_per_image
,
embed_is_patch
=
embed_is_patch
,
aspect_ratios
=
aspect_ratios
,
)
def
_process_image_input
(
self
,
image_input
:
Llama4ImagePatchInputs
)
->
MultiModalEmbeddings
:
flat_data
=
image_input
[
"flat_data"
]
patches_per_image
=
image_input
[
"patches_per_image"
].
tolist
()
vision_embeddings_flat
=
self
.
vision_model
(
flat_data
).
last_hidden_state
return
vision_embeddings_flat
.
split
(
patches_per_image
,
dim
=
0
)
def
get_multimodal_embeddings
(
self
,
**
kwargs
)
->
Optional
[
MultiModalEmbeddings
]:
image_input
=
self
.
_parse_and_validate_image_input
(
**
kwargs
)
if
image_input
is
None
:
return
None
# num_images x [num_chunks, num_patches, hidden_dim]
image_features
=
self
.
_process_image_input
(
image_input
)
# num_images x [num_chunks x num_patches, hidden_dim]
image_features_flat
=
[
img
.
flatten
(
0
,
1
)
for
img
in
image_features
]
# num_images x [1, input_len] -> num_images x [input_len]
embed_is_patch_flat
=
[
is_patch
.
flatten
(
0
,
1
)
for
is_patch
in
image_input
[
"embed_is_patch"
]
]
return
scatter_patch_features
(
image_features_flat
,
embed_is_patch_flat
,
)
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
,
multimodal_embeddings
:
Optional
[
NestedTensors
]
=
None
,
)
->
torch
.
Tensor
:
inputs_embeds
=
self
.
language_model
.
get_input_embeddings
(
input_ids
)
if
multimodal_embeddings
is
not
None
:
multimodal_embeddings
=
torch
.
cat
(
multimodal_embeddings
)
mm_embeddings
=
self
.
multi_modal_projector
(
multimodal_embeddings
)
inputs_embeds
=
merge_multimodal_embeddings
(
input_ids
,
inputs_embeds
,
select_patch_features
(
mm_embeddings
),
self
.
config
.
image_token_index
)
return
inputs_embeds
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
**
kwargs
:
object
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
# NOTE: In v1, inputs_embeds is always generated at model runner, this
# condition is for v0 compatibility.
if
"pixel_values"
in
kwargs
:
vision_embeddings
=
self
.
get_multimodal_embeddings
(
**
kwargs
)
inputs_embeds
=
self
.
get_input_embeddings
(
input_ids
,
vision_embeddings
)
input_ids
=
None
return
self
.
language_model
(
input_ids
,
positions
,
intermediate_tensors
,
inputs_embeds
)
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
torch
.
Tensor
]:
return
self
.
language_model
.
compute_logits
(
hidden_states
,
sampling_metadata
)
def
sample
(
self
,
logits
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
Optional
[
SamplerOutput
]:
return
self
.
language_model
.
sample
(
logits
,
sampling_metadata
)
def
separate_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]],
prefix
:
str
,
)
->
Tuple
[
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]],
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]]:
weights1
,
weights2
=
tee
(
weights
,
2
)
def
get_prefix_weights
()
->
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]:
for
name
,
data
in
weights1
:
if
name
.
startswith
(
prefix
):
yield
(
name
,
data
)
def
get_other_weights
()
->
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]:
for
name
,
data
in
weights2
:
if
not
name
.
startswith
(
prefix
):
yield
(
name
,
data
)
return
get_prefix_weights
(),
get_other_weights
()
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]])
->
Set
[
str
]:
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
".self_attn.qkv_proj"
,
".self_attn.q_proj"
,
"q"
),
(
".self_attn.qkv_proj"
,
".self_attn.k_proj"
,
"k"
),
(
".self_attn.qkv_proj"
,
".self_attn.v_proj"
,
"v"
),
]
params_dict
=
dict
(
self
.
named_parameters
())
updated_params
:
Set
[
str
]
=
set
()
# language_model is an Llama4ForCausalLM instance. We load it's
# using llama4's load_weights routine.
language_model_prefix
=
"language_model.model."
language_model_weights
,
other_weights
=
self
.
separate_weights
(
weights
,
prefix
=
language_model_prefix
)
loader
=
AutoWeightsLoader
(
self
)
loaded_language_model_params
=
loader
.
load_weights
(
language_model_weights
)
assert
loaded_language_model_params
is
not
None
updated_params
.
update
(
loaded_language_model_params
)
for
name
,
loaded_weight
in
other_weights
:
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
param
=
params_dict
[
name
]
updated_params
.
add
(
name
)
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
updated_params
.
add
(
name
)
return
updated_params
vllm/model_executor/models/registry.py
View file @
c5752323
...
@@ -73,6 +73,7 @@ _TEXT_GENERATION_MODELS = {
...
@@ -73,6 +73,7 @@ _TEXT_GENERATION_MODELS = {
"JAISLMHeadModel"
:
(
"jais"
,
"JAISLMHeadModel"
),
"JAISLMHeadModel"
:
(
"jais"
,
"JAISLMHeadModel"
),
"JambaForCausalLM"
:
(
"jamba"
,
"JambaForCausalLM"
),
"JambaForCausalLM"
:
(
"jamba"
,
"JambaForCausalLM"
),
"LlamaForCausalLM"
:
(
"llama"
,
"LlamaForCausalLM"
),
"LlamaForCausalLM"
:
(
"llama"
,
"LlamaForCausalLM"
),
"Llama4ForCausalLM"
:
(
"llama4"
,
"Llama4ForCausalLM"
),
# For decapoda-research/llama-*
# For decapoda-research/llama-*
"LLaMAForCausalLM"
:
(
"llama"
,
"LlamaForCausalLM"
),
"LLaMAForCausalLM"
:
(
"llama"
,
"LlamaForCausalLM"
),
"MambaForCausalLM"
:
(
"mamba"
,
"MambaForCausalLM"
),
"MambaForCausalLM"
:
(
"mamba"
,
"MambaForCausalLM"
),
...
@@ -194,6 +195,7 @@ _MULTIMODAL_MODELS = {
...
@@ -194,6 +195,7 @@ _MULTIMODAL_MODELS = {
# [Encoder-decoder]
# [Encoder-decoder]
"Florence2ForConditionalGeneration"
:
(
"florence2"
,
"Florence2ForConditionalGeneration"
),
# noqa: E501
"Florence2ForConditionalGeneration"
:
(
"florence2"
,
"Florence2ForConditionalGeneration"
),
# noqa: E501
"MllamaForConditionalGeneration"
:
(
"mllama"
,
"MllamaForConditionalGeneration"
),
# noqa: E501
"MllamaForConditionalGeneration"
:
(
"mllama"
,
"MllamaForConditionalGeneration"
),
# noqa: E501
"Llama4ForConditionalGeneration"
:
(
"mllama4"
,
"Llama4ForConditionalGeneration"
),
# noqa: E501
"SkyworkR1VChatModel"
:
(
"skyworkr1v"
,
"SkyworkR1VChatModel"
),
"SkyworkR1VChatModel"
:
(
"skyworkr1v"
,
"SkyworkR1VChatModel"
),
"WhisperForConditionalGeneration"
:
(
"whisper"
,
"WhisperForConditionalGeneration"
),
# noqa: E501
"WhisperForConditionalGeneration"
:
(
"whisper"
,
"WhisperForConditionalGeneration"
),
# noqa: E501
}
}
...
...
vllm/v1/attention/backends/flash_attn.py
View file @
c5752323
...
@@ -96,6 +96,183 @@ class FlashAttentionMetadata:
...
@@ -96,6 +96,183 @@ class FlashAttentionMetadata:
# For logging.
# For logging.
num_input_tokens
:
int
=
0
# Number of tokens including padding.
num_input_tokens
:
int
=
0
# Number of tokens including padding.
# for local attention
@
dataclass
class
LocalAttentionMetadata
:
local_query_start_loc
:
torch
.
Tensor
local_seqused_k
:
torch
.
Tensor
local_block_table
:
torch
.
Tensor
local_max_query_len
:
int
local_max_seq_len
:
int
local_attn_metadata
:
Optional
[
LocalAttentionMetadata
]
=
None
#
# Take in `query_start_loc_np` and `seq_lens_np` and break the sequences into
# local attention blocks, where each block is passed to the attention kernel
# as an independent local ("virtual") batch item.
#
# For example, if are performing a chunked prefill a batch of 3 sequences:
# q_seqlens = [4, 10, 5]
# kv_seqlens = [6, 17, 9]
# Then normally for regular attention we would compute with an attention mask
# for batch idx 0 (q_seqlens = 4, kv_seqlens = 6) like:
# batch idx: 0 (q_seqlens = 4, kv_seqlens = 6)
# k_toks > 0 1 2 3 4 5
# q_toks v _____________
# 0 | 1 1 1
# 1 | 1 1 1 1
# 2 | 1 1 1 1 1
# 3 | 1 1 1 1 1 1
#
# for local attention (with attn_chunk_size = 4) we would compute with an
# attention mask like:
# batch idx: 0 (q_seqlens = 4, kv_seqlens = 6, attn_chunk_size = 4)
# k_toks > 0 1 2 3 4 5
# q_toks v _____________
# 0 | 1 1 1
# 1 | 1 1 1 1
# 2 | 1
# 3 | 1 1
#
# We can simulate this mask using standard flash-attention by breaking the
# sequences into local ("virtual") batches, where each local batch item is a
# local attention block, so in this case batch idx 0 would be broken up into:
#
# local-batch idx: 0 (q_seqlens = 2, kv_seqlens = 4) (batch 0)
# k_toks > 0 1 2 3
# q_toks v _____________
# 0 | 1 1 1
# 1 | 1 1 1 1
# local-batch idx: 1 (q_seqlens = 2, kv_seqlens = 2) (batch 0)
# k_toks > 4 5
# q_toks v _____________
# 2 | 1
# 3 | 1 1
#
# e.g. if we have:
# attn_chunk_size = 4
# query_start_loc_np = [0, 4, 14, 19] (q_seqlens = [4, 10, 5])
# Then this function would return:
# __b0__ ______b1______ __b2__ < orig batch indices
# q_seqlens_local = [ 2, 2, 1, 4, 4, 1, 4, 1]
# cu_seqlens_q_local = [0, 4, 6, 10, 14, 18, 19, 23, 24]
# seqlens_k_local = [ 4, 2, 4, 4, 4, 1, 4, 1]
# block_table_local : shape[local_virtual_batches, pages_per_local_batch]
def
make_local_attention_virtual_batches
(
attn_chunk_size
:
int
,
query_start_loc_np
:
np
.
ndarray
,
seq_lens_np
:
np
.
ndarray
,
block_table
:
torch
.
tensor
,
page_size
:
int
=
0
,
)
->
tuple
[
np
.
ndarray
,
np
.
ndarray
,
np
.
ndarray
,
torch
.
tensor
]:
q_seqlens
=
query_start_loc_np
[
1
:]
-
query_start_loc_np
[:
-
1
]
actual_batch_size
=
seq_lens_np
.
shape
[
0
]
# Handle if we are starting in the middle of a local attention block,
# we assume q_seqlens > 0 (for all elements), for each batch idx we compute
# the number of tokens that are not in the first local attention block and
# then we can simply use a cdiv for the rest.
# For example if we have:
# attn_chunk_size = 4
# q_seqlens = [4, 10, 5]
# k_seqlens = [6, 17, 9]
# Then we would get:
# new_tokens_in_first_block = [2, 1, 4]
# local_blocks = [2, 4, 2]
q_tokens_in_first_block
=
np
.
minimum
(
attn_chunk_size
-
((
seq_lens_np
-
q_seqlens
)
%
attn_chunk_size
),
q_seqlens
).
astype
(
np
.
int32
)
tokens_in_last_block
=
attn_chunk_size
+
(
seq_lens_np
%
-
attn_chunk_size
)
local_blocks
=
1
+
cdiv
(
q_seqlens
-
q_tokens_in_first_block
,
attn_chunk_size
)
# Once we know the number of local blocks we can compute the request spans
# for each batch idx, we can figure out the number of "virtual" requests we
# have to make,
# For the above example we would get:
# seqlens_q_local = [2, 2, 1, 4, 4, 1, 4, 1]
#
# First Get batched arange. (E.g., [2, 4, 2] -> [0, 1, 0, 1, 2, 3, 0, 1])
# (TODO: max a utility to share this code with _prepare_inputs)
# arange step 1. [2, 4, 2] -> [2, 6, 8]
cu_num_blocks
=
np
.
cumsum
(
local_blocks
)
virtual_batches
=
cu_num_blocks
[
-
1
]
# arange step 2. [2, 6, 8] -> [0, 0, 2, 2, 2, 2, 6, 6]
block_offsets
=
np
.
repeat
(
cu_num_blocks
-
local_blocks
,
local_blocks
)
# arange step 3. [0, 1, 0, 1, 2, 3, 0, 1]
arange
=
np
.
arange
(
virtual_batches
,
dtype
=
np
.
int32
)
-
block_offsets
# also compute reverse arange (i.e. [1, 0, 3, 2, 1, 0, 1, 0])
rarange
=
np
.
repeat
(
local_blocks
,
local_blocks
)
-
arange
-
1
# Then we can compute the seqlens_q_local, handling the fact that the
# first and last blocks could be partial
seqlens_q_local
=
\
np
.
repeat
(
q_seqlens
-
q_tokens_in_first_block
,
local_blocks
)
# set the first block since this may be a partial block
seqlens_q_local
[
arange
==
0
]
=
q_tokens_in_first_block
# set the remaining blocks
seqlens_q_local
[
arange
>
0
]
=
np
.
minimum
(
seqlens_q_local
-
attn_chunk_size
*
(
arange
-
1
),
attn_chunk_size
)[
arange
>
0
]
# convert from q_seqlens to cu_seqlens_q
cu_seqlens_q_local
=
np
.
pad
(
np
.
cumsum
(
seqlens_q_local
),
(
1
,
0
))
\
.
astype
(
np
.
int32
)
# compute the seqlens_k_local,
# basically a full local attention block for all but the last block in each
# batch
# For our example this will be:
# seqlens_k_local = [4, 2, 4, 4, 4, 1, 4, 1]
seqlens_k_local
=
np
.
full
(
cu_num_blocks
[
-
1
],
attn_chunk_size
,
dtype
=
np
.
int32
)
seqlens_k_local
[
cu_num_blocks
-
1
]
=
tokens_in_last_block
k_seqstarts_absolute
=
np
.
repeat
(
seq_lens_np
,
local_blocks
)
-
\
(
rarange
*
attn_chunk_size
+
\
np
.
repeat
(
tokens_in_last_block
,
local_blocks
))
# For the example the local attention blocks start at:
# _b0_ _____b1_____ _b2_
# k_seqstarts_absolute = [0, 4, 4, 8, 12, 16, 4, 8]
block_starts
=
k_seqstarts_absolute
//
page_size
assert
attn_chunk_size
%
page_size
==
0
,
\
f
"attn_chunk_size
{
attn_chunk_size
}
is not "
\
f
"divisible by page_size
{
page_size
}
"
pages_per_local_batch
=
attn_chunk_size
//
page_size
# Create a block_table for the local attention blocks
# For out example if we have a block-table like (assuming page_size=2):
# block_table = [
# [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9], < batch 0
# [10, 11, 12, 13, 14, 15, 16, 17, 18, 19], < batch 1
# [20, 21, 22, 23, 24, 25, 26, 27, 28, 29], < batch 2
# ]
# Then for the local batches we would want a block-table like
# block_table_local = [
# [ 0, 1 ], < local-batch 0, (batch 0, starting from k[0])
# [ 2, 3 ], < local-batch 1, (batch 0, starting from k[4])
# [ 12, 13 ], < local-batch 2, (batch 1, starting from k[4])
# [ 14, 15 ], < local-batch 3, (batch 1, starting from k[8])
# [ 16, 17 ], < local-batch 4, (batch 1, starting from k[12])
# [ 18, 19 ], < local-batch 5, (batch 1, starting from k[16])
# [ 22, 23 ], < local-batch 6, (batch 2, starting from k[4])
# [ 24, 25 ], < local-batch 7, (batch 2, starting from k[8])
# ]
block_indices
=
np
.
broadcast_to
(
np
.
arange
(
pages_per_local_batch
,
dtype
=
np
.
int32
),
(
virtual_batches
,
pages_per_local_batch
))
\
+
np
.
expand_dims
(
block_starts
,
axis
=
1
)
block_indices
=
block_indices
.
flatten
()
batch_indices
=
np
.
repeat
(
np
.
arange
(
actual_batch_size
,
dtype
=
np
.
int32
),
local_blocks
*
pages_per_local_batch
)
block_table_local
=
block_table
[
batch_indices
,
block_indices
]
\
.
view
(
virtual_batches
,
-
1
)
return
seqlens_q_local
,
cu_seqlens_q_local
,
seqlens_k_local
,
\
block_table_local
class
FlashAttentionMetadataBuilder
:
class
FlashAttentionMetadataBuilder
:
...
@@ -109,18 +286,40 @@ class FlashAttentionMetadataBuilder:
...
@@ -109,18 +286,40 @@ class FlashAttentionMetadataBuilder:
def
build
(
self
,
num_reqs
:
int
,
num_actual_tokens
:
int
,
max_query_len
:
int
,
def
build
(
self
,
num_reqs
:
int
,
num_actual_tokens
:
int
,
max_query_len
:
int
,
common_prefix_len
:
int
):
common_prefix_len
:
int
):
max_seq_len
=
self
.
runner
.
seq_lens_np
[:
num_reqs
].
max
()
max_seq_len
=
self
.
runner
.
seq_lens_np
[:
num_reqs
].
max
()
query_start_loc
=
self
.
runner
.
query_start_loc_cpu
[:
num_reqs
+
1
].
to
(
query_start_loc_cpu
=
self
.
runner
.
query_start_loc_cpu
[:
num_reqs
+
1
]
self
.
runner
.
device
,
non_blocking
=
True
)
query_start_loc
=
query_start_loc_cpu
.
to
(
self
.
runner
.
device
,
seq_lens
=
self
.
runner
.
seq_lens_cpu
[:
num_reqs
].
to
(
self
.
runner
.
device
,
non_blocking
=
True
)
non_blocking
=
True
)
seq_lens_cpu
=
self
.
runner
.
seq_lens_cpu
[:
num_reqs
]
seq_lens
=
seq_lens_cpu
.
to
(
self
.
runner
.
device
,
non_blocking
=
True
)
block_table
=
(
block_table
=
(
self
.
runner
.
input_batch
.
block_table
.
get_device_tensor
()[:
num_reqs
])
self
.
runner
.
input_batch
.
block_table
.
get_device_tensor
()[:
num_reqs
])
slot_mapping
=
self
.
runner
.
slot_mapping_cpu
[:
num_actual_tokens
].
to
(
slot_mapping
=
self
.
runner
.
slot_mapping_cpu
[:
num_actual_tokens
].
to
(
self
.
runner
.
device
,
non_blocking
=
True
).
long
()
self
.
runner
.
device
,
non_blocking
=
True
).
long
()
# for local attention
local_attn_metadata
=
None
if
self
.
runner
.
attention_chunk_size
is
not
None
:
seqlens_q_local_np
,
virt_q_cu_seqlens_np
,
virt_k_seqlens_np
,
\
virt_block_table
=
make_local_attention_virtual_batches
(
self
.
runner
.
attention_chunk_size
,
self
.
runner
.
query_start_loc_np
[:
num_reqs
+
1
],
self
.
runner
.
seq_lens_np
[:
num_reqs
],
block_table
,
self
.
runner
.
block_size
,
)
local_attn_metadata
=
FlashAttentionMetadata
.
LocalAttentionMetadata
(
local_query_start_loc
=
torch
.
from_numpy
(
virt_q_cu_seqlens_np
).
to
(
self
.
runner
.
device
,
non_blocking
=
True
),
local_seqused_k
=
torch
.
from_numpy
(
virt_k_seqlens_np
).
to
(
self
.
runner
.
device
,
non_blocking
=
True
),
local_block_table
=
virt_block_table
,
local_max_query_len
=
seqlens_q_local_np
.
max
(),
local_max_seq_len
=
virt_k_seqlens_np
.
max
(),
)
use_cascade
=
common_prefix_len
>
0
use_cascade
=
common_prefix_len
>
0
if
use_cascade
:
if
use_cascade
:
# TODO: Optimize.
cu_prefix_query_lens
=
torch
.
tensor
([
0
,
num_actual_tokens
],
cu_prefix_query_lens
=
torch
.
tensor
([
0
,
num_actual_tokens
],
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
device
=
self
.
runner
.
device
)
device
=
self
.
runner
.
device
)
...
@@ -149,6 +348,7 @@ class FlashAttentionMetadataBuilder:
...
@@ -149,6 +348,7 @@ class FlashAttentionMetadataBuilder:
cu_prefix_query_lens
=
cu_prefix_query_lens
,
cu_prefix_query_lens
=
cu_prefix_query_lens
,
prefix_kv_lens
=
prefix_kv_lens
,
prefix_kv_lens
=
prefix_kv_lens
,
suffix_kv_lens
=
suffix_kv_lens
,
suffix_kv_lens
=
suffix_kv_lens
,
local_attn_metadata
=
local_attn_metadata
,
)
)
return
attn_metadata
return
attn_metadata
...
@@ -167,6 +367,7 @@ class FlashAttentionImpl(AttentionImpl):
...
@@ -167,6 +367,7 @@ class FlashAttentionImpl(AttentionImpl):
blocksparse_params
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
blocksparse_params
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
logits_soft_cap
:
Optional
[
float
]
=
None
,
logits_soft_cap
:
Optional
[
float
]
=
None
,
attn_type
:
AttentionType
=
AttentionType
.
DECODER
,
attn_type
:
AttentionType
=
AttentionType
.
DECODER
,
use_irope
:
bool
=
False
,
)
->
None
:
)
->
None
:
if
blocksparse_params
is
not
None
:
if
blocksparse_params
is
not
None
:
raise
ValueError
(
raise
ValueError
(
...
@@ -203,6 +404,7 @@ class FlashAttentionImpl(AttentionImpl):
...
@@ -203,6 +404,7 @@ class FlashAttentionImpl(AttentionImpl):
"encoder/decoder cross-attention "
"encoder/decoder cross-attention "
"are not implemented for "
"are not implemented for "
"FlashAttentionImpl"
)
"FlashAttentionImpl"
)
self
.
use_irope
=
use_irope
self
.
vllm_flash_attn_version
=
get_flash_attn_version
()
self
.
vllm_flash_attn_version
=
get_flash_attn_version
()
if
is_quantized_kv_cache
(
self
.
kv_cache_dtype
)
\
if
is_quantized_kv_cache
(
self
.
kv_cache_dtype
)
\
and
not
flash_attn_supports_fp8
():
and
not
flash_attn_supports_fp8
():
...
@@ -265,8 +467,7 @@ class FlashAttentionImpl(AttentionImpl):
...
@@ -265,8 +467,7 @@ class FlashAttentionImpl(AttentionImpl):
layer
.
_k_scale
,
layer
.
_k_scale
,
layer
.
_v_scale
,
layer
.
_v_scale
,
)
)
descale_shape
=
(
attn_metadata
.
query_start_loc
.
shape
[
0
]
-
1
,
key
.
shape
[
1
])
if
self
.
kv_cache_dtype
.
startswith
(
"fp8"
):
if
self
.
kv_cache_dtype
.
startswith
(
"fp8"
):
key_cache
=
key_cache
.
view
(
torch
.
float8_e4m3fn
)
key_cache
=
key_cache
.
view
(
torch
.
float8_e4m3fn
)
value_cache
=
value_cache
.
view
(
torch
.
float8_e4m3fn
)
value_cache
=
value_cache
.
view
(
torch
.
float8_e4m3fn
)
...
@@ -278,22 +479,41 @@ class FlashAttentionImpl(AttentionImpl):
...
@@ -278,22 +479,41 @@ class FlashAttentionImpl(AttentionImpl):
query
=
query
.
reshape
((
num_tokens
,
num_heads
,
head_size
))
query
=
query
.
reshape
((
num_tokens
,
num_heads
,
head_size
))
# Compute attention and update output up to `num_actual_tokens`.
# Compute attention and update output up to `num_actual_tokens`.
if
not
attn_metadata
.
use_cascade
:
use_local_attn
=
\
# Regular attention (common case).
(
self
.
use_irope
and
attn_metadata
.
local_attn_metadata
is
not
None
)
if
not
attn_metadata
.
use_cascade
or
use_local_attn
:
if
use_local_attn
:
assert
attn_metadata
.
local_attn_metadata
is
not
None
local_metadata
=
attn_metadata
.
local_attn_metadata
cu_seqlens_q
=
local_metadata
.
local_query_start_loc
seqused_k
=
local_metadata
.
local_seqused_k
max_seqlen_q
=
local_metadata
.
local_max_query_len
max_seqlen_k
=
local_metadata
.
local_max_seq_len
block_table
=
local_metadata
.
local_block_table
else
:
cu_seqlens_q
=
attn_metadata
.
query_start_loc
seqused_k
=
attn_metadata
.
seq_lens
max_seqlen_q
=
attn_metadata
.
max_query_len
max_seqlen_k
=
attn_metadata
.
max_seq_len
block_table
=
attn_metadata
.
block_table
descale_shape
=
(
cu_seqlens_q
.
shape
[
0
]
-
1
,
key
.
shape
[
1
])
flash_attn_varlen_func
(
flash_attn_varlen_func
(
q
=
query
[:
num_actual_tokens
],
q
=
query
[:
num_actual_tokens
],
k
=
key_cache
,
k
=
key_cache
,
v
=
value_cache
,
v
=
value_cache
,
out
=
output
[:
num_actual_tokens
],
out
=
output
[:
num_actual_tokens
],
cu_seqlens_q
=
attn_metadata
.
query_start_loc
,
cu_seqlens_q
=
cu_seqlens_q
,
max_seqlen_q
=
attn_metadata
.
max_query_
len
,
max_seqlen_q
=
max_seq
len
_q
,
seqused_k
=
attn_metadata
.
seq_lens
,
seqused_k
=
seqused_k
,
max_seqlen_k
=
attn_metadata
.
max_seq
_
len
,
max_seqlen_k
=
max_seqlen
_k
,
softmax_scale
=
self
.
scale
,
softmax_scale
=
self
.
scale
,
causal
=
True
,
causal
=
True
,
alibi_slopes
=
self
.
alibi_slopes
,
alibi_slopes
=
self
.
alibi_slopes
,
window_size
=
self
.
sliding_window
,
window_size
=
self
.
sliding_window
,
block_table
=
attn_metadata
.
block_table
,
block_table
=
block_table
,
softcap
=
self
.
logits_soft_cap
,
softcap
=
self
.
logits_soft_cap
,
fa_version
=
self
.
vllm_flash_attn_version
,
fa_version
=
self
.
vllm_flash_attn_version
,
q_descale
=
layer
.
_q_scale
.
expand
(
descale_shape
),
q_descale
=
layer
.
_q_scale
.
expand
(
descale_shape
),
...
@@ -302,6 +522,8 @@ class FlashAttentionImpl(AttentionImpl):
...
@@ -302,6 +522,8 @@ class FlashAttentionImpl(AttentionImpl):
)
)
return
output
return
output
assert
not
use_local_attn
,
(
"Cascade attention does not support local attention."
)
# Cascade attention (rare case).
# Cascade attention (rare case).
cascade_attention
(
cascade_attention
(
output
[:
num_actual_tokens
],
output
[:
num_actual_tokens
],
...
...
vllm/v1/attention/backends/triton_attn.py
View file @
c5752323
...
@@ -70,6 +70,7 @@ class TritonAttentionImpl(AttentionImpl):
...
@@ -70,6 +70,7 @@ class TritonAttentionImpl(AttentionImpl):
blocksparse_params
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
blocksparse_params
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
logits_soft_cap
:
Optional
[
float
]
=
None
,
logits_soft_cap
:
Optional
[
float
]
=
None
,
attn_type
:
AttentionType
=
AttentionType
.
DECODER
,
attn_type
:
AttentionType
=
AttentionType
.
DECODER
,
use_irope
:
bool
=
False
,
)
->
None
:
)
->
None
:
if
blocksparse_params
is
not
None
:
if
blocksparse_params
is
not
None
:
raise
ValueError
(
raise
ValueError
(
...
@@ -86,6 +87,7 @@ class TritonAttentionImpl(AttentionImpl):
...
@@ -86,6 +87,7 @@ class TritonAttentionImpl(AttentionImpl):
else
:
else
:
self
.
sliding_window
=
(
sliding_window
-
1
,
0
)
self
.
sliding_window
=
(
sliding_window
-
1
,
0
)
self
.
kv_cache_dtype
=
kv_cache_dtype
self
.
kv_cache_dtype
=
kv_cache_dtype
self
.
use_irope
=
use_irope
assert
self
.
num_heads
%
self
.
num_kv_heads
==
0
assert
self
.
num_heads
%
self
.
num_kv_heads
==
0
self
.
num_queries_per_kv
=
self
.
num_heads
//
self
.
num_kv_heads
self
.
num_queries_per_kv
=
self
.
num_heads
//
self
.
num_kv_heads
...
@@ -156,24 +158,41 @@ class TritonAttentionImpl(AttentionImpl):
...
@@ -156,24 +158,41 @@ class TritonAttentionImpl(AttentionImpl):
layer
.
_v_scale
,
layer
.
_v_scale
,
)
)
use_local_attn
=
\
(
self
.
use_irope
and
attn_metadata
.
local_attn_metadata
is
not
None
)
if
use_local_attn
:
assert
attn_metadata
.
local_attn_metadata
is
not
None
local_metadata
=
attn_metadata
.
local_attn_metadata
cu_seqlens_q
=
local_metadata
.
local_query_start_loc
sequesd_k
=
local_metadata
.
local_seqused_k
max_seqlen_q
=
local_metadata
.
local_max_query_len
max_seqlen_k
=
local_metadata
.
local_max_seq_len
block_table
=
local_metadata
.
local_block_table
else
:
cu_seqlens_q
=
attn_metadata
.
query_start_loc
sequesd_k
=
attn_metadata
.
seq_lens
max_seqlen_q
=
attn_metadata
.
max_query_len
max_seqlen_k
=
attn_metadata
.
max_seq_len
block_table
=
attn_metadata
.
block_table
# Compute attention and update output up to `num_actual_tokens`.
# Compute attention and update output up to `num_actual_tokens`.
chunked_prefill_paged_decode
(
chunked_prefill_paged_decode
(
query
=
query
[:
num_actual_tokens
],
query
=
query
[:
num_actual_tokens
],
key
=
key
[:
num_actual_tokens
],
key
=
key
[:
num_actual_tokens
],
value
=
value
[:
num_actual_tokens
],
value
=
value
[:
num_actual_tokens
],
output
=
output
[:
num_actual_tokens
],
output
=
output
[:
num_actual_tokens
],
kv_cache_dtype
=
self
.
kv_cache_dtype
,
kv_cache_dtype
=
self
.
kv_cache_dtype
,
key_cache
=
key_cache
,
key_cache
=
key_cache
,
value_cache
=
value_cache
,
value_cache
=
value_cache
,
block_table
=
block_table
,
block_table
=
attn_metadata
.
block_table
,
query_start_loc
=
cu_seqlens_q
,
query_start_loc
=
attn_metadata
.
query_start_loc
,
seq_lens
=
sequesd_k
,
seq_lens
=
attn_metadata
.
seq_lens
,
max_seq_len
=
max_seqlen_k
,
max_seq_len
=
attn_metadata
.
max_seq_len
,
max_query_len
=
max_seqlen_q
,
max_query_len
=
attn_metadata
.
max_query_len
,
k_scale
=
layer
.
_k_scale
,
k_scale
=
layer
.
_k_scale
,
v_scale
=
layer
.
_v_scale
,
v_scale
=
layer
.
_v_scale
,
alibi_slopes
=
self
.
alibi_slopes
,
alibi_slopes
=
self
.
alibi_slopes
,
sliding_window
=
self
.
sliding_window
[
0
],
sliding_window
=
self
.
sliding_window
[
0
],
sm_scale
=
self
.
scale
)
sm_scale
=
self
.
scale
)
return
output
return
output
vllm/v1/worker/gpu_model_runner.py
View file @
c5752323
...
@@ -113,6 +113,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -113,6 +113,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self
.
num_kv_heads
=
model_config
.
get_num_kv_heads
(
parallel_config
)
self
.
num_kv_heads
=
model_config
.
get_num_kv_heads
(
parallel_config
)
self
.
head_size
=
model_config
.
get_head_size
()
self
.
head_size
=
model_config
.
get_head_size
()
self
.
hidden_size
=
model_config
.
get_hidden_size
()
self
.
hidden_size
=
model_config
.
get_hidden_size
()
self
.
attention_chunk_size
=
model_config
.
attention_chunk_size
self
.
attn_backend
=
get_attn_backend
(
self
.
attn_backend
=
get_attn_backend
(
self
.
head_size
,
self
.
head_size
,
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment