Commit 183a88cf authored by zhaochao's avatar zhaochao
Browse files

fix some note


Signed-off-by: default avatarzhaochao <zhaochao1@sugon.com>
parent ca2958a8
...@@ -890,12 +890,12 @@ class FlashAttention(torch.nn.Module): ...@@ -890,12 +890,12 @@ class FlashAttention(torch.nn.Module):
elif q_format == "thd": elif q_format == "thd":
# thd -> t(hd) # thd -> t(hd)
output = output.reshape(output.shape[0], -1) output = output.reshape(output.shape[0], -1)
# Handle output shape when V head dim differs from Q/K head dim
if value_layer.shape[-1] != query_layer.shape[-1]: if value_layer.shape[-1] != query_layer.shape[-1]:
v_dim = value_layer.shape[-1] v_dim = value_layer.shape[-1]
num_heads = query_layer.shape[-2] num_heads = query_layer.shape[-2]
# 恢复为 (..., num_heads, head_dim_qk)
out_shape_heads = output.shape[:-1] + (num_heads, query_layer.shape[-1]) out_shape_heads = output.shape[:-1] + (num_heads, query_layer.shape[-1])
output = output.view(out_shape_heads)[..., :v_dim] # 裁剪到 V 的维度 output = output.view(out_shape_heads)[..., :v_dim]
output = output.reshape(output.shape[:-2] + (num_heads * v_dim,)) output = output.reshape(output.shape[:-2] + (num_heads * v_dim,))
return output.contiguous() return output.contiguous()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment