"...include/git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "a4bf3a98b1c7cb1e4bae930f4e3962cb410ff8b6"
Unverified Commit ae9a6f0a authored by Tong WU's avatar Tong WU Committed by GitHub
Browse files

[Refactor][Example] Update linear attention examples and add tests (#1010)



* [Refactor][Example] Update linear attention examples and add tests

- Refactored the backward and forward linear attention kernels to use shared memory and atomic additions for improved performance.
- Introduced L2 normalization in the main functions of both examples.
- Added a new test suite for the linear attention examples to ensure correctness and performance.
- Updated argument parsing in the main functions for better usability.

* upd docstring for tma atomic add

* lint

* Add flash-linear-attention dependency to requirements.txt

* Rename main function to chunk_linear_attn_bwd

* Rename main function to chunk_linear_attn_fwd

* chore

---------
Co-authored-by: default avatarLeiWang1999 <leiwang1999@outlook.com>
Co-authored-by: default avatarLei Wang <34334180+LeiWang1999@users.noreply.github.com>
parent b7dfdb39
import torch import torch
import tilelang as tl import tilelang
import tilelang.language as T import tilelang.language as T
from tilelang.profiler import do_bench from tilelang.profiler import do_bench
import argparse import argparse
from fla.ops.linear_attn import fused_chunk_linear_attn # We compare with FLA from fla.ops.linear_attn import fused_chunk_linear_attn # We compare with FLA
from fla.modules.l2norm import l2norm_fwd
from einops import rearrange
from typing import Optional, Tuple
@tl.jit( @tilelang.jit(
out_idx=[4, 5, 6],
pass_configs={ pass_configs={
"tl.disable_tma_lower": True, tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
"tl.disable_warp_specialized": True tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
}) })
def chunk_linear_attn_bwd_kernel( def tl_fused_chunk_bwd_kernel(
B, B,
S, S,
H, H,
...@@ -30,19 +31,19 @@ def chunk_linear_attn_bwd_kernel( ...@@ -30,19 +31,19 @@ def chunk_linear_attn_bwd_kernel(
chunk_size = 64 chunk_size = 64
BK = BV = 64 # Set to 128 can be faster, but has some numerical differences with FLA BK = BV = 64 # Set to 128 can be faster, but has some numerical differences with FLA
assert S % chunk_size == 0 and DK % BK == 0 and DV % BV == 0 assert S % chunk_size == 0 and DK % BK == 0 and DV % BV == 0
NK = tl.cdiv(DK, BK) NK = tilelang.cdiv(DK, BK)
NV = tl.cdiv(DV, BV) NV = tilelang.cdiv(DV, BV)
NT = tl.cdiv(S, chunk_size) NT = tilelang.cdiv(S, chunk_size)
@T.prim_func @T.prim_func
def chunk_linear_attn_bwd( def fused_chunk_linear_attn_bwd(
Q: T.Tensor([B, S, H, DK], dtype), # type: ignore Q: T.Tensor([B, S, H, DK], dtype), # type: ignore
K: T.Tensor([B, S, H, DK], dtype), # type: ignore K: T.Tensor([B, S, H, DK], dtype), # type: ignore
V: T.Tensor([B, S, H, DV], dtype), # type: ignore V: T.Tensor([B, S, H, DV], dtype), # type: ignore
dO: T.Tensor([B, S, H, DV], dtype), # type: ignore dO: T.Tensor([B, S, H, DV], dtype), # type: ignore
dQ: T.Tensor([NV, B, S, H, DK], dtype), # type: ignore dQ: T.Tensor([B, S, H, DK], accum_dtype), # type: ignore
dK: T.Tensor([NV, B, S, H, DK], dtype), # type: ignore dK: T.Tensor([B, S, H, DK], accum_dtype), # type: ignore
dV: T.Tensor([NK, B, S, H, DV], dtype), # type: ignore dV: T.Tensor([B, S, H, DV], accum_dtype), # type: ignore
): ):
with T.Kernel(NV, NK, B * H) as (i_v, i_k, i_bh): with T.Kernel(NV, NK, B * H) as (i_v, i_k, i_bh):
i_b = i_bh // H i_b = i_bh // H
...@@ -51,8 +52,11 @@ def chunk_linear_attn_bwd_kernel( ...@@ -51,8 +52,11 @@ def chunk_linear_attn_bwd_kernel(
ds = T.alloc_fragment([chunk_size, chunk_size], accum_dtype) ds = T.alloc_fragment([chunk_size, chunk_size], accum_dtype)
ds_shared = T.alloc_shared([chunk_size, chunk_size], dtype) ds_shared = T.alloc_shared([chunk_size, chunk_size], dtype)
dq = T.alloc_fragment([chunk_size, BK], accum_dtype) dq = T.alloc_fragment([chunk_size, BK], accum_dtype)
dq_shared = T.alloc_shared([chunk_size, BK], accum_dtype)
dk = T.alloc_fragment([chunk_size, BK], accum_dtype) dk = T.alloc_fragment([chunk_size, BK], accum_dtype)
dk_shared = T.alloc_shared([chunk_size, BK], accum_dtype)
dv = T.alloc_fragment([chunk_size, BV], accum_dtype) dv = T.alloc_fragment([chunk_size, BV], accum_dtype)
dv_shared = T.alloc_shared([chunk_size, BV], accum_dtype)
q = T.alloc_shared([chunk_size, BK], dtype) q = T.alloc_shared([chunk_size, BK], dtype)
k = T.alloc_shared([chunk_size, BK], dtype) k = T.alloc_shared([chunk_size, BK], dtype)
v = T.alloc_shared([chunk_size, BV], dtype) v = T.alloc_shared([chunk_size, BV], dtype)
...@@ -61,22 +65,19 @@ def chunk_linear_attn_bwd_kernel( ...@@ -61,22 +65,19 @@ def chunk_linear_attn_bwd_kernel(
h_shared = T.alloc_shared([BV, BK], dtype) h_shared = T.alloc_shared([BV, BK], dtype)
dh = T.alloc_fragment([BK, BV], accum_dtype) dh = T.alloc_fragment([BK, BV], accum_dtype)
dh_shared = T.alloc_shared([BK, BV], dtype) dh_shared = T.alloc_shared([BK, BV], dtype)
T.clear(h)
T.clear(dh)
T.annotate_layout({ T.annotate_layout({
ds_shared: tl.layout.make_swizzled_layout(ds_shared), dq_shared: tilelang.layout.make_swizzled_layout(dq_shared),
q: tl.layout.make_swizzled_layout(q), dk_shared: tilelang.layout.make_swizzled_layout(dk_shared),
k: tl.layout.make_swizzled_layout(k), dv_shared: tilelang.layout.make_swizzled_layout(dv_shared)
v: tl.layout.make_swizzled_layout(v),
do: tl.layout.make_swizzled_layout(do),
h_shared: tl.layout.make_swizzled_layout(h_shared),
dh_shared: tl.layout.make_swizzled_layout(dh_shared)
}) })
T.use_swizzle(10) T.use_swizzle(10)
T.clear(h)
T.clear(dh)
# Calculate dQ # Calculate dQ
for i in T.Pipelined(0, NT, num_stages=1): for i in T.Pipelined(0, NT):
T.copy(K[i_b, i * chunk_size:(i + 1) * chunk_size, i_h, i_k * BK:(i_k + 1) * BK], k) T.copy(K[i_b, i * chunk_size:(i + 1) * chunk_size, i_h, i_k * BK:(i_k + 1) * BK], k)
T.copy(V[i_b, i * chunk_size:(i + 1) * chunk_size, i_h, i_v * BV:(i_v + 1) * BV], v) T.copy(V[i_b, i * chunk_size:(i + 1) * chunk_size, i_h, i_v * BV:(i_v + 1) * BV], v)
T.copy(dO[i_b, i * chunk_size:(i + 1) * chunk_size, i_h, i_v * BV:(i_v + 1) * BV], T.copy(dO[i_b, i * chunk_size:(i + 1) * chunk_size, i_h, i_v * BV:(i_v + 1) * BV],
...@@ -92,12 +93,13 @@ def chunk_linear_attn_bwd_kernel( ...@@ -92,12 +93,13 @@ def chunk_linear_attn_bwd_kernel(
T.gemm(v, k, h, transpose_A=True) T.gemm(v, k, h, transpose_A=True)
for row, col in T.Parallel(chunk_size, BK): for row, col in T.Parallel(chunk_size, BK):
dq[row, col] *= scale dq[row, col] *= scale
T.copy( T.copy(dq, dq_shared)
dq, dQ[i_v, i_b, i * chunk_size:(i + 1) * chunk_size, i_h, T.atomic_add(
i_k * BK:(i_k + 1) * BK]) dQ[i_b, i * chunk_size:(i + 1) * chunk_size, i_h, i_k * BK:(i_k + 1) * BK],
dq_shared)
# Calculate dK, dV (reversely) # Calculate dK, dV (reversely)
for i in T.Pipelined(1, NT + 1, num_stages=1): for i in T.Pipelined(1, NT + 1):
start = NT - i start = NT - i
for row, col in T.Parallel(chunk_size, BK): for row, col in T.Parallel(chunk_size, BK):
q[row, col] = Q[i_b, start * chunk_size + row, i_h, i_k * BK + col] * scale q[row, col] = Q[i_b, start * chunk_size + row, i_h, i_k * BK + col] * scale
...@@ -131,53 +133,90 @@ def chunk_linear_attn_bwd_kernel( ...@@ -131,53 +133,90 @@ def chunk_linear_attn_bwd_kernel(
# Update dh # Update dh
T.gemm(q, do, dh, transpose_A=True) T.gemm(q, do, dh, transpose_A=True)
T.copy( T.copy(dk, dk_shared)
dk, dK[i_v, i_b, start * chunk_size:(start + 1) * chunk_size, i_h, T.atomic_add(
i_k * BK:(i_k + 1) * BK]) dK[i_b, start * chunk_size:(start + 1) * chunk_size, i_h,
T.copy( i_k * BK:(i_k + 1) * BK], dk_shared)
dv, dV[i_k, i_b, start * chunk_size:(start + 1) * chunk_size, i_h, T.copy(dv, dv_shared)
i_v * BV:(i_v + 1) * BV]) T.atomic_add(
dV[i_b, start * chunk_size:(start + 1) * chunk_size, i_h,
return chunk_linear_attn_bwd i_v * BV:(i_v + 1) * BV], dv_shared)
return fused_chunk_linear_attn_bwd
def postprocess(dQ, dK, dV):
dQ = dQ[0] if dQ.size(0) == 1 else dQ.sum(0)
dK = dK[0] if dK.size(0) == 1 else dK.sum(0) def tl_fused_chunk_bwd(Q, K, V, dO):
dV = dV[0] if dV.size(0) == 1 else dV.sum(0) B, S, H, D = Q.shape
return dQ, dK, dV kernel = tl_fused_chunk_bwd_kernel(B, S, H, D, D)
dQ = torch.zeros_like(Q, dtype=torch.float32)
dK = torch.zeros_like(K, dtype=torch.float32)
def main(): dV = torch.zeros_like(V, dtype=torch.float32)
parser = argparse.ArgumentParser() kernel(Q, K, V, dO, dQ, dK, dV)
parser.add_argument('--B', type=int, default=8, help='Batch size') return dQ.to(torch.float16), dK.to(torch.float16), dV.to(torch.float16)
parser.add_argument('--S', type=int, default=4096, help='Seq len')
parser.add_argument('--H', type=int, default=32, help='Num heads')
parser.add_argument('--D', type=int, default=256, help='Head dim') def ref_program(q: torch.Tensor,
args = parser.parse_args() k: torch.Tensor,
B, S, H, D = args.B, args.S, args.H, args.D v: torch.Tensor,
scale: Optional[float] = None) -> Tuple[torch.Tensor, torch.Tensor]:
q, k, v = q.float(), k.float(), v.float()
if scale is None:
scale = q.shape[-1]**-0.5
chunk_size = 64
q = rearrange(q, 'b (n c) h d -> b h n c d', c=chunk_size) * scale
k = rearrange(k, 'b (n c) h d -> b h n c d', c=chunk_size)
v = rearrange(v, 'b (n c) h d -> b h n c d', c=chunk_size)
kv = k.transpose(-1, -2) @ v
kv = kv.cumsum(2)
h = kv[:, :, -1, :, :]
kv = torch.cat([torch.zeros_like(kv[:, :, :1]), kv[:, :, :-1]], dim=2)
inter = q @ kv
intra = ((q @ k.transpose(-1, -2)).masked_fill_(
torch.triu(torch.ones(chunk_size, chunk_size, dtype=bool, device=q.device), diagonal=1),
0)) @ v
o = inter + intra
return rearrange(o, 'b h n c d -> b (n c) h d'), h
def main(B=1, S=1024, H=16, D=128):
q = torch.randn((B, S, H, D), device='cuda', dtype=torch.float16, requires_grad=True) q = torch.randn((B, S, H, D), device='cuda', dtype=torch.float16, requires_grad=True)
k = torch.randn((B, S, H, D), device='cuda', dtype=torch.float16, requires_grad=True) k = torch.randn((B, S, H, D), device='cuda', dtype=torch.float16, requires_grad=True)
v = torch.randn((B, S, H, D), device='cuda', dtype=torch.float16, requires_grad=True) v = torch.randn((B, S, H, D), device='cuda', dtype=torch.float16, requires_grad=True)
do = torch.randn((B, S, H, D), device='cuda', dtype=torch.float16) do = torch.randn((B, S, H, D), device='cuda', dtype=torch.float16)
kernel = chunk_linear_attn_bwd_kernel(B, S, H, D, D) # qk norm is necessary for linear attn
dq, dk, dv = postprocess(*kernel(q, k, v, do)) q = l2norm_fwd(q)[0].requires_grad_(True)
o_ref, _ = fused_chunk_linear_attn(q, k, v, output_final_state=True, normalize=False) k = l2norm_fwd(k)[0].requires_grad_(True)
dq, dk, dv = tl_fused_chunk_bwd(q, k, v, do)
q.grad = k.grad = v.grad = None
o_ref, _ = ref_program(q, k, v)
o_ref.backward(do, retain_graph=True) o_ref.backward(do, retain_graph=True)
if torch.allclose(dq, q.grad) and torch.allclose(dk, k.grad) and torch.allclose(dv, v.grad):
assert torch.allclose(
dq, q.grad, atol=1e-2, rtol=1e-2), f'dq max err: {(dq - q.grad).abs().max()}'
assert torch.allclose(
dk, k.grad, atol=1e-2, rtol=1e-2), f'dk max err: {(dk - k.grad).abs().max()}'
assert torch.allclose(
dv, v.grad, atol=1e-2, rtol=1e-2), f'dv max err: {(dv - v.grad).abs().max()}'
print('Passed all tests!✅') print('Passed all tests!✅')
else:
print('Failed some tests!❌') # Benchmark
t1 = do_bench(lambda: o_ref.backward(do, retain_graph=True), warmup=25, rep=100)
q.grad = k.grad = v.grad = None q.grad = k.grad = v.grad = None
o_ref, _ = fused_chunk_linear_attn(q, k, v, output_final_state=True, normalize=False) o_ref, _ = fused_chunk_linear_attn(q, k, v, output_final_state=True, normalize=False)
t2 = do_bench(lambda: postprocess(*kernel(q, k, v, do)), warmup=25, rep=100) t1 = do_bench(lambda: o_ref.backward(do, retain_graph=True), backend='cupti')
t2 = do_bench(lambda: tl_fused_chunk_bwd(q, k, v, do), backend='cupti')
print(f'Triton latency: {t1:.3f} ms') print(f'Triton latency: {t1:.3f} ms')
print(f'TileLang latency: {t2:.3f} ms') print(f'TileLang latency: {t2:.3f} ms')
print(f'Speedup: {t1/t2:.3f}x') print(f'Speedup: {t1/t2:.3f}x')
if __name__ == '__main__': if __name__ == '__main__':
main() parser = argparse.ArgumentParser()
parser.add_argument('--B', type=int, default=8, help='Batch size')
parser.add_argument('--S', type=int, default=1024, help='Seq len')
parser.add_argument('--H', type=int, default=32, help='Num heads')
parser.add_argument('--D', type=int, default=128, help='Head dim')
args = parser.parse_args()
main(args.B, args.S, args.H, args.D)
import torch import torch
import tilelang as tl import tilelang
import tilelang.language as T import tilelang.language as T
from tilelang.profiler import do_bench from tilelang.profiler import do_bench
import argparse import argparse
from fla.ops.linear_attn import fused_chunk_linear_attn # We compare with FLA from fla.ops.linear_attn import fused_chunk_linear_attn # We compare with FLA
from fla.modules.l2norm import l2norm_fwd
from einops import rearrange
from typing import Optional, Tuple
@tl.jit( @tilelang.jit(
out_idx=[3, 4], out_idx=[4],
pass_configs={ pass_configs={
"tl.disable_tma_lower": True, tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
"tl.disable_warp_specialized": True tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
}) })
def chunk_linear_attn_fwd_kernel( def tl_fused_chunk_fwd_kernel(
B, B,
S, S,
H, H,
...@@ -30,16 +32,16 @@ def chunk_linear_attn_fwd_kernel( ...@@ -30,16 +32,16 @@ def chunk_linear_attn_fwd_kernel(
chunk_size = 64 chunk_size = 64
BK = BV = 64 # Set to 128 can be faster, but has some numerical differences with FLA BK = BV = 64 # Set to 128 can be faster, but has some numerical differences with FLA
assert S % chunk_size == 0 and DK % BK == 0 and DV % BV == 0 assert S % chunk_size == 0 and DK % BK == 0 and DV % BV == 0
NK = tl.cdiv(DK, BK) NK = tilelang.cdiv(DK, BK)
NV = tl.cdiv(DV, BV) NV = tilelang.cdiv(DV, BV)
NT = tl.cdiv(S, chunk_size) NT = tilelang.cdiv(S, chunk_size)
@T.prim_func @T.prim_func
def chunk_linear_attn_fwd( def fused_chunk_linear_attn_fwd(
Q: T.Tensor([B, S, H, DK], dtype), # type: ignore Q: T.Tensor([B, S, H, DK], dtype), # type: ignore
K: T.Tensor([B, S, H, DK], dtype), # type: ignore K: T.Tensor([B, S, H, DK], dtype), # type: ignore
V: T.Tensor([B, S, H, DV], dtype), # type: ignore V: T.Tensor([B, S, H, DV], dtype), # type: ignore
O: T.Tensor([NK, B, S, H, DV], dtype), # type: ignore O: T.Tensor([B, S, H, DV], accum_dtype), # type: ignore
final_state: T.Tensor([B, H, DK, DV], accum_dtype)): # type: ignore final_state: T.Tensor([B, H, DK, DV], accum_dtype)): # type: ignore
with T.Kernel(NV, NK, B * H) as (i_v, i_k, i_bh): with T.Kernel(NV, NK, B * H) as (i_v, i_k, i_bh):
i_b = i_bh // H i_b = i_bh // H
...@@ -53,18 +55,14 @@ def chunk_linear_attn_fwd_kernel( ...@@ -53,18 +55,14 @@ def chunk_linear_attn_fwd_kernel(
s = T.alloc_fragment([chunk_size, chunk_size], accum_dtype) s = T.alloc_fragment([chunk_size, chunk_size], accum_dtype)
s_shared = T.alloc_shared([chunk_size, chunk_size], dtype) s_shared = T.alloc_shared([chunk_size, chunk_size], dtype)
o = T.alloc_fragment([chunk_size, BV], accum_dtype) o = T.alloc_fragment([chunk_size, BV], accum_dtype)
T.clear(h) o_shared = T.alloc_shared([chunk_size, BV], accum_dtype)
T.annotate_layout({ T.annotate_layout({o_shared: tilelang.layout.make_swizzled_layout(o_shared)})
q: tl.layout.make_swizzled_layout(q),
k: tl.layout.make_swizzled_layout(k),
v: tl.layout.make_swizzled_layout(v),
h_shared: tl.layout.make_swizzled_layout(h_shared),
s_shared: tl.layout.make_swizzled_layout(s_shared),
})
T.use_swizzle(10) T.use_swizzle(10)
for i in T.Pipelined(0, NT, num_stages=2): T.clear(h)
for i in T.Pipelined(0, NT):
for row, col in T.Parallel(chunk_size, BK): for row, col in T.Parallel(chunk_size, BK):
q[row, col] = Q[i_b, i * chunk_size + row, i_h, i_k * BK + col] * scale q[row, col] = Q[i_b, i * chunk_size + row, i_h, i_k * BK + col] * scale
T.copy(K[i_b, i * chunk_size:(i + 1) * chunk_size, i_h, i_k * BK:(i_k + 1) * BK], k) T.copy(K[i_b, i * chunk_size:(i + 1) * chunk_size, i_h, i_k * BK:(i_k + 1) * BK], k)
...@@ -78,52 +76,80 @@ def chunk_linear_attn_fwd_kernel( ...@@ -78,52 +76,80 @@ def chunk_linear_attn_fwd_kernel(
T.copy(h, h_shared) T.copy(h, h_shared)
T.gemm(k, v, h, transpose_A=True) T.gemm(k, v, h, transpose_A=True)
T.gemm(q, h_shared, o) T.gemm(q, h_shared, o)
T.copy( T.copy(o, o_shared)
o, O[i_k, i_b, i * chunk_size:(i + 1) * chunk_size, i_h, T.atomic_add(
i_v * BV:(i_v + 1) * BV]) O[i_b, i * chunk_size:(i + 1) * chunk_size, i_h, i_v * BV:(i_v + 1) * BV],
o_shared)
#TODO: consider using vectorized atomic add or tma reduce for sm90
# Output final state # Output final state
T.copy(h, final_state[i_b, i_h, i_k * BK:(i_k + 1) * BK, i_v * BV:(i_v + 1) * BV]) T.copy(h, final_state[i_b, i_h, i_k * BK:(i_k + 1) * BK, i_v * BV:(i_v + 1) * BV])
return chunk_linear_attn_fwd return fused_chunk_linear_attn_fwd
def postprocess(o, h): def tl_fused_chunk_fwd(q, k, v):
o = o[0] if o.size(0) == 1 else o.sum(0) B, S, H, D = q.shape
kernel = tl_fused_chunk_fwd_kernel(B, S, H, D, D)
o = torch.zeros((B, S, H, D), device='cuda', dtype=torch.float32)
h = kernel(q, k, v, o)
return o, h return o, h
def main(): def ref_program(q: torch.Tensor,
parser = argparse.ArgumentParser() k: torch.Tensor,
parser.add_argument('--B', type=int, default=8, help='Batch size') v: torch.Tensor,
parser.add_argument('--S', type=int, default=4096, help='Seq len') scale: Optional[float] = None) -> Tuple[torch.Tensor, torch.Tensor]:
parser.add_argument('--H', type=int, default=32, help='Num heads') q, k, v = q.float(), k.float(), v.float()
parser.add_argument('--D', type=int, default=256, help='Head dim') if scale is None:
args = parser.parse_args() scale = q.shape[-1]**-0.5
B, S, H, D = args.B, args.S, args.H, args.D chunk_size = 64
q = rearrange(q, 'b (n c) h d -> b h n c d', c=chunk_size) * scale
k = rearrange(k, 'b (n c) h d -> b h n c d', c=chunk_size)
v = rearrange(v, 'b (n c) h d -> b h n c d', c=chunk_size)
kv = k.transpose(-1, -2) @ v
kv = kv.cumsum(2)
h = kv[:, :, -1, :, :]
kv = torch.cat([torch.zeros_like(kv[:, :, :1]), kv[:, :, :-1]], dim=2)
inter = q @ kv
intra = ((q @ k.transpose(-1, -2)).masked_fill_(
torch.triu(torch.ones(chunk_size, chunk_size, dtype=bool, device=q.device), diagonal=1),
0)) @ v
o = inter + intra
return rearrange(o, 'b h n c d -> b (n c) h d'), h
def main(B=1, S=512, H=16, D=128):
q = torch.randn((B, S, H, D), device='cuda', dtype=torch.float16) q = torch.randn((B, S, H, D), device='cuda', dtype=torch.float16)
k = torch.randn((B, S, H, D), device='cuda', dtype=torch.float16) k = torch.randn((B, S, H, D), device='cuda', dtype=torch.float16)
v = torch.randn((B, S, H, D), device='cuda', dtype=torch.float16) v = torch.randn((B, S, H, D), device='cuda', dtype=torch.float16)
kernel = chunk_linear_attn_fwd_kernel(B, S, H, D, D) # qk norm is necessary for linear attn
o, h = postprocess(*kernel(q, k, v)) q, _ = l2norm_fwd(q)
o_ref, h_ref = fused_chunk_linear_attn(q, k, v, output_final_state=True, normalize=False) k, _ = l2norm_fwd(k)
if torch.allclose(o, o_ref) and torch.allclose(h, h_ref): o, h = tl_fused_chunk_fwd(q, k, v)
o_ref, h_ref = ref_program(q, k, v)
assert torch.allclose(o, o_ref, atol=1e-2, rtol=1e-2), f'o max err: {(o - o_ref).abs().max()}'
assert torch.allclose(h, h_ref, atol=1e-2, rtol=1e-2), f'h max err: {(h - h_ref).abs().max()}'
print('Passed all tests!✅') print('Passed all tests!✅')
else:
print('Failed some tests!❌')
t1 = do_bench( t1 = do_bench(
lambda: fused_chunk_linear_attn(q, k, v, output_final_state=True, normalize=False)[0], lambda: fused_chunk_linear_attn(q, k, v, output_final_state=True, normalize=False),
warmup=25, backend='cupti')
rep=100) t2 = do_bench(lambda: tl_fused_chunk_fwd(q, k, v), backend='cupti')
t2 = do_bench(lambda: postprocess(*kernel(q, k, v)), warmup=25, rep=100)
print(f'Triton latency: {t1:.3f} ms') print(f'Triton latency: {t1:.3f} ms')
print(f'TileLang latency: {t2:.3f} ms') print(f'TileLang latency: {t2:.3f} ms')
print(f'Speedup: {t1/t2:.3f}x') print(f'Speedup: {t1/t2:.3f}x')
if __name__ == '__main__': if __name__ == '__main__':
main() parser = argparse.ArgumentParser()
parser.add_argument('--B', type=int, default=8, help='Batch size')
parser.add_argument('--S', type=int, default=1024, help='Seq len')
parser.add_argument('--H', type=int, default=32, help='Num heads')
parser.add_argument('--D', type=int, default=128, help='Head dim')
args = parser.parse_args()
main(args.B, args.S, args.H, args.D)
import tilelang.testing
import example_linear_attn_fwd
import example_linear_attn_bwd
@tilelang.testing.requires_cuda
def test_example_linear_attn_fwd():
example_linear_attn_fwd.main()
@tilelang.testing.requires_cuda
def test_example_linear_attn_bwd():
example_linear_attn_bwd.main()
if __name__ == "__main__":
tilelang.testing.main()
...@@ -7,3 +7,4 @@ torch ...@@ -7,3 +7,4 @@ torch
torch>=2.7; platform_system == 'Darwin' torch>=2.7; platform_system == 'Darwin'
tqdm>=4.62.3 tqdm>=4.62.3
typing-extensions>=4.10.0 typing-extensions>=4.10.0
flash-linear-attention==0.3.2
\ No newline at end of file
...@@ -128,6 +128,7 @@ def atomic_add(dst: Buffer, ...@@ -128,6 +128,7 @@ def atomic_add(dst: Buffer,
value (PrimExpr): Value to add atomically. value (PrimExpr): Value to add atomically.
memory_order (Optional[str]): Optional memory-order name controlling the atomic operation's ordering. memory_order (Optional[str]): Optional memory-order name controlling the atomic operation's ordering.
return_prev (bool): If True, return the previous value; if False, return handle (default False). return_prev (bool): If True, return the previous value; if False, return handle (default False).
use_tma (bool): If True, use TMA (cp.reduce) to perform the atomic add. This is available only for sm90+ (default False).
Returns: Returns:
PrimExpr: A handle representing the atomic addition operation, or the previous value if return_prev is True. PrimExpr: A handle representing the atomic addition operation, or the previous value if return_prev is True.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment