Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
7017f30c
Commit
7017f30c
authored
Jul 28, 2025
by
gaoqiong
Browse files
修改W4A8 以及W8A8量化量化092接口
parent
98958aed
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
81 additions
and
156 deletions
+81
-156
vllm/model_executor/layers/fused_moe/fused_moe.py
vllm/model_executor/layers/fused_moe/fused_moe.py
+55
-4
vllm/model_executor/layers/quantization/blockwise_int8.py
vllm/model_executor/layers/quantization/blockwise_int8.py
+1
-1
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
...quantization/compressed_tensors/compressed_tensors_moe.py
+5
-141
vllm/model_executor/layers/quantization/w8a8_int8.py
vllm/model_executor/layers/quantization/w8a8_int8.py
+10
-6
vllm/utils/__init__.py
vllm/utils/__init__.py
+10
-4
No files found.
vllm/model_executor/layers/fused_moe/fused_moe.py
View file @
7017f30c
...
@@ -28,7 +28,7 @@ from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
...
@@ -28,7 +28,7 @@ from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
from
lmslim.layers.gemm.int8_utils
import
(
from
lmslim.layers.gemm.int8_utils
import
(
per_token_group_quant_int8
,
per_token_quant_int8
)
per_token_group_quant_int8
,
per_token_quant_int8
)
from
lmslim.layers.fused_moe.fuse_moe_int8
import
(
fused_experts_impl_int8
,
get_w8a8moe_json
)
from
lmslim.layers.fused_moe.fuse_moe_int8
import
(
fused_experts_impl_int8
,
get_w8a8moe_json
)
from
lmslim.layers.fused_moe.fuse_moe_w4a8
import
fused_experts_impl_w4a8
from
vllm.model_executor.layers.fused_moe.prepare_finalize
import
(
from
vllm.model_executor.layers.fused_moe.prepare_finalize
import
(
MoEPrepareAndFinalizeNoEP
)
MoEPrepareAndFinalizeNoEP
)
...
@@ -653,6 +653,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
...
@@ -653,6 +653,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
use_int8_w8a8
:
bool
,
use_int8_w8a8
:
bool
,
use_int8_w8a16
:
bool
,
use_int8_w8a16
:
bool
,
use_int4_w4a16
:
bool
,
use_int4_w4a16
:
bool
,
use_int4_w4a8
:
bool
,
per_channel_quant
:
bool
,
per_channel_quant
:
bool
,
block_shape
:
Optional
[
list
[
int
]]
=
None
,
block_shape
:
Optional
[
list
[
int
]]
=
None
,
use_nn_moe
:
Optional
[
bool
]
=
False
)
->
None
:
use_nn_moe
:
Optional
[
bool
]
=
False
)
->
None
:
...
@@ -1214,6 +1215,8 @@ def get_config_dtype_str(
...
@@ -1214,6 +1215,8 @@ def get_config_dtype_str(
return
"int8_w8a16"
return
"int8_w8a16"
elif
use_int4_w4a16
:
elif
use_int4_w4a16
:
return
"int4_w4a16"
return
"int4_w4a16"
elif
use_int4_w4a16
:
return
"int4_w4a8"
elif
dtype
==
torch
.
float
:
elif
dtype
==
torch
.
float
:
# avoiding cases where kernel fails when float32 MoE
# avoiding cases where kernel fails when float32 MoE
# use fp16/bfloat16 configs
# use fp16/bfloat16 configs
...
@@ -1232,6 +1235,7 @@ def inplace_fused_experts(hidden_states: torch.Tensor,
...
@@ -1232,6 +1235,7 @@ def inplace_fused_experts(hidden_states: torch.Tensor,
use_int8_w8a8
:
bool
=
False
,
use_int8_w8a8
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int4_w4a16
:
bool
=
False
,
use_int4_w4a16
:
bool
=
False
,
use_int4_w4a8
:
bool
=
False
,
per_channel_quant
:
bool
=
False
,
per_channel_quant
:
bool
=
False
,
global_num_experts
:
int
=
-
1
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
...
@@ -1245,7 +1249,7 @@ def inplace_fused_experts(hidden_states: torch.Tensor,
...
@@ -1245,7 +1249,7 @@ def inplace_fused_experts(hidden_states: torch.Tensor,
use_nn_moe
:
Optional
[
bool
]
=
False
)
->
None
:
use_nn_moe
:
Optional
[
bool
]
=
False
)
->
None
:
fused_experts_impl
(
hidden_states
,
w1
,
w2
,
topk_weights
,
topk_ids
,
True
,
fused_experts_impl
(
hidden_states
,
w1
,
w2
,
topk_weights
,
topk_ids
,
True
,
activation
,
apply_router_weight_on_input
,
use_fp8_w8a8
,
activation
,
apply_router_weight_on_input
,
use_fp8_w8a8
,
use_int8_w8a8
,
use_int8_w8a16
,
use_int4_w4a16
,
use_int8_w8a8
,
use_int8_w8a16
,
use_int4_w4a16
,
use_int4_w4a8
,
per_channel_quant
,
global_num_experts
,
expert_map
,
per_channel_quant
,
global_num_experts
,
expert_map
,
w1_scale
,
w2_scale
,
w1_zp
,
w2_zp
,
a1_scale
,
a2_scale
,
w1_scale
,
w2_scale
,
w1_zp
,
w2_zp
,
a1_scale
,
a2_scale
,
block_shape
,
use_nn_moe
)
block_shape
,
use_nn_moe
)
...
@@ -1263,6 +1267,7 @@ def inplace_fused_experts_fake(
...
@@ -1263,6 +1267,7 @@ def inplace_fused_experts_fake(
use_int8_w8a8
:
bool
=
False
,
use_int8_w8a8
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int4_w4a16
:
bool
=
False
,
use_int4_w4a16
:
bool
=
False
,
use_int4_w4a8
:
bool
=
False
,
per_channel_quant
:
bool
=
False
,
per_channel_quant
:
bool
=
False
,
global_num_experts
:
int
=
-
1
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
...
@@ -1298,6 +1303,7 @@ def outplace_fused_experts(
...
@@ -1298,6 +1303,7 @@ def outplace_fused_experts(
use_int8_w8a8
:
bool
=
False
,
use_int8_w8a8
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int4_w4a16
:
bool
=
False
,
use_int4_w4a16
:
bool
=
False
,
use_int4_w4a8
:
bool
=
False
,
per_channel_quant
:
bool
=
False
,
per_channel_quant
:
bool
=
False
,
global_num_experts
:
int
=
-
1
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
...
@@ -1312,7 +1318,7 @@ def outplace_fused_experts(
...
@@ -1312,7 +1318,7 @@ def outplace_fused_experts(
return
fused_experts_impl
(
hidden_states
,
w1
,
w2
,
topk_weights
,
topk_ids
,
return
fused_experts_impl
(
hidden_states
,
w1
,
w2
,
topk_weights
,
topk_ids
,
False
,
activation
,
apply_router_weight_on_input
,
False
,
activation
,
apply_router_weight_on_input
,
use_fp8_w8a8
,
use_int8_w8a8
,
use_int8_w8a16
,
use_fp8_w8a8
,
use_int8_w8a8
,
use_int8_w8a16
,
use_int4_w4a16
,
per_channel_quant
,
use_int4_w4a16
,
use_int4_w4a8
,
per_channel_quant
,
global_num_experts
,
expert_map
,
w1_scale
,
global_num_experts
,
expert_map
,
w1_scale
,
w2_scale
,
w1_zp
,
w2_zp
,
a1_scale
,
a2_scale
,
w2_scale
,
w1_zp
,
w2_zp
,
a1_scale
,
a2_scale
,
block_shape
,
use_nn_moe
)
block_shape
,
use_nn_moe
)
...
@@ -1329,6 +1335,7 @@ def outplace_fused_experts_fake(
...
@@ -1329,6 +1335,7 @@ def outplace_fused_experts_fake(
use_int8_w8a8
:
bool
=
False
,
use_int8_w8a8
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int4_w4a16
:
bool
=
False
,
use_int4_w4a16
:
bool
=
False
,
use_int4_w4a8
:
bool
=
False
,
per_channel_quant
:
bool
=
False
,
per_channel_quant
:
bool
=
False
,
global_num_experts
:
int
=
-
1
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
...
@@ -1383,6 +1390,7 @@ def fused_experts(
...
@@ -1383,6 +1390,7 @@ def fused_experts(
use_int8_w8a8
:
bool
=
False
,
use_int8_w8a8
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int4_w4a16
:
bool
=
False
,
use_int4_w4a16
:
bool
=
False
,
use_int4_w4a8
:
bool
=
False
,
per_channel_quant
:
bool
=
False
,
per_channel_quant
:
bool
=
False
,
global_num_experts
:
int
=
-
1
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
...
@@ -1442,6 +1450,7 @@ def fused_experts(
...
@@ -1442,6 +1450,7 @@ def fused_experts(
use_int8_w8a8
=
use_int8_w8a8
,
use_int8_w8a8
=
use_int8_w8a8
,
use_int8_w8a16
=
use_int8_w8a16
,
use_int8_w8a16
=
use_int8_w8a16
,
use_int4_w4a16
=
use_int4_w4a16
,
use_int4_w4a16
=
use_int4_w4a16
,
use_int4_w4a8
=
use_int4_w4a8
,
per_channel_quant
=
per_channel_quant
,
per_channel_quant
=
per_channel_quant
,
global_num_experts
=
global_num_experts
,
global_num_experts
=
global_num_experts
,
expert_map
=
expert_map
,
expert_map
=
expert_map
,
...
@@ -1468,6 +1477,7 @@ def fused_experts_impl(
...
@@ -1468,6 +1477,7 @@ def fused_experts_impl(
use_int8_w8a8
:
bool
=
False
,
use_int8_w8a8
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int4_w4a16
:
bool
=
False
,
use_int4_w4a16
:
bool
=
False
,
use_int4_w4a8
:
bool
=
False
,
per_channel_quant
:
bool
=
False
,
per_channel_quant
:
bool
=
False
,
global_num_experts
:
int
=
-
1
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
...
@@ -1506,6 +1516,34 @@ def fused_experts_impl(
...
@@ -1506,6 +1516,34 @@ def fused_experts_impl(
block_shape
=
block_shape
,
block_shape
=
block_shape
,
use_nn_moe
=
False
use_nn_moe
=
False
)
)
elif
use_int4_w4a8
is
True
:
return
fused_experts_impl_w4a8
(
hidden_states
=
hidden_states
,
w1
=
w1
,
w2
=
w2
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
inplace
=
inplace
,
activation
=
activation
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
use_fp8_w8a8
=
False
,
use_int8_w8a8
=
False
,
use_int8_w8a16
=
False
,
use_int4_w4a16
=
False
,
use_int4_w4a8
=
True
,
per_channel_quant
=
per_channel_quant
,
global_num_experts
=
global_num_experts
,
expert_map
=
expert_map
,
w1_scale
=
w1_scale
,
w2_scale
=
w2_scale
,
w1_zp
=
w1_zp
,
w2_zp
=
w2_zp
,
a1_scale
=
a1_scale
,
a2_scale
=
a2_scale
,
block_shape
=
block_shape
,
use_nn_moe
=
False
)
#
if
use_int4_w4a16
:
if
use_int4_w4a16
:
assert
hidden_states
.
size
(
1
)
//
2
==
w1
.
size
(
2
),
(
assert
hidden_states
.
size
(
1
)
//
2
==
w1
.
size
(
2
),
(
"Hidden size mismatch"
)
"Hidden size mismatch"
)
...
@@ -1542,12 +1580,14 @@ def fused_experts_impl(
...
@@ -1542,12 +1580,14 @@ def fused_experts_impl(
use_int8_w8a8
=
use_int8_w8a8
,
use_int8_w8a8
=
use_int8_w8a8
,
use_int8_w8a16
=
use_int8_w8a16
,
use_int8_w8a16
=
use_int8_w8a16
,
use_int4_w4a16
=
use_int4_w4a16
,
use_int4_w4a16
=
use_int4_w4a16
,
use_int4_w4a8
=
use_int4_w4a8
,
dtype
=
hidden_states
.
dtype
)
dtype
=
hidden_states
.
dtype
)
qtype
=
get_config_quant_dtype
(
use_fp8_w8a8
=
use_fp8_w8a8
,
qtype
=
get_config_quant_dtype
(
use_fp8_w8a8
=
use_fp8_w8a8
,
use_int8_w8a8
=
use_int8_w8a8
,
use_int8_w8a8
=
use_int8_w8a8
,
use_int8_w8a16
=
use_int8_w8a16
,
use_int8_w8a16
=
use_int8_w8a16
,
use_int4_w4a16
=
use_int4_w4a16
)
use_int4_w4a16
=
use_int4_w4a16
,
use_int4_w4a8
=
use_int4_w4a8
)
get_config_func
=
functools
.
partial
(
get_config_func
=
functools
.
partial
(
try_get_optimal_moe_config
,
try_get_optimal_moe_config
,
...
@@ -1648,6 +1688,7 @@ def fused_experts_impl(
...
@@ -1648,6 +1688,7 @@ def fused_experts_impl(
use_int8_w8a8
=
use_int8_w8a8
,
use_int8_w8a8
=
use_int8_w8a8
,
use_int8_w8a16
=
use_int8_w8a16
,
use_int8_w8a16
=
use_int8_w8a16
,
use_int4_w4a16
=
use_int4_w4a16
,
use_int4_w4a16
=
use_int4_w4a16
,
use_int4_w4a8
=
use_int4_w4a8
,
per_channel_quant
=
per_channel_quant
,
per_channel_quant
=
per_channel_quant
,
block_shape
=
block_shape
,
block_shape
=
block_shape
,
use_nn_moe
=
use_nn_moe
)
use_nn_moe
=
use_nn_moe
)
...
@@ -1687,6 +1728,7 @@ def fused_experts_impl(
...
@@ -1687,6 +1728,7 @@ def fused_experts_impl(
use_int8_w8a8
=
use_int8_w8a8
,
use_int8_w8a8
=
use_int8_w8a8
,
use_int8_w8a16
=
use_int8_w8a16
,
use_int8_w8a16
=
use_int8_w8a16
,
use_int4_w4a16
=
use_int4_w4a16
,
use_int4_w4a16
=
use_int4_w4a16
,
use_int4_w4a8
=
use_int4_w4a8
,
per_channel_quant
=
per_channel_quant
,
per_channel_quant
=
per_channel_quant
,
block_shape
=
block_shape
,
block_shape
=
block_shape
,
use_nn_moe
=
use_nn_moe
)
use_nn_moe
=
use_nn_moe
)
...
@@ -1714,6 +1756,7 @@ def fused_moe(
...
@@ -1714,6 +1756,7 @@ def fused_moe(
use_int8_w8a8
:
bool
=
False
,
use_int8_w8a8
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int4_w4a16
:
bool
=
False
,
use_int4_w4a16
:
bool
=
False
,
use_int4_w4a8
:
bool
=
False
,
per_channel_quant
:
bool
=
False
,
per_channel_quant
:
bool
=
False
,
global_num_experts
:
int
=
-
1
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
...
@@ -1799,6 +1842,7 @@ def fused_moe(
...
@@ -1799,6 +1842,7 @@ def fused_moe(
use_int8_w8a8
=
use_int8_w8a8
,
use_int8_w8a8
=
use_int8_w8a8
,
use_int8_w8a16
=
use_int8_w8a16
,
use_int8_w8a16
=
use_int8_w8a16
,
use_int4_w4a16
=
use_int4_w4a16
,
use_int4_w4a16
=
use_int4_w4a16
,
use_int4_w4a8
=
use_int4_w4a8
,
per_channel_quant
=
per_channel_quant
,
per_channel_quant
=
per_channel_quant
,
global_num_experts
=
global_num_experts
,
global_num_experts
=
global_num_experts
,
expert_map
=
expert_map
,
expert_map
=
expert_map
,
...
@@ -1820,6 +1864,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
...
@@ -1820,6 +1864,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
use_int8_w8a8
:
bool
=
False
,
use_int8_w8a8
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int4_w4a16
:
bool
=
False
,
use_int4_w4a16
:
bool
=
False
,
use_int4_w4a8
:
bool
=
False
,
per_act_token_quant
:
bool
=
False
,
per_act_token_quant
:
bool
=
False
,
block_shape
:
Optional
[
List
[
int
]]
=
None
,
block_shape
:
Optional
[
List
[
int
]]
=
None
,
):
):
...
@@ -1829,6 +1874,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
...
@@ -1829,6 +1874,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
use_int8_w8a8
=
use_int8_w8a8
,
use_int8_w8a8
=
use_int8_w8a8
,
use_int8_w8a16
=
use_int8_w8a16
,
use_int8_w8a16
=
use_int8_w8a16
,
use_int4_w4a16
=
use_int4_w4a16
,
use_int4_w4a16
=
use_int4_w4a16
,
use_int4_w4a8
=
use_int4_w4a8
,
per_act_token_quant
=
per_act_token_quant
,
per_act_token_quant
=
per_act_token_quant
,
block_shape
=
block_shape
,
block_shape
=
block_shape
,
))
))
...
@@ -1837,6 +1883,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
...
@@ -1837,6 +1883,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
self
.
use_int4_w4a16
=
use_int4_w4a16
self
.
use_int4_w4a16
=
use_int4_w4a16
self
.
use_int8_w8a8
=
use_int8_w8a8
self
.
use_int8_w8a8
=
use_int8_w8a8
self
.
use_int8_w8a16
=
use_int8_w8a16
self
.
use_int8_w8a16
=
use_int8_w8a16
self
.
use_int4_w4a8
=
use_int4_w4a8
@
property
@
property
def
activation_formats
(
def
activation_formats
(
...
@@ -1966,6 +2013,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
...
@@ -1966,6 +2013,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
use_int8_w8a8
=
self
.
use_int8_w8a8
,
use_int8_w8a8
=
self
.
use_int8_w8a8
,
use_int8_w8a16
=
self
.
use_int8_w8a16
,
use_int8_w8a16
=
self
.
use_int8_w8a16
,
use_int4_w4a16
=
self
.
use_int4_w4a16
,
use_int4_w4a16
=
self
.
use_int4_w4a16
,
use_int4_w4a8
=
self
.
use_int4_w4a8
,
per_channel_quant
=
self
.
per_act_token_quant
,
per_channel_quant
=
self
.
per_act_token_quant
,
block_shape
=
self
.
block_shape
)
block_shape
=
self
.
block_shape
)
...
@@ -1996,6 +2044,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
...
@@ -1996,6 +2044,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
use_int8_w8a8
=
self
.
use_int8_w8a8
,
use_int8_w8a8
=
self
.
use_int8_w8a8
,
use_int8_w8a16
=
self
.
use_int8_w8a16
,
use_int8_w8a16
=
self
.
use_int8_w8a16
,
use_int4_w4a16
=
self
.
use_int4_w4a16
,
use_int4_w4a16
=
self
.
use_int4_w4a16
,
use_int4_w4a8
=
self
.
use_int4_w4a8
,
per_channel_quant
=
self
.
per_act_token_quant
,
per_channel_quant
=
self
.
per_act_token_quant
,
block_shape
=
self
.
block_shape
)
block_shape
=
self
.
block_shape
)
...
@@ -2005,6 +2054,7 @@ def modular_triton_fused_moe(
...
@@ -2005,6 +2054,7 @@ def modular_triton_fused_moe(
use_int8_w8a8
:
bool
,
use_int8_w8a8
:
bool
,
use_int8_w8a16
:
bool
,
use_int8_w8a16
:
bool
,
use_int4_w4a16
:
bool
,
use_int4_w4a16
:
bool
,
use_int4_w4a8
:
bool
,
per_act_token_quant
:
bool
,
per_act_token_quant
:
bool
,
block_shape
:
Optional
[
List
[
int
]]
=
None
,
block_shape
:
Optional
[
List
[
int
]]
=
None
,
)
->
mk
.
FusedMoEModularKernel
:
)
->
mk
.
FusedMoEModularKernel
:
...
@@ -2015,6 +2065,7 @@ def modular_triton_fused_moe(
...
@@ -2015,6 +2065,7 @@ def modular_triton_fused_moe(
use_int8_w8a8
=
use_int8_w8a8
,
use_int8_w8a8
=
use_int8_w8a8
,
use_int8_w8a16
=
use_int8_w8a16
,
use_int8_w8a16
=
use_int8_w8a16
,
use_int4_w4a16
=
use_int4_w4a16
,
use_int4_w4a16
=
use_int4_w4a16
,
use_int4_w4a8
=
use_int4_w4a8
,
per_act_token_quant
=
per_act_token_quant
,
per_act_token_quant
=
per_act_token_quant
,
block_shape
=
block_shape
,
block_shape
=
block_shape
,
),
),
...
...
vllm/model_executor/layers/quantization/blockwise_int8.py
View file @
7017f30c
...
@@ -477,7 +477,7 @@ class BlockInt8MoEMethod:
...
@@ -477,7 +477,7 @@ class BlockInt8MoEMethod:
from
vllm.model_executor.layers.fused_moe
import
fused_experts
from
vllm.model_executor.layers.fused_moe
import
fused_experts
if
enable_eplb
:
if
enable_eplb
:
raise
NotImplementedError
(
raise
NotImplementedError
(
"EPLB not supported for `Moe
WNA16
Method` yet."
)
"EPLB not supported for `Moe
BlockInt8
Method` yet."
)
# Expert selection
# Expert selection
topk_weights
,
topk_ids
=
FusedMoE
.
select_experts
(
topk_weights
,
topk_ids
=
FusedMoE
.
select_experts
(
hidden_states
=
x
,
hidden_states
=
x
,
...
...
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
View file @
7017f30c
...
@@ -974,147 +974,6 @@ class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod):
...
@@ -974,147 +974,6 @@ class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod):
)
)
class
CompressedTensorsW8A8Int8MoEMethod
(
CompressedTensorsMoEMethod
):
def
__init__
(
self
,
quant_config
:
"CompressedTensorsConfig"
# type: ignore # noqa E501
):
self
.
quant_config
=
quant_config
self
.
weight_quant
=
self
.
quant_config
.
target_scheme_map
[
"Linear"
].
get
(
"weights"
)
self
.
input_quant
=
self
.
quant_config
.
target_scheme_map
[
"Linear"
].
get
(
"input_activations"
)
per_channel
=
(
self
.
weight_quant
.
strategy
==
QuantizationStrategy
.
CHANNEL
and
self
.
input_quant
.
strategy
==
QuantizationStrategy
.
TOKEN
)
if
not
per_channel
:
raise
ValueError
(
"For INT8 Fused MoE layers, we require channelwise, "
"dynamic per token quantization. Found "
f
"
{
self
.
weight_quant
}
,
{
self
.
input_quant
}
"
)
self
.
static_input_scales
=
not
self
.
input_quant
.
dynamic
if
self
.
static_input_scales
:
raise
ValueError
(
"For INT8 Fused MoE layers, we require channelwise, "
"dynamic per token quantization. Found static input scales."
)
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
num_experts
:
int
,
hidden_size
:
int
,
intermediate_size_per_partition
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
):
params_dtype
=
torch
.
int8
# WEIGHTS
w13_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
2
*
intermediate_size_per_partition
,
hidden_size
,
dtype
=
params_dtype
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w13_weight"
,
w13_weight
)
set_weight_attrs
(
w13_weight
,
extra_weight_attrs
)
w2_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
hidden_size
,
intermediate_size_per_partition
,
dtype
=
params_dtype
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w2_weight"
,
w2_weight
)
set_weight_attrs
(
w2_weight
,
extra_weight_attrs
)
# WEIGHT_SCALES
assert
self
.
weight_quant
.
strategy
==
QuantizationStrategy
.
CHANNEL
w13_weight_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
2
*
intermediate_size_per_partition
,
1
,
dtype
=
torch
.
float32
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w13_weight_scale"
,
w13_weight_scale
)
w2_weight_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
hidden_size
,
1
,
dtype
=
torch
.
float32
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w2_weight_scale"
,
w2_weight_scale
)
# Add PER-CHANNEL quantization for FusedMoE.weight_loader.
extra_weight_attrs
.
update
(
{
"quant_method"
:
FusedMoeWeightScaleSupported
.
CHANNEL
.
value
})
set_weight_attrs
(
w13_weight_scale
,
extra_weight_attrs
)
set_weight_attrs
(
w2_weight_scale
,
extra_weight_attrs
)
# INPUT_SCALES
assert
not
self
.
static_input_scales
layer
.
w13_input_scale
=
None
layer
.
w2_input_scale
=
None
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
pass
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
top_k
:
int
,
renormalize
:
bool
,
use_grouped_topk
:
bool
=
False
,
topk_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
scoring_func
:
str
=
"softmax"
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
apply_router_weight_on_input
:
bool
=
False
,
activation
:
str
=
"silu"
,
enable_eplb
:
bool
=
False
,
expert_load_view
:
Optional
[
torch
.
Tensor
]
=
None
,
logical_to_physical_map
:
Optional
[
torch
.
Tensor
]
=
None
,
logical_replica_count
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
if
enable_eplb
:
raise
NotImplementedError
(
"EPLB not supported for "
"`CompressedTensorsW8A8Int8MoEMethod` yet."
)
from
vllm.model_executor.layers.fused_moe
import
fused_experts
topk_weights
,
topk_ids
=
FusedMoE
.
select_experts
(
hidden_states
=
x
,
router_logits
=
router_logits
,
use_grouped_topk
=
use_grouped_topk
,
top_k
=
top_k
,
renormalize
=
renormalize
,
topk_group
=
topk_group
,
num_expert_group
=
num_expert_group
,
custom_routing_function
=
custom_routing_function
,
scoring_func
=
scoring_func
,
e_score_correction_bias
=
e_score_correction_bias
)
return
fused_experts
(
hidden_states
=
x
,
w1
=
layer
.
w13_weight
,
w2
=
layer
.
w2_weight
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
inplace
=
True
,
activation
=
activation
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
use_int8_w8a8
=
True
,
per_channel_quant
=
True
,
global_num_experts
=
global_num_experts
,
expert_map
=
expert_map
,
w1_scale
=
layer
.
w13_weight_scale
,
w2_scale
=
layer
.
w2_weight_scale
,
a1_scale
=
layer
.
w13_input_scale
,
a2_scale
=
layer
.
w2_input_scale
)
class
CompressedTensorsWNA16MarlinMoEMethod
(
CompressedTensorsMoEMethod
):
class
CompressedTensorsWNA16MarlinMoEMethod
(
CompressedTensorsMoEMethod
):
def
__init__
(
def
__init__
(
...
@@ -1729,12 +1588,17 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
...
@@ -1729,12 +1588,17 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
apply_router_weight_on_input
:
bool
=
False
,
apply_router_weight_on_input
:
bool
=
False
,
activation
:
str
=
"silu"
,
activation
:
str
=
"silu"
,
enable_eplb
:
bool
=
False
,
use_nn_moe
:
Optional
[
bool
]
=
False
,
use_nn_moe
:
Optional
[
bool
]
=
False
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
use_fused_gate
:
Optional
[
bool
]
=
False
,
use_fused_gate
:
Optional
[
bool
]
=
False
,
**
_
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
from
vllm.model_executor.layers.fused_moe
import
fused_experts
from
vllm.model_executor.layers.fused_moe
import
fused_experts
if
enable_eplb
:
raise
NotImplementedError
(
"EPLB not supported for `CompressedTensorsW8A8Int8Method` yet."
)
topk_weights
,
topk_ids
=
FusedMoE
.
select_experts
(
topk_weights
,
topk_ids
=
FusedMoE
.
select_experts
(
hidden_states
=
x
,
hidden_states
=
x
,
router_logits
=
router_logits
,
router_logits
=
router_logits
,
...
...
vllm/model_executor/layers/quantization/w8a8_int8.py
View file @
7017f30c
...
@@ -264,7 +264,7 @@ class W8A8Int8MoEMethod:
...
@@ -264,7 +264,7 @@ class W8A8Int8MoEMethod:
# WEIGHTS
# WEIGHTS
w13_weight
=
torch
.
nn
.
Parameter
(
w13_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
torch
.
empty
(
num_experts
,
2
*
intermediate_size
,
hidden_size
,
dtype
=
torch
.
int8
num_experts
,
2
*
intermediate_size
,
hidden_size
//
2
,
dtype
=
torch
.
int8
),
),
requires_grad
=
False
,
requires_grad
=
False
,
)
)
...
@@ -272,7 +272,7 @@ class W8A8Int8MoEMethod:
...
@@ -272,7 +272,7 @@ class W8A8Int8MoEMethod:
set_weight_attrs
(
w13_weight
,
extra_weight_attrs
)
set_weight_attrs
(
w13_weight
,
extra_weight_attrs
)
w2_weight
=
torch
.
nn
.
Parameter
(
w2_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
hidden_size
,
intermediate_size
,
dtype
=
torch
.
int8
),
torch
.
empty
(
num_experts
,
hidden_size
,
intermediate_size
//
2
,
dtype
=
torch
.
int8
),
requires_grad
=
False
,
requires_grad
=
False
,
)
)
layer
.
register_parameter
(
"w2_weight"
,
w2_weight
)
layer
.
register_parameter
(
"w2_weight"
,
w2_weight
)
...
@@ -306,13 +306,13 @@ class W8A8Int8MoEMethod:
...
@@ -306,13 +306,13 @@ class W8A8Int8MoEMethod:
E
=
layer
.
w13_weight
.
shape
[
0
]
E
=
layer
.
w13_weight
.
shape
[
0
]
N1
=
layer
.
w13_weight
.
shape
[
1
]
N1
=
layer
.
w13_weight
.
shape
[
1
]
N2
=
layer
.
w2_weight
.
shape
[
1
]
N2
=
layer
.
w2_weight
.
shape
[
1
]
K
=
layer
.
w2_weight
.
shape
[
2
]
K
=
N1
//
2
if
[
E
,
N1
,
N2
,
K
]
not
in
self
.
tritonsingleton
.
moe_weight_shapes
:
if
[
E
,
N1
,
N2
,
K
]
not
in
self
.
tritonsingleton
.
moe_weight_shapes
:
self
.
tritonsingleton
.
moe_weight_shapes
.
append
([
E
,
N1
,
N2
,
K
])
self
.
tritonsingleton
.
moe_weight_shapes
.
append
([
E
,
N1
,
N2
,
K
])
TOPK
=
self
.
tritonsingleton
.
topk
TOPK
=
self
.
tritonsingleton
.
topk
json_file
=
self
.
tritonsingleton
.
get_moeint8json_name
(
E
,
N1
,
N2
,
K
,
TOPK
)
json_file
=
self
.
tritonsingleton
.
get_moeint8json_name
(
E
,
N1
,
N2
,
K
,
TOPK
,
use_int4_w4a8
=
True
)
configs_dict
=
self
.
tritonsingleton
.
get_moeint8_triton_cache
(
json_file
,
E
,
N1
,
N2
,
K
,
TOPK
)
configs_dict
=
self
.
tritonsingleton
.
get_moeint8_triton_cache
(
json_file
,
E
,
N1
,
N2
,
K
,
TOPK
)
#warmup
#warmup
...
@@ -345,12 +345,16 @@ class W8A8Int8MoEMethod:
...
@@ -345,12 +345,16 @@ class W8A8Int8MoEMethod:
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
apply_router_weight_on_input
:
bool
=
False
,
apply_router_weight_on_input
:
bool
=
False
,
activation
:
str
=
"silu"
,
activation
:
str
=
"silu"
,
enable_eplb
:
bool
=
False
,
use_nn_moe
:
Optional
[
bool
]
=
False
,
use_nn_moe
:
Optional
[
bool
]
=
False
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
use_fused_gate
:
Optional
[
bool
]
=
False
,
use_fused_gate
:
Optional
[
bool
]
=
False
,
**
_
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
from
vllm.model_executor.layers.fused_moe
import
fused_experts
from
vllm.model_executor.layers.fused_moe
import
fused_experts
if
enable_eplb
:
raise
NotImplementedError
(
"EPLB not supported for `W8A8Int8MoeMethod` yet."
)
# Expert selection
# Expert selection
topk_weights
,
topk_ids
=
FusedMoE
.
select_experts
(
topk_weights
,
topk_ids
=
FusedMoE
.
select_experts
(
hidden_states
=
x
,
hidden_states
=
x
,
...
@@ -374,7 +378,7 @@ class W8A8Int8MoEMethod:
...
@@ -374,7 +378,7 @@ class W8A8Int8MoEMethod:
topk_weights
=
topk_weights
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
topk_ids
=
topk_ids
,
inplace
=
True
,
inplace
=
True
,
use_int
8
_w
8
a8
=
True
,
use_int
4
_w
4
a8
=
True
,
per_channel_quant
=
True
,
per_channel_quant
=
True
,
activation
=
activation
,
activation
=
activation
,
expert_map
=
expert_map
,
expert_map
=
expert_map
,
...
...
vllm/utils/__init__.py
View file @
7017f30c
...
@@ -2060,7 +2060,13 @@ class W8a8GetCacheJSON:
...
@@ -2060,7 +2060,13 @@ class W8a8GetCacheJSON:
return
self
.
triton_json_dir
+
f
"/linear_
{
n
}
_
{
k
}
_block[
{
block_n
}
,
{
block_k
}
]_
{
self
.
device_name
}
.json"
return
self
.
triton_json_dir
+
f
"/linear_
{
n
}
_
{
k
}
_block[
{
block_n
}
,
{
block_k
}
]_
{
self
.
device_name
}
.json"
def
get_moeint8json_name
(
self
,
E
,
N1
,
N2
,
K
,
TOPK
,
def
get_moeint8json_name
(
self
,
E
,
N1
,
N2
,
K
,
TOPK
,
block_size
:
Optional
[
list
]
=
None
):
block_size
:
Optional
[
list
]
=
None
,
use_int4_w4a8
:
Optional
[
bool
]
=
False
):
if
use_int4_w4a8
:
if
block_size
is
not
None
:
return
self
.
triton_json_dir
+
f
"/MOE_W4A8INT8[
{
block_size
[
0
]
}
,
{
block_size
[
1
]
}
]_E=
{
E
}
_N1=
{
N1
}
_N2=
{
N2
}
_K=
{
K
}
_TOPK
{
TOPK
}
_
{
self
.
device_name
}
.json"
else
:
return
self
.
triton_json_dir
+
f
"/MOE_W4A8INT8_E=
{
E
}
_N1=
{
N1
}
_N2=
{
N2
}
_K=
{
K
}
_TOPK
{
TOPK
}
_
{
self
.
device_name
}
.json"
else
:
if
block_size
is
not
None
:
if
block_size
is
not
None
:
return
self
.
triton_json_dir
+
f
"/MOE_BLOCKINT8[
{
block_size
[
0
]
}
,
{
block_size
[
1
]
}
]_E=
{
E
}
_N1=
{
N1
}
_N2=
{
N2
}
_K=
{
K
}
_TOPK
{
TOPK
}
_
{
self
.
device_name
}
.json"
return
self
.
triton_json_dir
+
f
"/MOE_BLOCKINT8[
{
block_size
[
0
]
}
,
{
block_size
[
1
]
}
]_E=
{
E
}
_N1=
{
N1
}
_N2=
{
N2
}
_K=
{
K
}
_TOPK
{
TOPK
}
_
{
self
.
device_name
}
.json"
else
:
else
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment