Commit e89e8b6c authored by Yu Cheng's avatar Yu Cheng Committed by GitHub
Browse files

[Doc] Add MLA Decoding Performance Benchmarks and Documentation (#137)

- Update news and MLA performance benchmark in README.md
- Move performance benchmark and layout images to a dedicated 'figures' directory
- Improve code formatting and image references in documentation
parent e32311b2
......@@ -11,6 +11,7 @@ Tile Language (**tile-lang**) is a concise domain-specific language designed to
<img src=./images/MatmulExample.png />
## Latest News
- 03/03/2025 🚀: Added high-performance MLA Decoding support using only 80 lines of Python code, achieving performance on par with FlashMLA on H100 (see [example_mla_decode.py](./examples/deepseek_mla/example_mla_decode.py))! We also provide [documentation](./examples/deepseek_mla/README.md) explaining how TileLang achieves this.
- 02/15/2025 ✨: Added WebGPU Codegen support, see [Pull Request #86](https://github.com/tile-ai/tilelang/pull/86)!
- 02/12/2025 ✨: Excited to announce the release of [v0.1.0](https://github.com/tile-ai/tilelang/releases/tag/v0.1.0)!
- 02/10/2025 🚀: Added debug tools for TileLang—`T.print` for printing variables/buffers ([docs](https://tilelang.tile-ai.cn/tutorials/debug_tools_for_tilelang.html)) and a memory layout plotter ([examples/plot_layout](./examples/plot_layout)).
......@@ -36,6 +37,17 @@ Within the `examples` directory, you will also find additional complex kernels
TileLang achieves exceptional performance across a variety of computational patterns. Comprehensive benchmark scripts and settings are available at [tilelang-benchmark](https://github.com/tile-ai/tilelang-benchmark). Below are selected results showcasing its capabilities:
- MLA Decoding Performance on H100
<div style="display: flex; gap: 10px; justify-content: center;">
<div style="flex: 1;">
<img src="./examples/deepseek_mla/figures/bs64_float16.png" alt="mla decode performance bs64 on H100" width="100%" />
</div>
<div style="flex: 1;">
<img src="./examples/deepseek_mla/figures/bs128_float16.png" alt="mla decode performance bs128 on H100" width="100%" />
</div>
</div>
- Flash Attention Performance on H100
<div align="center"> <img src="./images/mha_performance_h100.png" alt="operator performance on H100" width=80% />
......@@ -87,7 +99,7 @@ We currently provide three ways to install **tile-lang** from source:
## Quick Start
In this section, youll learn how to write and execute a straightforward GEMM (matrix multiplication) kernel using tile-lang, followed by techniques for layout optimizations, pipelining, and L2-cache–friendly swizzling.
In this section, you'll learn how to write and execute a straightforward GEMM (matrix multiplication) kernel using tile-lang, followed by techniques for layout optimizations, pipelining, and L2-cache–friendly swizzling.
### GEMM Example with Annotations (Layout, L2 Cache Swizzling, and Pipelining, etc.)
......
......@@ -11,17 +11,17 @@ DeepSeek's MLA (Multi-Head Latent Attention) is a novel attention mechanism know
We benchmarked the performance of FlashMLA, TileLang, Torch, Triton, and FlashInfer under batch sizes of 64 and 128, with float16 data type, as shown in the figures below.
<figure style="text-align: center">
<a href="./bs64_float16.png">
<img src="./bs64_float16.png" alt="bs64_float16">
<a href="./figures/bs64_float16.png">
<img src="./figures/bs64_float16.png" alt="bs64_float16">
</a>
<figcaption>Figure 1:Performance under batch size=64</figcaption>
<figcaption style="text-align: center;">Figure 1:Performance under batch size=64</figcaption>
</figure>
<figure style="text-align: center">
<a href="./bs128_float16.png">
<img src="./bs128_float16.png" alt="bs128_float16">
<a href="./figures/bs128_float16.png">
<img src="./figures/bs128_float16.png" alt="bs128_float16">
</a>
<figcaption>Figure 2:Performance under batch size=128</figcaption>
<figcaption style="text-align: center;">Figure 2:Performance under batch size=128</figcaption>
</figure>
As shown in the results, TileLang achieves performance comparable to FlashMLA in most cases, significantly outperforming both FlashInfer and Triton.
......@@ -65,17 +65,17 @@ While the above process may seem complex, but don't worry - TileLang will handle
Figure 3 and Figure 4 illustrate the frontend TileLang script and its corresponding execution plan for MLA. Here, `T.gemm` represents matrix multiplication operations, `transpose_B=True` indicates transposition of matrix B, and `policy=FullCol` specifies that each warpgroup computes one column (e.g., split the result matrix in vertical dimension). `T.copy` represents buffer-to-buffer copying operations.
<figure style="text-align: center">
<a href="./qk_layout.jpg">
<img src="./qk_layout.jpg" alt="QK Layout">
<a href="./figures/qk_layout.jpg">
<img src="./figures/qk_layout.jpg" alt="QK Layout">
</a>
<figcaption>Figure 3:Buffer shapes in Q @ K</figcaption>
<figcaption style="text-align: center;">Figure 3:Buffer shapes in Q @ K</figcaption>
</figure>
<figure style="text-align: center">
<a href="./qk_layout.jpg">
<img src="./pv_layout.jpg" alt="PV Layout">
<a href="./figures/pv_layout.jpg">
<img src="./figures/pv_layout.jpg" alt="PV Layout">
</a>
<figcaption>Figure 4:Buffer shapes in acc_s @ V</figcaption>
<figcaption style="text-align: center;">Figure 4:Buffer shapes in acc_s @ V</figcaption>
</figure>
The mapping from TileLang frontend code to execution plan is accomplished through Layout Inference. Layout inference is a core optimization technique in TileLang. It automatically deduces the required buffer shapes and optimal layouts based on Tile-Operators (like `T.gemm`, `T.copy`, etc.), then generates the corresponding code. Here, we demonstrate a concrete example of buffer shape inference in MLA.
......
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
......@@ -105,8 +105,7 @@ print(base_layout)
plot_layout(base_layout, name="base_layout")
# warp layout 32x16
warp_layout = base_layout.repeat([block_rows, 1],
repeat_on_thread=True).replicate(block_cols)
warp_layout = base_layout.repeat([block_rows, 1], repeat_on_thread=True).replicate(block_cols)
print(warp_layout)
plot_layout(warp_layout, name="warp_layout")
......
......@@ -124,11 +124,23 @@ def plot_layout(layout: T.Layout,
thread_fontsize = min(font_size, font_size * (4 / len(thread_str)))
# Add thread ID text with adjusted font size
ax.text(j + 0.5, i + 0.3, thread_str,
ha='center', va='center', color='black', fontsize=thread_fontsize)
ax.text(
j + 0.5,
i + 0.3,
thread_str,
ha='center',
va='center',
color='black',
fontsize=thread_fontsize)
# Add local ID text with original font size
ax.text(j + 0.5, i + 0.7, f"L{local_id}",
ha='center', va='center', color='black', fontsize=font_size)
ax.text(
j + 0.5,
i + 0.7,
f"L{local_id}",
ha='center',
va='center',
color='black',
fontsize=font_size)
# Add row labels to the left side of the plot
for i in range(nrows):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment