Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
3efb9f4d
Unverified
Commit
3efb9f4d
authored
Sep 04, 2025
by
whx
Committed by
GitHub
Sep 04, 2025
Browse files
[Attention][Platform] Refactor MLA to support Custom Op (#23332)
Signed-off-by:
whx-sjtu
<
2952154980@qq.com
>
parent
04f3c35c
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
186 additions
and
58 deletions
+186
-58
vllm/model_executor/layers/mla.py
vllm/model_executor/layers/mla.py
+158
-0
vllm/model_executor/models/deepseek_v2.py
vllm/model_executor/models/deepseek_v2.py
+28
-58
No files found.
vllm/model_executor/layers/mla.py
0 → 100644
View file @
3efb9f4d
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
dataclasses
import
dataclass
from
typing
import
Optional
import
torch
from
vllm.attention
import
Attention
from
vllm.config
import
CacheConfig
from
vllm.model_executor.custom_op
import
CustomOp
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
@
dataclass
class
MLAModules
:
"""Modules used in MLA.
"""
kv_a_layernorm
:
torch
.
nn
.
Module
kv_b_proj
:
torch
.
nn
.
Module
rotary_emb
:
torch
.
nn
.
Module
o_proj
:
torch
.
nn
.
Module
fused_qkv_a_proj
:
Optional
[
torch
.
nn
.
Module
]
kv_a_proj_with_mqa
:
Optional
[
torch
.
nn
.
Module
]
q_a_layernorm
:
Optional
[
torch
.
nn
.
Module
]
q_b_proj
:
Optional
[
torch
.
nn
.
Module
]
q_proj
:
Optional
[
torch
.
nn
.
Module
]
@
CustomOp
.
register
(
"multi_head_latent_attention"
)
class
MultiHeadLatentAttention
(
CustomOp
):
"""MLA layer registered as CustomOp.
Note that currently MLA ignores the enable/disable mechanism of CustomOp
because there is only one in-tree implementation in forward_native.
TODO: implement this with a new PluggableLayer mechanism.
This class takes positions and hidden_states as input.
The input tensors can either contain prefill tokens or decode tokens.
The class does the following:
1. MLA Preprocess.
2. Perform multi-head attention to prefill tokens and
multi-query attention to decode tokens separately.
3. Return the output tensor.
"""
def
__init__
(
self
,
hidden_size
:
int
,
num_heads
:
int
,
scale
:
float
,
qk_nope_head_dim
:
int
,
qk_rope_head_dim
:
int
,
v_head_dim
:
int
,
q_lora_rank
:
Optional
[
int
],
kv_lora_rank
:
int
,
mla_modules
:
MLAModules
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
self
.
hidden_size
=
hidden_size
self
.
qk_nope_head_dim
=
qk_nope_head_dim
self
.
qk_rope_head_dim
=
qk_rope_head_dim
self
.
qk_head_dim
=
qk_nope_head_dim
+
qk_rope_head_dim
self
.
v_head_dim
=
v_head_dim
self
.
q_lora_rank
=
q_lora_rank
self
.
kv_lora_rank
=
kv_lora_rank
self
.
num_heads
=
num_heads
self
.
fused_qkv_a_proj
=
mla_modules
.
fused_qkv_a_proj
self
.
kv_a_proj_with_mqa
=
mla_modules
.
kv_a_proj_with_mqa
self
.
q_a_layernorm
=
mla_modules
.
q_a_layernorm
self
.
q_b_proj
=
mla_modules
.
q_b_proj
self
.
q_proj
=
mla_modules
.
q_proj
self
.
kv_a_layernorm
=
mla_modules
.
kv_a_layernorm
self
.
kv_b_proj
=
mla_modules
.
kv_b_proj
self
.
rotary_emb
=
mla_modules
.
rotary_emb
self
.
o_proj
=
mla_modules
.
o_proj
# In the MLA backend, kv_cache includes both k_c and
# pe (i.e. decoupled position embeddings). In particular,
# the concat_and_cache_mla op requires
# k_c.size(1) + k_pe.size(1) == kv_cache.size(2)
# i.e.
# kv_lora_rank + qk_rope_head_dim == head_size
self
.
mla_attn
=
Attention
(
num_heads
=
self
.
num_heads
,
head_size
=
self
.
kv_lora_rank
+
self
.
qk_rope_head_dim
,
scale
=
scale
,
num_kv_heads
=
1
,
cache_config
=
cache_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.attn"
,
use_mla
=
True
,
# MLA Args
q_lora_rank
=
self
.
q_lora_rank
,
kv_lora_rank
=
self
.
kv_lora_rank
,
qk_nope_head_dim
=
self
.
qk_nope_head_dim
,
qk_rope_head_dim
=
self
.
qk_rope_head_dim
,
qk_head_dim
=
self
.
qk_head_dim
,
v_head_dim
=
self
.
v_head_dim
,
kv_b_proj
=
self
.
kv_b_proj
,
)
self
.
prefix
=
prefix
self
.
debug_layer_idx
=
int
(
self
.
prefix
.
split
(
"."
)[
-
2
])
def
forward_native
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
q_c
=
None
kv_lora
=
None
if
self
.
q_lora_rank
is
not
None
:
assert
self
.
fused_qkv_a_proj
is
not
None
,
\
"fused_qkv_a_proj is required when q_lora_rank is not None"
assert
self
.
q_a_layernorm
is
not
None
,
\
"q_a_layernorm is required when q_lora_rank is not None"
assert
self
.
q_b_proj
is
not
None
,
\
"q_b_proj is required when q_lora_rank is not None"
qkv_lora
=
self
.
fused_qkv_a_proj
(
hidden_states
)[
0
]
q_c
,
kv_lora
=
qkv_lora
.
split
(
[
self
.
q_lora_rank
,
self
.
kv_lora_rank
+
self
.
qk_rope_head_dim
],
dim
=-
1
,
)
q_c
=
self
.
q_a_layernorm
(
q_c
)
q
=
self
.
q_b_proj
(
q_c
)[
0
]
else
:
assert
self
.
kv_a_proj_with_mqa
is
not
None
,
\
"kv_a_proj_with_mqa is required when q_lora_rank is None"
assert
self
.
q_proj
is
not
None
,
\
"q_proj is required when q_lora_rank is None"
kv_lora
=
self
.
kv_a_proj_with_mqa
(
hidden_states
)[
0
]
q
=
self
.
q_proj
(
hidden_states
)[
0
]
kv_c
,
k_pe
=
kv_lora
.
split
([
self
.
kv_lora_rank
,
self
.
qk_rope_head_dim
],
dim
=-
1
)
kv_c_normed
=
self
.
kv_a_layernorm
(
kv_c
)
q
=
q
.
view
(
-
1
,
self
.
num_heads
,
self
.
qk_head_dim
)
# Add head dim of 1 to k_pe
k_pe
=
k_pe
.
unsqueeze
(
1
)
q
[...,
self
.
qk_nope_head_dim
:],
k_pe
=
self
.
rotary_emb
(
positions
,
q
[...,
self
.
qk_nope_head_dim
:],
k_pe
)
attn_out
=
self
.
mla_attn
(
q
,
kv_c_normed
,
k_pe
,
output_shape
=
(
hidden_states
.
shape
[
0
],
self
.
num_heads
*
self
.
v_head_dim
))
return
self
.
o_proj
(
attn_out
)[
0
]
def
forward_cuda
(
self
,
*
args
,
**
kwargs
):
return
self
.
forward_native
(
*
args
,
**
kwargs
)
vllm/model_executor/models/deepseek_v2.py
View file @
3efb9f4d
...
@@ -47,6 +47,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
...
@@ -47,6 +47,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
ReplicatedLinear
,
ReplicatedLinear
,
RowParallelLinear
)
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.mla
import
MLAModules
,
MultiHeadLatentAttention
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.shared_fused_moe
import
SharedFusedMoE
from
vllm.model_executor.layers.shared_fused_moe
import
SharedFusedMoE
...
@@ -492,72 +493,41 @@ class DeepseekV2MLAAttention(nn.Module):
...
@@ -492,72 +493,41 @@ class DeepseekV2MLAAttention(nn.Module):
mscale
=
yarn_get_mscale
(
scaling_factor
,
float
(
mscale_all_dim
))
mscale
=
yarn_get_mscale
(
scaling_factor
,
float
(
mscale_all_dim
))
self
.
scaling
=
self
.
scaling
*
mscale
*
mscale
self
.
scaling
=
self
.
scaling
*
mscale
*
mscale
# In the MLA backend, kv_cache includes both k_c and
mla_modules
=
MLAModules
(
# pe (i.e. decoupled position embeddings). In particular,
kv_a_layernorm
=
self
.
kv_a_layernorm
,
# the concat_and_cache_mla op requires
# k_c.size(1) + k_pe.size(1) == kv_cache.size(2)
# i.e.
# kv_lora_rank + qk_rope_head_dim == head_size
self
.
mla_attn
=
Attention
(
num_heads
=
self
.
num_local_heads
,
head_size
=
self
.
kv_lora_rank
+
self
.
qk_rope_head_dim
,
scale
=
self
.
scaling
,
num_kv_heads
=
1
,
cache_config
=
cache_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.attn"
,
use_mla
=
True
,
# MLA Args
q_lora_rank
=
self
.
q_lora_rank
,
kv_lora_rank
=
self
.
kv_lora_rank
,
qk_nope_head_dim
=
self
.
qk_nope_head_dim
,
qk_rope_head_dim
=
self
.
qk_rope_head_dim
,
qk_head_dim
=
self
.
qk_head_dim
,
v_head_dim
=
self
.
v_head_dim
,
kv_b_proj
=
self
.
kv_b_proj
,
kv_b_proj
=
self
.
kv_b_proj
,
rotary_emb
=
self
.
rotary_emb
,
o_proj
=
self
.
o_proj
,
fused_qkv_a_proj
=
self
.
fused_qkv_a_proj
if
self
.
q_lora_rank
is
not
None
else
None
,
kv_a_proj_with_mqa
=
self
.
kv_a_proj_with_mqa
if
self
.
q_lora_rank
is
None
else
None
,
q_a_layernorm
=
self
.
q_a_layernorm
if
self
.
q_lora_rank
is
not
None
else
None
,
q_b_proj
=
self
.
q_b_proj
if
self
.
q_lora_rank
is
not
None
else
None
,
q_proj
=
self
.
q_proj
if
self
.
q_lora_rank
is
None
else
None
,
)
self
.
mla_attn
=
MultiHeadLatentAttention
(
self
.
hidden_size
,
self
.
num_local_heads
,
self
.
scaling
,
self
.
qk_nope_head_dim
,
self
.
qk_rope_head_dim
,
self
.
v_head_dim
,
self
.
q_lora_rank
,
self
.
kv_lora_rank
,
mla_modules
,
cache_config
,
quant_config
,
prefix
,
)
)
self
.
prefix
=
prefix
self
.
debug_layer_idx
=
int
(
self
.
prefix
.
split
(
"."
)[
-
2
])
def
forward
(
def
forward
(
self
,
self
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
q_c
=
None
return
self
.
mla_attn
(
positions
,
hidden_states
)
kv_lora
=
None
if
self
.
q_lora_rank
is
not
None
:
qkv_lora
=
self
.
fused_qkv_a_proj
(
hidden_states
)[
0
]
q_c
,
kv_lora
=
qkv_lora
.
split
(
[
self
.
q_lora_rank
,
self
.
kv_lora_rank
+
self
.
qk_rope_head_dim
],
dim
=-
1
,
)
q_c
=
self
.
q_a_layernorm
(
q_c
)
q
=
self
.
q_b_proj
(
q_c
)[
0
]
else
:
kv_lora
=
self
.
kv_a_proj_with_mqa
(
hidden_states
)[
0
]
q
=
self
.
q_proj
(
hidden_states
)[
0
]
kv_c
,
k_pe
=
kv_lora
.
split
([
self
.
kv_lora_rank
,
self
.
qk_rope_head_dim
],
dim
=-
1
)
kv_c_normed
=
self
.
kv_a_layernorm
(
kv_c
)
q
=
q
.
view
(
-
1
,
self
.
num_local_heads
,
self
.
qk_head_dim
)
# Add head dim of 1 to k_pe
k_pe
=
k_pe
.
unsqueeze
(
1
)
q
[...,
self
.
qk_nope_head_dim
:],
k_pe
=
self
.
rotary_emb
(
positions
,
q
[...,
self
.
qk_nope_head_dim
:],
k_pe
)
attn_out
=
self
.
mla_attn
(
q
,
kv_c_normed
,
k_pe
,
output_shape
=
(
hidden_states
.
shape
[
0
],
self
.
num_local_heads
*
self
.
v_head_dim
))
return
self
.
o_proj
(
attn_out
)[
0
]
class
DeepseekV2DecoderLayer
(
nn
.
Module
):
class
DeepseekV2DecoderLayer
(
nn
.
Module
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment