Unverified Commit c32a7c7c authored by shaochangxu's avatar shaochangxu Committed by GitHub
Browse files

[Bugfix] fused_experts_impl wrong compute type for float32 (#11921)


Signed-off-by: default avatarshaochangxu.scx <shaochangxu.scx@antgroup.com>
Co-authored-by: default avatarshaochangxu.scx <shaochangxu.scx@antgroup.com>
parent 2118d056
......@@ -701,8 +701,14 @@ def fused_experts_impl(hidden_states: torch.Tensor,
device=hidden_states.device,
dtype=hidden_states.dtype)
compute_type = (tl.bfloat16
if hidden_states.dtype == torch.bfloat16 else tl.float16)
if hidden_states.dtype == torch.bfloat16:
compute_type = tl.bfloat16
elif hidden_states.dtype == torch.float16:
compute_type = tl.float16
elif hidden_states.dtype == torch.float32:
compute_type = tl.float32
else:
raise ValueError(f"Unsupported compute_type: {hidden_states.dtype}")
if inplace:
out_hidden_states = hidden_states
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment