Unverified Commit 3ad6202d authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Example] Specify a fixed commit for the flash-linear-attention repository and...

[Example] Specify a fixed commit for the flash-linear-attention repository and optimize nsa examples (#913)

- Updated the requirements.txt to specify a fixed commit for the flash-linear-attention repository.
- Refactored import paths in benchmark_nsa_fwd.py for better organization.
- Added a new function to generate configurations for autotuning.
- Modified the tilelang_sparse_attention function to accept parameters for block size, number of stages, and threads, enhancing flexibility.
- Changed allocation of shared memory for accumulators to optimize performance.
parent f92de932
......@@ -92,3 +92,6 @@ tilelang/jit/adapter/cython/.cycache
# cache directory for clangd
.cache/
# claude
**/.claude
......@@ -10,7 +10,7 @@ from typing import Optional, Union
from einops import rearrange, repeat
import triton
import triton.language as tl
from fla.ops.common.utils import prepare_token_indices
from fla.ops.utils import prepare_token_indices
from fla.utils import autocast_custom_fwd, contiguous
......@@ -439,6 +439,20 @@ def naive_nsa(q: torch.Tensor,
return o_slc.to(dtype) + o_swa.to(dtype) if o_swa is not None else o_slc.to(dtype)
def get_configs():
import itertools
iter_params = dict(
block_T=[128, 256, 512],
num_stages=[0, 1, 2, 4, 5],
threads=[32, 64, 128, 256, 512],
)
return [{
k: v for k, v in zip(iter_params, values)
} for values in itertools.product(*iter_params.values())]
@tilelang.autotune(configs=get_configs(),)
@tilelang.jit
def tilelang_sparse_attention(batch,
heads,
seq_len,
......@@ -447,7 +461,10 @@ def tilelang_sparse_attention(batch,
scale=None,
block_size=64,
groups=1,
selected_blocks=16):
selected_blocks=16,
block_T=128,
num_stages=2,
threads=32):
if scale is None:
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
else:
......@@ -461,7 +478,7 @@ def tilelang_sparse_attention(batch,
dtype = "float16"
accum_dtype = "float"
block_S = block_size
block_T = min(128, tilelang.math.next_power_of_2(dim))
block_T = min(block_T, tilelang.math.next_power_of_2(dim))
NK = tilelang.cdiv(dim, block_T)
NV = tilelang.cdiv(dim, block_T)
......@@ -471,8 +488,6 @@ def tilelang_sparse_attention(batch,
G = groups
BS = block_S
BK = BV = block_T
num_stages = 2
threads = 32
@T.prim_func
def tilelang_sparse_attention(
......@@ -489,7 +504,7 @@ def tilelang_sparse_attention(batch,
O_shared = T.alloc_shared([G, BV], dtype)
acc_s = T.alloc_fragment([G, BS], accum_dtype)
acc_s_cast = T.alloc_fragment([G, BS], dtype)
acc_s_cast = T.alloc_shared([G, BS], dtype)
acc_o = T.alloc_fragment([G, BV], accum_dtype)
scores_max = T.alloc_fragment([G], accum_dtype)
scores_max_prev = T.alloc_fragment([G], accum_dtype)
......@@ -497,11 +512,7 @@ def tilelang_sparse_attention(batch,
scores_sum = T.alloc_fragment([G], accum_dtype)
logsum = T.alloc_fragment([G], accum_dtype)
# T.use_swizzle(10)
T.annotate_layout({Q_shared: tilelang.layout.make_swizzled_layout(Q_shared)})
T.annotate_layout({K_shared: tilelang.layout.make_swizzled_layout(K_shared)})
T.annotate_layout({V_shared: tilelang.layout.make_swizzled_layout(V_shared)})
T.annotate_layout({O_shared: tilelang.layout.make_swizzled_layout(O_shared)})
i_t, i_v, i_bh = bx, by, bz
i_b, i_h = i_bh // head_kv, i_bh % head_kv
......@@ -597,7 +608,7 @@ def benchmark_nsa(batch_size,
torch.random.manual_seed(0)
# Compile the NSA kernel
program = tilelang_sparse_attention(
kernel = tilelang_sparse_attention(
batch=batch_size,
heads=head_query,
seq_len=seq_len,
......@@ -608,9 +619,6 @@ def benchmark_nsa(batch_size,
selected_blocks=selected_blocks,
scale=scale,
)
print(program)
kernel = tilelang.compile(program, out_idx=None, execution_backend="cython")
print(kernel.get_kernel_source())
profiler = kernel.get_profiler()
......
git+https://github.com/fla-org/flash-linear-attention
\ No newline at end of file
git+https://github.com/fla-org/flash-linear-attention@c3bd56589033610264532b11f0972c69e4645f6e
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment