Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
2767fc34
Commit
2767fc34
authored
Aug 01, 2025
by
gaoqiong
Browse files
增加w4a8相关支持修改
parent
98958aed
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
97 additions
and
148 deletions
+97
-148
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+3
-0
vllm/model_executor/layers/fused_moe/fused_moe.py
vllm/model_executor/layers/fused_moe/fused_moe.py
+56
-4
vllm/model_executor/layers/quantization/blockwise_int8.py
vllm/model_executor/layers/quantization/blockwise_int8.py
+4
-3
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
...quantization/compressed_tensors/compressed_tensors_moe.py
+20
-135
vllm/model_executor/layers/quantization/w8a8_int8.py
vllm/model_executor/layers/quantization/w8a8_int8.py
+9
-5
vllm/v1/attention/backends/mla/common.py
vllm/v1/attention/backends/mla/common.py
+5
-1
No files found.
vllm/engine/arg_utils.py
View file @
2767fc34
...
...
@@ -1594,6 +1594,9 @@ class EngineArgs:
# For pooling tasks the default is False
if
model_config
.
runner_type
!=
"pooling"
:
self
.
enable_chunked_prefill
=
True
if
model_config
.
enable_chunked_prefill
is
not
None
and
\
model_config
.
enable_chunked_prefill
is
False
:
self
.
enable_chunked_prefill
=
False
if
self
.
enable_prefix_caching
is
None
:
self
.
enable_prefix_caching
=
True
else
:
...
...
vllm/model_executor/layers/fused_moe/fused_moe.py
View file @
2767fc34
...
...
@@ -28,7 +28,7 @@ from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
from
lmslim.layers.gemm.int8_utils
import
(
per_token_group_quant_int8
,
per_token_quant_int8
)
from
lmslim.layers.fused_moe.fuse_moe_int8
import
(
fused_experts_impl_int8
,
get_w8a8moe_json
)
from
lmslim.layers.fused_moe.fuse_moe_w4a8
import
fused_experts_impl_w4a8
from
vllm.model_executor.layers.fused_moe.prepare_finalize
import
(
MoEPrepareAndFinalizeNoEP
)
...
...
@@ -653,6 +653,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
use_int8_w8a8
:
bool
,
use_int8_w8a16
:
bool
,
use_int4_w4a16
:
bool
,
use_int4_w4a8
:
bool
,
per_channel_quant
:
bool
,
block_shape
:
Optional
[
list
[
int
]]
=
None
,
use_nn_moe
:
Optional
[
bool
]
=
False
)
->
None
:
...
...
@@ -1205,7 +1206,8 @@ def get_config_dtype_str(
use_int4_w4a16
:
Optional
[
bool
]
=
False
,
use_int8_w8a16
:
Optional
[
bool
]
=
False
,
use_fp8_w8a8
:
Optional
[
bool
]
=
False
,
use_int8_w8a8
:
Optional
[
bool
]
=
False
)
->
Optional
[
str
]:
use_int8_w8a8
:
Optional
[
bool
]
=
False
,
use_int4_w4a8
:
Optional
[
bool
]
=
False
)
->
Optional
[
str
]:
if
use_fp8_w8a8
:
return
"fp8_w8a8"
elif
use_int8_w8a8
:
...
...
@@ -1214,6 +1216,8 @@ def get_config_dtype_str(
return
"int8_w8a16"
elif
use_int4_w4a16
:
return
"int4_w4a16"
elif
use_int4_w4a8
:
return
"int4_w4a8"
elif
dtype
==
torch
.
float
:
# avoiding cases where kernel fails when float32 MoE
# use fp16/bfloat16 configs
...
...
@@ -1232,6 +1236,7 @@ def inplace_fused_experts(hidden_states: torch.Tensor,
use_int8_w8a8
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int4_w4a16
:
bool
=
False
,
use_int4_w4a8
:
bool
=
False
,
per_channel_quant
:
bool
=
False
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
...
...
@@ -1245,7 +1250,7 @@ def inplace_fused_experts(hidden_states: torch.Tensor,
use_nn_moe
:
Optional
[
bool
]
=
False
)
->
None
:
fused_experts_impl
(
hidden_states
,
w1
,
w2
,
topk_weights
,
topk_ids
,
True
,
activation
,
apply_router_weight_on_input
,
use_fp8_w8a8
,
use_int8_w8a8
,
use_int8_w8a16
,
use_int4_w4a16
,
use_int8_w8a8
,
use_int8_w8a16
,
use_int4_w4a16
,
use_int4_w4a8
,
per_channel_quant
,
global_num_experts
,
expert_map
,
w1_scale
,
w2_scale
,
w1_zp
,
w2_zp
,
a1_scale
,
a2_scale
,
block_shape
,
use_nn_moe
)
...
...
@@ -1263,6 +1268,7 @@ def inplace_fused_experts_fake(
use_int8_w8a8
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int4_w4a16
:
bool
=
False
,
use_int4_w4a8
:
bool
=
False
,
per_channel_quant
:
bool
=
False
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
...
...
@@ -1298,6 +1304,7 @@ def outplace_fused_experts(
use_int8_w8a8
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int4_w4a16
:
bool
=
False
,
use_int4_w4a8
:
bool
=
False
,
per_channel_quant
:
bool
=
False
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
...
...
@@ -1312,7 +1319,7 @@ def outplace_fused_experts(
return
fused_experts_impl
(
hidden_states
,
w1
,
w2
,
topk_weights
,
topk_ids
,
False
,
activation
,
apply_router_weight_on_input
,
use_fp8_w8a8
,
use_int8_w8a8
,
use_int8_w8a16
,
use_int4_w4a16
,
per_channel_quant
,
use_int4_w4a16
,
use_int4_w4a8
,
per_channel_quant
,
global_num_experts
,
expert_map
,
w1_scale
,
w2_scale
,
w1_zp
,
w2_zp
,
a1_scale
,
a2_scale
,
block_shape
,
use_nn_moe
)
...
...
@@ -1329,6 +1336,7 @@ def outplace_fused_experts_fake(
use_int8_w8a8
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int4_w4a16
:
bool
=
False
,
use_int4_w4a8
:
bool
=
False
,
per_channel_quant
:
bool
=
False
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
...
...
@@ -1383,6 +1391,7 @@ def fused_experts(
use_int8_w8a8
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int4_w4a16
:
bool
=
False
,
use_int4_w4a8
:
bool
=
False
,
per_channel_quant
:
bool
=
False
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
...
...
@@ -1442,6 +1451,7 @@ def fused_experts(
use_int8_w8a8
=
use_int8_w8a8
,
use_int8_w8a16
=
use_int8_w8a16
,
use_int4_w4a16
=
use_int4_w4a16
,
use_int4_w4a8
=
use_int4_w4a8
,
per_channel_quant
=
per_channel_quant
,
global_num_experts
=
global_num_experts
,
expert_map
=
expert_map
,
...
...
@@ -1468,6 +1478,7 @@ def fused_experts_impl(
use_int8_w8a8
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int4_w4a16
:
bool
=
False
,
use_int4_w4a8
:
bool
=
False
,
per_channel_quant
:
bool
=
False
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
...
...
@@ -1506,6 +1517,34 @@ def fused_experts_impl(
block_shape
=
block_shape
,
use_nn_moe
=
False
)
elif
use_int4_w4a8
is
True
:
return
fused_experts_impl_w4a8
(
hidden_states
=
hidden_states
,
w1
=
w1
,
w2
=
w2
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
inplace
=
inplace
,
activation
=
activation
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
use_fp8_w8a8
=
False
,
use_int8_w8a8
=
False
,
use_int8_w8a16
=
False
,
use_int4_w4a16
=
False
,
use_int4_w4a8
=
True
,
per_channel_quant
=
per_channel_quant
,
global_num_experts
=
global_num_experts
,
expert_map
=
expert_map
,
w1_scale
=
w1_scale
,
w2_scale
=
w2_scale
,
w1_zp
=
w1_zp
,
w2_zp
=
w2_zp
,
a1_scale
=
a1_scale
,
a2_scale
=
a2_scale
,
block_shape
=
block_shape
,
use_nn_moe
=
False
)
#
if
use_int4_w4a16
:
assert
hidden_states
.
size
(
1
)
//
2
==
w1
.
size
(
2
),
(
"Hidden size mismatch"
)
...
...
@@ -1542,6 +1581,7 @@ def fused_experts_impl(
use_int8_w8a8
=
use_int8_w8a8
,
use_int8_w8a16
=
use_int8_w8a16
,
use_int4_w4a16
=
use_int4_w4a16
,
use_int4_w4a8
=
use_int4_w4a8
,
dtype
=
hidden_states
.
dtype
)
qtype
=
get_config_quant_dtype
(
use_fp8_w8a8
=
use_fp8_w8a8
,
...
...
@@ -1648,6 +1688,7 @@ def fused_experts_impl(
use_int8_w8a8
=
use_int8_w8a8
,
use_int8_w8a16
=
use_int8_w8a16
,
use_int4_w4a16
=
use_int4_w4a16
,
use_int4_w4a8
=
use_int4_w4a8
,
per_channel_quant
=
per_channel_quant
,
block_shape
=
block_shape
,
use_nn_moe
=
use_nn_moe
)
...
...
@@ -1687,6 +1728,7 @@ def fused_experts_impl(
use_int8_w8a8
=
use_int8_w8a8
,
use_int8_w8a16
=
use_int8_w8a16
,
use_int4_w4a16
=
use_int4_w4a16
,
use_int4_w4a8
=
use_int4_w4a8
,
per_channel_quant
=
per_channel_quant
,
block_shape
=
block_shape
,
use_nn_moe
=
use_nn_moe
)
...
...
@@ -1714,6 +1756,7 @@ def fused_moe(
use_int8_w8a8
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int4_w4a16
:
bool
=
False
,
use_int4_w4a8
:
bool
=
False
,
per_channel_quant
:
bool
=
False
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
...
...
@@ -1799,6 +1842,7 @@ def fused_moe(
use_int8_w8a8
=
use_int8_w8a8
,
use_int8_w8a16
=
use_int8_w8a16
,
use_int4_w4a16
=
use_int4_w4a16
,
use_int4_w4a8
=
use_int4_w4a8
,
per_channel_quant
=
per_channel_quant
,
global_num_experts
=
global_num_experts
,
expert_map
=
expert_map
,
...
...
@@ -1820,6 +1864,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
use_int8_w8a8
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int4_w4a16
:
bool
=
False
,
use_int4_w4a8
:
bool
=
False
,
per_act_token_quant
:
bool
=
False
,
block_shape
:
Optional
[
List
[
int
]]
=
None
,
):
...
...
@@ -1829,6 +1874,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
use_int8_w8a8
=
use_int8_w8a8
,
use_int8_w8a16
=
use_int8_w8a16
,
use_int4_w4a16
=
use_int4_w4a16
,
use_int4_w4a8
=
use_int4_w4a8
,
per_act_token_quant
=
per_act_token_quant
,
block_shape
=
block_shape
,
))
...
...
@@ -1837,6 +1883,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
self
.
use_int4_w4a16
=
use_int4_w4a16
self
.
use_int8_w8a8
=
use_int8_w8a8
self
.
use_int8_w8a16
=
use_int8_w8a16
self
.
use_int4_w4a8
=
use_int4_w4a8
@
property
def
activation_formats
(
...
...
@@ -1914,6 +1961,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
config_dtype
=
get_config_dtype_str
(
use_fp8_w8a8
=
self
.
use_fp8_w8a8
,
use_int8_w8a16
=
self
.
use_int8_w8a16
,
use_int4_w4a16
=
self
.
use_int4_w4a16
,
use_int4_w4a8
=
self
.
use_int4_w4a8
,
dtype
=
hidden_states
.
dtype
)
config
=
try_get_optimal_moe_config
(
...
...
@@ -1966,6 +2014,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
use_int8_w8a8
=
self
.
use_int8_w8a8
,
use_int8_w8a16
=
self
.
use_int8_w8a16
,
use_int4_w4a16
=
self
.
use_int4_w4a16
,
use_int4_w4a8
=
self
.
use_int4_w4a8
,
per_channel_quant
=
self
.
per_act_token_quant
,
block_shape
=
self
.
block_shape
)
...
...
@@ -1996,6 +2045,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
use_int8_w8a8
=
self
.
use_int8_w8a8
,
use_int8_w8a16
=
self
.
use_int8_w8a16
,
use_int4_w4a16
=
self
.
use_int4_w4a16
,
use_int4_w4a8
=
self
.
use_int4_w4a8
,
per_channel_quant
=
self
.
per_act_token_quant
,
block_shape
=
self
.
block_shape
)
...
...
@@ -2005,6 +2055,7 @@ def modular_triton_fused_moe(
use_int8_w8a8
:
bool
,
use_int8_w8a16
:
bool
,
use_int4_w4a16
:
bool
,
use_int4_w4a8
:
bool
,
per_act_token_quant
:
bool
,
block_shape
:
Optional
[
List
[
int
]]
=
None
,
)
->
mk
.
FusedMoEModularKernel
:
...
...
@@ -2015,6 +2066,7 @@ def modular_triton_fused_moe(
use_int8_w8a8
=
use_int8_w8a8
,
use_int8_w8a16
=
use_int8_w8a16
,
use_int4_w4a16
=
use_int4_w4a16
,
use_int4_w4a8
=
use_int4_w4a8
,
per_act_token_quant
=
per_act_token_quant
,
block_shape
=
block_shape
,
),
...
...
vllm/model_executor/layers/quantization/blockwise_int8.py
View file @
2767fc34
...
...
@@ -432,7 +432,7 @@ class BlockInt8MoEMethod:
E
=
layer
.
w13_weight
.
shape
[
0
]
N1
=
layer
.
w13_weight
.
shape
[
1
]
N2
=
layer
.
w2_weight
.
shape
[
1
]
K
=
layer
.
w2_weight
.
shape
[
2
]
K
=
N
//
2
if
[
E
,
N1
,
N2
,
K
]
not
in
self
.
tritonsingleton
.
moe_weight_shapes
:
self
.
tritonsingleton
.
moe_weight_shapes
.
append
([
E
,
N1
,
N2
,
K
])
...
...
@@ -445,7 +445,8 @@ class BlockInt8MoEMethod:
#warmup
if
configs_dict
:
self
.
tritonsingleton
.
triton_moejson_dict
.
update
(
configs_dict
)
#print("*************self.tritonsingleton:",self.tritonsingleton)
#生成模型配置文件
self
.
tritonsingleton
.
gen_model_json
(
block_size
)
...
...
@@ -477,7 +478,7 @@ class BlockInt8MoEMethod:
from
vllm.model_executor.layers.fused_moe
import
fused_experts
if
enable_eplb
:
raise
NotImplementedError
(
"EPLB not supported for `Moe
WNA16
Method` yet."
)
"EPLB not supported for `Moe
BlockInt8
Method` yet."
)
# Expert selection
topk_weights
,
topk_ids
=
FusedMoE
.
select_experts
(
hidden_states
=
x
,
...
...
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
View file @
2767fc34
...
...
@@ -1053,6 +1053,22 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
layer
.
w2_input_scale
=
None
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
E
=
layer
.
w13_weight
.
shape
[
0
]
N1
=
layer
.
w13_weight
.
shape
[
1
]
N2
=
layer
.
w2_weight
.
shape
[
1
]
K
=
layer
.
w2_weight
.
shape
[
2
]
if
[
E
,
N1
,
N2
,
K
]
not
in
self
.
tritonsingleton
.
moe_weight_shapes
:
self
.
tritonsingleton
.
moe_weight_shapes
.
append
([
E
,
N1
,
N2
,
K
])
TOPK
=
self
.
tritonsingleton
.
topk
json_file
=
self
.
tritonsingleton
.
get_moeint8json_name
(
E
,
N1
,
N2
,
K
,
TOPK
)
configs_dict
=
self
.
tritonsingleton
.
get_moeint8_triton_cache
(
json_file
,
E
,
N1
,
N2
,
K
,
TOPK
)
#warmup
if
configs_dict
:
self
.
tritonsingleton
.
triton_moejson_dict
.
update
(
configs_dict
)
pass
def
apply
(
...
...
@@ -1076,6 +1092,7 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
expert_load_view
:
Optional
[
torch
.
Tensor
]
=
None
,
logical_to_physical_map
:
Optional
[
torch
.
Tensor
]
=
None
,
logical_replica_count
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
if
enable_eplb
:
raise
NotImplementedError
(
...
...
@@ -1112,7 +1129,8 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
w1_scale
=
layer
.
w13_weight_scale
,
w2_scale
=
layer
.
w2_weight_scale
,
a1_scale
=
layer
.
w13_input_scale
,
a2_scale
=
layer
.
w2_input_scale
)
a2_scale
=
layer
.
w2_input_scale
,
use_nn_moe
=
False
)
class
CompressedTensorsWNA16MarlinMoEMethod
(
CompressedTensorsMoEMethod
):
...
...
@@ -1636,137 +1654,4 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
w1_zp
=
None
,
w2_zp
=
None
,
block_shape
=
[
0
,
self
.
group_size
])
class
CompressedTensorsW8A8Int8MoEMethod
(
CompressedTensorsMoEMethod
):
def
__init__
(
self
,
quant_config
:
"CompressedTensorsConfig"
# type: ignore # noqa E501
):
self
.
quant_config
=
quant_config
self
.
weight_quant
=
self
.
quant_config
.
target_scheme_map
[
"Linear"
].
get
(
"weights"
)
self
.
input_quant
=
self
.
quant_config
.
target_scheme_map
[
"Linear"
].
get
(
"input_activations"
)
if
not
(
self
.
weight_quant
.
strategy
==
QuantizationStrategy
.
CHANNEL
and
self
.
input_quant
.
strategy
==
QuantizationStrategy
.
TOKEN
):
raise
ValueError
(
"For INT8 Fused MoE layers, only per-channel scales"
"for activations and per-token scales for activations are supported. Found "
f
"
{
self
.
weight_quant
}
,
{
self
.
input_quant
}
"
)
self
.
static_input_scales
=
not
self
.
input_quant
.
dynamic
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
num_experts
:
int
,
hidden_size
:
int
,
intermediate_size_per_partition
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
):
params_dtype
=
torch
.
int8
# WEIGHTS
w13_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
2
*
intermediate_size_per_partition
,
hidden_size
,
dtype
=
params_dtype
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w13_weight"
,
w13_weight
)
set_weight_attrs
(
w13_weight
,
extra_weight_attrs
)
w2_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
hidden_size
,
intermediate_size_per_partition
,
dtype
=
params_dtype
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w2_weight"
,
w2_weight
)
set_weight_attrs
(
w2_weight
,
extra_weight_attrs
)
w13_weight_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
2
*
intermediate_size_per_partition
,
1
,
dtype
=
torch
.
float32
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w13_weight_scale"
,
w13_weight_scale
)
w2_weight_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
hidden_size
,
1
,
dtype
=
torch
.
float32
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w2_weight_scale"
,
w2_weight_scale
)
extra_weight_attrs
.
update
({
"quant_method"
:
FusedMoeWeightScaleSupported
.
CHANNEL
.
value
})
set_weight_attrs
(
w13_weight_scale
,
extra_weight_attrs
)
set_weight_attrs
(
w2_weight_scale
,
extra_weight_attrs
)
# INPUT_SCALES
if
self
.
static_input_scales
:
raise
ValueError
(
"For INT8 Fused MoE layers, only dynamic scales"
"for activations are supported. Found "
f
"
{
self
.
input_quant
}
"
)
else
:
layer
.
w13_input_scale
=
None
layer
.
w2_input_scale
=
None
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
top_k
:
int
,
renormalize
:
bool
,
use_grouped_topk
:
bool
=
False
,
topk_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
scoring_func
:
str
=
"softmax"
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
apply_router_weight_on_input
:
bool
=
False
,
activation
:
str
=
"silu"
,
use_nn_moe
:
Optional
[
bool
]
=
False
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
use_fused_gate
:
Optional
[
bool
]
=
False
,
)
->
torch
.
Tensor
:
from
vllm.model_executor.layers.fused_moe
import
fused_experts
topk_weights
,
topk_ids
=
FusedMoE
.
select_experts
(
hidden_states
=
x
,
router_logits
=
router_logits
,
use_grouped_topk
=
use_grouped_topk
,
top_k
=
top_k
,
renormalize
=
renormalize
,
topk_group
=
topk_group
,
num_expert_group
=
num_expert_group
,
custom_routing_function
=
custom_routing_function
,
scoring_func
=
scoring_func
,
e_score_correction_bias
=
e_score_correction_bias
,
routed_scaling_factor
=
routed_scaling_factor
,
use_fused_gate
=
use_fused_gate
)
return
fused_experts
(
x
,
layer
.
w13_weight
,
layer
.
w2_weight
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
inplace
=
True
,
use_int8_w8a8
=
True
,
per_channel_quant
=
True
,
activation
=
activation
,
expert_map
=
expert_map
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
global_num_experts
=
global_num_experts
,
w1_scale
=
(
layer
.
w13_weight_scale
),
w2_scale
=
(
layer
.
w2_weight_scale
),
a1_scale
=
layer
.
w13_input_scale
,
a2_scale
=
layer
.
w2_input_scale
,
use_nn_moe
=
use_nn_moe
,
)
\ No newline at end of file
vllm/model_executor/layers/quantization/w8a8_int8.py
100755 → 100644
View file @
2767fc34
...
...
@@ -264,7 +264,7 @@ class W8A8Int8MoEMethod:
# WEIGHTS
w13_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
2
*
intermediate_size
,
hidden_size
,
dtype
=
torch
.
int8
num_experts
,
2
*
intermediate_size
,
hidden_size
//
2
,
dtype
=
torch
.
int8
),
requires_grad
=
False
,
)
...
...
@@ -272,7 +272,7 @@ class W8A8Int8MoEMethod:
set_weight_attrs
(
w13_weight
,
extra_weight_attrs
)
w2_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
hidden_size
,
intermediate_size
,
dtype
=
torch
.
int8
),
torch
.
empty
(
num_experts
,
hidden_size
,
intermediate_size
//
2
,
dtype
=
torch
.
int8
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w2_weight"
,
w2_weight
)
...
...
@@ -306,7 +306,7 @@ class W8A8Int8MoEMethod:
E
=
layer
.
w13_weight
.
shape
[
0
]
N1
=
layer
.
w13_weight
.
shape
[
1
]
N2
=
layer
.
w2_weight
.
shape
[
1
]
K
=
layer
.
w2_weight
.
shape
[
2
]
K
=
N1
//
2
if
[
E
,
N1
,
N2
,
K
]
not
in
self
.
tritonsingleton
.
moe_weight_shapes
:
self
.
tritonsingleton
.
moe_weight_shapes
.
append
([
E
,
N1
,
N2
,
K
])
...
...
@@ -345,12 +345,16 @@ class W8A8Int8MoEMethod:
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
apply_router_weight_on_input
:
bool
=
False
,
activation
:
str
=
"silu"
,
enable_eplb
:
bool
=
False
,
use_nn_moe
:
Optional
[
bool
]
=
False
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
use_fused_gate
:
Optional
[
bool
]
=
False
,
**
_
)
->
torch
.
Tensor
:
from
vllm.model_executor.layers.fused_moe
import
fused_experts
if
enable_eplb
:
raise
NotImplementedError
(
"EPLB not supported for `W8A8Int8MoeMethod` yet."
)
# Expert selection
topk_weights
,
topk_ids
=
FusedMoE
.
select_experts
(
hidden_states
=
x
,
...
...
@@ -374,7 +378,7 @@ class W8A8Int8MoEMethod:
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
inplace
=
True
,
use_int
8
_w
8
a8
=
True
,
use_int
4
_w
4
a8
=
True
,
per_channel_quant
=
True
,
activation
=
activation
,
expert_map
=
expert_map
,
...
...
vllm/v1/attention/backends/mla/common.py
View file @
2767fc34
...
...
@@ -912,7 +912,11 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
)
->
torch
.
Tensor
:
assert
attn_metadata
.
prefill
is
not
None
has_context
=
attn_metadata
.
prefill
.
chunked_context
is
not
None
if
envs
.
VLLM_HAS_CONTEXT_DEFAULT
:
has_context
=
attn_metadata
.
prefill
.
chunked_context
is
not
None
else
:
has_context
=
False
kv_nope
=
self
.
kv_b_proj
(
kv_c_normed
)[
0
].
view
(
\
-
1
,
self
.
num_heads
,
self
.
qk_nope_head_dim
+
self
.
v_head_dim
)
k_nope
,
v
=
kv_nope
\
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment