Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
855cb148
Commit
855cb148
authored
Jan 22, 2026
by
王敏
Browse files
merge dev分支代码
parents
9135afe4
fe2e2705
Changes
31
Show whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
282 additions
and
95 deletions
+282
-95
vllm/model_executor/models/falcon.py
vllm/model_executor/models/falcon.py
+2
-1
vllm/model_executor/models/glm4.py
vllm/model_executor/models/glm4.py
+2
-1
vllm/model_executor/models/qwen3.py
vllm/model_executor/models/qwen3.py
+95
-15
vllm/model_executor/models/qwen3_moe.py
vllm/model_executor/models/qwen3_moe.py
+24
-1
vllm/model_executor/models/telechat2.py
vllm/model_executor/models/telechat2.py
+2
-2
vllm/platforms/rocm.py
vllm/platforms/rocm.py
+1
-1
vllm/utils/__init__.py
vllm/utils/__init__.py
+14
-6
vllm/v1/attention/backends/mla/common.py
vllm/v1/attention/backends/mla/common.py
+61
-17
vllm/v1/attention/backends/mla/flashmla.py
vllm/v1/attention/backends/mla/flashmla.py
+41
-23
vllm/v1/core/sched/scheduler.py
vllm/v1/core/sched/scheduler.py
+11
-10
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+29
-18
No files found.
vllm/model_executor/models/falcon.py
View file @
855cb148
...
@@ -58,6 +58,7 @@ from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
...
@@ -58,6 +58,7 @@ from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.utils
import
pad_weight
,
gemm_bank_conf
from
vllm.model_executor.utils
import
pad_weight
,
gemm_bank_conf
import
vllm.envs
as
envs
FalconConfig
=
Union
[
HF_FalconConfig
,
RWConfig
]
FalconConfig
=
Union
[
HF_FalconConfig
,
RWConfig
]
...
@@ -393,7 +394,7 @@ class FalconModel(nn.Module):
...
@@ -393,7 +394,7 @@ class FalconModel(nn.Module):
self
.
use_gemm_pad
=
os
.
environ
.
get
(
'GEMM_PAD'
)
==
'1'
self
.
use_gemm_pad
=
os
.
environ
.
get
(
'GEMM_PAD'
)
==
'1'
self
.
use_fa_pad
=
os
.
environ
.
get
(
'FA_PAD'
)
==
'1'
self
.
use_fa_pad
=
os
.
environ
.
get
(
'FA_PAD'
)
==
'1'
self
.
use_awq_pad
=
os
.
environ
.
get
(
'AWQ_PAD'
)
==
'1'
self
.
use_awq_pad
=
os
.
environ
.
get
(
'AWQ_PAD'
)
==
'1'
self
.
w8a8_strategy
=
int
(
os
.
getenv
(
'W8A8_SUPPORT_METHODS'
,
'1'
))
self
.
w8a8_strategy
=
envs
.
VLLM_W8A8_BACKEND
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
word_embeddings
(
input_ids
)
return
self
.
word_embeddings
(
input_ids
)
...
...
vllm/model_executor/models/glm4.py
View file @
855cb148
...
@@ -31,6 +31,7 @@ from typing import Optional, Union
...
@@ -31,6 +31,7 @@ from typing import Optional, Union
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
transformers
import
Glm4Config
from
transformers
import
Glm4Config
import
vllm.envs
as
envs
class
MultiModalConfigProxy
:
class
MultiModalConfigProxy
:
...
@@ -332,7 +333,7 @@ class Glm4ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
...
@@ -332,7 +333,7 @@ class Glm4ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self
.
use_gemm_pad
=
os
.
environ
.
get
(
'GEMM_PAD'
)
==
'1'
self
.
use_gemm_pad
=
os
.
environ
.
get
(
'GEMM_PAD'
)
==
'1'
self
.
use_fa_pad
=
os
.
environ
.
get
(
'FA_PAD'
)
==
'1'
self
.
use_fa_pad
=
os
.
environ
.
get
(
'FA_PAD'
)
==
'1'
self
.
use_awq_pad
=
os
.
environ
.
get
(
'AWQ_PAD'
)
==
'1'
self
.
use_awq_pad
=
os
.
environ
.
get
(
'AWQ_PAD'
)
==
'1'
self
.
w8a8_strategy
=
int
(
os
.
getenv
(
'W8A8_SUPPORT_METHODS'
,
'1'
))
self
.
w8a8_strategy
=
envs
.
VLLM_W8A8_BACKEND
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
model
.
get_input_embeddings
(
input_ids
)
return
self
.
model
.
get_input_embeddings
(
input_ids
)
...
...
vllm/model_executor/models/qwen3.py
View file @
855cb148
...
@@ -52,6 +52,7 @@ from .qwen2 import Qwen2MLP as Qwen3MLP
...
@@ -52,6 +52,7 @@ from .qwen2 import Qwen2MLP as Qwen3MLP
from
.qwen2
import
Qwen2Model
from
.qwen2
import
Qwen2Model
from
.utils
import
AutoWeightsLoader
,
PPMissingLayer
,
maybe_prefix
from
.utils
import
AutoWeightsLoader
,
PPMissingLayer
,
maybe_prefix
import
vllm.envs
as
envs
import
vllm.envs
as
envs
from
vllm.utils
import
direct_register_custom_op
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -129,6 +130,58 @@ class Qwen3Attention(nn.Module):
...
@@ -129,6 +130,58 @@ class Qwen3Attention(nn.Module):
self
.
q_norm
=
RMSNorm
(
self
.
head_dim
,
eps
=
rms_norm_eps
)
self
.
q_norm
=
RMSNorm
(
self
.
head_dim
,
eps
=
rms_norm_eps
)
self
.
k_norm
=
RMSNorm
(
self
.
head_dim
,
eps
=
rms_norm_eps
)
self
.
k_norm
=
RMSNorm
(
self
.
head_dim
,
eps
=
rms_norm_eps
)
def
rms_rotary_embedding_fuse
(
positions
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
Optional
[
torch
.
Tensor
],
head_size
:
int
,
cos_sin_cache
:
torch
.
Tensor
,
is_neox_style
:
bool
,
q_weight
:
torch
.
Tensor
,
k_weight
:
torch
.
Tensor
,
q_bias
:
Optional
[
torch
.
Tensor
],
k_bias
:
Optional
[
torch
.
Tensor
],
epsilon
:
float
,
)
->
None
:
from
lightop
import
rms_rotary_embedding_fuse
as
fused_kernel
fused_kernel
(
positions
,
query
,
key
,
head_size
,
cos_sin_cache
,
is_neox_style
,
q_weight
,
k_weight
,
q_bias
,
k_bias
,
epsilon
,
)
def
rms_rotary_embedding_fuse_fake
(
positions
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
Optional
[
torch
.
Tensor
],
head_size
:
int
,
cos_sin_cache
:
torch
.
Tensor
,
is_neox_style
:
bool
,
q_weight
:
torch
.
Tensor
,
k_weight
:
torch
.
Tensor
,
q_bias
:
Optional
[
torch
.
Tensor
],
k_bias
:
Optional
[
torch
.
Tensor
],
epsilon
:
float
,
)
->
None
:
# Fake impl intentionally left as no-op for graph tracing modes.
pass
if
not
hasattr
(
torch
.
ops
.
vllm
,
"rms_rotary_embedding_fuse"
):
direct_register_custom_op
(
op_name
=
"rms_rotary_embedding_fuse"
,
op_func
=
rms_rotary_embedding_fuse
,
mutates_args
=
[
"query"
,
"key"
],
fake_impl
=
rms_rotary_embedding_fuse_fake
,
)
def
forward
(
def
forward
(
self
,
self
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
...
@@ -136,7 +189,34 @@ class Qwen3Attention(nn.Module):
...
@@ -136,7 +189,34 @@ class Qwen3Attention(nn.Module):
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
# Add qk-norm
if
envs
.
VLLM_USE_FUSED_RMS_ROPE
:
# Fused RMSNorm + RoPE path through custom op.
cos_sin_cache
=
self
.
rotary_emb
.
cos_sin_cache
if
(
cos_sin_cache
.
device
!=
q
.
device
or
cos_sin_cache
.
dtype
!=
q
.
dtype
):
cos_sin_cache
=
cos_sin_cache
.
to
(
q
.
device
,
dtype
=
q
.
dtype
,
non_blocking
=
True
)
# Persist the converted cache so we don't re-copy/re-allocate
# on every forward when the original buffer starts on CPU.
self
.
rotary_emb
.
cos_sin_cache
=
cos_sin_cache
q
=
q
.
contiguous
()
k
=
k
.
contiguous
()
torch
.
ops
.
vllm
.
rms_rotary_embedding_fuse
(
positions
,
q
,
k
,
self
.
head_dim
,
cos_sin_cache
,
self
.
rotary_emb
.
is_neox_style
,
self
.
q_norm
.
weight
,
self
.
k_norm
.
weight
,
None
,
None
,
self
.
q_norm
.
variance_epsilon
,
)
else
:
# Add qk-norm then RoPE (original path).
q_by_head
=
q
.
view
(
*
q
.
shape
[:
-
1
],
q
.
shape
[
-
1
]
//
self
.
head_dim
,
q_by_head
=
q
.
view
(
*
q
.
shape
[:
-
1
],
q
.
shape
[
-
1
]
//
self
.
head_dim
,
self
.
head_dim
)
self
.
head_dim
)
if
envs
.
VLLM_USE_APEX_RN
:
if
envs
.
VLLM_USE_APEX_RN
:
...
...
vllm/model_executor/models/qwen3_moe.py
View file @
855cb148
...
@@ -38,6 +38,19 @@ from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
...
@@ -38,6 +38,19 @@ from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.fused_moe
import
FusedMoE
from
vllm.model_executor.layers.fused_moe
import
FusedMoE
try
:
from
vllm.model_executor.layers.fused_moe.router_capture
import
(
maybe_record_router_logits
,
)
except
ImportError
:
def
maybe_record_router_logits
(
*
,
layer_name
:
str
,
router_logits
:
torch
.
Tensor
,
top_k
:
int
,
)
->
None
:
return
None
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
MergedColumnParallelLinear
,
from
vllm.model_executor.layers.linear
import
(
MergedColumnParallelLinear
,
QKVParallelLinear
,
QKVParallelLinear
,
...
@@ -111,6 +124,8 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
...
@@ -111,6 +124,8 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
):
):
super
().
__init__
()
super
().
__init__
()
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
_router_top_k
=
int
(
config
.
num_experts_per_tok
)
self
.
_router_capture_layer_name
=
prefix
if
self
.
tp_size
>
config
.
num_experts
:
if
self
.
tp_size
>
config
.
num_experts
:
raise
ValueError
(
raise
ValueError
(
...
@@ -140,6 +155,14 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
...
@@ -140,6 +155,14 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
# router_logits: (num_tokens, n_experts)
# router_logits: (num_tokens, n_experts)
router_logits
,
_
=
self
.
gate
(
hidden_states
)
router_logits
,
_
=
self
.
gate
(
hidden_states
)
if
not
(
hasattr
(
torch
,
"_dynamo"
)
and
torch
.
_dynamo
.
is_compiling
()):
capture_enabled
=
envs
.
VLLM_MOE_ROUTER_CAPTURE
if
capture_enabled
:
maybe_record_router_logits
(
layer_name
=
self
.
_router_capture_layer_name
,
router_logits
=
router_logits
,
top_k
=
self
.
_router_top_k
,
)
final_hidden_states
=
self
.
experts
(
hidden_states
=
hidden_states
,
final_hidden_states
=
self
.
experts
(
hidden_states
=
hidden_states
,
router_logits
=
router_logits
)
router_logits
=
router_logits
)
...
@@ -453,7 +476,7 @@ class Qwen3MoeModel(nn.Module):
...
@@ -453,7 +476,7 @@ class Qwen3MoeModel(nn.Module):
self
.
use_gemm_pad
=
os
.
environ
.
get
(
'GEMM_PAD'
)
==
'1'
self
.
use_gemm_pad
=
os
.
environ
.
get
(
'GEMM_PAD'
)
==
'1'
self
.
use_fa_pad
=
os
.
environ
.
get
(
'FA_PAD'
)
==
'1'
self
.
use_fa_pad
=
os
.
environ
.
get
(
'FA_PAD'
)
==
'1'
self
.
use_awq_pad
=
os
.
environ
.
get
(
'AWQ_PAD'
)
==
'1'
self
.
use_awq_pad
=
os
.
environ
.
get
(
'AWQ_PAD'
)
==
'1'
self
.
w8a8_strategy
=
int
(
os
.
getenv
(
'W8A8_SUPPORT_METHODS'
,
'1'
))
self
.
w8a8_strategy
=
envs
.
VLLM_W8A8_BACKEND
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
embed_tokens
(
input_ids
)
return
self
.
embed_tokens
(
input_ids
)
...
...
vllm/model_executor/models/telechat2.py
View file @
855cb148
...
@@ -37,6 +37,7 @@ from .utils import (AutoWeightsLoader, PPMissingLayer, WeightsMapper,
...
@@ -37,6 +37,7 @@ from .utils import (AutoWeightsLoader, PPMissingLayer, WeightsMapper,
from
vllm.utils
import
W8a8GetCacheJSON
from
vllm.utils
import
W8a8GetCacheJSON
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.utils
import
pad_weight
,
gemm_bank_conf
from
vllm.model_executor.utils
import
pad_weight
,
gemm_bank_conf
import
vllm.envs
as
envs
class
TeleChat2Model
(
LlamaModel
):
class
TeleChat2Model
(
LlamaModel
):
...
@@ -66,8 +67,7 @@ class TeleChat2Model(LlamaModel):
...
@@ -66,8 +67,7 @@ class TeleChat2Model(LlamaModel):
self
.
use_gemm_pad
=
os
.
environ
.
get
(
'GEMM_PAD'
)
==
'1'
self
.
use_gemm_pad
=
os
.
environ
.
get
(
'GEMM_PAD'
)
==
'1'
self
.
use_fa_pad
=
os
.
environ
.
get
(
'FA_PAD'
)
==
'1'
self
.
use_fa_pad
=
os
.
environ
.
get
(
'FA_PAD'
)
==
'1'
self
.
use_awq_pad
=
os
.
environ
.
get
(
'AWQ_PAD'
)
==
'1'
self
.
use_awq_pad
=
os
.
environ
.
get
(
'AWQ_PAD'
)
==
'1'
self
.
w8a8_strategy
=
int
(
os
.
getenv
(
'W8A8_SUPPORT_METHODS'
,
'1'
))
self
.
w8a8_strategy
=
envs
.
VLLM_W8A8_BACKEND
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]])
->
set
[
str
]:
torch
.
Tensor
]])
->
set
[
str
]:
stacked_params_mapping
=
[
stacked_params_mapping
=
[
...
...
vllm/platforms/rocm.py
View file @
855cb148
vllm/utils/__init__.py
View file @
855cb148
...
@@ -1958,7 +1958,7 @@ class W8a8GetCacheJSON:
...
@@ -1958,7 +1958,7 @@ class W8a8GetCacheJSON:
self
.
moe_weight_shapes
=
[]
self
.
moe_weight_shapes
=
[]
arch_name
=
torch
.
cuda
.
get_device_properties
(
"cuda"
).
gcnArchName
.
split
(
':'
)[
0
]
arch_name
=
torch
.
cuda
.
get_device_properties
(
"cuda"
).
gcnArchName
.
split
(
':'
)[
0
]
arch_cu
=
torch
.
cuda
.
get_device_properties
(
torch
.
cuda
.
current_device
()).
multi_processor_count
arch_cu
=
torch
.
cuda
.
get_device_properties
(
torch
.
cuda
.
current_device
()).
multi_processor_count
self
.
cache_json_data
=
{}
device_name
=
arch_name
+
'_'
+
str
(
arch_cu
)
+
'cu'
device_name
=
arch_name
+
'_'
+
str
(
arch_cu
)
+
'cu'
self
.
device_name
=
device_name
self
.
device_name
=
device_name
self
.
topk
=
1
self
.
topk
=
1
...
@@ -2060,19 +2060,27 @@ class W8a8GetCacheJSON:
...
@@ -2060,19 +2060,27 @@ class W8a8GetCacheJSON:
return
self
.
triton_json_dir
+
f
"/linear_
{
n
}
_
{
k
}
_block[
{
block_n
}
,
{
block_k
}
]_
{
self
.
device_name
}
.json"
return
self
.
triton_json_dir
+
f
"/linear_
{
n
}
_
{
k
}
_block[
{
block_n
}
,
{
block_k
}
]_
{
self
.
device_name
}
.json"
def
get_moeint8json_name
(
self
,
E
,
N1
,
N2
,
K
,
TOPK
,
def
get_moeint8json_name
(
self
,
E
,
N1
,
N2
,
K
,
TOPK
,
block_size
:
Optional
[
list
]
=
None
,
use_int4_w4a8
:
Optional
[
bool
]
=
False
):
block_size
:
Optional
[
list
]
=
None
,
use_int4_w4a8
:
Optional
[
bool
]
=
False
,
use_int8_w8a8
:
Optional
[
bool
]
=
False
):
if
use_int4_w4a8
:
if
use_int4_w4a8
:
if
block_size
is
not
None
:
if
block_size
is
not
None
:
return
self
.
triton_json_dir
+
f
"/MOE_W4A8INT8[
{
block_size
[
0
]
}
,
{
block_size
[
1
]
}
]_E=
{
E
}
_N1=
{
N1
}
_N2=
{
N2
}
_K=
{
K
}
_TOPK
{
TOPK
}
_
{
self
.
device_name
}
.json"
return
self
.
triton_json_dir
+
f
"/MOE_W4A8INT8[
{
block_size
[
0
]
}
,
{
block_size
[
1
]
}
]_E=
{
E
}
_N1=
{
N1
}
_N2=
{
N2
}
_K=
{
K
}
_TOPK
{
TOPK
}
_
{
self
.
device_name
}
.json"
else
:
else
:
return
self
.
triton_json_dir
+
f
"/MOE_W4A8INT8_E=
{
E
}
_N1=
{
N1
}
_N2=
{
N2
}
_K=
{
K
}
_TOPK
{
TOPK
}
_
{
self
.
device_name
}
.json"
return
self
.
triton_json_dir
+
f
"/MOE_W4A8INT8_E=
{
E
}
_N1=
{
N1
}
_N2=
{
N2
}
_K=
{
K
}
_TOPK
{
TOPK
}
_
{
self
.
device_name
}
.json"
elif
use_int8_w8a8
:
if
block_size
is
not
None
:
return
self
.
triton_json_dir
+
f
"/MOE_BLOCKINT8[
{
block_size
[
0
]
}
,
{
block_size
[
1
]
}
]_E=
{
E
}
_N1=
{
N1
}
_N2=
{
N2
}
_K=
{
K
}
_TOPK
{
TOPK
}
_
{
self
.
device_name
}
.json"
else
:
return
self
.
triton_json_dir
+
f
"/MOE_W8A8INT8_E=
{
E
}
_N1=
{
N1
}
_N2=
{
N2
}
_K=
{
K
}
_TOPK
{
TOPK
}
_
{
self
.
device_name
}
.json"
else
:
else
:
if
block_size
is
not
None
:
if
block_size
is
not
None
:
return
self
.
triton_json_dir
+
f
"/MOE_BLOCK
INT
8[
{
block_size
[
0
]
}
,
{
block_size
[
1
]
}
]_E=
{
E
}
_N1=
{
N1
}
_N2=
{
N2
}
_K=
{
K
}
_TOPK
{
TOPK
}
_
{
self
.
device_name
}
.json"
return
self
.
triton_json_dir
+
f
"/MOE_BLOCK
FP
8[
{
block_size
[
0
]
}
,
{
block_size
[
1
]
}
]_E=
{
E
}
_N1=
{
N1
}
_N2=
{
N2
}
_K=
{
K
}
_TOPK
{
TOPK
}
_
{
self
.
device_name
}
.json"
else
:
else
:
return
self
.
triton_json_dir
+
f
"/MOE_W8A8
INT
8_E=
{
E
}
_N1=
{
N1
}
_N2=
{
N2
}
_K=
{
K
}
_TOPK
{
TOPK
}
_
{
self
.
device_name
}
.json"
return
self
.
triton_json_dir
+
f
"/MOE_W8A8
FP
8_E=
{
E
}
_N1=
{
N1
}
_N2=
{
N2
}
_K=
{
K
}
_TOPK
{
TOPK
}
_
{
self
.
device_name
}
.json"
def
get_moeint8_triton_cache
(
self
,
file_path
,
E
,
N1
,
N2
,
K
,
TOPK
):
def
get_moeint8_triton_cache
(
self
,
file_path
,
E
,
N1
,
N2
,
K
,
TOPK
):
if
file_path
in
self
.
cache_json_data
:
# 直接返回缓存数据,避免重复读取
return
self
.
cache_json_data
[
file_path
]
cache_json_file
=
file_path
cache_json_file
=
file_path
if
os
.
path
.
exists
(
file_path
):
if
os
.
path
.
exists
(
file_path
):
...
@@ -2088,7 +2096,7 @@ class W8a8GetCacheJSON:
...
@@ -2088,7 +2096,7 @@ class W8a8GetCacheJSON:
for
sub_key
,
sub_value
in
value
.
items
():
for
sub_key
,
sub_value
in
value
.
items
():
configs_key
=
f
"
{
sub_key
}
_
{
key
}
"
configs_key
=
f
"
{
sub_key
}
_
{
key
}
"
configs_dict
[
configs_key
]
=
sub_value
configs_dict
[
configs_key
]
=
sub_value
self
.
cache_json_data
[
file_path
]
=
configs_dict
return
configs_dict
return
configs_dict
# Adapted from: https://stackoverflow.com/a/47212782/5082708
# Adapted from: https://stackoverflow.com/a/47212782/5082708
...
...
vllm/v1/attention/backends/mla/common.py
View file @
855cb148
...
@@ -217,7 +217,7 @@ from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
...
@@ -217,7 +217,7 @@ from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
CommonAttentionMetadata
)
CommonAttentionMetadata
)
from
vllm.v1.kv_cache_interface
import
AttentionSpec
from
vllm.v1.kv_cache_interface
import
AttentionSpec
from
vllm.v1.worker.block_table
import
BlockTable
from
vllm.v1.worker.block_table
import
BlockTable
from
lightop
import
fused_rms_norm_rope_contiguous
from
lightop
import
fused_rms_norm_rope_contiguous
,
fuse_rmsnorm_rope_quant_qkv
try
:
try
:
from
vllm.vllm_flash_attn
import
flash_attn_varlen_func
from
vllm.vllm_flash_attn
import
flash_attn_varlen_func
...
@@ -1233,6 +1233,47 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
...
@@ -1233,6 +1233,47 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
kv_cache_dtype_str
=
"bf16"
kv_cache_dtype_str
=
"bf16"
else
:
else
:
kv_cache_dtype_str
=
self
.
kv_cache_dtype
kv_cache_dtype_str
=
self
.
kv_cache_dtype
if
torch
.
cuda
.
get_device_properties
(
"cuda"
).
gcnArchName
.
split
(
':'
)[
0
]
==
"gfx938"
and
kv_cache_dtype_str
==
"fp8_e4m3"
and
envs
.
VLLM_USE_FUSED_CACHE_QUANT_BMM_MLA
:
if
has_prefill
:
fused_rms_norm_rope_contiguous
(
positions
[:
num_actual_toks
,
...],
q
,
k_pe
.
squeeze
(
1
),
k_c_normed
,
# not normed
key_normed
[:
num_actual_toks
,
...],
# normed
weight
,
cos_sin_cache
,
attn_metadata
.
slot_mapping
.
flatten
(),
kv_cache
,
kv_cache_dtype_str
,
1.0
,
False
,
1e-6
,
)
else
:
q_tensor
=
torch
.
randn
(
q
.
shape
[
0
],
num_local_heads
,
self
.
qk_nope_head_dim
+
self
.
qk_rope_head_dim
,
dtype
=
q
.
dtype
,
device
=
q
.
device
)
q_quant
=
torch
.
empty_like
(
q_tensor
,
dtype
=
torch
.
float8_e4m3fn
,
device
=
q
.
device
)
q_scale
=
torch
.
empty
(
q
.
shape
[
0
],
dtype
=
torch
.
float32
,
device
=
q
.
device
)
fuse_rmsnorm_rope_quant_qkv
(
positions
[:
num_actual_toks
,
...],
query_nope
,
q
,
q_quant
,
q_scale
,
k_pe
.
squeeze
(
1
),
k_c_normed
,
# not normed
key_normed
[:
num_actual_toks
,
...],
# normed
weight
,
cos_sin_cache
,
attn_metadata
.
slot_mapping
.
flatten
(),
kv_cache
,
kv_cache_dtype_str
,
1.0
,
False
,
1e-6
,
)
else
:
fused_rms_norm_rope_contiguous
(
fused_rms_norm_rope_contiguous
(
positions
[:
num_actual_toks
,
...],
positions
[:
num_actual_toks
,
...],
q
,
q
,
...
@@ -1259,12 +1300,15 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
...
@@ -1259,12 +1300,15 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
if
has_decode
:
if
has_decode
:
assert
attn_metadata
.
decode
is
not
None
assert
attn_metadata
.
decode
is
not
None
if
torch
.
cuda
.
get_device_properties
(
"cuda"
).
gcnArchName
.
split
(
':'
)[
0
]
==
"gfx938"
and
kv_cache_dtype_str
==
"fp8_e4m3"
and
envs
.
VLLM_USE_FUSED_CACHE_QUANT_BMM_MLA
:
decode_q
=
q_quant
[:
num_decode_tokens
]
decode_q_nope
,
decode_q_pe
=
decode_q
.
split
(
decode_q_nope
,
decode_q_pe
=
decode_q
.
split
(
[
self
.
qk_nope_head_dim
,
self
.
qk_rope_head_dim
],
dim
=-
1
)
[
self
.
qk_nope_head_dim
,
self
.
qk_rope_head_dim
],
dim
=-
1
)
# Convert from (B, N, P) to (N, B, P)
# Convert from (B, N, P) to (N, B, P)
decode_q_nope
=
decode_q_nope
.
transpose
(
0
,
1
)
decode_q_nope
=
decode_q_nope
.
transpose
(
0
,
1
)
# Multiply (N, B, P) x (N, P, L) -> (N, B, L)
# Multiply (N, B, P) x (N, P, L) -> (N, B, L)
decode_ql_nope
=
torch
.
bmm
(
decode_q_nope
,
self
.
W_UK_T
)
# todo: bmm support
decode_ql_nope
=
torch
.
bmm
(
q_scale
,
decode_q_nope
,
self
.
W_UK_T
)
if
torch
.
cuda
.
get_device_properties
(
"cuda"
).
gcnArchName
.
split
(
':'
)[
0
]
==
"gfx938"
and
kv_cache_dtype_str
==
"fp8_e4m3"
and
envs
.
VLLM_USE_FUSED_CACHE_QUANT_BMM_MLA
else
torch
.
bmm
(
decode_q_nope
,
self
.
W_UK_T
)
# Convert from (N, B, L) to (B, N, L)
# Convert from (N, B, L) to (B, N, L)
decode_ql_nope
=
decode_ql_nope
.
transpose
(
0
,
1
)
decode_ql_nope
=
decode_ql_nope
.
transpose
(
0
,
1
)
...
...
vllm/v1/attention/backends/mla/flashmla.py
View file @
855cb148
...
@@ -12,6 +12,7 @@ from vllm.attention.ops.flashmla import (flash_mla_with_kvcache,
...
@@ -12,6 +12,7 @@ from vllm.attention.ops.flashmla import (flash_mla_with_kvcache,
flash_mla_with_kvcache_q_nope_pe
,
flash_mla_with_kvcache_q_nope_pe
,
get_mla_metadata
,
get_mla_metadata
,
flash_mla_with_kvcache_fp8
,
flash_mla_with_kvcache_fp8
,
flash_mla_with_kvcache_fp8_with_cat
,
get_mla_decoding_metadata_dense_fp8
,
get_mla_decoding_metadata_dense_fp8
,
is_flashmla_supported
)
is_flashmla_supported
)
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
...
@@ -181,6 +182,23 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
...
@@ -181,6 +182,23 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
assert
attn_metadata
.
decode
is
not
None
assert
attn_metadata
.
decode
is
not
None
if
torch
.
cuda
.
get_device_properties
(
"cuda"
).
gcnArchName
.
split
(
':'
)[
0
]
==
"gfx938"
and
kv_cache_dtype
==
"fp8_e4m3"
and
envs
.
VLLM_USE_FLASH_MLA_FP8
:
if
torch
.
cuda
.
get_device_properties
(
"cuda"
).
gcnArchName
.
split
(
':'
)[
0
]
==
"gfx938"
and
kv_cache_dtype
==
"fp8_e4m3"
and
envs
.
VLLM_USE_FLASH_MLA_FP8
:
if
envs
.
VLLM_USE_FUSED_CACHE_QUANT_BMM_MLA
:
o
,
_
=
flash_mla_with_kvcache_fp8_with_cat
(
q_nope
=
q_nope
.
unsqueeze
(
1
),
q_pe
=
q_pe
.
unsqueeze
(
1
),
k_cache
=
kv_c_and_k_pe_cache
.
unsqueeze
(
-
2
),
# Add head dim of 1
block_table
=
attn_metadata
.
decode
.
block_table
,
cache_seqlens
=
attn_metadata
.
decode
.
seq_lens
,
head_dim_v
=
self
.
kv_lora_rank
,
tile_scheduler_metadata
=
attn_metadata
.
decode
.
tile_scheduler_metadata
,
num_splits
=
attn_metadata
.
decode
.
num_splits
,
softmax_scale
=
self
.
scale
,
causal
=
True
,
descale_q
=
q_scale
,
descale_k
=
k_scale
,
)
else
:
if
envs
.
VLLM_USE_OPT_CAT
:
if
envs
.
VLLM_USE_OPT_CAT
:
if
q_nope
.
shape
[
0
]
<
1024
:
if
q_nope
.
shape
[
0
]
<
1024
:
from
vllm.v1.attention.backends.mla.test_concat
import
concat_helper_decode
from
vllm.v1.attention.backends.mla.test_concat
import
concat_helper_decode
...
...
vllm/v1/core/sched/scheduler.py
View file @
855cb148
...
@@ -1051,11 +1051,12 @@ class Scheduler(SchedulerInterface):
...
@@ -1051,11 +1051,12 @@ class Scheduler(SchedulerInterface):
def
schedule
(
self
)
->
SchedulerOutput
:
def
schedule
(
self
)
->
SchedulerOutput
:
if
envs
.
VLLM_USE_PD_SPLIT
:
if
envs
.
VLLM_USE_PD_SPLIT
:
if
self
.
use_mla
:
if
self
.
full_cuda_graph
and
self
.
num_spec_tokens
>
0
:
return
self
.
schedule_split_pd
()
return
self
.
schedule_split_pd
()
else
:
else
:
if
self
.
connector
is
not
None
:
return
self
.
schedule_default
()
return
self
.
schedule_default
()
if
self
.
full_cuda_graph
and
self
.
use_mla
and
self
.
num_spec_tokens
>
0
:
else
:
return
self
.
schedule_split_pd
()
return
self
.
schedule_split_pd
()
else
:
else
:
return
self
.
schedule_default
()
return
self
.
schedule_default
()
...
@@ -1101,13 +1102,14 @@ class Scheduler(SchedulerInterface):
...
@@ -1101,13 +1102,14 @@ class Scheduler(SchedulerInterface):
req_id
=
req
.
request_id
req_id
=
req
.
request_id
req_ids
.
append
(
req_id
)
req_ids
.
append
(
req_id
)
num_tokens
=
req
.
num_generated_token_ids
num_tokens
=
req
.
num_generated_token_ids
if
self
.
use_pp
:
if
self
.
use_pp
:
# When using PP, the scheduler sends the sampled tokens back,
# When using PP, the scheduler sends the sampled tokens back,
# because there's no direct communication between the first-
# because there's no direct communication between the first-
# stage worker and the last-stage worker. Otherwise, we don't
# stage worker and the last-stage worker. Otherwise, we don't
# need to send the sampled tokens back because the model runner
# need to send the sampled tokens back because the model runner
# will cache them.
# will cache them.
token_ids
=
req
.
all_token_ids
[
-
num_tokens
:]
token_ids
=
req
.
all_token_ids
[
-
num_tokens
:]
if
num_tokens
>
0
else
[]
new_token_ids
.
append
(
token_ids
)
new_token_ids
.
append
(
token_ids
)
new_block_ids
.
append
(
req_to_new_block_ids
[
req_id
])
new_block_ids
.
append
(
req_to_new_block_ids
[
req_id
])
num_computed_tokens
.
append
(
req
.
num_computed_tokens
)
num_computed_tokens
.
append
(
req
.
num_computed_tokens
)
...
@@ -1241,7 +1243,7 @@ class Scheduler(SchedulerInterface):
...
@@ -1241,7 +1243,7 @@ class Scheduler(SchedulerInterface):
scheduled_spec_token_ids
=
(
scheduled_spec_token_ids
=
(
scheduler_output
.
scheduled_spec_decode_tokens
.
get
(
req_id
))
scheduler_output
.
scheduled_spec_decode_tokens
.
get
(
req_id
))
request
.
num_generated_token_ids
=
1
request
.
num_generated_token_ids
=
len
(
generated_token_ids
)
if
scheduled_spec_token_ids
:
if
scheduled_spec_token_ids
:
# num_computed_tokens represents the number of tokens
# num_computed_tokens represents the number of tokens
# processed in the current step, considering scheduled
# processed in the current step, considering scheduled
...
@@ -1253,7 +1255,6 @@ class Scheduler(SchedulerInterface):
...
@@ -1253,7 +1255,6 @@ class Scheduler(SchedulerInterface):
num_tokens_rejected
=
(
len
(
scheduled_spec_token_ids
)
+
1
-
num_tokens_rejected
=
(
len
(
scheduled_spec_token_ids
)
+
1
-
len
(
generated_token_ids
))
len
(
generated_token_ids
))
request
.
num_computed_tokens
-=
num_tokens_rejected
request
.
num_computed_tokens
-=
num_tokens_rejected
request
.
num_generated_token_ids
=
len
(
generated_token_ids
)
spec_decoding_stats
=
self
.
make_spec_decoding_stats
(
spec_decoding_stats
=
self
.
make_spec_decoding_stats
(
spec_decoding_stats
,
spec_decoding_stats
,
num_draft_tokens
=
len
(
scheduled_spec_token_ids
),
num_draft_tokens
=
len
(
scheduled_spec_token_ids
),
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
855cb148
...
@@ -31,7 +31,7 @@ from vllm.distributed.parallel_state import (
...
@@ -31,7 +31,7 @@ from vllm.distributed.parallel_state import (
prepare_communication_buffer_for_model
,
prepare_communication_buffer_for_model
,
get_tensor_model_parallel_world_size
)
get_tensor_model_parallel_world_size
)
from
vllm.forward_context
import
(
DPMetadata
,
get_forward_context
,
from
vllm.forward_context
import
(
DPMetadata
,
get_forward_context
,
set_forward_context
)
set_forward_context
,
set_profilling
)
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.mamba.mamba_mixer2
import
MambaMixer2
from
vllm.model_executor.layers.mamba.mamba_mixer2
import
MambaMixer2
from
vllm.model_executor.layers.rotary_embedding
import
MRotaryEmbedding
from
vllm.model_executor.layers.rotary_embedding
import
MRotaryEmbedding
...
@@ -514,14 +514,12 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
...
@@ -514,14 +514,12 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
new_token_ids
=
req_data
.
new_token_ids
[
i
]
new_token_ids
=
req_data
.
new_token_ids
[
i
]
# Add the sampled token(s) from the previous step (if any).
# Add the sampled token(s) from the previous step (if any).
# This doesn't include "unverified" tokens like spec tokens.
# This doesn't include "unverified" tokens like spec tokens.
num_new_tokens
=
(
num_computed_tokens
+
len
(
new_token_ids
)
-
num_new_tokens
=
len
(
new_token_ids
)
req_state
.
num_tokens
)
if
num_new_tokens
==
1
:
if
num_new_tokens
==
1
:
# Avoid slicing list in most common case.
req_state
.
output_token_ids
.
append
(
new_token_ids
[
-
1
])
req_state
.
output_token_ids
.
append
(
new_token_ids
[
-
1
])
elif
num_new_tokens
>
0
:
elif
num_new_tokens
>
0
:
req_state
.
output_token_ids
.
extend
(
req_state
.
output_token_ids
.
extend
(
new_token_ids
[
-
num_new_tokens
:]
)
new_token_ids
)
if
len
(
spec_token_ids
)
>
0
:
if
len
(
spec_token_ids
)
>
0
:
req_state
.
spec_token_ids
=
spec_token_ids
req_state
.
spec_token_ids
=
spec_token_ids
...
@@ -541,6 +539,11 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
...
@@ -541,6 +539,11 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
# The request is not in the persistent batch.
# The request is not in the persistent batch.
# The request was either preempted and resumed later, or was not
# The request was either preempted and resumed later, or was not
# scheduled in the previous step and needs to be added again.
# scheduled in the previous step and needs to be added again.
if
not
is_last_rank
:
req_state
=
self
.
requests
[
req_id
]
self
.
input_batch
.
add_request
(
req_state
)
req_index
=
self
.
input_batch
.
req_id_to_index
.
get
(
req_id
)
else
:
req_ids_to_add
.
append
(
req_id
)
req_ids_to_add
.
append
(
req_id
)
continue
continue
...
@@ -554,6 +557,7 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
...
@@ -554,6 +557,7 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
if
not
is_last_rank
:
if
not
is_last_rank
:
# Add new_token_ids to token_ids_cpu.
# Add new_token_ids to token_ids_cpu.
start_token_index
=
num_computed_tokens
start_token_index
=
num_computed_tokens
if
len
(
new_token_ids
)
>
0
:
end_token_index
=
num_computed_tokens
+
1
end_token_index
=
num_computed_tokens
+
1
self
.
input_batch
.
token_ids_cpu
[
self
.
input_batch
.
token_ids_cpu
[
req_index
,
req_index
,
...
@@ -1598,6 +1602,11 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
...
@@ -1598,6 +1602,11 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
seq_len
=
(
req_state
.
num_computed_tokens
+
seq_len
=
(
req_state
.
num_computed_tokens
+
scheduler_output
.
num_scheduled_tokens
[
req_id
])
scheduler_output
.
num_scheduled_tokens
[
req_id
])
if
seq_len
<
req_state
.
num_tokens
:
if
seq_len
<
req_state
.
num_tokens
:
# If we have already started decoding, seeing a "partial prefill"
# condition is suspicious and can lead to discarding the sampled
# token forever (PP stall).
if
req_state
.
output_token_ids
:
continue
# Ignore the sampled token for partial prefills.
# Ignore the sampled token for partial prefills.
# Rewind the generator state as if the token was not sampled.
# Rewind the generator state as if the token was not sampled.
# This relies on cuda-specific torch-internal impl details
# This relies on cuda-specific torch-internal impl details
...
@@ -1675,11 +1684,9 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
...
@@ -1675,11 +1684,9 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
spec_decode_metadata
,
spec_decode_metadata
,
attn_metadata
,
attn_metadata
,
)
)
if
spec_token_ids
is
not
None
:
if
spec_token_ids
is
not
None
:
for
i
in
discard_sampled_tokens_req_indices
:
for
i
in
discard_sampled_tokens_req_indices
:
spec_token_ids
[
i
].
clear
()
spec_token_ids
[
i
].
clear
()
# Clear KVConnector state after all KVs are generated.
# Clear KVConnector state after all KVs are generated.
if
has_kv_transfer_group
():
if
has_kv_transfer_group
():
get_kv_transfer_group
().
clear_connector_metadata
()
get_kv_transfer_group
().
clear_connector_metadata
()
...
@@ -3460,6 +3467,11 @@ class GPUModelRunnerMTP(GPUModelRunnerBase):
...
@@ -3460,6 +3467,11 @@ class GPUModelRunnerMTP(GPUModelRunnerBase):
seq_len
=
(
req_state
.
num_computed_tokens
+
seq_len
=
(
req_state
.
num_computed_tokens
+
scheduler_output
.
num_scheduled_tokens
[
req_id
])
scheduler_output
.
num_scheduled_tokens
[
req_id
])
if
seq_len
<
req_state
.
num_tokens
:
if
seq_len
<
req_state
.
num_tokens
:
# If we have already started decoding, seeing a "partial prefill"
# condition is suspicious and can lead to discarding the sampled
# token forever (PP stall).
if
req_state
.
output_token_ids
:
continue
# Ignore the sampled token for partial prefills.
# Ignore the sampled token for partial prefills.
# Rewind the generator state as if the token was not sampled.
# Rewind the generator state as if the token was not sampled.
# This relies on cuda-specific torch-internal impl details
# This relies on cuda-specific torch-internal impl details
...
@@ -3481,7 +3493,6 @@ class GPUModelRunnerMTP(GPUModelRunnerBase):
...
@@ -3481,7 +3493,6 @@ class GPUModelRunnerMTP(GPUModelRunnerBase):
hidden_states
[:
num_scheduled_tokens
],
hidden_states
[:
num_scheduled_tokens
],
scheduler_output
,
scheduler_output
,
)
)
#-----------------------------------
# Get the valid generated tokens.
# Get the valid generated tokens.
sampled_token_ids
=
sampler_output
.
sampled_token_ids
sampled_token_ids
=
sampler_output
.
sampled_token_ids
max_gen_len
=
sampled_token_ids
.
shape
[
-
1
]
max_gen_len
=
sampled_token_ids
.
shape
[
-
1
]
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment