Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
483acdc4
Commit
483acdc4
authored
Jun 13, 2025
by
xuxz
Browse files
增加fused_experts_impl_int8的接入
parent
942368c7
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
49 additions
and
78 deletions
+49
-78
vllm/model_executor/layers/fused_moe/fused_moe.py
vllm/model_executor/layers/fused_moe/fused_moe.py
+49
-78
No files found.
vllm/model_executor/layers/fused_moe/fused_moe.py
View file @
483acdc4
...
@@ -22,7 +22,7 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import (
...
@@ -22,7 +22,7 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import (
from
lmslim.layers.gemm.int8_utils
import
(
from
lmslim.layers.gemm.int8_utils
import
(
per_token_group_quant_int8
,
per_token_quant_int8
)
per_token_group_quant_int8
,
per_token_quant_int8
)
from
lmslim.layers.fused_moe.fuse_moe_int8
import
(
invoke_fused_moe_kerne
l_int8
,
get_w8a8moe_json
)
from
lmslim.layers.fused_moe.fuse_moe_int8
import
(
fused_experts_imp
l_int8
,
get_w8a8moe_json
)
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.utils
import
direct_register_custom_op
from
vllm.utils
import
direct_register_custom_op
...
@@ -1467,7 +1467,34 @@ def fused_experts_impl(hidden_states: torch.Tensor,
...
@@ -1467,7 +1467,34 @@ def fused_experts_impl(hidden_states: torch.Tensor,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
block_shape
:
Optional
[
List
[
int
]]
=
None
,
block_shape
:
Optional
[
List
[
int
]]
=
None
,
use_nn_moe
:
Optional
[
bool
]
=
False
):
use_nn_moe
:
Optional
[
bool
]
=
False
):
# Check constraints.
# Check constraints.
if
use_int8_w8a8
is
True
:
return
fused_experts_impl_int8
(
hidden_states
=
hidden_states
,
w1
=
w1
,
w2
=
w2
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
inplace
=
inplace
,
activation
=
activation
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
use_fp8_w8a8
=
False
,
use_int8_w8a8
=
True
,
use_int8_w8a16
=
False
,
use_int4_w4a16
=
False
,
per_channel_quant
=
per_channel_quant
,
global_num_experts
=
global_num_experts
,
expert_map
=
expert_map
,
w1_scale
=
w1_scale
,
w2_scale
=
w2_scale
,
w1_zp
=
w1_zp
,
w2_zp
=
w2_zp
,
a1_scale
=
a1_scale
,
a2_scale
=
a2_scale
,
block_shape
=
block_shape
,
use_nn_moe
=
False
)
if
use_int4_w4a16
:
if
use_int4_w4a16
:
assert
hidden_states
.
shape
[
1
]
//
2
==
w1
.
shape
[
assert
hidden_states
.
shape
[
1
]
//
2
==
w1
.
shape
[
2
],
"Hidden size mismatch"
2
],
"Hidden size mismatch"
...
@@ -1499,24 +1526,24 @@ def fused_experts_impl(hidden_states: torch.Tensor,
...
@@ -1499,24 +1526,24 @@ def fused_experts_impl(hidden_states: torch.Tensor,
# https://github.com/vllm-project/vllm/issues/5938
# https://github.com/vllm-project/vllm/issues/5938
CHUNK_SIZE
=
envs
.
VLLM_FUSED_MOE_CHUNK_SIZE
CHUNK_SIZE
=
envs
.
VLLM_FUSED_MOE_CHUNK_SIZE
M
=
min
(
num_tokens
,
CHUNK_SIZE
)
M
=
min
(
num_tokens
,
CHUNK_SIZE
)
if
not
use_int8_w8a8
:
config_dtype
=
get_config_dtype_str
(
use_fp8_w8a8
=
use_fp8_w8a8
,
use_int8_w8a8
=
use_int8_w8a8
,
use_int8_w8a16
=
use_int8_w8a16
,
use_int4_w4a16
=
use_int4_w4a16
,
dtype
=
hidden_states
.
dtype
)
get_config_func
=
functools
.
partial
(
try_get_optimal_moe_config
,
w1
.
shape
,
w2
.
shape
,
top_k_num
,
config_dtype
,
block_shape
=
block_shape
,
use_nn_moe
=
use_nn_moe
,
)
config
=
get_config_func
(
M
)
config_dtype
=
get_config_dtype_str
(
use_fp8_w8a8
=
use_fp8_w8a8
,
use_int8_w8a8
=
use_int8_w8a8
,
use_int8_w8a16
=
use_int8_w8a16
,
use_int4_w4a16
=
use_int4_w4a16
,
dtype
=
hidden_states
.
dtype
)
get_config_func
=
functools
.
partial
(
try_get_optimal_moe_config
,
w1
.
shape
,
w2
.
shape
,
top_k_num
,
config_dtype
,
block_shape
=
block_shape
,
use_nn_moe
=
use_nn_moe
,
)
config
=
get_config_func
(
M
)
# We can reuse the memory between these because by the time we need
# We can reuse the memory between these because by the time we need
# cache3, we're done with cache1
# cache3, we're done with cache1
...
@@ -1569,12 +1596,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
...
@@ -1569,12 +1596,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
curr_topk_ids
=
topk_ids
[
begin_chunk_idx
:
end_chunk_idx
]
curr_topk_ids
=
topk_ids
[
begin_chunk_idx
:
end_chunk_idx
]
curr_topk_weights
=
topk_weights
[
begin_chunk_idx
:
end_chunk_idx
]
curr_topk_weights
=
topk_weights
[
begin_chunk_idx
:
end_chunk_idx
]
if
use_int8_w8a8
:
m
=
curr_hidden_states
.
shape
[
0
]
config1
,
config2
=
get_w8a8moe_json
(
m
)
config
=
config1
qcurr_hidden_states
,
qa1_scale
=
moe_kernel_prepare_input
(
qcurr_hidden_states
,
qa1_scale
=
moe_kernel_prepare_input
(
A
=
curr_hidden_states
,
A
=
curr_hidden_states
,
B
=
w1
,
B
=
w1
,
...
@@ -1596,31 +1618,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
...
@@ -1596,31 +1618,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
moe_align_block_size
(
curr_topk_ids
,
config
[
'BLOCK_SIZE_M'
],
moe_align_block_size
(
curr_topk_ids
,
config
[
'BLOCK_SIZE_M'
],
global_num_experts
,
expert_map
))
global_num_experts
,
expert_map
))
if
use_int8_w8a8
:
invoke_fused_moe_kernel
(
qcurr_hidden_states
,
invoke_fused_moe_kernel_int8
(
qcurr_hidden_states
,
w1
,
intermediate_cache1
,
qa1_scale
,
w1_scale
,
w1_zp
,
curr_topk_weights
,
curr_topk_ids
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
,
apply_router_weight_on_input
,
top_k_num
,
config
,
compute_type
=
compute_type
,
use_fp8_w8a8
=
use_fp8_w8a8
,
use_int8_w8a8
=
use_int8_w8a8
,
use_int8_w8a16
=
use_int8_w8a16
,
use_int4_w4a16
=
use_int4_w4a16
,
per_channel_quant
=
per_channel_quant
,
block_shape
=
block_shape
,
use_nn_moe
=
use_nn_moe
)
else
:
invoke_fused_moe_kernel
(
qcurr_hidden_states
,
w1
,
w1
,
intermediate_cache1
,
intermediate_cache1
,
qa1_scale
,
qa1_scale
,
...
@@ -1663,35 +1661,8 @@ def fused_experts_impl(hidden_states: torch.Tensor,
...
@@ -1663,35 +1661,8 @@ def fused_experts_impl(hidden_states: torch.Tensor,
use_int4_w4a16
=
use_int4_w4a16
,
use_int4_w4a16
=
use_int4_w4a16
,
per_channel_quant
=
per_channel_quant
,
per_channel_quant
=
per_channel_quant
,
block_shape
=
block_shape
)
block_shape
=
block_shape
)
if
use_int8_w8a8
:
config
=
config2
invoke_fused_moe_kernel_int8
(
qintermediate_cache2
,
w2
,
intermediate_cache3
,
qa2_scale
,
w2_scale
,
w2_zp
,
curr_topk_weights
,
curr_topk_ids
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
,
not
apply_router_weight_on_input
,
1
,
config
,
compute_type
=
compute_type
,
use_fp8_w8a8
=
use_fp8_w8a8
,
use_int8_w8a8
=
use_int8_w8a8
,
use_int8_w8a16
=
use_int8_w8a16
,
use_int4_w4a16
=
use_int4_w4a16
,
per_channel_quant
=
per_channel_quant
,
block_shape
=
block_shape
,
use_nn_moe
=
use_nn_moe
)
else
:
invoke_fused_moe_kernel
(
qintermediate_cache2
,
invoke_fused_moe_kernel
(
qintermediate_cache2
,
w2
,
w2
,
intermediate_cache3
,
intermediate_cache3
,
qa2_scale
,
qa2_scale
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment