Unverified Commit 3ad6202d authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Example] Specify a fixed commit for the flash-linear-attention repository and...

[Example] Specify a fixed commit for the flash-linear-attention repository and optimize nsa examples (#913)

- Updated the requirements.txt to specify a fixed commit for the flash-linear-attention repository.
- Refactored import paths in benchmark_nsa_fwd.py for better organization.
- Added a new function to generate configurations for autotuning.
- Modified the tilelang_sparse_attention function to accept parameters for block size, number of stages, and threads, enhancing flexibility.
- Changed allocation of shared memory for accumulators to optimize performance.
parent f92de932
...@@ -92,3 +92,6 @@ tilelang/jit/adapter/cython/.cycache ...@@ -92,3 +92,6 @@ tilelang/jit/adapter/cython/.cycache
# cache directory for clangd # cache directory for clangd
.cache/ .cache/
# claude
**/.claude
...@@ -10,7 +10,7 @@ from typing import Optional, Union ...@@ -10,7 +10,7 @@ from typing import Optional, Union
from einops import rearrange, repeat from einops import rearrange, repeat
import triton import triton
import triton.language as tl import triton.language as tl
from fla.ops.common.utils import prepare_token_indices from fla.ops.utils import prepare_token_indices
from fla.utils import autocast_custom_fwd, contiguous from fla.utils import autocast_custom_fwd, contiguous
...@@ -439,6 +439,20 @@ def naive_nsa(q: torch.Tensor, ...@@ -439,6 +439,20 @@ def naive_nsa(q: torch.Tensor,
return o_slc.to(dtype) + o_swa.to(dtype) if o_swa is not None else o_slc.to(dtype) return o_slc.to(dtype) + o_swa.to(dtype) if o_swa is not None else o_slc.to(dtype)
def get_configs():
import itertools
iter_params = dict(
block_T=[128, 256, 512],
num_stages=[0, 1, 2, 4, 5],
threads=[32, 64, 128, 256, 512],
)
return [{
k: v for k, v in zip(iter_params, values)
} for values in itertools.product(*iter_params.values())]
@tilelang.autotune(configs=get_configs(),)
@tilelang.jit
def tilelang_sparse_attention(batch, def tilelang_sparse_attention(batch,
heads, heads,
seq_len, seq_len,
...@@ -447,7 +461,10 @@ def tilelang_sparse_attention(batch, ...@@ -447,7 +461,10 @@ def tilelang_sparse_attention(batch,
scale=None, scale=None,
block_size=64, block_size=64,
groups=1, groups=1,
selected_blocks=16): selected_blocks=16,
block_T=128,
num_stages=2,
threads=32):
if scale is None: if scale is None:
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
else: else:
...@@ -461,7 +478,7 @@ def tilelang_sparse_attention(batch, ...@@ -461,7 +478,7 @@ def tilelang_sparse_attention(batch,
dtype = "float16" dtype = "float16"
accum_dtype = "float" accum_dtype = "float"
block_S = block_size block_S = block_size
block_T = min(128, tilelang.math.next_power_of_2(dim)) block_T = min(block_T, tilelang.math.next_power_of_2(dim))
NK = tilelang.cdiv(dim, block_T) NK = tilelang.cdiv(dim, block_T)
NV = tilelang.cdiv(dim, block_T) NV = tilelang.cdiv(dim, block_T)
...@@ -471,8 +488,6 @@ def tilelang_sparse_attention(batch, ...@@ -471,8 +488,6 @@ def tilelang_sparse_attention(batch,
G = groups G = groups
BS = block_S BS = block_S
BK = BV = block_T BK = BV = block_T
num_stages = 2
threads = 32
@T.prim_func @T.prim_func
def tilelang_sparse_attention( def tilelang_sparse_attention(
...@@ -489,7 +504,7 @@ def tilelang_sparse_attention(batch, ...@@ -489,7 +504,7 @@ def tilelang_sparse_attention(batch,
O_shared = T.alloc_shared([G, BV], dtype) O_shared = T.alloc_shared([G, BV], dtype)
acc_s = T.alloc_fragment([G, BS], accum_dtype) acc_s = T.alloc_fragment([G, BS], accum_dtype)
acc_s_cast = T.alloc_fragment([G, BS], dtype) acc_s_cast = T.alloc_shared([G, BS], dtype)
acc_o = T.alloc_fragment([G, BV], accum_dtype) acc_o = T.alloc_fragment([G, BV], accum_dtype)
scores_max = T.alloc_fragment([G], accum_dtype) scores_max = T.alloc_fragment([G], accum_dtype)
scores_max_prev = T.alloc_fragment([G], accum_dtype) scores_max_prev = T.alloc_fragment([G], accum_dtype)
...@@ -497,11 +512,7 @@ def tilelang_sparse_attention(batch, ...@@ -497,11 +512,7 @@ def tilelang_sparse_attention(batch,
scores_sum = T.alloc_fragment([G], accum_dtype) scores_sum = T.alloc_fragment([G], accum_dtype)
logsum = T.alloc_fragment([G], accum_dtype) logsum = T.alloc_fragment([G], accum_dtype)
# T.use_swizzle(10) T.annotate_layout({O_shared: tilelang.layout.make_swizzled_layout(O_shared)})
T.annotate_layout({Q_shared: tilelang.layout.make_swizzled_layout(Q_shared)})
T.annotate_layout({K_shared: tilelang.layout.make_swizzled_layout(K_shared)})
T.annotate_layout({V_shared: tilelang.layout.make_swizzled_layout(V_shared)})
i_t, i_v, i_bh = bx, by, bz i_t, i_v, i_bh = bx, by, bz
i_b, i_h = i_bh // head_kv, i_bh % head_kv i_b, i_h = i_bh // head_kv, i_bh % head_kv
...@@ -597,7 +608,7 @@ def benchmark_nsa(batch_size, ...@@ -597,7 +608,7 @@ def benchmark_nsa(batch_size,
torch.random.manual_seed(0) torch.random.manual_seed(0)
# Compile the NSA kernel # Compile the NSA kernel
program = tilelang_sparse_attention( kernel = tilelang_sparse_attention(
batch=batch_size, batch=batch_size,
heads=head_query, heads=head_query,
seq_len=seq_len, seq_len=seq_len,
...@@ -608,9 +619,6 @@ def benchmark_nsa(batch_size, ...@@ -608,9 +619,6 @@ def benchmark_nsa(batch_size,
selected_blocks=selected_blocks, selected_blocks=selected_blocks,
scale=scale, scale=scale,
) )
print(program)
kernel = tilelang.compile(program, out_idx=None, execution_backend="cython")
print(kernel.get_kernel_source())
profiler = kernel.get_profiler() profiler = kernel.get_profiler()
......
git+https://github.com/fla-org/flash-linear-attention git+https://github.com/fla-org/flash-linear-attention@c3bd56589033610264532b11f0972c69e4645f6e
\ No newline at end of file \ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment