Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
14688ccd
Commit
14688ccd
authored
Dec 02, 2025
by
王敏
Browse files
Merge remote-tracking branch 'origin/v0.9.2-dev' into v0.9.2-dev
parents
55310f4f
fd559b9f
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
7 additions
and
55 deletions
+7
-55
vllm/model_executor/models/deepseek_v2.py
vllm/model_executor/models/deepseek_v2.py
+7
-55
No files found.
vllm/model_executor/models/deepseek_v2.py
View file @
14688ccd
...
...
@@ -50,7 +50,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
,
_yarn_find_correction_range
,
_yarn_linear_ramp_mask
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
(
...
...
@@ -608,52 +608,6 @@ class DeepseekV2MLAAttention(nn.Module):
self
.
prefix
=
prefix
self
.
debug_layer_idx
=
int
(
self
.
prefix
.
split
(
"."
)[
-
2
])
if
envs
.
VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT
:
self
.
max_position_embeddings
=
rope_scaling
[
"original_max_position_embeddings"
]
self
.
base
=
rope_theta
self
.
rotary_dim
=
qk_rope_head_dim
self
.
scaling_factor
=
scaling_factor
self
.
mscale
=
mscale
self
.
extrapolation_factor
=
1
self
.
beta_fast
=
32
self
.
beta_slow
=
1
cache
=
self
.
_compute_cos_sin_cache
()
cache
=
cache
.
to
(
"cuda"
)
self
.
cos_sin_cache
:
torch
.
Tensor
self
.
register_buffer
(
"cos_sin_cache"
,
cache
,
persistent
=
False
)
def
_compute_inv_freq
(
self
,
scaling_factor
:
float
)
->
torch
.
Tensor
:
pos_freqs
=
self
.
base
**
(
torch
.
arange
(
0
,
self
.
rotary_dim
,
2
,
dtype
=
torch
.
float
,
device
=
"cuda"
)
/
self
.
rotary_dim
)
inv_freq_extrapolation
=
1.0
/
pos_freqs
inv_freq_interpolation
=
1.0
/
(
scaling_factor
*
pos_freqs
)
low
,
high
=
_yarn_find_correction_range
(
self
.
beta_fast
,
self
.
beta_slow
,
self
.
rotary_dim
,
self
.
base
,
self
.
max_position_embeddings
)
# Get n-d rotational scaling corrected for extrapolation
inv_freq_mask
=
(
1
-
_yarn_linear_ramp_mask
(
low
,
high
,
self
.
rotary_dim
//
2
,
dtype
=
torch
.
float
))
*
self
.
extrapolation_factor
inv_freq
=
inv_freq_interpolation
*
(
1
-
inv_freq_mask
)
+
inv_freq_extrapolation
*
inv_freq_mask
return
inv_freq
def
_compute_cos_sin_cache
(
self
)
->
torch
.
Tensor
:
inv_freq
=
self
.
_compute_inv_freq
(
self
.
scaling_factor
)
t
=
torch
.
arange
(
self
.
max_position_embeddings
*
self
.
scaling_factor
,
device
=
"cuda"
,
dtype
=
torch
.
float32
)
freqs
=
torch
.
einsum
(
"i,j -> ij"
,
t
,
inv_freq
)
cos
=
(
freqs
.
cos
()
*
self
.
mscale
)
sin
=
(
freqs
.
sin
()
*
self
.
mscale
)
cache
=
torch
.
cat
((
cos
,
sin
),
dim
=-
1
)
return
cache
def
forward
(
self
,
...
...
@@ -767,12 +721,10 @@ class DeepseekV2MLAAttention(nn.Module):
q
=
q
.
view
(
-
1
,
self
.
num_local_heads
,
self
.
qk_head_dim
)
# Add head dim of 1 to k_pe
k_pe
=
k_pe
.
unsqueeze
(
1
)
weight
=
torch
.
ones
(
kv_c
.
shape
[
-
1
],
dtype
=
q
.
dtype
,
device
=
kv_c
.
device
)
weight
=
nn
.
Parameter
(
weight
)
if
self
.
cos_sin_cache
.
device
!=
positions
.
device
:
self
.
cos_sin_cache
=
self
.
cos_sin_cache
.
to
(
positions
.
device
)
if
self
.
cos_sin_cache
.
device
!=
q
.
dtype
:
self
.
cos_sin_cache
=
self
.
cos_sin_cache
.
to
(
q
.
dtype
)
weight
=
self
.
kv_a_layernorm
.
weight
cos_sin_cache
=
self
.
rotary_emb
.
cos_sin_cache
if
cos_sin_cache
.
device
!=
positions
.
device
or
cos_sin_cache
.
device
!=
q
.
dtype
:
cos_sin_cache
=
cos_sin_cache
.
to
(
positions
.
device
,
dtype
=
q
.
dtype
)
kv_c_normed
=
torch
.
empty
(
kv_c
.
shape
,
dtype
=
kv_c
.
dtype
,
device
=
kv_c
.
device
)
attn_out
=
self
.
mla_attn
(
q
[...,
self
.
qk_nope_head_dim
:],
...
...
@@ -783,8 +735,8 @@ class DeepseekV2MLAAttention(nn.Module):
q_ori
=
q
,
key_normed
=
kv_c_normed
,
positions
=
positions
,
weight
=
weight
.
data
,
cos_sin_cache
=
self
.
cos_sin_cache
)
weight
=
weight
,
cos_sin_cache
=
cos_sin_cache
)
return
self
.
o_proj
(
attn_out
)[
0
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment