Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
d49bafc5
Commit
d49bafc5
authored
Apr 23, 2026
by
lixh6
Browse files
[FEATURE] 接入Aiter MoE W8A8-Int8 量化模型支持
parent
753b29c0
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
100 additions
and
44 deletions
+100
-44
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe_marlin.py
...ation/compressed_tensors/compressed_tensors_moe_marlin.py
+100
-44
No files found.
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe_marlin.py
View file @
d49bafc5
...
@@ -436,28 +436,44 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
...
@@ -436,28 +436,44 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
layer
.
w13_input_scale
=
None
layer
.
w13_input_scale
=
None
layer
.
w2_input_scale
=
None
layer
.
w2_input_scale
=
None
def
shuffle_w8a8_gemm1
(
self
,
weight_data
):
w_i8
=
weight_data
.
to
(
torch
.
int8
)
return
moe_layout_shuffle_gemm1
(
w_i8
)
def
shuffle_w8a8_gemm2
(
self
,
weight_data
):
w_i8
=
weight_data
.
to
(
torch
.
int8
)
return
moe_layout_shuffle_gemm2
(
w_i8
)
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
w1_marlin_list
=
[]
if
envs
.
VLLM_USE_AITER_MOE_W8A8
==
True
:
for
ii
in
range
(
layer
.
w13_weight
.
shape
[
0
]):
layer
.
w13_weight_scale
=
Parameter
(
layer
.
w13_weight_scale
.
data
,
requires_grad
=
False
)
if
not
self
.
use_deepep
:
layer
.
w2_weight_scale
=
Parameter
(
layer
.
w2_weight_scale
.
data
,
requires_grad
=
False
)
w1_marlin_in
=
get_w8a8_int8_marlin_weights
(
layer
.
w13_weight
[
ii
])
shuffled_w13
=
self
.
shuffle_w8a8_gemm1
(
layer
.
w13_weight
)
else
:
layer
.
w13_weight
=
Parameter
(
shuffled_w13
.
view
(
*
layer
.
w13_weight
.
shape
),
requires_grad
=
False
)
w1_marlin_in
=
weight8bit_nt_kpack2_marlin1
(
layer
.
w13_weight
[
ii
])
shuffled_w2
=
self
.
shuffle_w8a8_gemm2
(
layer
.
w2_weight
)
w1_marlin_list
.
append
(
w1_marlin_in
)
layer
.
w2_weight
=
Parameter
(
shuffled_w2
.
view
(
*
layer
.
w2_weight
.
shape
),
requires_grad
=
False
)
w1_marlin
=
torch
.
stack
(
w1_marlin_list
,
dim
=
0
)
else
:
w1_marlin_list
=
[]
del
w1_marlin_list
for
ii
in
range
(
layer
.
w13_weight
.
shape
[
0
]):
w2_marlin_list
=
[]
if
not
self
.
use_deepep
:
for
ii
in
range
(
layer
.
w2_weight
.
shape
[
0
]):
w1_marlin_in
=
get_w8a8_int8_marlin_weights
(
layer
.
w13_weight
[
ii
])
if
not
self
.
use_deepep
:
else
:
w2_marlin_in
=
get_w8a8_int8_marlin_weights
(
layer
.
w2_weight
[
ii
])
w1_marlin_in
=
w8a8_nt_kpack2_marlin_weight
(
layer
.
w13_weight
[
ii
])
else
:
w1_marlin_list
.
append
(
w1_marlin_in
)
w2_marlin_in
=
weight8bit_nt_kpack2_marlin1
(
layer
.
w2_weight
[
ii
])
w1_marlin
=
torch
.
stack
(
w1_marlin_list
,
dim
=
0
)
w2_marlin_list
.
append
(
w2_marlin_in
)
w2_marlin
=
torch
.
stack
(
w2_marlin_list
,
dim
=
0
)
layer
.
w13_weight
=
Parameter
(
w1_marlin
,
requires_grad
=
False
)
del
w1_marlin_list
layer
.
w2_weight
=
Parameter
(
w2_marlin
,
requires_grad
=
False
)
w2_marlin_list
=
[]
for
ii
in
range
(
layer
.
w2_weight
.
shape
[
0
]):
if
not
self
.
use_deepep
:
w2_marlin_in
=
get_w8a8_int8_marlin_weights
(
layer
.
w2_weight
[
ii
])
else
:
w2_marlin_in
=
w8a8_nt_kpack2_marlin_weight
(
layer
.
w2_weight
[
ii
])
w2_marlin_list
.
append
(
w2_marlin_in
)
w2_marlin
=
torch
.
stack
(
w2_marlin_list
,
dim
=
0
)
layer
.
w13_weight
=
Parameter
(
w1_marlin
,
requires_grad
=
False
)
layer
.
w2_weight
=
Parameter
(
w2_marlin
,
requires_grad
=
False
)
def
apply
(
def
apply
(
self
,
self
,
...
@@ -473,30 +489,70 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
...
@@ -473,30 +489,70 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
from
vllm.model_executor.layers.fused_moe
import
fused_experts
from
vllm.model_executor.layers.fused_moe
import
fused_experts
return
fused_experts_impl_int8_marlin
(
if
envs
.
VLLM_USE_AITER_MOE_W8A8
==
True
:
hidden_states
=
x
,
m_flat
=
x
.
view
(
-
1
,
x
.
shape
[
-
1
])
w1
=
layer
.
w13_weight
,
M
=
m_flat
.
shape
[
0
]
w2
=
layer
.
w2_weight
,
E
=
layer
.
w13_weight
.
size
(
0
)
topk_weights
=
topk_weights
,
K
=
x
.
size
(
-
1
)
topk_ids
=
topk_ids
,
N1
=
layer
.
w13_weight
.
size
(
1
)
inplace
=
True
,
topk
=
topk_ids
.
size
(
1
)
activation
=
layer
.
activation
,
w1_input
=
layer
.
w13_weight
.
view
(
E
,
N1
,
K
)
apply_router_weight_on_input
=
layer
.
apply_router_weight_on_input
,
w2_input
=
layer
.
w2_weight
.
view
(
E
,
K
,
N1
//
2
)
use_int8_w8a8
=
True
,
per_channel_quant
=
True
,
_
,
moe_cfg
=
get_aiter_moe_config
(
global_num_experts
=
layer
.
global_num_experts
,
M
=
M
,
expert_map
=
layer
.
expert_map
,
E
=
E
,
quant_config
=
self
.
moe_quant_config
,
N1
=
N1
,
w1_scale
=
layer
.
w13_weight_scale
,
N2
=
N1
//
2
,
w2_scale
=
layer
.
w2_weight_scale
,
K
=
K
,
a1_scale
=
layer
.
w13_input_scale
,
top_k
=
topk
,
a2_scale
=
layer
.
w2_input_scale
,
block_size
=
0
,
use_nn_moe
=
False
,
dtype
=
x
.
dtype
,
i_q
=
i_q
,
quant_type
=
MoeQuantType
.
W8A8
,
i_s
=
i_s
,
)
shared_output
=
shared_output
,
output
=
aiter_moe
(
routed_scaling_factor
=
routed_scaling_factor
,
hidden_states
=
x
,
)
w1
=
w1_input
,
w2
=
w2_input
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
moe_config
=
moe_cfg
,
inplace
=
False
,
activation
=
getattr
(
layer
,
"activation"
,
"silu"
),
w1_scale
=
layer
.
w13_weight_scale
,
w2_scale
=
layer
.
w2_weight_scale
,
a1_scale
=
getattr
(
layer
,
"w13_input_scale"
,
None
),
a2_scale
=
getattr
(
layer
,
"w2_input_scale"
,
None
),
global_num_experts
=
E
,
expert_map
=
getattr
(
layer
,
"expert_map"
,
None
),
routed_scaling_factor
=
routed_scaling_factor
,
)
return
output
else
:
return
fused_experts_impl_int8_marlin
(
hidden_states
=
x
,
w1
=
layer
.
w13_weight
,
w2
=
layer
.
w2_weight
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
inplace
=
True
,
activation
=
layer
.
activation
,
apply_router_weight_on_input
=
layer
.
apply_router_weight_on_input
,
use_int8_w8a8
=
True
,
per_channel_quant
=
True
,
global_num_experts
=
layer
.
global_num_experts
,
expert_map
=
layer
.
expert_map
,
quant_config
=
self
.
moe_quant_config
,
w1_scale
=
layer
.
w13_weight_scale
,
w2_scale
=
layer
.
w2_weight_scale
,
a1_scale
=
layer
.
w13_input_scale
,
a2_scale
=
layer
.
w2_input_scale
,
use_nn_moe
=
False
,
i_q
=
i_q
,
i_s
=
i_s
,
shared_output
=
shared_output
,
routed_scaling_factor
=
routed_scaling_factor
,
)
def
select_gemm_impl
(
def
select_gemm_impl
(
self
,
self
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment