Unverified Commit 5b9af2f1 authored by senlyu163's avatar senlyu163 Committed by GitHub
Browse files

fix: handle output shape when sequence length is padded by `pad_tensor` (#709)



* fix: handle output shape when sequence length is padded by `pad_tensor`

* Fix condition to check output shape based on batch size

---------
Co-authored-by: default avatarMuyang Li <lmxyy1999@foxmail.com>
parent 955b5523
...@@ -174,5 +174,7 @@ def fused_qkv_norm_rottary( ...@@ -174,5 +174,7 @@ def fused_qkv_norm_rottary(
norm_k=norm_k.weight if norm_k is not None else None, norm_k=norm_k.weight if norm_k is not None else None,
rotary_emb=rotary_emb, rotary_emb=rotary_emb,
) )
if seq_len * batch_size < output.shape[0]:
output = output[: seq_len * batch_size, :]
output = output.view(batch_size, seq_len, -1) output = output.view(batch_size, seq_len, -1)
return output return output
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment