Unverified Commit ae4e2806 authored by Roger Wang's avatar Roger Wang Committed by GitHub
Browse files

[Bugfix] Fix FI kernel`chunk_gated_delta_rule` output shape for Qwen3.5 (#34219)


Signed-off-by: default avatarRoger Wang <hey@rogerw.io>
parent cbea11c9
...@@ -135,7 +135,7 @@ def fi_chunk_gated_delta_rule( ...@@ -135,7 +135,7 @@ def fi_chunk_gated_delta_rule(
fi_state = initial_state.to(torch.float32) fi_state = initial_state.to(torch.float32)
fi_g = g.to(torch.float32) fi_g = g.to(torch.float32)
fi_beta = beta.to(torch.float32) fi_beta = beta.to(torch.float32)
return chunk_gated_delta_rule_fi( output, final_state = chunk_gated_delta_rule_fi(
q=q, q=q,
k=k, k=k,
v=v, v=v,
...@@ -145,6 +145,8 @@ def fi_chunk_gated_delta_rule( ...@@ -145,6 +145,8 @@ def fi_chunk_gated_delta_rule(
output_final_state=output_final_state, output_final_state=output_final_state,
cu_seqlens=cu_seqlens, cu_seqlens=cu_seqlens,
) )
# Unsqueeze back to 4D (1, L, H, D) to match fla output format
return output.unsqueeze(0), final_state
@CustomOp.register("chunk_gated_delta_rule") @CustomOp.register("chunk_gated_delta_rule")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment