Commit d79204e5 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Release] Bumpy version to v0.1.1 (#107)

* Remove Torch CPP backend and update execution backend options

- Remove TorchCPPKernelAdapter and related code from JIT modules
- Update execution backend options in jit/__init__.py, kernel.py, and adapter/__init__.py
- Remove "torch_cpp" from supported execution backend literals
- Simplify backend validation and remove unused torch_cpp-related code
。

* lint fix

* Add block sparse attention implementations for TileLang and Triton

- Implement block sparse attention kernels for TileLang and Triton
- Add example scripts for block sparse attention with top-k and threshold-based masking
- Include utility functions for generating sparse attention masks
- Demonstrate causal attention with block-level sparsity
- Add test cases to validate sparse attention implementations against PyTorch reference

* Bump version to 0.1.1

* Refactor block sparse attention examples for improved code quality

- Apply consistent code formatting and style in TileLang and Triton block sparse attention implementations
- Add ruff linter ignore comment for specific line in Triton implementation
- Improve readability by adjusting indentation and line breaks
- Standardize sparse mask generation and test function implementations
- Minor optimizations in test case configurations

* lint
parent c7462abf
...@@ -3,6 +3,7 @@ include CMakeLists.txt ...@@ -3,6 +3,7 @@ include CMakeLists.txt
include requirements.txt include requirements.txt
include requirements-test.txt include requirements-test.txt
include requirements-dev.txt include requirements-dev.txt
include tilelang/jit/adapter/cython/cython_wrapper.pyx
recursive-include src * recursive-include src *
recursive-include 3rdparty * recursive-include 3rdparty *
recursive-exclude 3rdparty/clang* * recursive-exclude 3rdparty/clang* *
......
0.1.0 0.1.1
\ No newline at end of file \ No newline at end of file
...@@ -7,14 +7,18 @@ import tilelang ...@@ -7,14 +7,18 @@ import tilelang
import tilelang.language as T import tilelang.language as T
import torch.nn.functional as F import torch.nn.functional as F
def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False): def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False):
bsz, num_head, downsample_len, _ = x.shape bsz, num_head, downsample_len, _ = x.shape
# N_CTX = downsample_len * BLOCK # N_CTX = downsample_len * BLOCK
sparse_index = torch.topk(x, topk, dim=-1).indices sparse_index = torch.topk(x, topk, dim=-1).indices
dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len], False, dtype=torch.bool, device=x.device) dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len],
False,
dtype=torch.bool,
device=x.device)
dense_mask.scatter_(-1, sparse_index, True) dense_mask.scatter_(-1, sparse_index, True)
if use_dense_for_last_block: if use_dense_for_last_block:
dense_mask[:, :,-2:,:] = True dense_mask[:, :, -2:, :] = True
dense_mask.tril_() dense_mask.tril_()
return dense_mask return dense_mask
...@@ -22,7 +26,7 @@ def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False): ...@@ -22,7 +26,7 @@ def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False):
def get_sparse_attn_mask_from_threshold(x, threshold, use_dense_for_last_block=False): def get_sparse_attn_mask_from_threshold(x, threshold, use_dense_for_last_block=False):
dense_mask = x > threshold dense_mask = x > threshold
if use_dense_for_last_block: if use_dense_for_last_block:
dense_mask[:, :,-2:,:] = True dense_mask[:, :, -2:, :] = True
dense_mask.tril_() dense_mask.tril_()
return dense_mask return dense_mask
...@@ -165,6 +169,7 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal) ...@@ -165,6 +169,7 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
return kernel_func(block_M, block_N, num_stages, threads) return kernel_func(block_M, block_N, num_stages, threads)
def test_topk_sparse_attention(): def test_topk_sparse_attention():
# Config # Config
BATCH, N_HEADS, SEQ_LEN, D_HEAD = 1, 1, 256, 64 BATCH, N_HEADS, SEQ_LEN, D_HEAD = 1, 1, 256, 64
...@@ -177,13 +182,15 @@ def test_topk_sparse_attention(): ...@@ -177,13 +182,15 @@ def test_topk_sparse_attention():
k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16) k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16)
v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16) v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16)
sm_scale = 1.0 / (D_HEAD ** 0.5) sm_scale = 1.0 / (D_HEAD**0.5)
# Create sparse mask (downsampled to block level) # Create sparse mask (downsampled to block level)
downsample_factor = BLOCK downsample_factor = BLOCK
downsample_len = math.ceil(SEQ_LEN / downsample_factor) downsample_len = math.ceil(SEQ_LEN / downsample_factor)
x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len], device='cuda', dtype=torch.bfloat16) x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len],
x_ds[:,:,:,0] = 100 device='cuda',
dtype=torch.bfloat16)
x_ds[:, :, :, 0] = 100
block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK) block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK)
# Run Triton kernel # Run Triton kernel
...@@ -194,8 +201,7 @@ def test_topk_sparse_attention(): ...@@ -194,8 +201,7 @@ def test_topk_sparse_attention():
# Compute reference # Compute reference
# Expand block mask to full attention matrix # Expand block mask to full attention matrix
full_mask = torch.kron(block_mask.float(), full_mask = torch.kron(block_mask.float(), torch.ones(BLOCK, BLOCK, device='cuda'))
torch.ones(BLOCK, BLOCK, device='cuda'))
full_mask = full_mask[..., :SEQ_LEN, :SEQ_LEN].bool() full_mask = full_mask[..., :SEQ_LEN, :SEQ_LEN].bool()
full_mask = full_mask & torch.tril(torch.ones_like(full_mask)) # Apply causal full_mask = full_mask & torch.tril(torch.ones_like(full_mask)) # Apply causal
...@@ -208,11 +214,11 @@ def test_topk_sparse_attention(): ...@@ -208,11 +214,11 @@ def test_topk_sparse_attention():
print("ref_output", ref_output) print("ref_output", ref_output)
print("tilelang_output", tilelang_output) print("tilelang_output", tilelang_output)
# Verify accuracy # Verify accuracy
assert torch.allclose(tilelang_output, ref_output, atol=1e-2, rtol=1e-2), \ assert torch.allclose(tilelang_output, ref_output, atol=1e-2, rtol=1e-2), \
"TileLang output doesn't match reference" "TileLang output doesn't match reference"
print("Pass topk sparse attention test with qlen == klen") print("Pass topk sparse attention test with qlen == klen")
if __name__ == "__main__": if __name__ == "__main__":
test_topk_sparse_attention() test_topk_sparse_attention()
# Copyright (c) Microsoft Corporation. # Copyright (c) Microsoft Corporation.
# Licensed under the MIT License. # Licensed under the MIT License.
# ruff: noqa: E712
import math import math
import torch import torch
...@@ -7,6 +8,7 @@ import triton ...@@ -7,6 +8,7 @@ import triton
import triton.language as tl import triton.language as tl
import torch.nn.functional as F import torch.nn.functional as F
def is_hip(): def is_hip():
return triton.runtime.driver.active.get_current_target().backend == "hip" return triton.runtime.driver.active.get_current_target().backend == "hip"
...@@ -15,10 +17,13 @@ def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False): ...@@ -15,10 +17,13 @@ def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False):
bsz, num_head, downsample_len, _ = x.shape bsz, num_head, downsample_len, _ = x.shape
# N_CTX = downsample_len * BLOCK # N_CTX = downsample_len * BLOCK
sparse_index = torch.topk(x, topk, dim=-1).indices sparse_index = torch.topk(x, topk, dim=-1).indices
dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len], False, dtype=torch.bool, device=x.device) dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len],
False,
dtype=torch.bool,
device=x.device)
dense_mask.scatter_(-1, sparse_index, True) dense_mask.scatter_(-1, sparse_index, True)
if use_dense_for_last_block: if use_dense_for_last_block:
dense_mask[:, :,-2:,:] = True dense_mask[:, :, -2:, :] = True
dense_mask.tril_() dense_mask.tril_()
return dense_mask return dense_mask
...@@ -26,22 +31,26 @@ def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False): ...@@ -26,22 +31,26 @@ def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False):
def get_sparse_attn_mask_from_threshold(x, threshold, use_dense_for_last_block=False): def get_sparse_attn_mask_from_threshold(x, threshold, use_dense_for_last_block=False):
dense_mask = x > threshold dense_mask = x > threshold
if use_dense_for_last_block: if use_dense_for_last_block:
dense_mask[:, :,-2:,:] = True dense_mask[:, :, -2:, :] = True
dense_mask.tril_() dense_mask.tril_()
return dense_mask return dense_mask
@triton.jit @triton.jit
def _fwd_kernel_inner( def _fwd_kernel_inner(
acc, l_i, m_i, acc,
l_i,
m_i,
q, q,
k_block_col_idx, k_block_col_idx,
block_mask_ptr, block_mask_ptr,
k_ptrs, v_ptrs, k_ptrs,
offs_m, offs_n, v_ptrs,
stride_kt, stride_vt, stride_bmask_n, offs_m,
offs_n,
stride_kt,
stride_vt,
stride_bmask_n,
sm_scale, sm_scale,
seqlen_k, seqlen_k,
past_len, past_len,
...@@ -67,9 +76,9 @@ def _fwd_kernel_inner( ...@@ -67,9 +76,9 @@ def _fwd_kernel_inner(
qk *= sm_scale qk *= sm_scale
# the following is needed only when LAST_K_BLOCK or BLOCK_M < BLOCK_N # the following is needed only when LAST_K_BLOCK or BLOCK_M < BLOCK_N
if LAST_K_BLOCK : if LAST_K_BLOCK:
qk += tl.where(offs_m[:, None] + past_len >= (start_n + offs_n[None, :]), 0, float('-inf')) qk += tl.where(offs_m[:, None] + past_len >= (start_n + offs_n[None, :]), 0,
float('-inf'))
m_ij = tl.maximum(m_i, tl.max(qk, 1)) m_ij = tl.maximum(m_i, tl.max(qk, 1))
qk -= m_ij[:, None] qk -= m_ij[:, None]
...@@ -90,19 +99,36 @@ def _fwd_kernel_inner( ...@@ -90,19 +99,36 @@ def _fwd_kernel_inner(
return acc, l_i, m_i return acc, l_i, m_i
@triton.jit @triton.jit
def _fwd_kernel( def _fwd_kernel(
Q, K, V, sm_scale, Q,
K,
V,
sm_scale,
block_mask_ptr, block_mask_ptr,
Out, Out,
stride_qz, stride_qh, stride_qm, stride_qd, stride_qz,
stride_kz, stride_kh, stride_kn, stride_kd, stride_qh,
stride_vz, stride_vh, stride_vn, stride_vd, stride_qm,
stride_bmz, stride_bmh, stride_bmm, stride_bmn, stride_qd,
stride_oz, stride_oh, stride_om, stride_od, stride_kz,
H, N_CTX, stride_kh,
stride_kn,
stride_kd,
stride_vz,
stride_vh,
stride_vn,
stride_vd,
stride_bmz,
stride_bmh,
stride_bmm,
stride_bmn,
stride_oz,
stride_oh,
stride_om,
stride_od,
H,
N_CTX,
PAST_LEN, PAST_LEN,
BLOCK_M: tl.constexpr, BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr, BLOCK_N: tl.constexpr,
...@@ -144,13 +170,19 @@ def _fwd_kernel( ...@@ -144,13 +170,19 @@ def _fwd_kernel(
# loop over k, v and update accumulator # loop over k, v and update accumulator
for col_idx in range(k_block_start, k_block_end): for col_idx in range(k_block_start, k_block_end):
acc, l_i, m_i = _fwd_kernel_inner( acc, l_i, m_i = _fwd_kernel_inner(
acc, l_i, m_i, acc,
l_i,
m_i,
q, q,
col_idx, col_idx,
mask_ptrs, mask_ptrs,
k_ptrs, v_ptrs, k_ptrs,
offs_m, offs_n, v_ptrs,
stride_kn, stride_vn, stride_bmn, offs_m,
offs_n,
stride_kn,
stride_vn,
stride_bmn,
sm_scale, sm_scale,
N_CTX, N_CTX,
PAST_LEN, PAST_LEN,
...@@ -164,13 +196,13 @@ def _fwd_kernel( ...@@ -164,13 +196,13 @@ def _fwd_kernel(
acc = acc * l_recip acc = acc * l_recip
acc = acc.to(Out.dtype.element_ty) acc = acc.to(Out.dtype.element_ty)
off_o = off_z * stride_oz + off_h * stride_oh + offs_m[:, None] * stride_om + offs_d[
off_o = off_z * stride_oz + off_h * stride_oh + offs_m[:, None] * stride_om + offs_d[None, :] * stride_od None, :] * stride_od
out_ptrs = Out + off_o out_ptrs = Out + off_o
tl.store(out_ptrs, acc, mask=offs_m[:, None] < N_CTX) tl.store(out_ptrs, acc, mask=offs_m[:, None] < N_CTX)
def _forward(
ctx, def _forward(ctx,
q, q,
k, k,
v, v,
...@@ -180,9 +212,7 @@ def _forward( ...@@ -180,9 +212,7 @@ def _forward(
BLOCK_N=64, BLOCK_N=64,
num_warps=None, num_warps=None,
num_stages=1, num_stages=1,
out=None out=None):
):
assert q.shape[-1] == k.shape[-1] == v.shape[-1] assert q.shape[-1] == k.shape[-1] == v.shape[-1]
assert k.shape[2] == v.shape[2] assert k.shape[2] == v.shape[2]
...@@ -200,11 +230,13 @@ def _forward( ...@@ -200,11 +230,13 @@ def _forward(
N_CTX = k.shape[2] N_CTX = k.shape[2]
PAST_LEN = N_CTX - q.shape[2] PAST_LEN = N_CTX - q.shape[2]
H = q.shape[1] H = q.shape[1]
_fwd_kernel[grid]( _fwd_kernel[grid](
q, k, v, sm_scale, q,
k,
v,
sm_scale,
block_sparse_mask, block_sparse_mask,
o, o,
*q.stride(), *q.stride(),
...@@ -212,7 +244,8 @@ def _forward( ...@@ -212,7 +244,8 @@ def _forward(
*v.stride(), *v.stride(),
*block_sparse_mask.stride(), *block_sparse_mask.stride(),
*o.stride(), *o.stride(),
H, N_CTX, H,
N_CTX,
PAST_LEN, PAST_LEN,
BLOCK_M, BLOCK_M,
BLOCK_N, BLOCK_N,
...@@ -224,8 +257,6 @@ def _forward( ...@@ -224,8 +257,6 @@ def _forward(
return o return o
class _sparse_attention(torch.autograd.Function): class _sparse_attention(torch.autograd.Function):
@staticmethod @staticmethod
...@@ -239,8 +270,8 @@ class _sparse_attention(torch.autograd.Function): ...@@ -239,8 +270,8 @@ class _sparse_attention(torch.autograd.Function):
raise NotImplementedError("It does not support gradient propagation yet") raise NotImplementedError("It does not support gradient propagation yet")
return None, None, None, None, None return None, None, None, None, None
block_sparse_triton_fn = _sparse_attention.apply
block_sparse_triton_fn = _sparse_attention.apply
def test_topk_sparse_attention(): def test_topk_sparse_attention():
...@@ -254,31 +285,28 @@ def test_topk_sparse_attention(): ...@@ -254,31 +285,28 @@ def test_topk_sparse_attention():
q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.bfloat16) q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.bfloat16)
k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.bfloat16) k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.bfloat16)
v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.bfloat16) v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.bfloat16)
sm_scale = 1.0 / (D_HEAD ** 0.5) sm_scale = 1.0 / (D_HEAD**0.5)
# Create sparse mask (downsampled to block level) # Create sparse mask (downsampled to block level)
downsample_factor = BLOCK downsample_factor = BLOCK
downsample_len = math.ceil(SEQ_LEN / downsample_factor) downsample_len = math.ceil(SEQ_LEN / downsample_factor)
print("downsample_len", downsample_len) print("downsample_len", downsample_len)
x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len], device='cuda', dtype=torch.bfloat16) x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len],
x_ds[:,:,:,0] = 100 device='cuda',
dtype=torch.bfloat16)
x_ds[:, :, :, 0] = 100
print("x_ds.shape", x_ds.shape) print("x_ds.shape", x_ds.shape)
block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=downsample_len) block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK)
# print("block_mask", block_mask) # print("block_mask", block_mask)
print("block_mask.shape", block_mask.shape) print("block_mask.shape", block_mask.shape)
# Run Triton kernel # Run Triton kernel
triton_output = block_sparse_triton_fn( triton_output = block_sparse_triton_fn(q, k, v, block_mask, sm_scale)
q, k, v,
block_mask,
sm_scale
)
# Compute reference # Compute reference
# Expand block mask to full attention matrix # Expand block mask to full attention matrix
full_mask = torch.kron(block_mask.float(), full_mask = torch.kron(block_mask.float(), torch.ones(BLOCK, BLOCK, device='cuda'))
torch.ones(BLOCK, BLOCK, device='cuda'))
full_mask = full_mask[..., :SEQ_LEN, :SEQ_LEN].bool() full_mask = full_mask[..., :SEQ_LEN, :SEQ_LEN].bool()
full_mask = full_mask & torch.tril(torch.ones_like(full_mask)) # Apply causal full_mask = full_mask & torch.tril(torch.ones_like(full_mask)) # Apply causal
...@@ -291,69 +319,66 @@ def test_topk_sparse_attention(): ...@@ -291,69 +319,66 @@ def test_topk_sparse_attention():
# print("ref_output", ref_output) # print("ref_output", ref_output)
# print("triton_output", triton_output) # print("triton_output", triton_output)
# Verify accuracy # Verify accuracy
assert torch.allclose(triton_output, ref_output, atol=1e-2, rtol=1e-2), \ assert torch.allclose(triton_output, ref_output, atol=1e-2, rtol=1e-2), \
"Triton output doesn't match reference" "Triton output doesn't match reference"
print("Pass topk sparse attention test with qlen == klen") print("Pass topk sparse attention test with qlen == klen")
def test_topk_sparse_attention_qlt_kl():
BATCH, N_HEADS = 2, 4
Q_LEN, K_LEN, D_HEAD = 128, 256, 64 # qlen < klen; here, past_len = 256 - 128 = 128.
TOPK = 1
BLOCK = 64 # block size used in downsampling
torch.manual_seed(0)
# def test_topk_sparse_attention_qlt_kl(): # Create inputs.
# BATCH, N_HEADS = 2, 4 q = torch.randn(BATCH, N_HEADS, Q_LEN, D_HEAD, device='cuda', dtype=torch.bfloat16)
# Q_LEN, K_LEN, D_HEAD = 128, 256, 64 # qlen < klen; here, past_len = 256 - 128 = 128. k = torch.randn(BATCH, N_HEADS, K_LEN, D_HEAD, device='cuda', dtype=torch.bfloat16)
# TOPK = 1 v = torch.randn(BATCH, N_HEADS, K_LEN, D_HEAD, device='cuda', dtype=torch.bfloat16)
# BLOCK = 64 # block size used in downsampling # softmax scale
# torch.manual_seed(0) sm_scale = 1.0 / (D_HEAD**0.5)
# # Create inputs.
# q = torch.randn(BATCH, N_HEADS, Q_LEN, D_HEAD, device='cuda', dtype=torch.bfloat16)
# k = torch.randn(BATCH, N_HEADS, K_LEN, D_HEAD, device='cuda', dtype=torch.bfloat16)
# v = torch.randn(BATCH, N_HEADS, K_LEN, D_HEAD, device='cuda', dtype=torch.bfloat16)
# # softmax scale
# sm_scale = 1.0 / (D_HEAD ** 0.5)
# downsample_factor = BLOCK
# print("downsample_factor", downsample_factor)
# downsample_len = math.ceil(K_LEN / downsample_factor) # number of blocks along one dimension
# print("downsample_len", downsample_len)
# x_ds = torch.randn(BATCH, N_HEADS, downsample_len, downsample_len,
# device='cuda', dtype=torch.bfloat16)
# # Force the first column to be high so that the first block is always selected.
# x_ds[:, :, :, 0] = 100
# block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK)
# print("block_mask", block_mask)
# print("block_mask.shape", block_mask.shape)
# # Run Triton kernel.
# triton_output = block_sparse_triton_fn(q, k, v, block_mask, sm_scale)
# past_len = K_LEN - Q_LEN downsample_factor = BLOCK
print("downsample_factor", downsample_factor)
downsample_len = math.ceil(K_LEN / downsample_factor) # number of blocks along one dimension
print("downsample_len", downsample_len)
x_ds = torch.randn(
BATCH, N_HEADS, downsample_len, downsample_len, device='cuda', dtype=torch.bfloat16)
# Force the first column to be high so that the first block is always selected.
x_ds[:, :, :, 0] = 100
block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK)
print("block_mask", block_mask)
print("block_mask.shape", block_mask.shape)
# Run Triton kernel.
triton_output = block_sparse_triton_fn(q, k, v, block_mask, sm_scale)
# attn = torch.einsum('bhsd,bhtd->bhst', q, k) * sm_scale past_len = K_LEN - Q_LEN
# full_mask_full = torch.kron(block_mask.float(), torch.ones(BLOCK, BLOCK, device='cuda')).bool() attn = torch.einsum('bhsd,bhtd->bhst', q, k) * sm_scale
# full_mask_full = full_mask_full[..., :K_LEN, :K_LEN]
# effective_mask = full_mask_full[..., past_len:K_LEN, :] # shape: (B, H, Q_LEN, K_LEN) full_mask_full = torch.kron(block_mask.float(), torch.ones(BLOCK, BLOCK, device='cuda')).bool()
full_mask_full = full_mask_full[..., :K_LEN, :K_LEN]
effective_mask = full_mask_full[..., past_len:K_LEN, :] # shape: (B, H, Q_LEN, K_LEN)
# i_global = torch.arange(past_len, K_LEN, device=k.device).unsqueeze(1) # shape: (Q_LEN, 1) i_global = torch.arange(past_len, K_LEN, device=k.device).unsqueeze(1) # shape: (Q_LEN, 1)
# j_global = torch.arange(K_LEN, device=k.device).unsqueeze(0) # shape: (1, K_LEN) j_global = torch.arange(K_LEN, device=k.device).unsqueeze(0) # shape: (1, K_LEN)
# causal_mask = (j_global <= i_global) # shape: (Q_LEN, K_LEN) causal_mask = (j_global <= i_global) # shape: (Q_LEN, K_LEN)
# final_mask = effective_mask & causal_mask # shape: (B, H, Q_LEN, K_LEN) final_mask = effective_mask & causal_mask # shape: (B, H, Q_LEN, K_LEN)
# attn = attn.masked_fill(~final_mask, float('-inf')) attn = attn.masked_fill(~final_mask, float('-inf'))
# attn = F.softmax(attn, dim=-1) attn = F.softmax(attn, dim=-1)
# ref_output = torch.einsum('bhst,bhtd->bhsd', attn, v) ref_output = torch.einsum('bhst,bhtd->bhsd', attn, v)
# # Verify accuracy. # Verify accuracy.
# assert torch.allclose(triton_output, ref_output, atol=1e-2, rtol=1e-2), \ assert torch.allclose(triton_output, ref_output, atol=1e-2, rtol=1e-2), \
# "Triton output doesn't match reference when qlen < klen" "Triton output doesn't match reference when qlen < klen"
# print("Pass topk sparse attention test with qlen < klen") print("Pass topk sparse attention test with qlen < klen")
if __name__ == "__main__": if __name__ == "__main__":
test_topk_sparse_attention() test_topk_sparse_attention()
# test_topk_sparse_attention_qlt_kl() test_topk_sparse_attention_qlt_kl()
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment