Unverified Commit bf67fb19 authored by Tong WU's avatar Tong WU Committed by GitHub
Browse files

[Example] Optimize sink attention forward via swizzled layout and report benchmark results (#885)



* Enhance attention sink examples with swizzled layout and performance metrics

- Added `make_swizzled_layout` annotations for shared tensors in the `flashattn` function across MHA and GQA examples to optimize memory access patterns.
- Updated benchmark outputs to include speedup calculations comparing Triton and TileLang implementations.

* Add README for Attention Sink example with algorithm details and benchmark results

- Introduced a new README.md file for the Attention Sink example, outlining the forward and backward algorithms, including the computation of `dsinks`.
- Provided benchmark results comparing performance metrics of the optimized implementation against Triton, highlighting speedup across various configurations.

* Update README.md for Attention Sink example to include link to Triton implementation

* Update examples/attention_sink/README.md
Co-authored-by: default avatargemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>

* Update examples/attention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py
Co-authored-by: default avatargemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>

* typo

---------
Co-authored-by: default avatargemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
parent c861d8a2
# Attention Sink
We compare with an optimized version of the official Triton implementation at [here](https://github.com/openai/gpt-oss/blob/main/gpt_oss/triton/attention.py).
## Algorithm
### Forward
The only change from vanilla FlashAttention is that `sinks` should be taken into consideration in the softmax, which requires an extra rescaling at the epilogue stage.
### Backward
Based on detailed mathematical derivation, interestingly, the backward computation process of `dQ`, `dK`, `dv` is almost identical to that in vanilla FlashAttention, except for that the specific meanings of `lse` differ. We only need to compute `dsinks` additionally, which is given by:
$$
dsink_h=-\sum_{b}\sum_{q}P_{b, h, q}Delta_{b, h, q}
$$
where $P_{b, h, q}$ is the proportion of $sink_h$ in the softmax in the $b$-th block, $h$-th head and $q$-th query(row).
## Benchmark of forward process
### Benchmark Environment
- **Hardware**: NVIDIA H800
- **CUDA version**: 12.9
- **Triton Version**: 3.4.0
### Results
- dtype=float16
- batch_size=1, heads=64, kv_heads=8 (the setting of GPT-OSS-120B)
- Full attention is adopted.
| SEQ_LEN | headdim | Triton TFLOPs | TileLang TFLOPs | Speedup |
|---------|---------|---------------|----------------------|---------|
| 2048 | 64 | 231.55 | **277.07** | 1.20x |
| 2048 | 128 | 313.55 | **393.98** | 1.26x |
| | | | | |
| 4096 | 64 | 272.17 | **337.30** | 1.24x |
| 4096 | 128 | 356.35 | **461.54** | 1.30x |
| | | | | |
| 8192 | 64 | 289.93 | **353.81** | 1.22x |
| 8192 | 128 | 392.18 | **482.50** | 1.23x |
| | | | | |
| 16384 | 64 | 299.52 | **377.44** | 1.26x |
| 16384 | 128 | 404.64 | **519.02** | 1.28x |
> The backward performance will be further optimized via fine-grained manual pipelining of FA3 in the tilelang kernel.
\ No newline at end of file
......@@ -6,6 +6,7 @@ import tilelang
from tilelang.autotuner import autotune
from tilelang.profiler import do_bench
import tilelang.language as T
from tilelang.layout import make_swizzled_layout
import itertools
import argparse
import triton
......@@ -152,6 +153,13 @@ def flashattn(
logsum = T.alloc_fragment([block_M], accum_dtype)
sinks = T.alloc_fragment([block_M], dtype)
T.annotate_layout({
Q_shared: make_swizzled_layout(Q_shared),
K_shared: make_swizzled_layout(K_shared),
V_shared: make_swizzled_layout(V_shared),
O_shared: make_swizzled_layout(O_shared),
})
T.copy(Q[bz, by, bx * block_M:(bx + 1) * block_M, :], Q_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
......@@ -425,22 +433,24 @@ def main(
print("Checks for triton failed.❌")
# Benchmark triton
latency = do_bench(lambda: triton_program(Q, K, V, sinks, window_size), warmup=500)
print("Triton: {:.2f} ms".format(latency))
print("Triton: {:.2f} TFlops".format(total_flops / latency * 1e-9))
latency_triton = do_bench(lambda: triton_program(Q, K, V, sinks, window_size), warmup=500)
print("Triton: {:.2f} ms".format(latency_triton))
print("Triton: {:.2f} TFlops".format(total_flops / latency_triton * 1e-9))
# Benchmark tilelang
latency = do_bench(lambda: kernel(Q, K, V, sinks), warmup=500)
print("Tilelang: {:.2f} ms".format(latency))
print("Tilelang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
latency_tilelang = do_bench(lambda: kernel(Q, K, V, sinks), warmup=500)
print("Tilelang: {:.2f} ms".format(latency_tilelang))
print("Tilelang: {:.2f} TFlops".format(total_flops / latency_tilelang * 1e-9))
print("Speedup: {:.2f}x".format(latency_triton / latency_tilelang))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=1, help='batch size')
parser.add_argument('--heads', type=int, default=64, help='heads')
parser.add_argument('--seq_q', type=int, default=4096, help='sequence length of query')
parser.add_argument('--seq_kv', type=int, default=4096, help='sequence length of key/value')
parser.add_argument('--seq_q', type=int, default=2048, help='sequence length of query')
parser.add_argument('--seq_kv', type=int, default=2048, help='sequence length of key/value')
parser.add_argument('--dim', type=int, default=128, help='dim')
parser.add_argument('--groups', type=int, default=8, help='groups')
parser.add_argument(
......
......@@ -5,6 +5,7 @@ import tilelang
from tilelang.autotuner import autotune
from tilelang.profiler import do_bench
import tilelang.language as T
from tilelang.layout import make_swizzled_layout
import itertools
import argparse
......@@ -140,6 +141,13 @@ def flashattn(
logsum = T.alloc_fragment([block_M], accum_dtype)
sinks = T.alloc_fragment([block_M], dtype)
T.annotate_layout({
Q_shared: make_swizzled_layout(Q_shared),
K_shared: make_swizzled_layout(K_shared),
V_shared: make_swizzled_layout(V_shared),
O_shared: make_swizzled_layout(O_shared),
})
T.copy(Q[bz, by, bx * block_M:(bx + 1) * block_M, :], Q_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
......
......@@ -6,6 +6,7 @@ import tilelang
from tilelang.autotuner import autotune
from tilelang.profiler import do_bench
import tilelang.language as T
from tilelang.layout import make_swizzled_layout
import itertools
import argparse
import triton
......@@ -145,6 +146,13 @@ def flashattn(
logsum = T.alloc_fragment([block_M], accum_dtype)
sinks = T.alloc_fragment([block_M], dtype)
T.annotate_layout({
Q_shared: make_swizzled_layout(Q_shared),
K_shared: make_swizzled_layout(K_shared),
V_shared: make_swizzled_layout(V_shared),
O_shared: make_swizzled_layout(O_shared),
})
T.copy(Q[bz, by, bx * block_M:(bx + 1) * block_M, :], Q_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment