Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
7b2122d9
Commit
7b2122d9
authored
Feb 08, 2026
by
jujl1
Browse files
feat: w4a8
parent
76ec56bd
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
43 additions
and
171 deletions
+43
-171
vllm/model_executor/layers/fused_moe/fused_moe.py
vllm/model_executor/layers/fused_moe/fused_moe.py
+0
-6
vllm/model_executor/layers/quantization/slimquant_w4a8.py
vllm/model_executor/layers/quantization/slimquant_w4a8.py
+26
-114
vllm/model_executor/layers/quantization/slimquant_w4a8_marlin.py
...del_executor/layers/quantization/slimquant_w4a8_marlin.py
+17
-51
No files found.
vllm/model_executor/layers/fused_moe/fused_moe.py
View file @
7b2122d9
...
@@ -1760,12 +1760,6 @@ def fused_experts_impl(
...
@@ -1760,12 +1760,6 @@ def fused_experts_impl(
cache13
=
cache13
,
cache13
=
cache13
,
activation
=
activation
,
activation
=
activation
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
use_fp8_w8a8
=
False
,
use_int8_w8a8
=
False
,
use_int8_w8a16
=
False
,
use_int4_w4a16
=
False
,
use_int4_w4a8
=
True
,
per_channel_quant
=
per_channel_quant
,
global_num_experts
=
global_num_experts
,
global_num_experts
=
global_num_experts
,
expert_map
=
expert_map
,
expert_map
=
expert_map
,
w1_scale
=
w1_scale
,
w1_scale
=
w1_scale
,
...
...
vllm/model_executor/layers/quantization/slimquant_w4a8.py
View file @
7b2122d9
...
@@ -20,7 +20,7 @@ from lmslim.layers.gemm.int8_utils import (
...
@@ -20,7 +20,7 @@ from lmslim.layers.gemm.int8_utils import (
per_token_group_quant_int8
,
per_token_group_quant_int8
,
per_token_quant_int8
)
per_token_quant_int8
)
from
vllm.utils
import
W8a8GetCacheJSON
from
vllm.utils
import
W8a8GetCacheJSON
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
apply_int8_linear
import
os
import
os
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
from
vllm
import
envs
from
vllm
import
envs
...
@@ -94,7 +94,7 @@ class SlimQuantW4A8Int8LinearMethod(LinearMethodBase):
...
@@ -94,7 +94,7 @@ class SlimQuantW4A8Int8LinearMethod(LinearMethodBase):
def
__init__
(
self
,
quantization_config
:
SlimQuantW4A8Int8Config
):
def
__init__
(
self
,
quantization_config
:
SlimQuantW4A8Int8Config
):
self
.
quantization_config
=
quantization_config
self
.
quantization_config
=
quantization_config
self
.
tritonsingleton
=
W8a8GetCacheJSON
()
self
.
tritonsingleton
=
W8a8GetCacheJSON
()
self
.
w8a8_strategy
=
int
(
os
.
getenv
(
'W8A8_SUPPORT_METHODS'
,
'1'
))
self
.
w8a8_strategy
=
envs
.
VLLM_W8A8_BACKEND
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
n
=
layer
.
weight
.
shape
[
0
]
n
=
layer
.
weight
.
shape
[
0
]
...
@@ -112,7 +112,9 @@ class SlimQuantW4A8Int8LinearMethod(LinearMethodBase):
...
@@ -112,7 +112,9 @@ class SlimQuantW4A8Int8LinearMethod(LinearMethodBase):
for
key
,
value
in
configs_dict
.
items
():
for
key
,
value
in
configs_dict
.
items
():
m
=
int
(
key
.
split
(
'_'
)[
0
])
m
=
int
(
key
.
split
(
'_'
)[
0
])
ops
.
triton_int8_gemm_helper
(
m
=
m
,
n
=
n
,
k
=
k
,
per_token_act_quant
=
True
,
per_out_channel_weight_quant
=
True
,
use_bias
=
False
,
device
=
layer
.
weight
.
device
,
best_config
=
value
)
ops
.
triton_int8_gemm_helper
(
m
=
m
,
n
=
n
,
k
=
k
,
per_token_act_quant
=
True
,
per_out_channel_weight_quant
=
True
,
use_bias
=
False
,
device
=
layer
.
weight
.
device
,
best_config
=
value
)
else
:
elif
self
.
w8a8_strategy
==
3
:
layer
.
weight
.
data
=
layer
.
weight
.
data
.
T
else
:
weight_data
=
layer
.
weight
.
data
weight_data
=
layer
.
weight
.
data
_weight
=
weight_data
.
T
.
contiguous
().
reshape
(
n
,
-
1
)
_weight
=
weight_data
.
T
.
contiguous
().
reshape
(
n
,
-
1
)
layer
.
weight
.
data
=
_weight
layer
.
weight
.
data
=
_weight
...
@@ -159,68 +161,14 @@ class SlimQuantW4A8Int8LinearMethod(LinearMethodBase):
...
@@ -159,68 +161,14 @@ class SlimQuantW4A8Int8LinearMethod(LinearMethodBase):
input_quant_args
:
Optional
[
list
[
torch
.
Tensor
]]
=
None
,
input_quant_args
:
Optional
[
list
[
torch
.
Tensor
]]
=
None
,
silu_quant_args
:
Optional
[
list
[
torch
.
Tensor
]]
=
None
silu_quant_args
:
Optional
[
list
[
torch
.
Tensor
]]
=
None
):
):
x_q
,
x_scale
=
per_token_quant_int8
(
x
)
return
apply_int8_linear
(
input
=
x
,
weight
=
layer
.
weight
,
weight_scale
=
layer
.
weight_scale
,
bias
=
bias
,
w8a8_strategy
=
self
.
w8a8_strategy
,
input_quant_args
=
input_quant_args
,
silu_quant_args
=
silu_quant_args
)
if
self
.
w8a8_strategy
==
1
:
m
=
x_q
.
shape
[
0
]
k
=
x_q
.
shape
[
1
]
n
=
layer
.
weight
.
shape
[
1
]
if
len
(
W8A8_TRITONJSON
.
triton_json_dict
)
==
0
:
best_config
=
None
elif
f
"1_
{
n
}
_
{
k
}
"
in
W8A8_TRITONJSON
.
triton_json_dict
:
if
m
<=
16
:
m_
=
m
elif
m
<=
64
:
m_
=
(
m
+
3
)
&
-
4
#取值到最近的4的倍数
elif
m
<=
160
:
m_
=
(
m
//
8
)
*
8
elif
m
<
200
:
#256
m_
=
160
elif
m
<
480
:
#512
m_
=
256
elif
m
<
960
:
#1024
m_
=
512
elif
m
<
2048
:
m_
=
1024
elif
m
<
4096
:
m_
=
2048
elif
m
<
6000
:
m_
=
4096
else
:
m_
=
8192
best_config
=
W8A8_TRITONJSON
.
triton_json_dict
[
f
"
{
m_
}
_
{
n
}
_
{
k
}
"
]
else
:
best_config
=
None
#if best_config==None:
# print("m:{},n:{},k:{}".format(m,n,k))
# print("config not found!")
return
ops
.
triton_scaled_mm
(
x_q
,
layer
.
weight
,
scale_a
=
x_scale
,
scale_b
=
layer
.
weight_scale
,
out_dtype
=
x
.
dtype
,
bias
=
bias
,
best_config
=
best_config
)
elif
self
.
w8a8_strategy
==
2
:
return
ops
.
cutlass_scaled_mm
(
x_q
,
layer
.
weight
,
scale_a
=
x_scale
,
scale_b
=
layer
.
weight_scale
,
out_dtype
=
x
.
dtype
,
bias
=
bias
)
else
:
return
ops
.
rocblas_scaled_mm
(
x_q
,
layer
.
weight
,
scale_a
=
x_scale
,
scale_b
=
layer
.
weight_scale
,
out_dtype
=
x
.
dtype
,
bias
=
bias
)
class
SlimQuantW4A8Int8MoEMethod
:
class
SlimQuantW4A8Int8MoEMethod
:
...
@@ -256,8 +204,7 @@ class SlimQuantW4A8Int8MoEMethod:
...
@@ -256,8 +204,7 @@ class SlimQuantW4A8Int8MoEMethod:
self
.
quant_config
=
quant_config
self
.
quant_config
=
quant_config
self
.
tritonsingleton
=
W8a8GetCacheJSON
()
self
.
tritonsingleton
=
W8a8GetCacheJSON
()
self
.
moe_quant_config
:
Optional
[
FusedMoEQuantConfig
]
=
None
self
.
moe_quant_config
:
Optional
[
FusedMoEQuantConfig
]
=
None
self
.
fused_experts
:
Optional
[
FusedMoEModularKernel
]
=
None
self
.
moe_mk
:
Optional
[
FusedMoEModularKernel
]
=
None
self
.
topk_indices_dtype
=
None
def
get_fused_moe_quant_config
(
def
get_fused_moe_quant_config
(
self
,
layer
:
torch
.
nn
.
Module
)
->
Optional
[
FusedMoEQuantConfig
]:
self
,
layer
:
torch
.
nn
.
Module
)
->
Optional
[
FusedMoEQuantConfig
]:
...
@@ -270,9 +217,8 @@ class SlimQuantW4A8Int8MoEMethod:
...
@@ -270,9 +217,8 @@ class SlimQuantW4A8Int8MoEMethod:
per_act_token_quant
=
True
,
per_act_token_quant
=
True
,
per_out_ch_quant
=
False
,
per_out_ch_quant
=
False
,
block_shape
=
None
,
block_shape
=
None
,
weight_dtype
=
'int4'
)
)
self
.
moe_quant_config
.
_w1
.
dtype
=
"int4"
self
.
moe_quant_config
.
_w1
.
dtype
=
"int4"
return
self
.
moe_quant_config
return
self
.
moe_quant_config
def
create_weights
(
def
create_weights
(
...
@@ -354,49 +300,15 @@ class SlimQuantW4A8Int8MoEMethod:
...
@@ -354,49 +300,15 @@ class SlimQuantW4A8Int8MoEMethod:
)
)
def
apply
(
def
apply
(
self
,
self
,
layer
:
torch
.
nn
.
Module
,
layer
:
FusedMoE
,
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
top_k
:
int
,
topk_ids
:
torch
.
Tensor
,
renormalize
:
bool
,
use_nn_moe
:
bool
|
None
=
False
,
use_grouped_topk
:
bool
=
False
,
use_fused_gate
:
bool
|
None
=
False
,
topk_group
:
Optional
[
int
]
=
None
,
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
num_expert_group
:
Optional
[
int
]
=
None
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
scoring_func
:
str
=
"softmax"
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
apply_router_weight_on_input
:
bool
=
False
,
activation
:
str
=
"silu"
,
enable_eplb
:
bool
=
False
,
use_nn_moe
:
Optional
[
bool
]
=
False
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
use_fused_gate
:
Optional
[
bool
]
=
False
,
**
_
)
->
torch
.
Tensor
:
from
vllm.model_executor.layers.fused_moe
import
fused_experts
from
vllm.model_executor.layers.fused_moe
import
fused_experts
if
enable_eplb
:
raise
NotImplementedError
(
"EPLB not supported for `SlimQuantW4A8Int8MoEMethod` yet."
)
# Expert selection
topk_weights
,
topk_ids
,
_
=
FusedMoE
.
select_experts
(
hidden_states
=
x
,
router_logits
=
router_logits
,
use_grouped_topk
=
use_grouped_topk
,
top_k
=
top_k
,
renormalize
=
renormalize
,
topk_group
=
topk_group
,
num_expert_group
=
num_expert_group
,
custom_routing_function
=
custom_routing_function
,
scoring_func
=
scoring_func
,
e_score_correction_bias
=
e_score_correction_bias
,
routed_scaling_factor
=
routed_scaling_factor
,
use_fused_gate
=
use_fused_gate
)
return
fused_experts
(
return
fused_experts
(
x
,
x
,
layer
.
w13_weight
,
layer
.
w13_weight
,
...
@@ -404,10 +316,10 @@ class SlimQuantW4A8Int8MoEMethod:
...
@@ -404,10 +316,10 @@ class SlimQuantW4A8Int8MoEMethod:
topk_weights
=
topk_weights
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
topk_ids
=
topk_ids
,
inplace
=
True
,
inplace
=
True
,
activation
=
activation
,
activation
=
layer
.
activation
,
expert_map
=
expert_map
,
expert_map
=
layer
.
expert_map
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
apply_router_weight_on_input
=
layer
.
apply_router_weight_on_input
,
global_num_experts
=
global_num_experts
,
global_num_experts
=
layer
.
global_num_experts
,
quant_config
=
self
.
moe_quant_config
,
quant_config
=
self
.
moe_quant_config
,
use_nn_moe
=
use_nn_moe
,
use_nn_moe
=
use_nn_moe
,
)
)
vllm/model_executor/layers/quantization/slimquant_w4a8_marlin.py
View file @
7b2122d9
...
@@ -17,7 +17,7 @@ from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
...
@@ -17,7 +17,7 @@ from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from
vllm.model_executor.layers.fused_moe.modular_kernel
import
(
from
vllm.model_executor.layers.fused_moe.modular_kernel
import
(
FusedMoEModularKernel
)
FusedMoEModularKernel
)
from
vllm.model_executor.layers.quantization.slimquant_w4a8
import
SlimQuantW4A8Int8LinearMethod
from
vllm.model_executor.layers.quantization.slimquant_w4a8
import
SlimQuantW4A8Int8LinearMethod
from
vllm.model_executor.layers.fused_moe.fused_moe
import
get_moe_cache
try
:
try
:
from
lmslim.layers.fused_moe.fuse_moe_w4a8_marlin
import
fused_experts_impl_w4a8_marlin
from
lmslim.layers.fused_moe.fuse_moe_w4a8_marlin
import
fused_experts_impl_w4a8_marlin
except
Exception
:
except
Exception
:
...
@@ -147,8 +147,7 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
...
@@ -147,8 +147,7 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
self
.
moe
=
moe
self
.
moe
=
moe
self
.
quant_config
=
quant_config
self
.
quant_config
=
quant_config
self
.
moe_quant_config
:
Optional
[
FusedMoEQuantConfig
]
=
None
self
.
moe_quant_config
:
Optional
[
FusedMoEQuantConfig
]
=
None
self
.
fused_experts
:
Optional
[
FusedMoEModularKernel
]
=
None
self
.
moe_mk
:
Optional
[
FusedMoEModularKernel
]
=
None
self
.
topk_indices_dtype
=
None
def
get_fused_moe_quant_config
(
def
get_fused_moe_quant_config
(
self
,
layer
:
torch
.
nn
.
Module
)
:
self
,
layer
:
torch
.
nn
.
Module
)
:
...
@@ -218,46 +217,15 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
...
@@ -218,46 +217,15 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
layer
.
w2_weight
=
Parameter
(
w4a8_weight_repack_impl
(
layer
.
w2_weight
),
requires_grad
=
False
)
layer
.
w2_weight
=
Parameter
(
w4a8_weight_repack_impl
(
layer
.
w2_weight
),
requires_grad
=
False
)
def
apply
(
def
apply
(
self
,
self
,
layer
:
torch
.
nn
.
Module
,
layer
:
FusedMoE
,
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
top_k
:
int
,
topk_ids
:
torch
.
Tensor
,
renormalize
:
bool
,
use_nn_moe
:
bool
|
None
=
False
,
use_grouped_topk
:
bool
=
False
,
use_fused_gate
:
bool
|
None
=
False
,
topk_group
:
Optional
[
int
]
=
None
,
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
num_expert_group
:
Optional
[
int
]
=
None
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
scoring_func
:
str
=
"softmax"
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
apply_router_weight_on_input
:
bool
=
False
,
activation
:
str
=
"silu"
,
enable_eplb
:
bool
=
False
,
use_nn_moe
:
Optional
[
bool
]
=
False
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
use_fused_gate
:
Optional
[
bool
]
=
False
,
**
_
)
->
torch
.
Tensor
:
if
enable_eplb
:
raise
NotImplementedError
(
"EPLB not supported for `SlimQuantW4A8Int8MarlinMoEMethod` yet."
)
# Expert selection
topk_weights
,
topk_ids
,
_
=
FusedMoE
.
select_experts
(
hidden_states
=
x
,
router_logits
=
router_logits
,
use_grouped_topk
=
use_grouped_topk
,
top_k
=
top_k
,
renormalize
=
renormalize
,
topk_group
=
topk_group
,
num_expert_group
=
num_expert_group
,
custom_routing_function
=
custom_routing_function
,
scoring_func
=
scoring_func
,
e_score_correction_bias
=
e_score_correction_bias
,
routed_scaling_factor
=
routed_scaling_factor
,
use_fused_gate
=
use_fused_gate
)
workspace
,
global_reduce_buffer
=
MarlinMoeWorkspace
(
x
.
device
).
get_buffers
()
workspace
,
global_reduce_buffer
=
MarlinMoeWorkspace
(
x
.
device
).
get_buffers
()
return
fused_experts_impl_w4a8_marlin
(
return
fused_experts_impl_w4a8_marlin
(
x
,
x
,
...
@@ -268,15 +236,13 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
...
@@ -268,15 +236,13 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
workspace
=
workspace
,
workspace
=
workspace
,
global_reduce_buffer
=
global_reduce_buffer
,
global_reduce_buffer
=
global_reduce_buffer
,
inplace
=
True
,
inplace
=
True
,
use_int4_w4a8
=
True
,
activation
=
layer
.
activation
,
per_channel_quant
=
True
,
expert_map
=
layer
.
expert_map
,
activation
=
activation
,
apply_router_weight_on_input
=
layer
.
apply_router_weight_on_input
,
expert_map
=
expert_map
,
global_num_experts
=
layer
.
global_num_experts
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
moe_cache_getter
=
get_moe_cache
if
envs
.
VLLM_USE_GLOBAL_CACHE13
else
None
,
global_num_experts
=
global_num_experts
,
w1_scale
=
(
layer
.
w13_weight_scale
),
w1_scale
=
(
layer
.
w13_weight_scale
),
w2_scale
=
(
layer
.
w2_weight_scale
),
w2_scale
=
(
layer
.
w2_weight_scale
),
a1_scale
=
layer
.
w13_input_scale
,
a1_scale
=
layer
.
w13_input_scale
,
a2_scale
=
layer
.
w2_input_scale
,
a2_scale
=
layer
.
w2_input_scale
use_nn_moe
=
use_nn_moe
,
)
)
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment