"testing/python/vscode:/vscode.git/clone" did not exist on "c264f37f33205231d5d8fd96a6e5b21c1bfef453"
Unverified Commit 49f35393 authored by liu yuhao's avatar liu yuhao Committed by GitHub
Browse files

Fix the bug in issue #1266 (#1284)


Co-authored-by: default avatarcheeryBloosm <liu_yu_hao@126.com>
parent 551ac60d
...@@ -156,13 +156,14 @@ def main(): ...@@ -156,13 +156,14 @@ def main():
DO = torch.randn((B, SEQ_LEN, HQ, D), dtype=dtype, device='cuda') DO = torch.randn((B, SEQ_LEN, HQ, D), dtype=dtype, device='cuda')
block_indices = torch.full((B, SEQ_LEN, H, S), SEQ_LEN, dtype=torch.long, device='cuda') block_indices = torch.full((B, SEQ_LEN, H, S), SEQ_LEN, dtype=torch.long, device='cuda')
block_counts = torch.zeros((B, SEQ_LEN, H), dtype=torch.long, device='cuda')
for b in range(B): for b in range(B):
for t in range(SEQ_LEN): for t in range(SEQ_LEN):
for h in range(H): for h in range(H):
i_i = torch.randperm(max(1, (t // block_size)))[:S] i_i = torch.randperm(max(1, (t // block_size)))[:S]
block_indices[b, t, h, :len(i_i)] = i_i block_indices[b, t, h, :len(i_i)] = i_i
block_counts[b, t, h] = (block_indices[b, t, h] != SEQ_LEN).sum().item()
block_indices = block_indices.sort(-1)[0] block_indices = block_indices.sort(-1)[0]
block_counts = torch.randint(1, S + 1, (B, SEQ_LEN, H), device='cuda')
out = kernel(Q, K, V, block_indices.to(torch.int32)) out = kernel(Q, K, V, block_indices.to(torch.int32))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment