Unverified Commit 98f93db1 authored by 徐畅's avatar 徐畅 Committed by GitHub
Browse files

[Bugfix] Remove redundant T.fill to fix precision issue (#667)

parent 722c2a8c
...@@ -169,7 +169,7 @@ def flashattn(batch, heads, groups, seqlen_kv, dim, block_N, block_H, num_split, ...@@ -169,7 +169,7 @@ def flashattn(batch, heads, groups, seqlen_kv, dim, block_N, block_H, num_split,
T.fill(scores_max, -T.infinity(accum_dtype)) T.fill(scores_max, -T.infinity(accum_dtype))
loop_range = T.ceildiv((seqlen_kv // num_split), block_N) loop_range = T.ceildiv((seqlen_kv // num_split), block_N)
T.fill(K_shared, 0)
for k in T.Pipelined(loop_range, num_stages=num_stages): for k in T.Pipelined(loop_range, num_stages=num_stages):
T.copy( T.copy(
K[bid, (seqlen_kv // num_split) * sid + K[bid, (seqlen_kv // num_split) * sid +
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment