Unverified Commit 7cd0da99 authored by Tong WU's avatar Tong WU Committed by GitHub
Browse files

[Example] Add support for `bfloat16` and user-defined `sm_scale` in attention sink examples (#924)



* revert split+sum template for MHA backward

* lint

* Update example_mha_bwd.py

* Update example_mha_bwd_wgmma_pipelined.py

* Refactor attention sink examples to support bf16 and user-defined softmax scale

* fix typos

* Adding compile flags for fast math optimizations and enabling BF16 support in both GQA and MHA backward implementations.

* Update backward configuration for GQA and MHA examples to align with flash attention

* Refactor GQA backward implementation to improve atomic add performance

* Allow for slightly larger numerical error for bf16

* upd readme to show bf16 benchmark results

* lint

* fix ci and lint

* fix comments and lint

* refactor atomic add

---------
Co-authored-by: default avatarLei Wang <34334180+LeiWang1999@users.noreply.github.com>
parent 8f07b9b0
...@@ -206,8 +206,7 @@ def flashattn_bwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, bloc ...@@ -206,8 +206,7 @@ def flashattn_bwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, bloc
T.clear(dq) T.clear(dq)
T.gemm(dsT_shared, K_shared, dq, transpose_A=True) T.gemm(dsT_shared, K_shared, dq, transpose_A=True)
for i, j in T.Parallel(block_N, dim_qk): for i, j in T.Parallel(block_N, dim_qk):
if k * block_N + i < seq_len: T.atomic_add(dQ[bz, k * block_N + i, bx, j], dq[i, j])
T.atomic_add(dQ[bz, k * block_N + i, bx, j], dq[i, j])
for i, j in T.Parallel(block_M, dim_v): for i, j in T.Parallel(block_M, dim_v):
T.atomic_add(dV[bz, by * block_M + i, bx // groups, j], dv[i, j]) T.atomic_add(dV[bz, by * block_M + i, bx // groups, j], dv[i, j])
......
# Attention Sink # Attention Sink
We compare with an optimized version of the official Triton implementation at [here](https://github.com/openai/gpt-oss/blob/main/gpt_oss/triton/attention.py). We compare with an optimized version of the official Triton implementation [here](https://github.com/openai/gpt-oss/blob/main/gpt_oss/triton/attention.py).
## Algorithm ## Algorithm
...@@ -25,22 +25,22 @@ where $P_{b, h, q}$ is the proportion of $sink_h$ in the softmax in the $b$-th b ...@@ -25,22 +25,22 @@ where $P_{b, h, q}$ is the proportion of $sink_h$ in the softmax in the $b$-th b
### Results ### Results
- dtype=float16 - dtype=bfloat16
- batch_size=1, heads=64, kv_heads=8 (the setting of GPT-OSS-120B) - batch_size=1, heads=64, kv_heads=8 (the setting of GPT-OSS-120B)
- Full attention is adopted. - Full attention is adopted.
| SEQ_LEN | headdim | Triton TFLOPs | TileLang TFLOPs | Speedup | | SEQ_LEN | headdim | Triton TFLOPs | TileLang TFLOPs | Speedup |
|---------|---------|---------------|----------------------|---------| |---------|---------|---------------|----------------------|---------|
| 2048 | 64 | 231.55 | **277.07** | 1.20x | | 2048 | 64 | 232.98 | **281.89** | 1.21x |
| 2048 | 128 | 313.55 | **393.98** | 1.26x | | 2048 | 128 | 321.55 | **417.98** | 1.30x |
| | | | | | | | | | | |
| 4096 | 64 | 272.17 | **337.30** | 1.24x | | 4096 | 64 | 280.70 | **349.47** | 1.25x |
| 4096 | 128 | 356.35 | **461.54** | 1.30x | | 4096 | 128 | 369.61 | **497.13** | 1.35x |
| | | | | | | | | | | |
| 8192 | 64 | 289.93 | **353.81** | 1.22x | | 8192 | 64 | 299.04 | **385.56** | 1.29x |
| 8192 | 128 | 392.18 | **482.50** | 1.23x | | 8192 | 128 | 399.39 | **507.93** | 1.27x |
| | | | | | | | | | | |
| 16384 | 64 | 299.52 | **377.44** | 1.26x | | 16384 | 64 | 309.46 | **400.62** | 1.29x |
| 16384 | 128 | 404.64 | **519.02** | 1.28x | | 16384 | 128 | 418.99 | **549.11** | 1.31x |
> The backward performance will be further optimized via fine-grained manual pipelining of FA3 in the tilelang kernel. > The backward performance will be further optimized in the future.
\ No newline at end of file \ No newline at end of file
...@@ -5,43 +5,50 @@ import tilelang ...@@ -5,43 +5,50 @@ import tilelang
from tilelang.profiler import do_bench from tilelang.profiler import do_bench
import tilelang.language as T import tilelang.language as T
import argparse import argparse
from typing import Optional
def get_bwd_configs(): def get_bwd_configs():
sm_major, sm_minor = torch.cuda.get_device_capability() sm_major, sm_minor = torch.cuda.get_device_capability()
sm_version = sm_major * 10 + sm_minor sm_version = sm_major * 10 + sm_minor
if sm_version == 80: if sm_version == 80:
return 64, 64, 1, 128 return 64, 32, 1, 128
elif sm_version == 90: elif sm_version == 90:
return 128, 128, 2, 256 return 128, 32, 2, 256
else: else:
raise ValueError(f"Unsupported SM version: {sm_version}") raise ValueError(f"Unsupported SM version: {sm_version}")
@tilelang.jit( @tilelang.jit(
out_idx=[3, 4], pass_configs={ out_idx=[3, 4],
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
}) },
compile_flags=["-O3", "-DENABLE_BF16"])
def flashattn_fwd( def flashattn_fwd(
batch, batch,
heads, heads,
seq_len, seq_len,
dim, dim,
groups=1, groups=1,
window_size=None, # None for full attention, window_size=None, # None for full attention
block_M=128, sm_scale=None,
block_N=128, block_M=64,
num_stages=2, block_N=64,
threads=256): num_stages=1,
threads=128,
dtype: str = "float16"):
if window_size is not None: if window_size is not None:
assert window_size % block_N == 0, "window_size must be divisible by block_N" assert window_size % block_N == 0, "window_size must be divisible by block_N"
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) if sm_scale is None:
sm_scale = (1.0 / dim)**0.5
scale = sm_scale * 1.44269504 # log2(e)
head_kv = heads // groups head_kv = heads // groups
q_shape = [batch, heads, seq_len, dim] q_shape = [batch, heads, seq_len, dim]
kv_shape = [batch, head_kv, seq_len, dim] kv_shape = [batch, head_kv, seq_len, dim]
dtype = "float16"
accum_dtype = "float" accum_dtype = "float"
@T.prim_func @T.prim_func
...@@ -133,11 +140,12 @@ def flashattn_fwd( ...@@ -133,11 +140,12 @@ def flashattn_fwd(
@tilelang.jit( @tilelang.jit(
out_idx=[2], pass_configs={ out_idx=[2],
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
}) },
def flashattn_bwd_preprocess(batch, heads, seq_len, dim): compile_flags=["-O3", "-DENABLE_BF16"])
dtype = "float16" def flashattn_bwd_preprocess(batch, heads, seq_len, dim, dtype: str = "float16"):
accum_dtype = "float" accum_dtype = "float"
shape = [batch, heads, seq_len, dim] shape = [batch, heads, seq_len, dim]
blk = 32 blk = 32
...@@ -172,11 +180,12 @@ def make_dq_layout(dQ): ...@@ -172,11 +180,12 @@ def make_dq_layout(dQ):
@tilelang.jit( @tilelang.jit(
out_idx=[1], pass_configs={ out_idx=[1],
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
}) },
def flashattn_bwd_postprocess(batch, heads, seq_len, dim): compile_flags=["-O3", "-DENABLE_BF16"])
dtype = "float16" def flashattn_bwd_postprocess(batch, heads, seq_len, dim, dtype: str = "float16"):
accum_dtype = "float" accum_dtype = "float"
shape = [batch, heads, seq_len, dim] shape = [batch, heads, seq_len, dim]
blk = 64 blk = 64
...@@ -196,16 +205,26 @@ def flashattn_bwd_postprocess(batch, heads, seq_len, dim): ...@@ -196,16 +205,26 @@ def flashattn_bwd_postprocess(batch, heads, seq_len, dim):
return flash_bwd_post return flash_bwd_post
@tilelang.jit(pass_configs={ @tilelang.jit(
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, pass_configs={
}) tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
def flashattn_bwd(batch, heads, seq_len, dim, groups, window_size=None): # None for full attention },
sm_scale = (1.0 / dim)**0.5 compile_flags=["-O3", "-DENABLE_BF16"])
def flashattn_bwd(batch,
heads,
seq_len,
dim,
groups,
window_size=None,
sm_scale=None,
dtype="float16"): # None for full attention
if sm_scale is None:
sm_scale = (1.0 / dim)**0.5
scale = sm_scale * 1.44269504 # log2(e) scale = sm_scale * 1.44269504 # log2(e)
head_kv = heads // groups head_kv = heads // groups
q_shape = [batch, heads, seq_len, dim] q_shape = [batch, heads, seq_len, dim]
kv_shape = [batch, head_kv, seq_len, dim] kv_shape = [batch, head_kv, seq_len, dim]
dtype = "float16"
accum_dtype = "float" accum_dtype = "float"
block_M, block_N, num_stages, threads = get_bwd_configs() block_M, block_N, num_stages, threads = get_bwd_configs()
...@@ -222,8 +241,8 @@ def flashattn_bwd(batch, heads, seq_len, dim, groups, window_size=None): # None ...@@ -222,8 +241,8 @@ def flashattn_bwd(batch, heads, seq_len, dim, groups, window_size=None): # None
lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
dQ: T.Tensor(q_shape, accum_dtype), # type: ignore dQ: T.Tensor(q_shape, accum_dtype), # type: ignore
dK: T.Tensor(kv_shape, dtype), # type: ignore dK: T.Tensor(kv_shape, accum_dtype), # type: ignore
dV: T.Tensor(kv_shape, dtype), # type: ignore dV: T.Tensor(kv_shape, accum_dtype), # type: ignore
): ):
with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=threads) as (bx, by, bz): with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=threads) as (bx, by, bz):
K_shared = T.alloc_shared([block_M, dim], dtype) K_shared = T.alloc_shared([block_M, dim], dtype)
...@@ -240,8 +259,8 @@ def flashattn_bwd(batch, heads, seq_len, dim, groups, window_size=None): # None ...@@ -240,8 +259,8 @@ def flashattn_bwd(batch, heads, seq_len, dim, groups, window_size=None): # None
dv = T.alloc_fragment([block_M, dim], accum_dtype) dv = T.alloc_fragment([block_M, dim], accum_dtype)
dk = T.alloc_fragment([block_M, dim], accum_dtype) dk = T.alloc_fragment([block_M, dim], accum_dtype)
dq = T.alloc_fragment([block_N, dim], accum_dtype) dq = T.alloc_fragment([block_N, dim], accum_dtype)
dv_shared = T.alloc_shared([block_M, dim], dtype) dv_shared = T.alloc_shared([block_M, dim], accum_dtype)
dk_shared = T.alloc_shared([block_M, dim], dtype) dk_shared = T.alloc_shared([block_M, dim], accum_dtype)
T.annotate_layout({ T.annotate_layout({
dQ: make_dq_layout(dQ), dQ: make_dq_layout(dQ),
...@@ -281,7 +300,7 @@ def flashattn_bwd(batch, heads, seq_len, dim, groups, window_size=None): # None ...@@ -281,7 +300,7 @@ def flashattn_bwd(batch, heads, seq_len, dim, groups, window_size=None): # None
T.clear(dsT) T.clear(dsT)
T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
T.copy(qkT, qkT_cast) T.copy(qkT, qkT_cast)
T.gemm(qkT_cast, B=do, C=dv, policy=T.GemmWarpPolicy.FullRow) T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow)
T.copy(Delta[bz, bx, k * block_N:(k + 1) * block_N], delta) T.copy(Delta[bz, bx, k * block_N:(k + 1) * block_N], delta)
...@@ -292,21 +311,18 @@ def flashattn_bwd(batch, heads, seq_len, dim, groups, window_size=None): # None ...@@ -292,21 +311,18 @@ def flashattn_bwd(batch, heads, seq_len, dim, groups, window_size=None): # None
T.copy(dsT_cast, dsT_shared) T.copy(dsT_cast, dsT_shared)
T.clear(dq) T.clear(dq)
T.gemm(dsT_shared, K_shared, dq, transpose_A=True) T.gemm(dsT_shared, K_shared, dq, transpose_A=True)
for i, j in T.Parallel(block_N, dim): T.atomic_add(dQ[bz, bx, k * block_N:(k + 1) * block_N, :], dq)
if k * block_N + i < seq_len:
T.atomic_add(dQ[bz, bx, k * block_N + i, j], dq[i, j])
for i, j in T.Parallel(block_M, dim): T.copy(dv, dv_shared)
T.atomic_add(dV[bz, bx // groups, by * block_M + i, j], dv[i, j]) T.atomic_add(dV[bz, bx // groups, by * block_M:(by + 1) * block_M, :], dv_shared)
for i, j in T.Parallel(block_M, dim): T.copy(dk, dk_shared)
T.atomic_add(dK[bz, bx // groups, by * block_M + i, j], dk[i, j]) T.atomic_add(dK[bz, bx // groups, by * block_M:(by + 1) * block_M, :], dk_shared)
return flash_bwd return flash_bwd
@tilelang.jit(out_idx=-1) @tilelang.jit(out_idx=-1)
def flashattn_bwd_dsink(batch, heads, seq_len, block=256): def flashattn_bwd_dsink(batch, heads, seq_len, block=256, dtype: str = "float16"):
dtype = "float16"
accum_dtype = "float" accum_dtype = "float"
shape = [batch, heads, seq_len] shape = [batch, heads, seq_len]
...@@ -338,8 +354,16 @@ class _attention(torch.autograd.Function): ...@@ -338,8 +354,16 @@ class _attention(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, q, k, v, sinks, window_size, groups): def forward(ctx, q, k, v, sinks, window_size, groups):
def maybe_contiguous(x):
if x.stride(-1) != 1:
return x.contiguous()
return x
q, k, v, sinks = [maybe_contiguous(x) for x in (q, k, v, sinks)]
BATCH, H, N_CTX, D_HEAD = q.shape BATCH, H, N_CTX, D_HEAD = q.shape
kernel = flashattn_fwd(BATCH, H, N_CTX, D_HEAD, groups, window_size) dtype = "float16" if q.dtype == torch.float16 else "bfloat16"
kernel = flashattn_fwd(BATCH, H, N_CTX, D_HEAD, groups, window_size, dtype=dtype)
o, lse = kernel(q, k, v, sinks) o, lse = kernel(q, k, v, sinks)
ctx.save_for_backward(q, k, v, sinks, o, lse) ctx.save_for_backward(q, k, v, sinks, o, lse)
ctx.window_size = window_size ctx.window_size = window_size
...@@ -351,27 +375,22 @@ class _attention(torch.autograd.Function): ...@@ -351,27 +375,22 @@ class _attention(torch.autograd.Function):
q, k, v, sinks, o, lse = ctx.saved_tensors q, k, v, sinks, o, lse = ctx.saved_tensors
BATCH, H, N_CTX, D_HEAD = q.shape BATCH, H, N_CTX, D_HEAD = q.shape
groups = ctx.groups groups = ctx.groups
dtype = "float16" if q.dtype == torch.float16 else "bfloat16"
def maybe_contiguous(x): kernel_prep = flashattn_bwd_preprocess(BATCH, H, N_CTX, D_HEAD, dtype=dtype)
if x.stride(-1) != 1: kernel_post = flashattn_bwd_postprocess(BATCH, H, N_CTX, D_HEAD, dtype=dtype)
return x.contiguous()
return x
do, q, k, v, sinks, o = [maybe_contiguous(x) for x in (do, q, k, v, sinks, o)]
kernel_prep = flashattn_bwd_preprocess(BATCH, H, N_CTX, D_HEAD)
kernel_post = flashattn_bwd_postprocess(BATCH, H, N_CTX, D_HEAD)
delta = kernel_prep(o, do) delta = kernel_prep(o, do)
kernel = flashattn_bwd(BATCH, H, N_CTX, D_HEAD, groups, ctx.window_size) kernel = flashattn_bwd(BATCH, H, N_CTX, D_HEAD, groups, ctx.window_size, dtype=dtype)
q_shape = [BATCH, H, N_CTX, D_HEAD] q_shape = [BATCH, H, N_CTX, D_HEAD]
head_kv = H // groups head_kv = H // groups
kv_shape = [BATCH, head_kv, N_CTX, D_HEAD] kv_shape = [BATCH, head_kv, N_CTX, D_HEAD]
dq = torch.zeros(q_shape, dtype=torch.float32, device=q.device) # acc for atomicAdd dq = torch.zeros(q_shape, dtype=torch.float32, device=q.device) # acc for atomicAdd
dk = torch.zeros(kv_shape, dtype=torch.float16, device=q.device) dk = torch.zeros(kv_shape, dtype=torch.float32, device=q.device)
dv = torch.zeros(kv_shape, dtype=torch.float16, device=q.device) dv = torch.zeros(kv_shape, dtype=torch.float32, device=q.device)
kernel(q, k, v, do, lse, delta, dq, dk, dv) kernel(q, k, v, do, lse, delta, dq, dk, dv)
dq = kernel_post(dq) dq = kernel_post(dq)
kernel_dsink = flashattn_bwd_dsink(BATCH, H, N_CTX) kernel_dsink = flashattn_bwd_dsink(BATCH, H, N_CTX, dtype=dtype)
dsinks = kernel_dsink(sinks, delta, lse).sum(0).sum(1) dsinks = kernel_dsink(sinks, delta, lse).sum(0).sum(1)
return dq, dk, dv, dsinks, None, None return dq, dk, dv, dsinks, None, None
...@@ -385,7 +404,8 @@ def ref_program(query: torch.Tensor, ...@@ -385,7 +404,8 @@ def ref_program(query: torch.Tensor,
key: torch.Tensor, key: torch.Tensor,
value: torch.Tensor, value: torch.Tensor,
sinks: torch.Tensor, sinks: torch.Tensor,
sliding_window: int | None = None) -> torch.Tensor: sliding_window: Optional[int] = None,
dtype: torch.dtype = torch.float16) -> torch.Tensor:
key = key.transpose(1, 2).contiguous() key = key.transpose(1, 2).contiguous()
value = value.transpose(1, 2).contiguous() value = value.transpose(1, 2).contiguous()
...@@ -423,7 +443,7 @@ def ref_program(query: torch.Tensor, ...@@ -423,7 +443,7 @@ def ref_program(query: torch.Tensor,
output = torch.einsum("bhmqk,bkhmd->bqhmd", scores, value.float()) output = torch.einsum("bhmqk,bkhmd->bqhmd", scores, value.float())
output = output.reshape(batch_size, num_queries, num_key_value_heads * num_key_value_groups, output = output.reshape(batch_size, num_queries, num_key_value_heads * num_key_value_groups,
head_dim).to(torch.float16) head_dim).to(dtype)
return output.transpose(1, 2).contiguous() return output.transpose(1, 2).contiguous()
...@@ -432,7 +452,9 @@ def main(BATCH: int = 1, ...@@ -432,7 +452,9 @@ def main(BATCH: int = 1,
N_CTX: int = 512, N_CTX: int = 512,
D_HEAD: int = 64, D_HEAD: int = 64,
groups: int = 2, groups: int = 2,
window_size: int | None = None): window_size: int | None = None,
dtype: str = "float16"):
torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dtype]
if window_size is not None: if window_size is not None:
print('Using sliding window attention.') print('Using sliding window attention.')
assert window_size <= N_CTX assert window_size <= N_CTX
...@@ -443,14 +465,11 @@ def main(BATCH: int = 1, ...@@ -443,14 +465,11 @@ def main(BATCH: int = 1,
flops_per_matmul = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD * 0.5 flops_per_matmul = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD * 0.5
total_flops = 5 * flops_per_matmul total_flops = 5 * flops_per_matmul
Q = ( Q = (torch.randn(BATCH, H, N_CTX, D_HEAD, dtype=torch_dtype, device="cuda").requires_grad_())
torch.empty(BATCH, H, N_CTX, D_HEAD, dtype=torch.float16, K = torch.randn(
device="cuda").normal_().requires_grad_()) BATCH, H // groups, N_CTX, D_HEAD, dtype=torch_dtype, device="cuda").requires_grad_()
K = torch.empty( V = torch.randn_like(K).requires_grad_()
BATCH, H // groups, N_CTX, D_HEAD, dtype=torch.float16, sinks = torch.randn(H, dtype=torch_dtype, device="cuda").requires_grad_()
device="cuda").normal_().requires_grad_()
V = torch.empty_like(K).normal_().requires_grad_()
sinks = torch.randn(H, dtype=torch.float16, device="cuda").requires_grad_()
dO = torch.randn_like(Q) dO = torch.randn_like(Q)
O = attention(Q, K, V, sinks, window_size, groups) O = attention(Q, K, V, sinks, window_size, groups)
...@@ -460,7 +479,7 @@ def main(BATCH: int = 1, ...@@ -460,7 +479,7 @@ def main(BATCH: int = 1,
dV, V.grad = V.grad.clone(), None dV, V.grad = V.grad.clone(), None
dsinks, sinks.grad = sinks.grad.clone(), None dsinks, sinks.grad = sinks.grad.clone(), None
O_ref = ref_program(Q, K, V, sinks, window_size) O_ref = ref_program(Q, K, V, sinks, window_size, dtype=torch_dtype)
O_ref.backward(dO, retain_graph=True) O_ref.backward(dO, retain_graph=True)
dQ_ref, Q.grad = Q.grad.clone(), None dQ_ref, Q.grad = Q.grad.clone(), None
dK_ref, K.grad = K.grad.clone(), None dK_ref, K.grad = K.grad.clone(), None
...@@ -468,11 +487,20 @@ def main(BATCH: int = 1, ...@@ -468,11 +487,20 @@ def main(BATCH: int = 1,
dsinks_ref, sinks.grad = sinks.grad.clone(), None dsinks_ref, sinks.grad = sinks.grad.clone(), None
# Checks # Checks
assert torch.allclose(O, O_ref, rtol=1e-2, atol=1e-2) rtol, atol = {
assert torch.allclose(dV, dV_ref, rtol=1e-2, atol=1e-2) "float16": (1e-2, 1e-2),
assert torch.allclose(dK, dK_ref, rtol=1e-2, atol=1e-2) "bfloat16": (2e-2, 2e-2),
assert torch.allclose(dQ, dQ_ref, rtol=1e-2, atol=1e-2) }[dtype]
assert torch.allclose(dsinks, dsinks_ref, rtol=1e-2, atol=1e-2), f'{dsinks=}, {dsinks_ref=}' assert torch.allclose(O, O_ref, rtol=rtol, atol=atol), f'O max err: {(O-O_ref).abs().max()}'
assert torch.allclose(
dV, dV_ref, rtol=rtol, atol=atol), f'dV max err: {(dV-dV_ref).abs().max()}'
assert torch.allclose(
dK, dK_ref, rtol=rtol, atol=atol), f'dK max err: {(dK-dK_ref).abs().max()}'
assert torch.allclose(
dQ, dQ_ref, rtol=rtol, atol=atol), f'dq max err: {(dQ-dQ_ref).abs().max()}'
assert torch.allclose(
dsinks, dsinks_ref, rtol=rtol,
atol=atol), f'dsinks max err: {(dsinks-dsinks_ref).abs().max()}'
print("All checks passed for tilelang kernels.✅") print("All checks passed for tilelang kernels.✅")
...@@ -495,7 +523,7 @@ if __name__ == "__main__": ...@@ -495,7 +523,7 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=1, help='Batch size') parser.add_argument('--batch', type=int, default=1, help='Batch size')
parser.add_argument('--h', type=int, default=64, help='Number of heads') parser.add_argument('--h', type=int, default=64, help='Number of heads')
parser.add_argument('--n_ctx', type=int, default=1024, help='Context size') parser.add_argument('--n_ctx', type=int, default=4096, help='Context size')
parser.add_argument('--d_head', type=int, default=128, help='Head dimension') parser.add_argument('--d_head', type=int, default=128, help='Head dimension')
parser.add_argument('--groups', type=int, default=8, help='Groups') parser.add_argument('--groups', type=int, default=8, help='Groups')
parser.add_argument( parser.add_argument(
...@@ -503,5 +531,7 @@ if __name__ == "__main__": ...@@ -503,5 +531,7 @@ if __name__ == "__main__":
type=int, type=int,
default=None, default=None,
help='window size (default: None, which means full attention)') help='window size (default: None, which means full attention)')
parser.add_argument(
'--dtype', type=str, default="float16", help="dtype, can be float16 or bfloat16")
args = parser.parse_args() args = parser.parse_args()
main(args.batch, args.h, args.n_ctx, args.d_head, args.groups, args.window_size) main(args.batch, args.h, args.n_ctx, args.d_head, args.groups, args.window_size, args.dtype)
...@@ -12,6 +12,7 @@ import argparse ...@@ -12,6 +12,7 @@ import argparse
import triton import triton
import triton.language as tl import triton.language as tl
from triton.tools.tensor_descriptor import TensorDescriptor from triton.tools.tensor_descriptor import TensorDescriptor
from typing import Optional
def get_configs(): def get_configs():
...@@ -25,9 +26,11 @@ def get_configs(): ...@@ -25,9 +26,11 @@ def get_configs():
rep=100, rep=100,
) )
@tilelang.jit( @tilelang.jit(
out_idx=[3], pass_configs={ out_idx=[3],
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
}) },
compile_flags=["-O3", "-DENABLE_BF16"])
def flashattn( def flashattn(
batch, batch,
heads, heads,
...@@ -36,20 +39,24 @@ def flashattn( ...@@ -36,20 +39,24 @@ def flashattn(
dim, dim,
groups=1, groups=1,
window_size=None, # None for full attention window_size=None, # None for full attention
sm_scale=None,
block_M=128, block_M=128,
block_N=128, block_N=128,
num_stages=2, num_stages=2,
threads=256, threads=256,
dtype: str = "float16",
): ):
if window_size is not None: if window_size is not None:
assert window_size % block_N == 0, "window_size must be divisible by block_N" assert window_size % block_N == 0, "window_size must be divisible by block_N"
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) if sm_scale is None:
sm_scale = (1.0 / dim)**0.5
scale = sm_scale * 1.44269504 # log2(e)
head_kv = heads // groups head_kv = heads // groups
q_shape = [batch, heads, seq_q, dim] q_shape = [batch, heads, seq_q, dim]
kv_shape = [batch, head_kv, seq_kv, dim] kv_shape = [batch, head_kv, seq_kv, dim]
dtype = "float16"
accum_dtype = "float" accum_dtype = "float"
past_len = seq_kv - seq_q past_len = seq_kv - seq_q
...@@ -205,7 +212,8 @@ def ref_program(query: torch.Tensor, ...@@ -205,7 +212,8 @@ def ref_program(query: torch.Tensor,
key: torch.Tensor, key: torch.Tensor,
value: torch.Tensor, value: torch.Tensor,
sinks: torch.Tensor, sinks: torch.Tensor,
sliding_window: int | None = None) -> torch.Tensor: sliding_window: Optional[int] = None,
dtype: torch.dtype = torch.float16) -> torch.Tensor:
key = key.transpose(1, 2).contiguous() key = key.transpose(1, 2).contiguous()
value = value.transpose(1, 2).contiguous() value = value.transpose(1, 2).contiguous()
...@@ -243,7 +251,7 @@ def ref_program(query: torch.Tensor, ...@@ -243,7 +251,7 @@ def ref_program(query: torch.Tensor,
output = torch.einsum("bhmqk,bkhmd->bqhmd", scores, value.float()) output = torch.einsum("bhmqk,bkhmd->bqhmd", scores, value.float())
output = output.reshape(batch_size, num_queries, num_key_value_heads * num_key_value_groups, output = output.reshape(batch_size, num_queries, num_key_value_heads * num_key_value_groups,
head_dim).to(torch.float16) head_dim).to(dtype)
return output.transpose(1, 2).contiguous() return output.transpose(1, 2).contiguous()
...@@ -363,12 +371,18 @@ def triton_program(Q, K, V, Sinks, window_size: int | None = None) -> torch.Tens ...@@ -363,12 +371,18 @@ def triton_program(Q, K, V, Sinks, window_size: int | None = None) -> torch.Tens
return o return o
def gen_inputs(B, H, Sq, Skv, D, def gen_inputs(
groups) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: B,
query = torch.randn([B, H, Sq, D], dtype=torch.float16, device='cuda') H,
key = torch.randn([B, H // groups, Skv, D], dtype=torch.float16, device='cuda') Sq,
value = torch.randn([B, H // groups, Skv, D], dtype=torch.float16, device='cuda') Skv,
sinks = torch.randn([H], dtype=torch.float16, device='cuda') D,
groups,
dtype=torch.float16) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
query = torch.randn([B, H, Sq, D], dtype=dtype, device='cuda')
key = torch.randn([B, H // groups, Skv, D], dtype=dtype, device='cuda')
value = torch.randn([B, H // groups, Skv, D], dtype=dtype, device='cuda')
sinks = torch.randn([H], dtype=dtype, device='cuda')
return query, key, value, sinks return query, key, value, sinks
...@@ -380,8 +394,10 @@ def main( ...@@ -380,8 +394,10 @@ def main(
dim: int = 128, dim: int = 128,
groups: int = 8, groups: int = 8,
window_size: int | None = None, window_size: int | None = None,
dtype: str = "float16",
tune: bool = False, tune: bool = False,
): ):
torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dtype]
if window_size is not None: if window_size is not None:
print('Using sliding window attention.') print('Using sliding window attention.')
assert window_size <= seq_q assert window_size <= seq_q
...@@ -393,7 +409,7 @@ def main( ...@@ -393,7 +409,7 @@ def main(
total_flops = 2 * flops_per_matmul total_flops = 2 * flops_per_matmul
if tune: if tune:
kernel = flashattn(batch, heads, seq_q, seq_kv, dim, groups, window_size) kernel = flashattn(batch, heads, seq_q, seq_kv, dim, groups, window_size, dtype=dtype)
print(f"Best latency: {kernel.latency}") print(f"Best latency: {kernel.latency}")
print(f"Best TFlops: {total_flops / kernel.latency * 1e-9}") print(f"Best TFlops: {total_flops / kernel.latency * 1e-9}")
print(f"Best config: {kernel.config}") print(f"Best config: {kernel.config}")
...@@ -415,17 +431,21 @@ def main( ...@@ -415,17 +431,21 @@ def main(
block_M=block_M, block_M=block_M,
block_N=block_N, block_N=block_N,
num_stages=num_stages, num_stages=num_stages,
threads=threads) threads=threads,
dtype=dtype)
Q, K, V, sinks = gen_inputs(batch, heads, seq_q, seq_kv, dim, groups) Q, K, V, sinks = gen_inputs(batch, heads, seq_q, seq_kv, dim, groups, dtype=torch_dtype)
torch.testing.assert_close( torch.testing.assert_close(
kernel(Q, K, V, sinks), ref_program(Q, K, V, sinks, window_size), rtol=1e-2, atol=1e-2) kernel(Q, K, V, sinks),
ref_program(Q, K, V, sinks, window_size, dtype=torch_dtype),
rtol=1e-2,
atol=1e-2)
print("All checks passed.✅") print("All checks passed.✅")
if torch.allclose( if torch.allclose(
triton_program(Q, K, V, sinks, window_size), triton_program(Q, K, V, sinks, window_size),
ref_program(Q, K, V, sinks, window_size), ref_program(Q, K, V, sinks, window_size, dtype=torch_dtype),
rtol=1e-2, rtol=1e-2,
atol=1e-2): atol=1e-2):
print("Checks for triton passed.✅") print("Checks for triton passed.✅")
...@@ -458,7 +478,9 @@ if __name__ == "__main__": ...@@ -458,7 +478,9 @@ if __name__ == "__main__":
type=int, type=int,
default=None, default=None,
help='window size (default: None, which means full attention)') help='window size (default: None, which means full attention)')
parser.add_argument(
'--dtype', type=str, default="float16", help="dtype, can be float16 or bfloat16")
parser.add_argument('--tune', action='store_true', help='tune configs') parser.add_argument('--tune', action='store_true', help='tune configs')
args = parser.parse_args() args = parser.parse_args()
main(args.batch, args.heads, args.seq_q, args.seq_kv, args.dim, args.groups, args.window_size, main(args.batch, args.heads, args.seq_q, args.seq_kv, args.dim, args.groups, args.window_size,
args.tune) args.dtype, args.tune)
...@@ -5,40 +5,47 @@ import tilelang ...@@ -5,40 +5,47 @@ import tilelang
from tilelang.profiler import do_bench from tilelang.profiler import do_bench
import tilelang.language as T import tilelang.language as T
import argparse import argparse
from typing import Optional
def get_bwd_configs(): def get_bwd_configs():
sm_major, sm_minor = torch.cuda.get_device_capability() sm_major, sm_minor = torch.cuda.get_device_capability()
sm_version = sm_major * 10 + sm_minor sm_version = sm_major * 10 + sm_minor
if sm_version == 80: if sm_version == 80:
return 64, 64, 1, 128 return 64, 32, 1, 128
elif sm_version == 90: elif sm_version == 90:
return 128, 128, 2, 256 return 128, 32, 2, 256
else: else:
raise ValueError(f"Unsupported SM version: {sm_version}") raise ValueError(f"Unsupported SM version: {sm_version}")
@tilelang.jit( @tilelang.jit(
out_idx=[3, 4], pass_configs={ out_idx=[3, 4],
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
}) },
compile_flags=["-O3", "-DENABLE_BF16"])
def flashattn_fwd( def flashattn_fwd(
batch, batch,
heads, heads,
seq_len, seq_len,
dim, dim,
window_size=None, # None for full attention, window_size=None, # None for full attention,
sm_scale=None,
block_M=64, block_M=64,
block_N=64, block_N=64,
num_stages=1, num_stages=1,
threads=128): threads=128,
dtype: str = "float16"):
if window_size is not None: if window_size is not None:
assert window_size % block_N == 0, "window_size must be divisible by block_N" assert window_size % block_N == 0, "window_size must be divisible by block_N"
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) if sm_scale is None:
sm_scale = (1.0 / dim)**0.5
scale = sm_scale * 1.44269504 # log2(e)
shape = [batch, heads, seq_len, dim] shape = [batch, heads, seq_len, dim]
dtype = "float16"
accum_dtype = "float" accum_dtype = "float"
@T.prim_func @T.prim_func
...@@ -52,7 +59,6 @@ def flashattn_fwd( ...@@ -52,7 +59,6 @@ def flashattn_fwd(
): ):
with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz): with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz):
Q_shared = T.alloc_shared([block_M, dim], dtype) Q_shared = T.alloc_shared([block_M, dim], dtype)
# Q_local = T.alloc_fragment([block_M, dim], dtype)
K_shared = T.alloc_shared([block_N, dim], dtype) K_shared = T.alloc_shared([block_N, dim], dtype)
V_shared = T.alloc_shared([block_N, dim], dtype) V_shared = T.alloc_shared([block_N, dim], dtype)
acc_s = T.alloc_fragment([block_M, block_N], accum_dtype) acc_s = T.alloc_fragment([block_M, block_N], accum_dtype)
...@@ -72,9 +78,7 @@ def flashattn_fwd( ...@@ -72,9 +78,7 @@ def flashattn_fwd(
T.fill(scores_max, -T.infinity(accum_dtype)) T.fill(scores_max, -T.infinity(accum_dtype))
for i in T.Parallel(block_M): for i in T.Parallel(block_M):
sinks[i] = Sinks[by] sinks[i] = Sinks[by]
# T.copy(Q_shared, Q_local)
# for i, j in T.Parallel(block_M, dim):
# Q_local[i, j] *= scale
end = T.min(T.ceildiv(seq_len, block_N), T.ceildiv((bx + 1) * block_M, block_N)) end = T.min(T.ceildiv(seq_len, block_N), T.ceildiv((bx + 1) * block_M, block_N))
start = T.alloc_local([1], 'int32') start = T.alloc_local([1], 'int32')
if window_size is not None: if window_size is not None:
...@@ -133,11 +137,12 @@ def flashattn_fwd( ...@@ -133,11 +137,12 @@ def flashattn_fwd(
@tilelang.jit( @tilelang.jit(
out_idx=[2], pass_configs={ out_idx=[2],
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
}) },
def flashattn_bwd_preprocess(batch, heads, seq_len, dim): compile_flags=["-O3", "-DENABLE_BF16"])
dtype = "float16" def flashattn_bwd_preprocess(batch, heads, seq_len, dim, dtype: str = "float16"):
accum_dtype = "float" accum_dtype = "float"
shape = [batch, heads, seq_len, dim] shape = [batch, heads, seq_len, dim]
blk = 32 blk = 32
...@@ -172,11 +177,12 @@ def make_dq_layout(dQ): ...@@ -172,11 +177,12 @@ def make_dq_layout(dQ):
@tilelang.jit( @tilelang.jit(
out_idx=[1], pass_configs={ out_idx=[1],
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
}) },
def flashattn_bwd_postprocess(batch, heads, seq_len, dim): compile_flags=["-O3", "-DENABLE_BF16"])
dtype = "float16" def flashattn_bwd_postprocess(batch, heads, seq_len, dim, dtype: str = "float16"):
accum_dtype = "float" accum_dtype = "float"
shape = [batch, heads, seq_len, dim] shape = [batch, heads, seq_len, dim]
blk = 64 blk = 64
...@@ -196,23 +202,28 @@ def flashattn_bwd_postprocess(batch, heads, seq_len, dim): ...@@ -196,23 +202,28 @@ def flashattn_bwd_postprocess(batch, heads, seq_len, dim):
return flash_bwd_post return flash_bwd_post
@tilelang.jit(pass_configs={ @tilelang.jit(
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, pass_configs={
}) tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
},
compile_flags=["-O3", "-DENABLE_BF16"])
def flashattn_bwd( def flashattn_bwd(
batch, batch,
heads, heads,
seq_len, seq_len,
dim, dim,
window_size=None, # None for full attention, window_size=None, # None for full attention
sm_scale=None,
dtype: str = "float16",
): ):
block_M, block_N, num_stages, threads = get_bwd_configs() block_M, block_N, num_stages, threads = get_bwd_configs()
sm_scale = (1.0 / dim)**0.5 if sm_scale is None:
sm_scale = (1.0 / dim)**0.5
scale = sm_scale * 1.44269504 # log2(e) scale = sm_scale * 1.44269504 # log2(e)
shape = [batch, heads, seq_len, dim] shape = [batch, heads, seq_len, dim]
dtype = "float16"
accum_dtype = "float" accum_dtype = "float"
if window_size is not None: if window_size is not None:
...@@ -301,9 +312,8 @@ def flashattn_bwd( ...@@ -301,9 +312,8 @@ def flashattn_bwd(
T.copy(dsT_cast, dsT_shared) T.copy(dsT_cast, dsT_shared)
T.clear(dq) T.clear(dq)
T.gemm(dsT_shared, K_shared, dq, transpose_A=True) T.gemm(dsT_shared, K_shared, dq, transpose_A=True)
for i, j in T.Parallel(block_N, dim): T.atomic_add(dQ[bz, bx, k * block_N:(k + 1) * block_N, :], dq)
if k * block_N + i < seq_len:
T.atomic_add(dQ[bz, bx, k * block_N + i, j], dq[i, j])
T.copy(dv, dv_shared) T.copy(dv, dv_shared)
T.copy(dk, dk_shared) T.copy(dk, dk_shared)
T.copy(dv_shared, dV[bz, bx, by * block_M:(by + 1) * block_M, :]) T.copy(dv_shared, dV[bz, bx, by * block_M:(by + 1) * block_M, :])
...@@ -313,8 +323,7 @@ def flashattn_bwd( ...@@ -313,8 +323,7 @@ def flashattn_bwd(
@tilelang.jit(out_idx=-1) @tilelang.jit(out_idx=-1)
def flashattn_bwd_dsink(batch, heads, seq_len, block=128): def flashattn_bwd_dsink(batch, heads, seq_len, block=128, dtype: str = "float16"):
dtype = "float16"
accum_dtype = "float" accum_dtype = "float"
shape = [batch, heads, seq_len] shape = [batch, heads, seq_len]
...@@ -323,13 +332,13 @@ def flashattn_bwd_dsink(batch, heads, seq_len, block=128): ...@@ -323,13 +332,13 @@ def flashattn_bwd_dsink(batch, heads, seq_len, block=128):
Sinks: T.Tensor([heads], dtype), # type: ignore Sinks: T.Tensor([heads], dtype), # type: ignore
Delta: T.Tensor(shape, accum_dtype), # type: ignore Delta: T.Tensor(shape, accum_dtype), # type: ignore
lse: T.Tensor(shape, accum_dtype), # type: ignore lse: T.Tensor(shape, accum_dtype), # type: ignore
dsinks: T.Tensor(shape, dtype), # type: ignore dsinks: T.Tensor(shape, accum_dtype), # type: ignore
): ):
with T.Kernel(heads, T.ceildiv(seq_len, block), batch, threads=128) as (bx, by, bz): with T.Kernel(heads, T.ceildiv(seq_len, block), batch, threads=128) as (bx, by, bz):
sink = T.alloc_local([1], dtype) sink = T.alloc_local([1], dtype)
lse_fragment = T.alloc_fragment([block], accum_dtype) lse_fragment = T.alloc_fragment([block], accum_dtype)
delta_fragment = T.alloc_fragment([block], accum_dtype) delta_fragment = T.alloc_fragment([block], accum_dtype)
dsink_fragment = T.alloc_fragment([block], dtype) dsink_fragment = T.alloc_fragment([block], accum_dtype)
sink[0] = Sinks[bx] sink[0] = Sinks[bx]
T.copy(lse[bz, bx, by * block:(by + 1) * block], lse_fragment) T.copy(lse[bz, bx, by * block:(by + 1) * block], lse_fragment)
...@@ -347,9 +356,8 @@ class _attention(torch.autograd.Function): ...@@ -347,9 +356,8 @@ class _attention(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, q, k, v, sinks, window_size): def forward(ctx, q, k, v, sinks, window_size):
BATCH, H, N_CTX, D_HEAD = q.shape BATCH, H, N_CTX, D_HEAD = q.shape
block_M = 64 dtype = "float16" if q.dtype == torch.float16 else "bfloat16"
block_N = 64 if D_HEAD <= 128 else 32 kernel = flashattn_fwd(BATCH, H, N_CTX, D_HEAD, window_size, dtype=dtype)
kernel = flashattn_fwd(BATCH, H, N_CTX, D_HEAD, window_size, block_M, block_N)
o, lse = kernel(q, k, v, sinks) o, lse = kernel(q, k, v, sinks)
ctx.save_for_backward(q, k, v, sinks, o, lse) ctx.save_for_backward(q, k, v, sinks, o, lse)
ctx.window_size = window_size ctx.window_size = window_size
...@@ -366,18 +374,19 @@ class _attention(torch.autograd.Function): ...@@ -366,18 +374,19 @@ class _attention(torch.autograd.Function):
return x return x
do, q, k, v, sinks, o = [maybe_contiguous(x) for x in (do, q, k, v, sinks, o)] do, q, k, v, sinks, o = [maybe_contiguous(x) for x in (do, q, k, v, sinks, o)]
kernel_prep = flashattn_bwd_preprocess(BATCH, H, N_CTX, D_HEAD) dtype = "float16" if q.dtype == torch.float16 else "bfloat16"
kernel_post = flashattn_bwd_postprocess(BATCH, H, N_CTX, D_HEAD) kernel_prep = flashattn_bwd_preprocess(BATCH, H, N_CTX, D_HEAD, dtype=dtype)
kernel_post = flashattn_bwd_postprocess(BATCH, H, N_CTX, D_HEAD, dtype=dtype)
delta = kernel_prep(o, do) delta = kernel_prep(o, do)
kernel = flashattn_bwd(BATCH, H, N_CTX, D_HEAD, ctx.window_size) kernel = flashattn_bwd(BATCH, H, N_CTX, D_HEAD, ctx.window_size, dtype=dtype)
shape = [BATCH, H, N_CTX, D_HEAD] shape = [BATCH, H, N_CTX, D_HEAD]
dq = torch.zeros(shape, dtype=torch.float32, device=q.device) # acc for atomicAdd dq = torch.zeros(shape, dtype=torch.float32, device=q.device) # acc for atomicAdd
dk = torch.empty(shape, dtype=torch.float16, device=q.device) dk = torch.empty(shape, dtype=q.dtype, device=q.device)
dv = torch.empty(shape, dtype=torch.float16, device=q.device) dv = torch.empty(shape, dtype=q.dtype, device=q.device)
kernel(q, k, v, do, lse, delta, dq, dk, dv) kernel(q, k, v, do, lse, delta, dq, dk, dv)
dq = kernel_post(dq) dq = kernel_post(dq)
kernel_dsink = flashattn_bwd_dsink(BATCH, H, N_CTX) kernel_dsink = flashattn_bwd_dsink(BATCH, H, N_CTX, dtype=dtype)
dsinks = kernel_dsink(sinks, delta, lse).sum(0).sum(1) dsinks = kernel_dsink(sinks, delta, lse).sum(0).sum(1)
return dq, dk, dv, dsinks, None return dq, dk, dv, dsinks, None
...@@ -391,7 +400,8 @@ def ref_program(query: torch.Tensor, ...@@ -391,7 +400,8 @@ def ref_program(query: torch.Tensor,
key: torch.Tensor, key: torch.Tensor,
value: torch.Tensor, value: torch.Tensor,
sinks: torch.Tensor, sinks: torch.Tensor,
sliding_window: int | None = None) -> torch.Tensor: sliding_window: Optional[int] = None,
dtype: torch.dtype = torch.float16) -> torch.Tensor:
query = query.transpose(1, 2).contiguous().unsqueeze( query = query.transpose(1, 2).contiguous().unsqueeze(
3) # align with the original function's interface 3) # align with the original function's interface
...@@ -404,7 +414,7 @@ def ref_program(query: torch.Tensor, ...@@ -404,7 +414,7 @@ def ref_program(query: torch.Tensor,
sm_scale: float = 1.0 / head_dim**0.5 sm_scale: float = 1.0 / head_dim**0.5
sinks = sinks.view(1, num_key_value_heads, num_key_value_groups, 1, 1).float() sinks = sinks.view(1, num_key_value_heads, num_key_value_groups, 1, 1)
key = key.unsqueeze(3) key = key.unsqueeze(3)
value = value.unsqueeze(3) value = value.unsqueeze(3)
...@@ -430,7 +440,7 @@ def ref_program(query: torch.Tensor, ...@@ -430,7 +440,7 @@ def ref_program(query: torch.Tensor,
output = torch.einsum("bhmqk,bkhmd->bqhmd", scores, value.float()) output = torch.einsum("bhmqk,bkhmd->bqhmd", scores, value.float())
output = output.reshape(batch_size, num_queries, num_key_value_heads * num_key_value_groups, output = output.reshape(batch_size, num_queries, num_key_value_heads * num_key_value_groups,
head_dim).to(torch.float16) head_dim).to(dtype)
return output.transpose(1, 2).contiguous() return output.transpose(1, 2).contiguous()
...@@ -438,7 +448,9 @@ def main(BATCH: int = 1, ...@@ -438,7 +448,9 @@ def main(BATCH: int = 1,
H: int = 1, H: int = 1,
N_CTX: int = 512, N_CTX: int = 512,
D_HEAD: int = 128, D_HEAD: int = 128,
window_size: int | None = None): window_size: int | None = None,
dtype: str = "float16"):
torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dtype]
if window_size is not None: if window_size is not None:
print('Using sliding window attention.') print('Using sliding window attention.')
assert window_size <= N_CTX assert window_size <= N_CTX
...@@ -449,12 +461,10 @@ def main(BATCH: int = 1, ...@@ -449,12 +461,10 @@ def main(BATCH: int = 1,
flops_per_matmul = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD * 0.5 flops_per_matmul = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD * 0.5
total_flops = 5 * flops_per_matmul total_flops = 5 * flops_per_matmul
Q = ( Q = (torch.randn(BATCH, H, N_CTX, D_HEAD, dtype=torch_dtype, device="cuda").requires_grad_())
torch.empty(BATCH, H, N_CTX, D_HEAD, dtype=torch.half, K = torch.randn_like(Q).requires_grad_()
device="cuda").normal_().requires_grad_()) V = torch.randn_like(Q).requires_grad_()
K = torch.empty_like(Q).normal_().requires_grad_() sinks = torch.randn(H, dtype=torch_dtype, device=Q.device).requires_grad_()
V = torch.empty_like(Q).normal_().requires_grad_()
sinks = torch.randn(H, dtype=torch.float16, device=Q.device).requires_grad_()
dO = torch.randn_like(Q) dO = torch.randn_like(Q)
O = attention(Q, K, V, sinks, window_size) O = attention(Q, K, V, sinks, window_size)
...@@ -464,7 +474,7 @@ def main(BATCH: int = 1, ...@@ -464,7 +474,7 @@ def main(BATCH: int = 1,
dV, V.grad = V.grad.clone(), None dV, V.grad = V.grad.clone(), None
dsinks, sinks.grad = sinks.grad.clone(), None dsinks, sinks.grad = sinks.grad.clone(), None
O_ref = ref_program(Q, K, V, sinks, window_size) O_ref = ref_program(Q, K, V, sinks, window_size, dtype=torch_dtype)
O_ref.backward(dO, retain_graph=True) O_ref.backward(dO, retain_graph=True)
dQ_ref, Q.grad = Q.grad.clone(), None dQ_ref, Q.grad = Q.grad.clone(), None
dK_ref, K.grad = K.grad.clone(), None dK_ref, K.grad = K.grad.clone(), None
...@@ -472,11 +482,20 @@ def main(BATCH: int = 1, ...@@ -472,11 +482,20 @@ def main(BATCH: int = 1,
dsinks_ref, sinks.grad = sinks.grad.clone(), None dsinks_ref, sinks.grad = sinks.grad.clone(), None
# Checks # Checks
assert torch.allclose(O, O_ref, rtol=1e-2, atol=1e-2) rtol, atol = {
assert torch.allclose(dV, dV_ref, rtol=1e-2, atol=1e-2) "float16": (1e-2, 1e-2),
assert torch.allclose(dK, dK_ref, rtol=1e-2, atol=1e-2) "bfloat16": (2e-2, 2e-2),
assert torch.allclose(dQ, dQ_ref, rtol=1e-2, atol=1e-2) }[dtype]
assert torch.allclose(dsinks, dsinks_ref, rtol=1e-2, atol=1e-2), f'{dsinks=}, {dsinks_ref=}' assert torch.allclose(O, O_ref, rtol=rtol, atol=atol), f'O max err: {(O-O_ref).abs().max()}'
assert torch.allclose(
dV, dV_ref, rtol=rtol, atol=atol), f'dV max err: {(dV-dV_ref).abs().max()}'
assert torch.allclose(
dK, dK_ref, rtol=rtol, atol=atol), f'dK max err: {(dK-dK_ref).abs().max()}'
assert torch.allclose(
dQ, dQ_ref, rtol=rtol, atol=atol), f'dq max err: {(dQ-dQ_ref).abs().max()}'
assert torch.allclose(
dsinks, dsinks_ref, rtol=rtol,
atol=atol), f'dsinks max err: {(dsinks-dsinks_ref).abs().max()}'
print("All checks passed for tilelang kernels.✅") print("All checks passed for tilelang kernels.✅")
...@@ -498,13 +517,15 @@ def main(BATCH: int = 1, ...@@ -498,13 +517,15 @@ def main(BATCH: int = 1,
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=1, help='Batch size') parser.add_argument('--batch', type=int, default=1, help='Batch size')
parser.add_argument('--h', type=int, default=32, help='Number of heads') parser.add_argument('--h', type=int, default=64, help='Number of heads')
parser.add_argument('--n_ctx', type=int, default=1024, help='Context size') parser.add_argument('--n_ctx', type=int, default=4096, help='Context size')
parser.add_argument('--d_head', type=int, default=128, help='Head dimension') parser.add_argument('--d_head', type=int, default=128, help='Head dimension')
parser.add_argument( parser.add_argument(
'--window_size', '--window_size',
type=int, type=int,
default=None, default=None,
help='window size (default: None, which means full attention)') help='window size (default: None, which means full attention)')
parser.add_argument(
'--dtype', type=str, default="float16", help="dtype, can be float16 or bfloat16")
args = parser.parse_args() args = parser.parse_args()
main(args.batch, args.h, args.n_ctx, args.d_head, args.window_size) main(args.batch, args.h, args.n_ctx, args.d_head, args.window_size, args.dtype)
...@@ -8,6 +8,7 @@ import tilelang.language as T ...@@ -8,6 +8,7 @@ import tilelang.language as T
from tilelang.layout import make_swizzled_layout from tilelang.layout import make_swizzled_layout
import itertools import itertools
import argparse import argparse
from typing import Optional
def get_configs(): def get_configs():
...@@ -17,9 +18,11 @@ def get_configs(): ...@@ -17,9 +18,11 @@ def get_configs():
@autotune(configs=get_configs(), warmup=500, rep=100) @autotune(configs=get_configs(), warmup=500, rep=100)
@tilelang.jit( @tilelang.jit(
out_idx=[3], pass_configs={ out_idx=[3],
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
}) },
compile_flags=["-O3", "-DENABLE_BF16"])
def flashattn( def flashattn(
batch, batch,
heads, heads,
...@@ -27,17 +30,20 @@ def flashattn( ...@@ -27,17 +30,20 @@ def flashattn(
seq_kv, seq_kv,
dim, dim,
window_size=None, # None for full attention window_size=None, # None for full attention
sm_scale=None,
block_M=64, block_M=64,
block_N=64, block_N=64,
num_stages=1, num_stages=1,
threads=128): threads=128,
dtype: str = "float16"):
if window_size is not None: if window_size is not None:
assert window_size % block_N == 0, "window_size must be divisible by block_N" assert window_size % block_N == 0, "window_size must be divisible by block_N"
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) if sm_scale is None:
sm_scale = (1.0 / dim)**0.5
scale = sm_scale * 1.44269504 # log2(e)
q_shape = [batch, heads, seq_q, dim] q_shape = [batch, heads, seq_q, dim]
kv_shape = [batch, heads, seq_kv, dim] kv_shape = [batch, heads, seq_kv, dim]
dtype = "float16"
accum_dtype = "float" accum_dtype = "float"
past_len = seq_kv - seq_q past_len = seq_kv - seq_q
...@@ -186,7 +192,8 @@ def ref_program(query: torch.Tensor, ...@@ -186,7 +192,8 @@ def ref_program(query: torch.Tensor,
key: torch.Tensor, key: torch.Tensor,
value: torch.Tensor, value: torch.Tensor,
sinks: torch.Tensor, sinks: torch.Tensor,
sliding_window: int | None = None) -> torch.Tensor: sliding_window: Optional[int] = None,
dtype: torch.dtype = torch.float16) -> torch.Tensor:
query = query.transpose(1, 2).contiguous().unsqueeze( query = query.transpose(1, 2).contiguous().unsqueeze(
3) # align with the original function's interface 3) # align with the original function's interface
...@@ -225,15 +232,21 @@ def ref_program(query: torch.Tensor, ...@@ -225,15 +232,21 @@ def ref_program(query: torch.Tensor,
output = torch.einsum("bhmqk,bkhmd->bqhmd", scores, value.float()) output = torch.einsum("bhmqk,bkhmd->bqhmd", scores, value.float())
output = output.reshape(batch_size, num_queries, num_key_value_heads * num_key_value_groups, output = output.reshape(batch_size, num_queries, num_key_value_heads * num_key_value_groups,
head_dim).to(torch.float16) head_dim).to(dtype)
return output.transpose(1, 2).contiguous() return output.transpose(1, 2).contiguous()
def gen_inputs(B, H, Sq, Skv, D) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: def gen_inputs(
query = torch.randn([B, H, Sq, D], dtype=torch.float16, device='cuda') B,
key = torch.randn([B, H, Skv, D], dtype=torch.float16, device='cuda') H,
value = torch.randn([B, H, Skv, D], dtype=torch.float16, device='cuda') Sq,
sinks = torch.zeros([H], dtype=torch.float16, device='cuda') Skv,
D,
dtype=torch.float16) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
query = torch.randn([B, H, Sq, D], dtype=dtype, device='cuda')
key = torch.randn([B, H, Skv, D], dtype=dtype, device='cuda')
value = torch.randn([B, H, Skv, D], dtype=dtype, device='cuda')
sinks = torch.randn([H], dtype=dtype, device='cuda')
return query, key, value, sinks return query, key, value, sinks
...@@ -243,7 +256,9 @@ def main(batch: int = 1, ...@@ -243,7 +256,9 @@ def main(batch: int = 1,
seq_kv: int = 256, seq_kv: int = 256,
dim: int = 128, dim: int = 128,
window_size: int | None = None, window_size: int | None = None,
dtype: str = "float16",
tune: bool = False): tune: bool = False):
torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dtype]
if window_size is not None: if window_size is not None:
print('Using sliding window attention.') print('Using sliding window attention.')
assert window_size <= seq_q assert window_size <= seq_q
...@@ -255,7 +270,7 @@ def main(batch: int = 1, ...@@ -255,7 +270,7 @@ def main(batch: int = 1,
total_flops = 2 * flops_per_matmul total_flops = 2 * flops_per_matmul
if tune: if tune:
kernel = flashattn(batch, heads, seq_q, seq_kv, dim, window_size) kernel = flashattn(batch, heads, seq_q, seq_kv, dim, window_size, dtype=dtype)
print(f"Best latency: {kernel.latency}") print(f"Best latency: {kernel.latency}")
print(f"Best TFlops: {total_flops / kernel.latency * 1e-9}") print(f"Best TFlops: {total_flops / kernel.latency * 1e-9}")
print(f"Best config: {kernel.config}") print(f"Best config: {kernel.config}")
...@@ -276,15 +291,20 @@ def main(batch: int = 1, ...@@ -276,15 +291,20 @@ def main(batch: int = 1,
block_M=block_M, block_M=block_M,
block_N=block_N, block_N=block_N,
num_stages=num_stages, num_stages=num_stages,
threads=threads) threads=threads,
dtype=dtype)
Q, K, V, sinks = gen_inputs(batch, heads, seq_q, seq_kv, dim) Q, K, V, sinks = gen_inputs(batch, heads, seq_q, seq_kv, dim, dtype=torch_dtype)
torch.testing.assert_close( torch.testing.assert_close(
kernel(Q, K, V, sinks), ref_program(Q, K, V, sinks, window_size), rtol=1e-2, atol=1e-2) kernel(Q, K, V, sinks),
ref_program(Q, K, V, sinks, window_size, dtype=torch_dtype),
rtol=1e-2,
atol=1e-2)
print("All checks passed.✅") print("All checks passed.✅")
latency = do_bench(lambda: ref_program(Q, K, V, sinks, window_size), warmup=500) latency = do_bench(
lambda: ref_program(Q, K, V, sinks, window_size, dtype=torch_dtype), warmup=500)
print("Ref: {:.2f} ms".format(latency)) print("Ref: {:.2f} ms".format(latency))
print("Ref: {:.2f} TFlops".format(total_flops / latency * 1e-9)) print("Ref: {:.2f} TFlops".format(total_flops / latency * 1e-9))
latency = do_bench(lambda: kernel(Q, K, V, sinks), warmup=500) latency = do_bench(lambda: kernel(Q, K, V, sinks), warmup=500)
...@@ -304,6 +324,9 @@ if __name__ == "__main__": ...@@ -304,6 +324,9 @@ if __name__ == "__main__":
type=int, type=int,
default=None, default=None,
help='window size (default: None, which means full attention)') help='window size (default: None, which means full attention)')
parser.add_argument(
'--dtype', type=str, default="float16", help="dtype, can be float16 or bfloat16")
parser.add_argument('--tune', action='store_true', help='tune') parser.add_argument('--tune', action='store_true', help='tune')
args = parser.parse_args() args = parser.parse_args()
main(args.batch, args.heads, args.seq_q, args.seq_kv, args.dim, args.window_size, args.tune) main(args.batch, args.heads, args.seq_q, args.seq_kv, args.dim, args.window_size, args.dtype,
args.tune)
...@@ -12,6 +12,7 @@ import argparse ...@@ -12,6 +12,7 @@ import argparse
import triton import triton
import triton.language as tl import triton.language as tl
from triton.tools.tensor_descriptor import TensorDescriptor from triton.tools.tensor_descriptor import TensorDescriptor
from typing import Optional
def get_configs(): def get_configs():
...@@ -21,9 +22,11 @@ def get_configs(): ...@@ -21,9 +22,11 @@ def get_configs():
@autotune(configs=get_configs(), warmup=500, rep=100) @autotune(configs=get_configs(), warmup=500, rep=100)
@tilelang.jit( @tilelang.jit(
out_idx=[3], pass_configs={ out_idx=[3],
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
}) },
compile_flags=["-O3", "-DENABLE_BF16"])
def flashattn( def flashattn(
batch, batch,
heads, heads,
...@@ -31,18 +34,22 @@ def flashattn( ...@@ -31,18 +34,22 @@ def flashattn(
seq_kv, seq_kv,
dim, dim,
window_size=None, # None for full attention window_size=None, # None for full attention
sm_scale=None,
block_M=128, block_M=128,
block_N=128, block_N=128,
num_stages=2, num_stages=2,
threads=256): threads=256,
dtype: str = "float16"):
if window_size is not None: if window_size is not None:
assert window_size % block_N == 0, "window_size must be divisible by block_N" assert window_size % block_N == 0, "window_size must be divisible by block_N"
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) if sm_scale is None:
sm_scale = (1.0 / dim)**0.5
scale = sm_scale * 1.44269504 # log2(e)
q_shape = [batch, heads, seq_q, dim] q_shape = [batch, heads, seq_q, dim]
kv_shape = [batch, heads, seq_kv, dim] kv_shape = [batch, heads, seq_kv, dim]
dtype = "float16"
accum_dtype = "float" accum_dtype = "float"
past_len = seq_kv - seq_q past_len = seq_kv - seq_q
...@@ -198,7 +205,8 @@ def ref_program(query: torch.Tensor, ...@@ -198,7 +205,8 @@ def ref_program(query: torch.Tensor,
key: torch.Tensor, key: torch.Tensor,
value: torch.Tensor, value: torch.Tensor,
sinks: torch.Tensor, sinks: torch.Tensor,
sliding_window: int | None = None) -> torch.Tensor: sliding_window: Optional[int] = None,
dtype: torch.dtype = torch.float16) -> torch.Tensor:
query = query.transpose(1, 2).contiguous().unsqueeze( query = query.transpose(1, 2).contiguous().unsqueeze(
3) # align with the original function'sinterface 3) # align with the original function'sinterface
...@@ -237,7 +245,7 @@ def ref_program(query: torch.Tensor, ...@@ -237,7 +245,7 @@ def ref_program(query: torch.Tensor,
output = torch.einsum("bhmqk,bkhmd->bqhmd", scores, value.float()) output = torch.einsum("bhmqk,bkhmd->bqhmd", scores, value.float())
output = output.reshape(batch_size, num_queries, num_key_value_heads * num_key_value_groups, output = output.reshape(batch_size, num_queries, num_key_value_heads * num_key_value_groups,
head_dim).to(torch.float16) head_dim).to(dtype)
return output.transpose(1, 2).contiguous() return output.transpose(1, 2).contiguous()
...@@ -354,11 +362,17 @@ def triton_program(Q, K, V, Sinks, window_size: int | None = None) -> torch.Tens ...@@ -354,11 +362,17 @@ def triton_program(Q, K, V, Sinks, window_size: int | None = None) -> torch.Tens
return o return o
def gen_inputs(B, H, Sq, Skv, D) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: def gen_inputs(
query = torch.randn([B, H, Sq, D], dtype=torch.float16, device='cuda') B,
key = torch.randn([B, H, Skv, D], dtype=torch.float16, device='cuda') H,
value = torch.randn([B, H, Skv, D], dtype=torch.float16, device='cuda') Sq,
sinks = torch.randn([H], dtype=torch.float16, device='cuda') Skv,
D,
dtype=torch.float16) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
query = torch.randn([B, H, Sq, D], dtype=dtype, device='cuda')
key = torch.randn([B, H, Skv, D], dtype=dtype, device='cuda')
value = torch.randn([B, H, Skv, D], dtype=dtype, device='cuda')
sinks = torch.randn([H], dtype=dtype, device='cuda')
return query, key, value, sinks return query, key, value, sinks
...@@ -368,7 +382,9 @@ def main(batch: int = 1, ...@@ -368,7 +382,9 @@ def main(batch: int = 1,
seq_kv: int = 256, seq_kv: int = 256,
dim: int = 128, dim: int = 128,
window_size: int | None = None, window_size: int | None = None,
dtype: str = "float16",
tune: bool = False): tune: bool = False):
torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dtype]
if window_size is not None: if window_size is not None:
print('Using sliding window attention.') print('Using sliding window attention.')
assert window_size <= seq_q assert window_size <= seq_q
...@@ -380,7 +396,7 @@ def main(batch: int = 1, ...@@ -380,7 +396,7 @@ def main(batch: int = 1,
total_flops = 2 * flops_per_matmul total_flops = 2 * flops_per_matmul
if tune: if tune:
kernel = flashattn(batch, heads, seq_q, seq_kv, dim, window_size) kernel = flashattn(batch, heads, seq_q, seq_kv, dim, window_size, dtype=dtype)
print(f"Best latency: {kernel.latency}") print(f"Best latency: {kernel.latency}")
print(f"Best TFlops: {total_flops / kernel.latency * 1e-9}") print(f"Best TFlops: {total_flops / kernel.latency * 1e-9}")
print(f"Best config: {kernel.config}") print(f"Best config: {kernel.config}")
...@@ -401,17 +417,21 @@ def main(batch: int = 1, ...@@ -401,17 +417,21 @@ def main(batch: int = 1,
block_M=block_M, block_M=block_M,
block_N=block_N, block_N=block_N,
num_stages=num_stages, num_stages=num_stages,
threads=threads) threads=threads,
dtype=dtype)
Q, K, V, sinks = gen_inputs(batch, heads, seq_q, seq_kv, dim) Q, K, V, sinks = gen_inputs(batch, heads, seq_q, seq_kv, dim, dtype=torch_dtype)
torch.testing.assert_close( torch.testing.assert_close(
kernel(Q, K, V, sinks), ref_program(Q, K, V, sinks, window_size), rtol=1e-2, atol=1e-2) kernel(Q, K, V, sinks),
ref_program(Q, K, V, sinks, window_size, dtype=torch_dtype),
rtol=1e-2,
atol=1e-2)
print("All checks passed.✅") print("All checks passed.✅")
if torch.allclose( if torch.allclose(
triton_program(Q, K, V, sinks, window_size), triton_program(Q, K, V, sinks, window_size),
ref_program(Q, K, V, sinks, window_size), ref_program(Q, K, V, sinks, window_size, dtype=torch_dtype),
rtol=1e-2, rtol=1e-2,
atol=1e-2): atol=1e-2):
print("Checks for triton passed.✅") print("Checks for triton passed.✅")
...@@ -438,6 +458,9 @@ if __name__ == "__main__": ...@@ -438,6 +458,9 @@ if __name__ == "__main__":
type=int, type=int,
default=None, default=None,
help='window size (default: None, which means full attention)') help='window size (default: None, which means full attention)')
parser.add_argument(
'--dtype', type=str, default="float16", help="dtype, can be float16 or bfloat16")
parser.add_argument('--tune', action='store_true', help='tune') parser.add_argument('--tune', action='store_true', help='tune')
args = parser.parse_args() args = parser.parse_args()
main(args.batch, args.heads, args.seq_q, args.seq_kv, args.dim, args.window_size, args.tune) main(args.batch, args.heads, args.seq_q, args.seq_kv, args.dim, args.window_size, args.dtype,
args.tune)
...@@ -235,8 +235,7 @@ def flashattn_bwd_atomic_add(batch, ...@@ -235,8 +235,7 @@ def flashattn_bwd_atomic_add(batch,
T.clear(dq) T.clear(dq)
T.gemm(dsT_shared, K_shared, dq, transpose_A=True) T.gemm(dsT_shared, K_shared, dq, transpose_A=True)
for i, j in T.Parallel(block_N, dim_qk): for i, j in T.Parallel(block_N, dim_qk):
if k * block_N + i < seq_len: T.atomic_add(dQ[bz, k * block_N + i, bx, j], dq[i, j])
T.atomic_add(dQ[bz, k * block_N + i, bx, j], dq[i, j])
T.copy(dv, dv_shared) T.copy(dv, dv_shared)
T.atomic_add(dV[bz, by * block_M:(by + 1) * block_M, bx // groups, :], dv_shared) T.atomic_add(dV[bz, by * block_M:(by + 1) * block_M, bx // groups, :], dv_shared)
T.copy(dk, dk_shared) T.copy(dk, dk_shared)
...@@ -340,8 +339,7 @@ def flashattn_bwd_split(batch, ...@@ -340,8 +339,7 @@ def flashattn_bwd_split(batch,
T.clear(dq) T.clear(dq)
T.gemm(dsT_shared, K_shared, dq, transpose_A=True) T.gemm(dsT_shared, K_shared, dq, transpose_A=True)
for i, j in T.Parallel(block_N, dim_qk): for i, j in T.Parallel(block_N, dim_qk):
if k * block_N + i < seq_len: T.atomic_add(dQ[bz, k * block_N + i, bx, j], dq[i, j])
T.atomic_add(dQ[bz, k * block_N + i, bx, j], dq[i, j])
T.copy(dv, dv_shared) T.copy(dv, dv_shared)
T.copy(dv_shared, dV[bx % groups, bz, by * block_M:(by + 1) * block_M, bx // groups, :]) T.copy(dv_shared, dV[bx % groups, bz, by * block_M:(by + 1) * block_M, bx // groups, :])
......
...@@ -245,8 +245,7 @@ def flashattn_bwd_atomic_add(batch, ...@@ -245,8 +245,7 @@ def flashattn_bwd_atomic_add(batch,
T.gemm(dsT_shared, K_shared, dq, transpose_A=True, wg_wait=1) T.gemm(dsT_shared, K_shared, dq, transpose_A=True, wg_wait=1)
T.wait_wgmma(0) T.wait_wgmma(0)
for i, j in T.Parallel(block_N, dim_qk): for i, j in T.Parallel(block_N, dim_qk):
if k * block_N + i < seq_len: T.atomic_add(dQ[bz, k * block_N + i, bx, j], dq[i, j])
T.atomic_add(dQ[bz, k * block_N + i, bx, j], dq[i, j])
T.copy(dv, dv_shared) T.copy(dv, dv_shared)
T.atomic_add(dV[bz, by * block_M:(by + 1) * block_M, bx // groups, :], dv_shared) T.atomic_add(dV[bz, by * block_M:(by + 1) * block_M, bx // groups, :], dv_shared)
T.copy(dk, dk_shared) T.copy(dk, dk_shared)
...@@ -362,8 +361,7 @@ def flashattn_bwd_split(batch, ...@@ -362,8 +361,7 @@ def flashattn_bwd_split(batch,
T.gemm(dsT_shared, K_shared, dq, transpose_A=True, wg_wait=1) T.gemm(dsT_shared, K_shared, dq, transpose_A=True, wg_wait=1)
T.wait_wgmma(0) T.wait_wgmma(0)
for i, j in T.Parallel(block_N, dim_qk): for i, j in T.Parallel(block_N, dim_qk):
if k * block_N + i < seq_len: T.atomic_add(dQ[bz, k * block_N + i, bx, j], dq[i, j])
T.atomic_add(dQ[bz, k * block_N + i, bx, j], dq[i, j])
T.copy(dv, dv_shared) T.copy(dv, dv_shared)
T.copy(dv_shared, dV[bx % groups, bz, by * block_M:(by + 1) * block_M, bx // groups, :]) T.copy(dv_shared, dV[bx % groups, bz, by * block_M:(by + 1) * block_M, bx // groups, :])
......
...@@ -229,8 +229,7 @@ def flashattn_bwd(batch, heads, seq_len, dim, is_causal, block_M, block_N): ...@@ -229,8 +229,7 @@ def flashattn_bwd(batch, heads, seq_len, dim, is_causal, block_M, block_N):
T.clear(dq) T.clear(dq)
T.gemm(dsT_shared, K_shared, dq, transpose_A=True) T.gemm(dsT_shared, K_shared, dq, transpose_A=True)
for i, j in T.Parallel(block_N, dim): for i, j in T.Parallel(block_N, dim):
if k * block_N + i < seq_len: T.atomic_add(dQ[bz, bx, k * block_N + i, j], dq[i, j])
T.atomic_add(dQ[bz, bx, k * block_N + i, j], dq[i, j])
T.copy(dv, dv_shared) T.copy(dv, dv_shared)
T.copy(dk, dk_shared) T.copy(dk, dk_shared)
T.copy(dv_shared, dV[bz, bx, by * block_M:(by + 1) * block_M, :]) T.copy(dv_shared, dV[bz, bx, by * block_M:(by + 1) * block_M, :])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment