Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
c5752323
Unverified
Commit
c5752323
authored
Apr 05, 2025
by
Lu Fang
Committed by
GitHub
Apr 05, 2025
Browse files
[Model] Support Llama4 in vLLM (#16104)
parent
63375f0c
Changes
35
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
1864 additions
and
91 deletions
+1864
-91
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
...quantization/compressed_tensors/compressed_tensors_moe.py
+24
-14
vllm/model_executor/layers/quantization/experts_int8.py
vllm/model_executor/layers/quantization/experts_int8.py
+15
-12
vllm/model_executor/layers/quantization/fp8.py
vllm/model_executor/layers/quantization/fp8.py
+2
-0
vllm/model_executor/layers/quantization/gguf.py
vllm/model_executor/layers/quantization/gguf.py
+6
-0
vllm/model_executor/layers/quantization/gptq_marlin.py
vllm/model_executor/layers/quantization/gptq_marlin.py
+5
-0
vllm/model_executor/layers/quantization/moe_wna16.py
vllm/model_executor/layers/quantization/moe_wna16.py
+18
-15
vllm/model_executor/layers/quantization/quark/quark_moe.py
vllm/model_executor/layers/quantization/quark/quark_moe.py
+17
-13
vllm/model_executor/layers/rotary_embedding.py
vllm/model_executor/layers/rotary_embedding.py
+68
-0
vllm/model_executor/models/llama.py
vllm/model_executor/models/llama.py
+17
-5
vllm/model_executor/models/llama4.py
vllm/model_executor/models/llama4.py
+530
-0
vllm/model_executor/models/mllama4.py
vllm/model_executor/models/mllama4.py
+886
-0
vllm/model_executor/models/registry.py
vllm/model_executor/models/registry.py
+2
-0
vllm/v1/attention/backends/flash_attn.py
vllm/v1/attention/backends/flash_attn.py
+236
-14
vllm/v1/attention/backends/triton_attn.py
vllm/v1/attention/backends/triton_attn.py
+37
-18
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+1
-0
No files found.
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
View file @
c5752323
...
...
@@ -224,6 +224,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
custom_routing_function
:
Optional
[
Callable
]
=
None
,
scoring_func
:
str
=
"softmax"
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
apply_router_weight_on_input
:
bool
=
False
,
activation
:
str
=
"silu"
,
)
->
torch
.
Tensor
:
from
vllm.model_executor.layers.fused_moe
import
fused_experts
...
...
@@ -240,13 +241,15 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
scoring_func
=
scoring_func
,
e_score_correction_bias
=
e_score_correction_bias
)
return
fused_experts
(
x
,
return
fused_experts
(
x
,
layer
.
w13_weight
,
layer
.
w2_weight
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
inplace
=
True
,
activation
=
activation
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
use_fp8_w8a8
=
True
,
global_num_experts
=
global_num_experts
,
expert_map
=
expert_map
,
...
...
@@ -438,6 +441,7 @@ class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod):
custom_routing_function
:
Optional
[
Callable
]
=
None
,
scoring_func
:
str
=
"softmax"
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
apply_router_weight_on_input
:
bool
=
False
,
activation
:
str
=
"silu"
,
)
->
torch
.
Tensor
:
...
...
@@ -474,6 +478,7 @@ class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod):
a1_scale
=
layer
.
w13_input_scale
,
a2_scale
=
layer
.
w2_input_scale
,
out_dtype
=
x
.
dtype
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
)
...
...
@@ -778,6 +783,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
custom_routing_function
:
Optional
[
Callable
]
=
None
,
scoring_func
:
str
=
"softmax"
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
apply_router_weight_on_input
:
bool
=
False
,
activation
:
str
=
"silu"
,
)
->
torch
.
Tensor
:
assert
activation
==
"silu"
,
"Only SiLU activation is supported."
...
...
@@ -785,6 +791,10 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
raise
NotImplementedError
(
"Expert Parallelism is not supported for "
"fused Marlin MoE method."
)
if
apply_router_weight_on_input
:
raise
NotImplementedError
(
"Apply router weight on input is not supported for "
"fused Marlin MoE method."
)
topk_weights
,
topk_ids
=
FusedMoE
.
select_experts
(
hidden_states
=
x
,
...
...
vllm/model_executor/layers/quantization/experts_int8.py
View file @
c5752323
...
...
@@ -113,6 +113,7 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase):
custom_routing_function
:
Optional
[
Callable
]
=
None
,
scoring_func
:
str
=
"softmax"
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
apply_router_weight_on_input
:
bool
=
False
,
activation
:
str
=
"silu"
,
)
->
torch
.
Tensor
:
from
vllm.model_executor.layers.fused_moe
import
fused_experts
...
...
@@ -129,7 +130,8 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase):
scoring_func
=
scoring_func
,
e_score_correction_bias
=
e_score_correction_bias
)
return
fused_experts
(
x
,
return
fused_experts
(
x
,
layer
.
w13_weight
,
layer
.
w2_weight
,
topk_weights
=
topk_weights
,
...
...
@@ -138,6 +140,7 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase):
activation
=
activation
,
use_int8_w8a16
=
True
,
global_num_experts
=
global_num_experts
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
expert_map
=
expert_map
,
w1_scale
=
layer
.
w13_scale
,
w2_scale
=
layer
.
w2_scale
)
...
...
vllm/model_executor/layers/quantization/fp8.py
View file @
c5752323
...
...
@@ -773,6 +773,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
custom_routing_function
:
Optional
[
Callable
]
=
None
,
scoring_func
:
str
=
"softmax"
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
apply_router_weight_on_input
:
bool
=
False
,
activation
:
str
=
"silu"
,
)
->
torch
.
Tensor
:
from
vllm.model_executor.layers.fused_moe
import
fused_experts
...
...
@@ -800,6 +801,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
activation
=
activation
,
use_fp8_w8a8
=
True
,
global_num_experts
=
global_num_experts
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
expert_map
=
expert_map
,
w1_scale
=
(
layer
.
w13_weight_scale_inv
if
self
.
block_quant
else
layer
.
w13_weight_scale
),
...
...
vllm/model_executor/layers/quantization/gguf.py
View file @
c5752323
...
...
@@ -338,9 +338,15 @@ class GGUFMoEMethod(FusedMoEMethodBase):
custom_routing_function
:
Optional
[
Callable
]
=
None
,
scoring_func
:
str
=
"softmax"
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
apply_router_weight_on_input
:
bool
=
False
,
activation
:
str
=
"silu"
,
):
assert
activation
==
"silu"
,
"Only SiLU activation is supported."
if
apply_router_weight_on_input
:
raise
NotImplementedError
(
"Apply router weight on input is not supported for"
"fused GGUF MoE method."
)
topk_weights
,
topk_ids
=
FusedMoE
.
select_experts
(
hidden_states
=
x
,
router_logits
=
router_logits
,
...
...
vllm/model_executor/layers/quantization/gptq_marlin.py
View file @
c5752323
...
...
@@ -592,9 +592,14 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
custom_routing_function
:
Optional
[
Callable
]
=
None
,
scoring_func
:
str
=
"softmax"
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
apply_router_weight_on_input
:
bool
=
False
,
activation
:
str
=
"silu"
,
)
->
torch
.
Tensor
:
assert
activation
==
"silu"
,
"Only SiLU activation is supported."
if
apply_router_weight_on_input
is
not
None
:
raise
NotImplementedError
(
"Apply router weight on input is not supported for"
"fused Marlin MoE method."
)
# The input must currently be float16
orig_dtype
=
x
.
dtype
...
...
vllm/model_executor/layers/quantization/moe_wna16.py
View file @
c5752323
...
...
@@ -293,6 +293,7 @@ class MoeWNA16Method(FusedMoEMethodBase):
custom_routing_function
:
Optional
[
Callable
]
=
None
,
scoring_func
:
str
=
"softmax"
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
apply_router_weight_on_input
:
bool
=
False
,
activation
:
str
=
"silu"
,
)
->
torch
.
Tensor
:
from
vllm.model_executor.layers.fused_moe
import
fused_experts
...
...
@@ -312,7 +313,8 @@ class MoeWNA16Method(FusedMoEMethodBase):
weight_bits
=
self
.
quant_config
.
weight_bits
has_zp
=
self
.
quant_config
.
has_zp
return
fused_experts
(
x
,
return
fused_experts
(
x
,
layer
.
w13_qweight
,
layer
.
w2_qweight
,
topk_weights
=
topk_weights
,
...
...
@@ -321,6 +323,7 @@ class MoeWNA16Method(FusedMoEMethodBase):
use_int4_w4a16
=
weight_bits
==
4
,
use_int8_w8a16
=
weight_bits
==
8
,
global_num_experts
=
global_num_experts
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
expert_map
=
expert_map
,
w1_scale
=
layer
.
w13_scales
,
w2_scale
=
layer
.
w2_scales
,
...
...
vllm/model_executor/layers/quantization/quark/quark_moe.py
View file @
c5752323
...
...
@@ -202,6 +202,8 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
custom_routing_function
:
Optional
[
Callable
]
=
None
,
scoring_func
:
str
=
"softmax"
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
apply_router_weight_on_input
:
bool
=
False
,
activation
:
str
=
"silu"
,
)
->
torch
.
Tensor
:
from
vllm.model_executor.layers.fused_moe
import
fused_experts
...
...
@@ -217,7 +219,8 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
scoring_func
=
scoring_func
,
e_score_correction_bias
=
e_score_correction_bias
)
return
fused_experts
(
x
,
return
fused_experts
(
x
,
layer
.
w13_weight
,
layer
.
w2_weight
,
topk_weights
=
topk_weights
,
...
...
@@ -225,6 +228,7 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
inplace
=
True
,
use_fp8_w8a8
=
True
,
global_num_experts
=
global_num_experts
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
expert_map
=
expert_map
,
w1_scale
=
layer
.
w13_weight_scale
,
w2_scale
=
layer
.
w2_weight_scale
,
...
...
vllm/model_executor/layers/rotary_embedding.py
View file @
c5752323
...
...
@@ -851,6 +851,70 @@ class Llama3RotaryEmbedding(RotaryEmbedding):
return
new_freqs
class
Llama4VisionRotaryEmbedding
(
RotaryEmbedding
):
def
__init__
(
self
,
head_size
:
int
,
rotary_dim
:
int
,
max_position_embeddings
:
int
,
base
:
int
,
is_neox_style
:
bool
,
dtype
:
torch
.
dtype
,
):
super
().
__init__
(
head_size
,
rotary_dim
,
max_position_embeddings
,
base
,
is_neox_style
,
dtype
)
def
_compute_inv_freq
(
self
,
base
:
Union
[
int
,
float
])
->
torch
.
Tensor
:
inv_freqs
=
super
().
_compute_inv_freq
(
base
)
inv_freqs
=
inv_freqs
[:(
self
.
rotary_dim
//
2
)]
return
inv_freqs
def
_compute_cos_sin_cache
(
self
)
->
torch
.
Tensor
:
inv_freq
=
self
.
_compute_inv_freq
(
self
.
base
)
# self.max_position_embeddings here is number of image patches
# i.e. (image_size // patch_size) ** 2
num_patches
=
self
.
max_position_embeddings
img_idx
=
torch
.
arange
(
num_patches
,
dtype
=
torch
.
int32
)
\
.
reshape
(
num_patches
,
1
)
img_idx
=
torch
.
cat
([
img_idx
,
img_idx
[:
1
]],
dim
=
0
)
img_idx
[
-
1
,
-
1
]
=
-
2
# set to ID_CLS_TOKEN
num_patches_single_dim
=
int
(
math
.
sqrt
(
num_patches
))
frequencies_x
=
img_idx
%
num_patches_single_dim
frequencies_y
=
img_idx
//
num_patches_single_dim
freqs_x
=
((
frequencies_x
+
1
)[...,
None
]
*
inv_freq
[
None
,
None
,
:]).
repeat_interleave
(
2
,
dim
=-
1
)
freqs_y
=
((
frequencies_y
+
1
)[...,
None
]
*
inv_freq
[
None
,
None
,
:]).
repeat_interleave
(
2
,
dim
=-
1
)
freqs
=
torch
.
cat
([
freqs_x
,
freqs_y
],
dim
=-
1
).
float
().
contiguous
()[...,
::
2
]
freqs
=
freqs
.
masked_fill
(
img_idx
.
reshape
(
-
1
,
1
,
1
)
<
0
,
0
)
cache
=
torch
.
view_as_complex
(
torch
.
stack
([
torch
.
cos
(
freqs
),
torch
.
sin
(
freqs
)],
dim
=-
1
))
return
cache
def
forward
(
self
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
self
.
cos_sin_cache
:
torch
.
Tensor
=
self
.
cos_sin_cache
.
to
(
query
.
device
)
query_
=
torch
.
view_as_complex
(
query
.
float
().
reshape
(
*
query
.
shape
[:
-
1
],
-
1
,
2
))
key_
=
torch
.
view_as_complex
(
key
.
float
().
reshape
(
*
key
.
shape
[:
-
1
],
-
1
,
2
))
broadcast_shape
=
[
d
if
i
==
1
or
i
==
(
query_
.
ndim
-
1
)
else
1
for
i
,
d
in
enumerate
(
query_
.
shape
)
]
freqs_ci
=
self
.
cos_sin_cache
.
view
(
*
broadcast_shape
)
query_out
=
torch
.
view_as_real
(
query_
*
freqs_ci
).
flatten
(
3
)
key_out
=
torch
.
view_as_real
(
key_
*
freqs_ci
).
flatten
(
3
)
return
query_out
.
type_as
(
query
),
key_out
.
type_as
(
key
)
class
MRotaryEmbedding
(
RotaryEmbedding
):
"""Rotary Embedding with Multimodal Sections."""
...
...
@@ -1130,6 +1194,10 @@ def get_rope(
scaling_factor
,
low_freq_factor
,
high_freq_factor
,
original_max_position
)
elif
scaling_type
==
"mllama4"
:
rotary_emb
=
Llama4VisionRotaryEmbedding
(
head_size
,
rotary_dim
,
max_position
,
base
,
is_neox_style
,
dtype
)
elif
scaling_type
==
"default"
:
if
"mrope_section"
in
rope_scaling
:
rotary_emb
=
MRotaryEmbedding
(
...
...
vllm/model_executor/models/llama.py
View file @
c5752323
...
...
@@ -65,6 +65,7 @@ class LlamaMLP(nn.Module):
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
bias
:
bool
=
False
,
prefix
:
str
=
""
,
reduce_results
:
bool
=
True
,
)
->
None
:
super
().
__init__
()
self
.
gate_up_proj
=
MergedColumnParallelLinear
(
...
...
@@ -79,6 +80,7 @@ class LlamaMLP(nn.Module):
output_size
=
hidden_size
,
bias
=
bias
,
quant_config
=
quant_config
,
reduce_results
=
reduce_results
,
prefix
=
f
"
{
prefix
}
.down_proj"
,
)
if
hidden_act
!=
"silu"
:
...
...
@@ -466,10 +468,14 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
"ffn_norm"
:
"post_attention_layernorm"
,
"tok_embeddings"
:
"model.embed_tokens"
,
"output"
:
"lm_head"
,
"norm"
:
"model.norm"
"norm"
:
"model.norm"
,
}
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
,
layer_type
:
Type
[
LlamaDecoderLayer
]
=
LlamaDecoderLayer
):
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
quant_config
=
vllm_config
.
quant_config
...
...
@@ -478,7 +484,8 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self
.
lora_config
=
lora_config
self
.
model
=
self
.
_init_model
(
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"model"
))
prefix
=
maybe_prefix
(
prefix
,
"model"
),
layer_type
=
layer_type
)
if
get_pp_group
().
is_last_rank
:
self
.
unpadded_vocab_size
=
config
.
vocab_size
...
...
@@ -513,8 +520,13 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self
.
make_empty_intermediate_tensors
=
(
self
.
model
.
make_empty_intermediate_tensors
)
def
_init_model
(
self
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
return
LlamaModel
(
vllm_config
=
vllm_config
,
prefix
=
prefix
)
def
_init_model
(
self
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
,
layer_type
:
Type
[
LlamaDecoderLayer
]
=
LlamaDecoderLayer
):
return
LlamaModel
(
vllm_config
=
vllm_config
,
prefix
=
prefix
,
layer_type
=
layer_type
)
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
model
.
get_input_embeddings
(
input_ids
)
...
...
vllm/model_executor/models/llama4.py
0 → 100644
View file @
c5752323
This diff is collapsed.
Click to expand it.
vllm/model_executor/models/mllama4.py
0 → 100644
View file @
c5752323
This diff is collapsed.
Click to expand it.
vllm/model_executor/models/registry.py
View file @
c5752323
...
...
@@ -73,6 +73,7 @@ _TEXT_GENERATION_MODELS = {
"JAISLMHeadModel"
:
(
"jais"
,
"JAISLMHeadModel"
),
"JambaForCausalLM"
:
(
"jamba"
,
"JambaForCausalLM"
),
"LlamaForCausalLM"
:
(
"llama"
,
"LlamaForCausalLM"
),
"Llama4ForCausalLM"
:
(
"llama4"
,
"Llama4ForCausalLM"
),
# For decapoda-research/llama-*
"LLaMAForCausalLM"
:
(
"llama"
,
"LlamaForCausalLM"
),
"MambaForCausalLM"
:
(
"mamba"
,
"MambaForCausalLM"
),
...
...
@@ -194,6 +195,7 @@ _MULTIMODAL_MODELS = {
# [Encoder-decoder]
"Florence2ForConditionalGeneration"
:
(
"florence2"
,
"Florence2ForConditionalGeneration"
),
# noqa: E501
"MllamaForConditionalGeneration"
:
(
"mllama"
,
"MllamaForConditionalGeneration"
),
# noqa: E501
"Llama4ForConditionalGeneration"
:
(
"mllama4"
,
"Llama4ForConditionalGeneration"
),
# noqa: E501
"SkyworkR1VChatModel"
:
(
"skyworkr1v"
,
"SkyworkR1VChatModel"
),
"WhisperForConditionalGeneration"
:
(
"whisper"
,
"WhisperForConditionalGeneration"
),
# noqa: E501
}
...
...
vllm/v1/attention/backends/flash_attn.py
View file @
c5752323
...
...
@@ -96,6 +96,183 @@ class FlashAttentionMetadata:
# For logging.
num_input_tokens
:
int
=
0
# Number of tokens including padding.
# for local attention
@
dataclass
class
LocalAttentionMetadata
:
local_query_start_loc
:
torch
.
Tensor
local_seqused_k
:
torch
.
Tensor
local_block_table
:
torch
.
Tensor
local_max_query_len
:
int
local_max_seq_len
:
int
local_attn_metadata
:
Optional
[
LocalAttentionMetadata
]
=
None
#
# Take in `query_start_loc_np` and `seq_lens_np` and break the sequences into
# local attention blocks, where each block is passed to the attention kernel
# as an independent local ("virtual") batch item.
#
# For example, if are performing a chunked prefill a batch of 3 sequences:
# q_seqlens = [4, 10, 5]
# kv_seqlens = [6, 17, 9]
# Then normally for regular attention we would compute with an attention mask
# for batch idx 0 (q_seqlens = 4, kv_seqlens = 6) like:
# batch idx: 0 (q_seqlens = 4, kv_seqlens = 6)
# k_toks > 0 1 2 3 4 5
# q_toks v _____________
# 0 | 1 1 1
# 1 | 1 1 1 1
# 2 | 1 1 1 1 1
# 3 | 1 1 1 1 1 1
#
# for local attention (with attn_chunk_size = 4) we would compute with an
# attention mask like:
# batch idx: 0 (q_seqlens = 4, kv_seqlens = 6, attn_chunk_size = 4)
# k_toks > 0 1 2 3 4 5
# q_toks v _____________
# 0 | 1 1 1
# 1 | 1 1 1 1
# 2 | 1
# 3 | 1 1
#
# We can simulate this mask using standard flash-attention by breaking the
# sequences into local ("virtual") batches, where each local batch item is a
# local attention block, so in this case batch idx 0 would be broken up into:
#
# local-batch idx: 0 (q_seqlens = 2, kv_seqlens = 4) (batch 0)
# k_toks > 0 1 2 3
# q_toks v _____________
# 0 | 1 1 1
# 1 | 1 1 1 1
# local-batch idx: 1 (q_seqlens = 2, kv_seqlens = 2) (batch 0)
# k_toks > 4 5
# q_toks v _____________
# 2 | 1
# 3 | 1 1
#
# e.g. if we have:
# attn_chunk_size = 4
# query_start_loc_np = [0, 4, 14, 19] (q_seqlens = [4, 10, 5])
# Then this function would return:
# __b0__ ______b1______ __b2__ < orig batch indices
# q_seqlens_local = [ 2, 2, 1, 4, 4, 1, 4, 1]
# cu_seqlens_q_local = [0, 4, 6, 10, 14, 18, 19, 23, 24]
# seqlens_k_local = [ 4, 2, 4, 4, 4, 1, 4, 1]
# block_table_local : shape[local_virtual_batches, pages_per_local_batch]
def
make_local_attention_virtual_batches
(
attn_chunk_size
:
int
,
query_start_loc_np
:
np
.
ndarray
,
seq_lens_np
:
np
.
ndarray
,
block_table
:
torch
.
tensor
,
page_size
:
int
=
0
,
)
->
tuple
[
np
.
ndarray
,
np
.
ndarray
,
np
.
ndarray
,
torch
.
tensor
]:
q_seqlens
=
query_start_loc_np
[
1
:]
-
query_start_loc_np
[:
-
1
]
actual_batch_size
=
seq_lens_np
.
shape
[
0
]
# Handle if we are starting in the middle of a local attention block,
# we assume q_seqlens > 0 (for all elements), for each batch idx we compute
# the number of tokens that are not in the first local attention block and
# then we can simply use a cdiv for the rest.
# For example if we have:
# attn_chunk_size = 4
# q_seqlens = [4, 10, 5]
# k_seqlens = [6, 17, 9]
# Then we would get:
# new_tokens_in_first_block = [2, 1, 4]
# local_blocks = [2, 4, 2]
q_tokens_in_first_block
=
np
.
minimum
(
attn_chunk_size
-
((
seq_lens_np
-
q_seqlens
)
%
attn_chunk_size
),
q_seqlens
).
astype
(
np
.
int32
)
tokens_in_last_block
=
attn_chunk_size
+
(
seq_lens_np
%
-
attn_chunk_size
)
local_blocks
=
1
+
cdiv
(
q_seqlens
-
q_tokens_in_first_block
,
attn_chunk_size
)
# Once we know the number of local blocks we can compute the request spans
# for each batch idx, we can figure out the number of "virtual" requests we
# have to make,
# For the above example we would get:
# seqlens_q_local = [2, 2, 1, 4, 4, 1, 4, 1]
#
# First Get batched arange. (E.g., [2, 4, 2] -> [0, 1, 0, 1, 2, 3, 0, 1])
# (TODO: max a utility to share this code with _prepare_inputs)
# arange step 1. [2, 4, 2] -> [2, 6, 8]
cu_num_blocks
=
np
.
cumsum
(
local_blocks
)
virtual_batches
=
cu_num_blocks
[
-
1
]
# arange step 2. [2, 6, 8] -> [0, 0, 2, 2, 2, 2, 6, 6]
block_offsets
=
np
.
repeat
(
cu_num_blocks
-
local_blocks
,
local_blocks
)
# arange step 3. [0, 1, 0, 1, 2, 3, 0, 1]
arange
=
np
.
arange
(
virtual_batches
,
dtype
=
np
.
int32
)
-
block_offsets
# also compute reverse arange (i.e. [1, 0, 3, 2, 1, 0, 1, 0])
rarange
=
np
.
repeat
(
local_blocks
,
local_blocks
)
-
arange
-
1
# Then we can compute the seqlens_q_local, handling the fact that the
# first and last blocks could be partial
seqlens_q_local
=
\
np
.
repeat
(
q_seqlens
-
q_tokens_in_first_block
,
local_blocks
)
# set the first block since this may be a partial block
seqlens_q_local
[
arange
==
0
]
=
q_tokens_in_first_block
# set the remaining blocks
seqlens_q_local
[
arange
>
0
]
=
np
.
minimum
(
seqlens_q_local
-
attn_chunk_size
*
(
arange
-
1
),
attn_chunk_size
)[
arange
>
0
]
# convert from q_seqlens to cu_seqlens_q
cu_seqlens_q_local
=
np
.
pad
(
np
.
cumsum
(
seqlens_q_local
),
(
1
,
0
))
\
.
astype
(
np
.
int32
)
# compute the seqlens_k_local,
# basically a full local attention block for all but the last block in each
# batch
# For our example this will be:
# seqlens_k_local = [4, 2, 4, 4, 4, 1, 4, 1]
seqlens_k_local
=
np
.
full
(
cu_num_blocks
[
-
1
],
attn_chunk_size
,
dtype
=
np
.
int32
)
seqlens_k_local
[
cu_num_blocks
-
1
]
=
tokens_in_last_block
k_seqstarts_absolute
=
np
.
repeat
(
seq_lens_np
,
local_blocks
)
-
\
(
rarange
*
attn_chunk_size
+
\
np
.
repeat
(
tokens_in_last_block
,
local_blocks
))
# For the example the local attention blocks start at:
# _b0_ _____b1_____ _b2_
# k_seqstarts_absolute = [0, 4, 4, 8, 12, 16, 4, 8]
block_starts
=
k_seqstarts_absolute
//
page_size
assert
attn_chunk_size
%
page_size
==
0
,
\
f
"attn_chunk_size
{
attn_chunk_size
}
is not "
\
f
"divisible by page_size
{
page_size
}
"
pages_per_local_batch
=
attn_chunk_size
//
page_size
# Create a block_table for the local attention blocks
# For out example if we have a block-table like (assuming page_size=2):
# block_table = [
# [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9], < batch 0
# [10, 11, 12, 13, 14, 15, 16, 17, 18, 19], < batch 1
# [20, 21, 22, 23, 24, 25, 26, 27, 28, 29], < batch 2
# ]
# Then for the local batches we would want a block-table like
# block_table_local = [
# [ 0, 1 ], < local-batch 0, (batch 0, starting from k[0])
# [ 2, 3 ], < local-batch 1, (batch 0, starting from k[4])
# [ 12, 13 ], < local-batch 2, (batch 1, starting from k[4])
# [ 14, 15 ], < local-batch 3, (batch 1, starting from k[8])
# [ 16, 17 ], < local-batch 4, (batch 1, starting from k[12])
# [ 18, 19 ], < local-batch 5, (batch 1, starting from k[16])
# [ 22, 23 ], < local-batch 6, (batch 2, starting from k[4])
# [ 24, 25 ], < local-batch 7, (batch 2, starting from k[8])
# ]
block_indices
=
np
.
broadcast_to
(
np
.
arange
(
pages_per_local_batch
,
dtype
=
np
.
int32
),
(
virtual_batches
,
pages_per_local_batch
))
\
+
np
.
expand_dims
(
block_starts
,
axis
=
1
)
block_indices
=
block_indices
.
flatten
()
batch_indices
=
np
.
repeat
(
np
.
arange
(
actual_batch_size
,
dtype
=
np
.
int32
),
local_blocks
*
pages_per_local_batch
)
block_table_local
=
block_table
[
batch_indices
,
block_indices
]
\
.
view
(
virtual_batches
,
-
1
)
return
seqlens_q_local
,
cu_seqlens_q_local
,
seqlens_k_local
,
\
block_table_local
class
FlashAttentionMetadataBuilder
:
...
...
@@ -109,18 +286,40 @@ class FlashAttentionMetadataBuilder:
def
build
(
self
,
num_reqs
:
int
,
num_actual_tokens
:
int
,
max_query_len
:
int
,
common_prefix_len
:
int
):
max_seq_len
=
self
.
runner
.
seq_lens_np
[:
num_reqs
].
max
()
query_start_loc
=
self
.
runner
.
query_start_loc_cpu
[:
num_reqs
+
1
].
to
(
self
.
runner
.
device
,
non_blocking
=
True
)
seq_lens
=
self
.
runner
.
seq_lens_cpu
[:
num_reqs
].
to
(
self
.
runner
.
device
,
query_start_loc_cpu
=
self
.
runner
.
query_start_loc_cpu
[:
num_reqs
+
1
]
query_start_loc
=
query_start_loc_cpu
.
to
(
self
.
runner
.
device
,
non_blocking
=
True
)
seq_lens_cpu
=
self
.
runner
.
seq_lens_cpu
[:
num_reqs
]
seq_lens
=
seq_lens_cpu
.
to
(
self
.
runner
.
device
,
non_blocking
=
True
)
block_table
=
(
self
.
runner
.
input_batch
.
block_table
.
get_device_tensor
()[:
num_reqs
])
slot_mapping
=
self
.
runner
.
slot_mapping_cpu
[:
num_actual_tokens
].
to
(
self
.
runner
.
device
,
non_blocking
=
True
).
long
()
# for local attention
local_attn_metadata
=
None
if
self
.
runner
.
attention_chunk_size
is
not
None
:
seqlens_q_local_np
,
virt_q_cu_seqlens_np
,
virt_k_seqlens_np
,
\
virt_block_table
=
make_local_attention_virtual_batches
(
self
.
runner
.
attention_chunk_size
,
self
.
runner
.
query_start_loc_np
[:
num_reqs
+
1
],
self
.
runner
.
seq_lens_np
[:
num_reqs
],
block_table
,
self
.
runner
.
block_size
,
)
local_attn_metadata
=
FlashAttentionMetadata
.
LocalAttentionMetadata
(
local_query_start_loc
=
torch
.
from_numpy
(
virt_q_cu_seqlens_np
).
to
(
self
.
runner
.
device
,
non_blocking
=
True
),
local_seqused_k
=
torch
.
from_numpy
(
virt_k_seqlens_np
).
to
(
self
.
runner
.
device
,
non_blocking
=
True
),
local_block_table
=
virt_block_table
,
local_max_query_len
=
seqlens_q_local_np
.
max
(),
local_max_seq_len
=
virt_k_seqlens_np
.
max
(),
)
use_cascade
=
common_prefix_len
>
0
if
use_cascade
:
# TODO: Optimize.
cu_prefix_query_lens
=
torch
.
tensor
([
0
,
num_actual_tokens
],
dtype
=
torch
.
int32
,
device
=
self
.
runner
.
device
)
...
...
@@ -149,6 +348,7 @@ class FlashAttentionMetadataBuilder:
cu_prefix_query_lens
=
cu_prefix_query_lens
,
prefix_kv_lens
=
prefix_kv_lens
,
suffix_kv_lens
=
suffix_kv_lens
,
local_attn_metadata
=
local_attn_metadata
,
)
return
attn_metadata
...
...
@@ -167,6 +367,7 @@ class FlashAttentionImpl(AttentionImpl):
blocksparse_params
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
logits_soft_cap
:
Optional
[
float
]
=
None
,
attn_type
:
AttentionType
=
AttentionType
.
DECODER
,
use_irope
:
bool
=
False
,
)
->
None
:
if
blocksparse_params
is
not
None
:
raise
ValueError
(
...
...
@@ -203,6 +404,7 @@ class FlashAttentionImpl(AttentionImpl):
"encoder/decoder cross-attention "
"are not implemented for "
"FlashAttentionImpl"
)
self
.
use_irope
=
use_irope
self
.
vllm_flash_attn_version
=
get_flash_attn_version
()
if
is_quantized_kv_cache
(
self
.
kv_cache_dtype
)
\
and
not
flash_attn_supports_fp8
():
...
...
@@ -265,8 +467,7 @@ class FlashAttentionImpl(AttentionImpl):
layer
.
_k_scale
,
layer
.
_v_scale
,
)
descale_shape
=
(
attn_metadata
.
query_start_loc
.
shape
[
0
]
-
1
,
key
.
shape
[
1
])
if
self
.
kv_cache_dtype
.
startswith
(
"fp8"
):
key_cache
=
key_cache
.
view
(
torch
.
float8_e4m3fn
)
value_cache
=
value_cache
.
view
(
torch
.
float8_e4m3fn
)
...
...
@@ -278,22 +479,41 @@ class FlashAttentionImpl(AttentionImpl):
query
=
query
.
reshape
((
num_tokens
,
num_heads
,
head_size
))
# Compute attention and update output up to `num_actual_tokens`.
if
not
attn_metadata
.
use_cascade
:
# Regular attention (common case).
use_local_attn
=
\
(
self
.
use_irope
and
attn_metadata
.
local_attn_metadata
is
not
None
)
if
not
attn_metadata
.
use_cascade
or
use_local_attn
:
if
use_local_attn
:
assert
attn_metadata
.
local_attn_metadata
is
not
None
local_metadata
=
attn_metadata
.
local_attn_metadata
cu_seqlens_q
=
local_metadata
.
local_query_start_loc
seqused_k
=
local_metadata
.
local_seqused_k
max_seqlen_q
=
local_metadata
.
local_max_query_len
max_seqlen_k
=
local_metadata
.
local_max_seq_len
block_table
=
local_metadata
.
local_block_table
else
:
cu_seqlens_q
=
attn_metadata
.
query_start_loc
seqused_k
=
attn_metadata
.
seq_lens
max_seqlen_q
=
attn_metadata
.
max_query_len
max_seqlen_k
=
attn_metadata
.
max_seq_len
block_table
=
attn_metadata
.
block_table
descale_shape
=
(
cu_seqlens_q
.
shape
[
0
]
-
1
,
key
.
shape
[
1
])
flash_attn_varlen_func
(
q
=
query
[:
num_actual_tokens
],
k
=
key_cache
,
v
=
value_cache
,
out
=
output
[:
num_actual_tokens
],
cu_seqlens_q
=
attn_metadata
.
query_start_loc
,
max_seqlen_q
=
attn_metadata
.
max_query_
len
,
seqused_k
=
attn_metadata
.
seq_lens
,
max_seqlen_k
=
attn_metadata
.
max_seq
_
len
,
cu_seqlens_q
=
cu_seqlens_q
,
max_seqlen_q
=
max_seq
len
_q
,
seqused_k
=
seqused_k
,
max_seqlen_k
=
max_seqlen
_k
,
softmax_scale
=
self
.
scale
,
causal
=
True
,
alibi_slopes
=
self
.
alibi_slopes
,
window_size
=
self
.
sliding_window
,
block_table
=
attn_metadata
.
block_table
,
block_table
=
block_table
,
softcap
=
self
.
logits_soft_cap
,
fa_version
=
self
.
vllm_flash_attn_version
,
q_descale
=
layer
.
_q_scale
.
expand
(
descale_shape
),
...
...
@@ -302,6 +522,8 @@ class FlashAttentionImpl(AttentionImpl):
)
return
output
assert
not
use_local_attn
,
(
"Cascade attention does not support local attention."
)
# Cascade attention (rare case).
cascade_attention
(
output
[:
num_actual_tokens
],
...
...
vllm/v1/attention/backends/triton_attn.py
View file @
c5752323
...
...
@@ -70,6 +70,7 @@ class TritonAttentionImpl(AttentionImpl):
blocksparse_params
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
logits_soft_cap
:
Optional
[
float
]
=
None
,
attn_type
:
AttentionType
=
AttentionType
.
DECODER
,
use_irope
:
bool
=
False
,
)
->
None
:
if
blocksparse_params
is
not
None
:
raise
ValueError
(
...
...
@@ -86,6 +87,7 @@ class TritonAttentionImpl(AttentionImpl):
else
:
self
.
sliding_window
=
(
sliding_window
-
1
,
0
)
self
.
kv_cache_dtype
=
kv_cache_dtype
self
.
use_irope
=
use_irope
assert
self
.
num_heads
%
self
.
num_kv_heads
==
0
self
.
num_queries_per_kv
=
self
.
num_heads
//
self
.
num_kv_heads
...
...
@@ -156,20 +158,37 @@ class TritonAttentionImpl(AttentionImpl):
layer
.
_v_scale
,
)
use_local_attn
=
\
(
self
.
use_irope
and
attn_metadata
.
local_attn_metadata
is
not
None
)
if
use_local_attn
:
assert
attn_metadata
.
local_attn_metadata
is
not
None
local_metadata
=
attn_metadata
.
local_attn_metadata
cu_seqlens_q
=
local_metadata
.
local_query_start_loc
sequesd_k
=
local_metadata
.
local_seqused_k
max_seqlen_q
=
local_metadata
.
local_max_query_len
max_seqlen_k
=
local_metadata
.
local_max_seq_len
block_table
=
local_metadata
.
local_block_table
else
:
cu_seqlens_q
=
attn_metadata
.
query_start_loc
sequesd_k
=
attn_metadata
.
seq_lens
max_seqlen_q
=
attn_metadata
.
max_query_len
max_seqlen_k
=
attn_metadata
.
max_seq_len
block_table
=
attn_metadata
.
block_table
# Compute attention and update output up to `num_actual_tokens`.
chunked_prefill_paged_decode
(
query
=
query
[:
num_actual_tokens
],
chunked_prefill_paged_decode
(
query
=
query
[:
num_actual_tokens
],
key
=
key
[:
num_actual_tokens
],
value
=
value
[:
num_actual_tokens
],
output
=
output
[:
num_actual_tokens
],
kv_cache_dtype
=
self
.
kv_cache_dtype
,
key_cache
=
key_cache
,
value_cache
=
value_cache
,
block_table
=
attn_metadata
.
block_table
,
query_start_loc
=
attn_metadata
.
query_start_loc
,
seq_lens
=
attn_metadata
.
seq_lens
,
max_seq_len
=
attn_metadata
.
max_seq
_
len
,
max_query_len
=
attn_metadata
.
max_query_
len
,
block_table
=
block_table
,
query_start_loc
=
cu_seqlens_q
,
seq_lens
=
sequesd_k
,
max_seq_len
=
max_seqlen
_k
,
max_query_len
=
max_seq
len
_q
,
k_scale
=
layer
.
_k_scale
,
v_scale
=
layer
.
_v_scale
,
alibi_slopes
=
self
.
alibi_slopes
,
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
c5752323
...
...
@@ -113,6 +113,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self
.
num_kv_heads
=
model_config
.
get_num_kv_heads
(
parallel_config
)
self
.
head_size
=
model_config
.
get_head_size
()
self
.
hidden_size
=
model_config
.
get_hidden_size
()
self
.
attention_chunk_size
=
model_config
.
attention_chunk_size
self
.
attn_backend
=
get_attn_backend
(
self
.
head_size
,
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment