Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
4ee85c63
Commit
4ee85c63
authored
Apr 17, 2026
by
lixh6
Committed by
lixh6
Apr 22, 2026
Browse files
[FEATURE] 接入Aiter MoE W8A8 量化模型支持
parent
aef3c487
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
180 additions
and
74 deletions
+180
-74
vllm/envs.py
vllm/envs.py
+4
-0
vllm/model_executor/layers/fused_moe/fused_moe.py
vllm/model_executor/layers/fused_moe/fused_moe.py
+70
-30
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe_marlin.py
...ation/compressed_tensors/compressed_tensors_moe_marlin.py
+106
-44
No files found.
vllm/envs.py
View file @
4ee85c63
...
@@ -167,6 +167,7 @@ if TYPE_CHECKING:
...
@@ -167,6 +167,7 @@ if TYPE_CHECKING:
VLLM_MOE_USE_DEEP_GEMM
:
bool
=
True
VLLM_MOE_USE_DEEP_GEMM
:
bool
=
True
VLLM_USE_DEEP_GEMM_E8M0
:
bool
=
True
VLLM_USE_DEEP_GEMM_E8M0
:
bool
=
True
VLLM_USE_DEEP_GEMM_TMA_ALIGNED_SCALES
:
bool
=
True
VLLM_USE_DEEP_GEMM_TMA_ALIGNED_SCALES
:
bool
=
True
VLLM_USE_AITER_MOE_W8A8
:
bool
=
True
VLLM_DEEP_GEMM_WARMUP
:
Literal
[
VLLM_DEEP_GEMM_WARMUP
:
Literal
[
"skip"
,
"skip"
,
"full"
,
"full"
,
...
@@ -1290,6 +1291,9 @@ environment_variables: dict[str, Callable[[], Any]] = {
...
@@ -1290,6 +1291,9 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_DEEP_GEMM_TMA_ALIGNED_SCALES"
:
lambda
:
bool
(
"VLLM_USE_DEEP_GEMM_TMA_ALIGNED_SCALES"
:
lambda
:
bool
(
int
(
os
.
getenv
(
"VLLM_USE_DEEP_GEMM_TMA_ALIGNED_SCALES"
,
"1"
))
int
(
os
.
getenv
(
"VLLM_USE_DEEP_GEMM_TMA_ALIGNED_SCALES"
,
"1"
))
),
),
"VLLM_USE_AITER_MOE_W8A8"
:
lambda
:
bool
(
int
(
os
.
getenv
(
"VLLM_USE_AITER_MOE_W8A8"
,
"1"
))
),
# DeepGemm JITs the kernels on-demand. The warmup attempts to make DeepGemm
# DeepGemm JITs the kernels on-demand. The warmup attempts to make DeepGemm
# JIT all the required kernels before model execution so there is no
# JIT all the required kernels before model execution so there is no
# JIT'ing in the hot-path. However, this warmup increases the engine
# JIT'ing in the hot-path. However, this warmup increases the engine
...
...
vllm/model_executor/layers/fused_moe/fused_moe.py
View file @
4ee85c63
...
@@ -6,7 +6,9 @@ import functools
...
@@ -6,7 +6,9 @@ import functools
import
json
import
json
import
os
import
os
import
math
import
math
import
sys
import
aiter
from
aiter.moe
import
get_aiter_moe_config
,
aiter_moe
,
MoeQuantType
from
collections.abc
import
Callable
from
collections.abc
import
Callable
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
...
@@ -1858,35 +1860,73 @@ def fused_experts_impl(
...
@@ -1858,35 +1860,73 @@ def fused_experts_impl(
cache13
=
torch
.
empty
(
M
*
top_k_num
*
max
(
N
,
K
if
not
use_nn_moe
else
w2
.
shape
[
2
]),
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
)
cache13
=
torch
.
empty
(
M
*
top_k_num
*
max
(
N
,
K
if
not
use_nn_moe
else
w2
.
shape
[
2
]),
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
)
if
use_int8_w8a8
or
use_fp8_w8a8
:
if
use_int8_w8a8
or
use_fp8_w8a8
:
return
fused_experts_impl_int8
(
hidden_states
=
hidden_states
,
if
envs
.
VLLM_USE_AITER_MOE_W8A8
==
True
:
w1
=
w1
,
K_input
=
hidden_states
.
size
(
1
)
w2
=
w2
,
actual_N2
=
N
//
2
topk_weights
=
topk_weights
,
quant_type
=
MoeQuantType
.
W8A8
topk_ids
=
topk_ids
,
status
,
moe_config
=
get_aiter_moe_config
(
cache13
=
cache13
,
M
=
num_tokens
,
inplace
=
inplace
,
E
=
global_num_experts
,
activation
=
activation
,
N1
=
N
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
N2
=
actual_N2
,
use_fp8_w8a8
=
use_fp8_w8a8
,
K
=
K_input
,
use_int8_w8a8
=
use_int8_w8a8
,
top_k
=
top_k_num
,
use_int8_w8a16
=
False
,
block_size
=
0
,
use_int4_w4a16
=
False
,
dtype
=
hidden_states
.
dtype
,
per_channel_quant
=
per_channel_quant
,
quant_type
=
quant_type
,
global_num_experts
=
global_num_experts
,
)
expert_map
=
expert_map
,
w1_scale
=
w1_scale
,
output
=
aiter_moe
(
w2_scale
=
w2_scale
,
hidden_states
=
hidden_states
,
w1_zp
=
w1_zp
,
w1
=
w1
,
w2_zp
=
w2_zp
,
w2
=
w2
,
a1_scale
=
a1_scale
,
topk_weights
=
topk_weights
,
a2_scale
=
a2_scale
,
topk_ids
=
topk_ids
,
block_shape
=
block_shape
,
moe_config
=
moe_config
,
use_nn_moe
=
False
,
inplace
=
inplace
,
routed_scaling_factor
=
routed_scaling_factor
,
activation
=
activation
,
shared_output
=
shared_output
,
w1_scale
=
w1_scale
,
i_q
=
i_q
,
w2_scale
=
w2_scale
,
i_s
=
i_s
w1_zp
=
w1_zp
,
)
w2_zp
=
w2_zp
,
a1_scale
=
a1_scale
,
a2_scale
=
a2_scale
,
block_shape
=
None
,
global_num_experts
=
global_num_experts
,
expert_map
=
expert_map
,
routed_scaling_factor
=
routed_scaling_factor
,
)
return
output
else
:
return
fused_experts_impl_int8
(
hidden_states
=
hidden_states
,
w1
=
w1
,
w2
=
w2
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
cache13
=
cache13
,
inplace
=
inplace
,
activation
=
activation
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
use_fp8_w8a8
=
use_fp8_w8a8
,
use_int8_w8a8
=
use_int8_w8a8
,
use_int8_w8a16
=
False
,
use_int4_w4a16
=
False
,
per_channel_quant
=
per_channel_quant
,
global_num_experts
=
global_num_experts
,
expert_map
=
expert_map
,
w1_scale
=
w1_scale
,
w2_scale
=
w2_scale
,
w1_zp
=
w1_zp
,
w2_zp
=
w2_zp
,
a1_scale
=
a1_scale
,
a2_scale
=
a2_scale
,
block_shape
=
block_shape
,
use_nn_moe
=
False
,
routed_scaling_factor
=
routed_scaling_factor
,
shared_output
=
shared_output
,
i_q
=
i_q
,
i_s
=
i_s
)
elif
use_int4_w4a8
is
True
:
elif
use_int4_w4a8
is
True
:
return
fused_experts_impl_w4a8
(
hidden_states
=
hidden_states
,
return
fused_experts_impl_w4a8
(
hidden_states
=
hidden_states
,
w1
=
w1
,
w1
=
w1
,
...
...
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe_marlin.py
View file @
4ee85c63
...
@@ -26,6 +26,12 @@ from vllm.model_executor.layers.fused_moe import (
...
@@ -26,6 +26,12 @@ from vllm.model_executor.layers.fused_moe import (
FusedMoEPrepareAndFinalize
,
FusedMoEPrepareAndFinalize
,
FusedMoeWeightScaleSupported
,
FusedMoeWeightScaleSupported
,
)
)
import
aiter
from
aiter.test_common
import
checkAllclose
,
perftest
from
aiter.ops.shuffle
import
moe_layout_shuffle_gemm1
,
moe_layout_shuffle_gemm2
from
aiter.fused_moe
import
fused_topk
,
torch_moe
from
aiter
import
dtypes
,
ActivationType
from
aiter.moe
import
get_aiter_moe_config
,
aiter_moe
,
MoeSolutionType
,
MoeQuantType
try
:
try
:
from
lmslim.layers.fused_moe.fuse_moe_int8_marlin
import
fused_experts_impl_int8_marlin
from
lmslim.layers.fused_moe.fuse_moe_int8_marlin
import
fused_experts_impl_int8_marlin
from
lmslim.layers.fused_moe.fuse_moe_fp8_marlin
import
fused_experts_impl_fp8_marlin
from
lmslim.layers.fused_moe.fuse_moe_fp8_marlin
import
fused_experts_impl_fp8_marlin
...
@@ -369,28 +375,44 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
...
@@ -369,28 +375,44 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
layer
.
w13_input_scale
=
None
layer
.
w13_input_scale
=
None
layer
.
w2_input_scale
=
None
layer
.
w2_input_scale
=
None
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
def
shuffle_w8a8_gemm1
(
self
,
weight_data
):
w1_marlin_list
=
[]
w_i8
=
weight_data
.
to
(
torch
.
int8
)
for
ii
in
range
(
layer
.
w13_weight
.
shape
[
0
]):
return
moe_layout_shuffle_gemm1
(
w_i8
)
if
not
self
.
use_deepep
:
w1_marlin_in
=
get_w8a8_int8_marlin_weights
(
layer
.
w13_weight
[
ii
])
else
:
w1_marlin_in
=
weight8bit_nt_kpack2_marlin1
(
layer
.
w13_weight
[
ii
])
w1_marlin_list
.
append
(
w1_marlin_in
)
w1_marlin
=
torch
.
stack
(
w1_marlin_list
,
dim
=
0
)
del
w1_marlin_list
def
shuffle_w8a8_gemm2
(
self
,
weight_data
):
w2_marlin_list
=
[]
w_i8
=
weight_data
.
to
(
torch
.
int8
)
for
ii
in
range
(
layer
.
w2_weight
.
shape
[
0
]):
return
moe_layout_shuffle_gemm2
(
w_i8
)
if
not
self
.
use_deepep
:
w2_marlin_in
=
get_w8a8_int8_marlin_weights
(
layer
.
w2_weight
[
ii
])
else
:
w2_marlin_in
=
weight8bit_nt_kpack2_marlin1
(
layer
.
w2_weight
[
ii
])
w2_marlin_list
.
append
(
w2_marlin_in
)
w2_marlin
=
torch
.
stack
(
w2_marlin_list
,
dim
=
0
)
layer
.
w13_weight
=
Parameter
(
w1_marlin
,
requires_grad
=
False
)
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
layer
.
w2_weight
=
Parameter
(
w2_marlin
,
requires_grad
=
False
)
if
envs
.
VLLM_USE_AITER_MOE_W8A8
==
True
:
layer
.
w13_weight_scale
=
Parameter
(
layer
.
w13_weight_scale
.
data
,
requires_grad
=
False
)
layer
.
w2_weight_scale
=
Parameter
(
layer
.
w2_weight_scale
.
data
,
requires_grad
=
False
)
shuffled_w13
=
self
.
shuffle_w8a8_gemm1
(
layer
.
w13_weight
)
layer
.
w13_weight
=
Parameter
(
shuffled_w13
.
view
(
*
layer
.
w13_weight
.
shape
),
requires_grad
=
False
)
shuffled_w2
=
self
.
shuffle_w8a8_gemm2
(
layer
.
w2_weight
)
layer
.
w2_weight
=
Parameter
(
shuffled_w2
.
view
(
*
layer
.
w2_weight
.
shape
),
requires_grad
=
False
)
else
:
w1_marlin_list
=
[]
for
ii
in
range
(
layer
.
w13_weight
.
shape
[
0
]):
if
not
self
.
use_deepep
:
w1_marlin_in
=
get_w8a8_int8_marlin_weights
(
layer
.
w13_weight
[
ii
])
else
:
w1_marlin_in
=
w8a8_nt_kpack2_marlin_weight
(
layer
.
w13_weight
[
ii
])
w1_marlin_list
.
append
(
w1_marlin_in
)
w1_marlin
=
torch
.
stack
(
w1_marlin_list
,
dim
=
0
)
del
w1_marlin_list
w2_marlin_list
=
[]
for
ii
in
range
(
layer
.
w2_weight
.
shape
[
0
]):
if
not
self
.
use_deepep
:
w2_marlin_in
=
get_w8a8_int8_marlin_weights
(
layer
.
w2_weight
[
ii
])
else
:
w2_marlin_in
=
w8a8_nt_kpack2_marlin_weight
(
layer
.
w2_weight
[
ii
])
w2_marlin_list
.
append
(
w2_marlin_in
)
w2_marlin
=
torch
.
stack
(
w2_marlin_list
,
dim
=
0
)
layer
.
w13_weight
=
Parameter
(
w1_marlin
,
requires_grad
=
False
)
layer
.
w2_weight
=
Parameter
(
w2_marlin
,
requires_grad
=
False
)
def
apply
(
def
apply
(
self
,
self
,
...
@@ -406,30 +428,70 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
...
@@ -406,30 +428,70 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
from
vllm.model_executor.layers.fused_moe
import
fused_experts
from
vllm.model_executor.layers.fused_moe
import
fused_experts
return
fused_experts_impl_int8_marlin
(
if
envs
.
VLLM_USE_AITER_MOE_W8A8
==
True
:
hidden_states
=
x
,
m_flat
=
x
.
view
(
-
1
,
x
.
shape
[
-
1
])
w1
=
layer
.
w13_weight
,
M
=
m_flat
.
shape
[
0
]
w2
=
layer
.
w2_weight
,
E
=
layer
.
w13_weight
.
size
(
0
)
topk_weights
=
topk_weights
,
K
=
x
.
size
(
-
1
)
topk_ids
=
topk_ids
,
N1
=
layer
.
w13_weight
.
size
(
1
)
inplace
=
True
,
topk
=
topk_ids
.
size
(
1
)
activation
=
layer
.
activation
,
w1_input
=
layer
.
w13_weight
.
view
(
E
,
N1
,
K
)
apply_router_weight_on_input
=
layer
.
apply_router_weight_on_input
,
w2_input
=
layer
.
w2_weight
.
view
(
E
,
K
,
N1
//
2
)
use_int8_w8a8
=
True
,
per_channel_quant
=
True
,
_
,
moe_cfg
=
get_aiter_moe_config
(
global_num_experts
=
layer
.
global_num_experts
,
M
=
M
,
expert_map
=
layer
.
expert_map
,
E
=
E
,
quant_config
=
self
.
moe_quant_config
,
N1
=
N1
,
w1_scale
=
layer
.
w13_weight_scale
,
N2
=
N1
//
2
,
w2_scale
=
layer
.
w2_weight_scale
,
K
=
K
,
a1_scale
=
layer
.
w13_input_scale
,
top_k
=
topk
,
a2_scale
=
layer
.
w2_input_scale
,
block_size
=
0
,
use_nn_moe
=
False
,
dtype
=
x
.
dtype
,
i_q
=
i_q
,
quant_type
=
MoeQuantType
.
W8A8
,
i_s
=
i_s
,
)
shared_output
=
shared_output
,
output
=
aiter_moe
(
routed_scaling_factor
=
routed_scaling_factor
,
hidden_states
=
x
,
)
w1
=
w1_input
,
w2
=
w2_input
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
moe_config
=
moe_cfg
,
inplace
=
False
,
activation
=
getattr
(
layer
,
"activation"
,
"silu"
),
w1_scale
=
layer
.
w13_weight_scale
,
w2_scale
=
layer
.
w2_weight_scale
,
a1_scale
=
getattr
(
layer
,
"w13_input_scale"
,
None
),
a2_scale
=
getattr
(
layer
,
"w2_input_scale"
,
None
),
global_num_experts
=
E
,
expert_map
=
getattr
(
layer
,
"expert_map"
,
None
),
routed_scaling_factor
=
routed_scaling_factor
,
)
return
output
else
:
return
fused_experts_impl_int8_marlin
(
hidden_states
=
x
,
w1
=
layer
.
w13_weight
,
w2
=
layer
.
w2_weight
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
inplace
=
True
,
activation
=
layer
.
activation
,
apply_router_weight_on_input
=
layer
.
apply_router_weight_on_input
,
use_int8_w8a8
=
True
,
per_channel_quant
=
True
,
global_num_experts
=
layer
.
global_num_experts
,
expert_map
=
layer
.
expert_map
,
quant_config
=
self
.
moe_quant_config
,
w1_scale
=
layer
.
w13_weight_scale
,
w2_scale
=
layer
.
w2_weight_scale
,
a1_scale
=
layer
.
w13_input_scale
,
a2_scale
=
layer
.
w2_input_scale
,
use_nn_moe
=
False
,
i_q
=
i_q
,
i_s
=
i_s
,
shared_output
=
shared_output
,
routed_scaling_factor
=
routed_scaling_factor
,
)
def
select_gemm_impl
(
def
select_gemm_impl
(
self
,
self
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment