Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
855cb148
Commit
855cb148
authored
Jan 22, 2026
by
王敏
Browse files
merge dev分支代码
parents
9135afe4
fe2e2705
Changes
31
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
282 additions
and
95 deletions
+282
-95
vllm/model_executor/models/falcon.py
vllm/model_executor/models/falcon.py
+2
-1
vllm/model_executor/models/glm4.py
vllm/model_executor/models/glm4.py
+2
-1
vllm/model_executor/models/qwen3.py
vllm/model_executor/models/qwen3.py
+95
-15
vllm/model_executor/models/qwen3_moe.py
vllm/model_executor/models/qwen3_moe.py
+24
-1
vllm/model_executor/models/telechat2.py
vllm/model_executor/models/telechat2.py
+2
-2
vllm/platforms/rocm.py
vllm/platforms/rocm.py
+1
-1
vllm/utils/__init__.py
vllm/utils/__init__.py
+14
-6
vllm/v1/attention/backends/mla/common.py
vllm/v1/attention/backends/mla/common.py
+61
-17
vllm/v1/attention/backends/mla/flashmla.py
vllm/v1/attention/backends/mla/flashmla.py
+41
-23
vllm/v1/core/sched/scheduler.py
vllm/v1/core/sched/scheduler.py
+11
-10
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+29
-18
No files found.
vllm/model_executor/models/falcon.py
View file @
855cb148
...
...
@@ -58,6 +58,7 @@ from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.utils
import
pad_weight
,
gemm_bank_conf
import
vllm.envs
as
envs
FalconConfig
=
Union
[
HF_FalconConfig
,
RWConfig
]
...
...
@@ -393,7 +394,7 @@ class FalconModel(nn.Module):
self
.
use_gemm_pad
=
os
.
environ
.
get
(
'GEMM_PAD'
)
==
'1'
self
.
use_fa_pad
=
os
.
environ
.
get
(
'FA_PAD'
)
==
'1'
self
.
use_awq_pad
=
os
.
environ
.
get
(
'AWQ_PAD'
)
==
'1'
self
.
w8a8_strategy
=
int
(
os
.
getenv
(
'W8A8_SUPPORT_METHODS'
,
'1'
))
self
.
w8a8_strategy
=
envs
.
VLLM_W8A8_BACKEND
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
word_embeddings
(
input_ids
)
...
...
vllm/model_executor/models/glm4.py
View file @
855cb148
...
...
@@ -31,6 +31,7 @@ from typing import Optional, Union
import
torch
from
torch
import
nn
from
transformers
import
Glm4Config
import
vllm.envs
as
envs
class
MultiModalConfigProxy
:
...
...
@@ -332,7 +333,7 @@ class Glm4ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self
.
use_gemm_pad
=
os
.
environ
.
get
(
'GEMM_PAD'
)
==
'1'
self
.
use_fa_pad
=
os
.
environ
.
get
(
'FA_PAD'
)
==
'1'
self
.
use_awq_pad
=
os
.
environ
.
get
(
'AWQ_PAD'
)
==
'1'
self
.
w8a8_strategy
=
int
(
os
.
getenv
(
'W8A8_SUPPORT_METHODS'
,
'1'
))
self
.
w8a8_strategy
=
envs
.
VLLM_W8A8_BACKEND
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
model
.
get_input_embeddings
(
input_ids
)
...
...
vllm/model_executor/models/qwen3.py
View file @
855cb148
...
...
@@ -52,6 +52,7 @@ from .qwen2 import Qwen2MLP as Qwen3MLP
from
.qwen2
import
Qwen2Model
from
.utils
import
AutoWeightsLoader
,
PPMissingLayer
,
maybe_prefix
import
vllm.envs
as
envs
from
vllm.utils
import
direct_register_custom_op
logger
=
init_logger
(
__name__
)
...
...
@@ -129,6 +130,58 @@ class Qwen3Attention(nn.Module):
self
.
q_norm
=
RMSNorm
(
self
.
head_dim
,
eps
=
rms_norm_eps
)
self
.
k_norm
=
RMSNorm
(
self
.
head_dim
,
eps
=
rms_norm_eps
)
def
rms_rotary_embedding_fuse
(
positions
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
Optional
[
torch
.
Tensor
],
head_size
:
int
,
cos_sin_cache
:
torch
.
Tensor
,
is_neox_style
:
bool
,
q_weight
:
torch
.
Tensor
,
k_weight
:
torch
.
Tensor
,
q_bias
:
Optional
[
torch
.
Tensor
],
k_bias
:
Optional
[
torch
.
Tensor
],
epsilon
:
float
,
)
->
None
:
from
lightop
import
rms_rotary_embedding_fuse
as
fused_kernel
fused_kernel
(
positions
,
query
,
key
,
head_size
,
cos_sin_cache
,
is_neox_style
,
q_weight
,
k_weight
,
q_bias
,
k_bias
,
epsilon
,
)
def
rms_rotary_embedding_fuse_fake
(
positions
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
Optional
[
torch
.
Tensor
],
head_size
:
int
,
cos_sin_cache
:
torch
.
Tensor
,
is_neox_style
:
bool
,
q_weight
:
torch
.
Tensor
,
k_weight
:
torch
.
Tensor
,
q_bias
:
Optional
[
torch
.
Tensor
],
k_bias
:
Optional
[
torch
.
Tensor
],
epsilon
:
float
,
)
->
None
:
# Fake impl intentionally left as no-op for graph tracing modes.
pass
if
not
hasattr
(
torch
.
ops
.
vllm
,
"rms_rotary_embedding_fuse"
):
direct_register_custom_op
(
op_name
=
"rms_rotary_embedding_fuse"
,
op_func
=
rms_rotary_embedding_fuse
,
mutates_args
=
[
"query"
,
"key"
],
fake_impl
=
rms_rotary_embedding_fuse_fake
,
)
def
forward
(
self
,
positions
:
torch
.
Tensor
,
...
...
@@ -136,22 +189,49 @@ class Qwen3Attention(nn.Module):
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
# Add qk-norm
q_by_head
=
q
.
view
(
*
q
.
shape
[:
-
1
],
q
.
shape
[
-
1
]
//
self
.
head_dim
,
self
.
head_dim
)
if
envs
.
VLLM_USE_APEX_RN
:
q_by_head
=
self
.
q_norm
.
forward_apex
(
q_by_head
)
else
:
q_by_head
=
self
.
q_norm
.
forward_cuda
(
q_by_head
)
q
=
q_by_head
.
view
(
q
.
shape
)
k_by_head
=
k
.
view
(
*
k
.
shape
[:
-
1
],
k
.
shape
[
-
1
]
//
self
.
head_dim
,
self
.
head_dim
)
if
envs
.
VLLM_USE_APEX_RN
:
k_by_head
=
self
.
k_norm
.
forward_apex
(
k_by_head
)
if
envs
.
VLLM_USE_FUSED_RMS_ROPE
:
# Fused RMSNorm + RoPE path through custom op.
cos_sin_cache
=
self
.
rotary_emb
.
cos_sin_cache
if
(
cos_sin_cache
.
device
!=
q
.
device
or
cos_sin_cache
.
dtype
!=
q
.
dtype
):
cos_sin_cache
=
cos_sin_cache
.
to
(
q
.
device
,
dtype
=
q
.
dtype
,
non_blocking
=
True
)
# Persist the converted cache so we don't re-copy/re-allocate
# on every forward when the original buffer starts on CPU.
self
.
rotary_emb
.
cos_sin_cache
=
cos_sin_cache
q
=
q
.
contiguous
()
k
=
k
.
contiguous
()
torch
.
ops
.
vllm
.
rms_rotary_embedding_fuse
(
positions
,
q
,
k
,
self
.
head_dim
,
cos_sin_cache
,
self
.
rotary_emb
.
is_neox_style
,
self
.
q_norm
.
weight
,
self
.
k_norm
.
weight
,
None
,
None
,
self
.
q_norm
.
variance_epsilon
,
)
else
:
k_by_head
=
self
.
k_norm
.
forward_cuda
(
k_by_head
)
k
=
k_by_head
.
view
(
k
.
shape
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
# Add qk-norm then RoPE (original path).
q_by_head
=
q
.
view
(
*
q
.
shape
[:
-
1
],
q
.
shape
[
-
1
]
//
self
.
head_dim
,
self
.
head_dim
)
if
envs
.
VLLM_USE_APEX_RN
:
q_by_head
=
self
.
q_norm
.
forward_apex
(
q_by_head
)
else
:
q_by_head
=
self
.
q_norm
.
forward_cuda
(
q_by_head
)
q
=
q_by_head
.
view
(
q
.
shape
)
k_by_head
=
k
.
view
(
*
k
.
shape
[:
-
1
],
k
.
shape
[
-
1
]
//
self
.
head_dim
,
self
.
head_dim
)
if
envs
.
VLLM_USE_APEX_RN
:
k_by_head
=
self
.
k_norm
.
forward_apex
(
k_by_head
)
else
:
k_by_head
=
self
.
k_norm
.
forward_cuda
(
k_by_head
)
k
=
k_by_head
.
view
(
k
.
shape
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
attn_output
=
self
.
attn
(
q
,
k
,
v
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
...
...
vllm/model_executor/models/qwen3_moe.py
View file @
855cb148
...
...
@@ -38,6 +38,19 @@ from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.fused_moe
import
FusedMoE
try
:
from
vllm.model_executor.layers.fused_moe.router_capture
import
(
maybe_record_router_logits
,
)
except
ImportError
:
def
maybe_record_router_logits
(
*
,
layer_name
:
str
,
router_logits
:
torch
.
Tensor
,
top_k
:
int
,
)
->
None
:
return
None
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
MergedColumnParallelLinear
,
QKVParallelLinear
,
...
...
@@ -111,6 +124,8 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
):
super
().
__init__
()
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
_router_top_k
=
int
(
config
.
num_experts_per_tok
)
self
.
_router_capture_layer_name
=
prefix
if
self
.
tp_size
>
config
.
num_experts
:
raise
ValueError
(
...
...
@@ -140,6 +155,14 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
# router_logits: (num_tokens, n_experts)
router_logits
,
_
=
self
.
gate
(
hidden_states
)
if
not
(
hasattr
(
torch
,
"_dynamo"
)
and
torch
.
_dynamo
.
is_compiling
()):
capture_enabled
=
envs
.
VLLM_MOE_ROUTER_CAPTURE
if
capture_enabled
:
maybe_record_router_logits
(
layer_name
=
self
.
_router_capture_layer_name
,
router_logits
=
router_logits
,
top_k
=
self
.
_router_top_k
,
)
final_hidden_states
=
self
.
experts
(
hidden_states
=
hidden_states
,
router_logits
=
router_logits
)
...
...
@@ -453,7 +476,7 @@ class Qwen3MoeModel(nn.Module):
self
.
use_gemm_pad
=
os
.
environ
.
get
(
'GEMM_PAD'
)
==
'1'
self
.
use_fa_pad
=
os
.
environ
.
get
(
'FA_PAD'
)
==
'1'
self
.
use_awq_pad
=
os
.
environ
.
get
(
'AWQ_PAD'
)
==
'1'
self
.
w8a8_strategy
=
int
(
os
.
getenv
(
'W8A8_SUPPORT_METHODS'
,
'1'
))
self
.
w8a8_strategy
=
envs
.
VLLM_W8A8_BACKEND
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
embed_tokens
(
input_ids
)
...
...
vllm/model_executor/models/telechat2.py
View file @
855cb148
...
...
@@ -37,6 +37,7 @@ from .utils import (AutoWeightsLoader, PPMissingLayer, WeightsMapper,
from
vllm.utils
import
W8a8GetCacheJSON
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.utils
import
pad_weight
,
gemm_bank_conf
import
vllm.envs
as
envs
class
TeleChat2Model
(
LlamaModel
):
...
...
@@ -66,8 +67,7 @@ class TeleChat2Model(LlamaModel):
self
.
use_gemm_pad
=
os
.
environ
.
get
(
'GEMM_PAD'
)
==
'1'
self
.
use_fa_pad
=
os
.
environ
.
get
(
'FA_PAD'
)
==
'1'
self
.
use_awq_pad
=
os
.
environ
.
get
(
'AWQ_PAD'
)
==
'1'
self
.
w8a8_strategy
=
int
(
os
.
getenv
(
'W8A8_SUPPORT_METHODS'
,
'1'
))
self
.
w8a8_strategy
=
envs
.
VLLM_W8A8_BACKEND
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]])
->
set
[
str
]:
stacked_params_mapping
=
[
...
...
vllm/platforms/rocm.py
View file @
855cb148
...
...
@@ -17,7 +17,7 @@ from vllm.utils import cuda_device_count_stateless
from
.interface
import
DeviceCapability
,
Platform
,
PlatformEnum
,
_Backend
from
vllm.utils
import
SUPPORT_TC
if
not
SUPPORT_TC
:
os
.
environ
[
'VLLM_USE_V1'
]
=
'0'
os
.
environ
[
'VLLM_USE_FLASH_ATTN_PA'
]
=
'0'
...
...
vllm/utils/__init__.py
View file @
855cb148
...
...
@@ -1958,7 +1958,7 @@ class W8a8GetCacheJSON:
self
.
moe_weight_shapes
=
[]
arch_name
=
torch
.
cuda
.
get_device_properties
(
"cuda"
).
gcnArchName
.
split
(
':'
)[
0
]
arch_cu
=
torch
.
cuda
.
get_device_properties
(
torch
.
cuda
.
current_device
()).
multi_processor_count
self
.
cache_json_data
=
{}
device_name
=
arch_name
+
'_'
+
str
(
arch_cu
)
+
'cu'
self
.
device_name
=
device_name
self
.
topk
=
1
...
...
@@ -2060,19 +2060,27 @@ class W8a8GetCacheJSON:
return
self
.
triton_json_dir
+
f
"/linear_
{
n
}
_
{
k
}
_block[
{
block_n
}
,
{
block_k
}
]_
{
self
.
device_name
}
.json"
def
get_moeint8json_name
(
self
,
E
,
N1
,
N2
,
K
,
TOPK
,
block_size
:
Optional
[
list
]
=
None
,
use_int4_w4a8
:
Optional
[
bool
]
=
False
):
block_size
:
Optional
[
list
]
=
None
,
use_int4_w4a8
:
Optional
[
bool
]
=
False
,
use_int8_w8a8
:
Optional
[
bool
]
=
False
):
if
use_int4_w4a8
:
if
block_size
is
not
None
:
return
self
.
triton_json_dir
+
f
"/MOE_W4A8INT8[
{
block_size
[
0
]
}
,
{
block_size
[
1
]
}
]_E=
{
E
}
_N1=
{
N1
}
_N2=
{
N2
}
_K=
{
K
}
_TOPK
{
TOPK
}
_
{
self
.
device_name
}
.json"
else
:
return
self
.
triton_json_dir
+
f
"/MOE_W4A8INT8_E=
{
E
}
_N1=
{
N1
}
_N2=
{
N2
}
_K=
{
K
}
_TOPK
{
TOPK
}
_
{
self
.
device_name
}
.json"
return
self
.
triton_json_dir
+
f
"/MOE_W4A8INT8_E=
{
E
}
_N1=
{
N1
}
_N2=
{
N2
}
_K=
{
K
}
_TOPK
{
TOPK
}
_
{
self
.
device_name
}
.json"
elif
use_int8_w8a8
:
if
block_size
is
not
None
:
return
self
.
triton_json_dir
+
f
"/MOE_BLOCKINT8[
{
block_size
[
0
]
}
,
{
block_size
[
1
]
}
]_E=
{
E
}
_N1=
{
N1
}
_N2=
{
N2
}
_K=
{
K
}
_TOPK
{
TOPK
}
_
{
self
.
device_name
}
.json"
else
:
return
self
.
triton_json_dir
+
f
"/MOE_W8A8INT8_E=
{
E
}
_N1=
{
N1
}
_N2=
{
N2
}
_K=
{
K
}
_TOPK
{
TOPK
}
_
{
self
.
device_name
}
.json"
else
:
if
block_size
is
not
None
:
return
self
.
triton_json_dir
+
f
"/MOE_BLOCK
INT
8[
{
block_size
[
0
]
}
,
{
block_size
[
1
]
}
]_E=
{
E
}
_N1=
{
N1
}
_N2=
{
N2
}
_K=
{
K
}
_TOPK
{
TOPK
}
_
{
self
.
device_name
}
.json"
return
self
.
triton_json_dir
+
f
"/MOE_BLOCK
FP
8[
{
block_size
[
0
]
}
,
{
block_size
[
1
]
}
]_E=
{
E
}
_N1=
{
N1
}
_N2=
{
N2
}
_K=
{
K
}
_TOPK
{
TOPK
}
_
{
self
.
device_name
}
.json"
else
:
return
self
.
triton_json_dir
+
f
"/MOE_W8A8
INT
8_E=
{
E
}
_N1=
{
N1
}
_N2=
{
N2
}
_K=
{
K
}
_TOPK
{
TOPK
}
_
{
self
.
device_name
}
.json"
return
self
.
triton_json_dir
+
f
"/MOE_W8A8
FP
8_E=
{
E
}
_N1=
{
N1
}
_N2=
{
N2
}
_K=
{
K
}
_TOPK
{
TOPK
}
_
{
self
.
device_name
}
.json"
def
get_moeint8_triton_cache
(
self
,
file_path
,
E
,
N1
,
N2
,
K
,
TOPK
):
if
file_path
in
self
.
cache_json_data
:
# 直接返回缓存数据,避免重复读取
return
self
.
cache_json_data
[
file_path
]
cache_json_file
=
file_path
if
os
.
path
.
exists
(
file_path
):
...
...
@@ -2088,7 +2096,7 @@ class W8a8GetCacheJSON:
for
sub_key
,
sub_value
in
value
.
items
():
configs_key
=
f
"
{
sub_key
}
_
{
key
}
"
configs_dict
[
configs_key
]
=
sub_value
self
.
cache_json_data
[
file_path
]
=
configs_dict
return
configs_dict
# Adapted from: https://stackoverflow.com/a/47212782/5082708
...
...
vllm/v1/attention/backends/mla/common.py
View file @
855cb148
...
...
@@ -217,7 +217,7 @@ from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
CommonAttentionMetadata
)
from
vllm.v1.kv_cache_interface
import
AttentionSpec
from
vllm.v1.worker.block_table
import
BlockTable
from
lightop
import
fused_rms_norm_rope_contiguous
from
lightop
import
fused_rms_norm_rope_contiguous
,
fuse_rmsnorm_rope_quant_qkv
try
:
from
vllm.vllm_flash_attn
import
flash_attn_varlen_func
...
...
@@ -1233,21 +1233,62 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
kv_cache_dtype_str
=
"bf16"
else
:
kv_cache_dtype_str
=
self
.
kv_cache_dtype
fused_rms_norm_rope_contiguous
(
positions
[:
num_actual_toks
,
...],
q
,
k_pe
.
squeeze
(
1
),
k_c_normed
,
# not normed
key_normed
[:
num_actual_toks
,
...],
# normed
weight
,
cos_sin_cache
,
attn_metadata
.
slot_mapping
.
flatten
(),
kv_cache
,
kv_cache_dtype_str
,
1.0
,
False
,
1e-6
,
)
if
torch
.
cuda
.
get_device_properties
(
"cuda"
).
gcnArchName
.
split
(
':'
)[
0
]
==
"gfx938"
and
kv_cache_dtype_str
==
"fp8_e4m3"
and
envs
.
VLLM_USE_FUSED_CACHE_QUANT_BMM_MLA
:
if
has_prefill
:
fused_rms_norm_rope_contiguous
(
positions
[:
num_actual_toks
,
...],
q
,
k_pe
.
squeeze
(
1
),
k_c_normed
,
# not normed
key_normed
[:
num_actual_toks
,
...],
# normed
weight
,
cos_sin_cache
,
attn_metadata
.
slot_mapping
.
flatten
(),
kv_cache
,
kv_cache_dtype_str
,
1.0
,
False
,
1e-6
,
)
else
:
q_tensor
=
torch
.
randn
(
q
.
shape
[
0
],
num_local_heads
,
self
.
qk_nope_head_dim
+
self
.
qk_rope_head_dim
,
dtype
=
q
.
dtype
,
device
=
q
.
device
)
q_quant
=
torch
.
empty_like
(
q_tensor
,
dtype
=
torch
.
float8_e4m3fn
,
device
=
q
.
device
)
q_scale
=
torch
.
empty
(
q
.
shape
[
0
],
dtype
=
torch
.
float32
,
device
=
q
.
device
)
fuse_rmsnorm_rope_quant_qkv
(
positions
[:
num_actual_toks
,
...],
query_nope
,
q
,
q_quant
,
q_scale
,
k_pe
.
squeeze
(
1
),
k_c_normed
,
# not normed
key_normed
[:
num_actual_toks
,
...],
# normed
weight
,
cos_sin_cache
,
attn_metadata
.
slot_mapping
.
flatten
(),
kv_cache
,
kv_cache_dtype_str
,
1.0
,
False
,
1e-6
,
)
else
:
fused_rms_norm_rope_contiguous
(
positions
[:
num_actual_toks
,
...],
q
,
k_pe
.
squeeze
(
1
),
k_c_normed
,
# not normed
key_normed
[:
num_actual_toks
,
...],
# normed
weight
,
cos_sin_cache
,
attn_metadata
.
slot_mapping
.
flatten
(),
kv_cache
,
kv_cache_dtype_str
,
1.0
,
False
,
1e-6
,
)
if
has_prefill
:
if
envs
.
VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT
:
...
...
@@ -1259,12 +1300,15 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
if
has_decode
:
assert
attn_metadata
.
decode
is
not
None
if
torch
.
cuda
.
get_device_properties
(
"cuda"
).
gcnArchName
.
split
(
':'
)[
0
]
==
"gfx938"
and
kv_cache_dtype_str
==
"fp8_e4m3"
and
envs
.
VLLM_USE_FUSED_CACHE_QUANT_BMM_MLA
:
decode_q
=
q_quant
[:
num_decode_tokens
]
decode_q_nope
,
decode_q_pe
=
decode_q
.
split
(
[
self
.
qk_nope_head_dim
,
self
.
qk_rope_head_dim
],
dim
=-
1
)
# Convert from (B, N, P) to (N, B, P)
decode_q_nope
=
decode_q_nope
.
transpose
(
0
,
1
)
# Multiply (N, B, P) x (N, P, L) -> (N, B, L)
decode_ql_nope
=
torch
.
bmm
(
decode_q_nope
,
self
.
W_UK_T
)
# todo: bmm support
decode_ql_nope
=
torch
.
bmm
(
q_scale
,
decode_q_nope
,
self
.
W_UK_T
)
if
torch
.
cuda
.
get_device_properties
(
"cuda"
).
gcnArchName
.
split
(
':'
)[
0
]
==
"gfx938"
and
kv_cache_dtype_str
==
"fp8_e4m3"
and
envs
.
VLLM_USE_FUSED_CACHE_QUANT_BMM_MLA
else
torch
.
bmm
(
decode_q_nope
,
self
.
W_UK_T
)
# Convert from (N, B, L) to (B, N, L)
decode_ql_nope
=
decode_ql_nope
.
transpose
(
0
,
1
)
...
...
vllm/v1/attention/backends/mla/flashmla.py
View file @
855cb148
...
...
@@ -12,6 +12,7 @@ from vllm.attention.ops.flashmla import (flash_mla_with_kvcache,
flash_mla_with_kvcache_q_nope_pe
,
get_mla_metadata
,
flash_mla_with_kvcache_fp8
,
flash_mla_with_kvcache_fp8_with_cat
,
get_mla_decoding_metadata_dense_fp8
,
is_flashmla_supported
)
from
vllm.logger
import
init_logger
...
...
@@ -181,31 +182,48 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
assert
attn_metadata
.
decode
is
not
None
if
torch
.
cuda
.
get_device_properties
(
"cuda"
).
gcnArchName
.
split
(
':'
)[
0
]
==
"gfx938"
and
kv_cache_dtype
==
"fp8_e4m3"
and
envs
.
VLLM_USE_FLASH_MLA_FP8
:
if
envs
.
VLLM_USE_OPT_CAT
:
if
q_nope
.
shape
[
0
]
<
1024
:
from
vllm.v1.attention.backends.mla.test_concat
import
concat_helper_decode
q
=
concat_helper_decode
(
q_nope
,
q_pe
,
dim
=
2
)
\
.
unsqueeze
(
1
)
if
envs
.
VLLM_USE_FUSED_CACHE_QUANT_BMM_MLA
:
o
,
_
=
flash_mla_with_kvcache_fp8_with_cat
(
q_nope
=
q_nope
.
unsqueeze
(
1
),
q_pe
=
q_pe
.
unsqueeze
(
1
),
k_cache
=
kv_c_and_k_pe_cache
.
unsqueeze
(
-
2
),
# Add head dim of 1
block_table
=
attn_metadata
.
decode
.
block_table
,
cache_seqlens
=
attn_metadata
.
decode
.
seq_lens
,
head_dim_v
=
self
.
kv_lora_rank
,
tile_scheduler_metadata
=
attn_metadata
.
decode
.
tile_scheduler_metadata
,
num_splits
=
attn_metadata
.
decode
.
num_splits
,
softmax_scale
=
self
.
scale
,
causal
=
True
,
descale_q
=
q_scale
,
descale_k
=
k_scale
,
)
else
:
if
envs
.
VLLM_USE_OPT_CAT
:
if
q_nope
.
shape
[
0
]
<
1024
:
from
vllm.v1.attention.backends.mla.test_concat
import
concat_helper_decode
q
=
concat_helper_decode
(
q_nope
,
q_pe
,
dim
=
2
)
\
.
unsqueeze
(
1
)
else
:
q
=
torch
.
cat
([
q_nope
,
q_pe
],
dim
=-
1
)
\
.
unsqueeze
(
1
)
# Add seqlen dim of 1 (decode)
else
:
q
=
torch
.
cat
([
q_nope
,
q_pe
],
dim
=-
1
)
\
.
unsqueeze
(
1
)
# Add seqlen dim of 1 (decode)
else
:
q
=
torch
.
cat
([
q_nope
,
q_pe
],
dim
=-
1
)
\
.
unsqueeze
(
1
)
# Add seqlen dim of 1 (decode)
o
,
_
=
flash_mla_with_kvcache_fp8
(
q
=
q
.
to
(
torch
.
float8_e4m3fn
),
k_cache
=
kv_c_and_k_pe_cache
.
unsqueeze
(
-
2
).
view
(
torch
.
float8_e4m3fn
),
# Add head dim of 1
block_table
=
attn_metadata
.
decode
.
block_table
,
cache_seqlens
=
attn_metadata
.
decode
.
seq_lens
,
head_dim_v
=
self
.
kv_lora_rank
,
tile_scheduler_metadata
=
attn_metadata
.
decode
.
tile_scheduler_metadata
,
num_splits
=
attn_metadata
.
decode
.
num_splits
,
softmax_scale
=
self
.
scale
,
causal
=
True
,
descale_q
=
q_scale
,
descale_k
=
k_scale
,
)
.
unsqueeze
(
1
)
# Add seqlen dim of 1 (decode)
o
,
_
=
flash_mla_with_kvcache_fp8
(
q
=
q
.
to
(
torch
.
float8_e4m3fn
),
k_cache
=
kv_c_and_k_pe_cache
.
unsqueeze
(
-
2
).
view
(
torch
.
float8_e4m3fn
),
# Add head dim of 1
block_table
=
attn_metadata
.
decode
.
block_table
,
cache_seqlens
=
attn_metadata
.
decode
.
seq_lens
,
head_dim_v
=
self
.
kv_lora_rank
,
tile_scheduler_metadata
=
attn_metadata
.
decode
.
tile_scheduler_metadata
,
num_splits
=
attn_metadata
.
decode
.
num_splits
,
softmax_scale
=
self
.
scale
,
causal
=
True
,
descale_q
=
q_scale
,
descale_k
=
k_scale
,
)
else
:
if
not
envs
.
VLLM_USE_CAT_MLA
or
kv_cache_dtype
==
"fp8_e4m3"
:
...
...
vllm/v1/core/sched/scheduler.py
View file @
855cb148
...
...
@@ -1051,14 +1051,15 @@ class Scheduler(SchedulerInterface):
def
schedule
(
self
)
->
SchedulerOutput
:
if
envs
.
VLLM_USE_PD_SPLIT
:
return
self
.
schedule_split_pd
()
else
:
if
self
.
connector
is
not
None
:
return
self
.
schedule_default
()
if
self
.
full_cuda_graph
and
self
.
use_mla
and
self
.
num_spec_tokens
>
0
:
return
self
.
schedule_split_pd
()
if
self
.
use_mla
:
if
self
.
full_cuda_graph
and
self
.
num_spec_tokens
>
0
:
return
self
.
schedule_split_pd
()
else
:
return
self
.
schedule_default
()
else
:
return
self
.
schedule_default
()
return
self
.
schedule_split_pd
()
else
:
return
self
.
schedule_default
()
def
_update_after_schedule
(
self
,
...
...
@@ -1101,13 +1102,14 @@ class Scheduler(SchedulerInterface):
req_id
=
req
.
request_id
req_ids
.
append
(
req_id
)
num_tokens
=
req
.
num_generated_token_ids
if
self
.
use_pp
:
# When using PP, the scheduler sends the sampled tokens back,
# because there's no direct communication between the first-
# stage worker and the last-stage worker. Otherwise, we don't
# need to send the sampled tokens back because the model runner
# will cache them.
token_ids
=
req
.
all_token_ids
[
-
num_tokens
:]
token_ids
=
req
.
all_token_ids
[
-
num_tokens
:]
if
num_tokens
>
0
else
[]
new_token_ids
.
append
(
token_ids
)
new_block_ids
.
append
(
req_to_new_block_ids
[
req_id
])
num_computed_tokens
.
append
(
req
.
num_computed_tokens
)
...
...
@@ -1241,7 +1243,7 @@ class Scheduler(SchedulerInterface):
scheduled_spec_token_ids
=
(
scheduler_output
.
scheduled_spec_decode_tokens
.
get
(
req_id
))
request
.
num_generated_token_ids
=
1
request
.
num_generated_token_ids
=
len
(
generated_token_ids
)
if
scheduled_spec_token_ids
:
# num_computed_tokens represents the number of tokens
# processed in the current step, considering scheduled
...
...
@@ -1253,7 +1255,6 @@ class Scheduler(SchedulerInterface):
num_tokens_rejected
=
(
len
(
scheduled_spec_token_ids
)
+
1
-
len
(
generated_token_ids
))
request
.
num_computed_tokens
-=
num_tokens_rejected
request
.
num_generated_token_ids
=
len
(
generated_token_ids
)
spec_decoding_stats
=
self
.
make_spec_decoding_stats
(
spec_decoding_stats
,
num_draft_tokens
=
len
(
scheduled_spec_token_ids
),
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
855cb148
...
...
@@ -31,7 +31,7 @@ from vllm.distributed.parallel_state import (
prepare_communication_buffer_for_model
,
get_tensor_model_parallel_world_size
)
from
vllm.forward_context
import
(
DPMetadata
,
get_forward_context
,
set_forward_context
)
set_forward_context
,
set_profilling
)
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.mamba.mamba_mixer2
import
MambaMixer2
from
vllm.model_executor.layers.rotary_embedding
import
MRotaryEmbedding
...
...
@@ -514,14 +514,12 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
new_token_ids
=
req_data
.
new_token_ids
[
i
]
# Add the sampled token(s) from the previous step (if any).
# This doesn't include "unverified" tokens like spec tokens.
num_new_tokens
=
(
num_computed_tokens
+
len
(
new_token_ids
)
-
req_state
.
num_tokens
)
num_new_tokens
=
len
(
new_token_ids
)
if
num_new_tokens
==
1
:
# Avoid slicing list in most common case.
req_state
.
output_token_ids
.
append
(
new_token_ids
[
-
1
])
elif
num_new_tokens
>
0
:
req_state
.
output_token_ids
.
extend
(
new_token_ids
[
-
num_new_tokens
:]
)
new_token_ids
)
if
len
(
spec_token_ids
)
>
0
:
req_state
.
spec_token_ids
=
spec_token_ids
...
...
@@ -541,8 +539,13 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
# The request is not in the persistent batch.
# The request was either preempted and resumed later, or was not
# scheduled in the previous step and needs to be added again.
req_ids_to_add
.
append
(
req_id
)
continue
if
not
is_last_rank
:
req_state
=
self
.
requests
[
req_id
]
self
.
input_batch
.
add_request
(
req_state
)
req_index
=
self
.
input_batch
.
req_id_to_index
.
get
(
req_id
)
else
:
req_ids_to_add
.
append
(
req_id
)
continue
# Update the persistent batch.
self
.
input_batch
.
num_computed_tokens_cpu
[
req_index
]
=
(
...
...
@@ -554,13 +557,14 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
if
not
is_last_rank
:
# Add new_token_ids to token_ids_cpu.
start_token_index
=
num_computed_tokens
end_token_index
=
num_computed_tokens
+
1
self
.
input_batch
.
token_ids_cpu
[
req_index
,
start_token_index
:
end_token_index
]
=
new_token_ids
[
-
1
]
self
.
input_batch
.
num_tokens_no_spec
[
req_index
]
=
end_token_index
self
.
input_batch
.
num_tokens
[
req_index
]
=
end_token_index
if
len
(
new_token_ids
)
>
0
:
end_token_index
=
num_computed_tokens
+
1
self
.
input_batch
.
token_ids_cpu
[
req_index
,
start_token_index
:
end_token_index
]
=
new_token_ids
[
-
1
]
self
.
input_batch
.
num_tokens_no_spec
[
req_index
]
=
end_token_index
self
.
input_batch
.
num_tokens
[
req_index
]
=
end_token_index
# Add spec_token_ids to token_ids_cpu.
if
spec_token_ids
:
...
...
@@ -1598,6 +1602,11 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
seq_len
=
(
req_state
.
num_computed_tokens
+
scheduler_output
.
num_scheduled_tokens
[
req_id
])
if
seq_len
<
req_state
.
num_tokens
:
# If we have already started decoding, seeing a "partial prefill"
# condition is suspicious and can lead to discarding the sampled
# token forever (PP stall).
if
req_state
.
output_token_ids
:
continue
# Ignore the sampled token for partial prefills.
# Rewind the generator state as if the token was not sampled.
# This relies on cuda-specific torch-internal impl details
...
...
@@ -1675,11 +1684,9 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
spec_decode_metadata
,
attn_metadata
,
)
if
spec_token_ids
is
not
None
:
for
i
in
discard_sampled_tokens_req_indices
:
spec_token_ids
[
i
].
clear
()
# Clear KVConnector state after all KVs are generated.
if
has_kv_transfer_group
():
get_kv_transfer_group
().
clear_connector_metadata
()
...
...
@@ -2089,7 +2096,7 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
Randomize input_ids if VLLM_RANDOMIZE_DP_DUMMY_INPUTS is set.
This is to help balance expert-selection
- during profile_run
- during DP rank dummy run
- during DP rank dummy run
"""
dp_size
=
self
.
vllm_config
.
parallel_config
.
data_parallel_size
randomize_inputs
=
envs
.
VLLM_RANDOMIZE_DP_DUMMY_INPUTS
and
dp_size
>
1
...
...
@@ -3460,6 +3467,11 @@ class GPUModelRunnerMTP(GPUModelRunnerBase):
seq_len
=
(
req_state
.
num_computed_tokens
+
scheduler_output
.
num_scheduled_tokens
[
req_id
])
if
seq_len
<
req_state
.
num_tokens
:
# If we have already started decoding, seeing a "partial prefill"
# condition is suspicious and can lead to discarding the sampled
# token forever (PP stall).
if
req_state
.
output_token_ids
:
continue
# Ignore the sampled token for partial prefills.
# Rewind the generator state as if the token was not sampled.
# This relies on cuda-specific torch-internal impl details
...
...
@@ -3481,7 +3493,6 @@ class GPUModelRunnerMTP(GPUModelRunnerBase):
hidden_states
[:
num_scheduled_tokens
],
scheduler_output
,
)
#-----------------------------------
# Get the valid generated tokens.
sampled_token_ids
=
sampler_output
.
sampled_token_ids
max_gen_len
=
sampled_token_ids
.
shape
[
-
1
]
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment