# Util functions for flash linear attention cumsum # Reference: fla/ops/utils/cumsum.py import tilelang import tilelang.language as T import sys # noqa: F401 # Add your fla repository path to sys.path # Currently we use the fla repository from the flash-linear-attention project at commit id f03cb3ae # sys.path.insert(0, "/home/tzj/flash-linear-attention") try: import fla print(fla.__file__) from fla.ops.utils.cumsum import chunk_local_cumsum_scalar except ImportError: print("fla not found, using tilelang implementation") fla = None import torch @tilelang.jit( out_idx=[-1], pass_configs={ tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True }) def tilelang_chunk_local_cumsum_scalar( # task config B, S, H, chunk_size=64, is_varlen=False, head_first=False, reverse=False, input_dtype="float16", output_dtype="float32", # kernel config block_S=64, threads=256, use_fragment=False, ): G_shape = (B, H, S) if head_first else (B, S, H) assert chunk_size == 2**(chunk_size.bit_length() - 1), "chunk_size must be a power of 2" assert chunk_size == block_S, "chunk_size must be equal to block_S" @T.prim_func def kernel( G: T.Tensor(G_shape, dtype=input_dtype), G_new: T.Tensor(G_shape, dtype=output_dtype), ): with T.Kernel(T.ceildiv(S, block_S), B * H, threads=threads) as (bs, bbh): bb, bh = bbh // H, bbh % H G_shared = T.alloc_shared((1, block_S), dtype=output_dtype, scope="shared") if head_first: T.copy(G[bb, bh, bs * block_S:(bs + 1) * block_S], G_shared) else: T.copy(G[bb, bs * block_S:(bs + 1) * block_S, bh], G_shared) if use_fragment: G_fragment = T.alloc_fragment((1, block_S), dtype=output_dtype, scope="shared") T.copy(G_shared, G_fragment) T.cumsum(G_fragment, dim=1, reverse=reverse) if head_first: T.copy(G_fragment, G_new[bb, bh, bs * block_S:(bs + 1) * block_S]) else: T.copy(G_fragment, G_new[bb, bs * block_S:(bs + 1) * block_S, bh]) else: T.cumsum(G_shared, dim=1, reverse=reverse) if head_first: T.copy(G_shared, G_new[bb, bh, bs * block_S:(bs + 1) * block_S]) else: T.copy(G_shared, G_new[bb, bs * block_S:(bs + 1) * block_S, bh]) return kernel def prepare_cumsum_input( B, S, H, dtype, ): G = torch.randn(B, S, H, dtype=dtype).cuda() return G def prepare_cumsum_output( B, S, H, dtype, ): G_new = torch.empty(B, S, H, dtype=dtype).cuda() return G_new def run_test( B, S, H, chunk_size, reverse, head_first, input_dtype, output_dtype, threads, use_fragment, ): G = prepare_cumsum_input(B, S, H, getattr(torch, input_dtype)) G_new_ref = prepare_cumsum_output(B, S, H, getattr(torch, output_dtype)) G_new_tilelang = prepare_cumsum_output(B, S, H, getattr(torch, output_dtype)) # reference cumsum G_new_ref = chunk_local_cumsum_scalar( g=G, chunk_size=chunk_size, reverse=reverse, head_first=head_first, output_dtype=getattr(torch, output_dtype)) # tilelang cumsum block_S = chunk_size kernel = tilelang_chunk_local_cumsum_scalar( B=B, S=S, H=H, chunk_size=chunk_size, reverse=reverse, head_first=head_first, input_dtype=input_dtype, output_dtype=output_dtype, block_S=block_S, threads=threads, use_fragment=use_fragment, ) torch.cuda.profiler.start() G_new_tilelang = kernel(G) torch.cuda.profiler.stop() try: torch.testing.assert_close(G_new_tilelang, G_new_ref, rtol=1e-2, atol=1e-2) print("tilelang cumsum passed √") except Exception as e: print("tilelang cumsum failed ✗") print(e) print("G:") print(G.view(-1)) print("G_new_tilelang:") print(G_new_tilelang.view(-1)) print("G_new_ref:") print(G_new_ref.view(-1)) def main(): run_test( B=1, S=32768, H=32, chunk_size=64, reverse=True, head_first=False, input_dtype="float32", output_dtype="float32", threads=256, use_fragment=False) if __name__ == "__main__": main()