"example/vscode:/vscode.git/clone" did not exist on "3ab20fd7530e5a878bf73c1d7005a83f3aa26f02"
Commit d79204e5 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Release] Bumpy version to v0.1.1 (#107)

* Remove Torch CPP backend and update execution backend options

- Remove TorchCPPKernelAdapter and related code from JIT modules
- Update execution backend options in jit/__init__.py, kernel.py, and adapter/__init__.py
- Remove "torch_cpp" from supported execution backend literals
- Simplify backend validation and remove unused torch_cpp-related code
。

* lint fix

* Add block sparse attention implementations for TileLang and Triton

- Implement block sparse attention kernels for TileLang and Triton
- Add example scripts for block sparse attention with top-k and threshold-based masking
- Include utility functions for generating sparse attention masks
- Demonstrate causal attention with block-level sparsity
- Add test cases to validate sparse attention implementations against PyTorch reference

* Bump version to 0.1.1

* Refactor block sparse attention examples for improved code quality

- Apply consistent code formatting and style in TileLang and Triton block sparse attention implementations
- Add ruff linter ignore comment for specific line in Triton implementation
- Improve readability by adjusting indentation and line breaks
- Standardize sparse mask generation and test function implementations
- Minor optimizations in test case configurations

* lint
parent c7462abf
......@@ -3,6 +3,7 @@ include CMakeLists.txt
include requirements.txt
include requirements-test.txt
include requirements-dev.txt
include tilelang/jit/adapter/cython/cython_wrapper.pyx
recursive-include src *
recursive-include 3rdparty *
recursive-exclude 3rdparty/clang* *
......
0.1.0
\ No newline at end of file
0.1.1
\ No newline at end of file
......@@ -7,24 +7,28 @@ import tilelang
import tilelang.language as T
import torch.nn.functional as F
def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False):
bsz, num_head, downsample_len, _ = x.shape
# N_CTX = downsample_len * BLOCK
sparse_index = torch.topk(x, topk, dim=-1).indices
dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len], False, dtype=torch.bool, device=x.device)
dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len],
False,
dtype=torch.bool,
device=x.device)
dense_mask.scatter_(-1, sparse_index, True)
if use_dense_for_last_block:
dense_mask[:, :,-2:,:] = True
dense_mask[:, :, -2:, :] = True
dense_mask.tril_()
return dense_mask
return dense_mask
def get_sparse_attn_mask_from_threshold(x, threshold, use_dense_for_last_block=False):
dense_mask = x > threshold
dense_mask = x > threshold
if use_dense_for_last_block:
dense_mask[:, :,-2:,:] = True
dense_mask[:, :, -2:, :] = True
dense_mask.tril_()
return dense_mask
return dense_mask
def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal):
......@@ -136,7 +140,7 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
scores_sum = T.alloc_fragment([block_M], accum_dtype)
logsum = T.alloc_fragment([block_M], accum_dtype)
block_mask = T.alloc_local([downsample_len], block_mask_dtype)
T.copy(Q[bz, by, bx * block_M:(bx + 1) * block_M, :], Q_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
......@@ -165,6 +169,7 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
return kernel_func(block_M, block_N, num_stages, threads)
def test_topk_sparse_attention():
# Config
BATCH, N_HEADS, SEQ_LEN, D_HEAD = 1, 1, 256, 64
......@@ -177,13 +182,15 @@ def test_topk_sparse_attention():
k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16)
v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16)
sm_scale = 1.0 / (D_HEAD ** 0.5)
sm_scale = 1.0 / (D_HEAD**0.5)
# Create sparse mask (downsampled to block level)
downsample_factor = BLOCK
downsample_len = math.ceil(SEQ_LEN / downsample_factor)
x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len], device='cuda', dtype=torch.bfloat16)
x_ds[:,:,:,0] = 100
x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len],
device='cuda',
dtype=torch.bfloat16)
x_ds[:, :, :, 0] = 100
block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK)
# Run Triton kernel
......@@ -194,25 +201,24 @@ def test_topk_sparse_attention():
# Compute reference
# Expand block mask to full attention matrix
full_mask = torch.kron(block_mask.float(),
torch.ones(BLOCK, BLOCK, device='cuda'))
full_mask = torch.kron(block_mask.float(), torch.ones(BLOCK, BLOCK, device='cuda'))
full_mask = full_mask[..., :SEQ_LEN, :SEQ_LEN].bool()
full_mask = full_mask & torch.tril(torch.ones_like(full_mask)) # Apply causal
# PyTorch reference implementation
attn = torch.einsum('bhsd,bhtd->bhst', q, k) * sm_scale
attn = attn.masked_fill(~full_mask, float('-inf'))
attn = F.softmax(attn, dim=-1)
ref_output = torch.einsum('bhst,bhtd->bhsd', attn, v)
print("ref_output", ref_output)
print("tilelang_output", tilelang_output)
# Verify accuracy
assert torch.allclose(tilelang_output, ref_output, atol=1e-2, rtol=1e-2), \
"TileLang output doesn't match reference"
print("Pass topk sparse attention test with qlen == klen")
if __name__ == "__main__":
test_topk_sparse_attention()
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ruff: noqa: E712
import math
import torch
......@@ -7,6 +8,7 @@ import triton
import triton.language as tl
import torch.nn.functional as F
def is_hip():
return triton.runtime.driver.active.get_current_target().backend == "hip"
......@@ -15,33 +17,40 @@ def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False):
bsz, num_head, downsample_len, _ = x.shape
# N_CTX = downsample_len * BLOCK
sparse_index = torch.topk(x, topk, dim=-1).indices
dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len], False, dtype=torch.bool, device=x.device)
dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len],
False,
dtype=torch.bool,
device=x.device)
dense_mask.scatter_(-1, sparse_index, True)
if use_dense_for_last_block:
dense_mask[:, :,-2:,:] = True
dense_mask[:, :, -2:, :] = True
dense_mask.tril_()
return dense_mask
return dense_mask
def get_sparse_attn_mask_from_threshold(x, threshold, use_dense_for_last_block=False):
dense_mask = x > threshold
dense_mask = x > threshold
if use_dense_for_last_block:
dense_mask[:, :,-2:,:] = True
dense_mask[:, :, -2:, :] = True
dense_mask.tril_()
return dense_mask
return dense_mask
@triton.jit
def _fwd_kernel_inner(
acc, l_i, m_i,
acc,
l_i,
m_i,
q,
k_block_col_idx,
block_mask_ptr,
k_ptrs, v_ptrs,
offs_m, offs_n,
stride_kt, stride_vt, stride_bmask_n,
k_ptrs,
v_ptrs,
offs_m,
offs_n,
stride_kt,
stride_vt,
stride_bmask_n,
sm_scale,
seqlen_k,
past_len,
......@@ -51,8 +60,8 @@ def _fwd_kernel_inner(
):
mask_val = tl.load(block_mask_ptr + k_block_col_idx * stride_bmask_n)
# print
# print
if k_block_col_idx == 3:
print("mask_val", mask_val)
if mask_val == True:
......@@ -67,9 +76,9 @@ def _fwd_kernel_inner(
qk *= sm_scale
# the following is needed only when LAST_K_BLOCK or BLOCK_M < BLOCK_N
if LAST_K_BLOCK :
qk += tl.where(offs_m[:, None] + past_len >= (start_n + offs_n[None, :]), 0, float('-inf'))
if LAST_K_BLOCK:
qk += tl.where(offs_m[:, None] + past_len >= (start_n + offs_n[None, :]), 0,
float('-inf'))
m_ij = tl.maximum(m_i, tl.max(qk, 1))
qk -= m_ij[:, None]
......@@ -78,7 +87,7 @@ def _fwd_kernel_inner(
alpha = tl.exp(m_i - m_ij)
l_i = l_i * alpha + l_ij
acc = acc * alpha[:, None]
# update acc
v = tl.load(v_ptrs + start_n * stride_vt)
......@@ -90,21 +99,38 @@ def _fwd_kernel_inner(
return acc, l_i, m_i
@triton.jit
def _fwd_kernel(
Q, K, V, sm_scale,
Q,
K,
V,
sm_scale,
block_mask_ptr,
Out,
stride_qz, stride_qh, stride_qm, stride_qd,
stride_kz, stride_kh, stride_kn, stride_kd,
stride_vz, stride_vh, stride_vn, stride_vd,
stride_bmz, stride_bmh, stride_bmm, stride_bmn,
stride_oz, stride_oh, stride_om, stride_od,
H, N_CTX,
stride_qz,
stride_qh,
stride_qm,
stride_qd,
stride_kz,
stride_kh,
stride_kn,
stride_kd,
stride_vz,
stride_vh,
stride_vn,
stride_vd,
stride_bmz,
stride_bmh,
stride_bmm,
stride_bmn,
stride_oz,
stride_oh,
stride_om,
stride_od,
H,
N_CTX,
PAST_LEN,
BLOCK_M: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
):
......@@ -144,13 +170,19 @@ def _fwd_kernel(
# loop over k, v and update accumulator
for col_idx in range(k_block_start, k_block_end):
acc, l_i, m_i = _fwd_kernel_inner(
acc, l_i, m_i,
acc,
l_i,
m_i,
q,
col_idx,
mask_ptrs,
k_ptrs, v_ptrs,
offs_m, offs_n,
stride_kn, stride_vn, stride_bmn,
k_ptrs,
v_ptrs,
offs_m,
offs_n,
stride_kn,
stride_vn,
stride_bmn,
sm_scale,
N_CTX,
PAST_LEN,
......@@ -162,27 +194,25 @@ def _fwd_kernel(
m_i += tl.math.log(l_i)
l_recip = 1 / l_i[:, None]
acc = acc * l_recip
acc = acc.to(Out.dtype.element_ty)
acc = acc.to(Out.dtype.element_ty)
off_o = off_z * stride_oz + off_h * stride_oh + offs_m[:, None] * stride_om + offs_d[None, :] * stride_od
off_o = off_z * stride_oz + off_h * stride_oh + offs_m[:, None] * stride_om + offs_d[
None, :] * stride_od
out_ptrs = Out + off_o
tl.store(out_ptrs, acc, mask=offs_m[:, None] < N_CTX)
def _forward(
ctx,
q,
k,
v,
block_sparse_mask,
sm_scale,
BLOCK_M=64,
BLOCK_N=64,
num_warps=None,
num_stages=1,
out=None
):
def _forward(ctx,
q,
k,
v,
block_sparse_mask,
sm_scale,
BLOCK_M=64,
BLOCK_N=64,
num_warps=None,
num_stages=1,
out=None):
assert q.shape[-1] == k.shape[-1] == v.shape[-1]
assert k.shape[2] == v.shape[2]
......@@ -200,19 +230,22 @@ def _forward(
N_CTX = k.shape[2]
PAST_LEN = N_CTX - q.shape[2]
H = q.shape[1]
_fwd_kernel[grid](
q, k, v, sm_scale,
q,
k,
v,
sm_scale,
block_sparse_mask,
o,
*q.stride(),
*k.stride(),
*v.stride(),
*block_sparse_mask.stride(),
*q.stride(),
*k.stride(),
*v.stride(),
*block_sparse_mask.stride(),
*o.stride(),
H, N_CTX,
H,
N_CTX,
PAST_LEN,
BLOCK_M,
BLOCK_N,
......@@ -224,8 +257,6 @@ def _forward(
return o
class _sparse_attention(torch.autograd.Function):
@staticmethod
......@@ -239,8 +270,8 @@ class _sparse_attention(torch.autograd.Function):
raise NotImplementedError("It does not support gradient propagation yet")
return None, None, None, None, None
block_sparse_triton_fn = _sparse_attention.apply
block_sparse_triton_fn = _sparse_attention.apply
def test_topk_sparse_attention():
......@@ -254,106 +285,100 @@ def test_topk_sparse_attention():
q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.bfloat16)
k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.bfloat16)
v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.bfloat16)
sm_scale = 1.0 / (D_HEAD ** 0.5)
sm_scale = 1.0 / (D_HEAD**0.5)
# Create sparse mask (downsampled to block level)
downsample_factor = BLOCK
downsample_len = math.ceil(SEQ_LEN / downsample_factor)
print("downsample_len", downsample_len)
x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len], device='cuda', dtype=torch.bfloat16)
x_ds[:,:,:,0] = 100
x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len],
device='cuda',
dtype=torch.bfloat16)
x_ds[:, :, :, 0] = 100
print("x_ds.shape", x_ds.shape)
block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=downsample_len)
block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK)
# print("block_mask", block_mask)
print("block_mask.shape", block_mask.shape)
# Run Triton kernel
triton_output = block_sparse_triton_fn(
q, k, v,
block_mask,
sm_scale
)
triton_output = block_sparse_triton_fn(q, k, v, block_mask, sm_scale)
# Compute reference
# Expand block mask to full attention matrix
full_mask = torch.kron(block_mask.float(),
torch.ones(BLOCK, BLOCK, device='cuda'))
full_mask = torch.kron(block_mask.float(), torch.ones(BLOCK, BLOCK, device='cuda'))
full_mask = full_mask[..., :SEQ_LEN, :SEQ_LEN].bool()
full_mask = full_mask & torch.tril(torch.ones_like(full_mask)) # Apply causal
# PyTorch reference implementation
attn = torch.einsum('bhsd,bhtd->bhst', q, k) * sm_scale
attn = attn.masked_fill(~full_mask, float('-inf'))
attn = F.softmax(attn, dim=-1)
ref_output = torch.einsum('bhst,bhtd->bhsd', attn, v)
# print("ref_output", ref_output)
# print("triton_output", triton_output)
# Verify accuracy
assert torch.allclose(triton_output, ref_output, atol=1e-2, rtol=1e-2), \
"Triton output doesn't match reference"
print("Pass topk sparse attention test with qlen == klen")
def test_topk_sparse_attention_qlt_kl():
BATCH, N_HEADS = 2, 4
Q_LEN, K_LEN, D_HEAD = 128, 256, 64 # qlen < klen; here, past_len = 256 - 128 = 128.
TOPK = 1
BLOCK = 64 # block size used in downsampling
torch.manual_seed(0)
# def test_topk_sparse_attention_qlt_kl():
# BATCH, N_HEADS = 2, 4
# Q_LEN, K_LEN, D_HEAD = 128, 256, 64 # qlen < klen; here, past_len = 256 - 128 = 128.
# TOPK = 1
# BLOCK = 64 # block size used in downsampling
# torch.manual_seed(0)
# # Create inputs.
# q = torch.randn(BATCH, N_HEADS, Q_LEN, D_HEAD, device='cuda', dtype=torch.bfloat16)
# k = torch.randn(BATCH, N_HEADS, K_LEN, D_HEAD, device='cuda', dtype=torch.bfloat16)
# v = torch.randn(BATCH, N_HEADS, K_LEN, D_HEAD, device='cuda', dtype=torch.bfloat16)
# # softmax scale
# sm_scale = 1.0 / (D_HEAD ** 0.5)
# Create inputs.
q = torch.randn(BATCH, N_HEADS, Q_LEN, D_HEAD, device='cuda', dtype=torch.bfloat16)
k = torch.randn(BATCH, N_HEADS, K_LEN, D_HEAD, device='cuda', dtype=torch.bfloat16)
v = torch.randn(BATCH, N_HEADS, K_LEN, D_HEAD, device='cuda', dtype=torch.bfloat16)
# softmax scale
sm_scale = 1.0 / (D_HEAD**0.5)
# downsample_factor = BLOCK
# print("downsample_factor", downsample_factor)
# downsample_len = math.ceil(K_LEN / downsample_factor) # number of blocks along one dimension
# print("downsample_len", downsample_len)
# x_ds = torch.randn(BATCH, N_HEADS, downsample_len, downsample_len,
# device='cuda', dtype=torch.bfloat16)
# # Force the first column to be high so that the first block is always selected.
# x_ds[:, :, :, 0] = 100
# block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK)
# print("block_mask", block_mask)
# print("block_mask.shape", block_mask.shape)
# # Run Triton kernel.
# triton_output = block_sparse_triton_fn(q, k, v, block_mask, sm_scale)
downsample_factor = BLOCK
print("downsample_factor", downsample_factor)
downsample_len = math.ceil(K_LEN / downsample_factor) # number of blocks along one dimension
print("downsample_len", downsample_len)
x_ds = torch.randn(
BATCH, N_HEADS, downsample_len, downsample_len, device='cuda', dtype=torch.bfloat16)
# Force the first column to be high so that the first block is always selected.
x_ds[:, :, :, 0] = 100
block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK)
print("block_mask", block_mask)
print("block_mask.shape", block_mask.shape)
# Run Triton kernel.
triton_output = block_sparse_triton_fn(q, k, v, block_mask, sm_scale)
# past_len = K_LEN - Q_LEN
past_len = K_LEN - Q_LEN
# attn = torch.einsum('bhsd,bhtd->bhst', q, k) * sm_scale
attn = torch.einsum('bhsd,bhtd->bhst', q, k) * sm_scale
# full_mask_full = torch.kron(block_mask.float(), torch.ones(BLOCK, BLOCK, device='cuda')).bool()
# full_mask_full = full_mask_full[..., :K_LEN, :K_LEN]
full_mask_full = torch.kron(block_mask.float(), torch.ones(BLOCK, BLOCK, device='cuda')).bool()
full_mask_full = full_mask_full[..., :K_LEN, :K_LEN]
# effective_mask = full_mask_full[..., past_len:K_LEN, :] # shape: (B, H, Q_LEN, K_LEN)
effective_mask = full_mask_full[..., past_len:K_LEN, :] # shape: (B, H, Q_LEN, K_LEN)
i_global = torch.arange(past_len, K_LEN, device=k.device).unsqueeze(1) # shape: (Q_LEN, 1)
j_global = torch.arange(K_LEN, device=k.device).unsqueeze(0) # shape: (1, K_LEN)
causal_mask = (j_global <= i_global) # shape: (Q_LEN, K_LEN)
# i_global = torch.arange(past_len, K_LEN, device=k.device).unsqueeze(1) # shape: (Q_LEN, 1)
# j_global = torch.arange(K_LEN, device=k.device).unsqueeze(0) # shape: (1, K_LEN)
# causal_mask = (j_global <= i_global) # shape: (Q_LEN, K_LEN)
final_mask = effective_mask & causal_mask # shape: (B, H, Q_LEN, K_LEN)
# final_mask = effective_mask & causal_mask # shape: (B, H, Q_LEN, K_LEN)
attn = attn.masked_fill(~final_mask, float('-inf'))
attn = F.softmax(attn, dim=-1)
ref_output = torch.einsum('bhst,bhtd->bhsd', attn, v)
# attn = attn.masked_fill(~final_mask, float('-inf'))
# attn = F.softmax(attn, dim=-1)
# ref_output = torch.einsum('bhst,bhtd->bhsd', attn, v)
# Verify accuracy.
assert torch.allclose(triton_output, ref_output, atol=1e-2, rtol=1e-2), \
"Triton output doesn't match reference when qlen < klen"
# # Verify accuracy.
# assert torch.allclose(triton_output, ref_output, atol=1e-2, rtol=1e-2), \
# "Triton output doesn't match reference when qlen < klen"
# print("Pass topk sparse attention test with qlen < klen")
print("Pass topk sparse attention test with qlen < klen")
if __name__ == "__main__":
test_topk_sparse_attention()
# test_topk_sparse_attention_qlt_kl()
\ No newline at end of file
test_topk_sparse_attention_qlt_kl()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment