"git@developer.sourcefind.cn:OpenDAS/nni.git" did not exist on "45ecb3ce4f716d5e2f9abc69f37f33607e02a5ef"
Commit a0dfa516 authored by Yuqing Xia's avatar Yuqing Xia Committed by LeiWang1999
Browse files

fix typo (#635)

parent 9c777b67
...@@ -104,8 +104,7 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_ ...@@ -104,8 +104,7 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
T.copy(acc_s, S_shared) T.copy(acc_s, S_shared)
T.copy(S_shared, acc_s_cast) T.copy(S_shared, acc_s_cast)
for i in T.Parallel(block_H): for i in T.Parallel(block_H):
logsum[i] = logsum[i] * sco logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
es_scale[i] + scores_sum[i]
for i, j in T.Parallel(block_H, dim): for i, j in T.Parallel(block_H, dim):
acc_o[i, j] *= scores_scale[i] acc_o[i, j] *= scores_scale[i]
T.gemm(acc_s_cast, KV_shared, acc_o, policy=T.GemmWarpPolicy.FullCol) T.gemm(acc_s_cast, KV_shared, acc_o, policy=T.GemmWarpPolicy.FullCol)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment