Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
82cd3c88
Commit
82cd3c88
authored
Dec 24, 2025
by
王敏
Browse files
Merge remote-tracking branch 'origin/v0.9.2-dev' into v0.9.2-dev
# Conflicts: # vllm/envs.py
parents
35e43dfb
7d5faa43
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
410 additions
and
308 deletions
+410
-308
vllm/attention/layer.py
vllm/attention/layer.py
+13
-24
vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py
...ted/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py
+64
-26
vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_engine.py
...ibuted/kv_transfer/kv_connector/v1/p2p/p2p_nccl_engine.py
+123
-5
vllm/envs.py
vllm/envs.py
+6
-2
vllm/model_executor/model_loader/utils.py
vllm/model_executor/model_loader/utils.py
+0
-4
vllm/model_executor/models/deepseek_mtp.py
vllm/model_executor/models/deepseek_mtp.py
+30
-6
vllm/model_executor/models/qwen3_moe.py
vllm/model_executor/models/qwen3_moe.py
+173
-240
vllm/v1/attention/backends/mla/common.py
vllm/v1/attention/backends/mla/common.py
+1
-1
No files found.
vllm/attention/layer.py
View file @
82cd3c88
...
@@ -553,18 +553,7 @@ def unified_attention_with_output(
...
@@ -553,18 +553,7 @@ def unified_attention_with_output(
maybe_save_kv_layer_to_connector
(
layer_name
,
kv_cache
)
maybe_save_kv_layer_to_connector
(
layer_name
,
kv_cache
)
if
not
envs
.
VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT
:
def
unified_attention_with_output_fake
(
def
unified_attention_with_output_fake
(
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
output
:
torch
.
Tensor
,
layer_name
:
str
,
output_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
None
:
return
else
:
def
unified_attention_with_output_fake
(
query
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
...
...
vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py
View file @
82cd3c88
...
@@ -18,7 +18,7 @@ from vllm.forward_context import get_forward_context
...
@@ -18,7 +18,7 @@ from vllm.forward_context import get_forward_context
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.v1.attention.backends.mla.common
import
MLACommonMetadata
from
vllm.v1.attention.backends.mla.common
import
MLACommonMetadata
from
vllm.v1.core.sched.output
import
SchedulerOutput
from
vllm.v1.core.sched.output
import
SchedulerOutput
from
vllm.distributed.parallel_state
import
get_pp_group
,
get_tp_group
from
vllm.distributed.parallel_state
import
get_pp_group
,
get_tp_group
,
get_dp_group
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
vllm.attention.backends.abstract
import
AttentionMetadata
from
vllm.attention.backends.abstract
import
AttentionMetadata
...
@@ -90,12 +90,24 @@ class P2pNcclConnector(KVConnectorBase_V1):
...
@@ -90,12 +90,24 @@ class P2pNcclConnector(KVConnectorBase_V1):
if
role
==
KVConnectorRole
.
WORKER
else
0
if
role
==
KVConnectorRole
.
WORKER
else
0
self
.
_local_rank
=
get_world_group
().
local_rank
\
self
.
_local_rank
=
get_world_group
().
local_rank
\
if
role
==
KVConnectorRole
.
WORKER
else
0
if
role
==
KVConnectorRole
.
WORKER
else
0
self
.
_dp_rank
=
get_dp_group
().
rank_in_group
\
if
role
==
KVConnectorRole
.
WORKER
else
0
self
.
_pp_rank
=
get_pp_group
().
rank_in_group
\
if
role
==
KVConnectorRole
.
WORKER
else
0
self
.
_tp_rank
=
get_tp_group
().
rank_in_group
\
if
role
==
KVConnectorRole
.
WORKER
else
0
self
.
_dp_size
=
get_dp_group
().
world_size
\
if
role
==
KVConnectorRole
.
WORKER
else
0
self
.
_pp_size
=
get_pp_group
().
world_size
\
if
role
==
KVConnectorRole
.
WORKER
else
0
self
.
_tp_size
=
get_tp_group
().
world_size
\
if
role
==
KVConnectorRole
.
WORKER
else
0
self
.
p2p_nccl_engine
=
P2pNcclEngine
(
self
.
p2p_nccl_engine
=
P2pNcclEngine
(
local_rank
=
self
.
_local_rank
,
local_rank
=
self
.
_local_rank
,
config
=
self
.
config
,
hostname
=
""
,
port_offset
=
self
.
_rank
,
port_offset
=
self
.
_rank
,
config
=
self
.
config
,
model_config
=
vllm_config
.
model_config
,
)
if
role
==
KVConnectorRole
.
WORKER
else
None
)
if
role
==
KVConnectorRole
.
WORKER
else
None
self
.
parallel_config
=
vllm_config
.
parallel_config
self
.
parallel_config
=
vllm_config
.
parallel_config
...
@@ -105,9 +117,19 @@ class P2pNcclConnector(KVConnectorBase_V1):
...
@@ -105,9 +117,19 @@ class P2pNcclConnector(KVConnectorBase_V1):
self
.
pp_size
=
self
.
parallel_config
.
pipeline_parallel_size
self
.
pp_size
=
self
.
parallel_config
.
pipeline_parallel_size
self
.
tp_size
=
self
.
parallel_config
.
tensor_parallel_size
self
.
tp_size
=
self
.
parallel_config
.
tensor_parallel_size
self
.
num_card
=
self
.
pp_size
*
self
.
tp_size
self
.
num_card
=
self
.
pp_size
*
self
.
tp_size
self
.
multiple_machines
=
1
if
self
.
num_card
>
8
else
0
if
self
.
is_producer
and
self
.
multiple_machines
==
1
:
self
.
remote_tp_size
=
self
.
config
.
get_from_extra_config
(
"remote_tp_size"
,
self
.
tp_size
)
self
.
remote_pp_size
=
self
.
config
.
get_from_extra_config
(
"remote_pp_size"
,
self
.
pp_size
)
self
.
enable_asymmetric_p2p
=
self
.
config
.
get_from_extra_config
(
"enable_asymmetric_p2p"
,
False
)
self
.
remote_num_card
=
self
.
remote_tp_size
*
self
.
remote_pp_size
self
.
multiple_machines_d
=
1
if
self
.
remote_num_card
>
8
else
0
self
.
multiple_machines_p
=
1
if
self
.
num_card
>
8
else
0
if
self
.
is_producer
and
self
.
multiple_machines_p
==
1
:
self
.
ip_map
=
{}
self
.
ip_map
=
{}
self
.
duplicate_keys
=
[]
self
.
duplicate_keys
=
[]
config_file
=
os
.
getenv
(
'IP_CONFIG_FILE'
)
config_file
=
os
.
getenv
(
'IP_CONFIG_FILE'
)
...
@@ -353,6 +375,8 @@ class P2pNcclConnector(KVConnectorBase_V1):
...
@@ -353,6 +375,8 @@ class P2pNcclConnector(KVConnectorBase_V1):
assert
self
.
p2p_nccl_engine
is
not
None
assert
self
.
p2p_nccl_engine
is
not
None
is_mla
=
isinstance
(
attn_metadata
,
MLACommonMetadata
)
def
extract_kv_from_layer
(
def
extract_kv_from_layer
(
layer
:
torch
.
Tensor
,
layer
:
torch
.
Tensor
,
slot_mapping
:
torch
.
Tensor
,
slot_mapping
:
torch
.
Tensor
,
...
@@ -417,7 +441,7 @@ class P2pNcclConnector(KVConnectorBase_V1):
...
@@ -417,7 +441,7 @@ class P2pNcclConnector(KVConnectorBase_V1):
pp_rank
=
(
self
.
parallel_config
.
rank
//
self
.
parallel_config
.
tensor_parallel_size
pp_rank
=
(
self
.
parallel_config
.
rank
//
self
.
parallel_config
.
tensor_parallel_size
)
%
self
.
parallel_config
.
pipeline_parallel_size
)
%
self
.
parallel_config
.
pipeline_parallel_size
if
(
self
.
multiple_machines
):
if
(
self
.
multiple_machines
_p
and
self
.
multiple_machines_d
):
ip_second
=
self
.
get_ip_value
(
ip
)
ip_second
=
self
.
get_ip_value
(
ip
)
if
(
self
.
pp_size
==
1
):
if
(
self
.
pp_size
==
1
):
if
self
.
_rank
<
8
:
if
self
.
_rank
<
8
:
...
@@ -433,29 +457,43 @@ class P2pNcclConnector(KVConnectorBase_V1):
...
@@ -433,29 +457,43 @@ class P2pNcclConnector(KVConnectorBase_V1):
self
.
p2p_nccl_engine
.
send_tensor
(
request_id
+
"#"
+
layer_name
,
self
.
p2p_nccl_engine
.
send_tensor
(
request_id
+
"#"
+
layer_name
,
kv_cache
,
str
(
ip_second
)
+
":"
+
str
(
port
+
self
.
_rank
))
kv_cache
,
str
(
ip_second
)
+
":"
+
str
(
port
+
self
.
_rank
))
else
:
else
:
print
(
"Error: only suppprt pp1 pp2 !!!!!!"
)
logger
.
error
(
"Error: multiple machines only suppprt pp1tp16 and pp2tp8!!!!!!"
)
else
:
elif
(
self
.
multiple_machines_p
and
not
self
.
multiple_machines_d
):
if
(
self
.
pp_size
==
1
):
if
(
self
.
pp_size
==
2
):
remote_address
=
ip
+
":"
+
str
(
port
+
self
.
_tp_rank
)
self
.
p2p_nccl_engine
.
send_tensor
(
request_id
+
"#"
+
layer_name
,
self
.
p2p_nccl_engine
.
send_tensor
(
request_id
+
"#"
+
layer_name
,
kv_cache
,
remote_address
)
kv_cache
,
remote_address
)
elif
(
self
.
pp_size
==
2
):
if
(
pp_rank
==
0
):
self
.
p2p_nccl_engine
.
send_tensor
(
request_id
+
"#"
+
layer_name
,
kv_cache
,
remote_address
)
self
.
p2p_nccl_engine
.
send_tensor
(
request_id
+
"#"
+
layer_name
,
kv_cache
,
ip
+
":"
+
str
(
port
+
self
.
_rank
+
4
))
else
:
else
:
self
.
p2p_nccl_engine
.
send_tensor
(
request_id
+
"#"
+
layer_name
,
logger
.
error
(
"Error: P multiple machines D machine only suppprt P:pp2tp8 D:tp8 !!!!!!"
)
kv_cache
,
remote_address
)
self
.
p2p_nccl_engine
.
send_tensor
(
request_id
+
"#"
+
layer_name
,
elif
(
not
self
.
multiple_machines_p
and
not
self
.
multiple_machines_d
):
kv_cache
,
ip
+
":"
+
str
(
port
+
self
.
_rank
-
4
))
self
.
p2p_nccl_engine
.
send_tensor_new
(
request_id
,
layer_name
,
kv_cache
,
elif
(
self
.
pp_size
==
8
):
is_mla
)
for
i
in
range
(
8
):
# if (self.pp_size == 1):
self
.
p2p_nccl_engine
.
send_tensor
(
request_id
+
"#"
+
layer_name
,
# self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
kv_cache
,
ip
+
":"
+
str
(
port
+
i
))
# kv_cache, remote_address)
# elif (self.pp_size == 2):
# if (pp_rank == 0):
# self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
# kv_cache, remote_address)
# self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
# kv_cache, ip + ":" + str(port + self._rank + 4))
# else:
# self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
# kv_cache, remote_address)
# self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
# kv_cache, ip + ":" + str(port + self._rank - 4))
# elif (self.pp_size == 8):
# for i in range(8):
# self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
# kv_cache, ip + ":" + str(port + i))
# elif (self.enable_asymmetric_p2p):
# self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
# kv_cache, remote_address)
# else:
# logger.error("Error: P/D single machine only suppprt multiple tp:: (P: pp2tp4 D:tp8 P:pp8tp1 D:tp8) !!!!!!")
else
:
else
:
print
(
"Error: only suppprt pp1 pp2 pp8!!!!!!"
)
logger
.
error
(
"Error: not support!!!!!!"
)
def
wait_for_save
(
self
):
def
wait_for_save
(
self
):
pass
pass
# if self.is_producer:
# if self.is_producer:
...
...
vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_engine.py
View file @
82cd3c88
...
@@ -13,6 +13,7 @@ from typing import TYPE_CHECKING, Any, Optional
...
@@ -13,6 +13,7 @@ from typing import TYPE_CHECKING, Any, Optional
import
msgpack
import
msgpack
import
torch
import
torch
import
zmq
import
zmq
import
regex
from
vllm.config
import
KVTransferConfig
from
vllm.config
import
KVTransferConfig
from
vllm.distributed.device_communicators.pynccl_wrapper
import
(
from
vllm.distributed.device_communicators.pynccl_wrapper
import
(
...
@@ -23,6 +24,11 @@ from vllm.utils import current_stream, get_ip
...
@@ -23,6 +24,11 @@ from vllm.utils import current_stream, get_ip
from
vllm
import
envs
from
vllm
import
envs
from
vllm.distributed.parallel_state
import
get_pp_group
,
get_tp_group
from
vllm.distributed.parallel_state
import
get_pp_group
,
get_tp_group
from
dataclasses
import
dataclass
from
vllm.model_executor.models.utils
import
extract_layer_index
from
vllm.distributed.utils
import
get_pp_indices
from
vllm.config
import
ModelConfig
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
vllm.forward_context
import
ForwardContext
from
vllm.forward_context
import
ForwardContext
...
@@ -30,6 +36,11 @@ logger = logging.getLogger(__name__)
...
@@ -30,6 +36,11 @@ logger = logging.getLogger(__name__)
DEFAULT_MEM_POOL_SIZE_GB
=
32
DEFAULT_MEM_POOL_SIZE_GB
=
32
# @dataclass
# class SendQueueItem:
# tensor_id: str
# remote_address: str
# tensor: torch.Tensor
@
contextmanager
@
contextmanager
def
set_p2p_nccl_context
(
num_channels
:
str
):
def
set_p2p_nccl_context
(
num_channels
:
str
):
...
@@ -65,22 +76,39 @@ class P2pNcclEngine:
...
@@ -65,22 +76,39 @@ class P2pNcclEngine:
def
__init__
(
self
,
def
__init__
(
self
,
local_rank
:
int
,
local_rank
:
int
,
port_offset
:
int
,
config
:
KVTransferConfig
,
config
:
KVTransferConfig
,
hostname
:
str
=
""
,
model_config
:
ModelConfig
,
port_offset
:
int
=
0
,
library_path
:
Optional
[
str
]
=
None
)
->
None
:
library_path
:
Optional
[
str
]
=
None
)
->
None
:
self
.
config
=
config
self
.
config
=
config
self
.
model_config
=
model_config
self
.
rank
=
port_offset
self
.
rank
=
port_offset
self
.
local_rank
=
local_rank
self
.
local_rank
=
local_rank
self
.
device
=
torch
.
device
(
f
"cuda:
{
self
.
local_rank
}
"
)
self
.
device
=
torch
.
device
(
f
"cuda:
{
self
.
local_rank
}
"
)
self
.
nccl
=
NCCLLibrary
(
library_path
)
self
.
nccl
=
NCCLLibrary
(
library_path
)
if
not
hostname
:
self
.
total_num_hidden_layers
=
getattr
(
self
.
model_config
.
hf_text_config
,
hostname
=
get_ip
()
"num_hidden_layers"
,
0
)
self
.
pp_rank
=
get_pp_group
().
rank_in_group
self
.
tp_rank
=
get_tp_group
().
rank_in_group
self
.
pp_size
=
get_pp_group
().
world_size
self
.
tp_size
=
get_tp_group
().
world_size
if
config
.
is_kv_producer
:
self
.
remote_tp_size
=
self
.
config
.
get_from_extra_config
(
"remote_tp_size"
,
1
)
self
.
remote_pp_size
=
self
.
config
.
get_from_extra_config
(
"remote_pp_size"
,
1
)
self
.
enable_asymmetric_p2p
=
self
.
config
.
get_from_extra_config
(
"enable_asymmetric_p2p"
,
False
)
if
self
.
remote_tp_size
%
self
.
tp_size
!=
0
:
logger
.
error
(
" the Prefill TP size must be less than or equal to the Decode TP size!!!!"
)
self
.
multp
=
int
(
self
.
remote_tp_size
/
self
.
tp_size
)
port
=
int
(
self
.
config
.
kv_port
)
+
port_offset
port
=
int
(
self
.
config
.
kv_port
)
+
port_offset
if
port
==
0
:
if
port
==
0
:
raise
ValueError
(
"Port cannot be 0"
)
raise
ValueError
(
"Port cannot be 0"
)
self
.
_hostname
=
hostname
self
.
_hostname
=
get_ip
()
self
.
_port
=
port
self
.
_port
=
port
# Each card corresponds to a ZMQ address.
# Each card corresponds to a ZMQ address.
...
@@ -195,6 +223,61 @@ class P2pNcclEngine:
...
@@ -195,6 +223,61 @@ class P2pNcclEngine:
return
self
.
socks
[
remote_address
],
self
.
comms
[
remote_address
]
return
self
.
socks
[
remote_address
],
self
.
comms
[
remote_address
]
def
get_send_queue_items
(
self
,
request_id
:
str
,
layer_name
:
str
,
tensor
:
torch
.
Tensor
,
is_mla
:
bool
)
->
list
[
any
]:
tensor_id
=
self
.
get_tensor_id
(
request_id
,
layer_name
)
remote_ip
,
remote_port
=
self
.
parse_request_id
(
request_id
,
True
)
if
not
self
.
enable_asymmetric_p2p
:
remote_address
=
remote_ip
+
":"
+
str
(
remote_port
+
self
.
rank
)
return
[(
tensor_id
,
remote_address
,
tensor
)]
if
not
is_mla
:
logger
.
error
(
" P2PNCCL only support mla model symmetric PP/TP!!!!"
)
remote_pp_rank
=
self
.
compute_remote_pp_rank
(
layer_name
)
items
:
list
[
Any
]
=
[]
up_down
=
1
# remote_tp_rank = self.tp_rank * self.multp
for
d_tp_rank
in
range
(
self
.
remote_tp_size
):
for
mul_tp
in
range
(
self
.
multp
):
if
self
.
tp_rank
+
mul_tp
*
self
.
tp_size
==
d_tp_rank
:
remote_port_offset
=
remote_pp_rank
*
self
.
remote_tp_size
+
d_tp_rank
remote_address
=
remote_ip
+
":"
+
str
(
remote_port
+
remote_port_offset
)
logger
.
debug
(
"📥 [PUT] Wait to send: tensor_id:%s, tensor_shape:%s, "
"(pp=%d, tp=%d) -> remote_address=%s(pp=%d, tp=%d)"
,
tensor_id
,
tensor
.
shape
,
self
.
pp_rank
,
self
.
tp_rank
,
remote_address
,
remote_pp_rank
,
self
.
rank
*
mul_tp
+
self
.
rank
)
items
.
append
([
tensor_id
,
remote_address
,
tensor
])
return
items
def
send_tensor_new
(
self
,
request_id
:
str
,
layer_name
:
str
,
tensor
:
torch
.
Tensor
,
is_mla
:
bool
=
False
,
)
->
bool
:
tensor_id
=
self
.
get_tensor_id
(
request_id
,
layer_name
)
if
self
.
send_type
==
"PUT"
:
return
all
(
self
.
send_sync
(
item
)
for
item
in
self
.
get_send_queue_items
(
request_id
,
layer_name
,
tensor
,
is_mla
))
if
self
.
send_type
==
"PUT_ASYNC"
:
with
self
.
send_queue_cv
:
for
item
in
self
.
get_send_queue_items
(
request_id
,
layer_name
,
tensor
,
is_mla
):
self
.
send_queue
.
append
(
item
)
self
.
send_queue_cv
.
notify
()
return
True
if
self
.
send_type
==
"GET"
:
logger
.
error
(
" P2PNCCL new not support GET model, please set VLLM_P2PNCCL_NEW=0 use defalut model!!!!"
)
def
send_tensor
(
def
send_tensor
(
self
,
self
,
tensor_id
:
str
,
tensor_id
:
str
,
...
@@ -659,3 +742,38 @@ class P2pNcclEngine:
...
@@ -659,3 +742,38 @@ class P2pNcclEngine:
self
.
_send_thread
.
join
()
self
.
_send_thread
.
join
()
if
self
.
_ping_thread
is
not
None
:
if
self
.
_ping_thread
is
not
None
:
self
.
_ping_thread
.
join
()
self
.
_ping_thread
.
join
()
def
compute_remote_pp_rank
(
self
,
layer_name
:
str
)
->
int
:
current_layer_idx
=
extract_layer_index
(
layer_name
)
for
d_pp_rank
in
range
(
self
.
remote_pp_size
):
start
,
end
=
get_pp_indices
(
self
.
total_num_hidden_layers
,
d_pp_rank
,
self
.
remote_pp_size
)
logger
.
info
(
f
"""compute_remote_pp_rank : current_layer_idx:
{
current_layer_idx
}
start:
{
start
}
end:
{
end
}
"""
)
if
(
current_layer_idx
==
self
.
total_num_hidden_layers
):
return
self
.
remote_pp_size
-
1
if
start
<=
current_layer_idx
<
end
:
return
d_pp_rank
return
-
1
@
staticmethod
def
get_tensor_id
(
request_id
:
str
,
layer_name
:
str
)
->
str
:
return
request_id
+
"#"
+
layer_name
@
staticmethod
def
parse_request_id
(
request_id
:
str
,
is_prefill
=
True
)
->
tuple
[
str
,
int
]:
# Regular expression to match the string hostname and integer port
if
is_prefill
:
pattern
=
r
"___decode_addr_(.*):(\d+)"
else
:
pattern
=
r
"___prefill_addr_(.*):(\d+)___"
# Use re.search to find the pattern in the request_id
match
=
regex
.
search
(
pattern
,
request_id
)
if
match
:
# Extract the ranks
ip
=
match
.
group
(
1
)
port
=
int
(
match
.
group
(
2
))
return
ip
,
port
raise
ValueError
(
f
"Request id
{
request_id
}
does not contain hostname and port"
)
\ No newline at end of file
vllm/envs.py
View file @
82cd3c88
...
@@ -196,6 +196,7 @@ if TYPE_CHECKING:
...
@@ -196,6 +196,7 @@ if TYPE_CHECKING:
VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT
:
bool
=
False
VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT
:
bool
=
False
VLLM_USE_FUSED_RMS_ROPE
:
bool
=
False
VLLM_USE_FUSED_RMS_ROPE
:
bool
=
False
VLLM_USE_MARLIN_W16A16_MOE
:
bool
=
False
VLLM_USE_MARLIN_W16A16_MOE
:
bool
=
False
VLLM_USE_FUSED_FILL_RMS_CAT
:
bool
=
False
VLLM_ENABLE_DEEPEP_HT_DEEPGEMM
:
bool
=
True
VLLM_ENABLE_DEEPEP_HT_DEEPGEMM
:
bool
=
True
def
get_default_cache_root
():
def
get_default_cache_root
():
...
@@ -1070,7 +1071,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
...
@@ -1070,7 +1071,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
# flag to control vllm to use optimized kernels
# flag to control vllm to use optimized kernels
"VLLM_CUSTOM_CACHE"
:
"VLLM_CUSTOM_CACHE"
:
lambda
:
bool
(
int
(
os
.
environ
.
get
(
"VLLM_CUSTOM_CACHE"
,
"
0
"
))),
lambda
:
bool
(
int
(
os
.
environ
.
get
(
"VLLM_CUSTOM_CACHE"
,
"
1
"
))),
# flag to control vllm to use optimized kernels
# flag to control vllm to use optimized kernels
"VLLM_CUSTOM_ALLREDUCE_SUPPORTED_WORLDSIZE_MAX"
:
"VLLM_CUSTOM_ALLREDUCE_SUPPORTED_WORLDSIZE_MAX"
:
...
@@ -1276,11 +1277,14 @@ environment_variables: dict[str, Callable[[], Any]] = {
...
@@ -1276,11 +1277,14 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_MARLIN_W16A16_MOE"
:
"VLLM_USE_MARLIN_W16A16_MOE"
:
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_MARLIN_W16A16_MOE"
,
"False"
).
lower
()
in
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_MARLIN_W16A16_MOE"
,
"False"
).
lower
()
in
(
"true"
,
"1"
)),
(
"true"
,
"1"
)),
# vLLM will use lightop for dpsk mtp fill + rms*2 + cat
"VLLM_USE_FUSED_FILL_RMS_CAT"
:
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_FUSED_FILL_RMS_CAT"
,
"False"
).
lower
()
in
(
"true"
,
"1"
)),
# vLLM will use deepgemm kernel for deepep ht mode
# vLLM will use deepgemm kernel for deepep ht mode
"VLLM_ENABLE_DEEPEP_HT_DEEPGEMM"
:
"VLLM_ENABLE_DEEPEP_HT_DEEPGEMM"
:
lambda
:
(
os
.
getenv
(
'VLLM_ENABLE_DEEPEP_HT_DEEPGEMM'
,
'1'
).
lower
()
in
lambda
:
(
os
.
getenv
(
'VLLM_ENABLE_DEEPEP_HT_DEEPGEMM'
,
'1'
).
lower
()
in
(
"true"
,
"1"
)),
(
"true"
,
"1"
)),
}
}
# --8<-- [end:env-vars-definition]
# --8<-- [end:env-vars-definition]
...
...
vllm/model_executor/model_loader/utils.py
View file @
82cd3c88
...
@@ -253,8 +253,6 @@ def get_model_architecture(
...
@@ -253,8 +253,6 @@ def get_model_architecture(
os
.
environ
[
'VLLM_USE_OPT_CAT'
]
=
'1'
os
.
environ
[
'VLLM_USE_OPT_CAT'
]
=
'1'
if
not
envs
.
is_set
(
"VLLM_USE_CAT_MLA"
):
if
not
envs
.
is_set
(
"VLLM_USE_CAT_MLA"
):
os
.
environ
[
'VLLM_USE_CAT_MLA'
]
=
'1'
os
.
environ
[
'VLLM_USE_CAT_MLA'
]
=
'1'
# if not envs.is_set("VLLM_REJECT_SAMPLE_OPT"):
# os.environ['VLLM_REJECT_SAMPLE_OPT'] = '1'
if
not
envs
.
is_set
(
"VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT"
):
if
not
envs
.
is_set
(
"VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT"
):
os
.
environ
[
'VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT'
]
=
'1'
os
.
environ
[
'VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT'
]
=
'1'
if
not
envs
.
is_set
(
"VLLM_SCHED_ENABLE_MINIMAL_INJECTION"
):
if
not
envs
.
is_set
(
"VLLM_SCHED_ENABLE_MINIMAL_INJECTION"
):
...
@@ -298,8 +296,6 @@ def get_model_architecture(
...
@@ -298,8 +296,6 @@ def get_model_architecture(
os
.
environ
[
'VLLM_USE_OPT_CAT'
]
=
'1'
os
.
environ
[
'VLLM_USE_OPT_CAT'
]
=
'1'
if
not
envs
.
is_set
(
"VLLM_USE_CAT_MLA"
):
if
not
envs
.
is_set
(
"VLLM_USE_CAT_MLA"
):
os
.
environ
[
'VLLM_USE_CAT_MLA'
]
=
'1'
os
.
environ
[
'VLLM_USE_CAT_MLA'
]
=
'1'
# if not envs.is_set("VLLM_REJECT_SAMPLE_OPT"):
# os.environ['VLLM_REJECT_SAMPLE_OPT'] = '1'
if
not
envs
.
is_set
(
"VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT"
):
if
not
envs
.
is_set
(
"VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT"
):
os
.
environ
[
'VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT'
]
=
'1'
os
.
environ
[
'VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT'
]
=
'1'
if
not
envs
.
is_set
(
"VLLM_SCHED_ENABLE_MINIMAL_INJECTION"
):
if
not
envs
.
is_set
(
"VLLM_SCHED_ENABLE_MINIMAL_INJECTION"
):
...
...
vllm/model_executor/models/deepseek_mtp.py
View file @
82cd3c88
...
@@ -28,6 +28,8 @@ from .interfaces import SupportsPP
...
@@ -28,6 +28,8 @@ from .interfaces import SupportsPP
from
.utils
import
maybe_prefix
from
.utils
import
maybe_prefix
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.quantization.blockwise_int8
import
BlockInt8Config
from
vllm.model_executor.layers.quantization.blockwise_int8
import
BlockInt8Config
import
vllm.envs
as
envs
from
vllm.utils
import
direct_register_custom_op
class
SharedHead
(
nn
.
Module
):
class
SharedHead
(
nn
.
Module
):
...
@@ -72,6 +74,24 @@ class DeepSeekMultiTokenPredictorLayer(nn.Module):
...
@@ -72,6 +74,24 @@ class DeepSeekMultiTokenPredictorLayer(nn.Module):
self
.
mtp_block
=
DeepseekV2DecoderLayer
(
config
,
prefix
,
model_config
,
self
.
mtp_block
=
DeepseekV2DecoderLayer
(
config
,
prefix
,
model_config
,
cache_config
,
quant_config
)
cache_config
,
quant_config
)
def
fuse_fill_rms_x2_concat
(
hidden_states_fuse
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
inputs_embeds
:
torch
.
Tensor
,
previous_hidden_states
:
torch
.
Tensor
,
weight_inputs_embeds
:
torch
.
Tensor
,
weight_previous_hidden_states
:
torch
.
Tensor
,
epsilon
:
float
)
->
None
:
from
lightop
import
fuse_fill_rms_x2_concat
fuse_fill_rms_x2_concat
(
hidden_states_fuse
,
positions
,
inputs_embeds
,
previous_hidden_states
,
weight_inputs_embeds
,
weight_previous_hidden_states
,
epsilon
)
def
fuse_fill_rms_x2_concat_fake
(
hidden_states_fuse
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
inputs_embeds
:
torch
.
Tensor
,
previous_hidden_states
:
torch
.
Tensor
,
weight_inputs_embeds
:
torch
.
Tensor
,
weight_previous_hidden_states
:
torch
.
Tensor
,
epsilon
:
float
)
->
None
:
pass
direct_register_custom_op
(
op_name
=
"fuse_fill_rms_x2_concat"
,
op_func
=
fuse_fill_rms_x2_concat
,
mutates_args
=
[
"hidden_states_fuse"
,
"inputs_embeds"
],
fake_impl
=
fuse_fill_rms_x2_concat_fake
,
)
def
forward
(
def
forward
(
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
...
@@ -84,10 +104,14 @@ class DeepSeekMultiTokenPredictorLayer(nn.Module):
...
@@ -84,10 +104,14 @@ class DeepSeekMultiTokenPredictorLayer(nn.Module):
inputs_embeds
=
self
.
embed_tokens
(
input_ids
)
inputs_embeds
=
self
.
embed_tokens
(
input_ids
)
assert
inputs_embeds
is
not
None
assert
inputs_embeds
is
not
None
# masking inputs at position 0, as not needed by MTP
# masking inputs at position 0, as not needed by MTP
if
envs
.
VLLM_USE_FUSED_FILL_RMS_CAT
:
hidden_states_fuse
=
torch
.
empty
(
inputs_embeds
.
shape
[
0
],
inputs_embeds
.
shape
[
1
]
*
2
,
device
=
inputs_embeds
.
device
,
dtype
=
inputs_embeds
.
dtype
)
torch
.
ops
.
vllm
.
fuse_fill_rms_x2_concat
(
hidden_states_fuse
,
positions
,
inputs_embeds
,
previous_hidden_states
,
self
.
enorm
.
weight
,
self
.
hnorm
.
weight
,
self
.
enorm
.
variance_epsilon
)
hidden_states
=
self
.
eh_proj
(
hidden_states_fuse
)
else
:
inputs_embeds
[
positions
==
0
]
=
0
inputs_embeds
[
positions
==
0
]
=
0
inputs_embeds
=
self
.
enorm
(
inputs_embeds
)
inputs_embeds
=
self
.
enorm
(
inputs_embeds
)
previous_hidden_states
=
self
.
hnorm
(
previous_hidden_states
)
previous_hidden_states
=
self
.
hnorm
(
previous_hidden_states
)
hidden_states
=
self
.
eh_proj
(
hidden_states
=
self
.
eh_proj
(
torch
.
cat
([
inputs_embeds
,
previous_hidden_states
],
dim
=-
1
))
torch
.
cat
([
inputs_embeds
,
previous_hidden_states
],
dim
=-
1
))
...
...
vllm/model_executor/models/qwen3_moe.py
View file @
82cd3c88
...
@@ -22,22 +22,19 @@
...
@@ -22,22 +22,19 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
"""Inference-only Qwen3MoE model compatible with HuggingFace weights."""
"""Inference-only Qwen3MoE model compatible with HuggingFace weights."""
import
typing
from
collections.abc
import
Iterable
from
collections.abc
import
Callable
,
Iterable
from
itertools
import
islice
from
typing
import
Any
,
Optional
,
Union
from
typing
import
Any
,
Optional
,
Union
import
os
import
os
import
re
import
re
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
transformers
import
PretrainedConfig
from
vllm.attention
import
Attention
from
vllm.attention
import
Attention
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
CacheConfig
,
VllmConfig
,
get_current_vllm_config
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.distributed
import
(
get_ep_group
,
get_pp_group
,
from
vllm.distributed
import
get_pp_group
,
get_tensor_model_parallel_world_size
get_tensor_model_parallel_world_size
,
tensor_model_parallel_all_gather
)
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.fused_moe
import
FusedMoE
from
vllm.model_executor.layers.fused_moe
import
FusedMoE
...
@@ -51,17 +48,17 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
...
@@ -51,17 +48,17 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
VocabParallelEmbedding
)
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
(
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
default_weight_loader
,
maybe_remap_kv_scale_name
)
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.models.utils
import
sequence_parallel_chunk
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
.interfaces
import
MixtureOfExperts
,
SupportsLoRA
,
SupportsPP
from
.interfaces
import
SupportsPP
from
.utils
import
(
AutoWeightsLoader
,
PPMissingLayer
,
extract_layer_index
,
from
.utils
import
(
AutoWeightsLoader
,
extract_layer_index
,
is_pp_missing_parameter
,
is_pp_missing_parameter
,
make_empty_intermediate_tensors_factory
,
make_layers
,
make_empty_intermediate_tensors_factory
,
make_layers
,
maybe_prefix
)
maybe_prefix
)
import
vllm.envs
as
envs
import
vllm.envs
as
envs
from
vllm.utils
import
direct_register_custom_op
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.utils
import
pad_weight
,
gemm_bank_conf
from
vllm.model_executor.utils
import
pad_weight
,
gemm_bank_conf
from
vllm.utils
import
W8a8GetCacheJSON
from
vllm.utils
import
W8a8GetCacheJSON
...
@@ -108,86 +105,49 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
...
@@ -108,86 +105,49 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
def
__init__
(
def
__init__
(
self
,
self
,
vllm_config
:
VllmConfig
,
config
:
PretrainedConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
prefix
:
str
=
""
,
):
):
super
().
__init__
()
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_text_config
parallel_config
=
vllm_config
.
parallel_config
quant_config
=
vllm_config
.
quant_config
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
ep_group
=
get_ep_group
().
device_group
self
.
ep_rank
=
self
.
ep_group
.
rank
()
self
.
ep_size
=
self
.
ep_group
.
size
()
self
.
n_routed_experts
=
config
.
num_experts
self
.
is_sequence_parallel
=
parallel_config
.
use_sequence_parallel_moe
if
self
.
tp_size
>
config
.
num_experts
:
if
self
.
tp_size
>
config
.
num_experts
:
raise
ValueError
(
raise
ValueError
(
f
"Tensor parallel size
{
self
.
tp_size
}
is greater than "
f
"Tensor parallel size
{
self
.
tp_size
}
is greater than "
f
"the number of experts
{
config
.
num_experts
}
."
)
f
"the number of experts
{
config
.
num_experts
}
."
)
# Load balancing settings.
self
.
experts
=
FusedMoE
(
num_experts
=
config
.
num_experts
,
vllm_config
=
get_current_vllm_config
()
eplb_config
=
vllm_config
.
parallel_config
.
eplb_config
self
.
enable_eplb
=
parallel_config
.
enable_eplb
self
.
n_logical_experts
=
self
.
n_routed_experts
self
.
n_redundant_experts
=
eplb_config
.
num_redundant_experts
self
.
n_physical_experts
=
(
self
.
n_logical_experts
+
self
.
n_redundant_experts
)
self
.
n_local_physical_experts
=
self
.
n_physical_experts
//
self
.
ep_size
self
.
physical_expert_start
=
(
self
.
ep_rank
*
self
.
n_local_physical_experts
)
self
.
physical_expert_end
=
(
self
.
physical_expert_start
+
self
.
n_local_physical_experts
)
self
.
experts
=
FusedMoE
(
num_experts
=
self
.
n_routed_experts
,
top_k
=
config
.
num_experts_per_tok
,
top_k
=
config
.
num_experts_per_tok
,
hidden_size
=
config
.
hidden_size
,
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
config
.
moe_intermediate_size
,
intermediate_size
=
config
.
moe_intermediate_size
,
reduce_results
=
Tru
e
,
reduce_results
=
Fals
e
,
renormalize
=
config
.
norm_topk_prob
,
renormalize
=
config
.
norm_topk_prob
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.experts"
,
prefix
=
f
"
{
prefix
}
.experts"
)
enable_eplb
=
self
.
enable_eplb
,
num_redundant_experts
=
self
.
n_redundant_experts
,
is_sequence_parallel
=
self
.
is_sequence_parallel
)
self
.
gate
=
ReplicatedLinear
(
config
.
hidden_size
,
self
.
gate
=
ReplicatedLinear
(
config
.
hidden_size
,
config
.
num_experts
,
config
.
num_experts
,
bias
=
False
,
bias
=
False
,
quant_config
=
quant_config
,
quant_config
=
None
,
prefix
=
f
"
{
prefix
}
.gate"
)
prefix
=
f
"
{
prefix
}
.gate"
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
assert
hidden_states
.
dim
(
# NOTE: hidden_states can have either 1D or 2D shape.
)
<=
2
,
"Qwen3MoeSparseMoeBlock only supports 1D or 2D inputs"
orig_shape
=
hidden_states
.
shape
is_input_1d
=
hidden_states
.
dim
()
==
1
hidden_dim
=
hidden_states
.
shape
[
-
1
]
num_tokens
,
hidden_dim
=
hidden_states
.
shape
hidden_states
=
hidden_states
.
view
(
-
1
,
hidden_dim
)
hidden_states
=
hidden_states
.
view
(
-
1
,
hidden_dim
)
if
self
.
is_sequence_parallel
:
hidden_states
=
sequence_parallel_chunk
(
hidden_states
)
# router_logits: (num_tokens, n_experts)
# router_logits: (num_tokens, n_experts)
router_logits
,
_
=
self
.
gate
(
hidden_states
)
router_logits
,
_
=
self
.
gate
(
hidden_states
)
final_hidden_states
=
self
.
experts
(
hidden_states
=
hidden_states
,
final_hidden_states
=
self
.
experts
(
hidden_states
=
hidden_states
,
router_logits
=
router_logits
)
router_logits
=
router_logits
)
if
self
.
is_sequence_parallel
:
if
self
.
tp_size
>
1
:
final_hidden_states
=
tensor_model_parallel_all_gather
(
final_hidden_states
=
self
.
experts
.
maybe_all_reduce_tensor_model_parallel
(
# noqa E501
final_hidden_states
,
0
)
final_hidden_states
)
final_hidden_states
=
final_hidden_states
[:
num_tokens
]
# return to 1d if input is 1d
return
final_hidden_states
.
view
(
orig_shape
)
return
final_hidden_states
.
squeeze
(
0
)
if
is_input_1d
else
\
final_hidden_states
class
Qwen3MoeAttention
(
nn
.
Module
):
class
Qwen3MoeAttention
(
nn
.
Module
):
...
@@ -206,7 +166,6 @@ class Qwen3MoeAttention(nn.Module):
...
@@ -206,7 +166,6 @@ class Qwen3MoeAttention(nn.Module):
cache_config
:
Optional
[
CacheConfig
]
=
None
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
prefix
:
str
=
""
,
dual_chunk_attention_config
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
hidden_size
=
hidden_size
self
.
hidden_size
=
hidden_size
...
@@ -230,7 +189,6 @@ class Qwen3MoeAttention(nn.Module):
...
@@ -230,7 +189,6 @@ class Qwen3MoeAttention(nn.Module):
self
.
scaling
=
self
.
head_dim
**-
0.5
self
.
scaling
=
self
.
head_dim
**-
0.5
self
.
rope_theta
=
rope_theta
self
.
rope_theta
=
rope_theta
self
.
max_position_embeddings
=
max_position_embeddings
self
.
max_position_embeddings
=
max_position_embeddings
self
.
dual_chunk_attention_config
=
dual_chunk_attention_config
self
.
qkv_proj
=
QKVParallelLinear
(
hidden_size
,
self
.
qkv_proj
=
QKVParallelLinear
(
hidden_size
,
self
.
head_dim
,
self
.
head_dim
,
...
@@ -252,25 +210,72 @@ class Qwen3MoeAttention(nn.Module):
...
@@ -252,25 +210,72 @@ class Qwen3MoeAttention(nn.Module):
max_position
=
max_position_embeddings
,
max_position
=
max_position_embeddings
,
base
=
rope_theta
,
base
=
rope_theta
,
rope_scaling
=
rope_scaling
,
rope_scaling
=
rope_scaling
,
dual_chunk_attention_config
=
dual_chunk_attention_config
,
)
)
self
.
attn
=
Attention
(
self
.
attn
=
Attention
(
self
.
num_heads
,
self
.
num_heads
,
self
.
head_dim
,
self
.
head_dim
,
self
.
scaling
,
self
.
scaling
,
num_kv_heads
=
self
.
num_kv_heads
,
num_kv_heads
=
self
.
num_kv_heads
,
cache_config
=
cache_config
,
cache_config
=
cache_config
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.attn"
,
prefix
=
f
"
{
prefix
}
.attn"
)
**
{
"layer_idx"
:
extract_layer_index
(
prefix
),
"dual_chunk_attention_config"
:
dual_chunk_attention_config
,
}
if
dual_chunk_attention_config
else
{},
)
self
.
q_norm
=
RMSNorm
(
self
.
head_dim
,
eps
=
rms_norm_eps
)
self
.
q_norm
=
RMSNorm
(
self
.
head_dim
,
eps
=
rms_norm_eps
)
self
.
k_norm
=
RMSNorm
(
self
.
head_dim
,
eps
=
rms_norm_eps
)
self
.
k_norm
=
RMSNorm
(
self
.
head_dim
,
eps
=
rms_norm_eps
)
def
rms_rotary_embedding_fuse
(
positions
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
Optional
[
torch
.
Tensor
],
head_size
:
int
,
cos_sin_cache
:
torch
.
Tensor
,
is_neox_style
:
bool
,
q_weight
:
torch
.
Tensor
,
k_weight
:
torch
.
Tensor
,
q_bias
:
Optional
[
torch
.
Tensor
],
k_bias
:
Optional
[
torch
.
Tensor
],
epsilon
:
float
,
)
->
None
:
from
lightop
import
rms_rotary_embedding_fuse
as
fused_kernel
fused_kernel
(
positions
,
query
,
key
,
head_size
,
cos_sin_cache
,
is_neox_style
,
q_weight
,
k_weight
,
q_bias
,
k_bias
,
epsilon
,
)
def
rms_rotary_embedding_fuse_fake
(
# q_out:torch.Tensor,
# k_out:torch.Tensor,
positions
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
Optional
[
torch
.
Tensor
],
head_size
:
int
,
cos_sin_cache
:
torch
.
Tensor
,
is_neox_style
:
bool
,
q_weight
:
torch
.
Tensor
,
k_weight
:
torch
.
Tensor
,
q_bias
:
Optional
[
torch
.
Tensor
],
k_bias
:
Optional
[
torch
.
Tensor
],
epsilon
:
float
,
)
->
None
:
# Fake impl intentionally left as no-op for graph tracing modes.
pass
direct_register_custom_op
(
op_name
=
"rms_rotary_embedding_fuse"
,
op_func
=
rms_rotary_embedding_fuse
,
mutates_args
=
[
"query"
,
"key"
],
fake_impl
=
rms_rotary_embedding_fuse_fake
,
)
def
forward
(
def
forward
(
self
,
self
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
...
@@ -278,7 +283,36 @@ class Qwen3MoeAttention(nn.Module):
...
@@ -278,7 +283,36 @@ class Qwen3MoeAttention(nn.Module):
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
# Add qk-norm
if
envs
.
VLLM_USE_FUSED_RMS_ROPE
:
# Fused RMSNorm + RoPE path through custom op.
cos_sin_cache
=
self
.
rotary_emb
.
cos_sin_cache
if
(
cos_sin_cache
.
device
!=
q
.
device
or
cos_sin_cache
.
dtype
!=
q
.
dtype
):
cos_sin_cache
=
cos_sin_cache
.
to
(
q
.
device
,
dtype
=
q
.
dtype
,
non_blocking
=
True
)
# Persist the converted cache so we don't re-copy/re-allocate
# on every forward when the original buffer starts on CPU.
self
.
rotary_emb
.
cos_sin_cache
=
cos_sin_cache
# # q, k 使用 continuous
q
=
q
.
contiguous
()
k
=
k
.
contiguous
()
torch
.
ops
.
vllm
.
rms_rotary_embedding_fuse
(
positions
,
q
,
k
,
self
.
head_dim
,
cos_sin_cache
,
self
.
rotary_emb
.
is_neox_style
,
self
.
q_norm
.
weight
,
self
.
k_norm
.
weight
,
None
,
None
,
self
.
q_norm
.
variance_epsilon
,
)
else
:
# Add qk-norm then RoPE (original path).
q_by_head
=
q
.
view
(
*
q
.
shape
[:
-
1
],
q
.
shape
[
-
1
]
//
self
.
head_dim
,
q_by_head
=
q
.
view
(
*
q
.
shape
[:
-
1
],
q
.
shape
[
-
1
]
//
self
.
head_dim
,
self
.
head_dim
)
self
.
head_dim
)
if
envs
.
VLLM_USE_APEX_RN
:
if
envs
.
VLLM_USE_APEX_RN
:
...
@@ -302,21 +336,19 @@ class Qwen3MoeAttention(nn.Module):
...
@@ -302,21 +336,19 @@ class Qwen3MoeAttention(nn.Module):
class
Qwen3MoeDecoderLayer
(
nn
.
Module
):
class
Qwen3MoeDecoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
)
->
None
:
def
__init__
(
self
,
config
:
PretrainedConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_text_config
cache_config
=
vllm_config
.
cache_config
quant_config
=
vllm_config
.
quant_config
self
.
hidden_size
=
config
.
hidden_size
self
.
hidden_size
=
config
.
hidden_size
rope_theta
=
getattr
(
config
,
"rope_theta"
,
10000
)
rope_theta
=
getattr
(
config
,
"rope_theta"
,
10000
)
rope_scaling
=
getattr
(
config
,
"rope_scaling"
,
None
)
rope_scaling
=
getattr
(
config
,
"rope_scaling"
,
None
)
max_position_embeddings
=
getattr
(
config
,
"max_position_embeddings"
,
max_position_embeddings
=
getattr
(
config
,
"max_position_embeddings"
,
8192
)
8192
)
dual_chunk_attention_config
=
getattr
(
config
,
"dual_chunk_attention_config"
,
None
)
self
.
self_attn
=
Qwen3MoeAttention
(
self
.
self_attn
=
Qwen3MoeAttention
(
hidden_size
=
self
.
hidden_size
,
hidden_size
=
self
.
hidden_size
,
num_heads
=
config
.
num_attention_heads
,
num_heads
=
config
.
num_attention_heads
,
...
@@ -330,7 +362,6 @@ class Qwen3MoeDecoderLayer(nn.Module):
...
@@ -330,7 +362,6 @@ class Qwen3MoeDecoderLayer(nn.Module):
cache_config
=
cache_config
,
cache_config
=
cache_config
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.self_attn"
,
prefix
=
f
"
{
prefix
}
.self_attn"
,
dual_chunk_attention_config
=
dual_chunk_attention_config
,
)
)
# `mlp_only_layers` in the config.
# `mlp_only_layers` in the config.
...
@@ -340,7 +371,8 @@ class Qwen3MoeDecoderLayer(nn.Module):
...
@@ -340,7 +371,8 @@ class Qwen3MoeDecoderLayer(nn.Module):
if
(
layer_idx
not
in
mlp_only_layers
)
and
(
if
(
layer_idx
not
in
mlp_only_layers
)
and
(
config
.
num_experts
>
0
and
config
.
num_experts
>
0
and
(
layer_idx
+
1
)
%
config
.
decoder_sparse_step
==
0
):
(
layer_idx
+
1
)
%
config
.
decoder_sparse_step
==
0
):
self
.
mlp
=
Qwen3MoeSparseMoeBlock
(
vllm_config
=
vllm_config
,
self
.
mlp
=
Qwen3MoeSparseMoeBlock
(
config
=
config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.mlp"
)
prefix
=
f
"
{
prefix
}
.mlp"
)
else
:
else
:
self
.
mlp
=
Qwen3MoeMLP
(
hidden_size
=
config
.
hidden_size
,
self
.
mlp
=
Qwen3MoeMLP
(
hidden_size
=
config
.
hidden_size
,
...
@@ -384,11 +416,9 @@ class Qwen3MoeModel(nn.Module):
...
@@ -384,11 +416,9 @@ class Qwen3MoeModel(nn.Module):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_text_config
config
=
vllm_config
.
model_config
.
hf_config
cache_config
=
vllm_config
.
cache_config
quant_config
=
vllm_config
.
quant_config
quant_config
=
vllm_config
.
quant_config
parallel_config
=
vllm_config
.
parallel_config
eplb_config
=
parallel_config
.
eplb_config
self
.
num_redundant_experts
=
eplb_config
.
num_redundant_experts
self
.
padding_idx
=
config
.
pad_token_id
self
.
padding_idx
=
config
.
pad_token_id
self
.
vocab_size
=
config
.
vocab_size
self
.
vocab_size
=
config
.
vocab_size
...
@@ -403,11 +433,12 @@ class Qwen3MoeModel(nn.Module):
...
@@ -403,11 +433,12 @@ class Qwen3MoeModel(nn.Module):
self
.
embed_tokens
=
VocabParallelEmbedding
(
self
.
embed_tokens
=
VocabParallelEmbedding
(
config
.
vocab_size
,
config
.
vocab_size
,
config
.
hidden_size
,
config
.
hidden_size
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.embed_tokens"
)
prefix
=
f
"
{
prefix
}
.embed_tokens"
)
self
.
start_layer
,
self
.
end_layer
,
self
.
layers
=
make_layers
(
self
.
start_layer
,
self
.
end_layer
,
self
.
layers
=
make_layers
(
config
.
num_hidden_layers
,
config
.
num_hidden_layers
,
lambda
prefix
:
Qwen3MoeDecoderLayer
(
vllm_config
=
vllm_config
,
lambda
prefix
:
Qwen3MoeDecoderLayer
(
config
=
config
,
cache_config
=
cache_config
,
quant_config
=
quant_config
,
prefix
=
prefix
),
prefix
=
prefix
),
prefix
=
f
"
{
prefix
}
.layers"
,
prefix
=
f
"
{
prefix
}
.layers"
,
)
)
...
@@ -444,7 +475,8 @@ class Qwen3MoeModel(nn.Module):
...
@@ -444,7 +475,8 @@ class Qwen3MoeModel(nn.Module):
assert
intermediate_tensors
is
not
None
assert
intermediate_tensors
is
not
None
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
residual
=
intermediate_tensors
[
"residual"
]
residual
=
intermediate_tensors
[
"residual"
]
for
layer
in
islice
(
self
.
layers
,
self
.
start_layer
,
self
.
end_layer
):
for
i
in
range
(
self
.
start_layer
,
self
.
end_layer
):
layer
=
self
.
layers
[
i
]
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
residual
)
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
residual
)
if
not
get_pp_group
().
is_last_rank
:
if
not
get_pp_group
().
is_last_rank
:
return
IntermediateTensors
({
return
IntermediateTensors
({
...
@@ -454,16 +486,6 @@ class Qwen3MoeModel(nn.Module):
...
@@ -454,16 +486,6 @@ class Qwen3MoeModel(nn.Module):
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
return
hidden_states
return
hidden_states
def
get_expert_mapping
(
self
)
->
list
[
tuple
[
str
,
str
,
int
,
str
]]:
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
return
FusedMoE
.
make_expert_params_mapping
(
ckpt_gate_proj_name
=
"gate_proj"
,
ckpt_down_proj_name
=
"down_proj"
,
ckpt_up_proj_name
=
"up_proj"
,
num_experts
=
self
.
config
.
num_experts
,
num_redundant_experts
=
self
.
num_redundant_experts
)
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]])
->
set
[
str
]:
torch
.
Tensor
]])
->
set
[
str
]:
stacked_params_mapping
=
[
stacked_params_mapping
=
[
...
@@ -480,9 +502,16 @@ class Qwen3MoeModel(nn.Module):
...
@@ -480,9 +502,16 @@ class Qwen3MoeModel(nn.Module):
".v_scale"
,
"_v_scale"
,
".weight_scale"
,
".v_scale"
,
"_v_scale"
,
".weight_scale"
,
"_weight_scale"
,
".input_scale"
,
"_input_scale"
)
"_weight_scale"
,
".input_scale"
,
"_input_scale"
)
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping
=
FusedMoE
.
make_expert_params_mapping
(
ckpt_gate_proj_name
=
"gate_proj"
,
ckpt_down_proj_name
=
"down_proj"
,
ckpt_up_proj_name
=
"up_proj"
,
num_experts
=
self
.
config
.
num_experts
)
params_dict
=
dict
(
self
.
named_parameters
())
params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
:
set
[
str
]
=
set
()
loaded_params
:
set
[
str
]
=
set
()
expert_params_mapping
=
self
.
get_expert_mapping
()
for
name
,
loaded_weight
in
weights
:
for
name
,
loaded_weight
in
weights
:
if
self
.
use_llama_nn
:
if
self
.
use_llama_nn
:
current_count
=
loaded_weight
.
current_count
current_count
=
loaded_weight
.
current_count
...
@@ -508,68 +537,35 @@ class Qwen3MoeModel(nn.Module):
...
@@ -508,68 +537,35 @@ class Qwen3MoeModel(nn.Module):
# Skip layers on other devices.
# Skip layers on other devices.
if
is_pp_missing_parameter
(
name
,
self
):
if
is_pp_missing_parameter
(
name
,
self
):
continue
continue
if
name
.
endswith
(
"scale"
):
# Remapping the name of FP8 kv-scale.
name
=
maybe_remap_kv_scale_name
(
name
,
params_dict
)
if
name
is
None
:
continue
if
name
not
in
params_dict
:
if
name
not
in
params_dict
:
continue
continue
param
=
params_dict
[
name
]
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
weight_loader
=
param
.
weight_loader
default_weight_loader
)
if
weight_loader
==
default_weight_loader
:
weight_loader
(
param
,
loaded_weight
)
else
:
weight_loader
(
param
,
loaded_weight
,
shard_id
)
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
break
else
:
else
:
is_expert_weight
=
False
for
mapping
in
expert_params_mapping
:
for
mapping
in
expert_params_mapping
:
param_name
,
weight_name
,
expert_id
,
shard_id
=
mapping
param_name
,
weight_name
,
expert_id
,
shard_id
=
mapping
if
weight_name
not
in
name
:
if
weight_name
not
in
name
:
continue
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
# Anyway, this is an expert weight and should not be
# Skip layers on other devices.
# attempted to load as other weights later
if
is_pp_missing_parameter
(
name
,
self
):
is_expert_weight
=
True
# Do not modify `name` since the loop may continue here
# Instead, create a new variable
name_mapped
=
name
.
replace
(
weight_name
,
param_name
)
if
is_pp_missing_parameter
(
name_mapped
,
self
):
continue
continue
# Skip loading extra parameters for GPTQ/modelopt models.
# Skip loading extra parameters for GPTQ/modelopt models.
if
name_mapped
.
endswith
(
if
name
.
endswith
(
ignore_suffixes
ignore_suffixes
)
and
name
not
in
params_dict
:
)
and
name_mapped
not
in
params_dict
:
continue
continue
param
=
params_dict
[
name
]
param
=
params_dict
[
name_mapped
]
weight_loader
=
param
.
weight_loader
# We should ask the weight loader to return success or not
weight_loader
(
param
,
# here since otherwise we may skip experts with other
# available replicas.
weight_loader
=
typing
.
cast
(
Callable
[...,
bool
],
param
.
weight_loader
)
success
=
weight_loader
(
param
,
loaded_weight
,
loaded_weight
,
name_mapped
,
name
,
shard_id
=
shard_id
,
shard_id
=
shard_id
,
expert_id
=
expert_id
,
expert_id
=
expert_id
)
return_success
=
True
)
if
success
:
name
=
name_mapped
break
break
else
:
else
:
if
is_expert_weight
:
# We've checked that this is an expert weight
# However it's not mapped locally to this rank
# So we simply skip it
continue
# Skip loading extra parameters for GPTQ/modelopt models.
# Skip loading extra parameters for GPTQ/modelopt models.
if
name
.
endswith
(
if
name
.
endswith
(
ignore_suffixes
)
and
name
not
in
params_dict
:
ignore_suffixes
)
and
name
not
in
params_dict
:
...
@@ -639,8 +635,7 @@ class Qwen3MoeModel(nn.Module):
...
@@ -639,8 +635,7 @@ class Qwen3MoeModel(nn.Module):
return
loaded_params
return
loaded_params
class
Qwen3MoeForCausalLM
(
nn
.
Module
,
SupportsPP
,
SupportsLoRA
,
class
Qwen3MoeForCausalLM
(
nn
.
Module
,
SupportsPP
):
MixtureOfExperts
):
packed_modules_mapping
=
{
packed_modules_mapping
=
{
"qkv_proj"
:
[
"qkv_proj"
:
[
"q_proj"
,
"q_proj"
,
...
@@ -657,7 +652,7 @@ class Qwen3MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA,
...
@@ -657,7 +652,7 @@ class Qwen3MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA,
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_
text_
config
config
=
vllm_config
.
model_config
.
hf_config
quant_config
=
vllm_config
.
quant_config
quant_config
=
vllm_config
.
quant_config
self
.
config
=
config
self
.
config
=
config
self
.
quant_config
=
quant_config
self
.
quant_config
=
quant_config
...
@@ -665,74 +660,13 @@ class Qwen3MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA,
...
@@ -665,74 +660,13 @@ class Qwen3MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA,
prefix
=
maybe_prefix
(
prefix
,
"model"
))
prefix
=
maybe_prefix
(
prefix
,
"model"
))
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
,
config
.
hidden_size
,
quant_config
=
quant_config
,
quant_config
=
quant_config
)
prefix
=
maybe_prefix
(
prefix
,
"lm_head"
))
if
self
.
config
.
tie_word_embeddings
:
if
self
.
config
.
tie_word_embeddings
:
self
.
lm_head
.
weight
=
self
.
model
.
embed_tokens
.
weight
self
.
lm_head
.
weight
=
self
.
model
.
embed_tokens
.
weight
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
make_empty_intermediate_tensors
=
(
self
.
make_empty_intermediate_tensors
=
(
self
.
model
.
make_empty_intermediate_tensors
)
self
.
model
.
make_empty_intermediate_tensors
)
# Set MoE hyperparameters
self
.
expert_weights
=
[]
self
.
moe_layers
:
list
[
FusedMoE
]
=
[]
example_layer
=
None
for
layer
in
self
.
model
.
layers
:
if
isinstance
(
layer
,
PPMissingLayer
):
continue
assert
isinstance
(
layer
,
Qwen3MoeDecoderLayer
)
if
isinstance
(
layer
.
mlp
,
Qwen3MoeSparseMoeBlock
):
example_layer
=
layer
.
mlp
self
.
moe_layers
.
append
(
layer
.
mlp
.
experts
)
if
example_layer
is
None
:
raise
RuntimeError
(
"No Qwen3MoE layer found in the model.layers."
)
self
.
num_moe_layers
=
len
(
self
.
moe_layers
)
self
.
num_expert_groups
=
1
self
.
num_shared_experts
=
0
self
.
num_logical_experts
=
example_layer
.
n_logical_experts
self
.
num_physical_experts
=
example_layer
.
n_physical_experts
self
.
num_local_physical_experts
=
example_layer
.
n_local_physical_experts
self
.
num_routed_experts
=
example_layer
.
n_routed_experts
self
.
num_redundant_experts
=
example_layer
.
n_redundant_experts
def
set_eplb_state
(
self
,
expert_load_view
:
torch
.
Tensor
,
logical_to_physical_map
:
torch
.
Tensor
,
logical_replica_count
:
torch
.
Tensor
,
)
->
None
:
for
layer_idx
,
layer
in
enumerate
(
self
.
moe_layers
):
# Register the expert weights.
self
.
expert_weights
.
append
(
layer
.
get_expert_weights
())
layer
.
set_eplb_state
(
moe_layer_idx
=
layer_idx
,
expert_load_view
=
expert_load_view
,
logical_to_physical_map
=
logical_to_physical_map
,
logical_replica_count
=
logical_replica_count
,
)
def
update_physical_experts_metadata
(
self
,
num_physical_experts
:
int
,
num_local_physical_experts
:
int
,
)
->
None
:
assert
self
.
num_local_physical_experts
==
num_local_physical_experts
self
.
num_physical_experts
=
num_physical_experts
self
.
num_local_physical_experts
=
num_local_physical_experts
self
.
num_redundant_experts
=
(
num_physical_experts
-
self
.
num_logical_experts
)
for
layer
in
self
.
model
.
layers
:
if
isinstance
(
layer
.
mlp
,
Qwen3MoeSparseMoeBlock
):
moe
=
layer
.
mlp
moe
.
n_local_physical_experts
=
num_local_physical_experts
moe
.
n_physical_experts
=
num_physical_experts
moe
.
n_redundant_experts
=
self
.
num_redundant_experts
moe
.
experts
.
update_expert_map
()
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
model
.
get_input_embeddings
(
input_ids
)
return
self
.
model
.
get_input_embeddings
(
input_ids
)
...
@@ -750,14 +684,13 @@ class Qwen3MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA,
...
@@ -750,14 +684,13 @@ class Qwen3MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA,
def
compute_logits
(
def
compute_logits
(
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
torch
.
Tensor
]:
)
->
Optional
[
torch
.
Tensor
]:
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
)
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
,
sampling_metadata
)
return
logits
return
logits
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]])
->
set
[
str
]:
torch
.
Tensor
]])
->
set
[
str
]:
loader
=
AutoWeightsLoader
(
self
)
loader
=
AutoWeightsLoader
(
self
)
return
loader
.
load_weights
(
weights
)
return
loader
.
load_weights
(
weights
)
\ No newline at end of file
def
get_expert_mapping
(
self
)
->
list
[
tuple
[
str
,
str
,
int
,
str
]]:
return
self
.
model
.
get_expert_mapping
()
vllm/v1/attention/backends/mla/common.py
View file @
82cd3c88
...
@@ -217,6 +217,7 @@ from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
...
@@ -217,6 +217,7 @@ from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
CommonAttentionMetadata
)
CommonAttentionMetadata
)
from
vllm.v1.kv_cache_interface
import
AttentionSpec
from
vllm.v1.kv_cache_interface
import
AttentionSpec
from
vllm.v1.worker.block_table
import
BlockTable
from
vllm.v1.worker.block_table
import
BlockTable
from
lightop
import
fused_rms_norm_rope_contiguous
try
:
try
:
from
vllm.vllm_flash_attn
import
flash_attn_varlen_func
from
vllm.vllm_flash_attn
import
flash_attn_varlen_func
...
@@ -1163,7 +1164,6 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
...
@@ -1163,7 +1164,6 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
kv_cache_dtype_str
=
"bf16"
kv_cache_dtype_str
=
"bf16"
else
:
else
:
kv_cache_dtype_str
=
self
.
kv_cache_dtype
kv_cache_dtype_str
=
self
.
kv_cache_dtype
from
lightop
import
fused_rms_norm_rope_contiguous
fused_rms_norm_rope_contiguous
(
fused_rms_norm_rope_contiguous
(
positions
[:
num_actual_toks
,
...],
positions
[:
num_actual_toks
,
...],
q
,
q
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment