Commit 483acdc4 authored by xuxz's avatar xuxz
Browse files

增加fused_experts_impl_int8的接入

parent 942368c7
......@@ -22,7 +22,7 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import (
from lmslim.layers.gemm.int8_utils import (
per_token_group_quant_int8, per_token_quant_int8)
from lmslim.layers.fused_moe.fuse_moe_int8 import (invoke_fused_moe_kernel_int8,get_w8a8moe_json)
from lmslim.layers.fused_moe.fuse_moe_int8 import (fused_experts_impl_int8,get_w8a8moe_json)
from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op
......@@ -1468,6 +1468,33 @@ def fused_experts_impl(hidden_states: torch.Tensor,
block_shape: Optional[List[int]] = None,
use_nn_moe: Optional[bool] = False):
# Check constraints.
if use_int8_w8a8 is True:
return fused_experts_impl_int8(hidden_states=hidden_states,
w1=w1,
w2=w2,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=inplace,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
use_fp8_w8a8= False,
use_int8_w8a8= True,
use_int8_w8a16= False,
use_int4_w4a16 = False,
per_channel_quant=per_channel_quant,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=w1_scale,
w2_scale=w2_scale,
w1_zp=w1_zp,
w2_zp=w2_zp,
a1_scale=a1_scale,
a2_scale=a2_scale,
block_shape=block_shape,
use_nn_moe= False
)
if use_int4_w4a16:
assert hidden_states.shape[1] // 2 == w1.shape[
2], "Hidden size mismatch"
......@@ -1499,7 +1526,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
# https://github.com/vllm-project/vllm/issues/5938
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
M = min(num_tokens, CHUNK_SIZE)
if not use_int8_w8a8:
config_dtype = get_config_dtype_str(use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16,
......@@ -1570,11 +1597,6 @@ def fused_experts_impl(hidden_states: torch.Tensor,
curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx]
curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx]
if use_int8_w8a8:
m=curr_hidden_states.shape[0]
config1,config2=get_w8a8moe_json(m)
config=config1
qcurr_hidden_states, qa1_scale = moe_kernel_prepare_input(
A=curr_hidden_states,
B=w1,
......@@ -1596,30 +1618,6 @@ def fused_experts_impl(hidden_states: torch.Tensor,
moe_align_block_size(curr_topk_ids, config['BLOCK_SIZE_M'],
global_num_experts, expert_map))
if use_int8_w8a8:
invoke_fused_moe_kernel_int8(qcurr_hidden_states,
w1,
intermediate_cache1,
qa1_scale,
w1_scale,
w1_zp,
curr_topk_weights,
curr_topk_ids,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
apply_router_weight_on_input,
top_k_num,
config,
compute_type=compute_type,
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16,
per_channel_quant=per_channel_quant,
block_shape=block_shape,
use_nn_moe=use_nn_moe)
else:
invoke_fused_moe_kernel(qcurr_hidden_states,
w1,
intermediate_cache1,
......@@ -1664,33 +1662,6 @@ def fused_experts_impl(hidden_states: torch.Tensor,
per_channel_quant=per_channel_quant,
block_shape=block_shape)
if use_int8_w8a8:
config=config2
invoke_fused_moe_kernel_int8(qintermediate_cache2,
w2,
intermediate_cache3,
qa2_scale,
w2_scale,
w2_zp,
curr_topk_weights,
curr_topk_ids,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
not apply_router_weight_on_input,
1,
config,
compute_type=compute_type,
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16,
per_channel_quant=per_channel_quant,
block_shape=block_shape,
use_nn_moe=use_nn_moe)
else:
invoke_fused_moe_kernel(qintermediate_cache2,
w2,
intermediate_cache3,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment