Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
153002ad
Commit
153002ad
authored
Apr 23, 2026
by
wanghl6
Browse files
[Perf]融合算子优化
parent
aef3c487
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
127 additions
and
49 deletions
+127
-49
vllm/envs.py
vllm/envs.py
+9
-1
vllm/model_executor/layers/quantization/utils/fp8_utils.py
vllm/model_executor/layers/quantization/utils/fp8_utils.py
+44
-0
vllm/model_executor/models/deepseek_v2.py
vllm/model_executor/models/deepseek_v2.py
+55
-40
vllm/v1/attention/backends/mla/flashmla_sparse.py
vllm/v1/attention/backends/mla/flashmla_sparse.py
+19
-8
No files found.
vllm/envs.py
View file @
153002ad
...
...
@@ -325,7 +325,8 @@ if TYPE_CHECKING:
USE_LIGHTOP_CONVERT_REQ_INDEX_TO_GLOBAL_INDEX
:
bool
=
False
VLLM_DISABLE_DSA
:
bool
=
False
VLLM_LIGHTLY_CP_THRESHOULD
:
int
=
2048
USE_LIGHTOP_CP_CONVERT_FP8_KV_CACHE
:
bool
=
False
USE_LIGHTOP_FUSE_LN_ROPE_QUANT
:
bool
=
False
def
get_default_cache_root
():
return
os
.
getenv
(
...
...
@@ -2012,6 +2013,13 @@ environment_variables: dict[str, Callable[[], Any]] = {
# MLA_CP open threshold
"VLLM_LIGHTLY_CP_THRESHOULD"
:
lambda
:
int
(
os
.
getenv
(
"VLLM_LIGHTLY_CP_THRESHOULD"
,
"2048"
)),
"USE_LIGHTOP_CP_CONVERT_FP8_KV_CACHE"
:
lambda
:
(
os
.
environ
.
get
(
"USE_LIGHTOP_CP_CONVERT_FP8_KV_CACHE"
,
"False"
).
lower
()
in
(
"true"
,
"1"
)),
"USE_LIGHTOP_FUSE_LN_ROPE_QUANT"
:
lambda
:
(
os
.
environ
.
get
(
"USE_LIGHTOP_FUSE_LN_ROPE_QUANT"
,
"False"
).
lower
()
in
(
"true"
,
"1"
)),
}
# --8<-- [end:env-vars-definition]
...
...
vllm/model_executor/layers/quantization/utils/fp8_utils.py
View file @
153002ad
...
...
@@ -1067,7 +1067,51 @@ def per_token_group_quant_fp8(
return
x_q
,
x_s
def
_lightop_fuse_norm_rope_quant_fp8_impl
(
positions
:
torch
.
Tensor
,
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
head_dim
:
int
,
cos_sin_cache
:
torch
.
Tensor
,
is_neox
:
bool
,
is_rmsnorm
:
bool
,
weight_k
:
torch
.
Tensor
|
None
,
bias_k
:
torch
.
Tensor
|
None
,
eps
:
float
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
from
lightop
import
op
return
op
.
fuse_norm_rope_quant_fp8
(
positions
,
q
,
k
,
head_dim
,
cos_sin_cache
,
is_neox
,
is_rmsnorm
,
weight_k
,
bias_k
,
eps
)
def
_lightop_fuse_norm_rope_quant_fp8_fake
(
positions
:
torch
.
Tensor
,
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
head_dim
:
int
,
cos_sin_cache
:
torch
.
Tensor
,
is_neox
:
bool
,
is_rmsnorm
:
bool
,
weight_k
:
torch
.
Tensor
|
None
,
bias_k
:
torch
.
Tensor
|
None
,
eps
:
float
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
k_out
=
torch
.
empty_like
(
k
)
fp8_dtype
=
current_platform
.
fp8_dtype
()
if
hasattr
(
current_platform
,
"fp8_dtype"
)
else
torch
.
float8_e4m3fn
q_fp8_out
=
torch
.
empty_like
(
q
,
dtype
=
fp8_dtype
)
q_scale_out
=
torch
.
empty
(
(
q
.
shape
[
0
],
q
.
shape
[
1
],
1
),
dtype
=
torch
.
float32
,
device
=
q
.
device
)
return
k_out
,
q_fp8_out
,
q_scale_out
direct_register_custom_op
(
op_name
=
"lightop_fuse_norm_rope_quant_fp8"
,
op_func
=
_lightop_fuse_norm_rope_quant_fp8_impl
,
mutates_args
=
[],
fake_impl
=
_lightop_fuse_norm_rope_quant_fp8_fake
,
)
def
per_token_group_quant_fp8_packed_for_deepgemm
(
x
:
torch
.
Tensor
,
group_size
:
int
,
...
...
vllm/model_executor/models/deepseek_v2.py
View file @
153002ad
...
...
@@ -860,53 +860,68 @@ class Indexer(nn.Module):
)
->
torch
.
Tensor
:
q
,
_
=
self
.
wq_b
(
qr
)
q
=
q
.
view
(
-
1
,
self
.
n_head
,
self
.
head_dim
)
q_pe
,
q_nope
=
torch
.
split
(
q
,
[
self
.
rope_dim
,
self
.
head_dim
-
self
.
rope_dim
],
dim
=-
1
)
if
envs
.
USE_FUSED_RMS_QUANT
and
self
.
wk
.
weight
.
dtype
==
torch
.
int8
and
iqis
is
not
None
:
k
,
_
=
self
.
wk
(
hidden_states
,
iqis
=
iqis
)
else
:
k
,
_
=
self
.
wk
(
hidden_states
)
k
=
self
.
k_norm
(
k
)
k_pe
,
k_nope
=
torch
.
split
(
k
,
[
self
.
rope_dim
,
self
.
head_dim
-
self
.
rope_dim
],
dim
=-
1
)
q_pe
,
k_pe
=
rotary_emb
(
positions
,
q_pe
,
k_pe
.
unsqueeze
(
1
))
# Note: RoPE (NeoX) can introduce extra leading dimensions during compilation
# so we need to reshape back to token-flattened shapes
q_pe
=
q_pe
.
reshape
(
-
1
,
self
.
n_head
,
self
.
rope_dim
)
k_pe
=
k_pe
.
reshape
(
-
1
,
1
,
self
.
rope_dim
)
# `rotary_emb` is shape-preserving; `q_pe` is already
# [num_tokens, n_head, rope_dim].
q
=
torch
.
cat
([
q_pe
,
q_nope
],
dim
=-
1
)
# `k_pe` is [num_tokens, 1, rope_dim] (MQA).
k
=
torch
.
cat
([
k_pe
.
squeeze
(
-
2
),
k_nope
],
dim
=-
1
)
enable_lightly_cp
=
get_forward_context
().
enable_lightly_cp
if
enable_lightly_cp
:
k
=
tensor_model_parallel_all_gather
(
k
.
contiguous
(),
0
)
gather_indexes_tensor
=
get_forward_context
().
gather_indexes_tensor
enable_lightly_cplb
=
get_forward_context
().
enable_lightly_cplb
if
enable_lightly_cplb
and
gather_indexes_tensor
is
not
None
:
k
=
torch
.
index_select
(
k
,
0
,
gather_indexes_tensor
)
# we only quant q here since k quant is fused with cache insertion
if
not
current_platform
.
is_rocm
()
or
torch
.
cuda
.
get_device_properties
(
"cuda"
).
gcnArchName
.
split
(
':'
)[
0
]
==
"gfx938"
:
q
=
q
.
view
(
-
1
,
self
.
head_dim
)
q_fp8
,
q_scale
=
per_token_group_quant_fp8
(
if
envs
.
USE_LIGHTOP_FUSE_LN_ROPE_QUANT
:
is_rmsnorm
=
not
hasattr
(
self
.
k_norm
,
'bias'
)
or
self
.
k_norm
.
bias
is
None
weight_k
=
getattr
(
self
.
k_norm
,
'weight'
,
None
)
bias_k
=
getattr
(
self
.
k_norm
,
'bias'
,
None
)
eps
=
getattr
(
self
.
k_norm
,
'eps'
,
1e-5
)
cos_sin_cache
=
getattr
(
rotary_emb
,
'cos_sin_cache'
,
None
)
is_neox
=
getattr
(
rotary_emb
,
'is_neox'
,
True
)
k
,
q_fp8
,
q_scale
=
torch
.
ops
.
vllm
.
lightop_fuse_norm_rope_quant_fp8
(
positions
,
q
,
self
.
quant_block_size
,
column_major_scales
=
False
,
use_ue8m0
=
self
.
scale_fmt
is
not
None
,
k
,
self
.
head_dim
,
cos_sin_cache
,
is_neox
,
is_rmsnorm
,
weight_k
,
bias_k
,
eps
)
q_fp8
=
q_fp8
.
view
(
-
1
,
self
.
n_head
,
self
.
head_dim
)
q_scale
=
q_scale
.
view
(
-
1
,
self
.
n_head
,
1
)
if
current_platform
.
is_rocm
()
and
torch
.
cuda
.
get_device_properties
(
"cuda"
).
gcnArchName
.
split
(
':'
)[
0
]
!=
"gfx938"
:
q_fp8
=
q
q_scale
=
None
else
:
q_fp8
=
q
q_pe
,
q_nope
=
torch
.
split
(
q
,
[
self
.
rope_dim
,
self
.
head_dim
-
self
.
rope_dim
],
dim
=-
1
)
k
=
self
.
k_norm
(
k
)
k_pe
,
k_nope
=
torch
.
split
(
k
,
[
self
.
rope_dim
,
self
.
head_dim
-
self
.
rope_dim
],
dim
=-
1
)
q_pe
,
k_pe
=
rotary_emb
(
positions
,
q_pe
,
k_pe
.
unsqueeze
(
1
))
# Note: RoPE (NeoX) can introduce extra leading dimensions during compilation
# so we need to reshape back to token-flattened shapes
q_pe
=
q_pe
.
reshape
(
-
1
,
self
.
n_head
,
self
.
rope_dim
)
k_pe
=
k_pe
.
reshape
(
-
1
,
1
,
self
.
rope_dim
)
# `rotary_emb` is shape-preserving; `q_pe` is already
# [num_tokens, n_head, rope_dim].
q
=
torch
.
cat
([
q_pe
,
q_nope
],
dim
=-
1
)
# `k_pe` is [num_tokens, 1, rope_dim] (MQA).
k
=
torch
.
cat
([
k_pe
.
squeeze
(
-
2
),
k_nope
],
dim
=-
1
)
# we only quant q here since k quant is fused with cache insertion
if
not
current_platform
.
is_rocm
()
or
torch
.
cuda
.
get_device_properties
(
"cuda"
).
gcnArchName
.
split
(
':'
)[
0
]
==
"gfx938"
:
q
=
q
.
view
(
-
1
,
self
.
head_dim
)
q_fp8
,
q_scale
=
per_token_group_quant_fp8
(
q
,
self
.
quant_block_size
,
column_major_scales
=
False
,
use_ue8m0
=
self
.
scale_fmt
is
not
None
,
)
q_fp8
=
q_fp8
.
view
(
-
1
,
self
.
n_head
,
self
.
head_dim
)
q_scale
=
q_scale
.
view
(
-
1
,
self
.
n_head
,
1
)
else
:
q_fp8
=
q
q_scale
=
None
if
envs
.
USE_FUSED_RMS_QUANT
and
self
.
weights_proj
.
weight
.
dtype
==
torch
.
int8
and
iqis
is
not
None
:
weights
,
_
=
self
.
weights_proj
(
hidden_states
,
iqis
=
iqis
)
...
...
vllm/v1/attention/backends/mla/flashmla_sparse.py
View file @
153002ad
...
...
@@ -868,14 +868,25 @@ class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]):
assert
fp8_metadata
.
prefill
is
not
None
for
chunk
in
fp8_metadata
.
prefill
.
chunks
:
chunk_workspace
=
self
.
prefill_bf16_workspace
[:
chunk
.
chunk_tot_seqlen
]
ops
.
cp_gather_and_upconvert_fp8_kv_cache
(
kv_c_and_k_pe_cache
,
chunk_workspace
,
chunk
.
block_table
,
chunk
.
seq_lens
,
chunk
.
workspace_starts
,
len
(
chunk
.
block_table
),
)
if
not
envs
.
USE_LIGHTOP_CP_CONVERT_FP8_KV_CACHE
:
ops
.
cp_gather_and_upconvert_fp8_kv_cache
(
kv_c_and_k_pe_cache
,
chunk_workspace
,
chunk
.
block_table
,
chunk
.
seq_lens
,
chunk
.
workspace_starts
,
len
(
chunk
.
block_table
),
)
else
:
from
lightop
import
op
op
.
cp_gather_and_upconvert_fp8_kv_cache
(
kv_c_and_k_pe_cache
,
chunk_workspace
,
chunk
.
block_table
,
chunk
.
seq_lens
,
chunk
.
workspace_starts
,
len
(
chunk
.
block_table
),
)
chunk_q
=
q
[
chunk
.
tokens_slice
]
chunk_topk_indices_workspace
=
topk_indices
[
chunk
.
tokens_slice
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment