Commit 87709cf4 authored by Shengyu Liu's avatar Shengyu Liu
Browse files

Add a comment

parent c28eca99
...@@ -45,7 +45,7 @@ def generate_testcase(t: TestParam) -> Testcase: ...@@ -45,7 +45,7 @@ def generate_testcase(t: TestParam) -> Testcase:
for b in range(t.b): for b in range(t.b):
for s in range(t.s_q): for s in range(t.s_q):
for h in range(t.h_kv): for h in range(t.h_kv):
# TODO Comment # NOTE We use the following method to generate indices so that most indices lies within [s_kv-20000, s_kv), which is more realistic for sparse attention
near_mask = torch.randint(0, 32, (min(t.topk, t.s_kv),)) < 31 near_mask = torch.randint(0, 32, (min(t.topk, t.s_kv),)) < 31
cur_indices = torch.randperm(t.s_kv)[:t.topk] cur_indices = torch.randperm(t.s_kv)[:t.topk]
cur_indices[near_mask] = torch.randint(max(0, t.s_kv-20000), t.s_kv-1, (near_mask.sum().item(),)) cur_indices[near_mask] = torch.randint(max(0, t.s_kv-20000), t.s_kv-1, (near_mask.sum().item(),))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment