Commit 41e6d686 authored by zhuwenwen's avatar zhuwenwen
Browse files

add fused_experts_impl of int8

parent a246d08c
...@@ -1481,6 +1481,31 @@ def fused_experts_impl( ...@@ -1481,6 +1481,31 @@ def fused_experts_impl(
use_nn_moe: Optional[bool] = False, use_nn_moe: Optional[bool] = False,
) -> torch.Tensor: ) -> torch.Tensor:
# Check constraints. # Check constraints.
if use_int8_w8a8 is True:
return fused_experts_impl_int8(hidden_states=hidden_states,
w1=w1,
w2=w2,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=inplace,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
use_fp8_w8a8=False,
use_int8_w8a8=True,
use_int8_w8a16=False,
use_int4_w4a16=False,
per_channel_quant=per_channel_quant,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=w1_scale,
w2_scale=w2_scale,
w1_zp=w1_zp,
w2_zp=w2_zp,
a1_scale=a1_scale,
a2_scale=a2_scale,
block_shape=block_shape,
use_nn_moe=False
)
if use_int4_w4a16: if use_int4_w4a16:
assert hidden_states.size(1) // 2 == w1.size(2), ( assert hidden_states.size(1) // 2 == w1.size(2), (
"Hidden size mismatch") "Hidden size mismatch")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment