Commit 87709cf4 authored by Shengyu Liu's avatar Shengyu Liu
Browse files

Add a comment

parent c28eca99
......@@ -45,7 +45,7 @@ def generate_testcase(t: TestParam) -> Testcase:
for b in range(t.b):
for s in range(t.s_q):
for h in range(t.h_kv):
# TODO Comment
# NOTE We use the following method to generate indices so that most indices lies within [s_kv-20000, s_kv), which is more realistic for sparse attention
near_mask = torch.randint(0, 32, (min(t.topk, t.s_kv),)) < 31
cur_indices = torch.randperm(t.s_kv)[:t.topk]
cur_indices[near_mask] = torch.randint(max(0, t.s_kv-20000), t.s_kv-1, (near_mask.sum().item(),))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment