Commit cd87548a authored by gaoqiong's avatar gaoqiong
Browse files

修复ds3 int8调用的功能

parent 8f73ab36
...@@ -373,18 +373,21 @@ class BlockInt8MoEMethod: ...@@ -373,18 +373,21 @@ class BlockInt8MoEMethod:
router_logits: torch.Tensor, router_logits: torch.Tensor,
top_k: int, top_k: int,
renormalize: bool, renormalize: bool,
use_grouped_topk: bool, use_grouped_topk: bool = False,
topk_group: Optional[int] = None, topk_group: Optional[int] = None,
use_nn_moe: Optional[bool] = False,
num_expert_group: Optional[int] = None, num_expert_group: Optional[int] = None,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None, custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax", scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None, e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
use_nn_moe: Optional[bool] = False,
) -> torch.Tensor: ) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe import fused_experts
#print("===========fused_experts========================")
# Expert selection # Expert selection
topk_weights, topk_ids = FusedMoE.select_experts( topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x, hidden_states=x,
...@@ -409,6 +412,10 @@ class BlockInt8MoEMethod: ...@@ -409,6 +412,10 @@ class BlockInt8MoEMethod:
topk_ids=topk_ids, topk_ids=topk_ids,
inplace=True, inplace=True,
use_int8_w8a8=True, use_int8_w8a8=True,
activation=activation,
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
w1_scale=(layer.w13_weight_scale_inv), w1_scale=(layer.w13_weight_scale_inv),
w2_scale=(layer.w2_weight_scale_inv), w2_scale=(layer.w2_weight_scale_inv),
a1_scale=layer.w13_input_scale, a1_scale=layer.w13_input_scale,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment