Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
64e307c7
Commit
64e307c7
authored
Dec 01, 2025
by
zhuwenwen
Browse files
add VLLM_USE_OPT_RESHAPE_AND_CACHE (test)
parent
4c92e64a
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
211 additions
and
48 deletions
+211
-48
vllm/attention/layer.py
vllm/attention/layer.py
+64
-19
vllm/envs.py
vllm/envs.py
+5
-0
vllm/model_executor/model_loader/utils.py
vllm/model_executor/model_loader/utils.py
+6
-2
vllm/model_executor/models/deepseek_v2.py
vllm/model_executor/models/deepseek_v2.py
+87
-17
vllm/v1/attention/backends/mla/common.py
vllm/v1/attention/backends/mla/common.py
+49
-10
No files found.
vllm/attention/layer.py
View file @
64e307c7
...
@@ -199,6 +199,11 @@ class Attention(nn.Module):
...
@@ -199,6 +199,11 @@ class Attention(nn.Module):
# shape does not match the query shape, so we optionally let the model
# shape does not match the query shape, so we optionally let the model
# definition specify the output tensor shape.
# definition specify the output tensor shape.
output_shape
:
Optional
[
torch
.
Size
]
=
None
,
output_shape
:
Optional
[
torch
.
Size
]
=
None
,
q_ori
:
Optional
[
torch
.
Tensor
]
=
None
,
key_normed
:
Optional
[
torch
.
Tensor
]
=
None
,
positions
:
Optional
[
torch
.
Tensor
]
=
None
,
weight
:
Optional
[
torch
.
Tensor
]
=
None
,
cos_sin_cache
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""
"""
The KV cache is stored inside this class and is accessed via
The KV cache is stored inside this class and is accessed via
...
@@ -255,8 +260,12 @@ class Attention(nn.Module):
...
@@ -255,8 +260,12 @@ class Attention(nn.Module):
attn_metadata
,
attn_metadata
,
output
=
output
)
output
=
output
)
else
:
else
:
torch
.
ops
.
vllm
.
unified_attention_with_output
(
if
not
envs
.
VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT
:
query
,
key
,
value
,
output
,
self
.
layer_name
)
torch
.
ops
.
vllm
.
unified_attention_with_output
(
query
,
key
,
value
,
output
,
self
.
layer_name
)
else
:
torch
.
ops
.
vllm
.
unified_attention_with_output
(
query
,
key
,
value
,
output
,
self
.
layer_name
,
None
,
q_ori
,
key_normed
,
positions
,
weight
,
cos_sin_cache
)
return
output
.
view
(
-
1
,
hidden_size
)
return
output
.
view
(
-
1
,
hidden_size
)
else
:
else
:
if
self
.
use_direct_call
:
if
self
.
use_direct_call
:
...
@@ -497,6 +506,11 @@ def unified_attention_with_output(
...
@@ -497,6 +506,11 @@ def unified_attention_with_output(
output
:
torch
.
Tensor
,
output
:
torch
.
Tensor
,
layer_name
:
str
,
layer_name
:
str
,
output_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
output_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
q_ori
:
Optional
[
torch
.
Tensor
]
=
None
,
key_normed
:
Optional
[
torch
.
Tensor
]
=
None
,
positions
:
Optional
[
torch
.
Tensor
]
=
None
,
weight
:
Optional
[
torch
.
Tensor
]
=
None
,
cos_sin_cache
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
None
:
)
->
None
:
wait_for_kv_layer_from_connector
(
layer_name
)
wait_for_kv_layer_from_connector
(
layer_name
)
forward_context
:
ForwardContext
=
get_forward_context
()
forward_context
:
ForwardContext
=
get_forward_context
()
...
@@ -505,29 +519,60 @@ def unified_attention_with_output(
...
@@ -505,29 +519,60 @@ def unified_attention_with_output(
attn_metadata
=
attn_metadata
[
layer_name
]
attn_metadata
=
attn_metadata
[
layer_name
]
self
=
forward_context
.
no_compile_layers
[
layer_name
]
self
=
forward_context
.
no_compile_layers
[
layer_name
]
kv_cache
=
self
.
kv_cache
[
forward_context
.
virtual_engine
]
kv_cache
=
self
.
kv_cache
[
forward_context
.
virtual_engine
]
self
.
impl
.
forward
(
self
,
if
not
envs
.
VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT
:
query
,
self
.
impl
.
forward
(
self
,
key
,
query
,
value
,
key
,
kv_cache
,
value
,
attn_metadata
,
kv_cache
,
output
=
output
,
attn_metadata
,
output_scale
=
output_scale
)
output
=
output
,
output_scale
=
output_scale
)
else
:
self
.
impl
.
forward
(
self
,
query
,
key
,
value
,
kv_cache
,
attn_metadata
,
output
=
output
,
output_scale
=
output_scale
,
q_ori
=
q_ori
,
key_normed
=
key_normed
,
positions
=
positions
,
weight
=
weight
,
cos_sin_cache
=
cos_sin_cache
)
if
envs
.
VLLM_ENABLE_TBO
:
if
envs
.
VLLM_ENABLE_TBO
:
tbo_maybe_save_kv_layer_to_connector
(
layer_name
,
kv_cache
)
tbo_maybe_save_kv_layer_to_connector
(
layer_name
,
kv_cache
)
else
:
else
:
maybe_save_kv_layer_to_connector
(
layer_name
,
kv_cache
)
maybe_save_kv_layer_to_connector
(
layer_name
,
kv_cache
)
def
unified_attention_with_output_fake
(
if
not
envs
.
VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT
:
query
:
torch
.
Tensor
,
def
unified_attention_with_output_fake
(
key
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
output
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
layer_name
:
str
,
output
:
torch
.
Tensor
,
output_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
layer_name
:
str
,
)
->
None
:
output_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
return
)
->
None
:
return
else
:
def
unified_attention_with_output_fake
(
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
output
:
torch
.
Tensor
,
layer_name
:
str
,
output_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
q_ori
:
Optional
[
torch
.
Tensor
]
=
None
,
key_normed
:
Optional
[
torch
.
Tensor
]
=
None
,
positions
:
Optional
[
torch
.
Tensor
]
=
None
,
weight
:
Optional
[
torch
.
Tensor
]
=
None
,
cos_sin_cache
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
None
:
return
direct_register_custom_op
(
direct_register_custom_op
(
...
...
vllm/envs.py
View file @
64e307c7
...
@@ -189,6 +189,7 @@ if TYPE_CHECKING:
...
@@ -189,6 +189,7 @@ if TYPE_CHECKING:
VLLM_USE_FUSE_SILU_AND_MUL
:
bool
=
False
VLLM_USE_FUSE_SILU_AND_MUL
:
bool
=
False
VLLM_USE_OPT_RESHAPE_AND_CACHE
:
bool
=
False
VLLM_USE_OPT_RESHAPE_AND_CACHE
:
bool
=
False
VLLM_USE_TOPK_RENORM
:
bool
=
False
VLLM_USE_TOPK_RENORM
:
bool
=
False
VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT
:
bool
=
False
def
get_default_cache_root
():
def
get_default_cache_root
():
return
os
.
getenv
(
return
os
.
getenv
(
...
@@ -1238,6 +1239,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
...
@@ -1238,6 +1239,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
(
os
.
environ
.
get
(
"VLLM_USE_TOPK_RENORM"
,
"True"
).
lower
()
in
(
os
.
environ
.
get
(
"VLLM_USE_TOPK_RENORM"
,
"True"
).
lower
()
in
(
"true"
,
"1"
)),
(
"true"
,
"1"
)),
# vllm will use fused rmsnorm + contiguous + rope(for dpsk-v3) + concat_and_cache_mla
"VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT"
:
lambda
:
(
os
.
getenv
(
'VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT'
,
'False'
).
lower
()
in
(
"true"
,
"1"
)),
}
}
# --8<-- [end:env-vars-definition]
# --8<-- [end:env-vars-definition]
...
...
vllm/model_executor/model_loader/utils.py
View file @
64e307c7
...
@@ -255,6 +255,8 @@ def get_model_architecture(
...
@@ -255,6 +255,8 @@ def get_model_architecture(
os
.
environ
[
'VLLM_USE_LIGHTOP_FILL_MOE_ALIGN'
]
=
'1'
os
.
environ
[
'VLLM_USE_LIGHTOP_FILL_MOE_ALIGN'
]
=
'1'
if
not
envs
.
is_set
(
"VLLM_USE_CAT_MLA"
):
if
not
envs
.
is_set
(
"VLLM_USE_CAT_MLA"
):
os
.
environ
[
'VLLM_USE_CAT_MLA'
]
=
'1'
os
.
environ
[
'VLLM_USE_CAT_MLA'
]
=
'1'
# if not envs.is_set("VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT"):
# os.environ['VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT'] = '1'
else
:
else
:
if
not
envs
.
is_set
(
"VLLM_USE_PD_SPLIT"
):
if
not
envs
.
is_set
(
"VLLM_USE_PD_SPLIT"
):
os
.
environ
[
'VLLM_USE_PD_SPLIT'
]
=
'1'
os
.
environ
[
'VLLM_USE_PD_SPLIT'
]
=
'1'
...
@@ -267,8 +269,8 @@ def get_model_architecture(
...
@@ -267,8 +269,8 @@ def get_model_architecture(
os
.
environ
[
'VLLM_USE_LIGHTOP_MOE_SUM'
]
=
'1'
os
.
environ
[
'VLLM_USE_LIGHTOP_MOE_SUM'
]
=
'1'
if
not
envs
.
is_set
(
"VLLM_USE_FUSE_SILU_AND_MUL"
):
if
not
envs
.
is_set
(
"VLLM_USE_FUSE_SILU_AND_MUL"
):
os
.
environ
[
'VLLM_USE_FUSE_SILU_AND_MUL'
]
=
'1'
os
.
environ
[
'VLLM_USE_FUSE_SILU_AND_MUL'
]
=
'1'
if
not
envs
.
is_set
(
"VLLM_USE_OPT_RESHAPE_AND_CACHE"
):
#
if not envs.is_set("VLLM_USE_OPT_RESHAPE_AND_CACHE"):
os
.
environ
[
'VLLM_USE_OPT_RESHAPE_AND_CACHE'
]
=
'1'
#
os.environ['VLLM_USE_OPT_RESHAPE_AND_CACHE'] = '1'
if
os
.
getenv
(
'GEMM_PAD'
)
!=
'1'
:
if
os
.
getenv
(
'GEMM_PAD'
)
!=
'1'
:
os
.
environ
[
'GEMM_PAD'
]
=
'0'
os
.
environ
[
'GEMM_PAD'
]
=
'0'
...
@@ -286,6 +288,8 @@ def get_model_architecture(
...
@@ -286,6 +288,8 @@ def get_model_architecture(
os
.
environ
[
'VLLM_USE_LIGHTOP_FILL_MOE_ALIGN'
]
=
'1'
os
.
environ
[
'VLLM_USE_LIGHTOP_FILL_MOE_ALIGN'
]
=
'1'
if
not
envs
.
is_set
(
"VLLM_USE_CAT_MLA"
):
if
not
envs
.
is_set
(
"VLLM_USE_CAT_MLA"
):
os
.
environ
[
'VLLM_USE_CAT_MLA'
]
=
'1'
os
.
environ
[
'VLLM_USE_CAT_MLA'
]
=
'1'
if
not
envs
.
is_set
(
"VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT"
):
os
.
environ
[
'VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT'
]
=
'1'
else
:
else
:
if
not
envs
.
is_set
(
"VLLM_USE_PD_SPLIT"
):
if
not
envs
.
is_set
(
"VLLM_USE_PD_SPLIT"
):
os
.
environ
[
'VLLM_USE_PD_SPLIT'
]
=
'1'
os
.
environ
[
'VLLM_USE_PD_SPLIT'
]
=
'1'
...
...
vllm/model_executor/models/deepseek_v2.py
View file @
64e307c7
...
@@ -50,7 +50,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
...
@@ -50,7 +50,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear
)
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
,
_yarn_find_correction_range
,
_yarn_linear_ramp_mask
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
VocabParallelEmbedding
)
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
(
from
vllm.model_executor.model_loader.weight_utils
import
(
...
@@ -64,7 +64,8 @@ from .utils import (PPMissingLayer, is_pp_missing_parameter,
...
@@ -64,7 +64,8 @@ from .utils import (PPMissingLayer, is_pp_missing_parameter,
maybe_prefix
)
maybe_prefix
)
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.utils
import
W8a8GetCacheJSON
from
vllm.utils
import
W8a8GetCacheJSON
class
DeepseekV2MLP
(
nn
.
Module
):
class
DeepseekV2MLP
(
nn
.
Module
):
def
__init__
(
def
__init__
(
...
@@ -607,6 +608,52 @@ class DeepseekV2MLAAttention(nn.Module):
...
@@ -607,6 +608,52 @@ class DeepseekV2MLAAttention(nn.Module):
self
.
prefix
=
prefix
self
.
prefix
=
prefix
self
.
debug_layer_idx
=
int
(
self
.
prefix
.
split
(
"."
)[
-
2
])
self
.
debug_layer_idx
=
int
(
self
.
prefix
.
split
(
"."
)[
-
2
])
if
envs
.
VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT
:
self
.
max_position_embeddings
=
rope_scaling
[
"original_max_position_embeddings"
]
self
.
base
=
rope_theta
self
.
rotary_dim
=
qk_rope_head_dim
self
.
scaling_factor
=
scaling_factor
self
.
mscale
=
mscale
self
.
extrapolation_factor
=
1
self
.
beta_fast
=
32
self
.
beta_slow
=
1
cache
=
self
.
_compute_cos_sin_cache
()
cache
=
cache
.
to
(
"cuda"
)
self
.
cos_sin_cache
:
torch
.
Tensor
self
.
register_buffer
(
"cos_sin_cache"
,
cache
,
persistent
=
False
)
def
_compute_inv_freq
(
self
,
scaling_factor
:
float
)
->
torch
.
Tensor
:
pos_freqs
=
self
.
base
**
(
torch
.
arange
(
0
,
self
.
rotary_dim
,
2
,
dtype
=
torch
.
float
,
device
=
"cuda"
)
/
self
.
rotary_dim
)
inv_freq_extrapolation
=
1.0
/
pos_freqs
inv_freq_interpolation
=
1.0
/
(
scaling_factor
*
pos_freqs
)
low
,
high
=
_yarn_find_correction_range
(
self
.
beta_fast
,
self
.
beta_slow
,
self
.
rotary_dim
,
self
.
base
,
self
.
max_position_embeddings
)
# Get n-d rotational scaling corrected for extrapolation
inv_freq_mask
=
(
1
-
_yarn_linear_ramp_mask
(
low
,
high
,
self
.
rotary_dim
//
2
,
dtype
=
torch
.
float
))
*
self
.
extrapolation_factor
inv_freq
=
inv_freq_interpolation
*
(
1
-
inv_freq_mask
)
+
inv_freq_extrapolation
*
inv_freq_mask
return
inv_freq
def
_compute_cos_sin_cache
(
self
)
->
torch
.
Tensor
:
inv_freq
=
self
.
_compute_inv_freq
(
self
.
scaling_factor
)
t
=
torch
.
arange
(
self
.
max_position_embeddings
*
self
.
scaling_factor
,
device
=
"cuda"
,
dtype
=
torch
.
float32
)
freqs
=
torch
.
einsum
(
"i,j -> ij"
,
t
,
inv_freq
)
cos
=
(
freqs
.
cos
()
*
self
.
mscale
)
sin
=
(
freqs
.
sin
()
*
self
.
mscale
)
cache
=
torch
.
cat
((
cos
,
sin
),
dim
=-
1
)
return
cache
def
forward
(
def
forward
(
self
,
self
,
...
@@ -697,24 +744,47 @@ class DeepseekV2MLAAttention(nn.Module):
...
@@ -697,24 +744,47 @@ class DeepseekV2MLAAttention(nn.Module):
q
=
self
.
q_proj
(
hidden_states
)[
0
]
q
=
self
.
q_proj
(
hidden_states
)[
0
]
kv_c
,
k_pe
=
self
.
kv_a_proj_with_mqa
(
hidden_states
)[
0
].
split
(
kv_c
,
k_pe
=
self
.
kv_a_proj_with_mqa
(
hidden_states
)[
0
].
split
(
[
self
.
kv_lora_rank
,
self
.
qk_rope_head_dim
],
dim
=-
1
)
[
self
.
kv_lora_rank
,
self
.
qk_rope_head_dim
],
dim
=-
1
)
if
envs
.
VLLM_USE_LIGHTOP
:
if
not
envs
.
VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT
:
kv_c_normed
=
self
.
kv_a_layernorm
.
forward_cuda_opt
(
kv_c
)
if
envs
.
VLLM_USE_LIGHTOP
:
else
:
kv_c_normed
=
self
.
kv_a_layernorm
.
forward_cuda_opt
(
kv_c
)
kv_c_normed
=
self
.
kv_a_layernorm
(
kv_c
.
contiguous
())
else
:
kv_c_normed
=
self
.
kv_a_layernorm
(
kv_c
.
contiguous
())
q
=
q
.
view
(
-
1
,
self
.
num_local_heads
,
self
.
qk_head_dim
)
q
=
q
.
view
(
-
1
,
self
.
num_local_heads
,
self
.
qk_head_dim
)
# Add head dim of 1 to k_pe
# Add head dim of 1 to k_pe
k_pe
=
k_pe
.
unsqueeze
(
1
)
k_pe
=
k_pe
.
unsqueeze
(
1
)
q
[...,
self
.
qk_nope_head_dim
:],
k_pe
=
self
.
rotary_emb
(
q
[...,
self
.
qk_nope_head_dim
:],
k_pe
=
self
.
rotary_emb
(
positions
,
q
[...,
self
.
qk_nope_head_dim
:],
k_pe
)
positions
,
q
[...,
self
.
qk_nope_head_dim
:],
k_pe
)
attn_out
=
self
.
mla_attn
(
attn_out
=
self
.
mla_attn
(
q
,
q
,
kv_c_normed
,
kv_c_normed
,
k_pe
,
k_pe
,
output_shape
=
(
hidden_states
.
shape
[
0
],
output_shape
=
(
hidden_states
.
shape
[
0
],
self
.
num_local_heads
*
self
.
v_head_dim
))
self
.
num_local_heads
*
self
.
v_head_dim
))
else
:
q
=
q
.
view
(
-
1
,
self
.
num_local_heads
,
self
.
qk_head_dim
)
# Add head dim of 1 to k_pe
k_pe
=
k_pe
.
unsqueeze
(
1
)
weight
=
torch
.
ones
(
kv_c
.
shape
[
-
1
],
dtype
=
q
.
dtype
,
device
=
kv_c
.
device
)
weight
=
nn
.
Parameter
(
weight
)
if
self
.
cos_sin_cache
.
device
!=
positions
.
device
:
self
.
cos_sin_cache
=
self
.
cos_sin_cache
.
to
(
positions
.
device
)
if
self
.
cos_sin_cache
.
device
!=
q
.
dtype
:
self
.
cos_sin_cache
=
self
.
cos_sin_cache
.
to
(
q
.
dtype
)
kv_c_normed
=
torch
.
empty
(
kv_c
.
shape
,
dtype
=
kv_c
.
dtype
,
device
=
kv_c
.
device
)
attn_out
=
self
.
mla_attn
(
q
[...,
self
.
qk_nope_head_dim
:],
kv_c
,
k_pe
,
output_shape
=
(
hidden_states
.
shape
[
0
],
self
.
num_local_heads
*
self
.
v_head_dim
),
q_ori
=
q
,
key_normed
=
kv_c_normed
,
positions
=
positions
,
weight
=
weight
.
data
,
cos_sin_cache
=
self
.
cos_sin_cache
)
return
self
.
o_proj
(
attn_out
)[
0
]
return
self
.
o_proj
(
attn_out
)[
0
]
...
...
vllm/v1/attention/backends/mla/common.py
View file @
64e307c7
...
@@ -1095,6 +1095,11 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
...
@@ -1095,6 +1095,11 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
attn_metadata
:
M
,
attn_metadata
:
M
,
output
:
Optional
[
torch
.
Tensor
]
=
None
,
output
:
Optional
[
torch
.
Tensor
]
=
None
,
output_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
output_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
q_ori
:
Optional
[
torch
.
Tensor
]
=
None
,
key_normed
:
Optional
[
torch
.
Tensor
]
=
None
,
positions
:
Optional
[
torch
.
Tensor
]
=
None
,
weight
:
Optional
[
torch
.
Tensor
]
=
None
,
cos_sin_cache
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
assert
output
is
not
None
,
"Output tensor must be provided."
assert
output
is
not
None
,
"Output tensor must be provided."
...
@@ -1129,22 +1134,56 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
...
@@ -1129,22 +1134,56 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
decode_q
=
q
[:
num_decode_tokens
]
decode_q
=
q
[:
num_decode_tokens
]
prefill_q
=
q
[
num_decode_tokens
:]
prefill_k_pe
=
k_pe
[
num_decode_tokens
:]
prefill_k_pe
=
k_pe
[
num_decode_tokens
:]
prefill_k_c_normed
=
k_c_normed
[
num_decode_tokens
:]
if
not
envs
.
VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT
:
decode_q
=
q
[:
num_decode_tokens
]
prefill_q
=
q
[
num_decode_tokens
:]
prefill_k_c_normed
=
k_c_normed
[
num_decode_tokens
:]
else
:
q_ori
=
q_ori
[:
num_actual_toks
,
...]
decode_q
=
q_ori
[:
num_decode_tokens
]
prefill_q
=
q_ori
[
num_decode_tokens
:]
# write the latent and rope to kv cache
# write the latent and rope to kv cache
if
kv_cache
.
numel
()
>
0
:
if
kv_cache
.
numel
()
>
0
:
ops
.
concat_and_cache_mla
(
if
not
envs
.
VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT
:
k_c_normed
,
ops
.
concat_and_cache_mla
(
k_pe
.
squeeze
(
1
),
k_c_normed
,
kv_cache
,
k_pe
.
squeeze
(
1
),
attn_metadata
.
slot_mapping
.
flatten
(),
kv_cache
,
kv_cache_dtype
=
self
.
kv_cache_dtype
,
attn_metadata
.
slot_mapping
.
flatten
(),
scale
=
layer
.
_k_scale
,
kv_cache_dtype
=
self
.
kv_cache_dtype
,
)
scale
=
layer
.
_k_scale
,
)
else
:
from
lightop
import
fused_rms_norm_rope_contiguous
if
self
.
kv_cache_dtype
==
"auto"
:
if
q
.
dtype
==
torch
.
float16
:
kv_cache_dtype_str
=
"fp16"
elif
q
.
dtype
==
torch
.
bfloat16
:
kv_cache_dtype_str
=
"bf16"
else
:
kv_cache_dtype_str
=
self
.
kv_cache_dtype
fused_rms_norm_rope_contiguous
(
positions
,
q
,
k_pe
.
squeeze
(
1
),
k_c_normed
,
# not normed
key_normed
,
# normed
weight
,
cos_sin_cache
,
attn_metadata
.
slot_mapping
.
flatten
(),
kv_cache
,
kv_cache_dtype_str
,
1.0
,
False
,
1e-6
,
)
if
has_prefill
:
if
has_prefill
:
if
envs
.
VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT
:
prefill_k_c_normed
=
key_normed
[
num_decode_tokens
:]
output
[
num_decode_tokens
:]
=
self
.
_forward_prefill
(
output
[
num_decode_tokens
:]
=
self
.
_forward_prefill
(
prefill_q
,
prefill_k_c_normed
,
prefill_k_pe
,
kv_cache
,
prefill_q
,
prefill_k_c_normed
,
prefill_k_pe
,
kv_cache
,
attn_metadata
,
kv_scale
=
layer
.
_k_scale
)
attn_metadata
,
kv_scale
=
layer
.
_k_scale
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment