Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
82cd3c88
Commit
82cd3c88
authored
Dec 24, 2025
by
王敏
Browse files
Merge remote-tracking branch 'origin/v0.9.2-dev' into v0.9.2-dev
# Conflicts: # vllm/envs.py
parents
35e43dfb
7d5faa43
Changes
8
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
410 additions
and
308 deletions
+410
-308
vllm/attention/layer.py
vllm/attention/layer.py
+13
-24
vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py
...ted/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py
+64
-26
vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_engine.py
...ibuted/kv_transfer/kv_connector/v1/p2p/p2p_nccl_engine.py
+123
-5
vllm/envs.py
vllm/envs.py
+6
-2
vllm/model_executor/model_loader/utils.py
vllm/model_executor/model_loader/utils.py
+0
-4
vllm/model_executor/models/deepseek_mtp.py
vllm/model_executor/models/deepseek_mtp.py
+30
-6
vllm/model_executor/models/qwen3_moe.py
vllm/model_executor/models/qwen3_moe.py
+173
-240
vllm/v1/attention/backends/mla/common.py
vllm/v1/attention/backends/mla/common.py
+1
-1
No files found.
vllm/attention/layer.py
View file @
82cd3c88
...
...
@@ -553,31 +553,20 @@ def unified_attention_with_output(
maybe_save_kv_layer_to_connector
(
layer_name
,
kv_cache
)
if
not
envs
.
VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT
:
def
unified_attention_with_output_fake
(
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
output
:
torch
.
Tensor
,
layer_name
:
str
,
output_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
def
unified_attention_with_output_fake
(
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
output
:
torch
.
Tensor
,
layer_name
:
str
,
output_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
q_ori
:
Optional
[
torch
.
Tensor
]
=
None
,
key_normed
:
Optional
[
torch
.
Tensor
]
=
None
,
positions
:
Optional
[
torch
.
Tensor
]
=
None
,
weight
:
Optional
[
torch
.
Tensor
]
=
None
,
cos_sin_cache
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
None
:
return
else
:
def
unified_attention_with_output_fake
(
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
output
:
torch
.
Tensor
,
layer_name
:
str
,
output_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
q_ori
:
Optional
[
torch
.
Tensor
]
=
None
,
key_normed
:
Optional
[
torch
.
Tensor
]
=
None
,
positions
:
Optional
[
torch
.
Tensor
]
=
None
,
weight
:
Optional
[
torch
.
Tensor
]
=
None
,
cos_sin_cache
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
None
:
return
return
direct_register_custom_op
(
...
...
vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py
View file @
82cd3c88
...
...
@@ -18,7 +18,7 @@ from vllm.forward_context import get_forward_context
from
vllm.logger
import
init_logger
from
vllm.v1.attention.backends.mla.common
import
MLACommonMetadata
from
vllm.v1.core.sched.output
import
SchedulerOutput
from
vllm.distributed.parallel_state
import
get_pp_group
,
get_tp_group
from
vllm.distributed.parallel_state
import
get_pp_group
,
get_tp_group
,
get_dp_group
if
TYPE_CHECKING
:
from
vllm.attention.backends.abstract
import
AttentionMetadata
...
...
@@ -90,12 +90,24 @@ class P2pNcclConnector(KVConnectorBase_V1):
if
role
==
KVConnectorRole
.
WORKER
else
0
self
.
_local_rank
=
get_world_group
().
local_rank
\
if
role
==
KVConnectorRole
.
WORKER
else
0
self
.
_dp_rank
=
get_dp_group
().
rank_in_group
\
if
role
==
KVConnectorRole
.
WORKER
else
0
self
.
_pp_rank
=
get_pp_group
().
rank_in_group
\
if
role
==
KVConnectorRole
.
WORKER
else
0
self
.
_tp_rank
=
get_tp_group
().
rank_in_group
\
if
role
==
KVConnectorRole
.
WORKER
else
0
self
.
_dp_size
=
get_dp_group
().
world_size
\
if
role
==
KVConnectorRole
.
WORKER
else
0
self
.
_pp_size
=
get_pp_group
().
world_size
\
if
role
==
KVConnectorRole
.
WORKER
else
0
self
.
_tp_size
=
get_tp_group
().
world_size
\
if
role
==
KVConnectorRole
.
WORKER
else
0
self
.
p2p_nccl_engine
=
P2pNcclEngine
(
local_rank
=
self
.
_local_rank
,
config
=
self
.
config
,
hostname
=
""
,
port_offset
=
self
.
_rank
,
config
=
self
.
config
,
model_config
=
vllm_config
.
model_config
,
)
if
role
==
KVConnectorRole
.
WORKER
else
None
self
.
parallel_config
=
vllm_config
.
parallel_config
...
...
@@ -105,9 +117,19 @@ class P2pNcclConnector(KVConnectorBase_V1):
self
.
pp_size
=
self
.
parallel_config
.
pipeline_parallel_size
self
.
tp_size
=
self
.
parallel_config
.
tensor_parallel_size
self
.
num_card
=
self
.
pp_size
*
self
.
tp_size
self
.
multiple_machines
=
1
if
self
.
num_card
>
8
else
0
if
self
.
is_producer
and
self
.
multiple_machines
==
1
:
self
.
remote_tp_size
=
self
.
config
.
get_from_extra_config
(
"remote_tp_size"
,
self
.
tp_size
)
self
.
remote_pp_size
=
self
.
config
.
get_from_extra_config
(
"remote_pp_size"
,
self
.
pp_size
)
self
.
enable_asymmetric_p2p
=
self
.
config
.
get_from_extra_config
(
"enable_asymmetric_p2p"
,
False
)
self
.
remote_num_card
=
self
.
remote_tp_size
*
self
.
remote_pp_size
self
.
multiple_machines_d
=
1
if
self
.
remote_num_card
>
8
else
0
self
.
multiple_machines_p
=
1
if
self
.
num_card
>
8
else
0
if
self
.
is_producer
and
self
.
multiple_machines_p
==
1
:
self
.
ip_map
=
{}
self
.
duplicate_keys
=
[]
config_file
=
os
.
getenv
(
'IP_CONFIG_FILE'
)
...
...
@@ -353,6 +375,8 @@ class P2pNcclConnector(KVConnectorBase_V1):
assert
self
.
p2p_nccl_engine
is
not
None
is_mla
=
isinstance
(
attn_metadata
,
MLACommonMetadata
)
def
extract_kv_from_layer
(
layer
:
torch
.
Tensor
,
slot_mapping
:
torch
.
Tensor
,
...
...
@@ -417,7 +441,7 @@ class P2pNcclConnector(KVConnectorBase_V1):
pp_rank
=
(
self
.
parallel_config
.
rank
//
self
.
parallel_config
.
tensor_parallel_size
)
%
self
.
parallel_config
.
pipeline_parallel_size
if
(
self
.
multiple_machines
):
if
(
self
.
multiple_machines
_p
and
self
.
multiple_machines_d
):
ip_second
=
self
.
get_ip_value
(
ip
)
if
(
self
.
pp_size
==
1
):
if
self
.
_rank
<
8
:
...
...
@@ -433,29 +457,43 @@ class P2pNcclConnector(KVConnectorBase_V1):
self
.
p2p_nccl_engine
.
send_tensor
(
request_id
+
"#"
+
layer_name
,
kv_cache
,
str
(
ip_second
)
+
":"
+
str
(
port
+
self
.
_rank
))
else
:
print
(
"Error: only suppprt pp1 pp2 !!!!!!"
)
else
:
if
(
self
.
pp_size
==
1
):
logger
.
error
(
"Error: multiple machines only suppprt pp1tp16 and pp2tp8!!!!!!"
)
elif
(
self
.
multiple_machines_p
and
not
self
.
multiple_machines_d
):
if
(
self
.
pp_size
==
2
):
remote_address
=
ip
+
":"
+
str
(
port
+
self
.
_tp_rank
)
self
.
p2p_nccl_engine
.
send_tensor
(
request_id
+
"#"
+
layer_name
,
kv_cache
,
remote_address
)
elif
(
self
.
pp_size
==
2
):
if
(
pp_rank
==
0
):
self
.
p2p_nccl_engine
.
send_tensor
(
request_id
+
"#"
+
layer_name
,
kv_cache
,
remote_address
)
self
.
p2p_nccl_engine
.
send_tensor
(
request_id
+
"#"
+
layer_name
,
kv_cache
,
ip
+
":"
+
str
(
port
+
self
.
_rank
+
4
))
else
:
self
.
p2p_nccl_engine
.
send_tensor
(
request_id
+
"#"
+
layer_name
,
kv_cache
,
remote_address
)
self
.
p2p_nccl_engine
.
send_tensor
(
request_id
+
"#"
+
layer_name
,
kv_cache
,
ip
+
":"
+
str
(
port
+
self
.
_rank
-
4
))
elif
(
self
.
pp_size
==
8
):
for
i
in
range
(
8
):
self
.
p2p_nccl_engine
.
send_tensor
(
request_id
+
"#"
+
layer_name
,
kv_cache
,
ip
+
":"
+
str
(
port
+
i
))
else
:
print
(
"Error: only suppprt pp1 pp2 pp8!!!!!!"
)
logger
.
error
(
"Error: P multiple machines D machine only suppprt P:pp2tp8 D:tp8 !!!!!!"
)
elif
(
not
self
.
multiple_machines_p
and
not
self
.
multiple_machines_d
):
self
.
p2p_nccl_engine
.
send_tensor_new
(
request_id
,
layer_name
,
kv_cache
,
is_mla
)
# if (self.pp_size == 1):
# self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
# kv_cache, remote_address)
# elif (self.pp_size == 2):
# if (pp_rank == 0):
# self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
# kv_cache, remote_address)
# self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
# kv_cache, ip + ":" + str(port + self._rank + 4))
# else:
# self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
# kv_cache, remote_address)
# self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
# kv_cache, ip + ":" + str(port + self._rank - 4))
# elif (self.pp_size == 8):
# for i in range(8):
# self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
# kv_cache, ip + ":" + str(port + i))
# elif (self.enable_asymmetric_p2p):
# self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
# kv_cache, remote_address)
# else:
# logger.error("Error: P/D single machine only suppprt multiple tp:: (P: pp2tp4 D:tp8 P:pp8tp1 D:tp8) !!!!!!")
else
:
logger
.
error
(
"Error: not support!!!!!!"
)
def
wait_for_save
(
self
):
pass
# if self.is_producer:
...
...
vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_engine.py
View file @
82cd3c88
...
...
@@ -13,6 +13,7 @@ from typing import TYPE_CHECKING, Any, Optional
import
msgpack
import
torch
import
zmq
import
regex
from
vllm.config
import
KVTransferConfig
from
vllm.distributed.device_communicators.pynccl_wrapper
import
(
...
...
@@ -23,6 +24,11 @@ from vllm.utils import current_stream, get_ip
from
vllm
import
envs
from
vllm.distributed.parallel_state
import
get_pp_group
,
get_tp_group
from
dataclasses
import
dataclass
from
vllm.model_executor.models.utils
import
extract_layer_index
from
vllm.distributed.utils
import
get_pp_indices
from
vllm.config
import
ModelConfig
if
TYPE_CHECKING
:
from
vllm.forward_context
import
ForwardContext
...
...
@@ -30,6 +36,11 @@ logger = logging.getLogger(__name__)
DEFAULT_MEM_POOL_SIZE_GB
=
32
# @dataclass
# class SendQueueItem:
# tensor_id: str
# remote_address: str
# tensor: torch.Tensor
@
contextmanager
def
set_p2p_nccl_context
(
num_channels
:
str
):
...
...
@@ -65,22 +76,39 @@ class P2pNcclEngine:
def
__init__
(
self
,
local_rank
:
int
,
port_offset
:
int
,
config
:
KVTransferConfig
,
hostname
:
str
=
""
,
port_offset
:
int
=
0
,
model_config
:
ModelConfig
,
library_path
:
Optional
[
str
]
=
None
)
->
None
:
self
.
config
=
config
self
.
model_config
=
model_config
self
.
rank
=
port_offset
self
.
local_rank
=
local_rank
self
.
device
=
torch
.
device
(
f
"cuda:
{
self
.
local_rank
}
"
)
self
.
nccl
=
NCCLLibrary
(
library_path
)
if
not
hostname
:
hostname
=
get_ip
()
self
.
total_num_hidden_layers
=
getattr
(
self
.
model_config
.
hf_text_config
,
"num_hidden_layers"
,
0
)
self
.
pp_rank
=
get_pp_group
().
rank_in_group
self
.
tp_rank
=
get_tp_group
().
rank_in_group
self
.
pp_size
=
get_pp_group
().
world_size
self
.
tp_size
=
get_tp_group
().
world_size
if
config
.
is_kv_producer
:
self
.
remote_tp_size
=
self
.
config
.
get_from_extra_config
(
"remote_tp_size"
,
1
)
self
.
remote_pp_size
=
self
.
config
.
get_from_extra_config
(
"remote_pp_size"
,
1
)
self
.
enable_asymmetric_p2p
=
self
.
config
.
get_from_extra_config
(
"enable_asymmetric_p2p"
,
False
)
if
self
.
remote_tp_size
%
self
.
tp_size
!=
0
:
logger
.
error
(
" the Prefill TP size must be less than or equal to the Decode TP size!!!!"
)
self
.
multp
=
int
(
self
.
remote_tp_size
/
self
.
tp_size
)
port
=
int
(
self
.
config
.
kv_port
)
+
port_offset
if
port
==
0
:
raise
ValueError
(
"Port cannot be 0"
)
self
.
_hostname
=
hostname
self
.
_hostname
=
get_ip
()
self
.
_port
=
port
# Each card corresponds to a ZMQ address.
...
...
@@ -195,6 +223,61 @@ class P2pNcclEngine:
return
self
.
socks
[
remote_address
],
self
.
comms
[
remote_address
]
def
get_send_queue_items
(
self
,
request_id
:
str
,
layer_name
:
str
,
tensor
:
torch
.
Tensor
,
is_mla
:
bool
)
->
list
[
any
]:
tensor_id
=
self
.
get_tensor_id
(
request_id
,
layer_name
)
remote_ip
,
remote_port
=
self
.
parse_request_id
(
request_id
,
True
)
if
not
self
.
enable_asymmetric_p2p
:
remote_address
=
remote_ip
+
":"
+
str
(
remote_port
+
self
.
rank
)
return
[(
tensor_id
,
remote_address
,
tensor
)]
if
not
is_mla
:
logger
.
error
(
" P2PNCCL only support mla model symmetric PP/TP!!!!"
)
remote_pp_rank
=
self
.
compute_remote_pp_rank
(
layer_name
)
items
:
list
[
Any
]
=
[]
up_down
=
1
# remote_tp_rank = self.tp_rank * self.multp
for
d_tp_rank
in
range
(
self
.
remote_tp_size
):
for
mul_tp
in
range
(
self
.
multp
):
if
self
.
tp_rank
+
mul_tp
*
self
.
tp_size
==
d_tp_rank
:
remote_port_offset
=
remote_pp_rank
*
self
.
remote_tp_size
+
d_tp_rank
remote_address
=
remote_ip
+
":"
+
str
(
remote_port
+
remote_port_offset
)
logger
.
debug
(
"📥 [PUT] Wait to send: tensor_id:%s, tensor_shape:%s, "
"(pp=%d, tp=%d) -> remote_address=%s(pp=%d, tp=%d)"
,
tensor_id
,
tensor
.
shape
,
self
.
pp_rank
,
self
.
tp_rank
,
remote_address
,
remote_pp_rank
,
self
.
rank
*
mul_tp
+
self
.
rank
)
items
.
append
([
tensor_id
,
remote_address
,
tensor
])
return
items
def
send_tensor_new
(
self
,
request_id
:
str
,
layer_name
:
str
,
tensor
:
torch
.
Tensor
,
is_mla
:
bool
=
False
,
)
->
bool
:
tensor_id
=
self
.
get_tensor_id
(
request_id
,
layer_name
)
if
self
.
send_type
==
"PUT"
:
return
all
(
self
.
send_sync
(
item
)
for
item
in
self
.
get_send_queue_items
(
request_id
,
layer_name
,
tensor
,
is_mla
))
if
self
.
send_type
==
"PUT_ASYNC"
:
with
self
.
send_queue_cv
:
for
item
in
self
.
get_send_queue_items
(
request_id
,
layer_name
,
tensor
,
is_mla
):
self
.
send_queue
.
append
(
item
)
self
.
send_queue_cv
.
notify
()
return
True
if
self
.
send_type
==
"GET"
:
logger
.
error
(
" P2PNCCL new not support GET model, please set VLLM_P2PNCCL_NEW=0 use defalut model!!!!"
)
def
send_tensor
(
self
,
tensor_id
:
str
,
...
...
@@ -659,3 +742,38 @@ class P2pNcclEngine:
self
.
_send_thread
.
join
()
if
self
.
_ping_thread
is
not
None
:
self
.
_ping_thread
.
join
()
def
compute_remote_pp_rank
(
self
,
layer_name
:
str
)
->
int
:
current_layer_idx
=
extract_layer_index
(
layer_name
)
for
d_pp_rank
in
range
(
self
.
remote_pp_size
):
start
,
end
=
get_pp_indices
(
self
.
total_num_hidden_layers
,
d_pp_rank
,
self
.
remote_pp_size
)
logger
.
info
(
f
"""compute_remote_pp_rank : current_layer_idx:
{
current_layer_idx
}
start:
{
start
}
end:
{
end
}
"""
)
if
(
current_layer_idx
==
self
.
total_num_hidden_layers
):
return
self
.
remote_pp_size
-
1
if
start
<=
current_layer_idx
<
end
:
return
d_pp_rank
return
-
1
@
staticmethod
def
get_tensor_id
(
request_id
:
str
,
layer_name
:
str
)
->
str
:
return
request_id
+
"#"
+
layer_name
@
staticmethod
def
parse_request_id
(
request_id
:
str
,
is_prefill
=
True
)
->
tuple
[
str
,
int
]:
# Regular expression to match the string hostname and integer port
if
is_prefill
:
pattern
=
r
"___decode_addr_(.*):(\d+)"
else
:
pattern
=
r
"___prefill_addr_(.*):(\d+)___"
# Use re.search to find the pattern in the request_id
match
=
regex
.
search
(
pattern
,
request_id
)
if
match
:
# Extract the ranks
ip
=
match
.
group
(
1
)
port
=
int
(
match
.
group
(
2
))
return
ip
,
port
raise
ValueError
(
f
"Request id
{
request_id
}
does not contain hostname and port"
)
\ No newline at end of file
vllm/envs.py
View file @
82cd3c88
...
...
@@ -196,6 +196,7 @@ if TYPE_CHECKING:
VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT
:
bool
=
False
VLLM_USE_FUSED_RMS_ROPE
:
bool
=
False
VLLM_USE_MARLIN_W16A16_MOE
:
bool
=
False
VLLM_USE_FUSED_FILL_RMS_CAT
:
bool
=
False
VLLM_ENABLE_DEEPEP_HT_DEEPGEMM
:
bool
=
True
def
get_default_cache_root
():
...
...
@@ -1070,7 +1071,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
# flag to control vllm to use optimized kernels
"VLLM_CUSTOM_CACHE"
:
lambda
:
bool
(
int
(
os
.
environ
.
get
(
"VLLM_CUSTOM_CACHE"
,
"
0
"
))),
lambda
:
bool
(
int
(
os
.
environ
.
get
(
"VLLM_CUSTOM_CACHE"
,
"
1
"
))),
# flag to control vllm to use optimized kernels
"VLLM_CUSTOM_ALLREDUCE_SUPPORTED_WORLDSIZE_MAX"
:
...
...
@@ -1276,11 +1277,14 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_MARLIN_W16A16_MOE"
:
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_MARLIN_W16A16_MOE"
,
"False"
).
lower
()
in
(
"true"
,
"1"
)),
# vLLM will use lightop for dpsk mtp fill + rms*2 + cat
"VLLM_USE_FUSED_FILL_RMS_CAT"
:
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_FUSED_FILL_RMS_CAT"
,
"False"
).
lower
()
in
(
"true"
,
"1"
)),
# vLLM will use deepgemm kernel for deepep ht mode
"VLLM_ENABLE_DEEPEP_HT_DEEPGEMM"
:
lambda
:
(
os
.
getenv
(
'VLLM_ENABLE_DEEPEP_HT_DEEPGEMM'
,
'1'
).
lower
()
in
(
"true"
,
"1"
)),
}
# --8<-- [end:env-vars-definition]
...
...
vllm/model_executor/model_loader/utils.py
View file @
82cd3c88
...
...
@@ -253,8 +253,6 @@ def get_model_architecture(
os
.
environ
[
'VLLM_USE_OPT_CAT'
]
=
'1'
if
not
envs
.
is_set
(
"VLLM_USE_CAT_MLA"
):
os
.
environ
[
'VLLM_USE_CAT_MLA'
]
=
'1'
# if not envs.is_set("VLLM_REJECT_SAMPLE_OPT"):
# os.environ['VLLM_REJECT_SAMPLE_OPT'] = '1'
if
not
envs
.
is_set
(
"VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT"
):
os
.
environ
[
'VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT'
]
=
'1'
if
not
envs
.
is_set
(
"VLLM_SCHED_ENABLE_MINIMAL_INJECTION"
):
...
...
@@ -298,8 +296,6 @@ def get_model_architecture(
os
.
environ
[
'VLLM_USE_OPT_CAT'
]
=
'1'
if
not
envs
.
is_set
(
"VLLM_USE_CAT_MLA"
):
os
.
environ
[
'VLLM_USE_CAT_MLA'
]
=
'1'
# if not envs.is_set("VLLM_REJECT_SAMPLE_OPT"):
# os.environ['VLLM_REJECT_SAMPLE_OPT'] = '1'
if
not
envs
.
is_set
(
"VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT"
):
os
.
environ
[
'VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT'
]
=
'1'
if
not
envs
.
is_set
(
"VLLM_SCHED_ENABLE_MINIMAL_INJECTION"
):
...
...
vllm/model_executor/models/deepseek_mtp.py
View file @
82cd3c88
...
...
@@ -28,6 +28,8 @@ from .interfaces import SupportsPP
from
.utils
import
maybe_prefix
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.quantization.blockwise_int8
import
BlockInt8Config
import
vllm.envs
as
envs
from
vllm.utils
import
direct_register_custom_op
class
SharedHead
(
nn
.
Module
):
...
...
@@ -71,6 +73,24 @@ class DeepSeekMultiTokenPredictorLayer(nn.Module):
self
.
shared_head
=
SharedHead
(
config
=
config
,
quant_config
=
quant_config
)
self
.
mtp_block
=
DeepseekV2DecoderLayer
(
config
,
prefix
,
model_config
,
cache_config
,
quant_config
)
def
fuse_fill_rms_x2_concat
(
hidden_states_fuse
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
inputs_embeds
:
torch
.
Tensor
,
previous_hidden_states
:
torch
.
Tensor
,
weight_inputs_embeds
:
torch
.
Tensor
,
weight_previous_hidden_states
:
torch
.
Tensor
,
epsilon
:
float
)
->
None
:
from
lightop
import
fuse_fill_rms_x2_concat
fuse_fill_rms_x2_concat
(
hidden_states_fuse
,
positions
,
inputs_embeds
,
previous_hidden_states
,
weight_inputs_embeds
,
weight_previous_hidden_states
,
epsilon
)
def
fuse_fill_rms_x2_concat_fake
(
hidden_states_fuse
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
inputs_embeds
:
torch
.
Tensor
,
previous_hidden_states
:
torch
.
Tensor
,
weight_inputs_embeds
:
torch
.
Tensor
,
weight_previous_hidden_states
:
torch
.
Tensor
,
epsilon
:
float
)
->
None
:
pass
direct_register_custom_op
(
op_name
=
"fuse_fill_rms_x2_concat"
,
op_func
=
fuse_fill_rms_x2_concat
,
mutates_args
=
[
"hidden_states_fuse"
,
"inputs_embeds"
],
fake_impl
=
fuse_fill_rms_x2_concat_fake
,
)
def
forward
(
self
,
...
...
@@ -84,12 +104,16 @@ class DeepSeekMultiTokenPredictorLayer(nn.Module):
inputs_embeds
=
self
.
embed_tokens
(
input_ids
)
assert
inputs_embeds
is
not
None
# masking inputs at position 0, as not needed by MTP
inputs_embeds
[
positions
==
0
]
=
0
inputs_embeds
=
self
.
enorm
(
inputs_embeds
)
previous_hidden_states
=
self
.
hnorm
(
previous_hidden_states
)
hidden_states
=
self
.
eh_proj
(
torch
.
cat
([
inputs_embeds
,
previous_hidden_states
],
dim
=-
1
))
if
envs
.
VLLM_USE_FUSED_FILL_RMS_CAT
:
hidden_states_fuse
=
torch
.
empty
(
inputs_embeds
.
shape
[
0
],
inputs_embeds
.
shape
[
1
]
*
2
,
device
=
inputs_embeds
.
device
,
dtype
=
inputs_embeds
.
dtype
)
torch
.
ops
.
vllm
.
fuse_fill_rms_x2_concat
(
hidden_states_fuse
,
positions
,
inputs_embeds
,
previous_hidden_states
,
self
.
enorm
.
weight
,
self
.
hnorm
.
weight
,
self
.
enorm
.
variance_epsilon
)
hidden_states
=
self
.
eh_proj
(
hidden_states_fuse
)
else
:
inputs_embeds
[
positions
==
0
]
=
0
inputs_embeds
=
self
.
enorm
(
inputs_embeds
)
previous_hidden_states
=
self
.
hnorm
(
previous_hidden_states
)
hidden_states
=
self
.
eh_proj
(
torch
.
cat
([
inputs_embeds
,
previous_hidden_states
],
dim
=-
1
))
hidden_states
,
residual
=
self
.
mtp_block
(
positions
=
positions
,
hidden_states
=
hidden_states
,
...
...
vllm/model_executor/models/qwen3_moe.py
View file @
82cd3c88
This diff is collapsed.
Click to expand it.
vllm/v1/attention/backends/mla/common.py
View file @
82cd3c88
...
...
@@ -217,6 +217,7 @@ from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
CommonAttentionMetadata
)
from
vllm.v1.kv_cache_interface
import
AttentionSpec
from
vllm.v1.worker.block_table
import
BlockTable
from
lightop
import
fused_rms_norm_rope_contiguous
try
:
from
vllm.vllm_flash_attn
import
flash_attn_varlen_func
...
...
@@ -1163,7 +1164,6 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
kv_cache_dtype_str
=
"bf16"
else
:
kv_cache_dtype_str
=
self
.
kv_cache_dtype
from
lightop
import
fused_rms_norm_rope_contiguous
fused_rms_norm_rope_contiguous
(
positions
[:
num_actual_toks
,
...],
q
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment