Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
98fd089f
"wrappers/python/src/vscode:/vscode.git/clone" did not exist on "d58cc041374e0a94dabcc127bc46661e3a0cc817"
Unverified
Commit
98fd089f
authored
Feb 05, 2025
by
Isotr0py
Committed by
GitHub
Feb 04, 2025
Browse files
[VLM] Add MLA with pure RoPE support for deepseek-vl2 models (#12729)
parent
249824c3
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
30 additions
and
6 deletions
+30
-6
vllm/attention/backends/mla/utils.py
vllm/attention/backends/mla/utils.py
+26
-4
vllm/model_executor/models/deepseek_v2.py
vllm/model_executor/models/deepseek_v2.py
+2
-1
vllm/model_executor/models/deepseek_v3.py
vllm/model_executor/models/deepseek_v3.py
+2
-1
No files found.
vllm/attention/backends/mla/utils.py
View file @
98fd089f
...
@@ -26,7 +26,8 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import (
...
@@ -26,7 +26,8 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import (
apply_fp8_linear_generic
,
current_platform_fp8_dtype
,
is_fp8
)
apply_fp8_linear_generic
,
current_platform_fp8_dtype
,
is_fp8
)
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
scaled_dequantize
,
scaled_quantize
)
scaled_dequantize
,
scaled_quantize
)
from
vllm.model_executor.layers.rotary_embedding
import
RotaryEmbedding
from
vllm.model_executor.layers.rotary_embedding
import
(
DeepseekScalingRotaryEmbedding
,
RotaryEmbedding
)
try
:
try
:
from
vllm.vllm_flash_attn
import
flash_attn_varlen_func
from
vllm.vllm_flash_attn
import
flash_attn_varlen_func
...
@@ -174,6 +175,8 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
...
@@ -174,6 +175,8 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
self
.
v_head_dim
=
v_head_dim
self
.
v_head_dim
=
v_head_dim
self
.
rotary_emb
=
rotary_emb
self
.
rotary_emb
=
rotary_emb
self
.
use_yarn_rope
=
isinstance
(
rotary_emb
,
DeepseekScalingRotaryEmbedding
)
self
.
q_proj
=
q_proj
self
.
q_proj
=
q_proj
self
.
kv_b_proj
=
kv_b_proj
self
.
kv_b_proj
=
kv_b_proj
self
.
o_proj
=
o_proj
self
.
o_proj
=
o_proj
...
@@ -420,6 +423,24 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
...
@@ -420,6 +423,24 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
raise
NotImplementedError
raise
NotImplementedError
def
apply_pure_rope
(
self
,
input_positions
:
torch
.
Tensor
,
q_pe
:
torch
.
Tensor
,
k_pe
:
torch
.
Tensor
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
seq_len
=
input_positions
.
size
(
0
)
ori_q_pe_shape
,
ori_k_pe_shape
=
q_pe
.
shape
,
k_pe
.
shape
q_pe
,
k_pe
=
self
.
rotary_emb
(
input_positions
,
q_pe
.
reshape
(
seq_len
,
-
1
),
k_pe
.
reshape
(
seq_len
,
-
1
),
)
q_pe
,
k_pe
=
q_pe
.
view
(
ori_q_pe_shape
),
k_pe
.
view
(
ori_k_pe_shape
)
return
q_pe
,
k_pe
def
forward
(
def
forward
(
self
,
self
,
layer
:
AttentionLayer
,
layer
:
AttentionLayer
,
...
@@ -444,13 +465,14 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
...
@@ -444,13 +465,14 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
# Restore head dim (for rotary embedding)
# Restore head dim (for rotary embedding)
k_pe
=
k_pe
.
unsqueeze
(
1
)
k_pe
=
k_pe
.
unsqueeze
(
1
)
assert
hasattr
(
attn_metadata
,
"input_positions"
)
assert
hasattr
(
attn_metadata
,
"input_positions"
)
rope_fn
=
(
self
.
rotary_emb
if
self
.
use_yarn_rope
else
self
.
apply_pure_rope
)
if
is_decode
:
if
is_decode
:
q_nope
=
self
.
_q_proj_and_k_up_proj
(
hidden_states_or_q_c
)
q_nope
=
self
.
_q_proj_and_k_up_proj
(
hidden_states_or_q_c
)
q_pe
=
torch
.
matmul
(
hidden_states_or_q_c
,
self
.
W_QR
)
\
q_pe
=
torch
.
matmul
(
hidden_states_or_q_c
,
self
.
W_QR
)
\
.
view
(
-
1
,
self
.
num_heads
,
self
.
qk_rope_head_dim
)
.
view
(
-
1
,
self
.
num_heads
,
self
.
qk_rope_head_dim
)
q_pe
,
k_pe
=
\
q_pe
,
k_pe
=
rope_fn
(
attn_metadata
.
input_positions
,
q_pe
,
k_pe
)
self
.
rotary_emb
(
attn_metadata
.
input_positions
,
q_pe
,
k_pe
)
else
:
else
:
assert
is_prefill
assert
is_prefill
q
=
self
.
q_proj
(
hidden_states_or_q_c
)[
0
]
\
q
=
self
.
q_proj
(
hidden_states_or_q_c
)[
0
]
\
...
@@ -458,7 +480,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
...
@@ -458,7 +480,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
# TODO(lucas): there must be a nicer way to write this line
# TODO(lucas): there must be a nicer way to write this line
q
[...,
self
.
qk_nope_head_dim
:],
k_pe
=
\
q
[...,
self
.
qk_nope_head_dim
:],
k_pe
=
\
self
.
rotary_emb
(
rope_fn
(
attn_metadata
.
input_positions
,
attn_metadata
.
input_positions
,
q
[...,
self
.
qk_nope_head_dim
:],
k_pe
)
q
[...,
self
.
qk_nope_head_dim
:],
k_pe
)
...
...
vllm/model_executor/models/deepseek_v2.py
View file @
98fd089f
...
@@ -414,7 +414,8 @@ class DeepseekV2MLAAttention(nn.Module):
...
@@ -414,7 +414,8 @@ class DeepseekV2MLAAttention(nn.Module):
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.o_proj"
)
prefix
=
f
"
{
prefix
}
.o_proj"
)
rope_scaling
[
"rope_type"
]
=
'deepseek_yarn'
if
rope_scaling
:
rope_scaling
[
"rope_type"
]
=
'deepseek_yarn'
self
.
rotary_emb
=
get_rope
(
qk_rope_head_dim
,
self
.
rotary_emb
=
get_rope
(
qk_rope_head_dim
,
rotary_dim
=
qk_rope_head_dim
,
rotary_dim
=
qk_rope_head_dim
,
max_position
=
max_position_embeddings
,
max_position
=
max_position_embeddings
,
...
...
vllm/model_executor/models/deepseek_v3.py
View file @
98fd089f
...
@@ -422,7 +422,8 @@ class DeepseekV3MLAAttention(nn.Module):
...
@@ -422,7 +422,8 @@ class DeepseekV3MLAAttention(nn.Module):
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.o_proj"
)
prefix
=
f
"
{
prefix
}
.o_proj"
)
rope_scaling
[
"rope_type"
]
=
'deepseek_yarn'
if
rope_scaling
:
rope_scaling
[
"rope_type"
]
=
'deepseek_yarn'
self
.
rotary_emb
=
get_rope
(
qk_rope_head_dim
,
self
.
rotary_emb
=
get_rope
(
qk_rope_head_dim
,
rotary_dim
=
qk_rope_head_dim
,
rotary_dim
=
qk_rope_head_dim
,
max_position
=
max_position_embeddings
,
max_position
=
max_position_embeddings
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment