Unverified Commit 43287082 authored by Lucia Fang's avatar Lucia Fang Committed by GitHub
Browse files

[Bugfix] Fix missing per_act_token parameter in compressed_tensors_moe (#20509)


Signed-off-by: default avatarLu Fang <fanglu@fb.com>
parent f73d02aa
...@@ -322,7 +322,7 @@ def cutlass_moe_fp8( ...@@ -322,7 +322,7 @@ def cutlass_moe_fp8(
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
w1_scale: torch.Tensor, w1_scale: torch.Tensor,
w2_scale: torch.Tensor, w2_scale: torch.Tensor,
per_act_token: bool, per_act_token: Optional[bool] = None,
activation: str = "silu", activation: str = "silu",
a1_scale: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None,
...@@ -366,6 +366,9 @@ def cutlass_moe_fp8( ...@@ -366,6 +366,9 @@ def cutlass_moe_fp8(
Returns: Returns:
- torch.Tensor: The fp16 output tensor after applying the MoE layer. - torch.Tensor: The fp16 output tensor after applying the MoE layer.
""" """
if per_act_token is None:
per_act_token = a1_scale.numel() != 1 if a1_scale is not None else (
a2_scale.numel() != 1 if a2_scale is not None else False)
per_out_ch = w1_scale.numel() != w1_q.size(0) per_out_ch = w1_scale.numel() != w1_q.size(0)
num_experts = global_num_experts if global_num_experts != -1 else w1_q.size( num_experts = global_num_experts if global_num_experts != -1 else w1_q.size(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment