Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
9820d063
Commit
9820d063
authored
Jul 30, 2025
by
yangql
Browse files
解决v1cudagraph的问题以及接入fuse moe marlin_V3版本
parent
c2bcb0ab
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
149 additions
and
129 deletions
+149
-129
vllm/_custom_ops.py
vllm/_custom_ops.py
+20
-5
vllm/model_executor/layers/fused_moe/fused_marlin_moe.py
vllm/model_executor/layers/fused_moe/fused_marlin_moe.py
+115
-109
vllm/model_executor/layers/quantization/awq.py
vllm/model_executor/layers/quantization/awq.py
+9
-6
vllm/model_executor/layers/quantization/awq_marlin.py
vllm/model_executor/layers/quantization/awq_marlin.py
+2
-2
vllm/model_executor/layers/quantization/awq_triton.py
vllm/model_executor/layers/quantization/awq_triton.py
+1
-1
vllm/model_executor/layers/quantization/utils/marlin_utils.py
.../model_executor/layers/quantization/utils/marlin_utils.py
+0
-4
vllm/model_executor/models/deepseek_mtp.py
vllm/model_executor/models/deepseek_mtp.py
+2
-2
No files found.
vllm/_custom_ops.py
View file @
9820d063
...
...
@@ -10,16 +10,16 @@ import vllm.envs as envs
from
vllm.logger
import
init_logger
from
vllm.platforms
import
current_platform
from
vllm.scalar_type
import
ScalarType
from
vllm.utils
import
direct_register_custom_op
try
:
from
lmslim
import
quant_ops
from
lmslim
import
quant_tools
except
Exception
:
print
(
"INFO: Please install lmslim if you want to infer gptq or awq or w8a8 model.
\n
"
)
try
:
import
marlin
import
lightop
except
Exception
:
print
(
"INFO: Please install
marlin
if you want to infer awq of marlin.
\n
"
)
print
(
"INFO: Please install
lightop
if you want to infer awq of marlin.
\n
"
)
logger
=
init_logger
(
__name__
)
...
...
@@ -766,6 +766,14 @@ def awq_gemm(input: torch.Tensor, weight: torch.Tensor,
splikspace
,
splikspacesize
)
def
awq_gemm_fake
(
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
zeros_and_scales
:
torch
.
Tensor
,
m
:
int
,
n
:
int
,
k
:
int
,
group_size
:
int
,
padding_group
:
int
,
splikspace
:
torch
.
Tensor
,
splikspacesize
:
int
)
->
torch
.
Tensor
:
return
torch
.
empty
((
m
,
n
),
dtype
=
input
.
dtype
,
device
=
input
.
device
)
def
convert_s4
(
qw
:
torch
.
Tensor
,
qz
:
torch
.
Tensor
,
s
:
torch
.
Tensor
,
group_size
:
int
):
return
quant_ops
.
convert_s4
(
qw
,
qz
,
s
,
group_size
)
...
...
@@ -1477,7 +1485,7 @@ def awq_marlin_moe_repack(b_q_weight: torch.Tensor, perm: torch.Tensor,
device
=
b_q_weight
.
device
,
dtype
=
b_q_weight
.
dtype
)
for
e
in
range
(
num_experts
):
output
[
e
]
=
torch
.
ops
.
marlin
.
awq_marlin_repack
(
b_q_weight
[
e
],
size_k
,
output
[
e
]
=
lightop
.
awq_marlin_repack
(
b_q_weight
[
e
],
size_k
,
size_n
,
num_bits
)
return
output
...
...
@@ -2436,4 +2444,11 @@ if hasattr(torch.ops._C, "int8_scaled_mm_with_quant"):
)
->
torch
.
Tensor
:
M
=
mat1
.
size
(
0
)
N
=
mat2
.
size
(
0
)
return
torch
.
empty
((
M
,
N
),
dtype
=
out_dtype
)
\ No newline at end of file
return
torch
.
empty
((
M
,
N
),
dtype
=
out_dtype
)
direct_register_custom_op
(
op_name
=
"awq_gemm"
,
op_func
=
awq_gemm
,
mutates_args
=
[],
fake_impl
=
awq_gemm_fake
,
)
\ No newline at end of file
vllm/model_executor/layers/fused_moe/fused_marlin_moe.py
View file @
9820d063
...
...
@@ -6,9 +6,10 @@ from typing import Optional
import
torch
try
:
import
marlin
import
lightop
except
Exception
:
print
(
"INFO: Please install marlin if you want to infer awq moe of marlin.
\n
"
)
print
(
"INFO: Please install lightop if you want to infer awq of marlin.
\n
"
)
import
vllm.envs
as
envs
import
vllm._custom_ops
as
ops
from
vllm.model_executor.layers.fused_moe.fused_moe
import
(
...
...
@@ -28,12 +29,12 @@ def fused_marlin_moe(
hidden_states
:
torch
.
Tensor
,
# 32, 7168
w1
:
torch
.
Tensor
,
# 256, 512, 7168 --> 32*8, 512 --> 32*8, 256
w2
:
torch
.
Tensor
,
# 256, 256, 7168
w1_scale
:
torch
.
Tensor
,
w2_scale
:
torch
.
Tensor
,
w1_scale
_zero
:
torch
.
Tensor
,
w2_scale
_zero
:
torch
.
Tensor
,
gating_output
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
global_num_experts
:
int
=
-
1
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
g_idx1
:
Optional
[
torch
.
Tensor
]
=
None
,
g_idx2
:
Optional
[
torch
.
Tensor
]
=
None
,
...
...
@@ -41,7 +42,7 @@ def fused_marlin_moe(
sort_indices2
:
Optional
[
torch
.
Tensor
]
=
None
,
w1_zeros
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_zeros
:
Optional
[
torch
.
Tensor
]
=
None
,
#
workspace: Optional[torch.Tensor] = None,
workspace
:
Optional
[
torch
.
Tensor
]
=
None
,
num_bits
:
int
=
4
,
is_k_full
:
bool
=
True
,
inplace
:
bool
=
False
)
->
torch
.
Tensor
:
...
...
@@ -94,48 +95,31 @@ def fused_marlin_moe(
assert
w1
.
is_contiguous
(),
"Expert weights1 must be contiguous"
assert
w2
.
is_contiguous
(),
"Expert weights2 must be contiguous"
assert
hidden_states
.
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
]
# assert num_bits in [4, 8]
# 目前只支持 uint4的量化结果
# assert num_bits in [4]
assert
num_bits
in
[
4
]
M
,
K
=
hidden_states
.
shape
# 32, 7168
num_tokens
,
K
=
hidden_states
.
shape
# 32, 7168
E
=
w1
.
shape
[
0
]
# 256
N
=
w2
.
shape
[
1
]
*
16
# 256
topk
=
topk_ids
.
shape
[
1
]
# 8
# # 计算 topk_weights 和 topk_ids
# topk_weights, topk_ids = fused_topk(hidden_states, score, topk, False)
#暂时固定为16384
CHUNK_SIZE
=
16384
# 选择 block_size_m 的逻辑按照 Marlin来设置
for
block_size_m
in
[
16
,
32
,
48
,
64
,
80
]:
if
M
*
topk
/
E
/
block_size_m
<
0.9
:
break
# print("m: ", M, "; block_m: ", block_size_m)
M
=
min
(
num_tokens
,
CHUNK_SIZE
)
if
global_num_experts
==
-
1
:
global_num_experts
=
E
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
=
\
moe_align_block_size
(
topk_ids
,
block_size_m
,
global_num_experts
,
expert_map
)
# max_num = num_tokens_post_padded.item()
# print("max_num: ", max_num)
# 输出
# for i in range(0, max_num, block_size_m):
# print(i / block_size_m, sorted_token_ids[i:(i + block_size_m)])
# if workspace is None:
# max_workspace_size = (max(2 * N, K) // 64) * \
# (sorted_token_ids.size(0) // block_size_m)
# device = hidden_states.device
# sms = torch.cuda.get_device_properties(device).multi_processor_count
# max_workspace_size = min(max_workspace_size, sms * 4)
# workspace = torch.zeros(max_workspace_size,
# dtype=torch.int,
# device=device,
# requires_grad=False)
if
workspace
is
None
:
sms
=
torch
.
cuda
.
get_device_properties
(
device
=
'cuda'
).
multi_processor_count
workspace
=
torch
.
zeros
(
sms
*
3
,
dtype
=
torch
.
int
,
device
=
hidden_states
.
device
,
requires_grad
=
False
)
scalar_type1
=
get_scalar_type
(
num_bits
,
w1_zeros
is
not
None
)
scalar_type2
=
get_scalar_type
(
num_bits
,
w2_zeros
is
not
None
)
intermediate_cache2
=
torch
.
empty
(
# [32*8, 256]
if
global_num_experts
==
-
1
:
global_num_experts
=
E
intermediate_cache2
=
torch
.
empty
(
(
M
*
topk_ids
.
shape
[
1
],
N
),
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
,
...
...
@@ -145,90 +129,112 @@ def fused_marlin_moe(
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
,
)
intermediate_cache1
=
intermediate_cache13
[:
M
*
topk_ids
.
shape
[
1
]
*
2
*
N
]
# [32*8, 512]
intermediate_cache1
=
intermediate_cache13
[:
M
*
topk_ids
.
shape
[
1
]
*
2
*
N
]
intermediate_cache1
=
intermediate_cache1
.
view
(
-
1
,
2
*
N
)
intermediate_cache3
=
intermediate_cache13
[:
M
*
topk_ids
.
shape
[
1
]
*
K
]
# # [32*8, 7168]
intermediate_cache3
=
intermediate_cache13
[:
M
*
topk_ids
.
shape
[
1
]
*
K
]
intermediate_cache3
=
intermediate_cache3
.
view
(
-
1
,
K
)
use_atomic_add
=
hidden_states
.
dtype
==
torch
.
half
or
\
torch
.
cuda
.
get_device_capability
(
hidden_states
.
device
)[
0
]
>=
9
intermediate_cache1
.
zero_
()
intermediate_cache1
=
torch
.
ops
.
marlin
.
moe_wna16_marlin_gemm
(
hidden_states
,
# [32, 7168] # arg0: torch.Tensor,
intermediate_cache1
,
# [32*8, 512] # arg1: Optional[torch.Tensor]
w1
,
# arg2: torch.Tensor
w1_scale
,
# arg3: torch.Tensor
# w1_zeros, # arg4: Optional[torch.Tensor]
g_idx1
,
# arg5: Optional[torch.Tensor]
sort_indices1
,
# arg6: Optional[torch.Tensor]
# workspace, # arg7: torch.Tensor
sorted_token_ids
,
# arg8: torch.Tensor
expert_ids
,
# arg9: torch.Tensor
num_tokens_post_padded
,
# arg10: torch.Tensor
topk_weights
,
#arg11: torch.Tensor,
block_size_m
,
# arg12: int,
topk
,
# arg13: int,
False
,
# arg14: bool,
expert_map
is
not
None
,
# arg15: bool,
scalar_type1
.
id
,
# arg16: int
M
,
# arg17: int,
2
*
N
,
# arg18: int
K
,
# arg19: int,
is_k_full
,
# arg20: bool,
use_atomic_add
,
# arg21: bool,
True
,
# arg22: bool
False
)
# arg23: bool
# [32*8, 512] --> [32*8, 256]
torch
.
ops
.
_C
.
silu_and_mul
(
intermediate_cache2
,
intermediate_cache1
.
view
(
-
1
,
2
*
N
))
intermediate_cache3
.
zero_
()
intermediate_cache3
=
torch
.
ops
.
marlin
.
moe_wna16_marlin_gemm
(
intermediate_cache2
,
# [32*8, 256]
intermediate_cache3
,
# [32*8, 7168]
w2
,
w2_scale
,
# w2_zeros,
g_idx2
,
sort_indices2
,
# workspace,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
,
topk_weights
,
block_size_m
,
1
,
True
,
expert_map
is
not
None
,
scalar_type2
.
id
,
M
*
topk
,
K
,
N
,
is_k_full
,
use_atomic_add
,
True
,
False
).
view
(
-
1
,
topk
,
K
)
output
=
hidden_states
if
inplace
else
torch
.
empty_like
(
hidden_states
)
# return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape),
# dim=1,
# out=output)
ops
.
moe_sum
(
intermediate_cache3
.
view
(
*
intermediate_cache3
.
shape
),
output
)
return
output
if
inplace
:
out_hidden_states
=
hidden_states
else
:
out_hidden_states
=
torch
.
empty_like
(
hidden_states
)
for
chunk
in
range
((
num_tokens
//
CHUNK_SIZE
)
+
1
):
begin_chunk_idx
,
end_chunk_idx
=
(
chunk
*
CHUNK_SIZE
,
min
((
chunk
+
1
)
*
CHUNK_SIZE
,
num_tokens
))
curr_hidden_states
=
hidden_states
[
begin_chunk_idx
:
end_chunk_idx
]
tokens_in_chunk
,
_
=
curr_hidden_states
.
size
()
if
tokens_in_chunk
==
0
:
break
intermediate_cache3
=
intermediate_cache3
.
view
(
-
1
,
K
)
if
tokens_in_chunk
<
CHUNK_SIZE
and
chunk
>
0
:
intermediate_cache1
=
intermediate_cache1
[:
tokens_in_chunk
*
topk
,
:]
intermediate_cache2
=
intermediate_cache2
[:
tokens_in_chunk
*
topk
,
:]
intermediate_cache3
=
intermediate_cache3
[:
tokens_in_chunk
*
topk
,
:]
M
=
tokens_in_chunk
# Select block_size_m
for
block_size_m
in
[
16
,
32
,
48
,
64
,
80
]:
if
M
*
topk
/
E
/
block_size_m
<
0.9
:
break
curr_topk_ids
=
topk_ids
[
begin_chunk_idx
:
end_chunk_idx
]
curr_topk_weights
=
topk_weights
[
begin_chunk_idx
:
end_chunk_idx
]
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
=
moe_align_block_size
(
curr_topk_ids
,
block_size_m
,
global_num_experts
,
expert_map
)
intermediate_cache1
=
lightop
.
moe_marlin_w4a16
(
curr_hidden_states
,
intermediate_cache1
,
w1
,
w1_scale_zero
,
g_idx1
,
sort_indices1
,
workspace
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
,
curr_topk_weights
,
block_size_m
,
topk
,
False
,
expert_map
is
not
None
,
M
,
2
*
N
,
K
,
is_k_full
,
use_atomic_add
,
True
,
False
)
torch
.
ops
.
_C
.
silu_and_mul
(
intermediate_cache2
,
intermediate_cache1
)
intermediate_cache3
=
lightop
.
moe_marlin_w4a16
(
intermediate_cache2
,
intermediate_cache3
,
w2
,
w2_scale_zero
,
g_idx2
,
sort_indices2
,
workspace
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
,
curr_topk_weights
,
block_size_m
,
1
,
True
,
expert_map
is
not
None
,
M
*
topk
,
K
,
N
,
is_k_full
,
use_atomic_add
,
True
,
False
).
view
(
-
1
,
topk
,
K
)
ops
.
moe_sum
(
intermediate_cache3
.
view
(
*
intermediate_cache3
.
shape
),
out_hidden_states
[
begin_chunk_idx
:
end_chunk_idx
])
return
out_hidden_states
def
fused_marlin_moe_fake
(
hidden_states
:
torch
.
Tensor
,
# 32, 7168
w1
:
torch
.
Tensor
,
# 256, 512, 7168 --> 32*8, 512 --> 32*8, 256
w2
:
torch
.
Tensor
,
# 256, 256, 7168
w1_scale
:
torch
.
Tensor
,
w2_scale
:
torch
.
Tensor
,
w1_scale
_zero
:
torch
.
Tensor
,
w2_scale
_zero
:
torch
.
Tensor
,
gating_output
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
global_num_experts
:
int
=
-
1
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
g_idx1
:
Optional
[
torch
.
Tensor
]
=
None
,
g_idx2
:
Optional
[
torch
.
Tensor
]
=
None
,
...
...
@@ -236,7 +242,7 @@ def fused_marlin_moe_fake(
sort_indices2
:
Optional
[
torch
.
Tensor
]
=
None
,
w1_zeros
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_zeros
:
Optional
[
torch
.
Tensor
]
=
None
,
#
workspace: Optional[torch.Tensor] = None,
workspace
:
Optional
[
torch
.
Tensor
]
=
None
,
num_bits
:
int
=
4
,
is_k_full
:
bool
=
True
,
inplace
:
bool
=
False
)
->
torch
.
Tensor
:
...
...
vllm/model_executor/layers/quantization/awq.py
View file @
9820d063
...
...
@@ -67,8 +67,14 @@ def default_execution(k,n):
def
getspec_config
(
M
,
N
,
K
):
if
f
"
{
M
}
_
{
N
}
_
{
K
}
"
in
triton_configs_dict
:
return
triton_configs_dict
[
f
"
{
M
}
_
{
N
}
_
{
K
}
"
]
m_config
=
M
if
M
>
16
:
# 直接计算 2 的幂
m_config
=
1
while
m_config
<
M
:
m_config
*=
2
if
f
"
{
m_config
}
_
{
N
}
_
{
K
}
"
in
triton_configs_dict
:
return
triton_configs_dict
[
f
"
{
m_config
}
_
{
N
}
_
{
K
}
"
]
else
:
return
None
...
...
@@ -336,14 +342,11 @@ class AWQLinearMethod(LinearMethodBase):
padding_group
=
0
if
envs
.
VLLM_USE_TRITON_AWQ
:
if
m
>
16
:
m
=
1
<<
(
m
-
1
).
bit_length
()
best_config
=
getspec_config
(
m
,
n
,
k
)
out
=
awq_gemm_triton
(
reshaped_x
,
qweight
,
scales
,
qzeros
,
pack_factor
,
best_config
)
out_shape
=
(
x
.
shape
[:
-
1
]
+
(
qweight
.
shape
[
1
]
*
8
,
))
else
:
out
=
ops
.
awq_gemm
(
reshaped_x
,
out
=
torch
.
ops
.
vllm
.
awq_gemm
(
reshaped_x
,
qweight
,
zeros_and_scales
,
m
,
...
...
vllm/model_executor/layers/quantization/awq_marlin.py
View file @
9820d063
...
...
@@ -401,7 +401,7 @@ class AWQMoEMethod(FusedMoEMethodBase):
set_weight_attrs
(
w2_qzeros
,
extra_weight_attrs
)
device
=
layer
.
w13_qweight
.
device
layer
.
workspace
=
marlin_make_workspace_new
(
device
,
4
)
layer
.
workspace
=
marlin_make_workspace_new
(
device
,
3
)
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
num_experts
=
layer
.
w13_qweight
.
shape
[
0
]
...
...
@@ -546,6 +546,6 @@ class AWQMoEMethod(FusedMoEMethodBase):
expert_map
=
expert_map
,
w1_zeros
=
layer
.
w13_qzeros
,
w2_zeros
=
layer
.
w2_qzeros
,
#
workspace=layer.workspace
workspace
=
layer
.
workspace
,
num_bits
=
4
)
vllm/model_executor/layers/quantization/awq_triton.py
View file @
9820d063
...
...
@@ -294,7 +294,7 @@ def awq_gemm_triton(input: torch.Tensor,
scales
:
torch
.
Tensor
,
qzeros
:
torch
.
Tensor
,
split_k_iters
:
int
,
config
)
->
torch
.
Tensor
:
config
=
None
)
->
torch
.
Tensor
:
M
,
K
=
input
.
shape
N
=
qweight
.
shape
[
1
]
*
8
group_size
=
qweight
.
shape
[
0
]
//
qzeros
.
shape
[
0
]
...
...
vllm/model_executor/layers/quantization/utils/marlin_utils.py
View file @
9820d063
...
...
@@ -14,10 +14,6 @@ from vllm.platforms import current_platform
from
vllm.scalar_type
import
ScalarType
,
scalar_types
from
.quant_utils
import
pack_cols
,
unpack_cols
try
:
import
marlin
except
Exception
:
print
(
"INFO: Please install marlin if you want to infer awq moe of marlin.
\n
"
)
logger
=
init_logger
(
__name__
)
...
...
vllm/model_executor/models/deepseek_mtp.py
View file @
9820d063
...
...
@@ -21,7 +21,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
from
vllm.compilation.decorators
import
support_torch_compile
from
.deepseek_v2
import
(
DeepseekV2DecoderLayer
,
get_spec_layer_idx_from_weight_name
)
from
.interfaces
import
SupportsPP
...
...
@@ -150,7 +150,7 @@ class DeepSeekMultiTokenPredictor(nn.Module):
sampling_metadata
)
return
logits
#@support_torch_compile
class
DeepSeekMTP
(
nn
.
Module
,
SupportsPP
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment