Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
c2bcb0ab
Commit
c2bcb0ab
authored
Jul 26, 2025
by
yangql
Browse files
增加moe awq-marlin的支持
parent
cb37537e
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
326 additions
and
175 deletions
+326
-175
vllm/_custom_ops.py
vllm/_custom_ops.py
+5
-1
vllm/model_executor/layers/fused_moe/fused_marlin_moe.py
vllm/model_executor/layers/fused_moe/fused_marlin_moe.py
+140
-123
vllm/model_executor/layers/quantization/awq.py
vllm/model_executor/layers/quantization/awq.py
+1
-1
vllm/model_executor/layers/quantization/awq_marlin.py
vllm/model_executor/layers/quantization/awq_marlin.py
+43
-15
vllm/model_executor/layers/quantization/utils/marlin_utils.py
.../model_executor/layers/quantization/utils/marlin_utils.py
+135
-33
vllm/model_executor/models/deepseek_mtp.py
vllm/model_executor/models/deepseek_mtp.py
+1
-1
vllm/platforms/rocm.py
vllm/platforms/rocm.py
+1
-1
No files found.
vllm/_custom_ops.py
View file @
c2bcb0ab
...
@@ -16,6 +16,10 @@ try:
...
@@ -16,6 +16,10 @@ try:
from
lmslim
import
quant_tools
from
lmslim
import
quant_tools
except
Exception
:
except
Exception
:
print
(
"INFO: Please install lmslim if you want to infer gptq or awq or w8a8 model.
\n
"
)
print
(
"INFO: Please install lmslim if you want to infer gptq or awq or w8a8 model.
\n
"
)
try
:
import
marlin
except
Exception
:
print
(
"INFO: Please install marlin if you want to infer awq of marlin.
\n
"
)
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -1473,7 +1477,7 @@ def awq_marlin_moe_repack(b_q_weight: torch.Tensor, perm: torch.Tensor,
...
@@ -1473,7 +1477,7 @@ def awq_marlin_moe_repack(b_q_weight: torch.Tensor, perm: torch.Tensor,
device
=
b_q_weight
.
device
,
device
=
b_q_weight
.
device
,
dtype
=
b_q_weight
.
dtype
)
dtype
=
b_q_weight
.
dtype
)
for
e
in
range
(
num_experts
):
for
e
in
range
(
num_experts
):
output
[
e
]
=
torch
.
ops
.
_C
.
awq_marlin_repack
(
b_q_weight
[
e
],
size_k
,
output
[
e
]
=
torch
.
ops
.
marlin
.
awq_marlin_repack
(
b_q_weight
[
e
],
size_k
,
size_n
,
num_bits
)
size_n
,
num_bits
)
return
output
return
output
...
...
vllm/model_executor/layers/fused_moe/fused_marlin_moe.py
View file @
c2bcb0ab
...
@@ -5,7 +5,11 @@ import functools
...
@@ -5,7 +5,11 @@ import functools
from
typing
import
Optional
from
typing
import
Optional
import
torch
import
torch
try
:
import
marlin
except
Exception
:
print
(
"INFO: Please install marlin if you want to infer awq moe of marlin.
\n
"
)
import
vllm.envs
as
envs
import
vllm._custom_ops
as
ops
import
vllm._custom_ops
as
ops
from
vllm.model_executor.layers.fused_moe.fused_moe
import
(
from
vllm.model_executor.layers.fused_moe.fused_moe
import
(
moe_align_block_size
,
try_get_optimal_moe_config
)
moe_align_block_size
,
try_get_optimal_moe_config
)
...
@@ -14,28 +18,31 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
...
@@ -14,28 +18,31 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
from
vllm.scalar_type
import
ScalarType
,
scalar_types
from
vllm.scalar_type
import
ScalarType
,
scalar_types
from
vllm.utils
import
direct_register_custom_op
from
vllm.utils
import
direct_register_custom_op
def
get_scalar_type
(
num_bits
:
int
,
has_zp
:
bool
):
if
has_zp
:
return
scalar_types
.
uint4
if
num_bits
==
4
else
scalar_types
.
uint8
else
:
return
scalar_types
.
uint4b8
if
num_bits
==
4
else
scalar_types
.
uint8b128
def
fused_marlin_moe
(
hidden_states
:
torch
.
Tensor
,
def
fused_marlin_moe
(
w1
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
# 32, 7168
w2
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
# 256, 512, 7168 --> 32*8, 512 --> 32*8, 256
w2
:
torch
.
Tensor
,
# 256, 256, 7168
w1_scale
:
torch
.
Tensor
,
w1_scale
:
torch
.
Tensor
,
w2_scale
:
torch
.
Tensor
,
w2_scale
:
torch
.
Tensor
,
gating_output
:
torch
.
Tensor
,
gating_output
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
quant_type_id
:
int
,
apply_router_weight_on_input
:
bool
=
False
,
global_num_experts
:
int
=
-
1
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
global_scale1
:
Optional
[
torch
.
Tensor
]
=
None
,
global_scale2
:
Optional
[
torch
.
Tensor
]
=
None
,
g_idx1
:
Optional
[
torch
.
Tensor
]
=
None
,
g_idx1
:
Optional
[
torch
.
Tensor
]
=
None
,
g_idx2
:
Optional
[
torch
.
Tensor
]
=
None
,
g_idx2
:
Optional
[
torch
.
Tensor
]
=
None
,
sort_indices1
:
Optional
[
torch
.
Tensor
]
=
None
,
sort_indices1
:
Optional
[
torch
.
Tensor
]
=
None
,
sort_indices2
:
Optional
[
torch
.
Tensor
]
=
None
,
sort_indices2
:
Optional
[
torch
.
Tensor
]
=
None
,
w1_zeros
:
Optional
[
torch
.
Tensor
]
=
None
,
w1_zeros
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_zeros
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_zeros
:
Optional
[
torch
.
Tensor
]
=
None
,
workspace
:
Optional
[
torch
.
Tensor
]
=
None
,
# workspace: Optional[torch.Tensor] = None,
num_bits
:
int
=
4
,
is_k_full
:
bool
=
True
,
is_k_full
:
bool
=
True
,
inplace
:
bool
=
False
)
->
torch
.
Tensor
:
inplace
:
bool
=
False
)
->
torch
.
Tensor
:
"""
"""
...
@@ -65,16 +72,16 @@ def fused_marlin_moe(hidden_states: torch.Tensor,
...
@@ -65,16 +72,16 @@ def fused_marlin_moe(hidden_states: torch.Tensor,
Returns:
Returns:
- torch.Tensor: The output tensor after applying the MoE layer.
- torch.Tensor: The output tensor after applying the MoE layer.
"""
"""
quant_type
=
ScalarType
.
from_id
(
quant_type_id
)
#
quant_type = ScalarType.from_id(quant_type_id)
assert
quant_type
in
[
#
assert quant_type in [
scalar_types
.
uint4
,
scalar_types
.
uint8b128
,
scalar_types
.
uint4b8
,
#
scalar_types.uint4, scalar_types.uint8b128, scalar_types.uint4b8,
scalar_types
.
float8_e4m3fn
,
scalar_types
.
float4_e2m1f
#
scalar_types.float8_e4m3fn, scalar_types.float4_e2m1f
]
#
]
bit4_scalar_types
=
[
#
bit4_scalar_types = [
scalar_types
.
uint4
,
scalar_types
.
uint4b8
,
scalar_types
.
float4_e2m1f
#
scalar_types.uint4, scalar_types.uint4b8, scalar_types.float4_e2m1f
]
#
]
num_bits
=
4
if
quant_type
in
bit4_scalar_types
else
8
#
num_bits = 4 if quant_type in bit4_scalar_types else 8
# Check constraints.
# Check constraints.
assert
hidden_states
.
shape
[
0
]
==
gating_output
.
shape
[
assert
hidden_states
.
shape
[
0
]
==
gating_output
.
shape
[
...
@@ -87,35 +94,48 @@ def fused_marlin_moe(hidden_states: torch.Tensor,
...
@@ -87,35 +94,48 @@ def fused_marlin_moe(hidden_states: torch.Tensor,
assert
w1
.
is_contiguous
(),
"Expert weights1 must be contiguous"
assert
w1
.
is_contiguous
(),
"Expert weights1 must be contiguous"
assert
w2
.
is_contiguous
(),
"Expert weights2 must be contiguous"
assert
w2
.
is_contiguous
(),
"Expert weights2 must be contiguous"
assert
hidden_states
.
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
]
assert
hidden_states
.
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
]
assert
num_bits
in
[
4
,
8
]
# assert num_bits in [4, 8]
# 目前只支持 uint4的量化结果
M
,
K
=
hidden_states
.
shape
assert
num_bits
in
[
4
]
E
=
w1
.
shape
[
0
]
N
=
w2
.
shape
[
1
]
*
16
M
,
K
=
hidden_states
.
shape
# 32, 7168
topk
=
topk_ids
.
shape
[
1
]
E
=
w1
.
shape
[
0
]
# 256
N
=
w2
.
shape
[
1
]
*
16
# 256
get_config_func
=
functools
.
partial
(
topk
=
topk_ids
.
shape
[
1
]
# 8
try_get_optimal_moe_config
,
# # 计算 topk_weights 和 topk_ids
w1
.
shape
,
# topk_weights, topk_ids = fused_topk(hidden_states, score, topk, False)
w2
.
shape
,
topk_ids
.
shape
[
1
],
# 选择 block_size_m 的逻辑按照 Marlin来设置
None
,
for
block_size_m
in
[
16
,
32
,
48
,
64
,
80
]:
is_marlin
=
True
,
if
M
*
topk
/
E
/
block_size_m
<
0.9
:
)
break
config
=
get_config_func
(
M
)
# print("m: ", M, "; block_m: ", block_size_m)
block_size_m
=
config
[
"BLOCK_SIZE_M"
]
if
global_num_experts
==
-
1
:
if
global_num_experts
==
-
1
:
global_num_experts
=
E
global_num_experts
=
E
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
=
\
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
=
\
moe_align_block_size
(
topk_ids
,
block_size_m
,
global_num_experts
,
moe_align_block_size
(
topk_ids
,
block_size_m
,
global_num_experts
,
expert_map
)
expert_map
)
# max_num = num_tokens_post_padded.item()
if
workspace
is
None
:
# print("max_num: ", max_num)
workspace
=
marlin_make_workspace_new
(
hidden_states
.
device
,
4
)
# 输出
# for i in range(0, max_num, block_size_m):
intermediate_cache2
=
torch
.
empty
(
# print(i / block_size_m, sorted_token_ids[i:(i + block_size_m)])
# if workspace is None:
# max_workspace_size = (max(2 * N, K) // 64) * \
# (sorted_token_ids.size(0) // block_size_m)
# device = hidden_states.device
# sms = torch.cuda.get_device_properties(device).multi_processor_count
# max_workspace_size = min(max_workspace_size, sms * 4)
# workspace = torch.zeros(max_workspace_size,
# dtype=torch.int,
# device=device,
# requires_grad=False)
scalar_type1
=
get_scalar_type
(
num_bits
,
w1_zeros
is
not
None
)
scalar_type2
=
get_scalar_type
(
num_bits
,
w2_zeros
is
not
None
)
intermediate_cache2
=
torch
.
empty
(
# [32*8, 256]
(
M
*
topk_ids
.
shape
[
1
],
N
),
(
M
*
topk_ids
.
shape
[
1
],
N
),
device
=
hidden_states
.
device
,
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
,
dtype
=
hidden_states
.
dtype
,
...
@@ -125,94 +145,90 @@ def fused_marlin_moe(hidden_states: torch.Tensor,
...
@@ -125,94 +145,90 @@ def fused_marlin_moe(hidden_states: torch.Tensor,
device
=
hidden_states
.
device
,
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
,
dtype
=
hidden_states
.
dtype
,
)
)
intermediate_cache1
=
intermediate_cache13
[:
M
*
topk_ids
.
shape
[
1
]
*
2
*
N
]
intermediate_cache1
=
intermediate_cache13
[:
M
*
topk_ids
.
shape
[
1
]
*
2
*
N
]
# [32*8, 512]
intermediate_cache1
=
intermediate_cache1
.
view
(
-
1
,
2
*
N
)
intermediate_cache1
=
intermediate_cache1
.
view
(
-
1
,
2
*
N
)
intermediate_cache3
=
intermediate_cache13
[:
M
*
topk_ids
.
shape
[
1
]
*
K
]
intermediate_cache3
=
intermediate_cache13
[:
M
*
topk_ids
.
shape
[
1
]
*
K
]
# # [32*8, 7168]
intermediate_cache3
=
intermediate_cache3
.
view
(
-
1
,
K
)
intermediate_cache3
=
intermediate_cache3
.
view
(
-
1
,
K
)
maybe_warn_marlin_atomic_add
(
hidden_states
.
device
,
hidden_states
.
dtype
)
use_atomic_add
=
hidden_states
.
dtype
==
torch
.
half
or
\
use_atomic_add
=
hidden_states
.
dtype
==
torch
.
half
or
\
torch
.
cuda
.
get_device_capability
(
hidden_states
.
device
)[
0
]
>=
9
torch
.
cuda
.
get_device_capability
(
hidden_states
.
device
)[
0
]
>=
9
intermediate_cache1
=
ops
.
moe_wna16_marlin_gemm
(
intermediate_cache1
.
zero_
()
hidden_states
,
intermediate_cache1
=
torch
.
ops
.
marlin
.
moe_wna16_marlin_gemm
(
intermediate_cache1
,
hidden_states
,
# [32, 7168] # arg0: torch.Tensor,
w1
,
intermediate_cache1
,
# [32*8, 512] # arg1: Optional[torch.Tensor]
w1_scale
,
w1
,
# arg2: torch.Tensor
global_scale1
,
w1_scale
,
# arg3: torch.Tensor
w1_zeros
,
# w1_zeros, # arg4: Optional[torch.Tensor]
g_idx1
,
g_idx1
,
# arg5: Optional[torch.Tensor]
sort_indices1
,
sort_indices1
,
# arg6: Optional[torch.Tensor]
workspace
,
# workspace, # arg7: torch.Tensor
sorted_token_ids
,
sorted_token_ids
,
# arg8: torch.Tensor
expert_ids
,
expert_ids
,
# arg9: torch.Tensor
num_tokens_post_padded
,
num_tokens_post_padded
,
# arg10: torch.Tensor
topk_weights
,
topk_weights
,
#arg11: torch.Tensor,
moe_block_size
=
block_size_m
,
block_size_m
,
# arg12: int,
top_k
=
topk
,
topk
,
# arg13: int,
mul_topk_weights
=
apply_router_weight_on_input
,
False
,
# arg14: bool,
is_ep
=
expert_map
is
not
None
,
expert_map
is
not
None
,
# arg15: bool,
b_q_type
=
quant_type
,
scalar_type1
.
id
,
# arg16: int
size_m
=
M
,
M
,
# arg17: int,
size_n
=
2
*
N
,
2
*
N
,
# arg18: int
size_k
=
K
,
K
,
# arg19: int,
is_k_full
=
is_k_full
,
is_k_full
,
# arg20: bool,
use_atomic_add
=
use_atomic_add
,
use_atomic_add
,
# arg21: bool,
use_fp32_reduce
=
True
,
True
,
# arg22: bool
is_zp_float
=
False
)
False
)
# arg23: bool
# [32*8, 512] --> [32*8, 256]
torch
.
ops
.
_C
.
silu_and_mul
(
intermediate_cache2
,
torch
.
ops
.
_C
.
silu_and_mul
(
intermediate_cache2
,
intermediate_cache1
.
view
(
-
1
,
2
*
N
))
intermediate_cache1
.
view
(
-
1
,
2
*
N
))
if
expert_map
is
not
None
:
intermediate_cache3
.
zero_
()
intermediate_cache3
.
zero_
()
intermediate_cache3
=
torch
.
ops
.
marlin
.
moe_wna16_marlin_gemm
(
intermediate_cache3
=
ops
.
moe_wna16_marlin_gemm
(
intermediate_cache2
,
# [32*8, 256]
intermediate_cache2
,
intermediate_cache3
,
# [32*8, 7168]
intermediate_cache3
,
w2
,
w2
,
w2_scale
,
w2_scale
,
global_scale2
,
# w2_zeros,
w2_zeros
,
g_idx2
,
g_idx2
,
sort_indices2
,
sort_indices2
,
workspace
,
#
workspace,
sorted_token_ids
,
sorted_token_ids
,
expert_ids
,
expert_ids
,
num_tokens_post_padded
,
num_tokens_post_padded
,
topk_weights
,
topk_weights
,
moe_block_size
=
block_size_m
,
block_size_m
,
top_k
=
1
,
1
,
mul_topk_weights
=
not
apply_router_weight_on_input
,
True
,
is_ep
=
expert_map
is
not
None
,
expert_map
is
not
None
,
b_q_type
=
quant_type
,
scalar_type2
.
id
,
size_m
=
M
*
topk
,
M
*
topk
,
size_n
=
K
,
K
,
size_k
=
N
,
N
,
is_k_full
=
is_k_full
,
is_k_full
,
use_atomic_add
=
use_atomic_add
,
use_atomic_add
,
use_fp32_reduce
=
True
,
True
,
is_zp_float
=
False
).
view
(
-
1
,
topk
,
K
)
False
).
view
(
-
1
,
topk
,
K
)
output
=
hidden_states
if
inplace
else
torch
.
empty_like
(
hidden_states
)
output
=
hidden_states
if
inplace
else
torch
.
empty_like
(
hidden_states
)
return
torch
.
sum
(
intermediate_cache3
.
view
(
*
intermediate_cache3
.
shape
),
# return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape),
dim
=
1
,
# dim=1,
out
=
output
)
# out=output)
ops
.
moe_sum
(
intermediate_cache3
.
view
(
*
intermediate_cache3
.
shape
),
output
)
return
output
def
fused_marlin_moe_fake
(
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
def
fused_marlin_moe_fake
(
w2
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
# 32, 7168
w1
:
torch
.
Tensor
,
# 256, 512, 7168 --> 32*8, 512 --> 32*8, 256
w2
:
torch
.
Tensor
,
# 256, 256, 7168
w1_scale
:
torch
.
Tensor
,
w1_scale
:
torch
.
Tensor
,
w2_scale
:
torch
.
Tensor
,
w2_scale
:
torch
.
Tensor
,
gating_output
:
torch
.
Tensor
,
gating_output
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
quant_type_id
:
int
,
apply_router_weight_on_input
:
bool
=
False
,
global_num_experts
:
int
=
-
1
,
global_num_experts
:
int
=
-
1
,
global_scale1
:
Optional
[
torch
.
Tensor
]
=
None
,
global_scale2
:
Optional
[
torch
.
Tensor
]
=
None
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
g_idx1
:
Optional
[
torch
.
Tensor
]
=
None
,
g_idx1
:
Optional
[
torch
.
Tensor
]
=
None
,
g_idx2
:
Optional
[
torch
.
Tensor
]
=
None
,
g_idx2
:
Optional
[
torch
.
Tensor
]
=
None
,
...
@@ -220,7 +236,8 @@ def fused_marlin_moe_fake(hidden_states: torch.Tensor,
...
@@ -220,7 +236,8 @@ def fused_marlin_moe_fake(hidden_states: torch.Tensor,
sort_indices2
:
Optional
[
torch
.
Tensor
]
=
None
,
sort_indices2
:
Optional
[
torch
.
Tensor
]
=
None
,
w1_zeros
:
Optional
[
torch
.
Tensor
]
=
None
,
w1_zeros
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_zeros
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_zeros
:
Optional
[
torch
.
Tensor
]
=
None
,
workspace
:
Optional
[
torch
.
Tensor
]
=
None
,
# workspace: Optional[torch.Tensor] = None,
num_bits
:
int
=
4
,
is_k_full
:
bool
=
True
,
is_k_full
:
bool
=
True
,
inplace
:
bool
=
False
)
->
torch
.
Tensor
:
inplace
:
bool
=
False
)
->
torch
.
Tensor
:
return
torch
.
empty_like
(
hidden_states
)
return
torch
.
empty_like
(
hidden_states
)
...
...
vllm/model_executor/layers/quantization/awq.py
View file @
c2bcb0ab
...
@@ -338,7 +338,7 @@ class AWQLinearMethod(LinearMethodBase):
...
@@ -338,7 +338,7 @@ class AWQLinearMethod(LinearMethodBase):
if
envs
.
VLLM_USE_TRITON_AWQ
:
if
envs
.
VLLM_USE_TRITON_AWQ
:
if
m
>
16
:
if
m
>
16
:
m
=
2
**
math
.
ceil
(
math
.
log2
(
m
)
)
m
=
1
<<
(
m
-
1
).
bit_length
(
)
best_config
=
getspec_config
(
m
,
n
,
k
)
best_config
=
getspec_config
(
m
,
n
,
k
)
out
=
awq_gemm_triton
(
reshaped_x
,
qweight
,
scales
,
qzeros
,
pack_factor
,
best_config
)
out
=
awq_gemm_triton
(
reshaped_x
,
qweight
,
scales
,
qzeros
,
pack_factor
,
best_config
)
out_shape
=
(
x
.
shape
[:
-
1
]
+
(
qweight
.
shape
[
1
]
*
8
,
))
out_shape
=
(
x
.
shape
[:
-
1
]
+
(
qweight
.
shape
[
1
]
*
8
,
))
...
...
vllm/model_executor/layers/quantization/awq_marlin.py
View file @
c2bcb0ab
...
@@ -26,7 +26,8 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
...
@@ -26,7 +26,8 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
marlin_make_empty_g_idx
,
marlin_make_workspace_new
,
marlin_make_empty_g_idx
,
marlin_make_workspace_new
,
marlin_moe_permute_scales
,
marlin_permute_scales
,
marlin_moe_permute_scales
,
marlin_permute_scales
,
moe_awq_to_marlin_zero_points
,
verify_marlin_supported
,
moe_awq_to_marlin_zero_points
,
verify_marlin_supported
,
verify_marlin_supports_shape
)
verify_marlin_supports_shape
,
awq_marlin_moe_permute_sz
)
from
vllm.model_executor.layers.vocab_parallel_embedding
import
ParallelLMHead
from
vllm.model_executor.layers.vocab_parallel_embedding
import
ParallelLMHead
from
vllm.model_executor.parameter
import
(
GroupQuantScaleParameter
,
from
vllm.model_executor.parameter
import
(
GroupQuantScaleParameter
,
PackedvLLMParameter
)
PackedvLLMParameter
)
...
@@ -131,10 +132,10 @@ class AWQMarlinConfig(QuantizationConfig):
...
@@ -131,10 +132,10 @@ class AWQMarlinConfig(QuantizationConfig):
return
UnquantizedLinearMethod
()
return
UnquantizedLinearMethod
()
# Check if the layer is supported by AWQMarlin.
# Check if the layer is supported by AWQMarlin.
if
not
check_marlin_supports_layer
(
layer
,
self
.
group_size
):
if
not
check_marlin_supports_layer
(
layer
,
self
.
group_size
):
logger
.
warning_once
(
#
logger.warning_once(
"Layer '%s' is not supported by AWQMarlin. Falling back to unoptimized AWQ kernels."
,
# noqa: E501
#
"Layer '%s' is not supported by AWQMarlin. Falling back to unoptimized AWQ kernels.", # noqa: E501
prefix
,
#
prefix,
)
#
)
return
AWQConfig
.
from_config
(
return
AWQConfig
.
from_config
(
self
.
full_config
).
get_quant_method
(
layer
,
prefix
)
self
.
full_config
).
get_quant_method
(
layer
,
prefix
)
return
AWQMarlinLinearMethod
(
self
)
return
AWQMarlinLinearMethod
(
self
)
...
@@ -158,8 +159,8 @@ class AWQMarlinConfig(QuantizationConfig):
...
@@ -158,8 +159,8 @@ class AWQMarlinConfig(QuantizationConfig):
group_size
=
quant_config
.
get
(
"group_size"
)
group_size
=
quant_config
.
get
(
"group_size"
)
zero_point
=
quant_config
.
get
(
"zero_point"
)
zero_point
=
quant_config
.
get
(
"zero_point"
)
if
not
current_platform
.
is_cuda
():
#
if not current_platform.is_cuda():
return
False
#
return False
if
quant_method
!=
"awq"
:
if
quant_method
!=
"awq"
:
return
False
return
False
...
@@ -441,7 +442,7 @@ class AWQMoEMethod(FusedMoEMethodBase):
...
@@ -441,7 +442,7 @@ class AWQMoEMethod(FusedMoEMethodBase):
group_size
=
self
.
quant_config
.
group_size
,
group_size
=
self
.
quant_config
.
group_size
,
)
)
replace_parameter
(
layer
,
"w13_scales"
,
marlin_w13_scales
)
#
replace_parameter(layer, "w13_scales", marlin_w13_scales)
marlin_w2_scales
=
marlin_moe_permute_scales
(
marlin_w2_scales
=
marlin_moe_permute_scales
(
s
=
layer
.
w2_scales
,
s
=
layer
.
w2_scales
,
...
@@ -449,21 +450,41 @@ class AWQMoEMethod(FusedMoEMethodBase):
...
@@ -449,21 +450,41 @@ class AWQMoEMethod(FusedMoEMethodBase):
size_n
=
layer
.
w2_scales
.
shape
[
2
],
size_n
=
layer
.
w2_scales
.
shape
[
2
],
group_size
=
self
.
quant_config
.
group_size
,
group_size
=
self
.
quant_config
.
group_size
,
)
)
replace_parameter
(
layer
,
"w2_scales"
,
marlin_w2_scales
)
#replace_parameter(layer, "w2_scales", marlin_w2_scales)
marlin_w13_zp
=
moe_awq_to_marlin_zero_points
(
marlin_w13_zp
=
moe_awq_to_marlin_zero_points
(
layer
.
w13_qzeros
,
layer
.
w13_qzeros
,
size_k
=
layer
.
w13_qzeros
.
shape
[
1
],
size_k
=
layer
.
w13_qzeros
.
shape
[
1
],
size_n
=
layer
.
w13_qzeros
.
shape
[
2
]
*
self
.
quant_config
.
pack_factor
,
size_n
=
layer
.
w13_qzeros
.
shape
[
2
]
*
self
.
quant_config
.
pack_factor
,
num_bits
=
self
.
quant_config
.
weight_bits
)
num_bits
=
self
.
quant_config
.
weight_bits
)
replace_parameter
(
layer
,
"w13_qzeros"
,
marlin_w13_zp
)
#
replace_parameter(layer, "w13_qzeros", marlin_w13_zp)
marlin_w2_zp
=
moe_awq_to_marlin_zero_points
(
marlin_w2_zp
=
moe_awq_to_marlin_zero_points
(
layer
.
w2_qzeros
,
layer
.
w2_qzeros
,
size_k
=
layer
.
w2_qzeros
.
shape
[
1
],
size_k
=
layer
.
w2_qzeros
.
shape
[
1
],
size_n
=
layer
.
w2_qzeros
.
shape
[
2
]
*
self
.
quant_config
.
pack_factor
,
size_n
=
layer
.
w2_qzeros
.
shape
[
2
]
*
self
.
quant_config
.
pack_factor
,
num_bits
=
self
.
quant_config
.
weight_bits
)
num_bits
=
self
.
quant_config
.
weight_bits
)
replace_parameter
(
layer
,
"w2_qzeros"
,
marlin_w2_zp
)
# replace_parameter(layer, "w2_qzeros", marlin_w2_zp)
marlin_w13_sz
=
awq_marlin_moe_permute_sz
(
marlin_w13_scales
,
marlin_w13_zp
,
size_k
=
layer
.
w13_scales
.
shape
[
1
]
*
self
.
quant_config
.
group_size
,
size_n
=
layer
.
w13_scales
.
shape
[
2
]
)
marlin_w2_sz
=
awq_marlin_moe_permute_sz
(
marlin_w2_scales
,
marlin_w2_zp
,
size_k
=
layer
.
w2_scales
.
shape
[
1
]
*
self
.
quant_config
.
group_size
,
size_n
=
layer
.
w2_scales
.
shape
[
2
]
)
replace_parameter
(
layer
,
"w13_scales"
,
marlin_w13_sz
)
replace_parameter
(
layer
,
"w2_scales"
,
marlin_w2_sz
)
layer
.
w13_qzeros
=
None
layer
.
w2_qzeros
=
None
torch
.
cuda
.
empty_cache
()
def
apply
(
def
apply
(
self
,
self
,
...
@@ -482,6 +503,9 @@ class AWQMoEMethod(FusedMoEMethodBase):
...
@@ -482,6 +503,9 @@ class AWQMoEMethod(FusedMoEMethodBase):
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
apply_router_weight_on_input
:
bool
=
False
,
apply_router_weight_on_input
:
bool
=
False
,
activation
:
str
=
"silu"
,
activation
:
str
=
"silu"
,
use_nn_moe
:
Optional
[
bool
]
=
False
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
use_fused_gate
:
Optional
[
bool
]
=
False
,
enable_eplb
:
bool
=
False
,
enable_eplb
:
bool
=
False
,
expert_load_view
:
Optional
[
torch
.
Tensor
]
=
None
,
expert_load_view
:
Optional
[
torch
.
Tensor
]
=
None
,
logical_to_physical_map
:
Optional
[
torch
.
Tensor
]
=
None
,
logical_to_physical_map
:
Optional
[
torch
.
Tensor
]
=
None
,
...
@@ -503,7 +527,9 @@ class AWQMoEMethod(FusedMoEMethodBase):
...
@@ -503,7 +527,9 @@ class AWQMoEMethod(FusedMoEMethodBase):
num_expert_group
=
num_expert_group
,
num_expert_group
=
num_expert_group
,
custom_routing_function
=
custom_routing_function
,
custom_routing_function
=
custom_routing_function
,
scoring_func
=
scoring_func
,
scoring_func
=
scoring_func
,
e_score_correction_bias
=
e_score_correction_bias
)
e_score_correction_bias
=
e_score_correction_bias
,
routed_scaling_factor
=
routed_scaling_factor
,
use_fused_gate
=
use_fused_gate
)
return
torch
.
ops
.
vllm
.
fused_marlin_moe
(
return
torch
.
ops
.
vllm
.
fused_marlin_moe
(
x
,
x
,
...
@@ -514,10 +540,12 @@ class AWQMoEMethod(FusedMoEMethodBase):
...
@@ -514,10 +540,12 @@ class AWQMoEMethod(FusedMoEMethodBase):
router_logits
,
router_logits
,
topk_weights
,
topk_weights
,
topk_ids
,
topk_ids
,
quant_type_id
=
self
.
quant_type
.
id
,
#
quant_type_id=self.quant_type.id,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
#
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts
=
global_num_experts
,
global_num_experts
=
global_num_experts
,
expert_map
=
expert_map
,
expert_map
=
expert_map
,
w1_zeros
=
layer
.
w13_qzeros
,
w1_zeros
=
layer
.
w13_qzeros
,
w2_zeros
=
layer
.
w2_qzeros
,
w2_zeros
=
layer
.
w2_qzeros
,
workspace
=
layer
.
workspace
)
# workspace=layer.workspace
num_bits
=
4
)
vllm/model_executor/layers/quantization/utils/marlin_utils.py
View file @
c2bcb0ab
...
@@ -14,6 +14,10 @@ from vllm.platforms import current_platform
...
@@ -14,6 +14,10 @@ from vllm.platforms import current_platform
from
vllm.scalar_type
import
ScalarType
,
scalar_types
from
vllm.scalar_type
import
ScalarType
,
scalar_types
from
.quant_utils
import
pack_cols
,
unpack_cols
from
.quant_utils
import
pack_cols
,
unpack_cols
try
:
import
marlin
except
Exception
:
print
(
"INFO: Please install marlin if you want to infer awq moe of marlin.
\n
"
)
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -153,7 +157,7 @@ def check_marlin_supports_shape(output_size_per_partition: int,
...
@@ -153,7 +157,7 @@ def check_marlin_supports_shape(output_size_per_partition: int,
return
False
,
e
.
__str__
()
return
False
,
e
.
__str__
()
return
True
,
None
return
True
,
None
#暂不支持marlinlinear
def
check_marlin_supports_layer
(
layer
:
LinearBase
,
group_size
:
int
)
\
def
check_marlin_supports_layer
(
layer
:
LinearBase
,
group_size
:
int
)
\
->
bool
:
->
bool
:
output_size_per_partition
=
getattr
(
layer
,
"output_size_per_partition"
,
output_size_per_partition
=
getattr
(
layer
,
"output_size_per_partition"
,
...
@@ -161,12 +165,12 @@ def check_marlin_supports_layer(layer: LinearBase, group_size: int) \
...
@@ -161,12 +165,12 @@ def check_marlin_supports_layer(layer: LinearBase, group_size: int) \
input_size_per_partition
=
getattr
(
layer
,
"input_size_per_partition"
,
input_size_per_partition
=
getattr
(
layer
,
"input_size_per_partition"
,
None
)
or
layer
.
input_size
None
)
or
layer
.
input_size
return
check_marlin_supports_shape
(
#
return check_marlin_supports_shape(
output_size_per_partition
=
output_size_per_partition
,
#
output_size_per_partition=output_size_per_partition,
input_size_per_partition
=
input_size_per_partition
,
#
input_size_per_partition=input_size_per_partition,
input_size
=
layer
.
input_size
,
#
input_size=layer.input_size,
group_size
=
group_size
)[
0
]
#
group_size=group_size)[0]
return
False
def
check_moe_marlin_supports_layer
(
layer
:
LinearBase
,
group_size
:
int
)
\
def
check_moe_marlin_supports_layer
(
layer
:
LinearBase
,
group_size
:
int
)
\
->
bool
:
->
bool
:
...
@@ -237,30 +241,46 @@ def marlin_sort_g_idx(
...
@@ -237,30 +241,46 @@ def marlin_sort_g_idx(
return
g_idx
[
g_idx_sort_indices
],
g_idx_sort_indices
return
g_idx
[
g_idx_sort_indices
],
g_idx_sort_indices
def
get_scale_perms
():
# def get_scale_perms():
scale_perm
:
list
[
int
]
=
[]
# scale_perm: list[int] = []
for
i
in
range
(
8
):
# for i in range(8):
scale_perm
.
extend
([
i
+
8
*
j
for
j
in
range
(
8
)])
# scale_perm.extend([i + 8 * j for j in range(8)])
scale_perm_single
:
list
[
int
]
=
[]
# scale_perm_single: list[int] = []
for
i
in
range
(
4
):
# for i in range(4):
scale_perm_single
.
extend
(
# scale_perm_single.extend(
[
2
*
i
+
j
for
j
in
[
0
,
1
,
8
,
9
,
16
,
17
,
24
,
25
]])
# [2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]])
return
scale_perm
,
scale_perm_single
# return scale_perm, scale_perm_single
# def marlin_permute_scales(s: torch.Tensor, size_k: int, size_n: int,
# group_size: int) -> torch.Tensor:
def
marlin_permute_scales
(
s
:
torch
.
Tensor
,
size_k
:
int
,
size_n
:
int
,
# scale_perm, scale_perm_single = get_scale_perms()
group_size
:
int
)
->
torch
.
Tensor
:
# if group_size < size_k and group_size != -1:
# s = s.reshape((-1, len(scale_perm)))[:, scale_perm]
# else:
# s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single]
# s = s.reshape((-1, size_n)).contiguous()
scale_perm
,
scale_perm_single
=
get_scale_perms
()
# return s
if
group_size
<
size_k
and
group_size
!=
-
1
:
def
get_scale_perms
():
scale_perm
:
List
[
int
]
=
[]
for
i
in
range
(
16
):
# 遍历列方向不同scale的 8个线程
scale_perm
.
extend
([
i
+
16
*
j
for
j
in
range
(
8
)])
# 插入 8 个数据块中 对应位置的索引
return
scale_perm
def
marlin_permute_scales
(
s
:
torch
.
Tensor
,
# [56, 512] # torch.float16
size_k
:
int
,
# 7168
size_n
:
int
,
# 512
group_size
:
int
# 128
)
->
torch
.
Tensor
:
# 将[128, 128](fp16) B矩阵中 每个[16, 16]计算块中的对应位置的 zero值 放到一起
scale_perm
=
get_scale_perms
()
s
=
s
.
reshape
((
-
1
,
len
(
scale_perm
)))[:,
scale_perm
]
s
=
s
.
reshape
((
-
1
,
len
(
scale_perm
)))[:,
scale_perm
]
else
:
s
=
s
.
reshape
((
-
1
,
len
(
scale_perm_single
)))[:,
scale_perm_single
]
s
=
s
.
reshape
((
-
1
,
size_n
)).
contiguous
()
s
=
s
.
reshape
((
-
1
,
size_n
)).
contiguous
()
return
s
return
s
def
marlin_moe_permute_scales
(
def
marlin_moe_permute_scales
(
s
:
torch
.
Tensor
,
s
:
torch
.
Tensor
,
size_k
:
int
,
size_k
:
int
,
...
@@ -281,19 +301,18 @@ def marlin_moe_permute_scales(
...
@@ -281,19 +301,18 @@ def marlin_moe_permute_scales(
def
marlin_zero_points
(
zp
:
torch
.
Tensor
,
size_k
:
int
,
size_n
:
int
,
def
marlin_zero_points
(
zp
:
torch
.
Tensor
,
size_k
:
int
,
size_n
:
int
,
num_bits
:
int
)
->
torch
.
Tensor
:
num_bits
:
int
)
->
torch
.
Tensor
:
# Permute zero-points in a similar way to scales, but do not use the
# 和 scale 使用一致的重排逻辑,将[128, 128](fp16) B矩阵中 每个[16, 16]计算块中的对应位置的 zero值 放到一起
# "single" permutation, since zero-points are applied on every MMA
scale_perm
=
get_scale_perms
()
scale_perm
,
_
=
get_scale_perms
()
zp
=
zp
.
reshape
((
-
1
,
len
(
scale_perm
)))[:,
scale_perm
]
zp
=
zp
.
reshape
((
-
1
,
len
(
scale_perm
)))[:,
scale_perm
]
#
Interleave column dim (for the dequantize code) and pack it to int32
#
uint4 混排
if
num_bits
==
4
:
if
num_bits
==
4
:
interleave
=
numpy
.
array
([
0
,
2
,
4
,
6
,
1
,
3
,
5
,
7
])
interleave
=
numpy
.
array
([
0
,
2
,
4
,
6
,
1
,
3
,
5
,
7
])
elif
num_bits
==
8
:
elif
num_bits
==
8
:
interleave
=
numpy
.
array
([
0
,
2
,
1
,
3
])
interleave
=
numpy
.
array
([
0
,
2
,
1
,
3
])
else
:
else
:
raise
Exception
(
"num_bits must be 4 or 8, got {}"
.
format
(
num_bits
))
raise
Exception
(
"num_bits must be 4 or 8, got {}"
.
format
(
num_bits
))
# uint4打包成 int32
zp
=
zp
.
reshape
((
-
1
,
len
(
interleave
)))[:,
interleave
].
ravel
()
zp
=
zp
.
reshape
((
-
1
,
len
(
interleave
)))[:,
interleave
].
ravel
()
zp
=
zp
.
reshape
((
-
1
,
size_n
)).
contiguous
()
zp
=
zp
.
reshape
((
-
1
,
size_n
)).
contiguous
()
zp
=
pack_cols
(
zp
,
num_bits
,
size_k
,
size_n
)
zp
=
pack_cols
(
zp
,
num_bits
,
size_k
,
size_n
)
...
@@ -474,3 +493,86 @@ def apply_awq_marlin_linear(
...
@@ -474,3 +493,86 @@ def apply_awq_marlin_linear(
output
.
add_
(
bias
)
# In-place add
output
.
add_
(
bias
)
# In-place add
return
output
.
reshape
(
out_shape
)
return
output
.
reshape
(
out_shape
)
def
merge_scales_zeros
(
marlin_s
:
torch
.
Tensor
,
marlin_zp
:
torch
.
Tensor
,
data_num_0
:
int
,
data_num_1
:
int
)
->
torch
.
Tensor
:
"""
合并两个 Tensor, 每行交替取 data_num_0 个 float16 和 data_num_1 个 int32。
要求:
- marlin_s 每行长度能被 data_num_0 整除
- marlin_zp 每行长度能被 data_num_1 整除
- 合并后的总字节数必为 4 的倍数
返回:
[N, M] 的 int32 Tensor(行数一致,列数已对齐)
"""
assert
marlin_s
.
shape
[
0
]
==
marlin_zp
.
shape
[
0
],
"Batch size mismatch"
assert
marlin_s
.
dtype
==
torch
.
float16
assert
marlin_zp
.
dtype
==
torch
.
int32
N
,
D0
=
marlin_s
.
shape
_
,
D1
=
marlin_zp
.
shape
assert
D0
%
data_num_0
==
0
,
"marlin_s 每行必须能被 data_num_0 整除"
assert
D1
%
data_num_1
==
0
,
"marlin_zp 每行必须能被 data_num_1 整除"
s_block_count
=
D0
//
data_num_0
zp_block_count
=
D1
//
data_num_1
assert
s_block_count
==
zp_block_count
total_blocks
=
s_block_count
# 转为字节视图
s_bytes
=
marlin_s
.
view
(
torch
.
uint8
).
reshape
(
N
,
-
1
)
zp_bytes
=
marlin_zp
.
view
(
torch
.
uint8
).
reshape
(
N
,
-
1
)
# 每行的合并结果
merged_rows
=
[]
for
i
in
range
(
N
):
s_row
=
s_bytes
[
i
]
zp_row
=
zp_bytes
[
i
]
s_ptr
=
0
zp_ptr
=
0
merged
=
[]
for
_
in
range
(
total_blocks
):
# 如果 s 还有剩余 block,就取
if
s_ptr
<
s_row
.
numel
():
chunk_s
=
s_row
[
s_ptr
:
s_ptr
+
data_num_0
*
2
]
# float16 = 2 字节
merged
.
append
(
chunk_s
)
s_ptr
+=
data_num_0
*
2
# 如果 zp 还有剩余 block,就取
if
zp_ptr
<
zp_row
.
numel
():
chunk_zp
=
zp_row
[
zp_ptr
:
zp_ptr
+
data_num_1
*
4
]
# int32 = 4 字节
merged
.
append
(
chunk_zp
)
zp_ptr
+=
data_num_1
*
4
# 合并所有字节,并直接转换为 int32
merged_bytes
=
torch
.
cat
(
merged
)
# assert merged_bytes.numel() % 4 == 0, "最终字节长度必须是4的倍数"
merged_int32
=
merged_bytes
.
view
(
torch
.
int32
)
merged_rows
.
append
(
merged_int32
)
# 所有合并行长度一致,可以直接堆叠
result
=
torch
.
stack
(
merged_rows
)
return
result
def
awq_marlin_moe_permute_sz
(
s
:
torch
.
Tensor
,
z
:
torch
.
Tensor
,
size_k
:
int
,
size_n
:
int
,
)
->
torch
.
Tensor
:
num_experts
=
s
.
shape
[
0
]
# output = torch.empty((num_experts, size_k // 16, size_n//2 + size_n//8),
# device=z.device,
# dtype=z.dtype)
outputs
=
[]
for
e
in
range
(
num_experts
):
out_sz
=
merge_scales_zeros
(
s
[
e
],
z
[
e
],
128
,
16
)
outputs
.
append
(
out_sz
)
return
torch
.
stack
(
outputs
,
dim
=
0
)
vllm/model_executor/models/deepseek_mtp.py
View file @
c2bcb0ab
...
@@ -164,7 +164,7 @@ class DeepSeekMTP(nn.Module, SupportsPP):
...
@@ -164,7 +164,7 @@ class DeepSeekMTP(nn.Module, SupportsPP):
os
.
environ
[
'LLAMA_NN'
]
=
'0'
os
.
environ
[
'LLAMA_NN'
]
=
'0'
os
.
environ
[
'LM_NN'
]
=
'0'
os
.
environ
[
'LM_NN'
]
=
'0'
# The AWQ layer of MTP uses BlockInt8W8A8.
# The AWQ layer of MTP uses BlockInt8W8A8.
if
self
.
quant_method
==
"moe_wna16"
:
if
self
.
quant_method
==
"moe_wna16"
or
self
.
quant_method
==
"awq_marlin"
:
vllm_config
.
quant_config
=
BlockInt8Config
(
is_checkpoint_int8_serialized
=
True
,
weight_block_size
=
[
128
,
128
])
vllm_config
.
quant_config
=
BlockInt8Config
(
is_checkpoint_int8_serialized
=
True
,
weight_block_size
=
[
128
,
128
])
self
.
model
=
DeepSeekMultiTokenPredictor
(
vllm_config
=
vllm_config
,
self
.
model
=
DeepSeekMultiTokenPredictor
(
vllm_config
=
vllm_config
,
...
...
vllm/platforms/rocm.py
View file @
c2bcb0ab
...
@@ -180,7 +180,7 @@ class RocmPlatform(Platform):
...
@@ -180,7 +180,7 @@ class RocmPlatform(Platform):
supported_quantization
:
list
[
str
]
=
[
supported_quantization
:
list
[
str
]
=
[
"awq"
,
"gptq"
,
"fp8"
,
"compressed-tensors"
,
"fbgemm_fp8"
,
"gguf"
,
"awq"
,
"gptq"
,
"fp8"
,
"compressed-tensors"
,
"fbgemm_fp8"
,
"gguf"
,
"quark"
,
"ptpc_fp8"
,
"moe_wna16"
,
"blockwise_int8"
,
"w8a8_int8"
"quark"
,
"ptpc_fp8"
,
"moe_wna16"
,
"blockwise_int8"
,
"w8a8_int8"
,
"awq_marlin"
]
]
@
classmethod
@
classmethod
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment